├── README.md ├── simclr ├── README.md ├── nt_xent.py ├── simclr_datasets.py ├── simclr_eval.py ├── simclr_models.py └── simclr_ecg.py ├── multitask ├── README.md ├── splitters.py ├── dataloader.py ├── finetune_ft30.py ├── loader.py ├── finetune.py ├── util.py ├── batch.py ├── model.py ├── pretrain_supervised_weighting.py └── pretrain_supervised_weighting_ft30.py └── LICENSE /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning to Improve Pre-Training 2 | This folder contains code to run experiments in the paper [Meta-Learning to Improve Pre-Training, NeurIPS 2021](https://openreview.net/forum?id=Wiq6Mg8btwT). Please refer to the README files in the `multitask/` and `simclr` folders for experiments in the mutlitask PT and self-supervised PT domains, respectively. 3 | 4 | We also include a self-contained notebook that can be run on Google Colab to examine the synthetic MNIST data augmentation domain [here](https://colab.research.google.com/drive/1k5sc7Ij1wxCRFdD7aObJmuj6qdBcEFdm?usp=sharing). 5 | -------------------------------------------------------------------------------- /simclr/README.md: -------------------------------------------------------------------------------- 1 | # Meta-Parameterized SimCLR 2 | This folder contains code to run the meta-parameterized SimCLR experiments in the paper. 3 | 4 | 5 | ## Getting started 6 | Download the dataset following the instructions [here](https://www.physionet.org/content/ptb-xl/1.0.1/). Once you have downloaded the dataset, make sure the `path` variable in `simclr_datasets.py` to be that path to the data. 7 | 8 | Install the core libraries: `pip install torch higher wfdb pandas numpy` 9 | 10 | 11 | ## Pre-training 12 | 13 | **No augmentation learning:** To train a SimCLR model with default augmentations run: 14 | ```python simclr_ecg.py --warmup_epochs 100 --epochs 50 --teacherarch warpexmag --gpu GPU --seed SEED``` 15 | 16 | The warmup epochs being greater than the number of epochs means the augmentations are not optimized. 17 | 18 | 19 | **Augmentation learning:** To train a SimCLR model and optimize augmentations, with `N` MetaFT examples, run: 20 | ```python simclr_ecg.py --warmup_epochs 1 --epochs 50 --teacherarch warpexmag --gpu GPU --seed SEED --ex N``` 21 | 22 | 23 | 24 | ## Fine-tuning 25 | 26 | To fine-tune a pre-trained model on `NFT` fine-tuning examples (FT dataset has `N` data points), with FT seed RUNSEED and dataset seed (i.e., PT seed) SEED: 27 | 28 | `python simclr_eval.py --gpu GPU --checkpoint /PATH/TO/CHECKPOINT --transfer_eval --runseed RUNSEED --seed SEED --ex NFT` 29 | 30 | Note: The partial FT access setting is when `NFT` is more than `N` from the PT. 31 | -------------------------------------------------------------------------------- /simclr/nt_xent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class NTXentLoss(torch.nn.Module): 6 | 7 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 8 | super(NTXentLoss, self).__init__() 9 | self.batch_size = batch_size 10 | self.temperature = temperature 11 | self.device = device 12 | self.softmax = torch.nn.Softmax(dim=-1) 13 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) 14 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 15 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 16 | 17 | def _get_similarity_function(self, use_cosine_similarity): 18 | if use_cosine_similarity: 19 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 20 | return self._cosine_simililarity 21 | else: 22 | return self._dot_simililarity 23 | 24 | def _get_correlated_mask(self): 25 | diag = np.eye(2 * self.batch_size) 26 | l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) 27 | l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) 28 | mask = torch.from_numpy((diag + l1 + l2)) 29 | mask = (1 - mask).type(torch.bool) 30 | return mask.to(self.device) 31 | 32 | @staticmethod 33 | def _dot_simililarity(x, y): 34 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 35 | # x shape: (N, 1, C) 36 | # y shape: (1, C, 2N) 37 | # v shape: (N, 2N) 38 | return v 39 | 40 | def _cosine_simililarity(self, x, y): 41 | # x shape: (N, 1, C) 42 | # y shape: (1, 2N, C) 43 | # v shape: (N, 2N) 44 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 45 | return v 46 | 47 | def forward(self, zis, zjs): 48 | representations = torch.cat([zjs, zis], dim=0) 49 | 50 | similarity_matrix = self.similarity_function(representations, representations) 51 | 52 | # filter out the scores from the positive samples 53 | l_pos = torch.diag(similarity_matrix, self.batch_size) 54 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 55 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 56 | 57 | negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) 58 | 59 | logits = torch.cat((positives, negatives), dim=1) 60 | logits /= self.temperature 61 | 62 | labels = torch.zeros(2 * self.batch_size).to(self.device).long() 63 | loss = self.criterion(logits, labels) 64 | 65 | return loss / (2 * self.batch_size) -------------------------------------------------------------------------------- /multitask/README.md: -------------------------------------------------------------------------------- 1 | # Meta-Parameterized Multitask PT 2 | This folder contains code to run the meta-parameterized multitask PT experiments in the paper. Note that most of the dataloading, model definition, and training code is based on the implementation of `Strategies for Pre-training Graph Neural Nets', Hu et al., ICLR 2020 at [this link](https://github.com/snap-stanford/pretrain-gnns/). 3 | 4 | 5 | ## Getting started 6 | Follow the installation instructions from Hu et al., ICLR 2020 at [this link](https://github.com/snap-stanford/pretrain-gnns/). Download the biological dataset, and all the python dependencies. Also install the `higher` library: `pip install higher`. Once you have downloaded the dataset, make sure the `root_supervised` variable in the training scripts and the `util.py` script is set to be that path. 7 | 8 | 9 | ## Pre-training 10 | 11 | ### Full FT Access 12 | To train the full FT access model, run: 13 | ```python pretrain_supervised_weighting.py --gpu 0 --savefol exw-adamhyper``` 14 | 15 | ### Partial FT Access 16 | To train the partial FT access model with 0.5 of the total data, run: 17 | ```python pretrain_supervised_weighting.py --gpu 0 --savefol exw-adamhyper --smallft 0.5``` 18 | 19 | To train the partial FT access model on a subset of the FT tasks, say fold 0 (tasks 0-29 in meta PT, tasks 30-39 at FT time), run: 20 | ```python pretrain_supervised_multitask5030.py --gpu 0 --fold 0``` 21 | 22 | ## Fine-tuning 23 | 24 | ### Full FT Access and Partial FT access with small MetaFT dataset 25 | To fine-tune the full FT access model or the model using a small MetaFT dataset, run the below command, replacing the `OUTPUT_FILENAME` by the desired output filename, `SEED` by the fine-tuning seed, `LR` by the FT learning rate, and `PATH/TO/CHECKPOINT` by the path to the pre-trained checkpoint. 26 | 27 | ```python finetune.py --device 0 --filename OUTPUT_FILENAME --runseed SEED --lr LR --model_file PATH/TO/CHECKPOINT``` 28 | 29 | By default, this does full transfer. If you want to do linear evaluation, add the `--lineval` flag to the end of the above command. We used an LR of 1e-5 for full transfer, and 1e-4 for linear eval. 30 | 31 | 32 | ### Partial FT Access model with subset of tasks 33 | To fine-tune, run the below command, replacing the `OUTPUT_FILENAME` by the desired output filename, `SEED` by the fine-tuning seed, `LR` by the FT learning rate, `PATH/TO/CHECKPOINT` by the path to the pre-trained checkpoint, and `FOLD` by the desired fold. 34 | 35 | ```python finetune_ft30.py --device 0 --filename OUTPUT_FILENAME --runseed SEED --lr LR --model_file PATH/TO/CHECKPOINT --fold FOLD``` 36 | 37 | By default, this does full transfer. If you want to do linear evaluation, add the `--lineval` flag to the end of the above command. We used an LR of 1e-5 for full transfer, and 1e-4 for linear eval. 38 | -------------------------------------------------------------------------------- /multitask/splitters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def random_split(dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1, 6 | seed=0): 7 | """ 8 | Adapted from graph-pretrain 9 | :param dataset: 10 | :param task_idx: 11 | :param null_value: 12 | :param frac_train: 13 | :param frac_valid: 14 | :param frac_test: 15 | :param seed: 16 | :return: train, valid, test slices of the input dataset obj. 17 | """ 18 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) 19 | 20 | num_mols = len(dataset) 21 | random.seed(seed) 22 | all_idx = list(range(num_mols)) 23 | random.shuffle(all_idx) 24 | 25 | train_idx = all_idx[:int(frac_train * num_mols)] 26 | valid_idx = all_idx[int(frac_train * num_mols):int(frac_valid * num_mols) 27 | + int(frac_train * num_mols)] 28 | test_idx = all_idx[int(frac_valid * num_mols) + int(frac_train * num_mols):] 29 | 30 | assert len(set(train_idx).intersection(set(valid_idx))) == 0 31 | assert len(set(valid_idx).intersection(set(test_idx))) == 0 32 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols 33 | 34 | train_dataset = dataset[torch.tensor(train_idx)] 35 | valid_dataset = dataset[torch.tensor(valid_idx)] 36 | if frac_test == 0: 37 | test_dataset = None 38 | else: 39 | test_dataset = dataset[torch.tensor(test_idx)] 40 | 41 | return train_dataset, valid_dataset, test_dataset 42 | 43 | def species_split(dataset, train_valid_species_id_list=[3702, 6239, 511145, 44 | 7227, 10090, 4932, 7955], 45 | test_species_id_list=[9606]): 46 | """ 47 | Split dataset based on species_id attribute 48 | :param dataset: 49 | :param train_valid_species_id_list: 50 | :param test_species_id_list: 51 | :return: train_valid dataset, test dataset 52 | """ 53 | # NB: pytorch geometric dataset object can be indexed using slices or 54 | # byte tensors. We will use byte tensors here 55 | 56 | train_valid_byte_tensor = torch.zeros(len(dataset), dtype=torch.uint8) 57 | for id in train_valid_species_id_list: 58 | train_valid_byte_tensor += (dataset.data.species_id == id) 59 | 60 | test_species_byte_tensor = torch.zeros(len(dataset), dtype=torch.uint8) 61 | for id in test_species_id_list: 62 | test_species_byte_tensor += (dataset.data.species_id == id) 63 | 64 | assert ((train_valid_byte_tensor + test_species_byte_tensor) == 1).all() 65 | 66 | train_valid_dataset = dataset[train_valid_byte_tensor] 67 | test_valid_dataset = dataset[test_species_byte_tensor] 68 | 69 | return train_valid_dataset, test_valid_dataset 70 | 71 | if __name__ == "__main__": 72 | from collections import Counter 73 | -------------------------------------------------------------------------------- /multitask/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from batch import BatchFinetune, BatchMasking, BatchAE, BatchSubstructContext 5 | 6 | class DataLoaderFinetune(torch.utils.data.DataLoader): 7 | r"""Data loader which merges data objects from a 8 | :class:`torch_geometric.data.dataset` to a mini-batch. 9 | Args: 10 | dataset (Dataset): The dataset from which to load the data. 11 | batch_size (int, optional): How may samples per batch to load. 12 | (default: :obj:`1`) 13 | shuffle (bool, optional): If set to :obj:`True`, the data will be 14 | reshuffled at every epoch (default: :obj:`True`) 15 | """ 16 | 17 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 18 | super(DataLoaderFinetune, self).__init__( 19 | dataset, 20 | batch_size, 21 | shuffle, 22 | collate_fn=lambda data_list: BatchFinetune.from_data_list(data_list), 23 | **kwargs) 24 | 25 | class DataLoaderMasking(torch.utils.data.DataLoader): 26 | r"""Data loader which merges data objects from a 27 | :class:`torch_geometric.data.dataset` to a mini-batch. 28 | Args: 29 | dataset (Dataset): The dataset from which to load the data. 30 | batch_size (int, optional): How may samples per batch to load. 31 | (default: :obj:`1`) 32 | shuffle (bool, optional): If set to :obj:`True`, the data will be 33 | reshuffled at every epoch (default: :obj:`True`) 34 | """ 35 | 36 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 37 | super(DataLoaderMasking, self).__init__( 38 | dataset, 39 | batch_size, 40 | shuffle, 41 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list), 42 | **kwargs) 43 | 44 | 45 | class DataLoaderAE(torch.utils.data.DataLoader): 46 | r"""Data loader which merges data objects from a 47 | :class:`torch_geometric.data.dataset` to a mini-batch. 48 | Args: 49 | dataset (Dataset): The dataset from which to load the data. 50 | batch_size (int, optional): How may samples per batch to load. 51 | (default: :obj:`1`) 52 | shuffle (bool, optional): If set to :obj:`True`, the data will be 53 | reshuffled at every epoch (default: :obj:`True`) 54 | """ 55 | 56 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 57 | super(DataLoaderAE, self).__init__( 58 | dataset, 59 | batch_size, 60 | shuffle, 61 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list), 62 | **kwargs) 63 | 64 | 65 | class DataLoaderSubstructContext(torch.utils.data.DataLoader): 66 | r"""Data loader which merges data objects from a 67 | :class:`torch_geometric.data.dataset` to a mini-batch. 68 | Args: 69 | dataset (Dataset): The dataset from which to load the data. 70 | batch_size (int, optional): How may samples per batch to load. 71 | (default: :obj:`1`) 72 | shuffle (bool, optional): If set to :obj:`True`, the data will be 73 | reshuffled at every epoch (default: :obj:`True`) 74 | """ 75 | 76 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 77 | super(DataLoaderSubstructContext, self).__init__( 78 | dataset, 79 | batch_size, 80 | shuffle, 81 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list), 82 | **kwargs) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /simclr/simclr_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | import random 6 | # np.random.seed(0) 7 | 8 | import pandas as pd 9 | import wfdb 10 | import ast 11 | 12 | 13 | class ECGSimCLR(Dataset): 14 | def __init__(self, x, y, transform, simclr=True): 15 | super(ECGSimCLR,self).__init__() 16 | # do some padding here. 17 | if x.shape[1] != 1024 and x.shape[1] == 1000: 18 | # pad 19 | x = np.pad(x, [[0,0], [0,24], [0,0]]) 20 | self.x = x.astype(np.float32) 21 | self.y = y.astype(np.float32) 22 | self.transform = transform 23 | self.simclr = simclr 24 | 25 | def __len__(self): 26 | return len(self.x) 27 | 28 | def __getitem__(self, idx): 29 | x = self.x[idx] 30 | y = self.y[idx] 31 | if self.transform is not None: 32 | x = self.transform(x) 33 | if self.simclr: 34 | return x 35 | else: 36 | sample = (x, y) 37 | return sample 38 | 39 | 40 | class ECGDataSetWrapper(object): 41 | 42 | def __init__(self, batch_size, num_workers=0): 43 | self.batch_size = batch_size 44 | self.num_workers = num_workers 45 | 46 | def get_data_loaders(self, args, evaluate = False): 47 | 48 | def load_raw_data(df, sampling_rate, path): 49 | if sampling_rate == 100: 50 | data = [wfdb.rdsamp(path+f) for f in df.filename_lr] 51 | else: 52 | data = [wfdb.rdsamp(path+f) for f in df.filename_hr] 53 | data = np.array([signal for signal, meta in data]) 54 | return data 55 | 56 | def aggregate_diagnostic(y_dic): 57 | tmp = np.zeros(5) 58 | idxd = {'NORM' : 0, 'MI' : 1, 'STTC' : 2, 'CD' : 3, 'HYP' : 4} 59 | for key in y_dic.keys(): 60 | if key in agg_df.index: 61 | cls = agg_df.loc[key].diagnostic_class 62 | tmp[idxd[cls]] = 1 63 | return tmp 64 | 65 | path = 'path/to/dataset/' 66 | sampling_rate=100 67 | 68 | # load and convert annotation data 69 | Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id') 70 | Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 71 | 72 | # Load raw signal data 73 | X = load_raw_data(Y, sampling_rate, path) 74 | 75 | # Load scp_statements.csv for diagnostic aggregation 76 | agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0) 77 | agg_df = agg_df[agg_df.diagnostic == 1] 78 | 79 | # Apply diagnostic superclass 80 | Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic) 81 | 82 | # Split data into train and test 83 | test_fold = 10 84 | # Train 85 | X_train = X[np.where(Y.strat_fold != test_fold)] 86 | y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass 87 | y_train = np.stack(y_train, axis=0) 88 | 89 | # Test 90 | X_test = X[np.where(Y.strat_fold == test_fold)] 91 | y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass 92 | y_test = np.stack(y_test, axis=0) 93 | 94 | data_augment = self.get_simclr_pipeline_transform() 95 | data_augment = SimCLRDataTransform(data_augment) 96 | 97 | FT_TASKS = 5 98 | 99 | # Normalisation: follow PTB-XL demo code. Do zero mean, unit var normalisation across all leads, timesteps, and patients 100 | meansig = np.mean(X_train.reshape(-1)) 101 | stdsig = np.std(X_train.reshape(-1)) 102 | X_train = (X_train - meansig)/stdsig 103 | X_test = (X_test - meansig)/stdsig 104 | 105 | pretrain_dataset = ECGSimCLR(X_train, y_train, data_augment) 106 | 107 | torch.manual_seed(args.seed) 108 | random.seed(args.seed) 109 | np.random.seed(args.seed) 110 | 111 | rng = np.random.RandomState(args.seed) 112 | idxs = np.arange(len(y_train)) 113 | rng.shuffle(idxs) 114 | 115 | 116 | if args.ex >= 50: 117 | train_samp = int(0.8*args.ex) 118 | val_samp = args.ex - train_samp 119 | else: 120 | if args.ex == 25: 121 | train_samp = 15 122 | val_samp = 10 123 | elif args.ex == 10: 124 | train_samp = val_samp = 5 125 | 126 | train_idxs = idxs[:train_samp] 127 | val_idxs = idxs[train_samp:train_samp+val_samp] 128 | 129 | ft_train = ECGSimCLR(X_train[train_idxs], y_train[train_idxs], transform=None, simclr=False) 130 | ft_val = ECGSimCLR(X_train[val_idxs], y_train[val_idxs], transform=None, simclr=False) 131 | ft_test = ECGSimCLR(X_test, y_test, transform=None, simclr=False) 132 | 133 | pretrain_loader = torch.utils.data.DataLoader(dataset=pretrain_dataset, 134 | batch_size=args.batch_size, 135 | shuffle=True, 136 | num_workers=0, 137 | drop_last=True) 138 | ft_train_loader = torch.utils.data.DataLoader(dataset=ft_train, 139 | batch_size=args.batch_size, 140 | shuffle=True, 141 | num_workers=0) 142 | ft_val_loader = torch.utils.data.DataLoader(dataset=ft_val, 143 | batch_size=args.batch_size, 144 | shuffle=True, 145 | num_workers=0) 146 | ft_test_loader = torch.utils.data.DataLoader(dataset=ft_test, 147 | batch_size=args.batch_size, 148 | shuffle=True, 149 | num_workers=0) 150 | 151 | return pretrain_loader, ft_train_loader, ft_val_loader, ft_test_loader, None, FT_TASKS 152 | 153 | def get_simclr_pipeline_transform(self): 154 | def rand_crop_ecg(ecg): 155 | cropped_ecg = ecg.copy() 156 | for j in range(ecg.shape[1]): 157 | crop_len = np.random.randint(len(ecg)) // 2 158 | crop_start = max(0, np.random.randint(-crop_len, len(ecg))) 159 | cropped_ecg[crop_start: crop_start + crop_len, j] = 0 160 | return cropped_ecg 161 | def rand_add_noise(ecg): 162 | noise_frac = np.random.rand() * .1 163 | return ecg + noise_frac * ecg.std(axis=0) * np.random.randn(*ecg.shape) 164 | data_transforms = [rand_crop_ecg, rand_add_noise] 165 | return data_transforms 166 | 167 | 168 | 169 | class SimCLRDataTransform(object): 170 | def __init__(self, transform): 171 | self.transform = transform 172 | 173 | def __call__(self, sample): 174 | xi = sample.copy() 175 | for t in self.transform: 176 | xi = t(xi) 177 | 178 | xj = sample.copy() 179 | for t in self.transform: 180 | xj = t(xj) 181 | 182 | return xi.astype(np.float32), xj.astype(np.float32) 183 | -------------------------------------------------------------------------------- /simclr/simclr_eval.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as utils 7 | import torch.nn.functional as F 8 | import higher 9 | import pickle 10 | from torch.utils.data import Dataset, DataLoader, Subset 11 | from torchvision import datasets, transforms 12 | from torch.autograd import grad 13 | from tqdm import tqdm 14 | from simclr_models import * 15 | from simclr_datasets import * 16 | from nt_xent import NTXentLoss 17 | 18 | from torch.backends import cudnn 19 | cudnn.deterministic = True 20 | cudnn.benchmark = False 21 | 22 | import argparse 23 | 24 | parser = argparse.ArgumentParser(description='Eval SIMCLR ECG') 25 | 26 | parser.add_argument('--seed', type=int, default=0) 27 | parser.add_argument('--runseed', type=int, default=0) 28 | parser.add_argument('--gpu', type=int, default=0) 29 | parser.add_argument('--ex', type=int, default=500, help='num data points') 30 | 31 | parser.add_argument('--finetune_lr', type=float, default=1e-3) 32 | parser.add_argument('--epochs', default=200, type=int) 33 | parser.add_argument('--studentarch', type=str, default='resnet18') 34 | parser.add_argument('--dataset', type=str, default='ecg') 35 | parser.add_argument('--batch_size', type=int, default=64) 36 | parser.add_argument('--savefol', type=str, default='simclr-ecg-eval') 37 | parser.add_argument('--transfer_eval', action='store_true') 38 | parser.add_argument('--checkpoint', type=str) 39 | 40 | args = parser.parse_args() 41 | 42 | torch.manual_seed(args.runseed) 43 | torch.multiprocessing.set_sharing_strategy('file_system') 44 | 45 | 46 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 47 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 48 | 49 | if args.transfer_eval: 50 | args.savefol += f'-transfereval-{args.ex}ex' 51 | else: 52 | args.savefol += f'-lineval-{args.ex}ex' 53 | 54 | class AverageMeter(object): 55 | """Computes and stores the average and current value""" 56 | def __init__(self): 57 | self.reset() 58 | 59 | def reset(self): 60 | self.val = 0 61 | self.avg = 0 62 | self.sum = 0 63 | self.count = 0 64 | 65 | def update(self, val, n=1): 66 | self.val = val 67 | self.sum += val * n 68 | self.count += n 69 | self.avg = self.sum / self.count 70 | 71 | 72 | def model_saver(epoch, student, head, teacher, pt_opt, pt_sched, ft_opt, hyp_opt, path): 73 | torch.save({ 74 | 'student_sd': student.state_dict(), 75 | 'teacher_sd': teacher.state_dict() if teacher is not None else None, 76 | 'head_sd': head.state_dict(), 77 | 'ft_opt_state_dict': ft_opt.state_dict(), 78 | }, path + f'/checkpoint_epoch{epoch}.pt') 79 | 80 | 81 | def get_save_path(): 82 | modfol = f"""seed{args.seed}-runseed{args.runseed}-student{args.studentarch}-ftlr{args.finetune_lr}-epochs{args.epochs}-ckpt{args.checkpoint}""" 83 | pth = os.path.join(args.savefol, modfol) 84 | os.makedirs(pth, exist_ok=True) 85 | return pth 86 | 87 | def get_loss(student,head,teacher, x, y): 88 | head_op = head(student.logits(x)) 89 | l_obj = nn.BCEWithLogitsLoss() 90 | clf_loss = l_obj(head_op, y) 91 | y_loss_stud = clf_loss 92 | acc_stud = 0 #torch.mean(torch.sigmoid(head_op) > 0.5 * y).item() 93 | return y_loss_stud, acc_stud 94 | 95 | # Utility function to update lossdict 96 | def update_lossdict(lossdict, update, action='append'): 97 | for k in update.keys(): 98 | if action == 'append': 99 | if k in lossdict: 100 | lossdict[k].append(update[k]) 101 | else: 102 | lossdict[k] = [update[k]] 103 | elif action == 'sum': 104 | if k in lossdict: 105 | lossdict[k] += update[k] 106 | else: 107 | lossdict[k] = update[k] 108 | else: 109 | raise NotImplementedError 110 | return lossdict 111 | 112 | 113 | from sklearn.metrics import roc_auc_score 114 | 115 | # Evaluate student on complete train/test set. 116 | def eval_student(student, head, dl): 117 | student.eval() 118 | net_loss = 0 119 | correct = 0 120 | y_pred = [] 121 | y_true = [] 122 | l_obj = nn.BCEWithLogitsLoss(reduction='sum') 123 | # clf_loss = l_obj(head_op, y) 124 | with torch.no_grad(): 125 | for data, target in dl: 126 | y_true.append(target.detach().cpu().numpy()) 127 | data, target = data.to(device), target.to(device) 128 | output = head(student.logits(data)) 129 | net_loss += l_obj(output, target).item() # sum up batch loss 130 | y_pred.append(output.detach().cpu().numpy()) 131 | # pred = torch.sigmoid(output) > 0.5 132 | # correct += torch.sum(pred == target).item() 133 | 134 | y_pred = np.concatenate(y_pred, axis=0) 135 | y_true = np.concatenate(y_true, axis=0) 136 | net_loss /= len(dl.dataset) 137 | # acc = 100. * correct / len(dl.dataset * y_pred.shape[1]) 138 | 139 | roc_list = [] 140 | for i in range(y_true.shape[1]): 141 | try: 142 | #AUC is only defined when there is at least one positive data. 143 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == 0) > 0: 144 | roc_list.append(roc_auc_score(y_true[:,i], y_pred[:,i])) 145 | else: 146 | roc_list.append(np.nan) 147 | except ValueError: 148 | roc_list.append(np.nan) 149 | 150 | print('Average loss: {:.4f}, AUC: {:.4f}'.format(net_loss, np.mean(roc_list))) 151 | return {'epoch_loss': net_loss, 'auc' : roc_list} 152 | 153 | import copy 154 | 155 | def do_train_step(student, head, optimizer, x, y): 156 | student.eval() 157 | 158 | x = x.to(device) 159 | y = y.to(device) 160 | 161 | loss, acc = get_loss(student, head, None, x, y) 162 | optimizer.zero_grad() 163 | loss.backward() 164 | optimizer.step() 165 | 166 | return loss.item(), acc 167 | 168 | def train(): 169 | ft_loss_meter = AverageMeter() 170 | ft_acc_meter = AverageMeter() 171 | 172 | DSHandle = ECGDataSetWrapper(args.batch_size) 173 | pretrain_dl, train_dl, val_dl, test_dl, _, NUM_TASKS_FT = DSHandle.get_data_loaders(args, evaluate=True) 174 | 175 | torch.manual_seed(args.runseed) 176 | import random 177 | random.seed(args.runseed) 178 | np.random.seed(args.runseed) 179 | 180 | if args.studentarch == 'resnet18': 181 | student = ecg_simclr_resnet18().to(device) 182 | head = MultitaskHead(256, NUM_TASKS_FT).to(device) 183 | else: 184 | raise NotImplementedError 185 | 186 | if args.checkpoint is None: 187 | print("No checkpoint! Training from scratch") 188 | else: 189 | ckpt = torch.load(args.checkpoint) 190 | student.load_state_dict(ckpt['student_sd'], strict=False) 191 | # head.load_state_dict(ckpt['head_sd']) # Not loading head, so commented out 192 | print("Loading student; not doing pretraining") 193 | 194 | if args.transfer_eval: 195 | finetune_optim = torch.optim.Adam(list(head.parameters()) + list(student.parameters()), lr=args.finetune_lr) 196 | else: 197 | finetune_optim = torch.optim.Adam(head.parameters(), lr=args.finetune_lr) 198 | 199 | stud_finetune_train_ld = {'loss' : [], 'acc' : []} 200 | stud_finetune_val_ld = {'loss' : [], 'acc' : []} 201 | stud_finetune_test_ld = {} 202 | 203 | for n in range(args.epochs): 204 | progress_bar = tqdm(train_dl) 205 | for i, (x,y) in enumerate(progress_bar): 206 | progress_bar.set_description('Finetune Epoch ' + str(n)) 207 | ft_train_loss, ft_train_acc = do_train_step(student, head, finetune_optim, x, y) 208 | ft_loss_meter.update(ft_train_loss) 209 | ft_acc_meter.update(ft_train_acc) 210 | progress_bar.set_postfix( 211 | finetune_train_loss='%.4f' % ft_loss_meter.avg , 212 | finetune_train_acc='%.4f' % ft_acc_meter.avg , 213 | ) 214 | # append to lossdict 215 | stud_finetune_train_ld['loss'].append(ft_train_loss) 216 | stud_finetune_train_ld['acc'].append(ft_train_acc) 217 | 218 | 219 | ft_test_ld = eval_student(student,head, test_dl) 220 | stud_finetune_test_ld = update_lossdict(stud_finetune_test_ld, ft_test_ld) 221 | 222 | ft_val_ld = eval_student(student,head, val_dl) 223 | stud_finetune_val_ld = update_lossdict(stud_finetune_val_ld, ft_val_ld) 224 | 225 | ft_train_ld = eval_student(student,head, train_dl) 226 | stud_finetune_train_ld = update_lossdict(stud_finetune_train_ld, ft_train_ld) 227 | # save the logs 228 | tosave = { 229 | 'finetune_train_ld' : stud_finetune_train_ld, 230 | 'finetune_val_ld' : stud_finetune_val_ld, 231 | 'finetune_test_ld' : stud_finetune_test_ld, 232 | } 233 | torch.save(tosave, os.path.join(get_save_path(), 'eval_logs.ckpt')) 234 | # if n % 5 == 0: 235 | # model_saver(n, student, head, pretrain_optim, finetune_optim, save_path) 236 | # print(f"Saved model at epoch {n}") 237 | ft_loss_meter.reset() 238 | ft_acc_meter.reset() 239 | return student, head, finetune_optim 240 | 241 | 242 | res = train() 243 | 244 | 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /multitask/finetune_ft30.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from loader import BioDataset 4 | from dataloader import DataLoaderFinetune 5 | from splitters import random_split, species_split 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | from model import GNN, GNN_graphpred 16 | from sklearn.metrics import roc_auc_score 17 | 18 | import pandas as pd 19 | 20 | import os 21 | import pickle 22 | 23 | criterion = nn.BCEWithLogitsLoss() 24 | 25 | def get_subset_fold(data): 26 | include_start = args.fold*10 27 | include_end = args.fold*10 + 10 28 | seg = data[:,include_start:include_end] 29 | # seg2 = data[:,exclude_end:] 30 | # return torch.cat([seg1, seg2], dim=1) 31 | return seg 32 | 33 | def train(args, model, device, loader, optimizer): 34 | if args.lineval: 35 | model.eval() 36 | else: 37 | model.train() 38 | 39 | for step, batch in enumerate(loader): 40 | batch = batch.to(device) 41 | pred = model(batch) 42 | # HACK to just eval on last 10 43 | y = batch.go_target_downstream.view(-1, 40) 44 | y = get_subset_fold(y).view(pred.shape).to(torch.float64) 45 | 46 | optimizer.zero_grad() 47 | loss = criterion(pred.double(), y) 48 | loss.backward() 49 | 50 | optimizer.step() 51 | 52 | 53 | def eval(args, model, device, loader): 54 | model.eval() 55 | y_true = [] 56 | y_scores = [] 57 | 58 | for step, batch in enumerate(loader): 59 | batch = batch.to(device) 60 | 61 | with torch.no_grad(): 62 | pred = model(batch) 63 | 64 | y = batch.go_target_downstream.view(-1, 40) 65 | y_true.append(get_subset_fold(y).detach().cpu()) 66 | y_scores.append(pred.detach().cpu()) 67 | 68 | y_true = torch.cat(y_true, dim = 0).numpy() 69 | y_scores = torch.cat(y_scores, dim = 0).numpy() 70 | 71 | roc_list = [] 72 | for i in range(y_true.shape[1]): 73 | #AUC is only defined when there is at least one positive data. 74 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == 0) > 0: 75 | roc_list.append(roc_auc_score(y_true[:,i], y_scores[:,i])) 76 | else: 77 | roc_list.append(np.nan) 78 | 79 | return np.array(roc_list) #y_true.shape[1] 80 | 81 | # def main(): 82 | # Training settings 83 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 84 | parser.add_argument('--device', type=int, default=0, 85 | help='which gpu to use if any (default: 0)') 86 | parser.add_argument('--batch_size', type=int, default=32, 87 | help='input batch size for training (default: 32)') 88 | parser.add_argument('--epochs', type=int, default=50, 89 | help='number of epochs to train (default: 50)') 90 | parser.add_argument('--lr', type=float, default=0.001, 91 | help='learning rate (default: 0.001)') 92 | parser.add_argument('--decay', type=float, default=0, 93 | help='weight decay (default: 0)') 94 | parser.add_argument('--num_layer', type=int, default=5, 95 | help='number of GNN message passing layers (default: 5).') 96 | parser.add_argument('--emb_dim', type=int, default=300, 97 | help='embedding dimensions (default: 300)') 98 | parser.add_argument('--dropout_ratio', type=float, default=0.5, 99 | help='dropout ratio (default: 0.5)') 100 | parser.add_argument('--graph_pooling', type=str, default="mean", 101 | help='graph level pooling (sum, mean, max, set2set, attention)') 102 | parser.add_argument('--JK', type=str, default="last", 103 | help='how the node features across layers are combined. last, sum, max or concat') 104 | parser.add_argument('--model_file', type=str, default = '', help='filename to read the model (if there is any)') 105 | parser.add_argument('--filename', type=str, default = '', help='output filename') 106 | parser.add_argument('--gnn_type', type=str, default="gin") 107 | parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting dataset.") 108 | parser.add_argument('--runseed', type=int, default=0, help = "Seed for running experiments.") 109 | parser.add_argument('--num_workers', type=int, default = 0, help='number of workers for dataset loading') 110 | parser.add_argument('--eval_train', type=int, default = 0, help='evaluating training or not') 111 | parser.add_argument('--split', type=str, default = "species", help='Random or species split') 112 | parser.add_argument('--lineval', action='store_true') 113 | parser.add_argument('--fold', type=int, default=3, help='fold*10:fold*10+10 defines test tasks') 114 | args = parser.parse_args() 115 | 116 | if args.lineval: 117 | args.filename += '_lineval' 118 | 119 | torch.manual_seed(args.runseed) 120 | np.random.seed(args.runseed) 121 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 122 | if torch.cuda.is_available(): 123 | torch.cuda.manual_seed_all(args.runseed) 124 | 125 | 126 | # Need to set this! 127 | root_supervised = '/path/to/dataset/' 128 | 129 | dataset = BioDataset(root_supervised, data_type='supervised') 130 | 131 | print(dataset) 132 | 133 | if args.split == "random": 134 | print("random splitting") 135 | train_dataset, valid_dataset, test_dataset = random_split(dataset, seed = args.seed) 136 | elif args.split == "species": 137 | trainval_dataset, test_dataset = species_split(dataset) 138 | train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) 139 | test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed = args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) 140 | print("species splitting") 141 | else: 142 | raise ValueError("Unknown split name.") 143 | 144 | train_loader = DataLoaderFinetune(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 145 | val_loader = DataLoaderFinetune(valid_dataset, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 146 | 147 | if args.split == "random": 148 | test_loader = DataLoaderFinetune(test_dataset, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 149 | else: 150 | ### for species splitting 151 | test_easy_loader = DataLoaderFinetune(test_dataset_broad, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 152 | test_hard_loader = DataLoaderFinetune(test_dataset_none, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 153 | 154 | num_tasks = 10 #len(dataset[0].go_target_downstream) 155 | 156 | print(train_dataset[0]) 157 | 158 | #set up model 159 | model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type) 160 | 161 | if not args.model_file == "": 162 | model.from_pretrained(args.model_file) 163 | 164 | model.to(device) 165 | 166 | #set up optimizer 167 | if args.lineval: 168 | optimizer = optim.Adam(model.graph_pred_linear.parameters(), lr=args.lr, weight_decay=args.decay) 169 | else: 170 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 171 | 172 | train_acc_list = [] 173 | val_acc_list = [] 174 | 175 | ### for random splitting 176 | test_acc_list = [] 177 | 178 | ### for species splitting 179 | test_acc_easy_list = [] 180 | test_acc_hard_list = [] 181 | 182 | 183 | if not args.filename == "": 184 | if os.path.exists(args.filename): 185 | print("removed existing file!!") 186 | os.remove(args.filename) 187 | 188 | for epoch in range(1, args.epochs+1): 189 | print("====epoch " + str(epoch)) 190 | 191 | train(args, model, device, train_loader, optimizer) 192 | 193 | print("====Evaluation") 194 | if args.eval_train: 195 | train_acc = eval(args, model, device, train_loader) 196 | else: 197 | train_acc = 0 198 | print("ommitting training evaluation") 199 | val_acc = eval(args, model, device, val_loader) 200 | 201 | val_acc_list.append(val_acc) 202 | train_acc_list.append(train_acc) 203 | 204 | if args.split == "random": 205 | test_acc = eval(args, model, device, test_loader) 206 | test_acc_list.append(test_acc) 207 | else: 208 | test_acc_easy = eval(args, model, device, test_easy_loader) 209 | test_acc_hard = eval(args, model, device, test_hard_loader) 210 | test_acc_easy_list.append(test_acc_easy) 211 | test_acc_hard_list.append(test_acc_hard) 212 | print(test_acc_easy) 213 | print(test_acc_hard) 214 | 215 | print("") 216 | 217 | os.makedirs(f"result-ft30-fold{args.fold}/finetune_seed" + str(args.runseed), exist_ok=True) 218 | # torch.save(model.state_dict(), 'models/' + args.filename) 219 | 220 | if not args.filename == "": 221 | with open(f"result-ft30-fold{args.fold}/finetune_seed" + str(args.runseed)+ "/" + args.filename, 'wb') as f: 222 | if args.split == "random": 223 | pickle.dump({"train": np.array(train_acc_list), "val": np.array(val_acc_list), "test": np.array(test_acc_list)}, f) 224 | else: 225 | pickle.dump({"train": np.array(train_acc_list), "val": np.array(val_acc_list), "test_easy": np.array(test_acc_easy_list), "test_hard": np.array(test_acc_hard_list)}, f) 226 | 227 | 228 | # if __name__ == "__main__": 229 | # main() 230 | -------------------------------------------------------------------------------- /multitask/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import networkx as nx 5 | import pandas as pd 6 | import numpy as np 7 | from torch.utils import data 8 | from torch_geometric.data import Data 9 | from torch_geometric.data import InMemoryDataset 10 | from torch_geometric.data import Batch 11 | from itertools import repeat, product, chain 12 | from collections import Counter, deque 13 | from networkx.algorithms.traversal.breadth_first_search import generic_bfs_edges 14 | 15 | def nx_to_graph_data_obj(g, center_id, allowable_features_downstream=None, 16 | allowable_features_pretrain=None, 17 | node_id_to_go_labels=None): 18 | """ 19 | Converts nx graph of PPI to pytorch geometric Data object. 20 | :param g: nx graph object of ego graph 21 | :param center_id: node id of center node in the ego graph 22 | :param allowable_features_downstream: list of possible go function node 23 | features for the downstream task. The resulting go_target_downstream node 24 | feature vector will be in this order. 25 | :param allowable_features_pretrain: list of possible go function node 26 | features for the pretraining task. The resulting go_target_pretrain node 27 | feature vector will be in this order. 28 | :param node_id_to_go_labels: dict that maps node id to a list of its 29 | corresponding go labels 30 | :return: pytorch geometric Data object with the following attributes: 31 | edge_attr 32 | edge_index 33 | x 34 | species_id 35 | center_node_idx 36 | go_target_downstream (only if node_id_to_go_labels is not None) 37 | go_target_pretrain (only if node_id_to_go_labels is not None) 38 | """ 39 | n_nodes = g.number_of_nodes() 40 | n_edges = g.number_of_edges() 41 | 42 | # nodes 43 | nx_node_ids = [n_i for n_i in g.nodes()] # contains list of nx node ids 44 | # in a particular ordering. Will be used as a mapping to convert 45 | # between nx node ids and data obj node indices 46 | 47 | x = torch.tensor(np.ones(n_nodes).reshape(-1, 1), dtype=torch.float) 48 | # we don't have any node labels, so set to dummy 1. dim n_nodes x 1 49 | 50 | center_node_idx = nx_node_ids.index(center_id) 51 | center_node_idx = torch.tensor([center_node_idx], dtype=torch.long) 52 | 53 | # edges 54 | edges_list = [] 55 | edge_features_list = [] 56 | for node_1, node_2, attr_dict in g.edges(data=True): 57 | edge_feature = [attr_dict['w1'], attr_dict['w2'], attr_dict['w3'], 58 | attr_dict['w4'], attr_dict['w5'], attr_dict['w6'], 59 | attr_dict['w7'], 0, 0] # last 2 indicate self-loop 60 | # and masking 61 | edge_feature = np.array(edge_feature, dtype=int) 62 | # convert nx node ids to data obj node index 63 | i = nx_node_ids.index(node_1) 64 | j = nx_node_ids.index(node_2) 65 | edges_list.append((i, j)) 66 | edge_features_list.append(edge_feature) 67 | edges_list.append((j, i)) 68 | edge_features_list.append(edge_feature) 69 | 70 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] 71 | edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) 72 | 73 | # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] 74 | edge_attr = torch.tensor(np.array(edge_features_list), 75 | dtype=torch.float) 76 | 77 | try: 78 | species_id = int(nx_node_ids[0].split('.')[0]) # nx node id is of the form: 79 | # species_id.protein_id 80 | species_id = torch.tensor([species_id], dtype=torch.long) 81 | except: # occurs when nx node id has no species id info. For the extract 82 | # substructure context pair transform, where we convert a data obj to 83 | # a nx graph obj (which does not have original node id info) 84 | species_id = torch.tensor([0], dtype=torch.long) # dummy species 85 | # id is 0 86 | 87 | # construct data obj 88 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) 89 | data.species_id = species_id 90 | data.center_node_idx = center_node_idx 91 | 92 | if node_id_to_go_labels: # supervised case with go node labels 93 | # Construct a dim n_pretrain_go_classes tensor and a 94 | # n_downstream_go_classes tensor for the center node. 0 is no data 95 | # or negative, 1 is positive. 96 | downstream_go_node_feature = [0] * len(allowable_features_downstream) 97 | pretrain_go_node_feature = [0] * len(allowable_features_pretrain) 98 | if center_id in node_id_to_go_labels: 99 | go_labels = node_id_to_go_labels[center_id] 100 | # get indices of allowable_features_downstream that match with elements 101 | # in go_labels 102 | _, node_feature_indices, _ = np.intersect1d( 103 | allowable_features_downstream, go_labels, return_indices=True) 104 | for idx in node_feature_indices: 105 | downstream_go_node_feature[idx] = 1 106 | # get indices of allowable_features_pretrain that match with 107 | # elements in go_labels 108 | _, node_feature_indices, _ = np.intersect1d( 109 | allowable_features_pretrain, go_labels, return_indices=True) 110 | for idx in node_feature_indices: 111 | pretrain_go_node_feature[idx] = 1 112 | data.go_target_downstream = torch.tensor(np.array(downstream_go_node_feature), 113 | dtype=torch.long) 114 | data.go_target_pretrain = torch.tensor(np.array(pretrain_go_node_feature), 115 | dtype=torch.long) 116 | 117 | return data 118 | 119 | def graph_data_obj_to_nx(data): 120 | """ 121 | Converts pytorch geometric Data obj to network x data object. 122 | :param data: pytorch geometric Data object 123 | :return: nx graph object 124 | """ 125 | G = nx.Graph() 126 | 127 | # edges 128 | edge_index = data.edge_index.cpu().numpy() 129 | edge_attr = data.edge_attr.cpu().numpy() 130 | n_edges = edge_index.shape[1] 131 | for j in range(0, n_edges, 2): 132 | begin_idx = int(edge_index[0, j]) 133 | end_idx = int(edge_index[1, j]) 134 | w1, w2, w3, w4, w5, w6, w7, _, _ = edge_attr[j].astype(bool) 135 | if not G.has_edge(begin_idx, end_idx): 136 | G.add_edge(begin_idx, end_idx, w1=w1, w2=w2, w3=w3, w4=w4, w5=w5, 137 | w6=w6, w7=w7) 138 | 139 | # # add center node id information in final nx graph object 140 | # nx.set_node_attributes(G, {data.center_node_idx.item(): True}, 'is_centre') 141 | 142 | return G 143 | 144 | class BioDataset(InMemoryDataset): 145 | def __init__(self, 146 | root, 147 | data_type, 148 | empty=False, 149 | transform=None, 150 | pre_transform=None, 151 | pre_filter=None): 152 | """ 153 | Adapted from qm9.py. Disabled the download functionality 154 | :param root: the data directory that contains a raw and processed dir 155 | :param data_type: either supervised or unsupervised 156 | :param empty: if True, then will not load any data obj. For 157 | initializing empty dataset 158 | :param transform: 159 | :param pre_transform: 160 | :param pre_filter: 161 | """ 162 | self.root = root 163 | self.data_type = data_type 164 | # print(root) 165 | 166 | super(BioDataset, self).__init__(root, transform, pre_transform, pre_filter) 167 | if not empty: 168 | self.data, self.slices = torch.load(self.processed_paths[0]) 169 | 170 | @property 171 | def raw_file_names(self): 172 | #raise NotImplementedError('Data is assumed to be processed') 173 | if self.data_type == 'supervised': # 8 labelled species 174 | file_name_list = ['3702', '6239', '511145', '7227', '9606', '10090', '4932', '7955'] 175 | else: # unsupervised: 8 labelled species, and 42 top unlabelled species by n_nodes. 176 | file_name_list = ['3702', '6239', '511145', '7227', '9606', '10090', 177 | '4932', '7955', '3694', '39947', '10116', '443255', '9913', '13616', 178 | '3847', '4577', '8364', '9823', '9615', '9544', '9796', '3055', '7159', 179 | '9031', '7739', '395019', '88036', '9685', '9258', '9598', '485913', 180 | '44689', '9593', '7897', '31033', '749414', '59729', '536227', '4081', 181 | '8090', '9601', '749927', '13735', '448385', '457427', '3711', '479433', 182 | '479432', '28377', '9646'] 183 | return file_name_list 184 | 185 | 186 | @property 187 | def processed_file_names(self): 188 | return 'geometric_data_processed.pt' 189 | 190 | def download(self): 191 | raise NotImplementedError('Must indicate valid location of raw data. ' 192 | 'No download allowed') 193 | 194 | def process(self): 195 | raise NotImplementedError('Data is assumed to be processed') 196 | 197 | if __name__ == "__main__": 198 | 199 | 200 | 201 | root_supervised = 'dataset/supervised' 202 | 203 | d_supervised = BioDataset(root_supervised, data_type='supervised') 204 | 205 | print(d_supervised) 206 | 207 | root_unsupervised = 'dataset/unsupervised' 208 | d_unsupervised = BioDataset(root_unsupervised, data_type='unsupervised') 209 | 210 | print(d_unsupervised) 211 | 212 | 213 | -------------------------------------------------------------------------------- /multitask/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from loader import BioDataset 4 | from dataloader import DataLoaderFinetune 5 | from splitters import random_split, species_split 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | from model import GNN, GNN_graphpred 16 | from sklearn.metrics import roc_auc_score 17 | 18 | import pandas as pd 19 | 20 | import os 21 | import pickle 22 | 23 | criterion = nn.BCEWithLogitsLoss() 24 | 25 | def train(args, model, device, loader, optimizer): 26 | if args.lineval: 27 | model.eval() 28 | else: 29 | model.train() 30 | 31 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 32 | batch = batch.to(device) 33 | pred = model(batch) 34 | y = batch.go_target_downstream.view(pred.shape).to(torch.float64) 35 | 36 | optimizer.zero_grad() 37 | loss = criterion(pred.double(), y) 38 | loss.backward() 39 | 40 | optimizer.step() 41 | 42 | 43 | def eval(args, model, device, loader): 44 | model.eval() 45 | y_true = [] 46 | y_scores = [] 47 | 48 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 49 | batch = batch.to(device) 50 | 51 | with torch.no_grad(): 52 | pred = model(batch) 53 | 54 | y_true.append(batch.go_target_downstream.view(pred.shape).detach().cpu()) 55 | y_scores.append(pred.detach().cpu()) 56 | 57 | y_true = torch.cat(y_true, dim = 0).numpy() 58 | y_scores = torch.cat(y_scores, dim = 0).numpy() 59 | 60 | roc_list = [] 61 | for i in range(y_true.shape[1]): 62 | #AUC is only defined when there is at least one positive data. 63 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == 0) > 0: 64 | roc_list.append(roc_auc_score(y_true[:,i], y_scores[:,i])) 65 | else: 66 | roc_list.append(np.nan) 67 | 68 | return np.array(roc_list) #y_true.shape[1] 69 | 70 | def main(): 71 | # Training settings 72 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 73 | parser.add_argument('--device', type=int, default=0, 74 | help='which gpu to use if any (default: 0)') 75 | parser.add_argument('--batch_size', type=int, default=32, 76 | help='input batch size for training (default: 32)') 77 | parser.add_argument('--epochs', type=int, default=50, 78 | help='number of epochs to train (default: 50)') 79 | parser.add_argument('--lr', type=float, default=0.001, 80 | help='learning rate (default: 0.001)') 81 | parser.add_argument('--decay', type=float, default=0, 82 | help='weight decay (default: 0)') 83 | parser.add_argument('--num_layer', type=int, default=5, 84 | help='number of GNN message passing layers (default: 5).') 85 | parser.add_argument('--emb_dim', type=int, default=300, 86 | help='embedding dimensions (default: 300)') 87 | parser.add_argument('--dropout_ratio', type=float, default=0.5, 88 | help='dropout ratio (default: 0.5)') 89 | parser.add_argument('--graph_pooling', type=str, default="mean", 90 | help='graph level pooling (sum, mean, max, set2set, attention)') 91 | parser.add_argument('--JK', type=str, default="last", 92 | help='how the node features across layers are combined. last, sum, max or concat') 93 | parser.add_argument('--model_file', type=str, default = '', help='filename to read the model (if there is any)') 94 | parser.add_argument('--filename', type=str, default = '', help='output filename') 95 | parser.add_argument('--gnn_type', type=str, default="gin") 96 | parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting dataset.") 97 | parser.add_argument('--runseed', type=int, default=0, help = "Seed for running experiments.") 98 | parser.add_argument('--num_workers', type=int, default = 0, help='number of workers for dataset loading') 99 | parser.add_argument('--eval_train', type=int, default = 0, help='evaluating training or not') 100 | parser.add_argument('--split', type=str, default = "species", help='Random or species split') 101 | parser.add_argument('--lineval', action='store_true') 102 | args = parser.parse_args() 103 | 104 | if args.lineval: 105 | args.filename += '_lineval' 106 | 107 | torch.manual_seed(args.runseed) 108 | np.random.seed(args.runseed) 109 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 110 | if torch.cuda.is_available(): 111 | torch.cuda.manual_seed_all(args.runseed) 112 | 113 | print("Checking for completion...") 114 | savepath = "result/finetune_seed" + str(args.runseed)+ "/" + args.filename 115 | if os.path.exists(savepath): 116 | print("Done this one: ", savepath) 117 | return 118 | 119 | 120 | # Need to set this! 121 | root_supervised = '/path/to/dataset/' 122 | 123 | dataset = BioDataset(root_supervised, data_type='supervised') 124 | 125 | print(dataset) 126 | 127 | if args.split == "random": 128 | print("random splitting") 129 | train_dataset, valid_dataset, test_dataset = random_split(dataset, seed = args.seed) 130 | elif args.split == "species": 131 | trainval_dataset, test_dataset = species_split(dataset) 132 | train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) 133 | test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed = args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) 134 | print("species splitting") 135 | else: 136 | raise ValueError("Unknown split name.") 137 | 138 | train_loader = DataLoaderFinetune(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 139 | val_loader = DataLoaderFinetune(valid_dataset, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 140 | 141 | if args.split == "random": 142 | test_loader = DataLoaderFinetune(test_dataset, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 143 | else: 144 | ### for species splitting 145 | test_easy_loader = DataLoaderFinetune(test_dataset_broad, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 146 | test_hard_loader = DataLoaderFinetune(test_dataset_none, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 147 | 148 | num_tasks = len(dataset[0].go_target_downstream) 149 | 150 | print(train_dataset[0]) 151 | 152 | #set up model 153 | model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type) 154 | 155 | if not args.model_file == "": 156 | model.from_pretrained(args.model_file) 157 | 158 | model.to(device) 159 | 160 | #set up optimizer 161 | if args.lineval: 162 | optimizer = optim.Adam(model.graph_pred_linear.parameters(), lr=args.lr, weight_decay=args.decay) 163 | else: 164 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 165 | 166 | train_acc_list = [] 167 | val_acc_list = [] 168 | 169 | ### for random splitting 170 | test_acc_list = [] 171 | 172 | ### for species splitting 173 | test_acc_easy_list = [] 174 | test_acc_hard_list = [] 175 | 176 | 177 | if not args.filename == "": 178 | if os.path.exists(args.filename): 179 | print("removed existing file!!") 180 | os.remove(args.filename) 181 | 182 | for epoch in range(1, args.epochs+1): 183 | print("====epoch " + str(epoch)) 184 | 185 | train(args, model, device, train_loader, optimizer) 186 | 187 | print("====Evaluation") 188 | if args.eval_train: 189 | train_acc = eval(args, model, device, train_loader) 190 | else: 191 | train_acc = 0 192 | print("ommitting training evaluation") 193 | val_acc = eval(args, model, device, val_loader) 194 | 195 | val_acc_list.append(val_acc) 196 | train_acc_list.append(train_acc) 197 | 198 | if args.split == "random": 199 | test_acc = eval(args, model, device, test_loader) 200 | test_acc_list.append(test_acc) 201 | else: 202 | test_acc_easy = eval(args, model, device, test_easy_loader) 203 | test_acc_hard = eval(args, model, device, test_hard_loader) 204 | test_acc_easy_list.append(test_acc_easy) 205 | test_acc_hard_list.append(test_acc_hard) 206 | print(test_acc_easy) 207 | print(test_acc_hard) 208 | 209 | print("") 210 | 211 | os.makedirs("result/finetune_seed" + str(args.runseed), exist_ok=True) 212 | torch.save(model.state_dict(), 'models/' + args.filename) 213 | 214 | if not args.filename == "": 215 | with open("result/finetune_seed" + str(args.runseed)+ "/" + args.filename, 'wb') as f: 216 | if args.split == "random": 217 | pickle.dump({"train": np.array(train_acc_list), "val": np.array(val_acc_list), "test": np.array(test_acc_list)}, f) 218 | else: 219 | pickle.dump({"train": np.array(train_acc_list), "val": np.array(val_acc_list), "test_easy": np.array(test_acc_easy_list), "test_hard": np.array(test_acc_hard_list)}, f) 220 | 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /multitask/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import networkx as nx 5 | from loader import BioDataset, graph_data_obj_to_nx, nx_to_graph_data_obj 6 | 7 | def combine_dataset(dataset1, dataset2): 8 | data_list = [data for data in dataset1] 9 | data_list.extend([data for data in dataset2]) 10 | # NEED TO SET THIS 11 | root_supervised = '/path/to/dataset/' 12 | dataset = BioDataset(root_supervised, data_type='supervised', empty = True) 13 | 14 | dataset.data, dataset.slices = dataset.collate(data_list) 15 | return dataset 16 | 17 | class NegativeEdge: 18 | def __init__(self): 19 | """ 20 | Randomly sample negative edges 21 | """ 22 | pass 23 | 24 | def __call__(self, data): 25 | num_nodes = data.num_nodes 26 | num_edges = data.num_edges 27 | 28 | edge_set = set([str(data.edge_index[0,i].cpu().item()) + "," + str(data.edge_index[1,i].cpu().item()) for i in range(data.edge_index.shape[1])]) 29 | 30 | redandunt_sample = torch.randint(0, num_nodes, (2,5*num_edges)) 31 | sampled_ind = [] 32 | sampled_edge_set = set([]) 33 | for i in range(5*num_edges): 34 | node1 = redandunt_sample[0,i].cpu().item() 35 | node2 = redandunt_sample[1,i].cpu().item() 36 | edge_str = str(node1) + "," + str(node2) 37 | if not edge_str in edge_set and not edge_str in sampled_edge_set and not node1 == node2: 38 | sampled_edge_set.add(edge_str) 39 | sampled_ind.append(i) 40 | if len(sampled_ind) == num_edges/2: 41 | break 42 | 43 | data.negative_edge_index = redandunt_sample[:,sampled_ind] 44 | 45 | return data 46 | 47 | class MaskEdge: 48 | def __init__(self, mask_rate): 49 | """ 50 | Assume edge_attr is of the form: 51 | [w1, w2, w3, w4, w5, w6, w7, self_loop, mask] 52 | :param mask_rate: % of edges to be masked 53 | """ 54 | self.mask_rate = mask_rate 55 | 56 | def __call__(self, data, masked_edge_indices=None): 57 | """ 58 | 59 | :param data: pytorch geometric data object. Assume that the edge 60 | ordering is the default pytorch geometric ordering, where the two 61 | directions of a single edge occur in pairs. 62 | Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3], 63 | [1, 0, 2, 1, 3, 2]]) 64 | :param masked_edge_indices: If None, then randomly sample num_edges * mask_rate + 1 65 | number of edge indices. Otherwise should correspond to the 1st 66 | direction of an edge pair. ie all indices should be an even number 67 | :return: None, creates new attributes in the original data object: 68 | data.mask_edge_idx: indices of masked edges 69 | data.mask_edge_labels: corresponding ground truth edge feature for 70 | each masked edge 71 | data.edge_attr: modified in place: the edge features ( 72 | both directions) that correspond to the masked edges have the masked 73 | edge feature 74 | """ 75 | if masked_edge_indices == None: 76 | # sample x distinct edges to be masked, based on mask rate. But 77 | # will sample at least 1 edge 78 | num_edges = int(data.edge_index.size()[1] / 2) # num unique edges 79 | sample_size = int(num_edges * self.mask_rate + 1) 80 | # during sampling, we only pick the 1st direction of a particular 81 | # edge pair 82 | masked_edge_indices = [2 * i for i in random.sample(range( 83 | num_edges), sample_size)] 84 | 85 | data.masked_edge_idx = torch.tensor(np.array(masked_edge_indices)) 86 | 87 | # create ground truth edge features for the edges that correspond to 88 | # the masked indices 89 | mask_edge_labels_list = [] 90 | for idx in masked_edge_indices: 91 | mask_edge_labels_list.append(data.edge_attr[idx].view(1, -1)) 92 | data.mask_edge_label = torch.cat(mask_edge_labels_list, dim=0) 93 | 94 | # created new masked edge_attr, where both directions of the masked 95 | # edges have masked edge type. For message passing in gcn 96 | 97 | # append the 2nd direction of the masked edges 98 | all_masked_edge_indices = masked_edge_indices + [i + 1 for i in 99 | masked_edge_indices] 100 | for idx in all_masked_edge_indices: 101 | data.edge_attr[idx] = torch.tensor(np.array([0, 0, 0, 0, 0, 102 | 0, 0, 0, 1]), 103 | dtype=torch.float) 104 | 105 | return data 106 | # # debugging 107 | # print(masked_edge_indices) 108 | # print(all_masked_edge_indices) 109 | 110 | 111 | def reset_idxes(G): 112 | """ 113 | Resets node indices such that they are numbered from 0 to num_nodes - 1 114 | :param G: 115 | :return: copy of G with relabelled node indices, mapping 116 | """ 117 | mapping = {} 118 | for new_idx, old_idx in enumerate(G.nodes()): 119 | mapping[old_idx] = new_idx 120 | new_G = nx.relabel_nodes(G, mapping, copy=True) 121 | return new_G, mapping 122 | 123 | 124 | class ExtractSubstructureContextPair: 125 | def __init__(self, l1, center=True): 126 | """ 127 | Randomly selects a node from the data object, and adds attributes 128 | that contain the substructure that corresponds the whole graph, and the 129 | context substructures that corresponds to 130 | the subgraph that is between l1 and the edge of the graph. If 131 | center=True, then will select the center node as the root node. 132 | :param center: True, will select a center node as root node, otherwise 133 | randomly selects a node 134 | :param l1: 135 | """ 136 | self.center = center 137 | self.l1 = l1 138 | 139 | 140 | if self.l1 == 0: 141 | self.l1 = -1 142 | 143 | def __call__(self, data, root_idx=None): 144 | """ 145 | 146 | :param data: pytorch geometric data object 147 | :param root_idx: Usually None. Otherwise directly sets node idx of 148 | root ( 149 | for debugging only) 150 | :return: None. Creates new attributes in original data object: 151 | data.center_substruct_idx 152 | data.x_substruct 153 | data.edge_attr_substruct 154 | data.edge_index_substruct 155 | data.x_context 156 | data.edge_attr_context 157 | data.edge_index_context 158 | data.overlap_context_substruct_idx 159 | """ 160 | num_atoms = data.x.size()[0] 161 | G = graph_data_obj_to_nx(data) 162 | 163 | if root_idx == None: 164 | if self.center == True: 165 | root_idx = data.center_node_idx.item() 166 | else: 167 | root_idx = random.sample(range(num_atoms), 1)[0] 168 | 169 | # in the PPI case, the subgraph is the entire PPI graph 170 | data.x_substruct = data.x 171 | data.edge_attr_substruct = data.edge_attr 172 | data.edge_index_substruct = data.edge_index 173 | data.center_substruct_idx = data.center_node_idx 174 | 175 | 176 | # Get context that is between l1 and the max diameter of the PPI graph 177 | l1_node_idxes = nx.single_source_shortest_path_length(G, root_idx, 178 | self.l1).keys() 179 | # l2_node_idxes = nx.single_source_shortest_path_length(G, root_idx, 180 | # self.l2).keys() 181 | l2_node_idxes = range(num_atoms) 182 | context_node_idxes = set(l1_node_idxes).symmetric_difference( 183 | set(l2_node_idxes)) 184 | if len(context_node_idxes) > 0: 185 | context_G = G.subgraph(context_node_idxes) 186 | context_G, context_node_map = reset_idxes(context_G) # need to 187 | # reset node idx to 0 -> num_nodes - 1, other data obj does not 188 | # make sense 189 | context_data = nx_to_graph_data_obj(context_G, 0) # use a dummy 190 | # center node idx 191 | data.x_context = context_data.x 192 | data.edge_attr_context = context_data.edge_attr 193 | data.edge_index_context = context_data.edge_index 194 | 195 | # Get indices of overlapping nodes between substruct and context, 196 | # WRT context ordering 197 | context_substruct_overlap_idxes = list(context_node_idxes) 198 | if len(context_substruct_overlap_idxes) > 0: 199 | context_substruct_overlap_idxes_reorder = [context_node_map[old_idx] 200 | for 201 | old_idx in 202 | context_substruct_overlap_idxes] 203 | data.overlap_context_substruct_idx = \ 204 | torch.tensor(context_substruct_overlap_idxes_reorder) 205 | 206 | return data 207 | 208 | def __repr__(self): 209 | return '{}(l1={}, center={})'.format(self.__class__.__name__, 210 | self.l1, self.center) 211 | 212 | if __name__ == "__main__": 213 | root_supervised = 'dataset/supervised' 214 | thresholds = [266, 1, 777, 652, 300, 900, 670] 215 | d_supervised = BioDataset(root_supervised, thresholds, 216 | max_search_depth=2, max_n_neighbors_sampled=10, 217 | n_subgraphs=1e12, data_type='supervised') 218 | # test ExtractSubstructureContextPair for PPI networks 219 | data = d_supervised[0] 220 | sub_context_transform = ExtractSubstructureContextPair(1, center=True) 221 | sub_context_transform(data) 222 | -------------------------------------------------------------------------------- /multitask/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, Batch 3 | 4 | class BatchFinetune(Data): 5 | r"""A plain old python object modeling a batch of graphs as one big 6 | (dicconnected) graph. With :class:`torch_geometric.data.Data` being the 7 | base class, all its methods can also be used here. 8 | In addition, single graphs can be reconstructed via the assignment vector 9 | :obj:`batch`, which maps each node to its respective graph identifier. 10 | """ 11 | 12 | def __init__(self, batch=None, **kwargs): 13 | super(BatchMasking, self).__init__(**kwargs) 14 | self.batch = batch 15 | 16 | @staticmethod 17 | def from_data_list(data_list): 18 | r"""Constructs a batch object from a python list holding 19 | :class:`torch_geometric.data.Data` objects. 20 | The assignment vector :obj:`batch` is created on the fly.""" 21 | keys = [set(data.keys) for data in data_list] 22 | keys = list(set.union(*keys)) 23 | assert 'batch' not in keys 24 | 25 | batch = BatchMasking() 26 | 27 | for key in keys: 28 | batch[key] = [] 29 | batch.batch = [] 30 | 31 | cumsum_node = 0 32 | cumsum_edge = 0 33 | 34 | for i, data in enumerate(data_list): 35 | num_nodes = data.num_nodes 36 | batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) 37 | for key in data.keys: 38 | item = data[key] 39 | if key in ['edge_index', 'center_node_idx']: 40 | item = item + cumsum_node 41 | batch[key].append(item) 42 | 43 | cumsum_node += num_nodes 44 | cumsum_edge += data.edge_index.shape[1] 45 | 46 | for key in keys: 47 | batch[key] = torch.cat( 48 | batch[key], dim=data_list[0].__cat_dim__(key, batch[key][0])) 49 | batch.batch = torch.cat(batch.batch, dim=-1) 50 | return batch.contiguous() 51 | 52 | @property 53 | def num_graphs(self): 54 | """Returns the number of graphs in the batch.""" 55 | return self.batch[-1].item() + 1 56 | 57 | 58 | class BatchMasking(Data): 59 | r"""A plain old python object modeling a batch of graphs as one big 60 | (dicconnected) graph. With :class:`torch_geometric.data.Data` being the 61 | base class, all its methods can also be used here. 62 | In addition, single graphs can be reconstructed via the assignment vector 63 | :obj:`batch`, which maps each node to its respective graph identifier. 64 | """ 65 | 66 | def __init__(self, batch=None, **kwargs): 67 | super(BatchMasking, self).__init__(**kwargs) 68 | self.batch = batch 69 | 70 | @staticmethod 71 | def from_data_list(data_list): 72 | r"""Constructs a batch object from a python list holding 73 | :class:`torch_geometric.data.Data` objects. 74 | The assignment vector :obj:`batch` is created on the fly.""" 75 | keys = [set(data.keys) for data in data_list] 76 | keys = list(set.union(*keys)) 77 | assert 'batch' not in keys 78 | 79 | batch = BatchMasking() 80 | 81 | for key in keys: 82 | batch[key] = [] 83 | batch.batch = [] 84 | 85 | cumsum_node = 0 86 | cumsum_edge = 0 87 | 88 | for i, data in enumerate(data_list): 89 | num_nodes = data.num_nodes 90 | batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) 91 | for key in data.keys: 92 | item = data[key] 93 | if key in ['edge_index']: 94 | item = item + cumsum_node 95 | elif key == 'masked_edge_idx': 96 | item = item + cumsum_edge 97 | batch[key].append(item) 98 | 99 | cumsum_node += num_nodes 100 | cumsum_edge += data.edge_index.shape[1] 101 | 102 | for key in keys: 103 | batch[key] = torch.cat( 104 | batch[key], dim=data_list[0].__cat_dim__(key, batch[key][0])) 105 | batch.batch = torch.cat(batch.batch, dim=-1) 106 | return batch.contiguous() 107 | 108 | def cumsum(self, key, item): 109 | r"""If :obj:`True`, the attribute :obj:`key` with content :obj:`item` 110 | should be added up cumulatively before concatenated together. 111 | .. note:: 112 | This method is for internal use only, and should only be overridden 113 | if the batch concatenation process is corrupted for a specific data 114 | attribute. 115 | """ 116 | return key in ['edge_index', 'face', 'masked_atom_indices', 'connected_edge_indices'] 117 | 118 | @property 119 | def num_graphs(self): 120 | """Returns the number of graphs in the batch.""" 121 | return self.batch[-1].item() + 1 122 | 123 | class BatchAE(Data): 124 | r"""A plain old python object modeling a batch of graphs as one big 125 | (dicconnected) graph. With :class:`torch_geometric.data.Data` being the 126 | base class, all its methods can also be used here. 127 | In addition, single graphs can be reconstructed via the assignment vector 128 | :obj:`batch`, which maps each node to its respective graph identifier. 129 | """ 130 | 131 | def __init__(self, batch=None, **kwargs): 132 | super(BatchAE, self).__init__(**kwargs) 133 | self.batch = batch 134 | 135 | @staticmethod 136 | def from_data_list(data_list): 137 | r"""Constructs a batch object from a python list holding 138 | :class:`torch_geometric.data.Data` objects. 139 | The assignment vector :obj:`batch` is created on the fly.""" 140 | keys = [set(data.keys) for data in data_list] 141 | keys = list(set.union(*keys)) 142 | assert 'batch' not in keys 143 | 144 | batch = BatchAE() 145 | 146 | for key in keys: 147 | batch[key] = [] 148 | batch.batch = [] 149 | 150 | cumsum_node = 0 151 | 152 | for i, data in enumerate(data_list): 153 | num_nodes = data.num_nodes 154 | batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) 155 | for key in data.keys: 156 | item = data[key] 157 | if key in ['edge_index', 'negative_edge_index']: 158 | item = item + cumsum_node 159 | batch[key].append(item) 160 | 161 | cumsum_node += num_nodes 162 | 163 | for key in keys: 164 | batch[key] = torch.cat( 165 | batch[key], dim=batch.cat_dim(key)) 166 | batch.batch = torch.cat(batch.batch, dim=-1) 167 | return batch.contiguous() 168 | 169 | @property 170 | def num_graphs(self): 171 | """Returns the number of graphs in the batch.""" 172 | return self.batch[-1].item() + 1 173 | 174 | def cat_dim(self, key): 175 | return -1 if key in ["edge_index", "negative_edge_index"] else 0 176 | 177 | 178 | 179 | class BatchSubstructContext(Data): 180 | r"""A plain old python object modeling a batch of graphs as one big 181 | (dicconnected) graph. With :class:`torch_geometric.data.Data` being the 182 | base class, all its methods can also be used here. 183 | In addition, single graphs can be reconstructed via the assignment vector 184 | :obj:`batch`, which maps each node to its respective graph identifier. 185 | """ 186 | 187 | """ 188 | Specialized batching for substructure context pair! 189 | """ 190 | 191 | def __init__(self, batch=None, **kwargs): 192 | super(BatchSubstructContext, self).__init__(**kwargs) 193 | self.batch = batch 194 | 195 | @staticmethod 196 | def from_data_list(data_list): 197 | r"""Constructs a batch object from a python list holding 198 | :class:`torch_geometric.data.Data` objects. 199 | The assignment vector :obj:`batch` is created on the fly.""" 200 | #keys = [set(data.keys) for data in data_list] 201 | #keys = list(set.union(*keys)) 202 | #assert 'batch' not in keys 203 | 204 | batch = BatchSubstructContext() 205 | keys = ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct", "overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"] 206 | 207 | for key in keys: 208 | #print(key) 209 | batch[key] = [] 210 | 211 | #batch.batch = [] 212 | #used for pooling the context 213 | batch.batch_overlapped_context = [] 214 | batch.overlapped_context_size = [] 215 | 216 | cumsum_main = 0 217 | cumsum_substruct = 0 218 | cumsum_context = 0 219 | 220 | i = 0 221 | 222 | for data in data_list: 223 | #If there is no context, just skip!! 224 | if hasattr(data, "x_context"): 225 | num_nodes = data.num_nodes 226 | num_nodes_substruct = len(data.x_substruct) 227 | num_nodes_context = len(data.x_context) 228 | 229 | #batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) 230 | batch.batch_overlapped_context.append(torch.full((len(data.overlap_context_substruct_idx), ), i, dtype=torch.long)) 231 | batch.overlapped_context_size.append(len(data.overlap_context_substruct_idx)) 232 | 233 | ###batching for the main graph 234 | #for key in data.keys: 235 | # if not "context" in key and not "substruct" in key: 236 | # item = data[key] 237 | # item = item + cumsum_main if batch.cumsum(key, item) else item 238 | # batch[key].append(item) 239 | 240 | ###batching for the substructure graph 241 | for key in ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct"]: 242 | item = data[key] 243 | item = item + cumsum_substruct if batch.cumsum(key, item) else item 244 | batch[key].append(item) 245 | 246 | 247 | ###batching for the context graph 248 | for key in ["overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"]: 249 | item = data[key] 250 | item = item + cumsum_context if batch.cumsum(key, item) else item 251 | batch[key].append(item) 252 | 253 | cumsum_main += num_nodes 254 | cumsum_substruct += num_nodes_substruct 255 | cumsum_context += num_nodes_context 256 | i += 1 257 | 258 | for key in keys: 259 | batch[key] = torch.cat( 260 | batch[key], dim=batch.cat_dim(key)) 261 | #batch.batch = torch.cat(batch.batch, dim=-1) 262 | batch.batch_overlapped_context = torch.cat(batch.batch_overlapped_context, dim=-1) 263 | batch.overlapped_context_size = torch.LongTensor(batch.overlapped_context_size) 264 | 265 | return batch.contiguous() 266 | 267 | def cat_dim(self, key): 268 | return -1 if key in ["edge_index", "edge_index_substruct", "edge_index_context"] else 0 269 | 270 | def cumsum(self, key, item): 271 | r"""If :obj:`True`, the attribute :obj:`key` with content :obj:`item` 272 | should be added up cumulatively before concatenated together. 273 | .. note:: 274 | This method is for internal use only, and should only be overridden 275 | if the batch concatenation process is corrupted for a specific data 276 | attribute. 277 | """ 278 | return key in ["edge_index", "edge_index_substruct", "edge_index_context", "overlap_context_substruct_idx", "center_substruct_idx"] 279 | 280 | @property 281 | def num_graphs(self): 282 | """Returns the number of graphs in the batch.""" 283 | return self.batch[-1].item() + 1 284 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /simclr/simclr_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.distributions.normal import Normal 6 | 7 | 8 | 9 | def conv5x1(in_channels, out_channels, stride=1): 10 | return nn.Conv1d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2) 11 | 12 | def conv9x1(in_channels, out_channels, stride=1): 13 | return nn.Conv1d(in_channels, out_channels, kernel_size=9, stride=stride, padding=4) 14 | 15 | def conv15x1(in_channels, out_channels, stride=1): 16 | return nn.Conv1d(in_channels, out_channels, kernel_size=15, stride=stride, padding=7) 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | norm_layer = nn.BatchNorm1d 24 | 25 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 26 | self.conv1 = conv15x1(in_channels, out_channels, stride) 27 | self.bn1 = norm_layer(out_channels) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = conv15x1(out_channels, out_channels) 30 | self.bn2 = norm_layer(out_channels) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | identity = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | # out = self.dropout(out) 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | class ResNet(nn.Module): 54 | 55 | def __init__(self, block, layers, num_outputs=5, zero_init_residual=False, 56 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 57 | norm_layer=nn.BatchNorm1d): 58 | 59 | super(ResNet, self).__init__() 60 | 61 | self.num_outputs = num_outputs 62 | 63 | self._norm_layer = norm_layer 64 | 65 | self.inplanes = 32 66 | 67 | self.conv1 = nn.Conv1d(12, self.inplanes, kernel_size=15, stride=2, padding=3, 68 | bias=False) 69 | self.bn1 = norm_layer(self.inplanes) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.layer1 = self._make_layer(block, 32, layers[0]) 72 | self.layer2 = self._make_layer(block, 64, layers[1], stride=2) 73 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 74 | self.layer4 = self._make_layer(block, 256, layers[3], stride=2) 75 | self.avgpool = nn.AdaptiveAvgPool1d(1) 76 | 77 | self.proj = nn.Sequential(nn.Linear(256, 256, bias=False), 78 | nn.ReLU(inplace=True), nn.Linear(256, 128, bias=True)) 79 | 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv1d): 82 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 83 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 84 | nn.init.constant_(m.weight, 1) 85 | nn.init.constant_(m.bias, 0) 86 | 87 | # Zero-initialize the last BN in each residual branch, 88 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 89 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 90 | if zero_init_residual: 91 | for m in self.modules(): 92 | if isinstance(m, Bottleneck): 93 | nn.init.constant_(m.bn3.weight, 0) 94 | elif isinstance(m, BasicBlock): 95 | nn.init.constant_(m.bn2.weight, 0) 96 | 97 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 98 | norm_layer = self._norm_layer 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | # print("Got to downsample place") 102 | downsample = nn.Sequential( 103 | nn.Conv1d(self.inplanes, planes, kernel_size=1, stride=stride, bias=False), 104 | norm_layer(planes), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for _ in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def _forward_impl(self, x): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | x = self.layer1(x) 120 | x = self.layer2(x) 121 | x = self.layer3(x) 122 | x = self.layer4(x) 123 | x = self.avgpool(x) 124 | x = torch.flatten(x, 1) 125 | return x 126 | 127 | def logits(self,x): 128 | BS, L, C = x.shape 129 | x = x.transpose(1,2) 130 | return self._forward_impl(x) 131 | 132 | def forward(self, x): 133 | BS, L, C = x.shape 134 | x = x.transpose(1,2) 135 | feature = self._forward_impl(x) 136 | out = self.proj(feature) 137 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 138 | 139 | 140 | def _resnet(arch, block, layers, pretrained2, progress, **kwargs): 141 | model = ResNet(block, layers, **kwargs) 142 | return model 143 | 144 | 145 | def ecg_simclr_resnet18(pretrained=False, progress=True, **kwargs): 146 | r"""ResNet-18 model from 147 | `"Deep Residual Learning for Image Recognition" `_ 148 | 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | progress (bool): If True, displays a progress bar of the download to stderr 152 | """ 153 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 154 | **kwargs) 155 | 156 | 157 | 158 | 159 | from torch.distributions.normal import Normal 160 | from scipy.ndimage import gaussian_filter1d 161 | 162 | class MultitaskHead(torch.nn.Module): 163 | def __init__(self, feats, num_classes=5): 164 | super(MultitaskHead, self).__init__() 165 | self.fc_pi = torch.nn.Linear(feats, num_classes) 166 | 167 | def forward(self, x): 168 | out_pi = self.fc_pi(x) 169 | return out_pi 170 | 171 | 172 | #### Adapted from https://github.com/voxelmorph/voxelmorph 173 | class SpatialTransformer(nn.Module): 174 | """ 175 | N-D Spatial Transformer 176 | """ 177 | 178 | def __init__(self, size, mode='bilinear'): 179 | super().__init__() 180 | 181 | self.mode = mode 182 | 183 | # create sampling grid 184 | vectors = [torch.arange(0, s) for s in size] 185 | grids = torch.meshgrid(vectors) 186 | grid = torch.stack(grids) 187 | grid = torch.unsqueeze(grid, 0) 188 | grid = grid.type(torch.FloatTensor) 189 | 190 | # registering the grid as a buffer cleanly moves it to the GPU, but it also 191 | # adds it to the state dict. this is annoying since everything in the state dict 192 | # is included when saving weights to disk, so the model files are way bigger 193 | # than they need to be. so far, there does not appear to be an elegant solution. 194 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict 195 | self.register_buffer('grid', grid) 196 | # print("grid shape", grid.shape) 197 | 198 | def forward(self, src, flow): 199 | # new locations 200 | new_locs = self.grid + flow 201 | shape = flow.shape[2:] 202 | 203 | 204 | # need to normalize grid values to [-1, 1] for resampler 205 | for i in range(len(shape)): 206 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 207 | 208 | if len(src.shape) == 3: 209 | src = src.unsqueeze(-1).repeat(1,1,1,2) 210 | new_locs = new_locs.unsqueeze(-1).repeat(1,1,1,2) 211 | 212 | # move channels dim to last position 213 | # also not sure why, but the channels need to be reversed 214 | if len(shape) == 2: 215 | new_locs = new_locs.permute(0, 2, 3, 1) 216 | new_locs = new_locs[..., [1, 0]] 217 | elif len(shape) == 3: 218 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 219 | new_locs = new_locs[..., [2, 1, 0]] 220 | 221 | samp = F.grid_sample(src, new_locs, align_corners=True, mode=self.mode) 222 | return samp.squeeze(2) #samp[:,:,:,0] 223 | 224 | 225 | class VecInt(nn.Module): 226 | """ 227 | Integrates a vector field via scaling and squaring. 228 | """ 229 | 230 | def __init__(self, inshape, nsteps): 231 | super().__init__() 232 | 233 | assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps 234 | self.nsteps = nsteps 235 | self.scale = 1.0 / (2 ** self.nsteps) 236 | self.transformer = SpatialTransformer(inshape) 237 | 238 | def forward(self, vec): 239 | vec = vec * self.scale 240 | # print("vecshape", vec.shape) 241 | for _ in range(self.nsteps): 242 | vec = vec + self.transformer(vec, vec) 243 | return vec 244 | 245 | 246 | 247 | class ResizeTransformTime(nn.Module): 248 | """ 249 | Resize a transform, which involves resizing the vector field *and* rescaling it. 250 | """ 251 | 252 | def __init__(self, sf, ndims): 253 | super().__init__() 254 | self.sf = sf 255 | self.mode = 'linear' 256 | 257 | def forward(self, x): 258 | factor = self.sf 259 | if factor < 1: 260 | x = F.interpolate(x, align_corners=False, scale_factor=factor, mode=self.mode, recompute_scale_factor=False) 261 | x = factor * x 262 | elif factor > 1: 263 | x = factor * x 264 | x = F.interpolate(x, align_corners=False, scale_factor=factor, mode=self.mode, recompute_scale_factor=False) 265 | return x 266 | 267 | 268 | 269 | from scipy.ndimage import gaussian_filter1d 270 | 271 | 272 | # per example magnitude 273 | class RandWarpAugLearnExMag(nn.Module): 274 | def __init__(self, inshape, int_steps = 5, int_downsize = 4, flow_mag=4, smooth_size = 25): 275 | 276 | super().__init__() 277 | 278 | ndims=1 279 | self.inshape=inshape 280 | resize = int_steps > 0 and int_downsize > 1 281 | self.resize = ResizeTransformTime(1/int_downsize, ndims) if resize else None 282 | self.fullsize = ResizeTransformTime(int_downsize, ndims) if resize else None 283 | 284 | # configure optional integration layer for diffeomorphic warp 285 | down_shape = [inshape[0]//int_downsize] 286 | self.integrate = VecInt(down_shape, int_steps) if int_steps > 0 else None 287 | 288 | # configure transformer 289 | self.transformer = SpatialTransformer(inshape) 290 | 291 | # set up smoothing filter 292 | self.flow_mag = torch.nn.parameter.Parameter(torch.Tensor([float(flow_mag)])) 293 | self.smooth_size= smooth_size 294 | self.smooth_pad = smooth_centre = (smooth_size-1)//2 295 | smooth_kernel = np.zeros(smooth_size) 296 | smooth_kernel[smooth_centre] = 1 297 | filt = gaussian_filter1d(smooth_kernel, smooth_centre).astype(np.float32) 298 | self.smooth_kernel = torch.from_numpy(filt) 299 | 300 | self.net = nn.Sequential(nn.Conv1d(12, 32, 15,stride=2), 301 | nn.BatchNorm1d(32), 302 | nn.ReLU(inplace=True), 303 | nn.Conv1d(32, 32, 15, stride=2), 304 | nn.BatchNorm1d(32), 305 | nn.ReLU(inplace=True), 306 | nn.Conv1d(32, 32, 15, stride=2), 307 | nn.BatchNorm1d(32), 308 | nn.ReLU(inplace=True), 309 | nn.Conv1d(32, 32, 15, stride=2), 310 | nn.BatchNorm1d(32), 311 | nn.ReLU(inplace=True), 312 | nn.AdaptiveAvgPool1d(1), 313 | nn.Flatten()) 314 | 315 | self.flow_mag_layer = nn.Linear(32,1) 316 | 317 | def forward(self, source): 318 | BS, L, C = source.shape 319 | source = source.transpose(1,2) 320 | x=source 321 | fm = 2*torch.sigmoid(self.flow_mag_layer(self.net(x))) 322 | 323 | fm_std = 100*(self.flow_mag**2) 324 | 325 | flow_field = fm.view(BS, 1, 1) * fm_std*torch.randn(x.shape[0], 1, self.inshape[0]).to(x.device) 326 | 327 | # resize flow for integration 328 | pos_flow = flow_field 329 | if self.resize: 330 | pos_flow = self.resize(pos_flow) 331 | 332 | preint_flow = pos_flow 333 | 334 | # integrate to produce diffeomorphic warp 335 | if self.integrate: 336 | pos_flow = self.integrate(pos_flow) 337 | 338 | # resize to final resolution 339 | if self.fullsize: 340 | pos_flow = self.fullsize(pos_flow) 341 | 342 | # DO SOME SMOOTHING OF THE FLOW FIELD HERE. 343 | pos_flow = F.conv1d(pos_flow, self.smooth_kernel.view(1,1,self.smooth_size).to(x.device), padding=self.smooth_pad, stride=1) 344 | 345 | # warp image with flow field 346 | y_source = self.transformer(source, pos_flow) 347 | return y_source.transpose(1,2) 348 | 349 | -------------------------------------------------------------------------------- /multitask/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | from torch_geometric.utils import add_self_loops, degree, softmax 4 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 5 | import torch.nn.functional as F 6 | from loader import BioDataset 7 | from dataloader import DataLoaderFinetune 8 | from torch_scatter import scatter_add 9 | from torch_geometric.nn.inits import glorot, zeros 10 | 11 | class GINConv(MessagePassing): 12 | """ 13 | Extension of GIN aggregation to incorporate edge information by concatenation. 14 | 15 | Args: 16 | emb_dim (int): dimensionality of embeddings for nodes and edges. 17 | input_layer (bool): whethe the GIN conv is applied to input layer or not. (Input node labels are uniform...) 18 | 19 | See https://arxiv.org/abs/1810.00826 20 | """ 21 | def __init__(self, emb_dim, aggr = "add", input_layer = False): 22 | super(GINConv, self).__init__() 23 | # multi-layer perceptron 24 | self.mlp = torch.nn.Sequential(torch.nn.Linear(2*emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) 25 | 26 | ### Mapping 0/1 edge features to embedding 27 | self.edge_encoder = torch.nn.Linear(9, emb_dim) 28 | 29 | ### Mapping uniform input features to embedding. 30 | self.input_layer = input_layer 31 | if self.input_layer: 32 | self.input_node_embeddings = torch.nn.Embedding(2, emb_dim) 33 | torch.nn.init.xavier_uniform_(self.input_node_embeddings.weight.data) 34 | 35 | self.aggr = aggr 36 | 37 | def forward(self, x, edge_index, edge_attr): 38 | #add self loops in the edge space 39 | # print(edge_index) 40 | edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0)) 41 | # print(edge_index.type()) 42 | 43 | #add features corresponding to self-loop edges. 44 | self_loop_attr = torch.zeros(x.size(0), 9) 45 | self_loop_attr[:,7] = 1 # attribute for self-loop edge 46 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 47 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) 48 | 49 | edge_embeddings = self.edge_encoder(edge_attr) 50 | 51 | if self.input_layer: 52 | x = self.input_node_embeddings(x.to(torch.int64).view(-1,)) 53 | 54 | # print(edge_index.dtype) 55 | edge_index = edge_index.long() 56 | # print(edge_index.dtype) 57 | 58 | return self.propagate(edge_index , x=x, edge_attr=edge_embeddings) 59 | 60 | def message(self, x_j, edge_attr): 61 | return torch.cat([x_j, edge_attr], dim = 1) 62 | 63 | def update(self, aggr_out): 64 | return self.mlp(aggr_out) 65 | 66 | 67 | class GCNConv(MessagePassing): 68 | 69 | def __init__(self, emb_dim, aggr = "add", input_layer = False): 70 | super(GCNConv, self).__init__() 71 | 72 | self.emb_dim = emb_dim 73 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 74 | 75 | ### Mapping 0/1 edge features to embedding 76 | self.edge_encoder = torch.nn.Linear(9, emb_dim) 77 | 78 | ### Mapping uniform input features to embedding. 79 | self.input_layer = input_layer 80 | if self.input_layer: 81 | self.input_node_embeddings = torch.nn.Embedding(2, emb_dim) 82 | torch.nn.init.xavier_uniform_(self.input_node_embeddings.weight.data) 83 | 84 | self.aggr = aggr 85 | 86 | def norm(self, edge_index, num_nodes, dtype): 87 | ### assuming that self-loops have been already added in edge_index 88 | edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, 89 | device=edge_index.device) 90 | row, col = edge_index 91 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 92 | deg_inv_sqrt = deg.pow(-0.5) 93 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 94 | 95 | return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 96 | 97 | 98 | def forward(self, x, edge_index, edge_attr): 99 | #add self loops in the edge space 100 | edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) 101 | 102 | #add features corresponding to self-loop edges. 103 | self_loop_attr = torch.zeros(x.size(0), 9) 104 | self_loop_attr[:,7] = 1 # attribute for self-loop edge 105 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 106 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) 107 | 108 | edge_embeddings = self.edge_encoder(edge_attr) 109 | 110 | if self.input_layer: 111 | x = self.input_node_embeddings(x.to(torch.int64).view(-1,)) 112 | 113 | norm = self.norm(edge_index, x.size(0), x.dtype) 114 | 115 | x = self.linear(x) 116 | 117 | return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings, norm = norm) 118 | 119 | def message(self, x_j, edge_attr, norm): 120 | return norm.view(-1, 1) * (x_j + edge_attr) 121 | 122 | 123 | class GATConv(MessagePassing): 124 | def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr = "add", input_layer = False): 125 | super(GATConv, self).__init__() 126 | 127 | self.aggr = aggr 128 | 129 | self.emb_dim = emb_dim 130 | self.heads = heads 131 | self.negative_slope = negative_slope 132 | 133 | self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim) 134 | self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim)) 135 | 136 | self.bias = torch.nn.Parameter(torch.Tensor(emb_dim)) 137 | 138 | ### Mapping 0/1 edge features to embedding 139 | self.edge_encoder = torch.nn.Linear(9, heads * emb_dim) 140 | 141 | ### Mapping uniform input features to embedding. 142 | self.input_layer = input_layer 143 | if self.input_layer: 144 | self.input_node_embeddings = torch.nn.Embedding(2, emb_dim) 145 | torch.nn.init.xavier_uniform_(self.input_node_embeddings.weight.data) 146 | 147 | self.reset_parameters() 148 | 149 | def reset_parameters(self): 150 | glorot(self.att) 151 | zeros(self.bias) 152 | 153 | def forward(self, x, edge_index, edge_attr): 154 | #add self loops in the edge space 155 | edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) 156 | 157 | #add features corresponding to self-loop edges. 158 | self_loop_attr = torch.zeros(x.size(0), 9) 159 | self_loop_attr[:,7] = 1 # attribute for self-loop edge 160 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 161 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) 162 | 163 | edge_embeddings = self.edge_encoder(edge_attr) 164 | 165 | if self.input_layer: 166 | x = self.input_node_embeddings(x.to(torch.int64).view(-1,)) 167 | 168 | x = self.weight_linear(x).view(-1, self.heads, self.emb_dim) 169 | return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings) 170 | 171 | def message(self, edge_index, x_i, x_j, edge_attr): 172 | edge_attr = edge_attr.view(-1, self.heads, self.emb_dim) 173 | x_j += edge_attr 174 | 175 | alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) 176 | 177 | alpha = F.leaky_relu(alpha, self.negative_slope) 178 | alpha = softmax(alpha, edge_index[0]) 179 | 180 | return x_j * alpha.view(-1, self.heads, 1) 181 | 182 | def update(self, aggr_out): 183 | aggr_out = aggr_out.mean(dim=1) 184 | aggr_out = aggr_out + self.bias 185 | 186 | return aggr_out 187 | 188 | 189 | class GraphSAGEConv(MessagePassing): 190 | def __init__(self, emb_dim, aggr = "mean", input_layer = False): 191 | super(GraphSAGEConv, self).__init__() 192 | 193 | self.emb_dim = emb_dim 194 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 195 | 196 | ### Mapping 0/1 edge features to embedding 197 | self.edge_encoder = torch.nn.Linear(9, emb_dim) 198 | 199 | ### Mapping uniform input features to embedding. 200 | self.input_layer = input_layer 201 | if self.input_layer: 202 | self.input_node_embeddings = torch.nn.Embedding(2, emb_dim) 203 | torch.nn.init.xavier_uniform_(self.input_node_embeddings.weight.data) 204 | 205 | self.aggr = aggr 206 | 207 | def forward(self, x, edge_index, edge_attr): 208 | #add self loops in the edge space 209 | edge_index = add_self_loops(edge_index, num_nodes = x.size(0)) 210 | 211 | #add features corresponding to self-loop edges. 212 | self_loop_attr = torch.zeros(x.size(0), 9) 213 | self_loop_attr[:,7] = 1 # attribute for self-loop edge 214 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 215 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) 216 | 217 | edge_embeddings = self.edge_encoder(edge_attr) 218 | 219 | if self.input_layer: 220 | x = self.input_node_embeddings(x.to(torch.int64).view(-1,)) 221 | 222 | x = self.linear(x) 223 | 224 | return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings) 225 | 226 | def message(self, x_j, edge_attr): 227 | return x_j + edge_attr 228 | 229 | def update(self, aggr_out): 230 | return F.normalize(aggr_out, p = 2, dim = -1) 231 | 232 | 233 | class GNN(torch.nn.Module): 234 | """ 235 | Extension of GIN to incorporate edge information by concatenation. 236 | 237 | Args: 238 | num_layer (int): the number of GNN layers 239 | emb_dim (int): dimensionality of embeddings 240 | JK (str): last, concat, max or sum. 241 | max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation 242 | drop_ratio (float): dropout rate 243 | gnn_type: gin, gat, graphsage, gcn 244 | 245 | See https://arxiv.org/abs/1810.00826 246 | JK-net: https://arxiv.org/abs/1806.03536 247 | 248 | Output: 249 | node representations 250 | 251 | """ 252 | def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"): 253 | super(GNN, self).__init__() 254 | self.num_layer = num_layer 255 | self.drop_ratio = drop_ratio 256 | self.JK = JK 257 | 258 | if self.num_layer < 2: 259 | raise ValueError("Number of GNN layers must be greater than 1.") 260 | 261 | ###List of message-passing GNN convs 262 | self.gnns = torch.nn.ModuleList() 263 | for layer in range(num_layer): 264 | if layer == 0: 265 | input_layer = True 266 | else: 267 | input_layer = False 268 | 269 | if gnn_type == "gin": 270 | self.gnns.append(GINConv(emb_dim, aggr = "add", input_layer = input_layer)) 271 | elif gnn_type == "gcn": 272 | self.gnns.append(GCNConv(emb_dim, input_layer = input_layer)) 273 | elif gnn_type == "gat": 274 | self.gnns.append(GATConv(emb_dim, input_layer = input_layer)) 275 | elif gnn_type == "graphsage": 276 | self.gnns.append(GraphSAGEConv(emb_dim, input_layer = input_layer)) 277 | 278 | #def forward(self, x, edge_index, edge_attr): 279 | def forward(self, x, edge_index, edge_attr): 280 | h_list = [x] 281 | for layer in range(self.num_layer): 282 | h = self.gnns[layer](h_list[layer], edge_index, edge_attr) 283 | if layer == self.num_layer - 1: 284 | #remove relu from the last layer 285 | h = F.dropout(h, self.drop_ratio, training = self.training) 286 | else: 287 | h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 288 | h_list.append(h) 289 | 290 | if self.JK == "last": 291 | node_representation = h_list[-1] 292 | elif self.JK == "sum": 293 | h_list = [h.unsqueeze_(0) for h in h_list] 294 | node_representation = torch.sum(torch.cat(h_list[1:], dim = 0), dim = 0)[0] 295 | 296 | return node_representation 297 | 298 | 299 | class GNN_graphpred(torch.nn.Module): 300 | """ 301 | Extension of GIN to incorporate edge information by concatenation. 302 | 303 | Args: 304 | num_layer (int): the number of GNN layers 305 | emb_dim (int): dimensionality of embeddings 306 | num_tasks (int): number of tasks in multi-task learning scenario 307 | drop_ratio (float): dropout rate 308 | JK (str): last, concat, max or sum. 309 | graph_pooling (str): sum, mean, max, attention, set2set 310 | 311 | See https://arxiv.org/abs/1810.00826 312 | JK-net: https://arxiv.org/abs/1806.03536 313 | """ 314 | def __init__(self, num_layer, emb_dim, num_tasks, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"): 315 | super(GNN_graphpred, self).__init__() 316 | self.num_layer = num_layer 317 | self.drop_ratio = drop_ratio 318 | self.JK = JK 319 | self.emb_dim = emb_dim 320 | self.num_tasks = num_tasks 321 | 322 | if self.num_layer < 2: 323 | raise ValueError("Number of GNN layers must be greater than 1.") 324 | 325 | self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type) 326 | 327 | #Different kind of graph pooling 328 | if graph_pooling == "sum": 329 | self.pool = global_add_pool 330 | elif graph_pooling == "mean": 331 | self.pool = global_mean_pool 332 | elif graph_pooling == "max": 333 | self.pool = global_max_pool 334 | elif graph_pooling == "attention": 335 | self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1)) 336 | else: 337 | raise ValueError("Invalid graph pooling type.") 338 | 339 | self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks) 340 | 341 | def from_pretrained(self, model_file): 342 | sd = torch.load(model_file) 343 | if 'student_sd' not in sd: 344 | self.gnn.load_state_dict(torch.load(model_file)) 345 | else: 346 | sd_red = {k[4:] : sd['student_sd'][k] for k in sd['student_sd'].keys() if 'gnn' in k} 347 | self.gnn.load_state_dict(sd_red) 348 | 349 | 350 | def forward(self, data): 351 | x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch 352 | # print(edge_index) 353 | node_representation = self.gnn(x, edge_index, edge_attr) 354 | 355 | pooled = self.pool(node_representation, batch) 356 | center_node_rep = node_representation[data.center_node_idx] 357 | 358 | graph_rep = torch.cat([pooled, center_node_rep], dim = 1) 359 | 360 | return self.graph_pred_linear(graph_rep) 361 | 362 | def logits(self, data): 363 | x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch 364 | # print(edge_index) 365 | node_representation = self.gnn(x, edge_index, edge_attr) 366 | 367 | pooled = self.pool(node_representation, batch) 368 | center_node_rep = node_representation[data.center_node_idx] 369 | 370 | graph_rep = torch.cat([pooled, center_node_rep], dim = 1) 371 | 372 | return graph_rep 373 | 374 | def head(self, feats): 375 | return self.graph_pred_linear(feats) 376 | 377 | class GNNHead(torch.nn.Module): 378 | def __init__(self, feats, num_classes): 379 | super(GNNHead, self).__init__() 380 | self.fc_pi = torch.nn.Linear(feats, num_classes) 381 | 382 | def forward(self, x): 383 | out_pi = self.fc_pi(x) 384 | return out_pi 385 | 386 | class TaskWeight(torch.nn.Module): 387 | def __init__(self, num_classes): 388 | super(TaskWeight, self).__init__() 389 | self.num_classes=num_classes 390 | self.weights = torch.nn.parameter.Parameter(torch.zeros(num_classes)) 391 | 392 | def forward(self): 393 | # form weights purely using classes. 394 | out_exw = 2*torch.sigmoid(self.weights) 395 | # print(out_exw.shape) 396 | return out_exw 397 | 398 | if __name__ == "__main__": 399 | pass 400 | 401 | 402 | 403 | -------------------------------------------------------------------------------- /simclr/simclr_ecg.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as utils 7 | import torch.nn.functional as F 8 | import higher 9 | import pickle 10 | from torch.utils.data import Dataset, DataLoader, Subset 11 | from torchvision import datasets, transforms 12 | from torch.autograd import grad 13 | from tqdm import tqdm 14 | from simclr_models import * 15 | from simclr_datasets import * 16 | from nt_xent import NTXentLoss 17 | 18 | 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(description='ECG SIMCLR IFT') 22 | 23 | parser.add_argument('--seed', type=int, default=0) 24 | parser.add_argument('--gpu', type=int) 25 | 26 | parser.add_argument('--pretrain_lr', type=float, default=1e-4) 27 | parser.add_argument('--finetune_lr', type=float, default=1e-4) 28 | parser.add_argument('--hyper_lr', type=float, default=1e-4) 29 | parser.add_argument('--epochs', default=50, type=int) 30 | parser.add_argument('--ex', default=500, type=int) 31 | parser.add_argument('--warmup_epochs', type=int, default=1) 32 | parser.add_argument('--pretrain_steps', type=int, default = 10) 33 | parser.add_argument('--finetune_steps', type=int, default = 1) 34 | parser.add_argument('--studentarch', type=str, default='resnet18') 35 | parser.add_argument('--teacherarch') 36 | parser.add_argument('--dataset', type=str, default='ecg') 37 | parser.add_argument('--neumann', type=int, default=1) 38 | parser.add_argument('--batch_size', type=int, default=256) 39 | parser.add_argument('--savefol', type=str, default='simclr') 40 | parser.add_argument('--save', action='store_false') 41 | parser.add_argument('--no_probs', action='store_true') 42 | parser.add_argument('--temperature', type=float, default=0.5) 43 | parser.add_argument('--checkpoint', type=str) 44 | parser.add_argument('--teach_checkpoint', type=str) 45 | 46 | args = parser.parse_args() 47 | 48 | torch.manual_seed(args.seed) 49 | torch.cuda.manual_seed_all(args.seed) 50 | torch.multiprocessing.set_sharing_strategy('file_system') 51 | 52 | if args.no_probs: 53 | args.savefol += 'determ' 54 | 55 | args.savefol += f'-{args.ex}ex' 56 | 57 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 58 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 59 | 60 | 61 | class AverageMeter(object): 62 | """Computes and stores the average and current value""" 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count 77 | 78 | 79 | def model_saver(epoch, student, head, teacher, pt_opt, pt_sched, ft_opt, hyp_opt, path): 80 | torch.save({ 81 | 'student_sd': student.state_dict(), 82 | 'teacher_sd': teacher.state_dict() if teacher is not None else None, 83 | 'head_sd': head.state_dict(), 84 | 'pt_opt_state_dict': pt_opt.state_dict(), 85 | 'pt_sched_state_dict': pt_sched.state_dict(), 86 | 'ft_opt_state_dict': ft_opt.state_dict(), 87 | 'hyp_opt_state_dict': hyp_opt.state_dict() if teacher is not None else None, 88 | }, path + f'/checkpoint_epoch{epoch}.pt') 89 | 90 | 91 | def get_save_path(): 92 | modfol = f"""seed{args.seed}-dataset{args.dataset}-student{args.studentarch}-teacher{args.teacherarch}-ptlr{args.pretrain_lr}-ftlr{args.finetune_lr}-hyplr{args.hyper_lr}-warmup{args.warmup_epochs}-pt_steps{args.pretrain_steps}-ft_steps{args.finetune_steps}-neumann{args.neumann}""" 93 | if args.teach_checkpoint: 94 | args.savefol += '-teachckpt' 95 | modfol = os.path.join(modfol, args.teach_checkpoint) 96 | pth = os.path.join(args.savefol, modfol) 97 | os.makedirs(pth, exist_ok=True) 98 | return pth 99 | 100 | def zero_hypergrad(hyper_params): 101 | """ 102 | 103 | :param get_hyper_train: 104 | :return: 105 | """ 106 | current_index = 0 107 | for p in hyper_params: 108 | p_num_params = np.prod(p.shape) 109 | if p.grad is not None: 110 | p.grad = p.grad * 0 111 | current_index += p_num_params 112 | 113 | 114 | def store_hypergrad(hyper_params, total_d_val_loss_d_lambda): 115 | """ 116 | 117 | :param get_hyper_train: 118 | :param total_d_val_loss_d_lambda: 119 | :return: 120 | """ 121 | current_index = 0 122 | for p in hyper_params: 123 | p_num_params = np.prod(p.shape) 124 | p.grad = total_d_val_loss_d_lambda[current_index:current_index + p_num_params].view(p.shape) 125 | current_index += p_num_params 126 | 127 | 128 | def neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, num_neumann_terms, model, head): 129 | preconditioner = d_val_loss_d_theta.detach() 130 | counter = preconditioner 131 | 132 | # Do the fixed point iteration to approximate the vector-inverseHessian product 133 | i = 0 134 | while i < num_neumann_terms: # for i in range(num_neumann_terms): 135 | old_counter = counter 136 | 137 | # This increments counter to counter * (I - hessian) = counter - counter * hessian 138 | hessian_term = gather_flat_grad( 139 | grad(d_train_loss_d_w, list(model.parameters())+list(head.parameters()), grad_outputs=counter.view(-1), retain_graph=True)) 140 | counter = old_counter - elementary_lr * hessian_term 141 | 142 | preconditioner = preconditioner + counter 143 | i += 1 144 | return elementary_lr * preconditioner 145 | 146 | def get_hyper_train_flat(hyper_params): 147 | return torch.cat([p.view(-1) for p in hyper_params]) 148 | 149 | def gather_flat_grad(loss_grad): 150 | return torch.cat([p.reshape(-1) for p in loss_grad]) #g_vector 151 | 152 | def get_loss(student,head,teacher, x, y): 153 | head_op = head(student.logits(x)) 154 | pi_stud = student(x) 155 | l_obj = nn.BCEWithLogitsLoss() 156 | clf_loss = l_obj(head_op, y) 157 | y_loss_stud = clf_loss + 0*torch.sum(pi_stud[0])+ 0*torch.sum(pi_stud[1]) 158 | acc_stud = 0 #torch.mean(torch.sigmoid(head_op) > 0.5 * y).item() 159 | return y_loss_stud, acc_stud 160 | 161 | 162 | def hyper_step(model, head, teacher, hyper_params, pretrain_loader, optimizer, d_val_loss_d_theta, elementary_lr, neum_steps): 163 | zero_hypergrad(hyper_params) 164 | num_weights = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in head.parameters()) 165 | 166 | d_train_loss_d_w = torch.zeros(num_weights).to(device) 167 | model.train(), model.zero_grad(),head.train(), head.zero_grad() 168 | 169 | # NOTE: This should be the pretrain set: gradient of PRETRAINING loss wrt pretrain parameters. 170 | for batch_idx, (xis, xjs) in enumerate(pretrain_loader): 171 | xis = xis.to(device) 172 | xjs = xjs.to(device) 173 | if teacher is not None: 174 | xis = teacher(xis) 175 | xjs = teacher(xjs) 176 | train_loss= get_loss_simclr(model, xis, xjs) 177 | train_loss = train_loss + train_loss*head(model.logits(xis)).sum()*0 178 | optimizer.zero_grad() 179 | d_train_loss_d_w += gather_flat_grad(grad(train_loss, list(model.parameters())+list(head.parameters()), 180 | create_graph=True, allow_unused=True)) 181 | break 182 | optimizer.zero_grad() 183 | 184 | # Initialize the preconditioner and counter 185 | preconditioner = d_val_loss_d_theta 186 | 187 | preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, 188 | neum_steps, model, head) 189 | 190 | 191 | # THIS SHOULD BE PRETRAIN LOSS AGAIN. 192 | indirect_grad = gather_flat_grad( 193 | grad(d_train_loss_d_w, hyper_params, grad_outputs=preconditioner.view(-1))) 194 | hypergrad = indirect_grad # + direct_Grad 195 | 196 | zero_hypergrad(hyper_params) 197 | store_hypergrad(hyper_params, -hypergrad) 198 | return hypergrad 199 | 200 | nt_xent_criterion = NTXentLoss(device, args.batch_size, args.temperature, use_cosine_similarity=True) 201 | def get_loss_simclr(student, xis, xjs): 202 | 203 | # print(xis.type(), xjs.type()) 204 | # get the representations and the projections 205 | ris, zis = student(xis) # [N,C] 206 | 207 | # get the representations and the projections 208 | rjs, zjs = student(xjs) # [N,C] 209 | 210 | # normalize projection feature vectors 211 | # zis = F.normalize(zis, dim=1) 212 | # zjs = F.normalize(zjs, dim=1) 213 | 214 | loss = nt_xent_criterion(zis, zjs) 215 | return loss 216 | 217 | def do_pretrain(student, head, teacher, optimizer, xis, xjs): 218 | student.train() 219 | xis = xis.to(device) 220 | xjs = xjs.to(device) 221 | if teacher is not None: 222 | xis = teacher(xis) 223 | xjs = teacher(xjs) 224 | loss = get_loss_simclr(student, xis, xjs) 225 | optimizer.zero_grad() 226 | loss.backward() 227 | optimizer.step() 228 | 229 | return loss.item() 230 | 231 | 232 | def inner_loop_finetune(student, head, teacher, optimizer, train_dl, val_dl, num_steps): 233 | stud_loss = 0. 234 | stud_acc = 0. 235 | 236 | student.eval() 237 | if teacher is not None: 238 | teacher.train() 239 | 240 | for i, (x,y) in enumerate(train_dl): 241 | x = x.to(device) 242 | y = y.to(device) 243 | y_loss, acc = get_loss(student, head, teacher, x, y) 244 | optimizer.step(y_loss) 245 | # logging 246 | stud_loss += y_loss.item() 247 | stud_acc += acc 248 | 249 | if i == num_steps - 1: 250 | break 251 | stud_loss /= num_steps 252 | stud_acc /= num_steps 253 | 254 | # Now compute the val loss 255 | avgloss = None 256 | avgacc = None 257 | for i, (x,y) in enumerate(val_dl): 258 | x = x.to(device) 259 | y = y.to(device) 260 | y_loss, acc = get_loss(student, head, teacher, x, y) 261 | 262 | if avgloss is None: 263 | avgloss = y_loss 264 | avgacc = acc 265 | else: 266 | avgloss += y_loss 267 | avgacc += acc 268 | break 269 | # print(avgloss.item()) 270 | # Now compute a finetuning gradient 271 | ft_grad = torch.autograd.grad(avgloss, list(student.parameters())+list(head.parameters(time=0)), allow_unused=True) 272 | return (stud_loss, stud_acc), (avgloss, avgacc), ft_grad, head 273 | 274 | def do_ft_head(student, head, optimizer, dl): 275 | student.eval() 276 | for x,y in dl: 277 | x = x.to(device) 278 | y = y.to(device) 279 | 280 | loss, acc = get_loss(student, head, None, x, y) 281 | optimizer.zero_grad() 282 | loss.backward() 283 | optimizer.step() 284 | optimizer.zero_grad() 285 | break 286 | return loss.item(), acc 287 | 288 | 289 | # Utility function to update lossdict 290 | def update_lossdict(lossdict, update, action='append'): 291 | for k in update.keys(): 292 | if action == 'append': 293 | if k in lossdict: 294 | lossdict[k].append(update[k]) 295 | else: 296 | lossdict[k] = [update[k]] 297 | elif action == 'sum': 298 | if k in lossdict: 299 | lossdict[k] += update[k] 300 | else: 301 | lossdict[k] = update[k] 302 | else: 303 | raise NotImplementedError 304 | return lossdict 305 | 306 | 307 | from sklearn.metrics import roc_auc_score 308 | 309 | # Evaluate student on complete train/test set. 310 | def eval_student(student, head, dl): 311 | student.eval() 312 | net_loss = 0 313 | correct = 0 314 | y_pred = [] 315 | y_true = [] 316 | l_obj = nn.BCEWithLogitsLoss(reduction='sum') 317 | # clf_loss = l_obj(head_op, y) 318 | with torch.no_grad(): 319 | for data, target in dl: 320 | y_true.append(target.detach().cpu().numpy()) 321 | data, target = data.to(device), target.to(device) 322 | output = head(student.logits(data)) 323 | net_loss += l_obj(output, target).item() # sum up batch loss 324 | y_pred.append(output.detach().cpu().numpy()) 325 | # pred = torch.sigmoid(output) > 0.5 326 | # correct += torch.sum(pred == target).item() 327 | 328 | y_pred = np.concatenate(y_pred, axis=0) 329 | y_true = np.concatenate(y_true, axis=0) 330 | net_loss /= len(dl.dataset) 331 | # acc = 100. * correct / len(dl.dataset * y_pred.shape[1]) 332 | 333 | roc_list = [] 334 | for i in range(y_true.shape[1]): 335 | try: 336 | #AUC is only defined when there is at least one positive data. 337 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == 0) > 0: 338 | roc_list.append(roc_auc_score(y_true[:,i], y_pred[:,i])) 339 | else: 340 | roc_list.append(np.nan) 341 | except ValueError: 342 | roc_list.append(np.nan) 343 | 344 | print('Average loss: {:.4f}, AUC: {:.4f}'.format(net_loss, np.mean(roc_list))) 345 | return {'epoch_loss': net_loss, 'auc' : roc_list} 346 | 347 | 348 | import copy 349 | 350 | def train(): 351 | pt_meter = AverageMeter() 352 | ft_loss_meter = AverageMeter() 353 | ft_acc_meter = AverageMeter() 354 | 355 | # create save path 356 | if args.save: 357 | save_path = get_save_path() 358 | 359 | 360 | if args.teacherarch == 'warpexmag': 361 | teacher = RandWarpAugLearnExMag(inshape=[1024]).to(device) 362 | hyp_params = list(teacher.parameters()) 363 | 364 | hyp_optim = torch.optim.Adam([ 365 | {'params': teacher.net.parameters(), 'lr': args.hyper_lr}, 366 | {'params': teacher.flow_mag_layer.parameters(), 'lr': args.hyper_lr}, 367 | {'params': [teacher.flow_mag], 'lr': 1}]) 368 | 369 | hyp_scheduler = None 370 | else: 371 | args.teacherarch = None 372 | teacher = None 373 | hyp_params = None 374 | hyp_optim = None 375 | hyp_scheduler = None 376 | 377 | DSHandle = ECGDataSetWrapper(args.batch_size) 378 | pretrain_dl, train_dl, val_dl, test_dl, _, NUM_TASKS_FT = DSHandle.get_data_loaders(args) 379 | 380 | if args.studentarch == 'resnet18': 381 | student = ecg_simclr_resnet18().to(device) 382 | head = MultitaskHead(256, NUM_TASKS_FT).to(device) 383 | else: 384 | raise NotImplementedError 385 | 386 | 387 | pretrain_optim = torch.optim.Adam(student.parameters(), lr=args.pretrain_lr) 388 | pretrain_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(pretrain_optim, T_max=args.epochs, eta_min=0, 389 | last_epoch=-1) 390 | finetune_optim = torch.optim.Adam(head.parameters(), lr=args.finetune_lr) 391 | 392 | if args.checkpoint: 393 | ckpt = torch.load(args.checkpoint) 394 | student.load_state_dict(ckpt['student_sd']) 395 | if teacher is not None and ckpt['teacher_sd'] is not None: 396 | teacher.load_state_dict(ckpt['teacher_sd']) 397 | head.load_state_dict(ckpt['head_sd']) 398 | pretrain_optim.load_state_dict(ckpt['pt_opt_state_dict']) 399 | pretrain_scheduler.load_state_dict(ckpt['pt_sched_state_dict']) 400 | finetune_optim.load_state_dict(ckpt['ft_opt_state_dict']) 401 | if teacher is not None and ckpt['hyp_opt_state_dict'] is not None: 402 | hyp_optim.load_state_dict(ckpt['hyp_opt_state_dict']) 403 | load_ep = int(os.path.split(args.checkpoint)[-1][16:-3]) + 1 404 | print(f"Restored from epoch {load_ep}") 405 | else: 406 | print("Training from scratch") 407 | load_ep = 0 408 | 409 | if args.teach_checkpoint: 410 | print("LOADING PT AUG MODEL") 411 | ckpt = torch.load(args.teach_checkpoint) 412 | teacher.load_state_dict(ckpt['aug_sd']) 413 | print("LOAD SUCCESSFUL") 414 | 415 | stud_pretrain_ld = {'loss' : [], 'acc' : [] } 416 | stud_finetune_train_ld = {'loss' : [], 'acc' : []} 417 | stud_finetune_val_ld = {'loss' : [], 'acc' : []} 418 | stud_finetune_test_ld = {} 419 | 420 | num_finetune_steps = args.finetune_steps 421 | num_neumann_steps = args.neumann 422 | 423 | steps = 0 424 | for n in range(load_ep,args.epochs): 425 | # gradnorm = {'gradnorm' : []} 426 | progress_bar = tqdm(pretrain_dl) 427 | for i, (xis,xjs) in enumerate(progress_bar): 428 | progress_bar.set_description('Epoch ' + str(n)) 429 | # step_num = i + n*len(pretrain_dl) 430 | if teacher is not None: 431 | zero_hypergrad(hyp_params) 432 | 433 | if n < args.warmup_epochs or teacher is None: 434 | pt_loss = do_pretrain(student, head, teacher, pretrain_optim, xis, xjs) 435 | pt_meter.update(pt_loss) 436 | if teacher is not None: 437 | ft_train_loss, ft_train_acc = do_ft_head(student, head, finetune_optim, train_dl) 438 | else: 439 | ft_train_loss, ft_train_acc = 0,0 440 | ft_loss_meter.update(ft_train_loss) 441 | ft_acc_meter.update(ft_train_acc) 442 | ft_val_loss, ft_val_acc, hypg = 0,0,0 443 | else: 444 | # if steps % args.pretrain_steps == 0: 445 | pt_loss = do_pretrain(student, head, teacher, pretrain_optim, xis, xjs) 446 | pt_meter.update(pt_loss) 447 | if steps % args.pretrain_steps == 0: 448 | with higher.innerloop_ctx(head, finetune_optim, copy_initial_weights=True) as (fnet, diffopt): 449 | (ft_train_loss, ft_train_acc), (ft_val_loss, ft_val_acc), ft_grad, fnet = \ 450 | inner_loop_finetune(student, fnet, teacher, diffopt, train_dl, val_dl, num_finetune_steps) 451 | head.load_state_dict(fnet.state_dict()) 452 | ft_loss_meter.update(ft_train_loss) 453 | ft_acc_meter.update(ft_train_acc) 454 | ft_grad = gather_flat_grad(ft_grad) 455 | for param_group in pretrain_optim.param_groups: 456 | cur_lr = param_group['lr'] 457 | break 458 | 459 | hypg = hyper_step(student, head, teacher, hyp_params, pretrain_dl, pretrain_optim, ft_grad, cur_lr, num_neumann_steps) 460 | hypg = hypg.norm().item() 461 | hyp_optim.step() 462 | else: 463 | ft_train_loss, ft_train_acc = do_ft_head(student, head, finetune_optim, train_dl) 464 | ft_loss_meter.update(ft_train_loss) 465 | ft_acc_meter.update(ft_train_acc) 466 | ft_val_loss, ft_val_acc, hypg = 0,0,0 467 | steps += 1 468 | progress_bar.set_postfix( 469 | pretrain_loss='%.4f' % pt_meter.avg , 470 | finetune_train_loss='%.4f' % ft_loss_meter.avg , 471 | finetune_train_acc='%.4f' % ft_acc_meter.avg , 472 | ) 473 | 474 | # append to lossdict 475 | stud_pretrain_ld['loss'].append(pt_loss) 476 | stud_finetune_train_ld['loss'].append(ft_train_loss) 477 | stud_finetune_train_ld['acc'].append(ft_train_acc) 478 | stud_finetune_val_ld['loss'].append(ft_val_loss) 479 | stud_finetune_val_ld['acc'].append(ft_val_acc) 480 | 481 | if teacher is not None: 482 | ft_test_ld = eval_student(student,head, test_dl) 483 | stud_finetune_test_ld = update_lossdict(stud_finetune_test_ld, ft_test_ld) 484 | 485 | ft_val_ld = eval_student(student, head, val_dl) 486 | stud_finetune_val_ld = update_lossdict(stud_finetune_val_ld, ft_val_ld) 487 | 488 | ft_train_ld = eval_student(student,head, train_dl) 489 | stud_finetune_train_ld = update_lossdict(stud_finetune_train_ld, ft_train_ld) 490 | 491 | 492 | if hyp_scheduler is not None: 493 | hyp_scheduler.step() 494 | # reset the meter 495 | pt_meter.reset() 496 | ft_loss_meter.reset() 497 | ft_acc_meter.reset() 498 | # save the logs 499 | if args.save: 500 | tosave = { 501 | 'pretrain_ld' : stud_pretrain_ld, 502 | 'finetune_train_ld' : stud_finetune_train_ld, 503 | 'finetune_val_ld' : stud_finetune_val_ld, 504 | 'finetune_test_ld' : stud_finetune_test_ld, 505 | } 506 | torch.save(tosave, os.path.join(save_path, 'logs.ckpt')) 507 | if n == args.epochs- 1: 508 | model_saver(n, student, head, teacher, pretrain_optim, pretrain_scheduler, finetune_optim, hyp_optim, save_path) 509 | print(f"Saved model at epoch {n}") 510 | return student, head, teacher, pretrain_optim, pretrain_scheduler, finetune_optim, hyp_optim 511 | 512 | 513 | res = train() 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | -------------------------------------------------------------------------------- /multitask/pretrain_supervised_weighting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from splitters import random_split, species_split 4 | from loader import BioDataset 5 | from torch_geometric.data import DataLoader 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | from dataloader import DataLoaderFinetune 16 | from model import GNN, GNN_graphpred 17 | from sklearn.metrics import roc_auc_score 18 | 19 | import pandas as pd 20 | 21 | from util import combine_dataset 22 | from model import GNN, GNN_graphpred, GNNHead, TaskWeight 23 | 24 | from PIL import Image 25 | import numpy as np 26 | import os 27 | import torch.utils.data as utils 28 | import higher 29 | import pickle 30 | from torch.autograd import grad 31 | import random 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 34 | parser.add_argument('--batch_size', type=int, default=32, 35 | help='input batch size for training (default: 32)') 36 | 37 | parser.add_argument('--epochs', type=int, default=100) 38 | parser.add_argument('--warmup_epochs', type=int, default=1) 39 | parser.add_argument('--pretrain_steps', type=int, default = 10) 40 | parser.add_argument('--finetune_steps', type=int, default = 1) 41 | parser.add_argument('--neumann', type=int, default=1) 42 | 43 | parser.add_argument('--pretrain_lr', type=float, default=1e-4, 44 | help='learning rate (default: 0.001)') 45 | 46 | parser.add_argument('--finetune_lr', type=float, default=1e-4) 47 | 48 | parser.add_argument('--hyper_lr', type=float, default=1) 49 | 50 | 51 | parser.add_argument('--decay', type=float, default=0, 52 | help='weight decay (default: 0)') 53 | parser.add_argument('--num_layer', type=int, default=5, 54 | help='number of GNN message passing layers (default: 5).') 55 | parser.add_argument('--emb_dim', type=int, default=300, 56 | help='embedding dimensions (default: 300)') 57 | parser.add_argument('--dropout_ratio', type=float, default=0.2, 58 | help='dropout ratio (default: 0.2)') 59 | parser.add_argument('--graph_pooling', type=str, default="mean", 60 | help='graph level pooling (sum, mean, max, set2set, attention)') 61 | parser.add_argument('--JK', type=str, default="last", 62 | help='how the node features across layers are combined. last, sum, max or concat') 63 | parser.add_argument('--input_model_file', type=str, default = '', help='filename to read the model (if there is any)') 64 | parser.add_argument('--savefol', type=str, default='exw-adamhyper') 65 | parser.add_argument('--gnn_type', type=str, default="gin") 66 | parser.add_argument('--num_workers', type=int, default = 0, help='number of workers for dataset loading') 67 | parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting dataset.") 68 | parser.add_argument('--split', type=str, default = "species", help='Random or species split') 69 | parser.add_argument('--gpu', type=int, default=0) 70 | parser.add_argument('--smallft', type=float, default = 1, help='split the val set down or not') 71 | parser.add_argument('--checkpoint', type=str) 72 | parser.add_argument('--ft_bn', action='store_true') 73 | args = parser.parse_args() 74 | 75 | 76 | torch.manual_seed(0) 77 | np.random.seed(0) 78 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 79 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 80 | print(device) 81 | if torch.cuda.is_available(): 82 | torch.cuda.manual_seed_all(0) 83 | 84 | class AverageMeter(object): 85 | """Computes and stores the average and current value""" 86 | def __init__(self): 87 | self.reset() 88 | 89 | def reset(self): 90 | self.val = 0 91 | self.avg = 0 92 | self.sum = 0 93 | self.count = 0 94 | 95 | def update(self, val, n=1): 96 | self.val = val 97 | self.sum += val * n 98 | self.count += n 99 | self.avg = self.sum / self.count 100 | 101 | 102 | # Need to set this! 103 | root_supervised = '/path/to/dataset/' 104 | 105 | dataset = BioDataset(root_supervised, data_type='supervised') 106 | 107 | 108 | print("Making PT dataset...") 109 | if args.split == "random": 110 | print("random splitting") 111 | train_dataset, valid_dataset, test_dataset = random_split(dataset, seed = args.seed) 112 | print(train_dataset) 113 | print(valid_dataset) 114 | pretrain_dataset = combine_dataset(train_dataset, valid_dataset) 115 | print(pretrain_dataset) 116 | elif args.split == "species": 117 | print("species splitting") 118 | trainval_dataset, test_dataset = species_split(dataset) 119 | test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed = args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) 120 | print(trainval_dataset) 121 | print(test_dataset_broad) 122 | pretrain_dataset = combine_dataset(trainval_dataset, test_dataset_broad) 123 | print(pretrain_dataset) 124 | else: 125 | raise ValueError("Unknown split name.") 126 | 127 | pretrain_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 128 | pretrain_val_loader = None 129 | NUM_TASKS_PT = len(pretrain_dataset[0].go_target_pretrain) 130 | 131 | 132 | print("Making FT dataset...") 133 | if args.split == "random": 134 | print("random splitting") 135 | train_dataset, valid_dataset, test_dataset = random_split(dataset, seed = args.seed) 136 | elif args.split == "species": 137 | trainval_dataset, test_dataset = species_split(dataset) 138 | train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) 139 | test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed = args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) 140 | print("species splitting") 141 | else: 142 | raise ValueError("Unknown split name.") 143 | 144 | if args.smallft != 1: 145 | random.seed(0) 146 | torch.manual_seed(0) 147 | len_train = int(args.smallft*len(train_dataset)) 148 | len_val = int(args.smallft*len(valid_dataset)) 149 | train_dataset, _ = torch.utils.data.random_split(train_dataset, [len_train, len(train_dataset)-len_train]) 150 | valid_dataset, _ = torch.utils.data.random_split(valid_dataset, [len_val, len(valid_dataset)-len_val]) 151 | 152 | 153 | finetune_train_loader = DataLoaderFinetune(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 154 | finetune_val_loader = DataLoaderFinetune(valid_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 155 | 156 | if args.split == "random": 157 | test_loader = DataLoaderFinetune(test_dataset, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 158 | else: 159 | ### for species splitting 160 | finetune_test_loader_easy = DataLoaderFinetune(test_dataset_broad, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 161 | finetune_test_loader_hard = DataLoaderFinetune(test_dataset_none, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 162 | 163 | NUM_TASKS_FT = len(dataset[0].go_target_downstream) 164 | 165 | 166 | 167 | def model_saver(epoch, student, head, teacher, pt_opt, ft_opt, hyp_opt, path): 168 | torch.save({ 169 | 'student_sd': student.state_dict(), 170 | 'teacher_sd': teacher.state_dict(), 171 | 'head_sd': head.state_dict(), 172 | 'pt_opt_state_dict': pt_opt.state_dict(), 173 | 'ft_opt_state_dict': ft_opt.state_dict(), 174 | 'hyp_opt_state_dict': hyp_opt.state_dict(), 175 | }, path + f'/checkpoint_epoch{epoch}.pt') 176 | 177 | 178 | def get_save_path(): 179 | modfol = f"""ptlr{args.pretrain_lr}-ftlr{args.finetune_lr}-hyplr{args.hyper_lr}-warmup{args.warmup_epochs}-pt_steps{args.pretrain_steps}-ft_steps{args.finetune_steps}-neumann{args.neumann}-ft_bn{args.ft_bn}-smallft{args.smallft}""" 180 | pth = os.path.join(args.savefol, modfol) 181 | os.makedirs(pth, exist_ok=True) 182 | return pth 183 | 184 | def zero_hypergrad(hyper_params): 185 | """ 186 | 187 | :param get_hyper_train: 188 | :return: 189 | """ 190 | current_index = 0 191 | for p in hyper_params: 192 | p_num_params = np.prod(p.shape) 193 | if p.grad is not None: 194 | p.grad = p.grad * 0 195 | current_index += p_num_params 196 | 197 | 198 | def store_hypergrad(hyper_params, total_d_val_loss_d_lambda): 199 | """ 200 | 201 | :param get_hyper_train: 202 | :param total_d_val_loss_d_lambda: 203 | :return: 204 | """ 205 | current_index = 0 206 | for p in hyper_params: 207 | p_num_params = np.prod(p.shape) 208 | p.grad = total_d_val_loss_d_lambda[current_index:current_index + p_num_params].view(p.shape) 209 | current_index += p_num_params 210 | 211 | 212 | def neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, num_neumann_terms, model, head): 213 | preconditioner = d_val_loss_d_theta.detach() 214 | counter = preconditioner 215 | 216 | # Do the fixed point iteration to approximate the vector-inverseHessian product 217 | i = 0 218 | while i < num_neumann_terms: # for i in range(num_neumann_terms): 219 | old_counter = counter 220 | 221 | # This increments counter to counter * (I - hessian) = counter - counter * hessian 222 | hessian_term = gather_flat_grad( 223 | grad(d_train_loss_d_w, list(model.parameters())+list(head.parameters()), grad_outputs=counter.view(-1), retain_graph=True)) 224 | # hessian_term[hessian_term == None] = 0 225 | counter = old_counter - elementary_lr * hessian_term 226 | 227 | preconditioner = preconditioner + counter 228 | i += 1 229 | return elementary_lr * preconditioner 230 | 231 | def get_hyper_train_flat(hyper_params): 232 | return torch.cat([p.view(-1) for p in hyper_params]) 233 | 234 | def gather_flat_grad(loss_grad): 235 | return torch.cat([p.reshape(-1) for p in loss_grad]) #g_vector 236 | 237 | 238 | # Forward pass on student and teacher, getting neg log lik of label and batch acc 239 | def get_loss(student,head,teacher, batch, ft): 240 | batch = batch.to(device) 241 | feats = student.logits(batch) 242 | pi_stud = student.head(feats) 243 | head_op = head(feats) 244 | if ft: 245 | y = batch.go_target_downstream.view(head_op.shape).to(torch.float64) 246 | l_obj = nn.BCEWithLogitsLoss() 247 | loss = l_obj(head_op.double(), y) 248 | y_loss_stud = loss + 0*torch.sum(pi_stud.double()) 249 | else: 250 | y = batch.go_target_pretrain.view(pi_stud.shape).to(torch.float64) 251 | comm = teacher.forward() 252 | l_obj = nn.BCEWithLogitsLoss(reduction = 'none') 253 | loss = l_obj(pi_stud.double(), y) 254 | # now apply the classweight 255 | comm = comm.unsqueeze(0) 256 | y_loss_stud = torch.mean(comm*loss) + 0*torch.sum(head_op) * 0*torch.sum(loss) # force dependence on other things so autograd does not complain 257 | return y_loss_stud, 0 258 | 259 | 260 | def hyper_step(model, head, teacher, hyper_params, pretrain_loader, optimizer, d_val_loss_d_theta, elementary_lr, neum_steps): 261 | zero_hypergrad(hyper_params) 262 | num_weights = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in head.parameters()) 263 | num_hypers = sum(p.numel() for p in hyper_params) 264 | 265 | d_train_loss_d_w = torch.zeros(num_weights).cuda() 266 | model.train(), model.zero_grad(),head.train(), head.zero_grad() 267 | 268 | # NOTE: This should be the pretrain set: gradient of PRETRAINING loss wrt pretrain parameters. 269 | for batch_idx, batch in enumerate(pretrain_loader): 270 | batch = batch.to(device) 271 | train_loss, _ = get_loss(model, head, teacher, batch, ft=False) 272 | optimizer.zero_grad() 273 | d_train_loss_d_w += gather_flat_grad(grad(train_loss, list(model.parameters())+list(head.parameters()), 274 | create_graph=True)) 275 | break 276 | optimizer.zero_grad() 277 | 278 | # Initialize the preconditioner and counter 279 | preconditioner = d_val_loss_d_theta 280 | 281 | preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, 282 | neum_steps, model, head) 283 | 284 | # THIS SHOULD BE PRETRAIN LOSS AGAIN. 285 | indirect_grad = gather_flat_grad( 286 | grad(d_train_loss_d_w, hyper_params, grad_outputs=preconditioner.view(-1))) 287 | hypergrad = indirect_grad # no direct grad term 288 | 289 | zero_hypergrad(hyper_params) 290 | store_hypergrad(hyper_params, -hypergrad) 291 | return hypergrad 292 | 293 | def do_pretrain(student, head, teacher, optimizer, batch): 294 | student.train() 295 | 296 | batch = batch.to(device) 297 | 298 | loss, acc = get_loss(student, head, teacher, batch, ft=False) 299 | optimizer.zero_grad() 300 | loss.backward() 301 | optimizer.step() 302 | 303 | return loss.item(), acc 304 | 305 | 306 | 307 | def inner_loop_finetune(student, head, teacher, optimizer, train_dl, val_dl, num_steps): 308 | stud_loss = 0. 309 | stud_acc = 0. 310 | 311 | if args.ft_bn: 312 | student.train() 313 | else: 314 | student.eval() # BN should be eval model when we do only head unrolling! 315 | teacher.train() 316 | 317 | for i, batch in enumerate(train_dl): 318 | batch = batch.to(device) 319 | y_loss, acc = get_loss(student, head, teacher, batch, ft=True) 320 | optimizer.step(y_loss) 321 | # logging 322 | stud_loss += y_loss.item() 323 | stud_acc += acc 324 | 325 | if i == num_steps - 1: 326 | break 327 | stud_loss /= num_steps 328 | stud_acc /= num_steps 329 | 330 | # Now compute the val loss 331 | avgloss = None 332 | avgacc = None 333 | for i, batch in enumerate(val_dl): 334 | batch = batch.to(device) 335 | y_loss, acc = get_loss(student, head, teacher, batch, ft=True) 336 | 337 | if avgloss is None: 338 | avgloss = y_loss 339 | avgacc = acc 340 | else: 341 | avgloss += y_loss 342 | avgacc += acc 343 | break 344 | # print(avgloss.item()) 345 | # Now compute a finetuning gradient 346 | ft_grad = torch.autograd.grad(avgloss, list(student.parameters())+list(head.parameters(time=0)), allow_unused=True) 347 | return (stud_loss, stud_acc), (avgloss, avgacc), ft_grad, head 348 | 349 | 350 | 351 | # Utility function to update lossdict 352 | def update_lossdict(lossdict, update, action='append'): 353 | for k in update.keys(): 354 | if action == 'append': 355 | if k in lossdict: 356 | lossdict[k].append(update[k]) 357 | else: 358 | lossdict[k] = [update[k]] 359 | elif action == 'sum': 360 | if k in lossdict: 361 | lossdict[k] += update[k] 362 | else: 363 | lossdict[k] = update[k] 364 | else: 365 | raise NotImplementedError 366 | return lossdict 367 | 368 | 369 | 370 | # Evaluate student on complete train/test set. 371 | def eval_student(student, head, dl): 372 | student.eval() 373 | y_true = [] 374 | y_scores = [] 375 | 376 | for step, batch in enumerate(tqdm(dl, desc="Iteration")): 377 | batch = batch.to(device) 378 | 379 | with torch.no_grad(): 380 | pred = head(student.logits(batch)) 381 | 382 | y_true.append(batch.go_target_downstream.view(pred.shape).detach().cpu()) 383 | y_scores.append(pred.detach().cpu()) 384 | 385 | y_true = torch.cat(y_true, dim = 0).numpy() 386 | y_scores = torch.cat(y_scores, dim = 0).numpy() 387 | 388 | roc_list = [] 389 | for i in range(y_true.shape[1]): 390 | try: 391 | #AUC is only defined when there is at least one positive data. 392 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == 0) > 0: 393 | roc_list.append(roc_auc_score(y_true[:,i], y_scores[:,i])) 394 | else: 395 | roc_list.append(np.nan) 396 | except ValueError: 397 | roc_list.append(np.nan) 398 | 399 | return {'auc':np.array(roc_list)} #y_true.shape[1] 400 | 401 | 402 | import copy 403 | import time 404 | def do_ft_head(student, head, optimizer, dl): 405 | if args.ft_bn: 406 | student.train() 407 | else: 408 | student.eval() 409 | for batch in dl: 410 | batch = batch.to(device) 411 | feats = student.logits(batch).detach() 412 | head_op = head(feats) 413 | 414 | l_obj = nn.BCEWithLogitsLoss() 415 | y = batch.go_target_downstream.view(head_op.shape).to(torch.float64) 416 | loss = l_obj(head_op.double(), y) 417 | acc = 0 418 | 419 | optimizer.zero_grad() 420 | loss.backward() 421 | optimizer.step() 422 | 423 | optimizer.zero_grad() 424 | break 425 | return loss.item(), acc 426 | 427 | 428 | 429 | def train(pretrain_dl, train_dl, val_dl, test_dl_easy, test_dl_hard): 430 | # also creates save path 431 | save_path = get_save_path() 432 | 433 | student = GNN_graphpred(args.num_layer, args.emb_dim, NUM_TASKS_PT, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type) 434 | student = student.to(device) 435 | head = GNNHead(2*args.emb_dim, NUM_TASKS_FT).to(device) 436 | 437 | teacher = TaskWeight(num_classes=NUM_TASKS_PT).to(device) 438 | 439 | hyp_params = list(teacher.parameters()) 440 | 441 | pretrain_optim = torch.optim.Adam(student.parameters(), lr=args.pretrain_lr, weight_decay=args.decay) 442 | 443 | finetune_optim = torch.optim.Adam(head.parameters(), lr=args.finetune_lr) 444 | hyp_optim = torch.optim.Adam(hyp_params, lr=args.hyper_lr) 445 | 446 | stud_pretrain_ld = {'loss' : [], 'acc' : [] } 447 | stud_finetune_train_ld = {'loss' : [], 'acc' : []} 448 | stud_finetune_val_ld = {'loss' : [], 'acc' : []} 449 | stud_finetune_test_ld_easy = {} 450 | stud_finetune_test_ld_hard = {} 451 | 452 | num_finetune_steps = args.finetune_steps 453 | num_neumann_steps = args.neumann 454 | 455 | if args.checkpoint: 456 | ckpt = torch.load(args.checkpoint) 457 | student.load_state_dict(ckpt['student_sd']) 458 | teacher.load_state_dict(ckpt['teacher_sd']) 459 | head.load_state_dict(ckpt['head_sd']) 460 | pretrain_optim.load_state_dict(ckpt['pt_opt_state_dict']) 461 | finetune_optim.load_state_dict(ckpt['ft_opt_state_dict']) 462 | hyp_optim.load_state_dict(ckpt['hyp_opt_state_dict']) 463 | load_ep = int(os.path.split(args.checkpoint)[-1][16:-3]) + 1 464 | print(f"Loaded checkpoint {args.checkpoint}, epoch {load_ep}") 465 | else: 466 | load_ep = 0 467 | steps = 0 468 | for n in range(load_ep, args.epochs): 469 | pt_loss_meter = AverageMeter() 470 | pt_acc_meter = AverageMeter() 471 | ft_train_loss_meter = AverageMeter() 472 | ft_train_acc_meter = AverageMeter() 473 | ft_val_loss_meter = AverageMeter() 474 | ft_val_acc_meter = AverageMeter() 475 | 476 | progress_bar = tqdm(pretrain_dl) 477 | for i, batch in enumerate(progress_bar): 478 | progress_bar.set_description('Epoch ' + str(n)) 479 | zero_hypergrad(hyp_params) 480 | 481 | if n < args.warmup_epochs: 482 | student.zero_grad(), head.zero_grad() 483 | pt_loss, pt_acc = do_pretrain(student, head, teacher, pretrain_optim, batch) 484 | pt_loss_meter.update(pt_loss) 485 | pt_acc_meter.update(pt_acc) 486 | 487 | student.zero_grad(), head.zero_grad() 488 | ft_train_loss, ft_train_acc = do_ft_head(student, head, finetune_optim, train_dl) 489 | ft_train_loss_meter.update(ft_train_loss) 490 | ft_train_acc_meter.update(ft_train_acc) 491 | ft_val_loss, ft_val_acc, hypg = 0,0,0 492 | else: 493 | pt_loss, pt_acc = do_pretrain(student, head, teacher, pretrain_optim, batch) 494 | pt_loss_meter.update(pt_loss) 495 | pt_acc_meter.update(pt_acc) 496 | if steps % args.pretrain_steps == 0: 497 | with higher.innerloop_ctx(head, finetune_optim, copy_initial_weights=True) as (fnet, diffopt): 498 | (ft_train_loss, ft_train_acc), (ft_val_loss, ft_val_acc), ft_grad, fnet = \ 499 | inner_loop_finetune(student, fnet, teacher, diffopt, train_dl, val_dl, num_finetune_steps) 500 | head.load_state_dict(fnet.state_dict()) 501 | 502 | ft_grad = gather_flat_grad(ft_grad) 503 | for param_group in pretrain_optim.param_groups: 504 | cur_lr = param_group['lr'] 505 | break 506 | 507 | hypg = hyper_step(student, head, teacher, hyp_params, pretrain_loader, pretrain_optim, ft_grad, cur_lr, num_neumann_steps) 508 | hypg = hypg.norm().item() 509 | 510 | hyp_optim.step() 511 | ft_train_loss_meter.update(ft_train_loss) 512 | ft_train_acc_meter.update(ft_train_acc) 513 | ft_val_loss_meter.update(ft_val_loss) 514 | ft_val_acc_meter.update(ft_val_acc) 515 | else: 516 | ft_train_loss, ft_train_acc = do_ft_head(student, head, finetune_optim, train_dl) 517 | ft_train_loss_meter.update(ft_train_loss) 518 | ft_train_acc_meter.update(ft_train_acc) 519 | ft_val_loss, ft_val_acc, hypg = 0,0,0 520 | 521 | 522 | steps += 1 523 | progress_bar.set_postfix( 524 | pt_l='%.4f' % pt_loss_meter.avg , 525 | pt_a='%.4f' % pt_acc_meter.avg , 526 | ft_tl='%.4f' % ft_train_loss_meter.avg , 527 | ft_ta='%.4f' % ft_train_acc_meter.avg , 528 | ft_vl='%.4f' % ft_val_loss_meter.avg , 529 | ft_va='%.4f' % ft_val_acc_meter.avg , 530 | hyp_norm='%.6f' % hypg 531 | ) 532 | 533 | # append to lossdict 534 | stud_pretrain_ld['loss'].append(pt_loss) 535 | stud_pretrain_ld['acc'].append(pt_acc) 536 | stud_finetune_train_ld['loss'].append(ft_train_loss) 537 | stud_finetune_train_ld['acc'].append(ft_train_acc) 538 | stud_finetune_val_ld['loss'].append(ft_val_loss) 539 | stud_finetune_val_ld['acc'].append(ft_val_acc) 540 | 541 | 542 | ft_test_ld = eval_student(student,head, test_dl_easy) 543 | print("Easy test") 544 | print(ft_test_ld) 545 | stud_finetune_test_ld_easy = update_lossdict(stud_finetune_test_ld_easy, ft_test_ld) 546 | 547 | ft_test_ld = eval_student(student,head, test_dl_hard) 548 | print("Hard test") 549 | print(ft_test_ld) 550 | stud_finetune_test_ld_hard = update_lossdict(stud_finetune_test_ld_hard, ft_test_ld) 551 | 552 | ft_val_ld = eval_student(student, head, val_dl) 553 | print("Val") 554 | print(ft_val_ld) 555 | stud_finetune_val_ld = update_lossdict(stud_finetune_val_ld, ft_val_ld) 556 | 557 | 558 | if pretrain_val_loader is not None: 559 | print("Evaluating on pretrain val set") 560 | eval_student(student, None, pretrain_val_loader) 561 | 562 | tosave = { 563 | 'pretrain_ld' : stud_pretrain_ld, 564 | 'finetune_train_ld' : stud_finetune_train_ld, 565 | 'finetune_val_ld' : stud_finetune_val_ld, 566 | 'finetune_test_ld_easy' : stud_finetune_test_ld_easy, 567 | 'finetune_test_ld_hard' : stud_finetune_test_ld_hard, 568 | } 569 | torch.save(tosave, os.path.join(save_path, 'logs.ckpt')) 570 | if n % 20 == 0 or n == args.epochs - 1: 571 | model_saver(n, student, head, teacher, pretrain_optim, finetune_optim, hyp_optim, save_path) 572 | print(f"Saved model at epoch {n}") 573 | return student, head, teacher, pretrain_optim, finetune_optim, hyp_optim 574 | 575 | 576 | res = train(pretrain_loader, finetune_train_loader, finetune_val_loader, finetune_test_loader_easy, finetune_test_loader_hard) 577 | 578 | if args.save: 579 | save_path = get_save_path() 580 | student, head, teacher, pretrain_optim, finetune_optim, hyp_optim = res 581 | model_saver(args.epochs, student, head, teacher, pretrain_optim, finetune_optim, hyp_optim, save_path) 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | -------------------------------------------------------------------------------- /multitask/pretrain_supervised_weighting_ft30.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from splitters import random_split, species_split 4 | from loader import BioDataset 5 | from torch_geometric.data import DataLoader 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | from dataloader import DataLoaderFinetune 16 | from model import GNN, GNN_graphpred 17 | from sklearn.metrics import roc_auc_score 18 | 19 | import pandas as pd 20 | 21 | from util import combine_dataset 22 | from model import GNN, GNN_graphpred, GNNHead, TaskWeight 23 | 24 | from PIL import Image 25 | import numpy as np 26 | import os 27 | import torch.utils.data as utils 28 | import higher 29 | import pickle 30 | from torch.autograd import grad 31 | import random 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 34 | parser.add_argument('--batch_size', type=int, default=32, 35 | help='input batch size for training (default: 32)') 36 | 37 | parser.add_argument('--epochs', type=int, default=100) 38 | parser.add_argument('--warmup_epochs', type=int, default=1) 39 | parser.add_argument('--pretrain_steps', type=int, default = 10) 40 | parser.add_argument('--finetune_steps', type=int, default = 1) 41 | parser.add_argument('--neumann', type=int, default=1) 42 | 43 | parser.add_argument('--pretrain_lr', type=float, default=1e-4, 44 | help='learning rate (default: 0.001)') 45 | 46 | parser.add_argument('--finetune_lr', type=float, default=1e-4) 47 | 48 | parser.add_argument('--hyper_lr', type=float, default=1) 49 | 50 | 51 | parser.add_argument('--decay', type=float, default=0, 52 | help='weight decay (default: 0)') 53 | parser.add_argument('--num_layer', type=int, default=5, 54 | help='number of GNN message passing layers (default: 5).') 55 | parser.add_argument('--emb_dim', type=int, default=300, 56 | help='embedding dimensions (default: 300)') 57 | parser.add_argument('--dropout_ratio', type=float, default=0.2, 58 | help='dropout ratio (default: 0.2)') 59 | parser.add_argument('--graph_pooling', type=str, default="mean", 60 | help='graph level pooling (sum, mean, max, set2set, attention)') 61 | parser.add_argument('--JK', type=str, default="last", 62 | help='how the node features across layers are combined. last, sum, max or concat') 63 | parser.add_argument('--input_model_file', type=str, default = '', help='filename to read the model (if there is any)') 64 | parser.add_argument('--savefol', type=str, default='exw-adamhyper-ft30tasks') 65 | parser.add_argument('--gnn_type', type=str, default="gin") 66 | parser.add_argument('--num_workers', type=int, default = 0, help='number of workers for dataset loading') 67 | parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting dataset.") 68 | parser.add_argument('--split', type=str, default = "species", help='Random or species split') 69 | parser.add_argument('--gpu', type=int, default=0) 70 | parser.add_argument('--fold', type=int, default=3, help='fold*10:fold*10+10 defines test tasks') 71 | 72 | parser.add_argument('--checkpoint', type=str) 73 | parser.add_argument('--ft_bn', action='store_true') 74 | args = parser.parse_args() 75 | args.savefol += f'-fold{args.fold}' 76 | 77 | torch.manual_seed(0) 78 | np.random.seed(0) 79 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 80 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 81 | print(device) 82 | if torch.cuda.is_available(): 83 | torch.cuda.manual_seed_all(0) 84 | 85 | class AverageMeter(object): 86 | """Computes and stores the average and current value""" 87 | def __init__(self): 88 | self.reset() 89 | 90 | def reset(self): 91 | self.val = 0 92 | self.avg = 0 93 | self.sum = 0 94 | self.count = 0 95 | 96 | def update(self, val, n=1): 97 | self.val = val 98 | self.sum += val * n 99 | self.count += n 100 | self.avg = self.sum / self.count 101 | 102 | 103 | 104 | # Need to set this! 105 | root_supervised = '/path/to/dataset/' 106 | 107 | dataset = BioDataset(root_supervised, data_type='supervised') 108 | 109 | 110 | print("Making PT dataset...") 111 | if args.split == "random": 112 | print("random splitting") 113 | train_dataset, valid_dataset, test_dataset = random_split(dataset, seed = args.seed) 114 | print(train_dataset) 115 | print(valid_dataset) 116 | pretrain_dataset = combine_dataset(train_dataset, valid_dataset) 117 | print(pretrain_dataset) 118 | elif args.split == "species": 119 | print("species splitting") 120 | trainval_dataset, test_dataset = species_split(dataset) 121 | test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed = args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) 122 | print(trainval_dataset) 123 | print(test_dataset_broad) 124 | pretrain_dataset = combine_dataset(trainval_dataset, test_dataset_broad) 125 | print(pretrain_dataset) 126 | #train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) 127 | else: 128 | raise ValueError("Unknown split name.") 129 | 130 | pretrain_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 131 | pretrain_val_loader = None 132 | NUM_TASKS_PT = len(pretrain_dataset[0].go_target_pretrain) 133 | 134 | 135 | print("Making FT dataset...") 136 | if args.split == "random": 137 | print("random splitting") 138 | train_dataset, valid_dataset, test_dataset = random_split(dataset, seed = args.seed) 139 | elif args.split == "species": 140 | trainval_dataset, test_dataset = species_split(dataset) 141 | train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) 142 | test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed = args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) 143 | print("species splitting") 144 | else: 145 | raise ValueError("Unknown split name.") 146 | 147 | finetune_train_loader = DataLoaderFinetune(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 148 | finetune_val_loader = DataLoaderFinetune(valid_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 149 | 150 | if args.split == "random": 151 | test_loader = DataLoaderFinetune(test_dataset, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 152 | else: 153 | ### for species splitting 154 | finetune_test_loader_easy = DataLoaderFinetune(test_dataset_broad, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 155 | finetune_test_loader_hard = DataLoaderFinetune(test_dataset_none, batch_size=10*args.batch_size, shuffle=False, num_workers = args.num_workers) 156 | 157 | # HACK. reduce number of tasks by 10 for meta PT. 158 | NUM_TASKS_FT = len(dataset[0].go_target_downstream) - 10 159 | 160 | 161 | 162 | def model_saver(epoch, student, head, teacher, pt_opt, ft_opt, hyp_opt, path): 163 | torch.save({ 164 | 'student_sd': student.state_dict(), 165 | 'teacher_sd': teacher.state_dict(), 166 | 'head_sd': head.state_dict(), 167 | 'pt_opt_state_dict': pt_opt.state_dict(), 168 | 'ft_opt_state_dict': ft_opt.state_dict(), 169 | 'hyp_opt_state_dict': hyp_opt.state_dict(), 170 | }, path + f'/checkpoint_epoch{epoch}.pt') 171 | 172 | 173 | def get_save_path(): 174 | modfol = f"""ptlr{args.pretrain_lr}-ftlr{args.finetune_lr}-hyplr{args.hyper_lr}-warmup{args.warmup_epochs}-pt_steps{args.pretrain_steps}-ft_steps{args.finetune_steps}-neumann{args.neumann}-ft_bn{args.ft_bn}""" 175 | pth = os.path.join(args.savefol, modfol) 176 | os.makedirs(pth, exist_ok=True) 177 | return pth 178 | 179 | def zero_hypergrad(hyper_params): 180 | """ 181 | 182 | :param get_hyper_train: 183 | :return: 184 | """ 185 | current_index = 0 186 | for p in hyper_params: 187 | p_num_params = np.prod(p.shape) 188 | if p.grad is not None: 189 | p.grad = p.grad * 0 190 | current_index += p_num_params 191 | 192 | 193 | def store_hypergrad(hyper_params, total_d_val_loss_d_lambda): 194 | """ 195 | 196 | :param get_hyper_train: 197 | :param total_d_val_loss_d_lambda: 198 | :return: 199 | """ 200 | current_index = 0 201 | for p in hyper_params: 202 | p_num_params = np.prod(p.shape) 203 | p.grad = total_d_val_loss_d_lambda[current_index:current_index + p_num_params].view(p.shape) 204 | current_index += p_num_params 205 | 206 | 207 | def neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, num_neumann_terms, model, head): 208 | preconditioner = d_val_loss_d_theta.detach() 209 | counter = preconditioner 210 | 211 | # Do the fixed point iteration to approximate the vector-inverseHessian product 212 | i = 0 213 | while i < num_neumann_terms: # for i in range(num_neumann_terms): 214 | old_counter = counter 215 | 216 | # This increments counter to counter * (I - hessian) = counter - counter * hessian 217 | hessian_term = gather_flat_grad( 218 | grad(d_train_loss_d_w, list(model.parameters())+list(head.parameters()), grad_outputs=counter.view(-1), retain_graph=True)) 219 | # hessian_term[hessian_term == None] = 0 220 | counter = old_counter - elementary_lr * hessian_term 221 | 222 | preconditioner = preconditioner + counter 223 | i += 1 224 | return elementary_lr * preconditioner 225 | 226 | def get_hyper_train_flat(hyper_params): 227 | return torch.cat([p.view(-1) for p in hyper_params]) 228 | 229 | def gather_flat_grad(loss_grad): 230 | return torch.cat([p.reshape(-1) for p in loss_grad]) #g_vector 231 | 232 | def get_subset_fold(data): 233 | exclude_start = args.fold*10 234 | exclude_end = args.fold*10 + 10 235 | seg1 = data[:,:exclude_start] 236 | seg2 = data[:,exclude_end:] 237 | return torch.cat([seg1, seg2], dim=1) 238 | 239 | 240 | # Forward pass on student and teacher, getting neg log lik of label and batch acc 241 | def get_loss(student,head,teacher, batch, ft): 242 | batch = batch.to(device) 243 | feats = student.logits(batch) 244 | pi_stud = student.head(feats) 245 | head_op = head(feats) 246 | if ft: 247 | # HACK HERE TO ONLY USE THE FIRST 30 TASKS 248 | y = batch.go_target_downstream.view(-1, 40) 249 | y = get_subset_fold(y).view(head_op.shape).to(torch.float64) 250 | l_obj = nn.BCEWithLogitsLoss() 251 | loss = l_obj(head_op.double(), y) 252 | y_loss_stud = loss + 0*torch.sum(pi_stud.double()) 253 | else: 254 | y = batch.go_target_pretrain.view(pi_stud.shape).to(torch.float64) 255 | comm = teacher.forward() 256 | l_obj = nn.BCEWithLogitsLoss(reduction = 'none') 257 | loss = l_obj(pi_stud.double(), y) 258 | # now apply the classweight 259 | comm = comm.unsqueeze(0) 260 | y_loss_stud = torch.mean(comm*loss) + 0*torch.sum(head_op) * 0*torch.sum(loss) 261 | return y_loss_stud, 0 262 | 263 | 264 | def hyper_step(model, head, teacher, hyper_params, pretrain_loader, optimizer, d_val_loss_d_theta, elementary_lr, neum_steps): 265 | zero_hypergrad(hyper_params) 266 | num_weights = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in head.parameters()) 267 | num_hypers = sum(p.numel() for p in hyper_params) 268 | 269 | d_train_loss_d_w = torch.zeros(num_weights).cuda() 270 | model.train(), model.zero_grad(),head.train(), head.zero_grad() 271 | 272 | # NOTE: This should be the pretrain set: gradient of PRETRAINING loss wrt pretrain parameters. 273 | for batch_idx, batch in enumerate(pretrain_loader): 274 | batch = batch.to(device) 275 | train_loss, _ = get_loss(model, head, teacher, batch, ft=False) 276 | optimizer.zero_grad() 277 | d_train_loss_d_w += gather_flat_grad(grad(train_loss, list(model.parameters())+list(head.parameters()), 278 | create_graph=True)) 279 | break 280 | optimizer.zero_grad() 281 | 282 | # Initialize the preconditioner and counter 283 | preconditioner = d_val_loss_d_theta 284 | 285 | preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, 286 | neum_steps, model, head) 287 | indirect_grad = gather_flat_grad( 288 | grad(d_train_loss_d_w, hyper_params, grad_outputs=preconditioner.view(-1))) 289 | hypergrad = indirect_grad # + direct_Grad 290 | 291 | zero_hypergrad(hyper_params) 292 | store_hypergrad(hyper_params, -hypergrad) 293 | return hypergrad 294 | 295 | def do_pretrain(student, head, teacher, optimizer, batch): 296 | student.train() 297 | 298 | batch = batch.to(device) 299 | 300 | loss, acc = get_loss(student, head, teacher, batch, ft=False) 301 | optimizer.zero_grad() 302 | loss.backward() 303 | optimizer.step() 304 | 305 | return loss.item(), acc 306 | 307 | 308 | 309 | def inner_loop_finetune(student, head, teacher, optimizer, train_dl, val_dl, num_steps): 310 | stud_loss = 0. 311 | stud_acc = 0. 312 | 313 | if args.ft_bn: 314 | student.train() 315 | else: 316 | student.eval() # BN train for only head unroll 317 | teacher.train() 318 | 319 | for i, batch in enumerate(train_dl): 320 | batch = batch.to(device) 321 | y_loss, acc = get_loss(student, head, teacher, batch, ft=True) 322 | optimizer.step(y_loss) 323 | # logging 324 | stud_loss += y_loss.item() 325 | stud_acc += acc 326 | 327 | if i == num_steps - 1: 328 | break 329 | stud_loss /= num_steps 330 | stud_acc /= num_steps 331 | 332 | # Now compute the val loss 333 | avgloss = None 334 | avgacc = None 335 | for i, batch in enumerate(val_dl): 336 | batch = batch.to(device) 337 | y_loss, acc = get_loss(student, head, teacher, batch, ft=True) 338 | 339 | if avgloss is None: 340 | avgloss = y_loss 341 | avgacc = acc 342 | else: 343 | avgloss += y_loss 344 | avgacc += acc 345 | break 346 | # Now compute a finetuning gradient 347 | ft_grad = torch.autograd.grad(avgloss, list(student.parameters())+list(head.parameters(time=0)), allow_unused=True) 348 | return (stud_loss, stud_acc), (avgloss, avgacc), ft_grad, head 349 | 350 | 351 | 352 | # Utility function to update lossdict 353 | def update_lossdict(lossdict, update, action='append'): 354 | for k in update.keys(): 355 | if action == 'append': 356 | if k in lossdict: 357 | lossdict[k].append(update[k]) 358 | else: 359 | lossdict[k] = [update[k]] 360 | elif action == 'sum': 361 | if k in lossdict: 362 | lossdict[k] += update[k] 363 | else: 364 | lossdict[k] = update[k] 365 | else: 366 | raise NotImplementedError 367 | return lossdict 368 | 369 | 370 | 371 | # Evaluate student on complete train/test set. 372 | def eval_student(student, head, dl): 373 | student.eval() 374 | y_true = [] 375 | y_scores = [] 376 | 377 | for step, batch in enumerate(tqdm(dl, desc="Iteration")): 378 | batch = batch.to(device) 379 | 380 | with torch.no_grad(): 381 | pred = head(student.logits(batch)) 382 | 383 | y_true.append(get_subset_fold(batch.go_target_downstream.view(-1,40).detach().cpu())) 384 | y_scores.append(pred.detach().cpu()) 385 | 386 | y_true = torch.cat(y_true, dim = 0).numpy() 387 | y_scores = torch.cat(y_scores, dim = 0).numpy() 388 | 389 | roc_list = [] 390 | for i in range(y_true.shape[1]): 391 | try: 392 | #AUC is only defined when there is at least one positive data. 393 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == 0) > 0: 394 | roc_list.append(roc_auc_score(y_true[:,i], y_scores[:,i])) 395 | else: 396 | roc_list.append(np.nan) 397 | except ValueError: 398 | roc_list.append(np.nan) 399 | 400 | return {'auc':np.array(roc_list)} #y_true.shape[1] 401 | 402 | 403 | import copy 404 | import time 405 | def do_ft_head(student, head, optimizer, dl): 406 | # t = time.time() 407 | if args.ft_bn: 408 | student.train() 409 | else: 410 | student.eval() 411 | for batch in dl: 412 | batch = batch.to(device) 413 | feats = student.logits(batch).detach() 414 | head_op = head(feats) 415 | 416 | l_obj = nn.BCEWithLogitsLoss() 417 | # HACK to get in right shape 418 | y = batch.go_target_downstream.view(-1, 40) 419 | y = get_subset_fold(y).view(head_op.shape).to(torch.float64) 420 | loss = l_obj(head_op.double(), y) 421 | acc = 0 422 | 423 | optimizer.zero_grad() 424 | loss.backward() 425 | optimizer.step() 426 | 427 | optimizer.zero_grad() 428 | break 429 | return loss.item(), acc 430 | 431 | 432 | 433 | def train(pretrain_dl, train_dl, val_dl, test_dl_easy, test_dl_hard): 434 | # create save path 435 | save_path = get_save_path() 436 | 437 | student = GNN_graphpred(args.num_layer, args.emb_dim, NUM_TASKS_PT, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type) 438 | student = student.to(device) 439 | head = GNNHead(2*args.emb_dim, NUM_TASKS_FT).to(device) 440 | 441 | teacher = TaskWeight(num_classes=NUM_TASKS_PT).to(device) 442 | 443 | hyp_params = list(teacher.parameters()) 444 | 445 | pretrain_optim = torch.optim.Adam(student.parameters(), lr=args.pretrain_lr, weight_decay=args.decay) 446 | 447 | finetune_optim = torch.optim.Adam(head.parameters(), lr=args.finetune_lr) 448 | hyp_optim = torch.optim.Adam(hyp_params, lr=args.hyper_lr) 449 | 450 | stud_pretrain_ld = {'loss' : [], 'acc' : [] } 451 | stud_finetune_train_ld = {'loss' : [], 'acc' : []} 452 | stud_finetune_val_ld = {'loss' : [], 'acc' : []} 453 | stud_finetune_test_ld_easy = {} 454 | stud_finetune_test_ld_hard = {} 455 | 456 | num_finetune_steps = args.finetune_steps 457 | num_neumann_steps = args.neumann 458 | 459 | if args.checkpoint: 460 | ckpt = torch.load(args.checkpoint) 461 | student.load_state_dict(ckpt['student_sd']) 462 | teacher.load_state_dict(ckpt['teacher_sd']) 463 | head.load_state_dict(ckpt['head_sd']) 464 | pretrain_optim.load_state_dict(ckpt['pt_opt_state_dict']) 465 | finetune_optim.load_state_dict(ckpt['ft_opt_state_dict']) 466 | hyp_optim.load_state_dict(ckpt['hyp_opt_state_dict']) 467 | load_ep = int(os.path.split(args.checkpoint)[-1][16:-3]) + 1 468 | print(f"Loaded checkpoint {args.checkpoint}, epoch {load_ep}") 469 | else: 470 | load_ep = 0 471 | steps = 0 472 | for n in range(load_ep, args.epochs): 473 | pt_loss_meter = AverageMeter() 474 | pt_acc_meter = AverageMeter() 475 | ft_train_loss_meter = AverageMeter() 476 | ft_train_acc_meter = AverageMeter() 477 | ft_val_loss_meter = AverageMeter() 478 | ft_val_acc_meter = AverageMeter() 479 | 480 | progress_bar = tqdm(pretrain_dl) 481 | # progress_bar = pretrain_dl 482 | for i, batch in enumerate(progress_bar): 483 | progress_bar.set_description('Epoch ' + str(n)) 484 | # step_num = i + n*len(pretrain_dl) 485 | zero_hypergrad(hyp_params) 486 | 487 | if n < args.warmup_epochs: 488 | student.zero_grad(), head.zero_grad() 489 | pt_loss, pt_acc = do_pretrain(student, head, teacher, pretrain_optim, batch) 490 | pt_loss_meter.update(pt_loss) 491 | pt_acc_meter.update(pt_acc) 492 | 493 | student.zero_grad(), head.zero_grad() 494 | # s_copy = copy.deepcopy(student) 495 | ft_train_loss, ft_train_acc = do_ft_head(student, head, finetune_optim, train_dl) 496 | ft_train_loss_meter.update(ft_train_loss) 497 | ft_train_acc_meter.update(ft_train_acc) 498 | ft_val_loss, ft_val_acc, hypg = 0,0,0 499 | 500 | else: 501 | pt_loss, pt_acc = do_pretrain(student, head, teacher, pretrain_optim, batch) 502 | pt_loss_meter.update(pt_loss) 503 | pt_acc_meter.update(pt_acc) 504 | if steps % args.pretrain_steps == 0: 505 | with higher.innerloop_ctx(head, finetune_optim, copy_initial_weights=True) as (fnet, diffopt): 506 | (ft_train_loss, ft_train_acc), (ft_val_loss, ft_val_acc), ft_grad, fnet = \ 507 | inner_loop_finetune(student, fnet, teacher, diffopt, train_dl, val_dl, num_finetune_steps) 508 | head.load_state_dict(fnet.state_dict()) 509 | 510 | ft_grad = gather_flat_grad(ft_grad) 511 | for param_group in pretrain_optim.param_groups: 512 | cur_lr = param_group['lr'] 513 | break 514 | 515 | hypg = hyper_step(student, head, teacher, hyp_params, pretrain_loader, pretrain_optim, ft_grad, cur_lr, num_neumann_steps) 516 | hypg = hypg.norm().item() 517 | 518 | hyp_optim.step() 519 | # print(teacher.weights.mean()) 520 | ft_train_loss_meter.update(ft_train_loss) 521 | ft_train_acc_meter.update(ft_train_acc) 522 | ft_val_loss_meter.update(ft_val_loss) 523 | ft_val_acc_meter.update(ft_val_acc) 524 | else: 525 | ft_train_loss, ft_train_acc = do_ft_head(student, head, finetune_optim, train_dl) 526 | ft_train_loss_meter.update(ft_train_loss) 527 | ft_train_acc_meter.update(ft_train_acc) 528 | ft_val_loss, ft_val_acc, hypg = 0,0,0 529 | 530 | 531 | steps += 1 532 | progress_bar.set_postfix( 533 | pt_l='%.4f' % pt_loss_meter.avg , 534 | pt_a='%.4f' % pt_acc_meter.avg , 535 | ft_tl='%.4f' % ft_train_loss_meter.avg , 536 | ft_ta='%.4f' % ft_train_acc_meter.avg , 537 | ft_vl='%.4f' % ft_val_loss_meter.avg , 538 | ft_va='%.4f' % ft_val_acc_meter.avg , 539 | hyp_norm='%.6f' % hypg 540 | ) 541 | 542 | # append to lossdict 543 | stud_pretrain_ld['loss'].append(pt_loss) 544 | stud_pretrain_ld['acc'].append(pt_acc) 545 | stud_finetune_train_ld['loss'].append(ft_train_loss) 546 | stud_finetune_train_ld['acc'].append(ft_train_acc) 547 | stud_finetune_val_ld['loss'].append(ft_val_loss) 548 | stud_finetune_val_ld['acc'].append(ft_val_acc) 549 | 550 | 551 | ft_test_ld = eval_student(student,head, test_dl_easy) 552 | print("Easy test") 553 | print(ft_test_ld) 554 | stud_finetune_test_ld_easy = update_lossdict(stud_finetune_test_ld_easy, ft_test_ld) 555 | 556 | ft_test_ld = eval_student(student,head, test_dl_hard) 557 | print("Hard test") 558 | print(ft_test_ld) 559 | stud_finetune_test_ld_hard = update_lossdict(stud_finetune_test_ld_hard, ft_test_ld) 560 | 561 | ft_val_ld = eval_student(student, head, val_dl) 562 | print("Val") 563 | print(ft_val_ld) 564 | stud_finetune_val_ld = update_lossdict(stud_finetune_val_ld, ft_val_ld) 565 | 566 | 567 | if pretrain_val_loader is not None: 568 | print("Evaluating on pretrain val set") 569 | eval_student(student, None, pretrain_val_loader) 570 | 571 | tosave = { 572 | 'pretrain_ld' : stud_pretrain_ld, 573 | 'finetune_train_ld' : stud_finetune_train_ld, 574 | 'finetune_val_ld' : stud_finetune_val_ld, 575 | 'finetune_test_ld_easy' : stud_finetune_test_ld_easy, 576 | 'finetune_test_ld_hard' : stud_finetune_test_ld_hard, 577 | } 578 | torch.save(tosave, os.path.join(save_path, 'logs.ckpt')) 579 | if n % 20 == 0 or n==args.epochs-1: 580 | model_saver(n, student, head, teacher, pretrain_optim, finetune_optim, hyp_optim, save_path) 581 | print(f"Saved model at epoch {n}") 582 | return student, head, teacher, pretrain_optim, finetune_optim, hyp_optim 583 | 584 | 585 | res = train(pretrain_loader, finetune_train_loader, finetune_val_loader, finetune_test_loader_easy, finetune_test_loader_hard) 586 | 587 | if args.save: 588 | save_path = get_save_path() 589 | student, head, teacher, pretrain_optim, finetune_optim, hyp_optim = res 590 | model_saver(args.epochs, student, head, teacher, pretrain_optim, finetune_optim, hyp_optim, save_path) 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | --------------------------------------------------------------------------------