├── README.md ├── dataset_config.py ├── basic_meta_learning_module.py ├── metrics.py ├── main_reptile.py ├── util.py ├── hyp_reptile.py ├── data_reptile.py ├── triplet_vae.py ├── memory_encoder.py ├── reptile.py ├── prepare_dbpedia.py └── prepare_wikidata.py /README.md: -------------------------------------------------------------------------------- 1 | # Few-shot Learning for Uncommon Entities and Relations in Knowledge Graphs 2 | 3 | The source code for our paper [Tackling Long-Tailed Relations and Uncommon Entities in Knowledge Graph Completion](https://arxiv.org/abs/1909.11359). 4 | 5 | Requirements: python3, pytorch1.0 6 | 7 | First, download two datasets from here: https://drive.google.com/open?id=1bekOAfMrx9V3uUp6dSYWkr-L5f3fTkwP, and put them into Few-shot-KGC/data/. 8 | 9 | Train a model on Wikidata: python main_reptile.py --train --dataset wikidata --idx_device -1 10 | 11 | Train a model on DBpedia: python main_reptile.py --train --dataset dbpedia --idx_device -1 12 | 13 | Please note: --idx_device -1 means to train model on CPU. If you want to train it on GPU, replace -1 to your GPU index. For example, --idx_device 0. The code cannot simutaneously run on multiple GPUs. 14 | 15 | Dataset-specific hyper-parameters are in dataset_config.py, and other hyper-parameters are in hyp_reptile.py. 16 | 17 | 18 | -------------------------------------------------------------------------------- /dataset_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | def get_dataset_config(dataset): 4 | config = {} 5 | if dataset == "wikidata": 6 | config["tmp_root"] = "./tmp/wikidata/" 7 | config["json_dataset_root"] = "/gds/zhwang/zhwang/data/knowledge_graph/wikidata/" 8 | config["pkl_dataset_root"] = "/gds/zhwang/zhwang/workspace/knowledge_graph/zero_shot/" 9 | config["raw_pkl_path"] = "./data/wikidata/all_data.pkl" 10 | 11 | config["min_word_cnt"] = 1 12 | config["rm_num_in_name"] = True 13 | config["max_len_dscp"] = 32 14 | config["min_task_size"] = 5 15 | config["min_num_cand"] = 100 16 | 17 | config["encoder_dim_pool_kernel"] = 2 18 | config["num_aug"] = 8 19 | config["margin"] = 1.0 20 | elif dataset == "dbpedia": 21 | config["tmp_root"] = "./tmp/dbpedia/" 22 | config["dataset_root"] = "/gds/zhwang/zhwang/data/knowledge_graph/dbpedia/dbpedia500/" 23 | config["raw_pkl_path"] = "./data/dbpedia/all_data.pkl" 24 | 25 | config["min_word_cnt"] = 10 26 | config["rm_num_in_name"] = True 27 | config["max_len_dscp"] = 200 28 | config["tasks_split_size"] = [220, 30, 69] # 319 29 | config["min_task_size"] = 5 30 | config["max_task_size"] = 1000 31 | config["min_num_cand"] = 100 32 | config["max_num_cand"] = 1000 33 | 34 | config["encoder_dim_pool_kernel"] = 4 35 | config["num_aug"] = 128 36 | config["margin"] = 2.0 37 | else: 38 | raise RuntimeError("wrong dataset in get_dataset_config()") 39 | 40 | return config 41 | -------------------------------------------------------------------------------- /basic_meta_learning_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sys import exit 3 | import torch as T 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class BasicMetaLearningModule(nn.Module): 8 | def __init__(self, hyp): 9 | super(BasicMetaLearningModule, self).__init__() 10 | self.device = hyp.device 11 | self.params = nn.ParameterDict() 12 | 13 | # saving original parameters by using deepcopy(self.encoder.state_dict()) and load_state_dict() leads to GPU memory leak 14 | def copy_param(self): 15 | new_params = {} 16 | for name, param in self.params.items(): 17 | new_params[name] = T.clone(param) 18 | 19 | return new_params 20 | 21 | def set_param(self, new_params): 22 | for name, param in new_params.items(): 23 | self.params[name].data = self.params[name].data.copy_(param) 24 | 25 | def xavier(self, *shape): 26 | return nn.Parameter(nn.init.xavier_uniform_(T.empty(shape, device = self.device))) 27 | 28 | def unif(self, *shape): 29 | return nn.Parameter(nn.init.uniform_(T.empty(shape, device = self.device), -0.1, 0.1)) 30 | 31 | def fill(self, value, *shape): 32 | return nn.Parameter(T.empty(shape, device = self.device).fill_(value)) 33 | 34 | def add_norm_param(self, norm_type, dim, prefix, postfix): 35 | if norm_type == "instance_norm": 36 | self.params["{}_mean_norm{}".format(prefix, postfix)] = self.fill(0.0, dim) 37 | self.params["{}_var_norm{}".format(prefix, postfix)] = self.fill(1.0, dim) 38 | self.params["{}_mean_norm{}".format(prefix, postfix)].requires_grad = False 39 | self.params["{}_var_norm{}".format(prefix, postfix)].requires_grad = False 40 | self.params["{}_w_norm{}".format(prefix, postfix)] = self.unif(dim) 41 | self.params["{}_b_norm{}".format(prefix, postfix)] = self.fill(0.0, dim) 42 | else: 43 | pass 44 | 45 | def norm(self, norm_type, x, prefix, postfix): 46 | if norm_type == "instance_norm": 47 | return F.instance_norm(x, 48 | self.params["{}_mean_norm{}".format(prefix, postfix)], 49 | self.params["{}_var_norm{}".format(prefix, postfix)], 50 | self.params["{}_w_norm{}".format(prefix, postfix)], 51 | self.params["{}_b_norm{}".format(prefix, postfix)]) 52 | else: 53 | return x 54 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import logging 4 | from logging import info, warn, error 5 | from sys import exit 6 | 7 | from util import * 8 | 9 | class Metrics(object): 10 | def __init__(self): 11 | self.hits10 = 0 12 | self.hits5 = 0 13 | self.hits1 = 0 14 | self.mrr = 0 15 | self.task_hits10 = 0 16 | self.task_hits5 = 0 17 | self.task_hits1 = 0 18 | self.task_mrr = 0 19 | self.num_task_data = 0 20 | self.num_all_data = 0 21 | 22 | def reset_task_metrics(self, task): 23 | self.cur_task = task 24 | self.task_hits10 = 0 25 | self.task_hits5 = 0 26 | self.task_hits1 = 0 27 | self.task_mrr = 0 28 | self.task_rank = [] 29 | self.num_task_data = 0 30 | 31 | def add(self, data): 32 | idx_sorted = list(np.argsort(data)) # ascending 33 | rank = idx_sorted.index(0) + 1 34 | if rank <= 10: 35 | self.hits10 += 1 36 | self.task_hits10 += 1 37 | if rank <= 5: 38 | self.hits5 += 1 39 | self.task_hits5 += 1 40 | if rank <= 1: 41 | self.hits1 += 1 42 | self.task_hits1 += 1 43 | self.mrr += 1 / rank 44 | self.task_mrr += 1 / rank 45 | self.num_task_data += 1 46 | self.num_all_data += 1 47 | self.task_rank.append((rank, data.shape[0])) 48 | 49 | def log_task_metric(self): 50 | task_metrics = {} 51 | task_metrics["hits10"] = self.task_hits10 / self.num_task_data 52 | task_metrics["hits5"] = self.task_hits5 / self.num_task_data 53 | task_metrics["hits1"] = self.task_hits1 / self.num_task_data 54 | task_metrics["mrr"] = self.task_mrr / self.num_task_data 55 | prefix = "task {}: {} data, ".format(self.cur_task, self.num_task_data) 56 | ''' 57 | print("task rank: ", end = " ") 58 | for duo in self.task_rank: 59 | print("{}/{} ".format(duo[0], duo[1]), end = " ") 60 | print("") 61 | ''' 62 | self.__log_metric(task_metrics, prefix) 63 | 64 | 65 | def log_overall_metric(self): 66 | metrics = {} 67 | metrics["hits10"] = self.hits10 / self.num_all_data 68 | metrics["hits5"] = self.hits5 / self.num_all_data 69 | metrics["hits1"] = self.hits1 / self.num_all_data 70 | metrics["mrr"] = self.mrr / self.num_all_data 71 | prefix = "overall dataset: {} data, ".format(self.num_all_data) 72 | self.__log_metric(metrics, prefix) 73 | 74 | 75 | def __log_metric(self, metrics, prefix): 76 | logging_info = prefix 77 | for name, val in metrics.items(): 78 | logging_info += "{} = {} ".format(name, "{:.4f}".format(val)) 79 | info(logging_info) 80 | 81 | -------------------------------------------------------------------------------- /main_reptile.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import logging 4 | from logging import info, warn, error 5 | from sys import exit 6 | import numpy as np 7 | import torch as T 8 | from collections import defaultdict 9 | 10 | from util import * 11 | from hyp_reptile import init_hyp, log_hyp 12 | from data_reptile import load_raw_pkl, TrainDataset, EvalDataset 13 | from reptile import REPTILE 14 | 15 | def predict(tag, eval_dataset, model, hyp): 16 | info("\n------------------------------------------\nstart evaluating on {} dataset with {}".format(tag, hyp.model)) 17 | model.eval() 18 | model.predict(eval_dataset) 19 | if hyp.is_training: 20 | model.train() 21 | info("\n------------------------------------------\nfinish evaluating on {} dataset with {}\n".format(tag, hyp.model)) 22 | 23 | def training(train_dataset, dev_dataset, test_dataset, model, hyp): 24 | info("\n------------------------------------------\nstart training") 25 | model.train() 26 | loss_items = defaultdict(float) 27 | 28 | for epoch in range(hyp.existing_epoch + 1, hyp.max_epoch + 1): 29 | loss = model.meta_train(train_dataset) 30 | for name, val in loss.items(): 31 | loss_items[name] += val 32 | if epoch % hyp.training_print_freq == 0: 33 | #info("{}/{} epoches".format(epoch, hyp.max_epoch)) 34 | total_loss = 0.0 35 | msg = "{}/{} epoches, ".format(epoch, hyp.max_epoch) 36 | for name, val in loss_items.items(): 37 | msg += "{} = {}, ".format(name, val) 38 | total_loss += val 39 | msg += "total_loss = {}".format(total_loss) 40 | info(msg) 41 | 42 | if dev_dataset is not None: 43 | predict("dev", dev_dataset, model, hyp) 44 | if test_dataset is not None: 45 | predict("test", test_dataset, model, hyp) 46 | loss_items = defaultdict(float) 47 | 48 | if not hyp.is_debugging and epoch % hyp.save_freq == 0: 49 | save_model(epoch, model, hyp) 50 | 51 | def run(hyp): 52 | if hyp.seed is not None: 53 | fix_random_seeds(int(hyp.seed)) 54 | 55 | if hyp.idx_device == -1 or not T.cuda.is_available(): 56 | hyp.device = T.device("cpu") 57 | else: 58 | hyp.device = T.device("cuda:{}".format(hyp.idx_device)) 59 | 60 | train_task, dev_task, test_task, ent_dscps, rel_dscps, i2r, w2i, i2w, rel2cand, e1rel_e2 = load_raw_pkl(hyp) 61 | 62 | hyp.dict_size = len(i2w) 63 | #hyp.char_size = len(i2c) 64 | hyp.idx_word_pad = w2i[hyp.WORD_PAD] 65 | 66 | train_dataset = TrainDataset("train", train_task, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp) 67 | dev_dataset = EvalDataset("dev", dev_task, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp) 68 | test_dataset = EvalDataset("test", test_task, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp) 69 | 70 | model = eval(hyp.model)(hyp) 71 | 72 | if hyp.load_existing_model: 73 | model = load_model(model, hyp) 74 | 75 | if hyp.is_training: 76 | training(train_dataset, dev_dataset, test_dataset, model, hyp) 77 | else: 78 | predict("dev", dev_dataset, model, hyp) 79 | predict("test", test_dataset, model, hyp) 80 | 81 | if __name__ == "__main__": 82 | np.set_printoptions(threshold = np.inf) 83 | hyp = init_hyp() 84 | 85 | logger = init_logger(hyp) 86 | log_hyp(hyp) 87 | run(hyp) 88 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sys import exit 3 | import logging 4 | from logging import info, warn, error 5 | import random 6 | import numpy as np 7 | import torch as T 8 | 9 | def init_logger(hyp): 10 | log_path = hyp.log_root + hyp.prefix 11 | if hyp.is_training: 12 | log_path += "_train_" 13 | if hyp.postfix != "": 14 | log_path += hyp.postfix + "_" 15 | log_path += hyp.existing_timestamp 16 | else: 17 | log_path += "_predict_" 18 | log_path += hyp.existing_timestamp 19 | log_path += "_epoch" + str(hyp.existing_epoch) 20 | log_path += ".log" 21 | 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.DEBUG) 24 | formatter = logging.Formatter("%(asctime)s %(levelname)s: - %(message)s", datefmt = "%Y-%m-%d %H:%M:%S") 25 | if not hyp.is_debugging: 26 | fh = logging.FileHandler(log_path) 27 | fh.setLevel(logging.DEBUG) 28 | fh.setFormatter(formatter) 29 | logger.addHandler(fh) 30 | ch = logging.StreamHandler() 31 | ch.setLevel(logging.DEBUG) 32 | ch.setFormatter(formatter) 33 | logger.addHandler(ch) 34 | 35 | return logger 36 | 37 | def fix_random_seeds(seed): 38 | warn("random seed is fixed: {}\n".format(seed)) 39 | random.seed(seed) 40 | np.random.seed(seed) 41 | T.manual_seed(seed) 42 | if T.cuda.is_available(): 43 | T.cuda.manual_seed_all(seed) 44 | 45 | def save_model(epoch, model, hyp): 46 | path = "{}{}_{}_epoch{}.pt".format(hyp.model_root, hyp.prefix, hyp.existing_timestamp, epoch) 47 | checkpoint = { 48 | "epoch" : epoch, 49 | "idx_device" : hyp.idx_device, 50 | } 51 | state_dict = model.get_state_dict() 52 | checkpoint = {**checkpoint, **state_dict} 53 | T.save(checkpoint, path) 54 | info("finish saving model to {}".format(path)) 55 | 56 | def load_model(model, hyp): 57 | path = "{}{}_{}_epoch{}.pt".format(hyp.model_root, hyp.prefix, hyp.existing_timestamp, hyp.existing_epoch) 58 | checkpoint = T.load(path) 59 | assert hyp.existing_epoch == checkpoint["epoch"] 60 | model.set_state_dict(checkpoint) 61 | info("finish loading model from {}".format(path)) 62 | return model 63 | 64 | def build_glove_emb(w2i, hyp): 65 | glove_emb = pickle.load(open(hyp.glove_pkl_path, "rb")) 66 | 67 | if glove_emb.shape[0] != len(w2i) or glove_emb.shape[1] != hyp.dim_emb: 68 | info("rebuild glove word embedding") 69 | glove_path = "{}glove.6B.{}d.txt".format(hyp.glove_root, hyp.dim_emb) 70 | glove_emb = T.empty((len(w2i), hyp.dim_emb)) 71 | with open(glove_path, "r") as f_src: 72 | pass 73 | else: 74 | return glove_emb 75 | 76 | def extract_top_level_dict(current_dict): 77 | """ 78 | Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params 79 | :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree. 80 | :param value: Param value 81 | :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it. 82 | :return: A dictionary graph of the params already added to the graph. 83 | """ 84 | output_dict = {} 85 | for key in current_dict.keys(): 86 | name = key.replace("layer_dict.", "") 87 | top_level = name.split(".")[0] 88 | sub_level = ".".join(name.split(".")[1:]) 89 | 90 | if top_level not in output_dict: 91 | if sub_level == "": 92 | output_dict[top_level] = current_dict[key] 93 | else: 94 | output_dict[top_level] = {sub_level: current_dict[key]} 95 | else: 96 | new_item = {key: value for key, value in output_dict[top_level].items()} 97 | new_item[sub_level] = current_dict[key] 98 | output_dict[top_level] = new_item 99 | 100 | #print(current_dict.keys(), output_dict.keys()) 101 | return output_dict 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /hyp_reptile.py: -------------------------------------------------------------------------------- 1 | from sys import exit 2 | import argparse 3 | import logging 4 | from logging import info 5 | import time 6 | from dataset_config import get_dataset_config 7 | 8 | def init_hyp(dataset = None): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--UNK", default = "UNK", type = str) 12 | parser.add_argument("--NUM", default = "$", type = str) # for both character and word embeddings 13 | parser.add_argument("--WORD_PAD", default = "PAD", type = str) 14 | 15 | if dataset is None: 16 | parser.add_argument("--dataset", default = "", type = str) 17 | 18 | parser.add_argument("--data_root", default = "./data/", type = str) 19 | parser.add_argument("--model_root", default = "./model/", type = str) 20 | parser.add_argument("--log_root", default = "./log/", type = str) 21 | parser.add_argument("--raw_glove_root", default = "/gds/zhwang/zhwang/data/general_data/glove/", type = str) 22 | parser.add_argument("--glove_pkl_path", default = "./data/glove_emb.pkl", type = str) 23 | 24 | parser.add_argument("--postfix", default = "", type = str) 25 | parser.add_argument("--idx_device", default = -1, type = int) # -1: CPU, 0, 1, ...: GPU 26 | parser.add_argument("--model", default = "reptile", type = str) 27 | 28 | parser.add_argument("--debug", dest = "is_debugging", action = "store_true") 29 | parser.set_defaults(is_debugging = False) 30 | parser.add_argument("--train", dest = "is_training", action = "store_true") 31 | parser.add_argument("--test", dest = "is_training", action = "store_false") 32 | parser.set_defaults(is_training = None) 33 | parser.add_argument("--max_epoch", default = 1000, type = int) 34 | parser.add_argument("--training_print_freq", default = 100, type = int) 35 | parser.add_argument("--save_freq", default = 100000, type = int) 36 | 37 | parser.add_argument("--existing_epoch", default = 0, type = int) 38 | parser.add_argument("--existing_timestamp", default = "", type = str) 39 | 40 | parser.add_argument("--cand_bucket_size", default = 100, type = int) 41 | 42 | parser.add_argument("--meta_batch_size", default = 8, type = int) 43 | parser.add_argument("--num_shot", default = 1, type = int) 44 | parser.add_argument("--num_train_inner_iter", default = 5, type = int) 45 | parser.add_argument("--num_test_inner_iter", default = 5, type = int) 46 | 47 | parser.add_argument("--inner_lr", default = 1e-3, type = float) 48 | parser.add_argument("--meta_lr", default = 1e-3, type = float) 49 | 50 | parser.add_argument("--dim_emb", default = 100, type = int) 51 | 52 | parser.add_argument("--encoder_num_cnn_layer", default = 3, type = int) 53 | parser.add_argument("--encoder_dim_conv_filter", default = 3, type = int) 54 | parser.add_argument("--encoder_num_conv_filter", nargs = "+", default = [100, 100, 100], type = int) 55 | parser.add_argument("--encoder_normalization", default = "instance_norm", type = str) 56 | parser.add_argument("--encoder_num_memory", default = 128, type = int) 57 | parser.add_argument("--no_encoder_self_atten", dest = "encoder_self_atten", action = "store_false") 58 | parser.set_defaults(encoder_self_atten = True) 59 | parser.add_argument("--encoder_num_head", nargs = "+", default = [5, 5], type = int) 60 | parser.add_argument("--encoder_act_func", default = "tanh", type = str) 61 | 62 | parser.add_argument("--num_aug", default = 0, type = int) 63 | parser.add_argument("--no_vae", dest = "use_vae", action = "store_false") 64 | parser.set_defaults(use_vae = True) 65 | parser.add_argument("--vae_num_conv_filter", default = 100, type = int) 66 | parser.add_argument("--vae_dim_hidden", default = 100, type = int) 67 | 68 | parser.add_argument("--vae_prior_nn_num_layer", default = 2, type = int) 69 | parser.add_argument("--vae_prior_nn_dim_hidden", nargs = "+", default = [100, 100], type = int) 70 | 71 | parser.add_argument("--vae_dim_latent", default = 50, type = int) 72 | parser.add_argument("--no_vae_prior", dest = "vae_use_prior", action = "store_false") 73 | parser.set_defaults(vae_use_prior = True) 74 | parser.add_argument("--prior_sigma_m", default = 1e4, type = float) 75 | parser.add_argument("--prior_sigma_s", default = 1e-4, type = float) 76 | parser.add_argument("--vae_lambda_kld", default = 1.0, type = float) 77 | parser.add_argument("--vae_lambda_reg", default = 1.0, type = float) 78 | parser.add_argument("--vae_normalization", default = "instance_norm", type = str) 79 | parser.add_argument("--vae_act_func", default = "tanh", type = str) 80 | 81 | parser.add_argument("--cnn_encoder", dest = "memory", action = "store_false") 82 | parser.set_defaults(memory = True) 83 | 84 | parser.add_argument("--sim_func", default = 'TransE', type = str) 85 | 86 | parser.add_argument("--seed", default = 1550148948, type = int) 87 | hyp = parser.parse_args() 88 | 89 | if dataset is None: # main() 90 | dataset_config = get_dataset_config(hyp.dataset) 91 | else: # prepare_data() 92 | dataset_config = get_dataset_config(dataset) 93 | for k, v in dataset_config.items(): 94 | if k == "num_aug": 95 | if not hyp.use_vae or hyp.num_aug != 0: 96 | continue 97 | setattr(hyp, k, v) 98 | 99 | if dataset is None: 100 | if hyp.is_training is None: 101 | print("--train or --test must be specified") 102 | exit(-1) 103 | elif hyp.is_training: 104 | if hyp.existing_timestamp != "" and hyp.existing_epoch != 0: 105 | hyp.load_existing_model = True 106 | else: 107 | hyp.load_existing_model = False 108 | new_timestamp = str(time.time()).split(".")[0] 109 | hyp.existing_timestamp = new_timestamp 110 | else: 111 | hyp.load_existing_model = True 112 | if hyp.existing_timestamp == "": 113 | print("existing_timestamp must be specified") 114 | exit(-1) 115 | if hyp.existing_epoch == 0: 116 | print("existing_epoch must be specified") 117 | exit(-1) 118 | 119 | if dataset is None and hyp.seed == 0: 120 | hyp.seed = hyp.existing_timestamp 121 | 122 | hyp.model = hyp.model.upper() 123 | hyp.prefix = hyp.model 124 | 125 | return hyp 126 | 127 | def log_hyp(hyp): 128 | info("----------------------------\n\nhyperparameters:\n") 129 | for k, v in vars(hyp).items(): 130 | info("{} = {}".format(k, v)) 131 | info("------------------------------------------\n") 132 | 133 | if __name__ == "__main__": 134 | logger = logging.getLogger() 135 | logger.setLevel(logging.DEBUG) 136 | ch = logging.StreamHandler() 137 | ch.setLevel(logging.INFO) 138 | logger.addHandler(ch) 139 | 140 | hyp = init_hyp() 141 | hyp.aa = 1 # add extra hyper-parameters after parsing 142 | log_hyp(hyp) 143 | -------------------------------------------------------------------------------- /data_reptile.py: -------------------------------------------------------------------------------- 1 | from sys import exit 2 | import pickle 3 | import logging 4 | from logging import info, error 5 | from random import sample, choices 6 | import torch as T 7 | 8 | from hyp_reptile import * 9 | 10 | def load_raw_pkl(hyp): 11 | pkl_path = hyp.raw_pkl_path + ".debug" if hyp.is_debugging else hyp.raw_pkl_path 12 | info("start loading raw pickle from {}".format(pkl_path)) 13 | all_data = pickle.load(open(pkl_path, "rb")) 14 | train_task, dev_task, test_task, ent_dscps, rel_dscps, i2r, w2i, i2w, rel_cand, e1rel_e2 = all_data 15 | 16 | info("#task in each split: #train = {}, #dev = {}, #test = {}".format( 17 | len(train_task), len(dev_task), len(test_task))) 18 | info("#ent = {}, #rel = {}, #word = {}".format( 19 | len(ent_dscps), len(i2r), len(w2i))) 20 | 21 | return all_data 22 | 23 | class BaseDataset(object): 24 | def __init__(self, dataset, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp, 25 | task_type, num_shot, num_inner_iter): 26 | self.dataset = dataset 27 | self.ent_dscps = ent_dscps 28 | self.rel_dscps = rel_dscps 29 | self.rel2cand = rel2cand 30 | self.e1rel_e2 = e1rel_e2 31 | self.hyp = hyp 32 | 33 | self.idx_word_pad = hyp.idx_word_pad 34 | self.device = hyp.device 35 | self.task_type = task_type 36 | self.num_shot = hyp.num_shot 37 | self.num_inner_iter = num_inner_iter 38 | 39 | self.dataset_rels = list(self.dataset.keys()) 40 | self.num_all_data = 0 41 | for rel in self.dataset_rels: 42 | self.num_all_data += len(self.dataset[rel]) 43 | 44 | self.task_size = self.num_shot if task_type == "train" else 1 45 | self.idx_sup_inner = 0 46 | self.sup_task = None 47 | self.idx_sup = list(range(self.task_size)) 48 | self.update_len = lambda new_len, exist_len : new_len if new_len > exist_len else exist_len 49 | 50 | def sample_sup_noise(self, ent1, rel, ent2): 51 | existing_ent2 = self.e1rel_e2[(ent1, rel)] 52 | while 1: 53 | noise = sample(self.rel2cand[rel], 1)[0] 54 | if noise not in existing_ent2 and noise != ent2: 55 | break 56 | 57 | return noise 58 | 59 | def init_sup_tensors(self, task_src): 60 | self.sup_task = [] 61 | for i, (ent1, ent2, entn, rel) in enumerate(task_src): 62 | dscp1, dscp2, dscpn, dscpr = self.ent_dscps[ent1], self.ent_dscps[ent2], self.ent_dscps[entn], self.rel_dscps[rel] 63 | self.sup_task.append((dscp1, dscp2, dscpn, dscpr)) 64 | 65 | def get_inner_data(self): 66 | batch = [e.to(self.device) for e in self.sup_task[self.idx_sup_inner]] 67 | 68 | self.idx_sup_inner += 1 69 | if self.idx_sup_inner >= self.num_shot: 70 | self.sup_task = None 71 | 72 | return batch 73 | 74 | class TrainDataset(BaseDataset): 75 | def __init__(self, task_type, dataset, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp): 76 | BaseDataset.__init__(self, dataset, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp, 77 | task_type, hyp.num_shot, hyp.num_train_inner_iter) 78 | info("finish initializing {} dataset".format(task_type)) 79 | 80 | def get_sup_inner_data(self): 81 | if self.sup_task is None: 82 | self.__init_sup_task() 83 | return self.get_inner_data() 84 | 85 | def __init_sup_task(self): 86 | self.idx_sup_inner = 0 87 | task_src = [] 88 | 89 | rel = sample(self.dataset_rels, 1)[0] 90 | rel_duos = self.dataset[rel] 91 | if self.num_shot > len(rel_duos): 92 | task_duo = [choice(rel_duos) for i in range(self.num_shot)] 93 | else: 94 | task_duo = sample(rel_duos, self.num_shot) 95 | for ent1, ent2 in task_duo: 96 | noise = self.sample_sup_noise(ent1, rel, ent2) 97 | task_src.append((ent1, ent2, noise, rel)) 98 | 99 | self.init_sup_tensors(task_src) 100 | 101 | class EvalDataset(BaseDataset): 102 | def __init__(self, task_type, dataset, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp): 103 | BaseDataset.__init__(self, dataset, ent_dscps, rel_dscps, rel2cand, e1rel_e2, hyp, 104 | task_type, hyp.num_shot, hyp.num_test_inner_iter) 105 | self.cand_bucket_size = hyp.cand_bucket_size 106 | 107 | self.__reset() 108 | info("finish initializing {} dataset".format(task_type)) 109 | 110 | def next_rel(self): 111 | self.idx_ent = self.num_shot 112 | if self.idx_rel >= len(self.dataset_rels): 113 | self.__reset() 114 | return False 115 | else: 116 | self.cur_rel = self.dataset_rels[self.idx_rel] 117 | self.cur_task = self.dataset[self.cur_rel] 118 | self.idx_rel += 1 119 | return True 120 | 121 | def get_sup_inner_data(self): 122 | if self.sup_task is None: 123 | self.__init_sup_task() 124 | return self.get_inner_data() 125 | 126 | def get_qur_task_cand_rel(self): 127 | cand_buckets = [] 128 | bucket = [] 129 | for cand in self.rel2cand[self.cur_rel]: # cands in rel2cand have been sorted aescendingly 130 | if len(bucket) >= self.cand_bucket_size: 131 | cand_buckets.append(bucket) 132 | bucket = [] 133 | bucket.append(cand) 134 | if len(bucket) != 0: 135 | cand_buckets.append(bucket) 136 | 137 | batch_dscp_cand = [] 138 | for bucket in cand_buckets: 139 | max_bucket_dscp_len = len(self.ent_dscps[bucket[-1]]) 140 | t_dscp_cand = T.full((len(bucket), max_bucket_dscp_len), self.idx_word_pad, dtype = T.long) 141 | for i, cand in enumerate(bucket): 142 | cand_dscp = self.ent_dscps[cand] 143 | t_dscp_cand[i, :len(cand_dscp)] = cand_dscp 144 | batch_dscp_cand.append(t_dscp_cand.to(self.device)) 145 | 146 | t_dscpr = self.rel_dscps[self.cur_rel].long().to(self.device) 147 | return batch_dscp_cand, t_dscpr 148 | 149 | # this method only iterate through entities within a relation 150 | def get_qur_inner_data(self): 151 | if self.idx_ent >= len(self.cur_task): 152 | return None 153 | 154 | ent1, ent_true = self.cur_task[self.idx_ent] 155 | self.idx_ent += 1 156 | 157 | t_dscp1 = self.ent_dscps[ent1].to(self.device) 158 | t_dscp2 = self.ent_dscps[ent_true].to(self.device) 159 | 160 | cands = [] 161 | for idx_cand, cand in enumerate(self.rel2cand[self.cur_rel]): 162 | if cand not in self.e1rel_e2[(ent1, self.cur_rel)] and cand != ent_true: 163 | cands.append(idx_cand) # local index of candidate entity in rel2cand[cur_rel] 164 | t_idx_cands = T.tensor(cands, dtype = T.long, device = self.device) 165 | 166 | qur_task = [t_dscp1, t_dscp2, t_idx_cands] 167 | return qur_task 168 | 169 | def __init_sup_task(self): 170 | self.idx_sup_inner = 0 171 | task_src = [] 172 | 173 | for ent1, ent2 in self.cur_task[:self.num_shot]: 174 | noise = self.sample_sup_noise(ent1, self.cur_rel, ent2) 175 | task_src.append((ent1, ent2, noise, self.cur_rel)) 176 | self.init_sup_tensors(task_src) 177 | 178 | def __reset(self): 179 | self.idx_rel = 0 180 | self.idx_ent = self.num_shot 181 | self.cur_rel = self.dataset_rels[self.idx_rel] 182 | self.cur_task = self.dataset[self.cur_rel] 183 | 184 | if __name__ == "__main__": 185 | hyp = init_hyp() 186 | hyp.device = T.device("cpu") 187 | 188 | logger = logging.getLogger() 189 | logger.setLevel(logging.DEBUG) 190 | formatter = logging.Formatter('%(asctime)s %(levelname)s: - %(message)s', datefmt = '%Y-%m-%d %H:%M:%S') 191 | ch = logging.StreamHandler() 192 | ch.setLevel(logging.INFO) 193 | ch.setFormatter(formatter) 194 | logger.addHandler(ch) 195 | 196 | train_task, dev_task, test_task, ent_dscps, rel_dscps, i2r, w2i, i2w, rel_cand, e1rel_e2 = all_data 197 | -------------------------------------------------------------------------------- /triplet_vae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sys import exit 3 | import torch as T 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from basic_meta_learning_module import BasicMetaLearningModule 7 | 8 | class TripletVAE(BasicMetaLearningModule): 9 | def __init__(self, hyp): 10 | super(TripletVAE, self).__init__(hyp) 11 | self.dim_input = hyp.encoder_num_conv_filter[-1] 12 | self.num_aug = hyp.num_aug 13 | self.num_conv_filter = hyp.vae_num_conv_filter 14 | self.dim_hidden = hyp.vae_dim_hidden 15 | self.prior_nn_num_layer = hyp.vae_prior_nn_num_layer 16 | self.prior_nn_dim_hidden = hyp.vae_prior_nn_dim_hidden 17 | 18 | self.dim_latent = hyp.vae_dim_latent 19 | 20 | self.use_prior = hyp.vae_use_prior 21 | self.prior_sigma_m = hyp.prior_sigma_m 22 | self.prior_sigma_s = hyp.prior_sigma_s 23 | self.lambda_kld = hyp.vae_lambda_kld 24 | self.lambda_reg = hyp.vae_lambda_reg 25 | self.normalization = hyp.vae_normalization 26 | self.act_func = T.tanh if hyp.vae_act_func == "tanh" else F.leaky_relu 27 | 28 | self.__init_params() 29 | 30 | def learn(self, dscp1, dscp2, dscpr): 31 | triplet = T.stack([dscp1, dscpr, dscp2], 1).transpose(1, 2) # shape: (num_shot, dim_input, 3) 32 | proposal_mu, proposal_sigma = self.__proposal_network(triplet) 33 | prior_mu, prior_sigma = self.__prior_network(dscpr) 34 | eps = T.randn(proposal_mu.shape, device = self.device) 35 | recon_triplet = self.__generative_network(proposal_mu, proposal_sigma, eps) 36 | 37 | log_proposal_sigma, log_prior_sigma = T.log(proposal_sigma), T.log(prior_sigma) 38 | loss_reconstruct = T.mean((triplet - recon_triplet) ** 2) 39 | kld = log_prior_sigma - log_proposal_sigma + 0.5 * (proposal_sigma ** 2 + (proposal_mu - prior_mu) ** 2) / prior_sigma ** 2 - 0.5 40 | kld = self.lambda_kld * T.mean(kld) 41 | reg = 0 42 | if self.use_prior: 43 | reg = prior_mu ** 2 / (2 * self.prior_sigma_m ** 2) - self.prior_sigma_s * (log_prior_sigma - prior_sigma) 44 | reg = self.lambda_reg * T.mean(reg) 45 | loss = loss_reconstruct + kld + reg 46 | 47 | loss_print = [["vae_reconstruct", loss_reconstruct.data.cpu().item()], 48 | ["vae_kld", kld.data.cpu().item()]] 49 | if self.use_prior: 50 | loss_print.append(["vae_regularization", reg.data.cpu().item()]) 51 | 52 | return loss, loss_print 53 | 54 | def generate(self, dscpr): 55 | prior_mu, prior_sigma = self.__prior_network(dscpr) 56 | eps = T.randn((self.num_aug, self.dim_latent), device = self.device) 57 | gen_triplets = self.__generative_network(prior_mu, prior_sigma, eps) 58 | ent1, ent2 = gen_triplets[:, :, 0], gen_triplets[:, :, 2] 59 | 60 | return ent1, ent2 61 | 62 | def __generative_network(self, mu, sigma, eps): 63 | x = mu + sigma * eps 64 | 65 | for i in range(2): 66 | x = T.mm(x, self.params["generative_w_hidden{}".format(i)]) + self.params["generative_b_hidden{}".format(i)] 67 | x = self.norm(self.normalization, x, "generative", i) 68 | x = self.act_func(x) 69 | x = x.unsqueeze(2) 70 | x = F.conv_transpose1d(x, self.params["generative_w_deconv0"], self.params["generative_b_deconv0"]) 71 | x = self.norm(self.normalization, x, "generative", 2) 72 | x = self.act_func(x) 73 | x = F.conv_transpose1d(x, self.params["generative_w_deconv1"], self.params["generative_b_deconv1"]) 74 | x = self.act_func(x) 75 | 76 | return x 77 | 78 | def __proposal_network(self, x): 79 | for i in range(2): 80 | x = F.conv1d(x, self.params["proposal_w_conv{}".format(i)], self.params["proposal_b_conv{}".format(i)]) 81 | x = self.norm(self.normalization, x, "proposal", i) 82 | x = self.act_func(x) 83 | x = x.squeeze(2) 84 | x = T.mm(x, self.params["proposal_w_hidden"]) + self.params["proposal_b_hidden"] 85 | x = self.norm(self.normalization, x, "proposal", 2) 86 | x = self.act_func(x) 87 | 88 | mu = T.mm(x, self.params["proposal_w_mu"]) + self.params["proposal_b_mu"] 89 | sigma = F.softplus(T.mm(x, self.params["proposal_w_sigma"]) + self.params["proposal_b_sigma"]) 90 | return mu, sigma 91 | 92 | def __prior_network(self, x): 93 | for i in range(self.prior_nn_num_layer): 94 | x = T.mm(x, self.params["prior_w_hidden{}".format(i)]) + self.params["prior_b_hidden{}".format(i)] 95 | x = self.norm(self.normalization, x, "prior", i) 96 | x = self.act_func(x) 97 | mu = T.mm(x, self.params["prior_w_mu"]) + self.params["prior_b_mu"] 98 | sigma = F.softplus(T.mm(x, self.params["prior_w_sigma"]) + self.params["prior_b_sigma"]) 99 | return mu, sigma 100 | 101 | def __init_params(self): 102 | # proposal network 103 | self.params["proposal_w_conv0"] = self.xavier(self.num_conv_filter, self.dim_input, 2) 104 | self.params["proposal_b_conv0"] = self.fill(0.0, self.num_conv_filter) 105 | self.add_norm_param(self.normalization, self.num_conv_filter, "proposal", 0) 106 | self.params["proposal_w_conv1"] = self.xavier(self.num_conv_filter, self.num_conv_filter, 2) 107 | self.params["proposal_b_conv1"] = self.fill(0.0, self.num_conv_filter) 108 | self.add_norm_param(self.normalization, self.num_conv_filter, "proposal", 1) 109 | self.params["proposal_w_hidden"] = self.xavier(self.num_conv_filter, self.dim_hidden) 110 | self.params["proposal_b_hidden"] = self.fill(0.0, self.dim_hidden) 111 | self.add_norm_param(self.normalization, self.dim_hidden, "proposal", 2) 112 | self.params["proposal_w_mu"] = self.xavier(self.dim_hidden, self.dim_latent) 113 | self.params["proposal_b_mu"] = self.fill(0.0, self.dim_latent) 114 | self.params["proposal_w_sigma"] = self.xavier(self.dim_hidden, self.dim_latent) 115 | self.params["proposal_b_sigma"] = self.fill(0.0, self.dim_latent) 116 | 117 | # prior network 118 | self.params["prior_w_hidden0"] = self.xavier(self.dim_input, self.prior_nn_dim_hidden[0]) 119 | self.params["prior_b_hidden0"] = self.fill(0.0, self.prior_nn_dim_hidden[0]) 120 | self.add_norm_param(self.normalization, self.prior_nn_dim_hidden[0], "prior", 0) 121 | for i in range(1, self.prior_nn_num_layer): 122 | self.params["prior_w_hidden{}".format(i)] = self.xavier(self.prior_nn_dim_hidden[i - 1], self.prior_nn_dim_hidden[i]) 123 | self.params["prior_b_hidden{}".format(i)] = self.fill(0.0, self.prior_nn_dim_hidden[i]) 124 | self.add_norm_param(self.normalization, self.prior_nn_dim_hidden[i], "prior", i) 125 | self.params["prior_w_mu"] = self.xavier(self.prior_nn_dim_hidden[-1], self.dim_latent) 126 | self.params["prior_b_mu"] = self.fill(0.0, self.dim_latent) 127 | self.params["prior_w_sigma"] = self.xavier(self.prior_nn_dim_hidden[-1], self.dim_latent) 128 | self.params["prior_b_sigma"] = self.fill(0.0, self.dim_latent) 129 | 130 | # generative network 131 | self.params["generative_w_hidden0"] = self.xavier(self.dim_latent, self.dim_hidden) 132 | self.params["generative_b_hidden0"] = self.fill(0.0, self.dim_hidden) 133 | self.add_norm_param(self.normalization, self.dim_hidden, "generative", 0) 134 | self.params["generative_w_hidden1"] = self.xavier(self.dim_hidden, self.num_conv_filter) 135 | self.params["generative_b_hidden1"] = self.fill(0.0, self.num_conv_filter) 136 | self.add_norm_param(self.normalization, self.num_conv_filter, "generative", 1) 137 | self.params["generative_w_deconv0"] = self.xavier(self.num_conv_filter, self.num_conv_filter, 2) 138 | self.params["generative_b_deconv0"] = self.fill(0.0, self.num_conv_filter) 139 | self.add_norm_param(self.normalization, self.num_conv_filter, "generative", 2) 140 | self.params["generative_w_deconv1"] = self.xavier(self.num_conv_filter, self.dim_input, 2) 141 | self.params["generative_b_deconv1"] = self.fill(0.0, self.dim_input) 142 | -------------------------------------------------------------------------------- /memory_encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sys import exit 3 | import torch as T 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | from basic_meta_learning_module import BasicMetaLearningModule 8 | 9 | class MemoryEncoder(BasicMetaLearningModule): 10 | def __init__(self, hyp): 11 | super(MemoryEncoder, self).__init__(hyp) 12 | self.dict_size = hyp.dict_size 13 | self.dim_emb = hyp.dim_emb 14 | self.idx_word_pad = hyp.idx_word_pad 15 | 16 | self.num_cnn_layer = hyp.encoder_num_cnn_layer 17 | self.dim_conv_filter = hyp.encoder_dim_conv_filter 18 | self.num_conv_filter = hyp.encoder_num_conv_filter 19 | self.dim_pool_kernel = hyp.encoder_dim_pool_kernel 20 | self.normalization = hyp.encoder_normalization 21 | self.num_memory = hyp.encoder_num_memory 22 | self.num_head = hyp.encoder_num_head 23 | self.act_func = T.tanh if hyp.encoder_act_func == "tanh" else F.relu 24 | self.use_self_atten = hyp.encoder_self_atten 25 | 26 | assert self.dim_conv_filter % 2 != 0 27 | self.padding_size = self.dim_conv_filter // 2 28 | self.dim_head = [] 29 | for i in range(len(self.num_head)): 30 | assert self.num_conv_filter[i] % self.num_head[i] == 0 31 | self.dim_head.append(self.num_conv_filter[i] // self.num_head[i]) 32 | 33 | self.__init_params() 34 | 35 | def encode_rel(self, dscpr): 36 | r_emb = self.__rel_CNN(dscpr) 37 | 38 | if self.num_memory > 0: 39 | atten = F.cosine_similarity(self.params["w_head_memory"], r_emb.unsqueeze(0), 1) # shape: (num_memory) 40 | atten = atten.unsqueeze(0) 41 | atten = T.softmax(atten, 1).transpose(0, 1) 42 | head_memory_r = T.sum(atten * self.params["w_head_trans"], 0) 43 | 44 | atten = F.cosine_similarity(self.params["w_tail_memory"], r_emb.unsqueeze(0), 1) # shape: (num_memory) 45 | atten = atten.unsqueeze(0) 46 | atten = T.softmax(atten, 1).transpose(0, 1) 47 | tail_memory_r = T.sum(atten * self.params["w_tail_trans"], 0) 48 | 49 | return r_emb.unsqueeze(0), head_memory_r, tail_memory_r 50 | else: 51 | return r_emb.unsqueeze(0), r_emb, r_emb 52 | 53 | def __rel_CNN(self, x): 54 | if len(x.shape) == 1: 55 | x = x.unsqueeze(0) 56 | x = F.embedding(x, self.params["w_emb"], self.idx_word_pad, 1.0) 57 | x = x.permute(0, 2, 1) 58 | x = F.conv1d(x, self.params["w_conv00"], self.params["b_conv00"], padding = self.padding_size) 59 | x = F.conv1d(x, self.params["w_conv01"], self.params["b_conv01"], padding = self.padding_size) 60 | self.norm(self.normalization, x, "encoder", 0) 61 | x = self.act_func(x) 62 | x = F.max_pool1d(x, self.dim_pool_kernel, self.dim_pool_kernel, ceil_mode = True) # len_x /= 2 63 | 64 | for i in range(1, self.num_cnn_layer): 65 | x = F.conv1d(x, self.params["w_conv{}0".format(i)], self.params["b_conv{}0".format(i)], padding = self.padding_size) 66 | x = F.conv1d(x, self.params["w_conv{}1".format(i)], self.params["b_conv{}1".format(i)], padding = self.padding_size) 67 | self.norm(self.normalization, x, "encoder", i) 68 | x = self.act_func(x) 69 | if i == self.num_cnn_layer - 1: 70 | x = T.mean(x, -1) 71 | else: 72 | x = F.max_pool1d(x, self.dim_pool_kernel, self.dim_pool_kernel, ceil_mode = True) 73 | 74 | x = x.squeeze() 75 | return x # shape: (num_conv_filter[-1]) 76 | 77 | def forward(self, x, memory_r): 78 | if len(x.shape) == 1: 79 | x = x.unsqueeze(0) 80 | x_emb = F.embedding(x, self.params["w_emb"], self.idx_word_pad, 1.0) 81 | x = x_emb.permute(0, 2, 1) 82 | 83 | x = self.__CNN(x, memory_r) 84 | x = self.__SelfAtten(x) 85 | 86 | return x # shape: (batch_size, num_conv_filter[-1]) 87 | 88 | def __CNN(self, x, memory_r): 89 | x = F.conv1d(x, self.params["w_conv00"], self.params["b_conv00"], padding = self.padding_size) 90 | x = F.conv1d(x, self.params["w_conv01"], self.params["b_conv01"], padding = self.padding_size) 91 | self.norm(self.normalization, x, "encoder", 0) 92 | x = self.act_func(x) 93 | 94 | if self.num_memory > 0: 95 | x = x.transpose(1, 2) 96 | memory_atten = F.cosine_similarity(x, memory_r.unsqueeze(0).unsqueeze(0), 2) 97 | memory_atten = T.softmax(memory_atten, 1) 98 | x = x * memory_atten.unsqueeze(-1) 99 | x = x.transpose(1, 2) 100 | 101 | x = F.max_pool1d(x, self.dim_pool_kernel, self.dim_pool_kernel, ceil_mode = True) # len_x /= 2 102 | return x 103 | 104 | def __SelfAtten(self, x): 105 | batch_size = x.shape[0] 106 | for i in range(1, self.num_cnn_layer): 107 | x = F.conv1d(x, self.params["w_conv{}0".format(i)], self.params["b_conv{}0".format(i)], padding = self.padding_size) 108 | x = F.conv1d(x, self.params["w_conv{}1".format(i)], self.params["b_conv{}1".format(i)], padding = self.padding_size) 109 | self.norm(self.normalization, x, "encoder", i) 110 | x = self.act_func(x) 111 | 112 | if self.use_self_atten: 113 | len_x = x.shape[2] 114 | num_head, dim_head = self.num_head[i - 1], self.dim_head[i - 1] 115 | x = x.transpose(1, 2) 116 | x = x.view(batch_size, len_x, num_head, dim_head).transpose(1, 2) 117 | atten = T.matmul(x, x.transpose(-2, -1)) / (dim_head ** 0.5) 118 | atten = T.softmax(atten, -1) # shape: (batch_size, len_x, num_head[i - 1], num_head[i - 1]) 119 | x = T.matmul(atten, x) 120 | x = x.transpose(1, 2).contiguous().view(batch_size, len_x, num_head * dim_head) 121 | x = x.transpose(1, 2) 122 | 123 | if i == self.num_cnn_layer - 1: 124 | x = T.mean(x, -1) 125 | else: 126 | x = F.max_pool1d(x, self.dim_pool_kernel, self.dim_pool_kernel, ceil_mode = True) 127 | 128 | return x 129 | 130 | def __init_params(self): 131 | w_emb = nn.init.xavier_uniform_(T.empty((self.dict_size, self.dim_emb), device = self.device)) 132 | w_emb[self.idx_word_pad].fill_(0.0) 133 | self.params["w_emb"] = nn.Parameter(w_emb) # w_emb 134 | 135 | num_in_filter = self.dim_emb 136 | num_out_filter = self.num_conv_filter[0] 137 | self.params["w_conv00"] = self.xavier(num_out_filter, num_in_filter, self.dim_conv_filter) 138 | self.params["b_conv00"] = self.fill(0.0, num_out_filter) 139 | #''' 140 | num_in_filter = self.num_conv_filter[0] 141 | num_out_filter = self.num_conv_filter[0] 142 | self.params["w_conv01"] = self.xavier(num_out_filter, num_in_filter, self.dim_conv_filter) 143 | self.params["b_conv01"] = self.fill(0.0, num_out_filter) 144 | #''' 145 | 146 | self.add_norm_param(self.normalization, num_out_filter, "encoder", 0) 147 | 148 | if self.num_memory > 0: 149 | self.params["w_head_memory"] = self.unif(self.num_memory, num_out_filter) 150 | self.params["w_tail_memory"] = self.unif(self.num_memory, num_out_filter) 151 | self.params["w_head_trans"] = self.unif(self.num_memory, num_out_filter) 152 | self.params["w_tail_trans"] = self.unif(self.num_memory, num_out_filter) 153 | 154 | for i in range(1, self.num_cnn_layer): 155 | num_in_filter = self.num_conv_filter[i - 1] 156 | num_out_filter = self.num_conv_filter[i] 157 | # conv1d 158 | self.params["w_conv{}0".format(i)] = self.xavier(num_out_filter, num_in_filter, self.dim_conv_filter) 159 | self.params["b_conv{}0".format(i)] = self.fill(0.0, num_out_filter) 160 | #''' 161 | num_in_filter = self.num_conv_filter[i] 162 | num_out_filter = self.num_conv_filter[i] 163 | self.params["w_conv{}1".format(i)] = self.xavier(num_out_filter, num_in_filter, self.dim_conv_filter) 164 | self.params["b_conv{}1".format(i)] = self.fill(0.0, num_out_filter) 165 | #''' 166 | 167 | self.add_norm_param(self.normalization, num_out_filter, "encoder", i) 168 | 169 | 170 | -------------------------------------------------------------------------------- /reptile.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sys import exit 3 | from logging import info, error 4 | from metrics import Metrics 5 | import torch as T 6 | from torch import nn, optim 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from collections import defaultdict 10 | 11 | #from cnn_encoder import CNN_Encoder 12 | from memory_encoder import MemoryEncoder 13 | from triplet_vae import TripletVAE 14 | 15 | class REPTILE(nn.Module): 16 | def __init__(self, hyp): 17 | super(REPTILE, self).__init__() 18 | self.hyp = hyp 19 | self.memory = hyp.memory 20 | self.device = hyp.device 21 | 22 | if self.memory: 23 | self.encoder = MemoryEncoder(hyp) 24 | info("finish initializing memory encoder") 25 | else: 26 | exit(-1) 27 | #self.encoder = CNN_Encoder(hyp) 28 | #info("finish initializing CNN encoder") 29 | 30 | self.vae = TripletVAE(hyp) 31 | self.modules = { 32 | "encoder" : self.encoder, 33 | "vae" : self.vae 34 | } 35 | 36 | if hyp.sim_func == "TransE": 37 | self.sim_func = lambda dscp1_emb, dscp2_emb, dscpr_emb: T.sum(T.abs(dscp1_emb + dscpr_emb - dscp2_emb), -1) 38 | else: 39 | error("sim_func {} has not implemented. Candidates: TransE".format(hyp.sim_func)) 40 | exit(-1) 41 | 42 | if hyp.use_vae: 43 | info("VAE is used to augment data") 44 | self.optm = optim.Adam(self.parameters(), lr = hyp.inner_lr, betas = (0.0, 0.999)) 45 | else: 46 | self.optm = optim.Adam(self.encoder.parameters(), lr = hyp.inner_lr, betas = (0.0, 0.999)) 47 | info("finish initializing REPTILE") 48 | 49 | def get_state_dict(self): 50 | state_dict = { 51 | "encoder" : self.encoder.state_dict(), 52 | "vae" : self.vae.state_dict(), 53 | "optm" : self.optm.state_dict() 54 | } 55 | return state_dict 56 | 57 | def set_state_dict(self, checkpoint): 58 | self.encoder.load_state_dict(checkpoint["encoder"]) 59 | self.vae.load_state_dict(checkpoint["vae"]) 60 | self.optm.load_state_dict(checkpoint["optm"]) 61 | 62 | def meta_train(self, train_dataset): 63 | weights_original = self.copy_param() 64 | new_weights = None 65 | batch_loss = defaultdict(float) 66 | 67 | for idx_batch in range(self.hyp.meta_batch_size): 68 | dscp1, dscp2, dscpn, dscpr = train_dataset.get_sup_inner_data() 69 | for idx_iter in range(self.hyp.num_train_inner_iter): 70 | loss = self.__inner_train_step(dscp1, dscp2, dscpn, dscpr, True) 71 | for name, val in loss: 72 | batch_loss[name] += val 73 | if idx_batch == 0: 74 | new_weights = self.copy_param() 75 | else: 76 | tmp = self.copy_param() 77 | for module_name in tmp: 78 | for param_name in tmp[module_name]: 79 | new_weights[module_name][param_name] += tmp[module_name][param_name] 80 | self.set_param(weights_original) 81 | tmp = self.hyp.num_train_inner_iter * self.hyp.meta_batch_size 82 | for name in batch_loss: 83 | batch_loss[name] /= tmp 84 | 85 | if self.hyp.meta_batch_size > 1: 86 | for module_name in new_weights: 87 | for param_name in new_weights[module_name]: 88 | new_weights[module_name][param_name] /= self.hyp.meta_batch_size 89 | 90 | meta_updated_weights = defaultdict(dict) 91 | for module_name in new_weights: 92 | for param_name in new_weights[module_name]: 93 | meta_grad = weights_original[module_name][param_name] - new_weights[module_name][param_name] 94 | meta_updated_weights[module_name][param_name] = weights_original[module_name][param_name] - self.hyp.meta_lr * meta_grad 95 | self.set_param(meta_updated_weights) 96 | 97 | return batch_loss 98 | 99 | def __inner_train_step(self, dscp1, dscp2, dscpn, dscpr, meta_training): 100 | if self.memory: 101 | dscpr_emb, head_memory_dscpr_emb, tail_memory_dscpr_emb = self.encoder.encode_rel(dscpr) 102 | dscp1_emb = self.encoder(dscp1, head_memory_dscpr_emb) 103 | dscp2_emb = self.encoder(dscp2, tail_memory_dscpr_emb) 104 | dscpn_emb = self.encoder(dscpn, tail_memory_dscpr_emb) 105 | else: 106 | dscpr_emb = self.encoder(dscpr) 107 | dscp1_emb = self.encoder(dscp1) 108 | dscp2_emb = self.encoder(dscp2) 109 | dscpn_emb = self.encoder(dscpn) 110 | 111 | if self.hyp.use_vae and meta_training: 112 | vae_loss, vae_loss_print = self.vae.learn(dscp1_emb.detach(), dscp2_emb.detach(), dscpr_emb.detach()) 113 | else: 114 | vae_loss = T.zeros(1, device = self.device) 115 | 116 | if self.hyp.use_vae and not meta_training: 117 | aug1_emb, aug2_emb = self.vae.generate(dscpr_emb.detach()) 118 | dscp1_emb = T.cat([dscp1_emb, aug1_emb.detach()], 0) 119 | dscp2_emb = T.cat([dscp2_emb, aug2_emb.detach()], 0) 120 | 121 | sim_pos = self.sim_func(dscp1_emb, dscp2_emb, dscpr_emb) 122 | sim_neg = self.sim_func(dscp1_emb, dscpn_emb, dscpr_emb) 123 | kg_loss = T.mean(F.relu(self.hyp.margin + sim_pos - sim_neg)) 124 | loss = kg_loss + vae_loss 125 | 126 | self.optm.zero_grad() 127 | loss.backward() 128 | self.optm.step() 129 | 130 | loss_print = [["kg_loss", kg_loss.data.cpu().item()]] 131 | if self.hyp.use_vae and meta_training: 132 | loss_print += vae_loss_print 133 | return loss_print 134 | 135 | def __predict_triplet(self, dscp1, dscp2, idx_cands): 136 | if self.memory: 137 | dscp1_emb = self.encoder(dscp1, self.qur_head_memory_dscpr_emb) 138 | dscp2_emb = self.encoder(dscp2, self.qur_tail_memory_dscpr_emb) 139 | else: 140 | dscp1_emb = self.encoder(dscp1) 141 | dscp2_emb = self.encoder(dscp2) 142 | cands_emb = T.cat([dscp2_emb, self.qur_cands_emb[idx_cands]], 0) 143 | 144 | sim_cand = self.sim_func(dscp1_emb, cands_emb, self.qur_dscpr_emb) 145 | return sim_cand 146 | 147 | def predict(self, dataset): 148 | weights_original = self.copy_param() 149 | metrics = Metrics() 150 | 151 | while dataset.next_rel(): 152 | metrics.reset_task_metrics(dataset.cur_rel) 153 | 154 | for idx_iter in range(self.hyp.num_test_inner_iter): 155 | for idx_data in range(self.hyp.num_shot): 156 | dscp1, dscp2, dscpn, dscpr = dataset.get_sup_inner_data() 157 | for idx_iter in range(self.hyp.num_test_inner_iter): 158 | self.__inner_train_step(dscp1, dscp2, dscpn, dscpr, False) 159 | 160 | batch_dscp_cand, dscpr = dataset.get_qur_task_cand_rel() 161 | cands_emb = [] 162 | if self.memory: 163 | self.qur_dscpr_emb, self.qur_head_memory_dscpr_emb, self.qur_tail_memory_dscpr_emb = self.encoder.encode_rel(dscpr) 164 | cands_emb = [self.encoder(bucket, self.qur_tail_memory_dscpr_emb) for bucket in batch_dscp_cand] 165 | else: 166 | self.qur_dscpr_emb = self.encoder(dscpr) 167 | cands_emb = [self.encoder(bucket) for bucket in batch_dscp_cand] 168 | self.qur_cands_emb = T.cat(cands_emb, 0) 169 | 170 | while 1: 171 | qur_task = dataset.get_qur_inner_data() 172 | if qur_task is None: 173 | break 174 | sim_cands = self.__predict_triplet(*qur_task) 175 | sim_cands = sim_cands.detach().data # sim_cands[0]: y_true sim, sim_cands[1:]: y_cands sim 176 | sim_cands = sim_cands.cpu().numpy() 177 | metrics.add(sim_cands) 178 | 179 | self.set_param(weights_original) 180 | if self.memory: 181 | self.qur_dscpr_emb = self.qur_cands_emb = self.qur_head_memory_dscpr_emb = self.qur_tail_memory_dscpr_emb = None 182 | else: 183 | self.qur_dscpr_emb = self.qur_cands_emb = None 184 | metrics.log_task_metric() 185 | metrics.log_overall_metric() 186 | 187 | # saving original parameters by using deepcopy(self.encoder.state_dict()) and load_state_dict() leads to GPU memory leak 188 | def copy_param(self): 189 | new_params = {} 190 | for name, module in self.modules.items(): 191 | new_params[name] = module.copy_param() 192 | 193 | return new_params 194 | 195 | def set_param(self, new_params): 196 | for name, module_params in new_params.items(): 197 | self.modules[name].set_param(module_params) 198 | 199 | 200 | -------------------------------------------------------------------------------- /prepare_dbpedia.py: -------------------------------------------------------------------------------- 1 | import string 2 | import random 3 | import json 4 | from sys import exit 5 | import os 6 | import re 7 | from collections import defaultdict, Counter 8 | import pickle 9 | import torch as T 10 | from hyp_reptile import * 11 | 12 | def build_dscp(hyp): 13 | ents = set() 14 | rels = set() 15 | #ent_names = {} 16 | rel_names = {} 17 | ent_dscps = {} 18 | w_cnt = Counter() 19 | 20 | rm_aster = lambda s : re.sub(r'\*', '', s) # remove * 21 | tokenize = lambda s : re.sub('[{}]'.format(re.escape(string.punctuation)), r' ', s) # replace punctuations with space 22 | replace_num = lambda s, num_token : re.sub(r'\d+', num_token, s) # replace continuous numbers with '$' 23 | 24 | ''' 25 | with open(hyp.dataset_root + "entity_names.txt", "r", encoding = "utf-8") as f_ent: 26 | for ii, l in enumerate(f_ent): 27 | l = l.strip().encode("ascii", "ignore").decode("ascii").lower().split("\t") 28 | if len(l) != 3: 29 | continue 30 | token, name = l[0], l[2] 31 | name = replace_num(tokenize(rm_aster(name)), hyp.NUM) if hyp.rm_num_in_name else tokenize(rm_aster(name)) 32 | name = name.split() 33 | for e in name: 34 | w_cnt[e] += 1 35 | ent_names[token] = name 36 | ents.add(token) 37 | ''' 38 | 39 | with open(hyp.dataset_root + "relation_names.txt", "r", encoding = "utf-8") as f_ent: 40 | for l in f_ent: 41 | l = l.strip().encode("ascii", "ignore").decode("ascii").lower().split("\t") 42 | if len(l) != 3: 43 | continue 44 | token, name = l[0], l[2] 45 | name = replace_num(tokenize(rm_aster(name)), hyp.NUM) if hyp.rm_num_in_name else tokenize(rm_aster(name)) 46 | name = name.split() 47 | for e in name: 48 | w_cnt[e] += 1 49 | rel_names[token] = name 50 | rels.add(token) 51 | 52 | with open(hyp.dataset_root + "descriptions.txt", "r", encoding = "utf-8") as f_dscp: 53 | for ii, l in enumerate(f_dscp): 54 | l = l.strip().encode("ascii", "ignore").decode("ascii").lower().split("\t") 55 | if len(l) != 3: 56 | continue 57 | token, dscp = l[0], l[2] 58 | dscp = replace_num(tokenize(rm_aster(dscp)), hyp.NUM) 59 | dscp = dscp.split() 60 | dscp = dscp[:hyp.max_len_dscp] 61 | for e in dscp: 62 | w_cnt[e] += 1 63 | ent_dscps[token] = dscp 64 | ents.add(token) 65 | 66 | pkl_path = hyp.tmp_root + "dscp.pkl.tmp" 67 | with open(pkl_path, 'wb') as f_dump: 68 | pickle.dump((ents, rels, rel_names, ent_dscps, w_cnt), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 69 | return ents, rels, rel_names, ent_dscps, w_cnt 70 | 71 | def load_dscp(hyp): 72 | pkl_path = hyp.tmp_root + "dscp.pkl.tmp" 73 | with open(pkl_path, 'rb') as f_dump: 74 | ents, rels, rel_names, ent_dscps, w_cnt = pickle.load(f_dump) 75 | return ents, rels, rel_names, ent_dscps, w_cnt 76 | 77 | def build_dataset(dscp_ents, dscp_rels, hyp): 78 | def read_dataset(src_path, all_triplets, dscp_ents, dscp_rels): 79 | with open(src_path, "r", encoding = "utf-8") as f_dataset: 80 | for l in f_dataset: 81 | l = l.strip().encode("ascii", "ignore").decode("ascii").lower().split() 82 | if len(l) != 3: 83 | #print(l) 84 | continue 85 | 86 | ent1, ent2, rel = l 87 | if ent1 not in dscp_ents: 88 | #print("ent1 missing: {}".format(ent1)) 89 | continue 90 | if ent2 not in dscp_ents: 91 | #print("ent2 missing: {}".format(ent2)) 92 | continue 93 | if rel not in dscp_rels: 94 | #print("rel missing: {}".format(rel)) 95 | continue 96 | all_triplets[rel].append((ent1, ent2)) 97 | 98 | all_triplets = defaultdict(list) 99 | print(len(all_triplets)) 100 | read_dataset(hyp.dataset_root + "train.txt", all_triplets, dscp_ents, dscp_rels) 101 | print(len(all_triplets)) 102 | read_dataset(hyp.dataset_root + "valid.txt", all_triplets, dscp_ents, dscp_rels) 103 | print(len(all_triplets)) 104 | read_dataset(hyp.dataset_root + "test.txt", all_triplets, dscp_ents, dscp_rels) 105 | print(len(all_triplets)) 106 | 107 | valid_all_triplets = {} 108 | for rel, duos in all_triplets.items(): 109 | if len(duos) >= hyp.min_task_size and len(duos) <= hyp.max_task_size: 110 | valid_all_triplets[rel] = duos 111 | print("after filtering small and large tasks, #tasks =" ,len(valid_all_triplets)) 112 | 113 | if hyp.is_debugging: 114 | train_size = dev_size = test_size = len(valid_all_triplets) // 3 115 | else: 116 | train_size, dev_size, test_size = hyp.tasks_split_size 117 | train_dataset = defaultdict(list) 118 | dev_dataset = defaultdict(list) 119 | test_dataset = defaultdict(list) 120 | 121 | task_ents = set() 122 | task_rels = set() 123 | for idx_rel, (rel, duos) in enumerate(valid_all_triplets.items()): 124 | if idx_rel < train_size: 125 | dst_dataset = train_dataset 126 | elif idx_rel >= train_size and idx_rel < train_size + dev_size: 127 | dst_dataset = dev_dataset 128 | else: 129 | dst_dataset = test_dataset 130 | 131 | if hyp.is_debugging: 132 | duos = duos[:10] 133 | dst_dataset[rel] = duos 134 | for ent1, ent2 in duos: 135 | task_ents.add(ent1) 136 | task_ents.add(ent2) 137 | task_rels.add(rel) 138 | 139 | print("after building few shot dataset, #train = {}, #dev = {}, #test = {}".format(len(train_dataset), len(dev_dataset), len(test_dataset))) 140 | 141 | pkl_path = hyp.tmp_root + "dataset.pkl.tmp.debug" if hyp.is_debugging else hyp.tmp_root + "dataset.pkl.tmp" 142 | with open(pkl_path, 'wb') as f_dump: 143 | pickle.dump((train_dataset, dev_dataset, test_dataset, task_ents, task_rels), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 144 | 145 | return train_dataset, dev_dataset, test_dataset, task_ents, task_rels 146 | 147 | def load_dataset(hyp): 148 | pkl_name = "dataset.pkl.tmp" 149 | if hyp.is_debugging: 150 | pkl_name += ".debug" 151 | pkl_path = hyp.tmp_root + pkl_name 152 | 153 | with open(pkl_path, 'rb') as f_dump: 154 | train_dataset, dev_dataset, test_dataset, task_ents, task_rels = pickle.load(f_dump) 155 | return train_dataset, dev_dataset, test_dataset, task_ents, task_rels 156 | 157 | def build_map(ents, rels, w_cnt, hyp): 158 | w2i = {} 159 | i2w = {} 160 | for w, cnt in w_cnt.items(): 161 | if cnt < hyp.min_word_cnt: 162 | continue 163 | i = len(w2i) 164 | w2i[w] = i 165 | i2w[i] = w 166 | if hyp.UNK not in w2i: 167 | i = len(w2i) 168 | w2i[hyp.UNK] = i 169 | i2w[i] = hyp.UNK 170 | if hyp.NUM not in w2i: 171 | i = len(w2i) 172 | w2i[hyp.NUM] = i 173 | i2w[i] = hyp.NUM 174 | if hyp.WORD_PAD not in w2i: 175 | i = len(w2i) 176 | w2i[hyp.WORD_PAD] = i 177 | i2w[i] = hyp.WORD_PAD 178 | print("min_word_cnt = {}, after filtering, #word = {}".format(hyp.min_word_cnt, len(i2w))) 179 | 180 | e2i = {} 181 | i2e = {} 182 | r2i = {} 183 | i2r = {} 184 | for rel in rels: 185 | i = len(r2i) 186 | r2i[rel] = i 187 | i2r[i] = rel 188 | for ent in ents: 189 | i = len(e2i) 190 | e2i[ent] = i 191 | i2e[i] = ent 192 | 193 | print("#ent = {}, #rel = {}".format(len(ents), len(rels))) 194 | pkl_path = hyp.tmp_root + "map.pkl.tmp.debug" if hyp.is_debugging else hyp.tmp_root + "map.pkl.tmp" 195 | with open(pkl_path, 'wb') as f_dump: 196 | pickle.dump((e2i, i2e, r2i, i2r, w2i, i2w), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 197 | return e2i, i2e, r2i, i2r, w2i, i2w 198 | 199 | def load_map(hyp): 200 | pkl_path = hyp.tmp_root + "map.pkl.tmp.debug" if hyp.is_debugging else hyp.tmp_root + "map.pkl.tmp" 201 | with open(pkl_path, 'rb') as f_dump: 202 | e2i, i2e, r2i, i2r, w2i, i2w = pickle.load(f_dump) 203 | return e2i, i2e, r2i, i2r, w2i, i2w 204 | 205 | def token2idx(train_dataset, dev_dataset, test_dataset, rel_names, ent_dscps, e2i, r2i, w2i): 206 | new_train_dataset = defaultdict(list) 207 | new_dev_dataset = defaultdict(list) 208 | new_test_dataset = defaultdict(list) 209 | new_rel_names = {} 210 | new_ent_dscps = {} 211 | 212 | for src, dst in zip([train_dataset, dev_dataset, test_dataset], [new_train_dataset, new_dev_dataset, new_test_dataset]): 213 | for rel, duos in src.items(): 214 | for ent1, ent2 in duos: 215 | dst[r2i[rel]].append((e2i[ent1], e2i[ent2])) 216 | 217 | for rel, name in rel_names.items(): 218 | if rel in r2i: 219 | idx_rel = r2i[rel] 220 | new_rel_names[idx_rel] = T.tensor([w2i[w] if w in w2i else w2i[hyp.UNK] for w in name], dtype = T.long) 221 | for ent, dscp in ent_dscps.items(): 222 | if ent in e2i: 223 | idx_ent = e2i[ent] 224 | new_ent_dscps[idx_ent] = T.tensor([w2i[w] if w in w2i else w2i[hyp.UNK] for w in dscp], dtype = T.long) 225 | 226 | print("finish transforming symbols to indices") 227 | pkl_path = "idx.pkl.tmp.debug" if hyp.is_debugging else "idx.pkl.tmp" 228 | with open(hyp.tmp_root + pkl_path, 'wb') as f_dump: 229 | pickle.dump((new_train_dataset, new_dev_dataset, new_test_dataset, new_rel_names, new_ent_dscps), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 230 | return new_train_dataset, new_dev_dataset, new_test_dataset, new_rel_names, new_ent_dscps 231 | 232 | def load_idx(hyp): 233 | pkl_path = "idx.pkl.tmp.debug" if hyp.is_debugging else "idx.pkl.tmp" 234 | with open(hyp.tmp_root + pkl_path, 'rb') as f_dump: 235 | train_dataset, dev_dataset, test_dataset, rel_names, ent_dscps = pickle.load(f_dump) 236 | return train_dataset, dev_dataset, test_dataset, rel_names, ent_dscps 237 | 238 | def build_aux(train_dataset, dev_dataset, test_dataset, ent_dscps, e2i, i2r, hyp): 239 | e1rel_e2 = defaultdict(set) 240 | for src in [train_dataset, dev_dataset]: 241 | for idx_rel, duos in src.items(): 242 | for idx_ent1, idx_ent2 in duos: 243 | e1rel_e2[(idx_ent1, idx_rel)].add(idx_ent2) 244 | 245 | rel_cand = defaultdict(set) 246 | for src in [train_dataset, dev_dataset, test_dataset]: 247 | for idx_rel, duos in src.items(): 248 | for idx_ent1, idx_ent2 in duos: 249 | if idx_ent2 not in rel_cand[idx_rel]: 250 | rel_cand[idx_rel].add(idx_ent2) 251 | if len(rel_cand[idx_rel]) >= hyp.max_num_cand: 252 | break 253 | 254 | print("remaining candidates / origin candidates in dataset:") 255 | for idx_rel, cands in rel_cand.items(): 256 | purposed_num_cand = len(cands) if len(cands) > hyp.min_num_cand else hyp.min_num_cand 257 | if len(cands) < purposed_num_cand: 258 | num_left = purposed_num_cand - len(cands) 259 | while num_left > 0: 260 | while 1: 261 | idx_ent = random.randint(0, len(e2i) - 1) 262 | if idx_ent not in cands: 263 | break 264 | cands.add(idx_ent) 265 | num_left -= 1 266 | cands = list(cands) 267 | cands.sort(key = lambda idx_ent, ent_dscps = ent_dscps : len(ent_dscps[idx_ent])) 268 | #print("{}: {}".format(i2r[idx_rel], len(cands))) 269 | rel_cand[idx_rel] = cands 270 | 271 | print("finish building e1rel_e2, rel_cand") 272 | pkl_path = "aux.pkl.tmp.debug" if hyp.is_debugging else "aux.pkl.tmp" 273 | with open(hyp.tmp_root + pkl_path, 'wb') as f_dump: 274 | pickle.dump((e1rel_e2, rel_cand), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 275 | return e1rel_e2, rel_cand 276 | 277 | def load_aux(hyp): 278 | pkl_path = "aux.pkl.tmp.debug" if hyp.is_debugging else "aux.pkl.tmp" 279 | with open(hyp.tmp_root + pkl_path, 'rb') as f_dump: 280 | e1rel_e2, rel_cand = pickle.load(f_dump) 281 | return e1rel_e2, rel_cand 282 | 283 | def prepare_task(hyp): 284 | dscp_ents, dscp_rels, rel_names, ent_dscps, w_cnt = build_dscp(hyp) 285 | #dscp_ents, dscp_rels, rel_names, ent_dscps, w_cnt = load_dscp(hyp) 286 | print(1) 287 | 288 | raw_train_dataset, raw_dev_dataset, raw_test_dataset, task_ents, task_rels = build_dataset(dscp_ents, dscp_rels, hyp) 289 | #raw_train_dataset, raw_dev_dataset, raw_test_dataset, task_ents, task_rels = load_dataset(hyp) 290 | print(2) 291 | 292 | e2i, i2e, r2i, i2r, w2i, i2w = build_map(task_ents, task_rels, w_cnt, hyp) 293 | #e2i, i2e, r2i, i2r, w2i, i2w = load_map(hyp) 294 | print(3) 295 | 296 | train_dataset, dev_dataset, test_dataset, rel_names, ent_dscps = token2idx(raw_train_dataset, raw_dev_dataset, raw_test_dataset, rel_names, ent_dscps, e2i, r2i, w2i) 297 | #train_dataset, dev_dataset, test_dataset, rel_names, ent_dscps = load_idx(hyp) 298 | print(4) 299 | 300 | e1rel_e2, rel_cand = build_aux(train_dataset, dev_dataset, test_dataset, ent_dscps, e2i, i2r, hyp) 301 | #e1rel_e2, rel_cand = load_aux(hyp) 302 | print(5) 303 | 304 | all_data = [train_dataset, dev_dataset, test_dataset, ent_dscps, rel_names, i2r, w2i, i2w, rel_cand, e1rel_e2] 305 | pkl_path = hyp.raw_pkl_path + ".debug" if hyp.is_debugging else hyp.raw_pkl_path 306 | print("start dumping to {0}".format(pkl_path)) 307 | with open(pkl_path, 'wb') as f_dump: 308 | pickle.dump(all_data, f_dump, protocol = pickle.HIGHEST_PROTOCOL) 309 | print("finish") 310 | 311 | return all_data 312 | 313 | if __name__ == "__main__": 314 | hyp = init_hyp("dbpedia") 315 | prepare_task(hyp) 316 | -------------------------------------------------------------------------------- /prepare_wikidata.py: -------------------------------------------------------------------------------- 1 | import string 2 | import random 3 | import json 4 | from sys import exit 5 | import os 6 | import re 7 | from collections import defaultdict, Counter, namedtuple 8 | import pickle 9 | import torch as T 10 | from hyp_reptile import * 11 | 12 | def build_dscp(hyp): 13 | w_cnt = Counter() 14 | chars = set() 15 | ents = set() 16 | rels = set() 17 | ent_dscps = {} 18 | rel_dscps = {} 19 | ent_parents = {} 20 | rel_parents = {} 21 | parent_ents = set() 22 | wiki_dscp_path = hyp.proj_root + "data/wikidata/dscp_all.txt" 23 | 24 | rm_aster = lambda s : re.sub(r'\*', '', s) # remove * 25 | tokenize = lambda s : re.sub('([{}])'.format(re.escape(string.punctuation)), r' \1 ', s) # separate punctuations 26 | replace_num = lambda s, num_token : re.sub(r'\d+', num_token, s) # replace continuous numbers with '$' 27 | 28 | with open(wiki_dscp_path, "r", encoding = "utf-8") as f_src: 29 | for ii, l in enumerate(f_src): 30 | l = l.strip().split("####") 31 | if len(l) != 4: 32 | continue 33 | token, name, dscp, parents = l[0], l[1].lower(), l[2].lower(), l[3].split() 34 | if name == "none" or dscp == "none": 35 | continue 36 | name = replace_num(rm_aster(name), hyp.NUM) if hyp.rm_num_in_name else rm_aster(name) 37 | dscp = replace_num(tokenize(rm_aster(dscp)), hyp.NUM) 38 | words = dscp.split() 39 | for e in words: 40 | w_cnt[e] += 1 41 | for c in name: 42 | chars.add(c) 43 | if token[0] == "P": 44 | rels.add(token) 45 | rel_dscps[token] = (name, words) 46 | rel_parents[token] = parents 47 | elif token[0] == "Q": 48 | ents.add(token) 49 | ent_dscps[token] = (name, words) 50 | ent_parents[token] = parents 51 | else: 52 | raise RuntimeError(token) 53 | for e in parents: 54 | parent_ents.add(e) 55 | 56 | pkl_path = hyp.tmp_root + "dscp.pkl.tmp" 57 | with open(pkl_path, 'wb') as f_dump: 58 | pickle.dump((ents, rels, ent_dscps, rel_dscps, w_cnt, chars, ent_parents, rel_parents, parent_ents), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 59 | return ents, rels, ent_dscps, rel_dscps, w_cnt, chars, ent_parents, rel_parents, parent_ents 60 | 61 | def load_dscp(hyp): 62 | pkl_path = hyp.tmp_root + "dscp.pkl.tmp" 63 | with open(pkl_path, 'rb') as f_dump: 64 | dscp_ents, dscp_rels, ent_dscps, rel_dscps, w_cnt, chars, ent_parents, rel_parents, parent_ents = pickle.load(f_dump) 65 | return dscp_ents, dscp_rels, ent_dscps, rel_dscps, w_cnt, chars, ent_parents, rel_parents, parent_ents 66 | 67 | def build_task(dscp_ents, dscp_rels, hyp): 68 | src_root = hyp.json_dataset_root 69 | raw_train_task = json.load(open(src_root + "train_tasks.json", "r")) 70 | raw_dev_task = json.load(open(src_root + "dev_tasks.json", "r")) 71 | raw_test_task = json.load(open(src_root + "test_tasks.json", "r")) 72 | task_ents = set() 73 | task_rels = set() 74 | train_task = defaultdict(list) 75 | dev_task = defaultdict(list) 76 | test_task = defaultdict(list) 77 | 78 | dev_size1 = 0 79 | dev_size2 = 0 80 | for rel, tups in raw_dev_task.items(): 81 | dev_size1 += len(tups) 82 | if rel not in dscp_rels: 83 | print("dev task {} has no dscp".format(rel)) 84 | continue 85 | 86 | valid_tups = [] 87 | for ii, tup in enumerate(tups): 88 | if hyp.is_debugging and ii >= 10: 89 | break 90 | if tup[0] in dscp_ents and tup[2] in dscp_ents: 91 | valid_tups.append((tup[0], tup[2])) 92 | if len(valid_tups) < hyp.min_task_size: 93 | print("dev task {} is too small: size = {}".format(rel, len(valid_tups))) 94 | continue 95 | 96 | dev_task[rel] += valid_tups 97 | dev_size2 += len(valid_tups) 98 | task_rels.add(rel) 99 | for e1, e2 in valid_tups: 100 | task_ents.add(e1) 101 | task_ents.add(e2) 102 | print("origin dev task: #task = {}, #tuple = {}, after filtering entities and relations without description, #task = {}, #tuple = {}".format(len(raw_dev_task), dev_size1, len(dev_task), dev_size2)) 103 | 104 | test_size1 = 0 105 | test_size2 = 0 106 | for rel, tups in raw_test_task.items(): 107 | test_size1 += len(tups) 108 | if rel not in dscp_rels: 109 | print("test task {} has no dscp".format(rel)) 110 | continue 111 | 112 | valid_tups = [] 113 | for ii, tup in enumerate(tups): 114 | if hyp.is_debugging and ii >= 10: 115 | break 116 | if tup[0] in dscp_ents and tup[2] in dscp_ents: 117 | valid_tups.append((tup[0], tup[2])) 118 | if len(valid_tups) < hyp.min_task_size: 119 | print("test task {} is too small: size = {}".format(rel, len(valid_tups))) 120 | continue 121 | 122 | test_task[rel] += valid_tups 123 | test_size2 += len(valid_tups) 124 | task_rels.add(rel) 125 | for e1, e2 in valid_tups: 126 | task_ents.add(e1) 127 | task_ents.add(e2) 128 | print("origin test task: #task = {}, #tuple = {}, after filtering entities and relations without description, #task = {}, #tuple = {}".format(len(raw_test_task), test_size1, len(test_task), test_size2)) 129 | 130 | train_size1 = 0 131 | train_size2 = 0 132 | train_size3 = 0 133 | for rel, tups in raw_train_task.items(): 134 | train_size1 += len(tups) 135 | if rel not in dscp_rels: 136 | print("train task {} has no dscp".format(rel)) 137 | continue 138 | 139 | possible_tups = [] 140 | for ii, tup in enumerate(tups): 141 | if hyp.is_debugging and ii >= 10: 142 | break 143 | if tup[0] in dscp_ents and tup[2] in dscp_ents: 144 | possible_tups.append((tup[0], tup[2])) 145 | train_size2 += len(possible_tups) 146 | valid_tups = [] 147 | for e1, e2 in possible_tups: 148 | valid_tups.append((e1, e2)) 149 | if len(valid_tups) < hyp.min_task_size: 150 | print("train task {} is too small: size = {}".format(rel, len(valid_tups))) 151 | continue 152 | 153 | train_task[rel] += valid_tups 154 | train_size3 += len(valid_tups) 155 | task_rels.add(rel) 156 | for e1, e2 in valid_tups: 157 | task_ents.add(e1) 158 | task_ents.add(e2) 159 | print("origin training task: #task = {}, #tuple = {}, after filtering by e2i, r2i, #task = {}, #tuple = {}, after filtering dev_test entities, #tuple = {}".format(len(raw_train_task), train_size1, len(train_task), train_size2, train_size3)) 160 | 161 | pkl_path = hyp.tmp_root 162 | pkl_path += "task.pkl.tmp.debug" if hyp.is_debugging else "task.pkl.tmp" 163 | with open(pkl_path, 'wb') as f_dump: 164 | pickle.dump((train_task, dev_task, test_task, task_ents, task_rels), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 165 | return train_task, dev_task, test_task, task_ents, task_rels 166 | 167 | def load_task(hyp): 168 | pkl_path = hyp.tmp_root 169 | pkl_path += "task.pkl.tmp.debug" if hyp.is_debugging else "task.pkl.tmp" 170 | with open(pkl_path, 'rb') as f_dump: 171 | train_task, dev_task, test_task, task_ents, task_rels = pickle.load(f_dump) 172 | return train_task, dev_task, test_task, task_ents, task_rels 173 | 174 | def build_map(ents, rels, w_cnt, hyp): 175 | w2i = {} 176 | i2w = {} 177 | for w, cnt in w_cnt.items(): 178 | if cnt < hyp.min_word_cnt: 179 | continue 180 | i = len(w2i) 181 | w2i[w] = i 182 | i2w[i] = w 183 | if hyp.UNK not in w2i: 184 | i = len(w2i) 185 | w2i[hyp.UNK] = i 186 | i2w[i] = hyp.UNK 187 | if hyp.NUM not in w2i: 188 | i = len(w2i) 189 | w2i[hyp.NUM] = i 190 | i2w[i] = hyp.NUM 191 | if hyp.WORD_PAD not in w2i: 192 | i = len(w2i) 193 | w2i[hyp.WORD_PAD] = i 194 | i2w[i] = hyp.WORD_PAD 195 | print("min_word_cnt = {}, after filtering, #word = {}".format(hyp.min_word_cnt, len(i2w))) 196 | 197 | e2i = {} 198 | i2e = {} 199 | r2i = {} 200 | i2r = {} 201 | for rel in rels: 202 | i = len(r2i) 203 | r2i[rel] = i 204 | i2r[i] = rel 205 | for ent in ents: 206 | i = len(e2i) 207 | e2i[ent] = i 208 | i2e[i] = ent 209 | 210 | print("#ent = {}, #rel = {}".format(len(ents), len(rels))) 211 | pkl_path = hyp.tmp_root 212 | pkl_path += "map.pkl.tmp.debug" if hyp.is_debugging else "map.pkl.tmp" 213 | with open(pkl_path, 'wb') as f_dump: 214 | pickle.dump((e2i, i2e, r2i, i2r, w2i, i2w), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 215 | return e2i, i2e, r2i, i2r, w2i, i2w 216 | 217 | def load_map(hyp): 218 | pkl_path = hyp.tmp_root 219 | pkl_path += "map.pkl.tmp.debug" if hyp.is_debugging else "map.pkl.tmp" 220 | with open(pkl_path, 'rb') as f_dump: 221 | e2i, i2e, r2i, i2r, w2i, i2w = pickle.load(f_dump) 222 | return e2i, i2e, r2i, i2r, w2i, i2w 223 | 224 | def token2idx(train_task, dev_task, test_task, ent_dscps, rel_dscps, e2i, r2i, w2i): 225 | new_train_task = defaultdict(list) 226 | new_dev_task = defaultdict(list) 227 | new_test_task = defaultdict(list) 228 | new_ent_dscps = {} 229 | new_rel_dscps = {} 230 | 231 | for src, dst in zip([train_task, dev_task, test_task], [new_train_task, new_dev_task, new_test_task]): 232 | for rel, duos in src.items(): 233 | for duo in duos: 234 | dst[r2i[rel]].append((e2i[duo[0]], e2i[duo[1]])) 235 | 236 | for ent, (name, dscp) in ent_dscps.items(): 237 | if ent in e2i: 238 | idx_ent = e2i[ent] 239 | new_ent_dscps[idx_ent] = T.tensor([w2i[w] if w in w2i else w2i[hyp.UNK] for w in dscp], dtype = T.long) 240 | for rel, (name, dscp) in rel_dscps.items(): 241 | if rel in r2i: 242 | idx_rel = r2i[rel] 243 | new_rel_dscps[idx_rel] = T.tensor([w2i[w] if w in w2i else w2i[hyp.UNK] for w in dscp], dtype = T.long) 244 | 245 | print("finish transforming symbols to indices") 246 | pkl_path = hyp.tmp_root 247 | pkl_path += "idx.pkl.tmp.debug" if hyp.is_debugging else "idx.pkl.tmp" 248 | with open(pkl_path, 'wb') as f_dump: 249 | pickle.dump((new_train_task, new_dev_task, new_test_task, new_ent_dscps, new_rel_dscps), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 250 | return new_train_task, new_dev_task, new_test_task, new_ent_dscps, new_rel_dscps 251 | 252 | def load_idx(hyp): 253 | pkl_path = hyp.tmp_root 254 | pkl_path += "idx.pkl.tmp.debug" if hyp.is_debugging else "idx.pkl.tmp" 255 | with open(pkl_path, 'rb') as f_dump: 256 | train_task, dev_task, test_task, ent_dscps, rel_dscps = pickle.load(f_dump) 257 | return train_task, dev_task, test_task, ent_dscps, rel_dscps 258 | 259 | def build_aux(train_task, dev_task, test_task, e2i, i2r, ent_dscps, hyp): 260 | raw_rel_cand = json.load(open(hyp.json_dataset_root + "rel2candidates.json", "r")) 261 | new_rel_cand = {} 262 | valid_rels = list(train_task) + list(dev_task) + list(test_task) 263 | 264 | e1rel_e2 = defaultdict(set) 265 | for tasks in [train_task, dev_task]: 266 | for idx_rel, duos in tasks.items(): 267 | for idx_e1, idx_e2 in duos: 268 | e1rel_e2[(idx_e1, idx_rel)].add(idx_e2) 269 | 270 | print("remaining candidates / origin candidates in dataset:") 271 | for idx_rel in valid_rels: 272 | ents = raw_rel_cand[i2r[idx_rel]] 273 | cands = [] 274 | for ent in ents: 275 | if ent in e2i: 276 | idx_ent = e2i[ent] 277 | if idx_ent in ent_dscps: 278 | cands.append(idx_ent) 279 | 280 | purposed_num_cand = len(ents) if len(ents) > hyp.min_num_cand else hyp.min_num_cand 281 | if len(cands) < purposed_num_cand: 282 | existing_cand = set(cands) 283 | num_left = purposed_num_cand - len(cands) 284 | while num_left > 0: 285 | while 1: 286 | idx_ent = random.randint(0, len(e2i) - 1) 287 | if idx_ent not in ent_dscps: 288 | continue 289 | if idx_ent not in existing_cand: 290 | break 291 | cands.append(idx_ent) 292 | num_left -= 1 293 | cands.sort(key = lambda idx_ent, ent_dscps = ent_dscps : len(ent_dscps[idx_ent])) 294 | 295 | print("{}: {}/{}".format(i2r[idx_rel], len(cands), len(ents))) 296 | new_rel_cand[idx_rel] = cands 297 | 298 | print("finish building rel_cand and e1rel_e2") 299 | pkl_path = hyp.tmp_root 300 | pkl_path += "aux.pkl.tmp.debug" if hyp.is_debugging else "aux.pkl.tmp" 301 | with open(pkl_path, 'wb') as f_dump: 302 | pickle.dump((new_rel_cand, e1rel_e2), f_dump, protocol = pickle.HIGHEST_PROTOCOL) 303 | return new_rel_cand, e1rel_e2 304 | 305 | def load_aux(hyp): 306 | pkl_path = hyp.tmp_root 307 | pkl_path += "aux.pkl.tmp.debug" if hyp.is_debugging else "aux.pkl.tmp" 308 | with open(pkl_path, 'rb') as f_dump: 309 | rel_cand, e1rel_e2 = pickle.load(f_dump) 310 | return rel_cand, e1rel_e2 311 | 312 | def prepare_task(hyp): 313 | print("start preparing wikidata pickles") 314 | #dscp_ents, dscp_rels, ent_dscps, rel_dscps, w_cnt, chars, ent_parents, rel_parents, parent_ents = build_dscp(hyp) 315 | dscp_ents, dscp_rels, ent_dscps, rel_dscps, w_cnt, chars, ent_parents, rel_parents, parent_ents = load_dscp(hyp) 316 | print(1) 317 | 318 | #train_task, dev_task, test_task, task_ents, task_rels = build_task(dscp_ents, dscp_rels, hyp) 319 | train_task, dev_task, test_task, task_ents, task_rels = load_task(hyp) 320 | ''' 321 | print("train") 322 | for k in train_task: 323 | print(k, len(train_task[k])) 324 | print("dev") 325 | for k in dev_task: 326 | print(k, len(dev_task[k])) 327 | print("test") 328 | for k in test_task: 329 | print(k, len(test_task[k])) 330 | exit(0) 331 | ''' 332 | print(2) 333 | 334 | #e2i, i2e, r2i, i2r, w2i, i2w = build_map(task_ents, task_rels, w_cnt, hyp) 335 | e2i, i2e, r2i, i2r, w2i, i2w = load_map(hyp) 336 | print(3) 337 | 338 | #train_task, dev_task, test_task, ent_dscps, rel_dscps = token2idx(train_task, dev_task, test_task, ent_dscps, rel_dscps, e2i, r2i, w2i) 339 | train_task, dev_task, test_task, ent_dscps, rel_dscps = load_idx(hyp) 340 | 341 | print(len(task_ents)) 342 | print(len(e2i)) 343 | print(len(ent_dscps)) 344 | #rel_cand, e1rel_e2 = build_aux(train_task, dev_task, test_task, e2i, i2r, ent_dscps, hyp) 345 | rel_cand, e1rel_e2 = load_aux(hyp) 346 | 347 | all_data = [train_task, dev_task, test_task, ent_dscps, rel_dscps, i2r, w2i, i2w, rel_cand, e1rel_e2] 348 | pkl_path = hyp.raw_pkl_path + ".debug" if hyp.is_debugging else hyp.raw_pkl_path 349 | print("start dumping to {0}".format(pkl_path)) 350 | with open(pkl_path, 'wb') as f_dump: 351 | pickle.dump(all_data, f_dump, protocol = pickle.HIGHEST_PROTOCOL) 352 | print("finish") 353 | 354 | return all_data 355 | 356 | if __name__ == "__main__": 357 | hyp = init_hyp("wikidata") 358 | prepare_task(hyp) 359 | --------------------------------------------------------------------------------