├── src ├── __init__.py ├── dataloaders │ ├── __init__.py │ ├── negative_samplers │ │ ├── __init__.py │ │ ├── base.py │ │ └── random.py │ ├── base.py │ └── rec_dataloader.py ├── models │ ├── __init__.py │ ├── sasrec.py │ ├── bert4rec.py │ ├── relative_position.py │ ├── embedding.py │ ├── utils.py │ ├── heads.py │ └── new_transformer.py ├── datasets │ ├── ijcai.py │ ├── retail.py │ ├── yelp.py │ ├── ml_10m.py │ ├── __init__.py │ └── base.py ├── configs │ ├── yelp │ │ ├── yelp_full.yaml │ │ ├── yelp_battn.yaml │ │ ├── yelp_bhead.yaml │ │ ├── yelp_brpb.yaml │ │ ├── yelp_wo_brpb.yaml │ │ ├── yelp_bert_mb.yaml │ │ ├── yelp_wo_battn.yaml │ │ ├── yelp_wo_bhead.yaml │ │ ├── yelp_bert_one.yaml │ │ └── yelp_full_neg.yaml │ ├── ijcai │ │ ├── ijcai_brpb.yaml │ │ ├── ijcai_full.yaml │ │ ├── ijcai_battn.yaml │ │ ├── ijcai_bhead.yaml │ │ ├── ijcai_wo_bhead.yaml │ │ ├── ijcai_wo_brpb.yaml │ │ ├── ijcai_bert_mb.yaml │ │ ├── ijcai_bert_one.yaml │ │ ├── ijcai_wo_battn.yaml │ │ └── ijcai_full_neg.yaml │ └── retail │ │ ├── retail_brpb.yaml │ │ ├── retail_full.yaml │ │ ├── retail_battn.yaml │ │ ├── retail_bert_mb.yaml │ │ ├── retail_bhead.yaml │ │ ├── retail_wo_battn.yaml │ │ ├── retail_wo_bhead.yaml │ │ ├── retail_wo_brpb.yaml │ │ ├── retail_bert_one.yaml │ │ └── retail_full_neg.yaml ├── utils.py ├── model.py └── datamodule.py ├── requirements.txt ├── mb-str.jpg ├── run.py ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning == 1.5.3 2 | torch >= 1.3.0 -------------------------------------------------------------------------------- /mb-str.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanenming/mb-str/HEAD/mb-str.jpg -------------------------------------------------------------------------------- /src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .rec_dataloader import RecDataloader 3 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .bert4rec import BERT 3 | from .heads import DotProductPredictionHead, CGCDotProductPredictionHead -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.utilities.cli import LightningCLI 4 | from src.model import RecModel 5 | from src.datamodule import RecDataModule 6 | 7 | 8 | def cli_main(): 9 | cli = LightningCLI(RecModel, RecDataModule, save_config_overwrite=True) 10 | 11 | if __name__ == '__main__': 12 | cli_main() -------------------------------------------------------------------------------- /src/dataloaders/negative_samplers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .random import RandomNegativeSampler 3 | 4 | 5 | NEGATIVE_SAMPLERS = { 6 | RandomNegativeSampler.code(): RandomNegativeSampler, 7 | } 8 | 9 | def negative_sampler_factory(code, train, val, user_count, item_count, sample_size, save_folder): 10 | negative_sampler = NEGATIVE_SAMPLERS[code] 11 | return negative_sampler(train, val, user_count, item_count, sample_size, save_folder) 12 | -------------------------------------------------------------------------------- /src/datasets/ijcai.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import AbstractDataset 3 | 4 | import pandas as pd 5 | 6 | class IjcaiDataset(AbstractDataset): 7 | @classmethod 8 | def code(cls): 9 | return 'ijcai' 10 | 11 | def load_df(self): 12 | folder_path = self._get_rawdata_root_path() 13 | file_path = folder_path.joinpath('ijcai.txt') 14 | df = pd.read_csv(file_path, sep='\t', header=None) 15 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 16 | return df -------------------------------------------------------------------------------- /src/datasets/retail.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import AbstractDataset 3 | 4 | import pandas as pd 5 | 6 | class RetailDataset(AbstractDataset): 7 | @classmethod 8 | def code(cls): 9 | return 'retail' 10 | 11 | def load_df(self): 12 | folder_path = self._get_rawdata_root_path() 13 | file_path = folder_path.joinpath('retail.txt') 14 | df = pd.read_csv(file_path, sep='\t', header=None) 15 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 16 | return df -------------------------------------------------------------------------------- /src/datasets/yelp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .base import AbstractDataset 4 | 5 | import pandas as pd 6 | 7 | class YelpDataset(AbstractDataset): 8 | @classmethod 9 | def code(cls): 10 | return 'yelp' 11 | 12 | def load_df(self): 13 | folder_path = self._get_rawdata_root_path() 14 | file_path = folder_path.joinpath('yelp.txt') 15 | df = pd.read_csv(file_path, sep='\t', header=None) 16 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 17 | return df -------------------------------------------------------------------------------- /src/datasets/ml_10m.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .base import AbstractDataset 4 | 5 | import pandas as pd 6 | 7 | class ML10MDataset(AbstractDataset): 8 | @classmethod 9 | def code(cls): 10 | return 'ml-10m' 11 | 12 | def load_df(self): 13 | folder_path = self._get_rawdata_root_path() 14 | file_path = folder_path.joinpath('ml-10m.txt') 15 | df = pd.read_csv(file_path, sep='\t', header=None) 16 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 17 | return df 18 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .ml_10m import ML10MDataset 4 | from .yelp import YelpDataset 5 | from .retail import RetailDataset 6 | from .ijcai import IjcaiDataset 7 | 8 | DATASETS = { 9 | ML10MDataset.code(): ML10MDataset, 10 | YelpDataset.code(): YelpDataset, 11 | RetailDataset.code(): RetailDataset, 12 | IjcaiDataset.code(): IjcaiDataset 13 | } 14 | 15 | 16 | def dataset_factory( 17 | dataset_code, 18 | target_behavior, 19 | multi_behavior, 20 | min_uc, 21 | ): 22 | dataset = DATASETS[dataset_code] 23 | return dataset(target_behavior, multi_behavior, min_uc) 24 | -------------------------------------------------------------------------------- /src/configs/yelp/yelp_full.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/full 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 30 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 4 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_wo_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/wo-brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_bert_mb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/bert-mb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_wo_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/wo-battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_wo_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/wo-bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_bert_one.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 1 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: False 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/bert-one 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_full.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.3 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 32 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/full 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '2' 41 | accelerator: ddp 42 | accumulate_grad_batches: 4 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 43 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 3 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_wo_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 32 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/wo-bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '2' 41 | accelerator: ddp 42 | accumulate_grad_batches: 4 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_wo_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 32 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/wo-brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '2' 41 | accelerator: ddp 42 | accumulate_grad_batches: 4 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_full.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/full 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_bert_mb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/bert-mb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accumulate_grad_batches: 2 42 | accelerator: ddp 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_bert_one.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: False 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/bert-one 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_wo_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/wo-battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_bert_mb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/bert-mb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_wo_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/wo-battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_wo_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/wo-bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_wo_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/wo-brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_bert_one.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 1 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: False 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/bert-one 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_full_neg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 32 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.3 25 | num_workers: 4 26 | train_negative_sampler_code: random_train 27 | train_negative_sample_size: 100 28 | val_negative_sampler_code: random 29 | val_negative_sample_size: 99 30 | train_batch_size: 128 31 | val_batch_size: 128 32 | predict_only_target: False 33 | 34 | trainer: 35 | default_root_dir: logs/yelp/full-neg 36 | callbacks: 37 | - class_path: pytorch_lightning.callbacks.EarlyStopping 38 | init_args: 39 | monitor: 'Val:NDCG@10' 40 | patience: 50 41 | mode: max 42 | gpus: '1' 43 | accelerator: ddp 44 | seed_everything: 42 45 | 46 | optimizer: 47 | class_path: torch.optim.Adam 48 | init_args: 49 | lr: 0.001 50 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_full_neg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.3 25 | num_workers: 4 26 | train_negative_sampler_code: random 27 | train_negative_sample_size: 1 28 | val_negative_sampler_code: random 29 | val_negative_sample_size: 99 30 | train_batch_size: 128 31 | val_batch_size: 128 32 | predict_only_target: False 33 | 34 | trainer: 35 | default_root_dir: logs/ijcai/full-neg-1 36 | callbacks: 37 | - class_path: pytorch_lightning.callbacks.EarlyStopping 38 | init_args: 39 | monitor: 'Val:NDCG@10' 40 | patience: 10 41 | mode: max 42 | gpus: '1' 43 | accelerator: ddp 44 | accumulate_grad_batches: 1 45 | seed_everything: 42 46 | 47 | optimizer: 48 | class_path: torch.optim.Adam 49 | init_args: 50 | lr: 0.001 51 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_full_neg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | train_negative_sampler_code: random 27 | train_negative_sample_size: 100 28 | val_negative_sampler_code: random 29 | val_negative_sample_size: 99 30 | train_batch_size: 128 31 | val_batch_size: 128 32 | predict_only_target: False 33 | 34 | trainer: 35 | default_root_dir: logs/retail/full-neg 36 | callbacks: 37 | - class_path: pytorch_lightning.callbacks.EarlyStopping 38 | init_args: 39 | monitor: 'Val:NDCG@10' 40 | patience: 10 41 | mode: max 42 | gpus: '1' 43 | accelerator: ddp 44 | accumulate_grad_batches: 1 45 | seed_everything: 42 46 | 47 | optimizer: 48 | class_path: torch.optim.Adam 49 | init_args: 50 | lr: 0.001 51 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/models/sasrec.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn as nn 4 | import pytorch_lightning as pl 5 | from .embedding import BERTEmbedding 6 | from .transformer import TransformerBlock 7 | 8 | 9 | 10 | class SAS(pl.LightningModule): 11 | def __init__(self, 12 | max_len, 13 | num_items, 14 | n_layer, 15 | n_head, 16 | d_model, 17 | dropout 18 | ): 19 | super().__init__() 20 | self.d_model = d_model 21 | self.num_items = num_items 22 | 23 | vocab_size = num_items + 1 24 | self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model, max_len=max_len, dropout=dropout) 25 | # multi-layers transformer blocks, deep network 26 | self.transformer_blocks = nn.ModuleList( 27 | [TransformerBlock(d_model, n_head, d_model * 4, dropout) for _ in range(n_layer)]) 28 | 29 | def forward(self, x): 30 | mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) 31 | tl = x.shape[1] # time dim len for enforce causality 32 | mask *= torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.device)) 33 | # embedding the indexed sequence to sequence of vectors 34 | x = self.embedding(x) 35 | # running over multiple transformer blocks 36 | for transformer in self.transformer_blocks: 37 | x = transformer.forward(x, mask) 38 | return x -------------------------------------------------------------------------------- /src/dataloaders/negative_samplers/base.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from abc import * 4 | from pathlib import Path 5 | import pickle 6 | 7 | 8 | class AbstractNegativeSampler(metaclass=ABCMeta): 9 | def __init__(self, train, val, user_count, item_count, sample_size, save_folder): 10 | self.train = train 11 | self.val = val 12 | self.user_count = user_count 13 | self.item_count = item_count 14 | self.sample_size = sample_size 15 | self.save_folder = save_folder 16 | 17 | @classmethod 18 | @abstractmethod 19 | def code(cls): 20 | pass 21 | 22 | @abstractmethod 23 | def generate_negative_samples(self): 24 | pass 25 | 26 | def get_negative_samples(self): 27 | savefile_path = self._get_save_path() 28 | if savefile_path.is_file(): 29 | print('Negatives samples exist. Loading.') 30 | negative_samples = pickle.load(savefile_path.open('rb')) 31 | return negative_samples 32 | print("Negative samples don't exist. Generating.") 33 | negative_samples = self.generate_negative_samples() 34 | with savefile_path.open('wb') as f: 35 | pickle.dump(negative_samples, f) 36 | return negative_samples 37 | 38 | def _get_save_path(self): 39 | folder = Path(self.save_folder) 40 | filename = '{}-sample_size{}.pkl'.format(self.code(), self.sample_size) 41 | return folder.joinpath(filename) -------------------------------------------------------------------------------- /src/dataloaders/base.py: -------------------------------------------------------------------------------- 1 | 2 | from .negative_samplers import negative_sampler_factory 3 | 4 | from abc import * 5 | 6 | 7 | class AbstractDataloader(metaclass=ABCMeta): 8 | def __init__(self, 9 | dataset, 10 | val_negative_sampler_code, 11 | val_negative_sample_size 12 | ): 13 | save_folder = dataset._get_preprocessed_folder_path() 14 | dataset = dataset.load_dataset() 15 | self.train = dataset['train'] 16 | self.val = dataset['val'] 17 | self.train_b = dataset['train_b'] 18 | self.val_b = dataset['val_b'] 19 | self.val_num = dataset['val_num'] 20 | self.umap = dataset['umap'] 21 | self.smap = dataset['smap'] 22 | self.bmap = dataset['bmap'] 23 | self.user_count = len(self.umap) 24 | self.item_count = len(self.smap) 25 | self.behavior_count = len(self.bmap) 26 | 27 | val_negative_sampler = negative_sampler_factory(val_negative_sampler_code, self.train, self.val, 28 | self.user_count, self.item_count, 29 | val_negative_sample_size, 30 | save_folder) 31 | self.val_negative_samples = val_negative_sampler.get_negative_samples() 32 | 33 | 34 | @abstractmethod 35 | def get_train_loader(self): 36 | pass 37 | 38 | @abstractmethod 39 | def get_val_loader(self): 40 | pass -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def recall(scores, labels, k): 5 | scores = scores.cpu() 6 | labels = labels.cpu() 7 | rank = (-scores).argsort(dim=1) 8 | cut = rank[:, :k] 9 | hit = labels.gather(1, cut) 10 | return (hit.sum(1).float() / labels.sum(1).float()).mean().item() 11 | 12 | 13 | def ndcg(scores, labels, k): 14 | scores = scores.cpu() 15 | labels = labels.cpu() 16 | rank = (-scores).argsort(dim=1) 17 | cut = rank[:, :k] 18 | hits = labels.gather(1, cut) 19 | position = torch.arange(2, 2+k) 20 | weights = 1 / torch.log2(position.float()) 21 | dcg = (hits.float() * weights).sum(1) 22 | idcg = torch.Tensor([weights[:min(n, k)].sum() for n in labels.sum(1)]) 23 | ndcg = dcg / idcg 24 | return ndcg.mean() 25 | 26 | 27 | def recalls_and_ndcgs_for_ks(scores, labels, ks): 28 | metrics = {} 29 | 30 | scores = scores.cpu() 31 | labels = labels.cpu() 32 | answer_count = labels.sum(1) 33 | answer_count_float = answer_count.float() 34 | labels_float = labels.float() 35 | rank = (-scores).argsort(dim=1) 36 | cut = rank 37 | for k in sorted(ks, reverse=True): 38 | cut = cut[:, :k] 39 | hits = labels_float.gather(1, cut) 40 | metrics['Recall@%d' % k] = (hits.sum(1) / answer_count_float).mean().item() 41 | 42 | position = torch.arange(2, 2+k) 43 | weights = 1 / torch.log2(position.float()) 44 | dcg = (hits * weights).sum(1) 45 | idcg = torch.Tensor([weights[:min(n, k)].sum() for n in answer_count]) 46 | ndcg = (dcg / idcg).mean() 47 | metrics['NDCG@%d' % k] = ndcg 48 | 49 | return metrics 50 | 51 | def split_at_index(dim, index, t): 52 | pre_slices = (slice(None),) * dim 53 | l = (*pre_slices, slice(None, index)) 54 | r = (*pre_slices, slice(index, None)) 55 | return t[l], t[r] -------------------------------------------------------------------------------- /src/dataloaders/negative_samplers/random.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .base import AbstractNegativeSampler 4 | 5 | from tqdm import trange 6 | 7 | import numpy as np 8 | 9 | 10 | class RandomNegativeSampler(AbstractNegativeSampler): 11 | @classmethod 12 | def code(cls): 13 | return 'random' 14 | 15 | def generate_negative_samples(self): 16 | negative_samples = {} 17 | print('Sampling negative items') 18 | for user in trange(1, self.user_count+1): 19 | if user not in self.val.keys(): 20 | continue 21 | seen = set(self.train[user]) 22 | seen.update(self.val[user]) 23 | 24 | samples = [] 25 | for _ in range(self.sample_size): 26 | item = np.random.choice(self.item_count) + 1 27 | while item in seen or item in samples: 28 | item = np.random.choice(self.item_count) + 1 29 | samples.append(item) 30 | 31 | negative_samples[user] = samples 32 | 33 | return negative_samples 34 | 35 | # class RandomNegativeSamplerTrain(AbstractNegativeSampler): 36 | # @classmethod 37 | # def code(cls): 38 | # return 'random_train' 39 | 40 | # def generate_negative_samples(self): 41 | # negative_samples = {} 42 | # print('Sampling negative items') 43 | # for user in trange(1, self.user_count+1): 44 | # seen = set(self.train[user]) 45 | # if user in self.val.keys(): 46 | # seen.update(self.val[user]) 47 | # samples = [] 48 | # for _ in range(self.sample_size): 49 | # item = np.random.choice(self.item_count) + 1 50 | # while item in seen or item in samples: 51 | # item = np.random.choice(self.item_count) + 1 52 | # samples.append(item) 53 | 54 | # negative_samples[user] = samples 55 | 56 | # return negative_samples -------------------------------------------------------------------------------- /src/models/bert4rec.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn as nn 4 | import pytorch_lightning as pl 5 | from .embedding import BERTEmbedding, SimpleEmbedding 6 | from .new_transformer import TransformerBlock 7 | 8 | 9 | class BERT(pl.LightningModule): 10 | def __init__(self, 11 | max_len: int = None, 12 | num_items: int = None, 13 | n_layer: int = None, 14 | n_head: int = None, 15 | n_b: int = None, 16 | d_model: int = None, 17 | dropout: float = .0, 18 | battn: bool = None, 19 | bpff: bool = None, 20 | brpb: bool = None, 21 | ): 22 | super().__init__() 23 | self.d_model = d_model 24 | self.num_items = num_items 25 | self.n_b = n_b 26 | self.battn = battn 27 | self.bpff = bpff 28 | self.brpb = brpb 29 | 30 | vocab_size = num_items + 1 + n_b # add padding and mask 31 | # if self.brpb: 32 | if True: 33 | # simple embedding, adding behavioral relative positional bias in transformer blocks 34 | self.embedding = SimpleEmbedding(vocab_size=vocab_size, embed_size=d_model, dropout=dropout) 35 | else: 36 | # embedding for BERT, sum of positional, token embeddings 37 | self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model, max_len=max_len, dropout=dropout) 38 | # multi-layers transformer blocks 39 | self.transformer_blocks = nn.ModuleList( 40 | [TransformerBlock(d_model, n_head, d_model * 4, n_b, battn, bpff, brpb, dropout) for _ in range(n_layer)]) 41 | 42 | def forward(self, x, b_seq): 43 | # get padding masks 44 | mask = (x > 0) 45 | # embedding the indexed sequence to sequence of vectors 46 | x = self.embedding(x) 47 | # running over multiple transformer blocks 48 | for transformer in self.transformer_blocks: 49 | x = transformer.forward(x, b_seq, mask) 50 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 8 |