├── README.md └── src ├── args.py ├── components.py ├── dataset.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch-HeGAN 2 | 3 | Pytorch implement of HeGAN(Adversarial Learning on Heterogeneous Information Networks) -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Filename : args.py 3 | @Create Time : 6/18/2020 8:27 PM 4 | @Author : Rylynn 5 | @Description : 6 | 7 | """ 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--name', type=str, help='Model name') 12 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size') 13 | parser.add_argument('--epoch', type=int, default=20, help='Training epoch') 14 | parser.add_argument('--epoch_g', type=int, default=10, help='Generator epoch in each training epoch') 15 | parser.add_argument('--epoch_d', type=int, default=10, help='Discriminator epoch in each training epoch') 16 | parser.add_argument('--sample_num', type=int, default=20, help='Number of sampling in one node') 17 | parser.add_argument('--gen_lr', type=float, default=0.001, help='Learning rate of Generator') 18 | parser.add_argument('--dis_lr', type=float, default=0.001, help='Learning rate of Discriminator') 19 | parser.add_argument('--node_embed_size', type=int, default=64, help='Node Embedding size') 20 | parser.add_argument('--node_size', type=int, default=37791, help='Node Embedding size') 21 | parser.add_argument('--relation_size', type=int, default=6, help='Relation size') 22 | args = parser.parse_args() -------------------------------------------------------------------------------- /src/components.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Filename : components.py 3 | @Create Time : 6/18/2020 3:43 PM 4 | @Author : Rylynn 5 | @Description : 6 | 7 | """ 8 | from collections import OrderedDict 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class Discriminator(nn.Module): 15 | def __init__(self, args): 16 | super(Discriminator, self).__init__() 17 | self.args = args 18 | 19 | self.node_embed = nn.Embedding.from_pretrained(torch.from_numpy(args.pretrain_embed)).float() 20 | self.relation_embed = nn.Embedding(args.relation_size, args.node_embed_size * args.node_embed_size) 21 | 22 | nn.init.xavier_uniform(self.relation_embed.weight) 23 | 24 | self.sigmoid = nn.Sigmoid() 25 | 26 | def forward_fake(self, node_idx, relation_idx, fake_node_embed): 27 | node_embed = self.node_embed(node_idx) 28 | node_embed = node_embed.reshape((-1, 1, self.args.node_embed_size)) 29 | relation_embed = self.relation_embed(relation_idx) 30 | relation_embed = relation_embed.reshape((-1, self.args.node_embed_size, self.args.node_embed_size)) 31 | temp = torch.matmul(node_embed, relation_embed) 32 | 33 | score = torch.sum(torch.mul(temp, fake_node_embed), 0) 34 | prob = self.sigmoid(score) 35 | return prob 36 | 37 | def forward(self, node_idx, relation_idx, node_neighbor_idx): 38 | node_embed = self.node_embed(node_idx) 39 | node_embed = node_embed.reshape((-1, 1, self.args.node_embed_size)) 40 | relation_embed = self.relation_embed(relation_idx) 41 | relation_embed = relation_embed.reshape((-1, self.args.node_embed_size, self.args.node_embed_size)) 42 | temp = torch.matmul(node_embed, relation_embed) 43 | 44 | score = torch.sum(torch.mul(temp, self.node_embed(node_neighbor_idx)), 0) 45 | 46 | prob = self.sigmoid(score) 47 | return prob 48 | 49 | def multify(self, node_idx, relation_idx): 50 | """ 51 | get e_u^D * M_r^b 52 | :param node_idx: 53 | :param relation_idx: 54 | :return: 55 | """ 56 | node_embed = self.node_embed(node_idx) 57 | relation_embed = self.relation_embed(relation_idx) 58 | relation_embed = relation_embed.reshape((-1, self.args.node_embed_size, self.args.node_embed_size)) 59 | temp = torch.matmul(node_embed, relation_embed) 60 | return temp 61 | 62 | 63 | class Generator(nn.Module): 64 | def __init__(self, args): 65 | super(Generator, self).__init__() 66 | self.args = args 67 | 68 | node_embed_size = args.node_embed_size 69 | 70 | self.node_embed = nn.Embedding.from_pretrained(torch.from_numpy(args.pretrain_embed)).float() 71 | self.relation_embed = nn.Embedding(args.relation_size, args.node_embed_size * args.node_embed_size) 72 | 73 | nn.init.xavier_uniform(self.relation_embed.weight) 74 | 75 | self.fc = nn.Sequential( 76 | OrderedDict([ 77 | ("w_1", nn.Linear(node_embed_size, node_embed_size)), 78 | ("a_1", nn.LeakyReLU()), 79 | ("w_2", nn.Linear(node_embed_size, node_embed_size)), 80 | ("a_2", nn.LeakyReLU()) 81 | ]) 82 | ) 83 | 84 | nn.init.xavier_uniform(self.fc[0].weight) 85 | nn.init.xavier_uniform(self.fc[2].weight) 86 | 87 | self.sigmoid = nn.Sigmoid() 88 | 89 | def forward(self, node_idx, relation_idx, dis_temp): 90 | fake_nodes = self.generate_fake_nodes(node_idx, relation_idx) 91 | score = torch.matmul(dis_temp, fake_nodes) 92 | prob = self.sigmoid(score) 93 | return prob 94 | 95 | def loss(self, prob): 96 | loss = torch.sum(torch.log(- prob)) 97 | return loss 98 | 99 | def generate_fake_nodes(self, node_idx, relation_idx): 100 | node_embed = self.node_embed(node_idx) 101 | node_embed = node_embed.reshape((-1, 1, self.args.node_embed_size)) 102 | relation_embed = self.relation_embed(relation_idx) 103 | relation_embed = relation_embed.reshape((-1, self.args.node_embed_size, self.args.node_embed_size)) 104 | temp = torch.matmul(node_embed, relation_embed) 105 | 106 | # add noise 107 | temp = temp + torch.randn(temp.shape, requires_grad=False).cuda() 108 | output = self.fc(temp) 109 | 110 | return output 111 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Filename : dataset.py 3 | @Create Time : 6/18/2020 7:37 PM 4 | @Author : Rylynn 5 | @Description : 6 | 7 | """ 8 | import random 9 | 10 | from torch.utils.data import Dataset 11 | 12 | import logging 13 | import numpy as np 14 | from args import args 15 | 16 | 17 | class DBLPDataset(): 18 | def __init__(self): 19 | logging.info("Loading data from dataset...") 20 | self.node_size, self.relation_size, self.graph = read_graph('../data/DBLP/dblp_triple.dat') 21 | logging.info("Loading data finished, total {} nodes, {} relations.".format(self.node_size, self.relation_size)) 22 | self.generator_dataset = self.GeneratorDataset(self.graph) 23 | self.discriminator_dataset = self.DiscriminatorDataset(self.graph) 24 | 25 | class GeneratorDataset(Dataset): 26 | def __init__(self, graph): 27 | self.graph = graph 28 | self.node_list = list(graph.keys()) 29 | self.node_idx = [] 30 | self.relation_idx = [] 31 | 32 | self.sample() 33 | 34 | def __getitem__(self, index): 35 | return self.node_idx[index], self.relation_idx[index] 36 | 37 | def __len__(self): 38 | return len(self.node_idx) 39 | 40 | def sample(self): 41 | logging.info("Sampling data for Generator") 42 | 43 | for node_id in self.node_list: 44 | for i in range(args.sample_num): 45 | relations = list(self.graph[node_id].keys()) 46 | relation_id = random.sample(relations, 1)[0] 47 | 48 | self.node_idx.append(node_id) 49 | self.relation_idx.append(relation_id) 50 | 51 | class DiscriminatorDataset(Dataset): 52 | def __init__(self, graph): 53 | self.graph = graph 54 | self.node_list = list(graph.keys()) 55 | 56 | # real node and real relation 57 | self.pos_node_idx = [] 58 | self.pos_relation_idx = [] 59 | self.pos_node_neighbor_idx = [] 60 | 61 | # real node and wrong relation 62 | self.neg_node_idx = [] 63 | self.neg_relation_idx = [] 64 | self.neg_node_neighbor_idx = [] 65 | 66 | self.sample() 67 | 68 | def __getitem__(self, index): 69 | return self.pos_node_idx[index], self.pos_relation_idx[index], self.pos_node_neighbor_idx[index], \ 70 | self.neg_node_idx[index], self.neg_relation_idx[index], self.neg_node_neighbor_idx[index] 71 | 72 | def __len__(self): 73 | return len(self.pos_node_idx) 74 | 75 | def sample(self): 76 | logging.info("Sampling data for Discriminator") 77 | for node_id in self.node_list: 78 | for i in range(args.sample_num): 79 | # sample real node and true relation 80 | relations = list(self.graph[node_id].keys()) 81 | relation_id = random.sample(relations, 1)[0] 82 | neighbors = self.graph[node_id][relation_id] 83 | node_neighbor_id = neighbors[np.random.randint(0, len(neighbors))] 84 | 85 | self.pos_node_idx.append(node_id) 86 | self.pos_relation_idx.append(relation_id) 87 | self.pos_node_neighbor_idx.append(node_neighbor_id) 88 | 89 | # sample real node and wrong relation 90 | self.neg_node_idx.append(node_id) 91 | self.neg_node_neighbor_idx.append(node_neighbor_id) 92 | neg_relation_id = np.random.randint(0, args.relation_size) 93 | while neg_relation_id == relation_id: 94 | neg_relation_id = np.random.randint(0, args.relation_size) 95 | self.neg_relation_idx.append(neg_relation_id) 96 | 97 | 98 | def read_graph(graph_filename): 99 | # p -> a : 0 100 | # a -> p : 1 101 | # p -> c : 2 102 | # c -> p : 3 103 | # p -> t : 4 104 | # t -> p : 5 105 | # graph_filename = '../data/dblp/dblp_triple.dat' 106 | 107 | relations = set() 108 | nodes = set() 109 | graph = {} 110 | 111 | with open(graph_filename) as infile: 112 | for line in infile.readlines(): 113 | source_node, target_node, relation = line.strip().split(' ') 114 | source_node = int(source_node) 115 | target_node = int(target_node) 116 | relation = int(relation) 117 | 118 | nodes.add(source_node) 119 | nodes.add(target_node) 120 | relations.add(relation) 121 | 122 | if source_node not in graph: 123 | graph[source_node] = {} 124 | 125 | if relation not in graph[source_node]: 126 | graph[source_node][relation] = [] 127 | 128 | graph[source_node][relation].append(target_node) 129 | 130 | n_node = len(nodes) 131 | return n_node, len(relations), graph 132 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Filename : model.py 3 | @Create Time : 6/18/2020 3:43 PM 4 | @Author : Rylynn 5 | @Description : 6 | 7 | """ 8 | import argparse 9 | import datetime 10 | import logging 11 | import numpy as np 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | from tqdm import tqdm 15 | import torch.nn as nn 16 | 17 | from torch.utils.data import DataLoader 18 | from args import args 19 | from dataset import * 20 | from components import * 21 | 22 | logging.basicConfig(level=logging.INFO, 23 | format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') 24 | 25 | 26 | class HeGAN(object): 27 | def __init__(self, args): 28 | self.args = args 29 | logging.info("Loading pretrain embedding file...") 30 | pretrain_emb = self.read_pretrain_embed('../pretrain/dblp_pre_train.emb', node_size=args.node_size, embed_size=args.node_embed_size) 31 | self.args.pretrain_embed = pretrain_emb 32 | logging.info("Pretrain embedding file loaded.") 33 | 34 | logging.info("Building Generator...") 35 | generator = Generator(self.args) 36 | self.generator = generator.cuda() 37 | 38 | logging.info("Building Discriminator...") 39 | discriminator = Discriminator(self.args) 40 | self.discriminator = discriminator.cuda() 41 | 42 | self.name = "HeGAN-" + datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S') 43 | if args.name: 44 | self.name = args.name + datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S') 45 | 46 | 47 | def read_pretrain_embed(self, pretrain_file, node_size, embed_size): 48 | embedding_matrix = np.random.rand(node_size, embed_size) 49 | i = -1 50 | with open(pretrain_file) as infile: 51 | for line in infile.readlines()[1:]: 52 | i += 1 53 | emd = line.strip().split() 54 | embedding_matrix[int(emd[0]), :] = self.str_list_to_float(emd[1:]) 55 | return embedding_matrix 56 | 57 | def str_list_to_float(self, str_list): 58 | return [float(item) for item in str_list] 59 | 60 | def train(self): 61 | writer = SummaryWriter("./log/" + self.name) 62 | 63 | dblp_dataset = DBLPDataset() 64 | gen_data_loader = DataLoader(dblp_dataset.generator_dataset, shuffle=True, batch_size=self.args.batch_size, 65 | num_workers=8, pin_memory=True) 66 | dis_data_loader = DataLoader(dblp_dataset.discriminator_dataset, shuffle=True, batch_size=self.args.batch_size, 67 | num_workers=8, pin_memory=True) 68 | 69 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), self.args.dis_lr) 70 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), self.args.gen_lr) 71 | 72 | logging.info("Training Begin...") 73 | for total_idx in range(self.args.epoch): 74 | # Training discriminator 75 | for dis_idx in range(self.args.epoch_g): 76 | dis_batch_loss = 0 77 | for (batch_idx, data) in tqdm(enumerate(dis_data_loader)): 78 | pos_node_idx = data[0].cuda() 79 | pos_relation_idx = data[1].cuda() 80 | pos_node_neighbor_idx = data[2].cuda() 81 | 82 | neg_node_idx = data[3].cuda() 83 | neg_relation_idx = data[4].cuda() 84 | neg_node_neighbor_idx = data[5].cuda() 85 | 86 | fake_nodes_embed = self.generator.generate_fake_nodes(pos_node_idx, pos_relation_idx) 87 | 88 | prob_pos = self.discriminator(pos_node_idx, pos_relation_idx, pos_node_neighbor_idx) 89 | prob_neg = self.discriminator(neg_node_idx, neg_relation_idx, neg_node_neighbor_idx) 90 | prob_fake = self.discriminator.forward_fake(pos_node_idx, pos_relation_idx, fake_nodes_embed) 91 | 92 | discriminator_loss_pos = torch.sum(torch.log(prob_pos)) 93 | discriminator_loss_neg = torch.sum(torch.log(-prob_neg)) 94 | discriminator_loss_fake = torch.sum(torch.log(-prob_fake)) 95 | 96 | discriminator_loss = discriminator_loss_pos + discriminator_loss_neg + discriminator_loss_fake 97 | dis_batch_loss += discriminator_loss.item() 98 | 99 | discriminator_optimizer.zero_grad() 100 | discriminator_loss.backward() 101 | discriminator_optimizer.step() 102 | 103 | logging.info("Total epoch: {}, Discriminator epoch: {}, loss: {}.". 104 | format(total_idx, dis_idx, dis_batch_loss / len(dis_data_loader))) 105 | writer.add_scalar("dis_loss", dis_batch_loss / len(dis_data_loader)) 106 | 107 | # Training generator 108 | for gen_idx in range(self.args.epoch_d): 109 | gen_batch_loss = 0 110 | for (batch_idx, data) in tqdm(enumerate(gen_data_loader)): 111 | node_idx = data[0].cuda() 112 | relation_idx = data[1].cuda() 113 | 114 | temp = self.discriminator.multify(node_idx, relation_idx) 115 | prob = self.generator(node_idx, relation_idx, dis_temp=temp) 116 | generator_loss = self.generator.loss(prob) 117 | 118 | l2_regularization = torch.tensor([0], dtype=torch.float32) 119 | for param in self.generator.parameters(): 120 | l2_regularization += torch.norm(param, 2) 121 | 122 | generator_loss = generator_loss + l2_regularization 123 | gen_batch_loss += generator_loss.item() 124 | 125 | generator_optimizer.zero_grad() 126 | generator_loss.backward() 127 | generator_optimizer.step() 128 | 129 | logging.info("Total epoch: {}, Discriminator epoch: {}, loss: {}.". 130 | format(total_idx, gen_idx, gen_batch_loss / len(gen_data_loader))) 131 | writer.add_scalar("gen_loss", gen_batch_loss / len(gen_data_loader)) 132 | 133 | writer.close() 134 | 135 | 136 | 137 | 138 | def main(): 139 | he_gan = HeGAN(args=args) 140 | he_gan.train() 141 | 142 | def test(): 143 | dblp_dataset = DBLPDataset() 144 | dl = DataLoader(dblp_dataset.discriminator_dataset, shuffle=True, batch_size=args.batch_size) 145 | for (idx, a) in enumerate(dl): 146 | print(a) 147 | 148 | 149 | if __name__ == '__main__': 150 | # test() 151 | main() 152 | --------------------------------------------------------------------------------