├── src ├── __init__.py ├── dataloaders │ ├── __init__.py │ ├── negative_samplers │ │ ├── __init__.py │ │ ├── base.py │ │ └── random.py │ ├── base.py │ └── rec_dataloader.py ├── models │ ├── __init__.py │ ├── sasrec.py │ ├── bert4rec.py │ ├── relative_position.py │ ├── embedding.py │ ├── utils.py │ ├── heads.py │ └── new_transformer.py ├── datasets │ ├── ijcai.py │ ├── retail.py │ ├── yelp.py │ ├── ml_10m.py │ ├── __init__.py │ └── base.py ├── configs │ ├── yelp │ │ ├── yelp_full.yaml │ │ ├── yelp_battn.yaml │ │ ├── yelp_bhead.yaml │ │ ├── yelp_brpb.yaml │ │ ├── yelp_wo_brpb.yaml │ │ ├── yelp_bert_mb.yaml │ │ ├── yelp_wo_battn.yaml │ │ ├── yelp_wo_bhead.yaml │ │ ├── yelp_bert_one.yaml │ │ └── yelp_full_neg.yaml │ ├── ijcai │ │ ├── ijcai_brpb.yaml │ │ ├── ijcai_full.yaml │ │ ├── ijcai_battn.yaml │ │ ├── ijcai_bhead.yaml │ │ ├── ijcai_wo_bhead.yaml │ │ ├── ijcai_wo_brpb.yaml │ │ ├── ijcai_bert_mb.yaml │ │ ├── ijcai_bert_one.yaml │ │ ├── ijcai_wo_battn.yaml │ │ └── ijcai_full_neg.yaml │ └── retail │ │ ├── retail_brpb.yaml │ │ ├── retail_full.yaml │ │ ├── retail_battn.yaml │ │ ├── retail_bert_mb.yaml │ │ ├── retail_bhead.yaml │ │ ├── retail_wo_battn.yaml │ │ ├── retail_wo_bhead.yaml │ │ ├── retail_wo_brpb.yaml │ │ ├── retail_bert_one.yaml │ │ └── retail_full_neg.yaml ├── utils.py ├── model.py └── datamodule.py ├── requirements.txt ├── mb-str.jpg ├── run.py ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning == 1.5.3 2 | torch >= 1.3.0 -------------------------------------------------------------------------------- /mb-str.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanenming/mb-str/HEAD/mb-str.jpg -------------------------------------------------------------------------------- /src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .rec_dataloader import RecDataloader 3 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .bert4rec import BERT 3 | from .heads import DotProductPredictionHead, CGCDotProductPredictionHead -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.utilities.cli import LightningCLI 4 | from src.model import RecModel 5 | from src.datamodule import RecDataModule 6 | 7 | 8 | def cli_main(): 9 | cli = LightningCLI(RecModel, RecDataModule, save_config_overwrite=True) 10 | 11 | if __name__ == '__main__': 12 | cli_main() -------------------------------------------------------------------------------- /src/dataloaders/negative_samplers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .random import RandomNegativeSampler 3 | 4 | 5 | NEGATIVE_SAMPLERS = { 6 | RandomNegativeSampler.code(): RandomNegativeSampler, 7 | } 8 | 9 | def negative_sampler_factory(code, train, val, user_count, item_count, sample_size, save_folder): 10 | negative_sampler = NEGATIVE_SAMPLERS[code] 11 | return negative_sampler(train, val, user_count, item_count, sample_size, save_folder) 12 | -------------------------------------------------------------------------------- /src/datasets/ijcai.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import AbstractDataset 3 | 4 | import pandas as pd 5 | 6 | class IjcaiDataset(AbstractDataset): 7 | @classmethod 8 | def code(cls): 9 | return 'ijcai' 10 | 11 | def load_df(self): 12 | folder_path = self._get_rawdata_root_path() 13 | file_path = folder_path.joinpath('ijcai.txt') 14 | df = pd.read_csv(file_path, sep='\t', header=None) 15 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 16 | return df -------------------------------------------------------------------------------- /src/datasets/retail.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import AbstractDataset 3 | 4 | import pandas as pd 5 | 6 | class RetailDataset(AbstractDataset): 7 | @classmethod 8 | def code(cls): 9 | return 'retail' 10 | 11 | def load_df(self): 12 | folder_path = self._get_rawdata_root_path() 13 | file_path = folder_path.joinpath('retail.txt') 14 | df = pd.read_csv(file_path, sep='\t', header=None) 15 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 16 | return df -------------------------------------------------------------------------------- /src/datasets/yelp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .base import AbstractDataset 4 | 5 | import pandas as pd 6 | 7 | class YelpDataset(AbstractDataset): 8 | @classmethod 9 | def code(cls): 10 | return 'yelp' 11 | 12 | def load_df(self): 13 | folder_path = self._get_rawdata_root_path() 14 | file_path = folder_path.joinpath('yelp.txt') 15 | df = pd.read_csv(file_path, sep='\t', header=None) 16 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 17 | return df -------------------------------------------------------------------------------- /src/datasets/ml_10m.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .base import AbstractDataset 4 | 5 | import pandas as pd 6 | 7 | class ML10MDataset(AbstractDataset): 8 | @classmethod 9 | def code(cls): 10 | return 'ml-10m' 11 | 12 | def load_df(self): 13 | folder_path = self._get_rawdata_root_path() 14 | file_path = folder_path.joinpath('ml-10m.txt') 15 | df = pd.read_csv(file_path, sep='\t', header=None) 16 | df.columns = ['uid', 'sid', 'behavior', 'timestamp'] 17 | return df 18 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .ml_10m import ML10MDataset 4 | from .yelp import YelpDataset 5 | from .retail import RetailDataset 6 | from .ijcai import IjcaiDataset 7 | 8 | DATASETS = { 9 | ML10MDataset.code(): ML10MDataset, 10 | YelpDataset.code(): YelpDataset, 11 | RetailDataset.code(): RetailDataset, 12 | IjcaiDataset.code(): IjcaiDataset 13 | } 14 | 15 | 16 | def dataset_factory( 17 | dataset_code, 18 | target_behavior, 19 | multi_behavior, 20 | min_uc, 21 | ): 22 | dataset = DATASETS[dataset_code] 23 | return dataset(target_behavior, multi_behavior, min_uc) 24 | -------------------------------------------------------------------------------- /src/configs/yelp/yelp_full.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/full 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 30 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 4 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_wo_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/wo-brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_bert_mb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/bert-mb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_wo_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/wo-battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_wo_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/wo-bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_bert_one.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 1 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: False 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/yelp/bert-one 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | seed_everything: 42 43 | 44 | optimizer: 45 | class_path: torch.optim.Adam 46 | init_args: 47 | lr: 0.001 48 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_full.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.3 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 32 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/full 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '2' 41 | accelerator: ddp 42 | accumulate_grad_batches: 4 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 43 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 3 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_wo_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 32 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/wo-bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '2' 41 | accelerator: ddp 42 | accumulate_grad_batches: 4 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_wo_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 32 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/wo-brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '2' 41 | accelerator: ddp 42 | accumulate_grad_batches: 4 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_full.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/full 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_bert_mb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/bert-mb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accumulate_grad_batches: 2 42 | accelerator: ddp 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_bert_one.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: False 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/bert-one 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_wo_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 64 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/ijcai/wo-battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 2 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_bert_mb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/bert-mb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_wo_battn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/wo-battn 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_wo_bhead.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/wo-bhead 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_wo_brpb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: False 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/wo-brpb 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 10 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_bert_one.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 1 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: False 13 | bpff: False 14 | brpb: False 15 | b_head: False 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: False 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | val_negative_sampler_code: random 27 | val_negative_sample_size: 99 28 | train_batch_size: 128 29 | val_batch_size: 128 30 | predict_only_target: False 31 | 32 | trainer: 33 | default_root_dir: logs/retail/bert-one 34 | callbacks: 35 | - class_path: pytorch_lightning.callbacks.EarlyStopping 36 | init_args: 37 | monitor: 'Val:NDCG@10' 38 | patience: 5 39 | mode: max 40 | gpus: '1' 41 | accelerator: ddp 42 | accumulate_grad_batches: 1 43 | seed_everything: 42 44 | 45 | optimizer: 46 | class_path: torch.optim.Adam 47 | init_args: 48 | lr: 0.001 49 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/yelp/yelp_full_neg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 32 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.1 10 | n_layer: 2 11 | num_items: 22734 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'yelp' 19 | target_behavior: 'pos' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 22734 23 | max_len: 50 24 | mask_prob: 0.3 25 | num_workers: 4 26 | train_negative_sampler_code: random_train 27 | train_negative_sample_size: 100 28 | val_negative_sampler_code: random 29 | val_negative_sample_size: 99 30 | train_batch_size: 128 31 | val_batch_size: 128 32 | predict_only_target: False 33 | 34 | trainer: 35 | default_root_dir: logs/yelp/full-neg 36 | callbacks: 37 | - class_path: pytorch_lightning.callbacks.EarlyStopping 38 | init_args: 39 | monitor: 'Val:NDCG@10' 40 | patience: 50 41 | mode: max 42 | gpus: '1' 43 | accelerator: ddp 44 | seed_everything: 42 45 | 46 | optimizer: 47 | class_path: torch.optim.Adam 48 | init_args: 49 | lr: 0.001 50 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/ijcai/ijcai_full_neg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 874306 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'ijcai' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 874306 23 | max_len: 50 24 | mask_prob: 0.3 25 | num_workers: 4 26 | train_negative_sampler_code: random 27 | train_negative_sample_size: 1 28 | val_negative_sampler_code: random 29 | val_negative_sample_size: 99 30 | train_batch_size: 128 31 | val_batch_size: 128 32 | predict_only_target: False 33 | 34 | trainer: 35 | default_root_dir: logs/ijcai/full-neg-1 36 | callbacks: 37 | - class_path: pytorch_lightning.callbacks.EarlyStopping 38 | init_args: 39 | monitor: 'Val:NDCG@10' 40 | patience: 10 41 | mode: max 42 | gpus: '1' 43 | accelerator: ddp 44 | accumulate_grad_batches: 1 45 | seed_everything: 42 46 | 47 | optimizer: 48 | class_path: torch.optim.Adam 49 | init_args: 50 | lr: 0.001 51 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/configs/retail/retail_full_neg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone: 3 | class_path: src.model.BERT 4 | init_args: 5 | max_len: 50 6 | d_model: 16 7 | n_head: 2 8 | n_b: 4 9 | dropout: 0.2 10 | n_layer: 2 11 | num_items: 99037 12 | battn: True 13 | bpff: True 14 | brpb: True 15 | b_head: True 16 | 17 | data: 18 | dataset_code: 'retail' 19 | target_behavior: 'buy' 20 | multi_behavior: True 21 | min_uc: 3 22 | num_items: 99037 23 | max_len: 50 24 | mask_prob: 0.2 25 | num_workers: 4 26 | train_negative_sampler_code: random 27 | train_negative_sample_size: 100 28 | val_negative_sampler_code: random 29 | val_negative_sample_size: 99 30 | train_batch_size: 128 31 | val_batch_size: 128 32 | predict_only_target: False 33 | 34 | trainer: 35 | default_root_dir: logs/retail/full-neg 36 | callbacks: 37 | - class_path: pytorch_lightning.callbacks.EarlyStopping 38 | init_args: 39 | monitor: 'Val:NDCG@10' 40 | patience: 10 41 | mode: max 42 | gpus: '1' 43 | accelerator: ddp 44 | accumulate_grad_batches: 1 45 | seed_everything: 42 46 | 47 | optimizer: 48 | class_path: torch.optim.Adam 49 | init_args: 50 | lr: 0.001 51 | weight_decay: 0.000001 -------------------------------------------------------------------------------- /src/models/sasrec.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn as nn 4 | import pytorch_lightning as pl 5 | from .embedding import BERTEmbedding 6 | from .transformer import TransformerBlock 7 | 8 | 9 | 10 | class SAS(pl.LightningModule): 11 | def __init__(self, 12 | max_len, 13 | num_items, 14 | n_layer, 15 | n_head, 16 | d_model, 17 | dropout 18 | ): 19 | super().__init__() 20 | self.d_model = d_model 21 | self.num_items = num_items 22 | 23 | vocab_size = num_items + 1 24 | self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model, max_len=max_len, dropout=dropout) 25 | # multi-layers transformer blocks, deep network 26 | self.transformer_blocks = nn.ModuleList( 27 | [TransformerBlock(d_model, n_head, d_model * 4, dropout) for _ in range(n_layer)]) 28 | 29 | def forward(self, x): 30 | mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) 31 | tl = x.shape[1] # time dim len for enforce causality 32 | mask *= torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.device)) 33 | # embedding the indexed sequence to sequence of vectors 34 | x = self.embedding(x) 35 | # running over multiple transformer blocks 36 | for transformer in self.transformer_blocks: 37 | x = transformer.forward(x, mask) 38 | return x -------------------------------------------------------------------------------- /src/dataloaders/negative_samplers/base.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from abc import * 4 | from pathlib import Path 5 | import pickle 6 | 7 | 8 | class AbstractNegativeSampler(metaclass=ABCMeta): 9 | def __init__(self, train, val, user_count, item_count, sample_size, save_folder): 10 | self.train = train 11 | self.val = val 12 | self.user_count = user_count 13 | self.item_count = item_count 14 | self.sample_size = sample_size 15 | self.save_folder = save_folder 16 | 17 | @classmethod 18 | @abstractmethod 19 | def code(cls): 20 | pass 21 | 22 | @abstractmethod 23 | def generate_negative_samples(self): 24 | pass 25 | 26 | def get_negative_samples(self): 27 | savefile_path = self._get_save_path() 28 | if savefile_path.is_file(): 29 | print('Negatives samples exist. Loading.') 30 | negative_samples = pickle.load(savefile_path.open('rb')) 31 | return negative_samples 32 | print("Negative samples don't exist. Generating.") 33 | negative_samples = self.generate_negative_samples() 34 | with savefile_path.open('wb') as f: 35 | pickle.dump(negative_samples, f) 36 | return negative_samples 37 | 38 | def _get_save_path(self): 39 | folder = Path(self.save_folder) 40 | filename = '{}-sample_size{}.pkl'.format(self.code(), self.sample_size) 41 | return folder.joinpath(filename) -------------------------------------------------------------------------------- /src/dataloaders/base.py: -------------------------------------------------------------------------------- 1 | 2 | from .negative_samplers import negative_sampler_factory 3 | 4 | from abc import * 5 | 6 | 7 | class AbstractDataloader(metaclass=ABCMeta): 8 | def __init__(self, 9 | dataset, 10 | val_negative_sampler_code, 11 | val_negative_sample_size 12 | ): 13 | save_folder = dataset._get_preprocessed_folder_path() 14 | dataset = dataset.load_dataset() 15 | self.train = dataset['train'] 16 | self.val = dataset['val'] 17 | self.train_b = dataset['train_b'] 18 | self.val_b = dataset['val_b'] 19 | self.val_num = dataset['val_num'] 20 | self.umap = dataset['umap'] 21 | self.smap = dataset['smap'] 22 | self.bmap = dataset['bmap'] 23 | self.user_count = len(self.umap) 24 | self.item_count = len(self.smap) 25 | self.behavior_count = len(self.bmap) 26 | 27 | val_negative_sampler = negative_sampler_factory(val_negative_sampler_code, self.train, self.val, 28 | self.user_count, self.item_count, 29 | val_negative_sample_size, 30 | save_folder) 31 | self.val_negative_samples = val_negative_sampler.get_negative_samples() 32 | 33 | 34 | @abstractmethod 35 | def get_train_loader(self): 36 | pass 37 | 38 | @abstractmethod 39 | def get_val_loader(self): 40 | pass -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def recall(scores, labels, k): 5 | scores = scores.cpu() 6 | labels = labels.cpu() 7 | rank = (-scores).argsort(dim=1) 8 | cut = rank[:, :k] 9 | hit = labels.gather(1, cut) 10 | return (hit.sum(1).float() / labels.sum(1).float()).mean().item() 11 | 12 | 13 | def ndcg(scores, labels, k): 14 | scores = scores.cpu() 15 | labels = labels.cpu() 16 | rank = (-scores).argsort(dim=1) 17 | cut = rank[:, :k] 18 | hits = labels.gather(1, cut) 19 | position = torch.arange(2, 2+k) 20 | weights = 1 / torch.log2(position.float()) 21 | dcg = (hits.float() * weights).sum(1) 22 | idcg = torch.Tensor([weights[:min(n, k)].sum() for n in labels.sum(1)]) 23 | ndcg = dcg / idcg 24 | return ndcg.mean() 25 | 26 | 27 | def recalls_and_ndcgs_for_ks(scores, labels, ks): 28 | metrics = {} 29 | 30 | scores = scores.cpu() 31 | labels = labels.cpu() 32 | answer_count = labels.sum(1) 33 | answer_count_float = answer_count.float() 34 | labels_float = labels.float() 35 | rank = (-scores).argsort(dim=1) 36 | cut = rank 37 | for k in sorted(ks, reverse=True): 38 | cut = cut[:, :k] 39 | hits = labels_float.gather(1, cut) 40 | metrics['Recall@%d' % k] = (hits.sum(1) / answer_count_float).mean().item() 41 | 42 | position = torch.arange(2, 2+k) 43 | weights = 1 / torch.log2(position.float()) 44 | dcg = (hits * weights).sum(1) 45 | idcg = torch.Tensor([weights[:min(n, k)].sum() for n in answer_count]) 46 | ndcg = (dcg / idcg).mean() 47 | metrics['NDCG@%d' % k] = ndcg 48 | 49 | return metrics 50 | 51 | def split_at_index(dim, index, t): 52 | pre_slices = (slice(None),) * dim 53 | l = (*pre_slices, slice(None, index)) 54 | r = (*pre_slices, slice(index, None)) 55 | return t[l], t[r] -------------------------------------------------------------------------------- /src/dataloaders/negative_samplers/random.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .base import AbstractNegativeSampler 4 | 5 | from tqdm import trange 6 | 7 | import numpy as np 8 | 9 | 10 | class RandomNegativeSampler(AbstractNegativeSampler): 11 | @classmethod 12 | def code(cls): 13 | return 'random' 14 | 15 | def generate_negative_samples(self): 16 | negative_samples = {} 17 | print('Sampling negative items') 18 | for user in trange(1, self.user_count+1): 19 | if user not in self.val.keys(): 20 | continue 21 | seen = set(self.train[user]) 22 | seen.update(self.val[user]) 23 | 24 | samples = [] 25 | for _ in range(self.sample_size): 26 | item = np.random.choice(self.item_count) + 1 27 | while item in seen or item in samples: 28 | item = np.random.choice(self.item_count) + 1 29 | samples.append(item) 30 | 31 | negative_samples[user] = samples 32 | 33 | return negative_samples 34 | 35 | # class RandomNegativeSamplerTrain(AbstractNegativeSampler): 36 | # @classmethod 37 | # def code(cls): 38 | # return 'random_train' 39 | 40 | # def generate_negative_samples(self): 41 | # negative_samples = {} 42 | # print('Sampling negative items') 43 | # for user in trange(1, self.user_count+1): 44 | # seen = set(self.train[user]) 45 | # if user in self.val.keys(): 46 | # seen.update(self.val[user]) 47 | # samples = [] 48 | # for _ in range(self.sample_size): 49 | # item = np.random.choice(self.item_count) + 1 50 | # while item in seen or item in samples: 51 | # item = np.random.choice(self.item_count) + 1 52 | # samples.append(item) 53 | 54 | # negative_samples[user] = samples 55 | 56 | # return negative_samples -------------------------------------------------------------------------------- /src/models/bert4rec.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn as nn 4 | import pytorch_lightning as pl 5 | from .embedding import BERTEmbedding, SimpleEmbedding 6 | from .new_transformer import TransformerBlock 7 | 8 | 9 | class BERT(pl.LightningModule): 10 | def __init__(self, 11 | max_len: int = None, 12 | num_items: int = None, 13 | n_layer: int = None, 14 | n_head: int = None, 15 | n_b: int = None, 16 | d_model: int = None, 17 | dropout: float = .0, 18 | battn: bool = None, 19 | bpff: bool = None, 20 | brpb: bool = None, 21 | ): 22 | super().__init__() 23 | self.d_model = d_model 24 | self.num_items = num_items 25 | self.n_b = n_b 26 | self.battn = battn 27 | self.bpff = bpff 28 | self.brpb = brpb 29 | 30 | vocab_size = num_items + 1 + n_b # add padding and mask 31 | # if self.brpb: 32 | if True: 33 | # simple embedding, adding behavioral relative positional bias in transformer blocks 34 | self.embedding = SimpleEmbedding(vocab_size=vocab_size, embed_size=d_model, dropout=dropout) 35 | else: 36 | # embedding for BERT, sum of positional, token embeddings 37 | self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model, max_len=max_len, dropout=dropout) 38 | # multi-layers transformer blocks 39 | self.transformer_blocks = nn.ModuleList( 40 | [TransformerBlock(d_model, n_head, d_model * 4, n_b, battn, bpff, brpb, dropout) for _ in range(n_layer)]) 41 | 42 | def forward(self, x, b_seq): 43 | # get padding masks 44 | mask = (x > 0) 45 | # embedding the indexed sequence to sequence of vectors 46 | x = self.embedding(x) 47 | # running over multiple transformer blocks 48 | for transformer in self.transformer_blocks: 49 | x = transformer.forward(x, b_seq, mask) 50 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 8 |
9 | 10 | # Multi-Behavior Sequential Transformer Recommender 11 | 12 |
13 | 14 | ![](mb-str.jpg) 15 | 16 |
17 | 18 |
19 | The code is tested on an NVIDIA 1080Ti Platform. 20 | 21 | ## Quick Start 22 | 0. install pytorch and other dependencies 23 | ```bash 24 | pip install -r requirements.txt 25 | ``` 26 | 1. Download datasets from `https://cloud.tsinghua.edu.cn/d/dc03b3300d4d483d817d/` and put them into the `data/` folder. 27 | 2. run the model with a `yaml` configuration file like following: 28 | ```bash 29 | python run.py fit --config src/configs/yelp/yelp_full.yaml 30 | ``` 31 | 32 | ## Cite us 33 | 34 | ``` 35 | @inproceedings{DBLP:conf/sigir/YuanG0GLT22, 36 | author = {Enming Yuan and 37 | Wei Guo and 38 | Zhicheng He and 39 | Huifeng Guo and 40 | Chengkai Liu and 41 | Ruiming Tang}, 42 | editor = {Enrique Amig{\'{o}} and 43 | Pablo Castells and 44 | Julio Gonzalo and 45 | Ben Carterette and 46 | J. Shane Culpepper and 47 | Gabriella Kazai}, 48 | title = {Multi-Behavior Sequential Transformer Recommender}, 49 | booktitle = {{SIGIR} '22: The 45th International {ACM} {SIGIR} Conference on Research 50 | and Development in Information Retrieval, Madrid, Spain, July 11 - 51 | 15, 2022}, 52 | pages = {1642--1652}, 53 | publisher = {{ACM}}, 54 | year = {2022}, 55 | url = {https://doi.org/10.1145/3477495.3532023}, 56 | doi = {10.1145/3477495.3532023}, 57 | timestamp = {Sat, 09 Jul 2022 09:25:34 +0200}, 58 | biburl = {https://dblp.org/rec/conf/sigir/YuanG0GLT22.bib}, 59 | bibsource = {dblp computer science bibliography, https://dblp.org} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /src/models/relative_position.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class RelativePositionBias(nn.Module): 10 | def __init__(self, num_buckets=32, max_distance=128, n_heads=2): 11 | super(RelativePositionBias, self).__init__() 12 | self.num_buckets = num_buckets 13 | self.max_distance = max_distance 14 | self.relative_attention_bias = nn.Embedding(self.num_buckets, n_heads) 15 | 16 | @staticmethod 17 | def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): 18 | ret = 0 19 | n = -relative_position 20 | num_buckets //= 2 21 | ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets 22 | n = torch.abs(n) 23 | 24 | # now n is in the range [0, inf) 25 | 26 | # half of the buckets are for exact increments in positions 27 | max_exact = num_buckets // 2 28 | is_small = n < max_exact 29 | 30 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 31 | val_if_large = max_exact + ( 32 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 33 | ).long() 34 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 35 | 36 | ret += torch.where(is_small, n, val_if_large) 37 | return ret 38 | 39 | def forward(self, qlen, klen): 40 | """ Compute binned relative position bias """ 41 | device = self.relative_attention_bias.weight.device 42 | q_pos = torch.arange(qlen, dtype = torch.long, device = device) 43 | k_pos = torch.arange(klen, dtype = torch.long, device = device) 44 | relative_position = k_pos[None, :] - q_pos[:, None] 45 | """ 46 | k 47 | 0 1 2 3 48 | q -1 0 1 2 49 | -2 -1 0 1 50 | -3 -2 -1 0 51 | """ 52 | rp_bucket = self._relative_position_bucket( 53 | relative_position, # shape (qlen, klen) 54 | num_buckets=self.num_buckets, 55 | max_distance=self.max_distance, 56 | ) 57 | rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) 58 | values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) 59 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) 60 | return values -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import pytorch_lightning as pl 4 | from .models import CGCDotProductPredictionHead, DotProductPredictionHead 5 | from .models.bert4rec import BERT 6 | from .utils import recalls_and_ndcgs_for_ks 7 | 8 | 9 | class RecModel(pl.LightningModule): 10 | def __init__(self, 11 | backbone: BERT, 12 | b_head: bool = False, 13 | ): 14 | super().__init__() 15 | self.backbone = backbone 16 | self.n_b = backbone.n_b 17 | if b_head: 18 | self.head = CGCDotProductPredictionHead(backbone.d_model, self.n_b, 3, 1, backbone.num_items, self.backbone.embedding.token) 19 | else: 20 | self.head = DotProductPredictionHead(backbone.d_model, backbone.num_items, self.backbone.embedding.token) 21 | self.loss = torch.nn.CrossEntropyLoss(ignore_index=0) 22 | 23 | def forward(self, input_ids, b_seq): 24 | return self.backbone(input_ids, b_seq) 25 | 26 | 27 | def training_step(self, batch, batch_idx): 28 | input_ids = batch['input_ids'] 29 | b_seq = batch['behaviors'] 30 | outputs = self(input_ids, b_seq) 31 | outputs = outputs.view(-1, outputs.size(-1)) # BT x H 32 | labels = batch['labels'] 33 | labels = labels.view(-1) # BT 34 | 35 | valid = labels>0 36 | valid_index = valid.nonzero().squeeze() # M 37 | valid_outputs = outputs[valid_index] 38 | valid_b_seq = b_seq.view(-1)[valid_index] # M 39 | valid_labels = labels[valid_index] 40 | valid_logits = self.head(valid_outputs, valid_b_seq) # M 41 | 42 | loss = self.loss(valid_logits, valid_labels) 43 | loss = loss.unsqueeze(0) 44 | return {'loss':loss} 45 | 46 | 47 | def training_epoch_end(self, training_step_outputs): 48 | loss = torch.cat([o['loss'] for o in training_step_outputs], 0).mean() 49 | self.log('train_loss', loss) 50 | 51 | def validation_step(self, batch, batch_idx): 52 | input_ids = batch['input_ids'] 53 | b_seq = batch['behaviors'] 54 | outputs = self(input_ids, b_seq) 55 | 56 | # get scores (B x C) for evaluation 57 | last_outputs = outputs[:, -1, :] 58 | last_b_seq = b_seq[:,-1] 59 | candidates = batch['candidates'].squeeze() # B x C 60 | logits = self.head(last_outputs, last_b_seq, candidates) 61 | labels = batch['labels'].squeeze() 62 | metrics = recalls_and_ndcgs_for_ks(logits, labels, [1, 5, 10, 20, 50]) 63 | 64 | return metrics 65 | 66 | def validation_epoch_end(self, validation_step_outputs): 67 | keys = validation_step_outputs[0].keys() 68 | for k in keys: 69 | tmp = [] 70 | for o in validation_step_outputs: 71 | tmp.append(o[k]) 72 | self.log(f'Val:{k}', torch.Tensor(tmp).mean()) -------------------------------------------------------------------------------- /src/models/embedding.py: -------------------------------------------------------------------------------- 1 | 2 | from torch import nn as nn 3 | 4 | class PositionalEmbedding(nn.Module): 5 | 6 | def __init__(self, max_len, d_model): 7 | super().__init__() 8 | self.pe = nn.Embedding(max_len, d_model) 9 | self.apply(self._init_weights) 10 | 11 | def forward(self, x): 12 | batch_size = x.size(0) 13 | return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1) 14 | 15 | def _init_weights(self, module): 16 | """Initialize the weights.""" 17 | if isinstance(module, nn.Embedding): 18 | module.weight.data.normal_(mean=0.0, std=0.02) 19 | if module.padding_idx is not None: 20 | module.weight.data[module.padding_idx].zero_() 21 | 22 | class BERTEmbedding(nn.Module): 23 | """ 24 | BERT Embedding which is consisted with under features 25 | 1. TokenEmbedding : normal embedding matrix 26 | 2. PositionalEmbedding : adding positional information using sin, cos 27 | sum of all these features are output of BERTEmbedding 28 | """ 29 | def __init__(self, vocab_size, embed_size, max_len, dropout=0.1): 30 | """ 31 | :param vocab_size: total vocab size 32 | :param embed_size: embedding size of token embedding 33 | :param dropout: dropout rate 34 | """ 35 | super().__init__() 36 | self.token = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=0) 37 | self.position = PositionalEmbedding(max_len=max_len, d_model=embed_size) 38 | self.dropout = nn.Dropout(p=dropout) 39 | self.embed_size = embed_size 40 | self.apply(self._init_weights) 41 | 42 | def forward(self, sequence): 43 | x = self.token(sequence) + self.position(sequence) 44 | return self.dropout(x) 45 | 46 | def _init_weights(self, module): 47 | """Initialize the weights.""" 48 | if isinstance(module, nn.Embedding): 49 | module.weight.data.normal_(mean=0.0, std=0.02) 50 | if module.padding_idx is not None: 51 | module.weight.data[module.padding_idx].zero_() 52 | 53 | 54 | class SimpleEmbedding(nn.Module): 55 | """ 56 | BERT Embedding which is consisted with under features 57 | 1. TokenEmbedding : normal embedding matrix 58 | """ 59 | def __init__(self, vocab_size, embed_size, dropout=0.1): 60 | """ 61 | :param vocab_size: total vocab size 62 | :param embed_size: embedding size of token embedding 63 | :param dropout: dropout rate 64 | """ 65 | super().__init__() 66 | self.token = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=0) 67 | self.dropout = nn.Dropout(p=dropout) 68 | self.embed_size = embed_size 69 | self.apply(self._init_weights) 70 | 71 | def forward(self, sequence): 72 | x = self.token(sequence) 73 | return self.dropout(x) 74 | 75 | def _init_weights(self, module): 76 | """Initialize the weights.""" 77 | if isinstance(module, nn.Embedding): 78 | module.weight.data.normal_(mean=0.0, std=0.02) 79 | if module.padding_idx is not None: 80 | module.weight.data[module.padding_idx].zero_() -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from torch import nn as nn 3 | import torch 4 | import math 5 | import torch.nn.functional as F 6 | 7 | class GradMultiply(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x, scale): 10 | ctx.scale = scale 11 | res = x.new(x) 12 | return res 13 | 14 | @staticmethod 15 | def backward(ctx, grad): 16 | return grad * ctx.scale, None 17 | 18 | class PositionwiseFeedForward(nn.Module): 19 | "Implements FFN equation." 20 | def __init__(self, d_model, d_ff, dropout=0.1): 21 | super(PositionwiseFeedForward, self).__init__() 22 | self.w_1 = nn.Linear(d_model, d_ff) 23 | self.w_2 = nn.Linear(d_ff, d_model) 24 | self.dropout = nn.Dropout(dropout) 25 | self.activation = nn.ReLU() 26 | self.apply(self._init_weights) 27 | 28 | def forward(self, x): 29 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 30 | 31 | def _init_weights(self, module): 32 | """Initialize the weights.""" 33 | if isinstance(module, nn.Linear): 34 | module.weight.data.normal_(mean=0.0, std=0.02) 35 | if module.bias is not None: 36 | module.bias.data.zero_() 37 | 38 | class SublayerConnection(nn.Module): 39 | """ 40 | sublayer connection with behavior specific layer norm 41 | """ 42 | def __init__(self, size, dropout=0): 43 | super(SublayerConnection, self).__init__() 44 | self.norm = nn.LayerNorm(size) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | def forward(self, x, sublayer): 48 | "Apply residual connection to any sublayer with the same size." 49 | return self.norm(x + self.dropout(sublayer(x))) 50 | 51 | class BehaviorSpecificPFF(nn.Module): 52 | """ 53 | Behavior specific pointwise feedforward network. 54 | """ 55 | def __init__(self, d_model, d_ff, n_b, bpff=False, dropout=0.1): 56 | super().__init__() 57 | self.n_b = n_b 58 | self.bpff = bpff 59 | if bpff and n_b > 1: 60 | self.pff = nn.ModuleList([PositionwiseFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) for i in range(n_b)]) 61 | else: 62 | self.pff = PositionwiseFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) 63 | 64 | def multi_behavior_pff(self, x, b_seq): 65 | """ 66 | x: B x T x H 67 | b_seq: B x T, 0 means padding. 68 | """ 69 | outputs = [torch.zeros_like(x)] 70 | for i in range(self.n_b): 71 | outputs.append(self.pff[i](x)) 72 | return torch.einsum('nBTh, BTn -> BTh', torch.stack(outputs, dim=0), F.one_hot(b_seq, num_classes=self.n_b+1).float()) 73 | 74 | def forward(self, x, b_seq=None): 75 | if self.bpff and self.n_b > 1: 76 | return self.multi_behavior_pff(x, b_seq) 77 | else: 78 | return self.pff(x) 79 | 80 | class MMoE(nn.Module): 81 | def __init__(self, d_model, d_ff, n_b, n_e=1, bmmoe=False, dropout=0.1): 82 | super(MMoE, self).__init__() 83 | self.n_b = n_b 84 | self.n_e = n_e 85 | self.bmmoe = bmmoe 86 | if self.bmmoe and n_e > 1: 87 | self.softmax = nn.Softmax(dim=-1) 88 | self.experts = nn.ModuleList([PositionwiseFeedForward(d_model, d_ff, dropout) for i in range(self.n_e)]) 89 | self.w_gates = nn.Parameter(torch.randn(self.n_b, d_model, self.n_e), requires_grad=True) 90 | else: 91 | self.pff = PositionwiseFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) 92 | 93 | def forward(self, x, b_seq): 94 | if self.bmmoe and self.n_e > 1: 95 | experts_o = [e(x) for e in self.experts] 96 | experts_o_tensor = torch.stack(experts_o) 97 | gates_o = self.softmax(torch.einsum('bnd,tde->tbne', x, self.w_gates)) 98 | output = torch.einsum('ebnd,tbne->tbnd', experts_o_tensor, gates_o) 99 | outputs = torch.cat([torch.zeros_like(x).unsqueeze(0), output]) 100 | return torch.einsum('tbnd, bnt -> bnd', outputs, F.one_hot(b_seq, num_classes=self.n_b+1).float()) 101 | else: 102 | return self.pff(x) -------------------------------------------------------------------------------- /src/models/heads.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # head used for bert4rec 8 | class DotProductPredictionHead(nn.Module): 9 | """share embedding parameters""" 10 | def __init__(self, d_model, num_items, token_embeddings): 11 | super().__init__() 12 | self.token_embeddings = token_embeddings 13 | self.vocab_size = num_items + 1 14 | self.out = nn.Sequential( 15 | nn.Linear(d_model, d_model), 16 | nn.ReLU(), 17 | ) 18 | self.bias = nn.Parameter(torch.zeros(1, self.vocab_size)) 19 | 20 | def forward(self, x, b_seq, candidates=None): 21 | x = self.out(x) # B x H or M x H 22 | if candidates is not None: # x : B x H 23 | emb = self.token_embeddings(candidates) # B x C x H 24 | logits = (x.unsqueeze(1) * emb).sum(-1) # B x C 25 | bias = self.bias.expand(logits.size(0), -1).gather(1, candidates) # B x C 26 | logits += bias 27 | else: # x : M x H 28 | emb = self.token_embeddings.weight[:self.vocab_size] # V x H 29 | logits = torch.matmul(x, emb.transpose(0, 1)) # M x V 30 | logits += self.bias 31 | return logits 32 | 33 | 34 | class CGCDotProductPredictionHead(nn.Module): 35 | """ 36 | model with shared expert and behavior specific expert 37 | 3 shared expert, 38 | 1 specific expert per behavior. 39 | """ 40 | def __init__(self, d_model, n_b, n_e_sh, n_e_sp, num_items, token_embeddings): 41 | super().__init__() 42 | self.n_b = n_b 43 | self.n_e_sh = n_e_sh 44 | self.n_e_sp = n_e_sp 45 | self.vocab_size = num_items + 1 46 | self.softmax = nn.Softmax(dim=-1) 47 | self.shared_experts = nn.ModuleList([nn.Sequential(nn.Linear(d_model, d_model)) for i in range(self.n_e_sh)]) 48 | self.specific_experts = nn.ModuleList([nn.Sequential(nn.Linear(d_model, d_model)) for i in range(self.n_b * self.n_e_sp)]) 49 | # self.shared_experts = nn.ModuleList([nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(),nn.Linear(d_model, d_model)) for i in range(self.n_e_sh)]) 50 | # self.specific_experts = nn.ModuleList([nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(),nn.Linear(d_model, d_model)) for i in range(self.n_b * self.n_e_sp)]) 51 | self.w_gates = nn.Parameter(torch.randn(self.n_b, d_model, self.n_e_sh + self.n_e_sp), requires_grad=True) 52 | self.token_embeddings = token_embeddings 53 | self.ln = nn.LayerNorm(d_model) 54 | 55 | def forward(self, x, b_seq, candidates=None): 56 | x = self.mmoe_process(x, b_seq) 57 | if candidates is not None: # x : B x H 58 | emb = self.token_embeddings(candidates) # B x C x H 59 | logits = (x.unsqueeze(1) * emb).sum(-1) # B x C 60 | else: # x : M x H 61 | emb = self.token_embeddings.weight[:self.vocab_size] # V x H 62 | logits = torch.matmul(x, emb.transpose(0, 1)) # M x V 63 | return logits 64 | 65 | def mmoe_process(self, x, b_seq): 66 | shared_experts_o = [e(x) for e in self.shared_experts] 67 | specific_experts_o = [e(x) for e in self.specific_experts] 68 | gates_o = self.softmax(torch.einsum('nd,tde->tne', x, self.w_gates)) 69 | # rearange 70 | experts_o_tensor = torch.stack([torch.stack(shared_experts_o+specific_experts_o[i*self.n_e_sp:(i+1)*self.n_e_sp]) for i in range(self.n_b)]) 71 | # torch.stack([torch.stack(shared_experts_o+specific_experts_o[i*2: (i+1)*2]) for i in range(4)]) 72 | output = torch.einsum('tend,tne->tnd', experts_o_tensor, gates_o) 73 | outputs = torch.cat([torch.zeros_like(x).unsqueeze(0), output]) 74 | return x + self.ln(torch.einsum('tnd, nt -> nd', outputs, F.one_hot(b_seq, num_classes=self.n_b+1).float())) 75 | 76 | # class DotProductPredictionHead(nn.Module): 77 | # """share embedding parameters""" 78 | # def __init__(self, d_model, num_items, token_embeddings): 79 | # super().__init__() 80 | # self.token_embeddings = token_embeddings 81 | # self.vocab_size = num_items + 1 82 | 83 | # def forward(self, x, b_seq, candidates=None): 84 | # if candidates is not None: # x : B x H 85 | # emb = self.token_embeddings(candidates) # B x C x H 86 | # logits = (x.unsqueeze(1) * emb).sum(-1) # B x C 87 | # else: # x : M x H 88 | # emb = self.token_embeddings.weight[:self.vocab_size] # V x H 89 | # logits = torch.matmul(x, emb.transpose(0, 1)) # M x V 90 | # return logits -------------------------------------------------------------------------------- /src/datamodule.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from typing import List, Union 4 | from .datasets import dataset_factory 5 | from .dataloaders import RecDataloader 6 | 7 | class RecDataModule(pl.LightningDataModule): 8 | def __init__( 9 | self, 10 | dataset_code: str = None, 11 | target_behavior: str = None, 12 | multi_behavior: Union[bool, List] = None, 13 | min_uc: int = None, 14 | num_items: int = None, 15 | max_len: int = None, 16 | mask_prob: float = None, 17 | num_workers: int = None, 18 | val_negative_sampler_code: str = None, 19 | val_negative_sample_size: int = None, 20 | train_batch_size: int = None, 21 | val_batch_size: int = None, 22 | predict_only_target: bool = None, 23 | ): 24 | super().__init__() 25 | self.dataset_code = dataset_code 26 | self.min_uc = min_uc 27 | self.target_behavior = target_behavior 28 | self.multi_behavior = multi_behavior 29 | self.num_items = num_items 30 | self.max_len = max_len 31 | self.mask_prob = mask_prob 32 | self.num_workers = num_workers 33 | self.val_negative_sampler_code = val_negative_sampler_code 34 | self.val_negative_sample_size = val_negative_sample_size 35 | self.train_batch_size = train_batch_size 36 | self.val_batch_size = val_batch_size 37 | self.predict_only_target = predict_only_target 38 | 39 | def prepare_data(self): 40 | # download, split, etc... 41 | # only called on 1 GPU/TPU in distributed 42 | dataset_factory( 43 | self.dataset_code, 44 | self.target_behavior, 45 | self.multi_behavior, 46 | self.min_uc, 47 | ) 48 | 49 | def setup(self, stage): 50 | # make assignments here (val/train/test split) 51 | # called on every process in DDP 52 | self.dataset = dataset_factory( 53 | self.dataset_code, 54 | self.target_behavior, 55 | self.multi_behavior, 56 | self.min_uc, 57 | ) 58 | 59 | self.dataloader = RecDataloader( 60 | self.dataset, 61 | self.max_len, 62 | self.mask_prob, 63 | self.num_items, 64 | self.num_workers, 65 | self.val_negative_sampler_code, 66 | self.val_negative_sample_size, 67 | self.train_batch_size, 68 | self.val_batch_size, 69 | self.predict_only_target, 70 | ) 71 | 72 | def train_dataloader(self): 73 | return self.dataloader.get_train_loader() 74 | def val_dataloader(self): 75 | return self.dataloader.get_val_loader() 76 | 77 | class RecDataModuleNeg(pl.LightningDataModule): 78 | def __init__( 79 | self, 80 | dataset_code: str = None, 81 | target_behavior: str = None, 82 | multi_behavior: bool = None, 83 | min_uc: int = None, 84 | num_items: int = None, 85 | max_len: int = None, 86 | mask_prob: float = None, 87 | num_workers: int = None, 88 | train_negative_sampler_code: str = None, 89 | train_negative_sample_size: int = None, 90 | val_negative_sampler_code: str = None, 91 | val_negative_sample_size: int = None, 92 | train_batch_size: int = None, 93 | val_batch_size: int = None, 94 | predict_only_target: bool = None, 95 | ): 96 | super().__init__() 97 | self.dataset_code = dataset_code 98 | self.min_uc = min_uc 99 | self.target_behavior = target_behavior 100 | self.multi_behavior = multi_behavior 101 | self.num_items = num_items 102 | self.max_len = max_len 103 | self.mask_prob = mask_prob 104 | self.num_workers = num_workers 105 | self.train_negative_sampler_code = train_negative_sampler_code 106 | self.train_negative_sample_size = train_negative_sample_size 107 | self.val_negative_sampler_code = val_negative_sampler_code 108 | self.val_negative_sample_size = val_negative_sample_size 109 | self.train_batch_size = train_batch_size 110 | self.val_batch_size = val_batch_size 111 | self.predict_only_target = predict_only_target 112 | 113 | def prepare_data(self): 114 | # download, split, etc... 115 | # only called on 1 GPU/TPU in distributed 116 | dataset_factory( 117 | self.dataset_code, 118 | self.target_behavior, 119 | self.multi_behavior, 120 | self.min_uc, 121 | ) 122 | 123 | def setup(self, stage): 124 | # make assignments here (val/train/test split) 125 | # called on every process in DDP 126 | self.dataset = dataset_factory( 127 | self.dataset_code, 128 | self.target_behavior, 129 | self.multi_behavior, 130 | self.min_uc, 131 | ) 132 | 133 | self.dataloader = RecDataloaderNeg( 134 | self.dataset, 135 | self.max_len, 136 | self.mask_prob, 137 | self.num_items, 138 | self.num_workers, 139 | self.train_negative_sampler_code, 140 | self.train_negative_sample_size, 141 | self.val_negative_sampler_code, 142 | self.val_negative_sample_size, 143 | self.train_batch_size, 144 | self.val_batch_size, 145 | self.predict_only_target, 146 | ) 147 | 148 | def train_dataloader(self): 149 | return self.dataloader.get_train_loader() 150 | def val_dataloader(self): 151 | return self.dataloader.get_val_loader() -------------------------------------------------------------------------------- /src/datasets/base.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | RAW_DATASET_ROOT_FOLDER = 'data' 4 | 5 | import pandas as pd 6 | from tqdm import tqdm 7 | tqdm.pandas() 8 | 9 | from abc import * 10 | from pathlib import Path 11 | import pickle 12 | 13 | class AbstractDataset(metaclass=ABCMeta): 14 | def __init__(self, 15 | target_behavior, 16 | multi_behavior, 17 | min_uc 18 | ): 19 | self.target_behavior = target_behavior 20 | self.multi_behavior = multi_behavior 21 | self.min_uc = min_uc 22 | self.bmap = None 23 | assert self.min_uc >= 2, 'Need at least 2 items per user for validation and test' 24 | self.split = 'leave_one_out' 25 | 26 | @classmethod 27 | @abstractmethod 28 | def code(cls): 29 | pass 30 | 31 | @classmethod 32 | def raw_code(cls): 33 | return cls.code() 34 | 35 | @abstractmethod 36 | def load_df(self): 37 | pass 38 | 39 | def load_dataset(self): 40 | self.preprocess() 41 | dataset_path = self._get_preprocessed_dataset_path() 42 | dataset = pickle.load(dataset_path.open('rb')) 43 | return dataset 44 | 45 | def preprocess(self): 46 | dataset_path = self._get_preprocessed_dataset_path() 47 | if dataset_path.is_file(): 48 | print('Already preprocessed. Skip preprocessing') 49 | return 50 | if not dataset_path.parent.is_dir(): 51 | dataset_path.parent.mkdir(parents=True) 52 | df = self.load_df() 53 | df = self.make_implicit(df) 54 | df = self.filter_triplets(df) 55 | df, umap, smap, bmap = self.densify_index(df) 56 | self.bmap = bmap 57 | train, train_b, val, val_b, val_num = self.split_df(df, len(umap)) 58 | dataset = {'train': train, 59 | 'val': val, 60 | 'train_b': train_b, 61 | 'val_b': val_b, 62 | 'val_num': val_num, 63 | 'umap': umap, 64 | 'smap': smap, 65 | 'bmap': bmap} 66 | with dataset_path.open('wb') as f: 67 | pickle.dump(dataset, f) 68 | 69 | def make_implicit(self, df): 70 | print('Behavior selection') 71 | if self.multi_behavior: 72 | pass 73 | else: 74 | df = df[df['behavior'] == self.target_behavior] 75 | return df 76 | 77 | def filter_triplets(self, df): 78 | print('Filtering triplets') 79 | if self.min_uc > 0: 80 | user_sizes = df.groupby('uid').size() 81 | good_users = user_sizes.index[user_sizes >= self.min_uc] 82 | df = df[df['uid'].isin(good_users)] 83 | return df 84 | 85 | def densify_index(self, df): 86 | print('Densifying index') 87 | umap = {u: (i+1) for i, u in enumerate(set(df['uid']))} 88 | smap = {s: (i+1) for i, s in enumerate(set(df['sid']))} 89 | bmap = {b: (i+1) for i, b in enumerate(set(df['behavior']))} 90 | df['uid'] = df['uid'].map(umap) 91 | df['sid'] = df['sid'].map(smap) 92 | df['behavior'] = df['behavior'].map(bmap) 93 | return df, umap, smap, bmap 94 | 95 | # def densify_index(self, df): 96 | # print('Densifying index') 97 | # umap = {u: u for u in set(df['uid'])} 98 | # smap = {s: s for s in set(df['sid'])} 99 | # bmap = {'pv': 1, 'fav':2, 'cart':3, 'buy':4} if 'buy' in set(df['behavior']) else {'tip': 1, 'neg':2, 'neutral':3, 'pos':4} 100 | # df['behavior'] = df['behavior'].map(bmap) 101 | # return df, umap, smap, bmap 102 | 103 | def split_df(self, df, user_count): 104 | if self.split == 'leave_one_out': 105 | print('Splitting') 106 | user_group = df.groupby('uid') 107 | # since we have sorted raw input, we do not need to sort again, 108 | # if you use random permuted df, you need to use the following lines of code. 109 | # user2items = user_group.progress_apply(lambda d: list(d.sort_values(by='timestamp')['sid'])) 110 | # user2behaviors = user_group.progress_apply(lambda d: list(d.sort_values(by='timestamp')['behavior'])) 111 | user2items = user_group.progress_apply(lambda d: list(d['sid'])) 112 | user2behaviors = user_group.progress_apply(lambda d: list(d['behavior'])) 113 | train, train_b, val, val_b, = {}, {}, {}, {} 114 | for user in range(1, user_count+1): 115 | items = user2items[user] 116 | behaviors = user2behaviors[user] 117 | # only evaluate the target behavior 118 | if behaviors[-1] == self.bmap[self.target_behavior]: 119 | train[user], val[user] = items[:-1], items[-1:] 120 | train_b[user], val_b[user] = behaviors[:-1], behaviors[-1:] 121 | else: 122 | train[user] = items 123 | train_b[user] = behaviors 124 | return train, train_b, val, val_b, len(val) 125 | else: 126 | raise NotImplementedError 127 | 128 | def _get_rawdata_root_path(self): 129 | return Path(RAW_DATASET_ROOT_FOLDER) 130 | 131 | def _get_preprocessed_root_path(self): 132 | root = self._get_rawdata_root_path() 133 | return root.joinpath('preprocessed') 134 | 135 | def _get_preprocessed_folder_path(self): 136 | preprocessed_root = self._get_preprocessed_root_path() 137 | folder_name = '{}-min_uc{}-target_B{}_MB{}-split{}' \ 138 | .format(self.code(), self.min_uc, self.target_behavior, self.multi_behavior, self.split) 139 | return preprocessed_root.joinpath(folder_name) 140 | 141 | def _get_preprocessed_dataset_path(self): 142 | folder = self._get_preprocessed_folder_path() 143 | return folder.joinpath('dataset.pkl') 144 | -------------------------------------------------------------------------------- /src/dataloaders/rec_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .base import AbstractDataloader 4 | 5 | import torch 6 | import numpy as np 7 | import torch.utils.data as data_utils 8 | 9 | class RecDataloader(AbstractDataloader): 10 | def __init__( 11 | self, 12 | dataset, 13 | seg_len, 14 | mask_prob, 15 | num_items, 16 | num_workers, 17 | val_negative_sampler_code, 18 | val_negative_sample_size, 19 | train_batch_size, 20 | val_batch_size, 21 | predict_only_target=False, 22 | ): 23 | super().__init__(dataset, 24 | val_negative_sampler_code, 25 | val_negative_sample_size) 26 | self.target_code = self.bmap.get('buy') if self.bmap.get('buy') else self.bmap.get('pos') 27 | self.seg_len = seg_len 28 | self.mask_prob = mask_prob 29 | self.num_items = num_items 30 | self.num_workers = num_workers 31 | self.train_batch_size = train_batch_size 32 | self.val_batch_size = val_batch_size 33 | self.predict_only_target = predict_only_target 34 | 35 | def get_train_loader(self): 36 | dataset = self._get_train_dataset() 37 | dataloader = data_utils.DataLoader(dataset, batch_size=self.train_batch_size, 38 | shuffle=True, num_workers=self.num_workers) 39 | return dataloader 40 | 41 | def _get_train_dataset(self): 42 | dataset = RecTrainDataset(self.train, self.train_b, self.seg_len, self.mask_prob, self.num_items, self.target_code, self.predict_only_target) 43 | return dataset 44 | 45 | def get_val_loader(self): 46 | dataset = self._get_eval_dataset() 47 | dataloader = data_utils.DataLoader(dataset, batch_size=self.val_batch_size, 48 | shuffle=False, num_workers=self.num_workers) 49 | return dataloader 50 | 51 | def _get_eval_dataset(self): 52 | dataset = RecEvalDataset(self.train, self.train_b, self.val, self.val_b, self.val_num, self.seg_len, self.num_items, self.target_code, self.val_negative_samples) 53 | return dataset 54 | 55 | class RecTrainDataset(data_utils.Dataset): 56 | def __init__(self, u2seq, u2b, max_len, mask_prob, num_items, target_code, predict_only_target): 57 | self.u2seq = u2seq 58 | self.u2b = u2b 59 | self.users = sorted(self.u2seq.keys()) 60 | self.max_len = max_len 61 | self.mask_prob = mask_prob 62 | self.num_items = num_items 63 | self.target_code = target_code 64 | self.predict_only_target = predict_only_target 65 | 66 | def __len__(self): 67 | return len(self.users) 68 | 69 | def __getitem__(self, index): 70 | user = self.users[index] 71 | seq = self.u2seq[user] 72 | b_seq = self.u2b[user] 73 | 74 | tokens = [] 75 | behaviors = [] 76 | labels = [] 77 | for s,b in zip(seq, b_seq): 78 | prob = np.random.rand() 79 | if prob < self.mask_prob and not self.predict_only_target: 80 | tokens.append(self.num_items+1) 81 | labels.append(s) 82 | elif prob < self.mask_prob and self.predict_only_target and b == self.target_code: 83 | tokens.append(self.num_items+1) 84 | labels.append(s) 85 | else: 86 | tokens.append(s) 87 | labels.append(0) 88 | behaviors.append(b) 89 | 90 | if len(tokens) <= self.max_len or np.random.rand()<0.8: 91 | # if len(tokens) <= self.max_len: 92 | # if True: 93 | tokens = tokens[-self.max_len:] 94 | labels = labels[-self.max_len:] 95 | behaviors = behaviors[-self.max_len:] 96 | 97 | padding_len = self.max_len - len(tokens) 98 | 99 | tokens = [0] * padding_len + tokens 100 | labels = [0] * padding_len + labels 101 | behaviors = [0] * padding_len + behaviors 102 | else: 103 | begin_idx = np.random.randint(0, len(tokens)-self.max_len+1) 104 | tokens = tokens[begin_idx:begin_idx+self.max_len] 105 | labels = labels[begin_idx:begin_idx+self.max_len] 106 | behaviors = behaviors[begin_idx:begin_idx+self.max_len] 107 | 108 | return { 109 | 'input_ids':torch.LongTensor(tokens), 110 | 'labels':torch.LongTensor(labels), 111 | 'behaviors':torch.LongTensor(behaviors) 112 | } 113 | 114 | 115 | class RecEvalDataset(data_utils.Dataset): 116 | def __init__(self, u2seq, u2b, u2answer, u2ab, val_num, max_len, num_items, target_code, negative_samples): 117 | self.u2seq = u2seq 118 | self.u2b = u2b 119 | self.u2answer = u2answer 120 | self.users = sorted(self.u2answer.keys()) 121 | self.u2ab = u2ab 122 | self.val_num = val_num 123 | self.max_len = max_len 124 | self.negative_samples = negative_samples 125 | self.num_items = num_items 126 | self.target_code = target_code 127 | 128 | def __len__(self): 129 | return self.val_num 130 | 131 | def __getitem__(self, index): 132 | user = self.users[index] 133 | seq = self.u2seq[user] 134 | answer = self.u2answer[user] 135 | negs = self.negative_samples[user] 136 | 137 | candidates = answer + negs 138 | labels = [1] * len(answer) + [0] * len(negs) 139 | 140 | seq = seq + [self.num_items + 1] 141 | seq = seq[-self.max_len:] 142 | seq_b = self.u2b[user] + self.u2ab[user] 143 | seq_b = seq_b[-self.max_len:] 144 | padding_len = self.max_len - len(seq) 145 | seq = [0] * padding_len + seq 146 | seq_b = [0] * padding_len + seq_b 147 | 148 | return { 149 | 'input_ids':torch.LongTensor(seq), 150 | 'candidates':torch.LongTensor(candidates), 151 | 'labels':torch.LongTensor(labels), 152 | 'behaviors': torch.LongTensor(seq_b) 153 | } -------------------------------------------------------------------------------- /src/models/new_transformer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from torch import nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | import math 7 | from .utils import SublayerConnection, BehaviorSpecificPFF 8 | from .relative_position import RelativePositionBias 9 | 10 | class Attention(nn.Module): 11 | def __init__(self, dropout=0.1): 12 | super().__init__() 13 | self.dropout = nn.Dropout(dropout) 14 | 15 | def forward(self, query, key, value, b_mat=None, rpb=None, W1=None, alpha1=None, W2=None, alpha2=None, mask=None): 16 | # 1. Calculate Q-K similarity. w. / w.o. multi-behavior dependencies 17 | if b_mat is not None: 18 | W1_ = torch.einsum('Bhmn,CBh->Chmn', W1, F.softmax(alpha1, 1)) 19 | att_all = torch.einsum('bhim,Chmn,bhjn->bhijC', query, W1_, key) 20 | h=W1.size(1) 21 | scores = att_all.gather(4, b_mat[:,None,:,:,None].repeat(1,h,1,1,1)).squeeze(4) \ 22 | / math.sqrt(query.size(-1)) + rpb 23 | else: 24 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 25 | / math.sqrt(query.size(-1)) + rpb 26 | 27 | # 2. dealing with padding and softmax. 28 | if mask is not None: 29 | assert len(mask.shape) == 2 30 | mask = (mask[:,:,None] & mask[:,None,:]).unsqueeze(1) 31 | if scores.dtype == torch.float16: 32 | scores = scores.masked_fill(mask == 0, -65500) 33 | else: 34 | scores = scores.masked_fill(mask == 0, -1e30) 35 | p_attn = self.dropout(nn.functional.softmax(scores, dim=-1)) 36 | 37 | # 3. information agregation. w./w.o. multi-behavior dependencies 38 | if b_mat is not None: 39 | h=W2.size(1) 40 | one_hot_b_mat = F.one_hot(b_mat[:,None,:,:], num_classes=alpha2.size(0)).repeat(1,h,1,1,1) 41 | W2_ = torch.einsum('BhdD,CBh->ChdD', W2, F.softmax(alpha2, 1)) 42 | return torch.einsum('bhij, bhijC, ChdD, bhjd -> bhiD', p_attn, one_hot_b_mat, W2_, value) 43 | # return torch.matmul(p_attn, value) 44 | else: 45 | return torch.matmul(p_attn, value) 46 | 47 | class MultiHeadedAttention(nn.Module): 48 | def __init__(self, h, n_b, battn, brpb, d_model, dropout=0.1): 49 | super().__init__() 50 | assert d_model % h == 0 51 | 52 | # We assume d_v always equals d_k 53 | self.d_k = d_model // h 54 | self.h = h 55 | self.n_b = n_b 56 | self.battn = battn 57 | self.brpb = brpb 58 | 59 | if battn and n_b > 1: # behavior-specific mutual attention 60 | self.W1 = nn.Parameter(torch.randn(self.n_b, self.h, self.d_k, self.d_k)) 61 | self.alpha1 = nn.Parameter(torch.randn(self.n_b * self.n_b + 1, self.n_b, self.h)) 62 | self.W2 = nn.Parameter(torch.randn(self.n_b, self.h, self.d_k, self.d_k)) 63 | self.alpha2 = nn.Parameter(torch.randn(self.n_b * self.n_b + 1, self.n_b, self.h)) 64 | self.linear_layers = nn.Parameter(torch.randn(3, self.n_b+1, d_model, self.h, self.d_k)) 65 | else: 66 | self.W1 = None 67 | self.W2 = None 68 | self.alpha1, self.alpha2 = None, None 69 | self.linear_layers = nn.Parameter(torch.randn(3, d_model, self.h, self.d_k)) 70 | self.linear_layers.data.normal_(mean=0.0, std=0.02) 71 | 72 | if self.brpb: 73 | self.rpb = nn.ModuleList([RelativePositionBias(32,40,self.h) for i in range(self.n_b * self.n_b + 1)]) 74 | self.attention = Attention(dropout) 75 | self.dropout = nn.Dropout(dropout) 76 | 77 | def forward(self, query, key, value, b_seq=None, mask=None): 78 | batch_size, seq_len = query.size(0), query.size(1) 79 | b_mat = ((b_seq[:,:,None]-1)*self.n_b + b_seq[:,None,:]) * (b_seq[:,:,None]*b_seq[:,None,:]!=0) 80 | # 0. rel pos bias 81 | if self.brpb: 82 | rel_pos_bias = torch.stack([layer(seq_len, seq_len) for layer in self.rpb], -1).repeat(batch_size,1,1,1,1) 83 | rel_pos_bias = rel_pos_bias.gather(4, b_mat[:,None,:,:,None].repeat(1,self.h,1,1,1)).squeeze(4) 84 | else: 85 | rel_pos_bias = 0 86 | 87 | if self.battn and self.n_b>1: # behavior-specific mutual attention 88 | # 1) Do all the linear projections in batch from d_model => h x d_k 89 | query, key, value = [torch.einsum("bnd, Bdhk, bnB->bhnk", x, self.linear_layers[l], F.one_hot(b_seq,num_classes=self.n_b+1).float()) 90 | for l, x in zip(range(3), (query, key, value))] 91 | else: 92 | # 1) Do all the linear projections in batch from d_model => h x d_k 93 | query, key, value = [torch.einsum("bnd, dhk->bhnk", x, self.linear_layers[l]) 94 | for l, x in zip(range(3), (query, key, value))] 95 | b_mat = None 96 | 97 | # 2) Apply attention on all the projected vectors in batch. 98 | x = self.attention(query, key, value, b_mat=b_mat, rpb=rel_pos_bias, W1=self.W1, alpha1=self.alpha1, W2=self.W2, alpha2=self.alpha2, mask=mask) 99 | 100 | # 3) "Concat" using a view. 101 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 102 | 103 | return x 104 | 105 | class TransformerBlock(nn.Module): 106 | def __init__(self, hidden, attn_heads, feed_forward_hidden, n_b, battn, bpff, brpb, dropout): 107 | """ 108 | :param hidden: hidden size of transformer 109 | :param attn_heads: head sizes of multi-head attention 110 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 111 | :param dropout: dropout rate 112 | :param n_b: number of behaviors 113 | :param battn: use multi-behavior cross attention 114 | :param bpff: use behavior-specific multi-gated mixture of experts 115 | :param brpb: use behavior-specific relative position bias 116 | """ 117 | super().__init__() 118 | self.attention = MultiHeadedAttention(h=attn_heads, n_b=n_b, battn=battn, brpb=brpb, d_model=hidden, dropout=dropout) 119 | self.feed_forward = BehaviorSpecificPFF(d_model=hidden, d_ff=feed_forward_hidden, n_b=n_b, bpff=bpff, dropout=dropout) 120 | self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) 121 | self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) 122 | self.dropout = nn.Dropout(p=dropout) 123 | 124 | def forward(self, x, b_seq, mask): 125 | x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, b_seq, mask=mask)) 126 | x = self.output_sublayer(x, lambda _x: self.feed_forward(_x, b_seq)) 127 | return self.dropout(x) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------