├── .DS_Store ├── README.md ├── bprmf.sh ├── cke.sh ├── cofm.sh ├── fm.sh ├── jTransUP ├── .DS_Store ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── data │ ├── __pycache__ │ │ ├── drawer.cpython-36.pyc │ │ ├── load_kg_rating_data.cpython-36.pyc │ │ ├── load_triple_data.cpython-36.pyc │ │ ├── pre_ml1m.cpython-36.pyc │ │ ├── preprocess.cpython-36.pyc │ │ ├── preprocessRatings.cpython-36.pyc │ │ └── preprocessTriples.cpython-36.pyc │ ├── dbpedia_connector.py │ ├── drawer.py │ ├── load_kg_rating_data.py │ ├── load_rating_data.py │ ├── load_triple_data.py │ ├── preprocessRatings.py │ └── preprocessTriples.py ├── models │ ├── .DS_Store │ ├── CFKG.py │ ├── CKE.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── CFKG.cpython-36.pyc │ │ ├── CKE.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── base.cpython-36.pyc │ │ ├── bprmf.cpython-36.pyc │ │ ├── cofm.cpython-36.pyc │ │ ├── fm.cpython-36.pyc │ │ ├── item_recommendation.cpython-36.pyc │ │ ├── jTransUP.cpython-36.pyc │ │ ├── knowledgable_recommendation.cpython-36.pyc │ │ ├── knowledge_representation.cpython-36.pyc │ │ ├── rating.cpython-36.pyc │ │ ├── transD.cpython-36.pyc │ │ ├── transE.cpython-36.pyc │ │ ├── transH.cpython-36.pyc │ │ ├── transR.cpython-36.pyc │ │ ├── transUP.cpython-36.pyc │ │ └── transUP_bias.cpython-36.pyc │ ├── base.py │ ├── bprmf.py │ ├── cofm.py │ ├── fm.py │ ├── item_recommendation.py │ ├── jTransUP.py │ ├── knowledgable_recommendation.py │ ├── knowledge_representation.py │ ├── transD.py │ ├── transE.py │ ├── transH.py │ ├── transR.py │ └── transUP.py └── utils │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── data.cpython-36.pyc │ ├── kg_log_parser.cpython-36.pyc │ ├── log_parser.cpython-36.pyc │ ├── loss.cpython-36.pyc │ ├── rec_log_parser.cpython-36.pyc │ ├── trainer.cpython-36.pyc │ └── visuliazer.cpython-36.pyc │ ├── data.py │ ├── evaluation.py │ ├── evaluation_onehot.py │ ├── kg_log_parser.py │ ├── loss.py │ ├── misc.py │ ├── rec_log_parser.py │ ├── trainer.py │ └── visuliazer.py ├── ktup.sh ├── ktup_eval.sh ├── log └── .DS_Store ├── requirements.txt ├── run_item_recommendation.py ├── run_knowledgable_recommendation.py ├── run_knowledge_representation.py ├── run_preprocess.py ├── run_test.py ├── swipe.sh ├── test.py ├── transe.sh ├── transh.sh ├── transr.sh ├── transup.sh └── tt.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unifying Knowledge Graph Learning and Recommendation Towards a Better Understanding of User Preference 2 | 3 | This is the code of the *Unifying Knowledge Graph Learning and Recommendation Towards a Better Understanding of User Preference* in WWW'19, which proposed a model that jointly train two tasks of item recommendation and KG representation learning. 4 | 5 | ## Environment 6 | 7 | python 3.6 8 | 9 | Pytorch 0.3.x 10 | 11 | visdom if visualization flag is set to True. 12 | 13 | some required packages are included in *requirements.txt*. 14 | 15 | ## Run our codes 16 | 17 | We implement the models including our proposed TUP and KTUP as well as some baselines: BPRMF, FM, CFKG, CKE, CoFM, TransE, TransH and TransR. We split them into three types: Item recommendation, knowledge representation and the joint model of two tasks, which correspond to run_item_recommendation.py, run_knowledge_representation.py and run_knowledgable_recommendation.py, respectively. Each model has an example shell file to run the code. 18 | 19 | Take item recommendation for example, to run each model, simply: 20 | 21 | `python run_item_recommendation.py -model_type REC_MODEL -dataset DATASET_NAME -data_path PATH_TO_DATASET_FOLDER -log_path PATH_TO_LOG -rec_test_files EVAL_FILENAMES -nohas_visualization` 22 | 23 | For knowledge representation, simply: 24 | 25 | `python run_knowledge_representation.py -model_type KG_MODEL -dataset DATASET_NAME -data_path PATH_TO_DATASET_FOLDER -log_path PATH_TO_LOG -kg_test_files EVAL_FILENAMES -nohas_visualization` 26 | 27 | For joint model, simplY: 28 | 29 | `python run_knowledgable_recommendation.py -model_type JOINT_MODEL -dataset DATASET_NAME -data_path PATH_TO_DATASET_FOLDER -log_path PATH_TO_LOG -rec_test_files REC_EVAL_FILENAMES -kg_test_files KG_EVAL_FILENAMES -nohas_visualization` 30 | 31 | we now describe the main flags: datasets, models and visualization: 32 | 33 | ## Datasets 34 | 35 | We use two datasets: movielens-1m (ml1m for short) and dbbook2014. We collect the related facts from DBPedia, where the triplets are directly related to the entities with mapped items, no matter which role (i.e. subject or object) the entity serves as. The processed datasets can be download [here](https://drive.google.com/file/d/1FIbaWzP6AWUNG2-8q6SKQ3b9yTiiLvGW/view?usp=sharing), and the original data is limited to their authority. 36 | 37 | The flag '-data_path' is used to specify the root path to dataset. 38 | 39 | The flag '-dataset' is used to specify the dataset, 'ml1m' for movielens-1m, 'dbbook2014' for dbbook2014, and any dataset folder names under the data path. 40 | 41 | Thus, all dataset related files shall be put into the folder '/PATH_TO_DATASET_FOLDER/DATASET_NAME/'. 42 | 43 | Now, we detail the required files. 44 | 45 | For item recommendation, the folder should contain the following files: **train.dat**, **u_map.dat**, **i_map.dat**, where each line in train.data is a triple: 'user_id item_id rating' that separated by '\t', u_map and i_map specify the mapped user_id to original user_id. The evaluation files would be specified by flag '-rec_test_files', where multiple eval files separated by ':'. Note that the first eval file is used for validation. 46 | 47 | For KG representation, the files should under the path: '/PATH_TO_DATASET_FOLDER/DATASET_NAME/kg/'. The required files contain: **train.dat**, **e_map.dat** and **r_map.dat**. Similarly, each line in train.dat is a triple: 'head_entity_id tail_entity_id relation_id' separated by '\t', e_map and r_map specify the mapped entity_id to original entity. The evaluation files would be specified by flag '-kg_test_files', where multiple eval files separated by ':'. Note that the first eval file is used for validation. 48 | 49 | The joint model requires all of the above files and **i2kg_map.tsv**, where each line consist of original item id, entity title, and original entity uri separated by tab. 50 | 51 | For example, we run our KTUP by: 52 | 53 | `python run_knowledgable_recommendation.py -model_type jtransup -dataset ml1m -data_path ~/joint-kg-recommender/datasets/ -log_path ~/Github/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -kg_test_files valid.dat:test.dat -nohas_visualization` 54 | 55 | Then, we need a folder '~/joint-kg-recommender/datasets/ml1m/' including the required files as: 56 | 57 | ``` 58 | ml1m 59 | │ train.dat 60 | │ valid.dat 61 | │ test.dat 62 | │ u_map.dat 63 | │ i_map.dat 64 | │ i2kg_map.tsv 65 | │ 66 | └───kg 67 | │ │ train.dat 68 | │ │ valid.dat 69 | │ │ test.dat 70 | │ │ e_map.dat 71 | │ │ r_map.dat 72 | ``` 73 | 74 | 75 | ## Models 76 | 77 | We use the flag '-model_type' to specify the model used, which has to be chosen from the following models: ['bpr','fm','trasne','transh','transr','cfkg','cke','cofm','transup','jtransup']. 78 | 79 | Specifically, ['bpr','fm','transup'] is for item recommendation. 'transup' is our proposed **TUP**. ['trasne','trasnh','trasnr'] is for kg representation. ['cfkg','cke','cofm','jtransup'] is for item recommendation. 'jtransup' is our proposed **KTUP**. 80 | 81 | ### Model Specific Flags 82 | 83 | For TUP, there are two specific flags: '-num_preferences' and '-use_st_gumbel', which denotes the number of user preferences and if we use hard strategy for preference induction, respectively. 84 | 85 | For joint models, there are also two flags: '-joint_ratio' and '-share_embeddings', which denote the ratio of training data in each batch between item recommendation and KG representation, and if the two tasks share embeddings of aligned items and entities. Note that for model 'cfkg', it must be '-share_embeddings' due to the its basic idea of unified graph of items and entities. For models 'cke' and 'jtransup' (KTUP), it must be '-noshare_embeddings'. 86 | 87 | ### General Flags 88 | 89 | We can also specify the general parameters by setting flags like optimizer or learning rate, which can be found in './models/base.py'. 90 | 91 | ### Visualization 92 | 93 | We use the package of visdom for visualization. If you decide to visualize the training and evaluation curve, the visdom environment is required (python -m visdom.server) and set '-has_visualization', and even the port '-visualization_port 8097'. Then, one can moniter the training and evaluation curves using the brower by entering : "http://host_ip:8097". 94 | 95 | ## Reference 96 | If you use our code, please cite our paper: 97 | ``` 98 | @inproceedings{cao2018unifying, 99 | title={Unifying Knowledge Graph Learning and Recommendation: Towards a Better Understanding of User Preference}, 100 | author={Cao, Yixin and Wang, Xiang and He, Xiangnan and Hu, Zikun and Chua Tat-seng}, 101 | booktitle={WWW}, 102 | year={2019} 103 | } 104 | ``` -------------------------------------------------------------------------------- /bprmf.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python run_item_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -l2_lambda 1e-5 -negtive_samples 1 -model_type bprmf -has_visualization -dataset dbbook2014 -batch_size 1024 -embedding_size 100 -learning_rate 0.005 -topn 10 -seed 3 -eval_interval_steps 5000 -training_steps 500000 -early_stopping_steps_to_wait 25000 -optimizer_type Adagrad -visualization_port 8097 2 | -------------------------------------------------------------------------------- /cke.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python run_knowledgable_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -kg_test_files valid.dat:test.dat -l2_lambda 0 -model_type cke -has_visualization -dataset dbbook2014 -batch_size 400 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 19520 -training_steps 1952000 -early_stopping_steps_to_wait 97600 -optimizer_type Adam -joint_ratio 0.7 -noshare_embeddings -L1_flag -norm_lambda 1 -kg_lambda 1 -use_st_gumbel -load_ckpt_file tuned_dbbook2014/dbbook2014-bprmf-1540692224.ckpt:tuned_dbbook2014/dbbook2014-transr-1540701160.ckpt -visualization_port 8098 2 | 3 | -------------------------------------------------------------------------------- /cofm.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python run_knowledgable_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -kg_test_files valid.dat:test.dat -l2_lambda 0 -model_type cofm -has_visualization -dataset dbbook2014 -batch_size 400 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 19520 -training_steps 1952000 -early_stopping_steps_to_wait 97600 -optimizer_type Adam -joint_ratio 0.7 -load_ckpt_file tuned_dbbook2014/dbbook2014-bprmf-1540692224.ckpt:tuned_dbbook2014/dbbook2014-transe-1540685958.ckpt -noshare_embeddings -L1_flag -norm_lambda 1 -kg_lambda 1 -visualization_port 8097 2 | -------------------------------------------------------------------------------- /fm.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python run_item_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -l2_lambda 1e-5 -negtive_samples 1 -model_type fm -has_visualization -dataset dbbook2014 -batch_size 1024 -embedding_size 100 -learning_rate 0.1 -topn 10 -seed 3 -eval_interval_steps 500 -training_steps 50000 -early_stopping_steps_to_wait 2500 -optimizer_type Adagrad -visualization_port 8097 2 | 3 | -------------------------------------------------------------------------------- /jTransUP/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/.DS_Store -------------------------------------------------------------------------------- /jTransUP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/__init__.py -------------------------------------------------------------------------------- /jTransUP/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/__pycache__/drawer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/data/__pycache__/drawer.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/__pycache__/load_kg_rating_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/data/__pycache__/load_kg_rating_data.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/__pycache__/load_triple_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/data/__pycache__/load_triple_data.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/__pycache__/pre_ml1m.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/data/__pycache__/pre_ml1m.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/__pycache__/preprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/data/__pycache__/preprocess.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/__pycache__/preprocessRatings.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/data/__pycache__/preprocessRatings.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/__pycache__/preprocessTriples.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/data/__pycache__/preprocessTriples.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/data/dbpedia_connector.py: -------------------------------------------------------------------------------- 1 | from SPARQLWrapper import SPARQLWrapper, JSON 2 | import json 3 | import os 4 | import time 5 | 6 | def loadItemToKGMap(filename): 7 | with open(filename, 'r') as fin: 8 | item_to_kg_dict = {} 9 | for line in fin: 10 | line_split = line.strip().split('\t') 11 | if len(line_split) < 3 : continue 12 | item_id = line_split[0] 13 | db_uri = line_split[2] 14 | item_to_kg_dict[item_id] = db_uri 15 | return item_to_kg_dict 16 | 17 | def getHeadQuery(ent): 18 | return "SELECT * WHERE { <%s> ?p ?o }" % ent 19 | 20 | def getTailQuery(ent): 21 | return "SELECT * WHERE { ?s ?p <%s> }" % ent 22 | 23 | def cleanHeadResults(results): 24 | results_cl = [] 25 | predicate_set = set() 26 | entity_set = set() 27 | for result in results["results"]["bindings"]: 28 | # skip those non-eng 29 | if result['o']['type'] == 'literal' and 'xml:lang' in result['o'] and \ 30 | result['o']['xml:lang'] != 'en': 31 | continue 32 | 33 | if result['o']['type'] == 'uri': 34 | entity_set.add(result['o']['value']) 35 | predicate_set.add(result['p']['value']) 36 | results_cl.append(result) 37 | return results_cl, predicate_set, entity_set 38 | 39 | def cleanTailResults(results): 40 | results_cl = [] 41 | predicate_set = set() 42 | entity_set = set() 43 | for result in results["results"]["bindings"]: 44 | entity_set.add(result['s']['value']) 45 | predicate_set.add(result['p']['value']) 46 | results_cl.append(result) 47 | return results_cl, predicate_set, entity_set 48 | 49 | def downloadDBPedia(sparql, fout, entities, asTail=True): 50 | sec_to_wait = 60 51 | for ent in entities: 52 | print("downloading {} ...".format(ent)) 53 | while True: 54 | try: 55 | sparql.setQuery(getHeadQuery(ent)) 56 | head_results = sparql.query().convert() 57 | break 58 | except: 59 | print("http failure! wait %d seconds to retry..." % sec_to_wait) 60 | time.sleep(sec_to_wait) 61 | 62 | head_results_cl, predicate_set, entity_set = cleanHeadResults(head_results) 63 | head_json_str = json.dumps(head_results_cl) 64 | 65 | if asTail: 66 | while True: 67 | try: 68 | sparql.setQuery(getTailQuery(ent)) 69 | tail_results = sparql.query().convert() 70 | break 71 | except: 72 | print("http failure! wait %d seconds to retry..." % sec_to_wait) 73 | time.sleep(sec_to_wait) 74 | tail_results_cl, tail_predicate_set, tail_entity_set = cleanTailResults(tail_results) 75 | tail_json_str = json.dumps(tail_results_cl) 76 | predicate_set |= tail_predicate_set 77 | entity_set |= tail_entity_set 78 | 79 | fout.write(ent + '\t' + head_json_str + '\t' + tail_json_str + '\n') 80 | print("finish! {} entities and {} predicates!".format(len(entity_set), len(predicate_set))) 81 | time.sleep(1) 82 | return entity_set, predicate_set 83 | 84 | if __name__ == "__main__": 85 | n_hop = 1 86 | 87 | sparql = SPARQLWrapper("http://dbpedia.org/sparql") 88 | sparql.setReturnFormat(JSON) 89 | 90 | item2kg_file = "/home/ethan/Github/joint-kg-recommender/datasets/ml1m/MappingMovielens2DBpedia-1.2.tsv" 91 | kg_path = "/home/ethan/Github/joint-kg-recommender/datasets/ml1m/kg/" 92 | 93 | # item2kg_file = "/home/ethan/Github/joint-kg-recommender/datasets/dbbook2014/DBbook_Items_DBpedia_mapping.tsv" 94 | # kg_path = "/home/ethan/Github/joint-kg-recommender/datasets/dbbook2014/kg/" 95 | 96 | all_predicate_set = set() 97 | all_entity_set = set() 98 | 99 | item2kg_dict = loadItemToKGMap(item2kg_file) 100 | item_entities = set(item2kg_dict.values()) 101 | 102 | all_entity_set.update(item_entities) 103 | 104 | input_entities = item_entities 105 | for i in range(n_hop): 106 | kg_file = os.path.join(kg_path, "kg_hop%d.dat" % i) 107 | with open(kg_file, 'a') as fout: 108 | entity_set, predicate_set = downloadDBPedia(sparql, fout, input_entities, asTail=True) 109 | input_entities = entity_set - all_entity_set 110 | 111 | all_predicate_set |= predicate_set 112 | all_entity_set |= entity_set 113 | 114 | predicate_file = os.path.join(kg_path, "predicate_vocab.dat") 115 | with open(predicate_file, 'w') as fout: 116 | for pred in all_predicate_set: 117 | fout.write(pred + '\n') 118 | 119 | entity_file = os.path.join(kg_path, "entity_vocab.dat") 120 | with open(entity_file, 'w') as fout: 121 | for ent in all_entity_set: 122 | fout.write(ent + '\n') -------------------------------------------------------------------------------- /jTransUP/data/drawer.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | 5 | def drawDBbook(): 6 | width = .55 # width of a bar 7 | ''' 8 | m1_t = pd.DataFrame({ 9 | 'users' : [4770, 812, 276, 123, 59], 10 | 'cfkg' : [13.17, 11.59, 9.31, 7.67, 5.54], 11 | 'cofm_s' : [13.5, 13, 9.89, 7.86, 5.62], 12 | 'cofm_r' : [13.38, 12.81, 9.95, 7.81, 5.64]}) 13 | ''' 14 | m1_t = pd.DataFrame({ 15 | 'users' : [8, 12, 16, 17], 16 | 'fm' : [3.92, 5.96, 9.88, 10.6], 17 | 'bprmf' : [4.15, 6.27, 9.82, 10.7], 18 | 'cfkg' : [3.53, 5.59, 8.42, 8], 19 | 'cke' : [3.82, 7.31, 12.23, 13.8], 20 | 'cofm_s' : [3.51, 6.51, 9.89, 10.13], 21 | 'cofm_r' : [3.62, 5.95, 9.58, 11.09], 22 | 'transup_h' : [3.80, 5.99, 9.70, 10.42], 23 | 'transup_s' : [4.32, 6.45, 10.19, 11.90], 24 | 'jtransup_h' : [4.24, 7.21, 12.24, 13.74], 25 | 'jtransup_s' : [4.22, 7.49, 12.07, 13.78], 26 | }) 27 | 28 | m1_t['fm'].plot(secondary_y=True, color = '#95e1d3', linestyle='-', marker='o') 29 | m1_t['bprmf'].plot(secondary_y=True, color = '#8b4c8c', linestyle='-', marker='^') 30 | m1_t['cfkg'].plot(secondary_y=True, color = '#0d627a', linestyle='-', marker='s') 31 | m1_t['cke'].plot(secondary_y=True, color = '#0c907d', linestyle='-', marker='s') 32 | m1_t['cofm_s'].plot(secondary_y=True, color = '#cce490', linestyle='-', marker='s') 33 | m1_t['cofm_r'].plot(secondary_y=True, color = '#f2f4b2', linestyle='-', marker='s') 34 | m1_t['transup_h'].plot(secondary_y=True, color = '#f57665', linestyle='-', marker='s') 35 | m1_t['transup_s'].plot(secondary_y=True, color = '#48466d', linestyle='-', marker='s') 36 | m1_t['jtransup_h'].plot(secondary_y=True, color = '#3d84a8', linestyle='-', marker='s') 37 | m1_t['jtransup_s'].plot(secondary_y=True, color = '#46cdcf', linestyle='-', marker='s') 38 | ax = m1_t['users'].plot(kind='bar', width = width, color = '#abedd8') 39 | 40 | plt.xlim([-width, len(m1_t['users'])-width]) 41 | ax.set_xticklabels(('2858', '1370', '877', '111'), rotation = 45) 42 | plt.legend() 43 | plt.show() 44 | 45 | def drawMl1m(): 46 | width = .55 # width of a bar 47 | 48 | m1_t = pd.DataFrame({ 49 | 'Ratings' : [17, 30, 50, 70, 89, 123, 174, 244, 347, 563], 50 | 'FM' : [7.55, 10.9, 14.11, 15.6, 15.88, 14.78, 13.4, 11.64, 9.86, 7.66], 51 | 'BPRMF' : [8.33, 12.4, 15.29, 16.59, 16.81, 15.42, 13.92, 11.98, 10, 7.67], 52 | 'CFKG' : [7.62, 12.3, 15.7, 15.35, 15.16, 14.23, 13.04, 11.36, 9.52, 7.68], 53 | 'CKE' : [10.76, 15.92, 19.32, 20.07, 19.63, 18.66, 16.71, 14.49, 11.67, 8.27], 54 | # 'CoFM_s' : [7.59, 11.96, 15.31, 16.09, 16.3, 15.88, 14.57, 12.67, 10.32, 7.88], 55 | 'CoFM' : [7.46, 11.63, 15.36, 16.14, 16.22, 15.67, 14.64, 12.44, 10.4, 7.91], 56 | 'TUP' : [11.66, 17.04, 20.26, 19.46, 19.56, 18.07, 15.61, 13.33, 10.9, 8.01], 57 | # 'TransUP_s' : [11.16, 16.94, 20.7, 19.48, 19.12, 18.17, 15.54, 13.27, 10.75, 7.92], 58 | # 'jTransUP_h' : [10.14, 15.81, 18.63, 19.78, 20, 18.87, 17.76, 15.22, 12.06, 8.34], 59 | 'KTUP' : [10.08, 15.78, 19.21, 20, 20.17, 19.65, 17.94, 15.18, 11.77, 8.34] 60 | }) 61 | 62 | ax1 = m1_t['FM'].plot(secondary_y=True, color = '#8293ff', linestyle='--', marker='^') 63 | m1_t['BPRMF'].plot(secondary_y=True, color = '#503bff', linestyle='--', marker='p') 64 | 65 | m1_t['CFKG'].plot(secondary_y=True, color = '#f08a5d', linestyle='-', marker='*') 66 | m1_t['CKE'].plot(secondary_y=True, color = '#b83b5e', linestyle='-', marker='x') 67 | # m1_t['cofm_s'].plot(secondary_y=True, color = '#cce490', linestyle='-', marker='s') 68 | m1_t['CoFM'].plot(secondary_y=True, color = '#6a2c70', linestyle='-', marker='+') 69 | m1_t['TUP'].plot(secondary_y=True, color = '#48466d', linestyle='--', marker='o') 70 | # m1_t['TransUP_s'].plot(secondary_y=True, color = '#48466d', linestyle='--', marker='o') 71 | 72 | # m1_t['jtransup_h'].plot(secondary_y=True, color = '#393e46', linestyle='-', marker='s') 73 | m1_t['KTUP'].plot(secondary_y=True, color = '#222831', linestyle='-', marker='s') 74 | ax = m1_t['Ratings'].plot(kind='bar', width = width, color = '#abedd8') 75 | 76 | ax.set_xlabel('User Group') 77 | ax.set_ylabel('# Avg. Ratings', color='#45b7b7') 78 | ax1.set_ylabel('F1 score (%)') 79 | 80 | plt.xlim([-width, len(m1_t['Ratings'])-width]) 81 | # ax.set_xticklabels(('684', '1364', '773', '585', '438', '750', '410', '517', '234', '285'), rotation = 45) 82 | plt.legend() 83 | plt.show() 84 | 85 | def drawMl1mTopnPrec(): 86 | width = .55 # width of a bar 87 | 88 | prec = pd.DataFrame({ 89 | 'Topn' : [1, 2, 5, 10, 20, 50, 100], 90 | 'FM' : [33.54, 31.68, 28.1, 24.54, 20.92, 15.89, 12.22], 91 | 'BPRMF' : [34.92, 32.88, 29.26, 25.74, 21.97, 16.64, 12.74], 92 | 'CFKG' : [32.5, 30.6, 27.52, 24.56, 21.16, 16.24, 12.39], 93 | 'CKE' : [39.11, 36.92, 34.16, 30.91, 26.60, 20.11, 14.8], 94 | 'CoFM' : [29.09, 29.13, 28.18, 26.12, 22.95, 17.57, 13.24], 95 | 'TransUP' : [38.53, 36.95, 33.64, 30.19, 26.13, 19.66, 14.54], 96 | 'jTransUP' : [40.84, 39.38, 35.55,31.62, 26.97, 19.95, 14.45] 97 | }) 98 | 99 | ax1 = prec['FM'].plot(secondary_y=False, color = '#8293ff', linestyle='--', marker='^') 100 | prec['BPRMF'].plot(secondary_y=False, color = '#503bff', linestyle='--', marker='p') 101 | 102 | prec['CFKG'].plot(secondary_y=False, color = '#f08a5d', linestyle='-', marker='*') 103 | prec['CKE'].plot(secondary_y=False, color = '#b83b5e', linestyle='-', marker='x') 104 | # m1_t['cofm_s'].plot(secondary_y=True, color = '#cce490', linestyle='-', marker='s') 105 | prec['CoFM'].plot(secondary_y=False, color = '#6a2c70', linestyle='-', marker='+') 106 | # m1_t['transup_h'].plot(secondary_y=True, color = '#f57665', linestyle='-', marker='s') 107 | prec['TransUP'].plot(secondary_y=False, color = '#48466d', linestyle='--', marker='o') 108 | 109 | # m1_t['jtransup_h'].plot(secondary_y=True, color = '#393e46', linestyle='-', marker='s') 110 | prec['jTransUP'].plot(secondary_y=False, color = '#222831', linestyle='-', marker='s') 111 | 112 | ax1.set_xlabel('Topn') 113 | ax1.set_ylabel('Precision@N') 114 | ax1.set_xticks([0,1,2,3,4,5,6]) 115 | ax1.set_xticklabels(('1', '2', '5', '10', '20', '50', '100')) 116 | 117 | # plt.legend() 118 | plt.show() 119 | 120 | def drawMl1mTopnRecall(): 121 | width = .55 # width of a bar 122 | recall = pd.DataFrame({ 123 | 'Topn' : [1, 2, 5, 10, 20, 50, 100], 124 | 'FM' : [1.72, 3.16, 6.63, 10.84, 17.48, 30.53, 44.32], 125 | 'BPRMF' : [1.83, 3.35, 7.15, 11.77, 18.81, 32.56, 46.7], 126 | 'CFKG' : [1.73, 3.14, 6.64, 11.29, 18.38, 32.19, 45.65], 127 | 'CKE' : [2.2, 3.94, 8.54, 14.64, 23.2, 39.41, 53.25], 128 | 'CoFM' : [1.35, 2.7, 6.44, 11.54, 19.27, 34.04, 47.94], 129 | 'TransUP' : [2.29, 4.25, 9.05, 15.08, 24.09, 40.46, 54.92], 130 | 'jTransUP' : [2.25, 4.18, 8.81, 14.61, 23.01, 38.19, 51.38] 131 | }) 132 | ax1 = recall['FM'].plot(secondary_y=False, color = '#8293ff', linestyle='--', marker='^') 133 | recall['BPRMF'].plot(secondary_y=False, color = '#503bff', linestyle='--', marker='p') 134 | 135 | recall['CFKG'].plot(secondary_y=False, color = '#f08a5d', linestyle='-', marker='*') 136 | recall['CKE'].plot(secondary_y=False, color = '#b83b5e', linestyle='-', marker='x') 137 | # m1_t['cofm_s'].plot(secondary_y=True, color = '#cce490', linestyle='-', marker='s') 138 | recall['CoFM'].plot(secondary_y=False, color = '#6a2c70', linestyle='-', marker='+') 139 | # m1_t['transup_h'].plot(secondary_y=True, color = '#f57665', linestyle='-', marker='s') 140 | recall['TransUP'].plot(secondary_y=False, color = '#48466d', linestyle='--', marker='o') 141 | 142 | # m1_t['jtransup_h'].plot(secondary_y=True, color = '#393e46', linestyle='-', marker='s') 143 | recall['jTransUP'].plot(secondary_y=False, color = '#222831', linestyle='-', marker='s') 144 | 145 | ax1.set_xlabel('Topn') 146 | ax1.set_ylabel('Recall@N') 147 | ax1.set_xticks([0,1,2,3,4,5,6]) 148 | ax1.set_xticklabels(('1', '2', '5', '10', '20', '50', '100')) 149 | 150 | # plt.legend() 151 | plt.show() 152 | 153 | def drawMl1mTopnF1(): 154 | width = .55 # width of a bar 155 | f1 = pd.DataFrame({ 156 | 'Topn' : [1, 2, 5, 10, 20, 50, 100], 157 | 'FM' : [3.1, 5.24, 9.17, 12.27, 15.12, 16.85, 16.07], 158 | 'BPRMF' : [3.29, 5.55, 9.78, 13.16, 16.11, 17.76, 16.8], 159 | 'CFKG' : [3.11, 5.18, 9.11, 12.59, 15.57, 17.37, 16.34], 160 | 'CKE' : [3.94, 6.49, 11.66, 16.19, 19.74, 21.48, 19.39], 161 | 'CoFM' : [2.44, 4.51, 8.97, 13.08, 16.68, 18.72, 17.43], 162 | 'TransUP' : [4.08, 6.92, 12.09, 16.32, 19.94, 21.41, 19.34], 163 | 'jTransUP' : [4.04, 6.89, 12.07, 16.34, 19.8, 21.1, 18.83] 164 | }) 165 | ax1 = f1['FM'].plot(secondary_y=False, color = '#8293ff', linestyle='--', marker='^') 166 | f1['BPRMF'].plot(secondary_y=False, color = '#503bff', linestyle='--', marker='p') 167 | 168 | f1['CFKG'].plot(secondary_y=False, color = '#f08a5d', linestyle='-', marker='*') 169 | f1['CKE'].plot(secondary_y=False, color = '#b83b5e', linestyle='-', marker='x') 170 | # m1_t['cofm_s'].plot(secondary_y=True, color = '#cce490', linestyle='-', marker='s') 171 | f1['CoFM'].plot(secondary_y=False, color = '#6a2c70', linestyle='-', marker='+') 172 | # m1_t['transup_h'].plot(secondary_y=True, color = '#f57665', linestyle='-', marker='s') 173 | f1['TransUP'].plot(secondary_y=False, color = '#48466d', linestyle='--', marker='o') 174 | 175 | # m1_t['jtransup_h'].plot(secondary_y=True, color = '#393e46', linestyle='-', marker='s') 176 | f1['jTransUP'].plot(secondary_y=False, color = '#222831', linestyle='-', marker='s') 177 | 178 | ax1.set_xlabel('Topn') 179 | ax1.set_ylabel('F1@N') 180 | ax1.set_xticks([0,1,2,3,4,5,6]) 181 | ax1.set_xticklabels(('1', '2', '5', '10', '20', '50', '100')) 182 | 183 | plt.legend() 184 | plt.show() 185 | 186 | def drawMl1mTopnHits(): 187 | width = .55 # width of a bar 188 | hits = pd.DataFrame({ 189 | 'Topn' : [1, 2, 5, 10, 20, 50, 100], 190 | 'CFKG' : [3.25, 11.23, 23.33, 32.08, 41.13, 53.18, 62.38], 191 | 'CKE' : [1.66, 6.9, 17.15, 25.01, 33.28, 45.18, 54.78], 192 | 'CoFM' : [3.32, 12.67, 26.77, 36.46, 45.67, 56.91, 64.98], 193 | 'transe' : [3.28, 13.03, 27.23, 36.59, 45.48, 56.33, 64.07], 194 | 'transh' : [3.43, 13.67, 27.97, 37.25, 45.92, 56.62, 64.35], 195 | 'transr' : [3.27, 10.48, 21.22, 29.14, 37.49, 48.94, 57.87], 196 | 'jtransup' : [3.38, 13.64, 28.28, 38.05, 47.07, 57.78, 65.14] 197 | }) 198 | ax1 = hits['transe'].plot(secondary_y=False, color = '#8293ff', linestyle='--', marker='^') 199 | hits['transh'].plot(secondary_y=False, color = '#503bff', linestyle='--', marker='p') 200 | 201 | hits['transr'].plot(secondary_y=False, color = '#f08a5d', linestyle='-', marker='*') 202 | hits['CFKG'].plot(secondary_y=False, color = '#f08a5d', linestyle='-', marker='*') 203 | hits['CKE'].plot(secondary_y=False, color = '#b83b5e', linestyle='-', marker='x') 204 | # m1_t['cofm_s'].plot(secondary_y=True, color = '#cce490', linestyle='-', marker='s') 205 | hits['CoFM'].plot(secondary_y=False, color = '#6a2c70', linestyle='-', marker='+') 206 | # m1_t['transup_h'].plot(secondary_y=True, color = '#f57665', linestyle='-', marker='s') 207 | 208 | # m1_t['jtransup_h'].plot(secondary_y=True, color = '#393e46', linestyle='-', marker='s') 209 | hits['jtransup'].plot(secondary_y=False, color = '#222831', linestyle='-', marker='s') 210 | 211 | ax1.set_xlabel('Topn') 212 | ax1.set_ylabel('F1@N') 213 | ax1.set_xticks([0,1,2,3,4,5,6]) 214 | ax1.set_xticklabels(('1', '2', '5', '10', '20', '50', '100')) 215 | 216 | plt.legend() 217 | plt.show() 218 | 219 | if __name__ == "__main__": 220 | plt.rcParams.update({'font.size': 14}) 221 | drawMl1m() 222 | # drawDBbook() 223 | # drawMl1mTopnPrec() 224 | # drawMl1mTopnRecall() 225 | # drawMl1mTopnF1() 226 | # drawMl1mTopnHits() -------------------------------------------------------------------------------- /jTransUP/data/load_kg_rating_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jTransUP.data import load_rating_data, load_triple_data 3 | 4 | # two items refer to the same entity 5 | def loadR2KgMap(filename): 6 | i2kg_map = {} 7 | kg2i_map = {} 8 | with open(filename, 'r', encoding='utf-8') as fin: 9 | for line in fin: 10 | line_split = line.strip().split('\t') 11 | if len(line_split) != 3 : continue 12 | i_id = line_split[0] 13 | kg_uri = line_split[2] 14 | i2kg_map[i_id] = kg_uri 15 | kg2i_map[kg_uri] = i_id 16 | print("successful load {} item and {} entity pairs!".format(len(i2kg_map), len(kg2i_map))) 17 | return i2kg_map, kg2i_map 18 | 19 | # map: org:id 20 | # link: org(map1):org(map2) 21 | def rebuildEntityItemVocab(map1, map2, links): 22 | new_map = {} 23 | index = 0 24 | has_map2 = {} 25 | remap1 = {} 26 | for org_id1 in map1: 27 | mapped_id2 = -1 28 | if org_id1 in links: 29 | org_id2 = links[org_id1] 30 | if org_id2 in map2: 31 | mapped_id2 = map2[org_id2] 32 | has_map2[org_id2] = index 33 | new_map[index] = (map1[org_id1], mapped_id2) 34 | 35 | remap1[map1[org_id1]] = index 36 | index += 1 37 | 38 | remap2 = {} 39 | mapped_id1 = -1 40 | for org_id2 in map2: 41 | if org_id2 in has_map2 : 42 | remap2[map2[org_id2]] = has_map2[org_id2] 43 | continue 44 | new_map[index] = (mapped_id1, map2[org_id2]) 45 | 46 | remap2[map2[org_id2]] = index 47 | index += 1 48 | return new_map, remap1, remap2, len(has_map2) 49 | 50 | 51 | def load_data(data_path, rec_eval_files, kg_eval_files, batch_size, negtive_samples=1, logger=None): 52 | kg_path = os.path.join(data_path, 'kg') 53 | map_file = os.path.join(data_path, 'i2kg_map.tsv') 54 | 55 | rating_train_dataset, rating_eval_datasets, u_map, i_map = load_rating_data.load_data(data_path, rec_eval_files, batch_size, logger=logger, negtive_samples=negtive_samples) 56 | 57 | triple_train_dataset, triple_eval_datasets, e_map, r_map = load_triple_data.load_data(kg_path, kg_eval_files, batch_size, logger=logger, negtive_samples=negtive_samples) 58 | 59 | i2kg_map, kg2i_map = loadR2KgMap(map_file) 60 | # e_map,imap org--> new id 61 | ikg_map, e_remap, i_remap, aligned_ie_total = rebuildEntityItemVocab(e_map, i_map, kg2i_map) 62 | 63 | if logger is not None: 64 | logger.info("Find {} aligned items and entities!".format(aligned_ie_total)) 65 | 66 | return rating_train_dataset, rating_eval_datasets, u_map, i_remap, triple_train_dataset, triple_eval_datasets, e_remap, r_map, ikg_map -------------------------------------------------------------------------------- /jTransUP/data/load_rating_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from jTransUP.utils.data import MakeTrainIterator, MakeEvalIterator 4 | 5 | # org--> id 6 | def loadVocab(filename): 7 | with open(filename, 'r', encoding='utf-8') as fin: 8 | vocab = {} 9 | for line in fin: 10 | line_split = line.strip().split('\t') 11 | if len(line_split) != 2 : continue 12 | mapped_id = int(line_split[0]) 13 | org_id = line_split[1] 14 | vocab[org_id] = mapped_id 15 | 16 | return vocab 17 | 18 | # dict:{u_id:set(i_ids), ... } 19 | def loadRatings(filename): 20 | with open(filename, 'r', encoding='utf-8') as fin: 21 | rating_total = 0 22 | rating_list = [] 23 | rating_dict = {} 24 | for line in fin: 25 | line_split = line.strip().split('\t') 26 | if len(line_split) != 3 : continue 27 | u_id = int(line_split[0]) 28 | i_id = int(line_split[1]) 29 | r_score = int(line_split[2]) 30 | rating_list.append( (u_id, i_id) ) 31 | 32 | tmp_items = rating_dict.get(u_id, set()) 33 | tmp_items.add(i_id) 34 | rating_dict[u_id] = tmp_items 35 | 36 | rating_total += 1 37 | 38 | return rating_total, rating_list, rating_dict 39 | 40 | def load_data(data_path, eval_filenames, batch_size, negtive_samples=1, logger=None): 41 | 42 | train_file = os.path.join(data_path, "train.dat") 43 | 44 | eval_files = [] 45 | for file_name in eval_filenames: 46 | eval_files.append(os.path.join(data_path, file_name)) 47 | 48 | u_map_file = os.path.join(data_path, "u_map.dat") 49 | i_map_file = os.path.join(data_path, "i_map.dat") 50 | 51 | train_total, train_list, train_dict = loadRatings(train_file) 52 | 53 | eval_dataset = [] 54 | for eval_file in eval_files: 55 | eval_dataset.append( loadRatings(eval_file) ) 56 | 57 | if logger is not None: 58 | eval_totals = [str(eval_data[0]) for eval_data in eval_dataset] 59 | logger.info("Totally {} train ratings, {} eval ratings in files: {}!".format(train_total, ",".join(eval_totals), ";".join(eval_files))) 60 | 61 | # get user total 62 | u_map = loadVocab(u_map_file) 63 | # get item total 64 | i_map = loadVocab(i_map_file) 65 | 66 | if logger is not None: 67 | logger.info("successfully load {} users and {} items!".format(len(u_map), len(i_map))) 68 | 69 | train_iter = MakeTrainIterator(train_list, batch_size, negtive_samples=negtive_samples) 70 | 71 | new_eval_datasets = [] 72 | dt = np.dtype('int') 73 | for eval_data in eval_dataset: 74 | tmp_iter = MakeEvalIterator(list(eval_data[2].keys()), dt, batch_size) 75 | new_eval_datasets.append([tmp_iter, eval_data[0], eval_data[1], eval_data[2]]) 76 | 77 | train_dataset = (train_iter, train_total, train_list, train_dict) 78 | return train_dataset, new_eval_datasets, u_map, i_map 79 | 80 | if __name__ == "__main__": 81 | # Demo: 82 | data_path = "/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/" 83 | batch_size = 10 84 | from jTransUP.data.load_kg_rating_data import loadR2KgMap 85 | 86 | i2kg_file = os.path.join(data_path, 'i2kg_map.tsv') 87 | i2kg_pairs = loadR2KgMap(i2kg_file) 88 | i_set = set([p[0] for p in i2kg_pairs]) 89 | 90 | datasets, rating_iters, u_map, i_map, user_total, item_total = load_data(data_path, batch_size, item_vocab=i_set) 91 | 92 | trainList, testDict, validDict, allDict, testTotal, validTotal = datasets 93 | print("user:{}, item:{}!".format(user_total, item_total)) 94 | print("totally ratings for {} train, {} valid, and {} test!".format(len(trainList), item_total, testTotal)) -------------------------------------------------------------------------------- /jTransUP/data/load_triple_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from jTransUP.utils.data import MakeTrainIterator, MakeEvalIterator 4 | 5 | def loadTriples(filename): 6 | with open(filename, 'r', encoding='utf-8') as fin: 7 | triple_total = 0 8 | triple_list = [] 9 | triple_head_dict = {} 10 | triple_tail_dict = {} 11 | for line in fin: 12 | line_split = line.strip().split('\t') 13 | if len(line_split) != 3 : continue 14 | h_id = int(line_split[0]) 15 | t_id = int(line_split[1]) 16 | r_id = int(line_split[2]) 17 | 18 | triple_list.append( (h_id, t_id, r_id) ) 19 | 20 | tmp_heads = triple_head_dict.get( (t_id, r_id), set()) 21 | tmp_heads.add(h_id) 22 | triple_head_dict[(t_id, r_id)] = tmp_heads 23 | 24 | tmp_tails = triple_tail_dict.get( (h_id, r_id), set()) 25 | tmp_tails.add(t_id) 26 | triple_tail_dict[(h_id, r_id)] = tmp_tails 27 | 28 | triple_total += 1 29 | 30 | return triple_total, triple_list, triple_head_dict, triple_tail_dict 31 | 32 | # org-->id 33 | def loadVocab(filename): 34 | with open(filename, 'r', encoding='utf-8') as fin: 35 | vocab = {} 36 | for line in fin: 37 | line_split = line.strip().split('\t') 38 | if len(line_split) != 2 : continue 39 | e_id = int(line_split[0]) 40 | e_uri = line_split[1] 41 | vocab[e_uri] = e_id 42 | 43 | return vocab 44 | 45 | def load_data(kg_path, eval_filenames, batch_size, negtive_samples=1, logger=None): 46 | 47 | # each dataset has the /kg/ dictionary 48 | 49 | train_file = os.path.join(kg_path, "train.dat") 50 | eval_files = [] 51 | for file_name in eval_filenames: 52 | eval_files.append(os.path.join(kg_path, file_name)) 53 | 54 | e_map_file = os.path.join(kg_path, "e_map.dat") 55 | r_map_file = os.path.join(kg_path, "r_map.dat") 56 | 57 | train_total, train_list, train_head_dict, train_tail_dict = loadTriples(train_file) 58 | 59 | eval_dataset = [] 60 | for eval_file in eval_files: 61 | eval_dataset.append( loadTriples(eval_file) ) 62 | 63 | if logger is not None: 64 | eval_totals = [str(eval_data[0]) for eval_data in eval_dataset] 65 | logger.info("Totally {} train triples, {} eval triples in files: {}!".format(train_total, ",".join(eval_totals), ";".join(eval_files))) 66 | 67 | # get entity total 68 | e_map = loadVocab(e_map_file) 69 | # get relation total 70 | r_map = loadVocab(r_map_file) 71 | 72 | if logger is not None: 73 | logger.info("successfully load {} entities and {} relations!".format(len(e_map), len(r_map))) 74 | 75 | train_iter = MakeTrainIterator(train_list, batch_size, negtive_samples=negtive_samples) 76 | 77 | # train_total, train_list, train_head_dict, train_tail_dict 78 | new_eval_datasets = [] 79 | dt = np.dtype('int,int') 80 | for eval_data in eval_dataset: 81 | tmp_head_iter = MakeEvalIterator(list(eval_data[2].keys()), dt, batch_size) 82 | tmp_tail_iter = MakeEvalIterator(list(eval_data[3].keys()), dt, batch_size) 83 | new_eval_datasets.append([tmp_head_iter, tmp_tail_iter, eval_data[0], eval_data[1], eval_data[2], eval_data[3]]) 84 | 85 | train_dataset = (train_iter, train_total, train_list, train_head_dict, train_tail_dict) 86 | 87 | return train_dataset, new_eval_datasets, e_map, r_map 88 | 89 | if __name__ == "__main__": 90 | # Demo: 91 | data_path = "/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/" 92 | 93 | # i2kg_file = os.path.join(data_path, 'i2kg_map.tsv') 94 | # i2kg_pairs = loadR2KgMap(i2kg_file) 95 | # e_set = set([p[1] for p in i2kg_pairs]) 96 | 97 | rel_file = os.path.join(data_path+'kg/', 'relation_filter.dat') 98 | rel_vocab = set() 99 | with open(rel_file, 'r') as fin: 100 | for line in fin: 101 | rel_vocab.add(line.strip()) 102 | 103 | _, triple_datasets = load_data(data_path, rel_vocab=rel_vocab) 104 | 105 | trainList, testList, validList, e_map, r_map, entity_total, relation_total = triple_datasets 106 | print("entity:{}, relation:{}!".format(entity_total, relation_total)) 107 | print("totally triples for {} train, {} valid, and {} test!".format(len(trainList), len(validList), len(testList))) -------------------------------------------------------------------------------- /jTransUP/data/preprocessRatings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import csv 4 | import json 5 | import os 6 | import random 7 | import math 8 | import logging 9 | 10 | class Rating(object): 11 | def __init__(self, user, item, rating): 12 | self.u = user 13 | self.i = item 14 | self.r = rating 15 | 16 | def splitRatingData(user_dict, train_ratio = 0.7, test_ratio = 0.2, shuffle_data_split=False, filter_unseen_samples=True): 17 | # valid ratio could be 1-train_ratio-test_ratio, and maybe zero 18 | 19 | assert train_ratio > 0 and train_ratio < 1, "train ratio out of range!" 20 | assert test_ratio > 0 and test_ratio < 1, "test ratio out of range!" 21 | 22 | valid_ratio = 1 - train_ratio - test_ratio 23 | assert valid_ratio >= 0 and valid_ratio < 1, "valid ratio out of range!" 24 | 25 | train_item_set = set() 26 | tmp_train_list = [] 27 | tmp_valid_list = [] 28 | tmp_test_list = [] 29 | for user in user_dict: 30 | tmp_item_list = user_dict[user] 31 | 32 | n_items = len(tmp_item_list) 33 | n_train = math.ceil(n_items * train_ratio) 34 | n_valid = math.ceil(n_items * valid_ratio) if valid_ratio > 0 else 0 35 | # in case of zero test item 36 | if n_train >= n_items: 37 | n_train = n_items - 1 38 | n_valid = 0 39 | elif n_train + n_valid >= n_items : 40 | n_valid = n_items - 1 - n_train 41 | 42 | if shuffle_data_split : random.shuffle(tmp_item_list) 43 | 44 | for ir in tmp_item_list[0:n_train]: 45 | tmp_train_list.append( (user, ir[0], ir[1]) ) 46 | train_item_set.add(ir[0]) 47 | tmp_valid_list.extend([(user, ir[0], ir[1]) for ir in tmp_item_list[n_train:n_train+n_valid]]) 48 | 49 | tmp_test_list.extend( [(user, ir[0], ir[1]) for ir in tmp_item_list[n_train+n_valid:]] ) 50 | 51 | u_map = {} 52 | for index, user in enumerate(user_dict.keys()): 53 | u_map[user] = index 54 | i_map = {} 55 | for index, item in enumerate(train_item_set): 56 | i_map[item] = index 57 | 58 | train_list = [Rating(u_map[rating[0]], i_map[rating[1]], rating[2]) for rating in tmp_train_list] 59 | 60 | if filter_unseen_samples: 61 | valid_list = [Rating(u_map[rating[0]], i_map[rating[1]], rating[2]) for rating in tmp_valid_list if rating[1] in train_item_set] 62 | 63 | test_list = [Rating(u_map[rating[0]], i_map[rating[1]], rating[2]) for rating in tmp_test_list if rating[1] in train_item_set] 64 | else: 65 | valid_list = [Rating(u_map[rating[0]], i_map[rating[1]], rating[2]) for rating in tmp_valid_list ] 66 | 67 | test_list = [Rating(u_map[rating[0]], i_map[rating[1]], rating[2]) for rating in tmp_test_list ] 68 | 69 | return train_list, valid_list, test_list, u_map, i_map 70 | 71 | def cutLowFrequentData(rating_file, item_vocab=None, low_frequence=10, logger=None): 72 | df = pd.read_csv(rating_file, encoding='utf-8') 73 | df = df[['userId', 'itemId', 'rating']] 74 | df = df.values 75 | 76 | user_dict = dict() 77 | item_dict = dict() 78 | 79 | f_user_dict = dict() 80 | f_item_dict = dict() 81 | 82 | for line in df: 83 | u_id = int(line[0]) 84 | i_id = int(line[1]) 85 | r_score = int(line[2]) 86 | 87 | if item_vocab is not None and i_id not in item_vocab : continue 88 | 89 | if u_id in user_dict: 90 | user_dict[u_id].append( (i_id, r_score) ) 91 | else: 92 | user_dict[u_id] = [(i_id, r_score)] 93 | 94 | if i_id in item_dict.keys(): 95 | item_dict[i_id].append( (u_id, r_score) ) 96 | else: 97 | item_dict[i_id] = [(u_id, r_score)] 98 | 99 | if logger is not None: 100 | logger.info("Totally {} interactions between {} user and {} items!".format(len(df), len(user_dict), len(item_dict))) 101 | logger.debug("Filtering infrequent users and items (<={}) ...".format(low_frequence)) 102 | while True: 103 | flag1, flag2 = True, True 104 | 105 | for u_id in user_dict.keys(): 106 | pos_items = user_dict[u_id] 107 | valid_items = [idx for idx in pos_items if idx[0] in item_dict.keys()] 108 | 109 | if len(valid_items) >= low_frequence: 110 | f_user_dict[u_id] = valid_items 111 | else: 112 | flag1 = False 113 | 114 | total_ratings = 0 115 | for i_id in item_dict.keys(): 116 | pos_users = item_dict[i_id] 117 | valid_users = [udx for udx in pos_users if udx[0] in user_dict.keys()] 118 | 119 | if len(valid_users) >= low_frequence: 120 | f_item_dict[i_id] = valid_users 121 | total_ratings += len(valid_users) 122 | else: 123 | flag2 = False 124 | 125 | user_dict = f_user_dict.copy() 126 | item_dict = f_item_dict.copy() 127 | f_user_dict = {} 128 | f_item_dict = {} 129 | 130 | if logger is not None: 131 | logger.info("Remaining : {} interactions of {} users and {} items!".format( total_ratings, len(user_dict), len(item_dict))) 132 | if flag1 and flag2: 133 | if logger is not None: logger.debug('Filtering infrequent users and items done!') 134 | break 135 | 136 | return user_dict 137 | 138 | def preprocess(rating_file, out_path, train_ratio=0.7, test_ratio=0.2, shuffle_data_split=True, filter_unseen_samples=True, low_frequence=10, logger=None): 139 | train_file = os.path.join(out_path, "train.dat") 140 | test_file = os.path.join(out_path, "test.dat") 141 | valid_file = os.path.join(out_path, "valid.dat") if 1 - train_ratio - test_ratio != 0 else None 142 | 143 | u_map_file = os.path.join(out_path, "u_map.dat") 144 | i_map_file = os.path.join(out_path, "i_map.dat") 145 | 146 | str_is_shuffle = "shuffle and split" if shuffle_data_split else "split without shuffle" 147 | 148 | if logger is not None: 149 | logger.info("{} {} for {:.1f} training, {:.1f} validation and {:.1f} testing!".format( str_is_shuffle, rating_file, train_ratio, 1-train_ratio-test_ratio, test_ratio )) 150 | 151 | # only remain the items in the item_vocab 152 | item_vocab = None 153 | 154 | user_dict = cutLowFrequentData(rating_file, item_vocab=item_vocab, low_frequence=low_frequence, logger=logger) 155 | 156 | train_list, valid_list, test_list, u_map, i_map = splitRatingData(user_dict, train_ratio = train_ratio, test_ratio = test_ratio, shuffle_data_split=shuffle_data_split, filter_unseen_samples=filter_unseen_samples) 157 | 158 | if logger is not None: 159 | logger.debug("Spliting dataset is done!") 160 | logger.info("Filtering unseen users and items ..." if filter_unseen_samples else "Not filter unseen users and items.") 161 | logger.info("{} users and {} items, where {} train, {} valid, and {} test!".format(len(u_map), len(i_map), len(train_list), len(valid_list), len(test_list))) 162 | 163 | # save ent_dic, rel_dic 164 | with open(u_map_file, 'w', encoding='utf-8') as fout: 165 | for org_u_id in u_map: 166 | fout.write('{}\t{}\n'.format(u_map[org_u_id], org_u_id)) 167 | 168 | with open(i_map_file, 'w', encoding='utf-8') as fout: 169 | for org_i_id in i_map: 170 | fout.write('{}\t{}\n'.format(i_map[org_i_id], org_i_id)) 171 | 172 | with open(train_file, 'w', encoding='utf-8') as fout: 173 | for rating in train_list: 174 | fout.write('{}\t{}\t{}\n'.format(rating.u, rating.i, rating.r)) 175 | 176 | with open(test_file, 'w', encoding='utf-8') as fout: 177 | for rating in test_list: 178 | fout.write('{}\t{}\t{}\n'.format(rating.u, rating.i, rating.r)) 179 | 180 | if len(valid_list) > 0: 181 | with open(valid_file, 'w', encoding='utf-8') as fout: 182 | for rating in valid_list: 183 | fout.write('{}\t{}\t{}\n'.format(rating.u, rating.i, rating.r)) 184 | 185 | def loadRatings(filename): 186 | with open(filename, 'r', encoding='utf-8') as fin: 187 | user_dict = {} 188 | total_count = 0 189 | for line in fin: 190 | line_split = line.strip().split('\t') 191 | if len(line_split) != 3 : continue 192 | u = int(line_split[0]) 193 | i = int(line_split[1]) 194 | rating = int(line_split[2]) 195 | 196 | i_set = user_dict.get(u, set()) 197 | i_set.add( (i, rating) ) 198 | user_dict[u] = i_set 199 | total_count += 1 200 | return total_count, user_dict 201 | 202 | def getMaxMinRatings(user_dict): 203 | max_ratings = 0 204 | min_ratings = 10000 205 | for u in user_dict: 206 | if len(user_dict[u]) > max_ratings: 207 | max_ratings = len(user_dict[u]) 208 | if len(user_dict[u]) < min_ratings: 209 | min_ratings = len(user_dict[u]) 210 | return max_ratings, min_ratings 211 | 212 | def splitUsers(user_dict, split_num): 213 | split_num = 10 214 | splited_users = [set() for _ in range(split_num)] 215 | max_ratings, min_ratings = getMaxMinRatings(user_dict) 216 | step = math.ceil((max_ratings - min_ratings + 1) / split_num) 217 | splited_threshold = [i for i in range(min_ratings, max_ratings, step) if i!= min_ratings] + [max_ratings] 218 | splited_threshold = [20, 40, 60, 80, 100, 150, 200, 300, 400, max_ratings] 219 | 220 | for u in user_dict: 221 | rating_num = len(user_dict[u]) 222 | for i, thr in enumerate(splited_threshold): 223 | if rating_num <= thr: 224 | splited_users[i].add(u) 225 | break 226 | return splited_threshold, splited_users 227 | 228 | def output(filename, user_dict, u_ids): 229 | count = 0 230 | with open(filename, 'w', encoding='utf-8') as fout: 231 | for u in user_dict: 232 | if u in u_ids: 233 | for ir in user_dict[u]: 234 | fout.write("{}\t{}\t{}\n".format(u, ir[0], ir[1])) 235 | count += 1 236 | return count 237 | 238 | if __name__ == "__main__": 239 | root_path = '/Users/caoyixin/Github/joint-kg-recommender/datasets/' 240 | dataset = 'ml1m' 241 | split_num = 10 242 | train_file = root_path + dataset +'/train.dat' 243 | test_file = root_path + dataset +'/test.dat' 244 | log_file = root_path + dataset + '/data_preprocess.log' 245 | 246 | logger = logging.getLogger() 247 | logger.setLevel(level=logging.DEBUG) 248 | 249 | # Formatter 250 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 251 | # FileHandler 252 | file_handler = logging.FileHandler(log_file) 253 | file_handler.setFormatter(formatter) 254 | logger.addHandler(file_handler) 255 | 256 | # StreamHandler 257 | stream_handler = logging.StreamHandler() 258 | stream_handler.setFormatter(formatter) 259 | logger.addHandler(stream_handler) 260 | 261 | train_rating_total, train_user_dict = loadRatings(train_file) 262 | test_rating_total, test_user_dict = loadRatings(test_file) 263 | 264 | max_ratings, min_ratings = getMaxMinRatings(train_user_dict) 265 | logger.info("load {} ratings for {} users, where min {} and max {} ratings!".format(train_rating_total, len(train_user_dict), min_ratings, max_ratings)) 266 | 267 | splited_threshold, splited_users = splitUsers(train_user_dict, split_num) 268 | 269 | for i, u_ids in enumerate(splited_users): 270 | logger.info("generating test ratings if {} user ratings num < {} ...".format(len(u_ids), splited_threshold[i])) 271 | filename = root_path + dataset +'/test{}.dat'.format(i) 272 | count = output(filename, test_user_dict, u_ids) 273 | logger.info("output {} ratings done!".format(count)) -------------------------------------------------------------------------------- /jTransUP/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/.DS_Store -------------------------------------------------------------------------------- /jTransUP/models/CFKG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = CFKG 10 | return model_cls( 11 | L1_flag = FLAGS.L1_flag, 12 | embedding_size = FLAGS.embedding_size, 13 | user_total = user_total, 14 | item_total = item_total, 15 | entity_total = entity_total, 16 | relation_total = relation_total 17 | ) 18 | 19 | class CFKG(nn.Module): 20 | def __init__(self, 21 | L1_flag, 22 | embedding_size, 23 | user_total, 24 | item_total, 25 | entity_total, 26 | relation_total 27 | ): 28 | super(CFKG, self).__init__() 29 | self.L1_flag = L1_flag 30 | self.embedding_size = embedding_size 31 | self.user_total = user_total 32 | self.item_total = item_total 33 | self.ent_total = entity_total 34 | # add buy relation between user and item 35 | self.rel_total = relation_total + 1 36 | self.is_pretrained = False 37 | 38 | # init user embeddings 39 | user_weight = torch.FloatTensor(self.user_total, self.embedding_size) 40 | nn.init.xavier_uniform(user_weight) 41 | 42 | self.user_embeddings = nn.Embedding(self.user_total, self.embedding_size) 43 | self.user_embeddings.weight = nn.Parameter(user_weight) 44 | normalize_user_emb = F.normalize(self.user_embeddings.weight.data, p=2, dim=1) 45 | self.user_embeddings.weight.data = normalize_user_emb 46 | self.user_embeddings = to_gpu(self.user_embeddings) 47 | 48 | # init entity and relation embeddings 49 | ent_weight = torch.FloatTensor(self.ent_total, self.embedding_size) 50 | rel_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 51 | nn.init.xavier_uniform(ent_weight) 52 | nn.init.xavier_uniform(rel_weight) 53 | self.ent_embeddings = nn.Embedding(self.ent_total, self.embedding_size) 54 | self.rel_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 55 | 56 | self.ent_embeddings.weight = nn.Parameter(ent_weight) 57 | self.rel_embeddings.weight = nn.Parameter(rel_weight) 58 | 59 | normalize_ent_emb = F.normalize(self.ent_embeddings.weight.data, p=2, dim=1) 60 | normalize_rel_emb = F.normalize(self.rel_embeddings.weight.data, p=2, dim=1) 61 | 62 | self.ent_embeddings.weight.data = normalize_ent_emb 63 | self.rel_embeddings.weight.data = normalize_rel_emb 64 | 65 | self.ent_embeddings = to_gpu(self.ent_embeddings) 66 | self.rel_embeddings = to_gpu(self.rel_embeddings) 67 | 68 | # share embedding 69 | self.item_embeddings = self.ent_embeddings 70 | 71 | def forward(self, ratings, triples, is_rec=True): 72 | 73 | if is_rec and ratings is not None: 74 | u_ids, i_ids = ratings 75 | batch_size = len(u_ids) 76 | 77 | u_e = self.user_embeddings(u_ids) 78 | i_e = self.item_embeddings(i_ids) 79 | 80 | buy_e = self.rel_embeddings(to_gpu(V(torch.LongTensor([self.rel_total-1])))) 81 | buy_e_expand = buy_e.expand(batch_size, self.embedding_size) 82 | # L1 distance 83 | if self.L1_flag: 84 | score = torch.sum(torch.abs(u_e + buy_e_expand - i_e), 1) 85 | # L2 distance 86 | else: 87 | score = torch.sum((u_e + buy_e_expand - i_e) ** 2, 1) 88 | 89 | elif not is_rec and triples is not None: 90 | h, t, r = triples 91 | h_e = self.ent_embeddings(h) 92 | t_e = self.ent_embeddings(t) 93 | r_e = self.rel_embeddings(r) 94 | 95 | # L1 distance 96 | if self.L1_flag: 97 | score = torch.sum(torch.abs(h_e + r_e - t_e), 1) 98 | # L2 distance 99 | else: 100 | score = torch.sum((h_e + r_e - t_e) ** 2, 1) 101 | else: 102 | raise NotImplementedError 103 | 104 | return score 105 | 106 | def evaluateRec(self, u_ids, all_i_ids=None): 107 | batch_size = len(u_ids) 108 | all_i = self.item_embeddings(all_i_ids) if all_i_ids is not None else self.item_embeddings.weight 109 | item_total, dim = all_i.size() 110 | # batch * dim 111 | u_e = self.user_embeddings(u_ids) 112 | # batch * item * dim 113 | u_e_expand = u_e.expand(item_total, batch_size, dim).permute(1, 0, 2) 114 | 115 | buy_e = self.rel_embeddings(to_gpu(V(torch.LongTensor([self.rel_total-1])))) 116 | buy_e_expand = buy_e.expand(batch_size, item_total, dim) 117 | 118 | c_i_e = u_e_expand + buy_e_expand 119 | 120 | # batch * item * dim 121 | i_expand = all_i.expand(batch_size, item_total, dim) 122 | 123 | if self.L1_flag: 124 | score = torch.sum(torch.abs(c_i_e-i_expand), 2) 125 | else: 126 | score = torch.sum((c_i_e-i_expand) ** 2, 2) 127 | return score 128 | 129 | def evaluateHead(self, t, r, all_e_ids=None): 130 | batch_size = len(t) 131 | all_e = self.ent_embeddings(all_e_ids) if all_e_ids is not None else self.ent_embeddings.weight 132 | ent_total, dim = all_e.size() 133 | # batch * dim 134 | t_e = self.ent_embeddings(t) 135 | r_e = self.rel_embeddings(r) 136 | 137 | c_h_e = t_e - r_e 138 | 139 | # batch * entity * dim 140 | c_h_expand = c_h_e.expand(ent_total, batch_size, dim).permute(1, 0, 2) 141 | 142 | # batch * entity * dim 143 | ent_expand = all_e.expand(batch_size, ent_total, dim) 144 | 145 | # batch * entity 146 | if self.L1_flag: 147 | score = torch.sum(torch.abs(c_h_expand-ent_expand), 2) 148 | else: 149 | score = torch.sum((c_h_expand-ent_expand) ** 2, 2) 150 | return score 151 | 152 | def evaluateTail(self, h, r, all_e_ids=None): 153 | batch_size = len(h) 154 | all_e = self.ent_embeddings(all_e_ids) if all_e_ids is not None else self.ent_embeddings.weight 155 | ent_total, dim = all_e.size() 156 | # batch * dim 157 | h_e = self.ent_embeddings(h) 158 | r_e = self.rel_embeddings(r) 159 | 160 | c_t_e = h_e + r_e 161 | 162 | # batch * entity * dim 163 | c_t_expand = c_t_e.expand(ent_total, batch_size, dim).permute(1, 0, 2) 164 | 165 | # batch * entity * dim 166 | ent_expand = all_e.expand(batch_size, ent_total, dim) 167 | 168 | # batch * entity 169 | if self.L1_flag: 170 | score = torch.sum(torch.abs(c_t_expand-ent_expand), 2) 171 | else: 172 | score = torch.sum((c_t_expand-ent_expand) ** 2, 2) 173 | return score 174 | 175 | 176 | def disable_grad(self): 177 | for name, param in self.named_parameters(): 178 | param.requires_grad=False 179 | 180 | def enable_grad(self): 181 | for name, param in self.named_parameters(): 182 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/CKE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | from jTransUP.utils.misc import to_gpu, projection_transR_pytorch, projection_transR_pytorch_batch 6 | 7 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 8 | model_cls = CKE 9 | return model_cls( 10 | L1_flag = FLAGS.L1_flag, 11 | embedding_size = FLAGS.embedding_size, 12 | user_total = user_total, 13 | item_total = item_total, 14 | entity_total = entity_total, 15 | relation_total = relation_total, 16 | i_map=i_map, 17 | new_map=new_map 18 | ) 19 | 20 | class CKE(nn.Module): 21 | def __init__(self, 22 | L1_flag, 23 | embedding_size, 24 | user_total, 25 | item_total, 26 | entity_total, 27 | relation_total, 28 | i_map, 29 | new_map 30 | ): 31 | super(CKE, self).__init__() 32 | self.L1_flag = L1_flag 33 | self.embedding_size = embedding_size 34 | self.user_total = user_total 35 | self.item_total = item_total 36 | # padding when item are not aligned with any entity 37 | self.ent_total = entity_total + 1 38 | self.rel_total = relation_total 39 | self.is_pretrained = False 40 | # store item to item-entity dic 41 | self.i_map = i_map 42 | # store item-entity to (entity, item) 43 | self.new_map = new_map 44 | 45 | # bprmf 46 | # init user and item embeddings 47 | user_weight = torch.FloatTensor(self.user_total, self.embedding_size) 48 | item_weight = torch.FloatTensor(self.item_total, self.embedding_size) 49 | nn.init.xavier_uniform(user_weight) 50 | nn.init.xavier_uniform(item_weight) 51 | self.user_embeddings = nn.Embedding(self.user_total, self.embedding_size) 52 | self.item_embeddings = nn.Embedding(self.item_total, self.embedding_size) 53 | self.user_embeddings.weight = nn.Parameter(user_weight) 54 | self.item_embeddings.weight = nn.Parameter(item_weight) 55 | normalize_user_emb = F.normalize(self.user_embeddings.weight.data, p=2, dim=1) 56 | normalize_item_emb = F.normalize(self.item_embeddings.weight.data, p=2, dim=1) 57 | self.user_embeddings.weight.data = normalize_user_emb 58 | self.item_embeddings.weight.data = normalize_item_emb 59 | 60 | self.user_embeddings = to_gpu(self.user_embeddings) 61 | self.item_embeddings = to_gpu(self.item_embeddings) 62 | 63 | # transR 64 | 65 | ent_weight = torch.FloatTensor(self.ent_total-1, self.embedding_size) 66 | rel_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 67 | proj_weight = torch.FloatTensor(self.rel_total, self.embedding_size * self.embedding_size) 68 | nn.init.xavier_uniform(ent_weight) 69 | nn.init.xavier_uniform(rel_weight) 70 | 71 | norm_ent_weight = F.normalize(ent_weight, p=2, dim=1) 72 | 73 | if self.is_pretrained: 74 | nn.init.eye(proj_weight) 75 | proj_weight = proj_weight.view(-1).expand(self.relation_total, -1) 76 | else: 77 | nn.init.xavier_uniform(proj_weight) 78 | 79 | # init user and item embeddings 80 | self.ent_embeddings = nn.Embedding(self.ent_total, self.embedding_size, padding_idx=self.ent_total-1) 81 | 82 | self.rel_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 83 | self.proj_embeddings = nn.Embedding(self.rel_total, self.embedding_size * self.embedding_size) 84 | 85 | self.ent_embeddings.weight = nn.Parameter(torch.cat([norm_ent_weight, torch.zeros(1, self.embedding_size)], dim=0)) 86 | self.rel_embeddings.weight = nn.Parameter(rel_weight) 87 | self.proj_embeddings.weight = nn.Parameter(proj_weight) 88 | 89 | # normalize_ent_emb = F.normalize(self.ent_embeddings.weight.data, p=2, dim=1) 90 | normalize_rel_emb = F.normalize(self.rel_embeddings.weight.data, p=2, dim=1) 91 | # normalize_proj_emb = F.normalize(self.proj_embeddings.weight.data, p=2, dim=1) 92 | 93 | # self.ent_embeddings.weight.data = normalize_ent_emb 94 | self.rel_embeddings.weight.data = normalize_rel_emb 95 | # self.proj_embeddings.weight.data = normalize_proj_emb 96 | 97 | self.ent_embeddings = to_gpu(self.ent_embeddings) 98 | self.rel_embeddings = to_gpu(self.rel_embeddings) 99 | self.proj_embeddings = to_gpu(self.proj_embeddings) 100 | 101 | def paddingItems(self, i_ids, pad_index): 102 | padded_e_ids = [] 103 | for i_id in i_ids: 104 | new_index = self.i_map[i_id] 105 | ent_id = self.new_map[new_index][0] 106 | padded_e_ids.append(ent_id if ent_id != -1 else pad_index) 107 | return padded_e_ids 108 | 109 | def forward(self, ratings, triples, is_rec=True): 110 | 111 | if is_rec and ratings is not None: 112 | u_ids, i_ids = ratings 113 | 114 | e_ids = self.paddingItems(i_ids.data, self.ent_total-1) 115 | e_var = to_gpu(V(torch.LongTensor(e_ids))) 116 | 117 | u_e = self.user_embeddings(u_ids) 118 | i_e = self.item_embeddings(i_ids) 119 | e_e = self.ent_embeddings(e_var) 120 | ie_e = i_e + e_e 121 | 122 | score = torch.bmm(u_e.unsqueeze(1), ie_e.unsqueeze(2)).squeeze() 123 | elif not is_rec and triples is not None: 124 | h, t, r = triples 125 | h_e = self.ent_embeddings(h) 126 | t_e = self.ent_embeddings(t) 127 | r_e = self.rel_embeddings(r) 128 | proj_e = self.proj_embeddings(r) 129 | 130 | proj_h_e = projection_transR_pytorch(h_e, proj_e) 131 | proj_t_e = projection_transR_pytorch(t_e, proj_e) 132 | 133 | if self.L1_flag: 134 | score = torch.sum(torch.abs(proj_h_e + r_e - proj_t_e), 1) 135 | else: 136 | score = torch.sum((proj_h_e + r_e - proj_t_e) ** 2, 1) 137 | else: 138 | raise NotImplementedError 139 | 140 | return score 141 | 142 | def evaluateRec(self, u_ids, all_i_ids=None): 143 | batch_size = len(u_ids) 144 | i_ids = range(len(self.item_embeddings.weight)) 145 | e_ids = self.paddingItems(i_ids, self.ent_total-1) 146 | e_var = to_gpu(V(torch.LongTensor(e_ids))) 147 | e_e = self.ent_embeddings(e_var) 148 | 149 | all_ie_e = self.item_embeddings.weight + e_e 150 | 151 | u_e = self.user_embeddings(u_ids) 152 | 153 | return torch.matmul(u_e, all_ie_e.t()) 154 | 155 | def evaluateHead(self, t, r, all_e_ids=None): 156 | batch_size = len(t) 157 | all_e = self.ent_embeddings(all_e_ids) if all_e_ids is not None and self.is_share else self.ent_embeddings.weight 158 | ent_total, dim = all_e.size() 159 | # batch * dim 160 | t_e = self.ent_embeddings(t) 161 | r_e = self.rel_embeddings(r) 162 | # batch* dim*dim 163 | proj_e = self.proj_embeddings(r) 164 | # batch * dim 165 | proj_t_e = projection_transR_pytorch(t_e, proj_e) 166 | c_h_e = proj_t_e - r_e 167 | 168 | # batch * entity * dim 169 | c_h_expand = c_h_e.expand(ent_total, batch_size, dim).permute(1, 0, 2) 170 | 171 | # batch * entity * dim 172 | proj_ent_expand = projection_transR_pytorch_batch(all_e, proj_e) 173 | 174 | # batch * entity 175 | if self.L1_flag: 176 | score = torch.sum(torch.abs(c_h_expand-proj_ent_expand), 2) 177 | else: 178 | score = torch.sum((c_h_expand-proj_ent_expand) ** 2, 2) 179 | return score 180 | 181 | def evaluateTail(self, h, r, all_e_ids=None): 182 | batch_size = len(h) 183 | all_e = self.ent_embeddings(all_e_ids) if all_e_ids is not None and self.is_share else self.ent_embeddings.weight 184 | ent_total, dim = all_e.size() 185 | # batch * dim 186 | h_e = self.ent_embeddings(h) 187 | r_e = self.rel_embeddings(r) 188 | # batch* dim*dim 189 | proj_e = self.proj_embeddings(r) 190 | # batch * dim 191 | proj_h_e = projection_transR_pytorch(h_e, proj_e) 192 | c_t_e = proj_h_e + r_e 193 | 194 | # batch * entity * dim 195 | c_t_expand = c_t_e.expand(ent_total, batch_size, dim).permute(1, 0, 2) 196 | 197 | # batch * entity * dim 198 | proj_ent_expand = projection_transR_pytorch_batch(all_e, proj_e) 199 | 200 | # batch * entity 201 | if self.L1_flag: 202 | score = torch.sum(torch.abs(c_t_expand-proj_ent_expand), 2) 203 | else: 204 | score = torch.sum((c_t_expand-proj_ent_expand) ** 2, 2) 205 | return score 206 | 207 | def disable_grad(self): 208 | for name, param in self.named_parameters(): 209 | param.requires_grad=False 210 | 211 | def enable_grad(self): 212 | for name, param in self.named_parameters(): 213 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__init__.py -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/CFKG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/CFKG.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/CKE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/CKE.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/bprmf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/bprmf.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/cofm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/cofm.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/fm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/fm.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/item_recommendation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/item_recommendation.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/jTransUP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/jTransUP.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/knowledgable_recommendation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/knowledgable_recommendation.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/knowledge_representation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/knowledge_representation.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/rating.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/rating.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/transD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/transD.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/transE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/transE.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/transH.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/transH.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/transR.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/transR.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/transUP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/transUP.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/__pycache__/transUP_bias.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/models/__pycache__/transUP_bias.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/models/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | import gflags 3 | import torch 4 | from copy import deepcopy 5 | from functools import reduce 6 | import os 7 | 8 | from jTransUP.data import load_rating_data, load_triple_data, load_kg_rating_data 9 | 10 | import jTransUP.models.transUP as transup 11 | import jTransUP.models.bprmf as bprmf 12 | import jTransUP.models.transH as transh 13 | import jTransUP.models.jTransUP as jtransup 14 | import jTransUP.models.fm as fm 15 | import jTransUP.models.transE as transe 16 | import jTransUP.models.transR as transr 17 | import jTransUP.models.transD as transd 18 | import jTransUP.models.cofm as cofm 19 | import jTransUP.models.CKE as cke 20 | import jTransUP.models.CFKG as cfkg 21 | 22 | def get_flags(): 23 | gflags.DEFINE_enum("model_type", "transup", ["transup", "bprmf", "fm", 24 | "transe", "transh", "transr", "transd", 25 | "cfkg", "cke", "cofm", "jtransup"], "") 26 | gflags.DEFINE_enum("dataset", "ml1m", ["ml1m", "dbbook2014", "amazon-book", "last-fm", "yelp2018"], "including ratings.csv, r2kg.tsv and a kg dictionary containing kg_hop[0-9].dat") 27 | gflags.DEFINE_bool( 28 | "filter_wrong_corrupted", 29 | True, 30 | "If set to True, filter test samples from train and validations.") 31 | gflags.DEFINE_bool("share_embeddings", False, "") 32 | gflags.DEFINE_bool("use_st_gumbel", False, "") 33 | gflags.DEFINE_integer("max_queue", 10, ".") 34 | gflags.DEFINE_integer("num_processes", 4, ".") 35 | 36 | gflags.DEFINE_float("learning_rate", 0.001, "Used in optimizer.") 37 | gflags.DEFINE_float("norm_lambda", 1.0, "decay of joint model.") 38 | gflags.DEFINE_float("kg_lambda", 1.0, "decay of kg model.") 39 | gflags.DEFINE_integer( 40 | "early_stopping_steps_to_wait", 41 | 70000, 42 | "How many times will lr decrease? If set to 0, it remains constant.") 43 | gflags.DEFINE_bool( 44 | "L1_flag", 45 | False, 46 | "If set to True, use L1 distance as dissimilarity; else, use L2.") 47 | gflags.DEFINE_bool( 48 | "is_report", 49 | False, 50 | "If set to True, use L1 distance as dissimilarity; else, use L2.") 51 | gflags.DEFINE_float("l2_lambda", 1e-5, "") 52 | gflags.DEFINE_integer("embedding_size", 64, ".") 53 | gflags.DEFINE_integer("negtive_samples", 1, ".") 54 | gflags.DEFINE_integer("batch_size", 512, "Minibatch size.") 55 | gflags.DEFINE_enum("optimizer_type", "Adagrad", ["Adam", "SGD", "Adagrad", "Rmsprop"], "") 56 | gflags.DEFINE_float("learning_rate_decay_when_no_progress", 0.5, 57 | "Used in optimizer. Decay the LR by this much every epoch steps if a new best has not been set in the last epoch.") 58 | 59 | gflags.DEFINE_integer( 60 | "eval_interval_steps", 61 | 14000, 62 | "Evaluate at this interval in each epoch.") 63 | gflags.DEFINE_integer( 64 | "training_steps", 65 | 1400000, 66 | "Stop training after this point.") 67 | gflags.DEFINE_float("clipping_max_value", 5.0, "") 68 | gflags.DEFINE_float("margin", 1.0, "Used in margin loss.") 69 | gflags.DEFINE_float("momentum", 0.9, "The momentum of the optimizer.") 70 | gflags.DEFINE_integer("seed", 0, "Fix the random seed. Except for 0, which means no setting of random seed.") 71 | gflags.DEFINE_integer("topn", 10, "") 72 | gflags.DEFINE_integer("num_preferences", 4, "") 73 | gflags.DEFINE_float("joint_ratio", 0.5, "(0 - 1). The train ratio of recommendation, kg is 1 - joint_ratio.") 74 | 75 | gflags.DEFINE_string("experiment_name", None, "") 76 | gflags.DEFINE_string("data_path", None, "") 77 | gflags.DEFINE_string("rec_test_files", None, "multiple filenames separated by ':'.") 78 | gflags.DEFINE_string("kg_test_files", None, "multiple filenames separated by ':'.") 79 | gflags.DEFINE_string("log_path", None, "") 80 | gflags.DEFINE_enum("log_level", "debug", ["debug", "info"], "") 81 | gflags.DEFINE_string( 82 | "ckpt_path", None, "Where to save/load checkpoints. If not set, the same as log_path") 83 | 84 | gflags.DEFINE_string( 85 | "load_ckpt_file", None, "Where to load pretrained checkpoints under log path. multiple filenames separated by ':'.") 86 | 87 | gflags.DEFINE_boolean( 88 | "has_visualization", 89 | True, 90 | "if set True, use visdom for visualization.") 91 | gflags.DEFINE_integer("visualization_port", 8097, "") 92 | # todo: only eval when no train.dat when load data 93 | gflags.DEFINE_boolean( 94 | "eval_only_mode", 95 | False, 96 | "If set, a checkpoint is loaded and a forward pass is done to get the predicted candidates." 97 | "Requirements: Must specify load_experiment_name.") 98 | gflags.DEFINE_string("load_experiment_name", None, "") 99 | 100 | def flag_defaults(FLAGS): 101 | 102 | if not FLAGS.experiment_name: 103 | timestamp = str(int(time.time())) 104 | FLAGS.experiment_name = "{}-{}-{}".format( 105 | FLAGS.dataset, 106 | FLAGS.model_type, 107 | timestamp, 108 | ) 109 | 110 | if not FLAGS.data_path: 111 | FLAGS.data_path = "../datasets/" 112 | 113 | if not FLAGS.log_path: 114 | FLAGS.log_path = "../log/" 115 | 116 | if not FLAGS.ckpt_path: 117 | FLAGS.ckpt_path = FLAGS.log_path 118 | 119 | if FLAGS.seed != 0: 120 | torch.manual_seed(FLAGS.seed) 121 | 122 | if FLAGS.model_type in ['cke', 'jtransup']: 123 | FLAGS.share_embeddings = False 124 | elif FLAGS.model_type == 'cfkg': 125 | FLAGS.share_embeddings = True 126 | 127 | 128 | def init_model( 129 | FLAGS, 130 | user_total, 131 | item_total, 132 | entity_total, 133 | relation_total, 134 | logger, 135 | i_map=None, 136 | e_map=None, 137 | new_map=None): 138 | # Choose model. 139 | logger.info("Building model.") 140 | if FLAGS.model_type == "transup": 141 | build_model = transup.build_model 142 | elif FLAGS.model_type == "bprmf": 143 | build_model = bprmf.build_model 144 | elif FLAGS.model_type == "fm": 145 | build_model = fm.build_model 146 | elif FLAGS.model_type == "transe": 147 | build_model = transe.build_model 148 | elif FLAGS.model_type == "transh": 149 | build_model = transh.build_model 150 | elif FLAGS.model_type == "transr": 151 | build_model = transr.build_model 152 | elif FLAGS.model_type == "transd": 153 | build_model = transd.build_model 154 | elif FLAGS.model_type == "cofm": 155 | build_model = cofm.build_model 156 | elif FLAGS.model_type == "cke": 157 | build_model = cke.build_model 158 | elif FLAGS.model_type == "cfkg": 159 | build_model = cfkg.build_model 160 | elif FLAGS.model_type == "jtransup": 161 | build_model = jtransup.build_model 162 | else: 163 | raise NotImplementedError 164 | 165 | model = build_model(FLAGS, user_total, item_total, entity_total, relation_total, 166 | i_map=i_map, e_map=e_map, new_map=new_map) 167 | 168 | # Print model size. 169 | logger.info("Architecture: {}".format(model)) 170 | 171 | total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) 172 | for w in model.parameters()]) 173 | logger.info("Total params: {}".format(total_params)) 174 | 175 | return model -------------------------------------------------------------------------------- /jTransUP/models/bprmf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = BPRMF 10 | return model_cls( 11 | FLAGS.embedding_size, 12 | user_total, 13 | item_total) 14 | 15 | class BPRMF(nn.Module): 16 | def __init__(self, 17 | embedding_size, 18 | user_total, 19 | item_total, 20 | ): 21 | super(BPRMF, self).__init__() 22 | self.embedding_size = embedding_size 23 | self.user_total = user_total 24 | self.item_total = item_total 25 | self.is_pretrained = False 26 | 27 | # init user and item embeddings 28 | user_weight = torch.FloatTensor(self.user_total, self.embedding_size) 29 | item_weight = torch.FloatTensor(self.item_total, self.embedding_size) 30 | nn.init.xavier_uniform(user_weight) 31 | nn.init.xavier_uniform(item_weight) 32 | # init user and item embeddings 33 | self.user_embeddings = nn.Embedding(self.user_total, self.embedding_size) 34 | self.item_embeddings = nn.Embedding(self.item_total, self.embedding_size) 35 | self.user_embeddings.weight = nn.Parameter(user_weight) 36 | self.item_embeddings.weight = nn.Parameter(item_weight) 37 | normalize_user_emb = F.normalize(self.user_embeddings.weight.data, p=2, dim=1) 38 | normalize_item_emb = F.normalize(self.item_embeddings.weight.data, p=2, dim=1) 39 | self.user_embeddings.weight.data = normalize_user_emb 40 | self.item_embeddings.weight.data = normalize_item_emb 41 | 42 | self.user_embeddings = to_gpu(self.user_embeddings) 43 | self.item_embeddings = to_gpu(self.item_embeddings) 44 | 45 | 46 | def forward(self, u_ids, i_ids): 47 | u_e = self.user_embeddings(u_ids) 48 | i_e = self.item_embeddings(i_ids) 49 | return torch.bmm(u_e.unsqueeze(1), i_e.unsqueeze(2)).squeeze() 50 | 51 | def evaluate(self, u_ids): 52 | u_e = self.user_embeddings(u_ids) 53 | 54 | return torch.matmul(u_e, self.item_embeddings.weight.t()) 55 | 56 | def disable_grad(self): 57 | for name, param in self.named_parameters(): 58 | param.requires_grad=False 59 | 60 | def enable_grad(self): 61 | for name, param in self.named_parameters(): 62 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/cofm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu, projection_transH_pytorch 7 | from jTransUP.models.transH import TransHModel 8 | from jTransUP.models.transUP import TransUPModel 9 | 10 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 11 | model_cls = coFM 12 | return model_cls( 13 | L1_flag = FLAGS.L1_flag, 14 | embedding_size = FLAGS.embedding_size, 15 | user_total = user_total, 16 | item_total = item_total, 17 | entity_total = entity_total, 18 | relation_total = relation_total, 19 | isShare = FLAGS.share_embeddings 20 | ) 21 | 22 | class coFM(nn.Module): 23 | def __init__(self, 24 | L1_flag, 25 | embedding_size, 26 | user_total, 27 | item_total, 28 | entity_total, 29 | relation_total, 30 | isShare 31 | ): 32 | super(coFM, self).__init__() 33 | self.L1_flag = L1_flag 34 | self.is_share = isShare 35 | self.embedding_size = embedding_size 36 | self.user_total = user_total 37 | self.item_total = item_total 38 | self.ent_total = entity_total 39 | self.rel_total = relation_total 40 | self.is_pretrained = False 41 | # fm 42 | user_weight = torch.FloatTensor(self.user_total, self.embedding_size) 43 | nn.init.xavier_uniform(user_weight) 44 | self.user_embeddings = nn.Embedding(self.user_total, self.embedding_size) 45 | self.user_embeddings.weight = nn.Parameter(user_weight) 46 | normalize_user_emb = F.normalize(self.user_embeddings.weight.data, p=2, dim=1) 47 | self.user_embeddings.weight.data = normalize_user_emb 48 | self.user_embeddings = to_gpu(self.user_embeddings) 49 | 50 | user_bias = torch.FloatTensor(self.user_total) 51 | item_bias = torch.FloatTensor(self.item_total) 52 | nn.init.constant(user_bias, 0) 53 | nn.init.constant(item_bias, 0) 54 | self.user_bias = nn.Embedding(self.user_total, 1) 55 | self.item_bias = nn.Embedding(self.item_total, 1) 56 | self.user_bias.weight = nn.Parameter(user_bias, 1) 57 | self.item_bias.weight = nn.Parameter(item_bias, 1) 58 | 59 | self.user_bias = to_gpu(self.user_bias) 60 | self.item_bias = to_gpu(self.item_bias) 61 | 62 | self.bias = nn.Parameter(to_gpu(torch.FloatTensor([0.0]))) 63 | 64 | # trane 65 | 66 | rel_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 67 | nn.init.xavier_uniform(rel_weight) 68 | self.rel_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 69 | self.rel_embeddings.weight = nn.Parameter(rel_weight) 70 | normalize_rel_emb = F.normalize(self.rel_embeddings.weight.data, p=2, dim=1) 71 | self.rel_embeddings.weight.data = normalize_rel_emb 72 | self.rel_embeddings = to_gpu(self.rel_embeddings) 73 | 74 | # shared embedding 75 | ent_weight = torch.FloatTensor(self.ent_total, self.embedding_size) 76 | nn.init.xavier_uniform(ent_weight) 77 | self.ent_embeddings = nn.Embedding(self.ent_total, self.embedding_size) 78 | self.ent_embeddings.weight = nn.Parameter(ent_weight) 79 | normalize_ent_emb = F.normalize(self.ent_embeddings.weight.data, p=2, dim=1) 80 | self.ent_embeddings.weight.data = normalize_ent_emb 81 | self.ent_embeddings = to_gpu(self.ent_embeddings) 82 | 83 | if self.is_share: 84 | assert self.item_total == self.ent_total, "item numbers didn't match entities!" 85 | self.item_embeddings = self.ent_embeddings 86 | else: 87 | item_weight = torch.FloatTensor(self.item_total, self.embedding_size) 88 | nn.init.xavier_uniform(item_weight) 89 | self.item_embeddings = nn.Embedding(self.item_total, self.embedding_size) 90 | self.item_embeddings.weight = nn.Parameter(item_weight) 91 | normalize_item_emb = F.normalize(self.item_embeddings.weight.data, p=2, dim=1) 92 | self.item_embeddings.weight.data = normalize_item_emb 93 | self.item_embeddings = to_gpu(self.item_embeddings) 94 | 95 | def forward(self, ratings, triples, is_rec=True): 96 | 97 | if is_rec and ratings is not None: 98 | u_ids, i_ids = ratings 99 | batch_size = len(u_ids) 100 | 101 | u_e = self.user_embeddings(u_ids) 102 | i_e = self.item_embeddings(i_ids) 103 | u_b = self.user_bias(u_ids).squeeze() 104 | i_b = self.item_bias(i_ids).squeeze() 105 | 106 | score = self.bias.expand(batch_size) + u_b + i_b + torch.bmm(u_e.unsqueeze(1), i_e.unsqueeze(2)).squeeze() 107 | 108 | elif not is_rec and triples is not None: 109 | h, t, r = triples 110 | h_e = self.ent_embeddings(h) 111 | t_e = self.ent_embeddings(t) 112 | r_e = self.rel_embeddings(r) 113 | 114 | # L1 distance 115 | if self.L1_flag: 116 | score = torch.sum(torch.abs(h_e + r_e - t_e), 1) 117 | # L2 distance 118 | else: 119 | score = torch.sum((h_e + r_e - t_e) ** 2, 1) 120 | else: 121 | raise NotImplementedError 122 | 123 | return score 124 | 125 | def evaluateRec(self, u_ids, all_i_ids=None): 126 | batch_size = len(u_ids) 127 | all_i = self.item_embeddings(all_i_ids) if all_i_ids is not None and self.is_share else self.item_embeddings.weight 128 | all_i_b = self.item_bias(all_i_ids) if all_i_ids is not None and self.is_share else self.item_bias.weight 129 | item_total, _ = all_i.size() 130 | 131 | u_e = self.user_embeddings(u_ids) 132 | u_b = self.user_bias(u_ids).squeeze() 133 | # expand to batch * item 134 | u_b_e = u_b.expand(item_total, batch_size).permute(1, 0) 135 | i_b_e = all_i_b.squeeze().expand(batch_size, item_total) 136 | 137 | score = self.bias.expand(batch_size, item_total) + u_b_e + i_b_e + torch.matmul(u_e, all_i.t()) 138 | 139 | return score 140 | 141 | def evaluateHead(self, t, r, all_e_ids=None): 142 | batch_size = len(t) 143 | 144 | all_e = self.ent_embeddings(all_e_ids) if all_e_ids is not None and self.is_share else self.ent_embeddings.weight 145 | ent_total, dim = all_e.size() 146 | 147 | # batch * dim 148 | t_e = self.ent_embeddings(t) 149 | r_e = self.rel_embeddings(r) 150 | 151 | c_h_e = t_e - r_e 152 | 153 | # batch * entity * dim 154 | c_h_expand = c_h_e.expand(ent_total, batch_size, dim).permute(1, 0, 2) 155 | 156 | # batch * entity * dim 157 | ent_expand = all_e.expand(batch_size, ent_total, dim) 158 | 159 | # batch * entity 160 | if self.L1_flag: 161 | score = torch.sum(torch.abs(c_h_expand-ent_expand), 2) 162 | else: 163 | score = torch.sum((c_h_expand-ent_expand) ** 2, 2) 164 | 165 | return score 166 | 167 | def evaluateTail(self, h, r, all_e_ids=None): 168 | batch_size = len(h) 169 | 170 | all_e = self.ent_embeddings(all_e_ids) if all_e_ids is not None and self.is_share else self.ent_embeddings.weight 171 | ent_total, dim = all_e.size() 172 | 173 | # batch * dim 174 | h_e = self.ent_embeddings(h) 175 | r_e = self.rel_embeddings(r) 176 | 177 | c_t_e = h_e + r_e 178 | 179 | # batch * entity * dim 180 | c_t_expand = c_t_e.expand(ent_total, batch_size, dim).permute(1, 0, 2) 181 | 182 | # batch * entity * dim 183 | ent_expand = all_e.expand(batch_size, ent_total, dim) 184 | 185 | # batch * entity 186 | if self.L1_flag: 187 | score = torch.sum(torch.abs(c_t_expand-ent_expand), 2) 188 | else: 189 | score = torch.sum((c_t_expand-ent_expand) ** 2, 2) 190 | return score 191 | 192 | def disable_grad(self): 193 | for name, param in self.named_parameters(): 194 | param.requires_grad=False 195 | 196 | def enable_grad(self): 197 | for name, param in self.named_parameters(): 198 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/fm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = FM 10 | return model_cls( 11 | FLAGS.embedding_size, 12 | user_total, 13 | item_total) 14 | 15 | class FM(nn.Module): 16 | def __init__(self, 17 | embedding_size, 18 | user_total, 19 | item_total, 20 | ): 21 | super(FM, self).__init__() 22 | self.embedding_size = embedding_size 23 | self.user_total = user_total 24 | self.item_total = item_total 25 | self.is_pretrained = False 26 | 27 | # init user and item embeddings 28 | user_weight = torch.FloatTensor(self.user_total, self.embedding_size) 29 | item_weight = torch.FloatTensor(self.item_total, self.embedding_size) 30 | user_bias = torch.FloatTensor(self.user_total) 31 | item_bias = torch.FloatTensor(self.item_total) 32 | nn.init.xavier_uniform(user_weight) 33 | nn.init.xavier_uniform(item_weight) 34 | nn.init.constant(user_bias, 0) 35 | nn.init.constant(item_bias, 0) 36 | # init user and item embeddings 37 | self.user_embeddings = nn.Embedding(self.user_total, self.embedding_size) 38 | self.item_embeddings = nn.Embedding(self.item_total, self.embedding_size) 39 | self.user_bias = nn.Embedding(self.user_total, 1) 40 | self.item_bias = nn.Embedding(self.item_total, 1) 41 | self.user_embeddings.weight = nn.Parameter(user_weight) 42 | self.item_embeddings.weight = nn.Parameter(item_weight) 43 | self.user_bias.weight = nn.Parameter(user_bias, 1) 44 | self.item_bias.weight = nn.Parameter(item_bias, 1) 45 | normalize_user_emb = F.normalize(self.user_embeddings.weight.data, p=2, dim=1) 46 | normalize_item_emb = F.normalize(self.item_embeddings.weight.data, p=2, dim=1) 47 | self.user_embeddings.weight.data = normalize_user_emb 48 | self.item_embeddings.weight.data = normalize_item_emb 49 | 50 | self.user_embeddings = to_gpu(self.user_embeddings) 51 | self.item_embeddings = to_gpu(self.item_embeddings) 52 | self.user_bias = to_gpu(self.user_bias) 53 | self.item_bias = to_gpu(self.item_bias) 54 | 55 | self.bias = nn.Parameter(to_gpu(torch.FloatTensor([0.0]))) 56 | 57 | def forward(self, u_ids, i_ids): 58 | batch_size = len(u_ids) 59 | 60 | u_e = self.user_embeddings(u_ids) 61 | i_e = self.item_embeddings(i_ids) 62 | u_b = self.user_bias(u_ids).squeeze() 63 | i_b = self.item_bias(i_ids).squeeze() 64 | 65 | y = self.bias.expand(batch_size) + u_b + i_b + torch.bmm(u_e.unsqueeze(1), i_e.unsqueeze(2)).squeeze() 66 | return y 67 | 68 | def evaluate(self, u_ids): 69 | batch_size = len(u_ids) 70 | u_e = self.user_embeddings(u_ids) 71 | u_b = self.user_bias(u_ids).squeeze() 72 | 73 | # expand to batch * item 74 | u_b_e = u_b.expand(self.item_total, batch_size).permute(1, 0) 75 | i_b_e = self.item_bias.weight.squeeze().expand(batch_size, self.item_total) 76 | 77 | y_e = self.bias.expand(batch_size, self.item_total) + u_b_e + i_b_e + torch.matmul(u_e, self.item_embeddings.weight.t()) 78 | 79 | return y_e 80 | 81 | def disable_grad(self): 82 | for name, param in self.named_parameters(): 83 | param.requires_grad=False 84 | 85 | def enable_grad(self): 86 | for name, param in self.named_parameters(): 87 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/item_recommendation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import gflags 3 | import sys 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | tqdm.monitor_iterval=0 8 | import math 9 | import time 10 | import random 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable as V 16 | 17 | from jTransUP.models.base import get_flags, flag_defaults, init_model 18 | from jTransUP.data.load_rating_data import load_data 19 | from jTransUP.utils.trainer import ModelTrainer 20 | from jTransUP.utils.misc import evalRecProcess, to_gpu, USE_CUDA 21 | from jTransUP.utils.loss import bprLoss, orthogonalLoss, normLoss 22 | from jTransUP.utils.visuliazer import Visualizer 23 | from jTransUP.utils.data import getNegRatings 24 | 25 | FLAGS = gflags.FLAGS 26 | 27 | def evaluate(FLAGS, model, eval_iter, eval_dict, all_dicts, logger, eval_descending=True, is_report=False): 28 | 29 | # Evaluate 30 | total_batches = len(eval_iter) 31 | # processing bar 32 | pbar = tqdm(total=total_batches) 33 | pbar.set_description("Run Eval") 34 | 35 | model.eval() 36 | model.disable_grad() 37 | 38 | results = [] 39 | for u_ids in eval_iter: 40 | u_var = to_gpu(V(torch.LongTensor(u_ids))) 41 | # batch * item 42 | scores = model.evaluate(u_var) 43 | preds = zip(u_ids, scores.data.cpu().numpy()) 44 | 45 | results.extend( evalRecProcess(list(preds), eval_dict, all_dicts=all_dicts, descending=eval_descending, num_processes=FLAGS.num_processes, topn=FLAGS.topn, queue_limit=FLAGS.max_queue) ) 46 | 47 | pbar.update(1) 48 | pbar.close() 49 | 50 | performances = [result[:5] for result in results] 51 | f1, p, r, hit, ndcg = np.array(performances).mean(axis=0) 52 | 53 | logger.info("f1:{:.4f}, p:{:.4f}, r:{:.4f}, hit:{:.4f}, ndcg:{:.4f}, topn:{}.".format(f1, p, r, hit, ndcg, FLAGS.topn)) 54 | 55 | if is_report: 56 | predict_tuples = [result[-1] for result in results] 57 | for pred_tuple in predict_tuples: 58 | u_id = pred_tuple[0] 59 | top_ids = pred_tuple[1] 60 | gold_ids = list(pred_tuple[2]) 61 | if FLAGS.model_type in ["transup", "jtransup", "cjtransup"]: 62 | for d in all_dicts: 63 | gold_ids += list(d.get(u_id, set())) 64 | gold_ids += list(eval_dict.get(u_id, set())) 65 | u_var = to_gpu(V(torch.LongTensor([u_id]))) 66 | i_var = to_gpu(V(torch.LongTensor(gold_ids))) 67 | # item_num * relation_total 68 | probs, _, _ = model.reportPreference(u_var, i_var) 69 | max_rel_index = torch.max(probs, 1)[1] 70 | gold_strs = ",".join(["{}({})".format(ir[0], ir[1]) for ir in zip(gold_ids, max_rel_index.data.tolist())]) 71 | else: 72 | gold_strs = ",".join([str(i) for i in gold_ids]) 73 | logger.info("user:{}\tgold:{}\ttop:{}".format(u_id, gold_strs, ",".join([str(i) for i in top_ids]))) 74 | model.enable_grad() 75 | return f1, p, r, hit, ndcg 76 | 77 | def train_loop(FLAGS, model, trainer, train_dataset, eval_datasets, 78 | user_total, item_total, logger, vis=None, is_report=False): 79 | train_iter, train_total, train_list, train_dict = train_dataset 80 | 81 | all_dicts = None 82 | if FLAGS.filter_wrong_corrupted: 83 | all_dicts = [train_dict] + [tmp_data[3] for tmp_data in eval_datasets] 84 | 85 | # Train. 86 | logger.info("Training.") 87 | 88 | # New Training Loop 89 | pbar = None 90 | total_loss = 0.0 91 | model.train() 92 | model.enable_grad() 93 | for _ in range(trainer.step, FLAGS.training_steps): 94 | 95 | if FLAGS.early_stopping_steps_to_wait > 0 and (trainer.step - trainer.best_step) > FLAGS.early_stopping_steps_to_wait: 96 | logger.info('No improvement after ' + 97 | str(FLAGS.early_stopping_steps_to_wait) + 98 | ' steps. Stopping training.') 99 | if pbar is not None: pbar.close() 100 | break 101 | if trainer.step % FLAGS.eval_interval_steps == 0 : 102 | if pbar is not None: 103 | pbar.close() 104 | total_loss /= FLAGS.eval_interval_steps 105 | logger.info("train loss:{:.4f}!".format(total_loss)) 106 | 107 | performances = [] 108 | for i, eval_data in enumerate(eval_datasets): 109 | all_eval_dicts = None 110 | if FLAGS.filter_wrong_corrupted: 111 | all_eval_dicts = [train_dict] + [tmp_data[3] for j, tmp_data in enumerate(eval_datasets) if j!=i] 112 | 113 | performances.append( evaluate(FLAGS, model, eval_data[0], eval_data[3], all_eval_dicts, logger, eval_descending=True if trainer.model_target == 1 else False, is_report=is_report)) 114 | 115 | if trainer.step > 0 and len(performances) > 0: 116 | is_best = trainer.new_performance(performances[0], performances) 117 | 118 | # visuliazation 119 | if vis is not None: 120 | vis.plot_many_stack({'Rec Train Loss': total_loss}, 121 | win_name="Loss Curve") 122 | f1_vis_dict = {} 123 | p_vis_dict = {} 124 | r_vis_dict = {} 125 | hit_vis_dict = {} 126 | ndcg_vis_dict = {} 127 | for i, performance in enumerate(performances): 128 | f1_vis_dict['Rec Eval {} F1'.format(i)] = performance[0] 129 | p_vis_dict['Rec Eval {} Precision'.format(i)] = performance[1] 130 | r_vis_dict['Rec Eval {} Recall'.format(i)] = performance[2] 131 | hit_vis_dict['Rec Eval {} Hit'.format(i)] = performance[3] 132 | ndcg_vis_dict['Rec Eval {} NDCG'.format(i)] = performance[4] 133 | 134 | if is_best: 135 | log_str = ["Best performances in {} step!".format(trainer.best_step)] 136 | log_str += ["{} : {}.".format(s, "%.5f" % f1_vis_dict[s]) for s in f1_vis_dict] 137 | log_str += ["{} : {}.".format(s, "%.5f" % p_vis_dict[s]) for s in p_vis_dict] 138 | log_str += ["{} : {}.".format(s, "%.5f" % r_vis_dict[s]) for s in r_vis_dict] 139 | log_str += ["{} : {}.".format(s, "%.5f" % hit_vis_dict[s]) for s in hit_vis_dict] 140 | log_str += ["{} : {}.".format(s, "%.5f" % ndcg_vis_dict[s]) for s in ndcg_vis_dict] 141 | 142 | vis.log("\n".join(log_str), win_name="Best Performances") 143 | 144 | vis.plot_many_stack(f1_vis_dict, win_name="Rec F1 Score@{}".format(FLAGS.topn)) 145 | 146 | vis.plot_many_stack(p_vis_dict, win_name="Rec Precision@{}".format(FLAGS.topn)) 147 | 148 | vis.plot_many_stack(r_vis_dict, win_name="Rec Recall@{}".format(FLAGS.topn)) 149 | 150 | vis.plot_many_stack(hit_vis_dict, win_name="Rec Hit Ratio@{}".format(FLAGS.topn)) 151 | 152 | vis.plot_many_stack(ndcg_vis_dict, win_name="Rec NDCG@{}".format(FLAGS.topn)) 153 | 154 | # set model in training mode 155 | pbar = tqdm(total=FLAGS.eval_interval_steps) 156 | pbar.set_description("Training") 157 | total_loss = 0.0 158 | model.train() 159 | model.enable_grad() 160 | 161 | rating_batch = next(train_iter) 162 | u, pi, ni = getNegRatings(rating_batch, item_total, all_dicts=all_dicts) 163 | 164 | u_var = to_gpu(V(torch.LongTensor(u))) 165 | pi_var = to_gpu(V(torch.LongTensor(pi))) 166 | ni_var = to_gpu(V(torch.LongTensor(ni))) 167 | 168 | trainer.optimizer_zero_grad() 169 | 170 | # Run model. output: batch_size * cand_num 171 | pos_score = model(u_var, pi_var) 172 | neg_score = model(u_var, ni_var) 173 | 174 | # Calculate loss. 175 | losses = bprLoss(pos_score, neg_score, target=trainer.model_target) 176 | 177 | if FLAGS.model_type in ["transup","transupb"]: 178 | user_embeddings = model.user_embeddings(u_var) 179 | item_embeddings = model.item_embeddings(torch.cat([pi_var, ni_var])) 180 | losses += orthogonalLoss(model.pref_embeddings.weight, model.pref_norm_embeddings.weight) + normLoss(user_embeddings) + normLoss(item_embeddings) + normLoss(model.pref_embeddings.weight) 181 | 182 | # Backward pass. 183 | losses.backward() 184 | 185 | # for param in model.parameters(): 186 | # print(param.grad.data.sum()) 187 | 188 | # Hard Gradient Clipping 189 | nn.utils.clip_grad_norm([param for name, param in model.named_parameters()], FLAGS.clipping_max_value) 190 | 191 | # Gradient descent step. 192 | trainer.optimizer_step() 193 | total_loss += losses.data[0] 194 | pbar.update(1) 195 | 196 | def run(only_forward=False): 197 | if FLAGS.seed != 0: 198 | random.seed(FLAGS.seed) 199 | torch.manual_seed(FLAGS.seed) 200 | 201 | # set visualization 202 | vis = None 203 | if FLAGS.has_visualization: 204 | vis = Visualizer(env=FLAGS.experiment_name, port=FLAGS.visualization_port) 205 | vis.log(json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True), 206 | win_name="Parameter") 207 | 208 | # set logger 209 | log_file = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".log") 210 | logger = logging.getLogger() 211 | log_level = logging.DEBUG if FLAGS.log_level == "debug" else logging.INFO 212 | logger.setLevel(level=log_level) 213 | # Formatter 214 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 215 | # FileHandler 216 | file_handler = logging.FileHandler(log_file) 217 | file_handler.setFormatter(formatter) 218 | logger.addHandler(file_handler) 219 | # StreamHandler 220 | stream_handler = logging.StreamHandler() 221 | stream_handler.setFormatter(formatter) 222 | logger.addHandler(stream_handler) 223 | 224 | logger.info("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) 225 | 226 | # load data 227 | dataset_path = os.path.join(FLAGS.data_path, FLAGS.dataset) 228 | eval_files = FLAGS.rec_test_files.split(':') 229 | 230 | train_dataset, eval_datasets, u_map, i_map = load_data(dataset_path, eval_files, FLAGS.batch_size, logger=logger, negtive_samples=FLAGS.negtive_samples) 231 | 232 | train_iter, train_total, train_list, train_dict = train_dataset 233 | 234 | user_total = max(len(u_map), max(u_map.values())) 235 | item_total = max(len(i_map), max(i_map.values())) 236 | 237 | model = init_model(FLAGS, user_total, item_total, 0, 0, logger) 238 | epoch_length = math.ceil( train_total / FLAGS.batch_size ) 239 | trainer = ModelTrainer(model, logger, epoch_length, FLAGS) 240 | 241 | if FLAGS.load_ckpt_file is not None: 242 | trainer.loadEmbedding(os.path.join(FLAGS.log_path, FLAGS.load_ckpt_file), model.state_dict(), cpu=not USE_CUDA) 243 | model.is_pretrained = True 244 | 245 | # Do an evaluation-only run. 246 | if only_forward: 247 | for i, eval_data in enumerate(eval_datasets): 248 | all_dicts = None 249 | if FLAGS.filter_wrong_corrupted: 250 | all_dicts = [train_dict] + [tmp_data[3] for j, tmp_data in enumerate(eval_datasets) if j!=i] 251 | evaluate( 252 | FLAGS, 253 | model, 254 | eval_data[0], 255 | eval_data[3], 256 | all_dicts, 257 | logger, 258 | eval_descending=True if trainer.model_target == 1 else False, 259 | is_report=FLAGS.is_report) 260 | else: 261 | train_loop( 262 | FLAGS, 263 | model, 264 | trainer, 265 | train_dataset, 266 | eval_datasets, 267 | user_total, 268 | item_total, 269 | logger, 270 | vis=vis, 271 | is_report=False) 272 | if vis is not None: 273 | vis.log("Finish!", win_name="Best Performances") 274 | 275 | if __name__ == '__main__': 276 | get_flags() 277 | 278 | # Parse command line flags. 279 | FLAGS(sys.argv) 280 | flag_defaults(FLAGS) 281 | 282 | run(only_forward=FLAGS.eval_only_mode) 283 | -------------------------------------------------------------------------------- /jTransUP/models/knowledge_representation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import gflags 3 | import sys 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | tqdm.monitor_iterval=0 8 | import math 9 | import time 10 | import random 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable as V 16 | 17 | from jTransUP.models.base import get_flags, flag_defaults, init_model 18 | from jTransUP.data.load_triple_data import load_data 19 | from jTransUP.utils.trainer import ModelTrainer 20 | from jTransUP.utils.misc import to_gpu, evalKGProcess, USE_CUDA 21 | from jTransUP.utils.loss import bprLoss, orthogonalLoss, normLoss 22 | from jTransUP.utils.visuliazer import Visualizer 23 | from jTransUP.utils.data import getTrainTripleBatch 24 | import jTransUP.utils.loss as loss 25 | 26 | FLAGS = gflags.FLAGS 27 | 28 | def evaluate(FLAGS, model, entity_total, relation_total, eval_head_iter, eval_tail_iter, eval_head_dict, eval_tail_dict, all_head_dicts, all_tail_dicts, logger, eval_descending=True, is_report=False): 29 | # Evaluate 30 | total_batches = len(eval_head_iter) + len(eval_tail_iter) 31 | # processing bar 32 | pbar = tqdm(total=total_batches) 33 | pbar.set_description("Run Eval") 34 | 35 | model.eval() 36 | model.disable_grad() 37 | 38 | # head prediction evaluation 39 | head_results = [] 40 | for batch_trs in eval_head_iter: 41 | t = [tr[0] for tr in batch_trs] 42 | r = [tr[1] for tr in batch_trs] 43 | t_var = to_gpu(V(torch.LongTensor(t))) 44 | r_var = to_gpu(V(torch.LongTensor(r))) 45 | # batch * item 46 | scores = model.evaluateHead(t_var, r_var) 47 | preds = zip(batch_trs, scores.data.cpu().numpy()) 48 | 49 | head_results.extend( evalKGProcess(list(preds), eval_head_dict, all_dicts=all_head_dicts, descending=eval_descending, num_processes=FLAGS.num_processes, topn=FLAGS.topn, queue_limit=FLAGS.max_queue) ) 50 | 51 | pbar.update(1) 52 | # head prediction evaluation 53 | tail_results = [] 54 | for batch_hrs in eval_tail_iter: 55 | h = [hr[0] for hr in batch_hrs] 56 | r = [hr[1] for hr in batch_hrs] 57 | h_var = to_gpu(V(torch.LongTensor(h))) 58 | r_var = to_gpu(V(torch.LongTensor(r))) 59 | # batch * item 60 | scores = model.evaluateTail(h_var, r_var) 61 | preds = zip(batch_hrs, scores.data.cpu().numpy()) 62 | 63 | tail_results.extend( evalKGProcess(list(preds), eval_tail_dict, all_dicts=all_tail_dicts, descending=eval_descending, num_processes=FLAGS.num_processes, topn=FLAGS.topn, queue_limit=FLAGS.max_queue) ) 64 | 65 | pbar.update(1) 66 | 67 | pbar.close() 68 | 69 | # hit, rank 70 | head_performances = [result[:2] for result in head_results] 71 | tail_performances = [result[:2] for result in tail_results] 72 | 73 | head_hit, head_mean_rank = np.array(head_performances).mean(axis=0) 74 | 75 | tail_hit, tail_mean_rank = np.array(tail_performances).mean(axis=0) 76 | 77 | logger.info("head hit:{:.4f}, head mean rank:{:.4f}, topn:{}.".format(head_hit, head_mean_rank, FLAGS.topn)) 78 | 79 | logger.info("tail hit:{:.4f}, tail mean rank:{:.4f}, topn:{}.".format(tail_hit, tail_mean_rank, FLAGS.topn)) 80 | 81 | head_num = len(head_results) 82 | tail_num = len(tail_results) 83 | 84 | avg_hit = float(head_hit * head_num + tail_hit * tail_num) / (head_num + tail_num) 85 | avg_mean_rank = float(head_mean_rank * head_num + tail_mean_rank * tail_num) / (head_num + tail_num) 86 | 87 | logger.info("avg hit:{:.4f}, avg mean rank:{:.4f}, topn:{}.".format(avg_hit, avg_mean_rank, FLAGS.topn)) 88 | 89 | if is_report: 90 | for result in head_results: 91 | hit = result[0] 92 | rank = result[1] 93 | t = result[2][0] 94 | r = result[2][1] 95 | gold_h = result[3] 96 | logger.info("H\t{}\t{}\t{}\t{}".format(gold_h, t, r, hit)) 97 | for result in tail_results: 98 | hit = result[0] 99 | rank = result[1] 100 | h = result[2][0] 101 | r = result[2][1] 102 | gold_t = result[3] 103 | logger.info("T\t{}\t{}\t{}\t{}".format(h, gold_t, r, hit)) 104 | model.enable_grad() 105 | return avg_hit, avg_mean_rank 106 | 107 | def train_loop(FLAGS, model, trainer, train_dataset, eval_datasets, 108 | entity_total, relation_total, logger, vis=None, is_report=False): 109 | train_iter, train_total, train_list, train_head_dict, train_tail_dict = train_dataset 110 | 111 | all_head_dicts = None 112 | all_tail_dicts = None 113 | if FLAGS.filter_wrong_corrupted: 114 | all_head_dicts = [train_head_dict] + [tmp_data[4] for tmp_data in eval_datasets] 115 | all_tail_dicts = [train_tail_dict] + [tmp_data[5] for tmp_data in eval_datasets] 116 | 117 | # Train. 118 | logger.info("Training.") 119 | 120 | # New Training Loop 121 | pbar = None 122 | total_loss = 0.0 123 | model.enable_grad() 124 | for _ in range(trainer.step, FLAGS.training_steps): 125 | 126 | if FLAGS.early_stopping_steps_to_wait > 0 and (trainer.step - trainer.best_step) > FLAGS.early_stopping_steps_to_wait: 127 | logger.info('No improvement after ' + 128 | str(FLAGS.early_stopping_steps_to_wait) + 129 | ' steps. Stopping training.') 130 | if pbar is not None: pbar.close() 131 | break 132 | if trainer.step % FLAGS.eval_interval_steps == 0 : 133 | if pbar is not None: 134 | pbar.close() 135 | total_loss /= FLAGS.eval_interval_steps 136 | logger.info("train loss:{:.4f}!".format(total_loss)) 137 | 138 | performances = [] 139 | for i, eval_data in enumerate(eval_datasets): 140 | eval_head_dicts = None 141 | eval_tail_dicts = None 142 | if FLAGS.filter_wrong_corrupted: 143 | eval_head_dicts = [train_head_dict] + [tmp_data[4] for j, tmp_data in enumerate(eval_datasets) if j!=i] 144 | eval_tail_dicts = [train_tail_dict] + [tmp_data[5] for j, tmp_data in enumerate(eval_datasets) if j!=i] 145 | 146 | performances.append( evaluate(FLAGS, model, entity_total, relation_total, eval_data[0], eval_data[1], eval_data[4], eval_data[5], eval_head_dicts, eval_tail_dicts, logger, eval_descending=False, is_report=is_report)) 147 | 148 | if trainer.step > 0 and len(performances) > 0 : 149 | is_best = trainer.new_performance(performances[0], performances) 150 | # visuliazation 151 | if vis is not None: 152 | vis.plot_many_stack({'KG Train Loss': total_loss}, 153 | win_name="Loss Curve") 154 | hit_vis_dict = {} 155 | meanrank_vis_dict = {} 156 | for i, performance in enumerate(performances): 157 | hit_vis_dict['KG Eval {} Hit'.format(i)] = performance[0] 158 | meanrank_vis_dict['KG Eval {} MeanRank'.format(i)] = performance[1] 159 | 160 | if is_best: 161 | log_str = ["Best performances in {} step!".format(trainer.best_step)] 162 | log_str += ["{} : {}.".format(s, "%.5f" % hit_vis_dict[s]) for s in hit_vis_dict] 163 | log_str += ["{} : {}.".format(s, "%.5f" % meanrank_vis_dict[s]) for s in meanrank_vis_dict] 164 | vis.log("\n".join(log_str), win_name="Best Performances") 165 | 166 | vis.plot_many_stack(hit_vis_dict, win_name="KG Hit Ratio@{}".format(FLAGS.topn)) 167 | 168 | vis.plot_many_stack(meanrank_vis_dict, win_name="KG MeanRank") 169 | # set model in training mode 170 | pbar = tqdm(total=FLAGS.eval_interval_steps) 171 | pbar.set_description("Training") 172 | total_loss = 0.0 173 | model.train() 174 | model.enable_grad() 175 | 176 | triple_batch = next(train_iter) 177 | ph, pt, pr, nh, nt, nr = getTrainTripleBatch(triple_batch, entity_total, all_head_dicts=all_head_dicts, all_tail_dicts=all_tail_dicts) 178 | 179 | ph_var = to_gpu(V(torch.LongTensor(ph))) 180 | pt_var = to_gpu(V(torch.LongTensor(pt))) 181 | pr_var = to_gpu(V(torch.LongTensor(pr))) 182 | nh_var = to_gpu(V(torch.LongTensor(nh))) 183 | nt_var = to_gpu(V(torch.LongTensor(nt))) 184 | nr_var = to_gpu(V(torch.LongTensor(nr))) 185 | 186 | trainer.optimizer_zero_grad() 187 | 188 | # Run model. output: batch_size * 1 189 | pos_score = model(ph_var, pt_var, pr_var) 190 | neg_score = model(nh_var, nt_var, nr_var) 191 | 192 | # Calculate loss. 193 | # losses = nn.MarginRankingLoss(margin=FLAGS.margin).forward(pos_score, neg_score, to_gpu(torch.autograd.Variable(torch.FloatTensor([trainer.model_target]*len(ph))))) 194 | 195 | losses = loss.marginLoss()(pos_score, neg_score, FLAGS.margin) 196 | 197 | ent_embeddings = model.ent_embeddings(torch.cat([ph_var, pt_var, nh_var, nt_var])) 198 | rel_embeddings = model.rel_embeddings(torch.cat([pr_var, nr_var])) 199 | 200 | if FLAGS.model_type == "transh": 201 | norm_embeddings = model.norm_embeddings(torch.cat([pr_var, nr_var])) 202 | losses += loss.orthogonalLoss(rel_embeddings, norm_embeddings) 203 | 204 | losses = losses + loss.normLoss(ent_embeddings) + loss.normLoss(rel_embeddings) 205 | 206 | # Backward pass. 207 | losses.backward() 208 | 209 | # for param in model.parameters(): 210 | # print(param.grad.data.sum()) 211 | 212 | # Hard Gradient Clipping 213 | nn.utils.clip_grad_norm([param for name, param in model.named_parameters()], FLAGS.clipping_max_value) 214 | 215 | # Gradient descent step. 216 | trainer.optimizer_step() 217 | total_loss += losses.data[0] 218 | pbar.update(1) 219 | trainer.save(trainer.checkpoint_path + '_final') 220 | 221 | def run(only_forward=False): 222 | if FLAGS.seed != 0: 223 | random.seed(FLAGS.seed) 224 | torch.manual_seed(FLAGS.seed) 225 | 226 | # set visualization 227 | vis = None 228 | if FLAGS.has_visualization: 229 | vis = Visualizer(env=FLAGS.experiment_name, port=FLAGS.visualization_port) 230 | vis.log(json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True), 231 | win_name="Parameter") 232 | 233 | # set logger 234 | log_file = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".log") 235 | 236 | logger = logging.getLogger() 237 | log_level = logging.DEBUG if FLAGS.log_level == "debug" else logging.INFO 238 | logger.setLevel(level=log_level) 239 | # Formatter 240 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 241 | # FileHandler 242 | file_handler = logging.FileHandler(log_file) 243 | file_handler.setFormatter(formatter) 244 | logger.addHandler(file_handler) 245 | # StreamHandler 246 | stream_handler = logging.StreamHandler() 247 | stream_handler.setFormatter(formatter) 248 | logger.addHandler(stream_handler) 249 | 250 | logger.info("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) 251 | 252 | # load data 253 | kg_path = os.path.join(os.path.join(FLAGS.data_path, FLAGS.dataset), 'kg') 254 | eval_files = [] 255 | if FLAGS.kg_test_files: 256 | eval_files = FLAGS.kg_test_files.split(':') 257 | 258 | train_dataset, eval_datasets, e_map, r_map = load_data(kg_path, eval_files, FLAGS.batch_size, logger=logger, negtive_samples=FLAGS.negtive_samples) 259 | 260 | entity_total = max(len(e_map), max(e_map.values())) 261 | relation_total = max(len(r_map), max(r_map.values())) 262 | 263 | train_iter, train_total, train_list, train_head_dict, train_tail_dict = train_dataset 264 | 265 | model = init_model(FLAGS, 0, 0, entity_total, relation_total, logger) 266 | epoch_length = math.ceil( train_total / FLAGS.batch_size ) 267 | trainer = ModelTrainer(model, logger, epoch_length, FLAGS) 268 | 269 | # todo : load ckpt full path 270 | if FLAGS.load_ckpt_file is not None: 271 | trainer.loadEmbedding(os.path.join(FLAGS.log_path, FLAGS.load_ckpt_file), model.state_dict(), cpu=not USE_CUDA) 272 | model.is_pretrained = True 273 | 274 | # Do an evaluation-only run. 275 | if only_forward: 276 | # head_iter, tail_iter, eval_total, eval_list, eval_head_dict, eval_tail_dict 277 | for i, eval_data in enumerate(eval_datasets): 278 | all_head_dicts = None 279 | all_tail_dicts = None 280 | if FLAGS.filter_wrong_corrupted: 281 | all_head_dicts = [train_head_dict] + [tmp_data[4] for j, tmp_data in enumerate(eval_datasets) if j!=i] 282 | all_tail_dicts = [train_tail_dict] + [tmp_data[5] for j, tmp_data in enumerate(eval_datasets) if j!=i] 283 | evaluate( 284 | FLAGS, 285 | model, 286 | entity_total, 287 | relation_total, 288 | eval_data[0], 289 | eval_data[1], 290 | eval_data[4], 291 | eval_data[5], 292 | all_head_dicts, 293 | all_tail_dicts, 294 | logger, 295 | eval_descending=False, 296 | is_report=FLAGS.is_report) 297 | else: 298 | train_loop( 299 | FLAGS, 300 | model, 301 | trainer, 302 | train_dataset, 303 | eval_datasets, 304 | entity_total, 305 | relation_total, 306 | logger, 307 | vis=vis, 308 | is_report=False) 309 | if vis is not None: 310 | vis.log("Finish!", win_name="Best Performances") 311 | 312 | if __name__ == '__main__': 313 | get_flags() 314 | 315 | # Parse command line flags. 316 | FLAGS(sys.argv) 317 | flag_defaults(FLAGS) 318 | 319 | run(only_forward=FLAGS.eval_only_mode) 320 | -------------------------------------------------------------------------------- /jTransUP/models/transD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu, projection_transD_pytorch_samesize 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = TransHModel 10 | return model_cls( 11 | L1_flag = FLAGS.L1_flag, 12 | embedding_size = FLAGS.embedding_size, 13 | ent_total = entity_total, 14 | rel_total = relation_total 15 | ) 16 | 17 | class TransHModel(nn.Module): 18 | def __init__(self, 19 | L1_flag, 20 | embedding_size, 21 | ent_total, 22 | rel_total 23 | ): 24 | super(TransHModel, self).__init__() 25 | self.L1_flag = L1_flag 26 | self.embedding_size = embedding_size 27 | self.ent_total = ent_total 28 | self.rel_total = rel_total 29 | self.is_pretrained = False 30 | 31 | ent_weight = torch.FloatTensor(self.ent_total, self.embedding_size) 32 | rel_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 33 | ent_proj_weight = torch.FloatTensor(self.ent_total, self.embedding_size) 34 | rel_proj_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 35 | nn.init.xavier_uniform(ent_weight) 36 | nn.init.xavier_uniform(rel_weight) 37 | ent_proj_weight.zero_() 38 | rel_proj_weight.zero_() 39 | # init user and item embeddings 40 | self.ent_embeddings = nn.Embedding(self.ent_total, self.embedding_size) 41 | self.rel_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 42 | self.ent_proj_embeddings = nn.Embedding(self.ent_total, self.embedding_size) 43 | self.rel_proj_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 44 | 45 | self.ent_embeddings.weight = nn.Parameter(ent_weight) 46 | self.rel_embeddings.weight = nn.Parameter(rel_weight) 47 | self.ent_proj_embeddings.weight = nn.Parameter(ent_proj_weight) 48 | self.rel_proj_embeddings.weight = nn.Parameter(rel_proj_weight) 49 | 50 | normalize_ent_emb = F.normalize(self.ent_embeddings.weight.data, p=2, dim=1) 51 | normalize_rel_emb = F.normalize(self.rel_embeddings.weight.data, p=2, dim=1) 52 | 53 | self.ent_embeddings.weight.data = normalize_ent_emb 54 | self.rel_embeddings.weight.data = normalize_rel_emb 55 | 56 | self.ent_embeddings = to_gpu(self.ent_embeddings) 57 | self.rel_embeddings = to_gpu(self.rel_embeddings) 58 | self.ent_proj_embeddings = to_gpu(self.ent_proj_embeddings) 59 | self.rel_proj_embeddings = to_gpu(self.rel_proj_embeddings) 60 | 61 | def forward(self, h, t, r): 62 | h_e = self.ent_embeddings(h) 63 | t_e = self.ent_embeddings(t) 64 | r_e = self.rel_embeddings(r) 65 | h_proj = self.ent_proj_embeddings(h) 66 | t_proj = self.ent_proj_embeddings(t) 67 | r_proj = self.rel_proj_embeddings(r) 68 | 69 | proj_h_e = projection_transD_pytorch_samesize(h_e, h_proj, r_proj) 70 | proj_t_e = projection_transD_pytorch_samesize(t_e, t_proj, r_proj) 71 | 72 | if self.L1_flag: 73 | score = torch.sum(torch.abs(proj_h_e + r_e - proj_t_e), 1) 74 | else: 75 | score = torch.sum((proj_h_e + r_e - proj_t_e) ** 2, 1) 76 | return score 77 | 78 | def evaluateHead(self, t, r): 79 | batch_size = len(t) 80 | # batch * dim 81 | t_e = self.ent_embeddings(t) 82 | r_e = self.rel_embeddings(r) 83 | # batch* dim 84 | t_proj = self.ent_proj_embeddings(t) 85 | r_proj = self.rel_proj_embeddings(r) 86 | # batch * dim 87 | proj_t_e = projection_transD_pytorch_samesize(t_e, t_proj, r_proj) 88 | c_h_e = proj_t_e - r_e 89 | 90 | # batch * entity * dim 91 | c_h_expand = c_h_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 92 | 93 | # batch * entity * dim 94 | t_proj_expand = t_proj.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 95 | r_proj_expand = r_proj.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 96 | 97 | ent_expand = self.ent_embeddings.weight.expand(batch_size, self.ent_total, self.embedding_size) 98 | proj_ent_expand = projection_transD_pytorch_samesize(ent_expand, t_proj_expand, r_proj_expand) 99 | 100 | # batch * entity 101 | if self.L1_flag: 102 | score = torch.sum(torch.abs(c_h_expand-proj_ent_expand), 2) 103 | else: 104 | score = torch.sum((c_h_expand-proj_ent_expand) ** 2, 2) 105 | return score 106 | 107 | def evaluateTail(self, h, r): 108 | batch_size = len(h) 109 | # batch * dim 110 | h_e = self.ent_embeddings(h) 111 | r_e = self.rel_embeddings(r) 112 | # batch* dim 113 | h_proj = self.ent_proj_embeddings(h) 114 | r_proj = self.rel_proj_embeddings(r) 115 | # batch * dim 116 | proj_h_e = projection_transD_pytorch_samesize(h_e, h_proj, r_proj) 117 | c_t_e = proj_h_e + r_e 118 | 119 | # batch * entity * dim 120 | c_t_expand = c_t_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 121 | 122 | # batch * entity * dim 123 | h_proj_expand = h_proj.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 124 | r_proj_expand = r_proj.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 125 | 126 | ent_expand = self.ent_embeddings.weight.expand(batch_size, self.ent_total, self.embedding_size) 127 | proj_ent_expand = projection_transD_pytorch_samesize(ent_expand, t_proj_expand, r_proj_expand) 128 | 129 | # batch * entity 130 | if self.L1_flag: 131 | score = torch.sum(torch.abs(c_t_expand-proj_ent_expand), 2) 132 | else: 133 | score = torch.sum((c_t_expand-proj_ent_expand) ** 2, 2) 134 | return score 135 | 136 | def disable_grad(self): 137 | for name, param in self.named_parameters(): 138 | param.requires_grad=False 139 | 140 | def enable_grad(self): 141 | for name, param in self.named_parameters(): 142 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/transE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = TransEModel 10 | return model_cls( 11 | L1_flag = FLAGS.L1_flag, 12 | embedding_size = FLAGS.embedding_size, 13 | ent_total = entity_total, 14 | rel_total = relation_total 15 | ) 16 | 17 | class TransEModel(nn.Module): 18 | def __init__(self, 19 | L1_flag, 20 | embedding_size, 21 | ent_total, 22 | rel_total 23 | ): 24 | super(TransEModel, self).__init__() 25 | self.L1_flag = L1_flag 26 | self.embedding_size = embedding_size 27 | self.ent_total = ent_total 28 | self.rel_total = rel_total 29 | self.is_pretrained = False 30 | 31 | ent_weight = torch.FloatTensor(self.ent_total, self.embedding_size) 32 | rel_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 33 | nn.init.xavier_uniform(ent_weight) 34 | nn.init.xavier_uniform(rel_weight) 35 | # init user and item embeddings 36 | self.ent_embeddings = nn.Embedding(self.ent_total, self.embedding_size) 37 | self.rel_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 38 | 39 | self.ent_embeddings.weight = nn.Parameter(ent_weight) 40 | self.rel_embeddings.weight = nn.Parameter(rel_weight) 41 | 42 | normalize_ent_emb = F.normalize(self.ent_embeddings.weight.data, p=2, dim=1) 43 | normalize_rel_emb = F.normalize(self.rel_embeddings.weight.data, p=2, dim=1) 44 | 45 | self.ent_embeddings.weight.data = normalize_ent_emb 46 | self.rel_embeddings.weight.data = normalize_rel_emb 47 | 48 | self.ent_embeddings = to_gpu(self.ent_embeddings) 49 | self.rel_embeddings = to_gpu(self.rel_embeddings) 50 | 51 | def forward(self, h, t, r): 52 | h_e = self.ent_embeddings(h) 53 | t_e = self.ent_embeddings(t) 54 | r_e = self.rel_embeddings(r) 55 | 56 | # L1 distance 57 | if self.L1_flag: 58 | score = torch.sum(torch.abs(h_e + r_e - t_e), 1) 59 | # L2 distance 60 | else: 61 | score = torch.sum((h_e + r_e - t_e) ** 2, 1) 62 | 63 | return score 64 | 65 | def evaluateHead(self, t, r): 66 | batch_size = len(t) 67 | # batch * dim 68 | t_e = self.ent_embeddings(t) 69 | r_e = self.rel_embeddings(r) 70 | 71 | c_h_e = t_e - r_e 72 | 73 | # batch * entity * dim 74 | c_h_expand = c_h_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 75 | 76 | # batch * entity * dim 77 | ent_expand = self.ent_embeddings.weight.expand(batch_size, self.ent_total, self.embedding_size) 78 | 79 | # batch * entity 80 | if self.L1_flag: 81 | score = torch.sum(torch.abs(c_h_expand-ent_expand), 2) 82 | else: 83 | score = torch.sum((c_h_expand-ent_expand) ** 2, 2) 84 | return score 85 | 86 | def evaluateTail(self, h, r): 87 | batch_size = len(h) 88 | # batch * dim 89 | h_e = self.ent_embeddings(h) 90 | r_e = self.rel_embeddings(r) 91 | 92 | c_t_e = h_e + r_e 93 | 94 | # batch * entity * dim 95 | c_t_expand = c_t_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 96 | 97 | # batch * entity * dim 98 | ent_expand = self.ent_embeddings.weight.expand(batch_size, self.ent_total, self.embedding_size) 99 | 100 | # batch * entity 101 | if self.L1_flag: 102 | score = torch.sum(torch.abs(c_t_expand-ent_expand), 2) 103 | else: 104 | score = torch.sum((c_t_expand-ent_expand) ** 2, 2) 105 | return score 106 | 107 | def disable_grad(self): 108 | for name, param in self.named_parameters(): 109 | param.requires_grad=False 110 | 111 | def enable_grad(self): 112 | for name, param in self.named_parameters(): 113 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/transH.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu, projection_transH_pytorch 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = TransHModel 10 | return model_cls( 11 | L1_flag = FLAGS.L1_flag, 12 | embedding_size = FLAGS.embedding_size, 13 | ent_total = entity_total, 14 | rel_total = relation_total 15 | ) 16 | 17 | class TransHModel(nn.Module): 18 | def __init__(self, 19 | L1_flag, 20 | embedding_size, 21 | ent_total, 22 | rel_total 23 | ): 24 | super(TransHModel, self).__init__() 25 | self.L1_flag = L1_flag 26 | self.embedding_size = embedding_size 27 | self.ent_total = ent_total 28 | self.rel_total = rel_total 29 | self.is_pretrained = False 30 | 31 | ent_weight = torch.FloatTensor(self.ent_total, self.embedding_size) 32 | rel_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 33 | norm_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 34 | nn.init.xavier_uniform(ent_weight) 35 | nn.init.xavier_uniform(rel_weight) 36 | nn.init.xavier_uniform(norm_weight) 37 | # init user and item embeddings 38 | self.ent_embeddings = nn.Embedding(self.ent_total, self.embedding_size) 39 | self.rel_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 40 | self.norm_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 41 | 42 | self.ent_embeddings.weight = nn.Parameter(ent_weight) 43 | self.rel_embeddings.weight = nn.Parameter(rel_weight) 44 | self.norm_embeddings.weight = nn.Parameter(norm_weight) 45 | 46 | normalize_ent_emb = F.normalize(self.ent_embeddings.weight.data, p=2, dim=1) 47 | normalize_rel_emb = F.normalize(self.rel_embeddings.weight.data, p=2, dim=1) 48 | normalize_norm_emb = F.normalize(self.norm_embeddings.weight.data, p=2, dim=1) 49 | 50 | self.ent_embeddings.weight.data = normalize_ent_emb 51 | self.rel_embeddings.weight.data = normalize_rel_emb 52 | self.norm_embeddings.weight.data = normalize_norm_emb 53 | 54 | self.ent_embeddings = to_gpu(self.ent_embeddings) 55 | self.rel_embeddings = to_gpu(self.rel_embeddings) 56 | self.norm_embeddings = to_gpu(self.norm_embeddings) 57 | 58 | def forward(self, h, t, r): 59 | h_e = self.ent_embeddings(h) 60 | t_e = self.ent_embeddings(t) 61 | r_e = self.rel_embeddings(r) 62 | norm_e = self.norm_embeddings(r) 63 | 64 | proj_h_e = projection_transH_pytorch(h_e, norm_e) 65 | proj_t_e = projection_transH_pytorch(t_e, norm_e) 66 | 67 | if self.L1_flag: 68 | score = torch.sum(torch.abs(proj_h_e + r_e - proj_t_e), 1) 69 | else: 70 | score = torch.sum((proj_h_e + r_e - proj_t_e) ** 2, 1) 71 | return score 72 | 73 | def evaluateHead(self, t, r): 74 | batch_size = len(t) 75 | # batch * dim 76 | t_e = self.ent_embeddings(t) 77 | r_e = self.rel_embeddings(r) 78 | norm_e = self.norm_embeddings(r) 79 | 80 | proj_t_e = projection_transH_pytorch(t_e, norm_e) 81 | c_h_e = proj_t_e - r_e 82 | 83 | # batch * entity * dim 84 | c_h_expand = c_h_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 85 | 86 | # batch * entity * dim 87 | norm_expand = norm_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 88 | ent_expand = self.ent_embeddings.weight.expand(batch_size, self.ent_total, self.embedding_size) 89 | proj_ent_e = projection_transH_pytorch(ent_expand, norm_expand) 90 | 91 | # batch * entity 92 | if self.L1_flag: 93 | score = torch.sum(torch.abs(c_h_expand-proj_ent_e), 2) 94 | else: 95 | score = torch.sum((c_h_expand-proj_ent_e) ** 2, 2) 96 | return score 97 | 98 | def evaluateTail(self, h, r): 99 | batch_size = len(h) 100 | # batch * dim 101 | h_e = self.ent_embeddings(h) 102 | r_e = self.rel_embeddings(r) 103 | norm_e = self.norm_embeddings(r) 104 | 105 | proj_h_e = projection_transH_pytorch(h_e, norm_e) 106 | c_t_e = proj_h_e + r_e 107 | 108 | # batch * entity * dim 109 | c_t_expand = c_t_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 110 | 111 | # batch * entity * dim 112 | norm_expand = norm_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 113 | ent_expand = self.ent_embeddings.weight.expand(batch_size, self.ent_total, self.embedding_size) 114 | proj_ent_e = projection_transH_pytorch(ent_expand, norm_expand) 115 | 116 | # batch * entity 117 | if self.L1_flag: 118 | score = torch.sum(torch.abs(c_t_expand-proj_ent_e), 2) 119 | else: 120 | score = torch.sum((c_t_expand-proj_ent_e) ** 2, 2) 121 | return score 122 | 123 | def disable_grad(self): 124 | for name, param in self.named_parameters(): 125 | param.requires_grad=False 126 | 127 | def enable_grad(self): 128 | for name, param in self.named_parameters(): 129 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/transR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu, projection_transR_pytorch, projection_transR_pytorch_batch 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = TransRModel 10 | return model_cls( 11 | L1_flag = FLAGS.L1_flag, 12 | embedding_size = FLAGS.embedding_size, 13 | ent_total = entity_total, 14 | rel_total = relation_total 15 | ) 16 | 17 | class TransRModel(nn.Module): 18 | def __init__(self, 19 | L1_flag, 20 | embedding_size, 21 | ent_total, 22 | rel_total 23 | ): 24 | super(TransRModel, self).__init__() 25 | self.L1_flag = L1_flag 26 | self.embedding_size = embedding_size 27 | self.ent_total = ent_total 28 | self.rel_total = rel_total 29 | self.is_pretrained = False 30 | self.max_entity_batch = 10 31 | 32 | ent_weight = torch.FloatTensor(self.ent_total, self.embedding_size) 33 | rel_weight = torch.FloatTensor(self.rel_total, self.embedding_size) 34 | proj_weight = torch.FloatTensor(self.rel_total, self.embedding_size * self.embedding_size) 35 | nn.init.xavier_uniform(ent_weight) 36 | nn.init.xavier_uniform(rel_weight) 37 | 38 | if self.is_pretrained: 39 | nn.init.eye(proj_weight) 40 | proj_weight = proj_weight.view(-1).expand(self.relation_total, -1) 41 | else: 42 | nn.init.xavier_uniform(proj_weight) 43 | 44 | # init user and item embeddings 45 | self.ent_embeddings = nn.Embedding(self.ent_total, self.embedding_size) 46 | self.rel_embeddings = nn.Embedding(self.rel_total, self.embedding_size) 47 | self.proj_embeddings = nn.Embedding(self.rel_total, self.embedding_size * self.embedding_size) 48 | 49 | self.ent_embeddings.weight = nn.Parameter(ent_weight) 50 | self.rel_embeddings.weight = nn.Parameter(rel_weight) 51 | self.proj_embeddings.weight = nn.Parameter(proj_weight) 52 | 53 | normalize_ent_emb = F.normalize(self.ent_embeddings.weight.data, p=2, dim=1) 54 | normalize_rel_emb = F.normalize(self.rel_embeddings.weight.data, p=2, dim=1) 55 | # normalize_proj_emb = F.normalize(self.proj_embeddings.weight.data, p=2, dim=1) 56 | 57 | self.ent_embeddings.weight.data = normalize_ent_emb 58 | self.rel_embeddings.weight.data = normalize_rel_emb 59 | # self.proj_embeddings.weight.data = normalize_proj_emb 60 | 61 | self.ent_embeddings = to_gpu(self.ent_embeddings) 62 | self.rel_embeddings = to_gpu(self.rel_embeddings) 63 | self.proj_embeddings = to_gpu(self.proj_embeddings) 64 | 65 | def forward(self, h, t, r): 66 | h_e = self.ent_embeddings(h) 67 | t_e = self.ent_embeddings(t) 68 | r_e = self.rel_embeddings(r) 69 | proj_e = self.proj_embeddings(r) 70 | 71 | proj_h_e = projection_transR_pytorch(h_e, proj_e) 72 | proj_t_e = projection_transR_pytorch(t_e, proj_e) 73 | 74 | if self.L1_flag: 75 | score = torch.sum(torch.abs(proj_h_e + r_e - proj_t_e), 1) 76 | else: 77 | score = torch.sum((proj_h_e + r_e - proj_t_e) ** 2, 1) 78 | return score 79 | 80 | def evaluateHead(self, t, r): 81 | batch_size = len(t) 82 | 83 | # batch * dim 84 | t_e = self.ent_embeddings(t) 85 | r_e = self.rel_embeddings(r) 86 | # batch* dim*dim 87 | proj_e = self.proj_embeddings(r) 88 | # batch * dim 89 | proj_t_e = projection_transR_pytorch(t_e, proj_e) 90 | c_h_e = proj_t_e - r_e 91 | 92 | # batch * entity * dim 93 | c_h_expand = c_h_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 94 | 95 | # batch * entity * dim 96 | proj_ent_expand = projection_transR_pytorch_batch(self.ent_embeddings.weight, proj_e) 97 | 98 | # batch * entity 99 | if self.L1_flag: 100 | score = torch.sum(torch.abs(c_h_expand-proj_ent_expand), 2) 101 | else: 102 | score = torch.sum((c_h_expand-proj_ent_expand) ** 2, 2) 103 | return score 104 | 105 | def evaluateTail(self, h, r): 106 | batch_size = len(h) 107 | 108 | # batch * dim 109 | h_e = self.ent_embeddings(h) 110 | r_e = self.rel_embeddings(r) 111 | # batch* dim*dim 112 | proj_e = self.proj_embeddings(r) 113 | # batch * dim 114 | proj_h_e = projection_transR_pytorch(h_e, proj_e) 115 | c_t_e = proj_h_e + r_e 116 | 117 | # batch * entity * dim 118 | c_t_expand = c_t_e.expand(self.ent_total, batch_size, self.embedding_size).permute(1, 0, 2) 119 | 120 | # batch * entity * dim 121 | proj_ent_expand = projection_transR_pytorch_batch(self.ent_embeddings.weight, proj_e) 122 | 123 | # batch * entity 124 | if self.L1_flag: 125 | score = torch.sum(torch.abs(c_t_expand-proj_ent_expand), 2) 126 | else: 127 | score = torch.sum((c_t_expand-proj_ent_expand) ** 2, 2) 128 | return score 129 | 130 | def disable_grad(self): 131 | for name, param in self.named_parameters(): 132 | param.requires_grad=False 133 | 134 | def enable_grad(self): 135 | for name, param in self.named_parameters(): 136 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/models/transUP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | from jTransUP.utils.misc import to_gpu, projection_transH_pytorch 7 | 8 | def build_model(FLAGS, user_total, item_total, entity_total, relation_total, i_map=None, e_map=None, new_map=None): 9 | model_cls = TransUPModel 10 | return model_cls( 11 | L1_flag = FLAGS.L1_flag, 12 | embedding_size = FLAGS.embedding_size, 13 | user_total = user_total, 14 | item_total = item_total, 15 | preference_total = FLAGS.num_preferences, 16 | use_st_gumbel = FLAGS.use_st_gumbel 17 | ) 18 | 19 | class TransUPModel(nn.Module): 20 | def __init__(self, 21 | L1_flag, 22 | embedding_size, 23 | user_total, 24 | item_total, 25 | preference_total, 26 | use_st_gumbel 27 | ): 28 | super(TransUPModel, self).__init__() 29 | self.L1_flag = L1_flag 30 | self.embedding_size = embedding_size 31 | self.user_total = user_total 32 | self.item_total = item_total 33 | self.preference_total = preference_total 34 | self.is_pretrained = False 35 | self.use_st_gumbel = use_st_gumbel 36 | 37 | user_weight = torch.FloatTensor(self.user_total, self.embedding_size) 38 | item_weight = torch.FloatTensor(self.item_total, self.embedding_size) 39 | pref_weight = torch.FloatTensor(self.preference_total, self.embedding_size) 40 | norm_weight = torch.FloatTensor(self.preference_total, self.embedding_size) 41 | nn.init.xavier_uniform(user_weight) 42 | nn.init.xavier_uniform(item_weight) 43 | nn.init.xavier_uniform(pref_weight) 44 | nn.init.xavier_uniform(norm_weight) 45 | # init user and item embeddings 46 | self.user_embeddings = nn.Embedding(self.user_total, self.embedding_size) 47 | self.item_embeddings = nn.Embedding(self.item_total, self.embedding_size) 48 | self.user_embeddings.weight = nn.Parameter(user_weight) 49 | self.item_embeddings.weight = nn.Parameter(item_weight) 50 | normalize_user_emb = F.normalize(self.user_embeddings.weight.data, p=2, dim=1) 51 | normalize_item_emb = F.normalize(self.item_embeddings.weight.data, p=2, dim=1) 52 | self.user_embeddings.weight.data = normalize_user_emb 53 | self.item_embeddings.weight.data = normalize_item_emb 54 | # init preference parameters 55 | self.pref_embeddings = nn.Embedding(self.preference_total, self.embedding_size) 56 | self.pref_norm_embeddings = nn.Embedding(self.preference_total, self.embedding_size) 57 | self.pref_embeddings.weight = nn.Parameter(pref_weight) 58 | self.pref_norm_embeddings.weight = nn.Parameter(norm_weight) 59 | normalize_pref_emb = F.normalize(self.pref_embeddings.weight.data, p=2, dim=1) 60 | normalize_norm_emb = F.normalize(self.pref_norm_embeddings.weight.data, p=2, dim=1) 61 | self.pref_embeddings.weight.data = normalize_pref_emb 62 | self.pref_norm_embeddings.weight.data = normalize_norm_emb 63 | 64 | self.user_embeddings = to_gpu(self.user_embeddings) 65 | self.item_embeddings = to_gpu(self.item_embeddings) 66 | self.pref_embeddings = to_gpu(self.pref_embeddings) 67 | self.pref_norm_embeddings = to_gpu(self.pref_norm_embeddings) 68 | 69 | def forward(self, u_ids, i_ids): 70 | u_e = self.user_embeddings(u_ids) 71 | i_e = self.item_embeddings(i_ids) 72 | 73 | _, r_e, norm = self.getPreferences(u_e, i_e, use_st_gumbel=self.use_st_gumbel) 74 | 75 | proj_u_e = projection_transH_pytorch(u_e, norm) 76 | proj_i_e = projection_transH_pytorch(i_e, norm) 77 | 78 | if self.L1_flag: 79 | score = torch.sum(torch.abs(proj_u_e + r_e - proj_i_e), 1) 80 | else: 81 | score = torch.sum((proj_u_e + r_e - proj_i_e) ** 2, 1) 82 | return score 83 | 84 | def evaluate(self, u_ids): 85 | batch_size = len(u_ids) 86 | u = self.user_embeddings(u_ids) 87 | # expand u and i to pair wise match, batch * item * dim 88 | u_e = u.expand(self.item_total, batch_size, self.embedding_size).permute(1, 0, 2) 89 | i_e = self.item_embeddings.weight.expand(batch_size, self.item_total, self.embedding_size) 90 | 91 | # batch * item * dim 92 | _, r_e, norm = self.getPreferences(u_e, i_e, use_st_gumbel=self.use_st_gumbel) 93 | 94 | proj_u_e = projection_transH_pytorch(u_e, norm) 95 | proj_i_e = projection_transH_pytorch(i_e, norm) 96 | 97 | # batch * item 98 | if self.L1_flag: 99 | score = torch.sum(torch.abs(proj_u_e + r_e - proj_i_e), 2) 100 | else: 101 | score = torch.sum((proj_u_e + r_e - proj_i_e) ** 2, 2) 102 | return score 103 | 104 | # u_e, i_e : batch * dim or batch * item * dim 105 | def getPreferences(self, u_e, i_e, use_st_gumbel=False): 106 | # use item and user embedding to compute preference distribution 107 | # pre_probs: batch * rel, or batch * item * rel 108 | pre_probs = torch.matmul(u_e + i_e, torch.t(self.pref_embeddings.weight)) / 2 109 | if use_st_gumbel: 110 | pre_probs = self.st_gumbel_softmax(pre_probs) 111 | 112 | r_e = torch.matmul(pre_probs, self.pref_embeddings.weight) 113 | norm = torch.matmul(pre_probs, self.pref_norm_embeddings.weight) 114 | 115 | return pre_probs, r_e, norm 116 | 117 | # batch or batch * item 118 | def convert_to_one_hot(self, indices, num_classes): 119 | """ 120 | Args: 121 | indices (Variable): A vector containing indices, 122 | whose size is (batch_size,). 123 | num_classes (Variable): The number of classes, which would be 124 | the second dimension of the resulting one-hot matrix. 125 | Returns: 126 | result: The one-hot matrix of size (batch_size, num_classes). 127 | """ 128 | 129 | old_shape = indices.shape 130 | new_shape = torch.Size([i for i in old_shape] + [num_classes]) 131 | indices = indices.unsqueeze(len(old_shape)) 132 | 133 | one_hot = V(indices.data.new(new_shape).zero_() 134 | .scatter_(len(old_shape), indices.data, 1)) 135 | return one_hot 136 | 137 | 138 | def masked_softmax(self, logits): 139 | eps = 1e-20 140 | probs = F.softmax(logits, dim=len(logits.shape)-1) 141 | return probs 142 | 143 | def st_gumbel_softmax(self, logits, temperature=1.0): 144 | """ 145 | Return the result of Straight-Through Gumbel-Softmax Estimation. 146 | It approximates the discrete sampling via Gumbel-Softmax trick 147 | and applies the biased ST estimator. 148 | In the forward propagation, it emits the discrete one-hot result, 149 | and in the backward propagation it approximates the categorical 150 | distribution via smooth Gumbel-Softmax distribution. 151 | Args: 152 | logits (Variable): A un-normalized probability values, 153 | which has the size (batch_size, num_classes) 154 | temperature (float): A temperature parameter. The higher 155 | the value is, the smoother the distribution is. 156 | Returns: 157 | y: The sampled output, which has the property explained above. 158 | """ 159 | 160 | eps = 1e-20 161 | u = logits.data.new(*logits.size()).uniform_() 162 | gumbel_noise = V(-torch.log(-torch.log(u + eps) + eps)) 163 | y = logits + gumbel_noise 164 | y = self.masked_softmax(logits=y / temperature) 165 | y_argmax = y.max(len(y.shape)-1)[1] 166 | y_hard = self.convert_to_one_hot( 167 | indices=y_argmax, 168 | num_classes=y.size(len(y.shape)-1)).float() 169 | y = (y_hard - y).detach() + y 170 | return y 171 | 172 | def reportPreference(self, u_id, i_ids): 173 | item_num = len(i_ids) 174 | # item_num * dim 175 | u_e = self.user_embeddings(u_id.expand(item_num)) 176 | # item_num * dim 177 | i_e = self.item_embeddings(i_ids) 178 | # item_num * relation_total 179 | 180 | return self.getPreferences(u_e, i_e, use_st_gumbel=self.use_st_gumbel) 181 | 182 | def disable_grad(self): 183 | for name, param in self.named_parameters(): 184 | param.requires_grad=False 185 | 186 | def enable_grad(self): 187 | for name, param in self.named_parameters(): 188 | param.requires_grad=True -------------------------------------------------------------------------------- /jTransUP/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/.DS_Store -------------------------------------------------------------------------------- /jTransUP/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__init__.py -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/kg_log_parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/kg_log_parser.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/log_parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/log_parser.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/rec_log_parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/rec_log_parser.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/__pycache__/visuliazer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/jTransUP/utils/__pycache__/visuliazer.cpython-36.pyc -------------------------------------------------------------------------------- /jTransUP/utils/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | import numpy as np 4 | 5 | def getTrainRatingBatch(rating_batch, item_total, all_dicts=None): 6 | u_ids = [rating.u for rating in rating_batch] 7 | pi_ids = [rating.i for rating in rating_batch] 8 | # yield u, pi, ni, each list contains batch size ids, 9 | u, pi, ni = addNegRatings(rating_batch.tolist(), item_total, all_dicts=all_dicts) 10 | return u, pi, ni 11 | 12 | def getTrainTripleBatch(triple_batch, entity_total, all_head_dicts=None, all_tail_dicts=None): 13 | negTripleList = [corrupt_head_filter(triple, entity_total, headDicts=all_head_dicts) if random.random() < 0.5 14 | else corrupt_tail_filter(triple, entity_total, tailDicts=all_tail_dicts) for triple in triple_batch] 15 | # yield u, pi, ni, each list contains batch size ids, 16 | ph, pt, pr = getTripleElements(triple_batch) 17 | nh, nt, nr = getTripleElements(negTripleList) 18 | return ph, pt, pr, nh, nt, nr 19 | 20 | # Change the head of a triple randomly, 21 | # with checking whether it is a false negative sample. 22 | # If it is, regenerate. 23 | def corrupt_head_filter(triple, entityTotal, headDicts=None): 24 | while True: 25 | newHead = random.randrange(entityTotal) 26 | if newHead == triple[0] : continue 27 | if headDicts is not None: 28 | has_exist = False 29 | tr = (triple[1], triple[2]) 30 | for head_dict in headDicts: 31 | if tr in head_dict and newHead in head_dict[tr]: 32 | has_exist = True 33 | break 34 | if has_exist: continue 35 | else: break 36 | else: break 37 | return (newHead, triple[1], triple[2]) 38 | 39 | # Change the tail of a triple randomly, 40 | # with checking whether it is a false negative sample. 41 | # If it is, regenerate. 42 | def corrupt_tail_filter(triple, entityTotal, tailDicts=None): 43 | while True: 44 | newTail = random.randrange(entityTotal) 45 | if newTail == triple[1] : continue 46 | if tailDicts is not None: 47 | has_exist = False 48 | hr = (triple[0], triple[2]) 49 | for tail_dict in tailDicts: 50 | if hr in tail_dict and newTail in tail_dict[hr]: 51 | has_exist = True 52 | break 53 | if has_exist: continue 54 | else: break 55 | else: break 56 | return (triple[0], newTail, triple[2]) 57 | 58 | def getTripleElements(tripleList): 59 | headList = [triple[0] for triple in tripleList] 60 | tailList = [triple[1] for triple in tripleList] 61 | relList = [triple[2] for triple in tripleList] 62 | return headList, tailList, relList 63 | 64 | def getNegRatings(ratingList, itemTotal, all_dicts=None): 65 | ni = [] 66 | neg_set = set() 67 | for rating in ratingList: 68 | c_u = rating[0] 69 | oldItem = rating[1] 70 | # rating exists 71 | fliter_items = None 72 | if all_dicts is not None: 73 | fliter_items = set() 74 | for dic in all_dicts: 75 | if c_u in dic: 76 | fliter_items.update(dic[c_u]) 77 | while True: 78 | newItem = random.randrange(itemTotal) 79 | if newItem != oldItem and newItem not in fliter_items and newItem not in neg_set : 80 | break 81 | ni.append(newItem) 82 | neg_set.add(newItem) 83 | u = [rating[0] for rating in ratingList] 84 | pi = [rating[1] for rating in ratingList] 85 | return u, pi, ni 86 | 87 | def MakeTrainIterator( 88 | train_data, 89 | batch_size, 90 | negtive_samples=1): 91 | train_list = np.array(train_data) 92 | 93 | def data_iter(): 94 | dataset_size = len(train_list) 95 | order = list(range(dataset_size)) * negtive_samples 96 | random.shuffle(order) 97 | start = -1 * batch_size 98 | 99 | while True: 100 | start += batch_size 101 | if start > dataset_size - batch_size: 102 | # Start another epoch. 103 | start = 0 104 | random.shuffle(order) 105 | batch_indices = order[start:start + batch_size] 106 | 107 | # numpy 108 | yield train_list[batch_indices].tolist() 109 | 110 | return data_iter() 111 | 112 | def MakeEvalIterator( 113 | eval_data, 114 | data_type, 115 | batch_size): 116 | # Make a list of minibatches from a dataset to use as an iterator. 117 | 118 | eval_list = np.asarray(eval_data, data_type) 119 | dataset_size = len(eval_list) 120 | order = list(range(dataset_size)) 121 | data_iter = [] 122 | start = -batch_size 123 | while True: 124 | start += batch_size 125 | 126 | if start >= dataset_size: 127 | break 128 | 129 | batch_indices = order[start:start + batch_size] 130 | candidate_batch = eval_list[batch_indices] 131 | data_iter.append(candidate_batch.tolist()) 132 | 133 | return data_iter -------------------------------------------------------------------------------- /jTransUP/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | from __future__ import division 3 | import numpy as np 4 | import math 5 | import pandas as pd 6 | import time 7 | 8 | def get_performance(recommend_list, purchased_list): 9 | """计算F1值。 10 | 输入: 11 | recommend_list:n * 1, 推荐算法给出的推荐结果。 12 | purchased_list:m * 1, 用户的真实购买记录。 13 | 输出: 14 | F1值。 15 | """ 16 | k = 0 17 | hit_number = 0 18 | rank_list = [] 19 | for i_id in recommend_list: 20 | k += 1 21 | if i_id in purchased_list: 22 | hit_number += 1 23 | rank_list.append(1) 24 | else: 25 | rank_list.append(0) 26 | 27 | k_gold = len(purchased_list) 28 | f = 0.0 29 | p = 0.0 30 | r = 0.0 31 | ndcg = 0.0 32 | 33 | if hit_number > 0: 34 | p = float(hit_number) / k 35 | r = float(hit_number) / k_gold 36 | f = 2 * p * r / (p + r) 37 | ndcg = ndcg_at_k(rank_list, k) 38 | 39 | return f, p, r, 1 if hit_number > 0 else 0, ndcg 40 | 41 | def dcg_at_k(r, k, method=1): 42 | """Score is discounted cumulative gain (dcg) 43 | Relevance is positive real values. Can use binary 44 | as the previous methods. 45 | Example from 46 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 47 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 48 | >>> dcg_at_k(r, 1) 49 | 3.0 50 | >>> dcg_at_k(r, 1, method=1) 51 | 3.0 52 | >>> dcg_at_k(r, 2) 53 | 5.0 54 | >>> dcg_at_k(r, 2, method=1) 55 | 4.2618595071429155 56 | >>> dcg_at_k(r, 10) 57 | 9.6051177391888114 58 | >>> dcg_at_k(r, 11) 59 | 9.6051177391888114 60 | Args: 61 | r: Relevance scores (list or numpy) in rank order 62 | (first element is the first item) 63 | k: Number of results to consider 64 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 65 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 66 | Returns: 67 | Discounted cumulative gain 68 | """ 69 | r = np.asfarray(r)[:k] 70 | if r.size: 71 | if method == 0: 72 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 73 | elif method == 1: 74 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 75 | else: 76 | raise ValueError('method must be 0 or 1.') 77 | return 0. 78 | 79 | 80 | def ndcg_at_k(r, k, method=0): 81 | """Score is normalized discounted cumulative gain (ndcg) 82 | Relevance is positive real values. Can use binary 83 | as the previous methods. 84 | Example from 85 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 86 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 87 | >>> ndcg_at_k(r, 1) 88 | 1.0 89 | >>> r = [2, 1, 2, 0] 90 | >>> ndcg_at_k(r, 4) 91 | 0.9203032077642922 92 | >>> ndcg_at_k(r, 4, method=1) 93 | 0.96519546960144276 94 | >>> ndcg_at_k([0], 1) 95 | 0.0 96 | >>> ndcg_at_k([1], 2) 97 | 1.0 98 | Args: 99 | r: Relevance scores (list or numpy) in rank order 100 | (first element is the first item) 101 | k: Number of results to consider 102 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 103 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 104 | Returns: 105 | Normalized discounted cumulative gain 106 | """ 107 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 108 | if not dcg_max: 109 | return 0. 110 | return dcg_at_k(r, k, method) / dcg_max 111 | 112 | def evalAll(recommend_list, purchased_list): 113 | """计算F1和NDCG值。 114 | 输入: 115 | recommend_list: 推荐算法给出的推荐结果。 116 | purchased_list: 用户的真实购买记录。 117 | 输出: 118 | F1,NDCG值。 119 | """ 120 | assert len(recommend_list) == len(purchased_list), "Eval user number not match!" 121 | 122 | results = [] 123 | for list_pair in zip(recommend_list, purchased_list): 124 | f, p, r, hit_ratio, ndcg = get_performance(list_pair[0], list_pair[1]) 125 | results.append([f, p, r, hit_ratio, ndcg]) 126 | # f1, prec, rec, hit_ratio, ndcg 127 | performance = np.array(results).mean(axis=0) 128 | return performance[0], performance[1], performance[2], performance[3], performance[4] 129 | 130 | if __name__ == "__main__": 131 | a = np.random.randint(0, 10, size=(2, 3)) 132 | b = np.random.randint(0, 10, size=(2, 4)) 133 | print(a) 134 | print(b) 135 | f1, prec, rec, hit_ratio, ndcg = evalAll(a, b) 136 | print("{},{},{},{},{}".format(f1, prec, rec, hit_ratio, ndcg)) -------------------------------------------------------------------------------- /jTransUP/utils/evaluation_onehot.py: -------------------------------------------------------------------------------- 1 | import math 2 | import heapq 3 | import numpy as np 4 | 5 | 6 | def eval_model_pro(y_gnd, y_pre, K, row_len): 7 | mat_gnd = np.reshape(y_gnd, (-1, row_len)) 8 | mat_pre = np.reshape(y_pre, (-1, row_len)) 9 | 10 | hits, ndcgs= eval_model(mat_gnd, mat_pre, K) 11 | return hits, ndcgs 12 | 13 | 14 | def eval_model(y_gnd, y_pre, K): 15 | ndcgs, hits = [], [] 16 | y_gnd = y_gnd.tolist() 17 | y_pre = y_pre.tolist() 18 | 19 | for i, i_gnd in enumerate(y_gnd): 20 | i_pre = y_pre[i] 21 | hit, ndcg = eval_one_rating(i_gnd, i_pre, K) 22 | hits.append(hit) 23 | ndcgs.append(ndcg) 24 | 25 | hits = np.array(hits).mean() 26 | ndcgs = np.array(ndcgs).mean() 27 | 28 | return hits, ndcgs 29 | 30 | 31 | def eval_one_rating(i_gnd, i_pre, K): 32 | if sum(i_pre) == 0: 33 | return 0, 0 34 | map_score = {} 35 | for item, score in enumerate(i_pre): 36 | map_score[item] = score 37 | 38 | rank_list = heapq.nlargest(K, map_score, key=map_score.get) #return rank index 39 | target_item = i_gnd.index(1) 40 | 41 | hit = get_hit_ratio(rank_list, target_item) 42 | ndcg = get_ndcg(rank_list, target_item) 43 | 44 | return hit, ndcg 45 | 46 | 47 | def get_hit_ratio(rank_list, target_item): 48 | for item in rank_list: 49 | if item == target_item: 50 | return 1 51 | return 0 52 | 53 | 54 | def get_ndcg(rank_list, target_item): 55 | for i, item in enumerate(rank_list): 56 | if item == target_item: 57 | return math.log(2) / math.log(i + 2) 58 | return 0 59 | -------------------------------------------------------------------------------- /jTransUP/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from jTransUP.utils.misc import to_gpu 7 | 8 | class marginLoss(nn.Module): 9 | def __init__(self): 10 | super(marginLoss, self).__init__() 11 | 12 | def forward(self, pos, neg, margin): 13 | zero_tensor = to_gpu(torch.FloatTensor(pos.size())) 14 | zero_tensor.zero_() 15 | zero_tensor = autograd.Variable(zero_tensor) 16 | return torch.sum(torch.max(pos - neg + margin, zero_tensor)) 17 | 18 | def orthogonalLoss(rel_embeddings, norm_embeddings): 19 | return torch.sum(torch.sum(norm_embeddings * rel_embeddings, dim=1, keepdim=True) ** 2 / torch.sum(rel_embeddings ** 2, dim=1, keepdim=True)) 20 | 21 | def normLoss(embeddings, dim=1): 22 | norm = torch.sum(embeddings ** 2, dim=dim, keepdim=True) 23 | return torch.sum(torch.max(norm - to_gpu(autograd.Variable(torch.FloatTensor([1.0]))), to_gpu(autograd.Variable(torch.FloatTensor([0.0]))))) 24 | ''' 25 | def normLoss(embeddings, dim=1): 26 | norm = torch.sum(embeddings ** 2, dim=dim, keepdim=True) 27 | return torch.sum(torch.max(norm - 1.0, to_gpu(autograd.Variable(torch.FloatTensor([0.0]))))) 28 | ''' 29 | def bprLoss(pos, neg, target=1.0): 30 | loss = - F.logsigmoid(target * ( pos - neg )) 31 | return loss.mean() 32 | 33 | def pNormLoss(emb1, emb2, L1_flag=False): 34 | if L1_flag: 35 | distance = torch.sum(torch.abs(emb1 - emb2), 1) 36 | else: 37 | distance = torch.sum((emb1 - emb2) ** 2, 1) 38 | return distance.mean() -------------------------------------------------------------------------------- /jTransUP/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import deque 3 | import numpy as np 4 | from jTransUP.utils.evaluation import ndcg_at_k 5 | import heapq 6 | import time 7 | from itertools import groupby 8 | import multiprocessing 9 | import math 10 | 11 | USE_CUDA = torch.cuda.is_available() 12 | 13 | def to_gpu(var): 14 | if USE_CUDA: 15 | return var.cuda() 16 | return var 17 | 18 | def projection_transH_pytorch(original, norm): 19 | return original - torch.sum(original * norm, dim=len(original.size())-1, keepdim=True) * norm 20 | 21 | def projection_transR_pytorch(original, proj_matrix): 22 | ent_embedding_size = original.shape[1] 23 | rel_embedding_size = proj_matrix.shape[1] // ent_embedding_size 24 | original = original.view(-1, ent_embedding_size, 1) 25 | proj_matrix = proj_matrix.view(-1, rel_embedding_size, ent_embedding_size) 26 | return torch.matmul(proj_matrix, original).view(-1, rel_embedding_size) 27 | 28 | # original: E*d2, proj: b*d1*d2 29 | def projection_transR_pytorch_batch(original, proj_matrix): 30 | ent_embedding_size = original.shape[1] 31 | rel_embedding_size = proj_matrix.shape[1] // ent_embedding_size 32 | proj_matrix = proj_matrix.view(-1, rel_embedding_size, ent_embedding_size) 33 | return torch.matmul(proj_matrix, original.transpose(0,1)).transpose(1,2) 34 | 35 | # batch * dim 36 | def projection_transD_pytorch_samesize(entity_embedding, entity_projection, relation_projection): 37 | return entity_embedding + torch.sum(entity_embedding * entity_projection, dim=len(entity_embedding.size())-1, keepdim=True) * relation_projection 38 | 39 | class Accumulator(object): 40 | """Accumulator. Makes it easy to keep a trailing list of statistics.""" 41 | 42 | def __init__(self, maxlen=None): 43 | self.maxlen = maxlen 44 | self.cache = dict() 45 | 46 | def add(self, key, val): 47 | self.cache.setdefault(key, deque(maxlen=self.maxlen)).append(val) 48 | 49 | def get(self, key, clear=True): 50 | ret = self.cache.get(key, []) 51 | if clear: 52 | try: 53 | del self.cache[key] 54 | except BaseException: 55 | pass 56 | return ret 57 | 58 | def get_avg(self, key, clear=True): 59 | return np.array(self.get(key, clear)).mean() 60 | 61 | class MyEvalKGProcess(multiprocessing.Process): 62 | def __init__(self, L, eval_dict, all_dicts=None, descending=True, topn=10, queue=None): 63 | super(MyEvalKGProcess, self).__init__() 64 | self.queue = queue 65 | self.L = L 66 | self.eval_dict = eval_dict 67 | self.all_dicts = all_dicts 68 | self.topn = topn 69 | self.descending = descending 70 | 71 | def run(self): 72 | while True: 73 | pred_scores = self.queue.get() 74 | try: 75 | self.process_data(pred_scores, self.eval_dict, all_dicts=self.all_dicts) 76 | except: 77 | time.sleep(5) 78 | self.process_data(pred_scores, self.eval_dict, all_dicts=self.all_dicts) 79 | self.queue.task_done() 80 | 81 | def process_data(self, pred_scores, eval_dict, all_dicts=None): 82 | for pred in pred_scores: 83 | if pred[0] not in eval_dict: continue 84 | gold = eval_dict[pred[0]] 85 | # ids to be filtered 86 | fliter_samples = None 87 | if all_dicts is not None: 88 | fliter_samples = set() 89 | for dic in all_dicts: 90 | if pred[0] in dic: 91 | fliter_samples.update(dic[pred[0]]) 92 | 93 | per_scores = pred[1] if not self.descending else -pred[1] 94 | hits, gold_ranks, gold_ids = getKGPerformance(per_scores, gold, fliter_samples=fliter_samples, topn=self.topn) 95 | self.L.extend( list(zip(hits, gold_ranks, [pred[0]]*len(hits), gold_ids)) ) 96 | 97 | # pred_scores: batch * item, [(id, numpy.array), ...], all_dicts:(train_dict, valid_dict, test_dict) 98 | def evalKGProcess(pred_scores, eval_dict, all_dicts=None, descending=True, num_processes=multiprocessing.cpu_count(), topn=10, queue_limit=10): 99 | offset = math.ceil(float(len(pred_scores)) / queue_limit) 100 | grouped_lists = [pred_scores[i:i+offset] for i in range(0,len(pred_scores),offset)] 101 | 102 | with multiprocessing.Manager() as manager: 103 | L = manager.list() 104 | queue = multiprocessing.JoinableQueue() 105 | workerList = [] 106 | for i in range(num_processes): 107 | worker = MyEvalKGProcess(L, eval_dict, all_dicts=all_dicts, descending=descending, topn=topn, queue=queue) 108 | workerList.append(worker) 109 | worker.daemon = True 110 | worker.start() 111 | 112 | for sub_list in grouped_lists: 113 | if len(sub_list) == 0 : continue 114 | queue.put(sub_list) 115 | queue.join() 116 | 117 | results = list(L) 118 | 119 | for worker in workerList: 120 | worker.terminate() 121 | 122 | return results 123 | 124 | # pred: numpy.array, gold,filter: set() 125 | def getKGPerformance(pred, gold, fliter_samples=None, topn=10): 126 | # index of pred is also ids, id's rank 127 | pred_ranks = np.argsort(pred) 128 | 129 | gold_ranks = [] 130 | hits = [] 131 | gold_ids = [] 132 | current_rank = 0 133 | topn_to_skip = 0 134 | for rank_id in pred_ranks: 135 | if fliter_samples is not None and rank_id in fliter_samples : 136 | if current_rank < topn : topn_to_skip += 1 137 | continue 138 | if rank_id in gold: 139 | gold_ranks.append(current_rank) 140 | gold_ids.append(rank_id) 141 | hits.append(1 if current_rank < topn else 0) 142 | if len(gold_ranks) == len(gold) : break 143 | else: 144 | current_rank += 1 145 | 146 | return hits, gold_ranks, gold_ids 147 | 148 | class MyEvalRecProcess(multiprocessing.Process): 149 | def __init__(self, L, eval_dict, all_dicts=None, descending=True, topn=10, queue=None): 150 | super(MyEvalRecProcess, self).__init__() 151 | self.queue = queue 152 | self.L = L 153 | self.eval_dict = eval_dict 154 | self.all_dicts = all_dicts 155 | self.topn = topn 156 | self.descending = descending 157 | 158 | def run(self): 159 | while True: 160 | pred_scores = self.queue.get() 161 | try: 162 | self.process_data(pred_scores, self.eval_dict, all_dicts=self.all_dicts) 163 | except: 164 | time.sleep(5) 165 | self.process_data(pred_scores, self.eval_dict, all_dicts=self.all_dicts) 166 | self.queue.task_done() 167 | 168 | def process_data(self, pred_scores, eval_dict, all_dicts=None): 169 | for pred in pred_scores: 170 | if pred[0] not in eval_dict: continue 171 | gold = eval_dict[pred[0]] 172 | # ids to be filtered 173 | fliter_samples = None 174 | if all_dicts is not None: 175 | fliter_samples = set() 176 | for dic in all_dicts: 177 | if pred[0] in dic: 178 | fliter_samples.update(dic[pred[0]]) 179 | 180 | per_scores = pred[1] if not self.descending else -pred[1] 181 | f1, p, r, hit, ndcg, top_ids = getRecPerformance(per_scores, gold, fliter_samples=fliter_samples, topn=self.topn) 182 | 183 | self.L.append( [f1, p, r, hit, ndcg, (pred[0], top_ids, gold)] ) 184 | 185 | # pred_scores: batch * item, [(id, numpy.array), ...], all_dicts:(train_dict, valid_dict, test_dict) 186 | def evalRecProcess(pred_scores, eval_dict, all_dicts=None, descending=True, num_processes=multiprocessing.cpu_count(), topn=10, queue_limit=10): 187 | offset = math.ceil(float(len(pred_scores)) / queue_limit) 188 | grouped_lists = [pred_scores[i:i+offset] for i in range(0,len(pred_scores),offset)] 189 | 190 | with multiprocessing.Manager() as manager: 191 | L = manager.list() 192 | queue = multiprocessing.JoinableQueue() 193 | workerList = [] 194 | for i in range(num_processes): 195 | worker = MyEvalRecProcess(L, eval_dict, all_dicts=all_dicts, descending=descending, topn=topn, queue=queue) 196 | workerList.append(worker) 197 | worker.daemon = True 198 | worker.start() 199 | 200 | for sub_list in grouped_lists: 201 | if len(sub_list) == 0 : continue 202 | queue.put(sub_list) 203 | queue.join() 204 | 205 | results = list(L) 206 | 207 | for worker in workerList: 208 | worker.terminate() 209 | 210 | return results 211 | 212 | # pred: numpy.array, gold,filter: set() 213 | def getRecPerformance(pred, gold, fliter_samples=None, topn=10): 214 | # index of pred is also ids 215 | pred_ranks = np.argsort(pred) 216 | 217 | hits = [] 218 | current_rank = 0 219 | topn_to_skip = 0 220 | top_ids = [] 221 | for rank_id in pred_ranks: 222 | if fliter_samples is not None and rank_id in fliter_samples : 223 | if current_rank < topn : topn_to_skip += 1 224 | continue 225 | 226 | hits.append(1 if rank_id in gold else 0) 227 | top_ids.append(rank_id) 228 | current_rank += 1 229 | if current_rank >= topn : break 230 | 231 | # hit number, how many preds in gold 232 | hits_count = sum(hits) 233 | 234 | k = len(hits) 235 | k_gold = len(gold) 236 | f1 = 0.0 237 | p = 0.0 238 | r = 0.0 239 | ndcg = 0.0 240 | hit = 1 if hits_count > 0 else 0 241 | 242 | if hits_count > 0: 243 | p = float(hits_count) / k 244 | r = float(hits_count) / k_gold 245 | f1 = 2 * p * r / (p + r) 246 | ndcg = ndcg_at_k(hits, k) 247 | 248 | return f1, p, r, hit, ndcg, top_ids 249 | 250 | def recursively_set_device(inp, gpu=USE_CUDA): 251 | if hasattr(inp, 'keys'): 252 | for k in list(inp.keys()): 253 | inp[k] = recursively_set_device(inp[k], USE_CUDA) 254 | elif isinstance(inp, list): 255 | return [recursively_set_device(ii, USE_CUDA) for ii in inp] 256 | elif isinstance(inp, tuple): 257 | return (recursively_set_device(ii, USE_CUDA) for ii in inp) 258 | elif hasattr(inp, 'cpu'): 259 | if USE_CUDA: 260 | inp = inp.cuda() 261 | else: 262 | inp = inp.cpu() 263 | return inp -------------------------------------------------------------------------------- /jTransUP/utils/rec_log_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | # two items refer to the same entity 4 | def loadR2KgMap(filename, item_vocab=None, kg_vocab=None): 5 | i2kg_map = {} 6 | kg2i_map = {} 7 | with open(filename, 'r', encoding='utf-8') as fin: 8 | for line in fin: 9 | line_split = line.strip().split('\t') 10 | if len(line_split) != 3 : continue 11 | i_id = int(line_split[0]) 12 | kg_uri = line_split[2] 13 | if item_vocab is not None and kg_vocab is not None: 14 | if i_id not in item_vocab or kg_uri not in kg_vocab : continue 15 | i_id = item_vocab[i_id] 16 | kg_uri = kg_vocab[kg_uri] 17 | i2kg_map[i_id] = kg_uri 18 | kg2i_map[kg_uri] = i_id 19 | print("successful load {} item and {} entity pairs!".format(len(i2kg_map), len(kg2i_map))) 20 | return i2kg_map, kg2i_map 21 | 22 | def loadRecVocab(filename): 23 | with open(filename, 'r', encoding='utf-8') as fin: 24 | vocab = {} 25 | vocab_reverse = {} 26 | for line in fin: 27 | line_split = line.strip().split('\t') 28 | if len(line_split) != 2 : continue 29 | mapped_id = int(line_split[0]) 30 | org_id = int(line_split[1]) 31 | vocab[org_id] = mapped_id 32 | vocab_reverse[mapped_id] = org_id 33 | print("load {} vocab!".format(len(vocab))) 34 | 35 | return vocab, vocab_reverse 36 | 37 | def loadKGVocab(filename): 38 | with open(filename, 'r', encoding='utf-8') as fin: 39 | vocab = {} 40 | vocab_reverse = {} 41 | for line in fin: 42 | line_split = line.strip().split('\t') 43 | if len(line_split) != 2 : continue 44 | mapped_id = int(line_split[0]) 45 | org_id = line_split[1] 46 | vocab[org_id] = mapped_id 47 | vocab_reverse[mapped_id] = org_id 48 | print("load {} vocab!".format(len(vocab))) 49 | return vocab, vocab_reverse 50 | 51 | rel_p = re.compile(r'([\d]+)\(([\d]+)\)') 52 | def parseRecResults(log_filename, model_type): 53 | results = {} 54 | user_item_relations = {} 55 | correct_num = 0 56 | wrong_num = 0 57 | with open(log_filename, 'r', encoding='utf-8') as fin: 58 | for line in fin: 59 | line_split = line.strip().split('\t') 60 | if len(line_split) != 3: continue 61 | u_id = int(line_split[0].split('user:')[-1]) 62 | pred_ids = set([int(i) for i in re.split(r':|,', line_split[2])[1:]]) 63 | 64 | if model_type not in ['transup', 'jtransup']: 65 | gold_ids = set([int(i) for i in re.split(r':|,', line_split[1])[1:]]) 66 | else: 67 | tmp_gold_ids = [s for s in re.split(r':|,', line_split[1])[1:]] 68 | gold_ids = set() 69 | for s in tmp_gold_ids: 70 | m = rel_p.match(s) 71 | if m is not None: 72 | gold_id = int(m[1]) 73 | rel_id = int(m[2]) 74 | gold_ids.add(gold_id) 75 | user_item_relations[(u_id,gold_id)] = rel_id 76 | correct = pred_ids & gold_ids 77 | wrong = pred_ids - correct 78 | correct_num += len(correct) 79 | wrong_num += len(wrong) 80 | results[u_id] = (correct, wrong, gold_ids) 81 | print("parse {} users, avgerage {} correct and {} wrong!".format(len(results), correct_num/len(results), wrong_num/len(results))) 82 | print("parse {} user item relations!".format(len(user_item_relations))) 83 | 84 | return results, user_item_relations 85 | 86 | def loadTriples(filename): 87 | with open(filename, 'r', encoding='utf-8') as fin: 88 | triple_total = 0 89 | triple_list = [] 90 | triple_head_dict = {} 91 | triple_tail_dict = {} 92 | for line in fin: 93 | line_split = line.strip().split('\t') 94 | if len(line_split) != 3 : continue 95 | h_id = int(line_split[0]) 96 | t_id = int(line_split[1]) 97 | r_id = int(line_split[2]) 98 | 99 | triple_list.append( (h_id, t_id, r_id) ) 100 | 101 | tmp_heads = triple_head_dict.get( (t_id, r_id), set()) 102 | tmp_heads.add(h_id) 103 | triple_head_dict[(t_id, r_id)] = tmp_heads 104 | 105 | tmp_tails = triple_tail_dict.get( (h_id, r_id), set()) 106 | tmp_tails.add(t_id) 107 | triple_tail_dict[(h_id, r_id)] = tmp_tails 108 | 109 | triple_total += 1 110 | 111 | return triple_total, triple_list, triple_head_dict, triple_tail_dict 112 | 113 | def compareLogs(log1, log2, model1, model2, i2kg_map, triple_head_dict, triple_tail_dict, output_file, u_map_reverse=None, i_map_reverse=None, e_map_reverse=None, r_map_reverse=None, users=None): 114 | results1, preference1 = parseRecResults(log1, model1) 115 | results2, preference2 = parseRecResults(log2, model2) 116 | with open(output_file, 'w', encoding='utf-8') as fout: 117 | for u_id in results1: 118 | if u_id not in results2 or (users is not None and u_map_reverse[u_id] not in users): continue 119 | correct1, wrong1, gold_ids1 = results1[u_id] 120 | correct2, wrong2, gold_ids2 = results2[u_id] 121 | target_items = correct2 - correct1 122 | # if len(target_items) == 0 : continue 123 | analysis(fout, u_id, target_items, preference2, gold_ids2, i2kg_map, triple_head_dict, triple_tail_dict, u_map_reverse=u_map_reverse, i_map_reverse=i_map_reverse, e_map_reverse=e_map_reverse, r_map_reverse=r_map_reverse) 124 | 125 | def analysis(fout, u_id, target_items, preference, gold_ids, i2kg_map, triple_head_dict, triple_tail_dict, u_map_reverse=None, i_map_reverse=None, e_map_reverse=None, r_map_reverse=None): 126 | out_reverse = False 127 | if u_map_reverse is not None and i_map_reverse is not None and e_map_reverse is not None and r_map_reverse is not None : 128 | out_reverse = True 129 | if out_reverse: 130 | out_str = "user:{}\ttarget:{}".format(u_map_reverse[u_id], ",".join([e_map_reverse[i2kg_map[ti]] for ti in target_items if ti in i2kg_map])) 131 | else: 132 | out_str = "user:{}\ttarget:{}".format(u_id, ",".join([str(ti) for ti in target_items])) 133 | 134 | target_rel_ids = [preference.get((u_id, ti), -1) for ti in target_items] 135 | tmp_gold_ids = gold_ids 136 | 137 | gold_rel_ids = [preference.get((u_id, gi), -1) for gi in tmp_gold_ids] 138 | u_prefs_target = set(target_rel_ids) - set([-1]) 139 | u_prefs_gold = set(gold_rel_ids) - set([-1]) 140 | 141 | share_prefs = u_prefs_target & u_prefs_gold 142 | if len(share_prefs) == 0: return None 143 | org_pref = [r_map_reverse[r] for r in share_prefs] 144 | if 'http://dbpedia.org/ontology/starring' not in org_pref or 'http://dbpedia.org/ontology/director' not in org_pref: return None 145 | remap_tis = [i2kg_map.get(ti, -1) for ti in target_items] 146 | remap_gids = [i2kg_map.get(gold_id, -1) for gold_id in tmp_gold_ids] 147 | for tr in zip(remap_tis, target_rel_ids): 148 | if tr[1] not in share_prefs : continue 149 | h_ids = triple_head_dict.get(tr, set()) 150 | t_ids = triple_tail_dict.get(tr, set()) 151 | if len(h_ids) == 0 and len(t_ids) == 0 : continue 152 | for gid in remap_gids: 153 | if gid == tr[0] : continue 154 | gh_ids = triple_head_dict.get((gid, tr[1]), set()) 155 | gt_ids = triple_tail_dict.get((gid, tr[1]), set()) 156 | if len(gh_ids) == 0 and len(gt_ids) == 0 : continue 157 | h_gh = h_ids & gh_ids 158 | h_gt = h_ids & gt_ids 159 | t_gh = t_ids & gh_ids 160 | t_gt = t_ids & gt_ids 161 | if len(h_gh) > 0: 162 | if out_reverse: 163 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in h_gh])) 164 | else: 165 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in h_gh])) 166 | out_str += "\nhh:{}".format(tmp_str) 167 | if len(h_gt) > 0: 168 | if out_reverse: 169 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in h_gt])) 170 | else: 171 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in h_gt])) 172 | out_str += "\nht:{}".format(tmp_str) 173 | if len(t_gh) > 0: 174 | if out_reverse: 175 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in t_gh])) 176 | else: 177 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in t_gh])) 178 | out_str += "\nth:{}".format(tmp_str) 179 | if len(t_gt) > 0: 180 | if out_reverse: 181 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in t_gt])) 182 | else: 183 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in t_gt])) 184 | out_str += "\ntt:{}".format(tmp_str) 185 | fout.write("{}\n".format(out_str)) 186 | 187 | 188 | def output(fout, u_id, target_items, preference, gold_ids, i2kg_map, triple_head_dict, triple_tail_dict, u_map_reverse=None, i_map_reverse=None, e_map_reverse=None, r_map_reverse=None): 189 | out_reverse = False 190 | if u_map_reverse is not None and i_map_reverse is not None and e_map_reverse is not None and r_map_reverse is not None : 191 | out_reverse = True 192 | if out_reverse: 193 | out_str = "user:{}\ttarget:{}".format(u_map_reverse[u_id], ",".join([e_map_reverse[i2kg_map[ti]] for ti in target_items if ti in i2kg_map])) 194 | else: 195 | out_str = "user:{}\ttarget:{}".format(u_id, ",".join([str(ti) for ti in target_items])) 196 | rel_ids = [preference.get((u_id, ti), -1) for ti in target_items] 197 | remap_tis = [i2kg_map.get(ti, -1) for ti in target_items] 198 | tmp_gold_ids = gold_ids - target_items 199 | remap_gids = [i2kg_map.get(gold_id, -1) for gold_id in tmp_gold_ids] 200 | 201 | for tr in zip(remap_tis, rel_ids): 202 | h_ids = triple_head_dict.get(tr, set()) 203 | t_ids = triple_tail_dict.get(tr, set()) 204 | if len(h_ids) == 0 and len(t_ids) == 0 : continue 205 | for gid in remap_gids: 206 | gh_ids = triple_head_dict.get((gid, tr[1]), set()) 207 | gt_ids = triple_tail_dict.get((gid, tr[1]), set()) 208 | if len(gh_ids) == 0 and len(gt_ids) == 0 : continue 209 | h_gh = h_ids & gh_ids 210 | h_gt = h_ids & gt_ids 211 | t_gh = t_ids & gh_ids 212 | t_gt = t_ids & gt_ids 213 | if len(h_gh) > 0: 214 | if out_reverse: 215 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in h_gh])) 216 | else: 217 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in h_gh])) 218 | out_str += "\nhh:{}".format(tmp_str) 219 | if len(h_gt) > 0: 220 | if out_reverse: 221 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in h_gt])) 222 | else: 223 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in h_gt])) 224 | out_str += "\nht:{}".format(tmp_str) 225 | if len(t_gh) > 0: 226 | if out_reverse: 227 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in t_gh])) 228 | else: 229 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in t_gh])) 230 | out_str += "\nth:{}".format(tmp_str) 231 | if len(t_gt) > 0: 232 | if out_reverse: 233 | tmp_str = "{},{},{},[{}]".format(e_map_reverse[tr[0]], e_map_reverse[gid], r_map_reverse[tr[1]], ",".join([str(e_map_reverse[i]) for i in t_gt])) 234 | else: 235 | tmp_str = "{},{},{},[{}]".format(tr[0], gid, tr[1], ",".join([str(i) for i in t_gt])) 236 | out_str += "\ntt:{}".format(tmp_str) 237 | fout.write("{}\n".format(out_str)) 238 | 239 | 240 | model1 = 'bprmf' 241 | model2 = 'jtransup' 242 | root_path = '/Users/caoyixin/Github/joint-kg-recommender' 243 | dataset_path = root_path + '/datasets/ml1m' 244 | 245 | log1 = root_path + '/log/log/tuned_ml1m/ml1m-bprmf-analysis.log' 246 | log2 = root_path + '/log/log/tuned_ml1m/ml1m-cjtransup-nogumbel_analysis_old.log' 247 | 248 | users = None 249 | 250 | u_map_file = dataset_path + '/u_map.dat' 251 | i_map_file = dataset_path + '/i_map.dat' 252 | e_map_file = dataset_path + '/kg/e_map.dat' 253 | r_map_file = dataset_path + '/kg/r_map.dat' 254 | i2kg_map_file = dataset_path + '/i2kg_map.tsv' 255 | train_triple_file = dataset_path + '/kg/train.dat' 256 | output_file = root_path + '/log/parse_bprmf_jtransup.log' 257 | 258 | user_vocab, user_vocab_reverse = loadRecVocab(u_map_file) 259 | item_vocab, item_vocab_reverse = loadRecVocab(i_map_file) 260 | kg_vocab, kg_vocab_reverse = loadKGVocab(e_map_file) 261 | rel_vocab, rel_vocab_reverse = loadKGVocab(r_map_file) 262 | i2kg_map, kg2i_map = loadR2KgMap(i2kg_map_file, item_vocab=item_vocab, kg_vocab=kg_vocab) 263 | 264 | triple_total, triple_list, triple_head_dict, triple_tail_dict = loadTriples(train_triple_file) 265 | 266 | compareLogs(log1, log2, model1, model2, i2kg_map, triple_head_dict, triple_tail_dict, output_file, u_map_reverse=user_vocab_reverse, i_map_reverse=item_vocab_reverse, e_map_reverse=kg_vocab_reverse, r_map_reverse=rel_vocab_reverse, users=users) 267 | 268 | -------------------------------------------------------------------------------- /jTransUP/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | import os 5 | from jTransUP.utils.misc import to_gpu, recursively_set_device, USE_CUDA 6 | 7 | def get_checkpoint_path(FLAGS, suffix=".ckpt"): 8 | # Set checkpoint path. 9 | if FLAGS.ckpt_path.endswith(".ckpt"): 10 | checkpoint_path = FLAGS.ckpt_path 11 | else: 12 | checkpoint_path = os.path.join(FLAGS.ckpt_path, FLAGS.experiment_name + suffix) 13 | return checkpoint_path 14 | 15 | def get_model_target(model_type): 16 | target = 1 if model_type in ["bprmf", "cofm", "fm"] else -1 17 | return target 18 | 19 | check_rho = 1.0 20 | class ModelTrainer(object): 21 | def __init__(self, model, logger, epoch_length, FLAGS): 22 | self.model = model 23 | self.logger = logger 24 | self.epoch_length = epoch_length 25 | self.model_target = get_model_target(FLAGS.model_type) 26 | 27 | self.logger.info('One epoch is ' + str(self.epoch_length) + ' steps.') 28 | 29 | self.parameters = [param for name, param in model.named_parameters()] 30 | self.optimizer_type = FLAGS.optimizer_type 31 | 32 | self.l2_lambda = FLAGS.l2_lambda 33 | self.learning_rate_decay_when_no_progress = FLAGS.learning_rate_decay_when_no_progress 34 | self.momentum = FLAGS.momentum 35 | self.eval_interval_steps = FLAGS.eval_interval_steps 36 | 37 | self.step = 0 38 | self.best_step = 0 39 | 40 | # record best dev, test acc 41 | self.best_dev_performance = 0.0 42 | self.best_performances = None 43 | 44 | # GPU support. 45 | to_gpu(model) 46 | 47 | self.optimizer_reset(FLAGS.learning_rate) 48 | 49 | self.checkpoint_path = get_checkpoint_path(FLAGS) 50 | 51 | # Load checkpoint if available. 52 | if FLAGS.eval_only_mode and os.path.isfile(FLAGS.load_experiment_name): 53 | self.logger.info("Found checkpoint, restoring.") 54 | self.load(FLAGS.load_experiment_name, cpu=not USE_CUDA) 55 | self.logger.info( 56 | "Resuming at step: {} with best dev performance: {} and test performance : {}.".format( 57 | self.best_step, self.best_dev_performance, self.best_performances)) 58 | 59 | def reset(self): 60 | self.step = 0 61 | self.best_step = 0 62 | 63 | def optimizer_reset(self, learning_rate): 64 | self.learning_rate = learning_rate 65 | 66 | if self.optimizer_type == "Adam": 67 | self.optimizer = optim.Adam(self.parameters, lr=learning_rate, 68 | weight_decay=self.l2_lambda) 69 | elif self.optimizer_type == "SGD": 70 | self.optimizer = optim.SGD(self.parameters, lr=learning_rate, 71 | weight_decay=self.l2_lambda, momentum=self.momentum) 72 | elif self.optimizer_type == "Adagrad": 73 | self.optimizer = optim.Adagrad(self.parameters, lr=learning_rate, 74 | weight_decay=self.l2_lambda) 75 | elif self.optimizer_type == "Rmsprop": 76 | self.optimizer = optim.RMSprop(self.parameters, lr=learning_rate, 77 | weight_decay=self.l2_lambda, momentum=self.momentum) 78 | 79 | def optimizer_step(self): 80 | self.optimizer.step() 81 | self.step += 1 82 | 83 | def optimizer_zero_grad(self): 84 | self.optimizer.zero_grad() 85 | 86 | def new_performance(self, dev_performance, performances): 87 | is_best = False 88 | # Track best dev error 89 | performance_to_care = dev_performance[0] 90 | if performance_to_care > check_rho * self.best_dev_performance: 91 | self.best_step = self.step 92 | self.logger.info( "Checkpointing ..." ) 93 | self.save(self.checkpoint_path) 94 | self.best_performances = performances 95 | self.best_dev_performance = performance_to_care 96 | is_best = True 97 | # Learning rate decay 98 | if self.learning_rate_decay_when_no_progress != 1.0: 99 | last_epoch_start = self.step - (self.step % self.epoch_length) 100 | if self.step - last_epoch_start <= self.eval_interval_steps and self.best_step < (last_epoch_start - self.epoch_length): 101 | self.logger.info('No improvement after one epoch. Lowering learning rate.') 102 | self.optimizer_reset(self.learning_rate * self.learning_rate_decay_when_no_progress) 103 | return is_best 104 | 105 | def checkpoint(self): 106 | self.logger.info("Checkpointing.") 107 | self.save(self.checkpoint_path) 108 | 109 | def save(self, filename): 110 | if USE_CUDA: 111 | recursively_set_device(self.model.state_dict(), gpu=-1) 112 | recursively_set_device(self.optimizer.state_dict(), gpu=-1) 113 | 114 | # Always sends Tensors to CPU. 115 | save_dict = { 116 | 'step': self.step, 117 | 'best_step': self.best_step, 118 | 'best_dev_performance': self.best_dev_performance, 119 | 'model_state_dict': self.model.state_dict(), 120 | 'optimizer_state_dict': self.optimizer.state_dict() 121 | } 122 | torch.save(save_dict, filename) 123 | 124 | if USE_CUDA: 125 | recursively_set_device(self.model.state_dict(), gpu=USE_CUDA) 126 | recursively_set_device(self.optimizer.state_dict(), gpu=USE_CUDA) 127 | 128 | def load(self, filename, cpu=False): 129 | if cpu: 130 | # Load GPU-based checkpoints on CPU 131 | checkpoint = torch.load( 132 | filename, map_location=lambda storage, loc: storage) 133 | else: 134 | checkpoint = torch.load(filename) 135 | model_state_dict = checkpoint['model_state_dict'] 136 | 137 | self.model.load_state_dict(model_state_dict, strict=False) 138 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 139 | 140 | self.step = checkpoint['step'] 141 | self.best_step = checkpoint['best_step'] 142 | self.best_dev_performance = checkpoint['best_dev_performance'] 143 | 144 | def loadEmbedding(self, filename, embedding_names, cpu=False, e_remap=None, i_remap=None): 145 | assert os.path.isfile(filename), "Checkpoint file not found!" 146 | self.logger.info("Found checkpoint, restoring pre-trained embeddings.") 147 | 148 | if cpu: 149 | # Load GPU-based checkpoints on CPU 150 | checkpoint = torch.load( 151 | filename, map_location=lambda storage, loc: storage) 152 | else: 153 | checkpoint = torch.load(filename) 154 | old_model_state_dict = checkpoint['model_state_dict'] 155 | 156 | model_dict = self.model.state_dict() 157 | 158 | # 1. filter out unnecessary keys 159 | pretrained_dict = {k: v for k, v in old_model_state_dict.items() if k in embedding_names} 160 | # 2. overwrite entries in the existing state dict 161 | model_dict.update(pretrained_dict) 162 | 163 | # for cke load 164 | if 'ent_embeddings.weight' in old_model_state_dict and 'ent_embeddings.weight' in model_dict and len(old_model_state_dict['ent_embeddings.weight'])+1 == len(self.model.ent_embeddings.weight.data): 165 | 166 | loaded_embeddings = old_model_state_dict['ent_embeddings.weight'] 167 | del (model_dict['ent_embeddings.weight']) 168 | self.model.ent_embeddings.weight.data[:len(loaded_embeddings), :] = loaded_embeddings[:, :] 169 | self.logger.info('Restored ' + str(len(loaded_embeddings)) + ' entities from checkpoint.') 170 | 171 | # for cfkg load 172 | if 'rel_embeddings.weight' in old_model_state_dict and 'rel_embeddings.weight' in model_dict and len(old_model_state_dict['rel_embeddings.weight'])+1 == len(self.model.rel_embeddings.weight.data): 173 | 174 | loaded_embeddings = old_model_state_dict['rel_embeddings.weight'] 175 | del (model_dict['rel_embeddings.weight']) 176 | self.model.rel_embeddings.weight.data[:len(loaded_embeddings), :] = loaded_embeddings[:, :] 177 | self.logger.info('Restored ' + str(len(loaded_embeddings)) + ' relations from checkpoint.') 178 | 179 | # restore entities 180 | if e_remap is not None and 'ent_embeddings.weight' in model_dict and 'ent_embeddings.weight' in embedding_names: 181 | loaded_embeddings = model_dict['ent_embeddings.weight'] 182 | del (model_dict['ent_embeddings.weight']) 183 | 184 | count = 0 185 | for index in e_remap: 186 | mapped_index = e_remap[index] 187 | self.model.ent_embeddings.weight.data[mapped_index, :] = loaded_embeddings[index, :] 188 | count += 1 189 | self.logger.info('Restored ' + str(count) + ' entities from checkpoint.') 190 | 191 | # restore entities 192 | if i_remap is not None and 'item_embeddings.weight' in model_dict and 'item_embeddings.weight' in embedding_names: 193 | loaded_embeddings = model_dict['item_embeddings.weight'] 194 | del (model_dict['item_embeddings.weight']) 195 | 196 | count = 0 197 | for index in i_remap: 198 | mapped_index = i_remap[index] 199 | self.model.item_embeddings.weight.data[mapped_index, :] = loaded_embeddings[index, :] 200 | count += 1 201 | self.logger.info('Restored ' + str(count) + ' items from checkpoint.') 202 | # for cofm 203 | if 'item_bias.weight' in model_dict and 'item_bias.weight' in pretrained_dict: 204 | loaded_embeddings = model_dict['item_bias.weight'] 205 | del (model_dict['item_bias.weight']) 206 | 207 | count = 0 208 | for index in i_remap: 209 | mapped_index = i_remap[index] 210 | self.model.item_bias.weight.data[mapped_index] = loaded_embeddings[index] 211 | count += 1 212 | self.logger.info('Restored ' + str(count) + ' items bias from checkpoint.') 213 | 214 | # 3. load the new state dict 215 | self.model.load_state_dict(model_dict, strict=False) 216 | 217 | self.logger.info("Load Embeddings of {} from {}.".format(", ".join(list(pretrained_dict.keys())), filename)) 218 | -------------------------------------------------------------------------------- /jTransUP/utils/visuliazer.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | import time 3 | import numpy as np 4 | 5 | class Visualizer(object): 6 | def __init__(self, env='default', **kwargs): 7 | self.vis = visdom.Visdom(env=env, **kwargs) 8 | self.index = {} 9 | 10 | def log(self, output_str, win_name="Log"): 11 | x = self.index.get(win_name, 0) 12 | self.vis.text(output_str, win=win_name, append=False if x == 0 else True) 13 | self.index[win_name] = x + 1 14 | 15 | def plot_many_stack(self, points, win_name="", options={}): 16 | ''' 17 | self.plot('loss',1.00) 18 | ''' 19 | name=list(points.keys()) 20 | if len(win_name) < 1: 21 | win_name = " ".join(name) 22 | 23 | options['legend'] = name 24 | options['title'] = win_name 25 | 26 | x = self.index.get(win_name, 0) 27 | val=list(points.values()) 28 | if len(val)==1: 29 | y=np.array(val) 30 | else: 31 | y=np.array(val).reshape(-1,len(val)) 32 | 33 | self.vis.line(Y=y,X=np.ones(y.shape)*x, 34 | win=win_name, 35 | opts=options, 36 | update=None if x == 0 else 'append' 37 | ) 38 | self.index[win_name] = x + 1 39 | -------------------------------------------------------------------------------- /ktup.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_knowledgable_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -kg_test_files valid.dat:test.dat -l2_lambda 0 -model_type cjtransup -has_visualization -dataset dbbook2014 -batch_size 400 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 19520 -training_steps 1952000 -early_stopping_steps_to_wait 97600 -optimizer_type Adam -joint_ratio 0.7 -noshare_embeddings -L1_flag -norm_lambda 1 -kg_lambda 1 -nouse_st_gumbel -load_ckpt_file tuned_dbbook2014/dbbook2014-transup-1540700102.ckpt:tuned_dbbook2014/dbbook2014-transh-1540701096.ckpt -visualization_port 8098 2 | 3 | -------------------------------------------------------------------------------- /ktup_eval.sh: -------------------------------------------------------------------------------- 1 | rec_test_files=$1 2 | 3 | for rec_test_files in test0.dat test1.dat test2.dat test3.dat test4.dat 4 | do 5 | CUDA_VISIBLE_DEVICES=0 python run_knowledgable_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files $rec_test_files -l2_lambda 0 -model_type cjtransup -has_visualization -dataset ml1m -batch_size 400 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 19520 -training_steps 1952000 -early_stopping_steps_to_wait 97600 -optimizer_type Adam -joint_ratio 0.7 -noshare_embeddings -L1_flag -norm_lambda 1 -kg_lambda 1 -nouse_st_gumbel -visualization_port 8098 -eval_only_mode -is_report -load_ckpt_file 6 | done 7 | 8 | kg_test_files=$2 9 | for kg_test_files in one2one.dat one2N.dat N2one.dat N2N.dat 10 | do 11 | CUDA_VISIBLE_DEVICES=0 python run_knowledgable_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -kg_test_files test0.dat -kg_test_files -l2_lambda 0 -model_type cjtransup -has_visualization -dataset dbbook2014 -batch_size 400 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 19520 -training_steps 1952000 -early_stopping_steps_to_wait 97600 -optimizer_type Adam -joint_ratio 0.7 -noshare_embeddings -L1_flag -norm_lambda 1 -kg_lambda 1 -nouse_st_gumbel -visualization_port 8098 -eval_only_mode -is_report -load_ckpt_file 12 | done -------------------------------------------------------------------------------- /log/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoMiner/joint-kg-recommender/c165f12c039dc62c888f843351d178b8a94f3689/log/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | conda install python-gflags 2 | pip install visdom 3 | conda install -c conda-forge tqdm -------------------------------------------------------------------------------- /run_item_recommendation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gflags 3 | from jTransUP.models import item_recommendation 4 | from jTransUP.models.base import get_flags, flag_defaults 5 | 6 | FLAGS = gflags.FLAGS 7 | 8 | if __name__ == '__main__': 9 | get_flags() 10 | # Parse command line flags. 11 | FLAGS(sys.argv) 12 | flag_defaults(FLAGS) 13 | item_recommendation.run(only_forward=FLAGS.eval_only_mode) -------------------------------------------------------------------------------- /run_knowledgable_recommendation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gflags 3 | from jTransUP.models import knowledgable_recommendation 4 | from jTransUP.models.base import get_flags, flag_defaults 5 | 6 | FLAGS = gflags.FLAGS 7 | 8 | if __name__ == '__main__': 9 | get_flags() 10 | # Parse command line flags. 11 | FLAGS(sys.argv) 12 | flag_defaults(FLAGS) 13 | knowledgable_recommendation.run(only_forward=FLAGS.eval_only_mode) -------------------------------------------------------------------------------- /run_knowledge_representation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gflags 3 | from jTransUP.models import knowledge_representation 4 | from jTransUP.models.base import get_flags, flag_defaults 5 | 6 | FLAGS = gflags.FLAGS 7 | 8 | if __name__ == '__main__': 9 | get_flags() 10 | # Parse command line flags. 11 | FLAGS(sys.argv) 12 | flag_defaults(FLAGS) 13 | knowledge_representation.run(only_forward=FLAGS.eval_only_mode) -------------------------------------------------------------------------------- /run_preprocess.py: -------------------------------------------------------------------------------- 1 | from jTransUP.data.preprocessRatings import preprocess as preprocessRating 2 | from jTransUP.data.preprocessTriples import preprocess as preprocessKG 3 | import os 4 | import logging 5 | 6 | data_path = "/Users/caoyixin/Github/joint-kg-recommender/datasets/" 7 | dataset = 'dbbook2014' 8 | 9 | dataset_path = os.path.join(data_path, dataset) 10 | kg_path = os.path.join(dataset_path, 'kg') 11 | 12 | rating_file = os.path.join(dataset_path, 'ratings.csv') 13 | triple_file = os.path.join(kg_path, "kg_hop0.dat") 14 | relation_file = os.path.join(kg_path, "relation_filter.dat") 15 | i2kg_file = os.path.join(dataset_path, "i2kg_map.tsv") 16 | 17 | log_path = dataset_path 18 | 19 | logger = logging.getLogger() 20 | logger.setLevel(level=logging.DEBUG) 21 | 22 | log_file = os.path.join(dataset_path, "data_preprocess.log") 23 | # Formatter 24 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 25 | # FileHandler 26 | file_handler = logging.FileHandler(log_file) 27 | file_handler.setFormatter(formatter) 28 | logger.addHandler(file_handler) 29 | 30 | # StreamHandler 31 | stream_handler = logging.StreamHandler() 32 | stream_handler.setFormatter(formatter) 33 | logger.addHandler(stream_handler) 34 | 35 | preprocessRating(rating_file, dataset_path, low_frequence=5, logger=logger) 36 | 37 | preprocessKG([triple_file], kg_path, entity_file=i2kg_file, relation_file=relation_file, logger=logger) 38 | -------------------------------------------------------------------------------- /run_test.py: -------------------------------------------------------------------------------- 1 | from jTransUP.data.load_rating_data import load_data 2 | from jTransUP.utils.data import MakeTrainIterator, MakeEvalIterator 3 | import os 4 | 5 | trainDict, testDict, validDict, allRatingDict, user_total, item_total, trainTotal, testTotal, validTotal = load_data("/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/") 6 | print("user:{}, item:{}!".format(user_total, item_total)) 7 | print("totally ratings for {} train, {} valid, and {} test!".format(trainTotal, validTotal, testTotal)) 8 | u_id = 249 9 | print("u:{} has brought items for train {}, valid {} and test {}!".format(u_id, trainDict[u_id], validDict[u_id] if validDict is not None else [], testDict[u_id])) 10 | 11 | eval_iter = MakeEvalIterator(validDict, item_total, 100, allRatingDict=allRatingDict) 12 | eval_total = 0 13 | item_count = 0 14 | while True: 15 | rating_batch = next(eval_iter) 16 | if rating_batch is None: break 17 | u, pi = rating_batch 18 | for i in u: 19 | if i == 249: item_count += 1 20 | eval_total += len(u) 21 | print(eval_total) 22 | print("user {} has {} eval!".format(u_id, item_count)) 23 | -------------------------------------------------------------------------------- /swipe.sh: -------------------------------------------------------------------------------- 1 | model_type=$1 2 | 3 | for model_type in bprmf fm transup 4 | do 5 | CUDA_VISIBLE_DEVICES=1 python run_item_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -num_preferences 10 -l2_lambda 1e-5 -negtive_samples 1 -model_type $model_type -has_visualization -dataset ml1m -batch_size 512 -embedding_size 100 -learning_rate 0.005 -topn 10 -seed 3 -eval_interval_steps 14000 -training_steps 1400000 -early_stopping_steps_to_wait 70000 6 | done 7 | 8 | CUDA_VISIBLE_DEVICES=1 python run_knowledge_representation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -kg_test_files valid.dat:test.dat -l2_lambda 1e-5 -negtive_samples 1 -model_type transe -has_visualization -dataset ml1m -batch_size 512 -embedding_size 100 -learning_rate 0.005 -topn 10 -seed 3 -eval_interval_steps 1250 -training_steps 125000 -early_stopping_steps_to_wait 6250 9 | 10 | CUDA_VISIBLE_DEVICES=1 python run_knowledge_representation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -kg_test_files valid.dat:test.dat -l2_lambda 1e-5 -negtive_samples 1 -model_type transh -has_visualization -dataset ml1m -batch_size 512 -embedding_size 100 -learning_rate 0.005 -topn 10 -seed 3 -eval_interval_steps 1250 -training_steps 125000 -early_stopping_steps_to_wait 6250 11 | 12 | CUDA_VISIBLE_DEVICES=1 python run_knowledgable_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -kg_test_files valid.dat:test.dat -l2_lambda 1e-5 -negtive_samples 1 -model_type jtransup -has_visualization -dataset ml1m -batch_size 512 -embedding_size 100 -learning_rate 0.005 -topn 10 -seed 3 -share_embeddings -joint_ratio 0.9 -eval_interval_steps 16000 -training_steps 1600000 -early_stopping_steps_to_wait 90000 13 | 14 | CUDA_VISIBLE_DEVICES=1 python run_knowledgable_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -kg_test_files valid.dat:test.dat -l2_lambda 1e-5 -negtive_samples 1 -model_type jtransup -has_visualization -dataset ml1m -batch_size 512 -embedding_size 100 -learning_rate 0.005 -topn 10 -seed 3 -noshare_embeddings -joint_ratio 0.9 -eval_interval_steps 16000 -training_steps 1600000 -early_stopping_steps_to_wait 90000 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from jTransUP.utils.visuliazer import Visualizer 2 | import random 3 | import visdom 4 | 5 | 6 | vis = Visualizer() 7 | for i in range(100): 8 | x = random.random() 9 | vis.plot_many_stack({'x': x}, win_name="train_loss", options={'height':10,'width':10}) 10 | y = random.random() 11 | z = random.random() 12 | vis.plot_many_stack({'y': y, 'z':z}, win_name="valid_loss", options={'height':300,'width':400}) 13 | ''' 14 | 15 | vis = visdom.Visdom() 16 | 17 | trace = dict(x=[1, 2, 3], y=[4, 5, 6], mode="markers+lines", type='custom', 18 | marker={'color': 'red', 'symbol': 104, 'size': "10"}, 19 | text=["one", "two", "three"], name='1st Trace') 20 | layout = dict(title="First Plot", xaxis={'title': 'x1'}, yaxis={'title': 'x2'}) 21 | 22 | vis._send({'data': [trace], 'layout': layout, 'win': 'mywin'}) 23 | ''' -------------------------------------------------------------------------------- /transe.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_knowledge_representation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -kg_test_files valid.dat:test.dat -l2_lambda 0 -model_type transe -has_visualization -dataset dbbook2014 -batch_size 256 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 9150 -training_steps 915000 -early_stopping_steps_to_wait 45750 -optimizer_type Adam -L1_flag -norm_lambda 1 -kg_lambda 1 -visualization_port 8098 2 | -------------------------------------------------------------------------------- /transh.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_knowledge_representation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -kg_test_files valid.dat:test.dat -l2_lambda 0 -model_type transh -has_visualization -dataset dbbook2014 -batch_size 256 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 9150 -training_steps 915000 -early_stopping_steps_to_wait 45750 -optimizer_type Adam -L1_flag -norm_lambda 1 -kg_lambda 1 -visualization_port 8098 -load_ckpt_file ~/joint-kg-recommender/log/tuned_dbbook2014/dbbook2014-transe-1540685958.ckpt 2 | -------------------------------------------------------------------------------- /transr.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_knowledge_representation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -kg_test_files valid.dat:test.dat -l2_lambda 0 -model_type transr -has_visualization -dataset dbbook2014 -batch_size 256 -embedding_size 100 -learning_rate 0.001 -topn 10 -seed 3 -eval_interval_steps 9150 -training_steps 915000 -early_stopping_steps_to_wait 45750 -optimizer_type Adam -L1_flag -norm_lambda 1 -kg_lambda 1 -visualization_port 8098 -load_ckpt_file ~/joint-kg-recommender/log/tuned_dbbook2014/dbbook2014-transe-1540685958.ckpt 2 | -------------------------------------------------------------------------------- /transup.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python run_item_recommendation.py -data_path ~/joint-kg-recommender/datasets/ -log_path ~/joint-kg-recommender/log/ -rec_test_files valid.dat:test.dat -l2_lambda 1e-5 -negtive_samples 1 -model_type transup -has_visualization -dataset dbbook2014 -batch_size 1024 -embedding_size 100 -learning_rate 0.005 -topn 10 -seed 3 -eval_interval_steps 500 -training_steps 50000 -early_stopping_steps_to_wait 2500 -optimizer_type Adagrad -L1_flag -num_preferences 13 -nouse_st_gumbel -visualization_port 8097 -load_ckpt_file /tuned_dbbook2014/dbbook2014-bprmf-1540692224.ckpt 2 | -------------------------------------------------------------------------------- /tt.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def loadRatings(filename): 4 | with open(filename, 'r', encoding='utf-8') as fin: 5 | user_dict = {} 6 | total_count = 0 7 | for line in fin: 8 | line_split = line.strip().split('\t') 9 | if len(line_split) != 3 : continue 10 | u = int(line_split[0]) 11 | i = int(line_split[1]) 12 | rating = int(line_split[2]) 13 | 14 | i_set = user_dict.get(u, set()) 15 | i_set.add( (i, rating) ) 16 | user_dict[u] = i_set 17 | total_count += 1 18 | return total_count, user_dict 19 | 20 | ''' 21 | train_filename = '/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/train.dat' 22 | train_total_count, train_user_dict = loadRatings(train_filename) 23 | for i in range(0, 10): 24 | filename = '/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/test{}.dat'.format(i) 25 | 26 | test_total_count, test_user_dict = loadRatings(filename) 27 | count = 0 28 | for u in test_user_dict: 29 | count += len(train_user_dict[u]) 30 | 31 | print(count/len(test_user_dict)) 32 | ''' 33 | train_filename = '/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/train.dat' 34 | train_total_count, train_user_dict = loadRatings(train_filename) 35 | valid_filename = '/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/valid.dat' 36 | valid_total_count, valid_user_dict = loadRatings(valid_filename) 37 | test_filename = '/Users/caoyixin/Github/joint-kg-recommender/datasets/ml1m/test.dat' 38 | test_total_count, test_user_dict = loadRatings(test_filename) 39 | 40 | count = 0 41 | 42 | for d in [train_user_dict, valid_user_dict, ] --------------------------------------------------------------------------------