├── __init__.py ├── model ├── __init__.py ├── doc_encoder.py ├── net.py └── attention.py ├── requirements.txt ├── .gitignore ├── config.py ├── metric.py ├── main.py ├── unit_test.py ├── README.md ├── inference.py ├── dataset.py └── train.py /__init__.py: -------------------------------------------------------------------------------- 1 | from .net import NPA 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .net import NRMS -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | pytorch_lightning 3 | orjson 4 | gensim -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | lightning_logs/ 3 | word2vec/* 4 | data/* 5 | .*/ 6 | __py*/ -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | hparams = { 2 | 'batch_size': 32, 3 | 'lr': 5e-4, 4 | 'name': 'ranger', 5 | 'version': 'v3', 6 | 'description': 'NRMS lr=5e-4, with weight_decay', 7 | 'pretrained_model': './word2vec/wiki_300d_5ws.model', 8 | 'model': { 9 | 'dct_size': 'auto', 10 | 'nhead': 10, 11 | 'embed_size': 300, 12 | # 'self_attn_size': 400, 13 | 'encoder_size': 250, 14 | 'v_size': 200 15 | }, 16 | 'data': { 17 | 'pos_k': 50, 18 | 'neg_k': 4, 19 | 'maxlen': 15 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def dcg_score(y_true, y_score, k=10): 5 | order = np.argsort(y_score)[::-1] 6 | y_true = np.take(y_true, order[:k]) 7 | gains = 2 ** y_true - 1 8 | discounts = np.log2(np.arange(len(y_true)) + 2) 9 | return np.sum(gains / discounts) 10 | 11 | 12 | def ndcg_score(y_true, y_score, k=10): 13 | best = dcg_score(y_true, y_true, k) 14 | actual = dcg_score(y_true, y_score, k) 15 | return actual / best 16 | 17 | def mrr_score(y_true, y_score): 18 | order = np.argsort(y_score)[::-1] 19 | y_true = np.take(y_true, order) 20 | rr_score = y_true / (np.arange(len(y_true)) + 1) 21 | return np.sum(rr_score) / np.sum(y_true) 22 | -------------------------------------------------------------------------------- /model/doc_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.attention import AdditiveAttention 5 | 6 | 7 | class DocEncoder(nn.Module): 8 | def __init__(self, hparams, weight=None) -> None: 9 | super(DocEncoder, self).__init__() 10 | self.hparams = hparams 11 | if weight is None: 12 | self.embedding = nn.Embedding(100, 300) 13 | else: 14 | self.embedding = nn.Embedding.from_pretrained(weight, freeze=False, padding_idx=0) 15 | self.mha = nn.MultiheadAttention(hparams['embed_size'], num_heads=hparams['nhead'], dropout=0.1) 16 | self.proj = nn.Linear(hparams['embed_size'], hparams['encoder_size']) 17 | self.additive_attn = AdditiveAttention(hparams['encoder_size'], hparams['v_size']) 18 | 19 | def forward(self, x): 20 | x = F.dropout(self.embedding(x), 0.2) 21 | x = x.permute(1, 0, 2) 22 | output, _ = self.mha(x, x, x) 23 | output = F.dropout(output.permute(1, 0, 2)) 24 | output = self.proj(output) 25 | output, _ = self.additive_attn(output) 26 | return output 27 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 2 | from pytorch_lightning.loggers import TensorBoardLogger 3 | from pytorch_lightning import Trainer 4 | from torch import dtype 5 | from config import hparams 6 | import os 7 | from train import Model 8 | from argparse import ArgumentParser 9 | 10 | parser = ArgumentParser('train args') 11 | parser.add_argument('--gpu', default='0') 12 | parser.add_argument('--epochs', default=50, type=int) 13 | args = parser.parse_args() 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 16 | 17 | model = Model(hparams) 18 | checkpoint_callback = ModelCheckpoint( 19 | filepath=f'lightning_logs/{hparams["name"]}/{hparams["version"]}/' + '{epoch}-{auroc:.2f}', 20 | save_top_k=3, 21 | verbose=True, 22 | monitor='auroc', 23 | mode='max', 24 | save_last=True 25 | ) 26 | 27 | early_stop = EarlyStopping( 28 | monitor='auroc', 29 | min_delta=0.001, 30 | patience=5, 31 | strict=False, 32 | verbose=True, 33 | mode='max' 34 | ) 35 | logger = TensorBoardLogger( 36 | save_dir='lightning_logs', 37 | name=hparams['name'], 38 | version=hparams["version"] 39 | ) 40 | 41 | trainer = Trainer(max_epochs=args.epochs, 42 | gpus=1, 43 | early_stop_callback=early_stop, 44 | weights_summary='full', 45 | checkpoint_callback=checkpoint_callback, 46 | logger=logger) 47 | 48 | trainer.fit(model) 49 | -------------------------------------------------------------------------------- /unit_test.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from config import hparams 4 | 5 | def attention(): 6 | from model.attention import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct, AdditiveAttention 7 | 8 | q = torch.rand(10, 16, 200) 9 | k = v = torch.rand(10, 16, 200) 10 | 11 | print('MultiHead: ') 12 | proj = InProjContainer(nn.Linear(200, 200), nn.Linear( 13 | 200, 200), nn.Linear(200, 200)) 14 | mha = MultiheadAttentionContainer(nhead=8, 15 | in_proj_container=proj, 16 | attention_layer=ScaledDotProduct(), 17 | out_proj=nn.Linear(200, 200)) 18 | output, score = mha(q, k, v) 19 | print(output.shape) 20 | print('Additive: ') 21 | attn = AdditiveAttention(200, 300) 22 | output = output.permute(1, 0, 2) 23 | output, score = attn(output) 24 | print(output.size(), score.size()) 25 | 26 | def doc_encoder(): 27 | from model.doc_encoder import DocEncoder 28 | 29 | encoder = DocEncoder(hparams['model']) 30 | x = torch.randint(0, 100, (16, 100)) 31 | output = encoder(x) 32 | print(output.shape) 33 | 34 | def NRMS(): 35 | from model.net import NRMS 36 | 37 | nrms = NRMS(hparams['model']) 38 | clicks = torch.randint(0, 100, (8, 50, 100)) 39 | cands = torch.randint(0, 100, (8, 10, 100)) 40 | logits = nrms(clicks, cands) 41 | # print(logits.shape) 42 | 43 | # doc_encoder() 44 | attention() 45 | # NRMS() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NRMS 2 | 3 | * Pytorch 1.5 4 | * [Neural News Recommendation with Multi-Head Self-Attention](https://www.aclweb.org/anthology/D19-1671/) 5 | * [Offical Implementation](https://github.com/wuch15/EMNLP2019-NRMS) (keras) 6 | 7 | ## Change 8 | 9 | * Use Range instead of Adam Optimizer 10 | * title encoder: branch for RoBERTa and ELETRA-Small 11 | * pytorch-lightning 12 | * tensorboard support 13 | * early stop 14 | 15 | 16 | ## Benchmark 17 | 18 | * Use Taiwan PTT forum as data (Tranditional Chinese) 19 | * regard a comment as the user intereseted in the post 20 | * train on one Titan RTX 21 | * train until early stop 22 | 23 | ## Data Description 24 | 25 | * articles.json 26 | 27 | ```python 28 | [{'id': 0, 'title': ['[', '公告', '] ', '八卦', '優文', '推薦', '申請']}, 29 | {'id': 1, 'title': ['[', '公告', '] ', '八卦板', '政治文', '規範', '草案', '開始', '討論']}, 30 | {'id': 2, 'title': ['[', '公告', '] ', '三月份', '置底', '閒聊', '文']}, 31 | ... 32 | ] 33 | ``` 34 | 35 | * users_list.json 36 | ```python 37 | [{'user_id': 0, 'push':[1, 2, 3]}, # 'push' is a list of articles.json id 38 | {'user_id': 1, 'push':[2, 5, 6]}, 39 | ... 40 | ] 41 | ``` 42 | 43 | 44 | ### Model 45 | 46 | * original: Use Word2Vec pretrained on Wiki-zh 47 | * Roberta: roberta-base in [this](https://github.com/ymcui/Chinese-BERT-wwm) 48 | * ELETRA: electra-smiall in [this](https://github.com/ymcui/Chinese-ELECTRA) 49 | 50 | ### training time 51 | 52 | * original(Adam): 1 hr 53 | * original(Ranger): 1 hr 4 min 54 | * Roberta(Ranger): 19 hr 46 min 55 | * ELETRA-Small(Ranger): 2hr 19min 56 | 57 | ### Score on ValidationSet 58 | 59 | 60 | #### AUROC 61 | 62 | * original(Adam): 0.86 63 | * original(Ranger): 0.89 64 | * Roberta(Ranger): 0.94 65 | * ELETRA-small(Ranger): 0.91 66 | 67 | #### ndcg@5 68 | 69 | * original(Adam): 0.73 70 | * original(Ranger): 0.79 71 | * Roberta(Ranger): 0.88 72 | * ELETRA-small(Ranger): 0.81 73 | 74 | #### ndcg@10 75 | 76 | * original(Adam): 0.67 77 | * original(Ranger): 0.72 78 | * Roberta(Ranger): 0.81 79 | * ELETRA-small(Ranger): 0.74 80 | -------------------------------------------------------------------------------- /model/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.doc_encoder import DocEncoder 5 | from model.attention import AdditiveAttention 6 | 7 | 8 | class NRMS(nn.Module): 9 | def __init__(self, hparams, weight=None): 10 | super(NRMS, self).__init__() 11 | self.hparams = hparams 12 | self.doc_encoder = DocEncoder(hparams, weight=weight) 13 | # proj = InProjContainer(nn.Linear(hparams['encoder_size'], hparams['encoder_size']), 14 | # nn.Linear(hparams['encoder_size'], hparams['encoder_size']), 15 | # nn.Linear(hparams['encoder_size'], hparams['encoder_size'])) 16 | self.mha = nn.MultiheadAttention(hparams['encoder_size'], hparams['nhead'], dropout=0.1) 17 | 18 | # self.mha = MultiheadAttentionContainer(nhead=hparams['nhead'], 19 | # in_proj_container=proj, 20 | # attention_layer=ScaledDotProduct(), 21 | # out_proj=nn.Linear(hparams['encoder_size'], hparams['encoder_size'])) 22 | self.proj = nn.Linear(hparams['encoder_size'], hparams['encoder_size']) 23 | self.additive_attn = AdditiveAttention(hparams['encoder_size'], hparams['v_size']) 24 | self.criterion = nn.CrossEntropyLoss() 25 | 26 | def forward(self, clicks, cands, labels=None): 27 | """forward 28 | 29 | Args: 30 | clicks (tensor): [num_user, num_click_docs, seq_len] 31 | cands (tensor): [num_user, num_candidate_docs, seq_len] 32 | """ 33 | num_click_docs = clicks.shape[1] 34 | num_cand_docs = cands.shape[1] 35 | num_user = clicks.shape[0] 36 | seq_len = clicks.shape[2] 37 | clicks = clicks.reshape(-1, seq_len) 38 | cands = cands.reshape(-1, seq_len) 39 | click_embed = self.doc_encoder(clicks) 40 | cand_embed = self.doc_encoder(cands) 41 | click_embed = click_embed.reshape(num_user, num_click_docs, -1) 42 | cand_embed = cand_embed.reshape(num_user, num_cand_docs, -1) 43 | click_embed = click_embed.permute(1, 0, 2) 44 | click_output, _ = self.mha(click_embed, click_embed, click_embed) 45 | click_output = F.dropout(click_output.permute(1, 0, 2), 0.2) 46 | 47 | click_repr = self.proj(click_output) 48 | click_repr, _ = self.additive_attn(click_output) 49 | logits = torch.bmm(click_repr.unsqueeze(1), cand_embed.permute(0, 2, 1)).squeeze(1) # [B, 1, hid], [B, 10, hid] 50 | if labels is not None: 51 | loss = self.criterion(logits, labels) 52 | return loss, logits 53 | return torch.sigmoid(logits) 54 | # return torch.softmax(logits, -1) 55 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from gensim.models import Word2Vec 3 | import torch 4 | from model import NRMS 5 | from typing import List 6 | from gaisTokenizer import Tokenizer 7 | 8 | 9 | class Model(pl.LightningModule): 10 | def __init__(self, hparams): 11 | super(Model, self).__init__() 12 | self.w2v = Word2Vec.load(hparams['pretrained_model']) 13 | self.w2id = {w: self.w2v.wv.vocab[w].index for w in self.w2v.wv.vocab} 14 | 15 | if hparams['model']['dct_size'] == 'auto': 16 | hparams['model']['dct_size'] = len(self.w2v.wv.vocab) 17 | self.model = NRMS(hparams['model'], torch.tensor(self.w2v.wv.vectors)) 18 | self.hparams = hparams 19 | self.maxlen = hparams['data']['maxlen'] 20 | self.tokenizer = Tokenizer('k95763565C5F785B50546754545D77505F0325160B58173C17291B3D5E2500135001671C06272B3B06281E1E5E55A9F7EB80C0E58AD1EB50AC') 21 | 22 | def forward(self, viewed, cands, topk): 23 | """forward 24 | 25 | Args: 26 | viewed (tensor): [B, viewed_num, maxlen] 27 | cands (tesnor): [B, cand_num, maxlen] 28 | Returns: 29 | val: [B] 0 ~ 1 30 | idx: [B] 31 | """ 32 | logits = self.model(viewed, cands) 33 | val, idx = logits.topk(topk) 34 | return idx, val 35 | 36 | def predict_one(self, viewed, cands, topk): 37 | """predict one user 38 | 39 | Args: 40 | viewed (List[List[str]]): 41 | cands (List[List[str]]): 42 | Returns: 43 | topk of cands 44 | """ 45 | viewed_token = torch.tensor([self.sent2idx(v) for v in viewed]).unsqueeze(0) 46 | cands_token = torch.tensor([self.sent2idx(c) for c in cands]).unsqueeze(0) 47 | idx, val = self(viewed_token, cands_token, topk) 48 | val = val.squeeze().detach().cpu().tolist() 49 | 50 | result = [cands[i] for i in idx.squeeze()] 51 | return result, val 52 | 53 | def sent2idx(self, tokens: List[str]): 54 | if ']' in tokens: 55 | tokens = tokens[tokens.index(']'):] 56 | tokens = [self.w2id[token.strip()] 57 | for token in tokens if token.strip() in self.w2id.keys()] 58 | tokens += [0] * (self.maxlen - len(tokens)) 59 | tokens = tokens[:self.maxlen] 60 | return tokens 61 | 62 | def tokenize(self, sents: str): 63 | return self.tokenizer.tokenize(sents) 64 | 65 | 66 | def print_func(r): 67 | for t in r: 68 | print(''.join(t)) 69 | 70 | if __name__ == '__main__': 71 | import json, random 72 | with open('./data/articles.json', 'r') as f: 73 | articles = json.loads(f.read()) 74 | with open('./data/users_list.json', 'r') as f: 75 | users = json.loads(f.read()) 76 | nrms = Model.load_from_checkpoint('lightning_logs/ranger/v3/epoch=30-auroc=0.89.ckpt') 77 | viewed = users[1001]['push'][:50] 78 | viewed = [articles[v]['title'] for v in viewed] 79 | print_func(viewed) 80 | cands = [a['title'] for a in random.sample(articles, 20)] + viewed[:10] 81 | result, val = nrms.predict_one(viewed, cands, 20) 82 | print('result') 83 | print_func(result) 84 | print(val) 85 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List 3 | 4 | import orjson as json 5 | import torch 6 | from gensim.models import Word2Vec 7 | from torch.utils import data 8 | from tqdm import tqdm 9 | 10 | 11 | class Dataset(data.Dataset): 12 | def __init__(self, article_file: str, user_file: str, w2v, maxlen: int = 15, pos_num: int = 50, neg_k: int = 4): 13 | self.articles = self.load_json(article_file) 14 | self.users = self.load_json(user_file) 15 | self.maxlen = maxlen 16 | self.neg_k = neg_k 17 | self.pos_num = pos_num 18 | 19 | self.w2id = {w: w2v.wv.vocab[w].index for w in w2v.wv.vocab} 20 | 21 | def load_json(self, file: str): 22 | with open(file, 'r') as f: 23 | return json.loads(f.read()) 24 | 25 | def sent2idx(self, tokens: List[str]): 26 | # tokens = tokens[3:] 27 | if ']' in tokens: 28 | tokens = tokens[tokens.index(']'):] 29 | tokens = [self.w2id[token.strip()] 30 | for token in tokens if token.strip() in self.w2id.keys()] 31 | tokens += [0] * (self.maxlen - len(tokens)) 32 | tokens = tokens[:self.maxlen] 33 | return tokens 34 | 35 | def __len__(self): 36 | return len(self.users) 37 | 38 | def __getitem__(self, idx: int): 39 | """getitem 40 | 41 | Args: 42 | idx (int): 43 | Data: 44 | return ( 45 | user_id (int): 1 46 | click (tensor): [batch, num_click_docs, seq_len] 47 | cand (tensor): [batch, num_candidate_docs, seq_len] 48 | label: candidate docs label (0 or 1) 49 | ) 50 | """ 51 | push = self.users[idx]['push'] 52 | random.shuffle(push) 53 | push = push[:self.pos_num] 54 | uid = self.users[idx]['user_id'] 55 | click_doc = [self.sent2idx(self.articles[p]['title']) for p in push] 56 | cand_doc = [] 57 | cand_doc_label = [] 58 | # neg 59 | for i in range(self.neg_k): 60 | neg_id = -1 61 | while neg_id == -1 or neg_id in push: 62 | neg_id = random.randint(0, len(self.articles) - 1) 63 | cand_doc.append(self.sent2idx(self.articles[neg_id]['title'])) 64 | cand_doc_label.append(0) 65 | # pos 66 | try: 67 | cand_doc.append(self.sent2idx( 68 | self.articles[push[random.randint(50, len(self.push) - 1)]['title']])) 69 | cand_doc_label.append(1) 70 | except Exception: 71 | try: 72 | cand_doc.append(self.sent2idx(self.articles[push[0]]['title'])) 73 | except: 74 | print(push[0]) 75 | print(self.articles[push[0]]) 76 | cand_doc_label.append(1) 77 | 78 | tmp = list(zip(cand_doc, cand_doc_label)) 79 | random.shuffle(tmp) 80 | cand_doc, cand_doc_label = zip(*tmp) 81 | return torch.tensor(click_doc), torch.tensor(cand_doc), torch.tensor(cand_doc_label, dtype=torch.float).argmax(0) 82 | 83 | class ValDataset(Dataset): 84 | def __init__(self, num=50, *args, **kwargs) -> None: 85 | super(ValDataset, self).__init__(*args, **kwargs) 86 | self.num = num 87 | 88 | def __getitem__(self, idx: int): 89 | push = self.users[idx]['push'] 90 | random.shuffle(push) 91 | uid = self.users[idx]['user_id'] 92 | click_doc = [self.sent2idx(self.articles[p]['title']) for p in push[:self.pos_num]] 93 | 94 | true_num = 10 95 | # true_num = random.randint(1, min(self.num, len(push)) ) 96 | f_num = self.num - true_num 97 | cand_doc = random.sample(push, true_num) # true 98 | cand_doc_label = [1] * true_num 99 | cand_doc.extend(random.sample(range(0, len(self.articles)), f_num)) # false 100 | cand_doc_label.extend([0] * f_num) 101 | tmp = list(zip(cand_doc, cand_doc_label)) 102 | random.shuffle(tmp) 103 | cand_doc, cand_doc_label = zip(*tmp) 104 | cand_doc = [self.sent2idx(self.articles[cand]['title']) for cand in cand_doc] 105 | return torch.LongTensor(click_doc), torch.LongTensor(cand_doc), torch.LongTensor(cand_doc_label) 106 | 107 | 108 | if __name__ == '__main__': 109 | w2v = Word2Vec.load('./word2vec/wiki_300d_5ws.model') 110 | ds = ValDataset(50, './data/articles.json', './data/users_list.json', 111 | w2v, maxlen=30, pos_num=50, neg_k=4) 112 | print(ds[10]) 113 | for i in tqdm(ds): 114 | pass 115 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pytorch_ranger 2 | import torch 3 | from torch.utils import data 4 | import pytorch_lightning as pl 5 | from pytorch_ranger import Ranger 6 | from dataset import Dataset, ValDataset 7 | from gensim.models import Word2Vec 8 | from model.net import NRMS 9 | from metric import ndcg_score, mrr_score 10 | 11 | 12 | class Model(pl.LightningModule): 13 | def __init__(self, hparams): 14 | super(Model, self).__init__() 15 | self.w2v = Word2Vec.load(hparams['pretrained_model']) 16 | if hparams['model']['dct_size'] == 'auto': 17 | hparams['model']['dct_size'] = len(self.w2v.wv.vocab) 18 | self.model = NRMS(hparams['model'], torch.tensor(self.w2v.wv.vectors)) 19 | self.hparams = hparams 20 | 21 | def configure_optimizers(self): 22 | # return torch.optim.Adam(self.parameters(), lr=self.hparams['lr'], weight_decay=1e-5) 23 | return pytorch_ranger.Ranger(self.parameters(), lr=self.hparams['lr'], weight_decay=1e-5) 24 | 25 | def prepare_data(self): 26 | """prepare_data 27 | 28 | load dataset 29 | """ 30 | d = self.hparams['data'] 31 | self.train_ds = Dataset( 32 | './data/articles.json', './data/users_list.json', self.w2v, maxlen=self.hparams['data']['maxlen'], pos_num=d['pos_k'], neg_k=d['neg_k']) 33 | self.val_ds = ValDataset( 34 | 50, './data/articles.json', './data/users_list.json', self.w2v) 35 | tmp = [t.unsqueeze(0) for t in self.train_ds[0]] 36 | self.logger.experiment.add_graph(self.model, tmp) 37 | # num_train = int(len(ds) * 0.85) 38 | # num_val = len(ds) - num_train 39 | # self.train_ds, self.val_ds = data.random_split(ds, (num_train, num_val)) 40 | 41 | def train_dataloader(self): 42 | """ 43 | 44 | return: 45 | train_dataloader 46 | """ 47 | return data.DataLoader(self.train_ds, batch_size=self.hparams['batch_size'], num_workers=10, shuffle=True) 48 | 49 | def val_dataloader(self): 50 | """ 51 | 52 | return: 53 | val_dataloader 54 | """ 55 | sampler = data.RandomSampler( 56 | self.val_ds, num_samples=10000, replacement=True) 57 | return data.DataLoader(self.val_ds, sampler=sampler, batch_size=self.hparams['batch_size'], num_workers=10, drop_last=True) 58 | 59 | def forward(self): 60 | """forward 61 | define as normal pytorch model 62 | """ 63 | return None 64 | 65 | def training_step(self, batch, batch_idx): 66 | """for each step(batch) 67 | 68 | Arguments: 69 | batch {[type]} -- data 70 | batch_idx {[type]} 71 | 72 | """ 73 | clicks, cands, labels = batch 74 | loss, score = self.model(clicks, cands, labels) 75 | return {'loss': loss} 76 | 77 | def training_epoch_end(self, outputs): 78 | """for each epoch end 79 | 80 | Arguments: 81 | outputs: list of training_step output 82 | """ 83 | loss_mean = torch.stack([x['loss'] for x in outputs]).mean() 84 | logs = {'train_loss': loss_mean} 85 | self.model.eval() 86 | 87 | # self.logger.log_metrics(logs, self.current_epoch) 88 | return {'progress_bar': logs, 'log': logs} 89 | 90 | def validation_step(self, batch, batch_idx): 91 | """for each step(batch) 92 | 93 | Arguments: 94 | batch {[type]} -- data 95 | batch_idx {[type]} 96 | 97 | """ 98 | clicks, cands, cands_label = batch 99 | with torch.no_grad(): 100 | logits = self.model(clicks, cands) 101 | mrr = 0. 102 | auc = 0. 103 | ndcg5, ndcg10 = 0., 0. 104 | 105 | for score, label in zip(logits, cands_label): 106 | auc += pl.metrics.functional.auroc(score, label) 107 | score = score.detach().cpu().numpy() 108 | label = label.detach().cpu().numpy() 109 | mrr += mrr_score(label, score) 110 | ndcg5 += ndcg_score(label, score, 5) 111 | ndcg10 += ndcg_score(label, score, 10) 112 | return {'auroc': (auc / logits.shape[0]).item(), 'mrr': (mrr / logits.shape[0]).item(), 'ndcg5': (ndcg5 / logits.shape[0]).item(), 'ndcg10': (ndcg10 / logits.shape[0]).item()} 113 | 114 | def validation_epoch_end(self, outputs): 115 | """ 116 | validation end 117 | 118 | Arguments: 119 | outputs: list of training_step output 120 | """ 121 | mrr = torch.tensor([x['mrr'] for x in outputs]) 122 | auroc = torch.tensor([x['auroc'] for x in outputs]) 123 | ndcg5 = torch.tensor([x['ndcg5'] for x in outputs]) 124 | ndcg10 = torch.tensor([x['ndcg10'] for x in outputs]) 125 | 126 | logs = {'auroc': auroc.mean(), 'mrr': mrr.mean( 127 | ), 'ndcg@5': ndcg5.mean(), 'ndcg@10': ndcg10.mean()} 128 | # self.logger.log_metrics(logs, self.current_epoch) 129 | self.model.train() 130 | return {'progress_bar': logs, 'log': logs} 131 | 132 | 133 | if __name__ == '__main__': 134 | from pytorch_lightning import Trainer 135 | from config import hparams 136 | import os 137 | 138 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 139 | 140 | model = Model(hparams) 141 | 142 | trainer = Trainer(max_epochs=50, 143 | gpus=1 144 | ) 145 | 146 | trainer.fit(model) 147 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import Tuple, Optional 5 | 6 | 7 | class AdditiveAttention(torch.nn.Module): 8 | def __init__(self, in_dim=100, v_size=200): 9 | super().__init__() 10 | 11 | self.in_dim = in_dim 12 | self.v_size = v_size 13 | # self.v = torch.nn.Parameter(torch.rand(self.v_size)) 14 | self.proj = nn.Sequential(nn.Linear(self.in_dim, self.v_size), nn.Tanh()) 15 | self.proj_v = nn.Linear(self.v_size, 1) 16 | 17 | def forward(self, context): 18 | """Additive Attention 19 | 20 | Args: 21 | context (tensor): [B, seq_len, in_dim] 22 | 23 | Returns: 24 | outputs, weights: [B, seq_len, out_dim], [B, seq_len] 25 | """ 26 | # weights = self.proj(context) @ self.v 27 | weights = self.proj_v(self.proj(context)).squeeze(-1) 28 | weights = torch.softmax(weights, dim=-1) # [B, seq_len] 29 | return torch.bmm(weights.unsqueeze(1), context).squeeze(1), weights # [B, 1, seq_len], [B, seq_len, dim] 30 | 31 | 32 | class MultiheadAttentionContainer(torch.nn.Module): 33 | def __init__(self, nhead, in_proj_container, attention_layer, out_proj): 34 | r""" A multi-head attention container 35 | Args: 36 | nhead: the number of heads in the multiheadattention model 37 | in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear). 38 | attention_layer: The attention layer. 39 | out_proj: The multi-head out-projection layer (a.k.a nn.Linear). 40 | Examples:: 41 | >>> import torch 42 | >>> embed_dim, num_heads, bsz = 10, 5, 64 43 | >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), 44 | torch.nn.Linear(embed_dim, embed_dim), 45 | torch.nn.Linear(embed_dim, embed_dim)) 46 | >>> MHA = MultiheadAttentionContainer(num_heads, 47 | in_proj_container, 48 | ScaledDotProduct(), 49 | torch.nn.Linear(embed_dim, embed_dim)) 50 | >>> query = torch.rand((21, bsz, embed_dim)) 51 | >>> key = value = torch.rand((16, bsz, embed_dim)) 52 | >>> attn_output, attn_weights = MHA(query, key, value) 53 | >>> print(attn_output.shape) 54 | >>> torch.Size([21, 64, 10]) 55 | """ 56 | super(MultiheadAttentionContainer, self).__init__() 57 | self.nhead = nhead 58 | self.in_proj_container = in_proj_container 59 | self.attention_layer = attention_layer 60 | self.out_proj = out_proj 61 | 62 | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 63 | attn_mask: Optional[torch.Tensor] = None, 64 | bias_k: Optional[torch.Tensor] = None, 65 | bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 66 | r""" 67 | Args: 68 | query, key, value (Tensor): map a query and a set of key-value pairs to an output. 69 | See "Attention Is All You Need" for more details. 70 | attn_mask, bias_k and bias_v (Tensor, optional): keyword arguments passed to the attention layer. 71 | See the definitions in the attention. 72 | Shape: 73 | - Inputs: 74 | - query: :math:`(L, N, E)` 75 | - key: :math:`(S, N, E)` 76 | - value: :math:`(S, N, E)` 77 | - attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer. 78 | - Outputs: 79 | - attn_output: :math:`(L, N, E)` 80 | - attn_output_weights: :math:`(N * H, L, S)` 81 | where where L is the target length, S is the sequence length, H is the number of attention heads, 82 | N is the batch size, and E is the embedding dimension. 83 | """ 84 | tgt_len, src_len, bsz, embed_dim = query.size( 85 | -3), key.size(-3), query.size(-2), query.size(-1) 86 | q, k, v = self.in_proj_container(query, key, value) 87 | assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads" 88 | head_dim = q.size(-1) // self.nhead 89 | q = q.reshape(tgt_len, bsz * self.nhead, head_dim) 90 | 91 | assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads" 92 | head_dim = k.size(-1) // self.nhead 93 | k = k.reshape(src_len, bsz * self.nhead, head_dim) 94 | 95 | assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads" 96 | head_dim = v.size(-1) // self.nhead 97 | v = v.reshape(src_len, bsz * self.nhead, head_dim) 98 | 99 | attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, 100 | bias_k=bias_k, bias_v=bias_v) 101 | attn_output = attn_output.reshape(tgt_len, bsz, embed_dim) 102 | attn_output = self.out_proj(attn_output) 103 | return attn_output, attn_output_weights 104 | 105 | 106 | class ScaledDotProduct(torch.nn.Module): 107 | 108 | def __init__(self, dropout=0.0): 109 | r"""Processes a projected query and key-value pair to apply 110 | scaled dot product attention. 111 | Args: 112 | dropout (float): probability of dropping an attention weight. 113 | Examples:: 114 | >>> SDP = torchtext.models.ScaledDotProduct(0.1) 115 | >>> q = torch.randn(256, 21, 3) 116 | >>> k = v = torch.randn(256, 21, 3) 117 | >>> attn_output, attn_weights = SDP(q, k, v) 118 | >>> print(attn_output.shape, attn_weights.shape) 119 | torch.Size([256, 21, 3]) torch.Size([256, 21, 21]) 120 | """ 121 | super(ScaledDotProduct, self).__init__() 122 | self.dropout = dropout 123 | 124 | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 125 | attn_mask: Optional[torch.Tensor] = None, 126 | bias_k: Optional[torch.Tensor] = None, 127 | bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 128 | r"""Uses a scaled dot product with the projected key-value pair to update 129 | the projected query. 130 | Args: 131 | query (Tensor): Projected query 132 | key (Tensor): Projected key 133 | value (Tensor): Projected value 134 | attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. 135 | bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at 136 | sequence dim (dim=-3). Those are used for incremental decoding. Users should provide 137 | non-None to both arguments in order to activate them. 138 | Shape: 139 | - query: :math:`(L, N * H, E / H)` 140 | - key: :math:`(S, N * H, E / H)` 141 | - value: :math:`(S, N * H, E / H)` 142 | - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend 143 | while ``False`` values will be unchanged. 144 | - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` 145 | - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` 146 | where L is the target length, S is the source length, H is the number 147 | of attention heads, N is the batch size, and E is the embedding dimension. 148 | """ 149 | if bias_k is not None and bias_v is not None: 150 | assert key.size(-1) == bias_k.size(-1) and key.size(-2) == bias_k.size(-2) and bias_k.size(-3) == 1, \ 151 | "Shape of bias_k is not supported" 152 | assert value.size(-1) == bias_v.size(-1) and value.size(-2) == bias_v.size(-2) and bias_v.size(-3) == 1, \ 153 | "Shape of bias_v is not supported" 154 | key = torch.cat([key, bias_k]) 155 | value = torch.cat([value, bias_v]) 156 | if attn_mask is not None: 157 | _attn_mask = attn_mask 158 | attn_mask = torch.nn.functional.pad(_attn_mask, (0, 1)) 159 | 160 | tgt_len, head_dim = query.size(-3), query.size(-1) 161 | assert query.size(-1) == key.size(-1) == value.size( 162 | -1), "The feature dim of query, key, value must be equal." 163 | assert key.size() == value.size(), "Shape of key, value must match" 164 | src_len = key.size(-3) 165 | batch_heads = max(query.size(-2), key.size(-2)) 166 | 167 | # Scale query 168 | query, key, value = query.transpose( 169 | -2, -3), key.transpose(-2, -3), value.transpose(-2, -3) 170 | query = query * (head_dim ** -0.5) 171 | if attn_mask is not None: 172 | if attn_mask.dim() != 3: 173 | raise RuntimeError('attn_mask must be a 3D tensor.') 174 | if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \ 175 | (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads): 176 | raise RuntimeError('The size of the attn_mask is not correct.') 177 | if attn_mask.dtype != torch.bool: 178 | raise RuntimeError( 179 | 'Only bool tensor is supported for attn_mask') 180 | 181 | # Dot product of q, k 182 | attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) 183 | if attn_mask is not None: 184 | attn_output_weights.masked_fill_(attn_mask, -1e8,) 185 | attn_output_weights = torch.nn.functional.softmax( 186 | attn_output_weights, dim=-1) 187 | attn_output_weights = torch.nn.functional.dropout( 188 | attn_output_weights, p=self.dropout, training=self.training) 189 | attn_output = torch.matmul(attn_output_weights, value) 190 | return attn_output.transpose(-2, -3), attn_output_weights 191 | 192 | 193 | class InProjContainer(torch.nn.Module): 194 | def __init__(self, query_proj, key_proj, value_proj): 195 | r"""A in-proj container to process inputs. 196 | Args: 197 | query_proj: a proj layer for query. 198 | key_proj: a proj layer for key. 199 | value_proj: a proj layer for value. 200 | """ 201 | 202 | super(InProjContainer, self).__init__() 203 | self.query_proj = query_proj 204 | self.key_proj = key_proj 205 | self.value_proj = value_proj 206 | 207 | def forward(self, 208 | query: torch.Tensor, 209 | key: torch.Tensor, 210 | value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 211 | r"""Projects the input sequences using in-proj layers. 212 | Args: 213 | query, key, value (Tensors): sequence to be projected 214 | Shape: 215 | - query, key, value: :math:`(S, N, E)` 216 | - Output: :math:`(S, N, E)` 217 | where S is the sequence length, N is the batch size, and E is the embedding dimension. 218 | """ 219 | return self.query_proj(query), self.key_proj(key), self.value_proj(value) 220 | 221 | 222 | def generate_square_subsequent_mask(nbatch, sz): 223 | r"""Generate a square mask for the sequence. The masked positions are filled with True. 224 | Unmasked positions are filled with False. 225 | Args: 226 | nbatch: the number of batch size 227 | sz: the size of square mask 228 | """ 229 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose( 230 | 0, 1).repeat(nbatch, 1, 1) 231 | return mask 232 | --------------------------------------------------------------------------------