├── Code ├── Loss_custom.py ├── __init__.py ├── dataloader.py ├── model.py ├── sampler.py ├── train.py ├── utils.py └── vision.py ├── LICENSE ├── README.md └── main3.png /Code/Loss_custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import numpy as np 4 | import scipy.linalg 5 | import pdb 6 | from scipy.stats import pearsonr 7 | 8 | 9 | def signal_corelation(signal1, signal2): 10 | if np.all(signal1 == 0) or np.all(signal2 == 0): 11 | pcc = 0 12 | else: 13 | pcc, p_value = pearsonr(signal1, signal2) 14 | return pcc 15 | 16 | 17 | def P_MSE(gen, real): 18 | """Persent mean square error 19 | """ 20 | rate = 10.0 21 | 22 | row = gen.shape[1] 23 | col = gen.shape[2] 24 | batch = gen.shape[0] 25 | thres = torch.full([row, col], 0.01).to(gen.device).float() 26 | nonzero_real = torch.where(0.0 != real, real, thres).to(gen.device).float() 27 | mul = (gen - real).to(gen.device).float() 28 | per = torch.div(mul, nonzero_real).to(gen.device).float() 29 | loss = torch.sum(torch.abs(per)) 30 | loss = torch.div(loss, float(batch)) * rate 31 | 32 | return loss 33 | 34 | 35 | def main_PMSE(): 36 | a = torch.full([2, 3, 3], -2).float() 37 | b = torch.randint(-2, 3, (2, 3, 3)).float() 38 | loss = P_MSE(a, b) 39 | # print loss 40 | 41 | 42 | def Pearson_loss_regions(gen, real): 43 | batch = gen.shape[0] 44 | loss = 0.0 45 | eps = 1e-6 46 | for i in range(batch): 47 | region_num = gen[i].shape[0] 48 | for region in range(region_num): 49 | gen_vec = gen[i][region].to(gen.device) 50 | real_vec = real[i][region].to(gen.device) 51 | gen_mean = gen_vec - torch.mean(gen_vec) + eps 52 | real_mean = real_vec - torch.mean(real_vec) + eps 53 | r_num = torch.sum(gen_mean * real_mean) 54 | r_den = torch.sqrt(torch.sum(torch.pow(gen_mean, 2)) * torch.sum(torch.pow(real_mean, 2))) 55 | pear = r_num / r_den 56 | # pear_official = signal_corelation(gen_vec.numpy(), real_vec.numpy()) 57 | # print pear, pear_official 58 | loss = loss + torch.pow(pear - 1.0, 2) 59 | # print loss 60 | loss = torch.div(loss, float(batch)) 61 | return loss 62 | 63 | 64 | def Pearson_loss_whole(gen, real): 65 | batch = gen.shape[0] 66 | loss = 0.0 67 | eps = 1e-6 68 | rate = 1.0 69 | for i in range(batch): 70 | gen_vec = gen[i].view(-1).to(gen.device) 71 | real_vec = real[i].view(-1).to(gen.device) 72 | gen_mean = gen_vec - torch.mean(gen_vec) + eps 73 | real_mean = real_vec - torch.mean(real_vec) + eps 74 | r_num = torch.sum(gen_mean * real_mean) 75 | r_den = torch.sqrt(torch.sum(torch.pow(gen_mean, 2)) * torch.sum(torch.pow(real_mean, 2))) 76 | pear = r_num / r_den 77 | loss = loss + torch.pow(pear - 1.0, 2) 78 | loss = torch.div(loss, float(batch)) * rate 79 | return loss 80 | 81 | 82 | def main_Pearson_loss(): 83 | a = torch.randn(6,120,120).float() 84 | b = torch.randn(6,120,120).float() 85 | loss = Pearson_loss_whole(a, b) 86 | print(loss) 87 | 88 | a = torch.randn(6, 120, 120).float() 89 | b = torch.randn(6, 120, 120).float() 90 | loss = Pearson_loss_regions(a, b) 91 | print(loss) 92 | 93 | 94 | if __name__ == '__main__': 95 | main_Pearson_loss() 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /Code/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | -------------------------------------------------------------------------------- /Code/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import random 4 | import os.path 5 | import numpy as np 6 | import torch 7 | from scipy.linalg import sqrtm 8 | from scipy.stats import pearsonr 9 | import matplotlib.pylab as plt 10 | import seaborn as sns 11 | 12 | import pdb 13 | 14 | 15 | def subjects_error_list(subjects, data_path, empty_list): 16 | 17 | def zero_row_col(matrix, subject, subjects_error): 18 | row_zero = np.zeros(shape=(matrix.shape[1],)) 19 | col_zero = np.zeros(shape=(matrix.shape[0],)) 20 | row_index = [] 21 | col_index = [] 22 | 23 | for row in range(matrix.shape[0]): 24 | if (matrix[row] == row_zero).all(): 25 | row_index.append(row) 26 | 27 | for col in range(matrix.shape[1]): 28 | if (matrix[:, col] == col_zero).all(): 29 | col_index.append(col) 30 | 31 | if empty_list != sorted(row_index) or empty_list != sorted(col_index): 32 | if subject not in subjects_error: 33 | subjects_error.append(subject) 34 | 35 | subjects_error = [] 36 | 37 | for subject in subjects: 38 | adj_matrix = np.loadtxt(data_path + '/' + subject + '/' + 'common_fiber_matrix.txt') 39 | func_matrix = np.loadtxt(data_path + '/' + subject + '/' + 'pcc_fmri_feature_matrix_0.txt') 40 | zero_row_col(adj_matrix, subject, subjects_error) 41 | zero_row_col(func_matrix, subject, subjects_error) 42 | 43 | return subjects_error 44 | 45 | 46 | def adj_matrix_normlize(adj): 47 | adj_log = adj + 1 48 | adj_log = np.log2(adj_log) 49 | # adj_log_norm = (adj_log - adj_log.min())/(adj_log.max()-adj_log.min()) 50 | # I = np.identity(adj.shape[0]) 51 | # adj_log_norm_diag = I-(I-1)*adj_log_norm 52 | return adj_log 53 | 54 | 55 | def load_data(data_path, empty_list): 56 | # pdb.set_trace() 57 | subjects = [sub for sub in os.listdir(data_path) if not sub.startswith('.')] 58 | subjects_error = subjects_error_list(subjects, data_path, empty_list) 59 | data = [] 60 | SubjID_list = [subject for subject in subjects if subject not in subjects_error] 61 | 62 | # pdb.set_trace() 63 | for subject in SubjID_list: 64 | adj_matrix_path = data_path + '/' + subject + '/' + 'common_fiber_matrix.txt' 65 | func_matrix_path = data_path + '/' + subject + '/' + 'pcc_fmri_feature_matrix_0.txt' 66 | data.append((subject, adj_matrix_path, func_matrix_path)) 67 | random.shuffle(data) 68 | return data 69 | 70 | 71 | def normlize_data(data, empty_list): 72 | eps = 1e-9 73 | all_funcs = [] 74 | all_adjs = [] 75 | for index in range(len(data)): 76 | subject, adj_matrix_path, func_matrix_path = data[index] 77 | 78 | adj_matrix = np.loadtxt(adj_matrix_path) 79 | adj_matrix = np.delete(adj_matrix, empty_list, axis=0) 80 | adj_matrix = np.delete(adj_matrix, empty_list, axis=1) 81 | adj_matrix = adj_matrix_normlize(adj_matrix) 82 | 83 | func_matrix = np.loadtxt(func_matrix_path) 84 | func_matrix = np.delete(func_matrix, empty_list, axis=0) 85 | func_matrix = np.delete(func_matrix, empty_list, axis=1) 86 | 87 | all_adjs.append(adj_matrix) 88 | all_funcs.append(func_matrix) 89 | 90 | all_adjs = np.stack(all_adjs) 91 | all_funcs = np.stack(all_funcs) 92 | 93 | adj_mean = all_adjs.mean((0, 1, 2), keepdims=True).squeeze(0) 94 | adj_std = all_adjs.std((0, 1, 2), keepdims=True).squeeze(0) 95 | 96 | func_mean = all_funcs.mean((0, 1, 2), keepdims=True).squeeze(0) 97 | func_std = all_funcs.std((0, 1, 2), keepdims=True).squeeze(0) 98 | 99 | return (torch.from_numpy(adj_mean) + eps, torch.from_numpy(adj_std) + eps, torch.from_numpy(func_mean) + eps, 100 | torch.from_numpy(func_std) + eps) 101 | 102 | 103 | 104 | class MICCAI(data.Dataset): 105 | def __init__(self, data_path, all_data, data_mean, empty_list, train=True, test=False): 106 | self.data_path = data_path 107 | self.train = train # training set or val set 108 | self.test = test 109 | self.adj_mean, self.adj_std, self.feat_mean, self.feat_std = data_mean 110 | self.empty_list = empty_list 111 | 112 | # pdb.set_trace() 113 | if self.train: 114 | self.data = all_data[:500] 115 | elif not test: 116 | self.data = all_data[500:600] 117 | else: 118 | self.data = all_data[600:] 119 | 120 | random.shuffle(self.data) 121 | 122 | def __getitem__(self, index): 123 | subject, adj_matrix_path, func_matrix_path = self.data[index] 124 | 125 | adj_matrix = np.loadtxt(adj_matrix_path) 126 | adj_matrix = np.delete(adj_matrix, self.empty_list, axis=0) 127 | adj_matrix = np.delete(adj_matrix, self.empty_list, axis=1) 128 | adj_matrix = adj_matrix_normlize(adj_matrix) 129 | 130 | func_matrix = np.loadtxt(func_matrix_path) 131 | func_matrix = np.delete(func_matrix, self.empty_list, axis=0) 132 | func_matrix = np.delete(func_matrix, self.empty_list, axis=1) 133 | 134 | adj_matrix = torch.from_numpy(adj_matrix) 135 | adj_matrix = (adj_matrix - self.adj_mean) / self.adj_std 136 | 137 | func_matrix = torch.from_numpy(func_matrix) 138 | func_matrix = (func_matrix - self.feat_mean) / self.feat_std 139 | 140 | return subject, adj_matrix, func_matrix 141 | 142 | def debug_getitem__(self, index=0): 143 | # pdb.set_trace() 144 | subject, adj_matrix_path, func_matrix_path = self.data[index] 145 | 146 | adj_matrix = np.loadtxt(adj_matrix_path) 147 | adj_matrix = np.delete(adj_matrix, self.empty_list, axis=0) 148 | adj_matrix = np.delete(adj_matrix, self.empty_list, axis=1) 149 | # adj_matrix = adj_matrix_normlize(adj_matrix) 150 | 151 | func_matrix = np.loadtxt(func_matrix_path) 152 | func_matrix = np.delete(func_matrix, self.empty_list, axis=0) 153 | func_matrix = np.delete(func_matrix, self.empty_list, axis=1) 154 | 155 | adj_matrix = torch.from_numpy(adj_matrix) 156 | func_matrix = torch.from_numpy(func_matrix) 157 | 158 | pdb.set_trace() 159 | 160 | return subject, adj_matrix, func_matrix 161 | 162 | def __len__(self): 163 | return len(self.data) 164 | 165 | 166 | def get_loader(data_path, all_data, data_mean, empty_list, training, test, batch_size=16, num_workers=4): 167 | dataset = MICCAI(data_path, all_data, data_mean, empty_list, training, test) 168 | 169 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 170 | batch_size=batch_size, 171 | num_workers=num_workers) 172 | 173 | return data_loader 174 | 175 | 176 | if __name__ == '__main__': 177 | data_path = './data/HCP_1064_matrix_atlas2' 178 | all_data = load_data(data_path) 179 | data_mean = normlize_data(all_data) 180 | # get_common_adj(all_data, data_mean) 181 | # dataset = MICCAI(data_path, all_data) 182 | # for i in range(len(dataset)): 183 | # x, y, z = dataset.debug_getitem__(i) 184 | 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /Code/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from torch.nn.parameter import Parameter 9 | from torch.nn.modules.module import Module 10 | import pdb 11 | 12 | 13 | class BatchNorm(nn.Module): 14 | def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): 15 | super(BatchNorm, self).__init__() 16 | 17 | self.batchnorm_layer = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) 18 | 19 | def forward(self, x): 20 | x = x.permute(0, 2, 1) 21 | x = self.batchnorm_layer(x) 22 | x = x.permute(0, 2, 1) 23 | return x 24 | 25 | 26 | class BatchNormLinear(nn.Module): 27 | def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): 28 | super(BatchNormLinear, self).__init__() 29 | 30 | self.batchnorm_layer = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) 31 | 32 | def forward(self, x): 33 | x = self.batchnorm_layer(x) 34 | return x 35 | 36 | 37 | class BatchNormAdj(nn.Module): 38 | def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): 39 | super(BatchNormAdj, self).__init__() 40 | self.batchnorm_layer = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) 41 | 42 | def forward(self, x): 43 | batch_size = x.size(0) 44 | num_region = x.size(1) 45 | x = x.contiguous().view(batch_size, -1) 46 | x = self.batchnorm_layer(x) 47 | x = x.contiguous().view(batch_size, num_region, -1) 48 | return x 49 | 50 | 51 | class LayerNorm(nn.Module): 52 | def __init__(self, num_features, eps=1e-05, elementwise_affine=True): 53 | super(LayerNorm, self).__init__() 54 | 55 | self.LayerNorm = nn.LayerNorm(num_features, eps, elementwise_affine)#num_features = (input.size()[1:]) 56 | 57 | def forward(self, x): 58 | x = self.LayerNorm(x) 59 | return x 60 | 61 | 62 | class GraphConvolution(Module): 63 | 64 | def __init__(self, in_features, out_features, bias=True): 65 | super(GraphConvolution, self).__init__() 66 | self.in_features = in_features 67 | self.out_features = out_features 68 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 69 | if bias: 70 | self.bias = Parameter(torch.FloatTensor(out_features)) 71 | else: 72 | self.register_parameter('bias', None) 73 | self.reset_parameters() 74 | 75 | def reset_parameters(self): 76 | stdv = 1. / math.sqrt(self.weight.size(1)) 77 | self.weight.data.uniform_(-stdv, stdv) 78 | if self.bias is not None: 79 | self.bias.data.uniform_(-stdv, stdv) 80 | 81 | def forward(self, input, adj): 82 | support = torch.matmul(input, self.weight) 83 | output = torch.einsum('bij,bjd->bid', [adj, support]) 84 | if self.bias is not None: 85 | return output + self.bias 86 | else: 87 | return output 88 | 89 | def __repr__(self): 90 | return self.__class__.__name__ + ' (' \ 91 | + str(self.in_features) + ' -> ' \ 92 | + str(self.out_features) + ')' 93 | 94 | 95 | class GCNGenerator(nn.Module): 96 | def __init__(self, in_feature, out1_feature, out2_feature, out3_feature, dropout): 97 | super(GCNGenerator, self).__init__() 98 | 99 | self.gc1 = GraphConvolution(in_feature, in_feature) 100 | self.LayerNorm1 = LayerNorm([in_feature, in_feature]) 101 | 102 | self.gc2_01 = GraphConvolution(in_feature, int(in_feature*2)) 103 | self.LayerNorm2_01 = LayerNorm([in_feature, int(in_feature*2)]) 104 | self.gc2_12 = GraphConvolution(int(in_feature*2), in_feature) 105 | self.LayerNorm2_12 = LayerNorm([in_feature, in_feature]) 106 | 107 | self.gc3_01 = GraphConvolution(in_feature, int(in_feature/2)) 108 | self.LayerNorm3_01 = LayerNorm([in_feature, int(in_feature/2)]) 109 | self.gc3_13 = GraphConvolution(int(in_feature/2), in_feature) 110 | self.LayerNorm3_13 = LayerNorm([in_feature, in_feature]) 111 | 112 | # the theta: can compare different initializations 113 | self.weight = Parameter(torch.FloatTensor([0.0, 0.0, 0.0, ])) 114 | 115 | 116 | def forward(self, topo, funcs, batchSize, isTest=False): 117 | 118 | topo = funcs # to compare with different updating methods: topo != funcs 119 | 120 | x1 = self.gc1(funcs, topo) 121 | x1 = self.LayerNorm1(x1) 122 | x1 = F.leaky_relu(x1, 0.05, inplace=True) 123 | 124 | x2 = self.gc2_01(funcs, topo) 125 | x2 = self.LayerNorm2_01(x2) 126 | x2 = F.leaky_relu(x2, 0.05, inplace=True) 127 | x2 = self.gc2_12(x2, topo) 128 | x2 = self.LayerNorm2_12(x2) 129 | x2 = F.leaky_relu(x2, 0.05, inplace=True) 130 | 131 | x3 = self.gc3_01(funcs, topo) 132 | x3 = self.LayerNorm3_01(x3) 133 | x3 = F.leaky_relu(x3, 0.05, inplace=True) 134 | x3 = self.gc3_13(x3, topo) 135 | x3 = self.LayerNorm3_13(x3) 136 | x3 = F.leaky_relu(x3, 0.05, inplace=True) 137 | 138 | x = self.weight[0]*x1 + self.weight[1]*x2 + self.weight[2]*x3 139 | outputs = x + torch.transpose(x, 1, 2) 140 | if isTest is True: 141 | return outputs.squeeze().unsqueeze(0) 142 | else: 143 | return outputs.squeeze() 144 | 145 | 146 | class CNNGenerator1(nn.Module): 147 | def __init__(self, in_feature, out1_feature, out2_feature, out3_feature, dropout): 148 | super(CNNGenerator1, self).__init__() 149 | self.conv1 = nn.Conv2d(2, 4, kernel_size=15, stride=1, padding=7) 150 | self.conv2 = nn.Conv2d(2, 8, kernel_size=5, stride=1, padding=2) 151 | self.dropout = nn.Dropout(dropout) 152 | self.weight1 = Parameter(torch.FloatTensor([0.25, 0.25, 0.25, 0.25, ])) 153 | self.weight2 = Parameter(torch.FloatTensor([0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, ])) 154 | self.LayerNorm = LayerNorm([in_feature, in_feature]) 155 | 156 | def forward(self, topo, func_matrix, batchSize, isTest=False): 157 | # topo not used in CNN based generator 158 | func_matrix = func_matrix.unsqueeze(1) 159 | x0 = torch.cat((func_matrix, func_matrix), 1) 160 | 161 | x1 = self.conv1(x0) 162 | x1 = self.LayerNorm(x1) 163 | x1 = F.leaky_relu(x1, 0.5, inplace=True) 164 | x1 = self.weight1[0] * x1[:, 0] + self.weight1[1] * x1[:, 1] + self.weight1[2] * x1[:, 2] + self.weight1[3] * x1[:, 3] 165 | 166 | x2 = self.conv2(x0) 167 | x2 = self.LayerNorm(x2) 168 | x2 = F.leaky_relu(x2, 0.05, inplace=True) 169 | x2 = self.weight2[0] * x2[:, 0] + self.weight2[1] * x2[:, 1] + self.weight2[2] * x2[:, 2] + self.weight2[3] * x2[:, 3] + self.weight2[4] * x2[:, 4] + self.weight2[5] * x2[:, 5] + self.weight2[6] * x2[:, 6] + self.weight2[7] * x2[:, 7] 170 | 171 | x = x1 + x2 172 | outputs = F.leaky_relu(x, 0.5, inplace=True) 173 | outputs = outputs + torch.transpose(outputs, 1, 2) 174 | if isTest is True: 175 | return outputs.squeeze().unsqueeze(0) 176 | else: 177 | return outputs.squeeze() 178 | 179 | 180 | class CNNGenerator2(nn.Module): 181 | def __init__(self, in_feature, out1_feature, out2_feature, out3_feature, dropout): 182 | super(CNNGenerator2, self).__init__() 183 | 184 | self.features = nn.Sequential( 185 | nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2), 186 | nn.ReLU(inplace=True), 187 | nn.MaxPool2d(kernel_size=3, stride=2), 188 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 189 | nn.ReLU(inplace=True), 190 | nn.MaxPool2d(kernel_size=3, stride=2), 191 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 192 | nn.ReLU(inplace=True), 193 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 194 | nn.ReLU(inplace=True), 195 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 196 | nn.ReLU(inplace=True), 197 | nn.MaxPool2d(kernel_size=3, stride=2), 198 | ) 199 | 200 | self.dropout = nn.Dropout(dropout) 201 | self.Linear1 = nn.Linear(3*3*256, in_feature) 202 | self.Linear2 = nn.Linear(3*3*256, in_feature) 203 | # atlas1: 3*3*256; atlas2: 256 204 | 205 | 206 | def forward(self, topo, func_matrix, batchSize, isTest=False): 207 | 208 | x = func_matrix.unsqueeze(1) 209 | x = self.features(x) 210 | 211 | x = torch.flatten(x, 1) 212 | 213 | x = torch.bmm(self.Linear1(x).unsqueeze(2), self.Linear2(x).unsqueeze(1)) 214 | 215 | outputs = x + torch.transpose(x, 1, 2) 216 | if isTest is True: 217 | return outputs.squeeze().unsqueeze(0) 218 | else: 219 | return outputs.squeeze() 220 | 221 | 222 | class Discriminator(nn.Module): 223 | def __init__(self, in_feature, out1_feature, out2_feature, out3_feature, dropout): 224 | super(Discriminator, self).__init__() 225 | 226 | self.gc1 = GraphConvolution(in_feature, out1_feature) 227 | self.batchnorm1 = BatchNorm(out1_feature) 228 | self.gc2 = GraphConvolution(out1_feature, out2_feature) 229 | self.batchnorm2 = BatchNorm(out2_feature) 230 | self.gc3 = GraphConvolution(out2_feature, out3_feature) 231 | self.batchnorm3 = BatchNorm(out3_feature) 232 | self.batchnorm4 = BatchNormLinear(1024) 233 | self.dropout = dropout 234 | self.Linear1 = nn.Linear(out3_feature * in_feature, 1024) # 148 for atlas1 and 68 for atlas2 235 | self.dropout = dropout 236 | self.Linear2 = nn.Linear(1024, 1) 237 | 238 | 239 | def batch_eye(self, size): 240 | batch_size = size[0] 241 | n = size[1] 242 | I = torch.eye(n).unsqueeze(0) 243 | I = I.repeat(batch_size, 1, 1) 244 | return I 245 | 246 | def forward(self, adj_matrix, batchSize, isTest=False): 247 | x = self.batch_eye(adj_matrix.shape).to(adj_matrix.device).float() 248 | 249 | x = self.gc1(x, adj_matrix) 250 | x = nn.LeakyReLU(0.2, True)(x) 251 | x = self.batchnorm1(x) 252 | 253 | x = self.gc2(x, adj_matrix) 254 | x = nn.LeakyReLU(0.2, True)(x) 255 | x = self.batchnorm2(x) 256 | 257 | x = self.gc3(x, adj_matrix) 258 | x = nn.LeakyReLU(0.2, True)(x) 259 | 260 | x = F.dropout(x, self.dropout, training=self.training) 261 | x = x.contiguous().view(batchSize, -1) 262 | x = self.Linear1(x) 263 | x = nn.LeakyReLU(0.2, True)(x) 264 | 265 | x = F.dropout(x, self.dropout, training=self.training) 266 | x = self.Linear2(x) 267 | x = torch.sigmoid(x) 268 | x = x.contiguous().view(batchSize, -1) 269 | outputs = x 270 | return outputs.squeeze() 271 | -------------------------------------------------------------------------------- /Code/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.utils.data 4 | import random 5 | 6 | 7 | class BalancedBatchSampler(torch.utils.data.sampler.Sampler): 8 | def __init__(self, dataset): 9 | self.dataset = {} 10 | self.balanced_max = 0 11 | # Save all the indices for all the classes 12 | for idx in range(0, len(dataset)): 13 | label = self._get_label(dataset, idx) 14 | if label not in self.dataset: 15 | self.dataset[label] = [] 16 | self.dataset[label].append(idx) 17 | self.balanced_max = len(self.dataset[label]) \ 18 | if len(self.dataset[label]) > self.balanced_max else self.balanced_max 19 | 20 | # Oversample the classes with fewer elements than the max 21 | for label in self.dataset: 22 | while len(self.dataset[label]) < self.balanced_max: 23 | self.dataset[label].append(random.choice(self.dataset[label])) 24 | 25 | self.keys = list(self.dataset.keys()) 26 | self.currentkey = 0 27 | 28 | def __iter__(self): 29 | while len(self.dataset[self.keys[self.currentkey]]) > 0: 30 | yield self.dataset[self.keys[self.currentkey]].pop() 31 | self.currentkey = (self.currentkey + 1) % len(self.keys) 32 | 33 | 34 | def _get_label(self, dataset, idx): 35 | return str(dataset.data[idx][3]) 36 | 37 | def __len__(self): 38 | return self.balanced_max*len(self.keys) -------------------------------------------------------------------------------- /Code/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | from dataloader import get_loader, load_data, normlize_data 5 | from model import GCNGenerator, Discriminator, CNNGenerator2, CNNGenerator1 6 | from utils import * 7 | 8 | from tensorboardX import SummaryWriter 9 | import shutil 10 | import pdb 11 | from Loss_custom import Pearson_loss_regions, Pearson_loss_whole 12 | 13 | # Device configuration 14 | # device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | def main(args): 18 | device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") 19 | # Create model directory 20 | 21 | if os.path.exists(args.runs_path): 22 | shutil.rmtree(args.runs_path) 23 | os.makedirs(args.runs_path) 24 | 25 | if os.path.exists(args.model_path): 26 | shutil.rmtree(args.model_path) 27 | os.makedirs(args.model_path) 28 | 29 | if os.path.exists(args.results_path): 30 | shutil.rmtree(args.results_path) 31 | os.makedirs(args.results_path) 32 | 33 | if os.path.exists(args.middle_results_path): 34 | shutil.rmtree(args.middle_results_path) 35 | os.makedirs(args.middle_results_path) 36 | 37 | if args.atlas == 'atlas1': 38 | empty_list = [41, 116] 39 | elif args.atlas == 'atlas2': 40 | empty_list = [3, 38] 41 | 42 | all_data = load_data(args.data_path, empty_list) 43 | data_mean = normlize_data(all_data, empty_list) 44 | 45 | exp_prec = [] 46 | for exp_num in range(2): 47 | # Build the models 48 | if os.path.exists(args.middle_results_path + '/' + str(exp_num)): 49 | shutil.rmtree(args.middle_results_path + '/' + str(exp_num)) 50 | os.makedirs(args.middle_results_path + '/' + str(exp_num)) 51 | 52 | generator = GCNGenerator(args.input_size, args.out1_feature, args.out2_feature, args.out3_feature, 0.6) 53 | discriminator = Discriminator(args.input_size, args.out1_feature, args.out2_feature, args.out3_feature, 0.6) 54 | 55 | test_generator = GCNGenerator(args.input_size, args.out1_feature, args.out2_feature, args.out3_feature, 0.6) 56 | test_discriminator = Discriminator(args.input_size, args.out1_feature, args.out2_feature, args.out3_feature, 57 | 0.6) 58 | 59 | generator = generator.to(device) 60 | discriminator = discriminator.to(device) 61 | 62 | test_generator = test_generator.to(device) 63 | test_discriminator = test_discriminator.to(device) 64 | 65 | adversarial_loss = torch.nn.BCELoss() 66 | # adversarial_loss = torch.nn.MSELoss() 67 | 68 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.learning_rate, betas=(args.beta1, args.beta2), 69 | weight_decay=args.weight_decay) 70 | scheduler_G = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, 'max', patience=args.patience) 71 | 72 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate, 73 | betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) 74 | scheduler_D = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, 'max', patience=args.patience) 75 | 76 | exp_log_dir = os.path.join(args.runs_path, str(exp_num)) 77 | if not os.path.isdir(exp_log_dir): 78 | os.makedirs(exp_log_dir) 79 | writer = SummaryWriter(log_dir=exp_log_dir) 80 | 81 | best_prec = 1000000 82 | for epoch in range(args.num_epochs): 83 | train_data_loader = get_loader(args.data_path, all_data, data_mean, empty_list, True, False, args.batch_size, 84 | num_workers=args.num_workers) 85 | val_data_loader = get_loader(args.data_path, all_data, data_mean, empty_list, False, False, args.batch_size, 86 | num_workers=args.num_workers) 87 | 88 | if epoch < args.pre_epochs: 89 | pre_train(args, train_data_loader, generator, discriminator, adversarial_loss, optimizer_G, 90 | optimizer_D, writer, epoch, exp_num, device) 91 | else: 92 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.learning_rate / 1, 93 | betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) 94 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate / 1, 95 | betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) 96 | 97 | train(args, train_data_loader, generator, discriminator, adversarial_loss, optimizer_G, 98 | optimizer_D, writer, epoch, exp_num, device) 99 | prec = validate(args, val_data_loader, generator, discriminator, adversarial_loss, writer, epoch, 100 | exp_num, device) 101 | scheduler_G.step(prec) 102 | scheduler_D.step(prec) 103 | 104 | # Save the model checkpoints 105 | if prec < best_prec: 106 | best_prec = prec 107 | torch.save(generator.state_dict(), os.path.join( 108 | args.model_path, 'GAN_generator-{}.ckpt'.format(exp_num))) 109 | torch.save(discriminator.state_dict(), os.path.join( 110 | args.model_path, 'GAN_discriminator-{}.ckpt'.format(exp_num))) 111 | 112 | test_data_loader = get_loader(args.data_path, all_data, data_mean, empty_list, False, True, 1, num_workers=args.num_workers) 113 | test_prec = test(args, test_data_loader, test_generator, test_discriminator, adversarial_loss, writer, epoch, exp_num, device) 114 | print("Test Prec:", test_prec) 115 | exp_prec.append(test_prec) 116 | writer.close() 117 | 118 | del generator 119 | del discriminator 120 | del test_generator 121 | del test_discriminator 122 | del adversarial_loss 123 | del optimizer_G 124 | del optimizer_D 125 | 126 | del scheduler_G 127 | del scheduler_D 128 | del writer 129 | print(exp_prec) 130 | 131 | 132 | def pre_train(args, data_loader, generator, discriminator, adversarial_loss, optimizer_G, optimizer_D, writer, 133 | epoch, exp_num, device): 134 | generator.train() 135 | 136 | if epoch % 50 == 0: 137 | for name, param in generator.named_parameters(): 138 | if name == 'weight': 139 | print(param) 140 | 141 | batch_time = AverageMeter() # forward prop. + back prop. time 142 | data_time = AverageMeter() # data loading time 143 | losses = AverageMeter() # loss (per word decoded) 144 | start = time.time() 145 | for i, (subject, adj_matrix, func_matrix) in enumerate(data_loader): 146 | # pdb.set_trace() 147 | batchSize = adj_matrix.shape[0] 148 | data_time.update(time.time() - start) 149 | 150 | funcs = func_matrix.to(device).float() 151 | adjs_real = adj_matrix.to(device).float() 152 | 153 | adjs_gen = generator(funcs, funcs, batchSize, isTest=False) 154 | 155 | if epoch > 1: 156 | topo = adjs_gen 157 | adjs_gen = generator(topo, funcs, batchSize, isTest=False) 158 | 159 | loss = torch.nn.functional.mse_loss(adjs_gen, adjs_real) * args.nodes + Pearson_loss_regions(adjs_gen, 160 | adjs_real) + Pearson_loss_whole( 161 | adjs_gen, adjs_real) 162 | loss.backward() 163 | optimizer_G.step() 164 | 165 | losses.update(loss.item()) 166 | 167 | batch_time.update(time.time() - start) 168 | 169 | start = time.time() 170 | 171 | if epoch % 25 == 0: 172 | adjs_gen_middle = adjs_gen.squeeze().cpu().detach().numpy() 173 | for j in range(batchSize): 174 | np.savetxt(args.middle_results_path + '/' + str(exp_num) + '/' + subject[j] + '_adjs-gen_exp(' + str( 175 | exp_num) + ')_epoch(' + str(epoch) + ').txt', 176 | adjs_gen_middle[j], fmt='%.9f') 177 | 178 | if epoch == 0: 179 | adjs_real_middle = adjs_real.squeeze().cpu().detach().numpy() 180 | for j in range(batchSize): 181 | np.savetxt(args.middle_results_path + '/' + str(exp_num) + '/' + subject[j] + '_adjs-real_exp(' + str( 182 | exp_num) + ').txt', adjs_real_middle[j], fmt='%.9f') 183 | 184 | if i % args.log_step == 0: 185 | print('Pre_Train Epoch: [{0}][{1}/{2}]\t' 186 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 187 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 188 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(data_loader), 189 | batch_time=batch_time, 190 | data_time=data_time, 191 | loss=losses, )) 192 | writer.add_histogram('GAN_pre_train_loss', losses.avg, epoch) 193 | 194 | 195 | def train(args, data_loader, generator, discriminator, adversarial_loss, optimizer_G, optimizer_D, writer, epoch, 196 | exp_num, device): 197 | # Train the models 198 | # pdb.set_trace() 199 | generator.train() 200 | discriminator.train() 201 | 202 | if epoch % 50 == 0: 203 | for name, param in generator.named_parameters(): 204 | if name == 'weight': 205 | print(param) 206 | 207 | batch_time = AverageMeter() # forward prop. + back prop. time 208 | data_time = AverageMeter() # data loading time 209 | G_losses = AverageMeter() # loss (per word decoded) 210 | D_losses = AverageMeter() # loss (per word decoded) 211 | gen_top5accs = AverageMeter() 212 | real_top5accs = AverageMeter() 213 | fake_top5accs = AverageMeter() 214 | 215 | start = time.time() 216 | for i, (subject, adj_matrix, func_matrix) in enumerate(data_loader): 217 | # pdb.set_trace() 218 | batchSize = adj_matrix.shape[0] 219 | data_time.update(time.time() - start) 220 | 221 | funcs = func_matrix.to(device).float() 222 | adjs_real = adj_matrix.to(device).float() 223 | 224 | valid = torch.ones((adjs_real.size(0),), dtype=torch.float).to(device) 225 | fake = torch.zeros((adjs_real.size(0),), dtype=torch.float).to(device) 226 | 227 | optimizer_D.zero_grad() 228 | discriminator.zero_grad() 229 | score_real = discriminator(adjs_real, batchSize) 230 | real_loss = adversarial_loss(score_real, valid) 231 | # real_loss = adversarial_loss(score_real, torch.randint(0,2,(adjs_real.size(0), 1)).to(device).float()) 232 | real_acc = float((score_real.view(-1) > 0.5).sum().item()) / float(batchSize) 233 | 234 | adjs_gen = generator(funcs, funcs, batchSize, isTest=False) 235 | 236 | if epoch > 1: 237 | topo = adjs_gen 238 | adjs_gen = generator(topo, funcs, batchSize, isTest=False) 239 | 240 | score_fake = discriminator(adjs_gen.detach(), batchSize) 241 | fake_loss = adversarial_loss(score_fake, fake) 242 | # fake_loss = adversarial_loss(score_fake, torch.randint(0,2,(adjs_gen.size(0), 1)).to(device).float()) 243 | d_loss = (real_loss + fake_loss) * 0.5 244 | 245 | fake_acc = float((score_fake.view(-1) < 0.5).sum().item()) / float(batchSize) 246 | 247 | if ((epoch - 200 + 5) % 5 == 0 and real_acc < 0.8) or ((epoch - 200 + 5) % 5 == 0 and fake_acc < 0.5): 248 | d_loss.backward() 249 | optimizer_D.step() 250 | # fake_acc = float((score_fake.view(-1) < 0.5).sum().item())/float(batchSize) 251 | 252 | optimizer_G.zero_grad() 253 | generator.zero_grad() 254 | g_score = discriminator(adjs_gen, batchSize) 255 | g_loss = adversarial_loss(g_score, valid) + torch.nn.functional.mse_loss(adjs_gen,adjs_real) * args.nodes + Pearson_loss_regions(adjs_gen, adjs_real) + Pearson_loss_whole(adjs_gen, adjs_real) 256 | # if itern%5 == 0: 257 | g_loss.backward() 258 | optimizer_G.step() 259 | gen_acc = float((g_score.view(-1) >= 0.5).sum().item()) / float(batchSize) 260 | 261 | # top5 = accuracy(torch.sigmoid(scores), labels, 1) 262 | # pdb.set_trace() 263 | 264 | G_losses.update(g_loss.item()) 265 | D_losses.update(d_loss.item()) 266 | 267 | real_top5accs.update(real_acc) 268 | fake_top5accs.update(fake_acc) 269 | gen_top5accs.update(gen_acc) 270 | 271 | batch_time.update(time.time() - start) 272 | 273 | start = time.time() 274 | 275 | if epoch % 25 == 0: 276 | adjs_gen_middle = adjs_gen.squeeze().cpu().detach().numpy() 277 | for j in range(batchSize): 278 | np.savetxt(args.middle_results_path + '/' + str(exp_num) + '/' + subject[j] + '_adjs-gen_exp(' + str( 279 | exp_num) + ')_epoch(' + str(epoch) + ').txt', 280 | adjs_gen_middle[j], fmt='%.9f') 281 | 282 | if epoch == 0: 283 | adjs_real_middle = adjs_real.squeeze().cpu().detach().numpy() 284 | for j in range(batchSize): 285 | np.savetxt(args.middle_results_path + '/' + str(exp_num) + '/' + subject[j] + '_adjs-real_exp(' + str( 286 | exp_num) + ').txt', adjs_real_middle[j], fmt='%.9f') 287 | 288 | # Print log info 289 | if i % args.log_step == 0: 290 | print('Train Epoch: [{0}][{1}/{2}]\t' 291 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 292 | 'G_Loss {G_loss.val:.4f} ({G_loss.avg:.4f})\t' 293 | 'D_Loss {D_loss.val:.4f} ({D_loss.avg:.4f})\t' 294 | 'Gen Accuracy {gen_top5.val:.3f} ({gen_top5.avg:.3f})\t' 295 | 'Real Accuracy {real_top5.val:.3f} ({real_top5.avg:.3f})\t' 296 | 'Fake Accuracy {fake_top5.val:.3f} ({fake_top5.avg:.3f})'.format(epoch, i, len(data_loader), 297 | batch_time=batch_time, 298 | G_loss=G_losses, 299 | D_loss=D_losses, 300 | gen_top5=gen_top5accs, 301 | real_top5=real_top5accs, 302 | fake_top5=fake_top5accs, )) 303 | writer.add_histogram('GAN_train_G_loss', G_losses.avg, epoch) 304 | writer.add_histogram('GAN_train_D_loss', D_losses.avg, epoch) 305 | writer.add_histogram('GAN_train_gen_acc', gen_top5accs.avg, epoch) 306 | writer.add_histogram('GAN_train_real_acc', real_top5accs.avg, epoch) 307 | writer.add_histogram('GAN_train_fake_acc', fake_top5accs.avg, epoch) 308 | 309 | 310 | def validate(args, data_loader, generator, discriminator, adversarial_loss, writer, epoch, exp_num, device): 311 | # Evaluate the models 312 | generator.train() 313 | discriminator.train() 314 | 315 | batch_time = AverageMeter() # forward prop. + back prop. time 316 | data_time = AverageMeter() # data loading time 317 | G_losses = AverageMeter() # loss (per word decoded) 318 | D_losses = AverageMeter() # loss (per word decoded) 319 | real_losses = AverageMeter() 320 | fake_losses = AverageMeter() 321 | gen_top5accs = AverageMeter() 322 | real_top5accs = AverageMeter() 323 | fake_top5accs = AverageMeter() 324 | D_top5accs = AverageMeter() 325 | 326 | start = time.time() 327 | for i, (subject, adj_matrix, func_matrix) in enumerate(data_loader): 328 | batchSize = adj_matrix.shape[0] 329 | 330 | data_time.update(time.time() - start) 331 | 332 | funcs = func_matrix.to(device).float() 333 | adjs_real = adj_matrix.to(device).float() 334 | 335 | valid = torch.ones((adjs_real.size(0),), dtype=torch.float).to(device) 336 | fake = torch.zeros((adjs_real.size(0),), dtype=torch.float).to(device) 337 | 338 | adjs_gen = generator(funcs, funcs, batchSize, isTest=False) 339 | 340 | if epoch > 1: 341 | topo = adjs_gen 342 | adjs_gen = generator(topo, funcs, batchSize, isTest=False) 343 | 344 | # Forward, backward and optimize 345 | with torch.no_grad(): 346 | score_real = discriminator(adjs_real, batchSize) 347 | score_fake = discriminator(adjs_gen.detach(), batchSize) 348 | g_score = discriminator(adjs_gen, batchSize) 349 | g_loss = adversarial_loss(g_score, valid) + torch.nn.functional.mse_loss(adjs_gen, 350 | adjs_real) * args.nodes + Pearson_loss_regions( 351 | adjs_gen, adjs_real) + Pearson_loss_whole(adjs_gen, adjs_real) 352 | real_loss = adversarial_loss(score_real, valid) 353 | fake_loss = adversarial_loss(score_fake, fake) 354 | d_loss = (real_loss + fake_loss) * 0.5 355 | # loss = criterion(torch.sigmoid(scores), onehot_labels) 356 | 357 | real_acc = float((score_real.view(-1) > 0.5).sum().item()) / float(batchSize) 358 | fake_acc = float((score_fake.view(-1) < 0.5).sum().item()) / float(batchSize) 359 | D_acc = (real_acc + fake_acc) / 2.0 360 | gen_acc = float((g_score.view(-1) >= 0.5).sum().item()) / float(batchSize) 361 | 362 | G_losses.update(g_loss.item()) 363 | D_losses.update(d_loss.item()) 364 | real_losses.update(real_loss.item()) 365 | fake_losses.update(fake_loss.item()) 366 | 367 | real_top5accs.update(real_acc) 368 | fake_top5accs.update(fake_acc) 369 | D_top5accs.update(D_acc) 370 | gen_top5accs.update(gen_acc) 371 | 372 | batch_time.update(time.time() - start) 373 | 374 | start = time.time() 375 | 376 | if epoch % 25 == 0: 377 | adjs_gen_middle = adjs_gen.squeeze().cpu().detach().numpy() 378 | for j in range(batchSize): 379 | np.savetxt(args.middle_results_path + '/' + str(exp_num) + '/' + subject[j] + '_adjs-gen_exp(' + str( 380 | exp_num) + ')_epoch(' + str(epoch) + ').txt', 381 | adjs_gen_middle[j], fmt='%.9f') 382 | 383 | if epoch == 300: 384 | adjs_real_middle = adjs_real.squeeze().cpu().detach().numpy() 385 | for j in range(batchSize): 386 | np.savetxt(args.middle_results_path + '/' + str(exp_num) + '/' + subject[j] + '_adjs-real_exp(' + str( 387 | exp_num) + ').txt', adjs_real_middle[j], fmt='%.9f') 388 | 389 | print('Val Epoch: [{0}/{1}]\t' 390 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 391 | 'G_Loss {G_loss.val:.4f} ({G_loss.avg:.4f})\t' 392 | 'D_Loss {D_loss.val:.4f} ({D_loss.avg:.4f})\t' 393 | 'Gen Accuracy {gen_top5.val:.3f} ({gen_top5.avg:.3f})\t' 394 | 'Real Accuracy {real_top5.val:.3f} ({real_top5.avg:.3f})\t' 395 | 'Fake Accuracy {fake_top5.val:.3f} ({fake_top5.avg:.3f})'.format(len(data_loader), len(data_loader), 396 | batch_time=batch_time, 397 | G_loss=G_losses, 398 | D_loss=D_losses, 399 | gen_top5=gen_top5accs, 400 | real_top5=real_top5accs, 401 | fake_top5=fake_top5accs, )) 402 | writer.add_histogram('GAN_val_G_loss', G_losses.avg, epoch) 403 | writer.add_histogram('GAN_val_D_loss', D_losses.avg, epoch) 404 | writer.add_histogram('GAN_val_gen_acc', gen_top5accs.avg, epoch) 405 | writer.add_histogram('GAN_val_real_acc', real_top5accs.avg, epoch) 406 | writer.add_histogram('GAN_val_fake_acc', fake_top5accs.avg, epoch) 407 | return G_losses.avg 408 | 409 | 410 | def test(args, data_loader, test_generator, test_discriminator, adversarial_loss, writer, epoch, exp_num, device): 411 | if os.path.exists(args.results_path + '/' + str(exp_num)): 412 | shutil.rmtree(args.results_path + '/' + str(exp_num)) 413 | os.makedirs(args.results_path + '/' + str(exp_num)) 414 | # Evaluate the models 415 | test_generator.load_state_dict(torch.load(os.path.join(args.model_path, 'GAN_generator-{}.ckpt'.format(exp_num)))) 416 | test_generator.eval() 417 | 418 | test_discriminator.load_state_dict( 419 | torch.load(os.path.join(args.model_path, 'GAN_discriminator-{}.ckpt'.format(exp_num)))) 420 | test_discriminator.eval() 421 | 422 | batch_time = AverageMeter() # forward prop. + back prop. time 423 | data_time = AverageMeter() # data loading time 424 | G_losses = AverageMeter() 425 | D_losses = AverageMeter() 426 | real_losses = AverageMeter() 427 | fake_losses = AverageMeter() 428 | D_top5accs = AverageMeter() 429 | gen_top5accs = AverageMeter() 430 | real_top5accs = AverageMeter() 431 | fake_top5accs = AverageMeter() 432 | 433 | start = time.time() 434 | for i, (subject, adj_matrix, func_matrix) in enumerate(data_loader): 435 | batchSize = adj_matrix.shape[0] 436 | 437 | data_time.update(time.time() - start) 438 | 439 | funcs = func_matrix.to(device).float() 440 | adjs_real = adj_matrix.to(device).float() 441 | 442 | valid = torch.ones((adjs_real.size(0),), dtype=torch.float).to(device) 443 | fake = torch.zeros((adjs_real.size(0),), dtype=torch.float).to(device) 444 | 445 | adjs_gen = test_generator(funcs, funcs, batchSize, isTest=True) 446 | 447 | if epoch > 1: 448 | topo = adjs_gen 449 | adjs_gen = test_generator(topo, funcs, batchSize, isTest=True) 450 | 451 | adjs_gen_final = adjs_gen.squeeze().cpu().detach().numpy() 452 | np.savetxt(args.results_path + '/' + str(exp_num) + '/' + subject[0] + '_adjs_gen_' + str(exp_num) + '.txt', 453 | adjs_gen_final, fmt='%.9f') 454 | 455 | np.savetxt(args.results_path + '/' + str(exp_num) + '/' + subject[0] + '_adjs_real_' + str(exp_num) + '.txt', 456 | adjs_real.squeeze().cpu().detach().numpy(), fmt='%.9f') 457 | 458 | # Forward, backward and optimiz 459 | with torch.no_grad(): 460 | score_real = test_discriminator(adjs_real, batchSize) 461 | score_fake = test_discriminator(adjs_gen.detach(), batchSize) 462 | g_score = test_discriminator(adjs_gen, batchSize) 463 | # pdb.set_trace() 464 | g_loss = adversarial_loss(g_score.view(-1), valid) + torch.nn.functional.mse_loss(adjs_gen, adjs_real) * args.nodes + Pearson_loss_regions(adjs_gen, adjs_real) + Pearson_loss_whole(adjs_gen, adjs_real) 465 | real_loss = adversarial_loss(score_real.view(-1), valid) 466 | fake_loss = adversarial_loss(score_fake.view(-1), fake) 467 | d_loss = (real_loss + fake_loss) * 0.5 468 | # loss = criterion(torch.sigmoid(scores), onehot_labels) 469 | 470 | real_acc = float((score_real.view(-1) > 0.5).sum().item()) / float(batchSize) 471 | fake_acc = float((score_fake.view(-1) < 0.5).sum().item()) / float(batchSize) 472 | D_acc = (real_acc + fake_acc) / 2.0 473 | gen_acc = float((g_score.view(-1) >= 0.5).sum().item()) / float(batchSize) 474 | 475 | G_losses.update(g_loss.item()) 476 | D_losses.update(d_loss.item()) 477 | real_losses.update(real_loss.item()) 478 | fake_losses.update(fake_loss.item()) 479 | 480 | real_top5accs.update(real_acc) 481 | fake_top5accs.update(fake_acc) 482 | D_top5accs.update(D_acc) 483 | gen_top5accs.update(gen_acc) 484 | batch_time.update(time.time() - start) 485 | 486 | start = time.time() 487 | 488 | print('Test Epoch: [{0}/{1}]\t' 489 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 490 | 'G_Loss {G_loss.val:.4f} ({G_loss.avg:.4f})\t' 491 | 'D_Loss {D_loss.val:.4f} ({D_loss.avg:.4f})\t' 492 | 'Gen Accuracy {gen_top5.val:.3f} ({gen_top5.avg:.3f})\t' 493 | 'Real Accuracy {real_top5.val:.3f} ({real_top5.avg:.3f})\t' 494 | 'Fake Accuracy {fake_top5.val:.3f} ({fake_top5.avg:.3f})'.format(len(data_loader), len(data_loader), 495 | batch_time=batch_time, 496 | G_loss=G_losses, 497 | D_loss=D_losses, 498 | gen_top5=gen_top5accs, 499 | real_top5=real_top5accs, 500 | fake_top5=fake_top5accs, )) 501 | return D_top5accs.avg 502 | 503 | 504 | if __name__ == '__main__': 505 | parser = argparse.ArgumentParser() 506 | parser.add_argument('-atlas', '--atlas', type=str, default='atlas1', help='path for data') 507 | 508 | parser.add_argument('--data_path', type=str, default='./data/HCP_1064_SC_FC_atlas1', help='path for data') 509 | 510 | parser.add_argument('--model_path', type=str, default='./models', help='path for saving trained models') 511 | parser.add_argument('--results_path', type=str, default='./results', help='path for generased adjs') 512 | parser.add_argument('--middle_results_path', type=str, default='./middle_results', help='path for generased middle adjs') 513 | parser.add_argument('--runs_path', type=str, default='./runs', help='path for runs') 514 | 515 | parser.add_argument('--log_step', type=int, default=10, help='step size for prining log info') 516 | parser.add_argument('--save_step', type=int, default=1, help='step size for saving trained models') 517 | 518 | parser.add_argument('--nodes', type=int, default=68, help='number of regions, 148 for atlas1 and 68 for atlas2') 519 | parser.add_argument('--input_size', type=int, default=68, help='dimension of input feature, 148 for atlas1 and 68 for atlas2') 520 | parser.add_argument('--out1_feature', type=int, default=68, help='dimension of discriminator gcn1, 148 for atlas1 and 68 for atlas2') 521 | parser.add_argument('--out2_feature', type=int, default=256, help='dimension of discriminator gcn2, 256 for both atlas1 and atlas2') 522 | parser.add_argument('--out3_feature', type=int, default=68, help='dimension of discriminator gcn3, 148 for atlas1 and 68 for atlas2') 523 | 524 | parser.add_argument('--num_epochs', type=int, default=1000) 525 | parser.add_argument('--pre_epochs', type=int, default=200) 526 | parser.add_argument('--batch_size', type=int, default=32) 527 | parser.add_argument('--num_workers', type=int, default=4) 528 | parser.add_argument('--learning_rate', type=float, default=0.001) 529 | parser.add_argument('--beta1', type=float, default=0.5) 530 | parser.add_argument('--beta2', type=float, default=0.999) 531 | parser.add_argument('--weight_decay', type=float, default=0.01) 532 | parser.add_argument('--patience', type=float, default=10) 533 | 534 | parser.add_argument('-gpu_id', '--gpu_id', type=int, default=0) 535 | 536 | args = parser.parse_args() 537 | 538 | if args.atlas == 'atlas1': 539 | args.data_path = './data/HCP_1064_SC_FC_atlas1' 540 | args.nodes = 148 541 | args.input_size = 148 542 | args.out1_feature = 148 543 | args.out3_feature = 148 544 | args.model_path = './atlas1/models' 545 | args.results_path = './atlas1/results' 546 | args.middle_results_path = './atlas1/middle_results' 547 | args.runs_path = './atlas1/runs' 548 | elif args.atlas == 'atlas2': 549 | args.data_path = './data/HCP_1064_SC_FC_atlas2' 550 | args.nodes = 68 551 | args.input_size = 68 552 | args.out1_feature = 68 553 | args.out3_feature = 68 554 | args.model_path = './atlas2/models' 555 | args.results_path = './atlas2/results' 556 | args.middle_results_path = './atlas2/middle_results' 557 | args.runs_path = './atlas2/runs' 558 | else: 559 | print('wrong atlas type!') 560 | exit() 561 | 562 | print(args) 563 | main(args) 564 | -------------------------------------------------------------------------------- /Code/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import json 5 | import torch 6 | import torch.nn as nn 7 | import pdb 8 | 9 | 10 | class AverageMeter(object): 11 | """ 12 | Keeps track of most recent, average, sum, and count of a metric. 13 | """ 14 | 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | def adjust_learning_rate(optimizer, shrink_factor): 32 | """ 33 | Shrinks learning rate by a specified factor. 34 | :param optimizer: optimizer whose learning rate must be shrunk. 35 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 36 | """ 37 | 38 | print("\nDECAYING learning rate.") 39 | for param_group in optimizer.param_groups: 40 | param_group['lr'] = param_group['lr'] * shrink_factor 41 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 42 | 43 | 44 | def accuracy(scores, targets, k): 45 | """ 46 | Computes top-k accuracy, from predicted and true labels. 47 | :param scores: scores from the model 48 | :param targets: true labels 49 | :param k: k in top-k accuracy 50 | :return: top-k accuracy 51 | """ 52 | # pdb.set_trace() 53 | batch_size = targets.size(0) 54 | _, ind = scores.topk(k, 1, True, True) 55 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 56 | correct_total = correct.view(-1).float().sum() # 0D tensor 57 | return correct_total.item() * (100.0 / batch_size) 58 | 59 | 60 | def one_hot(scores): 61 | scores_onehot = [] 62 | for i in range(scores.shape[0]): 63 | if scores[i] > 0.5: 64 | scores_onehot.append([0,1]) 65 | else: 66 | scores_onehot.append([1,0]) 67 | scores_onehot = np.stack(scores_onehot) 68 | scores_onehot = torch.from_numpy(scores_onehot).to(scores.device) 69 | return scores_onehot 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /Code/vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | from graphviz import Digraph 5 | 6 | 7 | def make_dot(var, params=None): 8 | if params is not None: 9 | assert all(isinstance(p, Variable) for p in params.values()) 10 | param_map = {id(v): k for k, v in params.items()} 11 | 12 | node_attr = dict(style='filled', shape='box', align='left', 13 | fontsize='12', ranksep='0.1', height='0.2') 14 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 15 | seen = set() 16 | 17 | def size_to_str(size): 18 | return '(' + (', ').join(['%d' % v for v in size]) + ')' 19 | 20 | output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var) 21 | 22 | def add_nodes(var): 23 | if var not in seen: 24 | if torch.is_tensor(var): 25 | # note: this used to show .saved_tensors in pytorch0.2, but stopped 26 | # working as it was moved to ATen and Variable-Tensor merged 27 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') 28 | elif hasattr(var, 'variable'): 29 | u = var.variable 30 | name = param_map[id(u)] if params is not None else '' 31 | node_name = '%s\n %s' % (name, size_to_str(u.size())) 32 | dot.node(str(id(var)), node_name, fillcolor='lightblue') 33 | elif var in output_nodes: 34 | dot.node(str(id(var)), str(type(var).__name__), fillcolor='darkolivegreen1') 35 | else: 36 | dot.node(str(id(var)), str(type(var).__name__)) 37 | seen.add(var) 38 | if hasattr(var, 'next_functions'): 39 | for u in var.next_functions: 40 | if u[0] is not None: 41 | dot.edge(str(id(u[0])), str(id(var))) 42 | add_nodes(u[0]) 43 | if hasattr(var, 'saved_tensors'): 44 | for t in var.saved_tensors: 45 | dot.edge(str(id(t)), str(id(var))) 46 | add_nodes(t) 47 | 48 | if isinstance(var, tuple): 49 | for v in var: 50 | add_nodes(v.grad_fn) 51 | else: 52 | add_nodes(var.grad_fn) 53 | return dot 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 qidianzl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recovering-Brain-Structure-Network-Using-Functional-Connectivity 2 | ### Framework: 3 | ![framework](main3.png) 4 | 5 | ### Papers: 6 | This repository provides a PyTorch implementation of the models adopted in the two papers: 7 | 8 | - Zhang, Lu, Li Wang, and Dajiang Zhu. "Recovering brain structural connectivity from functional connectivity via multi-gcn based generative adversarial network." International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2020. 9 | - Zhang, Lu, Li Wang, and Dajiang Zhu. "Predicting Brain Structure Network using Functional Connectivity." in process. 10 | 11 | The first paper proposes the Multi-GCN GAN model and structure preserving loss, and the second paper further expands the research on different datasets, different atlases, different functional connectivity generation methods, different models, and new evaluation measures. New results have been obtained. 12 | 13 | 14 | ### Code: 15 | #### dataloader.py 16 | This file includes the preprocessing and normalization operations of the data. All the details have been introduced in the two papers. The only element needs to pay attention to is the empty list, which records the ids of the empty ROIs of specific atlases. For example, there are two brain regions in Destrieux Atlas are empty (Medial_wall for both left and right hemispheres). Therefore the corresponding two rows and columns in the generated SC and FC are zeros. We deleted these rows and columns. 17 | 18 | #### model.py 19 | We implemented different models in this file, including two different CNN-based generators, Multi-GCN-based generator and GCN-based discriminator. Different models can be chosen by directly calling the corresponding classes when run the train.py file. Different model architectures are as follows: 20 | - CNN (CNN-based generator, MSE loss and PCC loss) 21 | - Multi-GCN (Multi-GCN-based generator, MSE loss and PCC loss) 22 | - CNN based GAN (CNN-based generator and GCN-based discriminator, SP loss) 23 | - MGCN-GAN (Multi-GCN-based generator and GCN-based discriminator, SP loss) 24 | 25 | When adopting the proposed MGCN-GAN architecture, the different topology updating methods and differnet initializations of learnable combination coefficients of multiple GCNs (theta) can be directly changed in this file, and we have annotated in this file about how to change them. For Linear regression model, we directly called the *LinearRegression* from *sklearn.linear_model* package. 26 | 27 | #### Loss_custom.py 28 | The proposed SP loss includes three components: GAN loss, MSE loss and PCC loss. In this file, we implemented the PCC loss. For the MSE loss and GAN loss, we directly called the loss functions from torch.nn module in train.py file. By directly editing train.py file, different loss functions can be chosen, including: 29 | - GAN Loss 30 | - MSE+GAN loss 31 | - PCC+GAN loss 32 | - SP loss 33 | 34 | #### train.py 35 | You need to run this file to start. All the hyper-parameters can be defined in this file. 36 | 37 | Run `python ./train.py -atlas='atlas1' -gpu_id=1`. 38 | 39 | Tested with: 40 | - PyTorch 1.9.0 41 | - Python 3.7.0 42 | 43 | ### Data: 44 | We used 1064 subjects from HCP dataset and 132 subjects from ADNI dataset in our research. For each subject, we generated the structural connectivity (SC) and the functional connectivity (FC) matrices. All of the connectivity matrices can be shared for research purpose. Please contact the author to get the data by sending email to lu.zhang2@mavs.uta.edu. 45 | 46 | ### Citation: 47 | If you used the code or data of this project, please cite: 48 | 49 | @inproceedings{zhang2020recovering, 50 | title={Recovering brain structural connectivity from functional connectivity via multi-gcn based generative adversarial network}, 51 | author={Zhang, Lu and Wang, Li and Zhu, Dajiang}, 52 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 53 | pages={53--61}, 54 | year={2020}, 55 | organization={Springer} 56 | } 57 | 58 | @article{zhang2022predicting, 59 | title={Predicting brain structural network using functional connectivity}, 60 | author={Zhang, Lu and Wang, Li and Zhu, Dajiang and Alzheimer's Disease Neuroimaging Initiative and others}, 61 | journal={Medical Image Analysis}, 62 | volume={79}, 63 | pages={102463}, 64 | year={2022}, 65 | publisher={Elsevier} 66 | } 67 | 68 | 69 | -------------------------------------------------------------------------------- /main3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qidianzl/Recovering-Brain-Structure-Network-Using-Functional-Connectivity/0dcaa353be7f6741a5bb67cd7e28f6a295de0c1a/main3.png --------------------------------------------------------------------------------