├── Image_embedder ├── MoCo │ ├── data_aug │ │ ├── contrastive_learning_dataset.py │ │ ├── gaussian_blur.py │ │ └── view_generator.py │ ├── exceptions │ │ ├── __pycache__ │ │ │ └── exceptions.cpython-38.pyc │ │ └── exceptions.py │ ├── load_data.py │ ├── main_micle.py │ ├── moco │ │ ├── __pycache__ │ │ │ ├── builder.cpython-38.pyc │ │ │ └── loader.cpython-38.pyc │ │ ├── builder.py │ │ └── loader.py │ └── utils.py └── SimCLR │ ├── data_aug │ ├── contrastive_learning_dataset.py │ ├── gaussian_blur.py │ └── view_generator.py │ ├── dataloader.py │ ├── exceptions │ ├── __pycache__ │ │ └── exceptions.cpython-38.pyc │ └── exceptions.py │ ├── load_data.py │ ├── medical_aug.py │ ├── models │ └── resnet_simclr.py │ ├── run_with_pretrain_with_micle.py │ ├── simclr_micle.py │ └── utils.py ├── MultiplexNetwork ├── __pycache__ │ ├── embedder.cpython-38.pyc │ └── evaluate.cpython-38.pyc ├── data │ └── abide.pkl ├── embedder.py ├── evaluate.py ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── attention.cpython-38.pyc │ │ ├── discriminator.cpython-38.pyc │ │ ├── gcn.cpython-38.pyc │ │ └── readout.cpython-38.pyc │ ├── attention.py │ ├── discriminator.py │ ├── gcn.py │ └── readout.py ├── main.py ├── models │ ├── DMGI.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── DMGI.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── logreg.cpython-38.pyc │ └── logreg.py ├── saved_model │ ├── best_abide_DMGI_type0,type1,type2,type3.pkl │ ├── best_cmmd_DMGI_type0,type1,type2,type3.pkl │ └── best_cmmd_train60_0806_DMGI_type0,type1,type2,type3.pkl └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── process.cpython-38.pyc │ └── process.py ├── Preprocessing ├── Image_preprocessing │ ├── abide_atlas.py │ ├── adni_atlas.py │ ├── cmmd_save.py │ ├── duke_save.py │ └── oasis_atlas.py ├── Non_image_preprocessing │ ├── abide_kmeans.py │ ├── adni_kmeans.py │ ├── cmmd.ipynb │ ├── duke_kmeans.py │ └── oasis_kmeans.py ├── README.md └── sample_all_path.txt └── README.md /Image_embedder/MoCo/data_aug/contrastive_learning_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import transforms 2 | from data_aug.gaussian_blur import GaussianBlur 3 | from torchvision import transforms, datasets 4 | from data_aug.view_generator import ContrastiveLearningViewGenerator 5 | from exceptions.exceptions import InvalidDatasetSelection 6 | 7 | 8 | class ContrastiveLearningDataset: 9 | def __init__(self, root_folder): 10 | self.root_folder = root_folder 11 | 12 | @staticmethod 13 | def get_moco_pipeline_transform(size, s=1): 14 | """Return a set of data augmentation transformations as described in the SimCLR paper.""" 15 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 16 | data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.RandomApply([color_jitter], p=0.8), 19 | GaussianBlur(kernel_size=int(0.1 * size)), 20 | transforms.ToTensor()]) 21 | return data_transforms 22 | 23 | def get_dataset(self, name, n_views): 24 | valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True, 25 | transform=ContrastiveLearningViewGenerator( 26 | self.get_moco_pipeline_transform(32), 27 | n_views), 28 | download=True), 29 | 30 | 'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled', 31 | transform=ContrastiveLearningViewGenerator( 32 | self.get_moco_pipeline_transform(96), 33 | n_views), 34 | download=True)} 35 | 36 | try: 37 | dataset_fn = valid_datasets[name] 38 | except KeyError: 39 | raise InvalidDatasetSelection() 40 | else: 41 | return dataset_fn() 42 | -------------------------------------------------------------------------------- /Image_embedder/MoCo/data_aug/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchvision.transforms import transforms 5 | 6 | np.random.seed(0) 7 | 8 | 9 | class GaussianBlur(object): 10 | """blur a single image on CPU""" 11 | def __init__(self, kernel_size): 12 | radias = kernel_size // 2 13 | kernel_size = radias * 2 + 1 14 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 15 | stride=1, padding=0, bias=False, groups=3) 16 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 17 | stride=1, padding=0, bias=False, groups=3) 18 | self.k = kernel_size 19 | self.r = radias 20 | 21 | self.blur = nn.Sequential( 22 | nn.ReflectionPad2d(radias), 23 | self.blur_h, 24 | self.blur_v 25 | ) 26 | 27 | self.pil_to_tensor = transforms.ToTensor() 28 | self.tensor_to_pil = transforms.ToPILImage() 29 | 30 | def __call__(self, img): 31 | img = self.pil_to_tensor(img).unsqueeze(0) 32 | 33 | sigma = np.random.uniform(0.1, 2.0) 34 | x = np.arange(-self.r, self.r + 1) 35 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 36 | x = x / x.sum() 37 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 38 | 39 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 40 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 41 | 42 | with torch.no_grad(): 43 | img = self.blur(img) 44 | img = img.squeeze() 45 | 46 | img = self.tensor_to_pil(img) 47 | 48 | return img -------------------------------------------------------------------------------- /Image_embedder/MoCo/data_aug/view_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.random.seed(0) 4 | 5 | 6 | class ContrastiveLearningViewGenerator(object): 7 | """Take two random crops of one image as the query and key.""" 8 | 9 | def __init__(self, base_transform, n_views=2): 10 | self.base_transform = base_transform 11 | self.n_views = n_views 12 | 13 | def __call__(self, x): 14 | return [self.base_transform(x) for i in range(self.n_views)] 15 | -------------------------------------------------------------------------------- /Image_embedder/MoCo/exceptions/__pycache__/exceptions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/Image_embedder/MoCo/exceptions/__pycache__/exceptions.cpython-38.pyc -------------------------------------------------------------------------------- /Image_embedder/MoCo/exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | class BaseMoCoException(Exception): 2 | """Base exception""" 3 | 4 | 5 | class InvalidBackboneError(BaseMoCoException): 6 | """Raised when the choice of backbone Convnet is invalid.""" 7 | 8 | 9 | class InvalidDatasetSelection(BaseMoCoException): 10 | """Raised when the choice of dataset is invalid.""" 11 | -------------------------------------------------------------------------------- /Image_embedder/MoCo/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | from tqdm import tqdm 7 | import cv2 8 | 9 | 10 | def load_data(args): 11 | tf = transforms.ToTensor() 12 | data_name = args.data 13 | if data_name =='ADNI': 14 | root_dir = "./../../medical_dataset/ADNI/image_2D/" 15 | elif data_name =='OASIS': 16 | root_dir = "./../../medical_dataset/OASIS/image_2D/" 17 | elif data_name =='ABIDE': 18 | root_dir = "./../../medical_dataset/ABIDE/image_2D/" 19 | elif data_name =='CMMD': 20 | root_dir = "./../../medical_dataset/CMMD/image_2D/" 21 | elif data_name =='QIN': 22 | root_dir = './../../medical_dataset/QIN/image_2D/' 23 | 24 | files = os.listdir(root_dir) 25 | i=0 26 | name_list = [] 27 | for file in tqdm(files): 28 | if i ==0: 29 | 30 | img = cv2.imread(root_dir +'/' + file) 31 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 32 | gray_three = cv2.merge([gray,gray,gray]) 33 | img__ = cv2.resize(gray_three, (96,96), interpolation = cv2.INTER_AREA) 34 | # with torch.cuda.device(args.gpu_index): 35 | # img_t = torch.tensor(img__, dtype=torch.float, device=args.device) 36 | with torch.cuda.device(args.gpu_index): 37 | img_t = tf(img__).cuda() 38 | # print(args.device) 39 | # print(img_t) 40 | images = img_t.unsqueeze(0) 41 | if 'CMMD' in root_dir: 42 | id_ = file.split('-')[0][-1] +file.split('-')[1][0:4] 43 | num = '0' 44 | else: 45 | id_, num, coordinate = file.split('_') 46 | #0000 for escape from unexpected duplication of id + num 47 | names = int(id_ +'0000' + num) 48 | name_list.append(names) 49 | 50 | del img 51 | del gray 52 | del gray_three 53 | del img__ 54 | del img_t 55 | else: 56 | img = cv2.imread(root_dir +'/'+file) 57 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 58 | gray_three = cv2.merge([gray,gray,gray]) 59 | img__ = cv2.resize(gray_three, (96,96), interpolation = cv2.INTER_AREA) 60 | with torch.cuda.device(args.gpu_index): 61 | img_t = tf(img__).cuda() 62 | images = torch.cat((images,img_t.unsqueeze(0)),0) 63 | 64 | if 'CMMD' in root_dir: 65 | id_ = file.split('-')[0][-1] +file.split('-')[1][0:4] 66 | num = '0' 67 | else: 68 | id_, num, coordinate = file.split('_') 69 | names = int(id_ +'0000' + num) 70 | name_list.append(names) 71 | 72 | del img 73 | del gray 74 | del gray_three 75 | del img__ 76 | del img_t 77 | i+=1 78 | train_images = images.detach().cpu().numpy() 79 | train_names = np.expand_dims(np.array(name_list),axis=1) 80 | 81 | 82 | unique_label = list(set(name_list)) 83 | labels_np = np.array(name_list) 84 | ii = 0 85 | image_list = [] 86 | label_list__=[] 87 | for l in tqdm(unique_label): 88 | if ii==0: 89 | a = np.where(labels_np == l) 90 | index_ = list(a[0]) 91 | coor_img_t = images[index_] 92 | image_list.append(coor_img_t) 93 | label_list__.append(l) 94 | del coor_img_t 95 | 96 | else: 97 | a = np.where(labels_np == l) 98 | index_ = list(a[0]) 99 | coor_img_t = images[index_] 100 | image_list.append(coor_img_t) 101 | label_list__.append(l) 102 | del coor_img_t 103 | ii+=1 104 | del iamges 105 | return train_images,train_names, image_list, label_list__ 106 | -------------------------------------------------------------------------------- /Image_embedder/MoCo/main_micle.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/facebookresearch/moco 2 | import argparse 3 | import builtins 4 | import math 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | import torch.optim 20 | import torch.multiprocessing as mp 21 | import torch.utils.data 22 | import torch.utils.data.distributed 23 | # from torchvision.models import resnet18 24 | import torchvision.transforms as transforms 25 | import torchvision.datasets as datasets 26 | import torchvision.models as models 27 | 28 | import moco.loader 29 | import moco.builder 30 | 31 | from load_data import load_data 32 | model_names = sorted(name for name in models.__dict__ 33 | if name.islower() and not name.startswith("__") 34 | and callable(models.__dict__[name])) 35 | 36 | parser = argparse.ArgumentParser(description='PyTorch MoCo') 37 | parser.add_argument('--data', type='str', default='ADNI') 38 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 39 | choices=model_names, 40 | help='model architecture: ' + 41 | ' | '.join(model_names) + 42 | ' (default: resnet50)') 43 | parser.add_argument('-j', '--workers', default=12, type=int, metavar='N', 44 | help='number of data loading workers (default: 32)') 45 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 46 | help='number of total epochs to run') 47 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 48 | help='manual epoch number (useful on restarts)') 49 | parser.add_argument('-b', '--batch-size', default=256, type=int, 50 | metavar='N') 51 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 52 | metavar='LR', help='initial learning rate', dest='lr') 53 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 54 | help='momentum of SGD solver') 55 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 56 | metavar='W', help='weight decay (default: 1e-4)', 57 | dest='weight_decay') 58 | 59 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 60 | help='path to latest checkpoint (default: none)') 61 | parser.add_argument('--seed', default=None, type=int, 62 | help='seed for initializing training. ') 63 | parser.add_argument('--disable-cuda', action='store_true', 64 | help='Disable CUDA') 65 | parser.add_argument('--fp16-precision', action='store_true', 66 | help='Whether or not to use 16-bit precision GPU training.') 67 | parser.add_argument('--out_dim', default=128, type=int, 68 | help='feature dimension (default: 128)') 69 | parser.add_argument('--log-every-n-steps', default=100, type=int, 70 | help='Log every n steps') 71 | parser.add_argument('--temperature', default=0.07, type=float, 72 | help='softmax temperature (default: 0.07)') 73 | parser.add_argument('--n-views', default=2, type=int, metavar='N', 74 | help='Number of views for contrastive learning training.') 75 | parser.add_argument('--ngpus_per_node', default=1, type=int, help='ngpus per node.') 76 | parser.add_argument('--print_freq', default=100, type=int, help='print frequency.') 77 | parser.add_argument('--gpu', default=0, type=int, help='Gpu index.') 78 | parser.add_argument('--moco-dim', default=128, type=int, 79 | help='feature dimension (default: 128)') 80 | parser.add_argument('--moco-k', default=65536, type=int, 81 | help='queue size; number of negative keys (default: 65536)') 82 | parser.add_argument('--moco-m', default=0.999, type=float, 83 | help='moco momentum of updating key encoder (default: 0.999)') 84 | parser.add_argument('--moco-t', default=0.07, type=float, 85 | help='softmax temperature (default: 0.07)') 86 | parser.add_argument('--mlp', action='store_true', 87 | help='use mlp head') 88 | parser.add_argument('--aug-plus', action='store_true', 89 | help='use moco v2 data augmentation') 90 | parser.add_argument('--cos', action='store_true', 91 | help='use cosine lr schedule') 92 | 93 | def main(): 94 | args = parser.parse_args() 95 | assert args.n_views == 2 96 | # check if gpu training is available 97 | if not args.disable_cuda and torch.cuda.is_available(): 98 | args.device = torch.device('cuda') 99 | cudnn.deterministic = True 100 | cudnn.benchmark = True 101 | else: 102 | args.device = torch.device('cpu') 103 | args.gpu = -1 104 | 105 | 106 | train_images, train_labels, extract_images, extract_names = load_data(args) 107 | print("preparing images -- end") 108 | train_images_list = [] 109 | for i in range(len(train_images)): 110 | a = [train_images[i], medical_aug(train_images[i])] 111 | train_images_list.append(a) 112 | del train_images 113 | train_data_testing = TensorData(train_images_list, train_labels) 114 | train_loader = torch.utils.data.DataLoader( 115 | train_data_testing, batch_size=args.batch_size, shuffle=True, 116 | num_workers=args.workers, pin_memory=True, drop_last=True) 117 | 118 | model = moco.builder.MoCo( 119 | models.__dict__[args.arch], 120 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp).cuda() 121 | 122 | criterion = nn.CrossEntropyLoss().cuda() 123 | 124 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 125 | momentum=args.momentum, 126 | weight_decay=args.weight_decay) 127 | 128 | args.resume = './checkpoint_0190.pth.tar' 129 | if args.resume: 130 | if os.path.isfile(args.resume): 131 | print("=> loading checkpoint '{}'".format(args.resume)) 132 | if args.gpu is None: 133 | checkpoint = torch.load(args.resume) 134 | else: 135 | loc = 'cuda:{}'.format(args.gpu) 136 | checkpoint = torch.load(args.resume, map_location=loc) 137 | args.start_epoch = checkpoint['epoch'] 138 | model.load_state_dict(checkpoint['state_dict']) 139 | optimizer.load_state_dict(checkpoint['optimizer']) 140 | print("=> loaded checkpoint '{}' (epoch {})" 141 | .format(args.resume, checkpoint['epoch'])) 142 | else: 143 | print("=> no checkpoint found at '{}'".format(args.resume)) 144 | 145 | cudnn.benchmark = True 146 | args.start_epoch = 0 147 | for epoch in range(args.start_epoch, args.epochs): 148 | train(train_loader, model, criterion, optimizer, epoch, args) 149 | 150 | if (epoch+1)%10 ==0: 151 | save_checkpoint({ 152 | 'epoch': epoch + 1, 153 | 'arch': args.arch, 154 | 'state_dict': model.state_dict(), 155 | 'optimizer' : optimizer.state_dict(), 156 | }, is_best=False, filename='./micle_abide/checkpoint_{:04d}.pth.tar'.format(epoch)) 157 | 158 | with torch.cuda.device(args.gpu_index): 159 | 160 | extract_names = np.array(extract_names) 161 | model.eval() 162 | kkk = 0 163 | for im in extract_images: 164 | if kkk == 0: 165 | im_tf = im.to(args.device) 166 | _, _, anchor_features = model(im_tf) 167 | feat = torch.mean(anchor_features, dim=0).unsqueeze(0).detach().cpu() 168 | else: 169 | im_tf = im.to(args.device) 170 | _, _, anchor_features = model(im_tf) 171 | feat_ = torch.mean(anchor_features, dim=0).unsqueeze(0).detach().cpu() 172 | feat = torch.cat((feat,feat_),dim=0) 173 | del feat_ 174 | kkk+=1 175 | n = feat.detach().cpu().numpy() 176 | np.savetxt('./extracted_feature/train_feature.csv',n,delimiter=',') 177 | np.savetxt('./extracted_feature/train_id.csv',extract_names,delimiter=',') 178 | 179 | def train(train_loader, model, criterion, optimizer, epoch, args): 180 | batch_time = AverageMeter('Time', ':6.3f') 181 | data_time = AverageMeter('Data', ':6.3f') 182 | losses = AverageMeter('Loss', ':.4e') 183 | micles = AverageMeter('Micle', ':.4e') 184 | top1 = AverageMeter('Acc@1', ':6.2f') 185 | top5 = AverageMeter('Acc@5', ':6.2f') 186 | progress = ProgressMeter( 187 | len(train_loader), 188 | [batch_time, data_time, losses, micles,top1, top5], 189 | prefix="Epoch: [{}]".format(epoch)) 190 | 191 | # switch to train mode 192 | model.train() 193 | 194 | end = time.time() 195 | for i, (images, _) in enumerate(train_loader): 196 | data_time.update(time.time() - end) 197 | 198 | if args.gpu is not None: 199 | images[:,0] = images[:,0].cuda(args.gpu, non_blocking=True) 200 | images[:,1] = images[:,1].cuda(args.gpu, non_blocking=True) 201 | output, target, q = model(im_q=images[:,0].cuda(), im_k=images[:,1].cuda()) 202 | loss = criterion(output, target) 203 | _1, _2, features = model(images[:,0].cuda(), images[:,0].cuda()) 204 | micle_loss = micle(args,features, _) 205 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 206 | losses.update(loss.item(), images[0].size(0)) 207 | micles.update(micle_loss.item(), images[0].size(0)) 208 | top1.update(acc1[0], images[0].size(0)) 209 | top5.update(acc5[0], images[0].size(0)) 210 | 211 | # compute gradient and do SGD step 212 | optimizer.zero_grad() 213 | loss_with_micle = loss + micle_loss 214 | # loss.backward() 215 | loss_with_micle.backward() 216 | optimizer.step() 217 | 218 | # measure elapsed time 219 | batch_time.update(time.time() - end) 220 | end = time.time() 221 | args.print_freq = 1 222 | if i % args.print_freq == 0: 223 | progress.display(i) 224 | 225 | 226 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 227 | torch.save(state, filename) 228 | if is_best: 229 | shutil.copyfile(filename, 'model_best.pth.tar') 230 | 231 | def medical_aug(img_t, size=96, s =1): 232 | img_t = torch.FloatTensor(img_t) 233 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 234 | data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size), 235 | transforms.RandomHorizontalFlip(), 236 | transforms.RandomApply([color_jitter], p=0.8), 237 | transforms.ToTensor()]) 238 | trans = transforms.ToPILImage() 239 | img = trans(img_t) 240 | aug_img = data_transforms(img).detach().numpy() 241 | return aug_img 242 | 243 | def micle(args, features, indx): 244 | index = indx.detach().numpy() 245 | unique, counts = np.unique(index, return_counts=True) 246 | count_dict = dict(zip(unique, counts)) 247 | loss = torch.zeros(1).to(args.device) 248 | mse = nn.MSELoss().to(args.device) 249 | count = 0 250 | edge_count = 0 251 | len_dict = 0 252 | for key in count_dict: 253 | len_dict +=1 254 | if count_dict[key]>2: 255 | which = np.where(index == key)[0] 256 | mask = torch.tensor(which).to(args.device) 257 | 258 | features_ = features[mask] 259 | features_ = F.normalize(features_, dim=1) 260 | similarity_matrix = torch.matmul(features_, features_.T) 261 | similarity_matrix = F.normalize(similarity_matrix) 262 | mask = torch.eye(similarity_matrix.shape[0], dtype=torch.bool).to(args.device) 263 | positive = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 264 | 265 | labels = torch.ones(positive.shape,dtype=torch.long).to(args.device) 266 | 267 | loss += mse(positive.to(torch.float32),labels.to(torch.float32)) 268 | count +=1 269 | if not (count==0): 270 | loss =loss/count 271 | return loss 272 | 273 | class AverageMeter(object): 274 | def __init__(self, name, fmt=':f'): 275 | self.name = name 276 | self.fmt = fmt 277 | self.reset() 278 | 279 | def reset(self): 280 | self.val = 0 281 | self.avg = 0 282 | self.sum = 0 283 | self.count = 0 284 | 285 | def update(self, val, n=1): 286 | self.val = val 287 | self.sum += val * n 288 | self.count += n 289 | self.avg = self.sum / self.count 290 | 291 | def __str__(self): 292 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 293 | return fmtstr.format(**self.__dict__) 294 | class TensorData(): 295 | def __init__(self, x_data, y_data): 296 | self.x_data = torch.FloatTensor(x_data) 297 | self.y_data = torch.LongTensor(y_data) 298 | 299 | self.len = self.y_data.shape[0] 300 | 301 | def __getitem__(self, index): 302 | return self.x_data[index], self.y_data[index] 303 | def __len__(self): 304 | return self.len 305 | 306 | class ProgressMeter(object): 307 | def __init__(self, num_batches, meters, prefix=""): 308 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 309 | self.meters = meters 310 | self.prefix = prefix 311 | 312 | def display(self, batch): 313 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 314 | entries += [str(meter) for meter in self.meters] 315 | print('\t'.join(entries)) 316 | 317 | def _get_batch_fmtstr(self, num_batches): 318 | num_digits = len(str(num_batches // 1)) 319 | fmt = '{:' + str(num_digits) + 'd}' 320 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 321 | 322 | 323 | def adjust_learning_rate(optimizer, epoch, args): 324 | """Decay the learning rate based on schedule""" 325 | lr = args.lr 326 | if args.cos: # cosine lr schedule 327 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 328 | else: # stepwise lr schedule 329 | for milestone in args.schedule: 330 | lr *= 0.1 if epoch >= milestone else 1. 331 | for param_group in optimizer.param_groups: 332 | param_group['lr'] = lr 333 | 334 | 335 | def accuracy(output, target, topk=(1,)): 336 | """Computes the accuracy over the k top predictions for the specified values of k""" 337 | with torch.no_grad(): 338 | maxk = max(topk) 339 | batch_size = target.size(0) 340 | 341 | _, pred = output.topk(maxk, 1, True, True) 342 | pred = pred.t() 343 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 344 | 345 | res = [] 346 | for k in topk: 347 | # print(torch.reshape(correct[:k],(-1,)).shape) 348 | correct_k = torch.reshape(correct[:k],(-1,)).float().sum(0, keepdim=True) 349 | # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 350 | res.append(correct_k.mul_(100.0 / batch_size)) 351 | return res 352 | 353 | 354 | if __name__ == '__main__': 355 | main() -------------------------------------------------------------------------------- /Image_embedder/MoCo/moco/__pycache__/builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/Image_embedder/MoCo/moco/__pycache__/builder.cpython-38.pyc -------------------------------------------------------------------------------- /Image_embedder/MoCo/moco/__pycache__/loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/Image_embedder/MoCo/moco/__pycache__/loader.cpython-38.pyc -------------------------------------------------------------------------------- /Image_embedder/MoCo/moco/builder.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/facebookresearch/moco 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MoCo(nn.Module): 7 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): 8 | 9 | super(MoCo, self).__init__() 10 | 11 | self.K = K 12 | self.m = m 13 | self.T = T 14 | self.encoder_q = base_encoder(num_classes=dim) 15 | self.encoder_k = base_encoder(num_classes=dim) 16 | 17 | if mlp: 18 | dim_mlp = self.encoder_q.fc.weight.shape[1] 19 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 20 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 21 | 22 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 23 | param_k.data.copy_(param_q.data) 24 | param_k.requires_grad = False 25 | self.register_buffer("queue", torch.randn(dim, K)) 26 | self.queue = nn.functional.normalize(self.queue, dim=0) 27 | 28 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 29 | 30 | @torch.no_grad() 31 | def _momentum_update_key_encoder(self): 32 | 33 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 34 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 35 | 36 | @torch.no_grad() 37 | def _dequeue_and_enqueue(self, keys): 38 | keys = concat_all_gather(keys) 39 | 40 | batch_size = keys.shape[0] 41 | 42 | ptr = int(self.queue_ptr) 43 | assert self.K % batch_size == 0 44 | 45 | self.queue[:, ptr:ptr + batch_size] = keys.T 46 | ptr = (ptr + batch_size) % self.K 47 | 48 | self.queue_ptr[0] = ptr 49 | 50 | @torch.no_grad() 51 | def _batch_shuffle_ddp(self, x): 52 | 53 | batch_size_this = x.shape[0] 54 | x_gather = concat_all_gather(x) 55 | batch_size_all = x_gather.shape[0] 56 | 57 | num_gpus = batch_size_all // batch_size_this 58 | 59 | idx_shuffle = torch.randperm(batch_size_all).cuda() 60 | 61 | idx_unshuffle = torch.argsort(idx_shuffle) 62 | 63 | idx_this = idx_shuffle 64 | 65 | return x_gather[idx_this], idx_unshuffle 66 | 67 | @torch.no_grad() 68 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 69 | 70 | batch_size_this = x.shape[0] 71 | x_gather = concat_all_gather(x) 72 | batch_size_all = x_gather.shape[0] 73 | 74 | num_gpus = batch_size_all // batch_size_this 75 | idx_this = idx_unshuffle 76 | 77 | return x_gather[idx_this] 78 | 79 | #for image training 80 | def forward(self, im_q, im_k): 81 | q = self.encoder_q(im_q) 82 | q = nn.functional.normalize(q, dim=1) 83 | 84 | with torch.no_grad(): 85 | self._momentum_update_key_encoder() 86 | 87 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 88 | 89 | k = self.encoder_k(im_k) 90 | k = nn.functional.normalize(k, dim=1) 91 | 92 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 93 | 94 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 95 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 96 | 97 | logits = torch.cat([l_pos, l_neg], dim=1) 98 | 99 | logits /= self.T 100 | 101 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 102 | 103 | self._dequeue_and_enqueue(k) 104 | 105 | return logits, labels, q 106 | 107 | @torch.no_grad() 108 | def concat_all_gather(tensor): 109 | output = tensor 110 | 111 | return output -------------------------------------------------------------------------------- /Image_embedder/MoCo/moco/loader.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/facebookresearch/moco 2 | from PIL import ImageFilter 3 | import random 4 | 5 | 6 | class TwoCropsTransform: 7 | 8 | def __init__(self, base_transform): 9 | self.base_transform = base_transform 10 | 11 | def __call__(self, x): 12 | q = self.base_transform(x) 13 | k = self.base_transform(x) 14 | return [q, k] 15 | class GaussianBlur(object): 16 | def __init__(self, sigma=[.1, 2.]): 17 | self.sigma = sigma 18 | 19 | def __call__(self, x): 20 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 21 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 22 | return x -------------------------------------------------------------------------------- /Image_embedder/MoCo/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | import yaml 6 | 7 | 8 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 9 | torch.save(state, filename) 10 | if is_best: 11 | shutil.copyfile(filename, 'model_best.pth.tar') 12 | 13 | 14 | def save_config_file(model_checkpoints_folder, args): 15 | if not os.path.exists(model_checkpoints_folder): 16 | os.makedirs(model_checkpoints_folder) 17 | with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile: 18 | yaml.dump(args, outfile, default_flow_style=False) 19 | 20 | 21 | def accuracy(output, target, topk=(1,)): 22 | """Computes the accuracy over the k top predictions for the specified values of k""" 23 | with torch.no_grad(): 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/data_aug/contrastive_learning_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import transforms 2 | from data_aug.gaussian_blur import GaussianBlur 3 | from torchvision import transforms, datasets 4 | from data_aug.view_generator import ContrastiveLearningViewGenerator 5 | from exceptions.exceptions import InvalidDatasetSelection 6 | 7 | 8 | class ContrastiveLearningDataset: 9 | def __init__(self, root_folder): 10 | self.root_folder = root_folder 11 | 12 | @staticmethod 13 | def get_simclr_pipeline_transform(size, s=1): 14 | """Return a set of data augmentation transformations as described in the SimCLR paper.""" 15 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 16 | data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.RandomApply([color_jitter], p=0.8), 19 | # transforms.RandomGrayscale(p=0.2), 20 | GaussianBlur(kernel_size=int(0.1 * size)), 21 | transforms.ToTensor()]) 22 | return data_transforms 23 | 24 | def get_dataset(self, name, n_views): 25 | valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True, 26 | transform=ContrastiveLearningViewGenerator( 27 | self.get_simclr_pipeline_transform(32), 28 | n_views), 29 | download=True), 30 | 31 | 'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled', 32 | transform=ContrastiveLearningViewGenerator( 33 | self.get_simclr_pipeline_transform(96), 34 | n_views), 35 | download=True)} 36 | 37 | try: 38 | dataset_fn = valid_datasets[name] 39 | except KeyError: 40 | raise InvalidDatasetSelection() 41 | else: 42 | return dataset_fn() 43 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/data_aug/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchvision.transforms import transforms 5 | 6 | np.random.seed(0) 7 | 8 | 9 | class GaussianBlur(object): 10 | """blur a single image on CPU""" 11 | def __init__(self, kernel_size): 12 | radias = kernel_size // 2 13 | kernel_size = radias * 2 + 1 14 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 15 | stride=1, padding=0, bias=False, groups=3) 16 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 17 | stride=1, padding=0, bias=False, groups=3) 18 | self.k = kernel_size 19 | self.r = radias 20 | 21 | self.blur = nn.Sequential( 22 | nn.ReflectionPad2d(radias), 23 | self.blur_h, 24 | self.blur_v 25 | ) 26 | 27 | self.pil_to_tensor = transforms.ToTensor() 28 | self.tensor_to_pil = transforms.ToPILImage() 29 | 30 | def __call__(self, img): 31 | img = self.pil_to_tensor(img).unsqueeze(0) 32 | 33 | sigma = np.random.uniform(0.1, 2.0) 34 | x = np.arange(-self.r, self.r + 1) 35 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 36 | x = x / x.sum() 37 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 38 | 39 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 40 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 41 | 42 | with torch.no_grad(): 43 | img = self.blur(img) 44 | img = img.squeeze() 45 | 46 | img = self.tensor_to_pil(img) 47 | 48 | return img -------------------------------------------------------------------------------- /Image_embedder/SimCLR/data_aug/view_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.random.seed(0) 4 | 5 | 6 | class ContrastiveLearningViewGenerator(object): 7 | """Take two random crops of one image as the query and key.""" 8 | 9 | def __init__(self, base_transform, n_views=2): 10 | self.base_transform = base_transform 11 | self.n_views = n_views 12 | 13 | def __call__(self, x): 14 | return [self.base_transform(x) for i in range(self.n_views)] 15 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TensorData(): 5 | 6 | def __init__(self, x_data, y_data): 7 | 8 | self.x_data = torch.FloatTensor(x_data) 9 | self.y_data = torch.LongTensor(y_data) 10 | 11 | self.len = self.y_data.shape[0] 12 | def __getitem__(self, index): 13 | 14 | return self.x_data[index], self.y_data[index] 15 | 16 | def __len__(self): 17 | 18 | return self.len -------------------------------------------------------------------------------- /Image_embedder/SimCLR/exceptions/__pycache__/exceptions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/Image_embedder/SimCLR/exceptions/__pycache__/exceptions.cpython-38.pyc -------------------------------------------------------------------------------- /Image_embedder/SimCLR/exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | class BaseSimCLRException(Exception): 2 | """Base exception""" 3 | 4 | 5 | class InvalidBackboneError(BaseSimCLRException): 6 | """Raised when the choice of backbone Convnet is invalid.""" 7 | 8 | 9 | class InvalidDatasetSelection(BaseSimCLRException): 10 | """Raised when the choice of dataset is invalid.""" 11 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | from tqdm import tqdm 7 | import cv2 8 | 9 | 10 | def load_data(args): 11 | tf = transforms.ToTensor() 12 | data_name = args.data 13 | if data_name =='ADNI': 14 | root_dir = "./../../medical_dataset/ADNI/image_2D/" 15 | elif data_name =='OASIS': 16 | root_dir = "./../../medical_dataset/OASIS/image_2D/" 17 | elif data_name =='ABIDE': 18 | root_dir = "./../../medical_dataset/ABIDE/image_2D/" 19 | elif data_name =='CMMD': 20 | root_dir = "./../../medical_dataset/CMMD/image_2D/" 21 | elif data_name =='QIN': 22 | root_dir = './../../medical_dataset/QIN/image_2D/' 23 | 24 | files = os.listdir(root_dir) 25 | i=0 26 | name_list = [] 27 | for file in tqdm(files): 28 | if i ==0: 29 | 30 | img = cv2.imread(root_dir +'/' + file) 31 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 32 | gray_three = cv2.merge([gray,gray,gray]) 33 | img__ = cv2.resize(gray_three, (96,96), interpolation = cv2.INTER_AREA) 34 | # with torch.cuda.device(args.gpu_index): 35 | # img_t = torch.tensor(img__, dtype=torch.float, device=args.device) 36 | with torch.cuda.device(args.gpu_index): 37 | img_t = tf(img__).cuda() 38 | # print(args.device) 39 | # print(img_t) 40 | images = img_t.unsqueeze(0) 41 | if 'CMMD' in root_dir: 42 | id_ = file.split('-')[0][-1] +file.split('-')[1][0:4] 43 | num = '0' 44 | else: 45 | id_, num, coordinate = file.split('_') 46 | #0000 for escape from unexpected duplication of id + num 47 | names = int(id_ +'0000' + num) 48 | name_list.append(names) 49 | 50 | del img 51 | del gray 52 | del gray_three 53 | del img__ 54 | del img_t 55 | else: 56 | img = cv2.imread(root_dir +'/'+file) 57 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 58 | gray_three = cv2.merge([gray,gray,gray]) 59 | img__ = cv2.resize(gray_three, (96,96), interpolation = cv2.INTER_AREA) 60 | with torch.cuda.device(args.gpu_index): 61 | img_t = tf(img__).cuda() 62 | images = torch.cat((images,img_t.unsqueeze(0)),0) 63 | 64 | if 'CMMD' in root_dir: 65 | id_ = file.split('-')[0][-1] +file.split('-')[1][0:4] 66 | num = '0' 67 | else: 68 | id_, num, coordinate = file.split('_') 69 | names = int(id_ +'0000' + num) 70 | name_list.append(names) 71 | 72 | del img 73 | del gray 74 | del gray_three 75 | del img__ 76 | del img_t 77 | i+=1 78 | train_images = images.detach().cpu().numpy() 79 | train_names = np.expand_dims(np.array(name_list),axis=1) 80 | 81 | 82 | unique_label = list(set(name_list)) 83 | labels_np = np.array(name_list) 84 | ii = 0 85 | image_list = [] 86 | label_list__=[] 87 | for l in tqdm(unique_label): 88 | if ii==0: 89 | a = np.where(labels_np == l) 90 | index_ = list(a[0]) 91 | coor_img_t = images[index_] 92 | image_list.append(coor_img_t) 93 | label_list__.append(l) 94 | del coor_img_t 95 | 96 | else: 97 | a = np.where(labels_np == l) 98 | index_ = list(a[0]) 99 | coor_img_t = images[index_] 100 | image_list.append(coor_img_t) 101 | label_list__.append(l) 102 | del coor_img_t 103 | ii+=1 104 | del iamges 105 | return train_images,train_names, image_list, label_list__ 106 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/medical_aug.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torchvision import models 4 | 5 | from torchvision.transforms import transforms 6 | from torchvision import transforms, datasets 7 | 8 | def medical_aug(img_t, size=96, s =1): 9 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 10 | data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.RandomApply([color_jitter], p=0.8), 13 | transforms.RandomGrayscale(p=0.2), 14 | transforms.ToTensor()]) 15 | trans = transforms.ToPILImage() 16 | img = trans(img_t) 17 | aug_img = data_transforms(img) 18 | return aug_img -------------------------------------------------------------------------------- /Image_embedder/SimCLR/models/resnet_simclr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models as models 3 | 4 | from exceptions.exceptions import InvalidBackboneError 5 | 6 | 7 | class ResNetSimCLR(nn.Module): 8 | 9 | def __init__(self, base_model, out_dim): 10 | super(ResNetSimCLR, self).__init__() 11 | self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim), 12 | "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)} 13 | 14 | self.backbone = self._get_basemodel(base_model) 15 | dim_mlp = self.backbone.fc.in_features 16 | 17 | # add mlp projection head 18 | self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc) 19 | 20 | def _get_basemodel(self, model_name): 21 | try: 22 | model = self.resnet_dict[model_name] 23 | except KeyError: 24 | raise InvalidBackboneError( 25 | "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50") 26 | else: 27 | return model 28 | 29 | def forward(self, x): 30 | return self.backbone(x) 31 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/run_with_pretrain_with_micle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from codecs import namereplace_errors 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | from torchvision import models 6 | from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset 7 | from models.resnet_simclr import ResNetSimCLR 8 | from simclr_micle import SimCLR_micle 9 | from load_data import load_data 10 | from dataloader import TensorData 11 | import numpy as np 12 | 13 | model_names = sorted(name for name in models.__dict__ 14 | if name.islower() and not name.startswith("__") 15 | and callable(models.__dict__[name])) 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch SimCLR') 18 | parser.add_argument('--data', type='str', default='ADNI') 19 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 20 | choices=model_names, 21 | help='model architecture: ' + 22 | ' | '.join(model_names) + 23 | ' (default: resnet50)') 24 | parser.add_argument('-j', '--workers', default=12, type=int, metavar='N', 25 | help='number of data loading workers (default: 32)') 26 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 27 | help='number of total epochs to run') 28 | parser.add_argument('-b', '--batch-size', default=256, type=int, 29 | metavar='N', 30 | help='mini-batch size (default: 256), this is the total ' 31 | 'batch size of all GPUs on the current node when ' 32 | 'using Data Parallel or Distributed Data Parallel') 33 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, 34 | metavar='LR', help='initial learning rate', dest='lr') 35 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 36 | metavar='W', help='weight decay (default: 1e-4)', 37 | dest='weight_decay') 38 | parser.add_argument('--seed', default=None, type=int, 39 | help='seed for initializing training. ') 40 | parser.add_argument('--disable-cuda', action='store_true', 41 | help='Disable CUDA') 42 | parser.add_argument('--fp16-precision', action='store_true', 43 | help='Whether or not to use 16-bit precision GPU training.') 44 | 45 | parser.add_argument('--out_dim', default=1024, type=int, 46 | help='feature dimension (default: 128)') 47 | parser.add_argument('--log-every-n-steps', default=100, type=int, 48 | help='Log every n steps') 49 | parser.add_argument('--temperature', default=0.07, type=float, 50 | help='softmax temperature (default: 0.07)') 51 | parser.add_argument('--n-views', default=2, type=int, metavar='N', 52 | help='Number of views for contrastive learning training.') 53 | parser.add_argument('--gpu_index', default=0, type=int, help='Gpu index.') 54 | 55 | def main(): 56 | args = parser.parse_args() 57 | assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2." 58 | if not args.disable_cuda and torch.cuda.is_available(): 59 | args.device = torch.device(args.gpu_index) 60 | 61 | cudnn.deterministic = True 62 | cudnn.benchmark = True 63 | else: 64 | args.device = torch.device('cpu') 65 | args.gpu_index = -1 66 | 67 | print(args.device) 68 | print("preparing images") 69 | train_images, train_labels, extract_images, extract_names = load_data(args) 70 | print("load images -- end") 71 | train_data_testing = TensorData(train_images, train_labels) 72 | #for inductive setting 73 | train_loader = torch.utils.data.DataLoader( 74 | train_data_testing, batch_size=args.batch_size, shuffle=True, 75 | num_workers=args.workers, pin_memory=True, drop_last=True) 76 | ########## 77 | 78 | 79 | 80 | #load model 81 | model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim) 82 | with torch.cuda.device(args.gpu_index): 83 | checkpoint = torch.load('./runs/Pretrain_model_emb1024/checkpoint_0100.pth.tar', map_location=args.device) 84 | 85 | state_dict = checkpoint['state_dict'] 86 | model.load_state_dict(state_dict, strict=False) 87 | 88 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 89 | 90 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, 91 | last_epoch=-1) 92 | 93 | with torch.cuda.device(args.gpu_index): 94 | simclr = SimCLR_micle(model=model, optimizer=optimizer, scheduler=scheduler, args=args) 95 | simclr.train(train_loader) 96 | 97 | extract_names = np.array(extract_names) 98 | model.eval() 99 | kkk = 0 100 | for im in extract_images: 101 | if kkk == 0: 102 | im_tf = im.to(args.device) 103 | anchor_features = model(im_tf) 104 | feat = torch.mean(anchor_features, dim=0).unsqueeze(0).detach().cpu() 105 | else: 106 | im_tf = im.to(args.device) 107 | anchor_features = model(im_tf) 108 | feat_ = torch.mean(anchor_features, dim=0).unsqueeze(0).detach().cpu() 109 | feat = torch.cat((feat,feat_),dim=0) 110 | del feat_ 111 | kkk+=1 112 | n = feat.detach().cpu().numpy() 113 | np.savetxt('./extracted_feature/train_feature.csv',n,delimiter=',') 114 | np.savetxt('./extracted_feature/train_id.csv',extract_names,delimiter=',') 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/simclr_micle.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.cuda.amp import GradScaler, autocast 8 | from tqdm import tqdm 9 | from utils import save_config_file, accuracy, save_checkpoint 10 | import numpy as np 11 | 12 | from medical_aug import medical_aug 13 | torch.manual_seed(0) 14 | 15 | 16 | class SimCLR_micle(object): 17 | 18 | def __init__(self, *args, **kwargs): 19 | self.args = kwargs['args'] 20 | self.model = kwargs['model'].to(self.args.device) 21 | self.optimizer = kwargs['optimizer'] 22 | self.scheduler = kwargs['scheduler'] 23 | self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device) 24 | self.mse = torch.nn.MSELoss().to(self.args.device) 25 | 26 | def info_nce_loss(self, features): 27 | 28 | labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0) 29 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() 30 | labels = labels.to(self.args.device) 31 | 32 | features = F.normalize(features, dim=1) 33 | 34 | similarity_matrix = torch.matmul(features, features.T) 35 | 36 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device) 37 | labels = labels[~mask].view(labels.shape[0], -1) 38 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 39 | 40 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) 41 | 42 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 43 | 44 | logits = torch.cat([positives, negatives], dim=1) 45 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device) 46 | 47 | logits = logits / self.args.temperature 48 | return logits, labels 49 | 50 | def micle(self, features, indx): 51 | index = indx.detach().numpy() 52 | unique, counts = np.unique(index, return_counts=True) 53 | count_dict = dict(zip(unique, counts)) 54 | loss = torch.zeros(1).to(self.args.device) 55 | count = 0 56 | len_dict = 0 57 | for key in count_dict: 58 | len_dict +=1 59 | if count_dict[key]>2: 60 | which = np.where(index == key)[0] 61 | mask = torch.tensor(which).to(self.args.device) 62 | 63 | features_ = features[mask] 64 | features_ = F.normalize(features_, dim=1) 65 | 66 | similarity_matrix = torch.matmul(features_, features_.T) 67 | similarity_matrix = F.normalize(similarity_matrix) 68 | mask = torch.eye(similarity_matrix.shape[0], dtype=torch.bool).to(self.args.device) 69 | positive = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 70 | 71 | labels = torch.ones(positive.shape,dtype=torch.long).to(self.args.device) 72 | 73 | loss += self.mse(positive.to(torch.float32),labels.to(torch.float32)) 74 | count +=1 75 | if not (count==0): 76 | loss =loss/count 77 | return loss 78 | 79 | def train(self, train_loader): 80 | 81 | scaler = GradScaler(enabled=self.args.fp16_precision) 82 | 83 | 84 | n_iter = 0 85 | 86 | for epoch_counter in range(self.args.epochs): 87 | for images, _ in tqdm(train_loader): 88 | for i in range(images.shape[0]): 89 | aug = medical_aug(images[i]) 90 | images = torch.cat((images,aug.unsqueeze(0)),0) 91 | _ = torch.cat((_,_[i].unsqueeze(1)),0) 92 | images = images.to(self.args.device) 93 | 94 | with autocast(enabled=self.args.fp16_precision): 95 | features = self.model(images) 96 | logits, labels = self.info_nce_loss(features) 97 | loss = self.criterion(logits, labels) 98 | 99 | with autocast(enabled=self.args.fp16_precision): 100 | features = self.model(images) 101 | micle_loss = self.micle(features, _) 102 | 103 | self.optimizer.zero_grad() 104 | overall_loss = loss + micle_loss 105 | 106 | 107 | overall_loss = overall_loss.to(torch.float32) 108 | scaler.scale(overall_loss).backward() 109 | 110 | scaler.step(self.optimizer) 111 | scaler.update() 112 | n_iter += 1 113 | 114 | if epoch_counter >= 10: 115 | self.scheduler.step() 116 | 117 | checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs) 118 | save_checkpoint({ 119 | 'epoch': self.args.epochs, 120 | 'arch': self.args.arch, 121 | 'state_dict': self.model.state_dict(), 122 | 'optimizer': self.optimizer.state_dict(), 123 | }, filename=checkpoint_name) 124 | -------------------------------------------------------------------------------- /Image_embedder/SimCLR/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | import yaml 6 | 7 | 8 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 9 | torch.save(state, filename) 10 | 11 | 12 | def save_config_file(model_checkpoints_folder, args): 13 | if not os.path.exists(model_checkpoints_folder): 14 | os.makedirs(model_checkpoints_folder) 15 | with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile: 16 | yaml.dump(args, outfile, default_flow_style=False) 17 | 18 | 19 | def accuracy(output, target, topk=(1,)): 20 | with torch.no_grad(): 21 | maxk = max(topk) 22 | batch_size = target.size(0) 23 | 24 | _, pred = output.topk(maxk, 1, True, True) 25 | pred = pred.t() 26 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 27 | 28 | res = [] 29 | for k in topk: 30 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 31 | res.append(correct_k.mul_(100.0 / batch_size)) 32 | return res 33 | -------------------------------------------------------------------------------- /MultiplexNetwork/__pycache__/embedder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/__pycache__/embedder.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/__pycache__/evaluate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/__pycache__/evaluate.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/data/abide.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/data/abide.pkl -------------------------------------------------------------------------------- /MultiplexNetwork/embedder.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/pcy1302/DMGI 2 | import time 3 | import numpy as np 4 | import torch 5 | from utils import process 6 | import torch.nn as nn 7 | from layers import AvgReadout 8 | 9 | class embedder: 10 | def __init__(self, args): 11 | args.batch_size = 1 12 | args.sparse = True 13 | args.metapaths_list = args.metapaths.split(",") 14 | args.gpu_num_ = args.gpu_num 15 | if args.gpu_num_ == 'cpu': 16 | args.device = 'cpu' 17 | else: 18 | args.device = torch.device("cuda:" + str(args.gpu_num_) if torch.cuda.is_available() else "cpu") 19 | 20 | adj, features, labels, idx_train, idx_val, idx_test = process.loads(args) 21 | features = [process.preprocess_features(feature) for feature in features] 22 | 23 | args.nb_nodes = features[0].shape[0] 24 | args.ft_size = features[0].shape[1] 25 | args.nb_classes = labels.shape[1] 26 | args.nb_graphs = len(adj) 27 | args.adj = adj 28 | adj = [process.normalize_adj(adj_) for adj_ in adj] 29 | self.adj = [process.sparse_mx_to_torch_sparse_tensor(adj_) for adj_ in adj] 30 | 31 | self.features = [torch.FloatTensor(feature[np.newaxis]) for feature in features] 32 | 33 | self.labels = torch.FloatTensor(labels[np.newaxis]).to(args.device) 34 | self.idx_train = torch.LongTensor(idx_train).to(args.device) 35 | self.idx_val = torch.LongTensor(idx_val).to(args.device) 36 | self.idx_test = torch.LongTensor(idx_test).to(args.device) 37 | 38 | self.train_lbls = torch.argmax(self.labels[0, self.idx_train], dim=1) 39 | self.val_lbls = torch.argmax(self.labels[0, self.idx_val], dim=1) 40 | self.test_lbls = torch.argmax(self.labels[0, self.idx_test], dim=1) 41 | 42 | # How to aggregate 43 | args.readout_func = AvgReadout() 44 | 45 | # Summary aggregation 46 | args.readout_act_func = nn.Sigmoid() 47 | 48 | self.args = args 49 | 50 | def currentTime(self): 51 | now = time.localtime() 52 | s = "%04d-%02d-%02d %02d:%02d:%02d" % ( 53 | now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec) 54 | 55 | return s 56 | -------------------------------------------------------------------------------- /MultiplexNetwork/evaluate.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/pcy1302/DMGI 2 | import torch 3 | torch.manual_seed(0) 4 | torch.cuda.manual_seed_all(0) 5 | torch.backends.cudnn.deterministic = True 6 | torch.backends.cudnn.benchmark = False 7 | from models import LogReg 8 | import torch.nn as nn 9 | import numpy as np 10 | np.random.seed(0) 11 | from sklearn.metrics import f1_score 12 | 13 | 14 | def evaluate(embeds, idx_train, idx_val, idx_test, labels, device, isTest=True): 15 | hid_units = embeds.shape[2] 16 | nb_classes = labels.shape[2] 17 | xent = nn.CrossEntropyLoss() 18 | train_embs = embeds[0, idx_train] 19 | val_embs = embeds[0, idx_val] 20 | test_embs = embeds[0, idx_test] 21 | 22 | train_lbls = torch.argmax(labels[0, idx_train], dim=1) 23 | val_lbls = torch.argmax(labels[0, idx_val], dim=1) 24 | test_lbls = torch.argmax(labels[0, idx_test], dim=1) 25 | 26 | accs = [] 27 | micro_f1s = [] 28 | macro_f1s = [] 29 | macro_f1s_val = [] 30 | for _ in range(50): 31 | log = LogReg(hid_units, nb_classes) 32 | opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0) 33 | log.to(device) 34 | 35 | val_accs = []; test_accs = [] 36 | val_micro_f1s = []; test_micro_f1s = [] 37 | val_macro_f1s = []; test_macro_f1s = [] 38 | for iter_ in range(50): 39 | # train 40 | log.train() 41 | opt.zero_grad() 42 | 43 | logits = log(train_embs) 44 | loss = xent(logits, train_lbls) 45 | 46 | loss.backward() 47 | opt.step() 48 | 49 | # val 50 | logits = log(val_embs) 51 | preds = torch.argmax(logits, dim=1) 52 | 53 | val_acc = torch.sum(preds == val_lbls).float() / val_lbls.shape[0] 54 | val_f1_macro = f1_score(val_lbls.cpu(), preds.cpu(), average='macro') 55 | val_f1_micro = f1_score(val_lbls.cpu(), preds.cpu(), average='micro') 56 | 57 | val_accs.append(val_acc.item()) 58 | val_macro_f1s.append(val_f1_macro) 59 | val_micro_f1s.append(val_f1_micro) 60 | 61 | # test 62 | logits = log(test_embs) 63 | preds = torch.argmax(logits, dim=1) 64 | 65 | test_acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0] 66 | test_f1_macro = f1_score(test_lbls.cpu(), preds.cpu(), average='macro') 67 | test_f1_micro = f1_score(test_lbls.cpu(), preds.cpu(), average='micro') 68 | 69 | test_accs.append(test_acc.item()) 70 | test_macro_f1s.append(test_f1_macro) 71 | test_micro_f1s.append(test_f1_micro) 72 | 73 | 74 | max_iter = val_accs.index(max(val_accs)) 75 | accs.append(test_accs[max_iter]) 76 | 77 | max_iter = val_macro_f1s.index(max(val_macro_f1s)) 78 | macro_f1s.append(test_macro_f1s[max_iter]) 79 | macro_f1s_val.append(val_macro_f1s[max_iter]) ### 80 | 81 | max_iter = val_micro_f1s.index(max(val_micro_f1s)) 82 | micro_f1s.append(test_micro_f1s[max_iter]) 83 | if isTest: 84 | print("\t[Classification] Macro-F1: {:.4f} ({:.4f}) | Micro-F1: {:.4f} ({:.4f})".format(np.mean(macro_f1s), 85 | np.std(macro_f1s), 86 | np.mean(micro_f1s), 87 | np.std(micro_f1s))) 88 | print("\t[Maximums] Macro-F1: {:.4f} | Micro-F1: {:.4f} | Test accuracy: {:.4f}".format(np.max(macro_f1s),np.max(micro_f1s),np.max(accs))) 89 | else: 90 | return np.mean(macro_f1s_val), np.mean(macro_f1s) 91 | 92 | test_embs = np.array(test_embs.cpu()) 93 | test_lbls = np.array(test_lbls.cpu()) -------------------------------------------------------------------------------- /MultiplexNetwork/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/pcy1302/DMGI 2 | from .readout import AvgReadout 3 | from .discriminator import Discriminator 4 | from .attention import Attention 5 | 6 | -------------------------------------------------------------------------------- /MultiplexNetwork/layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/layers/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/layers/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/layers/__pycache__/discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/layers/__pycache__/discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/layers/__pycache__/gcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/layers/__pycache__/gcn.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/layers/__pycache__/readout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/layers/__pycache__/readout.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/layers/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, args): 8 | super(Attention, self).__init__() 9 | self.args = args 10 | self.A = nn.ModuleList([nn.Linear(args.hid_units, 1) for _ in range(args.nb_graphs)]) 11 | self.weight_init() 12 | 13 | def weight_init(self): 14 | for i in range(self.args.nb_graphs): 15 | nn.init.xavier_normal_(self.A[i].weight) 16 | self.A[i].bias.data.fill_(0.0) 17 | 18 | def forward(self, feat_pos, feat_neg, summary): 19 | feat_pos, feat_pos_attn, p = self.attn_feature(feat_pos) 20 | feat_neg, feat_neg_attn, p_ = self.attn_feature(feat_neg) 21 | summary, summary_attn = self.attn_summary(summary) 22 | 23 | return feat_pos, feat_neg, summary, p 24 | 25 | 26 | def attn_feature(self, features): 27 | features_attn = [] 28 | for i in range(self.args.nb_graphs): 29 | features_attn.append((self.A[i](features[i].squeeze()))) 30 | features_attn = F.softmax(torch.cat(features_attn, 1), -1) 31 | p = torch.mean(features_attn,dim=0) 32 | features = torch.cat(features,1).squeeze(0) 33 | features_attn_reshaped = features_attn.transpose(1, 0).contiguous().view(-1, 1) 34 | features = features * features_attn_reshaped.expand_as(features) 35 | features = features.view(self.args.nb_graphs, self.args.nb_nodes, self.args.hid_units).sum(0).unsqueeze(0) 36 | 37 | return features, features_attn, p 38 | 39 | def attn_summary(self, features): 40 | features_attn = [] 41 | for i in range(self.args.nb_graphs): 42 | features_attn.append((self.A[i](features[i].squeeze()))) 43 | features_attn = F.softmax(torch.cat(features_attn), dim=-1).unsqueeze(1) 44 | features = torch.cat(features, 0) 45 | features_attn_expanded = features_attn.expand_as(features) 46 | features = (features * features_attn_expanded).sum(0).unsqueeze(0) 47 | 48 | return features, features_attn 49 | -------------------------------------------------------------------------------- /MultiplexNetwork/layers/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(0) 3 | torch.cuda.manual_seed_all(0) 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | import torch.nn as nn 7 | 8 | 9 | class Discriminator(nn.Module): 10 | def __init__(self, n_h): 11 | super(Discriminator, self).__init__() 12 | self.f_k_bilinear = nn.Bilinear(n_h, n_h, 1) 13 | 14 | for m in self.modules(): 15 | self.weights_init(m) 16 | 17 | def weights_init(self, m): 18 | if isinstance(m, nn.Bilinear): 19 | torch.nn.init.xavier_uniform_(m.weight.data) 20 | if m.bias is not None: 21 | m.bias.data.fill_(0.0) 22 | 23 | def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): 24 | c_x = torch.unsqueeze(c, 1) # c: summary vector, h_pl: positive, h_mi: negative 25 | c_x = c_x.expand_as(h_pl) 26 | 27 | sc_1 = torch.squeeze(self.f_k_bilinear(h_pl, c_x), 2) # sc_1 = 1 x nb_nodes 28 | sc_2 = torch.squeeze(self.f_k_bilinear(h_mi, c_x), 2) # sc_2 = 1 x nb_nodes 29 | 30 | if s_bias1 is not None: 31 | sc_1 += s_bias1 32 | if s_bias2 is not None: 33 | sc_2 += s_bias2 34 | logits = torch.cat((sc_1, sc_2), 1) 35 | 36 | return logits -------------------------------------------------------------------------------- /MultiplexNetwork/layers/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(0) 3 | torch.cuda.manual_seed_all(0) 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import pdb 9 | import math 10 | 11 | class GCN(nn.Module): 12 | def __init__(self, in_ft, out_ft, act, drop_prob, isBias=False): 13 | super(GCN, self).__init__() 14 | 15 | self.fc_1 = nn.Linear(in_ft, out_ft, bias=False) 16 | 17 | if act == 'prelu': 18 | self.act = nn.PReLU() 19 | elif act == 'relu': 20 | self.act = nn.ReLU() 21 | elif act == 'leakyrelu': 22 | self.act = nn.LeakyReLU() 23 | elif act == 'relu6': 24 | self.act = nn.ReLU6() 25 | elif act == 'rrelu': 26 | self.act = nn.RReLU() 27 | elif act == 'selu': 28 | self.act = nn.SELU() 29 | elif act == 'celu': 30 | self.act = nn.CELU() 31 | elif act == 'sigmoid': 32 | self.act = nn.Sigmoid() 33 | elif act == 'identity': 34 | self.act = nn.Identity() 35 | 36 | if isBias: 37 | self.bias_1 = nn.Parameter(torch.FloatTensor(out_ft)) 38 | self.bias_1.data.fill_(0.0) 39 | else: 40 | self.register_parameter('bias', None) 41 | 42 | for m in self.modules(): 43 | self.weights_init(m) 44 | 45 | self.drop_prob = drop_prob 46 | self.isBias = isBias 47 | 48 | def weights_init(self, m): 49 | if isinstance(m, nn.Linear): 50 | torch.nn.init.xavier_uniform_(m.weight.data) 51 | if m.bias is not None: 52 | m.bias.data.fill_(0.0) 53 | 54 | # Shape of seq: (batch, nodes, features) 55 | def forward(self, seq, adj, sparse=False): 56 | seq = F.dropout(seq, self.drop_prob, training=self.training) 57 | seq = self.fc_1(seq) 58 | if sparse: 59 | seq = torch.unsqueeze(torch.spmm(adj, torch.squeeze(seq, 0)), 0) 60 | else: 61 | seq = torch.bmm(adj, seq) 62 | 63 | if self.isBias: 64 | seq += self.bias_1 65 | 66 | return self.act(seq) 67 | 68 | -------------------------------------------------------------------------------- /MultiplexNetwork/layers/readout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(0) 3 | torch.cuda.manual_seed_all(0) 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | import torch.nn as nn 7 | 8 | class AvgReadout(nn.Module): 9 | def __init__(self): 10 | super(AvgReadout, self).__init__() 11 | 12 | def forward(self, seq): 13 | return torch.mean(seq, 1) -------------------------------------------------------------------------------- /MultiplexNetwork/main.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/pcy1302/DMGI 2 | import numpy as np 3 | np.random.seed(0) 4 | import torch 5 | torch.autograd.set_detect_anomaly(True) 6 | torch.manual_seed(0) 7 | torch.cuda.manual_seed_all(0) 8 | torch.backends.cudnn.deterministic = True 9 | torch.backends.cudnn.benchmark = False 10 | import argparse 11 | 12 | def parse_args(): 13 | # input arguments 14 | parser = argparse.ArgumentParser(description='DMGI') 15 | 16 | parser.add_argument('--embedder', nargs='?', default='DMGI') 17 | parser.add_argument('--dataset', nargs='?', default='ADNI') 18 | parser.add_argument('--metapaths', nargs='?', default='type0,type1,type2,type3') 19 | 20 | parser.add_argument('--nb_epochs', type=int, default=10000) 21 | parser.add_argument('--hid_units', type=int, default=64) 22 | parser.add_argument('--lr', type = float, default = 0.0005) 23 | parser.add_argument('--l2_coef', type=float, default=0.0001) 24 | parser.add_argument('--drop_prob', type=float, default=0.5) 25 | parser.add_argument('--reg_coef', type=float, default=0.001) 26 | parser.add_argument('--sup_coef', type=float, default=0.1) 27 | parser.add_argument('--sc', type=float, default=3.0, help='GCN self connection') 28 | parser.add_argument('--margin', type=float, default=0.1) 29 | parser.add_argument('--gpu_num', type=int, default=0) 30 | parser.add_argument('--patience', type=int, default=100) 31 | parser.add_argument('--nheads', type=int, default=1) 32 | parser.add_argument('--activation', nargs='?', default='relu') 33 | parser.add_argument('--isSemi', action='store_true', default=True) 34 | parser.add_argument('--isBias', action='store_true', default=False) 35 | parser.add_argument('--isAttn', action='store_true', default=True) 36 | 37 | return parser.parse_known_args() 38 | 39 | 40 | def main(): 41 | args, unknown = parse_args() 42 | from models import DMGI 43 | embedder = DMGI(args) 44 | embedder.training() 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /MultiplexNetwork/models/DMGI.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/pcy1302/DMGI/blob/master/models/DMGI.py 2 | 3 | import torch 4 | 5 | torch.manual_seed(0) 6 | torch.cuda.manual_seed_all(0) 7 | 8 | torch.backends.cudnn.deterministic = True 9 | torch.backends.cudnn.benchmark = False 10 | import torch.nn as nn 11 | from embedder import embedder 12 | from layers import GCN, Discriminator, Attention 13 | import numpy as np 14 | np.random.seed(0) 15 | 16 | from evaluate import evaluate 17 | from models import LogReg 18 | import pickle as pkl 19 | from tqdm import trange 20 | 21 | 22 | class DMGI(embedder): 23 | def __init__(self, args): 24 | embedder.__init__(self, args) 25 | self.args = args 26 | 27 | def training(self): 28 | features = [feature.to(self.args.device) for feature in self.features] 29 | adj = [adj_.to(self.args.device) for adj_ in self.adj] 30 | model = modeler(self.args).to(self.args.device) 31 | optimiser = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=self.args.l2_coef) 32 | cnt_wait = 0; best = 1e9 33 | b_xent = nn.BCEWithLogitsLoss() 34 | xent = nn.CrossEntropyLoss() 35 | for epoch in trange(self.args.nb_epochs): 36 | xent_loss = None 37 | model.train() 38 | optimiser.zero_grad() 39 | idx = np.random.permutation(self.args.nb_nodes) 40 | 41 | shuf = [feature[:, idx, :] for feature in features] 42 | shuf = [shuf_ft.to(self.args.device) for shuf_ft in shuf] 43 | 44 | lbl_1 = torch.ones(self.args.batch_size, self.args.nb_nodes) 45 | lbl_2 = torch.zeros(self.args.batch_size, self.args.nb_nodes) 46 | lbl = torch.cat((lbl_1, lbl_2), 1).to(self.args.device) 47 | 48 | result = model(features, adj, shuf, self.args.sparse, None, None, None) 49 | logits = result['logits'] 50 | 51 | for view_idx, logit in enumerate(logits): 52 | if xent_loss is None: 53 | xent_loss = b_xent(logit, lbl) 54 | else: 55 | xent_loss += b_xent(logit, lbl) 56 | 57 | loss = xent_loss 58 | 59 | reg_loss = result['reg_loss'] 60 | loss += self.args.reg_coef * reg_loss 61 | 62 | if self.args.isSemi: 63 | sup = result['semi'] 64 | semi_loss = xent(sup[self.idx_train], self.train_lbls) 65 | loss += self.args.sup_coef * semi_loss 66 | 67 | if loss < best: 68 | best = loss 69 | cnt_wait = 0 70 | torch.save(model.state_dict(), 'saved_model/best_{}_{}_{}.pkl'.format(self.args.dataset, self.args.embedder, self.args.metapaths)) 71 | else: 72 | cnt_wait += 1 73 | 74 | if cnt_wait == self.args.patience: 75 | break 76 | 77 | loss.backward() 78 | optimiser.step() 79 | 80 | model.load_state_dict(torch.load('saved_model/best_{}_{}_{}.pkl'.format(self.args.dataset, self.args.embedder, self.args.metapaths))) 81 | 82 | # Evaluation 83 | model.eval() 84 | evaluate(model.H.data.detach(), self.idx_train, self.idx_val, self.idx_test, self.labels, self.args.device) 85 | # evaluate(result['h1'].data.detach(), self.idx_train, self.idx_val, self.idx_test, self.labels, self.args.device) 86 | 87 | 88 | class modeler(nn.Module): 89 | def __init__(self, args): 90 | super(modeler, self).__init__() 91 | self.args = args 92 | self.gcn = nn.ModuleList([GCN(args.ft_size, args.hid_units, args.activation, args.drop_prob, args.isBias) for _ in range(args.nb_graphs)]) 93 | 94 | self.disc = Discriminator(args.hid_units) 95 | self.H = nn.Parameter(torch.FloatTensor(1, args.nb_nodes, args.hid_units)) 96 | self.readout_func = self.args.readout_func 97 | if args.isAttn: 98 | self.attn = nn.ModuleList([Attention(args) for _ in range(args.nheads)]) 99 | 100 | if args.isSemi: 101 | self.logistic = LogReg(args.hid_units, args.nb_classes).to(args.device) 102 | 103 | self.init_weight() 104 | 105 | def init_weight(self): 106 | nn.init.xavier_normal_(self.H) 107 | 108 | def forward(self, feature, adj, shuf, sparse, msk, samp_bias1, samp_bias2): 109 | h_1_all = []; h_2_all = []; c_all = []; logits = [] 110 | result = {} 111 | for i in range(self.args.nb_graphs): 112 | h_1 = self.gcn[i](feature[i], adj[i], sparse) 113 | c = self.readout_func(h_1) 114 | c = self.args.readout_act_func(c) 115 | h_2 = self.gcn[i](shuf[i], adj[i], sparse) 116 | logit = self.disc(c, h_1, h_2, samp_bias1, samp_bias2) 117 | h_1_all.append(h_1) 118 | h_2_all.append(h_2) 119 | c_all.append(c) 120 | logits.append(logit) 121 | result['logits'] = logits 122 | if self.args.isAttn: 123 | h_1_all_lst = []; h_2_all_lst = []; c_all_lst = [] 124 | for h_idx in range(self.args.nheads): 125 | h_1_all_, h_2_all_, c_all_, p = self.attn[h_idx](h_1_all, h_2_all, c_all) 126 | h_1_all_lst.append(h_1_all_); h_2_all_lst.append(h_2_all_); c_all_lst.append(c_all_) 127 | h_1_all = torch.mean(torch.cat(h_1_all_lst, 0), 0).unsqueeze(0) 128 | h_2_all = torch.mean(torch.cat(h_2_all_lst, 0), 0).unsqueeze(0) 129 | else: 130 | h_1_all = torch.mean(torch.cat(h_1_all), 0).unsqueeze(0) 131 | h_2_all = torch.mean(torch.cat(h_2_all), 0).unsqueeze(0) 132 | pos_reg_loss = ((self.H - h_1_all) ** 2).sum() 133 | neg_reg_loss = ((self.H - h_2_all) ** 2).sum() 134 | reg_loss = pos_reg_loss - neg_reg_loss 135 | result['reg_loss'] = reg_loss 136 | if self.args.isSemi: 137 | semi = self.logistic(self.H).squeeze(0) 138 | # semi = self.logistic(h_1_all).squeeze(0) 139 | result['h1'] = h_1_all 140 | result['semi'] = semi 141 | return result -------------------------------------------------------------------------------- /MultiplexNetwork/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .logreg import LogReg 2 | from .DMGI import DMGI 3 | -------------------------------------------------------------------------------- /MultiplexNetwork/models/__pycache__/DMGI.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/models/__pycache__/DMGI.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/models/__pycache__/logreg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/models/__pycache__/logreg.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/models/logreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(0) 3 | torch.cuda.manual_seed_all(0) 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class LogReg(nn.Module): 10 | def __init__(self, ft_in, nb_classes): 11 | super(LogReg, self).__init__() 12 | self.fc = nn.Linear(ft_in, nb_classes) 13 | 14 | for m in self.modules(): 15 | self.weights_init(m) 16 | 17 | def weights_init(self, m): 18 | if isinstance(m, nn.Linear): 19 | torch.nn.init.xavier_uniform_(m.weight.data) 20 | if m.bias is not None: 21 | m.bias.data.fill_(0.0) 22 | 23 | def forward(self, seq): 24 | ret = self.fc(seq) 25 | return ret 26 | 27 | -------------------------------------------------------------------------------- /MultiplexNetwork/saved_model/best_abide_DMGI_type0,type1,type2,type3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/saved_model/best_abide_DMGI_type0,type1,type2,type3.pkl -------------------------------------------------------------------------------- /MultiplexNetwork/saved_model/best_cmmd_DMGI_type0,type1,type2,type3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/saved_model/best_cmmd_DMGI_type0,type1,type2,type3.pkl -------------------------------------------------------------------------------- /MultiplexNetwork/saved_model/best_cmmd_train60_0806_DMGI_type0,type1,type2,type3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/saved_model/best_cmmd_train60_0806_DMGI_type0,type1,type2,type3.pkl -------------------------------------------------------------------------------- /MultiplexNetwork/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/utils/__init__.py -------------------------------------------------------------------------------- /MultiplexNetwork/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/utils/__pycache__/process.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sein-Kim/Multimodal-Medical/c235485673e3f6040ab301c312b17bc62ef720d6/MultiplexNetwork/utils/__pycache__/process.cpython-38.pyc -------------------------------------------------------------------------------- /MultiplexNetwork/utils/process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | import networkx as nx 4 | import scipy.sparse as sp 5 | import sys 6 | import torch 7 | import torch.nn as nn 8 | import scipy.io as sio 9 | import pdb 10 | 11 | def loads(args): 12 | dataset = args.dataset 13 | metapaths = args.metapaths_list 14 | sc = args.sc 15 | 16 | data = pkl.load(open('data/{}.pkl'.format(dataset), "rb")) 17 | label = data['label'] 18 | N = label.shape[0] 19 | 20 | truefeatures = data['feature'].astype(float) 21 | rownetworks = [data[metapath] + np.eye(N)*sc for metapath in metapaths] 22 | 23 | rownetworks = [sp.csr_matrix(rownetwork) for rownetwork in rownetworks] 24 | 25 | truefeatures = sp.lil_matrix(truefeatures) 26 | 27 | idx_train = data['train_idx'].ravel() 28 | idx_val = data['val_idx'].ravel() 29 | idx_test = data['test_idx'].ravel() 30 | 31 | truefeatures_list = [] 32 | for _ in range(len(rownetworks)): 33 | truefeatures_list.append(truefeatures) 34 | 35 | return rownetworks, truefeatures_list, label, idx_train, idx_val, idx_test 36 | 37 | def parse_skipgram(fname): 38 | with open(fname) as f: 39 | toks = list(f.read().split()) 40 | nb_nodes = int(toks[0]) 41 | nb_features = int(toks[1]) 42 | ret = np.empty((nb_nodes, nb_features)) 43 | it = 2 44 | for i in range(nb_nodes): 45 | cur_nd = int(toks[it]) - 1 46 | it += 1 47 | for j in range(nb_features): 48 | cur_ft = float(toks[it]) 49 | ret[cur_nd][j] = cur_ft 50 | it += 1 51 | return ret 52 | 53 | def accuracy(output, labels): 54 | preds = output.max(1)[1].type_as(labels) 55 | correct = preds.eq(labels).double() 56 | correct = correct.sum() 57 | return correct / len(labels) 58 | 59 | def adj_to_bias(adj, sizes, nhood=1): 60 | nb_graphs = adj.shape[0] 61 | mt = np.empty(adj.shape) 62 | for g in range(nb_graphs): 63 | mt[g] = np.eye(adj.shape[1]) 64 | for _ in range(nhood): 65 | mt[g] = np.matmul(mt[g], (adj[g] + np.eye(adj.shape[1]))) 66 | for i in range(sizes[g]): 67 | for j in range(sizes[g]): 68 | if mt[g][i][j] > 0.0: 69 | mt[g][i][j] = 1.0 70 | return -1e9 * (1.0 - mt) 71 | 72 | def sample_mask(idx, l): 73 | """Create mask.""" 74 | mask = np.zeros(l) 75 | mask[idx] = 1 76 | return np.array(mask, dtype=np.bool) 77 | 78 | def sparse_to_tuple(sparse_mx, insert_batch=False): 79 | def to_tuple(mx): 80 | if not sp.isspmatrix_coo(mx): 81 | mx = mx.tocoo() 82 | if insert_batch: 83 | coords = np.vstack((np.zeros(mx.row.shape[0]), mx.row, mx.col)).transpose() 84 | values = mx.data 85 | shape = (1,) + mx.shape 86 | else: 87 | coords = np.vstack((mx.row, mx.col)).transpose() 88 | values = mx.data 89 | shape = mx.shape 90 | return coords, values, shape 91 | 92 | if isinstance(sparse_mx, list): 93 | for i in range(len(sparse_mx)): 94 | sparse_mx[i] = to_tuple(sparse_mx[i]) 95 | else: 96 | sparse_mx = to_tuple(sparse_mx) 97 | 98 | return sparse_mx 99 | 100 | def preprocess_features(features): 101 | rowsum = np.array(features.sum(1)) 102 | r_inv = np.power(rowsum, -1).flatten() 103 | r_inv[np.isinf(r_inv)] = 0. 104 | r_mat_inv = sp.diags(r_inv) 105 | features = r_mat_inv.dot(features) 106 | return features.todense() 107 | 108 | def normalize_adj(adj): 109 | adj = sp.coo_matrix(adj) 110 | rowsum = np.array(adj.sum(1)) 111 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 112 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 113 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 114 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 115 | 116 | def preprocess_adj(adj): 117 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 118 | return sparse_to_tuple(adj_normalized) 119 | 120 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 121 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 122 | indices = torch.from_numpy( 123 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 124 | values = torch.from_numpy(sparse_mx.data) 125 | shape = torch.Size(sparse_mx.shape) 126 | return torch.sparse.FloatTensor(indices, values, shape) -------------------------------------------------------------------------------- /Preprocessing/Image_preprocessing/abide_atlas.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import SimpleITK as sitk 3 | import matplotlib.pyplot as plt 4 | import time 5 | from tqdm.notebook import tqdm 6 | import random 7 | data_path_txt = './../../medical_dataset/ABIDE/all_paths.txt' 8 | data_path = './../../medical_dataset/ABIDE/image_2D/' 9 | 10 | f = open(data_path_txt,'r') 11 | lines = f.readlines() 12 | all_data = [] 13 | for line in lines: 14 | name, path_ = line[:-1].split('\t') 15 | all_data.append([int(name),path_]) 16 | f.close() 17 | 18 | labels = [] 19 | for data in all_data: 20 | labels.append(data[0]) 21 | 22 | i = 0 23 | atlas = sitk.ReadImage('C:/Users/user/Desktop/abide/ABIDE/50002/MP-RAGE/2000-01-01_00_00_00.0/S164623/ABIDE_50002_MRI_MP-RAGE_br_raw_20120830172854796_S164623_I328631.nii') 24 | elastixImageFilter = sitk.ElastixImageFilter() 25 | elastixImageFilter.SetFixedImage(atlas) 26 | thres=-1 27 | dict_data_count = {} 28 | 29 | for data in all_data: 30 | if i >thres: 31 | time_1 = time.time() 32 | read_sitk = sitk.ReadImage(data[1]) 33 | elastixImageFilter.SetMovingImage(read_sitk) 34 | elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap('translation')) 35 | elastixImageFilter.Execute() 36 | read_sitk = elastixImageFilter.GetResultImage() 37 | 38 | img_vol = sitk.GetArrayFromImage(read_sitk) 39 | if img_vol.shape[0] >=100: 40 | side = [i+90 for i in range(40)] 41 | front = [i+100 for i in range(50)] 42 | cross = [i+90 for i in range(40)] 43 | random.shuffle(side) 44 | random.shuffle(front) 45 | random.shuffle(cross) 46 | image_2D_path = data_path +str(int(data[0])) +'_' +str(i) +'_' 47 | if not (data[0] in list(dict_data_count.keys())): 48 | dict_data_count[data[0]] = 1 49 | for s in side[:20]: 50 | save_path = image_2D_path+ 'side' + str(s) +'.png' 51 | plt.imshow(img_vol[:,:,s], cmap='gray') 52 | plt.axis('off') 53 | # plt.show 54 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 55 | del save_path 56 | plt.close('all') 57 | plt.clf() 58 | ################## 59 | for f in front[:20]: 60 | save_path = image_2D_path+ 'front' + str(f) +'.png' 61 | plt.imshow(img_vol[:,f], cmap='gray') 62 | plt.axis('off') 63 | # plt.show 64 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 65 | del save_path 66 | plt.close('all') 67 | plt.clf() 68 | ############################ 69 | for c in cross[:20]: 70 | save_path = image_2D_path+ 'cross' + str(c) +'.png' 71 | plt.imshow(img_vol[50], cmap='gray') 72 | plt.axis('off') 73 | # plt.show 74 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 75 | del save_path 76 | plt.close('all') 77 | plt.clf() 78 | 79 | del read_sitk 80 | del img_vol 81 | del image_2D_path 82 | i+=1 -------------------------------------------------------------------------------- /Preprocessing/Image_preprocessing/adni_atlas.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import SimpleITK as sitk 3 | import matplotlib.pyplot as plt 4 | import time 5 | from tqdm.notebook import tqdm 6 | import random 7 | data_path_txt = './../../medical_dataset/ADNI/all_paths.txt' 8 | data_path = './../../medical_dataset/ADNI/image_2D/' 9 | 10 | f = open(data_path_txt,'r') 11 | lines = f.readlines() 12 | all_data = [] 13 | for line in lines: 14 | name, path_ = line[:-1].split('\t') 15 | all_data.append([int(name),path_]) 16 | f.close() 17 | 18 | labels = [] 19 | for data in all_data: 20 | labels.append(data[0]) 21 | 22 | i = 0 23 | atlas = sitk.ReadImage('C:/Users/user/Desktop/ADNI1_Complete 1Yr 3T/ADNI/002_S_0413/MPR____N3__Scaled/2006-05-19_16_17_47.0/I40657/ADNI_002_S_0413_MR_MPR____N3__Scaled_Br_20070216232854688_S14782_I40657.nii') 24 | elastixImageFilter = sitk.ElastixImageFilter() 25 | elastixImageFilter.SetFixedImage(atlas) 26 | thres=-1 27 | dict_data_count = {} 28 | 29 | for data in all_data: 30 | if i >thres: 31 | time_1 = time.time() 32 | read_sitk = sitk.ReadImage(data[1]) 33 | elastixImageFilter.SetMovingImage(read_sitk) 34 | elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap('translation')) 35 | elastixImageFilter.Execute() 36 | read_sitk = elastixImageFilter.GetResultImage() 37 | 38 | img_vol = sitk.GetArrayFromImage(read_sitk) 39 | if img_vol.shape[0] >=100: 40 | side = [i+90 for i in range(40)] 41 | front = [i+100 for i in range(50)] 42 | cross = [i+90 for i in range(40)] 43 | random.shuffle(side) 44 | random.shuffle(front) 45 | random.shuffle(cross) 46 | image_2D_path = data_path +str(int(data[0])) +'_' +str(i) +'_' 47 | if not (data[0] in list(dict_data_count.keys())): 48 | dict_data_count[data[0]] = 1 49 | for s in side[:20]: 50 | save_path = image_2D_path+ 'side' + str(s) +'.png' 51 | plt.imshow(img_vol[:,:,s], cmap='gray') 52 | plt.axis('off') 53 | # plt.show 54 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 55 | del save_path 56 | plt.close('all') 57 | plt.clf() 58 | ################## 59 | for f in front[:20]: 60 | save_path = image_2D_path+ 'front' + str(f) +'.png' 61 | plt.imshow(img_vol[:,f], cmap='gray') 62 | plt.axis('off') 63 | # plt.show 64 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 65 | del save_path 66 | plt.close('all') 67 | plt.clf() 68 | ############################ 69 | for c in cross[:20]: 70 | save_path = image_2D_path+ 'cross' + str(c) +'.png' 71 | plt.imshow(img_vol[50], cmap='gray') 72 | plt.axis('off') 73 | # plt.show 74 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 75 | del save_path 76 | plt.close('all') 77 | plt.clf() 78 | 79 | del read_sitk 80 | del img_vol 81 | del image_2D_path 82 | i+=1 -------------------------------------------------------------------------------- /Preprocessing/Image_preprocessing/cmmd_save.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | import time 4 | import matplotlib.pyplot as plt 5 | 6 | data_path_txt = './../../medical_dataset/CMMD/all_paths' 7 | data_path = './../../medical_dataset/CMMD/image_2D/' 8 | f = open(data_path_txt,'r') 9 | lines = f.readlines() 10 | 11 | i = 0 12 | for line in lines: 13 | time_1 = time.time() 14 | ids_ = line.split('\t')[0] 15 | filename = line.split('\t')[1][:-1] 16 | 17 | image_2D_path = data_path + ids_ + '_'+str(i) + '.png' 18 | images = sitk.ReadImage(filename) 19 | images_array = sitk.GetArrayFromImage(images).astype('float32') 20 | img = np.squeeze(images_array) 21 | copy_img = img.copy() 22 | min = np.min(copy_img) 23 | max = np.max(copy_img) 24 | 25 | copy_img1 = copy_img - np.min(copy_img) 26 | copy_img = copy_img1/np.max(copy_img1) 27 | copy_img *= 2**8-1 28 | copy_img = copy_img.astype(np.uint8) 29 | plt.imshow(copy_img, cmap='gray') 30 | plt.axis('off') 31 | plt.savefig(image_2D_path, bbox_inches='tight',pad_inches = 0) 32 | plt.clf() 33 | 34 | del copy_img 35 | del copy_img1 36 | del images 37 | del images_array 38 | del img 39 | 40 | 41 | 42 | i+=1 -------------------------------------------------------------------------------- /Preprocessing/Image_preprocessing/duke_save.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | import time 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | data_path_txt = './../../medical_dataset/QIN/all_paths' 8 | data_path = './../../medical_dataset/QIN/image_2D/' 9 | f = open(data_path_txt,'r') 10 | lines = f.readlines() 11 | 12 | i = 0 13 | for line in lines: 14 | time_1 = time.time() 15 | ids_ = line.split('\t')[0] 16 | filename = line.split('\t')[1][:-1] 17 | 18 | image_2D_path = data_path + ids_ + '_'+str(i) + '.png' 19 | images = sitk.ReadImage(filename) 20 | images_array = sitk.GetArrayFromImage(images).astype('float32') 21 | img = np.squeeze(images_array) 22 | copy_img = img.copy() 23 | min = np.min(copy_img) 24 | max = np.max(copy_img) 25 | 26 | copy_img1 = copy_img - np.min(copy_img) 27 | copy_img = copy_img1/np.max(copy_img1) 28 | copy_img *= 2**8-1 29 | copy_img = copy_img.astype(np.uint8) 30 | plt.imshow(copy_img, cmap='gray') 31 | plt.axis('off') 32 | plt.savefig(image_2D_path, bbox_inches='tight',pad_inches = 0) 33 | plt.clf() 34 | 35 | del copy_img 36 | del copy_img1 37 | del images 38 | del images_array 39 | del img 40 | 41 | 42 | 43 | i+=1 44 | -------------------------------------------------------------------------------- /Preprocessing/Image_preprocessing/oasis_atlas.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import SimpleITK as sitk 3 | import matplotlib.pyplot as plt 4 | import time 5 | from tqdm.notebook import tqdm 6 | import random 7 | data_path_txt = './../../medical_dataset/OASIS/all_paths.txt' 8 | data_path = './../../medical_dataset/OASIS/image_2D/' 9 | 10 | f = open(data_path_txt,'r') 11 | lines = f.readlines() 12 | all_data = [] 13 | for line in lines: 14 | name, path_ = line[:-1].split('\t') 15 | all_data.append([int(name),path_]) 16 | f.close() 17 | 18 | labels = [] 19 | for data in all_data: 20 | labels.append(data[0]) 21 | 22 | i = 0 23 | atlas = sitk.ReadImage('C:/Users/user/Desktop/images/OAS30936_d6483_2.nii.gz') 24 | elastixImageFilter = sitk.ElastixImageFilter() 25 | elastixImageFilter.SetFixedImage(atlas) 26 | thres=-1 27 | dict_data_count = {} 28 | 29 | for data in all_data: 30 | if i >thres: 31 | time_1 = time.time() 32 | read_sitk = sitk.ReadImage(data[1]) 33 | elastixImageFilter.SetMovingImage(read_sitk) 34 | elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap('translation')) 35 | elastixImageFilter.Execute() 36 | read_sitk = elastixImageFilter.GetResultImage() 37 | 38 | img_vol = sitk.GetArrayFromImage(read_sitk) 39 | if img_vol.shape[0] >=100: 40 | side = [i+90 for i in range(40)] 41 | front = [i+100 for i in range(50)] 42 | cross = [i+90 for i in range(40)] 43 | random.shuffle(side) 44 | random.shuffle(front) 45 | random.shuffle(cross) 46 | image_2D_path = data_path +str(int(data[0])) +'_' +str(i) +'_' 47 | if not (data[0] in list(dict_data_count.keys())): 48 | dict_data_count[data[0]] = 1 49 | for s in side[:20]: 50 | save_path = image_2D_path+ 'side' + str(s) +'.png' 51 | plt.imshow(img_vol[:,:,s], cmap='gray') 52 | plt.axis('off') 53 | # plt.show 54 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 55 | del save_path 56 | plt.close('all') 57 | plt.clf() 58 | ################## 59 | for f in front[:20]: 60 | save_path = image_2D_path+ 'front' + str(f) +'.png' 61 | plt.imshow(img_vol[:,f], cmap='gray') 62 | plt.axis('off') 63 | # plt.show 64 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 65 | del save_path 66 | plt.close('all') 67 | plt.clf() 68 | ############################ 69 | for c in cross[:20]: 70 | save_path = image_2D_path+ 'cross' + str(c) +'.png' 71 | plt.imshow(img_vol[50], cmap='gray') 72 | plt.axis('off') 73 | # plt.show 74 | plt.savefig(save_path, bbox_inches='tight',pad_inches = 0) 75 | del save_path 76 | plt.close('all') 77 | plt.clf() 78 | 79 | del read_sitk 80 | del img_vol 81 | del image_2D_path 82 | i+=1 -------------------------------------------------------------------------------- /Preprocessing/Non_image_preprocessing/abide_kmeans.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pickle 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='abide_kmeans') 8 | parser.add_argument('--K', type=int, default=4) 9 | parser.add_argument('--thres',type=str, default='0.9,0.9,0.9,0.9') 10 | return parser.parse_known_args() 11 | args, unknown = parse_args() 12 | K = args.K 13 | clinical = pd.read_csv('./../non-image/ABIDE/FinalMerged_MM.csv') 14 | 15 | clinical.drop(columns=['SITE_ID'], inplace=True) 16 | clinical.drop(columns = ['FIQ_TEST_TYPE', 'VIQ_TEST_TYPE', 'PIQ_TEST_TYPE'], inplace= True) 17 | clinical['HANDEDNESS_CATEGORY']= clinical['HANDEDNESS_CATEGORY'].fillna(clinical['HANDEDNESS_CATEGORY'].mode()[0]) 18 | clinical = clinical.fillna(-9999) 19 | 20 | use_l = ['SUB_ID', 21 | 'DX_GROUP', 22 | 'AGE_AT_SCAN', 23 | 'SEX', 24 | 'HANDEDNESS_CATEGORY', 25 | 'HANDEDNESS_SCORES', 26 | 'FIQ', 27 | 'VIQ', 28 | 'PIQ', 29 | 'ADOS_STEREO_BEHAV', 30 | 'ADOS_GOTHAM_SOCAFFECT', 31 | 'ADOS_GOTHAM_RRB', 32 | 'ADOS_GOTHAM_TOTAL', 33 | 'ADOS_GOTHAM_SEVERITY', 34 | 'SRS_RAW_TOTAL', 35 | 'EYE_STATUS_AT_SCAN'] 36 | clinical_dum = clinical[use_l] 37 | df= clinical_dum.drop(columns = ['SUB_ID','DX_GROUP']) 38 | normalized_df=((df-df.mean())/df.std()).fillna(0) 39 | feature_dict = {} 40 | label_dict = {} 41 | i = 0 42 | sub_id = clinical_dum['SUB_ID'] 43 | use_clinical_dummy_no = normalized_df 44 | 45 | for sub in sub_id: 46 | feature_dict[sub] = use_clinical_dummy_no.iloc[i].to_numpy() 47 | label_dict[sub] = clinical_dum['DX_GROUP'].iloc[i] 48 | i+=1 49 | non_img = {'label': label_dict, 'feature': feature_dict} 50 | with open('./../non_image/ABIDE/abide_nonimg.pkl', 'wb') as f: 51 | pickle.dump(non_img, f, pickle.HIGHEST_PROTOCOL) 52 | 53 | path_ = './../SimCLR/extracted_feature/abide/' 54 | train_feature = np.loadtxt('./' + path_ + 'train_feature.csv',delimiter=',',dtype=np.float32) 55 | # valid_feature = np.loadtxt('./' + path_ + 'valid_feature.csv',delimiter=',',dtype=np.float32) 56 | test_feature = np.loadtxt('./' + path_ + 'test_feature.csv',delimiter=',',dtype=np.float32) 57 | 58 | train_id = pd.read_csv('./' + path_ +'train_id.csv', header=None) 59 | # valid_id = pd.read_csv('./' + path_ +'valid_id.csv', header=None) 60 | test_id = pd.read_csv('./' + path_ +'test_id.csv', header=None) 61 | 62 | with open('./../non-image/ABIDE/abide_nonimg.pkl', 'rb') as fr: 63 | data = pickle.load(fr) 64 | before_adj = [] 65 | labels= [] 66 | train_index = [] 67 | valid_index = [] 68 | test_index = [] 69 | 70 | 71 | all_feature = [] 72 | concate_feature = []#brain feature + patient feature 73 | k = 0 74 | id_list = [] 75 | for i in range(len(train_id)): 76 | a = int(str(train_id[0][i])[:5]) 77 | id_list.append(a) 78 | l = data['label'][a] 79 | l_ = data['feature'][a] 80 | # l = rid_label_dict[id_[i]] 81 | train_index.append(k) 82 | 83 | labels.append(l) 84 | before_adj.append(l_) 85 | all_feature.append(list(train_feature[k])) 86 | concate_feature.append(list(train_feature[k]) + list(l_)) 87 | k+=1 88 | 89 | k_test = 0 90 | for i in range(len(test_id)): 91 | a = int(str(test_id[0][i])[:5]) 92 | id_list.append(a) 93 | l = data['label'][a] 94 | l_ = data['feature'][a] 95 | # l = rid_label_dict[id_[i]] 96 | test_index.append(k) 97 | 98 | labels.append(l) 99 | before_adj.append(l_) 100 | all_feature.append(list(test_feature[k_test])) 101 | concate_feature.append(list(test_feature[k_test]) + list(l_)) 102 | k+=1 103 | k_test +=1 104 | 105 | 106 | modi_label = [] 107 | for i in labels: 108 | if i == 1: 109 | modi_label.append(0) 110 | else: 111 | modi_label.append(1) 112 | 113 | 114 | labels = modi_label 115 | 116 | indexes = train_index + valid_index + test_index 117 | train_ = len(indexes[:int(len(indexes)*0.6)]) 118 | valid_ = len(indexes[int(len(indexes)*0.6):int(len(indexes)*0.7)]) 119 | test_ = len(indexes[int(len(indexes)*0.7):]) 120 | 121 | import random 122 | random.shuffle(indexes) 123 | train_index = indexes[:train_] 124 | valid_index = indexes[train_:train_+valid_] 125 | test_index = indexes[train_+valid_:] 126 | before_adj_num = np.array(before_adj) 127 | all_feature_num = np.array(all_feature) 128 | concate_feature_num = np.array(concate_feature) 129 | from sklearn.preprocessing import minmax_scale 130 | before_adj_num = minmax_scale(before_adj_num, axis=0, copy =True) 131 | 132 | import sklearn.metrics.pairwise 133 | cos_sim = sklearn.metrics.pairwise.cosine_similarity(before_adj_num,before_adj_num) 134 | 135 | will_remove_list_1 = [] 136 | will_remove_list_2 = [] 137 | for id_ in id_list: 138 | if clinical[clinical['SUB_ID'] ==id_]['ADI_R_SOCIAL_TOTAL_A'].item() == -9999.0: 139 | will_remove_list_1.append(id_) 140 | if clinical[clinical['SUB_ID'] ==id_]['SRS_RAW_TOTAL'].item() == -9999.0: 141 | will_remove_list_2.append(id_) 142 | 143 | transpose_num = clinical[use_l[2:]].T 144 | from sklearn.cluster import KMeans 145 | kmeans = KMeans(n_clusters=K) 146 | kmeans.fit(transpose_num) 147 | type_list = [] 148 | for k in range(K): 149 | type_list.append([]) 150 | for i in range(len(use_l[2:])): 151 | type_list[kmeans.labels_[i]].append(use_l[2:][i]) 152 | 153 | from sklearn.preprocessing import minmax_scale 154 | 155 | k = 0 156 | threses = args.thres.split(',') 157 | threses = [float(th) for th in threses] 158 | save_list= [] 159 | for types in type_list: 160 | before_adj_ = [] 161 | print('type' + str(k)) 162 | print("********") 163 | ll = clinical[['DX_GROUP', 'SUB_ID']+types].drop(columns=['DX_GROUP']) 164 | ll = ll.fillna(0) 165 | for id_ in id_list: 166 | lp = clinical[clinical['SUB_ID'] ==id_].index.item() 167 | p = ll.drop(columns=['SUB_ID']).loc[lp].tolist() 168 | before_adj_.append(p) 169 | p_ = np.array(before_adj_) 170 | p_ = minmax_scale(p_, axis=0, copy =True) 171 | 172 | cos_ = sklearn.metrics.pairwise.cosine_similarity(p_,p_) 173 | adj_ = np.zeros(cos_.shape) 174 | thres = threses[k] 175 | for i in range(cos_.shape[0]): 176 | for j in range(cos_.shape[0]): 177 | if cos_[i][j] > thres: 178 | adj_[i][j] = 1 179 | if k ==2: 180 | if id_list[i] in will_remove_list_1: 181 | adj_[i][j] = 0 182 | if id_list[j] in will_remove_list_1: 183 | adj_[i][j] = 0 184 | if k ==3: 185 | if id_list[i] in will_remove_list_2: 186 | adj_[i][j] = 0 187 | if id_list[j] in will_remove_list_2: 188 | adj_[i][j] = 0 189 | else: 190 | adj_[i][j] = 0 191 | for i in range(cos_.shape[0]): 192 | for j in range(cos_.shape[0]): 193 | if i == j: 194 | adj_[i][j] = 1 195 | save_list.append(adj_) 196 | k+=1 197 | 198 | y = np.zeros((len(labels),2)) 199 | for i in range(len(labels)): 200 | if labels[i] ==0: 201 | y[i,0]=1 202 | else: 203 | y[i,1]=1 204 | 205 | train_index = np.array(train_index) 206 | valid_index = np.array(valid_index) 207 | test_index = np.array(test_index) 208 | if K ==6: 209 | adj_0 = save_list[0] 210 | adj_1 = save_list[1] 211 | adj_2 = save_list[2] 212 | adj_3 = save_list[3] 213 | adj_4 = save_list[4] 214 | adj_5 = save_list[5] 215 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'type5':adj_5,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 216 | elif K ==5: 217 | adj_0 = save_list[0] 218 | adj_1 = save_list[1] 219 | adj_2 = save_list[2] 220 | adj_3 = save_list[3] 221 | adj_4 = save_list[4] 222 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 223 | elif K ==4: 224 | adj_0 = save_list[0] 225 | adj_1 = save_list[1] 226 | adj_2 = save_list[2] 227 | adj_3 = save_list[3] 228 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 229 | elif K ==3: 230 | adj_0 = save_list[0] 231 | adj_1 = save_list[1] 232 | adj_2 = save_list[2] 233 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 234 | elif K ==2: 235 | adj_0 = save_list[0] 236 | adj_1 = save_list[1] 237 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 238 | 239 | with open('./../MultiplexNetwork/data/abide.pkl', 'wb') as f: 240 | pickle.dump(multi, f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /Preprocessing/Non_image_preprocessing/adni_kmeans.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pickle 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='adni_kmeans') 8 | parser.add_argument('--K', type=int, default=4) 9 | parser.add_argument('--thres',type=str, default='0.9,0.9,0.9,0.9') 10 | return parser.parse_known_args() 11 | args, unknown = parse_args() 12 | K = args.K 13 | clinical = pd.read_csv('./../non-image/ADNI/additional_adni1_2.csv') 14 | 15 | clinical.drop(columns = ['ICV_bl','WholeBrain_bl','IMAGEUID_bl','Ventricles_bl','Fusiform_bl','MidTemp_bl','Month', 'M','race','education','Month_bl','Years_bl','ethnicity','apoe4','age_'],inplace=True) 16 | total_index = [] 17 | for id in clinical['RID'].unique(): 18 | local_index = [] 19 | nan_count_index = [] 20 | for idx in clinical[clinical['RID'] == id].index: 21 | local_index.append(idx) 22 | nan_count_index.append(clinical[clinical['RID'] == id].loc[idx].isna().sum()) 23 | tmp = min(nan_count_index) 24 | total_index.append(local_index[nan_count_index.index(tmp)]) 25 | use_clinical = clinical.loc[total_index] 26 | 27 | path_ = './../SimCLR/extracted_feature/adni/' 28 | train_feature = np.loadtxt('./' + path_ + 'train_feature.csv',delimiter=',',dtype=np.float32) 29 | valid_feature = np.loadtxt('./' + path_ + 'valid_feature.csv',delimiter=',',dtype=np.float32) 30 | test_feature = np.loadtxt('./' + path_ + 'test_feature.csv',delimiter=',',dtype=np.float32) 31 | 32 | train_id = pd.read_csv('./' + path_ +'train_id.csv', header=None) 33 | valid_id = pd.read_csv('./' + path_ +'valid_id.csv', header=None) 34 | test_id = pd.read_csv('./' + path_ +'test_id.csv', header=None) 35 | 36 | id_list = [] 37 | image_list = [] 38 | for i in range(len(train_id)): 39 | a = str(train_id[0][i]) 40 | l = int(a[:a[:-3].rfind('0000')]) 41 | id_list.append(l) 42 | image_list.append(train_feature[i]) 43 | for i in range(len(valid_id)): 44 | a = str(valid_id[0][i]) 45 | l = int(a[:a[:-3].rfind('0000')]) 46 | id_list.append(l) 47 | image_list.append(valid_feature[i]) 48 | for i in range(len(test_id)): 49 | a = str(test_id[0][i]) 50 | l = int(a[:a[:-3].rfind('0000')]) 51 | id_list.append(l) 52 | image_list.append(test_feature[i]) 53 | 54 | use_col = use_clinical.keys().tolist() 55 | 56 | for col in use_col: 57 | use_clinical[col].fillna(use_clinical[col].mode()[0],inplace=True) 58 | 59 | k_means_list = use_col[3:] 60 | 61 | from sklearn.preprocessing import minmax_scale 62 | use_clinical_minmax = minmax_scale(use_clinical[k_means_list], axis=0, copy =True) 63 | use_k_means = use_clinical.copy() 64 | use_k_means[k_means_list] = use_clinical_minmax 65 | 66 | transpose_num = use_k_means[k_means_list].T 67 | from sklearn.cluster import KMeans 68 | kmeans = KMeans(n_clusters=K) 69 | kmeans.fit(transpose_num) 70 | 71 | type_list = [] 72 | for k in range(K): 73 | type_list.append([]) 74 | for i in range(len(k_means_list)): 75 | type_list[kmeans.labels_[i]].append(k_means_list[i]) 76 | 77 | non_image_feat = [] 78 | labels = [] 79 | use_clinical_dum = pd.get_dummies(use_clinical.drop(columns=['label','RID']),columns=use_col[3:7]) 80 | for id_ in id_list: 81 | lp = use_clinical[use_clinical['RID'] == id_].index.item() 82 | lab = use_clinical['label'].loc[lp] 83 | non_image_feat.append(minmax_scale(use_clinical_dum, axis=0, copy =True)[lp]) 84 | # non_image_feat.append(use_clinical_dum.loc[lp]) 85 | labels.append(lab) 86 | 87 | import sklearn.metrics.pairwise 88 | from sklearn.preprocessing import minmax_scale 89 | 90 | k = 0 91 | save_list= [] 92 | threses = args.thres.split(',') 93 | threses = [float(th) for th in threses] 94 | for types in type_list: 95 | before_adj_ = [] 96 | use_clinical_dummy_multi = use_clinical[['label']+types] 97 | ll = use_clinical_dummy_multi.drop(columns=['label']) 98 | ll = ll.fillna(0) 99 | for id_ in id_list: 100 | lp = use_clinical[use_clinical['RID'] ==id_].index.item() 101 | p = ll.loc[lp].tolist() 102 | before_adj_.append(p) 103 | p_ = np.array(before_adj_) 104 | p_ = minmax_scale(p_, axis=0, copy =True) 105 | 106 | cos_ = sklearn.metrics.pairwise.cosine_similarity(p_,p_) 107 | adj_ = np.zeros(cos_.shape) 108 | thres = threses[k] 109 | for i in range(cos_.shape[0]): 110 | for j in range(cos_.shape[0]): 111 | if cos_[i][j] > thres: 112 | adj_[i][j] = 1 113 | else: 114 | adj_[i][j] = 0 115 | save_list.append(adj_) 116 | k+=1 117 | 118 | y = np.zeros((len(labels),3)) 119 | for i in range(len(labels)): 120 | if labels[i] ==0: 121 | y[i,0]=1 122 | elif labels[i]==3: 123 | y[i,1]=1 124 | elif labels[i] ==4: 125 | y[i,2] = 1 126 | 127 | concat_feature = [] 128 | for i in range(len(image_list)): 129 | concat = np.concatenate((np.expand_dims(image_list[i],axis=0),np.expand_dims(non_image_feat[i],axis=0)),axis=1) 130 | concat_feature.append(concat[0]) 131 | 132 | concate_feature_num = np.array(concat_feature) 133 | 134 | indexes = [i for i in range(len(labels))] 135 | train_ = len(indexes[:int(len(indexes)*0.6)]) 136 | valid_ = len(indexes[int(len(indexes)*0.6):int(len(indexes)*0.7)]) 137 | test_ = len(indexes[int(len(indexes)*0.7):]) 138 | 139 | 140 | import random 141 | random.shuffle(indexes) 142 | train_index = indexes[:train_] 143 | valid_index = indexes[train_:train_+valid_] 144 | test_index = indexes[train_+valid_:] 145 | 146 | train_index = np.array(train_index) 147 | valid_index = np.array(valid_index) 148 | test_index = np.array(test_index) 149 | 150 | if K ==6: 151 | adj_0 = save_list[0] 152 | adj_1 = save_list[1] 153 | adj_2 = save_list[2] 154 | adj_3 = save_list[3] 155 | adj_4 = save_list[4] 156 | adj_5 = save_list[5] 157 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'type5':adj_5,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 158 | elif K ==5: 159 | adj_0 = save_list[0] 160 | adj_1 = save_list[1] 161 | adj_2 = save_list[2] 162 | adj_3 = save_list[3] 163 | adj_4 = save_list[4] 164 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 165 | elif K ==4: 166 | adj_0 = save_list[0] 167 | adj_1 = save_list[1] 168 | adj_2 = save_list[2] 169 | adj_3 = save_list[3] 170 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 171 | 172 | elif K ==3: 173 | adj_0 = save_list[0] 174 | adj_1 = save_list[1] 175 | adj_2 = save_list[2] 176 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 177 | elif K ==2: 178 | adj_0 = save_list[0] 179 | adj_1 = save_list[1] 180 | multi = {'label':y,'type0':adj_0,'type1':adj_1,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 181 | 182 | with open('./../MultiplexNetwork/data/adni.pkl', 'wb') as f: 183 | pickle.dump(multi, f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /Preprocessing/Non_image_preprocessing/cmmd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np\n", 11 | "import pickle" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "image_feat = np.loadtxt('./../SimCLR/cmmd_image_feature.csv',delimiter=',',dtype=np.float32)\n", 21 | "ids = pd.read_csv('./../SimCLR/cmmd_ids.csv',header=None)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "ids_list = []\n", 31 | "for i in range(len(image_feat)):\n", 32 | " a = 'D' + str(ids[0][i])[0] +'-' + str(ids[0][i])[1:5]\n", 33 | " ids_list.append(a)\n", 34 | " i+=1" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "data = pd.read_excel('./CMMD_clinicaldata_revision (1).xlsx')\n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "data.drop(columns='number', inplace=True)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "data['subtype'].fillna(0,inplace=True)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "data['age_'] = pd.qcut(data['Age'],5,labels=False)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "data.drop(columns=['Age'],inplace=True)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "data['label'] = data['classification'].apply(lambda x: 0 if x == 'Benign' else 1)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "data.drop(columns=['classification'],inplace=True)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "data" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "sub = data['ID1']\n", 116 | "cdr = data['label']\n", 117 | "data_non_dmgi_drop = data.drop(columns=['ID1','label'])\n", 118 | "data_dum = pd.get_dummies(data_non_dmgi_drop, columns=data_non_dmgi_drop.keys())" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "feature_dict = {}\n", 128 | "label_dict = {}\n", 129 | "i = 0\n", 130 | "# use_clinical_dummy_no = use_clinical_dummy.drop(columns=['cdr','age_','apoe']).fillna(0)\n", 131 | "\n", 132 | "for s in sub:\n", 133 | " feature_dict[s] = data_dum.iloc[i].to_numpy()\n", 134 | " label_dict[s] = cdr[i]\n", 135 | " i+=1" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "oasis3_1 = {'label': label_dict, 'feature': feature_dict}\n", 145 | "with open('./cmmd.pkl', 'wb') as f:\n", 146 | " pickle.dump(oasis3_1, f, pickle.HIGHEST_PROTOCOL)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "concat_feature = []\n", 156 | "labels = []\n", 157 | "i =0\n", 158 | "for idd in ids_list:\n", 159 | " pp = image_feat[i].tolist() + feature_dict[idd].tolist()\n", 160 | " concat_feature.append(pp)\n", 161 | " labels.append(label_dict[idd])\n", 162 | " i+=1" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "concat_num = np.array(concat_feature)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "data" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "type_list = [['LeftRight'],['abnormality'],['subtype'],['age_']]\n", 190 | "from sklearn.preprocessing import minmax_scale\n", 191 | "import sklearn.metrics.pairwise\n", 192 | "\n", 193 | "k = 0\n", 194 | "save_list= []\n", 195 | "threses = [0.9, 0.9, 0.9, 0.5]\n", 196 | "for types in type_list:\n", 197 | " before_adj_ = []\n", 198 | " use_clinical_dummy_multi = pd.get_dummies(data[['label']+types], columns=types)\n", 199 | " ll = use_clinical_dummy_multi.drop(columns=['label'])\n", 200 | " ll = ll.fillna(0)\n", 201 | " for id_ in ids_list:\n", 202 | " # print(data[data['ID1'] ==id_].index[0])\n", 203 | " lp = data[data['ID1'] ==id_].index[0]\n", 204 | " p = ll.loc[lp].tolist()\n", 205 | " before_adj_.append(p)\n", 206 | " p_ = np.array(before_adj_)\n", 207 | " p_ = minmax_scale(p_, axis=0, copy =True)\n", 208 | "\n", 209 | " cos_ = sklearn.metrics.pairwise.cosine_similarity(p_,p_)\n", 210 | " adj_ = np.zeros(cos_.shape)\n", 211 | " thres = threses[k]\n", 212 | " for i in range(cos_.shape[0]):\n", 213 | " for j in range(cos_.shape[0]):\n", 214 | " if cos_[i][j] > thres:\n", 215 | " adj_[i][j] = 1\n", 216 | " else:\n", 217 | " adj_[i][j] = 0\n", 218 | " save_list.append(adj_)\n", 219 | " k+=1" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "adj_0 = save_list[0]\n", 229 | "adj_1 = save_list[1]\n", 230 | "adj_2 = save_list[2]\n", 231 | "adj_3 = save_list[3]" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "indexes = [i for i in range(len(labels))]\n", 241 | "train_ = len(indexes[:int(len(indexes)*0.6)])\n", 242 | "valid_ = len(indexes[int(len(indexes)*0.6):int(len(indexes)*0.7)])\n", 243 | "test_ = len(indexes[int(len(indexes)*0.7):])\n", 244 | "\n", 245 | "import random\n", 246 | "random.shuffle(indexes)\n", 247 | "train_index = indexes[:train_]\n", 248 | "valid_index = indexes[train_:train_+valid_]\n", 249 | "test_index = indexes[train_+valid_:]\n" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "y = np.zeros((len(labels),2))\n", 259 | "for i in range(len(labels)):\n", 260 | " if labels[i] ==0:\n", 261 | " y[i,0]=1\n", 262 | " elif labels[i]==1:\n", 263 | " y[i,1]=1\n", 264 | "\n", 265 | "train_index = np.array(train_index)\n", 266 | "valid_index = np.array(valid_index)\n", 267 | "test_index = np.array(test_index)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "concat_num.shape" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concat_num}\n" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "with open('./../../MultiplexNetwork/data/cmmd.pkl', 'wb') as f:\n", 295 | " pickle.dump(oasis3_multi, f, pickle.HIGHEST_PROTOCOL)" 296 | ] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3.8.5 ('study')", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.8.5" 316 | }, 317 | "orig_nbformat": 4, 318 | "vscode": { 319 | "interpreter": { 320 | "hash": "4e7dcdda9aa57128db4f79b31d827bdb0aa0e537d1eeb024b6c7498f481347ff" 321 | } 322 | } 323 | }, 324 | "nbformat": 4, 325 | "nbformat_minor": 2 326 | } 327 | -------------------------------------------------------------------------------- /Preprocessing/Non_image_preprocessing/duke_kmeans.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pickle 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='qin_kmeans') 8 | parser.add_argument('--K', type=int, default=4) 9 | return parser.parse_known_args() 10 | args, unknown = parse_args() 11 | K = args.K 12 | data = pd.read_excel('./../non-image/QIN/QIN_clinical.xlsx') 13 | na_drop_list = [] 14 | for col in data.keys(): 15 | if sum(data[col].isna())/len(data) > 0.1: 16 | print(col) 17 | na_drop_list.append(col) 18 | # print(sum(data[col].isna())) 19 | na_drop_list = ['Staging(Tumor Size)# [T]','Mol Subtype','Number of Ovaries In Situ \n','Race and Ethnicity'] + na_drop_list +['Days to last local recurrence free assessment (from the date of diagnosis) ', 'Days to last distant recurrence free assemssment(from the date of diagnosis) ','Days to Surgery (from the date of diagnosis)','Days to MRI (From the Date of Diagnosis)'] 20 | data.drop(columns = na_drop_list, inplace=True) 21 | for col in data.keys(): 22 | data[col].fillna(data[col].mode()[0]) 23 | 24 | numerical = ['Date of Birth (Days)'] 25 | data_non_dmgi = data.copy() 26 | # data_non_dmgi['age_'] = pd.qcut(data_non_dmgi['ageAtEntry'], 5, labels=False) 27 | # use_clinical['height_'] = pd.qcut(use_clinical['height'], 3, labels=False) 28 | # use_clinical['weight_'] = pd.qcut(use_clinical['weight'], 3, labels=False) 29 | 30 | for col in numerical: 31 | data_non_dmgi[col] = pd.qcut(data_non_dmgi[col],5, labels=False) 32 | sub = data_non_dmgi['Patient ID'] 33 | cdr = data_non_dmgi['Tumor Grade'] 34 | data_non_dmgi_drop = data_non_dmgi.drop(columns=['Patient ID','Tumor Grade']) 35 | data_non_dmgi_drop_dum = pd.get_dummies(data_non_dmgi_drop, columns=data_non_dmgi_drop.keys()) 36 | feature_dict = {} 37 | label_dict = {} 38 | i = 0 39 | # use_clinical_dummy_no = use_clinical_dummy.drop(columns=['cdr','age_','apoe']).fillna(0) 40 | 41 | for s in sub: 42 | feature_dict[s] = data_non_dmgi_drop_dum.iloc[i].to_numpy() 43 | label_dict[s] = cdr[i] 44 | i+=1 45 | data_dmgi = data.drop(columns=['Date of Birth (Days)']) 46 | path_ = '../../moco/' 47 | features = np.loadtxt('./' + path_ + 'breast_feature.csv',delimiter=',',dtype=np.float32) 48 | # valid_feature = np.loadtxt('./' + path_ + 'valid_feature.csv',delimiter=',',dtype=np.float32) 49 | ids = pd.read_csv('./' + path_ +'breast_id.csv', header=None) 50 | # valid_id = pd.read_csv('./' + path_ +'valid_id.csv', header=None) 51 | for col in data_dmgi: 52 | data_dmgi[col].fillna(data_dmgi[col].mode()[0],inplace=True) 53 | idx_list = [] 54 | for col in data_dmgi.keys(): 55 | idx = data_dmgi[data_dmgi[col] =='NP'].index 56 | idx_list.append(idx) 57 | if len(idx)>0: 58 | data_dmgi.loc[idx, col] = -1 59 | k_means_list = data_dmgi.drop(columns = ['Patient ID', 'Tumor Grade']).keys() 60 | import random 61 | transpose_num = data_dmgi[k_means_list].T 62 | from sklearn.cluster import KMeans 63 | kmeans = KMeans(n_clusters=K) 64 | kmeans.fit(transpose_num) 65 | type_list = [] 66 | for k in range(K): 67 | type_list.append([]) 68 | for i in range(len(k_means_list)): 69 | type_list[kmeans.labels_[i]].append(k_means_list[i]) 70 | 71 | id_list = [] 72 | for idss in ids[0]: 73 | a = 'Breast_MRI_' + str(idss)[1:4] 74 | id_list.append(a) 75 | from sklearn.preprocessing import minmax_scale 76 | import sklearn 77 | from sklearn.cluster import KMeans 78 | import sklearn.metrics.pairwise 79 | 80 | k = 0 81 | save_list= [] 82 | threses = [0.8, 0.8, 0.8, 0.8] 83 | thres = 0.8 84 | for types in type_list: 85 | before_adj_ = [] 86 | print('type' + str(k)) 87 | print("********") 88 | use_clinical_dummy_multi = pd.get_dummies(data_dmgi[['Tumor Grade']+types], columns=types) 89 | ll = use_clinical_dummy_multi.drop(columns=['Tumor Grade']) 90 | ll = ll.fillna(0) 91 | for id_ in id_list: 92 | lp = data_dmgi[data_dmgi['Patient ID'] ==id_].index.item() 93 | p = ll.loc[lp].tolist() 94 | before_adj_.append(p) 95 | p_ = np.array(before_adj_) 96 | p_ = minmax_scale(p_, axis=0, copy =True) 97 | 98 | cos_ = sklearn.metrics.pairwise.cosine_similarity(p_,p_) 99 | adj_ = np.zeros(cos_.shape) 100 | thres = 0.8 101 | for i in range(cos_.shape[0]): 102 | for j in range(cos_.shape[0]): 103 | if cos_[i][j] > thres: 104 | adj_[i][j] = 1 105 | else: 106 | adj_[i][j] = 0 107 | print('dense') 108 | print(sum(sum(adj_))/(adj_.shape[0]*adj_.shape[0])) 109 | print("********") 110 | save_list.append(adj_) 111 | k+=1 112 | 113 | feat = [] 114 | lab = [] 115 | 116 | for i in range(len(id_list)): 117 | a_ = id_list[i] 118 | img_feat = features[i] 119 | l_ = data_dmgi.drop(columns=['Patient ID', 'Tumor Grade']) 120 | l_ = l_.fillna(0) 121 | 122 | l_dum = pd.get_dummies(l_, columns=l_.keys()) 123 | l = data_dmgi[data_dmgi['Patient ID'] ==a_].index.item() 124 | p = l_dum.loc[l].tolist() 125 | p_num = np.array(p) 126 | pp = np.concatenate((img_feat,p_num)) 127 | feat.append(pp) 128 | label = data_dmgi.loc[l,'Tumor Grade'].item() 129 | lab.append(label) 130 | i+=1 131 | concate_feature = np.array(feat) 132 | indexes = [i for i in range(len(lab))] 133 | # indexes=train_index.tolist() + valid_index.tolist() + test_index.tolist() 134 | train_ = len(indexes[:int(len(indexes)*0.6)]) 135 | valid_ = len(indexes[int(len(indexes)*0.6):int(len(indexes)*0.7)]) 136 | test_ = len(indexes[int(len(indexes)*0.7):]) 137 | 138 | import random 139 | random.shuffle(indexes) 140 | train_index = indexes[:train_] 141 | valid_index = indexes[train_:train_+valid_] 142 | test_index = indexes[train_+valid_:] 143 | # train_index = [i for i in range(train_)] 144 | # valid_index = [i+max(train_index)+1 for i in range(valid_)] 145 | # test_index = [i+max(valid_index)+1 for i in range(test_)] 146 | 147 | y = np.zeros((len(lab),3)) 148 | for i in range(len(lab)): 149 | if lab[i] ==1: 150 | y[i,0]=1 151 | elif lab[i]==2: 152 | y[i,1]=1 153 | else: 154 | y[i,2] =1 155 | 156 | train_index = np.array(train_index) 157 | valid_index = np.array(valid_index) 158 | test_index = np.array(test_index) 159 | if K ==6: 160 | adj_0 = save_list[0] 161 | adj_1 = save_list[1] 162 | adj_2 = save_list[2] 163 | adj_3 = save_list[3] 164 | adj_4 = save_list[4] 165 | adj_5 = save_list[5] 166 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'type5':adj_5,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature} 167 | elif K ==5: 168 | adj_0 = save_list[0] 169 | adj_1 = save_list[1] 170 | adj_2 = save_list[2] 171 | adj_3 = save_list[3] 172 | adj_4 = save_list[4] 173 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature} 174 | elif K ==4: 175 | adj_0 = save_list[0] 176 | adj_1 = save_list[1] 177 | adj_2 = save_list[2] 178 | adj_3 = save_list[3] 179 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature} 180 | elif K ==3: 181 | adj_0 = save_list[0] 182 | adj_1 = save_list[1] 183 | adj_2 = save_list[2] 184 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature} 185 | elif K ==2: 186 | adj_0 = save_list[0] 187 | adj_1 = save_list[1] 188 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature} 189 | 190 | with open('./MultiplexNetwork/data/qin.pkl', 'wb') as f: 191 | pickle.dump(oasis3_multi, f, pickle.HIGHEST_PROTOCOL) 192 | -------------------------------------------------------------------------------- /Preprocessing/Non_image_preprocessing/oasis_kmeans.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pickle 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='qin_kmeans') 8 | parser.add_argument('--K', type=int, default=4) 9 | parser.add_argument('--thres',type=str, default='0.9,0.9,0.9,0.9') 10 | return parser.parse_known_args() 11 | args, unknown = parse_args() 12 | K = args.K 13 | train_rate = 0.6 14 | clinical = pd.read_csv('./../../non-image/OASIS3/clinical data 2.csv') 15 | clinical_basic = pd.read_csv('./../../non-image/OASIS3/subject_basic.csv') 16 | clinical_judge = pd.read_csv('./../../non-image/OASIS3/Judgement.csv') 17 | clinical_basic_ = clinical_basic.drop(columns=['PETs','MR Sessions'], axis=1) 18 | faq_list = ['BILLS','TAXES','SHOPPING','GAMES','STOVE','MEALPREP','EVENTS','PAYATTN','REMDATES','TRAVEL'] 19 | clinical_basic_[faq_list] = 0 20 | clinical_basic_.drop(columns=['YOB'],inplace = True) 21 | for col in clinical_basic_.columns: 22 | clinical_basic_[col] = clinical_basic_[col].fillna(clinical_basic_[col].mode()[0]) 23 | total_index = [] 24 | sign = 0 25 | locate_cdr2=0 26 | for id in clinical['Subject'].unique(): 27 | local_index = [] 28 | nan_count_index = [] 29 | sign=0 30 | locate_cdr2 = 0 31 | for idx in clinical[clinical['Subject'] == id].index: 32 | local_index.append(idx) 33 | nan_count_index.append(clinical[clinical['Subject'] == id].loc[idx].isna().sum()) 34 | if (clinical[clinical['Subject']==id].loc[idx]['cdr'] == 2.0): 35 | sign=1 36 | locate_cdr2 = idx 37 | tmp = min(nan_count_index) 38 | # total_index.append(local_index[nan_count_index.index(tmp)]) 39 | 40 | if sign == 1: 41 | total_index.append(locate_cdr2) 42 | else: 43 | total_index.append(local_index[nan_count_index.index(tmp)]) 44 | 45 | use_clinical = clinical.loc[total_index] 46 | use_clinical['label'] = 0 47 | use_clinical['label'] = use_clinical['dx1'].apply(lambda x: 0 if (x=='Cognitively normal') else (2 if (x=='AD Dementia') else 1)) 48 | use_clinical['age_'] = pd.qcut(use_clinical['ageAtEntry'], 5, labels=False) 49 | use_clinical['height_'] = pd.qcut(use_clinical['height'], 3, labels=False) 50 | use_clinical['weight_'] = pd.qcut(use_clinical['weight'], 3, labels=False) 51 | 52 | jud = clinical_judge[['Subject','DECSUB','DECIN','DECCLIN']] 53 | for col in jud.columns: 54 | jud[col] = jud[col].fillna(jud[col].mode()[0]) 55 | li = ['DECSUB', 'DECIN', 'DECCLIN'] 56 | for il in li: 57 | clinical_basic_[il] = 0 58 | for sub in clinical_basic_['Subject']: 59 | try: 60 | a = jud[jud['Subject']==sub][il].max() 61 | clinical_basic_.loc[clinical_basic_[clinical_basic_['Subject']==sub].index,il] = a 62 | except: 63 | clinical_basic_.loc[clinical_basic_[clinical_basic_['Subject']==sub].index,il] = 0 64 | 65 | clinical_basic_['cdr'] = use_clinical['cdr'].to_list() 66 | use_clinical_2 = use_clinical[['Subject','cdr','age_','homehobb','apoe']] 67 | use_clinical_2 68 | for col in use_clinical_2.columns.tolist()[2:]: 69 | clinical_basic_[col] = use_clinical_2[col].to_list() 70 | for col in clinical_basic_.columns: 71 | clinical_basic_[col] = clinical_basic_[col].fillna(clinical_basic_[col].mode()[0]) 72 | use_col = clinical_basic_.columns.tolist()[3:17] + clinical_basic_.columns.tolist()[-6:] 73 | all_col = use_col+['Subject','cdr'] 74 | 75 | clinical_basic_all_dum = pd.get_dummies(clinical_basic_[use_col], columns=use_col) 76 | sub = clinical_basic_['Subject'] 77 | cdr = clinical_basic_['cdr'] 78 | # clinical_basic_all_dum.drop(columns=['Subject','cdr'], inplace=True) 79 | feature_dict = {} 80 | label_dict = {} 81 | i = 0 82 | 83 | for s in sub: 84 | feature_dict[s] = clinical_basic_all_dum.iloc[i].to_numpy() 85 | label_dict[s] = cdr[i] 86 | i+=1 87 | non_img = {'label': label_dict, 'feature': feature_dict} 88 | with open('./../non_image/OASIS3/oasis_nonimg.pkl', 'wb') as f: 89 | pickle.dump(non_img, f, pickle.HIGHEST_PROTOCOL) 90 | path_ = '../SimCLR/extracted_feature/oasis/' 91 | train_feature = np.loadtxt('./' + path_ + 'train_feature.csv',delimiter=',',dtype=np.float32) 92 | # valid_feature = np.loadtxt('./' + path_ + 'valid_feature.csv',delimiter=',',dtype=np.float32) 93 | test_feature = np.loadtxt('./' + path_ + 'test_feature.csv',delimiter=',',dtype=np.float32) 94 | 95 | train_id = pd.read_csv('./' + path_ +'train_id.csv', header=None) 96 | # valid_id = pd.read_csv('./' + path_ +'valid_id.csv', header=None) 97 | test_id = pd.read_csv('./' + path_ +'test_id.csv', header=None) 98 | 99 | 100 | with open('./../non-image/OASIS3/oasis_nonimg.pkl', 'rb') as fr: 101 | data = pickle.load(fr) 102 | before_adj = [] 103 | labels= [] 104 | train_index = [] 105 | valid_index = [] 106 | test_index = [] 107 | 108 | 109 | all_feature = [] 110 | concate_feature = []#brain feature + patient feature 111 | k = 0 112 | id_list = [] 113 | for i in range(len(train_id)): 114 | a = 'OAS' + str(train_id[0][i])[:5] 115 | id_list.append(a) 116 | l = data['label'][a] 117 | l_ = data['feature'][a] 118 | # l = rid_label_dict[id_[i]] 119 | train_index.append(k) 120 | 121 | labels.append(l) 122 | before_adj.append(l_) 123 | all_feature.append(list(train_feature[k])) 124 | concate_feature.append(list(train_feature[k]) + list(l_)) 125 | k+=1 126 | 127 | k_test = 0 128 | for i in range(len(test_id)): 129 | a = 'OAS' + str(test_id[0][i])[:5] 130 | id_list.append(a) 131 | l = data['label'][a] 132 | l_ = data['feature'][a] 133 | # l = rid_label_dict[id_[i]] 134 | test_index.append(k) 135 | 136 | labels.append(l) 137 | before_adj.append(l_) 138 | all_feature.append(list(test_feature[k_test])) 139 | concate_feature.append(list(test_feature[k_test]) + list(l_)) 140 | k+=1 141 | k_test +=1 142 | modi_label = [] 143 | for i in labels: 144 | if i == 0.0: 145 | modi_label.append(0) 146 | elif i == 0.5: 147 | modi_label.append(1) 148 | elif i ==1.0: 149 | modi_label.append(2) 150 | else: 151 | modi_label.append(3) 152 | 153 | labels = modi_label 154 | indexes = train_index + valid_index + test_index 155 | # indexes=train_index.tolist() + valid_index.tolist() + test_index.tolist() 156 | train_ = len(indexes[:int(len(indexes)*train_rate)]) 157 | valid_ = len(indexes[int(len(indexes)*train_rate):int(len(indexes)*(train_rate+0.1))]) 158 | test_ = len(indexes[int(len(indexes)*(train_rate+0.1)):]) 159 | 160 | import random 161 | random.shuffle(indexes) 162 | train_index = indexes[:train_] 163 | valid_index = indexes[train_:train_+valid_] 164 | test_index = indexes[train_+valid_:] 165 | # train_index = [i for i in range(train_)] 166 | # valid_index = [i+max(train_index)+1 for i in range(valid_)] 167 | # test_index = [i+max(valid_index)+1 for i in range(test_)] 168 | 169 | before_adj_num = np.array(before_adj) 170 | all_feature_num = np.array(all_feature) 171 | concate_feature_num = np.array(concate_feature) 172 | import sklearn.metrics.pairwise 173 | cos_sim = sklearn.metrics.pairwise.cosine_similarity(before_adj_num,before_adj_num) 174 | 175 | adj = np.zeros(cos_sim.shape) 176 | thres = 0.4#0.925 177 | for i in range(cos_sim.shape[0]): 178 | for j in range(cos_sim.shape[0]): 179 | if cos_sim[i][j] > thres: 180 | adj[i][j] = 1 181 | else: 182 | adj[i][j] = 0 183 | print('dense') 184 | print(sum(sum(adj))/(adj.shape[0]*adj.shape[0])) 185 | # print(use_col) 186 | # k_means_list = use_col[:-4] + use_col[-3:] 187 | k_means_list = ['DECSUB', 'DECIN', 'DECCLIN', 'age_', 'homehobb','apoe','UDS B9: Clin. Judgements','UDS B5: NPI-Q','UDS B8: Phys. Neuro Findings','Psych Assessments'] 188 | 189 | # ['UDS B9: Clin. Judgements', 'UDS A5: Sub Health Hist.', 'UDS B6: GDS', 'UDS A1: Sub Demos'] 190 | # ['UDS B7: FAQs', 'UDS A2: Informant Demos', 'age_', 'UDS B5: NPI-Q'] 191 | # ['DECCLIN', 'UDS A3: Partcpt Family Hist.', 'UDS B3: UPDRS', 'ADRC Clinical Data'] 192 | # ['UDS B2: HIS and CVD', 'homehobb', 'UDS D1: Clinician Diagnosis', 'UDS B8: Phys. Neuro Findings', 'Psych Assessments', 'DECIN', 'apoe'] 193 | 194 | transpose_num = clinical_basic_[k_means_list].T 195 | from sklearn.cluster import KMeans 196 | kmeans = KMeans(n_clusters=K) 197 | kmeans.fit(transpose_num) 198 | type_list = [] 199 | for k in range(K): 200 | type_list.append([]) 201 | for i in range(len(k_means_list)): 202 | type_list[kmeans.labels_[i]].append(k_means_list[i]) 203 | feature_dict = {} 204 | label_dict = {} 205 | i = 0 206 | from sklearn.preprocessing import minmax_scale 207 | 208 | k = 0 209 | save_list= [] 210 | threses = args.thres.split(',') 211 | threses = [float(th) for th in threses] 212 | for types in type_list: 213 | before_adj_ = [] 214 | use_clinical_dummy_multi = pd.get_dummies(clinical_basic_[['cdr']+types], columns=types) 215 | ll = use_clinical_dummy_multi.drop(columns=['cdr']) 216 | ll = ll.fillna(0) 217 | for id_ in id_list: 218 | lp = clinical_basic_[clinical_basic_['Subject'] ==id_].index.item() 219 | p = ll.loc[lp].tolist() 220 | before_adj_.append(p) 221 | p_ = np.array(before_adj_) 222 | p_ = minmax_scale(p_, axis=0, copy =True) 223 | 224 | cos_ = sklearn.metrics.pairwise.cosine_similarity(p_,p_) 225 | adj_ = np.zeros(cos_.shape) 226 | thres = threses[k] 227 | for i in range(cos_.shape[0]): 228 | for j in range(cos_.shape[0]): 229 | if cos_[i][j] > thres: 230 | adj_[i][j] = 1 231 | else: 232 | adj_[i][j] = 0 233 | save_list.append(adj_) 234 | k+=1 235 | 236 | y = np.zeros((len(labels),4)) 237 | for i in range(len(labels)): 238 | if labels[i] ==0: 239 | y[i,0]=1 240 | elif labels[i]==1: 241 | y[i,1]=1 242 | elif labels[i] ==2: 243 | y[i,2] = 1 244 | else: 245 | y[i,3] =1 246 | 247 | train_index = np.array(train_index) 248 | valid_index = np.array(valid_index) 249 | test_index = np.array(test_index) 250 | if K ==6: 251 | adj_0 = save_list[0] 252 | adj_1 = save_list[1] 253 | adj_2 = save_list[2] 254 | adj_3 = save_list[3] 255 | adj_4 = save_list[4] 256 | adj_5 = save_list[5] 257 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'type5':adj_5,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 258 | elif K ==4: 259 | adj_0 = save_list[0] 260 | adj_1 = save_list[1] 261 | adj_2 = save_list[2] 262 | adj_3 = save_list[3] 263 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 264 | 265 | elif K ==5: 266 | adj_0 = save_list[0] 267 | adj_1 = save_list[1] 268 | adj_2 = save_list[2] 269 | adj_3 = save_list[3] 270 | adj_4 = save_list[4] 271 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'type3':adj_3,'type4':adj_4,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 272 | elif K ==3: 273 | adj_0 = save_list[0] 274 | adj_1 = save_list[1] 275 | adj_2 = save_list[2] 276 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'type2':adj_2,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 277 | elif K ==2: 278 | adj_0 = save_list[0] 279 | adj_1 = save_list[1] 280 | oasis3_multi = {'label':y,'type0':adj_0,'type1':adj_1,'train_idx':train_index,'val_idx':valid_index,'test_idx':test_index,'feature':concate_feature_num} 281 | 282 | with open('./../MultiplexNetwork/data/oasis.pkl', 'wb') as f: 283 | pickle.dump(oasis3_multi, f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /Preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Preprocessing datasets 2 | 3 | ### How to Run (ADNI dataset) 4 | - Download ADNI dataset from https://adni.loni.usc.edu/. 5 | - Preprocess ADNI dataset (Saving images) 6 | - Download .nii files from adni site 7 | - make './ADNI/all_paths.txt' files which "Patient ID" +'\t' + "path of .nii file of patient" 8 | - then run 9 |

 10 | {python adni_atlas.py}
 11 | 
12 | 13 | - After saving images, to get image representation in SimCLR folder run 14 |

 15 | {python run_with_pretrain_with_micle.py}
 16 | 
17 | - After get image representation and non-image feature, then run 18 |

 19 | {python adni_kmeans.py}
 20 | 
21 | 22 | - Then run 23 |

 24 | {python main.py}
 25 | 
26 | 27 | ### How to Run (OASIS-3 dataset) 28 | - Download OASIS-3 dataset from https://www.oasis-brains.org/. 29 | - Preprocess OASIS-3 dataset (Saving images) 30 | - Download .nii files from OASIS-3 site 31 | -make './OASIS/all_paths.txt' file which "Patient ID" + '\t' + "path of .nii file of patient" 32 | - then run 33 |

 34 | {python oasis_atlas.py}
 35 | 
36 | 37 | - After saving images, to get image representation, in SimCLR folder run 38 |

 39 | {python run_with_pretrain_with_micle.py}
 40 | 
41 | 42 | 43 | - After get image representation and non-image feature, then run 44 |

 45 | {python oasis_kmeans.py}
 46 | 
47 | 48 | 49 | - Then run 50 |

 51 | {python main.py}
 52 | 
53 | 54 | ### How to Run (ABIDE dataset) 55 | - Download ABIDE dataset from https://adni.loni.usc.edu/. (Same site with ADNI) 56 | - Preprocess ABIDE dataset (Saving Images) 57 | - Download .nii files from ABIDE site 58 | - make './ABIDE/all_paths.txt' files which "Patient ID" + '\t' + "path of .nii file of patient" 59 | - then run 60 |

 61 | {python abide_atlas.py}
 62 | 
63 | 64 | - After saving images, to get image representation, in SimCLR folder run 65 |

 66 | {python run_with_pretrain_with_micle.py}
 67 | 
68 | 69 | 70 | - After get image representation and non-image feature, then run 71 |

 72 | {python abide_kmeans.py}
 73 | 
74 | 75 | - Then run 76 |

 77 | {python main.py}
 78 | 
79 | 80 | ### How to Run (QIN-Breast dataset) 81 | - Download QIN-Breast dataset from https://wiki.cancerimagingarchive.net/display/Public/QIN-Breast. 82 | - Preprocess QIN-Breast dataset (Savning Images) 83 | - Download .dcm files from QIN-Breast site 84 | - make './QIN/all_paths.txt' files which "Patient ID" + '\t' + "path of .dcm file of patient" 85 | - then run 86 |

 87 | {python qin_save.py}
 88 | 
89 | 90 | - After saving images, to get image representation, in SimCLR folder run 91 |

 92 | {python run_with_pretrain_with_micle.py}
 93 | 
94 | 95 | - After get image representation and non-image feature, then run 96 |

 97 | {python qin_kmeans.py}
 98 | 
99 | 100 | - Then run 101 |

102 | {python main.py}
103 | 
104 | 105 | ### How to Run (CMMD dataset) 106 | - Download CMMD dataset from https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70230508. 107 | - Preprocess CMMD dataset (Saving Images) 108 | - Download .dcm files form CMMD site 109 | - make './CMMD/all_paths.txt' files which "Pateint ID" + '\t' + "path of .dcm file of patient" 110 | - then run 111 |

112 | {python cmmd_save.py}
113 | 
114 | 115 | - After saving images, to get image representation, in SimCLR folder run 116 |

117 | {python run_with_pretrain_with_micle.py}
118 | 
119 | 120 | - After get image representation and non-image feature, then run 121 |

122 | {cmmd.ipynb}
123 | 
124 | 125 | - Then run 126 |

127 | {python main.py}
128 | 
129 | -------------------------------------------------------------------------------- /Preprocessing/sample_all_path.txt: -------------------------------------------------------------------------------- 1 | 50001 C:/Users/user/Desktop/dataset/50001/sample_00/S164623/50001_MRI_raw_20120830172854796_8631.nii 2 | 50002 C:/Users/user/Desktop/dataset/50002/sample_00/S164416/50002_MRI_raw_20120830155445855_410.nii 3 | 50003 C:/Users/user/Desktop/dataset/50003/sample_00/S164726/50003_MRI_raw_20120830181140636_736.nii 4 | 50004 C:/Users/user/Desktop/dataset/50004/sample_00/S165234/50004_MRI_raw_20120830220635982_244.nii 5 | 50005 C:/Users/user/Desktop/dataset/50005/sample_00/S165414/50005_MRI_raw_20120830234259859_424.nii 6 | 50006 C:/Users/user/Desktop/dataset/50006/sample_00/S165121/50006_MRI_raw_20120830205554190_131.nii 7 | 50007 C:/Users/user/Desktop/dataset/50007/sample_00/S164971/50007_MRI_raw_20120830194128385_981.nii -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Heterogeneous Graph Learning for Multi-modal Medical Data Analysis 2 | 3 | The official source code for [**Heterogeneous Graph Learning for Multi-modal Medical Data Analysis**](https://arxiv.org/abs/2211.15158) paper, accepted at AAAI 2023. 4 | 5 | ### Overview 6 | Routine clinical visits of a patient produce not only image data, but also non-image data containing clinical information regarding the patient, i.e., medical data is multi-modal in nature. Such heterogeneous modalities offer different and complementary perspectives on the same patient, resulting in more accurate clinical decisions when they are properly combined. However, despite its significance, how to effectively fuse the multi-modal medical data into a unified framework has received relatively little attention. In this paper, we propose an effective graph-based framework called HetMed (Heterogeneous Graph Learning for Multi-modal Medical Data Analysis) for fusing the multi-modal medical data. Specifically, we construct a multiplex network that incorporates multiple types of non-image features of patients to capture the complex relationship between patients in a systematic way, which leads to more accurate clinical decisions. Extensive experiments on various real-world datasets demonstrate the superiority and practicality of HetMed. 7 | 8 | figure 9 | - Multiple modalities of medical data provide different and complementary views of the same patient. 10 | 11 | ### Run our framework 12 | 13 | - Due to memory and policy problem of datasets, we cannot upload medical image or non-image of datasets. 14 | - In "Preprocessing", we denote the ways to get datasets and how to preprocess the datasets. 15 | 16 | - In this circumstance, we upload embeddings of some datasets from Image embedder and non-image information. 17 | - To checkout reproduce our model, use this .plk file run DMGI model. 18 | 19 | - ABIDE data sets 20 |

21 | cd MultiplexNetwork
22 | python main.py --data abide --methapath type0,type1,type2,type3 --isSemi --isAttn --sup_coef 1.0
23 | 
24 | 25 | - CMMD data sets 26 |

27 | cd MultiplexNetwork
28 | python main.py --data cmmd --methapath type0,type1,type2,type3 --isSemi --isAttn --patience 20 --sup_coef 0.01
29 | 
30 | 31 | 32 | ### Cite (Bibtex) 33 | - Please refer the following paer, if you find HetMed useful in your research: 34 | - Kim, Sein and Lee, Namkyeong and Lee, Junseok and Hyun, Dongmin and Park, Chanyoung. "Heterogeneous Graph Learning for Multi-modal Medical Data Analysis" AAAI 2023. 35 | - Bibtex 36 | ``` 37 | @article{kim2022heterogeneous, 38 | title={Heterogeneous Graph Learning for Multi-modal Medical Data Analysis}, 39 | author={Kim, Sein and Lee, Namkyeong and Lee, Junseok and Hyun, Dongmin and Park, Chanyoung}, 40 | journal={arXiv preprint arXiv:2211.15158}, 41 | year={2022} 42 | } 43 | ``` 44 | --------------------------------------------------------------------------------