├── SA-result ├── README ├── result_1000_asd_z.npy ├── result_30_asd_x.npy ├── adTrained_asd_x_13.npy ├── adTrained_asd_z_13.npy ├── adTrained_tdc_x_13.npy └── adTrained_tdc_z_13.npy ├── nan_subid.npy ├── README.md ├── LICENSE ├── main.py ├── baseline.ipynb ├── BrainVisial.ipynb ├── Models.py ├── K-fold.ipynb ├── ensamble.ipynb ├── K-fold-withoutsMRI.ipynb └── Leave-one-site-out.ipynb /SA-result/README: -------------------------------------------------------------------------------- 1 | The result of important nodes and edges of ensamble models 2 | -------------------------------------------------------------------------------- /nan_subid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiJiangLabUESTC/Node-Edge-Graph-Attention-Networks/HEAD/nan_subid.npy -------------------------------------------------------------------------------- /SA-result/result_1000_asd_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiJiangLabUESTC/Node-Edge-Graph-Attention-Networks/HEAD/SA-result/result_1000_asd_z.npy -------------------------------------------------------------------------------- /SA-result/result_30_asd_x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiJiangLabUESTC/Node-Edge-Graph-Attention-Networks/HEAD/SA-result/result_30_asd_x.npy -------------------------------------------------------------------------------- /SA-result/adTrained_asd_x_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiJiangLabUESTC/Node-Edge-Graph-Attention-Networks/HEAD/SA-result/adTrained_asd_x_13.npy -------------------------------------------------------------------------------- /SA-result/adTrained_asd_z_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiJiangLabUESTC/Node-Edge-Graph-Attention-Networks/HEAD/SA-result/adTrained_asd_z_13.npy -------------------------------------------------------------------------------- /SA-result/adTrained_tdc_x_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiJiangLabUESTC/Node-Edge-Graph-Attention-Networks/HEAD/SA-result/adTrained_tdc_x_13.npy -------------------------------------------------------------------------------- /SA-result/adTrained_tdc_z_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiJiangLabUESTC/Node-Edge-Graph-Attention-Networks/HEAD/SA-result/adTrained_tdc_z_13.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Node-Edge-Graph-Attention-Networks 2 | Adversarial Learning Based Node-Edge Graph Attention Networks for Autism Spectrum Disorder Identification 3 | 4 | ## Description 5 | 1. Model Interpretability see [BrainVisial.ipynb](./BrainVisial.ipynb) 6 | 2. The norm trained, FGSM trained and PGD trained see [k-fold.ipynb](./k-fold.ipynb) 7 | 3. Model used for ablation experiments see [Models.py](./Models.p) 8 | ## Cite 9 | ``` 10 | @article{chen2022adversarial, 11 | title={Adversarial Learning Based Node-Edge Graph Attention Networks for Autism Spectrum Disorder Identification}, 12 | author={Chen, Yuzhong and Yan, Jiadong and Jiang, Mingxin and Zhang, Tuo and Zhao, Zhongbo and Zhao, Weihua and Zheng, Jian and Yao, Dezhong and Zhang, Rong and Kendrick, Keith M and others}, 13 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 14 | year={2022}, 15 | publisher={IEEE} 16 | } 17 | ``` 18 | ## Contact us 19 | For more details please contact chenyuzhong211@gmail.com 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 XiJiangLabUESTC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | import os 7 | import time 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import random 11 | import pandas as pd 12 | import torch.nn.functional as F 13 | import argparse 14 | from Models import * 15 | parser = argparse.ArgumentParser(description='manual to this script') 16 | parser.add_argument('--gpus', type=int, default = 0) 17 | parser.add_argument('--layer',type=int, default = 5) 18 | parser.add_argument('--thed', type=float, default = 0.1) 19 | parser.add_argument('--filename',type=str,default='result.txt') 20 | parser.add_argument('--knn',type=int,default=5) 21 | args = parser.parse_args() 22 | if args.gpus==1: 23 | device=torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 24 | elif args.gpus==0: 25 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | elif args.gpus==2: 27 | device=torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') 28 | elif args.gpus==3: 29 | device=torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') 30 | 31 | else: 32 | device=torch.device('cpu') 33 | from tqdm import tqdm 34 | import warnings 35 | warnings.filterwarnings("ignore") 36 | cpac_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_CPAC/' 37 | smri_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_sMRI/' 38 | nan_subid=np.load('nan_subid.npy').tolist() 39 | seed = 1234 40 | def setup_seed(seed): 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | torch.backends.cudnn.deterministic = True 46 | class LabelSmoothLoss(nn.Module): 47 | 48 | def __init__(self, smoothing=0.0): 49 | super(LabelSmoothLoss, self).__init__() 50 | self.smoothing = smoothing 51 | 52 | def forward(self, input, target): 53 | log_prob = F.log_softmax(input, dim=-1) 54 | weight = input.new_ones(input.size()) * \ 55 | self.smoothing / (input.size(-1) - 1.) 56 | weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing)) 57 | loss = (-weight * log_prob).sum(dim=-1).mean() 58 | return loss 59 | def cal_evaluate(TP,TN,FP,FN): 60 | if TP>0: 61 | p = TP / (TP + FP) 62 | r = TP / (TP + FN) 63 | F1 = 2 * r * p / (r + p) 64 | else: 65 | F1=0 66 | acc = (TP + TN) / (TP + TN + FP + FN) 67 | return acc,F1 68 | def test(device,model,testloader): 69 | model.eval() 70 | TP_test,TN_test,FP_test,FN_test=0,0,0,0 71 | with torch.no_grad(): 72 | for (X,Z,label,sub_id) in testloader: 73 | TP,TN,FN,FP=0,0,0,0 74 | n=X.size()[0] 75 | X=X.to(device) 76 | Z=Z.to(device) 77 | label=label.to(device) 78 | y=model(X,Z) 79 | _,predict=torch.max(y,1) 80 | TP+=((predict==1)&(label==1)).sum().item() 81 | TN+=((predict==0)&(label==0)).sum().item() 82 | FN+=((predict==0)&(label==1)).sum().item() 83 | FP+=((predict==1)&(label==0)).sum().item() 84 | TP_test+=TP 85 | TN_test+=TN 86 | FP_test+=FP 87 | FN_test+=FN 88 | acc,f1=cal_evaluate(TP_test,TN_test,FP_test,FN_test) 89 | global max_acc 90 | global modelname 91 | global savedModel 92 | if acc>=max_acc: 93 | max_acc=acc 94 | if saveModel: 95 | torch.save(model.state_dict(),modelname) 96 | return acc,f1,TP_test,TN_test,FP_test,FN_test 97 | class dataset(Dataset): 98 | def __init__(self,fmri_root,smri_root,site,ASD,TDC): 99 | super(dataset,self).__init__() 100 | self.fmri=fmri_root 101 | self.smri=smri_root 102 | self.ASD=[j for i in ASD for j in i] 103 | self.TDC=[j for i in TDC for j in i] 104 | self.data=self.ASD+self.TDC 105 | random.shuffle(self.data) 106 | self.data_site={} 107 | for i in range(len(site)): 108 | data=ASD[i]+TDC[i] 109 | for j in data: 110 | if j not in self.data_site: 111 | self.data_site[j]=site[i] 112 | def __getitem__(self,index): 113 | data=self.data[index] 114 | sub_id=int(data[0:5]) 115 | if data in self.ASD: 116 | data_slow5 =np.load(self.fmri+self.data_site[data]+'/group1_slow5/'+data,allow_pickle=True) 117 | data_slow4 =np.load(self.fmri+self.data_site[data]+'/group1_slow4/'+data,allow_pickle=True) 118 | data_voxel =np.load(self.smri+self.data_site[data]+'/group1/'+data,allow_pickle=True) 119 | data_FCz =np.load(self.fmri+self.data_site[data]+'/group1_FC/'+data,allow_pickle=True) 120 | elif data in self.TDC: 121 | data_slow5 =np.load(self.fmri+self.data_site[data]+'/group2_slow5/'+data,allow_pickle=True) 122 | data_slow4 =np.load(self.fmri+self.data_site[data]+'/group2_slow4/'+data,allow_pickle=True) 123 | data_voxel =np.load(self.smri+self.data_site[data]+'/group2/'+data,allow_pickle=True) 124 | data_FCz =np.load(self.fmri+self.data_site[data]+'/group2_FC/'+data,allow_pickle=True) 125 | else: 126 | print('wrong input') 127 | data_slow5=(data_slow5-np.min(data_slow5))/(np.max(data_slow5)-np.min(data_slow5)) 128 | data_slow4=(data_slow4-np.min(data_slow4))/(np.max(data_slow4)-np.min(data_slow4)) 129 | if np.any(np.isnan(data_slow5)) or np.any(np.isnan(data_slow4)) or np.any(np.isnan(data_FCz)): 130 | print('data wronmg') 131 | #data_FCz=(data_FCz-np.min(data_FCz))/(np.max(data_FCz)-np.min(data_FCz)) 132 | if self.data[index] in self.ASD: 133 | label=torch.tensor(1) 134 | else: 135 | label=torch.tensor(0) 136 | X=np.zeros((116,3),dtype=np.float32) 137 | X[:,0]=data_slow5 138 | X[:,1]=data_slow4 139 | X[:,2]=data_voxel 140 | data_FCz=data_FCz.astype(np.float32) 141 | Z=torch.from_numpy(data_FCz) 142 | X=torch.from_numpy(X) 143 | return X,Z,label,sub_id 144 | def __len__(self): 145 | return len(self.data) 146 | 147 | def train_pgd(model,trainloader,testloader,eps=0.02,iters=10,alpha=0.004): 148 | result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN')) 149 | criterian1=LabelSmoothLoss(0.1).to(device) 150 | optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate) 151 | for j in range(epoch): 152 | loss_sum=0 153 | model.train() 154 | for (X,Z,label,sub_id) in trainloader: 155 | model.train() 156 | x=X.to(device) 157 | z=Z.to(device) 158 | label=label.to(device) 159 | pretu_x,pretu_z=x,z 160 | ori_x,ori_z=x.data,z.data 161 | for i in range(iters): 162 | pretu_x.requires_grad=True 163 | pretu_z.requires_grad=True 164 | y=model(pretu_x,pretu_z) 165 | loss=criterian1(y,label) 166 | model.zero_grad() 167 | loss.backward() 168 | adv_x=pretu_x+alpha*torch.sign(pretu_x.grad.data) 169 | adv_z=pretu_z+alpha*torch.sign(pretu_z.grad.data) 170 | eta_x=torch.clamp(adv_x-ori_x,min=-eps,max=eps) 171 | eta_z=torch.clamp(adv_z-ori_z,min=-eps,max=eps) 172 | pretu_x=torch.clamp(ori_x+eta_x,min=0,max=1).detach_() 173 | pretu_z=torch.clamp(ori_z+eta_z,min=-1,max=1).detach_() 174 | y=model(x,z) 175 | yy=model(pretu_x,pretu_z) 176 | L2=torch.tensor(0,dtype=torch.float32).to(device) 177 | if L2_lamda>0: 178 | for name,parameters in model.named_parameters(): 179 | if name[0:5]=='clase' and name[-8:]=='0.weight': 180 | L2+=L2_lamda*torch.norm(parameters,2) 181 | loss=0.5*(criterian1(yy,label)+criterian1(y,label))+L2 182 | loss_sum+=loss.item() 183 | optimizer.zero_grad() 184 | loss.backward() 185 | optimizer.step() 186 | if (j+1)%5==0 or j==0: 187 | acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader) 188 | result.loc[j//5]=[loss_sum,acc,f1,TP_test,TN_test,FP_test,FN_test] 189 | result.sort_values('Acc',inplace=True,ascending=False) 190 | return result.iloc[0]['Acc'] 191 | 192 | def train_fgsm(model,trainloader,testloader,epsilon=0.05): 193 | result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN')) 194 | criterian1=LabelSmoothLoss(0.1).to(device) 195 | optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate) 196 | acc=0.5000 197 | for j in range(epoch): 198 | loss_sum=0 199 | TP,TN,FP,FN=0,0,0,0 200 | model.train() 201 | for (X,Z,label,sub_id) in trainloader: 202 | x=X.to(device) 203 | z=Z.to(device) 204 | x.requires_grad=True 205 | z.requires_grad=True 206 | label=label.to(device) 207 | y=model(x,z) 208 | loss=criterian1(y,label) 209 | model.zero_grad() 210 | loss.backward(retain_graph=True) 211 | sign_grad_x=torch.sign(x.grad.data) 212 | sign_grad_z=torch.sign(z.grad.data) 213 | perturbed_x=x+epsilon*sign_grad_x 214 | perturbed_z=z+epsilon*sign_grad_z 215 | perturbed_x=torch.clamp(perturbed_x,0,1) 216 | perturbed_z=torch.clamp(perturbed_z,-1,1) 217 | y=model(perturbed_x,perturbed_z) 218 | L2=torch.tensor(0,dtype=torch.float32).to(device) 219 | if L2_lamda>0: 220 | for name,parameters in model.named_parameters(): 221 | if name[0:5]=='clase' and name[-8:]=='0.weight': 222 | L2+=L2_lamda*torch.norm(parameters,2) 223 | loss=0.5*(criterian1(y,label)+loss)+L2 224 | loss_sum+=loss.item() 225 | optimizer.zero_grad() 226 | loss.backward() 227 | optimizer.step() 228 | if (j+1)%5==0 or j==0: 229 | acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader) 230 | result.loc[j//5]=[loss_sum,acc,f1,TP_test,TN_test,FP_test,FN_test] 231 | result.sort_values('Acc',inplace=True,ascending=False) 232 | return result.iloc[0]['Acc'] 233 | 234 | def train_norm(model,trainloader,testloader): 235 | result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN')) 236 | criterian1=LabelSmoothLoss(0.1).to(device) 237 | optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate) 238 | acc=0.5000 239 | loss_sum=0 240 | for j in range(epoch): 241 | loss_sum=0 242 | TP,TN,FP,FN=0,0,0,0 243 | model.train() 244 | for (X,Z,label,sub_id) in trainloader: 245 | x=X.to(device) 246 | z=Z.to(device) 247 | label=label.to(device) 248 | y=model(x,z) 249 | loss=criterian1(y,label) 250 | L2=torch.tensor(0,dtype=torch.float32).to(device) 251 | if L2_lamda>0: 252 | for name,parameters in model.named_parameters(): 253 | if name[0:5]=='clase' and name[-8:]=='0.weight': 254 | L2+=L2_lamda*torch.norm(parameters,2) 255 | loss=loss+L2 256 | loss_sum+=loss.item() 257 | optimizer.zero_grad() 258 | loss.backward() 259 | optimizer.step() 260 | if (j+1)%5==0 or j==0: 261 | acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader) 262 | result.loc[j//5]=[loss_sum,acc,f1,TP_test,TN_test,FP_test,FN_test] 263 | result.sort_values('Acc',inplace=True,ascending=False) 264 | return result.iloc[0]['Acc'] 265 | if __name__=='__main__': 266 | setup_seed(seed) 267 | train_site=test_site=np.load('DATAARRANGE/train_test_site.npy') 268 | train_asd_dict=np.load('DATAARRANGE/train_asd_dict.npy',allow_pickle=True).item() 269 | train_tdc_dict=np.load('DATAARRANGE/train_tdc_dict.npy',allow_pickle=True).item() 270 | test_asd_dict=np.load('DATAARRANGE/test_asd_dict.npy',allow_pickle=True).item() 271 | test_tdc_dict=np.load('DATAARRANGE/test_tdc_dict.npy',allow_pickle=True).item() 272 | L1_lamda=0.0 273 | L2_lamda=0.0001 274 | learning_rate=0.0001 275 | epoch =100 276 | batch_size=64 277 | gmma =1 278 | layer =1 279 | Acc_norm=np.zeros(10) 280 | Acc_fgsm=np.zeros(10) 281 | Acc_pgd =np.zeros(10) 282 | for index in range(10): 283 | start_t=time.time() 284 | saveModel=False 285 | max_acc=0.6 286 | modelname='../SAVEDModels/PGDtrainedensamble/models_{}_{}'.format(0,index) 287 | train_asd=train_asd_dict[index] 288 | train_tdc=train_tdc_dict[index] 289 | test_asd =test_asd_dict[index] 290 | test_tdc =test_tdc_dict[index] 291 | trainset=dataset(site=train_site,fmri_root=cpac_root,smri_root=smri_root,ASD=train_asd,TDC=train_tdc) 292 | trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True) 293 | testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc) 294 | testloader=DataLoader(testset,batch_size=1) 295 | 296 | # model=ANEGCN_fixed(args.layer).to(device) 297 | # acc=train_norm(model,trainloader,testloader) 298 | # if acc>=Acc_norm[index]: 299 | # Acc_norm[index]=acc 300 | 301 | # model=ANEGCN_fixed(args.layer).to(device) 302 | # acc=train_fgsm(model,trainloader,testloader,0.05) 303 | # if acc>=Acc_fgsm[index]: 304 | # Acc_fgsm[index]=acc 305 | 306 | model=ANEGCN(args.layer).to(device) 307 | acc=train_pgd(model,trainloader,testloader) 308 | if acc>=Acc_pgd[index]: 309 | Acc_pgd[index]=acc 310 | end_t=time.time() 311 | print('\r[%2d/10] Rest time: %.2f Speed:%.2f'%(1+index,(9-index)*(end_t-start_t)/3600,(end_t-start_t)/3600),end='') 312 | with open(args.filename,'a') as fileOut: 313 | print('[%2d/10] Norm Acc:%.4f FGSM Acc:%.4f PGD Acc:%.4f'%(index,Acc_norm[index],Acc_fgsm[index],Acc_pgd[index]),file=fileOut) 314 | with open(args.filename,'a') as fileOut: 315 | print('Norm Acc:',np.mean(Acc_norm),'+',np.std(Acc_norm),'\n', 316 | 'FGSM Acc:',np.mean(Acc_fgsm),'+',np.std(Acc_fgsm),'\n', 317 | 'PGD Acc:',np.mean(Acc_pgd ),'+',np.std(Acc_pgd ),file=fileOut) 318 | 319 | 320 | 321 | -------------------------------------------------------------------------------- /baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import os\n", 11 | "import time\n", 12 | "import matplotlib\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import random\n", 15 | "import pandas as pd\n", 16 | "import warnings \n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "from tqdm.notebook import tqdm\n", 19 | "cpac_root='/media/dm/0001A094000BF891/Yazid/ABIDEI_CPAC/'\n", 20 | "smri_root='/media/dm/0001A094000BF891/Yazid/ABIDEI_sMRI/'\n", 21 | "sites=np.load('DATAARRANGE/train_test_site.npy')\n", 22 | "train_asd_dict=np.load('DATAARRANGE/train_asd_dict.npy',allow_pickle=True).item()\n", 23 | "train_tdc_dict=np.load('DATAARRANGE/train_tdc_dict.npy',allow_pickle=True).item()\n", 24 | "test_asd_dict=np.load('DATAARRANGE/test_asd_dict.npy',allow_pickle=True).item()\n", 25 | "test_tdc_dict=np.load('DATAARRANGE/test_tdc_dict.npy',allow_pickle=True).item()\n", 26 | "from sklearn.metrics import accuracy_score\n", 27 | "from sklearn.metrics import confusion_matrix" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "### DATA" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "def normalized(X):\n", 44 | " return (X-X.mean())/X.std()\n", 45 | "def get_data(cpac_root,smri_root,sites,asd_list,tdc_list):\n", 46 | " x=np.zeros((13804))\n", 47 | " y=np.zeros((1))\n", 48 | " for index in range(len(sites)):\n", 49 | " site=sites[index]\n", 50 | " asdlist=asd_list[index]\n", 51 | " tdclist=tdc_list[index]\n", 52 | " \n", 53 | " slow5_asd=cpac_root+site+'/group1_slow5/'\n", 54 | " slow4_asd=cpac_root+site+'/group1_slow4/'\n", 55 | " voxel_asd=smri_root+site+'/group1/'\n", 56 | " fc_asd =cpac_root+site+'/group1_FC/'\n", 57 | " for file in asdlist:\n", 58 | " fc_data=np.load(fc_asd+file,allow_pickle=True).flatten()\n", 59 | " s5_data=normalized(np.load(slow5_asd+file,allow_pickle=True))\n", 60 | " s4_data=normalized(np.load(slow4_asd+file,allow_pickle=True))\n", 61 | " vl_data=normalized(np.load(voxel_asd+file,allow_pickle=True))\n", 62 | " data=np.concatenate((fc_data, s5_data), axis=0)\n", 63 | " data=np.concatenate((data,s4_data),axis=0)\n", 64 | " data=np.concatenate((data,vl_data),axis=0)\n", 65 | " x=np.row_stack((x,data))\n", 66 | " y=np.row_stack((y,np.array([1])))\n", 67 | " \n", 68 | " slow5_tdc=cpac_root+site+'/group2_slow5/'\n", 69 | " slow4_tdc=cpac_root+site+'/group2_slow4/'\n", 70 | " voxel_tdc=smri_root+site+'/group2/'\n", 71 | " fc_tdc =cpac_root+site+'/group2_FC/'\n", 72 | " for file in tdclist:\n", 73 | " fc_data=np.load(fc_tdc+file,allow_pickle=True).flatten()\n", 74 | " s5_data=normalized(np.load(slow5_tdc+file,allow_pickle=True))\n", 75 | " s4_data=normalized(np.load(slow4_tdc+file,allow_pickle=True))\n", 76 | " vl_data=normalized(np.load(voxel_tdc+file,allow_pickle=True))\n", 77 | " data=np.concatenate((fc_data, s5_data), axis=0)\n", 78 | " data=np.concatenate((data,s4_data),axis=0)\n", 79 | " data=np.concatenate((data,vl_data),axis=0)\n", 80 | " x=np.row_stack((x,data))\n", 81 | " y=np.row_stack((y,np.array([2])))\n", 82 | " return x[1:,:],y[1:,:] " 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "SVM\tAcc: 0.6904392721698117\n", 95 | "TN:295 | TP:398 | FP:186 FN:128\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "from sklearn import svm\n", 101 | "#L_2regularization parameter: 1, RBF kernel and kernel coefficient for RBF used 1 / (n_features * X.var()) as value\n", 102 | "acc=0\n", 103 | "TN,FP,FN,TP=0,0,0,0\n", 104 | "for index in range(10):\n", 105 | " train_x,train_y=get_data(cpac_root,smri_root,sites,train_asd_dict[index],train_tdc_dict[index])\n", 106 | " test_x,test_y=get_data(cpac_root,smri_root,sites,test_asd_dict[index],test_tdc_dict[index])\n", 107 | " clf = svm.SVC()\n", 108 | " clf.fit(train_x,train_y)\n", 109 | " pred_y=clf.predict(test_x)\n", 110 | " acc+=accuracy_score(pred_y,test_y)\n", 111 | " tn, fp, fn, tp = confusion_matrix(test_y,pred_y).ravel()\n", 112 | " TN+=tn\n", 113 | " FP+=fp\n", 114 | " FN+=fn\n", 115 | " TP+=tp\n", 116 | "print('SVM\\tAcc: '+str(acc/10))\n", 117 | "print('TN:%d | TP:%d | FP:%d FN:%d'%(TN,TP,FP,FN))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "SVC()" 129 | ] 130 | }, 131 | "execution_count": 4, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | } 135 | ], 136 | "source": [ 137 | "clf" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "RandomForestClassifier\tAcc: 0.6453366298626709\n", 150 | "TN:272 | TP:376 | FP:209 FN:150\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "from sklearn.ensemble import RandomForestClassifier\n", 156 | "#(200 trees, ‘gini’ function to measure the quality of a split)\n", 157 | "acc=0\n", 158 | "TN,FP,FN,TP=0,0,0,0\n", 159 | "for index in range(10):\n", 160 | " train_x,train_y=get_data(cpac_root,smri_root,sites,train_asd_dict[index],train_tdc_dict[index])\n", 161 | " test_x,test_y=get_data(cpac_root,smri_root,sites,test_asd_dict[index],test_tdc_dict[index])\n", 162 | " clf = RandomForestClassifier(n_estimators=200)\n", 163 | " clf.fit(train_x,train_y)\n", 164 | " pred_y=clf.predict(test_x)\n", 165 | " acc+=accuracy_score(pred_y,test_y)\n", 166 | " tn, fp, fn, tp = confusion_matrix(test_y,pred_y).ravel()\n", 167 | " TN+=tn\n", 168 | " FP+=fp\n", 169 | " FN+=fn\n", 170 | " TP+=tp\n", 171 | "print('RandomForestClassifier\\tAcc: '+str(acc/10))\n", 172 | "print('TN:%d | TP:%d | FP:%d FN:%d'%(TN,TP,FP,FN))" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 5, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "GradientBoostingClassifier\tAcc: 0.667999858771603\n", 185 | "TN:299 | TP:372 | FP:182 FN:154\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "from sklearn.ensemble import GradientBoostingClassifier\n", 191 | "#(the DecisionTreeClassifier as the base estimator, the maximum number of estimators was 200, learning rate was set to 1 and used SAMME.R as the real boosting algorithm)\n", 192 | "acc=0\n", 193 | "TN,FP,FN,TP=0,0,0,0\n", 194 | "for index in range(10):\n", 195 | " train_x,train_y=get_data(cpac_root,smri_root,sites,train_asd_dict[index],train_tdc_dict[index])\n", 196 | " test_x,test_y=get_data(cpac_root,smri_root,sites,test_asd_dict[index],test_tdc_dict[index])\n", 197 | " clf = GradientBoostingClassifier(n_estimators=200) \n", 198 | " clf.fit(train_x,train_y)\n", 199 | " pred_y=clf.predict(test_x)\n", 200 | " acc+=accuracy_score(pred_y,test_y)\n", 201 | " tn, fp, fn, tp = confusion_matrix(test_y,pred_y).ravel()\n", 202 | " TN+=tn\n", 203 | " FP+=fp\n", 204 | " FN+=fn\n", 205 | " TP+=tp\n", 206 | "print('GradientBoostingClassifier\\tAcc: '+str(acc/10))\n", 207 | "print('TN:%d | TP:%d | FP:%d FN:%d'%(TN,TP,FP,FN))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 6, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "ename": "ValueError", 217 | "evalue": "Negative values in data passed to MultinomialNB (input X)", 218 | "output_type": "error", 219 | "traceback": [ 220 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 221 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", 222 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtest_x\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcpac_root\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msmri_root\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msites\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_asd_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_tdc_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMultinomialNB\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrain_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mpred_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0macc\u001b[0m\u001b[0;34m+=\u001b[0m\u001b[0maccuracy_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpred_y\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 223 | "\u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.8/site-packages/sklearn/naive_bayes.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 637\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_init_counters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_effective_classes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 638\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 639\u001b[0m \u001b[0malpha\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_alpha\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_feature_log_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 224 | "\u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.8/site-packages/sklearn/naive_bayes.py\u001b[0m in \u001b[0;36m_count\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;34m\"\"\"Count and smooth feature occurrences.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 771\u001b[0;31m \u001b[0mcheck_non_negative\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"MultinomialNB (input X)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 772\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_count_\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0msafe_sparse_dot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 773\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclass_count_\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 225 | "\u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_non_negative\u001b[0;34m(X, whom)\u001b[0m\n\u001b[1;32m 1123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1124\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mX_min\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1125\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Negative values in data passed to %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mwhom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 226 | "\u001b[0;31mValueError\u001b[0m: Negative values in data passed to MultinomialNB (input X)" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "from sklearn.naive_bayes import MultinomialNB\n", 232 | "acc=0\n", 233 | "for index in range(10):\n", 234 | " train_x,train_y=get_data(cpac_root,smri_root,sites,train_asd_dict[index],train_tdc_dict[index])\n", 235 | " test_x,test_y=get_data(cpac_root,smri_root,sites,test_asd_dict[index],test_tdc_dict[index])\n", 236 | " clf = MultinomialNB(alpha=0.01)\n", 237 | " clf.fit(train_x,train_y)\n", 238 | " pred_y=clf.predict(test_x)\n", 239 | " acc+=accuracy_score(pred_y,test_y) \n", 240 | "print('MultinomialNB\\tAcc: '+str(acc/10))" 241 | ] 242 | } 243 | ], 244 | "metadata": { 245 | "kernelspec": { 246 | "display_name": "Pytorch", 247 | "language": "python", 248 | "name": "pytorch" 249 | }, 250 | "language_info": { 251 | "codemirror_mode": { 252 | "name": "ipython", 253 | "version": 3 254 | }, 255 | "file_extension": ".py", 256 | "mimetype": "text/x-python", 257 | "name": "python", 258 | "nbconvert_exporter": "python", 259 | "pygments_lexer": "ipython3", 260 | "version": "3.6.13" 261 | }, 262 | "toc-autonumbering": false, 263 | "toc-showcode": true, 264 | "toc-showmarkdowntxt": false, 265 | "toc-showtags": false 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 4 269 | } 270 | -------------------------------------------------------------------------------- /BrainVisial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn \n", 11 | "import numpy as np\n", 12 | "from torch.utils.data import Dataset\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "import torch.nn.functional as F\n", 15 | "import os\n", 16 | "import time\n", 17 | "import matplotlib\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "import random\n", 20 | "import pandas as pd\n", 21 | "import torch.nn.functional as F\n", 22 | "device=torch.device('cpu')\n", 23 | "cpac_root='/media/dm/0001A094000BF891/Yazid/ABIDEI_CPAC/'\n", 24 | "smri_root='/media/dm/0001A094000BF891/Yazid/ABIDEI_sMRI/'\n", 25 | "nan_subid=np.load('nan_subid.npy').tolist()\n", 26 | "aal=np.load('Atlas/AAL.npy').tolist()\n", 27 | "Lobe_aal=np.load('SA-result/Lobe_aal.npy',allow_pickle=True).tolist()\n", 28 | "Lobe={}\n", 29 | "Lobe['Central']=[1,2,57,58,17,18]\n", 30 | "Lobe['Frontal']=[3,4,5,6,7,8,9,10,11,12,13,14,15,16,19,20,21,22,23,24,25,26,27,28,69,70]\n", 31 | "Lobe['Temporal']=[79,80,81,82,85,86,89,90]\n", 32 | "Lobe['Parietal']=[59,60,61,62,63,64,65,66,67,68]\n", 33 | "Lobe['Occipital']=[43,44,45,46,47,48,49,50,51,52,53,54,55,56]\n", 34 | "Lobe['Limbic']=[31,32,33,34,35,36,37,38,39,40,83,84,87,88]\n", 35 | "Lobe['Insula']=[29,30]\n", 36 | "Lobe['Subcortical']=[41,42,71,72,73,74,75,76,77,78]\n", 37 | "Lobe['Cerebelum']=[91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108]\n", 38 | "Lobe['Vermis']=[109,110,111,112,113,114,115,116]" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "tags": [] 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "class Attention(nn.Module):\n", 50 | " def __init__(self):\n", 51 | " super(Attention,self).__init__()\n", 52 | " self.conv1=nn.Conv1d(in_channels=3,out_channels=3,kernel_size=1,padding=0)\n", 53 | " self.conv2=nn.Conv1d(in_channels=116,out_channels=116,kernel_size=1,padding=0)\n", 54 | " self.softmax=nn.Softmax(dim=-1)\n", 55 | " def forward(self,Z,X):\n", 56 | " batchsize,x_dim,x_c= X.size()\n", 57 | " batchsize,z_dim,z_c= Z.size()\n", 58 | " K=self.conv1(X.permute(0,2,1))# BS,x_c,x_dim\n", 59 | " Q=K.permute(0,2,1)# BS,x_dim,x_c\n", 60 | " V=self.conv2(Z.permute(0,2,1))# Bs,z_c,z_dim\n", 61 | " attention=self.softmax(torch.matmul(Q,K))#BS,x_dim,x_dim\n", 62 | " out=torch.bmm(attention,V).permute(0,2,1)#BS,z_dim,z_c\n", 63 | " return out\n", 64 | "class NEResGCN(nn.Module):\n", 65 | " def __init__(self,layer):\n", 66 | " super(NEResGCN,self).__init__()\n", 67 | " self.layer =layer\n", 68 | " self.relu =nn.ReLU()\n", 69 | " self.atten =nn.ModuleList([Attention() for i in range(layer)])\n", 70 | " self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 71 | " self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 72 | " self.node_w=nn.ParameterList([nn.Parameter(torch.randn((3,3),dtype=torch.float32)) for i in range(layer)])\n", 73 | " self.edge_w=nn.ParameterList([nn.Parameter(torch.randn((116,116),dtype=torch.float32)) for i in range(layer)])\n", 74 | " self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*3,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)])\n", 75 | " self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)])\n", 76 | " self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(),\n", 77 | " nn.Linear(1024,2))\n", 78 | " self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False)\n", 79 | " self._initialize_weights()\n", 80 | " # params initialization\n", 81 | " def _initialize_weights(self):\n", 82 | " for m in self.modules():\n", 83 | " if isinstance(m, (nn.Conv1d,nn.Linear)):\n", 84 | " nn.init.xavier_uniform_(m.weight)\n", 85 | " if m.bias is not None:\n", 86 | " nn.init.constant_(m.bias, 0)\n", 87 | " elif isinstance(m, nn.BatchNorm1d):\n", 88 | " nn.init.constant_(m.weight, 1)\n", 89 | " nn.init.constant_(m.bias, 0)\n", 90 | " def normalized(self,Z):\n", 91 | " n=Z.size()[0]\n", 92 | " A=Z[0,:,:]\n", 93 | " A=A+torch.diag(self.ones)\n", 94 | " d=A.sum(1)\n", 95 | " D=torch.diag(torch.pow(d,-1))\n", 96 | " A=D.mm(A).reshape(1,116,116)\n", 97 | " for i in range(1,n):\n", 98 | " A1=Z[i,:,:]+torch.diag(self.ones)\n", 99 | " d=A1.sum(1)\n", 100 | " D=torch.diag(torch.pow(d,-1))\n", 101 | " A1=D.mm(A1).reshape(1,116,116)\n", 102 | " A=torch.cat((A,A1),0)\n", 103 | " return A\n", 104 | " \n", 105 | " def update_A(self,Z):\n", 106 | " n=Z.size()[0]\n", 107 | " A=Z[0,:,:]\n", 108 | " Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*0.2))\n", 109 | " A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 110 | " A=A.reshape(1,116,116)\n", 111 | " for i in range(1,n):\n", 112 | " A2=Z[i,:,:]\n", 113 | " Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*0.2))\n", 114 | " A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 115 | " A2=A2.reshape(1,116,116)\n", 116 | " A=torch.cat((A,A2),0)\n", 117 | " return A\n", 118 | " \n", 119 | " def forward(self,X,Z):\n", 120 | " n=X.size()[0]\n", 121 | " XX=self.line_n[0](X.view(n,-1))\n", 122 | " ZZ=self.line_e[0](Z.view(n,-1))\n", 123 | " for i in range(self.layer):\n", 124 | " A=self.atten[i](Z,X)\n", 125 | " Z1=torch.matmul(A,Z)\n", 126 | " Z2=torch.matmul(Z1,self.edge_w[i])\n", 127 | " Z=self.relu(self.norm_e[i](Z2))+Z\n", 128 | " #Z.register_hook(grad_Z_hook)\n", 129 | " #feat_Z_hook(Z)\n", 130 | " ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1)\n", 131 | " X1=torch.matmul(A,X)\n", 132 | " X1=torch.matmul(X1,self.node_w[i])\n", 133 | " X=self.relu(self.norm_n[i](X1))+X\n", 134 | " #X.register_hook(grad_X_hook)\n", 135 | " #feat_X_hook(X)\n", 136 | " XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1)\n", 137 | " XZ=torch.cat((XX,ZZ),1)\n", 138 | " y=self.clase(XZ)\n", 139 | " #print(self.clase[0].weight)\n", 140 | " return y\n", 141 | "# def grad_X_hook(grad):\n", 142 | "# X_grad.append(grad)\n", 143 | "# def feat_X_hook(X):\n", 144 | "# X_feat.append(X.detach())\n", 145 | "# def grad_Z_hook(grad):\n", 146 | "# Z_grad.append(grad)\n", 147 | "# def feat_Z_hook(Z):\n", 148 | "# Z_feat.append(Z.detach())" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "tags": [] 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "def normalized(X):\n", 160 | " return (X-X.mean())/X.std()\n", 161 | "def gradient(device,model,dataloader):\n", 162 | " model.eval()\n", 163 | " result_asd_x=np.zeros((116,3))\n", 164 | " result_asd_z=np.zeros((116,116))\n", 165 | " result_tdc_x=np.zeros((116,3))\n", 166 | " result_tdc_z=np.zeros((116,116))\n", 167 | " for (X,Z,A,label,sub_id) in dataloader:\n", 168 | " model.zero_grad()\n", 169 | " X=torch.autograd.Variable(X,requires_grad=True)\n", 170 | " x=X.to(device)\n", 171 | " Z=torch.autograd.Variable(Z,requires_grad=True)\n", 172 | " z=Z.to(device)\n", 173 | " A=torch.autograd.Variable(A,requires_grad=True)\n", 174 | " a=A.to(device)\n", 175 | " y=model(x,z)\n", 176 | " if (label==torch.FloatTensor([0])).item():\n", 177 | " torch.autograd.backward(y,torch.FloatTensor([[1.,0.]]).to(device))\n", 178 | " else:\n", 179 | " torch.autograd.backward(y,torch.FloatTensor([[0.,1.]]).to(device))\n", 180 | " grad_X=X.grad.numpy()[0]\n", 181 | " grad_Z=Z.grad.numpy()[0]\n", 182 | " grad_X=normalized(grad_X)\n", 183 | " grad_Z=normalized(grad_Z)\n", 184 | " if label==torch.FloatTensor([0]).item():\n", 185 | " result_tdc_x+=grad_X\n", 186 | " result_tdc_z+=grad_Z\n", 187 | " else:\n", 188 | " result_asd_x+=grad_X\n", 189 | " result_asd_z+=grad_Z\n", 190 | " return result_asd_x,result_asd_z,result_tdc_x,result_tdc_z\n", 191 | "def grad_cam(grad,feat,top_rate):\n", 192 | " N=len(grad)//5\n", 193 | " n=grad[0].shape[2]\n", 194 | " result=torch.zeros((116,n))\n", 195 | " for i in range(N):\n", 196 | " weight=torch.zeros(5)\n", 197 | " for j in range(5):\n", 198 | " weight[j]=(grad[i*5+j]*(grad[i*5+j]>0)).sum()\n", 199 | " weight=F.softmax(weight, dim=0)\n", 200 | " feature=torch.zeros((116,n))\n", 201 | " for j in range(5):\n", 202 | " feature+=weight[j]*(grad[i*5+j][0]>0)*grad[i*5+j][0]\n", 203 | " value_x,_=torch.topk(torch.abs(feature.view(-1)),int(116*n*top_rate))\n", 204 | " result+=(torch.abs(feature)>=value_x[-1])\n", 205 | " return result " 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": { 212 | "tags": [] 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "def data_arange(sites,fmri_root,smri_root,nan_subid):\n", 217 | " asd,tdc=[],[]\n", 218 | " for site in sites:\n", 219 | " mri_asd=os.listdir(smri_root+site+'/group1')\n", 220 | " mri_tdc=os.listdir(smri_root+site+'/group2')\n", 221 | " fmri_asd=os.listdir(fmri_root+site+'/group1_FC')\n", 222 | " fmri_tdc=os.listdir(fmri_root+site+'/group2_FC')\n", 223 | " site_asd=[i for i in mri_asd if i in fmri_asd ]\n", 224 | " site_tdc=[i for i in mri_tdc if i in fmri_tdc ]\n", 225 | " site_asd=[i for i in site_asd if int(i[:5]) not in nan_subid]\n", 226 | " site_tdc=[i for i in site_tdc if int(i[:5]) not in nan_subid]\n", 227 | " asd.append(site_asd)\n", 228 | " tdc.append(site_tdc) \n", 229 | " return asd,tdc\n", 230 | "class dataset(Dataset):\n", 231 | " def __init__(self,fmri_root,smri_root,site,ASD,TDC,topk=True,rate=0.2):\n", 232 | " super(dataset,self).__init__()\n", 233 | " self.fmri=fmri_root\n", 234 | " self.smri=smri_root\n", 235 | " self.ASD=[j for i in ASD for j in i]\n", 236 | " self.TDC=[j for i in TDC for j in i]\n", 237 | " self.data=self.ASD+self.TDC\n", 238 | " random.shuffle(self.data)\n", 239 | " self.data_site={}\n", 240 | " for i in range(len(site)):\n", 241 | " data=ASD[i]+TDC[i]\n", 242 | " for j in data:\n", 243 | " if j not in self.data_site:\n", 244 | " self.data_site[j]=site[i] \n", 245 | " self.rate=rate\n", 246 | " self.topk=topk\n", 247 | " def normalize(self,A):\n", 248 | " d=A.sum(1)\n", 249 | " D=torch.diag(torch.pow(d,-1))\n", 250 | " return D.mm(A)\n", 251 | " def __getitem__(self,index):\n", 252 | " data=self.data[index]\n", 253 | " sub_id=int(data[0:5])\n", 254 | " if data in self.ASD:\n", 255 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group1_slow5/'+data,allow_pickle=True)\n", 256 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group1_slow4/'+data,allow_pickle=True)\n", 257 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group1/'+data,allow_pickle=True)\n", 258 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group1_FC/'+data,allow_pickle=True)\n", 259 | " elif data in self.TDC:\n", 260 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group2_slow5/'+data,allow_pickle=True)\n", 261 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group2_slow4/'+data,allow_pickle=True)\n", 262 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group2/'+data,allow_pickle=True)\n", 263 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group2_FC/'+data,allow_pickle=True)\n", 264 | " else:\n", 265 | " print('wrong input')\n", 266 | " data_slow5=(data_slow5-np.min(data_slow5))/(np.max(data_slow5)-np.min(data_slow5))\n", 267 | " data_slow4=(data_slow4-np.min(data_slow4))/(np.max(data_slow4)-np.min(data_slow4))\n", 268 | " if np.any(np.isnan(data_slow5)) or np.any(np.isnan(data_slow4)) or np.any(np.isnan(data_FCz)):\n", 269 | " print('data wronmg')\n", 270 | " #data_FCz=(data_FCz-np.min(data_FCz))/(np.max(data_FCz)-np.min(data_FCz))\n", 271 | " if self.data[index] in self.ASD:\n", 272 | " label=torch.tensor(1)\n", 273 | " else:\n", 274 | " label=torch.tensor(0)\n", 275 | " X=np.zeros((116,3),dtype=np.float32)\n", 276 | " X[:,0]=data_slow5\n", 277 | " X[:,1]=data_slow4\n", 278 | " X[:,2]=data_voxel\n", 279 | " data_FCz=data_FCz.astype(np.float32)\n", 280 | " Z=torch.from_numpy(data_FCz)\n", 281 | " X=torch.from_numpy(X)\n", 282 | " if self.topk:\n", 283 | " Value,_=torch.topk(torch.abs(Z.view(-1)),int(116*116*self.rate))\n", 284 | " A=(torch.abs(Z)>=Value[-1])+torch.tensor(0.0,dtype=torch.float32)\n", 285 | " A=self.normalize(A)\n", 286 | " return X,Z,A,label,sub_id\n", 287 | " def __len__(self):\n", 288 | " return len(self.data)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "## gradient-basedfor\n", 298 | "featuresMap=list()\n", 299 | "all_site=os.listdir(cpac_root)\n", 300 | "Result_ASD_X,Result_TDC_X=np.zeros((116,3)),np.zeros((116,3))\n", 301 | "Result_TDC_Z,Result_ASD_Z=np.zeros((116,116)),np.zeros((116,116))\n", 302 | "vote = 13\n", 303 | "for i in range(vote):\n", 304 | " for index in range(10):\n", 305 | " PATH='/media/dm/0001A094000BF891/Yazid/SAVEDModels/ensamble/models_{}_{}'.format(i,index)\n", 306 | " model=NEResGCN(5)\n", 307 | " model.load_state_dict(torch.load(PATH))\n", 308 | " train_asd,train_tdc=data_arange(all_site,fmri_root=cpac_root,smri_root=smri_root,nan_subid=nan_subid)\n", 309 | " trainset=dataset(site=all_site,fmri_root=cpac_root,smri_root=smri_root,ASD=train_asd,TDC=train_tdc)\n", 310 | " trainloader=DataLoader(trainset,batch_size=1,shuffle=True,drop_last=True)\n", 311 | " result_asd_x,result_asd_z,result_tdc_x,result_tdc_z=gradient(device,model,trainloader)\n", 312 | " Result_ASD_X+=result_asd_x\n", 313 | " Result_ASD_Z+=result_asd_z\n", 314 | " Result_TDC_X+=result_tdc_x\n", 315 | " Result_TDC_Z+=result_tdc_z\n", 316 | "np.save('SA-result/ensamble/adTrained_asd_x_{}.npy'.format(vote),Result_ASD_X)\n", 317 | "np.save('SA-result/ensamble/adTrained_asd_z_{}.npy'.format(vote),Result_ASD_Z)\n", 318 | "np.save('SA-result/ensamble/adTrained_tdc_x_{}.npy'.format(vote),Result_TDC_X)\n", 319 | "np.save('SA-result/ensamble/adTrained_tdc_z_{}.npy'.format(vote),Result_TDC_Z)" 320 | ] 321 | } 322 | ], 323 | "metadata": { 324 | "kernelspec": { 325 | "display_name": "Pytorch", 326 | "language": "python", 327 | "name": "pytorch" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.6.13" 340 | } 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 5 344 | } 345 | -------------------------------------------------------------------------------- /Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class GraphNorm(nn.Module): 4 | def __init__(self,features): 5 | super(GraphNorm,self).__init__() 6 | self.weight = nn.Parameter(torch.randn(features)) 7 | self.bias = nn.Parameter(torch.randn(features)) 8 | self.alpha = nn.Parameter(torch.randn(features)) 9 | def forward(self,X): 10 | X=X.transpose(0,1) 11 | X=self.weight*(X-self.alpha*X.mean(0))/X.std(0) 12 | X=X.transpose(0,1)+self.bias 13 | return X 14 | 15 | class Attention(nn.Module): 16 | def __init__(self): 17 | super(Attention,self).__init__() 18 | self.conv1=nn.Conv1d(in_channels=3,out_channels=3,kernel_size=1,padding=0) 19 | self.conv2=nn.Conv1d(in_channels=116,out_channels=116,kernel_size=1,padding=0) 20 | self.softmax=nn.Softmax(dim=-1) 21 | def forward(self,Z,X): 22 | K=self.conv1(X.permute(0,2,1))# BS,x_c,x_dim 23 | Q=K.permute(0,2,1)# BS,x_dim,x_c 24 | V=self.conv2(Z.permute(0,2,1))# Bs,z_c,z_dim 25 | attention=self.softmax(torch.matmul(Q,K))#BS,x_dim,x_dim 26 | out=torch.bmm(attention,V).permute(0,2,1)#BS,z_dim,z_c 27 | return out 28 | 29 | class NEGAN(nn.Module): 30 | def __init__(self,layer): 31 | super(NEGAN,self).__init__() 32 | self.layer =layer 33 | self.relu =nn.ReLU() 34 | self.atten =nn.ModuleList([Attention() for i in range(layer)]) 35 | self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 36 | self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 37 | self.node_w=nn.ParameterList([nn.Parameter(torch.randn((3,3),dtype=torch.float32)) for i in range(layer)]) 38 | self.edge_w=nn.ParameterList([nn.Parameter(torch.randn((116,116),dtype=torch.float32)) for i in range(layer)]) 39 | self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*3,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)]) 40 | self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)]) 41 | self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(), 42 | nn.Linear(1024,2)) 43 | self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False) 44 | self._initialize_weights() 45 | # params initialization 46 | def _initialize_weights(self): 47 | for m in self.modules(): 48 | if isinstance(m, (nn.Conv1d,nn.Linear)): 49 | nn.init.xavier_uniform_(m.weight) 50 | if m.bias is not None: 51 | nn.init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.BatchNorm1d): 53 | nn.init.constant_(m.weight, 1) 54 | nn.init.constant_(m.bias, 0) 55 | def normalized(self,Z): 56 | n=Z.size()[0] 57 | A=Z[0,:,:] 58 | A=A+torch.diag(self.ones) 59 | d=A.sum(1) 60 | D=torch.diag(torch.pow(d,-1)) 61 | A=D.mm(A).reshape(1,116,116) 62 | for i in range(1,n): 63 | A1=Z[i,:,:]+torch.diag(self.ones) 64 | d=A1.sum(1) 65 | D=torch.diag(torch.pow(d,-1)) 66 | A1=D.mm(A1).reshape(1,116,116) 67 | A=torch.cat((A,A1),0) 68 | return A 69 | 70 | def update_A(self,Z): 71 | n=Z.size()[0] 72 | A=Z[0,:,:] 73 | Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*args.thed)) 74 | A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32) 75 | A=A.reshape(1,116,116) 76 | for i in range(1,n): 77 | A2=Z[i,:,:] 78 | Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*args.thed)) 79 | A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32) 80 | A2=A2.reshape(1,116,116) 81 | A=torch.cat((A,A2),0) 82 | return A 83 | 84 | def forward(self,X,Z): 85 | n=X.size()[0] 86 | XX=self.line_n[0](X.view(n,-1)) 87 | ZZ=self.line_e[0](Z.view(n,-1)) 88 | for i in range(self.layer): 89 | A=self.atten[i](Z,X) 90 | Z1=torch.matmul(A,Z) 91 | Z2=torch.matmul(Z1,self.edge_w[i]) 92 | Z=self.relu(self.norm_e[i](Z2))+Z 93 | ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1) 94 | X1=torch.matmul(A,X) 95 | X1=torch.matmul(X1,self.node_w[i]) 96 | X=self.relu(self.norm_n[i](X1))+X 97 | XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1) 98 | XZ=torch.cat((XX,ZZ),1) 99 | y=self.clase(XZ) 100 | return y 101 | 102 | 103 | class ANEGCN_fixed(nn.Module): 104 | """相较于初始模型,只用一层attention,后面的层使用邻接矩阵与第一个相同""" 105 | def __init__(self,layer): 106 | super(ANEGCN_fixed,self).__init__() 107 | self.layer =layer 108 | self.relu =nn.ReLU() 109 | self.atten =Attention() 110 | self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 111 | self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 112 | self.node_w=nn.ModuleList([nn.Linear(3,3) for i in range(layer)]) 113 | self.edge_w=nn.ModuleList([nn.Linear(116,116) for i in range(layer)]) 114 | self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*3,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)]) 115 | self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)]) 116 | self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(), 117 | nn.Linear(1024,2)) 118 | self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False) 119 | self._initialize_weights() 120 | # params initialization 121 | def _initialize_weights(self): 122 | for m in self.modules(): 123 | if isinstance(m, (nn.Conv1d,nn.Linear)): 124 | nn.init.xavier_uniform_(m.weight) 125 | if m.bias is not None: 126 | nn.init.constant_(m.bias, 0) 127 | elif isinstance(m, nn.BatchNorm1d): 128 | nn.init.constant_(m.weight, 1) 129 | nn.init.constant_(m.bias, 0) 130 | def normalized(self,Z): 131 | n=Z.size()[0] 132 | A=Z[0,:,:] 133 | A=A+torch.diag(self.ones) 134 | d=A.sum(1) 135 | D=torch.diag(torch.pow(d,-1)) 136 | A=D.mm(A).reshape(1,116,116) 137 | for i in range(1,n): 138 | A1=Z[i,:,:]+torch.diag(self.ones) 139 | d=A1.sum(1) 140 | D=torch.diag(torch.pow(d,-1)) 141 | A1=D.mm(A1).reshape(1,116,116) 142 | A=torch.cat((A,A1),0) 143 | return A 144 | 145 | def update_A(self,Z): 146 | n=Z.size()[0] 147 | A=Z[0,:,:] 148 | Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*self.thed)) 149 | A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32) 150 | A=A.reshape(1,116,116) 151 | for i in range(1,n): 152 | A2=Z[i,:,:] 153 | Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*self.thed)) 154 | A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32) 155 | A2=A2.reshape(1,116,116) 156 | A=torch.cat((A,A2),0) 157 | return A 158 | 159 | def forward(self,X,Z): 160 | n=X.size()[0] 161 | XX=self.line_n[0](X.view(n,-1)) 162 | ZZ=self.line_e[0](Z.view(n,-1)) 163 | A=self.atten(Z,X) 164 | for i in range(self.layer): 165 | Z1=self.edge_w[i](torch.matmul(A,Z)) 166 | Z=self.relu(self.norm_e[i](Z1))+Z 167 | ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1) 168 | X1=self.node_w[i](torch.matmul(A,X)) 169 | X=self.relu(self.norm_n[i](X1))+X 170 | #X.register_hook(grad_X_hook) 171 | #feat_X_hook(X) 172 | XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1) 173 | XZ=torch.cat((XX,ZZ),1) 174 | y=self.clase(XZ) 175 | #print(self.clase[0].weight) 176 | return y 177 | class ANEGCN_noatt(nn.Module): 178 | """与初始模型相比较,去掉了Attention,转而只使用通过阈值的方法来确定""" 179 | def __init__(self,layer,thed): 180 | super(ANEGCN_noatt,self).__init__() 181 | self.layer =layer 182 | self.thed = thed 183 | self.relu =nn.ReLU() 184 | self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 185 | self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 186 | 187 | # self.node_w=nn.ParameterList([nn.Parameter(torch.randn((3,3),dtype=torch.float32)) for i in range(layer)]) 188 | # self.edge_w=nn.ParameterList([nn.Parameter(torch.randn((116,116),dtype=torch.float32)) for i in range(layer)]) 189 | self.node_w=nn.ModuleList([nn.Linear(3,3) for i in range(layer)]) 190 | self.edge_w=nn.ModuleList([nn.Linear(116,116) for i in range(layer)]) 191 | 192 | self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*3,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)]) 193 | self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)]) 194 | self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(), 195 | nn.Linear(1024,2)) 196 | self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False) 197 | self._initialize_weights() 198 | # params initialization 199 | def _initialize_weights(self): 200 | for m in self.modules(): 201 | if isinstance(m, (nn.Conv1d,nn.Linear)): 202 | nn.init.xavier_uniform_(m.weight) 203 | if m.bias is not None: 204 | nn.init.constant_(m.bias, 0) 205 | elif isinstance(m, nn.BatchNorm1d): 206 | nn.init.constant_(m.weight, 1) 207 | nn.init.constant_(m.bias, 0) 208 | def normalized(self,Z): 209 | n=Z.size()[0] 210 | A=Z[0,:,:] 211 | A=A+torch.diag(self.ones) 212 | d=A.sum(1) 213 | D=torch.diag(torch.pow(d,-1)) 214 | A=D.mm(A).reshape(1,116,116) 215 | for i in range(1,n): 216 | A1=Z[i,:,:]+torch.diag(self.ones) 217 | d=A1.sum(1) 218 | D=torch.diag(torch.pow(d,-1)) 219 | A1=D.mm(A1).reshape(1,116,116) 220 | A=torch.cat((A,A1),0) 221 | return A 222 | 223 | def update_A(self,Z): 224 | n=Z.size()[0] 225 | A=Z[0,:,:] 226 | Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*self.thed)) 227 | A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32) 228 | A=A.reshape(1,116,116) 229 | for i in range(1,n): 230 | A2=Z[i,:,:] 231 | Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*self.thed)) 232 | A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32) 233 | A2=A2.reshape(1,116,116) 234 | A=torch.cat((A,A2),0) 235 | return A 236 | 237 | def forward(self,X,Z): 238 | n=X.size()[0] 239 | XX=self.line_n[0](X.view(n,-1)) 240 | ZZ=self.line_e[0](Z.view(n,-1)) 241 | for i in range(self.layer): 242 | A=self.update_A(Z) 243 | A=self.normalized(Z) 244 | Z1=self.edge_w[i](torch.matmul(A,Z)) 245 | Z=self.relu(self.norm_e[i](Z1))+Z 246 | ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1) 247 | X1=self.node_w[i](torch.matmul(A,X)) 248 | X=self.relu(self.norm_n[i](X1))+X 249 | #X.register_hook(grad_X_hook) 250 | #feat_X_hook(X) 251 | XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1) 252 | XZ=torch.cat((XX,ZZ),1) 253 | y=self.clase(XZ) 254 | #print(self.clase[0].weight) 255 | return y 256 | 257 | class ANEGCN_1(nn.Module): 258 | '''不在使用Feedforward 作为降采样的方法,使用线性层''' 259 | def __init__(self,layer): 260 | super(ANEGCN_1,self).__init__() 261 | self.layer =layer 262 | self.celu =nn.ReLU() 263 | self.atten =nn.ModuleList([Attention() for i in range(layer)]) 264 | self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 265 | self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 266 | self.node_w=nn.ModuleList([nn.Linear(3,3) for i in range(layer)]) 267 | self.edge_w=nn.ModuleList([nn.Linear(116,116) for i in range(layer)]) 268 | self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(3,1),nn.ReLU(),nn.BatchNorm1d(116)) for i in range(layer+1)]) 269 | self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116,3),nn.ReLU(),nn.BatchNorm1d(116)) for i in range(layer+1)]) 270 | self.clase=nn.Sequential(nn.Linear(116*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(), 271 | nn.Linear(1024,2)) 272 | 273 | def forward(self,X,Z): 274 | # X: bs*N*n_features 275 | # Z: bs8N*N 276 | n=X.size()[0] 277 | XX=self.line_n[0](X).view(n,-1) 278 | ZZ=self.line_e[0](Z).view(n,-1) 279 | for i in range(self.layer): 280 | A=self.atten[i](Z,X) 281 | Z1=torch.matmul(A,Z) 282 | Z2=self.edge_w[i](Z1) 283 | Z=self.celu(self.norm_e[i](Z2))+Z 284 | ZZ=torch.cat((ZZ,self.line_e[i+1](Z).view(n,-1)),dim=1) 285 | X1=torch.matmul(A,X) 286 | X1=self.node_w[i](X1) 287 | X=self.celu(self.norm_n[i](X1))+X 288 | XX=torch.cat((XX,self.line_n[i+1](X).view(n,-1)),dim=1) 289 | XZ=torch.cat((XX,ZZ),1) 290 | y=self.clase(XZ) 291 | return y 292 | 293 | 294 | class ConvDownSample(nn.Module): 295 | def __init__(self,in_feat): 296 | super(ConvDownSample,self).__init__() 297 | if in_feat==3: 298 | self.Conv=nn.Conv1d(1,1,3) 299 | elif in_feat==116: 300 | self.Conv=nn.Conv1d(1,1,39*2,39,39) 301 | self.activeFunc=nn.ReLU() 302 | self.norm=nn.BatchNorm1d(1) 303 | def forward(self,X): 304 | h=self.norm(self.activeFunc(self.Conv(X.reshape(X.shape[0]*X.shape[1],1,X.shape[2])))) 305 | return h 306 | class ANEGCN_2(nn.Module): 307 | '''使用卷积层作为降采样的方法''' 308 | def __init__(self,layer): 309 | super(ANEGCN_2,self).__init__() 310 | self.layer =layer 311 | self.celu =nn.ReLU() 312 | self.atten =nn.ModuleList([Attention() for i in range(layer)]) 313 | self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 314 | self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 315 | self.node_w=nn.ModuleList([nn.Linear(3,3) for i in range(layer)]) 316 | self.edge_w=nn.ModuleList([nn.Linear(116,116) for i in range(layer)]) 317 | self.line_n=nn.ModuleList([ConvDownSample(3) for i in range(layer+1)]) 318 | self.line_e=nn.ModuleList([ConvDownSample(116) for i in range(layer+1)]) 319 | self.clase=nn.Sequential(nn.Linear(116*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(), 320 | nn.Linear(1024,2)) 321 | 322 | def forward(self,X,Z): 323 | # X: bs*N*n_features 324 | # Z: bs8N*N 325 | n=X.size()[0] 326 | XX=self.line_n[0](X).view(n,-1) 327 | ZZ=self.line_e[0](Z).view(n,-1) 328 | for i in range(self.layer): 329 | A=self.atten[i](Z,X) 330 | Z1=self.edge_w[i](torch.matmul(A,Z)) 331 | Z=self.celu(self.norm_e[i](Z1))+Z 332 | ZZ=torch.cat((ZZ,self.line_e[i+1](Z).view(n,-1)),dim=1) 333 | X1=self.node_w[i](torch.matmul(A,X)) 334 | X=self.celu(self.norm_n[i](X1))+X 335 | XX=torch.cat((XX,self.line_n[i+1](X).view(n,-1)),dim=1) 336 | XZ=torch.cat((XX,ZZ),1) 337 | y=self.clase(XZ) 338 | return y 339 | 340 | class AvgPoolDownSample(nn.Module): 341 | def __init__(self,in_feat): 342 | super(AvgPoolDownSample,self).__init__() 343 | if in_feat==3: 344 | self.pool=nn.AvgPool1d(3,1,0) 345 | elif in_feat==116: 346 | self.pool=nn.AvgPool1d(40,38,0) 347 | self.activeFunc=nn.ReLU() 348 | self.norm=nn.BatchNorm1d(1) 349 | def forward(self,X): 350 | h=self.norm(self.activeFunc(self.pool(X.reshape(X.shape[0]*X.shape[1],1,X.shape[2])))) 351 | return h 352 | class ANEGCN_3(nn.Module): 353 | '''使用均值池化层作为降采样的方法''' 354 | def __init__(self,layer): 355 | super(ANEGCN_3,self).__init__() 356 | self.layer =layer 357 | self.celu =nn.ReLU() 358 | self.atten =nn.ModuleList([Attention() for i in range(layer)]) 359 | self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 360 | self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 361 | self.node_w=nn.ModuleList([nn.Linear(3,3) for i in range(layer)]) 362 | self.edge_w=nn.ModuleList([nn.Linear(116,116) for i in range(layer)]) 363 | self.line_n=nn.ModuleList([AvgPoolDownSample(3) for i in range(layer+1)]) 364 | self.line_e=nn.ModuleList([AvgPoolDownSample(116) for i in range(layer+1)]) 365 | self.clase=nn.Sequential(nn.Linear(116*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(), 366 | nn.Linear(1024,2)) 367 | 368 | def forward(self,X,Z): 369 | # X: bs*N*n_features 370 | # Z: bs8N*N 371 | n=X.size()[0] 372 | XX=self.line_n[0](X).view(n,-1) 373 | ZZ=self.line_e[0](Z).view(n,-1) 374 | for i in range(self.layer): 375 | A=self.atten[i](Z,X) 376 | Z1=self.edge_w[i](torch.matmul(A,Z)) 377 | Z=self.celu(self.norm_e[i](Z1))+Z 378 | ZZ=torch.cat((ZZ,self.line_e[i+1](Z).view(n,-1)),dim=1) 379 | X1=self.node_w[i](torch.matmul(A,X)) 380 | X=self.celu(self.norm_n[i](X1))+X 381 | XX=torch.cat((XX,self.line_n[i+1](X).view(n,-1)),dim=1) 382 | XZ=torch.cat((XX,ZZ),1) 383 | y=self.clase(XZ) 384 | return y 385 | 386 | class MaxPoolDownSample(nn.Module): 387 | def __init__(self,in_feat): 388 | super(MaxPoolDownSample,self).__init__() 389 | if in_feat==3: 390 | self.pool=nn.MaxPool1d(3,1,0) 391 | elif in_feat==116: 392 | self.pool=nn.MaxPool1d(40,38,0) 393 | self.activeFunc=nn.ReLU() 394 | self.norm=nn.BatchNorm1d(1) 395 | def forward(self,X): 396 | h=self.norm(self.activeFunc(self.pool(X.reshape(X.shape[0]*X.shape[1],1,X.shape[2])))) 397 | return h 398 | class ANEGCN_4(nn.Module): 399 | '''使用最大池化层作为降采样的方法''' 400 | def __init__(self,layer): 401 | super(ANEGCN_4,self).__init__() 402 | self.layer =layer 403 | self.celu =nn.ReLU() 404 | self.atten =nn.ModuleList([Attention() for i in range(layer)]) 405 | self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 406 | self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)]) 407 | self.node_w=nn.ModuleList([nn.Linear(3,3) for i in range(layer)]) 408 | self.edge_w=nn.ModuleList([nn.Linear(116,116) for i in range(layer)]) 409 | self.line_n=nn.ModuleList([MaxPoolDownSample(3) for i in range(layer+1)]) 410 | self.line_e=nn.ModuleList([MaxPoolDownSample(116) for i in range(layer+1)]) 411 | self.clase=nn.Sequential(nn.Linear(116*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(), 412 | nn.Linear(1024,2)) 413 | 414 | def forward(self,X,Z): 415 | # X: bs*N*n_features 416 | # Z: bs8N*N 417 | n=X.size()[0] 418 | XX=self.line_n[0](X).view(n,-1) 419 | ZZ=self.line_e[0](Z).view(n,-1) 420 | for i in range(self.layer): 421 | A=self.atten[i](Z,X) 422 | Z1=self.edge_w[i](torch.matmul(A,Z)) 423 | Z=self.celu(self.norm_e[i](Z1))+Z 424 | ZZ=torch.cat((ZZ,self.line_e[i+1](Z).view(n,-1)),dim=1) 425 | X1=self.node_w[i](torch.matmul(A,X)) 426 | X=self.celu(self.norm_n[i](X1))+X 427 | XX=torch.cat((XX,self.line_n[i+1](X).view(n,-1)),dim=1) 428 | XZ=torch.cat((XX,ZZ),1) 429 | y=self.clase(XZ) 430 | return y 431 | -------------------------------------------------------------------------------- /K-fold.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn \n", 11 | "import numpy as np\n", 12 | "from torch.utils.data import Dataset\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "import os\n", 15 | "import time\n", 16 | "import random\n", 17 | "import pandas as pd\n", 18 | "import torch.nn.functional as F\n", 19 | "device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 20 | "import warnings \n", 21 | "warnings.filterwarnings(\"ignore\")\n", 22 | "from tqdm import tqdm\n", 23 | "cpac_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_CPAC/'\n", 24 | "smri_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_sMRI/'\n", 25 | "nan_subid=np.load('nan_subid.npy').tolist()\n", 26 | "def setup_seed(seed):\n", 27 | " torch.manual_seed(seed)\n", 28 | " torch.cuda.manual_seed_all(seed)\n", 29 | " np.random.seed(seed)\n", 30 | " random.seed(seed)\n", 31 | " torch.backends.cudnn.deterministic = True" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "class Attention(nn.Module):\n", 41 | " def __init__(self):\n", 42 | " super(Attention,self).__init__()\n", 43 | " self.conv1=nn.Conv1d(in_channels=3,out_channels=3,kernel_size=1,padding=0)\n", 44 | " self.conv2=nn.Conv1d(in_channels=116,out_channels=116,kernel_size=1,padding=0)\n", 45 | " self.softmax=nn.Softmax(dim=-1)\n", 46 | " def forward(self,Z,X):\n", 47 | " batchsize,x_dim,x_c= X.size()\n", 48 | " batchsize,z_dim,z_c= Z.size()\n", 49 | " K=self.conv1(X.permute(0,2,1))# BS,x_c,x_dim\n", 50 | " Q=K.permute(0,2,1)# BS,x_dim,x_c\n", 51 | " V=self.conv2(Z.permute(0,2,1))# Bs,z_c,z_dim\n", 52 | " attention=self.softmax(torch.matmul(Q,K))#BS,x_dim,x_dim\n", 53 | " out=torch.bmm(attention,V).permute(0,2,1)#BS,z_dim,z_c\n", 54 | " return out\n", 55 | "class NEResGCN(nn.Module):\n", 56 | " def __init__(self,layer):\n", 57 | " super(NEResGCN,self).__init__()\n", 58 | " self.layer =layer\n", 59 | " self.relu =nn.ReLU()\n", 60 | " self.atten =nn.ModuleList([Attention() for i in range(layer)])\n", 61 | " self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 62 | " self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 63 | " self.node_w=nn.ParameterList([nn.Parameter(torch.randn((3,3),dtype=torch.float32)) for i in range(layer)])\n", 64 | " self.edge_w=nn.ParameterList([nn.Parameter(torch.randn((116,116),dtype=torch.float32)) for i in range(layer)])\n", 65 | " self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*3,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)])\n", 66 | " self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)])\n", 67 | " self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(),\n", 68 | " nn.Linear(1024,2))\n", 69 | " self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False)\n", 70 | " self._initialize_weights()\n", 71 | " # params initialization\n", 72 | " def _initialize_weights(self):\n", 73 | " for m in self.modules():\n", 74 | " if isinstance(m, (nn.Conv1d,nn.Linear)):\n", 75 | " nn.init.xavier_uniform_(m.weight)\n", 76 | " if m.bias is not None:\n", 77 | " nn.init.constant_(m.bias, 0)\n", 78 | " elif isinstance(m, nn.BatchNorm1d):\n", 79 | " nn.init.constant_(m.weight, 1)\n", 80 | " nn.init.constant_(m.bias, 0)\n", 81 | " def normalized(self,Z):\n", 82 | " n=Z.size()[0]\n", 83 | " A=Z[0,:,:]\n", 84 | " A=A+torch.diag(self.ones)\n", 85 | " d=A.sum(1)\n", 86 | " D=torch.diag(torch.pow(d,-1))\n", 87 | " A=D.mm(A).reshape(1,116,116)\n", 88 | " for i in range(1,n):\n", 89 | " A1=Z[i,:,:]+torch.diag(self.ones)\n", 90 | " d=A1.sum(1)\n", 91 | " D=torch.diag(torch.pow(d,-1))\n", 92 | " A1=D.mm(A1).reshape(1,116,116)\n", 93 | " A=torch.cat((A,A1),0)\n", 94 | " return A\n", 95 | " \n", 96 | " def update_A(self,Z):\n", 97 | " n=Z.size()[0]\n", 98 | " A=Z[0,:,:]\n", 99 | " Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*0.2))\n", 100 | " A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 101 | " A=A.reshape(1,116,116)\n", 102 | " for i in range(1,n):\n", 103 | " A2=Z[i,:,:]\n", 104 | " Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*0.2))\n", 105 | " A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 106 | " A2=A2.reshape(1,116,116)\n", 107 | " A=torch.cat((A,A2),0)\n", 108 | " return A\n", 109 | " \n", 110 | " def forward(self,X,Z):\n", 111 | " n=X.size()[0]\n", 112 | " XX=self.line_n[0](X.view(n,-1))\n", 113 | " ZZ=self.line_e[0](Z.view(n,-1))\n", 114 | " for i in range(self.layer):\n", 115 | " A=self.atten[i](Z,X)\n", 116 | " Z1=torch.matmul(A,Z)\n", 117 | " Z2=torch.matmul(Z1,self.edge_w[i])\n", 118 | " Z=self.relu(self.norm_e[i](Z2))+Z\n", 119 | " ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1)\n", 120 | " X1=torch.matmul(A,X)\n", 121 | " X1=torch.matmul(X1,self.node_w[i])\n", 122 | " X=self.relu(self.norm_n[i](X1))+X\n", 123 | " #X.register_hook(grad_X_hook)\n", 124 | " #feat_X_hook(X)\n", 125 | " XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1)\n", 126 | " XZ=torch.cat((XX,ZZ),1)\n", 127 | " y=self.clase(XZ)\n", 128 | " #print(self.clase[0].weight)\n", 129 | " return y\n", 130 | "def grad_X_hook(grad):\n", 131 | " X_grad.append(grad)\n", 132 | "def feat_X_hook(X):\n", 133 | " X_feat.append(X.detach())\n", 134 | "X_grad=list()\n", 135 | "X_feat=list()" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 3, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "class LabelSmoothLoss(nn.Module):\n", 145 | " \n", 146 | " def __init__(self, smoothing=0.0):\n", 147 | " super(LabelSmoothLoss, self).__init__()\n", 148 | " self.smoothing = smoothing\n", 149 | " \n", 150 | " def forward(self, input, target):\n", 151 | " log_prob = F.log_softmax(input, dim=-1)\n", 152 | " weight = input.new_ones(input.size()) * \\\n", 153 | " self.smoothing / (input.size(-1) - 1.)\n", 154 | " weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))\n", 155 | " loss = (-weight * log_prob).sum(dim=-1).mean()\n", 156 | " return loss\n", 157 | "def data_split(full_list, ratio, shuffle=True):\n", 158 | " \"\"\"\n", 159 | " 数据集拆分: 将列表full_list按比例ratio(随机)划分为2个子列表sublist_1与sublist_2\n", 160 | " :param full_list: 数据列表\n", 161 | " :param ratio: 子列表1\n", 162 | " :param shuffle: 子列表2\n", 163 | " :return:\n", 164 | " \"\"\"\n", 165 | " n_total = len(full_list)\n", 166 | " offset = int(n_total * ratio)\n", 167 | " if n_total == 0 or offset < 1:\n", 168 | " return [], full_list\n", 169 | " if shuffle:\n", 170 | " random.shuffle(full_list)\n", 171 | " sublist_1 = full_list[:offset]\n", 172 | " sublist_2 = full_list[offset:]\n", 173 | " return sublist_1, sublist_2\n", 174 | "def data_2_k(full_list,k,shuffle=True):\n", 175 | " n_total=len(full_list)\n", 176 | " if shuffle:\n", 177 | " random.shuffle(full_list)\n", 178 | " data_list_list=[]\n", 179 | " for i in range(k):\n", 180 | " data_list_list.append(full_list[int(i*n_total/k):int((i+1)*n_total/k)])\n", 181 | " return data_list_list\n", 182 | "def test(device,model,testloader):\n", 183 | " model.eval()\n", 184 | " TP_test,TN_test,FP_test,FN_test=0,0,0,0\n", 185 | " with torch.no_grad():\n", 186 | " for (X,Z,label,sub_id) in testloader:\n", 187 | " TP,TN,FN,FP=0,0,0,0\n", 188 | " n=X.size()[0]\n", 189 | " X=X.to(device)\n", 190 | " Z=Z.to(device)\n", 191 | " label=label.to(device)\n", 192 | " y=model(X,Z)\n", 193 | " _,predict=torch.max(y,1)\n", 194 | " TP+=((predict==1)&(label==1)).sum().item()\n", 195 | " TN+=((predict==0)&(label==0)).sum().item()\n", 196 | " FN+=((predict==0)&(label==1)).sum().item()\n", 197 | " FP+=((predict==1)&(label==0)).sum().item()\n", 198 | " TP_test+=TP\n", 199 | " TN_test+=TN\n", 200 | " FP_test+=FP\n", 201 | " FN_test+=FN\n", 202 | " acc,f1=cal_evaluate(TP_test,TN_test,FP_test,FN_test)\n", 203 | " global max_acc\n", 204 | " global modelname\n", 205 | " global savedModel\n", 206 | " if acc>=max_acc:\n", 207 | " max_acc=acc\n", 208 | " if saveModel:\n", 209 | " torch.save(model.state_dict(),modelname)\n", 210 | " ##read\n", 211 | " #model=NERESGCN(layer)\n", 212 | " #model.load_state_dict(torch.load(PATH))\n", 213 | " #print('Saved the model')\n", 214 | " #print('TEST: ACC:%.4f F1:%.4f [TP:%3d|TN:%3d|FP:%3d|FN:%3d]'%(acc,f1,TP_test,TN_test,FP_test,FN_test)) \n", 215 | " return acc,f1,TP_test,TN_test,FP_test,FN_test\n", 216 | "#计算边节点的字典\n", 217 | "def gradient(device,model,dataloader):\n", 218 | " model.eval()\n", 219 | " for (X,Z,A,label,sub_id) in dataloader:\n", 220 | " X=torch.autograd.Variable(X,requires_grad=True)\n", 221 | " x=X.to(device)\n", 222 | " Z=torch.autograd.Variable(Z,requires_grad=True)\n", 223 | " z=Z.to(device)\n", 224 | " A=torch.autograd.Variable(A,requires_grad=True)\n", 225 | " a=A.to(device)\n", 226 | " y=model(x,z,a)\n", 227 | " if (label==torch.FloatTensor([0])).item():\n", 228 | " print('0')\n", 229 | " #y.autograd.backward(torch.FloatTensor([[1.,0.]]).to(device))\n", 230 | " torch.autograd.backward(y,torch.FloatTensor([[1.,0.]]).to(device))\n", 231 | " else:\n", 232 | " print('1')\n", 233 | " torch.autograd.backward(y,torch.FloatTensor([[0.,1.]]).to(device))\n", 234 | " grad_X=X.grad\n", 235 | " grad_Z=Z.grad\n", 236 | " #print(grad_X)\n", 237 | " value_x,index_x=torch.topk(torch.abs(grad_X.view(-1)),10)\n", 238 | " grad_X_topk=(torch.abs(grad_X)>=value_x[-1])\n", 239 | " value_z,index_z=torch.topk(torch.abs(grad_Z.view(-1)),100)\n", 240 | " grad_Z_topk=(torch.abs(grad_Z)>=value_z[-1])\n", 241 | " global gradsave_dict\n", 242 | " if label==torch.FloatTensor([0]).item():\n", 243 | " np.save(gradsave_dict+'/TDC/Z/'+str(sub_id.item()),grad_Z.numpy())\n", 244 | " np.save(gradsave_dict+'/TDC/X/'+str(sub_id.item()),grad_X.numpy())\n", 245 | " else:\n", 246 | " np.save(gradsave_dict+'/ASD/Z/'+str(sub_id.item()),grad_Z.numpy())\n", 247 | " np.save(gradsave_dict+'/ASD/X/'+str(sub_id.item()),grad_X.numpy())\n", 248 | " \n", 249 | "def cal_dict():\n", 250 | " index=0\n", 251 | " A={}\n", 252 | " for i in range(116):\n", 253 | " for j in range(i+1,116):\n", 254 | " A[index]=(i,j)\n", 255 | " A[(i,j)]=index\n", 256 | " index+=1\n", 257 | " return A\n", 258 | "def cal_evaluate(TP,TN,FP,FN):\n", 259 | " if TP>0:\n", 260 | " p = TP / (TP + FP)\n", 261 | " r = TP / (TP + FN)\n", 262 | " F1 = 2 * r * p / (r + p)\n", 263 | " else:\n", 264 | " F1=0\n", 265 | " acc = (TP + TN) / (TP + TN + FP + FN)\n", 266 | " #print('ACC:%.4f F1:%.4f [TP:%d|TN:%d|FP:%d|FN:%d]'%(acc,F1,TP,TN,FP,FN))\n", 267 | " return acc,F1\n", 268 | "def data_arange(sites,fmri_root,smri_root,nan_subid):\n", 269 | " asd,tdc=[],[]\n", 270 | " for site in sites:\n", 271 | " mri_asd=os.listdir(smri_root+site+'/group1')\n", 272 | " mri_tdc=os.listdir(smri_root+site+'/group2')\n", 273 | " fmri_asd=os.listdir(fmri_root+site+'/group1_FC')\n", 274 | " fmri_tdc=os.listdir(fmri_root+site+'/group2_FC')\n", 275 | " site_asd=[i for i in mri_asd if i in fmri_asd ]\n", 276 | " site_tdc=[i for i in mri_tdc if i in fmri_tdc ]\n", 277 | " site_asd=[i for i in site_asd if int(i[:5]) not in nan_subid]\n", 278 | " site_tdc=[i for i in site_tdc if int(i[:5]) not in nan_subid]\n", 279 | " asd.append(site_asd)\n", 280 | " tdc.append(site_tdc)\n", 281 | " return asd,tdc\n", 282 | "class dataset(Dataset):\n", 283 | " def __init__(self,fmri_root,smri_root,site,ASD,TDC):\n", 284 | " super(dataset,self).__init__()\n", 285 | " self.fmri=fmri_root\n", 286 | " self.smri=smri_root\n", 287 | " self.ASD=[j for i in ASD for j in i]\n", 288 | " self.TDC=[j for i in TDC for j in i]\n", 289 | " self.data=self.ASD+self.TDC\n", 290 | " random.shuffle(self.data)\n", 291 | " self.data_site={}\n", 292 | " for i in range(len(site)):\n", 293 | " data=ASD[i]+TDC[i]\n", 294 | " for j in data:\n", 295 | " if j not in self.data_site:\n", 296 | " self.data_site[j]=site[i] \n", 297 | " def __getitem__(self,index):\n", 298 | " data=self.data[index]\n", 299 | " sub_id=int(data[0:5])\n", 300 | " if data in self.ASD:\n", 301 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group1_slow5/'+data,allow_pickle=True)\n", 302 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group1_slow4/'+data,allow_pickle=True)\n", 303 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group1/'+data,allow_pickle=True)\n", 304 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group1_FC/'+data,allow_pickle=True)\n", 305 | " elif data in self.TDC:\n", 306 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group2_slow5/'+data,allow_pickle=True)\n", 307 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group2_slow4/'+data,allow_pickle=True)\n", 308 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group2/'+data,allow_pickle=True)\n", 309 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group2_FC/'+data,allow_pickle=True)\n", 310 | " else:\n", 311 | " print('wrong input')\n", 312 | " data_slow5=(data_slow5-np.min(data_slow5))/(np.max(data_slow5)-np.min(data_slow5))\n", 313 | " data_slow4=(data_slow4-np.min(data_slow4))/(np.max(data_slow4)-np.min(data_slow4))\n", 314 | " if np.any(np.isnan(data_slow5)) or np.any(np.isnan(data_slow4)) or np.any(np.isnan(data_FCz)):\n", 315 | " print('data wronmg')\n", 316 | " #data_FCz=(data_FCz-np.min(data_FCz))/(np.max(data_FCz)-np.min(data_FCz))\n", 317 | " if self.data[index] in self.ASD:\n", 318 | " label=torch.tensor(1)\n", 319 | " else:\n", 320 | " label=torch.tensor(0)\n", 321 | " X=np.zeros((116,3),dtype=np.float32)\n", 322 | " X[:,0]=data_slow5\n", 323 | " X[:,1]=data_slow4\n", 324 | " X[:,2]=data_voxel\n", 325 | " data_FCz=data_FCz.astype(np.float32)\n", 326 | " Z=torch.from_numpy(data_FCz)\n", 327 | " X=torch.from_numpy(X)\n", 328 | " return X,Z,label,sub_id\n", 329 | " def __len__(self):\n", 330 | " return len(self.data)\n", 331 | "def get_acc(acc_list,toprate):\n", 332 | " acc_list.sort()\n", 333 | " return acc_list[-int(toprate*len(acc_list))]\n", 334 | "def plot_acc(acc_list,loss_list):\n", 335 | " num_bins=50\n", 336 | " fig,ax=plt.subplots()\n", 337 | " #n,bins,patches=ax[0].hist(acc_list,num_bins,density=True)\n", 338 | " x=np.arange(0,len(acc_list))\n", 339 | " ax2=ax.twinx()\n", 340 | " ax.plot(x,acc_list,'b')\n", 341 | " ax.set_ylim(0.4,1)\n", 342 | " ax2.plot(x,loss_list,'r')\n", 343 | " plt.show()\n", 344 | "def train_fgsm(model,trainloader,testloader,epsilon):\n", 345 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 346 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 347 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 348 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 349 | " pbar=tqdm(range(epoch),leave=False,position=0)\n", 350 | " acc=0.5000\n", 351 | " loss_sum=0\n", 352 | " for j in pbar:\n", 353 | " pbar.set_description('Loss: {:.2f} Acc:{:.4f}'.format(loss_sum,acc))\n", 354 | " loss_sum=0\n", 355 | " TP,TN,FP,FN=0,0,0,0\n", 356 | " for (X,Z,A,label,sub_id) in trainloader:\n", 357 | " model.train()\n", 358 | " x=X.to(device)\n", 359 | " z=Z.to(device)\n", 360 | " x.requires_grad=True\n", 361 | " z.requires_grad=True\n", 362 | " label=label.to(device)\n", 363 | " y=model(x,z)\n", 364 | " loss=criterian1(y,label)\n", 365 | " model.zero_grad()\n", 366 | " loss.backward(retain_graph=True)\n", 367 | " sign_grad_x=torch.sign(x.grad.data)\n", 368 | " sign_grad_z=torch.sign(z.grad.data)\n", 369 | " perturbed_x=x+epsilon*sign_grad_x \n", 370 | " perturbed_z=z+epsilon*sign_grad_z \n", 371 | " perturbed_x=torch.clamp(perturbed_x,0,1)\n", 372 | " perturbed_z=torch.clamp(perturbed_z,-1,1)\n", 373 | " y=model(perturbed_x,perturbed_z)\n", 374 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 375 | " if L2_lamda>0:\n", 376 | " for name,parameters in model.named_parameters():\n", 377 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 378 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 379 | " loss=0.5*(criterian1(y,label)+loss)+L2\n", 380 | " loss_sum+=loss.item()\n", 381 | " optimizer.zero_grad()\n", 382 | " loss.backward()\n", 383 | " optimizer.step()\n", 384 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 385 | " result.loc[j]={'Loss':loss_sum,'Acc':acc,'F1':f1,'TP':TP_test,'TN':TN_test,'FP':FP_test,'FN':FN_test}\n", 386 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 387 | " return result.iloc[9] \n", 388 | "def train_pgd(model,trainloader,testloader,eps=0.05,iters=10,alpha=2/255):\n", 389 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 390 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 391 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 392 | " for j in range(epoch):\n", 393 | " loss_sum=0.\n", 394 | " TP,TN,FP,FN=0,0,0,0\n", 395 | " model.train()\n", 396 | " for (X,Z,label,sub_id) in trainloader:\n", 397 | " model.train()\n", 398 | " x=X.to(device)\n", 399 | " z=Z.to(device)\n", 400 | " label=label.to(device)\n", 401 | " pretu_x,pretu_z=x,z\n", 402 | " ori_x,ori_z=x.data,z.data\n", 403 | " for i in range(iters):\n", 404 | " pretu_x.requires_grad=True\n", 405 | " pretu_z.requires_grad=True\n", 406 | " y=model(pretu_x,pretu_z)\n", 407 | " loss=criterian1(y,label)\n", 408 | " model.zero_grad()\n", 409 | " loss.backward()\n", 410 | " adv_x=pretu_x+alpha*torch.sign(pretu_x.grad.data)\n", 411 | " adv_z=pretu_z+alpha*torch.sign(pretu_z.grad.data)\n", 412 | " eta_x=torch.clamp(adv_x-ori_x,min=-eps,max=eps)\n", 413 | " eta_z=torch.clamp(adv_z-ori_z,min=-eps,max=eps)\n", 414 | " pretu_x=torch.clamp(ori_x+eta_x,min=0,max=1).detach_()\n", 415 | " pretu_z=torch.clamp(ori_z+eta_z,min=-1,max=1).detach_()\n", 416 | " y=model(x,z)\n", 417 | " yy=model(pretu_x,pretu_z)\n", 418 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 419 | " if L2_lamda>0:\n", 420 | " for name,parameters in model.named_parameters():\n", 421 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 422 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 423 | " loss=0.5*(criterian1(yy,label)+criterian1(y,label))+L2\n", 424 | " loss_sum+=loss.item()\n", 425 | " optimizer.zero_grad()\n", 426 | " loss.backward()\n", 427 | " optimizer.step()\n", 428 | " if (j+1)%10==0:\n", 429 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 430 | " result.loc[(j+1)//10]=[loss_sum,acc,f1,TP_test,TN_test,FP_test,FN_test]\n", 431 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 432 | " print(' FinalAcc: {:.4f}'.format(result.iloc[0]['Acc']))\n", 433 | " return result.iloc[0]['Acc']\n", 434 | "def train(model,trainloader,testloader):\n", 435 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 436 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 437 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 438 | " #optimizer = AdaBelief(model.parameters(), lr=1e-4, eps=1e-8, betas=(0.9,0.999), weight_decay=L2_lamda,weight_decouple = True, rectify = False)\n", 439 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 440 | " pbar=tqdm(range(epoch),leave=False,position=0)\n", 441 | " acc=0.5000\n", 442 | " loss_sum=0\n", 443 | " for j in pbar:\n", 444 | " pbar.set_description('Loss: {:.2f} Acc:{:.4f}'.format(loss_sum,acc))\n", 445 | " loss_sum=0\n", 446 | " TP,TN,FP,FN=0,0,0,0\n", 447 | " time_start=time.time()\n", 448 | " for (X,Z,A,label,sub_id) in trainloader:\n", 449 | " #print(A)\n", 450 | " #print(Z)\n", 451 | " #print(X.shape,torch.mean(X),torch.std(X))\n", 452 | " #X=X+torch.randn(X.shape)*X.std(0)\n", 453 | " #Z=Z+torch.randn(Z.shape)*Z.std(0)\n", 454 | " model.train()\n", 455 | " X=X.to(device)\n", 456 | " Z=Z.to(device)\n", 457 | " label=label.to(device)\n", 458 | " y=model(X,Z)\n", 459 | " #print(y)\n", 460 | " _,predict=torch.max(y,1)\n", 461 | " TP+=((predict==1)&(label==1)).sum().item()\n", 462 | " TN+=((predict==0)&(label==0)).sum().item()\n", 463 | " FN+=((predict==0)&(label==1)).sum().item()\n", 464 | " FP+=((predict==1)&(label==0)).sum().item()\n", 465 | " loss=criterian1(y,label)\n", 466 | " L1=torch.tensor(0,dtype=torch.float32).to(device)\n", 467 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 468 | " #print(model.parameters.weit_n)\n", 469 | " if L1_lamda>0 or L2_lamda>0:\n", 470 | " for name,parameters in model.named_parameters():\n", 471 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 472 | " L1+=L1_lamda*torch.norm(parameters,1)\n", 473 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 474 | " loss+=(L1+L2)\n", 475 | " loss_sum+=loss.item()\n", 476 | " optimizer.zero_grad()\n", 477 | " loss.backward()\n", 478 | " optimizer.step()\n", 479 | " scheduler.step()\n", 480 | " time_end=time.time()\n", 481 | " time_cost=(time_end-time_start)/60.0\n", 482 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 483 | " result.loc[j]={'Loss':loss_sum,'Acc':acc,'F1':f1,'TP':TP_test,'TN':TN_test,'FP':FP_test,'FN':FN_test}\n", 484 | " #acc,f1=cal_evaluate(TP,TN,FP,FN)\n", 485 | " #print(\"[%2d/%d] ACC:%.2f F1:%.2f Loss: %.4f [TP:%3d|TN:%3d|FP:%3d|FN:%3d] CostTime:%4.1f min | RestTime:%.2f h\" %(\n", 486 | " # j+1,epoch,acc,f1,loss_sum,TP,TN,FP,FN,time_cost,time_cost/60*(epoch-1-j)))\n", 487 | " #losses.append(loss_sum)\n", 488 | " #print(model.parameters())\n", 489 | " #acc,f1=test(device,model,testloader)\n", 490 | " #print(acc)\n", 491 | " plot_acc(result['Acc'],result['Loss'])\n", 492 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 493 | " return result.iloc[9] \n", 494 | " #print('Top:10|%.4f\\tTop:20|%.4f\\tMax|%.4f'%(get_acc(acc_test,0.1),get_acc(acc_test,0.2),max(acc_test)))" 495 | ] 496 | }, 497 | { 498 | "cell_type": "raw", 499 | "metadata": {}, 500 | "source": [ 501 | "## 机构混合\n", 502 | "test_site =os.listdir(cpac_root)\n", 503 | "train_site=test_site\n", 504 | "k_fold=10 \n", 505 | "asd,tdc=data_arange(test_site,fmri_root=cpac_root,smri_root=smri_root,nan_subid=nan_subid)\n", 506 | "train_asd_dict={i:[] for i in range(k_fold)}\n", 507 | "train_tdc_dict={i:[] for i in range(k_fold)}\n", 508 | "test_asd_dict ={i:[] for i in range(k_fold)}\n", 509 | "test_tdc_dict ={i:[] for i in range(k_fold)}\n", 510 | "for data in asd:\n", 511 | " datasplit=data_2_k(data,k_fold,True)\n", 512 | " for index in range(k_fold):\n", 513 | " test_asd_dict[index].append(datasplit[index])\n", 514 | " train_temp=[j for i in datasplit for j in i if j not in datasplit[index]]\n", 515 | " train_asd_dict[index].append(train_temp)\n", 516 | "for data in tdc:\n", 517 | " datasplit=data_2_k(data,k_fold,True)\n", 518 | " for index in range(k_fold):\n", 519 | " test_tdc_dict[index].append(datasplit[index])\n", 520 | " test_temp=[j for i in datasplit for j in i if j not in datasplit[index]]\n", 521 | " train_tdc_dict[index].append(test_temp)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "raw", 526 | "metadata": {}, 527 | "source": [ 528 | "## 机构混合\n", 529 | "np.save('train_asd_dict.npy',train_asd_dict)\n", 530 | "np.save('train_tdc_dict.npy',train_tdc_dict)\n", 531 | "np.save('test_asd_dict.npy',test_asd_dict)\n", 532 | "np.save('test_tdc_dict.npy',test_tdc_dict)\n", 533 | "np.save('train_test_site.npy',train_site)" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": 4, 539 | "metadata": {}, 540 | "outputs": [], 541 | "source": [ 542 | "## 机构混合续\n", 543 | "train_site=test_site=np.load('DATAARRANGE/train_test_site.npy')\n", 544 | "train_asd_dict=np.load('DATAARRANGE/train_asd_dict.npy',allow_pickle=True).item()\n", 545 | "train_tdc_dict=np.load('DATAARRANGE/train_tdc_dict.npy',allow_pickle=True).item()\n", 546 | "test_asd_dict=np.load('DATAARRANGE/test_asd_dict.npy',allow_pickle=True).item()\n", 547 | "test_tdc_dict=np.load('DATAARRANGE/test_tdc_dict.npy',allow_pickle=True).item()" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 5, 553 | "metadata": {}, 554 | "outputs": [ 555 | { 556 | "name": "stdout", 557 | "output_type": "stream", 558 | "text": [ 559 | " FinalAcc: 0.7529\n", 560 | " FinalAcc: 0.6471\n", 561 | " FinalAcc: 0.7327\n", 562 | " FinalAcc: 0.7573\n", 563 | " FinalAcc: 0.7767\n", 564 | " FinalAcc: 0.7292\n", 565 | " FinalAcc: 0.6700\n", 566 | " FinalAcc: 0.6154\n", 567 | " FinalAcc: 0.6566\n", 568 | " FinalAcc: 0.6842\n", 569 | "Epision:0.001 Mean Acc:0.7022 \n" 570 | ] 571 | } 572 | ], 573 | "source": [ 574 | "# 机构混合续\n", 575 | "setup_seed(123)\n", 576 | "global figname\n", 577 | "figname='temple.png'\n", 578 | "#for epision in [0.001,0.005,0.01,0.02,0.05,0.1,0.2]:\n", 579 | "for epision in [0.001]:\n", 580 | " result=np.zeros(10)\n", 581 | " L1_lamda=0.0\n", 582 | " L2_lamda=0.0001\n", 583 | " learning_rate=0.0001\n", 584 | " epoch =100\n", 585 | " batch_size=64\n", 586 | " gmma =1\n", 587 | " layer =5\n", 588 | " for index in range(10):\n", 589 | " train_asd=train_asd_dict[index]\n", 590 | " train_tdc=train_tdc_dict[index]\n", 591 | " test_asd =test_asd_dict[index]\n", 592 | " test_tdc =test_tdc_dict[index]\n", 593 | " global max_acc\n", 594 | " global saveModel\n", 595 | " saveModel=False\n", 596 | " global modelname\n", 597 | " trainset=dataset(site=train_site,fmri_root=cpac_root,smri_root=smri_root,ASD=train_asd,TDC=train_tdc)\n", 598 | " trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True)\n", 599 | " testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc)\n", 600 | " testloader=DataLoader(testset,batch_size=1)\n", 601 | " model=NEResGCN(layer).to(device)\n", 602 | " modelname='/media/dm/0001A094000BF891/Yazid/SAVEDModels/normtrained/models_{}_{}'.format(index,9)\n", 603 | " max_acc=0.6\n", 604 | " result[index]=train_pgd(model,trainloader,testloader,eps=epision,iters=10,alpha=epision/5)\n", 605 | " print('Epision:{} Mean Acc:{:.4f} '.format(epision,result.mean()))" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": null, 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [] 614 | } 615 | ], 616 | "metadata": { 617 | "kernelspec": { 618 | "display_name": "Pytorch", 619 | "language": "python", 620 | "name": "pytorch" 621 | }, 622 | "language_info": { 623 | "codemirror_mode": { 624 | "name": "ipython", 625 | "version": 3 626 | }, 627 | "file_extension": ".py", 628 | "mimetype": "text/x-python", 629 | "name": "python", 630 | "nbconvert_exporter": "python", 631 | "pygments_lexer": "ipython3", 632 | "version": "3.6.13" 633 | }, 634 | "toc-autonumbering": false, 635 | "toc-showcode": true, 636 | "toc-showmarkdowntxt": false, 637 | "toc-showtags": false 638 | }, 639 | "nbformat": 4, 640 | "nbformat_minor": 4 641 | } 642 | -------------------------------------------------------------------------------- /ensamble.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn \n", 11 | "import numpy as np\n", 12 | "from torch.utils.data import Dataset\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "import os\n", 15 | "import time\n", 16 | "import random\n", 17 | "import pandas as pd\n", 18 | "import torch.nn.functional as F\n", 19 | "device=torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n", 20 | "from tqdm import tqdm\n", 21 | "import warnings \n", 22 | "warnings.filterwarnings(\"ignore\")\n", 23 | "cpac_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_CPAC/'\n", 24 | "smri_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_sMRI/'\n", 25 | "nan_subid=np.load('nan_subid.npy').tolist()\n", 26 | "def setup_seed(seed):\n", 27 | " torch.manual_seed(seed)\n", 28 | " torch.cuda.manual_seed_all(seed)\n", 29 | " np.random.seed(seed)\n", 30 | " random.seed(seed)\n", 31 | " torch.backends.cudnn.deterministic = True" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "class Attention(nn.Module):\n", 41 | " def __init__(self):\n", 42 | " super(Attention,self).__init__()\n", 43 | " self.conv1=nn.Conv1d(in_channels=3,out_channels=3,kernel_size=1,padding=0)\n", 44 | " self.conv2=nn.Conv1d(in_channels=116,out_channels=116,kernel_size=1,padding=0)\n", 45 | " self.softmax=nn.Softmax(dim=-1)\n", 46 | " def forward(self,Z,X):\n", 47 | " batchsize,x_dim,x_c= X.size()\n", 48 | " batchsize,z_dim,z_c= Z.size()\n", 49 | " K=self.conv1(X.permute(0,2,1))# BS,x_c,x_dim\n", 50 | " Q=K.permute(0,2,1)# BS,x_dim,x_c\n", 51 | " V=self.conv2(Z.permute(0,2,1))# Bs,z_c,z_dim\n", 52 | " attention=self.softmax(torch.matmul(Q,K))#BS,x_dim,x_dim\n", 53 | " out=torch.bmm(attention,V).permute(0,2,1)#BS,z_dim,z_c\n", 54 | " return out\n", 55 | "class ANEGCN(nn.Module):\n", 56 | " def __init__(self,layer):\n", 57 | " super(ANEGCN,self).__init__()\n", 58 | " self.layer =layer\n", 59 | " self.relu =nn.ReLU()\n", 60 | " self.atten =nn.ModuleList([Attention() for i in range(layer)])\n", 61 | " self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 62 | " self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 63 | " self.node_w=nn.ParameterList([nn.Parameter(torch.randn((3,3),dtype=torch.float32)) for i in range(layer)])\n", 64 | " self.edge_w=nn.ParameterList([nn.Parameter(torch.randn((116,116),dtype=torch.float32)) for i in range(layer)])\n", 65 | " self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*3,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)])\n", 66 | " self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)])\n", 67 | " self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(),\n", 68 | " nn.Linear(1024,2))\n", 69 | " self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False)\n", 70 | " self._initialize_weights()\n", 71 | " # params initialization\n", 72 | " def _initialize_weights(self):\n", 73 | " for m in self.modules():\n", 74 | " if isinstance(m, (nn.Conv1d,nn.Linear)):\n", 75 | " nn.init.xavier_uniform_(m.weight)\n", 76 | " if m.bias is not None:\n", 77 | " nn.init.constant_(m.bias, 0)\n", 78 | " elif isinstance(m, nn.BatchNorm1d):\n", 79 | " nn.init.constant_(m.weight, 1)\n", 80 | " nn.init.constant_(m.bias, 0)\n", 81 | " def normalized(self,Z):\n", 82 | " n=Z.size()[0]\n", 83 | " A=Z[0,:,:]\n", 84 | " A=A+torch.diag(self.ones)\n", 85 | " d=A.sum(1)\n", 86 | " D=torch.diag(torch.pow(d,-1))\n", 87 | " A=D.mm(A).reshape(1,116,116)\n", 88 | " for i in range(1,n):\n", 89 | " A1=Z[i,:,:]+torch.diag(self.ones)\n", 90 | " d=A1.sum(1)\n", 91 | " D=torch.diag(torch.pow(d,-1))\n", 92 | " A1=D.mm(A1).reshape(1,116,116)\n", 93 | " A=torch.cat((A,A1),0)\n", 94 | " return A\n", 95 | " \n", 96 | " def update_A(self,Z):\n", 97 | " n=Z.size()[0]\n", 98 | " A=Z[0,:,:]\n", 99 | " Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*0.2))\n", 100 | " A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 101 | " A=A.reshape(1,116,116)\n", 102 | " for i in range(1,n):\n", 103 | " A2=Z[i,:,:]\n", 104 | " Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*0.2))\n", 105 | " A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 106 | " A2=A2.reshape(1,116,116)\n", 107 | " A=torch.cat((A,A2),0)\n", 108 | " return A\n", 109 | " \n", 110 | " def forward(self,X,Z):\n", 111 | " n=X.size()[0]\n", 112 | " XX=self.line_n[0](X.view(n,-1))\n", 113 | " ZZ=self.line_e[0](Z.view(n,-1))\n", 114 | " for i in range(self.layer):\n", 115 | " A=self.atten[i](Z,X)\n", 116 | " Z1=torch.matmul(A,Z)\n", 117 | " Z2=torch.matmul(Z1,self.edge_w[i])\n", 118 | " Z=self.relu(self.norm_e[i](Z2))+Z\n", 119 | " ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1)\n", 120 | " X1=torch.matmul(A,X)\n", 121 | " X1=torch.matmul(X1,self.node_w[i])\n", 122 | " X=self.relu(self.norm_n[i](X1))+X\n", 123 | " #X.register_hook(grad_X_hook)\n", 124 | " #feat_X_hook(X)\n", 125 | " XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1)\n", 126 | " XZ=torch.cat((XX,ZZ),1)\n", 127 | " y=self.clase(XZ)\n", 128 | " #print(self.clase[0].weight)\n", 129 | " return y\n", 130 | "def grad_X_hook(grad):\n", 131 | " X_grad.append(grad)\n", 132 | "def feat_X_hook(X):\n", 133 | " X_feat.append(X.detach())\n", 134 | "X_grad=list()\n", 135 | "X_feat=list()\n", 136 | "class LabelSmoothLoss(nn.Module):\n", 137 | " \n", 138 | " def __init__(self, smoothing=0.0):\n", 139 | " super(LabelSmoothLoss, self).__init__()\n", 140 | " self.smoothing = smoothing\n", 141 | " \n", 142 | " def forward(self, input, target):\n", 143 | " log_prob = F.log_softmax(input, dim=-1)\n", 144 | " weight = input.new_ones(input.size()) * \\\n", 145 | " self.smoothing / (input.size(-1) - 1.)\n", 146 | " weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))\n", 147 | " loss = (-weight * log_prob).sum(dim=-1).mean()\n", 148 | " return loss" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 3, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "def cal_evaluate(TP,TN,FP,FN):\n", 158 | " if TP>0:\n", 159 | " p = TP / (TP + FP)\n", 160 | " r = TP / (TP + FN)\n", 161 | " F1 = 2 * r * p / (r + p)\n", 162 | " else:\n", 163 | " F1=0\n", 164 | " acc = (TP + TN) / (TP + TN + FP + FN)\n", 165 | " #print('ACC:%.4f F1:%.4f [TP:%d|TN:%d|FP:%d|FN:%d]'%(acc,F1,TP,TN,FP,FN))\n", 166 | " return acc,F1\n", 167 | "def test(device,model,testloader):\n", 168 | " model.eval()\n", 169 | " TP_test,TN_test,FP_test,FN_test=0,0,0,0\n", 170 | " with torch.no_grad():\n", 171 | " for (X,Z,label,sub_id) in testloader:\n", 172 | " TP,TN,FN,FP=0,0,0,0\n", 173 | " n=X.size()[0]\n", 174 | " X=X.to(device)\n", 175 | " Z=Z.to(device)\n", 176 | " label=label.to(device)\n", 177 | " y=model(X,Z)\n", 178 | " _,predict=torch.max(y,1)\n", 179 | " TP+=((predict==1)&(label==1)).sum().item()\n", 180 | " TN+=((predict==0)&(label==0)).sum().item()\n", 181 | " FN+=((predict==0)&(label==1)).sum().item()\n", 182 | " FP+=((predict==1)&(label==0)).sum().item()\n", 183 | " TP_test+=TP\n", 184 | " TN_test+=TN\n", 185 | " FP_test+=FP\n", 186 | " FN_test+=FN\n", 187 | " acc,f1=cal_evaluate(TP_test,TN_test,FP_test,FN_test)\n", 188 | " global max_acc\n", 189 | " global modelname\n", 190 | " global savedModel\n", 191 | " if acc>=max_acc:\n", 192 | " max_acc=acc\n", 193 | " if saveModel:\n", 194 | " torch.save(model.state_dict(),modelname)\n", 195 | " return acc,f1,TP_test,TN_test,FP_test,FN_test" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 4, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "class dataset(Dataset):\n", 205 | " def __init__(self,fmri_root,smri_root,site,ASD,TDC):\n", 206 | " super(dataset,self).__init__()\n", 207 | " self.fmri=fmri_root\n", 208 | " self.smri=smri_root\n", 209 | " self.ASD=[j for i in ASD for j in i]\n", 210 | " self.TDC=[j for i in TDC for j in i]\n", 211 | " self.data=self.ASD+self.TDC\n", 212 | " random.shuffle(self.data)\n", 213 | " self.data_site={}\n", 214 | " for i in range(len(site)):\n", 215 | " data=ASD[i]+TDC[i]\n", 216 | " for j in data:\n", 217 | " if j not in self.data_site:\n", 218 | " self.data_site[j]=site[i] \n", 219 | " def __getitem__(self,index):\n", 220 | " data=self.data[index]\n", 221 | " sub_id=int(data[0:5])\n", 222 | " if data in self.ASD:\n", 223 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group1_slow5/'+data,allow_pickle=True)\n", 224 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group1_slow4/'+data,allow_pickle=True)\n", 225 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group1/'+data,allow_pickle=True)\n", 226 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group1_FC/'+data,allow_pickle=True)\n", 227 | " elif data in self.TDC:\n", 228 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group2_slow5/'+data,allow_pickle=True)\n", 229 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group2_slow4/'+data,allow_pickle=True)\n", 230 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group2/'+data,allow_pickle=True)\n", 231 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group2_FC/'+data,allow_pickle=True)\n", 232 | " else:\n", 233 | " print('wrong input')\n", 234 | " data_slow5=(data_slow5-np.min(data_slow5))/(np.max(data_slow5)-np.min(data_slow5))\n", 235 | " data_slow4=(data_slow4-np.min(data_slow4))/(np.max(data_slow4)-np.min(data_slow4))\n", 236 | " if np.any(np.isnan(data_slow5)) or np.any(np.isnan(data_slow4)) or np.any(np.isnan(data_FCz)):\n", 237 | " print('data wronmg')\n", 238 | " #data_FCz=(data_FCz-np.min(data_FCz))/(np.max(data_FCz)-np.min(data_FCz))\n", 239 | " if self.data[index] in self.ASD:\n", 240 | " label=torch.tensor(1)\n", 241 | " else:\n", 242 | " label=torch.tensor(0)\n", 243 | " X=np.zeros((116,3),dtype=np.float32)\n", 244 | " X[:,0]=data_slow5\n", 245 | " X[:,1]=data_slow4\n", 246 | " X[:,2]=data_voxel\n", 247 | " data_FCz=data_FCz.astype(np.float32)\n", 248 | " Z=torch.from_numpy(data_FCz)\n", 249 | " X=torch.from_numpy(X)\n", 250 | " return X,Z,label,sub_id\n", 251 | " def __len__(self):\n", 252 | " return len(self.data)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 5, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "def train_fgsm(model,trainloader,testloader,epsilon=0.05):\n", 262 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 263 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 264 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 265 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 266 | " acc=0.5000\n", 267 | " loss_sum=0\n", 268 | " for j in range(epoch):\n", 269 | " print('\\rLoss: {:.2f} Acc:{:.4f}'.format(loss_sum,acc),end='')\n", 270 | " loss_sum=0\n", 271 | " TP,TN,FP,FN=0,0,0,0\n", 272 | " model.train()\n", 273 | " for (X,Z,label,sub_id) in trainloader:\n", 274 | " x=X.to(device)\n", 275 | " z=Z.to(device)\n", 276 | " x.requires_grad=True\n", 277 | " z.requires_grad=True\n", 278 | " label=label.to(device)\n", 279 | " y=model(x,z)\n", 280 | " loss=criterian1(y,label)\n", 281 | " model.zero_grad()\n", 282 | " loss.backward(retain_graph=True)\n", 283 | " sign_grad_x=torch.sign(x.grad.data)\n", 284 | " sign_grad_z=torch.sign(z.grad.data)\n", 285 | " perturbed_x=x+epsilon*sign_grad_x \n", 286 | " perturbed_z=z+epsilon*sign_grad_z \n", 287 | " perturbed_x=torch.clamp(perturbed_x,0,1)\n", 288 | " perturbed_z=torch.clamp(perturbed_z,-1,1)\n", 289 | " y=model(perturbed_x,perturbed_z)\n", 290 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 291 | " if L2_lamda>0:\n", 292 | " for name,parameters in model.named_parameters():\n", 293 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 294 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 295 | " loss=0.5*(criterian1(y,label)+loss)+L2\n", 296 | " loss_sum+=loss.item()\n", 297 | " optimizer.zero_grad()\n", 298 | " loss.backward()\n", 299 | " optimizer.step()\n", 300 | " if (j+1)%10 == 0 or j==0:\n", 301 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 302 | " result.loc[j]={'Loss':loss_sum,'Acc':acc,'F1':f1,'TP':TP_test,'TN':TN_test,'FP':FP_test,'FN':FN_test}\n", 303 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 304 | " print('',end='')\n", 305 | " print('Acc: ', result.iloc[0]['Acc'])\n", 306 | "def train_norm(model,trainloader,testloader):\n", 307 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 308 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 309 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 310 | " acc=0.5000\n", 311 | " loss_sum=0\n", 312 | " for j in range(epoch):\n", 313 | " print('\\rEPOCH: [{:03d}|100] Loss: {:.2f} Acc:{:.4f}'.format(j+1,loss_sum,acc),end='')\n", 314 | " loss_sum=0\n", 315 | " TP,TN,FP,FN=0,0,0,0\n", 316 | " model.train()\n", 317 | " for (X,Z,label,sub_id) in trainloader:\n", 318 | " x=X.to(device)\n", 319 | " z=Z.to(device)\n", 320 | " label=label.to(device)\n", 321 | " y=model(x,z)\n", 322 | " loss=criterian1(y,label)\n", 323 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 324 | " if L2_lamda>0:\n", 325 | " for name,parameters in model.named_parameters():\n", 326 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 327 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 328 | " loss=loss+L2\n", 329 | " loss_sum+=loss.item()\n", 330 | " optimizer.zero_grad()\n", 331 | " loss.backward()\n", 332 | " optimizer.step()\n", 333 | " if (j+1)%10 == 0 or j==0:\n", 334 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 335 | " result.loc[j//10]={'Loss':loss_sum,'Acc':acc,'F1':f1,'TP':TP_test,'TN':TN_test,'FP':FP_test,'FN':FN_test}\n", 336 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 337 | " print(' FinalAcc: {:.4f}'.format(result.iloc[0]['Acc']))\n", 338 | "def train_pgd(model,trainloader,testloader,eps=0.2,iters=10,alpha=8/255):\n", 339 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 340 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 341 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 342 | " for j in range(epoch):\n", 343 | " loss_sum=0.\n", 344 | " TP,TN,FP,FN=0,0,0,0\n", 345 | " model.train()\n", 346 | " for (X,Z,label,sub_id) in trainloader:\n", 347 | " model.train()\n", 348 | " x=X.to(device)\n", 349 | " z=Z.to(device)\n", 350 | " label=label.to(device)\n", 351 | " pretu_x,pretu_z=x,z\n", 352 | " ori_x,ori_z=x.data,z.data\n", 353 | " for i in range(iters):\n", 354 | " pretu_x.requires_grad=True\n", 355 | " pretu_z.requires_grad=True\n", 356 | " y=model(pretu_x,pretu_z)\n", 357 | " loss=criterian1(y,label)\n", 358 | " model.zero_grad()\n", 359 | " loss.backward()\n", 360 | " adv_x=pretu_x+alpha*torch.sign(pretu_x.grad.data)\n", 361 | " adv_z=pretu_z+alpha*torch.sign(pretu_z.grad.data)\n", 362 | " eta_x=torch.clamp(adv_x-ori_x,min=-eps,max=eps)\n", 363 | " eta_z=torch.clamp(adv_z-ori_z,min=-eps,max=eps)\n", 364 | " pretu_x=torch.clamp(ori_x+eta_x,min=0,max=1).detach_()\n", 365 | " pretu_z=torch.clamp(ori_z+eta_z,min=-1,max=1).detach_()\n", 366 | " y=model(x,z)\n", 367 | " yy=model(pretu_x,pretu_z)\n", 368 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 369 | " if L2_lamda>0:\n", 370 | " for name,parameters in model.named_parameters():\n", 371 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 372 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 373 | " loss=0.5*(criterian1(yy,label)+criterian1(y,label))+L2\n", 374 | " loss_sum+=loss.item()\n", 375 | " optimizer.zero_grad()\n", 376 | " loss.backward()\n", 377 | " optimizer.step()\n", 378 | " if (j+1)%10==0:\n", 379 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 380 | " result.loc[(j+1)//10]=[loss_sum,acc,f1,TP_test,TN_test,FP_test,FN_test]\n", 381 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 382 | " print(' FinalAcc: {:.4f}'.format(result.iloc[0]['Acc']))\n", 383 | " return result.iloc[0]['Acc']" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 6, 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [ 392 | "## 机构混合续\n", 393 | "train_site=test_site=np.load('DATAARRANGE/train_test_site.npy')\n", 394 | "train_asd_dict=np.load('DATAARRANGE/train_asd_dict.npy',allow_pickle=True).item()\n", 395 | "train_tdc_dict=np.load('DATAARRANGE/train_tdc_dict.npy',allow_pickle=True).item()\n", 396 | "test_asd_dict=np.load('DATAARRANGE/test_asd_dict.npy',allow_pickle=True).item()\n", 397 | "test_tdc_dict=np.load('DATAARRANGE/test_tdc_dict.npy',allow_pickle=True).item()" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 7, 403 | "metadata": {}, 404 | "outputs": [ 405 | { 406 | "name": "stdout", 407 | "output_type": "stream", 408 | "text": [ 409 | " FinalAcc: 0.7765\n", 410 | " FinalAcc: 0.6863\n", 411 | " FinalAcc: 0.6634\n", 412 | " FinalAcc: 0.6893\n", 413 | " FinalAcc: 0.7670\n", 414 | " FinalAcc: 0.6875\n", 415 | " FinalAcc: 0.7200\n", 416 | " FinalAcc: 0.6346\n", 417 | " FinalAcc: 0.6869\n", 418 | " FinalAcc: 0.6842\n", 419 | " FinalAcc: 0.7529\n", 420 | " FinalAcc: 0.7157\n", 421 | " FinalAcc: 0.7030\n", 422 | " FinalAcc: 0.7087\n", 423 | " FinalAcc: 0.7282\n", 424 | " FinalAcc: 0.7083\n", 425 | " FinalAcc: 0.6900\n", 426 | " FinalAcc: 0.6538\n", 427 | " FinalAcc: 0.6768\n", 428 | " FinalAcc: 0.6930\n", 429 | " FinalAcc: 0.7412\n", 430 | " FinalAcc: 0.7255\n", 431 | " FinalAcc: 0.7426\n", 432 | " FinalAcc: 0.6990\n", 433 | " FinalAcc: 0.6990\n", 434 | " FinalAcc: 0.7083\n", 435 | " FinalAcc: 0.7000\n", 436 | " FinalAcc: 0.6731\n", 437 | " FinalAcc: 0.6970\n", 438 | " FinalAcc: 0.6579\n", 439 | " FinalAcc: 0.7529\n", 440 | " FinalAcc: 0.6863\n", 441 | " FinalAcc: 0.6931\n", 442 | " FinalAcc: 0.6893\n", 443 | " FinalAcc: 0.7087\n", 444 | " FinalAcc: 0.6979\n", 445 | " FinalAcc: 0.6900\n", 446 | " FinalAcc: 0.6250\n", 447 | " FinalAcc: 0.6768\n", 448 | " FinalAcc: 0.7018\n", 449 | " FinalAcc: 0.7647\n", 450 | " FinalAcc: 0.7353\n", 451 | " FinalAcc: 0.7228\n", 452 | " FinalAcc: 0.6990\n", 453 | " FinalAcc: 0.7184\n", 454 | " FinalAcc: 0.7083\n", 455 | " FinalAcc: 0.7100\n", 456 | " FinalAcc: 0.6635\n", 457 | " FinalAcc: 0.7374\n", 458 | " FinalAcc: 0.7193\n", 459 | " FinalAcc: 0.7765\n", 460 | " FinalAcc: 0.7647\n", 461 | " FinalAcc: 0.6832\n", 462 | " FinalAcc: 0.7184\n", 463 | " FinalAcc: 0.6990\n", 464 | " FinalAcc: 0.6979\n", 465 | " FinalAcc: 0.7000\n", 466 | " FinalAcc: 0.6538\n", 467 | " FinalAcc: 0.7172\n", 468 | " FinalAcc: 0.6579\n", 469 | " FinalAcc: 0.7647\n", 470 | " FinalAcc: 0.7157\n", 471 | " FinalAcc: 0.7327\n", 472 | " FinalAcc: 0.7282\n", 473 | " FinalAcc: 0.7670\n", 474 | " FinalAcc: 0.6667\n", 475 | " FinalAcc: 0.7000\n", 476 | " FinalAcc: 0.6250\n", 477 | " FinalAcc: 0.6970\n", 478 | " FinalAcc: 0.6754\n", 479 | " FinalAcc: 0.7765\n", 480 | " FinalAcc: 0.7059\n", 481 | " FinalAcc: 0.7129\n", 482 | " FinalAcc: 0.7087\n", 483 | " FinalAcc: 0.6796\n", 484 | " FinalAcc: 0.7083\n", 485 | " FinalAcc: 0.7100\n", 486 | " FinalAcc: 0.6538\n", 487 | " FinalAcc: 0.6970\n", 488 | " FinalAcc: 0.6579\n", 489 | " FinalAcc: 0.7882\n", 490 | " FinalAcc: 0.7255\n", 491 | " FinalAcc: 0.6832\n", 492 | " FinalAcc: 0.6796\n", 493 | " FinalAcc: 0.6990\n", 494 | " FinalAcc: 0.6979\n", 495 | " FinalAcc: 0.7300\n", 496 | " FinalAcc: 0.7019\n", 497 | " FinalAcc: 0.6869\n", 498 | " FinalAcc: 0.7018\n", 499 | " FinalAcc: 0.7765\n", 500 | " FinalAcc: 0.7059\n", 501 | " FinalAcc: 0.7228\n", 502 | " FinalAcc: 0.7379\n", 503 | " FinalAcc: 0.7184\n", 504 | " FinalAcc: 0.6875\n", 505 | " FinalAcc: 0.7300\n", 506 | " FinalAcc: 0.6442\n", 507 | " FinalAcc: 0.7172\n", 508 | " FinalAcc: 0.6667\n", 509 | " FinalAcc: 0.7529\n", 510 | " FinalAcc: 0.7255\n", 511 | " FinalAcc: 0.7327\n", 512 | " FinalAcc: 0.6893\n", 513 | " FinalAcc: 0.7184\n", 514 | " FinalAcc: 0.7188\n", 515 | " FinalAcc: 0.7100\n", 516 | " FinalAcc: 0.6827\n", 517 | " FinalAcc: 0.6667\n", 518 | " FinalAcc: 0.7018\n", 519 | " FinalAcc: 0.8118\n", 520 | " FinalAcc: 0.7353\n", 521 | " FinalAcc: 0.7327\n", 522 | " FinalAcc: 0.6699\n", 523 | " FinalAcc: 0.7282\n", 524 | " FinalAcc: 0.7708\n", 525 | " FinalAcc: 0.6600\n", 526 | " FinalAcc: 0.6346\n", 527 | " FinalAcc: 0.7071\n", 528 | " FinalAcc: 0.7018\n", 529 | " FinalAcc: 0.8000\n", 530 | " FinalAcc: 0.7353\n", 531 | " FinalAcc: 0.7228\n", 532 | " FinalAcc: 0.6990\n", 533 | " FinalAcc: 0.7379\n", 534 | " FinalAcc: 0.6875\n", 535 | " FinalAcc: 0.7600\n", 536 | " FinalAcc: 0.6442\n", 537 | " FinalAcc: 0.6667\n", 538 | " FinalAcc: 0.6754\n" 539 | ] 540 | } 541 | ], 542 | "source": [ 543 | "setup_seed(123)\n", 544 | "global max_acc\n", 545 | "global saveModel\n", 546 | "global modelname\n", 547 | "for i in range(13):\n", 548 | " L1_lamda=0.0\n", 549 | " L2_lamda=0.0001\n", 550 | " learning_rate=0.0001\n", 551 | " epoch =100\n", 552 | " batch_size=64\n", 553 | " layer =5\n", 554 | " for index in range(10):\n", 555 | " saveModel=True\n", 556 | " max_acc=0.6\n", 557 | " modelname='../SAVEDModels/PGDtrainedensamble/models_{}_{}'.format(i,index)\n", 558 | " train_asd=train_asd_dict[index]\n", 559 | " train_tdc=train_tdc_dict[index]\n", 560 | " test_asd =test_asd_dict[index]\n", 561 | " test_tdc =test_tdc_dict[index]\n", 562 | " trainset=dataset(site=train_site,fmri_root=cpac_root,smri_root=smri_root,ASD=train_asd,TDC=train_tdc)\n", 563 | " trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True)\n", 564 | " testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc)\n", 565 | " testloader=DataLoader(testset,batch_size=1)\n", 566 | " model=ANEGCN(layer).to(device)\n", 567 | " train_pgd(model,trainloader,testloader,eps=0.02,iters=10,alpha=0.004)" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 8, 573 | "metadata": {}, 574 | "outputs": [ 575 | { 576 | "name": "stdout", 577 | "output_type": "stream", 578 | "text": [ 579 | "0.7358490566037735\n", 580 | "0.6798336798336798\n", 581 | "0.7870722433460076\n" 582 | ] 583 | } 584 | ], 585 | "source": [ 586 | "Pred_label={}\n", 587 | "True_label={}\n", 588 | "vote=13\n", 589 | "for i in range(vote):\n", 590 | " Pred_label={}\n", 591 | " True_label={}\n", 592 | " for index in range(10):\n", 593 | " PATH='../SAVEDModels/PGDtrainedensamble/models_{}_{}'.format(i,index)\n", 594 | " test_asd =test_asd_dict[index]\n", 595 | " test_tdc =test_tdc_dict[index]\n", 596 | " testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc)\n", 597 | " testloader=DataLoader(testset,batch_size=1)\n", 598 | " model=ANEGCN(5)\n", 599 | " model.load_state_dict(torch.load(PATH))\n", 600 | " model.eval()\n", 601 | " with torch.no_grad():\n", 602 | " for (X,Z,label,sub_id) in testloader:\n", 603 | " True_label[sub_id.item()]=label.item()\n", 604 | " y=model(X,Z)\n", 605 | " _,predict=torch.max(y,1)\n", 606 | " if sub_id.item() not in Pred_label:\n", 607 | " Pred_label[sub_id.item()]=0\n", 608 | " if predict.item()==1:\n", 609 | " Pred_label[sub_id.item()]+=1 \n", 610 | "TP,TN,FP,FN=0,0,0,0\n", 611 | "for sId in True_label:\n", 612 | " if True_label[sId]==1:\n", 613 | " if Pred_label[sId]>= ((vote+1)//2):\n", 614 | " TP+=1\n", 615 | " else:\n", 616 | " FN+=1\n", 617 | " else:\n", 618 | " if Pred_label[sId]<= ((vote-1)//2):\n", 619 | " TN+=1\n", 620 | " else:\n", 621 | " FP+=1\n", 622 | "print((TP+TN)/1007)\n", 623 | "print(TP/(TP+FN))\n", 624 | "print(TN/(FP+TN))" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": {}, 630 | "source": [ 631 | "03: 0.7229;\n", 632 | "07: 0.7319;0.6736;0.7852;\n", 633 | "13: 0.7329;0.6881;0.7738;\n", 634 | "15: 0.7309;0.6861;0.7717;\n", 635 | "17: 0.7269;0.6798;0.7700;\n", 636 | "19: 0.7239;0.6775;0.7662;\n", 637 | "21: 0.7269;0.6819;0.7681;\n", 638 | "23: 0.7239;0.6840;0.7605;" 639 | ] 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "metadata": {}, 644 | "source": [ 645 | "13: 0.7468 0.7173 0.7738;" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": 10, 651 | "metadata": {}, 652 | "outputs": [ 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "Acc: 0.6996+0.0415 Sen: 0.6698+0.0771 Spe: 0.7268+0.0702\n", 658 | "Acc: 0.7030+0.0259 Sen: 0.6305+0.0560 Spe: 0.7688+0.0477\n", 659 | "Acc: 0.7044+0.0256 Sen: 0.6602+0.0970 Spe: 0.7442+0.0728\n", 660 | "Acc: 0.6922+0.0298 Sen: 0.6183+0.1002 Spe: 0.7579+0.0887\n", 661 | "Acc: 0.7179+0.0252 Sen: 0.6572+0.1010 Spe: 0.7750+0.0571\n", 662 | "Acc: 0.7069+0.0379 Sen: 0.6794+0.0766 Spe: 0.7311+0.0842\n", 663 | "Acc: 0.7072+0.0420 Sen: 0.6651+0.1061 Spe: 0.7464+0.0674\n", 664 | "Acc: 0.7011+0.0326 Sen: 0.6562+0.0630 Spe: 0.7409+0.0696\n", 665 | "Acc: 0.7094+0.0306 Sen: 0.6657+0.1211 Spe: 0.7506+0.0757\n", 666 | "Acc: 0.7107+0.0355 Sen: 0.6757+0.0976 Spe: 0.7414+0.1029\n", 667 | "Acc: 0.7099+0.0242 Sen: 0.6797+0.0601 Spe: 0.7366+0.0465\n", 668 | "Acc: 0.7152+0.0501 Sen: 0.6783+0.0989 Spe: 0.7472+0.0879\n", 669 | "Acc: 0.7445+0.0929 Sen: 0.6958+0.1601 Spe: 0.7891+0.0919\n" 670 | ] 671 | } 672 | ], 673 | "source": [ 674 | "vote=13\n", 675 | "for i in range(vote):\n", 676 | " Acc,Sen,Spe=np.zeros(10),np.zeros(10),np.zeros(10)\n", 677 | " for index in range(10):\n", 678 | " PATH='../SAVEDModels/PGDtrainedensamble/models_{}_{}'.format(i,index)\n", 679 | " test_asd =test_asd_dict[index]\n", 680 | " test_tdc =test_tdc_dict[index]\n", 681 | " testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc)\n", 682 | " testloader=DataLoader(testset,batch_size=1)\n", 683 | " model=ANEGCN(5)\n", 684 | " model.load_state_dict(torch.load(PATH))\n", 685 | " model.eval()\n", 686 | " model=model.to(device)\n", 687 | " with torch.no_grad():\n", 688 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 689 | " Acc[index]=acc\n", 690 | " Sen[index]=TP_test/(TP_test+FN_test)\n", 691 | " Spe[index]=TN_test/(TN_test+FP_test)\n", 692 | " print('Acc: %.4f+%.4f Sen: %.4f+%.4f Spe: %.4f+%.4f'%(Acc.mean(),Acc.std(),Sen.mean(),Sen.std(),Spe.mean(),Spe.std()))" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 2, 698 | "metadata": {}, 699 | "outputs": [ 700 | { 701 | "name": "stdout", 702 | "output_type": "stream", 703 | "text": [ 704 | "0.0011470317840576172 0.06524252891540527\n" 705 | ] 706 | } 707 | ], 708 | "source": [ 709 | "import time \n", 710 | "import numpy as np\n", 711 | "A=np.random.randn(100000)\n", 712 | "start=time.time()\n", 713 | "B=2*A\n", 714 | "end=time.time()\n", 715 | "time1=end-start\n", 716 | "start=time.time()\n", 717 | "B=[2*i for i in A]\n", 718 | "end=time.time()\n", 719 | "time2=end-start\n", 720 | "print(time1,time2)" 721 | ] 722 | }, 723 | { 724 | "cell_type": "code", 725 | "execution_count": null, 726 | "metadata": {}, 727 | "outputs": [], 728 | "source": [] 729 | } 730 | ], 731 | "metadata": { 732 | "kernelspec": { 733 | "display_name": "Pytorch", 734 | "language": "python", 735 | "name": "pytorch" 736 | }, 737 | "language_info": { 738 | "codemirror_mode": { 739 | "name": "ipython", 740 | "version": 3 741 | }, 742 | "file_extension": ".py", 743 | "mimetype": "text/x-python", 744 | "name": "python", 745 | "nbconvert_exporter": "python", 746 | "pygments_lexer": "ipython3", 747 | "version": "3.6.13" 748 | } 749 | }, 750 | "nbformat": 4, 751 | "nbformat_minor": 4 752 | } 753 | -------------------------------------------------------------------------------- /K-fold-withoutsMRI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn \n", 11 | "import numpy as np\n", 12 | "from torch.utils.data import Dataset\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "import os\n", 15 | "import time\n", 16 | "import random\n", 17 | "import pandas as pd\n", 18 | "import torch.nn.functional as F\n", 19 | "device=torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n", 20 | "nan_subid=np.load('nan_subid.npy').tolist()\n", 21 | "import warnings \n", 22 | "warnings.filterwarnings(\"ignore\")\n", 23 | "cpac_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_CPAC/'\n", 24 | "smri_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_sMRI/'\n", 25 | "def setup_seed(seed):\n", 26 | " torch.manual_seed(seed)\n", 27 | " torch.cuda.manual_seed_all(seed)\n", 28 | " np.random.seed(seed)\n", 29 | " random.seed(seed)\n", 30 | " torch.backends.cudnn.deterministic = True" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "class Attention(nn.Module):\n", 40 | " def __init__(self):\n", 41 | " super(Attention,self).__init__()\n", 42 | " self.conv1=nn.Conv1d(in_channels=2,out_channels=2,kernel_size=1,padding=0)\n", 43 | " self.conv2=nn.Conv1d(in_channels=116,out_channels=116,kernel_size=1,padding=0)\n", 44 | " self.softmax=nn.Softmax(dim=-1)\n", 45 | " def forward(self,Z,X):\n", 46 | " batchsize,x_dim,x_c= X.size()\n", 47 | " batchsize,z_dim,z_c= Z.size()\n", 48 | " K=self.conv1(X.permute(0,2,1))# BS,x_c,x_dim\n", 49 | " Q=K.permute(0,2,1)# BS,x_dim,x_c\n", 50 | " V=self.conv2(Z.permute(0,2,1))# Bs,z_c,z_dim\n", 51 | " attention=self.softmax(torch.matmul(Q,K))#BS,x_dim,x_dim\n", 52 | " out=torch.bmm(attention,V).permute(0,2,1)#BS,z_dim,z_c\n", 53 | " return out\n", 54 | "class ANEGCN(nn.Module):\n", 55 | " def __init__(self,layer):\n", 56 | " super(ANEGCN,self).__init__()\n", 57 | " self.layer =layer\n", 58 | " self.relu =nn.ReLU()\n", 59 | " self.atten =nn.ModuleList([Attention() for i in range(layer)])\n", 60 | " self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 61 | " self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 62 | " self.node_w=nn.ParameterList([nn.Parameter(torch.randn((2,2),dtype=torch.float32)) for i in range(layer)])\n", 63 | " self.edge_w=nn.ParameterList([nn.Parameter(torch.randn((116,116),dtype=torch.float32)) for i in range(layer)])\n", 64 | " self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*2,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)])\n", 65 | " self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)])\n", 66 | " self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(0.2),nn.ReLU(),\n", 67 | " nn.Linear(1024,2))\n", 68 | " self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False)\n", 69 | " self._initialize_weights()\n", 70 | " # params initialization\n", 71 | " def _initialize_weights(self):\n", 72 | " for m in self.modules():\n", 73 | " if isinstance(m, (nn.Conv1d,nn.Linear)):\n", 74 | " nn.init.xavier_uniform_(m.weight)\n", 75 | " if m.bias is not None:\n", 76 | " nn.init.constant_(m.bias, 0)\n", 77 | " elif isinstance(m, nn.BatchNorm1d):\n", 78 | " nn.init.constant_(m.weight, 1)\n", 79 | " nn.init.constant_(m.bias, 0)\n", 80 | " def normalized(self,Z):\n", 81 | " n=Z.size()[0]\n", 82 | " A=Z[0,:,:]\n", 83 | " A=A+torch.diag(self.ones)\n", 84 | " d=A.sum(1)\n", 85 | " D=torch.diag(torch.pow(d,-1))\n", 86 | " A=D.mm(A).reshape(1,116,116)\n", 87 | " for i in range(1,n):\n", 88 | " A1=Z[i,:,:]+torch.diag(self.ones)\n", 89 | " d=A1.sum(1)\n", 90 | " D=torch.diag(torch.pow(d,-1))\n", 91 | " A1=D.mm(A1).reshape(1,116,116)\n", 92 | " A=torch.cat((A,A1),0)\n", 93 | " return A\n", 94 | " \n", 95 | " def update_A(self,Z):\n", 96 | " n=Z.size()[0]\n", 97 | " A=Z[0,:,:]\n", 98 | " Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*0.2))\n", 99 | " A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 100 | " A=A.reshape(1,116,116)\n", 101 | " for i in range(1,n):\n", 102 | " A2=Z[i,:,:]\n", 103 | " Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*0.2))\n", 104 | " A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 105 | " A2=A2.reshape(1,116,116)\n", 106 | " A=torch.cat((A,A2),0)\n", 107 | " return A\n", 108 | " \n", 109 | " def forward(self,X,Z):\n", 110 | " n=X.size()[0]\n", 111 | " XX=self.line_n[0](X.view(n,-1))\n", 112 | " ZZ=self.line_e[0](Z.view(n,-1))\n", 113 | " for i in range(self.layer):\n", 114 | " A=self.atten[i](Z,X)\n", 115 | " Z1=torch.matmul(A,Z)\n", 116 | " Z2=torch.matmul(Z1,self.edge_w[i])\n", 117 | " Z=self.relu(self.norm_e[i](Z2))+Z\n", 118 | " ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1)\n", 119 | " X1=torch.matmul(A,X)\n", 120 | " X1=torch.matmul(X1,self.node_w[i])\n", 121 | " X=self.relu(self.norm_n[i](X1))+X\n", 122 | " #X.register_hook(grad_X_hook)\n", 123 | " #feat_X_hook(X)\n", 124 | " XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1)\n", 125 | " XZ=torch.cat((XX,ZZ),1)\n", 126 | " y=self.clase(XZ)\n", 127 | " #print(self.clase[0].weight)\n", 128 | " return y\n", 129 | "def grad_X_hook(grad):\n", 130 | " X_grad.append(grad)\n", 131 | "def feat_X_hook(X):\n", 132 | " X_feat.append(X.detach())\n", 133 | "X_grad=list()\n", 134 | "X_feat=list()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 3, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "class LabelSmoothLoss(nn.Module):\n", 144 | " \n", 145 | " def __init__(self, smoothing=0.0):\n", 146 | " super(LabelSmoothLoss, self).__init__()\n", 147 | " self.smoothing = smoothing\n", 148 | " \n", 149 | " def forward(self, input, target):\n", 150 | " log_prob = F.log_softmax(input, dim=-1)\n", 151 | " weight = input.new_ones(input.size()) * \\\n", 152 | " self.smoothing / (input.size(-1) - 1.)\n", 153 | " weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))\n", 154 | " loss = (-weight * log_prob).sum(dim=-1).mean()\n", 155 | " return loss\n", 156 | "def data_split(full_list, ratio, shuffle=True):\n", 157 | " \"\"\"\n", 158 | " 数据集拆分: 将列表full_list按比例ratio(随机)划分为2个子列表sublist_1与sublist_2\n", 159 | " :param full_list: 数据列表\n", 160 | " :param ratio: 子列表1\n", 161 | " :param shuffle: 子列表2\n", 162 | " :return:\n", 163 | " \"\"\"\n", 164 | " n_total = len(full_list)\n", 165 | " offset = int(n_total * ratio)\n", 166 | " if n_total == 0 or offset < 1:\n", 167 | " return [], full_list\n", 168 | " if shuffle:\n", 169 | " random.shuffle(full_list)\n", 170 | " sublist_1 = full_list[:offset]\n", 171 | " sublist_2 = full_list[offset:]\n", 172 | " return sublist_1, sublist_2\n", 173 | "def data_2_k(full_list,k,shuffle=True):\n", 174 | " n_total=len(full_list)\n", 175 | " if shuffle:\n", 176 | " random.shuffle(full_list)\n", 177 | " data_list_list=[]\n", 178 | " for i in range(k):\n", 179 | " data_list_list.append(full_list[int(i*n_total/k):int((i+1)*n_total/k)])\n", 180 | " return data_list_list\n", 181 | "def test(device,model,testloader):\n", 182 | " model.eval()\n", 183 | " TP_test,TN_test,FP_test,FN_test=0,0,0,0\n", 184 | " with torch.no_grad():\n", 185 | " for (X,Z,label,sub_id) in testloader:\n", 186 | " TP,TN,FN,FP=0,0,0,0\n", 187 | " n=X.size()[0]\n", 188 | " X=X.to(device)\n", 189 | " Z=Z.to(device)\n", 190 | " label=label.to(device)\n", 191 | " y=model(X,Z)\n", 192 | " _,predict=torch.max(y,1)\n", 193 | " TP+=((predict==1)&(label==1)).sum().item()\n", 194 | " TN+=((predict==0)&(label==0)).sum().item()\n", 195 | " FN+=((predict==0)&(label==1)).sum().item()\n", 196 | " FP+=((predict==1)&(label==0)).sum().item()\n", 197 | " TP_test+=TP\n", 198 | " TN_test+=TN\n", 199 | " FP_test+=FP\n", 200 | " FN_test+=FN\n", 201 | " acc,f1=cal_evaluate(TP_test,TN_test,FP_test,FN_test)\n", 202 | " global max_acc\n", 203 | " global modelname\n", 204 | " global savedModel\n", 205 | " if acc>=max_acc:\n", 206 | " max_acc=acc\n", 207 | " if saveModel:\n", 208 | " torch.save(model.state_dict(),modelname)\n", 209 | " return acc,f1,TP_test,TN_test,FP_test,FN_test\n", 210 | "#计算边节点的字典\n", 211 | "def gradient(device,model,dataloader):\n", 212 | " model.eval()\n", 213 | " for (X,Z,A,label,sub_id) in dataloader:\n", 214 | " X=torch.autograd.Variable(X,requires_grad=True)\n", 215 | " x=X.to(device)\n", 216 | " Z=torch.autograd.Variable(Z,requires_grad=True)\n", 217 | " z=Z.to(device)\n", 218 | " A=torch.autograd.Variable(A,requires_grad=True)\n", 219 | " a=A.to(device)\n", 220 | " y=model(x,z,a)\n", 221 | " if (label==torch.FloatTensor([0])).item():\n", 222 | " print('0')\n", 223 | " #y.autograd.backward(torch.FloatTensor([[1.,0.]]).to(device))\n", 224 | " torch.autograd.backward(y,torch.FloatTensor([[1.,0.]]).to(device))\n", 225 | " else:\n", 226 | " print('1')\n", 227 | " torch.autograd.backward(y,torch.FloatTensor([[0.,1.]]).to(device))\n", 228 | " grad_X=X.grad\n", 229 | " grad_Z=Z.grad\n", 230 | " #print(grad_X)\n", 231 | " value_x,index_x=torch.topk(torch.abs(grad_X.view(-1)),10)\n", 232 | " grad_X_topk=(torch.abs(grad_X)>=value_x[-1])\n", 233 | " value_z,index_z=torch.topk(torch.abs(grad_Z.view(-1)),100)\n", 234 | " grad_Z_topk=(torch.abs(grad_Z)>=value_z[-1])\n", 235 | " global gradsave_dict\n", 236 | " if label==torch.FloatTensor([0]).item():\n", 237 | " np.save(gradsave_dict+'/TDC/Z/'+str(sub_id.item()),grad_Z.numpy())\n", 238 | " np.save(gradsave_dict+'/TDC/X/'+str(sub_id.item()),grad_X.numpy())\n", 239 | " else:\n", 240 | " np.save(gradsave_dict+'/ASD/Z/'+str(sub_id.item()),grad_Z.numpy())\n", 241 | " np.save(gradsave_dict+'/ASD/X/'+str(sub_id.item()),grad_X.numpy())\n", 242 | " \n", 243 | "def cal_dict():\n", 244 | " index=0\n", 245 | " A={}\n", 246 | " for i in range(116):\n", 247 | " for j in range(i+1,116):\n", 248 | " A[index]=(i,j)\n", 249 | " A[(i,j)]=index\n", 250 | " index+=1\n", 251 | " return A\n", 252 | "def cal_evaluate(TP,TN,FP,FN):\n", 253 | " if TP>0:\n", 254 | " p = TP / (TP + FP)\n", 255 | " r = TP / (TP + FN)\n", 256 | " F1 = 2 * r * p / (r + p)\n", 257 | " else:\n", 258 | " F1=0\n", 259 | " acc = (TP + TN) / (TP + TN + FP + FN)\n", 260 | " #print('ACC:%.4f F1:%.4f [TP:%d|TN:%d|FP:%d|FN:%d]'%(acc,F1,TP,TN,FP,FN))\n", 261 | " return acc,F1\n", 262 | "def data_arange(sites,fmri_root,smri_root,nan_subid):\n", 263 | " asd,tdc=[],[]\n", 264 | " for site in sites:\n", 265 | " mri_asd=os.listdir(smri_root+site+'/group1')\n", 266 | " mri_tdc=os.listdir(smri_root+site+'/group2')\n", 267 | " fmri_asd=os.listdir(fmri_root+site+'/group1_FC')\n", 268 | " fmri_tdc=os.listdir(fmri_root+site+'/group2_FC')\n", 269 | " site_asd=[i for i in mri_asd if i in fmri_asd ]\n", 270 | " site_tdc=[i for i in mri_tdc if i in fmri_tdc ]\n", 271 | " site_asd=[i for i in site_asd if int(i[:5]) not in nan_subid]\n", 272 | " site_tdc=[i for i in site_tdc if int(i[:5]) not in nan_subid]\n", 273 | " asd.append(site_asd)\n", 274 | " tdc.append(site_tdc)\n", 275 | " return asd,tdc\n", 276 | "class dataset(Dataset):\n", 277 | " def __init__(self,fmri_root,site,ASD,TDC):\n", 278 | " super(dataset,self).__init__()\n", 279 | " self.fmri=fmri_root\n", 280 | " self.ASD=[j for i in ASD for j in i]\n", 281 | " self.TDC=[j for i in TDC for j in i]\n", 282 | " self.data=self.ASD+self.TDC\n", 283 | " random.shuffle(self.data)\n", 284 | " self.data_site={}\n", 285 | " for i in range(len(site)):\n", 286 | " data=ASD[i]+TDC[i]\n", 287 | " for j in data:\n", 288 | " if j not in self.data_site:\n", 289 | " self.data_site[j]=site[i] \n", 290 | " def __getitem__(self,index):\n", 291 | " data=self.data[index]\n", 292 | " sub_id=int(data[0:5])\n", 293 | " if data in self.ASD:\n", 294 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group1_slow5/'+data,allow_pickle=True)\n", 295 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group1_slow4/'+data,allow_pickle=True)\n", 296 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group1_FC/'+data,allow_pickle=True)\n", 297 | " elif data in self.TDC:\n", 298 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group2_slow5/'+data,allow_pickle=True)\n", 299 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group2_slow4/'+data,allow_pickle=True)\n", 300 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group2_FC/'+data,allow_pickle=True)\n", 301 | " else:\n", 302 | " print('wrong input')\n", 303 | " data_slow5=(data_slow5-np.min(data_slow5))/(np.max(data_slow5)-np.min(data_slow5))\n", 304 | " data_slow4=(data_slow4-np.min(data_slow4))/(np.max(data_slow4)-np.min(data_slow4))\n", 305 | " if np.any(np.isnan(data_slow5)) or np.any(np.isnan(data_slow4)) or np.any(np.isnan(data_FCz)):\n", 306 | " print('data wronmg')\n", 307 | " #data_FCz=(data_FCz-np.min(data_FCz))/(np.max(data_FCz)-np.min(data_FCz))\n", 308 | " if self.data[index] in self.ASD:\n", 309 | " label=torch.tensor(1)\n", 310 | " else:\n", 311 | " label=torch.tensor(0)\n", 312 | " X=np.zeros((116,2),dtype=np.float32)\n", 313 | " X[:,0]=data_slow5\n", 314 | " X[:,1]=data_slow4\n", 315 | " data_FCz=data_FCz.astype(np.float32)\n", 316 | " Z=torch.from_numpy(data_FCz)\n", 317 | " X=torch.from_numpy(X)\n", 318 | " return X,Z,label,sub_id\n", 319 | " def __len__(self):\n", 320 | " return len(self.data)\n", 321 | "def get_acc(acc_list,toprate):\n", 322 | " acc_list.sort()\n", 323 | " return acc_list[-int(toprate*len(acc_list))]\n", 324 | "def plot_acc(acc_list):\n", 325 | " num_bins=50\n", 326 | " fig,ax=plt.subplots(2)\n", 327 | " n,bins,patches=ax[0].hist(acc_list,num_bins,density=True)\n", 328 | " ax[1].plot(acc_list)\n", 329 | " ax[1].set_ylim(0.4,1)\n", 330 | " plt.show()\n", 331 | " print('Top:10%:',get_acc(acc_list,0.1))\n", 332 | " print('Top:20%:',get_acc(acc_list,0.2))\n", 333 | " print('Max: ',max(acc_list))\n", 334 | "def train_adversivial(model,trainloader,testloader,epsilon):\n", 335 | " acc_list=[]\n", 336 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 337 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 338 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 339 | " for _ in range(epoch):\n", 340 | " TP,TN,FP,FN=0,0,0,0\n", 341 | " for (X,Z,A,label,sub_id) in trainloader:\n", 342 | " model.train()\n", 343 | " x=X.to(device)\n", 344 | " z=Z.to(device)\n", 345 | " x.requires_grad=True\n", 346 | " z.requires_grad=True\n", 347 | " label=label.to(device)\n", 348 | " y=model(x,z)\n", 349 | " loss=criterian1(y,label)\n", 350 | " model.zero_grad()\n", 351 | " loss.backward(retain_graph=True)\n", 352 | " sign_grad_x=torch.sign(x.grad.data)\n", 353 | " sign_grad_z=torch.sign(z.grad.data)\n", 354 | " perturbed_x=x+epsilon*sign_grad_x \n", 355 | " perturbed_z=z+epsilon*sign_grad_z\n", 356 | " perturbed_x=torch.clamp(perturbed_x,0,1)\n", 357 | " perturbed_z=torch.clamp(perturbed_z,-1,1)\n", 358 | " y=model(perturbed_x,perturbed_z)\n", 359 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 360 | " if L2_lamda>0:\n", 361 | " for name,parameters in model.named_parameters():\n", 362 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 363 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 364 | " loss=0.5*(criterian1(y,label)+loss)+L2\n", 365 | " optimizer.zero_grad()\n", 366 | " loss.backward()\n", 367 | " optimizer.step()\n", 368 | " acc,f1=test(device,model,testloader)\n", 369 | " acc_list.append(acc)\n", 370 | " plot_acc(acc_list)\n", 371 | " print('Top:10|%.4f\\tTop:20|%.4f\\tMax|%.4f'%(get_acc(acc_list,0.1),get_acc(acc_list,0.2),max(acc_list)))\n", 372 | "def train_pgd(model,trainloader,testloader,eps=0.2,iters=10,alpha=2/255):\n", 373 | " acc_list=[]\n", 374 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 375 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 376 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 377 | " for _ in range(epoch):\n", 378 | " TP,TN,FP,FN=0,0,0,0\n", 379 | " for (X,Z,label,sub_id) in trainloader:\n", 380 | " model.train()\n", 381 | " x=X.to(device)\n", 382 | " z=Z.to(device)\n", 383 | " label=label.to(device)\n", 384 | " for i in range(iters):\n", 385 | " x.requires_grad=True\n", 386 | " z.requires_grad=True\n", 387 | " y=model(x,z)\n", 388 | " loss=criterian1(y,label)\n", 389 | " model.zero_grad()\n", 390 | " loss.backward()\n", 391 | " adv_x=x+alpha*torch.sign(x.grad.data)\n", 392 | " adv_z=z+alpha*torch.sign(z.grad.data)\n", 393 | " eta_x=torch.clamp(adv_x-x,min=-eps,max=eps)\n", 394 | " eta_z=torch.clamp(adv_z-z,min=-eps,max=eps)\n", 395 | " x=torch.clamp(x+eta_x,min=0,max=1).detach_()\n", 396 | " z=torch.clamp(z+eta_z,min=-1,max=1).detach_()\n", 397 | " y=model(x,z)\n", 398 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 399 | " if L2_lamda>0:\n", 400 | " for name,parameters in model.named_parameters():\n", 401 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 402 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 403 | " loss=criterian1(y,label)+L2\n", 404 | " optimizer.zero_grad()\n", 405 | " loss.backward()\n", 406 | " optimizer.step()\n", 407 | " acc,f1=test(device,model,testloader)\n", 408 | " acc_list.append(acc)\n", 409 | " plot_acc(acc_list)\n", 410 | " print('Top:10|%.4f\\tTop:20|%.4f\\tMax|%.4f'%(get_acc(acc_list,0.1),get_acc(acc_list,0.2),max(acc_list)))\n", 411 | "def train_pgd(model,trainloader,testloader,eps=0.2,iters=10,alpha=8/255):\n", 412 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 413 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 414 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 415 | " for j in range(epoch):\n", 416 | " loss_sum=0.\n", 417 | " TP,TN,FP,FN=0,0,0,0\n", 418 | " model.train()\n", 419 | " for (X,Z,label,sub_id) in trainloader:\n", 420 | " model.train()\n", 421 | " x=X.to(device)\n", 422 | " z=Z.to(device)\n", 423 | " label=label.to(device)\n", 424 | " pretu_x,pretu_z=x,z\n", 425 | " ori_x,ori_z=x.data,z.data\n", 426 | " for i in range(iters):\n", 427 | " pretu_x.requires_grad=True\n", 428 | " pretu_z.requires_grad=True\n", 429 | " y=model(pretu_x,pretu_z)\n", 430 | " loss=criterian1(y,label)\n", 431 | " model.zero_grad()\n", 432 | " loss.backward()\n", 433 | " adv_x=pretu_x+alpha*torch.sign(pretu_x.grad.data)\n", 434 | " adv_z=pretu_z+alpha*torch.sign(pretu_z.grad.data)\n", 435 | " eta_x=torch.clamp(adv_x-ori_x,min=-eps,max=eps)\n", 436 | " eta_z=torch.clamp(adv_z-ori_z,min=-eps,max=eps)\n", 437 | " pretu_x=torch.clamp(ori_x+eta_x,min=0,max=1).detach_()\n", 438 | " pretu_z=torch.clamp(ori_z+eta_z,min=-1,max=1).detach_()\n", 439 | " y=model(x,z)\n", 440 | " yy=model(pretu_x,pretu_z)\n", 441 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 442 | " if L2_lamda>0:\n", 443 | " for name,parameters in model.named_parameters():\n", 444 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 445 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 446 | " loss=0.5*(criterian1(yy,label)+criterian1(y,label))+L2\n", 447 | " loss_sum+=loss.item()\n", 448 | " optimizer.zero_grad()\n", 449 | " loss.backward()\n", 450 | " optimizer.step()\n", 451 | " if (j+1)%10==0:\n", 452 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 453 | " result.loc[(j+1)//10]=[loss_sum,acc,f1,TP_test,TN_test,FP_test,FN_test]\n", 454 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 455 | " Sen=result.iloc[0]['TP']/(result.iloc[0]['TP']+result.iloc[0]['FN'])\n", 456 | " Spe=result.iloc[0]['TN']/(result.iloc[0]['TN']+result.iloc[0]['FP'])\n", 457 | " print('Acc: {:.4f} Sen: {:.4f} Spe: {:.4f}'.format(result.iloc[0]['Acc'],Sen,Spe))\n", 458 | " return result.iloc[0]['Acc'],Sen,Spe\n", 459 | "def train(model,trainloader,testloader):\n", 460 | " acc_list=[]\n", 461 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 462 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 463 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 464 | " for j in range(epoch):\n", 465 | " loss_sum=0\n", 466 | " TP,TN,FP,FN=0,0,0,0\n", 467 | " time_start=time.time()\n", 468 | " loss1_sum=loss2_sum=0\n", 469 | " for (X,Z,A,label,sub_id) in trainloader:\n", 470 | " #print(A)\n", 471 | " #print(Z)\n", 472 | " #print(X.shape,torch.mean(X),torch.std(X))\n", 473 | " model.train()\n", 474 | " X=X.to(device)\n", 475 | " Z=Z.to(device)\n", 476 | " label=label.to(device)\n", 477 | " y=model(X,Z)\n", 478 | " #print(y)\n", 479 | " _,predict=torch.max(y,1)\n", 480 | " TP+=((predict==1)&(label==1)).sum().item()\n", 481 | " TN+=((predict==0)&(label==0)).sum().item()\n", 482 | " FN+=((predict==0)&(label==1)).sum().item()\n", 483 | " FP+=((predict==1)&(label==0)).sum().item()\n", 484 | " loss=criterian1(y,label)\n", 485 | " L1=torch.tensor(0,dtype=torch.float32).to(device)\n", 486 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 487 | " #print(model.parameters.weit_n)\n", 488 | " if L1_lamda>0 or L2_lamda>0:\n", 489 | " for name,parameters in model.named_parameters():\n", 490 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 491 | " L1+=L1_lamda*torch.norm(parameters,1)\n", 492 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 493 | " loss+=(L1+L2)\n", 494 | " loss_sum+=loss.item()\n", 495 | " optimizer.zero_grad()\n", 496 | " loss.backward()\n", 497 | " #print(model.clase[0].weight.grad)\n", 498 | " optimizer.step() \n", 499 | " time_end=time.time()\n", 500 | " time_cost=(time_end-time_start)/60.0\n", 501 | " acc,f1=cal_evaluate(TP,TN,FP,FN)\n", 502 | " #print(\"[%2d/%d] ACC:%.2f F1:%.2f Loss: %.4f [TP:%3d|TN:%3d|FP:%3d|FN:%3d] CostTime:%4.1f min | RestTime:%.2f h\" %(\n", 503 | " # j+1,epoch,acc,f1,loss_sum,TP,TN,FP,FN,time_cost,time_cost/60*(epoch-1-j)))\n", 504 | " #losses.append(loss_sum)\n", 505 | " #print(model.parameters())\n", 506 | " acc,f1=test(device,model,testloader)\n", 507 | " #print(acc)\n", 508 | " acc_list.append(acc)\n", 509 | " \n", 510 | " plot_acc(acc_list)\n", 511 | " print('Top:10|%.4f\\tTop:20|%.4f\\tMax|%.4f'%(get_acc(acc_list,0.1),get_acc(acc_list,0.2),max(acc_list)))" 512 | ] 513 | }, 514 | { 515 | "cell_type": "raw", 516 | "metadata": {}, 517 | "source": [ 518 | "## 机构混合\n", 519 | "test_site =os.listdir(cpac_root)\n", 520 | "train_site=test_site\n", 521 | "k_fold=10 \n", 522 | "asd,tdc=data_arange(test_site,fmri_root=cpac_root,nan_subid=nan_subid,smri_root=smri_root)\n", 523 | "train_asd_dict={i:[] for i in range(k_fold)}\n", 524 | "train_tdc_dict={i:[] for i in range(k_fold)}\n", 525 | "test_asd_dict ={i:[] for i in range(k_fold)}\n", 526 | "test_tdc_dict ={i:[] for i in range(k_fold)}\n", 527 | "for data in asd:\n", 528 | " datasplit=data_2_k(data,k_fold,True)\n", 529 | " for index in range(k_fold):\n", 530 | " test_asd_dict[index].append(datasplit[index])\n", 531 | " train_temp=[j for i in datasplit for j in i if j not in datasplit[index]]\n", 532 | " train_asd_dict[index].append(train_temp)\n", 533 | "for data in tdc:\n", 534 | " datasplit=data_2_k(data,k_fold,True)\n", 535 | " for index in range(k_fold):\n", 536 | " test_tdc_dict[index].append(datasplit[index])\n", 537 | " test_temp=[j for i in datasplit for j in i if j not in datasplit[index]]\n", 538 | " train_tdc_dict[index].append(test_temp)" 539 | ] 540 | }, 541 | { 542 | "cell_type": "raw", 543 | "metadata": {}, 544 | "source": [ 545 | "## 机构混合\n", 546 | "np.save('train_asd_dict.npy',train_asd_dict)\n", 547 | "np.save('train_tdc_dict.npy',train_tdc_dict)\n", 548 | "np.save('test_asd_dict.npy',test_asd_dict)\n", 549 | "np.save('test_tdc_dict.npy',test_tdc_dict)\n", 550 | "np.save('train_test_site.npy',train_site)" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 4, 556 | "metadata": {}, 557 | "outputs": [], 558 | "source": [ 559 | "## 机构混合续\n", 560 | "train_site=test_site=np.load('DATAARRANGE/train_test_site.npy')\n", 561 | "train_asd_dict=np.load('DATAARRANGE/train_asd_dict.npy',allow_pickle=True).item()\n", 562 | "train_tdc_dict=np.load('DATAARRANGE/train_tdc_dict.npy',allow_pickle=True).item()\n", 563 | "test_asd_dict=np.load('DATAARRANGE/test_asd_dict.npy',allow_pickle=True).item()\n", 564 | "test_tdc_dict=np.load('DATAARRANGE/test_tdc_dict.npy',allow_pickle=True).item()" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": 5, 570 | "metadata": {}, 571 | "outputs": [ 572 | { 573 | "name": "stdout", 574 | "output_type": "stream", 575 | "text": [ 576 | "Acc: 0.7059 Sen: 0.6250 Spe: 0.7778\n", 577 | "Acc: 0.6765 Sen: 0.6250 Spe: 0.7222\n", 578 | "Acc: 0.6634 Sen: 0.6200 Spe: 0.7059\n", 579 | "Acc: 0.6990 Sen: 0.6458 Spe: 0.7455\n", 580 | "Acc: 0.6796 Sen: 0.5400 Spe: 0.8113\n", 581 | "Acc: 0.6979 Sen: 0.6522 Spe: 0.7400\n", 582 | "Acc: 0.7200 Sen: 0.7660 Spe: 0.6792\n", 583 | "Acc: 0.6538 Sen: 0.6667 Spe: 0.6415\n", 584 | "Acc: 0.6869 Sen: 0.5745 Spe: 0.7885\n", 585 | "Acc: 0.7193 Sen: 0.5926 Spe: 0.8333\n" 586 | ] 587 | } 588 | ], 589 | "source": [ 590 | "global max_acc\n", 591 | "global saveModel\n", 592 | "global modelname\n", 593 | "L1_lamda=0.0\n", 594 | "L2_lamda=0.0001\n", 595 | "learning_rate=0.0001\n", 596 | "epoch =100\n", 597 | "batch_size=64\n", 598 | "layer =5\n", 599 | "Acc,Sen,Spe=np.zeros(10),np.zeros(10),np.zeros(10)\n", 600 | "for index in range(10):\n", 601 | " saveModel=False\n", 602 | " max_acc=0.6\n", 603 | " train_asd=train_asd_dict[index]\n", 604 | " train_tdc=train_tdc_dict[index]\n", 605 | " test_asd =test_asd_dict[index]\n", 606 | " test_tdc =test_tdc_dict[index]\n", 607 | " trainset=dataset(site=train_site,fmri_root=cpac_root,ASD=train_asd,TDC=train_tdc)\n", 608 | " trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True)\n", 609 | " testset=dataset(site=test_site,fmri_root=cpac_root,ASD=test_asd,TDC=test_tdc)\n", 610 | " testloader=DataLoader(testset,batch_size=1)\n", 611 | " model=ANEGCN(layer).to(device)\n", 612 | " Acc[index],Sen[index],Spe[index]=train_pgd(model,trainloader,testloader,eps=0.001,iters=10,alpha=2/(255*50))" 613 | ] 614 | }, 615 | { 616 | "cell_type": "raw", 617 | "metadata": { 618 | "jupyter": { 619 | "source_hidden": true 620 | } 621 | }, 622 | "source": [ 623 | "#to find wrong data\n", 624 | "''' \n", 625 | "to find nan_subject_id \n", 626 | "'''\n", 627 | "nan_subid=[]\n", 628 | "for index in range(10):\n", 629 | " test_asd =test_asd_dict[index]\n", 630 | " test_tdc =test_tdc_dict[index]\n", 631 | " testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc,edge_dict=cal_dict())\n", 632 | " testloader=DataLoader(testset,batch_size=1)\n", 633 | " model=NEResGCN(5)\n", 634 | " model.eval()\n", 635 | " for (X,Z,A,label,sub_id) in testloader:\n", 636 | " sub_id=sub_id.item()\n", 637 | " y=model(X,Z,A)\n", 638 | " if torch.any(torch.isnan(y)):\n", 639 | " print(sub_id)\n", 640 | " nan_subid.append(sub_id)" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": null, 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [] 649 | } 650 | ], 651 | "metadata": { 652 | "kernelspec": { 653 | "display_name": "Pytorch", 654 | "language": "python", 655 | "name": "pytorch" 656 | }, 657 | "language_info": { 658 | "codemirror_mode": { 659 | "name": "ipython", 660 | "version": 3 661 | }, 662 | "file_extension": ".py", 663 | "mimetype": "text/x-python", 664 | "name": "python", 665 | "nbconvert_exporter": "python", 666 | "pygments_lexer": "ipython3", 667 | "version": "3.6.13" 668 | }, 669 | "toc-autonumbering": false, 670 | "toc-showcode": true, 671 | "toc-showmarkdowntxt": false, 672 | "toc-showtags": false 673 | }, 674 | "nbformat": 4, 675 | "nbformat_minor": 4 676 | } 677 | -------------------------------------------------------------------------------- /Leave-one-site-out.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn \n", 11 | "import numpy as np\n", 12 | "from torch.utils.data import Dataset\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "import os\n", 15 | "import time\n", 16 | "import random\n", 17 | "import pandas as pd\n", 18 | "import torch.nn.functional as F\n", 19 | "device=torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n", 20 | "cpac_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_CPAC/'\n", 21 | "smri_root='/media/D/yazid/ASD-classification-ANEGCN/ABIDEI_sMRI/'\n", 22 | "nan_subid=np.load('nan_subid.npy').tolist()\n", 23 | "import warnings \n", 24 | "warnings.filterwarnings(\"ignore\")\n", 25 | "random_seed = 7777\n", 26 | "def setup_seed(seed):\n", 27 | " torch.manual_seed(seed)\n", 28 | " torch.cuda.manual_seed_all(seed)\n", 29 | " np.random.seed(seed)\n", 30 | " random.seed(seed)\n", 31 | " torch.backends.cudnn.deterministic = True" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "class Attention(nn.Module):\n", 41 | " def __init__(self):\n", 42 | " super(Attention,self).__init__()\n", 43 | " self.conv1=nn.Conv1d(in_channels=3,out_channels=3,kernel_size=1,padding=0)\n", 44 | " self.conv2=nn.Conv1d(in_channels=116,out_channels=116,kernel_size=1,padding=0)\n", 45 | " self.softmax=nn.Softmax(dim=-1)\n", 46 | " def forward(self,Z,X):\n", 47 | " K=self.conv1(X.permute(0,2,1))\n", 48 | " Q=K.permute(0,2,1)\n", 49 | " V=self.conv2(Z.permute(0,2,1))\n", 50 | " attention=self.softmax(torch.matmul(Q,K))\n", 51 | " out=torch.bmm(attention,V).permute(0,2,1)\n", 52 | " return out\n", 53 | "class NEGAN(nn.Module):\n", 54 | " def __init__(self,layer,dropout_rate):\n", 55 | " super(NEGAN,self).__init__()\n", 56 | " self.layer =layer\n", 57 | " self.relu =nn.ReLU()\n", 58 | " self.atten =nn.ModuleList([Attention() for i in range(layer)])\n", 59 | " self.norm_n=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 60 | " self.norm_e=nn.ModuleList([nn.BatchNorm1d(116) for i in range(layer)])\n", 61 | " self.node_w=nn.ParameterList([nn.Parameter(torch.randn((3,3),dtype=torch.float32)) for i in range(layer)])\n", 62 | " self.edge_w=nn.ParameterList([nn.Parameter(torch.randn((116,116),dtype=torch.float32)) for i in range(layer)])\n", 63 | " self.line_n=nn.ModuleList([nn.Sequential(nn.Linear(116*3,128),nn.ReLU(),nn.BatchNorm1d(128)) for i in range(layer+1)])\n", 64 | " self.line_e=nn.ModuleList([nn.Sequential(nn.Linear(116*116,128*3),nn.ReLU(),nn.BatchNorm1d(128*3)) for i in range(layer+1)])\n", 65 | " self.clase =nn.Sequential(nn.Linear(128*4*(self.layer+1),1024),nn.Dropout(dropout_rate),nn.ReLU(),\n", 66 | " nn.Linear(1024,2))\n", 67 | " self.ones=nn.Parameter(torch.ones((116),dtype=torch.float32),requires_grad=False)\n", 68 | " self._initialize_weights()\n", 69 | "\n", 70 | " # params initialization\n", 71 | " def _initialize_weights(self):\n", 72 | " for m in self.modules():\n", 73 | " if isinstance(m, (nn.Conv1d,nn.Linear)):\n", 74 | " nn.init.xavier_uniform_(m.weight)\n", 75 | " if m.bias is not None:\n", 76 | " nn.init.constant_(m.bias, 0)\n", 77 | " elif isinstance(m, nn.BatchNorm1d):\n", 78 | " nn.init.constant_(m.weight, 1)\n", 79 | " nn.init.constant_(m.bias, 0)\n", 80 | " def normalized(self,Z):\n", 81 | " n=Z.size()[0]\n", 82 | " A=Z[0,:,:]\n", 83 | " A=A+torch.diag(self.ones)\n", 84 | " d=A.sum(1)\n", 85 | " D=torch.diag(torch.pow(d,-1))\n", 86 | " A=D.mm(A).reshape(1,116,116)\n", 87 | " for i in range(1,n):\n", 88 | " A1=Z[i,:,:]+torch.diag(self.ones)\n", 89 | " d=A1.sum(1)\n", 90 | " D=torch.diag(torch.pow(d,-1))\n", 91 | " A1=D.mm(A1).reshape(1,116,116)\n", 92 | " A=torch.cat((A,A1),0)\n", 93 | " return A\n", 94 | " \n", 95 | " def update_A(self,Z):\n", 96 | " n=Z.size()[0]\n", 97 | " A=Z[0,:,:]\n", 98 | " Value,_=torch.topk(torch.abs(A.view(-1)),int(116*116*0.2))\n", 99 | " A=(torch.abs(A)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 100 | " A=A.reshape(1,116,116)\n", 101 | " for i in range(1,n):\n", 102 | " A2=Z[i,:,:]\n", 103 | " Value,_=torch.topk(torch.abs(A2.view(-1)),int(116*116*0.2))\n", 104 | " A2=(torch.abs(A2)>=Value[-1])+torch.tensor(0,dtype=torch.float32)\n", 105 | " A2=A2.reshape(1,116,116)\n", 106 | " A=torch.cat((A,A2),0)\n", 107 | " return A\n", 108 | " \n", 109 | " def forward(self,X,Z):\n", 110 | " n=X.size()[0]\n", 111 | " XX=self.line_n[0](X.view(n,-1))\n", 112 | " ZZ=self.line_e[0](Z.view(n,-1))\n", 113 | " for i in range(self.layer):\n", 114 | " A=self.atten[i](Z,X)\n", 115 | " Z1=torch.matmul(A,Z)\n", 116 | " Z2=torch.matmul(Z1,self.edge_w[i])\n", 117 | " Z=self.relu(self.norm_e[i](Z2))+Z\n", 118 | " ZZ=torch.cat((ZZ,self.line_e[i+1](Z.view(n,-1))),dim=1)\n", 119 | " X1=torch.matmul(A,X)\n", 120 | " X1=torch.matmul(X1,self.node_w[i])\n", 121 | " X=self.relu(self.norm_n[i](X1))+X\n", 122 | " #X.register_hook(grad_X_hook)\n", 123 | " #feat_X_hook(X)\n", 124 | " XX=torch.cat((XX,self.line_n[i+1](X.view(n,-1))),dim=1)\n", 125 | " XZ=torch.cat((XX,ZZ),1)\n", 126 | " y=self.clase(XZ)\n", 127 | " #print(self.clase[0].weight)\n", 128 | " return y\n", 129 | "def grad_X_hook(grad):\n", 130 | " X_grad.append(grad)\n", 131 | "def feat_X_hook(X):\n", 132 | " X_feat.append(X.detach())\n", 133 | "X_grad=list()\n", 134 | "X_feat=list()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 3, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "class LabelSmoothLoss(nn.Module):\n", 144 | " \n", 145 | " def __init__(self, smoothing=0.0):\n", 146 | " super(LabelSmoothLoss, self).__init__()\n", 147 | " self.smoothing = smoothing\n", 148 | " \n", 149 | " def forward(self, input, target):\n", 150 | " log_prob = F.log_softmax(input, dim=-1)\n", 151 | " weight = input.new_ones(input.size()) * \\\n", 152 | " self.smoothing / (input.size(-1) - 1.)\n", 153 | " weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))\n", 154 | " loss = (-weight * log_prob).sum(dim=-1).mean()\n", 155 | " return loss\n", 156 | "def data_split(full_list, ratio, shuffle=True):\n", 157 | " \"\"\"\n", 158 | " 数据集拆分: 将列表full_list按比例ratio(随机)划分为2个子列表sublist_1与sublist_2\n", 159 | " :param full_list: 数据列表\n", 160 | " :param ratio: 子列表1\n", 161 | " :param shuffle: 子列表2\n", 162 | " :return:\n", 163 | " \"\"\"\n", 164 | " n_total = len(full_list)\n", 165 | " offset = int(n_total * ratio)\n", 166 | " if n_total == 0 or offset < 1:\n", 167 | " return [], full_list\n", 168 | " if shuffle:\n", 169 | " random.shuffle(full_list)\n", 170 | " sublist_1 = full_list[:offset]\n", 171 | " sublist_2 = full_list[offset:]\n", 172 | " return sublist_1, sublist_2\n", 173 | "def data_2_k(full_list,k,shuffle=True):\n", 174 | " n_total=len(full_list)\n", 175 | " if shuffle:\n", 176 | " random.shuffle(full_list)\n", 177 | " data_list_list=[]\n", 178 | " for i in range(k):\n", 179 | " data_list_list.append(full_list[int(i*n_total/k):int((i+1)*n_total/k)])\n", 180 | " return data_list_list\n", 181 | "def test(device,model,testloader):\n", 182 | " model.eval()\n", 183 | " TP_test,TN_test,FP_test,FN_test=0,0,0,0\n", 184 | " with torch.no_grad():\n", 185 | " for (X,Z,label,sub_id) in testloader:\n", 186 | " TP,TN,FN,FP=0,0,0,0\n", 187 | " n=X.size()[0]\n", 188 | " X=X.to(device)\n", 189 | " Z=Z.to(device)\n", 190 | " label=label.to(device)\n", 191 | " y=model(X,Z)\n", 192 | " _,predict=torch.max(y,1)\n", 193 | " TP+=((predict==1)&(label==1)).sum().item()\n", 194 | " TN+=((predict==0)&(label==0)).sum().item()\n", 195 | " FN+=((predict==0)&(label==1)).sum().item()\n", 196 | " FP+=((predict==1)&(label==0)).sum().item()\n", 197 | " TP_test+=TP\n", 198 | " TN_test+=TN\n", 199 | " FP_test+=FP\n", 200 | " FN_test+=FN\n", 201 | " acc,f1=cal_evaluate(TP_test,TN_test,FP_test,FN_test)\n", 202 | " global max_acc\n", 203 | " global modelname\n", 204 | " global savedModel\n", 205 | " if acc>=max_acc:\n", 206 | " max_acc=acc\n", 207 | " if saveModel:\n", 208 | " torch.save(model.state_dict(),modelname)\n", 209 | " ##read\n", 210 | " #model=NERESGCN(layer)\n", 211 | " #model.load_state_dict(torch.load(PATH))\n", 212 | " #print('Saved the model')\n", 213 | " #print('TEST: ACC:%.4f F1:%.4f [TP:%3d|TN:%3d|FP:%3d|FN:%3d]'%(acc,f1,TP_test,TN_test,FP_test,FN_test)) \n", 214 | " return acc,f1,TP_test,TN_test,FP_test,FN_test\n", 215 | "#计算边节点的字典\n", 216 | "def gradient(device,model,dataloader):\n", 217 | " model.eval()\n", 218 | " for (X,Z,A,label,sub_id) in dataloader:\n", 219 | " X=torch.autograd.Variable(X,requires_grad=True)\n", 220 | " x=X.to(device)\n", 221 | " Z=torch.autograd.Variable(Z,requires_grad=True)\n", 222 | " z=Z.to(device)\n", 223 | " A=torch.autograd.Variable(A,requires_grad=True)\n", 224 | " a=A.to(device)\n", 225 | " y=model(x,z,a)\n", 226 | " if (label==torch.FloatTensor([0])).item():\n", 227 | " print('0')\n", 228 | " #y.autograd.backward(torch.FloatTensor([[1.,0.]]).to(device))\n", 229 | " torch.autograd.backward(y,torch.FloatTensor([[1.,0.]]).to(device))\n", 230 | " else:\n", 231 | " print('1')\n", 232 | " torch.autograd.backward(y,torch.FloatTensor([[0.,1.]]).to(device))\n", 233 | " grad_X=X.grad\n", 234 | " grad_Z=Z.grad\n", 235 | " #print(grad_X)\n", 236 | " value_x,index_x=torch.topk(torch.abs(grad_X.view(-1)),10)\n", 237 | " grad_X_topk=(torch.abs(grad_X)>=value_x[-1])\n", 238 | " value_z,index_z=torch.topk(torch.abs(grad_Z.view(-1)),100)\n", 239 | " grad_Z_topk=(torch.abs(grad_Z)>=value_z[-1])\n", 240 | " global gradsave_dict\n", 241 | " if label==torch.FloatTensor([0]).item():\n", 242 | " np.save(gradsave_dict+'/TDC/Z/'+str(sub_id.item()),grad_Z.numpy())\n", 243 | " np.save(gradsave_dict+'/TDC/X/'+str(sub_id.item()),grad_X.numpy())\n", 244 | " else:\n", 245 | " np.save(gradsave_dict+'/ASD/Z/'+str(sub_id.item()),grad_Z.numpy())\n", 246 | " np.save(gradsave_dict+'/ASD/X/'+str(sub_id.item()),grad_X.numpy())\n", 247 | " \n", 248 | "def cal_dict():\n", 249 | " index=0\n", 250 | " A={}\n", 251 | " for i in range(116):\n", 252 | " for j in range(i+1,116):\n", 253 | " A[index]=(i,j)\n", 254 | " A[(i,j)]=index\n", 255 | " index+=1\n", 256 | " return A\n", 257 | "def cal_evaluate(TP,TN,FP,FN):\n", 258 | " if TP>0:\n", 259 | " p = TP / (TP + FP)\n", 260 | " r = TP / (TP + FN)\n", 261 | " F1 = 2 * r * p / (r + p)\n", 262 | " else:\n", 263 | " F1=0\n", 264 | " acc = (TP + TN) / (TP + TN + FP + FN)\n", 265 | " #print('ACC:%.4f F1:%.4f [TP:%d|TN:%d|FP:%d|FN:%d]'%(acc,F1,TP,TN,FP,FN))\n", 266 | " return acc,F1\n", 267 | "def data_arange(sites,fmri_root,smri_root,nan_subid):\n", 268 | " asd,tdc=[],[]\n", 269 | " for site in sites:\n", 270 | " mri_asd=os.listdir(smri_root+site+'/group1')\n", 271 | " mri_tdc=os.listdir(smri_root+site+'/group2')\n", 272 | " fmri_asd=os.listdir(fmri_root+site+'/group1_FC')\n", 273 | " fmri_tdc=os.listdir(fmri_root+site+'/group2_FC')\n", 274 | " site_asd=[i for i in mri_asd if i in fmri_asd ]\n", 275 | " site_tdc=[i for i in mri_tdc if i in fmri_tdc ]\n", 276 | " site_asd=[i for i in site_asd if int(i[:5]) not in nan_subid]\n", 277 | " site_tdc=[i for i in site_tdc if int(i[:5]) not in nan_subid]\n", 278 | " asd.append(site_asd)\n", 279 | " tdc.append(site_tdc)\n", 280 | " return asd,tdc\n", 281 | "class dataset(Dataset):\n", 282 | " def __init__(self,fmri_root,smri_root,site,ASD,TDC,edge_dict=cal_dict(),topk=True,rate=0.2):\n", 283 | " super(dataset,self).__init__()\n", 284 | " self.fmri=fmri_root\n", 285 | " self.smri=smri_root\n", 286 | " self.ASD=[j for i in ASD for j in i]\n", 287 | " self.TDC=[j for i in TDC for j in i]\n", 288 | " self.data=self.ASD+self.TDC\n", 289 | " random.shuffle(self.data)\n", 290 | " self.data_site={}\n", 291 | " for i in range(len(site)):\n", 292 | " data=ASD[i]+TDC[i]\n", 293 | " for j in data:\n", 294 | " if j not in self.data_site:\n", 295 | " self.data_site[j]=site[i] \n", 296 | " self.dict=edge_dict\n", 297 | " self.rate=rate\n", 298 | " self.topk=topk\n", 299 | " def normalize(self,A):\n", 300 | " d=A.sum(1)\n", 301 | " D=torch.diag(torch.pow(d,-1))\n", 302 | " return D.mm(A)\n", 303 | " def __getitem__(self,index):\n", 304 | " data=self.data[index]\n", 305 | " sub_id=int(data[0:5])\n", 306 | " if data in self.ASD:\n", 307 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group1_slow5/'+data,allow_pickle=True)\n", 308 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group1_slow4/'+data,allow_pickle=True)\n", 309 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group1/'+data,allow_pickle=True)\n", 310 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group1_FC/'+data,allow_pickle=True)\n", 311 | " elif data in self.TDC:\n", 312 | " data_slow5 =np.load(self.fmri+self.data_site[data]+'/group2_slow5/'+data,allow_pickle=True)\n", 313 | " data_slow4 =np.load(self.fmri+self.data_site[data]+'/group2_slow4/'+data,allow_pickle=True)\n", 314 | " data_voxel =np.load(self.smri+self.data_site[data]+'/group2/'+data,allow_pickle=True)\n", 315 | " data_FCz =np.load(self.fmri+self.data_site[data]+'/group2_FC/'+data,allow_pickle=True)\n", 316 | " else:\n", 317 | " print('wrong input')\n", 318 | " data_slow5=(data_slow5-np.min(data_slow5))/(np.max(data_slow5)-np.min(data_slow5))\n", 319 | " data_slow4=(data_slow4-np.min(data_slow4))/(np.max(data_slow4)-np.min(data_slow4))\n", 320 | " if np.any(np.isnan(data_slow5)) or np.any(np.isnan(data_slow4)) or np.any(np.isnan(data_FCz)):\n", 321 | " print('data wronmg')\n", 322 | " #data_FCz=(data_FCz-np.min(data_FCz))/(np.max(data_FCz)-np.min(data_FCz))\n", 323 | " if self.data[index] in self.ASD:\n", 324 | " label=torch.tensor(1)\n", 325 | " else:\n", 326 | " label=torch.tensor(0)\n", 327 | " X=np.zeros((116,3),dtype=np.float32)\n", 328 | " X[:,0]=data_slow5\n", 329 | " X[:,1]=data_slow4\n", 330 | " X[:,2]=data_voxel\n", 331 | " data_FCz=data_FCz.astype(np.float32)\n", 332 | " Z=torch.from_numpy(data_FCz)\n", 333 | " X=torch.from_numpy(X)\n", 334 | " return X,Z,label,sub_id\n", 335 | " def __len__(self):\n", 336 | " return len(self.data)\n", 337 | "def get_acc(acc_list,toprate):\n", 338 | " acc_list.sort()\n", 339 | " return acc_list[-int(toprate*len(acc_list))]\n", 340 | "def plot_acc(acc_list):\n", 341 | " num_bins=50\n", 342 | " fig,ax=plt.subplots(2)\n", 343 | " n,bins,patches=ax[0].hist(acc_list,num_bins,density=True)\n", 344 | " ax[1].plot(acc_list)\n", 345 | " ax[1].set_ylim(0.4,1)\n", 346 | " plt.show()\n", 347 | " print('Top:10%:',get_acc(acc_list,0.1))\n", 348 | " print('Top:20%:',get_acc(acc_list,0.2))\n", 349 | " print('Max: ',max(acc_list))\n", 350 | "def train_fgsm(model,trainloader,testloader,epsilon):\n", 351 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 352 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 353 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 354 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 355 | " acc=0.5000\n", 356 | " loss_sum=0\n", 357 | " for j in range(epoch):\n", 358 | " print('\\r[%3d/%3d] Loss: %.2f Acc:%.4f' %(j+1,epoch,loss_sum,acc),end='')\n", 359 | " loss_sum=0\n", 360 | " TP,TN,FP,FN=0,0,0,0\n", 361 | " for (X,Z,A,label,sub_id) in trainloader:\n", 362 | " model.train()\n", 363 | " x=X.to(device)\n", 364 | " z=Z.to(device)\n", 365 | " x.requires_grad=True\n", 366 | " z.requires_grad=True\n", 367 | " label=label.to(device)\n", 368 | " y=model(x,z)\n", 369 | " loss=criterian1(y,label)\n", 370 | " model.zero_grad()\n", 371 | " loss.backward(retain_graph=True)\n", 372 | " sign_grad_x=torch.sign(x.grad.data)\n", 373 | " sign_grad_z=torch.sign(z.grad.data)\n", 374 | " perturbed_x=x+epsilon*sign_grad_x \n", 375 | " perturbed_z=z+epsilon*sign_grad_z \n", 376 | " perturbed_x=torch.clamp(perturbed_x,0,1)\n", 377 | " perturbed_z=torch.clamp(perturbed_z,-1,1)\n", 378 | " y=model(perturbed_x,perturbed_z)\n", 379 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 380 | " if L2_lamda>0:\n", 381 | " for name,parameters in model.named_parameters():\n", 382 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 383 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 384 | " loss=0.5*(criterian1(y,label)+loss)+L2\n", 385 | " loss_sum+=loss.item()\n", 386 | " optimizer.zero_grad()\n", 387 | " loss.backward()\n", 388 | " optimizer.step()\n", 389 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 390 | " result.loc[j]={'Loss':loss_sum,'Acc':acc,'F1':f1,'TP':TP_test,'TN':TN_test,'FP':FP_test,'FN':FN_test}\n", 391 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 392 | " print('\\n')\n", 393 | " return result.iloc[9]\n", 394 | "def train(model,trainloader,testloader):\n", 395 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 396 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 397 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 398 | " #optimizer = AdaBelief(model.parameters(), lr=1e-4, eps=1e-8, betas=(0.9,0.999), weight_decay=L2_lamda,weight_decouple = True, rectify = False)\n", 399 | " scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=gmma)\n", 400 | " acc=0.5000\n", 401 | " loss_sum=0\n", 402 | " for _ in range(epoch):\n", 403 | " print('\\rLoss: {:.2f} Acc:{:.4f}'.format(loss_sum,acc),end='')\n", 404 | " loss_sum=0\n", 405 | " TP,TN,FP,FN=0,0,0,0\n", 406 | " time_start=time.time()\n", 407 | " for (X,Z,A,label,sub_id) in trainloader:\n", 408 | " #print(A)\n", 409 | " #print(Z)\n", 410 | " #print(X.shape,torch.mean(X),torch.std(X))\n", 411 | " #X=X+torch.randn(X.shape)*X.std(0)\n", 412 | " #Z=Z+torch.randn(Z.shape)*Z.std(0)\n", 413 | " model.train()\n", 414 | " X=X.to(device)\n", 415 | " Z=Z.to(device)\n", 416 | " label=label.to(device)\n", 417 | " y=model(X,Z)\n", 418 | " #print(y)\n", 419 | " _,predict=torch.max(y,1)\n", 420 | " TP+=((predict==1)&(label==1)).sum().item()\n", 421 | " TN+=((predict==0)&(label==0)).sum().item()\n", 422 | " FN+=((predict==0)&(label==1)).sum().item()\n", 423 | " FP+=((predict==1)&(label==0)).sum().item()\n", 424 | " loss=criterian1(y,label)\n", 425 | " L1=torch.tensor(0,dtype=torch.float32).to(device)\n", 426 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 427 | " #print(model.parameters.weit_n)\n", 428 | " if L1_lamda>0 or L2_lamda>0:\n", 429 | " for name,parameters in model.named_parameters():\n", 430 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 431 | " L1+=L1_lamda*torch.norm(parameters,1)\n", 432 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 433 | " loss+=(L1+L2)\n", 434 | " loss_sum+=loss.item()\n", 435 | " optimizer.zero_grad()\n", 436 | " loss.backward()\n", 437 | " optimizer.step()\n", 438 | " scheduler.step()\n", 439 | " time_end=time.time()\n", 440 | " time_cost=(time_end-time_start)/60.0\n", 441 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 442 | " result.loc[j]={'Loss':loss_sum,'Acc':acc,'F1':f1,'TP':TP_test,'TN':TN_test,'FP':FP_test,'FN':FN_test}\n", 443 | " #acc,f1=cal_evaluate(TP,TN,FP,FN)\n", 444 | " #print(\"[%2d/%d] ACC:%.2f F1:%.2f Loss: %.4f [TP:%3d|TN:%3d|FP:%3d|FN:%3d] CostTime:%4.1f min | RestTime:%.2f h\" %(\n", 445 | " # j+1,epoch,acc,f1,loss_sum,TP,TN,FP,FN,time_cost,time_cost/60*(epoch-1-j)))\n", 446 | " #losses.append(loss_sum)\n", 447 | " #print(model.parameters())\n", 448 | " #acc,f1=test(device,model,testloader)\n", 449 | " #print(acc)\n", 450 | " #plot_acc(result['Acc'],result['Loss'])\n", 451 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 452 | " print('\\n')\n", 453 | " return result.iloc[9] \n", 454 | "def train_pgd(model,trainloader,testloader,eps=0.05,iters=10,alpha=2/255):\n", 455 | " result=pd.DataFrame(columns=('Loss','Acc','F1','TP','TN','FP','FN'))\n", 456 | " criterian1=LabelSmoothLoss(0.1).to(device)\n", 457 | " optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)\n", 458 | " for j in range(epoch):\n", 459 | " loss_sum=0.\n", 460 | " TP,TN,FP,FN=0,0,0,0\n", 461 | " model.train()\n", 462 | " for (X,Z,label,sub_id) in trainloader:\n", 463 | " model.train()\n", 464 | " x=X.to(device)\n", 465 | " z=Z.to(device)\n", 466 | " label=label.to(device)\n", 467 | " pretu_x,pretu_z=x,z\n", 468 | " ori_x,ori_z=x.data,z.data\n", 469 | " for i in range(iters):\n", 470 | " pretu_x.requires_grad=True\n", 471 | " pretu_z.requires_grad=True\n", 472 | " y=model(pretu_x,pretu_z)\n", 473 | " loss=criterian1(y,label)\n", 474 | " model.zero_grad()\n", 475 | " loss.backward()\n", 476 | " adv_x=pretu_x+alpha*torch.sign(pretu_x.grad.data)\n", 477 | " adv_z=pretu_z+alpha*torch.sign(pretu_z.grad.data)\n", 478 | " eta_x=torch.clamp(adv_x-ori_x,min=-eps,max=eps)\n", 479 | " eta_z=torch.clamp(adv_z-ori_z,min=-eps,max=eps)\n", 480 | " pretu_x=torch.clamp(ori_x+eta_x,min=0,max=1).detach_()\n", 481 | " pretu_z=torch.clamp(ori_z+eta_z,min=-1,max=1).detach_()\n", 482 | " y=model(x,z)\n", 483 | " yy=model(pretu_x,pretu_z)\n", 484 | " L2=torch.tensor(0,dtype=torch.float32).to(device)\n", 485 | " if L2_lamda>0:\n", 486 | " for name,parameters in model.named_parameters():\n", 487 | " if name[0:5]=='clase' and name[-8:]=='0.weight':\n", 488 | " L2+=L2_lamda*torch.norm(parameters,2)\n", 489 | " loss=0.5*(criterian1(yy,label)+criterian1(y,label))+L2\n", 490 | " loss_sum+=loss.item()\n", 491 | " optimizer.zero_grad()\n", 492 | " loss.backward()\n", 493 | " optimizer.step()\n", 494 | " if (j+1)%10==0:\n", 495 | " acc,f1,TP_test,TN_test,FP_test,FN_test=test(device,model,testloader)\n", 496 | " result.loc[(j+1)//10]=[loss_sum,acc,f1,TP_test,TN_test,FP_test,FN_test]\n", 497 | " result.sort_values('Acc',inplace=True,ascending=False)\n", 498 | " print(' FinalAcc: {:.4f}'.format(result.iloc[0]['Acc']))\n", 499 | " return result.iloc[0]['Acc']" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 4, 505 | "metadata": { 506 | "tags": [] 507 | }, 508 | "outputs": [ 509 | { 510 | "name": "stdout", 511 | "output_type": "stream", 512 | "text": [ 513 | " FinalAcc: 0.6875\n", 514 | " FinalAcc: 0.6774\n", 515 | " FinalAcc: 0.7017\n", 516 | " FinalAcc: 0.7778\n", 517 | " FinalAcc: 0.8462\n", 518 | " FinalAcc: 0.7900\n", 519 | " FinalAcc: 0.6667\n", 520 | " FinalAcc: 0.6182\n", 521 | " FinalAcc: 0.6078\n", 522 | " FinalAcc: 0.6389\n", 523 | " FinalAcc: 0.7708\n", 524 | " FinalAcc: 0.7000\n", 525 | " FinalAcc: 0.6909\n", 526 | " FinalAcc: 0.7500\n", 527 | " FinalAcc: 0.7172\n", 528 | " FinalAcc: 0.7895\n", 529 | " FinalAcc: 0.6000\n", 530 | "dropout_rate:0.0010 Acc:0.7001\n", 531 | " FinalAcc: 0.6806\n", 532 | " FinalAcc: 0.6452\n", 533 | " FinalAcc: 0.7127\n", 534 | " FinalAcc: 0.8889\n", 535 | " FinalAcc: 0.8462\n", 536 | " FinalAcc: 0.7700\n", 537 | " FinalAcc: 0.6296\n", 538 | " FinalAcc: 0.6727\n", 539 | " FinalAcc: 0.6667\n", 540 | " FinalAcc: 0.6944\n", 541 | " FinalAcc: 0.6667\n", 542 | " FinalAcc: 0.7500\n", 543 | " FinalAcc: 0.7455\n", 544 | " FinalAcc: 0.6786\n", 545 | " FinalAcc: 0.7172\n", 546 | " FinalAcc: 0.8421\n", 547 | " FinalAcc: 0.6250\n", 548 | "dropout_rate:0.0050 Acc:0.7051\n", 549 | " FinalAcc: 0.7292\n", 550 | " FinalAcc: 0.6774\n", 551 | " FinalAcc: 0.7017\n", 552 | " FinalAcc: 0.7778\n", 553 | " FinalAcc: 0.8462\n", 554 | " FinalAcc: 0.8000\n", 555 | " FinalAcc: 0.8148\n", 556 | " FinalAcc: 0.6364\n", 557 | " FinalAcc: 0.7059\n", 558 | " FinalAcc: 0.7222\n", 559 | " FinalAcc: 0.7292\n", 560 | " FinalAcc: 0.8000\n", 561 | " FinalAcc: 0.7818\n", 562 | " FinalAcc: 0.6786\n", 563 | " FinalAcc: 0.7374\n", 564 | " FinalAcc: 0.8421\n", 565 | " FinalAcc: 0.6000\n", 566 | "dropout_rate:0.0100 Acc:0.7279\n", 567 | " FinalAcc: 0.7014\n", 568 | " FinalAcc: 0.6613\n", 569 | " FinalAcc: 0.7459\n", 570 | " FinalAcc: 0.6667\n", 571 | " FinalAcc: 0.7692\n", 572 | " FinalAcc: 0.7800\n", 573 | " FinalAcc: 0.7407\n", 574 | " FinalAcc: 0.6727\n", 575 | " FinalAcc: 0.5882\n", 576 | " FinalAcc: 0.7222\n", 577 | " FinalAcc: 0.6667\n", 578 | " FinalAcc: 0.7500\n", 579 | " FinalAcc: 0.7455\n", 580 | " FinalAcc: 0.6071\n", 581 | " FinalAcc: 0.7071\n", 582 | " FinalAcc: 0.8421\n", 583 | " FinalAcc: 0.6500\n", 584 | "dropout_rate:0.0200 Acc:0.7110\n" 585 | ] 586 | } 587 | ], 588 | "source": [ 589 | "# leave one site out\n", 590 | "setup_seed(random_seed)\n", 591 | "L1_lamda=0.0\n", 592 | "L2_lamda=0.0001\n", 593 | "learning_rate=0.0001\n", 594 | "epoch =100\n", 595 | "batch_size=64\n", 596 | "gmma =0.95\n", 597 | "layer =5\n", 598 | "all_site=os.listdir(cpac_root)\n", 599 | "site_count={'CALTECH':19,'CMU':13,'KKI':51,'LEUVEN':62,'MAXMUN':55,'NYU':181,\n", 600 | " 'OHSU':28,'OLIN':36,'PITT':40,'SBL':9,'SDSU':27,'STANFORD':40,\n", 601 | " 'TRINITY':48,'UCLA':99,'UM':144,'USM':100,'YALE':55}\n", 602 | "for epision in [0.001,0.005,0.01,0.02]:\n", 603 | " sum_acc=0\n", 604 | " for site in all_site:\n", 605 | " test_site =[i for i in all_site if i==site]\n", 606 | " train_site=[i for i in all_site if i!=site]\n", 607 | " global max_acc\n", 608 | " max_acc=0.5\n", 609 | " global modelname\n", 610 | " modelname='/media/dm/0001A094000BF891/Yazid/SAVEDModels/adversialtrained/models_{}'.format(test_site[0])\n", 611 | " global saveModel\n", 612 | " saveModel=False\n", 613 | " train_asd,train_tdc=data_arange(train_site,fmri_root=cpac_root,smri_root=smri_root,nan_subid=nan_subid)\n", 614 | " test_asd,test_tdc =data_arange(test_site, fmri_root=cpac_root,smri_root=smri_root,nan_subid=nan_subid)\n", 615 | " trainset=dataset(site=train_site,fmri_root=cpac_root,smri_root=smri_root,ASD=train_asd,TDC=train_tdc,edge_dict=cal_dict())\n", 616 | " trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,drop_last=True)\n", 617 | " testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc,edge_dict=cal_dict())\n", 618 | " testloader=DataLoader(testset,batch_size=1)\n", 619 | " model=NEResGCN(layer,dropout_rate=0.2).to(device)\n", 620 | " acc=train_pgd(model,trainloader,testloader,eps=epision,iters=10,alpha=epision/5)\n", 621 | " sum_acc+=(site_count[test_site[0]]*acc)\n", 622 | " print('dropout_rate:%.4f Acc:%.4f'%(epision,sum_acc/1007))" 623 | ] 624 | }, 625 | { 626 | "cell_type": "raw", 627 | "metadata": {}, 628 | "source": [ 629 | "''' \n", 630 | "to find nan_subject_id \n", 631 | "'''\n", 632 | "nan_subid=[]\n", 633 | "for index in range(10):\n", 634 | " test_asd =test_asd_dict[index]\n", 635 | " test_tdc =test_tdc_dict[index]\n", 636 | " testset=dataset(site=test_site,fmri_root=cpac_root,smri_root=smri_root,ASD=test_asd,TDC=test_tdc,edge_dict=cal_dict())\n", 637 | " testloader=DataLoader(testset,batch_size=1)\n", 638 | " model=NEResGCN(5)\n", 639 | " model.eval()\n", 640 | " for (X,Z,A,label,sub_id) in testloader:\n", 641 | " sub_id=sub_id.item()\n", 642 | " y=model(X,Z,A)\n", 643 | " if torch.any(torch.isnan(y)):\n", 644 | " print(sub_id)\n", 645 | " nan_subid.append(sub_id)" 646 | ] 647 | }, 648 | { 649 | "cell_type": "markdown", 650 | "metadata": {}, 651 | "source": [ 652 | "**Epsilon**:***0.0010*** **Acc**: ***703***
\n", 653 | "**Epsilon**:***0.0050*** **Acc**: ***715***
\n", 654 | "**Epsilon**:***0.0100*** **Acc**: ***715***
\n", 655 | "**Epsilon**:***0.0200*** **Acc**: ***727***
\n", 656 | "**Epsilon**:***0.0500*** **Acc**: ***717***
\n", 657 | "**Epsilon**:***0.1000*** **Acc**: ***675***
\n", 658 | "**Epsilon**:***0.2000*** **Acc**: ***627***
" 659 | ] 660 | } 661 | ], 662 | "metadata": { 663 | "kernelspec": { 664 | "display_name": "Pytorch", 665 | "language": "python", 666 | "name": "pytorch" 667 | }, 668 | "language_info": { 669 | "codemirror_mode": { 670 | "name": "ipython", 671 | "version": 3 672 | }, 673 | "file_extension": ".py", 674 | "mimetype": "text/x-python", 675 | "name": "python", 676 | "nbconvert_exporter": "python", 677 | "pygments_lexer": "ipython3", 678 | "version": "3.6.13" 679 | }, 680 | "toc-autonumbering": false, 681 | "toc-showcode": true, 682 | "toc-showmarkdowntxt": false, 683 | "toc-showtags": false 684 | }, 685 | "nbformat": 4, 686 | "nbformat_minor": 4 687 | } 688 | --------------------------------------------------------------------------------