├── recbole ├── config │ └── __init__.py ├── sampler │ └── __init__.py ├── properties │ ├── model │ │ ├── DirectAU.yaml │ │ └── MAWU.yaml │ ├── dataset │ │ ├── gowalla.yaml │ │ ├── yelp.yaml │ │ └── beauty.yaml │ └── overall.yaml ├── quick_start │ ├── __init__.py │ └── quick_start.py ├── model │ ├── __init__.py │ ├── general_recommender │ │ ├── __init__.py │ │ ├── directau.py │ │ └── mawu.py │ ├── init.py │ └── loss.py ├── __init__.py ├── data │ ├── __init__.py │ ├── dataloader │ │ ├── __init__.py │ │ ├── user_dataloader.py │ │ ├── knowledge_dataloader.py │ │ ├── general_dataloader.py │ │ └── abstract_dataloader.py │ ├── dataset │ │ ├── __init__.py │ │ ├── kg_seq_dataset.py │ │ ├── decisiontree_dataset.py │ │ ├── customized_dataset.py │ │ └── sequential_dataset.py │ ├── utils.py │ └── interaction.py ├── trainer │ ├── __init__.py │ └── hyper_tuning.py ├── evaluator │ ├── __init__.py │ ├── evaluator.py │ ├── register.py │ ├── utils.py │ ├── base_metric.py │ └── collector.py └── utils │ ├── __init__.py │ ├── argument_list.py │ ├── wandblogger.py │ ├── enum_type.py │ ├── url.py │ ├── logger.py │ ├── case_study.py │ └── utils.py ├── MANIFEST.in ├── .gitignore ├── run_recbole.py ├── LICENSE ├── setup.py ├── README.md ├── mawu.yaml └── style.cfg /recbole/config/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.config.configurator import Config 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include recbole/properties * 2 | recursive-include recbole/dataset_example * 3 | -------------------------------------------------------------------------------- /recbole/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.sampler.sampler import Sampler, KGSampler, RepeatableSampler, SeqSampler 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /log/* 2 | /log_tensorboard/* 3 | /dataset/* 4 | /saved/* 5 | /output_files/* 6 | 7 | __pycache__/ 8 | *.py[cod] -------------------------------------------------------------------------------- /recbole/properties/model/DirectAU.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | encoder: MF 3 | neg_sampling: ~ 4 | n_layers: 2 5 | gamma: 0.4 -------------------------------------------------------------------------------- /recbole/quick_start/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.quick_start.quick_start import run_recbole, objective_function, load_data_and_model 2 | -------------------------------------------------------------------------------- /recbole/properties/model/MAWU.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | encoder: MF 3 | neg_sampling: ~ 4 | n_layers: 2 5 | gamma1: 0.5 6 | gamma2: 0.5 -------------------------------------------------------------------------------- /recbole/model/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | -------------------------------------------------------------------------------- /recbole/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | __version__ = '1.0.1' 6 | -------------------------------------------------------------------------------- /recbole/data/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.data.utils import * 2 | 3 | __all__ = ['create_dataset', 'data_preparation', 'save_split_dataloaders', 'load_split_dataloaders'] 4 | -------------------------------------------------------------------------------- /recbole/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.trainer.hyper_tuning import HyperTuning 2 | from recbole.trainer.trainer import * 3 | 4 | __all__ = ['Trainer', 'KGTrainer', 'KGATTrainer', 'S3RecTrainer'] 5 | -------------------------------------------------------------------------------- /recbole/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.evaluator.base_metric import * 2 | from recbole.evaluator.metrics import * 3 | from recbole.evaluator.evaluator import * 4 | from recbole.evaluator.register import * 5 | from recbole.evaluator.collector import * 6 | -------------------------------------------------------------------------------- /recbole/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.data.dataloader.abstract_dataloader import * 2 | from recbole.data.dataloader.general_dataloader import * 3 | from recbole.data.dataloader.knowledge_dataloader import * 4 | from recbole.data.dataloader.user_dataloader import * 5 | -------------------------------------------------------------------------------- /recbole/model/general_recommender/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.model.general_recommender.mf import MF 2 | from recbole.model.general_recommender.lightgcn import LightGCN 3 | from recbole.model.general_recommender.directau import DirectAU 4 | from recbole.model.general_recommender.mawu import MAWU 5 | -------------------------------------------------------------------------------- /recbole/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.data.dataset.dataset import Dataset 2 | from recbole.data.dataset.sequential_dataset import SequentialDataset 3 | from recbole.data.dataset.kg_dataset import KnowledgeBasedDataset 4 | from recbole.data.dataset.kg_seq_dataset import KGSeqDataset 5 | from recbole.data.dataset.decisiontree_dataset import DecisionTreeDataset 6 | from recbole.data.dataset.customized_dataset import * 7 | -------------------------------------------------------------------------------- /recbole/data/dataset/kg_seq_dataset.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/9/23 2 | # @Author : Xingyu Pan 3 | # @Email : panxingyu@ruc.edu.cn 4 | 5 | """ 6 | recbole.data.kg_seq_dataset 7 | ############################# 8 | """ 9 | 10 | from recbole.data.dataset import SequentialDataset, KnowledgeBasedDataset 11 | 12 | 13 | class KGSeqDataset(SequentialDataset, KnowledgeBasedDataset): 14 | """Containing both processing of Sequential Models and Knowledge-based Models. 15 | 16 | Inherit from :class:`~recbole.data.dataset.sequential_dataset.SequentialDataset` and 17 | :class:`~recbole.data.dataset.kg_dataset.KnowledgeBasedDataset`. 18 | """ 19 | 20 | def __init__(self, config): 21 | super().__init__(config) 22 | -------------------------------------------------------------------------------- /recbole/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole.utils.logger import init_logger, set_color 2 | from recbole.utils.utils import get_local_time, ensure_dir, get_model, get_trainer, \ 3 | early_stopping, calculate_valid_score, dict2str, init_seed, get_tensorboard, get_gpu_usage 4 | from recbole.utils.enum_type import * 5 | from recbole.utils.argument_list import * 6 | from recbole.utils.wandblogger import WandbLogger 7 | 8 | __all__ = [ 9 | 'init_logger', 'get_local_time', 'ensure_dir', 'get_model', 'get_trainer', 'early_stopping', 10 | 'calculate_valid_score', 'dict2str', 'Enum', 'ModelType', 'KGDataLoaderState', 'EvaluatorType', 'InputType', 11 | 'FeatureType', 'FeatureSource', 'init_seed', 'general_arguments', 'training_arguments', 'evaluation_arguments', 12 | 'dataset_arguments', 'get_tensorboard', 'set_color', 'get_gpu_usage', 'WandbLogger' 13 | ] 14 | -------------------------------------------------------------------------------- /run_recbole.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/7/20 2 | # @Author : Shanlei Mu 3 | # @Email : slmu@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2020/10/3, 2020/10/1 7 | # @Author : Yupeng Hou, Zihan Lin 8 | # @Email : houyupeng@ruc.edu.cn, zhlin@ruc.edu.cn 9 | 10 | 11 | import argparse 12 | 13 | from recbole.quick_start import run_recbole 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--model', '-m', type=str, default='BPR', help='name of models') 19 | parser.add_argument('--dataset', '-d', type=str, default='ml-100k', help='name of datasets') 20 | parser.add_argument('--config_files', type=str, default=None, help='config files') 21 | 22 | args, _ = parser.parse_known_args() 23 | 24 | config_file_list = args.config_files.strip().split(' ') if args.config_files else None 25 | run_recbole(model=args.model, dataset=args.dataset, config_file_list=config_file_list) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 RUCAIBox 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /recbole/properties/dataset/gowalla.yaml: -------------------------------------------------------------------------------- 1 | # Atomic File Format 2 | field_separator: "\t" 3 | seq_separator: " " 4 | 5 | # Common Features 6 | USER_ID_FIELD: user_id 7 | ITEM_ID_FIELD: item_id 8 | RATING_FIELD: rating 9 | TIME_FIELD: timestamp 10 | seq_len: ~ 11 | # Label for Point-wise DataLoader 12 | LABEL_FIELD: label 13 | threshold: ~ 14 | # NegSample Prefix for Pair-wise DataLoader 15 | NEG_PREFIX: neg_ 16 | # Sequential Model Needed 17 | ITEM_LIST_LENGTH_FIELD: item_length 18 | LIST_SUFFIX: _list 19 | MAX_ITEM_LIST_LENGTH: 50 20 | POSITION_FIELD: position_id 21 | # Knowledge-based Model Needed 22 | HEAD_ENTITY_ID_FIELD: head_id 23 | TAIL_ENTITY_ID_FIELD: tail_id 24 | RELATION_ID_FIELD: relation_id 25 | ENTITY_ID_FIELD: entity_id 26 | 27 | # Selectively Loading 28 | load_col: 29 | inter: [user_id, item_id, timestamp] 30 | unload_col: ~ 31 | unused_col: ~ 32 | 33 | # Filtering 34 | rm_dup_inter: last 35 | val_interval: ~ 36 | filter_inter_by_user_or_item: True 37 | user_inter_num_interval: '[10,inf)' 38 | item_inter_num_interval: '[10,inf)' 39 | 40 | # Preprocessing 41 | alias_of_user_id: ~ 42 | alias_of_item_id: ~ 43 | alias_of_entity_id: ~ 44 | alias_of_relation_id: ~ 45 | preload_weight: ~ 46 | normalize_field: ~ 47 | normalize_all: True 48 | -------------------------------------------------------------------------------- /recbole/properties/overall.yaml: -------------------------------------------------------------------------------- 1 | # general 2 | gpu_id: 0 3 | use_gpu: True 4 | seed: 2020 5 | state: INFO 6 | reproducibility: True 7 | data_path: 'dataset/' 8 | checkpoint_dir: 'saved' 9 | save_dataset_dir: 'saved/_dataset' 10 | show_progress: False 11 | save_dataset: True 12 | dataset_save_path: ~ # saved/gowalla-dataset.pth 13 | save_dataloaders: True 14 | dataloaders_save_path: ~ #'saved/gowalla-for-BPR-dataloader.pth' 15 | save_pos_pairs: True 16 | log_wandb: False 17 | wandb_project: 'cikm2023' 18 | 19 | # training settings 20 | epochs: 1000 21 | train_batch_size: 2048 22 | learner: adam 23 | learning_rate: 0.001 24 | neg_sampling: 25 | uniform: 10 26 | in_batch: False 27 | eval_step: 1 28 | stopping_step: 10 29 | clip_grad_norm: ~ 30 | # clip_grad_norm: {'max_norm': 5, 'norm_type': 2} 31 | weight_decay: 0 32 | loss_decimal_place: 4 33 | require_pow: False 34 | 35 | # evaluation settings~~ 36 | eval_args: 37 | split: {'RS':[0.7,0.1,0.2]} 38 | group_by: user 39 | order: RO 40 | mode: full 41 | repeatable: False 42 | # metrics: ["Recall","MRR","NDCG","Hit","Precision"] 43 | metrics: ["Recall", "NDCG"] 44 | topk: [5, 10, 20, 50] 45 | valid_metric: NDCG@20 46 | valid_metric_bigger: True 47 | eval_batch_size: 1048576 48 | metric_decimal_place: 4 -------------------------------------------------------------------------------- /recbole/properties/dataset/yelp.yaml: -------------------------------------------------------------------------------- 1 | # Atomic File Format 2 | field_separator: "\t" 3 | seq_separator: " " 4 | 5 | # Common Features 6 | USER_ID_FIELD: user_id 7 | ITEM_ID_FIELD: item_id 8 | seq_len: ~ 9 | 10 | # Label for Point-wise DataLoader 11 | LABEL_FIELD: label 12 | threshold: ~ 13 | 14 | # NegSample Prefix for Pair-wise DataLoader 15 | NEG_PREFIX: neg_ 16 | 17 | # Selectively Loading 18 | load_col: 19 | inter: [user_id, item_id] 20 | # the others 21 | unload_col: ~ 22 | unused_col: ~ 23 | additional_feat_suffix: ~ 24 | 25 | # Filtering 26 | rm_dup_inter: last 27 | val_interval: ~ 28 | filter_inter_by_user_or_item: True 29 | user_inter_num_interval: "[10,inf)" 30 | item_inter_num_interval: "[10,inf)" 31 | 32 | # Preprocessing 33 | alias_of_user_id: ~ 34 | alias_of_item_id: ~ 35 | alias_of_entity_id: ~ 36 | alias_of_relation_id: ~ 37 | preload_weight: ~ 38 | normalize_field: ~ 39 | normalize_all: ~ 40 | 41 | # Sequential Model Needed 42 | ITEM_LIST_LENGTH_FIELD: item_length 43 | LIST_SUFFIX: _list 44 | MAX_ITEM_LIST_LENGTH: 50 45 | POSITION_FIELD: position_id 46 | 47 | # Knowledge-based Model Needed 48 | HEAD_ENTITY_ID_FIELD: head_id 49 | TAIL_ENTITY_ID_FIELD: tail_id 50 | RELATION_ID_FIELD: relation_id 51 | ENTITY_ID_FIELD: entity_id 52 | 53 | # Benchmark .inter 54 | benchmark_filename: ~ 55 | -------------------------------------------------------------------------------- /recbole/properties/dataset/beauty.yaml: -------------------------------------------------------------------------------- 1 | # Atomic File Format 2 | field_separator: "\t" 3 | seq_separator: " " 4 | 5 | # Common Features 6 | USER_ID_FIELD: user_id 7 | ITEM_ID_FIELD: item_id 8 | RATING_FIELD: rating 9 | TIME_FIELD: timestamp 10 | seq_len: ~ 11 | 12 | # Label for Point-wise DataLoader 13 | LABEL_FIELD: label 14 | threshold: ~ 15 | 16 | # NegSample Prefix for Pair-wise DataLoader 17 | NEG_PREFIX: neg_ 18 | 19 | # Selectively Loading 20 | load_col: 21 | inter: [user_id, item_id, rating, timestamp] 22 | # the others 23 | unload_col: ~ 24 | unused_col: ~ 25 | additional_feat_suffix: ~ 26 | 27 | # Filtering 28 | rm_dup_inter: last 29 | val_interval: ~ 30 | filter_inter_by_user_or_item: True 31 | user_inter_num_interval: "[5,inf)" 32 | item_inter_num_interval: "[5,inf)" 33 | 34 | # Preprocessing 35 | alias_of_user_id: ~ 36 | alias_of_item_id: ~ 37 | alias_of_entity_id: ~ 38 | alias_of_relation_id: ~ 39 | preload_weight: ~ 40 | normalize_field: ~ 41 | normalize_all: ~ 42 | 43 | # Sequential Model Needed 44 | ITEM_LIST_LENGTH_FIELD: item_length 45 | LIST_SUFFIX: _list 46 | MAX_ITEM_LIST_LENGTH: 50 47 | POSITION_FIELD: position_id 48 | 49 | # Knowledge-based Model Needed 50 | HEAD_ENTITY_ID_FIELD: head_id 51 | TAIL_ENTITY_ID_FIELD: tail_id 52 | RELATION_ID_FIELD: relation_id 53 | ENTITY_ID_FIELD: entity_id 54 | 55 | # Benchmark .inter 56 | benchmark_filename: ~ 57 | -------------------------------------------------------------------------------- /recbole/utils/argument_list.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/10/19 2 | # @Author : Shanlei Mu 3 | # @Email : slmu@ruc.edu.cn 4 | 5 | # yapf: disable 6 | 7 | general_arguments = [ 8 | 'gpu_id', 'use_gpu', 9 | 'seed', 10 | 'reproducibility', 11 | 'state', 12 | 'data_path', 13 | 'checkpoint_dir', 14 | 'show_progress', 15 | 'config_file', 16 | 'save_dataset', 17 | 'dataset_save_path', 18 | 'save_dataloaders', 19 | 'dataloaders_save_path', 20 | 'log_wandb', 21 | ] 22 | 23 | training_arguments = [ 24 | 'epochs', 'train_batch_size', 25 | 'learner', 'learning_rate', 26 | 'neg_sampling', 27 | 'eval_step', 'stopping_step', 28 | 'clip_grad_norm', 29 | 'weight_decay', 30 | 'loss_decimal_place', 31 | ] 32 | 33 | evaluation_arguments = [ 34 | 'eval_args', 'repeatable', 35 | 'metrics', 'topk', 'valid_metric', 'valid_metric_bigger', 36 | 'eval_batch_size', 37 | 'metric_decimal_place', 38 | ] 39 | 40 | dataset_arguments = [ 41 | 'field_separator', 'seq_separator', 42 | 'USER_ID_FIELD', 'ITEM_ID_FIELD', 'RATING_FIELD', 'TIME_FIELD', 43 | 'seq_len', 44 | 'LABEL_FIELD', 'threshold', 45 | 'NEG_PREFIX', 46 | 'ITEM_LIST_LENGTH_FIELD', 'LIST_SUFFIX', 'MAX_ITEM_LIST_LENGTH', 'POSITION_FIELD', 47 | 'HEAD_ENTITY_ID_FIELD', 'TAIL_ENTITY_ID_FIELD', 'RELATION_ID_FIELD', 'ENTITY_ID_FIELD', 48 | 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix', 49 | 'rm_dup_inter', 'val_interval', 'filter_inter_by_user_or_item', 50 | 'user_inter_num_interval', 'item_inter_num_interval', 51 | 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'alias_of_relation_id', 52 | 'preload_weight', 'normalize_field', 'normalize_all', 53 | 'benchmark_filename', 54 | ] 55 | -------------------------------------------------------------------------------- /recbole/model/init.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/9/16 3 | # @Author : Shanlei Mu 4 | # @Email : slmu@ruc.edu.cn 5 | 6 | """ 7 | recbole.model.init 8 | ######################## 9 | """ 10 | 11 | import torch.nn as nn 12 | from torch.nn.init import xavier_normal_, xavier_uniform_, constant_ 13 | 14 | 15 | def xavier_normal_initialization(module): 16 | r""" using `xavier_normal_`_ in PyTorch to initialize the parameters in 17 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers, 18 | using constant 0 to initialize. 19 | 20 | .. _`xavier_normal_`: 21 | https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_normal_#torch.nn.init.xavier_normal_ 22 | 23 | Examples: 24 | >>> self.apply(xavier_normal_initialization) 25 | """ 26 | if isinstance(module, nn.Embedding): 27 | xavier_normal_(module.weight.data) 28 | elif isinstance(module, nn.Linear): 29 | xavier_normal_(module.weight.data) 30 | if module.bias is not None: 31 | constant_(module.bias.data, 0) 32 | 33 | 34 | def xavier_uniform_initialization(module): 35 | r""" using `xavier_uniform_`_ in PyTorch to initialize the parameters in 36 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers, 37 | using constant 0 to initialize. 38 | 39 | .. _`xavier_uniform_`: 40 | https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_uniform_#torch.nn.init.xavier_uniform_ 41 | 42 | Examples: 43 | >>> self.apply(xavier_uniform_initialization) 44 | """ 45 | if isinstance(module, nn.Embedding): 46 | xavier_uniform_(module.weight.data) 47 | elif isinstance(module, nn.Linear): 48 | xavier_uniform_(module.weight.data) 49 | if module.bias is not None: 50 | constant_(module.bias.data, 0) 51 | -------------------------------------------------------------------------------- /recbole/data/dataloader/user_dataloader.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/9/23 2 | # @Author : Yushuo Chen 3 | # @Email : chenyushuo@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2020/9/23, 2020/12/28 7 | # @Author : Yushuo Chen, Xingyu Pan 8 | # @email : chenyushuo@ruc.edu.cn, panxy@ruc.edu.cn 9 | 10 | """ 11 | recbole.data.dataloader.user_dataloader 12 | ################################################ 13 | """ 14 | import torch 15 | 16 | from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader 17 | from recbole.data.interaction import Interaction 18 | 19 | 20 | class UserDataLoader(AbstractDataLoader): 21 | """:class:`UserDataLoader` will return a batch of data which only contains user-id when it is iterated. 22 | 23 | Args: 24 | config (Config): The config of dataloader. 25 | dataset (Dataset): The dataset of dataloader. 26 | sampler (Sampler): The sampler of dataloader. 27 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 28 | 29 | Attributes: 30 | shuffle (bool): Whether the dataloader will be shuffle after a round. 31 | However, in :class:`UserDataLoader`, it's guaranteed to be ``True``. 32 | """ 33 | 34 | def __init__(self, config, dataset, sampler, shuffle=False): 35 | if shuffle is False: 36 | shuffle = True 37 | self.logger.warning('UserDataLoader must shuffle the data.') 38 | 39 | self.uid_field = dataset.uid_field 40 | self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)}) 41 | 42 | super().__init__(config, dataset, sampler, shuffle=shuffle) 43 | 44 | def _init_batch_size_and_step(self): 45 | batch_size = self.config['train_batch_size'] 46 | self.step = batch_size 47 | self.set_batch_size(batch_size) 48 | 49 | @property 50 | def pr_end(self): 51 | return len(self.user_list) 52 | 53 | def _shuffle(self): 54 | self.user_list.shuffle() 55 | 56 | def _next_batch_data(self): 57 | cur_data = self.user_list[self.pr:self.pr + self.step] 58 | self.pr += self.step 59 | return cur_data 60 | -------------------------------------------------------------------------------- /recbole/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Time : 2021/6/25 3 | # @Author : Zhichao Feng 4 | # @email : fzcbupt@gmail.com 5 | 6 | """ 7 | recbole.evaluator.evaluator 8 | ##################################### 9 | """ 10 | 11 | from recbole.evaluator.register import metrics_dict 12 | from recbole.evaluator.collector import DataStruct 13 | from collections import OrderedDict 14 | 15 | 16 | class Evaluator(object): 17 | """Evaluator is used to check parameter correctness, and summarize the results of all metrics. 18 | """ 19 | 20 | def __init__(self, config): 21 | self.config = config 22 | self.metrics = [metric.lower() for metric in self.config['metrics']] 23 | self.metric_class = {} 24 | 25 | for metric in self.metrics: 26 | self.metric_class[metric] = metrics_dict[metric](self.config) 27 | 28 | def evaluate(self, dataobject: DataStruct): 29 | """calculate all the metrics. It is called at the end of each epoch 30 | 31 | Args: 32 | dataobject (DataStruct): It contains all the information needed for metrics. 33 | 34 | Returns: 35 | collections.OrderedDict: such as ``{'hit@20': 0.3824, 'recall@20': 0.0527, 'hit@10': 0.3153, 'recall@10': 0.0329, 'gauc': 0.9236}`` 36 | 37 | """ 38 | result_dict = OrderedDict() 39 | for metric in self.metrics: 40 | metric_val = self.metric_class[metric].calculate_metric(dataobject) 41 | result_dict.update(metric_val) 42 | return result_dict 43 | 44 | def evaluate_unbiased(self, dataobject: DataStruct): 45 | """calculate all the metrics. It is called at the end of each epoch 46 | 47 | Args: 48 | dataobject (DataStruct): It contains all the information needed for metrics. 49 | 50 | Returns: 51 | collections.OrderedDict: such as ``{'hit@20': 0.3824, 'recall@20': 0.0527, 'hit@10': 0.3153, 'recall@10': 0.0329, 'gauc': 0.9236}`` 52 | 53 | """ 54 | result_dict = OrderedDict() 55 | for metric in self.metrics: 56 | metric_val = self.metric_class[metric].calculate_metric(dataobject, unbiased=True) 57 | result_dict.update(metric_val) 58 | return result_dict -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | 7 | from setuptools import setup, find_packages 8 | 9 | install_requires = ['numpy>=1.17.2', 'torch>=1.7.0', 'scipy==1.6.0', 'pandas>=1.0.5', 'tqdm>=4.48.2', 10 | 'colorlog==4.7.2','colorama==0.4.4', 11 | 'scikit_learn>=0.23.2', 'pyyaml>=5.1.0', 'tensorboard>=2.5.0'] 12 | 13 | setup_requires = [] 14 | 15 | extras_require = { 16 | 'hyperopt': ['hyperopt>=0.2.4'] 17 | } 18 | 19 | classifiers = ["License :: OSI Approved :: MIT License"] 20 | 21 | long_description = 'RecBole is developed based on Python and PyTorch for ' \ 22 | 'reproducing and developing recommendation algorithms in ' \ 23 | 'a unified, comprehensive and efficient framework for ' \ 24 | 'research purpose. In the first version, our library ' \ 25 | 'includes 53 recommendation algorithms, covering four ' \ 26 | 'major categories: General Recommendation, Sequential ' \ 27 | 'Recommendation, Context-aware Recommendation and ' \ 28 | 'Knowledge-based Recommendation. View RecBole homepage ' \ 29 | 'for more information: https://recbole.io' 30 | 31 | # Readthedocs requires Sphinx extensions to be specified as part of 32 | # install_requires in order to build properly. 33 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 34 | if on_rtd: 35 | install_requires.extend(setup_requires) 36 | 37 | setup( 38 | name='recbole', 39 | version= 40 | '1.0.1', # please remember to edit recbole/__init__.py in response, once updating the version 41 | description='A unified, comprehensive and efficient recommendation library', 42 | long_description=long_description, 43 | long_description_content_type="text/markdown", 44 | url='https://github.com/RUCAIBox/RecBole', 45 | author='RecBoleTeam', 46 | author_email='recbole@outlook.com', 47 | packages=[ 48 | package for package in find_packages() 49 | if package.startswith('recbole') 50 | ], 51 | include_package_data=True, 52 | install_requires=install_requires, 53 | setup_requires=setup_requires, 54 | extras_require=extras_require, 55 | zip_safe=False, 56 | classifiers=classifiers, 57 | ) 58 | -------------------------------------------------------------------------------- /recbole/utils/wandblogger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/8/2 3 | # @Author : Ayush Thakur 4 | # @Email : ayusht@wandb.com 5 | 6 | r""" 7 | recbole.utils.wandblogger 8 | ################################ 9 | """ 10 | 11 | class WandbLogger(object): 12 | """WandbLogger to log metrics to Weights and Biases. 13 | 14 | """ 15 | def __init__(self, config): 16 | """ 17 | Args: 18 | config (dict): A dictionary of parameters used by RecBole. 19 | """ 20 | self.config = config 21 | self.log_wandb = config.log_wandb 22 | self.setup() 23 | 24 | def setup(self): 25 | if self.log_wandb: 26 | try: 27 | import wandb 28 | self._wandb = wandb 29 | except ImportError: 30 | raise ImportError( 31 | "To use the Weights and Biases Logger please install wandb." 32 | "Run `pip install wandb` to install it." 33 | ) 34 | 35 | # Initialize a W&B run 36 | if self._wandb.run is None: 37 | self._wandb.init( 38 | project=self.config.wandb_project, 39 | config=self.config 40 | ) 41 | 42 | self._set_steps() 43 | 44 | def log_metrics(self, metrics, head='train', commit=True): 45 | if self.log_wandb: 46 | if head: 47 | metrics = self._add_head_to_metrics(metrics, head) 48 | self._wandb.log(metrics, commit=commit) 49 | else: 50 | self._wandb.log(metrics, commit=commit) 51 | 52 | def log_eval_metrics(self, metrics, head='eval'): 53 | if self.log_wandb: 54 | metrics = self._add_head_to_metrics(metrics, head) 55 | for k, v in metrics.items(): 56 | self._wandb.run.summary[k] = v 57 | 58 | def _set_steps(self): 59 | self._wandb.define_metric('train/*', step_metric='train_step') 60 | self._wandb.define_metric('valid/*', step_metric='valid_step') 61 | 62 | def _add_head_to_metrics(self, metrics, head): 63 | head_metrics = dict() 64 | for k, v in metrics.items(): 65 | if '_step' in k: 66 | head_metrics[k] = v 67 | else: 68 | head_metrics[f'{head}/{k}'] = v 69 | 70 | return head_metrics 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Toward a Better Understanding of Loss Functions for Collaborative Filtering (CIKM'23) 2 | 3 | This is the official code for MAWU in the paper "[Toward a Better Understanding of Loss Functions for Collaborative Filtering](https://arxiv.org/abs/2308.06091)", *The 32nd ACM International Conference on Information and Knowledge Management*. This code is implemented on [RecBole](https://github.com/RUCAIBox/RecBole). 4 | 5 | Note that a summary of our paper is on [our lab blog](https://dial.skku.edu/blog/2023_mawu) (in Korean). 6 | 7 | ## How to run 8 | 9 | ### Set conda environment 10 | ``` 11 | conda env create -f mawu.yaml 12 | conda activate mawu 13 | ``` 14 | 15 | ### Run commands for DirectAU 16 | ``` 17 | python run_recbole.py --model=DirectAU --dataset=beauty --encoder=MF --weight_decay=1e-4 --gamma=0.4 && 18 | python run_recbole.py --model=DirectAU --dataset=beauty --encoder=LightGCN --weight_decay=1e-4 --gamma=0.4 && 19 | python run_recbole.py --model=DirectAU --dataset=gowalla --encoder=MF --weight_decay=1e-6 --gamma=2 && 20 | python run_recbole.py --model=DirectAU --dataset=gowalla --encoder=LightGCN --weight_decay=1e-6 --gamma=2 && 21 | python run_recbole.py --model=DirectAU --dataset=yelp --encoder=MF --weight_decay=1e-6 --gamma=2 && 22 | python run_recbole.py --model=DirectAU --dataset=yelp --encoder=LightGCN --weight_decay=1e-6 --gamma=2 23 | ``` 24 | 25 | ### Run commands for MAWU 26 | ``` 27 | python run_recbole.py --model=MAWU --dataset=beauty --encoder=MF --weight_decay=1e-4 --gamma1=1 --gamma2=0.1 && 28 | python run_recbole.py --model=MAWU --dataset=beauty --encoder=LightGCN --weight_decay=1e-4 --gamma1=0.9 --gamma2=0.2 && 29 | python run_recbole.py --model=MAWU --dataset=gowalla --encoder=MF --weight_decay=1e-6 --gamma1=2.6 --gamma2=1.4 && 30 | python run_recbole.py --model=MAWU --dataset=gowalla --encoder=LightGCN --weight_decay=1e-6 --gamma1=2.4 --gamma2=1.6 && 31 | python run_recbole.py --model=MAWU --dataset=yelp --encoder=MF --weight_decay=1e-6 --gamma1=0.8 --gamma2=0.6 && 32 | python run_recbole.py --model=MAWU --dataset=yelp --encoder=LightGCN --weight_decay=1e-6 --gamma1=1.2 --gamma2=0.6 33 | ``` 34 | 35 | ## Citation 36 | If you find our work helpful, please cite our paper. 37 | ``` 38 | @inproceedings{park2023mawu, 39 | title={Toward a Better Understanding of Loss Functions for Collaborative Filtering}, 40 | author={Seongmin Park and 41 | Mincheol Yoon and 42 | Jae-woong Lee and 43 | Hogun Park and 44 | Jongwuk Lee}, 45 | booktitle={The 32nd ACM International Conference on Information and Knowledge Management (CIKM)}, 46 | year={2023} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /recbole/data/dataset/decisiontree_dataset.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/12/17 2 | # @Author : Chen Yang 3 | # @Email : 254170321@qq.com 4 | 5 | """ 6 | recbole.data.decisiontree_dataset 7 | ########################## 8 | """ 9 | 10 | from recbole.data.dataset import Dataset 11 | from recbole.utils import FeatureType 12 | 13 | 14 | class DecisionTreeDataset(Dataset): 15 | """:class:`DecisionTreeDataset` is based on :class:`~recbole.data.dataset.dataset.Dataset`, 16 | and 17 | 18 | Attributes: 19 | 20 | """ 21 | 22 | def __init__(self, config): 23 | super().__init__(config) 24 | 25 | def _judge_token_and_convert(self, feat): 26 | # get columns whose type is token 27 | col_list = [] 28 | for col_name in feat: 29 | if col_name == self.uid_field or col_name == self.iid_field: 30 | continue 31 | if self.field2type[col_name] == FeatureType.TOKEN: 32 | col_list.append(col_name) 33 | elif self.field2type[col_name] in {FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ}: 34 | feat = feat.drop([col_name], axis=1, inplace=False) 35 | 36 | # get hash map 37 | for col in col_list: 38 | self.hash_map[col] = dict({}) 39 | self.hash_count[col] = 0 40 | 41 | del_col = [] 42 | for col in self.hash_map: 43 | if col in feat.keys(): 44 | for value in feat[col]: 45 | # print(value) 46 | if value not in self.hash_map[col]: 47 | self.hash_map[col][value] = self.hash_count[col] 48 | self.hash_count[col] = self.hash_count[col] + 1 49 | if self.hash_count[col] > self.config['token_num_threshold']: 50 | del_col.append(col) 51 | break 52 | 53 | for col in del_col: 54 | del self.hash_count[col] 55 | del self.hash_map[col] 56 | col_list.remove(col) 57 | self.convert_col_list.extend(col_list) 58 | 59 | # transform the original data 60 | for col in self.hash_map.keys(): 61 | if col in feat.keys(): 62 | feat[col] = feat[col].map(self.hash_map[col]) 63 | 64 | return feat 65 | 66 | def _convert_token_to_hash(self): 67 | """Convert the data of token type to hash form 68 | 69 | """ 70 | self.hash_map = {} 71 | self.hash_count = {} 72 | self.convert_col_list = [] 73 | if self.config['convert_token_to_onehot']: 74 | for feat_name in ['inter_feat', 'user_feat', 'item_feat']: 75 | feat = getattr(self, feat_name) 76 | if feat is not None: 77 | feat = self._judge_token_and_convert(feat) 78 | setattr(self, feat_name, feat) 79 | 80 | def _from_scratch(self): 81 | super()._from_scratch() 82 | self._convert_token_to_hash() 83 | -------------------------------------------------------------------------------- /recbole/utils/enum_type.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/8/9 3 | # @Author : Yupeng Hou 4 | # @Email : houyupeng@ruc.edu.cn 5 | 6 | """ 7 | recbole.utils.enum_type 8 | ####################### 9 | """ 10 | 11 | from enum import Enum 12 | 13 | 14 | class ModelType(Enum): 15 | """Type of models. 16 | 17 | - ``GENERAL``: General Recommendation 18 | - ``SEQUENTIAL``: Sequential Recommendation 19 | - ``CONTEXT``: Context-aware Recommendation 20 | - ``KNOWLEDGE``: Knowledge-based Recommendation 21 | """ 22 | 23 | GENERAL = 1 24 | SEQUENTIAL = 2 25 | CONTEXT = 3 26 | KNOWLEDGE = 4 27 | TRADITIONAL = 5 28 | DECISIONTREE = 6 29 | 30 | 31 | class KGDataLoaderState(Enum): 32 | """States for Knowledge-based DataLoader. 33 | 34 | - ``RSKG``: Return both knowledge graph information and user-item interaction information. 35 | - ``RS``: Only return the user-item interaction. 36 | - ``KG``: Only return the triplets with negative examples in a knowledge graph. 37 | """ 38 | 39 | RSKG = 1 40 | RS = 2 41 | KG = 3 42 | 43 | 44 | class EvaluatorType(Enum): 45 | """Type for evaluation metrics. 46 | 47 | - ``RANKING``: Ranking-based metrics like NDCG, Recall, etc. 48 | - ``VALUE``: Value-based metrics like AUC, etc. 49 | """ 50 | 51 | RANKING = 1 52 | VALUE = 2 53 | 54 | 55 | class InputType(Enum): 56 | """Type of Models' input. 57 | 58 | - ``POINTWISE``: Point-wise input, like ``uid, iid, label``. 59 | - ``PAIRWISE``: Pair-wise input, like ``uid, pos_iid, neg_iid``. 60 | """ 61 | 62 | POINTWISE = 1 63 | PAIRWISE = 2 64 | LISTWISE = 3 65 | SETWISE = 4 66 | DualSETWISE = 5 67 | 68 | 69 | class FeatureType(Enum): 70 | """Type of features. 71 | 72 | - ``TOKEN``: Token features like user_id and item_id. 73 | - ``FLOAT``: Float features like rating and timestamp. 74 | - ``TOKEN_SEQ``: Token sequence features like review. 75 | - ``FLOAT_SEQ``: Float sequence features like pretrained vector. 76 | """ 77 | 78 | TOKEN = 'token' 79 | FLOAT = 'float' 80 | TOKEN_SEQ = 'token_seq' 81 | FLOAT_SEQ = 'float_seq' 82 | 83 | 84 | class FeatureSource(Enum): 85 | """Source of features. 86 | 87 | - ``INTERACTION``: Features from ``.inter`` (other than ``user_id`` and ``item_id``). 88 | - ``USER``: Features from ``.user`` (other than ``user_id``). 89 | - ``ITEM``: Features from ``.item`` (other than ``item_id``). 90 | - ``USER_ID``: ``user_id`` feature in ``inter_feat`` and ``user_feat``. 91 | - ``ITEM_ID``: ``item_id`` feature in ``inter_feat`` and ``item_feat``. 92 | - ``KG``: Features from ``.kg``. 93 | - ``NET``: Features from ``.net``. 94 | """ 95 | 96 | INTERACTION = 'inter' 97 | USER = 'user' 98 | ITEM = 'item' 99 | USER_ID = 'user_id' 100 | ITEM_ID = 'item_id' 101 | KG = 'kg' 102 | NET = 'net' 103 | -------------------------------------------------------------------------------- /recbole/utils/url.py: -------------------------------------------------------------------------------- 1 | ''' 2 | recbole.utils.url 3 | ################################ 4 | Reference code: 5 | https://github.com/snap-stanford/ogb/blob/master/ogb/utils/url.py 6 | ''' 7 | 8 | import urllib.request as ur 9 | import zipfile 10 | import os 11 | import os.path as osp 12 | import errno 13 | from logging import getLogger 14 | 15 | from tqdm import tqdm 16 | 17 | GBFACTOR = float(1 << 30) 18 | 19 | 20 | def decide_download(url): 21 | d = ur.urlopen(url) 22 | size = int(d.info()['Content-Length']) / GBFACTOR 23 | 24 | ### confirm if larger than 1GB 25 | if size > 1: 26 | return input('This will download %.2fGB. Will you proceed? (y/N)\n' % (size)).lower() == 'y' 27 | else: 28 | return True 29 | 30 | 31 | def makedirs(path): 32 | try: 33 | os.makedirs(osp.expanduser(osp.normpath(path))) 34 | except OSError as e: 35 | if e.errno != errno.EEXIST and osp.isdir(path): 36 | raise e 37 | 38 | 39 | def download_url(url, folder): 40 | '''Downloads the content of an URL to a specific folder. 41 | 42 | Args: 43 | url (string): The url. 44 | folder (string): The folder. 45 | ''' 46 | 47 | filename = url.rpartition('/')[2] 48 | path = osp.join(folder, filename) 49 | logger = getLogger() 50 | 51 | if osp.exists(path) and osp.getsize(path) > 0: # pragma: no cover 52 | logger.info(f'Using exist file {filename}') 53 | return path 54 | 55 | logger.info(f'Downloading {url}') 56 | 57 | makedirs(folder) 58 | data = ur.urlopen(url) 59 | 60 | size = int(data.info()['Content-Length']) 61 | 62 | chunk_size = 1024 * 1024 63 | num_iter = int(size / chunk_size) + 2 64 | 65 | downloaded_size = 0 66 | 67 | try: 68 | with open(path, 'wb') as f: 69 | pbar = tqdm(range(num_iter)) 70 | for i in pbar: 71 | chunk = data.read(chunk_size) 72 | downloaded_size += len(chunk) 73 | pbar.set_description('Downloaded {:.2f} GB'.format(float(downloaded_size) / GBFACTOR)) 74 | f.write(chunk) 75 | except: 76 | if os.path.exists(path): 77 | os.remove(path) 78 | raise RuntimeError('Stopped downloading due to interruption.') 79 | 80 | return path 81 | 82 | 83 | def extract_zip(path, folder): 84 | '''Extracts a zip archive to a specific folder. 85 | 86 | Args: 87 | path (string): The path to the tar archive. 88 | folder (string): The folder. 89 | ''' 90 | logger = getLogger() 91 | logger.info(f'Extracting {path}') 92 | with zipfile.ZipFile(path, 'r') as f: 93 | f.extractall(folder) 94 | 95 | 96 | def rename_atomic_files(folder, old_name, new_name): 97 | '''Rename all atomic files in a given folder. 98 | 99 | Args: 100 | folder (string): The folder. 101 | old_name (string): Old name for atomic files. 102 | new_name (string): New name for atomic files. 103 | ''' 104 | files = os.listdir(folder) 105 | for f in files: 106 | base, suf = os.path.splitext(f) 107 | if not old_name in base: 108 | continue 109 | assert suf in {'.inter', '.user', '.item'} 110 | os.rename(os.path.join(folder, f), os.path.join(folder, base.replace(old_name, new_name) + suf)) 111 | 112 | 113 | if __name__ == '__main__': 114 | pass 115 | -------------------------------------------------------------------------------- /recbole/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/8/7 3 | # @Author : Zihan Lin 4 | # @Email : linzihan.super@foxmail.com 5 | 6 | # UPDATE 7 | # @Time : 2021/3/7 8 | # @Author : Jiawei Guan 9 | # @Email : guanjw@ruc.edu.cn 10 | 11 | """ 12 | recbole.utils.logger 13 | ############################### 14 | """ 15 | 16 | import logging 17 | import os 18 | import colorlog 19 | import re 20 | 21 | from recbole.utils.utils import get_local_time, ensure_dir 22 | from colorama import init 23 | 24 | log_colors_config = { 25 | 'DEBUG': 'cyan', 26 | 'WARNING': 'yellow', 27 | 'ERROR': 'red', 28 | 'CRITICAL': 'red', 29 | } 30 | 31 | 32 | class RemoveColorFilter(logging.Filter): 33 | 34 | def filter(self, record): 35 | if record: 36 | ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') 37 | record.msg = ansi_escape.sub('', str(record.msg)) 38 | return True 39 | 40 | 41 | def set_color(log, color, highlight=True): 42 | color_set = ['black', 'red', 'green', 'yellow', 'blue', 'pink', 'cyan', 'white'] 43 | try: 44 | index = color_set.index(color) 45 | except: 46 | index = len(color_set) - 1 47 | prev_log = '\033[' 48 | if highlight: 49 | prev_log += '1;3' 50 | else: 51 | prev_log += '0;3' 52 | prev_log += str(index) + 'm' 53 | return prev_log + log + '\033[0m' 54 | 55 | 56 | def init_logger(config): 57 | """ 58 | A logger that can show a message on standard output and write it into the 59 | file named `filename` simultaneously. 60 | All the message that you want to log MUST be str. 61 | 62 | Args: 63 | config (Config): An instance object of Config, used to record parameter information. 64 | 65 | Example: 66 | >>> logger = logging.getLogger(config) 67 | >>> logger.debug(train_state) 68 | >>> logger.info(train_result) 69 | """ 70 | init(autoreset=True) 71 | LOGROOT = './log/' 72 | dir_name = os.path.dirname(LOGROOT) 73 | ensure_dir(dir_name) 74 | model_name = os.path.join(dir_name, config['model']) 75 | ensure_dir(model_name) 76 | logfilename = '{}/{}.log'.format(config['model'], get_local_time()) 77 | 78 | logfilepath = os.path.join(LOGROOT, logfilename) 79 | 80 | filefmt = "%(asctime)-15s %(levelname)s %(message)s" 81 | filedatefmt = "%a %d %b %Y %H:%M:%S" 82 | fileformatter = logging.Formatter(filefmt, filedatefmt) 83 | 84 | sfmt = "%(log_color)s%(asctime)-15s %(levelname)s %(message)s" 85 | sdatefmt = "%d %b %H:%M" 86 | sformatter = colorlog.ColoredFormatter(sfmt, sdatefmt, log_colors=log_colors_config) 87 | if config['state'] is None or config['state'].lower() == 'info': 88 | level = logging.INFO 89 | elif config['state'].lower() == 'debug': 90 | level = logging.DEBUG 91 | elif config['state'].lower() == 'error': 92 | level = logging.ERROR 93 | elif config['state'].lower() == 'warning': 94 | level = logging.WARNING 95 | elif config['state'].lower() == 'critical': 96 | level = logging.CRITICAL 97 | else: 98 | level = logging.INFO 99 | 100 | fh = logging.FileHandler(logfilepath) 101 | fh.setLevel(level) 102 | fh.setFormatter(fileformatter) 103 | remove_color_filter = RemoveColorFilter() 104 | fh.addFilter(remove_color_filter) 105 | 106 | sh = logging.StreamHandler() 107 | sh.setLevel(level) 108 | sh.setFormatter(sformatter) 109 | 110 | logging.basicConfig(level=level, handlers=[sh, fh]) 111 | -------------------------------------------------------------------------------- /recbole/evaluator/register.py: -------------------------------------------------------------------------------- 1 | # @Time : 2021/6/23 2 | # @Author : Zihan Lin 3 | # @Email : zhlin@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2021/8/29 7 | # @Author : Zhichao Feng 8 | # @email : fzcbupt@gmail.com 9 | 10 | """ 11 | recbole.evaluator.register 12 | ################################################ 13 | """ 14 | import inspect 15 | import sys 16 | 17 | 18 | def cluster_info(module_name): 19 | """Collect information of all metrics, including: 20 | 21 | - ``metric_need``: Information needed to calculate this metric, the combination of ``rec.items, rec.topk, 22 | rec.meanrank, rec.score, data.num_items, data.num_users, data.count_items, data.count_users, data.label``. 23 | - ``metric_type``: Whether the scores required by metric are grouped by user, range in ``EvaluatorType.RANKING`` 24 | and ``EvaluatorType.VALUE``. 25 | - ``smaller``: Whether the smaller metric value represents better performance, 26 | range in ``True`` and ``False``, default to ``False``. 27 | 28 | Note: 29 | For ``metric_type``: in current RecBole, all the "grouped-score" metrics are ranking-based and all the 30 | "non-grouped-score" metrics are value-based. To keep with our paper, we adopted the more formal terms: 31 | ``RANKING`` and ``VALUE``. 32 | 33 | Args: 34 | module_name (str): the name of module ``recbole.evaluator.metrics``. 35 | 36 | Returns: 37 | dict: Three dictionaries containing the above information 38 | and a dictionary matching metric names to metric classes. 39 | """ 40 | smaller_m = [] 41 | m_dict, m_info, m_types = {}, {}, {} 42 | metric_class = inspect.getmembers( 43 | sys.modules[module_name], lambda x: inspect.isclass(x) and x.__module__ == module_name 44 | ) 45 | for name, metric_cls in metric_class: 46 | name = name.lower() 47 | m_dict[name] = metric_cls 48 | if hasattr(metric_cls, 'metric_need'): 49 | m_info[name] = metric_cls.metric_need 50 | else: 51 | raise AttributeError(f"Metric '{name}' has no attribute [metric_need].") 52 | if hasattr(metric_cls, 'metric_type'): 53 | m_types[name] = metric_cls.metric_type 54 | else: 55 | raise AttributeError(f"Metric '{name}' has no attribute [metric_type].") 56 | if metric_cls.smaller is True: 57 | smaller_m.append(name) 58 | return smaller_m, m_info, m_types, m_dict 59 | 60 | 61 | metric_module_name = 'recbole.evaluator.metrics' 62 | smaller_metrics, metric_information, metric_types, metrics_dict = cluster_info(metric_module_name) 63 | 64 | 65 | class Register(object): 66 | """ Register module load the registry according to the metrics in config. 67 | It is a member of DataCollector. 68 | The DataCollector collect the resource that need for Evaluator under the guidance of Register 69 | """ 70 | 71 | def __init__(self, config): 72 | 73 | self.config = config 74 | self.metrics = [metric.lower() for metric in self.config['metrics']] 75 | self._build_register() 76 | 77 | def _build_register(self): 78 | for metric in self.metrics: 79 | metric_needs = metric_information[metric] 80 | for info in metric_needs: 81 | setattr(self, info, True) 82 | 83 | def has_metric(self, metric: str): 84 | if metric.lower() in self.metrics: 85 | return True 86 | else: 87 | return False 88 | 89 | def need(self, key: str): 90 | if hasattr(self, key): 91 | return getattr(self, key) 92 | return False 93 | -------------------------------------------------------------------------------- /recbole/utils/case_study.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/12/25 2 | # @Author : Yushuo Chen 3 | # @Email : chenyushuo@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2020/12/25 7 | # @Author : Yushuo Chen 8 | # @email : chenyushuo@ruc.edu.cn 9 | 10 | """ 11 | recbole.utils.case_study 12 | ##################################### 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from recbole.data.interaction import Interaction 19 | 20 | 21 | @torch.no_grad() 22 | def full_sort_scores(uid_series, model, test_data, device=None): 23 | """Calculate the scores of all items for each user in uid_series. 24 | 25 | Note: 26 | The score of [pad] and history items will be set into -inf. 27 | 28 | Args: 29 | uid_series (numpy.ndarray or list): User id series. 30 | model (AbstractRecommender): Model to predict. 31 | test_data (FullSortEvalDataLoader): The test_data of model. 32 | device (torch.device, optional): The device which model will run on. Defaults to ``None``. 33 | Note: ``device=None`` is equivalent to ``device=torch.device('cpu')``. 34 | 35 | Returns: 36 | torch.Tensor: the scores of all items for each user in uid_series. 37 | """ 38 | device = device or torch.device('cpu') 39 | uid_series = torch.tensor(uid_series) 40 | uid_field = test_data.dataset.uid_field 41 | dataset = test_data.dataset 42 | model.eval() 43 | 44 | if not test_data.is_sequential: 45 | input_interaction = dataset.join(Interaction({uid_field: uid_series})) 46 | history_item = test_data.uid2history_item[list(uid_series)] 47 | history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) 48 | history_col = torch.cat(list(history_item)) 49 | history_index = history_row, history_col 50 | else: 51 | _, index = (dataset.inter_feat[uid_field] == uid_series[:, None]).nonzero(as_tuple=True) 52 | input_interaction = dataset[index] 53 | history_index = None 54 | 55 | # Get scores of all items 56 | input_interaction = input_interaction.to(device) 57 | try: 58 | scores = model.full_sort_predict(input_interaction) 59 | except NotImplementedError: 60 | input_interaction = input_interaction.repeat_interleave(dataset.item_num) 61 | input_interaction.update(test_data.dataset.get_item_feature().to(device).repeat(len(uid_series))) 62 | scores = model.predict(input_interaction) 63 | 64 | scores = scores.view(-1, dataset.item_num) 65 | scores[:, 0] = -np.inf # set scores of [pad] to -inf 66 | if history_index is not None: 67 | scores[history_index] = -np.inf # set scores of history items to -inf 68 | 69 | return scores 70 | 71 | 72 | def full_sort_topk(uid_series, model, test_data, k, device=None): 73 | """Calculate the top-k items' scores and ids for each user in uid_series. 74 | 75 | Note: 76 | The score of [pad] and history items will be set into -inf. 77 | 78 | Args: 79 | uid_series (numpy.ndarray): User id series. 80 | model (AbstractRecommender): Model to predict. 81 | test_data (FullSortEvalDataLoader): The test_data of model. 82 | k (int): The top-k items. 83 | device (torch.device, optional): The device which model will run on. Defaults to ``None``. 84 | Note: ``device=None`` is equivalent to ``device=torch.device('cpu')``. 85 | 86 | Returns: 87 | tuple: 88 | - topk_scores (torch.Tensor): The scores of topk items. 89 | - topk_index (torch.Tensor): The index of topk items, which is also the internal ids of items. 90 | """ 91 | scores = full_sort_scores(uid_series, model, test_data, device) 92 | return torch.topk(scores, k) 93 | -------------------------------------------------------------------------------- /recbole/evaluator/utils.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Time : 2020/08/04 3 | # @Author : Kaiyuan Li 4 | # @email : tsotfsk@outlook.com 5 | 6 | # UPDATE 7 | # @Time : 2020/09/28, 2020/08/09 8 | # @Author : Kaiyuan Li, Zhichao Feng 9 | # @email : tsotfsk@outlook.com, fzcbupt@gmail.com 10 | 11 | """ 12 | recbole.evaluator.utils 13 | ################################ 14 | """ 15 | 16 | import itertools 17 | 18 | import numpy as np 19 | import torch 20 | 21 | 22 | def pad_sequence(sequences, len_list, pad_to=None, padding_value=0): 23 | """pad sequences to a matrix 24 | 25 | Args: 26 | sequences (list): list of variable length sequences. 27 | len_list (list): the length of the tensors in the sequences 28 | pad_to (int, optional): if pad_to is not None, the sequences will pad to the length you set, 29 | else the sequence will pad to the max length of the sequences. 30 | padding_value (int, optional): value for padded elements. Default: 0. 31 | 32 | Returns: 33 | torch.Tensor: [seq_num, max_len] or [seq_num, pad_to] 34 | 35 | """ 36 | max_len = np.max(len_list) if pad_to is None else pad_to 37 | min_len = np.min(len_list) 38 | device = sequences[0].device 39 | if max_len == min_len: 40 | result = torch.cat(sequences, dim=0).view(-1, max_len) 41 | else: 42 | extra_len_list = np.subtract(max_len, len_list).tolist() 43 | padding_nums = max_len * len(len_list) - np.sum(len_list) 44 | padding_tensor = torch.tensor([-np.inf], device=device).repeat(padding_nums) 45 | padding_list = torch.split(padding_tensor, extra_len_list) 46 | result = list(itertools.chain.from_iterable(zip(sequences, padding_list))) 47 | result = torch.cat(result) 48 | 49 | return result.view(-1, max_len) 50 | 51 | 52 | def trunc(scores, method): 53 | """Round the scores by using the given method 54 | 55 | Args: 56 | scores (numpy.ndarray): scores 57 | method (str): one of ['ceil', 'floor', 'around'] 58 | 59 | Raises: 60 | NotImplementedError: method error 61 | 62 | Returns: 63 | numpy.ndarray: processed scores 64 | """ 65 | 66 | try: 67 | cut_method = getattr(np, method) 68 | except NotImplementedError: 69 | raise NotImplementedError("module 'numpy' has no function named '{}'".format(method)) 70 | scores = cut_method(scores) 71 | return scores 72 | 73 | 74 | def cutoff(scores, threshold): 75 | """cut of the scores based on threshold 76 | 77 | Args: 78 | scores (numpy.ndarray): scores 79 | threshold (float): between 0 and 1 80 | 81 | Returns: 82 | numpy.ndarray: processed scores 83 | """ 84 | return np.where(scores > threshold, 1, 0) 85 | 86 | 87 | def _binary_clf_curve(trues, preds): 88 | """Calculate true and false positives per binary classification threshold 89 | 90 | Args: 91 | trues (numpy.ndarray): the true scores' list 92 | preds (numpy.ndarray): the predict scores' list 93 | 94 | Returns: 95 | fps (numpy.ndarray): A count of false positives, at index i being the number of negative 96 | samples assigned a score >= thresholds[i] 97 | preds (numpy.ndarray): An increasing count of true positives, at index i being the number 98 | of positive samples assigned a score >= thresholds[i]. 99 | 100 | Note: 101 | To improve efficiency, we referred to the source code(which is available at sklearn.metrics.roc_curve) 102 | in SkLearn and made some optimizations. 103 | 104 | """ 105 | trues = (trues == 1) 106 | 107 | desc_idxs = np.argsort(preds)[::-1] 108 | preds = preds[desc_idxs] 109 | trues = trues[desc_idxs] 110 | 111 | unique_val_idxs = np.where(np.diff(preds))[0] 112 | threshold_idxs = np.r_[unique_val_idxs, trues.size - 1] 113 | 114 | tps = np.cumsum(trues)[threshold_idxs] 115 | fps = 1 + threshold_idxs - tps 116 | return fps, tps 117 | -------------------------------------------------------------------------------- /mawu.yaml: -------------------------------------------------------------------------------- 1 | name: mawu 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - asttokens=2.2.1=pyhd8ed1ab_0 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - backports=1.0=pyhd8ed1ab_3 11 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 12 | - bzip2=1.0.8=h7f98852_4 13 | - ca-certificates=2023.7.22=hbcca054_0 14 | - certifi=2023.7.22=pyhd8ed1ab_0 15 | - comm=0.1.3=pyhd8ed1ab_0 16 | - debugpy=1.6.7=py38h8dc9893_0 17 | - decorator=5.1.1=pyhd8ed1ab_0 18 | - executing=1.2.0=pyhd8ed1ab_0 19 | - importlib-metadata=6.8.0=pyha770c72_0 20 | - importlib_metadata=6.8.0=hd8ed1ab_0 21 | - ipykernel=6.25.0=pyh71e2992_0 22 | - jedi=0.18.2=pyhd8ed1ab_0 23 | - jupyter_client=8.3.0=pyhd8ed1ab_0 24 | - jupyter_core=4.12.0=py38h578d9bd_0 25 | - ld_impl_linux-64=2.40=h41732ed_0 26 | - libedit=3.1.20191231=he28a2e2_2 27 | - libffi=3.4.2=h7f98852_5 28 | - libgcc-ng=13.1.0=he5830b7_0 29 | - libgomp=13.1.0=he5830b7_0 30 | - libnsl=2.0.0=h7f98852_0 31 | - libsodium=1.0.18=h36c2ea0_1 32 | - libsqlite=3.42.0=h2797004_0 33 | - libstdcxx-ng=13.1.0=hfd8a6a1_0 34 | - libuuid=2.38.1=h0b41bf4_0 35 | - libzlib=1.2.13=hd590300_5 36 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 37 | - ncurses=6.4=hcb278e6_0 38 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 39 | - openssl=3.1.1=hd590300_1 40 | - packaging=23.1=pyhd8ed1ab_0 41 | - parso=0.8.3=pyhd8ed1ab_0 42 | - pexpect=4.8.0=pyh1a96a4e_2 43 | - pickleshare=0.7.5=py_1003 44 | - pip=23.2.1=pyhd8ed1ab_0 45 | - prompt_toolkit=3.0.39=hd8ed1ab_0 46 | - psutil=5.9.5=py38h1de0b5d_0 47 | - ptyprocess=0.7.0=pyhd3deb0d_0 48 | - pure_eval=0.2.2=pyhd8ed1ab_0 49 | - pygments=2.15.1=pyhd8ed1ab_0 50 | - python=3.8.17=he550d4f_0_cpython 51 | - python-dateutil=2.8.2=pyhd8ed1ab_0 52 | - python_abi=3.8=3_cp38 53 | - pyzmq=25.1.0=py38h509eb50_0 54 | - readline=8.2=h8228510_1 55 | - setuptools=68.0.0=pyhd8ed1ab_0 56 | - six=1.16.0=pyh6c4a22f_0 57 | - sqlite=3.42.0=h2c6b66d_0 58 | - stack_data=0.6.2=pyhd8ed1ab_0 59 | - tk=8.6.12=h27826a3_0 60 | - tornado=6.3.2=py38h01eb140_0 61 | - traitlets=5.9.0=pyhd8ed1ab_0 62 | - typing_extensions=4.7.1=pyha770c72_0 63 | - wcwidth=0.2.6=pyhd8ed1ab_0 64 | - wheel=0.41.0=pyhd8ed1ab_0 65 | - xz=5.2.6=h166bdaf_0 66 | - zeromq=4.3.4=h9c3ff4c_1 67 | - zipp=3.16.2=pyhd8ed1ab_0 68 | - zlib=1.2.13=hd590300_5 69 | - pip: 70 | - appdirs==1.4.4 71 | - charset-normalizer==3.1.0 72 | - click==8.1.3 73 | - cloudpickle==2.2.1 74 | - cmake==3.26.3 75 | - colorama==0.4.4 76 | - colorlog==4.7.2 77 | - contourpy==1.0.7 78 | - cycler==0.11.0 79 | - docker-pycreds==0.4.0 80 | - fast-pytorch-kmeans==0.1.9 81 | - filelock==3.11.0 82 | - fonttools==4.39.2 83 | - future==0.18.3 84 | - gitdb==4.0.10 85 | - gitpython==3.1.31 86 | - hyperopt==0.2.4 87 | - idna==3.4 88 | - importlib-resources==5.12.0 89 | - ipython==8.11.0 90 | - ipywidgets==8.0.5 91 | - jinja2==3.1.2 92 | - joblib==1.2.0 93 | - jupyterlab-widgets==3.0.6 94 | - kiwisolver==1.4.4 95 | - lit==16.0.1 96 | - llvmlite==0.36.0 97 | - matplotlib==3.1.2 98 | - mpmath==1.3.0 99 | - networkx==3.0 100 | - numba==0.53.1 101 | - numpy==1.23.5 102 | - nvidia-cublas-cu11==11.10.3.66 103 | - nvidia-cuda-cupti-cu11==11.7.101 104 | - nvidia-cuda-nvrtc-cu11==11.7.99 105 | - nvidia-cuda-runtime-cu11==11.7.99 106 | - nvidia-cudnn-cu11==8.5.0.96 107 | - nvidia-cufft-cu11==10.9.0.58 108 | - nvidia-curand-cu11==10.2.10.91 109 | - nvidia-cusolver-cu11==11.4.0.1 110 | - nvidia-cusparse-cu11==11.7.4.91 111 | - nvidia-nccl-cu11==2.14.3 112 | - nvidia-nvtx-cu11==11.7.91 113 | - oauthlib==3.2.2 114 | - opentsne==0.7.1 115 | - pandas==1.0.5 116 | - pathtools==0.1.2 117 | - pillow==9.4.0 118 | - plotly==5.14.1 119 | - prompt-toolkit==3.0.38 120 | - py4j==0.10.9.7 121 | - pynvml==11.5.0 122 | - pytorch-ranger==0.1.1 123 | - pytz==2023.3 124 | - pyyaml==6.0 125 | - recbole==1.1.1 126 | - requests==2.28.2 127 | - scikit-learn==0.23.2 128 | - scipy==1.6.0 129 | - seaborn==0.12.2 130 | - sentry-sdk==1.19.1 131 | - setproctitle==1.3.2 132 | - smmap==5.0.0 133 | - sympy==1.11.1 134 | - tabulate==0.9.0 135 | - tenacity==8.2.2 136 | - thop==0.1.1-2209072238 137 | - threadpoolctl==3.1.0 138 | - torch==2.0.0 139 | - torch-geometric==2.3.0 140 | - torch-optimizer==0.3.0 141 | - torchaudio==0.7.2 142 | - torchvision==0.8.2+cu110 143 | - tqdm==4.65.0 144 | - triton==2.0.0 145 | - typing-extensions==4.5.0 146 | - urllib3==1.26.15 147 | - wandb==0.14.2 148 | - widgetsnbextension==4.0.6 149 | prefix: /home/tako/anaconda3/envs/sm 150 | -------------------------------------------------------------------------------- /recbole/evaluator/base_metric.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/10/21 2 | # @Author : Kaiyuan Li 3 | # @email : tsotfsk@outlook.com 4 | 5 | # UPDATE 6 | # @Time : 2020/10/21, 2021/8/29 7 | # @Author : Kaiyuan Li, Zhichao Feng 8 | # @email : tsotfsk@outlook.com, fzcbupt@gmail.com 9 | 10 | """ 11 | recbole.evaluator.abstract_metric 12 | ##################################### 13 | """ 14 | 15 | import torch 16 | from recbole.utils import EvaluatorType 17 | 18 | 19 | class AbstractMetric(object): 20 | """:class:`AbstractMetric` is the base object of all metrics. If you want to 21 | implement a metric, you should inherit this class. 22 | 23 | Args: 24 | config (Config): the config of evaluator. 25 | """ 26 | smaller = False 27 | 28 | def __init__(self, config): 29 | self.decimal_place = config['metric_decimal_place'] 30 | 31 | def calculate_metric(self, dataobject): 32 | """Get the dictionary of a metric. 33 | 34 | Args: 35 | dataobject(DataStruct): it contains all the information needed to calculate metrics. 36 | 37 | Returns: 38 | dict: such as ``{'metric@10': 3153, 'metric@20': 0.3824}`` 39 | """ 40 | raise NotImplementedError('Method [calculate_metric] should be implemented.') 41 | 42 | 43 | class TopkMetric(AbstractMetric): 44 | """:class:`TopkMetric` is a base object of top-k metrics. If you want to 45 | implement an top-k metric, you can inherit this class. 46 | 47 | Args: 48 | config (Config): The config of evaluator. 49 | """ 50 | metric_type = EvaluatorType.RANKING 51 | metric_need = ['rec.topk'] 52 | 53 | def __init__(self, config): 54 | super().__init__(config) 55 | self.topk = config['topk'] 56 | 57 | def used_info(self, dataobject): 58 | """Get the bool matrix indicating whether the corresponding item is positive 59 | and number of positive items for each user. 60 | """ 61 | rec_mat = dataobject.get('rec.topk') 62 | topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1) 63 | return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy() 64 | 65 | def topk_result(self, metric, value): 66 | """Match the metric value to the `k` and put them in `dictionary` form. 67 | 68 | Args: 69 | metric(str): the name of calculated metric. 70 | value(numpy.ndarray): metrics for each user, including values from `metric@1` to `metric@max(self.topk)`. 71 | 72 | Returns: 73 | dict: metric values required in the configuration. 74 | """ 75 | metric_dict = {} 76 | avg_result = value.mean(axis=0) 77 | for k in self.topk: 78 | key = '{}@{}'.format(metric, k) 79 | metric_dict[key] = round(avg_result[k - 1], self.decimal_place) 80 | return metric_dict 81 | 82 | def metric_info(self, pos_index, pos_len=None): 83 | """Calculate the value of the metric. 84 | 85 | Args: 86 | pos_index(numpy.ndarray): a bool matrix, shape of ``n_users * max(topk)``. The item with the (j+1)-th \ 87 | highest score of i-th user is positive if ``pos_index[i][j] == True`` and negative otherwise. 88 | pos_len(numpy.ndarray): a vector representing the number of positive items per user, shape of ``(n_users,)``. 89 | 90 | Returns: 91 | numpy.ndarray: metrics for each user, including values from `metric@1` to `metric@max(self.topk)`. 92 | """ 93 | raise NotImplementedError('Method [metric_info] of top-k metric should be implemented.') 94 | 95 | 96 | class LossMetric(AbstractMetric): 97 | """:class:`LossMetric` is a base object of loss based metrics and AUC. If you want to 98 | implement an loss based metric, you can inherit this class. 99 | 100 | Args: 101 | config (Config): The config of evaluator. 102 | """ 103 | metric_type = EvaluatorType.VALUE 104 | metric_need = ['rec.score', 'data.label'] 105 | 106 | def __init__(self, config): 107 | super().__init__(config) 108 | 109 | def used_info(self, dataobject): 110 | """Get scores that model predicted and the ground truth.""" 111 | preds = dataobject.get('rec.score') 112 | trues = dataobject.get('data.label') 113 | 114 | return preds.squeeze(-1).numpy(), trues.squeeze(-1).numpy() 115 | 116 | def output_metric(self, metric, dataobject): 117 | preds, trues = self.used_info(dataobject) 118 | result = self.metric_info(preds, trues) 119 | return {metric: round(result, self.decimal_place)} 120 | 121 | def metric_info(self, preds, trues): 122 | """Calculate the value of the metric. 123 | 124 | Args: 125 | preds (numpy.ndarray): the scores predicted by model, a one-dimensional vector. 126 | trues (numpy.ndarray): the label of items, which has the same shape as ``preds``. 127 | 128 | Returns: 129 | float: The value of the metric. 130 | """ 131 | raise NotImplementedError('Method [metric_info] of loss-based metric should be implemented.') 132 | -------------------------------------------------------------------------------- /recbole/data/dataset/customized_dataset.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/10/19 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2021/7/9 7 | # @Author : Yupeng Hou 8 | # @Email : houyupeng@ruc.edu.cn 9 | 10 | """ 11 | recbole.data.customized_dataset 12 | ################################## 13 | 14 | We only recommend building customized datasets by inheriting. 15 | 16 | Customized datasets named ``[Model Name]Dataset`` can be automatically called. 17 | """ 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from recbole.data.dataset import KGSeqDataset, SequentialDataset 23 | from recbole.data.interaction import Interaction 24 | from recbole.sampler import SeqSampler 25 | from recbole.utils.enum_type import FeatureType 26 | 27 | 28 | class GRU4RecKGDataset(KGSeqDataset): 29 | 30 | def __init__(self, config): 31 | super().__init__(config) 32 | 33 | 34 | class KSRDataset(KGSeqDataset): 35 | 36 | def __init__(self, config): 37 | super().__init__(config) 38 | 39 | 40 | class DIENDataset(SequentialDataset): 41 | """:class:`DIENDataset` is based on :class:`~recbole.data.dataset.sequential_dataset.SequentialDataset`. 42 | It is different from :class:`SequentialDataset` in `data_augmentation`. 43 | It add users' negative item list to interaction. 44 | 45 | The original version of sampling negative item list is implemented by Zhichao Feng (fzcbupt@gmail.com) in 2021/2/25, 46 | and he updated the codes in 2021/3/19. In 2021/7/9, Yupeng refactored SequentialDataset & SequentialDataLoader, 47 | then refactored DIENDataset, either. 48 | 49 | Attributes: 50 | augmentation (bool): Whether the interactions should be augmented in RecBole. 51 | seq_sample (recbole.sampler.SeqSampler): A sampler used to sample negative item sequence. 52 | neg_item_list_field (str): Field name for negative item sequence. 53 | neg_item_list (torch.tensor): all users' negative item history sequence. 54 | """ 55 | 56 | def __init__(self, config): 57 | super().__init__(config) 58 | 59 | list_suffix = config['LIST_SUFFIX'] 60 | neg_prefix = config['NEG_PREFIX'] 61 | self.seq_sampler = SeqSampler(self) 62 | self.neg_item_list_field = neg_prefix + self.iid_field + list_suffix 63 | self.neg_item_list = self.seq_sampler.sample_neg_sequence(self.inter_feat[self.iid_field]) 64 | 65 | def data_augmentation(self): 66 | """Augmentation processing for sequential dataset. 67 | 68 | E.g., ``u1`` has purchase sequence ````, 69 | then after augmentation, we will generate three cases. 70 | 71 | ``u1, | i2`` 72 | 73 | (Which means given user_id ``u1`` and item_seq ````, 74 | we need to predict the next item ``i2``.) 75 | 76 | The other cases are below: 77 | 78 | ``u1, | i3`` 79 | 80 | ``u1, | i4`` 81 | """ 82 | self.logger.debug('data_augmentation') 83 | 84 | self._aug_presets() 85 | 86 | self._check_field('uid_field', 'time_field') 87 | max_item_list_len = self.config['MAX_ITEM_LIST_LENGTH'] 88 | self.sort(by=[self.uid_field, self.time_field], ascending=True) 89 | last_uid = None 90 | uid_list, item_list_index, target_index, item_list_length = [], [], [], [] 91 | seq_start = 0 92 | for i, uid in enumerate(self.inter_feat[self.uid_field].numpy()): 93 | if last_uid != uid: 94 | last_uid = uid 95 | seq_start = i 96 | else: 97 | if i - seq_start > max_item_list_len: 98 | seq_start += 1 99 | uid_list.append(uid) 100 | item_list_index.append(slice(seq_start, i)) 101 | target_index.append(i) 102 | item_list_length.append(i - seq_start) 103 | 104 | uid_list = np.array(uid_list) 105 | item_list_index = np.array(item_list_index) 106 | target_index = np.array(target_index) 107 | item_list_length = np.array(item_list_length, dtype=np.int64) 108 | 109 | new_length = len(item_list_index) 110 | new_data = self.inter_feat[target_index] 111 | new_dict = { 112 | self.item_list_length_field: torch.tensor(item_list_length), 113 | } 114 | 115 | for field in self.inter_feat: 116 | if field != self.uid_field: 117 | list_field = getattr(self, f'{field}_list_field') 118 | list_len = self.field2seqlen[list_field] 119 | shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len 120 | list_ftype = self.field2type[list_field] 121 | dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 122 | new_dict[list_field] = torch.zeros(shape, dtype=dtype) 123 | 124 | value = self.inter_feat[field] 125 | for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): 126 | new_dict[list_field][i][:length] = value[index] 127 | 128 | # DIEN 129 | if field == self.iid_field: 130 | new_dict[self.neg_item_list_field] = torch.zeros(shape, dtype=dtype) 131 | for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): 132 | new_dict[self.neg_item_list_field][i][:length] = self.neg_item_list[index] 133 | 134 | new_data.update(Interaction(new_dict)) 135 | self.inter_feat = new_data 136 | -------------------------------------------------------------------------------- /recbole/model/loss.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/6/26 2 | # @Author : Shanlei Mu 3 | # @Email : slmu@ruc.edu.cn 4 | 5 | # UPDATE: 6 | # @Time : 2020/8/7, 2021/12/22 7 | # @Author : Shanlei Mu, Gaowei Zhang 8 | # @Email : slmu@ruc.edu.cn, 1462034631@qq.com 9 | 10 | 11 | """ 12 | recbole.model.loss 13 | ####################### 14 | Common Loss in recommender system 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class BPRLoss(nn.Module): 23 | """ BPRLoss, based on Bayesian Personalized Ranking 24 | 25 | Args: 26 | - gamma(float): Small value to avoid division by zero 27 | 28 | Shape: 29 | - Pos_score: (N) 30 | - Neg_score: (N), same shape as the Pos_score 31 | - Output: scalar. 32 | 33 | Examples:: 34 | 35 | >>> loss = BPRLoss() 36 | >>> pos_score = torch.randn(3, requires_grad=True) 37 | >>> neg_score = torch.randn(3, requires_grad=True) 38 | >>> output = loss(pos_score, neg_score) 39 | >>> output.backward() 40 | """ 41 | 42 | def __init__(self, gamma=1e-10): 43 | super(BPRLoss, self).__init__() 44 | self.gamma = gamma 45 | 46 | def forward(self, pos_score, neg_score): 47 | loss = -torch.log(self.gamma + torch.sigmoid(pos_score - neg_score)).mean() 48 | return loss 49 | 50 | 51 | class SSMLoss(nn.Module): 52 | def __init__(self): 53 | """ SampledSoftmaxCrossEntropyLoss 54 | :param num_negs: number of negative instances in bpr loss. 55 | """ 56 | super(SSMLoss, self).__init__() 57 | 58 | def forward(self, pos_score, neg_score, num_neg_items=None): 59 | """ 60 | :param pos_score: predicted values of shape (batch_size, ) 61 | :param neg_score: predicted values of shape (batch_size, num_negs) 62 | """ 63 | scores = torch.cat((pos_score.unsqueeze(1), neg_score), dim=1) 64 | probs = F.softmax(scores, dim=1) 65 | hit_probs = probs[:, 0] 66 | loss = -torch.log(hit_probs).mean() 67 | return loss 68 | 69 | # TODO: models using CCL must be modified 70 | class CCLoss(nn.Module): 71 | def __init__(self, negative_weight=None): 72 | """ 73 | :param num_negs: int, number of negative samples 74 | :param negative_weight:, float, the weight set to the negative samples. When negative_weight=None, it 75 | equals to num_negs 76 | """ 77 | super(CCLoss, self).__init__() 78 | self._negative_weight = negative_weight 79 | 80 | def forward(self, pos_score, neg_score, margin, num_neg_items=None): 81 | """ 82 | :param margin: float, margin in CosineContrastiveLoss 83 | :param pos_score: predicted values of shape (batch_size, ) 84 | :param neg_score: predicted values of shape (batch_size, num_negs) 85 | """ 86 | pos_loss = torch.relu(1 - pos_score) # TODO relu 빼보기 87 | neg_loss = torch.relu(neg_score - margin) 88 | if self._negative_weight: 89 | if num_neg_items is None: 90 | loss = pos_loss + neg_loss.mean(dim=-1) * self._negative_weight 91 | else: 92 | loss = pos_loss + neg_loss.sum(dim=-1) / num_neg_items * self._negative_weight 93 | else: 94 | loss = pos_loss + neg_loss.sum(dim=-1) 95 | return loss.mean() 96 | 97 | 98 | class DualCCLoss(nn.Module): 99 | def __init__(self, margin=0, negative_weight=None): 100 | """ 101 | :param margin: float, margin in CosineContrastiveLoss 102 | :param num_negs: int, number of negative samples 103 | :param negative_weight:, float, the weight set to the negative samples. When negative_weight=None, it 104 | equals to num_negs 105 | """ 106 | super(DualCCLoss, self).__init__() 107 | self._margin = margin 108 | self._negative_weight = negative_weight 109 | 110 | def forward(self, pos_score, neg_score): 111 | """ 112 | :param pos_score: predicted values of shape (batch_size, ) 113 | :param neg_score: predicted values of shape (batch_size, num_negs) 114 | """ 115 | pos_loss = torch.relu(1 - pos_score) # TODO relu 빼보기 116 | neg_loss = torch.relu(neg_score - self._margin) 117 | if self._negative_weight: 118 | loss = pos_loss.mean(dim=-1) + neg_loss.mean(dim=-1) * self._negative_weight 119 | else: 120 | loss = pos_loss.sum(dim=-1) + neg_loss.sum(dim=-1) 121 | return loss.mean() 122 | 123 | 124 | class RegLoss(nn.Module): 125 | """ RegLoss, L2 regularization on model parameters 126 | 127 | """ 128 | 129 | def __init__(self): 130 | super(RegLoss, self).__init__() 131 | 132 | def forward(self, parameters): 133 | reg_loss = None 134 | for W in parameters: 135 | if reg_loss is None: 136 | reg_loss = W.norm(2) 137 | else: 138 | reg_loss = reg_loss + W.norm(2) 139 | return reg_loss 140 | 141 | 142 | class EmbLoss(nn.Module): 143 | """ EmbLoss, regularization on embeddings 144 | 145 | """ 146 | 147 | def __init__(self, norm=2): 148 | super(EmbLoss, self).__init__() 149 | self.norm = norm 150 | 151 | def forward(self, *embeddings, require_pow=False): 152 | if require_pow: 153 | emb_loss = torch.zeros(1).to(embeddings[-1].device) 154 | for embedding in embeddings: 155 | emb_loss += torch.pow(input=torch.norm(embedding, p=self.norm), exponent=self.norm) 156 | emb_loss /= embeddings[-1].shape[0] 157 | emb_loss /= self.norm 158 | return emb_loss 159 | else: 160 | emb_loss = torch.zeros(1).to(embeddings[-1].device) 161 | for embedding in embeddings: 162 | emb_loss += torch.norm(embedding, p=self.norm) 163 | emb_loss /= embeddings[-1].shape[0] 164 | return emb_loss 165 | 166 | 167 | class EmbMarginLoss(nn.Module): 168 | """ EmbMarginLoss, regularization on embeddings 169 | """ 170 | 171 | def __init__(self, power=2): 172 | super(EmbMarginLoss, self).__init__() 173 | self.power = power 174 | 175 | def forward(self, *embeddings): 176 | dev = embeddings[-1].device 177 | cache_one = torch.tensor(1.0).to(dev) 178 | cache_zero = torch.tensor(0.0).to(dev) 179 | emb_loss = torch.tensor(0.).to(dev) 180 | for embedding in embeddings: 181 | norm_e = torch.sum(embedding ** self.power, dim=1, keepdim=True) 182 | emb_loss += torch.sum(torch.max(norm_e - cache_one, cache_zero)) 183 | return emb_loss 184 | -------------------------------------------------------------------------------- /recbole/quick_start/quick_start.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/10/6 2 | # @Author : Shanlei Mu 3 | # @Email : slmu@ruc.edu.cn 4 | 5 | """ 6 | recbole.quick_start 7 | ######################## 8 | """ 9 | import logging 10 | import os 11 | from logging import getLogger 12 | 13 | import torch 14 | 15 | from recbole.config import Config 16 | from recbole.data import create_dataset, data_preparation 17 | from recbole.utils import (get_model, get_trainer, init_logger, init_seed, 18 | set_color) 19 | 20 | 21 | def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True): 22 | r""" A fast running api, which includes the complete process of 23 | training and testing a model on a specified dataset 24 | 25 | Args: 26 | model (str, optional): Model name. Defaults to ``None``. 27 | dataset (str, optional): Dataset name. Defaults to ``None``. 28 | config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. 29 | config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. 30 | saved (bool, optional): Whether to save the model. Defaults to ``True``. 31 | """ 32 | # configurations initialization 33 | config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict) 34 | init_seed(config['seed'], config['reproducibility']) 35 | # logger initialization 36 | init_logger(config) 37 | logger = getLogger() 38 | 39 | logger.info(config) 40 | 41 | # dataset filtering 42 | dataset = create_dataset(config) 43 | logger.info(dataset) 44 | 45 | # dataset splitting 46 | train_data, valid_data, test_data = data_preparation(config, dataset) 47 | 48 | # positive pair 미리 저장 49 | # if not os.path.exists(f'saved/{dataset.dataset_name}-pos_pairs'): 50 | # if config['save_pos_pairs']: 51 | # # make indices of positive pairs by users 52 | # rating_matrix = train_data.dataset.inter_matrix(form='csr') 53 | 54 | # # TODO: 더 효율적인 구현 방법 생각 55 | # pos_pairs_dict = {} 56 | # for i, items_by_user in enumerate(rating_matrix): 57 | # items_by_user = torch.tensor(items_by_user.toarray().squeeze()) 58 | 59 | # nonzero_index = torch.nonzero(items_by_user).squeeze() 60 | # index = torch.combinations(nonzero_index) # 제일 오래걸리는 부분 61 | 62 | # pos_pairs_dict[i] = index 63 | 64 | # with open(f'saved/{dataset.dataset_name}-pos_pairs', 'wb') as f: 65 | # pickle.dump(pos_pairs_dict, f) 66 | 67 | # model loading and initialization 68 | init_seed(config['seed'], config['reproducibility']) 69 | model = get_model(config['model'])(config, train_data.dataset).to(config['device']) 70 | logger.info(model) 71 | 72 | # trainer loading and initialization 73 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) 74 | 75 | # model training 76 | best_valid_score, best_valid_result, train_time = trainer.fit( 77 | train_data, valid_data, saved=saved, show_progress=config['show_progress'], # test_data=test_data 78 | ) 79 | 80 | # model evaluation 81 | test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress']) 82 | 83 | logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}') 84 | logger.info(set_color('test result', 'yellow') + f': {test_result}') 85 | 86 | print('\t'.join(list(test_result.keys()))) 87 | print('\t'.join(list(str(i) for i in test_result.values()))) 88 | 89 | return { 90 | 'best_valid_score': best_valid_score, 91 | 'valid_score_bigger': config['valid_metric_bigger'], 92 | 'best_valid_result': best_valid_result, 93 | 'test_result': test_result 94 | } 95 | 96 | 97 | def objective_function(config_dict=None, config_file_list=None, saved=True): 98 | r""" The default objective_function used in HyperTuning 99 | 100 | Args: 101 | config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. 102 | config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. 103 | saved (bool, optional): Whether to save the model. Defaults to ``True``. 104 | """ 105 | 106 | config = Config(config_dict=config_dict, config_file_list=config_file_list) 107 | init_seed(config['seed'], config['reproducibility']) 108 | logging.basicConfig(level=logging.ERROR) 109 | 110 | # make directories for saving results 111 | os.makedirs(f"output_files/{config['dataset']}", exist_ok=True) 112 | 113 | dataset = create_dataset(config) 114 | train_data, valid_data, test_data = data_preparation(config, dataset) 115 | init_seed(config['seed'], config['reproducibility']) 116 | model = get_model(config['model'])(config, train_data.dataset).to(config['device']) 117 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) 118 | best_valid_score, best_valid_result, train_time = trainer.fit(train_data, valid_data, verbose=False, saved=saved) 119 | test_result = trainer.evaluate(test_data, load_best_model=saved) 120 | 121 | return { 122 | 'best_valid_score': best_valid_score, 123 | 'valid_score_bigger': config['valid_metric_bigger'], 124 | 'best_valid_result': best_valid_result, 125 | 'test_result': test_result, 126 | 'saved_model_file': trainer.saved_model_file, 127 | 'train_time': train_time 128 | } 129 | 130 | 131 | def load_data_and_model(model_file): 132 | r"""Load filtered dataset, split dataloaders and saved model. 133 | 134 | Args: 135 | model_file (str): The path of saved model file. 136 | 137 | Returns: 138 | tuple: 139 | - config (Config): An instance object of Config, which record parameter information in :attr:`model_file`. 140 | - model (AbstractRecommender): The model load from :attr:`model_file`. 141 | - dataset (Dataset): The filtered dataset. 142 | - train_data (AbstractDataLoader): The dataloader for training. 143 | - valid_data (AbstractDataLoader): The dataloader for validation. 144 | - test_data (AbstractDataLoader): The dataloader for testing. 145 | """ 146 | checkpoint = torch.load(model_file) 147 | config = checkpoint['config'] 148 | init_seed(config['seed'], config['reproducibility']) 149 | init_logger(config) 150 | logger = getLogger() 151 | logger.info(config) 152 | 153 | dataset = create_dataset(config) 154 | logger.info(dataset) 155 | train_data, valid_data, test_data = data_preparation(config, dataset) 156 | 157 | init_seed(config['seed'], config['reproducibility']) 158 | model = get_model(config['model'])(config, train_data.dataset).to(config['device']) 159 | model.load_state_dict(checkpoint['state_dict']) 160 | model.load_other_parameter(checkpoint.get('other_parameter')) 161 | 162 | return config, model, dataset, train_data, valid_data, test_data 163 | -------------------------------------------------------------------------------- /recbole/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/7/17 3 | # @Author : Shanlei Mu 4 | # @Email : slmu@ruc.edu.cn 5 | 6 | # UPDATE 7 | # @Time : 2021/3/8 8 | # @Author : Jiawei Guan 9 | # @Email : guanjw@ruc.edu.cn 10 | 11 | """ 12 | recbole.utils.utils 13 | ################################ 14 | """ 15 | 16 | import datetime 17 | import importlib 18 | import os 19 | import random 20 | 21 | import numpy as np 22 | import torch 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | from recbole.utils.enum_type import ModelType 26 | 27 | 28 | def get_local_time(): 29 | r"""Get current time 30 | 31 | Returns: 32 | str: current time 33 | """ 34 | cur = datetime.datetime.now() 35 | cur = cur.strftime('%b-%d-%Y_%H-%M-%S') 36 | 37 | return cur 38 | 39 | 40 | def ensure_dir(dir_path): 41 | r"""Make sure the directory exists, if it does not exist, create it 42 | 43 | Args: 44 | dir_path (str): directory path 45 | 46 | """ 47 | if not os.path.exists(dir_path): 48 | os.makedirs(dir_path) 49 | 50 | 51 | def get_model(model_name): 52 | r"""Automatically select model class based on model name 53 | 54 | Args: 55 | model_name (str): model name 56 | 57 | Returns: 58 | Recommender: model class 59 | """ 60 | model_submodule = [ 61 | 'general_recommender', 'context_aware_recommender', 'sequential_recommender', 'knowledge_aware_recommender', 62 | 'exlib_recommender' 63 | ] 64 | 65 | model_file_name = model_name.lower() 66 | model_module = None 67 | for submodule in model_submodule: 68 | module_path = '.'.join(['recbole.model', submodule, model_file_name]) 69 | if importlib.util.find_spec(module_path, __name__): 70 | model_module = importlib.import_module(module_path, __name__) 71 | break 72 | 73 | if model_module is None: 74 | raise ValueError('`model_name` [{}] is not the name of an existing model.'.format(model_name)) 75 | model_class = getattr(model_module, model_name) 76 | return model_class 77 | 78 | 79 | def get_trainer(model_type, model_name): 80 | r"""Automatically select trainer class based on model type and model name 81 | 82 | Args: 83 | model_type (ModelType): model type 84 | model_name (str): model name 85 | 86 | Returns: 87 | Trainer: trainer class 88 | """ 89 | try: 90 | return getattr(importlib.import_module('recbole.trainer'), model_name + 'Trainer') 91 | except AttributeError: 92 | if model_type == ModelType.KNOWLEDGE: 93 | return getattr(importlib.import_module('recbole.trainer'), 'KGTrainer') 94 | elif model_type == ModelType.TRADITIONAL: 95 | return getattr(importlib.import_module('recbole.trainer'), 'TraditionalTrainer') 96 | else: 97 | return getattr(importlib.import_module('recbole.trainer'), 'Trainer') 98 | 99 | 100 | def early_stopping(value, best, cur_step, max_step, bigger=True): 101 | r""" validation-based early stopping 102 | 103 | Args: 104 | value (float): current result 105 | best (float): best result 106 | cur_step (int): the number of consecutive steps that did not exceed the best result 107 | max_step (int): threshold steps for stopping 108 | bigger (bool, optional): whether the bigger the better 109 | 110 | Returns: 111 | tuple: 112 | - float, 113 | best result after this step 114 | - int, 115 | the number of consecutive steps that did not exceed the best result after this step 116 | - bool, 117 | whether to stop 118 | - bool, 119 | whether to update 120 | """ 121 | stop_flag = False 122 | update_flag = False 123 | if bigger: 124 | if value >= best: 125 | cur_step = 0 126 | best = value 127 | update_flag = True 128 | else: 129 | cur_step += 1 130 | if cur_step > max_step: 131 | stop_flag = True 132 | else: 133 | if value <= best: 134 | cur_step = 0 135 | best = value 136 | update_flag = True 137 | else: 138 | cur_step += 1 139 | if cur_step > max_step: 140 | stop_flag = True 141 | return best, cur_step, stop_flag, update_flag 142 | 143 | 144 | def calculate_valid_score(valid_result, valid_metric=None): 145 | r""" return valid score from valid result 146 | 147 | Args: 148 | valid_result (dict): valid result 149 | valid_metric (str, optional): the selected metric in valid result for valid score 150 | 151 | Returns: 152 | float: valid score 153 | """ 154 | if valid_metric: 155 | return valid_result[valid_metric] 156 | else: 157 | return valid_result['Recall@10'] 158 | 159 | 160 | def dict2str(result_dict): 161 | r""" convert result dict to str 162 | 163 | Args: 164 | result_dict (dict): result dict 165 | 166 | Returns: 167 | str: result str 168 | """ 169 | 170 | return ' '.join([str(metric) + ' : ' + str(value) for metric, value in result_dict.items()]) 171 | 172 | 173 | def init_seed(seed, reproducibility): 174 | r""" init random seed for random functions in numpy, torch, cuda and cudnn 175 | 176 | Args: 177 | seed (int): random seed 178 | reproducibility (bool): Whether to require reproducibility 179 | """ 180 | random.seed(seed) 181 | np.random.seed(seed) 182 | torch.manual_seed(seed) 183 | torch.cuda.manual_seed(seed) 184 | torch.cuda.manual_seed_all(seed) 185 | if reproducibility: 186 | torch.backends.cudnn.benchmark = False 187 | torch.backends.cudnn.deterministic = True 188 | else: 189 | torch.backends.cudnn.benchmark = True 190 | torch.backends.cudnn.deterministic = False 191 | 192 | 193 | def get_tensorboard(logger): 194 | r""" Creates a SummaryWriter of Tensorboard that can log PyTorch models and metrics into a directory for 195 | visualization within the TensorBoard UI. 196 | For the convenience of the user, the naming rule of the SummaryWriter's log_dir is the same as the logger. 197 | 198 | Args: 199 | logger: its output filename is used to name the SummaryWriter's log_dir. 200 | If the filename is not available, we will name the log_dir according to the current time. 201 | 202 | Returns: 203 | SummaryWriter: it will write out events and summaries to the event file. 204 | """ 205 | base_path = 'log_tensorboard' 206 | 207 | dir_name = None 208 | for handler in logger.handlers: 209 | if hasattr(handler, "baseFilename"): 210 | dir_name = os.path.basename(getattr(handler, 'baseFilename')).split('.')[0] 211 | break 212 | if dir_name is None: 213 | dir_name = '{}-{}'.format('model', get_local_time()) 214 | 215 | dir_path = os.path.join(base_path, dir_name) 216 | writer = SummaryWriter(dir_path) 217 | return writer 218 | 219 | 220 | def get_gpu_usage(device=None): 221 | r""" Return the reserved memory and total memory of given device in a string. 222 | Args: 223 | device: cuda.device. It is the device that the model run on. 224 | 225 | Returns: 226 | str: it contains the info about reserved memory and total memory of given device. 227 | """ 228 | 229 | reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 3 230 | total = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 231 | 232 | return '{:.2f} G/{:.2f} G'.format(reserved, total) 233 | -------------------------------------------------------------------------------- /recbole/data/dataloader/knowledge_dataloader.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/7/7 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2020/9/18, 2020/9/21, 2020/8/31 7 | # @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li 8 | # @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com 9 | 10 | """ 11 | recbole.data.dataloader.knowledge_dataloader 12 | ################################################ 13 | """ 14 | from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader 15 | from recbole.data.dataloader.general_dataloader import TrainDataLoader 16 | from recbole.data.interaction import Interaction 17 | from recbole.utils import InputType, KGDataLoaderState 18 | 19 | 20 | class KGDataLoader(AbstractDataLoader): 21 | """:class:`KGDataLoader` is a dataloader which would return the triplets with negative examples 22 | in a knowledge graph. 23 | 24 | Args: 25 | config (Config): The config of dataloader. 26 | dataset (Dataset): The dataset of dataloader. 27 | sampler (KGSampler): The knowledge graph sampler of dataloader. 28 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 29 | 30 | Attributes: 31 | shuffle (bool): Whether the dataloader will be shuffle after a round. 32 | However, in :class:`KGDataLoader`, it's guaranteed to be ``True``. 33 | """ 34 | 35 | def __init__(self, config, dataset, sampler, shuffle=False): 36 | if shuffle is False: 37 | shuffle = True 38 | self.logger.warning('kg based dataloader must shuffle the data') 39 | 40 | self.neg_sample_num = 1 41 | 42 | self.neg_prefix = config['NEG_PREFIX'] 43 | self.hid_field = dataset.head_entity_field 44 | self.tid_field = dataset.tail_entity_field 45 | 46 | # kg negative cols 47 | self.neg_tid_field = self.neg_prefix + self.tid_field 48 | dataset.copy_field_property(self.neg_tid_field, self.tid_field) 49 | 50 | super().__init__(config, dataset, sampler, shuffle=shuffle) 51 | 52 | def _init_batch_size_and_step(self): 53 | batch_size = self.config['train_batch_size'] 54 | self.step = batch_size 55 | self.set_batch_size(batch_size) 56 | 57 | @property 58 | def pr_end(self): 59 | return len(self.dataset.kg_feat) 60 | 61 | def _shuffle(self): 62 | self.dataset.kg_feat.shuffle() 63 | 64 | def _next_batch_data(self): 65 | cur_data = self.dataset.kg_feat[self.pr:self.pr + self.step] 66 | head_ids = cur_data[self.hid_field].numpy() 67 | neg_tail_ids = self.sampler.sample_by_entity_ids(head_ids, self.neg_sample_num) 68 | cur_data.update(Interaction({self.neg_tid_field: neg_tail_ids})) 69 | self.pr += self.step 70 | return cur_data 71 | 72 | 73 | class KnowledgeBasedDataLoader(AbstractDataLoader): 74 | """:class:`KnowledgeBasedDataLoader` is used for knowledge based model. 75 | 76 | It has three states, which is saved in :attr:`state`. 77 | In different states, :meth:`~_next_batch_data` will return different :class:`~recbole.data.interaction.Interaction`. 78 | Detailed, please see :attr:`~state`. 79 | 80 | Args: 81 | config (Config): The config of dataloader. 82 | dataset (Dataset): The dataset of dataloader. 83 | sampler (Sampler): The sampler of dataloader. 84 | kg_sampler (KGSampler): The knowledge graph sampler of dataloader. 85 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 86 | 87 | Attributes: 88 | state (KGDataLoaderState): 89 | This dataloader has three states: 90 | 91 | - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RS` 92 | - :obj:`~recbole.utils.enum_type.KGDataLoaderState.KG` 93 | - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RSKG` 94 | 95 | In the first state, this dataloader would only return the triplets with negative 96 | examples in a knowledge graph. 97 | 98 | In the second state, this dataloader would only return the user-item interaction. 99 | 100 | In the last state, this dataloader would return both knowledge graph information 101 | and user-item interaction information. 102 | """ 103 | 104 | def __init__(self, config, dataset, sampler, kg_sampler, shuffle=False): 105 | 106 | # using sampler 107 | self.general_dataloader = TrainDataLoader(config, dataset, sampler, shuffle=shuffle) 108 | 109 | # using kg_sampler 110 | self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler, shuffle=True) 111 | 112 | self.state = None 113 | 114 | super().__init__(config, dataset, sampler, shuffle=shuffle) 115 | 116 | def _init_batch_size_and_step(self): 117 | pass 118 | 119 | def update_config(self, config): 120 | self.general_dataloader.update_config(config) 121 | self.kg_dataloader.update_config(config) 122 | 123 | def __iter__(self): 124 | if self.state is None: 125 | raise ValueError( 126 | 'The dataloader\'s state must be set when using the kg based dataloader, ' 127 | 'you should call set_mode() before __iter__()' 128 | ) 129 | if self.state == KGDataLoaderState.KG: 130 | return self.kg_dataloader.__iter__() 131 | elif self.state == KGDataLoaderState.RS: 132 | return self.general_dataloader.__iter__() 133 | elif self.state == KGDataLoaderState.RSKG: 134 | self.kg_dataloader.__iter__() 135 | self.general_dataloader.__iter__() 136 | return self 137 | 138 | def _shuffle(self): 139 | pass 140 | 141 | def __next__(self): 142 | if self.general_dataloader.pr >= self.general_dataloader.pr_end: 143 | self.general_dataloader.pr = 0 144 | self.kg_dataloader.pr = 0 145 | raise StopIteration() 146 | return self._next_batch_data() 147 | 148 | def __len__(self): 149 | if self.state == KGDataLoaderState.KG: 150 | return len(self.kg_dataloader) 151 | else: 152 | return len(self.general_dataloader) 153 | 154 | @property 155 | def pr_end(self): 156 | if self.state == KGDataLoaderState.KG: 157 | return self.kg_dataloader.pr_end 158 | else: 159 | return self.general_dataloader.pr_end 160 | 161 | def _next_batch_data(self): 162 | try: 163 | kg_data = self.kg_dataloader.__next__() 164 | except StopIteration: 165 | kg_data = self.kg_dataloader.__next__() 166 | rec_data = self.general_dataloader.__next__() 167 | rec_data.update(kg_data) 168 | return rec_data 169 | 170 | def set_mode(self, state): 171 | """Set the mode of :class:`KnowledgeBasedDataLoader`, it can be set to three states: 172 | 173 | - KGDataLoaderState.RS 174 | - KGDataLoaderState.KG 175 | - KGDataLoaderState.RSKG 176 | 177 | The state of :class:`KnowledgeBasedDataLoader` would affect the result of _next_batch_data(). 178 | 179 | Args: 180 | state (KGDataLoaderState): the state of :class:`KnowledgeBasedDataLoader`. 181 | """ 182 | if state not in set(KGDataLoaderState): 183 | raise NotImplementedError(f'Kg data loader has no state named [{self.state}].') 184 | self.state = state 185 | 186 | def get_model(self, model): 187 | """Let the general_dataloader get the model, used for dynamic sampling. 188 | """ 189 | self.general_dataloader.get_model(model) -------------------------------------------------------------------------------- /recbole/model/general_recommender/directau.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import scipy.sparse as sp 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from recbole.model.abstract_recommender import GeneralRecommender 11 | from recbole.model.init import xavier_normal_initialization 12 | from recbole.utils import InputType 13 | 14 | 15 | class DirectAU(GeneralRecommender): 16 | input_type = InputType.POINTWISE 17 | 18 | def __init__(self, config, dataset): 19 | super(DirectAU, self).__init__(config, dataset) 20 | 21 | # load parameters info 22 | self.embedding_size = config['embedding_size'] 23 | self.gamma = config['gamma'] 24 | self.encoder_name = config['encoder'] 25 | self.use_pretrain = config['use_pretrain'] 26 | 27 | # define layers and loss 28 | if self.encoder_name == 'MF': 29 | self.encoder = MFEncoder(self.n_users, self.n_items, self.embedding_size) 30 | elif self.encoder_name == 'LightGCN': 31 | self.n_layers = config['n_layers'] 32 | self.interaction_matrix = dataset.inter_matrix(form='coo').astype(np.float32) 33 | self.norm_adj = self.get_norm_adj_mat().to(self.device) 34 | self.encoder = LGCNEncoder(self.n_users, self.n_items, self.embedding_size, self.norm_adj, self.n_layers) 35 | else: 36 | raise ValueError('Non-implemented Encoder.') 37 | 38 | # storage variables for full sort evaluation acceleration 39 | self.restore_user_e = None 40 | self.restore_item_e = None 41 | 42 | # parameters initialization 43 | self.apply(xavier_normal_initialization) 44 | 45 | def get_norm_adj_mat(self): 46 | # build adj matrix 47 | A = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32) 48 | inter_M = self.interaction_matrix 49 | inter_M_t = self.interaction_matrix.transpose() 50 | data_dict = dict(zip(zip(inter_M.row, inter_M.col + self.n_users), [1] * inter_M.nnz)) 51 | data_dict.update(dict(zip(zip(inter_M_t.row + self.n_users, inter_M_t.col), [1] * inter_M_t.nnz))) 52 | A._update(data_dict) 53 | # norm adj matrix 54 | sumArr = (A > 0).sum(axis=1) 55 | # add epsilon to avoid divide by zero Warning 56 | diag = np.array(sumArr.flatten())[0] + 1e-7 57 | diag = np.power(diag, -0.5) 58 | D = sp.diags(diag) 59 | L = D * A * D 60 | # covert norm_adj matrix to tensor 61 | L = sp.coo_matrix(L) 62 | row = L.row 63 | col = L.col 64 | i = torch.LongTensor([row, col]) 65 | data = torch.FloatTensor(L.data) 66 | SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape)) 67 | return SparseL 68 | 69 | def forward(self, user, item): 70 | user_e, item_e = self.encoder(user, item) 71 | return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1) 72 | 73 | @staticmethod 74 | def alignment(x, y, alpha=2): 75 | return (x - y).norm(p=2, dim=1).pow(alpha).mean() 76 | 77 | @staticmethod 78 | def uniformity(x, t=2): 79 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() 80 | 81 | def calculate_loss(self, interaction): 82 | if self.restore_user_e is not None or self.restore_item_e is not None: 83 | self.restore_user_e, self.restore_item_e = None, None 84 | 85 | user = interaction[self.USER_ID] 86 | item = interaction[self.ITEM_ID] 87 | 88 | user_e, item_e = self.forward(user, item) 89 | align = self.alignment(user_e, item_e) 90 | uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2 91 | 92 | return align + uniform 93 | 94 | def predict(self, interaction): 95 | user = interaction[self.USER_ID] 96 | item = interaction[self.ITEM_ID] 97 | user_e = self.encoder.user_embedding(user) 98 | item_e = self.encoder.item_embedding(item) 99 | return torch.mul(user_e, item_e).sum(dim=1) 100 | 101 | def full_sort_predict(self, interaction): 102 | user = interaction[self.USER_ID] 103 | if self.encoder_name == 'LightGCN': 104 | if self.restore_user_e is None or self.restore_item_e is None: 105 | self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings() 106 | user_e = self.restore_user_e[user] 107 | all_item_e = self.restore_item_e 108 | else: 109 | user_e = self.encoder.user_embedding(user) 110 | all_item_e = self.encoder.item_embedding.weight 111 | score = torch.matmul(user_e, all_item_e.transpose(0, 1)) 112 | return score.view(-1) 113 | 114 | # def save_params(self): 115 | # user_embeddings, item_embeddings = self.encoder.get_all_embeddings() 116 | # np.save('user-DirectAU.npy', user_embeddings.data.cpu().numpy()) 117 | # np.save('item-DirectAU.npy', item_embeddings.data.cpu().numpy()) 118 | 119 | # def check(self, interaction): 120 | # user = interaction[self.USER_ID] 121 | # item = interaction[self.ITEM_ID] 122 | # user_e, item_e = self.forward(user, item) 123 | # 124 | # user_e = user_e.detach() 125 | # item_e = item_e.detach() 126 | # 127 | # alignment_loss = self.alignment(user_e, item_e) 128 | # uniform_loss = (self.uniformity(user_e) + self.uniformity(item_e)) / 2 129 | # 130 | # return alignment_loss, uniform_loss 131 | 132 | 133 | class MFEncoder(nn.Module): 134 | def __init__(self, user_num, item_num, emb_size): 135 | super(MFEncoder, self).__init__() 136 | self.user_embedding = nn.Embedding(user_num, emb_size) 137 | self.item_embedding = nn.Embedding(item_num, emb_size) 138 | 139 | def forward(self, user_id, item_id): 140 | u_embed = self.user_embedding(user_id) 141 | i_embed = self.item_embedding(item_id) 142 | return u_embed, i_embed 143 | 144 | def get_all_embeddings(self): 145 | user_embeddings = self.user_embedding.weight 146 | item_embeddings = self.item_embedding.weight 147 | return user_embeddings, item_embeddings 148 | 149 | 150 | class LGCNEncoder(nn.Module): 151 | def __init__(self, user_num, item_num, emb_size, norm_adj, n_layers=3): 152 | super(LGCNEncoder, self).__init__() 153 | self.n_users = user_num 154 | self.n_items = item_num 155 | self.n_layers = n_layers 156 | self.norm_adj = norm_adj 157 | 158 | self.user_embedding = torch.nn.Embedding(user_num, emb_size) 159 | self.item_embedding = torch.nn.Embedding(item_num, emb_size) 160 | 161 | def get_ego_embeddings(self): 162 | user_embeddings = self.user_embedding.weight 163 | item_embeddings = self.item_embedding.weight 164 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) 165 | return ego_embeddings 166 | 167 | def get_all_embeddings(self): 168 | all_embeddings = self.get_ego_embeddings() 169 | embeddings_list = [all_embeddings] 170 | 171 | for layer_idx in range(self.n_layers): 172 | all_embeddings = torch.sparse.mm(self.norm_adj, all_embeddings) 173 | embeddings_list.append(all_embeddings) 174 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) 175 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) 176 | 177 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) 178 | return user_all_embeddings, item_all_embeddings 179 | 180 | def forward(self, user_id, item_id): 181 | user_all_embeddings, item_all_embeddings = self.get_all_embeddings() 182 | u_embed = user_all_embeddings[user_id] 183 | i_embed = item_all_embeddings[item_id] 184 | return u_embed, i_embed 185 | -------------------------------------------------------------------------------- /recbole/data/dataset/sequential_dataset.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/9/16 2 | # @Author : Yushuo Chen 3 | # @Email : chenyushuo@ruc.edu.cn 4 | 5 | # UPDATE: 6 | # @Time : 2020/9/16, 2021/7/1, 2021/7/11 7 | # @Author : Yushuo Chen, Xingyu Pan, Yupeng Hou 8 | # @Email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com, houyupeng@ruc.edu.cn 9 | 10 | """ 11 | recbole.data.sequential_dataset 12 | ############################### 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from recbole.data.dataset import Dataset 19 | from recbole.data.interaction import Interaction 20 | from recbole.utils.enum_type import FeatureType, FeatureSource 21 | 22 | 23 | class SequentialDataset(Dataset): 24 | """:class:`SequentialDataset` is based on :class:`~recbole.data.dataset.dataset.Dataset`, 25 | and provides augmentation interface to adapt to Sequential Recommendation, 26 | which can accelerate the data loader. 27 | 28 | Attributes: 29 | max_item_list_len (int): Max length of historical item list. 30 | item_list_length_field (str): Field name for item lists' length. 31 | """ 32 | 33 | def __init__(self, config): 34 | self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH'] 35 | self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD'] 36 | super().__init__(config) 37 | if config['benchmark_filename'] is not None: 38 | self._benchmark_presets() 39 | 40 | def _change_feat_format(self): 41 | """Change feat format from :class:`pandas.DataFrame` to :class:`Interaction`, 42 | then perform data augmentation. 43 | """ 44 | super()._change_feat_format() 45 | 46 | if self.config['benchmark_filename'] is not None: 47 | return 48 | self.logger.debug('Augmentation for sequential recommendation.') 49 | self.data_augmentation() 50 | 51 | def _aug_presets(self): 52 | list_suffix = self.config['LIST_SUFFIX'] 53 | for field in self.inter_feat: 54 | if field != self.uid_field: 55 | list_field = field + list_suffix 56 | setattr(self, f'{field}_list_field', list_field) 57 | ftype = self.field2type[field] 58 | 59 | if ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ]: 60 | list_ftype = FeatureType.TOKEN_SEQ 61 | else: 62 | list_ftype = FeatureType.FLOAT_SEQ 63 | 64 | if ftype in [FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ]: 65 | list_len = (self.max_item_list_len, self.field2seqlen[field]) 66 | else: 67 | list_len = self.max_item_list_len 68 | 69 | self.set_field_property(list_field, list_ftype, FeatureSource.INTERACTION, list_len) 70 | 71 | self.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1) 72 | 73 | def data_augmentation(self): 74 | """Augmentation processing for sequential dataset. 75 | 76 | E.g., ``u1`` has purchase sequence ````, 77 | then after augmentation, we will generate three cases. 78 | 79 | ``u1, | i2`` 80 | 81 | (Which means given user_id ``u1`` and item_seq ````, 82 | we need to predict the next item ``i2``.) 83 | 84 | The other cases are below: 85 | 86 | ``u1, | i3`` 87 | 88 | ``u1, | i4`` 89 | """ 90 | self.logger.debug('data_augmentation') 91 | 92 | self._aug_presets() 93 | 94 | self._check_field('uid_field', 'time_field') 95 | max_item_list_len = self.config['MAX_ITEM_LIST_LENGTH'] 96 | self.sort(by=[self.uid_field, self.time_field], ascending=True) 97 | last_uid = None 98 | uid_list, item_list_index, target_index, item_list_length = [], [], [], [] 99 | seq_start = 0 100 | for i, uid in enumerate(self.inter_feat[self.uid_field].numpy()): 101 | if last_uid != uid: 102 | last_uid = uid 103 | seq_start = i 104 | else: 105 | if i - seq_start > max_item_list_len: 106 | seq_start += 1 107 | uid_list.append(uid) 108 | item_list_index.append(slice(seq_start, i)) 109 | target_index.append(i) 110 | item_list_length.append(i - seq_start) 111 | 112 | uid_list = np.array(uid_list) 113 | item_list_index = np.array(item_list_index) 114 | target_index = np.array(target_index) 115 | item_list_length = np.array(item_list_length, dtype=np.int64) 116 | 117 | new_length = len(item_list_index) 118 | new_data = self.inter_feat[target_index] 119 | new_dict = { 120 | self.item_list_length_field: torch.tensor(item_list_length), 121 | } 122 | 123 | for field in self.inter_feat: 124 | if field != self.uid_field: 125 | list_field = getattr(self, f'{field}_list_field') 126 | list_len = self.field2seqlen[list_field] 127 | shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len 128 | new_dict[list_field] = torch.zeros(shape, dtype=self.inter_feat[field].dtype) 129 | 130 | value = self.inter_feat[field] 131 | for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): 132 | new_dict[list_field][i][:length] = value[index] 133 | 134 | new_data.update(Interaction(new_dict)) 135 | self.inter_feat = new_data 136 | 137 | def _benchmark_presets(self): 138 | list_suffix = self.config['LIST_SUFFIX'] 139 | for field in self.inter_feat: 140 | if field + list_suffix in self.inter_feat: 141 | list_field = field + list_suffix 142 | setattr(self, f'{field}_list_field', list_field) 143 | self.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1) 144 | self.inter_feat[self.item_list_length_field] = self.inter_feat[self.item_id_list_field].agg(len) 145 | 146 | def inter_matrix(self, form='coo', value_field=None): 147 | """Get sparse matrix that describe interactions between user_id and item_id. 148 | Sparse matrix has shape (user_num, item_num). 149 | For a row of , ``matrix[src, tgt] = 1`` if ``value_field`` is ``None``, 150 | else ``matrix[src, tgt] = self.inter_feat[src, tgt]``. 151 | 152 | Args: 153 | form (str, optional): Sparse matrix format. Defaults to ``coo``. 154 | value_field (str, optional): Data of sparse matrix, which should exist in ``df_feat``. 155 | Defaults to ``None``. 156 | 157 | Returns: 158 | scipy.sparse: Sparse matrix in form ``coo`` or ``csr``. 159 | """ 160 | if not self.uid_field or not self.iid_field: 161 | raise ValueError('dataset does not exist uid/iid, thus can not converted to sparse matrix.') 162 | 163 | l1_idx = (self.inter_feat[self.item_list_length_field] == 1) 164 | l1_inter_dict = self.inter_feat[l1_idx].interaction 165 | new_dict = {} 166 | list_suffix = self.config['LIST_SUFFIX'] 167 | candidate_field_set = set() 168 | for field in l1_inter_dict: 169 | if field != self.uid_field and field + list_suffix in l1_inter_dict: 170 | candidate_field_set.add(field) 171 | new_dict[field] = torch.cat([self.inter_feat[field], l1_inter_dict[field + list_suffix][:, 0]]) 172 | elif (not field.endswith(list_suffix)) and (field != self.item_list_length_field): 173 | new_dict[field] = torch.cat([self.inter_feat[field], l1_inter_dict[field]]) 174 | local_inter_feat = Interaction(new_dict) 175 | return self._create_sparse_matrix(local_inter_feat, self.uid_field, self.iid_field, form, value_field) 176 | 177 | def build(self): 178 | """Processing dataset according to evaluation setting, including Group, Order and Split. 179 | See :class:`~recbole.config.eval_setting.EvalSetting` for details. 180 | 181 | Args: 182 | eval_setting (:class:`~recbole.config.eval_setting.EvalSetting`): 183 | Object contains evaluation settings, which guide the data processing procedure. 184 | 185 | Returns: 186 | list: List of built :class:`Dataset`. 187 | """ 188 | ordering_args = self.config['eval_args']['order'] 189 | if ordering_args != 'TO': 190 | raise ValueError(f'The ordering args for sequential recommendation has to be \'TO\'') 191 | 192 | return super().build() 193 | -------------------------------------------------------------------------------- /recbole/model/general_recommender/mawu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import scipy.sparse as sp 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from recbole.model.abstract_recommender import GeneralRecommender 11 | from recbole.model.init import xavier_normal_initialization 12 | from recbole.utils import InputType 13 | 14 | 15 | class MAWU(GeneralRecommender): 16 | input_type = InputType.POINTWISE 17 | 18 | def __init__(self, config, dataset): 19 | super(MAWU, self).__init__(config, dataset) 20 | 21 | # load parameters info 22 | self.embedding_size = config['embedding_size'] 23 | self.gamma1 = config['gamma1'] 24 | self.gamma2 = config['gamma2'] 25 | 26 | self.encoder_name = config['encoder'] 27 | if config['margin'] == "None": 28 | self.margin = None 29 | else: 30 | self.margin = config['margin'] 31 | 32 | # define layers and loss 33 | if self.encoder_name == 'MF': 34 | self.encoder = MFEncoder(self.n_users, self.n_items, self.embedding_size) 35 | elif self.encoder_name == 'LightGCN': 36 | self.n_layers = config['n_layers'] 37 | self.interaction_matrix = dataset.inter_matrix(form='coo').astype(np.float32) 38 | self.norm_adj = self.get_norm_adj_mat().to(self.device) 39 | self.encoder = LGCNEncoder(self.n_users, self.n_items, self.embedding_size, self.norm_adj, self.n_layers) 40 | else: 41 | raise ValueError('Non-implemented Encoder.') 42 | 43 | # user, item margin 44 | self.user_margin = nn.Embedding(self.n_users, 1) 45 | self.item_margin = nn.Embedding(self.n_items, 1) 46 | 47 | # storage variables for full sort evaluation acceleration 48 | self.restore_user_e = None 49 | self.restore_item_e = None 50 | 51 | # parameters initialization 52 | self.apply(xavier_normal_initialization) 53 | 54 | def get_norm_adj_mat(self): 55 | # build adj matrix 56 | A = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32) 57 | inter_M = self.interaction_matrix 58 | inter_M_t = self.interaction_matrix.transpose() 59 | data_dict = dict(zip(zip(inter_M.row, inter_M.col + self.n_users), [1] * inter_M.nnz)) 60 | data_dict.update(dict(zip(zip(inter_M_t.row + self.n_users, inter_M_t.col), [1] * inter_M_t.nnz))) 61 | A._update(data_dict) 62 | # norm adj matrix 63 | sumArr = (A > 0).sum(axis=1) 64 | # add epsilon to avoid divide by zero Warning 65 | diag = np.array(sumArr.flatten())[0] + 1e-7 66 | diag = np.power(diag, -0.5) 67 | D = sp.diags(diag) 68 | L = D * A * D 69 | # covert norm_adj matrix to tensor 70 | L = sp.coo_matrix(L) 71 | row = L.row 72 | col = L.col 73 | i = torch.LongTensor([row, col]) 74 | data = torch.FloatTensor(L.data) 75 | SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape)) 76 | return SparseL 77 | 78 | def forward(self, user, item): 79 | user_e, item_e = self.encoder(user, item) 80 | return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1) 81 | 82 | @staticmethod 83 | def alignment(x, y, alpha=2): 84 | return (x - y).norm(p=2, dim=1).pow(alpha).mean() 85 | 86 | @staticmethod 87 | def alignment_dot(x, y): 88 | return -torch.sum(x * y, dim=-1).mean() 89 | 90 | @staticmethod 91 | def alignment_margin(x, y, margin): 92 | cos_sim = torch.sum(x * y, dim=-1) # dot product 93 | angle_ui = torch.arccos(torch.clamp(cos_sim,-1+1e-7,1-1e-7)) # clipping 94 | angle_ui_plus_margin = angle_ui + (1 - torch.sigmoid(margin)) 95 | angle_ui_plus_margin = torch.clamp(angle_ui_plus_margin, 0., np.pi) 96 | 97 | cos_sim_margin = torch.cos(angle_ui_plus_margin) 98 | 99 | return -cos_sim_margin.mean() 100 | 101 | @staticmethod 102 | def uniformity(x, t=2): 103 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() 104 | 105 | @staticmethod 106 | def uniformity_dot(x, t=2): 107 | cos_sim = F.cosine_similarity(x[:,:,None], x.t()[None,:,:]) 108 | # take lower triangular matrix 109 | cos_sim = torch.tril(cos_sim, diagonal=-1) 110 | # convert cos_sim to distance 111 | cos_sim = 2 - 2 * cos_sim 112 | 113 | return cos_sim.mul(-t).exp().mean().log() 114 | 115 | def calculate_loss(self, interaction): 116 | if self.restore_user_e is not None or self.restore_item_e is not None: 117 | self.restore_user_e, self.restore_item_e = None, None 118 | 119 | user = interaction[self.USER_ID] 120 | item = interaction[self.ITEM_ID] 121 | user_e, item_e = self.forward(user, item) 122 | 123 | # adaptive margin 124 | user_margin = self.user_margin(user) 125 | item_margin = self.item_margin(item) 126 | margin = user_margin + item_margin 127 | 128 | # margin-aware alignment and weighted uniformity losses 129 | align_margin = self.alignment_margin(user_e, item_e, margin) 130 | uniform = self.gamma1 * self.uniformity_dot(user_e) + self.gamma2 * self.uniformity_dot(item_e) 131 | 132 | loss = align_margin + uniform 133 | 134 | # clip user/item margin 135 | self.user_margin.weight.data = torch.clamp(self.user_margin.weight.data, min=0.0, max=1.0) 136 | self.item_margin.weight.data = torch.clamp(self.item_margin.weight.data, min=0.0, max=1.0) 137 | 138 | return loss 139 | 140 | def predict(self, interaction): 141 | user = interaction[self.USER_ID] 142 | item = interaction[self.ITEM_ID] 143 | user_e = self.encoder.user_embedding(user) 144 | item_e = self.encoder.item_embedding(item) 145 | return torch.mul(user_e, item_e).sum(dim=1) 146 | 147 | def full_sort_predict(self, interaction): 148 | user = interaction[self.USER_ID] 149 | if self.encoder_name == 'LightGCN': 150 | if self.restore_user_e is None or self.restore_item_e is None: 151 | self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings() 152 | user_e = self.restore_user_e[user] 153 | all_item_e = self.restore_item_e 154 | else: 155 | user_e = self.encoder.user_embedding(user) 156 | all_item_e = self.encoder.item_embedding.weight 157 | score = torch.matmul(user_e, all_item_e.transpose(0, 1)) 158 | return score.view(-1) 159 | 160 | # def save_params(self): 161 | # user_embeddings, item_embeddings = self.encoder.get_all_embeddings() 162 | # np.save('user-DirectAU.npy', user_embeddings.data.cpu().numpy()) 163 | # np.save('item-DirectAU.npy', item_embeddings.data.cpu().numpy()) 164 | 165 | # def check(self, interaction): 166 | # user = interaction[self.USER_ID] 167 | # item = interaction[self.ITEM_ID] 168 | # user_e, item_e = self.forward(user, item) 169 | # 170 | # user_e = user_e.detach() 171 | # item_e = item_e.detach() 172 | # 173 | # alignment_loss = self.alignment(user_e, item_e) 174 | # uniform_loss = (self.uniformity(user_e) + self.uniformity(item_e)) / 2 175 | # 176 | # return alignment_loss, uniform_loss 177 | 178 | 179 | class MFEncoder(nn.Module): 180 | def __init__(self, user_num, item_num, emb_size): 181 | super(MFEncoder, self).__init__() 182 | self.user_embedding = nn.Embedding(user_num, emb_size) 183 | self.item_embedding = nn.Embedding(item_num, emb_size) 184 | 185 | def forward(self, user_id, item_id): 186 | u_embed = self.user_embedding(user_id) 187 | i_embed = self.item_embedding(item_id) 188 | return u_embed, i_embed 189 | 190 | def get_all_embeddings(self): 191 | user_embeddings = self.user_embedding.weight 192 | item_embeddings = self.item_embedding.weight 193 | return user_embeddings, item_embeddings 194 | 195 | 196 | class LGCNEncoder(nn.Module): 197 | def __init__(self, user_num, item_num, emb_size, norm_adj, n_layers=3): 198 | super(LGCNEncoder, self).__init__() 199 | self.n_users = user_num 200 | self.n_items = item_num 201 | self.n_layers = n_layers 202 | self.norm_adj = norm_adj 203 | 204 | self.user_embedding = torch.nn.Embedding(user_num, emb_size) 205 | self.item_embedding = torch.nn.Embedding(item_num, emb_size) 206 | 207 | def get_ego_embeddings(self): 208 | user_embeddings = self.user_embedding.weight 209 | item_embeddings = self.item_embedding.weight 210 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) 211 | return ego_embeddings 212 | 213 | def get_all_embeddings(self): 214 | all_embeddings = self.get_ego_embeddings() 215 | embeddings_list = [all_embeddings] 216 | 217 | for layer_idx in range(self.n_layers): 218 | all_embeddings = torch.sparse.mm(self.norm_adj, all_embeddings) 219 | embeddings_list.append(all_embeddings) 220 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) 221 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) 222 | 223 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) 224 | return user_all_embeddings, item_all_embeddings 225 | 226 | def forward(self, user_id, item_id): 227 | user_all_embeddings, item_all_embeddings = self.get_all_embeddings() 228 | u_embed = user_all_embeddings[user_id] 229 | i_embed = item_all_embeddings[item_id] 230 | return u_embed, i_embed 231 | -------------------------------------------------------------------------------- /recbole/data/dataloader/general_dataloader.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/7/7 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2020/9/9, 2020/9/29, 2021/7/15 7 | # @Author : Yupeng Hou, Yushuo Chen, Xingyu Pan 8 | # @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, xy_pan@foxmail.com 9 | 10 | """ 11 | recbole.data.dataloader.general_dataloader 12 | ################################################ 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader, NegSampleDataLoader 19 | from recbole.data.interaction import Interaction, cat_interactions 20 | from recbole.utils import InputType, ModelType 21 | 22 | 23 | class TrainDataLoader(NegSampleDataLoader): 24 | """:class:`TrainDataLoader` is a dataloader for training. 25 | It can generate negative interaction when :attr:`training_neg_sample_num` is not zero. 26 | For the result of every batch, we permit that every positive interaction and its negative interaction 27 | must be in the same batch. 28 | 29 | Args: 30 | config (Config): The config of dataloader. 31 | dataset (Dataset): The dataset of dataloader. 32 | sampler (Sampler): The sampler of dataloader. 33 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 34 | """ 35 | 36 | def __init__(self, config, dataset, sampler, shuffle=False): 37 | self._set_neg_sample_args(config, dataset, config['MODEL_INPUT_TYPE'], config['train_neg_sample_args']) 38 | super().__init__(config, dataset, sampler, shuffle=shuffle) 39 | 40 | def _init_batch_size_and_step(self): 41 | batch_size = self.config['train_batch_size'] 42 | if self.neg_sample_args['strategy'] == 'by': 43 | if self.dl_format == InputType.SETWISE: 44 | self.step = batch_size 45 | self.set_batch_size(batch_size) 46 | else: 47 | batch_num = max(batch_size // self.times, 1) 48 | new_batch_size = batch_num * self.times 49 | self.step = batch_num 50 | self.set_batch_size(new_batch_size) 51 | else: 52 | self.step = batch_size 53 | self.set_batch_size(batch_size) 54 | 55 | def update_config(self, config): 56 | self._set_neg_sample_args(config, self.dataset, config['MODEL_INPUT_TYPE'], config['train_neg_sample_args']) 57 | super().update_config(config) 58 | 59 | @property 60 | def pr_end(self): 61 | return len(self.dataset) 62 | 63 | def _shuffle(self): 64 | self.dataset.shuffle() 65 | 66 | def _next_batch_data(self): 67 | cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step]) 68 | self.pr += self.step 69 | return cur_data 70 | 71 | 72 | class NegSampleEvalDataLoader(NegSampleDataLoader): 73 | """:class:`NegSampleEvalDataLoader` is a dataloader for neg-sampling evaluation. 74 | It is similar to :class:`TrainDataLoader` which can generate negative items, 75 | and this dataloader also permits that all the interactions corresponding to each user are in the same batch 76 | and positive interactions are before negative interactions. 77 | 78 | Args: 79 | config (Config): The config of dataloader. 80 | dataset (Dataset): The dataset of dataloader. 81 | sampler (Sampler): The sampler of dataloader. 82 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 83 | """ 84 | 85 | def __init__(self, config, dataset, sampler, shuffle=False): 86 | self._set_neg_sample_args(config, dataset, InputType.POINTWISE, config['eval_neg_sample_args']) 87 | if self.neg_sample_args['strategy'] == 'by': 88 | user_num = dataset.user_num 89 | dataset.sort(by=dataset.uid_field, ascending=True) 90 | self.uid_list = [] 91 | start, end = dict(), dict() 92 | for i, uid in enumerate(dataset.inter_feat[dataset.uid_field].numpy()): 93 | if uid not in start: 94 | self.uid_list.append(uid) 95 | start[uid] = i 96 | end[uid] = i 97 | self.uid2index = np.array([None] * user_num) 98 | self.uid2items_num = np.zeros(user_num, dtype=np.int64) 99 | for uid in self.uid_list: 100 | self.uid2index[uid] = slice(start[uid], end[uid] + 1) 101 | self.uid2items_num[uid] = end[uid] - start[uid] + 1 102 | self.uid_list = np.array(self.uid_list) 103 | 104 | super().__init__(config, dataset, sampler, shuffle=shuffle) 105 | 106 | def _init_batch_size_and_step(self): 107 | batch_size = self.config['eval_batch_size'] 108 | if self.neg_sample_args['strategy'] == 'by': 109 | inters_num = sorted(self.uid2items_num * self.times, reverse=True) 110 | batch_num = 1 111 | new_batch_size = inters_num[0] 112 | for i in range(1, len(inters_num)): 113 | if new_batch_size + inters_num[i] > batch_size: 114 | break 115 | batch_num = i + 1 116 | new_batch_size += inters_num[i] 117 | self.step = batch_num 118 | self.set_batch_size(new_batch_size) 119 | else: 120 | self.step = batch_size 121 | self.set_batch_size(batch_size) 122 | 123 | def update_config(self, config): 124 | self._set_neg_sample_args(config, self.dataset, InputType.POINTWISE, config['eval_neg_sample_args']) 125 | super().update_config(config) 126 | 127 | @property 128 | def pr_end(self): 129 | if self.neg_sample_args['strategy'] == 'by': 130 | return len(self.uid_list) 131 | else: 132 | return len(self.dataset) 133 | 134 | def _shuffle(self): 135 | self.logger.warnning('NegSampleEvalDataLoader can\'t shuffle') 136 | 137 | def _next_batch_data(self): 138 | if self.neg_sample_args['strategy'] == 'by': 139 | uid_list = self.uid_list[self.pr:self.pr + self.step] 140 | data_list = [] 141 | idx_list = [] 142 | positive_u = [] 143 | positive_i = torch.tensor([], dtype=torch.int64) 144 | 145 | for idx, uid in enumerate(uid_list): 146 | index = self.uid2index[uid] 147 | data_list.append(self._neg_sampling(self.dataset[index])) 148 | idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)] 149 | positive_u += [idx for i in range(self.uid2items_num[uid])] 150 | positive_i = torch.cat((positive_i, self.dataset[index][self.iid_field]), 0) 151 | 152 | cur_data = cat_interactions(data_list) 153 | idx_list = torch.from_numpy(np.array(idx_list)) 154 | positive_u = torch.from_numpy(np.array(positive_u)) 155 | 156 | self.pr += self.step 157 | 158 | return cur_data, idx_list, positive_u, positive_i 159 | else: 160 | cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step]) 161 | self.pr += self.step 162 | return cur_data, None, None, None 163 | 164 | 165 | class FullSortEvalDataLoader(AbstractDataLoader): 166 | """:class:`FullSortEvalDataLoader` is a dataloader for full-sort evaluation. In order to speed up calculation, 167 | this dataloader would only return then user part of interactions, positive items and used items. 168 | It would not return negative items. 169 | 170 | Args: 171 | config (Config): The config of dataloader. 172 | dataset (Dataset): The dataset of dataloader. 173 | sampler (Sampler): The sampler of dataloader. 174 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 175 | """ 176 | 177 | def __init__(self, config, dataset, sampler, shuffle=False): 178 | self.uid_field = dataset.uid_field 179 | self.iid_field = dataset.iid_field 180 | self.is_sequential = config['MODEL_TYPE'] == ModelType.SEQUENTIAL 181 | if not self.is_sequential: 182 | user_num = dataset.user_num 183 | self.uid_list = [] 184 | self.uid2items_num = np.zeros(user_num, dtype=np.int64) 185 | self.uid2positive_item = np.array([None] * user_num) 186 | self.uid2history_item = np.array([None] * user_num) 187 | 188 | dataset.sort(by=self.uid_field, ascending=True) 189 | last_uid = None 190 | positive_item = set() 191 | uid2used_item = sampler.used_ids 192 | for uid, iid in zip(dataset.inter_feat[self.uid_field].numpy(), dataset.inter_feat[self.iid_field].numpy()): 193 | if uid != last_uid: 194 | self._set_user_property(last_uid, uid2used_item[last_uid], positive_item) 195 | last_uid = uid 196 | self.uid_list.append(uid) 197 | positive_item = set() 198 | positive_item.add(iid) 199 | self._set_user_property(last_uid, uid2used_item[last_uid], positive_item) 200 | self.uid_list = torch.tensor(self.uid_list, dtype=torch.int64) 201 | self.user_df = dataset.join(Interaction({self.uid_field: self.uid_list})) 202 | 203 | super().__init__(config, dataset, sampler, shuffle=shuffle) 204 | 205 | def _set_user_property(self, uid, used_item, positive_item): 206 | if uid is None: 207 | return 208 | history_item = used_item - positive_item 209 | self.uid2positive_item[uid] = torch.tensor(list(positive_item), dtype=torch.int64) 210 | self.uid2items_num[uid] = len(positive_item) 211 | self.uid2history_item[uid] = torch.tensor(list(history_item), dtype=torch.int64) 212 | 213 | def _init_batch_size_and_step(self): 214 | batch_size = self.config['eval_batch_size'] 215 | if not self.is_sequential: 216 | # interaction이 단위이기 때문에, 하나의 user에 대해 평가하면 item_num만큼의 interaction이 나옴 217 | # batch_num은 한 번에 평가할 user의 수 (설정한 batch_size에 따라 결정됨) 218 | batch_num = max(batch_size // self.dataset.item_num, 1) 219 | new_batch_size = batch_num * self.dataset.item_num 220 | self.step = batch_num 221 | self.set_batch_size(new_batch_size) 222 | else: 223 | self.step = batch_size 224 | self.set_batch_size(batch_size) 225 | 226 | @property 227 | def pr_end(self): 228 | if not self.is_sequential: 229 | return len(self.uid_list) 230 | else: 231 | return len(self.dataset) 232 | 233 | def _shuffle(self): 234 | self.logger.warnning('FullSortEvalDataLoader can\'t shuffle') 235 | 236 | def _next_batch_data(self): 237 | if not self.is_sequential: 238 | user_df = self.user_df[self.pr:self.pr + self.step] 239 | uid_list = list(user_df[self.uid_field]) 240 | 241 | history_item = self.uid2history_item[uid_list] 242 | positive_item = self.uid2positive_item[uid_list] 243 | 244 | history_u = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) 245 | history_i = torch.cat(list(history_item)) 246 | 247 | # positive_u: user index 248 | # positive_i: item index 249 | positive_u = torch.cat([torch.full_like(pos_iid, i) for i, pos_iid in enumerate(positive_item)]) 250 | positive_i = torch.cat(list(positive_item)) 251 | 252 | self.pr += self.step 253 | return user_df, (history_u, history_i), positive_u, positive_i 254 | else: 255 | interaction = self.dataset[self.pr:self.pr + self.step] 256 | inter_num = len(interaction) 257 | positive_u = torch.arange(inter_num) 258 | positive_i = interaction[self.iid_field] 259 | 260 | self.pr += self.step 261 | return interaction, None, positive_u, positive_i 262 | -------------------------------------------------------------------------------- /recbole/data/utils.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/7/21 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | # UPDATE: 6 | # @Time : 2021/7/9, 2020/9/17, 2020/8/31, 2021/2/20, 2021/3/1 7 | # @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng, Jiawei Guan 8 | # @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com, chenghaoran29@foxmail.com, guanjw@ruc.edu.cn 9 | 10 | """ 11 | recbole.data.utils 12 | ######################## 13 | """ 14 | 15 | import copy 16 | import importlib 17 | import os 18 | import pickle 19 | 20 | from recbole.data.dataloader import * 21 | from recbole.sampler import KGSampler, Sampler, RepeatableSampler 22 | from recbole.utils import ModelType, ensure_dir, get_local_time, set_color 23 | from recbole.utils.argument_list import dataset_arguments 24 | 25 | 26 | def create_dataset(config): 27 | """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`. 28 | If :attr:`config['dataset_save_path']` file exists and 29 | its :attr:`config` of dataset is equal to current :attr:`config` of dataset. 30 | It will return the saved dataset in :attr:`config['dataset_save_path']`. 31 | 32 | Args: 33 | config (Config): An instance object of Config, used to record parameter information. 34 | 35 | Returns: 36 | Dataset: Constructed dataset. 37 | """ 38 | dataset_module = importlib.import_module('recbole.data.dataset') 39 | if hasattr(dataset_module, config['model'] + 'Dataset'): 40 | dataset_class = getattr(dataset_module, config['model'] + 'Dataset') 41 | else: 42 | model_type = config['MODEL_TYPE'] 43 | type2class = { 44 | ModelType.GENERAL: 'Dataset', 45 | ModelType.SEQUENTIAL: 'SequentialDataset', 46 | ModelType.CONTEXT: 'Dataset', 47 | ModelType.KNOWLEDGE: 'KnowledgeBasedDataset', 48 | ModelType.TRADITIONAL: 'Dataset', 49 | ModelType.DECISIONTREE: 'Dataset', 50 | } 51 | dataset_class = getattr(dataset_module, type2class[model_type]) 52 | 53 | 54 | default_file = os.path.join(config['save_dataset_dir'], f'{config["dataset"]}-{dataset_class.__name__}.pth') 55 | file = config['dataset_save_path'] or default_file 56 | if os.path.exists(file): 57 | with open(file, 'rb') as f: 58 | dataset = pickle.load(f) 59 | dataset_args_unchanged = True 60 | for arg in dataset_arguments + ['seed', 'repeatable']: 61 | if config[arg] != dataset.config[arg]: 62 | dataset_args_unchanged = False 63 | break 64 | if dataset_args_unchanged: 65 | logger = getLogger() 66 | logger.info(set_color('Load filtered dataset from', 'pink') + f': [{file}]') 67 | return dataset 68 | 69 | dataset = dataset_class(config) 70 | if config['save_dataset']: 71 | dataset.save() 72 | return dataset 73 | 74 | 75 | def save_split_dataloaders(config, dataloaders): 76 | """Save split dataloaders. 77 | 78 | Args: 79 | config (Config): An instance object of Config, used to record parameter information. 80 | dataloaders (tuple of AbstractDataLoader): The split dataloaders. 81 | """ 82 | ensure_dir(config['save_dataset_dir']) 83 | save_path = config['save_dataset_dir'] 84 | saved_dataloaders_file = f'{config["dataset"]}-for-{config["model"]}-dataloader.pth' 85 | file_path = os.path.join(save_path, saved_dataloaders_file) 86 | logger = getLogger() 87 | logger.info(set_color('Saving split dataloaders into', 'pink') + f': [{file_path}]') 88 | with open(file_path, 'wb') as f: 89 | pickle.dump(dataloaders, f) 90 | 91 | 92 | def load_split_dataloaders(config): 93 | """Load split dataloaders if saved dataloaders exist and 94 | their :attr:`config` of dataset are the same as current :attr:`config` of dataset. 95 | 96 | Args: 97 | config (Config): An instance object of Config, used to record parameter information. 98 | 99 | Returns: 100 | dataloaders (tuple of AbstractDataLoader or None): The split dataloaders. 101 | """ 102 | 103 | default_file = os.path.join(config['save_dataset_dir'], f'{config["dataset"]}-for-{config["model"]}-dataloader.pth') 104 | dataloaders_save_path = config['dataloaders_save_path'] or default_file 105 | if not os.path.exists(dataloaders_save_path): 106 | return None 107 | with open(dataloaders_save_path, 'rb') as f: 108 | train_data, valid_data, test_data = pickle.load(f) 109 | for arg in dataset_arguments + ['seed', 'repeatable', 'eval_args']: 110 | if config[arg] != train_data.config[arg]: 111 | return None 112 | train_data.update_config(config) 113 | valid_data.update_config(config) 114 | test_data.update_config(config) 115 | logger = getLogger() 116 | logger.info(set_color('Load split dataloaders from', 'pink') + f': [{dataloaders_save_path}]') 117 | return train_data, valid_data, test_data 118 | 119 | 120 | def data_preparation(config, dataset): 121 | """Split the dataset by :attr:`config['eval_args']` and create training, validation and test dataloader. 122 | 123 | Note: 124 | If we can load split dataloaders by :meth:`load_split_dataloaders`, we will not create new split dataloaders. 125 | 126 | Args: 127 | config (Config): An instance object of Config, used to record parameter information. 128 | dataset (Dataset): An instance object of Dataset, which contains all interaction records. 129 | 130 | Returns: 131 | tuple: 132 | - train_data (AbstractDataLoader): The dataloader for training. 133 | - valid_data (AbstractDataLoader): The dataloader for validation. 134 | - test_data (AbstractDataLoader): The dataloader for testing. 135 | """ 136 | dataloaders = load_split_dataloaders(config) 137 | if dataloaders is not None: 138 | train_data, valid_data, test_data = dataloaders 139 | else: 140 | model_type = config['MODEL_TYPE'] 141 | built_datasets = dataset.build() 142 | 143 | train_dataset, valid_dataset, test_dataset = built_datasets 144 | train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets) 145 | 146 | if model_type != ModelType.KNOWLEDGE: 147 | train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True) 148 | else: 149 | kg_sampler = KGSampler(dataset, config['train_neg_sample_args']['distribution']) 150 | train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, kg_sampler, shuffle=True) 151 | 152 | valid_data = get_dataloader(config, 'evaluation')(config, valid_dataset, valid_sampler, shuffle=False) 153 | test_data = get_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False) 154 | if config['save_dataloaders']: 155 | save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) 156 | 157 | logger = getLogger() 158 | logger.info( 159 | set_color('[Training]: ', 'pink') + set_color('train_batch_size', 'cyan') + ' = ' + 160 | set_color(f'[{config["train_batch_size"]}]', 'yellow') + set_color(' negative sampling', 'cyan') + ': ' + 161 | set_color(f'[{config["neg_sampling"]}]', 'yellow') 162 | ) 163 | logger.info( 164 | set_color('[Evaluation]: ', 'pink') + set_color('eval_batch_size', 'cyan') + ' = ' + 165 | set_color(f'[{config["eval_batch_size"]}]', 'yellow') + set_color(' eval_args', 'cyan') + ': ' + 166 | set_color(f'[{config["eval_args"]}]', 'yellow') 167 | ) 168 | return train_data, valid_data, test_data 169 | 170 | 171 | def get_dataloader(config, phase): 172 | """Return a dataloader class according to :attr:`config` and :attr:`phase`. 173 | 174 | Args: 175 | config (Config): An instance object of Config, used to record parameter information. 176 | phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. 177 | 178 | Returns: 179 | type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`. 180 | """ 181 | register_table = { 182 | "MultiDAE": _get_AE_dataloader, 183 | "MultiVAE": _get_AE_dataloader, 184 | 'MacridVAE': _get_AE_dataloader, 185 | 'AE': _get_AE_dataloader, 186 | 'AE_NORM': _get_AE_dataloader, 187 | 'CDAE': _get_AE_dataloader, 188 | 'ENMF': _get_AE_dataloader, 189 | 'RaCT': _get_AE_dataloader, 190 | 'RecVAE': _get_AE_dataloader, 191 | } 192 | 193 | if config['model'] in register_table: 194 | return register_table[config['model']](config, phase) 195 | 196 | model_type = config['MODEL_TYPE'] 197 | if phase == 'train': 198 | if model_type != ModelType.KNOWLEDGE: 199 | return TrainDataLoader 200 | else: 201 | return KnowledgeBasedDataLoader 202 | else: 203 | eval_strategy = config['eval_neg_sample_args']['strategy'] 204 | if eval_strategy in {'none', 'by'}: 205 | return NegSampleEvalDataLoader 206 | elif eval_strategy == 'full': 207 | return FullSortEvalDataLoader 208 | 209 | 210 | def _get_AE_dataloader(config, phase): 211 | """Customized function for VAE models to get correct dataloader class. 212 | 213 | Args: 214 | config (Config): An instance object of Config, used to record parameter information. 215 | phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. 216 | 217 | Returns: 218 | type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`. 219 | """ 220 | if phase == 'train': 221 | return UserDataLoader 222 | else: 223 | eval_strategy = config['eval_neg_sample_args']['strategy'] 224 | if eval_strategy in {'none', 'by'}: 225 | return NegSampleEvalDataLoader 226 | elif eval_strategy == 'full': 227 | return FullSortEvalDataLoader 228 | 229 | 230 | def create_samplers(config, dataset, built_datasets): 231 | """Create sampler for training, validation and testing. 232 | 233 | Args: 234 | config (Config): An instance object of Config, used to record parameter information. 235 | dataset (Dataset): An instance object of Dataset, which contains all interaction records. 236 | built_datasets (list of Dataset): A list of split Dataset, which contains dataset for 237 | training, validation and testing. 238 | 239 | Returns: 240 | tuple: 241 | - train_sampler (AbstractSampler): The sampler for training. 242 | - valid_sampler (AbstractSampler): The sampler for validation. 243 | - test_sampler (AbstractSampler): The sampler for testing. 244 | """ 245 | phases = ['train', 'valid', 'test'] 246 | train_neg_sample_args = config['train_neg_sample_args'] 247 | eval_neg_sample_args = config['eval_neg_sample_args'] 248 | sampler = None 249 | train_sampler, valid_sampler, test_sampler = None, None, None 250 | 251 | if train_neg_sample_args['strategy'] != 'none': 252 | if not config['repeatable']: 253 | sampler = Sampler(phases, built_datasets, train_neg_sample_args['distribution']) 254 | else: 255 | sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution']) 256 | train_sampler = sampler.set_phase('train') 257 | 258 | if eval_neg_sample_args['strategy'] != 'none': 259 | if sampler is None: 260 | if not config['repeatable']: 261 | sampler = Sampler(phases, built_datasets, eval_neg_sample_args['distribution']) 262 | else: 263 | sampler = RepeatableSampler(phases, dataset, eval_neg_sample_args['distribution']) 264 | else: 265 | sampler.set_distribution(eval_neg_sample_args['distribution']) 266 | valid_sampler = sampler.set_phase('valid') 267 | test_sampler = sampler.set_phase('test') 268 | 269 | return train_sampler, valid_sampler, test_sampler 270 | -------------------------------------------------------------------------------- /recbole/data/dataloader/abstract_dataloader.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/7/7 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2020/10/22, 2020/9/23 7 | # @Author : Yupeng Hou, Yushuo Chen 8 | # @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn 9 | 10 | """ 11 | recbole.data.dataloader.abstract_dataloader 12 | ################################################ 13 | """ 14 | 15 | import math 16 | import copy 17 | from logging import getLogger 18 | import numpy as np 19 | 20 | import torch 21 | 22 | from recbole.data.interaction import Interaction 23 | from recbole.utils import InputType, FeatureType, FeatureSource 24 | 25 | 26 | class AbstractDataLoader: 27 | """:class:`AbstractDataLoader` is an abstract object which would return a batch of data which is loaded by 28 | :class:`~recbole.data.interaction.Interaction` when it is iterated. 29 | And it is also the ancestor of all other dataloader. 30 | 31 | Args: 32 | config (Config): The config of dataloader. 33 | dataset (Dataset): The dataset of dataloader. 34 | sampler (Sampler): The sampler of dataloader. 35 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 36 | 37 | Attributes: 38 | dataset (Dataset): The dataset of this dataloader. 39 | shuffle (bool): If ``True``, dataloader will shuffle before every epoch. 40 | pr (int): Pointer of dataloader. 41 | step (int): The increment of :attr:`pr` for each batch. 42 | batch_size (int): The max interaction number for all batch. 43 | """ 44 | 45 | def __init__(self, config, dataset, sampler, shuffle=False): 46 | self.config = config 47 | self.logger = getLogger() 48 | self.dataset = dataset 49 | self.sampler = sampler 50 | self.batch_size = self.step = self.model = None 51 | self.shuffle = shuffle 52 | self.pr = 0 53 | self._init_batch_size_and_step() 54 | 55 | def _init_batch_size_and_step(self): 56 | """Initializing :attr:`step` and :attr:`batch_size`.""" 57 | raise NotImplementedError('Method [init_batch_size_and_step] should be implemented') 58 | 59 | def update_config(self, config): 60 | """Update configure of dataloader, such as :attr:`batch_size`, :attr:`step` etc. 61 | 62 | Args: 63 | config (Config): The new config of dataloader. 64 | """ 65 | self.config = config 66 | self._init_batch_size_and_step() 67 | 68 | def __len__(self): 69 | return math.ceil(self.pr_end / self.step) 70 | 71 | def __iter__(self): 72 | if self.shuffle: 73 | self._shuffle() 74 | return self 75 | 76 | def __next__(self): 77 | if self.pr >= self.pr_end: 78 | self.pr = 0 79 | raise StopIteration() 80 | return self._next_batch_data() 81 | 82 | @property 83 | def pr_end(self): 84 | """This property marks the end of dataloader.pr which is used in :meth:`__next__`.""" 85 | raise NotImplementedError('Method [pr_end] should be implemented') 86 | 87 | def _shuffle(self): 88 | """Shuffle the order of data, and it will be called by :meth:`__iter__` if self.shuffle is True. 89 | """ 90 | raise NotImplementedError('Method [shuffle] should be implemented.') 91 | 92 | def _next_batch_data(self): 93 | """Assemble next batch of data in form of Interaction, and return these data. 94 | 95 | Returns: 96 | Interaction: The next batch of data. 97 | """ 98 | raise NotImplementedError('Method [next_batch_data] should be implemented.') 99 | 100 | def set_batch_size(self, batch_size): 101 | """Reset the batch_size of the dataloader, but it can't be called when dataloader is being iterated. 102 | 103 | Args: 104 | batch_size (int): the new batch_size of dataloader. 105 | """ 106 | if self.pr != 0: 107 | raise PermissionError('Cannot change dataloader\'s batch_size while iteration') 108 | self.batch_size = batch_size 109 | 110 | 111 | class NegSampleDataLoader(AbstractDataLoader): 112 | """:class:`NegSampleDataLoader` is an abstract class which can sample negative examples by ratio. 113 | It has two neg-sampling method, the one is 1-by-1 neg-sampling (pair wise), 114 | and the other is 1-by-multi neg-sampling (point wise). 115 | 116 | Args: 117 | config (Config): The config of dataloader. 118 | dataset (Dataset): The dataset of dataloader. 119 | sampler (Sampler): The sampler of dataloader. 120 | shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. 121 | """ 122 | 123 | def __init__(self, config, dataset, sampler, shuffle=True): 124 | super().__init__(config, dataset, sampler, shuffle=shuffle) 125 | 126 | def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): 127 | self.uid_field = dataset.uid_field # 'user_id' 128 | self.iid_field = dataset.iid_field # 'item_id' 129 | self.dl_format = dl_format 130 | self.neg_sample_args = neg_sample_args 131 | self.times = 1 132 | if self.neg_sample_args['strategy'] == 'by': 133 | self.neg_sample_num = self.neg_sample_args['by'] 134 | 135 | if self.dl_format == InputType.POINTWISE: 136 | self.times = 1 + self.neg_sample_num 137 | self.sampling_func = self._neg_sample_by_point_wise_sampling 138 | 139 | self.label_field = config['LABEL_FIELD'] 140 | dataset.set_field_property(self.label_field, FeatureType.FLOAT, FeatureSource.INTERACTION, 1) 141 | elif self.dl_format == InputType.PAIRWISE: 142 | self.times = self.neg_sample_num 143 | self.sampling_func = self._neg_sample_by_pair_wise_sampling 144 | 145 | self.neg_prefix = config['NEG_PREFIX'] 146 | self.neg_item_id = self.neg_prefix + self.iid_field 147 | 148 | columns = [self.iid_field] if dataset.item_feat is None else dataset.item_feat.columns 149 | for item_feat_col in columns: 150 | neg_item_feat_col = self.neg_prefix + item_feat_col 151 | dataset.copy_field_property(neg_item_feat_col, item_feat_col) 152 | elif self.dl_format == InputType.SETWISE: 153 | self.times = self.neg_sample_num 154 | self.sampling_func = self._neg_sample_by_set_wise_sampling 155 | 156 | self.neg_prefix = config['NEG_PREFIX'] 157 | self.neg_item_id = self.neg_prefix + self.iid_field 158 | 159 | columns = [self.iid_field] if dataset.item_feat is None else dataset.item_feat.columns 160 | for item_feat_col in columns: 161 | neg_item_feat_col = self.neg_prefix + item_feat_col 162 | dataset.copy_field_property(neg_item_feat_col, item_feat_col) 163 | else: 164 | raise ValueError(f'`neg sampling by` with dl_format [{self.dl_format}] not been implemented.') 165 | 166 | elif self.neg_sample_args['strategy'] != 'none': 167 | raise ValueError(f'`neg_sample_args` [{self.neg_sample_args["strategy"]}] is not supported!') 168 | 169 | def _neg_sampling(self, inter_feat): 170 | if 'dynamic' in self.neg_sample_args.keys() and self.neg_sample_args['dynamic'] != 'none': 171 | candidate_num = self.neg_sample_args['dynamic'] 172 | user_ids = inter_feat[self.uid_field].numpy() 173 | item_ids = inter_feat[self.iid_field].numpy() 174 | neg_candidate_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num * candidate_num) 175 | self.model.eval() 176 | interaction = copy.deepcopy(inter_feat).to(self.model.device) 177 | interaction = interaction.repeat(self.neg_sample_num * candidate_num) 178 | neg_item_feat = Interaction({self.iid_field: neg_candidate_ids.to(self.model.device)}) 179 | interaction.update(neg_item_feat) 180 | scores = self.model.predict(interaction).reshape(candidate_num, -1) 181 | indices = torch.max(scores, dim=0)[1].detach() 182 | neg_candidate_ids = neg_candidate_ids.reshape(candidate_num, -1) 183 | neg_item_ids = neg_candidate_ids[indices, [i for i in range(neg_candidate_ids.shape[1])]].view(-1) 184 | self.model.train() 185 | return self.sampling_func(inter_feat, neg_item_ids) # return Interaction with negative items 186 | elif self.neg_sample_args['in_batch']: 187 | user_ids = inter_feat[self.uid_field].numpy() 188 | item_ids = inter_feat[self.iid_field].numpy() 189 | 190 | batch_size = len(user_ids) 191 | self.times = batch_size - 1 192 | 193 | # TODO 2 : point-wise, pair-wise에도 동일하게 적용해야함. 194 | # neg_item_ids[i] = item_ids[user_ids != user_id] 195 | 196 | # Method 1 197 | # neg_item_ids = np.zeros((batch_size, self.times), dtype=np.int64) 198 | # num_neg_items = np.zeros(batch_size, dtype=np.int64) 199 | # # user가 평가한 item은 제외함 200 | # for i, used_ids_by_user in enumerate(self.sampler.used_ids[user_ids]): 201 | # neg_items_by_user = set(item_ids) - used_ids_by_user 202 | # num_neg_items_by_user = len(neg_items_by_user) 203 | # neg_item_ids[i, :num_neg_items_by_user] = np.array(list(neg_items_by_user)) 204 | 205 | # # negative items 수 저장 206 | # num_neg_items[i] = num_neg_items_by_user 207 | 208 | # Method 2 209 | # diagnoal elements만 제거함. 210 | neg_item_ids = np.tile(item_ids, (batch_size, 1), dtype=np.int64) 211 | neg_item_ids = neg_item_ids[~np.eye(neg_item_ids.shape[0], dtype=bool)].reshape(neg_item_ids.shape[0], -1) 212 | 213 | neg_item_ids = torch.tensor(neg_item_ids) 214 | num_neg_items = torch.tensor(num_neg_items) 215 | num_neg_items = None 216 | 217 | return self.sampling_func(inter_feat, neg_item_ids, num_neg_items) 218 | 219 | elif self.neg_sample_args['strategy'] == 'by': 220 | user_ids = inter_feat[self.uid_field].numpy() 221 | item_ids = inter_feat[self.iid_field].numpy() 222 | neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num) 223 | return self.sampling_func(inter_feat, neg_item_ids) 224 | else: 225 | return inter_feat 226 | 227 | def _neg_sample_by_set_wise_sampling(self, inter_feat, neg_item_ids, num_neg_items=None): 228 | if neg_item_ids.ndim == 1: 229 | neg_item_ids = neg_item_ids.view(self.times, -1).T 230 | 231 | if num_neg_items is None: 232 | pass 233 | else: 234 | # negative items 수 Interaction에 넣기 235 | num_neg_item_feat = Interaction({'num_neg_items': num_neg_items}) 236 | inter_feat.update(num_neg_item_feat) 237 | 238 | neg_item_feat = Interaction({self.iid_field: neg_item_ids}) 239 | neg_item_feat = self.dataset.join(neg_item_feat) 240 | neg_item_feat.add_prefix(self.neg_prefix) 241 | inter_feat.update(neg_item_feat) 242 | return inter_feat 243 | 244 | def _neg_sample_by_pair_wise_sampling(self, inter_feat, neg_item_ids): 245 | if neg_item_ids.ndim == 2: 246 | neg_item_ids = neg_item_ids.nonzero(as_tuple=True)[0] 247 | inter_feat = inter_feat.repeat(self.times) 248 | neg_item_feat = Interaction({self.iid_field: neg_item_ids}) 249 | neg_item_feat = self.dataset.join(neg_item_feat) 250 | neg_item_feat.add_prefix(self.neg_prefix) 251 | inter_feat.update(neg_item_feat) 252 | return inter_feat 253 | 254 | def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_item_ids): 255 | pos_inter_num = len(inter_feat) 256 | new_data = inter_feat.repeat(self.times) 257 | new_data[self.iid_field][pos_inter_num:] = neg_item_ids 258 | new_data = self.dataset.join(new_data) 259 | labels = torch.zeros(pos_inter_num * self.times) 260 | labels[:pos_inter_num] = 1.0 261 | new_data.update(Interaction({self.label_field: labels})) 262 | return new_data 263 | 264 | def get_model(self, model): 265 | self.model = model 266 | -------------------------------------------------------------------------------- /recbole/evaluator/collector.py: -------------------------------------------------------------------------------- 1 | # @Time : 2021/6/23 2 | # @Author : Zihan Lin 3 | # @Email : zhlin@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2021/7/18 7 | # @Author : Zhichao Feng 8 | # @email : fzcbupt@gmail.com 9 | 10 | """ 11 | recbole.evaluator.collector 12 | ################################################ 13 | """ 14 | 15 | from recbole.evaluator.register import Register 16 | import torch 17 | import copy 18 | 19 | 20 | class DataStruct(object): 21 | 22 | def __init__(self): 23 | self._data_dict = {} 24 | 25 | def __getitem__(self, name: str): 26 | return self._data_dict[name] 27 | 28 | def __setitem__(self, name: str, value): 29 | self._data_dict[name] = value 30 | 31 | def __delitem__(self, name: str): 32 | self._data_dict.pop(name) 33 | 34 | def __contains__(self, key: str): 35 | return key in self._data_dict 36 | 37 | def get(self, name: str): 38 | if name not in self._data_dict: 39 | raise IndexError("Can not load the data without registration !") 40 | return self[name] 41 | 42 | def set(self, name: str, value): 43 | self._data_dict[name] = value 44 | 45 | def update_tensor(self, name: str, value: torch.Tensor): 46 | if name not in self._data_dict: 47 | self._data_dict[name] = value.cpu().clone().detach() 48 | else: 49 | if not isinstance(self._data_dict[name], torch.Tensor): 50 | raise ValueError("{} is not a tensor.".format(name)) 51 | self._data_dict[name] = torch.cat((self._data_dict[name], value.cpu().clone().detach()), dim=0) 52 | 53 | def __str__(self): 54 | data_info = '\nContaining:\n' 55 | for data_key in self._data_dict.keys(): 56 | data_info += data_key + '\n' 57 | return data_info 58 | 59 | 60 | class Collector(object): 61 | """The collector is used to collect the resource for evaluator. 62 | As the evaluation metrics are various, the needed resource not only contain the recommended result 63 | but also other resource from data and model. They all can be collected by the collector during the training 64 | and evaluation process. 65 | 66 | This class is only used in Trainer. 67 | 68 | """ 69 | 70 | def __init__(self, config): 71 | self.config = config 72 | self.data_struct = DataStruct() 73 | self.register = Register(config) 74 | self.full = ('full' in config['eval_args']['mode']) 75 | self.topk = self.config['topk'] 76 | self.device = self.config['device'] 77 | 78 | def data_collect(self, train_data): 79 | """ Collect the evaluation resource from training data. 80 | Args: 81 | train_data (AbstractDataLoader): the training dataloader which contains the training data. 82 | 83 | """ 84 | if self.register.need('data.num_items'): 85 | item_id = self.config['ITEM_ID_FIELD'] 86 | self.data_struct.set('data.num_items', train_data.dataset.num(item_id)) 87 | if self.register.need('data.num_users'): 88 | user_id = self.config['USER_ID_FIELD'] 89 | self.data_struct.set('data.num_users', train_data.dataset.num(user_id)) 90 | if self.register.need('data.count_items'): 91 | self.data_struct.set('data.count_items', train_data.dataset.item_counter) 92 | if self.register.need('data.count_users'): 93 | self.data_struct.set('data.count_items', train_data.dataset.user_counter) 94 | 95 | def _average_rank(self, scores): 96 | """Get the ranking of an ordered tensor, and take the average of the ranking for positions with equal values. 97 | 98 | Args: 99 | scores(tensor): an ordered tensor, with size of `(N, )` 100 | 101 | Returns: 102 | torch.Tensor: average_rank 103 | 104 | Example: 105 | >>> average_rank(tensor([[1,2,2,2,3,3,6],[2,2,2,2,4,5,5]])) 106 | tensor([[1.0000, 3.0000, 3.0000, 3.0000, 5.5000, 5.5000, 7.0000], 107 | [2.5000, 2.5000, 2.5000, 2.5000, 5.0000, 6.5000, 6.5000]]) 108 | 109 | Reference: 110 | https://github.com/scipy/scipy/blob/v0.17.1/scipy/stats/stats.py#L5262-L5352 111 | 112 | """ 113 | length, width = scores.shape 114 | true_tensor = torch.full((length, 1), True, dtype=torch.bool, device=self.device) 115 | 116 | obs = torch.cat([true_tensor, scores[:, 1:] != scores[:, :-1]], dim=1) 117 | # bias added to dense 118 | bias = torch.arange(0, length, device=self.device).repeat(width).reshape(width, -1). \ 119 | transpose(1, 0).reshape(-1) 120 | dense = obs.view(-1).cumsum(0) + bias 121 | 122 | # cumulative counts of each unique value 123 | count = torch.where(torch.cat([obs, true_tensor], dim=1))[1] 124 | # get average rank 125 | avg_rank = .5 * (count[dense] + count[dense - 1] + 1).view(length, -1) 126 | 127 | return avg_rank 128 | 129 | def eval_batch_collect( 130 | self, scores_tensor: torch.Tensor, interaction, positive_u: torch.Tensor, positive_i: torch.Tensor 131 | ): 132 | """ Collect the evaluation resource from batched eval data and batched model output. 133 | Args: 134 | scores_tensor (Torch.Tensor): the output tensor of model with the shape of `(N, )` 135 | interaction(Interaction): batched eval data. 136 | positive_u(Torch.Tensor): the row index of positive items for each user. 137 | positive_i(Torch.Tensor): the positive item id for each user. 138 | """ 139 | if self.register.need('rec.items'): 140 | 141 | # get topk 142 | _, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k 143 | self.data_struct.update_tensor('rec.items', topk_idx) 144 | 145 | if self.register.need('rec.topk'): 146 | 147 | _, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k 148 | pos_matrix = torch.zeros_like(scores_tensor, dtype=torch.int) 149 | pos_matrix[positive_u, positive_i] = 1 150 | pos_len_list = pos_matrix.sum(dim=1, keepdim=True) 151 | pos_idx = torch.gather(pos_matrix, dim=1, index=topk_idx) 152 | result = torch.cat((pos_idx, pos_len_list), dim=1) 153 | self.data_struct.update_tensor('rec.topk', result) 154 | 155 | if self.register.need('rec.meanrank'): 156 | 157 | desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True) 158 | 159 | # get the index of positive items in the ranking list 160 | pos_matrix = torch.zeros_like(scores_tensor) 161 | pos_matrix[positive_u, positive_i] = 1 162 | pos_index = torch.gather(pos_matrix, dim=1, index=desc_index) 163 | 164 | avg_rank = self._average_rank(desc_scores) 165 | pos_rank_sum = torch.where(pos_index == 1, avg_rank, torch.zeros_like(avg_rank)).sum(dim=-1, keepdim=True) 166 | 167 | pos_len_list = pos_matrix.sum(dim=1, keepdim=True) 168 | user_len_list = desc_scores.argmin(dim=1, keepdim=True) 169 | result = torch.cat((pos_rank_sum, user_len_list, pos_len_list), dim=1) 170 | self.data_struct.update_tensor('rec.meanrank', result) 171 | 172 | if self.register.need('rec.score'): 173 | 174 | self.data_struct.update_tensor('rec.score', scores_tensor) 175 | 176 | if self.register.need('data.label'): 177 | self.label_field = self.config['LABEL_FIELD'] 178 | self.data_struct.update_tensor('data.label', interaction[self.label_field].to(self.device)) 179 | 180 | def eval_batch_collect_unbiased( 181 | self, scores_tensor: torch.Tensor, interaction, positive_u: torch.Tensor, positive_i: torch.Tensor, pscore: torch.Tensor 182 | ): 183 | """ Collect the evaluation resource from batched eval data and batched model output. 184 | Args: 185 | scores_tensor (Torch.Tensor): the output tensor of model with the shape of `(N, )` 186 | interaction(Interaction): batched eval data. 187 | positive_u(Torch.Tensor): the row index of positive items for each user. 188 | positive_i(Torch.Tensor): the positive item id for each user. 189 | """ 190 | if self.register.need('rec.items'): 191 | # get topk 192 | _, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k 193 | self.data_struct.update_tensor('rec.items', topk_idx) 194 | 195 | if self.register.need('rec.topk'): 196 | _, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k 197 | pos_matrix = torch.zeros_like(scores_tensor, dtype=torch.int) 198 | pos_matrix[positive_u, positive_i] = 1 199 | pos_len_list = pos_matrix.sum(dim=1, keepdim=True) 200 | pos_idx = torch.gather(pos_matrix, dim=1, index=topk_idx) 201 | result = torch.cat((pos_idx, pos_len_list), dim=1) # positive로 예측한 item index의 실제 값(pos_idx)과 맨 마지막에 실제 positive 수(pos_len_list) 202 | self.data_struct.update_tensor('rec.topk', result) 203 | 204 | ##### pscore 205 | pscore_result = torch.zeros_like(pos_idx, dtype=torch.float) 206 | self.device = torch.device("cuda") 207 | pscore_result = pscore_result.to(self.device) 208 | pscore = pscore.to(self.device) 209 | pscore_result[:, :] = pscore[topk_idx] 210 | self.data_struct.update_tensor('data.pscore', pscore_result) 211 | 212 | ##### pscore all 213 | pos_matrix_tmp = torch.zeros_like(scores_tensor, dtype=torch.float) 214 | pos_matrix_tmp[positive_u, positive_i] = pscore[positive_i].to(self.device) 215 | pscore_result_all_user = pos_matrix_tmp.sum(dim=1, keepdim=True) 216 | self.data_struct.update_tensor('data.pscore_all', pscore_result_all_user) 217 | 218 | if self.register.need('rec.meanrank'): 219 | desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True) 220 | 221 | # get the index of positive items in the ranking list 222 | pos_matrix = torch.zeros_like(scores_tensor) 223 | pos_matrix[positive_u, positive_i] = 1 224 | pos_index = torch.gather(pos_matrix, dim=1, index=desc_index) 225 | 226 | avg_rank = self._average_rank(desc_scores) 227 | pos_rank_sum = torch.where(pos_index == 1, avg_rank, torch.zeros_like(avg_rank)).sum(dim=-1, keepdim=True) 228 | 229 | pos_len_list = pos_matrix.sum(dim=1, keepdim=True) 230 | user_len_list = desc_scores.argmin(dim=1, keepdim=True) 231 | result = torch.cat((pos_rank_sum, user_len_list, pos_len_list), dim=1) 232 | self.data_struct.update_tensor('rec.meanrank', result) 233 | 234 | if self.register.need('rec.score'): 235 | self.data_struct.update_tensor('rec.score', scores_tensor) 236 | 237 | if self.register.need('data.label'): 238 | self.label_field = self.config['LABEL_FIELD'] 239 | self.data_struct.update_tensor('data.label', interaction[self.label_field].to(self.device)) 240 | 241 | def model_collect(self, model: torch.nn.Module): 242 | """ Collect the evaluation resource from model. 243 | Args: 244 | model (nn.Module): the trained recommendation model. 245 | """ 246 | pass 247 | # TODO: 248 | 249 | def eval_collect(self, eval_pred: torch.Tensor, data_label: torch.Tensor): 250 | """ Collect the evaluation resource from total output and label. 251 | It was designed for those models that can not predict with batch. 252 | Args: 253 | eval_pred (torch.Tensor): the output score tensor of model. 254 | data_label (torch.Tensor): the label tensor. 255 | """ 256 | if self.register.need('rec.score'): 257 | self.data_struct.update_tensor('rec.score', eval_pred) 258 | 259 | if self.register.need('data.label'): 260 | self.label_field = self.config['LABEL_FIELD'] 261 | self.data_struct.update_tensor('data.label', data_label.to(self.device)) 262 | 263 | def get_data_struct(self): 264 | """ Get all the evaluation resource that been collected. 265 | And reset some of outdated resource. 266 | """ 267 | returned_struct = copy.deepcopy(self.data_struct) 268 | for key in ['rec.topk', 'rec.meanrank', 'rec.score', 'rec.items', 'data.label']: 269 | if key in self.data_struct: 270 | del self.data_struct[key] 271 | return returned_struct 272 | 273 | def get_data_struct_unbiased(self): 274 | """ Get all the evaluation resource that been collected. 275 | And reset some of outdated resource. 276 | """ 277 | returned_struct = copy.deepcopy(self.data_struct) 278 | for key in ['rec.topk', 'rec.meanrank', 'rec.score', 'rec.items', 'data.label', 'data.pscore', 'data.pscore_all']: 279 | if key in self.data_struct: 280 | del self.data_struct[key] 281 | return returned_struct -------------------------------------------------------------------------------- /style.cfg: -------------------------------------------------------------------------------- 1 | [style] 2 | # Align closing bracket with visual indentation. 3 | align_closing_bracket_with_visual_indent=True 4 | 5 | # Allow dictionary keys to exist on multiple lines. For example: 6 | # 7 | # x = { 8 | # ('this is the first element of a tuple', 9 | # 'this is the second element of a tuple'): 10 | # value, 11 | # } 12 | allow_multiline_dictionary_keys=False 13 | 14 | # Allow lambdas to be formatted on more than one line. 15 | allow_multiline_lambdas=False 16 | 17 | # Allow splitting before a default / named assignment in an argument list. 18 | allow_split_before_default_or_named_assigns=True 19 | 20 | # Allow splits before the dictionary value. 21 | allow_split_before_dict_value=True 22 | 23 | # Let spacing indicate operator precedence. For example: 24 | # 25 | # a = 1 * 2 + 3 / 4 26 | # b = 1 / 2 - 3 * 4 27 | # c = (1 + 2) * (3 - 4) 28 | # d = (1 - 2) / (3 + 4) 29 | # e = 1 * 2 - 3 30 | # f = 1 + 2 + 3 + 4 31 | # 32 | # will be formatted as follows to indicate precedence: 33 | # 34 | # a = 1*2 + 3/4 35 | # b = 1/2 - 3*4 36 | # c = (1+2) * (3-4) 37 | # d = (1-2) / (3+4) 38 | # e = 1*2 - 3 39 | # f = 1 + 2 + 3 + 4 40 | # 41 | arithmetic_precedence_indication=False 42 | 43 | # Number of blank lines surrounding top-level function and class 44 | # definitions. 45 | blank_lines_around_top_level_definition=2 46 | 47 | # Insert a blank line before a class-level docstring. 48 | blank_line_before_class_docstring=False 49 | 50 | # Insert a blank line before a module docstring. 51 | blank_line_before_module_docstring=True 52 | 53 | # Insert a blank line before a 'def' or 'class' immediately nested 54 | # within another 'def' or 'class'. For example: 55 | # 56 | # class Foo: 57 | # # <------ this blank line 58 | # def method(): 59 | # ... 60 | blank_line_before_nested_class_or_def=True 61 | 62 | # Do not split consecutive brackets. Only relevant when 63 | # dedent_closing_brackets is set. For example: 64 | # 65 | # call_func_that_takes_a_dict( 66 | # { 67 | # 'key1': 'value1', 68 | # 'key2': 'value2', 69 | # } 70 | # ) 71 | # 72 | # would reformat to: 73 | # 74 | # call_func_that_takes_a_dict({ 75 | # 'key1': 'value1', 76 | # 'key2': 'value2', 77 | # }) 78 | coalesce_brackets=True 79 | 80 | # The column limit. 81 | column_limit=120 82 | 83 | # The style for continuation alignment. Possible values are: 84 | # 85 | # - SPACE: Use spaces for continuation alignment. This is default behavior. 86 | # - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns 87 | # (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or 88 | # CONTINUATION_INDENT_WIDTH spaces) for continuation alignment. 89 | # - VALIGN-RIGHT: Vertically align continuation lines to multiple of 90 | # INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if 91 | # cannot vertically align continuation lines with indent characters. 92 | continuation_align_style=SPACE 93 | 94 | # Indent width used for line continuations. 95 | continuation_indent_width=4 96 | 97 | # Put closing brackets on a separate line, dedented, if the bracketed 98 | # expression can't fit in a single line. Applies to all kinds of brackets, 99 | # including function definitions and calls. For example: 100 | # 101 | # config = { 102 | # 'key1': 'value1', 103 | # 'key2': 'value2', 104 | # } # <--- this bracket is dedented and on a separate line 105 | # 106 | # time_series = self.remote_client.query_entity_counters( 107 | # entity='dev3246.region1', 108 | # key='dns.query_latency_tcp', 109 | # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), 110 | # start_ts=now()-timedelta(days=3), 111 | # end_ts=now(), 112 | # ) # <--- this bracket is dedented and on a separate line 113 | dedent_closing_brackets=True 114 | 115 | # Disable the heuristic which places each list element on a separate line 116 | # if the list is comma-terminated. 117 | disable_ending_comma_heuristic=False 118 | 119 | # Place each dictionary entry onto its own line. 120 | each_dict_entry_on_separate_line=True 121 | 122 | # Require multiline dictionary even if it would normally fit on one line. 123 | # For example: 124 | # 125 | # config = { 126 | # 'key1': 'value1' 127 | # } 128 | force_multiline_dict=False 129 | 130 | # The regex for an i18n comment. The presence of this comment stops 131 | # reformatting of that line, because the comments are required to be 132 | # next to the string they translate. 133 | i18n_comment= 134 | 135 | # The i18n function call names. The presence of this function stops 136 | # reformattting on that line, because the string it has cannot be moved 137 | # away from the i18n comment. 138 | i18n_function_call= 139 | 140 | # Indent blank lines. 141 | indent_blank_lines=False 142 | 143 | # Put closing brackets on a separate line, indented, if the bracketed 144 | # expression can't fit in a single line. Applies to all kinds of brackets, 145 | # including function definitions and calls. For example: 146 | # 147 | # config = { 148 | # 'key1': 'value1', 149 | # 'key2': 'value2', 150 | # } # <--- this bracket is indented and on a separate line 151 | # 152 | # time_series = self.remote_client.query_entity_counters( 153 | # entity='dev3246.region1', 154 | # key='dns.query_latency_tcp', 155 | # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), 156 | # start_ts=now()-timedelta(days=3), 157 | # end_ts=now(), 158 | # ) # <--- this bracket is indented and on a separate line 159 | indent_closing_brackets=False 160 | 161 | # Indent the dictionary value if it cannot fit on the same line as the 162 | # dictionary key. For example: 163 | # 164 | # config = { 165 | # 'key1': 166 | # 'value1', 167 | # 'key2': value1 + 168 | # value2, 169 | # } 170 | indent_dictionary_value=False 171 | 172 | # The number of columns to use for indentation. 173 | indent_width=4 174 | 175 | # Join short lines into one line. E.g., single line 'if' statements. 176 | join_multiple_lines=True 177 | 178 | # Do not include spaces around selected binary operators. For example: 179 | # 180 | # 1 + 2 * 3 - 4 / 5 181 | # 182 | # will be formatted as follows when configured with "*,/": 183 | # 184 | # 1 + 2*3 - 4/5 185 | no_spaces_around_selected_binary_operators= 186 | 187 | # Use spaces around default or named assigns. 188 | spaces_around_default_or_named_assign=False 189 | 190 | # Adds a space after the opening '{' and before the ending '}' dict delimiters. 191 | # 192 | # {1: 2} 193 | # 194 | # will be formatted as: 195 | # 196 | # { 1: 2 } 197 | spaces_around_dict_delimiters=False 198 | 199 | # Adds a space after the opening '[' and before the ending ']' list delimiters. 200 | # 201 | # [1, 2] 202 | # 203 | # will be formatted as: 204 | # 205 | # [ 1, 2 ] 206 | spaces_around_list_delimiters=False 207 | 208 | # Use spaces around the power operator. 209 | spaces_around_power_operator=True 210 | 211 | # Use spaces around the subscript / slice operator. For example: 212 | # 213 | # my_list[1 : 10 : 2] 214 | spaces_around_subscript_colon=False 215 | 216 | # Adds a space after the opening '(' and before the ending ')' tuple delimiters. 217 | # 218 | # (1, 2, 3) 219 | # 220 | # will be formatted as: 221 | # 222 | # ( 1, 2, 3 ) 223 | spaces_around_tuple_delimiters=False 224 | 225 | # The number of spaces required before a trailing comment. 226 | # This can be a single value (representing the number of spaces 227 | # before each trailing comment) or list of values (representing 228 | # alignment column values; trailing comments within a block will 229 | # be aligned to the first column value that is greater than the maximum 230 | # line length within the block). For example: 231 | # 232 | # With spaces_before_comment=5: 233 | # 234 | # 1 + 1 # Adding values 235 | # 236 | # will be formatted as: 237 | # 238 | # 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment 239 | # 240 | # With spaces_before_comment=15, 20: 241 | # 242 | # 1 + 1 # Adding values 243 | # two + two # More adding 244 | # 245 | # longer_statement # This is a longer statement 246 | # short # This is a shorter statement 247 | # 248 | # a_very_long_statement_that_extends_beyond_the_final_column # Comment 249 | # short # This is a shorter statement 250 | # 251 | # will be formatted as: 252 | # 253 | # 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 254 | # two + two # More adding 255 | # 256 | # longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 257 | # short # This is a shorter statement 258 | # 259 | # a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length 260 | # short # This is a shorter statement 261 | # 262 | spaces_before_comment=2 263 | 264 | # Insert a space between the ending comma and closing bracket of a list, 265 | # etc. 266 | space_between_ending_comma_and_closing_bracket=False 267 | 268 | # Use spaces inside brackets, braces, and parentheses. For example: 269 | # 270 | # method_call( 1 ) 271 | # my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ] 272 | # my_set = { 1, 2, 3 } 273 | space_inside_brackets=False 274 | 275 | # Split before arguments 276 | split_all_comma_separated_values=False 277 | 278 | # Split before arguments, but do not split all subexpressions recursively 279 | # (unless needed). 280 | split_all_top_level_comma_separated_values=False 281 | 282 | # Split before arguments if the argument list is terminated by a 283 | # comma. 284 | split_arguments_when_comma_terminated=False 285 | 286 | # Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' 287 | # rather than after. 288 | split_before_arithmetic_operator=False 289 | 290 | # Set to True to prefer splitting before '&', '|' or '^' rather than 291 | # after. 292 | split_before_bitwise_operator=True 293 | 294 | # Split before the closing bracket if a list or dict literal doesn't fit on 295 | # a single line. 296 | split_before_closing_bracket=True 297 | 298 | # Split before a dictionary or set generator (comp_for). For example, note 299 | # the split before the 'for': 300 | # 301 | # foo = { 302 | # variable: 'Hello world, have a nice day!' 303 | # for variable in bar if variable != 42 304 | # } 305 | split_before_dict_set_generator=True 306 | 307 | # Split before the '.' if we need to split a longer expression: 308 | # 309 | # foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) 310 | # 311 | # would reformat to something like: 312 | # 313 | # foo = ('This is a really long string: {}, {}, {}, {}' 314 | # .format(a, b, c, d)) 315 | split_before_dot=False 316 | 317 | # Split after the opening paren which surrounds an expression if it doesn't 318 | # fit on a single line. 319 | split_before_expression_after_opening_paren=False 320 | 321 | # If an argument / parameter list is going to be split, then split before 322 | # the first argument. 323 | split_before_first_argument=False 324 | 325 | # Set to True to prefer splitting before 'and' or 'or' rather than 326 | # after. 327 | split_before_logical_operator=True 328 | 329 | # Split named assignments onto individual lines. 330 | split_before_named_assigns=True 331 | 332 | # Set to True to split list comprehensions and generators that have 333 | # non-trivial expressions and multiple clauses before each of these 334 | # clauses. For example: 335 | # 336 | # result = [ 337 | # a_long_var + 100 for a_long_var in xrange(1000) 338 | # if a_long_var % 10] 339 | # 340 | # would reformat to something like: 341 | # 342 | # result = [ 343 | # a_long_var + 100 344 | # for a_long_var in xrange(1000) 345 | # if a_long_var % 10] 346 | split_complex_comprehension=False 347 | 348 | # The penalty for splitting right after the opening bracket. 349 | split_penalty_after_opening_bracket=300 350 | 351 | # The penalty for splitting the line after a unary operator. 352 | split_penalty_after_unary_operator=10000 353 | 354 | # The penalty of splitting the line around the '+', '-', '*', '/', '//', 355 | # ``%``, and '@' operators. 356 | split_penalty_arithmetic_operator=300 357 | 358 | # The penalty for splitting right before an if expression. 359 | split_penalty_before_if_expr=0 360 | 361 | # The penalty of splitting the line around the '&', '|', and '^' 362 | # operators. 363 | split_penalty_bitwise_operator=300 364 | 365 | # The penalty for splitting a list comprehension or generator 366 | # expression. 367 | split_penalty_comprehension=80 368 | 369 | # The penalty for characters over the column limit. 370 | split_penalty_excess_character=7000 371 | 372 | # The penalty incurred by adding a line split to the unwrapped line. The 373 | # more line splits added the higher the penalty. 374 | split_penalty_for_added_line_split=30 375 | 376 | # The penalty of splitting a list of "import as" names. For example: 377 | # 378 | # from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, 379 | # long_argument_2, 380 | # long_argument_3) 381 | # 382 | # would reformat to something like: 383 | # 384 | # from a_very_long_or_indented_module_name_yada_yad import ( 385 | # long_argument_1, long_argument_2, long_argument_3) 386 | split_penalty_import_names=0 387 | 388 | # The penalty of splitting the line around the 'and' and 'or' 389 | # operators. 390 | split_penalty_logical_operator=300 391 | 392 | # Use the Tab character for indentation. 393 | use_tabs=False 394 | 395 | -------------------------------------------------------------------------------- /recbole/trainer/hyper_tuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/7/19 19:06 3 | # @Author : Shanlei Mu 4 | # @Email : slmu@ruc.edu.cn 5 | # @File : hyper_tuning.py 6 | 7 | """ 8 | recbole.trainer.hyper_tuning 9 | ############################ 10 | """ 11 | 12 | from functools import partial 13 | 14 | import numpy as np 15 | 16 | from recbole.utils.utils import dict2str 17 | 18 | 19 | def _recursiveFindNodes(root, node_type='switch'): 20 | from hyperopt.pyll.base import Apply 21 | nodes = [] 22 | if isinstance(root, (list, tuple)): 23 | for node in root: 24 | nodes.extend(_recursiveFindNodes(node, node_type)) 25 | elif isinstance(root, dict): 26 | for node in root.values(): 27 | nodes.extend(_recursiveFindNodes(node, node_type)) 28 | elif isinstance(root, (Apply)): 29 | if root.name == node_type: 30 | nodes.append(root) 31 | 32 | for node in root.pos_args: 33 | if node.name == node_type: 34 | nodes.append(node) 35 | for _, node in root.named_args: 36 | if node.name == node_type: 37 | nodes.append(node) 38 | return nodes 39 | 40 | 41 | def _parameters(space): 42 | # Analyze the domain instance to find parameters 43 | parameters = {} 44 | if isinstance(space, dict): 45 | space = list(space.values()) 46 | for node in _recursiveFindNodes(space, 'switch'): 47 | # Find the name of this parameter 48 | paramNode = node.pos_args[0] 49 | assert paramNode.name == 'hyperopt_param' 50 | paramName = paramNode.pos_args[0].obj 51 | 52 | # Find all possible choices for this parameter 53 | values = [literal.obj for literal in node.pos_args[1:]] 54 | parameters[paramName] = np.array(range(len(values))) 55 | return parameters 56 | 57 | 58 | def _spacesize(space): 59 | # Compute the number of possible combinations 60 | params = _parameters(space) 61 | return np.prod([len(values) for values in params.values()]) 62 | 63 | 64 | class ExhaustiveSearchError(Exception): 65 | r""" ExhaustiveSearchError 66 | 67 | """ 68 | pass 69 | 70 | 71 | def _validate_space_exhaustive_search(space): 72 | from hyperopt.pyll.base import dfs, as_apply 73 | from hyperopt.pyll.stochastic import implicit_stochastic_symbols 74 | supported_stochastic_symbols = ['randint', 'quniform', 'qloguniform', 'qnormal', 'qlognormal', 'categorical'] 75 | for node in dfs(as_apply(space)): 76 | if node.name in implicit_stochastic_symbols: 77 | if node.name not in supported_stochastic_symbols: 78 | raise ExhaustiveSearchError( 79 | 'Exhaustive search is only possible with the following stochastic symbols: ' 80 | '' + ', '.join(supported_stochastic_symbols) 81 | ) 82 | 83 | 84 | def exhaustive_search(new_ids, domain, trials, seed, nbMaxSucessiveFailures=1000): 85 | r""" This is for exhaustive search in HyperTuning. 86 | 87 | """ 88 | from hyperopt import pyll 89 | from hyperopt.base import miscs_update_idxs_vals 90 | # Build a hash set for previous trials 91 | hashset = set([ 92 | hash( 93 | frozenset([(key, value[0]) if len(value) > 0 else ((key, None)) 94 | for key, value in trial['misc']['vals'].items()]) 95 | ) for trial in trials.trials 96 | ]) 97 | 98 | rng = np.random.RandomState(seed) 99 | rval = [] 100 | for _, new_id in enumerate(new_ids): 101 | newSample = False 102 | nbSucessiveFailures = 0 103 | while not newSample: 104 | # -- sample new specs, idxs, vals 105 | idxs, vals = pyll.rec_eval(domain.s_idxs_vals, memo={ 106 | domain.s_new_ids: [new_id], 107 | domain.s_rng: rng, 108 | }) 109 | new_result = domain.new_result() 110 | new_misc = dict(tid=new_id, cmd=domain.cmd, workdir=domain.workdir) 111 | miscs_update_idxs_vals([new_misc], idxs, vals) 112 | 113 | # Compare with previous hashes 114 | h = hash(frozenset([(key, value[0]) if len(value) > 0 else ((key, None)) for key, value in vals.items()])) 115 | if h not in hashset: 116 | newSample = True 117 | else: 118 | # Duplicated sample, ignore 119 | nbSucessiveFailures += 1 120 | 121 | if nbSucessiveFailures > nbMaxSucessiveFailures: 122 | # No more samples to produce 123 | return [] 124 | 125 | rval.extend(trials.new_trial_docs([new_id], [None], [new_result], [new_misc])) 126 | return rval 127 | 128 | 129 | class HyperTuning(object): 130 | r"""HyperTuning Class is used to manage the parameter tuning process of recommender system models. 131 | Given objective funciton, parameters range and optimization algorithm, using HyperTuning can find 132 | the best result among these parameters 133 | 134 | Note: 135 | HyperTuning is based on the hyperopt (https://github.com/hyperopt/hyperopt) 136 | 137 | Thanks to sbrodeur for the exhaustive search code. 138 | https://github.com/hyperopt/hyperopt/issues/200 139 | """ 140 | 141 | def __init__( 142 | self, 143 | objective_function, 144 | space=None, 145 | params_file=None, 146 | params_dict=None, 147 | fixed_config_file_list=None, 148 | algo='exhaustive', 149 | max_evals=100 150 | ): 151 | self.best_score = None 152 | self.best_params = None 153 | self.best_test_result = None 154 | self.params2result = {} 155 | 156 | self.objective_function = objective_function 157 | self.max_evals = max_evals 158 | self.fixed_config_file_list = fixed_config_file_list 159 | if space: 160 | self.space = space 161 | elif params_file: 162 | self.space = self._build_space_from_file(params_file) 163 | elif params_dict: 164 | self.space = self._build_space_from_dict(params_dict) 165 | else: 166 | raise ValueError('at least one of `space`, `params_file` and `params_dict` is provided') 167 | if isinstance(algo, str): 168 | if algo == 'exhaustive': 169 | self.algo = partial(exhaustive_search, nbMaxSucessiveFailures=1000) 170 | self.max_evals = _spacesize(self.space) 171 | else: 172 | raise ValueError('Illegal algo [{}]'.format(algo)) 173 | else: 174 | self.algo = algo 175 | 176 | @staticmethod 177 | def _build_space_from_file(file): 178 | from hyperopt import hp 179 | space = {} 180 | with open(file, 'r') as fp: 181 | for line in fp: 182 | para_list = line.strip().split(' ') 183 | if len(para_list) < 3: 184 | continue 185 | para_name, para_type, para_value = para_list[0], para_list[1], "".join(para_list[2:]) 186 | if para_type == 'choice': 187 | para_value = eval(para_value) 188 | space[para_name] = hp.choice(para_name, para_value) 189 | elif para_type == 'uniform': 190 | low, high = para_value.strip().split(',') 191 | space[para_name] = hp.uniform(para_name, float(low), float(high)) 192 | elif para_type == 'quniform': 193 | low, high, q = para_value.strip().split(',') 194 | space[para_name] = hp.quniform(para_name, float(low), float(high), float(q)) 195 | elif para_type == 'loguniform': 196 | low, high = para_value.strip().split(',') 197 | space[para_name] = hp.loguniform(para_name, float(low), float(high)) 198 | else: 199 | raise ValueError('Illegal param type [{}]'.format(para_type)) 200 | return space 201 | 202 | @staticmethod 203 | def _build_space_from_dict(config_dict): 204 | from hyperopt import hp 205 | space = {} 206 | for para_type in config_dict: 207 | if para_type == 'choice': 208 | for para_name in config_dict['choice']: 209 | para_value = config_dict['choice'][para_name] 210 | space[para_name] = hp.choice(para_name, para_value) 211 | elif para_type == 'uniform': 212 | for para_name in config_dict['uniform']: 213 | para_value = config_dict['uniform'][para_name] 214 | low = para_value[0] 215 | high = para_value[1] 216 | space[para_name] = hp.uniform(para_name, float(low), float(high)) 217 | elif para_type == 'quniform': 218 | for para_name in config_dict['quniform']: 219 | para_value = config_dict['quniform'][para_name] 220 | low = para_value[0] 221 | high = para_value[1] 222 | q = para_value[2] 223 | space[para_name] = hp.quniform(para_name, float(low), float(high), float(q)) 224 | elif para_type == 'loguniform': 225 | for para_name in config_dict['loguniform']: 226 | para_value = config_dict['loguniform'][para_name] 227 | low = para_value[0] 228 | high = para_value[1] 229 | space[para_name] = hp.loguniform(para_name, float(low), float(high)) 230 | else: 231 | raise ValueError('Illegal param type [{}]'.format(para_type)) 232 | return space 233 | 234 | @staticmethod 235 | def params2str(params): 236 | r""" convert dict to str 237 | 238 | Args: 239 | params (dict): parameters dict 240 | Returns: 241 | str: parameters string 242 | """ 243 | params_str = '' 244 | for param_name in params: 245 | params_str += param_name + ':' + str(params[param_name]) + ', ' 246 | return params_str[:-2] 247 | 248 | @staticmethod 249 | def _print_result(result_dict: dict): 250 | print('current best valid score: %.4f' % result_dict['best_valid_score']) 251 | print('current best valid result:') 252 | print(result_dict['best_valid_result']) 253 | print('current test result:') 254 | print(result_dict['test_result']) 255 | print('current saved model file:') 256 | print(result_dict['saved_model_file']) 257 | print() 258 | 259 | def export_result(self, output_file=None): 260 | r""" Write the searched parameters and corresponding results to the file 261 | 262 | Args: 263 | output_file (str): the output file 264 | 265 | """ 266 | with open(output_file, 'w') as fp: 267 | fp.write('Best parameter:\n' + dict2str(self.best_params) + '\n') 268 | fp.write('Best param. Test result:\n' + dict2str(self.params2result[self.params2str(self.best_params)]['test_result']) + '\n') 269 | fp.write(', '.join([str(value) for value in self.params2result[self.params2str(self.best_params)]['test_result'].values()]) + '\n') 270 | fp.write('Best param. Saved model file:\n' + self.params2result[self.params2str(self.best_params)]['saved_model_file'] + '\n') 271 | fp.write('Train time:\n' + str(self.params2result[self.params2str(self.best_params)]['train_time']) + '\n\n') 272 | 273 | for params in self.params2result: 274 | fp.write(params + '\n') 275 | fp.write('Valid result:\n' + dict2str(self.params2result[params]['best_valid_result']) + '\n') 276 | fp.write('Test result:\n' + dict2str(self.params2result[params]['test_result']) + '\n') 277 | fp.write(', '.join([str(value) for value in self.params2result[params]['test_result'].values()]) + '\n') 278 | fp.write('Saved model file:\n' + self.params2result[params]['saved_model_file'] + '\n') 279 | fp.write('Train time:\n' + str(self.params2result[params]['train_time']) + '\n\n') 280 | 281 | def trial(self, params): 282 | r"""Given a set of parameters, return results and optimization status 283 | 284 | Args: 285 | params (dict): the parameter dictionary 286 | """ 287 | import hyperopt 288 | config_dict = params.copy() 289 | params_str = self.params2str(params) 290 | print('running parameters:', config_dict) 291 | result_dict = self.objective_function(config_dict, self.fixed_config_file_list) 292 | self.params2result[params_str] = result_dict 293 | score, bigger = result_dict['best_valid_score'], result_dict['valid_score_bigger'] 294 | 295 | if not self.best_score: 296 | self.best_score = score 297 | self.best_params = params 298 | self._print_result(result_dict) 299 | else: 300 | if bigger: 301 | if score > self.best_score: 302 | self.best_score = score 303 | self.best_params = params 304 | self._print_result(result_dict) 305 | else: 306 | if score < self.best_score: 307 | self.best_score = score 308 | self.best_params = params 309 | self._print_result(result_dict) 310 | 311 | if bigger: 312 | score = -score 313 | return {'loss': score, 'status': hyperopt.STATUS_OK} 314 | 315 | def run(self): 316 | r""" begin to search the best parameters 317 | 318 | """ 319 | from hyperopt import fmin 320 | fmin(self.trial, self.space, algo=self.algo, max_evals=self.max_evals) 321 | -------------------------------------------------------------------------------- /recbole/data/interaction.py: -------------------------------------------------------------------------------- 1 | # @Time : 2020/7/10 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | # UPDATE 6 | # @Time : 2020/9/15, 2020/9/16, 2020/8/12 7 | # @Author : Yupeng Hou, Yushuo Chen, Xingyu Pan 8 | # @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, panxy@ruc.edu.cn 9 | 10 | """ 11 | recbole.data.interaction 12 | ############################ 13 | """ 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import torch 18 | import torch.nn.utils.rnn as rnn_utils 19 | 20 | 21 | def _convert_to_tensor(data): 22 | """This function can convert common data types (list, pandas.Series, numpy.ndarray, torch.Tensor) into torch.Tensor. 23 | 24 | Args: 25 | data (list, pandas.Series, numpy.ndarray, torch.Tensor): Origin data. 26 | 27 | Returns: 28 | torch.Tensor: Converted tensor from `data`. 29 | """ 30 | elem = data[0] 31 | if isinstance(elem, (float, int, np.float, np.int64)): 32 | new_data = torch.as_tensor(data) 33 | elif isinstance(elem, (list, tuple, pd.Series, np.ndarray, torch.Tensor)): 34 | seq_data = [torch.as_tensor(d) for d in data] 35 | new_data = rnn_utils.pad_sequence(seq_data, batch_first=True) 36 | else: 37 | raise ValueError(f'[{type(elem)}] is not supported!') 38 | if new_data.dtype == torch.float64: 39 | new_data = new_data.float() 40 | return new_data 41 | 42 | 43 | class Interaction(object): 44 | """The basic class representing a batch of interaction records. 45 | 46 | Note: 47 | While training, there is no strict rules for data in one Interaction object. 48 | 49 | While testing, it should be guaranteed that all interaction records of one single 50 | user will not appear in different Interaction object, and records of the same user 51 | should be continuous. Meanwhile, the positive cases of one user always need to occur 52 | **earlier** than this user's negative cases. 53 | 54 | A correct example: 55 | ======= ======= ======= 56 | user_id item_id label 57 | ======= ======= ======= 58 | 1 2 1 59 | 1 6 1 60 | 1 3 1 61 | 1 1 0 62 | 2 3 1 63 | ... ... ... 64 | ======= ======= ======= 65 | 66 | Some wrong examples for Interaction objects used in testing: 67 | 68 | 1. 69 | ======= ======= ======= ============ 70 | user_id item_id label 71 | ======= ======= ======= ============ 72 | 1 2 1 73 | 1 6 0 # positive cases of one user always need to 74 | 75 | occur earlier than this user's negative cases 76 | 1 3 1 77 | 1 1 0 78 | 2 3 1 79 | ... ... ... 80 | ======= ======= ======= ============ 81 | 82 | 2. 83 | ======= ======= ======= ======== 84 | user_id item_id label 85 | ======= ======= ======= ======== 86 | 1 2 1 87 | 1 6 1 88 | 1 3 1 89 | 2 3 1 # records of the same user should be continuous. 90 | 1 1 0 91 | ... ... ... 92 | ======= ======= ======= ======== 93 | 94 | Attributes: 95 | interaction (dict or pandas.DataFrame): keys are meaningful str (also can be called field name), 96 | and values are Torch Tensor of numpy Array with shape (batch_size, \\*). 97 | """ 98 | 99 | def __init__(self, interaction): 100 | self.interaction = dict() 101 | if isinstance(interaction, dict): 102 | for key, value in interaction.items(): 103 | if isinstance(value, (list, np.ndarray)): 104 | self.interaction[key] = _convert_to_tensor(value) 105 | elif isinstance(value, torch.Tensor): 106 | self.interaction[key] = value 107 | else: 108 | raise ValueError(f'The type of {key}[{type(value)}] is not supported!') 109 | elif isinstance(interaction, pd.DataFrame): 110 | for key in interaction: 111 | value = interaction[key].values 112 | self.interaction[key] = _convert_to_tensor(value) 113 | else: 114 | raise ValueError(f'[{type(interaction)}] is not supported for initialize `Interaction`!') 115 | self.length = -1 116 | for k in self.interaction: 117 | self.length = max(self.length, self.interaction[k].unsqueeze(-1).shape[0]) 118 | 119 | def __iter__(self): 120 | return self.interaction.__iter__() 121 | 122 | def __getattr__(self, item): 123 | if 'interaction' not in self.__dict__: 124 | raise AttributeError(f"'Interaction' object has no attribute 'interaction'") 125 | if item in self.interaction: 126 | return self.interaction[item] 127 | raise AttributeError(f"'Interaction' object has no attribute '{item}'") 128 | 129 | def __getitem__(self, index): 130 | if isinstance(index, str): 131 | return self.interaction[index] 132 | else: 133 | ret = {} 134 | for k in self.interaction: 135 | ret[k] = self.interaction[k][index] 136 | return Interaction(ret) 137 | 138 | def __setitem__(self, key, value): 139 | if not isinstance(key, str): 140 | raise KeyError(f'{type(key)} object does not support item assigment') 141 | self.interaction[key] = value 142 | 143 | def __delitem__(self, key): 144 | if key not in self.interaction: 145 | raise KeyError(f'{type(key)} object does not in this interaction') 146 | del self.interaction[key] 147 | 148 | def __contains__(self, item): 149 | return item in self.interaction 150 | 151 | def __len__(self): 152 | return self.length 153 | 154 | def __str__(self): 155 | info = [f'The batch_size of interaction: {self.length}'] 156 | for k in self.interaction: 157 | inter = self.interaction[k] 158 | temp_str = f" {k}, {inter.shape}, {inter.device.type}, {inter.dtype}" 159 | info.append(temp_str) 160 | info.append('\n') 161 | return '\n'.join(info) 162 | 163 | def __repr__(self): 164 | return self.__str__() 165 | 166 | @property 167 | def columns(self): 168 | """ 169 | Returns: 170 | list of str: The columns of interaction. 171 | """ 172 | return list(self.interaction.keys()) 173 | 174 | def to(self, device, selected_field=None): 175 | """Transfer Tensors in this Interaction object to the specified device. 176 | 177 | Args: 178 | device (torch.device): target device. 179 | selected_field (str or iterable object, optional): if specified, only Tensors 180 | with keys in selected_field will be sent to device. 181 | 182 | Returns: 183 | Interaction: a coped Interaction object with Tensors which are sent to 184 | the specified device. 185 | """ 186 | ret = {} 187 | if isinstance(selected_field, str): 188 | selected_field = [selected_field] 189 | 190 | if selected_field is not None: 191 | selected_field = set(selected_field) 192 | for k in self.interaction: 193 | if k in selected_field: 194 | ret[k] = self.interaction[k].to(device) 195 | else: 196 | ret[k] = self.interaction[k] 197 | else: 198 | for k in self.interaction: 199 | ret[k] = self.interaction[k].to(device) 200 | return Interaction(ret) 201 | 202 | def cpu(self): 203 | """Transfer Tensors in this Interaction object to cpu. 204 | 205 | Returns: 206 | Interaction: a coped Interaction object with Tensors which are sent to cpu. 207 | """ 208 | ret = {} 209 | for k in self.interaction: 210 | ret[k] = self.interaction[k].cpu() 211 | return Interaction(ret) 212 | 213 | def numpy(self): 214 | """Transfer Tensors to numpy arrays. 215 | 216 | Returns: 217 | dict: keys the same as Interaction object, are values are corresponding numpy 218 | arrays transformed from Tensor. 219 | """ 220 | ret = {} 221 | for k in self.interaction: 222 | ret[k] = self.interaction[k].numpy() 223 | return ret 224 | 225 | def repeat(self, sizes): 226 | """Repeats each tensor along the batch dim. 227 | 228 | Args: 229 | sizes (int): repeat times. 230 | 231 | Example: 232 | >>> a = Interaction({'k': torch.zeros(4)}) 233 | >>> a.repeat(3) 234 | The batch_size of interaction: 12 235 | k, torch.Size([12]), cpu 236 | 237 | >>> a = Interaction({'k': torch.zeros(4, 7)}) 238 | >>> a.repeat(3) 239 | The batch_size of interaction: 12 240 | k, torch.Size([12, 7]), cpu 241 | 242 | Returns: 243 | a copyed Interaction object with repeated Tensors. 244 | """ 245 | ret = {} 246 | for k in self.interaction: 247 | ret[k] = self.interaction[k].repeat([sizes] + [1] * (len(self.interaction[k].shape) - 1)) 248 | return Interaction(ret) 249 | 250 | def repeat_interleave(self, repeats, dim=0): 251 | """Similar to repeat_interleave of PyTorch. 252 | 253 | Details can be found in: 254 | 255 | https://pytorch.org/docs/stable/tensors.html?highlight=repeat#torch.Tensor.repeat_interleave 256 | 257 | Note: 258 | ``torch.repeat_interleave()`` is supported in PyTorch >= 1.2.0. 259 | """ 260 | ret = {} 261 | for k in self.interaction: 262 | ret[k] = self.interaction[k].repeat_interleave(repeats, dim=dim) 263 | return Interaction(ret) 264 | 265 | def update(self, new_inter): 266 | """Similar to ``dict.update()`` 267 | 268 | Args: 269 | new_inter (Interaction): current interaction will be updated by new_inter. 270 | """ 271 | for k in new_inter.interaction: 272 | self.interaction[k] = new_inter.interaction[k] 273 | 274 | def drop(self, column): 275 | """Drop column in interaction. 276 | 277 | Args: 278 | column (str): the column to be dropped. 279 | """ 280 | if column not in self.interaction: 281 | raise ValueError(f'Column [{column}] is not in [{self}].') 282 | del self.interaction[column] 283 | 284 | def _reindex(self, index): 285 | """Reset the index of interaction inplace. 286 | 287 | Args: 288 | index: the new index of current interaction. 289 | """ 290 | for k in self.interaction: 291 | self.interaction[k] = self.interaction[k][index] 292 | 293 | def shuffle(self): 294 | """Shuffle current interaction inplace. 295 | """ 296 | index = torch.randperm(self.length) 297 | self._reindex(index) 298 | 299 | def sort(self, by, ascending=True): 300 | """Sort the current interaction inplace. 301 | 302 | Args: 303 | by (str or list of str): Field that as the key in the sorting process. 304 | ascending (bool or list of bool, optional): Results are ascending if ``True``, otherwise descending. 305 | Defaults to ``True`` 306 | """ 307 | if isinstance(by, str): 308 | if by not in self.interaction: 309 | raise ValueError(f'[{by}] is not exist in interaction [{self}].') 310 | by = [by] 311 | elif isinstance(by, (list, tuple)): 312 | for b in by: 313 | if b not in self.interaction: 314 | raise ValueError(f'[{b}] is not exist in interaction [{self}].') 315 | else: 316 | raise TypeError(f'Wrong type of by [{by}].') 317 | 318 | if isinstance(ascending, bool): 319 | ascending = [ascending] 320 | elif isinstance(ascending, (list, tuple)): 321 | for a in ascending: 322 | if not isinstance(a, bool): 323 | raise TypeError(f'Wrong type of ascending [{ascending}].') 324 | else: 325 | raise TypeError(f'Wrong type of ascending [{ascending}].') 326 | 327 | if len(by) != len(ascending): 328 | if len(ascending) == 1: 329 | ascending = ascending * len(by) 330 | else: 331 | raise ValueError(f'by [{by}] and ascending [{ascending}] should have same length.') 332 | 333 | for b, a in zip(by[::-1], ascending[::-1]): 334 | index = np.argsort(self.interaction[b], kind='stable') 335 | if not a: 336 | index = index[::-1] 337 | self._reindex(index) 338 | 339 | def add_prefix(self, prefix): 340 | """Add prefix to current interaction's columns. 341 | 342 | Args: 343 | prefix (str): The prefix to be added. 344 | """ 345 | self.interaction = {prefix + key: value for key, value in self.interaction.items()} 346 | 347 | 348 | def cat_interactions(interactions): 349 | """Concatenate list of interactions to single interaction. 350 | 351 | Args: 352 | interactions (list of :class:`Interaction`): List of interactions to be concatenated. 353 | 354 | Returns: 355 | :class:`Interaction`: Concatenated interaction. 356 | """ 357 | if not isinstance(interactions, (list, tuple)): 358 | raise TypeError(f'Interactions [{interactions}] should be list or tuple.') 359 | if len(interactions) == 0: 360 | raise ValueError(f'Interactions [{interactions}] should have some interactions.') 361 | 362 | columns_set = set(interactions[0].columns) 363 | for inter in interactions: 364 | if columns_set != set(inter.columns): 365 | raise ValueError(f'Interactions [{interactions}] should have some interactions.') 366 | 367 | new_inter = {col: torch.cat([inter[col] for inter in interactions]) for col in columns_set} 368 | return Interaction(new_inter) 369 | --------------------------------------------------------------------------------