├── requirements.txt ├── BertSAGE ├── __pycache__ │ ├── model.cpython-38.pyc │ └── dataloader.cpython-38.pyc ├── infer.py ├── train.py ├── model.py └── dataloader.py ├── utils ├── utils.py └── atomic_utils.py ├── README.md └── preproc_atomic ├── match_head.py └── match_tail.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | transformers==3.4.0 3 | networkx 4 | numpy 5 | pandas 6 | tqdm 7 | -------------------------------------------------------------------------------- /BertSAGE/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/DISCOS-commonsense/HEAD/BertSAGE/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /BertSAGE/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/DISCOS-commonsense/HEAD/BertSAGE/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def chunks_list(l, group_number): 4 | group_size = math.ceil(len(l) / group_number) 5 | final_data_groups = list() 6 | for i in range(0, len(l), group_size): 7 | final_data_groups.append(l[i:min(i+group_size, len(l))]) 8 | return final_data_groups -------------------------------------------------------------------------------- /utils/atomic_utils.py: -------------------------------------------------------------------------------- 1 | ALL_SUBJS = [ "person", "man", "woman", 2 | "someone", "somebody", "i", "he", "she", "you", ] 3 | SUBJ2POSS = {"i":"my", "he": "his", "she":"her", "you":"your"} 4 | 5 | # Variables and rules 6 | SUBJS = ["person", "man", "woman", 7 | "someone", "somebody", "i", "he", "she", "you"] 8 | O_SUBJS = ["i", "you", "he", "she"] 9 | ATOMIC_SUBJS = ["PersonX", "PersonY", "PersonZ"] 10 | 11 | stative_rules = { 12 | "in":[], "out":[], 13 | "both_dir":["Synchronous", "Reason", "Result", "Condition", 14 | "Conjunction", "Restatement", "Alternative"] 15 | } 16 | cause_agent_rules = { 17 | "out":["Succession", "Condition", "Reason", ], 18 | "in":["Precedence", "Result",], 19 | "both_dir":["Synchronous", "Conjunction"], 20 | } 21 | effect_agent_rules = { 22 | "out":["Precedence", "Result",], 23 | "in":["Succession", "Condition", "Reason",], 24 | "both_dir":["Synchronous", "Conjunction"], 25 | } 26 | # This requires, subj to be different 27 | effect_theme_rules = { 28 | "out":["Precedence", "Result",], 29 | "in":["Succession", "Condition", "Reason",], 30 | "both_dir":["Synchronous", "Conjunction"], 31 | } 32 | 33 | ASER_rules_dict = { 34 | "stative": stative_rules, 35 | "cause_agent": cause_agent_rules, 36 | "effect_agent": effect_agent_rules, 37 | "effect_theme": effect_theme_rules, 38 | } 39 | 40 | # functions: 41 | def get_ppn_substitue_dict(head_split): 42 | """ 43 | input (list): the split result of a head 44 | 45 | output: a dict tha maps personal pronouns in 46 | head_split to subjects in ATOMIC_SUBJS 47 | """ 48 | atomic_head_pp_list = [] 49 | for token in head_split: 50 | if token in SUBJS: 51 | if not token in atomic_head_pp_list: 52 | atomic_head_pp_list.append(token) 53 | head_pp2atomic_pp = {} 54 | cnt = 0 55 | for pp in atomic_head_pp_list: 56 | head_pp2atomic_pp[pp] = ATOMIC_SUBJS[cnt] 57 | cnt += 1 58 | if cnt >= len(ATOMIC_SUBJS): 59 | break 60 | return head_pp2atomic_pp 61 | 62 | def filter_event(event): 63 | """ 64 | Function of filtering eventualities 65 | input (str): the string of eventuality 66 | 67 | output: whether to filter it out or not. 68 | """ 69 | tokens = event.split() 70 | # if tokens[-1] in SUBJS and tokens[-2] == "tell": 71 | # return True 72 | # if tokens[-1] in ["know", "say", "think"]: 73 | # return True 74 | # filter eventualities with only 2 tokens 75 | if len(tokens) <= 2: 76 | return True 77 | # filter hot verbs 78 | if any(kw in tokens for kw in ["say", "do", "know", "tell", "think", ]): 79 | return True 80 | # filter out errors that potentially due to the errors of the parser 81 | if tokens[0] in ["who", "what", "when", "where", "how", "why", "which", "whom", "whose"]: 82 | return True 83 | return False 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DISCOS-commonsense 2 | 3 | This is the github repo for The Web Conference (WWW) 2021 paper [DISCOS: Bridging the Gap between Discourse Knowledge and Commonsense Knowledge](https://arxiv.org/abs/2101.00154). 4 | 5 | Check out the follow-up work, a benchmark that evaluates the performance of transforming discourse knowledge to commonsense knowledge: [(EMNLP 2021, CSKB Population) Benchmarking Commonsense Knowledge Base Population with an Effective Evaluation Dataset](https://arxiv.org/abs/2109.07679), and the code [repositry](https://github.com/HKUST-KnowComp/CSKB-Population). 6 | 7 | ### How to train the Commonsense Knowledge Graph Population (CKGP) model 8 | 9 | Here is the instruction for learning the CKGP model. We use the filtered graph as introduced in Section 5.1.1 for this experiment. 10 | 11 | First, git clone this repo, and then download the prepared aligned graph from [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/tfangaa_connect_ust_hk/EqYM_lq9gl1DhJu6HnezBvYBzuOfk60iDhg_zCTq9gZrLw?e=18OxwY). The `data/graph_cache` folder contains the data that can be directly used for training and testing. The `data/graph_raw_data` is the graph file that we got after aligning ATOMIC and ASER, with pre-defined negative edges for the CKGP task. The `data/infer_candidates` folder contains the candidate (h, r, t) tuples to be scored by our BertSAGE model. 12 | 13 | Next install dependencies. Recommended python version is 3.8+. 14 | 15 | `pip install -r requirements.txt` 16 | 17 | Note that to load the files in `data/graph_cache` requires the same dependencies as in the `requirements.txt` file. E.g., you need to install the `transformers` package with version 3.4.0. 18 | 19 | Next train the `BertSAGE` model. For example here is the command to train with the `oReact` relation: 20 | 21 | ``` 22 | python -u BertSAGE/train.py --model graphsage \ 23 | --load_edge_types ASER \ 24 | --neg_prop 1 \ 25 | --graph_cach_path data/graph_cache/neg_{}_{}_{}.pickle \ 26 | --negative_sample prepared_neg \ 27 | --file_path data/graph_raw_data/G_aser_oReact_1hop_thresh_100_neg_other_20_inv_10.pickle 28 | ``` 29 | 30 | For other relations, you could find the corresponding `.pickle` file from `data` folder. 31 | 32 | For the inference part, you could run after training: 33 | 34 | ``` 35 | python -u BertSAGE/infer.py --gpu 0 --model graphsage \ 36 | --model_path models/G_aser_oReact_1hop_thresh_100_neg_other_20_inv_10/graphsage_best_bert_bs64_opt_SGD_lr0.01_decay0.8_500_layer1_neighnum_4_graph_ASER_acc.pth \ 37 | --infer_path data/infer_candidates/G_aser_oReact_1hop_thresh_100_neg_other_20_inv_10.npy \ 38 | --graph_cach_path data/graph_cache/neg_prepared_neg_ASER_G_aser_oReact_1hop_thresh_100_neg_other_20_inv_10.pickle 39 | ``` 40 | 41 | 42 | ### The Acquired Knowledge Graph DISCOS-ATOMIC 43 | 44 | By populating the knowledge in ATOMIC to the whole ASER, we can acquire a large-scale ATOMIC-like knowledge graph by selecting the tuples scored by BertSAGE over 0.5. Also, we present the acquisition results of DISCOS under the setting of COMET, i.e., given h and r to generate t. The new knowledge graph can be downloaded [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/tfangaa_connect_ust_hk/ElHMMtHsCwZLg-AdP8ZdJT8BCBwTOyAOil1XLt4EfPYWUg?e=49u0i3). 45 | 46 | The 3.4M if-then knowledge is populated using the whole graph of ASER-core, without the neighbor filtering. You may find the processed training graph and inference candidates [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/tfangaa_connect_ust_hk/EmC5tdRCmQlMrfwBHrVHYE4B5_UhIfqL1uxNSNofLPMYQQ?e=dkQObG). 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /preproc_atomic/match_head.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import math 4 | import numpy as np 5 | from tqdm import tqdm 6 | from itertools import chain 7 | from aser.extract.eventuality_extractor import SeedRuleEventualityExtractor 8 | from itertools import permutations, combinations_with_replacement 9 | from multiprocessing import Pool 10 | from utils.atomic_utils import ALL_SUBJS, SUBJ2POSS 11 | 12 | def instantiate_ppn(line): 13 | strs = e_extractor.parse_text(line) 14 | if len(strs) > 0: 15 | strs = strs[0]['tokens'] 16 | else: 17 | return [] 18 | pp_index = [] 19 | wildcard_index = [] 20 | for i, word in enumerate(strs): 21 | if word in ["PersonX", "PersonY", "PersonZ", "alex", "bob", "she", "he", "i", "you"]: 22 | pp_index.append(i) 23 | # Deprecate replacing WILDCARD. This will be handled independently 24 | # elif word in ["WILDCARD", "something"]: 25 | # wildcard_index.append(i) 26 | # permutation of all possible substitutions 27 | perm_pp = list(combinations_with_replacement(ALL_SUBJS, len(pp_index))) 28 | perm_wildcard = list(combinations_with_replacement(['something', 'thing'], len(wildcard_index))) 29 | all_perms = [list(tmp_a)+list(tmp_b) for tmp_a in perm_pp for tmp_b in perm_wildcard] 30 | all_index = pp_index + wildcard_index 31 | 32 | # deal with possesive cases 33 | modified_idx = [] 34 | if "'s" in strs: 35 | for idx in pp_index: 36 | if strs[min(idx + 1, len(strs)-1)] == "'s": 37 | modified_idx.append(idx) 38 | for perm in all_perms: 39 | # deal with possessive case 40 | if len(modified_idx) == 0: 41 | # if none of the PPs contain a following "'s", then just replace the heads 42 | yield ' '.join([strs[i] if not i in all_index else perm[all_index.index(i)] for i in range(len(strs))]) 43 | else: 44 | # else, replace the PersonX's with my, her, his, etc. 45 | new_strs = [strs[i] if not i in all_index else perm[all_index.index(i)] for i in range(len(strs))] 46 | for idx in modified_idx: 47 | if new_strs[idx] in SUBJ2POSS: 48 | new_strs[idx] = SUBJ2POSS[new_strs[idx]] 49 | new_strs[idx+1] = "\REMOVE" 50 | while "\REMOVE" in new_strs: 51 | new_strs.remove("\REMOVE") 52 | yield ' '.join(new_strs) 53 | 54 | def unfold_parse_results(e): 55 | # return the words of the extractor results 56 | if len(e) == 0: 57 | return "" 58 | if len(e[0]) == 0: 59 | return "" 60 | return " ".join(e[0][0].words) 61 | 62 | def extract(ATOMIC_lines, i): 63 | extracted_event_list = [[] for i in range(len(ATOMIC_lines))] 64 | for i in tqdm(range(i, len(ATOMIC_lines), num_thread)): 65 | line = ATOMIC_lines[i] 66 | possible_heads = instantiate_ppn(line) 67 | all_head_words = [unfold_parse_results(e_extractor.extract_from_text(tmp_text)) \ 68 | for tmp_text in possible_heads] 69 | extracted_event_list[i] = all_head_words 70 | return extracted_event_list 71 | 72 | # main 73 | stanford_patah = "stanford-corenlp-full/" 74 | e_extractor = SeedRuleEventualityExtractor( 75 | corenlp_path = stanford_patah, 76 | corenlp_port= 13000) 77 | 78 | ATOMIC_path = "all_agg_event.txt" 79 | ATOMIC_lines = open().readlines() 80 | 81 | num_thread = 5 82 | # the maximum number of a thread that the parser supports is 5 83 | workers = Pool(num_thread) 84 | all_results = [] 85 | for i in range(num_thread): 86 | tmp_result = workers.apply_async( 87 | extract, 88 | args=(ATOMIC_lines, i)) 89 | all_results.append(tmp_result) 90 | 91 | workers.close() 92 | workers.join() 93 | 94 | all_results = [tmp_result.get() for tmp_result in all_results] 95 | all_results = [list(chain(*item)) for item in zip(*all_results)] 96 | 97 | np.save('ASER-format-words/ATOMIC_head_words_withpersonz', all_results) -------------------------------------------------------------------------------- /BertSAGE/infer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import torch 4 | import os 5 | from dataloader import * 6 | from model import * 7 | import argparse 8 | import pickle 9 | from tqdm import tqdm 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | ## Required parameters 14 | parser.add_argument("--gpu", default='0', type=str, required=False, 15 | help="choose which gpu to use") 16 | parser.add_argument("--model", default='simple', type=str, required=False, 17 | choices=["graphsage", "simple"], 18 | help="choose model") 19 | parser.add_argument("--model_path", default='', type=str, required=True, 20 | help="model path") 21 | parser.add_argument("--encoder", default='bert', type=str, required=False, 22 | choices=["bert", "roberta"], 23 | help="choose encoder") 24 | parser.add_argument("--infer_path", default='', type=str, required=True, 25 | help="npy file to be inferenced") 26 | parser.add_argument("--graph_cach_path", default="graph_cache/.pickle", 27 | type=str, required=False, 28 | help="path of graph cache") 29 | parser.add_argument("--num_layers", default=1, type=int, required=False, 30 | help="number of graphsage layers") 31 | parser.add_argument("--num_neighbor_samples", default=4, type=int, required=False, 32 | help="num neighbor samples in GraphSAGE") 33 | 34 | args = parser.parse_args() 35 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | test_batch_size = 128 38 | 39 | if not os.path.exists("preds"): 40 | os.mkdir("preds") 41 | 42 | # graph_cache = args.graph_cach_path.format( 43 | # args.negative_sample, 44 | # args.load_edge_types, 45 | # os.path.basename(args.file_path).split(".")[0]) 46 | graph_cache = args.graph_cach_path 47 | 48 | if args.model == "simple": 49 | with open(graph_cache, "rb") as reader: 50 | graph_dataset = pickle.load(reader) 51 | print("after loading graph cache from", graph_cache) 52 | data_loader = InferenceSimpleDataset(args.infer_path, device, args.encoder, graph_dataset) 53 | elif args.model == "graphsage": 54 | with open(graph_cache, "rb") as reader: 55 | graph_dataset = pickle.load(reader) 56 | print("after loading graph cache from", graph_cache) 57 | data_loader = InferenceGraphDataset(args.infer_path, device, args.encoder, graph_dataset) 58 | 59 | if args.model == "simple": 60 | model = SimpleClassifier(encoder=args.encoder, 61 | adj_lists=None, 62 | nodes_tokenized=data_loader.get_nodes_tokenized(), 63 | device=device, 64 | ) 65 | elif args.model == 'graphsage': 66 | model = LinkPrediction(encoder=args.encoder, 67 | adj_lists=data_loader.get_adj_list(), 68 | nodes_tokenized=data_loader.get_nodes_tokenized(), 69 | device=device, 70 | num_layers=args.num_layers, 71 | num_neighbor_samples=args.num_neighbor_samples, 72 | ) 73 | 74 | model.load_state_dict(torch.load(args.model_path)) 75 | model.eval() 76 | 77 | def infer(data_loader, model): 78 | def infer_mode(mode): 79 | all_predictions = [] 80 | all_values = [] 81 | all_hids = [] 82 | for batch in tqdm(data_loader.get_batch(batch_size=test_batch_size, mode=mode)): 83 | b_s, _ = batch.shape # batch_size, 2+1 84 | all_nodes = batch[:, :2].reshape([-1]) 85 | hids = batch[:, 2].tolist() 86 | 87 | logits = model(all_nodes, b_s) # (batch_size, 2) 88 | 89 | logits = torch.softmax(logits, dim=1) 90 | values = logits[:, 1] 91 | _, predicted = torch.max(logits, dim=1) 92 | predicted = predicted.tolist() 93 | values = values.tolist() 94 | all_predictions.extend(predicted) 95 | all_values.extend(values) 96 | all_hids.extend(hids) 97 | return all_predictions, all_values, all_hids 98 | 99 | with torch.no_grad(): 100 | all_predictions = {} 101 | all_values = {} 102 | all_hids = {} 103 | for mode in ["head", "tail", "new"]: 104 | all_predictions[mode], all_values[mode], all_hids[mode] = infer_mode(mode) 105 | return all_predictions, all_values, all_hids 106 | 107 | preds, vals, hids = infer(data_loader, model) 108 | for mode in ["head", "tail", "new"]: 109 | print(mode, "num 1:", sum(np.array(preds[mode])==1), sum(np.array(preds[mode])==1)/len(preds[mode])) 110 | # print the correct predict 111 | plausible_knowledge = {} 112 | for mode in ["head", "tail", "new"]: 113 | plausible_knowledge[mode] = [] 114 | for i, (p, v, hid) in enumerate(zip(preds[mode], vals[mode], hids[mode])): 115 | plausible_knowledge[mode].append((data_loader.data[mode][i][:2], v, hid)) 116 | np.save("preds/"+os.path.basename(args.infer_path).split(".")[0]+"_preds"+"_"+args.model, plausible_knowledge) 117 | -------------------------------------------------------------------------------- /BertSAGE/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import torch 4 | import os 5 | from dataloader import * 6 | from model import * 7 | import argparse 8 | import pickle 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | ## Required parameters 13 | parser.add_argument("--gpu", default='0', type=str, required=False, 14 | help="choose which gpu to use") 15 | parser.add_argument("--model", default='simple', type=str, required=False, 16 | choices=["graphsage", "simple"], 17 | help="choose model") 18 | parser.add_argument("--encoder", default='bert', type=str, required=False, 19 | choices=["bert", "roberta"], 20 | help="choose encoder") 21 | parser.add_argument("--num_layers", default=1, type=int, required=False, 22 | help="number of graphsage layers") 23 | parser.add_argument("--lr", default=0.01, type=float, required=False, 24 | help="learning rate") 25 | parser.add_argument("--lrdecay", default=0.8, type=float, required=False, 26 | help="learning rate decay every 2000 steps") 27 | parser.add_argument("--decay_every", default=500, type=int, required=False, 28 | help="show test result every x steps") 29 | parser.add_argument("--test_every", default=250, type=int, required=False, 30 | help="show test result every x steps") 31 | parser.add_argument("--batch_size", default=64, type=int, required=False, 32 | help="batch size") 33 | parser.add_argument("--epochs", default=3, type=int, required=False, 34 | help="batch size") 35 | parser.add_argument("--num_neighbor_samples", default=4, type=int, required=False, 36 | help="num neighbor samples in GraphSAGE") 37 | parser.add_argument("--load_edge_types", default='ATOMIC', type=str, required=False, 38 | choices=["ATOMIC", "ASER", "ATOMIC+ASER"], 39 | help="load what edges to data_loader.adj_lists") 40 | parser.add_argument("--graph_cach_path", default="graph_cache/neg_{}_{}_{}_{}.pickle", 41 | type=str, required=False, 42 | help="path of graph cache") 43 | parser.add_argument("--optimizer", default='SGD', type=str, required=False, 44 | choices=["SGD", "ADAM"], 45 | help="optimizer to be used") 46 | parser.add_argument("--negative_sample", default='from_all', type=str, required=False, 47 | choices=["prepared_neg", "from_all", "fix_head"], 48 | help="nagative sample methods") 49 | parser.add_argument("--file_path", default='', type=str, required=True, 50 | help="load training graph pickle") 51 | parser.add_argument("--metric", default='acc', type=str, required=False, 52 | choices=["f1", "acc"], 53 | help="evaluation metric, either f1 or acc") 54 | parser.add_argument("--neg_prop", default=1.0, type=float, required=False, 55 | help="the proportion of negative sample: num_neg/num_pos") 56 | 57 | args = parser.parse_args() 58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 59 | lr = args.lr 60 | show_step = args.test_every 61 | batch_size= args.batch_size 62 | num_epochs = args.epochs 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | test_batch_size = 64 65 | neg_prop = args.neg_prop 66 | 67 | file_path = args.file_path 68 | 69 | graph_cache = args.graph_cach_path.format(args.negative_sample, args.load_edge_types, os.path.basename(file_path).split(".")[0]) 70 | if not os.path.exists("models"): 71 | os.mkdir("models") 72 | model_dir = "models/"+os.path.basename(file_path).split(".")[0] 73 | if not os.path.exists(model_dir): 74 | os.mkdir(model_dir) 75 | if args.model == "simple": 76 | model_save_path = os.path.join(model_dir, '{}_best_{}_bs{}_opt_{}_lr{}_decay{}_{}_{}.pth'\ 77 | .format(args.model, args.encoder, batch_size, args.optimizer, 78 | args.lr, args.lrdecay, args.decay_every, args.metric)) 79 | elif args.model == "graphsage": 80 | model_save_path = os.path.join(model_dir, '{}_best_{}_bs{}_opt_{}_lr{}_decay{}_{}_layer{}_neighnum_{}_graph_{}_{}.pth'\ 81 | .format(args.model, args.encoder, batch_size, args.optimizer, args.lr, 82 | args.lrdecay, args.decay_every, args.num_layers, 83 | args.num_neighbor_samples, args.load_edge_types, args.metric)) 84 | 85 | print(graph_cache) 86 | if not os.path.exists(graph_cache): 87 | data_loader = GraphDataset(file_path, device, args.encoder, 88 | negative_sample=args.negative_sample, load_edge_types=args.load_edge_types, 89 | neg_prop=neg_prop) 90 | with open(graph_cache, "wb") as writer: 91 | pickle.dump(data_loader,writer,pickle.HIGHEST_PROTOCOL) 92 | print("after dumping graph cache to", graph_cache) 93 | else: 94 | with open(graph_cache, "rb") as reader: 95 | data_loader = pickle.load(reader) 96 | print("after loading graph cache from", graph_cache) 97 | 98 | 99 | if args.model == "simple": 100 | model = SimpleClassifier(encoder=args.encoder, 101 | adj_lists=data_loader.get_adj_list(), 102 | nodes_tokenized=data_loader.get_nodes_tokenized(), 103 | device=device, 104 | ) 105 | elif args.model == 'graphsage': 106 | model = LinkPrediction(encoder=args.encoder, 107 | adj_lists=data_loader.get_adj_list(), 108 | nodes_tokenized=data_loader.get_nodes_tokenized(), 109 | device=device, 110 | num_layers=args.num_layers, 111 | num_neighbor_samples=args.num_neighbor_samples, 112 | ) 113 | 114 | criterion = torch.nn.CrossEntropyLoss() 115 | if args.optimizer == "SGD": 116 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 117 | elif args.optimizer == "ADAM": 118 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 119 | optimizer.zero_grad() 120 | 121 | step = 0 122 | 123 | best_valid_acc = 0 124 | best_test_acc = 0 125 | best_valid_pos_acc = 0 126 | best_test_pos_acc = 0 127 | 128 | my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.lrdecay) 129 | 130 | for epoch in range(num_epochs): 131 | for batch in data_loader.get_batch(batch_size=batch_size, mode="train"): 132 | # torch.cuda.empty_cache() 133 | step += 1 134 | if step % args.decay_every == 0: 135 | # lr = lr * args.lrdecay 136 | # optimizer = torch.optim.SGD(model.parameters(), lr=lr) 137 | my_lr_scheduler.step() 138 | # batch list((node_id1, node_id2)) 139 | edges, labels = batch 140 | b_s, _ = edges.shape # batch_size, 2 141 | all_nodes = edges.reshape([-1]) 142 | 143 | logits = model(all_nodes, b_s) 144 | loss = criterion(logits, labels) 145 | 146 | loss.backward() 147 | optimizer.step() 148 | optimizer.zero_grad() 149 | model.zero_grad() 150 | 151 | # evaluate 152 | if step % show_step == 0: 153 | val_acc, val_pos_acc = eval(data_loader, model, test_batch_size, criterion, "valid", args.metric) 154 | test_acc, test_pos_acc = eval(data_loader, model, test_batch_size, criterion, "test", args.metric) 155 | if val_acc > best_valid_acc: 156 | best_valid_acc = val_acc 157 | best_test_acc = test_acc 158 | best_valid_pos_acc = val_pos_acc 159 | best_test_pos_acc = test_pos_acc 160 | 161 | torch.save(model.state_dict(), model_save_path) 162 | 163 | print(args.metric, ": epoch {}, step {}, current valid: {}," 164 | "current test: {}, curret valid pos:{}," 165 | " current test pos: {},".format(epoch, step, val_acc, test_acc, val_pos_acc, test_pos_acc)) 166 | print(args.metric, ": current best val: {}, test: {}" 167 | "current best val pos: {}, test: {}".format(best_valid_acc, best_test_acc, best_valid_pos_acc, best_test_pos_acc)) 168 | -------------------------------------------------------------------------------- /preproc_atomic/match_tail.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import math 4 | import json 5 | import argparse 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from aser.extract.eventuality_extractor import SeedRuleEventualityExtractor 10 | from itertools import permutations, combinations_with_replacement, chain 11 | from multiprocessing import Pool 12 | from utils.atomic_utils import ALL_SUBJS, SUBJ2POSS 13 | 14 | def instantiate_ppn(line): 15 | # slightly different from that in match_heads.py 16 | # 1. doesn't parse the sentence again 17 | # 2. return the original sentence if there's no PersonX/Y/Z 18 | strs = line.split() 19 | if len(strs) == 0: 20 | return [] 21 | # strs = e_extractor.parse_text(line) 22 | # if len(strs) > 0: 23 | # strs = strs[0]['tokens'] 24 | # else: 25 | # return [] 26 | pp_index = [] 27 | wildcard_index = [] 28 | for i, word in enumerate(strs): 29 | if word in ["PersonX", "PersonY", "PersonZ"]: 30 | pp_index.append(i) 31 | # Deprecate replacing WILDCARD. This will be handled independently 32 | # elif word in ["WILDCARD", "something"]: 33 | # wildcard_index.append(i) 34 | # permutation of all possible substitutions 35 | perm_pp = list(combinations_with_replacement(ALL_SUBJS, len(pp_index))) 36 | perm_wildcard = list(combinations_with_replacement(['something', 'thing'], len(wildcard_index))) 37 | all_perms = [list(tmp_a)+list(tmp_b) for tmp_a in perm_pp for tmp_b in perm_wildcard] 38 | all_index = pp_index + wildcard_index 39 | if len(all_index) == 0: 40 | yield line 41 | else: 42 | modified_idx = [] 43 | if "'s" in strs: 44 | for idx in pp_index: 45 | if strs[min(idx + 1, len(strs)-1)] == "'s": 46 | modified_idx.append(idx) 47 | for perm in all_perms: 48 | # deal with possessive case 49 | if len(modified_idx) == 0: 50 | # if non of the PPs contain a following "'s", then just replace the heads 51 | yield ' '.join([strs[i] if not i in all_index else perm[all_index.index(i)] for i in range(len(strs))]) 52 | else: 53 | new_strs = [strs[i] if not i in all_index else perm[all_index.index(i)] for i in range(len(strs))] 54 | for idx in modified_idx: 55 | if new_strs[idx] in SUBJ2POSS: 56 | new_strs[idx] = SUBJ2POSS[new_strs[idx]] 57 | new_strs[idx+1] = "\REMOVE" 58 | while "\REMOVE" in new_strs: 59 | new_strs.remove("\REMOVE") 60 | yield ' '.join(new_strs) 61 | 62 | def contain_subject(dependencies): 63 | return any(dep in [item[1] for item in dependencies] for dep in ['nsubj', 'nsubjpass']) 64 | 65 | def fill_sentence(sent, r, has_subject): 66 | if r in ['oEffect', 'xEffect']: 67 | # + subject 68 | if has_subject: 69 | return [sent] 70 | else: 71 | return [' '.join([subj, sent]) for subj in ALL_SUBJS] 72 | elif r in ['oReact', 'xReact']: 73 | # + subject / + subject is 74 | if has_subject: 75 | return [sent] 76 | else: 77 | return [' '.join([subj, sent]) for subj in ALL_SUBJS] + \ 78 | [' '.join([subj, 'is', sent]) for subj in ALL_SUBJS] 79 | elif r in ['xAttr']: 80 | # + subject is 81 | if has_subject: 82 | return [sent] 83 | else: 84 | return [' '.join([subj, 'is', sent]) for subj in ALL_SUBJS] 85 | elif r in ['oWant', 'xWant']: 86 | # + subject want / + subject 87 | if has_subject: 88 | return [sent] 89 | else: 90 | # if start with 'to' 91 | if sent.lower().split()[0] == 'to': 92 | return [' '.join([subj, 'want', sent]) for subj in ALL_SUBJS] \ 93 | + [' '.join([subj, " ".join(sent.lower().split()[1:]) ]) for subj in ALL_SUBJS] 94 | else: 95 | return [' '.join([subj, 'want to', sent]) for subj in ALL_SUBJS] \ 96 | + [' '.join([subj, sent]) for subj in ALL_SUBJS] 97 | elif r in ['xIntent']: 98 | # + subject intent / + subject 99 | if has_subject: 100 | return [sent] 101 | else: 102 | # if start with 'to' 103 | if sent.lower().split()[0] == 'to': 104 | return [' '.join([subj, 'intent', sent]) for subj in ALL_SUBJS] \ 105 | + [' '.join([subj, " ".join(sent.lower().split()[1:]) ]) for subj in ALL_SUBJS] 106 | else: 107 | return [' '.join([subj, 'intent to', sent]) for subj in ALL_SUBJS]\ 108 | + [' '.join([subj, sent]) for subj in ALL_SUBJS] 109 | elif r in ['xNeed']: 110 | # + subject need / + subject 111 | if has_subject: 112 | return [sent] 113 | else: 114 | # if start with 'to' 115 | if sent.lower().split()[0] == 'to': 116 | return [' '.join([subj, 'need', sent]) for subj in ALL_SUBJS]\ 117 | + [' '.join([subj, " ".join(sent.lower().split()[1:]) ]) for subj in ALL_SUBJS] 118 | else: 119 | return [' '.join([subj, 'need to', sent]) for subj in ALL_SUBJS]\ 120 | + [' '.join([subj, sent]) for subj in ALL_SUBJS] 121 | 122 | def unfold_parse_results(e): 123 | if len(e) == 0: 124 | return "" 125 | if len(e[0]) == 0: 126 | return "" 127 | return " ".join(e[0][0].words) 128 | 129 | def process_pp(sent): 130 | """ 131 | Deal with the situation of "person x", "person y", "personx", "persony" 132 | """ 133 | fill_words = {"person x":"PersonX", "person y":"PersonY", 134 | "personx":"PersonX", "persony":"PersonY", 135 | "x":"PersonX", "y": "PersonY"} 136 | for strs in PP_filter_list: 137 | if strs in sent: 138 | sent = sent.replace(strs, fill_words[strs]) 139 | break 140 | sent_split = sent.split() 141 | X_dict = {"X":"PersonX"} 142 | if "x" in sent_split or "y" in sent_split: 143 | sent = " ".join([fill_words.get(item, item) for item in sent_split]) 144 | return sent 145 | 146 | 147 | def extract(atomic_data, r, idx): 148 | extracted_event_list = [[] for i in range(len(atomic_data))] 149 | for i in tqdm(range(idx, len(atomic_data[r]), num_thread)): 150 | tmp_node = [] 151 | for sent in json.loads(atomic_data[r][i]): 152 | if sent == 'none': 153 | continue 154 | # filter the text 155 | sent = sent.lower() 156 | sent = process_pp(sent) 157 | parsed_result = e_extractor.parse_text(sent)[0] 158 | filled_sentences = fill_sentence(sent, r, contain_subject(parsed_result['dependencies'])) 159 | filled_sentences = list(chain(*[instantiate_ppn(s) for s in filled_sentences])) 160 | tmp_node.append([unfold_parse_results(e_extractor.extract_from_text(tmp_text))\ 161 | for tmp_text in filled_sentences]) 162 | extracted_event_list[i] = tmp_node 163 | 164 | return extracted_event_list 165 | 166 | 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--relation", default='xWant', type=str, required=True, 169 | choices=['oEffect', 'oReact', 'oWant', 'xAttr', 170 | 'xEffect', 'xIntent', 'xNeed', 'xReact', 'xWant'], 171 | help="choose which relation to process") 172 | parser.add_argument("--port", default=14000, type=int, required=False, 173 | help="port of stanford parser") 174 | args = parser.parse_args() 175 | 176 | PP_filter_list = ["person x", "person y", "personx", "persony"] 177 | 178 | relation = args.relation 179 | 180 | e_extractor = SeedRuleEventualityExtractor( 181 | corenlp_path = "stanford-corenlp-full/", 182 | corenlp_port= args.port) 183 | atomic_data = pd.read_csv('v4_atomic_all_agg.csv') 184 | 185 | num_thread = 5 186 | workers = Pool(num_thread) 187 | all_results = [] 188 | for i in range(num_thread): 189 | tmp_result = workers.apply_async( 190 | extract, 191 | args=(atomic_data, relation, i)) 192 | all_results.append(tmp_result) 193 | 194 | workers.close() 195 | workers.join() 196 | 197 | all_results = [tmp_result.get() for tmp_result in all_results] 198 | all_results = [list(chain(*item)) for item in zip(*all_results)] 199 | 200 | np.save('ASER-format-words-final/ATOMIC_tails_'+relation, all_results) -------------------------------------------------------------------------------- /BertSAGE/model.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import torch 3 | import random 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from transformers import BertModel, RobertaModel 9 | from itertools import chain 10 | from torch.nn.utils.rnn import pad_sequence 11 | import numpy as np 12 | 13 | MAX_SEQ_LENGTH=30 14 | 15 | def eval(data_loader, model, test_batch_size, criterion, mode="test", metric="acc"): 16 | loss = 0 17 | correct_num = 0 18 | total_num = 0 19 | model.eval() 20 | num_steps = 0 21 | # the accuracy of positive examples 22 | correct_pos = 0 23 | total_pos = 0 24 | 25 | with torch.no_grad(): 26 | for batch in data_loader.get_batch(batch_size=test_batch_size, mode=mode): 27 | edges, labels = batch 28 | b_s, _ = edges.shape # batch_size, 2 29 | all_nodes = edges.reshape([-1]) 30 | 31 | logits = model(all_nodes, b_s) # (batch_size, 2) 32 | 33 | loss += criterion(logits, labels).item() 34 | 35 | predicted = torch.max(logits, dim=1)[1] 36 | correct_num += (predicted == labels).sum().item() 37 | total_num += b_s 38 | num_steps += 1 39 | 40 | correct_pos += ( (predicted == labels) & (labels == 1)).sum().item() 41 | total_pos += (labels == 1).sum().item() 42 | 43 | 44 | # print("eval", labels, logits) 45 | model.train() 46 | # print(mode+" set accuracy:", correct_num / total_num, "loss:", loss/num_steps) 47 | # return F1, 48 | TP = correct_pos 49 | FN = total_pos - correct_pos 50 | R = TP / (TP+FN) 51 | FP = total_num - correct_num - FN 52 | P = TP / (TP+FP) 53 | # return 2*P*R/(P+R), correct_pos/total_pos 54 | 55 | if metric == "acc": 56 | return correct_num / total_num, correct_pos/total_pos 57 | elif metric == "f1": 58 | return 2*P*R/(P+R), correct_pos/total_pos 59 | 60 | class LinkPrediction(nn.Module): 61 | def __init__(self, encoder, adj_lists, nodes_tokenized, device, num_layers=1,num_neighbor_samples=10): 62 | super(LinkPrediction, self).__init__() 63 | 64 | self.graph_model = GraphSage( 65 | encoder=encoder, 66 | num_layers=num_layers, 67 | input_size=768, 68 | output_size=768, 69 | adj_lists=adj_lists, 70 | nodes_tokenized=nodes_tokenized, 71 | device=device, 72 | agg_func='MEAN', 73 | num_neighbor_samples=num_neighbor_samples) 74 | 75 | self.link_classifier = Classification(768*2, 2, device) 76 | 77 | def forward(self, all_nodes, b_s): 78 | embs = self.graph_model(all_nodes)# (2*batch_size, emb_size) 79 | logits = self.link_classifier(embs.view([b_s, -1])) # (batch_size, 2*emb_size) 80 | 81 | return logits 82 | 83 | class SimpleClassifier(nn.Module): 84 | def __init__(self, encoder, adj_lists, nodes_tokenized, device): 85 | super(SimpleClassifier, self).__init__() 86 | self.nodes_tokenized = nodes_tokenized 87 | self.device = device 88 | if encoder == "bert": 89 | self.roberta_model = BertModel.from_pretrained("bert-base-uncased").to(device) 90 | elif encoder == "roberta": 91 | self.roberta_model = RobertaModel.from_pretrained('roberta-base').to(device) 92 | 93 | self.link_classifier = Classification(768*2, 2, device) 94 | 95 | def get_roberta_embs(self, input_ids): 96 | """ 97 | Input_ids: tensor (num_node, max_length) 98 | 99 | output: 100 | tensor: (num_node, emb_size) 101 | """ 102 | outputs = self.roberta_model(input_ids) 103 | return torch.mean(outputs[0], dim=1) # aggregate embs 104 | 105 | def forward(self, all_nodes, b_s): 106 | embs = self.get_roberta_embs( 107 | pad_sequence([self.nodes_tokenized[int(node)] for node in all_nodes], padding_value=1).transpose(0, 1).to(self.device) 108 | ) 109 | 110 | logits = self.link_classifier(embs.view([b_s, -1])) # (batch_size, 2*emb_size) 111 | 112 | return logits 113 | 114 | 115 | class Classification(nn.Module): 116 | 117 | def __init__(self, emb_size, num_classes, device): 118 | super(Classification, self).__init__() 119 | 120 | #self.weight = nn.Parameter(torch.FloatTensor(emb_size, num_classes)) 121 | self.linear = nn.Linear(emb_size, num_classes).to(device) 122 | 123 | def forward(self, embs): 124 | logists = self.linear(embs) 125 | return logists 126 | 127 | class SageLayer(nn.Module): 128 | """ 129 | Encodes a node's using 'convolutional' GraphSage approach 130 | """ 131 | def __init__(self, input_size, out_size): 132 | super(SageLayer, self).__init__() 133 | 134 | self.input_size = input_size 135 | self.out_size = out_size 136 | 137 | self.linear = nn.Linear(self.input_size*2, self.out_size) 138 | 139 | def forward(self, self_feats, aggregate_feats, neighs=None): 140 | """ 141 | Generates embeddings for a batch of nodes. 142 | 143 | nodes -- list of nodes 144 | """ 145 | combined = torch.cat([self_feats, aggregate_feats], dim=1) 146 | # [b_s, emb_size * 2] 147 | combined = F.relu( self.linear(combined) ) # [b_s, emb_size] 148 | return combined 149 | 150 | class GraphSage(nn.Module): 151 | """docstring for GraphSage""" 152 | def __init__(self, encoder, num_layers, input_size, output_size, 153 | adj_lists, nodes_tokenized, device, agg_func='MEAN', num_neighbor_samples=10): 154 | super(GraphSage, self).__init__() 155 | 156 | self.input_size = input_size 157 | self.out_size = output_size 158 | self.num_layers = num_layers 159 | self.device = device 160 | self.agg_func = agg_func 161 | self.num_neighbor_samples = num_neighbor_samples 162 | 163 | if encoder == "bert": 164 | self.roberta_model = BertModel.from_pretrained("bert-base-uncased").to(device) 165 | elif encoder == "roberta": 166 | self.roberta_model = RobertaModel.from_pretrained('roberta-base').to(device) 167 | 168 | self.adj_lists = adj_lists 169 | self.nodes_tokenized = nodes_tokenized 170 | 171 | for index in range(1, num_layers+1): 172 | layer_size = self.out_size if index != 1 else self.input_size 173 | setattr(self, 'sage_layer'+str(index), SageLayer(layer_size, self.out_size).to(device)) 174 | # self.fill_tensor = torch.FloatTensor(1, 768).fill_(0).to(self.device) 175 | self.fill_tensor = torch.nn.Parameter(torch.rand(1, 768)).to(self.device) 176 | 177 | 178 | 179 | def get_roberta_embs(self, input_ids): 180 | """ 181 | Input_ids: tensor (num_node, max_length) 182 | 183 | output: 184 | tensor: (num_node, emb_size) 185 | """ 186 | outputs = self.roberta_model(input_ids) 187 | return torch.mean(outputs[0], dim=1) # aggregate embs 188 | 189 | 190 | def forward(self, nodes_batch): 191 | """ 192 | Generates embeddings for a batch of nodes. 193 | nodes_batch -- (list: ids)batch of nodes to learn the embeddings 194 | """ 195 | lower_layer_nodes = list(nodes_batch) # node idx 196 | 197 | nodes_batch_layers = [(lower_layer_nodes,)] 198 | 199 | for i in range(self.num_layers): 200 | lower_layer_neighs, lower_layer_nodes = self._get_unique_neighs_list(lower_layer_nodes, num_sample=self.num_neighbor_samples) 201 | # lower_layer_neighs: list(list()) 202 | # lower_layer_nodes: list(nodes of next layer) 203 | nodes_batch_layers.insert(0, (lower_layer_nodes, lower_layer_neighs)) 204 | 205 | all_nodes = np.unique([int(n) for n in list(chain(*[layer[0] for layer in nodes_batch_layers]))]) 206 | all_nodes_idx = dict([(node, idx) for idx, node in enumerate(all_nodes) ]) 207 | 208 | 209 | all_neigh_nodes = pad_sequence([self.nodes_tokenized[node ] for node in all_nodes], padding_value=1).transpose(0, 1)[:, :MAX_SEQ_LENGTH].to(self.device) 210 | 211 | pre_hidden_embs = self.get_roberta_embs( 212 | all_neigh_nodes 213 | ) 214 | 215 | # (num_all_node, emb_size) 216 | 217 | for layer_idx in range(1, self.num_layers+1): 218 | this_layer_nodes = nodes_batch_layers[layer_idx][0] # all nodes in this layer 219 | neigh_nodes, neighbors_list = nodes_batch_layers[layer_idx-1] # previous layer 220 | # list(), list(list()) 221 | 222 | aggregate_feats = self.aggregate(neighbors_list, pre_hidden_embs, all_nodes_idx) 223 | # (this_layer_nodes_num, emb_size) 224 | 225 | sage_layer = getattr(self, 'sage_layer'+str(layer_idx)) 226 | 227 | cur_hidden_embs = sage_layer(self_feats=pre_hidden_embs[[all_nodes_idx[int(n)] for n in this_layer_nodes]], #pre_hidden_embs[layer_nodes], 228 | aggregate_feats=aggregate_feats) 229 | 230 | # cur_hidden_embs = torch.cat([pre_hidden_embs[[all_nodes_idx[int(n)] for n in this_layer_nodes]].unsqueeze(1), 231 | # aggregate_feats.unsqueeze(1)], dim=1) # (b_s, 2, emb_size) 232 | # cur_hidden_embs = torch.mean(cur_hidden_embs, dim=1) 233 | 234 | pre_hidden_embs[[all_nodes_idx[int(n)] for n in this_layer_nodes]] = cur_hidden_embs 235 | 236 | # (input_batch_node_size, emb_size) 237 | # output the embeddings of the input nodes 238 | return pre_hidden_embs[[all_nodes_idx[int(n)] for n in nodes_batch]] 239 | 240 | def _nodes_map(self, nodes, hidden_embs, neighs): 241 | layer_nodes, samp_neighs, layer_nodes_dict = neighs 242 | assert len(samp_neighs) == len(nodes) 243 | index = [layer_nodes_dict[x] for x in nodes] 244 | return index 245 | 246 | def _get_unique_neighs_list(self, nodes, num_sample=10): 247 | # TODO 248 | neighbors_list = [self.adj_lists[int(node)] for node in nodes] 249 | if not num_sample is None: 250 | samp_neighs = [np.random.choice(neighbors, num_sample) if len(neighbors)>0 else [] for neighbors in neighbors_list] 251 | else: 252 | samp_neighs = neighbors_list 253 | _unique_nodes_list = np.unique(list(chain(*samp_neighs))) 254 | return samp_neighs, _unique_nodes_list 255 | 256 | def aggregate(self, neighbors_list, pre_hidden_embs, all_nodes_idx): 257 | if self.agg_func == 'MEAN': 258 | agg_list = [torch.mean(pre_hidden_embs[ [int(all_nodes_idx[n]) for n in neighbors] ], dim=0).unsqueeze(0)\ 259 | if len(neighbors) > 0 else self.fill_tensor for neighbors in neighbors_list] 260 | if len(agg_list) > 0: 261 | return torch.cat(agg_list, dim=0) 262 | else: 263 | return torch.FloatTensor(0, pre_hidden_embs.shape[1]).fill_(0).to(self.device) 264 | if self.agg_func == 'MAX': 265 | return 0 266 | -------------------------------------------------------------------------------- /BertSAGE/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import networkx as nx 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torch.nn.utils.rnn import pad_sequence 9 | from collections import Counter 10 | from copy import deepcopy 11 | from itertools import chain 12 | 13 | from transformers import BertTokenizer, RobertaTokenizer 14 | from tqdm import tqdm 15 | 16 | MAX_NODE_LENGTH=10 17 | # for xWant ,neg_prop=4, all_in_aser, when this equals 10, filter out 0.6%. when 11, filter out 0.1%. 18 | 19 | 20 | np.random.seed(229) 21 | 22 | class InferenceGraphDataset(): 23 | def __init__(self, np_file_path, device, encoder, training_graph): 24 | 25 | self.training_graph = training_graph 26 | 27 | self.data = np.load(np_file_path, allow_pickle=True)[()] 28 | 29 | # train_edges = dict([(tuple(edge), True) for edge in self.training_graph.train_edges]) 30 | 31 | # self.node2id = dict([(node, i) for i, node in enumerate(self.data.flatten())]) 32 | # self.id2node = dict([(i, node) for i, node in enumerate(self.data.flatten())]) 33 | 34 | self.data_id = {} 35 | 36 | self.data_id["head"] = np.array([[self.training_graph.node2id[head], 37 | self.training_graph.node2id[tail], hid] for head, tail, hid in self.data["head"] \ 38 | if len(head.split()) < MAX_NODE_LENGTH and len(tail.split()) < MAX_NODE_LENGTH]) # if (self.training_graph.node2id[head], self.training_graph.node2id[tail]) not in train_edges 39 | 40 | self.data_id["tail"] = np.array([[self.training_graph.node2id[head], 41 | self.training_graph.node2id[tail], -1] for head, tail in self.data["tail"] \ 42 | if len(head.split()) < MAX_NODE_LENGTH and len(tail.split()) < MAX_NODE_LENGTH]) # if (self.training_graph.node2id[head], self.training_graph.node2id[tail]) not in train_edges 43 | 44 | self.data_id["new"] = np.array([[self.training_graph.node2id[head], 45 | self.training_graph.node2id[tail], -1] for head, tail in self.data["new"] \ 46 | if len(head.split()) < MAX_NODE_LENGTH and len(tail.split()) < MAX_NODE_LENGTH]) # if (self.training_graph.node2id[head], self.training_graph.node2id[tail]) not in train_edges 47 | 48 | def get_nodes_tokenized(self): 49 | return self.training_graph.get_nodes_tokenized() 50 | def get_adj_list(self): 51 | return self.training_graph.get_adj_list() 52 | def get_batch(self, batch_size=16, mode="head"): 53 | for i in range(0, len(self.data_id[mode]), batch_size): 54 | yield self.data_id[mode][i:min(i+batch_size, len(self.data_id[mode]))] 55 | 56 | class InferenceSimpleDataset(): 57 | def __init__(self, np_file_path, device, encoder,training_graph ): 58 | self.training_graph = training_graph 59 | self.data = np.load(np_file_path, allow_pickle=True)[()] 60 | # train_edges = dict([(tuple(edge), True) for edge in self.training_graph.train_edges]) 61 | 62 | self.data_id = {} 63 | 64 | self.data_id["head"] = np.array([[self.training_graph.node2id[head], 65 | self.training_graph.node2id[tail], hid] for head, tail, hid in self.data["head"] \ 66 | if len(head.split()) < MAX_NODE_LENGTH and len(tail.split()) < MAX_NODE_LENGTH ]) 67 | self.data_id["tail"] = np.array([[self.training_graph.node2id[head], 68 | self.training_graph.node2id[tail], -1] for head, tail in self.data["tail"] \ 69 | if len(head.split()) < MAX_NODE_LENGTH and len(tail.split()) < MAX_NODE_LENGTH ]) 70 | self.data_id["new"] = np.array([[self.training_graph.node2id[head], 71 | self.training_graph.node2id[tail], -1] for head, tail in self.data["new"] \ 72 | if len(head.split()) < MAX_NODE_LENGTH and len(tail.split()) < MAX_NODE_LENGTH ]) 73 | 74 | # if (self.training_graph.node2id[head], self.training_graph.node2id[tail]) not in train_edges 75 | def get_nodes_tokenized(self): 76 | return self.training_graph.get_nodes_tokenized() 77 | def get_batch(self, batch_size=16, mode="head"): 78 | for i in range(0, len(self.data_id[mode]), batch_size): 79 | yield self.data_id[mode][i:min(i+batch_size, len(self.data_id[mode]))] 80 | 81 | class GraphDataset(): 82 | 83 | def __init__(self, nx_file_path, 84 | device, 85 | encoder, 86 | split=[0.8, 0.1, 0.1], 87 | max_train_num=1000000, 88 | load_edge_types="ATOMIC", 89 | negative_sample="fix_head", 90 | atomic_csv_path="/home/tfangaa/Downloads/ATOMIC/v4_atomic_all_agg.csv", 91 | random_split=False, 92 | neg_prop=1.0): 93 | assert load_edge_types in ["ATOMIC", "ASER", "ATOMIC+ASER"], \ 94 | "should be in [\"ATOMIC\", \"ASER\", \"ATOMIC+ASER\"]" 95 | 96 | """ 97 | load_edge_types controls the edges to be loaded to self.adj_list 98 | """ 99 | 100 | # 1. Load graph 101 | self.atomic_csv_path = atomic_csv_path 102 | G = nx.read_gpickle(nx_file_path) 103 | 104 | print("dataset statistics:\nnumber of nodes:{}\nnumber of edges:{}\n".format(len(G.nodes()), len(G.edges()))) 105 | 106 | self.id2node = {} 107 | self.node2id = {} 108 | 109 | filter_nodes = [] 110 | for node in G.nodes(): 111 | if len(node.split()) > MAX_NODE_LENGTH: # filter extra large nodes 112 | filter_nodes.append(node) 113 | print("num of removing nodes:", len(filter_nodes)) 114 | G.remove_nodes_from(filter_nodes) 115 | 116 | 117 | for i, node in enumerate(G.nodes()): 118 | self.id2node[i] = node 119 | self.node2id[node] = i 120 | 121 | # 2. Prepare training and testing edges 122 | all_edges_shuffle = list(G.edges.data()) 123 | np.random.shuffle(all_edges_shuffle) 124 | ATOMIC_edges = [edge for edge in all_edges_shuffle if edge[2]["relation"]=="ATOMIC"] 125 | atomic_edge_cnter = Counter([(edge[2]['hid'], edge[2]['tid']) for edge in ATOMIC_edges]) 126 | self.train_id, self.val_id, self.test_id = [int(s*len(ATOMIC_edges)) for s in np.cumsum(split)/np.sum(split)] 127 | 128 | edge_by_htid = dict() 129 | for e in ATOMIC_edges: 130 | htid = (e[2]['hid'], e[2]['tid']) 131 | if htid not in edge_by_htid: 132 | edge_by_htid[htid] = [[self.node2id[e[0]], self.node2id[e[1]]]] 133 | else: 134 | edge_by_htid[htid].append([self.node2id[e[0]], self.node2id[e[1]]]) 135 | 136 | # randomly select edges 137 | current_edge_list = deepcopy(list(atomic_edge_cnter.keys())) 138 | if random_split: 139 | train_edges_pos_all, val_edges_pos_all, test_edges_pos_all = self.get_random_split(current_edge_list, edge_by_htid) 140 | else: 141 | train_edges_pos_all, val_edges_pos_all, test_edges_pos_all = self.get_split_from_atomic(edge_by_htid) 142 | 143 | print('Number of positive training examples:{}, validating:{}, testing:{}'.format(len(train_edges_pos_all), len(val_edges_pos_all), len(test_edges_pos_all))) 144 | 145 | if len(train_edges_pos_all) > max_train_num: 146 | train_edges_pos = list(np.array(train_edges_pos_all)[np.random.permutation(len(train_edges_pos_all))[:max_train_num]]) 147 | val_edges_pos = list(np.array(val_edges_pos_all)[np.random.permutation(len(val_edges_pos_all))[:int(max_train_num * split[1]/split[0])]]) 148 | test_edges_pos = list(np.array(test_edges_pos_all)[np.random.permutation(len(test_edges_pos_all))[:int(max_train_num * split[2]/split[0])]]) 149 | else: 150 | train_edges_pos = train_edges_pos_all 151 | val_edges_pos = val_edges_pos_all 152 | test_edges_pos = test_edges_pos_all 153 | 154 | edge_dict = dict([((self.node2id[head], self.node2id[tail]), True) for head, tail in G.edges()]) 155 | 156 | print('Number of positive training examples after trucating:{}, validating:{}, testing:{}'.format(len(train_edges_pos), len(val_edges_pos), len(test_edges_pos))) 157 | 158 | # 3. Sample negative edges 159 | 160 | if negative_sample == "fix_head": 161 | # bipartite graph 162 | all_heads = [self.node2id[node] for node, out_degree in G.out_degree if out_degree>0] 163 | all_tails = [self.node2id[node] for node, out_degree in G.out_degree if out_degree==0] 164 | neg_edges = [] 165 | num_neg = len(train_edges_pos) + len(val_edges_pos) + len(test_edges_pos) 166 | for i in range( int(num_neg * neg_prop) ): 167 | hd_idx = np.random.randint(0, len(all_heads)) 168 | tl_idx = np.random.randint(0, len(all_tails)) 169 | while (all_heads[hd_idx], all_tails[tl_idx]) in edge_dict: 170 | hd_idx = np.random.randint(0, len(all_heads)) 171 | tl_idx = np.random.randint(0, len(all_tails)) 172 | neg_edges.append([all_heads[hd_idx], all_tails[tl_idx]]) 173 | elif negative_sample == "from_all": 174 | neg_edges = [] 175 | num_neg = len(train_edges_pos) + len(val_edges_pos) + len(test_edges_pos) 176 | for i in range( int(num_neg * neg_prop) ): 177 | rnd = np.random.randint(0, len(self.node2id), 2) 178 | tmp_edge = (rnd[0], rnd[1]) 179 | while tmp_edge in edge_dict or tmp_edge[0] == tmp_edge[1]: 180 | rnd = np.random.randint(0, len(self.node2id), 2) 181 | tmp_edge = (rnd[0], rnd[1]) 182 | neg_edges.append(list(tmp_edge)) 183 | elif negative_sample == "prepared_neg": 184 | # some of the negative samples are pre-prepared 185 | neg_train = [[self.node2id[head], self.node2id[tail]] \ 186 | for head, tail, feat in G.edges.data() if feat["relation"]=="neg_trn"] 187 | neg_val = [[self.node2id[head], self.node2id[tail]] \ 188 | for head, tail, feat in G.edges.data() if feat["relation"]=="neg_dev"] 189 | neg_test = [[self.node2id[head], self.node2id[tail]] \ 190 | for head, tail, feat in G.edges.data() if feat["relation"]=="neg_tst"] 191 | print("num of prepared neg for train:{}, dev:{}, test:{}".format(len(neg_train), len(neg_val), len(neg_test))) 192 | neg_edges = [] 193 | num_neg = len(train_edges_pos) + len(val_edges_pos) + len(test_edges_pos) 194 | for i in range(int(num_neg * neg_prop) - len(neg_train) - len(neg_val) - len(neg_test) ): 195 | rnd = np.random.randint(0, len(self.node2id), 2) 196 | tmp_edge = (rnd[0], rnd[1]) 197 | while tmp_edge in edge_dict or tmp_edge[0] == tmp_edge[1]: 198 | rnd = np.random.randint(0, len(self.node2id), 2) 199 | tmp_edge = (rnd[0], rnd[1]) 200 | neg_edges.append(list(tmp_edge)) 201 | trn_val_idx = int(len(train_edges_pos)*neg_prop) - len(neg_train) 202 | val_tst_idx = trn_val_idx + int(len(val_edges_pos)*neg_prop)-len(neg_val) 203 | neg_edges = neg_train + neg_edges[:trn_val_idx]\ 204 | +neg_val + neg_edges[trn_val_idx:val_tst_idx]\ 205 | +neg_test + neg_edges[val_tst_idx:] 206 | assert len(neg_edges) == int((len(train_edges_pos) + len(val_edges_pos) + len(test_edges_pos))*neg_prop) 207 | 208 | train_edges_neg = neg_edges[:int(len(train_edges_pos)*neg_prop)] 209 | val_edges_neg = neg_edges[int(len(train_edges_pos)*neg_prop):int((len(train_edges_pos)+len(val_edges_pos))*neg_prop)] 210 | test_edges_neg = neg_edges[int((len(train_edges_pos)+len(val_edges_pos))*neg_prop):] 211 | print('Number of negative examples after trucating:{}, validating:{}, testing:{}'.format(len(train_edges_neg), len(val_edges_neg), len(test_edges_neg))) 212 | 213 | self.train_labels = np.array([0] * len(train_edges_neg) + [1] * len(train_edges_pos)) 214 | self.train_edges = np.array(train_edges_neg + train_edges_pos) 215 | train_shuffle_idx = np.random.permutation(len(self.train_edges)) 216 | self.train_labels, self.train_edges = self.train_labels[train_shuffle_idx], self.train_edges[train_shuffle_idx] 217 | 218 | self.val_labels = np.array([0] * len(val_edges_neg) + [1] * len(val_edges_pos)) 219 | self.val_edges = np.array(val_edges_neg + val_edges_pos) 220 | val_shuffle_idx = np.random.permutation(len(self.val_edges)) 221 | self.val_labels, self.val_edges = self.val_labels[val_shuffle_idx], self.val_edges[val_shuffle_idx] 222 | 223 | self.test_labels = np.array([0] * len(test_edges_neg) + [1] * len(test_edges_pos)) 224 | self.test_edges = np.array(test_edges_neg + test_edges_pos) 225 | test_shuffle_idx = np.random.permutation(len(self.test_edges)) 226 | self.test_labels, self.test_edges = self.test_labels[test_shuffle_idx], self.test_edges[test_shuffle_idx] 227 | 228 | print('finish preparing neg samples') 229 | 230 | self.mode_edges = { 231 | "train":torch.tensor(self.train_edges).to(device), 232 | "valid":torch.tensor(self.val_edges).to(device), 233 | "test":torch.tensor(self.test_edges).to(device) 234 | } 235 | self.mode_labels = { 236 | "train":torch.tensor(self.train_labels).to(device), 237 | "valid":torch.tensor(self.val_labels).to(device), 238 | "test":torch.tensor(self.test_labels).to(device) 239 | } 240 | 241 | # Prepare a sparse adj matrix, mask all the valid and test set 242 | # adj list that contains all the training edges 243 | self.adj_list = [[] for i in range(len(self.id2node))] 244 | 245 | # Edges are all the edges except for those in test/val set 246 | val_edges_dict = dict([((edge[0], edge[1]), True) for edge in val_edges_pos]) 247 | test_edges_dict = dict([((edge[0], edge[1]), True) for edge in test_edges_pos]) 248 | 249 | for head, tail, feat in G.edges.data(): 250 | if load_edge_types == "ATOMIC": 251 | if feat["relation"] != "ATOMIC": 252 | continue 253 | elif load_edge_types == "ASER": 254 | if feat["relation"] != "ASER": 255 | continue 256 | elif load_edge_types == "ATOMIC+ASER": 257 | pass 258 | if (self.node2id[head], self.node2id[tail]) not in val_edges_dict \ 259 | and (self.node2id[head], self.node2id[tail]) not in test_edges_dict : 260 | self.adj_list[self.node2id[head]].append(self.node2id[tail]) 261 | 262 | # 4. Tokenize nodes 263 | 264 | if encoder == "bert": 265 | self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 266 | elif encoder == "roberta": 267 | self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 268 | self.id2nodestoken = dict([(self.node2id[line], torch.tensor(self.tokenizer.encode(line, 269 | add_special_tokens=True)).to(device)) for line in tqdm(self.node2id)]) 270 | 271 | def get_random_split(self, current_edge_list, edge_by_htid): 272 | train_edges_pos_all = [] 273 | covered_atomic_edge = {} 274 | 275 | while len(train_edges_pos_all) <= self.train_id: 276 | # select an edge 277 | rnd_id = np.random.randint(len(current_edge_list)) 278 | htid = current_edge_list[rnd_id] 279 | current_edge_list.pop(rnd_id) 280 | train_edges_pos_all.extend(edge_by_htid[htid]) 281 | 282 | val_edges_pos_all = [] 283 | while len(train_edges_pos_all) + len(val_edges_pos_all) <= self.val_id: 284 | rnd_id = np.random.randint(len(current_edge_list)) 285 | htid = current_edge_list[rnd_id] 286 | current_edge_list.pop(rnd_id) 287 | val_edges_pos_all.extend(edge_by_htid[htid]) 288 | 289 | test_edges_pos_all = [] 290 | while len(train_edges_pos_all) + len(val_edges_pos_all) + len(test_edges_pos_all) < self.test_id: 291 | rnd_id = np.random.randint(len(current_edge_list)) 292 | htid = current_edge_list[rnd_id] 293 | current_edge_list.pop(rnd_id) 294 | test_edges_pos_all.extend(edge_by_htid[htid]) 295 | return train_edges_pos_all, val_edges_pos_all, test_edges_pos_all 296 | def get_split_from_atomic(self, edge_by_htid): 297 | train_edges_pos_all = [] 298 | val_edges_pos_all = [] 299 | test_edges_pos_all = [] 300 | atomic_raw = pd.read_csv(self.atomic_csv_path) 301 | splits = dict((i,spl) for i,spl in enumerate(atomic_raw['split'])) 302 | for htid, edge in edge_by_htid.items(): 303 | if splits[htid[0]] == "trn": 304 | train_edges_pos_all.extend(edge) 305 | elif splits[htid[0]] == "dev": 306 | val_edges_pos_all.extend(edge) 307 | elif splits[htid[0]] == "tst": 308 | test_edges_pos_all.extend(edge) 309 | return train_edges_pos_all, val_edges_pos_all, test_edges_pos_all 310 | 311 | def get_adj_list(self): 312 | return self.adj_list 313 | def get_nid2text(self): 314 | return self.id2node 315 | 316 | def get_nodes_tokenized(self): 317 | return self.id2nodestoken 318 | 319 | def get_batch(self, batch_size=16, mode="train"): 320 | assert mode in ["train", "valid", "test"], "invalid mode" 321 | 322 | for i in range(0, len(self.mode_edges[mode]), batch_size): 323 | yield self.mode_edges[mode][i:min(i+batch_size, len(self.mode_edges[mode]))], \ 324 | self.mode_labels[mode][i:min(i+batch_size, len(self.mode_edges[mode]))] --------------------------------------------------------------------------------