├── .gitignore ├── src ├── models │ ├── TwoTower.py │ ├── OneTowerBert.py │ ├── FIM.py │ ├── modules │ │ ├── attention.py │ │ └── encoder.py │ └── BaseModel.py ├── main │ ├── daemon.py │ ├── onetower.py │ ├── fim.py │ └── twotower.py ├── dev.ipynb └── utils │ ├── util.py │ ├── dataset.py │ └── manager.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | cache 3 | ckpts 4 | results 5 | email.py 6 | *.log 7 | __pycache__ -------------------------------------------------------------------------------- /src/models/TwoTower.py: -------------------------------------------------------------------------------- 1 | from .BaseModel import TwoTowerBaseModel 2 | 3 | 4 | 5 | class TwoTowerModel(TwoTowerBaseModel): 6 | def __init__(self, manager, newsEncoder, userEncoder): 7 | super().__init__(manager, name="-".join(["TwoTower", newsEncoder.name, userEncoder.name])) 8 | self.newsEncoder = newsEncoder 9 | self.userEncoder = userEncoder 10 | 11 | -------------------------------------------------------------------------------- /src/main/daemon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import logging 4 | import argparse 5 | logging.basicConfig(level=logging.INFO, 6 | format="[%(asctime)s] %(levelname)s (%(name)s) %(message)s") 7 | logger = logging.getLogger(__file__) 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("-d","--device", dest="device", 12 | help="device to run on, -1 means cpu", choices=[i for i in range(-1,10)], type=int, default=0) 13 | args = parser.parse_args() 14 | 15 | 16 | logger.info("I'm running on cuda:{} to stop the platform killing this job!".format(args.device)) 17 | a = torch.zeros((1),device=args.device) 18 | while(1): 19 | if a.item() > 2: 20 | a -= 1 21 | else: 22 | a += 1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Download 2 | There are two branches in this REPO. The `backup` branch is deprecated and used as backup. You can only clone this main branch by: 3 | ``` 4 | git clone -b main --single-branch https://github.com/namespace-Pt/News-Recommendation.git 5 | ``` 6 | 7 | ### Instruction 8 | 1. Download **MIND** dataset [here](https://msnews.github.io/) 9 | 2. Save MIND dataset in a directory, e.g. `~/Data` 10 | 3. ``` 11 | pip install torch==1.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 12 | ``` 13 | 4. Add your email information so that you can get an email every time the model is evaluated: 14 | ```bash 15 | cd src 16 | mkdir data 17 | echo "email = 'your.gmail@gmail.com'" >> data/email.py 18 | echo "password = 'your password'" >> data/email.py 19 | ``` 20 | 5. ```bash 21 | python -m main.twotower --data-root ~/Data \ 22 | --batch-size 8 \ 23 | --world-size 2 \ 24 | --news-encoder bert 25 | ``` 26 | - `--world-size` defines number of gpus in ddp training 27 | - `--news-encoder` defines news encoder 28 | - more parameters can be found in [src/utils/manager.py](src/utils/manager.py) 29 | 30 | -------------------------------------------------------------------------------- /src/main/onetower.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as mp 3 | from utils.manager import Manager 4 | from models.OneTowerBert import OneTowerBert 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | 7 | 8 | def main(rank, manager): 9 | """ train/dev/test the model (in distributed) 10 | 11 | Args: 12 | rank: current process id 13 | world_size: total gpus 14 | """ 15 | manager.setup(rank) 16 | loaders = manager.prepare() 17 | 18 | model = OneTowerBert(manager).to(manager.device) 19 | 20 | if manager.mode == 'train': 21 | if manager.world_size > 1: 22 | model = DDP(model, device_ids=[manager.device], output_device=manager.device) 23 | manager.train(model, loaders) 24 | 25 | elif manager.mode == 'dev': 26 | manager.load(model) 27 | model.dev(manager, loaders, log=True) 28 | 29 | elif manager.mode == 'test': 30 | manager.load(model) 31 | model.test(manager, loaders) 32 | 33 | 34 | if __name__ == "__main__": 35 | config = { 36 | "batch_size_eval": 100, 37 | "enable_fields": ["title"], 38 | "validate_step": "0.5e", 39 | } 40 | manager = Manager(config) 41 | 42 | if manager.world_size > 1: 43 | mp.spawn( 44 | main, 45 | args=(manager,), 46 | nprocs=manager.world_size, 47 | join=True 48 | ) 49 | else: 50 | main(manager.device, manager) -------------------------------------------------------------------------------- /src/main/fim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as mp 3 | from utils.manager import Manager 4 | from models.FIM import FIM 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | 7 | 8 | def main(rank, manager): 9 | """ train/dev/test the model (in distributed) 10 | 11 | Args: 12 | rank: current process id 13 | world_size: total gpus 14 | """ 15 | manager.setup(rank) 16 | loaders = manager.prepare() 17 | 18 | model = FIM(manager).to(manager.device) 19 | 20 | if manager.mode == 'train': 21 | if manager.world_size > 1: 22 | model = DDP(model, device_ids=[manager.device], output_device=manager.device) 23 | manager.train(model, loaders) 24 | 25 | elif manager.mode == 'dev': 26 | manager.load(model) 27 | model.dev(manager, loaders, log=True) 28 | 29 | elif manager.mode == 'test': 30 | manager.load(model) 31 | model.test(manager, loaders) 32 | 33 | 34 | if __name__ == "__main__": 35 | config = { 36 | "batch_size": 100, 37 | "batch_size_eval": 100, 38 | "enable_fields": ["title"], 39 | "hidden_dim": 150, 40 | "learning_rate": 1e-5, 41 | "validate_step": "0.5e", 42 | } 43 | manager = Manager(config) 44 | 45 | # essential to set this to False to speed up dilated cnn 46 | torch.backends.cudnn.deterministic = False 47 | 48 | if manager.world_size > 1: 49 | mp.spawn( 50 | main, 51 | args=(manager,), 52 | nprocs=manager.world_size, 53 | join=True 54 | ) 55 | else: 56 | main(manager.device, manager) -------------------------------------------------------------------------------- /src/models/OneTowerBert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .BaseModel import OneTowerBaseModel 4 | from .modules.encoder import BertCrossEncoder, TFMCrossEncoder 5 | 6 | 7 | 8 | class OneTowerBert(OneTowerBaseModel): 9 | def __init__(self, manager): 10 | super().__init__(manager) 11 | self.encoder = TFMCrossEncoder(manager) 12 | 13 | self.pooler = nn.Linear(manager.plm_dim, 1) 14 | self.aggregator = nn.Linear(self.his_size, 1) 15 | nn.init.xavier_normal_(self.pooler.weight) 16 | nn.init.xavier_normal_(self.aggregator.weight) 17 | 18 | 19 | def infer(self, x): 20 | cdd_token_id = x["cdd_token_id"].to(self.device) # B, C, L 21 | his_token_id = x["his_token_id"].to(self.device) # B, N, L 22 | cdd_attn_mask = x["cdd_attn_mask"].to(self.device) # B, C, L 23 | his_attn_mask = x["his_attn_mask"].to(self.device) # B, N, L 24 | 25 | B, C, L, N = *cdd_token_id.shape, self.his_size 26 | cdd_token_id = cdd_token_id.unsqueeze(-2).expand(B, C, N, L) 27 | his_token_id = his_token_id.unsqueeze(1).expand(B, C, N, L) 28 | cdd_attn_mask = cdd_attn_mask.unsqueeze(-2).expand(B, C, N, L) 29 | his_attn_mask = his_attn_mask.unsqueeze(1).expand(B, C, N, L) 30 | 31 | concat_token_id = torch.cat([cdd_token_id, his_token_id], dim=-1) 32 | concat_attn_mask = torch.cat([cdd_attn_mask, his_attn_mask], dim=-1) 33 | 34 | news_embedding = self.encoder(concat_token_id, concat_attn_mask) # B, C, N, D 35 | 36 | news_his_score = self.pooler(news_embedding).squeeze(-1) # B, C, N 37 | logits = self.aggregator(news_his_score).squeeze(-1) 38 | return logits 39 | 40 | 41 | def forward(self,x): 42 | logits = self.infer(x) 43 | labels = x["label"].to(self.device) 44 | loss = self.crossEntropy(logits, labels) 45 | return loss -------------------------------------------------------------------------------- /src/main/twotower.py: -------------------------------------------------------------------------------- 1 | import torch.multiprocessing as mp 2 | from utils.manager import Manager 3 | from models.TwoTower import TwoTowerModel 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | from models.modules.encoder import * 6 | 7 | 8 | def main(rank, manager): 9 | """ train/dev/test the model (in distributed) 10 | 11 | Args: 12 | rank: current process id 13 | world_size: total gpus 14 | """ 15 | manager.setup(rank) 16 | loaders = manager.prepare() 17 | 18 | if manager.newsEncoder == "cnn": 19 | newsEncoder = CnnNewsEncoder(manager) 20 | elif manager.newsEncoder == "bert": 21 | newsEncoder = AllBertNewsEncoder(manager) 22 | elif manager.newsEncoder == "tfm": 23 | newsEncoder = TfmNewsEncoder(manager) 24 | if manager.userEncoder == "rnn": 25 | userEncoder = RnnUserEncoder(manager) 26 | elif manager.userEncoder == "sum": 27 | userEncoder = SumUserEncoder(manager) 28 | elif manager.userEncoder == "avg": 29 | userEncoder = AvgUserEncoder(manager) 30 | elif manager.userEncoder == "attn": 31 | userEncoder = AttnUserEncoder(manager) 32 | elif manager.userEncoder == "tfm": 33 | userEncoder = TfmUserEncoder(manager) 34 | 35 | model = TwoTowerModel(manager, newsEncoder, userEncoder).to(manager.device) 36 | 37 | if manager.mode == 'train': 38 | if manager.world_size > 1: 39 | model = DDP(model, device_ids=[manager.device], output_device=manager.device) 40 | manager.train(model, loaders) 41 | 42 | elif manager.mode == 'dev': 43 | manager.load(model) 44 | model.dev(manager, loaders, log=True) 45 | 46 | 47 | if __name__ == "__main__": 48 | config = { 49 | "enable_fields": ["title"], 50 | "newsEncoder": "cnn", 51 | "userEncoder": "rnn", 52 | } 53 | manager = Manager(config) 54 | 55 | if manager.world_size > 1: 56 | mp.spawn( 57 | main, 58 | args=(manager,), 59 | nprocs=manager.world_size, 60 | join=True 61 | ) 62 | else: 63 | main(manager.device, manager) -------------------------------------------------------------------------------- /src/models/FIM.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .BaseModel import OneTowerBaseModel 3 | from .modules.encoder import HDCNNNewsEncoder 4 | 5 | 6 | 7 | class FIM(OneTowerBaseModel): 8 | def __init__(self, manager): 9 | super().__init__(manager) 10 | self.encoder = HDCNNNewsEncoder(manager) 11 | 12 | self.seqConv3D = nn.Sequential( 13 | nn.Conv3d(in_channels=self.encoder.level, out_channels=32, kernel_size=[3, 3, 3], padding=1), 14 | nn.ReLU(), 15 | nn.MaxPool3d(kernel_size=[3, 3, 3], stride=[3, 3, 3]), 16 | nn.Conv3d(in_channels=32, out_channels=16, kernel_size=[3, 3, 3], padding=1), 17 | nn.ReLU(), 18 | nn.MaxPool3d(kernel_size=[3, 3, 3], stride=[3, 3, 3]) 19 | ) 20 | nn.init.xavier_normal_(self.seqConv3D[0].weight) 21 | nn.init.xavier_normal_(self.seqConv3D[3].weight) 22 | 23 | 24 | final_dim = (self.his_size // 3 // 3) * (self.sequence_length // 3 // 3) ** 2 * 16 25 | self.pooler = nn.Linear(final_dim, 1) 26 | nn.init.xavier_normal_(self.pooler.weight) 27 | 28 | 29 | def infer(self, x): 30 | cdd_token_id = x["cdd_token_id"].to(self.device) 31 | cdd_token_embedding, _ = self.encoder(cdd_token_id) # B, C, V, L, D 32 | 33 | his_token_id = x["his_token_id"].to(self.device) 34 | his_token_embedding, _ = self.encoder(his_token_id) # B, N, V, L, D 35 | 36 | cdd_token_embedding = cdd_token_embedding.unsqueeze(2) 37 | his_token_embedding = his_token_embedding.unsqueeze(1) 38 | 39 | matching = cdd_token_embedding.matmul(his_token_embedding.transpose(-1, -2)) # B, C, N, V, L, L 40 | B, C, N, V, L = matching.shape[:-1] 41 | cnn_input = matching.view(-1, N, V, L, L).transpose(1, 2) # B*C, V, N, L, L 42 | cnn_output = self.seqConv3D(cnn_input).view(B, C, -1) # B*C, x 43 | 44 | logits = self.pooler(cnn_output).squeeze(-1) 45 | return logits 46 | 47 | 48 | def forward(self,x): 49 | logits = self.infer(x) 50 | labels = x["label"].to(self.device) 51 | loss = self.crossEntropy(logits, labels) 52 | return loss -------------------------------------------------------------------------------- /src/models/modules/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import _softmax_backward_data, nn 4 | 5 | 6 | def scaled_dp_attention(query, key, value, attn_mask=None, return_prob=False): 7 | """ calculate scaled attended output of values 8 | Args: 9 | query: tensor of [] 10 | key: tensor of [batch_size, *, key_num, key_dim] 11 | value: tensor of [batch_size, *, key_num, value_dim] 12 | attn_mask: tensor of [batch_size, *, query_num, key_num] 13 | Returns: 14 | attn_output: tensor of [batch_size, *, query_num, value_dim] 15 | """ 16 | 17 | # make sure dimension matches 18 | assert query.shape[-1] == key.shape[-1] 19 | key = key.transpose(-2, -1) 20 | 21 | attn_score = torch.matmul(query, key)/math.sqrt(query.shape[-1]) 22 | 23 | if attn_mask is not None: 24 | attn_mask = (1 - attn_mask) * -1e5 25 | attn_prob = torch.softmax(attn_score + attn_mask, -1) 26 | else: 27 | attn_prob = torch.softmax(attn_score, -1) 28 | 29 | attn_output = torch.matmul(attn_prob, value) 30 | 31 | if return_prob: 32 | return attn_output, attn_prob 33 | else: 34 | return attn_output 35 | 36 | 37 | def extend_attention_mask(encoder_attention_mask): 38 | """ 39 | Args: 40 | encoder_attention_mask (`torch.Tensor`): An attention mask. 41 | Returns: 42 | `torch.Tensor`: The inverted attention mask. 43 | """ 44 | if encoder_attention_mask.dim() == 3: 45 | encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] 46 | if encoder_attention_mask.dim() == 2: 47 | encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] 48 | 49 | encoder_extended_attention_mask = (1 - encoder_extended_attention_mask) * -1e5 50 | 51 | return encoder_extended_attention_mask 52 | 53 | 54 | 55 | class TFMSelfAttention(nn.Module): 56 | def __init__(self, hidden_dim, head_num, dropout_p): 57 | super().__init__() 58 | 59 | self.num_attention_heads = head_num 60 | self.attention_head_size = int(hidden_dim / head_num) 61 | self.all_head_size = self.num_attention_heads * self.attention_head_size 62 | if self.all_head_size != hidden_dim: 63 | RuntimeWarning(f"Truncating given hidden dim {hidden_dim} to {self.all_head_size} so that it can be divided by head num {head_num}") 64 | 65 | self.query = nn.Linear(hidden_dim, self.all_head_size) 66 | self.key = nn.Linear(hidden_dim, self.all_head_size) 67 | self.value = nn.Linear(hidden_dim, self.all_head_size) 68 | 69 | self.dropout = nn.Dropout(dropout_p) 70 | 71 | def transpose_for_scores(self, x): 72 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 73 | x = x.view(*new_x_shape) 74 | return x.permute(0, 2, 1, 3) 75 | 76 | def forward( 77 | self, 78 | hidden_states, 79 | attention_mask=None, 80 | ): 81 | # broadcast attention masks 82 | query_layer = self.transpose_for_scores(self.query(hidden_states)) 83 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 84 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 85 | 86 | # Take the dot product between "query" and "key" to get the raw attention scores. 87 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 88 | 89 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 90 | attention_probs = torch.softmax(attention_scores + attention_mask, dim=-1) 91 | 92 | attention_probs = self.dropout(attention_probs) 93 | 94 | context_layer = torch.matmul(attention_probs, value_layer) 95 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 96 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 97 | context_layer = context_layer.view(*new_context_layer_shape) 98 | 99 | return context_layer 100 | 101 | 102 | class TFMSelfOutput(nn.Module): 103 | def __init__(self, hidden_dim, dropout_p): 104 | super().__init__() 105 | self.dense = nn.Linear(hidden_dim, hidden_dim) 106 | self.LayerNorm = nn.LayerNorm(hidden_dim, eps=1e-12) 107 | self.dropout = nn.Dropout(dropout_p) 108 | 109 | def forward(self, hidden_states, input_tensor): 110 | hidden_states = self.dense(hidden_states) 111 | hidden_states = self.dropout(hidden_states) 112 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 113 | return hidden_states 114 | 115 | 116 | class TFMAttention(nn.Module): 117 | def __init__(self, hidden_dim, head_num, dropout_p): 118 | super().__init__() 119 | self.self = TFMSelfAttention(hidden_dim=hidden_dim, head_num=head_num, dropout_p=dropout_p) 120 | self.output = TFMSelfOutput(hidden_dim=hidden_dim, dropout_p=dropout_p) 121 | 122 | def forward( 123 | self, 124 | hidden_states, 125 | attention_mask=None, 126 | ): 127 | self_outputs = self.self( 128 | hidden_states, 129 | attention_mask 130 | ) 131 | attention_output = self.output(self_outputs, hidden_states) 132 | return attention_output 133 | 134 | 135 | class TFMIntermediate(nn.Module): 136 | def __init__(self, hidden_dim): 137 | super().__init__() 138 | self.dense = nn.Linear(hidden_dim, 4 * hidden_dim) 139 | self.intermediate_act_fn = nn.functional.gelu 140 | 141 | def forward(self, hidden_states): 142 | hidden_states = self.dense(hidden_states) 143 | hidden_states = self.intermediate_act_fn(hidden_states) 144 | return hidden_states 145 | 146 | 147 | class TFMOutput(nn.Module): 148 | def __init__(self, hidden_dim, dropout_p): 149 | super().__init__() 150 | self.dense = nn.Linear(4 * hidden_dim, hidden_dim) 151 | self.LayerNorm = nn.LayerNorm(hidden_dim, eps=1e-12) 152 | self.dropout = nn.Dropout(dropout_p) 153 | 154 | def forward(self, hidden_states, input_tensor): 155 | hidden_states = self.dense(hidden_states) 156 | hidden_states = self.dropout(hidden_states) 157 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 158 | return hidden_states 159 | 160 | 161 | class TFMLayer(nn.Module): 162 | def __init__(self, hidden_dim, head_num, dropout_p): 163 | """ 164 | hidden_dim: transformer model dimension 165 | head_num: number of self attention heads 166 | dropout_p: dropout probability 167 | """ 168 | super().__init__() 169 | self.attention = TFMAttention(hidden_dim=hidden_dim, head_num=head_num, dropout_p=dropout_p) 170 | self.intermediate = TFMIntermediate(hidden_dim=hidden_dim) 171 | self.output = TFMOutput(hidden_dim=hidden_dim, dropout_p=dropout_p) 172 | 173 | def forward( 174 | self, 175 | hidden_states, 176 | attention_mask=None, 177 | ): 178 | attention_mask = extend_attention_mask(attention_mask) 179 | attention_output = self.attention(hidden_states, attention_mask) 180 | intermediate_output = self.intermediate(attention_output) 181 | layer_output = self.output(intermediate_output, attention_output) 182 | return layer_output 183 | -------------------------------------------------------------------------------- /src/dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os \n", 10 | "import torch\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "import torch.nn as nn \n", 14 | "import torch.nn.functional as F\n", 15 | "from transformers import AutoTokenizer, AutoModel\n", 16 | "from utils.manager import Manager\n", 17 | "from utils.util import load_pickle, save_pickle, BM25" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stderr", 27 | "output_type": "stream", 28 | "text": [ 29 | "[2022-03-30 10:48:15,077] INFO (Manager) Hyper Parameters are:\n", 30 | "{'scale': 'demo', 'batch_size': 2, 'batch_size_eval': 2, 'checkpoint': 'none', 'verbose': None, 'his_size': 50, 'impr_size': 20, 'negative_num': 4, 'dropout_p': 0.1, 'learning_rate': 1e-05, 'scheduler': 'none', 'warmup': 0.1, 'title_length': 32, 'abs_length': 64, 'enable_fields': ['title', 'abs'], 'newsEncoder': 'cnn', 'userEncoder': 'rnn', 'hidden_dim': 768, 'head_num': 12, 'k': 4, 'plm': 'distilbert', 'seed': 3407, 'world_size': 1, 'sequence_length': 96}\n", 31 | "\n", 32 | "[2022-03-30 10:48:15,081] INFO (MIND_Train) Loading Cache at MINDdemo_train\n", 33 | "[2022-03-30 10:48:15,981] INFO (MIND_Dev) Loading Cache at MINDdemo_dev\n", 34 | "[2022-03-30 10:48:16,779] INFO (MIND_News) Loading Cache at MINDdemo_dev\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "command = \"\"\"\n", 40 | "-bs 2 -bse 2 -ef title abs -s demo -plm distilbert\n", 41 | "\"\"\"\n", 42 | "manager = Manager(command=command.strip().split(\" \"))\n", 43 | "loaders = manager.prepare()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "t = AutoTokenizer.from_pretrained(manager.plm_dir)\n", 53 | "m = AutoModel.from_pretrained(manager.plm_dir)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "loader_train = loaders[\"train\"]\n", 63 | "loader_dev = loaders[\"dev\"]\n", 64 | "loader_news = loaders[\"news\"]\n", 65 | "\n", 66 | "dataset_train = loader_train.dataset\n", 67 | "dataset_dev = loader_dev.dataset\n", 68 | "dataset_news = loader_news.dataset\n", 69 | "\n", 70 | "X1 = iter(loader_train)\n", 71 | "X2 = iter(loader_dev)\n", 72 | "X3 = iter(loader_news)\n", 73 | "x = next(X1)\n", 74 | "x2 = next(X2)\n", 75 | "x3 = next(X3)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 6, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "\"the brands queen elizabeth, prince charles, and prince philip swear by shop the notebooks, jackets, and more that the royals can't live without.\"" 87 | ] 88 | }, 89 | "execution_count": 6, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "# check news\n", 96 | "index = 1\n", 97 | "cdd_token_id = x3['cdd_token_id'][index]\n", 98 | "t.decode(cdd_token_id, skip_special_tokens=True)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 11, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "ename": "RuntimeError", 108 | "evalue": "The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 1. Target sizes: [2, 3, 3]. Tensor sizes: [2, 3]", 109 | "output_type": "error", 110 | "traceback": [ 111 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 112 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 113 | "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m a \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrand(\u001b[39m2\u001b[39m,\u001b[39m3\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m a\u001b[39m.\u001b[39;49mexpand(\u001b[39m2\u001b[39;49m,\u001b[39m3\u001b[39;49m,\u001b[39m3\u001b[39;49m)\n", 114 | "\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 1. Target sizes: [2, 3, 3]. Tensor sizes: [2, 3]" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "a = torch.rand(2,3)\n", 120 | "a.expand(2,3,3)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 10, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | " a \n", 133 | "[CLS] 1\n", 134 | "what 1\n", 135 | "you 1\n", 136 | "need 1\n", 137 | "to 1\n", 138 | "know 1\n", 139 | "about 1\n", 140 | "the 1\n", 141 | "c 1\n", 142 | "##8 1\n", 143 | "corvette 1\n", 144 | "' 1\n", 145 | "s 1\n", 146 | "new 1\n", 147 | "dual 1\n", 148 | "- 1\n", 149 | "clutch 1\n", 150 | "transmission 1\n", 151 | "[SEP] 1\n", 152 | "the 1\n", 153 | "new 1\n", 154 | "corvette 1\n", 155 | "has 1\n", 156 | "an 1\n", 157 | "eight 1\n", 158 | "- 1\n", 159 | "speed 1\n", 160 | "tre 1\n", 161 | "##me 1\n", 162 | "##c 1\n", 163 | "dc 1\n", 164 | "##t 1\n", 165 | ". 1\n", 166 | "we 1\n", 167 | "weren 1\n", 168 | "' 1\n", 169 | "t 1\n", 170 | "crazy 1\n", 171 | "about 1\n", 172 | "it 1\n", 173 | "in 1\n", 174 | "the 1\n", 175 | "pre 1\n", 176 | "- 1\n", 177 | "production 1\n", 178 | "c 1\n", 179 | "##8 1\n", 180 | "we 1\n", 181 | "drove 1\n", 182 | ", 1\n", 183 | "but 1\n", 184 | "engineers 1\n", 185 | "tell 1\n", 186 | "us 1\n", 187 | "the 1\n", 188 | "final 1\n", 189 | "version 1\n", 190 | "will 1\n", 191 | "be 1\n", 192 | "better 1\n", 193 | ". 1\n", 194 | "[SEP] 1\n", 195 | "[PAD] 0\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "# check attention mask\n", 201 | "index = (0, 0)\n", 202 | "cdd_token_id = x['cdd_token_id'][index]\n", 203 | "cdd_attn_mask = x[\"cdd_attn_mask\"][index]\n", 204 | "his_token_id = x[\"his_token_id\"][index]\n", 205 | "his_attn_mask = x[\"his_attn_mask\"][index]\n", 206 | "\n", 207 | "cdd_token = t.convert_ids_to_tokens(cdd_token_id)\n", 208 | "his_token = t.convert_ids_to_tokens(his_token_id)\n", 209 | "\n", 210 | "line = \"{:20} a \".format(\" \"*20)\n", 211 | "print(line)\n", 212 | "for i in range(manager.sequence_length):\n", 213 | " line = \"{:20} {}\".format(cdd_token[i], cdd_attn_mask[i])\n", 214 | " print(line)\n", 215 | " if cdd_token[i] == \"[PAD]\":\n", 216 | " break" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "# check train loader result\n", 226 | "nid2index = load_pickle(\"/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/MINDdemo_train/news/nid2index.pkl\")\n", 227 | "uid2index = load_pickle(\"/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/uid2index.pkl\")\n", 228 | "nindex2id = {v:k for k,v in nid2index.items()}\n", 229 | "uindex2id = {v:k for k,v in uid2index.items()}\n", 230 | "\n", 231 | "# check behaviors.tsv\n", 232 | "print([uindex2id[i] for i in x[\"user_index\"].tolist()], (x[\"impr_index\"] + 1).tolist())\n", 233 | "# check news.tsv\n", 234 | "print([nindex2id[i] for i in x[\"cdd_idx\"][0][:5].tolist()])\n", 235 | "print(t.batch_decode(x[\"cdd_token_id\"][0][:5], skip_special_tokens=True))\n", 236 | "\n", 237 | "print([nindex2id[i] for i in x[\"his_idx\"][0][:5].tolist()])\n", 238 | "print(t.batch_decode(x[\"his_token_id\"][0][:5], skip_special_tokens=True))" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "# check dev loader result\n", 248 | "nid2index = load_pickle(\"/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/MINDdemo_dev/news/nid2index.pkl\")\n", 249 | "uid2index = load_pickle(\"/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/uid2index.pkl\")\n", 250 | "nindex2id = {v:k for k,v in nid2index.items()}\n", 251 | "uindex2id = {v:k for k,v in uid2index.items()}\n", 252 | "\n", 253 | "# check behaviors.tsv\n", 254 | "print([uindex2id[i] for i in x2[\"user_index\"].tolist()], (x2[\"impr_index\"] + 1).tolist())\n", 255 | "# check news.tsv\n", 256 | "print([nindex2id[i] for i in x2[\"cdd_idx\"][0][:5].tolist()])\n", 257 | "print(t.batch_decode(x2[\"cdd_token_id\"][0][:5], skip_special_tokens=True))\n", 258 | "\n", 259 | "print([nindex2id[i] for i in x2[\"his_idx\"][0][:5].tolist()])\n", 260 | "print(t.batch_decode(x2[\"his_token_id\"][0][:5], skip_special_tokens=True))" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [] 269 | } 270 | ], 271 | "metadata": { 272 | "interpreter": { 273 | "hash": "a256a4def3bbb1bd6a1d46703c4995443a919758d62b261face579c969ba8076" 274 | }, 275 | "kernelspec": { 276 | "display_name": "Python 3.9.7 64-bit ('nn': conda)", 277 | "language": "python", 278 | "name": "python3" 279 | }, 280 | "language_info": { 281 | "codemirror_mode": { 282 | "name": "ipython", 283 | "version": 3 284 | }, 285 | "file_extension": ".py", 286 | "mimetype": "text/x-python", 287 | "name": "python", 288 | "nbconvert_exporter": "python", 289 | "pygments_lexer": "ipython3", 290 | "version": "3.9.7" 291 | }, 292 | "orig_nbformat": 4 293 | }, 294 | "nbformat": 4, 295 | "nbformat_minor": 2 296 | } 297 | -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import math 4 | import pickle 5 | import logging 6 | import pandas as pd 7 | import numpy as np 8 | from tqdm import tqdm 9 | from random import sample 10 | from collections import defaultdict 11 | from transformers import AutoModel, AutoTokenizer 12 | from sklearn.metrics import roc_auc_score, log_loss, mean_squared_error, accuracy_score, f1_score 13 | 14 | 15 | 16 | def load_pickle(path): 17 | """ load pickle file 18 | """ 19 | with open(path, "rb") as f: 20 | return pickle.load(f) 21 | 22 | 23 | def save_pickle(obj, path): 24 | with open(path, "wb") as f: 25 | pickle.dump(obj, f) 26 | 27 | 28 | def download_plm(plm_full_name, dir): 29 | # initialize bert related parameters 30 | os.makedirs(dir, exist_ok=True) 31 | tokenizer = AutoTokenizer.from_pretrained(plm_full_name) 32 | model = AutoModel.from_pretrained(plm_full_name) 33 | tokenizer.save_pretrained(dir) 34 | model.save_pretrained(dir) 35 | 36 | 37 | def pack_results(impr_indices, masks, *associated_lists): 38 | """ 39 | group lists by impr_index 40 | Args: 41 | associated_lists: list of lists, where list[i] is associated with the impr_indices[i] 42 | 43 | Returns: 44 | Iterable: grouped labels (if inputted) and preds 45 | """ 46 | list_num = len(associated_lists) 47 | dicts = [defaultdict(list) for i in range(list_num)] 48 | 49 | for x in tqdm(zip(impr_indices, masks, *associated_lists), total=len(impr_indices), desc="Packing Results", ncols=80): 50 | key = x[0] 51 | mask = x[1] 52 | values = x[2:] 53 | for i in range(list_num): 54 | dicts[i][key].extend(values[i][mask].tolist()) 55 | 56 | grouped_lists = [list(d.values()) for d in dicts] 57 | return grouped_lists 58 | 59 | 60 | def sample_news(news, k): 61 | """ Sample ratio samples from news list. 62 | If length of news is less than ratio, pad zeros. 63 | 64 | Args: 65 | news (list): input news list 66 | ratio (int): sample number 67 | 68 | Returns: 69 | list: output of sample list. 70 | int: count of valid news 71 | """ 72 | num = len(news) 73 | if k > num: 74 | return news + [0] * (k - num), num 75 | else: 76 | return sample(news, k), k 77 | 78 | 79 | def tokenize(sent): 80 | """ Split sentence into words 81 | Args: 82 | sent (str): Input sentence 83 | 84 | Return: 85 | list: word list 86 | """ 87 | pat = re.compile(r"[-\w_]+|[.,!?;|]") 88 | 89 | return [x for x in pat.findall(sent.lower())] 90 | 91 | 92 | def construct_nid2index(news_path, cache_dir): 93 | """ 94 | Construct news ID to news INDEX dictionary, index starting from 1 95 | """ 96 | news_df = pd.read_table(news_path, index_col=None, names=[ 97 | "newsID", "category", "subcategory", "title", "abstract", "url", "entity_title", "entity_abstract"], quoting=3) 98 | 99 | nid2index = {} 100 | for v in news_df["newsID"]: 101 | if v in nid2index: 102 | continue 103 | # plus one because all news offsets from 1 104 | nid2index[v] = len(nid2index) + 1 105 | save_pickle(nid2index, os.path.join(cache_dir, "nid2index.pkl")) 106 | 107 | 108 | def construct_uid2index(data_root, cache_root): 109 | """ 110 | Construct user ID to user IDX dictionary, index starting from 0 111 | """ 112 | uid2index = {} 113 | user_df_list = [] 114 | behaviors_file_list = [os.path.join(data_root, "MIND", directory, "behaviors.tsv") for directory in ["MINDlarge_train", "MINDlarge_dev", "MINDlarge_test"]] 115 | 116 | for f in behaviors_file_list: 117 | user_df_list.append(pd.read_table(f, index_col=None, names=[ 118 | "imprID", "uid", "time", "hisstory", "abstract", "impression"], quoting=3)["uid"]) 119 | user_df = pd.concat(user_df_list).drop_duplicates() 120 | for v in user_df: 121 | uid2index[v] = len(uid2index) 122 | save_pickle(uid2index, os.path.join(cache_root, "MIND", "uid2index.pkl")) 123 | return uid2index 124 | 125 | 126 | def mrr_score(y_true, y_score): 127 | """Computing mrr score metric. 128 | 129 | Args: 130 | y_true (np.ndarray): ground-truth labels. 131 | y_score (np.ndarray): predicted labels. 132 | 133 | Returns: 134 | np.ndarray: mrr scores. 135 | """ 136 | # descending rank prediction score, get corresponding index of candidate news 137 | order = np.argsort(y_score)[::-1] 138 | # get ground truth for these indexes 139 | y_true = np.take(y_true, order) 140 | # check whether the prediction news with max score is the one being clicked 141 | # calculate the inverse of its index 142 | rr_score = y_true / (np.arange(len(y_true)) + 1) 143 | return np.sum(rr_score) / np.sum(y_true) 144 | 145 | 146 | def ndcg_score(y_true, y_score, k=10): 147 | """Computing ndcg score metric at k. 148 | 149 | Args: 150 | y_true (np.ndarray): ground-truth labels. 151 | y_score (np.ndarray): predicted labels. 152 | 153 | Returns: 154 | np.ndarray: ndcg scores. 155 | """ 156 | best = dcg_score(y_true, y_true, k) 157 | actual = dcg_score(y_true, y_score, k) 158 | return actual / best 159 | 160 | 161 | def hit_score(y_true, y_score, k=10): 162 | """Computing hit score metric at k. 163 | 164 | Args: 165 | y_true (np.ndarray): ground-truth labels. 166 | y_score (np.ndarray): predicted labels. 167 | 168 | Returns: 169 | np.ndarray: hit score. 170 | """ 171 | ground_truth = np.where(y_true == 1)[0] 172 | argsort = np.argsort(y_score)[::-1][:k] 173 | for idx in argsort: 174 | if idx in ground_truth: 175 | return 1 176 | return 0 177 | 178 | 179 | def dcg_score(y_true, y_score, k=10): 180 | """Computing dcg score metric at k. 181 | 182 | Args: 183 | y_true (np.ndarray): ground-truth labels. 184 | y_score (np.ndarray): predicted labels. 185 | 186 | Returns: 187 | np.ndarray: dcg scores. 188 | """ 189 | k = min(np.shape(y_true)[-1], k) 190 | order = np.argsort(y_score)[::-1] 191 | y_true = np.take(y_true, order[:k]) 192 | gains = 2 ** y_true - 1 193 | discounts = np.log2(np.arange(len(y_true)) + 2) 194 | return np.sum(gains / discounts) 195 | 196 | 197 | def compute_metrics(labels, preds, metrics): 198 | """Calculate metrics,such as auc, logloss. 199 | """ 200 | res = {} 201 | for metric in metrics: 202 | if metric == "auc": 203 | auc = np.mean( 204 | [ 205 | roc_auc_score(each_labels, each_preds) 206 | for each_labels, each_preds in zip(labels, preds) 207 | ] 208 | ) 209 | res["auc"] = round(auc, 4) 210 | elif metric == "rmse": 211 | rmse = mean_squared_error(np.asarray(labels), np.asarray(preds)) 212 | res["rmse"] = np.sqrt(round(rmse, 4)) 213 | elif metric == "logloss": 214 | # avoid logloss nan 215 | preds = [max(min(p, 1.0 - 10e-12), 10e-12) for p in preds] 216 | logloss = log_loss(np.asarray(labels), np.asarray(preds)) 217 | res["logloss"] = round(logloss, 4) 218 | elif metric == "acc": 219 | pred = np.asarray(preds) 220 | pred[pred >= 0.5] = 1 221 | pred[pred < 0.5] = 0 222 | acc = accuracy_score(np.asarray(labels), pred) 223 | res["acc"] = round(acc, 4) 224 | elif metric == "f1": 225 | pred = np.asarray(preds) 226 | pred[pred >= 0.5] = 1 227 | pred[pred < 0.5] = 0 228 | f1 = f1_score(np.asarray(labels), pred) 229 | res["f1"] = round(f1, 4) 230 | elif metric == "mean_mrr": 231 | mean_mrr = np.mean( 232 | [ 233 | mrr_score(each_labels, each_preds) 234 | for each_labels, each_preds in zip(labels, preds) 235 | ] 236 | ) 237 | res["mean_mrr"] = round(mean_mrr, 4) 238 | elif metric.startswith("ndcg"): # format like: ndcg@2;4;6;8 239 | ndcg_list = [1, 2] 240 | ks = metric.split("@") 241 | if len(ks) > 1: 242 | ndcg_list = [int(token) for token in ks[1].split(";")] 243 | for k in ndcg_list: 244 | ndcg_temp = np.mean( 245 | [ 246 | ndcg_score(each_labels, each_preds, k) 247 | for each_labels, each_preds in zip(labels, preds) 248 | ] 249 | ) 250 | res["ndcg@{0}".format(k)] = round(ndcg_temp, 4) 251 | elif metric.startswith("hit"): # format like: hit@2;4;6;8 252 | hit_list = [1, 2] 253 | ks = metric.split("@") 254 | if len(ks) > 1: 255 | hit_list = [int(token) for token in ks[1].split(";")] 256 | for k in hit_list: 257 | hit_temp = np.mean( 258 | [ 259 | hit_score(each_labels, each_preds, k) 260 | for each_labels, each_preds in zip(labels, preds) 261 | ] 262 | ) 263 | res["hit@{0}".format(k)] = round(hit_temp, 4) 264 | else: 265 | raise ValueError("not define this metric {0}".format(metric)) 266 | return res 267 | 268 | 269 | class Sequential_Sampler: 270 | def __init__(self, dataset_length, num_replicas, rank) -> None: 271 | super().__init__() 272 | len_per_worker = dataset_length / num_replicas 273 | self.start = round(len_per_worker * rank) 274 | self.end = round(len_per_worker * (rank + 1)) 275 | 276 | def __iter__(self): 277 | start = self.start 278 | end = self.end 279 | return iter(range(start, end, 1)) 280 | 281 | def __len__(self): 282 | return self.end - self.start 283 | 284 | 285 | 286 | class BM25(object): 287 | """ 288 | compute bm25 score on the entire corpus, instead of the one limited by signal_length 289 | """ 290 | def __init__(self, k=0.9, b=0.4): 291 | self.k = k 292 | self.b = b 293 | self.logger = logging.getLogger("BM25") 294 | 295 | 296 | def fit(self, documents): 297 | """ 298 | build term frequencies (how many times a term occurs in one news) and document frequencies (how many documents contains a term) 299 | """ 300 | doc_length = 0 301 | doc_count = len(documents) 302 | 303 | tfs = [] 304 | df = defaultdict(int) 305 | for document in documents: 306 | tf = defaultdict(int) 307 | words = tokenize(document) 308 | for word in words: 309 | tf[word] += 1 310 | df[word] += 1 311 | tfs.append(tf) 312 | doc_length += len(words) 313 | 314 | self.tfs = tfs 315 | 316 | idf = defaultdict(float) 317 | for word, freq in df.items(): 318 | idf[word] = math.log((doc_count - freq + 0.5 ) / (freq + 0.5) + 1) 319 | 320 | self.idf = idf 321 | self.doc_avg_length = doc_length / doc_count 322 | 323 | 324 | def __call__(self, documents): 325 | self.logger.info("computing BM25 scores...") 326 | if not hasattr(self, "idf"): 327 | self.fit(documents) 328 | sorted_documents = [] 329 | for tf, document in zip(self.tfs, documents): 330 | score_pairs = [] 331 | for word, freq in tf.items(): 332 | # skip word such as punctuations 333 | if len(word) == 1: 334 | continue 335 | score = (self.idf[word] * freq * (self.k + 1)) / (freq + self.k * (1 - self.b + self.b * len(document) / self.doc_avg_length)) 336 | score_pairs.append((word, score)) 337 | score_pairs = sorted(score_pairs, key=lambda x: x[1], reverse=True) 338 | sorted_document = " ".join([x[0] for x in score_pairs]) 339 | sorted_documents.append(sorted_document) 340 | return sorted_documents 341 | -------------------------------------------------------------------------------- /src/models/BaseModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import logging 5 | import subprocess 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.distributed as dist 10 | import scipy.stats as ss 11 | from tqdm import tqdm 12 | from transformers import get_linear_schedule_with_warmup, AutoTokenizer 13 | from utils.util import pack_results, compute_metrics 14 | 15 | 16 | 17 | class BaseModel(nn.Module): 18 | def __init__(self, manager, name): 19 | super().__init__() 20 | 21 | self.his_size = manager.his_size 22 | self.sequence_length = manager.sequence_length 23 | self.hidden_dim = manager.hidden_dim 24 | self.device = manager.device 25 | self.rank = manager.rank 26 | self.world_size = manager.world_size 27 | 28 | # set all enable_xxx as attributes 29 | for k,v in vars(manager).items(): 30 | if k.startswith("enable"): 31 | setattr(self, k, v) 32 | self.negative_num = manager.negative_num 33 | 34 | if name is None: 35 | name = type(self).__name__ 36 | if manager.verbose is not None: 37 | self.name = "-".join([name, manager.verbose]) 38 | else: 39 | self.name = name 40 | 41 | self.crossEntropy = nn.CrossEntropyLoss() 42 | self.logger = logging.getLogger(self.name) 43 | 44 | 45 | def get_optimizer(self, manager, dataloader_length): 46 | optimizer = optim.Adam(self.parameters(), lr=manager.learning_rate) 47 | 48 | scheduler = None 49 | if manager.scheduler == "linear": 50 | total_steps = dataloader_length * manager.epochs 51 | scheduler = get_linear_schedule_with_warmup(optimizer, 52 | num_warmup_steps = round(manager.warmup * total_steps), 53 | num_training_steps = total_steps) 54 | 55 | return optimizer, scheduler 56 | 57 | 58 | def _gather_tensors_variable_shape(self, local_tensor): 59 | """ 60 | gather tensors from all gpus 61 | 62 | Args: 63 | local_tensor: the tensor that needs to be gathered 64 | 65 | Returns: 66 | all_tensors: concatenation of local_tensor in each process 67 | """ 68 | all_tensors = [None for _ in range(self.world_size)] 69 | dist.all_gather_object(all_tensors, local_tensor) 70 | all_tensors[self.rank] = local_tensor 71 | return torch.cat(all_tensors, dim=0) 72 | 73 | 74 | def _compute_gate(self, token_id, attn_mask, gate_mask, token_weight): 75 | """ gating by the weight of each token 76 | 77 | Returns: 78 | gated_token_ids: [B, K] 79 | gated_attn_masks: [B, K] 80 | gated_token_weight: [B, K] 81 | """ 82 | if gate_mask is not None: 83 | keep_k_modifier = self.keep_k_modifier * (gate_mask.sum(dim=-1, keepdim=True) < self.k) 84 | pad_pos = ~((gate_mask + keep_k_modifier).bool()) # B, L 85 | token_weight = token_weight.masked_fill(pad_pos, -float('inf')) 86 | 87 | gated_token_weight, gated_token_idx = token_weight.topk(self.k) 88 | gated_token_weight = torch.softmax(gated_token_weight, dim=-1) 89 | gated_token_id = token_id.gather(dim=-1, index=gated_token_idx) 90 | gated_attn_mask = attn_mask.gather(dim=-1, index=gated_token_idx) 91 | 92 | # heuristic gate 93 | else: 94 | if token_id.dim() == 2: 95 | gated_token_id = token_id[:, 1: self.k + 1] 96 | gated_attn_mask = attn_mask[:, 1: self.k + 1] 97 | else: 98 | gated_token_id = token_id[:, :, 1: self.k + 1] 99 | gated_attn_mask = attn_mask[:, :, 1: self.k + 1] 100 | gated_token_weight = None 101 | 102 | return gated_token_id, gated_attn_mask, gated_token_weight 103 | 104 | 105 | @torch.no_grad() 106 | def dev(self, manager, loaders, log=False): 107 | self.eval() 108 | 109 | labels, preds = self._dev(manager, loaders) 110 | 111 | if self.rank == 0: 112 | metrics = compute_metrics(labels, preds, manager.metrics) 113 | metrics["main"] = metrics["auc"] 114 | self.logger.info(metrics) 115 | if log: 116 | manager._log(self.name, metrics) 117 | else: 118 | metrics = None 119 | 120 | if manager.distributed: 121 | dist.barrier(device_ids=[self.device]) 122 | 123 | return metrics 124 | 125 | 126 | @torch.no_grad() 127 | def test(self, manager, loaders, log=False): 128 | self.eval() 129 | 130 | preds = self._test(manager, loaders) 131 | 132 | if manager.rank == 0: 133 | save_dir = "data/cache/results/{}/{}/{}".format(self.name, manager.scale, os.path.split(manager.checkpoint)[-1]) 134 | os.makedirs(save_dir, exist_ok=True) 135 | save_path = save_dir + "/prediction.txt" 136 | 137 | index = 1 138 | with open(save_path, "w") as f: 139 | for pred in preds: 140 | array = np.asarray(pred) 141 | rank_list = ss.rankdata(1 - array, method="min") 142 | line = str(index) + " [" + ",".join([str(i) for i in rank_list]) + "]" + "\n" 143 | f.write(line) 144 | index += 1 145 | try: 146 | subprocess.run(f"zip -j {os.path.join(save_dir, 'prediction.zip')} {save_path}", shell=True) 147 | except: 148 | self.logger.warning("Zip Command Not Found! Skip zipping.") 149 | self.logger.info("written to prediction at {}!".format(save_path)) 150 | 151 | if manager.distributed: 152 | dist.barrier(device_ids=[self.device]) 153 | 154 | 155 | 156 | class TwoTowerBaseModel(BaseModel): 157 | def __init__(self, manager, name=None): 158 | """ 159 | base class for two tower models (news encoder and user encoder), which we can cache all news and user representations in advance and speed up inference 160 | """ 161 | super().__init__(manager, name) 162 | 163 | 164 | def _compute_logits(self, cdd_news_repr, user_repr): 165 | """ calculate batch of click probabolity 166 | 167 | Args: 168 | cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim] 169 | user_repr: user representation, [batch_size, 1, hidden_dim] 170 | 171 | Returns: 172 | score of each candidate news, [batch_size, cdd_size] 173 | """ 174 | score = cdd_news_repr.matmul(user_repr.transpose(-2,-1)).squeeze(-1)/math.sqrt(cdd_news_repr.size(-1)) 175 | return score 176 | 177 | 178 | def _encode_news(self, x, cdd=True): 179 | if cdd: 180 | token_id = x["cdd_token_id"].to(self.device) 181 | attn_mask = x['cdd_attn_mask'].to(self.device) 182 | else: 183 | token_id = x["his_token_id"].to(self.device) 184 | attn_mask = x["his_attn_mask"].to(self.device) 185 | news_token_embedding, news_embedding = self.newsEncoder(token_id, attn_mask) 186 | return news_token_embedding, news_embedding 187 | 188 | 189 | def _encode_user(self, x=None, his_news_embedding=None, his_mask=None): 190 | if x is None: 191 | user_embedding = self.userEncoder(his_news_embedding, his_mask=his_mask) 192 | else: 193 | _, his_news_embedding = self._encode_news(x, cdd=False) 194 | user_embedding = self.userEncoder(his_news_embedding, his_mask=x["his_mask"].to(self.device)) 195 | return user_embedding 196 | 197 | 198 | def forward(self, x): 199 | _, cdd_news_embedding = self._encode_news(x) 200 | user_embedding = self._encode_user(x) 201 | 202 | logits = self._compute_logits(cdd_news_embedding, user_embedding) 203 | labels = x["label"].to(self.device) 204 | loss = self.crossEntropy(logits, labels) 205 | return loss 206 | 207 | 208 | def infer(self, x): 209 | """ 210 | infer logits with cache when evaluating; subclasses may adjust this function in case the user side encoding is different 211 | """ 212 | cdd_idx = x["cdd_idx"].to(self.device, non_blocking=True) 213 | his_idx = x["his_idx"].to(self.device, non_blocking=True) 214 | cdd_embedding = self.news_embeddings[cdd_idx] 215 | his_embedding = self.news_embeddings[his_idx] 216 | user_embedding = self._encode_user(his_news_embedding=his_embedding, his_mask=x['his_mask'].to(self.device)) 217 | logits = self._compute_logits(cdd_embedding, user_embedding) 218 | return logits 219 | 220 | 221 | @torch.no_grad() 222 | def encode_news(self, manager, loader_news): 223 | # every process holds the same copy of news embeddings 224 | news_embeddings = torch.zeros((len(loader_news.dataset), self.hidden_dim), device=self.device) 225 | 226 | # only encode news on the master node to avoid any problems possibly raised by gatherring 227 | if manager.rank == 0: 228 | start_idx = end_idx = 0 229 | for i, x in enumerate(tqdm(loader_news, ncols=80, desc="Encoding News")): 230 | _, news_embedding = self._encode_news(x) 231 | 232 | end_idx = start_idx + news_embedding.shape[0] 233 | news_embeddings[start_idx: end_idx] = news_embedding 234 | start_idx = end_idx 235 | 236 | if manager.debug: 237 | if i > 5: 238 | break 239 | # broadcast news embeddings to all gpus 240 | if manager.distributed: 241 | dist.broadcast(news_embeddings, 0) 242 | 243 | self.news_embeddings = news_embeddings 244 | 245 | 246 | def _dev(self, manager, loaders): 247 | self.encode_news(manager, loaders["news"]) 248 | 249 | impr_indices = [] 250 | masks = [] 251 | labels = [] 252 | preds = [] 253 | 254 | for i, x in enumerate(tqdm(loaders["dev"], ncols=80, desc="Predicting")): 255 | logits = self.infer(x) 256 | 257 | masks.extend(x["cdd_mask"].tolist()) 258 | impr_indices.extend(x["impr_index"].tolist()) 259 | labels.extend(x["label"].tolist()) 260 | preds.extend(logits.tolist()) 261 | 262 | if manager.distributed: 263 | dist.barrier(device_ids=[self.device]) 264 | outputs = [None for i in range(self.world_size)] 265 | dist.all_gather_object(outputs, (impr_indices, masks, labels, preds)) 266 | 267 | if self.rank == 0: 268 | impr_indices = [] 269 | masks = [] 270 | labels = [] 271 | preds = [] 272 | for output in outputs: 273 | impr_indices.extend(output[0]) 274 | masks.extend(output[1]) 275 | labels.extend(output[2]) 276 | preds.extend(output[3]) 277 | 278 | masks = np.asarray(masks, dtype=np.bool8) 279 | labels = np.asarray(labels, dtype=np.int32) 280 | preds = np.asarray(preds, dtype=np.float32) 281 | labels, preds = pack_results(impr_indices, masks, labels, preds) 282 | 283 | else: 284 | masks = np.asarray(masks, dtype=np.bool8) 285 | labels = np.asarray(labels, dtype=np.int32) 286 | preds = np.asarray(preds, dtype=np.float32) 287 | labels, preds = pack_results(impr_indices, masks, labels, preds) 288 | 289 | return labels, preds 290 | 291 | 292 | def _test(self, manager, loaders): 293 | self.encode_news(manager, loaders["news"]) 294 | 295 | impr_indices = [] 296 | masks = [] 297 | preds = [] 298 | 299 | for i, x in enumerate(tqdm(loaders["test"], ncols=80, desc="Predicting")): 300 | logits = self.infer(x) 301 | 302 | masks.extend(x["cdd_mask"].tolist()) 303 | impr_indices.extend(x["impr_index"].tolist()) 304 | preds.extend(logits.tolist()) 305 | 306 | if manager.distributed: 307 | dist.barrier(device_ids=[self.device]) 308 | outputs = [None for i in range(self.world_size)] 309 | dist.all_gather_object(outputs, (impr_indices, masks, preds)) 310 | 311 | if self.rank == 0: 312 | impr_indices = [] 313 | masks = [] 314 | preds = [] 315 | for output in outputs: 316 | impr_indices.extend(output[0]) 317 | masks.extend(output[1]) 318 | preds.extend(output[2]) 319 | 320 | masks = np.asarray(masks, dtype=np.bool8) 321 | preds = np.asarray(preds, dtype=np.float32) 322 | preds, = pack_results(impr_indices, masks, preds) 323 | 324 | else: 325 | masks = np.asarray(masks, dtype=np.bool8) 326 | preds = np.asarray(preds, dtype=np.float32) 327 | preds, = pack_results(impr_indices, masks, preds) 328 | 329 | return preds 330 | 331 | 332 | 333 | class OneTowerBaseModel(BaseModel): 334 | def __init__(self, manager, name=None): 335 | super().__init__(manager, name) 336 | 337 | 338 | @torch.no_grad() 339 | def _dev(self, manager, loaders): 340 | impr_indices = [] 341 | masks = [] 342 | labels = [] 343 | preds = [] 344 | 345 | for i, x in enumerate(tqdm(loaders["dev"], ncols=80, desc="Predicting")): 346 | logits = self.infer(x) 347 | 348 | masks.extend(x["cdd_mask"].tolist()) 349 | impr_indices.extend(x["impr_index"].tolist()) 350 | labels.extend(x["label"].tolist()) 351 | preds.extend(logits.tolist()) 352 | 353 | if manager.distributed: 354 | dist.barrier(device_ids=[self.device]) 355 | outputs = [None for i in range(self.world_size)] 356 | dist.all_gather_object(outputs, (impr_indices, masks, labels, preds)) 357 | 358 | if self.rank == 0: 359 | impr_indices = [] 360 | masks = [] 361 | labels = [] 362 | preds = [] 363 | for output in outputs: 364 | impr_indices.extend(output[0]) 365 | masks.extend(output[1]) 366 | labels.extend(output[2]) 367 | preds.extend(output[3]) 368 | 369 | masks = np.asarray(masks, dtype=np.bool8) 370 | labels = np.asarray(labels, dtype=np.int32) 371 | preds = np.asarray(preds, dtype=np.float32) 372 | labels, preds = pack_results(impr_indices, masks, labels, preds) 373 | 374 | else: 375 | masks = np.asarray(masks, dtype=np.bool8) 376 | labels = np.asarray(labels, dtype=np.int32) 377 | preds = np.asarray(preds, dtype=np.float32) 378 | labels, preds = pack_results(impr_indices, masks, labels, preds) 379 | 380 | return labels, preds 381 | 382 | 383 | def _test(self, manager, loaders): 384 | impr_indices = [] 385 | masks = [] 386 | preds = [] 387 | 388 | for i, x in enumerate(tqdm(loaders["test"], ncols=80, desc="Predicting")): 389 | logits = self.infer(x) 390 | 391 | masks.extend(x["cdd_mask"].tolist()) 392 | impr_indices.extend(x["impr_index"].tolist()) 393 | preds.extend(logits.tolist()) 394 | 395 | if manager.distributed: 396 | dist.barrier(device_ids=[self.device]) 397 | outputs = [None for i in range(self.world_size)] 398 | dist.all_gather_object(outputs, (impr_indices, masks, preds)) 399 | 400 | if self.rank == 0: 401 | impr_indices = [] 402 | masks = [] 403 | preds = [] 404 | for output in outputs: 405 | impr_indices.extend(output[0]) 406 | masks.extend(output[1]) 407 | preds.extend(output[2]) 408 | 409 | masks = np.asarray(masks, dtype=np.bool8) 410 | preds = np.asarray(preds, dtype=np.float32) 411 | preds, = pack_results(impr_indices, masks, preds) 412 | 413 | else: 414 | masks = np.asarray(masks, dtype=np.bool8) 415 | preds = np.asarray(preds, dtype=np.float32) 416 | preds, = pack_results(impr_indices, masks, preds) 417 | 418 | return preds -------------------------------------------------------------------------------- /src/models/modules/encoder.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoModel 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | from .attention import scaled_dp_attention, extend_attention_mask, TFMLayer 7 | 8 | 9 | 10 | class BaseNewsEncoder(nn.Module): 11 | def __init__(self, manager): 12 | super().__init__() 13 | self.name = type(self).__name__[:-11] 14 | 15 | 16 | 17 | class BaseUserEncoder(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | self.name = type(self).__name__[:-11] 21 | 22 | 23 | 24 | class CnnNewsEncoder(BaseNewsEncoder): 25 | def __init__(self, manager): 26 | super().__init__(manager) 27 | 28 | self.embedding_dim = manager.plm_dim 29 | bert = AutoModel.from_pretrained(manager.plm_dir) 30 | self.embedding = bert.embeddings.word_embeddings 31 | 32 | self.cnn = nn.Conv1d( 33 | in_channels=self.embedding_dim, 34 | out_channels=manager.hidden_dim, 35 | kernel_size=3, 36 | padding=1 37 | ) 38 | nn.init.xavier_normal_(self.cnn.weight) 39 | 40 | self.news_query = nn.Parameter(torch.randn((1, manager.hidden_dim), requires_grad=True)) 41 | nn.init.xavier_normal_(self.news_query) 42 | self.newsProject = nn.Linear(manager.hidden_dim, manager.hidden_dim) 43 | nn.init.xavier_normal_(self.newsProject.weight) 44 | self.Tanh = nn.Tanh() 45 | self.Relu = nn.ReLU() 46 | 47 | 48 | def forward(self, token_id, attn_mask, token_weight=None): 49 | """ encode news through 1-d CNN 50 | """ 51 | original_shape = token_id.shape 52 | token_embedding = self.embedding(token_id) 53 | if token_weight is not None: 54 | token_embedding = token_embedding * token_weight.unsqueeze(-1) 55 | cnn_input = token_embedding.view(-1, original_shape[-1], self.embedding_dim).transpose(-2, -1) 56 | cnn_output = self.Relu(self.cnn(cnn_input)).transpose(-2, -1).view(*original_shape, -1) 57 | news_embedding = scaled_dp_attention(self.news_query, self.Tanh(self.newsProject(cnn_output)), cnn_output, attn_mask=attn_mask.unsqueeze(-2)).squeeze(dim=-2) 58 | return cnn_output, news_embedding 59 | 60 | 61 | 62 | class AllBertNewsEncoder(BaseNewsEncoder): 63 | def __init__(self, manager): 64 | super().__init__(manager) 65 | self.plm = AutoModel.from_pretrained(manager.plm_dir) 66 | self.plm.pooler = None 67 | 68 | 69 | def forward(self, token_id, attn_mask): 70 | original_shape = token_id.shape 71 | token_id = token_id.view(-1, original_shape[-1]) 72 | attn_mask = attn_mask.view(-1, original_shape[-1]) 73 | 74 | token_embedding = self.plm(token_id, attention_mask=attn_mask).last_hidden_state 75 | news_embedding = token_embedding[:, 0].view(*original_shape[:-1], -1) 76 | token_embedding = token_embedding.view(*original_shape, -1) 77 | return token_embedding, news_embedding 78 | 79 | 80 | 81 | class GatedBertNewsEncoder(BaseNewsEncoder): 82 | def __init__(self, manager): 83 | super().__init__(manager) 84 | plm = AutoModel.from_pretrained(manager.plm_dir) 85 | self.embeddings = plm.embeddings 86 | self.plm = plm.encoder 87 | 88 | self.news_query = nn.Parameter(torch.randn((1, manager.hidden_dim), requires_grad=True)) 89 | nn.init.xavier_normal_(self.news_query) 90 | # self.newsProject = nn.Linear(manager.hidden_dim, manager.hidden_dim) 91 | # nn.init.xavier_normal_(self.newsProject.weight) 92 | # self.Tanh = nn.Tanh() 93 | 94 | 95 | def forward(self, token_id, attn_mask, token_weight=None): 96 | original_shape = token_id.shape 97 | token_id = token_id.view(-1, original_shape[-1]) 98 | attn_mask = attn_mask.view(-1, original_shape[-1]) 99 | 100 | token_embedding = self.embeddings(token_id) 101 | 102 | if token_weight is not None: 103 | token_weight = token_weight.view(-1, original_shape[-1]).unsqueeze(-1) 104 | token_embedding = token_embedding * (token_weight + (1 - token_weight.detach())) 105 | 106 | extended_attn_mask = extend_attention_mask(attn_mask) 107 | token_embedding = self.plm(token_embedding, attention_mask=extended_attn_mask).last_hidden_state 108 | # we do not keep [CLS] and [SEP] after gating, so it's better to use attention pooling 109 | news_embedding = scaled_dp_attention(self.news_query, token_embedding, token_embedding, attn_mask=attn_mask.unsqueeze(-2)).squeeze(dim=-2).view(*original_shape[:-1], -1) 110 | token_embedding = token_embedding.view(*original_shape, -1) 111 | return token_embedding, news_embedding 112 | 113 | 114 | 115 | class TfmNewsEncoder(BaseNewsEncoder): 116 | def __init__(self, manager): 117 | super().__init__(manager) 118 | self.embedding_dim = manager.plm_dim 119 | bert = AutoModel.from_pretrained(manager.plm_dir) 120 | self.embedding = bert.embeddings.word_embeddings 121 | self.transformer = TFMLayer(manager.hidden_dim, manager.head_num, 0.1) 122 | 123 | self.news_query = nn.Parameter(torch.randn((1, manager.hidden_dim), requires_grad=True)) 124 | nn.init.xavier_normal_(self.news_query) 125 | self.newsProject = nn.Linear(manager.hidden_dim, manager.hidden_dim) 126 | nn.init.xavier_normal_(self.newsProject.weight) 127 | self.Tanh = nn.Tanh() 128 | 129 | 130 | def forward(self, token_id, attn_mask, token_weight=None): 131 | original_shape = token_id.shape 132 | token_id = token_id.view(-1, original_shape[-1]) 133 | attn_mask = attn_mask.view(-1, original_shape[-1]) 134 | 135 | token_embedding = self.embedding(token_id) 136 | if token_weight is not None: 137 | token_weight = token_weight.view(-1, original_shape[-1]) 138 | token_embedding = token_embedding * token_weight.unsqueeze(-1) 139 | 140 | token_embedding = self.transformer(token_embedding, attention_mask=attn_mask) 141 | news_embedding = scaled_dp_attention(self.news_query, self.Tanh(self.newsProject(token_embedding)), token_embedding, attn_mask=attn_mask.unsqueeze(-2)).squeeze(dim=-2).view(*original_shape[:-1], -1) 142 | token_embedding = token_embedding.view(*original_shape, -1) 143 | return token_embedding, news_embedding 144 | 145 | 146 | 147 | class HDCNNNewsEncoder(BaseNewsEncoder): 148 | def __init__(self, manager): 149 | super().__init__(manager) 150 | 151 | self.hidden_dim = manager.hidden_dim 152 | 153 | self.embedding = nn.Embedding(manager.vocab_size, 300) 154 | self.embedding_dim = 300 155 | 156 | self.level = 3 157 | 158 | self.cnn_d1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=manager.hidden_dim, 159 | kernel_size=3, dilation=1, padding=1) 160 | self.cnn_d2 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=manager.hidden_dim, 161 | kernel_size=3, dilation=2, padding=2) 162 | self.cnn_d3 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=manager.hidden_dim, 163 | kernel_size=3, dilation=3, padding=3) 164 | 165 | self.ReLU = nn.ReLU() 166 | self.layerNorm = nn.LayerNorm(manager.hidden_dim) 167 | self.dropOut = nn.Dropout(p=manager.dropout_p) 168 | 169 | 170 | def forward(self, token_id, attn_mask=None, **kargs): 171 | """ 172 | Returns: 173 | token_embedding: B, N, V, L, D 174 | """ 175 | original_shape = token_id.shape 176 | token_embedding = self.dropOut(self.embedding(token_id)) 177 | cnn_input = token_embedding.view(-1, original_shape[-1], self.embedding_dim) 178 | 179 | token_embedding = torch.zeros( 180 | (*cnn_input.shape[:2], self.level, self.hidden_dim), device=cnn_input.device) 181 | 182 | cnn_input = cnn_input.transpose(-1, -2) 183 | 184 | token_embedding_d1 = self.cnn_d1(cnn_input).transpose(-2,-1) 185 | token_embedding_d1 = self.layerNorm(token_embedding_d1) 186 | token_embedding[:,:,0,:] = self.ReLU(token_embedding_d1) 187 | token_embedding[:,:,0,:] = token_embedding_d1 188 | 189 | token_embedding_d2 = self.cnn_d2(cnn_input).transpose(-2,-1) 190 | token_embedding_d2 = self.layerNorm(token_embedding_d2) 191 | token_embedding[:,:,1,:] = self.ReLU(token_embedding_d2) 192 | token_embedding[:,:,1,:] = token_embedding_d2 193 | 194 | token_embedding_d3 = self.cnn_d3(cnn_input).transpose(-2,-1) 195 | token_embedding_d3 = self.layerNorm(token_embedding_d3) 196 | token_embedding[:,:,2,:] = self.ReLU(token_embedding_d3) 197 | token_embedding[:,:,2,:] = token_embedding_d3 198 | 199 | token_embedding = token_embedding.view(*original_shape, self.level, self.hidden_dim).transpose(-2, -3) 200 | return token_embedding, None 201 | 202 | 203 | 204 | class RnnUserEncoder(BaseUserEncoder): 205 | def __init__(self, manager): 206 | super().__init__() 207 | # if manager.encoderU == 'gru': 208 | self.rnn = nn.GRU(manager.hidden_dim, manager.hidden_dim, batch_first=True) 209 | # elif manager.encoderU == 'lstm': 210 | # self.rnn = nn.LSTM(manager.hidden_dim, manager.hidden_dim, batch_first=True) 211 | for name, param in self.rnn.named_parameters(): 212 | if 'weight' in name: 213 | nn.init.orthogonal_(param) 214 | 215 | 216 | def forward(self, news_embedding, his_mask): 217 | """ 218 | encode user history into a representation vector 219 | 220 | Args: 221 | news_embedding: batch of news representations, [batch_size, *, hidden_dim] 222 | news_mask: [batch_size, *, 1] 223 | 224 | Returns: 225 | user_embedding: user representation (coarse), [batch_size, 1, hidden_dim] 226 | """ 227 | lens = his_mask.sum(dim=-1).cpu() 228 | rnn_input = pack_padded_sequence(news_embedding, lens, batch_first=True, enforce_sorted=False) 229 | 230 | _, user_embedding = self.rnn(rnn_input) 231 | if type(user_embedding) is tuple: 232 | user_embedding = user_embedding[0] 233 | return user_embedding.transpose(0,1) 234 | 235 | 236 | 237 | class SumUserEncoder(BaseUserEncoder): 238 | def __init__(self, manager): 239 | super().__init__() 240 | 241 | 242 | def forward(self, news_embedding, **kargs): 243 | """ 244 | encode user history into a representation vector 245 | 246 | Args: 247 | news_embedding: batch of news representations, [batch_size, *, hidden_dim] 248 | news_mask: [batch_size, *, 1] 249 | 250 | Returns: 251 | user_embedding: user representation (coarse), [batch_size, 1, hidden_dim] 252 | """ 253 | user_embedding = news_embedding.sum(dim=-2, keepdim=True) 254 | return user_embedding 255 | 256 | 257 | 258 | class AvgUserEncoder(BaseUserEncoder): 259 | def __init__(self, manager): 260 | super().__init__() 261 | 262 | 263 | def forward(self, news_embedding, **kargs): 264 | """ 265 | encode user history into a representation vector 266 | 267 | Args: 268 | news_embedding: batch of news representations, [batch_size, *, hidden_dim] 269 | news_mask: [batch_size, *, 1] 270 | 271 | Returns: 272 | user_embedding: user representation (coarse), [batch_size, 1, hidden_dim] 273 | """ 274 | user_embedding = news_embedding.mean(dim=-2, keepdim=True) 275 | return user_embedding 276 | 277 | 278 | 279 | class AttnUserEncoder(BaseUserEncoder): 280 | def __init__(self, manager): 281 | super().__init__() 282 | 283 | self.user_query = nn.Parameter(torch.randn((1, manager.hidden_dim), requires_grad=True)) 284 | nn.init.xavier_normal_(self.user_query) 285 | 286 | 287 | def forward(self, news_embedding, **kargs): 288 | """ 289 | encode user history into a representation vector 290 | 291 | Args: 292 | news_embedding: batch of news representations, [batch_size, *, hidden_dim] 293 | news_mask: [batch_size, *, 1] 294 | 295 | Returns: 296 | user_embedding: user representation (coarse), [batch_size, 1, hidden_dim] 297 | """ 298 | user_embedding = scaled_dp_attention(self.user_query, news_embedding, news_embedding) 299 | return user_embedding 300 | 301 | 302 | 303 | class TfmUserEncoder(BaseUserEncoder): 304 | def __init__(self, manager): 305 | super().__init__() 306 | self.transformer = TFMLayer(manager.hidden_dim, manager.head_num, 0.1) 307 | self.user_query = nn.Parameter(torch.randn((1, manager.hidden_dim), requires_grad=True)) 308 | nn.init.xavier_normal_(self.user_query) 309 | self.userProject = nn.Linear(manager.hidden_dim, manager.hidden_dim) 310 | nn.init.xavier_normal_(self.userProject.weight) 311 | self.Tanh = nn.Tanh() 312 | 313 | 314 | def forward(self, news_embedding, his_mask): 315 | """ 316 | encode user history into a representation vector 317 | 318 | Args: 319 | news_embedding: batch of news representations, [batch_size, *, hidden_dim] 320 | news_mask: [batch_size, *, 1] 321 | 322 | Returns: 323 | user_embedding: user representation (coarse), [batch_size, 1, hidden_dim] 324 | """ 325 | news_embedding = self.transformer(news_embedding, attention_mask=his_mask) 326 | user_embedding = scaled_dp_attention(self.user_query, self.Tanh(self.userProject(news_embedding)), news_embedding, attn_mask=his_mask.unsqueeze(-2)) 327 | return user_embedding 328 | 329 | 330 | 331 | class BertCrossEncoder(nn.Module): 332 | def __init__(self, manager): 333 | super().__init__() 334 | self.name = "AllBert" 335 | self.plm = AutoModel.from_pretrained(manager.plm_dir) 336 | 337 | self.news_query = nn.Parameter(torch.randn((1, manager.plm_dim), requires_grad=True)) 338 | nn.init.xavier_normal_(self.news_query) 339 | self.newsProject = nn.Linear(manager.plm_dim, manager.plm_dim) 340 | nn.init.xavier_normal_(self.newsProject.weight) 341 | self.Tanh = nn.Tanh() 342 | 343 | 344 | def forward(self, token_id, attn_mask): 345 | original_shape = token_id.shape 346 | token_id = token_id.view(-1, original_shape[-1]) 347 | attn_mask = attn_mask.view(-1, original_shape[-1]) 348 | 349 | # token_embedding = self.plm(token_id, attention_mask=attn_mask).last_hidden_state 350 | # news_embedding = token_embedding[:, 0].view(*original_shape[:-1], -1) 351 | # token_embedding = token_embedding.view(*original_shape, -1) 352 | 353 | token_embedding = self.plm(token_id, attention_mask=attn_mask).last_hidden_state 354 | token_embedding = token_embedding.view(*original_shape, -1) # B, N, L, D 355 | attn_mask = attn_mask.view(*original_shape).unsqueeze(-2) 356 | news_embedding = token_embedding.mean(dim=-2) 357 | # news_embedding = scaled_dp_attention(self.news_query, self.Tanh(self.newsProject(token_embedding)), token_embedding, attn_mask=attn_mask).squeeze(dim=-2) 358 | return news_embedding 359 | 360 | 361 | 362 | class TFMCrossEncoder(nn.Module): 363 | def __init__(self, manager): 364 | super().__init__() 365 | self.name = type(self).__name__[:-7] 366 | self.embedding_dim = manager.plm_dim 367 | bert = AutoModel.from_pretrained(manager.plm_dir) 368 | self.embedding = bert.embeddings.word_embeddings 369 | 370 | self.transformer = TFMLayer(manager.plm_dim, manager.head_num, manager.dropout_p) 371 | 372 | self.news_query = nn.Parameter(torch.randn((1, manager.plm_dim), requires_grad=True)) 373 | nn.init.xavier_normal_(self.news_query) 374 | self.newsProject = nn.Linear(manager.plm_dim, manager.plm_dim) 375 | nn.init.xavier_normal_(self.newsProject.weight) 376 | self.Tanh = nn.Tanh() 377 | 378 | 379 | def forward(self, token_id, attn_mask): 380 | original_shape = token_id.shape 381 | token_id = token_id.view(-1, original_shape[-1]) 382 | attn_mask = attn_mask.view(-1, original_shape[-1]) 383 | 384 | # token_embedding = self.plm(token_id, attention_mask=attn_mask).last_hidden_state 385 | # news_embedding = token_embedding[:, 0].view(*original_shape[:-1], -1) 386 | # token_embedding = token_embedding.view(*original_shape, -1) 387 | 388 | token_embedding = self.embedding(token_id) 389 | token_embedding = self.transformer(token_embedding, attention_mask=attn_mask) 390 | news_embedding = scaled_dp_attention(self.news_query, self.Tanh(self.newsProject(token_embedding)), token_embedding, attn_mask=attn_mask.unsqueeze(-2)).squeeze(dim=-2).view(*original_shape[:-1], -1) 391 | return news_embedding 392 | -------------------------------------------------------------------------------- /src/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import subprocess 4 | import numpy as np 5 | import torch.distributed as dist 6 | from tqdm import tqdm 7 | from multiprocessing import Pool 8 | from transformers import AutoTokenizer 9 | from torch.utils.data import Dataset 10 | from utils.util import load_pickle, save_pickle, construct_uid2index, construct_nid2index, sample_news 11 | 12 | 13 | 14 | class MIND(Dataset): 15 | def __init__(self, manager, data_dir, load_news=True, load_behaviors=True) -> None: 16 | super().__init__() 17 | self.logger = logging.getLogger(type(self).__name__) 18 | 19 | self.his_size = manager.his_size 20 | self.impr_size = manager.impr_size 21 | 22 | self.max_title_length = manager.max_title_length 23 | self.max_abs_length = manager.max_abs_length 24 | self.title_length = manager.title_length 25 | self.abs_length = manager.abs_length 26 | 27 | self.negative_num = manager.negative_num 28 | 29 | self.cache_root = manager.cache_root 30 | self.data_root = manager.data_root 31 | 32 | data_dir_name = data_dir.split("/")[-1] 33 | self.news_cache_root = os.path.join(manager.cache_root, "MIND", data_dir_name, "news") 34 | if "train" in data_dir_name: 35 | self.behaviors_cache_dir = os.path.join(manager.cache_root, "MIND", data_dir_name, "behaviors") 36 | else: 37 | # cache by impr size 38 | self.behaviors_cache_dir = os.path.join(manager.cache_root, "MIND", data_dir_name, "behaviors", str(self.impr_size)) 39 | 40 | news_num = manager.news_nums[data_dir_name] + 1 41 | 42 | # set all enable_xxx as attributes 43 | for k,v in vars(manager).items(): 44 | if k.startswith("enable"): 45 | setattr(self, k, v) 46 | 47 | if manager.rank == 0: 48 | if not os.path.exists(os.path.join(self.news_cache_root, "title_token_ids.pkl")): 49 | news_path = os.path.join(data_dir, "news.tsv") 50 | cache_news(news_path, self.news_cache_root, manager) 51 | if not os.path.exists(os.path.join(self.behaviors_cache_dir, "behaviors.pkl")): 52 | nid2index = load_pickle(os.path.join(self.news_cache_root, "nid2index.pkl")) 53 | cache_behaviors(os.path.join(data_dir, "behaviors.tsv"), self.behaviors_cache_dir, nid2index, manager) 54 | 55 | if manager.distributed: 56 | dist.barrier(device_ids=[manager.device]) 57 | 58 | if manager.rank == 0: 59 | self.logger.info(f"Loading Cache at {data_dir_name}") 60 | 61 | if load_news: 62 | pad_token_id = manager.special_token_ids["[PAD]"] 63 | sep_token_id = manager.special_token_ids["[SEP]"] 64 | cls_token_id = manager.special_token_ids["[CLS]"] 65 | punc_token_ids = manager.special_token_ids["punctuations"] 66 | 67 | # index=0 is padded news 68 | token_ids = [[] for _ in range(news_num)] 69 | self.sequence_length = manager.sequence_length 70 | 71 | start_idx = 0 72 | if "title" in self.enable_fields: 73 | title_token_ids = load_pickle(os.path.join(self.news_cache_root, "title_token_ids.pkl")) 74 | for i, token_id in enumerate(title_token_ids, start=1): 75 | token_id = token_id[start_idx: start_idx + self.title_length] 76 | # use [SEP] to separate title and abstract 77 | if len(token_id) > 2 - start_idx: 78 | token_id[-1] = sep_token_id 79 | token_ids[i].extend(token_id.copy()) 80 | if start_idx == 0: 81 | start_idx += 1 82 | 83 | if "abs" in self.enable_fields: 84 | abs_token_ids = load_pickle(os.path.join(self.news_cache_root, "abs_token_ids.pkl")) 85 | for i, token_id in enumerate(abs_token_ids, start=1): 86 | # offset to remove an extra [CLS] 87 | token_id = token_id[start_idx: self.abs_length + start_idx] 88 | # use [SEP] to separate abs and abstract 89 | if len(token_id) > 2 - start_idx: 90 | token_id[-1] = sep_token_id 91 | token_ids[i].extend(token_id.copy()) 92 | if start_idx == 0: 93 | start_idx += 1 94 | 95 | attn_masks = np.zeros((news_num, self.sequence_length), dtype=np.int64) 96 | for i, token_id in enumerate(token_ids): 97 | s_len = len(token_id) 98 | if s_len < self.sequence_length: 99 | token_ids[i] = token_id + [pad_token_id] * (self.sequence_length - s_len) 100 | attn_masks[i][:s_len] = 1 101 | 102 | self.token_ids = np.asarray(token_ids, dtype=np.int64) 103 | self.attn_masks = np.asarray(attn_masks, dtype=np.int64) 104 | 105 | if load_behaviors: 106 | behaviors = load_pickle(os.path.join(self.behaviors_cache_dir, "behaviors.pkl")) 107 | for k,v in behaviors.items(): 108 | setattr(self, k, v) 109 | 110 | 111 | def __len__(self): 112 | if hasattr(self, "imprs"): 113 | return len(self.imprs) 114 | else: 115 | return len(self.token_ids) 116 | 117 | 118 | 119 | class MIND_Train(MIND): 120 | def __init__(self, manager) -> None: 121 | data_dir = os.path.join(manager.data_root, "MIND", f"MIND{manager.scale}_train") 122 | super().__init__(manager, data_dir) 123 | 124 | self.negative_num = manager.negative_num 125 | 126 | 127 | def __getitem__(self, index): 128 | impr_index, positive = self.imprs[index] 129 | negatives = self.negatives[impr_index] 130 | histories = self.histories[impr_index] 131 | user_index = self.user_indices[impr_index] 132 | 133 | negatives, valid_num = sample_news(negatives, self.negative_num) 134 | cdd_idx = np.asarray([positive] + negatives, dtype=np.int64) 135 | cdd_mask = np.zeros(len(cdd_idx), dtype=np.int64) 136 | cdd_mask[:1 + valid_num] = 1 137 | 138 | his_idx = histories[:self.his_size] 139 | his_mask = np.zeros(self.his_size, dtype=np.int64) 140 | if len(his_idx) == 0: 141 | his_mask[0] = 1 142 | else: 143 | his_mask[:len(his_idx)] = 1 144 | # padding user history in case there are fewer historical clicks 145 | if len(his_idx) < self.his_size: 146 | his_idx = his_idx + [0] * (self.his_size - len(his_idx)) 147 | his_idx = np.asarray(his_idx, dtype=np.int64) 148 | 149 | # the first entry is the positive instance 150 | label = 0 151 | 152 | cdd_token_id = self.token_ids[cdd_idx] 153 | his_token_id = self.token_ids[his_idx] 154 | cdd_attn_mask = self.attn_masks[cdd_idx] 155 | his_attn_mask = self.attn_masks[his_idx] 156 | 157 | return_dict = { 158 | "impr_index": impr_index, 159 | "user_index": user_index, 160 | "cdd_idx": cdd_idx, 161 | "his_idx": his_idx, 162 | "cdd_mask": cdd_mask, 163 | "his_mask": his_mask, 164 | "cdd_token_id": cdd_token_id, 165 | "his_token_id": his_token_id, 166 | "cdd_attn_mask": cdd_attn_mask, 167 | "his_attn_mask": his_attn_mask, 168 | "label": label 169 | } 170 | return return_dict 171 | 172 | 173 | 174 | class MIND_Dev(MIND): 175 | def __init__(self, manager) -> None: 176 | data_dir = os.path.join(manager.data_root, "MIND", f"MIND{manager.scale}_dev") 177 | super().__init__(manager, data_dir) 178 | 179 | 180 | def __getitem__(self, index): 181 | impr_index, impr_news = self.imprs[index] 182 | histories = self.histories[impr_index] 183 | user_index = self.user_indices[impr_index] 184 | 185 | # use -1 as padded news' label 186 | label = np.asarray(self.labels[index] + [-1] * (self.impr_size - len(impr_news)), dtype=np.int64) 187 | 188 | cdd_mask = np.zeros(self.impr_size, dtype=np.bool8) 189 | cdd_mask[:len(impr_news)] = 1 190 | cdd_idx = np.asarray(impr_news + [0] * (self.impr_size - len(impr_news)), dtype=np.int64) 191 | 192 | his_idx = histories[:self.his_size] 193 | his_mask = np.zeros(self.his_size, dtype=np.int64) 194 | if len(his_idx) == 0: 195 | his_mask[0] = 1 196 | else: 197 | his_mask[:len(his_idx)] = 1 198 | # padding user history in case there are fewer historical clicks 199 | if len(his_idx) < self.his_size: 200 | his_idx = his_idx + [0] * (self.his_size - len(his_idx)) 201 | his_idx = np.asarray(his_idx, dtype=np.int64) 202 | 203 | cdd_token_id = self.token_ids[cdd_idx] 204 | his_token_id = self.token_ids[his_idx] 205 | cdd_attn_mask = self.attn_masks[cdd_idx] 206 | his_attn_mask = self.attn_masks[his_idx] 207 | 208 | return_dict = { 209 | "impr_index": impr_index, 210 | "user_index": user_index, 211 | "cdd_idx": cdd_idx, 212 | "his_idx": his_idx, 213 | "cdd_mask": cdd_mask, 214 | "his_mask": his_mask, 215 | "cdd_token_id": cdd_token_id, 216 | "his_token_id": his_token_id, 217 | "cdd_attn_mask": cdd_attn_mask, 218 | "his_attn_mask": his_attn_mask, 219 | "label": label 220 | } 221 | return return_dict 222 | 223 | 224 | 225 | class MIND_Test(MIND): 226 | def __init__(self, manager) -> None: 227 | data_dir = os.path.join(manager.data_root, "MIND", f"MIND{manager.scale}_test") 228 | super().__init__(manager, data_dir) 229 | 230 | 231 | def __getitem__(self, index): 232 | impr_index, impr_news = self.imprs[index] 233 | histories = self.histories[impr_index] 234 | user_index = self.user_indices[impr_index] 235 | 236 | cdd_mask = np.zeros(self.impr_size, dtype=np.bool8) 237 | cdd_mask[:len(impr_news)] = 1 238 | cdd_idx = np.asarray(impr_news + [0] * (self.impr_size - len(impr_news)), dtype=np.int64) 239 | 240 | his_idx = histories[:self.his_size] 241 | his_mask = np.zeros(self.his_size, dtype=np.int64) 242 | if len(his_idx) == 0: 243 | his_mask[0] = 1 244 | else: 245 | his_mask[:len(his_idx)] = 1 246 | # padding user history in case there are fewer historical clicks 247 | if len(his_idx) < self.his_size: 248 | his_idx = his_idx + [0] * (self.his_size - len(his_idx)) 249 | his_idx = np.asarray(his_idx, dtype=np.int64) 250 | 251 | cdd_token_id = self.token_ids[cdd_idx] 252 | his_token_id = self.token_ids[his_idx] 253 | cdd_attn_mask = self.attn_masks[cdd_idx] 254 | his_attn_mask = self.attn_masks[his_idx] 255 | 256 | return_dict = { 257 | "impr_index": impr_index, 258 | "user_index": user_index, 259 | "cdd_idx": cdd_idx, 260 | "his_idx": his_idx, 261 | "cdd_mask": cdd_mask, 262 | "his_mask": his_mask, 263 | "cdd_token_id": cdd_token_id, 264 | "his_token_id": his_token_id, 265 | "cdd_attn_mask": cdd_attn_mask, 266 | "his_attn_mask": his_attn_mask, 267 | } 268 | return return_dict 269 | 270 | 271 | class MIND_News(MIND): 272 | def __init__(self, manager) -> None: 273 | data_mode = "test" if manager.mode == "test" else "dev" 274 | data_dir = os.path.join(manager.data_root, "MIND", f"MIND{manager.scale}_{data_mode}") 275 | super().__init__(manager, data_dir, load_news=True, load_behaviors=False) 276 | 277 | # cut off padded news 278 | # self.token_ids = self.token_ids[1:] 279 | # self.attn_masks = self.attn_masks[1:] 280 | # if hasattr(self, "gate_masks"): 281 | # self.gate_masks = self.gate_masks[1:] 282 | 283 | 284 | def __getitem__(self, index): 285 | cdd_token_id = self.token_ids[index] 286 | cdd_attn_mask = self.attn_masks[index] 287 | 288 | return_dict = { 289 | "cdd_idx": index, 290 | "cdd_token_id": cdd_token_id, 291 | "cdd_attn_mask": cdd_attn_mask, 292 | } 293 | return return_dict 294 | 295 | 296 | 297 | 298 | def tokenize_news(news_path, cache_dir, news_num, tokenizer, max_title_length, max_abs_length): 299 | title_token_ids = [[]] * news_num 300 | abs_token_ids = [[]] * news_num 301 | 302 | with open(news_path, 'r') as f: 303 | for idx, line in enumerate(tqdm(f, total=news_num, desc="Tokenizing News", ncols=80)): 304 | id, category, subcategory, title, abs, _, _, _ = line.strip("\n").split("\t") 305 | 306 | title_token_id = tokenizer.encode(title, max_length=max_title_length) 307 | title_token_ids[idx] = title_token_id 308 | 309 | abs_token_id = tokenizer.encode(abs, max_length=max_abs_length) 310 | abs_token_ids[idx] = abs_token_id 311 | 312 | save_pickle(title_token_ids, os.path.join(cache_dir, "title_token_ids.pkl")) 313 | save_pickle(abs_token_ids, os.path.join(cache_dir, "abs_token_ids.pkl")) 314 | 315 | 316 | def cache_news(news_path, news_cache_root, manager): 317 | news_num = int(subprocess.check_output(["wc", "-l", news_path]).decode("utf-8").split()[0]) 318 | 319 | # different news file corresponds to different cache directory 320 | os.makedirs(news_cache_root, exist_ok=True) 321 | 322 | # TODO: bm25, entity and keyword 323 | tokenizer = AutoTokenizer.from_pretrained(manager.plm_dir) 324 | tokenize_news(news_path, news_cache_root, news_num, tokenizer, manager.max_title_length, manager.max_abs_length) 325 | 326 | if not os.path.exists(os.path.join(news_cache_root, "nid2index.pkl")): 327 | print(f"mapping news id to news index and save at {os.path.join(news_cache_root, 'nid2index.pkl')}...") 328 | construct_nid2index(news_path, news_cache_root) 329 | 330 | 331 | def cache_behaviors(behaviors_path, cache_dir, nid2index, manager): 332 | if not os.path.exists(os.path.join(manager.cache_root, 'MIND', 'uid2index.pkl')): 333 | print(f"mapping user id to user index and save at {os.path.join(manager.cache_root, 'MIND', 'uid2index.pkl')}...") 334 | uid2index = construct_uid2index(manager.data_root, manager.cache_root) 335 | else: 336 | uid2index = load_pickle(os.path.join(manager.cache_root, 'MIND', 'uid2index.pkl')) 337 | 338 | os.makedirs(cache_dir, exist_ok=True) 339 | imprs = [] 340 | histories = [] 341 | user_indices = [] 342 | impr_index = 0 343 | 344 | if "train" in behaviors_path: 345 | negatives = [] 346 | with open(behaviors_path, "r") as f: 347 | for line in tqdm(f, desc="Caching User Behaviors", ncols=80): 348 | _, uid, _, history, impression = line.strip("\n").split("\t") 349 | 350 | history = [nid2index[x] for x in history.split()] 351 | interaction_pair = impression.split() 352 | impr_news = [nid2index[x.split("-")[0]] for x in interaction_pair] 353 | label = [int(i.split("-")[1]) for i in interaction_pair] 354 | uindex = uid2index[uid] 355 | 356 | negative = [] 357 | for news, lab in zip(impr_news, label): 358 | if lab == 1: 359 | imprs.append((impr_index, news)) 360 | else: 361 | negative.append(news) 362 | 363 | histories.append(history) 364 | negatives.append(negative) 365 | user_indices.append(uindex) 366 | impr_index += 1 367 | 368 | save_dict = { 369 | "imprs": imprs, 370 | "user_indices": user_indices, 371 | "histories": histories, 372 | "negatives": negatives, 373 | } 374 | save_pickle(save_dict, os.path.join(cache_dir, "behaviors.pkl")) 375 | 376 | elif "dev" in behaviors_path: 377 | labels = [] 378 | with open(behaviors_path, "r") as f: 379 | for line in tqdm(f, desc="Caching User Behaviors", ncols=80): 380 | _, uid, _, history, impression = line.strip("\n").split("\t") 381 | 382 | history = [nid2index[x] for x in history.split()] 383 | interaction_pair = impression.split() 384 | impr_news = [nid2index[x.split("-")[0]] for x in interaction_pair] 385 | label = [int(i.split("-")[1]) for i in interaction_pair] 386 | uindex = uid2index[uid] 387 | 388 | for i in range(0, len(impr_news), manager.impr_size): 389 | imprs.append((impr_index, impr_news[i: i+manager.impr_size])) 390 | labels.append(label[i: i+manager.impr_size]) 391 | 392 | # 1 impression correspond to 1 of each of the following properties 393 | histories.append(history) 394 | user_indices.append(uindex) 395 | impr_index += 1 396 | 397 | save_dict = { 398 | "imprs": imprs, 399 | "labels": labels, 400 | "histories": histories, 401 | "user_indices": user_indices, 402 | } 403 | save_pickle(save_dict, os.path.join(cache_dir, "behaviors.pkl")) 404 | 405 | elif "test" in behaviors_path: 406 | with open(behaviors_path, "r") as f: 407 | for line in tqdm(f, desc="Caching User Behaviors", ncols=80): 408 | _, uid, _, history, impression = line.strip("\n").split("\t") 409 | 410 | impr_news = [nid2index[x] for x in impression.split()] 411 | history = [nid2index[x] for x in history.split()] 412 | uindex = uid2index[uid] 413 | 414 | for i in range(0, len(impr_news), manager.impr_size): 415 | imprs.append((impr_index, impr_news[i: i+manager.impr_size])) 416 | 417 | # 1 impression correspond to 1 of each of the following properties 418 | histories.append(history) 419 | user_indices.append(uindex) 420 | impr_index += 1 421 | 422 | save_dict = { 423 | "imprs": imprs, 424 | "histories": histories, 425 | "user_indices": user_indices, 426 | } 427 | save_pickle(save_dict, os.path.join(cache_dir, "behaviors.pkl")) 428 | 429 | -------------------------------------------------------------------------------- /src/utils/manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import logging 5 | import argparse 6 | import transformers 7 | import smtplib 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | from tqdm import tqdm 13 | from datetime import timedelta 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | from utils.util import Sequential_Sampler, load_pickle, save_pickle, download_plm 18 | from utils.dataset import * 19 | 20 | logger = logging.getLogger("Manager") 21 | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s (%(name)s) %(message)s") 22 | # prevent warning of transformers 23 | transformers.logging.set_verbosity_error() 24 | logging.getLogger("faiss.loader").setLevel(logging.ERROR) 25 | logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING) 26 | 27 | 28 | 29 | class Manager(): 30 | """ 31 | the class to handle dataloader preperation, model training/evaluation 32 | """ 33 | def __init__(self, config=None, command=None): 34 | """ 35 | set hyper parameters 36 | 37 | Args: 38 | config: some extra configuration 39 | """ 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("-s", "--scale", dest="scale", help="data scale", type=str, choices=["demo", "small", "large", "whole"], default="large") 42 | parser.add_argument("-e", "--epoch", dest="epochs", help="epochs to train the model", type=int, default=10) 43 | parser.add_argument("-m", "--mode", dest="mode", help="choose mode", default="train") 44 | parser.add_argument("-d", "--device", dest="device", help="gpu index, -1 for cpu", type=lambda x: int(x) if x != "cpu" else "cpu", default=0) 45 | parser.add_argument("-bs", "--batch-size", dest="batch_size", help="batch size in training", type=int, default=32) 46 | parser.add_argument("-bse", "--batch-size-eval", dest="batch_size_eval", help="batch size in encoding", type=int, default=200) 47 | # parser.add_argument("-dl", "--dataloaders", dest="dataloaders", help="training dataloaders", nargs="+", action="extend", choices=["train", "dev", "news", "behaviors"], default=["train", "dev", "news"]) 48 | 49 | parser.add_argument("-ck","--checkpoint", dest="checkpoint", help="load the model from checkpoint before training/evaluating", type=str, default="none") 50 | parser.add_argument("-vs","--validate-step", dest="validate_step", help="evaluate and save the model every step", type=str, default="0") 51 | parser.add_argument("-hst","--hold-step", dest="hold_step", help="don't evaluate until reaching hold step", type=str, default="0") 52 | parser.add_argument("-sav","--save-at-validate", dest="save_at_validate", help="save the model every time of validating", action="store_true", default=False) 53 | parser.add_argument("-vb","--verbose", dest="verbose", help="variant's name", type=str, default=None) 54 | parser.add_argument("--metrics", dest="metrics", help="metrics for evaluating the model", nargs="+", action="extend", default=["auc", "mean_mrr", "ndcg@5", "ndcg@10"]) 55 | 56 | parser.add_argument("-hs", "--his_size", dest="his_size",help="history size", type=int, default=50) 57 | parser.add_argument("-is", "--impr_size", dest="impr_size", help="impression size for evaluating", type=int, default=20) 58 | parser.add_argument("-nn", "--negative-num", dest="negative_num", help="number of negatives", type=int, default=4) 59 | parser.add_argument("-dp", "--dropout-p", dest="dropout_p", help="dropout probability", type=float, default=0.1) 60 | parser.add_argument("-lr", "--learning-rate", dest="learning_rate", help="learning rate", type=float, default=1e-5) 61 | parser.add_argument("-sch", "--scheduler", dest="scheduler", help="choose schedule scheme for optimizer", choices=["linear","none"], default="none") 62 | parser.add_argument("--warmup", dest="warmup", help="warmup steps of scheduler", type=float, default=0.1) 63 | 64 | parser.add_argument("-pth", "--preprocess-threads", dest="preprocess_threads", help="thread number in preprocessing", type=int, default=32) 65 | parser.add_argument("-dr", "--data-root", dest="data_root", default="../../../Data") 66 | parser.add_argument("-cr", "--cache-root", dest="cache_root", default="data/cache") 67 | 68 | parser.add_argument("-tl", "--title-length", dest="title_length", type=int, default=32) 69 | parser.add_argument("-al", "--abs-length", dest="abs_length", type=int, default=64) 70 | parser.add_argument("-mtl", "--max-title-length", dest="max_title_length", type=int, default=64) 71 | parser.add_argument("-mal", "--max-abs-length", dest="max_abs_length", type=int, default=128) 72 | 73 | parser.add_argument("-ef", "--enable-fields", dest="enable_fields", help="text fields to model", nargs="+", action="extend", choices=["title", "abs"], default=[]) 74 | 75 | parser.add_argument("-ne", "--news-encoder", dest="newsEncoder", default="cnn") 76 | parser.add_argument("-ue", "--user-encoder", dest="userEncoder", default="rnn") 77 | 78 | parser.add_argument("-hd", "--hidden-dim", dest="hidden_dim", type=int, default=768) 79 | parser.add_argument("-hn", "--head-num", dest="head_num", help="attention head number of tranformer model", type=int, default=12) 80 | 81 | parser.add_argument("-k", dest="k", help="gate number", type=int, default=4) 82 | 83 | parser.add_argument("-plm", dest="plm", help="short name of pre-trained language models", type=str, default="bert") 84 | 85 | parser.add_argument("--seed", dest="seed", default=3407, type=int) 86 | parser.add_argument("-ws", "--world-size", dest="world_size", help="gpu number", type=int, default=1) 87 | parser.add_argument("-br", "--base-rank", dest="base_rank", help="base device index", type=int, default=0) 88 | 89 | parser.add_argument("--debug", dest="debug", help="debug mode", action="store_true", default=False) 90 | 91 | if config: 92 | # different default settings per model 93 | parser.set_defaults(**config) 94 | if command is not None: 95 | args = vars(parser.parse_args(command)) 96 | else: 97 | args = vars(parser.parse_args()) 98 | 99 | # used for checking 100 | if args["debug"]: 101 | args["hold_step"] = "0" 102 | args["validate_step"] = "2" 103 | # default to load best checkpoint 104 | if args["mode"] != "train": 105 | if args["checkpoint"] == "none": 106 | args["checkpoint"] = "best" 107 | sequence_length = 0 108 | if "title" in args["enable_fields"]: 109 | sequence_length += args["title_length"] 110 | if "abs" in args["enable_fields"]: 111 | sequence_length += args["abs_length"] 112 | if sequence_length == 0: 113 | raise ValueError("Include at least one field!") 114 | else: 115 | args["sequence_length"] = sequence_length 116 | 117 | if args['seed'] is not None: 118 | seed = args['seed'] 119 | random.seed(seed) 120 | os.environ['PYTHONHASHSEED'] = str(seed) 121 | np.random.seed(seed) 122 | torch.manual_seed(seed) 123 | torch.cuda.manual_seed(seed) 124 | torch.cuda.manual_seed_all(seed) 125 | torch.backends.cudnn.deterministic = True 126 | torch.backends.cudnn.benchmark = True 127 | 128 | for k,v in args.items(): 129 | if not k.startswith("__"): 130 | setattr(self, k, v) 131 | 132 | plm_map = { 133 | "bert": { 134 | "full_name": "bert-base-uncased", 135 | "dim": 768, 136 | "vocab_size": 30522, 137 | "special_token_ids": { 138 | "[PAD]": 0, 139 | "[CLS]": 101, 140 | "[SEP]": 102, 141 | "punctuations": {}, 142 | } 143 | }, 144 | 145 | "distilbert": { 146 | "full_name": "distilbert-base-uncased", 147 | "dim": 768, 148 | "vocab_size": 30522, 149 | "special_token_ids": { 150 | "[PAD]": 0, 151 | "[CLS]": 101, 152 | "[SEP]": 102, 153 | "punctuations": {}, 154 | } 155 | }, 156 | 157 | } 158 | dataloader_map = { 159 | "train": ["train", "dev", "news"], 160 | "dev": ["dev", "news"], 161 | "test": ["test", "news"], 162 | } 163 | 164 | self.plm_dir = os.path.join(self.data_root, "PLM", self.plm) 165 | self.plm_dim = plm_map[self.plm]["dim"] 166 | self.special_token_ids = plm_map[self.plm]["special_token_ids"] 167 | self.vocab_size = plm_map[self.plm]["vocab_size"] 168 | self.plm_full_name = plm_map[self.plm]["full_name"] 169 | 170 | self.news_nums = { 171 | "MINDdemo_train": 51282, 172 | "MINDdemo_dev": 42416, 173 | "MINDsmall_train": 51282, 174 | "MINDsmall_dev": 42416, 175 | "MINDlarge_train": 101527, 176 | "MINDlarge_dev": 72023, 177 | "MINDlarge_test": 120961, 178 | } 179 | self.dataloaders = dataloader_map[self.mode] 180 | 181 | # default rank is 0 182 | self.rank = 0 183 | self.distributed = self.world_size > 1 184 | self.exclude_hparams = set(["news_nums", "vocab_size_map", "metrics", "plm_dim", "plm_dir", "data_root", "cache_root", "distributed", "exclude_hparams", "rank", "epochs", "mode", "debug", "special_token_ids", "validate_step", "hold_step", "exclude_hparams", "device", "save_at_validate", "preprocess_threads", "base_rank", "max_title_length", "max_abs_length"]) 185 | 186 | logger.info("Hyper Parameters are:\n{}\n".format({k:v for k,v in args.items() if k[0:2] != "__" and k not in self.exclude_hparams})) 187 | 188 | 189 | def setup(self, rank): 190 | """ 191 | set up distributed training and fix seeds 192 | """ 193 | os.environ["TOKENIZERS_PARALLELISM"] = "True" 194 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" 195 | 196 | if self.world_size > 1: 197 | os.environ["NCCL_DEBUG"] = "WARN" 198 | os.environ["MASTER_ADDR"] = "localhost" 199 | os.environ["MASTER_PORT"] = str(12355 + self.base_rank) 200 | 201 | # initialize the process group 202 | # set timeout to inf to prevent timeout error 203 | dist.init_process_group("nccl", rank=rank, world_size=self.world_size, timeout=timedelta(0, 1000000)) 204 | 205 | # manager.rank will be invoked in creating DistributedSampler 206 | self.rank = rank 207 | # manager.device will be invoked in the model 208 | self.device = rank + self.base_rank 209 | 210 | else: 211 | # one-gpu 212 | self.rank = 0 213 | 214 | if self.device != "cpu": 215 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 216 | # essential to make all_gather_object work properly 217 | torch.cuda.set_device(self.device) 218 | 219 | 220 | def prepare(self): 221 | """ 222 | prepare dataloader for training/evaluating 223 | 224 | Returns: 225 | loaders: list of dataloaders 226 | train: default/triple 227 | (passage) 228 | (query) 229 | (rerank) 230 | """ 231 | if self.rank == 0: 232 | # download plm once 233 | if os.path.exists(self.plm_dir): 234 | pass 235 | else: 236 | logger.info("downloading PLMs...") 237 | download_plm(self.plm_full_name, self.plm_dir) 238 | if self.distributed: 239 | dist.barrier(device_ids=[self.device]) 240 | 241 | loaders = {} 242 | 243 | # training dataloaders 244 | if "train" in self.dataloaders: 245 | dataset_train = MIND_Train(self) 246 | if self.distributed: 247 | sampler_train = DistributedSampler(dataset_train, num_replicas=self.world_size, rank=self.rank, seed=self.seed) 248 | shuffle = False 249 | else: 250 | sampler_train = None 251 | shuffle = True 252 | loaders["train"] = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=shuffle, sampler=sampler_train) 253 | 254 | if "dev" in self.dataloaders: 255 | dataset_dev = MIND_Dev(self) 256 | sampler_dev = Sequential_Sampler(len(dataset_dev), num_replicas=self.world_size, rank=self.rank) 257 | loaders["dev"] = DataLoader(dataset_dev, batch_size=self.batch_size_eval, sampler=sampler_dev, drop_last=False) 258 | 259 | if "test" in self.dataloaders: 260 | dataset_test = MIND_Test(self) 261 | sampler_test = Sequential_Sampler(len(dataset_test), num_replicas=self.world_size, rank=self.rank) 262 | loaders["test"] = DataLoader(dataset_test, batch_size=self.batch_size_eval, sampler=sampler_test, drop_last=False) 263 | 264 | if "news" in self.dataloaders: 265 | dataset_news = MIND_News(self) 266 | # no sampler 267 | loaders["news"] = DataLoader(dataset_news, batch_size=self.batch_size_eval, drop_last=False) 268 | 269 | return loaders 270 | 271 | 272 | def save(self, model, step, best=False): 273 | """ 274 | shortcut for saving the model and optimizer 275 | """ 276 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 277 | model = model.module 278 | if best: 279 | save_path = f"data/ckpts/{model.name}/{self.scale}/best.model" 280 | else: 281 | save_path = f"data/ckpts/{model.name}/{self.scale}/{step}.model" 282 | 283 | logger.info("saving model at {}...".format(save_path)) 284 | model_dict = model.state_dict() 285 | 286 | save_dict = {} 287 | save_dict["manager"] = {k:v for k,v in vars(self).items() if k[:2] != "__" and k not in self.exclude_hparams} 288 | save_dict["model"] = model_dict 289 | 290 | torch.save(save_dict, save_path) 291 | 292 | 293 | def load(self, model): 294 | """ 295 | shortcut for loading model and optimizer parameters 296 | 297 | Args: 298 | model: nn.Module 299 | checkpoint: the checkpoint step to load, if checkpoint==0, default to load the best model, if 300 | it doesn't exist, do not load 301 | strict: whether to enable strict loading 302 | """ 303 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 304 | model = model.module 305 | 306 | checkpoint = self.checkpoint 307 | if checkpoint == "none": 308 | return 309 | elif os.path.isfile(checkpoint): 310 | save_path = checkpoint 311 | elif checkpoint == "best": 312 | save_path = f"data/ckpts/{model.name}/{self.scale}/best.model" 313 | else: 314 | save_path = f"data/ckpts/{model.name}/{self.scale}/{checkpoint}.model" 315 | 316 | if not os.path.exists(save_path): 317 | if self.rank == 0: 318 | logger.warning(f"Checkpoint {save_path} Not Found, Not Loading Any Checkpoints!") 319 | return 320 | 321 | if self.rank == 0: 322 | logger.info("loading model from {}...".format(save_path)) 323 | 324 | state_dict = torch.load(save_path, map_location=torch.device(model.device)) 325 | 326 | if self.rank == 0: 327 | current_manager = vars(self) 328 | for k,v in state_dict["manager"].items(): 329 | try: 330 | if v != current_manager[k] and k not in {"dataloaders", "checkpoint"}: 331 | logger.info(f"manager settings {k} of the checkpoint is {v}, while it's {current_manager[k]} in current setting!") 332 | except KeyError: 333 | logger.info(f"manager settings {k} not found!") 334 | 335 | missing_keys, unexpected_keys = model.load_state_dict(state_dict["model"], strict=False) 336 | if self.rank == 0: 337 | if len(missing_keys): 338 | logger.warning(f"Missing Keys: {missing_keys}") 339 | if len(unexpected_keys): 340 | logger.warning(f"Unexpected Keys: {unexpected_keys}") 341 | 342 | 343 | def _log(self, model_name, metrics): 344 | """ 345 | wrap logging 346 | """ 347 | with open("performance.log", "a+") as f: 348 | d = {} 349 | for k, v in vars(self).items(): 350 | if k not in self.exclude_hparams: 351 | d[k] = v 352 | 353 | line = "{} : {}\n{}\n\n".format(model_name, str(d), str(metrics)) 354 | f.write(line) 355 | 356 | try: 357 | from data.email import email,password 358 | subject = f"[PR] {model_name}" 359 | email_server = smtplib.SMTP_SSL('smtp.gmail.com', 465) 360 | email_server.login(email, password) 361 | message = "Subject: {}\n\n{}".format(subject, line) 362 | email_server.sendmail(email, email, message) 363 | email_server.close() 364 | except: 365 | logger.info("error in connecting SMTP") 366 | 367 | 368 | def _train(self, model, loaders, validate_step, hold_step, optimizer, scheduler, save_at_validate=False): 369 | total_steps = 1 370 | loader_train = loaders["train"] 371 | 372 | best_res = {"main": -1.} 373 | 374 | if self.rank == 0: 375 | logger.info("training {}...".format(model.module.name if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.name)) 376 | 377 | for epoch in range(self.epochs): 378 | epoch_loss = 0 379 | if self.distributed: 380 | try: 381 | loader_train.sampler.set_epoch(epoch) 382 | except AttributeError: 383 | if self.rank == 0: 384 | logger.warning(f"{type(loader_train)} has no attribute 'sampler', make sure you're using Triple training dataloader") 385 | tqdm_ = tqdm(loader_train, ncols=120) 386 | 387 | for step, x in enumerate(tqdm_, 1): 388 | optimizer.zero_grad(set_to_none=True) 389 | loss = model(x) 390 | epoch_loss += float(loss) 391 | loss.backward() 392 | 393 | optimizer.step() 394 | if scheduler: 395 | scheduler.step() 396 | 397 | if step % 5 == 0: 398 | tqdm_.set_description("epoch: {:d}, step: {:d}, loss: {:.4f}".format(epoch + 1, step, epoch_loss / step)) 399 | 400 | if total_steps > hold_step and total_steps % validate_step == 0: 401 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 402 | result = model.module.dev(self, loaders) 403 | else: 404 | result = model.dev(self, loaders) 405 | # only the result of master node is useful 406 | if self.rank == 0: 407 | result["step"] = total_steps 408 | if save_at_validate: 409 | self.save(model, total_steps) 410 | 411 | # save the best model checkpoint 412 | if result["main"] >= best_res["main"]: 413 | best_res = result 414 | self.save(model, total_steps, best=True) 415 | self._log(model.module.name if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.name, result) 416 | 417 | # prevent SIGABRT 418 | if self.distributed: 419 | dist.barrier(device_ids=[self.device]) 420 | # continue training 421 | model.train() 422 | 423 | total_steps += 1 424 | 425 | return best_res 426 | 427 | 428 | def train(self, model, loaders): 429 | """ 430 | train the model 431 | """ 432 | model.train() 433 | if self.rank == 0: 434 | # in case the folder does not exists, create one 435 | os.makedirs(f"data/ckpts/{model.module.name if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.name}/{self.scale}", exist_ok=True) 436 | 437 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 438 | optimizer, scheduler = model.module.get_optimizer(self, len(loaders["train"])) 439 | else: 440 | optimizer, scheduler = model.get_optimizer(self, len(loaders["train"])) 441 | 442 | self.load(model) 443 | 444 | if self.validate_step[-1] == "e": 445 | # validate at the end of several epochs 446 | validate_step = round(len(loaders["train"]) * float(self.validate_step[:-1])) 447 | elif self.validate_step == "0": 448 | # validate at the end of every epoch 449 | validate_step = len(loaders["train"]) 450 | else: 451 | # validate at certain steps 452 | validate_step = int(self.validate_step) 453 | if self.hold_step[-1] == "e": 454 | hold_step = int(len(loaders["train"]) * float(self.hold_step[:-1])) 455 | else: 456 | hold_step = int(self.hold_step) 457 | 458 | result = self._train(model, loaders, validate_step, hold_step, optimizer, scheduler=scheduler, save_at_validate=self.save_at_validate) 459 | 460 | if self.rank in [-1,0]: 461 | logger.info("Best result: {}".format(result)) 462 | self._log(model.module.name if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.name, result) 463 | --------------------------------------------------------------------------------