├── .idea
├── .gitignore
├── IOCRec.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── dataset
└── toys
│ └── toys.seq
├── log
└── toys
│ ├── IOCRec-seed-2024_1.log
│ ├── IOCRec-seed2023_1.log
│ ├── IOCRec-seed2024-hidden128_1.log
│ ├── IOCRec-seed2024-hidden256_1.log
│ └── IOCRec_1.log
├── runIOCRec.py
├── src
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── __init__.cpython-39.pyc
├── dataset
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── data_processor.cpython-38.pyc
│ │ └── dataset.cpython-38.pyc
│ ├── data_processor.py
│ └── dataset.py
├── evaluation
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── estimator.cpython-38.pyc
│ │ └── metrics.cpython-38.pyc
│ ├── estimator.py
│ └── metrics.py
├── model
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── abstract_recommeder.cpython-38.pyc
│ │ ├── data_augmentation.cpython-38.pyc
│ │ ├── loss.cpython-38.pyc
│ │ └── sequential_encoder.cpython-38.pyc
│ ├── abstract_recommeder.py
│ ├── cl_based_seq_recommender
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── cl4srec.cpython-38.pyc
│ │ │ └── iocrec.cpython-38.pyc
│ │ ├── cl4srec.py
│ │ └── iocrec.py
│ ├── data_augmentation.py
│ ├── loss.py
│ ├── sequential_encoder.py
│ └── sequential_recommender
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── sasrec.cpython-38.pyc
│ │ └── sasrec.py
├── train
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── config.cpython-38.pyc
│ │ └── trainer.cpython-38.pyc
│ ├── config.py
│ └── trainer.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── recorder.cpython-38.pyc
│ └── utils.cpython-38.pyc
│ ├── recorder.py
│ └── utils.py
└── toys.sh
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/IOCRec.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IOCRec
2 | Pytorch implementation for paper: Multi-Intention Oriented Contrastive Learning for Sequential Recommendation (WSDM23).
3 |
4 | We implement IOCRec in Pytorch and obtain quite similar results on Toys under the same experimental setting. The default hyper-parameters are set as the optimal values for Toys reported in the paper. Besides, the training log is available for reproduction.
5 |
6 | ```
7 | 2023-07-07 15:48:05 INFO ------------------------------------------------Best Evaluation------------------------------------------------
8 | 2023-07-07 15:48:05 INFO Best Result at Epoch: 33 Early Stop at Patience: 10
9 | 2023-07-07 15:48:05 INFO hit@5:0.4513 hit@10:0.5453 hit@20:0.6621 hit@50:0.7935 ndcg@5:0.3588 ndcg@10:0.3891 ndcg@20:0.4186 ndcg@50:0.4455
10 | 2023-07-07 15:48:07 INFO -----------------------------------------------------Test Results------------------------------------------------------
11 | 2023-07-07 15:48:07 INFO hit@5:0.4022 hit@10:0.5005 hit@20:0.6205 hit@50:0.7594 ndcg@5:0.3145 ndcg@10:0.3462 ndcg@20:0.3765 ndcg@50:0.4048
12 | ```
13 | ## Datasets
14 | We provide Toys dataset.
15 |
16 | ## Quick Start
17 | You can run the model with the following code:
18 | ```
19 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --k_intention 4
20 | ```
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/log/toys/IOCRec-seed2023_1.log:
--------------------------------------------------------------------------------
1 | 2024-07-23 15:17:34 INFO log save at : log\toys\IOCRec-seed2023_1.log
2 | 2024-07-23 15:17:34 INFO model save at: save\IOCRec-toys-seed2023-2024-07-23_15-17-34.pth
3 | 2024-07-23 15:17:35 INFO [1] Model Hyper-Parameter ---------------------
4 | 2024-07-23 15:17:35 INFO model: IOCRec
5 | 2024-07-23 15:17:35 INFO model_type: SEQUENTIAL
6 | 2024-07-23 15:17:35 INFO aug_types: ['crop', 'mask', 'reorder']
7 | 2024-07-23 15:17:35 INFO crop_ratio: 0.2
8 | 2024-07-23 15:17:35 INFO mask_ratio: 0.7
9 | 2024-07-23 15:17:35 INFO reorder_ratio: 0.2
10 | 2024-07-23 15:17:35 INFO all_hidden: True
11 | 2024-07-23 15:17:35 INFO tao: 1.0
12 | 2024-07-23 15:17:35 INFO lamda: 0.1
13 | 2024-07-23 15:17:35 INFO k_intention: 4
14 | 2024-07-23 15:17:35 INFO embed_size: 64
15 | 2024-07-23 15:17:35 INFO ffn_hidden: 512
16 | 2024-07-23 15:17:35 INFO num_blocks: 3
17 | 2024-07-23 15:17:35 INFO num_heads: 2
18 | 2024-07-23 15:17:35 INFO hidden_dropout: 0.5
19 | 2024-07-23 15:17:35 INFO attn_dropout: 0.5
20 | 2024-07-23 15:17:35 INFO layer_norm_eps: 1e-12
21 | 2024-07-23 15:17:35 INFO initializer_range: 0.02
22 | 2024-07-23 15:17:35 INFO loss_type: CE
23 | 2024-07-23 15:17:35 INFO [2] Experiment Hyper-Parameter ----------------
24 | 2024-07-23 15:17:35 INFO [2-1] data hyper-parameter --------------------
25 | 2024-07-23 15:17:35 INFO dataset: toys
26 | 2024-07-23 15:17:35 INFO data_aug: True
27 | 2024-07-23 15:17:35 INFO seq_filter_len: 0
28 | 2024-07-23 15:17:35 INFO if_filter_target: False
29 | 2024-07-23 15:17:35 INFO max_len: 50
30 | 2024-07-23 15:17:35 INFO [2-2] pretraining hyper-parameter -------------
31 | 2024-07-23 15:17:35 INFO do_pretraining: False
32 | 2024-07-23 15:17:35 INFO pretraining_task: MISP
33 | 2024-07-23 15:17:35 INFO pretraining_epoch: 10
34 | 2024-07-23 15:17:35 INFO pretraining_batch: 512
35 | 2024-07-23 15:17:35 INFO pretraining_lr: 0.001
36 | 2024-07-23 15:17:35 INFO pretraining_l2: 0.0
37 | 2024-07-23 15:17:35 INFO [2-3] training hyper-parameter ----------------
38 | 2024-07-23 15:17:35 INFO epoch_num: 150
39 | 2024-07-23 15:17:35 INFO train_batch: 256
40 | 2024-07-23 15:17:35 INFO learning_rate: 0.001
41 | 2024-07-23 15:17:35 INFO l2: 0
42 | 2024-07-23 15:17:35 INFO patience: 10
43 | 2024-07-23 15:17:35 INFO device: cuda:0
44 | 2024-07-23 15:17:35 INFO num_worker: 0
45 | 2024-07-23 15:17:35 INFO seed: 2023
46 | 2024-07-23 15:17:35 INFO [2-4] evaluation hyper-parameter --------------
47 | 2024-07-23 15:17:35 INFO split_type: valid_and_test
48 | 2024-07-23 15:17:35 INFO split_mode: LS
49 | 2024-07-23 15:17:35 INFO eval_mode: uni100
50 | 2024-07-23 15:17:35 INFO metric: ['hit', 'ndcg']
51 | 2024-07-23 15:17:35 INFO k: [5, 10, 20, 50]
52 | 2024-07-23 15:17:35 INFO valid_metric: hit@10
53 | 2024-07-23 15:17:35 INFO eval_batch: 256
54 | 2024-07-23 15:17:35 INFO [2-5] save hyper-parameter --------------------
55 | 2024-07-23 15:17:35 INFO log_save: log
56 | 2024-07-23 15:17:35 INFO save: save
57 | 2024-07-23 15:17:35 INFO model_saved: None
58 | 2024-07-23 15:17:35 INFO [3] Data Statistic ----------------------------
59 | 2024-07-23 15:17:35 INFO dataset: toys
60 | 2024-07-23 15:17:35 INFO user number: 19412
61 | 2024-07-23 15:17:35 INFO item number: 11925
62 | 2024-07-23 15:17:35 INFO average seq length: 8.6337
63 | 2024-07-23 15:17:35 INFO density: 0.0007 sparsity: 0.9993
64 | 2024-07-23 15:17:35 INFO data after augmentation:
65 | 2024-07-23 15:17:35 INFO train samples: 109361 eval samples: 19412 test samples: 19412
66 | 2024-07-23 15:17:35 INFO [1] Model Architecture ------------------------
67 | 2024-07-23 15:17:35 INFO total parameters: 1035520
68 | 2024-07-23 15:17:35 INFO IOCRec(
69 | (cross_entropy): CrossEntropyLoss()
70 | (item_embedding): Embedding(11927, 64, padding_idx=0)
71 | (position_embedding): Embedding(50, 64)
72 | (input_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
73 | (input_dropout): Dropout(p=0.5, inplace=False)
74 | (local_encoder): Transformer(
75 | (encoder_layers): ModuleList(
76 | (0-2): 3 x EncoderLayer(
77 | (attn_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
78 | (pff_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
79 | (self_attention): MultiHeadAttentionLayer(
80 | (fc_q): Linear(in_features=64, out_features=64, bias=True)
81 | (fc_k): Linear(in_features=64, out_features=64, bias=True)
82 | (fc_v): Linear(in_features=64, out_features=64, bias=True)
83 | (attn_dropout): Dropout(p=0.5, inplace=False)
84 | (fc_o): Linear(in_features=64, out_features=64, bias=True)
85 | )
86 | (pff): PointWiseFeedForwardLayer(
87 | (fc1): Linear(in_features=64, out_features=512, bias=True)
88 | (fc2): Linear(in_features=512, out_features=64, bias=True)
89 | )
90 | (hidden_dropout): Dropout(p=0.5, inplace=False)
91 | (pff_out_drop): Dropout(p=0.5, inplace=False)
92 | )
93 | )
94 | )
95 | (global_seq_encoder): GlobalSeqEncoder(
96 | (dropout): Dropout(p=0.5, inplace=False)
97 | (K_linear): Linear(in_features=64, out_features=64, bias=True)
98 | (V_linear): Linear(in_features=64, out_features=64, bias=True)
99 | )
100 | (disentangle_encoder): DisentangleEncoder(
101 | (pos_fai): Embedding(50, 64)
102 | (W): Linear(in_features=64, out_features=64, bias=True)
103 | (layer_norm_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
104 | (layer_norm_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
105 | (layer_norm_3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
106 | (layer_norm_4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
107 | (layer_norm_5): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
108 | )
109 | (nce_loss): InfoNCELoss(
110 | (criterion): CrossEntropyLoss()
111 | )
112 | )
113 | 2024-07-23 15:17:35 INFO Start training...
114 | 2024-07-23 15:18:09 INFO ----------------------------------------------------Epoch 1----------------------------------------------------
115 | 2024-07-23 15:18:09 INFO Training Time :[33.3 s] Training Loss = 11.0718
116 | 2024-07-23 15:18:10 INFO Evaluation Time:[1.0 s] Eval Loss = 9.3946
117 | 2024-07-23 15:18:10 INFO hit@5:0.0718 hit@10:0.1202 hit@20:0.2062 hit@50:0.3898 ndcg@5:0.0443 ndcg@10:0.0598 ndcg@20:0.0813 ndcg@50:0.1180
118 | 2024-07-23 15:18:42 INFO ----------------------------------------------------Epoch 2----------------------------------------------------
119 | 2024-07-23 15:18:42 INFO Training Time :[32.8 s] Training Loss = 9.9107
120 | 2024-07-23 15:18:43 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3376
121 | 2024-07-23 15:18:43 INFO hit@5:0.1525 hit@10:0.2575 hit@20:0.4075 hit@50:0.4748 ndcg@5:0.0966 ndcg@10:0.1302 ndcg@20:0.1681 ndcg@50:0.1825
122 | 2024-07-23 15:19:16 INFO ----------------------------------------------------Epoch 3----------------------------------------------------
123 | 2024-07-23 15:19:16 INFO Training Time :[32.8 s] Training Loss = 9.5619
124 | 2024-07-23 15:19:17 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1699
125 | 2024-07-23 15:19:17 INFO hit@5:0.2255 hit@10:0.3208 hit@20:0.4547 hit@50:0.5217 ndcg@5:0.1567 ndcg@10:0.1874 ndcg@20:0.2211 ndcg@50:0.2354
126 | 2024-07-23 15:19:50 INFO ----------------------------------------------------Epoch 4----------------------------------------------------
127 | 2024-07-23 15:19:50 INFO Training Time :[32.8 s] Training Loss = 9.2151
128 | 2024-07-23 15:19:51 INFO Evaluation Time:[0.9 s] Eval Loss = 9.0117
129 | 2024-07-23 15:19:51 INFO hit@5:0.2879 hit@10:0.3840 hit@20:0.5012 hit@50:0.5750 ndcg@5:0.2079 ndcg@10:0.2389 ndcg@20:0.2683 ndcg@50:0.2839
130 | 2024-07-23 15:20:24 INFO ----------------------------------------------------Epoch 5----------------------------------------------------
131 | 2024-07-23 15:20:24 INFO Training Time :[32.9 s] Training Loss = 8.9182
132 | 2024-07-23 15:20:25 INFO Evaluation Time:[1.0 s] Eval Loss = 8.8102
133 | 2024-07-23 15:20:25 INFO hit@5:0.3426 hit@10:0.4392 hit@20:0.5495 hit@50:0.6274 ndcg@5:0.2529 ndcg@10:0.2841 ndcg@20:0.3119 ndcg@50:0.3282
134 | 2024-07-23 15:20:58 INFO ----------------------------------------------------Epoch 6----------------------------------------------------
135 | 2024-07-23 15:20:58 INFO Training Time :[33.1 s] Training Loss = 8.6244
136 | 2024-07-23 15:20:59 INFO Evaluation Time:[0.9 s] Eval Loss = 8.581
137 | 2024-07-23 15:20:59 INFO hit@5:0.3856 hit@10:0.4870 hit@20:0.5982 hit@50:0.6972 ndcg@5:0.2900 ndcg@10:0.3227 ndcg@20:0.3508 ndcg@50:0.3712
138 | 2024-07-23 15:21:32 INFO ----------------------------------------------------Epoch 7----------------------------------------------------
139 | 2024-07-23 15:21:32 INFO Training Time :[33.1 s] Training Loss = 8.3363
140 | 2024-07-23 15:21:33 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4352
141 | 2024-07-23 15:21:33 INFO hit@5:0.4089 hit@10:0.5100 hit@20:0.6219 hit@50:0.7328 ndcg@5:0.3132 ndcg@10:0.3459 ndcg@20:0.3741 ndcg@50:0.3968
142 | 2024-07-23 15:22:06 INFO ----------------------------------------------------Epoch 8----------------------------------------------------
143 | 2024-07-23 15:22:06 INFO Training Time :[32.9 s] Training Loss = 8.0958
144 | 2024-07-23 15:22:07 INFO Evaluation Time:[0.9 s] Eval Loss = 8.3217
145 | 2024-07-23 15:22:07 INFO hit@5:0.4293 hit@10:0.5254 hit@20:0.6391 hit@50:0.7704 ndcg@5:0.3325 ndcg@10:0.3635 ndcg@20:0.3922 ndcg@50:0.4189
146 | 2024-07-23 15:22:40 INFO ----------------------------------------------------Epoch 9----------------------------------------------------
147 | 2024-07-23 15:22:40 INFO Training Time :[33.0 s] Training Loss = 7.8828
148 | 2024-07-23 15:22:41 INFO Evaluation Time:[0.9 s] Eval Loss = 8.3024
149 | 2024-07-23 15:22:41 INFO hit@5:0.4324 hit@10:0.5290 hit@20:0.6475 hit@50:0.7881 ndcg@5:0.3393 ndcg@10:0.3705 ndcg@20:0.4003 ndcg@50:0.4288
150 | 2024-07-23 15:23:13 INFO ----------------------------------------------------Epoch 10----------------------------------------------------
151 | 2024-07-23 15:23:13 INFO Training Time :[32.9 s] Training Loss = 7.6411
152 | 2024-07-23 15:23:14 INFO Evaluation Time:[0.9 s] Eval Loss = 8.3549
153 | 2024-07-23 15:23:14 INFO hit@5:0.4394 hit@10:0.5328 hit@20:0.6481 hit@50:0.8010 ndcg@5:0.3455 ndcg@10:0.3756 ndcg@20:0.4047 ndcg@50:0.4356
154 | 2024-07-23 15:23:47 INFO ----------------------------------------------------Epoch 11----------------------------------------------------
155 | 2024-07-23 15:23:47 INFO Training Time :[32.8 s] Training Loss = 7.4921
156 | 2024-07-23 15:23:48 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4414
157 | 2024-07-23 15:23:48 INFO hit@5:0.4396 hit@10:0.5349 hit@20:0.6530 hit@50:0.8046 ndcg@5:0.3475 ndcg@10:0.3783 ndcg@20:0.4081 ndcg@50:0.4386
158 | 2024-07-23 15:24:21 INFO ----------------------------------------------------Epoch 12----------------------------------------------------
159 | 2024-07-23 15:24:21 INFO Training Time :[32.9 s] Training Loss = 7.3233
160 | 2024-07-23 15:24:22 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4664
161 | 2024-07-23 15:24:22 INFO hit@5:0.4418 hit@10:0.5354 hit@20:0.6513 hit@50:0.7995 ndcg@5:0.3486 ndcg@10:0.3788 ndcg@20:0.4079 ndcg@50:0.4379
162 | 2024-07-23 15:24:55 INFO ----------------------------------------------------Epoch 13----------------------------------------------------
163 | 2024-07-23 15:24:55 INFO Training Time :[32.9 s] Training Loss = 7.2029
164 | 2024-07-23 15:24:56 INFO Evaluation Time:[0.9 s] Eval Loss = 8.5498
165 | 2024-07-23 15:24:56 INFO hit@5:0.4423 hit@10:0.5357 hit@20:0.6529 hit@50:0.8065 ndcg@5:0.3504 ndcg@10:0.3806 ndcg@20:0.4102 ndcg@50:0.4412
166 | 2024-07-23 15:25:29 INFO ----------------------------------------------------Epoch 14----------------------------------------------------
167 | 2024-07-23 15:25:29 INFO Training Time :[32.9 s] Training Loss = 7.0766
168 | 2024-07-23 15:25:30 INFO Evaluation Time:[0.9 s] Eval Loss = 8.6568
169 | 2024-07-23 15:25:30 INFO hit@5:0.4435 hit@10:0.5385 hit@20:0.6561 hit@50:0.8055 ndcg@5:0.3522 ndcg@10:0.3829 ndcg@20:0.4125 ndcg@50:0.4427
170 | 2024-07-23 15:26:03 INFO ----------------------------------------------------Epoch 15----------------------------------------------------
171 | 2024-07-23 15:26:03 INFO Training Time :[32.9 s] Training Loss = 6.9914
172 | 2024-07-23 15:26:04 INFO Evaluation Time:[0.9 s] Eval Loss = 8.8304
173 | 2024-07-23 15:26:04 INFO hit@5:0.4402 hit@10:0.5366 hit@20:0.6507 hit@50:0.8021 ndcg@5:0.3506 ndcg@10:0.3818 ndcg@20:0.4106 ndcg@50:0.4412
174 | 2024-07-23 15:26:04 INFO EarlyStopping Counter: 1 out of 10
175 | 2024-07-23 15:26:36 INFO ----------------------------------------------------Epoch 16----------------------------------------------------
176 | 2024-07-23 15:26:36 INFO Training Time :[32.9 s] Training Loss = 6.9066
177 | 2024-07-23 15:26:37 INFO Evaluation Time:[0.9 s] Eval Loss = 8.8853
178 | 2024-07-23 15:26:37 INFO hit@5:0.4405 hit@10:0.5362 hit@20:0.6516 hit@50:0.8025 ndcg@5:0.3508 ndcg@10:0.3817 ndcg@20:0.4108 ndcg@50:0.4413
179 | 2024-07-23 15:26:37 INFO EarlyStopping Counter: 2 out of 10
180 | 2024-07-23 15:27:10 INFO ----------------------------------------------------Epoch 17----------------------------------------------------
181 | 2024-07-23 15:27:10 INFO Training Time :[32.9 s] Training Loss = 6.8368
182 | 2024-07-23 15:27:11 INFO Evaluation Time:[0.9 s] Eval Loss = 8.9803
183 | 2024-07-23 15:27:11 INFO hit@5:0.4457 hit@10:0.5386 hit@20:0.6551 hit@50:0.8060 ndcg@5:0.3545 ndcg@10:0.3846 ndcg@20:0.4139 ndcg@50:0.4444
184 | 2024-07-23 15:27:44 INFO ----------------------------------------------------Epoch 18----------------------------------------------------
185 | 2024-07-23 15:27:44 INFO Training Time :[32.8 s] Training Loss = 6.7691
186 | 2024-07-23 15:27:45 INFO Evaluation Time:[0.9 s] Eval Loss = 9.0
187 | 2024-07-23 15:27:45 INFO hit@5:0.4424 hit@10:0.5354 hit@20:0.6513 hit@50:0.8080 ndcg@5:0.3506 ndcg@10:0.3806 ndcg@20:0.4098 ndcg@50:0.4414
188 | 2024-07-23 15:27:45 INFO EarlyStopping Counter: 1 out of 10
189 | 2024-07-23 15:28:18 INFO ----------------------------------------------------Epoch 19----------------------------------------------------
190 | 2024-07-23 15:28:18 INFO Training Time :[32.9 s] Training Loss = 6.7210
191 | 2024-07-23 15:28:19 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1515
192 | 2024-07-23 15:28:19 INFO hit@5:0.4448 hit@10:0.5383 hit@20:0.6550 hit@50:0.8110 ndcg@5:0.3541 ndcg@10:0.3843 ndcg@20:0.4137 ndcg@50:0.4453
193 | 2024-07-23 15:28:19 INFO EarlyStopping Counter: 2 out of 10
194 | 2024-07-23 15:28:52 INFO ----------------------------------------------------Epoch 20----------------------------------------------------
195 | 2024-07-23 15:28:52 INFO Training Time :[32.9 s] Training Loss = 6.6472
196 | 2024-07-23 15:28:52 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3712
197 | 2024-07-23 15:28:52 INFO hit@5:0.4424 hit@10:0.5359 hit@20:0.6491 hit@50:0.8114 ndcg@5:0.3503 ndcg@10:0.3805 ndcg@20:0.4090 ndcg@50:0.4418
198 | 2024-07-23 15:28:52 INFO EarlyStopping Counter: 3 out of 10
199 | 2024-07-23 15:29:25 INFO ----------------------------------------------------Epoch 21----------------------------------------------------
200 | 2024-07-23 15:29:25 INFO Training Time :[32.8 s] Training Loss = 6.6233
201 | 2024-07-23 15:29:26 INFO Evaluation Time:[1.0 s] Eval Loss = 9.2789
202 | 2024-07-23 15:29:26 INFO hit@5:0.4415 hit@10:0.5368 hit@20:0.6493 hit@50:0.8083 ndcg@5:0.3498 ndcg@10:0.3806 ndcg@20:0.4089 ndcg@50:0.4411
203 | 2024-07-23 15:29:26 INFO EarlyStopping Counter: 4 out of 10
204 | 2024-07-23 15:29:59 INFO ----------------------------------------------------Epoch 22----------------------------------------------------
205 | 2024-07-23 15:29:59 INFO Training Time :[32.9 s] Training Loss = 6.5730
206 | 2024-07-23 15:30:00 INFO Evaluation Time:[0.9 s] Eval Loss = 9.4025
207 | 2024-07-23 15:30:00 INFO hit@5:0.4408 hit@10:0.5347 hit@20:0.6493 hit@50:0.8103 ndcg@5:0.3508 ndcg@10:0.3810 ndcg@20:0.4099 ndcg@50:0.4424
208 | 2024-07-23 15:30:00 INFO EarlyStopping Counter: 5 out of 10
209 | 2024-07-23 15:30:33 INFO ----------------------------------------------------Epoch 23----------------------------------------------------
210 | 2024-07-23 15:30:33 INFO Training Time :[33.0 s] Training Loss = 6.5330
211 | 2024-07-23 15:30:34 INFO Evaluation Time:[0.9 s] Eval Loss = 9.5505
212 | 2024-07-23 15:30:34 INFO hit@5:0.4380 hit@10:0.5326 hit@20:0.6483 hit@50:0.8110 ndcg@5:0.3489 ndcg@10:0.3794 ndcg@20:0.4085 ndcg@50:0.4413
213 | 2024-07-23 15:30:34 INFO EarlyStopping Counter: 6 out of 10
214 | 2024-07-23 15:31:07 INFO ----------------------------------------------------Epoch 24----------------------------------------------------
215 | 2024-07-23 15:31:07 INFO Training Time :[32.9 s] Training Loss = 6.4867
216 | 2024-07-23 15:31:08 INFO Evaluation Time:[0.9 s] Eval Loss = 9.2692
217 | 2024-07-23 15:31:08 INFO hit@5:0.4389 hit@10:0.5329 hit@20:0.6490 hit@50:0.8044 ndcg@5:0.3481 ndcg@10:0.3784 ndcg@20:0.4076 ndcg@50:0.4391
218 | 2024-07-23 15:31:08 INFO EarlyStopping Counter: 7 out of 10
219 | 2024-07-23 15:31:41 INFO ----------------------------------------------------Epoch 25----------------------------------------------------
220 | 2024-07-23 15:31:41 INFO Training Time :[32.8 s] Training Loss = 6.4528
221 | 2024-07-23 15:31:42 INFO Evaluation Time:[1.0 s] Eval Loss = 9.3407
222 | 2024-07-23 15:31:42 INFO hit@5:0.4401 hit@10:0.5327 hit@20:0.6511 hit@50:0.8073 ndcg@5:0.3492 ndcg@10:0.3790 ndcg@20:0.4088 ndcg@50:0.4404
223 | 2024-07-23 15:31:42 INFO EarlyStopping Counter: 8 out of 10
224 | 2024-07-23 15:32:14 INFO ----------------------------------------------------Epoch 26----------------------------------------------------
225 | 2024-07-23 15:32:14 INFO Training Time :[32.8 s] Training Loss = 6.4163
226 | 2024-07-23 15:32:15 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3725
227 | 2024-07-23 15:32:15 INFO hit@5:0.4367 hit@10:0.5293 hit@20:0.6491 hit@50:0.8035 ndcg@5:0.3471 ndcg@10:0.3771 ndcg@20:0.4071 ndcg@50:0.4384
228 | 2024-07-23 15:32:15 INFO EarlyStopping Counter: 9 out of 10
229 | 2024-07-23 15:32:48 INFO ----------------------------------------------------Epoch 27----------------------------------------------------
230 | 2024-07-23 15:32:48 INFO Training Time :[32.9 s] Training Loss = 6.3712
231 | 2024-07-23 15:32:49 INFO Evaluation Time:[1.0 s] Eval Loss = 9.3215
232 | 2024-07-23 15:32:49 INFO hit@5:0.4349 hit@10:0.5317 hit@20:0.6480 hit@50:0.8015 ndcg@5:0.3455 ndcg@10:0.3767 ndcg@20:0.4059 ndcg@50:0.4370
233 | 2024-07-23 15:32:49 INFO EarlyStopping Counter: 10 out of 10
234 | 2024-07-23 15:32:49 INFO ------------------------------------------------Best Evaluation------------------------------------------------
235 | 2024-07-23 15:32:49 INFO Best Result at Epoch: 17 Early Stop at Patience: 10
236 | 2024-07-23 15:32:49 INFO hit@5:0.4457 hit@10:0.5386 hit@20:0.6551 hit@50:0.8060 ndcg@5:0.3545 ndcg@10:0.3846 ndcg@20:0.4139 ndcg@50:0.4444
237 | 2024-07-23 15:32:50 INFO -----------------------------------------------------Test Results------------------------------------------------------
238 | 2024-07-23 15:32:50 INFO hit@5:0.4030 hit@10:0.4966 hit@20:0.6120 hit@50:0.7719 ndcg@5:0.3163 ndcg@10:0.3464 ndcg@20:0.3755 ndcg@50:0.4080
239 |
--------------------------------------------------------------------------------
/log/toys/IOCRec-seed2024-hidden256_1.log:
--------------------------------------------------------------------------------
1 | 2024-07-23 15:32:52 INFO log save at : log\toys\IOCRec-seed2024-hidden256_1.log
2 | 2024-07-23 15:32:52 INFO model save at: save\IOCRec-toys-seed2024-hidden256-2024-07-23_15-32-52.pth
3 | 2024-07-23 15:32:53 INFO [1] Model Hyper-Parameter ---------------------
4 | 2024-07-23 15:32:53 INFO model: IOCRec
5 | 2024-07-23 15:32:53 INFO model_type: SEQUENTIAL
6 | 2024-07-23 15:32:53 INFO aug_types: ['crop', 'mask', 'reorder']
7 | 2024-07-23 15:32:53 INFO crop_ratio: 0.2
8 | 2024-07-23 15:32:53 INFO mask_ratio: 0.7
9 | 2024-07-23 15:32:53 INFO reorder_ratio: 0.2
10 | 2024-07-23 15:32:53 INFO all_hidden: True
11 | 2024-07-23 15:32:53 INFO tao: 1.0
12 | 2024-07-23 15:32:53 INFO lamda: 0.1
13 | 2024-07-23 15:32:53 INFO k_intention: 4
14 | 2024-07-23 15:32:53 INFO embed_size: 64
15 | 2024-07-23 15:32:53 INFO ffn_hidden: 256
16 | 2024-07-23 15:32:53 INFO num_blocks: 3
17 | 2024-07-23 15:32:53 INFO num_heads: 2
18 | 2024-07-23 15:32:53 INFO hidden_dropout: 0.5
19 | 2024-07-23 15:32:53 INFO attn_dropout: 0.5
20 | 2024-07-23 15:32:53 INFO layer_norm_eps: 1e-12
21 | 2024-07-23 15:32:53 INFO initializer_range: 0.02
22 | 2024-07-23 15:32:53 INFO loss_type: CE
23 | 2024-07-23 15:32:53 INFO [2] Experiment Hyper-Parameter ----------------
24 | 2024-07-23 15:32:53 INFO [2-1] data hyper-parameter --------------------
25 | 2024-07-23 15:32:53 INFO dataset: toys
26 | 2024-07-23 15:32:53 INFO data_aug: True
27 | 2024-07-23 15:32:53 INFO seq_filter_len: 0
28 | 2024-07-23 15:32:53 INFO if_filter_target: False
29 | 2024-07-23 15:32:53 INFO max_len: 50
30 | 2024-07-23 15:32:53 INFO [2-2] pretraining hyper-parameter -------------
31 | 2024-07-23 15:32:53 INFO do_pretraining: False
32 | 2024-07-23 15:32:53 INFO pretraining_task: MISP
33 | 2024-07-23 15:32:53 INFO pretraining_epoch: 10
34 | 2024-07-23 15:32:53 INFO pretraining_batch: 512
35 | 2024-07-23 15:32:53 INFO pretraining_lr: 0.001
36 | 2024-07-23 15:32:53 INFO pretraining_l2: 0.0
37 | 2024-07-23 15:32:53 INFO [2-3] training hyper-parameter ----------------
38 | 2024-07-23 15:32:53 INFO epoch_num: 150
39 | 2024-07-23 15:32:53 INFO train_batch: 256
40 | 2024-07-23 15:32:53 INFO learning_rate: 0.001
41 | 2024-07-23 15:32:53 INFO l2: 0
42 | 2024-07-23 15:32:53 INFO patience: 10
43 | 2024-07-23 15:32:53 INFO device: cuda:0
44 | 2024-07-23 15:32:53 INFO num_worker: 0
45 | 2024-07-23 15:32:53 INFO seed: 2024
46 | 2024-07-23 15:32:53 INFO [2-4] evaluation hyper-parameter --------------
47 | 2024-07-23 15:32:53 INFO split_type: valid_and_test
48 | 2024-07-23 15:32:53 INFO split_mode: LS
49 | 2024-07-23 15:32:53 INFO eval_mode: uni100
50 | 2024-07-23 15:32:53 INFO metric: ['hit', 'ndcg']
51 | 2024-07-23 15:32:53 INFO k: [5, 10, 20, 50]
52 | 2024-07-23 15:32:53 INFO valid_metric: hit@10
53 | 2024-07-23 15:32:53 INFO eval_batch: 256
54 | 2024-07-23 15:32:53 INFO [2-5] save hyper-parameter --------------------
55 | 2024-07-23 15:32:53 INFO log_save: log
56 | 2024-07-23 15:32:53 INFO save: save
57 | 2024-07-23 15:32:53 INFO model_saved: None
58 | 2024-07-23 15:32:53 INFO [3] Data Statistic ----------------------------
59 | 2024-07-23 15:32:53 INFO dataset: toys
60 | 2024-07-23 15:32:53 INFO user number: 19412
61 | 2024-07-23 15:32:53 INFO item number: 11925
62 | 2024-07-23 15:32:53 INFO average seq length: 8.6337
63 | 2024-07-23 15:32:53 INFO density: 0.0007 sparsity: 0.9993
64 | 2024-07-23 15:32:53 INFO data after augmentation:
65 | 2024-07-23 15:32:53 INFO train samples: 109361 eval samples: 19412 test samples: 19412
66 | 2024-07-23 15:32:53 INFO [1] Model Architecture ------------------------
67 | 2024-07-23 15:32:53 INFO total parameters: 936448
68 | 2024-07-23 15:32:53 INFO IOCRec(
69 | (cross_entropy): CrossEntropyLoss()
70 | (item_embedding): Embedding(11927, 64, padding_idx=0)
71 | (position_embedding): Embedding(50, 64)
72 | (input_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
73 | (input_dropout): Dropout(p=0.5, inplace=False)
74 | (local_encoder): Transformer(
75 | (encoder_layers): ModuleList(
76 | (0-2): 3 x EncoderLayer(
77 | (attn_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
78 | (pff_layer_norm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
79 | (self_attention): MultiHeadAttentionLayer(
80 | (fc_q): Linear(in_features=64, out_features=64, bias=True)
81 | (fc_k): Linear(in_features=64, out_features=64, bias=True)
82 | (fc_v): Linear(in_features=64, out_features=64, bias=True)
83 | (attn_dropout): Dropout(p=0.5, inplace=False)
84 | (fc_o): Linear(in_features=64, out_features=64, bias=True)
85 | )
86 | (pff): PointWiseFeedForwardLayer(
87 | (fc1): Linear(in_features=64, out_features=256, bias=True)
88 | (fc2): Linear(in_features=256, out_features=64, bias=True)
89 | )
90 | (hidden_dropout): Dropout(p=0.5, inplace=False)
91 | (pff_out_drop): Dropout(p=0.5, inplace=False)
92 | )
93 | )
94 | )
95 | (global_seq_encoder): GlobalSeqEncoder(
96 | (dropout): Dropout(p=0.5, inplace=False)
97 | (K_linear): Linear(in_features=64, out_features=64, bias=True)
98 | (V_linear): Linear(in_features=64, out_features=64, bias=True)
99 | )
100 | (disentangle_encoder): DisentangleEncoder(
101 | (pos_fai): Embedding(50, 64)
102 | (W): Linear(in_features=64, out_features=64, bias=True)
103 | (layer_norm_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
104 | (layer_norm_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
105 | (layer_norm_3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
106 | (layer_norm_4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
107 | (layer_norm_5): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
108 | )
109 | (nce_loss): InfoNCELoss(
110 | (criterion): CrossEntropyLoss()
111 | )
112 | )
113 | 2024-07-23 15:32:53 INFO Start training...
114 | 2024-07-23 15:33:23 INFO ----------------------------------------------------Epoch 1----------------------------------------------------
115 | 2024-07-23 15:33:23 INFO Training Time :[30.5 s] Training Loss = 11.2264
116 | 2024-07-23 15:33:24 INFO Evaluation Time:[0.9 s] Eval Loss = 9.386
117 | 2024-07-23 15:33:24 INFO hit@5:0.0798 hit@10:0.1371 hit@20:0.2449 hit@50:0.3658 ndcg@5:0.0499 ndcg@10:0.0682 ndcg@20:0.0951 ndcg@50:0.1199
118 | 2024-07-23 15:33:54 INFO ----------------------------------------------------Epoch 2----------------------------------------------------
119 | 2024-07-23 15:33:54 INFO Training Time :[30.1 s] Training Loss = 9.9087
120 | 2024-07-23 15:33:55 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3211
121 | 2024-07-23 15:33:55 INFO hit@5:0.1836 hit@10:0.2819 hit@20:0.4099 hit@50:0.4318 ndcg@5:0.1191 ndcg@10:0.1507 ndcg@20:0.1832 ndcg@50:0.1879
122 | 2024-07-23 15:34:25 INFO ----------------------------------------------------Epoch 3----------------------------------------------------
123 | 2024-07-23 15:34:25 INFO Training Time :[30.1 s] Training Loss = 9.6230
124 | 2024-07-23 15:34:26 INFO Evaluation Time:[0.8 s] Eval Loss = 9.2058
125 | 2024-07-23 15:34:26 INFO hit@5:0.2080 hit@10:0.3090 hit@20:0.4284 hit@50:0.4485 ndcg@5:0.1420 ndcg@10:0.1745 ndcg@20:0.2048 ndcg@50:0.2092
126 | 2024-07-23 15:34:56 INFO ----------------------------------------------------Epoch 4----------------------------------------------------
127 | 2024-07-23 15:34:56 INFO Training Time :[30.2 s] Training Loss = 9.3300
128 | 2024-07-23 15:34:57 INFO Evaluation Time:[0.9 s] Eval Loss = 9.095
129 | 2024-07-23 15:34:57 INFO hit@5:0.2436 hit@10:0.3391 hit@20:0.4570 hit@50:0.4859 ndcg@5:0.1707 ndcg@10:0.2014 ndcg@20:0.2313 ndcg@50:0.2375
130 | 2024-07-23 15:35:27 INFO ----------------------------------------------------Epoch 5----------------------------------------------------
131 | 2024-07-23 15:35:27 INFO Training Time :[30.0 s] Training Loss = 9.0510
132 | 2024-07-23 15:35:28 INFO Evaluation Time:[0.9 s] Eval Loss = 8.9176
133 | 2024-07-23 15:35:28 INFO hit@5:0.2988 hit@10:0.3964 hit@20:0.5138 hit@50:0.5666 ndcg@5:0.2144 ndcg@10:0.2460 ndcg@20:0.2756 ndcg@50:0.2867
134 | 2024-07-23 15:35:58 INFO ----------------------------------------------------Epoch 6----------------------------------------------------
135 | 2024-07-23 15:35:58 INFO Training Time :[30.0 s] Training Loss = 8.7750
136 | 2024-07-23 15:35:59 INFO Evaluation Time:[0.9 s] Eval Loss = 8.7004
137 | 2024-07-23 15:35:59 INFO hit@5:0.3550 hit@10:0.4554 hit@20:0.5695 hit@50:0.6418 ndcg@5:0.2613 ndcg@10:0.2937 ndcg@20:0.3225 ndcg@50:0.3375
138 | 2024-07-23 15:36:29 INFO ----------------------------------------------------Epoch 7----------------------------------------------------
139 | 2024-07-23 15:36:29 INFO Training Time :[30.0 s] Training Loss = 8.5188
140 | 2024-07-23 15:36:30 INFO Evaluation Time:[0.9 s] Eval Loss = 8.5374
141 | 2024-07-23 15:36:30 INFO hit@5:0.3887 hit@10:0.4897 hit@20:0.6031 hit@50:0.6945 ndcg@5:0.2926 ndcg@10:0.3252 ndcg@20:0.3539 ndcg@50:0.3727
142 | 2024-07-23 15:37:00 INFO ----------------------------------------------------Epoch 8----------------------------------------------------
143 | 2024-07-23 15:37:00 INFO Training Time :[30.0 s] Training Loss = 8.2575
144 | 2024-07-23 15:37:00 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4156
145 | 2024-07-23 15:37:00 INFO hit@5:0.4123 hit@10:0.5096 hit@20:0.6229 hit@50:0.7218 ndcg@5:0.3151 ndcg@10:0.3466 ndcg@20:0.3752 ndcg@50:0.3955
146 | 2024-07-23 15:37:30 INFO ----------------------------------------------------Epoch 9----------------------------------------------------
147 | 2024-07-23 15:37:30 INFO Training Time :[30.1 s] Training Loss = 8.0445
148 | 2024-07-23 15:37:31 INFO Evaluation Time:[0.8 s] Eval Loss = 8.3391
149 | 2024-07-23 15:37:31 INFO hit@5:0.4217 hit@10:0.5192 hit@20:0.6349 hit@50:0.7477 ndcg@5:0.3260 ndcg@10:0.3573 ndcg@20:0.3866 ndcg@50:0.4096
150 | 2024-07-23 15:38:01 INFO ----------------------------------------------------Epoch 10----------------------------------------------------
151 | 2024-07-23 15:38:01 INFO Training Time :[30.1 s] Training Loss = 7.8524
152 | 2024-07-23 15:38:02 INFO Evaluation Time:[0.8 s] Eval Loss = 8.3403
153 | 2024-07-23 15:38:02 INFO hit@5:0.4295 hit@10:0.5239 hit@20:0.6416 hit@50:0.7662 ndcg@5:0.3334 ndcg@10:0.3639 ndcg@20:0.3935 ndcg@50:0.4189
154 | 2024-07-23 15:38:32 INFO ----------------------------------------------------Epoch 11----------------------------------------------------
155 | 2024-07-23 15:38:32 INFO Training Time :[30.0 s] Training Loss = 7.6800
156 | 2024-07-23 15:38:33 INFO Evaluation Time:[0.9 s] Eval Loss = 8.339
157 | 2024-07-23 15:38:33 INFO hit@5:0.4319 hit@10:0.5293 hit@20:0.6466 hit@50:0.7620 ndcg@5:0.3390 ndcg@10:0.3704 ndcg@20:0.4000 ndcg@50:0.4236
158 | 2024-07-23 15:39:03 INFO ----------------------------------------------------Epoch 12----------------------------------------------------
159 | 2024-07-23 15:39:03 INFO Training Time :[30.0 s] Training Loss = 7.5258
160 | 2024-07-23 15:39:04 INFO Evaluation Time:[0.9 s] Eval Loss = 8.4081
161 | 2024-07-23 15:39:04 INFO hit@5:0.4367 hit@10:0.5316 hit@20:0.6482 hit@50:0.7733 ndcg@5:0.3438 ndcg@10:0.3745 ndcg@20:0.4038 ndcg@50:0.4293
162 | 2024-07-23 15:39:34 INFO ----------------------------------------------------Epoch 13----------------------------------------------------
163 | 2024-07-23 15:39:34 INFO Training Time :[30.0 s] Training Loss = 7.4114
164 | 2024-07-23 15:39:35 INFO Evaluation Time:[0.9 s] Eval Loss = 8.5115
165 | 2024-07-23 15:39:35 INFO hit@5:0.4343 hit@10:0.5312 hit@20:0.6486 hit@50:0.7707 ndcg@5:0.3443 ndcg@10:0.3755 ndcg@20:0.4050 ndcg@50:0.4300
166 | 2024-07-23 15:39:35 INFO EarlyStopping Counter: 1 out of 10
167 | 2024-07-23 15:40:05 INFO ----------------------------------------------------Epoch 14----------------------------------------------------
168 | 2024-07-23 15:40:05 INFO Training Time :[30.0 s] Training Loss = 7.2897
169 | 2024-07-23 15:40:06 INFO Evaluation Time:[0.8 s] Eval Loss = 8.5864
170 | 2024-07-23 15:40:06 INFO hit@5:0.4365 hit@10:0.5341 hit@20:0.6490 hit@50:0.7624 ndcg@5:0.3448 ndcg@10:0.3762 ndcg@20:0.4052 ndcg@50:0.4285
171 | 2024-07-23 15:40:36 INFO ----------------------------------------------------Epoch 15----------------------------------------------------
172 | 2024-07-23 15:40:36 INFO Training Time :[30.0 s] Training Loss = 7.2119
173 | 2024-07-23 15:40:37 INFO Evaluation Time:[0.9 s] Eval Loss = 8.69
174 | 2024-07-23 15:40:37 INFO hit@5:0.4393 hit@10:0.5335 hit@20:0.6532 hit@50:0.7737 ndcg@5:0.3484 ndcg@10:0.3787 ndcg@20:0.4089 ndcg@50:0.4336
175 | 2024-07-23 15:40:37 INFO EarlyStopping Counter: 1 out of 10
176 | 2024-07-23 15:41:07 INFO ----------------------------------------------------Epoch 16----------------------------------------------------
177 | 2024-07-23 15:41:07 INFO Training Time :[30.2 s] Training Loss = 7.1417
178 | 2024-07-23 15:41:08 INFO Evaluation Time:[0.9 s] Eval Loss = 8.6575
179 | 2024-07-23 15:41:08 INFO hit@5:0.4397 hit@10:0.5368 hit@20:0.6544 hit@50:0.7579 ndcg@5:0.3503 ndcg@10:0.3817 ndcg@20:0.4113 ndcg@50:0.4325
180 | 2024-07-23 15:41:38 INFO ----------------------------------------------------Epoch 17----------------------------------------------------
181 | 2024-07-23 15:41:38 INFO Training Time :[30.1 s] Training Loss = 7.0573
182 | 2024-07-23 15:41:39 INFO Evaluation Time:[0.8 s] Eval Loss = 8.809
183 | 2024-07-23 15:41:39 INFO hit@5:0.4395 hit@10:0.5355 hit@20:0.6549 hit@50:0.7616 ndcg@5:0.3498 ndcg@10:0.3808 ndcg@20:0.4109 ndcg@50:0.4329
184 | 2024-07-23 15:41:39 INFO EarlyStopping Counter: 1 out of 10
185 | 2024-07-23 15:42:09 INFO ----------------------------------------------------Epoch 18----------------------------------------------------
186 | 2024-07-23 15:42:09 INFO Training Time :[30.0 s] Training Loss = 6.9430
187 | 2024-07-23 15:42:09 INFO Evaluation Time:[0.9 s] Eval Loss = 8.8837
188 | 2024-07-23 15:42:09 INFO hit@5:0.4430 hit@10:0.5368 hit@20:0.6553 hit@50:0.7630 ndcg@5:0.3530 ndcg@10:0.3832 ndcg@20:0.4131 ndcg@50:0.4353
189 | 2024-07-23 15:42:09 INFO EarlyStopping Counter: 2 out of 10
190 | 2024-07-23 15:42:39 INFO ----------------------------------------------------Epoch 19----------------------------------------------------
191 | 2024-07-23 15:42:39 INFO Training Time :[30.0 s] Training Loss = 6.8674
192 | 2024-07-23 15:42:40 INFO Evaluation Time:[0.9 s] Eval Loss = 8.9556
193 | 2024-07-23 15:42:40 INFO hit@5:0.4411 hit@10:0.5384 hit@20:0.6570 hit@50:0.7684 ndcg@5:0.3518 ndcg@10:0.3831 ndcg@20:0.4131 ndcg@50:0.4360
194 | 2024-07-23 15:43:10 INFO ----------------------------------------------------Epoch 20----------------------------------------------------
195 | 2024-07-23 15:43:10 INFO Training Time :[30.0 s] Training Loss = 6.8115
196 | 2024-07-23 15:43:11 INFO Evaluation Time:[0.9 s] Eval Loss = 9.0637
197 | 2024-07-23 15:43:11 INFO hit@5:0.4445 hit@10:0.5388 hit@20:0.6559 hit@50:0.7718 ndcg@5:0.3537 ndcg@10:0.3841 ndcg@20:0.4137 ndcg@50:0.4375
198 | 2024-07-23 15:43:41 INFO ----------------------------------------------------Epoch 21----------------------------------------------------
199 | 2024-07-23 15:43:41 INFO Training Time :[30.0 s] Training Loss = 6.7497
200 | 2024-07-23 15:43:42 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1482
201 | 2024-07-23 15:43:42 INFO hit@5:0.4407 hit@10:0.5364 hit@20:0.6575 hit@50:0.7716 ndcg@5:0.3513 ndcg@10:0.3822 ndcg@20:0.4127 ndcg@50:0.4363
202 | 2024-07-23 15:43:42 INFO EarlyStopping Counter: 1 out of 10
203 | 2024-07-23 15:44:12 INFO ----------------------------------------------------Epoch 22----------------------------------------------------
204 | 2024-07-23 15:44:12 INFO Training Time :[30.1 s] Training Loss = 6.7074
205 | 2024-07-23 15:44:13 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1342
206 | 2024-07-23 15:44:13 INFO hit@5:0.4402 hit@10:0.5359 hit@20:0.6563 hit@50:0.7722 ndcg@5:0.3520 ndcg@10:0.3828 ndcg@20:0.4132 ndcg@50:0.4371
207 | 2024-07-23 15:44:13 INFO EarlyStopping Counter: 2 out of 10
208 | 2024-07-23 15:44:43 INFO ----------------------------------------------------Epoch 23----------------------------------------------------
209 | 2024-07-23 15:44:43 INFO Training Time :[29.9 s] Training Loss = 6.9449
210 | 2024-07-23 15:44:44 INFO Evaluation Time:[0.9 s] Eval Loss = 9.2002
211 | 2024-07-23 15:44:44 INFO hit@5:0.4395 hit@10:0.5361 hit@20:0.6530 hit@50:0.7706 ndcg@5:0.3510 ndcg@10:0.3821 ndcg@20:0.4116 ndcg@50:0.4359
212 | 2024-07-23 15:44:44 INFO EarlyStopping Counter: 3 out of 10
213 | 2024-07-23 15:45:14 INFO ----------------------------------------------------Epoch 24----------------------------------------------------
214 | 2024-07-23 15:45:14 INFO Training Time :[30.0 s] Training Loss = 6.6846
215 | 2024-07-23 15:45:15 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3021
216 | 2024-07-23 15:45:15 INFO hit@5:0.4405 hit@10:0.5362 hit@20:0.6533 hit@50:0.7735 ndcg@5:0.3501 ndcg@10:0.3810 ndcg@20:0.4104 ndcg@50:0.4353
217 | 2024-07-23 15:45:15 INFO EarlyStopping Counter: 4 out of 10
218 | 2024-07-23 15:45:45 INFO ----------------------------------------------------Epoch 25----------------------------------------------------
219 | 2024-07-23 15:45:45 INFO Training Time :[30.0 s] Training Loss = 6.5977
220 | 2024-07-23 15:45:46 INFO Evaluation Time:[0.9 s] Eval Loss = 9.1295
221 | 2024-07-23 15:45:46 INFO hit@5:0.4438 hit@10:0.5363 hit@20:0.6579 hit@50:0.7648 ndcg@5:0.3533 ndcg@10:0.3832 ndcg@20:0.4137 ndcg@50:0.4358
222 | 2024-07-23 15:45:46 INFO EarlyStopping Counter: 5 out of 10
223 | 2024-07-23 15:46:16 INFO ----------------------------------------------------Epoch 26----------------------------------------------------
224 | 2024-07-23 15:46:16 INFO Training Time :[30.0 s] Training Loss = 6.5443
225 | 2024-07-23 15:46:16 INFO Evaluation Time:[0.9 s] Eval Loss = 9.4042
226 | 2024-07-23 15:46:16 INFO hit@5:0.4404 hit@10:0.5320 hit@20:0.6549 hit@50:0.7668 ndcg@5:0.3508 ndcg@10:0.3804 ndcg@20:0.4113 ndcg@50:0.4344
227 | 2024-07-23 15:46:16 INFO EarlyStopping Counter: 6 out of 10
228 | 2024-07-23 15:46:46 INFO ----------------------------------------------------Epoch 27----------------------------------------------------
229 | 2024-07-23 15:46:46 INFO Training Time :[30.0 s] Training Loss = 6.5445
230 | 2024-07-23 15:46:47 INFO Evaluation Time:[0.9 s] Eval Loss = 9.3367
231 | 2024-07-23 15:46:47 INFO hit@5:0.4423 hit@10:0.5359 hit@20:0.6557 hit@50:0.7659 ndcg@5:0.3524 ndcg@10:0.3826 ndcg@20:0.4127 ndcg@50:0.4355
232 | 2024-07-23 15:46:47 INFO EarlyStopping Counter: 7 out of 10
233 | 2024-07-23 15:47:17 INFO ----------------------------------------------------Epoch 28----------------------------------------------------
234 | 2024-07-23 15:47:17 INFO Training Time :[30.0 s] Training Loss = 6.4637
235 | 2024-07-23 15:47:18 INFO Evaluation Time:[0.9 s] Eval Loss = 9.4813
236 | 2024-07-23 15:47:18 INFO hit@5:0.4383 hit@10:0.5310 hit@20:0.6539 hit@50:0.7706 ndcg@5:0.3502 ndcg@10:0.3802 ndcg@20:0.4110 ndcg@50:0.4351
237 | 2024-07-23 15:47:18 INFO EarlyStopping Counter: 8 out of 10
238 | 2024-07-23 15:47:48 INFO ----------------------------------------------------Epoch 29----------------------------------------------------
239 | 2024-07-23 15:47:48 INFO Training Time :[29.9 s] Training Loss = 6.4458
240 | 2024-07-23 15:47:49 INFO Evaluation Time:[0.9 s] Eval Loss = 9.5087
241 | 2024-07-23 15:47:49 INFO hit@5:0.4374 hit@10:0.5339 hit@20:0.6544 hit@50:0.7711 ndcg@5:0.3489 ndcg@10:0.3801 ndcg@20:0.4104 ndcg@50:0.4345
242 | 2024-07-23 15:47:49 INFO EarlyStopping Counter: 9 out of 10
243 | 2024-07-23 15:48:19 INFO ----------------------------------------------------Epoch 30----------------------------------------------------
244 | 2024-07-23 15:48:19 INFO Training Time :[30.0 s] Training Loss = 6.4197
245 | 2024-07-23 15:48:20 INFO Evaluation Time:[0.8 s] Eval Loss = 9.4156
246 | 2024-07-23 15:48:20 INFO hit@5:0.4365 hit@10:0.5323 hit@20:0.6538 hit@50:0.7658 ndcg@5:0.3473 ndcg@10:0.3782 ndcg@20:0.4088 ndcg@50:0.4320
247 | 2024-07-23 15:48:20 INFO EarlyStopping Counter: 10 out of 10
248 | 2024-07-23 15:48:20 INFO ------------------------------------------------Best Evaluation------------------------------------------------
249 | 2024-07-23 15:48:20 INFO Best Result at Epoch: 20 Early Stop at Patience: 10
250 | 2024-07-23 15:48:20 INFO hit@5:0.4445 hit@10:0.5388 hit@20:0.6559 hit@50:0.7718 ndcg@5:0.3537 ndcg@10:0.3841 ndcg@20:0.4137 ndcg@50:0.4375
251 | 2024-07-23 15:48:21 INFO -----------------------------------------------------Test Results------------------------------------------------------
252 | 2024-07-23 15:48:21 INFO hit@5:0.4025 hit@10:0.4954 hit@20:0.6142 hit@50:0.7306 ndcg@5:0.3152 ndcg@10:0.3451 ndcg@20:0.3750 ndcg@50:0.3989
253 |
--------------------------------------------------------------------------------
/runIOCRec.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from src.train.trainer import load_trainer
3 |
4 | if __name__ == '__main__':
5 | parser = argparse.ArgumentParser()
6 |
7 | # Model
8 | parser.add_argument('--model', default='IOCRec', type=str)
9 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge'])
10 | # Contrast Learning Hyper Params
11 | parser.add_argument('--aug_types', default=['crop', 'mask', 'reorder'], help='augmentation types')
12 | parser.add_argument('--crop_ratio', default=0.2, type=float,
13 | help='Crop augmentation: proportion of cropped subsequence in origin sequence')
14 | parser.add_argument('--mask_ratio', default=0.7, type=float,
15 | help='Mask augmentation: proportion of masked items in origin sequence')
16 | parser.add_argument('--reorder_ratio', default=0.2, type=float,
17 | help='Reorder augmentation: proportion of reordered subsequence in origin sequence')
18 | parser.add_argument('--all_hidden', action='store_false', help='all hidden states for cl')
19 | parser.add_argument('--tao', default=1., type=float, help='temperature for softmax')
20 | parser.add_argument('--lamda', default=0.1, type=float,
21 | help='weight for contrast learning loss, only work when jointly training')
22 | parser.add_argument('--k_intention', default=4, type=int, help='number of disentangled intention')
23 | # Transformer
24 | parser.add_argument('--embed_size', default=128, type=int)
25 | parser.add_argument('--ffn_hidden', default=512, type=int, help='hidden dim for feed forward network')
26 | parser.add_argument('--num_blocks', default=3, type=int, help='number of transformer block')
27 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention')
28 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate')
29 | parser.add_argument('--attn_dropout', default=0.5, type=float, help='dropout rate for attention')
30 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps')
31 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range')
32 | # Data
33 | parser.add_argument('--dataset', default='toys', type=str)
34 | # Training
35 | parser.add_argument('--epoch_num', default=150, type=int)
36 | parser.add_argument('--data_aug', action='store_false', help='data augmentation')
37 | parser.add_argument('--train_batch', default=256, type=int)
38 | parser.add_argument('--learning_rate', default=1e-3, type=float)
39 | parser.add_argument('--l2', default=0, type=float, help='l2 normalization')
40 | parser.add_argument('--patience', default=10, help='early stop patience')
41 | parser.add_argument('--seed', default=2024, type=int, help='random seed, -1 means no fixed seed')
42 | parser.add_argument('--mark', default='seed2024', type=str, help='log suffix mark')
43 | # Evaluation
44 | parser.add_argument('--split_type', default='valid_and_test', choices=['valid_only', 'valid_and_test'])
45 | parser.add_argument('--split_mode', default='LS', type=str, help='[LS (leave-one-out), LS_R@0.x, PS (pre-split)]')
46 | parser.add_argument('--eval_mode', default='full', help='[uni100, pop100, full]')
47 | parser.add_argument('--k', default=[5, 10, 20, 50], help='rank k for each metric')
48 | parser.add_argument('--metric', default=['hit', 'ndcg'], help='[hit, ndcg, mrr, recall]')
49 | parser.add_argument('--valid_metric', default='hit@10', help='specifies which indicator to apply early stop')
50 |
51 | config = parser.parse_args()
52 |
53 | trainer = load_trainer(config)
54 | trainer.start_training()
55 |
56 |
57 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/__init__.py
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__init__.py
--------------------------------------------------------------------------------
/src/dataset/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/dataset/__pycache__/data_processor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__pycache__/data_processor.cpython-38.pyc
--------------------------------------------------------------------------------
/src/dataset/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/dataset/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/src/dataset/data_processor.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import copy
3 | import logging
4 | # import dgl
5 | import numpy as np
6 | import torch
7 | import os
8 | from scipy import sparse
9 | from src.utils.utils import save_pickle, load_pickle
10 |
11 |
12 | class DataProcessor:
13 | def __init__(self, config):
14 | self.config = config
15 | self.model_name = config.model
16 | self.dataset = config.dataset
17 | self.data_aug = config.data_aug
18 |
19 | self.use_tar_seq = False
20 | self.tar_seq_len = 1
21 | if self.model_name in ['KERL', 'MELOD']:
22 | self.use_tar_seq = True
23 | self.tar_seq_len = config.episode_len
24 | self.filter_len = config.seq_filter_len
25 | self.filter_target = config.if_filter_target
26 | self.item_id_pad = 1 # increase item id: x -> x + 1
27 | self.sep = config.separator
28 | self.graph_type_list = [g_type.upper() for g_type in config.graph_type]
29 | self.valid_graphs = ['GGNN', 'BIPARTITE', 'TRANSITION', 'HYPER']
30 |
31 | self.data_path = None
32 | self.train_data = None
33 | self.eval_data = None
34 | self.test_data = None
35 | self.kg_map = None
36 | self.do_data_split = True
37 | self.popularity = None
38 |
39 | self.split_type = config.split_type
40 | self.split_mode = config.split_mode
41 | self.eval_ratio = 0.2 # default
42 |
43 | # data statistic
44 | self.max_item = 0
45 | self.num_items = 0 # with padding item 0
46 | self.num_users = 0
47 | self.seq_avg_len = 0.
48 | self.total_seq_len = 0.
49 | self.density = 0.
50 | self.sparsity = 0.
51 |
52 | self._init_data_processer()
53 |
54 | def _init_data_processer(self):
55 | if self.split_mode == 'PS':
56 | self.do_data_split = False
57 | elif 'LS_R' == self.split_mode.split('@')[0]:
58 | self.eval_ratio = float(self.split_mode.split('@')[-1])
59 | self.split_mode = 'LS_R'
60 | self._set_data_path()
61 |
62 | def prepare_data(self):
63 | if self.do_data_split:
64 | seq_data_list = self._load_row_data()
65 | self._train_test_split(seq_data_list)
66 | else: # load pre-split data
67 | self._load_pre_split_data()
68 |
69 | data_dict = {'train': self.train_data,
70 | 'eval': self.eval_data,
71 | 'test': self.test_data,
72 | 'raw_train': self.row_train_data}
73 |
74 | extra_data_dict = self._prepare_additional_data()
75 |
76 | return data_dict, extra_data_dict
77 |
78 | def _prepare_additional_data(self):
79 | additional_data_dict = {}
80 | if self.config.model_type.lower() == 'graph':
81 | specified_graph_dict = self.prepare_specified_graph()
82 | if specified_graph_dict is not None:
83 | additional_data_dict['graph_dict'] = specified_graph_dict
84 | if self.config.model_type.lower() == 'knowledge':
85 | if self.config.kg_data_type == 'pretrain':
86 | kg_map = self.prepare_kg_map().float().to(self.config.device)
87 | additional_data_dict['kg_map'] = kg_map
88 | elif self.config.kg_data_type == 'jointly_train':
89 | triple_list, relation_tph, relation_hpt = self.load_kg_triplet(entity_id_shift=1)
90 | additional_data_dict['triple_list'] = triple_list
91 | additional_data_dict['relation_tph'] = relation_tph
92 | additional_data_dict['relation_hpt'] = relation_hpt
93 | else:
94 | pass
95 | if self.model_name in ['DuoRec']:
96 | train_target_item = self.train_data[1]
97 | same_target_index = self.semantic_augmentation(train_target_item)
98 | additional_data_dict['same_target_index'] = same_target_index
99 | elif self.model_name in ['KGSCL']:
100 | kg_relation_dict = load_pickle(os.path.join(self.data_path, 'kg_relation_set.pkl'))
101 | co_occurrence_dict = self.calc_co_occurrence(kg_relation_dict)
102 | additional_data_dict['kg_relation_dict'] = kg_relation_dict
103 | additional_data_dict['co_occurrence_dict'] = co_occurrence_dict
104 |
105 | return additional_data_dict
106 |
107 | def semantic_augmentation(self, target_item):
108 | target_item = np.array(target_item, dtype=np.long)
109 | aug_path = self.data_path + '/semantic_augmentation.pkl'
110 | target_deduplicated = set(target_item)
111 | if os.path.exists(aug_path):
112 | same_target_index = load_pickle(aug_path)
113 | else:
114 | same_target_index = {}
115 | for _, item_id in enumerate(target_deduplicated):
116 | all_index_same_id_w_self = np.where(target_item == item_id)[
117 | 0].tolist() # all index of a specific item id with self item, including itself
118 | if item_id not in same_target_index.keys():
119 | same_target_index[item_id] = []
120 | same_target_index[item_id].extend(all_index_same_id_w_self)
121 | save_pickle(same_target_index, aug_path)
122 | return same_target_index
123 |
124 | def calc_co_occurrence(self, kg_relation_dict):
125 | dict_path = self.data_path + '/co_occurrence.pkl'
126 | if os.path.exists(dict_path):
127 | co_occurrence_dict = load_pickle(dict_path)
128 | return co_occurrence_dict
129 | else:
130 | co_occurrence_all_data = {}
131 | for item_seq, target in zip(*self.row_train_data):
132 | for i in range(len(item_seq)):
133 | for j in range(i + 1, len(item_seq)):
134 | item_1, item_2 = item_seq[i], item_seq[j]
135 | co_item_pair = (item_1, item_2) if item_1 < item_2 else (item_2, item_1)
136 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}'
137 | if co_item_pair not in co_occurrence_all_data.keys():
138 | co_occurrence_all_data[co_item_pair] = 0.
139 | co_occurrence_all_data[co_item_pair] += 1.
140 | # train target item
141 | for item in item_seq:
142 | co_item_pair = (item, target) if item < target else (target, item_seq)
143 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}'
144 | if co_item_pair not in co_occurrence_all_data.keys():
145 | co_occurrence_all_data[co_item_pair] = 0.
146 | co_occurrence_all_data[co_item_pair] += 1.
147 |
148 | # calc co-occurrence for items with kg relations
149 | co_occurrence_dict = {}
150 | for item in kg_relation_dict.keys():
151 | shifted_item = item + 1
152 | co_occurrence_dict[item] = {'s': [], 'c': []}
153 |
154 | # frequency of substitute items
155 | substitute_freq_list = []
156 | substitute_item_list = kg_relation_dict[item]['s']
157 | if len(substitute_item_list) > 0:
158 | for sub_item in substitute_item_list:
159 | shifted_sub_item = sub_item + 1
160 | co_item_pair = (shifted_item, shifted_sub_item) if item < sub_item else (
161 | shifted_sub_item, shifted_item)
162 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}'
163 | if co_item_pair not in co_occurrence_all_data.keys():
164 | substitute_freq_list.append(1.)
165 | else:
166 | substitute_freq_list.append(1. + float(co_occurrence_all_data[co_item_pair]))
167 | sum_freq = sum(substitute_freq_list)
168 | substitute_freq_list = [freq / sum_freq for freq in substitute_freq_list]
169 | assert len(substitute_freq_list) == len(substitute_item_list)
170 | co_occurrence_dict[item]['s'] = substitute_freq_list
171 |
172 | # frequency of complement items
173 | complement_freq_list = []
174 | complement_item_list = kg_relation_dict[item]['c']
175 | if len(complement_item_list) > 0:
176 | for sub_item in complement_item_list:
177 | shifted_sub_item = sub_item + 1
178 | co_item_pair = (shifted_item, shifted_sub_item) if item < sub_item else (
179 | shifted_sub_item, shifted_item)
180 | co_item_pair = f'{co_item_pair[0]}-{co_item_pair[1]}'
181 | if co_item_pair not in co_occurrence_all_data.keys():
182 | complement_freq_list.append(1.)
183 | else:
184 | complement_freq_list.append(1. + co_occurrence_all_data[co_item_pair])
185 | sum_freq = sum(complement_freq_list)
186 | complement_freq_list = [freq / sum_freq for freq in complement_freq_list]
187 | assert len(complement_freq_list) == len(complement_item_list)
188 | co_occurrence_dict[item]['c'] = complement_freq_list
189 |
190 | save_pickle(co_occurrence_dict, dict_path)
191 | return co_occurrence_dict
192 |
193 | def load_kg_triplet(self, entity_id_shift=1):
194 | """
195 | entity_id_shift: entity_id += 1
196 |
197 | return:
198 | triple_list: list of kg triplets
199 | relation_tph: dict, tail number for each head in relation r
200 | relation_hpt: dict, head number of each tail in relation r
201 | """
202 | # required kg files
203 | train_triplet = f'{self.data_path}/kg_data/train.txt'
204 | entity2id = f'{self.data_path}/kg_data/entities.dict'
205 | relation2id = f'{self.data_path}/kg_data/relations.dict'
206 |
207 | entities2id = {}
208 | relations2id = {}
209 | relation_tph = {}
210 | relation_hpt = {}
211 |
212 | entity = []
213 | relation = []
214 | # read entity and relation sets
215 | with open(entity2id, 'r') as f1, open(relation2id, 'r') as f2:
216 | lines1 = f1.readlines()
217 | lines2 = f2.readlines()
218 | for line in lines1:
219 | line = line.strip().split('\t')
220 | if len(line) != 2:
221 | continue
222 | entities2id[line[0]] = int(line[1]) + entity_id_shift
223 | entity.append(int(line[1]))
224 |
225 | for line in lines2:
226 | line = line.strip().split('\t')
227 | if len(line) != 2:
228 | continue
229 | relations2id[line[0]] = int(line[1])
230 | relation.append(int(line[1]))
231 |
232 | triple_list = []
233 | relation_head = {}
234 | relation_tail = {}
235 |
236 | # read kg triplets, convert entity name to entity id
237 | with codecs.open(train_triplet, 'r') as f:
238 | content = f.readlines()
239 | for line in content:
240 | triple = line.strip().split("\t")
241 | if len(triple) != 3:
242 | continue
243 |
244 | h_ = int(entities2id[triple[0]])
245 | r_ = int(relations2id[triple[1]])
246 | t_ = int(entities2id[triple[2]])
247 |
248 | triple_list.append([h_, r_, t_])
249 | if r_ in relation_head:
250 | if h_ in relation_head[r_]:
251 | relation_head[r_][h_] += 1
252 | else:
253 | relation_head[r_][h_] = 1
254 | else:
255 | relation_head[r_] = {}
256 | relation_head[r_][h_] = 1
257 |
258 | if r_ in relation_tail:
259 | if t_ in relation_tail[r_]:
260 | relation_tail[r_][t_] += 1
261 | else:
262 | relation_tail[r_][t_] = 1
263 | else:
264 | relation_tail[r_] = {}
265 | relation_tail[r_][t_] = 1
266 |
267 | for r_ in relation_head:
268 | sum1, sum2 = 0, 0
269 | for head in relation_head[r_]:
270 | sum1 += 1
271 | sum2 += relation_head[r_][head]
272 | tph = sum2 / sum1
273 | relation_tph[r_] = tph
274 |
275 | for r_ in relation_tail:
276 | sum1, sum2 = 0, 0
277 | for tail in relation_tail[r_]:
278 | sum1 += 1
279 | sum2 += relation_tail[r_][tail]
280 | hpt = sum2 / sum1
281 | relation_hpt[r_] = hpt
282 |
283 | print("Complete load. entity : %d , relation : %d , train triple : %d" % (
284 | len(entity), len(relation), len(triple_list)))
285 |
286 | # update config
287 | self.config.entity_num = len(entity) + 1 # add padding item 0
288 | self.config.relation_num = len(relation)
289 |
290 | # tph: tails per head, hpt: heads per tail
291 | return triple_list, relation_tph, relation_hpt
292 |
293 | def _set_data_path(self):
294 | # find file path
295 | cur_path = os.path.abspath(__file__)
296 | root = '\\'.join(cur_path.split('\\')[:-3])
297 | self.data_path = os.path.join(root, f'dataset/{self.dataset}')
298 |
299 | def _read_seq_data(self, file_path):
300 | # read data file
301 | data_list = []
302 | with open(file_path, 'r') as fr:
303 | for line in fr.readlines():
304 | item_seq = list(map(int, line.strip().split(self.sep)))
305 | # remove target items
306 | if self.filter_target:
307 | item_seq = self._filter_target(item_seq)
308 | # drop too short sequence
309 | if len(item_seq) < self.filter_len:
310 | continue
311 | item_seq = [item + self.item_id_pad for item in item_seq] # shift item id x to x + 1
312 | # statistic
313 | self.max_item = max(self.max_item, max(item_seq))
314 | self.total_seq_len += float(len(item_seq))
315 | self.num_users += 1
316 | data_list.append(item_seq)
317 | return data_list
318 |
319 | def _load_row_data(self):
320 | """
321 | load total data sequences
322 | """
323 | file_path = os.path.join(self.data_path, f'{self.dataset}.seq')
324 | seq_data_list = self._read_seq_data(file_path)
325 | self._set_statistic(seq_data_list)
326 |
327 | return seq_data_list
328 |
329 | def _set_statistic(self, seq_data_list=None):
330 | self.seq_avg_len = round(float(self.total_seq_len) / self.num_users, 4)
331 | self.density = round(float(self.total_seq_len) / self.num_users / self.max_item, 4)
332 | self.sparsity = 1 - self.density
333 | self.num_users = int(self.num_users)
334 | self.num_items = int(self.max_item + 1) # with padding item 0
335 |
336 | # calculate popularity
337 | self.popularity = [0. for _ in range(self.num_items)]
338 | for item_seq in seq_data_list:
339 | for item in item_seq:
340 | self.popularity[item] += 1.
341 | self.popularity = [p / self.total_seq_len for p in self.popularity]
342 |
343 | def _load_pre_split_data(self):
344 | """
345 | load data after split, xx.train, xx.eval, xx.test
346 | """
347 | # load xx.train, xx.eval
348 | train_file = os.path.join(self.data_path, f'{self.dataset}.train')
349 | eval_file = os.path.join(self.data_path, f'{self.dataset}.eval')
350 |
351 | train_data_list = self._read_seq_data(train_file)
352 | eval_data_list = self._read_seq_data(eval_file)
353 |
354 | train_x = [seq[:-1] for seq in train_data_list if len(seq) > 1]
355 | train_y = [seq[-1] for seq in train_data_list if len(seq) > 1]
356 | eval_x = [seq[:-1] for seq in eval_data_list if len(seq) > 1]
357 | eval_y = [seq[-1] for seq in eval_data_list if len(seq) > 1]
358 |
359 | self.row_train_data = copy.deepcopy(train_x), copy.deepcopy(train_y)
360 | # training data augmentation
361 | self._data_augmentation(train_x, train_y)
362 |
363 | # load xx.test
364 | if self.split_type == 'valid_and_test':
365 | test_file = os.path.join(self.data_path, f'{self.dataset}.test')
366 | test_data_list = self._read_seq_data(test_file)
367 | test_x = [seq[:-1] for seq in test_data_list if len(seq) > 1]
368 | test_y = [seq[-1] for seq in test_data_list if len(seq) > 1]
369 | self.test_data = (test_x, test_y)
370 | test_seq = [seq + [target] for seq, target in zip(test_x, test_y)]
371 |
372 | self.train_data = (train_x, train_y)
373 | self.eval_data = (eval_x, eval_y)
374 |
375 | # gather all sequences
376 | all_data_list = [seq + [target] for seq, target in zip(train_x, train_y)]
377 | eval_seq = [seq + [target] for seq, target in zip(eval_x, eval_y)]
378 | all_data_list.extend(eval_seq)
379 | if self.split_type == 'valid_and_test':
380 | all_data_list.extend(test_seq)
381 |
382 | self._set_statistic(all_data_list)
383 |
384 | def _train_test_split(self, seq_data_list):
385 | if self.split_type == 'valid_only':
386 | train_x, train_y, eval_x, eval_y = self._leave_one_out_split(seq_data_list)
387 | else: # valid and test
388 | if self.split_mode == 'LS':
389 | train_x = [item_seq[:-3] for item_seq in seq_data_list if len(item_seq) > 3]
390 | train_y = [item_seq[-3] for item_seq in seq_data_list if len(item_seq) > 3]
391 | eval_x = [item_seq[:-2] for item_seq in seq_data_list if len(item_seq) > 2]
392 | eval_y = [item_seq[-2] for item_seq in seq_data_list if len(item_seq) > 2]
393 | test_x = [item_seq[:-1] for item_seq in seq_data_list if len(item_seq) > 1]
394 | test_y = [item_seq[-1] for item_seq in seq_data_list if len(item_seq) > 1]
395 | else: # LS_R
396 | train_x, train_y, test_x, test_y = self._leave_one_out_split(seq_data_list)
397 | # split eval and test data by ratio
398 | eval_x, eval_y, test_x, test_y = self._split_by_ratio(test_x, test_y)
399 | self.test_data = (test_x, test_y)
400 |
401 | self.row_train_data = (copy.deepcopy(train_x), copy.deepcopy(train_y))
402 | # training data augmentation
403 | self._data_augmentation(train_x, train_y)
404 |
405 | self.train_data = (train_x, train_y)
406 | self.eval_data = (eval_x, eval_y)
407 |
408 | def _leave_one_out_split(self, seq_data):
409 | train_x = [item_seq[:-2] for item_seq in seq_data if len(item_seq) > 2]
410 | train_y = [item_seq[-2] for item_seq in seq_data if len(item_seq) > 2]
411 | eval_x = [item_seq[:-1] for item_seq in seq_data if len(item_seq) > 1]
412 | eval_y = [item_seq[-1] for item_seq in seq_data if len(item_seq) > 1]
413 | return train_x, train_y, eval_x, eval_y
414 |
415 | def prepare_kg_map(self):
416 | kg_path = os.path.join(self.data_path, f'{self.dataset}_kg.npy')
417 | kg_map_np = np.load(kg_path)
418 | zero_kg_emb = np.zeros((1, kg_map_np.shape[-1]))
419 | kg_map_np_pad = np.concatenate([zero_kg_emb, kg_map_np], axis=0)
420 | return torch.from_numpy(kg_map_np_pad)
421 |
422 | def prepare_specified_graph(self):
423 | assert isinstance(self.graph_type_list, list), f'graph_type should be a list.'
424 | graph_dict = {}
425 | for g_type in self.graph_type_list:
426 | if g_type not in self.valid_graphs:
427 | raise KeyError(f'Invalid graph type:{self.graph_type_list}. Choose from {self.valid_graphs}')
428 | if g_type == 'GGNN':
429 | continue # session graph will be constructed in dataset
430 | graph_dict[g_type] = getattr(self, f'prepare_{g_type.lower()}_graph')()
431 | if len(graph_dict) == 0:
432 | return None
433 | return graph_dict
434 |
435 | def prepare_bipartite_graph(self, bidirectional=True):
436 | u_i_edges = []
437 | i_u_edges = []
438 |
439 | for user, item_list in enumerate(self.row_train_data[0]):
440 | for item in item_list:
441 | u_i_edges.append((user, item))
442 | if bidirectional: # bidirectional graph
443 | i_u_edges.append((item, user))
444 | num_nodes_dict = {'user': len(self.row_train_data[0]), 'item': self.num_items}
445 |
446 | logging.info('loading graph...')
447 | if bidirectional:
448 | bipartite = dgl.heterograph({
449 | ('user', 'contact', 'item'): u_i_edges,
450 | ('item', 'contact_by', 'user'): i_u_edges
451 | }, num_nodes_dict=num_nodes_dict)
452 | else:
453 | bipartite = dgl.heterograph({
454 | ('user', 'contact', 'item'): u_i_edges
455 | }, num_nodes_dict=num_nodes_dict)
456 |
457 | logging.info(bipartite)
458 | return bipartite
459 |
460 | def prepare_transition_graph(self):
461 | """
462 | Returns:
463 | adj (sparse matrix): adjacency matrix of item transition graph
464 | """
465 | src, des = [], []
466 | for item_seq in self.row_train_data[0]:
467 | for i in range(len(item_seq) - 1):
468 | src.append(item_seq[i])
469 | des.append(item_seq[i + 1])
470 | adj = sparse.coo_matrix((np.ones(len(src), dtype=float), (src, des)), shape=(self.num_items, self.num_items))
471 |
472 | # adj = sparse.load_npz(f'{self.data_path}/global_adj.npz')
473 |
474 | # frequency divide in degree of destination node as edge weight
475 | # in_degree = adj.sum(0).reshape(1, -1)
476 | # in_degree[in_degree == 0] = 1
477 | # adj = adj.multiply(1 / in_degree)
478 |
479 | return adj
480 |
481 | def prepare_hyper_graph(self):
482 | """
483 | Returns:
484 | incidence_matrix (sparse matrix): incidence matrix for session hyper-graph
485 | """
486 | # incidence_matrix = sparse.load_npz(f'{self.data_path}/global_adj.npz')
487 | # return incidence_matrix
488 |
489 | nodes, edges = [], []
490 | for edge, item_seq in enumerate(self.row_train_data[0]):
491 | for item in item_seq:
492 | nodes.append(item)
493 | edges.append(edge)
494 | incidence_matrix = sparse.coo_matrix((np.ones(len(nodes), dtype=float), (nodes, edges)))
495 |
496 | return incidence_matrix
497 |
498 | def data_log_verbose(self, order):
499 |
500 | logging.info(f'[{order}] Data Statistic '.ljust(47, '-'))
501 | logging.info(f'dataset: {self.dataset}')
502 | logging.info(f'user number: {self.num_users}')
503 | logging.info(f'item number: {self.max_item}')
504 | logging.info(f'average seq length: {self.seq_avg_len}')
505 | logging.info(f'density: {self.density} sparsity: {self.sparsity}')
506 | if self.data_aug:
507 | logging.info(f'data after augmentation:')
508 | if self.split_type == 'valid_only':
509 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}')
510 | else:
511 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}\ttest '
512 | f'samples: {len(self.test_data[0])}')
513 | else:
514 | logging.info(f'data without augmentation:')
515 | if self.split_type == 'valid_only':
516 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}')
517 | else:
518 | logging.info(f'train samples: {len(self.train_data[0])}\teval samples: {len(self.eval_data[0])}\ttest '
519 | f'samples: {len(self.test_data[0])}')
520 |
521 | def _filter_target(self, item_seq):
522 | target = item_seq[-1]
523 | item_seq = list(filter(lambda x: x != target, item_seq[:-1]))
524 | item_seq.append(target)
525 | return item_seq
526 |
527 | def _split_by_ratio(self, test_x, test_y):
528 | """
529 | random split by specified ratio
530 | """
531 | eval_size = int(len(test_y) * self.eval_ratio)
532 | index = np.arange(len(test_y))
533 | np.random.shuffle(index)
534 |
535 | eval_index = index[:eval_size]
536 | test_index = index[eval_size:]
537 |
538 | eval_x = [test_x[i] for i in eval_index]
539 | eval_y = [test_y[i] for i in eval_index]
540 |
541 | test_x = [test_x[i] for i in test_index]
542 | test_y = [test_y[i] for i in test_index]
543 |
544 | return eval_x, eval_y, test_x, test_y
545 |
546 | def _data_augmentation(self, train_x, train_y):
547 | if not self.data_aug:
548 | return
549 | if not self.use_tar_seq:
550 | aug_train_x = [item_seq[:last] for item_seq in train_x for last in range(1, len(item_seq))]
551 | aug_train_y = [item_seq[nextI] for item_seq in train_x for nextI in range(1, len(item_seq))]
552 | else:
553 | aug_train_x = [item_seq[:last] for item_seq in train_x for last in range(1, len(item_seq))]
554 | aug_train_y = [item_seq[nextI: nextI + self.tar_seq_len] for item_seq in train_x for nextI in
555 | range(1, len(item_seq))]
556 | train_x.extend(aug_train_x)
557 | train_y.extend(aug_train_y)
558 |
--------------------------------------------------------------------------------
/src/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import random
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset, DataLoader, default_collate
7 | from src.utils.utils import neg_sample
8 | from src.model.data_augmentation import Crop, Mask, Reorder
9 | from src.model.data_augmentation import AUGMENTATIONS
10 |
11 |
12 | def load_specified_dataset(model_name, config):
13 | if model_name in ['CL4SRec', 'ICLRec', 'IOCRec']:
14 | return CL4SRecDataset
15 | return SequentialDataset
16 |
17 |
18 | class BaseSequentialDataset(Dataset):
19 | def __init__(self, config, data_pair, additional_data_dict=None, train=True):
20 | super(BaseSequentialDataset, self).__init__()
21 | self.batch_batch_dict = {}
22 | self.num_items = config.num_items
23 | self.config = config
24 | self.train = train
25 | self.dataset = config.dataset
26 | self.max_len = config.max_len
27 | self.item_seq = data_pair[0]
28 | self.label = data_pair[1]
29 |
30 | def get_SRtask_input(self, idx):
31 | item_seq = self.item_seq[idx]
32 | target = self.label[idx]
33 |
34 | seq_len = len(item_seq) if len(item_seq) < self.max_len else self.max_len
35 | item_seq = item_seq[-self.max_len:]
36 | item_seq = item_seq + (self.max_len - seq_len) * [0]
37 |
38 | assert len(item_seq) == self.max_len
39 |
40 | return (torch.tensor(item_seq, dtype=torch.long),
41 | torch.tensor(seq_len, dtype=torch.long),
42 | torch.tensor(target, dtype=torch.long))
43 |
44 | def __getitem__(self, idx):
45 | return self.get_SRtask_input(idx)
46 |
47 | def __len__(self):
48 | return len(self.item_seq)
49 |
50 | def collate_fn(self, x):
51 | return self.basic_SR_collate_fn(x)
52 |
53 | def basic_SR_collate_fn(self, x):
54 | """
55 | x: [(seq_1, len_1, tar_1), ..., (seq_n, len_n, tar_n)]
56 | """
57 | item_seq, seq_len, target = default_collate(x)
58 | self.batch_batch_dict['item_seq'] = item_seq
59 | self.batch_batch_dict['seq_len'] = seq_len
60 | self.batch_batch_dict['target'] = target
61 | return self.batch_batch_dict
62 |
63 |
64 | class SequentialDataset(BaseSequentialDataset):
65 | def __init__(self, config, data_pair, additional_data_dict=None, train=True):
66 | super(SequentialDataset, self).__init__(config, data_pair, additional_data_dict, train)
67 |
68 |
69 | class CL4SRecDataset(BaseSequentialDataset):
70 | def __init__(self, config, data_pair, additional_data_dict=None, train=True):
71 | super(CL4SRecDataset, self).__init__(config, data_pair, additional_data_dict, train)
72 | self.mask_id = self.num_items
73 | self.aug_types = config.aug_types
74 | self.n_views = 2
75 | self.augmentations = []
76 |
77 | self.load_augmentor()
78 |
79 | def load_augmentor(self):
80 | for aug in self.aug_types:
81 | if aug == 'mask':
82 | self.augmentations.append(Mask(gamma=self.config.mask_ratio, mask_id=self.mask_id))
83 | else:
84 | self.augmentations.append(AUGMENTATIONS[aug](getattr(self.config, f'{aug}_ratio')))
85 |
86 | def __getitem__(self, index):
87 | # for eval and test
88 | if not self.train:
89 | return self.get_SRtask_input(index)
90 |
91 | # for training
92 | # contrast learning augmented views
93 | item_seq = self.item_seq[index]
94 | target = self.label[index]
95 | aug_type = np.random.choice([i for i in range(len(self.augmentations))],
96 | size=self.n_views, replace=True)
97 | aug_seq_1 = self.augmentations[aug_type[0]](item_seq)
98 | aug_seq_2 = self.augmentations[aug_type[1]](item_seq)
99 |
100 | aug_seq_1 = aug_seq_1[-self.max_len:]
101 | aug_seq_2 = aug_seq_2[-self.max_len:]
102 |
103 | aug_len_1 = len(aug_seq_1)
104 | aug_len_2 = len(aug_seq_2)
105 |
106 | aug_seq_1 = aug_seq_1 + [0] * (self.max_len - len(aug_seq_1))
107 | aug_seq_2 = aug_seq_2 + [0] * (self.max_len - len(aug_seq_2))
108 | assert len(aug_seq_1) == self.max_len
109 | assert len(aug_seq_2) == self.max_len
110 |
111 | # recommendation sequences
112 | seq_len = len(item_seq) if len(item_seq) < self.max_len else self.max_len
113 | item_seq = item_seq[-self.max_len:]
114 | item_seq = item_seq + (self.max_len - seq_len) * [0]
115 |
116 | assert len(item_seq) == self.max_len
117 |
118 | cur_tensors = (torch.tensor(item_seq, dtype=torch.long),
119 | torch.tensor(seq_len, dtype=torch.long),
120 | torch.tensor(target, dtype=torch.long),
121 | torch.tensor(aug_seq_1, dtype=torch.long),
122 | torch.tensor(aug_seq_2, dtype=torch.long),
123 | torch.tensor(aug_len_1, dtype=torch.long),
124 | torch.tensor(aug_len_2, dtype=torch.long))
125 |
126 | return cur_tensors
127 |
128 | def collate_fn(self, x):
129 | if not self.train:
130 | return self.basic_SR_collate_fn(x)
131 |
132 | item_seq, seq_len, target, aug_seq_1, aug_seq_2, aug_len_1, aug_len_2 = default_collate(x)
133 |
134 | self.batch_batch_dict['item_seq'] = item_seq
135 | self.batch_batch_dict['seq_len'] = seq_len
136 | self.batch_batch_dict['target'] = target
137 | self.batch_batch_dict['aug_seq_1'] = aug_seq_1
138 | self.batch_batch_dict['aug_seq_2'] = aug_seq_2
139 | self.batch_batch_dict['aug_len_1'] = aug_len_1
140 | self.batch_batch_dict['aug_len_2'] = aug_len_2
141 |
142 | return self.batch_batch_dict
143 |
144 |
145 | class MISPPretrainDataset(Dataset):
146 | """
147 | Masked Item & Segment Prediction (MISP)
148 | """
149 |
150 | def __init__(self, config, data_pair, additional_data_dict=None):
151 | self.mask_id = config.num_items
152 | self.mask_ratio = config.mask_ratio
153 | self.num_items = config.num_items + 1
154 | self.config = config
155 | self.item_seq = data_pair[0]
156 | self.label = data_pair[1]
157 | self.max_len = config.max_len
158 | self.long_sequence = []
159 |
160 | for seq in self.item_seq:
161 | self.long_sequence.extend(seq)
162 |
163 | def __len__(self):
164 | return len(self.item_seq)
165 |
166 | def __getitem__(self, index):
167 | sequence = self.item_seq[index] # pos_items
168 |
169 | # Masked Item Prediction
170 | masked_item_sequence = []
171 | neg_items = []
172 | pos_items = sequence
173 |
174 | item_set = set(sequence)
175 | for item in sequence[:-1]:
176 | prob = random.random()
177 | if prob < self.mask_ratio:
178 | masked_item_sequence.append(self.mask_id)
179 | neg_items.append(neg_sample(item_set, self.num_items))
180 | else:
181 | masked_item_sequence.append(item)
182 | neg_items.append(item)
183 | # add mask at the last position
184 | masked_item_sequence.append(self.mask_id)
185 | neg_items.append(neg_sample(item_set, self.num_items))
186 |
187 | assert len(masked_item_sequence) == len(sequence)
188 | assert len(pos_items) == len(sequence)
189 | assert len(neg_items) == len(sequence)
190 |
191 | # Segment Prediction
192 | if len(sequence) < 2:
193 | masked_segment_sequence = sequence
194 | pos_segment = sequence
195 | neg_segment = sequence
196 | else:
197 | sample_length = random.randint(1, len(sequence) // 2)
198 | start_id = random.randint(0, len(sequence) - sample_length)
199 | neg_start_id = random.randint(0, len(self.long_sequence) - sample_length)
200 | pos_segment = sequence[start_id: start_id + sample_length]
201 | neg_segment = self.long_sequence[neg_start_id:neg_start_id + sample_length]
202 | masked_segment_sequence = sequence[:start_id] + [self.mask_id] * sample_length + sequence[
203 | start_id + sample_length:]
204 | pos_segment = [self.mask_id] * start_id + pos_segment + [self.mask_id] * (
205 | len(sequence) - (start_id + sample_length))
206 | neg_segment = [self.mask_id] * start_id + neg_segment + [self.mask_id] * (
207 | len(sequence) - (start_id + sample_length))
208 |
209 | assert len(masked_segment_sequence) == len(sequence)
210 | assert len(pos_segment) == len(sequence)
211 | assert len(neg_segment) == len(sequence)
212 |
213 | # crop sequence
214 | masked_item_sequence = masked_item_sequence[-self.max_len:]
215 | pos_items = pos_items[-self.max_len:]
216 | neg_items = neg_items[-self.max_len:]
217 | masked_segment_sequence = masked_segment_sequence[-self.max_len:]
218 | pos_segment = pos_segment[-self.max_len:]
219 | neg_segment = neg_segment[-self.max_len:]
220 |
221 | # padding sequence
222 | pad_len = self.max_len - len(sequence)
223 | masked_item_sequence = masked_item_sequence + [0] * pad_len
224 | pos_items = pos_items + [0] * pad_len
225 | neg_items = neg_items + [0] * pad_len
226 | masked_segment_sequence = masked_segment_sequence + [0] * pad_len
227 | pos_segment = pos_segment + [0] * pad_len
228 | neg_segment = neg_segment + [0] * pad_len
229 |
230 | assert len(masked_item_sequence) == self.max_len
231 | assert len(pos_items) == self.max_len
232 | assert len(neg_items) == self.max_len
233 | assert len(masked_segment_sequence) == self.max_len
234 | assert len(pos_segment) == self.max_len
235 | assert len(neg_segment) == self.max_len
236 |
237 | cur_tensors = (torch.tensor(masked_item_sequence, dtype=torch.long),
238 | torch.tensor(pos_items, dtype=torch.long),
239 | torch.tensor(neg_items, dtype=torch.long),
240 | torch.tensor(masked_segment_sequence, dtype=torch.long),
241 | torch.tensor(pos_segment, dtype=torch.long),
242 | torch.tensor(neg_segment, dtype=torch.long))
243 | return cur_tensors
244 |
245 | def collate_fn(self, x):
246 | tensor_dict = {}
247 | tensor_list = [torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).long() for j in range(len(x[0]))]
248 | masked_item_sequence, pos_items, neg_items, \
249 | masked_segment_sequence, pos_segment, neg_segment = tensor_list
250 |
251 | tensor_dict['masked_item_sequence'] = masked_item_sequence
252 | tensor_dict['pos_items'] = pos_items
253 | tensor_dict['neg_items'] = neg_items
254 | tensor_dict['masked_segment_sequence'] = masked_segment_sequence
255 | tensor_dict['pos_segment'] = pos_segment
256 | tensor_dict['neg_segment'] = neg_segment
257 |
258 | return tensor_dict
259 |
260 |
261 | class MIMPretrainDataset(Dataset):
262 | def __init__(self, config, data_pair, additional_data_dict=None):
263 | self.config = config
264 | self.aug_types = config.aug_types
265 | self.mask_id = config.num_items
266 |
267 | self.item_seq = data_pair[0]
268 | self.label = data_pair[1]
269 | self.max_len = config.max_len
270 | self.n_views = 2
271 | self.augmentations = []
272 | self.load_augmentor()
273 |
274 | def load_augmentor(self):
275 | for aug in self.aug_types:
276 | if aug == 'mask':
277 | self.augmentations.append(Mask(gamma=self.config.mask_ratio, mask_id=self.mask_id))
278 | else:
279 | self.augmentations.append(AUGMENTATIONS[aug](getattr(self.config, f'{aug}_ratio')))
280 |
281 | def __getitem__(self, index):
282 | aug_type = np.random.choice([i for i in range(len(self.augmentations))],
283 | size=self.n_views, replace=False)
284 | item_seq = self.item_seq[index]
285 | aug_seq_1 = self.augmentations[aug_type[0]](item_seq)
286 | aug_seq_2 = self.augmentations[aug_type[1]](item_seq)
287 |
288 | aug_seq_1 = aug_seq_1[-self.max_len:]
289 | aug_seq_2 = aug_seq_2[-self.max_len:]
290 |
291 | aug_len_1 = len(aug_seq_1)
292 | aug_len_2 = len(aug_seq_2)
293 |
294 | aug_seq_1 = aug_seq_1 + [0] * (self.max_len - len(aug_seq_1))
295 | aug_seq_2 = aug_seq_2 + [0] * (self.max_len - len(aug_seq_2))
296 | assert len(aug_seq_1) == self.max_len
297 | assert len(aug_seq_2) == self.max_len
298 |
299 | aug_seq_tensors = (torch.tensor(aug_seq_1, dtype=torch.long),
300 | torch.tensor(aug_seq_2, dtype=torch.long),
301 | torch.tensor(aug_len_1, dtype=torch.long),
302 | torch.tensor(aug_len_2, dtype=torch.long))
303 |
304 | return aug_seq_tensors
305 |
306 | def __len__(self):
307 | '''
308 | consider n_view of a single sequence as one sample
309 | '''
310 | return len(self.item_seq)
311 |
312 | def collate_fn(self, x):
313 | tensor_dict = {}
314 | tensor_list = [torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).long() for j in range(len(x[0]))]
315 | aug_seq_1, aug_seq_2, aug_len_1, aug_len_2 = tensor_list
316 |
317 | tensor_dict['aug_seq_1'] = aug_seq_1
318 | tensor_dict['aug_seq_2'] = aug_seq_2
319 | tensor_dict['aug_len_1'] = aug_len_1
320 | tensor_dict['aug_len_2'] = aug_len_2
321 |
322 | return tensor_dict
323 |
324 |
325 | class PIDPretrainDataset(Dataset):
326 | def __init__(self, config, data_pair, additional_data_dict=None):
327 | self.num_items = config.num_items
328 | self.item_seq = data_pair[0]
329 | self.label = data_pair[1]
330 | self.config = config
331 | self.max_len = config.max_len
332 | self.pseudo_ratio = config.pseudo_ratio
333 |
334 | def __getitem__(self, index):
335 | item_seq = self.item_seq[index]
336 | pseudo_seq = []
337 | target = []
338 |
339 | for item in item_seq:
340 | if random.random() < self.pseudo_ratio:
341 | pseudo_item = neg_sample(item_seq, self.num_items)
342 | pseudo_seq.append(pseudo_item)
343 | target.append(0)
344 | else:
345 | pseudo_seq.append(item)
346 | target.append(1)
347 |
348 | pseudo_seq = pseudo_seq[-self.max_len:]
349 | target = target[-self.max_len:]
350 |
351 | pseudo_seq = pseudo_seq + [0] * (self.max_len - len(pseudo_seq))
352 | target = target + [0] * (self.max_len - len(target))
353 | assert len(pseudo_seq) == self.max_len
354 | assert len(target) == self.max_len
355 | pseudo_seq_tensors = (torch.tensor(pseudo_seq, dtype=torch.long),
356 | torch.tensor(target, dtype=torch.float))
357 |
358 | return pseudo_seq_tensors
359 |
360 | def __len__(self):
361 | '''
362 | consider n_view of a single sequence as one sample
363 | '''
364 | return len(self.item_seq)
365 |
366 | def collate_fn(self, x):
367 | tensor_dict = {}
368 | tensor_list = [torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).long() for j in range(len(x[0]))]
369 | pseudo_seq, target = tensor_list
370 |
371 | tensor_dict['pseudo_seq'] = pseudo_seq
372 | tensor_dict['target'] = target
373 |
374 | return tensor_dict
375 |
376 |
377 |
--------------------------------------------------------------------------------
/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__init__.py
--------------------------------------------------------------------------------
/src/evaluation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/evaluation/__pycache__/estimator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__pycache__/estimator.cpython-38.pyc
--------------------------------------------------------------------------------
/src/evaluation/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/evaluation/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/src/evaluation/estimator.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from tqdm import tqdm
4 | from src.evaluation.metrics import Metric
5 | from src.utils.utils import batch_to_device
6 |
7 |
8 | class Estimator:
9 | def __init__(self, config):
10 | self.popularity = None
11 | self.config = config
12 | self.metrics = config.metric
13 | self.k_list = config.k
14 | self.dev = config.device
15 | self.metric_res_dict = {}
16 | self.eval_loss = 0.
17 | self.max_k = max(self.k_list)
18 | self.split_type = config.split_type
19 | self.eval_mode = config.eval_mode
20 | self.neg_size = 0
21 | if self.eval_mode != 'full':
22 | self.neg_size = int(re.findall(r'\d+', self.eval_mode)[0])
23 | self.eval_mode = self.eval_mode[:3]
24 | self._reset_metrics()
25 |
26 | def _reset_metrics(self):
27 | for metric in self.metrics:
28 | for k in self.k_list:
29 | self.metric_res_dict[f'{metric}@{k}'] = 0.
30 | self.eval_loss = 0.
31 |
32 | def load_item_popularity(self, pop):
33 | self.popularity = torch.tensor(pop, dtype=torch.float, device=self.dev)
34 |
35 | @torch.no_grad()
36 | def evaluate(self, eval_loader, model):
37 | model.eval()
38 | self._reset_metrics()
39 |
40 | eval_sample_size = len(eval_loader.dataset)
41 | eval_iter = tqdm(enumerate(eval_loader), total=len(eval_loader))
42 | eval_iter.set_description(f'do evaluation...')
43 | for _, batch_dict in eval_iter:
44 | batch_to_device(batch_dict, self.dev)
45 | logits = model(batch_dict)
46 | model_loss = model.get_loss(batch_dict, logits)
47 | logits = self.neg_sample_select(batch_dict, logits)
48 | self.calc_metrics(logits, batch_dict['target'])
49 | self.eval_loss += model_loss.item()
50 |
51 | for metric in self.metrics:
52 | for k in self.k_list:
53 | self.metric_res_dict[f'{metric}@{k}'] /= float(eval_sample_size)
54 |
55 | eval_loss = self.eval_loss / float(len(eval_loader))
56 |
57 | return self.metric_res_dict, eval_loss
58 |
59 | @torch.no_grad()
60 | def test(self, test_loader, model):
61 | model.eval()
62 | self._reset_metrics()
63 |
64 | test_sample_size = len(test_loader.dataset)
65 | test_iter = tqdm(enumerate(test_loader), total=len(test_loader))
66 | test_iter.set_description(f'do test...')
67 | for _, batch_dict in test_iter:
68 | batch_to_device(batch_dict, self.dev)
69 | logits = model(batch_dict)
70 | logits = self.neg_sample_select(batch_dict, logits)
71 | self.calc_metrics(logits, batch_dict['target'])
72 |
73 | for metric in self.metrics:
74 | for k in self.k_list:
75 | self.metric_res_dict[f'{metric}@{k}'] /= float(test_sample_size)
76 | return self.metric_res_dict
77 |
78 | def calc_metrics(self, prediction, target):
79 | _, topk_index = torch.topk(prediction, self.max_k, -1) # [batch, max_k]
80 | topk_socre = torch.gather(prediction, index=topk_index, dim=-1)
81 | idx_sorted = torch.argsort(topk_socre, dim=-1, descending=True)
82 | top_k_item_sorted = torch.gather(topk_index, index=idx_sorted, dim=-1)
83 |
84 | for metric in self.metrics:
85 | for k in self.k_list:
86 | score = getattr(Metric, f'{metric.upper()}')(top_k_item_sorted, target, k)
87 | self.metric_res_dict[f'{metric}@{k}'] += score
88 |
89 | def calc_metrics_(self, prediction, target):
90 | _, topk_index = torch.topk(prediction, self.max_k, -1) # [batch, max_k]
91 | topk_socre = torch.gather(prediction, index=topk_index, dim=-1)
92 | idx_sorted = torch.argsort(topk_socre, dim=-1, descending=True)
93 | max_k_item_sorted = torch.gather(topk_index, index=idx_sorted, dim=-1)
94 |
95 | metric_res_dict = {}
96 | for metric in self.metrics:
97 | for k in self.k_list:
98 | score = getattr(Metric, f'{metric.upper()}')(max_k_item_sorted, target, k)
99 | metric_res_dict[f'{metric}@{k}'] += score
100 |
101 | return metric_res_dict
102 |
103 | def neg_sample_select(self, data_dict, prediction):
104 | if self.eval_mode == 'full':
105 | return prediction
106 | item_seq, target = data_dict['item_seq'], data_dict['target']
107 | # sample negative items
108 | target = target.unsqueeze(-1)
109 | mask_item = torch.cat([item_seq, target], dim=-1) # [batch, max_len + 1]
110 |
111 | if self.eval_mode == 'uni':
112 | sample_prob = torch.ones_like(prediction, device=self.dev) / prediction.size(-1)
113 | elif self.eval_mode == 'pop':
114 | if self.popularity.size(0) != prediction.size(-1): # ignore mask item
115 | self.popularity = torch.cat([self.popularity, torch.zeros((1,)).to(self.dev)], -1)
116 | sample_prob = self.popularity.unsqueeze(0).repeat(prediction.size(0), 1)
117 | else:
118 | raise NotImplementedError('Choose eval_model from [full, popxxx, unixxx]')
119 | sample_prob = sample_prob.scatter(dim=-1, index=mask_item, value=0.)
120 | neg_item = torch.multinomial(sample_prob, self.neg_size) # [batch, neg_size]
121 | # mask non-rank items
122 | rank_item = torch.cat([neg_item, target], dim=-1) # [batch, neg_size + 1]
123 | mask = torch.ones_like(prediction, device=self.dev).bool()
124 | mask = mask.scatter(dim=-1, index=rank_item, value=False)
125 | masked_pred = torch.masked_fill(prediction, mask, 0.)
126 |
127 | return masked_pred
128 |
129 |
--------------------------------------------------------------------------------
/src/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class Metric:
6 | @staticmethod
7 | def HIT(prediction, target, k):
8 | """
9 | calculate Hit-Ratio (HR) @ k
10 | :param prediction: [batch, max_k], sorted along dim -1
11 | :param target: [batch]
12 | :param k: scalar
13 | :return: average hit-ratio score among all input data
14 | """
15 | prediction, target = Metric.process(prediction, target, k)
16 | hit = ((prediction - target) == 0).sum(dim=-1).double()
17 | hit = hit.sum().item()
18 | return hit
19 |
20 | @staticmethod
21 | def NDCG(prediction, target, k):
22 | """
23 | calculate Normalized Discounted Cumulative Gain (NDCG) @ k.
24 | Note that the Ideal Discounted Cumulative Gain (IDCG) is equal to all users, so it can be ignored.
25 | :param prediction: [batch, max_k], sorted along dim -1
26 | :param target: [batch]
27 | :param k: scalar
28 | :return: average hit-ratio score among all input data
29 | """
30 | prediction, target = Metric.process(prediction, target, k)
31 | hit = ((prediction - target) == 0).sum(dim=-1).double() # [batch_size]
32 | row, col = ((prediction - target) == 0.).nonzero(as_tuple=True) # [hit_size]
33 | ndcg = hit.scatter(index=row, src=1. / torch.log2(col + 2).double(), dim=-1)
34 | ndcg = ndcg.sum().item()
35 | return ndcg
36 |
37 | @staticmethod
38 | def MRR(prediction, target, k):
39 | """
40 | calculate Mean Reciprocal Rank (MRR) @ k
41 | :param prediction: [batch, max_k], sorted along dim -1
42 | :param target: [batch]
43 | :param k: scalar
44 | :return: average hit-ratio score among all input data
45 | """
46 | prediction, target = Metric.process(prediction, target, k)
47 | hit = ((prediction - target) == 0).sum(dim=-1).double() # [batch_size]
48 | row, col = ((prediction - target) == 0.).nonzero(as_tuple=True) # [hit_size]
49 | mrr = hit.scatter(index=row, src=1. / (col + 1).double(), dim=-1)
50 | mrr = mrr.sum().item()
51 | return mrr
52 |
53 | @staticmethod
54 | def RECALL(prediction, target, k):
55 | """
56 | calculate recall @ k, similar to hit-ration under SR (Sequential recommendation) setting
57 | :param prediction: [batch, max_k], sorted along dim -1
58 | :param target: [batch]
59 | :param k: scalar
60 | :return: average hit-ratio score among all input data
61 | """
62 | return Metric.HIT(prediction, target, k)
63 |
64 | @staticmethod
65 | def process(prediction, target, k):
66 | if k < prediction.size(-1):
67 | prediction = prediction[:, :k] # [batch, k]
68 | target = target.unsqueeze(-1) # [batch, 1]
69 | return prediction, target
70 |
71 |
72 | if __name__ == '__main__':
73 | a = torch.arange(12).view(3, -1)
74 | a[1, -1] = 0
75 | print(a)
76 | hit = (a == 0).sum(dim=-1).float()
77 | hit_index, rank = (a == 0).nonzero(as_tuple=True)
78 | print(hit_index, rank)
79 | score = torch.scatter(hit, index=hit_index, src=1. / torch.log2(rank + 2), dim=-1)
80 | print(score)
81 | score = score.mean().cpu().numpy()
82 | print(score)
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from src.model.sequential_recommender import *
2 | from src.model.cl_based_seq_recommender import *
3 |
--------------------------------------------------------------------------------
/src/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/abstract_recommeder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/abstract_recommeder.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/data_augmentation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/data_augmentation.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/sequential_encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/__pycache__/sequential_encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/abstract_recommeder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class AbstractRecommender(nn.Module):
7 | def __init__(self, config):
8 | super(AbstractRecommender, self).__init__()
9 | self.num_items = config.num_items
10 | self.loss_type = config.loss_type
11 | self.max_len = config.max_len
12 | self.dev = config.device
13 | self.cross_entropy = nn.CrossEntropyLoss()
14 |
15 | def forward(self, data_dict: dict):
16 | """
17 | Args:
18 | data_dict: dict
19 | """
20 | pass
21 |
22 | def train_forward(self, data_dict: dict):
23 | """
24 | Args:
25 | data_dict: dict
26 | """
27 | logits = self.forward(data_dict)
28 | return self.get_loss(data_dict, logits)
29 |
30 | def load_basic_SR_data(self, data_dict):
31 | return data_dict['item_seq'], data_dict['seq_len'], data_dict['target']
32 |
33 | def get_loss(self, data_dict, logits, item_seq=None, target=None):
34 | if item_seq is None:
35 | item_seq = data_dict['item_seq']
36 | if target is None:
37 | target = data_dict['target']
38 |
39 | if self.loss_type.upper() == 'BCE':
40 | neg_item = self.get_negative_items(item_seq, target, num_samples=1)
41 | pos_score = torch.gather(logits, -1, target.unsqueeze(-1))
42 | neg_score = torch.gather(logits, -1, neg_item)
43 | loss = -torch.mean(F.logsigmoid(pos_score) + torch.log(1 - torch.sigmoid(neg_score)).sum(-1))
44 | elif self.loss_type.upper() == 'BPR': # BPR loss
45 | neg_item = self.get_negative_items(item_seq, target, num_samples=1)
46 | pos_score = torch.gather(logits, -1, target.unsqueeze(-1))
47 | neg_score = torch.gather(logits, -1, neg_item)
48 | loss = -torch.mean(F.logsigmoid(pos_score - neg_score))
49 | elif self.loss_type.upper() == 'CE': # CE loss
50 | # prediction = F.softmax(logits, -1)
51 | loss = self.cross_entropy(logits, target)
52 | # pos_score = torch.gather(prediction, -1, target.unsqueeze(-1))
53 | # loss = -torch.mean(torch.log(pos_score))
54 | else:
55 | loss = torch.zeros((1,)).to(self.dev)
56 | return loss
57 |
58 | def gather_index(self, output, index):
59 | """
60 | :param output: [batch, max_len, H]
61 | :param index: [batch]
62 | :return: [batch, H}
63 | """
64 | gather_index = index.view(-1, 1, 1).repeat(1, 1, output.size(-1))
65 | gather_output = output.gather(dim=1, index=gather_index)
66 | return gather_output.squeeze()
67 |
68 | def get_target_and_length(self, target_info):
69 | """
70 | :param target_info: target information dict
71 | :return:
72 | """
73 | target = target_info['target'] # [batch, ep_len]
74 | try:
75 | tar_len = target_info['target_len']
76 | except:
77 | raise Exception(f"{self.__class__.__name__} requires target sequences, set use_tar_seq to true in "
78 | f"experimental settings")
79 | return target, tar_len
80 |
81 | def get_negative_items(self, input_item, target, num_samples=1):
82 | """
83 | :param input_item: [batch_size, max_len]
84 | :param sample_size: [batch_size, num_samples]
85 | :return:
86 | """
87 | sample_prob = torch.ones(input_item.size(0), self.num_items, device=target.device)
88 | sample_prob.scatter_(-1, input_item, 0.)
89 | sample_prob.scatter_(-1, target.unsqueeze(-1), 0.)
90 | neg_items = torch.multinomial(sample_prob, num_samples)
91 |
92 | return neg_items
93 |
94 | def pack_to_batch(self, prediction):
95 | if prediction.dim() < 2:
96 | prediction = prediction.unsqueeze(0)
97 | return prediction
98 |
99 | def calc_total_params(self):
100 | """
101 | Calculate Total Parameters
102 | :return: number of parameters
103 | """
104 | return sum([p.nelement() for p in self.parameters()])
105 |
106 | def load_pretrain_model(self, pretrain_model):
107 | """
108 | load pretraining model, default: load all parameters
109 | """
110 | self.load_state_dict(pretrain_model.state_dict())
111 | del pretrain_model
112 |
113 | def MISP_pretrain_forward(self, data_dict: dict):
114 | pass
115 |
116 | def MIM_pretrain_forward(self, data_dict: dict):
117 | pass
118 |
119 | def PID_pretrain_forward(self, data_dict: dict):
120 | pass
121 |
122 |
123 | class AbstractRLRecommender(AbstractRecommender):
124 | def __init__(self, config):
125 | super(AbstractRLRecommender, self).__init__(config)
126 |
127 | def sample_neg_action(self, masked_action, neg_size):
128 | """
129 | :param masked_action: [batch, max_len]
130 | :return: neg_action, [batch, neg_size]
131 | """
132 | sample_prob = torch.ones(masked_action.size(0), self.num_items, device=masked_action.device)
133 | sample_prob = sample_prob.scatter(-1, masked_action, 0.)
134 | neg_action = torch.multinomial(sample_prob, neg_size)
135 |
136 | return neg_action
137 |
138 | def state_transfer(self, pre_item_seq, action, seq_len):
139 | """
140 | Parameters
141 | ----------
142 | pre_item_seq: torch.LongTensor, [batch_size, max_len]
143 | action: torch.LongTensor, [batch_size]
144 | seq_len: torch.LongTensor, [batch_size]
145 |
146 | Return
147 | ------
148 | next_state_seq: torch.LongTensor, [batch_size, max_len]
149 | """
150 | new_item_seq = pre_item_seq.clone().detach()
151 | action = action.unsqueeze(-1)
152 | seq_len = seq_len.unsqueeze(-1)
153 | max_len = pre_item_seq.size(1)
154 |
155 | padding_col = torch.zeros_like(action, dtype=torch.long, device=action.device)
156 | new_item_seq = torch.cat([new_item_seq, padding_col], -1)
157 | new_item_seq = new_item_seq.scatter(-1, seq_len, action)
158 | new_item_seq = new_item_seq[:, 1:]
159 |
160 | new_seq_len = seq_len.squeeze() + 1
161 | new_seq_len[new_seq_len > max_len] = max_len
162 |
163 | return new_item_seq, new_seq_len
164 |
--------------------------------------------------------------------------------
/src/model/cl_based_seq_recommender/__init__.py:
--------------------------------------------------------------------------------
1 | from .cl4srec import CL4SRec, CL4SRec_config
2 | from .iocrec import IOCRec, IOCRec_config
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/src/model/cl_based_seq_recommender/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/cl_based_seq_recommender/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/cl_based_seq_recommender/__pycache__/cl4srec.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/cl_based_seq_recommender/__pycache__/cl4srec.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/cl_based_seq_recommender/__pycache__/iocrec.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/cl_based_seq_recommender/__pycache__/iocrec.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/cl_based_seq_recommender/cl4srec.py:
--------------------------------------------------------------------------------
1 | """
2 | 2021. Contrastive Learning for Sequential Recommendation. In
3 | SIGIR ’21: Proceedings of the 44th International ACM SIGIR Conference on
4 | Research and Development in Information Retrieval (SIGIR’21)
5 | """
6 |
7 | import sys
8 | import torch.nn.functional as F
9 | from src.model.abstract_recommeder import AbstractRecommender
10 | import argparse
11 | import torch
12 | import torch.nn as nn
13 | from src.model.sequential_encoder import Transformer
14 | from src.model.loss import InfoNCELoss
15 | from src.utils.utils import HyperParamDict
16 |
17 |
18 | class CL4SRec(AbstractRecommender):
19 | def __init__(self, config, additional_data_dict):
20 | super(CL4SRec, self).__init__(config)
21 | self.mask_id = self.num_items
22 | self.num_items = self.num_items + 1
23 |
24 | self.do_pretraining = config.do_pretraining
25 | self.embed_size = config.embed_size
26 | self.initializer_range = config.initializer_range
27 | self.aug_views = 2
28 | self.tao = config.tao
29 | self.all_hidden = config.all_hidden
30 | self.lamda = config.lamda # cl loss weight if on pretraining
31 |
32 | self.item_embedding = nn.Embedding(self.num_items, self.embed_size, padding_idx=0)
33 | self.position_embedding = nn.Embedding(self.max_len, self.embed_size)
34 | self.input_layer_norm = nn.LayerNorm(self.embed_size, eps=config.layer_norm_eps)
35 | self.input_dropout = nn.Dropout(config.hidden_dropout)
36 | self.trm_encoder = Transformer(embed_size=self.embed_size,
37 | ffn_hidden=config.ffn_hidden,
38 | num_blocks=config.num_blocks,
39 | num_heads=config.num_heads,
40 | attn_dropout=config.attn_dropout,
41 | hidden_dropout=config.hidden_dropout,
42 | layer_norm_eps=config.layer_norm_eps)
43 | # self.proj = nn.Linear(self.embed_size, self.embed_size)
44 | self.nce_loss = InfoNCELoss(temperature=self.tao,
45 | similarity_type='dot')
46 | # self.target_dropout = nn.Dropout(0.2)
47 | # self.temperature = 0.07
48 | self.apply(self._init_weights)
49 |
50 | def _init_weights(self, module):
51 | if isinstance(module, (nn.Embedding, nn.Linear)):
52 | module.weight.data.normal_(mean=0.0, std=self.initializer_range)
53 | elif isinstance(module, nn.LayerNorm):
54 | module.weight.data.fill_(1.)
55 | module.bias.data.zero_()
56 | elif isinstance(module, nn.Linear) and module.bias is not None:
57 | module.bias.data.zero_()
58 |
59 | def train_forward(self, data_dict):
60 | logits = self.forward(data_dict)
61 | rec_loss = self.get_loss(data_dict, logits)
62 |
63 | if self.do_pretraining:
64 | return rec_loss
65 |
66 | # jointly training
67 | cl_loss = self.MIM_pretrain_forward(data_dict)
68 |
69 | return rec_loss + self.lamda * cl_loss
70 |
71 | def forward(self, data_dict):
72 | item_seq, seq_len, _ = self.load_basic_SR_data(data_dict)
73 | seq_embedding = self.seq_encoding(item_seq, seq_len)
74 | candidates = self.item_embedding.weight
75 |
76 | # add extra normalization
77 | # seq_embedding = F.normalize(seq_embedding)
78 | # candidates = self.target_dropout(candidates)
79 | # candidates = F.normalize(candidates)
80 |
81 | # logits = seq_embedding @ candidates.t() / self.temperature
82 | logits = seq_embedding @ candidates.t()
83 |
84 | return logits
85 |
86 | def position_encoding(self, item_input):
87 | seq_embedding = self.item_embedding(item_input)
88 | position = torch.arange(self.max_len, device=item_input.device).unsqueeze(0)
89 | position = position.expand_as(item_input).long()
90 | pos_embedding = self.position_embedding(position)
91 | seq_embedding += pos_embedding
92 | seq_embedding = self.input_layer_norm(seq_embedding)
93 | seq_embedding = self.input_dropout(seq_embedding)
94 |
95 | return seq_embedding
96 |
97 | def seq_encoding(self, item_seq, seq_len, return_all=False):
98 | seq_embedding = self.position_encoding(item_seq)
99 | out_seq_embedding = self.trm_encoder(item_seq, seq_embedding)
100 | if not return_all:
101 | out_seq_embedding = self.gather_index(out_seq_embedding, seq_len - 1)
102 | return out_seq_embedding
103 |
104 | def MIM_pretrain_forward(self, data_dict):
105 | aug_seq_1, aug_len_1 = data_dict['aug_seq_1'], data_dict['aug_len_1']
106 | aug_seq_2, aug_len_2 = data_dict['aug_seq_2'], data_dict['aug_len_2']
107 | # aug_seq_2, aug_len_2 = data_dict['item_seq'], data_dict['seq_len']
108 | # sequence encoding, [batch,embed_size]
109 | aug_seq_encoding_1 = self.seq_encoding(aug_seq_1, aug_len_1, return_all=self.all_hidden)
110 | aug_seq_encoding_2 = self.seq_encoding(aug_seq_2, aug_len_2, return_all=self.all_hidden)
111 |
112 | # add normalization
113 | # aug_seq_encoding_1 = F.normalize(aug_seq_encoding_1)
114 | # aug_seq_encoding_2 = F.normalize(aug_seq_encoding_2)
115 |
116 | # aug_seq_encoding_1 = self.proj(aug_seq_encoding_1)
117 | # aug_seq_encoding_2 = self.proj(aug_seq_encoding_2)
118 | cl_loss = self.nce_loss(aug_seq_encoding_1, aug_seq_encoding_2)
119 | # B = aug_seq_1.size(0)
120 | # aug_seq_encoding_1 = aug_seq_encoding_1.view(B, -1)
121 | # aug_seq_encoding_2 = aug_seq_encoding_2.view(B, -1)
122 | # dot_loss = (aug_seq_encoding_1 * aug_seq_encoding_2).sum(-1).mean()
123 |
124 | return cl_loss
125 |
126 |
127 | def CL4SRec_config():
128 | parser = HyperParamDict('CL4SRec-Pretraining default hyper-parameters')
129 | parser.add_argument('--model', default='CL4SRec', type=str)
130 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge'])
131 | # Contrast Learning Hyper Params
132 | parser.add_argument('--do_pretraining', action='store_false', help='if do pretraining')
133 | parser.add_argument('--training_fashion', default='pretraining', choices=['pretraining', 'jointly_training'])
134 | parser.add_argument('--pretraining_task', default='MIM', type=str, choices=['MISP', 'MIM', 'PID'],
135 | help='pretraining task:' \
136 | 'MISP: Mask Item Prediction and Mask Segment Prediction' \
137 | 'MIM: Mutual Information Maximization' \
138 | 'PID: Pseudo Item Discrimination'
139 | )
140 | parser.add_argument('--aug_types', default=['crop', 'mask', 'reorder'], help='augmentation types')
141 | parser.add_argument('--crop_ratio', default=0.7, type=float,
142 | help='Crop augmentation: proportion of cropped subsequence in origin sequence')
143 | parser.add_argument('--mask_ratio', default=0.5, type=float,
144 | help='Mask augmentation: proportion of masked items in origin sequence')
145 | parser.add_argument('--reorder_ratio', default=0.8, type=float,
146 | help='Reorder augmentation: proportion of reordered subsequence in origin sequence')
147 | parser.add_argument('--all_hidden', action='store_false', help='all hidden states for cl')
148 | parser.add_argument('--tao', default=1., type=float, help='temperature for softmax')
149 | parser.add_argument('--lamda', default=0.1, type=float,
150 | help='weight for contrast learning loss, only work when jointly training')
151 |
152 | # Transformer
153 | parser.add_argument('--embed_size', default=128, type=int)
154 | parser.add_argument('--ffn_hidden', default=512, type=int, help='hidden dim for feed forward network')
155 | parser.add_argument('--num_blocks', default=2, type=int, help='number of transformer block')
156 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention')
157 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate')
158 | parser.add_argument('--attn_dropout', default=0., type=float, help='dropout rate for attention')
159 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps')
160 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range')
161 |
162 | parser.add_argument('--loss_type', default='CE', type=str, choices=['CE', 'BPR', 'BCE', 'CUSTOM'])
163 |
164 | return parser
165 |
166 |
167 | if __name__ == '__main__':
168 | a = torch.arange(5)
169 | b = a.clone() + 10
170 | c = torch.stack([a, b], 0)
171 | print(c)
172 | print(c.transpose(0, 1))
173 |
--------------------------------------------------------------------------------
/src/model/cl_based_seq_recommender/iocrec.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/7/7
3 | # @Author : Chenglong Shi
4 | # @Email : hiderulo@163.com
5 |
6 | r"""
7 | IOCRec
8 | ################################################
9 |
10 | Reference:
11 | Xuewei Li et al., "Multi-Intention Oriented Contrastive Learning for Sequential Recommendation" in WSDM 2023.
12 |
13 | """
14 |
15 | import copy
16 | import math
17 | import sys
18 | import torch.nn.functional as F
19 | from src.model.abstract_recommeder import AbstractRecommender
20 | import argparse
21 | import torch
22 | import torch.nn as nn
23 | from src.model.sequential_encoder import Transformer
24 | from src.model.loss import InfoNCELoss
25 | from src.utils.utils import HyperParamDict
26 |
27 |
28 | class IOCRec(AbstractRecommender):
29 | def __init__(self, config, additional_data_dict):
30 | super(IOCRec, self).__init__(config)
31 | self.mask_id = self.num_items
32 | self.num_items = self.num_items + 1
33 | self.embed_size = config.embed_size
34 | self.initializer_range = config.initializer_range
35 | self.aug_views = 2
36 | self.tao = config.tao
37 | self.all_hidden = config.all_hidden
38 | self.lamda = config.lamda
39 | self.k_intention = config.k_intention
40 |
41 | self.item_embedding = nn.Embedding(self.num_items, self.embed_size, padding_idx=0)
42 | self.position_embedding = nn.Embedding(self.max_len, self.embed_size)
43 | self.input_layer_norm = nn.LayerNorm(self.embed_size, eps=config.layer_norm_eps)
44 | self.input_dropout = nn.Dropout(config.hidden_dropout)
45 | self.local_encoder = Transformer(embed_size=self.embed_size,
46 | ffn_hidden=config.ffn_hidden,
47 | num_blocks=config.num_blocks,
48 | num_heads=config.num_heads,
49 | attn_dropout=config.attn_dropout,
50 | hidden_dropout=config.hidden_dropout,
51 | layer_norm_eps=config.layer_norm_eps)
52 | self.global_seq_encoder = GlobalSeqEncoder(embed_size=self.embed_size,
53 | max_len=self.max_len,
54 | dropout=config.hidden_dropout)
55 | self.disentangle_encoder = DisentangleEncoder(k_intention=self.k_intention,
56 | embed_size=self.embed_size,
57 | max_len=self.max_len)
58 | self.nce_loss = InfoNCELoss(temperature=self.tao,
59 | similarity_type='dot')
60 | self.apply(self._init_weights)
61 |
62 | def _init_weights(self, module):
63 | if isinstance(module, (nn.Embedding, nn.Linear)):
64 | module.weight.data.normal_(mean=0.0, std=self.initializer_range)
65 | elif isinstance(module, nn.LayerNorm):
66 | module.weight.data.fill_(1.)
67 | module.bias.data.zero_()
68 | elif isinstance(module, nn.Linear) and module.bias is not None:
69 | module.bias.data.zero_()
70 |
71 | def train_forward(self, data_dict):
72 | _, _, target = self.load_basic_SR_data(data_dict)
73 | aug_seq_1, aug_len_1 = data_dict['aug_seq_1'], data_dict['aug_len_1']
74 | aug_seq_2, aug_len_2 = data_dict['aug_seq_2'], data_dict['aug_len_2']
75 |
76 | # rec task
77 | max_logits = self.forward(data_dict)
78 | rec_loss = self.cross_entropy(max_logits, target)
79 |
80 | # cl task
81 | B = target.size(0)
82 | aug_local_emb_1 = self.local_seq_encoding(aug_seq_1, aug_len_1, return_all=self.all_hidden)
83 | aug_global_emb_1 = self.global_seq_encoding(aug_seq_1, aug_len_1)
84 | disentangled_intention_1 = self.disentangle_encoder(aug_local_emb_1, aug_global_emb_1, aug_len_1)
85 | disentangled_intention_1 = disentangled_intention_1.view(B * self.k_intention, -1) # [B * K, L * D]
86 |
87 | aug_local_emb_2 = self.local_seq_encoding(aug_seq_2, aug_len_2, return_all=self.all_hidden)
88 | aug_global_emb_2 = self.global_seq_encoding(aug_seq_2, aug_len_2)
89 | disentangled_intention_2 = self.disentangle_encoder(aug_local_emb_2, aug_global_emb_2, aug_len_2)
90 | disentangled_intention_2 = disentangled_intention_2.view(B * self.k_intention, -1) # [B * K, L * D]
91 |
92 | cl_loss = self.nce_loss(disentangled_intention_1, disentangled_intention_2)
93 |
94 | return rec_loss + self.lamda * cl_loss
95 |
96 | def forward(self, data_dict):
97 | item_seq, seq_len, _ = self.load_basic_SR_data(data_dict)
98 | local_seq_emb = self.local_seq_encoding(item_seq, seq_len, return_all=True) # [B, L, D]
99 | global_seq_emb = self.global_seq_encoding(item_seq, seq_len)
100 | disentangled_intention_emb = self.disentangle_encoder(local_seq_emb, global_seq_emb, seq_len) # [B, K, L, D]
101 |
102 | gather_index = seq_len.view(-1, 1, 1, 1).repeat(1, self.k_intention, 1, self.embed_size)
103 | disentangled_intention_emb = disentangled_intention_emb.gather(2, gather_index - 1).squeeze() # [B, K, D]
104 | candidates = self.item_embedding.weight.unsqueeze(0) # [1, num_items, D]
105 | logits = disentangled_intention_emb @ candidates.permute(0, 2, 1) # [B, K, num_items]
106 | max_logits, _ = torch.max(logits, 1)
107 |
108 | return max_logits
109 |
110 | def position_encoding(self, item_input):
111 | seq_embedding = self.item_embedding(item_input)
112 | position = torch.arange(self.max_len, device=item_input.device).unsqueeze(0)
113 | position = position.expand_as(item_input).long()
114 | pos_embedding = self.position_embedding(position)
115 | seq_embedding += pos_embedding
116 | seq_embedding = self.input_layer_norm(seq_embedding)
117 | seq_embedding = self.input_dropout(seq_embedding)
118 |
119 | return seq_embedding
120 |
121 | def local_seq_encoding(self, item_seq, seq_len, return_all=False):
122 | seq_embedding = self.position_encoding(item_seq)
123 | out_seq_embedding = self.local_encoder(item_seq, seq_embedding)
124 | if not return_all:
125 | out_seq_embedding = self.gather_index(out_seq_embedding, seq_len - 1)
126 | return out_seq_embedding
127 |
128 | def global_seq_encoding(self, item_seq, seq_len):
129 | return self.global_seq_encoder(item_seq, seq_len, self.item_embedding)
130 |
131 |
132 | class GlobalSeqEncoder(nn.Module):
133 | def __init__(self, embed_size, max_len, dropout=0.5):
134 | super(GlobalSeqEncoder, self).__init__()
135 | self.embed_size = embed_size
136 | self.max_len = max_len
137 | self.dropout = nn.Dropout(dropout)
138 |
139 | self.Q_s = nn.Parameter(torch.randn(max_len, embed_size))
140 | self.K_linear = nn.Linear(embed_size, embed_size)
141 | self.V_linear = nn.Linear(embed_size, embed_size)
142 |
143 | def forward(self, item_seq, seq_len, item_embeddings):
144 | """
145 | Args:
146 | item_seq (tensor): [B, L]
147 | seq_len (tensor): [B]
148 | item_embeddings (tensor): [num_items, D], item embedding table
149 |
150 | Returns:
151 | global_seq_emb: [B, L, D]
152 | """
153 | item_emb = item_embeddings(item_seq) # [B, L, D]
154 | item_key = self.K_linear(item_emb)
155 | item_value = self.V_linear(item_emb)
156 |
157 | attn_logits = self.Q_s @ item_key.permute(0, 2, 1) # [B, L, L]
158 | attn_score = F.softmax(attn_logits, -1)
159 | global_seq_emb = self.dropout(attn_score @ item_value)
160 |
161 | return global_seq_emb
162 |
163 |
164 | class DisentangleEncoder(nn.Module):
165 | def __init__(self, k_intention, embed_size, max_len):
166 | super(DisentangleEncoder, self).__init__()
167 | self.embed_size = embed_size
168 |
169 | self.intentions = nn.Parameter(torch.randn(k_intention, embed_size))
170 | self.pos_fai = nn.Embedding(max_len, embed_size)
171 | self.rou = nn.Parameter(torch.randn(embed_size, ))
172 | self.W = nn.Linear(embed_size, embed_size)
173 | self.layer_norm_1 = nn.LayerNorm(embed_size)
174 | self.layer_norm_2 = nn.LayerNorm(embed_size)
175 | self.layer_norm_3 = nn.LayerNorm(embed_size)
176 | self.layer_norm_4 = nn.LayerNorm(embed_size)
177 | self.layer_norm_5 = nn.LayerNorm(embed_size)
178 |
179 | def forward(self, local_item_emb, global_item_emb, seq_len):
180 | """
181 | Args:
182 | local_item_emb: [B, L, D]
183 | global_item_emb: [B, L, D]
184 | seq_len: [B]
185 | Returns:
186 | disentangled_intention_emb: [B, K, L, D]
187 | """
188 | local_disen_emb = self.intention_disentangling(local_item_emb, seq_len)
189 | global_siden_emb = self.intention_disentangling(global_item_emb, seq_len)
190 | disentangled_intention_emb = local_disen_emb + global_siden_emb
191 |
192 | return disentangled_intention_emb
193 |
194 | def item2IntentionScore(self, item_emb):
195 | """
196 | Args:
197 | item_emb: [B, L, D]
198 | Returns:
199 | score: [B, L, K]
200 | """
201 | item_emb_norm = self.layer_norm_1(item_emb) # [B, L, D]
202 | intention_norm = self.layer_norm_2(self.intentions).unsqueeze(0) # [1, K, D]
203 |
204 | logits = item_emb_norm @ intention_norm.permute(0, 2, 1) # [B, L, K]
205 | score = F.softmax(logits / math.sqrt(self.embed_size), -1)
206 |
207 | return score
208 |
209 | def item2AttnWeight(self, item_emb, seq_len):
210 | """
211 | Args:
212 | item_emb: [B, L, D]
213 | seq_len: [B]
214 | Returns:
215 | score: [B, L]
216 | """
217 | B, L = item_emb.size(0), item_emb.size(1)
218 | dev = item_emb.device
219 | item_query_row = item_emb[torch.arange(B).to(dev), seq_len - 1] # [B, D]
220 | item_query_row += self.pos_fai(seq_len - 1) + self.rou
221 | item_query = self.layer_norm_3(item_query_row).unsqueeze(1) # [B, 1, D]
222 |
223 | pos_fai_tensor = self.pos_fai(torch.arange(L).to(dev)).unsqueeze(0) # [1, L, D]
224 | item_key_hat = self.layer_norm_4(item_emb + pos_fai_tensor)
225 | item_key = item_key_hat + torch.relu(self.W(item_key_hat))
226 |
227 | logits = item_query @ item_key.permute(0, 2, 1) # [B, 1, L]
228 | logits = logits.squeeze() / math.sqrt(self.embed_size)
229 | score = F.softmax(logits, -1)
230 |
231 | return score
232 |
233 | def intention_disentangling(self, item_emb, seq_len):
234 | """
235 | Args:
236 | item_emb: [B. L, D]
237 | seq_len: [B]
238 | Returns:
239 | item_disentangled_emb: [B, K, L, D]
240 | """
241 | # get score
242 | item2intention_score = self.item2IntentionScore(item_emb)
243 | item_attn_weight = self.item2AttnWeight(item_emb, seq_len)
244 |
245 | # get disentangled embedding
246 | score_fuse = item2intention_score * item_attn_weight.unsqueeze(-1) # [B, L, K]
247 | score_fuse = score_fuse.permute(0, 2, 1).unsqueeze(-1) # [B, K, L, 1]
248 | item_emb_k = item_emb.unsqueeze(1) # [B, 1, L, D]
249 | disentangled_item_emb = self.layer_norm_5(score_fuse * item_emb_k)
250 | return disentangled_item_emb
251 |
252 |
253 | def IOCRec_config():
254 | parser = HyperParamDict('IOCRec default hyper-parameters')
255 | parser.add_argument('--model', default='IOCRec', type=str)
256 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge'])
257 | # Contrast Learning Hyper Params
258 | parser.add_argument('--aug_types', default=['crop', 'mask', 'reorder'], help='augmentation types')
259 | parser.add_argument('--crop_ratio', default=0.4, type=float,
260 | help='Crop augmentation: proportion of cropped subsequence in origin sequence')
261 | parser.add_argument('--mask_ratio', default=0.3, type=float,
262 | help='Mask augmentation: proportion of masked items in origin sequence')
263 | parser.add_argument('--reorder_ratio', default=0.2, type=float,
264 | help='Reorder augmentation: proportion of reordered subsequence in origin sequence')
265 | parser.add_argument('--all_hidden', action='store_false', help='all hidden states for cl')
266 | parser.add_argument('--tao', default=1., type=float, help='temperature for softmax')
267 | parser.add_argument('--lamda', default=0.1, type=float,
268 | help='weight for contrast learning loss, only work when jointly training')
269 | parser.add_argument('--k_intention', default=4, type=int, help='number of disentangled intention')
270 | # Transformer
271 | parser.add_argument('--embed_size', default=64, type=int)
272 | parser.add_argument('--ffn_hidden', default=128, type=int, help='hidden dim for feed forward network')
273 | parser.add_argument('--num_blocks', default=3, type=int, help='number of transformer block')
274 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention')
275 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate')
276 | parser.add_argument('--attn_dropout', default=0.5, type=float, help='dropout rate for attention')
277 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps')
278 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range')
279 |
280 | parser.add_argument('--loss_type', default='CE', type=str, choices=['CE', 'BPR', 'BCE', 'CUSTOM'])
281 |
282 | return parser
283 |
284 |
285 | if __name__ == '__main__':
286 | a = torch.randn(20, 5, 50, 32)
287 | seq_len = torch.randperm(4).long()
288 | gather_index = seq_len.view(-1, 1, 1, 1).repeat(1, 5, 1, 32)
289 | res = torch.gather(a, 2, gather_index)
290 |
291 | print(res.size())
292 |
--------------------------------------------------------------------------------
/src/model/data_augmentation.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import random
4 | import numpy as np
5 | import torch
6 | import math
7 |
8 |
9 | class AbstractDataAugmentor:
10 | def __init__(self, aug_ratio):
11 | self.aug_ratio = aug_ratio
12 |
13 | def transform(self, item_seq, seq_len):
14 | """
15 | :param item_seq: torch.LongTensor, [batch, max_len]
16 | :param seq_len: torch.LongTensor, [batch]
17 | :return: aug_seq: torch.LongTensor, [batch, max_len]
18 | """
19 | raise NotImplementedError
20 |
21 |
22 | class CropAugmentor(AbstractDataAugmentor):
23 | """
24 | Torch version.
25 | """
26 |
27 | def __init__(self, aug_ratio):
28 | super(CropAugmentor, self).__init__(aug_ratio)
29 |
30 | def transform(self, item_seq, seq_len):
31 | """
32 | :param item_seq: torch.LongTensor, [batch, max_len]
33 | :param seq_len: torch.LongTensor, [batch]
34 | :return: aug_seq: torch.LongTensor, [batch, max_len]
35 | """
36 | max_len = item_seq.size(-1)
37 | aug_seq_len = torch.ceil(seq_len * self.aug_ratio).long()
38 | # get start index
39 | index = torch.arange(max_len, device=seq_len.device)
40 | index = index.expand_as(item_seq)
41 | up_bound = (seq_len - aug_seq_len).unsqueeze(-1)
42 | prob = torch.zeros_like(item_seq, device=seq_len.device).float()
43 | prob[index <= up_bound] = 1.
44 | start_index = torch.multinomial(prob, 1)
45 | # item indices in subsequence
46 | gather_index = torch.arange(max_len, device=seq_len.device)
47 | gather_index = gather_index.expand_as(item_seq)
48 | gather_index = gather_index + start_index
49 | max_seq_len = aug_seq_len.unsqueeze(-1)
50 | gather_index[index >= max_seq_len] = 0
51 | # augmented subsequence
52 | aug_seq = torch.gather(item_seq, -1, gather_index).long()
53 | aug_seq[index >= max_seq_len] = 0
54 |
55 | return aug_seq, aug_seq_len
56 |
57 |
58 | class MaskDAugmentor(AbstractDataAugmentor):
59 | """
60 | Torch version.
61 | """
62 |
63 | def __init__(self, aug_ratio, mask_id=0):
64 | super(MaskDAugmentor, self).__init__(aug_ratio)
65 | self.mask_id = mask_id
66 |
67 | def transform(self, item_seq, seq_len):
68 | """
69 | :param item_seq: torch.LongTensor, [batch, max_len]
70 | :param seq_len: torch.LongTensor, [batch]
71 | :return: aug_seq: torch.LongTensor, [batch, max_len]
72 | """
73 | max_len = item_seq.size(-1)
74 | aug_seq = item_seq.clone()
75 | aug_seq_len = seq_len.clone()
76 | # get mask item id
77 | mask_item_size = math.ceil(max_len * self.aug_ratio)
78 | prob = torch.ones_like(item_seq, device=seq_len.device).float()
79 | masked_item_id = torch.multinomial(prob, mask_item_size)
80 | # mask
81 | aug_seq = aug_seq.scatter(-1, masked_item_id, self.mask_id)
82 |
83 | return aug_seq, aug_seq_len
84 |
85 |
86 | class ReorderAugmentor(AbstractDataAugmentor):
87 | """
88 | Torch version.
89 | """
90 |
91 | def __init__(self, aug_ratio):
92 | super(ReorderAugmentor, self).__init__(aug_ratio)
93 |
94 | def transform(self, item_seq, seq_len):
95 | """
96 | Parameters
97 | ----------
98 | item_seq: [batch_size, max_len]
99 | seq_len: [batch_size]
100 |
101 | Returns
102 | -------
103 | aug_item_seq: [batch_size, max_len]
104 | aug_seq_len: [batch_size]
105 | """
106 | dev = item_seq.device
107 | batch_size, max_len = item_seq.shape
108 |
109 | # get start position
110 | reorder_size = (seq_len * self.aug_ratio).ceil().long().unsqueeze(-1) # [B, 1]
111 | position_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev) # [B, L]
112 | sample_prob = (position_tensor <= seq_len.unsqueeze(-1) - reorder_size).bool().float()
113 | start_index = torch.multinomial(sample_prob, num_samples=1) # [B, 1]
114 |
115 | # get reorder item mask
116 | reorder_item_mask = (start_index <= position_tensor) & (position_tensor < start_index + reorder_size)
117 |
118 | # reorder operation
119 | tmp_reorder_tensor = torch.zeros_like(item_seq).long().to(dev)
120 | tmp_reorder_tensor[reorder_item_mask] = item_seq[reorder_item_mask]
121 | rand_index = torch.randperm(max_len)
122 | tmp_reorder_tensor = tmp_reorder_tensor[:, rand_index]
123 |
124 | # put reordered items back
125 | aug_item_seq = item_seq.clone()
126 | aug_item_seq[reorder_item_mask] = tmp_reorder_tensor[tmp_reorder_tensor > 0]
127 |
128 | return aug_item_seq, seq_len
129 |
130 |
131 | class RepeatAugmentor(AbstractDataAugmentor):
132 | """
133 | Torch version.
134 | """
135 |
136 | def __init__(self, aug_ratio):
137 | super(RepeatAugmentor, self).__init__(aug_ratio)
138 |
139 | def transform_1(self, item_seq, seq_len):
140 | """
141 | Parameters
142 | ----------
143 | item_seq: [batch_size, max_len]
144 | seq_len: [batch_size]
145 |
146 | Returns
147 | -------
148 | aug_item_seq: [batch_size, max_len]
149 | aug_seq_len: [batch_size]
150 | """
151 | dev = item_seq.device
152 | batch_size, max_len = item_seq.shape
153 | sample_size = math.ceil(max_len * self.aug_ratio)
154 |
155 | # get reordered index
156 | valid_pos_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev)
157 | valid_pos_tensor[item_seq == 0] = -1
158 | rand_index = torch.randperm(max_len).to(dev)
159 | rand_pos_tensor = valid_pos_tensor[:, rand_index]
160 | reordered_index = item_seq.clone()
161 | reordered_index[reordered_index > 0] = rand_pos_tensor[rand_pos_tensor > 0]
162 |
163 | # break off
164 |
165 | # sample repeat elements
166 | sample_prob = torch.ones_like(item_seq).float().to(dev)
167 | repeat_pos = torch.multinomial(sample_prob, num_samples=sample_size)
168 | sorted_repeat_pos, _ = torch.sort(repeat_pos, dim=-1)
169 | repeat_element = item_seq.gather(dim=-1, index=sorted_repeat_pos).long()
170 |
171 | # augmented item sequences
172 | padding_seq = torch.zeros((batch_size, sample_size)).to(dev)
173 | aug_item_seq = torch.cat([item_seq, padding_seq], dim=-1).long() # [B, L + L']
174 |
175 | # get insert position mask of sampled item
176 | valid_pos_tensor = torch.arange(sample_size).unsqueeze(0).to(dev)
177 | ele_insert_pos = sorted_repeat_pos + valid_pos_tensor
178 | insert_mask = torch.zeros_like(aug_item_seq).to(item_seq)
179 | insert_mask = insert_mask.scatter(dim=-1, index=ele_insert_pos, value=1).bool()
180 |
181 | # set elements
182 | aug_item_seq[insert_mask] = repeat_element.flatten()
183 | aug_item_seq[~insert_mask] = item_seq.flatten()
184 |
185 | # slice
186 | full_seq_size = (seq_len == max_len).bool().sum()
187 | new_seq_len = (aug_item_seq > 0).sum(-1)
188 | _, sorted_idx = torch.sort(new_seq_len, dim=0, descending=True)
189 | _, restore_idx = torch.sort(sorted_idx, dim=0)
190 |
191 | sorted_aug_seq = aug_item_seq[sorted_idx]
192 | full_aug_seq = sorted_aug_seq[:full_seq_size]
193 | full_aug_seq = full_aug_seq[:, -max_len:]
194 | non_full_aug_seq = sorted_aug_seq[full_seq_size:]
195 | non_full_aug_seq = non_full_aug_seq[:, :max_len]
196 | aug_item_seq = torch.cat([full_aug_seq, non_full_aug_seq], dim=0)
197 |
198 | # restore position
199 | aug_item_seq = aug_item_seq[restore_idx]
200 | aug_seq_len = (aug_item_seq > 0).bool().sum(-1)
201 |
202 | return aug_item_seq, aug_seq_len
203 |
204 | def transform(self, item_seq, seq_len):
205 | """
206 | Parameters
207 | ----------
208 | item_seq: [batch_size, max_len]
209 | seq_len: [batch_size]
210 |
211 | Returns
212 | -------
213 | aug_item_seq: [batch_size, max_len]
214 | aug_seq_len: [batch_size]
215 | """
216 | dev = item_seq.device
217 | batch_size, max_len = item_seq.shape
218 | sample_size = int(item_seq.size(-1) * self.aug_ratio)
219 |
220 | # sample repeat elements
221 | sample_prob = torch.ones_like(item_seq).float().to(dev)
222 | repeat_pos = torch.multinomial(sample_prob, num_samples=sample_size)
223 | sorted_repeat_pos, _ = torch.sort(repeat_pos, dim=-1)
224 | repeat_element = item_seq.gather(dim=-1, index=sorted_repeat_pos).long()
225 |
226 | # augmented item sequences
227 | padding_seq = torch.zeros((batch_size, sample_size)).to(dev)
228 | aug_item_seq = torch.cat([item_seq, padding_seq], dim=-1).long() # [B, L + L']
229 |
230 | # get insert position mask of sampled item
231 | position_tensor = torch.arange(sample_size).unsqueeze(0).to(dev)
232 | ele_insert_pos = sorted_repeat_pos + position_tensor
233 | insert_mask = torch.zeros_like(aug_item_seq).to(item_seq)
234 | insert_mask = insert_mask.scatter(dim=-1, index=ele_insert_pos, value=1).bool()
235 |
236 | # set elements
237 | aug_item_seq[insert_mask] = repeat_element.flatten()
238 | aug_item_seq[~insert_mask] = item_seq.flatten()
239 |
240 | # slice
241 | full_seq_size = (seq_len == max_len).bool().sum()
242 | new_seq_len = (aug_item_seq > 0).sum(-1)
243 | _, sorted_idx = torch.sort(new_seq_len, dim=0, descending=True)
244 | _, restore_idx = torch.sort(sorted_idx, dim=0)
245 |
246 | sorted_aug_seq = aug_item_seq[sorted_idx]
247 | full_aug_seq = sorted_aug_seq[:full_seq_size]
248 | full_aug_seq = full_aug_seq[:, -max_len:]
249 | non_full_aug_seq = sorted_aug_seq[full_seq_size:]
250 | non_full_aug_seq = non_full_aug_seq[:, :max_len]
251 | aug_item_seq = torch.cat([full_aug_seq, non_full_aug_seq], dim=0)
252 |
253 | # restore position
254 | aug_item_seq = aug_item_seq[restore_idx]
255 | aug_seq_len = (aug_item_seq > 0).bool().sum(-1)
256 |
257 | return aug_item_seq, aug_seq_len
258 |
259 |
260 | class DropAugmentor(AbstractDataAugmentor):
261 | """
262 | Torch version of item drop operation.
263 | """
264 |
265 | def __init__(self, aug_ratio):
266 | super(DropAugmentor, self).__init__(aug_ratio)
267 |
268 | def transform(self, item_seq, seq_len, drop_prob=None):
269 | """
270 | Parameters
271 | ----------
272 | item_seq: [batch_size, max_len]
273 | seq_len: [batch_size]
274 | drop_prob: [batch_size, max_len]
275 |
276 | Returns
277 | -------
278 | aug_item_seq: [batch_size, max_len]
279 | aug_seq_len: [batch_size]
280 | """
281 |
282 | dev = item_seq.device
283 | batch_size, max_len = item_seq.shape
284 | drop_size = int(item_seq.size(-1) * self.aug_ratio)
285 |
286 | # sample drop item indices
287 | if drop_prob is None:
288 | drop_prob = torch.ones_like(item_seq).float().to(dev)
289 | drop_indices = torch.multinomial(drop_prob, num_samples=drop_size) # [B, drop_size]
290 |
291 | # fill 0 items
292 | row_dropped_item_seq = item_seq.scatter(-1, drop_indices, 0).long()
293 | valid_item_mask = (row_dropped_item_seq > 0).bool()
294 | dropped_seq_len = valid_item_mask.sum(-1) # [B]
295 | position_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev) # [B, L]
296 | valid_pos_mask = (position_tensor < dropped_seq_len.unsqueeze(-1)).bool()
297 |
298 | # post-process
299 | dropped_item_seq = torch.zeros_like(item_seq).to(dev)
300 | dropped_item_seq[valid_pos_mask] = row_dropped_item_seq[valid_item_mask]
301 |
302 | # avoid all 0 item
303 | empty_seq_mask = (dropped_seq_len == 0).bool()
304 | empty_seq_mask = empty_seq_mask.unsqueeze(-1).repeat(1, max_len)
305 | empty_seq_mask[:, 1:] = 0
306 | dropped_item_seq[empty_seq_mask] = item_seq[empty_seq_mask]
307 | dropped_seq_len = (dropped_item_seq > 0).sum(-1) # [B]
308 |
309 | return dropped_item_seq, dropped_seq_len
310 |
311 |
312 | class CauseCropAugmentor(AbstractDataAugmentor):
313 | """
314 | Torch version.
315 | """
316 |
317 | def __init__(self, aug_ratio):
318 | super(CauseCropAugmentor, self).__init__(aug_ratio)
319 |
320 | def transform(self, item_seq, seq_len, critical_mask=None):
321 | """
322 | :param item_seq: torch.LongTensor, [batch, max_len]
323 | :param seq_len: torch.LongTensor, [batch]
324 | :return: aug_seq: torch.LongTensor, [batch, max_len]
325 | """
326 | max_len = item_seq.size(-1)
327 | aug_seq_len = torch.ceil(seq_len * self.aug_ratio).long()
328 | # get start index
329 | index = torch.arange(max_len, device=seq_len.device)
330 | index = index.expand_as(item_seq)
331 | up_bound = (seq_len - aug_seq_len).unsqueeze(-1)
332 | prob = torch.zeros_like(item_seq, device=seq_len.device).float()
333 | prob[index <= up_bound] = 1.
334 | start_index = torch.multinomial(prob, 1)
335 | # item indices in subsequence
336 | gather_index = torch.arange(max_len, device=seq_len.device)
337 | gather_index = gather_index.expand_as(item_seq)
338 | gather_index = gather_index + start_index
339 | max_seq_len = aug_seq_len.unsqueeze(-1)
340 | gather_index[index >= max_seq_len] = 0
341 | # augmented subsequence
342 | aug_seq = torch.gather(item_seq, -1, gather_index).long()
343 | aug_seq[index >= max_seq_len] = 0
344 |
345 | return aug_seq, aug_seq_len
346 |
347 |
348 | class CauseReorderAugmentor(AbstractDataAugmentor):
349 | """
350 | Torch version.
351 | """
352 |
353 | def __init__(self, aug_ratio):
354 | super(CauseReorderAugmentor, self).__init__(aug_ratio)
355 |
356 | def transform(self, item_seq, seq_len):
357 | """
358 | Parameters
359 | ----------
360 | item_seq: [batch_size, max_len]
361 | seq_len: [batch_size]
362 |
363 | Returns
364 | -------
365 | aug_item_seq: [batch_size, max_len]
366 | aug_seq_len: [batch_size]
367 | """
368 | dev = item_seq.device
369 | batch_size, max_len = item_seq.shape
370 |
371 | # get start position
372 | reorder_size = (seq_len * self.aug_ratio).ceil().long().unsqueeze(-1) # [B, 1]
373 | position_tensor = torch.arange(max_len).repeat(batch_size, 1).to(dev) # [B, L]
374 | sample_prob = (position_tensor <= seq_len.unsqueeze(-1) - reorder_size).bool().float()
375 | start_index = torch.multinomial(sample_prob, num_samples=1) # [B, 1]
376 |
377 | # get reorder item mask
378 | reorder_item_mask = (start_index <= position_tensor) & (position_tensor < start_index + reorder_size)
379 |
380 | # reorder operation
381 | tmp_reorder_tensor = torch.zeros_like(item_seq).long().to(dev)
382 | tmp_reorder_tensor[reorder_item_mask] = item_seq[reorder_item_mask]
383 | rand_index = torch.randperm(max_len)
384 | tmp_reorder_tensor = tmp_reorder_tensor[:, rand_index]
385 |
386 | # put reordered items back
387 | aug_item_seq = item_seq.clone()
388 | aug_item_seq[reorder_item_mask] = tmp_reorder_tensor[tmp_reorder_tensor > 0]
389 |
390 | return aug_item_seq, seq_len
391 |
392 |
393 | class Crop(object):
394 | """Randomly crop a subseq from the original sequence"""
395 |
396 | def __init__(self, tao=0.2):
397 | self.tao = tao
398 |
399 | def __call__(self, sequence):
400 | # make a deep copy to avoid original sequence be modified
401 | copied_sequence = copy.deepcopy(sequence)
402 |
403 | # # add length constraints
404 | # if len(copied_sequence) < 5:
405 | # return copied_sequence
406 |
407 | sub_seq_length = int(self.tao * len(copied_sequence))
408 | # randint generate int x in range: a <= x <= b
409 | start_index = random.randint(0, len(copied_sequence) - sub_seq_length)
410 | if sub_seq_length < 1:
411 | return [copied_sequence[min(start_index, len(sequence) - 1)]]
412 | else:
413 | cropped_seq = copied_sequence[start_index:start_index + sub_seq_length]
414 | return cropped_seq
415 |
416 |
417 | class Mask(object):
418 | """Randomly mask k items given a sequence"""
419 |
420 | def __init__(self, gamma=0.7, mask_id=0):
421 | self.gamma = gamma
422 | self.mask_id = mask_id
423 |
424 | def __call__(self, sequence):
425 | # make a deep copy to avoid original sequence be modified
426 | copied_sequence = copy.deepcopy(sequence)
427 |
428 | # # add length constraints
429 | # if len(copied_sequence) < 5:
430 | # return copied_sequence
431 |
432 | mask_nums = int(self.gamma * len(copied_sequence))
433 | mask_idx = random.sample([i for i in range(len(copied_sequence))], k=mask_nums)
434 | for idx in mask_idx:
435 | copied_sequence[idx] = self.mask_id
436 | return copied_sequence
437 |
438 |
439 | class Reorder(object):
440 | """Randomly shuffle a continuous sub-sequence"""
441 |
442 | def __init__(self, beta=0.2):
443 | self.beta = beta
444 |
445 | def __call__(self, sequence):
446 | # make a deep copy to avoid original sequence be modified
447 | copied_sequence = copy.deepcopy(sequence)
448 |
449 | # # add length constraints
450 | # if len(copied_sequence) < 5:
451 | # return copied_sequence
452 |
453 | sub_seq_len = int(self.beta * len(copied_sequence))
454 | start_index = random.randint(0, len(copied_sequence) - sub_seq_len)
455 | sub_seq = copied_sequence[start_index:start_index + sub_seq_len]
456 | random.shuffle(sub_seq)
457 | reordered_seq = copied_sequence[:start_index] + sub_seq + \
458 | copied_sequence[start_index + sub_seq_len:]
459 | assert len(copied_sequence) == len(reordered_seq)
460 | return reordered_seq
461 |
462 |
463 | class Repeat(object):
464 | """Randomly repeat p% of items in sequence"""
465 |
466 | def __init__(self, p=0.2, min_rep_size=1):
467 | self.p = p # max repeat ratio
468 | self.min_rep_size = min_rep_size
469 |
470 | def __call__(self, sequence):
471 | # make a deep copy to avoid original sequence be modified
472 | copied_sequence = copy.deepcopy(sequence)
473 | max_repeat_nums = math.ceil(self.p * len(copied_sequence))
474 | repeat_nums = \
475 | random.sample([i for i in range(self.min_rep_size, max(self.min_rep_size, max_repeat_nums) + 1)], k=1)[0]
476 | repeat_idx = random.sample([i for i in range(len(copied_sequence))], k=repeat_nums)
477 | repeat_idx.sort()
478 | new_seq = []
479 | cur_idx = 0
480 | for i, item in enumerate(copied_sequence):
481 | new_seq.append(item)
482 | if cur_idx < len(repeat_idx) and i == repeat_idx[cur_idx]:
483 | new_seq.append(item)
484 | cur_idx += 1
485 | return new_seq
486 |
487 |
488 | class Drop(object):
489 | """Randomly repeat p% of items in sequence"""
490 |
491 | def __init__(self, p=0.2):
492 | self.p = p # max repeat ratio
493 |
494 | def __call__(self, sequence):
495 | # make a deep copy to avoid original sequence be modified
496 | copied_sequence = copy.deepcopy(sequence)
497 | drop_num = math.floor(self.p * len(copied_sequence))
498 | drop_idx = random.sample([i for i in range(len(copied_sequence))], k=drop_num)
499 | drop_idx.sort()
500 | new_seq = []
501 | cur_idx = 0
502 | for i, item in enumerate(copied_sequence):
503 | if cur_idx < len(drop_idx) and i == drop_idx[cur_idx]:
504 | cur_idx += 1
505 | continue
506 | new_seq.append(item)
507 | return new_seq
508 |
509 |
510 | AUGMENTATIONS = {'crop': Crop, 'mask': Mask, 'reorder': Reorder, 'repeat': Repeat, 'drop': Drop}
511 |
512 |
--------------------------------------------------------------------------------
/src/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class InfoNCELoss(nn.Module):
7 | """
8 | Pair-wise Noise Contrastive Estimation Loss
9 | """
10 |
11 | def __init__(self, temperature, similarity_type):
12 | super(InfoNCELoss, self).__init__()
13 | self.temperature = temperature # temperature
14 | self.sim_type = similarity_type # cos or dot
15 | self.criterion = nn.CrossEntropyLoss()
16 |
17 | def forward(self, aug_hidden_view1, aug_hidden_view2, mask=None):
18 | """
19 | Args:
20 | aug_hidden_view1 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1
21 | aug_hidden_view2 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1
22 |
23 | Returns: nce_loss (FloatTensor, (,)): calculated nce loss
24 | """
25 | if aug_hidden_view1.ndim > 2:
26 | # flatten tensor
27 | aug_hidden_view1 = aug_hidden_view1.view(aug_hidden_view1.size(0), -1)
28 | aug_hidden_view2 = aug_hidden_view2.view(aug_hidden_view2.size(0), -1)
29 |
30 | if self.sim_type not in ['cos', 'dot']:
31 | raise Exception(f"Invalid similarity_type for cs loss: [current:{self.sim_type}]. "
32 | f"Please choose from ['cos', 'dot']")
33 |
34 | if self.sim_type == 'cos':
35 | sim11 = self.cosinesim(aug_hidden_view1, aug_hidden_view1)
36 | sim22 = self.cosinesim(aug_hidden_view2, aug_hidden_view2)
37 | sim12 = self.cosinesim(aug_hidden_view1, aug_hidden_view2)
38 | elif self.sim_type == 'dot':
39 | # calc similarity
40 | sim11 = aug_hidden_view1 @ aug_hidden_view1.t()
41 | sim22 = aug_hidden_view2 @ aug_hidden_view2.t()
42 | sim12 = aug_hidden_view1 @ aug_hidden_view2.t()
43 | # mask non-calc value
44 | sim11[..., range(sim11.size(0)), range(sim11.size(0))] = float('-inf')
45 | sim22[..., range(sim22.size(0)), range(sim22.size(0))] = float('-inf')
46 |
47 | cl_logits1 = torch.cat([sim12, sim11], -1)
48 | cl_logits2 = torch.cat([sim22, sim12.t()], -1)
49 | cl_logits = torch.cat([cl_logits1, cl_logits2], 0) / self.temperature
50 | if mask is not None:
51 | cl_logits = torch.masked_fill(cl_logits, mask, float('-inf'))
52 | target = torch.arange(cl_logits.size(0)).long().to(aug_hidden_view1.device)
53 | cl_loss = self.criterion(cl_logits, target)
54 |
55 | return cl_loss
56 |
57 | def cosinesim(self, aug_hidden1, aug_hidden2):
58 | h = torch.matmul(aug_hidden1, aug_hidden2.T)
59 | h1_norm2 = aug_hidden1.pow(2).sum(dim=-1).sqrt().view(h.shape[0], 1)
60 | h2_norm2 = aug_hidden2.pow(2).sum(dim=-1).sqrt().view(1, h.shape[0])
61 | return h / (h1_norm2 @ h2_norm2)
62 |
63 |
64 | class InfoNCELoss_2(nn.Module):
65 | """
66 | Pair-wise Noise Contrastive Estimation Loss, another implementation.
67 | """
68 |
69 | def __init__(self, temperature, similarity_type, batch_size):
70 | super(InfoNCELoss_2, self).__init__()
71 | self.tem = temperature # temperature
72 | self.sim_type = similarity_type # cos or dot
73 | self.batch_size = batch_size
74 | self.mask = self.mask_correlated_samples(self.batch_size)
75 | self.criterion = nn.CrossEntropyLoss()
76 |
77 | def forward(self, aug_hidden1, aug_hidden2):
78 | """
79 | Args:
80 | aug_hidden1 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1
81 | aug_hidden2 (FloatTensor, [batch, max_len, dim] or [batch, dim]): augmented sequence representation1
82 |
83 | Returns: nce_loss (FloatTensor, (,)): calculated nce loss
84 | """
85 | if aug_hidden1.ndim > 2:
86 | # flatten tensor
87 | aug_hidden1 = aug_hidden1.view(aug_hidden1.size(0), -1)
88 | aug_hidden2 = aug_hidden2.view(aug_hidden2.size(0), -1)
89 |
90 | current_batch = aug_hidden1.size(0)
91 | N = 2 * current_batch
92 | all_hidden = torch.cat((aug_hidden1, aug_hidden2), dim=0) # [2*B, D]
93 |
94 | if self.sim_type == 'cos':
95 | all_hidden = F.normalize(all_hidden)
96 | sim = torch.mm(all_hidden, all_hidden.T) / self.tem
97 | # sim = F.cosine_similarity(all_hidden.unsqueeze(1), all_hidden.unsqueeze(0), dim=2) / self.tem
98 | elif self.sim_type == 'dot':
99 | sim = torch.mm(all_hidden, all_hidden.T) / self.tem
100 | else:
101 | raise Exception(f"Invalid similarity_type for cs loss: [current:{self.sim_type}]. "
102 | f"Please choose from ['cos', 'dot']")
103 |
104 | sim_i_j = torch.diag(sim, current_batch)
105 | sim_j_i = torch.diag(sim, -current_batch)
106 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
107 | if self.batch_size != current_batch:
108 | mask = self.mask_correlated_samples(current_batch)
109 | else:
110 | mask = self.mask
111 | negative_samples = sim[mask].reshape(N, -1)
112 | labels = torch.zeros(N).to(positive_samples.device).long()
113 | logits = torch.cat((positive_samples, negative_samples), dim=1)
114 | nce_loss = self.criterion(logits, labels)
115 |
116 | return nce_loss
117 |
118 | def mask_correlated_samples(self, batch_size):
119 | N = 2 * batch_size
120 | mask = torch.ones((N, N)).bool()
121 | mask = mask.fill_diagonal_(0)
122 | index1 = torch.arange(batch_size) + batch_size
123 | index2 = torch.arange(batch_size)
124 | index = torch.cat([index1, index2], 0).unsqueeze(-1) # [2*B, 1]
125 | mask = torch.scatter(mask, -1, index, 0)
126 | return mask
127 |
128 |
129 | def lalign(x, y, alpha=2):
130 | return (x - y).norm(dim=-1).pow(alpha).mean()
131 |
132 |
133 | def lunif(x, t=2):
134 | sq_dlist = torch.pdist(x, p=2).pow(2)
135 | return torch.log(sq_dlist.mul(-t).exp().mean() + 1e-6)
136 |
--------------------------------------------------------------------------------
/src/model/sequential_encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import copy
4 | import math
5 | import torch.nn.functional as F
6 |
7 |
8 | class Transformer(nn.Module):
9 | def __init__(self, embed_size, ffn_hidden, num_blocks, num_heads, attn_dropout, hidden_dropout,
10 | layer_norm_eps=0.02, bidirectional=False):
11 | super(Transformer, self).__init__()
12 | self.bidirectional = bidirectional
13 | encoder_layer = EncoderLayer(embed_size=embed_size,
14 | ffn_hidden=ffn_hidden,
15 | num_heads=num_heads,
16 | attn_dropout=attn_dropout,
17 | hidden_dropout=hidden_dropout,
18 | layer_norm_eps=layer_norm_eps)
19 | self.encoder_layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_blocks)])
20 |
21 | def forward(self, item_input, seq_embedding):
22 | """
23 | Only output the sequence representations of the last layer in Transformer.
24 | out_seq_embed: torch.FloatTensor, [batch_size, max_len, embed_size]
25 | """
26 | mask = self.create_mask(item_input)
27 | for layer in self.encoder_layers:
28 | seq_embedding = layer(seq_embedding, mask)
29 | return seq_embedding
30 |
31 | def create_mask(self, input_seq):
32 | """
33 | Parameters:
34 | input_seq: torch.LongTensor, [batch_size, max_len]
35 | Return:
36 | mask: torch.BoolTensor, [batch_size, 1, max_len, max_len]
37 | """
38 | mask = (input_seq != 0).bool().unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, max_len]
39 | mask = mask.expand(-1, -1, mask.size(-1), -1)
40 | if not self.bidirectional:
41 | mask = torch.tril(mask)
42 | return mask
43 |
44 | def set_attention_direction(self, bidirection=False):
45 | self.bidirectional = bidirection
46 |
47 |
48 | class EncoderLayer(nn.Module):
49 | def __init__(self, embed_size, ffn_hidden, num_heads, attn_dropout, hidden_dropout, layer_norm_eps):
50 | super(EncoderLayer, self).__init__()
51 |
52 | self.attn_layer_norm = nn.LayerNorm(embed_size, eps=layer_norm_eps)
53 | self.pff_layer_norm = nn.LayerNorm(embed_size, eps=layer_norm_eps)
54 |
55 | self.self_attention = MultiHeadAttentionLayer(embed_size, num_heads, attn_dropout)
56 | self.pff = PointWiseFeedForwardLayer(embed_size, ffn_hidden)
57 |
58 | self.hidden_dropout = nn.Dropout(hidden_dropout)
59 | self.pff_out_drop = nn.Dropout(hidden_dropout)
60 |
61 | def forward(self, input_seq, inputs_mask):
62 | """
63 | input:
64 | inputs: torch.FloatTensor, [batch_size, max_len, embed_size]
65 | inputs_mask: torch.BoolTensor, [batch_size, 1, 1, max_len]
66 | return:
67 | out_seq_embed: torch.FloatTensor, [batch_size, max_len, embed_size]
68 | """
69 | out_seq, att_matrix = self.self_attention(input_seq, input_seq, input_seq, inputs_mask)
70 | input_seq = self.attn_layer_norm(input_seq + self.hidden_dropout(out_seq))
71 | out_seq = self.pff(input_seq)
72 | out_seq = self.pff_layer_norm(input_seq + self.pff_out_drop(out_seq))
73 | return out_seq
74 |
75 |
76 | class MultiHeadAttentionLayer(nn.Module):
77 | def __init__(self, embed_size, nhead, attn_dropout):
78 | super(MultiHeadAttentionLayer, self).__init__()
79 | self.embed_size = embed_size
80 | self.nhead = nhead
81 |
82 | if self.embed_size % self.nhead != 0:
83 | raise ValueError(
84 | "The hidden size (%d) is not a multiple of the number of attention "
85 | "heads (%d)" % (self.embed_size, self.nhead)
86 | )
87 | self.head_dim = self.embed_size // self.nhead
88 |
89 | # Q K V input linear layer
90 | self.fc_q = nn.Linear(self.embed_size, self.embed_size)
91 | self.fc_k = nn.Linear(self.embed_size, self.embed_size)
92 | self.fc_v = nn.Linear(self.embed_size, self.embed_size)
93 |
94 | self.attn_dropout = nn.Dropout(attn_dropout)
95 | self.fc_o = nn.Linear(self.embed_size, self.embed_size)
96 | self.register_buffer('scale', torch.sqrt(torch.tensor(self.head_dim).float()))
97 |
98 | def forward(self, query, key, value, inputs_mask=None):
99 | """
100 | :param query: [query_size, max_len, embed_size]
101 | :param key: [key_size, max_len, embed_size]
102 | :param value: [key_size, max_len, embed_size]
103 | :param inputs_mask: [N, 1, max_len, max_len]
104 | :return: [N, max_len, embed_size]
105 | """
106 | batch_size = query.size(0)
107 | Q = self.fc_q(query)
108 | K = self.fc_k(key)
109 | V = self.fc_v(value)
110 |
111 | # [batch_size, n_head, max_len, head_dim]
112 | Q = Q.view(query.size(0), -1, self.nhead, self.head_dim).permute((0, 2, 1, 3))
113 | K = K.view(key.size(0), -1, self.nhead, self.head_dim).permute((0, 2, 1, 3))
114 | V = V.view(value.size(0), -1, self.nhead, self.head_dim).permute((0, 2, 1, 3))
115 |
116 | # calculate attention score
117 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
118 | if inputs_mask is not None:
119 | energy = energy.masked_fill(inputs_mask == 0, -1.e10)
120 |
121 | attention_prob = F.softmax(energy, dim=-1)
122 | attention_prob = self.attn_dropout(attention_prob)
123 |
124 | out = torch.matmul(attention_prob, V) # [batch_size, n_head, max_len, head_dim]
125 | out = out.permute((0, 2, 1, 3)).contiguous() # memory layout
126 | out = out.view((batch_size, -1, self.embed_size))
127 | out = self.fc_o(out)
128 | return out, attention_prob
129 |
130 |
131 | class PointWiseFeedForwardLayer(nn.Module):
132 | def __init__(self, embed_size, hidden_size):
133 | super(PointWiseFeedForwardLayer, self).__init__()
134 |
135 | self.fc1 = nn.Linear(embed_size, hidden_size)
136 | self.fc2 = nn.Linear(hidden_size, embed_size)
137 |
138 | def forward(self, inputs):
139 | out = self.fc2(F.gelu(self.fc1(inputs)))
140 | return out
141 |
--------------------------------------------------------------------------------
/src/model/sequential_recommender/__init__.py:
--------------------------------------------------------------------------------
1 | from src.model.sequential_recommender.sasrec import SASRec, SASRec_config
2 |
--------------------------------------------------------------------------------
/src/model/sequential_recommender/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/sequential_recommender/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/sequential_recommender/__pycache__/sasrec.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/model/sequential_recommender/__pycache__/sasrec.cpython-38.pyc
--------------------------------------------------------------------------------
/src/model/sequential_recommender/sasrec.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch.nn.functional as F
3 | from src.model.abstract_recommeder import AbstractRecommender
4 | import argparse
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn.init import xavier_normal_, xavier_uniform_
8 | from src.model.sequential_encoder import Transformer
9 | from src.utils.utils import HyperParamDict
10 |
11 |
12 | class SASRec(AbstractRecommender):
13 | def __init__(self, config, additional_data_dict):
14 | super(SASRec, self).__init__(config)
15 | self.embed_size = config.embed_size
16 | self.hidden_size = config.ffn_hidden
17 | self.initializer_range = config.initializer_range
18 |
19 | # module
20 | self.item_embedding = nn.Embedding(self.num_items, self.embed_size, padding_idx=0)
21 | self.position_embedding = nn.Embedding(self.max_len, self.embed_size)
22 | self.trm_encoder = Transformer(embed_size=self.embed_size,
23 | ffn_hidden=self.hidden_size,
24 | num_blocks=config.num_blocks,
25 | num_heads=config.num_heads,
26 | attn_dropout=config.attn_dropout,
27 | hidden_dropout=config.hidden_dropout,
28 | layer_norm_eps=config.layer_norm_eps)
29 |
30 | self.input_layer_norm = nn.LayerNorm(self.embed_size, eps=config.layer_norm_eps)
31 | self.dropout = nn.Dropout(config.hidden_dropout)
32 |
33 | self.apply(self._init_weights)
34 |
35 | def _init_weights(self, module):
36 | if isinstance(module, (nn.Embedding, nn.Linear)):
37 | module.weight.data.normal_(mean=0.0, std=self.initializer_range)
38 | elif isinstance(module, nn.LayerNorm):
39 | module.weight.data.fill_(1.)
40 | module.bias.data.zero_()
41 | elif isinstance(module, nn.Linear) and module.bias is not None:
42 | module.bias.data.zero_()
43 |
44 | # def train_forward(self, data_dict: dict):
45 | # item_seq, seq_len, target = self.load_basic_sr_data(data_dict)
46 | # seq_embedding = self.position_encoding(item_seq)
47 | # out_seq_embedding = self.trm_encoder(item_seq, seq_embedding)
48 | # loss = self.calc_loss_(item_seq, out_seq_embedding, target)
49 | #
50 | # return loss
51 |
52 | def forward(self, data_dict):
53 | item_seq, seq_len, _ = self.load_basic_SR_data(data_dict)
54 | seq_embedding = self.position_encoding(item_seq)
55 | out_seq_embedding = self.trm_encoder(item_seq, seq_embedding)
56 | seq_embedding = self.gather_index(out_seq_embedding, seq_len - 1)
57 |
58 | # get prediction
59 | candidates = self.item_embedding.weight
60 | logits = seq_embedding @ candidates.t()
61 |
62 | return logits
63 |
64 | def position_encoding(self, item_input):
65 | seq_embedding = self.item_embedding(item_input)
66 | position = torch.arange(self.max_len, device=item_input.device).unsqueeze(0)
67 | position = position.expand_as(item_input).long()
68 | pos_embedding = self.position_embedding(position)
69 | seq_embedding += pos_embedding
70 | seq_embedding = self.dropout(self.input_layer_norm(seq_embedding))
71 |
72 | return seq_embedding
73 |
74 | # def calc_loss_(self, item_seq, out_seq_embedding, target):
75 | # """
76 | # For no data augmentation situation.
77 | # item_seq: [B, L]
78 | # out_seq_embedding: [B, L, D]
79 | # target: [B, L]
80 | # """
81 | # embed_size = out_seq_embedding.size(-1)
82 | # valid_mask = (item_seq > 0).view(-1).bool()
83 | # out_seq_embedding = out_seq_embedding.view(-1, embed_size)
84 | # target = target.view(-1)
85 | #
86 | # candidates = self.item_embedding.weight
87 | # logits = out_seq_embedding @ candidates.transpose(0, 1)
88 | # logits = logits[valid_mask]
89 | # target = target[valid_mask]
90 | #
91 | # loss = self.cross_entropy(logits, target)
92 | #
93 | # return loss
94 |
95 |
96 | def SASRec_config():
97 | parser = HyperParamDict('SASRec default hyper-parameters')
98 | parser.add_argument('--model', default='SASRec', type=str)
99 | parser.add_argument('--model_type', default='Sequential', choices=['Sequential', 'Knowledge'])
100 | parser.add_argument('--embed_size', default=128, type=int)
101 | parser.add_argument('--ffn_hidden', default=512, type=int, help='hidden dim for feed forward network')
102 | parser.add_argument('--num_blocks', default=2, type=int, help='number of transformer block')
103 | parser.add_argument('--num_heads', default=2, type=int, help='number of head for multi-head attention')
104 | parser.add_argument('--hidden_dropout', default=0.5, type=float, help='hidden state dropout rate')
105 | parser.add_argument('--attn_dropout', default=0., type=float, help='dropout rate for attention')
106 | parser.add_argument('--layer_norm_eps', default=1e-12, type=float, help='transformer layer norm eps')
107 | parser.add_argument('--initializer_range', default=0.02, type=float, help='transformer params initialize range')
108 | parser.add_argument('--loss_type', default='CE', type=str, choices=['CE', 'BPR', 'BCE', 'CUSTOM'])
109 | return parser
110 |
--------------------------------------------------------------------------------
/src/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__init__.py
--------------------------------------------------------------------------------
/src/train/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/train/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/src/train/__pycache__/trainer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/train/__pycache__/trainer.cpython-38.pyc
--------------------------------------------------------------------------------
/src/train/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.cuda
3 | from easydict import EasyDict
4 | from src.utils.utils import HyperParamDict
5 |
6 | EXP_HYPER_LIST = {'Data': {'dataset': None, 'data_aug': None, 'seq_filter_len': None,
7 | 'if_filter_target': None, 'max_len': None},
8 | 'Pretraining': {'do_pretraining': None, 'pretraining_task': None, 'pretraining_epoch': None,
9 | 'pretraining_batch': None, 'pretraining_lr': None, 'pretraining_l2': None},
10 | 'Training': {'epoch_num': None, 'train_batch': None,
11 | 'learning_rate': None, 'l2': None, 'patience': None,
12 | 'device': None, 'num_worker': None, 'seed': None},
13 | 'Evaluation': {'split_type': None, 'split_mode': None, 'eval_mode': None, 'metric': None, 'k': None,
14 | 'valid_metric': None, 'eval_batch': None},
15 | 'Save': {'log_save': None, 'save': None, 'model_saved': None}}
16 |
17 |
18 | def experiment_hyper_load(exp_config):
19 | hyper_types = EXP_HYPER_LIST.keys()
20 | for hyper_dict in EXP_HYPER_LIST.values():
21 | for hyper in hyper_dict.keys():
22 | hyper_dict[hyper] = getattr(exp_config, hyper)
23 | return list(hyper_types), EXP_HYPER_LIST
24 |
25 |
26 | def get_device():
27 | return 'cuda:0' if torch.cuda.is_available() else 'cpu'
28 |
29 |
30 | def get_default_config():
31 | parser = HyperParamDict()
32 | # Model
33 | parser.add_argument('--model', default='URCL4SRec')
34 | # Data
35 | parser.add_argument('--dataset', default='toys', type=str,
36 | choices=['home', 'grocery', 'grocery', 'yelp_s3', 'toys'])
37 | parser.add_argument('--data_aug', action='store_false', help='data augmentation')
38 | parser.add_argument('--seq_filter_len', default=0, type=int, help='filter seq less than 3')
39 | parser.add_argument('--if_filter_target', action='store_true',
40 | help='if filter target appearing in previous sequence')
41 | parser.add_argument('--separator', default=' ', type=str, help='separator to split item sequence')
42 | parser.add_argument('--graph_type', default='None', type=str, help='do not use graph',
43 | choices=['None', 'BIPARTITE', 'TRANSITION'])
44 | parser.add_argument('--max_len', default=50, type=int, help='max sequence length')
45 | parser.add_argument('--kg_data_type', default='pretrain', type=str, choices=['pretrain', 'jointly_train', 'other'])
46 | # Pretraining
47 | parser.add_argument('--do_pretraining', default=False, action='store_true')
48 | parser.add_argument('--pretraining_task', default='MISP', type=str, choices=['MISP', 'MIM', 'PID'],
49 | help='pretraining task:' \
50 | 'MISP: Mask Item Prediction and Mask Segment Prediction' \
51 | 'MIM: Mutual Information Maximization' \
52 | 'PID: Pseudo Item Discrimination'
53 | )
54 | parser.add_argument('--pretraining_epoch', default=10, type=int)
55 | parser.add_argument('--pretraining_batch', default=512, type=int)
56 | parser.add_argument('--pretraining_lr', default=1e-3, type=float)
57 | parser.add_argument('--pretraining_l2', default=0., type=float, help='l2 normalization')
58 | # Training
59 | parser.add_argument('--epoch_num', default=100, type=int)
60 | parser.add_argument('--seed', default=1034, type=int, help="random seed, only -1 means don't set random seed")
61 | parser.add_argument('--train_batch', default=256, type=int)
62 | parser.add_argument('--learning_rate', default=1e-3, type=float)
63 | parser.add_argument('--l2', default=0., type=float, help='l2 normalization')
64 | parser.add_argument('--patience', default=5, type=int, help='early stop patience')
65 | parser.add_argument('--device', default=get_device(), choices=['cuda:0', 'cpu'],
66 | help='training on gpu or cpu, default gpu')
67 | parser.add_argument('--num_worker', default=0, type=int,
68 | help='num_workers for dataloader, best: 6')
69 | parser.add_argument('--mark', default='', type=str,
70 | help='mark of this run which will be added to the name of the log')
71 |
72 | # Evaluation
73 | parser.add_argument('--split_type', default='valid_and_test', choices=['valid_only', 'valid_and_test'])
74 | parser.add_argument('--split_mode', default='LS', type=str,
75 | help='LS: Leave-one-out splitting.'
76 | 'LS_R@0.2: use LS and a ratio 0.x of test data for validate if use valid_and_test.'
77 | 'PS: Pre-Splitting, prepare xx.train and xx.eval, also xx.test if use valid_and_test')
78 | parser.add_argument('--eval_mode', default='full', help='[uni100, uni200, full]')
79 | parser.add_argument('--metric', default=['hit', 'ndcg'], help='[hit, ndcg, mrr, recall]')
80 | parser.add_argument('--k', default=[5, 10], help='top k for each metric')
81 | parser.add_argument('--valid_metric', default='hit@10', help='specifies which indicator to apply early stop')
82 | parser.add_argument('--eval_batch', default=256, type=int)
83 |
84 | # save
85 | parser.add_argument('--log_save', default='log', type=str, help='log saving path')
86 | parser.add_argument('--save', default='save', type=str, help='model saving path')
87 | parser.add_argument('--model_saved', default=None, type=str)
88 |
89 | return parser
90 |
91 |
92 | def config_override(model_config, cmd_config):
93 | default_config = get_default_config()
94 | command_args = set([arg for arg in vars(cmd_config)])
95 | # overwrite model config by cmd config
96 | for arg in vars(model_config):
97 | if arg in command_args:
98 | setattr(model_config, arg, getattr(cmd_config, arg))
99 |
100 | # overwrite default config by cmd config
101 | for arg in vars(default_config):
102 | if arg in command_args:
103 | setattr(default_config, arg, getattr(cmd_config, arg))
104 |
105 | # overwrite default config by model config
106 | for arg in vars(model_config):
107 | setattr(default_config, arg, getattr(model_config, arg))
108 |
109 | return default_config
110 |
--------------------------------------------------------------------------------
/src/train/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 | from src.dataset import dataset
8 | from src.dataset.dataset import load_specified_dataset
9 | from src.dataset.data_processor import DataProcessor
10 | from src.evaluation.estimator import Estimator
11 | from src.utils.recorder import Recorder
12 | import src.model as model
13 | from src.train.config import experiment_hyper_load, config_override
14 | from src.utils.utils import set_seed, batch_to_device, KMeans
15 |
16 |
17 | # torch.autograd.set_detect_anomaly(True)
18 |
19 | def load_trainer(config):
20 | if config.model in ['ICLRec']:
21 | return ICLTrainer(config)
22 | return Trainer(config)
23 |
24 |
25 | class Trainer:
26 | def __init__(self, config):
27 | self.config = config
28 | self.model_name = config.model
29 | self._config_override(self.model_name, config)
30 | # pretraining
31 | self.pretraining_model = None
32 | self.do_pretraining = self.config.do_pretraining
33 | self.pretraining_task = self.config.pretraining_task
34 | self.pretraining_epoch = self.config.pretraining_epoch
35 | self.pretraining_batch = self.config.pretraining_batch
36 | self.pretraining_lr = self.config.pretraining_lr
37 | self.pretraining_l2 = self.config.pretraining_l2
38 |
39 | # training
40 | self.training_model = None
41 | self.num_worker = self.config.num_worker
42 | self.train_batch = self.config.train_batch
43 | self.eval_batch = self.config.eval_batch
44 | self.lr = self.config.learning_rate
45 | self.l2 = self.config.l2
46 | self.epoch_num = self.config.epoch_num
47 | self.dev = torch.device(self.config.device)
48 | self.split_type = self._set_split_mode(self.config.split_type)
49 | self.do_test = self.split_type == 'valid_and_test'
50 |
51 | # components
52 | self.data_processor = DataProcessor(self.config)
53 | self.estimator = Estimator(self.config)
54 | self.recorder = Recorder(self.config)
55 |
56 | # set random seed
57 | set_seed(self.config.seed)
58 |
59 | # preparing data
60 | data_dict, additional_data_dict = self.data_processor.prepare_data()
61 | self.data_dict = data_dict # store standard train/eval/test data
62 | self.additional_data_dict = additional_data_dict # extra data (model specified)
63 |
64 | # check_duplication(self.data_dict['train'][0])
65 |
66 | self.estimator.load_item_popularity(self.data_processor.popularity)
67 | self._set_num_items()
68 |
69 | def start_training(self):
70 | if self.do_pretraining:
71 | self.pretrain()
72 | self.train()
73 |
74 | def pretrain(self):
75 | if self.pretraining_task in ['MISP', 'MIM', 'PID']:
76 | pretrain_dataset = getattr(dataset, f'{self.pretraining_task}PretrainDataset')
77 | pretrain_dataset = pretrain_dataset(self.config, self.data_dict['train'],
78 | self.additional_data_dict)
79 | else:
80 | raise NotImplementedError(f'No such pretraining task: {self.pretraining_task}, '
81 | f'choosing from [MIP, MIM, PID]')
82 | train_loader = DataLoader(pretrain_dataset, batch_size=self.train_batch, collate_fn=pretrain_dataset.collate_fn,
83 | shuffle=True, num_workers=0, drop_last=False)
84 |
85 | pretrain_model = self._load_model()
86 |
87 | opt = torch.optim.Adam(filter(lambda x: x.requires_grad, pretrain_model.parameters()),
88 | self.pretraining_lr, weight_decay=self.pretraining_l2)
89 |
90 | self.experiment_setting_verbose(pretrain_model, training=False)
91 |
92 | logging.info('Start pretraining...')
93 | for epoch in range(self.pretraining_epoch):
94 | pretrain_model.train()
95 | self.recorder.epoch_restart()
96 | self.recorder.tik_start()
97 | train_iter = tqdm(enumerate(train_loader), total=len(train_loader))
98 | train_iter.set_description(f'pretraining epoch: {epoch}')
99 | for i, batch_dict in train_iter:
100 | batch_to_device(batch_dict, self.dev)
101 | loss = getattr(pretrain_model, f'{self.pretraining_task}_pretrain_forward')(batch_dict)
102 | opt.zero_grad()
103 | loss.backward()
104 | opt.step()
105 |
106 | self.recorder.save_batch_loss(loss.item())
107 | self.recorder.tik_end()
108 | self.recorder.train_log_verbose(len(train_loader))
109 |
110 | self.pretraining_model = pretrain_model
111 | logging.info('Pre-training is over, prepare for training...')
112 |
113 | def train(self):
114 | SpecifiedDataSet = load_specified_dataset(self.model_name, self.config)
115 | train_dataset = SpecifiedDataSet(self.config, self.data_dict['train'],
116 | self.additional_data_dict)
117 | train_loader = DataLoader(train_dataset, batch_size=self.train_batch, collate_fn=train_dataset.collate_fn,
118 | shuffle=True, num_workers=self.num_worker, drop_last=False)
119 |
120 | eval_dataset = SpecifiedDataSet(self.config, self.data_dict['eval'],
121 | self.additional_data_dict, train=False)
122 | eval_loader = DataLoader(eval_dataset, batch_size=self.eval_batch, collate_fn=eval_dataset.collate_fn,
123 | shuffle=False, num_workers=self.num_worker, drop_last=False)
124 |
125 | self.training_model = self._load_model()
126 |
127 | opt = torch.optim.Adam(filter(lambda x: x.requires_grad, self.training_model.parameters()), self.lr,
128 | weight_decay=self.l2)
129 | self.recorder.reset()
130 | self.experiment_setting_verbose(self.training_model)
131 |
132 | logging.info('Start training...')
133 | for epoch in range(self.epoch_num):
134 | self.training_model.train()
135 | self.recorder.epoch_restart()
136 | self.recorder.tik_start()
137 | train_iter = tqdm(enumerate(train_loader), total=len(train_loader))
138 | train_iter.set_description('training ')
139 | for step, batch_dict in train_iter:
140 | # training forward
141 | batch_dict['epoch'] = epoch
142 | batch_dict['step'] = step
143 | batch_to_device(batch_dict, self.dev)
144 | loss = self.training_model.train_forward(batch_dict)
145 | if torch.is_tensor(loss) and loss.requires_grad:
146 | opt.zero_grad()
147 | loss.backward()
148 | opt.step()
149 | self.recorder.save_batch_loss(loss.item())
150 | self.recorder.tik_end()
151 | self.recorder.train_log_verbose(len(train_loader))
152 |
153 | # evaluation
154 | self.recorder.tik_start()
155 | eval_metric_result, eval_loss = self.estimator.evaluate(eval_loader, self.training_model)
156 | self.recorder.tik_end(mode='eval')
157 | self.recorder.log_verbose_and_save(eval_metric_result, eval_loss, self.training_model)
158 |
159 | if self.recorder.early_stop:
160 | break
161 |
162 | self.recorder.report_best_res()
163 | # test model
164 | if self.do_test:
165 | test_metric_res = self.test_model(self.data_dict['test'])
166 | self.recorder.report_test_result(test_metric_res)
167 |
168 | def _set_split_mode(self, split_mode):
169 | assert split_mode in ['valid_and_test', 'valid_only'], f'Invalid split mode: {split_mode} !'
170 | return split_mode
171 |
172 | def _load_model(self):
173 | if self.do_pretraining and self.pretraining_model is not None: # return pretraining model
174 | return self.pretraining_model
175 |
176 | # return new model
177 | if self.config.model_type.upper() == 'SEQUENTIAL':
178 | return self._load_sequential_model()
179 | elif self.config.model_type.upper() in ['GRAPH', 'KNOWLEDGE']:
180 | return self._load_model_with_additional_data()
181 | else:
182 | raise KeyError(f'Invalid model_type:{self.config.model_type}. Choose from [sequential, knowledge, graph]')
183 |
184 | def _load_sequential_model(self):
185 | Model = getattr(model, self.model_name)
186 | specified_seq_model = Model(self.config, self.additional_data_dict).to(self.dev)
187 | return specified_seq_model
188 |
189 | def _load_model_with_additional_data(self):
190 | Model = getattr(model, self.model_name)
191 | specified_model = Model(self.config, self.additional_data_dict).to(self.dev)
192 | return specified_model
193 |
194 | def _config_override(self, model_name, cmd_config):
195 | self.model_config = getattr(model, f'{model_name}_config')()
196 | self.config = config_override(self.model_config, cmd_config)
197 | # capitalize
198 | self.config.model_type = self.config.model_type.upper()
199 | self.config.graph_type = [g_type.upper() for g_type in self.config.graph_type]
200 |
201 | def _set_num_items(self):
202 | self.config.num_items = self.data_processor.num_items
203 |
204 | def experiment_setting_verbose(self, model, training=True):
205 | if self.do_pretraining and training:
206 | return
207 | # model config
208 | logging.info('[1] Model Hyper-Parameter '.ljust(47, '-'))
209 | model_param_set = self.model_config.keys()
210 | for arg in vars(self.config):
211 | if arg in model_param_set:
212 | logging.info(f'{arg}: {getattr(self.config, arg)}')
213 | # experiment config
214 | logging.info('[2] Experiment Hyper-Parameter '.ljust(47, '-'))
215 | # verbose_order = ['Data', 'Training', 'Evaluation', 'Save']
216 | hyper_types, exp_setting = experiment_hyper_load(self.config)
217 | for i, hyper_type in enumerate(hyper_types):
218 | hyper_start_log = (f'[2-{i + 1}] ' + hyper_type.lower() + ' hyper-parameter ').ljust(47, '-')
219 | logging.info(hyper_start_log)
220 | for hyper, value in exp_setting[hyper_type].items():
221 | logging.info(f'{hyper}: {value}')
222 | # data statistic
223 | self.data_processor.data_log_verbose(3)
224 | # model architecture
225 | self.report_model_info(model)
226 |
227 | def report_model_info(self, model):
228 | # model architecture
229 | logging.info('[1] Model Architecture '.ljust(47, '-'))
230 | logging.info(f'total parameters: {model.calc_total_params()}')
231 | logging.info(model)
232 |
233 | def test_model(self, test_data_pair=None):
234 | SpecifiedDataSet = load_specified_dataset(self.model_name, self.config)
235 | test_dataset = SpecifiedDataSet(self.config, test_data_pair,
236 | self.additional_data_dict, train=False)
237 | test_loader = DataLoader(test_dataset, batch_size=self.eval_batch, num_workers=self.num_worker,
238 | collate_fn=test_dataset.collate_fn, drop_last=False, shuffle=False)
239 | # load the best model
240 | self.recorder.load_best_model(self.training_model)
241 | self.training_model.eval()
242 | test_metric_result = self.estimator.test(test_loader, self.training_model)
243 |
244 | return test_metric_result
245 |
246 | def start_test(self):
247 | self.training_model = self._load_model()
248 | self.experiment_setting_verbose(self.training_model)
249 | test_metric_res = self.test_model(self.data_dict['test'])
250 | self.recorder.report_test_result(test_metric_res)
251 |
252 |
253 | class ICLTrainer(Trainer):
254 | def __init__(self, config):
255 | super(ICLTrainer, self).__init__(config)
256 | self.num_intent_cluster = config.num_intent_cluster
257 | self.seq_representation_type = config.seq_representation_type
258 | # initialize Kmeans
259 | if self.seq_representation_type == "mean":
260 | cluster = KMeans(
261 | num_cluster=self.num_intent_cluster,
262 | seed=self.config.seed,
263 | hidden_size=self.config.embed_size,
264 | device=self.config.device,
265 | )
266 | else:
267 | cluster = KMeans(
268 | num_cluster=self.num_intent_cluster,
269 | seed=self.config.seed,
270 | hidden_size=self.config.embed_size * self.config.max_len,
271 | device=self.config.device,
272 | )
273 | self.cluster = cluster
274 |
275 | def train(self):
276 | SpecifiedDataSet = load_specified_dataset(self.model_name, self.config)
277 | intent_cluster_dataset = SpecifiedDataSet(self.config, self.data_dict['raw_train'],
278 | self.additional_data_dict)
279 | intent_cluster_loader = DataLoader(intent_cluster_dataset, batch_size=self.train_batch,
280 | collate_fn=intent_cluster_dataset.collate_fn,
281 | shuffle=True, num_workers=self.num_worker, drop_last=False)
282 |
283 | train_dataset = SpecifiedDataSet(self.config, self.data_dict['train'],
284 | self.additional_data_dict)
285 | train_loader = DataLoader(train_dataset, batch_size=self.train_batch, collate_fn=train_dataset.collate_fn,
286 | shuffle=True, num_workers=self.num_worker, drop_last=False)
287 |
288 | eval_dataset = SpecifiedDataSet(self.config, self.data_dict['eval'],
289 | self.additional_data_dict, train=False)
290 | eval_loader = DataLoader(eval_dataset, batch_size=self.eval_batch, collate_fn=eval_dataset.collate_fn,
291 | shuffle=False, num_workers=self.num_worker, drop_last=False)
292 |
293 | self.training_model = self._load_model()
294 |
295 | opt = torch.optim.Adam(filter(lambda x: x.requires_grad, self.training_model.parameters()), self.lr,
296 | weight_decay=self.l2)
297 | self.recorder.reset()
298 | self.experiment_setting_verbose(self.training_model)
299 |
300 | logging.info('Start training...')
301 | for epoch in range(self.epoch_num):
302 | self.training_model.train()
303 | self.recorder.epoch_restart()
304 | self.recorder.tik_start()
305 |
306 | # collect cluster data
307 | intent_cluster_iter = tqdm(enumerate(intent_cluster_loader), total=len(intent_cluster_loader))
308 | intent_cluster_iter.set_description('prepare clustering ')
309 |
310 | kmeans_training_data = [] # store all user intent representations in training data
311 | for step, batch_dict in intent_cluster_iter:
312 | batch_to_device(batch_dict, self.dev)
313 | item_seq, seq_len = batch_dict['item_seq'], batch_dict['seq_len']
314 | sequence_output = self.training_model.seq_encoding(item_seq, seq_len, return_all=True)
315 | # average sum
316 | if self.seq_representation_type == "mean":
317 | sequence_output = torch.mean(sequence_output, dim=1, keepdim=False)
318 | sequence_output = sequence_output.view(sequence_output.shape[0], -1) # otherwise concat
319 | sequence_output = sequence_output.detach().cpu().numpy()
320 | kmeans_training_data.append(sequence_output)
321 | kmeans_training_data = np.concatenate(kmeans_training_data, axis=0) # [user_size, dim]
322 |
323 | # train cluster
324 | self.cluster.train(kmeans_training_data)
325 |
326 | # clean memory
327 | del kmeans_training_data
328 | import gc
329 |
330 | gc.collect()
331 |
332 | train_iter = tqdm(enumerate(train_loader), total=len(train_loader))
333 | train_iter.set_description('training ')
334 | for step, batch_dict in train_iter:
335 | # training forward
336 | batch_dict['epoch'] = epoch
337 | batch_dict['step'] = step
338 | batch_dict['cluster'] = self.cluster
339 | batch_to_device(batch_dict, self.dev)
340 | loss = self.training_model.train_forward(batch_dict)
341 | if torch.is_tensor(loss) and loss.requires_grad:
342 | opt.zero_grad()
343 | loss.backward()
344 | opt.step()
345 | self.recorder.save_batch_loss(loss.item())
346 | self.recorder.tik_end()
347 | self.recorder.train_log_verbose(len(train_loader))
348 |
349 | # evaluation
350 | self.recorder.tik_start()
351 | eval_metric_result, eval_loss = self.estimator.evaluate(eval_loader, self.training_model)
352 | self.recorder.tik_end(mode='eval')
353 | self.recorder.log_verbose_and_save(eval_metric_result, eval_loss, self.training_model)
354 |
355 | if self.recorder.early_stop:
356 | break
357 |
358 | self.recorder.report_best_res()
359 | # test model
360 | if self.do_test:
361 | test_metric_res = self.test_model(self.data_dict['test'])
362 | self.recorder.report_test_result(test_metric_res)
363 |
364 |
365 | if __name__ == '__main__':
366 | pass
367 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/recorder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__pycache__/recorder.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LFM-bot/IOCRec/a1b91db459920b4640dd14f1044684a3a5c162d5/src/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/recorder.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import logging
3 | import datetime
4 | import torch
5 | import os
6 | import time as t
7 | import numpy as np
8 |
9 |
10 | class Recorder:
11 | def __init__(self, config):
12 | self.epoch = 0
13 | self.model_name = config.model
14 | self.dataset = config.dataset
15 | self.run_mark = config.mark
16 | self.log_path = config.log_save
17 | # metric
18 | self.metrics = config.metric
19 | self.k_list = config.k
20 | # records
21 | self.batch_loss_rec = 0.
22 | self.metric_records = {}
23 | self.time_record = {'train': 0., 'eval': 0.}
24 | self.decimal_round = 4
25 | self.mark = config.mark
26 | self.model_saved = config.model_saved
27 |
28 | # early stop
29 | self.early_stop = False
30 | self.core_metric = config.valid_metric
31 | self.patience = int(config.patience)
32 | self.best_metric_rec = {'epoch': 0, 'score': 0.}
33 | self.step_2_stop = self.patience
34 | # log report
35 | self.block_size = 6
36 | self.half_underline = self.block_size * len(self.metrics) * len(self.k_list)
37 |
38 | self._recoder_init(config)
39 |
40 | def reset(self):
41 | self.epoch = 0
42 |
43 | def _recoder_init(self, config):
44 | self._init_log()
45 | self._init_record()
46 | self._model_saving_init(config)
47 |
48 | def _model_saving_init(self, config):
49 | # check saving path
50 | if not os.path.exists(config.save):
51 | os.mkdir(config.save)
52 | # init model saving path
53 | curr_time = datetime.datetime.now()
54 | timestamp = datetime.datetime.strftime(curr_time, '%Y-%m-%d_%H-%M-%S')
55 | if self.model_saved is None:
56 | self.model_saved = config.save + f'\\{config.model}-{self.dataset}-{self.mark}-{timestamp}.pth'
57 | logging.info(f'model save at: {self.model_saved}')
58 |
59 | def _init_log(self):
60 | save_path = os.path.join(self.log_path, self.dataset)
61 | if not os.path.isdir(save_path):
62 | os.makedirs(save_path)
63 | times = 1
64 | log_model_name = self.model_name + f'-{self.run_mark}' if len(self.run_mark) > 0 else self.model_name
65 | log_file = os.path.join(save_path, '%s_%d.log' % (log_model_name, times))
66 | for i in range(100):
67 | if not os.path.isfile(log_file):
68 | break
69 | log_file = os.path.join(save_path, '%s_%d.log' % (log_model_name, times + i + 1))
70 |
71 | logging.basicConfig(
72 | format='%(asctime)s %(levelname)-8s %(message)s',
73 | level=logging.INFO,
74 | datefmt='%Y-%m-%d %H:%M:%S',
75 | filename=log_file,
76 | filemode='w'
77 | )
78 | console = logging.StreamHandler()
79 | console.setLevel(logging.INFO)
80 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
81 | console.setFormatter(formatter)
82 | logging.getLogger('').addHandler(console)
83 | logging.info('log save at : {}'.format(log_file))
84 |
85 | def _init_record(self):
86 | for metric in self.metrics:
87 | for k in self.k_list:
88 | self.metric_records[f'{metric}@{k}'] = []
89 | assert self.core_metric in self.metric_records.keys(), f'Invalid valid_metric: [{self.core_metric}], ' \
90 | f'choose from: {self.metric_records.keys()} !'
91 |
92 | def save_model(self, model):
93 | # save entire model
94 | # torch.save(model, self.model_saving)
95 |
96 | # only save model parameters
97 | torch.save(model.state_dict(), self.model_saved)
98 |
99 | def load_best_model(self, model):
100 | # load entire model
101 | # return torch.load(self.model_saving)
102 |
103 | # load parameters
104 | model.load_state_dict(torch.load(self.model_saved))
105 |
106 | def epoch_restart(self):
107 | self.batch_loss_rec = 0.
108 | self.epoch += 1
109 |
110 | def save_batch_loss(self, batch_loss):
111 | self.batch_loss_rec += batch_loss
112 |
113 | def tik_start(self):
114 | self._clock = t.time()
115 |
116 | def tik_end(self, mode='train'):
117 | end_clock = t.time()
118 | self.time_record[mode] = end_clock - self._clock
119 |
120 | def _save_best_result(self, metric_res, model):
121 | """
122 | :param metric_res: dict
123 | """
124 | for metric, score in metric_res.items():
125 | self.metric_records.get(metric).append(score)
126 | # early stop
127 | core_metric_res = metric_res.get(self.core_metric)
128 | self.early_stop_check(core_metric_res, model)
129 |
130 | def early_stop_check(self, core_metric_res, model):
131 | if core_metric_res > self.best_metric_rec.get('score'):
132 | self.best_metric_rec['score'] = core_metric_res
133 | self.best_metric_rec['epoch'] = self.epoch
134 | self.step_2_stop = self.patience
135 | # find a better model -> save
136 | self.save_model(model)
137 | else:
138 | self.step_2_stop -= 1
139 | logging.info(f'EarlyStopping Counter: {self.patience - self.step_2_stop} out of {self.patience}')
140 | if self.step_2_stop == 0:
141 | self.early_stop = True
142 |
143 | def train_log_verbose(self, num_batch):
144 | training_loss = self.batch_loss_rec / num_batch
145 | logging.info('-' * self.half_underline + f'----Epoch {self.epoch}----' + '-' * self.half_underline)
146 | output_str = " Training Time :[%.1f s]\tTraining Loss = %.4f" % (self.time_record['train'], training_loss)
147 | logging.info(output_str)
148 |
149 | def log_verbose_and_save(self, metric_score, eval_loss, model):
150 | res_str = ''
151 | for metric, score in metric_score.items():
152 | score = round(score, self.decimal_round)
153 | res_str += f'{metric}:{score:1.4f}\t'
154 |
155 | eval_time = round(self.time_record['eval'], 1)
156 | if eval_loss <= 0:
157 | eval_loss = '**'
158 | else:
159 | eval_loss = round(eval_loss, 4)
160 | logging.info(f"Evaluation Time:[{eval_time} s]\t Eval Loss = {eval_loss}")
161 | logging.info(res_str)
162 |
163 | # save results and model
164 | self._save_best_result(metric_score, model)
165 |
166 | def report_best_res(self):
167 | best_epoch = self.best_metric_rec['epoch']
168 | logging.info('-' * self.half_underline + 'Best Evaluation' + '-' * self.half_underline)
169 | logging.info(f"Best Result at Epoch: {best_epoch}\t Early Stop at Patience: {self.patience}")
170 | # load best results
171 | best_metrics_res = {}
172 | for metric, metric_res_list in self.metric_records.items():
173 | best_metrics_res[metric] = metric_res_list[best_epoch - 1]
174 | res_str = ''
175 | for metric, score in best_metrics_res.items():
176 | score = round(score, self.decimal_round)
177 | res_str += f'{metric}:{score:1.4f}\t'
178 | logging.info(res_str)
179 |
180 | def report_test_result(self, test_metric_res):
181 | res_str = ''
182 | for metric, score in test_metric_res.items():
183 | score = round(score, self.decimal_round)
184 | res_str += f'{metric}:{score:1.4f}\t'
185 | logging.info('-' * self.half_underline + f'-----Test Results------' + '-' * self.half_underline)
186 | logging.info(res_str)
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import logging
3 | import datetime
4 | import random
5 | from typing import List
6 |
7 | # import faiss
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import os
12 | import time as t
13 | import numpy as np
14 | from easydict import EasyDict
15 |
16 |
17 | def set_seed(seed):
18 | if seed == -1:
19 | return
20 | random.seed(seed)
21 | os.environ['PYTHONHASHSEED'] = str(seed)
22 | np.random.seed(seed)
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed(seed)
25 | torch.cuda.manual_seed_all(seed)
26 | # some cudnn methods can be random even after fixing the seed
27 | # unless you tell it to be deterministic
28 | torch.backends.cudnn.deterministic = True
29 |
30 |
31 | def batch_to_device(tensor_dict: dict, dev):
32 | for key, obj in tensor_dict.items():
33 | if torch.is_tensor(obj):
34 | tensor_dict[key] = obj.to(dev)
35 |
36 |
37 | def load_pickle(file_path):
38 | with open(file_path, 'rb') as fr:
39 | return pickle.load(fr)
40 |
41 |
42 | def save_pickle(obj, file_path):
43 | with open(file_path, 'wb') as f:
44 | pickle.dump(obj, f)
45 |
46 |
47 | def pkl_to_txt(dataset='beauty'):
48 | data_dir = '../../dataset'
49 | file_path = os.path.join(data_dir, f'{dataset}/{dataset}_seq.pkl')
50 | record = load_pickle(file_path)
51 |
52 | # write to xxx.txt
53 | target_file = os.path.join(data_dir, f'{dataset}/{dataset}_seq.txt')
54 | with open(target_file, 'w') as fw:
55 | for seq in record:
56 | seq = list(map(str, seq))
57 | seq_str = ' '.join(seq) + '\n'
58 | fw.write(seq_str)
59 |
60 |
61 | def freeze(layer):
62 | for child in layer.children():
63 | for param in child.parameters():
64 | param.requires_grad = False
65 |
66 |
67 | def neg_sample(item_set, item_size): # 前闭后闭
68 | item = random.randint(1, item_size - 1)
69 | while item in item_set:
70 | item = random.randint(1, item_size - 1)
71 | return item
72 |
73 |
74 | def get_activate(act='relu'):
75 | if act == 'relu':
76 | return nn.ReLU()
77 | elif act == 'leaky_relu':
78 | return nn.LeakyReLU()
79 | elif act == 'gelu':
80 | return nn.GELU()
81 | elif act == 'tanh':
82 | return nn.Tanh()
83 | elif act == 'sigmoid':
84 | return nn.Sigmoid()
85 | else:
86 | raise KeyError(f'Not support current activate function: {act}, please add by yourself.')
87 |
88 |
89 | class HyperParamDict(EasyDict):
90 | def __init__(self, description=None):
91 | super(HyperParamDict, self).__init__({})
92 | self.description = description
93 | self.attr_registered = []
94 |
95 | def add_argument(self, param_name, type=object, default=None, action=None, choices=None, help=None):
96 | param_name = self._parse_param_name(param_name)
97 | if default and type:
98 | try:
99 | default = type(default)
100 | except Exception:
101 | assert isinstance(default, type), f'KeyError. Type of param {param_name} should be {type}.'
102 | if choices:
103 | assert isinstance(choices, List), f'choices should be a list.'
104 | assert default in choices, f'KeyError. Please choose {param_name} from {choices}. ' \
105 | f'Now {param_name} = {default}.'
106 | if action:
107 | default = self._parse_action(action)
108 | if help:
109 | assert isinstance(help, str), f'help should be a str.'
110 | self.attr_registered.append(param_name)
111 | self.__setattr__(param_name, default)
112 |
113 | @staticmethod
114 | def _parse_param_name(param_name: str):
115 | index = param_name.rfind('-') # find last pos of -, return -1 on failure
116 | return param_name[index + 1:]
117 |
118 | @staticmethod
119 | def _parse_action(action):
120 | action_infos = action.split('_')
121 | assert action_infos[0] == 'store' and action_infos[-1] in ['false', 'true'], \
122 | f"Wrong action format: {action}. Please choose from ['store_false', 'store_true']."
123 | res = False if action_infos[-1] == 'true' else True
124 | return res
125 |
126 | def keys(self):
127 | return self.attr_registered
128 |
129 | def values(self):
130 | return [self.get(key) for key in self.attr_registered]
131 |
132 | def items(self):
133 | return [(key, self.get(key)) for key in self.attr_registered]
134 |
135 | def __str__(self):
136 | info_str = 'HyperParamDict{'
137 | param_list = []
138 | for key, value in self.items():
139 | param_list.append(f'({key}: {value})')
140 | info_str += ', '.join(param_list) + '}'
141 | return info_str
142 |
143 |
144 | class KMeans(object):
145 | def __init__(self, num_cluster, seed, hidden_size, gpu_id=0, device="cpu"):
146 | """
147 | Args:
148 | k: number of clusters
149 | """
150 | self.seed = seed
151 | self.num_cluster = num_cluster
152 | self.max_points_per_centroid = 4096
153 | self.min_points_per_centroid = 0
154 | self.gpu_id = 0
155 | self.device = device
156 | self.first_batch = True
157 | self.hidden_size = hidden_size
158 | self.clus, self.index = self.__init_cluster(self.hidden_size)
159 | self.centroids = [] # cluster centroids
160 |
161 | def __init_cluster(
162 | self, hidden_size, verbose=False, niter=20, nredo=5, max_points_per_centroid=4096, min_points_per_centroid=0
163 | ):
164 | logging.info(f" cluster train iterations: {niter}")
165 | clus = faiss.Clustering(hidden_size, self.num_cluster)
166 | clus.verbose = verbose
167 | clus.niter = niter
168 | clus.nredo = nredo
169 | clus.seed = self.seed
170 | clus.max_points_per_centroid = max_points_per_centroid
171 | clus.min_points_per_centroid = min_points_per_centroid
172 |
173 | res = faiss.StandardGpuResources()
174 | res.noTempMemory()
175 | cfg = faiss.GpuIndexFlatConfig()
176 | cfg.useFloat16 = False
177 | cfg.device = self.gpu_id
178 | index = faiss.GpuIndexFlatL2(res, hidden_size, cfg)
179 | return clus, index
180 |
181 | def train(self, x):
182 | # train to get centroids
183 | if x.shape[0] > self.num_cluster:
184 | self.clus.train(x, self.index)
185 | # get cluster centroids
186 | centroids = faiss.vector_to_array(self.clus.centroids).reshape(self.num_cluster, self.hidden_size)
187 | # convert to cuda Tensors for broadcast
188 | centroids = torch.Tensor(centroids).to(self.device)
189 | self.centroids = nn.functional.normalize(centroids, p=2, dim=1)
190 |
191 | def query(self, x):
192 | """
193 | Args
194 | x: batch intent representations of shape [B, D]
195 | Returns
196 | seq2cluster: assigned centroid id for each intent of shape [B]
197 | centroids_assignments: assigned centroid representation for each intent of shape [B, D]
198 | """
199 | # self.index.add(x)
200 | # D : cluster distances of shape [B, 1] I: cluster assignments of shape [B, 1]
201 | D, I = self.index.search(x, 1) # for each sample, find cluster distance and assignments
202 | seq2clusterID = [int(n[0]) for n in I]
203 | # print("cluster number:", self.num_cluster,"cluster in batch:", len(set(seq2cluster)))
204 | seq2clusterID = torch.LongTensor(seq2clusterID).to(self.device)
205 | centroids_assignments = self.centroids[seq2clusterID]
206 | return seq2clusterID, centroids_assignments
207 |
208 |
209 | def get_gpu_usage(device=None):
210 | r""" Return the reserved memory and total memory of given device in a string.
211 | Args:
212 | device: cuda.device. It is the device that the model run on.
213 |
214 | Returns:
215 | str: it contains the info about reserved memory and total memory of given device.
216 | """
217 |
218 | reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 3
219 | total = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3
220 |
221 | return '{:.2f} G/{:.2f} G'.format(reserved, total)
222 |
223 |
224 | if __name__ == '__main__':
225 | datasets = ['ml-10m']
226 | for dataset in datasets:
227 | pkl_to_txt(dataset)
228 |
--------------------------------------------------------------------------------
/toys.sh:
--------------------------------------------------------------------------------
1 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --k_intention 4 --seed 2023 --mark seed2023
2 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --ffn_hidden 256 --k_intention 4 --seed 2024 --mark seed2024-hidden256
3 | python runIOCRec.py --dataset toys --eval_mode uni100 --embed_size 64 --ffn_hidden 128 --k_intention 4 --seed 2024 --mark seed2024-hidden128
--------------------------------------------------------------------------------