├── .gitignore ├── LICENSE ├── figure └── mse.png ├── readme.md └── source ├── __init__.py ├── __main__.py ├── components ├── __init__.py ├── logger.py ├── lr_scheduler.py └── optimizer.py ├── conf ├── config.yaml ├── dataset │ ├── ABCD.yaml │ └── ABIDE.yaml ├── datasz │ ├── 100p.yaml │ ├── 10p.yaml │ ├── 20p.yaml │ ├── 30p.yaml │ ├── 40p.yaml │ ├── 50p.yaml │ ├── 60p.yaml │ ├── 70p.yaml │ ├── 80p.yaml │ └── 90p.yaml ├── model │ ├── bnt.yaml │ ├── brainnetcnn.yaml │ ├── fbnetgen.yaml │ └── transformer.yaml ├── optimizer │ └── adam.yaml ├── preprocess │ ├── mixup.yaml │ └── non_mixup.yaml └── training │ └── basic_training.yaml ├── dataset ├── __init__.py ├── abcd.py ├── abide.py ├── dataloader.py └── preprocess.py ├── models ├── BNT │ ├── __init__.py │ ├── bnt.py │ ├── components │ │ ├── __init__.py │ │ └── transformer_encoder.py │ └── ptdec │ │ ├── __init__.py │ │ ├── cluster.py │ │ └── dec.py ├── __init__.py ├── base.py ├── brainnetcnn.py ├── fbnetgen.py └── transformer.py ├── training ├── FBNettraining.py ├── __init__.py └── training.py └── utils ├── __init__.py ├── accuracy.py ├── count_params.py ├── gumbel_softmax.py ├── hyperboloid.py ├── meter.py └── prepossess.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *__pycache__* 3 | *.npy 4 | *.pt 5 | wandb/* 6 | result/* 7 | outputs/* 8 | multirun/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xuan Kan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /figure/mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wayfear/BrainNetworkTransformer/8a588aadad0166209269fa114e5df4e42209e207/figure/mse.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Brain Network Transformer 2 | 3 | Brain Network Transformer is the open-source implementation of the NeurIPS 2022 paper [Brain Network Transformer](https://arxiv.org/abs/2210.06681). 4 | 5 | 6 | [![stars - BrainNetworkTransformer](https://img.shields.io/github/stars/Wayfear/BrainNetworkTransformer?style=social)](https://github.com/Wayfear/BrainNetworkTransformer) 7 | [![forks - BrainNetworkTransformer](https://img.shields.io/github/forks/Wayfear/BrainNetworkTransformer?style=social)](https://github.com/Wayfear/BrainNetworkTransformer) 8 | ![language](https://img.shields.io/github/languages/top/Wayfear/BrainNetworkTransformer?color=lightgrey) 9 | ![lines](https://img.shields.io/tokei/lines/github/Wayfear/BrainNetworkTransformer?color=red) 10 | ![license](https://img.shields.io/github/license/Wayfear/BrainNetworkTransformer) 11 | --- 12 | 13 | 14 | ## Dataset 15 | 16 | Download the ABIDE dataset from [here](https://drive.google.com/file/d/14UGsikYH_SQ-d_GvY2Um2oEHw3WNxDY3/view?usp=sharing). 17 | 18 | ## Usage 19 | 20 | 1. Change the *path* attribute in file *source/conf/dataset/ABIDE.yaml* to the path of your dataset. 21 | 22 | 2. Run the following command to train the model. 23 | 24 | ```bash 25 | python -m source --multirun datasz=100p model=bnt,fbnetgen,brainnetcnn,transformer dataset=ABIDE,ABCD repeat_time=5 preprocess=mixup 26 | ``` 27 | 28 | - **datasz**, default=(10p, 20p, 30p, 40p, 50p, 60p, 70p, 80p, 90p, 100p). How much data to use for training. The value is a percentage of the total number of samples in the dataset. For example, 10p means 10% of the total number of samples in the training set. 29 | 30 | - **model**, default=(bnt,fbnetgen,brainnetcnn,transformer). Which model to use. The value is a list of model names. For example, bnt means Brain Network Transformer, fbnetgen means FBNetGen, brainnetcnn means BrainNetCNN, transformer means VanillaTF. 31 | 32 | - **dataset**, default=(ABIDE,ABCD). Which dataset to use. The value is a list of dataset names. For example, ABIDE means ABIDE, ABCD means ABCD. 33 | 34 | - **repeat_time**, default=5. How many times to repeat the experiment. The value is an integer. For example, 5 means repeat 5 times. 35 | 36 | - **preprocess**, default=(mixup, non_mixup). Which preprocess to applied. The value is a list of preprocess names. For example, mixup means mixup, non_mixup means the dataset is feeded into models without preprocess. 37 | 38 | 39 | ## Installation 40 | 41 | ```bash 42 | conda create --name bnt python=3.9 43 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 44 | conda install -c conda-forge wandb 45 | pip install hydra-core --upgrade 46 | conda install -c conda-forge scikit-learn 47 | conda install -c conda-forge pandas 48 | ``` 49 | 50 | 51 | ## Dependencies 52 | 53 | - python=3.9 54 | - cudatoolkit=11.3 55 | - torchvision=0.13.1 56 | - pytorch=1.12.1 57 | - torchaudio=0.12.1 58 | - wandb=0.13.1 59 | - scikit-learn=1.1.1 60 | - pandas=1.4.3 61 | - hydra-core=1.2.0 62 | 63 | 64 | ## Regression Performance 65 | 66 | We show regression performance which is not included in the paper. The results are the test MSE for the prediction of NIH Toolbox Picture Vocabulary Test Age 3+ v2.0 Uncorrected Standard Score, which is the "nihtbx_picvocab_uncorrected" in the [page](https://nda.nih.gov/data_structure.html?short_name=tlbx_cogsum01). From this figure, we can see that the performance of BNT (50.3) is the best among these models, with a large margin. 67 | 68 | ![mse](figure/mse.png) 69 | 70 | 71 | ## Citation 72 | 73 | Please cite our paper if you find this code useful for your work: 74 | ```bibtex 75 | @inproceedings{ 76 | kan2022bnt, 77 | title={BRAIN NETWORK TRANSFORMER}, 78 | author={Xuan Kan and Wei Dai and Hejie Cui and Zilong Zhang and Ying Guo and Carl Yang}, 79 | booktitle={Advances in Neural Information Processing Systems}, 80 | year={2022}, 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wayfear/BrainNetworkTransformer/8a588aadad0166209269fa114e5df4e42209e207/source/__init__.py -------------------------------------------------------------------------------- /source/__main__.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import wandb 3 | import hydra 4 | from omegaconf import DictConfig, open_dict 5 | from .dataset import dataset_factory 6 | from .models import model_factory 7 | from .components import lr_scheduler_factory, optimizers_factory, logger_factory 8 | from .training import training_factory 9 | from datetime import datetime 10 | 11 | 12 | def model_training(cfg: DictConfig): 13 | 14 | with open_dict(cfg): 15 | cfg.unique_id = datetime.now().strftime("%m-%d-%H-%M-%S") 16 | 17 | dataloaders = dataset_factory(cfg) 18 | logger = logger_factory(cfg) 19 | model = model_factory(cfg) 20 | optimizers = optimizers_factory( 21 | model=model, optimizer_configs=cfg.optimizer) 22 | lr_schedulers = lr_scheduler_factory(lr_configs=cfg.optimizer, 23 | cfg=cfg) 24 | training = training_factory(cfg, model, optimizers, 25 | lr_schedulers, dataloaders, logger) 26 | 27 | training.train() 28 | 29 | 30 | @hydra.main(version_base=None, config_path="conf", config_name="config") 31 | def main(cfg: DictConfig): 32 | 33 | group_name = f"{cfg.dataset.name}_{cfg.model.name}_{cfg.datasz.percentage}_{cfg.preprocess.name}" 34 | # _{cfg.training.name}\ 35 | # _{cfg.optimizer[0].lr_scheduler.mode}" 36 | 37 | for _ in range(cfg.repeat_time): 38 | run = wandb.init(project=cfg.project, entity=cfg.wandb_entity, reinit=True, 39 | group=f"{group_name}", tags=[f"{cfg.dataset.name}"]) 40 | model_training(cfg) 41 | 42 | run.finish() 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /source/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import initialize_logger, logger_factory 2 | from .lr_scheduler import LRScheduler, lr_scheduler_factory 3 | from .optimizer import optimizer_factory, optimizers_factory 4 | -------------------------------------------------------------------------------- /source/components/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Tuple 4 | from omegaconf import DictConfig 5 | 6 | 7 | def get_formatter() -> logging.Formatter: 8 | return logging.Formatter('[%(asctime)s][%(filename)s][L%(lineno)d][%(levelname)s] %(message)s') 9 | 10 | 11 | def initialize_logger() -> logging.Logger: 12 | logger = logging.getLogger() 13 | logger.setLevel(logging.INFO) 14 | for handler in logger.handlers: 15 | handler.close() 16 | logger.handlers.clear() 17 | 18 | formatter = get_formatter() 19 | stream_handler = logging.StreamHandler() 20 | stream_handler.setFormatter(formatter) 21 | logger.addHandler(stream_handler) 22 | 23 | return logger 24 | 25 | 26 | def set_file_handler(log_file_path: Path) -> logging.Logger: 27 | logger = initialize_logger() 28 | formatter = get_formatter() 29 | file_handler = logging.FileHandler(str(log_file_path)) 30 | file_handler.setFormatter(formatter) 31 | logger.addHandler(file_handler) 32 | 33 | return logger 34 | 35 | 36 | def logger_factory(config: DictConfig) -> Tuple[logging.Logger]: 37 | log_path = Path(config.log_path) / config.unique_id 38 | log_path.mkdir(exist_ok=True, parents=True) 39 | logger = set_file_handler(log_file_path=log_path 40 | / config.unique_id) 41 | return logger 42 | -------------------------------------------------------------------------------- /source/components/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import math 3 | from typing import List 4 | from omegaconf import DictConfig 5 | import torch 6 | 7 | 8 | class LRScheduler: 9 | def __init__(self, cfg: DictConfig, optimizer_cfg: DictConfig): 10 | self.lr_config = optimizer_cfg.lr_scheduler 11 | self.training_config = cfg 12 | self.lr = optimizer_cfg.lr 13 | 14 | assert self.lr_config.mode in [ 15 | 'step', 'poly', 'cos', 'linear', 'decay'] 16 | 17 | def update(self, optimizer: torch.optim.Optimizer, step: int): 18 | lr_config = self.lr_config 19 | lr_mode = lr_config.mode 20 | base_lr = lr_config.base_lr 21 | target_lr = lr_config.target_lr 22 | 23 | warm_up_from = lr_config.warm_up_from 24 | warm_up_steps = lr_config.warm_up_steps 25 | total_steps = self.training_config.total_steps 26 | 27 | assert 0 <= step <= total_steps 28 | if step < warm_up_steps: 29 | current_ratio = step / warm_up_steps 30 | self.lr = warm_up_from + (base_lr - warm_up_from) * current_ratio 31 | else: 32 | current_ratio = (step - warm_up_steps) / \ 33 | (total_steps - warm_up_steps) 34 | if lr_mode == 'step': 35 | count = bisect.bisect_left(lr_config.milestones, current_ratio) 36 | self.lr = base_lr * pow(lr_config.decay_factor, count) 37 | elif lr_mode == 'poly': 38 | poly = pow(1 - current_ratio, lr_config.poly_power) 39 | self.lr = target_lr + (base_lr - target_lr) * poly 40 | elif lr_mode == 'cos': 41 | cosine = math.cos(math.pi * current_ratio) 42 | self.lr = target_lr + (base_lr - target_lr) * (1 + cosine) / 2 43 | elif lr_mode == 'linear': 44 | self.lr = target_lr + \ 45 | (base_lr - target_lr) * (1 - current_ratio) 46 | elif lr_mode == 'decay': 47 | epoch = step // self.training_config.steps_per_epoch 48 | self.lr = base_lr * lr_config.lr_decay ** epoch 49 | 50 | for param_group in optimizer.param_groups: 51 | param_group['lr'] = self.lr 52 | 53 | 54 | def lr_scheduler_factory(lr_configs: List[DictConfig], cfg: DictConfig) -> List[LRScheduler]: 55 | return [LRScheduler(cfg=cfg, optimizer_cfg=lr_config) for lr_config in lr_configs] 56 | -------------------------------------------------------------------------------- /source/components/optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import List 4 | from omegaconf import DictConfig 5 | import torch 6 | 7 | 8 | def get_param_group_no_wd(model: torch.nn.Module, match_rule: str = None, except_rule: str = None): 9 | param_group_no_wd = [] 10 | names_no_wd = [] 11 | param_group_normal = [] 12 | 13 | type2num = defaultdict(lambda: 0) 14 | for name, m in model.named_modules(): 15 | if match_rule is not None and match_rule not in name: 16 | continue 17 | if except_rule is not None and except_rule in name: 18 | continue 19 | if isinstance(m, torch.nn.Conv2d): 20 | if m.bias is not None: 21 | param_group_no_wd.append(m.bias) 22 | names_no_wd.append(name + '.bias') 23 | type2num[m.__class__.__name__ + '.bias'] += 1 24 | elif isinstance(m, torch.nn.Linear): 25 | if m.bias is not None: 26 | param_group_no_wd.append(m.bias) 27 | names_no_wd.append(name + '.bias') 28 | type2num[m.__class__.__name__ + '.bias'] += 1 29 | elif isinstance(m, torch.nn.BatchNorm2d) \ 30 | or isinstance(m, torch.nn.BatchNorm1d): 31 | if m.weight is not None: 32 | param_group_no_wd.append(m.weight) 33 | names_no_wd.append(name + '.weight') 34 | type2num[m.__class__.__name__ + '.weight'] += 1 35 | if m.bias is not None: 36 | param_group_no_wd.append(m.bias) 37 | names_no_wd.append(name + '.bias') 38 | type2num[m.__class__.__name__ + '.bias'] += 1 39 | 40 | for name, p in model.named_parameters(): 41 | if match_rule is not None and match_rule not in name: 42 | continue 43 | if except_rule is not None and except_rule in name: 44 | continue 45 | if name not in names_no_wd: 46 | param_group_normal.append(p) 47 | 48 | params_length = len(param_group_normal) + len(param_group_no_wd) 49 | logging.info(f'Parameters [no weight decay] length [{params_length}]') 50 | return [{'params': param_group_normal}, {'params': param_group_no_wd, 'weight_decay': 0.0}], type2num 51 | 52 | 53 | def optimizer_factory(model: torch.nn.Module, optimizer_config: DictConfig) -> torch.optim.Optimizer: 54 | parameters = { 55 | 'lr': 0.0, 56 | 'weight_decay': optimizer_config.weight_decay 57 | } 58 | 59 | if optimizer_config.no_weight_decay: 60 | params, _ = get_param_group_no_wd(model, 61 | match_rule=optimizer_config.match_rule, 62 | except_rule=optimizer_config.except_rule) 63 | else: 64 | params = list(model.parameters()) 65 | logging.info(f'Parameters [normal] length [{len(params)}]') 66 | 67 | parameters['params'] = params 68 | 69 | optimizer_type = optimizer_config.name 70 | if optimizer_type == 'SGD': 71 | parameters['momentum'] = optimizer_config.momentum 72 | parameters['nesterov'] = optimizer_config.nesterov 73 | return getattr(torch.optim, optimizer_type)(**parameters) 74 | 75 | 76 | def optimizers_factory(model: torch.nn.Module, optimizer_configs: List[DictConfig]) -> List[torch.optim.Optimizer]: 77 | if model is None: 78 | return None 79 | return [optimizer_factory(model=model, optimizer_config=single_config) for single_config in optimizer_configs] 80 | -------------------------------------------------------------------------------- /source/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: ABCD 3 | - model: bnt # brainnetcnn, fbnetgen, bnt, transformer 4 | - optimizer: adam 5 | - training: basic_training 6 | - datasz: 100p 7 | - preprocess: mixup 8 | 9 | repeat_time: 5 10 | log_path: result 11 | save_learnable_graph: False 12 | 13 | # wandb 14 | wandb_entity: eggroup 15 | project: brainnetworktransformer -------------------------------------------------------------------------------- /source/conf/dataset/ABCD.yaml: -------------------------------------------------------------------------------- 1 | name: abcd 2 | batch_size: 16 3 | test_batch_size: 16 4 | val_batch_size: 16 5 | train_set: 0.7 6 | val_set: 0.1 7 | node_feature: /local/scratch/xkan/ABCD/ABCD/abcd_rest-pearson-HCP2016.npy 8 | time_seires: /local/scratch/xkan/ABCD/ABCD/abcd_rest-timeseires-HCP2016.npy 9 | node_id: /local/scratch/xkan/ABCD/ABCD/ids_HCP2016.txt 10 | seires_id: /local/scratch/xkan/ABCD/ABCD/ids_HCP2016_timeseires.txt 11 | label: /local/scratch/xkan/ABCD/ABCD/id2sex.txt 12 | drop_last: True 13 | stratified: False -------------------------------------------------------------------------------- /source/conf/dataset/ABIDE.yaml: -------------------------------------------------------------------------------- 1 | name: abide 2 | batch_size: 16 3 | test_batch_size: 16 4 | val_batch_size: 16 5 | train_set: 0.7 6 | val_set: 0.1 7 | path: /local/home/xkan/ABIDE/abide.npy 8 | stratified: True 9 | drop_last: True -------------------------------------------------------------------------------- /source/conf/datasz/100p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 1. 2 | -------------------------------------------------------------------------------- /source/conf/datasz/10p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.1 2 | -------------------------------------------------------------------------------- /source/conf/datasz/20p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.2 2 | -------------------------------------------------------------------------------- /source/conf/datasz/30p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.3 2 | -------------------------------------------------------------------------------- /source/conf/datasz/40p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.4 2 | -------------------------------------------------------------------------------- /source/conf/datasz/50p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.5 2 | -------------------------------------------------------------------------------- /source/conf/datasz/60p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.6 2 | -------------------------------------------------------------------------------- /source/conf/datasz/70p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.7 2 | -------------------------------------------------------------------------------- /source/conf/datasz/80p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.8 2 | -------------------------------------------------------------------------------- /source/conf/datasz/90p.yaml: -------------------------------------------------------------------------------- 1 | percentage: 0.9 2 | -------------------------------------------------------------------------------- /source/conf/model/bnt.yaml: -------------------------------------------------------------------------------- 1 | # seq, gnn, fbnetgen 2 | name: BrainNetworkTransformer 3 | sizes: [360, 100] # Note: The input node size should not be included here 4 | pooling: [false, true] 5 | pos_encoding: none # identity, none 6 | orthogonal: true 7 | freeze_center: true 8 | project_assignment: true 9 | pos_embed_dim: 360 -------------------------------------------------------------------------------- /source/conf/model/brainnetcnn.yaml: -------------------------------------------------------------------------------- 1 | name: BrainNetCNN 2 | -------------------------------------------------------------------------------- /source/conf/model/fbnetgen.yaml: -------------------------------------------------------------------------------- 1 | name: FBNETGEN 2 | 3 | # gru or cnn 4 | extractor_type: gru 5 | embedding_size: 16 6 | window_size: 4 7 | 8 | cnn_pool_size: 16 9 | 10 | # product or linear 11 | graph_generation: product 12 | 13 | num_gru_layers: 4 14 | 15 | dropout: 0.5 16 | 17 | group_loss: true 18 | sparsity_loss: true 19 | sparsity_loss_weight: 1.0e-4 20 | 21 | # training parameter 22 | train: FBNetTrain 23 | -------------------------------------------------------------------------------- /source/conf/model/transformer.yaml: -------------------------------------------------------------------------------- 1 | name: GraphTransformer 2 | self_attention_layer: 2 3 | readout: concat 4 | -------------------------------------------------------------------------------- /source/conf/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | - name: Adam 2 | lr: 1.0e-4 3 | match_rule: None 4 | except_rule: None 5 | no_weight_decay: false 6 | weight_decay: 1.0e-4 7 | lr_scheduler: 8 | mode: cos # ['step', 'poly', 'cos'] 9 | base_lr: 1.0e-4 10 | target_lr: 1.0e-5 11 | 12 | decay_factor: 0.1 # for step mode 13 | milestones: [0.3, 0.6, 0.9] 14 | poly_power: 2.0 # for poly mode 15 | lr_decay: 0.98 16 | 17 | warm_up_from: 0.0 18 | warm_up_steps: 0 19 | -------------------------------------------------------------------------------- /source/conf/preprocess/mixup.yaml: -------------------------------------------------------------------------------- 1 | name: continus_mixup 2 | continus: True -------------------------------------------------------------------------------- /source/conf/preprocess/non_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: non_mixup 2 | continus: False -------------------------------------------------------------------------------- /source/conf/training/basic_training.yaml: -------------------------------------------------------------------------------- 1 | name: Train 2 | epochs: 200 -------------------------------------------------------------------------------- /source/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig, open_dict 2 | from .abcd import load_abcd_data 3 | from .abide import load_abide_data 4 | from .dataloader import init_dataloader, init_stratified_dataloader 5 | from typing import List 6 | import torch.utils as utils 7 | 8 | 9 | def dataset_factory(cfg: DictConfig) -> List[utils.data.DataLoader]: 10 | 11 | assert cfg.dataset.name in ['abcd', 'abide'] 12 | 13 | datasets = eval( 14 | f"load_{cfg.dataset.name}_data")(cfg) 15 | 16 | dataloaders = init_stratified_dataloader(cfg, *datasets) \ 17 | if cfg.dataset.stratified \ 18 | else init_dataloader(cfg, *datasets) 19 | 20 | return dataloaders 21 | -------------------------------------------------------------------------------- /source/dataset/abcd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import preprocessing 4 | import pandas as pd 5 | from .preprocess import StandardScaler 6 | from omegaconf import DictConfig, open_dict 7 | 8 | 9 | def load_abcd_data(cfg: DictConfig): 10 | 11 | ts_data = np.load(cfg.dataset.time_seires, allow_pickle=True) 12 | pearson_data = np.load(cfg.dataset.node_feature, allow_pickle=True) 13 | label_df = pd.read_csv(cfg.dataset.label) 14 | 15 | with open(cfg.dataset.node_id, 'r') as f: 16 | lines = f.readlines() 17 | pearson_id = [line[:-1] for line in lines] 18 | 19 | with open(cfg.dataset.seires_id, 'r') as f: 20 | lines = f.readlines() 21 | ts_id = [line[:-1] for line in lines] 22 | 23 | id2pearson = dict(zip(pearson_id, pearson_data)) 24 | 25 | id2gender = dict(zip(label_df['id'], label_df['sex'])) 26 | 27 | final_timeseires, final_label, final_pearson = [], [], [] 28 | 29 | for ts, l in zip(ts_data, ts_id): 30 | if l in id2gender and l in id2pearson: 31 | if np.any(np.isnan(id2pearson[l])) == False: 32 | final_timeseires.append(ts) 33 | final_label.append(id2gender[l]) 34 | final_pearson.append(id2pearson[l]) 35 | 36 | encoder = preprocessing.LabelEncoder() 37 | 38 | encoder.fit(label_df["sex"]) 39 | 40 | labels = encoder.transform(final_label) 41 | 42 | scaler = StandardScaler(mean=np.mean( 43 | final_timeseires), std=np.std(final_timeseires)) 44 | 45 | final_timeseires = scaler.transform(final_timeseires) 46 | 47 | final_timeseires, final_pearson, labels = [np.array( 48 | data) for data in (final_timeseires, final_pearson, labels)] 49 | 50 | final_timeseires, final_pearson, labels = [torch.from_numpy( 51 | data).float() for data in (final_timeseires, final_pearson, labels)] 52 | 53 | with open_dict(cfg): 54 | 55 | cfg.dataset.node_sz, cfg.dataset.node_feature_sz = final_pearson.shape[1:] 56 | cfg.dataset.timeseries_sz = final_timeseires.shape[2] 57 | 58 | return final_timeseires, final_pearson, labels 59 | -------------------------------------------------------------------------------- /source/dataset/abide.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .preprocess import StandardScaler 4 | from omegaconf import DictConfig, open_dict 5 | 6 | 7 | def load_abide_data(cfg: DictConfig): 8 | 9 | data = np.load(cfg.dataset.path, allow_pickle=True).item() 10 | final_timeseires = data["timeseires"] 11 | final_pearson = data["corr"] 12 | labels = data["label"] 13 | site = data['site'] 14 | 15 | scaler = StandardScaler(mean=np.mean( 16 | final_timeseires), std=np.std(final_timeseires)) 17 | 18 | final_timeseires = scaler.transform(final_timeseires) 19 | 20 | final_timeseires, final_pearson, labels = [torch.from_numpy( 21 | data).float() for data in (final_timeseires, final_pearson, labels)] 22 | 23 | with open_dict(cfg): 24 | 25 | cfg.dataset.node_sz, cfg.dataset.node_feature_sz = final_pearson.shape[1:] 26 | cfg.dataset.timeseries_sz = final_timeseires.shape[2] 27 | 28 | return final_timeseires, final_pearson, labels, site 29 | -------------------------------------------------------------------------------- /source/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as utils 3 | from omegaconf import DictConfig, open_dict 4 | from typing import List 5 | from sklearn.model_selection import StratifiedShuffleSplit 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | def init_dataloader(cfg: DictConfig, 11 | final_timeseires: torch.tensor, 12 | final_pearson: torch.tensor, 13 | labels: torch.tensor) -> List[utils.DataLoader]: 14 | labels = F.one_hot(labels.to(torch.int64)) 15 | length = final_timeseires.shape[0] 16 | train_length = int(length*cfg.dataset.train_set*cfg.datasz.percentage) 17 | val_length = int(length*cfg.dataset.val_set) 18 | if cfg.datasz.percentage == 1.0: 19 | test_length = length-train_length-val_length 20 | else: 21 | test_length = int(length*(1-cfg.dataset.val_set-cfg.dataset.train_set)) 22 | 23 | with open_dict(cfg): 24 | # total_steps, steps_per_epoch for lr schedular 25 | cfg.steps_per_epoch = ( 26 | train_length - 1) // cfg.dataset.batch_size + 1 27 | cfg.total_steps = cfg.steps_per_epoch * cfg.training.epochs 28 | 29 | dataset = utils.TensorDataset( 30 | final_timeseires[:train_length+val_length+test_length], 31 | final_pearson[:train_length+val_length+test_length], 32 | labels[:train_length+val_length+test_length] 33 | ) 34 | 35 | train_dataset, val_dataset, test_dataset = utils.random_split( 36 | dataset, [train_length, val_length, test_length]) 37 | train_dataloader = utils.DataLoader( 38 | train_dataset, batch_size=cfg.dataset.batch_size, shuffle=True, drop_last=cfg.dataset.drop_last) 39 | 40 | val_dataloader = utils.DataLoader( 41 | val_dataset, batch_size=cfg.dataset.batch_size, shuffle=True, drop_last=False) 42 | 43 | test_dataloader = utils.DataLoader( 44 | test_dataset, batch_size=cfg.dataset.batch_size, shuffle=True, drop_last=False) 45 | 46 | return [train_dataloader, val_dataloader, test_dataloader] 47 | 48 | 49 | def init_stratified_dataloader(cfg: DictConfig, 50 | final_timeseires: torch.tensor, 51 | final_pearson: torch.tensor, 52 | labels: torch.tensor, 53 | stratified: np.array) -> List[utils.DataLoader]: 54 | labels = F.one_hot(labels.to(torch.int64)) 55 | length = final_timeseires.shape[0] 56 | train_length = int(length*cfg.dataset.train_set*cfg.datasz.percentage) 57 | val_length = int(length*cfg.dataset.val_set) 58 | if cfg.datasz.percentage == 1.0: 59 | test_length = length-train_length-val_length 60 | else: 61 | test_length = int(length*(1-cfg.dataset.val_set-cfg.dataset.train_set)) 62 | 63 | with open_dict(cfg): 64 | # total_steps, steps_per_epoch for lr schedular 65 | cfg.steps_per_epoch = ( 66 | train_length - 1) // cfg.dataset.batch_size + 1 67 | cfg.total_steps = cfg.steps_per_epoch * cfg.training.epochs 68 | 69 | split = StratifiedShuffleSplit( 70 | n_splits=1, test_size=val_length+test_length, train_size=train_length, random_state=42) 71 | for train_index, test_valid_index in split.split(final_timeseires, stratified): 72 | final_timeseires_train, final_pearson_train, labels_train = final_timeseires[ 73 | train_index], final_pearson[train_index], labels[train_index] 74 | final_timeseires_val_test, final_pearson_val_test, labels_val_test = final_timeseires[ 75 | test_valid_index], final_pearson[test_valid_index], labels[test_valid_index] 76 | stratified = stratified[test_valid_index] 77 | 78 | split2 = StratifiedShuffleSplit( 79 | n_splits=1, test_size=test_length) 80 | for test_index, valid_index in split2.split(final_timeseires_val_test, stratified): 81 | final_timeseires_test, final_pearson_test, labels_test = final_timeseires_val_test[ 82 | test_index], final_pearson_val_test[test_index], labels_val_test[test_index] 83 | final_timeseires_val, final_pearson_val, labels_val = final_timeseires_val_test[ 84 | valid_index], final_pearson_val_test[valid_index], labels_val_test[valid_index] 85 | 86 | train_dataset = utils.TensorDataset( 87 | final_timeseires_train, 88 | final_pearson_train, 89 | labels_train 90 | ) 91 | 92 | val_dataset = utils.TensorDataset( 93 | final_timeseires_val, final_pearson_val, labels_val 94 | ) 95 | 96 | test_dataset = utils.TensorDataset( 97 | final_timeseires_test, final_pearson_test, labels_test 98 | ) 99 | 100 | train_dataloader = utils.DataLoader( 101 | train_dataset, batch_size=cfg.dataset.batch_size, shuffle=True, drop_last=cfg.dataset.drop_last) 102 | 103 | val_dataloader = utils.DataLoader( 104 | val_dataset, batch_size=cfg.dataset.batch_size, shuffle=True, drop_last=False) 105 | 106 | test_dataloader = utils.DataLoader( 107 | test_dataset, batch_size=cfg.dataset.batch_size, shuffle=True, drop_last=False) 108 | 109 | return [train_dataloader, val_dataloader, test_dataloader] 110 | -------------------------------------------------------------------------------- /source/dataset/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from omegaconf import DictConfig 3 | 4 | 5 | class StandardScaler: 6 | """ 7 | Standard the input 8 | """ 9 | 10 | def __init__(self, mean: np.array, std: np.array): 11 | self.mean = mean 12 | self.std = std 13 | 14 | def transform(self, data: np.array): 15 | return (data - self.mean) / self.std 16 | 17 | def inverse_transform(self, data: np.array): 18 | return (data * self.std) + self.mean 19 | 20 | 21 | def reduce_sample_size(config: DictConfig, *args): 22 | sz = args[0].shape[0] 23 | used_sz = int(sz * config.datasz.percentage) 24 | return [d[:used_sz] for d in args] 25 | -------------------------------------------------------------------------------- /source/models/BNT/__init__.py: -------------------------------------------------------------------------------- 1 | from .bnt import BrainNetworkTransformer 2 | -------------------------------------------------------------------------------- /source/models/BNT/bnt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import TransformerEncoderLayer 4 | from .ptdec import DEC 5 | from typing import List 6 | from .components import InterpretableTransformerEncoder 7 | from omegaconf import DictConfig 8 | from ..base import BaseModel 9 | 10 | 11 | class TransPoolingEncoder(nn.Module): 12 | """ 13 | Transformer encoder with Pooling mechanism. 14 | Input size: (batch_size, input_node_num, input_feature_size) 15 | Output size: (batch_size, output_node_num, input_feature_size) 16 | """ 17 | 18 | def __init__(self, input_feature_size, input_node_num, hidden_size, output_node_num, pooling=True, orthogonal=True, freeze_center=False, project_assignment=True): 19 | super().__init__() 20 | self.transformer = InterpretableTransformerEncoder(d_model=input_feature_size, nhead=4, 21 | dim_feedforward=hidden_size, 22 | batch_first=True) 23 | 24 | self.pooling = pooling 25 | if pooling: 26 | encoder_hidden_size = 32 27 | self.encoder = nn.Sequential( 28 | nn.Linear(input_feature_size * 29 | input_node_num, encoder_hidden_size), 30 | nn.LeakyReLU(), 31 | nn.Linear(encoder_hidden_size, encoder_hidden_size), 32 | nn.LeakyReLU(), 33 | nn.Linear(encoder_hidden_size, 34 | input_feature_size * input_node_num), 35 | ) 36 | self.dec = DEC(cluster_number=output_node_num, hidden_dimension=input_feature_size, encoder=self.encoder, 37 | orthogonal=orthogonal, freeze_center=freeze_center, project_assignment=project_assignment) 38 | 39 | def is_pooling_enabled(self): 40 | return self.pooling 41 | 42 | def forward(self, x): 43 | x = self.transformer(x) 44 | if self.pooling: 45 | x, assignment = self.dec(x) 46 | return x, assignment 47 | return x, None 48 | 49 | def get_attention_weights(self): 50 | return self.transformer.get_attention_weights() 51 | 52 | def loss(self, assignment): 53 | return self.dec.loss(assignment) 54 | 55 | 56 | class BrainNetworkTransformer(BaseModel): 57 | 58 | def __init__(self, config: DictConfig): 59 | 60 | super().__init__() 61 | 62 | self.attention_list = nn.ModuleList() 63 | forward_dim = config.dataset.node_sz 64 | 65 | self.pos_encoding = config.model.pos_encoding 66 | if self.pos_encoding == 'identity': 67 | self.node_identity = nn.Parameter(torch.zeros( 68 | config.dataset.node_sz, config.model.pos_embed_dim), requires_grad=True) 69 | forward_dim = config.dataset.node_sz + config.model.pos_embed_dim 70 | nn.init.kaiming_normal_(self.node_identity) 71 | 72 | sizes = config.model.sizes 73 | sizes[0] = config.dataset.node_sz 74 | in_sizes = [config.dataset.node_sz] + sizes[:-1] 75 | do_pooling = config.model.pooling 76 | self.do_pooling = do_pooling 77 | for index, size in enumerate(sizes): 78 | self.attention_list.append( 79 | TransPoolingEncoder(input_feature_size=forward_dim, 80 | input_node_num=in_sizes[index], 81 | hidden_size=1024, 82 | output_node_num=size, 83 | pooling=do_pooling[index], 84 | orthogonal=config.model.orthogonal, 85 | freeze_center=config.model.freeze_center, 86 | project_assignment=config.model.project_assignment)) 87 | 88 | self.dim_reduction = nn.Sequential( 89 | nn.Linear(forward_dim, 8), 90 | nn.LeakyReLU() 91 | ) 92 | 93 | self.fc = nn.Sequential( 94 | nn.Linear(8 * sizes[-1], 256), 95 | nn.LeakyReLU(), 96 | nn.Linear(256, 32), 97 | nn.LeakyReLU(), 98 | nn.Linear(32, 2) 99 | ) 100 | 101 | def forward(self, 102 | time_seires: torch.tensor, 103 | node_feature: torch.tensor): 104 | 105 | bz, _, _, = node_feature.shape 106 | 107 | if self.pos_encoding == 'identity': 108 | pos_emb = self.node_identity.expand(bz, *self.node_identity.shape) 109 | node_feature = torch.cat([node_feature, pos_emb], dim=-1) 110 | 111 | assignments = [] 112 | 113 | for atten in self.attention_list: 114 | node_feature, assignment = atten(node_feature) 115 | assignments.append(assignment) 116 | 117 | node_feature = self.dim_reduction(node_feature) 118 | 119 | node_feature = node_feature.reshape((bz, -1)) 120 | 121 | return self.fc(node_feature) 122 | 123 | def get_attention_weights(self): 124 | return [atten.get_attention_weights() for atten in self.attention_list] 125 | 126 | def get_cluster_centers(self) -> torch.Tensor: 127 | """ 128 | Get the cluster centers, as computed by the encoder. 129 | 130 | :return: [number of clusters, hidden dimension] Tensor of dtype float 131 | """ 132 | return self.dec.get_cluster_centers() 133 | 134 | def loss(self, assignments): 135 | """ 136 | Compute KL loss for the given assignments. Note that not all encoders contain a pooling layer. 137 | Inputs: assignments: [batch size, number of clusters] 138 | Output: KL loss 139 | """ 140 | decs = list( 141 | filter(lambda x: x.is_pooling_enabled(), self.attention_list)) 142 | assignments = list(filter(lambda x: x is not None, assignments)) 143 | loss_all = None 144 | 145 | for index, assignment in enumerate(assignments): 146 | if loss_all is None: 147 | loss_all = decs[index].loss(assignment) 148 | else: 149 | loss_all += decs[index].loss(assignment) 150 | return loss_all 151 | -------------------------------------------------------------------------------- /source/models/BNT/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_encoder import InterpretableTransformerEncoder 2 | -------------------------------------------------------------------------------- /source/models/BNT/components/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import TransformerEncoderLayer 2 | from torch import Tensor 3 | from typing import Optional 4 | import torch.nn.functional as F 5 | 6 | 7 | class InterpretableTransformerEncoder(TransformerEncoderLayer): 8 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, 9 | layer_norm_eps=1e-5, batch_first=False, norm_first=False, 10 | device=None, dtype=None) -> None: 11 | super().__init__(d_model, nhead, dim_feedforward, dropout, activation, 12 | layer_norm_eps, batch_first, norm_first, device, dtype) 13 | self.attention_weights: Optional[Tensor] = None 14 | 15 | def _sa_block(self, x: Tensor, 16 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: 17 | x, weights = self.self_attn(x, x, x, 18 | attn_mask=attn_mask, 19 | key_padding_mask=key_padding_mask, 20 | need_weights=True) 21 | self.attention_weights = weights 22 | return self.dropout1(x) 23 | 24 | def get_attention_weights(self) -> Optional[Tensor]: 25 | return self.attention_weights 26 | -------------------------------------------------------------------------------- /source/models/BNT/ptdec/__init__.py: -------------------------------------------------------------------------------- 1 | from .dec import DEC 2 | -------------------------------------------------------------------------------- /source/models/BNT/ptdec/cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/vlukiyanov/pt-dec 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import Parameter 8 | from typing import Optional 9 | from torch.nn.functional import softmax 10 | 11 | 12 | class ClusterAssignment(nn.Module): 13 | def __init__( 14 | self, 15 | cluster_number: int, 16 | embedding_dimension: int, 17 | alpha: float = 1.0, 18 | cluster_centers: Optional[torch.Tensor] = None, 19 | orthogonal=True, 20 | freeze_center=True, 21 | project_assignment=True 22 | ) -> None: 23 | """ 24 | Module to handle the soft assignment, for a description see in 3.1.1. in Xie/Girshick/Farhadi, 25 | where the Student's t-distribution is used measure similarity between feature vector and each 26 | cluster centroid. 27 | 28 | :param cluster_number: number of clusters 29 | :param embedding_dimension: embedding dimension of feature vectors 30 | :param alpha: parameter representing the degrees of freedom in the t-distribution, default 1.0 31 | :param cluster_centers: clusters centers to initialise, if None then use Xavier uniform 32 | """ 33 | super(ClusterAssignment, self).__init__() 34 | self.embedding_dimension = embedding_dimension 35 | self.cluster_number = cluster_number 36 | self.alpha = alpha 37 | self.project_assignment = project_assignment 38 | if cluster_centers is None: 39 | initial_cluster_centers = torch.zeros( 40 | self.cluster_number, self.embedding_dimension, dtype=torch.float 41 | ) 42 | nn.init.xavier_uniform_(initial_cluster_centers) 43 | 44 | else: 45 | initial_cluster_centers = cluster_centers 46 | 47 | if orthogonal: 48 | orthogonal_cluster_centers = torch.zeros( 49 | self.cluster_number, self.embedding_dimension, dtype=torch.float 50 | ) 51 | orthogonal_cluster_centers[0] = initial_cluster_centers[0] 52 | for i in range(1, cluster_number): 53 | project = 0 54 | for j in range(i): 55 | project += self.project( 56 | initial_cluster_centers[j], initial_cluster_centers[i]) 57 | initial_cluster_centers[i] -= project 58 | orthogonal_cluster_centers[i] = initial_cluster_centers[i] / \ 59 | torch.norm(initial_cluster_centers[i], p=2) 60 | 61 | initial_cluster_centers = orthogonal_cluster_centers 62 | 63 | self.cluster_centers = Parameter( 64 | initial_cluster_centers, requires_grad=(not freeze_center)) 65 | 66 | @staticmethod 67 | def project(u, v): 68 | return (torch.dot(u, v)/torch.dot(u, u))*u 69 | 70 | def forward(self, batch: torch.Tensor) -> torch.Tensor: 71 | """ 72 | Compute the soft assignment for a batch of feature vectors, returning a batch of assignments 73 | for each cluster. 74 | 75 | :param batch: FloatTensor of [batch size, embedding dimension] 76 | :return: FloatTensor [batch size, number of clusters] 77 | """ 78 | 79 | if self.project_assignment: 80 | 81 | assignment = batch@self.cluster_centers.T 82 | # prove 83 | assignment = torch.pow(assignment, 2) 84 | 85 | norm = torch.norm(self.cluster_centers, p=2, dim=-1) 86 | soft_assign = assignment/norm 87 | return softmax(soft_assign, dim=-1) 88 | 89 | else: 90 | 91 | norm_squared = torch.sum( 92 | (batch.unsqueeze(1) - self.cluster_centers) ** 2, 2) 93 | numerator = 1.0 / (1.0 + (norm_squared / self.alpha)) 94 | power = float(self.alpha + 1) / 2 95 | numerator = numerator ** power 96 | return numerator / torch.sum(numerator, dim=1, keepdim=True) 97 | 98 | def get_cluster_centers(self) -> torch.Tensor: 99 | """ 100 | Get the cluster centers. 101 | 102 | :return: FloatTensor [number of clusters, embedding dimension] 103 | """ 104 | return self.cluster_centers 105 | -------------------------------------------------------------------------------- /source/models/BNT/ptdec/dec.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/vlukiyanov/pt-dec 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from typing import Tuple 8 | from .cluster import ClusterAssignment 9 | 10 | 11 | class DEC(nn.Module): 12 | def __init__( 13 | self, 14 | cluster_number: int, 15 | hidden_dimension: int, 16 | encoder: torch.nn.Module, 17 | alpha: float = 1.0, 18 | orthogonal=True, 19 | freeze_center=True, project_assignment=True 20 | ): 21 | """ 22 | Module which holds all the moving parts of the DEC algorithm, as described in 23 | Xie/Girshick/Farhadi; this includes the AutoEncoder stage and the ClusterAssignment stage. 24 | 25 | :param cluster_number: number of clusters 26 | :param hidden_dimension: hidden dimension, output of the encoder 27 | :param encoder: encoder to use 28 | :param alpha: parameter representing the degrees of freedom in the t-distribution, default 1.0 29 | """ 30 | super(DEC, self).__init__() 31 | self.encoder = encoder 32 | self.hidden_dimension = hidden_dimension 33 | self.cluster_number = cluster_number 34 | self.alpha = alpha 35 | self.assignment = ClusterAssignment( 36 | cluster_number, self.hidden_dimension, alpha, orthogonal=orthogonal, freeze_center=freeze_center, project_assignment=project_assignment 37 | ) 38 | 39 | self.loss_fn = nn.KLDivLoss(size_average=False) 40 | 41 | def forward(self, batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 42 | """ 43 | Compute the cluster assignment using the ClusterAssignment after running the batch 44 | through the encoder part of the associated AutoEncoder module. 45 | 46 | :param batch: [batch size, embedding dimension] FloatTensor 47 | :return: [batch size, number of clusters] FloatTensor 48 | """ 49 | node_num = batch.size(1) 50 | batch_size = batch.size(0) 51 | 52 | # [batch size, embedding dimension] 53 | flattened_batch = batch.view(batch_size, -1) 54 | encoded = self.encoder(flattened_batch) 55 | # [batch size * node_num, hidden dimension] 56 | encoded = encoded.view(batch_size * node_num, -1) 57 | # [batch size * node_num, cluster_number] 58 | assignment = self.assignment(encoded) 59 | # [batch size, node_num, cluster_number] 60 | assignment = assignment.view(batch_size, node_num, -1) 61 | # [batch size, node_num, hidden dimension] 62 | encoded = encoded.view(batch_size, node_num, -1) 63 | # Multiply the encoded vectors by the cluster assignment to get the final node representations 64 | # [batch size, cluster_number, hidden dimension] 65 | node_repr = torch.bmm(assignment.transpose(1, 2), encoded) 66 | return node_repr, assignment 67 | 68 | def target_distribution(self, batch: torch.Tensor) -> torch.Tensor: 69 | """ 70 | Compute the target distribution p_ij, given the batch (q_ij), as in 3.1.3 Equation 3 of 71 | Xie/Girshick/Farhadi; this is used the KL-divergence loss function. 72 | 73 | :param batch: [batch size, number of clusters] Tensor of dtype float 74 | :return: [batch size, number of clusters] Tensor of dtype float 75 | """ 76 | weight = (batch ** 2) / torch.sum(batch, 0) 77 | return (weight.t() / torch.sum(weight, 1)).t() 78 | 79 | def loss(self, assignment): 80 | flattened_assignment = assignment.view(-1, assignment.size(-1)) 81 | target = self.target_distribution(flattened_assignment).detach() 82 | return self.loss_fn(flattened_assignment.log(), target) / flattened_assignment.size(0) 83 | 84 | def get_cluster_centers(self) -> torch.Tensor: 85 | """ 86 | Get the cluster centers, as computed by the encoder. 87 | 88 | :return: [number of clusters, hidden dimension] Tensor of dtype float 89 | """ 90 | return self.assignment.get_cluster_centers() 91 | -------------------------------------------------------------------------------- /source/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import GraphTransformer 2 | from omegaconf import DictConfig 3 | from .brainnetcnn import BrainNetCNN 4 | from .fbnetgen import FBNETGEN 5 | from .BNT import BrainNetworkTransformer 6 | 7 | 8 | def model_factory(config: DictConfig): 9 | if config.model.name in ["LogisticRegression", "SVC"]: 10 | return None 11 | return eval(config.model.name)(config).cuda() 12 | -------------------------------------------------------------------------------- /source/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(nn.Module): 7 | 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | @abstractmethod 12 | def forward(self, 13 | time_seires: torch.tensor, 14 | node_feature: torch.tensor) -> torch.tensor: 15 | pass 16 | -------------------------------------------------------------------------------- /source/models/brainnetcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from omegaconf import DictConfig 4 | from .base import BaseModel 5 | 6 | 7 | class E2EBlock(torch.nn.Module): 8 | '''E2Eblock.''' 9 | 10 | def __init__(self, in_planes, planes, roi_num, bias=True): 11 | super().__init__() 12 | self.d = roi_num 13 | self.cnn1 = torch.nn.Conv2d(in_planes, planes, (1, self.d), bias=bias) 14 | self.cnn2 = torch.nn.Conv2d(in_planes, planes, (self.d, 1), bias=bias) 15 | 16 | def forward(self, x): 17 | a = self.cnn1(x) 18 | b = self.cnn2(x) 19 | return torch.cat([a]*self.d, 3)+torch.cat([b]*self.d, 2) 20 | 21 | 22 | class BrainNetCNN(BaseModel): 23 | def __init__(self, config: DictConfig): 24 | super().__init__() 25 | self.in_planes = 1 26 | self.d = config.dataset.node_sz 27 | 28 | self.e2econv1 = E2EBlock(1, 32, config.dataset.node_sz, bias=True) 29 | self.e2econv2 = E2EBlock(32, 64, config.dataset.node_sz, bias=True) 30 | self.E2N = torch.nn.Conv2d(64, 1, (1, self.d)) 31 | self.N2G = torch.nn.Conv2d(1, 256, (self.d, 1)) 32 | self.dense1 = torch.nn.Linear(256, 128) 33 | self.dense2 = torch.nn.Linear(128, 30) 34 | self.dense3 = torch.nn.Linear(30, 2) 35 | 36 | def forward(self, 37 | time_seires: torch.tensor, 38 | node_feature: torch.tensor): 39 | node_feature = node_feature.unsqueeze(dim=1) 40 | out = F.leaky_relu(self.e2econv1(node_feature), negative_slope=0.33) 41 | out = F.leaky_relu(self.e2econv2(out), negative_slope=0.33) 42 | out = F.leaky_relu(self.E2N(out), negative_slope=0.33) 43 | out = F.dropout(F.leaky_relu( 44 | self.N2G(out), negative_slope=0.33), p=0.5) 45 | out = out.view(out.size(0), -1) 46 | out = F.dropout(F.leaky_relu( 47 | self.dense1(out), negative_slope=0.33), p=0.5) 48 | out = F.dropout(F.leaky_relu( 49 | self.dense2(out), negative_slope=0.33), p=0.5) 50 | out = F.leaky_relu(self.dense3(out), negative_slope=0.33) 51 | 52 | return out 53 | -------------------------------------------------------------------------------- /source/models/fbnetgen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import Conv1d, MaxPool1d, Linear, GRU 6 | from omegaconf import DictConfig 7 | from .base import BaseModel 8 | 9 | 10 | class GruKRegion(nn.Module): 11 | 12 | def __init__(self, kernel_size=8, layers=4, out_size=8, dropout=0.5): 13 | super().__init__() 14 | self.gru = GRU(kernel_size, kernel_size, layers, 15 | bidirectional=True, batch_first=True) 16 | 17 | self.kernel_size = kernel_size 18 | 19 | self.linear = nn.Sequential( 20 | nn.Dropout(dropout), 21 | Linear(kernel_size*2, kernel_size), 22 | nn.LeakyReLU(negative_slope=0.2), 23 | Linear(kernel_size, out_size) 24 | ) 25 | 26 | def forward(self, raw): 27 | 28 | b, k, d = raw.shape 29 | 30 | x = raw.view((b*k, -1, self.kernel_size)) 31 | 32 | x, h = self.gru(x) 33 | 34 | x = x[:, -1, :] 35 | 36 | x = x.view((b, k, -1)) 37 | 38 | x = self.linear(x) 39 | return x 40 | 41 | 42 | class ConvKRegion(nn.Module): 43 | 44 | def __init__(self, k=1, out_size=8, kernel_size=8, pool_size=16, time_series=512): 45 | super().__init__() 46 | self.conv1 = Conv1d(in_channels=k, out_channels=32, 47 | kernel_size=kernel_size, stride=2) 48 | 49 | output_dim_1 = (time_series-kernel_size)//2+1 50 | 51 | self.conv2 = Conv1d(in_channels=32, out_channels=32, 52 | kernel_size=8) 53 | output_dim_2 = output_dim_1 - 8 + 1 54 | self.conv3 = Conv1d(in_channels=32, out_channels=16, 55 | kernel_size=8) 56 | output_dim_3 = output_dim_2 - 8 + 1 57 | self.max_pool1 = MaxPool1d(pool_size) 58 | output_dim_4 = output_dim_3 // pool_size * 16 59 | self.in0 = nn.InstanceNorm1d(time_series) 60 | self.in1 = nn.BatchNorm1d(32) 61 | self.in2 = nn.BatchNorm1d(32) 62 | self.in3 = nn.BatchNorm1d(16) 63 | 64 | self.linear = nn.Sequential( 65 | Linear(output_dim_4, 32), 66 | nn.LeakyReLU(negative_slope=0.2), 67 | Linear(32, out_size) 68 | ) 69 | 70 | def forward(self, x): 71 | 72 | b, k, d = x.shape 73 | 74 | x = torch.transpose(x, 1, 2) 75 | 76 | x = self.in0(x) 77 | 78 | x = torch.transpose(x, 1, 2) 79 | x = x.contiguous() 80 | 81 | x = x.view((b*k, 1, d)) 82 | 83 | x = self.conv1(x) 84 | 85 | x = self.in1(x) 86 | x = self.conv2(x) 87 | 88 | x = self.in2(x) 89 | x = self.conv3(x) 90 | 91 | x = self.in3(x) 92 | x = self.max_pool1(x) 93 | 94 | x = x.view((b, k, -1)) 95 | 96 | x = self.linear(x) 97 | 98 | return x 99 | 100 | 101 | class Embed2GraphByProduct(nn.Module): 102 | 103 | def __init__(self, input_dim, roi_num=264): 104 | super().__init__() 105 | 106 | def forward(self, x): 107 | 108 | m = torch.einsum('ijk,ipk->ijp', x, x) 109 | 110 | m = torch.unsqueeze(m, -1) 111 | 112 | return m 113 | 114 | 115 | class GNNPredictor(nn.Module): 116 | 117 | def __init__(self, node_input_dim, roi_num=360): 118 | super().__init__() 119 | inner_dim = roi_num 120 | self.roi_num = roi_num 121 | self.gcn = nn.Sequential( 122 | nn.Linear(node_input_dim, inner_dim), 123 | nn.LeakyReLU(negative_slope=0.2), 124 | Linear(inner_dim, inner_dim) 125 | ) 126 | self.bn1 = torch.nn.BatchNorm1d(inner_dim) 127 | 128 | self.gcn1 = nn.Sequential( 129 | nn.Linear(inner_dim, inner_dim), 130 | nn.LeakyReLU(negative_slope=0.2), 131 | ) 132 | self.bn2 = torch.nn.BatchNorm1d(inner_dim) 133 | self.gcn2 = nn.Sequential( 134 | nn.Linear(inner_dim, 64), 135 | nn.LeakyReLU(negative_slope=0.2), 136 | nn.Linear(64, 8), 137 | nn.LeakyReLU(negative_slope=0.2), 138 | ) 139 | self.bn3 = torch.nn.BatchNorm1d(inner_dim) 140 | 141 | self.fcn = nn.Sequential( 142 | nn.Linear(8*roi_num, 256), 143 | nn.LeakyReLU(negative_slope=0.2), 144 | nn.Linear(256, 32), 145 | nn.LeakyReLU(negative_slope=0.2), 146 | nn.Linear(32, 2) 147 | ) 148 | 149 | def forward(self, m, node_feature): 150 | bz = m.shape[0] 151 | 152 | x = torch.einsum('ijk,ijp->ijp', m, node_feature) 153 | 154 | x = self.gcn(x) 155 | 156 | x = x.reshape((bz*self.roi_num, -1)) 157 | x = self.bn1(x) 158 | x = x.reshape((bz, self.roi_num, -1)) 159 | 160 | x = torch.einsum('ijk,ijp->ijp', m, x) 161 | 162 | x = self.gcn1(x) 163 | 164 | x = x.reshape((bz*self.roi_num, -1)) 165 | x = self.bn2(x) 166 | x = x.reshape((bz, self.roi_num, -1)) 167 | 168 | x = torch.einsum('ijk,ijp->ijp', m, x) 169 | 170 | x = self.gcn2(x) 171 | 172 | x = self.bn3(x) 173 | 174 | x = x.view(bz, -1) 175 | 176 | return self.fcn(x) 177 | 178 | 179 | class Embed2GraphByLinear(nn.Module): 180 | 181 | def __init__(self, input_dim, roi_num=360): 182 | super().__init__() 183 | 184 | self.fc_out = nn.Linear(input_dim * 2, input_dim) 185 | self.fc_cat = nn.Linear(input_dim, 1) 186 | 187 | def encode_onehot(labels): 188 | classes = set(labels) 189 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 190 | enumerate(classes)} 191 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 192 | dtype=np.int32) 193 | return labels_onehot 194 | 195 | off_diag = np.ones([roi_num, roi_num]) 196 | rel_rec = np.array(encode_onehot( 197 | np.where(off_diag)[0]), dtype=np.float32) 198 | rel_send = np.array(encode_onehot( 199 | np.where(off_diag)[1]), dtype=np.float32) 200 | self.rel_rec = torch.FloatTensor(rel_rec).cuda() 201 | self.rel_send = torch.FloatTensor(rel_send).cuda() 202 | 203 | def forward(self, x): 204 | 205 | batch_sz, region_num, _ = x.shape 206 | receivers = torch.matmul(self.rel_rec, x) 207 | 208 | senders = torch.matmul(self.rel_send, x) 209 | x = torch.cat([senders, receivers], dim=2) 210 | x = torch.relu(self.fc_out(x)) 211 | x = self.fc_cat(x) 212 | 213 | x = torch.relu(x) 214 | 215 | m = torch.reshape( 216 | x, (batch_sz, region_num, region_num, -1)) 217 | return m 218 | 219 | 220 | class FBNETGEN(BaseModel): 221 | 222 | def __init__(self, config: DictConfig): 223 | super().__init__() 224 | 225 | assert config.model.extractor_type in ['cnn', 'gru'] 226 | assert config.model.graph_generation in ['linear', 'product'] 227 | assert config.dataset.timeseries_sz % config.model.window_size == 0 228 | 229 | self.graph_generation = config.model.graph_generation 230 | if config.model.extractor_type == 'cnn': 231 | self.extract = ConvKRegion( 232 | out_size=config.model.embedding_size, kernel_size=config.model.window_size, 233 | time_series=config.dataset.timeseries_sz) 234 | elif config.model.extractor_type == 'gru': 235 | self.extract = GruKRegion( 236 | out_size=config.model.embedding_size, kernel_size=config.model.window_size, 237 | layers=config.model.num_gru_layers) 238 | if self.graph_generation == "linear": 239 | self.emb2graph = Embed2GraphByLinear( 240 | config.model.embedding_size, roi_num=config.dataset.node_sz) 241 | elif self.graph_generation == "product": 242 | self.emb2graph = Embed2GraphByProduct( 243 | config.model.embedding_size, roi_num=config.dataset.node_sz) 244 | 245 | self.predictor = GNNPredictor( 246 | config.dataset.node_feature_sz, roi_num=config.dataset.node_sz) 247 | 248 | def forward(self, time_seires, node_feature): 249 | x = self.extract(time_seires) 250 | x = F.softmax(x, dim=-1) 251 | m = self.emb2graph(x) 252 | m = m[:, :, :, 0] 253 | 254 | return self.predictor(m, node_feature), m 255 | -------------------------------------------------------------------------------- /source/models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import TransformerEncoderLayer 4 | from omegaconf import DictConfig 5 | from .base import BaseModel 6 | 7 | 8 | class GraphTransformer(BaseModel): 9 | 10 | def __init__(self, cfg: DictConfig): 11 | 12 | super().__init__() 13 | 14 | self.attention_list = nn.ModuleList() 15 | self.readout = cfg.model.readout 16 | self.node_num = cfg.dataset.node_sz 17 | 18 | for _ in range(cfg.model.self_attention_layer): 19 | self.attention_list.append( 20 | TransformerEncoderLayer(d_model=cfg.dataset.node_feature_sz, nhead=4, dim_feedforward=1024, 21 | batch_first=True) 22 | ) 23 | 24 | final_dim = cfg.dataset.node_feature_sz 25 | 26 | if self.readout == "concat": 27 | self.dim_reduction = nn.Sequential( 28 | nn.Linear(cfg.dataset.node_feature_sz, 8), 29 | nn.LeakyReLU() 30 | ) 31 | final_dim = 8 * self.node_num 32 | 33 | elif self.readout == "sum": 34 | self.norm = nn.BatchNorm1d(cfg.dataset.node_feature_sz) 35 | 36 | self.fc = nn.Sequential( 37 | nn.Linear(final_dim, 256), 38 | nn.LeakyReLU(), 39 | nn.Linear(256, 32), 40 | nn.LeakyReLU(), 41 | nn.Linear(32, 2) 42 | ) 43 | 44 | def forward(self, time_seires, node_feature): 45 | bz, _, _, = node_feature.shape 46 | 47 | for atten in self.attention_list: 48 | node_feature = atten(node_feature) 49 | 50 | if self.readout == "concat": 51 | node_feature = self.dim_reduction(node_feature) 52 | node_feature = node_feature.reshape((bz, -1)) 53 | 54 | elif self.readout == "mean": 55 | node_feature = torch.mean(node_feature, dim=1) 56 | elif self.readout == "max": 57 | node_feature, _ = torch.max(node_feature, dim=1) 58 | elif self.readout == "sum": 59 | node_feature = torch.sum(node_feature, dim=1) 60 | node_feature = self.norm(node_feature) 61 | 62 | return self.fc(node_feature) 63 | 64 | def get_attention_weights(self): 65 | return [atten.get_attention_weights() for atten in self.attention_list] 66 | -------------------------------------------------------------------------------- /source/training/FBNettraining.py: -------------------------------------------------------------------------------- 1 | from source.utils import accuracy, isfloat 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from sklearn.metrics import roc_auc_score 6 | from sklearn.metrics import precision_recall_fscore_support, classification_report 7 | from source.utils import continus_mixup_data, mixup_cluster_loss, inner_loss, intra_loss 8 | from omegaconf import DictConfig 9 | from typing import List 10 | import torch.utils.data as utils 11 | from source.components import LRScheduler 12 | import logging 13 | from .training import Train 14 | 15 | 16 | class FBNetTrain(Train): 17 | 18 | def __init__(self, cfg: DictConfig, 19 | model: torch.nn.Module, 20 | optimizers: List[torch.optim.Optimizer], 21 | lr_schedulers: List[LRScheduler], 22 | dataloaders: List[utils.DataLoader], 23 | logger: logging.Logger) -> None: 24 | 25 | super().__init__(cfg, model, optimizers, lr_schedulers, dataloaders, logger) 26 | self.group_loss = cfg.model.group_loss 27 | self.sparsity_loss = cfg.model.sparsity_loss 28 | self.sparsity_loss_weight = cfg.model.sparsity_loss_weight 29 | 30 | def train_per_epoch(self, optimizer, lr_scheduler): 31 | self.model.train() 32 | 33 | for time_series, node_feature, label in self.train_dataloader: 34 | self.current_step += 1 35 | label = label.float() 36 | 37 | lr_scheduler.update(optimizer=optimizer, step=self.current_step) 38 | 39 | time_series, node_feature, label = time_series.cuda(), node_feature.cuda(), label.cuda() 40 | 41 | if self.config.preprocess.continus: 42 | time_series, node_feature, label = continus_mixup_data( 43 | time_series, node_feature, y=label) 44 | 45 | predict, learnable_matrix = self.model(time_series, node_feature) 46 | 47 | loss = 2 * self.loss_fn(predict, label) 48 | 49 | if self.group_loss: 50 | if self.config.preprocess.continus: 51 | loss += mixup_cluster_loss(learnable_matrix, 52 | label) 53 | else: 54 | loss += 2 * intra_loss(label[:, 1], learnable_matrix) + \ 55 | inner_loss(label[:, 1], learnable_matrix) 56 | 57 | if self.sparsity_loss: 58 | sparsity_loss = self.sparsity_loss_weight * \ 59 | torch.norm(learnable_matrix, p=1) 60 | loss += sparsity_loss 61 | 62 | self.train_loss.update_with_weight(loss.item(), label.shape[0]) 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() 66 | top1 = accuracy(predict, label[:, 1])[0] 67 | self.train_accuracy.update_with_weight(top1, label.shape[0]) 68 | # wandb.log({"LR": lr_scheduler.lr, 69 | # "Iter loss": loss.item()}) 70 | 71 | def test_per_epoch(self, dataloader, loss_meter, acc_meter): 72 | labels = [] 73 | result = [] 74 | 75 | self.model.eval() 76 | 77 | for time_series, node_feature, label in dataloader: 78 | label = label.float() 79 | time_series, node_feature, label = time_series.cuda(), node_feature.cuda(), label.cuda() 80 | output, _ = self.model(time_series, node_feature) 81 | 82 | loss = self.loss_fn(output, label) 83 | loss_meter.update_with_weight( 84 | loss.item(), label.shape[0]) 85 | top1 = accuracy(output, label[:, 1])[0] 86 | acc_meter.update_with_weight(top1, label.shape[0]) 87 | result += F.softmax(output, dim=1)[:, 1].tolist() 88 | labels += label[:, 1].tolist() 89 | 90 | auc = roc_auc_score(labels, result) 91 | result = np.array(result) 92 | result[result > 0.5] = 1 93 | result[result <= 0.5] = 0 94 | metric = precision_recall_fscore_support( 95 | labels, result, average='micro') 96 | 97 | report = classification_report( 98 | labels, result, output_dict=True, zero_division=0) 99 | 100 | recall = [0, 0] 101 | for k in report: 102 | if isfloat(k): 103 | recall[int(float(k))] = report[k]['recall'] 104 | return [auc] + list(metric) + recall 105 | -------------------------------------------------------------------------------- /source/training/__init__.py: -------------------------------------------------------------------------------- 1 | from operator import mod 2 | from .training import Train 3 | from .FBNettraining import FBNetTrain 4 | from omegaconf import DictConfig 5 | from typing import List 6 | import torch 7 | from source.components import LRScheduler 8 | import logging 9 | import torch.utils.data as utils 10 | 11 | 12 | def training_factory(config: DictConfig, 13 | model: torch.nn.Module, 14 | optimizers: List[torch.optim.Optimizer], 15 | lr_schedulers: List[LRScheduler], 16 | dataloaders: List[utils.DataLoader], 17 | logger: logging.Logger) -> Train: 18 | 19 | train = config.model.get("train", None) 20 | if not train: 21 | train = config.training.name 22 | return eval(train)(cfg=config, 23 | model=model, 24 | optimizers=optimizers, 25 | lr_schedulers=lr_schedulers, 26 | dataloaders=dataloaders, 27 | logger=logger) 28 | -------------------------------------------------------------------------------- /source/training/training.py: -------------------------------------------------------------------------------- 1 | from source.utils import accuracy, TotalMeter, count_params, isfloat 2 | import torch 3 | import numpy as np 4 | from pathlib import Path 5 | import torch.nn.functional as F 6 | from sklearn.metrics import roc_auc_score 7 | from sklearn.metrics import precision_recall_fscore_support, classification_report 8 | from source.utils import continus_mixup_data 9 | import wandb 10 | from omegaconf import DictConfig 11 | from typing import List 12 | import torch.utils.data as utils 13 | from source.components import LRScheduler 14 | import logging 15 | 16 | 17 | class Train: 18 | 19 | def __init__(self, cfg: DictConfig, 20 | model: torch.nn.Module, 21 | optimizers: List[torch.optim.Optimizer], 22 | lr_schedulers: List[LRScheduler], 23 | dataloaders: List[utils.DataLoader], 24 | logger: logging.Logger) -> None: 25 | 26 | self.config = cfg 27 | self.logger = logger 28 | self.model = model 29 | self.logger.info(f'#model params: {count_params(self.model)}') 30 | self.train_dataloader, self.val_dataloader, self.test_dataloader = dataloaders 31 | self.epochs = cfg.training.epochs 32 | self.total_steps = cfg.total_steps 33 | self.optimizers = optimizers 34 | self.lr_schedulers = lr_schedulers 35 | self.loss_fn = torch.nn.CrossEntropyLoss(reduction='sum') 36 | self.save_path = Path(cfg.log_path) / cfg.unique_id 37 | self.save_learnable_graph = cfg.save_learnable_graph 38 | 39 | self.init_meters() 40 | 41 | def init_meters(self): 42 | self.train_loss, self.val_loss,\ 43 | self.test_loss, self.train_accuracy,\ 44 | self.val_accuracy, self.test_accuracy = [ 45 | TotalMeter() for _ in range(6)] 46 | 47 | def reset_meters(self): 48 | for meter in [self.train_accuracy, self.val_accuracy, 49 | self.test_accuracy, self.train_loss, 50 | self.val_loss, self.test_loss]: 51 | meter.reset() 52 | 53 | def train_per_epoch(self, optimizer, lr_scheduler): 54 | self.model.train() 55 | 56 | for time_series, node_feature, label in self.train_dataloader: 57 | label = label.float() 58 | self.current_step += 1 59 | 60 | lr_scheduler.update(optimizer=optimizer, step=self.current_step) 61 | 62 | time_series, node_feature, label = time_series.cuda(), node_feature.cuda(), label.cuda() 63 | 64 | if self.config.preprocess.continus: 65 | time_series, node_feature, label = continus_mixup_data( 66 | time_series, node_feature, y=label) 67 | 68 | predict = self.model(time_series, node_feature) 69 | 70 | loss = self.loss_fn(predict, label) 71 | 72 | self.train_loss.update_with_weight(loss.item(), label.shape[0]) 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | top1 = accuracy(predict, label[:, 1])[0] 77 | self.train_accuracy.update_with_weight(top1, label.shape[0]) 78 | # wandb.log({"LR": lr_scheduler.lr, 79 | # "Iter loss": loss.item()}) 80 | 81 | def test_per_epoch(self, dataloader, loss_meter, acc_meter): 82 | labels = [] 83 | result = [] 84 | 85 | self.model.eval() 86 | 87 | for time_series, node_feature, label in dataloader: 88 | time_series, node_feature, label = time_series.cuda(), node_feature.cuda(), label.cuda() 89 | output = self.model(time_series, node_feature) 90 | 91 | label = label.float() 92 | 93 | loss = self.loss_fn(output, label) 94 | loss_meter.update_with_weight( 95 | loss.item(), label.shape[0]) 96 | top1 = accuracy(output, label[:, 1])[0] 97 | acc_meter.update_with_weight(top1, label.shape[0]) 98 | result += F.softmax(output, dim=1)[:, 1].tolist() 99 | labels += label[:, 1].tolist() 100 | 101 | auc = roc_auc_score(labels, result) 102 | result, labels = np.array(result), np.array(labels) 103 | result[result > 0.5] = 1 104 | result[result <= 0.5] = 0 105 | metric = precision_recall_fscore_support( 106 | labels, result, average='micro') 107 | 108 | report = classification_report( 109 | labels, result, output_dict=True, zero_division=0) 110 | 111 | recall = [0, 0] 112 | for k in report: 113 | if isfloat(k): 114 | recall[int(float(k))] = report[k]['recall'] 115 | return [auc] + list(metric) + recall 116 | 117 | def generate_save_learnable_matrix(self): 118 | 119 | # wandb.log({'heatmap_with_text': wandb.plots.HeatMap(x_labels, y_labels, matrix_values, show_text=False)}) 120 | learable_matrixs = [] 121 | 122 | labels = [] 123 | 124 | for time_series, node_feature, label in self.test_dataloader: 125 | label = label.long() 126 | time_series, node_feature, label = time_series.cuda(), node_feature.cuda(), label.cuda() 127 | _, learable_matrix, _ = self.model(time_series, node_feature) 128 | 129 | learable_matrixs.append(learable_matrix.cpu().detach().numpy()) 130 | labels += label.tolist() 131 | 132 | self.save_path.mkdir(exist_ok=True, parents=True) 133 | np.save(self.save_path/"learnable_matrix.npy", {'matrix': np.vstack( 134 | learable_matrixs), "label": np.array(labels)}, allow_pickle=True) 135 | 136 | def save_result(self, results: torch.Tensor): 137 | self.save_path.mkdir(exist_ok=True, parents=True) 138 | np.save(self.save_path/"training_process.npy", 139 | results, allow_pickle=True) 140 | 141 | torch.save(self.model.state_dict(), self.save_path/"model.pt") 142 | 143 | def train(self): 144 | training_process = [] 145 | self.current_step = 0 146 | for epoch in range(self.epochs): 147 | self.reset_meters() 148 | self.train_per_epoch(self.optimizers[0], self.lr_schedulers[0]) 149 | val_result = self.test_per_epoch(self.val_dataloader, 150 | self.val_loss, self.val_accuracy) 151 | 152 | test_result = self.test_per_epoch(self.test_dataloader, 153 | self.test_loss, self.test_accuracy) 154 | 155 | self.logger.info(" | ".join([ 156 | f'Epoch[{epoch}/{self.epochs}]', 157 | f'Train Loss:{self.train_loss.avg: .3f}', 158 | f'Train Accuracy:{self.train_accuracy.avg: .3f}%', 159 | f'Test Loss:{self.test_loss.avg: .3f}', 160 | f'Test Accuracy:{self.test_accuracy.avg: .3f}%', 161 | f'Val AUC:{val_result[0]:.4f}', 162 | f'Test AUC:{test_result[0]:.4f}', 163 | f'Test Sen:{test_result[-1]:.4f}', 164 | f'LR:{self.lr_schedulers[0].lr:.4f}' 165 | ])) 166 | 167 | wandb.log({ 168 | "Train Loss": self.train_loss.avg, 169 | "Train Accuracy": self.train_accuracy.avg, 170 | "Test Loss": self.test_loss.avg, 171 | "Test Accuracy": self.test_accuracy.avg, 172 | "Val AUC": val_result[0], 173 | "Test AUC": test_result[0], 174 | 'Test Sensitivity': test_result[-1], 175 | 'Test Specificity': test_result[-2], 176 | 'micro F1': test_result[-4], 177 | 'micro recall': test_result[-5], 178 | 'micro precision': test_result[-6], 179 | }) 180 | 181 | training_process.append({ 182 | "Epoch": epoch, 183 | "Train Loss": self.train_loss.avg, 184 | "Train Accuracy": self.train_accuracy.avg, 185 | "Test Loss": self.test_loss.avg, 186 | "Test Accuracy": self.test_accuracy.avg, 187 | "Test AUC": test_result[0], 188 | 'Test Sensitivity': test_result[-1], 189 | 'Test Specificity': test_result[-2], 190 | 'micro F1': test_result[-4], 191 | 'micro recall': test_result[-5], 192 | 'micro precision': test_result[-6], 193 | "Val AUC": val_result[0], 194 | "Val Loss": self.val_loss.avg, 195 | }) 196 | 197 | if self.save_learnable_graph: 198 | self.generate_save_learnable_matrix() 199 | self.save_result(training_process) 200 | -------------------------------------------------------------------------------- /source/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import accuracy, isfloat 2 | from .meter import WeightedMeter, AverageMeter, TotalMeter 3 | from .count_params import count_params 4 | from .prepossess import mixup_criterion, continus_mixup_data, mixup_cluster_loss, intra_loss, inner_loss 5 | -------------------------------------------------------------------------------- /source/utils/accuracy.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def accuracy(output: torch.Tensor, target: torch.Tensor, top_k=(1,)) -> List[float]: 7 | """Computes the precision@k for the specified values of k""" 8 | max_k = max(top_k) 9 | batch_size = target.size(0) 10 | 11 | _, predict = output.topk(max_k, 1, True, True) 12 | predict = predict.t() 13 | correct = predict.eq(target.view(1, -1).expand_as(predict)) 14 | 15 | res = [] 16 | for k in top_k: 17 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 18 | res.append(correct_k.mul_(100.0 / batch_size).item()) 19 | return res 20 | 21 | 22 | def isfloat(num): 23 | try: 24 | float(num) 25 | return True 26 | except ValueError: 27 | return False -------------------------------------------------------------------------------- /source/utils/count_params.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def count_params(model: nn.Module, only_requires_grad: bool = False): 5 | "count number trainable parameters in a pytorch model" 6 | if only_requires_grad: 7 | total_params = sum(p.numel() 8 | for p in model.parameters() if p.requires_grad) 9 | else: 10 | total_params = sum(p.numel() for p in model.parameters()) 11 | return total_params 12 | -------------------------------------------------------------------------------- /source/utils/gumbel_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def sample_gumbel(shape, eps=1e-20): 6 | U = torch.rand(shape).cuda() 7 | return -torch.autograd.Variable(torch.log(-torch.log(U + eps) + eps)) 8 | 9 | 10 | def gumbel_softmax_sample(logits, temperature, eps=1e-10): 11 | sample = sample_gumbel(logits.size(), eps=eps) 12 | y = logits + sample 13 | return F.softmax(y / temperature, dim=-1) 14 | 15 | 16 | def gumbel_softmax(logits, temperature, hard=False, eps=1e-10): 17 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 18 | Args: 19 | logits: [batch_size, n_class] unnormalized log-probs 20 | temperature: non-negative scalar 21 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 22 | Returns: 23 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 24 | If hard=True, then the returned sample will be one-hot, otherwise it will 25 | be a probabilitiy distribution that sums to 1 across classes 26 | """ 27 | y_soft = gumbel_softmax_sample(logits, temperature=temperature, eps=eps) 28 | if hard: 29 | shape = logits.size() 30 | _, k = y_soft.data.max(-1) 31 | y_hard = torch.zeros(*shape).cuda() 32 | y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) 33 | y = torch.autograd.Variable(y_hard - y_soft.data) + y_soft 34 | else: 35 | y = y_soft 36 | return y 37 | -------------------------------------------------------------------------------- /source/utils/hyperboloid.py: -------------------------------------------------------------------------------- 1 | """Hyperboloid manifold.""" 2 | 3 | import torch 4 | 5 | from manifolds.base import Manifold 6 | from utils.math_utils import arcosh, cosh, sinh 7 | 8 | 9 | class Hyperboloid(Manifold): 10 | """ 11 | Hyperboloid manifold class. 12 | 13 | We use the following convention: -x0^2 + x1^2 + ... + xd^2 = -K 14 | 15 | c = 1 / K is the hyperbolic curvature. 16 | """ 17 | 18 | def __init__(self): 19 | super(Hyperboloid, self).__init__() 20 | self.name = 'Hyperboloid' 21 | self.eps = {torch.float32: 1e-7, torch.float64: 1e-15} 22 | self.min_norm = 1e-15 23 | self.max_norm = 1e6 24 | 25 | def minkowski_dot(self, x, y, keepdim=True): 26 | res = torch.sum(x * y, dim=-1) - 2 * x[..., 0] * y[..., 0] 27 | if keepdim: 28 | res = res.view(res.shape + (1,)) 29 | return res 30 | 31 | def minkowski_norm(self, u, keepdim=True): 32 | dot = self.minkowski_dot(u, u, keepdim=keepdim) 33 | return torch.sqrt(torch.clamp(dot, min=self.eps[u.dtype])) 34 | 35 | def sqdist(self, x, y, c): 36 | K = 1. / c 37 | prod = self.minkowski_dot(x, y) 38 | theta = torch.clamp(-prod / K, min=1.0 + self.eps[x.dtype]) 39 | sqdist = K * arcosh(theta) ** 2 40 | # clamp distance to avoid nans in Fermi-Dirac decoder 41 | return torch.clamp(sqdist, max=50.0) 42 | 43 | def proj(self, x, c): 44 | K = 1. / c 45 | d = x.size(-1) - 1 46 | y = x.narrow(-1, 1, d) 47 | y_sqnorm = torch.norm(y, p=2, dim=1, keepdim=True) ** 2 48 | mask = torch.ones_like(x) 49 | mask[:, 0] = 0 50 | vals = torch.zeros_like(x) 51 | vals[:, 0:1] = torch.sqrt(torch.clamp(K + y_sqnorm, min=self.eps[x.dtype])) 52 | return vals + mask * x 53 | 54 | def proj_tan(self, u, x, c): 55 | K = 1. / c 56 | d = x.size(1) - 1 57 | ux = torch.sum(x.narrow(-1, 1, d) * u.narrow(-1, 1, d), dim=1, keepdim=True) 58 | mask = torch.ones_like(u) 59 | mask[:, 0] = 0 60 | vals = torch.zeros_like(u) 61 | vals[:, 0:1] = ux / torch.clamp(x[:, 0:1], min=self.eps[x.dtype]) 62 | return vals + mask * u 63 | 64 | def proj_tan0(self, u, c): 65 | narrowed = u.narrow(-1, 0, 1) 66 | vals = torch.zeros_like(u) 67 | vals[:, 0:1] = narrowed 68 | return u - vals 69 | 70 | def expmap(self, u, x, c): 71 | K = 1. / c 72 | sqrtK = K ** 0.5 73 | normu = self.minkowski_norm(u) 74 | normu = torch.clamp(normu, max=self.max_norm) 75 | theta = normu / sqrtK 76 | theta = torch.clamp(theta, min=self.min_norm) 77 | result = cosh(theta) * x + sinh(theta) * u / theta 78 | return self.proj(result, c) 79 | 80 | def logmap(self, x, y, c): 81 | K = 1. / c 82 | xy = torch.clamp(self.minkowski_dot(x, y) + K, max=-self.eps[x.dtype]) - K 83 | u = y + xy * x * c 84 | normu = self.minkowski_norm(u) 85 | normu = torch.clamp(normu, min=self.min_norm) 86 | dist = self.sqdist(x, y, c) ** 0.5 87 | result = dist * u / normu 88 | return self.proj_tan(result, x, c) 89 | 90 | def expmap0(self, u, c): 91 | K = 1. / c 92 | sqrtK = K ** 0.5 93 | d = u.size(-1) - 1 94 | x = u.narrow(-1, 1, d).view(-1, d) 95 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True) 96 | x_norm = torch.clamp(x_norm, min=self.min_norm) 97 | theta = x_norm / sqrtK 98 | res = torch.ones_like(u) 99 | res[:, 0:1] = sqrtK * cosh(theta) 100 | res[:, 1:] = sqrtK * sinh(theta) * x / x_norm 101 | return self.proj(res, c) 102 | 103 | def logmap0(self, x, c): 104 | K = 1. / c 105 | sqrtK = K ** 0.5 106 | d = x.size(-1) - 1 107 | y = x.narrow(-1, 1, d).view(-1, d) 108 | y_norm = torch.norm(y, p=2, dim=1, keepdim=True) 109 | y_norm = torch.clamp(y_norm, min=self.min_norm) 110 | res = torch.zeros_like(x) 111 | theta = torch.clamp(x[:, 0:1] / sqrtK, min=1.0 + self.eps[x.dtype]) 112 | res[:, 1:] = sqrtK * arcosh(theta) * y / y_norm 113 | return res 114 | 115 | def mobius_add(self, x, y, c): 116 | u = self.logmap0(y, c) 117 | v = self.ptransp0(x, u, c) 118 | return self.expmap(v, x, c) 119 | 120 | def mobius_matvec(self, m, x, c): 121 | u = self.logmap0(x, c) 122 | mu = u @ m.transpose(-1, -2) 123 | return self.expmap0(mu, c) 124 | 125 | def ptransp(self, x, y, u, c): 126 | logxy = self.logmap(x, y, c) 127 | logyx = self.logmap(y, x, c) 128 | sqdist = torch.clamp(self.sqdist(x, y, c), min=self.min_norm) 129 | alpha = self.minkowski_dot(logxy, u) / sqdist 130 | res = u - alpha * (logxy + logyx) 131 | return self.proj_tan(res, y, c) 132 | 133 | def ptransp0(self, x, u, c): 134 | K = 1. / c 135 | sqrtK = K ** 0.5 136 | x0 = x.narrow(-1, 0, 1) 137 | d = x.size(-1) - 1 138 | y = x.narrow(-1, 1, d) 139 | y_norm = torch.clamp(torch.norm(y, p=2, dim=1, keepdim=True), min=self.min_norm) 140 | y_normalized = y / y_norm 141 | v = torch.ones_like(x) 142 | v[:, 0:1] = - y_norm 143 | v[:, 1:] = (sqrtK - x0) * y_normalized 144 | alpha = torch.sum(y_normalized * u[:, 1:], dim=1, keepdim=True) / sqrtK 145 | res = u - alpha * v 146 | return self.proj_tan(res, x, c) 147 | 148 | def to_poincare(self, x, c): 149 | K = 1. / c 150 | sqrtK = K ** 0.5 151 | d = x.size(-1) - 1 152 | return sqrtK * x.narrow(-1, 1, d) / (x[:, 0:1] + sqrtK) 153 | 154 | -------------------------------------------------------------------------------- /source/utils/meter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class WeightedMeter: 5 | def __init__(self, name: str = None): 6 | self.name = name 7 | self.count = 0 8 | self.sum = 0.0 9 | self.avg = 0.0 10 | self.val = 0.0 11 | 12 | def update(self, val: float, num: int = 1): 13 | self.count += num 14 | self.sum += val * num 15 | self.avg = self.sum / self.count 16 | self.val = val 17 | 18 | def reset(self, total: float = 0, count: int = 0): 19 | self.count = count 20 | self.sum = total 21 | self.avg = total / max(count, 1) 22 | self.val = total / max(count, 1) 23 | 24 | 25 | class AverageMeter: 26 | def __init__(self, length: int, name: str = None): 27 | assert length > 0 28 | self.name = name 29 | self.count = 0 30 | self.sum = 0.0 31 | self.current: int = -1 32 | self.history: List[float] = [None] * length 33 | 34 | @property 35 | def val(self) -> float: 36 | return self.history[self.current] 37 | 38 | @property 39 | def avg(self) -> float: 40 | return self.sum / self.count 41 | 42 | def update(self, val: float): 43 | self.current = (self.current + 1) % len(self.history) 44 | self.sum += val 45 | 46 | old = self.history[self.current] 47 | if old is None: 48 | self.count += 1 49 | else: 50 | self.sum -= old 51 | self.history[self.current] = val 52 | 53 | 54 | class TotalMeter: 55 | def __init__(self): 56 | self.sum = 0.0 57 | self.count = 0 58 | 59 | def update(self, val: float): 60 | self.sum += val 61 | self.count += 1 62 | 63 | def update_with_weight(self, val: float, count: int): 64 | self.sum += val*count 65 | self.count += count 66 | 67 | def reset(self): 68 | self.sum = 0 69 | self.count = 0 70 | 71 | @property 72 | def avg(self): 73 | if self.count == 0: 74 | return -1 75 | return self.sum / self.count 76 | -------------------------------------------------------------------------------- /source/utils/prepossess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def continus_mixup_data(*xs, y=None, alpha=1.0, device='cuda'): 6 | '''Returns mixed inputs, pairs of targets, and lambda''' 7 | if alpha > 0: 8 | lam = np.random.beta(alpha, alpha) 9 | else: 10 | lam = 1 11 | batch_size = y.size()[0] 12 | index = torch.randperm(batch_size).to(device) 13 | new_xs = [lam * x + (1 - lam) * x[index, :] for x in xs] 14 | y = lam * y + (1-lam) * y[index] 15 | return *new_xs, y 16 | 17 | 18 | def mixup_data_by_class(x, nodes, y, alpha=1.0, device='cuda'): 19 | '''Returns mixed inputs, pairs of targets, and lambda''' 20 | 21 | mix_xs, mix_nodes, mix_ys = [], [], [] 22 | 23 | for t_y in y.unique(): 24 | idx = y == t_y 25 | 26 | t_mixed_x, t_mixed_nodes, _, _, _ = continus_mixup_data( 27 | x[idx], nodes[idx], y[idx], alpha=alpha, device=device) 28 | mix_xs.append(t_mixed_x) 29 | mix_nodes.append(t_mixed_nodes) 30 | 31 | mix_ys.append(y[idx]) 32 | 33 | return torch.cat(mix_xs, dim=0), torch.cat(mix_nodes, dim=0), torch.cat(mix_ys, dim=0) 34 | 35 | 36 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 37 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 38 | 39 | 40 | def mixup_cluster_loss(matrixs, y, intra_weight=2): 41 | 42 | y_1 = y[:, 1] 43 | 44 | y_0 = y[:, 0] 45 | 46 | bz, roi_num, _ = matrixs.shape 47 | matrixs = matrixs.reshape((bz, -1)) 48 | sum_1 = torch.sum(y_1) 49 | sum_0 = torch.sum(y_0) 50 | loss = 0.0 51 | 52 | if sum_0 > 0: 53 | center_0 = torch.matmul(y_0, matrixs)/sum_0 54 | diff_0 = torch.norm(matrixs-center_0, p=1, dim=1) 55 | loss += torch.matmul(y_0, diff_0)/(sum_0*roi_num*roi_num) 56 | if sum_1 > 0: 57 | center_1 = torch.matmul(y_1, matrixs)/sum_1 58 | diff_1 = torch.norm(matrixs-center_1, p=1, dim=1) 59 | loss += torch.matmul(y_1, diff_1)/(sum_1*roi_num*roi_num) 60 | if sum_0 > 0 and sum_1 > 0: 61 | loss += intra_weight * \ 62 | (1 - torch.norm(center_0-center_1, p=1)/(roi_num*roi_num)) 63 | 64 | return loss 65 | 66 | 67 | def inner_loss(label, matrixs): 68 | 69 | loss = 0 70 | 71 | if torch.sum(label == 0) > 1: 72 | loss += torch.mean(torch.var(matrixs[label == 0], dim=0)) 73 | 74 | if torch.sum(label == 1) > 1: 75 | loss += torch.mean(torch.var(matrixs[label == 1], dim=0)) 76 | 77 | return loss 78 | 79 | 80 | def intra_loss(label, matrixs): 81 | a, b = None, None 82 | 83 | if torch.sum(label == 0) > 0: 84 | a = torch.mean(matrixs[label == 0], dim=0) 85 | 86 | if torch.sum(label == 1) > 0: 87 | b = torch.mean(matrixs[label == 1], dim=0) 88 | if a is not None and b is not None: 89 | return 1 - torch.mean(torch.pow(a-b, 2)) 90 | else: 91 | return 0 92 | --------------------------------------------------------------------------------