├── README.md ├── tool.py ├── LICENSE ├── utils.py ├── augmentation.py ├── .gitignore ├── trainer.py ├── main.py ├── model.py ├── data_iterator.py └── layers.py /README.md: -------------------------------------------------------------------------------- 1 | # CL4SRec-pytorch 2 | A pytorch implementation of CL4SRec in "Contrastive Learning for Sequential Recommendation", which provides three output aggregation strategies including 'concat', 'mean' and 'predict' and three augmentation strategies 'mask', 'reorder' and 'crop'. 3 | 4 | ## Dataset 5 | The dataset should be organized as the following format. The first column is the userid, followed by the interacted items. 6 | 7 | ```python 8 | # ./data/dataset_name.txt 9 | user item1 item2 ... 10 | ``` 11 | 12 | ## Usage 13 | You can train CL4SRec on Yelp dataset by following command 14 | ```bash 15 | python -u main.py --dataset Yelp --cl_embs predict 16 | ``` 17 | 18 | ## Acknowledgement 19 | The Transformer layer is implemented based on [recbole](https://github.com/RUCAIBox/RecBole). -------------------------------------------------------------------------------- /tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from time import time 3 | import os 4 | import random 5 | from math import ceil, floor 6 | 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | def trans_to_cuda(variable): 11 | if torch.cuda.is_available(): 12 | return variable.cuda() 13 | else: 14 | return variable 15 | 16 | def trans_to_cpu(variable): 17 | if torch.cuda.is_available(): 18 | return variable.cpu() 19 | else: 20 | return variable 21 | 22 | def setup_seed(seed): 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) 26 | os.environ['PYTHONHASHSEED'] = str(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | torch.backends.cudnn.deterministic = True 31 | random.seed(seed) 32 | 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 JamZheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from tool import * 4 | 5 | class ContrastiveLoss(nn.Module): 6 | def __init__(self, config): 7 | super(ContrastiveLoss, self).__init__() 8 | self.ce_loss = nn.CrossEntropyLoss() 9 | self.temp = config.temp 10 | self.batch_size = config.batch_size 11 | self.maxlen = config.maxlen 12 | self.eps = 1e-12 13 | 14 | def forward(self, pos_embeds1, pos_embeds2): 15 | batch_size = pos_embeds1.shape[0] 16 | pos_diff = torch.sum(pos_embeds1 * pos_embeds2, dim=-1).view(-1, 1) / self.temp 17 | score11 = torch.matmul(pos_embeds1, pos_embeds1.transpose(0, 1)) / self.temp 18 | score22 = torch.matmul(pos_embeds2, pos_embeds2.transpose(0, 1)) / self.temp 19 | score12 = torch.matmul(pos_embeds1, pos_embeds2.transpose(0, 1)) / self.temp 20 | 21 | mask = (-torch.eye(batch_size).long() + 1).bool() 22 | mask = trans_to_cuda(mask) 23 | score11 = score11[mask].view(batch_size, -1) 24 | score22 = score22[mask].view(batch_size, -1) 25 | score12 = score12[mask].view(batch_size, -1) 26 | 27 | 28 | score1 = torch.cat((pos_diff, score11, score12), dim=1) # [B, 2B - 2] 29 | score2 = torch.cat((pos_diff, score22, score12), dim=1) 30 | score = torch.cat((score1, score2), dim=0) 31 | 32 | labels = torch.zeros(batch_size * 2).long() 33 | score = trans_to_cuda(score) 34 | labels = trans_to_cuda(labels) 35 | assert score.shape[-1] == 2 * batch_size - 1 36 | return self.ce_loss(score, labels) 37 | 38 | -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | from math import ceil, floor 2 | from typing import Any 3 | import numpy as np 4 | import copy 5 | 6 | class Augment(object): 7 | def __init__(self, p): 8 | self.p = p 9 | 10 | class Mask(Augment): 11 | def __init__(self, p, is_hard=True): 12 | super(Mask, self).__init__(p) 13 | self.is_hard = is_hard 14 | 15 | def __call__(self, ori_seq): 16 | if self.is_hard: 17 | return self.hard_mask(ori_seq) 18 | else: 19 | return self.soft_mask(ori_seq) 20 | 21 | def soft_mask(self, ori_seq): 22 | seq = copy.deepcopy(np.array(ori_seq)) 23 | mask = ((np.random.rand(seq.size)) > self.p) 24 | if mask.sum() < 1: 25 | mask = ((np.random.rand(seq.size)) > self.p) 26 | seq[mask] = 0 27 | return seq.tolist() 28 | 29 | def hard_mask(self, ori_seq): 30 | seq = copy.deepcopy(np.array(ori_seq)) 31 | seq_idx = np.random.choice(np.arange(0, len(seq)), size=floor(len(seq) * self.p), replace=False) 32 | # seq_idx = np.sort(seq_idx) 33 | seq[seq_idx] = 0 34 | return seq.tolist() 35 | 36 | class Reorder(Augment): 37 | def __init__(self, p, is_hard=True): 38 | super(Reorder, self).__init__(p) 39 | self.is_hard = is_hard 40 | 41 | def __call__(self, ori_seq): 42 | if self.is_hard: 43 | return self.hard_reorder(ori_seq) 44 | else: 45 | return self.soft_reorder(ori_seq) 46 | 47 | def hard_reorder(self, ori_seq): 48 | seq = copy.deepcopy(np.array(ori_seq)) 49 | begin = np.random.randint(0, ceil(len(seq) - (len(seq) * self.p))) 50 | ori_idx = np.arange(begin, ceil(begin + len(seq) * self.p)) 51 | shuffle_idx = copy.deepcopy(ori_idx) 52 | np.random.shuffle(shuffle_idx) 53 | seq[ori_idx] = seq[shuffle_idx] 54 | return seq.tolist() 55 | 56 | def soft_reorder(self, ori_seq): 57 | seq = copy.deepcopy(np.array(ori_seq)) 58 | ori_idx = np.random.choice(np.arange(0, len(seq)), size=ceil(len(seq) * self.p), replace=False) 59 | shuffle_idx = copy.deepcopy(ori_idx) 60 | np.random.shuffle(shuffle_idx) 61 | seq[ori_idx] = seq[shuffle_idx] 62 | return seq.tolist() 63 | 64 | class Crop(Augment): 65 | def __init__(self, p): 66 | super(Crop, self).__init__(p) 67 | 68 | def __call__(self, ori_seq): 69 | 70 | seq = copy.deepcopy(np.array(ori_seq)) 71 | begin = np.random.randint(0, ceil(len(seq) - (len(seq) * self.p))) 72 | tar_idx = np.arange(begin, ceil(begin + len(seq) * self.p)) 73 | seq = seq[tar_idx] 74 | 75 | 76 | return seq.tolist() 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from re import L 2 | import torch 3 | import os 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | 8 | from tool import * 9 | 10 | class Trainer(object): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | def train(self, epoch, writer, model, train_data, config): 15 | logger = config.logger 16 | model.train() 17 | 18 | step = 0 19 | loss_sum = 0 20 | 21 | if not os.path.exists('runs'): 22 | os.mkdir('runs') 23 | 24 | for seqs, poss, negs, lens, aug_seqs, aug_lens in train_data: 25 | # torch.cuda.empty_cache() 26 | seqs = trans_to_cuda(torch.LongTensor(seqs)) 27 | poss = trans_to_cuda(torch.LongTensor(poss)) 28 | negs = trans_to_cuda(torch.LongTensor(negs)) 29 | lens = trans_to_cuda(torch.LongTensor(lens)) 30 | 31 | model.optimizer.zero_grad() 32 | 33 | loss = 0 34 | cl_loss_record = 0 35 | step += 1 36 | 37 | 38 | # calculate the Next item loss 39 | loss += model.cal_loss(seqs, poss, negs, lens) 40 | 41 | # calculate the contrastive loss 42 | aug_seqs1, aug_seqs2 = trans_to_cuda(torch.LongTensor(aug_seqs[0])), trans_to_cuda(torch.LongTensor(aug_seqs[1])) 43 | aug_lens1, aug_lens2 = trans_to_cuda(torch.LongTensor(aug_lens[0])), trans_to_cuda(torch.LongTensor(aug_lens[1])) 44 | 45 | aug_embs1 = model(aug_seqs1, aug_lens1, phase=config.cl_embs) 46 | aug_embs2 = model(aug_seqs2, aug_lens2, phase=config.cl_embs) 47 | batch_size = aug_embs1.shape[0] 48 | 49 | if config.cl_embs == 'mean': 50 | cl_loss_record += model.cl_loss(aug_embs1.mean(-2), aug_embs2.mean(-2)) * config.w_clloss 51 | elif config.cl_embs == 'concat': 52 | cl_loss_record += model.cl_loss(aug_embs1.view(batch_size, -1), aug_embs2.view(batch_size, -1)) * config.w_clloss 53 | else: 54 | cl_loss_record += model.cl_loss(aug_embs1, aug_embs2) * config.w_clloss 55 | 56 | loss += cl_loss_record 57 | loss.backward() 58 | model.optimizer.step() 59 | loss_sum += loss.item() 60 | writer.add_scalar("loss", loss.item(), step) 61 | 62 | logger.info('Epoch(by epoch):{:d}\tloss:{:4f}'\ 63 | .format(epoch, loss_sum / train_data.n_batch / config.test_epoch)) 64 | 65 | 66 | def eval(self, epoch, model, config, test_data, ks, phase='valid'): 67 | logger = config.logger 68 | 69 | recall, ndcg = [0] * len(ks), [0] * len(ks) 70 | num = 0 71 | model.eval() 72 | 73 | test_data_iter = tqdm(test_data, total=test_data.n_batch) 74 | with torch.no_grad(): 75 | for seqs, tars, lens in test_data_iter: 76 | seqs = trans_to_cuda(torch.LongTensor(seqs)) 77 | lens = trans_to_cuda(torch.LongTensor(lens)) 78 | item_scores = model.full_sort_predict(seqs, lens) 79 | nrecall = int(ks[-1]) 80 | item_scores[:, 0] -= 1e9 81 | if config.repeat_rec == False: 82 | for seq, item_score in zip(seqs, item_scores): 83 | item_score[seq] -= 1e9 84 | _, items = torch.topk(item_scores, nrecall, sorted=True) 85 | items = trans_to_cpu(items).detach().numpy() 86 | 87 | batch_recall, batch_ndcg = [0] * len(ks), [0] * len(ks) 88 | 89 | for item, tar in zip(items, tars): 90 | for k, kk in enumerate(ks): 91 | if tar in set(item[:kk]): 92 | batch_recall[k] += 1 93 | item_idx = {i:idx + 1 for idx, i in enumerate(item[:kk])} 94 | batch_ndcg[k] += (1 / np.log2(item_idx[tar] + 1)) 95 | 96 | recall = [r + br for r, br in zip(recall, batch_recall)] 97 | ndcg = [n + bn for n, bn in zip(ndcg, batch_ndcg)] 98 | num += seqs.shape[0] 99 | 100 | if phase == 'valid': 101 | log_str = 'Valid: ' 102 | for nbr_k, kk in enumerate(ks): 103 | log_str += 'Recall@{:2d}:\t{:.4f}\t'.format(kk, recall[nbr_k] / num) 104 | logger.info(log_str) 105 | log_str = 'Valid: ' 106 | for nbr_k, kk in enumerate(ks): 107 | log_str += 'NDCG@{:2d}:\t{:.4f}\t'.format(kk, ndcg[nbr_k] / num) 108 | logger.info(log_str) 109 | else: 110 | log_str = 'Test: ' 111 | for nbr_k, kk in enumerate(ks): 112 | log_str += 'Recall@{:2d}:\t{:.4f}\t'.format(kk, recall[nbr_k] / num) 113 | logger.info(log_str) 114 | log_str = 'Test: ' 115 | for nbr_k, kk in enumerate(ks): 116 | log_str += 'NDCG@{:2d}:\t{:.4f}\t'.format(kk, ndcg[nbr_k] / num) 117 | logger.info(log_str) 118 | 119 | if ks is None: 120 | return [recall / num] 121 | else: 122 | return [r / num for r in recall] 123 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import logging 5 | import sys 6 | from datetime import datetime 7 | from time import time 8 | 9 | sys.path.append('..') 10 | from tensorboardX import SummaryWriter 11 | from trainer import Trainer 12 | from model import CL4SRec 13 | from data_iterator import TrainData, TestData 14 | from tool import trans_to_cuda, setup_seed 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--gpu_id', type=int, default=0) 19 | parser.add_argument('--dataset', type=str, default='Beauty') 20 | parser.add_argument('--model', type=str, default='CL4SRec') 21 | # contrastive learning 22 | parser.add_argument('--cl_embs', type=str, default='concat', help='concat | predict | mean') 23 | parser.add_argument('--w_clloss', type=float, default=0.1, help='the weight of cl loss') 24 | parser.add_argument('--temp', type=float, default=1) 25 | parser.add_argument('--aug_type', type=str, default= 'mcr', 26 | help='mask | crop | reorder') 27 | parser.add_argument('--is_hard', type=bool, default=True) 28 | parser.add_argument('--mask_p', type=float, default=0.5) 29 | parser.add_argument('--crop_p', type=float, default=0.6) 30 | parser.add_argument('--reorder_p', type=float, default=0.2) 31 | 32 | 33 | parser.add_argument('--repeat_rec', type=int, default=0) 34 | parser.add_argument('--num_layers', type=int, default=2) 35 | parser.add_argument('--dropout_prob', type=float, default=0.25) 36 | parser.add_argument('--filename', type=str, default='debug', help='post filename') 37 | parser.add_argument('--random_seed', type=int, default=11) 38 | parser.add_argument('--embedding_size', type=int, default=64) 39 | parser.add_argument('--hidden_size', type=int, default=64) 40 | parser.add_argument('--inner_size', type=int, default=64) 41 | parser.add_argument('--batch_size', type=int, default=512) 42 | parser.add_argument('--hidden_dropout_prob', type=float, default=0.5) 43 | parser.add_argument('--attention_probs_dropout_prob', type=float, default=0.5) 44 | parser.add_argument('--num_hidden_layers', type=int, default=2) 45 | parser.add_argument('--num_attention_heads', type=int, default=2) 46 | parser.add_argument('--hidden_act', type=str, default='gelu') 47 | parser.add_argument('--lr', type=float, default=0.001, help='') 48 | parser.add_argument('--weight_decay', type=float, default=0.000, help='') 49 | parser.add_argument('--max_iter', type=int, default=100, help='(k)') 50 | parser.add_argument('--maxlen', type=int, default=50) 51 | parser.add_argument('--best_ckpt_path', type=str, default='runs/', 52 | help='the direction to save ckpt') 53 | parser.add_argument('--log_dir', type=str, default='log_debug', help='the direction of log') 54 | parser.add_argument('--loss_type', type=str, default='BCE', help='CE | BCE') 55 | 56 | parser.add_argument('--test_epoch', type=int, default=1) 57 | parser.add_argument('--patience', type=int, default=40) 58 | parser.add_argument('--max_epoch', type=int, default=500) 59 | 60 | return parser.parse_args() 61 | 62 | 63 | 64 | def main(): 65 | # initial config and seed 66 | config = get_args() 67 | 68 | os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id) 69 | os.environ['CUDA_LAUNCH_BLOCKING']='1' 70 | SEED = config.random_seed 71 | setup_seed(SEED) 72 | 73 | if not os.path.exists(config.log_dir): 74 | os.mkdir(config.log_dir) 75 | config.log_dir += '/' + datetime.now().strftime('%m%d') 76 | if not os.path.exists(config.log_dir): 77 | os.mkdir(config.log_dir) 78 | # config.log_dir = os.path.join(config.log_dir, datetime.now().strftime('%m%d')) 79 | 80 | filename = '{}_{}_{}_{}_{}_{}'.format(config.dataset, config.model,config.batch_size, config.aug_type, config.cl_embs, 81 | datetime.fromtimestamp(time()).strftime('%m%d%H%M')) 82 | 83 | config.best_ckpt_path += filename 84 | if not os.path.exists('runs_tensorboard'): os.mkdir('runs_tensorboard') 85 | if not os.path.exists('runs'): os.mkdir('runs') 86 | 87 | logging.basicConfig(format="%(asctime)s %(name)s:%(levelname)s:%(message)s", 88 | filename='{}/{}.log'.format(config.log_dir, filename), 89 | level=logging.INFO) 90 | logger = logging.getLogger(__name__) 91 | config.logger = logger 92 | 93 | # initial dataset 94 | train_data = TrainData(config) 95 | valid_data = TestData(config, is_valid=True) 96 | test_data = TestData(config, is_valid=False) 97 | config.n_item, config.n_user = train_data.num_items + 1, train_data.num_users + 1 98 | 99 | writer = SummaryWriter('runs_tensorboard/{}'.format(filename)) 100 | 101 | logger.info("--------------------Configure Info:------------") 102 | for arg in vars(config): 103 | logger.info(f"{arg} : {getattr(config, arg)}") 104 | 105 | # initial model 106 | model = trans_to_cuda(CL4SRec(config)) 107 | 108 | # initial trainer 109 | trainer = Trainer() 110 | 111 | # ------------------train------------------------------ 112 | best_metrics = [0] 113 | trials = 0 114 | best_epoch = 0 115 | 116 | for i in range(config.max_epoch): 117 | epoch = i + 1 118 | trainer.train(epoch, writer, model, train_data, config) 119 | 120 | if epoch % config.test_epoch == 0: 121 | metrics = trainer.eval(epoch, model, config, valid_data, [5, 20], phase='valid') 122 | 123 | 124 | if metrics[-1] > best_metrics[-1]: 125 | best_epoch = epoch 126 | torch.save(model.state_dict(), config.best_ckpt_path) 127 | best_metrics = metrics 128 | trials = 0 129 | else: 130 | trials += 1 131 | # early stopping 132 | if trials > config.patience: 133 | break 134 | 135 | # ------------------test------------------------------ 136 | model.load_state_dict(torch.load(config.best_ckpt_path)) 137 | logger.info('-------------best valid in epoch {:d}-------------'.format(best_epoch)) 138 | trainer.eval(epoch, model, config, valid_data, ks = [5, 20], phase='valid') 139 | logger.info('------------test-----------------') 140 | trainer.eval(epoch, model, config, test_data, ks=[5, 20], phase='test') 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from numpy import zeros 2 | import torch 3 | from torch import nn 4 | from layers import TransformerEncoder 5 | from utils import * 6 | 7 | class CL4SRec(nn.Module): 8 | def __init__(self, config): 9 | super(CL4SRec, self).__init__() 10 | 11 | self.n_layers = config.num_hidden_layers 12 | self.n_heads = config.num_attention_heads 13 | self.hidden_size = config.hidden_size # same as embedding_size 14 | self.inner_size = config.inner_size # the dimensionality in feed-forward layer 15 | self.hidden_dropout_prob = config.hidden_dropout_prob 16 | self.attn_dropout_prob = config.attention_probs_dropout_prob 17 | self.hidden_act = config.hidden_act 18 | self.layer_norm_eps = 1e-12 19 | self.batch_size = config.batch_size 20 | 21 | self.initializer_range = 0.02 22 | 23 | self.loss_type = config.loss_type 24 | self.n_item = config.n_item 25 | self.max_seq_length = config.maxlen 26 | self.temp = config.temp 27 | 28 | # define layers and loss 29 | self.item_embeddings = nn.Embedding(self.n_item, self.hidden_size, padding_idx=0) 30 | 31 | self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size) 32 | self.trm_encoder = TransformerEncoder( 33 | n_layers=self.n_layers, 34 | n_heads=self.n_heads, 35 | hidden_size=self.hidden_size, 36 | inner_size=self.inner_size, 37 | hidden_dropout_prob=self.hidden_dropout_prob, 38 | attn_dropout_prob=self.attn_dropout_prob, 39 | hidden_act=self.hidden_act, 40 | layer_norm_eps=self.layer_norm_eps 41 | ) 42 | 43 | self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) 44 | self.dropout = nn.Dropout(self.hidden_dropout_prob) 45 | 46 | self.cl_loss = ContrastiveLoss(config) 47 | 48 | self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr, weight_decay=config.weight_decay) # 参数的正则项系数 49 | # parameters initialization 50 | 51 | self.apply(self._init_weights) 52 | if self.loss_type == 'BCE': 53 | self.criterion = nn.BCELoss(reduction='mean') 54 | elif self.loss_type == 'MCE' or self.loss_type == 'CE': 55 | self.criterion = nn.CrossEntropyLoss() 56 | else: 57 | raise NotImplementedError("Make sure 'loss_type' in ['CE' ...]!") 58 | 59 | def _init_weights(self, module): 60 | """ Initialize the weights """ 61 | if isinstance(module, (nn.Linear, nn.Embedding)): 62 | # Slightly different from the TF version which uses truncated_normal for initialization 63 | # cf https://github.com/pytorch/pytorch/pull/5617 64 | module.weight.data.normal_(mean=0.0, std=self.initializer_range) 65 | elif isinstance(module, nn.LayerNorm): 66 | module.bias.data.zero_() 67 | module.weight.data.fill_(1.0) 68 | if isinstance(module, nn.Linear) and module.bias is not None: 69 | module.bias.data.zero_() 70 | 71 | def get_attention_mask(self, item_seq): 72 | """Generate left-to-right uni-directional attention mask for multi-head attention.""" 73 | # [B, seq_len] 74 | attention_mask = (item_seq > 0).long() 75 | 76 | # [B, 1, 1, seq_len] 77 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 78 | # mask for left-to-right unidirectional 79 | max_len = attention_mask.size(-1) 80 | attn_shape = (1, max_len, max_len) 81 | subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 82 | subsequent_mask = (subsequent_mask == 0).unsqueeze(1) 83 | subsequent_mask = subsequent_mask.long().to(item_seq.device) 84 | 85 | extended_attention_mask = extended_attention_mask * subsequent_mask 86 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 87 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 88 | return extended_attention_mask 89 | 90 | def forward(self, item_seq, item_seq_len, phase='predict'): 91 | # position embedding 92 | position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device) 93 | position_ids = position_ids.unsqueeze(0).expand_as(item_seq) 94 | position_embedding = self.position_embedding(position_ids) 95 | # get positon embedding 96 | item_emb = self.item_embeddings(item_seq) 97 | input_emb = item_emb + position_embedding 98 | 99 | input_emb = self.LayerNorm(input_emb) 100 | input_emb = self.dropout(input_emb) 101 | 102 | extended_attention_mask = self.get_attention_mask(item_seq) 103 | 104 | trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True) 105 | output = trm_output[-1] 106 | 107 | if phase == 'predict': 108 | output = self.gather_indexes(output, item_seq_len - 1) 109 | 110 | return output # [B H] 111 | 112 | def cross_entropy(self, seq_out, poss, negs): 113 | # [batch seq_len hidden_size] 114 | seq_len = poss.shape[1] 115 | pos_emb = self.item_embeddings(poss) 116 | neg_emb = self.item_embeddings(negs) 117 | # [batch*seq_len hidden_size] 118 | pos = pos_emb.view(-1, pos_emb.size(2)) 119 | neg = neg_emb.view(-1, neg_emb.size(2)) 120 | seq_emb = seq_out.view(-1, self.hidden_size) # [batch*seq_len hidden_size] 121 | pos_logits = torch.sum(pos * seq_emb, -1) # [batch*seq_len] 122 | neg_logits = torch.sum(neg * seq_emb, -1) # [batch*seq_len] 123 | istarget = (poss > 0).view(poss.size(0) * seq_len).float() # [batch*seq_len] 124 | loss = torch.sum( 125 | -torch.log(torch.sigmoid(pos_logits) + 1e-24) * istarget 126 | - torch.log(1 - torch.sigmoid(neg_logits) + 1e-24) * istarget 127 | ) / torch.sum(istarget) 128 | 129 | return loss 130 | 131 | def cal_loss(self, input_ids, poss, negs, lens): 132 | if self.loss_type == 'BCE': 133 | seq_embs = self.forward(input_ids, lens, phase='concat') 134 | loss = self.cross_entropy(seq_embs, poss, negs) 135 | return loss 136 | 137 | 138 | def full_sort_predict(self, item_seq, item_seq_len): 139 | seq_output = self.forward(item_seq, item_seq_len) 140 | test_items_emb = self.item_embeddings.weight 141 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items] 142 | return scores 143 | 144 | def gather_indexes(self, output, gather_index): 145 | """Gathers the vectors at the specific positions over a minibatch""" 146 | gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1]) 147 | output_tensor = output.gather(dim=1, index=gather_index) 148 | return output_tensor.squeeze(1) -------------------------------------------------------------------------------- /data_iterator.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import numpy as np 3 | from random import shuffle 4 | import copy 5 | from augmentation import * 6 | import logging 7 | 8 | 9 | class Data(object): 10 | def __init__(self, data_name, max_len, logger): 11 | self.logger = logger 12 | 13 | file = f"./data/{data_name}.txt" 14 | 15 | self.max_len = max_len 16 | 17 | 18 | self.pos_seqs, self.num_users, self.num_items = self.read_file(file=file) 19 | 20 | temp_num = self.renumber() 21 | re_num = self.renumber() 22 | while re_num != temp_num: 23 | temp_num = re_num 24 | re_num = self.renumber() 25 | 26 | self.num_users, self.num_items = len(self.pos_seqs), re_num - len(self.pos_seqs) 27 | 28 | inters = sum([len(seq) for seq in self.pos_seqs]) 29 | logger.info("-------------after renumber------------") 30 | logger.info('users:'+ str(self.num_users)) 31 | logger.info('items:'+ str(self.num_items)) 32 | logger.info('average length:' + str(inters / self.num_users)) 33 | logger.info("data sparsity:" + str(inters / self.num_users / self.num_items)) 34 | 35 | 36 | def renumber(self, min_freq = 5): 37 | lens = [len(seq) for seq in self.pos_seqs] 38 | self.pos_seqs = [seq for seq, seqlen in zip(self.pos_seqs, lens) if seqlen >= min_freq] 39 | his_seqs = self.pos_seqs.copy() 40 | item_cnt = dict() 41 | for seq in his_seqs: 42 | for item in seq: 43 | if item in item_cnt.keys(): 44 | item_cnt[item] += 1 45 | else: 46 | item_cnt[item] = 1 47 | item_set = set() 48 | for item in item_cnt.keys(): 49 | if item_cnt[item] >= min_freq: 50 | item_set.add(item) 51 | 52 | iid_nbr_dict = {iid: idx for idx, iid in enumerate(sorted(list(item_set)))} 53 | self.pos_seqs = [[iid_nbr_dict[iid] + 1 for iid in seq if iid in item_set] for seq in 54 | self.pos_seqs] 55 | 56 | return len(item_set) + len(self.pos_seqs) 57 | 58 | 59 | def read_file(self, file): 60 | max_item = 0 61 | max_uid = 0 62 | pos_seqs = [] 63 | len_list = [] 64 | with open(file, 'r') as f: 65 | for line in f: 66 | inter = line.split(' ') 67 | uid = int(inter[0]) 68 | seq = [int(item) for item in inter[1:]] 69 | max_item = max(max_item, max(seq)) 70 | max_uid = max(max_uid, uid) 71 | len_list.append(len(seq)) 72 | pos_seqs.append(seq) 73 | 74 | self.logger.info("-------raw data-----") 75 | self.logger.info('users:' + str(max_uid)) 76 | self.logger.info('items:' + str(max_item)) 77 | self.logger.info('average length:' + str(sum(len_list) / len(len_list))) 78 | 79 | return pos_seqs, max_uid, max_item 80 | 81 | 82 | class TrainData(Data): 83 | def __init__(self, config): 84 | super(TrainData, self).__init__(config.dataset, config.maxlen, config.logger) 85 | logger = config.logger 86 | self.pre_seq_aug = config.pre_seq_aug 87 | self.batch_size = config.batch_size 88 | self.cl_type = config.aug_type 89 | self.pos_seqs = [seq[:-2] for seq in self.pos_seqs] 90 | self.is_hard = config.is_hard 91 | 92 | self.num_users = len(self.pos_seqs) 93 | 94 | inters = sum([len(seq) for seq in self.pos_seqs]) 95 | logger.info("-------------for training------------") 96 | logger.info('users:'+ str(self.num_users)) 97 | logger.info('items:'+ str(self.num_items)) 98 | logger.info('average length:' + str(inters / self.num_users)) 99 | logger.info("data sparsity:" + str(inters / self.num_users / self.num_items)) 100 | 101 | 102 | 103 | 104 | self.n_batch = len(self.pos_seqs) // self.batch_size 105 | self.curr_batch = 0 106 | if len(self.pos_seqs) % self.batch_size: 107 | self.n_batch += 1 108 | 109 | if self.is_hard: 110 | self.base_transform = { 'mask':Mask(config.mask_p), 111 | 'reorder':Reorder(config.reorder_p), 112 | 'crop':Crop(config.crop_p)} 113 | else: 114 | self.base_transform = { 'mask':Mask(config.mask_p, is_hard=False), 115 | 'reorder':Reorder(config.reorder_p, is_hard=False), 116 | 'crop':Crop(config.crop_p)} 117 | 118 | self.transform_map = {'m':'mask', 'r':'reorder', 'c':'crop'} 119 | 120 | 121 | 122 | def __iter__(self): 123 | return self 124 | 125 | def __next__(self): 126 | if self.curr_batch >= self.n_batch: 127 | shuffle(self.pos_seqs) 128 | self.curr_batch = 0 129 | raise StopIteration 130 | 131 | 132 | idxs = np.arange(self.curr_batch * self.batch_size, 133 | min(len(self.pos_seqs), (self.curr_batch + 1) * self.batch_size)) 134 | 135 | in_seqs, out_seqs, out_negs = [], [], [] 136 | for idx in idxs: 137 | seq = self.pos_seqs[idx].copy() 138 | 139 | in_seq = seq[:-1] 140 | out_seq = seq[1:] 141 | negs = [] 142 | 143 | for _ in in_seq: 144 | neg = np.random.randint(1, self.num_items + 1) 145 | while neg in seq: 146 | neg = np.random.randint(1, self.num_items + 1) 147 | negs.append(neg) 148 | 149 | in_seqs.append(in_seq) 150 | out_seqs.append(out_seq) 151 | out_negs.append(negs) 152 | 153 | 154 | lens = [len(seq) for seq in in_seqs] 155 | max_len = min(self.max_len, max(lens)) 156 | 157 | lens = [l if l <= max_len else max_len for l in lens] 158 | 159 | seqs = [seq + [0] * (max_len - len(seq)) if len(seq) <= max_len else seq[-max_len:] for seq in in_seqs] 160 | poss = [pos + [0] * (max_len - len(pos)) if len(pos) <= max_len else pos[-max_len:] for pos in out_seqs] 161 | negs = [neg + [0] * (max_len - len(neg)) if len(neg) <= max_len else neg[-max_len:] for neg in out_negs] 162 | 163 | 164 | in_seqs1, in_seqs2 = self.augment(in_seqs) 165 | aug_seqs1, lens1 = self.seqs_pad(in_seqs1) 166 | aug_seqs2, lens2 = self.seqs_pad(in_seqs2) 167 | aug_seqs = [aug_seqs1, aug_seqs2] 168 | aug_lens = [lens1, lens2] 169 | 170 | self.curr_batch += 1 171 | 172 | return seqs, poss, negs, lens, aug_seqs, aug_lens 173 | 174 | def augment(self, in_seqs): 175 | seqs1, seqs2 = [], [] 176 | if self.cl_type in self.base_transform.keys(): 177 | aug = self.base_transform[self.cl_type] 178 | for seq in in_seqs: 179 | seq1, seq2 = aug(seq), aug(seq) 180 | seqs1.append(seq1) 181 | seqs2.append(seq2) 182 | else: 183 | transform_list = [] 184 | for t in self.cl_type: 185 | transform_list.append(self.transform_map[t]) 186 | for seq in in_seqs: 187 | aug_method = np.random.choice(transform_list, size=2, replace=True) 188 | aug1, aug2 = self.base_transform[aug_method[0]], self.base_transform[aug_method[1]] 189 | seq1, seq2 = aug1(seq), aug2(seq) 190 | seqs1.append(seq1) 191 | seqs2.append(seq2) 192 | 193 | return seqs1, seqs2 194 | 195 | def seqs_pad(self, seqs): 196 | lens = [len(seq) for seq in seqs] 197 | max_len = self.max_len 198 | 199 | lens = [l if l <= max_len else max_len for l in lens] 200 | seqs = [seq + [0] * (max_len - len(seq)) if len(seq) <= max_len else seq[-max_len:] for seq in seqs] 201 | return seqs, lens 202 | 203 | 204 | 205 | class TestData(Data): 206 | def __init__(self, config, is_valid): 207 | super(TestData, self).__init__(config.dataset, config.maxlen, config.logger) 208 | self.batch_size = config.batch_size 209 | self.n_batch = len(self.pos_seqs) // config.batch_size 210 | if len(self.pos_seqs) % self.batch_size: 211 | self.n_batch += 1 212 | 213 | if is_valid: 214 | self.pos_seqs = [seq[:-1] for seq in self.pos_seqs] 215 | self.curr_batch = 0 216 | 217 | def __iter__(self): 218 | return self 219 | 220 | def __next__(self): 221 | if self.curr_batch >= self.n_batch: 222 | self.curr_batch = 0 223 | raise StopIteration 224 | 225 | idxs = np.arange(self.curr_batch * self.batch_size, 226 | min(len(self.pos_seqs), (self.curr_batch + 1) * self.batch_size)) 227 | 228 | seqs, tars = [], [] 229 | for idx in idxs: 230 | seq = self.pos_seqs[idx] 231 | seq_in = seq[:-1] 232 | seqs.append(seq_in) 233 | tars.append(seq[-1]) 234 | 235 | max_len = min(self.max_len, max(list(map(len, seqs)))) 236 | lens = [len(seq) if len(seq) <= max_len else max_len for seq in seqs] 237 | seqs = [seq + [0] * (max_len - len(seq)) if len(seq) <= self.max_len else seq[-max_len:] for seq in seqs] 238 | self.curr_batch += 1 239 | 240 | return seqs, tars, lens 241 | 242 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | recbole.model.layers 3 | ############################# 4 | Common Layers in recommender system 5 | """ 6 | 7 | import copy 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as fn 13 | from torch.nn.init import normal_ 14 | 15 | 16 | class MLPLayers(nn.Module): 17 | r""" MLPLayers 18 | 19 | Args: 20 | - layers(list): a list contains the size of each layer in mlp layers 21 | - dropout(float): probability of an element to be zeroed. Default: 0 22 | - activation(str): activation function after each layer in mlp layers. Default: 'relu'. 23 | candidates: 'sigmoid', 'tanh', 'relu', 'leekyrelu', 'none' 24 | 25 | Shape: 26 | 27 | - Input: (:math:`N`, \*, :math:`H_{in}`) where \* means any number of additional dimensions 28 | :math:`H_{in}` must equal to the first value in `layers` 29 | - Output: (:math:`N`, \*, :math:`H_{out}`) where :math:`H_{out}` equals to the last value in `layers` 30 | 31 | Examples:: 32 | 33 | >>> m = MLPLayers([64, 32, 16], 0.2, 'relu') 34 | >>> input = torch.randn(128, 64) 35 | >>> output = m(input) 36 | >>> print(output.size()) 37 | >>> torch.Size([128, 16]) 38 | """ 39 | 40 | def __init__(self, layers, dropout=0., activation='relu', bn=False, init_method=None): 41 | super(MLPLayers, self).__init__() 42 | self.layers = layers 43 | self.dropout = dropout 44 | self.activation = activation 45 | self.use_bn = bn 46 | self.init_method = init_method 47 | 48 | mlp_modules = [] 49 | for idx, (input_size, output_size) in enumerate(zip(self.layers[:-1], self.layers[1:])): 50 | mlp_modules.append(nn.Dropout(p=self.dropout)) 51 | mlp_modules.append(nn.Linear(input_size, output_size)) 52 | if self.use_bn: 53 | mlp_modules.append(nn.BatchNorm1d(num_features=output_size)) 54 | activation_func = activation_layer(self.activation, output_size) 55 | if activation_func is not None: 56 | mlp_modules.append(activation_func) 57 | 58 | self.mlp_layers = nn.Sequential(*mlp_modules) 59 | if self.init_method is not None: 60 | self.apply(self.init_weights) 61 | 62 | def init_weights(self, module): 63 | # We just initialize the module with normal distribution as the paper said 64 | if isinstance(module, nn.Linear): 65 | if self.init_method == 'norm': 66 | normal_(module.weight.data, 0, 0.01) 67 | if module.bias is not None: 68 | module.bias.data.fill_(0.0) 69 | 70 | def forward(self, input_feature): 71 | return self.mlp_layers(input_feature) 72 | 73 | 74 | def activation_layer(activation_name='relu', emb_dim=None): 75 | """Construct activation layers 76 | 77 | Args: 78 | activation_name: str, name of activation function 79 | emb_dim: int, used for Dice activation 80 | 81 | Return: 82 | activation: activation layer 83 | """ 84 | if activation_name is None: 85 | activation = None 86 | elif isinstance(activation_name, str): 87 | if activation_name.lower() == 'sigmoid': 88 | activation = nn.Sigmoid() 89 | elif activation_name.lower() == 'tanh': 90 | activation = nn.Tanh() 91 | elif activation_name.lower() == 'relu': 92 | activation = nn.ReLU() 93 | elif activation_name.lower() == 'leakyrelu': 94 | activation = nn.LeakyReLU() 95 | elif activation_name.lower() == 'dice': 96 | activation = Dice(emb_dim) 97 | elif activation_name.lower() == 'none': 98 | activation = None 99 | elif issubclass(activation_name, nn.Module): 100 | activation = activation_name() 101 | else: 102 | raise NotImplementedError("activation function {} is not implemented".format(activation_name)) 103 | 104 | return activation 105 | 106 | 107 | class VanillaAttention(nn.Module): 108 | """ 109 | Vanilla attention layer is implemented by linear layer. 110 | 111 | Args: 112 | input_tensor (torch.Tensor): the input of the attention layer 113 | 114 | Returns: 115 | hidden_states (torch.Tensor): the outputs of the attention layer 116 | weights (torch.Tensor): the attention weights 117 | 118 | """ 119 | 120 | def __init__(self, hidden_dim, attn_dim): 121 | super().__init__() 122 | self.projection = nn.Sequential(nn.Linear(hidden_dim, attn_dim), nn.ReLU(True), nn.Linear(attn_dim, 1)) 123 | 124 | def forward(self, input_tensor): 125 | # (B, Len, num, H) -> (B, Len, num, 1) 126 | energy = self.projection(input_tensor) 127 | weights = torch.softmax(energy.squeeze(-1), dim=-1) 128 | # (B, Len, num, H) * (B, Len, num, 1) -> (B, len, H) 129 | hidden_states = (input_tensor * weights.unsqueeze(-1)).sum(dim=-2) 130 | return hidden_states, weights 131 | 132 | 133 | class MultiHeadAttention(nn.Module): 134 | """ 135 | Multi-head Self-attention layers, a attention score dropout layer is introduced. 136 | 137 | Args: 138 | input_tensor (torch.Tensor): the input of the multi-head self-attention layer 139 | attention_mask (torch.Tensor): the attention mask for input tensor 140 | 141 | Returns: 142 | hidden_states (torch.Tensor): the output of the multi-head self-attention layer 143 | 144 | """ 145 | 146 | def __init__(self, n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, layer_norm_eps): 147 | super(MultiHeadAttention, self).__init__() 148 | if hidden_size % n_heads != 0: 149 | raise ValueError( 150 | "The hidden size (%d) is not a multiple of the number of attention " 151 | "heads (%d)" % (hidden_size, n_heads) 152 | ) 153 | 154 | self.num_attention_heads = n_heads 155 | self.attention_head_size = int(hidden_size / n_heads) 156 | self.all_head_size = self.num_attention_heads * self.attention_head_size 157 | 158 | self.query = nn.Linear(hidden_size, self.all_head_size) 159 | self.key = nn.Linear(hidden_size, self.all_head_size) 160 | self.value = nn.Linear(hidden_size, self.all_head_size) 161 | 162 | self.attn_dropout = nn.Dropout(attn_dropout_prob) 163 | 164 | self.dense = nn.Linear(hidden_size, hidden_size) 165 | self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 166 | self.out_dropout = nn.Dropout(hidden_dropout_prob) 167 | 168 | def transpose_for_scores(self, x): 169 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 170 | # [B, hidden_num, self.num_attention_heads, self.attention_head_size] 171 | x = x.view(*new_x_shape) 172 | return x.permute(0, 2, 1, 3) 173 | 174 | def forward(self, input_tensor, attention_mask): 175 | # [B, hidden_num. hidden_size] 176 | mixed_query_layer = self.query(input_tensor) 177 | mixed_key_layer = self.key(input_tensor) 178 | mixed_value_layer = self.value(input_tensor) 179 | # [B, hidden_num. self.all_head_size] 180 | query_layer = self.transpose_for_scores(mixed_query_layer) 181 | key_layer = self.transpose_for_scores(mixed_key_layer) 182 | value_layer = self.transpose_for_scores(mixed_value_layer) 183 | # Q [B, self.num_attention_heads, hidden_num, self.attention_head_size] 184 | # K [B, self.num_attention_heads, self.attention_head_size, hidden_num] 185 | # 计算每个hidden_num的attention 186 | # Take the dot product between "query" and "key" to get the raw attention scores. 187 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 188 | 189 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 190 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 191 | # [B, self.num_attention_heads, hidden_num, hidden_num] 192 | attention_scores = attention_scores + attention_mask 193 | 194 | # Normalize the attention scores to probabilities. 195 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 196 | # This is actually dropping out entire tokens to attend to, which might 197 | # seem a bit unusual, but is taken from the original Transformer paper. 198 | 199 | attention_probs = self.attn_dropout(attention_probs) 200 | # W [B, self.num_attention_heads, hidden_num, hidden_num] 201 | # V [B, self.num_attention_heads, hidden_num, self.attention_head_size] 202 | context_layer = torch.matmul(attention_probs, value_layer) 203 | # [B, hidden_num, self.num_attention_heads, self.attention_head_size] 204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 205 | # [B, hidden_num, self.all_head_size] 206 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 207 | context_layer = context_layer.view(*new_context_layer_shape) 208 | hidden_states = self.dense(context_layer) 209 | hidden_states = self.out_dropout(hidden_states) 210 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 211 | 212 | return hidden_states 213 | 214 | 215 | class FeedForward(nn.Module): 216 | """ 217 | Point-wise feed-forward layer is implemented by two dense layers. 218 | 219 | Args: 220 | input_tensor (torch.Tensor): the input of the point-wise feed-forward layer 221 | 222 | Returns: 223 | hidden_states (torch.Tensor): the output of the point-wise feed-forward layer 224 | 225 | """ 226 | 227 | def __init__(self, hidden_size, inner_size, hidden_dropout_prob, hidden_act, layer_norm_eps): 228 | super(FeedForward, self).__init__() 229 | self.dense_1 = nn.Linear(hidden_size, inner_size) 230 | self.intermediate_act_fn = self.get_hidden_act(hidden_act) 231 | 232 | self.dense_2 = nn.Linear(inner_size, hidden_size) 233 | self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 234 | self.dropout = nn.Dropout(hidden_dropout_prob) 235 | 236 | def get_hidden_act(self, act): 237 | ACT2FN = { 238 | "gelu": self.gelu, 239 | "relu": fn.relu, 240 | "swish": self.swish, 241 | "tanh": torch.tanh, 242 | "sigmoid": torch.sigmoid, 243 | } 244 | return ACT2FN[act] 245 | 246 | def gelu(self, x): 247 | """Implementation of the gelu activation function. 248 | 249 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):: 250 | 251 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 252 | 253 | Also see https://arxiv.org/abs/1606.08415 254 | """ 255 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 256 | 257 | def swish(self, x): 258 | return x * torch.sigmoid(x) 259 | 260 | def forward(self, input_tensor): 261 | hidden_states = self.dense_1(input_tensor) 262 | hidden_states = self.intermediate_act_fn(hidden_states) 263 | 264 | hidden_states = self.dense_2(hidden_states) 265 | hidden_states = self.dropout(hidden_states) 266 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 267 | 268 | return hidden_states 269 | 270 | 271 | class TransformerLayer(nn.Module): 272 | """ 273 | One transformer layer consists of a multi-head self-attention layer and a point-wise feed-forward layer. 274 | 275 | Args: 276 | hidden_states (torch.Tensor): the input of the multi-head self-attention sublayer 277 | attention_mask (torch.Tensor): the attention mask for the multi-head self-attention sublayer 278 | 279 | Returns: 280 | feedforward_output (torch.Tensor): The output of the point-wise feed-forward sublayer, 281 | is the output of the transformer layer. 282 | 283 | """ 284 | 285 | def __init__( 286 | self, n_heads, hidden_size, intermediate_size, hidden_dropout_prob, attn_dropout_prob, hidden_act, 287 | layer_norm_eps 288 | ): 289 | super(TransformerLayer, self).__init__() 290 | self.multi_head_attention = MultiHeadAttention( 291 | n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, layer_norm_eps 292 | ) 293 | self.feed_forward = FeedForward(hidden_size, intermediate_size, hidden_dropout_prob, hidden_act, layer_norm_eps) 294 | 295 | def forward(self, hidden_states, attention_mask): 296 | attention_output = self.multi_head_attention(hidden_states, attention_mask) 297 | feedforward_output = self.feed_forward(attention_output) 298 | return feedforward_output 299 | 300 | 301 | class TransformerEncoder(nn.Module): 302 | r""" One TransformerEncoder consists of several TransformerLayers. 303 | 304 | - n_layers(num): num of transformer layers in transformer encoder. Default: 2 305 | - n_heads(num): num of attention heads for multi-head attention layer. Default: 2 306 | - hidden_size(num): the input and output hidden size. Default: 64 307 | - inner_size(num): the dimensionality in feed-forward layer. Default: 256 308 | - hidden_dropout_prob(float): probability of an element to be zeroed. Default: 0.5 309 | - attn_dropout_prob(float): probability of an attention score to be zeroed. Default: 0.5 310 | - hidden_act(str): activation function in feed-forward layer. Default: 'gelu' 311 | candidates: 'gelu', 'relu', 'swish', 'tanh', 'sigmoid' 312 | - layer_norm_eps(float): a value added to the denominator for numerical stability. Default: 1e-12 313 | 314 | """ 315 | 316 | def __init__( 317 | self, 318 | n_layers=2, 319 | n_heads=2, 320 | hidden_size=64, 321 | inner_size=256, 322 | hidden_dropout_prob=0.5, 323 | attn_dropout_prob=0.5, 324 | hidden_act='gelu', 325 | layer_norm_eps=1e-12, 326 | noise_eps = 0.1 327 | ): 328 | 329 | super(TransformerEncoder, self).__init__() 330 | layer = TransformerLayer( 331 | n_heads, hidden_size, inner_size, hidden_dropout_prob, attn_dropout_prob, hidden_act, layer_norm_eps 332 | ) 333 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)]) 334 | self.noise_eps = noise_eps 335 | 336 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, perturbed=False): 337 | """ 338 | Args: 339 | hidden_states (torch.Tensor): the input of the TransformerEncoder 340 | attention_mask (torch.Tensor): the attention mask for the input hidden_states 341 | output_all_encoded_layers (Bool): whether output all transformer layers' output 342 | 343 | Returns: 344 | all_encoder_layers (list): if output_all_encoded_layers is True, return a list consists of all transformer 345 | layers' output, otherwise return a list only consists of the output of last transformer layer. 346 | 347 | """ 348 | all_encoder_layers = [] 349 | for layer_module in self.layer: 350 | if perturbed: 351 | random_noise = torch.rand_like(hidden_states).cuda() 352 | hidden_states += torch.sign(hidden_states) * torch.nn.functional.normalize(random_noise, dim=-1) * self.noise_eps 353 | hidden_states = layer_module(hidden_states, attention_mask) 354 | if output_all_encoded_layers: 355 | all_encoder_layers.append(hidden_states) 356 | if not output_all_encoded_layers: 357 | all_encoder_layers.append(hidden_states) 358 | return all_encoder_layers 359 | 360 | --------------------------------------------------------------------------------