├── config ├── __init__.py ├── config.py └── config.yaml ├── doc └── pipline.png ├── datasets ├── __init__.py ├── data_helper.py └── visual_data.py ├── models ├── __init__.py ├── HGNN.py └── layers.py ├── LICENSE ├── README.md ├── train.py └── utils └── hypergraph_utils.py /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import get_config 2 | -------------------------------------------------------------------------------- /doc/pipline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/HGNN/HEAD/doc/pipline.png -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_helper import load_ft 2 | from .visual_data import load_feature_construct_H -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import HGNN_conv, HGNN_fc, HGNN_embedding, HGNN_classifier 2 | from .HGNN import HGNN 3 | -------------------------------------------------------------------------------- /models/HGNN.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from models import HGNN_conv 3 | import torch.nn.functional as F 4 | 5 | 6 | class HGNN(nn.Module): 7 | def __init__(self, in_ch, n_class, n_hid, dropout=0.5): 8 | super(HGNN, self).__init__() 9 | self.dropout = dropout 10 | self.hgc1 = HGNN_conv(in_ch, n_hid) 11 | self.hgc2 = HGNN_conv(n_hid, n_class) 12 | 13 | def forward(self, x, G): 14 | x = F.relu(self.hgc1(x, G)) 15 | x = F.dropout(x, self.dropout) 16 | x = self.hgc2(x, G) 17 | return x 18 | -------------------------------------------------------------------------------- /datasets/data_helper.py: -------------------------------------------------------------------------------- 1 | import scipy.io as scio 2 | import numpy as np 3 | 4 | 5 | def load_ft(data_dir, feature_name='GVCNN'): 6 | data = scio.loadmat(data_dir) 7 | lbls = data['Y'].astype(np.long) 8 | if lbls.min() == 1: 9 | lbls = lbls - 1 10 | idx = data['indices'].item() 11 | 12 | if feature_name == 'MVCNN': 13 | fts = data['X'][0].item().astype(np.float32) 14 | elif feature_name == 'GVCNN': 15 | fts = data['X'][1].item().astype(np.float32) 16 | else: 17 | print(f'wrong feature name{feature_name}!') 18 | raise IOError 19 | 20 | idx_train = np.where(idx == 1)[0] 21 | idx_test = np.where(idx == 0)[0] 22 | return fts, lbls, idx_train, idx_test 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yue's Group of THU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import os.path as osp 4 | 5 | 6 | def get_config(dir='config/config.yaml'): 7 | # add direction join function when parse the yaml file 8 | def join(loader, node): 9 | seq = loader.construct_sequence(node) 10 | return os.path.sep.join(seq) 11 | 12 | # add string concatenation function when parse the yaml file 13 | def concat(loader, node): 14 | seq = loader.construct_sequence(node) 15 | seq = [str(tmp) for tmp in seq] 16 | return ''.join(seq) 17 | 18 | yaml.add_constructor('!join', join) 19 | yaml.add_constructor('!concat', concat) 20 | with open(dir, 'r') as f: 21 | cfg = yaml.load(f) 22 | 23 | check_dirs(cfg) 24 | 25 | return cfg 26 | 27 | 28 | def check_dir(folder, mk_dir=True): 29 | if not osp.exists(folder): 30 | if mk_dir: 31 | print(f'making direction {folder}!') 32 | os.mkdir(folder) 33 | else: 34 | raise Exception(f'Not exist direction {folder}') 35 | 36 | 37 | def check_dirs(cfg): 38 | check_dir(cfg['data_root'], mk_dir=False) 39 | 40 | check_dir(cfg['result_root']) 41 | check_dir(cfg['ckpt_folder']) 42 | check_dir(cfg['result_sub_folder']) 43 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------# 2 | # must check the direction of data_root and result_root before run # 3 | #------------------------------------------------------------------# 4 | 5 | #configure feature path 6 | # **** please must modify the data_root before first running **** 7 | data_root: &d_r /home/fengyifan/data/features 8 | modelnet40_ft: !join [*d_r, ModelNet40_mvcnn_gvcnn.mat] 9 | ntu2012_ft: !join [*d_r, NTU2012_mvcnn_gvcnn.mat] 10 | 11 | 12 | #Hypergraph 13 | graph_type: &g_t hypergraph 14 | K_neigs: [10] 15 | #K_neigs: [10, 15 ] 16 | m_prob: 1.0 17 | is_probH: True 18 | #--------------------------------------- 19 | # change me 20 | use_mvcnn_feature_for_structure: True 21 | use_gvcnn_feature_for_structure: True 22 | #--------------------------------------- 23 | 24 | 25 | #Model 26 | #-------------------------------------------------- 27 | # select the dataset you use, ModelNet40 or NTU2012 28 | on_dataset: &o_d ModelNet40 29 | #on_dataset: &o_d NTU2012 30 | #-------------------------------------------------- 31 | 32 | #--------------------------------------- 33 | # change me 34 | use_mvcnn_feature: False 35 | use_gvcnn_feature: True 36 | #--------------------------------------- 37 | 38 | 39 | #Result 40 | # configure result path 41 | # **** please must modify the result_root before first running **** 42 | result_root: &r_r /home/fengyifan/result/hgnn 43 | result_sub_folder: !join [*r_r, !concat [ *g_t, _, *o_d ]] 44 | ckpt_folder: !join [*r_r, ckpt] 45 | 46 | 47 | #Train 48 | max_epoch: 600 49 | n_hid: 128 50 | lr: 0.001 51 | milestones: [100] 52 | gamma: 0.9 53 | drop_out: 0.5 54 | print_freq: 50 55 | weight_decay: 0.0005 56 | decay_step: 200 57 | decay_rate: 0.7 -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | 7 | 8 | class HGNN_conv(nn.Module): 9 | def __init__(self, in_ft, out_ft, bias=True): 10 | super(HGNN_conv, self).__init__() 11 | 12 | self.weight = Parameter(torch.Tensor(in_ft, out_ft)) 13 | if bias: 14 | self.bias = Parameter(torch.Tensor(out_ft)) 15 | else: 16 | self.register_parameter('bias', None) 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | stdv = 1. / math.sqrt(self.weight.size(1)) 21 | self.weight.data.uniform_(-stdv, stdv) 22 | if self.bias is not None: 23 | self.bias.data.uniform_(-stdv, stdv) 24 | 25 | def forward(self, x: torch.Tensor, G: torch.Tensor): 26 | x = x.matmul(self.weight) 27 | if self.bias is not None: 28 | x = x + self.bias 29 | x = G.matmul(x) 30 | return x 31 | 32 | 33 | class HGNN_fc(nn.Module): 34 | def __init__(self, in_ch, out_ch): 35 | super(HGNN_fc, self).__init__() 36 | self.fc = nn.Linear(in_ch, out_ch) 37 | 38 | def forward(self, x): 39 | return self.fc(x) 40 | 41 | 42 | class HGNN_embedding(nn.Module): 43 | def __init__(self, in_ch, n_hid, dropout=0.5): 44 | super(HGNN_embedding, self).__init__() 45 | self.dropout = dropout 46 | self.hgc1 = HGNN_conv(in_ch, n_hid) 47 | self.hgc2 = HGNN_conv(n_hid, n_hid) 48 | 49 | def forward(self, x, G): 50 | x = F.relu(self.hgc1(x, G)) 51 | x = F.dropout(x, self.dropout) 52 | x = F.relu(self.hgc2(x, G)) 53 | return x 54 | 55 | 56 | class HGNN_classifier(nn.Module): 57 | def __init__(self, n_hid, n_class): 58 | super(HGNN_classifier, self).__init__() 59 | self.fc1 = nn.Linear(n_hid, n_class) 60 | 61 | def forward(self, x): 62 | x = self.fc1(x) 63 | return x -------------------------------------------------------------------------------- /datasets/visual_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_ft 2 | from utils import hypergraph_utils as hgut 3 | 4 | 5 | def load_feature_construct_H(data_dir, 6 | m_prob=1, 7 | K_neigs=[10], 8 | is_probH=True, 9 | split_diff_scale=False, 10 | use_mvcnn_feature=False, 11 | use_gvcnn_feature=True, 12 | use_mvcnn_feature_for_structure=False, 13 | use_gvcnn_feature_for_structure=True): 14 | """ 15 | 16 | :param data_dir: directory of feature data 17 | :param m_prob: parameter in hypergraph incidence matrix construction 18 | :param K_neigs: the number of neighbor expansion 19 | :param is_probH: probability Vertex-Edge matrix or binary 20 | :param use_mvcnn_feature: 21 | :param use_gvcnn_feature: 22 | :param use_mvcnn_feature_for_structure: 23 | :param use_gvcnn_feature_for_structure: 24 | :return: 25 | """ 26 | # init feature 27 | if use_mvcnn_feature or use_mvcnn_feature_for_structure: 28 | mvcnn_ft, lbls, idx_train, idx_test = load_ft(data_dir, feature_name='MVCNN') 29 | if use_gvcnn_feature or use_gvcnn_feature_for_structure: 30 | gvcnn_ft, lbls, idx_train, idx_test = load_ft(data_dir, feature_name='GVCNN') 31 | if 'mvcnn_ft' not in dir() and 'gvcnn_ft' not in dir(): 32 | raise Exception('None feature initialized') 33 | 34 | # construct feature matrix 35 | fts = None 36 | if use_mvcnn_feature: 37 | fts = hgut.feature_concat(fts, mvcnn_ft) 38 | if use_gvcnn_feature: 39 | fts = hgut.feature_concat(fts, gvcnn_ft) 40 | if fts is None: 41 | raise Exception(f'None feature used for model!') 42 | 43 | # construct hypergraph incidence matrix 44 | print('Constructing hypergraph incidence matrix! \n(It may take several minutes! Please wait patiently!)') 45 | H = None 46 | if use_mvcnn_feature_for_structure: 47 | tmp = hgut.construct_H_with_KNN(mvcnn_ft, K_neigs=K_neigs, 48 | split_diff_scale=split_diff_scale, 49 | is_probH=is_probH, m_prob=m_prob) 50 | H = hgut.hyperedge_concat(H, tmp) 51 | if use_gvcnn_feature_for_structure: 52 | tmp = hgut.construct_H_with_KNN(gvcnn_ft, K_neigs=K_neigs, 53 | split_diff_scale=split_diff_scale, 54 | is_probH=is_probH, m_prob=m_prob) 55 | H = hgut.hyperedge_concat(H, tmp) 56 | if H is None: 57 | raise Exception('None feature to construct hypergraph incidence matrix!') 58 | 59 | return fts, lbls, idx_train, idx_test, H 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## News 2 | We have release a deep learning toolbox named DHG for graph neural networks and hypergraph neural networks. You can find many intresting things in it. Many correlation structures like simple graph, directed graph, bipartite graph, and simple hypergraph are all supported in the toolbox, as well as their visualization. More details refer to [DHG](https://github.com/iMoonLab/DeepHypergraph)! 3 | 4 | ## Hypergraph Neural Networks 5 | Created by Yifan Feng, Haoxuan You, Zizhao Zhang, Rongrong, Ji, Yue Gao from Xiamen University and Tsinghua University. 6 | 7 | ![pipline](doc/pipline.png) 8 | 9 | ### Introduction 10 | This work will appear in AAAI 2019. We proposed a novel framework(HGNN) for data representation learning, which could take multi-modal data and exhibit superior performance gain compared with single modal or graph-based multi-modal methods. You can also check our [paper](http://gaoyue.org/paper/HGNN.pdf) for a deeper introduction. 11 | 12 | HGNN could encode high-order data correlation in a hypergraph structure. Confronting the challenges of learning representation for complex data in real practice, we propose to incorporate such data structure in a hypergraph, which is more flexible on data modeling, especially when dealing with complex data. In this method, a hyperedge convolution operation is designed to handle the data correlation during representation learning. In this way, traditional hypergraph learning procedure can be conducted using hyperedge convolution operations efficiently. HGNN is able to learn the hidden layer representation considering the high-order data structure, which is a general framework considering the complex data correlations. 13 | 14 | In this repository, we release code and data for train a Hypergrpah Nerual Networks for node classification on ModelNet40 dataset and NTU2012 dataset. The visual objects' feature is extracted by [MVCNN(Su et al.)](http://vis-www.cs.umass.edu/mvcnn/docs/su15mvcnn.pdf) and [GVCNN(Feng et al.)](http://openaccess.thecvf.com/content_cvpr_2018/papers/Feng_GVCNN_Group-View_Convolutional_CVPR_2018_paper.pdf). 15 | 16 | 17 | ### Citation 18 | if you find our work useful in your research, please consider citing: 19 | 20 | @article{feng2018hypergraph, 21 | title={Hypergraph Neural Networks}, 22 | author={Feng, Yifan and You, Haoxuan and Zhang, Zizhao and Ji, Rongrong and Gao, Yue}, 23 | journal={AAAI 2019}, 24 | year={2018} 25 | } 26 | 27 | ### Installation 28 | Install [Pytorch 0.4.0](https://pytorch.org/). You also need to install yaml. The code has been tested with Python 3.6, Pytorch 0.4.0 and CUDA 9.0 on Ubuntu 16.04. 29 | 30 | ### Usage 31 | 32 | **Firstly, you should download the feature files of modelnet40 and ntu2012 datasets. 33 | Then, configure the "data_root" and "result_root" path in config/config.yaml.** 34 | 35 | Download datasets for training/evaluation (should be placed under "data_root") 36 | - [ModelNet40_mvcnn_gvcnn_feature](https://drive.google.com/file/d/1euw3bygLzRQm_dYj1FoRduXvsRRUG2Gr/view?usp=sharing) 37 | - [NTU2012_mvcnn_gvcnn_feature](https://drive.google.com/file/d/1Vx4K15bW3__JPRV0KUoDWtQX8sB-vbO5/view?usp=sharing) 38 | 39 | 40 | 41 | To train and evaluate HGNN for node classification: 42 | ``` 43 | python train.py 44 | ``` 45 | You can select the feature that contribute to construct hypregraph incidence matrix by changing the status of parameters "use_mvcnn_feature_for_structure" and "use_gvcnn_feature_for_structure" in config.yaml file. Similarly, changing the status of parameter "use_gvcnn_feature" and "use_gvcnn_feature" can control the feature HGNN feed, and both true will concatenate the mvcnn feature and gvcnn feature as the node feature in HGNN. 46 | 47 | ```yaml 48 | # config/config.yaml 49 | use_mvcnn_feature_for_structure: True 50 | use_gvcnn_feature_for_structure: True 51 | use_mvcnn_feature: False 52 | use_gvcnn_feature: True 53 | ``` 54 | To change the experimental dataset (ModelNet40 or NTU2012) 55 | ```yaml 56 | # config/config.yaml 57 | #Model 58 | on_dataset: &o_d ModelNet40 59 | #on_dataset: &o_d NTU2012 60 | ``` 61 | ### License 62 | Our code is released under MIT License (see LICENSE file for details). 63 | 64 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import copy 4 | import torch 5 | import torch.optim as optim 6 | import pprint as pp 7 | import utils.hypergraph_utils as hgut 8 | from models import HGNN 9 | from config import get_config 10 | from datasets import load_feature_construct_H 11 | 12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 13 | cfg = get_config('config/config.yaml') 14 | 15 | # initialize data 16 | data_dir = cfg['modelnet40_ft'] if cfg['on_dataset'] == 'ModelNet40' \ 17 | else cfg['ntu2012_ft'] 18 | fts, lbls, idx_train, idx_test, H = \ 19 | load_feature_construct_H(data_dir, 20 | m_prob=cfg['m_prob'], 21 | K_neigs=cfg['K_neigs'], 22 | is_probH=cfg['is_probH'], 23 | use_mvcnn_feature=cfg['use_mvcnn_feature'], 24 | use_gvcnn_feature=cfg['use_gvcnn_feature'], 25 | use_mvcnn_feature_for_structure=cfg['use_mvcnn_feature_for_structure'], 26 | use_gvcnn_feature_for_structure=cfg['use_gvcnn_feature_for_structure']) 27 | G = hgut.generate_G_from_H(H) 28 | n_class = int(lbls.max()) + 1 29 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 30 | 31 | # transform data to device 32 | fts = torch.Tensor(fts).to(device) 33 | lbls = torch.Tensor(lbls).squeeze().long().to(device) 34 | G = torch.Tensor(G).to(device) 35 | idx_train = torch.Tensor(idx_train).long().to(device) 36 | idx_test = torch.Tensor(idx_test).long().to(device) 37 | 38 | 39 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25, print_freq=500): 40 | since = time.time() 41 | 42 | best_model_wts = copy.deepcopy(model.state_dict()) 43 | best_acc = 0.0 44 | 45 | for epoch in range(num_epochs): 46 | if epoch % print_freq == 0: 47 | print('-' * 10) 48 | print(f'Epoch {epoch}/{num_epochs - 1}') 49 | 50 | # Each epoch has a training and validation phase 51 | for phase in ['train', 'val']: 52 | if phase == 'train': 53 | scheduler.step() 54 | model.train() # Set model to training mode 55 | else: 56 | model.eval() # Set model to evaluate mode 57 | 58 | running_loss = 0.0 59 | running_corrects = 0 60 | 61 | idx = idx_train if phase == 'train' else idx_test 62 | 63 | # Iterate over data. 64 | optimizer.zero_grad() 65 | with torch.set_grad_enabled(phase == 'train'): 66 | outputs = model(fts, G) 67 | loss = criterion(outputs[idx], lbls[idx]) 68 | _, preds = torch.max(outputs, 1) 69 | 70 | # backward + optimize only if in training phase 71 | if phase == 'train': 72 | loss.backward() 73 | optimizer.step() 74 | 75 | # statistics 76 | running_loss += loss.item() * fts.size(0) 77 | running_corrects += torch.sum(preds[idx] == lbls.data[idx]) 78 | 79 | epoch_loss = running_loss / len(idx) 80 | epoch_acc = running_corrects.double() / len(idx) 81 | 82 | if epoch % print_freq == 0: 83 | print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') 84 | 85 | # deep copy the model 86 | if phase == 'val' and epoch_acc > best_acc: 87 | best_acc = epoch_acc 88 | best_model_wts = copy.deepcopy(model.state_dict()) 89 | 90 | if epoch % print_freq == 0: 91 | print(f'Best val Acc: {best_acc:4f}') 92 | print('-' * 20) 93 | 94 | time_elapsed = time.time() - since 95 | print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') 96 | print(f'Best val Acc: {best_acc:4f}') 97 | 98 | # load best model weights 99 | model.load_state_dict(best_model_wts) 100 | return model 101 | 102 | 103 | def _main(): 104 | print(f"Classification on {cfg['on_dataset']} dataset!!! class number: {n_class}") 105 | print(f"use MVCNN feature: {cfg['use_mvcnn_feature']}") 106 | print(f"use GVCNN feature: {cfg['use_gvcnn_feature']}") 107 | print(f"use MVCNN feature for structure: {cfg['use_mvcnn_feature_for_structure']}") 108 | print(f"use GVCNN feature for structure: {cfg['use_gvcnn_feature_for_structure']}") 109 | print('Configuration -> Start') 110 | pp.pprint(cfg) 111 | print('Configuration -> End') 112 | 113 | model_ft = HGNN(in_ch=fts.shape[1], 114 | n_class=n_class, 115 | n_hid=cfg['n_hid'], 116 | dropout=cfg['drop_out']) 117 | model_ft = model_ft.to(device) 118 | 119 | optimizer = optim.Adam(model_ft.parameters(), lr=cfg['lr'], 120 | weight_decay=cfg['weight_decay']) 121 | # optimizer = optim.SGD(model_ft.parameters(), lr=0.01, weight_decay=cfg['weight_decay) 122 | schedular = optim.lr_scheduler.MultiStepLR(optimizer, 123 | milestones=cfg['milestones'], 124 | gamma=cfg['gamma']) 125 | criterion = torch.nn.CrossEntropyLoss() 126 | 127 | model_ft = train_model(model_ft, criterion, optimizer, schedular, cfg['max_epoch'], print_freq=cfg['print_freq']) 128 | 129 | 130 | if __name__ == '__main__': 131 | _main() 132 | -------------------------------------------------------------------------------- /utils/hypergraph_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Utility functions for Hypergraph 3 | # 4 | # Author: Yifan Feng 5 | # Date: November 2018 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | 9 | 10 | def Eu_dis(x): 11 | """ 12 | Calculate the distance among each raw of x 13 | :param x: N X D 14 | N: the object number 15 | D: Dimension of the feature 16 | :return: N X N distance matrix 17 | """ 18 | x = np.mat(x) 19 | aa = np.sum(np.multiply(x, x), 1) 20 | ab = x * x.T 21 | dist_mat = aa + aa.T - 2 * ab 22 | dist_mat[dist_mat < 0] = 0 23 | dist_mat = np.sqrt(dist_mat) 24 | dist_mat = np.maximum(dist_mat, dist_mat.T) 25 | return dist_mat 26 | 27 | 28 | def feature_concat(*F_list, normal_col=False): 29 | """ 30 | Concatenate multiple modality feature. If the dimension of a feature matrix is more than two, 31 | the function will reduce it into two dimension(using the last dimension as the feature dimension, 32 | the other dimension will be fused as the object dimension) 33 | :param F_list: Feature matrix list 34 | :param normal_col: normalize each column of the feature 35 | :return: Fused feature matrix 36 | """ 37 | features = None 38 | for f in F_list: 39 | if f is not None and f != []: 40 | # deal with the dimension that more than two 41 | if len(f.shape) > 2: 42 | f = f.reshape(-1, f.shape[-1]) 43 | # normal each column 44 | if normal_col: 45 | f_max = np.max(np.abs(f), axis=0) 46 | f = f / f_max 47 | # facing the first feature matrix appended to fused feature matrix 48 | if features is None: 49 | features = f 50 | else: 51 | features = np.hstack((features, f)) 52 | if normal_col: 53 | features_max = np.max(np.abs(features), axis=0) 54 | features = features / features_max 55 | return features 56 | 57 | 58 | def hyperedge_concat(*H_list): 59 | """ 60 | Concatenate hyperedge group in H_list 61 | :param H_list: Hyperedge groups which contain two or more hypergraph incidence matrix 62 | :return: Fused hypergraph incidence matrix 63 | """ 64 | H = None 65 | for h in H_list: 66 | if h is not None and h != []: 67 | # for the first H appended to fused hypergraph incidence matrix 68 | if H is None: 69 | H = h 70 | else: 71 | if type(h) != list: 72 | H = np.hstack((H, h)) 73 | else: 74 | tmp = [] 75 | for a, b in zip(H, h): 76 | tmp.append(np.hstack((a, b))) 77 | H = tmp 78 | return H 79 | 80 | 81 | def generate_G_from_H(H, variable_weight=False): 82 | """ 83 | calculate G from hypgraph incidence matrix H 84 | :param H: hypergraph incidence matrix H 85 | :param variable_weight: whether the weight of hyperedge is variable 86 | :return: G 87 | """ 88 | if type(H) != list: 89 | return _generate_G_from_H(H, variable_weight) 90 | else: 91 | G = [] 92 | for sub_H in H: 93 | G.append(generate_G_from_H(sub_H, variable_weight)) 94 | return G 95 | 96 | 97 | def _generate_G_from_H(H, variable_weight=False): 98 | """ 99 | calculate G from hypgraph incidence matrix H 100 | :param H: hypergraph incidence matrix H 101 | :param variable_weight: whether the weight of hyperedge is variable 102 | :return: G 103 | """ 104 | H = np.array(H) 105 | n_edge = H.shape[1] 106 | # the weight of the hyperedge 107 | W = np.ones(n_edge) 108 | # the degree of the node 109 | DV = np.sum(H * W, axis=1) 110 | # the degree of the hyperedge 111 | DE = np.sum(H, axis=0) 112 | 113 | invDE = np.mat(np.diag(np.power(DE, -1))) 114 | DV2 = np.mat(np.diag(np.power(DV, -0.5))) 115 | W = np.mat(np.diag(W)) 116 | H = np.mat(H) 117 | HT = H.T 118 | 119 | if variable_weight: 120 | DV2_H = DV2 * H 121 | invDE_HT_DV2 = invDE * HT * DV2 122 | return DV2_H, W, invDE_HT_DV2 123 | else: 124 | G = DV2 * H * W * invDE * HT * DV2 125 | return G 126 | 127 | 128 | def construct_H_with_KNN_from_distance(dis_mat, k_neig, is_probH=True, m_prob=1): 129 | """ 130 | construct hypregraph incidence matrix from hypergraph node distance matrix 131 | :param dis_mat: node distance matrix 132 | :param k_neig: K nearest neighbor 133 | :param is_probH: prob Vertex-Edge matrix or binary 134 | :param m_prob: prob 135 | :return: N_object X N_hyperedge 136 | """ 137 | n_obj = dis_mat.shape[0] 138 | # construct hyperedge from the central feature space of each node 139 | n_edge = n_obj 140 | H = np.zeros((n_obj, n_edge)) 141 | for center_idx in range(n_obj): 142 | dis_mat[center_idx, center_idx] = 0 143 | dis_vec = dis_mat[center_idx] 144 | nearest_idx = np.array(np.argsort(dis_vec)).squeeze() 145 | avg_dis = np.average(dis_vec) 146 | if not np.any(nearest_idx[:k_neig] == center_idx): 147 | nearest_idx[k_neig - 1] = center_idx 148 | 149 | for node_idx in nearest_idx[:k_neig]: 150 | if is_probH: 151 | H[node_idx, center_idx] = np.exp(-dis_vec[0, node_idx] ** 2 / (m_prob * avg_dis) ** 2) 152 | else: 153 | H[node_idx, center_idx] = 1.0 154 | return H 155 | 156 | 157 | def construct_H_with_KNN(X, K_neigs=[10], split_diff_scale=False, is_probH=True, m_prob=1): 158 | """ 159 | init multi-scale hypergraph Vertex-Edge matrix from original node feature matrix 160 | :param X: N_object x feature_number 161 | :param K_neigs: the number of neighbor expansion 162 | :param split_diff_scale: whether split hyperedge group at different neighbor scale 163 | :param is_probH: prob Vertex-Edge matrix or binary 164 | :param m_prob: prob 165 | :return: N_object x N_hyperedge 166 | """ 167 | if len(X.shape) != 2: 168 | X = X.reshape(-1, X.shape[-1]) 169 | 170 | if type(K_neigs) == int: 171 | K_neigs = [K_neigs] 172 | 173 | dis_mat = Eu_dis(X) 174 | H = [] 175 | for k_neig in K_neigs: 176 | H_tmp = construct_H_with_KNN_from_distance(dis_mat, k_neig, is_probH, m_prob) 177 | if not split_diff_scale: 178 | H = hyperedge_concat(H, H_tmp) 179 | else: 180 | H.append(H_tmp) 181 | return H 182 | --------------------------------------------------------------------------------