├── .gitignore ├── README.md ├── config ├── FB15k-237.yaml ├── wn18rr.yaml └── wn18rr_predictorplus.yaml ├── data ├── FB15k-237 │ ├── README.txt │ ├── entities.dict │ ├── relations.dict │ ├── rnnlogic_rules.txt │ ├── test.txt │ ├── train.txt │ └── valid.txt └── wn18rr │ ├── entities.dict │ ├── relations.dict │ ├── rnnlogic_rules.txt │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── miner ├── main.cpp ├── pyrnnlogic.cpp ├── rnnlogic.cpp ├── rnnlogic.h └── setup.py ├── requirements.txt └── src ├── comm.py ├── data.py ├── embedding.py ├── generators.py ├── layers.py ├── predictors.py ├── run_predictorplus.py ├── run_rnnlogic.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | __pycache__ 3 | *.zip 4 | checkpoint 5 | *.pth 6 | *.pyd 7 | *.lib 8 | *.obj 9 | *.exp 10 | *build/ 11 | *.egg 12 | *.egg-info/ 13 | drugdiscovery 14 | scratch 15 | dataset 16 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RNNLogic+ 2 | 3 | 在此文件夹中,我们提供了RNNLogic+的重构代码,这是本文第3.4节中介绍的RNNLogic的改进版本。 4 | 5 | RNNLogic+ 的理念是首先通过运行 RNNLogic(不带 emb)来学习有用的逻辑规则,然后使用这些逻辑规则来训练更强大的推理预测器。通过这种方式,即使没有使用知识图谱嵌入,RNNLogic+也能通过emb实现与RNNLogic接近的结果。 6 | 7 | 若要运行 RNNLogic+,您可以执行以下步骤。 8 | 9 | ## Step 1: Mine logic rules 10 | 11 | In the first step, we mine some low-quality logic rules, which are used to pre-train the rule generator in RNNLogic+ to speed up training. 12 | 13 | To do that, go to the folder `miner`, and compile the codes by running the following command: 14 | 15 | `g++ -O3 rnnlogic.h rnnlogic.cpp main.cpp -o rnnlogic -lpthread` 16 | 17 | Afterwards, run the following command to mine logic rules: 18 | 19 | `./rnnlogic -data-path ../data/FB15k-237 -max-length 3 -threads 40 -lr 0.01 -wd 0.0005 -temp 100 -iterations 1 -top-n 0 -top-k 0 -top-n-out 0 -output-file mined_rules.txt` 20 | 21 | The codes run on CPUs. Thus it is better to use a server with many CPUs and use more threads by adjusing the option `-thread`. The program will output a file called `mined_rules.txt`, and you can move the file to your dataset folder. 22 | 23 | **In `data/FB15k-237` and `data/wn18rr`, we have provided these mined rules, so you can skip this step.** 24 | 25 | ## Step 2: Run RNNLogic+ 26 | 27 | 接下来,我们准备运行 RNNLogic。为此,请先编辑文件夹“config”中的配置文件,然后转到文件夹“src”。 28 | 29 | 如果您想使用single-GPU训练,请编辑第39行和第60行,然后进一步运行: 30 | 31 | `python run_rnnlogic.py --config ../config/FB15k-237.yaml` 32 | 33 | `python run_rnnlogic.py --config ../config/wn18rr.yaml` 34 | 35 | If you would like to use multi-GPU training, please run: 36 | 37 | `python -m torch.distributed.launch --nproc_per_node=4 run_rnnlogic.py --config ../config/FB15k-237.yaml` 38 | 39 | `python -m torch.distributed.launch --nproc_per_node=4 run_rnnlogic.py --config ../config/wn18rr.yaml` 40 | 41 | ## Results and Discussion 42 | 43 | Using the defaul configuration files, we are able to achieve the following results without using knowledge graph embeddings: 44 | 45 | **FB15k-237:** 46 | 47 | ``` 48 | Hit1 : 0.242949 49 | Hit3 : 0.358812 50 | Hit10: 0.494145 51 | MR : 384.201315 52 | MRR : 0.327182 53 | ``` 54 | 55 | **WN18RR:** 56 | 57 | ``` 58 | Hit1 : 0.439614 59 | Hit3 : 0.483718 60 | Hit10: 0.537939 61 | MR : 6377.744942 62 | MRR : 0.471933 63 | ``` 64 | 65 | **Discussion:** 66 | 67 | Note that for efficiency consideration, the default configurations are quite conservative, and it is easy to further improve the results. 68 | 69 | For example: 70 | 71 | - Current configuration files only consider logic rules which are not longer than 3. You might consider longer logic rules for better reasoning results. 72 | - Current configuration files specify the training iterations to 5. You might increase the value for better results. 73 | - Current configuration files specify the hidden dimension in the reasoning predictor to 16. You might also increase the value for better results. 74 | -------------------------------------------------------------------------------- /config/FB15k-237.yaml: -------------------------------------------------------------------------------- 1 | save_path: FB15k-237 2 | load_path: null 3 | seed: 1 4 | 5 | data: 6 | data_path: ../data/FB15k-237 7 | rule_file: ../data/FB15k-237/mined_rules.txt 8 | batch_size: 32 9 | 10 | EM: 11 | num_iters: 5 12 | prior_weight: 0.001 13 | num_rules: 100 14 | max_length: 3 15 | 16 | generator: 17 | gpu: 0 18 | model: 19 | embedding_dim: 512 20 | hidden_dim: 256 21 | num_layers: 1 22 | pre_train: 23 | num_epoch: 10000 24 | lr: 0.001 25 | print_every: 1000 26 | batch_size: 512 27 | train: 28 | num_epoch: 100 29 | lr: 0.00001 30 | print_every: 1000 31 | batch_size: 512 32 | post_train: 33 | num_epoch: 1000 34 | lr: 0.00001 35 | print_every: 1000 36 | batch_size: 512 37 | 38 | predictor: 39 | gpus: [0, 1, 2, 3] 40 | model: 41 | entity_feature: bias 42 | optimizer: 43 | lr: 0.001 44 | weight_decay: 0 45 | train: 46 | smoothing: 0.2 47 | batch_per_epoch: 1000000 48 | print_every: 1000 49 | eval: 50 | expectation: True 51 | H_score: 52 | print_every: 1000 53 | 54 | final_prediction: 55 | num_iters: 5 56 | num_rules: [100] 57 | max_length: [3] 58 | 59 | predictorplus: 60 | gpus: [0, 1, 2, 3] 61 | model: 62 | hidden_dim: 16 63 | optimizer: 64 | lr: 0.005 65 | weight_decay: 0 66 | train: 67 | smoothing: 0.2 68 | batch_per_epoch: 1000000 69 | print_every: 1000 70 | eval: 71 | expectation: True 72 | -------------------------------------------------------------------------------- /config/wn18rr.yaml: -------------------------------------------------------------------------------- 1 | save_path: wn18rr 2 | load_path: null 3 | seed: 1 4 | 5 | data: 6 | data_path: ../data/wn18rr 7 | rule_file: ../data/wn18rr/mined_rules.txt 8 | batch_size: 32 9 | 10 | EM: 11 | num_iters: 5 12 | prior_weight: 0.001 13 | num_rules: 100 14 | max_length: 3 15 | 16 | generator: 17 | gpu: 0 18 | model: 19 | embedding_dim: 512 20 | hidden_dim: 256 21 | num_layers: 1 22 | pre_train: 23 | num_epoch: 10000 24 | lr: 0.001 25 | print_every: 1000 26 | batch_size: 512 27 | train: 28 | num_epoch: 100 29 | lr: 0.00001 30 | print_every: 1000 31 | batch_size: 512 32 | post_train: 33 | num_epoch: 1000 34 | lr: 0.00001 35 | print_every: 1000 36 | batch_size: 512 37 | 38 | predictor: 39 | gpus: [0, 1, 2, 3] 40 | model: 41 | entity_feature: bias 42 | optimizer: 43 | lr: 0.001 44 | weight_decay: 0 45 | train: 46 | smoothing: 0.2 47 | batch_per_epoch: 1000000 48 | print_every: 1000 49 | eval: 50 | expectation: True 51 | H_score: 52 | print_every: 1000 53 | 54 | final_prediction: 55 | num_iters: 5 56 | num_rules: [100, 100, 500, 200, 100] 57 | max_length: [1, 2, 3, 4, 5] 58 | 59 | predictorplus: 60 | gpus: [0, 1, 2, 3] 61 | model: 62 | hidden_dim: 16 63 | aggregator: pna 64 | optimizer: 65 | lr: 0.005 66 | weight_decay: 0 67 | train: 68 | smoothing: 0.2 69 | batch_per_epoch: 1000000 70 | print_every: 1000 71 | eval: 72 | expectation: True 73 | -------------------------------------------------------------------------------- /config/wn18rr_predictorplus.yaml: -------------------------------------------------------------------------------- 1 | gpus: [0, 1, 2, 3] 2 | save_path: saved_wn18rr 3 | load_path: null 4 | seed: 1 5 | num_iters: 10 6 | 7 | data: 8 | data_path: ../data/wn18rr 9 | rule_file: ../data/wn18rr/rnnlogic_rules.txt 10 | batch_size: 32 11 | 12 | predictor: 13 | model: 14 | type: emb 15 | num_layers: 3 16 | hidden_dim: 16 17 | entity_feature: bias 18 | aggregator: pna 19 | embedding_path: null 20 | optimizer: 21 | lr: 0.005 22 | weight_decay: 0 23 | train: 24 | smoothing: 0.2 25 | batch_per_epoch: 1000000 26 | print_every: 1000 27 | eval: 28 | expectation: True 29 | 30 | -------------------------------------------------------------------------------- /data/FB15k-237/README.txt: -------------------------------------------------------------------------------- 1 | FB15K-237 Knowledge Base Completion Dataset 2 | 3 | This dataset contains knowledge base relation triples and textual mentions of Freebase entity pairs, as used in the work published in [1] and [2]. 4 | The knowledge base triples are a subset of the FB15K set [3], originally derived from Freebase. The textual mentions are derived from 200 million sentences from the ClueWeb12 [5] corpus coupled with Freebase entity mention annotations [4]. 5 | 6 | 7 | FILE FORMAT DETAILS 8 | 9 | The files train.txt, valid.txt, and test.text contain the training, development, and test set knowledge base triples used in both [1] and [2]. 10 | The file text_cvsc.txt contains the textual triples used in [2] and the file text_emnlp.txt contains the textual triples used in [1]. 11 | 12 | The knowledge base triples contain lines like this: 13 | 14 | /m/0grwj /people/person/profession /m/05sxg2 15 | 16 | The format is: 17 | 18 | mid1 relation mid2 19 | 20 | The separator is a tab character; the mids are Freebase ids of entities, and the relation is a single or a two-hop relation from Freebase, where an intermediate complex value type entity has been collapsed out. 21 | 22 | The textual mentions files have lines like this: 23 | 24 | /m/02qkt [XXX]:<-nn>:fact:<-pobj>:in:<-prep>:game:<-nsubj>:'s::pivot::[YYY] /m/05sb1 3 25 | 26 | This indicates the mids of two Freebase entities, together with a fully lexicalized dependency path between the entities. The last element in the tuple is the number of occurrences of the specified entity pair with the given dependency path in sentences from ClueWeb12. 27 | The dependency paths are specified as sequences of words (like the word "fact" above) and labeled dependency links (like above). The direction of traversal of a dependency arc is indicated by whether there is a - sign in front of the arc label "e.g." <-nsubj> vs . 28 | 29 | 30 | REFERENCES 31 | 32 | [1] Kristina Toutanova, Danqi Chen, Patrick Pantel, Hoifung Poon, Pallavi Choudhury, and Michael Gamon. Representing text for joint embedding of text and knowledge bases. In Proceedings of EMNLP 2015. 33 | [2] Kristina Toutanova and Danqi Chen. Observed versus latent features for knowledge base and text inference. In Proceedings of the 3rd Workshop on Continuous Vector Space Models and Their Compositionality 2015. 34 | [3] Antoine Bordes, Nicolas Usunier, Alberto Garcia Duran, Jason Weston, and Oksana Yakhnenko. Translating embeddings for modeling multirelational data. In Advances in Neural Information Processing Systems (NIPS) 2013. 35 | [4] Evgeniy Gabrilovich, Michael Ringgaard, and Amarnag Subramanya. FACC1: Freebase annotation of ClueWeb corpora, Version 1 (release date 2013-06-26, format version 1, correction level 0). http://lemurproject.org/clueweb12/FACC1/ 36 | [5] http://lemurproject.org/clueweb12/ 37 | 38 | 39 | CONTACT 40 | 41 | Please contact Kristina Toutanova kristout@microsoft.com if you have questions about the dataset. 42 | -------------------------------------------------------------------------------- /data/FB15k-237/relations.dict: -------------------------------------------------------------------------------- 1 | 0 /organization/organization/headquarters./location/mailing_address/state_province_region 2 | 1 /education/educational_institution/colors 3 | 2 /people/person/profession 4 | 3 /film/film/costume_design_by 5 | 4 /film/film/genre 6 | 5 /celebrities/celebrity/celebrity_friends./celebrities/friendship/friend 7 | 6 /tv/tv_producer/programs_produced./tv/tv_producer_term/producer_type 8 | 7 /film/film/executive_produced_by 9 | 8 /sports/sports_team/roster./basketball/basketball_roster_position/position 10 | 9 /award/award_nominee/award_nominations./award/award_nomination/nominated_for 11 | 10 /award/award_category/winners./award/award_honor/award_winner 12 | 11 /award/award_winner/awards_won./award/award_honor/award_winner 13 | 12 /music/artist/origin 14 | 13 /food/food/nutrients./food/nutrition_fact/nutrient 15 | 14 /film/film/distributors./film/film_film_distributor_relationship/region 16 | 15 /time/event/instance_of_recurring_event 17 | 16 /sports/professional_sports_team/draft_picks./sports/sports_league_draft_pick/school 18 | 17 /film/film/language 19 | 18 /location/statistical_region/places_exported_to./location/imports_and_exports/exported_to 20 | 19 /music/group_member/membership./music/group_membership/group 21 | 20 /tv/tv_network/programs./tv/tv_network_duration/program 22 | 21 /award/award_winning_work/awards_won./award/award_honor/award_winner 23 | 22 /people/person/places_lived./people/place_lived/location 24 | 23 /travel/travel_destination/climate./travel/travel_destination_monthly_climate/month 25 | 24 /broadcast/content/artist 26 | 25 /base/americancomedy/celebrity_impressionist/celebrities_impersonated 27 | 26 /base/popstra/celebrity/breakup./base/popstra/breakup/participant 28 | 27 /organization/organization/place_founded 29 | 28 /people/person/employment_history./business/employment_tenure/company 30 | 29 /location/statistical_region/gdp_nominal_per_capita./measurement_unit/dated_money_value/currency 31 | 30 /people/person/place_of_birth 32 | 31 /location/location/contains 33 | 32 /base/popstra/celebrity/dated./base/popstra/dated/participant 34 | 33 /user/ktrueman/default_domain/international_organization/member_states 35 | 34 /government/legislative_session/members./government/government_position_held/legislative_sessions 36 | 35 /film/film/estimated_budget./measurement_unit/dated_money_value/currency 37 | 36 /organization/non_profit_organization/registered_with./organization/non_profit_registration/registering_agency 38 | 37 /organization/organization/headquarters./location/mailing_address/country 39 | 38 /base/biblioness/bibs_location/country 40 | 39 /education/educational_institution/students_graduates./education/education/student 41 | 40 /music/group_member/membership./music/group_membership/role 42 | 41 /location/administrative_division/country 43 | 42 /award/ranked_item/appears_in_ranked_lists./award/ranking/list 44 | 43 /base/eating/practicer_of_diet/diet 45 | 44 /film/special_film_performance_type/film_performance_type./film/performance/film 46 | 45 /award/award_nominated_work/award_nominations./award/award_nomination/nominated_for 47 | 46 /film/director/film 48 | 47 /base/x2010fifaworldcupsouthafrica/world_cup_squad/current_world_cup_squad./base/x2010fifaworldcupsouthafrica/current_world_cup_squad/current_club 49 | 48 /olympics/olympic_games/participating_countries 50 | 49 /music/performance_role/regular_performances./music/group_membership/role 51 | 50 /music/artist/track_contributions./music/track_contribution/role 52 | 51 /base/aareas/schema/administrative_area/administrative_area_type 53 | 52 /film/film/distributors./film/film_film_distributor_relationship/film_distribution_medium 54 | 53 /olympics/olympic_games/sports 55 | 54 /soccer/football_team/current_roster./soccer/football_roster_position/position 56 | 55 /olympics/olympic_participating_country/athletes./olympics/olympic_athlete_affiliation/olympics 57 | 56 /military/military_combatant/military_conflicts./military/military_combatant_group/combatants 58 | 57 /tv/tv_personality/tv_regular_appearances./tv/tv_regular_personal_appearance/program 59 | 58 /common/topic/webpage./common/webpage/category 60 | 59 /music/genre/artists 61 | 60 /film/film/featured_film_locations 62 | 61 /location/location/adjoin_s./location/adjoining_relationship/adjoins 63 | 62 /sports/sports_team/colors 64 | 63 /tv/tv_program/program_creator 65 | 64 /business/business_operation/operating_income./measurement_unit/dated_money_value/currency 66 | 65 /ice_hockey/hockey_team/current_roster./sports/sports_team_roster/position 67 | 66 /film/film/prequel 68 | 67 /organization/endowed_organization/endowment./measurement_unit/dated_money_value/currency 69 | 68 /film/film_set_designer/film_sets_designed 70 | 69 /film/film/film_art_direction_by 71 | 70 /language/human_language/countries_spoken_in 72 | 71 /people/marriage_union_type/unions_of_this_type./people/marriage/location_of_ceremony 73 | 72 /tv/tv_writer/tv_programs./tv/tv_program_writer_relationship/tv_program 74 | 73 /government/political_party/politicians_in_this_party./government/political_party_tenure/politician 75 | 74 /sports/sports_team/roster./american_football/football_historical_roster_position/position_s 76 | 75 /film/film/release_date_s./film/film_regional_release_date/film_release_region 77 | 76 /film/film/release_date_s./film/film_regional_release_date/film_regional_debut_venue 78 | 77 /award/award_winning_work/awards_won./award/award_honor/honored_for 79 | 78 /location/capital_of_administrative_division/capital_of./location/administrative_division_capital_relationship/administrative_division 80 | 79 /location/hud_foreclosure_area/estimated_number_of_mortgages./measurement_unit/dated_integer/source 81 | 80 /award/award_category/winners./award/award_honor/ceremony 82 | 81 /people/person/languages 83 | 82 /film/actor/film./film/performance/film 84 | 83 /business/business_operation/revenue./measurement_unit/dated_money_value/currency 85 | 84 /base/petbreeds/city_with_dogs/top_breeds./base/petbreeds/dog_city_relationship/dog_breed 86 | 85 /sports/sports_team_location/teams 87 | 86 /film/film/music 88 | 87 /sports/professional_sports_team/draft_picks./sports/sports_league_draft_pick/draft 89 | 88 /education/educational_institution/students_graduates./education/education/major_field_of_study 90 | 89 /people/ethnicity/geographic_distribution 91 | 90 /sports/sports_league/teams./sports/sports_league_participation/team 92 | 91 /education/educational_degree/people_with_this_degree./education/education/student 93 | 92 /government/politician/government_positions_held./government/government_position_held/jurisdiction_of_office 94 | 93 /base/aareas/schema/administrative_area/capital 95 | 94 /film/film/film_production_design_by 96 | 95 /user/jg/default_domain/olympic_games/sports 97 | 96 /award/award_category/category_of 98 | 97 /education/educational_institution/school_type 99 | 98 /sports/sports_team/roster./baseball/baseball_roster_position/position 100 | 99 /tv/tv_producer/programs_produced./tv/tv_producer_term/program 101 | 100 /location/us_county/county_seat 102 | 101 /education/university/fraternities_and_sororities 103 | 102 /film/film/other_crew./film/film_crew_gig/crewmember 104 | 103 /military/military_conflict/combatants./military/military_combatant_group/combatants 105 | 104 /base/popstra/celebrity/canoodled./base/popstra/canoodled/participant 106 | 105 /education/educational_degree/people_with_this_degree./education/education/institution 107 | 106 /organization/organization/child./organization/organization_relationship/child 108 | 107 /travel/travel_destination/how_to_get_here./travel/transportation/mode_of_transportation 109 | 108 /award/award_category/nominees./award/award_nomination/nominated_for 110 | 109 /medicine/symptom/symptom_of 111 | 110 /people/ethnicity/people 112 | 111 /film/film/other_crew./film/film_crew_gig/film_crew_role 113 | 112 /government/governmental_body/members./government/government_position_held/legislative_sessions 114 | 113 /business/business_operation/industry 115 | 114 /film/film/country 116 | 115 /people/profession/specialization_of 117 | 116 /location/hud_county_place/place 118 | 117 /organization/role/leaders./organization/leadership/organization 119 | 118 /music/instrument/instrumentalists 120 | 119 /time/event/locations 121 | 120 /film/film/produced_by 122 | 121 /music/performance_role/track_performances./music/track_contribution/role 123 | 122 /film/film/runtime./film/film_cut/film_release_region 124 | 123 /olympics/olympic_sport/athletes./olympics/olympic_athlete_affiliation/country 125 | 124 /tv/tv_program/regular_cast./tv/regular_tv_appearance/actor 126 | 125 /award/award_nominee/award_nominations./award/award_nomination/award 127 | 126 /people/person/spouse_s./people/marriage/type_of_union 128 | 127 /film/actor/dubbing_performances./film/dubbing_performance/language 129 | 128 /sports/sports_position/players./sports/sports_team_roster/team 130 | 129 /award/award_ceremony/awards_presented./award/award_honor/honored_for 131 | 130 /sports/sports_team/sport 132 | 131 /tv/tv_program/country_of_origin 133 | 132 /award/award_category/disciplines_or_subjects 134 | 133 /base/popstra/celebrity/friendship./base/popstra/friendship/participant 135 | 134 /people/ethnicity/languages_spoken 136 | 135 /tv/tv_program/genre 137 | 136 /education/educational_degree/people_with_this_degree./education/education/major_field_of_study 138 | 137 /people/person/sibling_s./people/sibling_relationship/sibling 139 | 138 /business/business_operation/assets./measurement_unit/dated_money_value/currency 140 | 139 /olympics/olympic_games/medals_awarded./olympics/olympic_medal_honor/medal 141 | 140 /film/film/edited_by 142 | 141 /film/actor/film./film/performance/special_performance_type 143 | 142 /education/educational_institution_campus/educational_institution 144 | 143 /film/film/written_by 145 | 144 /sports/sports_position/players./sports/sports_team_roster/position 146 | 145 /base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/service_location 147 | 146 /film/film/personal_appearances./film/personal_film_appearance/person 148 | 147 /user/tsegaran/random/taxonomy_subject/entry./user/tsegaran/random/taxonomy_entry/taxonomy 149 | 148 /people/person/gender 150 | 149 /people/deceased_person/place_of_death 151 | 150 /location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency 152 | 151 /music/performance_role/guest_performances./music/recording_contribution/performance_role 153 | 152 /olympics/olympic_participating_country/medals_won./olympics/olympic_medal_honor/medal 154 | 153 /dataworld/gardening_hint/split_to 155 | 154 /location/country/capital 156 | 155 /award/award_winning_work/awards_won./award/award_honor/award 157 | 156 /tv/tv_program/tv_producer./tv/tv_producer_term/producer_type 158 | 157 /base/biblioness/bibs_location/state 159 | 158 /influence/influence_node/peers./influence/peer_relationship/peers 160 | 159 /film/film/story_by 161 | 160 /location/administrative_division/first_level_division_of 162 | 161 /baseball/baseball_team/team_stats./baseball/baseball_team_stats/season 163 | 162 /award/hall_of_fame/inductees./award/hall_of_fame_induction/inductee 164 | 163 /sports/sports_team/roster./american_football/football_roster_position/position 165 | 164 /base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/service_language 166 | 165 /sports/sports_position/players./american_football/football_historical_roster_position/position_s 167 | 166 /media_common/netflix_genre/titles 168 | 167 /people/person/spouse_s./people/marriage/spouse 169 | 168 /people/cause_of_death/people 170 | 169 /organization/organization_founder/organizations_founded 171 | 170 /government/government_office_category/officeholders./government/government_position_held/jurisdiction_of_office 172 | 171 /tv/tv_program/languages 173 | 172 /base/popstra/location/vacationers./base/popstra/vacation_choice/vacationer 174 | 173 /influence/influence_node/influenced_by 175 | 174 /location/country/second_level_divisions 176 | 175 /sports/sport/pro_athletes./sports/pro_sports_played/athlete 177 | 176 /government/legislative_session/members./government/government_position_held/district_represented 178 | 177 /olympics/olympic_sport/athletes./olympics/olympic_athlete_affiliation/olympics 179 | 178 /medicine/disease/risk_factors 180 | 179 /award/award_ceremony/awards_presented./award/award_honor/award_winner 181 | 180 /american_football/football_team/current_roster./sports/sports_team_roster/position 182 | 181 /music/artist/contribution./music/recording_contribution/performance_role 183 | 182 /education/educational_institution/campuses 184 | 183 /location/country/form_of_government 185 | 184 /base/marchmadness/ncaa_basketball_tournament/seeds./base/marchmadness/ncaa_tournament_seed/team 186 | 185 /education/field_of_study/students_majoring./education/education/major_field_of_study 187 | 186 /people/person/nationality 188 | 187 /film/film/release_date_s./film/film_regional_release_date/film_release_distribution_medium 189 | 188 /film/film/film_format 190 | 189 /soccer/football_player/current_team./sports/sports_team_roster/team 191 | 190 /government/politician/government_positions_held./government/government_position_held/legislative_sessions 192 | 191 /film/film/cinematography 193 | 192 /people/deceased_person/place_of_burial 194 | 193 /base/aareas/schema/administrative_area/administrative_parent 195 | 194 /music/genre/parent_genre 196 | 195 /sports/sports_league_draft/picks./sports/sports_league_draft_pick/school 197 | 196 /location/statistical_region/religions./location/religion_percentage/religion 198 | 197 /location/location/time_zones 199 | 198 /olympics/olympic_participating_country/medals_won./olympics/olympic_medal_honor/olympics 200 | 199 /film/film_distributor/films_distributed./film/film_film_distributor_relationship/film 201 | 200 /film/film/dubbing_performances./film/dubbing_performance/actor 202 | 201 /organization/organization/headquarters./location/mailing_address/citytown 203 | 202 /sports/pro_athlete/teams./sports/sports_team_roster/team 204 | 203 /education/university/local_tuition./measurement_unit/dated_money_value/currency 205 | 204 /music/record_label/artist 206 | 205 /business/job_title/people_with_this_title./business/employment_tenure/company 207 | 206 /music/instrument/family 208 | 207 /user/alexander/philosophy/philosopher/interests 209 | 208 /location/statistical_region/gdp_real./measurement_unit/adjusted_money_value/adjustment_currency 210 | 209 /tv/non_character_role/tv_regular_personal_appearances./tv/tv_regular_personal_appearance/person 211 | 210 /location/hud_county_place/county 212 | 211 /government/politician/government_positions_held./government/government_position_held/basic_title 213 | 212 /base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/contact_category 214 | 213 /people/person/religion 215 | 214 /education/university/domestic_tuition./measurement_unit/dated_money_value/currency 216 | 215 /award/award_nominee/award_nominations./award/award_nomination/award_nominee 217 | 216 /music/performance_role/regular_performances./music/group_membership/group 218 | 217 /education/university/international_tuition./measurement_unit/dated_money_value/currency 219 | 218 /film/film/film_festivals 220 | 219 /location/statistical_region/gdp_nominal./measurement_unit/dated_money_value/currency 221 | 220 /base/saturdaynightlive/snl_cast_member/seasons./base/saturdaynightlive/snl_season_tenure/cast_members 222 | 221 /education/field_of_study/students_majoring./education/education/student 223 | 222 /location/statistical_region/gni_per_capita_in_ppp_dollars./measurement_unit/dated_money_value/currency 224 | 223 /base/localfood/seasonal_month/produce_available./base/localfood/produce_availability/seasonal_months 225 | 224 /film/film_subject/films 226 | 225 /soccer/football_team/current_roster./sports/sports_team_roster/position 227 | 226 /location/location/partially_contains 228 | 227 /celebrities/celebrity/sexual_relationships./celebrities/romantic_relationship/celebrity 229 | 228 /people/person/spouse_s./people/marriage/location_of_ceremony 230 | 229 /base/culturalevent/event/entity_involved 231 | 230 /organization/organization_member/member_of./organization/organization_membership/organization 232 | 231 /base/locations/continents/countries_within 233 | 232 /location/country/official_language 234 | 233 /film/film/production_companies 235 | 234 /base/schemastaging/person_extra/net_worth./measurement_unit/dated_money_value/currency 236 | 235 /medicine/disease/notable_people_with_this_condition 237 | 236 /film/person_or_entity_appearing_in_film/films./film/personal_film_appearance/type_of_appearance 238 | 237 !/organization/organization/headquarters./location/mailing_address/state_province_region 239 | 238 !/education/educational_institution/colors 240 | 239 !/people/person/profession 241 | 240 !/film/film/costume_design_by 242 | 241 !/film/film/genre 243 | 242 !/celebrities/celebrity/celebrity_friends./celebrities/friendship/friend 244 | 243 !/tv/tv_producer/programs_produced./tv/tv_producer_term/producer_type 245 | 244 !/film/film/executive_produced_by 246 | 245 !/sports/sports_team/roster./basketball/basketball_roster_position/position 247 | 246 !/award/award_nominee/award_nominations./award/award_nomination/nominated_for 248 | 247 !/award/award_category/winners./award/award_honor/award_winner 249 | 248 !/award/award_winner/awards_won./award/award_honor/award_winner 250 | 249 !/music/artist/origin 251 | 250 !/food/food/nutrients./food/nutrition_fact/nutrient 252 | 251 !/film/film/distributors./film/film_film_distributor_relationship/region 253 | 252 !/time/event/instance_of_recurring_event 254 | 253 !/sports/professional_sports_team/draft_picks./sports/sports_league_draft_pick/school 255 | 254 !/film/film/language 256 | 255 !/location/statistical_region/places_exported_to./location/imports_and_exports/exported_to 257 | 256 !/music/group_member/membership./music/group_membership/group 258 | 257 !/tv/tv_network/programs./tv/tv_network_duration/program 259 | 258 !/award/award_winning_work/awards_won./award/award_honor/award_winner 260 | 259 !/people/person/places_lived./people/place_lived/location 261 | 260 !/travel/travel_destination/climate./travel/travel_destination_monthly_climate/month 262 | 261 !/broadcast/content/artist 263 | 262 !/base/americancomedy/celebrity_impressionist/celebrities_impersonated 264 | 263 !/base/popstra/celebrity/breakup./base/popstra/breakup/participant 265 | 264 !/organization/organization/place_founded 266 | 265 !/people/person/employment_history./business/employment_tenure/company 267 | 266 !/location/statistical_region/gdp_nominal_per_capita./measurement_unit/dated_money_value/currency 268 | 267 !/people/person/place_of_birth 269 | 268 !/location/location/contains 270 | 269 !/base/popstra/celebrity/dated./base/popstra/dated/participant 271 | 270 !/user/ktrueman/default_domain/international_organization/member_states 272 | 271 !/government/legislative_session/members./government/government_position_held/legislative_sessions 273 | 272 !/film/film/estimated_budget./measurement_unit/dated_money_value/currency 274 | 273 !/organization/non_profit_organization/registered_with./organization/non_profit_registration/registering_agency 275 | 274 !/organization/organization/headquarters./location/mailing_address/country 276 | 275 !/base/biblioness/bibs_location/country 277 | 276 !/education/educational_institution/students_graduates./education/education/student 278 | 277 !/music/group_member/membership./music/group_membership/role 279 | 278 !/location/administrative_division/country 280 | 279 !/award/ranked_item/appears_in_ranked_lists./award/ranking/list 281 | 280 !/base/eating/practicer_of_diet/diet 282 | 281 !/film/special_film_performance_type/film_performance_type./film/performance/film 283 | 282 !/award/award_nominated_work/award_nominations./award/award_nomination/nominated_for 284 | 283 !/film/director/film 285 | 284 !/base/x2010fifaworldcupsouthafrica/world_cup_squad/current_world_cup_squad./base/x2010fifaworldcupsouthafrica/current_world_cup_squad/current_club 286 | 285 !/olympics/olympic_games/participating_countries 287 | 286 !/music/performance_role/regular_performances./music/group_membership/role 288 | 287 !/music/artist/track_contributions./music/track_contribution/role 289 | 288 !/base/aareas/schema/administrative_area/administrative_area_type 290 | 289 !/film/film/distributors./film/film_film_distributor_relationship/film_distribution_medium 291 | 290 !/olympics/olympic_games/sports 292 | 291 !/soccer/football_team/current_roster./soccer/football_roster_position/position 293 | 292 !/olympics/olympic_participating_country/athletes./olympics/olympic_athlete_affiliation/olympics 294 | 293 !/military/military_combatant/military_conflicts./military/military_combatant_group/combatants 295 | 294 !/tv/tv_personality/tv_regular_appearances./tv/tv_regular_personal_appearance/program 296 | 295 !/common/topic/webpage./common/webpage/category 297 | 296 !/music/genre/artists 298 | 297 !/film/film/featured_film_locations 299 | 298 !/location/location/adjoin_s./location/adjoining_relationship/adjoins 300 | 299 !/sports/sports_team/colors 301 | 300 !/tv/tv_program/program_creator 302 | 301 !/business/business_operation/operating_income./measurement_unit/dated_money_value/currency 303 | 302 !/ice_hockey/hockey_team/current_roster./sports/sports_team_roster/position 304 | 303 !/film/film/prequel 305 | 304 !/organization/endowed_organization/endowment./measurement_unit/dated_money_value/currency 306 | 305 !/film/film_set_designer/film_sets_designed 307 | 306 !/film/film/film_art_direction_by 308 | 307 !/language/human_language/countries_spoken_in 309 | 308 !/people/marriage_union_type/unions_of_this_type./people/marriage/location_of_ceremony 310 | 309 !/tv/tv_writer/tv_programs./tv/tv_program_writer_relationship/tv_program 311 | 310 !/government/political_party/politicians_in_this_party./government/political_party_tenure/politician 312 | 311 !/sports/sports_team/roster./american_football/football_historical_roster_position/position_s 313 | 312 !/film/film/release_date_s./film/film_regional_release_date/film_release_region 314 | 313 !/film/film/release_date_s./film/film_regional_release_date/film_regional_debut_venue 315 | 314 !/award/award_winning_work/awards_won./award/award_honor/honored_for 316 | 315 !/location/capital_of_administrative_division/capital_of./location/administrative_division_capital_relationship/administrative_division 317 | 316 !/location/hud_foreclosure_area/estimated_number_of_mortgages./measurement_unit/dated_integer/source 318 | 317 !/award/award_category/winners./award/award_honor/ceremony 319 | 318 !/people/person/languages 320 | 319 !/film/actor/film./film/performance/film 321 | 320 !/business/business_operation/revenue./measurement_unit/dated_money_value/currency 322 | 321 !/base/petbreeds/city_with_dogs/top_breeds./base/petbreeds/dog_city_relationship/dog_breed 323 | 322 !/sports/sports_team_location/teams 324 | 323 !/film/film/music 325 | 324 !/sports/professional_sports_team/draft_picks./sports/sports_league_draft_pick/draft 326 | 325 !/education/educational_institution/students_graduates./education/education/major_field_of_study 327 | 326 !/people/ethnicity/geographic_distribution 328 | 327 !/sports/sports_league/teams./sports/sports_league_participation/team 329 | 328 !/education/educational_degree/people_with_this_degree./education/education/student 330 | 329 !/government/politician/government_positions_held./government/government_position_held/jurisdiction_of_office 331 | 330 !/base/aareas/schema/administrative_area/capital 332 | 331 !/film/film/film_production_design_by 333 | 332 !/user/jg/default_domain/olympic_games/sports 334 | 333 !/award/award_category/category_of 335 | 334 !/education/educational_institution/school_type 336 | 335 !/sports/sports_team/roster./baseball/baseball_roster_position/position 337 | 336 !/tv/tv_producer/programs_produced./tv/tv_producer_term/program 338 | 337 !/location/us_county/county_seat 339 | 338 !/education/university/fraternities_and_sororities 340 | 339 !/film/film/other_crew./film/film_crew_gig/crewmember 341 | 340 !/military/military_conflict/combatants./military/military_combatant_group/combatants 342 | 341 !/base/popstra/celebrity/canoodled./base/popstra/canoodled/participant 343 | 342 !/education/educational_degree/people_with_this_degree./education/education/institution 344 | 343 !/organization/organization/child./organization/organization_relationship/child 345 | 344 !/travel/travel_destination/how_to_get_here./travel/transportation/mode_of_transportation 346 | 345 !/award/award_category/nominees./award/award_nomination/nominated_for 347 | 346 !/medicine/symptom/symptom_of 348 | 347 !/people/ethnicity/people 349 | 348 !/film/film/other_crew./film/film_crew_gig/film_crew_role 350 | 349 !/government/governmental_body/members./government/government_position_held/legislative_sessions 351 | 350 !/business/business_operation/industry 352 | 351 !/film/film/country 353 | 352 !/people/profession/specialization_of 354 | 353 !/location/hud_county_place/place 355 | 354 !/organization/role/leaders./organization/leadership/organization 356 | 355 !/music/instrument/instrumentalists 357 | 356 !/time/event/locations 358 | 357 !/film/film/produced_by 359 | 358 !/music/performance_role/track_performances./music/track_contribution/role 360 | 359 !/film/film/runtime./film/film_cut/film_release_region 361 | 360 !/olympics/olympic_sport/athletes./olympics/olympic_athlete_affiliation/country 362 | 361 !/tv/tv_program/regular_cast./tv/regular_tv_appearance/actor 363 | 362 !/award/award_nominee/award_nominations./award/award_nomination/award 364 | 363 !/people/person/spouse_s./people/marriage/type_of_union 365 | 364 !/film/actor/dubbing_performances./film/dubbing_performance/language 366 | 365 !/sports/sports_position/players./sports/sports_team_roster/team 367 | 366 !/award/award_ceremony/awards_presented./award/award_honor/honored_for 368 | 367 !/sports/sports_team/sport 369 | 368 !/tv/tv_program/country_of_origin 370 | 369 !/award/award_category/disciplines_or_subjects 371 | 370 !/base/popstra/celebrity/friendship./base/popstra/friendship/participant 372 | 371 !/people/ethnicity/languages_spoken 373 | 372 !/tv/tv_program/genre 374 | 373 !/education/educational_degree/people_with_this_degree./education/education/major_field_of_study 375 | 374 !/people/person/sibling_s./people/sibling_relationship/sibling 376 | 375 !/business/business_operation/assets./measurement_unit/dated_money_value/currency 377 | 376 !/olympics/olympic_games/medals_awarded./olympics/olympic_medal_honor/medal 378 | 377 !/film/film/edited_by 379 | 378 !/film/actor/film./film/performance/special_performance_type 380 | 379 !/education/educational_institution_campus/educational_institution 381 | 380 !/film/film/written_by 382 | 381 !/sports/sports_position/players./sports/sports_team_roster/position 383 | 382 !/base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/service_location 384 | 383 !/film/film/personal_appearances./film/personal_film_appearance/person 385 | 384 !/user/tsegaran/random/taxonomy_subject/entry./user/tsegaran/random/taxonomy_entry/taxonomy 386 | 385 !/people/person/gender 387 | 386 !/people/deceased_person/place_of_death 388 | 387 !/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency 389 | 388 !/music/performance_role/guest_performances./music/recording_contribution/performance_role 390 | 389 !/olympics/olympic_participating_country/medals_won./olympics/olympic_medal_honor/medal 391 | 390 !/dataworld/gardening_hint/split_to 392 | 391 !/location/country/capital 393 | 392 !/award/award_winning_work/awards_won./award/award_honor/award 394 | 393 !/tv/tv_program/tv_producer./tv/tv_producer_term/producer_type 395 | 394 !/base/biblioness/bibs_location/state 396 | 395 !/influence/influence_node/peers./influence/peer_relationship/peers 397 | 396 !/film/film/story_by 398 | 397 !/location/administrative_division/first_level_division_of 399 | 398 !/baseball/baseball_team/team_stats./baseball/baseball_team_stats/season 400 | 399 !/award/hall_of_fame/inductees./award/hall_of_fame_induction/inductee 401 | 400 !/sports/sports_team/roster./american_football/football_roster_position/position 402 | 401 !/base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/service_language 403 | 402 !/sports/sports_position/players./american_football/football_historical_roster_position/position_s 404 | 403 !/media_common/netflix_genre/titles 405 | 404 !/people/person/spouse_s./people/marriage/spouse 406 | 405 !/people/cause_of_death/people 407 | 406 !/organization/organization_founder/organizations_founded 408 | 407 !/government/government_office_category/officeholders./government/government_position_held/jurisdiction_of_office 409 | 408 !/tv/tv_program/languages 410 | 409 !/base/popstra/location/vacationers./base/popstra/vacation_choice/vacationer 411 | 410 !/influence/influence_node/influenced_by 412 | 411 !/location/country/second_level_divisions 413 | 412 !/sports/sport/pro_athletes./sports/pro_sports_played/athlete 414 | 413 !/government/legislative_session/members./government/government_position_held/district_represented 415 | 414 !/olympics/olympic_sport/athletes./olympics/olympic_athlete_affiliation/olympics 416 | 415 !/medicine/disease/risk_factors 417 | 416 !/award/award_ceremony/awards_presented./award/award_honor/award_winner 418 | 417 !/american_football/football_team/current_roster./sports/sports_team_roster/position 419 | 418 !/music/artist/contribution./music/recording_contribution/performance_role 420 | 419 !/education/educational_institution/campuses 421 | 420 !/location/country/form_of_government 422 | 421 !/base/marchmadness/ncaa_basketball_tournament/seeds./base/marchmadness/ncaa_tournament_seed/team 423 | 422 !/education/field_of_study/students_majoring./education/education/major_field_of_study 424 | 423 !/people/person/nationality 425 | 424 !/film/film/release_date_s./film/film_regional_release_date/film_release_distribution_medium 426 | 425 !/film/film/film_format 427 | 426 !/soccer/football_player/current_team./sports/sports_team_roster/team 428 | 427 !/government/politician/government_positions_held./government/government_position_held/legislative_sessions 429 | 428 !/film/film/cinematography 430 | 429 !/people/deceased_person/place_of_burial 431 | 430 !/base/aareas/schema/administrative_area/administrative_parent 432 | 431 !/music/genre/parent_genre 433 | 432 !/sports/sports_league_draft/picks./sports/sports_league_draft_pick/school 434 | 433 !/location/statistical_region/religions./location/religion_percentage/religion 435 | 434 !/location/location/time_zones 436 | 435 !/olympics/olympic_participating_country/medals_won./olympics/olympic_medal_honor/olympics 437 | 436 !/film/film_distributor/films_distributed./film/film_film_distributor_relationship/film 438 | 437 !/film/film/dubbing_performances./film/dubbing_performance/actor 439 | 438 !/organization/organization/headquarters./location/mailing_address/citytown 440 | 439 !/sports/pro_athlete/teams./sports/sports_team_roster/team 441 | 440 !/education/university/local_tuition./measurement_unit/dated_money_value/currency 442 | 441 !/music/record_label/artist 443 | 442 !/business/job_title/people_with_this_title./business/employment_tenure/company 444 | 443 !/music/instrument/family 445 | 444 !/user/alexander/philosophy/philosopher/interests 446 | 445 !/location/statistical_region/gdp_real./measurement_unit/adjusted_money_value/adjustment_currency 447 | 446 !/tv/non_character_role/tv_regular_personal_appearances./tv/tv_regular_personal_appearance/person 448 | 447 !/location/hud_county_place/county 449 | 448 !/government/politician/government_positions_held./government/government_position_held/basic_title 450 | 449 !/base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/contact_category 451 | 450 !/people/person/religion 452 | 451 !/education/university/domestic_tuition./measurement_unit/dated_money_value/currency 453 | 452 !/award/award_nominee/award_nominations./award/award_nomination/award_nominee 454 | 453 !/music/performance_role/regular_performances./music/group_membership/group 455 | 454 !/education/university/international_tuition./measurement_unit/dated_money_value/currency 456 | 455 !/film/film/film_festivals 457 | 456 !/location/statistical_region/gdp_nominal./measurement_unit/dated_money_value/currency 458 | 457 !/base/saturdaynightlive/snl_cast_member/seasons./base/saturdaynightlive/snl_season_tenure/cast_members 459 | 458 !/education/field_of_study/students_majoring./education/education/student 460 | 459 !/location/statistical_region/gni_per_capita_in_ppp_dollars./measurement_unit/dated_money_value/currency 461 | 460 !/base/localfood/seasonal_month/produce_available./base/localfood/produce_availability/seasonal_months 462 | 461 !/film/film_subject/films 463 | 462 !/soccer/football_team/current_roster./sports/sports_team_roster/position 464 | 463 !/location/location/partially_contains 465 | 464 !/celebrities/celebrity/sexual_relationships./celebrities/romantic_relationship/celebrity 466 | 465 !/people/person/spouse_s./people/marriage/location_of_ceremony 467 | 466 !/base/culturalevent/event/entity_involved 468 | 467 !/organization/organization_member/member_of./organization/organization_membership/organization 469 | 468 !/base/locations/continents/countries_within 470 | 469 !/location/country/official_language 471 | 470 !/film/film/production_companies 472 | 471 !/base/schemastaging/person_extra/net_worth./measurement_unit/dated_money_value/currency 473 | 472 !/medicine/disease/notable_people_with_this_condition 474 | 473 !/film/person_or_entity_appearing_in_film/films./film/personal_film_appearance/type_of_appearance 475 | -------------------------------------------------------------------------------- /data/wn18rr/relations.dict: -------------------------------------------------------------------------------- 1 | 0 _hypernym 2 | 1 _derivationally_related_form 3 | 2 _instance_hypernym 4 | 3 _also_see 5 | 4 _member_meronym 6 | 5 _synset_domain_topic_of 7 | 6 _has_part 8 | 7 _member_of_domain_usage 9 | 8 _member_of_domain_region 10 | 9 _verb_group 11 | 10 _similar_to 12 | 11 !_hypernym 13 | 12 !_derivationally_related_form 14 | 13 !_instance_hypernym 15 | 14 !_also_see 16 | 15 !_member_meronym 17 | 16 !_synset_domain_topic_of 18 | 17 !_has_part 19 | 18 !_member_of_domain_usage 20 | 19 !_member_of_domain_region 21 | 20 !_verb_group 22 | 21 !_similar_to 23 | -------------------------------------------------------------------------------- /miner/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "rnnlogic.h" 15 | 16 | char data_path[MAX_STRING], output_file[MAX_STRING]; 17 | int max_length = 2, num_threads = 1, iterations = 10, top_k = 10, top_n = 100, top_n_out = 100; 18 | double total_loss = 0, learning_rate = 0.01, weight_decay = 0.0, temperature = 100.0; 19 | double miner_portion = 1.0, predictor_portion = 1.0; 20 | 21 | KnowledgeGraph KG; 22 | RuleMiner RM; 23 | ReasoningPredictor RP; 24 | RuleGenerator RG; 25 | Result result; 26 | 27 | void train() 28 | { 29 | printf("%lf %lf\n", miner_portion, predictor_portion); 30 | 31 | KG.read_data(data_path); 32 | 33 | RM.init_knowledge_graph(&KG); 34 | RM.search(max_length, miner_portion, num_threads); 35 | 36 | RP.init_knowledge_graph(&KG); 37 | RG.init_knowledge_graph(&KG); 38 | RG.set_pool(RM.get_logic_rules()); 39 | 40 | for (int k = 0; k != iterations; k++) 41 | { 42 | RG.random_from_pool(top_n); 43 | RP.set_logic_rules(RG.get_logic_rules()); 44 | RP.learn(learning_rate, weight_decay, temperature, false, predictor_portion, num_threads); 45 | RP.H_score(top_k, 1, 0, predictor_portion, num_threads); 46 | RG.update(RP.get_logic_rules()); 47 | } 48 | RG.out_rules(output_file, top_n_out); 49 | } 50 | 51 | int ArgPos(char *str, int argc, char **argv) 52 | { 53 | int a; 54 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) 55 | { 56 | if (a == argc - 1) 57 | { 58 | printf("Argument missing for %s\n", str); 59 | exit(1); 60 | } 61 | return a; 62 | } 63 | return -1; 64 | } 65 | 66 | int main(int argc, char **argv) 67 | { 68 | int i; 69 | if (argc == 1) 70 | { 71 | return 0; 72 | } 73 | data_path[0] = 0; 74 | if ((i = ArgPos((char *)"-data-path", argc, argv)) > 0) strcpy(data_path, argv[i + 1]); 75 | if ((i = ArgPos((char *)"-output-file", argc, argv)) > 0) strcpy(output_file, argv[i + 1]); 76 | if ((i = ArgPos((char *)"-max-length", argc, argv)) > 0) max_length = atoi(argv[i + 1]); 77 | if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) num_threads = atoi(argv[i + 1]); 78 | if ((i = ArgPos((char *)"-iterations", argc, argv)) > 0) iterations = atoi(argv[i + 1]); 79 | if ((i = ArgPos((char *)"-lr", argc, argv)) > 0) learning_rate = atof(argv[i + 1]); 80 | if ((i = ArgPos((char *)"-wd", argc, argv)) > 0) weight_decay = atof(argv[i + 1]); 81 | if ((i = ArgPos((char *)"-temp", argc, argv)) > 0) temperature = atof(argv[i + 1]); 82 | if ((i = ArgPos((char *)"-top-k", argc, argv)) > 0) top_k = atoi(argv[i + 1]); 83 | if ((i = ArgPos((char *)"-top-n", argc, argv)) > 0) top_n = atoi(argv[i + 1]); 84 | if ((i = ArgPos((char *)"-top-n-out", argc, argv)) > 0) top_n_out = atoi(argv[i + 1]); 85 | 86 | if (top_n == 0) top_n = 2000000000; 87 | if (top_n_out == 0) top_n_out = 2000000000; 88 | 89 | train(); 90 | return 0; 91 | } 92 | -------------------------------------------------------------------------------- /miner/pyrnnlogic.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | namespace py = pybind11; 4 | 5 | #include "rnnlogic.h" 6 | #include 7 | 8 | void *new_knowledge_graph(char *data_path) 9 | { 10 | KnowledgeGraph *p_kg = new KnowledgeGraph; 11 | p_kg->read_data(data_path); 12 | return (void *)(p_kg); 13 | } 14 | 15 | int num_entities(void *pt) 16 | { 17 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 18 | return p_kg->get_entity_size(); 19 | } 20 | 21 | int num_relations(void *pt) 22 | { 23 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 24 | return p_kg->get_relation_size(); 25 | } 26 | 27 | int num_train_triplets(void *pt) 28 | { 29 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 30 | return p_kg->get_train_size(); 31 | } 32 | 33 | int num_valid_triplets(void *pt) 34 | { 35 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 36 | return p_kg->get_valid_size(); 37 | } 38 | 39 | int num_test_triplets(void *pt) 40 | { 41 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 42 | return p_kg->get_test_size(); 43 | } 44 | 45 | void *new_rule_miner(void *pt) 46 | { 47 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 48 | RuleMiner *p_rm = new RuleMiner; 49 | p_rm->init_knowledge_graph(p_kg); 50 | return (void *)(p_rm); 51 | } 52 | 53 | void run_rule_miner(void *pt, int max_length, double portion, int num_threads) 54 | { 55 | RuleMiner *p_rm = (RuleMiner *)(pt); 56 | p_rm->search(max_length, portion, num_threads); 57 | return; 58 | } 59 | 60 | std::vector< std::vector > get_logic_rules(void *pt) 61 | { 62 | RuleMiner *p_rm = (RuleMiner *)(pt); 63 | std::vector< std::vector > rules; 64 | std::vector rule; 65 | 66 | int num_relations = p_rm->get_relation_size(); 67 | for (int r = 0; r != num_relations; r++) 68 | { 69 | for (int k = 0; k != int((p_rm->get_logic_rules())[r].size()); k++) 70 | { 71 | rule.clear(); 72 | rule.push_back((p_rm->get_logic_rules())[r][k].r_head); 73 | for (int i = 0; i != int((p_rm->get_logic_rules())[r][k].r_body.size()); i++) 74 | { 75 | rule.push_back((p_rm->get_logic_rules())[r][k].r_body[i]); 76 | } 77 | rules.push_back(rule); 78 | } 79 | } 80 | return rules; 81 | } 82 | 83 | void *new_reasoning_predictor(void *pt) 84 | { 85 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 86 | ReasoningPredictor *p_rp = new ReasoningPredictor; 87 | p_rp->init_knowledge_graph(p_kg); 88 | return (void *)(p_rp); 89 | } 90 | 91 | void load_reasoning_predictor(void *pt, char *file_name) 92 | { 93 | ReasoningPredictor *p_rp = (ReasoningPredictor *)(pt); 94 | p_rp->in_rules(file_name); 95 | return; 96 | } 97 | 98 | bool check_valid(void *pt, int h, int r, int t) 99 | { 100 | KnowledgeGraph *p_kg = (KnowledgeGraph *)(pt); 101 | Triplet triplet; 102 | triplet.h = h; triplet.r = r; triplet.t = t; 103 | return p_kg->check_true(triplet); 104 | } 105 | 106 | std::vector get_data(void *pt, char *mode, double portion, int num_threads) 107 | { 108 | ReasoningPredictor *p_rp = (ReasoningPredictor *)(pt); 109 | std::vector data; 110 | 111 | if (strcmp(mode, "train") == 0) 112 | { 113 | p_rp->out_train(&data, portion, num_threads); 114 | } 115 | if (strcmp(mode, "valid") == 0) 116 | { 117 | p_rp->out_test(&data, false, num_threads); 118 | } 119 | if (strcmp(mode, "test") == 0) 120 | { 121 | p_rp->out_test(&data, true, num_threads); 122 | } 123 | 124 | return data; 125 | } 126 | 127 | std::vector get_data_single(void *pt, char *mode, int h, int r, int t) 128 | { 129 | ReasoningPredictor *p_rp = (ReasoningPredictor *)(pt); 130 | std::vector data; 131 | 132 | if (strcmp(mode, "train") == 0) 133 | { 134 | p_rp->out_train_single(h, r, t, &data); 135 | } 136 | else 137 | { 138 | p_rp->out_test_single(h, r, t, &data); 139 | } 140 | 141 | return data; 142 | } 143 | 144 | std::vector get_count(void *pt, char *mode, int num_threads) 145 | { 146 | ReasoningPredictor *p_rp = (ReasoningPredictor *)(pt); 147 | std::vector data; 148 | 149 | if (strcmp(mode, "valid") == 0) 150 | { 151 | p_rp->out_test_count(&data, false, num_threads); 152 | } 153 | if (strcmp(mode, "test") == 0) 154 | { 155 | p_rp->out_test_count(&data, true, num_threads); 156 | } 157 | 158 | return data; 159 | } 160 | 161 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 162 | { 163 | m.doc() = "This is pyrnnlogic"; 164 | m.def("new_knowledge_graph", new_knowledge_graph, py::arg("data_path")); 165 | m.def("num_entities", num_entities, py::arg("pt")); 166 | m.def("num_relations", num_relations, py::arg("pt")); 167 | m.def("num_train_triplets", num_train_triplets, py::arg("pt")); 168 | m.def("num_valid_triplets", num_valid_triplets, py::arg("pt")); 169 | m.def("num_test_triplets", num_test_triplets, py::arg("pt")); 170 | m.def("new_rule_miner", new_rule_miner, py::arg("pt")); 171 | m.def("run_rule_miner", run_rule_miner, py::arg("pt"), py::arg("max_length"), py::arg("portion"), py::arg("num_threads")); 172 | m.def("get_logic_rules", get_logic_rules, py::arg("pt")); 173 | m.def("new_reasoning_predictor", new_reasoning_predictor, py::arg("pt")); 174 | m.def("load_reasoning_predictor", load_reasoning_predictor, py::arg("pt"), py::arg("file_name")); 175 | m.def("check_valid", check_valid, py::arg("pt"), py::arg("h"), py::arg("r"), py::arg("t")); 176 | m.def("get_data", get_data, py::arg("pt"), py::arg("mode"), py::arg("portion"), py::arg("num_threads")); 177 | m.def("get_data_single", get_data_single, py::arg("pt"), py::arg("mode"), py::arg("h"), py::arg("r"), py::arg("t")); 178 | m.def("get_count", get_count, py::arg("pt"), py::arg("mode"), py::arg("num_threads")); 179 | } 180 | -------------------------------------------------------------------------------- /miner/rnnlogic.cpp: -------------------------------------------------------------------------------- 1 | #include "rnnlogic.h" 2 | 3 | double sigmoid(double x) 4 | { 5 | return 1.0 / (1.0 + exp(-x)); 6 | } 7 | 8 | double abs_val(double x) 9 | { 10 | if (x < 0) return -x; 11 | else return x; 12 | } 13 | 14 | /***************************** 15 | ArgStruct 16 | *****************************/ 17 | 18 | ArgStruct::ArgStruct(void *_ptr, int _id) 19 | { 20 | ptr = _ptr; 21 | id = _id; 22 | } 23 | 24 | /***************************** 25 | Triplet 26 | *****************************/ 27 | 28 | bool operator < (Triplet u, Triplet v) 29 | { 30 | if (u.r == v.r) 31 | { 32 | if (u.h == v.h) return u.t < v.t; 33 | return u.h < v.h; 34 | } 35 | return u.r < v.r; 36 | } 37 | 38 | bool operator == (Triplet u, Triplet v) 39 | { 40 | if (u.h == v.h && u.t == v.t && u.r == v.r) return true; 41 | return false; 42 | } 43 | 44 | /***************************** 45 | RankListEntry 46 | *****************************/ 47 | 48 | bool operator < (RankListEntry u, RankListEntry v) 49 | { 50 | return u.val > v.val; 51 | } 52 | 53 | /***************************** 54 | Parameter 55 | *****************************/ 56 | 57 | Parameter::Parameter() 58 | { 59 | data = 0; m = 0; v = 0; t = 0; 60 | } 61 | 62 | void Parameter::clear() 63 | { 64 | data = 0; m = 0; v = 0; t = 0; 65 | } 66 | 67 | void Parameter::update(double grad, double learning_rate, double weight_decay) 68 | { 69 | double g = grad - weight_decay * data; 70 | 71 | t += 1; 72 | m = 0.9 * m + 0.1 * g; 73 | v = 0.999 * v + 0.001 * g * g; 74 | 75 | double bias1 = 1 - exp(log(0.9) * t); 76 | double bias2 = 1 - exp(log(0.999) * t); 77 | 78 | double mt = m / bias1; 79 | double vt = sqrt(v) / sqrt(bias2) + 0.00000001; 80 | 81 | data += learning_rate * mt / vt; 82 | } 83 | 84 | /***************************** 85 | Rule 86 | *****************************/ 87 | 88 | Rule::Rule() 89 | { 90 | r_body.clear(); r_head = -1; 91 | type = -1; 92 | H = 0; 93 | cn = 0; 94 | prior = 0; 95 | wt.clear(); 96 | } 97 | 98 | Rule::~Rule() 99 | { 100 | r_body.clear(); r_head = -1; 101 | type = -1; 102 | H = 0; 103 | cn = 0; 104 | prior = 0; 105 | wt.clear(); 106 | } 107 | 108 | void Rule::clear() 109 | { 110 | r_body.clear(); r_head = -1; 111 | type = -1; 112 | H = 0; 113 | cn = 0; 114 | prior = 0; 115 | wt.clear(); 116 | } 117 | 118 | bool operator < (Rule u, Rule v) 119 | { 120 | if (u.type == v.type) 121 | { 122 | if (u.r_head == v.r_head) 123 | { 124 | for (int k = 0; k != u.type; k++) 125 | { 126 | if (u.r_body[k] != v.r_body[k]) 127 | return u.r_body[k] < v.r_body[k]; 128 | } 129 | } 130 | return u.r_head < v.r_head; 131 | } 132 | return u.type < v.type; 133 | } 134 | 135 | bool operator == (Rule u, Rule v) 136 | { 137 | if (u.r_body == v.r_body && u.r_head == v.r_head && u.type == v.type) return true; 138 | return false; 139 | } 140 | 141 | /***************************** 142 | Result 143 | *****************************/ 144 | 145 | Result::Result() 146 | { 147 | h1 = 0; h3 = 0; h10 = 0; mr = 0; mrr = 0; 148 | } 149 | 150 | Result::Result(double mr_, double mrr_, double h1_, double h3_, double h10_) 151 | { 152 | h1 = h1_; h3 = h3_; h10 = h10_; mr = mr_; mrr = mrr_; 153 | } 154 | 155 | /***************************** 156 | KnowledgeGraph 157 | *****************************/ 158 | 159 | KnowledgeGraph::KnowledgeGraph() 160 | { 161 | entity_size = 0; relation_size = 0; 162 | train_triplet_size = 0; valid_triplet_size = 0; test_triplet_size = 0; 163 | all_triplet_size = 0; 164 | 165 | ent2id.clear(); rel2id.clear(); 166 | id2ent.clear(); id2rel.clear(); 167 | train_triplets.clear(); valid_triplets.clear(); test_triplets.clear(); 168 | set_train_triplets.clear(); set_all_triplets.clear(); 169 | e2r2n = NULL; 170 | } 171 | 172 | KnowledgeGraph::~KnowledgeGraph() 173 | { 174 | ent2id.clear(); rel2id.clear(); 175 | id2ent.clear(); id2rel.clear(); 176 | train_triplets.clear(); valid_triplets.clear(); test_triplets.clear(); 177 | set_train_triplets.clear(); set_all_triplets.clear(); 178 | for (int k = 0; k != entity_size; k++) 179 | { 180 | for (int r = 0; r != relation_size; r++) 181 | e2r2n[k][r].clear(); 182 | delete [] e2r2n[k]; 183 | } 184 | delete [] e2r2n; 185 | } 186 | 187 | int KnowledgeGraph::get_entity_size() 188 | { 189 | return entity_size; 190 | } 191 | 192 | int KnowledgeGraph::get_relation_size() 193 | { 194 | return relation_size; 195 | } 196 | 197 | int KnowledgeGraph::get_train_size() 198 | { 199 | return train_triplet_size; 200 | } 201 | 202 | int KnowledgeGraph::get_valid_size() 203 | { 204 | return valid_triplet_size; 205 | } 206 | 207 | int KnowledgeGraph::get_test_size() 208 | { 209 | return test_triplet_size; 210 | } 211 | 212 | void KnowledgeGraph::read_data(char *data_path) 213 | { 214 | char s_head[MAX_STRING], s_tail[MAX_STRING], s_ent[MAX_STRING], s_rel[MAX_STRING], s_file[MAX_STRING]; 215 | int h, t, r, id; 216 | Triplet triplet; 217 | std::map::iterator iter; 218 | FILE *fi; 219 | 220 | strcpy(s_file, data_path); 221 | strcat(s_file, "/entities.dict"); 222 | fi = fopen(s_file, "rb"); 223 | if (fi == NULL) 224 | { 225 | printf("ERROR: file of entities not found!\n"); 226 | exit(1); 227 | } 228 | while (1) 229 | { 230 | if (fscanf(fi, "%d %s", &id, s_ent) != 2) break; 231 | 232 | ent2id[s_ent] = id; 233 | id2ent[id] = s_ent; 234 | entity_size += 1; 235 | } 236 | fclose(fi); 237 | 238 | strcpy(s_file, data_path); 239 | strcat(s_file, "/relations.dict"); 240 | fi = fopen(s_file, "rb"); 241 | if (fi == NULL) 242 | { 243 | printf("ERROR: file of relations not found!\n"); 244 | exit(1); 245 | } 246 | while (1) 247 | { 248 | if (fscanf(fi, "%d %s", &id, s_rel) != 2) break; 249 | 250 | rel2id[s_rel] = id; 251 | id2rel[id] = s_rel; 252 | relation_size += 1; 253 | } 254 | fclose(fi); 255 | 256 | strcpy(s_file, data_path); 257 | strcat(s_file, "/train.txt"); 258 | fi = fopen(s_file, "rb"); 259 | if (fi == NULL) 260 | { 261 | printf("ERROR: file of train triplets not found!\n"); 262 | exit(1); 263 | } 264 | while (1) 265 | { 266 | if (fscanf(fi, "%s %s %s", s_head, s_rel, s_tail) != 3) break; 267 | if (ent2id.count(s_head) == 0 || ent2id.count(s_tail) == 0 || rel2id.count(s_rel) == 0) continue; 268 | 269 | h = ent2id[s_head]; t = ent2id[s_tail]; r = rel2id[s_rel]; 270 | triplet.h = h; triplet.t = t; triplet.r = r; 271 | train_triplets.push_back(triplet); 272 | set_train_triplets.insert(triplet); 273 | set_all_triplets.insert(triplet); 274 | } 275 | fclose(fi); 276 | 277 | train_triplet_size = int(train_triplets.size()); 278 | e2r2n = new std::vector * [entity_size]; 279 | for (int k = 0; k != entity_size; k++) e2r2n[k] = new std::vector [relation_size]; 280 | for (int k = 0; k != train_triplet_size; k++) 281 | { 282 | h = train_triplets[k].h; r = train_triplets[k].r; t = train_triplets[k].t; 283 | e2r2n[h][r].push_back(t); 284 | } 285 | 286 | strcpy(s_file, data_path); 287 | strcat(s_file, "/valid.txt"); 288 | fi = fopen(s_file, "rb"); 289 | if (fi == NULL) 290 | { 291 | printf("ERROR: file of test triplets not found!\n"); 292 | exit(1); 293 | } 294 | while (1) 295 | { 296 | if (fscanf(fi, "%s %s %s", s_head, s_rel, s_tail) != 3) break; 297 | if (ent2id.count(s_head) == 0 || ent2id.count(s_tail) == 0 || rel2id.count(s_rel) == 0) continue; 298 | 299 | h = ent2id[s_head]; t = ent2id[s_tail]; r = rel2id[s_rel]; 300 | triplet.h = h; triplet.t = t; triplet.r = r; 301 | valid_triplets.push_back(triplet); 302 | set_all_triplets.insert(triplet); 303 | } 304 | fclose(fi); 305 | valid_triplet_size = int(valid_triplets.size()); 306 | 307 | strcpy(s_file, data_path); 308 | strcat(s_file, "/test.txt"); 309 | fi = fopen(s_file, "rb"); 310 | if (fi == NULL) 311 | { 312 | printf("ERROR: file of test triplets not found!\n"); 313 | exit(1); 314 | } 315 | while (1) 316 | { 317 | if (fscanf(fi, "%s %s %s", s_head, s_rel, s_tail) != 3) break; 318 | if (ent2id.count(s_head) == 0 || ent2id.count(s_tail) == 0 || rel2id.count(s_rel) == 0) continue; 319 | 320 | h = ent2id[s_head]; t = ent2id[s_tail]; r = rel2id[s_rel]; 321 | triplet.h = h; triplet.t = t; triplet.r = r; 322 | test_triplets.push_back(triplet); 323 | set_all_triplets.insert(triplet); 324 | } 325 | fclose(fi); 326 | test_triplet_size = int(test_triplets.size()); 327 | 328 | all_triplet_size = int(set_all_triplets.size()); 329 | 330 | printf("#Entities: %d \n", entity_size); 331 | printf("#Relations: %d \n", relation_size); 332 | printf("#Train triplets: %d \n", train_triplet_size); 333 | printf("#Valid triplets: %d \n", valid_triplet_size); 334 | printf("#Test triplets: %d \n", test_triplet_size); 335 | printf("#All triplets: %d \n", all_triplet_size); 336 | } 337 | 338 | bool KnowledgeGraph::check_observed(Triplet triplet) 339 | { 340 | if (set_train_triplets.count(triplet) != 0) return true; 341 | else return false; 342 | } 343 | 344 | bool KnowledgeGraph::check_true(Triplet triplet) 345 | { 346 | if (set_all_triplets.count(triplet) != 0) return true; 347 | else return false; 348 | } 349 | 350 | void KnowledgeGraph::rule_search(int r, int e, int goal, int *path, int depth, int max_depth, std::set *rule_set, Triplet removed_triplet) 351 | { 352 | if (e == goal) 353 | { 354 | Rule rule; 355 | rule.type = depth; 356 | rule.r_head = r; 357 | rule.r_body.clear(); 358 | for (int k = 0; k != depth; k++) 359 | { 360 | rule.r_body.push_back(path[k]); 361 | } 362 | rule_set->insert(rule); 363 | return; 364 | } 365 | if (depth == max_depth) 366 | { 367 | return; 368 | } 369 | 370 | int cur_r, cur_n, len; 371 | for (cur_r = 0; cur_r != relation_size; cur_r++) 372 | { 373 | len = int(e2r2n[e][cur_r].size()); 374 | for (int k = 0; k != len; k++) 375 | { 376 | cur_n = e2r2n[e][cur_r][k]; 377 | if (e == removed_triplet.h && cur_r == removed_triplet.r && cur_n == removed_triplet.t) continue; 378 | path[depth] = cur_r; 379 | rule_search(r, cur_n, goal, path, depth+1, max_depth, rule_set, removed_triplet); 380 | } 381 | } 382 | } 383 | 384 | /* 385 | void KnowledgeGraph::rule_destination(int e, Rule rule, std::vector &dests, Triplet removed_triplet) 386 | { 387 | std::queue< std::pair > queue; 388 | queue.push(std::make_pair(e, 0)); 389 | int current_e, current_d, current_r, next_e; 390 | while (!queue.empty()) 391 | { 392 | std::pair pair = queue.front(); 393 | current_e = pair.first; 394 | current_d = pair.second; 395 | queue.pop(); 396 | if (current_d == int(rule.r_body.size())) 397 | { 398 | dests.push_back(current_e); 399 | continue; 400 | } 401 | current_r = rule.r_body[current_d]; 402 | for (int k = 0; k != int(e2r2n[current_e][current_r].size()); k++) 403 | { 404 | next_e = e2r2n[current_e][current_r][k]; 405 | if (current_e == removed_triplet.h && current_r == removed_triplet.r && next_e == removed_triplet.t) continue; 406 | queue.push(std::make_pair(next_e, current_d + 1)); 407 | } 408 | } 409 | } 410 | */ 411 | 412 | void KnowledgeGraph::rule_destination(int e, Rule rule, std::map *dest2count, Triplet removed_triplet) 413 | { 414 | std::map m[rule.r_body.size() + 1]; 415 | std::map *pt; 416 | std::map::iterator iter; 417 | int current_e, current_r, current_n, next_e; 418 | 419 | (*dest2count).clear(); 420 | m[0][e] = 1; 421 | for (int k = 0; k != int(rule.r_body.size()); k++) 422 | { 423 | if (k == int(rule.r_body.size()) - 1) pt = dest2count; 424 | else pt = &(m[k + 1]); 425 | 426 | for (iter = m[k].begin(); iter != m[k].end(); iter++) 427 | { 428 | current_e = iter->first; 429 | current_n = iter->second; 430 | current_r = rule.r_body[k]; 431 | 432 | for (int i = 0; i != int(e2r2n[current_e][current_r].size()); i++) 433 | { 434 | next_e = e2r2n[current_e][current_r][i]; 435 | if (current_e == removed_triplet.h && current_r == removed_triplet.r && next_e == removed_triplet.t) continue; 436 | 437 | if ((*pt).count(next_e) == 0) (*pt)[next_e] = 0; 438 | (*pt)[next_e] += current_n; 439 | } 440 | } 441 | } 442 | } 443 | 444 | /***************************** 445 | RuleMiner 446 | *****************************/ 447 | 448 | RuleMiner::RuleMiner() 449 | { 450 | num_threads = 4; max_length = 3; 451 | portion = 1; 452 | total_count = 0; 453 | rel2ruleset = NULL; rel2rules = NULL; 454 | sem_init(&mutex, 0, 1); 455 | p_kg = NULL; 456 | } 457 | 458 | RuleMiner::~RuleMiner() 459 | { 460 | total_count = 0; 461 | if (rel2ruleset != NULL) 462 | { 463 | for (int r = 0; r != p_kg->relation_size; r++) rel2ruleset[r].clear(); 464 | delete [] rel2ruleset; 465 | rel2ruleset = NULL; 466 | } 467 | if (rel2rules != NULL) 468 | { 469 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 470 | delete [] rel2rules; 471 | rel2rules = NULL; 472 | } 473 | sem_init(&mutex, 0, 1); 474 | p_kg = NULL; 475 | } 476 | 477 | void RuleMiner::init_knowledge_graph(KnowledgeGraph *_p_kg) 478 | { 479 | p_kg = _p_kg; 480 | rel2rules = new std::vector [p_kg->relation_size]; 481 | rel2ruleset = new std::set [p_kg->relation_size]; 482 | } 483 | 484 | void RuleMiner::clear() 485 | { 486 | total_count = 0; 487 | for (int k = 0; k != p_kg->relation_size; k++) 488 | { 489 | rel2rules[k].clear(); 490 | rel2ruleset[k].clear(); 491 | } 492 | sem_init(&mutex, 0, 1); 493 | } 494 | 495 | std::vector *RuleMiner::get_logic_rules() 496 | { 497 | return rel2rules; 498 | } 499 | 500 | int RuleMiner::get_relation_size() 501 | { 502 | return p_kg->relation_size; 503 | } 504 | 505 | void RuleMiner::search_thread(int thread) 506 | { 507 | int triplet_size = p_kg->train_triplet_size; 508 | int bg = int(triplet_size / num_threads) * thread; 509 | int ed = bg + int(triplet_size / num_threads * portion); 510 | if (thread == num_threads - 1 && portion == 1) ed = triplet_size; 511 | 512 | std::set::iterator iter; 513 | std::set rule_set; 514 | std::vector dests; 515 | Rule rule; 516 | int path[MAX_LENGTH], h, r, t; 517 | 518 | for (int T = bg; T != ed; T++) 519 | { 520 | if (T % 10 == 0) 521 | { 522 | total_count += 10; 523 | printf("Rule Discovery | Progress: %.3lf%% %c", (double)total_count / (double)(triplet_size * portion + 1) * 100, 13); 524 | fflush(stdout); 525 | } 526 | 527 | h = p_kg->train_triplets[T].h; 528 | r = p_kg->train_triplets[T].r; 529 | t = p_kg->train_triplets[T].t; 530 | 531 | rule_set.clear(); 532 | p_kg->rule_search(r, h, t, path, 0, max_length, &rule_set, p_kg->train_triplets[T]); 533 | 534 | for (iter = rule_set.begin(); iter != rule_set.end(); iter++) 535 | { 536 | if (iter->type == 1 && iter->r_body[0] == r) 537 | { 538 | rule_set.erase(iter); 539 | break; 540 | } 541 | } 542 | 543 | for (iter = rule_set.begin(); iter != rule_set.end(); iter++) 544 | { 545 | rule = *iter; 546 | sem_wait(&mutex); 547 | rel2ruleset[r].insert(rule); 548 | sem_post(&mutex); 549 | } 550 | } 551 | dests.clear(); 552 | rule_set.clear(); 553 | pthread_exit(NULL); 554 | } 555 | 556 | void *RuleMiner::search_thread_caller(void *arg) 557 | { 558 | RuleMiner *ptr = (RuleMiner *)(((ArgStruct *)arg)->ptr); 559 | int thread = ((ArgStruct *)arg)->id; 560 | ptr->search_thread(thread); 561 | pthread_exit(NULL); 562 | } 563 | 564 | void RuleMiner::search(int _max_length, double _portion, int _num_threads) 565 | { 566 | max_length = _max_length; 567 | portion = _portion; 568 | num_threads = _num_threads; 569 | 570 | std::random_shuffle((p_kg->train_triplets).begin(), (p_kg->train_triplets).end()); 571 | 572 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 573 | for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, RuleMiner::search_thread_caller, new ArgStruct(this, k)); 574 | for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL); 575 | printf("Rule Discovery | DONE! \n"); 576 | free(pt); 577 | 578 | int rel; 579 | Rule rule; 580 | std::set::iterator iter; 581 | for (rel = 0; rel != p_kg->relation_size; rel++) 582 | { 583 | for (iter = rel2ruleset[rel].begin(); iter != rel2ruleset[rel].end(); iter++) 584 | { 585 | rule = *iter; 586 | rel2rules[rel].push_back(rule); 587 | } 588 | } 589 | } 590 | 591 | void RuleMiner::save(char *file_name) 592 | { 593 | FILE *fo = fopen(file_name, "wb"); 594 | int cn = 0; 595 | Rule rule; 596 | for (int r = 0; r != p_kg->relation_size; r++) 597 | { 598 | for (int k = 0; k != int(rel2rules[r].size()); k++) 599 | { 600 | rule = rel2rules[r][k]; 601 | fprintf(fo, "%d %d", rule.type, rule.r_head); 602 | for (int k = 0; k != int(rule.r_body.size()); k++) 603 | fprintf(fo, " %d", rule.r_body[k]); 604 | fprintf(fo, " %lf %lf %lf\n", rule.H, rule.wt.data, rule.prior); 605 | 606 | cn += 1; 607 | } 608 | } 609 | fclose(fo); 610 | 611 | printf("#Logic rules saved: %d\n", cn); 612 | } 613 | 614 | void RuleMiner::load(char *file_name) 615 | { 616 | if (rel2rules != NULL) 617 | { 618 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 619 | delete [] rel2rules; 620 | } 621 | 622 | rel2rules = new std::vector [p_kg->relation_size]; 623 | 624 | FILE *fi = fopen(file_name, "rb"); 625 | int type, r_head, r_body, cn = 0; 626 | double H, wt, prior; 627 | Rule rule; 628 | while (1) 629 | { 630 | if (fscanf(fi, "%d %d", &type, &r_head) != 2) break; 631 | 632 | rule.clear(); 633 | rule.type = type; 634 | rule.r_head = r_head; 635 | for (int k = 0; k != type; k++) 636 | { 637 | if (fscanf(fi, "%d", &r_body) != 1) 638 | { 639 | printf("ERROR: format error in rule files!\n"); 640 | exit(1); 641 | } 642 | rule.r_body.push_back(r_body); 643 | } 644 | if (fscanf(fi, "%lf %lf %lf", &H, &wt, &prior) != 3) 645 | { 646 | printf("ERROR: format error in rule files!\n"); 647 | exit(1); 648 | } 649 | rule.H = H; 650 | rule.wt.data = wt; 651 | rule.prior = prior; 652 | 653 | rel2rules[r_head].push_back(rule); 654 | cn += 1; 655 | } 656 | fclose(fi); 657 | 658 | printf("#Logic rules loaded: %d\n", cn); 659 | } 660 | 661 | /***************************** 662 | ReasoningPredictor 663 | *****************************/ 664 | 665 | ReasoningPredictor::ReasoningPredictor() 666 | { 667 | num_threads = 4; top_k = 100; 668 | temperature = 100; learning_rate = 0.01; weight_decay = 0.0005; 669 | portion = 1.0; 670 | prior_weight = 0; H_temperature = 1; 671 | total_count = 0; total_loss = 0; 672 | rel2rules = NULL; 673 | test = true; fast = true; 674 | ranks.clear(); 675 | sem_init(&mutex, 0, 1); 676 | p_kg = NULL; 677 | } 678 | 679 | ReasoningPredictor::~ReasoningPredictor() 680 | { 681 | total_count = 0; total_loss = 0; 682 | if (rel2rules != NULL) 683 | { 684 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 685 | delete [] rel2rules; 686 | rel2rules = NULL; 687 | } 688 | ranks.clear(); 689 | sem_init(&mutex, 0, 1); 690 | p_kg = NULL; 691 | } 692 | 693 | void ReasoningPredictor::init_knowledge_graph(KnowledgeGraph *_p_kg) 694 | { 695 | p_kg = _p_kg; 696 | } 697 | 698 | void ReasoningPredictor::set_logic_rules(std::vector * _rel2rules) 699 | { 700 | if (rel2rules != NULL) 701 | { 702 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 703 | delete [] rel2rules; 704 | } 705 | 706 | rel2rules = new std::vector [p_kg->relation_size]; 707 | for (int r = 0; r != p_kg->relation_size; r++) 708 | { 709 | rel2rules[r] = _rel2rules[r]; 710 | for (int k = 0; k != int(rel2rules[r].size()); k++) 711 | { 712 | rel2rules[r][k].wt.clear(); 713 | rel2rules[r][k].H = 0; 714 | } 715 | } 716 | } 717 | 718 | std::vector *ReasoningPredictor::get_logic_rules() 719 | { 720 | return rel2rules; 721 | } 722 | 723 | int ReasoningPredictor::get_relation_size() 724 | { 725 | return p_kg->relation_size; 726 | } 727 | 728 | void ReasoningPredictor::learn_thread(int thread) 729 | { 730 | int triplet_size = p_kg->train_triplet_size; 731 | int bg = int(triplet_size / num_threads) * thread; 732 | int ed = bg + int(triplet_size / num_threads * portion); 733 | if (thread == num_threads - 1 && portion == 1) ed = triplet_size; 734 | 735 | Triplet triplet; 736 | int h, r, t, dest, count, index; 737 | double logit, target, grad; 738 | 739 | std::map dest2logit, dest2grad; 740 | std::map::iterator iter_ii; 741 | std::map::iterator iter_id; 742 | 743 | int max_num_rules = 0; 744 | for (r = 0; r != p_kg->relation_size; r++) if (int(rel2rules[r].size()) > max_num_rules) 745 | max_num_rules = int(rel2rules[r].size()); 746 | std::map *index2dest2count = new std::map [max_num_rules]; 747 | 748 | for (int T = bg; T != ed; T++) 749 | { 750 | if (T % 10 == 0) 751 | { 752 | total_count += 10; 753 | printf("Learning Rule Weights | Progress: %.3lf%% | Loss: %.6lf %c", (double)total_count / (double)(triplet_size * portion + 1) * 100, total_loss / total_count, 13); 754 | fflush(stdout); 755 | } 756 | 757 | h = p_kg->train_triplets[T].h; 758 | r = p_kg->train_triplets[T].r; 759 | t = p_kg->train_triplets[T].t; 760 | 761 | dest2logit.clear(); 762 | for (index = 0; index != int(rel2rules[r].size()); index++) 763 | { 764 | p_kg->rule_destination(h, rel2rules[r][index], &(index2dest2count[index]), p_kg->train_triplets[T]); 765 | 766 | for (iter_ii = index2dest2count[index].begin(); iter_ii != index2dest2count[index].end(); iter_ii++) 767 | { 768 | dest = iter_ii->first; 769 | count = iter_ii->second; 770 | 771 | if (dest2logit.count(dest) == 0) dest2logit[dest] = 0; 772 | dest2logit[dest] += rel2rules[r][index].wt.data * count / temperature; 773 | } 774 | } 775 | 776 | double max_val = -1000000, sum_val = 0; 777 | for (iter_id = dest2logit.begin(); iter_id != dest2logit.end(); iter_id++) 778 | max_val = std::max(max_val, iter_id->second); 779 | for (iter_id = dest2logit.begin(); iter_id != dest2logit.end(); iter_id++) 780 | sum_val += exp(iter_id->second - max_val); 781 | for (iter_id = dest2logit.begin(); iter_id != dest2logit.end(); iter_id++) 782 | dest2logit[iter_id->first] = exp(dest2logit[iter_id->first] - max_val) / sum_val; 783 | 784 | dest2grad.clear(); 785 | for (iter_id = dest2logit.begin(); iter_id != dest2logit.end(); iter_id++) 786 | { 787 | dest = iter_id->first; 788 | logit = iter_id->second; 789 | 790 | triplet = p_kg->train_triplets[T]; 791 | triplet.t = dest; 792 | if (p_kg->check_observed(triplet) == true) target = 1.0; 793 | else target = 0; 794 | grad = (target - logit) / temperature; 795 | dest2grad[dest] = grad; 796 | 797 | total_loss += abs_val(target - logit) / dest2logit.size(); 798 | } 799 | 800 | for (index = 0; index != int(rel2rules[r].size()); index++) 801 | { 802 | for (iter_ii = index2dest2count[index].begin(); iter_ii != index2dest2count[index].end(); iter_ii++) 803 | { 804 | dest = iter_ii->first; 805 | count = iter_ii->second; 806 | grad = dest2grad[dest]; 807 | if (fast) rel2rules[r][index].wt.update(grad * count, learning_rate, weight_decay); 808 | else for (int k = 0; k != count; k++) rel2rules[r][index].wt.update(grad, learning_rate, weight_decay); 809 | 810 | } 811 | } 812 | } 813 | dest2logit.clear(); 814 | delete [] index2dest2count; 815 | pthread_exit(NULL); 816 | } 817 | 818 | void *ReasoningPredictor::learn_thread_caller(void *arg) 819 | { 820 | ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr); 821 | int thread = ((ArgStruct *)arg)->id; 822 | ptr->learn_thread(thread); 823 | pthread_exit(NULL); 824 | } 825 | 826 | void ReasoningPredictor::learn(double _learning_rate, double _weight_decay, double _temperature, bool _fast, double _portion, int _num_threads) 827 | { 828 | learning_rate = _learning_rate; 829 | weight_decay = _weight_decay; 830 | temperature = _temperature; 831 | fast = _fast; 832 | portion = _portion; 833 | num_threads = _num_threads; 834 | 835 | total_count = 0; 836 | total_loss = 0; 837 | 838 | std::random_shuffle((p_kg->train_triplets).begin(), (p_kg->train_triplets).end()); 839 | 840 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 841 | for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::learn_thread_caller, new ArgStruct(this, k)); 842 | for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL); 843 | printf("Learning Rule Weights | DONE! | Loss: %.6lf \n", total_loss / total_count); 844 | free(pt); 845 | } 846 | 847 | void ReasoningPredictor::H_score_thread(int thread) 848 | { 849 | int triplet_size = p_kg->train_triplet_size; 850 | int bg = int(triplet_size / num_threads) * thread; 851 | int ed = bg + int(triplet_size / num_threads * portion); 852 | if (thread == num_threads - 1 && portion == 1) ed = triplet_size; 853 | 854 | std::vector dests; 855 | int h, r, t, dest, count, index; 856 | 857 | std::set destset; 858 | std::map::iterator iter_ii; 859 | 860 | int max_num_rules = 0; 861 | for (int r = 0; r != p_kg->relation_size; r++) if (int(rel2rules[r].size()) > max_num_rules) 862 | max_num_rules = int(rel2rules[r].size()); 863 | std::map *index2dest2count = new std::map [max_num_rules]; 864 | 865 | RankListEntry *rule2score = new RankListEntry [max_num_rules]; 866 | 867 | for (int T = bg; T != ed; T++) 868 | { 869 | if (T % 10 == 0) 870 | { 871 | total_count += 10; 872 | printf("Computing H Score | Progress: %.3lf%% %c", (double)total_count / (double)(triplet_size * portion + 1) * 100, 13); 873 | fflush(stdout); 874 | } 875 | 876 | h = p_kg->train_triplets[T].h; 877 | r = p_kg->train_triplets[T].r; 878 | t = p_kg->train_triplets[T].t; 879 | 880 | destset.clear(); 881 | for (index = 0; index != int(rel2rules[r].size()); index++) 882 | { 883 | p_kg->rule_destination(h, rel2rules[r][index], &(index2dest2count[index]), p_kg->train_triplets[T]); 884 | 885 | for (iter_ii = index2dest2count[index].begin(); iter_ii != index2dest2count[index].end(); iter_ii++) 886 | { 887 | dest = iter_ii->first; 888 | 889 | if (destset.count(dest) == 0) destset.insert(dest); 890 | } 891 | } 892 | 893 | for (index = 0; index != int(rel2rules[r].size()); index++) 894 | { 895 | rule2score[index].id = index; 896 | rule2score[index].val = rel2rules[r][index].prior * prior_weight; 897 | for (iter_ii = index2dest2count[index].begin(); iter_ii != index2dest2count[index].end(); iter_ii++) 898 | { 899 | dest = iter_ii->first; 900 | count = iter_ii->second; 901 | 902 | if (dest == t) rule2score[index].val += rel2rules[r][index].wt.data * count; 903 | rule2score[index].val -= rel2rules[r][index].wt.data * count / destset.size(); 904 | } 905 | } 906 | 907 | if (top_k == 0) 908 | { 909 | double max_val = -1000000, sum_val = 0; 910 | 911 | for (int k = 0; k != int(rel2rules[r].size()); k++) 912 | rule2score[k].val /= H_temperature; 913 | for (int k = 0; k != int(rel2rules[r].size()); k++) 914 | max_val = std::max(max_val, rule2score[k].val); 915 | for (int k = 0; k != int(rel2rules[r].size()); k++) 916 | sum_val += exp(rule2score[k].val - max_val); 917 | for (int k = 0; k != int(rel2rules[r].size()); k++) 918 | { 919 | index = rule2score[k].id; 920 | rel2rules[r][index].H += exp(rule2score[k].val - max_val) / sum_val / triplet_size; 921 | } 922 | } 923 | else 924 | { 925 | std::sort(rule2score, rule2score + int(rel2rules[r].size())); 926 | 927 | for (int k = 0; k != int(rel2rules[r].size()); k++) 928 | { 929 | if (k == top_k) break; 930 | 931 | index = rule2score[k].id; 932 | rel2rules[r][index].H += 1.0 / top_k / triplet_size; 933 | } 934 | } 935 | } 936 | destset.clear(); 937 | delete [] rule2score; 938 | delete [] index2dest2count; 939 | pthread_exit(NULL); 940 | } 941 | 942 | void *ReasoningPredictor::H_score_thread_caller(void *arg) 943 | { 944 | ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr); 945 | int thread = ((ArgStruct *)arg)->id; 946 | ptr->H_score_thread(thread); 947 | pthread_exit(NULL); 948 | } 949 | 950 | void ReasoningPredictor::H_score(int _top_k, double _H_temperature, double _prior_weight, double _portion, int _num_threads) 951 | { 952 | top_k = _top_k; 953 | H_temperature = _H_temperature; 954 | prior_weight = _prior_weight; 955 | portion = _portion; 956 | num_threads = _num_threads; 957 | 958 | total_count = 0; 959 | total_loss = 0; 960 | 961 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 962 | for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::H_score_thread_caller, new ArgStruct(this, k)); 963 | for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL); 964 | printf("Computing H Score | DONE! \n"); 965 | free(pt); 966 | } 967 | 968 | void ReasoningPredictor::evaluate_thread(int thread) 969 | { 970 | std::vector *p_triplets; 971 | if (test) p_triplets = &(p_kg->test_triplets); 972 | else p_triplets = &(p_kg->valid_triplets); 973 | 974 | int triplet_size = int((*p_triplets).size()); 975 | int bg = int(triplet_size / num_threads) * thread; 976 | int ed = int(triplet_size / num_threads) * (thread + 1); 977 | if (thread == num_threads - 1) ed = triplet_size; 978 | 979 | Triplet triplet; 980 | int h, r, t, dest, count, index, num_g, num_ge; 981 | double t_val; 982 | 983 | std::map dest2count; 984 | std::map::iterator iter_ii; 985 | 986 | RankListEntry *rank_list = new RankListEntry [p_kg->entity_size]; 987 | 988 | for (int T = bg; T != ed; T++) 989 | { 990 | if (T % 10 == 0) 991 | { 992 | total_count += 10; 993 | printf("Evaluation | Progress: %.3lf%% %c", (double)total_count / (double)(triplet_size + 1) * 100, 13); 994 | fflush(stdout); 995 | } 996 | 997 | h = (*p_triplets)[T].h; 998 | r = (*p_triplets)[T].r; 999 | t = (*p_triplets)[T].t; 1000 | 1001 | for (int k = 0; k != p_kg->entity_size; k++) 1002 | { 1003 | rank_list[k].id = k; 1004 | rank_list[k].val = 0; 1005 | } 1006 | 1007 | for (index = 0; index != int(rel2rules[r].size()); index++) 1008 | { 1009 | p_kg->rule_destination(h, rel2rules[r][index], &dest2count, (*p_triplets)[T]); 1010 | 1011 | for (iter_ii = dest2count.begin(); iter_ii != dest2count.end(); iter_ii++) 1012 | { 1013 | dest = iter_ii->first; 1014 | count = iter_ii->second; 1015 | 1016 | rank_list[dest].val += rel2rules[r][index].wt.data * count; 1017 | } 1018 | } 1019 | 1020 | t_val = rank_list[t].val; 1021 | 1022 | std::sort(rank_list, rank_list + p_kg->entity_size); 1023 | 1024 | num_g = 0; num_ge = 0; 1025 | triplet = (*p_triplets)[T]; 1026 | for (int k = 0; k != p_kg->entity_size; k++) 1027 | { 1028 | triplet.t = rank_list[k].id; 1029 | if (p_kg->check_true(triplet) == true && rank_list[k].id != t) continue; 1030 | 1031 | if (rank_list[k].val > t_val) num_g += 1; 1032 | if (rank_list[k].val >= t_val) num_ge += 1; 1033 | if (rank_list[k].val < t_val) break; 1034 | } 1035 | 1036 | sem_wait(&mutex); 1037 | ranks.push_back(std::make_pair(num_g, num_ge)); 1038 | sem_post(&mutex); 1039 | } 1040 | dest2count.clear(); 1041 | delete [] rank_list; 1042 | pthread_exit(NULL); 1043 | } 1044 | 1045 | void *ReasoningPredictor::evaluate_thread_caller(void *arg) 1046 | { 1047 | ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr); 1048 | int thread = ((ArgStruct *)arg)->id; 1049 | ptr->evaluate_thread(thread); 1050 | pthread_exit(NULL); 1051 | } 1052 | 1053 | Result ReasoningPredictor::evaluate(bool _test, int _num_threads) 1054 | { 1055 | test = _test; 1056 | num_threads = _num_threads; 1057 | 1058 | ranks.clear(); 1059 | total_count = 0; 1060 | total_loss = 0; 1061 | sem_init(&mutex, 0, 1); 1062 | 1063 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 1064 | for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::evaluate_thread_caller, new ArgStruct(this, k)); 1065 | for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL); 1066 | if (test == true) printf("Evaluation Test | DONE! \n"); 1067 | else printf("Evaluation Valid | DONE! \n"); 1068 | free(pt); 1069 | 1070 | int num_entities = p_kg->entity_size; 1071 | double *table_mr = (double *)calloc(num_entities + 1, sizeof(double)); 1072 | double *table_mrr = (double *)calloc(num_entities + 1, sizeof(double)); 1073 | double *table_hit1 = (double *)calloc(num_entities + 1, sizeof(double)); 1074 | double *table_hit3 = (double *)calloc(num_entities + 1, sizeof(double)); 1075 | double *table_hit10 = (double *)calloc(num_entities + 1, sizeof(double)); 1076 | for (int rank = 1; rank <= num_entities; rank++) 1077 | { 1078 | table_mr[rank] = rank; 1079 | table_mrr[rank] = 1.0 / rank; 1080 | if (rank <= 1) table_hit1[rank] = 1; 1081 | if (rank <= 3) table_hit3[rank] = 1; 1082 | if (rank <= 10) table_hit10[rank] = 1; 1083 | } 1084 | for (int rank = 1; rank <= num_entities; rank++) 1085 | { 1086 | table_mr[rank] += table_mr[rank - 1]; 1087 | table_mrr[rank] += table_mrr[rank - 1]; 1088 | table_hit1[rank] += table_hit1[rank - 1]; 1089 | table_hit3[rank] += table_hit3[rank - 1]; 1090 | table_hit10[rank] += table_hit10[rank - 1]; 1091 | } 1092 | 1093 | double mr = 0, mrr = 0, hit1 = 0, hit3 = 0, hit10 = 0; 1094 | std::vector< std::pair >::iterator iter; 1095 | for (iter = ranks.begin(); iter != ranks.end(); iter++) 1096 | { 1097 | int num_g = iter->first; 1098 | int num_ge = iter->second; 1099 | mr += (table_mr[num_ge] - table_mr[num_g]) / (num_ge - num_g); 1100 | mrr += (table_mrr[num_ge] - table_mrr[num_g]) / (num_ge - num_g); 1101 | hit1 += (table_hit1[num_ge] - table_hit1[num_g]) / (num_ge - num_g); 1102 | hit3 += (table_hit3[num_ge] - table_hit3[num_g]) / (num_ge - num_g); 1103 | hit10 += (table_hit10[num_ge] - table_hit10[num_g]) / (num_ge - num_g); 1104 | } 1105 | 1106 | free(table_mr); 1107 | free(table_mrr); 1108 | free(table_hit1); 1109 | free(table_hit3); 1110 | free(table_hit10); 1111 | 1112 | mr /= ranks.size(); 1113 | mrr /= ranks.size(); 1114 | hit1 /= ranks.size(); 1115 | hit3 /= ranks.size(); 1116 | hit10 /= ranks.size(); 1117 | 1118 | Result result(mr, mrr, hit1, hit3, hit10); 1119 | return result; 1120 | } 1121 | 1122 | void ReasoningPredictor::out_train_thread(int thread) 1123 | { 1124 | int triplet_size = p_kg->train_triplet_size; 1125 | int bg = int(triplet_size / num_threads) * thread; 1126 | int ed = bg + int(triplet_size / num_threads * portion); 1127 | if (thread == num_threads - 1 && portion == 1) ed = triplet_size; 1128 | 1129 | Triplet triplet; 1130 | int h, r, t, dest, count, index, valid; 1131 | 1132 | std::map dest2count; 1133 | std::map::iterator iter_ii; 1134 | std::map > dest2index2count; 1135 | std::map >::iterator iter_im; 1136 | 1137 | for (int T = bg; T != ed; T++) 1138 | { 1139 | if (T % 10 == 0) 1140 | { 1141 | total_count += 10; 1142 | printf("Generating Data | Progress: %.3lf%% %c", (double)total_count / (double)(triplet_size * portion + 1) * 100, 13); 1143 | fflush(stdout); 1144 | } 1145 | 1146 | h = p_kg->train_triplets[T].h; 1147 | r = p_kg->train_triplets[T].r; 1148 | t = p_kg->train_triplets[T].t; 1149 | 1150 | dest2index2count.clear(); 1151 | for (index = 0; index != int(rel2rules[r].size()); index++) 1152 | { 1153 | p_kg->rule_destination(h, rel2rules[r][index], &dest2count, p_kg->train_triplets[T]); 1154 | 1155 | for (iter_ii = dest2count.begin(); iter_ii != dest2count.end(); iter_ii++) 1156 | { 1157 | dest = iter_ii->first; 1158 | count = iter_ii->second; 1159 | 1160 | if (dest2index2count.count(dest) == 0) dest2index2count[dest] = std::map(); 1161 | dest2index2count[dest][index] = count; 1162 | } 1163 | } 1164 | 1165 | for (iter_im = dest2index2count.begin(); iter_im != dest2index2count.end(); iter_im++) 1166 | { 1167 | dest = iter_im->first; 1168 | 1169 | valid = 0; 1170 | triplet = p_kg->train_triplets[T]; 1171 | triplet.t = dest; 1172 | if (p_kg->check_observed(triplet) == true) valid = 1; 1173 | 1174 | thread_data[thread].push_back(h); 1175 | thread_data[thread].push_back(r); 1176 | thread_data[thread].push_back(t); 1177 | thread_data[thread].push_back(valid); 1178 | thread_data[thread].push_back(dest); 1179 | thread_data[thread].push_back(int((iter_im->second).size())); 1180 | 1181 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1182 | { 1183 | thread_data[thread].push_back(iter_ii->first); 1184 | } 1185 | 1186 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1187 | { 1188 | thread_data[thread].push_back(iter_ii->second); 1189 | } 1190 | 1191 | thread_split[thread].push_back(int(thread_data[thread].size())); 1192 | } 1193 | } 1194 | dest2count.clear(); 1195 | dest2index2count.clear(); 1196 | pthread_exit(NULL); 1197 | } 1198 | 1199 | void *ReasoningPredictor::out_train_thread_caller(void *arg) 1200 | { 1201 | ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr); 1202 | int thread = ((ArgStruct *)arg)->id; 1203 | ptr->out_train_thread(thread); 1204 | pthread_exit(NULL); 1205 | } 1206 | 1207 | void ReasoningPredictor::out_train(std::vector *data, double _portion, int _num_threads) 1208 | { 1209 | portion = _portion; 1210 | num_threads = _num_threads; 1211 | 1212 | total_count = 0; 1213 | for (int k = 0; k != num_threads; k++) thread_data[k].clear(); 1214 | for (int k = 0; k != num_threads; k++) thread_split[k].clear(); 1215 | 1216 | std::random_shuffle((p_kg->train_triplets).begin(), (p_kg->train_triplets).end()); 1217 | 1218 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 1219 | for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::out_train_thread_caller, new ArgStruct(this, k)); 1220 | for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL); 1221 | printf("Generating Training Data | DONE! \n"); 1222 | free(pt); 1223 | 1224 | (*data).clear(); 1225 | std::vector split; 1226 | split.push_back(0); 1227 | int base = 0; 1228 | for (int k = 0; k != num_threads; k++) 1229 | { 1230 | (*data).insert((*data).end(), thread_data[k].begin(), thread_data[k].end()); 1231 | for (int i = 0; i != int(thread_split[k].size()); i++) split.push_back(base + thread_split[k][i]); 1232 | base += int(thread_data[k].size()); 1233 | 1234 | thread_data[k].clear(); 1235 | thread_split[k].clear(); 1236 | } 1237 | 1238 | int data_length = int((*data).size()); 1239 | for (int k = 0; k != int(split.size()); k++) (*data).push_back(split[k]); 1240 | (*data).push_back(data_length); 1241 | 1242 | /*if (file_name != NULL) 1243 | { 1244 | FILE *fo = fopen(file_name, "wb"); 1245 | Instance instance; 1246 | std::map::iterator iter; 1247 | for (int k = 0; k != int(instances.size()); k++) 1248 | { 1249 | instance = instances[k]; 1250 | fprintf(fo, "%d %d %d %d\n", instance.h, instance.r, instance.t, int(instance.vec_destrule.size())); 1251 | for (int i = 0; i != int(instance.vec_destrule.size()); i++) 1252 | { 1253 | fprintf(fo, "%d %d %d", instance.vec_destrule[i].valid, instance.vec_destrule[i].dest, int(instance.vec_destrule[i].index2count.size())); 1254 | for (iter = instance.vec_destrule[i].index2count.begin(); iter != instance.vec_destrule[i].index2count.end(); iter++) fprintf(fo, " %d:%d", iter->first, iter->second); 1255 | fprintf(fo, "\n"); 1256 | } 1257 | } 1258 | fclose(fo); 1259 | }*/ 1260 | } 1261 | 1262 | void ReasoningPredictor::out_test_thread(int thread) 1263 | { 1264 | std::vector *p_triplets; 1265 | if (test) p_triplets = &(p_kg->test_triplets); 1266 | else p_triplets = &(p_kg->valid_triplets); 1267 | 1268 | int triplet_size = int((*p_triplets).size()); 1269 | int bg = int(triplet_size / num_threads) * thread; 1270 | int ed = int(triplet_size / num_threads) * (thread + 1); 1271 | if (thread == num_threads - 1) ed = triplet_size; 1272 | 1273 | Triplet triplet; 1274 | int h, r, t, dest, count, index, valid; 1275 | 1276 | std::map dest2count; 1277 | std::map::iterator iter_ii; 1278 | std::map > dest2index2count; 1279 | std::map >::iterator iter_im; 1280 | 1281 | for (int T = bg; T != ed; T++) 1282 | { 1283 | if (T % 10 == 0) 1284 | { 1285 | total_count += 10; 1286 | printf("Generating Data | Progress: %.3lf%% %c", (double)total_count / (double)(triplet_size + 1) * 100, 13); 1287 | fflush(stdout); 1288 | } 1289 | 1290 | h = (*p_triplets)[T].h; 1291 | r = (*p_triplets)[T].r; 1292 | t = (*p_triplets)[T].t; 1293 | 1294 | dest2index2count.clear(); 1295 | for (index = 0; index != int(rel2rules[r].size()); index++) 1296 | { 1297 | p_kg->rule_destination(h, rel2rules[r][index], &dest2count, (*p_triplets)[T]); 1298 | 1299 | for (iter_ii = dest2count.begin(); iter_ii != dest2count.end(); iter_ii++) 1300 | { 1301 | dest = iter_ii->first; 1302 | count = iter_ii->second; 1303 | 1304 | if (dest2index2count.count(dest) == 0) dest2index2count[dest] = std::map(); 1305 | dest2index2count[dest][index] = count; 1306 | } 1307 | } 1308 | 1309 | for (iter_im = dest2index2count.begin(); iter_im != dest2index2count.end(); iter_im++) 1310 | { 1311 | dest = iter_im->first; 1312 | 1313 | triplet = (*p_triplets)[T]; 1314 | triplet.t = dest; 1315 | if (p_kg->check_true(triplet) == true && dest != t) continue; 1316 | 1317 | valid = 0; 1318 | triplet = (*p_triplets)[T]; 1319 | triplet.t = dest; 1320 | if (p_kg->check_true(triplet) == true) valid = 1; 1321 | 1322 | thread_data[thread].push_back(h); 1323 | thread_data[thread].push_back(r); 1324 | thread_data[thread].push_back(t); 1325 | thread_data[thread].push_back(valid); 1326 | thread_data[thread].push_back(dest); 1327 | thread_data[thread].push_back(int((iter_im->second).size())); 1328 | 1329 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1330 | { 1331 | thread_data[thread].push_back(iter_ii->first); 1332 | } 1333 | 1334 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1335 | { 1336 | thread_data[thread].push_back(iter_ii->second); 1337 | } 1338 | 1339 | thread_split[thread].push_back(int(thread_data[thread].size())); 1340 | } 1341 | } 1342 | dest2count.clear(); 1343 | dest2index2count.clear(); 1344 | pthread_exit(NULL); 1345 | } 1346 | 1347 | void *ReasoningPredictor::out_test_thread_caller(void *arg) 1348 | { 1349 | ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr); 1350 | int thread = ((ArgStruct *)arg)->id; 1351 | ptr->out_test_thread(thread); 1352 | pthread_exit(NULL); 1353 | } 1354 | 1355 | void ReasoningPredictor::out_test(std::vector *data, bool _test, int _num_threads) 1356 | { 1357 | test = _test; 1358 | num_threads = _num_threads; 1359 | 1360 | total_count = 0; 1361 | for (int k = 0; k != num_threads; k++) thread_data[k].clear(); 1362 | for (int k = 0; k != num_threads; k++) thread_split[k].clear(); 1363 | 1364 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 1365 | for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::out_test_thread_caller, new ArgStruct(this, k)); 1366 | for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL); 1367 | if (test == true) printf("Generating Test Data | DONE! \n"); 1368 | else printf("Generating Validation Data | DONE! \n"); 1369 | free(pt); 1370 | 1371 | (*data).clear(); 1372 | std::vector split; 1373 | split.push_back(0); 1374 | int base = 0; 1375 | for (int k = 0; k != num_threads; k++) 1376 | { 1377 | (*data).insert((*data).end(), thread_data[k].begin(), thread_data[k].end()); 1378 | for (int i = 0; i != int(thread_split[k].size()); i++) split.push_back(base + thread_split[k][i]); 1379 | base += int(thread_data[k].size()); 1380 | 1381 | thread_data[k].clear(); 1382 | thread_split[k].clear(); 1383 | } 1384 | 1385 | int data_length = int((*data).size()); 1386 | for (int k = 0; k != int(split.size()); k++) (*data).push_back(split[k]); 1387 | (*data).push_back(data_length); 1388 | 1389 | /*if (file_name != NULL) 1390 | { 1391 | FILE *fo = fopen(file_name, "wb"); 1392 | Instance instance; 1393 | std::map::iterator iter; 1394 | for (int k = 0; k != int(instances.size()); k++) 1395 | { 1396 | instance = instances[k]; 1397 | fprintf(fo, "%d %d %d %d\n", instance.h, instance.r, instance.t, int(instance.vec_destrule.size())); 1398 | for (int i = 0; i != int(instance.vec_destrule.size()); i++) 1399 | { 1400 | fprintf(fo, "%d %d %d", instance.vec_destrule[i].valid, instance.vec_destrule[i].dest, int(instance.vec_destrule[i].index2count.size())); 1401 | for (iter = instance.vec_destrule[i].index2count.begin(); iter != instance.vec_destrule[i].index2count.end(); iter++) fprintf(fo, " %d:%d", iter->first, iter->second); 1402 | fprintf(fo, "\n"); 1403 | } 1404 | } 1405 | fclose(fo); 1406 | }*/ 1407 | } 1408 | 1409 | void ReasoningPredictor::in_rules(char *file_name) 1410 | { 1411 | if (rel2rules != NULL) 1412 | { 1413 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 1414 | delete [] rel2rules; 1415 | } 1416 | 1417 | rel2rules = new std::vector [p_kg->relation_size]; 1418 | 1419 | FILE *fi = fopen(file_name, "rb"); 1420 | 1421 | int r, r_head, n, id, type, r_body; 1422 | Rule rule; 1423 | while (1) 1424 | { 1425 | if (fscanf(fi, "%d %d", &r, &n) != 2) break; 1426 | 1427 | for (int k = 0; k != n; k++) 1428 | { 1429 | rule.clear(); 1430 | fscanf(fi, "%d %d %d", &id, &r_head, &type); 1431 | rule.r_head = r_head; 1432 | rule.type = type; 1433 | for (int i = 0; i != type; i++) 1434 | { 1435 | fscanf(fi, "%d", &r_body); 1436 | rule.r_body.push_back(r_body); 1437 | } 1438 | rel2rules[r].push_back(rule); 1439 | } 1440 | } 1441 | } 1442 | 1443 | void ReasoningPredictor::out_rules(char *file_name) 1444 | { 1445 | FILE *fo = fopen(file_name, "wb"); 1446 | for (int r = 0; r != p_kg->relation_size; r++) 1447 | { 1448 | fprintf(fo, "%d %d\n", r, int(rel2rules[r].size())); 1449 | for (int k = 0; k != int(rel2rules[r].size()); k++) 1450 | { 1451 | fprintf(fo, "%d %d %d", k, rel2rules[r][k].r_head, rel2rules[r][k].type); 1452 | for (int i = 0; i != int(rel2rules[r][k].r_body.size()); i++) fprintf(fo, " %d", rel2rules[r][k].r_body[i]); 1453 | fprintf(fo, "\n"); 1454 | } 1455 | } 1456 | fclose(fo); 1457 | } 1458 | 1459 | void ReasoningPredictor::out_train_single(int h, int r, int t, std::vector *data) 1460 | { 1461 | int dest, count, index, valid; 1462 | 1463 | std::vector split; 1464 | split.push_back(0); 1465 | 1466 | std::map dest2count; 1467 | std::map::iterator iter_ii; 1468 | std::map > dest2index2count; 1469 | std::map >::iterator iter_im; 1470 | 1471 | Triplet triplet; 1472 | triplet.h = h; triplet.r = r; triplet.t = t; 1473 | 1474 | dest2index2count.clear(); 1475 | for (index = 0; index != int(rel2rules[r].size()); index++) 1476 | { 1477 | p_kg->rule_destination(h, rel2rules[r][index], &dest2count, triplet); 1478 | 1479 | for (iter_ii = dest2count.begin(); iter_ii != dest2count.end(); iter_ii++) 1480 | { 1481 | dest = iter_ii->first; 1482 | count = iter_ii->second; 1483 | 1484 | if (dest2index2count.count(dest) == 0) dest2index2count[dest] = std::map(); 1485 | dest2index2count[dest][index] = count; 1486 | } 1487 | } 1488 | 1489 | for (iter_im = dest2index2count.begin(); iter_im != dest2index2count.end(); iter_im++) 1490 | { 1491 | dest = iter_im->first; 1492 | 1493 | valid = 0; 1494 | triplet.h = h; 1495 | triplet.r = r; 1496 | triplet.t = dest; 1497 | if (p_kg->check_observed(triplet) == true) valid = 1; 1498 | 1499 | (*data).push_back(h); 1500 | (*data).push_back(r); 1501 | (*data).push_back(t); 1502 | (*data).push_back(valid); 1503 | (*data).push_back(dest); 1504 | (*data).push_back(int((iter_im->second).size())); 1505 | 1506 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1507 | { 1508 | (*data).push_back(iter_ii->first); 1509 | } 1510 | 1511 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1512 | { 1513 | (*data).push_back(iter_ii->second); 1514 | } 1515 | 1516 | split.push_back(int((*data).size())); 1517 | } 1518 | 1519 | int data_length = int((*data).size()); 1520 | for (int k = 0; k != int(split.size()); k++) (*data).push_back(split[k]); 1521 | (*data).push_back(data_length); 1522 | 1523 | dest2count.clear(); 1524 | dest2index2count.clear(); 1525 | } 1526 | 1527 | void ReasoningPredictor::out_test_single(int h, int r, int t, std::vector *data) 1528 | { 1529 | int dest, count, index, valid; 1530 | 1531 | std::vector split; 1532 | split.push_back(0); 1533 | 1534 | std::map dest2count; 1535 | std::map::iterator iter_ii; 1536 | std::map > dest2index2count; 1537 | std::map >::iterator iter_im; 1538 | 1539 | Triplet triplet; 1540 | triplet.h = h; triplet.r = r; triplet.t = t; 1541 | 1542 | dest2index2count.clear(); 1543 | for (index = 0; index != int(rel2rules[r].size()); index++) 1544 | { 1545 | p_kg->rule_destination(h, rel2rules[r][index], &dest2count, triplet); 1546 | 1547 | for (iter_ii = dest2count.begin(); iter_ii != dest2count.end(); iter_ii++) 1548 | { 1549 | dest = iter_ii->first; 1550 | count = iter_ii->second; 1551 | 1552 | if (dest2index2count.count(dest) == 0) dest2index2count[dest] = std::map(); 1553 | dest2index2count[dest][index] = count; 1554 | } 1555 | } 1556 | 1557 | for (iter_im = dest2index2count.begin(); iter_im != dest2index2count.end(); iter_im++) 1558 | { 1559 | dest = iter_im->first; 1560 | 1561 | triplet.h = h; 1562 | triplet.r = r; 1563 | triplet.t = dest; 1564 | if (p_kg->check_true(triplet) == true && dest != t) continue; 1565 | 1566 | valid = 0; 1567 | if (p_kg->check_true(triplet) == true) valid = 1; 1568 | 1569 | (*data).push_back(h); 1570 | (*data).push_back(r); 1571 | (*data).push_back(t); 1572 | (*data).push_back(valid); 1573 | (*data).push_back(dest); 1574 | (*data).push_back(int((iter_im->second).size())); 1575 | 1576 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1577 | { 1578 | (*data).push_back(iter_ii->first); 1579 | } 1580 | 1581 | for (iter_ii = (iter_im->second).begin(); iter_ii != (iter_im->second).end(); iter_ii++) 1582 | { 1583 | (*data).push_back(iter_ii->second); 1584 | } 1585 | 1586 | split.push_back(int((*data).size())); 1587 | } 1588 | 1589 | int data_length = int((*data).size()); 1590 | for (int k = 0; k != int(split.size()); k++) (*data).push_back(split[k]); 1591 | (*data).push_back(data_length); 1592 | 1593 | dest2count.clear(); 1594 | dest2index2count.clear(); 1595 | } 1596 | 1597 | void ReasoningPredictor::out_test_count_thread(int thread) 1598 | { 1599 | std::vector *p_triplets; 1600 | if (test) p_triplets = &(p_kg->test_triplets); 1601 | else p_triplets = &(p_kg->valid_triplets); 1602 | 1603 | int triplet_size = int((*p_triplets).size()); 1604 | int bg = int(triplet_size / num_threads) * thread; 1605 | int ed = int(triplet_size / num_threads) * (thread + 1); 1606 | if (thread == num_threads - 1) ed = triplet_size; 1607 | 1608 | Triplet triplet; 1609 | int h, r, t, dest, count, index; 1610 | 1611 | std::map dest2count; 1612 | std::map::iterator iter_ii; 1613 | std::set destset; 1614 | std::set::iterator iter_i; 1615 | 1616 | for (int T = bg; T != ed; T++) 1617 | { 1618 | if (T % 10 == 0) 1619 | { 1620 | total_count += 10; 1621 | printf("Generating Count | Progress: %.3lf%% %c", (double)total_count / (double)(triplet_size + 1) * 100, 13); 1622 | fflush(stdout); 1623 | } 1624 | 1625 | h = (*p_triplets)[T].h; 1626 | r = (*p_triplets)[T].r; 1627 | t = (*p_triplets)[T].t; 1628 | 1629 | destset.clear(); 1630 | for (index = 0; index != int(rel2rules[r].size()); index++) 1631 | { 1632 | p_kg->rule_destination(h, rel2rules[r][index], &dest2count, (*p_triplets)[T]); 1633 | 1634 | for (iter_ii = dest2count.begin(); iter_ii != dest2count.end(); iter_ii++) 1635 | { 1636 | dest = iter_ii->first; 1637 | 1638 | destset.insert(dest); 1639 | } 1640 | } 1641 | 1642 | count = 0; 1643 | for (iter_i = destset.begin(); iter_i != destset.end(); iter_i++) 1644 | { 1645 | dest = *iter_i; 1646 | 1647 | triplet = (*p_triplets)[T]; 1648 | triplet.t = dest; 1649 | if (p_kg->check_true(triplet) == true && dest != t) continue; 1650 | 1651 | count += 1; 1652 | } 1653 | 1654 | thread_data[thread].push_back(count); 1655 | } 1656 | destset.clear(); 1657 | dest2count.clear(); 1658 | pthread_exit(NULL); 1659 | } 1660 | 1661 | void *ReasoningPredictor::out_test_count_thread_caller(void *arg) 1662 | { 1663 | ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr); 1664 | int thread = ((ArgStruct *)arg)->id; 1665 | ptr->out_test_count_thread(thread); 1666 | pthread_exit(NULL); 1667 | } 1668 | 1669 | void ReasoningPredictor::out_test_count(std::vector *data, bool _test, int _num_threads) 1670 | { 1671 | test = _test; 1672 | num_threads = _num_threads; 1673 | 1674 | total_count = 0; 1675 | for (int k = 0; k != num_threads; k++) thread_data[k].clear(); 1676 | 1677 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 1678 | for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::out_test_count_thread_caller, new ArgStruct(this, k)); 1679 | for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL); 1680 | if (test == true) printf("Generating Test Data | DONE! \n"); 1681 | else printf("Generating Validation Data | DONE! \n"); 1682 | free(pt); 1683 | 1684 | (*data).clear(); 1685 | for (int k = 0; k != num_threads; k++) 1686 | { 1687 | (*data).insert((*data).end(), thread_data[k].begin(), thread_data[k].end()); 1688 | 1689 | thread_data[k].clear(); 1690 | } 1691 | } 1692 | 1693 | /***************************** 1694 | RuleGenerator 1695 | *****************************/ 1696 | 1697 | RuleGenerator::RuleGenerator() 1698 | { 1699 | rel2rules = NULL; 1700 | rel2pool = NULL; 1701 | mapping = NULL; 1702 | p_kg = NULL; 1703 | } 1704 | 1705 | RuleGenerator::~RuleGenerator() 1706 | { 1707 | if (rel2rules != NULL) 1708 | { 1709 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 1710 | delete [] rel2rules; 1711 | rel2rules = NULL; 1712 | } 1713 | if (rel2pool != NULL) 1714 | { 1715 | for (int r = 0; r != p_kg->relation_size; r++) rel2pool[r].clear(); 1716 | delete [] rel2pool; 1717 | rel2pool = NULL; 1718 | } 1719 | if (mapping != NULL) 1720 | { 1721 | for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear(); 1722 | delete [] mapping; 1723 | mapping = NULL; 1724 | } 1725 | p_kg = NULL; 1726 | } 1727 | 1728 | void RuleGenerator::init_knowledge_graph(KnowledgeGraph *_p_kg) 1729 | { 1730 | p_kg = _p_kg; 1731 | } 1732 | 1733 | void RuleGenerator::set_logic_rules(std::vector * _rel2rules) 1734 | { 1735 | if (rel2rules != NULL) 1736 | { 1737 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 1738 | delete [] rel2rules; 1739 | } 1740 | 1741 | rel2rules = new std::vector [p_kg->relation_size]; 1742 | for (int r = 0; r != p_kg->relation_size; r++) 1743 | { 1744 | rel2rules[r] = _rel2rules[r]; 1745 | for (int k = 0; k != int(rel2rules[r].size()); k++) 1746 | { 1747 | rel2rules[r][k].wt.clear(); 1748 | rel2rules[r][k].H = 0; 1749 | } 1750 | } 1751 | } 1752 | 1753 | std::vector *RuleGenerator::get_logic_rules() 1754 | { 1755 | return rel2rules; 1756 | } 1757 | 1758 | void RuleGenerator::set_pool(std::vector * _rel2rules) 1759 | { 1760 | if (rel2pool != NULL) 1761 | { 1762 | for (int r = 0; r != p_kg->relation_size; r++) rel2pool[r].clear(); 1763 | delete [] rel2pool; 1764 | } 1765 | 1766 | rel2pool = new std::vector [p_kg->relation_size]; 1767 | for (int r = 0; r != p_kg->relation_size; r++) 1768 | { 1769 | rel2pool[r] = _rel2rules[r]; 1770 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1771 | { 1772 | rel2pool[r][k].wt.clear(); 1773 | rel2pool[r][k].H = 0; 1774 | rel2pool[r][k].cn = 0; 1775 | } 1776 | } 1777 | } 1778 | 1779 | void RuleGenerator::sample_from_pool(int _number, double _temperature) 1780 | { 1781 | if (rel2rules != NULL) 1782 | { 1783 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 1784 | delete [] rel2rules; 1785 | } 1786 | if (mapping != NULL) 1787 | { 1788 | for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear(); 1789 | delete [] mapping; 1790 | } 1791 | 1792 | rel2rules = new std::vector [p_kg->relation_size]; 1793 | mapping = new std::vector [p_kg->relation_size]; 1794 | for (int r = 0; r != p_kg->relation_size; r++) 1795 | { 1796 | std::vector probability; 1797 | double max_val = -1000000, sum_val = 0; 1798 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1799 | max_val = std::max(max_val, rel2pool[r][k].H); 1800 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1801 | sum_val += exp((rel2pool[r][k].H - max_val) / _temperature); 1802 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1803 | probability.push_back(exp((rel2pool[r][k].H - max_val) / _temperature) / sum_val); 1804 | 1805 | for (int k = 0; k != _number; k++) 1806 | { 1807 | double sum_prob = 0, rand_val = double(rand()) / double(RAND_MAX); 1808 | for (int index = 0; index != int(rel2pool[r].size()); index++) 1809 | { 1810 | sum_prob += probability[index]; 1811 | if (sum_prob > rand_val) 1812 | { 1813 | rel2rules[r].push_back(rel2pool[r][index]); 1814 | mapping[r].push_back(index); 1815 | break; 1816 | } 1817 | } 1818 | } 1819 | } 1820 | } 1821 | 1822 | void RuleGenerator::random_from_pool(int _number) 1823 | { 1824 | if (rel2rules != NULL) 1825 | { 1826 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 1827 | delete [] rel2rules; 1828 | } 1829 | if (mapping != NULL) 1830 | { 1831 | for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear(); 1832 | delete [] mapping; 1833 | } 1834 | 1835 | rel2rules = new std::vector [p_kg->relation_size]; 1836 | mapping = new std::vector [p_kg->relation_size]; 1837 | std::vector rand_index; 1838 | for (int r = 0; r != p_kg->relation_size; r++) 1839 | { 1840 | rand_index.clear(); 1841 | for (int k = 0; k != int(rel2pool[r].size()); k++) rand_index.push_back(k); 1842 | std::random_shuffle(rand_index.begin(), rand_index.end()); 1843 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1844 | { 1845 | if (k >= _number) break; 1846 | int index = rand_index[k]; 1847 | rel2rules[r].push_back(rel2pool[r][index]); 1848 | mapping[r].push_back(index); 1849 | } 1850 | } 1851 | } 1852 | 1853 | void RuleGenerator::best_from_pool(int _number) 1854 | { 1855 | if (rel2rules != NULL) 1856 | { 1857 | for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear(); 1858 | delete [] rel2rules; 1859 | } 1860 | if (mapping != NULL) 1861 | { 1862 | for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear(); 1863 | delete [] mapping; 1864 | } 1865 | 1866 | rel2rules = new std::vector [p_kg->relation_size]; 1867 | mapping = new std::vector [p_kg->relation_size]; 1868 | std::vector rank_list; 1869 | RankListEntry entry; 1870 | for (int r = 0; r != p_kg->relation_size; r++) 1871 | { 1872 | rank_list.clear(); 1873 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1874 | { 1875 | entry.id = k; 1876 | entry.val = rel2pool[r][k].H; 1877 | rank_list.push_back(entry); 1878 | } 1879 | std::sort(rank_list.begin(), rank_list.end()); 1880 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1881 | { 1882 | if (k >= _number) break; 1883 | int index = rank_list[k].id; 1884 | rel2rules[r].push_back(rel2pool[r][index]); 1885 | mapping[r].push_back(index); 1886 | } 1887 | } 1888 | } 1889 | 1890 | void RuleGenerator::update(std::vector * _rel2rules) 1891 | { 1892 | for (int r = 0; r != p_kg->relation_size; r++) 1893 | { 1894 | for (int k = 0; k != int(rel2rules[r].size()); k++) rel2rules[r][k].H = _rel2rules[r][k].H; 1895 | for (int k = 0; k != int(rel2rules[r].size()); k++) 1896 | { 1897 | int index = mapping[r][k]; 1898 | rel2pool[r][index].H = (rel2pool[r][index].H * rel2pool[r][index].cn + rel2rules[r][k].H) / (rel2pool[r][index].cn + 1); 1899 | rel2pool[r][index].cn += 1; 1900 | } 1901 | } 1902 | } 1903 | 1904 | void RuleGenerator::out_rules(char *file_name, int num_rules) 1905 | { 1906 | FILE *fo = fopen(file_name, "wb"); 1907 | std::vector rank_list; 1908 | RankListEntry entry; 1909 | for (int r = 0; r != p_kg->relation_size; r++) 1910 | { 1911 | rank_list.clear(); 1912 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1913 | { 1914 | entry.id = k; 1915 | entry.val = rel2pool[r][k].H; 1916 | rank_list.push_back(entry); 1917 | } 1918 | std::sort(rank_list.begin(), rank_list.end()); 1919 | 1920 | int actual_rules = int(rel2pool[r].size()); 1921 | if (num_rules < actual_rules) actual_rules = num_rules; 1922 | 1923 | //fprintf(fo, "%d %d\n", r, actual_rules); 1924 | for (int k = 0; k != int(rel2pool[r].size()); k++) 1925 | { 1926 | if (k == num_rules) break; 1927 | int index = rank_list[k].id; 1928 | //fprintf(fo, "%d %d %d", k, rel2pool[r][index].r_head, rel2pool[r][index].type); 1929 | fprintf(fo, "%d", rel2pool[r][index].r_head); 1930 | for (int i = 0; i != int(rel2pool[r][index].r_body.size()); i++) fprintf(fo, " %d", rel2pool[r][index].r_body[i]); 1931 | fprintf(fo, " %.16lf", rel2pool[r][index].H); 1932 | fprintf(fo, "\n"); 1933 | } 1934 | } 1935 | fclose(fo); 1936 | } 1937 | -------------------------------------------------------------------------------- /miner/rnnlogic.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #define MAX_STRING 1000 15 | #define MAX_THREADS 100 16 | #define MAX_LENGTH 100 17 | 18 | double sigmoid(double x); 19 | double abs_val(double x); 20 | 21 | struct ArgStruct 22 | { 23 | void *ptr; 24 | int id; 25 | 26 | ArgStruct(void *_ptr, int _id); 27 | }; 28 | 29 | struct Triplet 30 | { 31 | int h, t, r; 32 | 33 | friend bool operator < (Triplet u, Triplet v); 34 | friend bool operator == (Triplet u, Triplet v); 35 | }; 36 | 37 | struct RankListEntry 38 | { 39 | int id; 40 | double val; 41 | 42 | friend bool operator < (RankListEntry u, RankListEntry v); 43 | }; 44 | 45 | struct Parameter 46 | { 47 | double data, m, v, t; 48 | 49 | Parameter(); 50 | 51 | void clear(); 52 | void update(double grad, double learning_rate, double weight_decay=0); 53 | }; 54 | 55 | struct Rule 56 | { 57 | std::vector r_body; 58 | int r_head; 59 | int type; 60 | double H, cn, prior; 61 | Parameter wt; 62 | 63 | Rule(); 64 | ~Rule(); 65 | 66 | void clear(); 67 | friend bool operator < (Rule u, Rule v); 68 | friend bool operator == (Rule u, Rule v); 69 | }; 70 | 71 | struct Result 72 | { 73 | double h1, h3, h10, mr, mrr; 74 | 75 | Result(); 76 | Result(double mr_, double mrr_, double h1_, double h3_, double h10_); 77 | }; 78 | 79 | struct DestRule 80 | { 81 | int dest, valid; 82 | std::map index2count; 83 | 84 | void clear() 85 | { 86 | dest = -1; 87 | valid = -1; 88 | index2count.clear(); 89 | } 90 | }; 91 | 92 | struct Instance 93 | { 94 | int h, r, t; 95 | std::vector vec_destrule; 96 | 97 | void clear() 98 | { 99 | h = -1; 100 | r = -1; 101 | t = -1; 102 | vec_destrule.clear(); 103 | } 104 | }; 105 | 106 | class KnowledgeGraph 107 | { 108 | protected: 109 | int entity_size, relation_size, train_triplet_size, valid_triplet_size, test_triplet_size, all_triplet_size; 110 | std::map ent2id, rel2id; 111 | std::map id2ent, id2rel; 112 | std::vector train_triplets, valid_triplets, test_triplets; 113 | std::vector **e2r2n; 114 | std::set set_train_triplets, set_all_triplets; 115 | 116 | public: 117 | friend class RuleMiner; 118 | friend class ReasoningPredictor; 119 | friend class RuleGenerator; 120 | 121 | KnowledgeGraph(); 122 | ~KnowledgeGraph(); 123 | 124 | int get_entity_size(); 125 | int get_relation_size(); 126 | int get_train_size(); 127 | int get_valid_size(); 128 | int get_test_size(); 129 | 130 | void read_data(char *data_path); 131 | bool check_observed(Triplet triplet); 132 | bool check_true(Triplet triplet); 133 | void rule_search(int r, int e, int goal, int *path, int depth, int max_depth, std::set *rule_set, Triplet removed_triplet); 134 | void rule_destination(int e, Rule rule, std::map *dest2count, Triplet removed_triplet); 135 | }; 136 | 137 | class RuleMiner 138 | { 139 | protected: 140 | KnowledgeGraph *p_kg; 141 | int num_threads, max_length; 142 | double portion; 143 | long long total_count; 144 | std::vector *rel2rules; 145 | std::set *rel2ruleset; 146 | sem_t mutex; 147 | 148 | public: 149 | RuleMiner(); 150 | ~RuleMiner(); 151 | 152 | void init_knowledge_graph(KnowledgeGraph *_p_kg); 153 | void clear(); 154 | std::vector *get_logic_rules(); 155 | int get_relation_size(); 156 | 157 | void search_thread(int thread); 158 | static void *search_thread_caller(void *arg); 159 | void search(int _max_length, double _portion, int _num_threads); 160 | 161 | void save(char *file_name); 162 | void load(char *file_name); 163 | }; 164 | 165 | class ReasoningPredictor 166 | { 167 | protected: 168 | KnowledgeGraph *p_kg; 169 | std::vector *rel2rules; 170 | int num_threads, top_k; 171 | double temperature, learning_rate, weight_decay; 172 | double portion; 173 | double prior_weight, H_temperature; 174 | bool test, fast; 175 | long long total_count; 176 | double total_loss; 177 | std::vector< std::pair > ranks; 178 | sem_t mutex; 179 | 180 | std::vector thread_data[MAX_THREADS], thread_split[MAX_THREADS]; 181 | 182 | public: 183 | ReasoningPredictor(); 184 | ~ReasoningPredictor(); 185 | 186 | void init_knowledge_graph(KnowledgeGraph *_p_kg); 187 | void set_logic_rules(std::vector * _rel2rules); 188 | std::vector *get_logic_rules(); 189 | int get_relation_size(); 190 | 191 | void learn_thread(int thread); 192 | static void *learn_thread_caller(void *arg); 193 | void learn(double _learning_rate, double _weight_decay, double _temperature, bool _fast, double _portion, int _num_threads); 194 | 195 | void H_score_thread(int thread); 196 | static void *H_score_thread_caller(void *arg); 197 | void H_score(int _top_k, double _H_temperature, double _prior_weight, double _portion, int _num_threads); 198 | 199 | void evaluate_thread(int thread); 200 | static void *evaluate_thread_caller(void *arg); 201 | Result evaluate(bool _test, int _num_threads); 202 | 203 | void out_train_thread(int thread); 204 | static void *out_train_thread_caller(void *arg); 205 | void out_train(std::vector *data, double _portion, int _num_threads); 206 | 207 | void out_test_thread(int thread); 208 | static void *out_test_thread_caller(void *arg); 209 | void out_test(std::vector *data, bool _test, int _num_threads); 210 | 211 | void in_rules(char *file_name); 212 | void out_rules(char *file_name); 213 | 214 | void out_train_single(int h, int r, int t, std::vector *data); 215 | void out_test_single(int h, int r, int t, std::vector *data); 216 | 217 | void out_test_count_thread(int thread); 218 | static void *out_test_count_thread_caller(void *arg); 219 | void out_test_count(std::vector *data, bool _test, int _num_threads); 220 | }; 221 | 222 | class RuleGenerator 223 | { 224 | protected: 225 | KnowledgeGraph *p_kg; 226 | std::vector *rel2rules, *rel2pool; 227 | std::vector *mapping; 228 | public: 229 | RuleGenerator(); 230 | ~RuleGenerator(); 231 | 232 | void init_knowledge_graph(KnowledgeGraph *_p_kg); 233 | void set_logic_rules(std::vector * _rel2rules); 234 | std::vector *get_logic_rules(); 235 | 236 | void set_pool(std::vector * _rel2rules); 237 | void sample_from_pool(int _number, double _temperature=1); 238 | void random_from_pool(int _number); 239 | void best_from_pool(int _number); 240 | void update(std::vector * _rel2rules); 241 | void out_rules(char *file_name, int num_rules); 242 | }; 243 | -------------------------------------------------------------------------------- /miner/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils import cpp_extension as cpp 3 | 4 | setup( 5 | name='rnnlogic_ext', 6 | ext_modules=[ 7 | cpp.CppExtension('rnnlogic_ext', ["rnnlogic.cpp", "pyrnnlogic.cpp"], 8 | extra_compile_args={"cxx": ["-O3"]}) 9 | ], 10 | cmdclass={"build_ext": cpp.BuildExtension} 11 | ) 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch torchvision torchaudio 2 | numpy 3 | easydict -------------------------------------------------------------------------------- /src/comm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing 3 | from collections import defaultdict 4 | 5 | import torch 6 | from torch import distributed as dist 7 | 8 | 9 | cpu_group = None 10 | gpu_group = None 11 | 12 | 13 | def get_rank(): 14 | """ 15 | Get the rank of this process in distributed processes. 16 | Return 0 for single process case. 17 | """ 18 | if dist.is_initialized(): 19 | return dist.get_rank() 20 | if "RANK" in os.environ: 21 | return int(os.environ["RANK"]) 22 | return 0 23 | 24 | 25 | def get_world_size(): 26 | """ 27 | Get the total number of distributed processes. 28 | Return 1 for single process case. 29 | """ 30 | if dist.is_initialized(): 31 | return dist.get_world_size() 32 | if "WORLD_SIZE" in os.environ: 33 | return int(os.environ["WORLD_SIZE"]) 34 | return 1 35 | 36 | 37 | def get_group(device): 38 | """ 39 | Get the process group corresponding to the given device. 40 | Parameters: 41 | device (torch.device): query device 42 | """ 43 | group = cpu_group if device.type == "cpu" else gpu_group 44 | if group is None: 45 | raise ValueError("%s group is not initialized. Use comm.init_process_group() to initialize it" 46 | % device.type.upper()) 47 | return group 48 | 49 | 50 | def init_process_group(backend, init_method=None, **kwargs): 51 | """ 52 | Initialize CPU and/or GPU process groups. 53 | Parameters: 54 | backend (str): Communication backend. Use ``nccl`` for GPUs and ``gloo`` for CPUs. 55 | init_method (str, optional): URL specifying how to initialize the process group 56 | """ 57 | global cpu_group 58 | global gpu_group 59 | 60 | dist.init_process_group(backend, init_method, **kwargs) 61 | gpu_group = dist.group.WORLD 62 | if backend == "nccl": 63 | cpu_group = dist.new_group(backend="gloo") 64 | else: 65 | cpu_group = gpu_group 66 | 67 | 68 | def get_cpu_count(): 69 | """ 70 | Get the number of CPUs on this node. 71 | """ 72 | return multiprocessing.cpu_count() 73 | 74 | 75 | def synchronize(): 76 | """ 77 | Synchronize among all distributed processes. 78 | """ 79 | if get_world_size() > 1: 80 | dist.barrier() 81 | 82 | 83 | def _recursive_read(obj): 84 | values = defaultdict(list) 85 | sizes = defaultdict(list) 86 | if isinstance(obj, torch.Tensor): 87 | values[obj.dtype] += [obj.flatten()] 88 | sizes[obj.dtype] += [torch.tensor([obj.numel()], device=obj.device)] 89 | elif isinstance(obj, dict): 90 | for v in obj.values(): 91 | child_values, child_sizes = _recursive_read(v) 92 | for k, v in child_values.items(): 93 | values[k] += v 94 | for k, v in child_sizes.items(): 95 | sizes[k] += v 96 | elif isinstance(obj, list) or isinstance(obj, tuple): 97 | for v in obj: 98 | child_values, child_sizes = _recursive_read(v) 99 | for k, v in child_values.items(): 100 | values[k] += v 101 | for k, v in child_sizes.items(): 102 | sizes[k] += v 103 | else: 104 | raise ValueError("Unknown type `%s`" % type(obj)) 105 | return values, sizes 106 | 107 | 108 | def _recursive_write(obj, values, sizes=None): 109 | if isinstance(obj, torch.Tensor): 110 | if sizes is None: 111 | size = torch.tensor([obj.numel()], device=obj.device) 112 | else: 113 | s = sizes[obj.dtype] 114 | size, s = s.split([1, len(s) - 1]) 115 | sizes[obj.dtype] = s 116 | v = values[obj.dtype] 117 | new_obj, v = v.split([size, v.shape[-1] - size], dim=-1) 118 | # compatible with reduce / stack / cat 119 | new_obj = new_obj.view(new_obj.shape[:-1] + (-1,) + obj.shape[1:]) 120 | values[obj.dtype] = v 121 | return new_obj, values 122 | elif isinstance(obj, dict): 123 | new_obj = {} 124 | for k, v in obj.items(): 125 | new_obj[k], values = _recursive_write(v, values, sizes) 126 | elif isinstance(obj, list) or isinstance(obj, tuple): 127 | new_obj = [] 128 | for v in obj: 129 | new_v, values = _recursive_write(v, values, sizes) 130 | new_obj.append(new_v) 131 | else: 132 | raise ValueError("Unknown type `%s`" % type(obj)) 133 | return new_obj, values 134 | 135 | 136 | def reduce(obj, op="sum", dst=None): 137 | """ 138 | Reduce any nested container of tensors. 139 | Parameters: 140 | obj (Object): any container object. Can be nested list, tuple or dict. 141 | op (str, optional): element-wise reduction operator. 142 | Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. 143 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 144 | Example:: 145 | >>> # assume 4 workers 146 | >>> rank = comm.get_rank() 147 | >>> x = torch.rand(5) 148 | >>> obj = {"polynomial": x ** rank} 149 | >>> obj = comm.reduce(obj) 150 | >>> assert torch.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1) 151 | """ 152 | values = _recursive_read(obj)[0] 153 | values = {k: torch.cat(v) for k, v in values.items()} 154 | 155 | is_mean = op == "mean" 156 | if is_mean: 157 | op = "sum" 158 | op = getattr(dist.ReduceOp, op.upper()) 159 | 160 | reduced = {} 161 | for k, v in values.items(): 162 | dtype = v.dtype 163 | # NCCL can't solve bool. Cast them to byte 164 | if dtype == torch.bool: 165 | v = v.byte() 166 | group = get_group(v.device) 167 | if dst is None: 168 | dist.all_reduce(v, op=op, group=group) 169 | else: 170 | dist.reduce(v, op=op, dst=dst, group=group) 171 | if is_mean: 172 | v = v / get_world_size() 173 | reduced[k] = v.type(dtype) 174 | 175 | return _recursive_write(obj, reduced)[0] 176 | 177 | 178 | def stack(obj, dst=None): 179 | """ 180 | Stack any nested container of tensors. The new dimension will be added at the 0-th axis. 181 | Parameters: 182 | obj (Object): any container object. Can be nested list, tuple or dict. 183 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 184 | Example:: 185 | >>> # assume 4 workers 186 | >>> rank = comm.get_rank() 187 | >>> x = torch.rand(5) 188 | >>> obj = {"exponent": x ** rank} 189 | >>> obj = comm.stack(obj) 190 | >>> truth = torch.stack([torch.ones_like(x), x, x ** 2, x ** 3] 191 | >>> assert torch.allclose(obj["exponent"], truth)) 192 | """ 193 | values = _recursive_read(obj)[0] 194 | values = {k: torch.cat(v) for k, v in values.items()} 195 | 196 | stacked = {} 197 | for k, v in values.items(): 198 | dtype = v.dtype 199 | # NCCL can't solve bool. Cast them to byte 200 | if dtype == torch.bool: 201 | dtype = torch.uint8 202 | s = torch.zeros(get_world_size(), *v.shape, dtype=dtype, device=v.device) 203 | s[get_rank()] = v 204 | group = get_group(s.device) 205 | if dst is None: 206 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) 207 | else: 208 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) 209 | stacked[k] = s.type(v.dtype) 210 | 211 | return _recursive_write(obj, stacked)[0] 212 | 213 | 214 | def cat(obj, dst=None): 215 | """ 216 | Concatenate any nested container of tensors along the 0-th axis. 217 | Parameters: 218 | obj (Object): any container object. Can be nested list, tuple or dict. 219 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 220 | Example:: 221 | >>> # assume 4 workers 222 | >>> rank = comm.get_rank() 223 | >>> rng = torch.arange(10) 224 | >>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]} 225 | >>> obj = comm.cat(obj) 226 | >>> assert torch.allclose(obj["range"], rng) 227 | """ 228 | values, sizes = _recursive_read(obj) 229 | sizes = {k: torch.cat(v) for k, v in sizes.items()} 230 | 231 | sizes = stack(sizes) 232 | cated = {} 233 | for k, value in values.items(): 234 | size = sizes[k].t().flatten() # sizes[k]: (num_worker, num_obj) 235 | dtype = value[0].dtype 236 | # NCCL can't solve bool. Cast them to byte 237 | if dtype == torch.bool: 238 | dtype = torch.uint8 239 | s = torch.zeros(size.sum(), dtype=dtype, device=value[0].device) 240 | obj_id = get_rank() 241 | world_size = get_world_size() 242 | offset = size[:obj_id].sum() 243 | for v in value: 244 | assert offset + v.numel() <= len(s) 245 | s[offset: offset + v.numel()] = v 246 | offset += size[obj_id: obj_id + world_size].sum() 247 | obj_id += world_size 248 | group = get_group(s.device) 249 | if dst is None: 250 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) 251 | else: 252 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) 253 | cated[k] = s.type(value[0].dtype) 254 | sizes = {k: v.sum(dim=0) for k, v in sizes.items()} 255 | 256 | return _recursive_write(obj, cated, sizes)[0] -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch_scatter import scatter 4 | import numpy as np 5 | import os 6 | import random 7 | from easydict import EasyDict 8 | 9 | class KnowledgeGraph(object): 10 | def __init__(self, data_path): 11 | self.data_path = data_path 12 | 13 | self.entity2id = dict() 14 | self.relation2id = dict() 15 | self.id2entity = dict() 16 | self.id2relation = dict() 17 | 18 | with open(os.path.join(data_path, 'entities.dict')) as fi: 19 | for line in fi: 20 | id, entity = line.strip().split('\t') 21 | self.entity2id[entity] = int(id) 22 | self.id2entity[int(id)] = entity 23 | 24 | with open(os.path.join(data_path, 'relations.dict')) as fi: 25 | for line in fi: 26 | id, relation = line.strip().split('\t') 27 | self.relation2id[relation] = int(id) 28 | self.id2relation[int(id)] = relation 29 | 30 | self.entity_size = len(self.entity2id) 31 | self.relation_size = len(self.relation2id) 32 | 33 | self.train_facts = list() 34 | self.valid_facts = list() 35 | self.test_facts = list() 36 | self.hr2o = dict() 37 | self.hr2oo = dict() 38 | self.hr2ooo = dict() 39 | self.relation2adjacency = [[[], []] for k in range(self.relation_size)] 40 | self.relation2ht2index = [dict() for k in range(self.relation_size)] 41 | self.relation2outdegree = [[0 for i in range(self.entity_size)] for k in range(self.relation_size)] 42 | 43 | with open(os.path.join(data_path, "train.txt")) as fi: 44 | for line in fi: 45 | h, r, t = line.strip().split('\t') 46 | h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] 47 | self.train_facts.append((h, r, t)) 48 | 49 | hr_index = self.encode_hr(h, r) 50 | 51 | if hr_index not in self.hr2o: 52 | self.hr2o[hr_index] = list() 53 | self.hr2o[hr_index].append(t) 54 | 55 | if hr_index not in self.hr2oo: 56 | self.hr2oo[hr_index] = list() 57 | self.hr2oo[hr_index].append(t) 58 | 59 | if hr_index not in self.hr2ooo: 60 | self.hr2ooo[hr_index] = list() 61 | self.hr2ooo[hr_index].append(t) 62 | 63 | self.relation2adjacency[r][0].append(t) 64 | self.relation2adjacency[r][1].append(h) 65 | 66 | ht_index = self.encode_ht(h, t) 67 | assert ht_index not in self.relation2ht2index[r] 68 | index = len(self.relation2ht2index[r]) 69 | self.relation2ht2index[r][ht_index] = index 70 | 71 | self.relation2outdegree[r][t] += 1 72 | 73 | with open(os.path.join(data_path, "valid.txt")) as fi: 74 | for line in fi: 75 | h, r, t = line.strip().split('\t') 76 | h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] 77 | self.valid_facts.append((h, r, t)) 78 | 79 | hr_index = self.encode_hr(h, r) 80 | 81 | if hr_index not in self.hr2oo: 82 | self.hr2oo[hr_index] = list() 83 | self.hr2oo[hr_index].append(t) 84 | 85 | if hr_index not in self.hr2ooo: 86 | self.hr2ooo[hr_index] = list() 87 | self.hr2ooo[hr_index].append(t) 88 | 89 | with open(os.path.join(data_path, "test.txt")) as fi: 90 | for line in fi: 91 | h, r, t = line.strip().split('\t') 92 | h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] 93 | self.test_facts.append((h, r, t)) 94 | 95 | hr_index = self.encode_hr(h, r) 96 | 97 | if hr_index not in self.hr2ooo: 98 | self.hr2ooo[hr_index] = list() 99 | self.hr2ooo[hr_index].append(t) 100 | 101 | for r in range(self.relation_size): 102 | index = torch.LongTensor(self.relation2adjacency[r]) 103 | value = torch.ones(index.size(1)) 104 | self.relation2adjacency[r] = [index, value] 105 | 106 | self.relation2outdegree[r] = torch.LongTensor(self.relation2outdegree[r]) 107 | 108 | print("Data loading | DONE!") 109 | 110 | def encode_hr(self, h, r): 111 | return r * self.entity_size + h 112 | 113 | def decode_hr(self, index): 114 | h, r = index % self.entity_size, index // self.entity_size 115 | return h, r 116 | 117 | def encode_ht(self, h, t): 118 | return t * self.entity_size + h 119 | 120 | def decode_ht(self, index): 121 | h, t = index % self.entity_size, index // self.entity_size 122 | return h, t 123 | 124 | def get_updated_adjacency(self, r, edges_to_remove): 125 | if edges_to_remove == None: 126 | return None 127 | index = self.relation2sparse[r][0] 128 | value = self.relation2sparse[r][1] 129 | mask = (index.unsqueeze(1) == edges_to_remove.unsqueeze(-1)) 130 | mask = mask.all(dim=0).any(dim=0) 131 | mask = ~mask 132 | index = index[:, mask] 133 | value = value[mask] 134 | return [index, value] 135 | 136 | def grounding(self, h, r, rule, edges_to_remove): 137 | device = h.device 138 | with torch.no_grad(): 139 | x = torch.nn.functional.one_hot(h, self.entity_size).transpose(0, 1).unsqueeze(-1) 140 | if device.type == "cuda": 141 | x = x.cuda(device) 142 | for r_body in rule: 143 | if r_body == r: 144 | x = self.propagate(x, r_body, edges_to_remove) 145 | else: 146 | x = self.propagate(x, r_body, None) 147 | return x.squeeze(-1).transpose(0, 1) 148 | 149 | def propagate(self, x, relation, edges_to_remove=None): 150 | device = x.device 151 | node_in = self.relation2adjacency[relation][0][1] 152 | node_out = self.relation2adjacency[relation][0][0] 153 | if device.type == "cuda": 154 | node_in = node_in.cuda(device) 155 | node_out = node_out.cuda(device) 156 | 157 | message = x[node_in] 158 | E, B, D = message.size() 159 | 160 | if edges_to_remove == None: 161 | x = scatter(message, node_out, dim=0, dim_size=x.size(0)) 162 | else: 163 | # message: edge * batch * dim 164 | message = message.view(-1, D) 165 | bias = torch.arange(B) 166 | if device.type == "cuda": 167 | bias = bias.cuda(device) 168 | edges_to_remove = edges_to_remove * B + bias 169 | message[edges_to_remove] = 0 170 | message = message.view(E, B, D) 171 | x = scatter(message, node_out, dim=0, dim_size=x.size(0)) 172 | 173 | return x 174 | 175 | class TrainDataset(Dataset): 176 | def __init__(self, graph, batch_size): 177 | self.graph = graph 178 | self.batch_size = batch_size 179 | 180 | self.r2instances = [[] for r in range(self.graph.relation_size)] 181 | for h, r, t in self.graph.train_facts: 182 | self.r2instances[r].append((h, r, t)) 183 | 184 | self.make_batches() 185 | 186 | def make_batches(self): 187 | for r in range(self.graph.relation_size): 188 | random.shuffle(self.r2instances[r]) 189 | 190 | self.batches = list() 191 | for r, instances in enumerate(self.r2instances): 192 | for k in range(0, len(instances), self.batch_size): 193 | start = k 194 | end = min(k + self.batch_size, len(instances)) 195 | self.batches.append(instances[start:end]) 196 | random.shuffle(self.batches) 197 | 198 | def __len__(self): 199 | return len(self.batches) 200 | 201 | def __getitem__(self, idx): 202 | data = self.batches[idx] 203 | 204 | all_h = torch.LongTensor([_[0] for _ in data]) 205 | all_r = torch.LongTensor([_[1] for _ in data]) 206 | all_t = torch.LongTensor([_[2] for _ in data]) 207 | target = torch.zeros(len(data), self.graph.entity_size) 208 | edges_to_remove = [] 209 | for k, (h, r, t) in enumerate(data): 210 | hr_index = self.graph.encode_hr(h, r) 211 | t_index = torch.LongTensor(self.graph.hr2o[hr_index]) 212 | target[k][t_index] = 1 213 | 214 | ht_index = self.graph.encode_ht(h, t) 215 | edge = self.graph.relation2ht2index[r][ht_index] 216 | edges_to_remove.append(edge) 217 | edges_to_remove = torch.LongTensor(edges_to_remove) 218 | 219 | return all_h, all_r, all_t, target, edges_to_remove 220 | 221 | class ValidDataset(Dataset): 222 | def __init__(self, graph, batch_size): 223 | self.graph = graph 224 | self.batch_size = batch_size 225 | 226 | facts = self.graph.valid_facts 227 | 228 | r2instances = [[] for r in range(self.graph.relation_size)] 229 | for h, r, t in facts: 230 | r2instances[r].append((h, r, t)) 231 | 232 | self.batches = list() 233 | for r, instances in enumerate(r2instances): 234 | random.shuffle(instances) 235 | for k in range(0, len(instances), self.batch_size): 236 | start = k 237 | end = min(k + self.batch_size, len(instances)) 238 | self.batches.append(instances[start:end]) 239 | 240 | def __len__(self): 241 | return len(self.batches) 242 | 243 | def __getitem__(self, idx): 244 | data = self.batches[idx] 245 | 246 | all_h = torch.LongTensor([_[0] for _ in data]) 247 | all_r = torch.LongTensor([_[1] for _ in data]) 248 | all_t = torch.LongTensor([_[2] for _ in data]) 249 | 250 | mask = torch.ones(len(data), self.graph.entity_size).bool() 251 | for k, (h, r, t) in enumerate(data): 252 | hr_index = self.graph.encode_hr(h, r) 253 | t_index = torch.LongTensor(self.graph.hr2oo[hr_index]) 254 | mask[k][t_index] = 0 255 | 256 | return all_h, all_r, all_t, mask 257 | 258 | class TestDataset(Dataset): 259 | def __init__(self, graph, batch_size): 260 | self.graph = graph 261 | self.batch_size = batch_size 262 | 263 | facts = self.graph.test_facts 264 | 265 | r2instances = [[] for r in range(self.graph.relation_size)] 266 | for h, r, t in facts: 267 | r2instances[r].append((h, r, t)) 268 | 269 | self.batches = list() 270 | for r, instances in enumerate(r2instances): 271 | random.shuffle(instances) 272 | for k in range(0, len(instances), self.batch_size): 273 | start = k 274 | end = min(k + self.batch_size, len(instances)) 275 | self.batches.append(instances[start:end]) 276 | 277 | def __len__(self): 278 | return len(self.batches) 279 | 280 | def __getitem__(self, idx): 281 | data = self.batches[idx] 282 | 283 | all_h = torch.LongTensor([_[0] for _ in data]) 284 | all_r = torch.LongTensor([_[1] for _ in data]) 285 | all_t = torch.LongTensor([_[2] for _ in data]) 286 | 287 | mask = torch.ones(len(data), self.graph.entity_size).bool() 288 | for k, (h, r, t) in enumerate(data): 289 | hr_index = self.graph.encode_hr(h, r) 290 | t_index = torch.LongTensor(self.graph.hr2ooo[hr_index]) 291 | mask[k][t_index] = 0 292 | 293 | return all_h, all_r, all_t, mask 294 | 295 | class RuleDataset(Dataset): 296 | def __init__(self, num_relations, input): 297 | self.rules = list() 298 | self.num_relations = num_relations 299 | self.ending_idx = num_relations 300 | self.padding_idx = num_relations + 1 301 | 302 | if type(input) == list: 303 | rules = input 304 | elif type(input) == str: 305 | rules = list() 306 | with open(input, 'r') as fi: 307 | for line in fi: 308 | rule = line.strip().split() 309 | rule = [int(_) for _ in rule[0:-1]] + [float(rule[-1]) * 1000] 310 | rules.append(rule) 311 | 312 | self.rules = [] 313 | for rule in rules: 314 | rule_len = len(rule) 315 | formatted_rule = [rule[0:-1] + [self.ending_idx], self.padding_idx, rule[-1] + 1e-5] 316 | self.rules.append(formatted_rule) 317 | 318 | def __len__(self): 319 | return len(self.rules) 320 | 321 | def __getitem__(self, idx): 322 | return self.rules[idx] 323 | 324 | @staticmethod 325 | def collate_fn(data): 326 | inputs = [item[0][0:len(item[0])-1] for item in data] 327 | target = [item[0][1:len(item[0])] for item in data] 328 | weight = [float(item[-1]) for item in data] 329 | max_len = max([len(_) for _ in inputs]) 330 | padding_index = [int(item[-2]) for item in data] 331 | 332 | for k in range(len(data)): 333 | for i in range(max_len - len(inputs[k])): 334 | inputs[k].append(padding_index[k]) 335 | target[k].append(padding_index[k]) 336 | 337 | inputs = torch.tensor(inputs, dtype=torch.long) 338 | target = torch.tensor(target, dtype=torch.long) 339 | weight = torch.tensor(weight) 340 | mask = (target != torch.tensor(padding_index, dtype=torch.long).unsqueeze(1)) 341 | 342 | return inputs, target, mask, weight 343 | 344 | def Iterator(dataloader): 345 | while True: 346 | for data in dataloader: 347 | yield data -------------------------------------------------------------------------------- /src/embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import json 5 | 6 | class RotatE(torch.nn.Module): 7 | def __init__(self, path): 8 | super(RotatE, self).__init__() 9 | self.path = path 10 | 11 | cfg_file = os.path.join(path, 'config.json') 12 | with open(cfg_file, 'r') as fi: 13 | cfg = json.load(fi) 14 | self.emb_dim = cfg['hidden_dim'] 15 | self.gamma = cfg['gamma'] 16 | self.range = (self.gamma + 2.0) / self.emb_dim 17 | self.num_entities = cfg['nentity'] 18 | 19 | eemb_file = os.path.join(path, 'entity_embedding.npy') 20 | eemb = np.load(eemb_file) 21 | self.eemb = torch.nn.parameter.Parameter(torch.tensor(eemb)) 22 | 23 | remb_file = os.path.join(path, 'relation_embedding.npy') 24 | remb = np.load(remb_file) 25 | remb = torch.tensor(remb) 26 | self.remb = torch.nn.parameter.Parameter(torch.cat([remb, -remb], dim=0)) 27 | 28 | def product(self, vec1, vec2): 29 | re_1, im_1 = torch.chunk(vec1, 2, dim=-1) 30 | re_2, im_2 = torch.chunk(vec2, 2, dim=-1) 31 | 32 | re_res = re_1 * re_2 - im_1 * im_2 33 | im_res = re_1 * im_2 + im_1 * re_2 34 | 35 | return torch.cat([re_res, im_res], dim=-1) 36 | 37 | def project(self, vec): 38 | pi = 3.141592653589793238462643383279 39 | vec = vec / (self.range / pi) 40 | 41 | re_r = torch.cos(vec) 42 | im_r = torch.sin(vec) 43 | 44 | return torch.cat([re_r, im_r], dim=-1) 45 | 46 | def diff(self, vec1, vec2): 47 | diff = vec1 - vec2 48 | re_diff, im_diff = torch.chunk(diff, 2, dim=-1) 49 | diff = torch.stack([re_diff, im_diff], dim=0) 50 | diff = diff.norm(dim=0) 51 | return diff 52 | 53 | def dist(self, all_h, all_r, all_t): 54 | h_emb = self.eemb.index_select(0, all_h).squeeze() 55 | r_emb = self.remb.index_select(0, all_r).squeeze() 56 | t_emb = self.eemb.index_select(0, all_t).squeeze() 57 | 58 | r_emb = self.project(r_emb) 59 | e_emb = self.product(h_emb, r_emb) 60 | dist = self.diff(e_emb, t_emb) 61 | 62 | return dist.sum(dim=-1) 63 | 64 | def forward(self, all_h, all_r): 65 | all_h_ = all_h.unsqueeze(-1).expand(-1, self.num_entities).reshape(-1) 66 | all_r_ = all_r.unsqueeze(-1).expand(-1, self.num_entities).reshape(-1) 67 | all_e_ = torch.tensor(list(range(self.num_entities)), dtype=torch.long, device=all_h.device).unsqueeze(0).expand(all_r.size(0), -1).reshape(-1) 68 | kge_score = self.gamma - self.dist(all_h_, all_r_, all_e_) 69 | kge_score = kge_score.view(-1, self.num_entities) 70 | return kge_score 71 | 72 | -------------------------------------------------------------------------------- /src/generators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Generator(torch.nn.Module): 4 | def __init__(self, graph, num_layers, embedding_dim, hidden_dim): 5 | super(Generator, self).__init__() 6 | self.graph = graph 7 | self.num_relations = graph.relation_size 8 | 9 | self.num_layers = num_layers 10 | self.embedding_dim = embedding_dim 11 | self.hidden_dim = hidden_dim 12 | 13 | self.vocab_size = self.num_relations + 2 14 | self.label_size = self.num_relations + 1 15 | self.ending_idx = self.num_relations 16 | self.padding_idx = self.num_relations + 1 17 | 18 | self.embedding = torch.nn.Embedding(self.vocab_size, self.embedding_dim, padding_idx=self.padding_idx) 19 | self.rnn = torch.nn.LSTM(self.embedding_dim * 2, self.hidden_dim, self.num_layers, batch_first=True) 20 | self.linear = torch.nn.Linear(self.hidden_dim, self.label_size) 21 | self.criterion = torch.nn.CrossEntropyLoss(reduction='none') 22 | 23 | def forward(self, inputs, relation, hidden): 24 | embedding = self.embedding(inputs) 25 | embedding_r = self.embedding(relation).unsqueeze(1).expand(-1, inputs.size(1), -1) 26 | embedding = torch.cat([embedding, embedding_r], dim=-1) 27 | outputs, hidden = self.rnn(embedding, hidden) 28 | logits = self.linear(outputs) 29 | return logits, hidden 30 | 31 | def loss(self, inputs, target, mask, weight, hidden): 32 | logits, hidden = self.forward(inputs, inputs[:, 0], hidden) 33 | logits = torch.masked_select(logits, mask.unsqueeze(-1)).view(-1, self.label_size) 34 | target = torch.masked_select(target, mask) 35 | weight = torch.masked_select((mask.t() * weight).t(), mask) 36 | loss = (self.criterion(logits, target) * weight).sum() / weight.sum() 37 | return loss 38 | 39 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.init import xavier_normal_ 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, input_dim, hidden_dims, short_cut=False, batch_norm=False, activation="relu", dropout=0): 11 | super(MLP, self).__init__() 12 | 13 | self.dims = [input_dim] + hidden_dims 14 | self.short_cut = short_cut 15 | 16 | if isinstance(activation, str): 17 | self.activation = getattr(F, activation) 18 | else: 19 | self.activation = activation 20 | if dropout: 21 | self.dropout = nn.Dropout(dropout) 22 | else: 23 | self.dropout = None 24 | 25 | self.layers = nn.ModuleList() 26 | for i in range(len(self.dims) - 1): 27 | self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) 28 | if batch_norm: 29 | self.batch_norms = nn.ModuleList() 30 | for i in range(len(self.dims) - 2): 31 | self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1])) 32 | else: 33 | self.batch_norms = None 34 | 35 | def forward(self, input): 36 | layer_input = input 37 | 38 | for i, layer in enumerate(self.layers): 39 | hidden = layer(layer_input) 40 | if i < len(self.layers) - 1: 41 | if self.batch_norms: 42 | x = hidden.flatten(0, -2) 43 | hidden = self.batch_norms[i](x).view_as(hidden) 44 | hidden = self.activation(hidden) 45 | if self.dropout: 46 | hidden = self.dropout(hidden) 47 | if self.short_cut and hidden.shape == layer_input.shape: 48 | hidden = hidden + layer_input 49 | layer_input = hidden 50 | 51 | return hidden 52 | 53 | class FuncToNodeSum(nn.Module): 54 | def __init__(self, vector_dim): 55 | super(FuncToNodeSum, self).__init__() 56 | 57 | self.vector_dim = vector_dim 58 | self.layer_norm = nn.LayerNorm(self.vector_dim) 59 | self.add_model = MLP(self.vector_dim, [self.vector_dim]) 60 | 61 | self.eps = 1e-6 62 | 63 | def forward(self, A_fn, x_f, b_n): 64 | device = x_f.device 65 | batch_size = b_n.max().item() + 1 66 | 67 | degree = A_fn.sum(0) + 1 68 | weight = torch.transpose(A_fn, 0, 1).unsqueeze(-1) 69 | message = x_f.unsqueeze(0) 70 | 71 | weight_zero = weight == 0 72 | features = (message * weight).sum(1) 73 | output = self.add_model(features) 74 | output = self.layer_norm(output) 75 | output = torch.relu(output) 76 | 77 | return output 78 | 79 | class FuncToNode(nn.Module): 80 | def __init__(self, vector_dim): 81 | super(FuncToNode, self).__init__() 82 | 83 | self.vector_dim = vector_dim 84 | self.layer_norm = nn.LayerNorm(self.vector_dim) 85 | self.add_model = MLP(self.vector_dim * 12, [self.vector_dim]) 86 | 87 | self.eps = 1e-6 88 | 89 | def forward(self, A_fn, x_f, b_n): 90 | device = x_f.device 91 | batch_size = b_n.max().item() + 1 92 | 93 | degree = A_fn.sum(0) + 1 94 | weight = torch.transpose(A_fn, 0, 1).unsqueeze(-1) 95 | message = x_f.unsqueeze(0) 96 | 97 | weight_zero = weight == 0 98 | sum = (message * weight).sum(1) 99 | sq_sum = ((message ** 2) * weight).sum(1) 100 | min = message.expand(weight.size(0), -1, -1).masked_fill(weight_zero, float('inf')).min(1)[0] 101 | max = message.expand(weight.size(0), -1, -1).masked_fill(weight_zero, float('-inf')).max(1)[0] 102 | 103 | degree_out = degree.unsqueeze(-1) 104 | mean = sum / degree_out.clamp(min=self.eps) 105 | sq_mean = sq_sum / degree_out.clamp(min=self.eps) 106 | std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt() 107 | features = torch.cat([mean, min, max, std], dim=-1) 108 | 109 | scale = degree_out.log() 110 | sum_scale = torch.zeros(batch_size, device=device) 111 | cn_scale = torch.zeros(batch_size, device=device) 112 | ones = torch.ones(scale.size(0), device=device) 113 | sum_scale.scatter_add_(0, b_n, scale.squeeze(-1)) 114 | cn_scale.scatter_add_(0, b_n, ones) 115 | mean_scale = sum_scale / cn_scale.clamp(min=self.eps) 116 | scale = scale / mean_scale[b_n].unsqueeze(-1).clamp(min=self.eps) 117 | scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=self.eps)], dim=-1) 118 | 119 | update = features.unsqueeze(-1) * scales.unsqueeze(-2) 120 | update = update.flatten(-2) 121 | 122 | output = self.add_model(update) 123 | output = self.layer_norm(output) 124 | output = torch.relu(output) 125 | 126 | return output -------------------------------------------------------------------------------- /src/predictors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.init import xavier_normal_ 8 | import copy 9 | import random 10 | import logging 11 | from collections import defaultdict 12 | from layers import MLP, FuncToNode, FuncToNodeSum 13 | from embedding import RotatE 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch_scatter import scatter, scatter_add, scatter_min, scatter_max, scatter_mean 16 | 17 | class Predictor(torch.nn.Module): 18 | def __init__(self, graph, entity_feature='bias'): 19 | super(Predictor, self).__init__() 20 | self.graph = graph 21 | self.num_entities = graph.entity_size 22 | self.num_relations = graph.relation_size 23 | self.entity_feature = entity_feature 24 | if entity_feature == 'bias': 25 | self.bias = torch.nn.parameter.Parameter(torch.zeros(self.num_entities)) 26 | 27 | def set_rules(self, input): 28 | self.rules = list() 29 | if type(input) == list: 30 | for rule in input: 31 | rule_ = (rule[0], rule[1:]) 32 | self.rules.append(rule_) 33 | logging.info('Predictor: read {} rules from list.'.format(len(self.rules))) 34 | elif type(input) == str: 35 | with open(input, 'r') as fi: 36 | for line in fi: 37 | rule = line.strip().split() 38 | rule = [int(_) for _ in rule] 39 | rule_ = (rule[0], rule[1:]) 40 | self.rules.append(rule_) 41 | logging.info('Predictor: read {} rules from file.'.format(len(self.rules))) 42 | else: 43 | raise ValueError 44 | self.num_rules = len(self.rules) 45 | 46 | self.relation2rules = [[] for r in range(self.num_relations)] 47 | for index, rule in enumerate(self.rules): 48 | relation = rule[0] 49 | self.relation2rules[relation].append([index, rule]) 50 | 51 | self.rule_weights = torch.nn.parameter.Parameter(torch.zeros(self.num_rules)) 52 | 53 | def forward(self, all_h, all_r, edges_to_remove): 54 | query_r = all_r[0].item() 55 | assert (all_r != query_r).sum() == 0 56 | device = all_r.device 57 | 58 | score = torch.zeros(all_r.size(0), self.num_entities, device=device) 59 | mask = torch.zeros(all_r.size(0), self.num_entities, device=device) 60 | for index, (r_head, r_body) in self.relation2rules[query_r]: 61 | assert r_head == query_r 62 | 63 | x = self.graph.grounding(all_h, r_head, r_body, edges_to_remove) 64 | score += x * self.rule_weights[index] 65 | mask += x 66 | 67 | if mask.sum().item() == 0: 68 | if self.entity_feature == 'bias': 69 | return mask + self.bias.unsqueeze(0), (1 - mask).bool() 70 | else: 71 | return mask - float('-inf'), mask.bool() 72 | 73 | if self.entity_feature == 'bias': 74 | score = score + self.bias.unsqueeze(0) 75 | mask = torch.ones_like(mask).bool() 76 | else: 77 | mask = (mask != 0) 78 | score = score.masked_fill(~mask, float('-inf')) 79 | 80 | return score, mask 81 | 82 | def compute_H(self, all_h, all_r, all_t, edges_to_remove): 83 | query_r = all_r[0].item() 84 | assert (all_r != query_r).sum() == 0 85 | device = all_r.device 86 | 87 | rule_score = list() 88 | rule_index = list() 89 | mask = torch.zeros(all_r.size(0), self.num_entities, device=device) 90 | for index, (r_head, r_body) in self.relation2rules[query_r]: 91 | assert r_head == query_r 92 | 93 | x = self.graph.grounding(all_h, r_head, r_body, edges_to_remove) 94 | score = x * self.rule_weights[index] 95 | mask += x 96 | 97 | rule_score.append(score) 98 | rule_index.append(index) 99 | 100 | rule_index = torch.tensor(rule_index, dtype=torch.long, device=device) 101 | pos_index = F.one_hot(all_t, self.num_entities).bool() 102 | if device.type == "cuda": 103 | pos_index = pos_index.cuda(device) 104 | neg_index = (mask != 0) 105 | 106 | if len(rule_score) == 0: 107 | return None, None 108 | 109 | rule_H_score = list() 110 | for score in rule_score: 111 | pos_score = (score * pos_index).sum(1) / torch.clamp(pos_index.sum(1), min=1) 112 | neg_score = (score * neg_index).sum(1) / torch.clamp(neg_index.sum(1), min=1) 113 | H_score = pos_score - neg_score 114 | rule_H_score.append(H_score.unsqueeze(-1)) 115 | 116 | rule_H_score = torch.cat(rule_H_score, dim=-1) 117 | rule_H_score = torch.softmax(rule_H_score, dim=-1).sum(0) 118 | 119 | return rule_H_score, rule_index 120 | 121 | class PredictorPlus(torch.nn.Module): 122 | def __init__(self, graph, type='emb', num_layers=3, hidden_dim=16, entity_feature='bias', aggregator='sum', embedding_path=None): 123 | super(PredictorPlus, self).__init__() 124 | self.graph = graph 125 | 126 | self.type = type 127 | self.num_layers = num_layers 128 | self.hidden_dim = hidden_dim 129 | self.entity_feature = entity_feature 130 | self.aggregator = aggregator 131 | self.embedding_path = embedding_path 132 | 133 | self.num_entities = graph.entity_size 134 | self.num_relations = graph.relation_size 135 | self.padding_index = graph.relation_size 136 | 137 | self.vocab_emb = torch.nn.Embedding(self.num_relations + 1, self.hidden_dim, padding_idx=self.num_relations) 138 | 139 | if self.type == 'lstm': 140 | self.rnn = torch.nn.LSTM(self.hidden_dim, self.hidden_dim, self.num_layers, batch_first=True) 141 | elif self.type == 'gru': 142 | self.rnn = torch.nn.GRU(self.hidden_dim, self.hidden_dim, self.num_layers, batch_first=True) 143 | elif self.type == 'rnn': 144 | self.rnn = torch.nn.RNN(self.hidden_dim, self.hidden_dim, self.num_layers, batch_first=True) 145 | elif self.type == 'emb': 146 | self.rule_emb = None 147 | else: 148 | raise NotImplementedError 149 | 150 | if aggregator == 'sum': 151 | self.rule_to_entity = FuncToNodeSum(self.hidden_dim) 152 | elif aggregator == 'pna': 153 | self.rule_to_entity = FuncToNode(self.hidden_dim) 154 | else: 155 | raise NotImplementedError 156 | 157 | self.relation_emb = torch.nn.Embedding(self.num_relations, self.hidden_dim) 158 | self.score_model = MLP(self.hidden_dim * 2, [128, 1]) # 128 for FB15k 159 | 160 | if entity_feature == 'bias': 161 | self.bias = torch.nn.parameter.Parameter(torch.zeros(self.num_entities)) 162 | elif entity_feature == 'RotatE': 163 | self.RotatE = RotatE(embedding_path) 164 | 165 | def set_rules(self, input): 166 | self.rules = list() 167 | if type(input) == list: 168 | for rule in input: 169 | rule_ = (rule[0], rule[1:]) 170 | self.rules.append(rule_) 171 | logging.info('Predictor+: read {} rules from list.'.format(len(self.rules))) 172 | elif type(input) == str: 173 | self.rules = list() 174 | with open(input, 'r') as fi: 175 | for line in fi: 176 | rule = line.strip().split() 177 | rule = [int(_) for _ in rule] 178 | rule_ = (rule[0], rule[1:]) 179 | self.rules.append(rule_) 180 | logging.info('Predictor+: read {} rules from file.'.format(len(self.rules))) 181 | else: 182 | raise ValueError 183 | self.num_rules = len(self.rules) 184 | self.max_length = max([len(rule[1]) for rule in self.rules]) 185 | 186 | self.relation2rules = [[] for r in range(self.num_relations)] 187 | for index, rule in enumerate(self.rules): 188 | relation = rule[0] 189 | self.relation2rules[relation].append([index, rule]) 190 | 191 | self.rule_features = [] 192 | for rule in self.rules: 193 | rule_ = [rule[0]] + rule[1] + [self.padding_index for i in range(self.max_length - len(rule[1]))] 194 | self.rule_features.append(rule_) 195 | self.rule_features = torch.tensor(self.rule_features, dtype=torch.long) 196 | 197 | if self.type == 'emb': 198 | self.rule_emb = nn.parameter.Parameter(torch.zeros(self.num_rules, self.hidden_dim)) 199 | nn.init.kaiming_uniform_(self.rule_emb, a=math.sqrt(5), mode="fan_in") 200 | 201 | def encode_rules(self, rule_features): 202 | rule_masks = rule_features != self.num_relations 203 | x = self.vocab_emb(rule_features) 204 | output, hidden = self.rnn(x) 205 | idx = (rule_masks.sum(-1) - 1).long() 206 | idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.hidden_dim) 207 | rule_emb = torch.gather(output, 1, idx).squeeze(1) 208 | return rule_emb 209 | 210 | def forward(self, all_h, all_r, edges_to_remove): 211 | query_r = all_r[0].item() 212 | assert (all_r != query_r).sum() == 0 213 | device = all_r.device 214 | 215 | if device.type == "cuda": 216 | self.rule_features = self.rule_features.cuda(device) 217 | 218 | rule_index = list() 219 | rule_count = list() 220 | mask = torch.zeros(all_h.size(0), self.graph.entity_size, device=device) 221 | for index, (r_head, r_body) in self.relation2rules[query_r]: 222 | assert r_head == query_r 223 | 224 | count = self.graph.grounding(all_h, r_head, r_body, edges_to_remove).float() 225 | mask += count 226 | 227 | rule_index.append(index) 228 | rule_count.append(count) 229 | 230 | if mask.sum().item() == 0: 231 | if self.entity_feature == 'bias': 232 | return mask + self.bias.unsqueeze(0), (1 - mask).bool() 233 | elif self.entity_feature == 'RotatE': 234 | bias = self.RotatE(all_h, all_r) 235 | return mask + bias, (1 - mask).bool() 236 | else: 237 | return mask - float('-inf'), mask.bool() 238 | 239 | candidate_set = torch.nonzero(mask.view(-1), as_tuple=True)[0] 240 | batch_id_of_candidate = candidate_set // self.graph.entity_size 241 | 242 | rule_index = torch.tensor(rule_index, dtype=torch.long, device=device) 243 | rule_count = torch.stack(rule_count, dim=0) 244 | rule_count = rule_count.reshape(rule_index.size(0), -1)[:, candidate_set] 245 | 246 | if self.type == 'emb': 247 | rule_emb = self.rule_emb[rule_index] 248 | else: 249 | rule_emb = self.encode_rules(self.rule_features[rule_index]) 250 | 251 | output = self.rule_to_entity(rule_count, rule_emb, batch_id_of_candidate) 252 | 253 | rel = self.relation_emb(all_r[0]).unsqueeze(0).expand(output.size(0), -1) 254 | feature = torch.cat([output, rel], dim=-1) 255 | output = self.score_model(feature).squeeze(-1) 256 | 257 | score = torch.zeros(all_h.size(0) * self.graph.entity_size, device=device) 258 | score.scatter_(0, candidate_set, output) 259 | score = score.view(all_h.size(0), self.graph.entity_size) 260 | if self.entity_feature == 'bias': 261 | score = score + self.bias.unsqueeze(0) 262 | mask = torch.ones_like(mask).bool() 263 | elif self.entity_feature == 'RotatE': 264 | bias = self.RotatE(all_h, all_r) 265 | score = score + bias 266 | mask = torch.ones_like(mask).bool() 267 | else: 268 | mask = (mask != 0) 269 | score = score.masked_fill(~mask, float('-inf')) 270 | 271 | return score, mask 272 | -------------------------------------------------------------------------------- /src/run_predictorplus.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | import logging 5 | import argparse 6 | import random 7 | import json 8 | from easydict import EasyDict 9 | import numpy as np 10 | from datetime import datetime 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | from data import KnowledgeGraph, TrainDataset, ValidDataset, TestDataset 15 | from predictors import PredictorPlus 16 | from utils import load_config, save_config, set_logger, set_seed 17 | from trainer import TrainerPredictor 18 | import comm 19 | 20 | def parse_args(args=None): 21 | parser = argparse.ArgumentParser( 22 | description='RNNLogic', 23 | usage='train.py [] [-h | --help]' 24 | ) 25 | parser.add_argument('--config', default='../predictor.yaml', type=str) 26 | parser.add_argument("--local_rank", type=int, default=0) 27 | return parser.parse_args(args) 28 | 29 | def main(args): 30 | cfgs = load_config(args.config) 31 | cfg = cfgs[0] 32 | 33 | if cfg.save_path is None: 34 | cfg.save_path = os.path.join('/home/qumeng/scratch/rnnlogic/outputs', datetime.now().strftime('%Y%m-%d%H-%M%S')) 35 | 36 | if cfg.save_path and not os.path.exists(cfg.save_path): 37 | os.makedirs(cfg.save_path) 38 | 39 | save_config(cfg, cfg.save_path) 40 | 41 | set_logger(cfg.save_path) 42 | set_seed(cfg.seed) 43 | 44 | graph = KnowledgeGraph(cfg.data.data_path) 45 | train_set = TrainDataset(graph, cfg.data.batch_size) 46 | valid_set = ValidDataset(graph, cfg.data.batch_size) 47 | test_set = TestDataset(graph, cfg.data.batch_size) 48 | 49 | predictor = PredictorPlus(graph, **cfg.predictor.model) 50 | predictor.set_rules(cfg.data.rule_file) 51 | optim = torch.optim.Adam(predictor.parameters(), **cfg.predictor.optimizer) 52 | 53 | solver = TrainerPredictor(predictor, train_set, valid_set, test_set, optim, gpus=cfg.gpus) 54 | best_valid_mrr = 0.0 55 | test_mrr = 0.0 56 | for k in range(cfg.num_iters): 57 | if comm.get_rank() == 0: 58 | logging.info('-------------------------') 59 | logging.info('| Iteration: {}/{}'.format(k + 1, cfg.num_iters)) 60 | logging.info('-------------------------') 61 | 62 | solver.train(**cfg.predictor.train) 63 | valid_mrr_iter = solver.evaluate('valid', expectation=cfg.predictor.eval.expectation) 64 | test_mrr_iter = solver.evaluate('test', expectation=cfg.predictor.eval.expectation) 65 | if valid_mrr_iter > best_valid_mrr: 66 | best_valid_mrr = valid_mrr_iter 67 | test_mrr = test_mrr_iter 68 | solver.save(os.path.join(cfg.save_path, 'predictor.pt')) 69 | 70 | if __name__ == '__main__': 71 | main(parse_args()) -------------------------------------------------------------------------------- /src/run_rnnlogic.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | import logging 5 | import argparse 6 | import random 7 | import json 8 | from easydict import EasyDict 9 | import numpy as np 10 | from datetime import datetime 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | from data import KnowledgeGraph, TrainDataset, ValidDataset, TestDataset, RuleDataset 15 | from predictors import Predictor, PredictorPlus 16 | from generators import Generator 17 | from utils import load_config, save_config, set_logger, set_seed 18 | from trainer import TrainerPredictor, TrainerGenerator 19 | import comm 20 | 21 | def parse_args(args=None): 22 | parser = argparse.ArgumentParser( 23 | description='RNNLogic', 24 | usage='train.py [] [-h | --help]' 25 | ) 26 | parser.add_argument('--config', default='../rnnlogic.yaml', type=str) 27 | parser.add_argument("--local_rank", type=int, default=0) 28 | return parser.parse_args(args) 29 | 30 | def main(args): 31 | cfgs = load_config(args.config) 32 | cfg = cfgs[0] 33 | 34 | if cfg.save_path is None: 35 | cfg.save_path = os.path.join('../outputs', datetime.now().strftime('%Y%m-%d%H-%M%S')) 36 | 37 | if cfg.save_path and not os.path.exists(cfg.save_path): 38 | os.makedirs(cfg.save_path) 39 | 40 | save_config(cfg, cfg.save_path) 41 | 42 | set_logger(cfg.save_path) 43 | set_seed(cfg.seed) 44 | 45 | graph = KnowledgeGraph(cfg.data.data_path) 46 | train_set = TrainDataset(graph, cfg.data.batch_size) 47 | valid_set = ValidDataset(graph, cfg.data.batch_size) 48 | test_set = TestDataset(graph, cfg.data.batch_size) 49 | 50 | dataset = RuleDataset(graph.relation_size, cfg.data.rule_file) 51 | 52 | if comm.get_rank() == 0: 53 | logging.info('-------------------------') 54 | logging.info('| Pre-train Generator') 55 | logging.info('-------------------------') 56 | generator = Generator(graph, **cfg.generator.model) 57 | solver_g = TrainerGenerator(generator, gpu=cfg.generator.gpu) 58 | solver_g.train(dataset, **cfg.generator.pre_train) 59 | 60 | replay_buffer = list() 61 | for k in range(cfg.EM.num_iters): 62 | if comm.get_rank() == 0: 63 | logging.info('-------------------------') 64 | logging.info('| EM Iteration: {}/{}'.format(k + 1, cfg.EM.num_iters)) 65 | logging.info('-------------------------') 66 | 67 | # Sample logic rules. 68 | sampled_rules = solver_g.sample(cfg.EM.num_rules, cfg.EM.max_length) 69 | prior = [rule[-1] for rule in sampled_rules] 70 | rules = [rule[0:-1] for rule in sampled_rules] 71 | 72 | # Train a reasoning predictor with sampled logic rules. 73 | predictor = Predictor(graph, **cfg.predictor.model) 74 | predictor.set_rules(rules) 75 | optim = torch.optim.Adam(predictor.parameters(), **cfg.predictor.optimizer) 76 | 77 | solver_p = TrainerPredictor(predictor, train_set, valid_set, test_set, optim, gpus=cfg.predictor.gpus) 78 | solver_p.train(**cfg.predictor.train) 79 | valid_mrr_iter = solver_p.evaluate('valid', expectation=cfg.predictor.eval.expectation) 80 | test_mrr_iter = solver_p.evaluate('test', expectation=cfg.predictor.eval.expectation) 81 | 82 | # E-step: Compute H scores of logic rules. 83 | likelihood = solver_p.compute_H(**cfg.predictor.H_score) 84 | posterior = [l + p * cfg.EM.prior_weight for l, p in zip(likelihood, prior)] 85 | for i in range(len(rules)): 86 | rules[i].append(posterior[i]) 87 | replay_buffer += rules 88 | 89 | # M-step: Update the rule generator. 90 | dataset = RuleDataset(graph.relation_size, rules) 91 | solver_g.train(dataset, **cfg.generator.train) 92 | 93 | if replay_buffer != []: 94 | if comm.get_rank() == 0: 95 | logging.info('-------------------------') 96 | logging.info('| Post-train Generator') 97 | logging.info('-------------------------') 98 | dataset = RuleDataset(graph.relation_size, replay_buffer) 99 | solver_g.train(dataset, **cfg.generator.post_train) 100 | 101 | if comm.get_rank() == 0: 102 | logging.info('-------------------------') 103 | logging.info('| Beam Search Best Rules') 104 | logging.info('-------------------------') 105 | 106 | sampled_rules = list() 107 | for num_rules, max_length in zip(cfg.final_prediction.num_rules, cfg.final_prediction.max_length): 108 | sampled_rules_ = solver_g.beam_search(num_rules, max_length) 109 | sampled_rules += sampled_rules_ 110 | 111 | prior = [rule[-1] for rule in sampled_rules] 112 | rules = [rule[0:-1] for rule in sampled_rules] 113 | 114 | if comm.get_rank() == 0: 115 | logging.info('-------------------------') 116 | logging.info('| Train Final Predictor+') 117 | logging.info('-------------------------') 118 | 119 | predictor = PredictorPlus(graph, **cfg.predictorplus.model) 120 | predictor.set_rules(rules) 121 | optim = torch.optim.Adam(predictor.parameters(), **cfg.predictorplus.optimizer) 122 | 123 | solver_p = TrainerPredictor(predictor, train_set, valid_set, test_set, optim, gpus=cfg.predictorplus.gpus) 124 | best_valid_mrr = 0.0 125 | test_mrr = 0.0 126 | for k in range(cfg.final_prediction.num_iters): 127 | if comm.get_rank() == 0: 128 | logging.info('-------------------------') 129 | logging.info('| Iteration: {}/{}'.format(k + 1, cfg.final_prediction.num_iters)) 130 | logging.info('-------------------------') 131 | 132 | solver_p.train(**cfg.predictorplus.train) 133 | valid_mrr_iter = solver_p.evaluate('valid', expectation=cfg.predictorplus.eval.expectation) 134 | test_mrr_iter = solver_p.evaluate('test', expectation=cfg.predictorplus.eval.expectation) 135 | 136 | if valid_mrr_iter > best_valid_mrr: 137 | best_valid_mrr = valid_mrr_iter 138 | test_mrr = test_mrr_iter 139 | solver_p.save(os.path.join(cfg.save_path, 'predictor.pt')) 140 | 141 | if comm.get_rank() == 0: 142 | logging.info('-------------------------') 143 | logging.info('| Final Test MRR: {:.6f}'.format(test_mrr)) 144 | logging.info('-------------------------') 145 | 146 | if __name__ == '__main__': 147 | main(parse_args()) -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import comm 2 | from utils import * 3 | import torch 4 | from torch import distributed as dist 5 | from torch import nn 6 | from torch.utils import data as torch_data 7 | from itertools import islice 8 | from data import RuleDataset, Iterator 9 | 10 | class TrainerPredictor(object): 11 | 12 | def __init__(self, model, train_set, valid_set, test_set, optimizer, scheduler=None, gpus=None, num_worker=0): 13 | self.rank = comm.get_rank() 14 | self.world_size = comm.get_world_size() 15 | self.gpus = gpus 16 | self.num_worker = num_worker 17 | 18 | if gpus is None: 19 | self.device = torch.device("cpu") 20 | else: 21 | if len(gpus) != self.world_size: 22 | error_msg = "World size is %d but found %d GPUs in the argument" 23 | if self.world_size == 1: 24 | error_msg += ". Did you launch with `python -m torch.distributed.launch`?" 25 | raise ValueError(error_msg % (self.world_size, len(gpus))) 26 | self.device = torch.device(gpus[self.rank % len(gpus)]) 27 | 28 | if self.world_size > 1 and not dist.is_initialized(): 29 | if self.rank == 0: 30 | logging.info("Initializing distributed process group") 31 | backend = "gloo" if gpus is None else "nccl" 32 | comm.init_process_group(backend, init_method="env://") 33 | 34 | if self.rank == 0: 35 | logging.info("Preprocess training set") 36 | if self.world_size > 1: 37 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 38 | if self.device.type == "cuda": 39 | model = model.cuda(self.device) 40 | 41 | self.model = model 42 | self.train_set = train_set 43 | self.valid_set = valid_set 44 | self.test_set = test_set 45 | self.optimizer = optimizer 46 | self.scheduler = scheduler 47 | 48 | def train(self, batch_per_epoch, smoothing, print_every): 49 | if comm.get_rank() == 0: 50 | logging.info('>>>>> Predictor: Training') 51 | self.train_set.make_batches() 52 | sampler = torch_data.DistributedSampler(self.train_set, self.world_size, self.rank) 53 | dataloader = torch_data.DataLoader(self.train_set, 1, sampler=sampler, num_workers=self.num_worker) 54 | batch_per_epoch = batch_per_epoch or len(dataloader) 55 | model = self.model 56 | if self.world_size > 1: 57 | if self.device.type == "cuda": 58 | model = nn.parallel.DistributedDataParallel(model, device_ids=[self.device], find_unused_parameters=True) 59 | else: 60 | model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) 61 | model.train() 62 | 63 | total_loss = 0.0 64 | total_size = 0.0 65 | 66 | sampler.set_epoch(0) 67 | 68 | for batch_id, batch in enumerate(islice(dataloader, batch_per_epoch)): 69 | all_h, all_r, all_t, target, edges_to_remove = batch 70 | all_h = all_h.squeeze(0) 71 | all_r = all_r.squeeze(0) 72 | all_t = all_t.squeeze(0) 73 | target = target.squeeze(0) 74 | edges_to_remove = edges_to_remove.squeeze(0) 75 | target_t = torch.nn.functional.one_hot(all_t, self.train_set.graph.entity_size) 76 | 77 | if self.device.type == "cuda": 78 | all_h = all_h.cuda(device=self.device) 79 | all_r = all_r.cuda(device=self.device) 80 | target = target.cuda(device=self.device) 81 | edges_to_remove = edges_to_remove.cuda(device=self.device) 82 | target_t = target_t.cuda(device=self.device) 83 | 84 | target = target * smoothing + target_t * (1 - smoothing) 85 | 86 | logits, mask = model(all_h, all_r, edges_to_remove) 87 | if mask.sum().item() != 0: 88 | logits = (torch.softmax(logits, dim=1) + 1e-8).log() 89 | loss = -(logits[mask] * target[mask]).sum() / torch.clamp(target[mask].sum(), min=1) 90 | loss.backward() 91 | 92 | self.optimizer.step() 93 | self.optimizer.zero_grad() 94 | 95 | total_loss += loss.item() 96 | total_size += mask.sum().item() 97 | 98 | if (batch_id + 1) % print_every == 0: 99 | if comm.get_rank() == 0: 100 | logging.info('{} {} {:.6f} {:.1f}'.format(batch_id + 1, len(dataloader), total_loss / print_every, total_size / print_every)) 101 | total_loss = 0.0 102 | total_size = 0.0 103 | 104 | if self.scheduler: 105 | self.scheduler.step() 106 | 107 | @torch.no_grad() 108 | def compute_H(self, print_every): 109 | if comm.get_rank() == 0: 110 | logging.info('>>>>> Predictor: Computing H scores of rules') 111 | sampler = torch_data.DistributedSampler(self.train_set, self.world_size, self.rank) 112 | dataloader = torch_data.DataLoader(self.train_set, 1, sampler=sampler, num_workers=self.num_worker) 113 | model = self.model 114 | 115 | model.eval() 116 | all_H_score = torch.zeros(model.num_rules, device=self.device) 117 | for batch_id, batch in enumerate(dataloader): 118 | all_h, all_r, all_t, target, edges_to_remove = batch 119 | all_h = all_h.squeeze(0) 120 | all_r = all_r.squeeze(0) 121 | all_t = all_t.squeeze(0) 122 | target = target.squeeze(0) 123 | edges_to_remove = edges_to_remove.squeeze(0) 124 | 125 | if self.device.type == "cuda": 126 | all_h = all_h.cuda(device=self.device) 127 | all_r = all_r.cuda(device=self.device) 128 | target = target.cuda(device=self.device) 129 | edges_to_remove = edges_to_remove.cuda(device=self.device) 130 | 131 | H, index = model.compute_H(all_h, all_r, all_t, edges_to_remove) 132 | if H != None and index != None: 133 | all_H_score[index] += H / len(model.graph.train_facts) 134 | 135 | if (batch_id + 1) % print_every == 0: 136 | if comm.get_rank() == 0: 137 | logging.info('{} {}'.format(batch_id + 1, len(dataloader))) 138 | 139 | if self.world_size > 1: 140 | all_H_score = comm.stack(all_H_score) 141 | all_H_score = all_H_score.sum(0) 142 | 143 | return all_H_score.data.cpu().numpy().tolist() 144 | 145 | @torch.no_grad() 146 | def evaluate(self, split, expectation=True): 147 | if comm.get_rank() == 0: 148 | logging.info('>>>>> Predictor: Evaluating on {}'.format(split)) 149 | test_set = getattr(self, "%s_set" % split) 150 | sampler = torch_data.DistributedSampler(test_set, self.world_size, self.rank) 151 | dataloader = torch_data.DataLoader(test_set, 1, sampler=sampler, num_workers=self.num_worker) 152 | model = self.model 153 | 154 | model.eval() 155 | concat_logits = [] 156 | concat_all_h = [] 157 | concat_all_r = [] 158 | concat_all_t = [] 159 | concat_flag = [] 160 | concat_mask = [] 161 | for batch in dataloader: 162 | all_h, all_r, all_t, flag = batch 163 | all_h = all_h.squeeze(0) 164 | all_r = all_r.squeeze(0) 165 | all_t = all_t.squeeze(0) 166 | flag = flag.squeeze(0) 167 | if self.device.type == "cuda": 168 | all_h = all_h.cuda(device=self.device) 169 | all_r = all_r.cuda(device=self.device) 170 | all_t = all_t.cuda(device=self.device) 171 | flag = flag.cuda(device=self.device) 172 | 173 | logits, mask = model(all_h, all_r, None) 174 | 175 | concat_logits.append(logits) 176 | concat_all_h.append(all_h) 177 | concat_all_r.append(all_r) 178 | concat_all_t.append(all_t) 179 | concat_flag.append(flag) 180 | concat_mask.append(mask) 181 | 182 | concat_logits = torch.cat(concat_logits, dim=0) 183 | concat_all_h = torch.cat(concat_all_h, dim=0) 184 | concat_all_r = torch.cat(concat_all_r, dim=0) 185 | concat_all_t = torch.cat(concat_all_t, dim=0) 186 | concat_flag = torch.cat(concat_flag, dim=0) 187 | concat_mask = torch.cat(concat_mask, dim=0) 188 | 189 | ranks = [] 190 | for k in range(concat_all_t.size(0)): 191 | h = concat_all_h[k] 192 | r = concat_all_r[k] 193 | t = concat_all_t[k] 194 | if concat_mask[k, t].item() == True: 195 | val = concat_logits[k, t] 196 | L = (concat_logits[k][concat_flag[k]] > val).sum().item() + 1 197 | H = (concat_logits[k][concat_flag[k]] >= val).sum().item() + 2 198 | else: 199 | L = 1 200 | H = test_set.graph.entity_size + 1 201 | ranks += [[h, r, t, L, H]] 202 | ranks = torch.tensor(ranks, dtype=torch.long, device=self.device) 203 | 204 | if self.world_size > 1: 205 | ranks = comm.cat(ranks) 206 | 207 | query2LH = dict() 208 | for h, r, t, L, H in ranks.data.cpu().numpy().tolist(): 209 | query2LH[(h, r, t)] = (L, H) 210 | 211 | hit1, hit3, hit10, mr, mrr = 0.0, 0.0, 0.0, 0.0, 0.0 212 | for (L, H) in query2LH.values(): 213 | if expectation: 214 | for rank in range(L, H): 215 | if rank <= 1: 216 | hit1 += 1.0 / (H - L) 217 | if rank <= 3: 218 | hit3 += 1.0 / (H - L) 219 | if rank <= 10: 220 | hit10 += 1.0 / (H - L) 221 | mr += rank / (H - L) 222 | mrr += 1.0 / rank / (H - L) 223 | else: 224 | rank = H - 1 225 | if rank <= 1: 226 | hit1 += 1 227 | if rank <= 3: 228 | hit3 += 1 229 | if rank <= 10: 230 | hit10 += 1 231 | mr += rank 232 | mrr += 1.0 / rank 233 | 234 | hit1 /= len(ranks) 235 | hit3 /= len(ranks) 236 | hit10 /= len(ranks) 237 | mr /= len(ranks) 238 | mrr /= len(ranks) 239 | 240 | if comm.get_rank() == 0: 241 | logging.info('Data : {}'.format(len(query2LH))) 242 | logging.info('Hit1 : {:.6f}'.format(hit1)) 243 | logging.info('Hit3 : {:.6f}'.format(hit3)) 244 | logging.info('Hit10: {:.6f}'.format(hit10)) 245 | logging.info('MR : {:.6f}'.format(mr)) 246 | logging.info('MRR : {:.6f}'.format(mrr)) 247 | 248 | return mrr 249 | 250 | def load(self, checkpoint, load_optimizer=True): 251 | """ 252 | Load a checkpoint from file. 253 | Parameters: 254 | checkpoint (file-like): checkpoint file 255 | load_optimizer (bool, optional): load optimizer state or not 256 | """ 257 | if comm.get_rank() == 0: 258 | logging.info("Load checkpoint from %s" % checkpoint) 259 | checkpoint = os.path.expanduser(checkpoint) 260 | state = torch.load(checkpoint, map_location=self.device) 261 | 262 | self.model.load_state_dict(state["model"]) 263 | 264 | if load_optimizer: 265 | self.optimizer.load_state_dict(state["optimizer"]) 266 | for state in self.optimizer.state.values(): 267 | for k, v in state.items(): 268 | if isinstance(v, torch.Tensor): 269 | state[k] = v.to(self.device) 270 | 271 | comm.synchronize() 272 | 273 | def save(self, checkpoint): 274 | """ 275 | Save checkpoint to file. 276 | Parameters: 277 | checkpoint (file-like): checkpoint file 278 | """ 279 | if comm.get_rank() == 0: 280 | logging.info("Save checkpoint to %s" % checkpoint) 281 | checkpoint = os.path.expanduser(checkpoint) 282 | if self.rank == 0: 283 | state = { 284 | "model": self.model.state_dict(), 285 | "optimizer": self.optimizer.state_dict() 286 | } 287 | torch.save(state, checkpoint) 288 | 289 | comm.synchronize() 290 | 291 | class TrainerGenerator(object): 292 | 293 | def __init__(self, model, gpu): 294 | self.model = model 295 | 296 | if gpu is None: 297 | self.device = torch.device("cpu") 298 | else: 299 | self.device = torch.device(gpu) 300 | 301 | model = model.cuda(self.device) 302 | 303 | def train(self, rule_set, num_epoch=10000, lr=1e-3, print_every=100, batch_size=512): 304 | if comm.get_rank() == 0: 305 | logging.info('>>>>> Generator: Training') 306 | model = self.model 307 | model.train() 308 | 309 | dataloader = torch_data.DataLoader(rule_set, batch_size, shuffle=True, collate_fn=RuleDataset.collate_fn) 310 | iterator = Iterator(dataloader) 311 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 312 | 313 | total_loss = 0.0 314 | for epoch in range(num_epoch): 315 | batch = next(iterator) 316 | inputs, target, mask, weight = batch 317 | hidden = self.zero_state(inputs.size(0)) 318 | 319 | if self.device.type == "cuda": 320 | inputs = inputs.cuda(self.device) 321 | target = target.cuda(self.device) 322 | mask = mask.cuda(self.device) 323 | weight = weight.cuda(self.device) 324 | 325 | loss = model.loss(inputs, target, mask, weight, hidden) 326 | loss.backward() 327 | 328 | optimizer.step() 329 | optimizer.zero_grad() 330 | 331 | total_loss += loss.item() 332 | 333 | if (epoch + 1) % print_every == 0: 334 | if comm.get_rank() == 0: 335 | logging.info('{} {} {:.6f}'.format(epoch + 1, num_epoch, total_loss / print_every)) 336 | total_loss = 0.0 337 | 338 | def zero_state(self, batch_size): 339 | state_shape = (self.model.num_layers, batch_size, self.model.hidden_dim) 340 | h0 = c0 = torch.zeros(*state_shape, requires_grad=False, device=self.device) 341 | return (h0, c0) 342 | 343 | @torch.no_grad() 344 | def log_probability(self, rules): 345 | if rules == []: 346 | return [] 347 | 348 | model = self.model 349 | model.eval() 350 | 351 | rules = [rule + [model.ending_idx] for rule in rules] 352 | max_len = max([len(rule) for rule in rules]) 353 | for k in range(len(rules)): 354 | rule_len = len(rules[k]) 355 | for i in range(max_len - rule_len): 356 | rules[k] += [model.padding_idx] 357 | rules = torch.tensor(rules, dtype=torch.long, device=self.device) 358 | inputs = rules[:, :-1] 359 | target = rules[:, 1:] 360 | n, l = target.size(0), target.size(1) 361 | mask = (target != model.padding_idx) 362 | hidden = self.zero_state(inputs.size(0)) 363 | logits, hidden = model(inputs, inputs[:, 0], hidden) 364 | logits = torch.log_softmax(logits, -1) 365 | logits = logits * mask.unsqueeze(-1) 366 | target = (target * mask).unsqueeze(-1) 367 | log_prob = torch.gather(logits, -1, target).squeeze(-1) * mask 368 | log_prob = log_prob.sum(-1) 369 | return log_prob.data.cpu().numpy().tolist() 370 | 371 | @torch.no_grad() 372 | def next_relation_log_probability(self, seq, temperature): 373 | model = self.model 374 | model.eval() 375 | 376 | inputs = torch.tensor([seq], dtype=torch.long, device=self.device) 377 | relation = torch.tensor([seq[0]], dtype=torch.long, device=self.device) 378 | hidden = self.zero_state(1) 379 | logits, hidden = model(inputs, relation, hidden) 380 | log_prob = torch.log_softmax(logits[0, -1, :] / temperature, dim=-1).data.cpu().numpy().tolist() 381 | return log_prob 382 | 383 | @torch.no_grad() 384 | def beam_search(self, num_samples, max_len, temperature=0.2): 385 | if comm.get_rank() == 0: 386 | logging.info('>>>>> Generator: Rule generation with beam search') 387 | model = self.model 388 | model.eval() 389 | 390 | max_len += 1 391 | all_rules = [] 392 | for relation in range(model.num_relations): 393 | found_rules = [] 394 | prev_rules = [[[relation], 0]] 395 | for k in range(max_len): 396 | current_rules = list() 397 | for _i, (rule, score) in enumerate(prev_rules): 398 | assert rule[-1] != model.ending_idx 399 | log_prob = self.next_relation_log_probability(rule, temperature) 400 | for i in (range(model.label_size) if (k + 1) != max_len else [model.ending_idx]): 401 | new_rule = rule + [i] 402 | new_score = score + log_prob[i] 403 | (current_rules if i != model.ending_idx else found_rules).append((new_rule, new_score)) 404 | 405 | prev_rules = sorted(current_rules, key=lambda x:x[1], reverse=True)[:num_samples] 406 | found_rules = sorted(found_rules, key=lambda x:x[1], reverse=True)[:num_samples] 407 | 408 | ret = [rule[0:-1] + [score] for rule, score in found_rules] 409 | all_rules += ret 410 | return all_rules 411 | 412 | @torch.no_grad() 413 | def sample(self, num_samples, max_len, temperature=1.0): 414 | if comm.get_rank() == 0: 415 | logging.info('>>>>> Generator: Rule generation with sampling') 416 | model = self.model 417 | model.eval() 418 | 419 | all_rules = [] 420 | for relation in range(model.num_relations): 421 | rules = torch.zeros([num_samples, max_len + 1], dtype=torch.long, device=self.device) + model.ending_idx 422 | log_probabilities = torch.zeros([num_samples, max_len + 1], device=self.device) 423 | head = torch.tensor([relation for k in range(num_samples)], dtype=torch.long, device=self.device) 424 | 425 | rules[:, 0] = relation 426 | hidden = self.zero_state(num_samples) 427 | 428 | for pst in range(max_len): 429 | inputs = rules[:, pst].unsqueeze(-1) 430 | logits, hidden = model(inputs, head, hidden) 431 | logits /= temperature 432 | log_probability = torch.log_softmax(logits.squeeze(1), dim=-1) 433 | probability = torch.softmax(logits.squeeze(1), dim=-1) 434 | sample = torch.multinomial(probability, 1) 435 | log_probability = log_probability.gather(1, sample) 436 | 437 | mask = (rules[:, pst] != model.ending_idx) 438 | 439 | rules[mask, pst + 1] = sample.squeeze(-1)[mask] 440 | log_probabilities[mask, pst + 1] = log_probability.squeeze(-1)[mask] 441 | 442 | length = (rules != model.ending_idx).sum(-1).unsqueeze(-1) - 1 443 | formatted_rules = torch.cat([length, rules], dim=1) 444 | 445 | log_probabilities = log_probabilities.sum(-1) 446 | 447 | formatted_rules = formatted_rules.data.cpu().numpy().tolist() 448 | log_probabilities = log_probabilities.data.cpu().numpy().tolist() 449 | for k in range(num_samples): 450 | length = formatted_rules[k][0] 451 | formatted_rules[k] = formatted_rules[k][1: 2 + length] + [log_probabilities[k]] 452 | 453 | rule_set = set([tuple(rule) for rule in formatted_rules]) 454 | formatted_rules = [list(rule) for rule in rule_set] 455 | 456 | all_rules += formatted_rules 457 | 458 | return all_rules 459 | 460 | def load(self, checkpoint): 461 | """ 462 | Load a checkpoint from file. 463 | Parameters: 464 | checkpoint (file-like): checkpoint file 465 | load_optimizer (bool, optional): load optimizer state or not 466 | """ 467 | if comm.get_rank() == 0: 468 | logging.info("Load checkpoint from %s" % checkpoint) 469 | checkpoint = os.path.expanduser(checkpoint) 470 | state = torch.load(checkpoint, map_location=self.device) 471 | self.model.load_state_dict(state["model"]) 472 | 473 | def save(self, checkpoint): 474 | """ 475 | Save checkpoint to file. 476 | Parameters: 477 | checkpoint (file-like): checkpoint file 478 | """ 479 | if comm.get_rank() == 0: 480 | logging.info("Save checkpoint to %s" % checkpoint) 481 | checkpoint = os.path.expanduser(checkpoint) 482 | state = { 483 | "model": self.model.state_dict() 484 | } 485 | torch.save(state, checkpoint) 486 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import random 6 | import json 7 | import yaml 8 | import easydict 9 | import numpy as np 10 | import torch 11 | 12 | def load_config(cfg_file): 13 | with open(cfg_file, "r") as fin: 14 | raw_text = fin.read() 15 | 16 | if "---" in raw_text: 17 | configs = [] 18 | grid, template = raw_text.split("---") 19 | grid = yaml.safe_load(grid) 20 | template = jinja2.Template(template) 21 | for hyperparam in meshgrid(grid): 22 | config = easydict.EasyDict(yaml.safe_load(template.render(hyperparam))) 23 | configs.append(config) 24 | else: 25 | configs = [easydict.EasyDict(yaml.safe_load(raw_text))] 26 | 27 | return configs 28 | 29 | def save_config(cfg, path): 30 | with open(os.path.join(path, 'config.yaml'), 'w') as fo: 31 | yaml.dump(dict(cfg), fo) 32 | 33 | def set_seed(seed): 34 | torch.manual_seed(seed) 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | 39 | def save_model(model, optim, args): 40 | argparse_dict = vars(args) 41 | with open(os.path.join(args.save_path, 'config.json'), 'w') as fjson: 42 | json.dump(argparse_dict, fjson) 43 | 44 | params = { 45 | 'model': model.state_dict(), 46 | 'optim': optim.state_dict() 47 | } 48 | 49 | torch.save(params, os.path.join(args.save_path, 'checkpoint')) 50 | 51 | def load_model(model, optim, args): 52 | checkpoint = torch.load(args.load_path) 53 | model.load_state_dict(checkpoint['model']) 54 | optim.load_state_dict(checkpoint['optim']) 55 | 56 | def set_logger(save_path): 57 | log_file = os.path.join(save_path, 'run.log') 58 | 59 | logging.basicConfig( 60 | format='%(asctime)s %(levelname)-8s %(message)s', 61 | level=logging.INFO, 62 | datefmt='%Y-%m-%d %H:%M:%S', 63 | filename=log_file, 64 | filemode='w' 65 | ) 66 | console = logging.StreamHandler() 67 | console.setLevel(logging.INFO) 68 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 69 | console.setFormatter(formatter) 70 | logging.getLogger('').addHandler(console) --------------------------------------------------------------------------------