├── .gitignore ├── 01-ddp_byol.py ├── 02-eval_byot_svm_alldata.py ├── README.md ├── byol_trainer.py ├── checkpoints └── checkpoints_test │ └── test.pth ├── config └── normal_sch1.yaml ├── down_stream └── xuanwu_pd.csv ├── masking_generator.py ├── network_dataset.py ├── res.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /01-ddp_byol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | from network_dataset import Task1Data 4 | import os 5 | import time 6 | from torch.utils.tensorboard import SummaryWriter 7 | from torch.utils.data.dataloader import DataLoader 8 | import torch.nn.functional as F 9 | import torch.distributed as dist 10 | import argparse 11 | from byol_trainer import BYOLTrainer 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.utils.data.distributed import DistributedSampler 14 | 15 | from timm.utils import NativeScaler 16 | 17 | import random 18 | import numpy as np 19 | 20 | seed = 42 21 | random.seed(seed) 22 | os.environ['PYTHONHASHSEED'] = str(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | def get_parse(): 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--no_ddp", action = "store_true",default = False) 34 | parser.add_argument("--resume", action = "store_true",default = False) 35 | parser.add_argument("--config", type=str,default = "") 36 | parser.add_argument("--model_path", type=str,default = "") 37 | parser.add_argument("--use_ddp",type=bool) 38 | parser.add_argument("--local-rank", default=-1) 39 | 40 | FLAGS = parser.parse_args() 41 | FLAGS.use_ddp = not FLAGS.no_ddp 42 | return FLAGS 43 | 44 | def init_ddp(FLAGS): 45 | local_rank = FLAGS.local_rank 46 | torch.cuda.set_device(int(local_rank)) 47 | dist.init_process_group(backend='nccl') 48 | 49 | def adjust_learning_rate(optimizer, epoch, final_lr,warmup_epochs): 50 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 51 | 52 | lr_warmup = [i / warmup_epochs * final_lr for i in range(1, int(warmup_epochs +1))] 53 | if epoch < warmup_epochs: 54 | for param_group in optimizer.param_groups: 55 | param_group['lr'] = lr_warmup[epoch] 56 | else: 57 | for param_group in optimizer.param_groups: 58 | param_group['lr'] = final_lr 59 | 60 | def main(): 61 | FLAGS = get_parse() 62 | if FLAGS.use_ddp is True: 63 | print("Init ddp") 64 | init_ddp(FLAGS) 65 | 66 | config = yaml.load(open(FLAGS.config, "r"), Loader = yaml.FullLoader) 67 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 68 | data_path = str(config['data']['path']) 69 | csv_path = str(config['data']['csv']) 70 | time_len = int(config['data']['time_len']) 71 | mask_len = int(config['data']['time_mask']) 72 | mask_way = str(config['data']['mask_way']) 73 | train_dataset = Task1Data(root = data_path,csv = csv_path, mask_way=mask_way,mask_len=mask_len, time_len=time_len) 74 | if FLAGS.use_ddp is True: 75 | train_sampler = DistributedSampler(train_dataset) 76 | train_loader = DataLoader(train_dataset, batch_size=config['trainer']['batch_size'],num_workers=config['trainer']['num_workers'],pin_memory=True,sampler = train_sampler) 77 | else: 78 | train_sampler = None 79 | train_loader = DataLoader(train_dataset, batch_size=config['trainer']['batch_size'],num_workers=config['trainer']['num_workers']) 80 | 81 | feature_size = config['network']['feature_dim'] 82 | depth = config['network']['depth'] 83 | heads = config['network']['heads'] 84 | dim_feedforward = config['network']['dim_feedforward'] 85 | mm = str(config['network']['mm']) 86 | clf_mask = int(config['network']['clf_mask']) 87 | mse_mask = int(config['network']['mse_mask']) 88 | model = BYOLTrainer(depth,heads,config['trainer']['m'],feature_size,dim_feedforward,mm=mm, clf_mask = clf_mask, mse_mask = mse_mask) 89 | model.cuda().train() 90 | 91 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 92 | print(f"Model has parameters: {n_parameters / 1e6}M") 93 | model_checkpoints_folder = config['saving']['checkpoint_dir'] 94 | 95 | log_dir = config['saving']['log_dir'] 96 | 97 | if FLAGS.resume is True: 98 | checkpoint = torch.load(FLAGS.model_path, map_location = 'cpu') 99 | model.load_state_dict(checkpoint['model']) 100 | 101 | if FLAGS.use_ddp is True: 102 | model = DDP(model, find_unused_parameters=True) 103 | model_call = model.module 104 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 105 | else: 106 | model_call = model 107 | 108 | optimizer = torch.optim.AdamW(model_call.get_parameters(),lr= config['optimizer']['lr'],weight_decay=config['optimizer']['weight_decay']) 109 | 110 | if not os.path.exists(model_checkpoints_folder) and dist.get_rank() == 0: 111 | os.makedirs(model_checkpoints_folder) 112 | 113 | loss_scaler = NativeScaler() 114 | 115 | if FLAGS.use_ddp and dist.get_rank() == 0: 116 | writer = SummaryWriter(log_dir=log_dir) 117 | else: 118 | writer = None 119 | 120 | model_call.initialize_target() 121 | best_train_loss = 99999. 122 | 123 | acc_lambda = float(config['trainer']['acc_lambda']) 124 | mse_lambda = float(config['trainer']['mse_lambda']) 125 | warmup_epochs = int(config['trainer']['warmup_epochs']) 126 | 127 | # for test 128 | #for epoch_counter in range(1): 129 | for epoch_counter in range(config['trainer']['max_epochs']): 130 | epoch_counter = epoch_counter #+ epoch_from 131 | model.train() 132 | if FLAGS.use_ddp is True: 133 | train_loader.sampler.set_epoch(epoch_counter) 134 | if FLAGS.resume is not True: 135 | adjust_learning_rate(optimizer,epoch_counter,config['optimizer']['lr'],warmup_epochs) 136 | 137 | header = 'Epoch: [{}]'.format(epoch_counter) 138 | 139 | n_steps = 50 140 | 141 | total_loss = 0. 142 | byol_loss = 0. 143 | nce_loss = 0. 144 | mcl_acc = 0. 145 | 146 | count = 0. 147 | niter = 0. 148 | 149 | st = time.time() 150 | calc_st = time.time() 151 | 152 | for step, (batch_view_1, batch_view_2) in enumerate(train_loader): 153 | B = len(batch_view_1) 154 | 155 | batch_view_1 = batch_view_1.to(device,non_blocking=True).float() 156 | batch_view_2 = batch_view_2.to(device,non_blocking=True).float() 157 | 158 | loss_byol, acc, nce, mse = model(batch_view_1, batch_view_2) 159 | 160 | if mm == 'byol': 161 | loss = loss_byol 162 | elif mm == 'byol+clf': 163 | loss = loss_byol + acc_lambda* nce 164 | elif mm == 'byol+mse': 165 | loss = loss_byol + mse_lambda* mse 166 | else: 167 | loss = loss_byol + acc_lambda* nce + mse_lambda * mse 168 | 169 | optimizer.zero_grad() 170 | 171 | loss_scaler(loss, optimizer, parameters=model.parameters(),clip_grad=1,clip_mode='value') 172 | 173 | #loss.backward() 174 | #optimizer.step() 175 | 176 | model_call.update_target() # update the key encoder 177 | 178 | total_loss += len(batch_view_1) * float(loss) 179 | byol_loss += len(batch_view_1) * float(loss_byol) 180 | nce_loss += len(batch_view_1) * float(nce) 181 | mcl_acc += len(batch_view_1) * float(acc) 182 | 183 | count += len(batch_view_1) 184 | if FLAGS.use_ddp is False or (FLAGS.use_ddp is True and dist.get_rank() == 0): 185 | if step % n_steps == 0: 186 | end = time.time() 187 | print(f"Epoch: {epoch_counter} [{step}/{len(train_loader)}]: byol: {loss_byol:.5f}, nce: {nce:.5f}, mse: {mse:.5f} time: {end-st}") 188 | st = time.time() 189 | if step %n_steps == 0 and step != 0: 190 | if FLAGS.use_ddp is False or (FLAGS.use_ddp is True and dist.get_rank() == 0): 191 | need_time = (time.time() - calc_st) /n_steps * len(train_loader)/60./60. 192 | print(f"precalc time: {need_time} hours by batch: {config['trainer']['batch_size']}") 193 | calc_st = time.time() 194 | total_loss /= count 195 | byol_loss /= count 196 | nce_loss /= count 197 | mcl_acc /= count 198 | 199 | 200 | if writer is not None: 201 | 202 | writer.add_scalar('Acc', mcl_acc, global_step=epoch_counter) 203 | writer.add_scalar('Nce', nce_loss, global_step=epoch_counter) 204 | writer.add_scalar('MSE', mse, global_step=epoch_counter) 205 | writer.add_scalar('byol_loss', byol_loss, global_step=epoch_counter) 206 | writer.add_scalar('total_loss', total_loss, global_step=epoch_counter) 207 | 208 | if total_loss <= best_train_loss: 209 | best_train_loss = total_loss 210 | model_call.save_model(os.path.join(model_checkpoints_folder, 'best_model.pth')) 211 | 212 | if epoch_counter % config['saving']['n_epochs'] == 0: 213 | model_call.save_model(os.path.join(model_checkpoints_folder, f'model_{epoch_counter}.pth')) 214 | 215 | # save checkpoints 216 | model_call.save_model(os.path.join(model_checkpoints_folder, 'last_model.pth')) 217 | 218 | if __name__ == '__main__': 219 | main() -------------------------------------------------------------------------------- /02-eval_byot_svm_alldata.py: -------------------------------------------------------------------------------- 1 | from network_dataset import Task2Data 2 | import torch 3 | import yaml 4 | import numpy as np 5 | import os 6 | from sklearn import preprocessing 7 | from torch.utils.data.dataloader import DataLoader 8 | from utils import BNTF, MLPHead 9 | import random 10 | import argparse 11 | from sklearn.svm import SVC 12 | import torch 13 | from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score 14 | from utils import get_data 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--seed",'-s', type=int,default = 42) 18 | parser.add_argument("--config_path",'-c', type=str,default = "") 19 | parser.add_argument("--data",'-d', type=str,default = "") 20 | parser.add_argument("--csv",'-f', type=str,default = "") 21 | 22 | args = parser.parse_args() 23 | args = get_data(args) 24 | 25 | 26 | shuffle_seed = int(args.seed) #42 27 | batch_size = 64 28 | seed = 42 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | config_path = args.config_path 39 | 40 | 41 | config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) 42 | 43 | root = "/path/to/alldata" 44 | mask_way=config['data']['mask_way'] 45 | mask_len=int(config['data']['time_mask']) 46 | time_len=int(config['data']['time_len']) 47 | 48 | train_dataset = Task2Data(root, args.csv, mask_way,mask_len,time_len,shuffle_seed=shuffle_seed,is_train=True,is_test=False) 49 | val_dataset = Task2Data(root, args.csv,mask_way,mask_len,time_len,shuffle_seed=shuffle_seed,is_train=False,is_test=False) 50 | test_dataset = Task2Data(root, args.csv,mask_way,mask_len,time_len,shuffle_seed=shuffle_seed,is_train=False,is_test=True) 51 | 52 | train_loader = DataLoader(train_dataset, batch_size=batch_size, 53 | num_workers=4, drop_last=False, shuffle=True) 54 | 55 | val_loader = DataLoader(val_dataset, batch_size=batch_size, 56 | num_workers=4, drop_last=False, shuffle=False) 57 | 58 | test_loader = DataLoader(test_dataset, batch_size=batch_size, 59 | num_workers=4, drop_last=False, shuffle=False) 60 | 61 | 62 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 63 | feature_size = config['network']['feature_dim'] 64 | depth = config['network']['depth'] 65 | heads = config['network']['heads'] 66 | dim_feedforward = config['network']['dim_feedforward'] 67 | 68 | encoder = BNTF(feature_size,depth,heads,dim_feedforward).to(device) 69 | 70 | # test.pth for test. best_model.pth should be used 71 | load_params = torch.load(os.path.join(config['saving']['checkpoint_dir'],'test.pth'), map_location='cpu')['online_network_state_dict'] 72 | encoder.load_state_dict(load_params) 73 | print("Parameters successfully loaded.") 74 | 75 | encoder = encoder.to(device) 76 | def get_features_from_encoder(encoder, loader,times = 1): 77 | 78 | x_train = [] 79 | y_train = [] 80 | encoder.eval() 81 | for _ in range(times): 82 | # TTA is not used 83 | for i, (x, y) in enumerate(loader): 84 | x = x.to(device).float() 85 | y = y.to(device).long() 86 | with torch.no_grad(): 87 | bz, _, _, = x.shape 88 | for atten in encoder.attention_list: 89 | x = atten(x) 90 | node_feature = encoder.dim_reduction(x) 91 | feature_vector = node_feature.reshape((bz, -1)) 92 | x_train.extend(feature_vector.detach()) 93 | y_train.extend(y.detach()) 94 | x_train = torch.stack(x_train).detach().cpu() 95 | y_train = torch.stack(y_train).detach().cpu() 96 | return x_train, y_train 97 | 98 | encoder.eval() 99 | clf = SVC(probability=True) # 5 100 | 101 | x_train, y_train = get_features_from_encoder(encoder, train_loader,1) 102 | x_val, y_val = get_features_from_encoder(encoder, val_loader,1) 103 | x_test, y_test = get_features_from_encoder(encoder, test_loader,1) 104 | print("Loading features loaded.") 105 | 106 | x_train = x_train.detach().cpu().numpy() 107 | x_val = x_val.detach().cpu().numpy() 108 | x_test = x_test.detach().cpu().numpy() 109 | 110 | scaler = preprocessing.StandardScaler() 111 | scaler.fit(x_train) 112 | 113 | x_train = scaler.transform(x_train).astype(np.float32) 114 | x_val = scaler.transform(x_val).astype(np.float32) 115 | x_test = scaler.transform(x_test).astype(np.float32) 116 | 117 | clf.fit(x_train, y_train.detach().cpu().numpy()[:,1]) 118 | pred_test = clf.predict(x_test) 119 | 120 | acc = accuracy_score(pred_test, y_test.detach().cpu().numpy()[:,1]) 121 | cm = confusion_matrix(pred_test, y_test.detach().cpu().numpy()[:,1]) 122 | sen = round(cm[1, 1] / float(cm[1, 1]+cm[1, 0]),4) 123 | spe = round(cm[0, 0] / float(cm[0, 0]+cm[0, 1]),4) 124 | 125 | res_string = f"acc: {acc:.4f} sen: {sen:.4f} spe: {spe:.4f}" 126 | print(res_string) 127 | with open(f"res.txt", 'a') as f: 128 | f.write(f"data:[{args.data}] \t seed:[{shuffle_seed}] \t {res_string} \n") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BrainMass (Brain network analysis via mask modeling and feature Alignment by Self-Supervised learning) 2 | 3 | # Brain network foundation model 4 | 5 | ## Pre-training: 6 | 7 | For pre-training, you should change the "path" and "csv" in the config file (e.g., config/normal_sch1.yaml) 8 | 9 | The "path" is the path of timeseries data 10 | 11 | The csv is the list of pre-training files, like: 12 | 13 | file, 14 | ukb_001, 15 | ukb_002, 16 | ... 17 | 18 | (This code can also be trained for a single dataset and achieves promising performance.) 19 | 20 | you can run pre-training like: 21 | ```shell 22 | python -m torch.distributed.launch --nproc_per_node=4 01-ddp_byol.py --config config/normal_sch1.yaml 23 | ``` 24 | 25 | or use the checkpoint to pre-training like: 26 | ```shell 27 | python -m torch.distributed.launch --nproc_per_node=4 01-ddp_byol.py --config config/normal_sch1.yaml --resume --model_path checkpoints/checkpoints_test/test.pth 28 | ``` 29 | 30 | 31 | ## Downstream evaluation: 32 | 1. You need to change the root of the evaluation files in "02-eval_byot_svm_alldata.py" 43 to your path: 33 | 34 | ```python 35 | root = "/path/to/alldata" 36 | ``` 37 | 2. You need to prepare your own data in the down_stream folder: 38 | 39 | These csv files are like: 40 | 41 | new_nam,dx,is_train,site 42 | xuanwu_001,0,1,0 43 | xuanwu_002,1,0,1 44 | ... 45 | 46 | 3. For downstream evaluation, you can run like: 47 | ```shell 48 | python 02-eval_byot_svm_alldata.py -c config/normal_sch1.yaml -d abide1 -f checkpoints/checkpoints_test/test.pth 49 | ``` 50 | 51 | 52 | We find that different fMRI preprocessing steps might also lead to performance fluctuations. 53 | -------------------------------------------------------------------------------- /byol_trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | from utils import BNTF, MLPHead, get_sinusoid_encoding_table 7 | 8 | 9 | def accuracy(output,target,top_k=(1,)): 10 | """Computes the precision@k for the specified values of k""" 11 | max_k = max(top_k) 12 | batch_size = target.size(0) 13 | 14 | _, predict = output.topk(max_k, 1, True, True) 15 | predict = predict.t() 16 | correct = predict.eq(target.view(1, -1).expand_as(predict)) 17 | 18 | res = [] 19 | for k in top_k: 20 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 21 | res.append(correct_k.mul_(100.0 / batch_size).item()) 22 | return res 23 | 24 | class BYOLTrainer(nn.Module): 25 | def __init__(self,depth,heads,m,feature_dim,dim_feedforward,mm,clf_mask=10, mse_mask=5): 26 | super().__init__() 27 | self.m = m 28 | self.model_mode = mm 29 | self.online_network = BNTF(feature_dim,depth,heads,dim_feedforward) 30 | self.target_network = BNTF(feature_dim,depth,heads,dim_feedforward) 31 | self.predictor = MLPHead(feature_dim, feature_dim*2,feature_dim) 32 | 33 | roi_num = 100 34 | self.roi_num = roi_num 35 | 36 | # for mcl 37 | self.mcl_mask = clf_mask 38 | self.mrm_mask = mse_mask 39 | self.token_num = 2 40 | 41 | self.mask_embed = nn.Parameter(torch.zeros([1, 1, roi_num])) 42 | self.cls_token = nn.Parameter(torch.zeros(1, 1, roi_num)) 43 | self.dist_token = nn.Parameter(torch.zeros(1, 1, roi_num)) 44 | 45 | self.pos_embed = get_sinusoid_encoding_table(roi_num + self.token_num, roi_num) 46 | 47 | self.norm = nn.LayerNorm(roi_num) 48 | 49 | self.cpred = nn.Sequential(nn.Linear(roi_num, dim_feedforward), nn.LeakyReLU(), nn.Linear(dim_feedforward, roi_num)) #256 50 | self.gpred = nn.Sequential(nn.Linear(roi_num, dim_feedforward), nn.LeakyReLU(), nn.Linear(dim_feedforward, roi_num)) #1024 51 | 52 | self.softmax = nn.Softmax(dim=-1) 53 | self.lsoftmax = nn.LogSoftmax(dim=-1) 54 | 55 | self.init_params() 56 | self.check_values() 57 | 58 | 59 | def init_params(self): 60 | self.mask_embed = torch.nn.init.xavier_normal_(self.mask_embed) 61 | self.cls_token = torch.nn.init.xavier_normal_(self.cls_token) 62 | self.dist_token = torch.nn.init.xavier_normal_(self.dist_token) 63 | 64 | def check_values(self): 65 | if self.model_mode not in ["byol","byol+clf","byol+mse","byol+clf+mse",'moco', 'rrp',"moco+clf+mse"]: 66 | raise KeyError(f"{self.model_mode} value error, should be in [byol,byol+clf,byol+mse,byol+clf+mse]") 67 | 68 | if self.mcl_mask > self.roi_num: 69 | raise KeyError(f"{self.mcl_mask} value error, mcl_mask should be smaller than roi_num") 70 | 71 | if self.mrm_mask > self.roi_num: 72 | raise KeyError(f"{self.mrm_mask} value error, mrm_mask should be smaller than roi_num") 73 | 74 | @staticmethod 75 | def regression_loss(x, y): 76 | x = F.normalize(x, dim=1) 77 | y = F.normalize(y, dim=1) 78 | return 2 - 2 * (x * y).sum(dim=-1) 79 | 80 | def forward_mcl(self,x): 81 | B, T, C = x.shape 82 | mask = self.mcl_mask 83 | device = x.device 84 | token_num = 2 85 | 86 | encode_samples = torch.empty((B,mask,C),device = device, requires_grad=False).float() 87 | mask_index = torch.empty((B,mask),device = device, requires_grad=False).long() 88 | mask_dense = torch.ones([B,T,C],device = device) 89 | 90 | for i in range(B): 91 | mask_id = torch.tensor(random.sample(range(0, T), mask)) 92 | mask_index[i] = mask_id 93 | encode_samples[i] = x[i, mask_index[i]].clone().detach() 94 | mask_dense[i,mask_index[i]] = 0 95 | 96 | mask_tokens = self.mask_embed.expand(B, T, -1) 97 | 98 | new_x = x * mask_dense + (1-mask_dense) * mask_tokens 99 | 100 | cls_tokens = self.cls_token.expand(B, -1, -1) 101 | dist_tokens = self.dist_token.expand(B, -1, -1) 102 | new_x = torch.cat((cls_tokens,dist_tokens, new_x), dim=1) 103 | new_x = new_x + self.pos_embed.type_as(new_x).to(x.device).clone().detach() 104 | 105 | x_vis = self.online_network(new_x,forward_with_mlp=False) 106 | 107 | pred = torch.empty((B,mask, C),device = device).float() 108 | for i in range(B): 109 | pred[i]=self.cpred(x_vis[i,mask_index[i]+token_num]) 110 | 111 | nce = torch.tensor(0.0).to(device) 112 | correct = torch.tensor(0.0).to(device) 113 | 114 | for i in range(B): 115 | total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 100*100 116 | correct += torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=0), torch.arange(0, mask, device=device))) # correct is a tensor 117 | nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor 118 | 119 | acc = 1. * correct / (B * mask) 120 | nce = nce / (-1. * B * mask) 121 | 122 | return acc, nce 123 | 124 | def forward_mrm(self,x): 125 | B, T, C = x.shape 126 | mask = self.mrm_mask 127 | device = x.device 128 | token_num = 2 129 | 130 | mask_index = torch.empty((B,mask),device = device, requires_grad=False).long() 131 | mask_dense = torch.ones([B,T,C],device = device) 132 | 133 | for i in range(B): 134 | mask_id = torch.tensor(random.sample(range(0, T), mask)) 135 | mask_index[i] = mask_id 136 | mask_dense[i,mask_index[i]] = 0 137 | 138 | mask_tokens = self.mask_embed.expand(B, T, -1) 139 | 140 | new_x = x * mask_dense + (1-mask_dense) * mask_tokens 141 | 142 | cls_tokens = self.cls_token.expand(B, -1, -1) 143 | dist_tokens = self.dist_token.expand(B, -1, -1) 144 | new_x = torch.cat((cls_tokens,dist_tokens, new_x), dim=1) 145 | new_x = new_x + self.pos_embed.type_as(new_x).to(x.device).clone().detach() 146 | 147 | x_vis = self.online_network(new_x,forward_with_mlp=False) 148 | 149 | pred = torch.empty((B,mask, C),device = device).float() 150 | target = torch.empty((B,mask, C),device = device).float() 151 | for i in range(B): 152 | pred[i]=self.gpred(x_vis[i,mask_index[i]+token_num]) 153 | target[i] = x[i, mask_index[i], :] 154 | 155 | mse = torch.mean((pred - target) ** 2) 156 | 157 | return mse 158 | 159 | def forward(self, batch_view_1, batch_view_2,returns = 'all'): 160 | 161 | if self.model_mode == 'byol': 162 | predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) 163 | predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) 164 | acc = 0. 165 | nce = 0. 166 | mse = 0. 167 | with torch.no_grad(): 168 | targets_to_view_2 = self.target_network(batch_view_1) 169 | targets_to_view_1 = self.target_network(batch_view_2) 170 | 171 | loss = self.regression_loss(predictions_from_view_1, targets_to_view_1) 172 | loss += self.regression_loss(predictions_from_view_2, targets_to_view_2) 173 | 174 | return loss.mean(),acc,nce,mse 175 | elif self.model_mode == 'rrp': 176 | predictions_from_view_1 = self.online_network(batch_view_1) 177 | output = self.clf(predictions_from_view_1) 178 | loss = nn.CrossEntropyLoss()(output, batch_view_2) 179 | acc = accuracy(output, batch_view_2[:, 1])[0] 180 | return loss.mean(), acc, 0., 0. 181 | elif self.model_mode == 'moco': 182 | predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) 183 | predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) 184 | acc = 0. 185 | nce = 0. 186 | mse = 0. 187 | with torch.no_grad(): 188 | targets_to_view_2 = self.target_network(batch_view_1) 189 | targets_to_view_1 = self.target_network(batch_view_2) 190 | 191 | loss = self.contrastive_loss(predictions_from_view_1, targets_to_view_1) 192 | loss += self.contrastive_loss(predictions_from_view_2, targets_to_view_2) 193 | 194 | return loss.mean(),acc,nce,mse 195 | 196 | elif self.model_mode == 'byol+clf': 197 | acc,nce = self.forward_mcl(batch_view_1) 198 | 199 | mse = 0. 200 | 201 | predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) 202 | predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) 203 | 204 | with torch.no_grad(): 205 | targets_to_view_2 = self.target_network(batch_view_1) 206 | targets_to_view_1 = self.target_network(batch_view_2) 207 | 208 | loss = self.contrastive_loss(predictions_from_view_1, targets_to_view_1) 209 | loss += self.contrastive_loss(predictions_from_view_2, targets_to_view_2) 210 | 211 | return loss.mean(),acc,nce,mse 212 | 213 | elif self.model_mode == 'byol+mse': 214 | acc, nce = 0., 0. 215 | 216 | mse = self.forward_mrm(batch_view_1) 217 | 218 | predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) 219 | predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) 220 | 221 | with torch.no_grad(): 222 | targets_to_view_2 = self.target_network(batch_view_1) 223 | targets_to_view_1 = self.target_network(batch_view_2) 224 | 225 | loss = self.regression_loss(predictions_from_view_1, targets_to_view_1) 226 | loss += self.regression_loss(predictions_from_view_2, targets_to_view_2) 227 | 228 | return loss.mean(),acc,nce,mse 229 | 230 | elif self.model_mode == 'byol+clf+mse': 231 | acc,nce = self.forward_mcl(batch_view_1) 232 | 233 | mse = self.forward_mrm(batch_view_1) 234 | 235 | predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) 236 | predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) 237 | 238 | with torch.no_grad(): 239 | targets_to_view_2 = self.target_network(batch_view_1) 240 | targets_to_view_1 = self.target_network(batch_view_2) 241 | 242 | loss = self.regression_loss(predictions_from_view_1, targets_to_view_1) 243 | loss += self.regression_loss(predictions_from_view_2, targets_to_view_2) 244 | 245 | return loss.mean(),acc,nce,mse 246 | 247 | elif self.model_mode == 'moco+clf+mse': 248 | acc,nce = self.forward_mcl(batch_view_1) 249 | 250 | mse = self.forward_mrm(batch_view_1) 251 | 252 | predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) 253 | predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) 254 | 255 | with torch.no_grad(): 256 | targets_to_view_2 = self.target_network(batch_view_1) 257 | targets_to_view_1 = self.target_network(batch_view_2) 258 | 259 | loss = self.contrastive_loss(predictions_from_view_1, targets_to_view_1) 260 | loss += self.contrastive_loss(predictions_from_view_2, targets_to_view_2) 261 | 262 | return loss.mean(),acc,nce,mse 263 | 264 | def contrastive_loss(self, q, k): 265 | # normalize 266 | q = nn.functional.normalize(q, dim=1) 267 | k = nn.functional.normalize(k, dim=1) 268 | # gather all targets 269 | k = concat_all_gather(k) 270 | # Einstein sum is more intuitive 271 | logits = torch.einsum('nc,mc->nm', [q, k]) / 0.07 272 | N = logits.shape[0] # batch size per GPU 273 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() 274 | return nn.CrossEntropyLoss()(logits, labels) * (2 * 0.07) 275 | 276 | @torch.no_grad() 277 | def initialize_target(self): 278 | for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): 279 | param_k.data.copy_(param_q.data) # initialize 280 | param_k.requires_grad = False 281 | 282 | @torch.no_grad() 283 | def update_target(self): 284 | for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): 285 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 286 | 287 | def get_parameters(self): 288 | return list(self.online_network.parameters()) + list(self.predictor.parameters()) 289 | 290 | def save_model(self, path): 291 | try: 292 | if dist.get_rank() == 0: 293 | torch.save({ 294 | 'model':self.state_dict(), 295 | 'online_network_state_dict': self.online_network.state_dict(), 296 | 'target_network_state_dict': self.target_network.state_dict(), 297 | }, path) 298 | except: 299 | torch.save({ 300 | 'model':self.state_dict(), 301 | 'online_network_state_dict': self.online_network.state_dict(), 302 | 'target_network_state_dict': self.target_network.state_dict(), 303 | }, path) 304 | 305 | @torch.no_grad() 306 | def concat_all_gather(tensor): 307 | """ 308 | Performs all_gather operation on the provided tensors. 309 | *** Warning ***: torch.distributed.all_gather has no gradient. 310 | """ 311 | tensors_gather = [torch.ones_like(tensor) 312 | for _ in range(torch.distributed.get_world_size())] 313 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 314 | 315 | output = torch.cat(tensors_gather, dim=0) 316 | return output -------------------------------------------------------------------------------- /checkpoints/checkpoints_test/test.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/podismine/BrainMass/6543c9869cbb7f5369ff8f69b6dea5875f85df02/checkpoints/checkpoints_test/test.pth -------------------------------------------------------------------------------- /config/normal_sch1.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | mm: byol+clf+mse 3 | feature_dim: 1024 4 | depth: 16 5 | heads: 10 6 | dim_feedforward: 2048 7 | clf_mask: 10 8 | mse_mask: 5 9 | roi_num: 100 10 | 11 | saving: 12 | log_dir: logs/log_test 13 | checkpoint_dir: checkpoints/checkpoints_test 14 | n_epochs: 100 15 | 16 | trainer: 17 | batch_size: 64 18 | m: 0.996 # momentum update 19 | max_epochs: 2000 20 | num_workers: 8 21 | acc_lambda: 0.1 22 | mse_lambda: 2 23 | warmup_epochs: 10 24 | 25 | optimizer: 26 | lr: 0.0003 27 | weight_decay: 0.00005 28 | 29 | data: 30 | path: /path/to/timseries 31 | csv: /path/to/pretrain/data.csv 32 | mask_way: mask_per 33 | time_mask: 20 34 | time_len: 30 35 | 36 | -------------------------------------------------------------------------------- /down_stream/xuanwu_pd.csv: -------------------------------------------------------------------------------- 1 | file,dx,old_dx,is_train 2 | xuanwu_hc_00011,0,hc,1 3 | xuanwu_pd_00015,1,pd,0 4 | xuanwu_pd_00001,1,pd,1 5 | xuanwu_hc_00017,0,hc,1 6 | xuanwu_hc_00047,0,hc,1 7 | xuanwu_hc_00063,0,hc,1 8 | xuanwu_pd_00050,1,pd,1 9 | xuanwu_pd_00016,1,pd,1 10 | xuanwu_pd_00010,1,pd,1 11 | xuanwu_pd_00060,1,pd,1 12 | xuanwu_pd_00065,1,pd,1 13 | xuanwu_pd_00034,1,pd,0 14 | xuanwu_pd_00052,1,pd,1 15 | xuanwu_hc_00020,0,hc,0 16 | xuanwu_pd_00029,1,pd,1 17 | xuanwu_pd_00043,1,pd,1 18 | xuanwu_hc_00033,0,hc,1 19 | xuanwu_pd_00072,1,pd,1 20 | xuanwu_hc_00042,0,hc,0 21 | xuanwu_pd_00006,1,pd,1 22 | xuanwu_pd_00033,1,pd,0 23 | xuanwu_hc_00057,0,hc,0 24 | xuanwu_pd_00086,1,pd,0 25 | xuanwu_pd_00032,1,pd,1 26 | xuanwu_pd_00049,1,pd,0 27 | xuanwu_hc_00016,0,hc,1 28 | xuanwu_pd_00011,1,pd,1 29 | xuanwu_hc_00064,0,hc,1 30 | xuanwu_hc_00065,0,hc,1 31 | xuanwu_hc_00004,0,hc,1 32 | xuanwu_pd_00024,1,pd,1 33 | xuanwu_pd_00076,1,pd,1 34 | xuanwu_pd_00046,1,pd,1 35 | xuanwu_pd_00087,1,pd,1 36 | xuanwu_pd_00054,1,pd,1 37 | xuanwu_hc_00001,0,hc,1 38 | xuanwu_hc_00069,0,hc,1 39 | xuanwu_hc_00025,0,hc,1 40 | xuanwu_hc_00038,0,hc,1 41 | xuanwu_hc_00034,0,hc,0 42 | xuanwu_hc_00009,0,hc,1 43 | xuanwu_pd_00059,1,pd,1 44 | xuanwu_pd_00056,1,pd,0 45 | xuanwu_pd_00040,1,pd,1 46 | xuanwu_pd_00022,1,pd,0 47 | xuanwu_hc_00062,0,hc,1 48 | xuanwu_pd_00038,1,pd,1 49 | xuanwu_pd_00075,1,pd,1 50 | xuanwu_pd_00014,1,pd,1 51 | xuanwu_hc_00044,0,hc,0 52 | xuanwu_pd_00069,1,pd,1 53 | xuanwu_pd_00025,1,pd,1 54 | xuanwu_hc_00059,0,hc,0 55 | xuanwu_pd_00053,1,pd,0 56 | xuanwu_pd_00058,1,pd,0 57 | xuanwu_hc_00014,0,hc,1 58 | xuanwu_hc_00073,0,hc,1 59 | xuanwu_hc_00040,0,hc,1 60 | xuanwu_hc_00032,0,hc,1 61 | xuanwu_pd_00026,1,pd,1 62 | xuanwu_hc_00031,0,hc,1 63 | xuanwu_pd_00045,1,pd,1 64 | xuanwu_pd_00019,1,pd,0 65 | xuanwu_hc_00008,0,hc,1 66 | xuanwu_pd_00081,1,pd,1 67 | xuanwu_hc_00046,0,hc,1 68 | xuanwu_hc_00021,0,hc,1 69 | xuanwu_pd_00009,1,pd,0 70 | xuanwu_hc_00050,0,hc,1 71 | xuanwu_hc_00029,0,hc,1 72 | xuanwu_hc_00007,0,hc,0 73 | xuanwu_hc_00005,0,hc,0 74 | xuanwu_hc_00067,0,hc,0 75 | xuanwu_hc_00048,0,hc,1 76 | xuanwu_pd_00003,1,pd,1 77 | xuanwu_pd_00078,1,pd,1 78 | xuanwu_hc_00024,0,hc,0 79 | xuanwu_pd_00064,1,pd,1 80 | xuanwu_hc_00019,0,hc,1 81 | xuanwu_pd_00068,1,pd,1 82 | xuanwu_hc_00071,0,hc,0 83 | xuanwu_pd_00039,1,pd,1 84 | xuanwu_pd_00017,1,pd,1 85 | xuanwu_pd_00012,1,pd,1 86 | xuanwu_pd_00037,1,pd,0 87 | xuanwu_hc_00060,0,hc,0 88 | xuanwu_pd_00073,1,pd,0 89 | xuanwu_hc_00030,0,hc,1 90 | xuanwu_hc_00061,0,hc,1 91 | xuanwu_hc_00028,0,hc,0 92 | xuanwu_pd_00031,1,pd,1 93 | xuanwu_hc_00072,0,hc,1 94 | xuanwu_pd_00074,1,pd,1 95 | xuanwu_pd_00035,1,pd,1 96 | xuanwu_pd_00020,1,pd,0 97 | xuanwu_pd_00008,1,pd,0 98 | xuanwu_pd_00042,1,pd,0 99 | xuanwu_hc_00054,0,hc,1 100 | xuanwu_pd_00023,1,pd,1 101 | xuanwu_hc_00049,0,hc,0 102 | xuanwu_pd_00082,1,pd,1 103 | xuanwu_pd_00041,1,pd,0 104 | xuanwu_hc_00037,0,hc,1 105 | xuanwu_hc_00043,0,hc,0 106 | xuanwu_hc_00052,0,hc,1 107 | xuanwu_pd_00005,1,pd,1 108 | xuanwu_pd_00079,1,pd,1 109 | xuanwu_hc_00026,0,hc,0 110 | xuanwu_hc_00015,0,hc,1 111 | xuanwu_pd_00030,1,pd,1 112 | xuanwu_pd_00080,1,pd,0 113 | xuanwu_hc_00045,0,hc,1 114 | xuanwu_hc_00039,0,hc,1 115 | xuanwu_hc_00053,0,hc,0 116 | xuanwu_pd_00021,1,pd,1 117 | xuanwu_hc_00013,0,hc,0 118 | xuanwu_hc_00056,0,hc,1 119 | xuanwu_pd_00071,1,pd,1 120 | xuanwu_pd_00062,1,pd,0 121 | xuanwu_pd_00055,1,pd,1 122 | xuanwu_pd_00028,1,pd,0 123 | xuanwu_pd_00004,1,pd,1 124 | xuanwu_pd_00007,1,pd,0 125 | xuanwu_hc_00006,0,hc,1 126 | xuanwu_hc_00041,0,hc,1 127 | xuanwu_pd_00070,1,pd,0 128 | xuanwu_hc_00022,0,hc,0 129 | xuanwu_pd_00027,1,pd,1 130 | xuanwu_pd_00002,1,pd,0 131 | xuanwu_pd_00051,1,pd,1 132 | xuanwu_pd_00036,1,pd,1 133 | xuanwu_hc_00002,0,hc,1 134 | xuanwu_pd_00077,1,pd,0 135 | xuanwu_pd_00066,1,pd,1 136 | xuanwu_hc_00023,0,hc,1 137 | xuanwu_pd_00048,1,pd,1 138 | xuanwu_pd_00047,1,pd,0 139 | xuanwu_hc_00036,0,hc,1 140 | xuanwu_pd_00061,1,pd,1 141 | xuanwu_pd_00085,1,pd,1 142 | xuanwu_pd_00084,1,pd,1 143 | xuanwu_pd_00063,1,pd,1 144 | xuanwu_hc_00058,0,hc,1 145 | xuanwu_hc_00055,0,hc,1 146 | xuanwu_hc_00018,0,hc,1 147 | xuanwu_hc_00051,0,hc,1 148 | xuanwu_hc_00003,0,hc,0 149 | xuanwu_hc_00012,0,hc,1 150 | xuanwu_pd_00013,1,pd,1 151 | xuanwu_hc_00035,0,hc,1 152 | xuanwu_pd_00018,1,pd,1 153 | xuanwu_hc_00027,0,hc,0 154 | xuanwu_pd_00067,1,pd,1 155 | xuanwu_pd_00044,1,pd,1 156 | xuanwu_pd_00057,1,pd,0 157 | xuanwu_hc_00010,0,hc,1 158 | -------------------------------------------------------------------------------- /masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class RandomMaskingGenerator: 4 | def __init__(self, input_size, mask_ratio): 5 | if not isinstance(input_size, tuple): 6 | input_size = (input_size,) * 2 7 | 8 | self.height, self.width = input_size 9 | 10 | self.num_patches = self.height * self.width 11 | self.num_mask = int(mask_ratio * self.num_patches) 12 | 13 | def __repr__(self): 14 | repr_str = "Maks: total patches {}, mask patches {}".format( 15 | self.num_patches, self.num_mask 16 | ) 17 | return repr_str 18 | 19 | def __call__(self): 20 | mask = np.hstack([ 21 | np.zeros(self.num_patches - self.num_mask), 22 | np.ones(self.num_mask), 23 | ]) 24 | np.random.shuffle(mask) 25 | return mask # [196] -------------------------------------------------------------------------------- /network_dataset.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import os 3 | from torch.utils import data 4 | import numpy as np 5 | import nibabel as nib 6 | import random 7 | import pandas as pd 8 | from sklearn.model_selection import StratifiedShuffleSplit 9 | import warnings 10 | from nilearn.connectome import ConnectivityMeasure 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | class RandomMaskingGenerator: 15 | def __init__(self, input_size, mask_ratio): 16 | 17 | self.input_size = input_size 18 | 19 | self.num_patches = self.input_size 20 | self.num_mask = int(mask_ratio * self.num_patches) 21 | 22 | def __repr__(self): 23 | repr_str = "Maks: total patches {}, mask patches {}".format( 24 | self.num_patches, self.num_mask 25 | ) 26 | return repr_str 27 | 28 | def __call__(self): 29 | mask = np.hstack([ 30 | np.zeros(self.num_patches - self.num_mask), 31 | np.ones(self.num_mask), 32 | ]) 33 | np.random.shuffle(mask) 34 | return mask 35 | 36 | def mask_timeseries(timeser, mask = 30): 37 | rnd = np.random.random() 38 | time_len = timeser.shape[1] 39 | mask_index = np.array(random.sample(list(np.arange(0,time_len)),mask)) 40 | bool_mask = np.zeros((time_len)) 41 | bool_mask[mask_index]=1 42 | bool_mask = bool_mask.astype(bool) 43 | 44 | return timeser[:,~bool_mask] 45 | 46 | def mask_timeseries_per(timeser, mask = 30): 47 | rnd = np.random.random() 48 | 49 | time_len = timeser.shape[1] 50 | mask_len = int(mask * time_len /100) 51 | mask_index = np.array(random.sample(list(np.arange(0,time_len)),mask_len)) 52 | bool_mask = np.zeros((time_len)) 53 | bool_mask[mask_index]=1 54 | bool_mask = bool_mask.astype(bool) 55 | 56 | return timeser[:,~bool_mask] 57 | 58 | def random_timeseries(timeser,sample_len): 59 | time_len = timeser.shape[1] 60 | st_thres = 1 61 | if time_len <= sample_len + st_thres: 62 | return timeser 63 | 64 | select_range = time_len - sample_len 65 | if select_range < 1: 66 | return timeser 67 | 68 | st = random.sample(list(np.arange(st_thres,select_range)),1)[0] 69 | return timeser[:,st:st+sample_len] 70 | 71 | class Task1Data(data.Dataset): 72 | 73 | def __init__(self, root = None,csv = None, mask_way='mask',mask_len=10, time_len=30): 74 | self.template = 'sch' 75 | self.root = root 76 | self.mask_way = mask_way 77 | self.mask_len = mask_len 78 | self.time_len = time_len 79 | df = pd.read_csv(csv) 80 | self.names = list(df['file']) 81 | 82 | print(f"Finding files: {len(self.names)}") 83 | self.correlation_measure = ConnectivityMeasure(kind='correlation') 84 | 85 | def __getitem__(self,index): 86 | name = self.names[index] 87 | img = np.load(os.path.join(self.root, name)) 88 | if self.mask_way == 'mask': 89 | slices = [mask_timeseries(img,mask=self.mask_len).T, mask_timeseries(img,mask=self.mask_len).T] 90 | elif self.mask_way == 'mask_per': 91 | slices = [mask_timeseries_per(img,mask=self.mask_len).T, mask_timeseries_per(img,mask=self.mask_len).T] 92 | elif self.mask_way == 'random': 93 | slices = [random_timeseries(img,sample_len=self.time_len).T, random_timeseries(img,sample_len=self.time_len).T] 94 | else: 95 | raise KeyError(f"mask way error, your input is {self.mask_way}") 96 | correlation_matrix = self.correlation_measure.fit_transform(slices) 97 | correlation_matrix[correlation_matrix!=correlation_matrix]=0 98 | return correlation_matrix[0], correlation_matrix[1] 99 | 100 | def __len__(self): 101 | return len(self.names) 102 | 103 | class Task2Data(data.Dataset): 104 | 105 | def __init__(self, root= None, csv = None, mask_way='mask',mask_len=10, time_len=30,shuffle_seed=42,is_train = True, is_test = False): 106 | # self.template = 'sch' 107 | self.is_test = is_test 108 | self.is_train = is_train 109 | self.root = root 110 | 111 | self.mask_way = mask_way 112 | self.mask_len = mask_len 113 | self.time_len = time_len 114 | 115 | self.df = pd.read_csv(csv) 116 | 117 | self.names = list(self.df['file']) 118 | test_length = int(len(self.df) * 0.15) 119 | 120 | all_data = np.array(self.names) 121 | lbls = np.array(list([1 if f == 1 else 0 for f in self.df['dx'] ])) 122 | sites = np.array(self.df['site']) if 'site' in self.df.columns else lbls 123 | train_index = self.df[self.df['is_train']==1].index 124 | rest_index = self.df[self.df['is_train']==0].index 125 | 126 | data_train = all_data[train_index] 127 | labels_train = lbls[train_index] 128 | 129 | rest_data = all_data[rest_index] 130 | rest_site = sites[rest_index] 131 | rest_label = lbls[rest_index] 132 | 133 | 134 | split2 = StratifiedShuffleSplit(n_splits=1, test_size=test_length, random_state=shuffle_seed) 135 | for valid_index, test_index in split2.split(rest_data, rest_site): 136 | data_test, labels_test = rest_data[test_index], rest_label[test_index] 137 | data_val, labels_val = rest_data[valid_index], rest_label[valid_index] 138 | 139 | if is_test is True: 140 | print("Testing data:") 141 | self.imgs, self.lbls = data_test, labels_test 142 | elif is_train is True: 143 | print("Training data:") 144 | self.imgs, self.lbls = data_train, labels_train 145 | # self.imgs, self.lbls = np.concatenate([data_train, data_val],0), np.concatenate([labels_train, labels_val],0), 146 | else: 147 | print("Val data:") 148 | self.imgs, self.lbls = data_val, labels_val 149 | print(self.imgs.shape) 150 | self.correlation_measure = ConnectivityMeasure(kind='correlation') 151 | 152 | 153 | def __getitem__(self,index): 154 | name = self.imgs[index] 155 | lbl = self.lbls[index] 156 | img = np.load(os.path.join(self.root, f"{name}.npy")) 157 | if self.is_train is True: 158 | if self.mask_way == 'mask': 159 | slices = [mask_timeseries(img,self.mask_len).T] 160 | elif self.mask_way =='random': 161 | slices = [random_timeseries(img,self.time_len).T] 162 | elif self.mask_way =='mask_per': 163 | slices = [mask_timeseries_per(img,mask=self.mask_len).T] 164 | else: 165 | slices = [img.T] 166 | correlation_matrix = self.correlation_measure.fit_transform(slices).mean(0) 167 | elif self.is_test is False: 168 | slices = [img.T] 169 | correlation_matrix = self.correlation_measure.fit_transform(slices)[0] 170 | else: 171 | # slices = [img.T] 172 | slices = [mask_timeseries_per(img,mask=self.mask_len).T] 173 | correlation_matrix = self.correlation_measure.fit_transform(slices).mean(0) 174 | onehot_lbl = np.zeros((2)) 175 | onehot_lbl[lbl] = 1 176 | correlation_matrix[correlation_matrix!=correlation_matrix]=0 177 | return correlation_matrix,onehot_lbl 178 | 179 | def __len__(self): 180 | return len(self.imgs) -------------------------------------------------------------------------------- /res.txt: -------------------------------------------------------------------------------- 1 | data:[pd] seed:[42] acc: 0.8500 sen: 0.8182 spe: 0.8889 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # -------------------------------------------------------- 3 | # Based on BEiT, timm, DINO and DeiT code bases 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 6 | # https://github.com/facebookresearch/deit 7 | # https://github.com/facebookresearch/dino 8 | # --------------------------------------------------------' 9 | import io 10 | import os 11 | import math 12 | import time 13 | import json 14 | from collections import defaultdict, deque 15 | import datetime 16 | import numpy as np 17 | from timm.utils import get_state_dict 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from pathlib import Path 22 | 23 | import torch 24 | import torch.distributed as dist 25 | # from torch._six import inf 26 | inf= 999999 27 | import random 28 | 29 | from tensorboardX import SummaryWriter 30 | from torch.nn import TransformerEncoderLayer 31 | 32 | def get_data(args): 33 | if args.data == 'abide1': 34 | args.csv = "down_stream/abide1.csv" 35 | elif args.data == 'adhd': 36 | args.csv = "down_stream/adhd.csv" 37 | elif args.data == 'oas': 38 | args.csv = "down_stream/oas.csv" 39 | elif args.data == 'mci': 40 | args.csv = "down_stream/mci.csv" 41 | elif args.data == 'ad': 42 | args.csv = "down_stream/ad.csv" 43 | elif args.data == 'pd': 44 | args.csv = "down_stream/pd.csv" 45 | elif args.data == 'pro': 46 | args.csv = "down_stream/pro.csv" 47 | elif args.data =='schizo': 48 | args.csv = "down_stream/schizo.csv" 49 | elif args.data == 'abide2': 50 | args.csv = "down_stream/abide2.csv" 51 | elif args.data == 'ucla_bp': 52 | args.csv = "down_stream/ucla_bp.csv" 53 | elif args.data == 'ucla_schizo': 54 | args.csv = "down_stream/ucla_schz.csv" 55 | elif args.data == 'ucla_adhd': 56 | args.csv = "down_stream/ucla_adhd.csv" 57 | elif args.data == 'xuanwu_pd': 58 | args.csv = "down_stream/xuanwu_pd.csv" 59 | elif args.data == 'xuanwu_rbd': 60 | args.csv = "down_stream/xuanwu_rbd.csv" 61 | elif args.data == 'pd_ad': 62 | args.csv = "down_stream/pd_ad.csv" 63 | return args 64 | 65 | class BNTF(nn.Module): 66 | def __init__(self,feature_dim,depth,heads,dim_feedforward): 67 | super().__init__() 68 | self.num_patches = 100 69 | 70 | self.attention_list = nn.ModuleList() 71 | self.node_num = 100 72 | for _ in range(int(depth)): 73 | self.attention_list.append( 74 | TransformerEncoderLayer(d_model=self.node_num, nhead=int(heads), dim_feedforward=dim_feedforward, 75 | batch_first=True) 76 | ) 77 | self.dim_reduction = nn.Sequential( 78 | nn.Linear(self.node_num, 8), 79 | nn.LeakyReLU() 80 | ) 81 | 82 | final_dim = 8 * self.node_num 83 | 84 | self.g = MLPHead(final_dim, final_dim * 2, feature_dim) 85 | 86 | def forward(self,img,forward_with_mlp=True): 87 | bz, _, _, = img.shape 88 | 89 | for atten in self.attention_list: 90 | img = atten(img) 91 | if forward_with_mlp is not True: 92 | return img 93 | node_feature = self.dim_reduction(img) 94 | node_feature = node_feature.reshape((bz, -1)) 95 | node_feature = self.g(node_feature) 96 | return node_feature 97 | 98 | 99 | class MLPHead(nn.Module): 100 | def __init__(self, in_channels, mlp_hidden_size, projection_size): 101 | super().__init__() 102 | 103 | self.net = nn.Sequential( 104 | nn.Linear(in_channels, mlp_hidden_size), 105 | nn.BatchNorm1d(mlp_hidden_size), 106 | nn.ReLU(inplace=True), 107 | nn.Linear(mlp_hidden_size, projection_size) 108 | ) 109 | 110 | def forward(self, x): 111 | return self.net(x) 112 | 113 | class FT(nn.Module): 114 | def __init__(self,feature_dim,depth,heads,dim_feedforward): 115 | super().__init__() 116 | self.encoder = BNTF(feature_dim,depth,heads,dim_feedforward) 117 | self.g2 = nn.Sequential( 118 | nn.Linear(8 * 100, 256), 119 | nn.BatchNorm1d(256), 120 | nn.LeakyReLU(), 121 | nn.Dropout(0.3), 122 | nn.Linear(256, 32), 123 | nn.BatchNorm1d(32), 124 | nn.LeakyReLU(), 125 | nn.Dropout(0.3), 126 | nn.Linear(32,2) 127 | ) 128 | def forward(self,img): 129 | bz, _, _, = img.shape 130 | 131 | for atten in self.encoder.attention_list: 132 | img = atten(img) 133 | 134 | node_feature = self.encoder.dim_reduction(img) 135 | node_feature = node_feature.reshape((bz, -1)) 136 | node_feature = F.leaky_relu(node_feature) 137 | node_feature = self.g2(node_feature) 138 | return node_feature 139 | 140 | # sin-cos position encoding 141 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 142 | def get_sinusoid_encoding_table(n_position, d_hid): 143 | ''' Sinusoid position encoding table ''' 144 | # TODO: make it with torch instead of numpy 145 | def get_position_angle_vec(position): 146 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 147 | 148 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 149 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 150 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 151 | 152 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 153 | 154 | 155 | class SmoothedValue(object): 156 | """Track a series of values and provide access to smoothed values over a 157 | window or the global series average. 158 | """ 159 | 160 | def __init__(self, window_size=20, fmt=None): 161 | if fmt is None: 162 | fmt = "{median:.4f} ({global_avg:.4f})" 163 | self.deque = deque(maxlen=window_size) 164 | self.total = 0.0 165 | self.count = 0 166 | self.fmt = fmt 167 | 168 | def update(self, value, n=1): 169 | self.deque.append(value) 170 | self.count += n 171 | self.total += value * n 172 | 173 | def synchronize_between_processes(self): 174 | """ 175 | Warning: does not synchronize the deque! 176 | """ 177 | if not is_dist_avail_and_initialized(): 178 | return 179 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 180 | dist.barrier() 181 | dist.all_reduce(t) 182 | t = t.tolist() 183 | print(t) 184 | self.count = int(t[0]) 185 | self.total = t[1] 186 | 187 | @property 188 | def median(self): 189 | d = torch.tensor(list(self.deque)) 190 | return d.median().item() 191 | 192 | @property 193 | def avg(self): 194 | d = torch.tensor(list(self.deque), dtype=torch.float32) 195 | return d.mean().item() 196 | 197 | @property 198 | def global_avg(self): 199 | return self.total / self.count 200 | 201 | @property 202 | def max(self): 203 | return max(self.deque) 204 | 205 | @property 206 | def value(self): 207 | return self.deque[-1] 208 | 209 | def __str__(self): 210 | return self.fmt.format( 211 | median=self.median, 212 | avg=self.avg, 213 | global_avg=self.global_avg, 214 | max=self.max, 215 | value=self.value) 216 | 217 | 218 | class MetricLogger(object): 219 | def __init__(self, delimiter="\t"): 220 | self.meters = defaultdict(SmoothedValue) 221 | self.delimiter = delimiter 222 | 223 | def update(self, **kwargs): 224 | for k, v in kwargs.items(): 225 | if v is None: 226 | continue 227 | if isinstance(v, torch.Tensor): 228 | v = v.item() 229 | assert isinstance(v, (float, int)) 230 | self.meters[k].update(v) 231 | 232 | def __getattr__(self, attr): 233 | if attr in self.meters: 234 | return self.meters[attr] 235 | if attr in self.__dict__: 236 | return self.__dict__[attr] 237 | raise AttributeError("'{}' object has no attribute '{}'".format( 238 | type(self).__name__, attr)) 239 | 240 | def __str__(self): 241 | loss_str = [] 242 | for name, meter in self.meters.items(): 243 | loss_str.append( 244 | "{}: {}".format(name, str(meter)) 245 | ) 246 | return self.delimiter.join(loss_str) 247 | 248 | def synchronize_between_processes(self): 249 | for meter in self.meters.values(): 250 | meter.synchronize_between_processes() 251 | 252 | def add_meter(self, name, meter): 253 | self.meters[name] = meter 254 | 255 | def log_every(self, iterable, print_freq, header=None): 256 | i = 0 257 | if not header: 258 | header = '' 259 | start_time = time.time() 260 | end = time.time() 261 | iter_time = SmoothedValue(fmt='{avg:.4f}') 262 | data_time = SmoothedValue(fmt='{avg:.4f}') 263 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 264 | log_msg = [ 265 | header, 266 | '[{0' + space_fmt + '}/{1}]', 267 | 'eta: {eta}', 268 | '{meters}', 269 | 'time: {time}', 270 | 'data: {data}' 271 | ] 272 | if torch.cuda.is_available(): 273 | log_msg.append('max mem: {memory:.0f}') 274 | log_msg = self.delimiter.join(log_msg) 275 | MB = 1024.0 * 1024.0 276 | for obj in iterable: 277 | data_time.update(time.time() - end) 278 | yield obj 279 | iter_time.update(time.time() - end) 280 | if i % print_freq == 0 or i == len(iterable) - 1: 281 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 282 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 283 | if torch.cuda.is_available(): 284 | print(log_msg.format( 285 | i, len(iterable), eta=eta_string, 286 | meters=str(self), 287 | time=str(iter_time), data=str(data_time), 288 | memory=torch.cuda.max_memory_allocated() / MB)) 289 | else: 290 | print(log_msg.format( 291 | i, len(iterable), eta=eta_string, 292 | meters=str(self), 293 | time=str(iter_time), data=str(data_time))) 294 | i += 1 295 | end = time.time() 296 | total_time = time.time() - start_time 297 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 298 | print('{} Total time: {} ({:.4f} s / it)'.format( 299 | header, total_time_str, total_time / len(iterable))) 300 | 301 | 302 | class TensorboardLogger(object): 303 | def __init__(self, log_dir): 304 | self.writer = SummaryWriter(logdir=log_dir) 305 | self.step = 0 306 | 307 | def set_step(self, step=None): 308 | if step is not None: 309 | self.step = step 310 | else: 311 | self.step += 1 312 | 313 | def update(self, head='scalar', step=None, **kwargs): 314 | for k, v in kwargs.items(): 315 | if v is None: 316 | continue 317 | if isinstance(v, torch.Tensor): 318 | v = v.item() 319 | assert isinstance(v, (float, int)) 320 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 321 | 322 | def flush(self): 323 | self.writer.flush() 324 | 325 | def seed_worker(worker_id): 326 | worker_seed = torch.initial_seed() % 2**32 327 | np.random.seed(worker_seed) 328 | random.seed(worker_seed) 329 | 330 | def _load_checkpoint_for_ema(model_ema, checkpoint): 331 | """ 332 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 333 | """ 334 | mem_file = io.BytesIO() 335 | torch.save(checkpoint, mem_file) 336 | mem_file.seek(0) 337 | model_ema._load_checkpoint(mem_file) 338 | 339 | 340 | def setup_for_distributed(is_master): 341 | """ 342 | This function disables printing when not in master process 343 | """ 344 | import builtins as __builtin__ 345 | builtin_print = __builtin__.print 346 | 347 | def print(*args, **kwargs): 348 | force = kwargs.pop('force', False) 349 | if is_master or force: 350 | builtin_print(*args, **kwargs) 351 | 352 | __builtin__.print = print 353 | 354 | 355 | def is_dist_avail_and_initialized(): 356 | if not dist.is_available(): 357 | return False 358 | if not dist.is_initialized(): 359 | return False 360 | return True 361 | 362 | 363 | def get_world_size(): 364 | if not is_dist_avail_and_initialized(): 365 | return 1 366 | return dist.get_world_size() 367 | 368 | 369 | def get_rank(): 370 | if not is_dist_avail_and_initialized(): 371 | return 0 372 | return dist.get_rank() 373 | 374 | 375 | def is_main_process(): 376 | return get_rank() == 0 377 | 378 | 379 | def save_on_master(*args, **kwargs): 380 | if is_main_process(): 381 | torch.save(*args, **kwargs) 382 | 383 | 384 | def init_distributed_mode(args): 385 | if args.dist_on_itp: 386 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 387 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 388 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 389 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 390 | os.environ['LOCAL_RANK'] = str(args.gpu) 391 | os.environ['RANK'] = str(args.rank) 392 | os.environ['WORLD_SIZE'] = str(args.world_size) 393 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 394 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 395 | args.rank = int(os.environ["RANK"]) 396 | args.world_size = int(os.environ['WORLD_SIZE']) 397 | args.gpu = int(os.environ['LOCAL_RANK']) 398 | elif 'SLURM_PROCID' in os.environ: 399 | args.rank = int(os.environ['SLURM_PROCID']) 400 | args.gpu = args.rank % torch.cuda.device_count() 401 | else: 402 | print('Not using distributed mode') 403 | args.distributed = False 404 | return 405 | 406 | args.distributed = True 407 | 408 | torch.cuda.set_device(args.gpu) 409 | args.dist_backend = 'nccl' 410 | print('| distributed init (rank {}): {}, gpu {}'.format( 411 | args.rank, args.dist_url, args.gpu), flush=True) 412 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 413 | world_size=args.world_size, rank=args.rank) 414 | torch.distributed.barrier() 415 | setup_for_distributed(args.rank == 0) 416 | 417 | 418 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 419 | missing_keys = [] 420 | unexpected_keys = [] 421 | error_msgs = [] 422 | # copy state_dict so _load_from_state_dict can modify it 423 | metadata = getattr(state_dict, '_metadata', None) 424 | state_dict = state_dict.copy() 425 | if metadata is not None: 426 | state_dict._metadata = metadata 427 | 428 | def load(module, prefix=''): 429 | local_metadata = {} if metadata is None else metadata.get( 430 | prefix[:-1], {}) 431 | module._load_from_state_dict( 432 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 433 | for name, child in module._modules.items(): 434 | if child is not None: 435 | load(child, prefix + name + '.') 436 | #print("loading state_dict: ", state_dict.keys());exit() 437 | load(model, prefix=prefix) 438 | 439 | warn_missing_keys = [] 440 | ignore_missing_keys = [] 441 | for key in missing_keys: 442 | keep_flag = True 443 | for ignore_key in ignore_missing.split('|'): 444 | if ignore_key in key: 445 | keep_flag = False 446 | break 447 | if keep_flag: 448 | warn_missing_keys.append(key) 449 | else: 450 | ignore_missing_keys.append(key) 451 | 452 | missing_keys = warn_missing_keys 453 | 454 | if len(missing_keys) > 0: 455 | print("Weights of {} not initialized from pretrained model: {}".format( 456 | model.__class__.__name__, missing_keys)) 457 | if len(unexpected_keys) > 0: 458 | print("Weights from pretrained model not used in {}: {}".format( 459 | model.__class__.__name__, unexpected_keys)) 460 | if len(ignore_missing_keys) > 0: 461 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 462 | model.__class__.__name__, ignore_missing_keys)) 463 | if len(error_msgs) > 0: 464 | print('\n'.join(error_msgs)) 465 | 466 | 467 | class NativeScalerWithGradNormCount: 468 | state_dict_key = "amp_scaler" 469 | 470 | def __init__(self): 471 | self._scaler = torch.cuda.amp.GradScaler() 472 | 473 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 474 | self._scaler.scale(loss).backward(create_graph=create_graph) 475 | if update_grad: 476 | if clip_grad is not None: 477 | assert parameters is not None 478 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 479 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 480 | else: 481 | self._scaler.unscale_(optimizer) 482 | norm = get_grad_norm_(parameters) 483 | self._scaler.step(optimizer) 484 | self._scaler.update() 485 | else: 486 | norm = None 487 | return norm 488 | 489 | def state_dict(self): 490 | return self._scaler.state_dict() 491 | 492 | def load_state_dict(self, state_dict): 493 | self._scaler.load_state_dict(state_dict) 494 | 495 | 496 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 497 | if isinstance(parameters, torch.Tensor): 498 | parameters = [parameters] 499 | parameters = [p for p in parameters if p.grad is not None] 500 | norm_type = float(norm_type) 501 | if len(parameters) == 0: 502 | return torch.tensor(0.) 503 | device = parameters[0].grad.device 504 | if norm_type == inf: 505 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 506 | else: 507 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 508 | return total_norm 509 | 510 | 511 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 512 | start_warmup_value=0, warmup_steps=-1): 513 | warmup_schedule = np.array([]) 514 | warmup_iters = warmup_epochs * niter_per_ep 515 | if warmup_steps > 0: 516 | warmup_iters = warmup_steps 517 | print("Set warmup steps = %d" % warmup_iters) 518 | if warmup_epochs > 0: 519 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 520 | 521 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 522 | schedule = np.array( 523 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 524 | 525 | schedule = np.concatenate((warmup_schedule, schedule)) 526 | 527 | assert len(schedule) == epochs * niter_per_ep 528 | return schedule 529 | 530 | 531 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 532 | output_dir = Path(args.output_dir) 533 | epoch_name = str(epoch) 534 | if loss_scaler is not None: 535 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 536 | for checkpoint_path in checkpoint_paths: 537 | to_save = { 538 | 'model': model_without_ddp.state_dict(), 539 | 'optimizer': optimizer.state_dict(), 540 | 'epoch': epoch, 541 | 'scaler': loss_scaler.state_dict(), 542 | 'args': args, 543 | } 544 | 545 | if model_ema is not None: 546 | to_save['model_ema'] = get_state_dict(model_ema) 547 | 548 | save_on_master(to_save, checkpoint_path) 549 | else: 550 | client_state = {'epoch': epoch} 551 | if model_ema is not None: 552 | client_state['model_ema'] = get_state_dict(model_ema) 553 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 554 | 555 | 556 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 557 | output_dir = Path(args.output_dir) 558 | if loss_scaler is not None: 559 | # torch.amp 560 | if args.auto_resume and len(args.resume) == 0: 561 | import glob 562 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 563 | latest_ckpt = -1 564 | for ckpt in all_checkpoints: 565 | t = ckpt.split('-')[-1].split('.')[0] 566 | if t.isdigit(): 567 | latest_ckpt = max(int(t), latest_ckpt) 568 | if latest_ckpt >= 0: 569 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 570 | print("Auto resume checkpoint: %s" % args.resume) 571 | 572 | if args.resume: 573 | if args.resume.startswith('https'): 574 | checkpoint = torch.hub.load_state_dict_from_url( 575 | args.resume, map_location='cpu', check_hash=True) 576 | else: 577 | checkpoint = torch.load(args.resume, map_location='cpu') 578 | model_without_ddp.load_state_dict(checkpoint['model']) 579 | print("Resume checkpoint %s" % args.resume) 580 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 581 | optimizer.load_state_dict(checkpoint['optimizer']) 582 | args.start_epoch = checkpoint['epoch'] + 1 583 | if hasattr(args, 'model_ema') and args.model_ema: 584 | _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 585 | if 'scaler' in checkpoint: 586 | loss_scaler.load_state_dict(checkpoint['scaler']) 587 | print("With optim & sched!") 588 | else: 589 | # deepspeed, only support '--auto_resume'. 590 | if args.auto_resume: 591 | import glob 592 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) 593 | latest_ckpt = -1 594 | for ckpt in all_checkpoints: 595 | t = ckpt.split('-')[-1].split('.')[0] 596 | if t.isdigit(): 597 | latest_ckpt = max(int(t), latest_ckpt) 598 | if latest_ckpt >= 0: 599 | args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) 600 | print("Auto resume checkpoint: %d" % latest_ckpt) 601 | _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) 602 | args.start_epoch = client_states['epoch'] + 1 603 | if model_ema is not None: 604 | if args.model_ema: 605 | _load_checkpoint_for_ema(model_ema, client_states['model_ema']) 606 | 607 | 608 | def create_ds_config(args): 609 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") 610 | with open(args.deepspeed_config, mode="w") as writer: 611 | ds_config = { 612 | "train_batch_size": args.batch_size * args.update_freq * get_world_size(), 613 | "train_micro_batch_size_per_gpu": args.batch_size, 614 | "steps_per_print": 1000, 615 | "optimizer": { 616 | "type": "Adam", 617 | "adam_w_mode": True, 618 | "params": { 619 | "lr": args.lr, 620 | "weight_decay": args.weight_decay, 621 | "bias_correction": True, 622 | "betas": [ 623 | 0.9, 624 | 0.999 625 | ], 626 | "eps": 1e-8 627 | } 628 | }, 629 | "fp16": { 630 | "enabled": True, 631 | "loss_scale": 0, 632 | "initial_scale_power": 7, 633 | "loss_scale_window": 128 634 | } 635 | } 636 | 637 | writer.write(json.dumps(ds_config, indent=2)) 638 | --------------------------------------------------------------------------------