├── spexphormer ├── loader │ ├── dataset │ │ ├── __init__.py │ │ ├── Amazon_with_split.py │ │ ├── HeterophilousGraphDataset.py │ │ ├── Pokec.py │ │ ├── malnet_tiny.py │ │ ├── peptides_functional.py │ │ ├── aqsol_molecules.py │ │ ├── peptides_structural.py │ │ ├── voc_superpixels.py │ │ └── coco_superpixels.py │ ├── __init__.py │ ├── planetoid.py │ ├── heterogeneous_datasets.py │ └── split_generator.py ├── config │ ├── __init__.py │ ├── dataset_config.py │ ├── wandb_config.py │ ├── split_config.py │ ├── optimizers_config.py │ ├── data_preprocess_config.py │ ├── gt_config.py │ ├── defaults_config.py │ └── posenc_config.py ├── head │ ├── __init__.py │ ├── example.py │ └── inductive_node.py ├── layer │ ├── __init__.py │ ├── Spexphormer_Attention.py │ ├── Exphormer_Attention.py │ ├── ASE_Attention.py │ └── SpExphormer_full_layer.py ├── loss │ ├── __init__.py │ ├── l1.py │ ├── multilabel_classification_loss.py │ └── weighted_cross_entropy.py ├── train │ ├── __init__.py │ ├── neighbor_sampler.py │ └── custom_train.py ├── encoder │ ├── __init__.py │ ├── linear_node_encoder.py │ └── type_dict_encoder.py ├── network │ ├── __init__.py │ └── SpExphormer_model.py ├── optimizer │ ├── __init__.py │ └── extra_optimizers.py ├── transform │ ├── __init__.py │ ├── transforms.py │ └── expander_edges.py ├── __init__.py ├── utils.py └── metrics_ogb.py ├── Spexphormer.png ├── .gitattributes ├── LICENSE ├── configs ├── heterophilic │ ├── actor │ │ ├── actor_exphormer.yaml │ │ ├── actor_ASE.yaml │ │ └── actor_spexphormer.yaml │ ├── tolokers │ │ ├── tolokers_spexphormer.yaml │ │ └── tolokers_ASE.yaml │ └── minesweeper │ │ ├── minesweeper_spexphormer.yaml │ │ └── minesweeper_ASE.yaml ├── large │ ├── pokec │ │ ├── pokec_ASE.yaml │ │ └── pokec_spexphormer.yaml │ ├── proteins │ │ ├── proteins_ASE.yaml │ │ └── proteins_spexphormer.yaml │ └── amazon2m │ │ ├── amazon2m_spexphormer.yaml │ │ └── amazon2m_ASE.yaml └── homophilic │ ├── CS │ ├── CS_ASE.yaml │ └── CS_spexphormer.yaml │ ├── photo │ ├── photo_ASE.yaml │ └── photo_spexphormer.yaml │ ├── physics │ ├── physics_ASE.yaml │ └── physics_spexphormer.yaml │ └── computer │ ├── computer_ASE.yaml │ └── computer_spexphormer.yaml ├── README.md ├── .gitignore └── main.py /spexphormer/loader/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Spexphormer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamed1375/Sp_Exphormer/HEAD/Spexphormer.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Exclude notebook from language statistics 2 | *.ipynb linguist-documentation 3 | -------------------------------------------------------------------------------- /spexphormer/config/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/head/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/train/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/network/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /spexphormer/__init__.py: -------------------------------------------------------------------------------- 1 | # from .act import * # noqa 2 | from .config import * # noqa 3 | from .encoder import * # noqa 4 | from .head import * # noqa 5 | from .layer import * # noqa 6 | from .loader import * # noqa 7 | from .loss import * # noqa 8 | from .network import * # noqa 9 | from .optimizer import * # noqa 10 | # from .pooling import * # noqa 11 | # from .stage import * # noqa 12 | from .train import * # noqa 13 | from .transform import * # noqa -------------------------------------------------------------------------------- /spexphormer/encoder/linear_node_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym import cfg 3 | from torch_geometric.graphgym.register import register_node_encoder 4 | 5 | 6 | @register_node_encoder('LinearNode') 7 | class LinearNodeEncoder(torch.nn.Module): 8 | def __init__(self, emb_dim): 9 | super().__init__() 10 | 11 | self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) 12 | 13 | def forward(self, batch): 14 | batch.x = self.encoder(batch.x) 15 | return batch -------------------------------------------------------------------------------- /spexphormer/loss/l1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('l1_losses') 7 | def l1_losses(pred, true): 8 | if cfg.model.loss_fun == 'l1': 9 | l1_loss = nn.L1Loss() 10 | loss = l1_loss(pred, true) 11 | return loss, pred 12 | elif cfg.model.loss_fun == 'smoothl1': 13 | l1_loss = nn.SmoothL1Loss() 14 | loss = l1_loss(pred, true) 15 | return loss, pred 16 | -------------------------------------------------------------------------------- /spexphormer/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('dataset_cfg') 5 | def dataset_cfg(cfg): 6 | """Dataset-specific config options. 7 | """ 8 | 9 | # The number of node types to expect in TypeDictNodeEncoder. 10 | cfg.dataset.node_encoder_num_types = 0 11 | 12 | # The number of edge types to expect in TypeDictEdgeEncoder. 13 | cfg.dataset.edge_encoder_num_types = 0 14 | 15 | # VOC/COCO Superpixels dataset version based on SLIC compactness parameter. 16 | cfg.dataset.slic_compactness = 10 17 | -------------------------------------------------------------------------------- /spexphormer/config/wandb_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_wandb') 6 | def set_cfg_wandb(cfg): 7 | """Weights & Biases tracker configuration. 8 | """ 9 | 10 | # WandB group 11 | cfg.wandb = CN() 12 | 13 | # Use wandb or not 14 | cfg.wandb.use = False 15 | 16 | # Wandb entity name, should exist beforehand 17 | cfg.wandb.entity = "gtransformers" 18 | 19 | # Wandb project name, will be created in your team if doesn't exist already 20 | cfg.wandb.project = "gtblueprint" 21 | 22 | # Optional run name 23 | cfg.wandb.name = "" 24 | -------------------------------------------------------------------------------- /spexphormer/loss/multilabel_classification_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('multilabel_cross_entropy') 7 | def multilabel_cross_entropy(pred, true): 8 | """Multilabel cross-entropy loss. 9 | """ 10 | if cfg.dataset.task_type == 'classification_multilabel': 11 | if cfg.model.loss_fun != 'cross_entropy': 12 | raise ValueError("Only 'cross_entropy' loss_fun supported with " 13 | "'classification_multilabel' task_type.") 14 | bce_loss = nn.BCEWithLogitsLoss() 15 | is_labeled = true == true # Filter our nans. 16 | return bce_loss(pred[is_labeled], true[is_labeled].float()), pred 17 | -------------------------------------------------------------------------------- /spexphormer/config/split_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('split') 5 | def set_cfg_split(cfg): 6 | """Reconfigure the default config value for dataset split options. 7 | 8 | Returns: 9 | Reconfigured split configuration use by the experiment. 10 | """ 11 | 12 | # Default to selecting the standard split that ships with the dataset 13 | cfg.dataset.split_mode = 'standard' 14 | 15 | # Choose a particular split to use if multiple splits are available 16 | cfg.dataset.split_index = 0 17 | 18 | # Dir to cache cross-validation splits 19 | cfg.dataset.split_dir = './splits' 20 | 21 | # Choose to run multiple splits in one program execution, if set, 22 | # takes the precedence over cfg.dataset.split_index for split selection 23 | cfg.run_multiple_splits = [] 24 | -------------------------------------------------------------------------------- /spexphormer/head/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch_geometric.graphgym.register import register_head 4 | 5 | 6 | @register_head('head') 7 | class ExampleNodeHead(nn.Module): 8 | '''Head of GNN, node prediction''' 9 | def __init__(self, dim_in, dim_out): 10 | super().__init__() 11 | self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True) 12 | 13 | def _apply_index(self, batch): 14 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]: 15 | return batch.x[batch.node_label_index], batch.node_label 16 | else: 17 | return batch.x[batch.node_label_index], \ 18 | batch.node_label[batch.node_label_index] 19 | 20 | def forward(self, batch): 21 | batch = self.layer_post_mp(batch) 22 | pred, label = self._apply_index(batch) 23 | return pred, label 24 | -------------------------------------------------------------------------------- /spexphormer/config/optimizers_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('extended_optim') 5 | def extended_optim_cfg(cfg): 6 | """Extend optimizer config group that is first set by GraphGym in 7 | torch_geometric.graphgym.config.set_cfg 8 | """ 9 | 10 | # Number of batches to accumulate gradients over before updating parameters 11 | # Requires `custom` training loop, set `train.mode: custom` 12 | cfg.optim.batch_accumulation = 1 13 | 14 | # ReduceLROnPlateau: Factor by which the learning rate will be reduced 15 | cfg.optim.reduce_factor = 0.1 16 | 17 | # ReduceLROnPlateau: #epochs without improvement after which LR gets reduced 18 | cfg.optim.schedule_patience = 10 19 | 20 | # ReduceLROnPlateau: Lower bound on the learning rate 21 | cfg.optim.min_lr = 0.0 22 | 23 | # For schedulers with warm-up phase, set the warm-up number of epochs 24 | cfg.optim.num_warmup_epochs = 50 25 | 26 | # Clip gradient norms while training 27 | cfg.optim.clip_grad_norm = False 28 | -------------------------------------------------------------------------------- /spexphormer/head/inductive_node.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.models.layer import new_layer_config, MLP 4 | from torch_geometric.graphgym.register import register_head 5 | 6 | 7 | @register_head('inductive_node') 8 | class GNNInductiveNodeHead(nn.Module): 9 | """ 10 | GNN prediction head for inductive node prediction tasks. 11 | 12 | Args: 13 | dim_in (int): Input dimension 14 | dim_out (int): Output dimension. For binary prediction, dim_out=1. 15 | """ 16 | 17 | def __init__(self, dim_in, dim_out): 18 | super(GNNInductiveNodeHead, self).__init__() 19 | self.layer_post_mp = MLP( 20 | new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, 21 | has_act=False, has_bias=True, cfg=cfg)) 22 | 23 | def _apply_index(self, batch): 24 | return batch.x, batch.y 25 | 26 | def forward(self, batch): 27 | batch = self.layer_post_mp(batch) 28 | pred, label = self._apply_index(batch) 29 | return pred, label -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /spexphormer/config/data_preprocess_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | def set_cfg_preprocess(cfg): 6 | """Extend configuration with preprocessing options 7 | """ 8 | 9 | cfg.prep = CN() 10 | 11 | # Argument group for adding expander edges 12 | 13 | # if it's enabled expander edges would be available by e.g. data.expander_edges 14 | cfg.prep.exp = True 15 | cfg.prep.exp_deg = 5 16 | cfg.prep.exp_algorithm = 'Hamiltonian' # options are 'Hamiltonian', 'Random-d', 'Random-d2' 17 | 18 | cfg.prep.use_exp_edges = True 19 | cfg.prep.replace_combined_exp_edges = False 20 | cfg.prep.exp_max_num_iters = 100 21 | cfg.prep.add_edge_index = True 22 | cfg.prep.num_virt_node = 0 23 | 24 | cfg.prep.add_self_loops = False 25 | cfg.prep.add_reverse_edges = True 26 | cfg.prep.layer_edge_indices_dir = None 27 | cfg.prep.save_edges = False 28 | cfg.prep.load_edges = False 29 | cfg.prep.num_edge_sets = 1 30 | cfg.prep.edge_set_name = None 31 | 32 | cfg.prep.default_initial = False 33 | 34 | 35 | register_config('preprocess', set_cfg_preprocess) 36 | -------------------------------------------------------------------------------- /configs/heterophilic/actor/actor_exphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | # gpu_mem: True 4 | wandb: 5 | use: True 6 | project: actor 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-actor 10 | name: actor 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: True 24 | exp_deg: 15 25 | replace_combined_exp_edges: True 26 | train: 27 | mode: custom_train 28 | ckpt_best: True 29 | eval_period: 1 30 | ckpt_period: 100 31 | model: 32 | type: Exphormer 33 | loss_fun: cross_entropy 34 | gt: 35 | layer_type: Exphormer 36 | layers: 4 37 | n_heads: 2 38 | dim_hidden: 64 39 | dropout: 0.5 40 | layer_norm: False 41 | batch_norm: True 42 | gnn: 43 | head: default 44 | layers_pre_mp: 1 45 | layers_post_mp: 1 46 | optim: 47 | clip_grad_norm: True 48 | optimizer: adamW 49 | weight_decay: 1e-3 50 | base_lr: 0.01 51 | max_epoch: 50 52 | scheduler: cosine_with_warmup 53 | num_warmup_epochs: 5 54 | -------------------------------------------------------------------------------- /configs/heterophilic/tolokers/tolokers_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: auc 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: Tolokers 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Heterophilous 10 | name: Tolokers 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: False 24 | train: 25 | mode: custom_with_sampling 26 | batch_size: 100000 27 | eval_period: 1 28 | ckpt_period: 100 29 | edge_sample_num_neighbors: [12, 10, 10, 10] 30 | model: 31 | type: Exphormer 32 | loss_fun: cross_entropy 33 | gt: 34 | layer_type: Spexphormer 35 | layers: 4 36 | n_heads: 4 37 | dim_hidden: 32 38 | dropout: 0.25 39 | layer_norm: False 40 | batch_norm: True 41 | gnn: 42 | head: inductive_node 43 | layers_pre_mp: 1 44 | layers_post_mp: 1 45 | optim: 46 | clip_grad_norm: True 47 | optimizer: adamW 48 | weight_decay: 1e-3 49 | base_lr: 0.01 50 | max_epoch: 200 51 | scheduler: cosine_with_warmup 52 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /spexphormer/config/gt_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_gt') 6 | def set_cfg_gt(cfg): 7 | """Configuration for Exphormer Attention Layer 8 | """ 9 | 10 | # Positional encodings argument group 11 | cfg.gt = CN() 12 | 13 | # Type of Graph Transformer layer to use 14 | cfg.gt.layer_type = 'ExphormerInitial' 15 | 16 | # Number of Transformer layers in the model 17 | cfg.gt.layers = 3 18 | 19 | # Number of attention heads in the Graph Transformer 20 | cfg.gt.n_heads = 8 21 | 22 | # Size of the hidden node and edge representation 23 | cfg.gt.dim_hidden = 64 24 | 25 | # Size of the edge embedding 26 | cfg.gt.dim_edge = None 27 | 28 | # Dropout in feed-forward module. 29 | cfg.gt.dropout = 0.0 30 | 31 | cfg.gt.layer_norm = False 32 | 33 | cfg.gt.batch_norm = True 34 | 35 | cfg.gt.residual = True 36 | 37 | cfg.gt.activation = 'relu' 38 | 39 | cfg.gt.use_edge_feats = True 40 | 41 | # Feed forward network after the Attention layer 42 | cfg.gt.FFN = True 43 | 44 | # Jumping knowledge concatenate output of all layers to make the final prediction 45 | cfg.gt.JK = False -------------------------------------------------------------------------------- /configs/large/pokec/pokec_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: pokec 7 | entity: expand-gnns 8 | dataset: 9 | format: SNAP 10 | name: pokec 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: True 24 | exp_deg: 5 25 | train: 26 | mode: custom_train 27 | ckpt_best: True 28 | save_attention_scores: True 29 | eval_period: 5 30 | ckpt_period: 100 31 | temp_rdc_ratio: 0.95 32 | temp_min: 0.05 33 | temp_wait: 10 34 | model: 35 | type: Exphormer 36 | loss_fun: cross_entropy 37 | gt: 38 | layer_type: ASE 39 | layers: 2 40 | n_heads: 1 41 | dim_hidden: 8 42 | dropout: 0.0 43 | layer_norm: True 44 | batch_norm: False 45 | gnn: 46 | head: default 47 | layers_pre_mp: 1 48 | layers_post_mp: 1 49 | optim: 50 | clip_grad_norm: True 51 | optimizer: adamW 52 | weight_decay: 1e-3 53 | base_lr: 0.01 54 | max_epoch: 150 55 | scheduler: cosine_with_warmup 56 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/heterophilic/actor/actor_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: actor 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-actor 10 | name: actor 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: True 24 | exp_deg: 15 25 | train: 26 | mode: custom_train 27 | ckpt_best: True 28 | save_attention_scores: True 29 | eval_period: 1 30 | ckpt_period: 100 31 | temp_rdc_ratio: 0.99 32 | temp_min: 0.05 33 | temp_wait: 20 34 | model: 35 | type: Exphormer 36 | loss_fun: cross_entropy 37 | gt: 38 | layer_type: ASE 39 | layers: 4 40 | n_heads: 1 41 | dim_hidden: 4 42 | dropout: 0.0 43 | layer_norm: True 44 | batch_norm: False 45 | gnn: 46 | head: default 47 | layers_pre_mp: 1 48 | layers_post_mp: 1 49 | optim: 50 | clip_grad_norm: True 51 | optimizer: adamW 52 | weight_decay: 1e-3 53 | base_lr: 0.01 54 | max_epoch: 100 55 | scheduler: cosine_with_warmup 56 | num_warmup_epochs: 5 57 | -------------------------------------------------------------------------------- /configs/homophilic/CS/CS_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: cs3 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Coauthor 10 | name: cs 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'random' 22 | split: [0.6, 0.2, 0.2] 23 | prep: 24 | exp: True 25 | exp_deg: 15 26 | train: 27 | mode: custom_train 28 | ckpt_best: True 29 | save_attention_scores: True 30 | eval_period: 1 31 | ckpt_period: 100 32 | temp_rdc_ratio: 0.99 33 | temp_min: 0.05 34 | temp_wait: 5 35 | model: 36 | type: Exphormer 37 | loss_fun: cross_entropy 38 | gt: 39 | layer_type: ASE 40 | layers: 4 41 | n_heads: 1 42 | dim_hidden: 4 43 | dropout: 0.0 44 | layer_norm: True 45 | batch_norm: False 46 | gnn: 47 | head: default 48 | layers_pre_mp: 1 49 | layers_post_mp: 1 50 | optim: 51 | clip_grad_norm: True 52 | optimizer: adamW 53 | weight_decay: 1e-3 54 | base_lr: 0.01 55 | max_epoch: 200 56 | scheduler: cosine_with_warmup 57 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/homophilic/photo/photo_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | wandb: 4 | use: True 5 | project: photo3 6 | entity: expand-gnns 7 | dataset: 8 | format: PyG-Amazon 9 | name: photo 10 | task: node 11 | task_type: classification 12 | transductive: True 13 | node_encoder: True 14 | node_encoder_name: LinearNode 15 | node_encoder_bn: False 16 | edge_encoder: True 17 | edge_encoder_num_types: 3 18 | edge_encoder_name: TypeDictEdge 19 | edge_encoder_bn: False 20 | split_mode: 'standard' 21 | prep: 22 | exp: True 23 | exp_deg: 15 24 | train: 25 | mode: custom_train 26 | ckpt_best: True 27 | save_attention_scores: True 28 | eval_period: 1 29 | ckpt_period: 100 30 | temp_rdc_ratio: 0.95 31 | temp_min: 0.05 32 | temp_wait: 20 33 | sample_new_edges: False 34 | model: 35 | type: Exphormer 36 | loss_fun: cross_entropy 37 | gt: 38 | layer_type: ASE 39 | layers: 4 40 | n_heads: 1 41 | dim_hidden: 4 42 | dropout: 0.0 43 | layer_norm: True 44 | batch_norm: False 45 | gnn: 46 | head: default 47 | layers_pre_mp: 1 48 | layers_post_mp: 1 49 | optim: 50 | clip_grad_norm: True 51 | optimizer: adamW 52 | weight_decay: 1e-3 53 | base_lr: 0.1 54 | max_epoch: 100 55 | scheduler: cosine_with_warmup 56 | num_warmup_epochs: 5 57 | -------------------------------------------------------------------------------- /configs/heterophilic/minesweeper/minesweeper_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: auc 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: Minesweeper 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Heterophilous 10 | name: Minesweeper 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: False 24 | train: 25 | mode: custom_with_sampling 26 | batch_size: 100000 27 | eval_period: 1 28 | ckpt_period: 100 29 | edge_sample_num_neighbors: [12, 5, 5, 5] 30 | model: 31 | type: Exphormer 32 | loss_fun: cross_entropy 33 | gt: 34 | layer_type: Spexphormer 35 | layers: 4 36 | n_heads: 4 37 | dim_hidden: 32 # `gt.dim_hidden` must match `gnn.dim_inner` 38 | dropout: 0.2 39 | layer_norm: False 40 | batch_norm: True 41 | gnn: 42 | head: inductive_node 43 | layers_pre_mp: 0 44 | layers_post_mp: 1 45 | optim: 46 | clip_grad_norm: True 47 | optimizer: adamW 48 | weight_decay: 1e-3 49 | base_lr: 0.01 50 | max_epoch: 80 51 | scheduler: cosine_with_warmup 52 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/homophilic/physics/physics_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: physics3 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Coauthor 10 | name: physics 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'random' 22 | split: [0.6, 0.2, 0.2] 23 | prep: 24 | exp: True 25 | exp_deg: 15 26 | train: 27 | mode: custom_train 28 | ckpt_best: True 29 | save_attention_scores: True 30 | eval_period: 1 31 | ckpt_period: 100 32 | temp_rdc_ratio: 0.99 33 | temp_min: 0.05 34 | temp_wait: 5 35 | model: 36 | type: Exphormer 37 | loss_fun: cross_entropy 38 | gt: 39 | layer_type: ASE 40 | layers: 4 41 | n_heads: 1 42 | dim_hidden: 4 43 | dropout: 0.0 44 | layer_norm: True 45 | batch_norm: False 46 | gnn: 47 | head: default 48 | layers_pre_mp: 1 49 | layers_post_mp: 1 50 | optim: 51 | clip_grad_norm: True 52 | optimizer: adamW 53 | weight_decay: 1e-3 54 | base_lr: 0.01 55 | max_epoch: 200 56 | scheduler: cosine_with_warmup 57 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/homophilic/computer/computer_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | wandb: 4 | use: True 5 | project: computers3 6 | entity: expand-gnns 7 | dataset: 8 | format: PyG-Amazon 9 | name: computers 10 | task: node 11 | task_type: classification 12 | transductive: True 13 | node_encoder: True 14 | node_encoder_name: LinearNode 15 | node_encoder_bn: False 16 | edge_encoder: True 17 | edge_encoder_num_types: 3 18 | edge_encoder_name: TypeDictEdge 19 | edge_encoder_bn: False 20 | split_mode: 'standard' 21 | prep: 22 | exp: True 23 | exp_deg: 15 24 | train: 25 | mode: custom_train 26 | ckpt_best: True 27 | save_attention_scores: True 28 | eval_period: 1 29 | ckpt_period: 100 30 | temp_rdc_ratio: 0.99 31 | temp_min: 0.05 32 | temp_wait: 5 33 | model: 34 | type: Exphormer 35 | loss_fun: cross_entropy 36 | edge_decoding: dot 37 | graph_pooling: add 38 | gt: 39 | layer_type: ASE 40 | layers: 4 41 | n_heads: 1 42 | dim_hidden: 4 43 | dropout: 0.0 44 | layer_norm: True 45 | batch_norm: False 46 | gnn: 47 | head: default 48 | layers_pre_mp: 1 49 | layers_post_mp: 1 50 | optim: 51 | clip_grad_norm: True 52 | optimizer: adamW 53 | weight_decay: 1e-3 54 | base_lr: 0.01 55 | max_epoch: 200 56 | scheduler: cosine_with_warmup 57 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/homophilic/photo/photo_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | wandb: 4 | use: True 5 | project: photo3 6 | entity: expand-gnns 7 | dataset: 8 | format: PyG-Amazon 9 | name: photo 10 | task: node 11 | task_type: classification 12 | transductive: True 13 | node_encoder: True 14 | node_encoder_name: LinearNode 15 | node_encoder_bn: False 16 | edge_encoder: True 17 | edge_encoder_num_types: 3 18 | edge_encoder_name: TypeDictEdge2 19 | edge_encoder_bn: False 20 | split_mode: 'standard' 21 | prep: 22 | exp: False 23 | train: 24 | mode: custom_with_sampling 25 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform] 26 | batch_size: 10000 27 | eval_period: 1 28 | ckpt_period: 100 29 | edge_sample_num_neighbors: [5, 5, 5, 5] 30 | model: 31 | type: Exphormer 32 | loss_fun: cross_entropy 33 | gt: 34 | layer_type: Spexphormer 35 | layers: 4 36 | n_heads: 2 37 | dim_hidden: 56 38 | dropout: 0.5 39 | layer_norm: False 40 | batch_norm: True 41 | gnn: 42 | head: inductive_node 43 | layers_pre_mp: 1 44 | layers_post_mp: 1 45 | optim: 46 | clip_grad_norm: True 47 | optimizer: adamW 48 | weight_decay: 1e-3 49 | base_lr: 0.01 50 | max_epoch: 100 51 | scheduler: cosine_with_warmup 52 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/large/proteins/proteins_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: auc 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: ogbn-proteins 7 | entity: expand-gnns 8 | dataset: 9 | format: OGB 10 | name: ogbn-proteins 11 | task: node 12 | task_type: classification_multilabel 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: True 24 | exp_deg: 100 25 | train: 26 | mode: custom_train 27 | ckpt_best: True 28 | save_attention_scores: True 29 | eval_period: 5 30 | ckpt_period: 100 31 | temp_rdc_ratio: 0.99 32 | temp_min: 0.05 33 | temp_wait: 20 34 | model: 35 | type: Exphormer 36 | loss_fun: cross_entropy 37 | gt: 38 | layer_type: ASE 39 | layers: 2 40 | n_heads: 1 41 | dim_hidden: 8 42 | dropout: 0.0 43 | layer_norm: True 44 | batch_norm: False 45 | gnn: 46 | head: default 47 | layers_pre_mp: 1 48 | layers_post_mp: 1 # Not used when `gnn.head: san_graph` 49 | optim: 50 | clip_grad_norm: True 51 | optimizer: adamW 52 | weight_decay: 1e-3 53 | base_lr: 0.01 54 | max_epoch: 200 55 | scheduler: cosine_with_warmup 56 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/heterophilic/minesweeper/minesweeper_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: auc 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: Minesweeper 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Heterophilous 10 | name: Minesweeper 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: True 24 | exp_deg: 15 25 | train: 26 | mode: custom_train 27 | ckpt_best: True 28 | save_attention_scores: True 29 | eval_period: 1 30 | ckpt_period: 100 31 | temp_rdc_ratio: 0.99 32 | temp_min: 0.05 33 | temp_wait: 5 34 | model: 35 | type: Exphormer 36 | loss_fun: cross_entropy 37 | gt: 38 | layer_type: ASE 39 | layers: 4 40 | n_heads: 1 41 | dim_hidden: 4 42 | dropout: 0.0 43 | layer_norm: True 44 | batch_norm: False 45 | gnn: 46 | head: default 47 | layers_pre_mp: 1 48 | layers_post_mp: 1 # Not used when `gnn.head: san_graph` 49 | optim: 50 | clip_grad_norm: True 51 | optimizer: adamW 52 | weight_decay: 1e-3 53 | base_lr: 0.01 54 | max_epoch: 100 55 | scheduler: cosine_with_warmup 56 | num_warmup_epochs: 5 57 | -------------------------------------------------------------------------------- /configs/heterophilic/actor/actor_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: actor 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-actor 10 | name: actor 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: False 24 | train: 25 | mode: custom_with_sampling 26 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform, graph_max] 27 | batch_size: 10000 28 | eval_period: 1 29 | ckpt_period: 100 30 | edge_sample_num_neighbors: [2, 2, 2] 31 | model: 32 | type: Exphormer 33 | loss_fun: cross_entropy 34 | gt: 35 | layer_type: Spexphormer 36 | layers: 3 37 | n_heads: 4 38 | dim_hidden: 32 39 | dropout: 0.5 40 | layer_norm: False 41 | batch_norm: True 42 | gnn: 43 | head: inductive_node 44 | layers_pre_mp: 1 45 | layers_post_mp: 1 46 | optim: 47 | clip_grad_norm: True 48 | optimizer: adamW 49 | weight_decay: 1e-3 50 | base_lr: 0.01 51 | max_epoch: 100 52 | scheduler: cosine_with_warmup 53 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/heterophilic/tolokers/tolokers_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: auc 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: Tolokers 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Heterophilous 10 | name: Tolokers 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: True 24 | exp_deg: 15 25 | train: 26 | mode: custom_train 27 | ckpt_best: True 28 | save_attention_scores: True 29 | eval_period: 1 30 | ckpt_period: 100 31 | temp_rdc_ratio: 1.0 32 | temp_min: 0.05 33 | temp_wait: 5 34 | sample_new_edges: False 35 | model: 36 | type: Exphormer 37 | loss_fun: cross_entropy 38 | gt: 39 | layer_type: ASE 40 | layers: 4 41 | n_heads: 1 42 | dim_hidden: 4 43 | dropout: 0.0 44 | layer_norm: True 45 | batch_norm: False 46 | gnn: 47 | head: default 48 | layers_pre_mp: 1 49 | layers_post_mp: 1 # Not used when `gnn.head: san_graph` 50 | optim: 51 | clip_grad_norm: True 52 | optimizer: adamW 53 | weight_decay: 1e-3 54 | base_lr: 0.01 55 | max_epoch: 100 56 | scheduler: cosine_with_warmup 57 | num_warmup_epochs: 5 58 | -------------------------------------------------------------------------------- /spexphormer/loss/weighted_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_loss 5 | 6 | 7 | @register_loss('weighted_cross_entropy') 8 | def weighted_cross_entropy(pred, true): 9 | """Weighted cross-entropy for unbalanced classes. 10 | """ 11 | if cfg.model.loss_fun == 'weighted_cross_entropy': 12 | # calculating label weights for weighted loss computation 13 | V = true.size(0) 14 | n_classes = pred.shape[1] if pred.ndim > 1 else 2 15 | label_count = torch.bincount(true) 16 | label_count = label_count[label_count.nonzero(as_tuple=True)].squeeze() 17 | cluster_sizes = torch.zeros(n_classes, device=pred.device).long() 18 | cluster_sizes[torch.unique(true)] = label_count 19 | weight = (V - cluster_sizes).float() / V 20 | weight *= (cluster_sizes > 0).float() 21 | # multiclass 22 | if pred.ndim > 1: 23 | pred = F.log_softmax(pred, dim=-1) 24 | return F.nll_loss(pred, true, weight=weight), pred 25 | # binary 26 | else: 27 | loss = F.binary_cross_entropy_with_logits(pred, true.float(), 28 | weight=weight[true]) 29 | return loss, torch.sigmoid(pred) 30 | -------------------------------------------------------------------------------- /configs/homophilic/CS/CS_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: cs3 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Coauthor 10 | name: cs 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'random' 22 | split: [0.6, 0.2, 0.2] 23 | prep: 24 | exp: False 25 | train: 26 | mode: custom_with_sampling 27 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform] 28 | batch_size: 100000 29 | eval_period: 1 30 | ckpt_period: 100 31 | edge_sample_num_neighbors: [5, 5, 5, 5] 32 | model: 33 | type: Exphormer 34 | loss_fun: cross_entropy 35 | gt: 36 | layer_type: Spexphormer 37 | layers: 4 38 | n_heads: 2 39 | dim_hidden: 64 40 | dropout: 0.4 41 | layer_norm: False 42 | batch_norm: True 43 | gnn: 44 | head: inductive_node 45 | layers_pre_mp: 1 46 | layers_post_mp: 1 47 | optim: 48 | clip_grad_norm: True 49 | optimizer: adamW 50 | weight_decay: 1e-3 51 | base_lr: 0.002 52 | max_epoch: 120 53 | scheduler: cosine_with_warmup 54 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/large/pokec/pokec_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: pokec 7 | entity: expand-gnns 8 | dataset: 9 | format: SNAP 10 | name: pokec 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: False 24 | train: 25 | mode: custom_with_sampling 26 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform, graph_max] 27 | batch_size: 500 28 | eval_period: 5 29 | ckpt_period: 100 30 | edge_sample_num_neighbors: [20, 20] 31 | model: 32 | type: Exphormer 33 | loss_fun: cross_entropy 34 | gt: 35 | layer_type: Spexphormer 36 | layers: 2 37 | n_heads: 1 38 | dim_hidden: 64 39 | dropout: 0.2 40 | layer_norm: False 41 | batch_norm: True 42 | gnn: 43 | head: inductive_node 44 | layers_pre_mp: 2 45 | layers_post_mp: 2 46 | dim_inner: 64 47 | dropout: 0.2 48 | optim: 49 | clip_grad_norm: True 50 | optimizer: adamW 51 | weight_decay: 1e-3 52 | base_lr: 0.001 53 | max_epoch: 300 54 | scheduler: cosine_with_warmup 55 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/homophilic/physics/physics_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: physics3 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Coauthor 10 | name: physics 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'random' 22 | split: [0.6, 0.2, 0.2] 23 | prep: 24 | exp: False 25 | train: 26 | mode: custom_with_sampling 27 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform] 28 | batch_size: 100000 29 | eval_period: 1 30 | ckpt_period: 100 31 | edge_sample_num_neighbors: [5, 5, 5, 5] 32 | model: 33 | type: Exphormer 34 | loss_fun: cross_entropy 35 | gt: 36 | layer_type: Spexphormer 37 | layers: 4 38 | n_heads: 2 39 | dim_hidden: 64 40 | dropout: 0.4 41 | layer_norm: False 42 | batch_norm: True 43 | gnn: 44 | head: inductive_node 45 | layers_pre_mp: 0 46 | layers_post_mp: 1 47 | optim: 48 | clip_grad_norm: True 49 | optimizer: adamW 50 | weight_decay: 1e-3 51 | base_lr: 0.001 52 | max_epoch: 80 53 | scheduler: cosine_with_warmup 54 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/homophilic/computer/computer_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: computers3 7 | entity: expand-gnns 8 | dataset: 9 | format: PyG-Amazon 10 | name: computers 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'random' 22 | split: [0.6, 0.2, 0.2] 23 | prep: 24 | exp: False 25 | train: 26 | mode: custom_with_sampling 27 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform] 28 | batch_size: 100000 29 | eval_period: 1 30 | ckpt_period: 100 31 | edge_sample_num_neighbors: [5, 5, 5, 5] 32 | model: 33 | type: Exphormer 34 | loss_fun: cross_entropy 35 | gt: 36 | layer_type: Spexphormer 37 | layers: 4 38 | n_heads: 2 39 | dim_hidden: 80 40 | dropout: 0.5 41 | layer_norm: False 42 | batch_norm: True 43 | gnn: 44 | head: inductive_node 45 | layers_pre_mp: 1 46 | layers_post_mp: 1 47 | optim: 48 | clip_grad_norm: True 49 | optimizer: adamW 50 | weight_decay: 1e-3 51 | base_lr: 0.001 52 | max_epoch: 150 53 | scheduler: cosine_with_warmup 54 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/large/proteins/proteins_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: auc 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: proteins 7 | entity: expand-gnns 8 | dataset: 9 | format: OGB 10 | name: ogbn-proteins 11 | task: node 12 | task_type: classification_multilabel 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'standard' 22 | prep: 23 | exp: False 24 | train: 25 | mode: custom_with_sampling 26 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform, graph_max] 27 | batch_size: 256 28 | eval_period: 1 29 | ckpt_period: 100 30 | edge_sample_num_neighbors: [50, 30] 31 | model: 32 | type: Exphormer 33 | loss_fun: cross_entropy 34 | gt: 35 | layer_type: Spexphormer 36 | layers: 2 37 | n_heads: 1 38 | dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` 39 | dropout: 0.1 40 | layer_norm: False 41 | batch_norm: True 42 | gnn: 43 | head: inductive_node 44 | layers_pre_mp: 1 45 | layers_post_mp: 1 46 | optim: 47 | clip_grad_norm: True 48 | optimizer: adamW 49 | weight_decay: 1e-3 50 | base_lr: 0.005 51 | max_epoch: 200 52 | scheduler: cosine_with_warmup 53 | num_warmup_epochs: 5 54 | -------------------------------------------------------------------------------- /configs/large/amazon2m/amazon2m_spexphormer.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: amazon2m 7 | entity: expand-gnns 8 | dataset: 9 | format: OGB 10 | name: ogbn-products 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'random' 22 | split: [0.5, 0.25, 0.25] 23 | prep: 24 | exp: False 25 | train: 26 | mode: custom_with_sampling 27 | spexphormer_sampler: graph_reservoir # options: [graph_reservoir, batch_reservoir, graph_uniform, batch_uniform] 28 | batch_size: 1000 29 | eval_period: 10 30 | ckpt_period: 100 31 | edge_sample_num_neighbors: [10, 10] 32 | model: 33 | type: Exphormer 34 | loss_fun: cross_entropy 35 | gt: 36 | layer_type: Spexphormer 37 | layers: 2 38 | n_heads: 1 39 | dim_hidden: 128 # `gt.dim_hidden` must match `gnn.dim_inner` 40 | dropout: 0.2 41 | layer_norm: False 42 | batch_norm: True 43 | gnn: 44 | head: inductive_node 45 | layers_pre_mp: 1 46 | layers_post_mp: 1 47 | optim: 48 | clip_grad_norm: True 49 | optimizer: adamW 50 | weight_decay: 1e-3 51 | base_lr: 0.001 52 | max_epoch: 200 53 | scheduler: cosine_with_warmup 54 | num_warmup_epochs: 5 -------------------------------------------------------------------------------- /configs/large/amazon2m/amazon2m_ASE.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | metric_best: accuracy 3 | gpu_mem: True 4 | wandb: 5 | use: True 6 | project: amazon2m 7 | entity: expand-gnns 8 | dataset: 9 | format: OGB 10 | name: ogbn-products 11 | task: node 12 | task_type: classification 13 | transductive: True 14 | node_encoder: True 15 | node_encoder_name: LinearNode 16 | node_encoder_bn: False 17 | edge_encoder: True 18 | edge_encoder_num_types: 3 19 | edge_encoder_name: TypeDictEdge2 20 | edge_encoder_bn: False 21 | split_mode: 'random' 22 | split: [0.5, 0.25, 0.25] 23 | prep: 24 | exp: True 25 | exp_deg: 15 26 | train: 27 | mode: custom_train 28 | ckpt_best: True 29 | save_attention_scores: True 30 | eval_period: 1 31 | ckpt_period: 100 32 | temp_rdc_ratio: 0.99 33 | temp_min: 0.05 34 | temp_wait: 20 35 | model: 36 | type: Exphormer 37 | loss_fun: cross_entropy 38 | edge_decoding: dot 39 | graph_pooling: add 40 | gt: 41 | layer_type: ASE 42 | layers: 2 43 | n_heads: 1 44 | dim_hidden: 4 # `gt.dim_hidden` must match `gnn.dim_inner` 45 | dropout: 0.0 46 | layer_norm: False 47 | batch_norm: True 48 | gnn: 49 | head: default 50 | layers_pre_mp: 1 51 | layers_post_mp: 1 # Not used when `gnn.head: san_graph` 52 | optim: 53 | clip_grad_norm: True 54 | optimizer: adamW 55 | weight_decay: 1e-3 56 | base_lr: 0.01 57 | max_epoch: 200 58 | scheduler: cosine_with_warmup 59 | num_warmup_epochs: 5 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spexphormer: Even Sparser Graph Transformers 2 | 3 | ![Spexphormer-viz](./Spexphormer.png) 4 | 5 | 6 | Attention edges enable global information propagation at the cost of larger computational complexity. Aside from the computational complexity, batching without blindly losing the global information propagation seems impossible. 7 | But, are all these attention edges important? Many investigations suggest, no. How to identify the less important edges? We built on the Exphormer model, by first training a very low-dimensional network and gathering the attention scores. Then these attention scores will help to prune the attention map, helping with the batching process and also leading to a faster model. 8 | 9 | 10 | ### Python environment setup with Conda 11 | 12 | ```bash 13 | conda create -n spexphormer python=3.12 14 | conda activate spexphormer 15 | 16 | # check for your cuda version first 17 | conda install pytorch=2.3.0 torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia 18 | conda install pyg -c pyg 19 | 20 | pip install torchmetrics 21 | pip install yacs 22 | pip install ogb 23 | pip install tensorboardX 24 | pip install wandb 25 | pip install bottleneck 26 | 27 | conda clean --all 28 | ``` 29 | 30 | 31 | ### Running Exphormer 32 | ```bash 33 | conda activate spexphormer 34 | 35 | # Running the attention score estimator (ASE) network for the Actor dataset: 36 | python main.py --cfg configs/heterophilic/actor/actor_ASE.yaml wandb.use False 37 | 38 | # Training the spexphormer after the attention score estimation, experiment will run for 5 random seeds 0-4: 39 | python main.py --cfg configs/heterophilic/actor/actor_spexphormer.yaml --repeat 5 wandb.use False 40 | ``` 41 | 42 | After running the ASE codes, the attention scores will automatically be saved in the 'Attention_scores' directory. The Spexphormer code will read the attention scores and perform the sampling. 43 | 44 | 45 | ### W&B logging 46 | To use W&B logging, set `wandb.use True`. -------------------------------------------------------------------------------- /spexphormer/config/defaults_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('overwrite_defaults') 5 | def overwrite_defaults_cfg(cfg): 6 | """Overwrite the default config values that are first set by GraphGym in 7 | torch_geometric.graphgym.config.set_cfg 8 | 9 | WARNING: At the time of writing, the order in which custom config-setting 10 | functions like this one are executed is random; see the referenced `set_cfg` 11 | Therefore never reset here config options that are custom added, only change 12 | those that exist in core GraphGym. 13 | """ 14 | 15 | # Overwrite default dataset name 16 | cfg.dataset.name = 'none' 17 | 18 | # Overwrite default rounding precision 19 | cfg.round = 5 20 | 21 | 22 | @register_config('extended_cfg') 23 | def extended_cfg(cfg): 24 | """General extended config options. 25 | """ 26 | 27 | # Additional name tag used in `run_dir` and `wandb_name` auto generation. 28 | cfg.name_tag = "" 29 | 30 | cfg.train.mode = 'custom_train' 31 | # In training, if True (and also cfg.train.enable_ckpt is True) then 32 | # always checkpoint the current best model based on validation performance, 33 | # instead, when False, follow cfg.train.eval_period checkpointing frequency. 34 | cfg.train.ckpt_best = False 35 | # If True, after training, it will rerun the model for saving the attention scores 36 | cfg.train.save_attention_scores = False 37 | cfg.train.saving_epoch = False # don't change this, this is just for the model itself to adjust when to save 38 | cfg.train.cur_epoch = 0 39 | cfg.train.temp_rdc_ratio = 1.0 40 | cfg.train.temp_min = 0.1 41 | cfg.train.temp_wait = 10 42 | cfg.train.replace_edges = False 43 | cfg.train.number_of_edge_sets = 1 44 | cfg.train.rotate_edges = False 45 | cfg.train.layer_wise_edges = False 46 | cfg.train.sample_new_edges = False 47 | cfg.train.num_edge_samples = 50 48 | cfg.train.edge_sample_num_neighbors = [6, 4, 4, 4, 4, 4] 49 | cfg.train.spexphormer_sampler = 'graph_reservoir' 50 | cfg.train.resampling_epochs = 1 51 | -------------------------------------------------------------------------------- /spexphormer/encoder/type_dict_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import (register_node_encoder, 4 | register_edge_encoder) 5 | 6 | @register_node_encoder('TypeDictNode') 7 | class TypeDictNodeEncoder(torch.nn.Module): 8 | def __init__(self, emb_dim): 9 | super().__init__() 10 | 11 | num_types = cfg.dataset.node_encoder_num_types 12 | if num_types < 1: 13 | raise ValueError(f"Invalid 'node_encoder_num_types': {num_types}") 14 | 15 | self.encoder = torch.nn.Embedding(num_embeddings=num_types, 16 | embedding_dim=emb_dim) 17 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 18 | 19 | def forward(self, batch): 20 | # Encode just the first dimension if more exist 21 | batch.x = self.encoder(batch.x[:, 0]) 22 | 23 | return batch 24 | 25 | 26 | @register_edge_encoder('TypeDictEdge') 27 | class TypeDictEdgeEncoder(torch.nn.Module): 28 | def __init__(self, emb_dim): 29 | super().__init__() 30 | 31 | num_types = cfg.dataset.edge_encoder_num_types 32 | if num_types < 1: 33 | raise ValueError(f"Invalid 'edge_encoder_num_types': {num_types}") 34 | 35 | self.encoder = torch.nn.Embedding(num_embeddings=num_types, 36 | embedding_dim=emb_dim) 37 | 38 | def forward(self, batch): 39 | batch.edge_dict = batch.edge_attr 40 | batch.edge_attr = self.encoder(batch.edge_attr) 41 | return batch 42 | 43 | 44 | @register_edge_encoder('TypeDictEdge2') 45 | class TypeDictEdgeEncoder2(torch.nn.Module): 46 | ''' 47 | Edge encoder for the type dictionary that only stores the torch.nn.Embedding weights for the batch. 48 | ''' 49 | def __init__(self, emb_dim): 50 | super().__init__() 51 | 52 | self.num_types = cfg.dataset.edge_encoder_num_types 53 | if self.num_types < 1: 54 | raise ValueError(f"Invalid 'edge_encoder_num_types': {self.num_types}") 55 | 56 | self.encoder = torch.nn.Embedding(num_embeddings=self.num_types, 57 | embedding_dim=emb_dim) 58 | 59 | def forward(self, batch): 60 | batch.edge_embeddings = self.encoder(torch.arange(self.num_types).to(self.encoder.weight.device)) 61 | return batch -------------------------------------------------------------------------------- /spexphormer/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from yacs.config import CfgNode 4 | 5 | 6 | def flatten_dict(metrics): 7 | """Flatten a list of train/val/test metrics into one dict to send to wandb. 8 | 9 | Args: 10 | metrics: List of Dicts with metrics 11 | 12 | Returns: 13 | A flat dictionary with names prefixed with "train/" , "val/" , "test/" 14 | """ 15 | prefixes = ['train', 'val', 'test'] 16 | result = {} 17 | for i in range(len(metrics)): 18 | # Take the latest metrics. 19 | stats = metrics[i][-1] 20 | result.update({f"{prefixes[i]}/{k}": v for k, v in stats.items()}) 21 | return result 22 | 23 | 24 | def cfg_to_dict(cfg_node, key_list=[]): 25 | """Convert a config node to dictionary. 26 | 27 | Yacs doesn't have a default function to convert the cfg object to plain 28 | python dict. The following function was taken from 29 | https://github.com/rbgirshick/yacs/issues/19 30 | """ 31 | _VALID_TYPES = {tuple, list, str, int, float, bool} 32 | 33 | if not isinstance(cfg_node, CfgNode): 34 | if type(cfg_node) not in _VALID_TYPES: 35 | logging.warning(f"Key {'.'.join(key_list)} with " 36 | f"value {type(cfg_node)} is not " 37 | f"a valid type; valid types: {_VALID_TYPES}") 38 | return cfg_node 39 | else: 40 | cfg_dict = dict(cfg_node) 41 | for k, v in cfg_dict.items(): 42 | cfg_dict[k] = cfg_to_dict(v, key_list + [k]) 43 | return cfg_dict 44 | 45 | 46 | def make_wandb_name(cfg): 47 | # Format dataset name. 48 | dataset_name = cfg.dataset.format 49 | if dataset_name.startswith('OGB'): 50 | dataset_name = dataset_name[3:] 51 | if dataset_name.startswith('PyG-'): 52 | dataset_name = dataset_name[4:] 53 | if dataset_name in ['GNNBenchmarkDataset', 'TUDataset']: 54 | # Shorten some verbose dataset naming schemes. 55 | dataset_name = "" 56 | if cfg.dataset.name != 'none': 57 | dataset_name += "-" if dataset_name != "" else "" 58 | if cfg.dataset.name == 'LocalDegreeProfile': 59 | dataset_name += 'LDP' 60 | else: 61 | dataset_name += cfg.dataset.name 62 | # Format model name. 63 | model_name = cfg.model.type 64 | if cfg.model.type in ['gnn', 'custom_gnn']: 65 | model_name += f".{cfg.gnn.layer_type}" 66 | elif cfg.model.type == 'GPSModel': 67 | model_name = f"GPS.{cfg.gt.layer_type}" 68 | model_name += f".{cfg.name_tag}" if cfg.name_tag else "" 69 | # Compose wandb run name. 70 | name = f"{dataset_name}.{model_name}.r{cfg.run_id}" 71 | return name 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # CUSTOM 2 | .vscode/ 3 | scripts/ 4 | slurm_history/ 5 | datasets/* 6 | configs/temp/ 7 | pretrained/ 8 | results/* 9 | Attention_scores/* 10 | EdgeSets/* 11 | .DS_Store 12 | vocprep/benchmark_RELEASE/ 13 | vocprep/voc_viz_files/ 14 | vocprep/VOC/benchmark_RELEASE/ 15 | vocprep/VOC/*.tgz 16 | vocprep/VOC/*.pickle 17 | vocprep/VOC/*.pkl 18 | vocprep/VOC/*.zip 19 | splits/ 20 | wandb/* 21 | .idea 22 | *.log 23 | *.bak 24 | *.npy 25 | 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__ 28 | __pycache__/ 29 | __pycache__/* 30 | *.py[cod] 31 | *$py.class 32 | 33 | # C extensions 34 | *.so 35 | 36 | # Distribution / packaging 37 | .Python 38 | build/ 39 | develop-eggs/ 40 | dist/ 41 | downloads/ 42 | eggs/ 43 | .eggs/ 44 | lib/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | pip-wheel-metadata/ 51 | share/python-wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | *.py,cover 78 | .hypothesis/ 79 | .pytest_cache/ 80 | 81 | # Translations 82 | *.mo 83 | *.pot 84 | 85 | # Django stuff: 86 | *.log 87 | local_settings.py 88 | db.sqlite3 89 | db.sqlite3-journal 90 | 91 | # Flask stuff: 92 | instance/ 93 | .webassets-cache 94 | 95 | # Scrapy stuff: 96 | .scrapy 97 | 98 | # Sphinx documentation 99 | docs/_build/ 100 | 101 | # PyBuilder 102 | target/ 103 | 104 | # Jupyter Notebook 105 | .ipynb_checkpoints 106 | 107 | # IPython 108 | profile_default/ 109 | ipython_config.py 110 | 111 | # pyenv 112 | .python-version 113 | 114 | # pipenv 115 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 116 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 117 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 118 | # install all needed dependencies. 119 | #Pipfile.lock 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # vim edit buffer 159 | *.swp 160 | 161 | 162 | -------------------------------------------------------------------------------- /spexphormer/layer/Spexphormer_Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.register import register_layer 6 | 7 | 8 | class SpexphormerAttention(nn.Module): 9 | 10 | def __init__(self, in_dim, out_dim, num_heads, layer_idx, use_bias=False): 11 | super().__init__() 12 | 13 | if out_dim % num_heads != 0: 14 | raise ValueError('hidden dimension is not dividable by the number of heads') 15 | self.out_dim = out_dim // num_heads 16 | self.num_heads = num_heads 17 | self.layer_idx = layer_idx 18 | 19 | self.edge_index_name = f'edge_index_layer_{layer_idx}' 20 | self.edge_attr_name = f'edge_type_layer_{layer_idx}' 21 | 22 | self.Q = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 23 | self.K = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 24 | self.E1 = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 25 | self.E2 = nn.Linear(in_dim, num_heads, bias=True) 26 | self.V = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 27 | 28 | def forward(self, batch): 29 | edge_index = getattr(batch, self.edge_index_name) 30 | edge_attr = getattr(batch, self.edge_attr_name) 31 | 32 | n1 = batch.num_layer_nodes[self.layer_idx].item() 33 | n2 = batch.num_layer_nodes[self.layer_idx + 1].item() 34 | assert batch.x.shape[0] == n1 35 | 36 | Q_h = self.Q(batch.x[:n2]).view(-1, self.num_heads, self.out_dim) 37 | K_h = self.K(batch.x).view(-1, self.num_heads, self.out_dim) 38 | V_h = self.V(batch.x).view(-1, self.num_heads, self.out_dim) 39 | 40 | if cfg.dataset.edge_encoder_name == 'TypeDictEdge2': 41 | E1 = self.E1(batch.edge_embeddings)[edge_attr].view(n2, -1, self.num_heads, self.out_dim) 42 | E2 = self.E2(batch.edge_embeddings)[edge_attr].view(n2, -1, self.num_heads, 1) 43 | else: 44 | E1 = self.E1(edge_attr).view(n2, -1, self.num_heads, self.out_dim) 45 | E2 = self.E2(edge_attr).view(n2, -1, self.num_heads, 1) 46 | 47 | neighbors = edge_index[0, :] 48 | deg = neighbors.shape[0]//n2 49 | neighbors = neighbors.reshape(n2, deg) 50 | 51 | K_h = K_h[neighbors] 52 | V_h = V_h[neighbors] 53 | 54 | score = torch.mul(E1, K_h) 55 | 56 | score = torch.bmm(score.view(-1, deg, self.out_dim), Q_h.view(-1, self.out_dim, 1)) 57 | score = score.view(-1, self.num_heads, deg) 58 | 59 | score = score + E2.squeeze(-1).permute([0, 2, 1]) 60 | score = score.clamp(-8, 8) 61 | score = F.softmax(score, dim=-1) 62 | 63 | V_h = V_h.permute(0, 2, 1, 3) 64 | score = score.unsqueeze(-1) 65 | h_out = torch.mul(score, V_h) 66 | h_out = h_out.sum(dim=2) 67 | h_out = h_out.reshape(n2, -1) 68 | 69 | return h_out 70 | 71 | register_layer('ExphormerFinal', SpexphormerAttention) 72 | register_layer('ExphormerSecond', SpexphormerAttention) 73 | register_layer('ExphormerRegularGraph', SpexphormerAttention) -------------------------------------------------------------------------------- /spexphormer/network/SpExphormer_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import GNNPreMP 5 | from torch_geometric.graphgym.models.layer import (new_layer_config, 6 | BatchNorm1dNode) 7 | from torch_geometric.graphgym.register import register_network 8 | 9 | from spexphormer.layer.SpExphormer_full_layer import SpExphormerFullLayer 10 | 11 | 12 | class FeatureEncoder(torch.nn.Module): 13 | """ 14 | Encoding node and edge features 15 | 16 | Args: 17 | dim_in (int): Input feature dimension 18 | """ 19 | def __init__(self, dim_in): 20 | super(FeatureEncoder, self).__init__() 21 | self.dim_in = dim_in 22 | if cfg.dataset.node_encoder: 23 | # Encode integer node features via nn.Embeddings 24 | NodeEncoder = register.node_encoder_dict[ 25 | cfg.dataset.node_encoder_name] 26 | self.node_encoder = NodeEncoder(cfg.gt.dim_hidden) 27 | if cfg.dataset.node_encoder_bn: 28 | self.node_encoder_bn = BatchNorm1dNode( 29 | new_layer_config(cfg.gt.dim_hidden, -1, -1, has_act=False, 30 | has_bias=False, cfg=cfg)) 31 | self.dim_in = cfg.gt.dim_hidden 32 | if cfg.dataset.edge_encoder: 33 | if getattr(cfg.gt, 'dim_edge', None) is None: 34 | cfg.gt.dim_edge = cfg.gt.dim_hidden 35 | 36 | EdgeEncoder = register.edge_encoder_dict[ 37 | cfg.dataset.edge_encoder_name] 38 | self.edge_encoder = EdgeEncoder(cfg.gt.dim_edge) 39 | if cfg.dataset.edge_encoder_bn: 40 | self.edge_encoder_bn = BatchNorm1dNode( 41 | new_layer_config(cfg.gt.dim_edge, -1, -1, has_act=False, 42 | has_bias=False, cfg=cfg)) 43 | 44 | def forward(self, batch): 45 | for module in self.children(): 46 | batch = module(batch) 47 | return batch 48 | 49 | 50 | class SpExphormer_Network(torch.nn.Module): 51 | ''' 52 | This model can be used for creating the variants of Exphormer and Spexphormer networks. 53 | ''' 54 | def __init__(self, dim_in, dim_out): 55 | super().__init__() 56 | self.encoder = FeatureEncoder(dim_in) 57 | dim_in = self.encoder.dim_in 58 | 59 | if cfg.gnn.layers_pre_mp > 0: 60 | self.pre_mp = GNNPreMP( 61 | dim_in, cfg.gt.dim_hidden, cfg.gnn.layers_pre_mp) 62 | dim_in = cfg.gt.dim_hidden 63 | else: 64 | self.pre_mp = None 65 | 66 | layers = [] 67 | 68 | for i in range(cfg.gt.layers): 69 | layers.append(SpExphormerFullLayer( 70 | dim_h=cfg.gt.dim_hidden, 71 | layer_type=cfg.gt.layer_type, 72 | num_heads=cfg.gt.n_heads, 73 | layer_idx=i, 74 | dropout=cfg.gt.dropout, 75 | layer_norm=cfg.gt.layer_norm, 76 | batch_norm=cfg.gt.batch_norm, 77 | use_ffn=cfg.gt.FFN, 78 | exp_edges_cfg=cfg.prep 79 | )) 80 | self.layers = torch.nn.Sequential(*layers) 81 | 82 | GNNHead = register.head_dict[cfg.gnn.head] 83 | self.post_mp = GNNHead(dim_in=cfg.gt.dim_hidden, dim_out=dim_out) 84 | 85 | def forward(self, batch): 86 | for module in self.children(): 87 | batch = module(batch) 88 | 89 | return batch 90 | 91 | register_network('Exphormer', SpExphormer_Network) 92 | register_network('Spexphormer', SpExphormer_Network) -------------------------------------------------------------------------------- /spexphormer/layer/Exphormer_Attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from torch_geometric.graphgym.config import cfg 6 | from torch_geometric.graphgym.register import register_layer 7 | 8 | 9 | 10 | class ExphormerAttention(nn.Module): 11 | 12 | def __init__(self, in_dim, out_dim, num_heads, use_bias, layer_idx, dim_edge=None, use_virt_nodes=False): 13 | super().__init__() 14 | 15 | if out_dim % num_heads != 0: 16 | raise ValueError('hidden dimension is not dividable by the number of heads') 17 | self.out_dim = out_dim // num_heads 18 | self.num_heads = num_heads 19 | self.use_virt_nodes = use_virt_nodes 20 | self.use_bias = use_bias 21 | self.layer_idx = layer_idx 22 | 23 | if dim_edge is None: 24 | dim_edge = in_dim 25 | 26 | self.Q = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 27 | self.K = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 28 | self.V = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) 29 | 30 | self.use_edge_attr = cfg.gt.use_edge_feats 31 | if self.use_edge_attr: 32 | self.E = nn.Linear(dim_edge, self.out_dim * num_heads, bias=use_bias) 33 | 34 | 35 | def propagate_attention(self, batch, edge_index): 36 | src = batch.K_h[edge_index[0].to(torch.long)] # (num edges) x num_heads x out_dim 37 | dest = batch.Q_h[edge_index[1].to(torch.long)] # (num edges) x num_heads x out_dim 38 | score = torch.mul(src, dest) # element-wise multiplication 39 | 40 | # Scale scores by sqrt(d) 41 | score = score / np.sqrt(self.out_dim) 42 | 43 | # Use available edge features to modify the scores for edges 44 | score = torch.mul(score, batch.E) # (num real edges) x num_heads x out_dim 45 | 46 | # score = torch.exp(score.sum(-1, keepdim=True)) 47 | score = torch.exp(score.sum(-1, keepdim=True).clamp(-5, 5)) 48 | 49 | 50 | # Apply attention score to each source node to create edge messages 51 | msg = batch.V_h[edge_index[0].to(torch.long)] * score # (num real edges) x num_heads x out_dim 52 | # Add-up real msgs in destination nodes as given by batch.edge_index[1] 53 | batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim 54 | batch.wV.index_add_(0, edge_index[1].to(torch.long), msg) 55 | 56 | # Compute attention normalization coefficient 57 | batch.Z = score.new_zeros(batch.V_h.size(0), self.num_heads, 1) # (num nodes in batch) x num_heads x 1 58 | batch.Z.index_add_(0, edge_index[1].to(torch.long), score) 59 | 60 | 61 | def forward(self, batch): 62 | edge_attr = batch.edge_attr 63 | edge_index = batch.edge_index 64 | h = batch.x 65 | num_node = batch.batch.shape[0] 66 | 67 | Q_h = self.Q(h) 68 | K_h = self.K(h) 69 | V_h = self.V(h) 70 | 71 | if self.use_edge_attr: 72 | if cfg.dataset.edge_encoder_name == 'TypeDictEdge2': 73 | E = self.E(batch.edge_embeddings)[batch.edge_attr] 74 | else: 75 | E = self.E(edge_attr) 76 | 77 | # Reshaping into [num_nodes, num_heads, feat_dim] to 78 | # get projections for multi-head attention 79 | batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) 80 | batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) 81 | batch.E = E.view(-1, self.num_heads, self.out_dim) 82 | batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) 83 | 84 | self.propagate_attention(batch, edge_index) 85 | 86 | h_out = batch.wV / (batch.Z + 1e-6) 87 | 88 | h_out = h_out.view(-1, self.out_dim * self.num_heads) 89 | 90 | batch.virt_h = h_out[num_node:] 91 | h_out = h_out[:num_node] 92 | 93 | return h_out 94 | 95 | 96 | register_layer('Exphormer', ExphormerAttention) -------------------------------------------------------------------------------- /spexphormer/config/posenc_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('posenc') 6 | def set_cfg_posenc(cfg): 7 | """Extend configuration with positional encoding options. 8 | """ 9 | 10 | # Argument group for each Positional Encoding class. 11 | cfg.posenc_LapPE = CN() 12 | cfg.posenc_SignNet = CN() 13 | cfg.posenc_RWSE = CN() 14 | cfg.posenc_HKdiagSE = CN() 15 | cfg.posenc_ElstaticSE = CN() 16 | cfg.posenc_EquivStableLapPE = CN() 17 | 18 | # Effective Resistance Embeddings 19 | cfg.posenc_ERN = CN() #Effective Resistance for Nodes 20 | cfg.posenc_ERE = CN() #Effective Resistance for Edges 21 | 22 | # Common arguments to all PE types. 23 | for name in ['posenc_LapPE', 'posenc_SignNet', 24 | 'posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE', 25 | 'posenc_ERN', 'posenc_ERE']: 26 | pecfg = getattr(cfg, name) 27 | 28 | # Use extended positional encodings 29 | pecfg.enable = False 30 | 31 | # Neural-net model type within the PE encoder: 32 | # 'DeepSet', 'Transformer', 'Linear', 'none', ... 33 | pecfg.model = 'none' 34 | 35 | # Size of Positional Encoding embedding 36 | pecfg.dim_pe = 16 37 | 38 | # Number of layers in PE encoder model 39 | pecfg.layers = 3 40 | 41 | # Number of attention heads in PE encoder when model == 'Transformer' 42 | pecfg.n_heads = 4 43 | 44 | # Number of layers to apply in LapPE encoder post its pooling stage 45 | pecfg.post_layers = 0 46 | 47 | # Choice of normalization applied to raw PE stats: 'none', 'BatchNorm' 48 | pecfg.raw_norm_type = 'none' 49 | 50 | # In addition to appending PE to the node features, pass them also as 51 | # a separate variable in the PyG graph batch object. 52 | pecfg.pass_as_var = False 53 | 54 | # Config for EquivStable LapPE 55 | cfg.posenc_EquivStableLapPE.enable = False 56 | cfg.posenc_EquivStableLapPE.raw_norm_type = 'none' 57 | 58 | # Config for Laplacian Eigen-decomposition for PEs that use it. 59 | for name in ['posenc_LapPE', 'posenc_SignNet', 'posenc_EquivStableLapPE']: 60 | pecfg = getattr(cfg, name) 61 | pecfg.eigen = CN() 62 | 63 | # The normalization scheme for the graph Laplacian: 'none', 'sym', or 'rw' 64 | pecfg.eigen.laplacian_norm = 'sym' 65 | 66 | # The normalization scheme for the eigen vectors of the Laplacian 67 | pecfg.eigen.eigvec_norm = 'L2' 68 | 69 | # Maximum number of top smallest frequencies & eigenvectors to use 70 | pecfg.eigen.max_freqs = 10 71 | 72 | # Config for SignNet-specific options. 73 | cfg.posenc_SignNet.phi_out_dim = 4 74 | cfg.posenc_SignNet.phi_hidden_dim = 64 75 | 76 | for name in ['posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE']: 77 | pecfg = getattr(cfg, name) 78 | 79 | # Config for Kernel-based PE specific options. 80 | pecfg.kernel = CN() 81 | 82 | # List of times to compute the heat kernel for (the time is equivalent to 83 | # the variance of the kernel) / the number of steps for random walk kernel 84 | # Can be overridden by `posenc.kernel.times_func` 85 | pecfg.kernel.times = [] 86 | 87 | # Python snippet to generate `posenc.kernel.times`, e.g. 'range(1, 17)' 88 | # If set, it will be executed via `eval()` and override posenc.kernel.times 89 | pecfg.kernel.times_func = '' 90 | 91 | # Override default, electrostatic kernel has fixed set of 10 measures. 92 | cfg.posenc_ElstaticSE.kernel.times_func = 'range(10)' 93 | 94 | # Setting accuracy for Effective Resistance Calculations: 95 | cfg.posenc_ERN.accuracy = 0.1 96 | cfg.posenc_ERE.accuracy = 0.1 97 | 98 | # To be set during the calculations: 99 | cfg.posenc_ERN.er_dim = 'none' 100 | -------------------------------------------------------------------------------- /spexphormer/metrics_ogb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score, average_precision_score 3 | 4 | """ 5 | Evaluation functions from OGB. 6 | https://github.com/snap-stanford/ogb/blob/master/ogb/graphproppred/evaluate.py 7 | """ 8 | 9 | def eval_rocauc(y_true, y_pred): 10 | ''' 11 | compute ROC-AUC averaged across tasks 12 | ''' 13 | 14 | rocauc_list = [] 15 | 16 | for i in range(y_true.shape[1]): 17 | # AUC is only defined when there is at least one positive data. 18 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 19 | # ignore nan values 20 | is_labeled = y_true[:, i] == y_true[:, i] 21 | rocauc_list.append( 22 | roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])) 23 | 24 | if len(rocauc_list) == 0: 25 | raise RuntimeError( 26 | 'No positively labeled data available. Cannot compute ROC-AUC.') 27 | 28 | return {'rocauc': sum(rocauc_list) / len(rocauc_list)} 29 | 30 | 31 | def eval_ap(y_true, y_pred): 32 | ''' 33 | compute Average Precision (AP) averaged across tasks 34 | ''' 35 | 36 | ap_list = [] 37 | 38 | for i in range(y_true.shape[1]): 39 | # AUC is only defined when there is at least one positive data. 40 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 41 | # ignore nan values 42 | is_labeled = y_true[:, i] == y_true[:, i] 43 | ap = average_precision_score(y_true[is_labeled, i], 44 | y_pred[is_labeled, i]) 45 | 46 | ap_list.append(ap) 47 | 48 | if len(ap_list) == 0: 49 | raise RuntimeError( 50 | 'No positively labeled data available. Cannot compute Average Precision.') 51 | 52 | return {'ap': sum(ap_list) / len(ap_list)} 53 | 54 | 55 | def eval_rmse(y_true, y_pred): 56 | ''' 57 | compute RMSE score averaged across tasks 58 | ''' 59 | rmse_list = [] 60 | 61 | for i in range(y_true.shape[1]): 62 | # ignore nan values 63 | is_labeled = y_true[:, i] == y_true[:, i] 64 | rmse_list.append(np.sqrt( 65 | ((y_true[is_labeled, i] - y_pred[is_labeled, i]) ** 2).mean())) 66 | 67 | return {'rmse': sum(rmse_list) / len(rmse_list)} 68 | 69 | 70 | def eval_acc(y_true, y_pred): 71 | acc_list = [] 72 | 73 | for i in range(y_true.shape[1]): 74 | is_labeled = y_true[:, i] == y_true[:, i] 75 | correct = y_true[is_labeled, i] == y_pred[is_labeled, i] 76 | acc_list.append(float(np.sum(correct)) / len(correct)) 77 | 78 | return {'acc': sum(acc_list) / len(acc_list)} 79 | 80 | 81 | def eval_F1(seq_ref, seq_pred): 82 | # ''' 83 | # compute F1 score averaged over samples 84 | # ''' 85 | 86 | precision_list = [] 87 | recall_list = [] 88 | f1_list = [] 89 | 90 | for l, p in zip(seq_ref, seq_pred): 91 | label = set(l) 92 | prediction = set(p) 93 | true_positive = len(label.intersection(prediction)) 94 | false_positive = len(prediction - label) 95 | false_negative = len(label - prediction) 96 | 97 | if true_positive + false_positive > 0: 98 | precision = true_positive / (true_positive + false_positive) 99 | else: 100 | precision = 0 101 | 102 | if true_positive + false_negative > 0: 103 | recall = true_positive / (true_positive + false_negative) 104 | else: 105 | recall = 0 106 | if precision + recall > 0: 107 | f1 = 2 * precision * recall / (precision + recall) 108 | else: 109 | f1 = 0 110 | 111 | precision_list.append(precision) 112 | recall_list.append(recall) 113 | f1_list.append(f1) 114 | 115 | return {'precision': np.average(precision_list), 116 | 'recall': np.average(recall_list), 117 | 'F1': np.average(f1_list)} 118 | -------------------------------------------------------------------------------- /spexphormer/loader/dataset/Amazon_with_split.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | 6 | from torch_geometric.data import InMemoryDataset, download_url 7 | from torch_geometric.io import read_npz 8 | 9 | 10 | class AmazonWithSplit(InMemoryDataset): 11 | r"""The Amazon Computers and Amazon Photo networks from the 12 | `"Pitfalls of Graph Neural Network Evaluation" 13 | `_ paper. 14 | Nodes represent goods and edges represent that two goods are frequently 15 | bought together. 16 | Given product reviews as bag-of-words node features, the task is to 17 | map goods to their respective product category. 18 | 19 | Args: 20 | root (string): Root directory where the dataset should be saved. 21 | name (string): The name of the dataset (:obj:`"Computers"`, 22 | :obj:`"Photo"`). 23 | transform (callable, optional): A function/transform that takes in an 24 | :obj:`torch_geometric.data.Data` object and returns a transformed 25 | version. The data object will be transformed before every access. 26 | (default: :obj:`None`) 27 | pre_transform (callable, optional): A function/transform that takes in 28 | an :obj:`torch_geometric.data.Data` object and returns a 29 | transformed version. The data object will be transformed before 30 | being saved to disk. (default: :obj:`None`) 31 | """ 32 | 33 | url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/' 34 | 35 | def __init__( 36 | self, 37 | root: str, 38 | name: str, 39 | transform: Optional[Callable] = None, 40 | pre_transform: Optional[Callable] = None, 41 | ): 42 | self.name = name.lower() 43 | assert self.name in ['computers', 'photo'] 44 | super().__init__(root, transform, pre_transform) 45 | self.data, self.slices = torch.load(self.processed_paths[0]) 46 | 47 | self.train_mask = self.data.train_mask 48 | self.val_mask = self.data.val_mask 49 | self.test_mask = self.data.test_mask 50 | 51 | @property 52 | def raw_dir(self) -> str: 53 | return osp.join(self.root, self.name.capitalize(), 'raw') 54 | 55 | @property 56 | def processed_dir(self) -> str: 57 | return osp.join(self.root, self.name.capitalize(), 'processed') 58 | 59 | @property 60 | def raw_file_names(self) -> str: 61 | return f'amazon_electronics_{self.name.lower()}.npz' 62 | 63 | @property 64 | def processed_file_names(self) -> str: 65 | return 'data.pt' 66 | 67 | def download(self): 68 | download_url(self.url + self.raw_file_names, self.raw_dir) 69 | 70 | def process(self): 71 | data = read_npz(self.raw_paths[0]) 72 | data = data if self.pre_transform is None else self.pre_transform(data) 73 | 74 | y = data.y 75 | nclass = torch.unique(y).shape[0] 76 | 77 | percls_trn = int(round(0.6 * len(y) / nclass)) 78 | val_lb = int(round(0.2 * len(y))) 79 | 80 | indices = [] 81 | for i in range(nclass): 82 | index = (y == i).nonzero().view(-1) 83 | index = index[torch.randperm(index.size(0), device=index.device)] 84 | indices.append(index) 85 | 86 | train_index = torch.cat([i[:percls_trn] for i in indices], dim=0) 87 | rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0) 88 | rest_index = rest_index[torch.randperm(rest_index.size(0))] 89 | valid_index = rest_index[:val_lb] 90 | test_index = rest_index[val_lb:] 91 | 92 | train_mask = torch.zeros(data.y.shape, dtype=torch.bool) 93 | train_mask[train_index] = True 94 | val_mask = torch.zeros(data.y.shape, dtype=torch.bool) 95 | val_mask[valid_index] = True 96 | test_mask = torch.zeros(data.y.shape, dtype=torch.bool) 97 | test_mask[test_index] = True 98 | 99 | data.train_mask = train_mask 100 | data.val_mask = val_mask 101 | data.test_mask = test_mask 102 | 103 | data, slices = self.collate([data]) 104 | torch.save((data, slices), self.processed_paths[0]) 105 | 106 | def __repr__(self) -> str: 107 | return f'{self.__class__.__name__}{self.name.capitalize()}()' -------------------------------------------------------------------------------- /spexphormer/loader/dataset/HeterophilousGraphDataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Callable, Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from torch_geometric.data import Data, InMemoryDataset, download_url 8 | from torch_geometric.utils import to_undirected 9 | 10 | 11 | class HeterophilousGraphDataset(InMemoryDataset): 12 | r"""The heterophilous graphs :obj:`"Roman-empire"`, 13 | :obj:`"Amazon-ratings"`, :obj:`"Minesweeper"`, :obj:`"Tolokers"` and 14 | :obj:`"Questions"` from the `"A Critical Look at the Evaluation of GNNs 15 | under Heterophily: Are We Really Making Progress?" 16 | `_ paper. 17 | 18 | Args: 19 | root (str): Root directory where the dataset should be saved. 20 | name (str): The name of the dataset (:obj:`"Roman-empire"`, 21 | :obj:`"Amazon-ratings"`, :obj:`"Minesweeper"`, :obj:`"Tolokers"`, 22 | :obj:`"Questions"`). 23 | transform (callable, optional): A function/transform that takes in an 24 | :obj:`torch_geometric.data.Data` object and returns a transformed 25 | version. The data object will be transformed before every access. 26 | (default: :obj:`None`) 27 | pre_transform (callable, optional): A function/transform that takes in 28 | an :obj:`torch_geometric.data.Data` object and returns a 29 | transformed version. The data object will be transformed before 30 | being saved to disk. (default: :obj:`None`) 31 | force_reload (bool, optional): Whether to re-process the dataset. 32 | (default: :obj:`False`) 33 | 34 | **STATS:** 35 | 36 | .. list-table:: 37 | :widths: 10 10 10 10 10 38 | :header-rows: 1 39 | 40 | * - Name 41 | - #nodes 42 | - #edges 43 | - #features 44 | - #classes 45 | * - Roman-empire 46 | - 22,662 47 | - 32,927 48 | - 300 49 | - 18 50 | * - Amazon-ratings 51 | - 24,492 52 | - 93,050 53 | - 300 54 | - 5 55 | * - Minesweeper 56 | - 10,000 57 | - 39,402 58 | - 7 59 | - 2 60 | * - Tolokers 61 | - 11,758 62 | - 519,000 63 | - 10 64 | - 2 65 | * - Questions 66 | - 48,921 67 | - 153,540 68 | - 301 69 | - 2 70 | """ 71 | url = ('https://github.com/yandex-research/heterophilous-graphs/raw/' 72 | 'main/data') 73 | 74 | def __init__( 75 | self, 76 | root: str, 77 | name: str, 78 | transform: Optional[Callable] = None, 79 | pre_transform: Optional[Callable] = None 80 | ) -> None: 81 | self.name = name.lower().replace('-', '_') 82 | assert self.name in [ 83 | 'roman_empire', 84 | 'amazon_ratings', 85 | 'minesweeper', 86 | 'tolokers', 87 | 'questions', 88 | ] 89 | 90 | super().__init__(root, transform, pre_transform) 91 | # self.load(self.processed_paths[0]) 92 | self.data, self.slices = torch.load(self.processed_paths[0]) 93 | 94 | self.train_mask = self.data.train_mask 95 | self.val_mask = self.data.val_mask 96 | self.test_mask = self.data.test_mask 97 | 98 | @property 99 | def raw_dir(self) -> str: 100 | return osp.join(self.root, self.name, 'raw') 101 | 102 | @property 103 | def processed_dir(self) -> str: 104 | return osp.join(self.root, self.name, 'processed') 105 | 106 | @property 107 | def raw_file_names(self) -> str: 108 | return f'{self.name}.npz' 109 | 110 | @property 111 | def processed_file_names(self) -> str: 112 | return 'data.pt' 113 | 114 | def download(self) -> None: 115 | download_url(f'{self.url}/{self.name}.npz', self.raw_dir) 116 | 117 | def process(self) -> None: 118 | raw = np.load(self.raw_paths[0], 'r') 119 | x = torch.from_numpy(raw['node_features']) 120 | y = torch.from_numpy(raw['node_labels']) 121 | edge_index = torch.from_numpy(raw['edges']).t().contiguous() 122 | edge_index = to_undirected(edge_index, num_nodes=x.size(0)) 123 | train_mask = torch.from_numpy(raw['train_masks']).t()[:, 0].squeeze().contiguous() 124 | val_mask = torch.from_numpy(raw['val_masks']).t()[:, 0].squeeze().contiguous() 125 | test_mask = torch.from_numpy(raw['test_masks']).t()[:, 0].squeeze().contiguous() 126 | 127 | data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask, 128 | val_mask=val_mask, test_mask=test_mask) 129 | 130 | if self.pre_transform is not None: 131 | data = self.pre_transform(data) 132 | 133 | # self.save([data], self.processed_paths[0]) 134 | data, slices = self.collate([data]) 135 | torch.save((data, slices), self.processed_paths[0]) 136 | 137 | def __repr__(self) -> str: 138 | return f'{self.__class__.__name__}(name={self.name})' 139 | -------------------------------------------------------------------------------- /spexphormer/transform/transforms.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import numpy as np 5 | from torch_geometric.utils import subgraph 6 | from tqdm import tqdm 7 | 8 | 9 | def pre_transform_in_memory(dataset, transform_func, show_progress=False): 10 | """Pre-transform already loaded PyG dataset object. 11 | 12 | Apply transform function to a loaded PyG dataset object so that 13 | the transformed result is persistent for the lifespan of the object. 14 | This means the result is not saved to disk, as what PyG's `pre_transform` 15 | would do, but also the transform is applied only once and not at each 16 | data access as what PyG's `transform` hook does. 17 | 18 | Implementation is based on torch_geometric.data.in_memory_dataset.copy 19 | 20 | Args: 21 | dataset: PyG dataset object to modify 22 | transform_func: transformation function to apply to each data example 23 | show_progress: show tqdm progress bar 24 | """ 25 | if transform_func is None: 26 | return dataset 27 | 28 | data_list = [transform_func(dataset.get(i)) 29 | for i in tqdm(range(len(dataset)), 30 | disable=not show_progress, 31 | mininterval=10, 32 | miniters=len(dataset)//20)] 33 | data_list = list(filter(None, data_list)) 34 | 35 | dataset._indices = None 36 | dataset._data_list = data_list 37 | dataset.data, dataset.slices = dataset.collate(data_list) 38 | 39 | 40 | def generate_splits(data, g_split): 41 | n_nodes = len(data.x) 42 | train_mask = torch.zeros(n_nodes, dtype=bool) 43 | valid_mask = torch.zeros(n_nodes, dtype=bool) 44 | test_mask = torch.zeros(n_nodes, dtype=bool) 45 | idx = torch.randperm(n_nodes) 46 | val_num = test_num = int(n_nodes * (1 - g_split) / 2) 47 | train_mask[idx[val_num + test_num:]] = True 48 | valid_mask[idx[:val_num]] = True 49 | test_mask[idx[val_num:val_num + test_num]] = True 50 | data.train_mask = train_mask 51 | data.val_mask = valid_mask 52 | data.test_mask = test_mask 53 | return data 54 | 55 | 56 | def typecast_x(data, type_str): 57 | if type_str == 'float': 58 | data.x = data.x.float() 59 | elif type_str == 'long': 60 | data.x = data.x.long() 61 | else: 62 | raise ValueError(f"Unexpected type '{type_str}'.") 63 | return data 64 | 65 | 66 | def concat_x_and_pos(data): 67 | data.x = torch.cat((data.x, data.pos), 1) 68 | return data 69 | 70 | def move_node_feat_to_x(data): 71 | """For ogbn-proteins, move the attribute node_species to attribute x.""" 72 | data.x = data.node_species 73 | return data 74 | 75 | def proteins_node_one_hot(data): 76 | """For ogbn-proteins, move the attribute node_species to attribute x.""" 77 | uniq_ids = data.node_species.unique().tolist() 78 | index_mapping = {node_id: i for i, node_id in enumerate(uniq_ids)} 79 | data.x = data.node_species.apply_(index_mapping.get) 80 | data.x = torch.nn.functional.one_hot(data.x, num_classes=8).squeeze().contiguous() 81 | return data 82 | 83 | def clip_graphs_to_size(data, size_limit=5000): 84 | if hasattr(data, 'num_nodes'): 85 | N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa 86 | else: 87 | N = data.x.shape[0] # Number of nodes, including disconnected nodes. 88 | if N <= size_limit: 89 | return data 90 | else: 91 | logging.info(f' ...clip to {size_limit} a graph of size: {N}') 92 | if hasattr(data, 'edge_attr'): 93 | edge_attr = data.edge_attr 94 | else: 95 | edge_attr = None 96 | edge_index, edge_attr = subgraph(list(range(size_limit)), 97 | data.edge_index, edge_attr) 98 | if hasattr(data, 'x'): 99 | data.x = data.x[:size_limit] 100 | data.num_nodes = size_limit 101 | else: 102 | data.num_nodes = size_limit 103 | if hasattr(data, 'node_is_attributed'): # for ogbg-code2 dataset 104 | data.node_is_attributed = data.node_is_attributed[:size_limit] 105 | data.node_dfs_order = data.node_dfs_order[:size_limit] 106 | data.node_depth = data.node_depth[:size_limit] 107 | data.edge_index = edge_index 108 | if hasattr(data, 'edge_attr'): 109 | data.edge_attr = edge_attr 110 | return data 111 | 112 | 113 | def add_layer_edge_indices(data, dir, layers): 114 | device = data.edge_index.device 115 | for i in range(layers): 116 | # layer_edge_index = torch.load(dir+f'/edge_index_layer_{i}.pt', map_location=device) 117 | for j in range(50): 118 | layer_edge_index = np.load(dir+f'/edge_index_layer_{i}_sample_{j}.npy') 119 | layer_edge_attr = np.load(dir+f'/edge_attr_layer_{i}_sample_{j}.npy') 120 | setattr(data, f'edge_index_layer_{i}_sample_{j}', torch.from_numpy(layer_edge_index).to(device)) 121 | setattr(data, f'edge_attr_layer_{i}_sample_{j}', torch.from_numpy(layer_edge_attr).to(device)) 122 | 123 | return data 124 | -------------------------------------------------------------------------------- /spexphormer/loader/dataset/Pokec.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | from typing import Callable, Optional 4 | import scipy 5 | import numpy as np 6 | 7 | import torch 8 | 9 | from torch_geometric.data import Data, InMemoryDataset 10 | 11 | def rand_train_test_idx(label, train_prop=.5, valid_prop=.25, ignore_negative=True): 12 | """ randomly splits label into train/valid/test splits """ 13 | if ignore_negative: 14 | labeled_nodes = torch.where(label != -1)[0] 15 | else: 16 | labeled_nodes = label 17 | 18 | # To have a fixed splits: 19 | np.random.seed(123) 20 | 21 | n = labeled_nodes.shape[0] 22 | train_num = int(n * train_prop) 23 | valid_num = int(n * valid_prop) 24 | 25 | perm = torch.as_tensor(np.random.permutation(n)) 26 | 27 | train_indices = perm[:train_num] 28 | val_indices = perm[train_num:train_num + valid_num] 29 | test_indices = perm[train_num + valid_num:] 30 | 31 | if not ignore_negative: 32 | return train_indices, val_indices, test_indices 33 | 34 | train_idx = labeled_nodes[train_indices] 35 | valid_idx = labeled_nodes[val_indices] 36 | test_idx = labeled_nodes[test_indices] 37 | 38 | return train_idx, valid_idx, test_idx 39 | 40 | class Pokec(InMemoryDataset): 41 | 42 | # you can download from here manually, google drive downloader did not work :( 43 | url = 'https://drive.google.com/file/d/1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y' 44 | 45 | def __init__( 46 | self, 47 | root: str, 48 | name: str, 49 | transform: Optional[Callable] = None, 50 | pre_transform: Optional[Callable] = None, 51 | ): 52 | self.name = name.lower() 53 | super().__init__(root, transform, pre_transform) 54 | self.data, self.slices = torch.load(self.processed_paths[0]) 55 | 56 | self.train_mask = self.data.train_mask 57 | self.val_mask = self.data.val_mask 58 | self.test_mask = self.data.test_mask 59 | 60 | @property 61 | def raw_dir(self) -> str: 62 | return osp.join(self.root, self.name.capitalize(), 'raw') 63 | 64 | @property 65 | def processed_dir(self) -> str: 66 | return osp.join(self.root, self.name.capitalize(), 'processed') 67 | 68 | @property 69 | def raw_file_names(self) -> str: 70 | return 'pokek.mat' 71 | 72 | @property 73 | def processed_file_names(self) -> str: 74 | return 'data.pt' 75 | 76 | def download(self): 77 | # I did try a few libraries to download from google drive, it did not work, so please download it manually from: 78 | # https://drive.google.com/file/d/1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y 79 | pass 80 | 81 | def process(self): 82 | """ requires pokec.mat """ 83 | if not osp.exists(f'{self.raw_dir}/pokec.mat'): 84 | raise Exception("Please download the pokec.mat manually from https://drive.google.com/file/d/1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y") 85 | 86 | fulldata = scipy.io.loadmat(f'{self.raw_dir}/pokec.mat') 87 | edge_index = fulldata['edge_index'] 88 | node_feat = fulldata['node_feat'] 89 | label = fulldata['label'] 90 | 91 | edge_index = torch.tensor(edge_index, dtype=torch.long) 92 | node_feat = torch.tensor(node_feat).float() 93 | label = torch.tensor(label, dtype=torch.long).view(-1, 1) 94 | 95 | data = Data(x = node_feat, edge_index=edge_index, y=label) 96 | 97 | 98 | train_prop = 0.1 99 | val_prop = 0.1 100 | split_dir = f'{self.raw_dir}/split_{train_prop}_{val_prop}' 101 | tensor_split_idx = {} 102 | if osp.exists(split_dir): 103 | tensor_split_idx['train'] = torch.as_tensor(np.loadtxt(split_dir + '/pokec_train.txt'), dtype=torch.long) 104 | tensor_split_idx['valid'] = torch.as_tensor(np.loadtxt(split_dir + '/pokec_valid.txt'), dtype=torch.long) 105 | tensor_split_idx['test'] = torch.as_tensor(np.loadtxt(split_dir + '/pokec_test.txt'), dtype=torch.long) 106 | else: 107 | os.makedirs(split_dir) 108 | tensor_split_idx['train'], tensor_split_idx['valid'], tensor_split_idx['test'] \ 109 | = rand_train_test_idx(data.y, train_prop=train_prop, valid_prop=val_prop) 110 | np.savetxt(split_dir + '/pokec_train.txt', tensor_split_idx['train'], fmt='%d') 111 | np.savetxt(split_dir + '/pokec_valid.txt', tensor_split_idx['valid'], fmt='%d') 112 | np.savetxt(split_dir + '/pokec_test.txt', tensor_split_idx['test'], fmt='%d') 113 | 114 | train_mask = torch.zeros(data.y.shape, dtype=torch.bool) 115 | train_mask[tensor_split_idx['train']] = True 116 | val_mask = torch.zeros(data.y.shape, dtype=torch.bool) 117 | val_mask[tensor_split_idx['valid']] = True 118 | test_mask = torch.zeros(data.y.shape, dtype=torch.bool) 119 | test_mask[tensor_split_idx['test']] = True 120 | 121 | data.train_mask = train_mask 122 | data.val_mask = val_mask 123 | data.test_mask = test_mask 124 | 125 | data, slices = self.collate([data]) 126 | torch.save((data, slices), self.processed_paths[0]) 127 | 128 | def __repr__(self) -> str: 129 | return f'{self.__class__.__name__}{self.name.capitalize()}()' -------------------------------------------------------------------------------- /spexphormer/layer/ASE_Attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from pathlib import Path 7 | 8 | from torch_geometric.graphgym.config import cfg 9 | from torch_geometric.graphgym.register import register_layer 10 | 11 | 12 | def separate_by_node(edges, scores): 13 | num_nodes = edges.max().item() + 1 14 | neighbors = [[] for _ in range(num_nodes)] 15 | neighbors_edge_idx = [[] for _ in range(num_nodes)] 16 | P = [[] for _ in range(num_nodes)] 17 | sum_scores = torch.zeros(num_nodes, dtype=scores.dtype) 18 | rev_sum_scores = torch.zeros(num_nodes, dtype=scores.dtype) 19 | 20 | for i in range(edges.size(1)): 21 | u, v = edges[:, i] 22 | neighbors[v].append(u.item()) 23 | neighbors_edge_idx[v].append(i) 24 | w = scores[i].item() 25 | sum_scores[v] += w 26 | rev_sum_scores[u] += w 27 | P[v].append(w) 28 | 29 | return neighbors, neighbors_edge_idx, P, sum_scores.tolist(), rev_sum_scores.tolist() 30 | 31 | 32 | class ASE_Attention_Layer(nn.Module): 33 | 34 | def __init__(self, in_dim, out_dim, num_heads, layer_idx, use_bias=False, dim_edge=None, use_virt_nodes=False): 35 | super().__init__() 36 | 37 | self.out_dim = out_dim // num_heads 38 | self.num_heads = num_heads 39 | self.use_virt_nodes = use_virt_nodes 40 | self.use_bias = use_bias 41 | self.layer_idx = layer_idx 42 | 43 | if dim_edge is None: 44 | dim_edge = in_dim 45 | 46 | self.QKV = nn.Linear(in_dim, self.out_dim * num_heads * 3, bias=use_bias) 47 | self.V_scale = nn.Parameter(data=torch.Tensor([0.25]), requires_grad=True) 48 | 49 | self.use_edge_attr = cfg.gt.use_edge_feats 50 | if self.use_edge_attr: 51 | self.E1 = nn.Linear(dim_edge, self.out_dim * num_heads, bias=use_bias) 52 | self.E2 = nn.Linear(dim_edge, num_heads, bias=True) 53 | 54 | self.T = 1.0 55 | 56 | def propagate_attention(self, batch, edge_index): 57 | src = batch.K_h[edge_index[0].to(torch.long)] # (num edges) x num_heads x out_dim 58 | dest = batch.Q_h[edge_index[1].to(torch.long)] # (num edges) x num_heads x out_dim 59 | score = torch.einsum('ehd,ehd->eh', src, dest) # Efficient batch matrix multiplication 60 | 61 | # Scale scores by sqrt(d) 62 | score = score / np.sqrt(self.out_dim) 63 | 64 | if self.use_edge_attr: 65 | score = score * batch.E.sum(-1) # (num real edges) x num_heads 66 | score = score.unsqueeze(-1) + batch.E2 67 | else: 68 | score = score.unsqueeze(-1) 69 | 70 | score = score / self.T 71 | score = torch.exp(score.clamp(-8, 8)) 72 | 73 | 74 | # Apply attention score to each source node to create edge messages 75 | msg = batch.V_h[edge_index[0].to(torch.long)] * score # (num real edges) x num_heads x out_dim 76 | 77 | # Add-up real msgs in destination nodes as given by batch.edge_index[1] 78 | batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim 79 | batch.wV.index_add_(0, edge_index[1].to(torch.long), msg) # Using index_add_ as an alternative to scatter 80 | 81 | # Compute attention normalization coefficient 82 | batch.Z = score.new_zeros(batch.V_h.size(0), self.num_heads, 1) # (num nodes in batch) x num_heads x 1 83 | batch.Z.index_add_(0, edge_index[1].to(torch.long), score) # Using index_add_ as an alternative to scatter 84 | 85 | if cfg.train.saving_epoch: 86 | new_score = score/(batch.Z[edge_index[1]]) 87 | score_np = new_score.cpu().detach().numpy() 88 | Path(f'Attention_scores/{cfg.dataset.name}').mkdir(parents=True, exist_ok=True) 89 | with open(f'Attention_scores/{cfg.dataset.name}/seed{cfg.seed}_h{self.out_dim}_layer_{self.layer_idx}.npy', 'wb') as f: 90 | np.save(f, score_np) 91 | 92 | 93 | def forward(self, batch): 94 | if cfg.train.cur_epoch >= cfg.train.temp_wait: 95 | self.T = max(cfg.train.temp_min, cfg.train.temp_rdc_ratio ** (cfg.train.cur_epoch - cfg.train.temp_wait)) 96 | edge_attr = batch.edge_attr 97 | edge_index = batch.edge_index 98 | h = batch.x 99 | num_node = batch.batch.shape[0] 100 | 101 | QKV_h = self.QKV(h).view(-1, self.num_heads, 3 * self.out_dim) 102 | batch.Q_h, batch.K_h, batch.V_h = torch.split(QKV_h, self.out_dim, dim=-1) 103 | batch.V_h = F.normalize(batch.V_h, p=2.0, dim=-1) * self.V_scale 104 | 105 | if self.use_edge_attr: 106 | if cfg.dataset.edge_encoder_name == 'TypeDictEdge2': 107 | E = self.E1(batch.edge_embeddings)[batch.edge_attr] 108 | E2 = self.E2(batch.edge_embeddings)[batch.edge_attr] 109 | else: 110 | E = self.E1(edge_attr) 111 | E2 = self.E2(edge_attr) 112 | 113 | batch.E = E.view(-1, self.num_heads, self.out_dim) 114 | batch.E2 = E2.view(-1, self.num_heads, 1) 115 | 116 | self.propagate_attention(batch, edge_index) 117 | 118 | # Normalize the weighted sum of values by the normalization coefficient 119 | h_out = batch.wV / (batch.Z + 1e-6) 120 | 121 | # Reshape the output to combine the heads 122 | h_out = h_out.view(-1, self.out_dim * self.num_heads) 123 | 124 | # Separate virtual node embeddings if virtual nodes are used 125 | if self.use_virt_nodes: 126 | batch.virt_h = h_out[num_node:] 127 | h_out = h_out[:num_node] 128 | 129 | return h_out 130 | 131 | 132 | register_layer('ASE_Attention_Layer', ASE_Attention_Layer) 133 | -------------------------------------------------------------------------------- /spexphormer/layer/SpExphormer_full_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch_geometric.nn as pygnn 4 | 5 | from spexphormer.layer.Exphormer_Attention import ExphormerAttention 6 | from spexphormer.layer.ASE_Attention import ASE_Attention_Layer 7 | from spexphormer.layer.Spexphormer_Attention import SpexphormerAttention 8 | import warnings 9 | 10 | 11 | 12 | class AttentionLayer(nn.Module): 13 | """ 14 | Attention layer 15 | """ 16 | 17 | def __init__(self, dim_h, exphormer_model_type, num_heads, 18 | layer_idx, dropout=0.0, layer_norm=False, 19 | batch_norm=True, exp_edges_cfg=None): 20 | 21 | super().__init__() 22 | 23 | self.dim_h = dim_h 24 | self.layer_norm = layer_norm 25 | self.batch_norm = batch_norm 26 | self.num_heads = num_heads 27 | self.layer_idx = layer_idx 28 | 29 | 30 | if exphormer_model_type == 'Exphormer': 31 | self.self_attn = ExphormerAttention(dim_h, dim_h, num_heads, layer_idx=self.layer_idx, 32 | use_virt_nodes= exp_edges_cfg.num_virt_node > 0, use_bias=False) 33 | elif exphormer_model_type == 'ASE': 34 | if num_heads != 1: 35 | warnings.warn('numer of head for the initial network should be always 1, if you want to use the attention scores for a final network') 36 | self.self_attn = ASE_Attention_Layer(dim_h, dim_h, num_heads, layer_idx=self.layer_idx) 37 | elif exphormer_model_type == 'Spexphormer': 38 | self.self_attn = SpexphormerAttention(dim_h, dim_h, num_heads, layer_idx=self.layer_idx) 39 | else: 40 | raise ValueError(f"Unsupported exphormer model: " 41 | f"{exphormer_model_type}") 42 | self.exphormer_model_type = exphormer_model_type 43 | 44 | if self.layer_norm and self.batch_norm: 45 | raise ValueError("Cannot apply two types of normalization together") 46 | 47 | # Normalization for Self-Attention representation. 48 | if self.layer_norm: 49 | self.norm1_attn = pygnn.norm.GraphNorm(dim_h) 50 | if self.batch_norm: 51 | self.norm1_attn = nn.BatchNorm1d(dim_h) 52 | self.dropout = nn.Dropout(dropout) 53 | 54 | def forward(self, batch): 55 | h = batch.x 56 | h_in1 = h # for first residual connection 57 | 58 | h_attn = self.self_attn(batch) 59 | h_attn = self.dropout(h_attn) 60 | 61 | if h_attn.shape == h_in1.shape: 62 | h_attn = h_in1 + h_attn # Residual connection. 63 | else: 64 | h_attn = h_in1[:h_attn.shape[0]] + h_attn 65 | 66 | if self.layer_norm: 67 | h_attn = self.norm1_attn(h_attn, batch.batch) 68 | if self.batch_norm: 69 | h_attn = self.norm1_attn(h_attn) 70 | return h_attn 71 | 72 | 73 | class SpExphormerFullLayer(nn.Module): 74 | """Variants of the Exphormer 75 | """ 76 | 77 | def __init__(self, dim_h, 78 | layer_type, num_heads, layer_idx, dropout=0.0, 79 | layer_norm=False, batch_norm=True, use_ffn=True, 80 | exp_edges_cfg=None): 81 | super().__init__() 82 | 83 | self.dim_h = dim_h 84 | self.num_heads = num_heads 85 | self.layer_norm = layer_norm 86 | self.batch_norm = batch_norm 87 | self.layer_type = layer_type 88 | self.layer_idx = layer_idx 89 | self.use_ffn = use_ffn 90 | 91 | # Local message-passing models. 92 | self.attention_layer = [] 93 | 94 | if layer_type in {'Exphormer', 'Spexphormer', 'ASE'}: 95 | self.attention_layer = AttentionLayer(dim_h=dim_h, 96 | exphormer_model_type=layer_type, 97 | num_heads=self.num_heads, 98 | layer_idx=self.layer_idx, 99 | dropout=dropout, 100 | layer_norm=self.layer_norm, 101 | batch_norm=self.batch_norm, 102 | exp_edges_cfg = exp_edges_cfg) 103 | else: 104 | raise ValueError(f"Unsupported layer type: {layer_type}") 105 | 106 | self.activation = F.relu 107 | 108 | if self.use_ffn: 109 | # Feed Forward block. 110 | self.ff_linear1 = nn.Linear(dim_h, dim_h * 2) 111 | self.ff_linear2 = nn.Linear(dim_h * 2, dim_h) 112 | 113 | self.ff_dropout1 = nn.Dropout(dropout) 114 | self.ff_dropout2 = nn.Dropout(dropout) 115 | else: 116 | self.dropout = nn.Dropout(dropout) 117 | 118 | if self.layer_norm: 119 | # self.norm2 = pygnn.norm.LayerNorm(dim_h) 120 | self.norm2 = pygnn.norm.GraphNorm(dim_h) 121 | if self.batch_norm: 122 | self.norm2 = nn.BatchNorm1d(dim_h) 123 | 124 | def forward(self, batch): 125 | h = self.attention_layer(batch) 126 | 127 | if self.use_ffn: 128 | # Feed Forward block. 129 | h = h + self._ff_block(h) 130 | else: 131 | h = self.dropout(self.activation(h)) 132 | 133 | if self.layer_norm: 134 | h = self.norm2(h, batch.batch) 135 | if self.batch_norm: 136 | h = self.norm2(h) 137 | 138 | batch.x = h 139 | return batch 140 | 141 | def _ff_block(self, x): 142 | """Feed Forward block. 143 | """ 144 | x = self.ff_dropout1(self.activation(self.ff_linear1(x))) 145 | return self.ff_dropout2(self.ff_linear2(x)) 146 | 147 | def extra_repr(self): 148 | s = f'summary: dim_h={self.dim_h}, ' \ 149 | f'layer_type={self.layer_type}, ' \ 150 | f'heads={self.num_heads}' 151 | return s -------------------------------------------------------------------------------- /spexphormer/loader/dataset/malnet_tiny.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, List 2 | 3 | import os 4 | import glob 5 | import os.path as osp 6 | 7 | import torch 8 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 9 | extract_tar, extract_zip) 10 | from torch_geometric.utils import remove_isolated_nodes 11 | 12 | """ 13 | This is a local copy of MalNetTiny class from PyG 14 | https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/malnet_tiny.py 15 | 16 | TODO: Delete and use PyG's version once it is part of a released version. 17 | At the time of writing this class is in the main PyG github branch but is not 18 | included in the current latest released version 2.0.2. 19 | """ 20 | 21 | class MalNetTiny(InMemoryDataset): 22 | r"""The MalNet Tiny dataset from the 23 | `"A Large-Scale Database for Graph Representation Learning" 24 | `_ paper. 25 | :class:`MalNetTiny` contains 5,000 malicious and benign software function 26 | call graphs across 5 different types. Each graph contains at most 5k nodes. 27 | 28 | Args: 29 | root (string): Root directory where the dataset should be saved. 30 | transform (callable, optional): A function/transform that takes in an 31 | :obj:`torch_geometric.data.Data` object and returns a transformed 32 | version. The data object will be transformed before every access. 33 | (default: :obj:`None`) 34 | pre_transform (callable, optional): A function/transform that takes in 35 | an :obj:`torch_geometric.data.Data` object and returns a 36 | transformed version. The data object will be transformed before 37 | being saved to disk. (default: :obj:`None`) 38 | pre_filter (callable, optional): A function that takes in an 39 | :obj:`torch_geometric.data.Data` object and returns a boolean 40 | value, indicating whether the data object should be included in the 41 | final dataset. (default: :obj:`None`) 42 | """ 43 | 44 | url = 'http://malnet.cc.gatech.edu/graph-data/malnet-graphs-tiny.tar.gz' 45 | # 70/10/20 train, val, test split by type 46 | split_url = 'http://malnet.cc.gatech.edu/split-info/split_info_tiny.zip' 47 | 48 | def __init__(self, root: str, transform: Optional[Callable] = None, 49 | pre_transform: Optional[Callable] = None, 50 | pre_filter: Optional[Callable] = None): 51 | super().__init__(root, transform, pre_transform, pre_filter) 52 | self.data, self.slices = torch.load(self.processed_paths[0]) 53 | 54 | @property 55 | def raw_file_names(self) -> List[str]: 56 | folders = ['addisplay', 'adware', 'benign', 'downloader', 'trojan'] 57 | return [osp.join('malnet-graphs-tiny', folder) for folder in folders] 58 | 59 | @property 60 | def processed_file_names(self) -> List[str]: 61 | return ['data.pt', 'split_dict.pt'] 62 | 63 | def download(self): 64 | path = download_url(self.url, self.raw_dir) 65 | extract_tar(path, self.raw_dir) 66 | os.unlink(path) 67 | path = download_url(self.split_url, self.raw_dir) 68 | extract_zip(path, self.raw_dir) 69 | os.unlink(path) 70 | 71 | def process(self): 72 | data_list = [] 73 | split_dict = {'train': [], 'valid': [], 'test': []} 74 | 75 | parse = lambda f: set([x.split('/')[-1] 76 | for x in f.read().split('\n')[:-1]]) # -1 for empty line at EOF 77 | split_dir = osp.join(self.raw_dir, 'split_info_tiny', 'type') 78 | with open(osp.join(split_dir, 'train.txt'), 'r') as f: 79 | train_names = parse(f) 80 | assert len(train_names) == 3500 81 | with open(osp.join(split_dir, 'val.txt'), 'r') as f: 82 | val_names = parse(f) 83 | assert len(val_names) == 500 84 | with open(osp.join(split_dir, 'test.txt'), 'r') as f: 85 | test_names = parse(f) 86 | assert len(test_names) == 1000 87 | 88 | for y, raw_path in enumerate(self.raw_paths): 89 | raw_path = osp.join(raw_path, os.listdir(raw_path)[0]) 90 | filenames = glob.glob(osp.join(raw_path, '*.edgelist')) 91 | 92 | for filename in filenames: 93 | with open(filename, 'r') as f: 94 | edges = f.read().split('\n')[5:-1] 95 | edge_index = [[int(s) for s in edge.split()] for edge in edges] 96 | edge_index = torch.tensor(edge_index).t().contiguous() 97 | # Remove isolated nodes, including those with only a self-loop 98 | edge_index = remove_isolated_nodes(edge_index)[0] 99 | num_nodes = int(edge_index.max()) + 1 100 | data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes) 101 | data_list.append(data) 102 | 103 | ind = len(data_list) - 1 104 | graph_id = osp.splitext(osp.basename(filename))[0] 105 | if graph_id in train_names: 106 | split_dict['train'].append(ind) 107 | elif graph_id in val_names: 108 | split_dict['valid'].append(ind) 109 | elif graph_id in test_names: 110 | split_dict['test'].append(ind) 111 | else: 112 | raise ValueError(f'No split assignment for "{graph_id}".') 113 | 114 | if self.pre_filter is not None: 115 | data_list = [data for data in data_list if self.pre_filter(data)] 116 | 117 | if self.pre_transform is not None: 118 | data_list = [self.pre_transform(data) for data in data_list] 119 | 120 | torch.save(self.collate(data_list), self.processed_paths[0]) 121 | torch.save(split_dict, self.processed_paths[1]) 122 | 123 | def get_idx_split(self): 124 | return torch.load(self.processed_paths[1]) 125 | -------------------------------------------------------------------------------- /spexphormer/loader/dataset/peptides_functional.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os.path as osp 3 | import pickle 4 | import shutil 5 | 6 | import pandas as pd 7 | import torch 8 | from ogb.utils import smiles2graph 9 | from ogb.utils.torch_util import replace_numpy_with_torchtensor 10 | from ogb.utils.url import decide_download 11 | from torch_geometric.data import Data, InMemoryDataset, download_url 12 | from tqdm import tqdm 13 | 14 | 15 | class PeptidesFunctionalDataset(InMemoryDataset): 16 | def __init__(self, root='datasets', smiles2graph=smiles2graph, 17 | transform=None, pre_transform=None): 18 | """ 19 | PyG dataset of 15,535 peptides represented as their molecular graph 20 | (SMILES) with 10-way multi-task binary classification of their 21 | functional classes. 22 | 23 | The goal is use the molecular representation of peptides instead 24 | of amino acid sequence representation ('peptide_seq' field in the file, 25 | provided for possible baseline benchmarking but not used here) to test 26 | GNNs' representation capability. 27 | 28 | The 10 classes represent the following functional classes (in order): 29 | ['antifungal', 'cell_cell_communication', 'anticancer', 30 | 'drug_delivery_vehicle', 'antimicrobial', 'antiviral', 31 | 'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic'] 32 | 33 | Args: 34 | root (string): Root directory where the dataset should be saved. 35 | smiles2graph (callable): A callable function that converts a SMILES 36 | string into a graph object. We use the OGB featurization. 37 | * The default smiles2graph requires rdkit to be installed * 38 | """ 39 | 40 | self.original_root = root 41 | self.smiles2graph = smiles2graph 42 | self.folder = osp.join(root, 'peptides-functional') 43 | 44 | self.url = 'https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1' 45 | self.version = '701eb743e899f4d793f0e13c8fa5a1b4' # MD5 hash of the intended dataset file 46 | self.url_stratified_split = 'https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1' 47 | self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061' 48 | 49 | # Check version and update if necessary. 50 | release_tag = osp.join(self.folder, self.version) 51 | if osp.isdir(self.folder) and (not osp.exists(release_tag)): 52 | print(f"{self.__class__.__name__} has been updated.") 53 | if input("Will you update the dataset now? (y/N)\n").lower() == 'y': 54 | shutil.rmtree(self.folder) 55 | 56 | super().__init__(self.folder, transform, pre_transform) 57 | self.data, self.slices = torch.load(self.processed_paths[0]) 58 | 59 | @property 60 | def raw_file_names(self): 61 | return 'peptide_multi_class_dataset.csv.gz' 62 | 63 | @property 64 | def processed_file_names(self): 65 | return 'geometric_data_processed.pt' 66 | 67 | def _md5sum(self, path): 68 | hash_md5 = hashlib.md5() 69 | with open(path, 'rb') as f: 70 | buffer = f.read() 71 | hash_md5.update(buffer) 72 | return hash_md5.hexdigest() 73 | 74 | def download(self): 75 | if decide_download(self.url): 76 | path = download_url(self.url, self.raw_dir) 77 | # Save to disk the MD5 hash of the downloaded file. 78 | hash = self._md5sum(path) 79 | if hash != self.version: 80 | raise ValueError("Unexpected MD5 hash of the downloaded file") 81 | open(osp.join(self.root, hash), 'w').close() 82 | # Download train/val/test splits. 83 | path_split1 = download_url(self.url_stratified_split, self.root) 84 | assert self._md5sum(path_split1) == self.md5sum_stratified_split 85 | else: 86 | print('Stop download.') 87 | exit(-1) 88 | 89 | def process(self): 90 | data_df = pd.read_csv(osp.join(self.raw_dir, 91 | 'peptide_multi_class_dataset.csv.gz')) 92 | smiles_list = data_df['smiles'] 93 | 94 | print('Converting SMILES strings into graphs...') 95 | data_list = [] 96 | for i in tqdm(range(len(smiles_list))): 97 | data = Data() 98 | 99 | smiles = smiles_list[i] 100 | graph = self.smiles2graph(smiles) 101 | 102 | assert (len(graph['edge_feat']) == graph['edge_index'].shape[1]) 103 | assert (len(graph['node_feat']) == graph['num_nodes']) 104 | 105 | data.__num_nodes__ = int(graph['num_nodes']) 106 | data.edge_index = torch.from_numpy(graph['edge_index']).to( 107 | torch.int64) 108 | data.edge_attr = torch.from_numpy(graph['edge_feat']).to( 109 | torch.int64) 110 | data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) 111 | data.y = torch.Tensor([eval(data_df['labels'].iloc[i])]) 112 | 113 | data_list.append(data) 114 | 115 | if self.pre_transform is not None: 116 | data_list = [self.pre_transform(data) for data in data_list] 117 | 118 | data, slices = self.collate(data_list) 119 | 120 | print('Saving...') 121 | torch.save((data, slices), self.processed_paths[0]) 122 | 123 | def get_idx_split(self): 124 | """ Get dataset splits. 125 | 126 | Returns: 127 | Dict with 'train', 'val', 'test', splits indices. 128 | """ 129 | split_file = osp.join(self.root, 130 | "splits_random_stratified_peptide.pickle") 131 | with open(split_file, 'rb') as f: 132 | splits = pickle.load(f) 133 | split_dict = replace_numpy_with_torchtensor(splits) 134 | return split_dict 135 | 136 | 137 | if __name__ == '__main__': 138 | dataset = PeptidesFunctionalDataset() 139 | print(dataset) 140 | print(dataset.data.edge_index) 141 | print(dataset.data.edge_index.shape) 142 | print(dataset.data.x.shape) 143 | print(dataset[100]) 144 | print(dataset[100].y) 145 | print(dataset.get_idx_split()) 146 | -------------------------------------------------------------------------------- /spexphormer/loader/dataset/aqsol_molecules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import pickle 5 | 6 | import torch 7 | from tqdm import tqdm 8 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 9 | extract_zip) 10 | from torch_geometric.utils import add_self_loops 11 | 12 | 13 | class AQSOL(InMemoryDataset): 14 | r"""The AQSOL dataset from Benchmarking GNNs (Dwivedi et al., 2020) is based on AqSolDB 15 | (Sorkun et al., 2019) which is a standardized database of 9,982 molecular graphs with 16 | their aqueous solubility values, collected from 9 different data sources. 17 | 18 | The aqueous solubility targets are collected from experimental measurements and standardized 19 | to LogS units in AqSolDB. These final values as the property to regress in the AQSOL dataset 20 | which is the resultant collection in 'Benchmarking GNNs' after filtering out few graphs 21 | with no bonds/edges and a small number of graphs with missing node feature values. 22 | 23 | Thus, the total molecular graphs are 9,823. For each molecular graph, the node features are the 24 | types f heavy atoms and the edge features are the types of bonds between them, similar as ZINC. 25 | 26 | Size of Dataset: 9,982 molecules. 27 | Split: Scaffold split (8:1:1) following same code as OGB. 28 | After cleaning: 7,831 train / 996 val / 996 test 29 | Number of (unique) atoms: 65 30 | Number of (unique) bonds: 5 31 | Performance Metric: MAE, same as ZINC 32 | 33 | Atom Dict: {'Br': 0, 'C': 1, 'N': 2, 'O': 3, 'Cl': 4, 'Zn': 5, 'F': 6, 'P': 7, 'S': 8, 'Na': 9, 'Al': 10, 34 | 'Si': 11, 'Mo': 12, 'Ca': 13, 'W': 14, 'Pb': 15, 'B': 16, 'V': 17, 'Co': 18, 'Mg': 19, 'Bi': 20, 'Fe': 21, 35 | 'Ba': 22, 'K': 23, 'Ti': 24, 'Sn': 25, 'Cd': 26, 'I': 27, 'Re': 28, 'Sr': 29, 'H': 30, 'Cu': 31, 'Ni': 32, 36 | 'Lu': 33, 'Pr': 34, 'Te': 35, 'Ce': 36, 'Nd': 37, 'Gd': 38, 'Zr': 39, 'Mn': 40, 'As': 41, 'Hg': 42, 'Sb': 37 | 43, 'Cr': 44, 'Se': 45, 'La': 46, 'Dy': 47, 'Y': 48, 'Pd': 49, 'Ag': 50, 'In': 51, 'Li': 52, 'Rh': 53, 38 | 'Nb': 54, 'Hf': 55, 'Cs': 56, 'Ru': 57, 'Au': 58, 'Sm': 59, 'Ta': 60, 'Pt': 61, 'Ir': 62, 'Be': 63, 'Ge': 64} 39 | 40 | Bond Dict: {'NONE': 0, 'SINGLE': 1, 'DOUBLE': 2, 'AROMATIC': 3, 'TRIPLE': 4} 41 | 42 | Args: 43 | root (string): Root directory where the dataset should be saved. 44 | transform (callable, optional): A function/transform that takes in an 45 | :obj:`torch_geometric.data.Data` object and returns a transformed 46 | version. The data object will be transformed before every access. 47 | (default: :obj:`None`) 48 | pre_transform (callable, optional): A function/transform that takes in 49 | an :obj:`torch_geometric.data.Data` object and returns a 50 | transformed version. The data object will be transformed before 51 | being saved to disk. (default: :obj:`None`) 52 | pre_filter (callable, optional): A function that takes in an 53 | :obj:`torch_geometric.data.Data` object and returns a boolean 54 | value, indicating whether the data object should be included in the 55 | final dataset. (default: :obj:`None`) 56 | """ 57 | 58 | url = 'https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1' 59 | 60 | def __init__(self, root, split='train', transform=None, pre_transform=None, 61 | pre_filter=None): 62 | self.name = "AQSOL" 63 | assert split in ['train', 'val', 'test'] 64 | super().__init__(root, transform, pre_transform, pre_filter) 65 | path = osp.join(self.processed_dir, f'{split}.pt') 66 | self.data, self.slices = torch.load(path) 67 | 68 | 69 | @property 70 | def raw_file_names(self): 71 | return ['train.pickle', 'val.pickle', 'test.pickle'] 72 | 73 | @property 74 | def processed_file_names(self): 75 | return ['train.pt', 'val.pt', 'test.pt'] 76 | 77 | def download(self): 78 | shutil.rmtree(self.raw_dir) 79 | path = download_url(self.url, self.root) 80 | extract_zip(path, self.root) 81 | os.rename(osp.join(self.root, 'asqol_graph_raw'), self.raw_dir) 82 | os.unlink(path) 83 | 84 | def process(self): 85 | for split in ['train', 'val', 'test']: 86 | with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f: 87 | graphs = pickle.load(f) 88 | 89 | indices = range(len(graphs)) 90 | 91 | pbar = tqdm(total=len(indices)) 92 | pbar.set_description(f'Processing {split} dataset') 93 | 94 | data_list = [] 95 | for idx in indices: 96 | graph = graphs[idx] 97 | 98 | """ 99 | Each `graph` is a tuple (x, edge_attr, edge_index, y) 100 | Shape of x : [num_nodes, 1] 101 | Shape of edge_attr : [num_edges] 102 | Shape of edge_index : [2, num_edges] 103 | Shape of y : [1] 104 | """ 105 | 106 | x = torch.LongTensor(graph[0]).unsqueeze(-1) 107 | edge_attr = torch.LongTensor(graph[1])#.unsqueeze(-1) 108 | edge_index = torch.LongTensor(graph[2]) 109 | y = torch.tensor(graph[3]) 110 | 111 | data = Data(edge_index=edge_index) 112 | 113 | if edge_index.shape[1] == 0: 114 | continue # skipping for graphs with no bonds/edges 115 | 116 | if data.num_nodes != len(x): 117 | continue # cleaning <10 graphs with this discrepancy 118 | 119 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 120 | y=y) 121 | 122 | if self.pre_filter is not None and not self.pre_filter(data): 123 | continue 124 | 125 | if self.pre_transform is not None: 126 | data = self.pre_transform(data) 127 | 128 | data_list.append(data) 129 | pbar.update(1) 130 | 131 | pbar.close() 132 | torch.save(self.collate(data_list), 133 | osp.join(self.processed_dir, f'{split}.pt')) 134 | -------------------------------------------------------------------------------- /spexphormer/train/neighbor_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from torch_geometric.data import Data 5 | from bottleneck import argpartition 6 | 7 | def sampler_uniform(adj_eidx, P, P_inv, k): 8 | deg = adj_eidx.shape[0] 9 | if k > deg: 10 | extra = np.random.choice(adj_eidx, size=k - deg, replace=True) 11 | return np.concatenate([adj_eidx, extra]) 12 | idx = np.random.choice(adj_eidx, size=k, replace=False) 13 | return idx 14 | 15 | def sampler_reservoir(adj_eidx, P, P_inv, k): 16 | deg = adj_eidx.shape[0] 17 | if k > deg: 18 | extra = np.random.choice(adj_eidx, size=k - deg, replace=True, p=P) 19 | return np.concatenate([adj_eidx, extra]) 20 | rsv = -np.log(np.random.rand(deg)) * P_inv 21 | idx = argpartition(rsv, k-1)[:k] 22 | return adj_eidx[idx] 23 | 24 | # The maximum can be cached for each node to not repeat this process. Caching is not implemented here. 25 | def sampler_get_max(adj_eidx, P, P_inv, k): 26 | deg = adj_eidx.shape[0] 27 | if k > deg: 28 | extra = np.random.choice(adj_eidx, size=k - deg, replace=True, p=P) 29 | return np.concatenate([adj_eidx, extra]) 30 | idx = argpartition(P, P.shape[0]-k)[P.shape[0]-k:] 31 | return adj_eidx[idx] 32 | 33 | 34 | class NeighborSampler(): 35 | def __init__(self, original_graph, edge_index, edge_attr, P, adj_eidx, deg, num_layers, sampler='graph_reservoir') -> None: 36 | """ 37 | Initialize the NeighborSampler. 38 | Parameters: 39 | original_graph (torch_geometric.data.Data): The original graph containing n nodes and m edges. 40 | edge_index (torch.Tensor): Tensor of shape (2, m) representing the edge indices. 41 | edge_attr (torch.Tensor): Tensor of shape (m, 1) representing the edge attributes. 42 | P (list): List of length n, where each element is a 2D numpy array of size (num_layers, d_i). 43 | P[i][l][j] represents the attention score of the j-th neighbor of node i in layer l. 44 | adj_eidx (list): List of length n, where adj_eidx[i][j] is the index of the edge between node i 45 | and its j-th neighbor in the edge_index. 46 | deg (list): List representing the number of neighbors to sample per layer. 47 | num_layers (int): Number of layers in the model. 48 | sampler (str): Sampling strategy. It consists of two parts: the first part indicates whether 49 | sampling happens on the graph first and then batching is done, and the second 50 | part indicates the strategy of sampling. Default is 'graph_reservoir'. 51 | Returns: 52 | None 53 | """ 54 | 55 | self.data = original_graph 56 | self.edge_index, self.edge_attr, self.P, self.adj_eidx, self.deg, self.num_layers = \ 57 | edge_index, edge_attr, P, adj_eidx, deg, num_layers 58 | 59 | self.sampler = sampler 60 | self.P = tuple(self.P) 61 | # 1/P values to save some time in calculations 62 | self.P_inv = tuple([1.0/(p+1e-8) for p in self.P]) 63 | 64 | if self.sampler in ['graph_reservoir', 'graph_uniform']: 65 | self.sample_on_batching = False 66 | self.sampled_edge_ids = [] 67 | self.all_nodes_neighbor_sampler() 68 | else: 69 | self.sample_on_batching = True 70 | 71 | 72 | def expand_neighborhood_layer(self, layer_nodes, layer_idx): 73 | k = self.deg[layer_idx] 74 | 75 | if self.sample_on_batching: 76 | if self.sampler == 'batch_reservoir': 77 | layer_edge_ids = [sampler_reservoir(self.adj_eidx[v], self.P[v][layer_idx], self.P_inv[v][layer_idx], k) for v in layer_nodes] 78 | elif self.sampler == 'batch_uniform': 79 | layer_edge_ids = [sampler_uniform(self.adj_eidx[v], self.P[v][layer_idx], self.P_inv[v][layer_idx], k) for v in layer_nodes] 80 | else: 81 | raise ValueError(f'Unkown sampler {self.sampler}') 82 | layer_edge_ids = np.concatenate(layer_edge_ids) 83 | else: 84 | layer_edge_ids = self.sampled_edge_ids[layer_idx][layer_nodes].flatten() 85 | 86 | layer_edge_index = self.edge_index[:, layer_edge_ids] 87 | 88 | new_nodes = layer_edge_index[0, :].squeeze() 89 | new_layer_nodes = np.concatenate((layer_nodes, new_nodes)) 90 | new_layer_nodes = pd.unique(new_layer_nodes) 91 | layer_edge_index = torch.from_numpy(layer_edge_index) 92 | 93 | return new_layer_nodes, layer_edge_index, layer_edge_ids 94 | 95 | 96 | def make_batch(self, core_nodes): 97 | edges = [] 98 | nodes = [] 99 | edge_ids = [] 100 | layer_nodes = np.array(core_nodes) 101 | for layer_idx in reversed(range(self.num_layers)): 102 | layer_nodes, layer_edges, layer_edge_ids = self.expand_neighborhood_layer(layer_nodes, layer_idx) 103 | edges.append(layer_edges) 104 | nodes.append(layer_nodes) 105 | edge_ids.append(layer_edge_ids) 106 | 107 | edges = edges[::-1] 108 | nodes = nodes[::-1] 109 | edge_ids = edge_ids[::-1] 110 | data = Data(x=self.data.x[nodes[0]]) 111 | 112 | max_idx = np.max(nodes[0]) 113 | index_mapping = torch.zeros(max_idx+1, dtype=torch.int) 114 | index_mapping[nodes[0]] = torch.arange(nodes[0].shape[0], dtype=torch.int) 115 | for l in range(self.num_layers): 116 | mapped_edge_indexes = index_mapping[edges[l]].long() 117 | setattr(data, f'edge_index_layer_{l}', mapped_edge_indexes) 118 | setattr(data, f'edge_type_layer_{l}', self.edge_attr[edge_ids[l]]) 119 | 120 | num_layer_nodes = [len(layer_nodes) for layer_nodes in nodes] 121 | num_layer_nodes.append(len(core_nodes)) 122 | data.num_layer_nodes = torch.Tensor(num_layer_nodes).int() 123 | # print(f'data.num_layer_nodes: {data.num_layer_nodes}') 124 | data.y = self.data.y[core_nodes] 125 | 126 | return data 127 | 128 | 129 | def all_nodes_neighbor_sampler(self, layer_node_deg=None): 130 | if layer_node_deg is None: 131 | layer_node_deg = self.deg 132 | self.sampled_edge_ids = [] 133 | num_nodes = len(self.P) 134 | for layer_idx in range(self.num_layers): 135 | if self.sampler == 'graph_reservoir': 136 | layer_edge_ids = [sampler_reservoir(self.adj_eidx[v], self.P[v][layer_idx], self.P_inv[v][layer_idx], layer_node_deg[layer_idx]) for v in range(num_nodes)] 137 | elif self.sampler == 'graph_max': 138 | layer_edge_ids = [sampler_get_max(self.adj_eidx[v], self.P[v][layer_idx], self.P_inv[v][layer_idx], layer_node_deg[layer_idx]) for v in range(num_nodes)] 139 | elif self.sampler == 'graph_uniform': 140 | layer_edge_ids = [sampler_uniform(self.adj_eidx[v], self.P[v][layer_idx], self.P_inv[v][layer_idx], layer_node_deg[layer_idx]) for v in range(num_nodes)] 141 | else: 142 | raise ValueError(f'unknown neighbor sampler {self.sampler}') 143 | layer_edge_ids = np.stack(layer_edge_ids) 144 | self.sampled_edge_ids.append(layer_edge_ids) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import torch 4 | import logging 5 | 6 | from spexphormer.optimizer.extra_optimizers import ExtendedSchedulerConfig 7 | 8 | from torch_geometric.graphgym.cmd_args import parse_args 9 | from torch_geometric.graphgym.config import (cfg, dump_cfg, 10 | set_cfg, load_cfg, 11 | makedirs_rm_exist) 12 | from torch_geometric.graphgym.loader import create_loader 13 | from torch_geometric.graphgym.logger import set_printing 14 | from torch_geometric.graphgym.optim import create_optimizer, \ 15 | create_scheduler, OptimizerConfig 16 | from torch_geometric.graphgym.model_builder import create_model 17 | from torch_geometric.graphgym.train import train 18 | from torch_geometric.graphgym.utils.agg_runs import agg_runs 19 | from torch_geometric.graphgym.utils.comp_budget import params_count 20 | from torch_geometric.graphgym.utils.device import auto_select_device 21 | from torch_geometric.graphgym.register import train_dict 22 | from torch_geometric import seed_everything 23 | import random 24 | import numpy as np 25 | 26 | from spexphormer.logger import create_logger 27 | 28 | 29 | def new_optimizer_config(cfg): 30 | return OptimizerConfig(optimizer=cfg.optim.optimizer, 31 | base_lr=cfg.optim.base_lr, 32 | weight_decay=cfg.optim.weight_decay, 33 | momentum=cfg.optim.momentum) 34 | 35 | 36 | def new_scheduler_config(cfg): 37 | return ExtendedSchedulerConfig( 38 | scheduler=cfg.optim.scheduler, 39 | steps=cfg.optim.steps, lr_decay=cfg.optim.lr_decay, 40 | max_epoch=cfg.optim.max_epoch, reduce_factor=cfg.optim.reduce_factor, 41 | schedule_patience=cfg.optim.schedule_patience, min_lr=cfg.optim.min_lr, 42 | num_warmup_epochs=cfg.optim.num_warmup_epochs, 43 | train_mode=cfg.train.mode, eval_period=cfg.train.eval_period) 44 | 45 | 46 | def custom_set_out_dir(cfg, cfg_fname, name_tag): 47 | """Set custom main output directory path to cfg. 48 | Include the config filename and name_tag in the new :obj:`cfg.out_dir`. 49 | 50 | Args: 51 | cfg (CfgNode): Configuration node 52 | cfg_fname (string): Filename for the yaml format configuration file 53 | name_tag (string): Additional name tag to identify this execution of the 54 | configuration file, specified in :obj:`cfg.name_tag` 55 | """ 56 | run_name = os.path.splitext(os.path.basename(cfg_fname))[0] 57 | run_name += f"-{name_tag}" if name_tag else "" 58 | cfg.out_dir = os.path.join(cfg.out_dir, run_name) 59 | 60 | 61 | def custom_set_run_dir(cfg, run_id): 62 | """Custom output directory naming for each experiment run. 63 | 64 | Args: 65 | cfg (CfgNode): Configuration node 66 | run_id (int): Main for-loop iter id (the random seed or dataset split) 67 | """ 68 | cfg.run_dir = os.path.join(cfg.out_dir, str(run_id)) 69 | # Make output directory 70 | if cfg.train.auto_resume: 71 | os.makedirs(cfg.run_dir, exist_ok=True) 72 | else: 73 | makedirs_rm_exist(cfg.run_dir) 74 | 75 | 76 | def run_loop_settings(): 77 | """Create main loop execution settings based on the current cfg. 78 | 79 | Configures the main execution loop to run in one of two modes: 80 | 1. 'multi-seed' - Reproduces default behaviour of GraphGym when 81 | args.repeats controls how many times the experiment run is repeated. 82 | Each iteration is executed with a random seed set to an increment from 83 | the previous one, starting at initial cfg.seed. 84 | 2. 'multi-split' - Executes the experiment run over multiple dataset splits, 85 | these can be multiple CV splits or multiple standard splits. The random 86 | seed is reset to the initial cfg.seed value for each run iteration. 87 | 88 | Returns: 89 | List of run IDs for each loop iteration 90 | List of rng seeds to loop over 91 | List of dataset split indices to loop over 92 | """ 93 | if len(cfg.run_multiple_splits) == 0: 94 | # 'multi-seed' run mode 95 | num_iterations = args.repeat 96 | seeds = [cfg.seed + x for x in range(num_iterations)] 97 | split_indices = [cfg.dataset.split_index] * num_iterations 98 | run_ids = seeds 99 | else: 100 | # 'multi-split' run mode 101 | if args.repeat != 1: 102 | raise NotImplementedError("Running multiple repeats of multiple " 103 | "splits in one run is not supported.") 104 | num_iterations = len(cfg.run_multiple_splits) 105 | seeds = [cfg.seed] * num_iterations 106 | split_indices = cfg.run_multiple_splits 107 | run_ids = split_indices 108 | return run_ids, seeds, split_indices 109 | 110 | 111 | if __name__ == '__main__': 112 | # Load cmd line args 113 | args = parse_args() 114 | # Load config file 115 | set_cfg(cfg) 116 | load_cfg(cfg, args) 117 | custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag) 118 | dump_cfg(cfg) 119 | auto_select_device() 120 | cfg.device = cfg.accelerator 121 | # Set Pytorch environment 122 | torch.set_num_threads(cfg.num_threads) 123 | # Repeat for multiple experiment runs 124 | for run_id, seed, split_index in zip(*run_loop_settings()): 125 | # Set configurations for each run 126 | custom_set_run_dir(cfg, run_id) 127 | set_printing() 128 | cfg.dataset.split_index = split_index 129 | cfg.seed = seed 130 | cfg.run_id = run_id 131 | seed_everything(cfg.seed) 132 | 133 | logging.info(f"[*] Run ID {run_id}: seed={cfg.seed}, " 134 | f"split_index={cfg.dataset.split_index}") 135 | logging.info(f" Starting now: {datetime.datetime.now()}") 136 | # Set machine learning pipeline 137 | loaders = create_loader() 138 | loggers = create_logger() 139 | model = create_model() 140 | optimizer = create_optimizer(model.parameters(), 141 | new_optimizer_config(cfg)) 142 | scheduler = create_scheduler(optimizer, new_scheduler_config(cfg)) 143 | # Print model info 144 | logging.info(model) 145 | logging.info(cfg) 146 | cfg.params = params_count(model) 147 | logging.info('Num parameters: %s', cfg.params) 148 | # Start training 149 | if cfg.train.mode == 'standard': 150 | if cfg.wandb.use: 151 | logging.warning("[W] WandB logging is not supported with the " 152 | "default train.mode, set it to `custom`") 153 | train(loggers, loaders, model, optimizer, scheduler) 154 | else: 155 | train_dict[cfg.train.mode](loggers, loaders, model, optimizer, 156 | scheduler) 157 | # Aggregate results from different seeds 158 | try: 159 | agg_runs(cfg.out_dir, cfg.metric_best) 160 | except Exception as e: 161 | logging.info(f"Failed when trying to aggregate multiple runs: {e}") 162 | # When being launched in batch mode, mark a yaml as done 163 | if args.mark_done: 164 | os.rename(args.cfg_file, f'{args.cfg_file}_done') 165 | logging.info(f"[*] All done: {datetime.datetime.now()}") 166 | -------------------------------------------------------------------------------- /spexphormer/loader/dataset/peptides_structural.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os.path as osp 3 | import pickle 4 | import shutil 5 | 6 | import pandas as pd 7 | import torch 8 | from ogb.utils import smiles2graph 9 | from ogb.utils.torch_util import replace_numpy_with_torchtensor 10 | from ogb.utils.url import decide_download 11 | from torch_geometric.data import Data, InMemoryDataset, download_url 12 | from tqdm import tqdm 13 | 14 | 15 | class PeptidesStructuralDataset(InMemoryDataset): 16 | def __init__(self, root='datasets', smiles2graph=smiles2graph, 17 | transform=None, pre_transform=None): 18 | """ 19 | PyG dataset of 15,535 small peptides represented as their molecular 20 | graph (SMILES) with 11 regression targets derived from the peptide's 21 | 3D structure. 22 | 23 | The original amino acid sequence representation is provided in 24 | 'peptide_seq' and the distance between atoms in 'self_dist_matrix' field 25 | of the dataset file, but not used here as any part of the input. 26 | 27 | The 11 regression targets were precomputed from molecule XYZ: 28 | Inertia_mass_[a-c]: The principal component of the inertia of the 29 | mass, with some normalizations. Sorted 30 | Inertia_valence_[a-c]: The principal component of the inertia of the 31 | Hydrogen atoms. This is basically a measure of the 3D 32 | distribution of hydrogens. Sorted 33 | length_[a-c]: The length around the 3 main geometric axis of 34 | the 3D objects (without considering atom types). Sorted 35 | Spherocity: SpherocityIndex descriptor computed by 36 | rdkit.Chem.rdMolDescriptors.CalcSpherocityIndex 37 | Plane_best_fit: Plane of best fit (PBF) descriptor computed by 38 | rdkit.Chem.rdMolDescriptors.CalcPBF 39 | Args: 40 | root (string): Root directory where the dataset should be saved. 41 | smiles2graph (callable): A callable function that converts a SMILES 42 | string into a graph object. We use the OGB featurization. 43 | * The default smiles2graph requires rdkit to be installed * 44 | """ 45 | 46 | self.original_root = root 47 | self.smiles2graph = smiles2graph 48 | self.folder = osp.join(root, 'peptides-structural') 49 | 50 | ## Unnormalized targets. 51 | # self.url = 'https://www.dropbox.com/s/464u3303eu2u4zp/peptide_structure_dataset.csv.gz?dl=1' 52 | # self.version = '9786061a34298a0684150f2e4ff13f47' 53 | 54 | ## Standardized targets to zero mean and unit variance. 55 | self.url = 'https://www.dropbox.com/s/0d4aalmq4b4e2nh/peptide_structure_normalized_dataset.csv.gz?dl=1' 56 | self.version = 'c240c1c15466b5c907c63e180fa8aa89' # MD5 hash of the intended dataset file 57 | 58 | self.url_stratified_split = 'https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1' 59 | self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061' 60 | 61 | # Check version and update if necessary. 62 | release_tag = osp.join(self.folder, self.version) 63 | if osp.isdir(self.folder) and (not osp.exists(release_tag)): 64 | print(f"{self.__class__.__name__} has been updated.") 65 | if input("Will you update the dataset now? (y/N)\n").lower() == 'y': 66 | shutil.rmtree(self.folder) 67 | 68 | super().__init__(self.folder, transform, pre_transform) 69 | self.data, self.slices = torch.load(self.processed_paths[0]) 70 | 71 | @property 72 | def raw_file_names(self): 73 | return 'peptide_structure_normalized_dataset.csv.gz' 74 | 75 | @property 76 | def processed_file_names(self): 77 | return 'geometric_data_processed.pt' 78 | 79 | def _md5sum(self, path): 80 | hash_md5 = hashlib.md5() 81 | with open(path, 'rb') as f: 82 | buffer = f.read() 83 | hash_md5.update(buffer) 84 | return hash_md5.hexdigest() 85 | 86 | def download(self): 87 | if decide_download(self.url): 88 | path = download_url(self.url, self.raw_dir) 89 | # Save to disk the MD5 hash of the downloaded file. 90 | hash = self._md5sum(path) 91 | if hash != self.version: 92 | raise ValueError("Unexpected MD5 hash of the downloaded file") 93 | open(osp.join(self.root, hash), 'w').close() 94 | # Download train/val/test splits. 95 | path_split1 = download_url(self.url_stratified_split, self.root) 96 | assert self._md5sum(path_split1) == self.md5sum_stratified_split 97 | else: 98 | print('Stop download.') 99 | exit(-1) 100 | 101 | def process(self): 102 | data_df = pd.read_csv(osp.join(self.raw_dir, 103 | 'peptide_structure_normalized_dataset.csv.gz')) 104 | smiles_list = data_df['smiles'] 105 | target_names = ['Inertia_mass_a', 'Inertia_mass_b', 'Inertia_mass_c', 106 | 'Inertia_valence_a', 'Inertia_valence_b', 107 | 'Inertia_valence_c', 'length_a', 'length_b', 'length_c', 108 | 'Spherocity', 'Plane_best_fit'] 109 | # Assert zero mean and unit standard deviation. 110 | assert all(abs(data_df.loc[:, target_names].mean(axis=0)) < 1e-10) 111 | assert all(abs(data_df.loc[:, target_names].std(axis=0) - 1.) < 1e-10) 112 | 113 | print('Converting SMILES strings into graphs...') 114 | data_list = [] 115 | for i in tqdm(range(len(smiles_list))): 116 | data = Data() 117 | 118 | smiles = smiles_list[i] 119 | y = data_df.iloc[i][target_names] 120 | graph = self.smiles2graph(smiles) 121 | 122 | assert (len(graph['edge_feat']) == graph['edge_index'].shape[1]) 123 | assert (len(graph['node_feat']) == graph['num_nodes']) 124 | 125 | data.__num_nodes__ = int(graph['num_nodes']) 126 | data.edge_index = torch.from_numpy(graph['edge_index']).to( 127 | torch.int64) 128 | data.edge_attr = torch.from_numpy(graph['edge_feat']).to( 129 | torch.int64) 130 | data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) 131 | data.y = torch.Tensor([y]) 132 | 133 | data_list.append(data) 134 | 135 | if self.pre_transform is not None: 136 | data_list = [self.pre_transform(data) for data in data_list] 137 | 138 | data, slices = self.collate(data_list) 139 | 140 | print('Saving...') 141 | torch.save((data, slices), self.processed_paths[0]) 142 | 143 | def get_idx_split(self): 144 | """ Get dataset splits. 145 | 146 | Returns: 147 | Dict with 'train', 'val', 'test', splits indices. 148 | """ 149 | split_file = osp.join(self.root, 150 | "splits_random_stratified_peptide_structure.pickle") 151 | with open(split_file, 'rb') as f: 152 | splits = pickle.load(f) 153 | split_dict = replace_numpy_with_torchtensor(splits) 154 | return split_dict 155 | 156 | 157 | if __name__ == '__main__': 158 | dataset = PeptidesStructuralDataset() 159 | print(dataset) 160 | print(dataset.data.edge_index) 161 | print(dataset.data.edge_index.shape) 162 | print(dataset.data.x.shape) 163 | print(dataset[100]) 164 | print(dataset[100].y) 165 | print(dataset.get_idx_split()) 166 | -------------------------------------------------------------------------------- /spexphormer/loader/planetoid.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from torch_geometric.data import InMemoryDataset, download_url 8 | from torch_geometric.io import read_planetoid_data 9 | 10 | 11 | class Planetoid(InMemoryDataset): 12 | r"""The citation network datasets "Cora", "CiteSeer" and "PubMed" from the 13 | `"Revisiting Semi-Supervised Learning with Graph Embeddings" 14 | `_ paper. 15 | Nodes represent documents and edges represent citation links. 16 | Training, validation and test splits are given by binary masks. 17 | 18 | Args: 19 | root (string): Root directory where the dataset should be saved. 20 | name (string): The name of the dataset (:obj:`"Cora"`, 21 | :obj:`"CiteSeer"`, :obj:`"PubMed"`). 22 | split (string): The type of dataset split 23 | (:obj:`"public"`, :obj:`"full"`, :obj:`"geom-gcn"`, 24 | :obj:`"random"`). 25 | If set to :obj:`"public"`, the split will be the public fixed split 26 | from the `"Revisiting Semi-Supervised Learning with Graph 27 | Embeddings" `_ paper. 28 | If set to :obj:`"full"`, all nodes except those in the validation 29 | and test sets will be used for training (as in the 30 | `"FastGCN: Fast Learning with Graph Convolutional Networks via 31 | Importance Sampling" `_ paper). 32 | If set to :obj:`"geom-gcn"`, the 10 public fixed splits from the 33 | `"Geom-GCN: Geometric Graph Convolutional Networks" 34 | `_ paper are given. 35 | If set to :obj:`"random"`, train, validation, and test sets will be 36 | randomly generated, according to :obj:`num_train_per_class`, 37 | :obj:`num_val` and :obj:`num_test`. (default: :obj:`"public"`) 38 | num_train_per_class (int, optional): The number of training samples 39 | per class in case of :obj:`"random"` split. (default: :obj:`20`) 40 | num_val (int, optional): The number of validation samples in case of 41 | :obj:`"random"` split. (default: :obj:`500`) 42 | num_test (int, optional): The number of test samples in case of 43 | :obj:`"random"` split. (default: :obj:`1000`) 44 | transform (callable, optional): A function/transform that takes in an 45 | :obj:`torch_geometric.data.Data` object and returns a transformed 46 | version. The data object will be transformed before every access. 47 | (default: :obj:`None`) 48 | pre_transform (callable, optional): A function/transform that takes in 49 | an :obj:`torch_geometric.data.Data` object and returns a 50 | transformed version. The data object will be transformed before 51 | being saved to disk. (default: :obj:`None`) 52 | 53 | Stats: 54 | .. list-table:: 55 | :widths: 10 10 10 10 10 56 | :header-rows: 1 57 | 58 | * - Name 59 | - #nodes 60 | - #edges 61 | - #features 62 | - #classes 63 | * - Cora 64 | - 2,708 65 | - 10,556 66 | - 1,433 67 | - 7 68 | * - CiteSeer 69 | - 3,327 70 | - 9,104 71 | - 3,703 72 | - 6 73 | * - PubMed 74 | - 19,717 75 | - 88,648 76 | - 500 77 | - 3 78 | """ 79 | 80 | url = 'https://github.com/kimiyoung/planetoid/raw/master/data' 81 | geom_gcn_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/' 82 | 'geom-gcn/master') 83 | 84 | def __init__(self, root: str, name: str, split: str = "public", 85 | num_train_per_class: int = 20, num_val: int = 500, 86 | num_test: int = 1000, transform: Optional[Callable] = None, 87 | pre_transform: Optional[Callable] = None, train_percent = 0.6): 88 | self.name = name 89 | 90 | self.split = split.lower() 91 | assert self.split in ['public', 'full', 'geom-gcn', 'random'] 92 | 93 | super().__init__(root, transform, pre_transform) 94 | self.data, self.slices = torch.load(self.processed_paths[0]) 95 | 96 | if split == 'full': 97 | data = self.get(0) 98 | data.train_mask.fill_(True) 99 | data.train_mask[data.val_mask | data.test_mask] = False 100 | self.data, self.slices = self.collate([data]) 101 | 102 | elif split == 'random': 103 | data = self.get(0) 104 | data.train_mask.fill_(False) 105 | for c in range(self.num_classes): 106 | idx = (data.y == c).nonzero(as_tuple=False).view(-1) 107 | idx = idx[torch.randperm(idx.size(0))[:int(len(idx) * train_percent)]] 108 | data.train_mask[idx] = True 109 | 110 | remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1) 111 | remaining = remaining[torch.randperm(remaining.size(0))] 112 | 113 | val_index = int(len(remaining)*0.5) 114 | data.val_mask.fill_(False) 115 | data.val_mask[remaining[:val_index]] = True 116 | 117 | data.test_mask.fill_(False) 118 | data.test_mask[remaining[val_index:]] = True 119 | 120 | self.data, self.slices = self.collate([data]) 121 | 122 | @property 123 | def raw_dir(self) -> str: 124 | if self.split == 'geom-gcn': 125 | return osp.join(self.root, self.name, 'geom-gcn', 'raw') 126 | return osp.join(self.root, self.name, 'raw') 127 | 128 | @property 129 | def processed_dir(self) -> str: 130 | if self.split == 'geom-gcn': 131 | return osp.join(self.root, self.name, 'geom-gcn', 'processed') 132 | return osp.join(self.root, self.name, 'processed') 133 | 134 | @property 135 | def raw_file_names(self) -> List[str]: 136 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] 137 | return [f'ind.{self.name.lower()}.{name}' for name in names] 138 | 139 | @property 140 | def processed_file_names(self) -> str: 141 | return 'data.pt' 142 | 143 | def download(self): 144 | for name in self.raw_file_names: 145 | download_url(f'{self.url}/{name}', self.raw_dir) 146 | if self.split == 'geom-gcn': 147 | for i in range(10): 148 | url = f'{self.geom_gcn_url}/splits/{self.name.lower()}' 149 | download_url(f'{url}_split_0.6_0.2_{i}.npz', self.raw_dir) 150 | 151 | def process(self): 152 | data = read_planetoid_data(self.raw_dir, self.name) 153 | 154 | if self.split == 'geom-gcn': 155 | train_masks, val_masks, test_masks = [], [], [] 156 | for i in range(10): 157 | name = f'{self.name.lower()}_split_0.6_0.2_{i}.npz' 158 | splits = np.load(osp.join(self.raw_dir, name)) 159 | train_masks.append(torch.from_numpy(splits['train_mask'])) 160 | val_masks.append(torch.from_numpy(splits['val_mask'])) 161 | test_masks.append(torch.from_numpy(splits['test_mask'])) 162 | data.train_mask = torch.stack(train_masks, dim=1) 163 | data.val_mask = torch.stack(val_masks, dim=1) 164 | data.test_mask = torch.stack(test_masks, dim=1) 165 | 166 | data = data if self.pre_transform is None else self.pre_transform(data) 167 | torch.save(self.collate([data]), self.processed_paths[0]) 168 | 169 | def __repr__(self) -> str: 170 | return f'{self.name}()' 171 | -------------------------------------------------------------------------------- /spexphormer/optimizer/extra_optimizers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Iterator 4 | from dataclasses import dataclass 5 | 6 | import torch.optim as optim 7 | from torch.nn import Parameter 8 | from torch.optim import Adagrad, AdamW, Optimizer 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | 11 | from torch_geometric.graphgym.optim import SchedulerConfig 12 | import torch_geometric.graphgym.register as register 13 | 14 | 15 | @register.register_optimizer('adagrad') 16 | def adagrad_optimizer(params: Iterator[Parameter], base_lr: float, 17 | weight_decay: float) -> Adagrad: 18 | return Adagrad(params, lr=base_lr, weight_decay=weight_decay) 19 | 20 | 21 | @register.register_optimizer('adamW') 22 | def adamW_optimizer(params: Iterator[Parameter], base_lr: float, 23 | weight_decay: float) -> AdamW: 24 | return AdamW(params, lr=base_lr, weight_decay=weight_decay) 25 | 26 | 27 | 28 | @dataclass 29 | class ExtendedSchedulerConfig(SchedulerConfig): 30 | reduce_factor: float = 0.5 31 | schedule_patience: int = 15 32 | min_lr: float = 1e-6 33 | num_warmup_epochs: int = 10 34 | train_mode: str = 'custom' 35 | eval_period: int = 1 36 | 37 | 38 | @register.register_scheduler('plateau') 39 | def plateau_scheduler(optimizer: Optimizer, patience: int, 40 | lr_decay: float) -> ReduceLROnPlateau: 41 | return ReduceLROnPlateau(optimizer, patience=patience, factor=lr_decay) 42 | 43 | 44 | @register.register_scheduler('reduce_on_plateau') 45 | def scheduler_reduce_on_plateau(optimizer: Optimizer, reduce_factor: float, 46 | schedule_patience: int, min_lr: float, 47 | train_mode: str, eval_period: int): 48 | if train_mode == 'standard': 49 | raise ValueError("ReduceLROnPlateau scheduler is not supported " 50 | "by 'standard' graphgym training mode pipeline; " 51 | "try setting config 'train.mode: custom'") 52 | 53 | if eval_period != 1: 54 | logging.warning("When config train.eval_period is not 1, the " 55 | "optim.schedule_patience of ReduceLROnPlateau " 56 | "may not behave as intended.") 57 | 58 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 59 | optimizer=optimizer, 60 | mode='min', 61 | factor=reduce_factor, 62 | patience=schedule_patience, 63 | min_lr=min_lr, 64 | verbose=True 65 | ) 66 | if not hasattr(scheduler, 'get_last_lr'): 67 | # ReduceLROnPlateau doesn't have `get_last_lr` method as of current 68 | # pytorch1.10; we add it here for consistency with other schedulers. 69 | def get_last_lr(self): 70 | """ Return last computed learning rate by current scheduler. 71 | """ 72 | return self._last_lr 73 | 74 | scheduler.get_last_lr = get_last_lr.__get__(scheduler) 75 | scheduler._last_lr = [group['lr'] 76 | for group in scheduler.optimizer.param_groups] 77 | 78 | def modified_state_dict(ref): 79 | """Returns the state of the scheduler as a :class:`dict`. 80 | Additionally modified to ignore 'get_last_lr', 'state_dict'. 81 | Including these entries in the state dict would cause issues when 82 | loading a partially trained / pretrained model from a checkpoint. 83 | """ 84 | return {key: value for key, value in ref.__dict__.items() 85 | if key not in ['sparsifier', 'get_last_lr', 'state_dict']} 86 | 87 | scheduler.state_dict = modified_state_dict.__get__(scheduler) 88 | 89 | return scheduler 90 | 91 | 92 | @register.register_scheduler('linear_with_warmup') 93 | def linear_with_warmup_scheduler(optimizer: Optimizer, 94 | num_warmup_epochs: int, max_epoch: int): 95 | scheduler = get_linear_schedule_with_warmup( 96 | optimizer=optimizer, 97 | num_warmup_steps=num_warmup_epochs, 98 | num_training_steps=max_epoch 99 | ) 100 | return scheduler 101 | 102 | 103 | @register.register_scheduler('cosine_with_warmup') 104 | def cosine_with_warmup_scheduler(optimizer: Optimizer, 105 | num_warmup_epochs: int, max_epoch: int): 106 | scheduler = get_cosine_schedule_with_warmup( 107 | optimizer=optimizer, 108 | num_warmup_steps=num_warmup_epochs, 109 | num_training_steps=max_epoch 110 | ) 111 | return scheduler 112 | 113 | 114 | def get_linear_schedule_with_warmup( 115 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 116 | last_epoch: int = -1): 117 | """ 118 | Implementation by Huggingface: 119 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py 120 | 121 | Create a schedule with a learning rate that decreases linearly from the 122 | initial lr set in the optimizer to 0, after a warmup period during which it 123 | increases linearly from 0 to the initial lr set in the optimizer. 124 | Args: 125 | optimizer ([`~torch.optim.Optimizer`]): 126 | The optimizer for which to schedule the learning rate. 127 | num_warmup_steps (`int`): 128 | The number of steps for the warmup phase. 129 | num_training_steps (`int`): 130 | The total number of training steps. 131 | last_epoch (`int`, *optional*, defaults to -1): 132 | The index of the last epoch when resuming training. 133 | Return: 134 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 135 | """ 136 | 137 | def lr_lambda(current_step: int): 138 | if current_step < num_warmup_steps: 139 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps))) 140 | return max( 141 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 142 | ) 143 | 144 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 145 | 146 | 147 | def get_cosine_schedule_with_warmup( 148 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 149 | num_cycles: float = 0.5, last_epoch: int = -1): 150 | """ 151 | Implementation by Huggingface: 152 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py 153 | 154 | Create a schedule with a learning rate that decreases following the values 155 | of the cosine function between the initial lr set in the optimizer to 0, 156 | after a warmup period during which it increases linearly between 0 and the 157 | initial lr set in the optimizer. 158 | Args: 159 | optimizer ([`~torch.optim.Optimizer`]): 160 | The optimizer for which to schedule the learning rate. 161 | num_warmup_steps (`int`): 162 | The number of steps for the warmup phase. 163 | num_training_steps (`int`): 164 | The total number of training steps. 165 | num_cycles (`float`, *optional*, defaults to 0.5): 166 | The number of waves in the cosine schedule (the defaults is to just 167 | decrease from the max value to 0 following a half-cosine). 168 | last_epoch (`int`, *optional*, defaults to -1): 169 | The index of the last epoch when resuming training. 170 | Return: 171 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 172 | """ 173 | 174 | def lr_lambda(current_step): 175 | if current_step < num_warmup_steps: 176 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps))) 177 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 178 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 179 | 180 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 181 | -------------------------------------------------------------------------------- /spexphormer/loader/heterogeneous_datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import sys 8 | import math 9 | import time 10 | import pickle as pkl 11 | import scipy as sp 12 | from scipy import io 13 | import numpy as np 14 | import pandas as pd 15 | from sklearn.preprocessing import label_binarize 16 | 17 | from torch_geometric.data import Data, InMemoryDataset, download_url 18 | 19 | 20 | 21 | class GeomGCNHeterogeneousDatasets(InMemoryDataset): 22 | r"""Datasets: Penn94, Actor, Squirrel, Chameleon from 23 | `"Geom-GCN: Geometric Graph Convolutional Networks" 24 | `_ paper. 25 | Parts of code are from: 26 | 27 | Args: 28 | root (string): Root directory where the dataset should be saved. 29 | name (string): The name of the dataset (:obj:`"Penn94"`, 30 | :obj:`"Actor"`, :obj:`"Squirrel"`, :obj:`"Chameleon"`). 31 | transform (callable, optional): A function/transform that takes in an 32 | :obj:`torch_geometric.data.Data` object and returns a transformed 33 | version. The data object will be transformed before every access. 34 | (default: :obj:`None`) 35 | pre_transform (callable, optional): A function/transform that takes in 36 | an :obj:`torch_geometric.data.Data` object and returns a 37 | transformed version. The data object will be transformed before 38 | being saved to disk. (default: :obj:`None`) 39 | """ 40 | 41 | 42 | def __init__(self, root: str, name: str, transform: Optional[Callable] = None, 43 | pre_transform: Optional[Callable] = None, train_percent = 0.6): 44 | self.name = name 45 | 46 | self.url = 'https://github.com/DSL-Lab/Specformer/tree/main/Node/node_raw_data' 47 | self.url_penn = 'https://github.com/DSL-Lab/Specformer/raw/main/Node/node_raw_data' 48 | self.url_penn_split = 'https://github.com/DSL-Lab/Specformer/raw/main/Node/node_raw_data/fb100-Penn94-splits.npy' 49 | 50 | 51 | super().__init__(root, transform, pre_transform) 52 | self.data, self.slices = torch.load(self.processed_paths[0]) 53 | 54 | data = self.get(0) 55 | self.data, self.slices = self.collate([data]) 56 | 57 | @property 58 | def raw_dir(self) -> str: 59 | return osp.join(self.root, self.name, 'raw') 60 | 61 | @property 62 | def processed_dir(self) -> str: 63 | return osp.join(self.root, self.name, 'processed') 64 | 65 | @property 66 | def raw_file_names(self) -> List[str]: 67 | if self.name == 'Penn94': 68 | names = ['Penn94.mat', 'fb100-Penn94-splits.npy'] 69 | else: 70 | names = ['out1_graph_edges.txt', 'out1_node_feature_label.txt'] 71 | return names 72 | 73 | @property 74 | def processed_file_names(self) -> str: 75 | return 'data.pt' 76 | 77 | def download(self): 78 | if self.name == 'Penn94': 79 | for name in self.raw_file_names: 80 | download_url(f'{self.url_penn}/{name}', self.raw_dir) 81 | else: 82 | for name in self.raw_file_names: 83 | download_url(f'{self.url_penn}/{self.name}/{name}', self.raw_dir) 84 | 85 | def process(self): 86 | 87 | def feature_normalize(x): 88 | x = np.array(x) 89 | rowsum = x.sum(axis=1, keepdims=True) 90 | rowsum = np.clip(rowsum, 1, 1e10) 91 | return x / rowsum 92 | 93 | if self.name == 'Penn94': 94 | mat = io.loadmat(osp.join(self.raw_dir, 'Penn94.mat')) 95 | A = mat['A'] 96 | metadata = mat['local_info'] 97 | 98 | edge_index = A.nonzero() 99 | metadata = metadata.astype(int) 100 | label = metadata[:, 1] - 1 # gender label, -1 means unlabeled 101 | 102 | # make features into one-hot encodings 103 | feature_vals = np.hstack((np.expand_dims(metadata[:, 0], 1), metadata[:, 2:])) 104 | features = np.empty((A.shape[0], 0)) 105 | for col in range(feature_vals.shape[1]): 106 | feat_col = feature_vals[:, col] 107 | feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col)) 108 | features = np.hstack((features, feat_onehot)) 109 | 110 | node_feat = torch.tensor(features, dtype=torch.float) 111 | num_nodes = metadata.shape[0] 112 | label = torch.LongTensor(label) 113 | edge_index = torch.LongTensor(edge_index).contiguous() 114 | data = Data(x=node_feat, edge_index=edge_index, y=label) 115 | 116 | split = np.load(osp.join(self.raw_dir, 'fb100-Penn94-splits.npy'), allow_pickle=True)[0] 117 | train, valid, test = split['train'], split['valid'], split['test'] 118 | train_mask = torch.zeros(data.y.shape, dtype=torch.bool) 119 | train_mask[train] = True 120 | val_mask = torch.zeros(data.y.shape, dtype=torch.bool) 121 | val_mask[valid] = True 122 | test_mask = torch.zeros(data.y.shape, dtype=torch.bool) 123 | test_mask[test] = True 124 | 125 | data.train_mask = train_mask 126 | data.val_mask = val_mask 127 | data.test_mask = test_mask 128 | 129 | elif self.name in ['chameleon', 'squirrel', 'actor']: 130 | edge_df = pd.read_csv(osp.join(self.raw_dir, 'out1_graph_edges.txt'), sep='\t') 131 | node_df = pd.read_csv(osp.join(self.raw_dir, 'out1_node_feature_label.txt'), sep='\t') 132 | feature = node_df[node_df.columns[1]] 133 | y = node_df[node_df.columns[2]] 134 | 135 | source = list(edge_df[edge_df.columns[0]]) 136 | target = list(edge_df[edge_df.columns[1]]) 137 | 138 | if self.name == 'actor': 139 | # for sparse features 140 | nfeat = 932 141 | x = np.zeros((len(y), nfeat)) 142 | 143 | feature = list(feature) 144 | feature = [feat.split(',') for feat in feature] 145 | for ind, feat in enumerate(feature): 146 | for ff in feat: 147 | x[ind, int(ff)] = 1. 148 | x = feature_normalize(x) 149 | else: 150 | feature = list(feature) 151 | feature = [feat.split(',') for feat in feature] 152 | new_feat = [] 153 | for feat in feature: 154 | new_feat.append([int(f) for f in feat]) 155 | x = np.array(new_feat) 156 | x = feature_normalize(x) 157 | 158 | edge_index = [source, target] 159 | edge_index = torch.LongTensor(edge_index).contiguous() 160 | data = Data(x=torch.Tensor(x), edge_index=edge_index, y=torch.LongTensor(y)) 161 | 162 | y = data.y 163 | nclass = 5 164 | 165 | percls_trn = int(round(0.5 * len(y) / nclass)) 166 | val_lb = int(round(0.25 * len(y))) 167 | 168 | indices = [] 169 | for i in range(nclass): 170 | index = (y == i).nonzero().view(-1) 171 | index = index[torch.randperm(index.size(0), device=index.device)] 172 | indices.append(index) 173 | 174 | train_index = torch.cat([i[:percls_trn] for i in indices], dim=0) 175 | rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0) 176 | rest_index = rest_index[torch.randperm(rest_index.size(0))] 177 | valid_index = rest_index[:val_lb] 178 | test_index = rest_index[val_lb:] 179 | 180 | train_mask = torch.zeros(data.y.shape, dtype=torch.bool) 181 | train_mask[train_index] = True 182 | val_mask = torch.zeros(data.y.shape, dtype=torch.bool) 183 | val_mask[valid_index] = True 184 | test_mask = torch.zeros(data.y.shape, dtype=torch.bool) 185 | test_mask[test_index] = True 186 | 187 | data.train_mask = train_mask 188 | data.val_mask = val_mask 189 | data.test_mask = test_mask 190 | 191 | data = data if self.pre_transform is None else self.pre_transform(data) 192 | torch.save(self.collate([data]), self.processed_paths[0]) 193 | 194 | def __repr__(self) -> str: 195 | return f'{self.name}()' 196 | -------------------------------------------------------------------------------- /spexphormer/loader/dataset/voc_superpixels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import pickle 5 | 6 | import torch 7 | from tqdm import tqdm 8 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 9 | extract_zip) 10 | 11 | 12 | class VOCSuperpixels(InMemoryDataset): 13 | r"""The VOCSuperpixels dataset which contains image superpixels and a semantic segmentation label 14 | for each node superpixel. 15 | 16 | Construction and Preparation: 17 | - The superpixels are extracted in a similar fashion as the MNIST and CIFAR10 superpixels. 18 | - In VOCSuperpixels, the number of superpixel nodes <=500. (Note that it was <=75 for MNIST and 19 | <=150 for CIFAR10.) 20 | - The labeling of each superpixel node is done with the same value of the original pixel ground 21 | truth that is on the mean coord of the superpixel node 22 | 23 | - Based on the SBD annotations from 11355 images taken from the PASCAL VOC 2011 dataset. Original 24 | source `here`_. 25 | 26 | num_classes = 21 27 | ignore_label = 255 28 | 29 | color map 30 | 0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle, 6=bus, 7=car, 8=cat, 9=chair, 10=cow, 31 | 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person, 16=potted plant, 17=sheep, 18=sofa, 19=train, 32 | 20=tv/monitor 33 | 34 | Splitting: 35 | - In the original image dataset there are only train and val splitting. 36 | - For VOCSuperpixels, we maintain train, val and test splits where the train set is AS IS. The original 37 | val split of the image dataset is used to divide into new val and new test split that is eventually used 38 | in VOCSuperpixels. The policy for this val/test splitting is below. 39 | - Split total number of val graphs into 2 sets (val, test) with 50:50 using a stratified split proportionate 40 | to original distribution of data with respect to a meta label. 41 | - Each image is meta-labeled by majority voting of non-background grouth truth node labels. Then new val 42 | and new test is created with stratified sampling based on these meta-labels. This is done for preserving 43 | same distribution of node labels in both new val and new test 44 | - Therefore, the final train, val and test splits are correspondingly original train (8498), new val (1428) 45 | and new test (1429) splits. 46 | 47 | Args: 48 | root (string): Root directory where the dataset should be saved. 49 | name (string, optional): Option to select the graph construction format. 50 | If :obj: `"edge_wt_only_coord"`, the graphs are 8-nn graphs with the edge weights computed based on 51 | only spatial coordinates of superpixel nodes. 52 | If :obj: `"edge_wt_coord_feat"`, the graphs are 8-nn graphs with the edge weights computed based on 53 | combination of spatial coordinates and feature values of superpixel nodes. 54 | If :obj: `"edge_wt_region_boundary"`, the graphs region boundary graphs where two regions (i.e. 55 | superpixel nodes) have an edge between them if they share a boundary in the original image. 56 | (default: :obj:`"edge_wt_region_boundary"`) 57 | slic_compactness (int, optional): Option to select compactness of slic that was used for superpixels 58 | (:obj:`10`, :obj:`30`). (default: :obj:`30`) 59 | transform (callable, optional): A function/transform that takes in an 60 | :obj:`torch_geometric.data.Data` object and returns a transformed 61 | version. The data object will be transformed before every access. 62 | (default: :obj:`None`) 63 | pre_transform (callable, optional): A function/transform that takes in 64 | an :obj:`torch_geometric.data.Data` object and returns a 65 | transformed version. The data object will be transformed before 66 | being saved to disk. (default: :obj:`None`) 67 | pre_filter (callable, optional): A function that takes in an 68 | :obj:`torch_geometric.data.Data` object and returns a boolean 69 | value, indicating whether the data object should be included in the 70 | final dataset. (default: :obj:`None`) 71 | """ 72 | 73 | url = { 74 | 10: { 75 | 'edge_wt_only_coord': 'https://www.dropbox.com/s/rk6pfnuh7tq3t37/voc_superpixels_edge_wt_only_coord.zip?dl=1', 76 | 'edge_wt_coord_feat': 'https://www.dropbox.com/s/2a53nmfp6llqg8y/voc_superpixels_edge_wt_coord_feat.zip?dl=1', 77 | 'edge_wt_region_boundary': 'https://www.dropbox.com/s/6pfz2mccfbkj7r3/voc_superpixels_edge_wt_region_boundary.zip?dl=1' 78 | }, 79 | 30: { 80 | 'edge_wt_only_coord': 'https://www.dropbox.com/s/toqulkdpb1jrswk/voc_superpixels_edge_wt_only_coord.zip?dl=1', 81 | 'edge_wt_coord_feat': 'https://www.dropbox.com/s/xywki8ysj63584d/voc_superpixels_edge_wt_coord_feat.zip?dl=1', 82 | 'edge_wt_region_boundary': 'https://www.dropbox.com/s/8x722ai272wqwl4/voc_superpixels_edge_wt_region_boundary.zip?dl=1' 83 | } 84 | } 85 | 86 | def __init__(self, root, name='edge_wt_region_boundary', slic_compactness=30, split='train', 87 | transform=None, pre_transform=None, pre_filter=None): 88 | self.name = name 89 | self.slic_compactness = slic_compactness 90 | assert split in ['train', 'val', 'test'] 91 | assert name in ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary'] 92 | assert slic_compactness in [10, 30] 93 | super().__init__(root, transform, pre_transform, pre_filter) 94 | path = osp.join(self.processed_dir, f'{split}.pt') 95 | self.data, self.slices = torch.load(path) 96 | 97 | 98 | @property 99 | def raw_file_names(self): 100 | return ['train.pickle', 'val.pickle', 'test.pickle'] 101 | 102 | @property 103 | def raw_dir(self): 104 | return osp.join(self.root, 105 | 'slic_compactness_' + str(self.slic_compactness), 106 | self.name, 107 | 'raw') 108 | 109 | @property 110 | def processed_dir(self): 111 | return osp.join(self.root, 112 | 'slic_compactness_' + str(self.slic_compactness), 113 | self.name, 114 | 'processed') 115 | 116 | @property 117 | def processed_file_names(self): 118 | return ['train.pt', 'val.pt', 'test.pt'] 119 | 120 | def download(self): 121 | shutil.rmtree(self.raw_dir) 122 | path = download_url(self.url[self.slic_compactness][self.name], self.root) 123 | extract_zip(path, self.root) 124 | os.rename(osp.join(self.root, 'voc_superpixels_' + self.name), self.raw_dir) 125 | os.unlink(path) 126 | 127 | def process(self): 128 | for split in ['train', 'val', 'test']: 129 | with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f: 130 | graphs = pickle.load(f) 131 | 132 | indices = range(len(graphs)) 133 | 134 | pbar = tqdm(total=len(indices)) 135 | pbar.set_description(f'Processing {split} dataset') 136 | 137 | data_list = [] 138 | for idx in indices: 139 | graph = graphs[idx] 140 | 141 | """ 142 | Each `graph` is a tuple (x, edge_attr, edge_index, y) 143 | Shape of x : [num_nodes, 14] 144 | Shape of edge_attr : [num_edges, 1] or [num_edges, 2] 145 | Shape of edge_index : [2, num_edges] 146 | Shape of y : [num_nodes] 147 | """ 148 | 149 | x = graph[0].to(torch.float) 150 | edge_attr = graph[1].to(torch.float) 151 | edge_index = graph[2] 152 | y = torch.LongTensor(graph[3]) 153 | 154 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 155 | y=y) 156 | 157 | if self.pre_filter is not None and not self.pre_filter(data): 158 | continue 159 | 160 | if self.pre_transform is not None: 161 | data = self.pre_transform(data) 162 | 163 | data_list.append(data) 164 | pbar.update(1) 165 | 166 | pbar.close() 167 | 168 | torch.save(self.collate(data_list), 169 | osp.join(self.processed_dir, f'{split}.pt')) -------------------------------------------------------------------------------- /spexphormer/transform/expander_edges.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import scipy as sp 4 | from pathlib import Path 5 | from typing import Any, Optional 6 | import torch 7 | from spexphormer.transform.dist_transforms import laplacian_eigenv 8 | from torch_geometric.graphgym.config import cfg 9 | 10 | 11 | def generate_random_regular_graph2(num_nodes, degree, rng=None): 12 | """Generates a random 2d-regular graph with n nodes. 13 | Returns the list of edges. This list is symmetric; i.e., if 14 | (x, y) is an edge so is (y,x). 15 | Args: 16 | num_nodes: Number of nodes in the desired graph. 17 | degree: Desired degree. 18 | rng: random number generator 19 | Returns: 20 | senders: tail of each edge. 21 | receivers: head of each edge. 22 | """ 23 | 24 | if rng is None: 25 | rng = np.random.default_rng() 26 | 27 | senders = [*range(0, num_nodes)] * degree 28 | receivers = rng.permutation(senders).tolist() 29 | 30 | senders, receivers = [*senders, *receivers], [*receivers, *senders] 31 | 32 | return senders, receivers 33 | 34 | 35 | def generate_random_regular_graph1(num_nodes, degree, rng=None): 36 | """Generates a random 2d-regular graph with n nodes. 37 | Returns the list of edges. This list is symmetric; i.e., if 38 | (x, y) is an edge so is (y,x). 39 | Args: 40 | num_nodes: Number of nodes in the desired graph. 41 | degree: Desired degree. 42 | rng: random number generator 43 | Returns: 44 | senders: tail of each edge. 45 | receivers: head of each edge. 46 | """ 47 | 48 | if rng is None: 49 | rng = np.random.default_rng() 50 | 51 | senders = [*range(0, num_nodes)] * degree 52 | receivers = [] 53 | for _ in range(degree): 54 | receivers.extend(rng.permutation(list(range(num_nodes))).tolist()) 55 | 56 | senders, receivers = [*senders, *receivers], [*receivers, *senders] 57 | 58 | senders = np.array(senders) 59 | receivers = np.array(receivers) 60 | 61 | return senders, receivers 62 | 63 | 64 | def generate_random_graph_with_hamiltonian_cycles(num_nodes, degree, rng=None): 65 | """Generates a 2d-regular graph with n nodes using d random hamiltonian cycles. 66 | Returns the list of edges. This list is symmetric; i.e., if 67 | (x, y) is an edge so is (y,x). 68 | Args: 69 | num_nodes: Number of nodes in the desired graph. 70 | degree: Desired degree. 71 | rng: random number generator 72 | Returns: 73 | senders: tail of each edge. 74 | receivers: head of each edge. 75 | """ 76 | 77 | if rng is None: 78 | rng = np.random.default_rng() 79 | 80 | senders = [] 81 | receivers = [] 82 | for _ in range(degree): 83 | permutation = rng.permutation(list(range(num_nodes))).tolist() 84 | for idx, v in enumerate(permutation): 85 | u = permutation[idx - 1] 86 | senders.extend([v, u]) 87 | receivers.extend([u, v]) 88 | 89 | senders = np.array(senders) 90 | receivers = np.array(receivers) 91 | 92 | return senders, receivers 93 | 94 | 95 | def augment_with_expander(data, degree, algorithm, rng=None, max_num_iters=100, exp_index=0): 96 | """Generates a random d-regular expander graph with n nodes. 97 | Returns the list of edges. This list is symmetric; i.e., if 98 | (x, y) is an edge so is (y,x). 99 | Args: 100 | num_nodes: Number of nodes in the desired graph. 101 | degree: Desired degree. 102 | rng: random number generator 103 | max_num_iters: maximum number of iterations 104 | Returns: 105 | senders: tail of each edge. 106 | receivers: head of each edge. 107 | """ 108 | 109 | num_nodes = data.num_nodes 110 | 111 | if rng is None: 112 | rng = np.random.default_rng() 113 | 114 | eig_val = -1 115 | eig_val_lower_bound = max(0, 2 * degree - 2 * math.sqrt(2 * degree - 1) - 0.1) 116 | 117 | max_eig_val_so_far = -1 118 | max_senders = [] 119 | max_receivers = [] 120 | cur_iter = 1 121 | 122 | if num_nodes <= degree: 123 | degree = num_nodes - 1 124 | 125 | # if there are too few nodes, random graph generation will fail. in this case, we will 126 | # add the whole graph. 127 | if num_nodes <= 10: 128 | for i in range(num_nodes): 129 | for j in range(num_nodes): 130 | if i != j: 131 | max_senders.append(i) 132 | max_receivers.append(j) 133 | else: 134 | while eig_val < eig_val_lower_bound and cur_iter <= max_num_iters: 135 | if algorithm == 'Random-d': 136 | senders, receivers = generate_random_regular_graph1(num_nodes, degree, rng) 137 | elif algorithm == 'Random-d-2': 138 | senders, receivers = generate_random_regular_graph2(num_nodes, degree, rng) 139 | elif algorithm == 'Hamiltonian': 140 | senders, receivers = generate_random_graph_with_hamiltonian_cycles(num_nodes, degree, rng) 141 | else: 142 | raise ValueError('prep.exp_algorithm should be one of the Random-d or Hamiltonian') 143 | 144 | if num_nodes > 1e5: 145 | max_senders = senders 146 | max_receivers = receivers 147 | break 148 | 149 | [eig_val, _] = laplacian_eigenv(senders, receivers, k=1, n=num_nodes) 150 | if len(eig_val) == 0: 151 | print("num_nodes = %d, degree = %d, cur_iter = %d, mmax_iters = %d, senders = %d, receivers = %d" %(num_nodes, degree, cur_iter, max_num_iters, len(senders), len(receivers))) 152 | eig_val = 0 153 | else: 154 | eig_val = eig_val[0] 155 | 156 | if eig_val > max_eig_val_so_far: 157 | max_eig_val_so_far = eig_val 158 | max_senders = senders 159 | max_receivers = receivers 160 | 161 | cur_iter += 1 162 | 163 | # eliminate self loops. 164 | non_loops = [ 165 | *filter(lambda i: max_senders[i] != max_receivers[i], range(0, len(max_senders))) 166 | ] 167 | 168 | senders = np.array(max_senders)[non_loops] 169 | receivers = np.array(max_receivers)[non_loops] 170 | 171 | max_senders = torch.tensor(max_senders, dtype=torch.long).view(-1, 1) 172 | max_receivers = torch.tensor(max_receivers, dtype=torch.long).view(-1, 1) 173 | expander_edges = torch.cat([max_senders, max_receivers], dim=1) 174 | 175 | 176 | data.edge_index = torch.cat([data.edge_index, expander_edges.t()], dim=1) 177 | num_exp_edges = expander_edges.shape[0] 178 | 179 | edge_type = torch.zeros(data.edge_index.shape[1], dtype=torch.long) 180 | edge_type[-num_exp_edges:] = 1 181 | 182 | # Adding self loops for complete Exphormer edges 183 | num_nodes = data.num_nodes 184 | self_loops = torch.arange(num_nodes, dtype=torch.long).unsqueeze(0).repeat(2, 1) 185 | data.edge_index = torch.cat((data.edge_index, self_loops), dim=1) 186 | self_loop_feats = torch.full((num_nodes,), 2, dtype=torch.long) 187 | edge_type = torch.cat((edge_type, self_loop_feats)) 188 | 189 | edge_type = edge_type.contiguous() 190 | 191 | if hasattr(data, 'edge_attr') and data.edge_attr is not None: 192 | num_new_edges = edge_type.shape[0] - data.edge_attr.shape[0] 193 | new_edge_attr = torch.zeros((num_new_edges, data.edge_attr.shape[1]), dtype=data.edge_attr.dtype) 194 | data.edge_attr = torch.cat((data.edge_attr, new_edge_attr), dim=0) 195 | edge_type_attr = edge_type 196 | data.edge_attr = torch.cat((edge_type_attr, data.edge_attr), dim=1) 197 | else: 198 | data.edge_attr = edge_type 199 | 200 | if cfg.prep.save_edges: 201 | edge_set_path = Path(f'EdgeSets/{cfg.dataset.name}') 202 | edge_set_path.mkdir(parents=True, exist_ok=True) 203 | np.save(edge_set_path / 'edges_exp.npy', data.edge_index.cpu().detach().numpy()) 204 | np.save(edge_set_path / 'edge_attr_exp.npy', data.edge_attr.cpu().detach().numpy()) 205 | else: 206 | if exp_index == 0: 207 | data.expander_edges = expander_edges 208 | else: 209 | attrname = f"expander_edges{exp_index}" 210 | setattr(data, attrname, expander_edges) 211 | 212 | 213 | return data 214 | 215 | def load_edges(data): 216 | edge_set_path = Path(f'EdgeSets/{cfg.dataset.name}') 217 | 218 | if cfg.prep.num_edge_sets == 1: 219 | edges = np.load(edge_set_path / f'edges_{cfg.prep.edge_set_name}.npy') 220 | edge_attr = np.load(edge_set_path / f'edge_attr_{cfg.prep.edge_set_name}.npy') 221 | data.edge_index = torch.from_numpy(edges) 222 | data.edge_attr = torch.from_numpy(edge_attr) 223 | else: 224 | for idx in range(cfg.prep.num_edge_sets): 225 | edges = np.load(edge_set_path / f'edges_{cfg.prep.edge_set_name}_{idx}.npy') 226 | edge_attr = np.load(edge_set_path / f'edge_attr_{cfg.prep.edge_set_name}_{idx}.npy') 227 | setattr(data, f'edge_index_{idx}', torch.from_numpy(edges)) 228 | setattr(data, f'edge_attr_{idx}', torch.from_numpy(edge_attr).long()) 229 | 230 | return data -------------------------------------------------------------------------------- /spexphormer/loader/dataset/coco_superpixels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import pickle 5 | 6 | import torch 7 | from tqdm import tqdm 8 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 9 | extract_zip) 10 | 11 | 12 | class COCOSuperpixels(InMemoryDataset): 13 | r"""The COCOSuperpixels dataset which contains image superpixels and a semantic segmentation label 14 | for each node superpixel. 15 | 16 | Construction and Preparation: 17 | - The superpixels are extracted in a similar fashion as the MNIST and CIFAR10 superpixels. 18 | - In COCOSuperpixels, the number of superpixel nodes <=500. (Note that it was <=75 for MNIST and 19 | <=150 for CIFAR10.) 20 | - The labeling of each superpixel node is done with the same value of the original pixel ground 21 | truth that is on the mean coord of the superpixel node 22 | 23 | - Based on the COCO 2017 dataset. Original 24 | source `here`_. 25 | 26 | num_classes = 81 27 | 28 | COCO categories: 29 | person bicycle car motorcycle airplane bus train truck boat traffic light fire hydrant stop 30 | sign parking meter bench bird cat dog horse sheep cow elephant bear zebra giraffe backpack 31 | umbrella handbag tie suitcase frisbee skis snowboard sports ball kite baseball bat baseball 32 | glove skateboard surfboard tennis racket bottle wine glass cup fork knife spoon bowl banana 33 | apple sandwich orange broccoli carrot hot dog pizza donut cake chair couch potted plant bed 34 | dining table toilet tv laptop mouse remote keyboard cell phone microwave oven toaster sink 35 | refrigerator book clock vase scissors teddy bear hair drier toothbrush 36 | 37 | Splitting: 38 | - In the original image dataset there are only train and val splitting. 39 | - For COCOSuperpixels, we maintain the original val split as the new test split, and divide the 40 | original train split into new val split and train split. The resultant train, val and test split 41 | have 113286, 5000, 5000 superpixel graphs. 42 | 43 | Args: 44 | root (string): Root directory where the dataset should be saved. 45 | name (string, optional): Option to select the graph construction format. 46 | If :obj: `"edge_wt_only_coord"`, the graphs are 8-nn graphs with the edge weights computed based on 47 | only spatial coordinates of superpixel nodes. 48 | If :obj: `"edge_wt_coord_feat"`, the graphs are 8-nn graphs with the edge weights computed based on 49 | combination of spatial coordinates and feature values of superpixel nodes. 50 | If :obj: `"edge_wt_region_boundary"`, the graphs region boundary graphs where two regions (i.e. 51 | superpixel nodes) have an edge between them if they share a boundary in the original image. 52 | (default: :obj:`"edge_wt_region_boundary"`) 53 | slic_compactness (int, optional): Option to select compactness of slic that was used for superpixels 54 | (:obj:`10`, :obj:`30`). (default: :obj:`30`) 55 | transform (callable, optional): A function/transform that takes in an 56 | :obj:`torch_geometric.data.Data` object and returns a transformed 57 | version. The data object will be transformed before every access. 58 | (default: :obj:`None`) 59 | pre_transform (callable, optional): A function/transform that takes in 60 | an :obj:`torch_geometric.data.Data` object and returns a 61 | transformed version. The data object will be transformed before 62 | being saved to disk. (default: :obj:`None`) 63 | pre_filter (callable, optional): A function that takes in an 64 | :obj:`torch_geometric.data.Data` object and returns a boolean 65 | value, indicating whether the data object should be included in the 66 | final dataset. (default: :obj:`None`) 67 | """ 68 | 69 | url = { 70 | 10: { 71 | 'edge_wt_only_coord': 'https://www.dropbox.com/s/prqizdep8gk0ndk/coco_superpixels_edge_wt_only_coord.zip?dl=1', 72 | 'edge_wt_coord_feat': 'https://www.dropbox.com/s/zftoyln1pkcshcg/coco_superpixels_edge_wt_coord_feat.zip?dl=1', 73 | 'edge_wt_region_boundary': 'https://www.dropbox.com/s/fhihfcyx2y978u8/coco_superpixels_edge_wt_region_boundary.zip?dl=1' 74 | }, 75 | 30: { 76 | 'edge_wt_only_coord': 'https://www.dropbox.com/s/hrbfkxmc5z9lsaz/coco_superpixels_edge_wt_only_coord.zip?dl=1', 77 | 'edge_wt_coord_feat': 'https://www.dropbox.com/s/4rfa2d5ij1gfu9b/coco_superpixels_edge_wt_coord_feat.zip?dl=1', 78 | 'edge_wt_region_boundary': 'https://www.dropbox.com/s/r6ihg1f4pmyjjy0/coco_superpixels_edge_wt_region_boundary.zip?dl=1' 79 | } 80 | } 81 | 82 | def __init__(self, root, name='edge_wt_region_boundary', slic_compactness=30, split='train', 83 | transform=None, pre_transform=None, pre_filter=None): 84 | self.name = name 85 | self.slic_compactness = slic_compactness 86 | assert split in ['train', 'val', 'test'] 87 | assert name in ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary'] 88 | assert slic_compactness in [10, 30] 89 | super().__init__(root, transform, pre_transform, pre_filter) 90 | path = osp.join(self.processed_dir, f'{split}.pt') 91 | self.data, self.slices = torch.load(path) 92 | 93 | 94 | @property 95 | def raw_file_names(self): 96 | return ['train.pickle', 'val.pickle', 'test.pickle'] 97 | 98 | @property 99 | def raw_dir(self): 100 | return osp.join(self.root, 101 | 'slic_compactness_' + str(self.slic_compactness), 102 | self.name, 103 | 'raw') 104 | 105 | @property 106 | def processed_dir(self): 107 | return osp.join(self.root, 108 | 'slic_compactness_' + str(self.slic_compactness), 109 | self.name, 110 | 'processed') 111 | 112 | @property 113 | def processed_file_names(self): 114 | return ['train.pt', 'val.pt', 'test.pt'] 115 | 116 | def download(self): 117 | shutil.rmtree(self.raw_dir) 118 | path = download_url(self.url[self.slic_compactness][self.name], self.root) 119 | extract_zip(path, self.root) 120 | os.rename(osp.join(self.root, 'coco_superpixels_' + self.name), self.raw_dir) 121 | os.unlink(path) 122 | 123 | def label_remap(self): 124 | # Util function to remap the labels as the original label idxs are not contiguous 125 | 126 | original_label_ix = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 127 | 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 128 | 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 129 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 130 | 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 131 | 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 132 | 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 133 | 87, 88, 89, 90] 134 | label_map = {} 135 | for i, key in enumerate(original_label_ix): 136 | label_map[key] = i 137 | 138 | return label_map 139 | 140 | def process(self): 141 | label_map = self.label_remap() 142 | for split in ['train', 'val', 'test']: 143 | with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f: 144 | graphs = pickle.load(f) 145 | 146 | indices = range(len(graphs)) 147 | 148 | pbar = tqdm(total=len(indices)) 149 | pbar.set_description(f'Processing {split} dataset') 150 | 151 | data_list = [] 152 | for idx in indices: 153 | graph = graphs[idx] 154 | 155 | """ 156 | Each `graph` is a tuple (x, edge_attr, edge_index, y) 157 | Shape of x : [num_nodes, 14] 158 | Shape of edge_attr : [num_edges, 1] or [num_edges, 2] 159 | Shape of edge_index : [2, num_edges] 160 | Shape of y : [num_nodes] 161 | """ 162 | 163 | x = graph[0].to(torch.float) 164 | edge_attr = graph[1].to(torch.float) 165 | edge_index = graph[2] 166 | y = torch.LongTensor(graph[3]) 167 | 168 | # Label remapping. See self.label_remap() func 169 | for i, label in enumerate(y): 170 | y[i] = label_map[label.item()] 171 | 172 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 173 | y=y) 174 | 175 | if self.pre_filter is not None and not self.pre_filter(data): 176 | continue 177 | 178 | if self.pre_transform is not None: 179 | data = self.pre_transform(data) 180 | 181 | data_list.append(data) 182 | pbar.update(1) 183 | 184 | pbar.close() 185 | 186 | torch.save(self.collate(data_list), 187 | osp.join(self.processed_dir, f'{split}.pt')) 188 | -------------------------------------------------------------------------------- /spexphormer/train/custom_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from torch_geometric.graphgym.checkpoint import load_ckpt, save_ckpt, clean_ckpt 9 | from torch_geometric.graphgym.config import cfg 10 | from torch_geometric.graphgym.loss import compute_loss 11 | from torch_geometric.graphgym.register import register_train 12 | from torch_geometric.graphgym.utils.epoch import is_eval_epoch, is_ckpt_epoch 13 | 14 | from spexphormer.utils import cfg_to_dict, flatten_dict, make_wandb_name 15 | 16 | 17 | def train_epoch(logger, loader, model, optimizer, scheduler): 18 | model.train() 19 | optimizer.zero_grad() 20 | time_start = time.time() 21 | 22 | for iter, batch in enumerate(loader): 23 | batch.split = 'train' 24 | batch.to(torch.device(cfg.accelerator)) 25 | pred, true = model(batch) 26 | loss, pred_score = compute_loss(pred, true) 27 | _true = true.detach().to('cpu', non_blocking=True) 28 | _pred = pred_score.detach().to('cpu', non_blocking=True) 29 | loss.backward() 30 | 31 | if cfg.optim.clip_grad_norm: 32 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 33 | optimizer.step() 34 | optimizer.zero_grad() 35 | logger.update_stats(true=_true, 36 | pred=_pred, 37 | loss=loss.detach().cpu().item(), 38 | lr=scheduler.get_last_lr()[0], 39 | time_used=time.time() - time_start, 40 | params=cfg.params, 41 | dataset_name=cfg.dataset.name) 42 | time_start = time.time() 43 | 44 | 45 | @torch.no_grad() 46 | def eval_epoch(logger, loader, model, split='val'): 47 | model.eval() 48 | time_start = time.time() 49 | for batch in loader: 50 | batch.split = split 51 | batch.to(torch.device(cfg.accelerator)) 52 | 53 | pred, true = model(batch) 54 | extra_stats = {} 55 | loss, pred_score = compute_loss(pred, true) 56 | _true = true.detach().to('cpu', non_blocking=True) 57 | _pred = pred_score.detach().to('cpu', non_blocking=True) 58 | logger.update_stats(true=_true, 59 | pred=_pred, 60 | loss=loss.detach().cpu().item(), 61 | lr=0, time_used=time.time() - time_start, 62 | params=cfg.params, 63 | dataset_name=cfg.dataset.name, 64 | **extra_stats) 65 | time_start = time.time() 66 | 67 | 68 | @register_train('custom_train') 69 | def custom_train_initial(loggers, loaders, model, optimizer, scheduler): 70 | """ 71 | Customized training pipeline. 72 | 73 | Args: 74 | loggers: List of loggers 75 | loaders: List of loaders 76 | model: GNN model 77 | optimizer: PyTorch optimizer 78 | scheduler: PyTorch learning rate scheduler 79 | 80 | """ 81 | 82 | start_epoch = 0 83 | if cfg.train.auto_resume: 84 | start_epoch = load_ckpt(model, optimizer, scheduler, 85 | cfg.train.epoch_resume) 86 | if start_epoch == cfg.optim.max_epoch: 87 | logging.info('Checkpoint found, Task already done') 88 | else: 89 | logging.info('Start from epoch %s', start_epoch) 90 | 91 | if cfg.wandb.use: 92 | try: 93 | import wandb 94 | except: 95 | raise ImportError('WandB is not installed.') 96 | if cfg.wandb.name == '': 97 | wandb_name = make_wandb_name(cfg) 98 | else: 99 | wandb_name = cfg.wandb.name 100 | run = wandb.init(entity=cfg.wandb.entity, project=cfg.wandb.project, 101 | name=wandb_name) 102 | run.config.update(cfg_to_dict(cfg)) 103 | 104 | num_splits = len(loggers) 105 | split_names = ['val', 'test'] 106 | full_epoch_times = [] 107 | perf = [[] for _ in range(num_splits)] 108 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch): 109 | cfg.train.cur_epoch = cur_epoch 110 | 111 | start_time = time.perf_counter() 112 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) 113 | perf[0].append(loggers[0].write_epoch(cur_epoch)) 114 | if is_eval_epoch(cur_epoch): 115 | for i in range(1, num_splits): 116 | eval_epoch(loggers[i], loaders[i], model, 117 | split=split_names[i - 1]) 118 | perf[i].append(loggers[i].write_epoch(cur_epoch)) 119 | else: 120 | for i in range(1, num_splits): 121 | perf[i].append(perf[i][-1]) 122 | 123 | val_perf = perf[1] 124 | if cfg.optim.scheduler == 'reduce_on_plateau': 125 | scheduler.step(val_perf[-1]['loss']) 126 | else: 127 | scheduler.step() 128 | full_epoch_times.append(time.perf_counter() - start_time) 129 | # Checkpoint with regular frequency (if enabled). 130 | if cfg.train.enable_ckpt and not cfg.train.ckpt_best \ 131 | and is_ckpt_epoch(cur_epoch): 132 | save_ckpt(model, optimizer, scheduler, cur_epoch) 133 | 134 | if cfg.wandb.use: 135 | run.log(flatten_dict(perf), step=cur_epoch) 136 | 137 | # Log current best stats on eval epoch. 138 | if is_eval_epoch(cur_epoch): 139 | best_epoch = np.array([vp['loss'] for vp in val_perf]).argmin() 140 | best_train = best_val = best_test = "" 141 | if cfg.metric_best != 'auto': 142 | # Select again based on val perf of `cfg.metric_best`. 143 | m = cfg.metric_best 144 | best_epoch = getattr(np.array([vp[m] for vp in val_perf]), 145 | cfg.metric_agg)() 146 | if m in perf[0][best_epoch]: 147 | best_train = f"train_{m}: {perf[0][best_epoch][m]:.4f}" 148 | else: 149 | # Note: For some datasets it is too expensive to compute 150 | # the main metric on the training set. 151 | best_train = f"train_{m}: {0:.4f}" 152 | best_val = f"val_{m}: {perf[1][best_epoch][m]:.4f}" 153 | best_test = f"test_{m}: {perf[2][best_epoch][m]:.4f}" 154 | 155 | if cfg.wandb.use: 156 | bstats = {"best/epoch": best_epoch} 157 | for i, s in enumerate(['train', 'val', 'test']): 158 | bstats[f"best/{s}_loss"] = perf[i][best_epoch]['loss'] 159 | if m in perf[i][best_epoch]: 160 | bstats[f"best/{s}_{m}"] = perf[i][best_epoch][m] 161 | run.summary[f"best_{s}_perf"] = \ 162 | perf[i][best_epoch][m] 163 | for x in ['hits@1', 'hits@3', 'hits@10', 'mrr']: 164 | if x in perf[i][best_epoch]: 165 | bstats[f"best/{s}_{x}"] = perf[i][best_epoch][x] 166 | run.log(bstats, step=cur_epoch) 167 | run.summary["full_epoch_time_avg"] = np.mean(full_epoch_times) 168 | run.summary["full_epoch_time_sum"] = np.sum(full_epoch_times) 169 | # Checkpoint the best epoch params (if enabled). 170 | if cfg.train.enable_ckpt and cfg.train.ckpt_best and \ 171 | best_epoch == cur_epoch: 172 | save_ckpt(model, optimizer, scheduler, cur_epoch) 173 | if cfg.train.ckpt_clean: # Delete old ckpt each time. 174 | clean_ckpt() 175 | logging.info( 176 | f"> Epoch {cur_epoch}: took {full_epoch_times[-1]:.1f}s " 177 | f"(avg {np.mean(full_epoch_times):.1f}s) | " 178 | f"Best so far: epoch {best_epoch}\t" 179 | f"train_loss: {perf[0][best_epoch]['loss']:.4f} {best_train}\t" 180 | f"val_loss: {perf[1][best_epoch]['loss']:.4f} {best_val}\t" 181 | f"test_loss: {perf[2][best_epoch]['loss']:.4f} {best_test}" 182 | ) 183 | if hasattr(model, 'trf_layers'): 184 | # Log SAN's gamma parameter values if they are trainable. 185 | for li, gtl in enumerate(model.trf_layers): 186 | if torch.is_tensor(gtl.attention.gamma) and \ 187 | gtl.attention.gamma.requires_grad: 188 | logging.info(f" {gtl.__class__.__name__} {li}: " 189 | f"gamma={gtl.attention.gamma.item()}") 190 | logging.info(f"Avg time per epoch: {np.mean(full_epoch_times):.2f}s") 191 | logging.info(f"Total train loop time: {np.sum(full_epoch_times) / 3600:.2f}h") 192 | 193 | 194 | if cfg.train.save_attention_scores: 195 | cfg.train.saving_epoch=True 196 | load_ckpt(model) 197 | model.eval() 198 | cfg.train.cur_epoch = int(best_epoch) 199 | os.makedirs(f'Attention_scores/{cfg.dataset.name}', exist_ok=True) 200 | with open(f'Attention_scores/{cfg.dataset.name}/edges.npy', 'wb') as f: 201 | np.save(f, loaders[0].dataset.data.edge_index.detach().numpy()) 202 | with open(f'Attention_scores/{cfg.dataset.name}/edge_attr.npy', 'wb') as f: 203 | np.save(f, loaders[0].dataset.data.edge_attr.detach().numpy()) 204 | with torch.no_grad(): 205 | for batch in loaders[1]: 206 | batch.split = 'val' 207 | batch.to(torch.device(cfg.accelerator)) 208 | _ = model(batch) 209 | cfg.train.saving_epoch=False 210 | 211 | for logger in loggers: 212 | logger.close() 213 | if cfg.train.ckpt_clean: 214 | clean_ckpt() 215 | # close wandb 216 | if cfg.wandb.use: 217 | run.finish() 218 | run = None 219 | 220 | logging.info('Task done, results saved in %s', cfg.run_dir) -------------------------------------------------------------------------------- /spexphormer/loader/split_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | from sklearn.model_selection import KFold, StratifiedKFold, ShuffleSplit 7 | from torch_geometric.graphgym.config import cfg 8 | from torch_geometric.graphgym.loader import index2mask, set_dataset_attr 9 | import torch 10 | 11 | def prepare_splits(dataset): 12 | """Ready train/val/test splits. 13 | 14 | Determine the type of split from the config and call the corresponding 15 | split generation / verification function. 16 | """ 17 | split_mode = cfg.dataset.split_mode 18 | 19 | if split_mode == 'standard': 20 | setup_standard_split(dataset) 21 | elif split_mode == 'random': 22 | setup_random_split(dataset) 23 | elif split_mode.startswith('cv-'): 24 | cv_type, k = split_mode.split('-')[1:] 25 | setup_cv_split(dataset, cv_type, int(k)) 26 | else: 27 | raise ValueError(f"Unknown split mode: {split_mode}") 28 | 29 | 30 | def setup_standard_split(dataset): 31 | """Select a standard split. 32 | 33 | Use standard splits that come with the dataset. Pick one split based on the 34 | ``split_index`` from the config file if multiple splits are available. 35 | 36 | GNNBenchmarkDatasets have splits that are not prespecified as masks. Therefore, 37 | they are handled differently and are first processed to generate the masks. 38 | 39 | Raises: 40 | ValueError: If any one of train/val/test mask is missing. 41 | IndexError: If the ``split_index`` is greater or equal to the total 42 | number of splits available. 43 | """ 44 | split_index = cfg.dataset.split_index 45 | task_level = cfg.dataset.task 46 | 47 | if task_level == 'node': 48 | for split_name in 'train_mask', 'val_mask', 'test_mask': 49 | mask = getattr(dataset.data, split_name, None) 50 | # Check if the train/val/test split mask is available 51 | if mask is None: 52 | raise ValueError(f"Missing '{split_name}' for standard split") 53 | 54 | # Pick a specific split if multiple splits are available 55 | if mask.dim() == 2: 56 | if split_index >= mask.shape[1]: 57 | raise IndexError(f"Specified split index ({split_index}) is " 58 | f"out of range of the number of available " 59 | f"splits ({mask.shape[1]}) for {split_name}") 60 | set_dataset_attr(dataset, split_name, mask[:, split_index], 61 | len(mask[:, split_index])) 62 | else: 63 | if split_index != 0: 64 | raise IndexError(f"This dataset has single standard split") 65 | 66 | elif task_level == 'graph': 67 | for split_name in 'train_graph_index', 'val_graph_index', 'test_graph_index': 68 | if not hasattr(dataset.data, split_name): 69 | raise ValueError(f"Missing '{split_name}' for standard split") 70 | if split_index != 0: 71 | raise NotImplementedError(f"Multiple standard splits not supported " 72 | f"for dataset task level: {task_level}") 73 | 74 | elif task_level == 'link_pred': 75 | for split_name in 'train_edge_index', 'val_edge_index', 'test_edge_index': 76 | if not hasattr(dataset.data, split_name): 77 | raise ValueError(f"Missing '{split_name}' for standard split") 78 | if split_index != 0: 79 | raise NotImplementedError(f"Multiple standard splits not supported " 80 | f"for dataset task level: {task_level}") 81 | 82 | else: 83 | if split_index != 0: 84 | raise NotImplementedError(f"Multiple standard splits not supported " 85 | f"for dataset task level: {task_level}") 86 | 87 | 88 | def setup_random_split(dataset): 89 | """Generate random splits. 90 | 91 | Generate random train/val/test based on the ratios defined in the config 92 | file. 93 | 94 | Raises: 95 | ValueError: If the number split ratios is not equal to 3, or the ratios 96 | do not sum up to 1. 97 | """ 98 | split_ratios = cfg.dataset.split 99 | 100 | # seed = cfg.seed 101 | seed = 0 102 | 103 | if len(split_ratios) != 3: 104 | raise ValueError( 105 | f"Three split ratios is expected for train/val/test, received " 106 | f"{len(split_ratios)} split ratios: {repr(split_ratios)}") 107 | elif sum(split_ratios) != 1: 108 | raise ValueError( 109 | f"The train/val/test split ratios must sum up to 1, input ratios " 110 | f"sum up to {sum(split_ratios):.2f} instead: {repr(split_ratios)}") 111 | 112 | train_index, val_test_index = next( 113 | ShuffleSplit( 114 | train_size=split_ratios[0], 115 | random_state=seed 116 | ).split(dataset.data.y, dataset.data.y) 117 | ) 118 | val_index, test_index = next( 119 | ShuffleSplit( 120 | train_size=split_ratios[1] / (1 - split_ratios[0]), 121 | random_state=seed 122 | ).split(dataset.data.y[val_test_index], dataset.data.y[val_test_index]) 123 | ) 124 | val_index = val_test_index[val_index] 125 | test_index = val_test_index[test_index] 126 | 127 | set_dataset_splits(dataset, [train_index, val_index, test_index]) 128 | 129 | 130 | def set_dataset_splits(dataset, splits): 131 | """Set given splits to the dataset object. 132 | 133 | Args: 134 | dataset: PyG dataset object 135 | splits: List of train/val/test split indices 136 | 137 | Raises: 138 | ValueError: If any pair of splits has intersecting indices 139 | """ 140 | # First check whether splits intersect and raise error if so. 141 | for i in range(len(splits) - 1): 142 | for j in range(i + 1, len(splits)): 143 | n_intersect = len(set(splits[i]) & set(splits[j])) 144 | if n_intersect != 0: 145 | raise ValueError( 146 | f"Splits must not have intersecting indices: " 147 | f"split #{i} (n = {len(splits[i])}) and " 148 | f"split #{j} (n = {len(splits[j])}) have " 149 | f"{n_intersect} intersecting indices" 150 | ) 151 | task_level = cfg.dataset.task 152 | if task_level == 'node': 153 | split_names = ['train_mask', 'val_mask', 'test_mask'] 154 | for split_name, split_index in zip(split_names, splits): 155 | mask = index2mask(torch.LongTensor(split_index), size=dataset.data.y.shape[0]) 156 | set_dataset_attr(dataset, split_name, mask, len(mask)) 157 | 158 | elif task_level == 'graph': 159 | split_names = [ 160 | 'train_graph_index', 'val_graph_index', 'test_graph_index' 161 | ] 162 | for split_name, split_index in zip(split_names, splits): 163 | set_dataset_attr(dataset, split_name, split_index, len(split_index)) 164 | 165 | else: 166 | raise ValueError(f"Unsupported dataset task level: {task_level}") 167 | 168 | 169 | def setup_cv_split(dataset, cv_type, k): 170 | """Generate cross-validation splits. 171 | 172 | Generate `k` folds for cross-validation based on `cv_type` procedure. Save 173 | these to disk or load existing splits, then select particular train/val/test 174 | split based on cfg.dataset.split_index from the config object. 175 | 176 | Args: 177 | dataset: PyG dataset object 178 | cv_type: Identifier for which sklearn fold splitter to use 179 | k: how many cross-validation folds to split the dataset into 180 | 181 | Raises: 182 | IndexError: If the `split_index` is greater than or equal to `k` 183 | """ 184 | split_index = cfg.dataset.split_index 185 | split_dir = cfg.dataset.split_dir 186 | 187 | if split_index >= k: 188 | raise IndexError(f"Specified split_index={split_index} is " 189 | f"out of range of the number of folds k={k}") 190 | 191 | os.makedirs(split_dir, exist_ok=True) 192 | save_file = os.path.join( 193 | split_dir, 194 | f"{cfg.dataset.format}_{dataset.name}_{cv_type}-{k}.json" 195 | ) 196 | if not os.path.isfile(save_file): 197 | create_cv_splits(dataset, cv_type, k, save_file) 198 | with open(save_file) as f: 199 | cv = json.load(f) 200 | assert cv['dataset'] == dataset.name, "Unexpected dataset CV splits" 201 | assert cv['n_samples'] == len(dataset), "Dataset length does not match" 202 | assert cv['n_splits'] > split_index, "Fold selection out of range" 203 | assert k == cv['n_splits'], f"Expected k={k}, but {cv['n_splits']} found" 204 | 205 | test_ids = cv[str(split_index)] 206 | val_ids = cv[str((split_index + 1) % k)] 207 | train_ids = [] 208 | for i in range(k): 209 | if i != split_index and i != (split_index + 1) % k: 210 | train_ids.extend(cv[str(i)]) 211 | 212 | set_dataset_splits(dataset, [train_ids, val_ids, test_ids]) 213 | 214 | 215 | def create_cv_splits(dataset, cv_type, k, file_name): 216 | """Create cross-validation splits and save them to file. 217 | """ 218 | n_samples = len(dataset) 219 | if cv_type == 'stratifiedkfold': 220 | kf = StratifiedKFold(n_splits=k, shuffle=True, random_state=123) 221 | kf_split = kf.split(np.zeros(n_samples), dataset.data.y) 222 | elif cv_type == 'kfold': 223 | kf = KFold(n_splits=k, shuffle=True, random_state=123) 224 | kf_split = kf.split(np.zeros(n_samples)) 225 | else: 226 | ValueError(f"Unexpected cross-validation type: {cv_type}") 227 | 228 | splits = {'n_samples': n_samples, 229 | 'n_splits': k, 230 | 'cross_validator': kf.__str__(), 231 | 'dataset': dataset.name 232 | } 233 | for i, (_, ids) in enumerate(kf_split): 234 | splits[i] = ids.tolist() 235 | with open(file_name, 'w') as f: 236 | json.dump(splits, f) 237 | logging.info(f"[*] Saved newly generated CV splits by {kf} to {file_name}") 238 | --------------------------------------------------------------------------------