├── .gitignore ├── README.md ├── asset └── model.jpg ├── config.py ├── data ├── dataloader.py ├── dataset.py └── pool.py ├── evaluator.py ├── main.py ├── model ├── DPGNN.py ├── __init__.py ├── abstract.py └── layer.py ├── prop ├── DPGNN.yaml └── overall.yaml ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VS Code 132 | .vscode/ 133 | 134 | # Repo 135 | dataset/ 136 | log/ 137 | saved/ 138 | *remained/ 139 | wandb/ 140 | nohup.out 141 | *_score/ 142 | *.score 143 | assets/ 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPGNN 2 | This is the official PyTorch implementation for the paper: 3 | 4 | > Modeling Two-Way Selection Preference for Person-Job Fit. RecSys 2022 5 | 6 | ## Overview 7 | 8 | We propose a dual-perspective graph representation learning approach to model directed interactions between candidates and jobs for **person-job fit**, named Dual-Perspective Graph Neural Network (**DPGNN**). 9 | 10 | ![markdown picture](./asset/model.jpg) 11 | 12 | ## Requirements 13 | 14 | ``` 15 | torch==1.10.0+cu113 16 | torch_geometric==2.0.2 17 | cudatoolkit==11.3.1 18 | ``` 19 | 20 | ## Dataset 21 | 22 | `dataset_path` in `prop/overall.yaml` should contain the following files: 23 | 24 | ``` 25 | dataset_path/ 26 | ├── data.{train/valid_g/valid_j/test_g/test_j/user_add/job_add} 27 | ├── {geek/job}.bert.npy 28 | └── {geek/job}.token 29 | ``` 30 | 31 | ### Train 32 | 33 | ```bash 34 | python main.py 35 | ``` 36 | 37 | ## Acknowledgement 38 | 39 | The implementation is based on the open-source recommendation library [RecBole](https://github.com/RUCAIBox/RecBole) and [RecBole-PJF](https://github.com/RUCAIBox/RecBole-PJF). 40 | 41 | Please consider citing the following papers as the references if you use our codes. 42 | 43 | ``` 44 | @inproceedings{yang2022modeling, 45 | author = {Chen Yang and Yupeng Hou and Yang Song and Tao Zhang and Ji-Rong Wen and Wayne Xin Zhao}, 46 | title = {Modeling Two-Way Selection Preference for Person-Job Fit}, 47 | booktitle = {{RecSys}}, 48 | year = {2022} 49 | } 50 | 51 | @inproceedings{zhao2021recbole, 52 | title={Recbole: Towards a unified, comprehensive and efficient framework for recommendation algorithms}, 53 | author={Wayne Xin Zhao and Shanlei Mu and Yupeng Hou and Zihan Lin and Kaiyuan Li and Yushuo Chen and Yujie Lu and Hui Wang and Changxin Tian and Xingyu Pan and Yingqian Min and Zhichao Feng and Xinyan Fan and Xu Chen and Pengfei Wang and Wendi Ji and Yaliang Li and Xiaoling Wang and Ji-Rong Wen}, 54 | booktitle={{CIKM}}, 55 | year={2021} 56 | } 57 | 58 | @article{zhao2022recbole, 59 | title={RecBole 2.0: Towards a More Up-to-Date Recommendation Library}, 60 | author={Zhao, Wayne Xin and Hou, Yupeng and Pan, Xingyu and Yang, Chen and Zhang, Zeyu and Lin, Zihan and Zhang, Jingsen and Bian, Shuqing and Tang, Jiakai and Sun, Wenqi and others}, 61 | journal={arXiv preprint arXiv:2206.07351}, 62 | year={2022} 63 | } 64 | ``` 65 | 66 | -------------------------------------------------------------------------------- /asset/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flust/DPGNN/26112cb2379c30c8be17ad2c3f6b878e1c5be622/asset/model.jpg -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import yaml 5 | import torch 6 | 7 | 8 | class Config(object): 9 | """ Configurator module that load the defined parameters. 10 | Configurator module will first load parameters from ``prop/overall.yaml`` and ``prop/[model].yaml``, 11 | then load parameters from ``config_dict`` 12 | The priority order is as following: 13 | config dict > yaml config file 14 | """ 15 | 16 | def __init__(self, model, config_dict=None): 17 | """ 18 | Args: 19 | model (str): the model name. 20 | config_dict (dict): the external parameter dictionary, default is None. 21 | """ 22 | self.params = self._load_parameters(model, config_dict) 23 | self._init_device() 24 | 25 | def _load_parameters(self, model, params_from_config_dict): 26 | params = {'model': model} 27 | params_from_file = self._load_config_files(model) 28 | for p in [params_from_file, params_from_config_dict]: 29 | if p is not None: 30 | params.update(p) 31 | return params 32 | 33 | def _load_config_files(self, model): 34 | yaml_loader = self._build_yaml_loader() 35 | file_config_dict = dict() 36 | file_list = [ 37 | 'prop/overall.yaml', 38 | f'prop/{model}.yaml' 39 | ] 40 | for file in file_list: 41 | with open(file, 'r', encoding='utf-8') as f: 42 | file_config_dict.update(yaml.load(f.read(), Loader=yaml_loader)) 43 | return file_config_dict 44 | 45 | def _build_yaml_loader(self): 46 | loader = yaml.FullLoader 47 | loader.add_implicit_resolver( 48 | u'tag:yaml.org,2002:float', 49 | re.compile(u'''^(?: 50 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 51 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 52 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 53 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 54 | |[-+]?\\.(?:inf|Inf|INF) 55 | |\\.(?:nan|NaN|NAN))$''', re.X), 56 | list(u'-+0123456789.')) 57 | return loader 58 | 59 | def _init_device(self): 60 | use_gpu = self.params['use_gpu'] 61 | if use_gpu: 62 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.params['gpu_id']) 63 | self.params['device'] = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu") 64 | 65 | def __getitem__(self, item): 66 | return self.params[item] if item in self.params else None 67 | 68 | def __contains__(self, key): 69 | if not isinstance(key, str): 70 | raise TypeError("index must be a str.") 71 | return key in self.params 72 | 73 | def __str__(self): 74 | return '\n\t'.join(['Parameters:'] + [ 75 | f'{arg}={value}' 76 | for arg, value in self.params.items() 77 | ]) 78 | 79 | def __repr__(self): 80 | return self.__str__() -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | 4 | def construct_dataloader(config, datasets): 5 | param_list = [ 6 | [*datasets], 7 | [config['train_batch_size']] * 1 + [config['eval_batch_size']] * 4, 8 | [True, False, False, False, False], 9 | [config['num_workers']] * 5, 10 | [config['pin_memory']] * 5 11 | ] 12 | dataloaders = [ 13 | DataLoader( 14 | dataset=ds, 15 | batch_size=bs, 16 | shuffle=shuffle, 17 | num_workers=nw, 18 | pin_memory=pm 19 | ) for ds, bs, shuffle, nw, pm in zip(*param_list) 20 | ] 21 | return dataloaders 22 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | 10 | from utils import dynamic_load 11 | import random 12 | 13 | def create_datasets(config, pool): 14 | data_list = [] 15 | if config['pattern'] == 'geek': 16 | # train on user data 17 | data_list.extend(['train_g', 'valid_g', 'test_g']) 18 | elif config['pattern'] == 'job': 19 | # train on job data 20 | data_list.extend(['train_j', 'valid_j', 'test_j']) 21 | else: 22 | # others: train on full data 23 | # data_list.extend(['train_all', 'valid_g', 'valid_j']) 24 | data_list.extend(['train_all_add', 'valid_g', 'valid_j']) 25 | 26 | # test set for geek & test set for job 27 | data_list.extend(['test_g', 'test_j']) 28 | 29 | return [ 30 | dynamic_load(config, 'data.dataset', 'Dataset')(config, pool, phase) 31 | for phase in data_list 32 | ] 33 | 34 | 35 | class PJFDataset(Dataset): 36 | def __init__(self, config, pool, phase): 37 | super(PJFDataset, self).__init__() 38 | self.config = config 39 | self.phase = phase 40 | self.logger = getLogger() 41 | 42 | self._init_attributes(pool) 43 | self._load_inters() 44 | 45 | def _init_attributes(self, pool): 46 | self.geek_num = pool.geek_num 47 | self.job_num = pool.job_num 48 | self.geek_token2id = pool.geek_token2id 49 | self.job_token2id = pool.job_token2id 50 | 51 | def _load_inters(self): 52 | filepath = os.path.join(self.config['dataset_path'], f'data.{self.phase}') 53 | self.logger.info(f'Loading from {filepath}') 54 | 55 | self.geek_ids, self.job_ids, self.labels = [], [], [] 56 | with open(filepath, 'r', encoding='utf-8') as file: 57 | for line in tqdm(file): 58 | # geek_token, job_token, ts, label = line.strip().split('\t') 59 | geek_token, job_token, label = line.strip().split('\t')[:3] 60 | geek_id = self.geek_token2id[geek_token] 61 | self.geek_ids.append(geek_id) 62 | job_id = self.job_token2id[job_token] 63 | self.job_ids.append(job_id) 64 | self.labels.append(int(label)) 65 | self.geek_ids = torch.LongTensor(self.geek_ids) 66 | self.job_ids = torch.LongTensor(self.job_ids) 67 | self.labels = torch.FloatTensor(self.labels) 68 | 69 | def __len__(self): 70 | return self.labels.shape[0] 71 | 72 | def __getitem__(self, index): 73 | return { 74 | 'geek_id': self.geek_ids[index], 75 | 'job_id': self.job_ids[index], 76 | 'label': self.labels[index] 77 | } 78 | 79 | def __str__(self): 80 | return '\n\t'.join([f'{self.phase} Dataset:'] + [ 81 | f'{self.labels.shape[0]} interactions' 82 | ]) 83 | 84 | def __repr__(self): 85 | return self.__str__() 86 | 87 | 88 | class DPGNNDataset(PJFDataset): 89 | def __init__(self, config, pool, phase): 90 | super(DPGNNDataset, self).__init__(config, pool, phase) 91 | self.pool = pool 92 | 93 | def _load_inters(self): 94 | filepath = os.path.join(self.config['dataset_path'], f'data.{self.phase}') 95 | self.logger.info(f'Loading from {filepath}') 96 | 97 | self.geek_ids, self.job_ids, self.labels = [], [], [] 98 | with open(filepath, 'r', encoding='utf-8') as file: 99 | for line in tqdm(file): 100 | geek_token, job_token, label = line.strip().split('\t')[:3] 101 | if self.phase[0:5] == 'train' and label[0] == '0': # 训练过程只保留正例 102 | continue 103 | geek_id = self.geek_token2id[geek_token] 104 | self.geek_ids.append(geek_id) 105 | job_id = self.job_token2id[job_token] 106 | self.job_ids.append(job_id) 107 | self.labels.append(int(label)) 108 | self.geek_ids = torch.LongTensor(self.geek_ids) 109 | self.job_ids = torch.LongTensor(self.job_ids) 110 | self.labels = torch.FloatTensor(self.labels) 111 | 112 | def __getitem__(self, index): 113 | geek_id = self.geek_ids[index] 114 | job_id = self.job_ids[index] 115 | neg_geek = random.randint(1, self.geek_num - 1) 116 | neg_job = random.randint(1, self.job_num - 1) 117 | 118 | while neg_job in self.pool.geek2jobs[geek_id]: 119 | neg_job = random.randint(1, self.job_num - 1) 120 | while neg_geek in self.pool.job2geeks[job_id]: 121 | neg_geek = random.randint(1, self.geek_num - 1) 122 | 123 | return { 124 | 'geek_id': self.geek_ids[index], 125 | 'job_id': self.job_ids[index], 126 | 'neg_geek': neg_geek, 127 | 'neg_job': neg_job, 128 | 'label_pos': torch.Tensor([1]), 129 | 'label_neg': torch.Tensor([0]), 130 | 'label': self.labels[index] 131 | } 132 | -------------------------------------------------------------------------------- /data/pool.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | from typing import DefaultDict 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import random 10 | import pdb 11 | from scipy.sparse import coo_matrix 12 | 13 | 14 | class PJFPool(object): 15 | def __init__(self, config): 16 | self.logger = getLogger() 17 | self.config = config 18 | self._load_ids() 19 | self._load_inter() 20 | 21 | def _load_ids(self): 22 | for target in ['geek', 'job']: 23 | token2id = {} 24 | id2token = [] 25 | filepath = os.path.join(self.config['dataset_path'], f'{target}.token') 26 | self.logger.info(f'Loading {filepath}') 27 | with open(filepath, 'r') as file: 28 | for i, line in enumerate(file): 29 | token = line.strip() 30 | token2id[token] = i 31 | id2token.append(token) 32 | setattr(self, f'{target}_token2id', token2id) 33 | setattr(self, f'{target}_id2token', id2token) 34 | setattr(self, f'{target}_num', len(id2token)) 35 | 36 | def _load_inter(self): 37 | self.geek2jobs = DefaultDict(list) 38 | self.job2geeks = DefaultDict(list) 39 | 40 | data_all = open(os.path.join(self.config['dataset_path'], f'data.train_all_add')) 41 | for l in tqdm(data_all): 42 | gid, jid, label = l.split('\t') 43 | gid = self.geek_token2id[gid] 44 | jid = self.job_token2id[jid] 45 | self.geek2jobs[gid].append(jid) 46 | self.job2geeks[jid].append(gid) 47 | 48 | def __str__(self): 49 | return '\n\t'.join(['Pool:'] + [ 50 | f'{self.geek_num} geeks', 51 | f'{self.job_num} jobs' 52 | ]) 53 | 54 | def __repr__(self): 55 | return self.__str__() 56 | 57 | 58 | class DPGNNPool(PJFPool): 59 | def __init__(self, config): 60 | super(DPGNNPool, self).__init__(config) 61 | success_file = os.path.join(self.config['dataset_path'], f'data.train_all') 62 | self.interaction_matrix = self._load_edge(success_file) 63 | 64 | user_add_file = os.path.join(self.config['dataset_path'], f'data.user_add') 65 | job_add_file = os.path.join(self.config['dataset_path'], f'data.job_add') 66 | 67 | # add_sample_rate = config['add_sample_rate'] 68 | self.user_add_matrix = self._load_edge(user_add_file) 69 | self.job_add_matrix = self._load_edge(job_add_file) 70 | 71 | if(config['ADD_BERT']): 72 | self._load_bert() 73 | 74 | def _load_edge(self, filepath): 75 | self.logger.info(f'Loading from {filepath}') 76 | self.geek_ids, self.job_ids, self.labels = [], [], [] 77 | with open(filepath, 'r', encoding='utf-8') as file: 78 | for line in tqdm(file): 79 | geek_token, job_token, label = line.strip().split('\t')[:3] 80 | 81 | geek_id = self.geek_token2id[geek_token] 82 | job_id = self.job_token2id[job_token] 83 | self.geek_ids.append(geek_id) 84 | self.job_ids.append(job_id) 85 | self.labels.append(int(label)) 86 | 87 | self.geek_ids = torch.LongTensor(self.geek_ids) 88 | self.job_ids = torch.LongTensor(self.job_ids) 89 | self.labels = torch.FloatTensor(self.labels) 90 | 91 | src = self.geek_ids 92 | tgt = self.job_ids 93 | data = self.labels 94 | interaction_matrix = coo_matrix((data, (src, tgt)), shape=(self.geek_num, self.job_num)) 95 | return interaction_matrix 96 | 97 | def _load_bert(self): 98 | u_filepath = os.path.join(self.config['dataset_path'], 'geek.bert.npy') 99 | self.logger.info(f'Loading from {u_filepath}') 100 | j_filepath = os.path.join(self.config['dataset_path'], 'job.bert.npy') 101 | # bert_filepath = os.path.join(self.config['dataset_path'], f'data.{self.phase}.bert.npy') 102 | self.logger.info(f'Loading from {j_filepath}') 103 | 104 | u_array = np.load(u_filepath).astype(np.float64) 105 | # add padding 106 | u_array = np.vstack([u_array, np.zeros((1, u_array.shape[1]))]) 107 | 108 | j_array = np.load(j_filepath).astype(np.float64) 109 | # add padding 110 | j_array = np.vstack([j_array, np.zeros((1, j_array.shape[1]))]) 111 | 112 | self.geek_token2bertid = {} 113 | self.job_token2bertid = {} 114 | for i in range(u_array.shape[0]): 115 | self.geek_token2bertid[str(u_array[i, 0].astype(int))] = i 116 | for i in range(j_array.shape[0]): 117 | self.job_token2bertid[str(j_array[i, 0].astype(int))] = i 118 | 119 | self.u_bert_vec = torch.FloatTensor(u_array[:, 1:]) 120 | self.j_bert_vec = torch.FloatTensor(j_array[:, 1:]) 121 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | 4 | import numpy as np 5 | from sklearn.metrics import roc_auc_score 6 | from sklearn.metrics import log_loss 7 | import pdb 8 | 9 | class Evaluator: 10 | def __init__(self, config): 11 | self.logger = getLogger() 12 | self.metric2func = { 13 | 'ndcg': self._calcu_nDCG, 14 | 'map': self._calcu_MAP, 15 | 'p': self._calcu_Precision, 16 | 'r': self._calcu_Recall, 17 | 'auc': roc_auc_score, 18 | 'logloss': log_loss, 19 | } 20 | self.topk = config['topk'] 21 | self.maxtopk = max(self.topk) 22 | self.precision = config['metric_decimal_place'] 23 | self.metrics = ['r@5', 'p@5', 'ndcg@5', 'mrr'] 24 | 25 | self.base = [] 26 | self.idcg = [] 27 | for i in range(self.maxtopk): 28 | self.base.append(np.log(2) / np.log(i + 2)) 29 | if i > 0: 30 | self.idcg.append(self.base[i] + self.idcg[i - 1]) 31 | else: 32 | self.idcg.append(self.base[i]) 33 | 34 | #self._load_geek2weak(config['dataset_path']) 35 | 36 | def collect(self, interaction, scores, reverse=False): 37 | uid2topk = {} 38 | scores = scores.cpu().numpy() 39 | labels = interaction['label'].numpy() 40 | if reverse: 41 | actor = 'job_id' 42 | else: 43 | actor = 'geek_id' 44 | for i, uid in enumerate(interaction[actor].numpy()): 45 | if uid not in uid2topk: 46 | uid2topk[uid] = [] 47 | uid2topk[uid].append((scores[i], labels[i])) 48 | return uid2topk 49 | 50 | def evaluate(self, uid2topk_list, group='all'): 51 | uid2topk = self._merge_uid2topk(uid2topk_list) 52 | uid2topk = self._filter_illegal(uid2topk) 53 | uid2topk = self._filter_group(uid2topk, group) 54 | result = {} 55 | result.update(self._calcu_ranking_metrics(uid2topk)) 56 | result.update(self._calcu_cls_metrics(uid2topk)) 57 | 58 | for m in result: 59 | result[m] = round(result[m], self.precision) 60 | return result, self._format_str(result) 61 | 62 | def _format_str(self, result): 63 | res = '' 64 | for metric in self.metrics: 65 | res += '\n\t{}:\t{:.4f}'.format(metric, result[metric]) 66 | return res 67 | 68 | def _calcu_ranking_metrics(self, uid2topk): 69 | result = {} 70 | for m in ['ndcg', 'map', 'p', 'r']: 71 | for k in self.topk: 72 | metric = f'{m}@{k}' 73 | if metric in self.metrics: 74 | result[metric] = self.metric2func[m](uid2topk, k) 75 | if 'mrr' in self.metrics: 76 | result['mrr'] = self._calcu_MRR(uid2topk) 77 | return result 78 | 79 | def _calcu_cls_metrics(self, uid2topk): 80 | scores, labels = self._flatten_cls_list(uid2topk) 81 | result = {} 82 | for m in ['auc', 'logloss']: 83 | if m in self.metrics: 84 | result[m] = self.metric2func[m](labels, scores) 85 | if 'gauc' in self.metrics: 86 | result['gauc'] = self._calcu_GAUC(uid2topk) 87 | return result 88 | 89 | def _calcu_GAUC(self, uid2topk): 90 | weight_sum = auc_sum = 0 91 | for uid, lst in uid2topk.items(): 92 | score_list, lb_list = zip(*lst) 93 | scores = np.array(score_list) 94 | labels = np.array(lb_list) 95 | w = len(labels) 96 | auc = roc_auc_score(labels, scores) 97 | weight_sum += w 98 | auc_sum += auc * w 99 | return float(auc_sum / weight_sum) 100 | 101 | def _calcu_nDCG(self, uid2topk, k): 102 | tot = 0 103 | for uid in uid2topk: 104 | dcg = 0 105 | pos = 0 106 | for i, (score, lb) in enumerate(uid2topk[uid][:k]): 107 | dcg += lb * self.base[i] 108 | pos += lb 109 | tot += dcg / self.idcg[int(pos) - 1] 110 | return tot / len(uid2topk) 111 | 112 | def _calcu_Precision(self, uid2topk, k): 113 | tot = 0 114 | for uid in uid2topk: 115 | rec = 0 116 | rel = 0 117 | for i, (score, lb) in enumerate(uid2topk[uid][:k]): 118 | rec += 1 119 | rel += lb 120 | tot += rel / rec 121 | return tot / len(uid2topk) 122 | 123 | def _calcu_Recall(self, uid2topk, k): 124 | tot = 0 125 | for uid in uid2topk: 126 | rec = 0 127 | rel = 0 128 | for i, (score, lb) in enumerate(uid2topk[uid]): 129 | if i < k: 130 | rec += lb 131 | rel += lb 132 | tot += rec / rel 133 | return tot / len(uid2topk) 134 | 135 | def _calcu_MRR(self, uid2topk): 136 | tot = 0 137 | for uid in uid2topk: 138 | for i, (score, lb) in enumerate(uid2topk[uid]): 139 | if lb == 1: 140 | tot += 1 / (i + 1) 141 | break 142 | return tot / len(uid2topk) 143 | 144 | def _calcu_MAP(self, uid2topk, k): 145 | tot = 0 146 | for uid in uid2topk: 147 | tp = 0 148 | pos = 0 149 | ap = 0 150 | for i, (score, lb) in enumerate(uid2topk[uid][:k]): 151 | if lb == 1: 152 | tp += 1 153 | pos += 1 154 | ap += tp / (i + 1) 155 | if pos > 0: 156 | tot += ap / pos 157 | return tot / len(uid2topk) 158 | 159 | def _merge_uid2topk(self, uid2topk_list): 160 | uid2topk = {} 161 | for single_uid2topk in uid2topk_list: 162 | for uid in single_uid2topk: 163 | if uid not in uid2topk: 164 | uid2topk[uid] = [] 165 | uid2topk[uid].extend(single_uid2topk[uid]) 166 | return self._sort_uid2topk(uid2topk) 167 | 168 | def _load_geek2weak(self, dataset_path): 169 | self.geek2weak = [] 170 | filepath = os.path.join(dataset_path, f'geek.weak') 171 | self.logger.info(f'Loading {filepath}') 172 | with open(filepath, 'r') as file: 173 | for line in file: 174 | token, weak = line.strip().split('\t') 175 | self.geek2weak.append(int(weak)) 176 | 177 | def _filter_illegal(self, uid2topk): 178 | new_uid2topk = {} 179 | for uid, lst in uid2topk.items(): 180 | score_list, lb_list = zip(*lst) 181 | lb_sum = sum(lb_list) 182 | if lb_sum == 0 or lb_sum == len(lb_list): 183 | continue 184 | new_uid2topk[uid] = uid2topk[uid] 185 | return new_uid2topk 186 | 187 | def _filter_group(self, uid2topk, group): 188 | if group == 'all': 189 | return uid2topk 190 | elif group in ['weak', 'skilled']: 191 | self.logger.info(f'Evaluating on [{group}]') 192 | flag = 1 if group == 'weak' else 0 193 | new_uid2topk = {} 194 | # for uid in uid2topk: 195 | # if abs(self.geek2weak[uid] - flag) < 0.1: 196 | # new_uid2topk[uid] = uid2topk[uid] 197 | return new_uid2topk 198 | else: 199 | raise NotImplementedError(f'Not support [{group}]') 200 | 201 | def _sort_uid2topk(self, uid2topk): 202 | for uid in uid2topk: 203 | uid2topk[uid].sort(key=lambda t: t[1], reverse=False) 204 | uid2topk[uid].sort(key=lambda t: t[0], reverse=True) 205 | return uid2topk 206 | 207 | def _flatten_cls_list(self, uid2topk): 208 | scores = [] 209 | labels = [] 210 | for uid, lst in uid2topk.items(): 211 | score_list, lb_list = zip(*lst) 212 | scores.append(np.array(score_list)) 213 | labels.append(np.array(lb_list)) 214 | scores = np.concatenate(scores) 215 | labels = np.concatenate(labels) 216 | assert scores.shape[0] == labels.shape[0] 217 | return scores, labels 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import os 3 | import sys 4 | 5 | from config import Config 6 | from data.dataset import create_datasets 7 | from data.dataloader import construct_dataloader 8 | from trainer import Trainer 9 | from utils import init_seed, init_logger, dynamic_load 10 | 11 | def get_arguments(): 12 | args = dict() 13 | for arg in sys.argv[1:]: 14 | arg_name, arg_value = arg.split('=') 15 | try: 16 | arg_value = int(arg_value) 17 | except: 18 | try: 19 | arg_value = float(arg_value) 20 | except: 21 | pass 22 | arg_name = arg_name.strip('-') 23 | args[arg_name] = arg_value 24 | print(args) 25 | return args 26 | 27 | def main_process(model='DPGNN', config_dict=None, saved=True): 28 | """Main process API for experiments of VPJF 29 | 30 | Args: 31 | model (str): Model name. 32 | config_dict (dict): Parameters dictionary used to modify experiment parameters. 33 | Defaults to ``None``. 34 | saved (bool): Whether to save the model parameters. Defaults to ``True``. 35 | """ 36 | 37 | # configurations initialization 38 | config = Config(model, config_dict=config_dict) 39 | init_seed(config['seed'], config['reproducibility']) 40 | 41 | # logger initialization 42 | init_logger(config) 43 | logger = getLogger() 44 | logger.info(config) 45 | 46 | # data preparation 47 | pool = dynamic_load(config, 'data.pool', 'Pool')(config) 48 | logger.info(pool) 49 | 50 | datasets = create_datasets(config, pool) 51 | for ds in datasets: 52 | logger.info(ds) 53 | 54 | # load dataset 55 | train_data, valid_data_g, valid_data_j, test_data_g, test_data_j = construct_dataloader(config, datasets) 56 | 57 | # model loading and initialization 58 | model = dynamic_load(config, 'model')(config, pool).to(config['device']) 59 | logger.info(model) 60 | 61 | # trainer loading and initialization 62 | trainer = Trainer(config, model) 63 | 64 | # model training 65 | best_valid_score, best_valid_result_g, best_valid_result_j = trainer.fit(train_data, valid_data_g, valid_data_j, saved=saved) 66 | 67 | logger.info('best valid result for geek: {}'.format(best_valid_result_g)) 68 | logger.info('best valid result for job: {}'.format(best_valid_result_j)) 69 | 70 | # model evaluation for user 71 | test_result, test_result_str = trainer.evaluate(test_data_g, load_best_model=True) 72 | logger.info('test for user result [all]: {}'.format(test_result_str)) 73 | 74 | # model evaluation for job 75 | test_result, test_result_str = trainer.evaluate(test_data_j, load_best_model=True, reverse=True) 76 | logger.info('test for job result [all]: {}'.format(test_result_str)) 77 | 78 | return { 79 | 'best_valid_score': best_valid_score, 80 | 'best_valid_result_g': best_valid_result_g, 81 | 'best_valid_result_j':best_valid_result_j, 82 | 'test_result': test_result 83 | } 84 | 85 | 86 | if __name__ == "__main__": 87 | args = get_arguments() 88 | main_process(model='DPGNN', config_dict=args) 89 | -------------------------------------------------------------------------------- /model/DPGNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.init import xavier_normal_ 6 | from model.abstract import PJFModel 7 | from model.layer import BPRLoss, EmbLoss 8 | from model.layer import GCNConv 9 | 10 | from torch_geometric.nn import MessagePassing 11 | from torch_geometric.utils import degree 12 | 13 | 14 | class DPGNN(PJFModel): 15 | def __init__(self, config, pool): 16 | super(DPGNN, self).__init__(config, pool) 17 | self.config = config 18 | self.pool = pool 19 | 20 | # load dataset info 21 | self.interaction_matrix = pool.interaction_matrix.astype(np.float32) 22 | self.user_add_matrix = pool.user_add_matrix.astype(np.float32) 23 | self.job_add_matrix = pool.job_add_matrix.astype(np.float32) 24 | 25 | # load parameters info 26 | self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN 27 | self.n_layers = config['n_layers'] # int type:the layer num of lightGCN 28 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization 29 | self.mul_weight = config['mutual_weight'] 30 | self.temperature = config['temperature'] 31 | self.n_users = pool.geek_num 32 | self.n_items = pool.job_num 33 | 34 | # layers 35 | self.user_embedding_a = nn.Embedding(self.n_users, self.latent_dim) 36 | self.item_embedding_a = nn.Embedding(self.n_items, self.latent_dim) 37 | self.user_embedding_p = nn.Embedding(self.n_users, self.latent_dim) 38 | self.item_embedding_p = nn.Embedding(self.n_items, self.latent_dim) 39 | self.gcn_conv = GCNConv(dim=self.latent_dim) 40 | self.mf_loss = BPRLoss() 41 | self.reg_loss = EmbLoss() 42 | self.mutual_loss = nn.CrossEntropyLoss().to(self.device) 43 | self.loss = 0 44 | 45 | # bert part 46 | self.ADD_BERT = config['ADD_BERT'] 47 | self.BERT_e_size = 0 48 | if self.ADD_BERT: 49 | self.BERT_e_size = config['BERT_output_size'] 50 | self.bert_lr = nn.Linear(config['BERT_embedding_size'], 51 | self.BERT_e_size).to(self.config['device']) 52 | self._load_bert() 53 | 54 | # generate intermediate data 55 | self.edge_index, self.edge_weight = self.get_norm_adj_mat() 56 | self.edge_index = self.edge_index.to(self.device) 57 | self.edge_weight = self.edge_weight.to(self.device) 58 | 59 | self.apply(self._init_weights) 60 | 61 | def _init_weights(self, module): 62 | if isinstance(module, nn.Embedding): 63 | xavier_normal_(module.weight.data) 64 | 65 | def _load_bert(self): 66 | self.bert_user = torch.FloatTensor([]).to(self.config['device']) 67 | for i in range(self.n_users): 68 | geek_token = self.pool.geek_id2token[i] 69 | bert_id = self.pool.geek_token2bertid[geek_token] 70 | bert_u_vec = self.pool.u_bert_vec[bert_id, :].unsqueeze(0).to(self.config['device']) 71 | self.bert_user = torch.cat([self.bert_user, bert_u_vec], dim=0) 72 | del bert_u_vec 73 | 74 | self.bert_job = torch.FloatTensor([]).to(self.config['device']) 75 | for i in range(self.n_items): 76 | job_token = self.pool.job_id2token[i] 77 | bert_id = self.pool.job_token2bertid[job_token] 78 | bert_j_vec = self.pool.j_bert_vec[bert_id].unsqueeze(0).to(self.config['device']) 79 | self.bert_job = torch.cat([self.bert_job, bert_j_vec], dim=0) 80 | del bert_j_vec 81 | 82 | def get_norm_adj_mat(self): 83 | r"""Get the normalized interaction matrix of users and items. 84 | Construct the square matrix from the training data and normalize it 85 | using the laplace matrix. 86 | .. math:: 87 | A_{hat} = D^{-0.5} \times A \times D^{-0.5} 88 | Returns: 89 | The normalized interaction matrix in Tensor. 90 | """ 91 | # user a node: [0 ~~~ n_users] (len: n_users) 92 | # item p node: [n_users + 1 ~~~ n_users + 1 + n_items] (len: n_users) 93 | # user p node: [~] (len: n_users) 94 | # item a node: [~] (len: n_items) 95 | n_all = self.n_users + self.n_items 96 | 97 | # success edge 98 | row = torch.LongTensor(self.interaction_matrix.row) 99 | col = torch.LongTensor(self.interaction_matrix.col) + self.n_users 100 | edge_index1 = torch.stack([row, col]) 101 | edge_index2 = torch.stack([col, row]) 102 | edge_index3 = torch.stack([row + n_all, col + n_all]) 103 | edge_index4 = torch.stack([col + n_all, row + n_all]) 104 | edge_index_suc = torch.cat([edge_index1, edge_index2, edge_index3, edge_index4], dim=1) 105 | 106 | # user_add edge 107 | row = torch.LongTensor(self.user_add_matrix.row) 108 | col = torch.LongTensor(self.user_add_matrix.col) + self.n_users 109 | edge_index1 = torch.stack([row, col]) 110 | edge_index2 = torch.stack([col, row]) 111 | edge_index_user_add = torch.cat([edge_index1, edge_index2], dim=1) 112 | 113 | # job_add edge 114 | row = torch.LongTensor(self.job_add_matrix.row) 115 | col = torch.LongTensor(self.job_add_matrix.col) + self.n_users 116 | edge_index1 = torch.stack([row + n_all, col + n_all]) 117 | edge_index2 = torch.stack([col + n_all, row + n_all]) 118 | edge_index_job_add = torch.cat([edge_index1, edge_index2], dim=1) 119 | 120 | # self edge 121 | geek = torch.LongTensor(torch.arange(0, self.n_users)) 122 | job = torch.LongTensor(torch.arange(0, self.n_items) + self.n_users) 123 | edge_index_geek_1 = torch.stack([geek, geek + n_all]) 124 | edge_index_geek_2 = torch.stack([geek + n_all, geek]) 125 | edge_index_job_1 = torch.stack([job, job + n_all]) 126 | edge_index_job_2 = torch.stack([job + n_all, job]) 127 | edge_index_self = torch.cat([edge_index_geek_1, edge_index_geek_2, edge_index_job_1, edge_index_job_2], dim=1) 128 | 129 | # all edge 130 | edge_index = torch.cat([edge_index_suc, edge_index_user_add, edge_index_job_add, edge_index_self], dim=1) 131 | 132 | deg = degree(edge_index[0], (self.n_users + self.n_items) * 2) 133 | norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg)) 134 | 135 | edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]] 136 | 137 | return edge_index, edge_weight 138 | 139 | def get_ego_embeddings(self): 140 | r"""Get the embedding of users and items and combine to an embedding matrix. 141 | 142 | Returns: 143 | Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] 144 | """ 145 | user_embeddings_a = self.user_embedding_a.weight 146 | item_embeddings_a = self.item_embedding_a.weight 147 | user_embeddings_p = self.user_embedding_p.weight 148 | item_embeddings_p = self.item_embedding_p.weight 149 | ego_embeddings = torch.cat([user_embeddings_a, 150 | item_embeddings_p, 151 | user_embeddings_p, 152 | item_embeddings_a], dim=0) 153 | 154 | if self.ADD_BERT: 155 | self.bert_u = self.bert_lr(self.bert_user) 156 | self.bert_j = self.bert_lr(self.bert_job) 157 | 158 | bert_e = torch.cat([self.bert_u, 159 | self.bert_j, 160 | self.bert_u, 161 | self.bert_j], dim = 0) 162 | return torch.cat([ego_embeddings, bert_e], dim=1) 163 | 164 | return ego_embeddings 165 | 166 | def info_nce_loss(self, index, is_user): 167 | all_embeddings = self.get_ego_embeddings() 168 | user_e_a, item_e_p, user_e_p, item_e_a = torch.split(all_embeddings, 169 | [self.n_users, self.n_items, self.n_users, self.n_items]) 170 | if is_user: 171 | u_e_a = F.normalize(user_e_a[index], dim=1) 172 | u_e_p = F.normalize(user_e_p[index], dim=1) 173 | similarity_matrix = torch.matmul(u_e_a, u_e_p.T) 174 | else: 175 | i_e_a = F.normalize(item_e_a[index], dim=1) 176 | i_e_p = F.normalize(item_e_p[index], dim=1) 177 | similarity_matrix = torch.matmul(i_e_a, i_e_p.T) 178 | 179 | mask = torch.eye(index.shape[0], dtype=torch.bool).to(self.device) 180 | 181 | positives = similarity_matrix[mask].view(index.shape[0], -1) 182 | negatives = similarity_matrix[~mask].view(index.shape[0], -1) 183 | 184 | logits = torch.cat([positives, negatives], dim=1) 185 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device) 186 | logits = logits / self.temperature 187 | 188 | return logits, labels 189 | 190 | def forward(self): 191 | all_embeddings = self.get_ego_embeddings() 192 | embeddings_list = [all_embeddings] 193 | 194 | for layer_idx in range(self.n_layers): 195 | all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) 196 | embeddings_list.append(all_embeddings) 197 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) 198 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) 199 | 200 | user_e_a, item_e_p, user_e_p, item_e_a = torch.split(lightgcn_all_embeddings, 201 | [self.n_users, self.n_items, self.n_users, self.n_items]) 202 | return user_e_a, item_e_p, user_e_p, item_e_a 203 | 204 | def calculate_loss(self, interaction): 205 | user = interaction['geek_id'] 206 | item = interaction['job_id'] 207 | neg_user = interaction['neg_geek'] 208 | neg_item = interaction['neg_job'] 209 | 210 | user_e_a, item_e_p, user_e_p, item_e_a = self.forward() 211 | 212 | # user active 213 | u_e_a = user_e_a[user] 214 | n_u_e_a = user_e_a[neg_user] 215 | # item negative 216 | i_e_p = item_e_p[item] 217 | n_i_e_p = item_e_p[neg_item] 218 | 219 | # user negative 220 | u_e_p = user_e_p[user] 221 | n_u_e_p = user_e_p[neg_user] 222 | # item active 223 | i_e_a = item_e_a[item] 224 | n_i_e_a = item_e_a[neg_item] 225 | 226 | r_pos = torch.mul(u_e_a, i_e_p).sum(dim=1) 227 | s_pos = torch.mul(u_e_p, i_e_a).sum(dim=1) 228 | 229 | r_neg1 = torch.mul(u_e_a, n_i_e_p).sum(dim=1) 230 | s_neg1 = torch.mul(u_e_p, n_i_e_a).sum(dim=1) 231 | 232 | r_neg2 = torch.mul(n_u_e_a, i_e_p).sum(dim=1) 233 | s_neg2 = torch.mul(n_u_e_p, i_e_a).sum(dim=1) 234 | 235 | # calculate BPR Loss 236 | # pos_scores = I_geek + I_job 237 | # neg_scores_u = n_I_geek_1 + n_I_job_1 238 | # neg_scores_i = n_I_geek_2 + n_I_job_2 239 | 240 | mf_loss_u = self.mf_loss(2 * r_pos + 2 * s_pos, r_neg1 + s_neg1 + r_neg2 + s_neg2) 241 | 242 | # calculate Emb Loss 243 | u_ego_embeddings_a = self.user_embedding_a(user) 244 | u_ego_embeddings_p = self.user_embedding_p(user) 245 | pos_ego_embeddings_a = self.item_embedding_a(item) 246 | pos_ego_embeddings_p = self.item_embedding_p(item) 247 | neg_ego_embeddings_a = self.item_embedding_a(neg_item) 248 | neg_ego_embeddings_p = self.item_embedding_p(neg_item) 249 | neg_u_ego_embeddings_a = self.user_embedding_a(neg_user) 250 | neg_u_ego_embeddings_p = self.user_embedding_p(neg_user) 251 | 252 | reg_loss = self.reg_loss(u_ego_embeddings_a, u_ego_embeddings_p, 253 | pos_ego_embeddings_a, pos_ego_embeddings_p, 254 | neg_ego_embeddings_a, neg_ego_embeddings_p, 255 | neg_u_ego_embeddings_a, neg_u_ego_embeddings_p) 256 | 257 | loss = mf_loss_u + self.reg_weight * reg_loss 258 | 259 | logits_user, labels = self.info_nce_loss(user, is_user=True) 260 | loss += self.mul_weight * self.mutual_loss(logits_user, labels) 261 | 262 | logits_job, labels = self.info_nce_loss(item, is_user=False) 263 | loss += self.mul_weight * self.mutual_loss(logits_job, labels) 264 | 265 | return loss 266 | 267 | def predict(self, interaction): 268 | user = interaction['geek_id'] 269 | item = interaction['job_id'] 270 | 271 | user_e_a, item_e_p, user_e_p, item_e_a = self.forward() 272 | 273 | # user activate 274 | u_e_a = user_e_a[user] 275 | # item negative 276 | i_e_p = item_e_p[item] 277 | 278 | # user negative 279 | u_e_p = user_e_p[user] 280 | # item negative 281 | i_e_a = item_e_a[item] 282 | 283 | I_geek = torch.mul(u_e_a, i_e_p).sum(dim=1) 284 | I_job = torch.mul(u_e_p, i_e_a).sum(dim=1) 285 | # calculate BPR Loss 286 | scores = I_geek + I_job 287 | 288 | return scores 289 | 290 | 291 | 292 | 293 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.DPGNN import DPGNN 2 | -------------------------------------------------------------------------------- /model/abstract.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | 7 | class PJFModel(nn.Module): 8 | r"""Base class for all Person-Job Fit models 9 | """ 10 | 11 | def __init__(self, config, pool): 12 | super(PJFModel, self).__init__() 13 | 14 | self.logger = getLogger() 15 | self.device = config['device'] 16 | 17 | self.geek_num = pool.geek_num 18 | self.job_num = pool.job_num 19 | 20 | def calculate_loss(self, interaction): 21 | """Calculate the training loss for a batch data. 22 | 23 | Args: 24 | interaction (dict): Interaction class of the batch. 25 | 26 | Returns: 27 | torch.Tensor: Training loss, shape: [] 28 | """ 29 | raise NotImplementedError 30 | 31 | def predict(self, interaction): 32 | """Predict the scores between users and items. 33 | 34 | Args: 35 | interaction (dict): Interaction class of the batch. 36 | 37 | Returns: 38 | torch.Tensor: Predicted scores for given users and items, shape: [batch_size] 39 | """ 40 | raise NotImplementedError 41 | 42 | def __str__(self): 43 | """Model prints with number of trainable parameters 44 | """ 45 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 46 | params = sum([np.prod(p.size()) for p in model_parameters]) 47 | return super(PJFModel, self).__str__() + '\n\tTrainable parameters: {}'.format(params) 48 | -------------------------------------------------------------------------------- /model/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.init import normal_ 6 | from torch_geometric.nn import MessagePassing 7 | from torch_geometric.utils import add_self_loops, degree, remove_self_loops, softmax 8 | from torch.nn import Linear 9 | 10 | from typing import Union, Tuple, Optional, Any 11 | 12 | 13 | class MLPLayers(nn.Module): 14 | """ MLPLayers 15 | 16 | Args: 17 | - layers(list): a list contains the size of each layer in mlp layers 18 | - dropout(float): probability of an element to be zeroed. Default: 0 19 | - activation(str): activation function after each layer in mlp layers. Default: 'relu'. 20 | candidates: 'sigmoid', 'tanh', 'relu', 'leekyrelu', 'none' 21 | 22 | Shape: 23 | 24 | - Input: (:math:`N`, \*, :math:`H_{in}`) where \* means any number of additional dimensions 25 | :math:`H_{in}` must equal to the first value in `layers` 26 | - Output: (:math:`N`, \*, :math:`H_{out}`) where :math:`H_{out}` equals to the last value in `layers` 27 | 28 | Examples: 29 | 30 | >>> m = MLPLayers([64, 32, 16], 0.2, 'relu') 31 | >>> input = torch.randn(128, 64) 32 | >>> output = m(input) 33 | >>> print(output.size()) 34 | >>> torch.Size([128, 16]) 35 | """ 36 | 37 | def __init__(self, layers, dropout=0., activation='relu', bn=False, init_method=None): 38 | super(MLPLayers, self).__init__() 39 | self.layers = layers 40 | self.dropout = dropout 41 | self.activation = activation 42 | self.use_bn = bn 43 | self.init_method = init_method 44 | 45 | mlp_modules = [] 46 | for idx, (input_size, output_size) in enumerate(zip(self.layers[:-1], self.layers[1:])): 47 | mlp_modules.append(nn.Dropout(p=self.dropout)) 48 | mlp_modules.append(nn.Linear(input_size, output_size)) 49 | if self.use_bn: 50 | mlp_modules.append(nn.BatchNorm1d(num_features=output_size)) 51 | activation_func = activation_layer(self.activation, output_size) 52 | if activation_func is not None: 53 | mlp_modules.append(activation_func) 54 | 55 | self.mlp_layers = nn.Sequential(*mlp_modules) 56 | if self.init_method is not None: 57 | self.apply(self.init_weights) 58 | 59 | def init_weights(self, module): 60 | # We just initialize the module with normal distribution as the paper said 61 | if isinstance(module, nn.Linear): 62 | if self.init_method == 'norm': 63 | normal_(module.weight.data, 0, 0.01) 64 | if module.bias is not None: 65 | module.bias.data.fill_(0.0) 66 | 67 | def forward(self, input_feature): 68 | return self.mlp_layers(input_feature) 69 | 70 | 71 | def activation_layer(activation_name='relu', emb_dim=None): 72 | """Construct activation layers 73 | 74 | Args: 75 | activation_name: str, name of activation function 76 | emb_dim: int, used for Dice activation 77 | 78 | Return: 79 | activation: activation layer 80 | """ 81 | if activation_name is None: 82 | activation = None 83 | elif isinstance(activation_name, str): 84 | if activation_name.lower() == 'sigmoid': 85 | activation = nn.Sigmoid() 86 | elif activation_name.lower() == 'tanh': 87 | activation = nn.Tanh() 88 | elif activation_name.lower() == 'relu': 89 | activation = nn.ReLU() 90 | elif activation_name.lower() == 'leakyrelu': 91 | activation = nn.LeakyReLU() 92 | elif activation_name.lower() == 'none': 93 | activation = None 94 | elif issubclass(activation_name, nn.Module): 95 | activation = activation_name() 96 | else: 97 | raise NotImplementedError("activation function {} is not implemented".format(activation_name)) 98 | 99 | return activation 100 | 101 | 102 | class SimpleFusionLayer(nn.Module): 103 | def __init__(self, hd_size): 104 | super(SimpleFusionLayer, self).__init__() 105 | self.fc = nn.Linear(hd_size * 4, hd_size) 106 | 107 | def forward(self, a, b): 108 | assert a.shape == b.shape 109 | x = torch.cat([a, b, a * b, a - b], dim=-1) 110 | x = self.fc(x) 111 | x = torch.tanh(x) 112 | return x 113 | 114 | 115 | class FusionLayer(nn.Module): 116 | def __init__(self, hd_size): 117 | super(FusionLayer, self).__init__() 118 | self.m = SimpleFusionLayer(hd_size) 119 | self.g = nn.Sequential( 120 | nn.Linear(hd_size * 2, 1), 121 | nn.Sigmoid() 122 | ) 123 | 124 | def _single_layer(self, a, b): 125 | ma = self.m(a, b) 126 | x = torch.cat([a, b], dim=-1) 127 | ga = self.g(x) 128 | return ga * ma + (1 - ga) * a 129 | 130 | def forward(self, a, b): 131 | assert a.shape == b.shape 132 | a = self._single_layer(a, b) 133 | b = self._single_layer(b, a) 134 | return torch.cat([a, b], dim=-1) 135 | 136 | 137 | class BPRLoss(nn.Module): 138 | def __init__(self, gamma=1e-10): 139 | super(BPRLoss, self).__init__() 140 | self.gamma = gamma 141 | 142 | def forward(self, pos_score, neg_score): 143 | loss = -torch.log(self.gamma + torch.sigmoid(pos_score - neg_score)).mean() 144 | return loss 145 | 146 | 147 | class BiBPRLoss(nn.Module): 148 | def __init__(self, gamma=1e-10): 149 | super(BiBPRLoss, self).__init__() 150 | self.gamma = gamma 151 | 152 | def forward(self, pos_score, neg_score_g, neg_score_j, omega = 1): 153 | loss = - torch.log(self.gamma + torch.sigmoid(pos_score - neg_score_g)).mean() \ 154 | - omega * torch.log(self.gamma + torch.sigmoid(pos_score - neg_score_j)).mean() 155 | return loss 156 | 157 | 158 | class HingeLoss(torch.nn.Module): 159 | def __init__(self): 160 | super(HingeLoss, self).__init__() 161 | self.delta = 0.05 162 | 163 | def forward(self, pos_score, neg_score): 164 | hinge_loss = torch.clamp(neg_score - pos_score+ self.delta, min=0).mean() 165 | return hinge_loss 166 | 167 | 168 | class EmbLoss(nn.Module): 169 | """ EmbLoss, regularization on embeddings 170 | """ 171 | 172 | def __init__(self, norm=2): 173 | super(EmbLoss, self).__init__() 174 | self.norm = norm 175 | 176 | def forward(self, *embeddings): 177 | emb_loss = torch.zeros(1).to(embeddings[-1].device) 178 | for embedding in embeddings: 179 | emb_loss += torch.norm(embedding, p=self.norm) 180 | emb_loss /= embeddings[-1].shape[0] 181 | return emb_loss 182 | 183 | 184 | class GCNConv(MessagePassing): 185 | def __init__(self, dim): 186 | super(GCNConv, self).__init__(aggr='add') 187 | self.dim = dim 188 | 189 | def forward(self, x, edge_index, edge_weight): 190 | return self.propagate(edge_index, x=x, edge_weight=edge_weight) 191 | 192 | def message(self, x_j, edge_weight): 193 | return edge_weight.view(-1, 1) * x_j 194 | 195 | def __repr__(self): 196 | return '{}({})'.format(self.__class__.__name__, self.dim) 197 | 198 | 199 | class GNNConv(MessagePassing): 200 | def __init__(self, dim, n_geek, n_job): 201 | super(GNNConv, self).__init__(node_dim=0, aggr='add') 202 | self.dim = dim 203 | 204 | def forward(self, x, edge_index, edge_weight): 205 | return self.propagate(edge_index, x=x, edge_weight=edge_weight) 206 | 207 | def message(self, x_j, edge_weight): 208 | return edge_weight.view(-1, 1) * x_j 209 | 210 | def __repr__(self): 211 | return '{}({})'.format(self.__class__.__name__, self.dim) 212 | 213 | 214 | -------------------------------------------------------------------------------- /prop/DPGNN.yaml: -------------------------------------------------------------------------------- 1 | # Model 2 | embedding_size: 128 3 | n_layers: 3 4 | reg_weight: 1e-05 5 | mutual_weight: 0.05 6 | temperature: 0.2 7 | 8 | ADD_BERT: True 9 | BERT_embedding_size: 768 10 | BERT_output_size: 32 11 | 12 | # Training 13 | learning_rate: 0.001 14 | 15 | # General 16 | train_batch_size: 4096 17 | eval_batch_size: 4096 18 | 19 | -------------------------------------------------------------------------------- /prop/overall.yaml: -------------------------------------------------------------------------------- 1 | # Device 2 | use_gpu: True 3 | gpu_id: 1 4 | 5 | # Training 6 | learner: Adam 7 | epochs: 3000 # 300 8 | eval_step: 1 9 | stopping_step: 5 10 | clip_grad_norm: ~ 11 | 12 | # Evaluation 13 | topk: [5] 14 | valid_metric: r@5 15 | 16 | # DataLoader 17 | num_workers: 4 18 | pin_memory: True 19 | 20 | # General 21 | checkpoint_dir: ./saved/ 22 | dataset_path: ./dataset/ 23 | 24 | loss_decimal_place: 4 25 | metric_decimal_place: 4 26 | 27 | # Repreducibility 28 | seed: 2020 29 | reproducibility: True 30 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | from time import time 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch.nn.utils.clip_grad import clip_grad_norm_ 8 | from tqdm import tqdm 9 | 10 | from utils import ensure_dir, get_local_time, dict2device 11 | from evaluator import Evaluator 12 | import pdb 13 | 14 | 15 | class Trainer(object): 16 | """The Trainer for training and evaluation strategies. 17 | 18 | Initializing the Trainer needs two parameters: `config` and `model`. 19 | - `config` records the parameters information for controlling training and evaluation, 20 | such as `learning_rate`, `epochs`, `eval_step` and so on. 21 | - `model` is the instantiated object of a Model Class. 22 | """ 23 | 24 | def __init__(self, config, model): 25 | self.config = config 26 | self.model = model 27 | 28 | self.logger = getLogger() 29 | self.learner = config['learner'].lower() 30 | self.learning_rate = config['learning_rate'] 31 | self.epochs = config['epochs'] 32 | self.eval_step = min(config['eval_step'], self.epochs) 33 | self.stopping_step = config['stopping_step'] 34 | self.clip_grad_norm = config['clip_grad_norm'] 35 | self.valid_metric = config['valid_metric'].lower() 36 | self.test_batch_size = config['eval_batch_size'] 37 | self.device = config['device'] 38 | self.checkpoint_dir = config['checkpoint_dir'] 39 | ensure_dir(self.checkpoint_dir) 40 | saved_model_file = '{}-{}.pth'.format(self.config['model'], get_local_time()) 41 | self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file) 42 | 43 | self.start_epoch = 0 44 | self.cur_step = 0 45 | self.best_valid_score = -1 46 | self.best_valid_result = None 47 | self.train_loss_dict = dict() 48 | self.optimizer = self._build_optimizer() 49 | self.evaluator = Evaluator(config) 50 | 51 | # pa = os.path.join(self.checkpoint_dir, 'BPJFNN-Dec-27-2021_15-42-32.pth') 52 | # self.resume_checkpoint(pa) 53 | 54 | def _build_optimizer(self): 55 | """Init the Optimizer 56 | 57 | Returns: 58 | torch.optim: the optimizer 59 | """ 60 | opt2method = { 61 | 'adam': optim.Adam, 62 | 'sgd': optim.SGD, 63 | 'adagrad': optim.Adagrad, 64 | 'rmsprop': optim.RMSprop, 65 | 'sparse_adam': optim.SparseAdam 66 | } 67 | 68 | if self.learner in opt2method: 69 | optimizer = opt2method[self.learner](filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.learning_rate) 70 | else: 71 | self.logger.warning('Received unrecognized optimizer, set default Adam optimizer') 72 | optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) 73 | return optimizer 74 | 75 | def _train_epoch(self, train_data, epoch_idx): 76 | """Train the model in an epoch 77 | 78 | Args: 79 | train_data (DataLoader): The train data. 80 | epoch_idx (int): The current epoch id. 81 | 82 | Returns: 83 | float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains 84 | multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a 85 | tuple which includes the sum of loss in each part. 86 | """ 87 | self.model.train() 88 | total_loss = None 89 | iter_data = ( 90 | tqdm( 91 | enumerate(train_data), 92 | total=len(train_data), 93 | desc=f"Train {epoch_idx:>5}", 94 | ) 95 | ) 96 | 97 | for batch_idx, interaction in iter_data: 98 | interaction = dict2device(interaction, self.device) 99 | self.optimizer.zero_grad() 100 | losses = self.model.calculate_loss(interaction) 101 | if isinstance(losses, tuple): 102 | loss = sum(losses) 103 | loss_tuple = tuple(per_loss.item() for per_loss in losses) 104 | total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple))) 105 | else: 106 | loss = losses 107 | total_loss = losses.item() if total_loss is None else total_loss + losses.item() 108 | self._check_nan(loss) 109 | # loss.backward(retain_graph=True) 110 | loss.backward() 111 | if self.clip_grad_norm: 112 | clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm) 113 | self.optimizer.step() 114 | del interaction, losses 115 | return total_loss 116 | 117 | def _valid_epoch(self, valid_data, reverse=False): 118 | """Valid the model with valid data 119 | 120 | Args: 121 | valid_data (DataLoader): the valid data. 122 | 123 | Returns: 124 | float: valid score 125 | dict: valid result 126 | """ 127 | valid_result, valid_result_str = self.evaluate(valid_data, load_best_model=False, reverse=reverse) 128 | valid_score = valid_result[self.valid_metric] 129 | return valid_score, valid_result, valid_result_str 130 | 131 | def _save_checkpoint(self, epoch): 132 | """Store the model parameters information and training information. 133 | 134 | Args: 135 | epoch (int): the current epoch id 136 | """ 137 | state = { 138 | 'config': self.config, 139 | 'epoch': epoch, 140 | 'cur_step': self.cur_step, 141 | 'best_valid_score': self.best_valid_score, 142 | 'state_dict': self.model.state_dict(), 143 | 'optimizer': self.optimizer.state_dict(), 144 | } 145 | torch.save(state, self.saved_model_file) 146 | 147 | def resume_checkpoint(self, resume_file): 148 | """Load the model parameters information and training information. 149 | 150 | Args: 151 | resume_file (file): the checkpoint file 152 | """ 153 | resume_file = str(resume_file) 154 | checkpoint = torch.load(resume_file) 155 | self.start_epoch = checkpoint['epoch'] + 1 156 | self.cur_step = checkpoint['cur_step'] 157 | self.best_valid_score = checkpoint['best_valid_score'] 158 | 159 | # load architecture params from checkpoint 160 | if checkpoint['config']['model'].lower() != self.config['model'].lower(): 161 | self.logger.warning('Architecture configuration given in config file is different from that of checkpoint. ' 162 | 'This may yield an exception while state_dict is being loaded.') 163 | self.model.load_state_dict(checkpoint['state_dict']) 164 | 165 | # load optimizer state from checkpoint only when optimizer type is not changed 166 | self.optimizer.load_state_dict(checkpoint['optimizer']) 167 | message_output = 'Checkpoint loaded. Resume training from epoch {}'.format(self.start_epoch) 168 | self.logger.info(message_output) 169 | 170 | def _check_nan(self, loss): 171 | if torch.isnan(loss): 172 | raise ValueError('Training loss is nan') 173 | 174 | def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses): 175 | des = self.config['loss_decimal_place'] or 4 176 | train_loss_output = 'epoch %d training [time: %.2fs, ' % (epoch_idx, e_time - s_time) 177 | if isinstance(losses, tuple): 178 | des = 'train_loss%d: %.' + str(des) + 'f' 179 | train_loss_output += ', '.join(des % (idx + 1, loss) for idx, loss in enumerate(losses)) 180 | else: 181 | des = '%.' + str(des) + 'f' 182 | train_loss_output += 'train loss:' + des % losses 183 | return train_loss_output + ']' 184 | 185 | def fit(self, train_data, valid_data_g=None, valid_data_j=None, verbose=True, saved=True): 186 | """Train the model based on the train data and the valid data. 187 | 188 | Args: 189 | train_data (DataLoader): the train data 190 | valid_data_g (DataLoader, optional): the valid data of geek, default: None. 191 | If it's None, the early_stopping is invalid. 192 | valid_data_j (DataLoader, optional): the valid data of job, default: None. 193 | If it's None, the early_stopping is invalid. 194 | verbose (bool, optional): whether to write training and evaluation information to logger, default: True 195 | saved (bool, optional): whether to save the model parameters, default: True 196 | 197 | Returns: 198 | (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None) 199 | """ 200 | if saved and self.start_epoch >= self.epochs: 201 | self._save_checkpoint(-1) 202 | 203 | for epoch_idx in range(self.start_epoch, self.epochs): 204 | # train 205 | training_start_time = time() 206 | train_loss = self._train_epoch(train_data, epoch_idx) 207 | self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss 208 | training_end_time = time() 209 | train_loss_output = \ 210 | self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) 211 | if verbose: 212 | self.logger.info(train_loss_output) 213 | 214 | # eval 215 | if self.eval_step <= 0 or not valid_data_g or not valid_data_j: 216 | if saved: 217 | self._save_checkpoint(epoch_idx) 218 | update_output = 'Saving current: %s' % self.saved_model_file 219 | if verbose: 220 | self.logger.info(update_output) 221 | continue 222 | if (epoch_idx + 1) % self.eval_step == 0: 223 | valid_start_time = time() 224 | 225 | valid_score_g, valid_result_g, valid_result_str_g = self._valid_epoch(valid_data_g, reverse=False) 226 | valid_score_j, valid_result_j, valid_result_str_j = self._valid_epoch(valid_data_j, reverse=True) # for evaluate in job direction 227 | 228 | valid_score = (valid_score_g + valid_score_j) / 2 229 | 230 | self.best_valid_score, self.cur_step, stop_flag, update_flag = self._early_stopping( 231 | valid_score, self.best_valid_score, self.cur_step, max_step=self.stopping_step) 232 | valid_end_time = time() 233 | valid_score_output = "epoch %d evaluating [time: %.2fs, valid_score: %f]" % \ 234 | (epoch_idx, valid_end_time - valid_start_time, valid_score) 235 | valid_result_g_output = 'valid result for geek:' + valid_result_str_g 236 | valid_result_j_output = 'valid result for job:' + valid_result_str_j 237 | 238 | if verbose: 239 | self.logger.info(valid_score_output) 240 | self.logger.info(valid_result_g_output) 241 | self.logger.info(valid_result_j_output) 242 | if update_flag: 243 | if saved: 244 | self._save_checkpoint(epoch_idx) 245 | update_output = 'Saving current best: %s' % self.saved_model_file 246 | if verbose: 247 | self.logger.info(update_output) 248 | self.best_valid_result_g = valid_result_g 249 | self.best_valid_result_j = valid_result_j 250 | 251 | if stop_flag: 252 | stop_output = 'Finished training, best eval result in epoch %d' % \ 253 | (epoch_idx - self.cur_step * self.eval_step) 254 | if verbose: 255 | self.logger.info(stop_output) 256 | break 257 | return self.best_valid_score, self.best_valid_result_g, self.best_valid_result_j 258 | 259 | def _early_stopping(self, value, best, cur_step, max_step): 260 | """validation-based early stopping 261 | 262 | Args: 263 | value (float): current result 264 | best (float): best result 265 | cur_step (int): the number of consecutive steps that did not exceed the best result 266 | max_step (int): threshold steps for stopping 267 | 268 | Returns: 269 | tuple: 270 | - float, 271 | best result after this step 272 | - int, 273 | the number of consecutive steps that did not exceed the best result after this step 274 | - bool, 275 | whether to stop 276 | - bool, 277 | whether to update 278 | """ 279 | stop_flag = False 280 | update_flag = False 281 | if value > best: 282 | cur_step = 0 283 | best = value 284 | update_flag = True 285 | else: 286 | cur_step += 1 287 | if cur_step > max_step: 288 | stop_flag = True 289 | return best, cur_step, stop_flag, update_flag 290 | 291 | @torch.no_grad() 292 | def evaluate(self, eval_data, load_best_model=True, model_file=None, 293 | save_score=False, group='all', reverse=False): 294 | """Evaluate the model based on the eval data. 295 | 296 | Args: 297 | eval_data (DataLoader): the eval data 298 | load_best_model (bool, optional): whether load the best model in the training process, default: True. 299 | It should be set True, if users want to test the model after training. 300 | model_file (str, optional): the saved model file, default: None. If users want to test the previously 301 | trained model file, they can set this parameter. 302 | save_score (bool): Save .score file to running dir if ``True``. Defaults to ``False``. 303 | group (str): Which group to evaluate, can be ``all``, ``weak``, ``skilled``. 304 | 305 | Returns: 306 | dict: eval result, key is the eval metric and value in the corresponding metric value 307 | """ 308 | if not eval_data: 309 | return 310 | 311 | if save_score: 312 | model_name = self.config['model'] 313 | tag = 'job' if reverse else 'user' 314 | 315 | if load_best_model: 316 | if model_file: 317 | checkpoint_file = model_file 318 | else: 319 | checkpoint_file = self.saved_model_file 320 | checkpoint = torch.load(checkpoint_file) 321 | self.model.load_state_dict(checkpoint['state_dict']) 322 | message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file) 323 | self.logger.info(message_output) 324 | 325 | self.model.eval() 326 | 327 | batch_matrix_list = [] 328 | iter_data = ( 329 | tqdm( 330 | enumerate(eval_data), 331 | total=len(eval_data), 332 | desc=f"Evaluate ", 333 | ) 334 | ) 335 | 336 | for batch_idx, batched_data in iter_data: 337 | interaction = batched_data 338 | scores = self.model.predict(dict2device(interaction, self.device)) 339 | batch_matrix = self.evaluator.collect(interaction, scores, reverse) 340 | batch_matrix_list.append(batch_matrix) 341 | result, result_str = self.evaluator.evaluate(batch_matrix_list, group) 342 | 343 | 344 | return result, result_str 345 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import datetime 5 | import importlib 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def init_seed(seed, reproducibility): 12 | """ init random seed for random functions in numpy, torch, cuda and cudnn 13 | 14 | Args: 15 | seed (int): random seed 16 | reproducibility (bool): Whether to require reproducibility 17 | """ 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | if reproducibility: 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | else: 27 | torch.backends.cudnn.benchmark = True 28 | torch.backends.cudnn.deterministic = False 29 | 30 | 31 | def get_local_time(): 32 | """Get current time 33 | 34 | Returns: 35 | str: current time 36 | """ 37 | cur = datetime.datetime.now() 38 | cur = cur.strftime('%b-%d-%Y_%H-%M-%S') 39 | 40 | return cur 41 | 42 | 43 | def init_logger(config): 44 | """ 45 | A logger that can show a message on standard output and write it into the 46 | file named `filename` simultaneously. 47 | All the message that you want to log MUST be str. 48 | 49 | Args: 50 | config (Config): An instance object of Config, used to record parameter information. 51 | 52 | Example: 53 | >>> init_logger(config) 54 | >>> logger = logging.getLogger() 55 | >>> logger.debug(train_state) 56 | >>> logger.info(train_result) 57 | """ 58 | LOGROOT = './log/' 59 | dir_name = os.path.dirname(LOGROOT) 60 | ensure_dir(dir_name) 61 | 62 | logfilename = '{}-{}.log'.format(config['model'], get_local_time()) 63 | 64 | logfilepath = os.path.join(LOGROOT, logfilename) 65 | 66 | filefmt = "%(asctime)-15s %(levelname)s %(message)s" 67 | filedatefmt = "%a %d %b %Y %H:%M:%S" 68 | fileformatter = logging.Formatter(filefmt, filedatefmt) 69 | 70 | sfmt = "%(asctime)-15s %(levelname)s %(message)s" 71 | sdatefmt = "%d %b %H:%M" 72 | sformatter = logging.Formatter(sfmt, sdatefmt) 73 | if config['state'] is None or config['state'].lower() == 'info': 74 | level = logging.INFO 75 | elif config['state'].lower() == 'debug': 76 | level = logging.DEBUG 77 | elif config['state'].lower() == 'error': 78 | level = logging.ERROR 79 | elif config['state'].lower() == 'warning': 80 | level = logging.WARNING 81 | elif config['state'].lower() == 'critical': 82 | level = logging.CRITICAL 83 | else: 84 | level = logging.INFO 85 | fh = logging.FileHandler(logfilepath) 86 | fh.setLevel(level) 87 | fh.setFormatter(fileformatter) 88 | 89 | sh = logging.StreamHandler() 90 | sh.setLevel(level) 91 | sh.setFormatter(sformatter) 92 | 93 | logging.basicConfig( 94 | level=level, 95 | handlers=[fh, sh] 96 | ) 97 | 98 | 99 | def dynamic_load(config, module_path, class_name=''): 100 | module = importlib.import_module(module_path) 101 | return getattr(module, config['model'] + class_name) 102 | 103 | 104 | def ensure_dir(dir_path): 105 | """Make sure the directory exists, if it does not exist, create it. 106 | 107 | Args: 108 | dir_path (str): directory path 109 | """ 110 | if not os.path.exists(dir_path): 111 | os.makedirs(dir_path) 112 | 113 | 114 | def dict2device(dct, device): 115 | new_dct = {} 116 | for k in dct: 117 | new_dct[k] = dct[k].to(device) 118 | return new_dct 119 | --------------------------------------------------------------------------------