├── remi.png ├── src ├── test.py ├── DNN.py ├── Pop.py ├── MIND.py ├── GRU4Rec.py ├── REMI.py ├── train.py ├── ComiRec.py ├── utils.py ├── evalution.py └── BasicModel.py ├── env.yaml ├── README.md └── process └── data.py /remi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tokkiu/REMI/HEAD/remi.png -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | for i in range(5): 3 | try: 4 | print(i) 5 | break 6 | except Exception as e: 7 | print(i, ',,') -------------------------------------------------------------------------------- /src/DNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from BasicModel import BasicModel 6 | 7 | 8 | class DNN(BasicModel): 9 | 10 | def __init__(self, item_num, hidden_size, batch_size, seq_len=50): 11 | super(DNN, self).__init__(item_num, hidden_size, batch_size, seq_len) 12 | self.linear = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 13 | self.relu = nn.ReLU() 14 | self.reset_parameters() 15 | 16 | 17 | def forward(self, item_list, label_list, mask, times, device, train=True): 18 | # mask: [b, s] 19 | mask = torch.unsqueeze(mask, -1) # [b, s, 1] 20 | item_eb = self.embeddings(item_list) # [b, s, h] 21 | item_eb_mean = torch.sum(item_eb, dim=1) / (torch.sum(mask, dim=1, dtype=torch.float) + 1e-9) # [b, h] 22 | user_eb = self.linear(item_eb_mean) 23 | # todo check this back 24 | # user_eb = self.relu(user_eb) # [b,h] 25 | 26 | scores = self.calculate_score(user_eb) 27 | 28 | return user_eb, scores 29 | -------------------------------------------------------------------------------- /src/Pop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from BasicModel import BasicModel 6 | 7 | 8 | class Pop(BasicModel): 9 | 10 | def __init__(self, item_num, hidden_size, batch_size, seq_len=50, device=None): 11 | super(Pop, self).__init__(item_num, hidden_size, batch_size, seq_len) 12 | self.name = 'Pop' 13 | self.item_cnt = torch.zeros(item_num, 1, dtype=torch.long, device=device, requires_grad=False) 14 | self.max_cnt = None 15 | self.fake_loss = torch.nn.Parameter(torch.zeros(1)) 16 | self.other_parameter_name = ['item_cnt', 'max_cnt'] 17 | 18 | 19 | def forward(self, item_list, label_list, mask, times, device, train=True): 20 | pass 21 | 22 | def calculate_loss(self, item): 23 | self.item_cnt[item, :] = self.item_cnt[item, :] + 1 24 | 25 | self.max_cnt = torch.max(self.item_cnt, dim=0)[0] 26 | 27 | return torch.nn.Parameter(torch.zeros(1)) 28 | 29 | def predict(self, item): 30 | result = torch.true_divide(self.item_cnt[item, :], self.max_cnt) 31 | return result.squeeze(-1) 32 | 33 | def full_sort_predict(self, batch_user_num): 34 | result = self.item_cnt.to(torch.float64) / self.max_cnt.to(torch.float64) 35 | result = torch.repeat_interleave(result.unsqueeze(0), batch_user_num, dim=0) 36 | return result.view(-1) -------------------------------------------------------------------------------- /src/MIND.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from BasicModel import BasicModel, CapsuleNetwork 6 | 7 | 8 | class MIND(BasicModel): 9 | 10 | def __init__(self, item_num, hidden_size, batch_size, interest_num=4, seq_len=50, routing_times=3, relu_layer=True): 11 | super(MIND, self).__init__(item_num, hidden_size, batch_size, seq_len) 12 | self.interest_num = interest_num 13 | self.routing_times = routing_times 14 | self.hard_readout = True 15 | self.capsule_network = CapsuleNetwork(self.hidden_size, self.seq_len, bilinear_type=0, interest_num=self.interest_num, 16 | routing_times=self.routing_times, hard_readout=self.hard_readout, relu_layer=relu_layer) 17 | self.reset_parameters() 18 | 19 | 20 | def forward(self, item_list, label_list, mask, times, device, train=True): 21 | 22 | item_eb = self.embeddings(item_list) 23 | item_eb = item_eb * torch.reshape(mask, (-1, self.seq_len, 1)) 24 | if train: 25 | label_eb = self.embeddings(label_list) 26 | user_eb = self.capsule_network(item_eb, mask, device) 27 | 28 | if not train: 29 | return user_eb, None 30 | 31 | readout, selection = self.read_out(user_eb, label_eb) 32 | scores = self.calculate_score(readout) 33 | 34 | return user_eb, scores, readout, selection 35 | -------------------------------------------------------------------------------- /src/GRU4Rec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from BasicModel import BasicModel 6 | 7 | 8 | class GRU4Rec(BasicModel): 9 | 10 | def __init__(self, item_num, hidden_size, batch_size, seq_len=50, num_layers=3, dropout=0.1): 11 | super(GRU4Rec, self).__init__(item_num, hidden_size, batch_size, seq_len) 12 | 13 | self.gru = nn.GRU( 14 | input_size = self.hidden_size, 15 | hidden_size = self.hidden_size*2, 16 | num_layers=num_layers, 17 | batch_first=True, 18 | bias=False, 19 | ) 20 | self.dense = nn.Linear(hidden_size*2, hidden_size) 21 | self.emb_dropout = nn.Dropout(dropout) 22 | 23 | self.apply(self._init_weights) 24 | 25 | 26 | 27 | def forward(self, item_list, label_list, mask, times, device, train=True): 28 | 29 | item_eb = self.embeddings(item_list) # [b, s, h] 30 | item_seq_emb_dropout = self.emb_dropout(item_eb) 31 | 32 | output, fin_state = self.gru(item_seq_emb_dropout) # [b, s, h], [num_layers, b, h] 33 | # user_eb = fin_state[-1] 34 | # scores = self.calculate_score(user_eb) 35 | item_len_list = mask.sum(dim=1) 36 | # print('log item', item_list[:3]) 37 | # print('log len', item_len_list[:3]) 38 | gru_output = self.dense(output) 39 | # the embedding of the predicted item, shape of (batch_size, embedding_size) 40 | user_eb = self.gather_indexes(gru_output, item_len_list - 1) 41 | scores = self.calculate_score(user_eb) 42 | 43 | return user_eb, scores 44 | 45 | def gather_indexes(self, output, gather_index): 46 | """Gathers the vectors at the specific positions over a minibatch""" 47 | gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1]) 48 | output_tensor = output.gather(dim=1, index=gather_index) 49 | return output_tensor.squeeze(1) 50 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: comitf 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2022.4.26=h06a4308_0 10 | - certifi=2022.6.15=py37h06a4308_0 11 | - cudatoolkit=11.3.1=h2bc3f7f_2 12 | - faiss-gpu=1.7.2=py3.7_h28a55e0_0_cuda11.3 13 | - intel-openmp=2021.4.0=h06a4308_3561 14 | - libedit=3.1.20210910=h7f8727e_0 15 | - libfaiss=1.7.2=hfc2d529_0_cuda11.3 16 | - libffi=3.2.1=hf484d3e_1007 17 | - libgcc-ng=11.2.0=h1234567_1 18 | - libgomp=11.2.0=h1234567_1 19 | - libstdcxx-ng=11.2.0=h1234567_1 20 | - mkl=2021.4.0=h06a4308_640 21 | - mkl-service=2.4.0=py37h7f8727e_0 22 | - mkl_fft=1.3.1=py37hd3c417c_0 23 | - mkl_random=1.2.2=py37h51133e4_0 24 | - ncurses=6.3=h5eee18b_3 25 | - numpy-base=1.21.5=py37ha15fc14_3 26 | - openssl=1.0.2u=h7b6447c_0 27 | - pip=21.2.2=py37h06a4308_0 28 | - python=3.7.0=h6e4f718_3 29 | - readline=7.0=h7b6447c_5 30 | - setuptools=61.2.0=py37h06a4308_0 31 | - sqlite=3.33.0=h62c20be_0 32 | - tk=8.6.12=h1ccaba5_0 33 | - wheel=0.37.1=pyhd3eb1b0_0 34 | - xz=5.2.5=h7f8727e_1 35 | - zlib=1.2.12=h7f8727e_2 36 | - pip: 37 | - absl-py==0.15.0 38 | - astor==0.8.1 39 | - astunparse==1.6.3 40 | - cached-property==1.5.2 41 | - cachetools==5.2.0 42 | - charset-normalizer==2.1.0 43 | - clang==5.0 44 | - flatbuffers==1.12 45 | - gast==0.2.2 46 | - google-auth==2.9.0 47 | - google-auth-oauthlib==0.4.6 48 | - google-pasta==0.2.0 49 | - grpcio==1.47.0 50 | - h5py==3.1.0 51 | - idna==3.3 52 | - importlib-metadata==4.12.0 53 | - keras==2.9.0 54 | - keras-applications==1.0.8 55 | - keras-preprocessing==1.1.2 56 | - markdown==3.3.7 57 | - mock==4.0.3 58 | - numpy==1.19.5 59 | - oauthlib==3.2.0 60 | - opt-einsum==3.3.0 61 | - pandas==1.1.5 62 | - pillow==9.2.0 63 | - protobuf==3.19.4 64 | - pyasn1==0.4.8 65 | - pyasn1-modules==0.2.8 66 | - python-dateutil==2.8.2 67 | - pytz==2022.1 68 | - requests==2.28.1 69 | - requests-oauthlib==1.3.1 70 | - rsa==4.8 71 | - six==1.15.0 72 | - tensorboard==1.15.0 73 | - tensorboard-data-server==0.6.1 74 | - tensorboard-plugin-wit==1.8.1 75 | - tensorboardx==2.5.1 76 | - tensorflow-estimator==1.15.1 77 | - tensorflow-gpu==1.15.0 78 | - termcolor==1.1.0 79 | - torch==1.8.0+cu111 80 | - torchaudio==0.8.0 81 | - torchvision==0.9.0+cu111 82 | - typing-extensions==3.7.4.3 83 | - urllib3==1.26.10 84 | - werkzeug==2.1.2 85 | - wrapt==1.12.1 86 | - zipp==3.8.0 87 | prefix: /home/jingqi/anaconda3/envs/comitf 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REMI 2 | The source code for our RecSys 2023 Paper [**"Rethinking Multi-Interest Learning for Candidate Matching in Recommender Systems"**](https://arxiv.org/abs/2302.14532) 3 | 4 | ## Overview 5 | We propose the REMI framework, consisting of an Interest-aware Hard Negative mining strategy (IHN) and a Routing Regularization (RR) method. IHN emphasizes interest-aware hard negatives by proposing an ideal sampling distribution and developing a Monte-Carlo strategy for efficient approximation. RR prevents routing collapse by introducing a novel regularization term on the item-to-interest routing matrices. These two components enhance the learned multi-interest representations from both the optimization objective and the composition information. REMI is a general framework that can be readily applied to various existing multi-interest candidate matching methods. Experiments on three real-world datasets show our method can significantly improve state-of-the-art methods with easy implementation and negligible computational overhead. The source code will be released. 6 | 7 | ![avatar](remi.png) 8 | 9 | ## Preparation 10 | 11 | Our code is based on PyTorch 1.7.1 and runnable for both windows and ubuntu server. Required python packages: 12 | 13 | > + numpy 14 | > + torch 15 | > + faiss-gpu 16 | 17 | 18 | ## Dataset 19 | Original links of datasets are: 20 | 21 | http://jmcauley.ucsd.edu/data/amazon/index.html 22 | https://tianchi.aliyun.com/dataset/dataDetail?dataId=649&userId=1 23 | https://github.com/RUCAIBox/RecSysDatasets 24 | 25 | You can run python process/data.py {dataset_name} to preprocess the datasets. 26 | 27 | ## Usage 28 | 29 | ### Prepare data 30 | Make sure `reviews_Books_5.json` file is located in current directory 31 | (Download the raw file at https://cseweb.ucsd.edu/~jmcauley/datasets/amazon_v2 32 | ) 33 | ``` 34 | $ git clone https://github.com/Tokkiu/REMI.git 35 | $ cd REMI 36 | $ python process/data.py book 37 | ``` 38 | 39 | ### Prepare environment 40 | ``` 41 | conda env create --file env.yml 42 | ``` 43 | 44 | ### Train and evaluate 45 | 46 | Run the baseline. 47 | ``` 48 | $ python src/train.py --model_type ComiRec-SA --gpu 0 --dataset book 49 | $ python src/train.py --model_type ComiRec-SA --gpu 0 --dataset gowalla 50 | $ python src/train.py --model_type ComiRec-SA --gpu 0 --dataset rocket 51 | ``` 52 | 53 | Available baselines: 54 | * ComiRec-SA 55 | * ComiRec-DR 56 | * MIND 57 | * Pop 58 | * GRU4Rec 59 | * DNN 60 | 61 | Reproduce the reported result. 62 | ``` 63 | $ python src/train.py --model_type REMI --gpu 0 --dataset book --rlambda 100 --rbeta 10 64 | $ python src/train.py --model_type REMI --gpu 0 --dataset gowalla --rlambda 100 --rbeta 1 65 | $ python src/train.py --model_type REMI --gpu 0 --dataset rocket --rlambda 100 --rbeta 0.1 66 | ``` 67 | 68 | The *--rbeta* is used to activate **IHN** module. (available in whole framework) 69 | The *--rlambda* is used to activate **RR** module. (available in REMI) 70 | 71 | ## Cite 72 | 73 | If you find this repo useful, please cite 74 | ``` 75 | @inproceedings{xie2023rethinking, 76 | title = {Rethinking Multi-Interest Learning for Candidate Matching in Recommender Systems}, 77 | author = {Yueqi Xie and Jingqi Gao and Peilin Zhou and Qichen Ye and Yining Hua and Jaeboum Kim and Fangzhao Wu and Sunghun Kim}, 78 | booktitle = {Proceedings of the 17th ACM Conference on Recommender Systems}, 79 | year = {2023}, 80 | } 81 | ``` 82 | 83 | ## Credit 84 | This repo is based on the following repositories: 85 | * [pytorch_ComiRec](https://github.com/ShiningCosmos/pytorch_ComiRec) 86 | * [ComiRec](https://github.com/THUDM/ComiRec) 87 | * [InfoNCE](https://github.com/Stonesjtu/Pytorch-NCE/) 88 | 89 | ## Contact 90 | Feel free to contact us if there is any question. (YueqiXIE, yxieay@connect.ust.hk; Jingqi Gao, mrgao.ary@gmail.com; Peilin Zhou, zhoupalin@gmail.com; Russell KIM, russellkim@upstage.ai) -------------------------------------------------------------------------------- /src/REMI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from BasicModel import BasicModel 6 | 7 | class REMI(BasicModel): 8 | 9 | def __init__(self, item_num, hidden_size, batch_size, interest_num=4, seq_len=50, add_pos=True, beta=0, args=None, device=None): 10 | super(REMI, self).__init__(item_num, hidden_size, batch_size, seq_len, beta) 11 | self.interest_num = interest_num 12 | self.num_heads = interest_num 13 | self.interest_num = interest_num 14 | self.hard_readout = True 15 | self.add_pos = add_pos 16 | if self.add_pos: 17 | self.position_embedding = nn.Parameter(torch.Tensor(1, self.seq_len, self.hidden_size)) 18 | self.linear1 = nn.Sequential( 19 | nn.Linear(self.hidden_size, self.hidden_size * 4, bias=False), 20 | nn.Tanh() 21 | ) 22 | self.linear2 = nn.Linear(self.hidden_size * 4, self.num_heads, bias=False) 23 | self.reset_parameters() 24 | 25 | 26 | def forwardLogits(self, item_eb, mask): 27 | item_eb = item_eb * torch.reshape(mask, (-1, self.seq_len, 1)) 28 | item_eb = torch.reshape(item_eb, (-1, self.seq_len, self.hidden_size)) 29 | if self.add_pos: 30 | # 位置嵌入堆叠一个batch,然后与历史物品嵌入相加 31 | item_eb_add_pos = item_eb + self.position_embedding.repeat(item_eb.shape[0], 1, 1) 32 | # item_eb_add_pos = item_eb + self.position_embedding[:, -1, :].repeat(item_eb.shape[0], 1, 1) 33 | else: 34 | item_eb_add_pos = item_eb 35 | 36 | # shape=(batch_size, maxlen, hidden_size*4) 37 | item_hidden = self.linear1(item_eb_add_pos) 38 | # shape=(batch_size, maxlen, num_heads) 39 | item_att_w = self.linear2(item_hidden) 40 | # shape=(batch_size, num_heads, maxlen) 41 | item_att_w = torch.transpose(item_att_w, 2, 1).contiguous() 42 | 43 | 44 | atten_mask = torch.unsqueeze(mask, dim=1).repeat(1, self.num_heads, 1) # shape=(batch_size, num_heads, maxlen) 45 | paddings = torch.ones_like(atten_mask, dtype=torch.float) * (-2 ** 32 + 1) # softmax之后无限接近于0 46 | 47 | # print(item_eb.size(), item_att_w.size(), atten_mask.size(), mask.size()) 48 | 49 | item_att_w = torch.where(torch.eq(atten_mask, 0), paddings, item_att_w) 50 | item_att_w = F.softmax(item_att_w, dim=-1) # 矩阵A,shape=(batch_size, num_heads, maxlen) 51 | return item_att_w 52 | 53 | def forward(self, item_list, label_list, mask, times, device, train=True): 54 | item_eb = self.embeddings(item_list) 55 | item_eb = item_eb * torch.reshape(mask, (-1, self.seq_len, 1)) 56 | if train: 57 | label_eb = self.embeddings(label_list) 58 | 59 | # 历史物品嵌入序列,shape=(batch_size, maxlen, embedding_dim) 60 | item_eb = torch.reshape(item_eb, (-1, self.seq_len, self.hidden_size)) 61 | 62 | if self.add_pos: 63 | # 位置嵌入堆叠一个batch,然后与历史物品嵌入相加 64 | item_eb_add_pos = item_eb + self.position_embedding.repeat(item_eb.shape[0], 1, 1) 65 | else: 66 | item_eb_add_pos = item_eb 67 | 68 | # shape=(batch_size, maxlen, hidden_size*4) 69 | item_hidden = self.linear1(item_eb_add_pos) 70 | # shape=(batch_size, maxlen, num_heads) 71 | item_att_w = self.linear2(item_hidden) 72 | # shape=(batch_size, num_heads, maxlen) 73 | item_att_w = torch.transpose(item_att_w, 2, 1).contiguous() 74 | 75 | atten_mask = torch.unsqueeze(mask, dim=1).repeat(1, self.num_heads, 1) # shape=(batch_size, num_heads, maxlen) 76 | paddings = torch.ones_like(atten_mask, dtype=torch.float) * (-2 ** 32 + 1) # softmax之后无限接近于0 77 | 78 | item_att_w = torch.where(torch.eq(atten_mask, 0), paddings, item_att_w) 79 | item_att_w = F.softmax(item_att_w, dim=-1) # 矩阵A,shape=(batch_size, num_heads, maxlen) 80 | 81 | # interest_emb即论文中的Vu 82 | interest_emb = torch.matmul(item_att_w, # shape=(batch_size, num_heads, maxlen) 83 | item_eb # shape=(batch_size, maxlen, embedding_dim) 84 | ) # shape=(batch_size, num_heads, embedding_dim) 85 | 86 | # 用户多兴趣向量 87 | user_eb = interest_emb # shape=(batch_size, num_heads, embedding_dim) 88 | 89 | if not train: 90 | return user_eb, None 91 | 92 | readout, selection = self.read_out(user_eb, label_eb) 93 | scores = None if self.is_sampler else self.calculate_score(readout) 94 | 95 | return user_eb, scores, item_att_w, readout, selection 96 | 97 | def calculate_atten_loss(self, attention): 98 | C_mean = torch.mean(attention, dim=2, keepdim=True) 99 | C_reg = (attention - C_mean) 100 | # C_reg = C_reg.matmul(C_reg.transpose(1,2)) / self.hidden_size 101 | C_reg = torch.bmm(C_reg, C_reg.transpose(1, 2)) / self.hidden_size 102 | dr = torch.diagonal(C_reg, dim1=-2, dim2=-1) 103 | n2 = torch.norm(dr, dim=(1)) ** 2 104 | return n2.sum() -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | pid = os.getpid() 5 | print('pid:%d' % (pid)) 6 | 7 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4" 8 | 9 | import torch 10 | from utils import get_parser, setup_seed 11 | from evalution import train, test, output 12 | 13 | 14 | if __name__ == '__main__': 15 | print(sys.argv) 16 | parser = get_parser() 17 | args = parser.parse_args() 18 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 19 | args.gpu = '0' 20 | if args.gpu: 21 | device = torch.device("cuda:"+args.gpu if torch.cuda.is_available() else "cpu") 22 | print("use cuda:"+args.gpu if torch.cuda.is_available() else "use cpu, cuda:"+args.gpu+" not available") 23 | else: 24 | device = torch.device("cpu") 25 | print("use cpu") 26 | 27 | SEED = args.random_seed 28 | setup_seed(SEED) 29 | 30 | if args.dataset == 'book': 31 | path = './data/book_data/' 32 | item_count = 367982 + 1 33 | batch_size = 128 34 | seq_len = 20 35 | test_iter = 1000 36 | if args.dataset == 'bookv': 37 | path = './data/bookv_data/' 38 | item_count = 703121 + 1 39 | batch_size = 128 40 | seq_len = 20 41 | test_iter = 1000 42 | # behaviors: 27158711 43 | if args.dataset == 'bookr': 44 | path = './data/bookr_data/' 45 | item_count = 1163015 + 1 46 | batch_size = 128 47 | seq_len = 20 48 | test_iter = 1000 49 | # behaviors: 28723363 50 | if args.dataset == 'gowalla': 51 | path = './data/gowalla_data/' 52 | item_count = 174605 + 1 53 | batch_size = 256 54 | seq_len = 40 55 | test_iter = 1000 56 | if args.dataset == 'gowalla10': 57 | path = './data/gowalla10_data/' 58 | item_count = 57445 + 1 59 | batch_size = 256 60 | seq_len = 40 61 | test_iter = 1000 62 | # behaviors: 2061264 63 | elif args.dataset == 'familyTV': 64 | path = './data/familyTV_data/' 65 | item_count = 867632 + 1 66 | batch_size = 256 67 | seq_len = 30 68 | test_iter = 1000 69 | elif args.dataset == 'kindle': 70 | path = './data/kindle_data/' 71 | item_count = 260154 + 1 72 | batch_size = 128 73 | seq_len = 20 74 | test_iter = 200 75 | elif args.dataset == 'taobao': 76 | batch_size = 256 77 | seq_len = 50 78 | test_iter = 500 79 | path = './data/taobao_data/' 80 | item_count = 1708531 81 | elif args.dataset == 'cloth': 82 | batch_size = 256 83 | seq_len = 20 84 | test_iter = 200 85 | path = './data/cloth_data/' 86 | item_count = 737822 + 1 87 | 88 | elif args.dataset == 'tmall': 89 | batch_size = 256 90 | seq_len = 100 91 | test_iter = 200 92 | path = './data/tmall_data/' 93 | item_count = 946102 + 1 94 | elif args.dataset == 'rocket': 95 | batch_size = 256 96 | seq_len = 20 97 | test_iter = 200 98 | path = './data/rocket_data/' 99 | item_count = 81635 + 1 100 | 101 | train_file = path + args.dataset + '_train.txt' 102 | valid_file = path + args.dataset + '_valid.txt' 103 | test_file = path + args.dataset + '_test.txt' 104 | cate_file = path + args.dataset + '_item_cate.txt' 105 | dataset = args.dataset 106 | 107 | print("Param dataset=" + str(args.dataset)) 108 | print("Param model_type=" + str(args.model_type)) 109 | print("Param hidden_size=" + str(args.hidden_size)) 110 | print("Param dropout=" + str(args.dropout)) 111 | print("Param layers=" + str(args.layers)) 112 | print("Param interest_num=" + str(args.interest_num)) 113 | print("Param add_pos=" + str(args.add_pos == 1)) 114 | 115 | print("Param weight_decay=" + str(args.weight_decay)) 116 | 117 | batch_size = 128 118 | 119 | prob_dic = { 120 | 0: 'uniform', 121 | 1: 'log' 122 | } 123 | 124 | print("Param sampled_n=" + str(args.sampled_n)) 125 | print("Param beta=" + str(args.rbeta)) 126 | print("Param sampled_loss=" + str(args.sampled_loss)) 127 | print("Param sample_prob=" + prob_dic[args.sample_prob]) 128 | 129 | 130 | if args.p == 'train': 131 | train(device=device, train_file=train_file, valid_file=valid_file, test_file=test_file, 132 | dataset=dataset, model_type=args.model_type, item_count=item_count, batch_size=batch_size, 133 | lr=args.learning_rate, seq_len=seq_len, hidden_size=args.hidden_size, 134 | interest_num=args.interest_num, topN=args.topN, max_iter=args.max_iter, test_iter=test_iter, 135 | decay_step=args.lr_dc_step, lr_decay=args.lr_dc, patience=args.patience, exp=args.exp, args=args) 136 | elif args.p == 'test': 137 | test(device=device, test_file=test_file, cate_file=cate_file, dataset=dataset, model_type=args.model_type, 138 | item_count=item_count, batch_size=batch_size, lr=args.learning_rate, seq_len=seq_len, 139 | hidden_size=args.hidden_size, interest_num=args.interest_num, topN=args.topN, coef=args.coef, exp=args.exp) 140 | elif args.p == 'output': 141 | output(device=device, dataset=dataset, model_type=args.model_type, item_count=item_count, 142 | batch_size=batch_size, lr=args.learning_rate, seq_len=seq_len, hidden_size=args.hidden_size, 143 | interest_num=args.interest_num, topN=args.topN, exp=args.exp) 144 | else: 145 | print('do nothing...') 146 | 147 | -------------------------------------------------------------------------------- /src/ComiRec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from BasicModel import BasicModel, CapsuleNetwork 6 | 7 | 8 | class ComiRec_DR(BasicModel): 9 | 10 | def __init__(self, item_num, hidden_size, batch_size, interest_num=4, seq_len=50, routing_times=3, relu_layer=False, hard_readout=True): 11 | super(ComiRec_DR, self).__init__(item_num, hidden_size, batch_size, seq_len) 12 | self.interest_num = interest_num 13 | self.routing_times = routing_times 14 | #self.hard_readout = True 15 | self.hard_readout = hard_readout 16 | self.capsule_network = CapsuleNetwork(self.hidden_size, self.seq_len, bilinear_type=2, interest_num=self.interest_num, 17 | routing_times=self.routing_times, hard_readout=self.hard_readout, relu_layer=relu_layer) 18 | self.reset_parameters() 19 | 20 | 21 | def forward(self, item_list, label_list, mask, times, device, train=True): 22 | 23 | item_eb = self.embeddings(item_list) 24 | item_eb = item_eb * torch.reshape(mask, (-1, self.seq_len, 1)) 25 | if train: 26 | label_eb = self.embeddings(label_list) 27 | user_eb = self.capsule_network(item_eb, mask, device) 28 | 29 | if not train: 30 | return user_eb, None 31 | 32 | readout, _ = self.read_out(user_eb, label_eb) 33 | scores = self.calculate_score(readout) 34 | 35 | return user_eb, scores, readout 36 | 37 | 38 | class ComiRec_SA(BasicModel): 39 | 40 | def __init__(self, item_num, hidden_size, batch_size, interest_num=4, seq_len=50, add_pos=True, beta=0, args=None, device=None): 41 | super(ComiRec_SA, self).__init__(item_num, hidden_size, batch_size, seq_len, beta) 42 | self.interest_num = interest_num 43 | self.num_heads = interest_num 44 | self.interest_num = interest_num 45 | self.hard_readout = True 46 | self.add_pos = add_pos 47 | if self.add_pos: 48 | self.position_embedding = nn.Parameter(torch.Tensor(1, self.seq_len, self.hidden_size)) 49 | self.linear1 = nn.Sequential( 50 | nn.Linear(self.hidden_size, self.hidden_size * 4, bias=False), 51 | nn.Tanh() 52 | ) 53 | self.linear2 = nn.Linear(self.hidden_size * 4, self.num_heads, bias=False) 54 | self.reset_parameters() 55 | 56 | 57 | def forwardLogits(self, item_eb, mask): 58 | item_eb = item_eb * torch.reshape(mask, (-1, self.seq_len, 1)) 59 | item_eb = torch.reshape(item_eb, (-1, self.seq_len, self.hidden_size)) 60 | if self.add_pos: 61 | # 位置嵌入堆叠一个batch,然后与历史物品嵌入相加 62 | item_eb_add_pos = item_eb + self.position_embedding.repeat(item_eb.shape[0], 1, 1) 63 | # item_eb_add_pos = item_eb + self.position_embedding[:, -1, :].repeat(item_eb.shape[0], 1, 1) 64 | else: 65 | item_eb_add_pos = item_eb 66 | 67 | # shape=(batch_size, maxlen, hidden_size*4) 68 | item_hidden = self.linear1(item_eb_add_pos) 69 | # shape=(batch_size, maxlen, num_heads) 70 | item_att_w = self.linear2(item_hidden) 71 | # shape=(batch_size, num_heads, maxlen) 72 | item_att_w = torch.transpose(item_att_w, 2, 1).contiguous() 73 | 74 | 75 | atten_mask = torch.unsqueeze(mask, dim=1).repeat(1, self.num_heads, 1) # shape=(batch_size, num_heads, maxlen) 76 | paddings = torch.ones_like(atten_mask, dtype=torch.float) * (-2 ** 32 + 1) # softmax之后无限接近于0 77 | 78 | # print(item_eb.size(), item_att_w.size(), atten_mask.size(), mask.size()) 79 | 80 | item_att_w = torch.where(torch.eq(atten_mask, 0), paddings, item_att_w) 81 | item_att_w = F.softmax(item_att_w, dim=-1) # 矩阵A,shape=(batch_size, num_heads, maxlen) 82 | return item_att_w 83 | 84 | def forward(self, item_list, label_list, mask, times, device, train=True): 85 | item_eb = self.embeddings(item_list) 86 | item_eb = item_eb * torch.reshape(mask, (-1, self.seq_len, 1)) 87 | if train: 88 | label_eb = self.embeddings(label_list) 89 | 90 | # 历史物品嵌入序列,shape=(batch_size, maxlen, embedding_dim) 91 | item_eb = torch.reshape(item_eb, (-1, self.seq_len, self.hidden_size)) 92 | 93 | if self.add_pos: 94 | # 位置嵌入堆叠一个batch,然后与历史物品嵌入相加 95 | item_eb_add_pos = item_eb + self.position_embedding.repeat(item_eb.shape[0], 1, 1) 96 | else: 97 | item_eb_add_pos = item_eb 98 | 99 | # shape=(batch_size, maxlen, hidden_size*4) 100 | item_hidden = self.linear1(item_eb_add_pos) 101 | # shape=(batch_size, maxlen, num_heads) 102 | item_att_w = self.linear2(item_hidden) 103 | # shape=(batch_size, num_heads, maxlen) 104 | item_att_w = torch.transpose(item_att_w, 2, 1).contiguous() 105 | 106 | atten_mask = torch.unsqueeze(mask, dim=1).repeat(1, self.num_heads, 1) # shape=(batch_size, num_heads, maxlen) 107 | paddings = torch.ones_like(atten_mask, dtype=torch.float) * (-2 ** 32 + 1) # softmax之后无限接近于0 108 | 109 | item_att_w = torch.where(torch.eq(atten_mask, 0), paddings, item_att_w) 110 | item_att_w = F.softmax(item_att_w, dim=-1) # 矩阵A,shape=(batch_size, num_heads, maxlen) 111 | 112 | # interest_emb即论文中的Vu 113 | interest_emb = torch.matmul(item_att_w, # shape=(batch_size, num_heads, maxlen) 114 | item_eb # shape=(batch_size, maxlen, embedding_dim) 115 | ) # shape=(batch_size, num_heads, embedding_dim) 116 | 117 | # 用户多兴趣向量 118 | user_eb = interest_emb # shape=(batch_size, num_heads, embedding_dim) 119 | 120 | if not train: 121 | return user_eb, None 122 | 123 | readout, selection = self.read_out(user_eb, label_eb) 124 | scores = None if self.is_sampler else self.calculate_score(readout) 125 | 126 | return user_eb, scores, item_att_w, readout, selection 127 | -------------------------------------------------------------------------------- /process/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import random 5 | from collections import defaultdict 6 | 7 | random.seed(1230) 8 | 9 | name = 'cloth' 10 | filter_size = 5 11 | if len(sys.argv) > 1: 12 | name = sys.argv[1] 13 | if len(sys.argv) > 2: 14 | filter_size = int(sys.argv[2]) 15 | 16 | users = defaultdict(list) 17 | item_count = defaultdict(int) 18 | 19 | def read_from_amazon(source): 20 | with open(source, 'r') as f: 21 | for line in f: 22 | r = json.loads(line.strip()) 23 | uid = r['reviewerID'] 24 | iid = r['asin'] 25 | item_count[iid] += 1 26 | ts = float(r['unixReviewTime']) 27 | users[uid].append((iid, ts)) 28 | 29 | def read_from_kindle(source): 30 | with open(source, 'r') as f: 31 | for line in f: 32 | # r = json.loads(line.strip()) 33 | r = line.split(",") 34 | uid = r[1] 35 | iid = r[0] 36 | item_count[iid] += 1 37 | ts = float(r[3]) 38 | users[uid].append((iid, ts)) 39 | 40 | def read_from_movie(source): 41 | with open(source, 'r') as f: 42 | for line in f: 43 | # r = json.loads(line.strip()) 44 | r = line.split(",") 45 | uid = r[1] 46 | iid = r[0] 47 | item_count[iid] += 1 48 | ts = float(r[3]) 49 | users[uid].append((iid, ts)) 50 | 51 | def read_from_taobao(source): 52 | with open(source, 'r') as f: 53 | for line in f: 54 | conts = line.strip().split(',') 55 | uid = int(conts[0]) 56 | iid = int(conts[1]) 57 | if conts[3] != 'pv': 58 | continue 59 | item_count[iid] += 1 60 | ts = int(conts[4]) 61 | users[uid].append((iid, ts)) 62 | 63 | def read_from_gowalla(source): 64 | with open(source, 'r') as f: 65 | for line in f: 66 | conts = line.strip().split('\t') 67 | uid = int(conts[0]) 68 | iid = int(conts[1]) 69 | item_count[iid] += 1 70 | ts = int(float(conts[2])) 71 | users[uid].append((iid, ts)) 72 | 73 | def read_from_tmall(source): 74 | with open(source, 'r') as f: 75 | for line in f: 76 | conts = line.strip().split('\t') 77 | uid = int(conts[0]) 78 | iid = int(conts[2]) 79 | item_count[iid] += 1 80 | ts = int(float(conts[4])) 81 | users[uid].append((iid, ts)) 82 | 83 | def read_from_rocket(source): 84 | with open(source, 'r') as f: 85 | for line in f: 86 | conts = line.strip().split('\t') 87 | uid = int(conts[1]) 88 | iid = int(conts[2]) 89 | item_count[iid] += 1 90 | ts = int(float(conts[0])) 91 | users[uid].append((iid, ts)) 92 | 93 | 94 | if name == 'book': 95 | read_from_amazon('reviews_Books_5.json') 96 | if name == 'bookr': 97 | read_from_kindle('Books.csv') 98 | if name == 'bookv': 99 | read_from_amazon('Books_5.json') 100 | if name == 'kindle': 101 | read_from_kindle('Kindle_Store.csv') 102 | if name == 'movie': 103 | read_from_movie('ratings_Movies_and_TV.csv') 104 | if name == 'beauty': 105 | read_from_kindle('All_Beauty.csv') 106 | if name == 'cloth': 107 | read_from_kindle('Clothing_Shoes_and_Jewelry.csv') 108 | elif name == 'taobao': 109 | read_from_taobao('UserBehavior.csv') 110 | elif name == 'tmall': 111 | read_from_tmall('tmall-click.inter') 112 | elif name == 'rocket': 113 | read_from_rocket('retailrocket-view.inter') 114 | # avg items: 10.586210988489379 115 | # total items: 81635 116 | # total behaviors: 356840 117 | elif name == 'gowalla': 118 | read_from_gowalla('gowalla.inter') 119 | elif name == 'gowalla10': 120 | filter_size = 10 121 | read_from_gowalla('gowalla.inter') 122 | # avg items: 31.466796934631944 123 | # total items: 174605 124 | # total behaviors: 2061264 125 | 126 | items = list(item_count.items()) 127 | items.sort(key=lambda x:x[1], reverse=True) 128 | 129 | item_total = 0 130 | print("Use core", filter_size) 131 | for index, (iid, num) in enumerate(items): 132 | if num >= filter_size: 133 | item_total = index + 1 134 | else: 135 | break 136 | 137 | item_map = dict(zip([items[i][0] for i in range(item_total)], list(range(1, item_total+1)))) 138 | 139 | user_ids = list(users.keys()) 140 | filter_user_ids = [] 141 | filter_seq_count = 0 142 | for user in user_ids: 143 | item_list = users[user] 144 | index = 0 145 | for item, timestamp in item_list: 146 | if item in item_map: 147 | index += 1 148 | if index >= filter_size: 149 | filter_user_ids.append(user) 150 | filter_seq_count += index 151 | user_ids = filter_user_ids 152 | 153 | random.shuffle(user_ids) 154 | num_users = len(user_ids) 155 | user_map = dict(zip(user_ids, list(range(num_users)))) 156 | split_1 = int(num_users * 0.8) 157 | split_2 = int(num_users * 0.9) 158 | train_users = user_ids[:split_1] 159 | valid_users = user_ids[split_1:split_2] 160 | test_users = user_ids[split_2:] 161 | 162 | def export_map(name, map_dict): 163 | with open(name, 'w') as f: 164 | for key, value in map_dict.items(): 165 | f.write('%s,%d\n' % (key, value)) 166 | 167 | def export_data(name, user_list): 168 | total_data = 0 169 | with open(name, 'w') as f: 170 | for user in user_list: 171 | if user not in user_map: 172 | continue 173 | item_list = users[user] 174 | item_list.sort(key=lambda x:x[1]) 175 | index = 0 176 | for item, timestamp in item_list: 177 | if item in item_map: 178 | f.write('%d,%d,%d,%d\n' % (user_map[user], item_map[item], index, timestamp)) 179 | index += 1 180 | total_data += 1 181 | return total_data 182 | 183 | path = './data/' + name + '_data/' 184 | if not os.path.exists(path): 185 | os.mkdir(path) 186 | 187 | export_map(path + name + '_user_map.txt', user_map) 188 | export_map(path + name + '_item_map.txt', item_map) 189 | 190 | total_train = export_data(path + name + '_train.txt', train_users) 191 | total_valid = export_data(path + name + '_valid.txt', valid_users) 192 | total_test = export_data(path + name + '_test.txt', test_users) 193 | print('avg items: ', filter_seq_count/len(user_ids)) 194 | print('total items: ', item_total) 195 | print('total behaviors: ', total_train + total_valid + total_test) 196 | print('total users: ', len(filter_user_ids)) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import numpy as np 6 | import torch 7 | 8 | from torch.utils.data import DataLoader 9 | from DNN import DNN 10 | from Pop import Pop 11 | from GRU4Rec import GRU4Rec 12 | from MIND import MIND 13 | from ComiRec import ComiRec_DR, ComiRec_SA 14 | from REMI import REMI 15 | 16 | 17 | def get_parser(): 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-p', type=str, default='train', help='train | test') # train or test or output 21 | parser.add_argument('--dataset', type=str, default='book', help='book | taobao') # 数据集 22 | parser.add_argument('--random_seed', type=int, default=2021) 23 | parser.add_argument('--hidden_size', type=int, default=64) # 隐藏层维度、嵌入维度 24 | parser.add_argument('--interest_num', type=int, default=4) # 兴趣的数量 25 | parser.add_argument('--model_type', type=str, default='MIND', help='DNN | GRU4Rec | MIND | ..') # 模型类型 26 | parser.add_argument('--learning_rate', type=float, default=0.001, help='learning_rate') # 学习率 27 | parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay rate') 28 | parser.add_argument('--lr_dc_step', type=int, default=30, help='(k), the number of steps after which the learning rate decay') 29 | parser.add_argument('--max_iter', type=int, default=1000, help='(k)') # 最大迭代次数,单位是k(1000) 30 | parser.add_argument('--patience', type=int, default=50) # patience,用于early stopping 31 | parser.add_argument('--topN', type=int, default=50) # default=50 32 | parser.add_argument('--gpu', type=str, default=None) # None -> cpu 33 | parser.add_argument('--coef', default=None) # 多样性,用于test 34 | parser.add_argument('--exp', default='e1') 35 | parser.add_argument('--add_pos', type=int, default=1) 36 | parser.add_argument('--dropout', type=float, default=0.2) 37 | parser.add_argument('--layers', type=int, default=1) 38 | parser.add_argument('--weight_decay', type=float, default=0) 39 | parser.add_argument('--sampled_n', type=int, default=1280) 40 | parser.add_argument('--sampled_loss', type=str, default='sampled') 41 | parser.add_argument('--sample_prob', type=int, default=0) 42 | 43 | # For REMI 44 | parser.add_argument('--rbeta', type=float, default=0) 45 | parser.add_argument('--rlambda', type=float, default=0) 46 | 47 | return parser 48 | 49 | 50 | class DataIterator(torch.utils.data.IterableDataset): 51 | 52 | def __init__(self, source, 53 | batch_size=128, 54 | seq_len=100, 55 | train_flag=1, 56 | time_span = 128 57 | ): 58 | print("Using time span", time_span) 59 | self.read(source) # 读取数据,获取用户列表和对应的按时间戳排序的物品序列,每个用户对应一个物品list 60 | self.users = list(self.users) # 用户列表 61 | 62 | self.time_span = time_span 63 | self.batch_size = batch_size # 用于训练 64 | self.eval_batch_size = batch_size # 用于验证、测试 65 | self.train_flag = train_flag # train_flag=1表示训练 66 | self.seq_len = seq_len # 历史物品序列的最大长度 67 | self.index = 0 # 验证和测试时选择用户的位置的标记 68 | print("total user:", len(self.users)) 69 | print("total items:", len(self.items)) 70 | 71 | def __iter__(self): 72 | return self 73 | 74 | # def next(self): 75 | # return self.__next__() 76 | 77 | def read(self, source): 78 | self.graph = {} # key:user_id,value:一个list,放着该user_id所有(item_id,time_stamp)元组,排序后value只保留item_id 79 | self.time_graph = {} 80 | self.users = set() 81 | self.items = set() 82 | self.times = set() 83 | with open(source, 'r') as f: 84 | for line in f: 85 | conts = line.strip().split(',') 86 | user_id = int(conts[0]) 87 | item_id = int(conts[1]) 88 | if len(conts) == 3: 89 | time_stamp = int(conts[2]) 90 | else: 91 | idx = int(conts[2]) 92 | time_stamp = int(conts[3]) 93 | self.users.add(user_id) 94 | self.items.add(item_id) 95 | self.times.add(time_stamp) 96 | if user_id not in self.graph: 97 | self.graph[user_id] = [] 98 | self.graph[user_id].append((item_id, time_stamp)) 99 | for user_id, value in self.graph.items(): # 每个user的物品序列按时间戳排序 100 | value.sort(key=lambda x: x[1]) 101 | time_list = list(map(lambda x: x[1], value)) 102 | time_min = min(time_list) 103 | # self.graph[user_id] = list(map(lambda x: [x[0], ], items)) 104 | self.graph[user_id] = [x[0] for x in value] # 排序后只保留了item_id 105 | self.time_graph[user_id] = [int(round((x[1] - time_min) / 86400.0) + 1) for x in value] # 排序后只保留了item_id 106 | self.users = list(self.users) # 用户列表 107 | self.items = list(self.items) # 物品列表 108 | 109 | def compute_time_matrix(self, time_seq, item_num): 110 | time_matrix = np.zeros([self.seq_len, self.seq_len], dtype=np.int32) 111 | for i in range(item_num): 112 | for j in range(item_num): 113 | span = abs(time_seq[i] - time_seq[j]) 114 | if span > self.time_span: 115 | time_matrix[i][j] = self.time_span 116 | else: 117 | time_matrix[i][j] = span 118 | return time_matrix.tolist() 119 | 120 | def compute_adj_matrix(self, mask_seq, item_num): 121 | node_num = len(mask_seq) 122 | 123 | adj_matrix = np.zeros([node_num, node_num + 2], dtype=np.int32) 124 | 125 | adj_matrix[0][0] = 1 126 | adj_matrix[0][1] = 1 127 | adj_matrix[0][-1] = 1 128 | 129 | adj_matrix[item_num - 1][item_num - 1] = 1 130 | adj_matrix[item_num - 1][item_num] = 1 131 | adj_matrix[item_num - 1][-1] = 1 132 | 133 | for i in range(1, item_num - 1): 134 | adj_matrix[i][i] = 1 135 | adj_matrix[i][i + 1] = 1 136 | adj_matrix[i][-1] = 1 137 | 138 | if (item_num < node_num): 139 | for i in range(item_num, node_num): 140 | adj_matrix[i][0] = 1 141 | adj_matrix[i][1] = 1 142 | adj_matrix[i][-1] = 1 143 | 144 | return adj_matrix.tolist() 145 | 146 | def __next__(self): 147 | if self.train_flag == 1: # 训练 148 | user_id_list = random.sample(self.users, self.batch_size) # 随机抽取batch_size个user 149 | else: # 验证、测试,按顺序选取eval_batch_size个user,直到遍历完所有user 150 | total_user = len(self.users) 151 | if self.index >= total_user: 152 | self.index = 0 153 | raise StopIteration 154 | user_id_list = self.users[self.index: self.index+self.eval_batch_size] 155 | self.index += self.eval_batch_size 156 | 157 | item_id_list = [] 158 | hist_time_list = [] 159 | hist_item_list = [] 160 | time_matrix_list = [] 161 | hist_mask_list = [] 162 | adj_matrix_list = [] 163 | 164 | for user_id in user_id_list: 165 | item_list = self.graph[user_id] # 排序后的user的item序列 166 | time_list = self.time_graph[user_id] # 排序后的user的item序列 167 | # 这里训练和(验证、测试)采取了不同的数据选取方式 168 | if self.train_flag == 1: # 训练,选取训练时的label 169 | k = random.choice(range(4, len(item_list))) # 从[4,len(item_list))中随机选择一个index 170 | item_id_list.append(item_list[k]) # 该index对应的item加入item_id_list 171 | else: # 验证、测试,选取该user后20%的item用于验证、测试 172 | k = int(len(item_list) * 0.8) 173 | item_id_list.append(item_list[k:]) 174 | # k前的item序列为历史item序列 175 | if k >= self.seq_len: # 选取seq_len个物品 176 | hist_item_list.append(item_list[k-self.seq_len: k]) 177 | hist_mask_list.append([1.0] * self.seq_len) 178 | hist_time_list.append(time_list[k-self.seq_len: k]) 179 | time_matrix_list.append(self.compute_time_matrix(time_list[k - self.seq_len: k], self.seq_len)) 180 | adj_matrix_list.append(self.compute_adj_matrix([1.0] * self.seq_len, self.seq_len)) 181 | 182 | else: 183 | hist_item_list.append(item_list[:k] + [0] * (self.seq_len - k)) 184 | hist_mask_list.append([1.0] * k + [0.0] * (self.seq_len - k)) 185 | hist_time_list.append(time_list[:k] + [0] * (self.seq_len - k)) 186 | time_matrix_list.append(self.compute_time_matrix(time_list[:k] + [0] * (self.seq_len - k), k)) 187 | adj_matrix_list.append(self.compute_adj_matrix([1.0] * k + [0.0] * (self.seq_len - k), k)) 188 | 189 | # 返回用户列表(batch_size)、物品列表(label)(batch_size)、 190 | # 历史物品列表(batch_size,seq_len)、历史物品的mask列表(batch_size,seq_len) 191 | return user_id_list, item_id_list, hist_item_list, hist_mask_list, (time_matrix_list, adj_matrix_list) 192 | # return user_id_list, item_id_list, hist_item_list, hist_mask_list, hist_time_list 193 | 194 | 195 | def get_DataLoader(source, batch_size, seq_len, train_flag=1, args=None): 196 | dataIterator = DataIterator(source, batch_size, seq_len, train_flag) 197 | return DataLoader(dataIterator, batch_size=None, batch_sampler=None) 198 | 199 | 200 | def setup_seed(seed): 201 | torch.manual_seed(seed) 202 | torch.cuda.manual_seed_all(seed) 203 | np.random.seed(seed) 204 | random.seed(seed) 205 | torch.backends.cudnn.deterministic = True 206 | 207 | 208 | # 获取模型 209 | def get_model(dataset, model_type, item_count, batch_size, hidden_size, interest_num, seq_len, routing_times=3, args=None, device=None): 210 | # def get_model(dataset, model_type, item_count, batch_size, hidden_size, interest_num, seq_len, beta, routing_times=3,): 211 | add_pos = True 212 | if args: 213 | add_pos = args.add_pos == 1 214 | if model_type == 'DNN': 215 | model = DNN(item_count, hidden_size, batch_size, seq_len) 216 | elif model_type == 'Pop': 217 | model = Pop(item_count, hidden_size, batch_size, seq_len, device) 218 | elif model_type == 'GRU4Rec': 219 | # todo check this back 220 | model = GRU4Rec(item_count, hidden_size, batch_size, seq_len, num_layers=args.layers, dropout=args.dropout) 221 | elif model_type == 'MIND': 222 | relu_layer = True if dataset == 'book' else False 223 | model = MIND(item_count, hidden_size, batch_size, interest_num, seq_len, routing_times=routing_times, relu_layer=relu_layer) 224 | elif model_type == 'ComiRec-DR': 225 | relu_layer = False 226 | hard_readout = False if dataset == 'kindle' else True 227 | model = ComiRec_DR(item_count, hidden_size, batch_size, interest_num, seq_len, routing_times=routing_times, relu_layer=relu_layer, hard_readout=hard_readout) 228 | elif model_type in ['ComiRec-SA']: 229 | # import pdb; pdb.set_trace() 230 | model = ComiRec_SA(item_count, hidden_size, batch_size, interest_num, seq_len, add_pos=add_pos, args = args, device = device) 231 | elif model_type == "REMI": 232 | model = REMI(item_count, hidden_size, batch_size, interest_num, seq_len, add_pos=add_pos, beta=args.rbeta, 233 | args=args, device=device) 234 | else: 235 | print ("Invalid model_type : %s", model_type) 236 | return 237 | model.name = model_type 238 | return model 239 | 240 | 241 | # 生成实验名称 242 | def get_exp_name(dataset, model_type, batch_size, lr, hidden_size, seq_len, interest_num, topN, save=True, exp='e1'): 243 | extr_name = exp 244 | para_name = '_'.join([dataset, model_type, 'b'+str(batch_size), 'lr'+str(lr), 'd'+str(hidden_size), 245 | 'len'+str(seq_len), 'in'+str(interest_num), 'top'+str(topN)]) 246 | exp_name = para_name + '_' + extr_name 247 | 248 | while os.path.exists('best_model/' + exp_name) and save: 249 | # flag = input('The exp name already exists. Do you want to cover? (y/n)') 250 | # if flag == 'y' or flag == 'Y': 251 | shutil.rmtree('best_model/' + exp_name) 252 | break 253 | # else: 254 | # extr_name = input('Please input the experiment name: ') 255 | # exp_name = para_name + '_' + extr_name 256 | 257 | return exp_name 258 | 259 | 260 | def save_model(model, Path): 261 | if not os.path.exists(Path): 262 | os.makedirs(Path) 263 | torch.save(model.state_dict(), Path + 'model.pt') 264 | 265 | 266 | def load_model(model, path): 267 | model.load_state_dict(torch.load(path + 'model.pt')) 268 | print('model loaded from %s' % path) 269 | 270 | 271 | def to_tensor(var, device): 272 | var = torch.Tensor(var) 273 | var = var.to(device) 274 | return var.long() 275 | 276 | 277 | # 读取物品类别信息,返回一个dict,key:item_id,value:cate_id 278 | def load_item_cate(source): 279 | item_cate = {} 280 | with open(source, 'r') as f: 281 | for line in f: 282 | conts = line.strip().split(',') 283 | item_id = int(conts[0]) 284 | cate_id = int(conts[1]) 285 | item_cate[item_id] = cate_id 286 | return item_cate 287 | 288 | 289 | # 计算物品多样性,item_list中的所有item两两计算 290 | def compute_diversity(item_list, item_cate_map): 291 | n = len(item_list) 292 | diversity = 0.0 293 | for i in range(n): 294 | for j in range(i+1, n): 295 | diversity += item_cate_map[item_list[i]] != item_cate_map[item_list[j]] 296 | diversity /= ((n-1) * n / 2) 297 | return diversity 298 | -------------------------------------------------------------------------------- /src/evalution.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import math 3 | import sys 4 | import time 5 | import faiss 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import os 10 | import signal 11 | 12 | 13 | error_flag = {'sig':0} 14 | 15 | def sig_handler(signum, frame): 16 | error_flag['sig'] = signum 17 | print("segfault core", signum) 18 | 19 | signal.signal(signal.SIGSEGV, sig_handler) 20 | 21 | from utils import get_DataLoader, get_exp_name, get_model, load_model, save_model, to_tensor, load_item_cate, compute_diversity 22 | def evaluate_pop(model, test_data, hidden_size, device, topN=20): 23 | total = 0 24 | total_recall = 0.0 25 | total_ndcg = 0.0 26 | total_hitrate = 0 27 | total_diversity = 0.0 28 | for _, (users, targets, items, mask, times) in enumerate(test_data): 29 | res = model.full_sort_predict(1) 30 | 31 | 32 | item_list = res.argsort(0, True)[:topN] 33 | assert len(item_list.size()) == 1 34 | assert item_list.size(0) == topN 35 | 36 | for i, iid_list in enumerate(targets): # 每个用户的label列表,此处item_id为一个二维list,验证和测试是多label的 37 | recall = 0 38 | dcg = 0.0 39 | # item_list = set(res[0]) # I[i]是一个batch中第i个用户的近邻搜索结果,i∈[0, batch_size) 40 | for no, iid in enumerate(item_list): # 对于每一个label物品 41 | if iid in iid_list: # 如果该label物品在近邻搜索的结果中 42 | recall += 1 43 | dcg += 1.0 / math.log(no + 2, 2) 44 | idcg = 0.0 45 | for no in range(recall): 46 | idcg += 1.0 / math.log(no + 2, 2) 47 | total_recall += recall * 1.0 / len(iid_list) 48 | if recall > 0: # recall>0当然表示有命中 49 | total_ndcg += dcg / idcg 50 | total_hitrate += 1 51 | 52 | total += len(targets) # total增加每个批次的用户数量 53 | 54 | recall = total_recall / total # 召回率,每个用户召回率的平均值 55 | ndcg = total_ndcg / total # NDCG 56 | hitrate = total_hitrate * 1.0 / total # 命中率 57 | return {'recall': recall, 'ndcg': ndcg, 'hitrate': hitrate} 58 | 59 | def evaluate(model, test_data, hidden_size, device, k=20, coef=None, item_cate_map=None, args=None): 60 | if model.name == 'Pop': 61 | return evaluate_pop(model, test_data, hidden_size, device, k) 62 | topN = k # 评价时选取topN 63 | if coef is not None: 64 | coef = float(coef) 65 | 66 | gpu_indexs = [None] 67 | for i in range(1000): 68 | try: 69 | item_embs = model.output_items().cpu().detach().numpy() 70 | res = faiss.StandardGpuResources() # 使用单个GPU 71 | flat_config = faiss.GpuIndexFlatConfig() 72 | flat_config.device = device.index 73 | 74 | gpu_indexs[0] = faiss.GpuIndexFlatIP(res, hidden_size, flat_config) # 建立GPU index用于Inner Product近邻搜索 75 | gpu_indexs[0].add(item_embs) # 给index添加向量数据 76 | if error_flag['sig'] == 0: 77 | break 78 | else: 79 | print("core received", error_flag['sig']) 80 | error_flag['sig'] = 0 81 | except Exception as e: 82 | print("error received", e) 83 | print("Faiss re-try", i) 84 | time.sleep(5) 85 | 86 | 87 | total = 0 88 | total_recall = 0.0 89 | total_ndcg = 0.0 90 | total_hitrate = 0 91 | total_diversity = 0.0 92 | 93 | for _, (users, targets, items, mask, times) in enumerate(test_data): # 一个batch的数据 94 | 95 | # 获取用户嵌入 96 | # 多兴趣模型,shape=(batch_size, num_interest, embedding_dim) 97 | # 其他模型,shape=(batch_size, embedding_dim) 98 | time_mat, adj_mat = times 99 | time_tensor = (to_tensor(time_mat, device), to_tensor(adj_mat, device)) 100 | user_embs,_ = model(to_tensor(items, device), None, to_tensor(mask, device), time_tensor, device, train=False) 101 | user_embs = user_embs.cpu().detach().numpy() 102 | gpu_index = gpu_indexs[0] 103 | # 用内积来近邻搜索,实际是内积的值越大,向量越近(越相似) 104 | if len(user_embs.shape) == 2: # 非多兴趣模型评估 105 | D, I = gpu_index.search(user_embs, topN) # Inner Product近邻搜索,D为distance,I是index 106 | for i, iid_list in enumerate(targets): # 每个用户的label列表,此处item_id为一个二维list,验证和测试是多label的 107 | recall = 0 108 | dcg = 0.0 109 | item_list = set(I[i]) # I[i]是一个batch中第i个用户的近邻搜索结果,i∈[0, batch_size) 110 | for no, iid in enumerate(item_list): # 对于每一个label物品 111 | if iid in iid_list: # 如果该label物品在近邻搜索的结果中 112 | recall += 1 113 | dcg += 1.0 / math.log(no+2, 2) 114 | idcg = 0.0 115 | for no in range(recall): 116 | idcg += 1.0 / math.log(no+2, 2) 117 | total_recall += recall * 1.0 / len(iid_list) 118 | if recall > 0: # recall>0当然表示有命中 119 | total_ndcg += dcg / idcg 120 | total_hitrate += 1 121 | if coef is not None: 122 | total_diversity += compute_diversity(I[i], item_cate_map) # 两个参数分别为推荐物品列表和物品类别字典 123 | else: # 多兴趣模型评估 124 | ni = user_embs.shape[1] # num_interest 125 | user_embs = np.reshape(user_embs, [-1, user_embs.shape[-1]]) # shape=(batch_size*num_interest, embedding_dim) 126 | D, I = gpu_index.search(user_embs, topN) # Inner Product近邻搜索,D为distance,I是index 127 | for i, iid_list in enumerate(targets): # 每个用户的label列表,此处item_id为一个二维list,验证和测试是多label的 128 | recall = 0 129 | dcg = 0.0 130 | item_list_set = set() 131 | if coef is None: # 不考虑物品多样性 132 | # 将num_interest个兴趣向量的所有topN近邻物品(num_interest*topN个物品)集合起来按照距离重新排序 133 | item_list = list(zip(np.reshape(I[i*ni:(i+1)*ni], -1), np.reshape(D[i*ni:(i+1)*ni], -1))) 134 | item_list.sort(key=lambda x:x[1], reverse=True) # 降序排序,内积越大,向量越近 135 | for j in range(len(item_list)): # 按距离由近到远遍历推荐物品列表,最后选出最近的topN个物品作为最终的推荐物品 136 | if item_list[j][0] not in item_list_set and item_list[j][0] != 0: 137 | item_list_set.add(item_list[j][0]) 138 | if len(item_list_set) >= topN: 139 | break 140 | else: # 考虑物品多样性 141 | coef = float(coef) 142 | # 所有兴趣向量的近邻物品集中起来按距离再次排序 143 | origin_item_list = list(zip(np.reshape(I[i*ni:(i+1)*ni], -1), np.reshape(D[i*ni:(i+1)*ni], -1))) 144 | origin_item_list.sort(key=lambda x:x[1], reverse=True) 145 | item_list = [] # 存放(item_id, distance, item_cate)三元组,要用到物品类别,所以只存放有类别的物品 146 | tmp_item_set = set() # 近邻推荐物品中有类别的物品的集合 147 | for (x, y) in origin_item_list: # x为索引,y为距离 148 | if x not in tmp_item_set and x in item_cate_map: 149 | item_list.append((x, y, item_cate_map[x])) 150 | tmp_item_set.add(x) 151 | cate_dict = defaultdict(int) 152 | for j in range(topN): # 选出topN个物品 153 | max_index = 0 154 | # score = distance - λ * 已选出的物品中与该物品的类别相同的物品的数量(score越大越好) 155 | max_score = item_list[0][1] - coef * cate_dict[item_list[0][2]] 156 | for k in range(1, len(item_list)): # 遍历所有候选物品,每个循环找出一个score最大的item 157 | # 第一次遍历必然先选出第一个物品 158 | if item_list[k][1] - coef * cate_dict[item_list[k][2]] > max_score: 159 | max_index = k 160 | max_score = item_list[k][1] - coef * cate_dict[item_list[k][2]] 161 | elif item_list[k][1] < max_score: # 当距离得分小于max_score时,后续物品得分一定小于max_score 162 | break 163 | item_list_set.add(item_list[max_index][0]) 164 | # 选出来的物品的类别对应的value加1,这里是为了尽可能选出类别不同的物品 165 | cate_dict[item_list[max_index][2]] += 1 166 | item_list.pop(max_index) # 候选物品列表中删掉选出来的物品 167 | 168 | 169 | 170 | # 上述if-else只是为了用不同方式计算得到最后推荐的结果item列表 171 | for no, iid in enumerate(item_list_set): # 对于每一个推荐的物品 172 | if iid in iid_list: # 如果该推荐的物品在label物品列表中 173 | recall += 1 174 | dcg += 1.0 / math.log(no+2, 2) 175 | idcg = 0.0 176 | for no in range(recall): 177 | idcg += 1.0 / math.log(no+2, 2) 178 | total_recall += recall * 1.0 / len(iid_list) # len(iid_list)表示label数量 179 | if recall > 0: # recall>0当然表示有命中 180 | total_ndcg += dcg / idcg 181 | total_hitrate += 1 182 | if coef is not None: 183 | total_diversity += compute_diversity(list(item_list_set), item_cate_map) 184 | 185 | total += len(targets) # total增加每个批次的用户数量 186 | 187 | recall = total_recall / total # 召回率,每个用户召回率的平均值 188 | ndcg = total_ndcg / total # NDCG 189 | hitrate = total_hitrate * 1.0 / total # 命中率 190 | if coef is None: 191 | return {'recall': recall, 'ndcg': ndcg, 'hitrate': hitrate} 192 | diversity = total_diversity * 1.0 / total # 多样性 193 | return {'recall': recall, 'ndcg': ndcg, 'hitrate': hitrate, 'diversity': diversity} 194 | 195 | torch.set_printoptions( 196 | precision=2, # 精度,保留小数点后几位,默认4 197 | threshold=np.inf, 198 | edgeitems=3, 199 | linewidth=200, # 每行最多显示的字符数,默认80,超过则换行显示 200 | profile=None, 201 | sci_mode=False # 用科学技术法显示数据,默认True 202 | ) 203 | 204 | def train(device, train_file, valid_file, test_file, dataset, model_type, item_count, batch_size, lr, seq_len, 205 | hidden_size, interest_num, topN, max_iter, test_iter, decay_step, lr_decay, patience, exp, args): 206 | # if model_type in ['MIND', 'ComiRec-DR']: 207 | # lr = 0.005 208 | 209 | print("Param lr=" + str(lr)) 210 | exp_name = get_exp_name(dataset, model_type, batch_size, lr, hidden_size, seq_len, interest_num, topN, exp=exp) # 实验名称 211 | best_model_path = "best_model/" + exp_name + '/' # 模型保存路径 212 | 213 | # prepare data 214 | train_data = get_DataLoader(train_file, batch_size, seq_len, train_flag=1, args=args) 215 | valid_data = get_DataLoader(valid_file, batch_size, seq_len, train_flag=0, args=args) 216 | 217 | model = get_model(dataset, model_type, item_count, batch_size, hidden_size, interest_num, seq_len, args=args, device=device) 218 | model = model.to(device) 219 | model.set_device(device) 220 | 221 | model.set_sampler(args, device=device) 222 | 223 | loss_fn = nn.CrossEntropyLoss() 224 | optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay=args.weight_decay) 225 | 226 | trials = 0 227 | 228 | print('training begin') 229 | sys.stdout.flush() 230 | 231 | start_time = time.time() 232 | model.loss_fct = loss_fn 233 | try: 234 | total_loss, total_loss_1, total_loss_2, total_loss_3, total_loss_4, total_loss_5 = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 235 | iter = 0 236 | best_metric = 0 # 最佳指标值,在这里是最佳recall值 237 | #scheduler.step() 238 | for i, (users, targets, items, mask, times) in enumerate(train_data): 239 | model.train() 240 | iter += 1 241 | optimizer.zero_grad() 242 | pos_items = to_tensor(targets, device) 243 | interests, atten, readout, selection = None, None, None, None 244 | time_mat, adj_mat = times 245 | times_tensor = (to_tensor(time_mat, device), to_tensor(adj_mat, device)) 246 | if model_type in ['ComiRec-SA', "REMI"]: 247 | interests, scores, atten, readout, selection = model(to_tensor(items, device), pos_items, to_tensor(mask, device), times_tensor, device) 248 | 249 | if model_type == 'ComiRec-DR': 250 | interests, scores, readout = model(to_tensor(items, device), pos_items, to_tensor(mask, device), times_tensor, device) 251 | 252 | if model_type == 'MIND': 253 | interests, scores, readout, selection = model(to_tensor(items, device), pos_items, to_tensor(mask, device), times_tensor, device) 254 | 255 | if model_type in ['GRU4Rec', 'DNN']: 256 | readout, scores = model(to_tensor(items, device), pos_items, to_tensor(mask, device), times_tensor, device) 257 | 258 | if model_type == 'Pop': 259 | loss = model.calculate_loss(pos_items) 260 | else: 261 | loss = model.calculate_sampled_loss(readout, pos_items, selection, interests) if model.is_sampler else model.calculate_full_loss(loss_fn, scores, to_tensor(targets, device), interests) 262 | 263 | if model_type == "REMI": 264 | loss += args.rlambda * model.calculate_atten_loss(atten) 265 | 266 | loss.backward() 267 | optimizer.step() 268 | 269 | total_loss += loss 270 | 271 | if iter%test_iter == 0: 272 | model.eval() 273 | metrics = evaluate(model, valid_data, hidden_size, device, topN, args=args) 274 | log_str = 'iter: %d, train loss: %.4f' % (iter, total_loss / test_iter) # 打印loss 275 | if metrics != {}: 276 | log_str += ', ' + ', '.join(['valid ' + key + ': %.6f' % value for key, value in metrics.items()]) 277 | print(exp_name) 278 | print(log_str) 279 | 280 | # 保存recall最佳的模型 281 | if 'recall' in metrics: 282 | recall = metrics['recall'] 283 | if recall > best_metric: 284 | best_metric = recall 285 | save_model(model, best_model_path) 286 | trials = 0 287 | else: 288 | trials += 1 289 | if trials > patience: # early stopping 290 | print("early stopping!") 291 | break 292 | 293 | # 每次test之后loss_sum置零 294 | total_loss = 0.0 295 | test_time = time.time() 296 | print("time interval: %.4f min" % ((test_time-start_time)/60.0)) 297 | sys.stdout.flush() 298 | 299 | if iter >= max_iter * 1000: # 超过最大迭代次数,退出训练 300 | break 301 | 302 | except KeyboardInterrupt: 303 | print('-' * 99) 304 | print('Exiting from training early') 305 | 306 | load_model(model, best_model_path) 307 | model.eval() 308 | 309 | # 训练结束后用valid_data测试一次 310 | metrics = evaluate(model, valid_data, hidden_size, device, topN, args=args) 311 | print(', '.join(['Valid ' + key + ': %.6f' % value for key, value in metrics.items()])) 312 | 313 | # 训练结束后用test_data测试一次 314 | print("Test result:") 315 | test_data = get_DataLoader(test_file, batch_size, seq_len, train_flag=0, args=args) 316 | metrics = evaluate(model, test_data, hidden_size, device, 20, args=args) 317 | for key, value in metrics.items(): 318 | print('test ' + key + '@20' + '=%.6f' % value) 319 | 320 | metrics = evaluate(model, test_data, hidden_size, device, 50, args=args) 321 | for key, value in metrics.items(): 322 | print('test ' + key + '@50' + '=%.6f' % value) 323 | 324 | 325 | def test(device, test_file, cate_file, dataset, model_type, item_count, batch_size, lr, seq_len, 326 | hidden_size, interest_num, topN, coef=None, exp='test'): 327 | 328 | exp_name = get_exp_name(dataset, model_type, batch_size, lr, hidden_size, seq_len, interest_num, topN, save=False, exp=exp) # 实验名称 329 | best_model_path = "best_model/" + exp_name + '/' # 模型保存路径A 330 | 331 | model = get_model(dataset, model_type, item_count, batch_size, hidden_size, interest_num, seq_len) 332 | load_model(model, best_model_path) 333 | model = model.to(device) 334 | model.eval() 335 | 336 | test_data = get_DataLoader(test_file, batch_size, seq_len, train_flag=0) 337 | item_cate_map = load_item_cate(cate_file) # 读取物品的类型 338 | metrics = evaluate(model, test_data, hidden_size, device, topN, coef=coef, item_cate_map=item_cate_map) 339 | print(', '.join(['test ' + key + ': %.6f' % value for key, value in metrics.items()])) 340 | 341 | 342 | def output(device, dataset, model_type, item_count, batch_size, lr, seq_len, 343 | hidden_size, interest_num, topN, exp='eval'): 344 | 345 | exp_name = get_exp_name(dataset, model_type, batch_size, lr, hidden_size, seq_len, interest_num, topN, save=False, exp=exp) # 实验名称 346 | best_model_path = "best_model/" + exp_name + '/' # 模型保存路径 347 | 348 | model = get_model(dataset, model_type, item_count, batch_size, hidden_size, interest_num, seq_len) 349 | load_model(model, best_model_path) 350 | model = model.to(device) 351 | model.eval() 352 | 353 | item_embs = model.output_items() # 获取物品嵌入 354 | np.save('output/' + exp_name + '_emb.npy', item_embs) # 保存物品嵌入 355 | -------------------------------------------------------------------------------- /src/BasicModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import isclose 5 | import math 6 | # A backoff probability to stabilize log operation 7 | BACKOFF_PROB = 1e-10 8 | 9 | class AliasMultinomial(torch.nn.Module): 10 | '''Alias sampling method to speedup multinomial sampling 11 | The alias method treats multinomial sampling as a combination of uniform sampling and 12 | bernoulli sampling. It achieves significant acceleration when repeatedly sampling from 13 | the save multinomial distribution. 14 | Attributes: 15 | - probs: the probability density of desired multinomial distribution 16 | Refs: 17 | - https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 18 | ''' 19 | def __init__(self, probs): 20 | super(AliasMultinomial, self).__init__() 21 | 22 | # @todo calculate divergence 23 | assert abs(probs.sum().item() - 1) < 1e-5, 'The noise distribution must sum to 1' 24 | 25 | cpu_probs = probs.cpu() 26 | K = len(probs) 27 | 28 | # such a name helps to avoid the namespace check for nn.Module 29 | self_prob = [0] * K 30 | self_alias = [0] * K 31 | 32 | # Sort the data into the outcomes with probabilities 33 | # that are larger and smaller than 1/K. 34 | smaller = [] 35 | larger = [] 36 | for idx, prob in enumerate(cpu_probs): 37 | self_prob[idx] = K*prob 38 | if self_prob[idx] < 1.0: 39 | smaller.append(idx) 40 | else: 41 | larger.append(idx) 42 | 43 | # Loop though and create little binary mixtures that 44 | # appropriately allocate the larger outcomes over the 45 | # overall uniform mixture. 46 | while len(smaller) > 0 and len(larger) > 0: 47 | small = smaller.pop() 48 | large = larger.pop() 49 | 50 | self_alias[small] = large 51 | self_prob[large] = (self_prob[large] - 1.0) + self_prob[small] 52 | 53 | if self_prob[large] < 1.0: 54 | smaller.append(large) 55 | else: 56 | larger.append(large) 57 | 58 | for last_one in smaller+larger: 59 | self_prob[last_one] = 1 60 | 61 | self.register_buffer('prob', torch.Tensor(self_prob)) 62 | self.register_buffer('alias', torch.LongTensor(self_alias)) 63 | 64 | def draw(self, *size): 65 | """Draw N samples from multinomial 66 | Args: 67 | - size: the output size of samples 68 | """ 69 | max_value = self.alias.size(0) 70 | 71 | kk = self.alias.new(*size).random_(0, max_value).long().view(-1) 72 | prob = self.prob[kk] 73 | alias = self.alias[kk] 74 | # b is whether a random number is greater than q 75 | b = torch.bernoulli(prob).long() 76 | oq = kk.mul(b) 77 | oj = alias.mul(1 - b) 78 | 79 | return (oq + oj).view(size) 80 | 81 | class NCELoss(nn.Module): 82 | """Noise Contrastive Estimation 83 | NCE is to eliminate the computational cost of softmax 84 | normalization. 85 | There are 3 loss modes in this NCELoss module: 86 | - nce: enable the NCE approximation 87 | - sampled: enabled sampled softmax approximation 88 | - full: use the original cross entropy as default loss 89 | They can be switched by directly setting `nce.loss_type = 'nce'`. 90 | Ref: 91 | X.Chen etal Recurrent neural network language 92 | model training with noise contrastive estimation 93 | for speech recognition 94 | https://core.ac.uk/download/pdf/42338485.pdf 95 | Attributes: 96 | noise: the distribution of noise 97 | noise_ratio: $\frac{#noises}{#real data samples}$ (k in paper) 98 | norm_term: the normalization term (lnZ in paper), can be heuristically 99 | determined by the number of classes, plz refer to the code. 100 | reduction: reduce methods, same with pytorch's loss framework, 'none', 101 | 'elementwise_mean' and 'sum' are supported. 102 | loss_type: loss type of this module, currently 'full', 'sampled', 'nce' 103 | are supported 104 | Shape: 105 | - noise: :math:`(V)` where `V = vocabulary size` 106 | - target: :math:`(B, N)` 107 | - loss: a scalar loss by default, :math:`(B, N)` if `reduction='none'` 108 | Input: 109 | target: the supervised training label. 110 | args&kwargs: extra arguments passed to underlying index module 111 | Return: 112 | loss: if `reduction='sum' or 'elementwise_mean'` the scalar NCELoss ready for backward, 113 | else the loss matrix for every individual targets. 114 | """ 115 | 116 | def __init__(self, 117 | noise, 118 | noise_ratio=100, 119 | norm_term='auto', 120 | reduction='elementwise_mean', 121 | per_word=False, 122 | loss_type='nce', 123 | beta = 0, 124 | device=None 125 | ): 126 | super(NCELoss, self).__init__() 127 | self.device = device 128 | # Re-norm the given noise frequency list and compensate words with 129 | # extremely low prob for numeric stability 130 | self.update_noise(noise) 131 | 132 | # @todo quick path to 'full' mode 133 | # @todo if noise_ratio is 1, use all items as samples 134 | self.noise_ratio = noise_ratio 135 | self.beta = beta 136 | if norm_term == 'auto': 137 | self.norm_term = math.log(noise.numel()) 138 | else: 139 | self.norm_term = norm_term 140 | self.reduction = reduction 141 | self.per_word = per_word 142 | self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none') 143 | self.ce = nn.CrossEntropyLoss(reduction='none') 144 | self.loss_type = loss_type 145 | 146 | def update_noise(self, noise): 147 | probs = noise / noise.sum() 148 | probs = probs.clamp(min=BACKOFF_PROB) 149 | renormed_probs = probs / probs.sum() 150 | # import pdb; pdb.set_trace() 151 | self.register_buffer('logprob_noise', renormed_probs.log()) 152 | self.alias = AliasMultinomial(renormed_probs) 153 | 154 | def forward(self, target, input, embs, interests=None, loss_fn = None, *args, **kwargs): 155 | """compute the loss with output and the desired target 156 | The `forward` is the same among all NCELoss submodules, it 157 | takes care of generating noises and calculating the loss 158 | given target and noise scores. 159 | """ 160 | 161 | batch = target.size(0) 162 | max_len = target.size(1) 163 | if self.loss_type != 'full': 164 | 165 | # use all or sampled 166 | # noise_samples = self.get_noise(batch, max_len) 167 | noise_samples = torch.arange(embs.size(0)).to(self.device).unsqueeze(0).unsqueeze(0).repeat(batch, 1, 1) if self.noise_ratio == 1 else self.get_noise(batch, max_len) 168 | 169 | logit_noise_in_noise = self.logprob_noise[noise_samples.data.view(-1)].view_as(noise_samples) 170 | logit_target_in_noise = self.logprob_noise[target.data.view(-1)].view_as(target) 171 | 172 | # B,N,Nr 173 | 174 | # (B,N), (B,N,Nr) 175 | logit_noise_in_noise = self.logprob_noise[noise_samples.data.view(-1)].view_as(noise_samples) 176 | logit_target_in_noise = self.logprob_noise[target.data.view(-1)].view_as(target) 177 | 178 | logit_target_in_model, logit_noise_in_model = self._get_logit(target, noise_samples, input, embs, *args, **kwargs) 179 | 180 | 181 | 182 | if self.loss_type == 'nce': 183 | if self.training: 184 | loss = self.nce_loss( 185 | logit_target_in_model, logit_noise_in_model, 186 | logit_noise_in_noise, logit_target_in_noise, 187 | ) 188 | else: 189 | # directly output the approximated posterior 190 | loss = - logit_target_in_model 191 | elif self.loss_type == 'sampled': 192 | loss = self.sampled_softmax_loss( 193 | logit_target_in_model, logit_noise_in_model, 194 | logit_noise_in_noise, logit_target_in_noise, 195 | ) 196 | # NOTE: The mix mode is still under investigation 197 | elif self.loss_type == 'mix' and self.training: 198 | loss = 0.5 * self.nce_loss( 199 | logit_target_in_model, logit_noise_in_model, 200 | logit_noise_in_noise, logit_target_in_noise, 201 | ) 202 | loss += 0.5 * self.sampled_softmax_loss( 203 | logit_target_in_model, logit_noise_in_model, 204 | logit_noise_in_noise, logit_target_in_noise, 205 | ) 206 | 207 | else: 208 | current_stage = 'training' if self.training else 'inference' 209 | raise NotImplementedError( 210 | 'loss type {} not implemented at {}'.format( 211 | self.loss_type, current_stage 212 | ) 213 | ) 214 | 215 | else: 216 | # Fallback into conventional cross entropy 217 | loss = self.ce_loss(target, *args, **kwargs) 218 | 219 | if self.reduction == 'elementwise_mean': 220 | return loss.mean() 221 | elif self.reduction == 'sum': 222 | return loss.sum() 223 | else: 224 | return loss 225 | 226 | def get_noise(self, batch_size, max_len): 227 | """Generate noise samples from noise distribution""" 228 | 229 | noise_size = (batch_size, max_len, self.noise_ratio) 230 | if self.per_word: 231 | noise_samples = self.alias.draw(*noise_size) 232 | else: 233 | noise_samples = self.alias.draw(1, 1, self.noise_ratio).expand(*noise_size) 234 | 235 | noise_samples = noise_samples.contiguous() 236 | return noise_samples 237 | 238 | def _get_logit(self, target_idx, noise_idx,input, embs, *args, **kwargs): 239 | """Get the logits of NCE estimated probability for target and noise 240 | Both NCE and sampled softmax Loss are unchanged when the probabilities are scaled 241 | evenly, here we subtract the maximum value as in softmax, for numeric stability. 242 | Shape: 243 | - Target_idx: :math:`(N)` 244 | - Noise_idx: :math:`(N, N_r)` where `N_r = noise ratio` 245 | """ 246 | 247 | target_logit, noise_logit = self.get_score(target_idx, noise_idx, input, embs, *args, **kwargs) 248 | 249 | # import pdb; pdb.set_trace() 250 | target_logit = target_logit.sub(self.norm_term) 251 | noise_logit = noise_logit.sub(self.norm_term) 252 | # import pdb; pdb.set_trace() 253 | return target_logit, noise_logit 254 | 255 | def get_score(self, target_idx, noise_idx, input, embs, *args, **kwargs): 256 | """Get the target and noise score 257 | Usually logits are used as score. 258 | This method should be override by inherit classes 259 | Returns: 260 | - target_score: real valued score for each target index 261 | - noise_score: real valued score for each noise index 262 | """ 263 | original_size = target_idx.size() 264 | 265 | # flatten the following matrix 266 | input = input.contiguous().view(-1, input.size(-1)) 267 | target_idx = target_idx.view(-1) 268 | noise_idx = noise_idx[0, 0].view(-1) 269 | # import pdb; pdb.set_trace() 270 | target_batch = embs[target_idx] 271 | # import pdb; pdb.set_trace() 272 | # target_bias = self.bias.index_select(0, target_idx) # N 273 | target_score = torch.sum(input * target_batch, dim=1) # N X E * N X E 274 | 275 | noise_batch = embs[noise_idx] # Nr X H 276 | noise_score = torch.matmul( 277 | input, noise_batch.t() 278 | ) 279 | return target_score.view(original_size), noise_score.view(*original_size, -1) 280 | 281 | def ce_loss(self, target_idx, *args, **kwargs): 282 | """Get the conventional CrossEntropyLoss 283 | The returned loss should be of the same size of `target` 284 | Args: 285 | - target_idx: batched target index 286 | - args, kwargs: any arbitrary input if needed by sub-class 287 | Returns: 288 | - loss: the estimated loss for each target 289 | """ 290 | raise NotImplementedError() 291 | 292 | def nce_loss(self, logit_target_in_model, logit_noise_in_model, logit_noise_in_noise, logit_target_in_noise): 293 | """Compute the classification loss given all four probabilities 294 | Args: 295 | - logit_target_in_model: logit of target words given by the model (RNN) 296 | - logit_noise_in_model: logit of noise words given by the model 297 | - logit_noise_in_noise: logit of noise words given by the noise distribution 298 | - logit_target_in_noise: logit of target words given by the noise distribution 299 | Returns: 300 | - loss: a mis-classification loss for every single case 301 | """ 302 | 303 | # NOTE: prob <= 1 is not guaranteed 304 | logit_model = torch.cat([logit_target_in_model.unsqueeze(2), logit_noise_in_model], dim=2) 305 | logit_noise = torch.cat([logit_target_in_noise.unsqueeze(2), logit_noise_in_noise], dim=2) 306 | 307 | # predicted probability of the word comes from true data distribution 308 | # The posterior can be computed as following 309 | # p_true = logit_model.exp() / (logit_model.exp() + self.noise_ratio * logit_noise.exp()) 310 | # For numeric stability we compute the logits of true label and 311 | # directly use bce_with_logits. 312 | # Ref https://pytorch.org/docs/stable/nn.html?highlight=bce#torch.nn.BCEWithLogitsLoss 313 | logit_true = logit_model - logit_noise - math.log(self.noise_ratio) 314 | 315 | label = torch.zeros_like(logit_model) 316 | label[:, :, 0] = 1 317 | 318 | loss = self.bce_with_logits(logit_true, label).sum(dim=2) 319 | return loss 320 | 321 | def sampled_softmax_loss(self, logit_target_in_model, logit_noise_in_model, logit_noise_in_noise, logit_target_in_noise): 322 | """Compute the sampled softmax loss based on the tensorflow's impl""" 323 | ori_logits = torch.cat([logit_target_in_model.unsqueeze(2), logit_noise_in_model], dim=2) 324 | q_logits = torch.cat([logit_target_in_noise.unsqueeze(2), logit_noise_in_noise], dim=2) 325 | 326 | # subtract Q for correction of biased sampling 327 | logits = ori_logits - q_logits 328 | labels = torch.zeros_like(logits.narrow(2, 0, 1)).squeeze(2).long() 329 | 330 | if self.beta == 0: 331 | loss = self.ce( 332 | logits.view(-1, logits.size(-1)), 333 | labels.view(-1), 334 | ).view_as(labels) 335 | 336 | if self.beta != 0: 337 | x = ori_logits.view(-1, ori_logits.size(-1)) 338 | x = x - torch.max(x, dim = -1)[0].unsqueeze(-1) 339 | pos = torch.exp(x[:,0]) 340 | neg = torch.exp(x[:,1:]) 341 | imp = (self.beta * x[:,1:] - torch.max(self.beta * x[:,1:],dim = -1)[0].unsqueeze(-1)).exp() 342 | reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1) 343 | if torch.isinf(reweight_neg).any() or torch.isnan(reweight_neg).any(): 344 | import pdb; pdb.set_trace() 345 | Ng = reweight_neg 346 | 347 | stable_logsoftmax = -(x[:,0] - torch.log(pos + Ng)) 348 | loss = torch.unsqueeze(stable_logsoftmax, 1) 349 | 350 | return loss 351 | 352 | 353 | def build_noise(number, args=None): 354 | if args.sample_prob == 0: 355 | return build_uniform_noise(number) 356 | if args.sample_prob == 1: 357 | return build_log_noise(number) 358 | 359 | def build_log_noise(number): 360 | total = number 361 | freq = torch.Tensor([1.0] * number).cuda() 362 | noise = freq / total 363 | for i in range(number): 364 | noise[i] = (np.log(i + 2) - np.log(i + 1)) / np.log(number + 1) 365 | 366 | assert abs(noise.sum() - 1) < 0.001 367 | return noise 368 | 369 | def build_uniform_noise(number): 370 | total = number 371 | freq = torch.Tensor([1.0] * number).cuda() 372 | noise = freq / total 373 | assert abs(noise.sum() - 1) < 0.001 374 | return noise 375 | 376 | import time 377 | from torch.nn.init import xavier_uniform_, xavier_normal_ 378 | 379 | class BasicModel(nn.Module): 380 | 381 | def __init__(self, item_num, hidden_size, batch_size, seq_len=50, beta=0): 382 | super(BasicModel, self).__init__() 383 | self.name = 'base' 384 | self.hidden_size = hidden_size 385 | self.batch_size = batch_size 386 | self.item_num = item_num 387 | self.seq_len = seq_len 388 | self.beta = beta 389 | self.embeddings = nn.Embedding(self.item_num, self.hidden_size, padding_idx=0) 390 | self.interest_num = 0 391 | 392 | def set_device(self, device): 393 | self.device = device 394 | 395 | def set_sampler(self, args, device=None): 396 | self.is_sampler = True 397 | if args.sampled_n == 0: 398 | self.is_sampler = False 399 | return 400 | 401 | self.sampled_n = args.sampled_n 402 | 403 | noise = build_noise(self.item_num, args) 404 | 405 | self.sample_loss = NCELoss(noise=noise, 406 | noise_ratio=self.sampled_n, 407 | norm_term=0, 408 | reduction='elementwise_mean', 409 | per_word=False, 410 | loss_type=args.sampled_loss, 411 | beta=self.beta, 412 | device=device 413 | ) 414 | 415 | def _init_weights(self, module): 416 | if isinstance(module, nn.Embedding): 417 | xavier_normal_(module.weight) 418 | elif isinstance(module, nn.GRU): 419 | xavier_uniform_(module.weight_hh_l0) 420 | xavier_uniform_(module.weight_ih_l0) 421 | 422 | def reset_parameters(self, initializer=None): 423 | for weight in self.parameters(): 424 | torch.nn.init.kaiming_normal_(weight) 425 | 426 | 427 | def read_out(self, user_eb, label_eb): 428 | 429 | # 这个模型训练过程中label是可见的,此处的item_eb就是label物品的嵌入 430 | atten = torch.matmul(user_eb, # shape=(batch_size, interest_num, hidden_size) 431 | torch.reshape(label_eb, (-1, self.hidden_size, 1)) # shape=(batch_size, hidden_size, 1) 432 | ) # shape=(batch_size, interest_num, 1) 433 | 434 | atten = F.softmax(torch.pow(torch.reshape(atten, (-1, self.interest_num)), 1), dim=-1) # shape=(batch_size, interest_num) 435 | 436 | if self.hard_readout: # 选取interest_num个兴趣胶囊中的一个,MIND和ComiRec都是用的这种方式 437 | readout = torch.reshape(user_eb, (-1, self.hidden_size))[ 438 | (torch.argmax(atten, dim=-1) + torch.arange(label_eb.shape[0], device=user_eb.device) * self.interest_num).long()] 439 | else: # 综合interest_num个兴趣胶囊,论文及代码实现中没有使用这种方法 440 | readout = torch.matmul(torch.reshape(atten, (label_eb.shape[0], 1, self.interest_num)), # shape=(batch_size, 1, interest_num) 441 | user_eb # shape=(batch_size, interest_num, hidden_size) 442 | ) # shape=(batch_size, 1, hidden_size) 443 | readout = torch.reshape(readout, (label_eb.shape[0], self.hidden_size)) # shape=(batch_size, hidden_size) 444 | # readout是vu堆叠成的矩阵(一个batch的vu)(vu可以说就是最终的用户嵌入) 445 | selection = torch.argmax(atten, dim=-1) 446 | return readout, selection 447 | 448 | 449 | def calculate_score(self, user_eb): 450 | all_items = self.embeddings.weight 451 | scores = torch.matmul(user_eb, all_items.transpose(1, 0)) # [b, n] 452 | return scores 453 | 454 | 455 | def output_items(self): 456 | return self.embeddings.weight 457 | 458 | def calculate_full_loss(self, loss_fn, scores, target, interests): 459 | return loss_fn(scores, target) 460 | 461 | 462 | def calculate_sampled_loss(self, readout, pos_items, selection, interests): 463 | return self.sample_loss(pos_items.unsqueeze(-1), readout, self.embeddings.weight) 464 | 465 | 466 | import numpy as np 467 | import random 468 | import math 469 | class LogUniformSampler(object): 470 | def __init__(self, ntokens, device): 471 | 472 | self.N = ntokens 473 | self.prob = [0] * self.N 474 | 475 | self.generate_distribution() 476 | self.prob_tensor = torch.tensor(self.prob) 477 | self.cans = torch.arange(0, self.N) 478 | 479 | def generate_distribution(self): 480 | for i in range(self.N): 481 | self.prob[i] = (np.log(i+2) - np.log(i+1)) / np.log(self.N + 1) 482 | 483 | def probability(self, idx): 484 | return self.prob[idx] 485 | 486 | def expected_count(self, num_tries, samples): 487 | freq = list() 488 | for sample_idx in samples: 489 | freq.append(-(np.exp(num_tries * np.log(1-self.prob[sample_idx]))-1)) 490 | return freq 491 | 492 | def accidental_match(self, labels, samples): 493 | sample_dict = dict() 494 | 495 | for idx in range(len(samples)): 496 | sample_dict[samples[idx]] = idx 497 | 498 | result = list() 499 | for idx in range(len(labels)): 500 | if labels[idx] in sample_dict: 501 | result.append((idx, sample_dict[labels[idx]])) 502 | 503 | return result 504 | 505 | def sample(self, size, labels): 506 | log_N = np.log(self.N) 507 | 508 | x = np.random.uniform(low=0.0, high=1.0, size=size) 509 | value = np.floor(np.exp(x * log_N)).astype(int) - 1 510 | samples = value.tolist() 511 | 512 | true_freq = self.expected_count(size, labels.tolist()) 513 | sample_freq = self.expected_count(size, samples) 514 | if random.random() < 0.0002: 515 | print('By softmax', [round(i, 3) for i in true_freq], [round(i, 3) for i in sample_freq]) 516 | 517 | return samples, true_freq, sample_freq 518 | 519 | def sample_uniform_prob(self, size, labels): 520 | idx = self.prob_tensor.multinomial(num_samples=size, replacement=False) 521 | b = self.cans[idx] 522 | 523 | true_freq = self.expected_count(size, labels.tolist()) 524 | sample_freq = self.expected_count(size, b) 525 | if random.random() < 0.0002: 526 | print('By uniform prob', [round(i, 3) for i in true_freq], [round(i, 3) for i in sample_freq]) 527 | 528 | return b, true_freq, sample_freq 529 | 530 | def sample_uniform(self, size, labels): 531 | indice = random.sample(range(self.N), size) 532 | indice = torch.tensor(indice) 533 | 534 | true_freq = self.expected_count(size, labels.tolist()) 535 | sample_freq = self.expected_count(size, indice) 536 | # print('By uniform', true_freq, sample_freq) 537 | 538 | return indice, true_freq, sample_freq 539 | 540 | def sample_unique(self, size, labels): 541 | # Slow. Not Recommended. 542 | log_N = np.log(self.N) 543 | samples = list() 544 | 545 | while (len(samples) < size): 546 | x = np.random.uniform(low=0.0, high=1.0, size=1)[0] 547 | value = np.floor(np.exp(x * log_N)).astype(int) - 1 548 | if value in samples: 549 | continue 550 | else: 551 | samples.append(value) 552 | 553 | true_freq = self.expected_count(size, labels.tolist()) 554 | sample_freq = self.expected_count(size, samples) 555 | 556 | return samples, true_freq, sample_freq 557 | 558 | 559 | class CapsuleNetwork(nn.Module): 560 | 561 | def __init__(self, hidden_size, seq_len, bilinear_type=2, interest_num=4, routing_times=3, hard_readout=True, relu_layer=False): 562 | super(CapsuleNetwork, self).__init__() 563 | self.hidden_size = hidden_size # h 564 | self.seq_len = seq_len # s 565 | self.bilinear_type = bilinear_type 566 | self.interest_num = interest_num 567 | self.routing_times = routing_times 568 | self.hard_readout = hard_readout 569 | self.relu_layer = relu_layer 570 | self.stop_grad = True 571 | self.relu = nn.Sequential( 572 | nn.Linear(self.hidden_size, self.hidden_size, bias=False), 573 | nn.ReLU() 574 | ) 575 | if self.bilinear_type == 0: # MIND 576 | self.linear = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 577 | elif self.bilinear_type == 1: 578 | self.linear = nn.Linear(self.hidden_size, self.hidden_size * self.interest_num, bias=False) 579 | else: # ComiRec_DR 580 | self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.hidden_size, self.hidden_size)) 581 | 582 | 583 | def forward(self, item_eb, mask, device): 584 | if self.bilinear_type == 0: # MIND 585 | item_eb_hat = self.linear(item_eb) # [b, s, h] 586 | item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num) # [b, s, h*in] 587 | elif self.bilinear_type == 1: 588 | item_eb_hat = self.linear(item_eb) 589 | else: # ComiRec_DR 590 | u = torch.unsqueeze(item_eb, dim=2) # shape=(batch_size, maxlen, 1, embedding_dim) 591 | item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3) # shape=(batch_size, maxlen, hidden_size*interest_num) 592 | 593 | item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.hidden_size)) 594 | item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous() 595 | item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.hidden_size)) 596 | 597 | # [b, in, s, h] 598 | if self.stop_grad: # 截断反向传播,item_emb_hat不计入梯度计算中 599 | item_eb_hat_iter = item_eb_hat.detach() 600 | else: 601 | item_eb_hat_iter = item_eb_hat 602 | 603 | # b的shape=(b, in, s) 604 | if self.bilinear_type > 0: # b初始化为0(一般的胶囊网络算法) 605 | capsule_weight = torch.zeros(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=device, requires_grad=False) 606 | else: # MIND使用高斯分布随机初始化b 607 | capsule_weight = torch.randn(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=device, requires_grad=False) 608 | 609 | for i in range(self.routing_times): # 动态路由传播3次 610 | atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1) # [b, in, s] 611 | paddings = torch.zeros_like(atten_mask, dtype=torch.float) 612 | 613 | # 计算c,进行mask,最后shape=[b, in, 1, s] 614 | capsule_softmax_weight = F.softmax(capsule_weight, dim=-1) 615 | capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight) # mask 616 | capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2) 617 | 618 | if i < 2: 619 | # s=c*u_hat , (batch_size, interest_num, 1, seq_len) * (batch_size, interest_num, seq_len, hidden_size) 620 | interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter) # shape=(batch_size, interest_num, 1, hidden_size) 621 | cap_norm = torch.sum(torch.square(interest_capsule), -1, True) # shape=(batch_size, interest_num, 1, 1) 622 | scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9) # shape同上 623 | interest_capsule = scalar_factor * interest_capsule # squash(s)->v,shape=(batch_size, interest_num, 1, hidden_size) 624 | 625 | # 更新b 626 | delta_weight = torch.matmul(item_eb_hat_iter, # shape=(batch_size, interest_num, seq_len, hidden_size) 627 | torch.transpose(interest_capsule, 2, 3).contiguous() # shape=(batch_size, interest_num, hidden_size, 1) 628 | ) # u_hat*v, shape=(batch_size, interest_num, seq_len, 1) 629 | delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len)) # shape=(batch_size, interest_num, seq_len) 630 | capsule_weight = capsule_weight + delta_weight # 更新b 631 | else: 632 | interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat) 633 | cap_norm = torch.sum(torch.square(interest_capsule), -1, True) 634 | scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9) 635 | interest_capsule = scalar_factor * interest_capsule 636 | 637 | interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.hidden_size)) 638 | 639 | if self.relu_layer: # MIND模型使用book数据库时,使用relu_layer 640 | interest_capsule = self.relu(interest_capsule) 641 | 642 | return interest_capsule 643 | --------------------------------------------------------------------------------