├── CNN-LSTM_iftarg_AllBrain.py ├── CNN_FNC_iftarg.py ├── MSCT.py ├── MSCT_FNC.py ├── MsCNN_ICA.py ├── README.md ├── area ├── AllBrain.nii ├── Angular_L.nii ├── Angular_R.nii ├── Cingulum_Ant_L.nii ├── Cingulum_Ant_R.nii ├── Cingulum_Mid_L.nii ├── Cingulum_Mid_R.nii ├── Cingulum_Post_L.nii ├── Cingulum_Post_R.nii ├── Frontal_Inf_Oper_L.nii ├── Frontal_Inf_Oper_R.nii ├── Frontal_Inf_Orb_L.nii ├── Frontal_Inf_Orb_R.nii ├── Frontal_Inf_Tri_L.nii ├── Frontal_Inf_Tri_R.nii ├── Fusiform_L.nii ├── Fusiform_R.nii ├── Hippocampus_L.nii ├── Hippocampus_R.nii ├── Parahippo_L.nii ├── Parahippo_R.nii ├── Parietal_Inf_L.nii ├── Parietal_Inf_R.nii ├── Parietal_Sup_L.nii ├── Parietal_Sup_R.nii ├── Precuneus_L.nii ├── Precuneus_R.nii ├── SupraMarginal_L.nii └── SupraMarginal_R.nii ├── multisc_CNN-LSTM_iftarg_AllBrain.py ├── multisc_CNN_FNC_iftarg.py ├── niicat.py ├── original_data.py └── sequence.py /CNN-LSTM_iftarg_AllBrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import torch 6 | import numpy as np 7 | import random 8 | import torch.nn.functional as F 9 | import time 10 | import nibabel as nib 11 | from torch import nn 12 | from torch import optim 13 | from torch.autograd import Variable 14 | from imblearn.over_sampling import RandomOverSampler 15 | 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | print(device) 18 | 19 | class CnnLSTM(nn.Module): 20 | def __init__(self): 21 | super(CnnLSTM, self).__init__() 22 | self.conv = torch.nn.Sequential( 23 | torch.nn.Conv3d( 24 | in_channels=1, 25 | out_channels=4, 26 | kernel_size=3, 27 | padding=0), 28 | torch.nn.ReLU(), 29 | torch.nn.MaxPool3d(3), 30 | torch.nn.Conv3d(4, 8, kernel_size=7, padding=0), 31 | torch.nn.ReLU() 32 | ) 33 | 34 | self.lstm = nn.LSTM( 35 | input_size=22984, 36 | hidden_size=1024, 37 | num_layers=2, 38 | batch_first=True 39 | ) 40 | 41 | self.fc = nn.Sequential( 42 | nn.Linear(1024, 256), 43 | nn.ReLU(), 44 | nn.Dropout(p=0.5), 45 | nn.Linear(256, 32), 46 | nn.ReLU(), 47 | nn.Dropout(p=0.5), 48 | nn.Linear(32, 2) 49 | ) 50 | 51 | def forward(self, x): 52 | x = self.conv(x) 53 | x = torch.flatten(x,start_dim=1) 54 | x = x.reshape(16, 22984, 1) 55 | x = torch.transpose(x, 0, 2) 56 | x = torch.transpose(x, 1, 2) 57 | out, (h_n,h_c) = self.lstm(x, None) 58 | out = F.relu(out) 59 | out = out[:, -1, :] 60 | out = self.fc(out) 61 | return out 62 | 63 | def dataReader(sub_list, task_list): 64 | Input = [] 65 | label = [] 66 | for task in task_list: 67 | for sub in sub_list: 68 | path_label = '/home/zmx/ds002311/event/' 69 | path_input = '/home/zmx/ds002311/preprocessed_4D/' + task + '/' 70 | if sub < 10: 71 | num = 'sub-0' + str(sub) 72 | else: 73 | num = 'sub-' + str(sub) 74 | 75 | input_name = num + '_' + task 76 | if task == 'mot_1': 77 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 78 | elif task == 'mot_2': 79 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 80 | elif task == 'mot_3': 81 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 82 | 83 | img = nib.load(path_input + input_name + '.nii') 84 | img = np.array(img.get_fdata()) 85 | template = nib.load('/home/zmx/fMRI/Template/area/AllBrain.nii') 86 | template = np.array(template.get_fdata()) 87 | for i in range(61): 88 | for j in range(73): 89 | for k in range(61): 90 | if template[i][j][k]==0: 91 | img[i][j][k] = np.zeros(405) 92 | 93 | with open(path_label + label_name + '.tsv','rt') as csvfile: 94 | reader = csv.DictReader(csvfile, delimiter='\t') 95 | cond = [row['cond'] for row in reader] 96 | 97 | for i in range(len(cond)): 98 | if cond[i] == 'targ_easy': 99 | cond[i] = 0 100 | elif cond[i] == 'targ_hard': 101 | cond[i] = 0 102 | elif cond[i] == 'lure_hard': 103 | cond[i] = 1 104 | del cond[24] # 最后一段时间不全 105 | label.extend(cond) 106 | label = list(map(int,label)) 107 | 108 | data = img[:,:,:,12:] # 从第13个时间点开始,删除前12个时间点 109 | 110 | for i in range(24): # 最后一段时间不全 111 | Input.append(data[:,:,:,16*i:16*i+16]) 112 | 113 | Input = np.array(Input) 114 | label = np.array(label) 115 | 116 | max_value = np.max(Input) # 获得最大值 117 | min_value = np.min(Input) # 获得最小值 118 | scalar = max_value - min_value # 获得间隔数量 119 | Input = list(map(lambda x: x / scalar, Input)) # 归一化 120 | 121 | Input = np.array(Input) 122 | return Input, label 123 | 124 | 125 | LSTM = CnnLSTM() 126 | LSTM.to(device) 127 | epochs = 50 128 | # 定义loss和optimizer 129 | optimizer = optim.Adam(LSTM.parameters(), lr=0.0001) 130 | criterion = nn.CrossEntropyLoss() 131 | 132 | # 训练 133 | correct = 0 134 | total = 0 135 | for epoch in range(epochs): 136 | 137 | if epoch % 2 == 0: # 衰减的学习率 138 | for p in optimizer.param_groups: 139 | p['lr'] *= 0.9 140 | #sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 141 | #task = ['mot_1','mot_2','mot_3'] 142 | for task in [['mot_1'],['mot_2']]: 143 | for sub in [[1,3],[5,6]]: 144 | train_x, train_y = dataReader(sub,task) 145 | 146 | smo = RandomOverSampler(random_state=42) # 处理样本数量不对称 147 | nsamples, nx, ny, nz, nt = train_x.shape 148 | d2_train_dataset = train_x.reshape((nsamples,nx*ny*nz*nt)) 149 | train_x, train_y = smo.fit_sample(d2_train_dataset, train_y) 150 | train_x = train_x.reshape(len(train_x), nx, ny, nz, nt) 151 | 152 | state = np.random.get_state() # 打乱顺序 153 | np.random.shuffle(train_x) 154 | np.random.set_state(state) 155 | np.random.shuffle(train_y) 156 | 157 | train_x = torch.from_numpy(train_x) 158 | train_x = torch.tensor(train_x, dtype=torch.float32) 159 | train_y = torch.from_numpy(train_y) 160 | 161 | for b_x, b_y in zip(train_x,train_y): 162 | b_x = b_x.reshape(-1, 61, 73, 61, 16) 163 | b_y = b_y.reshape(-1) 164 | b_x = torch.transpose(b_x, 3, 4) 165 | b_x = torch.transpose(b_x, 2, 3) 166 | b_x = torch.transpose(b_x, 1, 2) 167 | b_x = torch.transpose(b_x, 0, 1) 168 | b_x, b_y = b_x.to(device), b_y.to(device) 169 | output = LSTM(b_x) 170 | loss = criterion(output,b_y) 171 | 172 | optimizer.zero_grad() 173 | loss.backward() 174 | optimizer.step() 175 | 176 | _, predicted = torch.max(output.data, 1) 177 | total += b_y.size(0) 178 | correct += (predicted == b_y).sum().item() 179 | 180 | TimeStr = time.asctime(time.localtime(time.time())) 181 | print('Epoch: {} --- {}'.format(epoch, TimeStr)) 182 | print('Train Accuracy of the model: {} %'.format(100 * correct / total)) 183 | print('Target: {}'.format(b_y)) 184 | print('Output: {}'.format(torch.max(output, 1)[1])) 185 | print('Train Loss of the model: {}'.format(loss)) 186 | 187 | 188 | 189 | # # 测试 190 | with torch.no_grad(): 191 | print('--------test--------') 192 | print('--------test--------') 193 | print('--------test--------') 194 | correct = 0 195 | total = 0 196 | for task in [['mot_1']]: 197 | for sub in [10]: 198 | test_x, test_y = dataReader(sub,task) 199 | test_x = torch.from_numpy(test_x) 200 | test_x = torch.tensor(test_x, dtype=torch.float32) 201 | test_y = torch.from_numpy(test_y) 202 | for t_x, t_y in zip(test_x, test_y): 203 | t_x = t_x.reshape(-1, 61, 73, 61, 16) 204 | t_y = t_y.reshape(-1) 205 | t_x = torch.transpose(t_x, 3, 4) 206 | t_x = torch.transpose(t_x, 2, 3) 207 | t_x = torch.transpose(t_x, 1, 2) 208 | t_x = torch.transpose(t_x, 0, 1) 209 | t_x, t_y = t_x.to(device), t_y.to(device) 210 | output = LSTM(t_x) 211 | _, predicted = torch.max(output.data, 1) 212 | total += t_y.size(0) 213 | correct += (predicted == t_y).sum().item() 214 | print('Test Accuracy of the model: {} %'.format(100 * correct / total)) 215 | print('Output: {}'.format(torch.max(output, 1)[1])) 216 | -------------------------------------------------------------------------------- /CNN_FNC_iftarg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import torch 6 | import numpy as np 7 | import random 8 | import torch.nn.functional as F 9 | import time 10 | import nibabel as nib 11 | import matplotlib.pyplot as plt 12 | from math import * 13 | from torch import nn 14 | from torch import optim 15 | from torch.autograd import Variable 16 | from imblearn.over_sampling import SMOTE,ADASYN,RandomOverSampler 17 | 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | print(device) 20 | 21 | class Cnn(nn.Module): 22 | def __init__(self): 23 | super(Cnn, self).__init__() 24 | self.conv = torch.nn.Sequential( 25 | torch.nn.Conv2d( 26 | in_channels=1, 27 | out_channels=8, 28 | kernel_size=3, 29 | padding=0), 30 | nn.BatchNorm2d(8), 31 | torch.nn.ReLU(), 32 | torch.nn.MaxPool2d(2), 33 | torch.nn.Conv2d(8, 16, kernel_size=3, padding=0), 34 | nn.BatchNorm2d(16), 35 | torch.nn.ReLU() 36 | ) 37 | 38 | self.fc = nn.Sequential( 39 | nn.Linear(28224, 1024), 40 | nn.ReLU(), 41 | nn.Dropout(p=0.5), 42 | nn.Linear(1024, 128), 43 | nn.ReLU(), 44 | nn.Dropout(p=0.5), 45 | nn.Linear(128, 2) 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.conv(x) 50 | x = torch.flatten(x,start_dim=1) 51 | x = self.fc(x) 52 | return x 53 | 54 | def dataReader(sub_list, task_list): 55 | Input = [] 56 | label = [] 57 | for task in task_list: 58 | for sub in sub_list: 59 | path_label = '/home/zmx/ds002311/event/' 60 | if sub < 10: 61 | num = 'sub-0' + str(sub) 62 | else: 63 | num = 'sub-' + str(sub) 64 | 65 | path_input = '/home/zmx/ds002311/FNC/' + num + '/' + task + '/' 66 | 67 | input_name = num + '_' + task 68 | 69 | if task == 'mot_1': 70 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 71 | elif task == 'mot_2': 72 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 73 | elif task == 'mot_3': 74 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 75 | 76 | for i in range(24): 77 | img = np.loadtxt(open(path_input + input_name + '_' + str(i+1) + ".csv","rb"),delimiter=",",skiprows=0) 78 | Input.append(img) 79 | 80 | 81 | with open(path_label + label_name + '.tsv','rt') as csvfile: 82 | reader = csv.DictReader(csvfile, delimiter='\t') 83 | cond = [row['cond'] for row in reader] 84 | 85 | for i in range(len(cond)): 86 | if cond[i] == 'targ_easy': 87 | cond[i] = 0 88 | elif cond[i] == 'targ_hard': 89 | cond[i] = 0 90 | elif cond[i] == 'lure_hard': 91 | cond[i] = 1 92 | del cond[24] # 最后一段时间不全 93 | label.extend(cond) 94 | label = list(map(int,label)) 95 | 96 | Input = np.array(Input) 97 | label = np.array(label) 98 | 99 | return Input, label 100 | 101 | 102 | train_task = ['mot_1','mot_2'] 103 | train_sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 104 | 105 | train_x, train_y = dataReader(train_sub,train_task) 106 | 107 | plt.ion() #开启interactive mode 成功的关键函数 108 | plt.figure(1) 109 | t_now = 0 110 | t = [] 111 | loss_list = [] 112 | acc_list = [] 113 | 114 | CNN = Cnn() 115 | CNN.to(device) 116 | epochs = 10 117 | batch_size = 16 118 | # 定义loss和optimizer 119 | optimizer = optim.Adam(CNN.parameters(), lr=0.001, weight_decay=0.01) 120 | criterion = nn.CrossEntropyLoss().to(device) 121 | 122 | # 训练 123 | correct = 0 124 | total = 0 125 | CNN.train(mode=True) 126 | for epoch in range(epochs): 127 | if epoch % 5 == 0: # 衰减的学习率 128 | for p in optimizer.param_groups: 129 | p['lr'] *= 0.9 130 | 131 | # smo = RandomOverSampler(random_state=42) # 处理样本数量不对称 132 | # smo = ADASYN(random_state=42) 133 | smo = SMOTE(random_state=42) 134 | nsamples, nx, ny = train_x.shape 135 | d2_train_dataset = train_x.reshape((nsamples,nx*ny)) 136 | train_x_smo, train_y_smo = smo.fit_sample(d2_train_dataset, train_y) 137 | train_x_smo = train_x_smo.reshape(len(train_x_smo), nx, ny) 138 | 139 | state = np.random.get_state() # 打乱顺序 140 | np.random.shuffle(train_x_smo) 141 | np.random.set_state(state) 142 | np.random.shuffle(train_y_smo) 143 | 144 | train_x_smo = torch.from_numpy(train_x_smo) 145 | train_x_smo = train_x_smo.type(torch.FloatTensor) 146 | train_y_smo = torch.from_numpy(train_y_smo) 147 | 148 | for i in range(0, len(train_x_smo) - batch_size, batch_size): 149 | loss_batch = 0 150 | 151 | for b_x, b_y in zip(train_x_smo[i:i+batch_size],train_y_smo[i:i+batch_size]): 152 | b_x = b_x.reshape(-1, 90, 90) 153 | b_x = b_x.reshape(-1, 1, 90, 90) 154 | b_y = b_y.reshape(-1) 155 | b_x, b_y = b_x.to(device), b_y.to(device) 156 | output = CNN(b_x) 157 | loss = criterion(output,b_y) 158 | loss_batch += loss 159 | 160 | _, predicted = torch.max(output.data, 1) 161 | total += b_y.size(0) 162 | correct += (predicted == b_y).sum().item() 163 | # print('Target: {}'.format(b_y)) 164 | # print('Output: {}'.format(torch.max(output, 1)[1])) 165 | 166 | loss_batch = loss_batch / batch_size 167 | optimizer.zero_grad() 168 | loss_batch.backward() 169 | optimizer.step() 170 | 171 | TimeStr = time.asctime(time.localtime(time.time())) 172 | print('Epoch: {} --- {}'.format(epoch, TimeStr)) 173 | print('Train Accuracy of the model: {} %'.format(100 * correct / total)) 174 | print('Train Loss of this batch: {}'.format(loss_batch)) 175 | 176 | # if i % 5 == 0: # 隔一定数量的batch画图 177 | # t.append(t_now) 178 | # loss_list.append(loss_batch) 179 | # acc_list.append(100 * correct / total) 180 | # plt.subplot(2,1,1) 181 | # plt.plot(t,loss_list,'-r') 182 | # plt.title('loss',fontsize=10) 183 | # plt.tight_layout(h_pad=1) 184 | # plt.subplot(2,1,2) 185 | # plt.plot(t,acc_list,'-b') 186 | # plt.title('acc',fontsize=10) 187 | # plt.draw() 188 | # plt.pause(0.01) 189 | # t_now += 5 190 | 191 | 192 | 193 | 194 | 195 | # # 测试 196 | CNN.eval() 197 | with torch.no_grad(): 198 | print('--------test--------') 199 | print('--------test--------') 200 | print('--------test--------') 201 | correct = 0 202 | total = 0 203 | test_task = ['mot_3'] 204 | test_sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 205 | test_x, test_y = dataReader(test_sub,test_task) 206 | 207 | smo = RandomOverSampler(random_state=42) 208 | nsamples, nx, ny = test_x.shape 209 | d2_train_dataset = test_x.reshape((nsamples,nx*ny)) 210 | test_x, test_y = smo.fit_sample(d2_train_dataset, test_y) 211 | test_x = test_x.reshape(len(test_x), nx, ny) 212 | 213 | test_x = torch.from_numpy(test_x) 214 | test_x = torch.tensor(test_x, dtype=torch.float32) 215 | test_y = torch.from_numpy(test_y) 216 | for t_x, t_y in zip(test_x, test_y): 217 | t_x = t_x.reshape(-1, 90, 90) 218 | t_x = t_x.reshape(-1, 1, 90, 90) 219 | t_y = t_y.reshape(-1) 220 | t_x, t_y = t_x.to(device), t_y.to(device) 221 | output = CNN(t_x) 222 | loss = criterion(output,t_y) 223 | _, predicted = torch.max(output.data, 1) 224 | total += t_y.size(0) 225 | correct += (predicted == t_y).sum().item() 226 | print('Test Accuracy of the model: {} %'.format(100 * correct / total)) 227 | print('Target: {}'.format(t_y)) 228 | print('Output: {}'.format(torch.max(output, 1)[1])) 229 | print('Test Loss of the model: {}'.format(loss)) 230 | -------------------------------------------------------------------------------- /MSCT.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import csv 4 | import torch 5 | import numpy as np 6 | import random 7 | import torch.nn.functional as F 8 | import time 9 | import nibabel as nib 10 | from torch import nn 11 | from torch import optim 12 | import torch.utils.data as Data 13 | from torch.autograd import Variable 14 | from imblearn.over_sampling import RandomOverSampler,ADASYN,SMOTE 15 | import matplotlib.pyplot as plt 16 | 17 | RANDOM_SEED = 408 18 | # RANDOM_SEED = 296 19 | NUM_IC = 32 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | print(device) 23 | 24 | 25 | class Cnn(nn.Module): 26 | def __init__(self): 27 | super(Cnn, self).__init__() 28 | self.conv1 = torch.nn.Sequential( 29 | torch.nn.Conv2d( 30 | in_channels=1, 31 | out_channels=32, 32 | kernel_size=[3,NUM_IC-1], 33 | padding=[1,0]), 34 | torch.nn.BatchNorm2d(32), 35 | torch.nn.LeakyReLU(), 36 | ) 37 | self.conv2 = torch.nn.Sequential( 38 | torch.nn.Conv2d( 39 | in_channels=1, 40 | out_channels=32, 41 | kernel_size=[5,NUM_IC-1], 42 | padding=[2,0]), 43 | torch.nn.BatchNorm2d(32), 44 | torch.nn.LeakyReLU(), 45 | ) 46 | self.conv3 = torch.nn.Sequential( 47 | torch.nn.Conv2d( 48 | in_channels=1, 49 | out_channels=32, 50 | kernel_size=[7,NUM_IC-1], 51 | padding=[3, 0]), 52 | torch.nn.BatchNorm2d(32), 53 | torch.nn.LeakyReLU(), 54 | ) 55 | self.conv4 = torch.nn.Sequential( 56 | torch.nn.Conv2d( 57 | in_channels=1, 58 | out_channels=32, 59 | kernel_size=[9,NUM_IC-1], 60 | padding=[4, 0]), 61 | torch.nn.BatchNorm2d(32), 62 | torch.nn.LeakyReLU(), 63 | ) 64 | self.lstm = nn.GRU( 65 | input_size=256, 66 | hidden_size=128, 67 | num_layers=2, 68 | batch_first=True, 69 | dropout=0.3 70 | ) 71 | self.lstm2 = nn.GRU( 72 | input_size=256+128, 73 | hidden_size=32, 74 | num_layers=3, 75 | batch_first=True, 76 | dropout=0.3 77 | ) 78 | 79 | self.fc = nn.Sequential( 80 | nn.Linear(32, 8), 81 | nn.Dropout(p=0.3), 82 | nn.Linear(8, 2), 83 | ) 84 | 85 | def forward(self, x): 86 | import torch 87 | x1 = self.conv1(x) 88 | x2 = self.conv2(x) 89 | x3 = self.conv3(x) 90 | x4 = self.conv4(x) 91 | x1 = torch.transpose(x1,1,2) 92 | x1 = x1.flatten(2).cpu() 93 | x2 = torch.transpose(x2, 1, 2) 94 | x2 = x2.flatten(2).cpu() 95 | x3 = torch.transpose(x3, 1, 2) 96 | x3 = x3.flatten(2).cpu() 97 | x4 = torch.transpose(x4, 1, 2) 98 | x4 = x4.flatten(2).cpu() 99 | x = torch.cat((x1,x2), 2) 100 | x = torch.cat((x, x3), 2) 101 | x = torch.cat((x, x4), 2) 102 | x = x.to(device) 103 | 104 | out, (h_n) = self.lstm(x, None) 105 | out = torch.cat((out.cpu(),x.cpu()),2).to(device) 106 | out, (h_n) = self.lstm2(out, None) 107 | out = torch.mean(out,dim=1,keepdim=False) 108 | x = self.fc(out) 109 | return x, out 110 | 111 | 112 | def dataReader(sub_list, task_list): 113 | Input = [] 114 | label = [] 115 | for task in task_list: 116 | for sub in sub_list: 117 | path_label = '/home/zmx/ds002311/event/' 118 | if sub < 10: 119 | num = 'sub-0' + str(sub) 120 | else: 121 | num = 'sub-' + str(sub) 122 | datanum=int(np.argwhere(np.array(sub_list)==sub)[0])+1 123 | if datanum<10: 124 | dn='00' + str(datanum) 125 | else: 126 | dn='0' + str(datanum) 127 | 128 | input_name = num + '_' + task 129 | 130 | if task == 'mot_1': 131 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 132 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s1_.nii' 133 | elif task == 'mot_2': 134 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 135 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s2_.nii' 136 | elif task == 'mot_3': 137 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 138 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s3_.nii' 139 | 140 | img = nib.load(datapath) 141 | img = np.array(img.get_fdata()) 142 | data = img[12:,:] 143 | # data = np.delete(data,[0,1,2,3,4,5,6,9,11,12,19,30,32,37],1) 144 | data = np.delete(data,[0,2,3,11,12,32,37],1) 145 | 146 | for i in range(24): 147 | time_course = data[16 * i + 1:16 * i + 13,:] 148 | time_course = np.array(time_course) 149 | 150 | time_course = time_course.T # 沿时间轴标准化 151 | for ic in range(NUM_IC): 152 | for x in time_course[ic]: 153 | x = float(x - np.mean(time_course[ic]))/np.std(time_course[ic]) 154 | time_course = time_course.T 155 | Input.append(time_course) 156 | 157 | with open(path_label + label_name + '.tsv', 'rt') as csvfile: 158 | reader = csv.DictReader(csvfile, delimiter='\t') 159 | cond = [row['cond'] for row in reader] 160 | 161 | for i in range(len(cond)): 162 | if cond[i] == 'targ_easy': 163 | cond[i] = 0 164 | elif cond[i] == 'targ_hard': 165 | cond[i] = 0 166 | elif cond[i] == 'lure_hard': 167 | cond[i] = 1 168 | del cond[24] # 最后一段时间不全 169 | label.extend(cond) 170 | label = list(map(int, label)) 171 | 172 | Input = np.array(Input) 173 | label = np.array(label) 174 | 175 | return Input, label 176 | 177 | 178 | CNN = Cnn() 179 | CNN.to(device) 180 | epochs = 50 181 | # 定义loss和optimizer 182 | optimizer = optim.Adam(CNN.parameters(), lr=0.001, weight_decay=0.02) 183 | criterion = nn.CrossEntropyLoss() 184 | scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3,factor=0.5,threshold=1e-3,) 185 | 186 | # 训练 187 | train_correct = 0 188 | train_total = 0 189 | TP = 0 190 | FP = 0 191 | FN = 0 192 | batch_size= 64 193 | data_x, data_y = dataReader([1, 3, 5, 6, 7, 8, 9, 10, 13, 14, 15, 18, 21, 22, 23], ['mot_1', 'mot_2', 'mot_3']) 194 | 195 | from sklearn.model_selection import train_test_split 196 | 197 | train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.2, random_state=RANDOM_SEED) 198 | smo = SMOTE(random_state=RANDOM_SEED) # 处理样本数量不对称 199 | nsamples, nx, ny = train_x.shape 200 | d2_train_dataset = train_x.reshape((nsamples, nx * ny)) 201 | train_x, train_y = smo.fit_sample(d2_train_dataset, train_y) 202 | train_x = train_x.reshape(len(train_x), nx, ny) 203 | 204 | train_x = torch.from_numpy(train_x) 205 | train_x = torch.tensor(train_x, dtype=torch.float32) 206 | train_y = torch.from_numpy(train_y) 207 | 208 | test_x = torch.from_numpy(test_x) 209 | test_x = torch.tensor(test_x, dtype=torch.float32) 210 | test_y = torch.from_numpy(test_y) 211 | 212 | torch_dataset = Data.TensorDataset(train_x, train_y) 213 | loader = Data.DataLoader( 214 | dataset=torch_dataset, 215 | batch_size=batch_size, # 批大小 216 | shuffle=True, 217 | num_workers=2, 218 | ) 219 | 220 | plt.ion() # 开启interactive mode 绘制图形 221 | plt.figure(1) 222 | t_now = 0 223 | t = [] 224 | loss_list = [] 225 | acc_list = [] 226 | pre_list = [] 227 | rec_list = [] 228 | test_acc_list = [] 229 | test_pre_list = [] 230 | test_rec_list = [] 231 | test_if_legend = 1 232 | if_legend = 1 233 | from sklearn.manifold import TSNE 234 | tsne=TSNE(n_components=2, random_state=RANDOM_SEED, init ='pca') 235 | 236 | for epoch in range(1, epochs + 1): 237 | CNN.train(mode=True) 238 | X_train_tsne=[] 239 | for i, (x, y) in enumerate(loader): 240 | batch_x = Variable(x) 241 | batch_x = batch_x.reshape(len(batch_x),1,12,NUM_IC) 242 | batch_y = Variable(y) 243 | batch_x = batch_x.to(device) 244 | batch_y = batch_y.to(device) 245 | output, temp = CNN(batch_x) 246 | loss = criterion(output,batch_y) 247 | 248 | _, predicted = torch.max(output.data, 1) 249 | zes = Variable(torch.zeros(len(batch_x)).type(torch.LongTensor)) #全0变量 250 | ons = Variable(torch.ones(len(batch_x)).type(torch.LongTensor)) #全1变量 251 | zes = zes.to(device) 252 | ons = ons.to(device) 253 | TP += ((predicted==ons)&(batch_y==ons)).sum() 254 | FP += ((predicted==ons)&(batch_y==zes)).sum() 255 | FN += ((predicted==zes)&(batch_y==ons)).sum() 256 | train_total += batch_y.size(0) 257 | train_correct += (predicted == batch_y).sum().item() 258 | if epoch == epochs - 1: 259 | temp = temp.cpu() 260 | temp = temp.detach().numpy() 261 | X_train_tsne.extend(temp) 262 | 263 | optimizer.zero_grad() 264 | loss.backward() 265 | optimizer.step() 266 | 267 | if i % 5 == 0: # 隔一定数量的batch画图 268 | t.append(t_now) 269 | loss_list.append(loss) 270 | acc_list.append(100 * train_correct / train_total) 271 | pre_list.append(100 * TP / (TP + FP)) 272 | rec_list.append(100 * TP / (TP + FN)) 273 | plt.subplot(1,3,1) 274 | plt.plot(t,loss_list,'-r') 275 | plt.title('train loss',fontsize=15) 276 | plt.xlabel(u"Number of train batches") 277 | plt.ylabel(u"loss") 278 | plt.tight_layout(h_pad=2) 279 | plt.subplot(1,3,2) 280 | plt.plot(t,acc_list,color='orange',label='accuracy') 281 | plt.plot(t,pre_list,color='green',label='precision') 282 | plt.plot(t,rec_list,color='blue',label='recall') 283 | if if_legend: 284 | plt.legend() 285 | if_legend = 0 286 | plt.title('train Accuracy&Precision&Recall',fontsize=15) 287 | plt.xlabel(u"Number of train batches") 288 | plt.ylabel(u"%") 289 | t_now += 5 290 | plt.draw() 291 | plt.pause(0.01) 292 | 293 | TimeStr = time.asctime(time.localtime(time.time())) 294 | print('Epoch: {} --- {} --- '.format(epoch, TimeStr)) 295 | print('Train Accuracy of the model: {} %'.format(100 * train_correct / train_total)) 296 | print('Train Loss of the model: {}'.format(loss)) 297 | print('Learning rate: {}'.format(optimizer.param_groups[0]['lr'])) 298 | # 调整学习率 299 | scheduler.step(loss) 300 | 301 | CNN.eval() 302 | with torch.no_grad(): 303 | test_correct = 0 304 | test_total = 0 305 | test_avg_loss = 0 306 | test_TP = 0 307 | test_FP = 0 308 | test_FN = 0 309 | X_test_tsne=[] 310 | 311 | test_x = test_x.reshape(len(test_x),1,12,NUM_IC) 312 | test_x, test_y = test_x.to(device), test_y.to(device) 313 | 314 | test_output, temp2 = CNN(test_x) 315 | test_avg_loss += criterion(test_output, test_y) 316 | _, predicted = torch.max(test_output.data, 1) 317 | zes = Variable(torch.zeros(len(test_x)).type(torch.LongTensor)) #全0变量 318 | ons = Variable(torch.ones(len(test_x)).type(torch.LongTensor)) #全1变量 319 | zes = zes.to(device) 320 | ons = ons.to(device) 321 | test_TP += ((predicted==ons)&(test_y==ons)).sum() 322 | test_FP += ((predicted==ons)&(test_y==zes)).sum() 323 | test_FN += ((predicted==zes)&(test_y==ons)).sum() 324 | test_total += test_y.size(0) 325 | test_correct += (predicted == test_y).sum().item() 326 | 327 | if epoch == epochs - 1: 328 | temp2 = temp2.cpu() 329 | temp2 = temp2.detach().numpy() 330 | X_test_tsne = temp2 331 | 332 | test_avg_loss = test_avg_loss / len(test_y) 333 | test_acc_list.append(100 * test_correct / test_total) 334 | test_pre_list.append(100 * test_TP / (test_TP + test_FP)) 335 | test_rec_list.append(100 * test_TP / (test_TP + test_FN)) 336 | plt.subplot(1,3,3) 337 | plt.plot([n for n in range(0,epoch)],test_acc_list,color='orange',label='accuracy') 338 | plt.plot([n for n in range(0,epoch)],test_pre_list,color='green',label='precision') 339 | plt.plot([n for n in range(0,epoch)],test_rec_list,color='blue',label='recall') 340 | if test_if_legend: 341 | plt.legend() 342 | test_if_legend = 0 343 | plt.title('test Accuracy&Precision&Recall',fontsize=15) 344 | plt.xlabel(u"Number of train epoches") 345 | plt.ylabel(u"%") 346 | print('Test Accuracy of the model: {} %'.format(100 * test_correct / test_total)) 347 | print('Test loss of the model: {}'.format(test_avg_loss)) 348 | 349 | plt.pause(0) 350 | 351 | # 绘制数据分布 352 | ''' 353 | X_train_tsne = np.array(X_train_tsne) 354 | train_size = X_train_tsne.shape[0] 355 | X_test_tsne = np.array(X_test_tsne) 356 | X_tsne = np.concatenate((X_train_tsne,X_test_tsne)) 357 | X_tsne = tsne.fit_transform(X_tsne) 358 | x_min, x_max = X_tsne.min(0), X_tsne.max(0) 359 | X_norm = (X_tsne - x_min) / (x_max - x_min) # 归一化 360 | 361 | lure_train_x = [] 362 | lure_train_y = [] 363 | targ_train_x = [] 364 | targ_train_y = [] 365 | lure_test_x = [] 366 | lure_test_y = [] 367 | targ_test_x = [] 368 | targ_test_y = [] 369 | for m in range(X_norm.shape[0]): 370 | if m < train_size: 371 | if int(train_y[m]): 372 | lure_train_x.append(X_norm[m, 0]) 373 | lure_train_y.append(X_norm[m, 1]) 374 | else: 375 | targ_train_x.append(X_norm[m, 0]) 376 | targ_train_y.append(X_norm[m, 1]) 377 | else: 378 | if int(test_y[m - train_size]): 379 | lure_test_x.append(X_norm[m, 0]) 380 | lure_test_y.append(X_norm[m, 1]) 381 | else: 382 | targ_test_x.append(X_norm[m, 0]) 383 | targ_test_y.append(X_norm[m, 1]) 384 | 385 | plt.subplot(1,2,1) 386 | plt.title("train result") 387 | plt.scatter(targ_train_x, targ_train_y, s=120,marker = ".", color='orange', label='targ', edgecolor='black',alpha=1) 388 | plt.scatter(lure_train_x, lure_train_y, s=120,marker = ".", color='green', label='lure', edgecolor='black',alpha=1) 389 | plt.legend() 390 | 391 | plt.subplot(1,2,2) 392 | plt.title("test result") 393 | plt.scatter(targ_test_x, targ_test_y, s=120,marker = ".", color='orange', label='targ', edgecolor='black',alpha=1) 394 | plt.scatter(lure_test_x, lure_test_y, s=120,marker = ".", color='green', label='lure', edgecolor='black',alpha=1) 395 | plt.legend() 396 | 397 | plt.draw() 398 | plt.pause(0) 399 | ''' 400 | -------------------------------------------------------------------------------- /MSCT_FNC.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | import numpy as np 4 | import random 5 | import torch.nn.functional as F 6 | import time 7 | import nibabel as nib 8 | from torch import nn 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from imblearn.over_sampling import RandomOverSampler,ADASYN 12 | 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | print(device) 15 | 16 | 17 | class Cnn(nn.Module): 18 | def __init__(self): 19 | super(Cnn, self).__init__() 20 | self.scale1_conv = torch.nn.Sequential( 21 | torch.nn.Conv2d( 22 | in_channels=1, 23 | out_channels=32, 24 | kernel_size=[3,31], 25 | padding=[1,0]), 26 | torch.nn.BatchNorm2d(32), 27 | torch.nn.ReLU() 28 | ) 29 | self.scale2_conv = torch.nn.Sequential( 30 | torch.nn.BatchNorm2d(1), 31 | torch.nn.Conv2d( 32 | in_channels=1, 33 | out_channels=32, 34 | kernel_size=[5,31], 35 | padding=[2,0]), 36 | torch.nn.BatchNorm2d(32), 37 | torch.nn.ReLU() 38 | ) 39 | self.FNC_conv = torch.nn.Sequential( 40 | torch.nn.Conv2d( 41 | in_channels=1, 42 | out_channels=8, 43 | kernel_size=3, 44 | padding=0), 45 | nn.BatchNorm2d(8), 46 | torch.nn.ReLU(), 47 | torch.nn.MaxPool2d(2), 48 | torch.nn.Conv2d(8, 16, kernel_size=3, padding=0), 49 | nn.BatchNorm2d(16), 50 | torch.nn.ReLU() 51 | ) 52 | self.lstm = nn.LSTM( 53 | input_size=64, 54 | hidden_size=32, 55 | num_layers=1, 56 | batch_first=True, 57 | # dropout=0.5 58 | ) 59 | self.lstm2 = nn.LSTM( 60 | input_size=96, 61 | hidden_size=32, 62 | num_layers=1, 63 | batch_first=True, 64 | # dropout=0.5 65 | ) 66 | 67 | self.fc = nn.Sequential( 68 | nn.Linear(32, 8), 69 | nn.ReLU(), 70 | nn.Dropout(p=0.5), 71 | nn.Linear(8, 2), 72 | ) 73 | 74 | def forward(self, x, x_FNC): 75 | x1 = self.scale1_conv(x) 76 | x2 = self.scale2_conv(x) 77 | x1=torch.transpose(x1, 1, 2) 78 | x1=x1.flatten(2).cpu() 79 | x2 = torch.transpose(x2, 1, 2) 80 | x2 = x2.flatten(2).cpu() 81 | x=torch.cat((x1,x2),2).to(device) 82 | out, (h_n, h_c) = self.lstm(x, None) 83 | out=torch.cat((out.cpu(),x.cpu()),2).to(device) 84 | out, (h_n, h_c) = self.lstm2(out, None) 85 | out = F.relu(out) 86 | out = out[:, -1, :] 87 | out_FNC = self.FNC_conv(x_FNC) 88 | print(out.shape) 89 | print(out_FNC.shape) 90 | x = self.fc(out) 91 | return x 92 | 93 | 94 | def dataReader(sub_list, task_list): 95 | Input = [] 96 | label = [] 97 | for task in task_list: 98 | for sub in sub_list: 99 | path_label = '/home/zmx/ds002311/event/' 100 | if sub < 10: 101 | num = 'sub-0' + str(sub) 102 | else: 103 | num = 'sub-' + str(sub) 104 | datanum=int(np.argwhere(np.array(sub_list)==sub)[0])+1 105 | if datanum<10: 106 | dn='00' + str(datanum) 107 | else: 108 | dn='0' + str(datanum) 109 | 110 | input_name = num + '_' + task 111 | 112 | if task == 'mot_1': 113 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 114 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s1_.nii' 115 | elif task == 'mot_2': 116 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 117 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s2_.nii' 118 | elif task == 'mot_3': 119 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 120 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s3_.nii' 121 | 122 | img = nib.load(datapath) 123 | img = np.array(img.get_fdata()) 124 | data=img[12:,:] 125 | 126 | for i in range(24): 127 | time_course = data[16 * i:16 * i + 16,:] 128 | time_course = np.array(time_course) 129 | 130 | time_course = time_course.T # 沿时间轴标准化 131 | for ic in range(31): 132 | for x in time_course[ic]: 133 | x = float(x - np.mean(time_course[ic]))/np.std(time_course[ic]) 134 | time_course = time_course.T 135 | Input.append(time_course) 136 | 137 | with open(path_label + label_name + '.tsv', 'rt') as csvfile: 138 | reader = csv.DictReader(csvfile, delimiter='\t') 139 | cond = [row['cond'] for row in reader] 140 | 141 | for i in range(len(cond)): 142 | if cond[i] == 'targ_easy': 143 | cond[i] = 0 144 | elif cond[i] == 'targ_hard': 145 | cond[i] = 0 146 | elif cond[i] == 'lure_hard': 147 | cond[i] = 1 148 | del cond[24] # 最后一段时间不全 149 | label.extend(cond) 150 | label = list(map(int, label)) 151 | 152 | Input = np.array(Input) 153 | label = np.array(label) 154 | 155 | return Input, label, FNC 156 | 157 | 158 | CNN = Cnn() 159 | CNN.to(device) 160 | epochs = 50 161 | # 定义loss和optimizer 162 | optimizer = optim.Adam(CNN.parameters(), lr=0.001, weight_decay=0.002) 163 | criterion = nn.CrossEntropyLoss() 164 | 165 | # 训练 166 | train_correct = 0 167 | train_total = 0 168 | batch_size=16 169 | data_x, data_y, data_FNC = dataReader([1, 3, 5, 6, 7, 8, 9, 10, 13, 14, 15, 18, 21, 22, 23], ['mot_1', 'mot_2', 'mot_3']) 170 | 171 | from sklearn.model_selection import train_test_split 172 | 173 | train_x, test_x, train_y, test_y, train_FNC, test_FNC = train_test_split(data_x, data_y, data_FNC, test_size=0.2) 174 | smo = ADASYN(random_state=42) # 处理样本数量不对称 175 | nsamples, nx, ny = train_x.shape 176 | d2_train_dataset = train_x.reshape((nsamples, nx * ny)) 177 | train_x, train_y = smo.fit_sample(d2_train_dataset, train_y) 178 | train_x = train_x.reshape(len(train_x), nx, ny) 179 | 180 | smo = RandomOverSampler(random_state=42) 181 | nsamples, nx, ny = test_x.shape 182 | d2_train_dataset = test_x.reshape((nsamples, nx * ny)) 183 | test_x, test_y = smo.fit_sample(d2_train_dataset, test_y) 184 | test_x = test_x.reshape(len(test_x), nx, ny) 185 | 186 | test_x = torch.from_numpy(test_x) 187 | test_x = torch.tensor(test_x, dtype=torch.float32) 188 | test_y = torch.from_numpy(test_y) 189 | 190 | for epoch in range(epochs): 191 | #CNN.train(mode=True) 192 | if epoch: 193 | train_x=np.array(train_x) 194 | train_y=np.array(train_y) 195 | 196 | state = np.random.get_state() # 打乱顺序 197 | np.random.shuffle(train_x) 198 | np.random.set_state(state) 199 | np.random.shuffle(train_y) 200 | 201 | train_x = torch.from_numpy(train_x) 202 | train_x = torch.tensor(train_x, dtype=torch.float32) 203 | train_y = torch.from_numpy(train_y) 204 | train_y = torch.tensor(train_y, dtype=torch.long) 205 | for i in range(0, len(train_x) - batch_size, batch_size): 206 | loss_batch = 0 207 | for b_x, b_y in zip(train_x[i:i+batch_size], train_y[i:i+batch_size]): 208 | b_x = b_x.reshape(-1, 16, 31) 209 | b_x = b_x.reshape(-1, 1, 16, 31) 210 | b_y = b_y.reshape(-1) 211 | b_x, b_y = b_x.to(device), b_y.to(device) 212 | 213 | output = CNN(b_x) 214 | loss = criterion(output, b_y) 215 | loss_batch += loss 216 | 217 | _, predicted = torch.max(output.data, 1) 218 | train_total += b_y.size(0) 219 | train_correct += (predicted == b_y).sum().item() 220 | 221 | loss_batch = loss_batch / batch_size 222 | optimizer.zero_grad() 223 | loss_batch.backward() 224 | optimizer.step() 225 | 226 | TimeStr = time.asctime(time.localtime(time.time())) 227 | print('Epoch: {} --- {}'.format(epoch, TimeStr)) 228 | print('Train Accuracy of the model: {} %'.format(100 * train_correct / train_total)) 229 | print('Train Loss of the model: {}'.format(loss_batch)) 230 | 231 | # 衰减的学习率 232 | for p in optimizer.param_groups: 233 | p['lr'] *= 0.95 234 | 235 | # 每个epoch测试一次查看loss和准确率 236 | #CNN.eval() 237 | with torch.no_grad(): 238 | test_correct = 0 239 | test_total = 0 240 | test_avg_loss = 0 241 | 242 | for t_x, t_y in zip(test_x, test_y): 243 | t_x = t_x.reshape(-1, 16, 31) 244 | t_x = t_x.reshape(-1, 1, 16, 31) 245 | t_y = t_y.reshape(-1) 246 | t_x, t_y = t_x.to(device), t_y.to(device) 247 | test_output = CNN(t_x) 248 | test_avg_loss += criterion(test_output, t_y) 249 | _, predicted = torch.max(test_output.data, 1) 250 | test_total += t_y.size(0) 251 | test_correct += (predicted == t_y).sum().item() 252 | test_avg_loss = test_avg_loss / len(test_y) 253 | print('Test Accuracy of the model: {} %'.format(100 * test_correct / test_total)) 254 | print('Test loss of the model: {}'.format(test_avg_loss)) 255 | -------------------------------------------------------------------------------- /MsCNN_ICA.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import csv 4 | import torch 5 | import numpy as np 6 | import random 7 | import torch.nn.functional as F 8 | import time 9 | import nibabel as nib 10 | from torch import nn 11 | from torch import optim 12 | from torch.autograd import Variable 13 | from imblearn.over_sampling import RandomOverSampler,ADASYN,SMOTE 14 | import matplotlib.pyplot as plt 15 | 16 | RANDOM_SEED = 408 # 测试集0.05时 17 | # RANDOM_SEED = 296 # 测试集0.1时 18 | 19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | print(device) 21 | 22 | 23 | class Cnn(nn.Module): 24 | def __init__(self): 25 | super(Cnn, self).__init__() 26 | self.conv1 = torch.nn.Sequential( 27 | torch.nn.BatchNorm2d(1), 28 | torch.nn.Conv2d( 29 | in_channels=1, 30 | out_channels=32, 31 | kernel_size=[3,31], 32 | padding=[1,0]), 33 | torch.nn.BatchNorm2d(32), 34 | torch.nn.LeakyReLU(), 35 | ) 36 | self.conv2 = torch.nn.Sequential( 37 | torch.nn.BatchNorm2d(1), 38 | torch.nn.Conv2d( 39 | in_channels=1, 40 | out_channels=32, 41 | kernel_size=[5,31], 42 | padding=[2,0]), 43 | torch.nn.BatchNorm2d(32), 44 | torch.nn.LeakyReLU(), 45 | ) 46 | self.conv3 = torch.nn.Sequential( 47 | torch.nn.BatchNorm2d(1), 48 | torch.nn.Conv2d( 49 | in_channels=1, 50 | out_channels=32, 51 | kernel_size=[7, 31], 52 | padding=[3, 0]), 53 | torch.nn.BatchNorm2d(32), 54 | torch.nn.LeakyReLU(), 55 | ) 56 | self.conv4 = torch.nn.Sequential( 57 | torch.nn.BatchNorm2d(1), 58 | torch.nn.Conv2d( 59 | in_channels=1, 60 | out_channels=32, 61 | kernel_size=[9, 31], 62 | padding=[4, 0]), 63 | torch.nn.BatchNorm2d(32), 64 | torch.nn.LeakyReLU(), 65 | ) 66 | 67 | self.fc = nn.Sequential( 68 | nn.Linear(12*128, 512), 69 | nn.Dropout(p=0.5), 70 | nn.Linear(512, 64), 71 | nn.Linear(64, 2), 72 | ) 73 | 74 | def forward(self, x): 75 | import torch 76 | x1 = self.conv1(x) 77 | x2 = self.conv2(x) 78 | x3 = self.conv3(x) 79 | x4 = self.conv4(x) 80 | x1 = torch.transpose(x1,1,2) 81 | x1 = x1.flatten(2).cpu() 82 | x2 = torch.transpose(x2, 1, 2) 83 | x2 = x2.flatten(2).cpu() 84 | x3 = torch.transpose(x3, 1, 2) 85 | x3 = x3.flatten(2).cpu() 86 | x4 = torch.transpose(x4, 1, 2) 87 | x4 = x4.flatten(2).cpu() 88 | x = torch.cat((x1,x2), 2) 89 | x = torch.cat((x, x3), 2) 90 | x = torch.cat((x, x4), 2) 91 | out = x.to(device) 92 | out = torch.flatten(out,start_dim=1) 93 | x = self.fc(out) 94 | return x, out 95 | 96 | 97 | def dataReader(sub_list, task_list): 98 | Input = [] 99 | label = [] 100 | for task in task_list: 101 | for sub in sub_list: 102 | path_label = '/home/zmx/ds002311/event/' 103 | if sub < 10: 104 | num = 'sub-0' + str(sub) 105 | else: 106 | num = 'sub-' + str(sub) 107 | datanum=int(np.argwhere(np.array(sub_list)==sub)[0])+1 108 | if datanum<10: 109 | dn='00' + str(datanum) 110 | else: 111 | dn='0' + str(datanum) 112 | 113 | input_name = num + '_' + task 114 | 115 | if task == 'mot_1': 116 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 117 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s1_.nii' 118 | elif task == 'mot_2': 119 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 120 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s2_.nii' 121 | elif task == 'mot_3': 122 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 123 | datapath = '/home/zmx/ds002311/ICA/' + 'ica_sub'+dn+'_timecourses_ica_s3_.nii' 124 | 125 | img = nib.load(datapath) 126 | img = np.array(img.get_fdata()) 127 | data=img[12:,:] 128 | 129 | for i in range(24): 130 | time_course = data[16 * i + 1:16 * i + 13,:] 131 | time_course = np.array(time_course) 132 | 133 | # time_course = time_course.T # 沿时间轴标准化 134 | # for ic in range(31): 135 | # for x in time_course[ic]: 136 | # x = float(x - np.mean(time_course[ic]))/np.std(time_course[ic]) 137 | # time_course = time_course.T 138 | Input.append(time_course) 139 | 140 | with open(path_label + label_name + '.tsv', 'rt') as csvfile: 141 | reader = csv.DictReader(csvfile, delimiter='\t') 142 | cond = [row['cond'] for row in reader] 143 | 144 | for i in range(len(cond)): 145 | if cond[i] == 'targ_easy': 146 | cond[i] = 0 147 | elif cond[i] == 'targ_hard': 148 | cond[i] = 0 149 | elif cond[i] == 'lure_hard': 150 | cond[i] = 1 151 | del cond[24] # 最后一段时间不全 152 | label.extend(cond) 153 | label = list(map(int, label)) 154 | 155 | Input = np.array(Input) 156 | label = np.array(label) 157 | 158 | return Input, label 159 | 160 | 161 | CNN = Cnn() 162 | CNN.to(device) 163 | epochs = 50 164 | # 定义loss和optimizer 165 | optimizer = optim.Adam(CNN.parameters(), lr=0.001, weight_decay=0.05) 166 | criterion = nn.CrossEntropyLoss() 167 | scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3,factor=0.5,threshold=1e-3,) 168 | 169 | # 训练 170 | train_correct = 0 171 | train_total = 0 172 | TP = 0 173 | FP = 0 174 | FN = 0 175 | batch_size= 64 176 | data_x, data_y = dataReader([1, 3, 5, 6, 7, 8, 9, 10, 13, 14, 15, 18, 21, 22, 23], ['mot_1', 'mot_2', 'mot_3']) 177 | 178 | from sklearn.model_selection import train_test_split 179 | 180 | train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.05, random_state=RANDOM_SEED) 181 | smo = SMOTE(random_state=RANDOM_SEED) # 处理样本数量不对称 182 | nsamples, nx, ny = train_x.shape 183 | d2_train_dataset = train_x.reshape((nsamples, nx * ny)) 184 | train_x, train_y = smo.fit_sample(d2_train_dataset, train_y) 185 | train_x = train_x.reshape(len(train_x), nx, ny) 186 | 187 | smo = RandomOverSampler(random_state=RANDOM_SEED) 188 | nsamples, nx, ny = test_x.shape 189 | d2_train_dataset = test_x.reshape((nsamples, nx * ny)) 190 | test_x, test_y = smo.fit_sample(d2_train_dataset, test_y) 191 | test_x = test_x.reshape(len(test_x), nx, ny) 192 | 193 | test_x = torch.from_numpy(test_x) 194 | test_x = torch.tensor(test_x, dtype=torch.float32) 195 | test_y = torch.from_numpy(test_y) 196 | 197 | plt.ion() # 开启interactive mode 绘制图形 198 | plt.figure(1) 199 | t_now = 0 200 | t = [] 201 | loss_list = [] 202 | acc_list = [] 203 | pre_list = [] 204 | rec_list = [] 205 | if_legend = 1 206 | from sklearn.manifold import TSNE 207 | tsne=TSNE(n_components=2, random_state=RANDOM_SEED, init ='pca') 208 | 209 | for epoch in range(epochs): 210 | CNN.train(mode=True) 211 | X_train_tsne=[] 212 | if epoch: 213 | train_x=np.array(train_x) 214 | train_y=np.array(train_y) 215 | 216 | state = np.random.get_state() # 打乱顺序 217 | np.random.shuffle(train_x) 218 | np.random.set_state(state) 219 | np.random.shuffle(train_y) 220 | 221 | train_x = torch.from_numpy(train_x) 222 | train_x = torch.tensor(train_x, dtype=torch.float32) 223 | train_y = torch.from_numpy(train_y) 224 | train_y = torch.tensor(train_y, dtype=torch.long) 225 | for i in range(0, len(train_x), batch_size): 226 | loss_batch = 0 227 | for b_x, b_y in zip(train_x[i:i+batch_size], train_y[i:i+batch_size]): 228 | b_x = b_x.reshape(-1, 12, 31) 229 | b_x = b_x.reshape(-1, 1, 12, 31) 230 | b_y = b_y.reshape(-1) 231 | b_x, b_y = b_x.to(device), b_y.to(device) 232 | 233 | output, temp = CNN(b_x) 234 | loss = criterion(output, b_y) 235 | loss_batch += loss 236 | 237 | _, predicted = torch.max(output.data, 1) 238 | train_total += b_y.size(0) 239 | train_correct += (predicted == b_y).sum().item() 240 | if predicted == 1 and b_y[0] == 1: 241 | TP += 1 242 | elif predicted == 1 and b_y[0] == 0: 243 | FP += 1 244 | elif predicted == 0 and b_y[0] == 1: 245 | FN += 1 246 | if epoch == epochs - 1: 247 | temp = temp.cpu() 248 | temp = temp.detach().numpy() 249 | X_train_tsne.append(temp) 250 | 251 | loss_batch = loss_batch / batch_size 252 | optimizer.zero_grad() 253 | loss_batch.backward() 254 | optimizer.step() 255 | 256 | if i % (5*batch_size) == 0: # 隔一定数量的batch画图 257 | t.append(t_now) 258 | loss_list.append(loss_batch) 259 | acc_list.append(100 * train_correct / train_total) 260 | pre_list.append(100 * TP / (TP + FP)) 261 | rec_list.append(100 * TP / (TP + FN)) 262 | plt.subplot(2,1,1) 263 | plt.plot(t,loss_list,'-r') 264 | plt.title('train loss',fontsize=15) 265 | plt.xlabel(u"Number of train batches") 266 | plt.ylabel(u"loss") 267 | plt.tight_layout(h_pad=2) 268 | plt.subplot(2,1,2) 269 | plt.plot(t,acc_list,color='orange',label='accuracy') 270 | plt.plot(t,pre_list,color='green',label='precision') 271 | plt.plot(t,rec_list,color='blue',label='recall') 272 | if if_legend: 273 | plt.legend() 274 | if_legend = 0 275 | plt.title('Accuracy&Precision&Recall',fontsize=15) 276 | plt.xlabel(u"Number of train batches") 277 | plt.ylabel(u"%") 278 | t_now += 5 279 | plt.draw() 280 | plt.pause(0.01) 281 | 282 | TimeStr = time.asctime(time.localtime(time.time())) 283 | print('Epoch: {} --- {} --- '.format(epoch, TimeStr)) 284 | print('Train Accuracy of the model: {} %'.format(100 * train_correct / train_total)) 285 | print('Train Loss of the model: {}'.format(loss_batch)) 286 | print('Learning rate: {}'.format(optimizer.param_groups[0]['lr'])) 287 | # 调整学习率 288 | scheduler.step(loss_batch) 289 | 290 | CNN.eval() 291 | with torch.no_grad(): 292 | test_correct = 0 293 | test_total = 0 294 | test_avg_loss = 0 295 | X_test_tsne=[] 296 | 297 | for t_x, t_y in zip(test_x, test_y): 298 | t_x = t_x.reshape(-1, 12, 31) 299 | t_x = t_x.reshape(-1, 1, 12, 31) 300 | t_y = t_y.reshape(-1) 301 | t_x, t_y = t_x.to(device), t_y.to(device) 302 | test_output, temp2 = CNN(t_x) 303 | test_avg_loss += criterion(test_output, t_y) 304 | _, predicted = torch.max(test_output.data, 1) 305 | test_total += t_y.size(0) 306 | test_correct += (predicted == t_y).sum().item() 307 | 308 | if epoch == epochs - 1: 309 | temp2 = temp2.cpu() 310 | temp2 = temp2.detach().numpy() 311 | X_test_tsne.append(temp2) 312 | 313 | test_avg_loss = test_avg_loss / len(test_y) 314 | print('Test Accuracy of the model: {} %'.format(100 * test_correct / test_total)) 315 | print('Test loss of the model: {}'.format(test_avg_loss)) 316 | 317 | ''' 318 | 319 | X_train_tsne = np.array(X_train_tsne) 320 | train_size = X_train_tsne.shape[0] 321 | X_train_tsne = X_train_tsne.reshape(X_train_tsne.shape[0], 1536) 322 | X_test_tsne = np.array(X_test_tsne) 323 | test_size = X_test_tsne.shape[0] 324 | X_test_tsne = X_test_tsne.reshape(X_test_tsne.shape[0], 1536) 325 | X_tsne = np.concatenate((X_train_tsne,X_test_tsne)) 326 | X_tsne = tsne.fit_transform(X_tsne) 327 | x_min, x_max = X_tsne.min(0), X_tsne.max(0) 328 | X_norm = (X_tsne - x_min) / (x_max - x_min) # 归一化 329 | 330 | lure_train_x = [] 331 | lure_train_y = [] 332 | targ_train_x = [] 333 | targ_train_y = [] 334 | lure_test_x = [] 335 | lure_test_y = [] 336 | targ_test_x = [] 337 | targ_test_y = [] 338 | for m in range(X_norm.shape[0]): 339 | if m < train_size: 340 | if train_y[m]: 341 | lure_train_x.append(X_norm[m, 0]) 342 | lure_train_y.append(X_norm[m, 1]) 343 | else: 344 | targ_train_x.append(X_norm[m, 0]) 345 | targ_train_y.append(X_norm[m, 1]) 346 | else: 347 | if test_y[m - train_size]: 348 | lure_test_x.append(X_norm[m, 0]) 349 | lure_test_y.append(X_norm[m, 1]) 350 | else: 351 | targ_test_x.append(X_norm[m, 0]) 352 | targ_test_y.append(X_norm[m, 1]) 353 | 354 | plt.subplot(1,2,1) 355 | plt.title("train result") 356 | plt.scatter(targ_train_x, targ_train_y, s=120,marker = ".", color='orange', label='targ', edgecolor='black',alpha=1) 357 | plt.scatter(lure_train_x, lure_train_y, s=120,marker = ".", color='green', label='lure', edgecolor='black',alpha=1) 358 | plt.legend() 359 | 360 | plt.subplot(1,2,2) 361 | plt.title("test result") 362 | plt.scatter(targ_test_x, targ_test_y, s=120,marker = ".", color='orange', label='targ', edgecolor='black',alpha=1) 363 | plt.scatter(lure_test_x, lure_test_y, s=120,marker = ".", color='green', label='lure', edgecolor='black',alpha=1) 364 | plt.legend() 365 | 366 | plt.draw() 367 | plt.pause(0) 368 | ''' 369 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-fMRI-signal-classifier 2 | Classifiers about the accuracy and the difficulty of tasks, based on CNN and LSTM models. 3 | -------------------------------------------------------------------------------- /area/AllBrain.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/AllBrain.nii -------------------------------------------------------------------------------- /area/Angular_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Angular_L.nii -------------------------------------------------------------------------------- /area/Angular_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Angular_R.nii -------------------------------------------------------------------------------- /area/Cingulum_Ant_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Cingulum_Ant_L.nii -------------------------------------------------------------------------------- /area/Cingulum_Ant_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Cingulum_Ant_R.nii -------------------------------------------------------------------------------- /area/Cingulum_Mid_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Cingulum_Mid_L.nii -------------------------------------------------------------------------------- /area/Cingulum_Mid_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Cingulum_Mid_R.nii -------------------------------------------------------------------------------- /area/Cingulum_Post_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Cingulum_Post_L.nii -------------------------------------------------------------------------------- /area/Cingulum_Post_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Cingulum_Post_R.nii -------------------------------------------------------------------------------- /area/Frontal_Inf_Oper_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Frontal_Inf_Oper_L.nii -------------------------------------------------------------------------------- /area/Frontal_Inf_Oper_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Frontal_Inf_Oper_R.nii -------------------------------------------------------------------------------- /area/Frontal_Inf_Orb_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Frontal_Inf_Orb_L.nii -------------------------------------------------------------------------------- /area/Frontal_Inf_Orb_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Frontal_Inf_Orb_R.nii -------------------------------------------------------------------------------- /area/Frontal_Inf_Tri_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Frontal_Inf_Tri_L.nii -------------------------------------------------------------------------------- /area/Frontal_Inf_Tri_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Frontal_Inf_Tri_R.nii -------------------------------------------------------------------------------- /area/Fusiform_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Fusiform_L.nii -------------------------------------------------------------------------------- /area/Fusiform_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Fusiform_R.nii -------------------------------------------------------------------------------- /area/Hippocampus_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Hippocampus_L.nii -------------------------------------------------------------------------------- /area/Hippocampus_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Hippocampus_R.nii -------------------------------------------------------------------------------- /area/Parahippo_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Parahippo_L.nii -------------------------------------------------------------------------------- /area/Parahippo_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Parahippo_R.nii -------------------------------------------------------------------------------- /area/Parietal_Inf_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Parietal_Inf_L.nii -------------------------------------------------------------------------------- /area/Parietal_Inf_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Parietal_Inf_R.nii -------------------------------------------------------------------------------- /area/Parietal_Sup_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Parietal_Sup_L.nii -------------------------------------------------------------------------------- /area/Parietal_Sup_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Parietal_Sup_R.nii -------------------------------------------------------------------------------- /area/Precuneus_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Precuneus_L.nii -------------------------------------------------------------------------------- /area/Precuneus_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/Precuneus_R.nii -------------------------------------------------------------------------------- /area/SupraMarginal_L.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/SupraMarginal_L.nii -------------------------------------------------------------------------------- /area/SupraMarginal_R.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MXMX0811/pytorch-fMRI-signal-classifier/9a402ae037f04c5fa53fb1017e96f66a32d078b0/area/SupraMarginal_R.nii -------------------------------------------------------------------------------- /multisc_CNN-LSTM_iftarg_AllBrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import torch 6 | import numpy as np 7 | import random 8 | import torch.nn.functional as F 9 | import time 10 | import nibabel as nib 11 | from torch import nn 12 | from torch import optim 13 | from torch.autograd import Variable 14 | from imblearn.over_sampling import SMOTE,ADASYN,RandomOverSampler 15 | 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | print(device) 18 | 19 | class CnnLSTM(nn.Module): 20 | def __init__(self): 21 | super(CnnLSTM, self).__init__() 22 | self.conv1 = torch.nn.Sequential( 23 | torch.nn.Conv3d( 24 | in_channels=1, 25 | out_channels=4, 26 | kernel_size=4, 27 | padding=0), 28 | nn.BatchNorm3d(4), 29 | torch.nn.ReLU(), 30 | torch.nn.MaxPool3d(4), 31 | torch.nn.Conv3d(4, 8, kernel_size=4, padding=0), 32 | nn.BatchNorm3d(8), 33 | torch.nn.ReLU() 34 | ) 35 | 36 | self.conv2 = torch.nn.Sequential( 37 | torch.nn.Conv3d( 38 | in_channels=1, 39 | out_channels=4, 40 | kernel_size=8, 41 | padding=0), 42 | nn.BatchNorm3d(4), 43 | torch.nn.ReLU(), 44 | torch.nn.MaxPool3d(4), 45 | torch.nn.Conv3d(4, 8, kernel_size=8, padding=0), 46 | nn.BatchNorm3d(8), 47 | torch.nn.ReLU() 48 | ) 49 | 50 | self.conv3 = torch.nn.Sequential( 51 | torch.nn.Conv3d( 52 | in_channels=1, 53 | out_channels=4, 54 | kernel_size=12, 55 | padding=0), 56 | nn.BatchNorm3d(4), 57 | torch.nn.ReLU(), 58 | torch.nn.MaxPool3d(3), 59 | torch.nn.Conv3d(4, 8, kernel_size=12, padding=0), 60 | nn.BatchNorm3d(8), 61 | torch.nn.ReLU() 62 | ) 63 | 64 | self.lstm = nn.LSTM( 65 | input_size=17944, 66 | hidden_size=512, 67 | num_layers=2, 68 | batch_first=True 69 | ) 70 | 71 | self.fc = nn.Sequential( 72 | nn.Linear(512, 128), 73 | nn.ReLU(), 74 | nn.Dropout(p=0.5), 75 | nn.Linear(128, 32), 76 | nn.ReLU(), 77 | nn.Dropout(p=0.5), 78 | nn.Linear(32, 2) 79 | ) 80 | 81 | def forward(self, x): 82 | x1 = self.conv1(x) 83 | x1 = torch.flatten(x1,start_dim=1) 84 | x2 = self.conv2(x) 85 | x2 = torch.flatten(x2,start_dim=1) 86 | x3 = self.conv3(x) 87 | x3 = torch.flatten(x3,start_dim=1) 88 | x4 = torch.cat((x1, x2, x3),1) 89 | x4 = x4.reshape(16, 17944, 1) 90 | x4 = torch.transpose(x4, 0, 2) 91 | x4 = torch.transpose(x4, 1, 2) 92 | out, (h_n,h_c) = self.lstm(x4, None) 93 | out = F.relu(out) 94 | out = out[:, -1, :] 95 | out = self.fc(out) 96 | return out 97 | 98 | def dataReader(sub_list, task_list): 99 | Input = [] 100 | label = [] 101 | for task in task_list: 102 | for sub in sub_list: 103 | path_label = '/home/zmx/ds002311/event/' 104 | path_input = '/home/zmx/ds002311/preprocessed_4D/' + task + '/' 105 | if sub < 10: 106 | num = 'sub-0' + str(sub) 107 | else: 108 | num = 'sub-' + str(sub) 109 | 110 | input_name = num + '_' + task 111 | if task == 'mot_1': 112 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 113 | elif task == 'mot_2': 114 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 115 | elif task == 'mot_3': 116 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 117 | 118 | img = nib.load(path_input + input_name + '.nii') 119 | img = np.array(img.get_fdata()) 120 | template = nib.load('/home/zmx/fMRI/Template/area/AllBrain.nii') 121 | template = np.array(template.get_fdata()) 122 | for i in range(61): 123 | for j in range(73): 124 | for k in range(61): 125 | if template[i][j][k]==0: 126 | img[i][j][k] = np.zeros(405) 127 | 128 | with open(path_label + label_name + '.tsv','rt') as csvfile: 129 | reader = csv.DictReader(csvfile, delimiter='\t') 130 | cond = [row['cond'] for row in reader] 131 | 132 | for i in range(len(cond)): 133 | if cond[i] == 'targ_easy': 134 | cond[i] = 0 135 | elif cond[i] == 'targ_hard': 136 | cond[i] = 0 137 | elif cond[i] == 'lure_hard': 138 | cond[i] = 1 139 | 140 | del cond[24] # 最后一段时间不全 141 | label.extend(cond) 142 | label = list(map(int,label)) 143 | 144 | data = img[:,:,:,12:] # 从第13个时间点开始,删除前12个时间点 145 | 146 | for i in range(24): # 最后一段时间不全 147 | Input.append(data[:,:,:,16*i:16*i+16]) 148 | 149 | Input = np.array(Input) 150 | label = np.array(label) 151 | 152 | max_value = np.max(Input) # 获得最大值 153 | min_value = np.min(Input) # 获得最小值 154 | scalar = max_value - min_value # 获得间隔数量 155 | Input = list(map(lambda x: x / scalar, Input)) # 归一化 156 | 157 | Input = np.array(Input) 158 | return Input, label 159 | 160 | 161 | LSTM = CnnLSTM() 162 | LSTM.to(device) 163 | epochs = 12 164 | batch_size = 4 165 | # 定义loss和optimizer 166 | optimizer = optim.Adam(LSTM.parameters(), lr=0.00005) 167 | criterion = nn.CrossEntropyLoss() 168 | 169 | # 训练 170 | correct = 0 171 | total = 0 172 | for epoch in range(epochs): 173 | if epoch % 5 == 0: # 衰减的学习率 174 | for p in optimizer.param_groups: 175 | p['lr'] *= 0.9 176 | #sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 177 | #task = ['mot_1','mot_2','mot_3'] 178 | for task in [['mot_1'],['mot_2']]: 179 | for sub in [[1,3],[5,6]]: 180 | train_x, train_y = dataReader(sub,task) 181 | 182 | smo = RandomOverSampler(random_state=42) # 处理样本数量不对称 183 | nsamples, nx, ny, nz, nt = train_x.shape 184 | d2_train_dataset = train_x.reshape((nsamples,nx*ny*nz*nt)) 185 | train_x, train_y = smo.fit_sample(d2_train_dataset, train_y) 186 | train_x = train_x.reshape(len(train_x), nx, ny, nz, nt) 187 | 188 | state = np.random.get_state() # 打乱顺序 189 | np.random.shuffle(train_x) 190 | np.random.set_state(state) 191 | np.random.shuffle(train_y) 192 | 193 | train_x = torch.from_numpy(train_x) 194 | train_x = torch.tensor(train_x, dtype=torch.float32) 195 | train_y = torch.from_numpy(train_y) 196 | 197 | for i in range(0, len(train_x) - batch_size, batch_size): 198 | loss_batch = 0 199 | 200 | for b_x, b_y in zip(train_x[i:i+batch_size],train_y[i:i+batch_size]): 201 | b_x = b_x.reshape(-1, 61, 73, 61, 16) 202 | b_y = b_y.reshape(-1) 203 | b_x = torch.transpose(b_x, 3, 4) 204 | b_x = torch.transpose(b_x, 2, 3) 205 | b_x = torch.transpose(b_x, 1, 2) 206 | b_x = torch.transpose(b_x, 0, 1) 207 | b_x, b_y = b_x.to(device), b_y.to(device) 208 | output = LSTM(b_x) 209 | loss = criterion(output,b_y) 210 | loss_batch += loss 211 | 212 | _, predicted = torch.max(output.data, 1) 213 | total += b_y.size(0) 214 | correct += (predicted == b_y).sum().item() 215 | 216 | loss_batch = loss_batch / batch_size 217 | optimizer.zero_grad() 218 | loss_batch.backward() 219 | optimizer.step() 220 | 221 | TimeStr = time.asctime(time.localtime(time.time())) 222 | print('Epoch: {} --- {}'.format(epoch, TimeStr)) 223 | print('Train Accuracy of the model: {} %'.format(100 * correct / total)) 224 | print('Train Loss of the model: {}'.format(loss)) 225 | 226 | 227 | 228 | # # 测试 229 | with torch.no_grad(): 230 | print('--------test--------') 231 | print('--------test--------') 232 | print('--------test--------') 233 | correct = 0 234 | total = 0 235 | test_task = ['mot_1'] 236 | test_sub = [10] 237 | test_x, test_y = dataReader(test_sub,test_task) 238 | 239 | smo = RandomOverSampler(random_state=42) 240 | nsamples, nx, ny, nz, nt = test_x.shape 241 | d2_train_dataset = test_x.reshape((nsamples,nx*ny*nz*nt)) 242 | test_x, test_y = smo.fit_sample(d2_train_dataset, test_y) 243 | test_x = test_x.reshape(len(test_x), nx, ny, nz, nt) 244 | 245 | test_x = torch.from_numpy(test_x) 246 | test_x = torch.tensor(test_x, dtype=torch.float32) 247 | test_y = torch.from_numpy(test_y) 248 | for t_x, t_y in zip(test_x, test_y): 249 | t_x = t_x.reshape(-1, 61, 73, 61, 16) 250 | t_y = t_y.reshape(-1) 251 | t_x = torch.transpose(t_x, 3, 4) 252 | t_x = torch.transpose(t_x, 2, 3) 253 | t_x = torch.transpose(t_x, 1, 2) 254 | t_x = torch.transpose(t_x, 0, 1) 255 | t_x, t_y = t_x.to(device), t_y.to(device) 256 | output = LSTM(t_x) 257 | _, predicted = torch.max(output.data, 1) 258 | total += t_y.size(0) 259 | correct += (predicted == t_y).sum().item() 260 | print('Test Accuracy of the model: {} %'.format(100 * correct / total)) 261 | print('Target: {}'.format(t_y)) 262 | print('Output: {}'.format(torch.max(output, 1)[1])) 263 | print('Test Loss of the model: {}'.format(loss)) 264 | -------------------------------------------------------------------------------- /multisc_CNN_FNC_iftarg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import torch 6 | import numpy as np 7 | import random 8 | import torch.nn.functional as F 9 | import time 10 | import nibabel as nib 11 | import matplotlib.pyplot as plt 12 | from math import * 13 | from torch import nn 14 | from torch import optim 15 | from torch.autograd import Variable 16 | from imblearn.over_sampling import SMOTE,ADASYN,RandomOverSampler 17 | 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | print(device) 20 | 21 | class Cnn(nn.Module): 22 | def __init__(self): 23 | super(Cnn, self).__init__() 24 | self.conv1 = torch.nn.Sequential( 25 | torch.nn.Conv2d( 26 | in_channels=1, 27 | out_channels=4, 28 | kernel_size=3, 29 | padding=0), 30 | nn.BatchNorm2d(4), 31 | torch.nn.ReLU(), 32 | torch.nn.MaxPool2d(2), 33 | torch.nn.Conv2d(4, 8, kernel_size=3, padding=0), 34 | nn.BatchNorm2d(8), 35 | torch.nn.ReLU() 36 | ) 37 | 38 | self.conv2 = torch.nn.Sequential( 39 | torch.nn.Conv2d( 40 | in_channels=1, 41 | out_channels=4, 42 | kernel_size=5, 43 | padding=0), 44 | nn.BatchNorm2d(4), 45 | torch.nn.ReLU(), 46 | torch.nn.MaxPool2d(2), 47 | torch.nn.Conv2d(4, 8, kernel_size=5, padding=0), 48 | nn.BatchNorm2d(8), 49 | torch.nn.ReLU() 50 | ) 51 | 52 | self.conv3 = torch.nn.Sequential( 53 | torch.nn.Conv2d( 54 | in_channels=1, 55 | out_channels=4, 56 | kernel_size=7, 57 | padding=0), 58 | nn.BatchNorm2d(4), 59 | torch.nn.ReLU(), 60 | torch.nn.MaxPool2d(2), 61 | torch.nn.Conv2d(4, 8, kernel_size=7, padding=0), 62 | nn.BatchNorm2d(8), 63 | torch.nn.ReLU() 64 | ) 65 | 66 | self.fc = nn.Sequential( 67 | nn.Linear(36648, 1024), 68 | nn.ReLU(), 69 | nn.Dropout(p=0.5), 70 | nn.Linear(1024, 128), 71 | nn.ReLU(), 72 | nn.Dropout(p=0.5), 73 | nn.Linear(128, 2) 74 | ) 75 | 76 | def forward(self, x): 77 | x1 = self.conv1(x) 78 | x1 = torch.flatten(x1,start_dim=1) 79 | x2 = self.conv2(x) 80 | x2 = torch.flatten(x2,start_dim=1) 81 | x3 = self.conv3(x) 82 | x3 = torch.flatten(x3,start_dim=1) 83 | x4 = torch.cat((x1, x2, x3),1) 84 | out = self.fc(x4) 85 | return out 86 | 87 | def dataReader(sub_list, task_list): 88 | Input = [] 89 | label = [] 90 | for task in task_list: 91 | for sub in sub_list: 92 | path_label = '/home/zmx/ds002311/event/' 93 | if sub < 10: 94 | num = 'sub-0' + str(sub) 95 | else: 96 | num = 'sub-' + str(sub) 97 | 98 | path_input = '/home/zmx/ds002311/FNC/' + num + '/' + task + '/' 99 | 100 | input_name = num + '_' + task 101 | 102 | if task == 'mot_1': 103 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 104 | elif task == 'mot_2': 105 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 106 | elif task == 'mot_3': 107 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 108 | 109 | for i in range(24): 110 | img = np.loadtxt(open(path_input + input_name + '_' + str(i+1) + ".csv","rb"),delimiter=",",skiprows=0) 111 | Input.append(img) 112 | 113 | 114 | with open(path_label + label_name + '.tsv','rt') as csvfile: 115 | reader = csv.DictReader(csvfile, delimiter='\t') 116 | cond = [row['cond'] for row in reader] 117 | 118 | for i in range(len(cond)): 119 | if cond[i] == 'targ_easy': 120 | cond[i] = 0 121 | elif cond[i] == 'targ_hard': 122 | cond[i] = 0 123 | elif cond[i] == 'lure_hard': 124 | cond[i] = 1 125 | del cond[24] # 最后一段时间不全 126 | label.extend(cond) 127 | label = list(map(int,label)) 128 | 129 | Input = np.array(Input) 130 | label = np.array(label) 131 | 132 | return Input, label 133 | 134 | 135 | train_task = ['mot_1','mot_2'] 136 | train_sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 137 | 138 | train_x, train_y = dataReader(train_sub,train_task) 139 | 140 | plt.ion() #开启interactive mode 成功的关键函数 141 | plt.figure(1) 142 | t_now = 0 143 | t = [] 144 | loss_list = [] 145 | acc_list = [] 146 | 147 | CNN = Cnn() 148 | CNN.to(device) 149 | epochs = 20 150 | batch_size = 16 151 | # 定义loss和optimizer 152 | optimizer = optim.Adam(CNN.parameters(), lr=0.001, weight_decay=0.02) 153 | criterion = nn.CrossEntropyLoss().to(device) 154 | 155 | # 训练 156 | correct = 0 157 | total = 0 158 | CNN.train(mode=True) 159 | for epoch in range(epochs): 160 | if epoch % 5 == 0: # 衰减的学习率 161 | for p in optimizer.param_groups: 162 | p['lr'] *= 0.9 163 | 164 | # smo = RandomOverSampler(random_state=42) # 处理样本数量不对称 165 | # smo = ADASYN(random_state=42) 166 | smo = SMOTE(random_state=42) 167 | nsamples, nx, ny = train_x.shape 168 | d2_train_dataset = train_x.reshape((nsamples,nx*ny)) 169 | train_x_smo, train_y_smo = smo.fit_sample(d2_train_dataset, train_y) 170 | train_x_smo = train_x_smo.reshape(len(train_x_smo), nx, ny) 171 | 172 | state = np.random.get_state() # 打乱顺序 173 | np.random.shuffle(train_x_smo) 174 | np.random.set_state(state) 175 | np.random.shuffle(train_y_smo) 176 | 177 | train_x_smo = torch.from_numpy(train_x_smo) 178 | train_x_smo = train_x_smo.type(torch.FloatTensor) 179 | train_y_smo = torch.from_numpy(train_y_smo) 180 | 181 | for i in range(0, len(train_x_smo) - batch_size, batch_size): 182 | loss_batch = 0 183 | 184 | for b_x, b_y in zip(train_x_smo[i:i+batch_size],train_y_smo[i:i+batch_size]): 185 | b_x = b_x.reshape(-1, 90, 90) 186 | b_x = b_x.reshape(-1, 1, 90, 90) 187 | b_y = b_y.reshape(-1) 188 | b_x, b_y = b_x.to(device), b_y.to(device) 189 | output = CNN(b_x) 190 | loss = criterion(output,b_y) 191 | loss_batch += loss 192 | 193 | _, predicted = torch.max(output.data, 1) 194 | total += b_y.size(0) 195 | correct += (predicted == b_y).sum().item() 196 | # print('Target: {}'.format(b_y)) 197 | # print('Output: {}'.format(torch.max(output, 1)[1])) 198 | 199 | loss_batch = loss_batch / batch_size 200 | optimizer.zero_grad() 201 | loss_batch.backward() 202 | optimizer.step() 203 | 204 | TimeStr = time.asctime(time.localtime(time.time())) 205 | print('Epoch: {} --- {}'.format(epoch, TimeStr)) 206 | print('Train Accuracy of the model: {} %'.format(100 * correct / total)) 207 | print('Train Loss of this batch: {}'.format(loss_batch)) 208 | 209 | if i % 5 == 0: # 隔一定数量的batch画图 210 | t.append(t_now) 211 | loss_list.append(loss_batch) 212 | acc_list.append(100 * correct / total) 213 | plt.subplot(2,1,1) 214 | plt.plot(t,loss_list,'-r') 215 | plt.title('loss',fontsize=10) 216 | plt.tight_layout(h_pad=1) 217 | plt.subplot(2,1,2) 218 | plt.plot(t,acc_list,'-b') 219 | plt.title('acc',fontsize=10) 220 | plt.draw() 221 | plt.pause(0.01) 222 | t_now += 5 223 | 224 | 225 | 226 | 227 | 228 | # # 测试 229 | CNN.eval() 230 | with torch.no_grad(): 231 | print('--------test--------') 232 | print('--------test--------') 233 | print('--------test--------') 234 | correct = 0 235 | total = 0 236 | test_task = ['mot_3'] 237 | test_sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 238 | test_x, test_y = dataReader(test_sub,test_task) 239 | 240 | smo = RandomOverSampler(random_state=42) 241 | nsamples, nx, ny = test_x.shape 242 | d2_train_dataset = test_x.reshape((nsamples,nx*ny)) 243 | test_x, test_y = smo.fit_sample(d2_train_dataset, test_y) 244 | test_x = test_x.reshape(len(test_x), nx, ny) 245 | 246 | test_x = torch.from_numpy(test_x) 247 | test_x = torch.tensor(test_x, dtype=torch.float32) 248 | test_y = torch.from_numpy(test_y) 249 | for t_x, t_y in zip(test_x, test_y): 250 | t_x = t_x.reshape(-1, 90, 90) 251 | t_x = t_x.reshape(-1, 1, 90, 90) 252 | t_y = t_y.reshape(-1) 253 | t_x, t_y = t_x.to(device), t_y.to(device) 254 | output = CNN(t_x) 255 | loss = criterion(output,t_y) 256 | _, predicted = torch.max(output.data, 1) 257 | total += t_y.size(0) 258 | correct += (predicted == t_y).sum().item() 259 | print('Test Accuracy of the model: {} %'.format(100 * correct / total)) 260 | print('Target: {}'.format(t_y)) 261 | print('Output: {}'.format(torch.max(output, 1)[1])) 262 | print('Test Loss of the model: {}'.format(loss)) 263 | -------------------------------------------------------------------------------- /niicat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | 5 | sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 6 | phase= ['loc','mot_1','mot_2','mot_3','prememory','postmemory'] 7 | for index in sub: 8 | task = 'mrcat' 9 | n = index 10 | if index < 10: 11 | index = '0' + str(index) 12 | else: 13 | index = str(index) 14 | 15 | for i in range(1,406): 16 | if i < 10: 17 | i = str(i) 18 | task = task + ' /home/zmx/ds002311/sub-' + index + '/func/mot_3/swrasub-' + index + '_task-mot_run-03_bold_0000'+ i +'.nii ' 19 | elif i < 100: 20 | i = str(i) 21 | task = task + ' /home/zmx/ds002311/sub-' + index + '/func/mot_3/swrasub-' + index + '_task-mot_run-03_bold_000'+ i +'.nii ' 22 | else: 23 | i = str(i) 24 | task = task + ' /home/zmx/ds002311/sub-' + index + '/func/mot_3/swrasub-' + index + '_task-mot_run-03_bold_00'+ i +'.nii ' 25 | #print(task) 26 | task = task + '/home/zmx/ds002311/func_net/sub-' + index + '_mot_3.nii' 27 | os.system(task) 28 | -------------------------------------------------------------------------------- /original_data.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import csv 4 | import torch 5 | import numpy as np 6 | import random 7 | import torch.nn.functional as F 8 | import time 9 | import nibabel as nib 10 | from torch import nn 11 | from torch import optim 12 | from torch.autograd import Variable 13 | from imblearn.over_sampling import RandomOverSampler,ADASYN,SMOTE 14 | import matplotlib.pyplot as plt 15 | 16 | RANDOM_SEED = 408 # 测试集0.05时 17 | # RANDOM_SEED = 296 # 测试集0.1时 18 | 19 | 20 | def dataReader(sub_list, task_list): 21 | Input = [] 22 | label = [] 23 | for task in task_list: 24 | for sub in sub_list: 25 | path_label = '/home/zmx/ds002311/event/' 26 | if sub < 10: 27 | num = 'sub-0' + str(sub) 28 | else: 29 | num = 'sub-' + str(sub) 30 | 31 | path_input = '/home/zmx/ds002311/FNC/' + num + '/' + task + '/' 32 | 33 | input_name = num + '_' + task 34 | 35 | if task == 'mot_1': 36 | label_name = num + '_func_' + num + '_task-mot_run-01_events' 37 | elif task == 'mot_2': 38 | label_name = num + '_func_' + num + '_task-mot_run-02_events' 39 | elif task == 'mot_3': 40 | label_name = num + '_func_' + num + '_task-mot_run-03_events' 41 | 42 | for i in range(24): 43 | img = np.loadtxt(open(path_input + input_name + '_' + str(i+1) + ".csv","rb"),delimiter=",",skiprows=0) 44 | Input.append(img) 45 | 46 | 47 | with open(path_label + label_name + '.tsv','rt') as csvfile: 48 | reader = csv.DictReader(csvfile, delimiter='\t') 49 | cond = [row['cond'] for row in reader] 50 | 51 | for i in range(len(cond)): 52 | if cond[i] == 'targ_easy': 53 | cond[i] = 0 54 | elif cond[i] == 'targ_hard': 55 | cond[i] = 0 56 | elif cond[i] == 'lure_hard': 57 | cond[i] = 1 58 | label.extend(cond) 59 | label = list(map(int,label)) 60 | del label[24] # 最后一段时间不全 61 | 62 | Input = np.array(Input) 63 | label = np.array(label) 64 | 65 | return Input, label 66 | 67 | 68 | data_x, data_y = dataReader([1, 3, 5, 6, 7, 8, 9, 10, 13, 14, 15, 18, 21, 22, 23], ['mot_1', 'mot_2', 'mot_3']) 69 | 70 | from sklearn.model_selection import train_test_split 71 | 72 | train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.05, random_state=RANDOM_SEED) 73 | original_train = train_x 74 | original_y = train_y 75 | smo = SMOTE(random_state=RANDOM_SEED) # 处理样本数量不对称 76 | nsamples, nx, ny = train_x.shape 77 | d2_train_dataset = train_x.reshape((nsamples, nx * ny)) 78 | train_x, train_y = smo.fit_sample(d2_train_dataset, train_y) 79 | train_x = train_x.reshape(len(train_x), nx, ny) 80 | smote_train = train_x 81 | smote_y = train_y 82 | 83 | plt.ion() # 开启interactive mode 绘制图形 84 | plt.figure(1) 85 | t_now = 0 86 | t = [] 87 | loss_list = [] 88 | acc_list = [] 89 | from sklearn.manifold import TSNE 90 | tsne=TSNE(n_components=2, random_state=RANDOM_SEED, init ='pca') 91 | 92 | X_train_tsne = original_train 93 | print(X_train_tsne.shape) 94 | train_size = X_train_tsne.shape[0] 95 | X_train_tsne = X_train_tsne.reshape(X_train_tsne.shape[0], 90*90) 96 | X_test_tsne = smote_train 97 | print(X_test_tsne.shape) 98 | test_size = X_test_tsne.shape[0] 99 | X_test_tsne = X_test_tsne.reshape(X_test_tsne.shape[0], 90*90) 100 | X_tsne = np.concatenate((X_train_tsne,X_test_tsne)) 101 | X_tsne = tsne.fit_transform(X_tsne) 102 | x_min, x_max = X_tsne.min(0), X_tsne.max(0) 103 | X_norm = (X_tsne - x_min) / (x_max - x_min) # 归一化 104 | 105 | lure_train_x = [] 106 | lure_train_y = [] 107 | targ_train_x = [] 108 | targ_train_y = [] 109 | lure_test_x = [] 110 | lure_test_y = [] 111 | targ_test_x = [] 112 | targ_test_y = [] 113 | for m in range(X_norm.shape[0]): 114 | if m < train_size: 115 | if original_y[m]: 116 | lure_train_x.append(X_norm[m, 0]) 117 | lure_train_y.append(X_norm[m, 1]) 118 | else: 119 | targ_train_x.append(X_norm[m, 0]) 120 | targ_train_y.append(X_norm[m, 1]) 121 | else: 122 | if smote_y[m - train_size]: 123 | lure_test_x.append(X_norm[m, 0]) 124 | lure_test_y.append(X_norm[m, 1]) 125 | else: 126 | targ_test_x.append(X_norm[m, 0]) 127 | targ_test_y.append(X_norm[m, 1]) 128 | 129 | plt.subplot(1,2,1) 130 | plt.title("2D distribution of raw FNC data") 131 | plt.scatter(targ_train_x, targ_train_y, s=120,marker = ".", color='orange', label='targ', edgecolor='black',alpha=1) 132 | plt.scatter(lure_train_x, lure_train_y, s=120,marker = ".", color='green', label='lure', edgecolor='black',alpha=1) 133 | plt.legend() 134 | 135 | plt.subplot(1,2,2) 136 | plt.title("2D distribution of FNC data with SMOTE") 137 | plt.scatter(targ_test_x, targ_test_y, s=120,marker = ".", color='orange', label='targ', edgecolor='black',alpha=1) 138 | plt.scatter(lure_test_x, lure_test_y, s=120,marker = ".", color='green', label='lure', edgecolor='black',alpha=1) 139 | plt.legend() 140 | 141 | plt.draw() 142 | plt.pause(0) -------------------------------------------------------------------------------- /sequence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import nibabel as nib 5 | import numpy as np 6 | import pandas as pd 7 | import time 8 | import multiprocessing 9 | 10 | def extraction(templet_arr,name,path): 11 | img = nib.load(path + name + '.nii') 12 | img = np.array(img.get_fdata()) 13 | for i in range(61): 14 | for j in range(73): 15 | for k in range(61): 16 | if templet_arr[i][j][k]==0: 17 | img[i][j][k]*=0 18 | img_sum = img.sum(axis=0) 19 | img_sum = img_sum.sum(axis=0) 20 | img_sum = img_sum.sum(axis=0) 21 | img_sum = img_sum / (templet_arr.sum() / np.max(templet_arr)) 22 | return img_sum 23 | 24 | def process(region): 25 | global name 26 | TimeStr = time.asctime(time.localtime(time.time())) 27 | print(TimeStr + '\t' + name + '\t' + region) 28 | img = nib.load('/home/zmx/fMRI/Template/area/' + region + '.nii') 29 | img_arr = np.array(img.get_fdata()) 30 | img_res = extraction(img_arr,name,path) 31 | varDict[region+'_res'] = list(img_res) 32 | 33 | #sub = [1,3,5,6,7,8,9,10,13,14,15,18,21,22,23] 34 | sub = [1] 35 | area = ['Hippocampus_L','Hippocampus_R','Parahippo_L','Parahippo_R','Fusiform_L','Fusiform_R',\ 36 | 'Precuneus_L','Precuneus_R','Parietal_Inf_L','Parietal_Inf_R','Parietal_Sup_L','Parietal_Sup_R',\ 37 | 'Angular_L','Angular_R','Cingulum_Ant_L','Cingulum_Ant_R','Cingulum_Mid_L','Cingulum_Mid_R',\ 38 | 'Cingulum_Post_L','Cingulum_Post_R','Frontal_Inf_Oper_L','Frontal_Inf_Oper_R','Frontal_Inf_Orb_L','Frontal_Inf_Orb_R',\ 39 | 'Frontal_Inf_Tri_L','Frontal_Inf_Tri_R','SupraMarginal_L','SupraMarginal_R'] 40 | #task = ['loc','mot_1','mot_2','mot_3','prememory','postmemory'] 41 | task = ['loc'] 42 | 43 | varDict = multiprocessing.Manager().dict() 44 | 45 | for t in task: 46 | path = '/home/zmx/ds002311/preprocessed_4D/' + t + '/' 47 | for i in sub: 48 | if i < 10: 49 | name = 'bsub-0' + str(i) + '_' + t 50 | else: 51 | name = 'sub-' + str(i) + '_' + t 52 | 53 | pool = multiprocessing.Pool(processes = 2) 54 | 55 | for region in area: 56 | pool.apply_async(process, (region,)) 57 | 58 | pool.close() 59 | pool.join() 60 | varDict = dict(varDict) 61 | dataframe = pd.DataFrame(varDict) 62 | dataframe.to_csv(path + name + ".csv",index=False) --------------------------------------------------------------------------------