├── .idea ├── .gitignore ├── IOCRec.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── dataset └── toys │ └── toys.seq ├── log └── toys │ ├── IOCRec-seed-2024_1.log │ ├── IOCRec-seed2023_1.log │ ├── IOCRec-seed2024-hidden128_1.log │ ├── IOCRec-seed2024-hidden256_1.log │ └── IOCRec_1.log ├── runIOCRec.py ├── src ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── data_processor.cpython-38.pyc │ │ └── dataset.cpython-38.pyc │ ├── data_processor.py │ └── dataset.py ├── evaluation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── estimator.cpython-38.pyc │ │ └── metrics.cpython-38.pyc │ ├── estimator.py │ └── metrics.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── abstract_recommeder.cpython-38.pyc │ │ ├── data_augmentation.cpython-38.pyc │ │ ├── loss.cpython-38.pyc │ │ └── sequential_encoder.cpython-38.pyc │ ├── abstract_recommeder.py │ ├── cl_based_seq_recommender │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── cl4srec.cpython-38.pyc │ │ │ └── iocrec.cpython-38.pyc │ │ ├── cl4srec.py │ │ └── iocrec.py │ ├── data_augmentation.py │ ├── loss.py │ ├── sequential_encoder.py │ └── sequential_recommender │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── sasrec.cpython-38.pyc │ │ └── sasrec.py ├── train │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── config.cpython-38.pyc │ │ └── trainer.cpython-38.pyc │ ├── config.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── recorder.cpython-38.pyc │ └── utils.cpython-38.pyc │ ├── recorder.py │ └── utils.py └── toys.sh /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/IOCRec.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IOCRec 2 | Pytorch implementation for paper: Multi-Intention Oriented Contrastive Learning for Sequential Recommendation (WSDM23). 3 | 4 | We implement IOCRec in Pytorch and obtain quite similar results on Toys under the same experimental setting. The default hyper-parameters are set as the optimal values for Toys reported in the paper. Besides, the training log is available for reproduction. 5 | 6 | ``` 7 | 2023-07-07 15:48:05 INFO ------------------------------------------------Best Evaluation------------------------------------------------ 8 | 2023-07-07 15:48:05 INFO Best Result at Epoch: 33 Early Stop at Patience: 10 9 | 2023-07-07 15:48:05 INFO hit@5:0.4513 hit@10:0.5453 hit@20:0.6621 hit@50:0.7935 ndcg@5:0.3588 ndcg@10:0.3891 ndcg@20:0.4186 ndcg@50:0.4455 10 | 2023-07-07 15:48:07 INFO -----------------------------------------------------Test Results------------------------------------------------------ 11 | 2023-07-07 15:48:07 INFO hit@5:0.4022 hit@10:0.5005 hit@20:0.6205 hit@50:0.7594 ndcg@5:0.3145 ndcg@10:0.3462 ndcg@20:0.3765 ndcg@50:0.4048 12 | ``` 13 | ## Datasets 14 | We provide Toys dataset. 15 | 16 | ## Quick Start 17 | You can run the model with the following code: 18 | ``` 19 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --k_intention 4 20 | ``` 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /log/toys/IOCRec-seed2023_1.log: -------------------------------------------------------------------------------- 1 | 2024-07-23 15:17:34 INFO log save at : log\toys\IOCRec-seed2023_1.log 2 | 2024-07-23 15:17:34 INFO model save at: save\IOCRec-toys-seed2023-2024-07-23_15-17-34.pth 3 | 2024-07-23 15:17:35 INFO [1] Model Hyper-Parameter --------------------- 4 | 2024-07-23 15:17:35 INFO model: IOCRec 5 | 2024-07-23 15:17:35 INFO model_type: SEQUENTIAL 6 | 2024-07-23 15:17:35 INFO aug_types: ['crop', 'mask', 'reorder'] 7 | 2024-07-23 15:17:35 INFO crop_ratio: 0.2 8 | 2024-07-23 15:17:35 INFO mask_ratio: 0.7 9 | 2024-07-23 15:17:35 INFO reorder_ratio: 0.2 10 | 2024-07-23 15:17:35 INFO all_hidden: True 11 | 2024-07-23 15:17:35 INFO tao: 1.0 12 | 2024-07-23 15:17:35 INFO lamda: 0.1 13 | 2024-07-23 15:17:35 INFO k_intention: 4 14 | 2024-07-23 15:17:35 INFO embed_size: 64 15 | 2024-07-23 15:17:35 INFO ffn_hidden: 512 16 | 2024-07-23 15:17:35 INFO num_blocks: 3 17 | 2024-07-23 15:17:35 INFO num_heads: 2 18 | 2024-07-23 15:17:35 INFO hidden_dropout: 0.5 19 | 2024-07-23 15:17:35 INFO attn_dropout: 0.5 20 | 2024-07-23 15:17:35 INFO layer_norm_eps: 1e-12 21 | 2024-07-23 15:17:35 INFO initializer_range: 0.02 22 | 2024-07-23 15:17:35 INFO loss_type: CE 23 | 2024-07-23 15:17:35 INFO [2] Experiment Hyper-Parameter ---------------- 24 | 2024-07-23 15:17:35 INFO [2-1] data hyper-parameter -------------------- 25 | 2024-07-23 15:17:35 INFO dataset: toys 26 | 2024-07-23 15:17:35 INFO data_aug: True 27 | 2024-07-23 15:17:35 INFO seq_filter_len: 0 28 | 2024-07-23 15:17:35 INFO if_filter_target: False 29 | 2024-07-23 15:17:35 INFO max_len: 50 30 | 2024-07-23 15:17:35 INFO [2-2] pretraining hyper-parameter ------------- 31 | 2024-07-23 15:17:35 INFO do_pretraining: False 32 | 2024-07-23 15:17:35 INFO pretraining_task: MISP 33 | 2024-07-23 15:17:35 INFO pretraining_epoch: 10 34 | 2024-07-23 15:17:35 INFO pretraining_batch: 512 35 | 2024-07-23 15:17:35 INFO pretraining_lr: 0.001 36 | 2024-07-23 15:17:35 INFO pretraining_l2: 0.0 37 | 2024-07-23 15:17:35 INFO [2-3] training hyper-parameter ---------------- 38 | 2024-07-23 15:17:35 INFO epoch_num: 150 39 | 2024-07-23 15:17:35 INFO train_batch: 256 40 | 2024-07-23 15:17:35 INFO learning_rate: 0.001 41 | 2024-07-23 15:17:35 INFO l2: 0 42 | 2024-07-23 15:17:35 INFO patience: 10 43 | 2024-07-23 15:17:35 INFO device: cuda:0 44 | 2024-07-23 15:17:35 INFO num_worker: 0 45 | 2024-07-23 15:17:35 INFO seed: 2023 46 | 2024-07-23 15:17:35 INFO [2-4] evaluation hyper-parameter -------------- 47 | 2024-07-23 15:17:35 INFO split_type: valid_and_test 48 | 2024-07-23 15:17:35 INFO split_mode: LS 49 | 2024-07-23 15:17:35 INFO eval_mode: uni100 50 | 2024-07-23 15:17:35 INFO metric: ['hit', 'ndcg'] 51 | 2024-07-23 15:17:35 INFO k: [5, 10, 20, 50] 52 | 2024-07-23 15:17:35 INFO valid_metric: hit@10 53 | 2024-07-23 15:17:35 INFO eval_batch: 256 54 | 2024-07-23 15:17:35 INFO [2-5] save hyper-parameter -------------------- 55 | 2024-07-23 15:17:35 INFO log_save: log 56 | 2024-07-23 15:17:35 INFO save: save 57 | 2024-07-23 15:17:35 INFO model_saved: None 58 | 2024-07-23 15:17:35 INFO [3] Data Statistic ---------------------------- 59 | 2024-07-23 15:17:35 INFO dataset: toys 60 | 2024-07-23 15:17:35 INFO user number: 19412 61 | 2024-07-23 15:17:35 INFO item number: 11925 62 | 2024-07-23 15:17:35 INFO average seq length: 8.6337 63 | 2024-07-23 15:17:35 INFO density: 0.0007 sparsity: 0.9993 64 | 2024-07-23 15:17:35 INFO data after augmentation: 65 | 2024-07-23 15:17:35 INFO train samples: 109361 eval samples: 19412 test samples: 19412 66 | 2024-07-23 15:17:35 INFO [1] Model Architecture ------------------------ 67 | 2024-07-23 15:17:35 INFO total parameters: 1035520 68 | 2024-07-23 15:17:35 INFO IOCRec( 69 | (cross_entropy): CrossEntropyLoss() 70 | (item_embedding): Embedding(11927, 64, padding_idx=0) 71 | (position_embedding): Embedding(50, 64) 72 | (input_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True) 73 | (input_dropout): Dropout(p=0.5, inplace=False) 74 | (local_encoder): Transformer( 75 | (encoder_layers): ModuleList( 76 | (0-2): 3 x EncoderLayer( 77 | (attn_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True) 78 | (pff_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True) 79 | (self_attention): MultiHeadAttentionLayer( 80 | (fc_q): Linear(in_features=64, out_features=64, bias=True) 81 | (fc_k): Linear(in_features=64, out_features=64, bias=True) 82 | (fc_v): Linear(in_features=64, out_features=64, bias=True) 83 | (attn_dropout): Dropout(p=0.5, inplace=False) 84 | (fc_o): Linear(in_features=64, out_features=64, bias=True) 85 | ) 86 | (pff): PointWiseFeedForwardLayer( 87 | (fc1): Linear(in_features=64, out_features=512, bias=True) 88 | (fc2): Linear(in_features=512, out_features=64, bias=True) 89 | ) 90 | (hidden_dropout): Dropout(p=0.5, inplace=False) 91 | (pff_out_drop): Dropout(p=0.5, inplace=False) 92 | ) 93 | ) 94 | ) 95 | (global_seq_encoder): GlobalSeqEncoder( 96 | (dropout): Dropout(p=0.5, inplace=False) 97 | (K_linear): Linear(in_features=64, out_features=64, bias=True) 98 | (V_linear): Linear(in_features=64, out_features=64, bias=True) 99 | ) 100 | (disentangle_encoder): DisentangleEncoder( 101 | (pos_fai): Embedding(50, 64) 102 | (W): Linear(in_features=64, out_features=64, bias=True) 103 | (layer_norm_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 104 | (layer_norm_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 105 | (layer_norm_3): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 106 | (layer_norm_4): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 107 | (layer_norm_5): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 108 | ) 109 | (nce_loss): InfoNCELoss( 110 | (criterion): CrossEntropyLoss() 111 | ) 112 | ) 113 | 2024-07-23 15:17:35 INFO Start training... 114 | 2024-07-23 15:18:09 INFO ----------------------------------------------------Epoch 1---------------------------------------------------- 115 | 2024-07-23 15:18:09 INFO Training Time :[33.3 s] Training Loss = 11.0718 116 | 2024-07-23 15:18:10 INFO Evaluation Time:[1.0 s] Eval Loss = 9.3946 117 | 2024-07-23 15:18:10 INFO hit@5:0.0718 hit@10:0.1202 hit@20:0.2062 hit@50:0.3898 ndcg@5:0.0443 ndcg@10:0.0598 ndcg@20:0.0813 ndcg@50:0.1180 118 | 2024-07-23 15:18:42 INFO ----------------------------------------------------Epoch 2---------------------------------------------------- 119 | 2024-07-23 15:18:42 INFO Training Time :[32.8 s] Training Loss = 9.9107 120 | 2024-07-23 15:18:43 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3376 121 | 2024-07-23 15:18:43 INFO hit@5:0.1525 hit@10:0.2575 hit@20:0.4075 hit@50:0.4748 ndcg@5:0.0966 ndcg@10:0.1302 ndcg@20:0.1681 ndcg@50:0.1825 122 | 2024-07-23 15:19:16 INFO ----------------------------------------------------Epoch 3---------------------------------------------------- 123 | 2024-07-23 15:19:16 INFO Training Time :[32.8 s] Training Loss = 9.5619 124 | 2024-07-23 15:19:17 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1699 125 | 2024-07-23 15:19:17 INFO hit@5:0.2255 hit@10:0.3208 hit@20:0.4547 hit@50:0.5217 ndcg@5:0.1567 ndcg@10:0.1874 ndcg@20:0.2211 ndcg@50:0.2354 126 | 2024-07-23 15:19:50 INFO ----------------------------------------------------Epoch 4---------------------------------------------------- 127 | 2024-07-23 15:19:50 INFO Training Time :[32.8 s] Training Loss = 9.2151 128 | 2024-07-23 15:19:51 INFO Evaluation Time:[0.9 s] Eval Loss = 9.0117 129 | 2024-07-23 15:19:51 INFO hit@5:0.2879 hit@10:0.3840 hit@20:0.5012 hit@50:0.5750 ndcg@5:0.2079 ndcg@10:0.2389 ndcg@20:0.2683 ndcg@50:0.2839 130 | 2024-07-23 15:20:24 INFO ----------------------------------------------------Epoch 5---------------------------------------------------- 131 | 2024-07-23 15:20:24 INFO Training Time :[32.9 s] Training Loss = 8.9182 132 | 2024-07-23 15:20:25 INFO Evaluation Time:[1.0 s] Eval Loss = 8.8102 133 | 2024-07-23 15:20:25 INFO hit@5:0.3426 hit@10:0.4392 hit@20:0.5495 hit@50:0.6274 ndcg@5:0.2529 ndcg@10:0.2841 ndcg@20:0.3119 ndcg@50:0.3282 134 | 2024-07-23 15:20:58 INFO ----------------------------------------------------Epoch 6---------------------------------------------------- 135 | 2024-07-23 15:20:58 INFO Training Time :[33.1 s] Training Loss = 8.6244 136 | 2024-07-23 15:20:59 INFO Evaluation Time:[0.9 s] Eval Loss = 8.581 137 | 2024-07-23 15:20:59 INFO hit@5:0.3856 hit@10:0.4870 hit@20:0.5982 hit@50:0.6972 ndcg@5:0.2900 ndcg@10:0.3227 ndcg@20:0.3508 ndcg@50:0.3712 138 | 2024-07-23 15:21:32 INFO ----------------------------------------------------Epoch 7---------------------------------------------------- 139 | 2024-07-23 15:21:32 INFO Training Time :[33.1 s] Training Loss = 8.3363 140 | 2024-07-23 15:21:33 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4352 141 | 2024-07-23 15:21:33 INFO hit@5:0.4089 hit@10:0.5100 hit@20:0.6219 hit@50:0.7328 ndcg@5:0.3132 ndcg@10:0.3459 ndcg@20:0.3741 ndcg@50:0.3968 142 | 2024-07-23 15:22:06 INFO ----------------------------------------------------Epoch 8---------------------------------------------------- 143 | 2024-07-23 15:22:06 INFO Training Time :[32.9 s] Training Loss = 8.0958 144 | 2024-07-23 15:22:07 INFO Evaluation Time:[0.9 s] Eval Loss = 8.3217 145 | 2024-07-23 15:22:07 INFO hit@5:0.4293 hit@10:0.5254 hit@20:0.6391 hit@50:0.7704 ndcg@5:0.3325 ndcg@10:0.3635 ndcg@20:0.3922 ndcg@50:0.4189 146 | 2024-07-23 15:22:40 INFO ----------------------------------------------------Epoch 9---------------------------------------------------- 147 | 2024-07-23 15:22:40 INFO Training Time :[33.0 s] Training Loss = 7.8828 148 | 2024-07-23 15:22:41 INFO Evaluation Time:[0.9 s] Eval Loss = 8.3024 149 | 2024-07-23 15:22:41 INFO hit@5:0.4324 hit@10:0.5290 hit@20:0.6475 hit@50:0.7881 ndcg@5:0.3393 ndcg@10:0.3705 ndcg@20:0.4003 ndcg@50:0.4288 150 | 2024-07-23 15:23:13 INFO ----------------------------------------------------Epoch 10---------------------------------------------------- 151 | 2024-07-23 15:23:13 INFO Training Time :[32.9 s] Training Loss = 7.6411 152 | 2024-07-23 15:23:14 INFO Evaluation Time:[0.9 s] Eval Loss = 8.3549 153 | 2024-07-23 15:23:14 INFO hit@5:0.4394 hit@10:0.5328 hit@20:0.6481 hit@50:0.8010 ndcg@5:0.3455 ndcg@10:0.3756 ndcg@20:0.4047 ndcg@50:0.4356 154 | 2024-07-23 15:23:47 INFO ----------------------------------------------------Epoch 11---------------------------------------------------- 155 | 2024-07-23 15:23:47 INFO Training Time :[32.8 s] Training Loss = 7.4921 156 | 2024-07-23 15:23:48 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4414 157 | 2024-07-23 15:23:48 INFO hit@5:0.4396 hit@10:0.5349 hit@20:0.6530 hit@50:0.8046 ndcg@5:0.3475 ndcg@10:0.3783 ndcg@20:0.4081 ndcg@50:0.4386 158 | 2024-07-23 15:24:21 INFO ----------------------------------------------------Epoch 12---------------------------------------------------- 159 | 2024-07-23 15:24:21 INFO Training Time :[32.9 s] Training Loss = 7.3233 160 | 2024-07-23 15:24:22 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4664 161 | 2024-07-23 15:24:22 INFO hit@5:0.4418 hit@10:0.5354 hit@20:0.6513 hit@50:0.7995 ndcg@5:0.3486 ndcg@10:0.3788 ndcg@20:0.4079 ndcg@50:0.4379 162 | 2024-07-23 15:24:55 INFO ----------------------------------------------------Epoch 13---------------------------------------------------- 163 | 2024-07-23 15:24:55 INFO Training Time :[32.9 s] Training Loss = 7.2029 164 | 2024-07-23 15:24:56 INFO Evaluation Time:[0.9 s] Eval Loss = 8.5498 165 | 2024-07-23 15:24:56 INFO hit@5:0.4423 hit@10:0.5357 hit@20:0.6529 hit@50:0.8065 ndcg@5:0.3504 ndcg@10:0.3806 ndcg@20:0.4102 ndcg@50:0.4412 166 | 2024-07-23 15:25:29 INFO ----------------------------------------------------Epoch 14---------------------------------------------------- 167 | 2024-07-23 15:25:29 INFO Training Time :[32.9 s] Training Loss = 7.0766 168 | 2024-07-23 15:25:30 INFO Evaluation Time:[0.9 s] Eval Loss = 8.6568 169 | 2024-07-23 15:25:30 INFO hit@5:0.4435 hit@10:0.5385 hit@20:0.6561 hit@50:0.8055 ndcg@5:0.3522 ndcg@10:0.3829 ndcg@20:0.4125 ndcg@50:0.4427 170 | 2024-07-23 15:26:03 INFO ----------------------------------------------------Epoch 15---------------------------------------------------- 171 | 2024-07-23 15:26:03 INFO Training Time :[32.9 s] Training Loss = 6.9914 172 | 2024-07-23 15:26:04 INFO Evaluation Time:[0.9 s] Eval Loss = 8.8304 173 | 2024-07-23 15:26:04 INFO hit@5:0.4402 hit@10:0.5366 hit@20:0.6507 hit@50:0.8021 ndcg@5:0.3506 ndcg@10:0.3818 ndcg@20:0.4106 ndcg@50:0.4412 174 | 2024-07-23 15:26:04 INFO EarlyStopping Counter: 1 out of 10 175 | 2024-07-23 15:26:36 INFO ----------------------------------------------------Epoch 16---------------------------------------------------- 176 | 2024-07-23 15:26:36 INFO Training Time :[32.9 s] Training Loss = 6.9066 177 | 2024-07-23 15:26:37 INFO Evaluation Time:[0.9 s] Eval Loss = 8.8853 178 | 2024-07-23 15:26:37 INFO hit@5:0.4405 hit@10:0.5362 hit@20:0.6516 hit@50:0.8025 ndcg@5:0.3508 ndcg@10:0.3817 ndcg@20:0.4108 ndcg@50:0.4413 179 | 2024-07-23 15:26:37 INFO EarlyStopping Counter: 2 out of 10 180 | 2024-07-23 15:27:10 INFO ----------------------------------------------------Epoch 17---------------------------------------------------- 181 | 2024-07-23 15:27:10 INFO Training Time :[32.9 s] Training Loss = 6.8368 182 | 2024-07-23 15:27:11 INFO Evaluation Time:[0.9 s] Eval Loss = 8.9803 183 | 2024-07-23 15:27:11 INFO hit@5:0.4457 hit@10:0.5386 hit@20:0.6551 hit@50:0.8060 ndcg@5:0.3545 ndcg@10:0.3846 ndcg@20:0.4139 ndcg@50:0.4444 184 | 2024-07-23 15:27:44 INFO ----------------------------------------------------Epoch 18---------------------------------------------------- 185 | 2024-07-23 15:27:44 INFO Training Time :[32.8 s] Training Loss = 6.7691 186 | 2024-07-23 15:27:45 INFO Evaluation Time:[0.9 s] Eval Loss = 9.0 187 | 2024-07-23 15:27:45 INFO hit@5:0.4424 hit@10:0.5354 hit@20:0.6513 hit@50:0.8080 ndcg@5:0.3506 ndcg@10:0.3806 ndcg@20:0.4098 ndcg@50:0.4414 188 | 2024-07-23 15:27:45 INFO EarlyStopping Counter: 1 out of 10 189 | 2024-07-23 15:28:18 INFO ----------------------------------------------------Epoch 19---------------------------------------------------- 190 | 2024-07-23 15:28:18 INFO Training Time :[32.9 s] Training Loss = 6.7210 191 | 2024-07-23 15:28:19 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1515 192 | 2024-07-23 15:28:19 INFO hit@5:0.4448 hit@10:0.5383 hit@20:0.6550 hit@50:0.8110 ndcg@5:0.3541 ndcg@10:0.3843 ndcg@20:0.4137 ndcg@50:0.4453 193 | 2024-07-23 15:28:19 INFO EarlyStopping Counter: 2 out of 10 194 | 2024-07-23 15:28:52 INFO ----------------------------------------------------Epoch 20---------------------------------------------------- 195 | 2024-07-23 15:28:52 INFO Training Time :[32.9 s] Training Loss = 6.6472 196 | 2024-07-23 15:28:52 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3712 197 | 2024-07-23 15:28:52 INFO hit@5:0.4424 hit@10:0.5359 hit@20:0.6491 hit@50:0.8114 ndcg@5:0.3503 ndcg@10:0.3805 ndcg@20:0.4090 ndcg@50:0.4418 198 | 2024-07-23 15:28:52 INFO EarlyStopping Counter: 3 out of 10 199 | 2024-07-23 15:29:25 INFO ----------------------------------------------------Epoch 21---------------------------------------------------- 200 | 2024-07-23 15:29:25 INFO Training Time :[32.8 s] Training Loss = 6.6233 201 | 2024-07-23 15:29:26 INFO Evaluation Time:[1.0 s] Eval Loss = 9.2789 202 | 2024-07-23 15:29:26 INFO hit@5:0.4415 hit@10:0.5368 hit@20:0.6493 hit@50:0.8083 ndcg@5:0.3498 ndcg@10:0.3806 ndcg@20:0.4089 ndcg@50:0.4411 203 | 2024-07-23 15:29:26 INFO EarlyStopping Counter: 4 out of 10 204 | 2024-07-23 15:29:59 INFO ----------------------------------------------------Epoch 22---------------------------------------------------- 205 | 2024-07-23 15:29:59 INFO Training Time :[32.9 s] Training Loss = 6.5730 206 | 2024-07-23 15:30:00 INFO Evaluation Time:[0.9 s] Eval Loss = 9.4025 207 | 2024-07-23 15:30:00 INFO hit@5:0.4408 hit@10:0.5347 hit@20:0.6493 hit@50:0.8103 ndcg@5:0.3508 ndcg@10:0.3810 ndcg@20:0.4099 ndcg@50:0.4424 208 | 2024-07-23 15:30:00 INFO EarlyStopping Counter: 5 out of 10 209 | 2024-07-23 15:30:33 INFO ----------------------------------------------------Epoch 23---------------------------------------------------- 210 | 2024-07-23 15:30:33 INFO Training Time :[33.0 s] Training Loss = 6.5330 211 | 2024-07-23 15:30:34 INFO Evaluation Time:[0.9 s] Eval Loss = 9.5505 212 | 2024-07-23 15:30:34 INFO hit@5:0.4380 hit@10:0.5326 hit@20:0.6483 hit@50:0.8110 ndcg@5:0.3489 ndcg@10:0.3794 ndcg@20:0.4085 ndcg@50:0.4413 213 | 2024-07-23 15:30:34 INFO EarlyStopping Counter: 6 out of 10 214 | 2024-07-23 15:31:07 INFO ----------------------------------------------------Epoch 24---------------------------------------------------- 215 | 2024-07-23 15:31:07 INFO Training Time :[32.9 s] Training Loss = 6.4867 216 | 2024-07-23 15:31:08 INFO Evaluation Time:[0.9 s] Eval Loss = 9.2692 217 | 2024-07-23 15:31:08 INFO hit@5:0.4389 hit@10:0.5329 hit@20:0.6490 hit@50:0.8044 ndcg@5:0.3481 ndcg@10:0.3784 ndcg@20:0.4076 ndcg@50:0.4391 218 | 2024-07-23 15:31:08 INFO EarlyStopping Counter: 7 out of 10 219 | 2024-07-23 15:31:41 INFO ----------------------------------------------------Epoch 25---------------------------------------------------- 220 | 2024-07-23 15:31:41 INFO Training Time :[32.8 s] Training Loss = 6.4528 221 | 2024-07-23 15:31:42 INFO Evaluation Time:[1.0 s] Eval Loss = 9.3407 222 | 2024-07-23 15:31:42 INFO hit@5:0.4401 hit@10:0.5327 hit@20:0.6511 hit@50:0.8073 ndcg@5:0.3492 ndcg@10:0.3790 ndcg@20:0.4088 ndcg@50:0.4404 223 | 2024-07-23 15:31:42 INFO EarlyStopping Counter: 8 out of 10 224 | 2024-07-23 15:32:14 INFO ----------------------------------------------------Epoch 26---------------------------------------------------- 225 | 2024-07-23 15:32:14 INFO Training Time :[32.8 s] Training Loss = 6.4163 226 | 2024-07-23 15:32:15 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3725 227 | 2024-07-23 15:32:15 INFO hit@5:0.4367 hit@10:0.5293 hit@20:0.6491 hit@50:0.8035 ndcg@5:0.3471 ndcg@10:0.3771 ndcg@20:0.4071 ndcg@50:0.4384 228 | 2024-07-23 15:32:15 INFO EarlyStopping Counter: 9 out of 10 229 | 2024-07-23 15:32:48 INFO ----------------------------------------------------Epoch 27---------------------------------------------------- 230 | 2024-07-23 15:32:48 INFO Training Time :[32.9 s] Training Loss = 6.3712 231 | 2024-07-23 15:32:49 INFO Evaluation Time:[1.0 s] Eval Loss = 9.3215 232 | 2024-07-23 15:32:49 INFO hit@5:0.4349 hit@10:0.5317 hit@20:0.6480 hit@50:0.8015 ndcg@5:0.3455 ndcg@10:0.3767 ndcg@20:0.4059 ndcg@50:0.4370 233 | 2024-07-23 15:32:49 INFO EarlyStopping Counter: 10 out of 10 234 | 2024-07-23 15:32:49 INFO ------------------------------------------------Best Evaluation------------------------------------------------ 235 | 2024-07-23 15:32:49 INFO Best Result at Epoch: 17 Early Stop at Patience: 10 236 | 2024-07-23 15:32:49 INFO hit@5:0.4457 hit@10:0.5386 hit@20:0.6551 hit@50:0.8060 ndcg@5:0.3545 ndcg@10:0.3846 ndcg@20:0.4139 ndcg@50:0.4444 237 | 2024-07-23 15:32:50 INFO -----------------------------------------------------Test Results------------------------------------------------------ 238 | 2024-07-23 15:32:50 INFO hit@5:0.4030 hit@10:0.4966 hit@20:0.6120 hit@50:0.7719 ndcg@5:0.3163 ndcg@10:0.3464 ndcg@20:0.3755 ndcg@50:0.4080 239 | -------------------------------------------------------------------------------- /log/toys/IOCRec-seed2024-hidden256_1.log: -------------------------------------------------------------------------------- 1 | 2024-07-23 15:32:52 INFO log save at : log\toys\IOCRec-seed2024-hidden256_1.log 2 | 2024-07-23 15:32:52 INFO model save at: save\IOCRec-toys-seed2024-hidden256-2024-07-23_15-32-52.pth 3 | 2024-07-23 15:32:53 INFO [1] Model Hyper-Parameter --------------------- 4 | 2024-07-23 15:32:53 INFO model: IOCRec 5 | 2024-07-23 15:32:53 INFO model_type: SEQUENTIAL 6 | 2024-07-23 15:32:53 INFO aug_types: ['crop', 'mask', 'reorder'] 7 | 2024-07-23 15:32:53 INFO crop_ratio: 0.2 8 | 2024-07-23 15:32:53 INFO mask_ratio: 0.7 9 | 2024-07-23 15:32:53 INFO reorder_ratio: 0.2 10 | 2024-07-23 15:32:53 INFO all_hidden: True 11 | 2024-07-23 15:32:53 INFO tao: 1.0 12 | 2024-07-23 15:32:53 INFO lamda: 0.1 13 | 2024-07-23 15:32:53 INFO k_intention: 4 14 | 2024-07-23 15:32:53 INFO embed_size: 64 15 | 2024-07-23 15:32:53 INFO ffn_hidden: 256 16 | 2024-07-23 15:32:53 INFO num_blocks: 3 17 | 2024-07-23 15:32:53 INFO num_heads: 2 18 | 2024-07-23 15:32:53 INFO hidden_dropout: 0.5 19 | 2024-07-23 15:32:53 INFO attn_dropout: 0.5 20 | 2024-07-23 15:32:53 INFO layer_norm_eps: 1e-12 21 | 2024-07-23 15:32:53 INFO initializer_range: 0.02 22 | 2024-07-23 15:32:53 INFO loss_type: CE 23 | 2024-07-23 15:32:53 INFO [2] Experiment Hyper-Parameter ---------------- 24 | 2024-07-23 15:32:53 INFO [2-1] data hyper-parameter -------------------- 25 | 2024-07-23 15:32:53 INFO dataset: toys 26 | 2024-07-23 15:32:53 INFO data_aug: True 27 | 2024-07-23 15:32:53 INFO seq_filter_len: 0 28 | 2024-07-23 15:32:53 INFO if_filter_target: False 29 | 2024-07-23 15:32:53 INFO max_len: 50 30 | 2024-07-23 15:32:53 INFO [2-2] pretraining hyper-parameter ------------- 31 | 2024-07-23 15:32:53 INFO do_pretraining: False 32 | 2024-07-23 15:32:53 INFO pretraining_task: MISP 33 | 2024-07-23 15:32:53 INFO pretraining_epoch: 10 34 | 2024-07-23 15:32:53 INFO pretraining_batch: 512 35 | 2024-07-23 15:32:53 INFO pretraining_lr: 0.001 36 | 2024-07-23 15:32:53 INFO pretraining_l2: 0.0 37 | 2024-07-23 15:32:53 INFO [2-3] training hyper-parameter ---------------- 38 | 2024-07-23 15:32:53 INFO epoch_num: 150 39 | 2024-07-23 15:32:53 INFO train_batch: 256 40 | 2024-07-23 15:32:53 INFO learning_rate: 0.001 41 | 2024-07-23 15:32:53 INFO l2: 0 42 | 2024-07-23 15:32:53 INFO patience: 10 43 | 2024-07-23 15:32:53 INFO device: cuda:0 44 | 2024-07-23 15:32:53 INFO num_worker: 0 45 | 2024-07-23 15:32:53 INFO seed: 2024 46 | 2024-07-23 15:32:53 INFO [2-4] evaluation hyper-parameter -------------- 47 | 2024-07-23 15:32:53 INFO split_type: valid_and_test 48 | 2024-07-23 15:32:53 INFO split_mode: LS 49 | 2024-07-23 15:32:53 INFO eval_mode: uni100 50 | 2024-07-23 15:32:53 INFO metric: ['hit', 'ndcg'] 51 | 2024-07-23 15:32:53 INFO k: [5, 10, 20, 50] 52 | 2024-07-23 15:32:53 INFO valid_metric: hit@10 53 | 2024-07-23 15:32:53 INFO eval_batch: 256 54 | 2024-07-23 15:32:53 INFO [2-5] save hyper-parameter -------------------- 55 | 2024-07-23 15:32:53 INFO log_save: log 56 | 2024-07-23 15:32:53 INFO save: save 57 | 2024-07-23 15:32:53 INFO model_saved: None 58 | 2024-07-23 15:32:53 INFO [3] Data Statistic ---------------------------- 59 | 2024-07-23 15:32:53 INFO dataset: toys 60 | 2024-07-23 15:32:53 INFO user number: 19412 61 | 2024-07-23 15:32:53 INFO item number: 11925 62 | 2024-07-23 15:32:53 INFO average seq length: 8.6337 63 | 2024-07-23 15:32:53 INFO density: 0.0007 sparsity: 0.9993 64 | 2024-07-23 15:32:53 INFO data after augmentation: 65 | 2024-07-23 15:32:53 INFO train samples: 109361 eval samples: 19412 test samples: 19412 66 | 2024-07-23 15:32:53 INFO [1] Model Architecture ------------------------ 67 | 2024-07-23 15:32:53 INFO total parameters: 936448 68 | 2024-07-23 15:32:53 INFO IOCRec( 69 | (cross_entropy): CrossEntropyLoss() 70 | (item_embedding): Embedding(11927, 64, padding_idx=0) 71 | (position_embedding): Embedding(50, 64) 72 | (input_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True) 73 | (input_dropout): Dropout(p=0.5, inplace=False) 74 | (local_encoder): Transformer( 75 | (encoder_layers): ModuleList( 76 | (0-2): 3 x EncoderLayer( 77 | (attn_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True) 78 | (pff_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True) 79 | (self_attention): MultiHeadAttentionLayer( 80 | (fc_q): Linear(in_features=64, out_features=64, bias=True) 81 | (fc_k): Linear(in_features=64, out_features=64, bias=True) 82 | (fc_v): Linear(in_features=64, out_features=64, bias=True) 83 | (attn_dropout): Dropout(p=0.5, inplace=False) 84 | (fc_o): Linear(in_features=64, out_features=64, bias=True) 85 | ) 86 | (pff): PointWiseFeedForwardLayer( 87 | (fc1): Linear(in_features=64, out_features=256, bias=True) 88 | (fc2): Linear(in_features=256, out_features=64, bias=True) 89 | ) 90 | (hidden_dropout): Dropout(p=0.5, inplace=False) 91 | (pff_out_drop): Dropout(p=0.5, inplace=False) 92 | ) 93 | ) 94 | ) 95 | (global_seq_encoder): GlobalSeqEncoder( 96 | (dropout): Dropout(p=0.5, inplace=False) 97 | (K_linear): Linear(in_features=64, out_features=64, bias=True) 98 | (V_linear): Linear(in_features=64, out_features=64, bias=True) 99 | ) 100 | (disentangle_encoder): DisentangleEncoder( 101 | (pos_fai): Embedding(50, 64) 102 | (W): Linear(in_features=64, out_features=64, bias=True) 103 | (layer_norm_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 104 | (layer_norm_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 105 | (layer_norm_3): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 106 | (layer_norm_4): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 107 | (layer_norm_5): LayerNorm((64,), eps=1e-05, elementwise_affine=True) 108 | ) 109 | (nce_loss): InfoNCELoss( 110 | (criterion): CrossEntropyLoss() 111 | ) 112 | ) 113 | 2024-07-23 15:32:53 INFO Start training... 114 | 2024-07-23 15:33:23 INFO ----------------------------------------------------Epoch 1---------------------------------------------------- 115 | 2024-07-23 15:33:23 INFO Training Time :[30.5 s] Training Loss = 11.2264 116 | 2024-07-23 15:33:24 INFO Evaluation Time:[0.9 s] Eval Loss = 9.386 117 | 2024-07-23 15:33:24 INFO hit@5:0.0798 hit@10:0.1371 hit@20:0.2449 hit@50:0.3658 ndcg@5:0.0499 ndcg@10:0.0682 ndcg@20:0.0951 ndcg@50:0.1199 118 | 2024-07-23 15:33:54 INFO ----------------------------------------------------Epoch 2---------------------------------------------------- 119 | 2024-07-23 15:33:54 INFO Training Time :[30.1 s] Training Loss = 9.9087 120 | 2024-07-23 15:33:55 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3211 121 | 2024-07-23 15:33:55 INFO hit@5:0.1836 hit@10:0.2819 hit@20:0.4099 hit@50:0.4318 ndcg@5:0.1191 ndcg@10:0.1507 ndcg@20:0.1832 ndcg@50:0.1879 122 | 2024-07-23 15:34:25 INFO ----------------------------------------------------Epoch 3---------------------------------------------------- 123 | 2024-07-23 15:34:25 INFO Training Time :[30.1 s] Training Loss = 9.6230 124 | 2024-07-23 15:34:26 INFO Evaluation Time:[0.8 s] Eval Loss = 9.2058 125 | 2024-07-23 15:34:26 INFO hit@5:0.2080 hit@10:0.3090 hit@20:0.4284 hit@50:0.4485 ndcg@5:0.1420 ndcg@10:0.1745 ndcg@20:0.2048 ndcg@50:0.2092 126 | 2024-07-23 15:34:56 INFO ----------------------------------------------------Epoch 4---------------------------------------------------- 127 | 2024-07-23 15:34:56 INFO Training Time :[30.2 s] Training Loss = 9.3300 128 | 2024-07-23 15:34:57 INFO Evaluation Time:[0.9 s] Eval Loss = 9.095 129 | 2024-07-23 15:34:57 INFO hit@5:0.2436 hit@10:0.3391 hit@20:0.4570 hit@50:0.4859 ndcg@5:0.1707 ndcg@10:0.2014 ndcg@20:0.2313 ndcg@50:0.2375 130 | 2024-07-23 15:35:27 INFO ----------------------------------------------------Epoch 5---------------------------------------------------- 131 | 2024-07-23 15:35:27 INFO Training Time :[30.0 s] Training Loss = 9.0510 132 | 2024-07-23 15:35:28 INFO Evaluation Time:[0.9 s] Eval Loss = 8.9176 133 | 2024-07-23 15:35:28 INFO hit@5:0.2988 hit@10:0.3964 hit@20:0.5138 hit@50:0.5666 ndcg@5:0.2144 ndcg@10:0.2460 ndcg@20:0.2756 ndcg@50:0.2867 134 | 2024-07-23 15:35:58 INFO ----------------------------------------------------Epoch 6---------------------------------------------------- 135 | 2024-07-23 15:35:58 INFO Training Time :[30.0 s] Training Loss = 8.7750 136 | 2024-07-23 15:35:59 INFO Evaluation Time:[0.9 s] Eval Loss = 8.7004 137 | 2024-07-23 15:35:59 INFO hit@5:0.3550 hit@10:0.4554 hit@20:0.5695 hit@50:0.6418 ndcg@5:0.2613 ndcg@10:0.2937 ndcg@20:0.3225 ndcg@50:0.3375 138 | 2024-07-23 15:36:29 INFO ----------------------------------------------------Epoch 7---------------------------------------------------- 139 | 2024-07-23 15:36:29 INFO Training Time :[30.0 s] Training Loss = 8.5188 140 | 2024-07-23 15:36:30 INFO Evaluation Time:[0.9 s] Eval Loss = 8.5374 141 | 2024-07-23 15:36:30 INFO hit@5:0.3887 hit@10:0.4897 hit@20:0.6031 hit@50:0.6945 ndcg@5:0.2926 ndcg@10:0.3252 ndcg@20:0.3539 ndcg@50:0.3727 142 | 2024-07-23 15:37:00 INFO ----------------------------------------------------Epoch 8---------------------------------------------------- 143 | 2024-07-23 15:37:00 INFO Training Time :[30.0 s] Training Loss = 8.2575 144 | 2024-07-23 15:37:00 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4156 145 | 2024-07-23 15:37:00 INFO hit@5:0.4123 hit@10:0.5096 hit@20:0.6229 hit@50:0.7218 ndcg@5:0.3151 ndcg@10:0.3466 ndcg@20:0.3752 ndcg@50:0.3955 146 | 2024-07-23 15:37:30 INFO ----------------------------------------------------Epoch 9---------------------------------------------------- 147 | 2024-07-23 15:37:30 INFO Training Time :[30.1 s] Training Loss = 8.0445 148 | 2024-07-23 15:37:31 INFO Evaluation Time:[0.8 s] Eval Loss = 8.3391 149 | 2024-07-23 15:37:31 INFO hit@5:0.4217 hit@10:0.5192 hit@20:0.6349 hit@50:0.7477 ndcg@5:0.3260 ndcg@10:0.3573 ndcg@20:0.3866 ndcg@50:0.4096 150 | 2024-07-23 15:38:01 INFO ----------------------------------------------------Epoch 10---------------------------------------------------- 151 | 2024-07-23 15:38:01 INFO Training Time :[30.1 s] Training Loss = 7.8524 152 | 2024-07-23 15:38:02 INFO Evaluation Time:[0.8 s] Eval Loss = 8.3403 153 | 2024-07-23 15:38:02 INFO hit@5:0.4295 hit@10:0.5239 hit@20:0.6416 hit@50:0.7662 ndcg@5:0.3334 ndcg@10:0.3639 ndcg@20:0.3935 ndcg@50:0.4189 154 | 2024-07-23 15:38:32 INFO ----------------------------------------------------Epoch 11---------------------------------------------------- 155 | 2024-07-23 15:38:32 INFO Training Time :[30.0 s] Training Loss = 7.6800 156 | 2024-07-23 15:38:33 INFO Evaluation Time:[0.9 s] Eval Loss = 8.339 157 | 2024-07-23 15:38:33 INFO hit@5:0.4319 hit@10:0.5293 hit@20:0.6466 hit@50:0.7620 ndcg@5:0.3390 ndcg@10:0.3704 ndcg@20:0.4000 ndcg@50:0.4236 158 | 2024-07-23 15:39:03 INFO ----------------------------------------------------Epoch 12---------------------------------------------------- 159 | 2024-07-23 15:39:03 INFO Training Time :[30.0 s] Training Loss = 7.5258 160 | 2024-07-23 15:39:04 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4081 161 | 2024-07-23 15:39:04 INFO hit@5:0.4367 hit@10:0.5316 hit@20:0.6482 hit@50:0.7733 ndcg@5:0.3438 ndcg@10:0.3745 ndcg@20:0.4038 ndcg@50:0.4293 162 | 2024-07-23 15:39:34 INFO ----------------------------------------------------Epoch 13---------------------------------------------------- 163 | 2024-07-23 15:39:34 INFO Training Time :[30.0 s] Training Loss = 7.4114 164 | 2024-07-23 15:39:35 INFO Evaluation Time:[0.9 s] Eval Loss = 8.5115 165 | 2024-07-23 15:39:35 INFO hit@5:0.4343 hit@10:0.5312 hit@20:0.6486 hit@50:0.7707 ndcg@5:0.3443 ndcg@10:0.3755 ndcg@20:0.4050 ndcg@50:0.4300 166 | 2024-07-23 15:39:35 INFO EarlyStopping Counter: 1 out of 10 167 | 2024-07-23 15:40:05 INFO ----------------------------------------------------Epoch 14---------------------------------------------------- 168 | 2024-07-23 15:40:05 INFO Training Time :[30.0 s] Training Loss = 7.2897 169 | 2024-07-23 15:40:06 INFO Evaluation Time:[0.8 s] Eval Loss = 8.5864 170 | 2024-07-23 15:40:06 INFO hit@5:0.4365 hit@10:0.5341 hit@20:0.6490 hit@50:0.7624 ndcg@5:0.3448 ndcg@10:0.3762 ndcg@20:0.4052 ndcg@50:0.4285 171 | 2024-07-23 15:40:36 INFO ----------------------------------------------------Epoch 15---------------------------------------------------- 172 | 2024-07-23 15:40:36 INFO Training Time :[30.0 s] Training Loss = 7.2119 173 | 2024-07-23 15:40:37 INFO Evaluation Time:[0.9 s] Eval Loss = 8.69 174 | 2024-07-23 15:40:37 INFO hit@5:0.4393 hit@10:0.5335 hit@20:0.6532 hit@50:0.7737 ndcg@5:0.3484 ndcg@10:0.3787 ndcg@20:0.4089 ndcg@50:0.4336 175 | 2024-07-23 15:40:37 INFO EarlyStopping Counter: 1 out of 10 176 | 2024-07-23 15:41:07 INFO ----------------------------------------------------Epoch 16---------------------------------------------------- 177 | 2024-07-23 15:41:07 INFO Training Time :[30.2 s] Training Loss = 7.1417 178 | 2024-07-23 15:41:08 INFO Evaluation Time:[0.9 s] Eval Loss = 8.6575 179 | 2024-07-23 15:41:08 INFO hit@5:0.4397 hit@10:0.5368 hit@20:0.6544 hit@50:0.7579 ndcg@5:0.3503 ndcg@10:0.3817 ndcg@20:0.4113 ndcg@50:0.4325 180 | 2024-07-23 15:41:38 INFO ----------------------------------------------------Epoch 17---------------------------------------------------- 181 | 2024-07-23 15:41:38 INFO Training Time :[30.1 s] Training Loss = 7.0573 182 | 2024-07-23 15:41:39 INFO Evaluation Time:[0.8 s] Eval Loss = 8.809 183 | 2024-07-23 15:41:39 INFO hit@5:0.4395 hit@10:0.5355 hit@20:0.6549 hit@50:0.7616 ndcg@5:0.3498 ndcg@10:0.3808 ndcg@20:0.4109 ndcg@50:0.4329 184 | 2024-07-23 15:41:39 INFO EarlyStopping Counter: 1 out of 10 185 | 2024-07-23 15:42:09 INFO ----------------------------------------------------Epoch 18---------------------------------------------------- 186 | 2024-07-23 15:42:09 INFO Training Time :[30.0 s] Training Loss = 6.9430 187 | 2024-07-23 15:42:09 INFO Evaluation Time:[0.9 s] Eval Loss = 8.8837 188 | 2024-07-23 15:42:09 INFO hit@5:0.4430 hit@10:0.5368 hit@20:0.6553 hit@50:0.7630 ndcg@5:0.3530 ndcg@10:0.3832 ndcg@20:0.4131 ndcg@50:0.4353 189 | 2024-07-23 15:42:09 INFO EarlyStopping Counter: 2 out of 10 190 | 2024-07-23 15:42:39 INFO ----------------------------------------------------Epoch 19---------------------------------------------------- 191 | 2024-07-23 15:42:39 INFO Training Time :[30.0 s] Training Loss = 6.8674 192 | 2024-07-23 15:42:40 INFO Evaluation Time:[0.9 s] Eval Loss = 8.9556 193 | 2024-07-23 15:42:40 INFO hit@5:0.4411 hit@10:0.5384 hit@20:0.6570 hit@50:0.7684 ndcg@5:0.3518 ndcg@10:0.3831 ndcg@20:0.4131 ndcg@50:0.4360 194 | 2024-07-23 15:43:10 INFO ----------------------------------------------------Epoch 20---------------------------------------------------- 195 | 2024-07-23 15:43:10 INFO Training Time :[30.0 s] Training Loss = 6.8115 196 | 2024-07-23 15:43:11 INFO Evaluation Time:[0.9 s] Eval Loss = 9.0637 197 | 2024-07-23 15:43:11 INFO hit@5:0.4445 hit@10:0.5388 hit@20:0.6559 hit@50:0.7718 ndcg@5:0.3537 ndcg@10:0.3841 ndcg@20:0.4137 ndcg@50:0.4375 198 | 2024-07-23 15:43:41 INFO ----------------------------------------------------Epoch 21---------------------------------------------------- 199 | 2024-07-23 15:43:41 INFO Training Time :[30.0 s] Training Loss = 6.7497 200 | 2024-07-23 15:43:42 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1482 201 | 2024-07-23 15:43:42 INFO hit@5:0.4407 hit@10:0.5364 hit@20:0.6575 hit@50:0.7716 ndcg@5:0.3513 ndcg@10:0.3822 ndcg@20:0.4127 ndcg@50:0.4363 202 | 2024-07-23 15:43:42 INFO EarlyStopping Counter: 1 out of 10 203 | 2024-07-23 15:44:12 INFO ----------------------------------------------------Epoch 22---------------------------------------------------- 204 | 2024-07-23 15:44:12 INFO Training Time :[30.1 s] Training Loss = 6.7074 205 | 2024-07-23 15:44:13 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1342 206 | 2024-07-23 15:44:13 INFO hit@5:0.4402 hit@10:0.5359 hit@20:0.6563 hit@50:0.7722 ndcg@5:0.3520 ndcg@10:0.3828 ndcg@20:0.4132 ndcg@50:0.4371 207 | 2024-07-23 15:44:13 INFO EarlyStopping Counter: 2 out of 10 208 | 2024-07-23 15:44:43 INFO ----------------------------------------------------Epoch 23---------------------------------------------------- 209 | 2024-07-23 15:44:43 INFO Training Time :[29.9 s] Training Loss = 6.9449 210 | 2024-07-23 15:44:44 INFO Evaluation Time:[0.9 s] Eval Loss = 9.2002 211 | 2024-07-23 15:44:44 INFO hit@5:0.4395 hit@10:0.5361 hit@20:0.6530 hit@50:0.7706 ndcg@5:0.3510 ndcg@10:0.3821 ndcg@20:0.4116 ndcg@50:0.4359 212 | 2024-07-23 15:44:44 INFO EarlyStopping Counter: 3 out of 10 213 | 2024-07-23 15:45:14 INFO ----------------------------------------------------Epoch 24---------------------------------------------------- 214 | 2024-07-23 15:45:14 INFO Training Time :[30.0 s] Training Loss = 6.6846 215 | 2024-07-23 15:45:15 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3021 216 | 2024-07-23 15:45:15 INFO hit@5:0.4405 hit@10:0.5362 hit@20:0.6533 hit@50:0.7735 ndcg@5:0.3501 ndcg@10:0.3810 ndcg@20:0.4104 ndcg@50:0.4353 217 | 2024-07-23 15:45:15 INFO EarlyStopping Counter: 4 out of 10 218 | 2024-07-23 15:45:45 INFO ----------------------------------------------------Epoch 25---------------------------------------------------- 219 | 2024-07-23 15:45:45 INFO Training Time :[30.0 s] Training Loss = 6.5977 220 | 2024-07-23 15:45:46 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1295 221 | 2024-07-23 15:45:46 INFO hit@5:0.4438 hit@10:0.5363 hit@20:0.6579 hit@50:0.7648 ndcg@5:0.3533 ndcg@10:0.3832 ndcg@20:0.4137 ndcg@50:0.4358 222 | 2024-07-23 15:45:46 INFO EarlyStopping Counter: 5 out of 10 223 | 2024-07-23 15:46:16 INFO ----------------------------------------------------Epoch 26---------------------------------------------------- 224 | 2024-07-23 15:46:16 INFO Training Time :[30.0 s] Training Loss = 6.5443 225 | 2024-07-23 15:46:16 INFO Evaluation Time:[0.9 s] Eval Loss = 9.4042 226 | 2024-07-23 15:46:16 INFO hit@5:0.4404 hit@10:0.5320 hit@20:0.6549 hit@50:0.7668 ndcg@5:0.3508 ndcg@10:0.3804 ndcg@20:0.4113 ndcg@50:0.4344 227 | 2024-07-23 15:46:16 INFO EarlyStopping Counter: 6 out of 10 228 | 2024-07-23 15:46:46 INFO ----------------------------------------------------Epoch 27---------------------------------------------------- 229 | 2024-07-23 15:46:46 INFO Training Time :[30.0 s] Training Loss = 6.5445 230 | 2024-07-23 15:46:47 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3367 231 | 2024-07-23 15:46:47 INFO hit@5:0.4423 hit@10:0.5359 hit@20:0.6557 hit@50:0.7659 ndcg@5:0.3524 ndcg@10:0.3826 ndcg@20:0.4127 ndcg@50:0.4355 232 | 2024-07-23 15:46:47 INFO EarlyStopping Counter: 7 out of 10 233 | 2024-07-23 15:47:17 INFO ----------------------------------------------------Epoch 28---------------------------------------------------- 234 | 2024-07-23 15:47:17 INFO Training Time :[30.0 s] Training Loss = 6.4637 235 | 2024-07-23 15:47:18 INFO Evaluation Time:[0.9 s] Eval Loss = 9.4813 236 | 2024-07-23 15:47:18 INFO hit@5:0.4383 hit@10:0.5310 hit@20:0.6539 hit@50:0.7706 ndcg@5:0.3502 ndcg@10:0.3802 ndcg@20:0.4110 ndcg@50:0.4351 237 | 2024-07-23 15:47:18 INFO EarlyStopping Counter: 8 out of 10 238 | 2024-07-23 15:47:48 INFO ----------------------------------------------------Epoch 29---------------------------------------------------- 239 | 2024-07-23 15:47:48 INFO Training Time :[29.9 s] Training Loss = 6.4458 240 | 2024-07-23 15:47:49 INFO Evaluation Time:[0.9 s] Eval Loss = 9.5087 241 | 2024-07-23 15:47:49 INFO hit@5:0.4374 hit@10:0.5339 hit@20:0.6544 hit@50:0.7711 ndcg@5:0.3489 ndcg@10:0.3801 ndcg@20:0.4104 ndcg@50:0.4345 242 | 2024-07-23 15:47:49 INFO EarlyStopping Counter: 9 out of 10 243 | 2024-07-23 15:48:19 INFO ----------------------------------------------------Epoch 30---------------------------------------------------- 244 | 2024-07-23 15:48:19 INFO Training Time :[30.0 s] Training Loss = 6.4197 245 | 2024-07-23 15:48:20 INFO Evaluation Time:[0.8 s] Eval Loss = 9.4156 246 | 2024-07-23 15:48:20 INFO hit@5:0.4365 hit@10:0.5323 hit@20:0.6538 hit@50:0.7658 ndcg@5:0.3473 ndcg@10:0.3782 ndcg@20:0.4088 ndcg@50:0.4320 247 | 2024-07-23 15:48:20 INFO EarlyStopping Counter: 10 out of 10 248 | 2024-07-23 15:48:20 INFO ------------------------------------------------Best Evaluation------------------------------------------------ 249 | 2024-07-23 15:48:20 INFO Best Result at Epoch: 20 Early Stop at Patience: 10 250 | 2024-07-23 15:48:20 INFO hit@5:0.4445 hit@10:0.5388 hit@20:0.6559 hit@50:0.7718 ndcg@5:0.3537 ndcg@10:0.3841 ndcg@20:0.4137 ndcg@50:0.4375 251 | 2024-07-23 15:48:21 INFO -----------------------------------------------------Test Results------------------------------------------------------ 252 | 2024-07-23 15:48:21 INFO hit@5:0.4025 hit@10:0.4954 hit@20:0.6142 hit@50:0.7306 ndcg@5:0.3152 ndcg@10:0.3451 ndcg@20:0.3750 ndcg@50:0.3989 253 | -------------------------------------------------------------------------------- /runIOCRec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from src.train.trainer import load_trainer 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser() 6 | 7 | # Model 8 | parser.add_argument('--model', default='IOCRec', type=str) 9 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge']) 10 | # Contrast Learning Hyper Params 11 | parser.add_argument('--aug_types', default=['crop', 'mask', 'reorder'], help='augmentation types') 12 | parser.add_argument('--crop_ratio', default=0.2, type=float, 13 | help='Crop augmentation: proportion of cropped subsequence in origin sequence') 14 | parser.add_argument('--mask_ratio', default=0.7, type=float, 15 | help='Mask augmentation: proportion of masked items in origin sequence') 16 | parser.add_argument('--reorder_ratio', default=0.2, type=float, 17 | help='Reorder augmentation: proportion of reordered subsequence in origin sequence') 18 | parser.add_argument('--all_hidden', action='store_false', help='all hidden states for cl') 19 | parser.add_argument('--tao', default=1., type=float, help='temperature for softmax') 20 | parser.add_argument('--lamda', default=0.1, type=float, 21 | help='weight for contrast learning loss, only work when jointly training') 22 | parser.add_argument('--k_intention', default=4, type=int, help='number of disentangled intention') 23 | # Transformer 24 | parser.add_argument('--embed_size', default=128, type=int) 25 | parser.add_argument('--ffn_hidden', default=512, type=int, help='hidden dim for feed forward network') 26 | parser.add_argument('--num_blocks', default=3, type=int, help='number of transformer block') 27 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention') 28 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate') 29 | parser.add_argument('--attn_dropout', default=0.5, type=float, help='dropout rate for attention') 30 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps') 31 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range') 32 | # Data 33 | parser.add_argument('--dataset', default='toys', type=str) 34 | # Training 35 | parser.add_argument('--epoch_num', default=150, type=int) 36 | parser.add_argument('--data_aug', action='store_false', help='data augmentation') 37 | parser.add_argument('--train_batch', default=256, type=int) 38 | parser.add_argument('--learning_rate', default=1e-3, type=float) 39 | parser.add_argument('--l2', default=0, type=float, help='l2 normalization') 40 | parser.add_argument('--patience', default=10, help='early stop patience') 41 | parser.add_argument('--seed', default=2024, type=int, help='random seed, -1 means no fixed seed') 42 | parser.add_argument('--mark', default='seed2024', type=str, help='log suffix mark') 43 | # Evaluation 44 | parser.add_argument('--split_type', default='valid_and_test', choices=['valid_only', 'valid_and_test']) 45 | parser.add_argument('--split_mode', default='LS', type=str, help='[LS (leave-one-out), LS_R@0.x, PS (pre-split)]') 46 | parser.add_argument('--eval_mode', default='full', help='[uni100, pop100, full]') 47 | parser.add_argument('--k', default=[5, 10, 20, 50], help='rank k for each metric') 48 | parser.add_argument('--metric', default=['hit', 'ndcg'], help='[hit, ndcg, mrr, recall]') 49 | parser.add_argument('--valid_metric', default='hit@10', help='specifies which indicator to apply early stop') 50 | 51 | config = parser.parse_args() 52 | 53 | trainer = load_trainer(config) 54 | trainer.start_training() 55 | 56 | 57 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__init__.py -------------------------------------------------------------------------------- /src/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/data_processor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__pycache__/data_processor.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset/data_processor.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import copy 3 | import logging 4 | # import dgl 5 | import numpy as np 6 | import torch 7 | import os 8 | from scipy import sparse 9 | from src.utils.utils import save_pickle, load_pickle 10 | 11 | 12 | class DataProcessor: 13 | def __init__(self, config): 14 | self.config = config 15 | self.model_name = config.model 16 | self.dataset = config.dataset 17 | self.data_aug = config.data_aug 18 | 19 | self.use_tar_seq = False 20 | self.tar_seq_len = 1 21 | if self.model_name in ['KERL', 'MELOD']: 22 | self.use_tar_seq = True 23 | self.tar_seq_len = config.episode_len 24 | self.filter_len = config.seq_filter_len 25 | self.filter_target = config.if_filter_target 26 | self.item_id_pad = 1 # increase item id: x -> x + 1 27 | self.sep = config.separator 28 | self.graph_type_list = [g_type.upper() for g_type in config.graph_type] 29 | self.valid_graphs = ['GGNN', 'BIPARTITE', 'TRANSITION', 'HYPER'] 30 | 31 | self.data_path = None 32 | self.train_data = None 33 | self.eval_data = None 34 | self.test_data = None 35 | self.kg_map = None 36 | self.do_data_split = True 37 | self.popularity = None 38 | 39 | self.split_type = config.split_type 40 | self.split_mode = config.split_mode 41 | self.eval_ratio = 0.2 # default 42 | 43 | # data statistic 44 | self.max_item = 0 45 | self.num_items = 0 # with padding item 0 46 | self.num_users = 0 47 | self.seq_avg_len = 0. 48 | self.total_seq_len = 0. 49 | self.density = 0. 50 | self.sparsity = 0. 51 | 52 | self._init_data_processer() 53 | 54 | def _init_data_processer(self): 55 | if self.split_mode == 'PS': 56 | self.do_data_split = False 57 | elif 'LS_R' == self.split_mode.split('@')[0]: 58 | self.eval_ratio = float(self.split_mode.split('@')[-1]) 59 | self.split_mode = 'LS_R' 60 | self._set_data_path() 61 | 62 | def prepare_data(self): 63 | if self.do_data_split: 64 | seq_data_list = self._load_row_data() 65 | self._train_test_split(seq_data_list) 66 | else: # load pre-split data 67 | self._load_pre_split_data() 68 | 69 | data_dict = {'train': self.train_data, 70 | 'eval': self.eval_data, 71 | 'test': self.test_data, 72 | 'raw_train': self.row_train_data} 73 | 74 | extra_data_dict = self._prepare_additional_data() 75 | 76 | return data_dict, extra_data_dict 77 | 78 | def _prepare_additional_data(self): 79 | additional_data_dict = {} 80 | if self.config.model_type.lower() == 'graph': 81 | specified_graph_dict = self.prepare_specified_graph() 82 | if specified_graph_dict is not None: 83 | additional_data_dict['graph_dict'] = specified_graph_dict 84 | if self.config.model_type.lower() == 'knowledge': 85 | if self.config.kg_data_type == 'pretrain': 86 | kg_map = self.prepare_kg_map().float().to(self.config.device) 87 | additional_data_dict['kg_map'] = kg_map 88 | elif self.config.kg_data_type == 'jointly_train': 89 | triple_list, relation_tph, relation_hpt = self.load_kg_triplet(entity_id_shift=1) 90 | additional_data_dict['triple_list'] = triple_list 91 | additional_data_dict['relation_tph'] = relation_tph 92 | additional_data_dict['relation_hpt'] = relation_hpt 93 | else: 94 | pass 95 | if self.model_name in ['DuoRec']: 96 | train_target_item = self.train_data[1] 97 | same_target_index = self.semantic_augmentation(train_target_item) 98 | additional_data_dict['same_target_index'] = same_target_index 99 | elif self.model_name in ['KGSCL']: 100 | kg_relation_dict = load_pickle(os.path.join(self.data_path, 'kg_relation_set.pkl')) 101 | co_occurrence_dict = self.calc_co_occurrence(kg_relation_dict) 102 | additional_data_dict['kg_relation_dict'] = kg_relation_dict 103 | additional_data_dict['co_occurrence_dict'] = co_occurrence_dict 104 | 105 | return additional_data_dict 106 | 107 | def semantic_augmentation(self, target_item): 108 | target_item = np.array(target_item, dtype=np.long) 109 | aug_path = self.data_path + '/semantic_augmentation.pkl' 110 | target_deduplicated = set(target_item) 111 | if os.path.exists(aug_path): 112 | same_target_index = load_pickle(aug_path) 113 | else: 114 | same_target_index = {} 115 | for _, item_id in enumerate(target_deduplicated): 116 | all_index_same_id_w_self = np.where(target_item == item_id)[ 117 | 0].tolist() # all index of a specific item id with self item, including itself 118 | if item_id not in same_target_index.keys(): 119 | same_target_index[item_id] = [] 120 | same_target_index[item_id].extend(all_index_same_id_w_self) 121 | save_pickle(same_target_index, aug_path) 122 | return same_target_index 123 | 124 | def calc_co_occurrence(self, kg_relation_dict): 125 | dict_path = self.data_path + '/co_occurrence.pkl' 126 | if os.path.exists(dict_path): 127 | co_occurrence_dict = load_pickle(dict_path) 128 | return co_occurrence_dict 129 | else: 130 | co_occurrence_all_data = {} 131 | for item_seq, target in zip(*self.row_train_data): 132 | for i in range(len(item_seq)): 133 | for j in range(i + 1, len(item_seq)): 134 | item_1, item_2 = item_seq[i], item_seq[j] 135 | co_item_pair = (item_1, item_2) if item_1 < item_2 else (item_2, item_1) 136 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}' 137 | if co_item_pair not in co_occurrence_all_data.keys(): 138 | co_occurrence_all_data[co_item_pair] = 0. 139 | co_occurrence_all_data[co_item_pair] += 1. 140 | # train target item 141 | for item in item_seq: 142 | co_item_pair = (item, target) if item < target else (target, item_seq) 143 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}' 144 | if co_item_pair not in co_occurrence_all_data.keys(): 145 | co_occurrence_all_data[co_item_pair] = 0. 146 | co_occurrence_all_data[co_item_pair] += 1. 147 | 148 | # calc co-occurrence for items with kg relations 149 | co_occurrence_dict = {} 150 | for item in kg_relation_dict.keys(): 151 | shifted_item = item + 1 152 | co_occurrence_dict[item] = {'s': [], 'c': []} 153 | 154 | # frequency of substitute items 155 | substitute_freq_list = [] 156 | substitute_item_list = kg_relation_dict[item]['s'] 157 | if len(substitute_item_list) > 0: 158 | for sub_item in substitute_item_list: 159 | shifted_sub_item = sub_item + 1 160 | co_item_pair = (shifted_item, shifted_sub_item) if item < sub_item else ( 161 | shifted_sub_item, shifted_item) 162 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}' 163 | if co_item_pair not in co_occurrence_all_data.keys(): 164 | substitute_freq_list.append(1.) 165 | else: 166 | substitute_freq_list.append(1. + float(co_occurrence_all_data[co_item_pair])) 167 | sum_freq = sum(substitute_freq_list) 168 | substitute_freq_list = [freq / sum_freq for freq in substitute_freq_list] 169 | assert len(substitute_freq_list) == len(substitute_item_list) 170 | co_occurrence_dict[item]['s'] = substitute_freq_list 171 | 172 | # frequency of complement items 173 | complement_freq_list = [] 174 | complement_item_list = kg_relation_dict[item]['c'] 175 | if len(complement_item_list) > 0: 176 | for sub_item in complement_item_list: 177 | shifted_sub_item = sub_item + 1 178 | co_item_pair = (shifted_item, shifted_sub_item) if item < sub_item else ( 179 | shifted_sub_item, shifted_item) 180 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}' 181 | if co_item_pair not in co_occurrence_all_data.keys(): 182 | complement_freq_list.append(1.) 183 | else: 184 | complement_freq_list.append(1. + co_occurrence_all_data[co_item_pair]) 185 | sum_freq = sum(complement_freq_list) 186 | complement_freq_list = [freq / sum_freq for freq in complement_freq_list] 187 | assert len(complement_freq_list) == len(complement_item_list) 188 | co_occurrence_dict[item]['c'] = complement_freq_list 189 | 190 | save_pickle(co_occurrence_dict, dict_path) 191 | return co_occurrence_dict 192 | 193 | def load_kg_triplet(self, entity_id_shift=1): 194 | """ 195 | entity_id_shift: entity_id += 1 196 | 197 | return: 198 | triple_list: list of kg triplets 199 | relation_tph: dict, tail number for each head in relation r 200 | relation_hpt: dict, head number of each tail in relation r 201 | """ 202 | # required kg files 203 | train_triplet = f'{self.data_path}/kg_data/train.txt' 204 | entity2id = f'{self.data_path}/kg_data/entities.dict' 205 | relation2id = f'{self.data_path}/kg_data/relations.dict' 206 | 207 | entities2id = {} 208 | relations2id = {} 209 | relation_tph = {} 210 | relation_hpt = {} 211 | 212 | entity = [] 213 | relation = [] 214 | # read entity and relation sets 215 | with open(entity2id, 'r') as f1, open(relation2id, 'r') as f2: 216 | lines1 = f1.readlines() 217 | lines2 = f2.readlines() 218 | for line in lines1: 219 | line = line.strip().split('\t') 220 | if len(line) != 2: 221 | continue 222 | entities2id[line[0]] = int(line[1]) + entity_id_shift 223 | entity.append(int(line[1])) 224 | 225 | for line in lines2: 226 | line = line.strip().split('\t') 227 | if len(line) != 2: 228 | continue 229 | relations2id[line[0]] = int(line[1]) 230 | relation.append(int(line[1])) 231 | 232 | triple_list = [] 233 | relation_head = {} 234 | relation_tail = {} 235 | 236 | # read kg triplets, convert entity name to entity id 237 | with codecs.open(train_triplet, 'r') as f: 238 | content = f.readlines() 239 | for line in content: 240 | triple = line.strip().split("\t") 241 | if len(triple) != 3: 242 | continue 243 | 244 | h_ = int(entities2id[triple[0]]) 245 | r_ = int(relations2id[triple[1]]) 246 | t_ = int(entities2id[triple[2]]) 247 | 248 | triple_list.append([h_, r_, t_]) 249 | if r_ in relation_head: 250 | if h_ in relation_head[r_]: 251 | relation_head[r_][h_] += 1 252 | else: 253 | relation_head[r_][h_] = 1 254 | else: 255 | relation_head[r_] = {} 256 | relation_head[r_][h_] = 1 257 | 258 | if r_ in relation_tail: 259 | if t_ in relation_tail[r_]: 260 | relation_tail[r_][t_] += 1 261 | else: 262 | relation_tail[r_][t_] = 1 263 | else: 264 | relation_tail[r_] = {} 265 | relation_tail[r_][t_] = 1 266 | 267 | for r_ in relation_head: 268 | sum1, sum2 = 0, 0 269 | for head in relation_head[r_]: 270 | sum1 += 1 271 | sum2 += relation_head[r_][head] 272 | tph = sum2 / sum1 273 | relation_tph[r_] = tph 274 | 275 | for r_ in relation_tail: 276 | sum1, sum2 = 0, 0 277 | for tail in relation_tail[r_]: 278 | sum1 += 1 279 | sum2 += relation_tail[r_][tail] 280 | hpt = sum2 / sum1 281 | relation_hpt[r_] = hpt 282 | 283 | print("Complete load. entity : %d , relation : %d , train triple : %d" % ( 284 | len(entity), len(relation), len(triple_list))) 285 | 286 | # update config 287 | self.config.entity_num = len(entity) + 1 # add padding item 0 288 | self.config.relation_num = len(relation) 289 | 290 | # tph: tails per head, hpt: heads per tail 291 | return triple_list, relation_tph, relation_hpt 292 | 293 | def _set_data_path(self): 294 | # find file path 295 | cur_path = os.path.abspath(__file__) 296 | root = '\\'.join(cur_path.split('\\')[:-3]) 297 | self.data_path = os.path.join(root, f'dataset/{self.dataset}') 298 | 299 | def _read_seq_data(self, file_path): 300 | # read data file 301 | data_list = [] 302 | with open(file_path, 'r') as fr: 303 | for line in fr.readlines(): 304 | item_seq = list(map(int, line.strip().split(self.sep))) 305 | # remove target items 306 | if self.filter_target: 307 | item_seq = self._filter_target(item_seq) 308 | # drop too short sequence 309 | if len(item_seq) < self.filter_len: 310 | continue 311 | item_seq = [item + self.item_id_pad for item in item_seq] # shift item id x to x + 1 312 | # statistic 313 | self.max_item = max(self.max_item, max(item_seq)) 314 | self.total_seq_len += float(len(item_seq)) 315 | self.num_users += 1 316 | data_list.append(item_seq) 317 | return data_list 318 | 319 | def _load_row_data(self): 320 | """ 321 | load total data sequences 322 | """ 323 | file_path = os.path.join(self.data_path, f'{self.dataset}.seq') 324 | seq_data_list = self._read_seq_data(file_path) 325 | self._set_statistic(seq_data_list) 326 | 327 | return seq_data_list 328 | 329 | def _set_statistic(self, seq_data_list=None): 330 | self.seq_avg_len = round(float(self.total_seq_len) / self.num_users, 4) 331 | self.density = round(float(self.total_seq_len) / self.num_users / self.max_item, 4) 332 | self.sparsity = 1 - self.density 333 | self.num_users = int(self.num_users) 334 | self.num_items = int(self.max_item + 1) # with padding item 0 335 | 336 | # calculate popularity 337 | self.popularity = [0. for _ in range(self.num_items)] 338 | for item_seq in seq_data_list: 339 | for item in item_seq: 340 | self.popularity[item] += 1. 341 | self.popularity = [p / self.total_seq_len for p in self.popularity] 342 | 343 | def _load_pre_split_data(self): 344 | """ 345 | load data after split, xx.train, xx.eval, xx.test 346 | """ 347 | # load xx.train, xx.eval 348 | train_file = os.path.join(self.data_path, f'{self.dataset}.train') 349 | eval_file = os.path.join(self.data_path, f'{self.dataset}.eval') 350 | 351 | train_data_list = self._read_seq_data(train_file) 352 | eval_data_list = self._read_seq_data(eval_file) 353 | 354 | train_x = [seq[:-1] for seq in train_data_list if len(seq) > 1] 355 | train_y = [seq[-1] for seq in train_data_list if len(seq) > 1] 356 | eval_x = [seq[:-1] for seq in eval_data_list if len(seq) > 1] 357 | eval_y = [seq[-1] for seq in eval_data_list if len(seq) > 1] 358 | 359 | self.row_train_data = copy.deepcopy(train_x), copy.deepcopy(train_y) 360 | # training data augmentation 361 | self._data_augmentation(train_x, train_y) 362 | 363 | # load xx.test 364 | if self.split_type == 'valid_and_test': 365 | test_file = os.path.join(self.data_path, f'{self.dataset}.test') 366 | test_data_list = self._read_seq_data(test_file) 367 | test_x = [seq[:-1] for seq in test_data_list if len(seq) > 1] 368 | test_y = [seq[-1] for seq in test_data_list if len(seq) > 1] 369 | self.test_data = (test_x, test_y) 370 | test_seq = [seq + [target] for seq, target in zip(test_x, test_y)] 371 | 372 | self.train_data = (train_x, train_y) 373 | self.eval_data = (eval_x, eval_y) 374 | 375 | # gather all sequences 376 | all_data_list = [seq + [target] for seq, target in zip(train_x, train_y)] 377 | eval_seq = [seq + [target] for seq, target in zip(eval_x, eval_y)] 378 | all_data_list.extend(eval_seq) 379 | if self.split_type == 'valid_and_test': 380 | all_data_list.extend(test_seq) 381 | 382 | self._set_statistic(all_data_list) 383 | 384 | def _train_test_split(self, seq_data_list): 385 | if self.split_type == 'valid_only': 386 | train_x, train_y, eval_x, eval_y = self._leave_one_out_split(seq_data_list) 387 | else: # valid and test 388 | if self.split_mode == 'LS': 389 | train_x = [item_seq[:-3] for item_seq in seq_data_list if len(item_seq) > 3] 390 | train_y = [item_seq[-3] for item_seq in seq_data_list if len(item_seq) > 3] 391 | eval_x = [item_seq[:-2] for item_seq in seq_data_list if len(item_seq) > 2] 392 | eval_y = [item_seq[-2] for item_seq in seq_data_list if len(item_seq) > 2] 393 | test_x = [item_seq[:-1] for item_seq in seq_data_list if len(item_seq) > 1] 394 | test_y = [item_seq[-1] for item_seq in seq_data_list if len(item_seq) > 1] 395 | else: # LS_R 396 | train_x, train_y, test_x, test_y = self._leave_one_out_split(seq_data_list) 397 | # split eval and test data by ratio 398 | eval_x, eval_y, test_x, test_y = self._split_by_ratio(test_x, test_y) 399 | self.test_data = (test_x, test_y) 400 | 401 | self.row_train_data = (copy.deepcopy(train_x), copy.deepcopy(train_y)) 402 | # training data augmentation 403 | self._data_augmentation(train_x, train_y) 404 | 405 | self.train_data = (train_x, train_y) 406 | self.eval_data = (eval_x, eval_y) 407 | 408 | def _leave_one_out_split(self, seq_data): 409 | train_x = [item_seq[:-2] for item_seq in seq_data if len(item_seq) > 2] 410 | train_y = [item_seq[-2] for item_seq in seq_data if len(item_seq) > 2] 411 | eval_x = [item_seq[:-1] for item_seq in seq_data if len(item_seq) > 1] 412 | eval_y = [item_seq[-1] for item_seq in seq_data if len(item_seq) > 1] 413 | return train_x, train_y, eval_x, eval_y 414 | 415 | def prepare_kg_map(self): 416 | kg_path = os.path.join(self.data_path, f'{self.dataset}_kg.npy') 417 | kg_map_np = np.load(kg_path) 418 | zero_kg_emb = np.zeros((1, kg_map_np.shape[-1])) 419 | kg_map_np_pad = np.concatenate([zero_kg_emb, kg_map_np], axis=0) 420 | return torch.from_numpy(kg_map_np_pad) 421 | 422 | def prepare_specified_graph(self): 423 | assert isinstance(self.graph_type_list, list), f'graph_type should be a list.' 424 | graph_dict = {} 425 | for g_type in self.graph_type_list: 426 | if g_type not in self.valid_graphs: 427 | raise KeyError(f'Invalid graph type:{self.graph_type_list}. Choose from {self.valid_graphs}') 428 | if g_type == 'GGNN': 429 | continue # session graph will be constructed in dataset 430 | graph_dict[g_type] = getattr(self, f'prepare_{g_type.lower()}_graph')() 431 | if len(graph_dict) == 0: 432 | return None 433 | return graph_dict 434 | 435 | def prepare_bipartite_graph(self, bidirectional=True): 436 | u_i_edges = [] 437 | i_u_edges = [] 438 | 439 | for user, item_list in enumerate(self.row_train_data[0]): 440 | for item in item_list: 441 | u_i_edges.append((user, item)) 442 | if bidirectional: # bidirectional graph 443 | i_u_edges.append((item, user)) 444 | num_nodes_dict = {'user': len(self.row_train_data[0]), 'item': self.num_items} 445 | 446 | logging.info('loading graph...') 447 | if bidirectional: 448 | bipartite = dgl.heterograph({ 449 | ('user', 'contact', 'item'): u_i_edges, 450 | ('item', 'contact_by', 'user'): i_u_edges 451 | }, num_nodes_dict=num_nodes_dict) 452 | else: 453 | bipartite = dgl.heterograph({ 454 | ('user', 'contact', 'item'): u_i_edges 455 | }, num_nodes_dict=num_nodes_dict) 456 | 457 | logging.info(bipartite) 458 | return bipartite 459 | 460 | def prepare_transition_graph(self): 461 | """ 462 | Returns: 463 | adj (sparse matrix): adjacency matrix of item transition graph 464 | """ 465 | src, des = [], [] 466 | for item_seq in self.row_train_data[0]: 467 | for i in range(len(item_seq) - 1): 468 | src.append(item_seq[i]) 469 | des.append(item_seq[i + 1]) 470 | adj = sparse.coo_matrix((np.ones(len(src), dtype=float), (src, des)), shape=(self.num_items, self.num_items)) 471 | 472 | # adj = sparse.load_npz(f'{self.data_path}/global_adj.npz') 473 | 474 | # frequency divide in degree of destination node as edge weight 475 | # in_degree = adj.sum(0).reshape(1, -1) 476 | # in_degree[in_degree == 0] = 1 477 | # adj = adj.multiply(1 / in_degree) 478 | 479 | return adj 480 | 481 | def prepare_hyper_graph(self): 482 | """ 483 | Returns: 484 | incidence_matrix (sparse matrix): incidence matrix for session hyper-graph 485 | """ 486 | # incidence_matrix = sparse.load_npz(f'{self.data_path}/global_adj.npz') 487 | # return incidence_matrix 488 | 489 | nodes, edges = [], [] 490 | for edge, item_seq in enumerate(self.row_train_data[0]): 491 | for item in item_seq: 492 | nodes.append(item) 493 | edges.append(edge) 494 | incidence_matrix = sparse.coo_matrix((np.ones(len(nodes), dtype=float), (nodes, edges))) 495 | 496 | return incidence_matrix 497 | 498 | def data_log_verbose(self, order): 499 | 500 | logging.info(f'[{order}] Data Statistic '.ljust(47, '-')) 501 | logging.info(f'dataset: {self.dataset}') 502 | logging.info(f'user number: {self.num_users}') 503 | logging.info(f'item number: {self.max_item}') 504 | logging.info(f'average seq length: {self.seq_avg_len}') 505 | logging.info(f'density: {self.density} sparsity: {self.sparsity}') 506 | if self.data_aug: 507 | logging.info(f'data after augmentation:') 508 | if self.split_type == 'valid_only': 509 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}') 510 | else: 511 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}\ttest ' 512 | f'samples: {len(self.test_data[0])}') 513 | else: 514 | logging.info(f'data without augmentation:') 515 | if self.split_type == 'valid_only': 516 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}') 517 | else: 518 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}\ttest ' 519 | f'samples: {len(self.test_data[0])}') 520 | 521 | def _filter_target(self, item_seq): 522 | target = item_seq[-1] 523 | item_seq = list(filter(lambda x: x != target, item_seq[:-1])) 524 | item_seq.append(target) 525 | return item_seq 526 | 527 | def _split_by_ratio(self, test_x, test_y): 528 | """ 529 | random split by specified ratio 530 | """ 531 | eval_size = int(len(test_y) * self.eval_ratio) 532 | index = np.arange(len(test_y)) 533 | np.random.shuffle(index) 534 | 535 | eval_index = index[:eval_size] 536 | test_index = index[eval_size:] 537 | 538 | eval_x = [test_x[i] for i in eval_index] 539 | eval_y = [test_y[i] for i in eval_index] 540 | 541 | test_x = [test_x[i] for i in test_index] 542 | test_y = [test_y[i] for i in test_index] 543 | 544 | return eval_x, eval_y, test_x, test_y 545 | 546 | def _data_augmentation(self, train_x, train_y): 547 | if not self.data_aug: 548 | return 549 | if not self.use_tar_seq: 550 | aug_train_x = [item_seq[:last] for item_seq in train_x for last in range(1, len(item_seq))] 551 | aug_train_y = [item_seq[nextI] for item_seq in train_x for nextI in range(1, len(item_seq))] 552 | else: 553 | aug_train_x = [item_seq[:last] for item_seq in train_x for last in range(1, len(item_seq))] 554 | aug_train_y = [item_seq[nextI: nextI + self.tar_seq_len] for item_seq in train_x for nextI in 555 | range(1, len(item_seq))] 556 | train_x.extend(aug_train_x) 557 | train_y.extend(aug_train_y) 558 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import random 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader, default_collate 7 | from src.utils.utils import neg_sample 8 | from src.model.data_augmentation import Crop, Mask, Reorder 9 | from src.model.data_augmentation import AUGMENTATIONS 10 | 11 | 12 | def load_specified_dataset(model_name, config): 13 | if model_name in ['CL4SRec', 'ICLRec', 'IOCRec']: 14 | return CL4SRecDataset 15 | return SequentialDataset 16 | 17 | 18 | class BaseSequentialDataset(Dataset): 19 | def __init__(self, config, data_pair, additional_data_dict=None, train=True): 20 | super(BaseSequentialDataset, self).__init__() 21 | self.batch_batch_dict = {} 22 | self.num_items = config.num_items 23 | self.config = config 24 | self.train = train 25 | self.dataset = config.dataset 26 | self.max_len = config.max_len 27 | self.item_seq = data_pair[0] 28 | self.label = data_pair[1] 29 | 30 | def get_SRtask_input(self, idx): 31 | item_seq = self.item_seq[idx] 32 | target = self.label[idx] 33 | 34 | seq_len = len(item_seq) if len(item_seq) < self.max_len else self.max_len 35 | item_seq = item_seq[-self.max_len:] 36 | item_seq = item_seq + (self.max_len - seq_len) * [0] 37 | 38 | assert len(item_seq) == self.max_len 39 | 40 | return (torch.tensor(item_seq, dtype=torch.long), 41 | torch.tensor(seq_len, dtype=torch.long), 42 | torch.tensor(target, dtype=torch.long)) 43 | 44 | def __getitem__(self, idx): 45 | return self.get_SRtask_input(idx) 46 | 47 | def __len__(self): 48 | return len(self.item_seq) 49 | 50 | def collate_fn(self, x): 51 | return self.basic_SR_collate_fn(x) 52 | 53 | def basic_SR_collate_fn(self, x): 54 | """ 55 | x: [(seq_1, len_1, tar_1), ..., (seq_n, len_n, tar_n)] 56 | """ 57 | item_seq, seq_len, target = default_collate(x) 58 | self.batch_batch_dict['item_seq'] = item_seq 59 | self.batch_batch_dict['seq_len'] = seq_len 60 | self.batch_batch_dict['target'] = target 61 | return self.batch_batch_dict 62 | 63 | 64 | class SequentialDataset(BaseSequentialDataset): 65 | def __init__(self, config, data_pair, additional_data_dict=None, train=True): 66 | super(SequentialDataset, self).__init__(config, data_pair, additional_data_dict, train) 67 | 68 | 69 | class CL4SRecDataset(BaseSequentialDataset): 70 | def __init__(self, config, data_pair, additional_data_dict=None, train=True): 71 | super(CL4SRecDataset, self).__init__(config, data_pair, additional_data_dict, train) 72 | self.mask_id = self.num_items 73 | self.aug_types = config.aug_types 74 | self.n_views = 2 75 | self.augmentations = [] 76 | 77 | self.load_augmentor() 78 | 79 | def load_augmentor(self): 80 | for aug in self.aug_types: 81 | if aug == 'mask': 82 | self.augmentations.append(Mask(gamma=self.config.mask_ratio, mask_id=self.mask_id)) 83 | else: 84 | self.augmentations.append(AUGMENTATIONS[aug](getattr(self.config, f'{aug}_ratio'))) 85 | 86 | def __getitem__(self, index): 87 | # for eval and test 88 | if not self.train: 89 | return self.get_SRtask_input(index) 90 | 91 | # for training 92 | # contrast learning augmented views 93 | item_seq = self.item_seq[index] 94 | target = self.label[index] 95 | aug_type = np.random.choice([i for i in range(len(self.augmentations))], 96 | size=self.n_views, replace=True) 97 | aug_seq_1 = self.augmentations[aug_type[0]](item_seq) 98 | aug_seq_2 = self.augmentations[aug_type[1]](item_seq) 99 | 100 | aug_seq_1 = aug_seq_1[-self.max_len:] 101 | aug_seq_2 = aug_seq_2[-self.max_len:] 102 | 103 | aug_len_1 = len(aug_seq_1) 104 | aug_len_2 = len(aug_seq_2) 105 | 106 | aug_seq_1 = aug_seq_1 + [0] * (self.max_len - len(aug_seq_1)) 107 | aug_seq_2 = aug_seq_2 + [0] * (self.max_len - len(aug_seq_2)) 108 | assert len(aug_seq_1) == self.max_len 109 | assert len(aug_seq_2) == self.max_len 110 | 111 | # recommendation sequences 112 | seq_len = len(item_seq) if len(item_seq) < self.max_len else self.max_len 113 | item_seq = item_seq[-self.max_len:] 114 | item_seq = item_seq + (self.max_len - seq_len) * [0] 115 | 116 | assert len(item_seq) == self.max_len 117 | 118 | cur_tensors = (torch.tensor(item_seq, dtype=torch.long), 119 | torch.tensor(seq_len, dtype=torch.long), 120 | torch.tensor(target, dtype=torch.long), 121 | torch.tensor(aug_seq_1, dtype=torch.long), 122 | torch.tensor(aug_seq_2, dtype=torch.long), 123 | torch.tensor(aug_len_1, dtype=torch.long), 124 | torch.tensor(aug_len_2, dtype=torch.long)) 125 | 126 | return cur_tensors 127 | 128 | def collate_fn(self, x): 129 | if not self.train: 130 | return self.basic_SR_collate_fn(x) 131 | 132 | item_seq, seq_len, target, aug_seq_1, aug_seq_2, aug_len_1, aug_len_2 = default_collate(x) 133 | 134 | self.batch_batch_dict['item_seq'] = item_seq 135 | self.batch_batch_dict['seq_len'] = seq_len 136 | self.batch_batch_dict['target'] = target 137 | self.batch_batch_dict['aug_seq_1'] = aug_seq_1 138 | self.batch_batch_dict['aug_seq_2'] = aug_seq_2 139 | self.batch_batch_dict['aug_len_1'] = aug_len_1 140 | self.batch_batch_dict['aug_len_2'] = aug_len_2 141 | 142 | return self.batch_batch_dict 143 | 144 | 145 | class MISPPretrainDataset(Dataset): 146 | """ 147 | Masked Item & Segment Prediction (MISP) 148 | """ 149 | 150 | def __init__(self, config, data_pair, additional_data_dict=None): 151 | self.mask_id = config.num_items 152 | self.mask_ratio = config.mask_ratio 153 | self.num_items = config.num_items + 1 154 | self.config = config 155 | self.item_seq = data_pair[0] 156 | self.label = data_pair[1] 157 | self.max_len = config.max_len 158 | self.long_sequence = [] 159 | 160 | for seq in self.item_seq: 161 | self.long_sequence.extend(seq) 162 | 163 | def __len__(self): 164 | return len(self.item_seq) 165 | 166 | def __getitem__(self, index): 167 | sequence = self.item_seq[index] # pos_items 168 | 169 | # Masked Item Prediction 170 | masked_item_sequence = [] 171 | neg_items = [] 172 | pos_items = sequence 173 | 174 | item_set = set(sequence) 175 | for item in sequence[:-1]: 176 | prob = random.random() 177 | if prob < self.mask_ratio: 178 | masked_item_sequence.append(self.mask_id) 179 | neg_items.append(neg_sample(item_set, self.num_items)) 180 | else: 181 | masked_item_sequence.append(item) 182 | neg_items.append(item) 183 | # add mask at the last position 184 | masked_item_sequence.append(self.mask_id) 185 | neg_items.append(neg_sample(item_set, self.num_items)) 186 | 187 | assert len(masked_item_sequence) == len(sequence) 188 | assert len(pos_items) == len(sequence) 189 | assert len(neg_items) == len(sequence) 190 | 191 | # Segment Prediction 192 | if len(sequence) < 2: 193 | masked_segment_sequence = sequence 194 | pos_segment = sequence 195 | neg_segment = sequence 196 | else: 197 | sample_length = random.randint(1, len(sequence) // 2) 198 | start_id = random.randint(0, len(sequence) - sample_length) 199 | neg_start_id = random.randint(0, len(self.long_sequence) - sample_length) 200 | pos_segment = sequence[start_id: start_id + sample_length] 201 | neg_segment = self.long_sequence[neg_start_id:neg_start_id + sample_length] 202 | masked_segment_sequence = sequence[:start_id] + [self.mask_id] * sample_length + sequence[ 203 | start_id + sample_length:] 204 | pos_segment = [self.mask_id] * start_id + pos_segment + [self.mask_id] * ( 205 | len(sequence) - (start_id + sample_length)) 206 | neg_segment = [self.mask_id] * start_id + neg_segment + [self.mask_id] * ( 207 | len(sequence) - (start_id + sample_length)) 208 | 209 | assert len(masked_segment_sequence) == len(sequence) 210 | assert len(pos_segment) == len(sequence) 211 | assert len(neg_segment) == len(sequence) 212 | 213 | # crop sequence 214 | masked_item_sequence = masked_item_sequence[-self.max_len:] 215 | pos_items = pos_items[-self.max_len:] 216 | neg_items = neg_items[-self.max_len:] 217 | masked_segment_sequence = masked_segment_sequence[-self.max_len:] 218 | pos_segment = pos_segment[-self.max_len:] 219 | neg_segment = neg_segment[-self.max_len:] 220 | 221 | # padding sequence 222 | pad_len = self.max_len - len(sequence) 223 | masked_item_sequence = masked_item_sequence + [0] * pad_len 224 | pos_items = pos_items + [0] * pad_len 225 | neg_items = neg_items + [0] * pad_len 226 | masked_segment_sequence = masked_segment_sequence + [0] * pad_len 227 | pos_segment = pos_segment + [0] * pad_len 228 | neg_segment = neg_segment + [0] * pad_len 229 | 230 | assert len(masked_item_sequence) == self.max_len 231 | assert len(pos_items) == self.max_len 232 | assert len(neg_items) == self.max_len 233 | assert len(masked_segment_sequence) == self.max_len 234 | assert len(pos_segment) == self.max_len 235 | assert len(neg_segment) == self.max_len 236 | 237 | cur_tensors = (torch.tensor(masked_item_sequence, dtype=torch.long), 238 | torch.tensor(pos_items, dtype=torch.long), 239 | torch.tensor(neg_items, dtype=torch.long), 240 | torch.tensor(masked_segment_sequence, dtype=torch.long), 241 | torch.tensor(pos_segment, dtype=torch.long), 242 | torch.tensor(neg_segment, dtype=torch.long)) 243 | return cur_tensors 244 | 245 | def collate_fn(self, x): 246 | tensor_dict = {} 247 | tensor_list = [torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).long() for j in range(len(x[0]))] 248 | masked_item_sequence, pos_items, neg_items, \ 249 | masked_segment_sequence, pos_segment, neg_segment = tensor_list 250 | 251 | tensor_dict['masked_item_sequence'] = masked_item_sequence 252 | tensor_dict['pos_items'] = pos_items 253 | tensor_dict['neg_items'] = neg_items 254 | tensor_dict['masked_segment_sequence'] = masked_segment_sequence 255 | tensor_dict['pos_segment'] = pos_segment 256 | tensor_dict['neg_segment'] = neg_segment 257 | 258 | return tensor_dict 259 | 260 | 261 | class MIMPretrainDataset(Dataset): 262 | def __init__(self, config, data_pair, additional_data_dict=None): 263 | self.config = config 264 | self.aug_types = config.aug_types 265 | self.mask_id = config.num_items 266 | 267 | self.item_seq = data_pair[0] 268 | self.label = data_pair[1] 269 | self.max_len = config.max_len 270 | self.n_views = 2 271 | self.augmentations = [] 272 | self.load_augmentor() 273 | 274 | def load_augmentor(self): 275 | for aug in self.aug_types: 276 | if aug == 'mask': 277 | self.augmentations.append(Mask(gamma=self.config.mask_ratio, mask_id=self.mask_id)) 278 | else: 279 | self.augmentations.append(AUGMENTATIONS[aug](getattr(self.config, f'{aug}_ratio'))) 280 | 281 | def __getitem__(self, index): 282 | aug_type = np.random.choice([i for i in range(len(self.augmentations))], 283 | size=self.n_views, replace=False) 284 | item_seq = self.item_seq[index] 285 | aug_seq_1 = self.augmentations[aug_type[0]](item_seq) 286 | aug_seq_2 = self.augmentations[aug_type[1]](item_seq) 287 | 288 | aug_seq_1 = aug_seq_1[-self.max_len:] 289 | aug_seq_2 = aug_seq_2[-self.max_len:] 290 | 291 | aug_len_1 = len(aug_seq_1) 292 | aug_len_2 = len(aug_seq_2) 293 | 294 | aug_seq_1 = aug_seq_1 + [0] * (self.max_len - len(aug_seq_1)) 295 | aug_seq_2 = aug_seq_2 + [0] * (self.max_len - len(aug_seq_2)) 296 | assert len(aug_seq_1) == self.max_len 297 | assert len(aug_seq_2) == self.max_len 298 | 299 | aug_seq_tensors = (torch.tensor(aug_seq_1, dtype=torch.long), 300 | torch.tensor(aug_seq_2, dtype=torch.long), 301 | torch.tensor(aug_len_1, dtype=torch.long), 302 | torch.tensor(aug_len_2, dtype=torch.long)) 303 | 304 | return aug_seq_tensors 305 | 306 | def __len__(self): 307 | ''' 308 | consider n_view of a single sequence as one sample 309 | ''' 310 | return len(self.item_seq) 311 | 312 | def collate_fn(self, x): 313 | tensor_dict = {} 314 | tensor_list = [torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).long() for j in range(len(x[0]))] 315 | aug_seq_1, aug_seq_2, aug_len_1, aug_len_2 = tensor_list 316 | 317 | tensor_dict['aug_seq_1'] = aug_seq_1 318 | tensor_dict['aug_seq_2'] = aug_seq_2 319 | tensor_dict['aug_len_1'] = aug_len_1 320 | tensor_dict['aug_len_2'] = aug_len_2 321 | 322 | return tensor_dict 323 | 324 | 325 | class PIDPretrainDataset(Dataset): 326 | def __init__(self, config, data_pair, additional_data_dict=None): 327 | self.num_items = config.num_items 328 | self.item_seq = data_pair[0] 329 | self.label = data_pair[1] 330 | self.config = config 331 | self.max_len = config.max_len 332 | self.pseudo_ratio = config.pseudo_ratio 333 | 334 | def __getitem__(self, index): 335 | item_seq = self.item_seq[index] 336 | pseudo_seq = [] 337 | target = [] 338 | 339 | for item in item_seq: 340 | if random.random() < self.pseudo_ratio: 341 | pseudo_item = neg_sample(item_seq, self.num_items) 342 | pseudo_seq.append(pseudo_item) 343 | target.append(0) 344 | else: 345 | pseudo_seq.append(item) 346 | target.append(1) 347 | 348 | pseudo_seq = pseudo_seq[-self.max_len:] 349 | target = target[-self.max_len:] 350 | 351 | pseudo_seq = pseudo_seq + [0] * (self.max_len - len(pseudo_seq)) 352 | target = target + [0] * (self.max_len - len(target)) 353 | assert len(pseudo_seq) == self.max_len 354 | assert len(target) == self.max_len 355 | pseudo_seq_tensors = (torch.tensor(pseudo_seq, dtype=torch.long), 356 | torch.tensor(target, dtype=torch.float)) 357 | 358 | return pseudo_seq_tensors 359 | 360 | def __len__(self): 361 | ''' 362 | consider n_view of a single sequence as one sample 363 | ''' 364 | return len(self.item_seq) 365 | 366 | def collate_fn(self, x): 367 | tensor_dict = {} 368 | tensor_list = [torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).long() for j in range(len(x[0]))] 369 | pseudo_seq, target = tensor_list 370 | 371 | tensor_dict['pseudo_seq'] = pseudo_seq 372 | tensor_dict['target'] = target 373 | 374 | return tensor_dict 375 | 376 | 377 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/evaluation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/estimator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__pycache__/estimator.cpython-38.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /src/evaluation/estimator.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from tqdm import tqdm 4 | from src.evaluation.metrics import Metric 5 | from src.utils.utils import batch_to_device 6 | 7 | 8 | class Estimator: 9 | def __init__(self, config): 10 | self.popularity = None 11 | self.config = config 12 | self.metrics = config.metric 13 | self.k_list = config.k 14 | self.dev = config.device 15 | self.metric_res_dict = {} 16 | self.eval_loss = 0. 17 | self.max_k = max(self.k_list) 18 | self.split_type = config.split_type 19 | self.eval_mode = config.eval_mode 20 | self.neg_size = 0 21 | if self.eval_mode != 'full': 22 | self.neg_size = int(re.findall(r'\d+', self.eval_mode)[0]) 23 | self.eval_mode = self.eval_mode[:3] 24 | self._reset_metrics() 25 | 26 | def _reset_metrics(self): 27 | for metric in self.metrics: 28 | for k in self.k_list: 29 | self.metric_res_dict[f'{metric}@{k}'] = 0. 30 | self.eval_loss = 0. 31 | 32 | def load_item_popularity(self, pop): 33 | self.popularity = torch.tensor(pop, dtype=torch.float, device=self.dev) 34 | 35 | @torch.no_grad() 36 | def evaluate(self, eval_loader, model): 37 | model.eval() 38 | self._reset_metrics() 39 | 40 | eval_sample_size = len(eval_loader.dataset) 41 | eval_iter = tqdm(enumerate(eval_loader), total=len(eval_loader)) 42 | eval_iter.set_description(f'do evaluation...') 43 | for _, batch_dict in eval_iter: 44 | batch_to_device(batch_dict, self.dev) 45 | logits = model(batch_dict) 46 | model_loss = model.get_loss(batch_dict, logits) 47 | logits = self.neg_sample_select(batch_dict, logits) 48 | self.calc_metrics(logits, batch_dict['target']) 49 | self.eval_loss += model_loss.item() 50 | 51 | for metric in self.metrics: 52 | for k in self.k_list: 53 | self.metric_res_dict[f'{metric}@{k}'] /= float(eval_sample_size) 54 | 55 | eval_loss = self.eval_loss / float(len(eval_loader)) 56 | 57 | return self.metric_res_dict, eval_loss 58 | 59 | @torch.no_grad() 60 | def test(self, test_loader, model): 61 | model.eval() 62 | self._reset_metrics() 63 | 64 | test_sample_size = len(test_loader.dataset) 65 | test_iter = tqdm(enumerate(test_loader), total=len(test_loader)) 66 | test_iter.set_description(f'do test...') 67 | for _, batch_dict in test_iter: 68 | batch_to_device(batch_dict, self.dev) 69 | logits = model(batch_dict) 70 | logits = self.neg_sample_select(batch_dict, logits) 71 | self.calc_metrics(logits, batch_dict['target']) 72 | 73 | for metric in self.metrics: 74 | for k in self.k_list: 75 | self.metric_res_dict[f'{metric}@{k}'] /= float(test_sample_size) 76 | return self.metric_res_dict 77 | 78 | def calc_metrics(self, prediction, target): 79 | _, topk_index = torch.topk(prediction, self.max_k, -1) # [batch, max_k] 80 | topk_socre = torch.gather(prediction, index=topk_index, dim=-1) 81 | idx_sorted = torch.argsort(topk_socre, dim=-1, descending=True) 82 | top_k_item_sorted = torch.gather(topk_index, index=idx_sorted, dim=-1) 83 | 84 | for metric in self.metrics: 85 | for k in self.k_list: 86 | score = getattr(Metric, f'{metric.upper()}')(top_k_item_sorted, target, k) 87 | self.metric_res_dict[f'{metric}@{k}'] += score 88 | 89 | def calc_metrics_(self, prediction, target): 90 | _, topk_index = torch.topk(prediction, self.max_k, -1) # [batch, max_k] 91 | topk_socre = torch.gather(prediction, index=topk_index, dim=-1) 92 | idx_sorted = torch.argsort(topk_socre, dim=-1, descending=True) 93 | max_k_item_sorted = torch.gather(topk_index, index=idx_sorted, dim=-1) 94 | 95 | metric_res_dict = {} 96 | for metric in self.metrics: 97 | for k in self.k_list: 98 | score = getattr(Metric, f'{metric.upper()}')(max_k_item_sorted, target, k) 99 | metric_res_dict[f'{metric}@{k}'] += score 100 | 101 | return metric_res_dict 102 | 103 | def neg_sample_select(self, data_dict, prediction): 104 | if self.eval_mode == 'full': 105 | return prediction 106 | item_seq, target = data_dict['item_seq'], data_dict['target'] 107 | # sample negative items 108 | target = target.unsqueeze(-1) 109 | mask_item = torch.cat([item_seq, target], dim=-1) # [batch, max_len + 1] 110 | 111 | if self.eval_mode == 'uni': 112 | sample_prob = torch.ones_like(prediction, device=self.dev) / prediction.size(-1) 113 | elif self.eval_mode == 'pop': 114 | if self.popularity.size(0) != prediction.size(-1): # ignore mask item 115 | self.popularity = torch.cat([self.popularity, torch.zeros((1,)).to(self.dev)], -1) 116 | sample_prob = self.popularity.unsqueeze(0).repeat(prediction.size(0), 1) 117 | else: 118 | raise NotImplementedError('Choose eval_model from [full, popxxx, unixxx]') 119 | sample_prob = sample_prob.scatter(dim=-1, index=mask_item, value=0.) 120 | neg_item = torch.multinomial(sample_prob, self.neg_size) # [batch, neg_size] 121 | # mask non-rank items 122 | rank_item = torch.cat([neg_item, target], dim=-1) # [batch, neg_size + 1] 123 | mask = torch.ones_like(prediction, device=self.dev).bool() 124 | mask = mask.scatter(dim=-1, index=rank_item, value=False) 125 | masked_pred = torch.masked_fill(prediction, mask, 0.) 126 | 127 | return masked_pred 128 | 129 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Metric: 6 | @staticmethod 7 | def HIT(prediction, target, k): 8 | """ 9 | calculate Hit-Ratio (HR) @ k 10 | :param prediction: [batch, max_k], sorted along dim -1 11 | :param target: [batch] 12 | :param k: scalar 13 | :return: average hit-ratio score among all input data 14 | """ 15 | prediction, target = Metric.process(prediction, target, k) 16 | hit = ((prediction - target) == 0).sum(dim=-1).double() 17 | hit = hit.sum().item() 18 | return hit 19 | 20 | @staticmethod 21 | def NDCG(prediction, target, k): 22 | """ 23 | calculate Normalized Discounted Cumulative Gain (NDCG) @ k. 24 | Note that the Ideal Discounted Cumulative Gain (IDCG) is equal to all users, so it can be ignored. 25 | :param prediction: [batch, max_k], sorted along dim -1 26 | :param target: [batch] 27 | :param k: scalar 28 | :return: average hit-ratio score among all input data 29 | """ 30 | prediction, target = Metric.process(prediction, target, k) 31 | hit = ((prediction - target) == 0).sum(dim=-1).double() # [batch_size] 32 | row, col = ((prediction - target) == 0.).nonzero(as_tuple=True) # [hit_size] 33 | ndcg = hit.scatter(index=row, src=1. / torch.log2(col + 2).double(), dim=-1) 34 | ndcg = ndcg.sum().item() 35 | return ndcg 36 | 37 | @staticmethod 38 | def MRR(prediction, target, k): 39 | """ 40 | calculate Mean Reciprocal Rank (MRR) @ k 41 | :param prediction: [batch, max_k], sorted along dim -1 42 | :param target: [batch] 43 | :param k: scalar 44 | :return: average hit-ratio score among all input data 45 | """ 46 | prediction, target = Metric.process(prediction, target, k) 47 | hit = ((prediction - target) == 0).sum(dim=-1).double() # [batch_size] 48 | row, col = ((prediction - target) == 0.).nonzero(as_tuple=True) # [hit_size] 49 | mrr = hit.scatter(index=row, src=1. / (col + 1).double(), dim=-1) 50 | mrr = mrr.sum().item() 51 | return mrr 52 | 53 | @staticmethod 54 | def RECALL(prediction, target, k): 55 | """ 56 | calculate recall @ k, similar to hit-ration under SR (Sequential recommendation) setting 57 | :param prediction: [batch, max_k], sorted along dim -1 58 | :param target: [batch] 59 | :param k: scalar 60 | :return: average hit-ratio score among all input data 61 | """ 62 | return Metric.HIT(prediction, target, k) 63 | 64 | @staticmethod 65 | def process(prediction, target, k): 66 | if k < prediction.size(-1): 67 | prediction = prediction[:, :k] # [batch, k] 68 | target = target.unsqueeze(-1) # [batch, 1] 69 | return prediction, target 70 | 71 | 72 | if __name__ == '__main__': 73 | a = torch.arange(12).view(3, -1) 74 | a[1, -1] = 0 75 | print(a) 76 | hit = (a == 0).sum(dim=-1).float() 77 | hit_index, rank = (a == 0).nonzero(as_tuple=True) 78 | print(hit_index, rank) 79 | score = torch.scatter(hit, index=hit_index, src=1. / torch.log2(rank + 2), dim=-1) 80 | print(score) 81 | score = score.mean().cpu().numpy() 82 | print(score) -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.sequential_recommender import * 2 | from src.model.cl_based_seq_recommender import * 3 | -------------------------------------------------------------------------------- /src/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/abstract_recommeder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/abstract_recommeder.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/data_augmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/data_augmentation.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/sequential_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/sequential_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/abstract_recommeder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AbstractRecommender(nn.Module): 7 | def __init__(self, config): 8 | super(AbstractRecommender, self).__init__() 9 | self.num_items = config.num_items 10 | self.loss_type = config.loss_type 11 | self.max_len = config.max_len 12 | self.dev = config.device 13 | self.cross_entropy = nn.CrossEntropyLoss() 14 | 15 | def forward(self, data_dict: dict): 16 | """ 17 | Args: 18 | data_dict: dict 19 | """ 20 | pass 21 | 22 | def train_forward(self, data_dict: dict): 23 | """ 24 | Args: 25 | data_dict: dict 26 | """ 27 | logits = self.forward(data_dict) 28 | return self.get_loss(data_dict, logits) 29 | 30 | def load_basic_SR_data(self, data_dict): 31 | return data_dict['item_seq'], data_dict['seq_len'], data_dict['target'] 32 | 33 | def get_loss(self, data_dict, logits, item_seq=None, target=None): 34 | if item_seq is None: 35 | item_seq = data_dict['item_seq'] 36 | if target is None: 37 | target = data_dict['target'] 38 | 39 | if self.loss_type.upper() == 'BCE': 40 | neg_item = self.get_negative_items(item_seq, target, num_samples=1) 41 | pos_score = torch.gather(logits, -1, target.unsqueeze(-1)) 42 | neg_score = torch.gather(logits, -1, neg_item) 43 | loss = -torch.mean(F.logsigmoid(pos_score) + torch.log(1 - torch.sigmoid(neg_score)).sum(-1)) 44 | elif self.loss_type.upper() == 'BPR': # BPR loss 45 | neg_item = self.get_negative_items(item_seq, target, num_samples=1) 46 | pos_score = torch.gather(logits, -1, target.unsqueeze(-1)) 47 | neg_score = torch.gather(logits, -1, neg_item) 48 | loss = -torch.mean(F.logsigmoid(pos_score - neg_score)) 49 | elif self.loss_type.upper() == 'CE': # CE loss 50 | # prediction = F.softmax(logits, -1) 51 | loss = self.cross_entropy(logits, target) 52 | # pos_score = torch.gather(prediction, -1, target.unsqueeze(-1)) 53 | # loss = -torch.mean(torch.log(pos_score)) 54 | else: 55 | loss = torch.zeros((1,)).to(self.dev) 56 | return loss 57 | 58 | def gather_index(self, output, index): 59 | """ 60 | :param output: [batch, max_len, H] 61 | :param index: [batch] 62 | :return: [batch, H} 63 | """ 64 | gather_index = index.view(-1, 1, 1).repeat(1, 1, output.size(-1)) 65 | gather_output = output.gather(dim=1, index=gather_index) 66 | return gather_output.squeeze() 67 | 68 | def get_target_and_length(self, target_info): 69 | """ 70 | :param target_info: target information dict 71 | :return: 72 | """ 73 | target = target_info['target'] # [batch, ep_len] 74 | try: 75 | tar_len = target_info['target_len'] 76 | except: 77 | raise Exception(f"{self.__class__.__name__} requires target sequences, set use_tar_seq to true in " 78 | f"experimental settings") 79 | return target, tar_len 80 | 81 | def get_negative_items(self, input_item, target, num_samples=1): 82 | """ 83 | :param input_item: [batch_size, max_len] 84 | :param sample_size: [batch_size, num_samples] 85 | :return: 86 | """ 87 | sample_prob = torch.ones(input_item.size(0), self.num_items, device=target.device) 88 | sample_prob.scatter_(-1, input_item, 0.) 89 | sample_prob.scatter_(-1, target.unsqueeze(-1), 0.) 90 | neg_items = torch.multinomial(sample_prob, num_samples) 91 | 92 | return neg_items 93 | 94 | def pack_to_batch(self, prediction): 95 | if prediction.dim() < 2: 96 | prediction = prediction.unsqueeze(0) 97 | return prediction 98 | 99 | def calc_total_params(self): 100 | """ 101 | Calculate Total Parameters 102 | :return: number of parameters 103 | """ 104 | return sum([p.nelement() for p in self.parameters()]) 105 | 106 | def load_pretrain_model(self, pretrain_model): 107 | """ 108 | load pretraining model, default: load all parameters 109 | """ 110 | self.load_state_dict(pretrain_model.state_dict()) 111 | del pretrain_model 112 | 113 | def MISP_pretrain_forward(self, data_dict: dict): 114 | pass 115 | 116 | def MIM_pretrain_forward(self, data_dict: dict): 117 | pass 118 | 119 | def PID_pretrain_forward(self, data_dict: dict): 120 | pass 121 | 122 | 123 | class AbstractRLRecommender(AbstractRecommender): 124 | def __init__(self, config): 125 | super(AbstractRLRecommender, self).__init__(config) 126 | 127 | def sample_neg_action(self, masked_action, neg_size): 128 | """ 129 | :param masked_action: [batch, max_len] 130 | :return: neg_action, [batch, neg_size] 131 | """ 132 | sample_prob = torch.ones(masked_action.size(0), self.num_items, device=masked_action.device) 133 | sample_prob = sample_prob.scatter(-1, masked_action, 0.) 134 | neg_action = torch.multinomial(sample_prob, neg_size) 135 | 136 | return neg_action 137 | 138 | def state_transfer(self, pre_item_seq, action, seq_len): 139 | """ 140 | Parameters 141 | ---------- 142 | pre_item_seq: torch.LongTensor, [batch_size, max_len] 143 | action: torch.LongTensor, [batch_size] 144 | seq_len: torch.LongTensor, [batch_size] 145 | 146 | Return 147 | ------ 148 | next_state_seq: torch.LongTensor, [batch_size, max_len] 149 | """ 150 | new_item_seq = pre_item_seq.clone().detach() 151 | action = action.unsqueeze(-1) 152 | seq_len = seq_len.unsqueeze(-1) 153 | max_len = pre_item_seq.size(1) 154 | 155 | padding_col = torch.zeros_like(action, dtype=torch.long, device=action.device) 156 | new_item_seq = torch.cat([new_item_seq, padding_col], -1) 157 | new_item_seq = new_item_seq.scatter(-1, seq_len, action) 158 | new_item_seq = new_item_seq[:, 1:] 159 | 160 | new_seq_len = seq_len.squeeze() + 1 161 | new_seq_len[new_seq_len > max_len] = max_len 162 | 163 | return new_item_seq, new_seq_len 164 | -------------------------------------------------------------------------------- /src/model/cl_based_seq_recommender/__init__.py: -------------------------------------------------------------------------------- 1 | from .cl4srec import CL4SRec, CL4SRec_config 2 | from .iocrec import IOCRec, IOCRec_config 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /src/model/cl_based_seq_recommender/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/cl_based_seq_recommender/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/cl_based_seq_recommender/__pycache__/cl4srec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/cl_based_seq_recommender/__pycache__/cl4srec.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/cl_based_seq_recommender/__pycache__/iocrec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/cl_based_seq_recommender/__pycache__/iocrec.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/cl_based_seq_recommender/cl4srec.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2021. Contrastive Learning for Sequential Recommendation. In 3 | SIGIR ’21: Proceedings of the 44th International ACM SIGIR Conference on 4 | Research and Development in Information Retrieval (SIGIR’21) 5 | """ 6 | 7 | import sys 8 | import torch.nn.functional as F 9 | from src.model.abstract_recommeder import AbstractRecommender 10 | import argparse 11 | import torch 12 | import torch.nn as nn 13 | from src.model.sequential_encoder import Transformer 14 | from src.model.loss import InfoNCELoss 15 | from src.utils.utils import HyperParamDict 16 | 17 | 18 | class CL4SRec(AbstractRecommender): 19 | def __init__(self, config, additional_data_dict): 20 | super(CL4SRec, self).__init__(config) 21 | self.mask_id = self.num_items 22 | self.num_items = self.num_items + 1 23 | 24 | self.do_pretraining = config.do_pretraining 25 | self.embed_size = config.embed_size 26 | self.initializer_range = config.initializer_range 27 | self.aug_views = 2 28 | self.tao = config.tao 29 | self.all_hidden = config.all_hidden 30 | self.lamda = config.lamda # cl loss weight if on pretraining 31 | 32 | self.item_embedding = nn.Embedding(self.num_items, self.embed_size, padding_idx=0) 33 | self.position_embedding = nn.Embedding(self.max_len, self.embed_size) 34 | self.input_layer_norm = nn.LayerNorm(self.embed_size, eps=config.layer_norm_eps) 35 | self.input_dropout = nn.Dropout(config.hidden_dropout) 36 | self.trm_encoder = Transformer(embed_size=self.embed_size, 37 | ffn_hidden=config.ffn_hidden, 38 | num_blocks=config.num_blocks, 39 | num_heads=config.num_heads, 40 | attn_dropout=config.attn_dropout, 41 | hidden_dropout=config.hidden_dropout, 42 | layer_norm_eps=config.layer_norm_eps) 43 | # self.proj = nn.Linear(self.embed_size, self.embed_size) 44 | self.nce_loss = InfoNCELoss(temperature=self.tao, 45 | similarity_type='dot') 46 | # self.target_dropout = nn.Dropout(0.2) 47 | # self.temperature = 0.07 48 | self.apply(self._init_weights) 49 | 50 | def _init_weights(self, module): 51 | if isinstance(module, (nn.Embedding, nn.Linear)): 52 | module.weight.data.normal_(mean=0.0, std=self.initializer_range) 53 | elif isinstance(module, nn.LayerNorm): 54 | module.weight.data.fill_(1.) 55 | module.bias.data.zero_() 56 | elif isinstance(module, nn.Linear) and module.bias is not None: 57 | module.bias.data.zero_() 58 | 59 | def train_forward(self, data_dict): 60 | logits = self.forward(data_dict) 61 | rec_loss = self.get_loss(data_dict, logits) 62 | 63 | if self.do_pretraining: 64 | return rec_loss 65 | 66 | # jointly training 67 | cl_loss = self.MIM_pretrain_forward(data_dict) 68 | 69 | return rec_loss + self.lamda * cl_loss 70 | 71 | def forward(self, data_dict): 72 | item_seq, seq_len, _ = self.load_basic_SR_data(data_dict) 73 | seq_embedding = self.seq_encoding(item_seq, seq_len) 74 | candidates = self.item_embedding.weight 75 | 76 | # add extra normalization 77 | # seq_embedding = F.normalize(seq_embedding) 78 | # candidates = self.target_dropout(candidates) 79 | # candidates = F.normalize(candidates) 80 | 81 | # logits = seq_embedding @ candidates.t() / self.temperature 82 | logits = seq_embedding @ candidates.t() 83 | 84 | return logits 85 | 86 | def position_encoding(self, item_input): 87 | seq_embedding = self.item_embedding(item_input) 88 | position = torch.arange(self.max_len, device=item_input.device).unsqueeze(0) 89 | position = position.expand_as(item_input).long() 90 | pos_embedding = self.position_embedding(position) 91 | seq_embedding += pos_embedding 92 | seq_embedding = self.input_layer_norm(seq_embedding) 93 | seq_embedding = self.input_dropout(seq_embedding) 94 | 95 | return seq_embedding 96 | 97 | def seq_encoding(self, item_seq, seq_len, return_all=False): 98 | seq_embedding = self.position_encoding(item_seq) 99 | out_seq_embedding = self.trm_encoder(item_seq, seq_embedding) 100 | if not return_all: 101 | out_seq_embedding = self.gather_index(out_seq_embedding, seq_len - 1) 102 | return out_seq_embedding 103 | 104 | def MIM_pretrain_forward(self, data_dict): 105 | aug_seq_1, aug_len_1 = data_dict['aug_seq_1'], data_dict['aug_len_1'] 106 | aug_seq_2, aug_len_2 = data_dict['aug_seq_2'], data_dict['aug_len_2'] 107 | # aug_seq_2, aug_len_2 = data_dict['item_seq'], data_dict['seq_len'] 108 | # sequence encoding, [batch,embed_size] 109 | aug_seq_encoding_1 = self.seq_encoding(aug_seq_1, aug_len_1, return_all=self.all_hidden) 110 | aug_seq_encoding_2 = self.seq_encoding(aug_seq_2, aug_len_2, return_all=self.all_hidden) 111 | 112 | # add normalization 113 | # aug_seq_encoding_1 = F.normalize(aug_seq_encoding_1) 114 | # aug_seq_encoding_2 = F.normalize(aug_seq_encoding_2) 115 | 116 | # aug_seq_encoding_1 = self.proj(aug_seq_encoding_1) 117 | # aug_seq_encoding_2 = self.proj(aug_seq_encoding_2) 118 | cl_loss = self.nce_loss(aug_seq_encoding_1, aug_seq_encoding_2) 119 | # B = aug_seq_1.size(0) 120 | # aug_seq_encoding_1 = aug_seq_encoding_1.view(B, -1) 121 | # aug_seq_encoding_2 = aug_seq_encoding_2.view(B, -1) 122 | # dot_loss = (aug_seq_encoding_1 * aug_seq_encoding_2).sum(-1).mean() 123 | 124 | return cl_loss 125 | 126 | 127 | def CL4SRec_config(): 128 | parser = HyperParamDict('CL4SRec-Pretraining default hyper-parameters') 129 | parser.add_argument('--model', default='CL4SRec', type=str) 130 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge']) 131 | # Contrast Learning Hyper Params 132 | parser.add_argument('--do_pretraining', action='store_false', help='if do pretraining') 133 | parser.add_argument('--training_fashion', default='pretraining', choices=['pretraining', 'jointly_training']) 134 | parser.add_argument('--pretraining_task', default='MIM', type=str, choices=['MISP', 'MIM', 'PID'], 135 | help='pretraining task:' \ 136 | 'MISP: Mask Item Prediction and Mask Segment Prediction' \ 137 | 'MIM: Mutual Information Maximization' \ 138 | 'PID: Pseudo Item Discrimination' 139 | ) 140 | parser.add_argument('--aug_types', default=['crop', 'mask', 'reorder'], help='augmentation types') 141 | parser.add_argument('--crop_ratio', default=0.7, type=float, 142 | help='Crop augmentation: proportion of cropped subsequence in origin sequence') 143 | parser.add_argument('--mask_ratio', default=0.5, type=float, 144 | help='Mask augmentation: proportion of masked items in origin sequence') 145 | parser.add_argument('--reorder_ratio', default=0.8, type=float, 146 | help='Reorder augmentation: proportion of reordered subsequence in origin sequence') 147 | parser.add_argument('--all_hidden', action='store_false', help='all hidden states for cl') 148 | parser.add_argument('--tao', default=1., type=float, help='temperature for softmax') 149 | parser.add_argument('--lamda', default=0.1, type=float, 150 | help='weight for contrast learning loss, only work when jointly training') 151 | 152 | # Transformer 153 | parser.add_argument('--embed_size', default=128, type=int) 154 | parser.add_argument('--ffn_hidden', default=512, type=int, help='hidden dim for feed forward network') 155 | parser.add_argument('--num_blocks', default=2, type=int, help='number of transformer block') 156 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention') 157 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate') 158 | parser.add_argument('--attn_dropout', default=0., type=float, help='dropout rate for attention') 159 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps') 160 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range') 161 | 162 | parser.add_argument('--loss_type', default='CE', type=str, choices=['CE', 'BPR', 'BCE', 'CUSTOM']) 163 | 164 | return parser 165 | 166 | 167 | if __name__ == '__main__': 168 | a = torch.arange(5) 169 | b = a.clone() + 10 170 | c = torch.stack([a, b], 0) 171 | print(c) 172 | print(c.transpose(0, 1)) 173 | -------------------------------------------------------------------------------- /src/model/cl_based_seq_recommender/iocrec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/7 3 | # @Author : Chenglong Shi 4 | # @Email : hiderulo@163.com 5 | 6 | r""" 7 | IOCRec 8 | ################################################ 9 | 10 | Reference: 11 | Xuewei Li et al., "Multi-Intention Oriented Contrastive Learning for Sequential Recommendation" in WSDM 2023. 12 | 13 | """ 14 | 15 | import copy 16 | import math 17 | import sys 18 | import torch.nn.functional as F 19 | from src.model.abstract_recommeder import AbstractRecommender 20 | import argparse 21 | import torch 22 | import torch.nn as nn 23 | from src.model.sequential_encoder import Transformer 24 | from src.model.loss import InfoNCELoss 25 | from src.utils.utils import HyperParamDict 26 | 27 | 28 | class IOCRec(AbstractRecommender): 29 | def __init__(self, config, additional_data_dict): 30 | super(IOCRec, self).__init__(config) 31 | self.mask_id = self.num_items 32 | self.num_items = self.num_items + 1 33 | self.embed_size = config.embed_size 34 | self.initializer_range = config.initializer_range 35 | self.aug_views = 2 36 | self.tao = config.tao 37 | self.all_hidden = config.all_hidden 38 | self.lamda = config.lamda 39 | self.k_intention = config.k_intention 40 | 41 | self.item_embedding = nn.Embedding(self.num_items, self.embed_size, padding_idx=0) 42 | self.position_embedding = nn.Embedding(self.max_len, self.embed_size) 43 | self.input_layer_norm = nn.LayerNorm(self.embed_size, eps=config.layer_norm_eps) 44 | self.input_dropout = nn.Dropout(config.hidden_dropout) 45 | self.local_encoder = Transformer(embed_size=self.embed_size, 46 | ffn_hidden=config.ffn_hidden, 47 | num_blocks=config.num_blocks, 48 | num_heads=config.num_heads, 49 | attn_dropout=config.attn_dropout, 50 | hidden_dropout=config.hidden_dropout, 51 | layer_norm_eps=config.layer_norm_eps) 52 | self.global_seq_encoder = GlobalSeqEncoder(embed_size=self.embed_size, 53 | max_len=self.max_len, 54 | dropout=config.hidden_dropout) 55 | self.disentangle_encoder = DisentangleEncoder(k_intention=self.k_intention, 56 | embed_size=self.embed_size, 57 | max_len=self.max_len) 58 | self.nce_loss = InfoNCELoss(temperature=self.tao, 59 | similarity_type='dot') 60 | self.apply(self._init_weights) 61 | 62 | def _init_weights(self, module): 63 | if isinstance(module, (nn.Embedding, nn.Linear)): 64 | module.weight.data.normal_(mean=0.0, std=self.initializer_range) 65 | elif isinstance(module, nn.LayerNorm): 66 | module.weight.data.fill_(1.) 67 | module.bias.data.zero_() 68 | elif isinstance(module, nn.Linear) and module.bias is not None: 69 | module.bias.data.zero_() 70 | 71 | def train_forward(self, data_dict): 72 | _, _, target = self.load_basic_SR_data(data_dict) 73 | aug_seq_1, aug_len_1 = data_dict['aug_seq_1'], data_dict['aug_len_1'] 74 | aug_seq_2, aug_len_2 = data_dict['aug_seq_2'], data_dict['aug_len_2'] 75 | 76 | # rec task 77 | max_logits = self.forward(data_dict) 78 | rec_loss = self.cross_entropy(max_logits, target) 79 | 80 | # cl task 81 | B = target.size(0) 82 | aug_local_emb_1 = self.local_seq_encoding(aug_seq_1, aug_len_1, return_all=self.all_hidden) 83 | aug_global_emb_1 = self.global_seq_encoding(aug_seq_1, aug_len_1) 84 | disentangled_intention_1 = self.disentangle_encoder(aug_local_emb_1, aug_global_emb_1, aug_len_1) 85 | disentangled_intention_1 = disentangled_intention_1.view(B * self.k_intention, -1) # [B * K, L * D] 86 | 87 | aug_local_emb_2 = self.local_seq_encoding(aug_seq_2, aug_len_2, return_all=self.all_hidden) 88 | aug_global_emb_2 = self.global_seq_encoding(aug_seq_2, aug_len_2) 89 | disentangled_intention_2 = self.disentangle_encoder(aug_local_emb_2, aug_global_emb_2, aug_len_2) 90 | disentangled_intention_2 = disentangled_intention_2.view(B * self.k_intention, -1) # [B * K, L * D] 91 | 92 | cl_loss = self.nce_loss(disentangled_intention_1, disentangled_intention_2) 93 | 94 | return rec_loss + self.lamda * cl_loss 95 | 96 | def forward(self, data_dict): 97 | item_seq, seq_len, _ = self.load_basic_SR_data(data_dict) 98 | local_seq_emb = self.local_seq_encoding(item_seq, seq_len, return_all=True) # [B, L, D] 99 | global_seq_emb = self.global_seq_encoding(item_seq, seq_len) 100 | disentangled_intention_emb = self.disentangle_encoder(local_seq_emb, global_seq_emb, seq_len) # [B, K, L, D] 101 | 102 | gather_index = seq_len.view(-1, 1, 1, 1).repeat(1, self.k_intention, 1, self.embed_size) 103 | disentangled_intention_emb = disentangled_intention_emb.gather(2, gather_index - 1).squeeze() # [B, K, D] 104 | candidates = self.item_embedding.weight.unsqueeze(0) # [1, num_items, D] 105 | logits = disentangled_intention_emb @ candidates.permute(0, 2, 1) # [B, K, num_items] 106 | max_logits, _ = torch.max(logits, 1) 107 | 108 | return max_logits 109 | 110 | def position_encoding(self, item_input): 111 | seq_embedding = self.item_embedding(item_input) 112 | position = torch.arange(self.max_len, device=item_input.device).unsqueeze(0) 113 | position = position.expand_as(item_input).long() 114 | pos_embedding = self.position_embedding(position) 115 | seq_embedding += pos_embedding 116 | seq_embedding = self.input_layer_norm(seq_embedding) 117 | seq_embedding = self.input_dropout(seq_embedding) 118 | 119 | return seq_embedding 120 | 121 | def local_seq_encoding(self, item_seq, seq_len, return_all=False): 122 | seq_embedding = self.position_encoding(item_seq) 123 | out_seq_embedding = self.local_encoder(item_seq, seq_embedding) 124 | if not return_all: 125 | out_seq_embedding = self.gather_index(out_seq_embedding, seq_len - 1) 126 | return out_seq_embedding 127 | 128 | def global_seq_encoding(self, item_seq, seq_len): 129 | return self.global_seq_encoder(item_seq, seq_len, self.item_embedding) 130 | 131 | 132 | class GlobalSeqEncoder(nn.Module): 133 | def __init__(self, embed_size, max_len, dropout=0.5): 134 | super(GlobalSeqEncoder, self).__init__() 135 | self.embed_size = embed_size 136 | self.max_len = max_len 137 | self.dropout = nn.Dropout(dropout) 138 | 139 | self.Q_s = nn.Parameter(torch.randn(max_len, embed_size)) 140 | self.K_linear = nn.Linear(embed_size, embed_size) 141 | self.V_linear = nn.Linear(embed_size, embed_size) 142 | 143 | def forward(self, item_seq, seq_len, item_embeddings): 144 | """ 145 | Args: 146 | item_seq (tensor): [B, L] 147 | seq_len (tensor): [B] 148 | item_embeddings (tensor): [num_items, D], item embedding table 149 | 150 | Returns: 151 | global_seq_emb: [B, L, D] 152 | """ 153 | item_emb = item_embeddings(item_seq) # [B, L, D] 154 | item_key = self.K_linear(item_emb) 155 | item_value = self.V_linear(item_emb) 156 | 157 | attn_logits = self.Q_s @ item_key.permute(0, 2, 1) # [B, L, L] 158 | attn_score = F.softmax(attn_logits, -1) 159 | global_seq_emb = self.dropout(attn_score @ item_value) 160 | 161 | return global_seq_emb 162 | 163 | 164 | class DisentangleEncoder(nn.Module): 165 | def __init__(self, k_intention, embed_size, max_len): 166 | super(DisentangleEncoder, self).__init__() 167 | self.embed_size = embed_size 168 | 169 | self.intentions = nn.Parameter(torch.randn(k_intention, embed_size)) 170 | self.pos_fai = nn.Embedding(max_len, embed_size) 171 | self.rou = nn.Parameter(torch.randn(embed_size, )) 172 | self.W = nn.Linear(embed_size, embed_size) 173 | self.layer_norm_1 = nn.LayerNorm(embed_size) 174 | self.layer_norm_2 = nn.LayerNorm(embed_size) 175 | self.layer_norm_3 = nn.LayerNorm(embed_size) 176 | self.layer_norm_4 = nn.LayerNorm(embed_size) 177 | self.layer_norm_5 = nn.LayerNorm(embed_size) 178 | 179 | def forward(self, local_item_emb, global_item_emb, seq_len): 180 | """ 181 | Args: 182 | local_item_emb: [B, L, D] 183 | global_item_emb: [B, L, D] 184 | seq_len: [B] 185 | Returns: 186 | disentangled_intention_emb: [B, K, L, D] 187 | """ 188 | local_disen_emb = self.intention_disentangling(local_item_emb, seq_len) 189 | global_siden_emb = self.intention_disentangling(global_item_emb, seq_len) 190 | disentangled_intention_emb = local_disen_emb + global_siden_emb 191 | 192 | return disentangled_intention_emb 193 | 194 | def item2IntentionScore(self, item_emb): 195 | """ 196 | Args: 197 | item_emb: [B, L, D] 198 | Returns: 199 | score: [B, L, K] 200 | """ 201 | item_emb_norm = self.layer_norm_1(item_emb) # [B, L, D] 202 | intention_norm = self.layer_norm_2(self.intentions).unsqueeze(0) # [1, K, D] 203 | 204 | logits = item_emb_norm @ intention_norm.permute(0, 2, 1) # [B, L, K] 205 | score = F.softmax(logits / math.sqrt(self.embed_size), -1) 206 | 207 | return score 208 | 209 | def item2AttnWeight(self, item_emb, seq_len): 210 | """ 211 | Args: 212 | item_emb: [B, L, D] 213 | seq_len: [B] 214 | Returns: 215 | score: [B, L] 216 | """ 217 | B, L = item_emb.size(0), item_emb.size(1) 218 | dev = item_emb.device 219 | item_query_row = item_emb[torch.arange(B).to(dev), seq_len - 1] # [B, D] 220 | item_query_row += self.pos_fai(seq_len - 1) + self.rou 221 | item_query = self.layer_norm_3(item_query_row).unsqueeze(1) # [B, 1, D] 222 | 223 | pos_fai_tensor = self.pos_fai(torch.arange(L).to(dev)).unsqueeze(0) # [1, L, D] 224 | item_key_hat = self.layer_norm_4(item_emb + pos_fai_tensor) 225 | item_key = item_key_hat + torch.relu(self.W(item_key_hat)) 226 | 227 | logits = item_query @ item_key.permute(0, 2, 1) # [B, 1, L] 228 | logits = logits.squeeze() / math.sqrt(self.embed_size) 229 | score = F.softmax(logits, -1) 230 | 231 | return score 232 | 233 | def intention_disentangling(self, item_emb, seq_len): 234 | """ 235 | Args: 236 | item_emb: [B. L, D] 237 | seq_len: [B] 238 | Returns: 239 | item_disentangled_emb: [B, K, L, D] 240 | """ 241 | # get score 242 | item2intention_score = self.item2IntentionScore(item_emb) 243 | item_attn_weight = self.item2AttnWeight(item_emb, seq_len) 244 | 245 | # get disentangled embedding 246 | score_fuse = item2intention_score * item_attn_weight.unsqueeze(-1) # [B, L, K] 247 | score_fuse = score_fuse.permute(0, 2, 1).unsqueeze(-1) # [B, K, L, 1] 248 | item_emb_k = item_emb.unsqueeze(1) # [B, 1, L, D] 249 | disentangled_item_emb = self.layer_norm_5(score_fuse * item_emb_k) 250 | return disentangled_item_emb 251 | 252 | 253 | def IOCRec_config(): 254 | parser = HyperParamDict('IOCRec default hyper-parameters') 255 | parser.add_argument('--model', default='IOCRec', type=str) 256 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge']) 257 | # Contrast Learning Hyper Params 258 | parser.add_argument('--aug_types', default=['crop', 'mask', 'reorder'], help='augmentation types') 259 | parser.add_argument('--crop_ratio', default=0.4, type=float, 260 | help='Crop augmentation: proportion of cropped subsequence in origin sequence') 261 | parser.add_argument('--mask_ratio', default=0.3, type=float, 262 | help='Mask augmentation: proportion of masked items in origin sequence') 263 | parser.add_argument('--reorder_ratio', default=0.2, type=float, 264 | help='Reorder augmentation: proportion of reordered subsequence in origin sequence') 265 | parser.add_argument('--all_hidden', action='store_false', help='all hidden states for cl') 266 | parser.add_argument('--tao', default=1., type=float, help='temperature for softmax') 267 | parser.add_argument('--lamda', default=0.1, type=float, 268 | help='weight for contrast learning loss, only work when jointly training') 269 | parser.add_argument('--k_intention', default=4, type=int, help='number of disentangled intention') 270 | # Transformer 271 | parser.add_argument('--embed_size', default=64, type=int) 272 | parser.add_argument('--ffn_hidden', default=128, type=int, help='hidden dim for feed forward network') 273 | parser.add_argument('--num_blocks', default=3, type=int, help='number of transformer block') 274 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention') 275 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate') 276 | parser.add_argument('--attn_dropout', default=0.5, type=float, help='dropout rate for attention') 277 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps') 278 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range') 279 | 280 | parser.add_argument('--loss_type', default='CE', type=str, choices=['CE', 'BPR', 'BCE', 'CUSTOM']) 281 | 282 | return parser 283 | 284 | 285 | if __name__ == '__main__': 286 | a = torch.randn(20, 5, 50, 32) 287 | seq_len = torch.randperm(4).long() 288 | gather_index = seq_len.view(-1, 1, 1, 1).repeat(1, 5, 1, 32) 289 | res = torch.gather(a, 2, gather_index) 290 | 291 | print(res.size()) 292 | -------------------------------------------------------------------------------- /src/model/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import random 4 | import numpy as np 5 | import torch 6 | import math 7 | 8 | 9 | class AbstractDataAugmentor: 10 | def __init__(self, aug_ratio): 11 | self.aug_ratio = aug_ratio 12 | 13 | def transform(self, item_seq, seq_len): 14 | """ 15 | :param item_seq: torch.LongTensor, [batch, max_len] 16 | :param seq_len: torch.LongTensor, [batch] 17 | :return: aug_seq: torch.LongTensor, [batch, max_len] 18 | """ 19 | raise NotImplementedError 20 | 21 | 22 | class CropAugmentor(AbstractDataAugmentor): 23 | """ 24 | Torch version. 25 | """ 26 | 27 | def __init__(self, aug_ratio): 28 | super(CropAugmentor, self).__init__(aug_ratio) 29 | 30 | def transform(self, item_seq, seq_len): 31 | """ 32 | :param item_seq: torch.LongTensor, [batch, max_len] 33 | :param seq_len: torch.LongTensor, [batch] 34 | :return: aug_seq: torch.LongTensor, [batch, max_len] 35 | """ 36 | max_len = item_seq.size(-1) 37 | aug_seq_len = torch.ceil(seq_len * self.aug_ratio).long() 38 | # get start index 39 | index = torch.arange(max_len, device=seq_len.device) 40 | index = index.expand_as(item_seq) 41 | up_bound = (seq_len - aug_seq_len).unsqueeze(-1) 42 | prob = torch.zeros_like(item_seq, device=seq_len.device).float() 43 | prob[index <= up_bound] = 1. 44 | start_index = torch.multinomial(prob, 1) 45 | # item indices in subsequence 46 | gather_index = torch.arange(max_len, device=seq_len.device) 47 | gather_index = gather_index.expand_as(item_seq) 48 | gather_index = gather_index + start_index 49 | max_seq_len = aug_seq_len.unsqueeze(-1) 50 | gather_index[index >= max_seq_len] = 0 51 | # augmented subsequence 52 | aug_seq = torch.gather(item_seq, -1, gather_index).long() 53 | aug_seq[index >= max_seq_len] = 0 54 | 55 | return aug_seq, aug_seq_len 56 | 57 | 58 | class MaskDAugmentor(AbstractDataAugmentor): 59 | """ 60 | Torch version. 61 | """ 62 | 63 | def __init__(self, aug_ratio, mask_id=0): 64 | super(MaskDAugmentor, self).__init__(aug_ratio) 65 | self.mask_id = mask_id 66 | 67 | def transform(self, item_seq, seq_len): 68 | """ 69 | :param item_seq: torch.LongTensor, [batch, max_len] 70 | :param seq_len: torch.LongTensor, [batch] 71 | :return: aug_seq: torch.LongTensor, [batch, max_len] 72 | """ 73 | max_len = item_seq.size(-1) 74 | aug_seq = item_seq.clone() 75 | aug_seq_len = seq_len.clone() 76 | # get mask item id 77 | mask_item_size = math.ceil(max_len * self.aug_ratio) 78 | prob = torch.ones_like(item_seq, device=seq_len.device).float() 79 | masked_item_id = torch.multinomial(prob, mask_item_size) 80 | # mask 81 | aug_seq = aug_seq.scatter(-1, masked_item_id, self.mask_id) 82 | 83 | return aug_seq, aug_seq_len 84 | 85 | 86 | class ReorderAugmentor(AbstractDataAugmentor): 87 | """ 88 | Torch version. 89 | """ 90 | 91 | def __init__(self, aug_ratio): 92 | super(ReorderAugmentor, self).__init__(aug_ratio) 93 | 94 | def transform(self, item_seq, seq_len): 95 | """ 96 | Parameters 97 | ---------- 98 | item_seq: [batch_size, max_len] 99 | seq_len: [batch_size] 100 | 101 | Returns 102 | ------- 103 | aug_item_seq: [batch_size, max_len] 104 | aug_seq_len: [batch_size] 105 | """ 106 | dev = item_seq.device 107 | batch_size, max_len = item_seq.shape 108 | 109 | # get start position 110 | reorder_size = (seq_len * self.aug_ratio).ceil().long().unsqueeze(-1) # [B, 1] 111 | position_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev) # [B, L] 112 | sample_prob = (position_tensor <= seq_len.unsqueeze(-1) - reorder_size).bool().float() 113 | start_index = torch.multinomial(sample_prob, num_samples=1) # [B, 1] 114 | 115 | # get reorder item mask 116 | reorder_item_mask = (start_index <= position_tensor) & (position_tensor < start_index + reorder_size) 117 | 118 | # reorder operation 119 | tmp_reorder_tensor = torch.zeros_like(item_seq).long().to(dev) 120 | tmp_reorder_tensor[reorder_item_mask] = item_seq[reorder_item_mask] 121 | rand_index = torch.randperm(max_len) 122 | tmp_reorder_tensor = tmp_reorder_tensor[:, rand_index] 123 | 124 | # put reordered items back 125 | aug_item_seq = item_seq.clone() 126 | aug_item_seq[reorder_item_mask] = tmp_reorder_tensor[tmp_reorder_tensor > 0] 127 | 128 | return aug_item_seq, seq_len 129 | 130 | 131 | class RepeatAugmentor(AbstractDataAugmentor): 132 | """ 133 | Torch version. 134 | """ 135 | 136 | def __init__(self, aug_ratio): 137 | super(RepeatAugmentor, self).__init__(aug_ratio) 138 | 139 | def transform_1(self, item_seq, seq_len): 140 | """ 141 | Parameters 142 | ---------- 143 | item_seq: [batch_size, max_len] 144 | seq_len: [batch_size] 145 | 146 | Returns 147 | ------- 148 | aug_item_seq: [batch_size, max_len] 149 | aug_seq_len: [batch_size] 150 | """ 151 | dev = item_seq.device 152 | batch_size, max_len = item_seq.shape 153 | sample_size = math.ceil(max_len * self.aug_ratio) 154 | 155 | # get reordered index 156 | valid_pos_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev) 157 | valid_pos_tensor[item_seq == 0] = -1 158 | rand_index = torch.randperm(max_len).to(dev) 159 | rand_pos_tensor = valid_pos_tensor[:, rand_index] 160 | reordered_index = item_seq.clone() 161 | reordered_index[reordered_index > 0] = rand_pos_tensor[rand_pos_tensor > 0] 162 | 163 | # break off 164 | 165 | # sample repeat elements 166 | sample_prob = torch.ones_like(item_seq).float().to(dev) 167 | repeat_pos = torch.multinomial(sample_prob, num_samples=sample_size) 168 | sorted_repeat_pos, _ = torch.sort(repeat_pos, dim=-1) 169 | repeat_element = item_seq.gather(dim=-1, index=sorted_repeat_pos).long() 170 | 171 | # augmented item sequences 172 | padding_seq = torch.zeros((batch_size, sample_size)).to(dev) 173 | aug_item_seq = torch.cat([item_seq, padding_seq], dim=-1).long() # [B, L + L'] 174 | 175 | # get insert position mask of sampled item 176 | valid_pos_tensor = torch.arange(sample_size).unsqueeze(0).to(dev) 177 | ele_insert_pos = sorted_repeat_pos + valid_pos_tensor 178 | insert_mask = torch.zeros_like(aug_item_seq).to(item_seq) 179 | insert_mask = insert_mask.scatter(dim=-1, index=ele_insert_pos, value=1).bool() 180 | 181 | # set elements 182 | aug_item_seq[insert_mask] = repeat_element.flatten() 183 | aug_item_seq[~insert_mask] = item_seq.flatten() 184 | 185 | # slice 186 | full_seq_size = (seq_len == max_len).bool().sum() 187 | new_seq_len = (aug_item_seq > 0).sum(-1) 188 | _, sorted_idx = torch.sort(new_seq_len, dim=0, descending=True) 189 | _, restore_idx = torch.sort(sorted_idx, dim=0) 190 | 191 | sorted_aug_seq = aug_item_seq[sorted_idx] 192 | full_aug_seq = sorted_aug_seq[:full_seq_size] 193 | full_aug_seq = full_aug_seq[:, -max_len:] 194 | non_full_aug_seq = sorted_aug_seq[full_seq_size:] 195 | non_full_aug_seq = non_full_aug_seq[:, :max_len] 196 | aug_item_seq = torch.cat([full_aug_seq, non_full_aug_seq], dim=0) 197 | 198 | # restore position 199 | aug_item_seq = aug_item_seq[restore_idx] 200 | aug_seq_len = (aug_item_seq > 0).bool().sum(-1) 201 | 202 | return aug_item_seq, aug_seq_len 203 | 204 | def transform(self, item_seq, seq_len): 205 | """ 206 | Parameters 207 | ---------- 208 | item_seq: [batch_size, max_len] 209 | seq_len: [batch_size] 210 | 211 | Returns 212 | ------- 213 | aug_item_seq: [batch_size, max_len] 214 | aug_seq_len: [batch_size] 215 | """ 216 | dev = item_seq.device 217 | batch_size, max_len = item_seq.shape 218 | sample_size = int(item_seq.size(-1) * self.aug_ratio) 219 | 220 | # sample repeat elements 221 | sample_prob = torch.ones_like(item_seq).float().to(dev) 222 | repeat_pos = torch.multinomial(sample_prob, num_samples=sample_size) 223 | sorted_repeat_pos, _ = torch.sort(repeat_pos, dim=-1) 224 | repeat_element = item_seq.gather(dim=-1, index=sorted_repeat_pos).long() 225 | 226 | # augmented item sequences 227 | padding_seq = torch.zeros((batch_size, sample_size)).to(dev) 228 | aug_item_seq = torch.cat([item_seq, padding_seq], dim=-1).long() # [B, L + L'] 229 | 230 | # get insert position mask of sampled item 231 | position_tensor = torch.arange(sample_size).unsqueeze(0).to(dev) 232 | ele_insert_pos = sorted_repeat_pos + position_tensor 233 | insert_mask = torch.zeros_like(aug_item_seq).to(item_seq) 234 | insert_mask = insert_mask.scatter(dim=-1, index=ele_insert_pos, value=1).bool() 235 | 236 | # set elements 237 | aug_item_seq[insert_mask] = repeat_element.flatten() 238 | aug_item_seq[~insert_mask] = item_seq.flatten() 239 | 240 | # slice 241 | full_seq_size = (seq_len == max_len).bool().sum() 242 | new_seq_len = (aug_item_seq > 0).sum(-1) 243 | _, sorted_idx = torch.sort(new_seq_len, dim=0, descending=True) 244 | _, restore_idx = torch.sort(sorted_idx, dim=0) 245 | 246 | sorted_aug_seq = aug_item_seq[sorted_idx] 247 | full_aug_seq = sorted_aug_seq[:full_seq_size] 248 | full_aug_seq = full_aug_seq[:, -max_len:] 249 | non_full_aug_seq = sorted_aug_seq[full_seq_size:] 250 | non_full_aug_seq = non_full_aug_seq[:, :max_len] 251 | aug_item_seq = torch.cat([full_aug_seq, non_full_aug_seq], dim=0) 252 | 253 | # restore position 254 | aug_item_seq = aug_item_seq[restore_idx] 255 | aug_seq_len = (aug_item_seq > 0).bool().sum(-1) 256 | 257 | return aug_item_seq, aug_seq_len 258 | 259 | 260 | class DropAugmentor(AbstractDataAugmentor): 261 | """ 262 | Torch version of item drop operation. 263 | """ 264 | 265 | def __init__(self, aug_ratio): 266 | super(DropAugmentor, self).__init__(aug_ratio) 267 | 268 | def transform(self, item_seq, seq_len, drop_prob=None): 269 | """ 270 | Parameters 271 | ---------- 272 | item_seq: [batch_size, max_len] 273 | seq_len: [batch_size] 274 | drop_prob: [batch_size, max_len] 275 | 276 | Returns 277 | ------- 278 | aug_item_seq: [batch_size, max_len] 279 | aug_seq_len: [batch_size] 280 | """ 281 | 282 | dev = item_seq.device 283 | batch_size, max_len = item_seq.shape 284 | drop_size = int(item_seq.size(-1) * self.aug_ratio) 285 | 286 | # sample drop item indices 287 | if drop_prob is None: 288 | drop_prob = torch.ones_like(item_seq).float().to(dev) 289 | drop_indices = torch.multinomial(drop_prob, num_samples=drop_size) # [B, drop_size] 290 | 291 | # fill 0 items 292 | row_dropped_item_seq = item_seq.scatter(-1, drop_indices, 0).long() 293 | valid_item_mask = (row_dropped_item_seq > 0).bool() 294 | dropped_seq_len = valid_item_mask.sum(-1) # [B] 295 | position_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev) # [B, L] 296 | valid_pos_mask = (position_tensor < dropped_seq_len.unsqueeze(-1)).bool() 297 | 298 | # post-process 299 | dropped_item_seq = torch.zeros_like(item_seq).to(dev) 300 | dropped_item_seq[valid_pos_mask] = row_dropped_item_seq[valid_item_mask] 301 | 302 | # avoid all 0 item 303 | empty_seq_mask = (dropped_seq_len == 0).bool() 304 | empty_seq_mask = empty_seq_mask.unsqueeze(-1).repeat(1, max_len) 305 | empty_seq_mask[:, 1:] = 0 306 | dropped_item_seq[empty_seq_mask] = item_seq[empty_seq_mask] 307 | dropped_seq_len = (dropped_item_seq > 0).sum(-1) # [B] 308 | 309 | return dropped_item_seq, dropped_seq_len 310 | 311 | 312 | class CauseCropAugmentor(AbstractDataAugmentor): 313 | """ 314 | Torch version. 315 | """ 316 | 317 | def __init__(self, aug_ratio): 318 | super(CauseCropAugmentor, self).__init__(aug_ratio) 319 | 320 | def transform(self, item_seq, seq_len, critical_mask=None): 321 | """ 322 | :param item_seq: torch.LongTensor, [batch, max_len] 323 | :param seq_len: torch.LongTensor, [batch] 324 | :return: aug_seq: torch.LongTensor, [batch, max_len] 325 | """ 326 | max_len = item_seq.size(-1) 327 | aug_seq_len = torch.ceil(seq_len * self.aug_ratio).long() 328 | # get start index 329 | index = torch.arange(max_len, device=seq_len.device) 330 | index = index.expand_as(item_seq) 331 | up_bound = (seq_len - aug_seq_len).unsqueeze(-1) 332 | prob = torch.zeros_like(item_seq, device=seq_len.device).float() 333 | prob[index <= up_bound] = 1. 334 | start_index = torch.multinomial(prob, 1) 335 | # item indices in subsequence 336 | gather_index = torch.arange(max_len, device=seq_len.device) 337 | gather_index = gather_index.expand_as(item_seq) 338 | gather_index = gather_index + start_index 339 | max_seq_len = aug_seq_len.unsqueeze(-1) 340 | gather_index[index >= max_seq_len] = 0 341 | # augmented subsequence 342 | aug_seq = torch.gather(item_seq, -1, gather_index).long() 343 | aug_seq[index >= max_seq_len] = 0 344 | 345 | return aug_seq, aug_seq_len 346 | 347 | 348 | class CauseReorderAugmentor(AbstractDataAugmentor): 349 | """ 350 | Torch version. 351 | """ 352 | 353 | def __init__(self, aug_ratio): 354 | super(CauseReorderAugmentor, self).__init__(aug_ratio) 355 | 356 | def transform(self, item_seq, seq_len): 357 | """ 358 | Parameters 359 | ---------- 360 | item_seq: [batch_size, max_len] 361 | seq_len: [batch_size] 362 | 363 | Returns 364 | ------- 365 | aug_item_seq: [batch_size, max_len] 366 | aug_seq_len: [batch_size] 367 | """ 368 | dev = item_seq.device 369 | batch_size, max_len = item_seq.shape 370 | 371 | # get start position 372 | reorder_size = (seq_len * self.aug_ratio).ceil().long().unsqueeze(-1) # [B, 1] 373 | position_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev) # [B, L] 374 | sample_prob = (position_tensor <= seq_len.unsqueeze(-1) - reorder_size).bool().float() 375 | start_index = torch.multinomial(sample_prob, num_samples=1) # [B, 1] 376 | 377 | # get reorder item mask 378 | reorder_item_mask = (start_index <= position_tensor) & (position_tensor < start_index + reorder_size) 379 | 380 | # reorder operation 381 | tmp_reorder_tensor = torch.zeros_like(item_seq).long().to(dev) 382 | tmp_reorder_tensor[reorder_item_mask] = item_seq[reorder_item_mask] 383 | rand_index = torch.randperm(max_len) 384 | tmp_reorder_tensor = tmp_reorder_tensor[:, rand_index] 385 | 386 | # put reordered items back 387 | aug_item_seq = item_seq.clone() 388 | aug_item_seq[reorder_item_mask] = tmp_reorder_tensor[tmp_reorder_tensor > 0] 389 | 390 | return aug_item_seq, seq_len 391 | 392 | 393 | class Crop(object): 394 | """Randomly crop a subseq from the original sequence""" 395 | 396 | def __init__(self, tao=0.2): 397 | self.tao = tao 398 | 399 | def __call__(self, sequence): 400 | # make a deep copy to avoid original sequence be modified 401 | copied_sequence = copy.deepcopy(sequence) 402 | 403 | # # add length constraints 404 | # if len(copied_sequence) < 5: 405 | # return copied_sequence 406 | 407 | sub_seq_length = int(self.tao * len(copied_sequence)) 408 | # randint generate int x in range: a <= x <= b 409 | start_index = random.randint(0, len(copied_sequence) - sub_seq_length) 410 | if sub_seq_length < 1: 411 | return [copied_sequence[min(start_index, len(sequence) - 1)]] 412 | else: 413 | cropped_seq = copied_sequence[start_index:start_index + sub_seq_length] 414 | return cropped_seq 415 | 416 | 417 | class Mask(object): 418 | """Randomly mask k items given a sequence""" 419 | 420 | def __init__(self, gamma=0.7, mask_id=0): 421 | self.gamma = gamma 422 | self.mask_id = mask_id 423 | 424 | def __call__(self, sequence): 425 | # make a deep copy to avoid original sequence be modified 426 | copied_sequence = copy.deepcopy(sequence) 427 | 428 | # # add length constraints 429 | # if len(copied_sequence) < 5: 430 | # return copied_sequence 431 | 432 | mask_nums = int(self.gamma * len(copied_sequence)) 433 | mask_idx = random.sample([i for i in range(len(copied_sequence))], k=mask_nums) 434 | for idx in mask_idx: 435 | copied_sequence[idx] = self.mask_id 436 | return copied_sequence 437 | 438 | 439 | class Reorder(object): 440 | """Randomly shuffle a continuous sub-sequence""" 441 | 442 | def __init__(self, beta=0.2): 443 | self.beta = beta 444 | 445 | def __call__(self, sequence): 446 | # make a deep copy to avoid original sequence be modified 447 | copied_sequence = copy.deepcopy(sequence) 448 | 449 | # # add length constraints 450 | # if len(copied_sequence) < 5: 451 | # return copied_sequence 452 | 453 | sub_seq_len = int(self.beta * len(copied_sequence)) 454 | start_index = random.randint(0, len(copied_sequence) - sub_seq_len) 455 | sub_seq = copied_sequence[start_index:start_index + sub_seq_len] 456 | random.shuffle(sub_seq) 457 | reordered_seq = copied_sequence[:start_index] + sub_seq + \ 458 | copied_sequence[start_index + sub_seq_len:] 459 | assert len(copied_sequence) == len(reordered_seq) 460 | return reordered_seq 461 | 462 | 463 | class Repeat(object): 464 | """Randomly repeat p% of items in sequence""" 465 | 466 | def __init__(self, p=0.2, min_rep_size=1): 467 | self.p = p # max repeat ratio 468 | self.min_rep_size = min_rep_size 469 | 470 | def __call__(self, sequence): 471 | # make a deep copy to avoid original sequence be modified 472 | copied_sequence = copy.deepcopy(sequence) 473 | max_repeat_nums = math.ceil(self.p * len(copied_sequence)) 474 | repeat_nums = \ 475 | random.sample([i for i in range(self.min_rep_size, max(self.min_rep_size, max_repeat_nums) + 1)], k=1)[0] 476 | repeat_idx = random.sample([i for i in range(len(copied_sequence))], k=repeat_nums) 477 | repeat_idx.sort() 478 | new_seq = [] 479 | cur_idx = 0 480 | for i, item in enumerate(copied_sequence): 481 | new_seq.append(item) 482 | if cur_idx < len(repeat_idx) and i == repeat_idx[cur_idx]: 483 | new_seq.append(item) 484 | cur_idx += 1 485 | return new_seq 486 | 487 | 488 | class Drop(object): 489 | """Randomly repeat p% of items in sequence""" 490 | 491 | def __init__(self, p=0.2): 492 | self.p = p # max repeat ratio 493 | 494 | def __call__(self, sequence): 495 | # make a deep copy to avoid original sequence be modified 496 | copied_sequence = copy.deepcopy(sequence) 497 | drop_num = math.floor(self.p * len(copied_sequence)) 498 | drop_idx = random.sample([i for i in range(len(copied_sequence))], k=drop_num) 499 | drop_idx.sort() 500 | new_seq = [] 501 | cur_idx = 0 502 | for i, item in enumerate(copied_sequence): 503 | if cur_idx < len(drop_idx) and i == drop_idx[cur_idx]: 504 | cur_idx += 1 505 | continue 506 | new_seq.append(item) 507 | return new_seq 508 | 509 | 510 | AUGMENTATIONS = {'crop': Crop, 'mask': Mask, 'reorder': Reorder, 'repeat': Repeat, 'drop': Drop} 511 | 512 | -------------------------------------------------------------------------------- /src/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class InfoNCELoss(nn.Module): 7 | """ 8 | Pair-wise Noise Contrastive Estimation Loss 9 | """ 10 | 11 | def __init__(self, temperature, similarity_type): 12 | super(InfoNCELoss, self).__init__() 13 | self.temperature = temperature # temperature 14 | self.sim_type = similarity_type # cos or dot 15 | self.criterion = nn.CrossEntropyLoss() 16 | 17 | def forward(self, aug_hidden_view1, aug_hidden_view2, mask=None): 18 | """ 19 | Args: 20 | aug_hidden_view1 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1 21 | aug_hidden_view2 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1 22 | 23 | Returns: nce_loss (FloatTensor, (,)): calculated nce loss 24 | """ 25 | if aug_hidden_view1.ndim > 2: 26 | # flatten tensor 27 | aug_hidden_view1 = aug_hidden_view1.view(aug_hidden_view1.size(0), -1) 28 | aug_hidden_view2 = aug_hidden_view2.view(aug_hidden_view2.size(0), -1) 29 | 30 | if self.sim_type not in ['cos', 'dot']: 31 | raise Exception(f"Invalid similarity_type for cs loss: [current:{self.sim_type}]. " 32 | f"Please choose from ['cos', 'dot']") 33 | 34 | if self.sim_type == 'cos': 35 | sim11 = self.cosinesim(aug_hidden_view1, aug_hidden_view1) 36 | sim22 = self.cosinesim(aug_hidden_view2, aug_hidden_view2) 37 | sim12 = self.cosinesim(aug_hidden_view1, aug_hidden_view2) 38 | elif self.sim_type == 'dot': 39 | # calc similarity 40 | sim11 = aug_hidden_view1 @ aug_hidden_view1.t() 41 | sim22 = aug_hidden_view2 @ aug_hidden_view2.t() 42 | sim12 = aug_hidden_view1 @ aug_hidden_view2.t() 43 | # mask non-calc value 44 | sim11[..., range(sim11.size(0)), range(sim11.size(0))] = float('-inf') 45 | sim22[..., range(sim22.size(0)), range(sim22.size(0))] = float('-inf') 46 | 47 | cl_logits1 = torch.cat([sim12, sim11], -1) 48 | cl_logits2 = torch.cat([sim22, sim12.t()], -1) 49 | cl_logits = torch.cat([cl_logits1, cl_logits2], 0) / self.temperature 50 | if mask is not None: 51 | cl_logits = torch.masked_fill(cl_logits, mask, float('-inf')) 52 | target = torch.arange(cl_logits.size(0)).long().to(aug_hidden_view1.device) 53 | cl_loss = self.criterion(cl_logits, target) 54 | 55 | return cl_loss 56 | 57 | def cosinesim(self, aug_hidden1, aug_hidden2): 58 | h = torch.matmul(aug_hidden1, aug_hidden2.T) 59 | h1_norm2 = aug_hidden1.pow(2).sum(dim=-1).sqrt().view(h.shape[0], 1) 60 | h2_norm2 = aug_hidden2.pow(2).sum(dim=-1).sqrt().view(1, h.shape[0]) 61 | return h / (h1_norm2 @ h2_norm2) 62 | 63 | 64 | class InfoNCELoss_2(nn.Module): 65 | """ 66 | Pair-wise Noise Contrastive Estimation Loss, another implementation. 67 | """ 68 | 69 | def __init__(self, temperature, similarity_type, batch_size): 70 | super(InfoNCELoss_2, self).__init__() 71 | self.tem = temperature # temperature 72 | self.sim_type = similarity_type # cos or dot 73 | self.batch_size = batch_size 74 | self.mask = self.mask_correlated_samples(self.batch_size) 75 | self.criterion = nn.CrossEntropyLoss() 76 | 77 | def forward(self, aug_hidden1, aug_hidden2): 78 | """ 79 | Args: 80 | aug_hidden1 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1 81 | aug_hidden2 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1 82 | 83 | Returns: nce_loss (FloatTensor, (,)): calculated nce loss 84 | """ 85 | if aug_hidden1.ndim > 2: 86 | # flatten tensor 87 | aug_hidden1 = aug_hidden1.view(aug_hidden1.size(0), -1) 88 | aug_hidden2 = aug_hidden2.view(aug_hidden2.size(0), -1) 89 | 90 | current_batch = aug_hidden1.size(0) 91 | N = 2 * current_batch 92 | all_hidden = torch.cat((aug_hidden1, aug_hidden2), dim=0) # [2*B, D] 93 | 94 | if self.sim_type == 'cos': 95 | all_hidden = F.normalize(all_hidden) 96 | sim = torch.mm(all_hidden, all_hidden.T) / self.tem 97 | # sim = F.cosine_similarity(all_hidden.unsqueeze(1), all_hidden.unsqueeze(0), dim=2) / self.tem 98 | elif self.sim_type == 'dot': 99 | sim = torch.mm(all_hidden, all_hidden.T) / self.tem 100 | else: 101 | raise Exception(f"Invalid similarity_type for cs loss: [current:{self.sim_type}]. " 102 | f"Please choose from ['cos', 'dot']") 103 | 104 | sim_i_j = torch.diag(sim, current_batch) 105 | sim_j_i = torch.diag(sim, -current_batch) 106 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 107 | if self.batch_size != current_batch: 108 | mask = self.mask_correlated_samples(current_batch) 109 | else: 110 | mask = self.mask 111 | negative_samples = sim[mask].reshape(N, -1) 112 | labels = torch.zeros(N).to(positive_samples.device).long() 113 | logits = torch.cat((positive_samples, negative_samples), dim=1) 114 | nce_loss = self.criterion(logits, labels) 115 | 116 | return nce_loss 117 | 118 | def mask_correlated_samples(self, batch_size): 119 | N = 2 * batch_size 120 | mask = torch.ones((N, N)).bool() 121 | mask = mask.fill_diagonal_(0) 122 | index1 = torch.arange(batch_size) + batch_size 123 | index2 = torch.arange(batch_size) 124 | index = torch.cat([index1, index2], 0).unsqueeze(-1) # [2*B, 1] 125 | mask = torch.scatter(mask, -1, index, 0) 126 | return mask 127 | 128 | 129 | def lalign(x, y, alpha=2): 130 | return (x - y).norm(dim=-1).pow(alpha).mean() 131 | 132 | 133 | def lunif(x, t=2): 134 | sq_dlist = torch.pdist(x, p=2).pow(2) 135 | return torch.log(sq_dlist.mul(-t).exp().mean() + 1e-6) 136 | -------------------------------------------------------------------------------- /src/model/sequential_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import copy 4 | import math 5 | import torch.nn.functional as F 6 | 7 | 8 | class Transformer(nn.Module): 9 | def __init__(self, embed_size, ffn_hidden, num_blocks, num_heads, attn_dropout, hidden_dropout, 10 | layer_norm_eps=0.02, bidirectional=False): 11 | super(Transformer, self).__init__() 12 | self.bidirectional = bidirectional 13 | encoder_layer = EncoderLayer(embed_size=embed_size, 14 | ffn_hidden=ffn_hidden, 15 | num_heads=num_heads, 16 | attn_dropout=attn_dropout, 17 | hidden_dropout=hidden_dropout, 18 | layer_norm_eps=layer_norm_eps) 19 | self.encoder_layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_blocks)]) 20 | 21 | def forward(self, item_input, seq_embedding): 22 | """ 23 | Only output the sequence representations of the last layer in Transformer. 24 | out_seq_embed: torch.FloatTensor, [batch_size, max_len, embed_size] 25 | """ 26 | mask = self.create_mask(item_input) 27 | for layer in self.encoder_layers: 28 | seq_embedding = layer(seq_embedding, mask) 29 | return seq_embedding 30 | 31 | def create_mask(self, input_seq): 32 | """ 33 | Parameters: 34 | input_seq: torch.LongTensor, [batch_size, max_len] 35 | Return: 36 | mask: torch.BoolTensor, [batch_size, 1, max_len, max_len] 37 | """ 38 | mask = (input_seq != 0).bool().unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, max_len] 39 | mask = mask.expand(-1, -1, mask.size(-1), -1) 40 | if not self.bidirectional: 41 | mask = torch.tril(mask) 42 | return mask 43 | 44 | def set_attention_direction(self, bidirection=False): 45 | self.bidirectional = bidirection 46 | 47 | 48 | class EncoderLayer(nn.Module): 49 | def __init__(self, embed_size, ffn_hidden, num_heads, attn_dropout, hidden_dropout, layer_norm_eps): 50 | super(EncoderLayer, self).__init__() 51 | 52 | self.attn_layer_norm = nn.LayerNorm(embed_size, eps=layer_norm_eps) 53 | self.pff_layer_norm = nn.LayerNorm(embed_size, eps=layer_norm_eps) 54 | 55 | self.self_attention = MultiHeadAttentionLayer(embed_size, num_heads, attn_dropout) 56 | self.pff = PointWiseFeedForwardLayer(embed_size, ffn_hidden) 57 | 58 | self.hidden_dropout = nn.Dropout(hidden_dropout) 59 | self.pff_out_drop = nn.Dropout(hidden_dropout) 60 | 61 | def forward(self, input_seq, inputs_mask): 62 | """ 63 | input: 64 | inputs: torch.FloatTensor, [batch_size, max_len, embed_size] 65 | inputs_mask: torch.BoolTensor, [batch_size, 1, 1, max_len] 66 | return: 67 | out_seq_embed: torch.FloatTensor, [batch_size, max_len, embed_size] 68 | """ 69 | out_seq, att_matrix = self.self_attention(input_seq, input_seq, input_seq, inputs_mask) 70 | input_seq = self.attn_layer_norm(input_seq + self.hidden_dropout(out_seq)) 71 | out_seq = self.pff(input_seq) 72 | out_seq = self.pff_layer_norm(input_seq + self.pff_out_drop(out_seq)) 73 | return out_seq 74 | 75 | 76 | class MultiHeadAttentionLayer(nn.Module): 77 | def __init__(self, embed_size, nhead, attn_dropout): 78 | super(MultiHeadAttentionLayer, self).__init__() 79 | self.embed_size = embed_size 80 | self.nhead = nhead 81 | 82 | if self.embed_size % self.nhead != 0: 83 | raise ValueError( 84 | "The hidden size (%d) is not a multiple of the number of attention " 85 | "heads (%d)" % (self.embed_size, self.nhead) 86 | ) 87 | self.head_dim = self.embed_size // self.nhead 88 | 89 | # Q K V input linear layer 90 | self.fc_q = nn.Linear(self.embed_size, self.embed_size) 91 | self.fc_k = nn.Linear(self.embed_size, self.embed_size) 92 | self.fc_v = nn.Linear(self.embed_size, self.embed_size) 93 | 94 | self.attn_dropout = nn.Dropout(attn_dropout) 95 | self.fc_o = nn.Linear(self.embed_size, self.embed_size) 96 | self.register_buffer('scale', torch.sqrt(torch.tensor(self.head_dim).float())) 97 | 98 | def forward(self, query, key, value, inputs_mask=None): 99 | """ 100 | :param query: [query_size, max_len, embed_size] 101 | :param key: [key_size, max_len, embed_size] 102 | :param value: [key_size, max_len, embed_size] 103 | :param inputs_mask: [N, 1, max_len, max_len] 104 | :return: [N, max_len, embed_size] 105 | """ 106 | batch_size = query.size(0) 107 | Q = self.fc_q(query) 108 | K = self.fc_k(key) 109 | V = self.fc_v(value) 110 | 111 | # [batch_size, n_head, max_len, head_dim] 112 | Q = Q.view(query.size(0), -1, self.nhead, self.head_dim).permute((0, 2, 1, 3)) 113 | K = K.view(key.size(0), -1, self.nhead, self.head_dim).permute((0, 2, 1, 3)) 114 | V = V.view(value.size(0), -1, self.nhead, self.head_dim).permute((0, 2, 1, 3)) 115 | 116 | # calculate attention score 117 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 118 | if inputs_mask is not None: 119 | energy = energy.masked_fill(inputs_mask == 0, -1.e10) 120 | 121 | attention_prob = F.softmax(energy, dim=-1) 122 | attention_prob = self.attn_dropout(attention_prob) 123 | 124 | out = torch.matmul(attention_prob, V) # [batch_size, n_head, max_len, head_dim] 125 | out = out.permute((0, 2, 1, 3)).contiguous() # memory layout 126 | out = out.view((batch_size, -1, self.embed_size)) 127 | out = self.fc_o(out) 128 | return out, attention_prob 129 | 130 | 131 | class PointWiseFeedForwardLayer(nn.Module): 132 | def __init__(self, embed_size, hidden_size): 133 | super(PointWiseFeedForwardLayer, self).__init__() 134 | 135 | self.fc1 = nn.Linear(embed_size, hidden_size) 136 | self.fc2 = nn.Linear(hidden_size, embed_size) 137 | 138 | def forward(self, inputs): 139 | out = self.fc2(F.gelu(self.fc1(inputs))) 140 | return out 141 | -------------------------------------------------------------------------------- /src/model/sequential_recommender/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.sequential_recommender.sasrec import SASRec, SASRec_config 2 | -------------------------------------------------------------------------------- /src/model/sequential_recommender/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/sequential_recommender/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/sequential_recommender/__pycache__/sasrec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/sequential_recommender/__pycache__/sasrec.cpython-38.pyc -------------------------------------------------------------------------------- /src/model/sequential_recommender/sasrec.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn.functional as F 3 | from src.model.abstract_recommeder import AbstractRecommender 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.init import xavier_normal_, xavier_uniform_ 8 | from src.model.sequential_encoder import Transformer 9 | from src.utils.utils import HyperParamDict 10 | 11 | 12 | class SASRec(AbstractRecommender): 13 | def __init__(self, config, additional_data_dict): 14 | super(SASRec, self).__init__(config) 15 | self.embed_size = config.embed_size 16 | self.hidden_size = config.ffn_hidden 17 | self.initializer_range = config.initializer_range 18 | 19 | # module 20 | self.item_embedding = nn.Embedding(self.num_items, self.embed_size, padding_idx=0) 21 | self.position_embedding = nn.Embedding(self.max_len, self.embed_size) 22 | self.trm_encoder = Transformer(embed_size=self.embed_size, 23 | ffn_hidden=self.hidden_size, 24 | num_blocks=config.num_blocks, 25 | num_heads=config.num_heads, 26 | attn_dropout=config.attn_dropout, 27 | hidden_dropout=config.hidden_dropout, 28 | layer_norm_eps=config.layer_norm_eps) 29 | 30 | self.input_layer_norm = nn.LayerNorm(self.embed_size, eps=config.layer_norm_eps) 31 | self.dropout = nn.Dropout(config.hidden_dropout) 32 | 33 | self.apply(self._init_weights) 34 | 35 | def _init_weights(self, module): 36 | if isinstance(module, (nn.Embedding, nn.Linear)): 37 | module.weight.data.normal_(mean=0.0, std=self.initializer_range) 38 | elif isinstance(module, nn.LayerNorm): 39 | module.weight.data.fill_(1.) 40 | module.bias.data.zero_() 41 | elif isinstance(module, nn.Linear) and module.bias is not None: 42 | module.bias.data.zero_() 43 | 44 | # def train_forward(self, data_dict: dict): 45 | # item_seq, seq_len, target = self.load_basic_sr_data(data_dict) 46 | # seq_embedding = self.position_encoding(item_seq) 47 | # out_seq_embedding = self.trm_encoder(item_seq, seq_embedding) 48 | # loss = self.calc_loss_(item_seq, out_seq_embedding, target) 49 | # 50 | # return loss 51 | 52 | def forward(self, data_dict): 53 | item_seq, seq_len, _ = self.load_basic_SR_data(data_dict) 54 | seq_embedding = self.position_encoding(item_seq) 55 | out_seq_embedding = self.trm_encoder(item_seq, seq_embedding) 56 | seq_embedding = self.gather_index(out_seq_embedding, seq_len - 1) 57 | 58 | # get prediction 59 | candidates = self.item_embedding.weight 60 | logits = seq_embedding @ candidates.t() 61 | 62 | return logits 63 | 64 | def position_encoding(self, item_input): 65 | seq_embedding = self.item_embedding(item_input) 66 | position = torch.arange(self.max_len, device=item_input.device).unsqueeze(0) 67 | position = position.expand_as(item_input).long() 68 | pos_embedding = self.position_embedding(position) 69 | seq_embedding += pos_embedding 70 | seq_embedding = self.dropout(self.input_layer_norm(seq_embedding)) 71 | 72 | return seq_embedding 73 | 74 | # def calc_loss_(self, item_seq, out_seq_embedding, target): 75 | # """ 76 | # For no data augmentation situation. 77 | # item_seq: [B, L] 78 | # out_seq_embedding: [B, L, D] 79 | # target: [B, L] 80 | # """ 81 | # embed_size = out_seq_embedding.size(-1) 82 | # valid_mask = (item_seq > 0).view(-1).bool() 83 | # out_seq_embedding = out_seq_embedding.view(-1, embed_size) 84 | # target = target.view(-1) 85 | # 86 | # candidates = self.item_embedding.weight 87 | # logits = out_seq_embedding @ candidates.transpose(0, 1) 88 | # logits = logits[valid_mask] 89 | # target = target[valid_mask] 90 | # 91 | # loss = self.cross_entropy(logits, target) 92 | # 93 | # return loss 94 | 95 | 96 | def SASRec_config(): 97 | parser = HyperParamDict('SASRec default hyper-parameters') 98 | parser.add_argument('--model', default='SASRec', type=str) 99 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge']) 100 | parser.add_argument('--embed_size', default=128, type=int) 101 | parser.add_argument('--ffn_hidden', default=512, type=int, help='hidden dim for feed forward network') 102 | parser.add_argument('--num_blocks', default=2, type=int, help='number of transformer block') 103 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention') 104 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate') 105 | parser.add_argument('--attn_dropout', default=0., type=float, help='dropout rate for attention') 106 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps') 107 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range') 108 | parser.add_argument('--loss_type', default='CE', type=str, choices=['CE', 'BPR', 'BCE', 'CUSTOM']) 109 | return parser 110 | -------------------------------------------------------------------------------- /src/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__init__.py -------------------------------------------------------------------------------- /src/train/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/train/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /src/train/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /src/train/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.cuda 3 | from easydict import EasyDict 4 | from src.utils.utils import HyperParamDict 5 | 6 | EXP_HYPER_LIST = {'Data': {'dataset': None, 'data_aug': None, 'seq_filter_len': None, 7 | 'if_filter_target': None, 'max_len': None}, 8 | 'Pretraining': {'do_pretraining': None, 'pretraining_task': None, 'pretraining_epoch': None, 9 | 'pretraining_batch': None, 'pretraining_lr': None, 'pretraining_l2': None}, 10 | 'Training': {'epoch_num': None, 'train_batch': None, 11 | 'learning_rate': None, 'l2': None, 'patience': None, 12 | 'device': None, 'num_worker': None, 'seed': None}, 13 | 'Evaluation': {'split_type': None, 'split_mode': None, 'eval_mode': None, 'metric': None, 'k': None, 14 | 'valid_metric': None, 'eval_batch': None}, 15 | 'Save': {'log_save': None, 'save': None, 'model_saved': None}} 16 | 17 | 18 | def experiment_hyper_load(exp_config): 19 | hyper_types = EXP_HYPER_LIST.keys() 20 | for hyper_dict in EXP_HYPER_LIST.values(): 21 | for hyper in hyper_dict.keys(): 22 | hyper_dict[hyper] = getattr(exp_config, hyper) 23 | return list(hyper_types), EXP_HYPER_LIST 24 | 25 | 26 | def get_device(): 27 | return 'cuda:0' if torch.cuda.is_available() else 'cpu' 28 | 29 | 30 | def get_default_config(): 31 | parser = HyperParamDict() 32 | # Model 33 | parser.add_argument('--model', default='URCL4SRec') 34 | # Data 35 | parser.add_argument('--dataset', default='toys', type=str, 36 | choices=['home', 'grocery', 'grocery', 'yelp_s3', 'toys']) 37 | parser.add_argument('--data_aug', action='store_false', help='data augmentation') 38 | parser.add_argument('--seq_filter_len', default=0, type=int, help='filter seq less than 3') 39 | parser.add_argument('--if_filter_target', action='store_true', 40 | help='if filter target appearing in previous sequence') 41 | parser.add_argument('--separator', default=' ', type=str, help='separator to split item sequence') 42 | parser.add_argument('--graph_type', default='None', type=str, help='do not use graph', 43 | choices=['None', 'BIPARTITE', 'TRANSITION']) 44 | parser.add_argument('--max_len', default=50, type=int, help='max sequence length') 45 | parser.add_argument('--kg_data_type', default='pretrain', type=str, choices=['pretrain', 'jointly_train', 'other']) 46 | # Pretraining 47 | parser.add_argument('--do_pretraining', default=False, action='store_true') 48 | parser.add_argument('--pretraining_task', default='MISP', type=str, choices=['MISP', 'MIM', 'PID'], 49 | help='pretraining task:' \ 50 | 'MISP: Mask Item Prediction and Mask Segment Prediction' \ 51 | 'MIM: Mutual Information Maximization' \ 52 | 'PID: Pseudo Item Discrimination' 53 | ) 54 | parser.add_argument('--pretraining_epoch', default=10, type=int) 55 | parser.add_argument('--pretraining_batch', default=512, type=int) 56 | parser.add_argument('--pretraining_lr', default=1e-3, type=float) 57 | parser.add_argument('--pretraining_l2', default=0., type=float, help='l2 normalization') 58 | # Training 59 | parser.add_argument('--epoch_num', default=100, type=int) 60 | parser.add_argument('--seed', default=1034, type=int, help="random seed, only -1 means don't set random seed") 61 | parser.add_argument('--train_batch', default=256, type=int) 62 | parser.add_argument('--learning_rate', default=1e-3, type=float) 63 | parser.add_argument('--l2', default=0., type=float, help='l2 normalization') 64 | parser.add_argument('--patience', default=5, type=int, help='early stop patience') 65 | parser.add_argument('--device', default=get_device(), choices=['cuda:0', 'cpu'], 66 | help='training on gpu or cpu, default gpu') 67 | parser.add_argument('--num_worker', default=0, type=int, 68 | help='num_workers for dataloader, best: 6') 69 | parser.add_argument('--mark', default='', type=str, 70 | help='mark of this run which will be added to the name of the log') 71 | 72 | # Evaluation 73 | parser.add_argument('--split_type', default='valid_and_test', choices=['valid_only', 'valid_and_test']) 74 | parser.add_argument('--split_mode', default='LS', type=str, 75 | help='LS: Leave-one-out splitting.' 76 | 'LS_R@0.2: use LS and a ratio 0.x of test data for validate if use valid_and_test.' 77 | 'PS: Pre-Splitting, prepare xx.train and xx.eval, also xx.test if use valid_and_test') 78 | parser.add_argument('--eval_mode', default='full', help='[uni100, uni200, full]') 79 | parser.add_argument('--metric', default=['hit', 'ndcg'], help='[hit, ndcg, mrr, recall]') 80 | parser.add_argument('--k', default=[5, 10], help='top k for each metric') 81 | parser.add_argument('--valid_metric', default='hit@10', help='specifies which indicator to apply early stop') 82 | parser.add_argument('--eval_batch', default=256, type=int) 83 | 84 | # save 85 | parser.add_argument('--log_save', default='log', type=str, help='log saving path') 86 | parser.add_argument('--save', default='save', type=str, help='model saving path') 87 | parser.add_argument('--model_saved', default=None, type=str) 88 | 89 | return parser 90 | 91 | 92 | def config_override(model_config, cmd_config): 93 | default_config = get_default_config() 94 | command_args = set([arg for arg in vars(cmd_config)]) 95 | # overwrite model config by cmd config 96 | for arg in vars(model_config): 97 | if arg in command_args: 98 | setattr(model_config, arg, getattr(cmd_config, arg)) 99 | 100 | # overwrite default config by cmd config 101 | for arg in vars(default_config): 102 | if arg in command_args: 103 | setattr(default_config, arg, getattr(cmd_config, arg)) 104 | 105 | # overwrite default config by model config 106 | for arg in vars(model_config): 107 | setattr(default_config, arg, getattr(model_config, arg)) 108 | 109 | return default_config 110 | -------------------------------------------------------------------------------- /src/train/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | from src.dataset import dataset 8 | from src.dataset.dataset import load_specified_dataset 9 | from src.dataset.data_processor import DataProcessor 10 | from src.evaluation.estimator import Estimator 11 | from src.utils.recorder import Recorder 12 | import src.model as model 13 | from src.train.config import experiment_hyper_load, config_override 14 | from src.utils.utils import set_seed, batch_to_device, KMeans 15 | 16 | 17 | # torch.autograd.set_detect_anomaly(True) 18 | 19 | def load_trainer(config): 20 | if config.model in ['ICLRec']: 21 | return ICLTrainer(config) 22 | return Trainer(config) 23 | 24 | 25 | class Trainer: 26 | def __init__(self, config): 27 | self.config = config 28 | self.model_name = config.model 29 | self._config_override(self.model_name, config) 30 | # pretraining 31 | self.pretraining_model = None 32 | self.do_pretraining = self.config.do_pretraining 33 | self.pretraining_task = self.config.pretraining_task 34 | self.pretraining_epoch = self.config.pretraining_epoch 35 | self.pretraining_batch = self.config.pretraining_batch 36 | self.pretraining_lr = self.config.pretraining_lr 37 | self.pretraining_l2 = self.config.pretraining_l2 38 | 39 | # training 40 | self.training_model = None 41 | self.num_worker = self.config.num_worker 42 | self.train_batch = self.config.train_batch 43 | self.eval_batch = self.config.eval_batch 44 | self.lr = self.config.learning_rate 45 | self.l2 = self.config.l2 46 | self.epoch_num = self.config.epoch_num 47 | self.dev = torch.device(self.config.device) 48 | self.split_type = self._set_split_mode(self.config.split_type) 49 | self.do_test = self.split_type == 'valid_and_test' 50 | 51 | # components 52 | self.data_processor = DataProcessor(self.config) 53 | self.estimator = Estimator(self.config) 54 | self.recorder = Recorder(self.config) 55 | 56 | # set random seed 57 | set_seed(self.config.seed) 58 | 59 | # preparing data 60 | data_dict, additional_data_dict = self.data_processor.prepare_data() 61 | self.data_dict = data_dict # store standard train/eval/test data 62 | self.additional_data_dict = additional_data_dict # extra data (model specified) 63 | 64 | # check_duplication(self.data_dict['train'][0]) 65 | 66 | self.estimator.load_item_popularity(self.data_processor.popularity) 67 | self._set_num_items() 68 | 69 | def start_training(self): 70 | if self.do_pretraining: 71 | self.pretrain() 72 | self.train() 73 | 74 | def pretrain(self): 75 | if self.pretraining_task in ['MISP', 'MIM', 'PID']: 76 | pretrain_dataset = getattr(dataset, f'{self.pretraining_task}PretrainDataset') 77 | pretrain_dataset = pretrain_dataset(self.config, self.data_dict['train'], 78 | self.additional_data_dict) 79 | else: 80 | raise NotImplementedError(f'No such pretraining task: {self.pretraining_task}, ' 81 | f'choosing from [MIP, MIM, PID]') 82 | train_loader = DataLoader(pretrain_dataset, batch_size=self.train_batch, collate_fn=pretrain_dataset.collate_fn, 83 | shuffle=True, num_workers=0, drop_last=False) 84 | 85 | pretrain_model = self._load_model() 86 | 87 | opt = torch.optim.Adam(filter(lambda x: x.requires_grad, pretrain_model.parameters()), 88 | self.pretraining_lr, weight_decay=self.pretraining_l2) 89 | 90 | self.experiment_setting_verbose(pretrain_model, training=False) 91 | 92 | logging.info('Start pretraining...') 93 | for epoch in range(self.pretraining_epoch): 94 | pretrain_model.train() 95 | self.recorder.epoch_restart() 96 | self.recorder.tik_start() 97 | train_iter = tqdm(enumerate(train_loader), total=len(train_loader)) 98 | train_iter.set_description(f'pretraining epoch: {epoch}') 99 | for i, batch_dict in train_iter: 100 | batch_to_device(batch_dict, self.dev) 101 | loss = getattr(pretrain_model, f'{self.pretraining_task}_pretrain_forward')(batch_dict) 102 | opt.zero_grad() 103 | loss.backward() 104 | opt.step() 105 | 106 | self.recorder.save_batch_loss(loss.item()) 107 | self.recorder.tik_end() 108 | self.recorder.train_log_verbose(len(train_loader)) 109 | 110 | self.pretraining_model = pretrain_model 111 | logging.info('Pre-training is over, prepare for training...') 112 | 113 | def train(self): 114 | SpecifiedDataSet = load_specified_dataset(self.model_name, self.config) 115 | train_dataset = SpecifiedDataSet(self.config, self.data_dict['train'], 116 | self.additional_data_dict) 117 | train_loader = DataLoader(train_dataset, batch_size=self.train_batch, collate_fn=train_dataset.collate_fn, 118 | shuffle=True, num_workers=self.num_worker, drop_last=False) 119 | 120 | eval_dataset = SpecifiedDataSet(self.config, self.data_dict['eval'], 121 | self.additional_data_dict, train=False) 122 | eval_loader = DataLoader(eval_dataset, batch_size=self.eval_batch, collate_fn=eval_dataset.collate_fn, 123 | shuffle=False, num_workers=self.num_worker, drop_last=False) 124 | 125 | self.training_model = self._load_model() 126 | 127 | opt = torch.optim.Adam(filter(lambda x: x.requires_grad, self.training_model.parameters()), self.lr, 128 | weight_decay=self.l2) 129 | self.recorder.reset() 130 | self.experiment_setting_verbose(self.training_model) 131 | 132 | logging.info('Start training...') 133 | for epoch in range(self.epoch_num): 134 | self.training_model.train() 135 | self.recorder.epoch_restart() 136 | self.recorder.tik_start() 137 | train_iter = tqdm(enumerate(train_loader), total=len(train_loader)) 138 | train_iter.set_description('training ') 139 | for step, batch_dict in train_iter: 140 | # training forward 141 | batch_dict['epoch'] = epoch 142 | batch_dict['step'] = step 143 | batch_to_device(batch_dict, self.dev) 144 | loss = self.training_model.train_forward(batch_dict) 145 | if torch.is_tensor(loss) and loss.requires_grad: 146 | opt.zero_grad() 147 | loss.backward() 148 | opt.step() 149 | self.recorder.save_batch_loss(loss.item()) 150 | self.recorder.tik_end() 151 | self.recorder.train_log_verbose(len(train_loader)) 152 | 153 | # evaluation 154 | self.recorder.tik_start() 155 | eval_metric_result, eval_loss = self.estimator.evaluate(eval_loader, self.training_model) 156 | self.recorder.tik_end(mode='eval') 157 | self.recorder.log_verbose_and_save(eval_metric_result, eval_loss, self.training_model) 158 | 159 | if self.recorder.early_stop: 160 | break 161 | 162 | self.recorder.report_best_res() 163 | # test model 164 | if self.do_test: 165 | test_metric_res = self.test_model(self.data_dict['test']) 166 | self.recorder.report_test_result(test_metric_res) 167 | 168 | def _set_split_mode(self, split_mode): 169 | assert split_mode in ['valid_and_test', 'valid_only'], f'Invalid split mode: {split_mode} !' 170 | return split_mode 171 | 172 | def _load_model(self): 173 | if self.do_pretraining and self.pretraining_model is not None: # return pretraining model 174 | return self.pretraining_model 175 | 176 | # return new model 177 | if self.config.model_type.upper() == 'SEQUENTIAL': 178 | return self._load_sequential_model() 179 | elif self.config.model_type.upper() in ['GRAPH', 'KNOWLEDGE']: 180 | return self._load_model_with_additional_data() 181 | else: 182 | raise KeyError(f'Invalid model_type:{self.config.model_type}. Choose from [sequential, knowledge, graph]') 183 | 184 | def _load_sequential_model(self): 185 | Model = getattr(model, self.model_name) 186 | specified_seq_model = Model(self.config, self.additional_data_dict).to(self.dev) 187 | return specified_seq_model 188 | 189 | def _load_model_with_additional_data(self): 190 | Model = getattr(model, self.model_name) 191 | specified_model = Model(self.config, self.additional_data_dict).to(self.dev) 192 | return specified_model 193 | 194 | def _config_override(self, model_name, cmd_config): 195 | self.model_config = getattr(model, f'{model_name}_config')() 196 | self.config = config_override(self.model_config, cmd_config) 197 | # capitalize 198 | self.config.model_type = self.config.model_type.upper() 199 | self.config.graph_type = [g_type.upper() for g_type in self.config.graph_type] 200 | 201 | def _set_num_items(self): 202 | self.config.num_items = self.data_processor.num_items 203 | 204 | def experiment_setting_verbose(self, model, training=True): 205 | if self.do_pretraining and training: 206 | return 207 | # model config 208 | logging.info('[1] Model Hyper-Parameter '.ljust(47, '-')) 209 | model_param_set = self.model_config.keys() 210 | for arg in vars(self.config): 211 | if arg in model_param_set: 212 | logging.info(f'{arg}: {getattr(self.config, arg)}') 213 | # experiment config 214 | logging.info('[2] Experiment Hyper-Parameter '.ljust(47, '-')) 215 | # verbose_order = ['Data', 'Training', 'Evaluation', 'Save'] 216 | hyper_types, exp_setting = experiment_hyper_load(self.config) 217 | for i, hyper_type in enumerate(hyper_types): 218 | hyper_start_log = (f'[2-{i + 1}] ' + hyper_type.lower() + ' hyper-parameter ').ljust(47, '-') 219 | logging.info(hyper_start_log) 220 | for hyper, value in exp_setting[hyper_type].items(): 221 | logging.info(f'{hyper}: {value}') 222 | # data statistic 223 | self.data_processor.data_log_verbose(3) 224 | # model architecture 225 | self.report_model_info(model) 226 | 227 | def report_model_info(self, model): 228 | # model architecture 229 | logging.info('[1] Model Architecture '.ljust(47, '-')) 230 | logging.info(f'total parameters: {model.calc_total_params()}') 231 | logging.info(model) 232 | 233 | def test_model(self, test_data_pair=None): 234 | SpecifiedDataSet = load_specified_dataset(self.model_name, self.config) 235 | test_dataset = SpecifiedDataSet(self.config, test_data_pair, 236 | self.additional_data_dict, train=False) 237 | test_loader = DataLoader(test_dataset, batch_size=self.eval_batch, num_workers=self.num_worker, 238 | collate_fn=test_dataset.collate_fn, drop_last=False, shuffle=False) 239 | # load the best model 240 | self.recorder.load_best_model(self.training_model) 241 | self.training_model.eval() 242 | test_metric_result = self.estimator.test(test_loader, self.training_model) 243 | 244 | return test_metric_result 245 | 246 | def start_test(self): 247 | self.training_model = self._load_model() 248 | self.experiment_setting_verbose(self.training_model) 249 | test_metric_res = self.test_model(self.data_dict['test']) 250 | self.recorder.report_test_result(test_metric_res) 251 | 252 | 253 | class ICLTrainer(Trainer): 254 | def __init__(self, config): 255 | super(ICLTrainer, self).__init__(config) 256 | self.num_intent_cluster = config.num_intent_cluster 257 | self.seq_representation_type = config.seq_representation_type 258 | # initialize Kmeans 259 | if self.seq_representation_type == "mean": 260 | cluster = KMeans( 261 | num_cluster=self.num_intent_cluster, 262 | seed=self.config.seed, 263 | hidden_size=self.config.embed_size, 264 | device=self.config.device, 265 | ) 266 | else: 267 | cluster = KMeans( 268 | num_cluster=self.num_intent_cluster, 269 | seed=self.config.seed, 270 | hidden_size=self.config.embed_size * self.config.max_len, 271 | device=self.config.device, 272 | ) 273 | self.cluster = cluster 274 | 275 | def train(self): 276 | SpecifiedDataSet = load_specified_dataset(self.model_name, self.config) 277 | intent_cluster_dataset = SpecifiedDataSet(self.config, self.data_dict['raw_train'], 278 | self.additional_data_dict) 279 | intent_cluster_loader = DataLoader(intent_cluster_dataset, batch_size=self.train_batch, 280 | collate_fn=intent_cluster_dataset.collate_fn, 281 | shuffle=True, num_workers=self.num_worker, drop_last=False) 282 | 283 | train_dataset = SpecifiedDataSet(self.config, self.data_dict['train'], 284 | self.additional_data_dict) 285 | train_loader = DataLoader(train_dataset, batch_size=self.train_batch, collate_fn=train_dataset.collate_fn, 286 | shuffle=True, num_workers=self.num_worker, drop_last=False) 287 | 288 | eval_dataset = SpecifiedDataSet(self.config, self.data_dict['eval'], 289 | self.additional_data_dict, train=False) 290 | eval_loader = DataLoader(eval_dataset, batch_size=self.eval_batch, collate_fn=eval_dataset.collate_fn, 291 | shuffle=False, num_workers=self.num_worker, drop_last=False) 292 | 293 | self.training_model = self._load_model() 294 | 295 | opt = torch.optim.Adam(filter(lambda x: x.requires_grad, self.training_model.parameters()), self.lr, 296 | weight_decay=self.l2) 297 | self.recorder.reset() 298 | self.experiment_setting_verbose(self.training_model) 299 | 300 | logging.info('Start training...') 301 | for epoch in range(self.epoch_num): 302 | self.training_model.train() 303 | self.recorder.epoch_restart() 304 | self.recorder.tik_start() 305 | 306 | # collect cluster data 307 | intent_cluster_iter = tqdm(enumerate(intent_cluster_loader), total=len(intent_cluster_loader)) 308 | intent_cluster_iter.set_description('prepare clustering ') 309 | 310 | kmeans_training_data = [] # store all user intent representations in training data 311 | for step, batch_dict in intent_cluster_iter: 312 | batch_to_device(batch_dict, self.dev) 313 | item_seq, seq_len = batch_dict['item_seq'], batch_dict['seq_len'] 314 | sequence_output = self.training_model.seq_encoding(item_seq, seq_len, return_all=True) 315 | # average sum 316 | if self.seq_representation_type == "mean": 317 | sequence_output = torch.mean(sequence_output, dim=1, keepdim=False) 318 | sequence_output = sequence_output.view(sequence_output.shape[0], -1) # otherwise concat 319 | sequence_output = sequence_output.detach().cpu().numpy() 320 | kmeans_training_data.append(sequence_output) 321 | kmeans_training_data = np.concatenate(kmeans_training_data, axis=0) # [user_size, dim] 322 | 323 | # train cluster 324 | self.cluster.train(kmeans_training_data) 325 | 326 | # clean memory 327 | del kmeans_training_data 328 | import gc 329 | 330 | gc.collect() 331 | 332 | train_iter = tqdm(enumerate(train_loader), total=len(train_loader)) 333 | train_iter.set_description('training ') 334 | for step, batch_dict in train_iter: 335 | # training forward 336 | batch_dict['epoch'] = epoch 337 | batch_dict['step'] = step 338 | batch_dict['cluster'] = self.cluster 339 | batch_to_device(batch_dict, self.dev) 340 | loss = self.training_model.train_forward(batch_dict) 341 | if torch.is_tensor(loss) and loss.requires_grad: 342 | opt.zero_grad() 343 | loss.backward() 344 | opt.step() 345 | self.recorder.save_batch_loss(loss.item()) 346 | self.recorder.tik_end() 347 | self.recorder.train_log_verbose(len(train_loader)) 348 | 349 | # evaluation 350 | self.recorder.tik_start() 351 | eval_metric_result, eval_loss = self.estimator.evaluate(eval_loader, self.training_model) 352 | self.recorder.tik_end(mode='eval') 353 | self.recorder.log_verbose_and_save(eval_metric_result, eval_loss, self.training_model) 354 | 355 | if self.recorder.early_stop: 356 | break 357 | 358 | self.recorder.report_best_res() 359 | # test model 360 | if self.do_test: 361 | test_metric_res = self.test_model(self.data_dict['test']) 362 | self.recorder.report_test_result(test_metric_res) 363 | 364 | 365 | if __name__ == '__main__': 366 | pass 367 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/recorder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__pycache__/recorder.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/recorder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import logging 3 | import datetime 4 | import torch 5 | import os 6 | import time as t 7 | import numpy as np 8 | 9 | 10 | class Recorder: 11 | def __init__(self, config): 12 | self.epoch = 0 13 | self.model_name = config.model 14 | self.dataset = config.dataset 15 | self.run_mark = config.mark 16 | self.log_path = config.log_save 17 | # metric 18 | self.metrics = config.metric 19 | self.k_list = config.k 20 | # records 21 | self.batch_loss_rec = 0. 22 | self.metric_records = {} 23 | self.time_record = {'train': 0., 'eval': 0.} 24 | self.decimal_round = 4 25 | self.mark = config.mark 26 | self.model_saved = config.model_saved 27 | 28 | # early stop 29 | self.early_stop = False 30 | self.core_metric = config.valid_metric 31 | self.patience = int(config.patience) 32 | self.best_metric_rec = {'epoch': 0, 'score': 0.} 33 | self.step_2_stop = self.patience 34 | # log report 35 | self.block_size = 6 36 | self.half_underline = self.block_size * len(self.metrics) * len(self.k_list) 37 | 38 | self._recoder_init(config) 39 | 40 | def reset(self): 41 | self.epoch = 0 42 | 43 | def _recoder_init(self, config): 44 | self._init_log() 45 | self._init_record() 46 | self._model_saving_init(config) 47 | 48 | def _model_saving_init(self, config): 49 | # check saving path 50 | if not os.path.exists(config.save): 51 | os.mkdir(config.save) 52 | # init model saving path 53 | curr_time = datetime.datetime.now() 54 | timestamp = datetime.datetime.strftime(curr_time, '%Y-%m-%d_%H-%M-%S') 55 | if self.model_saved is None: 56 | self.model_saved = config.save + f'\\{config.model}-{self.dataset}-{self.mark}-{timestamp}.pth' 57 | logging.info(f'model save at: {self.model_saved}') 58 | 59 | def _init_log(self): 60 | save_path = os.path.join(self.log_path, self.dataset) 61 | if not os.path.isdir(save_path): 62 | os.makedirs(save_path) 63 | times = 1 64 | log_model_name = self.model_name + f'-{self.run_mark}' if len(self.run_mark) > 0 else self.model_name 65 | log_file = os.path.join(save_path, '%s_%d.log' % (log_model_name, times)) 66 | for i in range(100): 67 | if not os.path.isfile(log_file): 68 | break 69 | log_file = os.path.join(save_path, '%s_%d.log' % (log_model_name, times + i + 1)) 70 | 71 | logging.basicConfig( 72 | format='%(asctime)s %(levelname)-8s %(message)s', 73 | level=logging.INFO, 74 | datefmt='%Y-%m-%d %H:%M:%S', 75 | filename=log_file, 76 | filemode='w' 77 | ) 78 | console = logging.StreamHandler() 79 | console.setLevel(logging.INFO) 80 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 81 | console.setFormatter(formatter) 82 | logging.getLogger('').addHandler(console) 83 | logging.info('log save at : {}'.format(log_file)) 84 | 85 | def _init_record(self): 86 | for metric in self.metrics: 87 | for k in self.k_list: 88 | self.metric_records[f'{metric}@{k}'] = [] 89 | assert self.core_metric in self.metric_records.keys(), f'Invalid valid_metric: [{self.core_metric}], ' \ 90 | f'choose from: {self.metric_records.keys()} !' 91 | 92 | def save_model(self, model): 93 | # save entire model 94 | # torch.save(model, self.model_saving) 95 | 96 | # only save model parameters 97 | torch.save(model.state_dict(), self.model_saved) 98 | 99 | def load_best_model(self, model): 100 | # load entire model 101 | # return torch.load(self.model_saving) 102 | 103 | # load parameters 104 | model.load_state_dict(torch.load(self.model_saved)) 105 | 106 | def epoch_restart(self): 107 | self.batch_loss_rec = 0. 108 | self.epoch += 1 109 | 110 | def save_batch_loss(self, batch_loss): 111 | self.batch_loss_rec += batch_loss 112 | 113 | def tik_start(self): 114 | self._clock = t.time() 115 | 116 | def tik_end(self, mode='train'): 117 | end_clock = t.time() 118 | self.time_record[mode] = end_clock - self._clock 119 | 120 | def _save_best_result(self, metric_res, model): 121 | """ 122 | :param metric_res: dict 123 | """ 124 | for metric, score in metric_res.items(): 125 | self.metric_records.get(metric).append(score) 126 | # early stop 127 | core_metric_res = metric_res.get(self.core_metric) 128 | self.early_stop_check(core_metric_res, model) 129 | 130 | def early_stop_check(self, core_metric_res, model): 131 | if core_metric_res > self.best_metric_rec.get('score'): 132 | self.best_metric_rec['score'] = core_metric_res 133 | self.best_metric_rec['epoch'] = self.epoch 134 | self.step_2_stop = self.patience 135 | # find a better model -> save 136 | self.save_model(model) 137 | else: 138 | self.step_2_stop -= 1 139 | logging.info(f'EarlyStopping Counter: {self.patience - self.step_2_stop} out of {self.patience}') 140 | if self.step_2_stop == 0: 141 | self.early_stop = True 142 | 143 | def train_log_verbose(self, num_batch): 144 | training_loss = self.batch_loss_rec / num_batch 145 | logging.info('-' * self.half_underline + f'----Epoch {self.epoch}----' + '-' * self.half_underline) 146 | output_str = " Training Time :[%.1f s]\tTraining Loss = %.4f" % (self.time_record['train'], training_loss) 147 | logging.info(output_str) 148 | 149 | def log_verbose_and_save(self, metric_score, eval_loss, model): 150 | res_str = '' 151 | for metric, score in metric_score.items(): 152 | score = round(score, self.decimal_round) 153 | res_str += f'{metric}:{score:1.4f}\t' 154 | 155 | eval_time = round(self.time_record['eval'], 1) 156 | if eval_loss <= 0: 157 | eval_loss = '**' 158 | else: 159 | eval_loss = round(eval_loss, 4) 160 | logging.info(f"Evaluation Time:[{eval_time} s]\t Eval Loss = {eval_loss}") 161 | logging.info(res_str) 162 | 163 | # save results and model 164 | self._save_best_result(metric_score, model) 165 | 166 | def report_best_res(self): 167 | best_epoch = self.best_metric_rec['epoch'] 168 | logging.info('-' * self.half_underline + 'Best Evaluation' + '-' * self.half_underline) 169 | logging.info(f"Best Result at Epoch: {best_epoch}\t Early Stop at Patience: {self.patience}") 170 | # load best results 171 | best_metrics_res = {} 172 | for metric, metric_res_list in self.metric_records.items(): 173 | best_metrics_res[metric] = metric_res_list[best_epoch - 1] 174 | res_str = '' 175 | for metric, score in best_metrics_res.items(): 176 | score = round(score, self.decimal_round) 177 | res_str += f'{metric}:{score:1.4f}\t' 178 | logging.info(res_str) 179 | 180 | def report_test_result(self, test_metric_res): 181 | res_str = '' 182 | for metric, score in test_metric_res.items(): 183 | score = round(score, self.decimal_round) 184 | res_str += f'{metric}:{score:1.4f}\t' 185 | logging.info('-' * self.half_underline + f'-----Test Results------' + '-' * self.half_underline) 186 | logging.info(res_str) -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import logging 3 | import datetime 4 | import random 5 | from typing import List 6 | 7 | # import faiss 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import os 12 | import time as t 13 | import numpy as np 14 | from easydict import EasyDict 15 | 16 | 17 | def set_seed(seed): 18 | if seed == -1: 19 | return 20 | random.seed(seed) 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | # some cudnn methods can be random even after fixing the seed 27 | # unless you tell it to be deterministic 28 | torch.backends.cudnn.deterministic = True 29 | 30 | 31 | def batch_to_device(tensor_dict: dict, dev): 32 | for key, obj in tensor_dict.items(): 33 | if torch.is_tensor(obj): 34 | tensor_dict[key] = obj.to(dev) 35 | 36 | 37 | def load_pickle(file_path): 38 | with open(file_path, 'rb') as fr: 39 | return pickle.load(fr) 40 | 41 | 42 | def save_pickle(obj, file_path): 43 | with open(file_path, 'wb') as f: 44 | pickle.dump(obj, f) 45 | 46 | 47 | def pkl_to_txt(dataset='beauty'): 48 | data_dir = '../../dataset' 49 | file_path = os.path.join(data_dir, f'{dataset}/{dataset}_seq.pkl') 50 | record = load_pickle(file_path) 51 | 52 | # write to xxx.txt 53 | target_file = os.path.join(data_dir, f'{dataset}/{dataset}_seq.txt') 54 | with open(target_file, 'w') as fw: 55 | for seq in record: 56 | seq = list(map(str, seq)) 57 | seq_str = ' '.join(seq) + '\n' 58 | fw.write(seq_str) 59 | 60 | 61 | def freeze(layer): 62 | for child in layer.children(): 63 | for param in child.parameters(): 64 | param.requires_grad = False 65 | 66 | 67 | def neg_sample(item_set, item_size): # 前闭后闭 68 | item = random.randint(1, item_size - 1) 69 | while item in item_set: 70 | item = random.randint(1, item_size - 1) 71 | return item 72 | 73 | 74 | def get_activate(act='relu'): 75 | if act == 'relu': 76 | return nn.ReLU() 77 | elif act == 'leaky_relu': 78 | return nn.LeakyReLU() 79 | elif act == 'gelu': 80 | return nn.GELU() 81 | elif act == 'tanh': 82 | return nn.Tanh() 83 | elif act == 'sigmoid': 84 | return nn.Sigmoid() 85 | else: 86 | raise KeyError(f'Not support current activate function: {act}, please add by yourself.') 87 | 88 | 89 | class HyperParamDict(EasyDict): 90 | def __init__(self, description=None): 91 | super(HyperParamDict, self).__init__({}) 92 | self.description = description 93 | self.attr_registered = [] 94 | 95 | def add_argument(self, param_name, type=object, default=None, action=None, choices=None, help=None): 96 | param_name = self._parse_param_name(param_name) 97 | if default and type: 98 | try: 99 | default = type(default) 100 | except Exception: 101 | assert isinstance(default, type), f'KeyError. Type of param {param_name} should be {type}.' 102 | if choices: 103 | assert isinstance(choices, List), f'choices should be a list.' 104 | assert default in choices, f'KeyError. Please choose {param_name} from {choices}. ' \ 105 | f'Now {param_name} = {default}.' 106 | if action: 107 | default = self._parse_action(action) 108 | if help: 109 | assert isinstance(help, str), f'help should be a str.' 110 | self.attr_registered.append(param_name) 111 | self.__setattr__(param_name, default) 112 | 113 | @staticmethod 114 | def _parse_param_name(param_name: str): 115 | index = param_name.rfind('-') # find last pos of -, return -1 on failure 116 | return param_name[index + 1:] 117 | 118 | @staticmethod 119 | def _parse_action(action): 120 | action_infos = action.split('_') 121 | assert action_infos[0] == 'store' and action_infos[-1] in ['false', 'true'], \ 122 | f"Wrong action format: {action}. Please choose from ['store_false', 'store_true']." 123 | res = False if action_infos[-1] == 'true' else True 124 | return res 125 | 126 | def keys(self): 127 | return self.attr_registered 128 | 129 | def values(self): 130 | return [self.get(key) for key in self.attr_registered] 131 | 132 | def items(self): 133 | return [(key, self.get(key)) for key in self.attr_registered] 134 | 135 | def __str__(self): 136 | info_str = 'HyperParamDict{' 137 | param_list = [] 138 | for key, value in self.items(): 139 | param_list.append(f'({key}: {value})') 140 | info_str += ', '.join(param_list) + '}' 141 | return info_str 142 | 143 | 144 | class KMeans(object): 145 | def __init__(self, num_cluster, seed, hidden_size, gpu_id=0, device="cpu"): 146 | """ 147 | Args: 148 | k: number of clusters 149 | """ 150 | self.seed = seed 151 | self.num_cluster = num_cluster 152 | self.max_points_per_centroid = 4096 153 | self.min_points_per_centroid = 0 154 | self.gpu_id = 0 155 | self.device = device 156 | self.first_batch = True 157 | self.hidden_size = hidden_size 158 | self.clus, self.index = self.__init_cluster(self.hidden_size) 159 | self.centroids = [] # cluster centroids 160 | 161 | def __init_cluster( 162 | self, hidden_size, verbose=False, niter=20, nredo=5, max_points_per_centroid=4096, min_points_per_centroid=0 163 | ): 164 | logging.info(f" cluster train iterations: {niter}") 165 | clus = faiss.Clustering(hidden_size, self.num_cluster) 166 | clus.verbose = verbose 167 | clus.niter = niter 168 | clus.nredo = nredo 169 | clus.seed = self.seed 170 | clus.max_points_per_centroid = max_points_per_centroid 171 | clus.min_points_per_centroid = min_points_per_centroid 172 | 173 | res = faiss.StandardGpuResources() 174 | res.noTempMemory() 175 | cfg = faiss.GpuIndexFlatConfig() 176 | cfg.useFloat16 = False 177 | cfg.device = self.gpu_id 178 | index = faiss.GpuIndexFlatL2(res, hidden_size, cfg) 179 | return clus, index 180 | 181 | def train(self, x): 182 | # train to get centroids 183 | if x.shape[0] > self.num_cluster: 184 | self.clus.train(x, self.index) 185 | # get cluster centroids 186 | centroids = faiss.vector_to_array(self.clus.centroids).reshape(self.num_cluster, self.hidden_size) 187 | # convert to cuda Tensors for broadcast 188 | centroids = torch.Tensor(centroids).to(self.device) 189 | self.centroids = nn.functional.normalize(centroids, p=2, dim=1) 190 | 191 | def query(self, x): 192 | """ 193 | Args 194 | x: batch intent representations of shape [B, D] 195 | Returns 196 | seq2cluster: assigned centroid id for each intent of shape [B] 197 | centroids_assignments: assigned centroid representation for each intent of shape [B, D] 198 | """ 199 | # self.index.add(x) 200 | # D : cluster distances of shape [B, 1] I: cluster assignments of shape [B, 1] 201 | D, I = self.index.search(x, 1) # for each sample, find cluster distance and assignments 202 | seq2clusterID = [int(n[0]) for n in I] 203 | # print("cluster number:", self.num_cluster,"cluster in batch:", len(set(seq2cluster))) 204 | seq2clusterID = torch.LongTensor(seq2clusterID).to(self.device) 205 | centroids_assignments = self.centroids[seq2clusterID] 206 | return seq2clusterID, centroids_assignments 207 | 208 | 209 | def get_gpu_usage(device=None): 210 | r""" Return the reserved memory and total memory of given device in a string. 211 | Args: 212 | device: cuda.device. It is the device that the model run on. 213 | 214 | Returns: 215 | str: it contains the info about reserved memory and total memory of given device. 216 | """ 217 | 218 | reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 3 219 | total = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 220 | 221 | return '{:.2f} G/{:.2f} G'.format(reserved, total) 222 | 223 | 224 | if __name__ == '__main__': 225 | datasets = ['ml-10m'] 226 | for dataset in datasets: 227 | pkl_to_txt(dataset) 228 | -------------------------------------------------------------------------------- /toys.sh: -------------------------------------------------------------------------------- 1 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --k_intention 4 --seed 2023 --mark seed2023 2 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --ffn_hidden 256 --k_intention 4 --seed 2024 --mark seed2024-hidden256 3 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --ffn_hidden 128 --k_intention 4 --seed 2024 --mark seed2024-hidden128 --------------------------------------------------------------------------------