├── .gitignore ├── Hypformer ├── README.md ├── __init__.py ├── hypformer.py ├── main.py ├── manifolds │ ├── __init__.py │ ├── hyp_layer.py │ ├── layer.py │ ├── lmath.py │ ├── lorentz.py │ └── utils.py └── requirement.txt ├── LICENSE ├── README.md ├── data ├── 20news │ └── 20news.pkl ├── OGB │ └── data ├── Planetoid │ ├── citeseer │ │ └── raw │ │ │ ├── ind.citeseer.allx │ │ │ ├── ind.citeseer.ally │ │ │ ├── ind.citeseer.graph │ │ │ ├── ind.citeseer.test.index │ │ │ ├── ind.citeseer.tx │ │ │ ├── ind.citeseer.ty │ │ │ ├── ind.citeseer.x │ │ │ └── ind.citeseer.y │ ├── cora │ │ └── raw │ │ │ ├── ind.cora.allx │ │ │ ├── ind.cora.ally │ │ │ ├── ind.cora.graph │ │ │ ├── ind.cora.test.index │ │ │ ├── ind.cora.tx │ │ │ ├── ind.cora.ty │ │ │ ├── ind.cora.x │ │ │ └── ind.cora.y │ └── pubmed │ │ └── raw │ │ ├── ind.pubmed.allx │ │ ├── ind.pubmed.ally │ │ ├── ind.pubmed.graph │ │ ├── ind.pubmed.test.index │ │ ├── ind.pubmed.tx │ │ ├── ind.pubmed.ty │ │ ├── ind.pubmed.x │ │ └── ind.pubmed.y ├── hgcn_data │ ├── airport │ │ ├── airport.p │ │ ├── airport_alldata.p │ │ └── routes.dat │ └── disease_nc │ │ ├── disease_nc.edges.csv │ │ ├── disease_nc.feats.npz │ │ └── disease_nc.labels.npy └── mini_imagenet │ └── mini_imagenet.pkl ├── figures └── framework.jpg ├── large ├── data_utils.py ├── dataset.py ├── eval.py ├── examples │ ├── 5-runs │ │ ├── run_amazon2M.sh │ │ ├── run_arxiv.sh │ │ └── run_protein.sh │ ├── amazon2M.sh │ ├── arxiv.sh │ └── protein.sh ├── gnns.py ├── hypformer.py ├── load_data.py ├── logger.py ├── main-batch.py ├── main.py ├── manifolds │ ├── __init__.py │ ├── layer.py │ ├── lmath.py │ ├── lorentz.py │ └── utils.py └── parse.py ├── medium ├── data_utils.py ├── dataset.py ├── examples │ ├── 5-runs │ │ ├── run_airport.sh │ │ ├── run_citeseer.sh │ │ ├── run_cora.sh │ │ └── run_pubmed.sh │ ├── airport.sh │ ├── citeseer.sh │ ├── cora.sh │ └── pubmed.sh ├── gnns.py ├── hypformer.py ├── logger.py ├── main.py ├── manifolds │ ├── __init__.py │ ├── hyp_layer.py │ ├── lorentz.py │ ├── lorentz_math.py │ └── manifold_utils.py └── parse.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | wandb 3 | *__pycache__ 4 | rsync_git.sh 5 | -------------------------------------------------------------------------------- /Hypformer/README.md: -------------------------------------------------------------------------------- 1 | # Simplified Hypformer Code 2 | 3 | This folder contains simplified code for Hypformer, designed to be easily adaptable for various research applications in GNN, Text, Image processing, and more. 4 | 5 | ## Overview 6 | 7 | The Hypformer implementation includes two types of attention mechanisms: 8 | 1. Full attention (softmax-based) 9 | 2. Linear attention (kernel-based) 10 | 11 | ## Prerequisites 12 | 13 | Before running the code, ensure you have the required dependencies installed, particularly `geoopt` for hyperbolic operations. 14 | 15 | ## Installation 16 | 17 | 1. Install geoopt: 18 | ```bash 19 | pip install geoopt 20 | ``` 21 | 22 | 2. Install other dependencies (if any): 23 | ```bash 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## Usage 28 | 29 | To run a simple demonstration of Hypformer: 30 | 31 | 1. Execute the main script: 32 | ```bash 33 | python main.py 34 | ``` 35 | 36 | 2. Expected output: 37 | ``` 38 | Input shape: torch.Size([10, 16]) 39 | Output shape: torch.Size([10, 5]) 40 | ``` 41 | 42 | ## Code Structure 43 | 44 | - `main.py`: Contains a simple example showcasing the usage of Hypformer. 45 | - `hypformer.py`: The core implementation of the Hypformer model. 46 | 47 | ## Customization 48 | 49 | To adapt Hypformer for your specific research needs: 50 | 51 | 1. Open `hypformer.py` 52 | 2. Modify the attention mechanisms or model architecture as required 53 | 3. Adjust the input/output dimensions in `main.py` to match your data 54 | 55 | ## Example 56 | 57 | Here's a basic example of how to use Hypformer in your code: 58 | 59 | ```python 60 | from hypformer import Hypformer 61 | import torch 62 | 63 | # Initialize Hypformer 64 | model = Hypformer(input_dim=16, hidden_dim=32, output_dim=5, num_layers=2) 65 | 66 | # Create sample input 67 | x = torch.randn(10, 16) 68 | 69 | # Forward pass 70 | output = model(x) 71 | 72 | print(f"Input shape: {x.shape}") 73 | print(f"Output shape: {output.shape}") 74 | ``` 75 | 76 | ## Contributing 77 | 78 | We welcome contributions to improve the Hypformer implementation. Please feel free to submit pull requests or open issues for any bugs or feature requests. 79 | 80 | ## License 81 | 82 | This project is licensed under the MIT License - see the [LICENSE](../LICENSE) file in the parent directory for details. 83 | 84 | ## Contact 85 | 86 | For any questions or concerns about this simplified implementation, please open an issue in the repository or contact the first author. -------------------------------------------------------------------------------- /Hypformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/Hypformer/__init__.py -------------------------------------------------------------------------------- /Hypformer/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from hypformer import HypFormer 6 | 7 | torch.manual_seed(42) 8 | 9 | # Generate pseudo input data, this data is n by d 10 | num_sample = 10 11 | num_features = 16 12 | num_classes = 5 13 | 14 | # Generate random node features 15 | x = torch.randn(num_sample, num_features) 16 | 17 | # Define model parameters 18 | in_channels = num_features 19 | hidden_channels = 32 20 | out_channels = num_classes 21 | 22 | # Create an args object with necessary attributes 23 | class Args: 24 | def __init__(self): 25 | self.k_in = 1.0 26 | self.k_out = 1.0 27 | self.decoder_type = 'hyp' 28 | self.device = 'cpu' 29 | self.add_positional_encoding = True 30 | self.attention_type = 'full' 31 | self.power_k = 2 32 | self.trans_heads_concat = False 33 | 34 | args = Args() 35 | 36 | # Instantiate the model 37 | model = HypFormer( 38 | in_channels=in_channels, 39 | hidden_channels=hidden_channels, 40 | out_channels=out_channels, 41 | trans_num_layers=2, 42 | trans_num_heads=4, 43 | trans_dropout=0.1, 44 | trans_use_bn=True, 45 | trans_use_residual=True, 46 | trans_use_weight=True, 47 | trans_use_act=True, 48 | args=args 49 | ) 50 | 51 | # Forward pass 52 | output = model(x) 53 | 54 | print(f"Input shape: {x.shape}") 55 | print(f"Output shape: {output.shape}") 56 | -------------------------------------------------------------------------------- /Hypformer/manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import (HypLayerNorm, HypDropout, 2 | HypActivation, HypNormalization, 3 | Optimizer, HypLinear) 4 | from .lorentz import Lorentz -------------------------------------------------------------------------------- /Hypformer/manifolds/hyp_layer.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional 5 | import torch.nn.init as init 6 | from manifolds.lorentz import Lorentz 7 | import math 8 | from geoopt import ManifoldParameter 9 | from geoopt.optim.rsgd import RiemannianSGD 10 | from geoopt.optim.radam import RiemannianAdam 11 | 12 | 13 | 14 | class HypLayerNorm(nn.Module): 15 | def __init__(self, manifold, in_features, manifold_out=None): 16 | super(HypLayerNorm, self).__init__() 17 | self.in_features = in_features 18 | self.manifold = manifold 19 | self.manifold_out = manifold_out 20 | self.layer = nn.LayerNorm(self.in_features) 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | self.layer.reset_parameters() 25 | 26 | def forward(self, x): 27 | x_space = x[..., 1:] 28 | x_space = self.layer(x_space) 29 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 30 | x = torch.cat([x_time, x_space], dim=-1) 31 | 32 | if self.manifold_out is not None: 33 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 34 | return x 35 | 36 | class HypNormalization(nn.Module): 37 | def __init__(self, manifold, manifold_out=None): 38 | super(HypNormalization, self).__init__() 39 | self.manifold = manifold 40 | self.manifold_out = manifold_out 41 | 42 | def forward(self, x): 43 | x_space = x[..., 1:] 44 | x_space = x_space / x_space.norm(dim=-1, keepdim=True) 45 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 46 | x = torch.cat([x_time, x_space], dim=-1) 47 | if self.manifold_out is not None: 48 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 49 | return x 50 | 51 | class HypActivation(nn.Module): 52 | def __init__(self, manifold, activation, manifold_out=None): 53 | super(HypActivation, self).__init__() 54 | self.manifold = manifold 55 | self.manifold_out = manifold_out 56 | self.activation = activation 57 | 58 | def forward(self, x): 59 | x_space = x[..., 1:] 60 | x_space = self.activation(x_space) 61 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 62 | x = torch.cat([x_time, x_space], dim=-1) 63 | if self.manifold_out is not None: 64 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 65 | return x 66 | 67 | class HypDropout(nn.Module): 68 | def __init__(self, manifold, dropout, manifold_out=None): 69 | super(HypDropout, self).__init__() 70 | self.manifold = manifold 71 | self.manifold_out = manifold_out 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, x, training=False): 75 | if training: 76 | x_space = x[..., 1:] 77 | x_space = self.dropout(x_space) 78 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 79 | x = torch.cat([x_time, x_space], dim=-1) 80 | if self.manifold_out is not None: 81 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 82 | return x 83 | 84 | class HypLinear(nn.Module): 85 | """ 86 | Parameters: 87 | manifold (manifold): The manifold to use for the linear transformation. 88 | in_features (int): The size of each input sample. 89 | out_features (int): The size of each output sample. 90 | bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True. 91 | dropout (float, optional): The dropout probability. Default is 0.1. 92 | """ 93 | 94 | def __init__(self, manifold, in_features, out_features, bias=True, manifold_out=None): 95 | super().__init__() 96 | self.in_features = in_features + 1 # + 1 for time dimension 97 | self.out_features = out_features 98 | self.bias = bias 99 | self.manifold = manifold 100 | self.manifold_out = manifold_out 101 | 102 | self.linear = nn.Linear(self.in_features, self.out_features, bias=bias) 103 | self.reset_parameters() 104 | 105 | def reset_parameters(self): 106 | init.xavier_uniform_(self.linear.weight, gain=math.sqrt(2)) 107 | init.constant_(self.linear.bias, 0) 108 | 109 | def forward(self, x, x_manifold='hyp'): 110 | if x_manifold != 'hyp': 111 | x = torch.cat([torch.ones_like(x)[..., 0:1], x], dim=-1) 112 | x = self.manifold.expmap0(x) 113 | x_space = self.linear(x) 114 | 115 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 116 | x = torch.cat([x_time, x_space], dim=-1) 117 | if self.manifold_out is not None: 118 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 119 | return x 120 | 121 | class HypCLS(nn.Module): 122 | def __init__(self, manifold, in_channels, out_channels, bias=True): 123 | super().__init__() 124 | self.manifold = manifold 125 | self.in_channels = in_channels 126 | self.out_channels = out_channels 127 | cls_emb = self.manifold.random_normal((self.out_channels, self.in_channels + 1), mean=0, std=1. / math.sqrt(self.in_channels + 1)) 128 | self.cls = ManifoldParameter(cls_emb, self.manifold, requires_grad=True) 129 | if bias: 130 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 131 | 132 | def cinner(self, x, y): 133 | x = x.clone() 134 | x.narrow(-1, 0, 1).mul_(-1) 135 | return x @ y.transpose(-1, -2) 136 | 137 | def forward(self, x, x_manifold='hyp', return_type='neg_dist'): 138 | if x_manifold != 'hyp': 139 | x = self.manifold.expmap0(torch.cat([torch.zeros_like(x)[..., 0:1], x], dim=-1)) # project to Lorentz 140 | 141 | dist = -2 * self.manifold.k - 2 * self.cinner(x, self.cls) + self.bias 142 | dist = dist.clamp(min=0) 143 | 144 | if return_type == 'neg_dist': 145 | return - dist 146 | elif return_type == 'prob': 147 | return 10 / (1.0 + dist) 148 | elif return_type == 'neg_log_prob': 149 | return - 10*torch.log(1.0 + dist) 150 | else: 151 | raise NotImplementedError 152 | 153 | 154 | class Optimizer(object): 155 | def __init__(self, model, args): 156 | # Extract optimizer types and parameters from arguments 157 | euc_optimizer_type = getattr(args, 'euc_optimizer_type', args.optimizer_type) # Euclidean optimizer type 158 | hyp_optimizer_type = getattr(args, 'hyp_optimizer_type', args.hyp_optimizer_type) # Hyperbolic optimizer type 159 | euc_lr = getattr(args, 'euc_lr', args.lr) # Euclidean learning rate 160 | hyp_lr = getattr(args, 'hyp_lr', args.hyp_lr) # Hyperbolic learning rate 161 | euc_weight_decay = getattr(args, 'euc_weight_decay', args.weight_decay) # Euclidean weight decay 162 | hyp_weight_decay = getattr(args, 'hyp_weight_decay', args.hyp_weight_decay) # Hyperbolic weight decay 163 | 164 | # Separate parameters for Euclidean and Hyperbolic parts of the model 165 | euc_params = [p for n, p in model.named_parameters() if p.requires_grad and not isinstance(p, ManifoldParameter)] # Euclidean parameters 166 | hyp_params = [p for n, p in model.named_parameters() if p.requires_grad and isinstance(p, ManifoldParameter)] # Hyperbolic parameters 167 | 168 | # Print the number of Euclidean and Hyperbolic parameters 169 | # print(f">> Number of Euclidean parameters: {sum(p.numel() for p in euc_params)}") 170 | # print(f">> Number of Hyperbolic parameters: {sum(p.numel() for p in hyp_params)}") 171 | # Initialize Euclidean optimizer 172 | 173 | if euc_optimizer_type == 'adam': 174 | optimizer_euc = torch.optim.Adam(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 175 | elif euc_optimizer_type == 'sgd': 176 | optimizer_euc = torch.optim.SGD(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 177 | else: 178 | raise NotImplementedError("Unsupported Euclidean optimizer type") 179 | 180 | # Initialize Hyperbolic optimizer if there are Hyperbolic parameters 181 | if hyp_params: 182 | if hyp_optimizer_type == 'radam': 183 | optimizer_hyp = RiemannianAdam(hyp_params, lr=hyp_lr, stabilize=50, weight_decay=hyp_weight_decay) 184 | elif hyp_optimizer_type == 'rsgd': 185 | optimizer_hyp = RiemannianSGD(hyp_params, lr=hyp_lr, stabilize=50, weight_decay=hyp_weight_decay) 186 | else: 187 | raise NotImplementedError("Unsupported Hyperbolic optimizer type") 188 | 189 | # Store both optimizers 190 | self.optimizer = [optimizer_euc, optimizer_hyp] 191 | else: 192 | # Store only Euclidean optimizer if there are no Hyperbolic parameters 193 | self.optimizer = [optimizer_euc] 194 | 195 | def step(self): 196 | # Perform optimization step for each optimizer 197 | for optimizer in self.optimizer: 198 | optimizer.step() 199 | 200 | def zero_grad(self): 201 | # Reset gradients to zero for each optimizer 202 | for optimizer in self.optimizer: 203 | optimizer.zero_grad() 204 | 205 | -------------------------------------------------------------------------------- /Hypformer/manifolds/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from geoopt import ManifoldParameter 5 | from geoopt.optim.rsgd import RiemannianSGD 6 | from geoopt.optim.radam import RiemannianAdam 7 | import math 8 | 9 | 10 | class HypLayerNorm(nn.Module): 11 | """ 12 | Hyperbolic Layer Normalization Layer 13 | 14 | Parameters: 15 | manifold (Manifold): The manifold to use for normalization. 16 | in_features (int): The number of input features. 17 | manifold_out (Manifold, optional): The output manifold. Default is None. 18 | """ 19 | 20 | def __init__(self, manifold, in_features, manifold_out=None): 21 | super(HypLayerNorm, self).__init__() 22 | self.in_features = in_features 23 | self.manifold = manifold 24 | self.manifold_out = manifold_out 25 | self.layer = nn.LayerNorm(self.in_features) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | """Reset layer parameters.""" 30 | self.layer.reset_parameters() 31 | 32 | def forward(self, x): 33 | """Forward pass for hyperbolic layer normalization.""" 34 | x_space = x[..., 1:] 35 | x_space = self.layer(x_space) 36 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 37 | x = torch.cat([x_time, x_space], dim=-1) 38 | 39 | if self.manifold_out is not None: 40 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 41 | return x 42 | 43 | 44 | class HypNormalization(nn.Module): 45 | """ 46 | Hyperbolic Normalization Layer 47 | 48 | Parameters: 49 | manifold (Manifold): The manifold to use for normalization. 50 | manifold_out (Manifold, optional): The output manifold. Default is None. 51 | """ 52 | 53 | def __init__(self, manifold, manifold_out=None): 54 | super(HypNormalization, self).__init__() 55 | self.manifold = manifold 56 | self.manifold_out = manifold_out 57 | 58 | def forward(self, x): 59 | """Forward pass for hyperbolic normalization.""" 60 | x_space = x[..., 1:] 61 | x_space = x_space / x_space.norm(dim=-1, keepdim=True) 62 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 63 | x = torch.cat([x_time, x_space], dim=-1) 64 | if self.manifold_out is not None: 65 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 66 | return x 67 | 68 | 69 | class HypActivation(nn.Module): 70 | """ 71 | Hyperbolic Activation Layer 72 | 73 | Parameters: 74 | manifold (Manifold): The manifold to use for the activation. 75 | activation (function): The activation function. 76 | manifold_out (Manifold, optional): The output manifold. Default is None. 77 | """ 78 | 79 | def __init__(self, manifold, activation, manifold_out=None): 80 | super(HypActivation, self).__init__() 81 | self.manifold = manifold 82 | self.manifold_out = manifold_out 83 | self.activation = activation 84 | 85 | def forward(self, x): 86 | """Forward pass for hyperbolic activation.""" 87 | x_space = x[..., 1:] 88 | x_space = self.activation(x_space) 89 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 90 | x = torch.cat([x_time, x_space], dim=-1) 91 | if self.manifold_out is not None: 92 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 93 | return x 94 | 95 | 96 | class HypDropout(nn.Module): 97 | """ 98 | Hyperbolic Dropout Layer 99 | 100 | Parameters: 101 | manifold (Manifold): The manifold to use for the dropout. 102 | dropout (float): The dropout probability. 103 | manifold_out (Manifold, optional): The output manifold. Default is None. 104 | """ 105 | 106 | def __init__(self, manifold, dropout, manifold_out=None): 107 | super(HypDropout, self).__init__() 108 | self.manifold = manifold 109 | self.manifold_out = manifold_out 110 | self.dropout = nn.Dropout(dropout) 111 | 112 | def forward(self, x, training=False): 113 | """Forward pass for hyperbolic dropout.""" 114 | if training: 115 | x_space = x[..., 1:] 116 | x_space = self.dropout(x_space) 117 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 118 | x = torch.cat([x_time, x_space], dim=-1) 119 | if self.manifold_out is not None: 120 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 121 | return x 122 | 123 | 124 | class HypLinear(nn.Module): 125 | """ 126 | Hyperbolic Linear Layer 127 | 128 | Parameters: 129 | manifold (Manifold): The manifold to use for the linear transformation. 130 | in_features (int): The size of each input sample. 131 | out_features (int): The size of each output sample. 132 | bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True. 133 | dropout (float, optional): The dropout probability. Default is 0.0. 134 | manifold_out (Manifold, optional): The output manifold. Default is None. 135 | """ 136 | 137 | def __init__(self, manifold, in_features, out_features, bias=True, dropout=0.0, manifold_out=None): 138 | super().__init__() 139 | self.in_features = in_features + 1 # +1 for time dimension 140 | self.out_features = out_features 141 | self.bias = bias 142 | self.manifold = manifold 143 | self.manifold_out = manifold_out 144 | 145 | self.linear = nn.Linear(self.in_features, self.out_features, bias=bias) 146 | self.dropout_rate = dropout 147 | self.reset_parameters() 148 | 149 | def reset_parameters(self): 150 | """Reset layer parameters.""" 151 | init.xavier_uniform_(self.linear.weight, gain=math.sqrt(2)) 152 | if self.bias: 153 | init.constant_(self.linear.bias, 0) 154 | 155 | def forward(self, x, x_manifold='hyp'): 156 | """Forward pass for hyperbolic linear layer.""" 157 | if x_manifold != 'hyp': 158 | x = torch.cat([torch.ones_like(x)[..., 0:1], x], dim=-1) 159 | x = self.manifold.expmap0(x) 160 | x_space = self.linear(x) 161 | 162 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 163 | x = torch.cat([x_time, x_space], dim=-1) 164 | if self.manifold_out is not None: 165 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 166 | return x 167 | 168 | class HypCLS(nn.Module): 169 | def __init__(self, manifold, in_channels, out_channels, bias=True): 170 | """ 171 | Initializes the HypCLS class with the given parameters. 172 | 173 | Parameters: 174 | - `manifold` (Manifold): The manifold object. 175 | - `in_channels` (int): The number of input channels. 176 | - `out_channels` (int): The number of output channels. 177 | - `bias` (bool, optional): Whether to include a bias term. Defaults to True. 178 | 179 | Returns: 180 | None 181 | """ 182 | super().__init__() 183 | self.manifold = manifold 184 | self.in_channels = in_channels 185 | self.out_channels = out_channels 186 | cls_emb = self.manifold.random_normal((self.out_channels, self.in_channels + 1), mean=0, std=1. / math.sqrt(self.in_channels + 1)) 187 | self.cls = ManifoldParameter(cls_emb, self.manifold, requires_grad=True) 188 | if bias: 189 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 190 | 191 | def cinner(self, x, y): 192 | x = x.clone() 193 | x.narrow(-1, 0, 1).mul_(-1) 194 | return x @ y.transpose(-1, -2) 195 | 196 | def forward(self, x, x_manifold='hyp', return_type='neg_dist'): 197 | if x_manifold != 'hyp': 198 | x = self.manifold.expmap0(torch.cat([torch.zeros_like(x)[..., 0:1], x], dim=-1)) # project to Lorentz 199 | 200 | dist = -2 * self.manifold.k - 2 * self.cinner(x, self.cls) + self.bias 201 | dist = dist.clamp(min=0) 202 | 203 | if return_type == 'neg_dist': 204 | return - dist 205 | elif return_type == 'prob': 206 | return 1.0 / (1.0 + dist) 207 | elif return_type == 'neg_log_prob': 208 | return - 1.0*torch.log(1.0 + dist) 209 | else: 210 | raise NotImplementedError 211 | 212 | class Optimizer(object): 213 | """ 214 | Optimizer for Euclidean and Hyperbolic parameters 215 | 216 | Parameters: 217 | model (nn.Module): The model containing the parameters to optimize. 218 | args (Namespace): The arguments containing optimizer settings. 219 | """ 220 | 221 | def __init__(self, model, args): 222 | euc_optimizer_type = args.optimizer_type 223 | hyp_optimizer_type = args.hyp_optimizer_type 224 | euc_lr = args.lr 225 | hyp_lr = args.hyp_lr 226 | euc_weight_decay = args.weight_decay 227 | hyp_weight_decay = args.hyp_weight_decay 228 | 229 | euc_params = [p for n, p in model.named_parameters() if 230 | p.requires_grad and not isinstance(p, ManifoldParameter)] 231 | hyp_params = [p for n, p in model.named_parameters() if p.requires_grad and isinstance(p, ManifoldParameter)] 232 | 233 | print(f">> Number of Euclidean parameters: {sum(p.numel() for p in euc_params)}") 234 | print(f">> Number of Hyperbolic parameters: {sum(p.numel() for p in hyp_params)}") 235 | self.optimizer = [] # Optimizers for Euclidean and Hyperbolic parts of the model 236 | 237 | if euc_params: 238 | if euc_optimizer_type == 'adam': 239 | optimizer_euc = torch.optim.Adam(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 240 | elif euc_optimizer_type == 'sgd': 241 | optimizer_euc = torch.optim.SGD(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 242 | else: 243 | raise NotImplementedError(f"Unknown Euclidean optimizer type: {euc_optimizer_type}") 244 | self.optimizer.append(optimizer_euc) 245 | 246 | if hyp_params: 247 | if hyp_optimizer_type == 'radam': 248 | optimizer_hyp = RiemannianAdam(hyp_params, lr=hyp_lr, stabilize=10, weight_decay=hyp_weight_decay) 249 | elif hyp_optimizer_type == 'rsgd': 250 | optimizer_hyp = RiemannianSGD(hyp_params, lr=hyp_lr, stabilize=10, weight_decay=hyp_weight_decay) 251 | else: 252 | raise NotImplementedError(f"Unknown Hyperbolic optimizer type: {hyp_optimizer_type}") 253 | self.optimizer.append(optimizer_hyp) 254 | 255 | def step(self): 256 | """Performs a single optimization step.""" 257 | for optimizer in self.optimizer: 258 | optimizer.step() 259 | 260 | def zero_grad(self): 261 | """Sets the gradients of all optimized tensors to zero.""" 262 | for optimizer in self.optimizer: 263 | optimizer.zero_grad() 264 | -------------------------------------------------------------------------------- /Hypformer/manifolds/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Tuple, Any, Union, List 3 | import functools 4 | import operator 5 | import torch 6 | import geoopt 7 | 8 | 9 | max_norm = 85 10 | eps = 1e-8 11 | 12 | __all__ = [ 13 | "copy_or_set_", 14 | "strip_tuple", 15 | "size2shape", 16 | "make_tuple", 17 | "broadcast_shapes", 18 | "ismanifold", 19 | "canonical_manifold", 20 | "list_range", 21 | "idx2sign", 22 | "drop_dims", 23 | "canonical_dims", 24 | "sign", 25 | "prod", 26 | "clamp_abs", 27 | "sabs", 28 | ] 29 | 30 | 31 | def copy_or_set_(dest: torch.Tensor, source: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Copy or inplace set from :code:`source` to :code:`dest`. 34 | 35 | A workaround to respect strides of :code:`dest` when copying :code:`source`. 36 | The original issue was raised `here `_ 37 | when working with matrix manifolds. Inplace set operation is mode efficient, 38 | but the resulting storage might be incompatible after. To avoid the issue we refer to 39 | the safe option and use :code:`copy_` if strides do not match. 40 | 41 | Parameters 42 | ---------- 43 | dest : torch.Tensor 44 | Destination tensor where to store new data 45 | source : torch.Tensor 46 | Source data to put in the new tensor 47 | 48 | Returns 49 | ------- 50 | dest 51 | torch.Tensor, modified inplace 52 | """ 53 | if dest.stride() != source.stride(): 54 | return dest.copy_(source) 55 | else: 56 | return dest.set_(source) 57 | 58 | 59 | def strip_tuple(tup: Tuple) -> Union[Tuple, Any]: 60 | if len(tup) == 1: 61 | return tup[0] 62 | else: 63 | return tup 64 | 65 | 66 | def make_tuple(obj: Union[Tuple, List, Any]) -> Tuple: 67 | if isinstance(obj, list): 68 | obj = tuple(obj) 69 | if not isinstance(obj, tuple): 70 | return (obj,) 71 | else: 72 | return obj 73 | 74 | 75 | def prod(items): 76 | return functools.reduce(operator.mul, items, 1) 77 | 78 | 79 | def sign(x): 80 | return torch.sign(x.sign() + 0.5) 81 | 82 | 83 | def sabs(x, eps: float = 1e-15): 84 | return x.abs().add_(eps) 85 | 86 | 87 | def clamp_abs(x, eps: float = 1e-15): 88 | s = sign(x) 89 | return s * sabs(x, eps=eps) 90 | 91 | 92 | def idx2sign(idx: int, dim: int, neg: bool = True): 93 | """ 94 | Unify idx to be negative or positive, that helps in cases of broadcasting. 95 | 96 | Parameters 97 | ---------- 98 | idx : int 99 | current index 100 | dim : int 101 | maximum dimension 102 | neg : bool 103 | indicate we need negative index 104 | 105 | Returns 106 | ------- 107 | int 108 | """ 109 | if neg: 110 | if idx < 0: 111 | return idx 112 | else: 113 | return (idx + 1) % -(dim + 1) 114 | else: 115 | return idx % dim 116 | 117 | 118 | def drop_dims(tensor: torch.Tensor, dims: List[int]): 119 | # Workaround to drop several dims in :func:`torch.squeeze`. 120 | seen: int = 0 121 | for d in dims: 122 | tensor = tensor.squeeze(d - seen) 123 | seen += 1 124 | return tensor 125 | 126 | 127 | def list_range(end: int): 128 | res: List[int] = [] 129 | for d in range(end): 130 | res.append(d) 131 | return res 132 | 133 | 134 | def canonical_dims(dims: List[int], maxdim: int): 135 | result: List[int] = [] 136 | for idx in dims: 137 | result.append(idx2sign(idx, maxdim, neg=False)) 138 | return result 139 | 140 | 141 | def size2shape(*size: Union[Tuple[int], int]) -> Tuple[int]: 142 | return make_tuple(strip_tuple(size)) 143 | 144 | 145 | def broadcast_shapes(*shapes: Tuple[int]) -> Tuple[int]: 146 | """Apply numpy broadcasting rules to shapes.""" 147 | result = [] 148 | for dims in itertools.zip_longest(*map(reversed, shapes), fillvalue=1): 149 | dim: int = 1 150 | for d in dims: 151 | if dim != 1 and d != 1 and d != dim: 152 | raise ValueError("Shapes can't be broadcasted") 153 | elif d > dim: 154 | dim = d 155 | result.append(dim) 156 | return tuple(reversed(result)) 157 | 158 | 159 | def ismanifold(instance, cls): 160 | """ 161 | Check if interface of an instance is compatible with given class. 162 | 163 | Parameters 164 | ---------- 165 | instance : geoopt.Manifold 166 | check if a given manifold is compatible with cls API 167 | cls : type 168 | manifold type 169 | 170 | Returns 171 | ------- 172 | bool 173 | comparison result 174 | """ 175 | if not issubclass(cls, geoopt.manifolds.Manifold): 176 | raise TypeError( 177 | "`cls` should be a subclass of geoopt.manifolds.Manifold") 178 | if not isinstance(instance, geoopt.manifolds.Manifold): 179 | return False 180 | else: 181 | # this is the case to care about, Scaled class is a proxy, but fails instance checks 182 | while isinstance(instance, geoopt.Scaled): 183 | instance = instance.base 184 | return isinstance(instance, cls) 185 | 186 | 187 | def canonical_manifold(manifold: "geoopt.Manifold"): 188 | """ 189 | Get a canonical manifold. 190 | 191 | If a manifold is wrapped with Scaled. Some attributes may not be available. This should help if you really need them. 192 | 193 | Parameters 194 | ---------- 195 | manifold : geoopt.Manifold 196 | 197 | Returns 198 | ------- 199 | geoopt.Maniflold 200 | an unwrapped manifold 201 | """ 202 | while isinstance(manifold, geoopt.Scaled): 203 | manifold = manifold.base 204 | return manifold 205 | 206 | 207 | def cosh(x: torch.Tensor) -> torch.Tensor: 208 | x = clamp(x, min=-max_norm, max=max_norm) 209 | return torch.cosh(x) 210 | 211 | 212 | def sinh(x: torch.Tensor) -> torch.Tensor: 213 | x = clamp(x, min=-max_norm, max=max_norm) 214 | return torch.sinh(x) 215 | 216 | 217 | def sqrt(x: torch.Tensor) -> torch.Tensor: 218 | x = clamp(x, min=1e-9) # Smaller epsilon due to precision around x=0. 219 | return torch.sqrt(x) 220 | 221 | 222 | class LeakyClamp(torch.autograd.Function): 223 | 224 | @staticmethod 225 | def forward(ctx: Any, x: torch.Tensor, min: float, max: float) -> torch.Tensor: 226 | with torch.no_grad(): 227 | ctx.save_for_backward(x.ge(min) & x.le(max)) 228 | return torch.clamp(x, min=min, max=max) 229 | 230 | @staticmethod 231 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 232 | mask, = ctx.saved_tensors 233 | mask = mask.type_as(grad_output) 234 | return grad_output * mask + grad_output * (1 - mask) * eps, None, None 235 | 236 | 237 | def clamp(x: torch.Tensor, min: float = float("-inf"), max: float = float("+inf")) -> torch.Tensor: 238 | return LeakyClamp.apply(x, min, max) 239 | 240 | 241 | class Atanh(torch.autograd.Function): 242 | """ 243 | Numerically stable arctanh that never returns NaNs. 244 | x = clamp(x, min=-1+eps, max=1-eps) 245 | Returns atanh(x) = arctanh(x) = 0.5*(log(1+x)-log(1-x)). 246 | """ 247 | 248 | @staticmethod 249 | def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: 250 | x = clamp(x, min=-1. + 4 * eps, max=1. - 4 * eps) 251 | ctx.save_for_backward(x) 252 | res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) 253 | return res 254 | 255 | @staticmethod 256 | def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: 257 | x, = ctx.saved_tensors 258 | return grad_output / (1 - x**2) 259 | 260 | 261 | def atanh(x: torch.Tensor) -> torch.Tensor: 262 | """ 263 | Numerically stable arctanh that never returns NaNs. 264 | 265 | :param x: The input tensor. 266 | :return: log(x + sqrt(max(x^2 - 1, eps)) 267 | """ 268 | return Atanh.apply(x) 269 | 270 | 271 | class Acosh(torch.autograd.Function): 272 | """ 273 | Numerically stable arccosh that never returns NaNs. 274 | Returns acosh(x) = arccosh(x) = log(x + sqrt(max(x^2 - 1, eps))). 275 | """ 276 | 277 | @staticmethod 278 | def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: 279 | with torch.no_grad(): 280 | x = clamp(x, min=1 + eps) 281 | z = sqrt(x * x - 1.) 282 | ctx.save_for_backward(z) 283 | return torch.log(x + z) 284 | 285 | @staticmethod 286 | def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: 287 | z, = ctx.saved_tensors 288 | # z_ = clamp(z, min=eps) 289 | z_ = z 290 | return grad_output / z_ 291 | 292 | 293 | def acosh(x: torch.Tensor) -> torch.Tensor: 294 | """ 295 | Numerically stable arccosh that never returns NaNs. 296 | 297 | :param x: The input tensor. 298 | :return: log(x + sqrt(max(x^2 - 1, eps)) 299 | """ 300 | return Acosh.apply(x) -------------------------------------------------------------------------------- /Hypformer/requirement.txt: -------------------------------------------------------------------------------- 1 | geoopt 2 | torch 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2024] [Graph and Geometric Learning Lab / Menglin Yang, Yale University] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hypformer: Exploring Efficient Hyperbolic Transformer Fully in Hyperbolic Space 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2407.01290-b31b1b.svg)](https://arxiv.org/abs/2407.01290) 4 | [![Conference](https://img.shields.io/badge/KDD-2024-blue)](https://kdd.org/kdd2024/) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | This is the PyTorch implementation of the paper ["Hypformer: Exploring Efficient Hyperbolic Transformer Fully in Hyperbolic Space"](https://arxiv.org/abs/2407.01290) to be presented at KDD 2024. 8 | 9 | 10 | Menglin Yang, Harshit Verma, Delvin Ce Zhang, Jiahong Liu, Irwin King, Rex Ying 11 | 12 | Arxiv: https://arxiv.org/abs/2407.01290 13 | 14 | Code: https://github.com/Graph-and-Geometric-Learning/hyperbolic-transformer 15 | 16 | ## Framework 17 | 18 | ![framework](./figures/framework.jpg) 19 | 20 | ## Updates (August 20, 2024 🔥) 21 | 22 | - [x] Large-scale graph evaluation 23 | - [x] Medium-scale graph evaluation 24 | - [ ] Image and text data evaluation (To be updated) 25 | - [x] Simplified version for reusage 26 | - [ ] Baselines (To be updated) 27 | 28 | ## 1. Requirements 29 | 30 | To install the required packages, run: 31 | 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## 2. Dataset 37 | 38 | Please check the `./data` folder for available datasets. 39 | 40 | Note: OGB datasets will be downloaded automatically when needed. 41 | 42 | ## 3. Run Hyperbolic Transformer 43 | 44 | The code has been evaluated on NVIDIA A100 GPUs. 45 | 46 | To run the code: 47 | 48 | 1. Navigate to the `large` directory: `cd large` 49 | 50 | 2. Check the `example` folder, where the `5-runs` folder contains scripts to get averaged results. 51 | 52 | 3. For a single run, execute one of the following commands: 53 | ```bash 54 | bash example/amazon2M.sh 55 | bash example/arxiv.sh 56 | bash example/proteins.sh 57 | ``` 58 | 4. Navigate to the `medium` directory: `cd medium`. For a single run, execute one of the following commands: 59 | 60 | ```bash 61 | bash example/cora.sh 62 | bash example/citeseer.sh 63 | bash example/pubmed.sh 64 | bash example/airport.sh 65 | ``` 66 | 67 | ## 4. Reuse Hyperbolic Transformer Modules 68 | 69 | To reuse the Hyperbolic Transformer modules, please check the folder `./Hypformer` 70 | 71 | for example: 72 | `Hyperbolic LayerNorm` in [hyp_layer.py](./Hypformer/manifolds/hyp_layer.py) 73 | 74 | ```python 75 | import torch 76 | import torch.nn as nn 77 | class HypLayerNorm(nn.Module): 78 | def __init__(self, manifold, in_features, manifold_out=None): 79 | super(HypLayerNorm, self).__init__() 80 | self.in_features = in_features 81 | self.manifold = manifold 82 | self.manifold_out = manifold_out 83 | self.layer = nn.LayerNorm(self.in_features) 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | self.layer.reset_parameters() 88 | 89 | def forward(self, x): 90 | x_space = x[..., 1:] 91 | x_space = self.layer(x_space) 92 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 93 | x = torch.cat([x_time, x_space], dim=-1) 94 | 95 | if self.manifold_out is not None: 96 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 97 | return x 98 | ``` 99 | 100 | - `Hyperbolic Linear Transformation` in [hyp_layer.py](./Hypformer/manifolds/hyp_layer.py) 101 | - `Hyperbolic Dropout Operations` in [hyp_layer.py](./Hypformer/manifolds/hyp_layer.py) 102 | - `Hyperbolic Activation Operations` in [hyp_layer.py](./Hypformer/manifolds/hyp_layer.py) 103 | - `Hyperbolic Classification Layer` in [hyp_layer.py](./Hypformer/manifolds/hyp_layer.py) 104 | - `Hyperbolic full/linear Attention` in [hypformer.py](./Hypformer/hypformer.py) 105 | 106 | ## 5. Acknowledgments 107 | 108 | This project was heavily built upon the following projects. We thank the authors for their awesome contributions: 109 | 110 | - SGFormer - https://github.com/qitianwu/SGFormer 111 | - HGCN - https://github.com/HazyResearch/hgcn/tree/master 112 | - fully HNN - https://github.com/chenweize1998/fully-hyperbolic-nn 113 | - Open Graph Benchmark - https://ogb.stanford.edu/ 114 | - Geoopt - https://github.com/geoopt/geoopt 115 | - GrapGPS - https://github.com/rampasek/GraphGPS 116 | - Graphformer - https://github.com/microsoft/GraphFormers 117 | - GraphTrans - https://github.com/ucbrise/graphtrans 118 | - Nodeformer - https://github.com/qitianwu/NodeFormer 119 | 120 | ## Citation 121 | 122 | If you find this work useful in your research, please consider citing our paper: 123 | 124 | ```bibtex 125 | @inproceedings{yang2022hypformer, 126 | title={Hypformer: Exploring Efficient Hyperbolic Transformer Fully in Hyperbolic Space}, 127 | author={Yang, Menglin and Verma, Harshit and Zhang, Delvin Ce and Liu, Jiahong and King, Irwin and Ying, Rex}, 128 | booktitle={Proceedings of the 2024 ACM SIGKDD International Conference on Knowledge Discovery and Data Mining}, 129 | year={2024} 130 | } 131 | ``` 132 | 133 | ## License 134 | 135 | This project is licensed under the MIT License - see the [LICENSE](./LICENSE) file for details. 136 | 137 | ## Contact 138 | 139 | For any questions or concerns, please open an issue in this repository or contact menglin.yang@{yale.edu,outlook.com} -------------------------------------------------------------------------------- /data/20news/20news.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/20news/20news.pkl -------------------------------------------------------------------------------- /data/OGB/data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/OGB/data -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/citeseer/raw/ind.citeseer.allx -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/citeseer/raw/ind.citeseer.ally -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/citeseer/raw/ind.citeseer.graph -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.test.index: -------------------------------------------------------------------------------- 1 | 2488 2 | 2644 3 | 3261 4 | 2804 5 | 3176 6 | 2432 7 | 3310 8 | 2410 9 | 2812 10 | 2520 11 | 2994 12 | 3282 13 | 2680 14 | 2848 15 | 2670 16 | 3005 17 | 2977 18 | 2592 19 | 2967 20 | 2461 21 | 3184 22 | 2852 23 | 2768 24 | 2905 25 | 2851 26 | 3129 27 | 3164 28 | 2438 29 | 2793 30 | 2763 31 | 2528 32 | 2954 33 | 2347 34 | 2640 35 | 3265 36 | 2874 37 | 2446 38 | 2856 39 | 3149 40 | 2374 41 | 3097 42 | 3301 43 | 2664 44 | 2418 45 | 2655 46 | 2464 47 | 2596 48 | 3262 49 | 3278 50 | 2320 51 | 2612 52 | 2614 53 | 2550 54 | 2626 55 | 2772 56 | 3007 57 | 2733 58 | 2516 59 | 2476 60 | 2798 61 | 2561 62 | 2839 63 | 2685 64 | 2391 65 | 2705 66 | 3098 67 | 2754 68 | 3251 69 | 2767 70 | 2630 71 | 2727 72 | 2513 73 | 2701 74 | 3264 75 | 2792 76 | 2821 77 | 3260 78 | 2462 79 | 3307 80 | 2639 81 | 2900 82 | 3060 83 | 2672 84 | 3116 85 | 2731 86 | 3316 87 | 2386 88 | 2425 89 | 2518 90 | 3151 91 | 2586 92 | 2797 93 | 2479 94 | 3117 95 | 2580 96 | 3182 97 | 2459 98 | 2508 99 | 3052 100 | 3230 101 | 3215 102 | 2803 103 | 2969 104 | 2562 105 | 2398 106 | 3325 107 | 2343 108 | 3030 109 | 2414 110 | 2776 111 | 2383 112 | 3173 113 | 2850 114 | 2499 115 | 3312 116 | 2648 117 | 2784 118 | 2898 119 | 3056 120 | 2484 121 | 3179 122 | 3132 123 | 2577 124 | 2563 125 | 2867 126 | 3317 127 | 2355 128 | 3207 129 | 3178 130 | 2968 131 | 3319 132 | 2358 133 | 2764 134 | 3001 135 | 2683 136 | 3271 137 | 2321 138 | 2567 139 | 2502 140 | 3246 141 | 2715 142 | 3066 143 | 2390 144 | 2381 145 | 3162 146 | 2741 147 | 2498 148 | 2790 149 | 3038 150 | 3321 151 | 2481 152 | 3050 153 | 3161 154 | 3122 155 | 2801 156 | 2957 157 | 3177 158 | 2965 159 | 2621 160 | 3208 161 | 2921 162 | 2802 163 | 2357 164 | 2677 165 | 2519 166 | 2860 167 | 2696 168 | 2368 169 | 3241 170 | 2858 171 | 2419 172 | 2762 173 | 2875 174 | 3222 175 | 3064 176 | 2827 177 | 3044 178 | 2471 179 | 3062 180 | 2982 181 | 2736 182 | 2322 183 | 2709 184 | 2766 185 | 2424 186 | 2602 187 | 2970 188 | 2675 189 | 3299 190 | 2554 191 | 2964 192 | 2597 193 | 2753 194 | 2979 195 | 2523 196 | 2912 197 | 2896 198 | 2317 199 | 3167 200 | 2813 201 | 2482 202 | 2557 203 | 3043 204 | 3244 205 | 2985 206 | 2460 207 | 2363 208 | 3272 209 | 3045 210 | 3192 211 | 2453 212 | 2656 213 | 2834 214 | 2443 215 | 3202 216 | 2926 217 | 2711 218 | 2633 219 | 2384 220 | 2752 221 | 3285 222 | 2817 223 | 2483 224 | 2919 225 | 2924 226 | 2661 227 | 2698 228 | 2361 229 | 2662 230 | 2819 231 | 3143 232 | 2316 233 | 3196 234 | 2739 235 | 2345 236 | 2578 237 | 2822 238 | 3229 239 | 2908 240 | 2917 241 | 2692 242 | 3200 243 | 2324 244 | 2522 245 | 3322 246 | 2697 247 | 3163 248 | 3093 249 | 3233 250 | 2774 251 | 2371 252 | 2835 253 | 2652 254 | 2539 255 | 2843 256 | 3231 257 | 2976 258 | 2429 259 | 2367 260 | 3144 261 | 2564 262 | 3283 263 | 3217 264 | 3035 265 | 2962 266 | 2433 267 | 2415 268 | 2387 269 | 3021 270 | 2595 271 | 2517 272 | 2468 273 | 3061 274 | 2673 275 | 2348 276 | 3027 277 | 2467 278 | 3318 279 | 2959 280 | 3273 281 | 2392 282 | 2779 283 | 2678 284 | 3004 285 | 2634 286 | 2974 287 | 3198 288 | 2342 289 | 2376 290 | 3249 291 | 2868 292 | 2952 293 | 2710 294 | 2838 295 | 2335 296 | 2524 297 | 2650 298 | 3186 299 | 2743 300 | 2545 301 | 2841 302 | 2515 303 | 2505 304 | 3181 305 | 2945 306 | 2738 307 | 2933 308 | 3303 309 | 2611 310 | 3090 311 | 2328 312 | 3010 313 | 3016 314 | 2504 315 | 2936 316 | 3266 317 | 3253 318 | 2840 319 | 3034 320 | 2581 321 | 2344 322 | 2452 323 | 2654 324 | 3199 325 | 3137 326 | 2514 327 | 2394 328 | 2544 329 | 2641 330 | 2613 331 | 2618 332 | 2558 333 | 2593 334 | 2532 335 | 2512 336 | 2975 337 | 3267 338 | 2566 339 | 2951 340 | 3300 341 | 2869 342 | 2629 343 | 2747 344 | 3055 345 | 2831 346 | 3105 347 | 3168 348 | 3100 349 | 2431 350 | 2828 351 | 2684 352 | 3269 353 | 2910 354 | 2865 355 | 2693 356 | 2884 357 | 3228 358 | 2783 359 | 3247 360 | 2770 361 | 3157 362 | 2421 363 | 2382 364 | 2331 365 | 3203 366 | 3240 367 | 2351 368 | 3114 369 | 2986 370 | 2688 371 | 2439 372 | 2996 373 | 3079 374 | 3103 375 | 3296 376 | 2349 377 | 2372 378 | 3096 379 | 2422 380 | 2551 381 | 3069 382 | 2737 383 | 3084 384 | 3304 385 | 3022 386 | 2542 387 | 3204 388 | 2949 389 | 2318 390 | 2450 391 | 3140 392 | 2734 393 | 2881 394 | 2576 395 | 3054 396 | 3089 397 | 3125 398 | 2761 399 | 3136 400 | 3111 401 | 2427 402 | 2466 403 | 3101 404 | 3104 405 | 3259 406 | 2534 407 | 2961 408 | 3191 409 | 3000 410 | 3036 411 | 2356 412 | 2800 413 | 3155 414 | 3224 415 | 2646 416 | 2735 417 | 3020 418 | 2866 419 | 2426 420 | 2448 421 | 3226 422 | 3219 423 | 2749 424 | 3183 425 | 2906 426 | 2360 427 | 2440 428 | 2946 429 | 2313 430 | 2859 431 | 2340 432 | 3008 433 | 2719 434 | 3058 435 | 2653 436 | 3023 437 | 2888 438 | 3243 439 | 2913 440 | 3242 441 | 3067 442 | 2409 443 | 3227 444 | 2380 445 | 2353 446 | 2686 447 | 2971 448 | 2847 449 | 2947 450 | 2857 451 | 3263 452 | 3218 453 | 2861 454 | 3323 455 | 2635 456 | 2966 457 | 2604 458 | 2456 459 | 2832 460 | 2694 461 | 3245 462 | 3119 463 | 2942 464 | 3153 465 | 2894 466 | 2555 467 | 3128 468 | 2703 469 | 2323 470 | 2631 471 | 2732 472 | 2699 473 | 2314 474 | 2590 475 | 3127 476 | 2891 477 | 2873 478 | 2814 479 | 2326 480 | 3026 481 | 3288 482 | 3095 483 | 2706 484 | 2457 485 | 2377 486 | 2620 487 | 2526 488 | 2674 489 | 3190 490 | 2923 491 | 3032 492 | 2334 493 | 3254 494 | 2991 495 | 3277 496 | 2973 497 | 2599 498 | 2658 499 | 2636 500 | 2826 501 | 3148 502 | 2958 503 | 3258 504 | 2990 505 | 3180 506 | 2538 507 | 2748 508 | 2625 509 | 2565 510 | 3011 511 | 3057 512 | 2354 513 | 3158 514 | 2622 515 | 3308 516 | 2983 517 | 2560 518 | 3169 519 | 3059 520 | 2480 521 | 3194 522 | 3291 523 | 3216 524 | 2643 525 | 3172 526 | 2352 527 | 2724 528 | 2485 529 | 2411 530 | 2948 531 | 2445 532 | 2362 533 | 2668 534 | 3275 535 | 3107 536 | 2496 537 | 2529 538 | 2700 539 | 2541 540 | 3028 541 | 2879 542 | 2660 543 | 3324 544 | 2755 545 | 2436 546 | 3048 547 | 2623 548 | 2920 549 | 3040 550 | 2568 551 | 3221 552 | 3003 553 | 3295 554 | 2473 555 | 3232 556 | 3213 557 | 2823 558 | 2897 559 | 2573 560 | 2645 561 | 3018 562 | 3326 563 | 2795 564 | 2915 565 | 3109 566 | 3086 567 | 2463 568 | 3118 569 | 2671 570 | 2909 571 | 2393 572 | 2325 573 | 3029 574 | 2972 575 | 3110 576 | 2870 577 | 3284 578 | 2816 579 | 2647 580 | 2667 581 | 2955 582 | 2333 583 | 2960 584 | 2864 585 | 2893 586 | 2458 587 | 2441 588 | 2359 589 | 2327 590 | 3256 591 | 3099 592 | 3073 593 | 3138 594 | 2511 595 | 2666 596 | 2548 597 | 2364 598 | 2451 599 | 2911 600 | 3237 601 | 3206 602 | 3080 603 | 3279 604 | 2934 605 | 2981 606 | 2878 607 | 3130 608 | 2830 609 | 3091 610 | 2659 611 | 2449 612 | 3152 613 | 2413 614 | 2722 615 | 2796 616 | 3220 617 | 2751 618 | 2935 619 | 3238 620 | 2491 621 | 2730 622 | 2842 623 | 3223 624 | 2492 625 | 3074 626 | 3094 627 | 2833 628 | 2521 629 | 2883 630 | 3315 631 | 2845 632 | 2907 633 | 3083 634 | 2572 635 | 3092 636 | 2903 637 | 2918 638 | 3039 639 | 3286 640 | 2587 641 | 3068 642 | 2338 643 | 3166 644 | 3134 645 | 2455 646 | 2497 647 | 2992 648 | 2775 649 | 2681 650 | 2430 651 | 2932 652 | 2931 653 | 2434 654 | 3154 655 | 3046 656 | 2598 657 | 2366 658 | 3015 659 | 3147 660 | 2944 661 | 2582 662 | 3274 663 | 2987 664 | 2642 665 | 2547 666 | 2420 667 | 2930 668 | 2750 669 | 2417 670 | 2808 671 | 3141 672 | 2997 673 | 2995 674 | 2584 675 | 2312 676 | 3033 677 | 3070 678 | 3065 679 | 2509 680 | 3314 681 | 2396 682 | 2543 683 | 2423 684 | 3170 685 | 2389 686 | 3289 687 | 2728 688 | 2540 689 | 2437 690 | 2486 691 | 2895 692 | 3017 693 | 2853 694 | 2406 695 | 2346 696 | 2877 697 | 2472 698 | 3210 699 | 2637 700 | 2927 701 | 2789 702 | 2330 703 | 3088 704 | 3102 705 | 2616 706 | 3081 707 | 2902 708 | 3205 709 | 3320 710 | 3165 711 | 2984 712 | 3185 713 | 2707 714 | 3255 715 | 2583 716 | 2773 717 | 2742 718 | 3024 719 | 2402 720 | 2718 721 | 2882 722 | 2575 723 | 3281 724 | 2786 725 | 2855 726 | 3014 727 | 2401 728 | 2535 729 | 2687 730 | 2495 731 | 3113 732 | 2609 733 | 2559 734 | 2665 735 | 2530 736 | 3293 737 | 2399 738 | 2605 739 | 2690 740 | 3133 741 | 2799 742 | 2533 743 | 2695 744 | 2713 745 | 2886 746 | 2691 747 | 2549 748 | 3077 749 | 3002 750 | 3049 751 | 3051 752 | 3087 753 | 2444 754 | 3085 755 | 3135 756 | 2702 757 | 3211 758 | 3108 759 | 2501 760 | 2769 761 | 3290 762 | 2465 763 | 3025 764 | 3019 765 | 2385 766 | 2940 767 | 2657 768 | 2610 769 | 2525 770 | 2941 771 | 3078 772 | 2341 773 | 2916 774 | 2956 775 | 2375 776 | 2880 777 | 3009 778 | 2780 779 | 2370 780 | 2925 781 | 2332 782 | 3146 783 | 2315 784 | 2809 785 | 3145 786 | 3106 787 | 2782 788 | 2760 789 | 2493 790 | 2765 791 | 2556 792 | 2890 793 | 2400 794 | 2339 795 | 3201 796 | 2818 797 | 3248 798 | 3280 799 | 2570 800 | 2569 801 | 2937 802 | 3174 803 | 2836 804 | 2708 805 | 2820 806 | 3195 807 | 2617 808 | 3197 809 | 2319 810 | 2744 811 | 2615 812 | 2825 813 | 2603 814 | 2914 815 | 2531 816 | 3193 817 | 2624 818 | 2365 819 | 2810 820 | 3239 821 | 3159 822 | 2537 823 | 2844 824 | 2758 825 | 2938 826 | 3037 827 | 2503 828 | 3297 829 | 2885 830 | 2608 831 | 2494 832 | 2712 833 | 2408 834 | 2901 835 | 2704 836 | 2536 837 | 2373 838 | 2478 839 | 2723 840 | 3076 841 | 2627 842 | 2369 843 | 2669 844 | 3006 845 | 2628 846 | 2788 847 | 3276 848 | 2435 849 | 3139 850 | 3235 851 | 2527 852 | 2571 853 | 2815 854 | 2442 855 | 2892 856 | 2978 857 | 2746 858 | 3150 859 | 2574 860 | 2725 861 | 3188 862 | 2601 863 | 2378 864 | 3075 865 | 2632 866 | 2794 867 | 3270 868 | 3071 869 | 2506 870 | 3126 871 | 3236 872 | 3257 873 | 2824 874 | 2989 875 | 2950 876 | 2428 877 | 2405 878 | 3156 879 | 2447 880 | 2787 881 | 2805 882 | 2720 883 | 2403 884 | 2811 885 | 2329 886 | 2474 887 | 2785 888 | 2350 889 | 2507 890 | 2416 891 | 3112 892 | 2475 893 | 2876 894 | 2585 895 | 2487 896 | 3072 897 | 3082 898 | 2943 899 | 2757 900 | 2388 901 | 2600 902 | 3294 903 | 2756 904 | 3142 905 | 3041 906 | 2594 907 | 2998 908 | 3047 909 | 2379 910 | 2980 911 | 2454 912 | 2862 913 | 3175 914 | 2588 915 | 3031 916 | 3012 917 | 2889 918 | 2500 919 | 2791 920 | 2854 921 | 2619 922 | 2395 923 | 2807 924 | 2740 925 | 2412 926 | 3131 927 | 3013 928 | 2939 929 | 2651 930 | 2490 931 | 2988 932 | 2863 933 | 3225 934 | 2745 935 | 2714 936 | 3160 937 | 3124 938 | 2849 939 | 2676 940 | 2872 941 | 3287 942 | 3189 943 | 2716 944 | 3115 945 | 2928 946 | 2871 947 | 2591 948 | 2717 949 | 2546 950 | 2777 951 | 3298 952 | 2397 953 | 3187 954 | 2726 955 | 2336 956 | 3268 957 | 2477 958 | 2904 959 | 2846 960 | 3121 961 | 2899 962 | 2510 963 | 2806 964 | 2963 965 | 3313 966 | 2679 967 | 3302 968 | 2663 969 | 3053 970 | 2469 971 | 2999 972 | 3311 973 | 2470 974 | 2638 975 | 3120 976 | 3171 977 | 2689 978 | 2922 979 | 2607 980 | 2721 981 | 2993 982 | 2887 983 | 2837 984 | 2929 985 | 2829 986 | 3234 987 | 2649 988 | 2337 989 | 2759 990 | 2778 991 | 2771 992 | 2404 993 | 2589 994 | 3123 995 | 3209 996 | 2729 997 | 3252 998 | 2606 999 | 2579 1000 | 2552 1001 | -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/citeseer/raw/ind.citeseer.tx -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/citeseer/raw/ind.citeseer.ty -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/citeseer/raw/ind.citeseer.x -------------------------------------------------------------------------------- /data/Planetoid/citeseer/raw/ind.citeseer.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/citeseer/raw/ind.citeseer.y -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/cora/raw/ind.cora.allx -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/cora/raw/ind.cora.ally -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/cora/raw/ind.cora.graph -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.test.index: -------------------------------------------------------------------------------- 1 | 2692 2 | 2532 3 | 2050 4 | 1715 5 | 2362 6 | 2609 7 | 2622 8 | 1975 9 | 2081 10 | 1767 11 | 2263 12 | 1725 13 | 2588 14 | 2259 15 | 2357 16 | 1998 17 | 2574 18 | 2179 19 | 2291 20 | 2382 21 | 1812 22 | 1751 23 | 2422 24 | 1937 25 | 2631 26 | 2510 27 | 2378 28 | 2589 29 | 2345 30 | 1943 31 | 1850 32 | 2298 33 | 1825 34 | 2035 35 | 2507 36 | 2313 37 | 1906 38 | 1797 39 | 2023 40 | 2159 41 | 2495 42 | 1886 43 | 2122 44 | 2369 45 | 2461 46 | 1925 47 | 2565 48 | 1858 49 | 2234 50 | 2000 51 | 1846 52 | 2318 53 | 1723 54 | 2559 55 | 2258 56 | 1763 57 | 1991 58 | 1922 59 | 2003 60 | 2662 61 | 2250 62 | 2064 63 | 2529 64 | 1888 65 | 2499 66 | 2454 67 | 2320 68 | 2287 69 | 2203 70 | 2018 71 | 2002 72 | 2632 73 | 2554 74 | 2314 75 | 2537 76 | 1760 77 | 2088 78 | 2086 79 | 2218 80 | 2605 81 | 1953 82 | 2403 83 | 1920 84 | 2015 85 | 2335 86 | 2535 87 | 1837 88 | 2009 89 | 1905 90 | 2636 91 | 1942 92 | 2193 93 | 2576 94 | 2373 95 | 1873 96 | 2463 97 | 2509 98 | 1954 99 | 2656 100 | 2455 101 | 2494 102 | 2295 103 | 2114 104 | 2561 105 | 2176 106 | 2275 107 | 2635 108 | 2442 109 | 2704 110 | 2127 111 | 2085 112 | 2214 113 | 2487 114 | 1739 115 | 2543 116 | 1783 117 | 2485 118 | 2262 119 | 2472 120 | 2326 121 | 1738 122 | 2170 123 | 2100 124 | 2384 125 | 2152 126 | 2647 127 | 2693 128 | 2376 129 | 1775 130 | 1726 131 | 2476 132 | 2195 133 | 1773 134 | 1793 135 | 2194 136 | 2581 137 | 1854 138 | 2524 139 | 1945 140 | 1781 141 | 1987 142 | 2599 143 | 1744 144 | 2225 145 | 2300 146 | 1928 147 | 2042 148 | 2202 149 | 1958 150 | 1816 151 | 1916 152 | 2679 153 | 2190 154 | 1733 155 | 2034 156 | 2643 157 | 2177 158 | 1883 159 | 1917 160 | 1996 161 | 2491 162 | 2268 163 | 2231 164 | 2471 165 | 1919 166 | 1909 167 | 2012 168 | 2522 169 | 1865 170 | 2466 171 | 2469 172 | 2087 173 | 2584 174 | 2563 175 | 1924 176 | 2143 177 | 1736 178 | 1966 179 | 2533 180 | 2490 181 | 2630 182 | 1973 183 | 2568 184 | 1978 185 | 2664 186 | 2633 187 | 2312 188 | 2178 189 | 1754 190 | 2307 191 | 2480 192 | 1960 193 | 1742 194 | 1962 195 | 2160 196 | 2070 197 | 2553 198 | 2433 199 | 1768 200 | 2659 201 | 2379 202 | 2271 203 | 1776 204 | 2153 205 | 1877 206 | 2027 207 | 2028 208 | 2155 209 | 2196 210 | 2483 211 | 2026 212 | 2158 213 | 2407 214 | 1821 215 | 2131 216 | 2676 217 | 2277 218 | 2489 219 | 2424 220 | 1963 221 | 1808 222 | 1859 223 | 2597 224 | 2548 225 | 2368 226 | 1817 227 | 2405 228 | 2413 229 | 2603 230 | 2350 231 | 2118 232 | 2329 233 | 1969 234 | 2577 235 | 2475 236 | 2467 237 | 2425 238 | 1769 239 | 2092 240 | 2044 241 | 2586 242 | 2608 243 | 1983 244 | 2109 245 | 2649 246 | 1964 247 | 2144 248 | 1902 249 | 2411 250 | 2508 251 | 2360 252 | 1721 253 | 2005 254 | 2014 255 | 2308 256 | 2646 257 | 1949 258 | 1830 259 | 2212 260 | 2596 261 | 1832 262 | 1735 263 | 1866 264 | 2695 265 | 1941 266 | 2546 267 | 2498 268 | 2686 269 | 2665 270 | 1784 271 | 2613 272 | 1970 273 | 2021 274 | 2211 275 | 2516 276 | 2185 277 | 2479 278 | 2699 279 | 2150 280 | 1990 281 | 2063 282 | 2075 283 | 1979 284 | 2094 285 | 1787 286 | 2571 287 | 2690 288 | 1926 289 | 2341 290 | 2566 291 | 1957 292 | 1709 293 | 1955 294 | 2570 295 | 2387 296 | 1811 297 | 2025 298 | 2447 299 | 2696 300 | 2052 301 | 2366 302 | 1857 303 | 2273 304 | 2245 305 | 2672 306 | 2133 307 | 2421 308 | 1929 309 | 2125 310 | 2319 311 | 2641 312 | 2167 313 | 2418 314 | 1765 315 | 1761 316 | 1828 317 | 2188 318 | 1972 319 | 1997 320 | 2419 321 | 2289 322 | 2296 323 | 2587 324 | 2051 325 | 2440 326 | 2053 327 | 2191 328 | 1923 329 | 2164 330 | 1861 331 | 2339 332 | 2333 333 | 2523 334 | 2670 335 | 2121 336 | 1921 337 | 1724 338 | 2253 339 | 2374 340 | 1940 341 | 2545 342 | 2301 343 | 2244 344 | 2156 345 | 1849 346 | 2551 347 | 2011 348 | 2279 349 | 2572 350 | 1757 351 | 2400 352 | 2569 353 | 2072 354 | 2526 355 | 2173 356 | 2069 357 | 2036 358 | 1819 359 | 1734 360 | 1880 361 | 2137 362 | 2408 363 | 2226 364 | 2604 365 | 1771 366 | 2698 367 | 2187 368 | 2060 369 | 1756 370 | 2201 371 | 2066 372 | 2439 373 | 1844 374 | 1772 375 | 2383 376 | 2398 377 | 1708 378 | 1992 379 | 1959 380 | 1794 381 | 2426 382 | 2702 383 | 2444 384 | 1944 385 | 1829 386 | 2660 387 | 2497 388 | 2607 389 | 2343 390 | 1730 391 | 2624 392 | 1790 393 | 1935 394 | 1967 395 | 2401 396 | 2255 397 | 2355 398 | 2348 399 | 1931 400 | 2183 401 | 2161 402 | 2701 403 | 1948 404 | 2501 405 | 2192 406 | 2404 407 | 2209 408 | 2331 409 | 1810 410 | 2363 411 | 2334 412 | 1887 413 | 2393 414 | 2557 415 | 1719 416 | 1732 417 | 1986 418 | 2037 419 | 2056 420 | 1867 421 | 2126 422 | 1932 423 | 2117 424 | 1807 425 | 1801 426 | 1743 427 | 2041 428 | 1843 429 | 2388 430 | 2221 431 | 1833 432 | 2677 433 | 1778 434 | 2661 435 | 2306 436 | 2394 437 | 2106 438 | 2430 439 | 2371 440 | 2606 441 | 2353 442 | 2269 443 | 2317 444 | 2645 445 | 2372 446 | 2550 447 | 2043 448 | 1968 449 | 2165 450 | 2310 451 | 1985 452 | 2446 453 | 1982 454 | 2377 455 | 2207 456 | 1818 457 | 1913 458 | 1766 459 | 1722 460 | 1894 461 | 2020 462 | 1881 463 | 2621 464 | 2409 465 | 2261 466 | 2458 467 | 2096 468 | 1712 469 | 2594 470 | 2293 471 | 2048 472 | 2359 473 | 1839 474 | 2392 475 | 2254 476 | 1911 477 | 2101 478 | 2367 479 | 1889 480 | 1753 481 | 2555 482 | 2246 483 | 2264 484 | 2010 485 | 2336 486 | 2651 487 | 2017 488 | 2140 489 | 1842 490 | 2019 491 | 1890 492 | 2525 493 | 2134 494 | 2492 495 | 2652 496 | 2040 497 | 2145 498 | 2575 499 | 2166 500 | 1999 501 | 2434 502 | 1711 503 | 2276 504 | 2450 505 | 2389 506 | 2669 507 | 2595 508 | 1814 509 | 2039 510 | 2502 511 | 1896 512 | 2168 513 | 2344 514 | 2637 515 | 2031 516 | 1977 517 | 2380 518 | 1936 519 | 2047 520 | 2460 521 | 2102 522 | 1745 523 | 2650 524 | 2046 525 | 2514 526 | 1980 527 | 2352 528 | 2113 529 | 1713 530 | 2058 531 | 2558 532 | 1718 533 | 1864 534 | 1876 535 | 2338 536 | 1879 537 | 1891 538 | 2186 539 | 2451 540 | 2181 541 | 2638 542 | 2644 543 | 2103 544 | 2591 545 | 2266 546 | 2468 547 | 1869 548 | 2582 549 | 2674 550 | 2361 551 | 2462 552 | 1748 553 | 2215 554 | 2615 555 | 2236 556 | 2248 557 | 2493 558 | 2342 559 | 2449 560 | 2274 561 | 1824 562 | 1852 563 | 1870 564 | 2441 565 | 2356 566 | 1835 567 | 2694 568 | 2602 569 | 2685 570 | 1893 571 | 2544 572 | 2536 573 | 1994 574 | 1853 575 | 1838 576 | 1786 577 | 1930 578 | 2539 579 | 1892 580 | 2265 581 | 2618 582 | 2486 583 | 2583 584 | 2061 585 | 1796 586 | 1806 587 | 2084 588 | 1933 589 | 2095 590 | 2136 591 | 2078 592 | 1884 593 | 2438 594 | 2286 595 | 2138 596 | 1750 597 | 2184 598 | 1799 599 | 2278 600 | 2410 601 | 2642 602 | 2435 603 | 1956 604 | 2399 605 | 1774 606 | 2129 607 | 1898 608 | 1823 609 | 1938 610 | 2299 611 | 1862 612 | 2420 613 | 2673 614 | 1984 615 | 2204 616 | 1717 617 | 2074 618 | 2213 619 | 2436 620 | 2297 621 | 2592 622 | 2667 623 | 2703 624 | 2511 625 | 1779 626 | 1782 627 | 2625 628 | 2365 629 | 2315 630 | 2381 631 | 1788 632 | 1714 633 | 2302 634 | 1927 635 | 2325 636 | 2506 637 | 2169 638 | 2328 639 | 2629 640 | 2128 641 | 2655 642 | 2282 643 | 2073 644 | 2395 645 | 2247 646 | 2521 647 | 2260 648 | 1868 649 | 1988 650 | 2324 651 | 2705 652 | 2541 653 | 1731 654 | 2681 655 | 2707 656 | 2465 657 | 1785 658 | 2149 659 | 2045 660 | 2505 661 | 2611 662 | 2217 663 | 2180 664 | 1904 665 | 2453 666 | 2484 667 | 1871 668 | 2309 669 | 2349 670 | 2482 671 | 2004 672 | 1965 673 | 2406 674 | 2162 675 | 1805 676 | 2654 677 | 2007 678 | 1947 679 | 1981 680 | 2112 681 | 2141 682 | 1720 683 | 1758 684 | 2080 685 | 2330 686 | 2030 687 | 2432 688 | 2089 689 | 2547 690 | 1820 691 | 1815 692 | 2675 693 | 1840 694 | 2658 695 | 2370 696 | 2251 697 | 1908 698 | 2029 699 | 2068 700 | 2513 701 | 2549 702 | 2267 703 | 2580 704 | 2327 705 | 2351 706 | 2111 707 | 2022 708 | 2321 709 | 2614 710 | 2252 711 | 2104 712 | 1822 713 | 2552 714 | 2243 715 | 1798 716 | 2396 717 | 2663 718 | 2564 719 | 2148 720 | 2562 721 | 2684 722 | 2001 723 | 2151 724 | 2706 725 | 2240 726 | 2474 727 | 2303 728 | 2634 729 | 2680 730 | 2055 731 | 2090 732 | 2503 733 | 2347 734 | 2402 735 | 2238 736 | 1950 737 | 2054 738 | 2016 739 | 1872 740 | 2233 741 | 1710 742 | 2032 743 | 2540 744 | 2628 745 | 1795 746 | 2616 747 | 1903 748 | 2531 749 | 2567 750 | 1946 751 | 1897 752 | 2222 753 | 2227 754 | 2627 755 | 1856 756 | 2464 757 | 2241 758 | 2481 759 | 2130 760 | 2311 761 | 2083 762 | 2223 763 | 2284 764 | 2235 765 | 2097 766 | 1752 767 | 2515 768 | 2527 769 | 2385 770 | 2189 771 | 2283 772 | 2182 773 | 2079 774 | 2375 775 | 2174 776 | 2437 777 | 1993 778 | 2517 779 | 2443 780 | 2224 781 | 2648 782 | 2171 783 | 2290 784 | 2542 785 | 2038 786 | 1855 787 | 1831 788 | 1759 789 | 1848 790 | 2445 791 | 1827 792 | 2429 793 | 2205 794 | 2598 795 | 2657 796 | 1728 797 | 2065 798 | 1918 799 | 2427 800 | 2573 801 | 2620 802 | 2292 803 | 1777 804 | 2008 805 | 1875 806 | 2288 807 | 2256 808 | 2033 809 | 2470 810 | 2585 811 | 2610 812 | 2082 813 | 2230 814 | 1915 815 | 1847 816 | 2337 817 | 2512 818 | 2386 819 | 2006 820 | 2653 821 | 2346 822 | 1951 823 | 2110 824 | 2639 825 | 2520 826 | 1939 827 | 2683 828 | 2139 829 | 2220 830 | 1910 831 | 2237 832 | 1900 833 | 1836 834 | 2197 835 | 1716 836 | 1860 837 | 2077 838 | 2519 839 | 2538 840 | 2323 841 | 1914 842 | 1971 843 | 1845 844 | 2132 845 | 1802 846 | 1907 847 | 2640 848 | 2496 849 | 2281 850 | 2198 851 | 2416 852 | 2285 853 | 1755 854 | 2431 855 | 2071 856 | 2249 857 | 2123 858 | 1727 859 | 2459 860 | 2304 861 | 2199 862 | 1791 863 | 1809 864 | 1780 865 | 2210 866 | 2417 867 | 1874 868 | 1878 869 | 2116 870 | 1961 871 | 1863 872 | 2579 873 | 2477 874 | 2228 875 | 2332 876 | 2578 877 | 2457 878 | 2024 879 | 1934 880 | 2316 881 | 1841 882 | 1764 883 | 1737 884 | 2322 885 | 2239 886 | 2294 887 | 1729 888 | 2488 889 | 1974 890 | 2473 891 | 2098 892 | 2612 893 | 1834 894 | 2340 895 | 2423 896 | 2175 897 | 2280 898 | 2617 899 | 2208 900 | 2560 901 | 1741 902 | 2600 903 | 2059 904 | 1747 905 | 2242 906 | 2700 907 | 2232 908 | 2057 909 | 2147 910 | 2682 911 | 1792 912 | 1826 913 | 2120 914 | 1895 915 | 2364 916 | 2163 917 | 1851 918 | 2391 919 | 2414 920 | 2452 921 | 1803 922 | 1989 923 | 2623 924 | 2200 925 | 2528 926 | 2415 927 | 1804 928 | 2146 929 | 2619 930 | 2687 931 | 1762 932 | 2172 933 | 2270 934 | 2678 935 | 2593 936 | 2448 937 | 1882 938 | 2257 939 | 2500 940 | 1899 941 | 2478 942 | 2412 943 | 2107 944 | 1746 945 | 2428 946 | 2115 947 | 1800 948 | 1901 949 | 2397 950 | 2530 951 | 1912 952 | 2108 953 | 2206 954 | 2091 955 | 1740 956 | 2219 957 | 1976 958 | 2099 959 | 2142 960 | 2671 961 | 2668 962 | 2216 963 | 2272 964 | 2229 965 | 2666 966 | 2456 967 | 2534 968 | 2697 969 | 2688 970 | 2062 971 | 2691 972 | 2689 973 | 2154 974 | 2590 975 | 2626 976 | 2390 977 | 1813 978 | 2067 979 | 1952 980 | 2518 981 | 2358 982 | 1789 983 | 2076 984 | 2049 985 | 2119 986 | 2013 987 | 2124 988 | 2556 989 | 2105 990 | 2093 991 | 1885 992 | 2305 993 | 2354 994 | 2135 995 | 2601 996 | 1770 997 | 1995 998 | 2504 999 | 1749 1000 | 2157 1001 | -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/cora/raw/ind.cora.tx -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/cora/raw/ind.cora.ty -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/cora/raw/ind.cora.x -------------------------------------------------------------------------------- /data/Planetoid/cora/raw/ind.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/cora/raw/ind.cora.y -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/pubmed/raw/ind.pubmed.allx -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/pubmed/raw/ind.pubmed.ally -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/pubmed/raw/ind.pubmed.graph -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.test.index: -------------------------------------------------------------------------------- 1 | 18747 2 | 19392 3 | 19181 4 | 18843 5 | 19221 6 | 18962 7 | 19560 8 | 19097 9 | 18966 10 | 19014 11 | 18756 12 | 19313 13 | 19000 14 | 19569 15 | 19359 16 | 18854 17 | 18970 18 | 19073 19 | 19661 20 | 19180 21 | 19377 22 | 18750 23 | 19401 24 | 18788 25 | 19224 26 | 19447 27 | 19017 28 | 19241 29 | 18890 30 | 18908 31 | 18965 32 | 19001 33 | 18849 34 | 19641 35 | 18852 36 | 19222 37 | 19172 38 | 18762 39 | 19156 40 | 19162 41 | 18856 42 | 18763 43 | 19318 44 | 18826 45 | 19712 46 | 19192 47 | 19695 48 | 19030 49 | 19523 50 | 19249 51 | 19079 52 | 19232 53 | 19455 54 | 18743 55 | 18800 56 | 19071 57 | 18885 58 | 19593 59 | 19394 60 | 19390 61 | 18832 62 | 19445 63 | 18838 64 | 19632 65 | 19548 66 | 19546 67 | 18825 68 | 19498 69 | 19266 70 | 19117 71 | 19595 72 | 19252 73 | 18730 74 | 18913 75 | 18809 76 | 19452 77 | 19520 78 | 19274 79 | 19555 80 | 19388 81 | 18919 82 | 19099 83 | 19637 84 | 19403 85 | 18720 86 | 19526 87 | 18905 88 | 19451 89 | 19408 90 | 18923 91 | 18794 92 | 19322 93 | 19431 94 | 18912 95 | 18841 96 | 19239 97 | 19125 98 | 19258 99 | 19565 100 | 18898 101 | 19482 102 | 19029 103 | 18778 104 | 19096 105 | 19684 106 | 19552 107 | 18765 108 | 19361 109 | 19171 110 | 19367 111 | 19623 112 | 19402 113 | 19327 114 | 19118 115 | 18888 116 | 18726 117 | 19510 118 | 18831 119 | 19490 120 | 19576 121 | 19050 122 | 18729 123 | 18896 124 | 19246 125 | 19012 126 | 18862 127 | 18873 128 | 19193 129 | 19693 130 | 19474 131 | 18953 132 | 19115 133 | 19182 134 | 19269 135 | 19116 136 | 18837 137 | 18872 138 | 19007 139 | 19212 140 | 18798 141 | 19102 142 | 18772 143 | 19660 144 | 19511 145 | 18914 146 | 18886 147 | 19672 148 | 19360 149 | 19213 150 | 18810 151 | 19420 152 | 19512 153 | 18719 154 | 19432 155 | 19350 156 | 19127 157 | 18782 158 | 19587 159 | 18924 160 | 19488 161 | 18781 162 | 19340 163 | 19190 164 | 19383 165 | 19094 166 | 18835 167 | 19487 168 | 19230 169 | 18791 170 | 18882 171 | 18937 172 | 18928 173 | 18755 174 | 18802 175 | 19516 176 | 18795 177 | 18786 178 | 19273 179 | 19349 180 | 19398 181 | 19626 182 | 19130 183 | 19351 184 | 19489 185 | 19446 186 | 18959 187 | 19025 188 | 18792 189 | 18878 190 | 19304 191 | 19629 192 | 19061 193 | 18785 194 | 19194 195 | 19179 196 | 19210 197 | 19417 198 | 19583 199 | 19415 200 | 19443 201 | 18739 202 | 19662 203 | 18904 204 | 18910 205 | 18901 206 | 18960 207 | 18722 208 | 18827 209 | 19290 210 | 18842 211 | 19389 212 | 19344 213 | 18961 214 | 19098 215 | 19147 216 | 19334 217 | 19358 218 | 18829 219 | 18984 220 | 18931 221 | 18742 222 | 19320 223 | 19111 224 | 19196 225 | 18887 226 | 18991 227 | 19469 228 | 18990 229 | 18876 230 | 19261 231 | 19270 232 | 19522 233 | 19088 234 | 19284 235 | 19646 236 | 19493 237 | 19225 238 | 19615 239 | 19449 240 | 19043 241 | 19674 242 | 19391 243 | 18918 244 | 19155 245 | 19110 246 | 18815 247 | 19131 248 | 18834 249 | 19715 250 | 19603 251 | 19688 252 | 19133 253 | 19053 254 | 19166 255 | 19066 256 | 18893 257 | 18757 258 | 19582 259 | 19282 260 | 19257 261 | 18869 262 | 19467 263 | 18954 264 | 19371 265 | 19151 266 | 19462 267 | 19598 268 | 19653 269 | 19187 270 | 19624 271 | 19564 272 | 19534 273 | 19581 274 | 19478 275 | 18985 276 | 18746 277 | 19342 278 | 18777 279 | 19696 280 | 18824 281 | 19138 282 | 18728 283 | 19643 284 | 19199 285 | 18731 286 | 19168 287 | 18948 288 | 19216 289 | 19697 290 | 19347 291 | 18808 292 | 18725 293 | 19134 294 | 18847 295 | 18828 296 | 18996 297 | 19106 298 | 19485 299 | 18917 300 | 18911 301 | 18776 302 | 19203 303 | 19158 304 | 18895 305 | 19165 306 | 19382 307 | 18780 308 | 18836 309 | 19373 310 | 19659 311 | 18947 312 | 19375 313 | 19299 314 | 18761 315 | 19366 316 | 18754 317 | 19248 318 | 19416 319 | 19658 320 | 19638 321 | 19034 322 | 19281 323 | 18844 324 | 18922 325 | 19491 326 | 19272 327 | 19341 328 | 19068 329 | 19332 330 | 19559 331 | 19293 332 | 18804 333 | 18933 334 | 18935 335 | 19405 336 | 18936 337 | 18945 338 | 18943 339 | 18818 340 | 18797 341 | 19570 342 | 19464 343 | 19428 344 | 19093 345 | 19433 346 | 18986 347 | 19161 348 | 19255 349 | 19157 350 | 19046 351 | 19292 352 | 19434 353 | 19298 354 | 18724 355 | 19410 356 | 19694 357 | 19214 358 | 19640 359 | 19189 360 | 18963 361 | 19218 362 | 19585 363 | 19041 364 | 19550 365 | 19123 366 | 19620 367 | 19376 368 | 19561 369 | 18944 370 | 19706 371 | 19056 372 | 19283 373 | 18741 374 | 19319 375 | 19144 376 | 19542 377 | 18821 378 | 19404 379 | 19080 380 | 19303 381 | 18793 382 | 19306 383 | 19678 384 | 19435 385 | 19519 386 | 19566 387 | 19278 388 | 18946 389 | 19536 390 | 19020 391 | 19057 392 | 19198 393 | 19333 394 | 19649 395 | 19699 396 | 19399 397 | 19654 398 | 19136 399 | 19465 400 | 19321 401 | 19577 402 | 18907 403 | 19665 404 | 19386 405 | 19596 406 | 19247 407 | 19473 408 | 19568 409 | 19355 410 | 18925 411 | 19586 412 | 18982 413 | 19616 414 | 19495 415 | 19612 416 | 19023 417 | 19438 418 | 18817 419 | 19692 420 | 19295 421 | 19414 422 | 19676 423 | 19472 424 | 19107 425 | 19062 426 | 19035 427 | 18883 428 | 19409 429 | 19052 430 | 19606 431 | 19091 432 | 19651 433 | 19475 434 | 19413 435 | 18796 436 | 19369 437 | 19639 438 | 19701 439 | 19461 440 | 19645 441 | 19251 442 | 19063 443 | 19679 444 | 19545 445 | 19081 446 | 19363 447 | 18995 448 | 19549 449 | 18790 450 | 18855 451 | 18833 452 | 18899 453 | 19395 454 | 18717 455 | 19647 456 | 18768 457 | 19103 458 | 19245 459 | 18819 460 | 18779 461 | 19656 462 | 19076 463 | 18745 464 | 18971 465 | 19197 466 | 19711 467 | 19074 468 | 19128 469 | 19466 470 | 19139 471 | 19309 472 | 19324 473 | 18814 474 | 19092 475 | 19627 476 | 19060 477 | 18806 478 | 18929 479 | 18737 480 | 18942 481 | 18906 482 | 18858 483 | 19456 484 | 19253 485 | 19716 486 | 19104 487 | 19667 488 | 19574 489 | 18903 490 | 19237 491 | 18864 492 | 19556 493 | 19364 494 | 18952 495 | 19008 496 | 19323 497 | 19700 498 | 19170 499 | 19267 500 | 19345 501 | 19238 502 | 18909 503 | 18892 504 | 19109 505 | 19704 506 | 18902 507 | 19275 508 | 19680 509 | 18723 510 | 19242 511 | 19112 512 | 19169 513 | 18956 514 | 19343 515 | 19650 516 | 19541 517 | 19698 518 | 19521 519 | 19087 520 | 18976 521 | 19038 522 | 18775 523 | 18968 524 | 19671 525 | 19412 526 | 19407 527 | 19573 528 | 19027 529 | 18813 530 | 19357 531 | 19460 532 | 19673 533 | 19481 534 | 19036 535 | 19614 536 | 18787 537 | 19195 538 | 18732 539 | 18884 540 | 19613 541 | 19657 542 | 19575 543 | 19226 544 | 19589 545 | 19234 546 | 19617 547 | 19707 548 | 19484 549 | 18740 550 | 19424 551 | 18784 552 | 19419 553 | 19159 554 | 18865 555 | 19105 556 | 19315 557 | 19480 558 | 19664 559 | 19378 560 | 18803 561 | 19605 562 | 18870 563 | 19042 564 | 19426 565 | 18848 566 | 19223 567 | 19509 568 | 19532 569 | 18752 570 | 19691 571 | 18718 572 | 19209 573 | 19362 574 | 19090 575 | 19492 576 | 19567 577 | 19687 578 | 19018 579 | 18830 580 | 19530 581 | 19554 582 | 19119 583 | 19442 584 | 19558 585 | 19527 586 | 19427 587 | 19291 588 | 19543 589 | 19422 590 | 19142 591 | 18897 592 | 18950 593 | 19425 594 | 19002 595 | 19588 596 | 18978 597 | 19551 598 | 18930 599 | 18736 600 | 19101 601 | 19215 602 | 19150 603 | 19263 604 | 18949 605 | 18974 606 | 18759 607 | 19335 608 | 19200 609 | 19129 610 | 19328 611 | 19437 612 | 18988 613 | 19429 614 | 19368 615 | 19406 616 | 19049 617 | 18811 618 | 19296 619 | 19256 620 | 19385 621 | 19602 622 | 18770 623 | 19337 624 | 19580 625 | 19476 626 | 19045 627 | 19132 628 | 19089 629 | 19120 630 | 19265 631 | 19483 632 | 18767 633 | 19227 634 | 18934 635 | 19069 636 | 18820 637 | 19006 638 | 19459 639 | 18927 640 | 19037 641 | 19280 642 | 19441 643 | 18823 644 | 19015 645 | 19114 646 | 19618 647 | 18957 648 | 19176 649 | 18853 650 | 19648 651 | 19201 652 | 19444 653 | 19279 654 | 18751 655 | 19302 656 | 19505 657 | 18733 658 | 19601 659 | 19533 660 | 18863 661 | 19708 662 | 19387 663 | 19346 664 | 19152 665 | 19206 666 | 18851 667 | 19338 668 | 19681 669 | 19380 670 | 19055 671 | 18766 672 | 19085 673 | 19591 674 | 19547 675 | 18958 676 | 19146 677 | 18840 678 | 19051 679 | 19021 680 | 19207 681 | 19235 682 | 19086 683 | 18979 684 | 19300 685 | 18939 686 | 19100 687 | 19619 688 | 19287 689 | 18980 690 | 19277 691 | 19326 692 | 19108 693 | 18920 694 | 19625 695 | 19374 696 | 19078 697 | 18734 698 | 19634 699 | 19339 700 | 18877 701 | 19423 702 | 19652 703 | 19683 704 | 19044 705 | 18983 706 | 19330 707 | 19529 708 | 19714 709 | 19468 710 | 19075 711 | 19540 712 | 18839 713 | 19022 714 | 19286 715 | 19537 716 | 19175 717 | 19463 718 | 19167 719 | 19705 720 | 19562 721 | 19244 722 | 19486 723 | 19611 724 | 18801 725 | 19178 726 | 19590 727 | 18846 728 | 19450 729 | 19205 730 | 19381 731 | 18941 732 | 19670 733 | 19185 734 | 19504 735 | 19633 736 | 18997 737 | 19113 738 | 19397 739 | 19636 740 | 19709 741 | 19289 742 | 19264 743 | 19353 744 | 19584 745 | 19126 746 | 18938 747 | 19669 748 | 18964 749 | 19276 750 | 18774 751 | 19173 752 | 19231 753 | 18973 754 | 18769 755 | 19064 756 | 19040 757 | 19668 758 | 18738 759 | 19082 760 | 19655 761 | 19236 762 | 19352 763 | 19609 764 | 19628 765 | 18951 766 | 19384 767 | 19122 768 | 18875 769 | 18992 770 | 18753 771 | 19379 772 | 19254 773 | 19301 774 | 19506 775 | 19135 776 | 19010 777 | 19682 778 | 19400 779 | 19579 780 | 19316 781 | 19553 782 | 19208 783 | 19635 784 | 19644 785 | 18891 786 | 19024 787 | 18989 788 | 19250 789 | 18850 790 | 19317 791 | 18915 792 | 19607 793 | 18799 794 | 18881 795 | 19479 796 | 19031 797 | 19365 798 | 19164 799 | 18744 800 | 18760 801 | 19502 802 | 19058 803 | 19517 804 | 18735 805 | 19448 806 | 19243 807 | 19453 808 | 19285 809 | 18857 810 | 19439 811 | 19016 812 | 18975 813 | 19503 814 | 18998 815 | 18981 816 | 19186 817 | 18994 818 | 19240 819 | 19631 820 | 19070 821 | 19174 822 | 18900 823 | 19065 824 | 19220 825 | 19229 826 | 18880 827 | 19308 828 | 19372 829 | 19496 830 | 18771 831 | 19325 832 | 19538 833 | 19033 834 | 18874 835 | 19077 836 | 19211 837 | 18764 838 | 19458 839 | 19571 840 | 19121 841 | 19019 842 | 19059 843 | 19497 844 | 18969 845 | 19666 846 | 19297 847 | 19219 848 | 19622 849 | 19184 850 | 18977 851 | 19702 852 | 19539 853 | 19329 854 | 19095 855 | 19675 856 | 18972 857 | 19514 858 | 19703 859 | 19188 860 | 18866 861 | 18812 862 | 19314 863 | 18822 864 | 18845 865 | 19494 866 | 19411 867 | 18916 868 | 19686 869 | 18967 870 | 19294 871 | 19143 872 | 19204 873 | 18805 874 | 19689 875 | 19233 876 | 18758 877 | 18748 878 | 19011 879 | 19685 880 | 19336 881 | 19608 882 | 19454 883 | 19124 884 | 18868 885 | 18807 886 | 19544 887 | 19621 888 | 19228 889 | 19154 890 | 19141 891 | 19145 892 | 19153 893 | 18860 894 | 19163 895 | 19393 896 | 19268 897 | 19160 898 | 19305 899 | 19259 900 | 19471 901 | 19524 902 | 18783 903 | 19396 904 | 18894 905 | 19430 906 | 19690 907 | 19348 908 | 19597 909 | 19592 910 | 19677 911 | 18889 912 | 19331 913 | 18773 914 | 19137 915 | 19009 916 | 18932 917 | 19599 918 | 18816 919 | 19054 920 | 19067 921 | 19477 922 | 19191 923 | 18921 924 | 18940 925 | 19578 926 | 19183 927 | 19004 928 | 19072 929 | 19710 930 | 19005 931 | 19610 932 | 18955 933 | 19457 934 | 19148 935 | 18859 936 | 18993 937 | 19642 938 | 19047 939 | 19418 940 | 19535 941 | 19600 942 | 19312 943 | 19039 944 | 19028 945 | 18879 946 | 19003 947 | 19026 948 | 19013 949 | 19149 950 | 19177 951 | 19217 952 | 18987 953 | 19354 954 | 19525 955 | 19202 956 | 19084 957 | 19032 958 | 18749 959 | 18867 960 | 19048 961 | 18999 962 | 19260 963 | 19630 964 | 18727 965 | 19356 966 | 19083 967 | 18926 968 | 18789 969 | 19370 970 | 18861 971 | 19311 972 | 19557 973 | 19531 974 | 19436 975 | 19140 976 | 19310 977 | 19501 978 | 18721 979 | 19604 980 | 19713 981 | 19262 982 | 19563 983 | 19507 984 | 19440 985 | 19572 986 | 19513 987 | 19515 988 | 19518 989 | 19421 990 | 19470 991 | 19499 992 | 19663 993 | 19508 994 | 18871 995 | 19528 996 | 19500 997 | 19307 998 | 19288 999 | 19594 1000 | 19271 1001 | -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/pubmed/raw/ind.pubmed.tx -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/pubmed/raw/ind.pubmed.ty -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/pubmed/raw/ind.pubmed.x -------------------------------------------------------------------------------- /data/Planetoid/pubmed/raw/ind.pubmed.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/Planetoid/pubmed/raw/ind.pubmed.y -------------------------------------------------------------------------------- /data/hgcn_data/airport/airport.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/hgcn_data/airport/airport.p -------------------------------------------------------------------------------- /data/hgcn_data/airport/airport_alldata.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/hgcn_data/airport/airport_alldata.p -------------------------------------------------------------------------------- /data/hgcn_data/disease_nc/disease_nc.feats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/hgcn_data/disease_nc/disease_nc.feats.npz -------------------------------------------------------------------------------- /data/hgcn_data/disease_nc/disease_nc.labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/hgcn_data/disease_nc/disease_nc.labels.npy -------------------------------------------------------------------------------- /data/mini_imagenet/mini_imagenet.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/data/mini_imagenet/mini_imagenet.pkl -------------------------------------------------------------------------------- /figures/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlin-codes/hyperbolicTransformer/798e8fd5e0698cd3c548562dffc00cee53c0f09a/figures/framework.jpg -------------------------------------------------------------------------------- /large/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from scipy import sparse as sp 8 | from sklearn.metrics import roc_auc_score, f1_score 9 | 10 | from torch_sparse import SparseTensor 11 | from google_drive_downloader import GoogleDriveDownloader as gdd 12 | 13 | def rand_train_test_idx(label, train_prop=.5, valid_prop=.25, ignore_negative=True): 14 | """ randomly splits label into train/valid/test splits """ 15 | if ignore_negative: 16 | labeled_nodes = torch.where(label != -1)[0] 17 | else: 18 | labeled_nodes = label 19 | 20 | n = labeled_nodes.shape[0] 21 | train_num = int(n * train_prop) 22 | valid_num = int(n * valid_prop) 23 | 24 | perm = torch.as_tensor(np.random.permutation(n)) 25 | 26 | train_indices = perm[:train_num] 27 | val_indices = perm[train_num:train_num + valid_num] 28 | test_indices = perm[train_num + valid_num:] 29 | 30 | if not ignore_negative: 31 | return train_indices, val_indices, test_indices 32 | 33 | train_idx = labeled_nodes[train_indices] 34 | valid_idx = labeled_nodes[val_indices] 35 | test_idx = labeled_nodes[test_indices] 36 | 37 | return train_idx, valid_idx, test_idx 38 | 39 | def load_fixed_splits(data_dir, dataset, name, protocol): 40 | splits_lst = [] 41 | if name in ['cora', 'citeseer', 'pubmed'] and protocol == 'semi': 42 | splits = {} 43 | splits['train'] = torch.as_tensor(dataset.train_idx) 44 | splits['valid'] = torch.as_tensor(dataset.valid_idx) 45 | splits['test'] = torch.as_tensor(dataset.test_idx) 46 | splits_lst.append(splits) 47 | elif name in ['cora', 'citeseer', 'pubmed', 'chameleon', 'squirrel', 'film', 'cornell', 'texas', 'wisconsin']: 48 | for i in range(10): 49 | splits_file_path = '{}/geom-gcn/splits/{}'.format(data_dir, name) + '_split_0.6_0.2_'+str(i)+'.npz' 50 | splits = {} 51 | with np.load(splits_file_path) as splits_file: 52 | splits['train'] = torch.BoolTensor(splits_file['train_mask']) 53 | splits['valid'] = torch.BoolTensor(splits_file['val_mask']) 54 | splits['test'] = torch.BoolTensor(splits_file['test_mask']) 55 | splits_lst.append(splits) 56 | else: 57 | raise NotImplementedError 58 | 59 | return splits_lst 60 | 61 | def class_rand_splits(label, label_num_per_class, valid_num=500, test_num=1000): 62 | train_idx, non_train_idx = [], [] 63 | idx = torch.arange(label.shape[0]) 64 | class_list = label.squeeze().unique() 65 | for i in range(class_list.shape[0]): 66 | c_i = class_list[i] 67 | idx_i = idx[label.squeeze() == c_i] 68 | n_i = idx_i.shape[0] 69 | rand_idx = idx_i[torch.randperm(n_i)] 70 | train_idx += rand_idx[:label_num_per_class].tolist() 71 | non_train_idx += rand_idx[label_num_per_class:].tolist() 72 | train_idx = torch.as_tensor(train_idx) 73 | non_train_idx = torch.as_tensor(non_train_idx) 74 | non_train_idx = non_train_idx[torch.randperm(non_train_idx.shape[0])] 75 | valid_idx, test_idx = non_train_idx[:valid_num], non_train_idx[valid_num:valid_num+test_num] 76 | 77 | return train_idx, valid_idx, test_idx 78 | 79 | def even_quantile_labels(vals, nclasses, verbose=True): 80 | """ partitions vals into nclasses by a quantile based split, 81 | where the first class is less than the 1/nclasses quantile, 82 | second class is less than the 2/nclasses quantile, and so on 83 | 84 | vals is np array 85 | returns an np array of int class labels 86 | """ 87 | label = -1 * np.ones(vals.shape[0], dtype=np.int) 88 | interval_lst = [] 89 | lower = -np.inf 90 | for k in range(nclasses - 1): 91 | upper = np.quantile(vals, (k + 1) / nclasses) 92 | interval_lst.append((lower, upper)) 93 | inds = (vals >= lower) * (vals < upper) 94 | label[inds] = k 95 | lower = upper 96 | label[vals >= lower] = nclasses - 1 97 | interval_lst.append((lower, np.inf)) 98 | if verbose: 99 | print('Class Label Intervals:') 100 | for class_idx, interval in enumerate(interval_lst): 101 | print(f'Class {class_idx}: [{interval[0]}, {interval[1]})]') 102 | return label 103 | 104 | 105 | def to_planetoid(dataset): 106 | """ 107 | Takes in a NCDataset and returns the dataset in H2GCN Planetoid form, as follows: 108 | x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; 109 | tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; 110 | allx => the feature vectors of both labeled and unlabeled training instances 111 | (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object; 112 | y => the one-hot labels of the labeled training instances as numpy.ndarray object; 113 | ty => the one-hot labels of the test instances as numpy.ndarray object; 114 | ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object; 115 | graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict 116 | object; 117 | split_idx => The ogb dictionary that contains the train, valid, test splits 118 | """ 119 | split_idx = dataset.get_idx_split('random', 0.25) 120 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 121 | 122 | graph, label = dataset[0] 123 | 124 | label = torch.squeeze(label) 125 | 126 | print("generate x") 127 | x = graph['node_feat'][train_idx].numpy() 128 | x = sp.csr_matrix(x) 129 | 130 | tx = graph['node_feat'][test_idx].numpy() 131 | tx = sp.csr_matrix(tx) 132 | 133 | allx = graph['node_feat'].numpy() 134 | allx = sp.csr_matrix(allx) 135 | 136 | y = F.one_hot(label[train_idx]).numpy() 137 | ty = F.one_hot(label[test_idx]).numpy() 138 | ally = F.one_hot(label).numpy() 139 | 140 | edge_index = graph['edge_index'].T 141 | 142 | graph = defaultdict(list) 143 | 144 | for i in range(0, label.shape[0]): 145 | graph[i].append(i) 146 | 147 | for start_edge, end_edge in edge_index: 148 | graph[start_edge.item()].append(end_edge.item()) 149 | 150 | return x, tx, allx, y, ty, ally, graph, split_idx 151 | 152 | 153 | def to_sparse_tensor(edge_index, edge_feat, num_nodes): 154 | """ converts the edge_index into SparseTensor 155 | """ 156 | num_edges = edge_index.size(1) 157 | 158 | (row, col), N, E = edge_index, num_nodes, num_edges 159 | perm = (col * N + row).argsort() 160 | row, col = row[perm], col[perm] 161 | 162 | value = edge_feat[perm] 163 | adj_t = SparseTensor(row=col, col=row, value=value, 164 | sparse_sizes=(N, N), is_sorted=True) 165 | 166 | # Pre-process some important attributes. 167 | adj_t.storage.rowptr() 168 | adj_t.storage.csr2csc() 169 | 170 | return adj_t 171 | 172 | 173 | def normalize(edge_index): 174 | """ normalizes the edge_index 175 | """ 176 | adj_t = edge_index.set_diag() 177 | deg = adj_t.sum(dim=1).to(torch.float) 178 | deg_inv_sqrt = deg.pow(-0.5) 179 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 180 | adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) 181 | return adj_t 182 | 183 | 184 | def gen_normalized_adjs(dataset): 185 | """ returns the normalized adjacency matrix 186 | """ 187 | row, col = dataset.graph['edge_index'] 188 | N = dataset.graph['num_nodes'] 189 | adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) 190 | deg = adj.sum(dim=1).to(torch.float) 191 | D_isqrt = deg.pow(-0.5) 192 | D_isqrt[D_isqrt == float('inf')] = 0 193 | 194 | DAD = D_isqrt.view(-1,1) * adj * D_isqrt.view(1,-1) 195 | DA = D_isqrt.view(-1,1) * D_isqrt.view(-1,1) * adj 196 | AD = adj * D_isqrt.view(1,-1) * D_isqrt.view(1,-1) 197 | return DAD, DA, AD 198 | 199 | def eval_f1(y_true, y_pred): 200 | acc_list = [] 201 | y_true = y_true.detach().cpu().numpy() 202 | y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy() 203 | 204 | for i in range(y_true.shape[1]): 205 | f1 = f1_score(y_true, y_pred, average='micro') 206 | acc_list.append(f1) 207 | 208 | return sum(acc_list)/len(acc_list) 209 | 210 | def eval_acc(y_true, y_pred): 211 | acc_list = [] 212 | y_true = y_true.detach().cpu().numpy() 213 | y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy() 214 | 215 | for i in range(y_true.shape[1]): 216 | is_labeled = y_true[:, i] == y_true[:, i] 217 | correct = y_true[is_labeled, i] == y_pred[is_labeled, i] 218 | acc_list.append(float(np.sum(correct))/len(correct)) 219 | 220 | return sum(acc_list)/len(acc_list) 221 | 222 | 223 | def eval_rocauc(y_true, y_pred): 224 | """ adapted from ogb 225 | https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/evaluate.py""" 226 | rocauc_list = [] 227 | y_true = y_true.detach().cpu().numpy() 228 | if y_true.shape[1] == 1: 229 | # use the predicted class for single-class classification 230 | y_pred = F.softmax(y_pred, dim=-1)[:,1].unsqueeze(1).cpu().numpy() 231 | else: 232 | y_pred = y_pred.detach().cpu().numpy() 233 | 234 | for i in range(y_true.shape[1]): 235 | # AUC is only defined when there is at least one positive data. 236 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 237 | is_labeled = y_true[:, i] == y_true[:, i] 238 | score = roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]) 239 | 240 | rocauc_list.append(score) 241 | 242 | if len(rocauc_list) == 0: 243 | raise RuntimeError( 244 | 'No positively labeled data available. Cannot compute ROC-AUC.') 245 | 246 | return sum(rocauc_list)/len(rocauc_list) 247 | 248 | def convert_to_adj(edge_index,n_node): 249 | '''convert from pyg format edge_index to n by n adj matrix''' 250 | adj=torch.zeros((n_node,n_node)) 251 | row,col=edge_index 252 | adj[row,col]=1 253 | return adj 254 | 255 | def adj_mul(adj_i, adj, N): 256 | adj_i_sp = torch.sparse_coo_tensor(adj_i, torch.ones(adj_i.shape[1], dtype=torch.float).to(adj.device), (N, N)) 257 | adj_sp = torch.sparse_coo_tensor(adj, torch.ones(adj.shape[1], dtype=torch.float).to(adj.device), (N, N)) 258 | adj_j = torch.sparse.mm(adj_i_sp, adj_sp) 259 | adj_j = adj_j.coalesce().indices() 260 | return adj_j 261 | 262 | import subprocess 263 | def get_gpu_memory_map(): 264 | """Get the current gpu usage. 265 | Returns 266 | ------- 267 | usage: dict 268 | Keys are device ids as integers. 269 | Values are memory usage as integers in MB. 270 | """ 271 | result = subprocess.check_output( 272 | [ 273 | 'nvidia-smi', '--query-gpu=memory.used', 274 | '--format=csv,nounits,noheader' 275 | ], encoding='utf-8') 276 | # Convert lines into a dictionary 277 | gpu_memory = np.array([int(x) for x in result.strip().split('\n')]) 278 | # gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 279 | return gpu_memory 280 | 281 | import subprocess 282 | def get_gpu_memory_map(): 283 | """Get the current gpu usage. 284 | Returns 285 | ------- 286 | usage: dict 287 | Keys are device ids as integers. 288 | Values are memory usage as integers in MB. 289 | """ 290 | result = subprocess.check_output( 291 | [ 292 | 'nvidia-smi', '--query-gpu=memory.used', 293 | '--format=csv,nounits,noheader' 294 | ], encoding='utf-8') 295 | # Convert lines into a dictionary 296 | gpu_memory = np.array([int(x) for x in result.strip().split('\n')]) 297 | # gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 298 | return gpu_memory 299 | 300 | def count_parameters(model): 301 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 302 | 303 | 304 | def compute_degrees(edge_index, num_nodes): 305 | out_degrees = torch.bincount(edge_index[0, :], minlength=num_nodes) 306 | in_degrees = torch.bincount(edge_index[1, :], minlength=num_nodes) 307 | 308 | return out_degrees + in_degrees 309 | 310 | dataset_drive_url = { 311 | 'snap-patents' : '1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia', 312 | 'pokec' : '1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y', 313 | 'yelp-chi': '1fAXtTVQS4CfEk4asqrFw9EPmlUPGbGtJ', 314 | } 315 | 316 | splits_drive_url = { 317 | 'snap-patents' : '12xbBRqd8mtG_XkNLH8dRRNZJvVM4Pw-N', 318 | 'pokec' : '1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_', 319 | } -------------------------------------------------------------------------------- /large/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch_geometric.utils import subgraph 5 | 6 | @torch.no_grad() 7 | def evaluate(model, dataset, split_idx, eval_func, criterion, args, degrees, threshold, result=None): 8 | if result is not None: 9 | out = result 10 | else: 11 | model.eval() 12 | out = model(dataset.graph['node_feat'], dataset.graph['edge_index']) 13 | 14 | train_acc = eval_func( 15 | dataset.label[split_idx['train']], out[split_idx['train']]) 16 | valid_acc = eval_func( 17 | dataset.label[split_idx['valid']], out[split_idx['valid']]) 18 | test_acc = eval_func( 19 | dataset.label[split_idx['test']], out[split_idx['test']]) 20 | degrees_in_test = degrees[split_idx['test']] 21 | top20_indices = split_idx['test'][degrees_in_test > threshold] 22 | bottom80_indices = split_idx['test'][degrees_in_test < threshold] 23 | top20_acc = eval_func( 24 | dataset.label[top20_indices], out[top20_indices] 25 | ) 26 | bottom80_acc = eval_func( 27 | dataset.label[bottom80_indices], out[bottom80_indices] 28 | ) 29 | 30 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 31 | if dataset.label.shape[1] == 1: 32 | true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1) 33 | else: 34 | true_label = dataset.label 35 | valid_loss = criterion(out[split_idx['valid']], true_label.squeeze(1)[ 36 | split_idx['valid']].to(torch.float)) 37 | else: 38 | out = F.log_softmax(out, dim=1) 39 | valid_loss = criterion( 40 | out[split_idx['valid']], dataset.label.squeeze(1)[split_idx['valid']]) 41 | 42 | return train_acc, valid_acc, test_acc, valid_loss, out, top20_acc, bottom80_acc 43 | 44 | @torch.no_grad() 45 | def evaluate_large(model, dataset, split_idx, eval_func, criterion, args, degrees, threshold, device="cpu", result=None): 46 | if result is not None: 47 | out = result 48 | else: 49 | model.eval() 50 | 51 | model.to(torch.device(device)) 52 | dataset.label = dataset.label.to(torch.device(device)) 53 | edge_index, x = dataset.graph['edge_index'].to(torch.device(device)), dataset.graph['node_feat'].to(torch.device(device)) 54 | out = model(x, edge_index) 55 | 56 | train_acc = eval_func( 57 | dataset.label[split_idx['train']], out[split_idx['train']]) 58 | valid_acc = eval_func( 59 | dataset.label[split_idx['valid']], out[split_idx['valid']]) 60 | test_acc = eval_func( 61 | dataset.label[split_idx['test']], out[split_idx['test']]) 62 | degrees_in_test = degrees[split_idx['test']] 63 | top20_indices = split_idx['test'][degrees_in_test > threshold] 64 | bottom80_indices = split_idx['test'][degrees_in_test <= threshold] 65 | 66 | top20_acc = eval_func( 67 | dataset.label[top20_indices], out[top20_indices] 68 | ) * top20_indices.shape[0] / split_idx['test'].shape[0] 69 | bottom80_acc = eval_func( 70 | dataset.label[bottom80_indices], out[bottom80_indices] 71 | ) * bottom80_indices.shape[0] / split_idx['test'].shape[0] 72 | # top_acc = test_acc[] 73 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 74 | if dataset.label.shape[1] == 1: 75 | true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1) 76 | else: 77 | true_label = dataset.label 78 | valid_loss = criterion(out[split_idx['valid']], true_label.squeeze(1)[ 79 | split_idx['valid']].to(torch.float)) 80 | else: 81 | out = F.log_softmax(out, dim=1) 82 | valid_loss = criterion( 83 | out[split_idx['valid']], dataset.label.squeeze(1)[split_idx['valid']]) 84 | 85 | return train_acc, valid_acc, test_acc, valid_loss, valid_loss, top20_acc, bottom80_acc 86 | 87 | def evaluate_batch(model, dataset, split_idx, args, device, n, true_label): 88 | num_batch = n // args.batch_size + 1 89 | edge_index, x = dataset.graph['edge_index'], dataset.graph['node_feat'] 90 | train_mask = torch.zeros(n, dtype=torch.bool) 91 | train_mask[split_idx['train']] = True 92 | valid_mask = torch.zeros(n, dtype=torch.bool) 93 | valid_mask[split_idx['valid']] = True 94 | test_mask = torch.zeros(n, dtype=torch.bool) 95 | test_mask[split_idx['test']] = True 96 | 97 | model.to(device) 98 | model.eval() 99 | 100 | idx = torch.randperm(n) 101 | train_total, train_correct=0, 0 102 | valid_total, valid_correct=0, 0 103 | test_total, test_correct=0, 0 104 | 105 | with torch.no_grad(): 106 | for i in range(num_batch): 107 | idx_i = idx[i*args.batch_size:(i+1)*args.batch_size] 108 | x_i = x[idx_i].to(device) 109 | edge_index_i, _ = subgraph(idx_i, edge_index, num_nodes=n, relabel_nodes=True) 110 | edge_index_i = edge_index_i.to(device) 111 | y_i = true_label[idx_i].to(device) 112 | train_mask_i = train_mask[idx_i] 113 | valid_mask_i = valid_mask[idx_i] 114 | test_mask_i = test_mask[idx_i] 115 | 116 | out_i = model(x_i, edge_index_i) 117 | 118 | cur_train_total, cur_train_correct=eval_acc(y_i[train_mask_i], out_i[train_mask_i]) 119 | train_total+=cur_train_total 120 | train_correct+=cur_train_correct 121 | cur_valid_total, cur_valid_correct=eval_acc(y_i[valid_mask_i], out_i[valid_mask_i]) 122 | valid_total+=cur_valid_total 123 | valid_correct+=cur_valid_correct 124 | cur_test_total, cur_test_correct=eval_acc(y_i[test_mask_i], out_i[test_mask_i]) 125 | test_total+=cur_test_total 126 | test_correct+=cur_test_correct 127 | 128 | # train_acc = eval_func( 129 | # dataset.label[split_idx['train']], out[split_idx['train']]) 130 | # valid_acc = eval_func( 131 | # dataset.label[split_idx['valid']], out[split_idx['valid']]) 132 | # test_acc = eval_func( 133 | # dataset.label[split_idx['test']], out[split_idx['test']]) 134 | train_acc=train_correct/train_total 135 | valid_acc=valid_correct/valid_total 136 | test_acc=test_correct/test_total 137 | 138 | return train_acc, valid_acc, test_acc, 0, None 139 | 140 | def eval_acc(true, pred): 141 | ''' 142 | true: (n, 1) 143 | pred: (n, c) 144 | ''' 145 | pred=torch.max(pred,dim=1,keepdim=True)[1] 146 | # cmp=torch.eq(true, pred) 147 | # print(f'pred:{pred}') 148 | # print(cmp) 149 | true_cnt=(true==pred).sum() 150 | 151 | return true.shape[0], true_cnt.item() 152 | 153 | 154 | if __name__=='__main__': 155 | x=torch.arange(4).unsqueeze(1) 156 | y=torch.Tensor([[3,0,0,0], 157 | [3,2,1.5,2.8], 158 | [0,0,2,1], 159 | [0,0,1,3] 160 | ]) 161 | a, b=eval_acc(x, y) 162 | print(x) 163 | print(a,b) 164 | -------------------------------------------------------------------------------- /large/examples/5-runs/run_amazon2M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | graph_weight=0.2 3 | k_in=0.5 4 | k_out=1.0 5 | power_k=2.0 6 | lr=0.005 7 | 8 | python main-batch.py \ 9 | --method hypformer \ 10 | --dataset amazon2m \ 11 | --metric acc \ 12 | --lr $lr \ 13 | --hidden_channels 256 \ 14 | --gnn_num_layers 3 \ 15 | --gnn_dropout 0.0 \ 16 | --weight_decay 0. \ 17 | --gnn_use_residual 1 \ 18 | --gnn_use_weight 1 \ 19 | --gnn_use_bn 1 \ 20 | --gnn_use_init 1 \ 21 | --gnn_use_act 1 \ 22 | --trans_num_layers 1 \ 23 | --trans_dropout 0. \ 24 | --trans_use_residual 1 \ 25 | --trans_use_weight 1 \ 26 | --trans_use_bn 1 \ 27 | --use_graph 1 \ 28 | --graph_weight $graph_weight \ 29 | --batch_size 100000 \ 30 | --seed 123 \ 31 | --runs 5 \ 32 | --epochs 200 \ 33 | --eval_step 1 \ 34 | --device 0 \ 35 | --k_in $k_in \ 36 | --k_out $k_out \ 37 | --power_k $power_k \ 38 | --attention_type linear_focused 39 | 40 | -------------------------------------------------------------------------------- /large/examples/5-runs/run_arxiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | hidden_channel=256 4 | lr=0.005 5 | graph_weight=0.2 6 | k_in=2.0 7 | k_out=0.5 8 | weight_decay=0.0 9 | 10 | python main.py \ 11 | --method hypformer \ 12 | --dataset ogbn-arxiv \ 13 | --metric acc \ 14 | --lr $lr \ 15 | --hidden_channels $hidden_channel \ 16 | --gnn_num_layers 3 \ 17 | --gnn_dropout 0.4 \ 18 | --gnn_use_residual 1 \ 19 | --gnn_use_weight 1 \ 20 | --gnn_use_bn 1 \ 21 | --gnn_use_act 1 \ 22 | --trans_num_layers 1 \ 23 | --trans_dropout 0. \ 24 | --weight_decay $weight_decay \ 25 | --trans_use_residual 1 \ 26 | --trans_use_weight 1 \ 27 | --trans_num_heads 2 \ 28 | --use_graph 1 \ 29 | --graph_weight $graph_weight \ 30 | --seed 123 \ 31 | --runs 5 \ 32 | --epochs 1000 \ 33 | --eval_step 1 \ 34 | --device 0 \ 35 | --k_in $k_in \ 36 | --k_out $k_out \ 37 | --attention_type linear_focused 38 | -------------------------------------------------------------------------------- /large/examples/5-runs/run_protein.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | hidden_channels=256 4 | lr=0.001 5 | graph_weight=0.2 6 | k_in=0.5 7 | k_out=1.0 8 | gnn_dropout=0 9 | trans_dropout=0.0 10 | weight_decay=0.0 11 | batch_size=10000 12 | power_k=1.0 13 | 14 | python main-batch.py \ 15 | --method hypformer \ 16 | --dataset ogbn-proteins \ 17 | --metric rocauc \ 18 | --lr $lr \ 19 | --hidden_channels $hidden_channels \ 20 | --gnn_num_layers 2 \ 21 | --gnn_dropout $gnn_dropout \ 22 | --gnn_use_residual 1 \ 23 | --gnn_use_weight 1 \ 24 | --gnn_use_bn 1 \ 25 | --gnn_use_act 1 \ 26 | --trans_num_layers 1 \ 27 | --trans_num_heads 1 \ 28 | --trans_dropout $trans_dropout \ 29 | --weight_decay $weight_decay \ 30 | --trans_use_residual 1 \ 31 | --trans_use_weight 1 \ 32 | --graph_weight $graph_weight \ 33 | --batch_size $batch_size \ 34 | --seed 123 \ 35 | --runs 5 \ 36 | --epochs 500 \ 37 | --eval_step 5 \ 38 | --device 0 \ 39 | --power_k $power_k \ 40 | --data_dir $hypformer_data_dir \ 41 | --decoder_type euc \ 42 | --attention_type linear_focused -------------------------------------------------------------------------------- /large/examples/amazon2M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | graph_weight=0.2 3 | k_in=0.5 4 | k_out=1.0 5 | power_k=2.0 6 | lr=0.005 7 | 8 | python main-batch.py \ 9 | --method hypformer \ 10 | --dataset amazon2m \ 11 | --metric acc \ 12 | --lr $lr \ 13 | --hidden_channels 256 \ 14 | --gnn_num_layers 3 \ 15 | --gnn_dropout 0.0 \ 16 | --weight_decay 0. \ 17 | --gnn_use_residual 1 \ 18 | --gnn_use_weight 1 \ 19 | --gnn_use_bn 1 \ 20 | --gnn_use_init 1 \ 21 | --gnn_use_act 1 \ 22 | --trans_num_layers 1 \ 23 | --trans_dropout 0. \ 24 | --trans_use_residual 1 \ 25 | --trans_use_weight 1 \ 26 | --trans_use_bn 1 \ 27 | --use_graph 1 \ 28 | --graph_weight $graph_weight \ 29 | --batch_size 100000 \ 30 | --seed 123 \ 31 | --runs 1 \ 32 | --save_result 0 \ 33 | --epochs 200 \ 34 | --eval_step 1 \ 35 | --device 0 \ 36 | --k_in $k_in \ 37 | --k_out $k_out \ 38 | --power_k $power_k \ 39 | --attention_type linear_focused 40 | 41 | -------------------------------------------------------------------------------- /large/examples/arxiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | hidden_channel=256 4 | lr=0.005 5 | graph_weight=0.2 6 | k_in=2.0 7 | k_out=0.5 8 | weight_decay=0.0 9 | 10 | python main.py \ 11 | --method hypformer \ 12 | --dataset ogbn-arxiv \ 13 | --metric acc \ 14 | --lr $lr \ 15 | --hidden_channels $hidden_channel \ 16 | --gnn_num_layers 3 \ 17 | --gnn_dropout 0.4 \ 18 | --gnn_use_residual 1 \ 19 | --gnn_use_weight 1 \ 20 | --gnn_use_bn 1 \ 21 | --gnn_use_act 1 \ 22 | --trans_num_layers 1 \ 23 | --trans_dropout 0. \ 24 | --weight_decay $weight_decay \ 25 | --trans_use_residual 1 \ 26 | --trans_use_weight 1 \ 27 | --trans_num_heads 2 \ 28 | --use_graph 1 \ 29 | --graph_weight $graph_weight \ 30 | --seed 123 \ 31 | --runs 1 \ 32 | --save_result 0 \ 33 | --epochs 1000 \ 34 | --eval_step 1 \ 35 | --device 0 \ 36 | --k_in $k_in \ 37 | --k_out $k_out \ 38 | --attention_type linear_focused 39 | -------------------------------------------------------------------------------- /large/examples/protein.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | hidden_channels=256 4 | lr=0.001 5 | graph_weight=0.2 6 | k_in=0.5 7 | k_out=1.0 8 | gnn_dropout=0 9 | trans_dropout=0.0 10 | weight_decay=0.0 11 | batch_size=10000 12 | power_k=1.0 13 | 14 | python main-batch.py \ 15 | --method hypformer \ 16 | --dataset ogbn-proteins \ 17 | --metric rocauc \ 18 | --lr $lr \ 19 | --hidden_channels $hidden_channels \ 20 | --gnn_num_layers 2 \ 21 | --gnn_dropout $gnn_dropout \ 22 | --gnn_use_residual 1 \ 23 | --gnn_use_weight 1 \ 24 | --gnn_use_bn 1 \ 25 | --gnn_use_act 1 \ 26 | --trans_num_layers 1 \ 27 | --trans_num_heads 1 \ 28 | --trans_dropout $trans_dropout \ 29 | --weight_decay $weight_decay \ 30 | --trans_use_residual 1 \ 31 | --trans_use_weight 1 \ 32 | --graph_weight $graph_weight \ 33 | --batch_size $batch_size \ 34 | --seed 123 \ 35 | --runs 1 \ 36 | --save_result 0 \ 37 | --epochs 500 \ 38 | --eval_step 5 \ 39 | --device 0 \ 40 | --power_k $power_k \ 41 | --data_dir $hypformer_data_dir \ 42 | --decoder_type euc \ 43 | --attention_type linear_focused -------------------------------------------------------------------------------- /large/load_data.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import numpy as np 3 | import scipy.sparse 4 | import torch 5 | import csv 6 | import json 7 | from os import path 8 | 9 | DATAPATH = '../../data/' 10 | 11 | def load_fb100(filename): 12 | # e.g. filename = Rutgers89 or Cornell5 or Wisconsin87 or Amherst41 13 | # columns are: student/faculty, gender, major, 14 | # second major/minor, dorm/house, year/ high school 15 | # 0 denotes missing entry 16 | mat = scipy.io.loadmat(DATAPATH + 'facebook100/' + filename + '.mat') 17 | A = mat['A'] 18 | metadata = mat['local_info'] 19 | return A, metadata 20 | 21 | def load_twitch(lang): 22 | assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset' 23 | filepath = DATAPATH + f"twitch/{lang}" 24 | label = [] 25 | node_ids = [] 26 | src = [] 27 | targ = [] 28 | uniq_ids = set() 29 | with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f: 30 | reader = csv.reader(f) 31 | next(reader) 32 | for row in reader: 33 | node_id = int(row[5]) 34 | # handle FR case of non-unique rows 35 | if node_id not in uniq_ids: 36 | uniq_ids.add(node_id) 37 | label.append(int(row[2]=="True")) 38 | node_ids.append(int(row[5])) 39 | 40 | node_ids = np.array(node_ids, dtype=np.int) 41 | with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f: 42 | reader = csv.reader(f) 43 | next(reader) 44 | for row in reader: 45 | src.append(int(row[0])) 46 | targ.append(int(row[1])) 47 | with open(f"{filepath}/musae_{lang}_features.json", 'r') as f: 48 | j = json.load(f) 49 | src = np.array(src) 50 | targ = np.array(targ) 51 | label = np.array(label) 52 | inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)} 53 | reorder_node_ids = np.zeros_like(node_ids) 54 | for i in range(label.shape[0]): 55 | reorder_node_ids[i] = inv_node_ids[i] 56 | 57 | n = label.shape[0] 58 | A = scipy.sparse.csr_matrix((np.ones(len(src)), 59 | (np.array(src), np.array(targ))), 60 | shape=(n,n)) 61 | features = np.zeros((n,3170)) 62 | for node, feats in j.items(): 63 | if int(node) >= n: 64 | continue 65 | features[int(node), np.array(feats, dtype=int)] = 1 66 | # features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task 67 | new_label = label[reorder_node_ids] 68 | label = new_label 69 | 70 | return A, label, features 71 | -------------------------------------------------------------------------------- /large/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import csv 4 | import os 5 | from datetime import datetime 6 | 7 | def mkdirs(path): 8 | if not os.path.exists(path): 9 | os.makedirs(path) 10 | return path 11 | 12 | class Logger(object): 13 | """ Adapted from https://github.com/snap-stanford/ogb/ """ 14 | def __init__(self, runs, args=None): 15 | self.args = args 16 | self.results = [[] for _ in range(runs)] 17 | 18 | def add_result(self, run, result): 19 | assert len(result) == 7 20 | assert run >= 0 and run < len(self.results) 21 | self.results[run].append(result) 22 | 23 | @staticmethod 24 | def get_results_string(best_result): 25 | result_string = '' 26 | r = best_result[:, 0] 27 | result_string += f'Highest Train: {r.mean():.2f} ± {r.std():.2f}\t' 28 | r = best_result[:, 1] 29 | result_string += f'Highest Test: {r.mean():.2f} ± {r.std():.2f}\t' 30 | r = best_result[:, 2] 31 | result_string += f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}\t' 32 | r = best_result[:, 3] 33 | result_string += f' Final Train: {r.mean():.2f} ± {r.std():.2f}\t' 34 | r = best_result[:, 4] 35 | result_string += f' Final Test: {r.mean():.2f} ± {r.std():.2f}' 36 | 37 | return result_string 38 | 39 | def print_statistics(self, run=None, mode='max_acc'): 40 | if run is not None: 41 | # Ensure all elements are tensors and convert them properly 42 | result = [torch.tensor(r) * 100 if isinstance(r, (float, int)) else torch.tensor(r) * 100 for r in self.results[run]] 43 | result = torch.stack(result) # Stack the list of tensors into a single tensor 44 | 45 | if self.args.save_whole_test_result: 46 | now = datetime.now() 47 | _month_day = now.strftime("%m%d") 48 | timestamp = now.strftime("%m%d-%H%M%S") 49 | results_path = mkdirs(f'results/runs/{self.args.dataset}/{_month_day}/{self.args.wandb_name}') 50 | with open(f'{results_path}/{run}-{self.args.run_id}-results.csv', 'w', newline='') as f: 51 | writer = csv.writer(f) 52 | # Write the header (optional) 53 | writer.writerow(["Epoch", "Train Acc", "Val Acc", "Test Acc", "Val Loss"]) 54 | 55 | # Write the data 56 | for epoch in range(len(self.results[run])): 57 | # Add 1 to run and epoch indices to match with human counting 58 | formatted_row = ['{:.4f}'.format(float(x)) for x in self.results[run][epoch]] 59 | writer.writerow([epoch * self.args.eval_step] + formatted_row) 60 | # Write the args 61 | writer.writerow([]) 62 | writer.writerow(["Args"]) 63 | for key, value in vars(self.args).items(): 64 | writer.writerow([key, value]) 65 | 66 | print(f"Saved results to {self.args.dataset}-{self.args.wandb_name}-{run}-results.csv") 67 | 68 | argmax = result[:, 1].argmax().item() 69 | argmin = result[:, 3].argmin().item() 70 | if mode == 'max_acc': 71 | ind = argmax 72 | else: 73 | ind = argmin 74 | print('==========================') 75 | print_str1 = f'>> Run {run + 1:02d}:\n' + \ 76 | f'\t Highest Train: {result[:, 0].max():.2f} ' + \ 77 | f'\t Highest Valid: {result[:, 1].max():.2f} ' + \ 78 | f'\t Highest Test: {result[:, 2].max():.2f}\n' + \ 79 | f'\t Chosen epoch based on Valid loss: {argmin * self.args.eval_step} ' + \ 80 | f'\t Final Train: {result[argmin, 0]:.2f} ' + \ 81 | f'\t Final Valid: {result[argmin, 1]:.2f} ' + \ 82 | f'\t Final Test: {result[argmin, 2]:.2f}' 83 | print(print_str1) 84 | 85 | print_str=f'>> Run {run + 1:02d}:' + \ 86 | f'\t Highest Train: {result[:, 0].max():.2f} ' + \ 87 | f'\t Highest Valid: {result[:, 1].max():.2f} ' + \ 88 | f'\t Highest Test: {result[:, 2].max():.2f}\n' + \ 89 | f'\t Chosen epoch based on Valid acc: {ind * self.args.eval_step} ' + \ 90 | f'\t Final Train: {result[ind, 0]:.2f} ' + \ 91 | f'\t Final Valid: {result[ind, 1]:.2f} ' + \ 92 | f'\t Final Test: {result[ind, 2]:.2f}' 93 | print(print_str) 94 | self.test = result[ind, 2] 95 | else: 96 | best_results = [] 97 | max_val_epoch = 0 98 | 99 | for r in self.results: 100 | r = [torch.tensor(res) * 100 if isinstance(res, (float, int)) else torch.tensor(res) * 100 for res in r] 101 | r = torch.stack(r) # Stack the list of tensors into a single tensor 102 | train1 = r[:, 0].max().item() 103 | test1 = r[:, 2].max().item() 104 | valid = r[:, 1].max().item() 105 | if mode == 'max_acc': 106 | train2 = r[r[:, 1].argmax(), 0].item() 107 | test2 = r[r[:, 1].argmax(), 2].item() 108 | max_val_epoch = r[:, 1].argmax() 109 | else: 110 | train2 = r[r[:, 3].argmin(), 0].item() 111 | test2 = r[r[:, 3].argmin(), 2].item() 112 | best_results.append((train1, test1, valid, train2, test2)) 113 | 114 | best_result = torch.tensor(best_results) 115 | 116 | print(f'All runs:') 117 | r = best_result[:, 0] 118 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 119 | r = best_result[:, 1] 120 | print(f'Highest Test: {r.mean():.2f} ± {r.std():.2f}') 121 | r = best_result[:, 2] 122 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 123 | r = best_result[:, 3] 124 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 125 | r = best_result[:, 4] 126 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') 127 | 128 | self.test = r.mean() 129 | # if self.args.use_wandb: 130 | # wandb.log({ 131 | # 'Average Highest Train': r.mean().item(), 132 | # 'Std Highest Train': r.std().item(), 133 | # 'Average Highest Test': best_result[:, 1].mean().item(), 134 | # 'Std Highest Test': best_result[:, 1].std().item(), 135 | # 'Average Highest Valid': best_result[:, 2].mean().item(), 136 | # 'Std Highest Valid': best_result[:, 2].std().item(), 137 | # 'Average Final Train': best_result[:, 3].mean().item(), 138 | # 'Std Final Train': best_result[:, 3].std().item(), 139 | # 'Average Final Test': best_result[:, 4].mean().item(), 140 | # 'Std Final Test': best_result[:, 4].std().item() 141 | # }) 142 | return self.get_results_string(best_result) 143 | 144 | def save(self, params, results, filename): 145 | with open(filename, 'a', encoding='utf-8') as file: 146 | file.write(f"{results}\n") 147 | file.write(f"{params}\n") 148 | file.write('=='*50) 149 | file.write('\n') 150 | file.write('\n') 151 | 152 | import os 153 | def save_result(args, results): 154 | if args.save_result: 155 | if not os.path.exists(f'results/{args.dataset}'): 156 | os.makedirs(f'results/{args.dataset}') 157 | filename = f'results/{args.dataset}/{args.method}.csv' 158 | print(f"Saving results to {filename}") 159 | with open(f"{filename}", 'a+') as write_obj: 160 | write_obj.write( 161 | f"{args.method} " + f"{args.kernel}: " + f"{args.weight_decay} " + f"{args.dropout} " + \ 162 | f"{args.num_layers} " + f"{args.alpha}: " + f"{args.hidden_channels}: " + \ 163 | f"{results.mean():.2f} $\pm$ {results.std():.2f} \n") 164 | -------------------------------------------------------------------------------- /large/main-batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import time 4 | import warnings 5 | from datetime import datetime 6 | import os 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch_geometric.utils import (to_undirected, remove_self_loops, add_self_loops, 12 | subgraph, k_hop_subgraph) 13 | import wandb 14 | 15 | from logger import Logger 16 | from dataset import load_dataset 17 | from data_utils import (normalize, gen_normalized_adjs, eval_acc, eval_rocauc, eval_f1, 18 | to_sparse_tensor, load_fixed_splits, adj_mul, compute_degrees) 19 | from eval import evaluate_large, evaluate_batch 20 | from parse import parse_method, parser_add_main_args 21 | from manifolds import Optimizer 22 | 23 | warnings.filterwarnings('ignore') 24 | 25 | 26 | def fix_seed(seed): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description='General Training Pipeline') 36 | parser_add_main_args(parser) 37 | return parser.parse_args() 38 | 39 | 40 | def get_device(use_cpu, device_id): 41 | if use_cpu: 42 | print('>> Using CPU (🐢🐢) ') 43 | return torch.device("cpu") 44 | else: 45 | device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu") 46 | print('>> Using GPU (✈️✈️) ') 47 | return device 48 | 49 | 50 | def load_and_preprocess_data(args): 51 | print(f'>> Loading dataset {args.dataset} (⏳⏳)') 52 | if 'hypformer_data_dir' in os.environ: 53 | args.data_dir = os.environ['hypformer_data_dir'] 54 | dataset = load_dataset(args.data_dir, args.dataset, args.sub_dataset) 55 | 56 | if len(dataset.label.shape) == 1: 57 | dataset.label = dataset.label.unsqueeze(1) 58 | 59 | return dataset 60 | 61 | 62 | def get_data_splits(args, dataset): 63 | if args.rand_split: 64 | split_idx_lst = [dataset.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop) 65 | for _ in range(args.runs)] 66 | print('>> Using random split ') 67 | elif args.rand_split_class: 68 | split_idx_lst = [dataset.get_idx_split(split_type='class', label_num_per_class=args.label_num_per_class) 69 | for _ in range(args.runs)] 70 | print('>> Using random class split ') 71 | elif args.dataset in ['ogbn-proteins', 'ogbn-arxiv', 'ogbn-products', 'amazon2m', 'ogbn-papers100M', 72 | 'ogbn-papers100M-sub']: 73 | split_idx_lst = [dataset.load_fixed_splits() 74 | for _ in range(args.runs)] 75 | print('>> Using fixed split ') 76 | else: 77 | split_idx_lst = load_fixed_splits(args.data_dir, dataset, name=args.dataset, protocol=args.protocol) 78 | print('>> Using fixed split ') 79 | 80 | return split_idx_lst 81 | 82 | 83 | def print_dataset_info(dataset): 84 | n = dataset.graph['num_nodes'] 85 | e = dataset.graph['edge_index'].shape[1] 86 | c = max(dataset.label.max().item() + 1, dataset.label.shape[1]) 87 | d = dataset.graph['node_feat'].shape[1] 88 | 89 | print(f">> Dataset {dataset.name} | num nodes {n} | num edges {e} | num node feats {d} | num classes {c}") 90 | 91 | return n, c, d 92 | 93 | 94 | def compute_and_print_degrees(dataset): 95 | degrees = compute_degrees(dataset.graph['edge_index'], dataset.graph['num_nodes']) 96 | print(f">> Total degree is {degrees.sum()}") 97 | # print(degrees) 98 | print(f">> Degree shape is {degrees.shape}") 99 | 100 | print(f">> Highest degree is {degrees.max().item()}") 101 | print(f">> Lowest degree is {degrees.min().item()}") 102 | 103 | sorted_degrees, _ = torch.sort(degrees) 104 | percentile_index = int(len(sorted_degrees) * 0.8) 105 | threshold = sorted_degrees[percentile_index] 106 | print(f'>> Mean degree: {degrees.float().mean().item():.2f}') 107 | print(f'>> Std degree: {degrees.float().std().item():.2f}') 108 | print(f'>> Number of nodes with degree 0: {(degrees == 0).sum().item()}') 109 | print(f'>> Threshold: {threshold:.2f}') 110 | 111 | return degrees, threshold 112 | 113 | 114 | def select_loss_function(args): 115 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 116 | criterion = nn.BCEWithLogitsLoss() 117 | print(f'>> Using BCEWithLogitsLoss for {args.dataset}') 118 | else: 119 | criterion = nn.NLLLoss() 120 | print(f'>> Using NLLLoss for {args.dataset}') 121 | 122 | return criterion 123 | 124 | 125 | def select_eval_function(args): 126 | if args.metric == 'rocauc': 127 | eval_func = eval_rocauc 128 | print('>> Using ROC-AUC metric ') 129 | elif args.metric == 'f1': 130 | eval_func = eval_f1 131 | print('>> Using F1 metric ') 132 | else: 133 | eval_func = eval_acc 134 | print('>> Using Accuracy metric ') 135 | 136 | return eval_func 137 | 138 | 139 | def preprocess_graph(dataset, args): 140 | dataset.graph['edge_index'], _ = remove_self_loops(dataset.graph['edge_index']) 141 | dataset.graph['edge_index'], _ = add_self_loops(dataset.graph['edge_index'], num_nodes=dataset.graph['num_nodes']) 142 | 143 | if not args.directed and args.dataset != 'ogbn-proteins': 144 | dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index']) 145 | print('>> Symmetrized the graph ') 146 | 147 | 148 | def convert_labels_to_one_hot(dataset, args): 149 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 150 | if dataset.label.shape[1] == 1: 151 | return F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1) 152 | return dataset.label 153 | 154 | 155 | def initialize_wandb(args, run): 156 | if args.wandb_name == '0': 157 | now = datetime.now() 158 | timestamp = now.strftime("%m%d-%H%M") 159 | args.wandb_name = timestamp 160 | if args.use_wandb: 161 | wandb.init(project=f'HyperbolicFormer({args.dataset})', config=vars(args), 162 | name=f'{args.dataset}-Params-{args.wandb_name}-run-{run}') 163 | 164 | def train_and_evaluate(args, dataset, split_idx_lst, device, criterion, eval_func): 165 | n, c, d = print_dataset_info(dataset) 166 | degrees, threshold = compute_and_print_degrees(dataset) 167 | preprocess_graph(dataset, args) 168 | true_label = convert_labels_to_one_hot(dataset, args) 169 | logger = Logger(args.runs, args) 170 | 171 | for run in range(args.runs): 172 | initialize_wandb(args, run) 173 | 174 | split_idx = split_idx_lst[run] if args.dataset not in ['cora', 'citeseer', 175 | 'pubmed'] or args.protocol != 'semi' else split_idx_lst[ 176 | 0] 177 | train_mask = torch.zeros(n, dtype=torch.bool) 178 | train_mask[split_idx['train']] = True 179 | 180 | model = parse_method(args, c, d, device) 181 | model.reset_parameters() 182 | optimizer = Optimizer(model, args) 183 | 184 | for epoch in range(args.epochs): 185 | loss = train_one_epoch(epoch, args, dataset, device, model, optimizer, criterion, n, train_mask, true_label, degrees, 186 | threshold, eval_func, split_idx, logger, run) 187 | 188 | logger.print_statistics(run) 189 | if args.use_wandb: 190 | wandb.finish() 191 | 192 | results = logger.print_statistics() 193 | logger.save(vars(args), results, f'results/{args.dataset}.csv') 194 | 195 | 196 | def train_one_epoch(epoch, args, dataset, device, model, optimizer, criterion, n, train_mask, true_label, degrees, threshold, 197 | eval_func, split_idx, logger, run): 198 | model.to(device) 199 | model.train() 200 | train_start = time.time() 201 | idx = torch.randperm(n) 202 | num_batch = n // args.batch_size + (n % args.batch_size > 0) 203 | 204 | for i in range(num_batch): 205 | batch_data = prepare_batch_data(i, args,dataset, idx, n, train_mask, true_label, degrees, device) 206 | optimizer.zero_grad() 207 | out_i = model(batch_data['x_i'], batch_data['edge_index_i']) 208 | loss = compute_loss(args, criterion, out_i, batch_data['train_mask_i'], batch_data['y_i']) 209 | loss.backward() 210 | optimizer.step() 211 | 212 | print(f'🔥🔥 Epoch: {epoch:02d}, Loss: {loss:.4f} || Train Time: {time.time() - train_start:.2f}s') 213 | evaluate_epoch(epoch, args, model, dataset, split_idx, eval_func, criterion, degrees, threshold, device, logger, 214 | run, loss, true_label) 215 | 216 | return loss 217 | 218 | 219 | def prepare_batch_data(i, args, dataset, idx, n, train_mask, true_label, degrees, device): 220 | idx_i = idx[i * args.batch_size:(i + 1) * args.batch_size] 221 | train_mask_i = train_mask[idx_i] 222 | x_i = dataset.graph['node_feat'][idx_i].to(device) 223 | edge_index_i, _ = subgraph(idx_i, dataset.graph['edge_index'], num_nodes=n, relabel_nodes=True) 224 | edge_index_i = edge_index_i.to(device) 225 | y_i = true_label[idx_i].to(device) 226 | 227 | return {'idx_i': idx_i, 'train_mask_i': train_mask_i, 'x_i': x_i, 'edge_index_i': edge_index_i, 'y_i': y_i} 228 | 229 | 230 | def compute_loss(args, criterion, out_i, train_mask_i, y_i): 231 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 232 | return criterion(out_i[train_mask_i], y_i.squeeze(1)[train_mask_i].to(torch.float)) 233 | else: 234 | out_i = F.log_softmax(out_i, dim=1) 235 | return criterion(out_i[train_mask_i], y_i.squeeze(1)[train_mask_i]) 236 | 237 | 238 | def evaluate_epoch(epoch, args, model, dataset, split_idx, eval_func, criterion, degrees, threshold, device, logger, 239 | run, loss, true_label): 240 | if (epoch + 1) % args.eval_step == 0: 241 | if args.dataset == 'ogbn-papers100M': 242 | result = evaluate_batch(model, dataset, split_idx, args, device, dataset.graph['num_nodes'], true_label) 243 | else: 244 | result = evaluate_large(model, dataset, split_idx, eval_func, criterion, args, degrees, threshold, 245 | device=device) 246 | logger.add_result(run, result) 247 | 248 | if epoch % args.display_step == 0: 249 | display_evaluation_results(epoch, result, split_idx, degrees, threshold, loss) 250 | 251 | if args.use_wandb: 252 | wandb.log({"run": run, "epoch": epoch, "loss": loss.item(), "train_acc": result[0], 253 | "val_acc": result[1], "test_acc": result[2], "val_loss": result[3]}) 254 | 255 | 256 | def display_evaluation_results(epoch, result, split_idx, degrees, threshold, loss): 257 | degrees_in_test = degrees[split_idx['test']] 258 | top_indices = split_idx['test'][degrees_in_test > threshold] 259 | bottom_indices = split_idx['test'][degrees_in_test <= threshold] 260 | 261 | max_top_acc = top_indices.shape[0] / split_idx['test'].shape[0] 262 | max_bottom_acc = bottom_indices.shape[0] / split_idx['test'].shape[0] 263 | print_str = (f'👉Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * result[0]:.2f}%, ' 264 | f'Valid: {100 * result[1]:.2f}%, Test: {100 * result[2]:.2f}%, ' 265 | f'Top: {100 * result[5]:.2f} | {100 * max_top_acc:.2f}%, ' 266 | f'Bottom: {100 * result[6]:.2f} | {100 * max_bottom_acc:.2f}%') 267 | print(print_str) 268 | 269 | 270 | def main(): 271 | args = parse_args() 272 | print('===' * 40) 273 | print('⚙️,', args) 274 | 275 | fix_seed(args.seed) 276 | device = get_device(args.cpu, args.device) 277 | dataset = load_and_preprocess_data(args) 278 | split_idx_lst = get_data_splits(args, dataset) 279 | 280 | criterion = select_loss_function(args) 281 | eval_func = select_eval_function(args) 282 | 283 | train_and_evaluate(args, dataset, split_idx_lst, device, criterion, eval_func) 284 | 285 | 286 | if __name__ == "__main__": 287 | main() 288 | -------------------------------------------------------------------------------- /large/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import warnings 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch_geometric.utils import to_undirected, remove_self_loops, add_self_loops 13 | from torch_scatter import scatter 14 | 15 | import wandb 16 | 17 | from logger import Logger 18 | from dataset import load_dataset 19 | from data_utils import (normalize, gen_normalized_adjs, eval_acc, eval_rocauc, eval_f1, 20 | to_sparse_tensor, load_fixed_splits, adj_mul, get_gpu_memory_map, 21 | count_parameters, compute_degrees) 22 | from eval import evaluate, evaluate_large 23 | from parse import parse_method, parser_add_main_args 24 | from manifolds import Optimizer 25 | 26 | warnings.filterwarnings('ignore') 27 | 28 | 29 | def fix_seed(seed): 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.backends.cudnn.deterministic = True 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description='General Training Pipeline') 39 | parser_add_main_args(parser) 40 | return parser.parse_args() 41 | 42 | 43 | def get_device(use_cpu, device_id): 44 | if use_cpu: 45 | print('>> Using CPU') 46 | return torch.device("cpu") 47 | else: 48 | device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu") 49 | print('>> Using GPU') 50 | return device 51 | 52 | 53 | def load_and_preprocess_data(args, device): 54 | print(f'>> Loading dataset {args.dataset}') 55 | if 'hypformer_data_dir' in os.environ: 56 | args.data_dir = os.environ['hypformer_data_dir'] 57 | dataset = load_dataset(args.data_dir, args.dataset, args.sub_dataset) 58 | 59 | if len(dataset.label.shape) == 1: 60 | dataset.label = dataset.label.unsqueeze(1) 61 | dataset.label = dataset.label.to(device) 62 | 63 | return dataset 64 | 65 | 66 | def get_data_splits(args, dataset): 67 | if args.rand_split: 68 | split_idx_lst = [dataset.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop) 69 | for _ in range(args.runs)] 70 | print('>> Using random split') 71 | elif args.rand_split_class: 72 | split_idx_lst = [dataset.get_idx_split(split_type='class', label_num_per_class=args.label_num_per_class) 73 | for _ in range(args.runs)] 74 | print('>> Using random class split') 75 | elif args.dataset in ['ogbn-proteins', 'ogbn-arxiv', 'ogbn-products']: 76 | split_idx_lst = [dataset.load_fixed_splits() for _ in range(args.runs)] 77 | print('>> Using fixed split') 78 | else: 79 | split_idx_lst = load_fixed_splits(args.data_dir, dataset, name=args.dataset, protocol=args.protocol) 80 | print('>> Using fixed split') 81 | 82 | return split_idx_lst 83 | 84 | 85 | def print_dataset_info(dataset): 86 | n = dataset.graph['num_nodes'] 87 | e = dataset.graph['edge_index'].shape[1] 88 | c = max(dataset.label.max().item() + 1, dataset.label.shape[1]) 89 | d = dataset.graph['node_feat'].shape[1] 90 | 91 | print(f">> Dataset {dataset.name} | num nodes {n} | num edges {e} | num node feats {d} | num classes {c}") 92 | return n, c, d 93 | 94 | 95 | def compute_and_print_degrees(dataset): 96 | degrees = compute_degrees(dataset.graph['edge_index'], dataset.graph['num_nodes']) 97 | print(f">> Total degree is {degrees.sum()}") 98 | print(f">> Degree shape is {degrees.shape}") 99 | 100 | print(f">> Highest degree is {degrees.max().item()}") 101 | print(f">> Lowest degree is {degrees.min().item()}") 102 | 103 | sorted_degrees, _ = torch.sort(degrees) 104 | percentile_index = int(len(sorted_degrees) * 0.8) 105 | threshold = sorted_degrees[percentile_index] 106 | print(f'>> Mean degree: {degrees.float().mean().item():.2f}') 107 | print(f'>> Std degree: {degrees.float().std().item():.2f}') 108 | print(f'>> Number of nodes with degree 0: {(degrees == 0).sum().item()}') 109 | print(f'>> Threshold: {threshold:.2f}') 110 | 111 | less_than_degree = (degrees <= threshold).sum().item() 112 | greater_than_degree = (degrees > threshold).sum().item() 113 | print( 114 | f">> Number of nodes with degree less than {threshold}: {less_than_degree}, it accounts for {less_than_degree / degrees.shape[0]:.2f}") 115 | print( 116 | f">> Number of nodes with degree greater than {threshold}: {greater_than_degree}, it accounts for {greater_than_degree / degrees.shape[0]:.2f}") 117 | 118 | return degrees, threshold 119 | 120 | 121 | def preprocess_graph(dataset, args, n): 122 | dataset.graph['edge_index'], _ = remove_self_loops(dataset.graph['edge_index']) 123 | dataset.graph['edge_index'], _ = add_self_loops(dataset.graph['edge_index'], num_nodes=n) 124 | 125 | if not args.directed and args.dataset != 'ogbn-proteins': 126 | dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index']) 127 | 128 | dataset.graph['edge_index'], dataset.graph['node_feat'] = dataset.graph['edge_index'].to(args.device), \ 129 | dataset.graph['node_feat'].to(args.device) 130 | return dataset 131 | 132 | 133 | def initialize_wandb(args, run): 134 | if args.wandb_name == '0': 135 | now = datetime.now() 136 | timestamp = now.strftime("%m%d-%H%M") 137 | args.wandb_name = timestamp 138 | if args.use_wandb: 139 | wandb.init(project=f'HyperbolicFormer({args.dataset})', config=vars(args), 140 | name=f'{args.dataset}-Params-{args.wandb_name}-run-{run}') 141 | 142 | 143 | 144 | def select_loss_function(args): 145 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 146 | criterion = nn.BCEWithLogitsLoss() 147 | print(f'>> Using BCEWithLogitsLoss for {args.dataset}') 148 | else: 149 | criterion = nn.NLLLoss() 150 | print(f'>> Using NLLLoss for {args.dataset}') 151 | return criterion 152 | 153 | 154 | def select_eval_function(args): 155 | if args.metric == 'rocauc': 156 | eval_func = eval_rocauc 157 | print('>> Using ROC-AUC metric') 158 | elif args.metric == 'f1': 159 | eval_func = eval_f1 160 | print('>> Using F1 metric') 161 | else: 162 | eval_func = eval_acc 163 | print('>> Using Accuracy metric') 164 | return eval_func 165 | 166 | 167 | def train_and_evaluate(args, dataset, split_idx_lst, device, criterion, eval_func): 168 | n, c, d = print_dataset_info(dataset) 169 | degrees, threshold = compute_and_print_degrees(dataset) 170 | preprocess_graph(dataset, args, n) 171 | logger = Logger(args.runs, args) 172 | 173 | for run in range(args.runs): 174 | initialize_wandb(args, run) 175 | split_idx = split_idx_lst[0] if args.dataset in ['cora', 'citeseer', 'pubmed'] and args.protocol == 'semi' else \ 176 | split_idx_lst[run] 177 | train_idx = split_idx['train'].to(device) 178 | 179 | model = parse_method(args, c, d, device) 180 | model.reset_parameters() 181 | optimizer = Optimizer(model, args) 182 | 183 | for epoch in range(args.epochs): 184 | train_start = time.time() 185 | loss = train_one_epoch(epoch, args, dataset, device, model, optimizer, criterion, train_idx) 186 | print(f'🔥🔥 Epoch: {epoch:02d}, Loss: {loss:.4f} || Train Time: {time.time() - train_start:.2f}s') 187 | 188 | if epoch % args.eval_step == 0: 189 | evaluate_and_log(epoch, args, model, dataset, split_idx, eval_func, criterion, degrees, threshold, 190 | device, logger, run, loss) 191 | 192 | logger.print_statistics(run) 193 | if args.use_wandb: 194 | wandb.finish() 195 | 196 | results = logger.print_statistics() 197 | if args.save_result: 198 | logger.save(vars(args), results, f'results/{args.dataset}.csv') 199 | 200 | 201 | def train_one_epoch(epoch, args, dataset, device, model, optimizer, criterion, train_idx): 202 | model.to(device) 203 | model.train() 204 | optimizer.zero_grad() 205 | 206 | out = model(dataset.graph['node_feat'], dataset.graph['edge_index']) 207 | true_label = get_true_label(dataset, args) 208 | loss = compute_loss(out, criterion, train_idx, true_label, args) 209 | loss.backward() 210 | optimizer.step() 211 | 212 | return loss 213 | 214 | 215 | def get_true_label(dataset, args): 216 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 217 | if dataset.label.shape[1] == 1: 218 | return F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1) 219 | return dataset.label 220 | 221 | 222 | def compute_loss(out, criterion, train_idx, true_label, args): 223 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 224 | return criterion(out[train_idx], true_label.squeeze(1)[train_idx].to(torch.float)) 225 | else: 226 | out = F.log_softmax(out, dim=1) 227 | return criterion(out[train_idx], true_label.squeeze(1)[train_idx]) 228 | 229 | 230 | def evaluate_and_log(epoch, args, model, dataset, split_idx, eval_func, criterion, degrees, threshold, device, logger, 231 | run, loss): 232 | result = evaluate_large(model, dataset, split_idx, eval_func, criterion, args, degrees, threshold, device=device) 233 | logger.add_result(run, result) 234 | if epoch % args.display_step == 0: 235 | display_evaluation_results(epoch, result, split_idx, degrees, threshold, loss) 236 | if args.use_wandb: 237 | wandb.log({"run": run, "epoch": epoch, "loss": loss.item(), "train_acc": result[0], "val_acc": result[1], 238 | "test_acc": result[2], "val_loss": result[3]}) 239 | 240 | 241 | def display_evaluation_results(epoch, result, split_idx, degrees, threshold, loss): 242 | degrees_in_test = degrees[split_idx['test']] 243 | top_indices = split_idx['test'][degrees_in_test > threshold] 244 | bottom_indices = split_idx['test'][degrees_in_test <= threshold] 245 | 246 | max_top_acc = top_indices.shape[0] / split_idx['test'].shape[0] 247 | max_bottom_acc = bottom_indices.shape[0] / split_idx['test'].shape[0] 248 | # print_str = (f'👉Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * result[0]:.2f}%, ' 249 | # f'Valid: {100 * result[1]:.2f}%, Test: {100 * result[2]:.2f}%, ' 250 | # f'Top: {100 * result[5]:.2f} | {100 * max_top_acc:.2f}%, ' 251 | # f'Bottom: {100 * result[6]:.2f} | {100 * max_bottom_acc:.2f}%') 252 | print_str = (f'👉Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * result[0]:.2f}%, ' 253 | f'Valid: {100 * result[1]:.2f}%, Test: {100 * result[2]:.2f}%') 254 | print(print_str) 255 | 256 | 257 | def main(): 258 | args = parse_args() 259 | print(args) 260 | 261 | fix_seed(args.seed) 262 | device = get_device(args.cpu, args.device) 263 | dataset = load_and_preprocess_data(args, device) 264 | split_idx_lst = get_data_splits(args, dataset) 265 | 266 | criterion = select_loss_function(args) 267 | eval_func = select_eval_function(args) 268 | 269 | train_and_evaluate(args, dataset, split_idx_lst, device, criterion, eval_func) 270 | 271 | 272 | if __name__ == "__main__": 273 | main() 274 | -------------------------------------------------------------------------------- /large/manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import (HypLayerNorm, HypDropout, 2 | HypActivation, HypNormalization, 3 | Optimizer, HypLinear) 4 | from .lorentz import Lorentz -------------------------------------------------------------------------------- /large/manifolds/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from geoopt import ManifoldParameter 5 | from geoopt.optim.rsgd import RiemannianSGD 6 | from geoopt.optim.radam import RiemannianAdam 7 | import math 8 | 9 | 10 | class HypLayerNorm(nn.Module): 11 | """ 12 | Hyperbolic Layer Normalization Layer 13 | 14 | Parameters: 15 | manifold (Manifold): The manifold to use for normalization. 16 | in_features (int): The number of input features. 17 | manifold_out (Manifold, optional): The output manifold. Default is None. 18 | """ 19 | 20 | def __init__(self, manifold, in_features, manifold_out=None): 21 | super(HypLayerNorm, self).__init__() 22 | self.in_features = in_features 23 | self.manifold = manifold 24 | self.manifold_out = manifold_out 25 | self.layer = nn.LayerNorm(self.in_features) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | """Reset layer parameters.""" 30 | self.layer.reset_parameters() 31 | 32 | def forward(self, x): 33 | """Forward pass for hyperbolic layer normalization.""" 34 | x_space = x[..., 1:] 35 | x_space = self.layer(x_space) 36 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 37 | x = torch.cat([x_time, x_space], dim=-1) 38 | 39 | if self.manifold_out is not None: 40 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 41 | return x 42 | 43 | 44 | class HypNormalization(nn.Module): 45 | """ 46 | Hyperbolic Normalization Layer 47 | 48 | Parameters: 49 | manifold (Manifold): The manifold to use for normalization. 50 | manifold_out (Manifold, optional): The output manifold. Default is None. 51 | """ 52 | 53 | def __init__(self, manifold, manifold_out=None): 54 | super(HypNormalization, self).__init__() 55 | self.manifold = manifold 56 | self.manifold_out = manifold_out 57 | 58 | def forward(self, x): 59 | """Forward pass for hyperbolic normalization.""" 60 | x_space = x[..., 1:] 61 | x_space = x_space / x_space.norm(dim=-1, keepdim=True) 62 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 63 | x = torch.cat([x_time, x_space], dim=-1) 64 | if self.manifold_out is not None: 65 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 66 | return x 67 | 68 | 69 | class HypActivation(nn.Module): 70 | """ 71 | Hyperbolic Activation Layer 72 | 73 | Parameters: 74 | manifold (Manifold): The manifold to use for the activation. 75 | activation (function): The activation function. 76 | manifold_out (Manifold, optional): The output manifold. Default is None. 77 | """ 78 | 79 | def __init__(self, manifold, activation, manifold_out=None): 80 | super(HypActivation, self).__init__() 81 | self.manifold = manifold 82 | self.manifold_out = manifold_out 83 | self.activation = activation 84 | 85 | def forward(self, x): 86 | """Forward pass for hyperbolic activation.""" 87 | x_space = x[..., 1:] 88 | x_space = self.activation(x_space) 89 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 90 | x = torch.cat([x_time, x_space], dim=-1) 91 | if self.manifold_out is not None: 92 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 93 | return x 94 | 95 | 96 | class HypDropout(nn.Module): 97 | """ 98 | Hyperbolic Dropout Layer 99 | 100 | Parameters: 101 | manifold (Manifold): The manifold to use for the dropout. 102 | dropout (float): The dropout probability. 103 | manifold_out (Manifold, optional): The output manifold. Default is None. 104 | """ 105 | 106 | def __init__(self, manifold, dropout, manifold_out=None): 107 | super(HypDropout, self).__init__() 108 | self.manifold = manifold 109 | self.manifold_out = manifold_out 110 | self.dropout = nn.Dropout(dropout) 111 | 112 | def forward(self, x, training=False): 113 | """Forward pass for hyperbolic dropout.""" 114 | if training: 115 | x_space = x[..., 1:] 116 | x_space = self.dropout(x_space) 117 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 118 | x = torch.cat([x_time, x_space], dim=-1) 119 | if self.manifold_out is not None: 120 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 121 | return x 122 | 123 | 124 | class HypLinear(nn.Module): 125 | """ 126 | Hyperbolic Linear Layer 127 | 128 | Parameters: 129 | manifold (Manifold): The manifold to use for the linear transformation. 130 | in_features (int): The size of each input sample. 131 | out_features (int): The size of each output sample. 132 | bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True. 133 | dropout (float, optional): The dropout probability. Default is 0.0. 134 | manifold_out (Manifold, optional): The output manifold. Default is None. 135 | """ 136 | 137 | def __init__(self, manifold, in_features, out_features, bias=True, dropout=0.0, manifold_out=None): 138 | super().__init__() 139 | self.in_features = in_features + 1 # +1 for time dimension 140 | self.out_features = out_features 141 | self.bias = bias 142 | self.manifold = manifold 143 | self.manifold_out = manifold_out 144 | 145 | self.linear = nn.Linear(self.in_features, self.out_features, bias=bias) 146 | self.dropout_rate = dropout 147 | self.reset_parameters() 148 | 149 | def reset_parameters(self): 150 | """Reset layer parameters.""" 151 | init.xavier_uniform_(self.linear.weight, gain=math.sqrt(2)) 152 | if self.bias: 153 | init.constant_(self.linear.bias, 0) 154 | 155 | def forward(self, x, x_manifold='hyp'): 156 | """Forward pass for hyperbolic linear layer.""" 157 | if x_manifold != 'hyp': 158 | x = torch.cat([torch.ones_like(x)[..., 0:1], x], dim=-1) 159 | x = self.manifold.expmap0(x) 160 | x_space = self.linear(x) 161 | 162 | x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 163 | x = torch.cat([x_time, x_space], dim=-1) 164 | if self.manifold_out is not None: 165 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 166 | return x 167 | 168 | class HypCLS(nn.Module): 169 | def __init__(self, manifold, in_channels, out_channels, bias=True): 170 | """ 171 | Initializes the HypCLS class with the given parameters. 172 | 173 | Parameters: 174 | - `manifold` (Manifold): The manifold object. 175 | - `in_channels` (int): The number of input channels. 176 | - `out_channels` (int): The number of output channels. 177 | - `bias` (bool, optional): Whether to include a bias term. Defaults to True. 178 | 179 | Returns: 180 | None 181 | """ 182 | super().__init__() 183 | self.manifold = manifold 184 | self.in_channels = in_channels 185 | self.out_channels = out_channels 186 | cls_emb = self.manifold.random_normal((self.out_channels, self.in_channels + 1), mean=0, std=1. / math.sqrt(self.in_channels + 1)) 187 | self.cls = ManifoldParameter(cls_emb, self.manifold, requires_grad=True) 188 | if bias: 189 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 190 | 191 | def cinner(self, x, y): 192 | x = x.clone() 193 | x.narrow(-1, 0, 1).mul_(-1) 194 | return x @ y.transpose(-1, -2) 195 | 196 | def forward(self, x, x_manifold='hyp', return_type='neg_dist'): 197 | if x_manifold != 'hyp': 198 | x = self.manifold.expmap0(torch.cat([torch.zeros_like(x)[..., 0:1], x], dim=-1)) # project to Lorentz 199 | 200 | dist = -2 * self.manifold.k - 2 * self.cinner(x, self.cls) + self.bias 201 | dist = dist.clamp(min=0) 202 | 203 | if return_type == 'neg_dist': 204 | return - dist 205 | elif return_type == 'prob': 206 | return 1.0 / (1.0 + dist) 207 | elif return_type == 'neg_log_prob': 208 | return - 1.0*torch.log(1.0 + dist) 209 | else: 210 | raise NotImplementedError 211 | 212 | class Optimizer(object): 213 | """ 214 | Optimizer for Euclidean and Hyperbolic parameters 215 | 216 | Parameters: 217 | model (nn.Module): The model containing the parameters to optimize. 218 | args (Namespace): The arguments containing optimizer settings. 219 | """ 220 | 221 | def __init__(self, model, args): 222 | euc_optimizer_type = args.optimizer_type 223 | hyp_optimizer_type = args.hyp_optimizer_type 224 | euc_lr = args.lr 225 | hyp_lr = args.hyp_lr 226 | euc_weight_decay = args.weight_decay 227 | hyp_weight_decay = args.hyp_weight_decay 228 | 229 | euc_params = [p for n, p in model.named_parameters() if 230 | p.requires_grad and not isinstance(p, ManifoldParameter)] 231 | hyp_params = [p for n, p in model.named_parameters() if p.requires_grad and isinstance(p, ManifoldParameter)] 232 | 233 | # print(f">> Number of Euclidean parameters: {sum(p.numel() for p in euc_params)}") 234 | # print(f">> Number of Hyperbolic parameters: {sum(p.numel() for p in hyp_params)}") 235 | self.optimizer = [] # Optimizers for Euclidean and Hyperbolic parts of the model 236 | 237 | if euc_params: 238 | if euc_optimizer_type == 'adam': 239 | optimizer_euc = torch.optim.Adam(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 240 | elif euc_optimizer_type == 'sgd': 241 | optimizer_euc = torch.optim.SGD(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 242 | else: 243 | raise NotImplementedError(f"Unknown Euclidean optimizer type: {euc_optimizer_type}") 244 | self.optimizer.append(optimizer_euc) 245 | 246 | if hyp_params: 247 | if hyp_optimizer_type == 'radam': 248 | optimizer_hyp = RiemannianAdam(hyp_params, lr=hyp_lr, stabilize=10, weight_decay=hyp_weight_decay) 249 | elif hyp_optimizer_type == 'rsgd': 250 | optimizer_hyp = RiemannianSGD(hyp_params, lr=hyp_lr, stabilize=10, weight_decay=hyp_weight_decay) 251 | else: 252 | raise NotImplementedError(f"Unknown Hyperbolic optimizer type: {hyp_optimizer_type}") 253 | self.optimizer.append(optimizer_hyp) 254 | 255 | def step(self): 256 | """Performs a single optimization step.""" 257 | for optimizer in self.optimizer: 258 | optimizer.step() 259 | 260 | def zero_grad(self): 261 | """Sets the gradients of all optimized tensors to zero.""" 262 | for optimizer in self.optimizer: 263 | optimizer.zero_grad() 264 | -------------------------------------------------------------------------------- /large/manifolds/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Tuple, Any, Union, List 3 | import functools 4 | import operator 5 | import torch 6 | import geoopt 7 | 8 | 9 | max_norm = 85 10 | eps = 1e-8 11 | 12 | __all__ = [ 13 | "copy_or_set_", 14 | "strip_tuple", 15 | "size2shape", 16 | "make_tuple", 17 | "broadcast_shapes", 18 | "ismanifold", 19 | "canonical_manifold", 20 | "list_range", 21 | "idx2sign", 22 | "drop_dims", 23 | "canonical_dims", 24 | "sign", 25 | "prod", 26 | "clamp_abs", 27 | "sabs", 28 | ] 29 | 30 | 31 | def copy_or_set_(dest: torch.Tensor, source: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Copy or inplace set from :code:`source` to :code:`dest`. 34 | 35 | A workaround to respect strides of :code:`dest` when copying :code:`source`. 36 | The original issue was raised `here `_ 37 | when working with matrix manifolds. Inplace set operation is mode efficient, 38 | but the resulting storage might be incompatible after. To avoid the issue we refer to 39 | the safe option and use :code:`copy_` if strides do not match. 40 | 41 | Parameters 42 | ---------- 43 | dest : torch.Tensor 44 | Destination tensor where to store new data 45 | source : torch.Tensor 46 | Source data to put in the new tensor 47 | 48 | Returns 49 | ------- 50 | dest 51 | torch.Tensor, modified inplace 52 | """ 53 | if dest.stride() != source.stride(): 54 | return dest.copy_(source) 55 | else: 56 | return dest.set_(source) 57 | 58 | 59 | def strip_tuple(tup: Tuple) -> Union[Tuple, Any]: 60 | if len(tup) == 1: 61 | return tup[0] 62 | else: 63 | return tup 64 | 65 | 66 | def make_tuple(obj: Union[Tuple, List, Any]) -> Tuple: 67 | if isinstance(obj, list): 68 | obj = tuple(obj) 69 | if not isinstance(obj, tuple): 70 | return (obj,) 71 | else: 72 | return obj 73 | 74 | 75 | def prod(items): 76 | return functools.reduce(operator.mul, items, 1) 77 | 78 | 79 | def sign(x): 80 | return torch.sign(x.sign() + 0.5) 81 | 82 | 83 | def sabs(x, eps: float = 1e-15): 84 | return x.abs().add_(eps) 85 | 86 | 87 | def clamp_abs(x, eps: float = 1e-15): 88 | s = sign(x) 89 | return s * sabs(x, eps=eps) 90 | 91 | 92 | def idx2sign(idx: int, dim: int, neg: bool = True): 93 | """ 94 | Unify idx to be negative or positive, that helps in cases of broadcasting. 95 | 96 | Parameters 97 | ---------- 98 | idx : int 99 | current index 100 | dim : int 101 | maximum dimension 102 | neg : bool 103 | indicate we need negative index 104 | 105 | Returns 106 | ------- 107 | int 108 | """ 109 | if neg: 110 | if idx < 0: 111 | return idx 112 | else: 113 | return (idx + 1) % -(dim + 1) 114 | else: 115 | return idx % dim 116 | 117 | 118 | def drop_dims(tensor: torch.Tensor, dims: List[int]): 119 | # Workaround to drop several dims in :func:`torch.squeeze`. 120 | seen: int = 0 121 | for d in dims: 122 | tensor = tensor.squeeze(d - seen) 123 | seen += 1 124 | return tensor 125 | 126 | 127 | def list_range(end: int): 128 | res: List[int] = [] 129 | for d in range(end): 130 | res.append(d) 131 | return res 132 | 133 | 134 | def canonical_dims(dims: List[int], maxdim: int): 135 | result: List[int] = [] 136 | for idx in dims: 137 | result.append(idx2sign(idx, maxdim, neg=False)) 138 | return result 139 | 140 | 141 | def size2shape(*size: Union[Tuple[int], int]) -> Tuple[int]: 142 | return make_tuple(strip_tuple(size)) 143 | 144 | 145 | def broadcast_shapes(*shapes: Tuple[int]) -> Tuple[int]: 146 | """Apply numpy broadcasting rules to shapes.""" 147 | result = [] 148 | for dims in itertools.zip_longest(*map(reversed, shapes), fillvalue=1): 149 | dim: int = 1 150 | for d in dims: 151 | if dim != 1 and d != 1 and d != dim: 152 | raise ValueError("Shapes can't be broadcasted") 153 | elif d > dim: 154 | dim = d 155 | result.append(dim) 156 | return tuple(reversed(result)) 157 | 158 | 159 | def ismanifold(instance, cls): 160 | """ 161 | Check if interface of an instance is compatible with given class. 162 | 163 | Parameters 164 | ---------- 165 | instance : geoopt.Manifold 166 | check if a given manifold is compatible with cls API 167 | cls : type 168 | manifold type 169 | 170 | Returns 171 | ------- 172 | bool 173 | comparison result 174 | """ 175 | if not issubclass(cls, geoopt.manifolds.Manifold): 176 | raise TypeError( 177 | "`cls` should be a subclass of geoopt.manifolds.Manifold") 178 | if not isinstance(instance, geoopt.manifolds.Manifold): 179 | return False 180 | else: 181 | # this is the case to care about, Scaled class is a proxy, but fails instance checks 182 | while isinstance(instance, geoopt.Scaled): 183 | instance = instance.base 184 | return isinstance(instance, cls) 185 | 186 | 187 | def canonical_manifold(manifold: "geoopt.Manifold"): 188 | """ 189 | Get a canonical manifold. 190 | 191 | If a manifold is wrapped with Scaled. Some attributes may not be available. This should help if you really need them. 192 | 193 | Parameters 194 | ---------- 195 | manifold : geoopt.Manifold 196 | 197 | Returns 198 | ------- 199 | geoopt.Maniflold 200 | an unwrapped manifold 201 | """ 202 | while isinstance(manifold, geoopt.Scaled): 203 | manifold = manifold.base 204 | return manifold 205 | 206 | 207 | def cosh(x: torch.Tensor) -> torch.Tensor: 208 | x = clamp(x, min=-max_norm, max=max_norm) 209 | return torch.cosh(x) 210 | 211 | 212 | def sinh(x: torch.Tensor) -> torch.Tensor: 213 | x = clamp(x, min=-max_norm, max=max_norm) 214 | return torch.sinh(x) 215 | 216 | 217 | def sqrt(x: torch.Tensor) -> torch.Tensor: 218 | x = clamp(x, min=1e-9) # Smaller epsilon due to precision around x=0. 219 | return torch.sqrt(x) 220 | 221 | 222 | class LeakyClamp(torch.autograd.Function): 223 | 224 | @staticmethod 225 | def forward(ctx: Any, x: torch.Tensor, min: float, max: float) -> torch.Tensor: 226 | with torch.no_grad(): 227 | ctx.save_for_backward(x.ge(min) & x.le(max)) 228 | return torch.clamp(x, min=min, max=max) 229 | 230 | @staticmethod 231 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 232 | mask, = ctx.saved_tensors 233 | mask = mask.type_as(grad_output) 234 | return grad_output * mask + grad_output * (1 - mask) * eps, None, None 235 | 236 | 237 | def clamp(x: torch.Tensor, min: float = float("-inf"), max: float = float("+inf")) -> torch.Tensor: 238 | return LeakyClamp.apply(x, min, max) 239 | 240 | 241 | class Atanh(torch.autograd.Function): 242 | """ 243 | Numerically stable arctanh that never returns NaNs. 244 | x = clamp(x, min=-1+eps, max=1-eps) 245 | Returns atanh(x) = arctanh(x) = 0.5*(log(1+x)-log(1-x)). 246 | """ 247 | 248 | @staticmethod 249 | def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: 250 | x = clamp(x, min=-1. + 4 * eps, max=1. - 4 * eps) 251 | ctx.save_for_backward(x) 252 | res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) 253 | return res 254 | 255 | @staticmethod 256 | def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: 257 | x, = ctx.saved_tensors 258 | return grad_output / (1 - x**2) 259 | 260 | 261 | def atanh(x: torch.Tensor) -> torch.Tensor: 262 | """ 263 | Numerically stable arctanh that never returns NaNs. 264 | 265 | :param x: The input tensor. 266 | :return: log(x + sqrt(max(x^2 - 1, eps)) 267 | """ 268 | return Atanh.apply(x) 269 | 270 | 271 | class Acosh(torch.autograd.Function): 272 | """ 273 | Numerically stable arccosh that never returns NaNs. 274 | Returns acosh(x) = arccosh(x) = log(x + sqrt(max(x^2 - 1, eps))). 275 | """ 276 | 277 | @staticmethod 278 | def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: 279 | with torch.no_grad(): 280 | x = clamp(x, min=1 + eps) 281 | z = sqrt(x * x - 1.) 282 | ctx.save_for_backward(z) 283 | return torch.log(x + z) 284 | 285 | @staticmethod 286 | def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: 287 | z, = ctx.saved_tensors 288 | # z_ = clamp(z, min=eps) 289 | z_ = z 290 | return grad_output / z_ 291 | 292 | 293 | def acosh(x: torch.Tensor) -> torch.Tensor: 294 | """ 295 | Numerically stable arccosh that never returns NaNs. 296 | 297 | :param x: The input tensor. 298 | :return: log(x + sqrt(max(x^2 - 1, eps)) 299 | """ 300 | return Acosh.apply(x) -------------------------------------------------------------------------------- /large/parse.py: -------------------------------------------------------------------------------- 1 | 2 | from hypformer import HypFormer 3 | 4 | def parse_method(args, c, d, device): 5 | model = HypFormer(d, args.hidden_channels, c, graph_weight=args.graph_weight, aggregate=args.aggregate, 6 | trans_num_layers=args.trans_num_layers, trans_dropout=args.trans_dropout, trans_num_heads=args.trans_num_heads, 7 | trans_use_bn=args.trans_use_bn, trans_use_residual=args.trans_use_residual, trans_use_weight=args.trans_use_weight, trans_use_act=args.trans_use_act, 8 | gnn_num_layers=args.gnn_num_layers, gnn_dropout=args.gnn_dropout, gnn_use_bn=args.gnn_use_bn, 9 | gnn_use_residual=args.gnn_use_residual, gnn_use_weight=args.gnn_use_weight, gnn_use_init=args.gnn_use_init, gnn_use_act=args.gnn_use_act, 10 | args=args).to(device) 11 | return model 12 | 13 | 14 | 15 | def parser_add_main_args(parser): 16 | # dataset and evaluation 17 | parser.add_argument('--dataset', type=str, default='proteins', help='Name of the dataset to be used (default: proteins)') 18 | parser.add_argument('--sub_dataset', type=str, default='', help='Sub-dataset to be used (if any)') 19 | parser.add_argument('--data_dir', type=str, default='../data', help='Directory where the data is stored') 20 | parser.add_argument('--device', type=int, default=0, help='GPU device ID to be used (default: 0)') 21 | parser.add_argument('--seed', type=int, default=123, help='Random seed for reproducibility (default: 123)') 22 | parser.add_argument('--cpu', type=int, choices=[0, 1], default=0, help='Use CPU instead of GPU (0: False, 1: True)') 23 | parser.add_argument('--epochs', type=int, default=500, help='Number of training epochs (default: 500)') 24 | parser.add_argument('--runs', type=int, default=1, help='Number of distinct runs (default: 1)') 25 | parser.add_argument('--directed', type=int, choices=[0, 1], default=0, help='Set to use directed graph (0: False, 1: True)') 26 | parser.add_argument('--train_prop', type=float, default=.5, help='Proportion of training labels (default: 0.5)') 27 | parser.add_argument('--valid_prop', type=float, default=.25, help='Proportion of validation labels (default: 0.25)') 28 | parser.add_argument('--protocol', type=str, default='semi', 29 | help='Protocol for cora datasets: semi or supervised (default: semi)') 30 | parser.add_argument('--rand_split', type=int, choices=[0, 1], help='Use random splits (0: False, 1: True)') 31 | parser.add_argument('--rand_split_class', type=int, choices=[0, 1], 32 | help='Use random splits with a fixed number of labeled nodes per class (0: False, 1: True)') 33 | parser.add_argument('--label_num_per_class', type=int, default=20, 34 | help='Number of labeled nodes per class (default: 20)') 35 | parser.add_argument('--metric', type=str, default='acc', choices=['acc', 'rocauc', 'f1'], 36 | help='Evaluation metric (default: acc)') 37 | 38 | parser.add_argument('--use_graph', type=int, choices=[0, 1], help='Use input graph (0: False, 1: True)') 39 | parser.add_argument('--aggregate', type=str, default='add', help='Aggregate type: add or cat (default: add)') 40 | parser.add_argument('--graph_weight', type=float, default=0.8, help='Weight for the graph (default: 0.8)') 41 | parser.add_argument('--gnn_use_bn', type=int, choices=[0, 1], 42 | help='Use batch normalization in each GNN layer (0: False, 1: True)') 43 | parser.add_argument('--gnn_use_residual', type=int, choices=[0, 1], 44 | help='Use residual connections in each GNN layer (0: False, 1: True)') 45 | parser.add_argument('--gnn_use_weight', type=int, choices=[0, 1], help='Use weight for GNN convolution (0: False, 1: True)') 46 | parser.add_argument('--gnn_use_init', type=int, choices=[0, 1], 47 | help='Use initial features in each GNN layer (0: False, 1: True)') 48 | parser.add_argument('--gnn_use_act', type=int, choices=[0, 1], help='Use activation function in each GNN layer (0: False, 1: True)') 49 | parser.add_argument('--gnn_num_layers', type=int, default=2, help='Number of GNN layers (default: 2)') 50 | parser.add_argument('--gnn_dropout', type=float, default=0.0, help='Dropout rate for GNN layers (default: 0.0)') 51 | 52 | # all-pair attention (Transformer) branch 53 | parser.add_argument('--method', type=str, default='hypformer', help='method to be used (default: hypformer)') 54 | parser.add_argument('--hidden_channels', type=int, default=32, help='Number of hidden channels (default: 32)') 55 | parser.add_argument('--trans_num_heads', type=int, default=1, 56 | help='Number of heads for attention in Transformer (default: 1)') 57 | parser.add_argument('--trans_use_weight', type=int, choices=[0, 1], 58 | help='Use weight for Transformer convolution (0: False, 1: True)') 59 | parser.add_argument('--trans_use_bn', type=int, choices=[0, 1], 60 | help='Use layer normalization in Transformer (0: False, 1: True)') 61 | parser.add_argument('--trans_use_residual', type=int, choices=[0, 1], 62 | help='Use residual connections in each Transformer layer (0: False, 1: True)') 63 | parser.add_argument('--trans_use_act', type=int, choices=[0, 1], 64 | help='Use activation function in each Transformer layer (0: False, 1: True)') 65 | parser.add_argument('--trans_num_layers', type=int, default=2, help='Number of Transformer layers (default: 2)') 66 | parser.add_argument('--trans_dropout', type=float, help='Dropout rate for Transformer layers') 67 | parser.add_argument('--add_positional_encoding', type=int, default=1, 68 | help='Add positional encoding to Transformer layers (default: 1)') 69 | 70 | # display and utility 71 | parser.add_argument('--display_step', type=int, default=1, help='Frequency of display updates (default: 1)') 72 | parser.add_argument('--eval_step', type=int, default=1, help='Frequency of evaluation steps (default: 1)') 73 | parser.add_argument('--cached', type=int, choices=[0, 1], help='Use cached data for faster processing (0: False, 1: True)') 74 | parser.add_argument('--print_prop', type=int, choices=[0, 1], 75 | help='Print proportions of predicted classes (0: False, 1: True)') 76 | parser.add_argument('--save_result', type=int, choices=[0, 1], default=0, help='Save the result of the run (0: False, 1: True)') 77 | parser.add_argument('--save_model', type=int, choices=[0, 1], help='Save the model after training (0: False, 1: True)') 78 | parser.add_argument('--use_pretrained', type=int, choices=[0, 1], help='Use a pre-trained model (0: False, 1: True)') 79 | parser.add_argument('--save_att', type=int, choices=[0, 1], 80 | help='Save attention weights for visualization (0: False, 1: True)') 81 | parser.add_argument('--model_dir', type=str, default='checkpoints/', 82 | help='Directory to save the model checkpoints (default: checkpoints/)') 83 | 84 | # other gnn parameters (for baselines) 85 | parser.add_argument('--hops', type=int, default=2, help='Number of hops for SGC (default: 2)') 86 | parser.add_argument('--gat_heads', type=int, default=4, help='Number of attention heads for GAT (default: 4)') 87 | parser.add_argument('--out_heads', type=int, default=1, help='Number of output heads for GAT (default: 1)') 88 | 89 | # training 90 | parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay (default: 0.0)') 91 | parser.add_argument('--hyp_weight_decay', type=float, default=0.005, 92 | help='Weight decay for Hyperbolic space (default: 0.005)') 93 | 94 | parser.add_argument('--optimizer_type', type=str, default='adam', choices=['adam', 'sgd'], 95 | help='Optimizer type for Euclidean space (default: adam)') 96 | parser.add_argument('--hyp_optimizer_type', type=str, default='radam', choices=['radam', 'rsgd'], 97 | help='Optimizer type for Hyperbolic space (default: radam)') 98 | 99 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate (default: 0.01)') 100 | parser.add_argument('--hyp_lr', type=float, default=0.01, help='Hyperbolic learning rate (default: 0.01)') 101 | 102 | parser.add_argument('--batch_size', type=int, default=10000, 103 | help='Mini-batch size for training large graphs (default: 10000)') 104 | parser.add_argument('--patience', type=int, default=200, help='Early stopping patience (default: 200)') 105 | parser.add_argument('--k_in', type=float, default=1.0, help='Curvature for input layer (default: 1.0)') 106 | parser.add_argument('--k_hidden', type=float, default=1.0, help='Curvature for hidden layer (default: 1.0)') 107 | parser.add_argument('--k_out', type=float, default=1.0, help='Curvature for output layer (default: 1.0)') 108 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], help='Use Weights and Biases for logging (0: False, 1: True)') 109 | parser.add_argument('--wandb_name', type=str, default='0', help='Weights and Biases project name (default: 0)') 110 | parser.add_argument('--power_k', type=float, default=2.0, help='Power k for query and key (default: 2.0)') 111 | parser.add_argument('--attention_type', type=str, default='linear_focused', 112 | help='Attention type: linear_focused, or full (default: linear_focused)') 113 | parser.add_argument('--run_id', type=str, default='0', help='Run ID (default: 0)') 114 | parser.add_argument('--save_whole_test_result', type=int, default=1, help='Save whole test result (default: 1)') 115 | parser.add_argument('--decoder_type', type=str, default='euc', help='Decoder type (default: euc)') 116 | parser.add_argument('--trans_heads_concat', type=int, default=0, help='Use heads concatenation for Transformer (default: 1)') 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /medium/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from scipy import sparse as sp 8 | from sklearn.metrics import f1_score, roc_auc_score 9 | from torch_sparse import SparseTensor 10 | 11 | 12 | def rand_train_test_idx(label, train_prop=0.5, valid_prop=0.25, ignore_negative=True): 13 | """randomly splits label into train/valid/test splits""" 14 | if ignore_negative: 15 | labeled_nodes = torch.where(label != -1)[0] 16 | else: 17 | labeled_nodes = label 18 | 19 | n = labeled_nodes.shape[0] 20 | train_num = int(n * train_prop) 21 | valid_num = int(n * valid_prop) 22 | 23 | perm = torch.as_tensor(np.random.permutation(n)) 24 | 25 | train_indices = perm[:train_num] 26 | val_indices = perm[train_num: train_num + valid_num] 27 | test_indices = perm[train_num + valid_num:] 28 | 29 | if not ignore_negative: 30 | return train_indices, val_indices, test_indices 31 | 32 | train_idx = labeled_nodes[train_indices] 33 | valid_idx = labeled_nodes[val_indices] 34 | test_idx = labeled_nodes[test_indices] 35 | 36 | return train_idx, valid_idx, test_idx 37 | 38 | 39 | def class_rand_splits(label, label_num_per_class, valid_num=500, test_num=1000): 40 | """use all remaining data points as test data, so test_num will not be used""" 41 | train_idx, non_train_idx = [], [] 42 | idx = torch.arange(label.shape[0]).to(label.device) 43 | class_list = label.squeeze().unique() 44 | for i in range(class_list.shape[0]): 45 | c_i = class_list[i] 46 | idx_i = idx[label.squeeze() == c_i] 47 | n_i = idx_i.shape[0] 48 | rand_idx = idx_i[torch.randperm(n_i)] 49 | train_idx += rand_idx[:label_num_per_class].tolist() 50 | non_train_idx += rand_idx[label_num_per_class:].tolist() 51 | train_idx = torch.as_tensor(train_idx) 52 | non_train_idx = torch.as_tensor(non_train_idx) 53 | non_train_idx = non_train_idx[torch.randperm(non_train_idx.shape[0])] 54 | valid_idx, test_idx = ( 55 | non_train_idx[:valid_num], 56 | non_train_idx[valid_num: valid_num + test_num], 57 | ) 58 | print(f"train:{train_idx.shape}, valid:{valid_idx.shape}, test:{test_idx.shape}") 59 | split_idx = {"train": train_idx, "valid": valid_idx, "test": test_idx} 60 | return split_idx 61 | 62 | 63 | def normalize_feat(mx): 64 | """Row-normalize np or sparse matrix.""" 65 | rowsum = np.array(mx.sum(1)) 66 | r_inv = np.power(rowsum, -1).flatten() 67 | r_inv[np.isinf(r_inv)] = 0.0 68 | r_mat_inv = sp.diags(r_inv) 69 | mx = r_mat_inv.dot(mx) 70 | return mx 71 | 72 | 73 | def eval_acc(y_true, y_pred): 74 | acc_list = [] 75 | y_true = y_true.detach().cpu().numpy() 76 | y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy() 77 | 78 | for i in range(y_true.shape[1]): 79 | is_labeled = y_true[:, i] == y_true[:, i] 80 | correct = y_true[is_labeled, i] == y_pred[is_labeled, i] 81 | acc_list.append(float(np.sum(correct)) / len(correct)) 82 | 83 | return sum(acc_list) / len(acc_list) 84 | 85 | 86 | @torch.no_grad() 87 | def evaluate(model, dataset, split_idx, eval_func, criterion, args, result=None): 88 | if result is not None: 89 | out = result 90 | else: 91 | model.eval() 92 | if args.method == "fast_transgnn" or args.method == "glcn": 93 | out, _ = model(dataset) 94 | else: 95 | out = model(dataset) 96 | 97 | train_acc = eval_func(dataset.label[split_idx["train"]], out[split_idx["train"]]) 98 | valid_acc = eval_func(dataset.label[split_idx["valid"]], out[split_idx["valid"]]) 99 | test_acc = eval_func(dataset.label[split_idx["test"]], out[split_idx["test"]]) 100 | if args.dataset in ( 101 | "yelp-chi", 102 | "deezer-europe", 103 | "twitch-e", 104 | "fb100", 105 | "ogbn-proteins", 106 | ): 107 | if dataset.label.shape[1] == 1: 108 | true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1) 109 | else: 110 | true_label = dataset.label 111 | valid_loss = criterion( 112 | out[split_idx["valid"]], 113 | true_label.squeeze(1)[split_idx["valid"]].to(torch.float), 114 | ) 115 | else: 116 | out = F.log_softmax(out, dim=1) 117 | valid_loss = criterion( 118 | out[split_idx["valid"]], dataset.label.squeeze(1)[split_idx["valid"]] 119 | ) 120 | 121 | return train_acc, valid_acc, test_acc, valid_loss, out 122 | 123 | 124 | def load_fixed_splits(dataset, name, protocol): 125 | splits_lst = [] 126 | if name in ["cora", "citeseer", "pubmed", "airport", "disease"] and protocol == "semi": 127 | splits = {} 128 | splits["train"] = torch.as_tensor(dataset.train_idx) 129 | splits["valid"] = torch.as_tensor(dataset.valid_idx) 130 | splits["test"] = torch.as_tensor(dataset.test_idx) 131 | splits_lst.append(splits) 132 | elif name in ["chameleon", "squirrel"]: 133 | file_path = f"../../data/wiki_new/{name}/{name}_filtered.npz" 134 | data = np.load(file_path) 135 | train_masks = data["train_masks"] # (10, N), 10 splits 136 | val_masks = data["val_masks"] 137 | test_masks = data["test_masks"] 138 | N = train_masks.shape[1] 139 | 140 | node_idx = np.arange(N) 141 | for i in range(10): 142 | splits = {} 143 | splits["train"] = torch.as_tensor(node_idx[train_masks[i]]) 144 | splits["valid"] = torch.as_tensor(node_idx[val_masks[i]]) 145 | splits["test"] = torch.as_tensor(node_idx[test_masks[i]]) 146 | splits_lst.append(splits) 147 | 148 | elif name in ["film"]: 149 | for i in range(10): 150 | splits_file_path = ( 151 | "../../data/geom-gcn/{}/{}".format(name, name) 152 | + "_split_0.6_0.2_" 153 | + str(i) 154 | + ".npz" 155 | ) 156 | splits = {} 157 | with np.load(splits_file_path) as splits_file: 158 | splits["train"] = torch.BoolTensor(splits_file["train_mask"]) 159 | splits["valid"] = torch.BoolTensor(splits_file["val_mask"]) 160 | splits["test"] = torch.BoolTensor(splits_file["test_mask"]) 161 | splits_lst.append(splits) 162 | else: 163 | raise NotImplementedError 164 | 165 | return splits_lst 166 | 167 | 168 | def split_data(labels, val_prop, test_prop, seed=1234): 169 | np.random.seed(seed) 170 | nb_nodes = labels.shape[0] 171 | all_idx = np.arange(nb_nodes) 172 | pos_idx = labels.nonzero()[0] 173 | neg_idx = (1. - labels).nonzero()[0] 174 | np.random.shuffle(pos_idx) 175 | np.random.shuffle(neg_idx) 176 | pos_idx = pos_idx.tolist() 177 | neg_idx = neg_idx.tolist() 178 | nb_pos_neg = min(len(pos_idx), len(neg_idx)) 179 | nb_val = round(val_prop * nb_pos_neg) 180 | nb_test = round(test_prop * nb_pos_neg) 181 | idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[ 182 | nb_val + nb_test:] 183 | idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[ 184 | nb_val + nb_test:] 185 | return idx_val_pos + idx_val_neg, idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg 186 | -------------------------------------------------------------------------------- /medium/examples/5-runs/run_airport.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python main.py \ 3 | --dataset airport \ 4 | --method hypformer \ 5 | --lr 0.005 \ 6 | --weight_decay 1e-3 \ 7 | --hidden_channels 256 \ 8 | --use_graph 1 \ 9 | --gnn_dropout 0.4 \ 10 | --gnn_use_bn 1 \ 11 | --gnn_num_layers 3 \ 12 | --gnn_use_init 1 \ 13 | --trans_num_layers 1 \ 14 | --trans_use_residual 1 \ 15 | --trans_use_bn 0 \ 16 | --graph_weight 0.2 \ 17 | --trans_dropout 0.2 \ 18 | --device 0 \ 19 | --runs 5 \ 20 | --power_k 2.0 \ 21 | --epochs 5000 \ 22 | --k_in 1.0 \ 23 | --k_out 2.0 \ 24 | --data_dir ../data \ 25 | --decoder_type hyp 26 | -------------------------------------------------------------------------------- /medium/examples/5-runs/run_citeseer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Highest Test: 74.40 ± 0.26 Final Test: 73.32 ± 0.50 4 | # activate conda env before running 5 | 6 | python main.py \ 7 | --dataset citeseer \ 8 | --method hypformer \ 9 | --lr 0.005 \ 10 | --hidden_channels 256 \ 11 | --use_graph 1 \ 12 | --weight_decay 0.005 \ 13 | --gnn_num_layers 5 \ 14 | --graph_weight 0.4 \ 15 | --gnn_dropout 0.5 \ 16 | --gnn_use_weight 0 \ 17 | --gnn_use_bn 0 \ 18 | --gnn_use_residual 1 \ 19 | --gnn_use_init 0 \ 20 | --gnn_use_act 1 \ 21 | --trans_num_layers 1 \ 22 | --trans_dropout 0.5 \ 23 | --trans_use_residual 1 \ 24 | --trans_use_weight 1 \ 25 | --trans_num_heads 1 \ 26 | --trans_use_bn 0 \ 27 | --trans_use_act 0 \ 28 | --rand_split_class 1 \ 29 | --valid_num 500 \ 30 | --test_num 1000 \ 31 | --no_feat_norm 1 \ 32 | --add_positional_encoding 1 \ 33 | --epochs 500 \ 34 | --seed 123 \ 35 | --device 0 \ 36 | --runs 5 \ 37 | --power_k 3.0 \ 38 | --k_in 1.0 \ 39 | --k_out 1.0 \ 40 | --attention_type linear_focused \ 41 | --decoder_type euc \ 42 | --save_result 0 43 | -------------------------------------------------------------------------------- /medium/examples/5-runs/run_cora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py \ 4 | --dataset cora \ 5 | --method hypformer \ 6 | --lr 0.005 \ 7 | --hidden_channels 256 \ 8 | --use_graph 1 \ 9 | --gnn_num_layers 6 \ 10 | --graph_weight 0.5 \ 11 | --weight_decay 0.005 \ 12 | --gnn_use_residual 1 \ 13 | --gnn_dropout 0.5 \ 14 | --gnn_use_bn 0 \ 15 | --gnn_use_init 0 \ 16 | --trans_num_layers 1 \ 17 | --trans_dropout 0.5 \ 18 | --trans_use_residual 1 \ 19 | --rand_split_class 1 \ 20 | --no_feat_norm 1 \ 21 | --trans_use_bn 0 \ 22 | --trans_num_heads 1 \ 23 | --valid_num 500 \ 24 | --test_num 1000 \ 25 | --no_feat_norm 1 \ 26 | --epochs 500 \ 27 | --seed 123 \ 28 | --device 0 \ 29 | --runs 5 \ 30 | --power_k 1.0 \ 31 | --attention_type linear_focused \ 32 | --decoder euc \ 33 | --k_in 1.0 \ 34 | --k_out 3.0 \ 35 | --data_dir ../data 36 | -------------------------------------------------------------------------------- /medium/examples/5-runs/run_pubmed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py \ 4 | --dataset pubmed \ 5 | --method hypformer \ 6 | --lr 0.005 \ 7 | --weight_decay 5e-4 \ 8 | --hidden_channels 256 \ 9 | --use_graph 1 \ 10 | --gnn_num_layers 4 \ 11 | --graph_weight 0.8 \ 12 | --gnn_use_residual 1 \ 13 | --gnn_use_weight 0 \ 14 | --gnn_dropout 0.5 \ 15 | --gnn_use_bn 0 \ 16 | --gnn_use_init 0 \ 17 | --gnn_use_act 0 \ 18 | --trans_num_layers 1 \ 19 | --trans_use_weight 1 \ 20 | --trans_use_act 0 \ 21 | --trans_dropout 0.5 \ 22 | --trans_use_residual 1 \ 23 | --rand_split_class 1 \ 24 | --valid_num 500 \ 25 | --test_num 1000 \ 26 | --no_feat_norm 1 \ 27 | --epochs 500 \ 28 | --seed 123 \ 29 | --device 0 \ 30 | --runs 5 \ 31 | --power_k 3.0 \ 32 | --k_in 1.0 \ 33 | --k_out 2.0 \ 34 | --attention_type linear_focused \ 35 | --decoder_type hyp \ 36 | --hyp_lr 0.005 \ 37 | --hyp_weight_decay 5e-4 \ 38 | --data_dir ../data \ 39 | --save_result 0 -------------------------------------------------------------------------------- /medium/examples/airport.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py \ 4 | --dataset airport \ 5 | --method hypformer \ 6 | --lr 0.005 \ 7 | --weight_decay 1e-3 \ 8 | --hidden_channels 256 \ 9 | --use_graph 1 \ 10 | --gnn_dropout 0.4 \ 11 | --gnn_use_bn 1 \ 12 | --gnn_num_layers 3 \ 13 | --gnn_use_init 1 \ 14 | --trans_num_layers 1 \ 15 | --trans_use_residual 1 \ 16 | --trans_use_bn 0 \ 17 | --graph_weight 0.2 \ 18 | --trans_dropout 0.2 \ 19 | --device 0 \ 20 | --runs 1 \ 21 | --power_k 2.0 \ 22 | --epochs 5000 \ 23 | --decoder hyp \ 24 | --k_in 1.0 \ 25 | --k_out 2.0 \ 26 | --data_dir ../data \ 27 | --decoder_type hyp 28 | -------------------------------------------------------------------------------- /medium/examples/citeseer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Highest Test: 74.40 ± 0.26 Final Test: 73.32 ± 0.50 4 | # activate conda env before running 5 | 6 | python main.py \ 7 | --dataset citeseer \ 8 | --method hypformer \ 9 | --lr 0.005 \ 10 | --hidden_channels 256 \ 11 | --use_graph 1 \ 12 | --weight_decay 0.005 \ 13 | --gnn_num_layers 5 \ 14 | --graph_weight 0.4 \ 15 | --gnn_dropout 0.5 \ 16 | --gnn_use_weight 0 \ 17 | --gnn_use_bn 0 \ 18 | --gnn_use_residual 1 \ 19 | --gnn_use_init 0 \ 20 | --gnn_use_act 1 \ 21 | --trans_num_layers 1 \ 22 | --trans_dropout 0.5 \ 23 | --trans_use_residual 1 \ 24 | --trans_use_weight 1 \ 25 | --trans_num_heads 1 \ 26 | --trans_use_bn 0 \ 27 | --trans_use_act 0 \ 28 | --rand_split_class 1 \ 29 | --valid_num 500 \ 30 | --test_num 1000 \ 31 | --no_feat_norm 1 \ 32 | --add_positional_encoding 1 \ 33 | --epochs 500 \ 34 | --seed 123 \ 35 | --device 0 \ 36 | --runs 1 \ 37 | --power_k 3.0 \ 38 | --k_in 1.0 \ 39 | --k_out 1.0 \ 40 | --attention_type linear_focused \ 41 | --decoder_type euc \ 42 | --save_result 0 43 | -------------------------------------------------------------------------------- /medium/examples/cora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py \ 4 | --dataset cora \ 5 | --method hypformer \ 6 | --lr 0.005 \ 7 | --hidden_channels 256 \ 8 | --use_graph 1 \ 9 | --gnn_num_layers 3 \ 10 | --graph_weight 0.5 \ 11 | --weight_decay 0.005 \ 12 | --gnn_use_residual 1 \ 13 | --gnn_dropout 0.5 \ 14 | --gnn_use_bn 0 \ 15 | --gnn_use_init 0 \ 16 | --trans_num_layers 1 \ 17 | --trans_dropout 0.5 \ 18 | --trans_use_residual 1 \ 19 | --rand_split_class 1 \ 20 | --no_feat_norm 1 \ 21 | --trans_use_bn 0 \ 22 | --trans_num_heads 1 \ 23 | --valid_num 500 \ 24 | --test_num 1000 \ 25 | --no_feat_norm 1 \ 26 | --epochs 500 \ 27 | --seed 123 \ 28 | --device 0 \ 29 | --runs 1 \ 30 | --power_k 1.0 \ 31 | --attention_type linear_focused \ 32 | --decoder euc \ 33 | --k_in 1.0 \ 34 | --k_out 3.0 \ 35 | --data_dir ../data 36 | -------------------------------------------------------------------------------- /medium/examples/pubmed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py \ 4 | --dataset pubmed \ 5 | --method hypformer \ 6 | --lr 0.005 \ 7 | --weight_decay 5e-4 \ 8 | --hidden_channels 256 \ 9 | --use_graph 1 \ 10 | --gnn_num_layers 4 \ 11 | --graph_weight 0.8 \ 12 | --gnn_use_residual 1 \ 13 | --gnn_use_weight 0 \ 14 | --gnn_dropout 0.5 \ 15 | --gnn_use_bn 0 \ 16 | --gnn_use_init 0 \ 17 | --gnn_use_act 0 \ 18 | --trans_num_layers 1 \ 19 | --trans_use_weight 1 \ 20 | --trans_use_act 0 \ 21 | --trans_dropout 0.5 \ 22 | --trans_use_residual 1 \ 23 | --rand_split_class 1 \ 24 | --valid_num 500 \ 25 | --test_num 1000 \ 26 | --no_feat_norm 1 \ 27 | --epochs 500 \ 28 | --seed 123 \ 29 | --device 0 \ 30 | --runs 1 \ 31 | --power_k 3.0 \ 32 | --k_in 1.0 \ 33 | --k_out 2.0 \ 34 | --attention_type linear_focused \ 35 | --decoder_type hyp \ 36 | --hyp_lr 0.005 \ 37 | --hyp_weight_decay 5e-4 \ 38 | --data_dir ../data \ 39 | --save_result 0 -------------------------------------------------------------------------------- /medium/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, runs, info=None): 7 | self.info = info 8 | self.results = [[] for _ in range(runs)] 9 | 10 | def add_result(self, run, result): 11 | assert len(result) == 4 12 | assert run >= 0 and run < len(self.results) 13 | self.results[run].append(result) 14 | 15 | @staticmethod 16 | def get_results_string(best_result): 17 | result_string = '' 18 | r = best_result[:, 0] 19 | result_string += f'Highest Train: {r.mean():.2f} ± {r.std():.2f}\t' 20 | r = best_result[:, 1] 21 | result_string += f'Highest Test: {r.mean():.2f} ± {r.std():.2f}\t' 22 | r = best_result[:, 2] 23 | result_string += f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}\t' 24 | r = best_result[:, 3] 25 | result_string += f' Final Train: {r.mean():.2f} ± {r.std():.2f}\t' 26 | r = best_result[:, 4] 27 | result_string += f' Final Test: {r.mean():.2f} ± {r.std():.2f}' 28 | 29 | return result_string 30 | 31 | def print_statistics(self, run=None, mode='max_acc'): 32 | if run is not None: 33 | result = 100 * torch.tensor(self.results[run]) 34 | argmax = result[:, 1].argmax().item() 35 | argmin = result[:, 3].argmin().item() 36 | if mode == 'max_acc': 37 | ind = argmax 38 | else: 39 | ind = argmin 40 | 41 | print_str = f'Run {run + 1:02d}:' + \ 42 | f'Highest Train: {result[:, 0].max():.2f} ' + \ 43 | f'Highest Valid: {result[:, 1].max():.2f} ' + \ 44 | f'Highest Test: {result[:, 2].max():.2f} ' + \ 45 | f'Chosen epoch: {ind + 1}\n' + \ 46 | f'Final Train: {result[ind, 0]:.2f} ' + \ 47 | f'Final Test: {result[ind, 2]:.2f}' 48 | print(print_str) 49 | self.test = result[ind, 2] 50 | 51 | else: 52 | best_results = [] 53 | max_val_epoch = 0 54 | for r in self.results: 55 | r = 100 * torch.tensor(r) 56 | train1 = r[:, 0].max().item() 57 | test1 = r[:, 2].max().item() 58 | valid = r[:, 1].max().item() 59 | if mode == 'max_acc': 60 | train2 = r[r[:, 1].argmax(), 0].item() 61 | test2 = r[r[:, 1].argmax(), 2].item() 62 | max_val_epoch = r[:, 1].argmax() 63 | else: 64 | train2 = r[r[:, 3].argmin(), 0].item() 65 | test2 = r[r[:, 3].argmin(), 2].item() 66 | best_results.append((train1, test1, valid, train2, test2)) 67 | 68 | best_result = torch.tensor(best_results) 69 | 70 | print_str = f'{len(self.results)} runs: ' 71 | r = best_result[:, 0] 72 | print_str += f'Highest Train: {r.mean():.2f} ± {r.std():.2f} ' 73 | print_str += f'Highest val epoch:{max_val_epoch}\n' 74 | r = best_result[:, 1] 75 | print_str += f'Highest Test: {r.mean():.2f} ± {r.std():.2f} ' 76 | r = best_result[:, 4] 77 | print_str += f'Final Test: {r.mean():.2f} ± {r.std():.2f}' 78 | 79 | self.test = r.mean() 80 | return print_str 81 | 82 | def output(self, out_path, info): 83 | with open(out_path, 'a', encoding='utf-8') as f: 84 | f.write(info) 85 | f.write(f'test acc:{self.test}\n') 86 | 87 | def save(self, params, results, filename): 88 | with open(filename, 'a', encoding='utf-8') as file: 89 | file.write(f"{results}\n") 90 | file.write(f"{params}\n") 91 | file.write('==' * 50) 92 | file.write('\n') 93 | file.write('\n') 94 | 95 | 96 | def save_result(args, results): 97 | if not os.path.exists(f'results/{args.dataset}'): 98 | os.makedirs(f'results/{args.dataset}') 99 | filename = f'results/{args.dataset}/{args.method}.csv' 100 | print(f"Saving results to {filename}") 101 | with open(f"{filename}", 'a+') as write_obj: 102 | # Get a dictionary of arguments and their values 103 | args_dict = vars(args) 104 | # Write each argument and its value to the file 105 | for arg, value in args_dict.items(): 106 | write_obj.write(f"{arg}: {value} ") 107 | # Write the results 108 | write_obj.write(f"{results.mean():.2f} $\pm$ {results.std():.2f} \n") 109 | -------------------------------------------------------------------------------- /medium/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import random 5 | import sys 6 | import warnings 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from data_utils import class_rand_splits, eval_acc, evaluate, load_fixed_splits 13 | from dataset import load_nc_dataset 14 | from logger import Logger 15 | from parse import parse_method, parser_add_main_args 16 | from sklearn.neighbors import kneighbors_graph 17 | 18 | from manifolds.hyp_layer import Optimizer 19 | 20 | warnings.filterwarnings('ignore') 21 | 22 | 23 | def mkdirs(path): 24 | if not os.path.exists(path): 25 | os.makedirs(path) 26 | return path 27 | 28 | 29 | def fix_seed(seed): 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = False 36 | torch.use_deterministic_algorithms = True 37 | 38 | 39 | ### Parse args ### 40 | parser = argparse.ArgumentParser(description='Medium Data Training Pipeline') 41 | parser_add_main_args(parser) 42 | args = parser.parse_args() 43 | print('====' * 20) 44 | print(args) 45 | fix_seed(args.seed) 46 | 47 | if args.cpu: 48 | device = torch.device("cpu") 49 | print('>> Using CPU') 50 | else: 51 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 52 | print('>> Using GPU: ' + str(args.device)) 53 | 54 | ### Load and preprocess data ### 55 | dataset = load_nc_dataset(args) 56 | 57 | if len(dataset.label.shape) == 1: 58 | dataset.label = dataset.label.unsqueeze(1) 59 | dataset.label = dataset.label.to(device) 60 | 61 | dataset_name = args.dataset 62 | 63 | if args.rand_split: 64 | print('>> loading random splits ...') 65 | split_idx_lst = [dataset.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop) 66 | for _ in range(args.runs)] 67 | elif args.rand_split_class: 68 | print('>> loading random class splits ...') 69 | split_idx_lst = [class_rand_splits( 70 | dataset.label, args.label_num_per_class, args.valid_num, args.test_num)] 71 | else: 72 | print('>> loading fixed splits ...') 73 | split_idx_lst = load_fixed_splits( 74 | dataset, name=args.dataset, protocol=args.protocol) 75 | 76 | if args.dataset in ('mini', '20news'): 77 | adj_knn = kneighbors_graph(dataset.graph['node_feat'], n_neighbors=args.knn_num, include_self=True) 78 | edge_index = torch.tensor(adj_knn.nonzero(), dtype=torch.long) 79 | dataset.graph['edge_index'] = edge_index 80 | 81 | n = dataset.graph['num_nodes'] 82 | num_class = max(dataset.label.max().item() + 1, dataset.label.shape[1]) 83 | num_class = (int)(num_class) 84 | node_feat_dim = dataset.graph['node_feat'].shape[1] 85 | args.in_channels = node_feat_dim 86 | args.out_channels = num_class 87 | 88 | dataset.graph['edge_index'] = dataset.graph['edge_index'].to(device), 89 | dataset.graph['node_feat'] = dataset.graph['node_feat'].to(device) 90 | 91 | print(f">> num nodes {n} | num classes {num_class} | num node feats {node_feat_dim}") 92 | 93 | if args.dataset in ('deezer-europe'): 94 | criterion = nn.BCEWithLogitsLoss() 95 | else: 96 | criterion = nn.NLLLoss() 97 | 98 | eval_func = eval_acc 99 | 100 | # =============================================================================== 101 | logger = Logger(args.runs, args) 102 | for run in range(args.runs): 103 | print(f'🔥Run {run + 1}/{args.runs}') 104 | if args.dataset in ['cora', 'citeseer', 'pubmed', 'airport', 'disease'] and args.protocol == 'semi': 105 | split_idx = split_idx_lst[0] 106 | else: 107 | split_idx = split_idx_lst[run] 108 | train_idx = split_idx['train'].to(device) # get train split 109 | model = parse_method(args, device) # load model 110 | optimizer = Optimizer(model, args) # load optimizer 111 | 112 | best_val = float('-inf') 113 | patience = 0 114 | for epoch in range(args.epochs): 115 | model.train() 116 | optimizer.zero_grad() 117 | emb = None 118 | out = model(dataset) 119 | out = F.log_softmax(out, dim=1) 120 | loss = criterion( 121 | out[train_idx], dataset.label.squeeze(1)[train_idx]) 122 | loss.backward() 123 | optimizer.step() 124 | 125 | result = evaluate(model, dataset, split_idx, eval_func, criterion, args) 126 | logger.add_result(run, result[:-1]) 127 | 128 | if result[1] > best_val: 129 | best_val = result[1] 130 | patience = 0 131 | else: 132 | patience += 1 133 | if patience >= args.patience: 134 | break 135 | 136 | if epoch % args.display_step == 0: 137 | print(f'Epoch: {epoch:02d}, ' 138 | f'Loss: {loss:.4f}, ' 139 | f'Train: {100 * result[0]:.2f}%, ' 140 | f'Valid: {100 * result[1]:.2f}%, ' 141 | f'Test: {100 * result[2]:.2f}%') 142 | logger.print_statistics(run) 143 | 144 | if args.output_attention: 145 | attentions = model.get_attentions(dataset.graph['node_feat'].to(device)) 146 | np.save(f'results/att/{args.dataset}_{args.method}_{args.hidden_channels}_attentions.npy', attentions.detach().cpu().numpy()) 147 | # delete the model and optimizer and start a new run 148 | del model, optimizer 149 | 150 | if args.runs > 1: 151 | results = logger.print_statistics() 152 | print('====================') 153 | print(results) 154 | print('====================') 155 | 156 | 157 | out_folder = 'results' 158 | if not os.path.exists(out_folder): 159 | os.mkdir(out_folder) 160 | 161 | if args.save_result: 162 | mkdirs(f'results/{args.dataset}') 163 | csvfilename = f'results/{args.dataset}/{args.dataset}_{args.method}_{args.hidden_channels}.csv' 164 | logger.save(vars(args), results, csvfilename) 165 | -------------------------------------------------------------------------------- /medium/manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from .hyp_layer import (HypLayerNorm, HypDropout, 2 | HypActivation, HypNormalization, 3 | Optimizer, HypCLS, HypLinear) 4 | from .lorentz import Lorentz -------------------------------------------------------------------------------- /medium/manifolds/hyp_layer.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional 5 | import torch.nn.init as init 6 | import math 7 | from geoopt import ManifoldParameter 8 | from geoopt.optim.rsgd import RiemannianSGD 9 | from geoopt.optim.radam import RiemannianAdam 10 | 11 | class HypLayerNorm(nn.Module): 12 | def __init__(self, manifold, in_features, manifold_out=None): 13 | super(HypLayerNorm, self).__init__() 14 | self.in_features = in_features 15 | self.manifold = manifold 16 | self.manifold_out = manifold_out 17 | self.layer = nn.LayerNorm(self.in_features) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | self.layer.reset_parameters() 22 | 23 | def forward(self, x): 24 | x_space = x[..., 1:] 25 | x_space = self.layer(x_space) 26 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 27 | x = torch.cat([x_time, x_space], dim=-1) 28 | 29 | # Adjust for a different manifold if specified 30 | if self.manifold_out is not None: 31 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 32 | return x 33 | 34 | class HypNormalization(nn.Module): 35 | def __init__(self, manifold, manifold_out=None): 36 | super(HypNormalization, self).__init__() 37 | self.manifold = manifold 38 | self.manifold_out = manifold_out 39 | 40 | def forward(self, x): 41 | x_space = x[..., 1:] 42 | x_space = x_space / x_space.norm(dim=-1, keepdim=True) 43 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 44 | x = torch.cat([x_time, x_space], dim=-1) 45 | 46 | # Adjust for a different manifold if specified 47 | if self.manifold_out is not None: 48 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 49 | return x 50 | 51 | class HypActivation(nn.Module): 52 | def __init__(self, manifold, activation, manifold_out=None): 53 | super(HypActivation, self).__init__() 54 | self.manifold = manifold 55 | self.manifold_out = manifold_out 56 | self.activation = activation 57 | 58 | def forward(self, x): 59 | x_space = x[..., 1:] 60 | x_space = self.activation(x_space) 61 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 62 | x = torch.cat([x_time, x_space], dim=-1) 63 | 64 | # Adjust for a different manifold if specified 65 | if self.manifold_out is not None: 66 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 67 | return x 68 | 69 | class HypDropout(nn.Module): 70 | def __init__(self, manifold, dropout, manifold_out=None): 71 | super(HypDropout, self).__init__() 72 | self.manifold = manifold 73 | self.manifold_out = manifold_out 74 | self.dropout = nn.Dropout(dropout) 75 | 76 | def forward(self, x, training=False): 77 | if training: 78 | x_space = x[..., 1:] 79 | x_space = self.dropout(x_space) 80 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 81 | x = torch.cat([x_time, x_space], dim=-1) 82 | 83 | # Adjust for a different manifold if specified 84 | if self.manifold_out is not None: 85 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 86 | return x 87 | 88 | class HypLinear(nn.Module): 89 | """ 90 | Parameters: 91 | manifold (manifold): The manifold to use for the linear transformation. 92 | in_features (int): The size of each input sample. 93 | out_features (int): The size of each output sample. 94 | bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True. 95 | dropout (float, optional): The dropout probability. Default is 0.1. 96 | """ 97 | 98 | def __init__(self, manifold, in_features, out_features, bias=True, manifold_out=None): 99 | super().__init__() 100 | self.in_features = in_features + 1 # + 1 for time dimension 101 | self.out_features = out_features 102 | self.bias = bias 103 | self.manifold = manifold 104 | self.manifold_out = manifold_out 105 | 106 | self.linear = nn.Linear(self.in_features, self.out_features, bias=bias) 107 | self.reset_parameters() 108 | 109 | def reset_parameters(self): 110 | init.xavier_uniform_(self.linear.weight, gain=math.sqrt(2)) 111 | init.constant_(self.linear.bias, 0) 112 | 113 | def forward(self, x, x_manifold='hyp'): 114 | if x_manifold != 'hyp': 115 | x = torch.cat([torch.ones_like(x)[..., 0:1], x], dim=-1) 116 | x = self.manifold.expmap0(x) 117 | x_space = self.linear(x) 118 | 119 | x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.manifold.k).sqrt() 120 | x = torch.cat([x_time, x_space], dim=-1) 121 | 122 | # Adjust for a different manifold if specified 123 | if self.manifold_out is not None: 124 | x = x * (self.manifold_out.k / self.manifold.k).sqrt() 125 | return x 126 | 127 | class HypCLS(nn.Module): 128 | def __init__(self, manifold, in_channels, out_channels, bias=True): 129 | super().__init__() 130 | self.manifold = manifold 131 | self.in_channels = in_channels 132 | self.out_channels = out_channels 133 | cls_emb = self.manifold.random_normal((self.out_channels, self.in_channels + 1), mean=0, std=1. / math.sqrt(self.in_channels + 1)) 134 | self.cls = ManifoldParameter(cls_emb, self.manifold, requires_grad=True) 135 | if bias: 136 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 137 | 138 | def cinner(self, x, y): 139 | x = x.clone() 140 | x.narrow(-1, 0, 1).mul_(-1) 141 | return x @ y.transpose(-1, -2) 142 | 143 | def forward(self, x, x_manifold='hyp', return_type='neg_dist'): 144 | if x_manifold != 'hyp': 145 | x = self.manifold.expmap0(torch.cat([torch.zeros_like(x)[..., 0:1], x], dim=-1)) # project to Lorentz 146 | 147 | dist = -2 * self.manifold.k - 2 * self.manifold.cinner(x, self.cls) + self.bias 148 | 149 | # dist = self.manifold.cdist(x, self.cls) + self.bias 150 | # dist = dist.clamp(min=1e-6) 151 | 152 | if return_type == 'neg_dist': 153 | return - dist 154 | elif return_type == 'prob': 155 | return 10 / (1.0 + dist) 156 | elif return_type == 'neg_log_prob': 157 | return - 10*torch.log(1.0 + dist) 158 | else: 159 | raise NotImplementedError 160 | 161 | 162 | class Optimizer(object): 163 | def __init__(self, model, args): 164 | # Extract optimizer types and parameters from arguments 165 | euc_optimizer_type = getattr(args, 'euc_optimizer_type', args.optimizer_type) # Euclidean optimizer type 166 | hyp_optimizer_type = getattr(args, 'hyp_optimizer_type', args.hyp_optimizer_type) # Hyperbolic optimizer type 167 | euc_lr = getattr(args, 'euc_lr', args.lr) # Euclidean learning rate 168 | hyp_lr = getattr(args, 'hyp_lr', args.hyp_lr) # Hyperbolic learning rate 169 | euc_weight_decay = getattr(args, 'euc_weight_decay', args.weight_decay) # Euclidean weight decay 170 | hyp_weight_decay = getattr(args, 'hyp_weight_decay', args.hyp_weight_decay) # Hyperbolic weight decay 171 | 172 | # Separate parameters for Euclidean and Hyperbolic parts of the model 173 | euc_params = [p for n, p in model.named_parameters() if p.requires_grad and not isinstance(p, ManifoldParameter)] # Euclidean parameters 174 | hyp_params = [p for n, p in model.named_parameters() if p.requires_grad and isinstance(p, ManifoldParameter)] # Hyperbolic parameters 175 | 176 | # Print the number of Euclidean and Hyperbolic parameters 177 | # print(f">> Number of Euclidean parameters: {sum(p.numel() for p in euc_params)}") 178 | # print(f">> Number of Hyperbolic parameters: {sum(p.numel() for p in hyp_params)}") 179 | # Initialize Euclidean optimizer 180 | 181 | if euc_optimizer_type == 'adam': 182 | optimizer_euc = torch.optim.Adam(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 183 | elif euc_optimizer_type == 'sgd': 184 | optimizer_euc = torch.optim.SGD(euc_params, lr=euc_lr, weight_decay=euc_weight_decay) 185 | else: 186 | raise NotImplementedError("Unsupported Euclidean optimizer type") 187 | 188 | # Initialize Hyperbolic optimizer if there are Hyperbolic parameters 189 | if hyp_params: 190 | if hyp_optimizer_type == 'radam': 191 | optimizer_hyp = RiemannianAdam(hyp_params, lr=hyp_lr, stabilize=50, weight_decay=hyp_weight_decay) 192 | elif hyp_optimizer_type == 'rsgd': 193 | optimizer_hyp = RiemannianSGD(hyp_params, lr=hyp_lr, stabilize=50, weight_decay=hyp_weight_decay) 194 | else: 195 | raise NotImplementedError("Unsupported Hyperbolic optimizer type") 196 | 197 | # Store both optimizers 198 | self.optimizer = [optimizer_euc, optimizer_hyp] 199 | else: 200 | # Store only Euclidean optimizer if there are no Hyperbolic parameters 201 | self.optimizer = [optimizer_euc] 202 | 203 | def step(self): 204 | # Perform optimization step for each optimizer 205 | for optimizer in self.optimizer: 206 | optimizer.step() 207 | 208 | def zero_grad(self): 209 | # Reset gradients to zero for each optimizer 210 | for optimizer in self.optimizer: 211 | optimizer.zero_grad() 212 | 213 | -------------------------------------------------------------------------------- /medium/manifolds/manifold_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Tuple, Any, Union, List 3 | import functools 4 | import operator 5 | import torch 6 | import geoopt 7 | 8 | 9 | max_norm = 85 10 | eps = 1e-8 11 | 12 | __all__ = [ 13 | "copy_or_set_", 14 | "strip_tuple", 15 | "size2shape", 16 | "make_tuple", 17 | "broadcast_shapes", 18 | "ismanifold", 19 | "canonical_manifold", 20 | "list_range", 21 | "idx2sign", 22 | "drop_dims", 23 | "canonical_dims", 24 | "sign", 25 | "prod", 26 | "clamp_abs", 27 | "sabs", 28 | ] 29 | 30 | 31 | def copy_or_set_(dest: torch.Tensor, source: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Copy or inplace set from :code:`source` to :code:`dest`. 34 | 35 | A workaround to respect strides of :code:`dest` when copying :code:`source`. 36 | The original issue was raised `here `_ 37 | when working with matrix manifolds. Inplace set operation is mode efficient, 38 | but the resulting storage might be incompatible after. To avoid the issue we refer to 39 | the safe option and use :code:`copy_` if strides do not match. 40 | 41 | Parameters 42 | ---------- 43 | dest : torch.Tensor 44 | Destination tensor where to store new data 45 | source : torch.Tensor 46 | Source data to put in the new tensor 47 | 48 | Returns 49 | ------- 50 | dest 51 | torch.Tensor, modified inplace 52 | """ 53 | if dest.stride() != source.stride(): 54 | return dest.copy_(source) 55 | else: 56 | return dest.set_(source) 57 | 58 | 59 | def strip_tuple(tup: Tuple) -> Union[Tuple, Any]: 60 | if len(tup) == 1: 61 | return tup[0] 62 | else: 63 | return tup 64 | 65 | 66 | def make_tuple(obj: Union[Tuple, List, Any]) -> Tuple: 67 | if isinstance(obj, list): 68 | obj = tuple(obj) 69 | if not isinstance(obj, tuple): 70 | return (obj,) 71 | else: 72 | return obj 73 | 74 | 75 | def prod(items): 76 | return functools.reduce(operator.mul, items, 1) 77 | 78 | 79 | def sign(x): 80 | # Ensure sign is either +1 or -1, mapping zeros to +1. 81 | return torch.sign(x.sign() + 0.5) 82 | 83 | 84 | def sabs(x, eps: float = 1e-15): 85 | return x.abs().add_(eps) 86 | 87 | 88 | def clamp_abs(x, eps: float = 1e-15): 89 | s = sign(x) 90 | return s * sabs(x, eps=eps) 91 | 92 | 93 | def idx2sign(idx: int, dim: int, neg: bool = True): 94 | """ 95 | Unify idx to be negative or positive, that helps in cases of broadcasting. 96 | 97 | Parameters 98 | ---------- 99 | idx : int 100 | current index 101 | dim : int 102 | maximum dimension 103 | neg : bool 104 | indicate we need negative index 105 | 106 | Returns 107 | ------- 108 | int 109 | """ 110 | if neg: 111 | if idx < 0: 112 | return idx 113 | else: 114 | return (idx + 1) % -(dim + 1) 115 | else: 116 | return idx % dim 117 | 118 | 119 | def drop_dims(tensor: torch.Tensor, dims: List[int]): 120 | # Workaround to drop several dims in :func:`torch.squeeze`. 121 | seen: int = 0 122 | for d in dims: 123 | tensor = tensor.squeeze(d - seen) 124 | seen += 1 125 | return tensor 126 | 127 | 128 | def list_range(end: int): 129 | res: List[int] = [] 130 | for d in range(end): 131 | res.append(d) 132 | return res 133 | 134 | 135 | def canonical_dims(dims: List[int], maxdim: int): 136 | result: List[int] = [] 137 | for idx in dims: 138 | result.append(idx2sign(idx, maxdim, neg=False)) 139 | return result 140 | 141 | 142 | def size2shape(*size: Union[Tuple[int], int]) -> Tuple[int]: 143 | return make_tuple(strip_tuple(size)) 144 | 145 | 146 | def broadcast_shapes(*shapes: Tuple[int]) -> Tuple[int]: 147 | """Apply numpy broadcasting rules to shapes.""" 148 | result = [] 149 | for dims in itertools.zip_longest(*map(reversed, shapes), fillvalue=1): 150 | dim: int = 1 151 | for d in dims: 152 | if dim != 1 and d != 1 and d != dim: 153 | raise ValueError("Shapes can't be broadcasted") 154 | elif d > dim: 155 | dim = d 156 | result.append(dim) 157 | return tuple(reversed(result)) 158 | 159 | 160 | def ismanifold(instance, cls): 161 | """ 162 | Check if interface of an instance is compatible with given class. 163 | 164 | Parameters 165 | ---------- 166 | instance : geoopt.Manifold 167 | check if a given manifold is compatible with cls API 168 | cls : type 169 | manifold type 170 | 171 | Returns 172 | ------- 173 | bool 174 | comparison result 175 | """ 176 | if not issubclass(cls, geoopt.manifolds.Manifold): 177 | raise TypeError( 178 | "`cls` should be a subclass of geoopt.manifolds.Manifold") 179 | if not isinstance(instance, geoopt.manifolds.Manifold): 180 | return False 181 | else: 182 | # this is the case to care about, Scaled class is a proxy, but fails instance checks 183 | while isinstance(instance, geoopt.Scaled): 184 | instance = instance.base 185 | return isinstance(instance, cls) 186 | 187 | 188 | def canonical_manifold(manifold: "geoopt.Manifold"): 189 | """ 190 | Get a canonical manifold. 191 | 192 | If a manifold is wrapped with Scaled. Some attributes may not be available. This should help if you really need them. 193 | 194 | Parameters 195 | ---------- 196 | manifold : geoopt.Manifold 197 | 198 | Returns 199 | ------- 200 | geoopt.Maniflold 201 | an unwrapped manifold 202 | """ 203 | while isinstance(manifold, geoopt.Scaled): 204 | manifold = manifold.base 205 | return manifold 206 | 207 | 208 | def cosh(x: torch.Tensor) -> torch.Tensor: 209 | x = clamp(x, min=-max_norm, max=max_norm) 210 | return torch.cosh(x) 211 | 212 | 213 | def sinh(x: torch.Tensor) -> torch.Tensor: 214 | x = clamp(x, min=-max_norm, max=max_norm) 215 | return torch.sinh(x) 216 | 217 | 218 | def sqrt(x: torch.Tensor) -> torch.Tensor: 219 | x = clamp(x, min=1e-9) # Smaller epsilon due to precision around x=0. 220 | return torch.sqrt(x) 221 | 222 | 223 | class LeakyClamp(torch.autograd.Function): 224 | 225 | @staticmethod 226 | def forward(ctx: Any, x: torch.Tensor, min: float, max: float) -> torch.Tensor: 227 | with torch.no_grad(): 228 | ctx.save_for_backward(x.ge(min) & x.le(max)) 229 | return torch.clamp(x, min=min, max=max) 230 | 231 | @staticmethod 232 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 233 | mask, = ctx.saved_tensors 234 | mask = mask.type_as(grad_output) 235 | return grad_output * mask + grad_output * (1 - mask) * eps, None, None 236 | 237 | 238 | def clamp(x: torch.Tensor, min: float = float("-inf"), max: float = float("+inf")) -> torch.Tensor: 239 | return LeakyClamp.apply(x, min, max) 240 | 241 | 242 | class Atanh(torch.autograd.Function): 243 | """ 244 | Numerically stable arctanh that never returns NaNs. 245 | x = clamp(x, min=-1+eps, max=1-eps) 246 | Returns atanh(x) = arctanh(x) = 0.5*(log(1+x)-log(1-x)). 247 | """ 248 | 249 | @staticmethod 250 | def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: 251 | x = clamp(x, min=-1. + 4 * eps, max=1. - 4 * eps) 252 | ctx.save_for_backward(x) 253 | res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) 254 | return res 255 | 256 | @staticmethod 257 | def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: 258 | x, = ctx.saved_tensors 259 | return grad_output / (1 - x**2) 260 | 261 | 262 | def atanh(x: torch.Tensor) -> torch.Tensor: 263 | """ 264 | Numerically stable arctanh that never returns NaNs. 265 | 266 | :param x: The input tensor. 267 | :return: log(x + sqrt(max(x^2 - 1, eps)) 268 | """ 269 | return Atanh.apply(x) 270 | 271 | 272 | class Acosh(torch.autograd.Function): 273 | """ 274 | Numerically stable arccosh that never returns NaNs. 275 | Returns acosh(x) = arccosh(x) = log(x + sqrt(max(x^2 - 1, eps))). 276 | """ 277 | 278 | @staticmethod 279 | def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: 280 | with torch.no_grad(): 281 | x = clamp(x, min=1 + eps) 282 | z = sqrt(x * x - 1.) 283 | ctx.save_for_backward(z) 284 | return torch.log(x + z) 285 | 286 | @staticmethod 287 | def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: 288 | z, = ctx.saved_tensors 289 | # z_ = clamp(z, min=eps) 290 | z_ = z 291 | return grad_output / z_ 292 | 293 | 294 | def acosh(x: torch.Tensor) -> torch.Tensor: 295 | """ 296 | Numerically stable arccosh that never returns NaNs. 297 | 298 | :param x: The input tensor. 299 | :return: log(x + sqrt(max(x^2 - 1, eps)) 300 | """ 301 | return Acosh.apply(x) -------------------------------------------------------------------------------- /medium/parse.py: -------------------------------------------------------------------------------- 1 | from hypformer import HypFormer 2 | 3 | 4 | def parse_method(args, device): 5 | if args.method == 'hypformer': 6 | model = HypFormer(args=args).to(device) 7 | else: 8 | raise ValueError(f'Invalid method {args.method}') 9 | return model 10 | 11 | 12 | def parser_add_main_args(parser): 13 | # dataset and evaluation 14 | parser.add_argument('--data_dir', type=str, default='../data', help='location of the data') 15 | parser.add_argument('--dataset', type=str, default='cora', help='name of dataset') 16 | parser.add_argument('--sub_dataset', type=str, default='gcn_data', help='name of sub dataset') 17 | parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') 18 | parser.add_argument('--seed', type=int, default=42) 19 | parser.add_argument('--cpu', type=int, default=0, help='use CPU instead of GPU') 20 | parser.add_argument('--epochs', type=int, default=500) 21 | parser.add_argument('--runs', type=int, default=1, help='number of distinct runs') 22 | parser.add_argument('--train_prop', type=float, default=.5, help='training label proportion') 23 | parser.add_argument('--valid_prop', type=float, default=.25, help='validation label proportion') 24 | parser.add_argument('--protocol', type=str, default='semi', help='protocol for cora datasets, semi or supervised') 25 | parser.add_argument('--rand_split', type=int, default=0, help='use random splits') 26 | parser.add_argument('--rand_split_class', type=int, default=0, 27 | help='use random splits with a fixed number of labeled nodes for each class') 28 | parser.add_argument('--label_num_per_class', type=int, default=20, 29 | help='labeled nodes per class (randomly selected)') 30 | parser.add_argument('--valid_num', type=int, default=500, help='total number of validation nodes') 31 | parser.add_argument('--test_num', type=int, default=500, help='total number of test nodes') 32 | parser.add_argument('--no_feat_norm', type=int, default=1) 33 | 34 | # display and utility 35 | parser.add_argument('--display_step', type=int, default=50, help='how often to print') 36 | 37 | # model 38 | parser.add_argument('--method', type=str, default='gcn') 39 | parser.add_argument('--hidden_channels', type=int, default=32) 40 | 41 | # gnn branch 42 | parser.add_argument('--use_graph', type=int, default=1, help='use graph encoder or not') 43 | parser.add_argument('--graph_weight', type=float, default=0.5, help='weight for graph encoder') 44 | parser.add_argument('--gnn_use_bn', type=int, default=1, help='use batchnorm for each GNN layer or not') 45 | parser.add_argument('--gnn_use_residual', type=int, default=1, help='use residual link for each GNN layer or not') 46 | parser.add_argument('--gnn_use_weight', type=int, default=0, help='use weight for GNN convolution') 47 | parser.add_argument('--gnn_use_init', type=int, default=0, help='use initial feat for each GNN layer or not') 48 | parser.add_argument('--gnn_use_act', type=int, default=1, help='use activation for each GNN layer or not') 49 | parser.add_argument('--gnn_num_layers', type=int, default=2, help='number of layers for GNN') 50 | parser.add_argument('--gnn_dropout', type=float, default=0.5) 51 | parser.add_argument('--knn_num', type=int, default=5, help='number of k for KNN graph') 52 | 53 | # attention (Transformer) branch 54 | parser.add_argument('--trans_num_heads', type=int, default=1, help='number of attention heads') 55 | parser.add_argument('--trans_heads_concat', type=int, default=0, help='concatenate multi-head attentions or not') 56 | parser.add_argument('--trans_use_weight', type=int, default=1, help='use weight for transformer convolution or not') 57 | parser.add_argument('--trans_use_bn', type=int, default=0, help='use batchnorm for each transformer layer or not') 58 | parser.add_argument('--trans_use_residual', type=int, default=0, 59 | help='use residual link for each transformer layer or not') 60 | parser.add_argument('--trans_use_act', type=int, default=0, help='use activation for each transformer layer or not') 61 | parser.add_argument('--trans_num_layers', type=int, default=1, help='number of layers for all-pair attention') 62 | 63 | parser.add_argument('--trans_dropout', type=float, default=0.0, help='transformer dropout') 64 | parser.add_argument('--k_in', type=float, default=1.0, help='manifold_in curvature') 65 | parser.add_argument('--k_hidden', type=float, default=1.0, help='Curvature for input layer (default: 1.0)') 66 | parser.add_argument('--k_out', type=float, default=1.0, help='manifold_out curvature') 67 | parser.add_argument('--power_k', type=float, default=2.0, help='power k for query and key') 68 | parser.add_argument('--attention_type', type=str, default='linear_focused', 69 | help='attention type: linear_focused, or full') 70 | parser.add_argument('--add_positional_encoding', type=int, default=1, help='add positional encoding or not') 71 | parser.add_argument('--output_attention', type=int, default=0, help='output attention or not') # output_attention 72 | 73 | # training 74 | parser.add_argument('--patience', type=int, default=200, help='early stopping patience') 75 | parser.add_argument('--lr', type=float, default=0.01) 76 | parser.add_argument('--hyp_lr', type=float, default=0.01) 77 | 78 | parser.add_argument('--optimizer_type', type=str, default='adam', choices=['adam', 'sgd']) 79 | parser.add_argument('--hyp_optimizer_type', type=str, default='radam', choices=['radam', 'rsgd']) 80 | parser.add_argument('--weight_decay', type=float, default=0.005) 81 | parser.add_argument('--hyp_weight_decay', type=float, default=0.005) 82 | 83 | parser.add_argument('--decoder_type', type=str, default='euc') 84 | 85 | parser.add_argument('--use_wandb', type=int, default=0, help='use wandb for logging') 86 | parser.add_argument('--wandb_name', type=int, default=0, help='wandb run name') 87 | parser.add_argument('--run_id', type=str, default='0', help='Run ID (default: 0)') 88 | parser.add_argument('--save_result', type=int, default=0, help='save whole test result') 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | googledrivedownloader==0.4 2 | networkx==3.2.1 3 | numpy==1.26.4 4 | ogb==1.3.6 5 | scikit-learn==1.4.2 6 | scipy==1.12.0 7 | geoopt==0.5.0 8 | torch==2.2.1 9 | torch_cluster==1.6.3 10 | torch_geometric==2.5.3 11 | torch_scatter==2.1.2 12 | torch_sparse==0.6.18 13 | torch_spline_conv==1.2.2 14 | wandb 15 | tqdm 16 | --------------------------------------------------------------------------------