├── models ├── PGD_NET_Spiral_T_1.pt ├── PGD_NET_Spiral_T_2.pt ├── PGD_NET_Spiral_T_5.pt └── miccai_bloch_gen_200210-180246.pt ├── data.py ├── bloch.py ├── operators.py ├── README.md ├── network_arch.py ├── utils.py └── main.py /models/PGD_NET_Spiral_T_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edongdongchen/PGD-Net/HEAD/models/PGD_NET_Spiral_T_1.pt -------------------------------------------------------------------------------- /models/PGD_NET_Spiral_T_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edongdongchen/PGD-Net/HEAD/models/PGD_NET_Spiral_T_2.pt -------------------------------------------------------------------------------- /models/PGD_NET_Spiral_T_5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edongdongchen/PGD-Net/HEAD/models/PGD_NET_Spiral_T_5.pt -------------------------------------------------------------------------------- /models/miccai_bloch_gen_200210-180246.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edongdongchen/PGD-Net/HEAD/models/miccai_bloch_gen_200210-180246.pt -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import torch 3 | import os 4 | import numpy as np 5 | import scipy.io as scio 6 | 7 | class MRFData(Dataset): 8 | def __init__(self, mod='train', sampling='S'): 9 | ''' 10 | The data was from a partner company and we are restricted from sharing. 11 | However, our code can be flexibly transferred or directly used on other customized MRF dataset. 12 | ''' 13 | 14 | # users need to specify their own ground truth source 15 | # X: MRF images 16 | # Y: MRF (kspace) measurements 17 | # M: tissue property maps 18 | self.scaling = 1 19 | if mod=='train': 20 | mat_path = './matfile/train_dataXS11_s.mat' 21 | if mod=='test': 22 | mat_path = './matfile/test_dataXS11_s.mat' 23 | 24 | mat_data = scio.loadmat(mat_path) 25 | X = np.transpose(mat_data['X'], (0,3,1,2)) 26 | Y = np.stack([mat_data['y_s_real'], mat_data['y_s_imag']], axis=-1) 27 | 28 | M = np.transpose(mat_data['MRF_maps'], (0,3,1,2)) 29 | 30 | self.x = torch.from_numpy(X) 31 | self.y = torch.from_numpy(Y) 32 | self.m = torch.from_numpy(M) 33 | 34 | self.y = torch.from_numpy(Y).unsqueeze(-2) 35 | 36 | 37 | print('MRF-{}-dataset:\nCS Fourier y: {},\nMRF image x:{},\nTissue map m: {}'.format(sampling, self.y.shape, self.x.shape, self.m.shape)) 38 | 39 | 40 | def __getitem__(self, index): 41 | return self.x[index], self.m[index], self.y[index] 42 | 43 | def __len__(self): 44 | return len(self.x) 45 | 46 | 47 | class BlochData(Dataset): 48 | def __init__(self, mat_path='./matfile/Ramp2D_200reps_guido_trainingset.mat'): 49 | assert os.path.exists(mat_path) 50 | mat_data = scio.loadmat(mat_path) 51 | X = mat_data['X'] 52 | Y = mat_data['Y'] 53 | Y = np.concatenate((Y, np.ones((len(Y),1), dtype=float)), axis=1) 54 | # MRF image 55 | self.X = np.reshape(X, (len(X), 1, 1, 10),'F') # N * H * W * C 56 | self.X = np.transpose(self.X, (0, 3, 1, 2)) # covert to N * C * H * W 57 | self.X = torch.from_numpy(self.X) 58 | # tissue map 59 | self.M = np.reshape(Y, (len(Y), 1, 1, 3),'F') # N * H * W * C 60 | self.M = np.transpose(self.M, (0,3,1,2)) # covert to N * C * H * W 61 | self.M = torch.from_numpy(self.M) 62 | 63 | def __getitem__(self, index): 64 | x, m = self.X[index], self.M[index] 65 | return x, m 66 | 67 | def __len__(self): 68 | return len(self.X) -------------------------------------------------------------------------------- /bloch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from utils import get_timestamp, logT 6 | import data 7 | 8 | """ 9 | Training code of neural network BLOCH estimator for MR fingerprinting in the paper 10 | @inproceedings{chen2020compressive, 11 | author = {Dongdong Chen and Mike E. Davies and Mohammad Golbabaee}, 12 | title = {Compressive MR Fingerprinting reconstruction with Neural Proximal Gradient iterations}, 13 | booktitle={International Conference on Medical image computing and computer-assisted intervention (MICCAI)}, 14 | year = {2020} 15 | } 16 | """ 17 | 18 | class BlochDecoder(nn.Module): 19 | def __init__(self, in_channels=2, out_channels=10): 20 | super(BlochDecoder, self).__init__() 21 | self.Conv = nn.Sequential( 22 | nn.Conv2d(in_channels=in_channels, out_channels=300, kernel_size=1, bias=True), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(in_channels=300, out_channels=out_channels, kernel_size=1, bias=True) 25 | ) 26 | 27 | def forward(self, map):# map: N x C x H x W 28 | t1t2, pd = map[:, 0:2, :, :], map[:, 2:3, :, :] 29 | x = self.Conv(t1t2) 30 | pd = pd.repeat(1,10,1,1) # Nx1xHxW -> Nx10xHxW 31 | x = x * pd # MRF image (scaled by pd) 32 | return x 33 | 34 | def train_bloch(lr=0.01, EPOCH=50, BATCH_SIZE=500, weight_decay=1e-10, dtype=torch.cuda.FloatTensor): 35 | bloch = BlochDecoder().cuda() 36 | criterion = torch.nn.MSELoss().cuda() 37 | optimizer = torch.optim.Adam(bloch.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay) 38 | 39 | dataloader = torch.utils.data.DataLoader(dataset=data.BlochData(), batch_size=BATCH_SIZE, shuffle=True) 40 | 41 | 42 | bloch.train() 43 | for iter in range(EPOCH): 44 | loss_epoch = [] 45 | for x, m in dataloader: 46 | x, m = x.type(dtype), m.type(dtype) 47 | x_hat = bloch(m) 48 | loss = criterion(x_hat, x) 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | loss_epoch.append(loss.item()) 53 | logT("===> Epoch {}: Loss: {:.10f}".format(iter, np.mean(loss_epoch))) 54 | filename = './models/miccai_bloch_gen_{}.pt'.format(get_timestamp()) 55 | torch.save(bloch.state_dict(), filename) 56 | print('Saved Bloch generator to the disk: {}'.format(filename)) 57 | 58 | def BLOCH(): 59 | """ 60 | return: a pre-trained BLOCH estimator. 61 | One can directly apply this BLOCH estimator to simulate the BLOCH equation response in practice 62 | """ 63 | bloch = BlochDecoder() 64 | bloch.load_state_dict(torch.load('./models/miccai_bloch_gen_200210-180246.pt')) 65 | bloch.eval() 66 | return bloch -------------------------------------------------------------------------------- /operators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scipy.io as scio 5 | 6 | from utils import to_tensor, complex_matmul 7 | 8 | """ 9 | PyTorch implementation of forward/adjoint operators for compressive MR fingerprinting (CS-MRF) in the paper 10 | @inproceedings{chen2020compressive, 11 | author = {Dongdong Chen and Mike E. Davies and Mohammad Golbabaee}, 12 | title = {Compressive MR Fingerprinting reconstruction with Neural Proximal Gradient iterations}, 13 | booktitle={International Conference on Medical image computing and computer-assisted intervention (MICCAI)}, 14 | year = {2020} 15 | } 16 | """ 17 | 18 | class OperatorBatch(nn.Module): 19 | def __init__(self, C=10, H=128, W=128, sampling='S', dtype=torch.cuda.FloatTensor): 20 | super(OperatorBatch, self).__init__() 21 | self.C, self.H, self.W, self.dtype = C, H, W, dtype 22 | 23 | # subspace dimension reduction 24 | pca_dic_data = scio.loadmat('./matfile/pytorch_Ramp2D_200reps_guido_trainingset.mat') 25 | 26 | self.V = to_tensor(pca_dic_data['V']) 27 | self.V_conj = to_tensor(pca_dic_data['V_conj']) 28 | if dtype is not None: 29 | self.V, self.V_conj = self.V.type(dtype), self.V_conj.type(dtype) 30 | assert self.V.shape[1]==self.C, 'Channels Error!' 31 | 32 | # init mask 33 | mask_data = scio.loadmat('./matfile/train_dataXS11_s.mat') 34 | if sampling=='C': 35 | mask = mask_data['samplemask_s'] 36 | if sampling=='S': 37 | mask = mask_data['samplemask_s'] 38 | self.mask = np.squeeze(np.asarray(mask-1)) 39 | print('mask.shape', self.mask.shape) 40 | 41 | def forward(self, x): 42 | return self.fwd_helper(x, self.mask, self.H, self.W, self.V) 43 | 44 | def adjoint(self, y, only_real=True): 45 | return self.adj_helper(y, self.mask, self.H, self.W, self.V_conj, only_real) 46 | 47 | def fwd_helper(self, x, mask, H, W, V):#x:NCHW 48 | N = x.shape[0] 49 | x = torch.stack([x, torch.zeros(x.shape).type(self.dtype)], dim=-1) 50 | x = torch.fft(x, 2) 51 | x = x.reshape(N, -1, H*W, 2) 52 | x = complex_matmul(V, x) 53 | x = x.reshape(N, -1, W, H, 2) 54 | x = x.permute(0, 1,3,2,4) 55 | x = x.reshape(N, -1, 1,2) 56 | x = x[:,mask, :,:]/np.sqrt(H*W) 57 | return x 58 | 59 | def adj_helper(self, y, mask, H, W, V_conj, only_real): #y:NCHW2 60 | N = y.shape[0] 61 | L = V_conj.shape[1] 62 | x = torch.zeros(N, L*H*W, 1, 2).type(self.dtype) 63 | x[:,mask,:,:]=y 64 | x = x.reshape(N, L, -1, 2) 65 | x = complex_matmul(V_conj, x) 66 | x = x.reshape(N, -1, W, H, 2) 67 | x = x.permute(0, 1, 3, 2, 4) 68 | x = torch.ifft(x,2)*np.sqrt(H*W) 69 | return x[...,0] if only_real else x 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Proximal Gradient Descent Network (PGD-Net) for Magnetic Resonance Fingerprinting 2 | This is the PyTorch implementation of MICCAI'20 paper 3 | 4 | [Compressive MR Fingerprinting reconstruction with Neural Proximal Gradient iterations](https://arxiv.org/pdf/2006.15271.pdf). 5 | 6 | By [Dongdong Chen](http://dongdongchen.com), [Mike E. Davies](https://scholar.google.co.uk/citations?user=dwmfR3oAAAAJ&hl=en), [Mohammad Golbabaee](https://mgolbabaee.wordpress.com/). 7 | 8 | The University of Edinburgh, The University of Bath. 9 | 10 | ### Table of Contents 11 | 0. [Keywords](#Keywords) 12 | 0. [Abstract](#Abstract) 13 | 0. [Requirement](#Requirement) 14 | 0. [Usage](#Usage) 15 | 0. [Citation](#citation) 16 | 17 | ### Keywords 18 | 19 | Magnetic Resonance Fingerprinting (MRF), Physics, Proximal gradient Descent (PGD), Inverse problem, Deep learning. 20 | 21 | ### Abstract 22 | 23 | Consistency of the predictions with respect to the physical forward model is pivotal for reliably solving inverse problems. This consistency is mostly un-controlled in the current end-to-end deep learning methodologies proposed for the Magnetic Resonance Fingerprinting (MRF) problem. To address this, we propose PGD-Net, a learned proximal gradient descent framework that directly incorporates the forward acquisition and Bloch dynamic models within a recurrent learning mechanism. The PGD-Net adopts a compact neural proximal model for de-aliasing and quantitative inference, that can be flexibly trained on scarce MRF training datasets. Our numerical experiments show that the PGD-Net can achieve a superior quantitative inference accuracy, much smaller storage requirement, and a comparable runtime to the recent deep learning MRF baselines, while being much faster than the dictionary matching schemes. 24 | 25 | ### Requirement 26 | 0. PyTorch >=1.0 27 | 0. CUDA >=8.5 28 | 29 | ### Usage 30 | 0. check the demo_train() and demo_test() in [main.py](https://github.com/edongdongchen/PGD-Net/blob/master/main.py) 31 | 0. the neura network architecture of PGD-Net ('proxnet') is defined in [network_arch.py](https://github.com/edongdongchen/PGD-Net/blob/master/network_arch.py) 32 | 0. the forward and adjoint operators are implemented in [operators.py](https://github.com/edongdongchen/PGD-Net/blob/master/operators.py) 33 | 0. note: the data was from a partner company and we are restricted from sharing. Users need to specify their own dataset. Our code can be flexibly transferred or directly used on other customized MRF dataset. 34 | 35 | ### Citation 36 | 37 | If you use these models in your research, please cite: 38 | 39 | @inproceedings{chen2020compressive, 40 | author = {Dongdong Chen and Mike E. Davies and Mohammad Golbabaee}, 41 | title = {Compressive MR Fingerprinting reconstruction with Neural Proximal Gradient iterations}, 42 | booktitle={International Conference on Medical image computing and computer-assisted intervention (MICCAI)}, 43 | year = {2020} 44 | } 45 | -------------------------------------------------------------------------------- /network_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """ 5 | Neural Network Architecture of PGD-Net for compressive MR fingerprinting (CS-MRF) in the paper 6 | @inproceedings{chen2020compressive, 7 | author = {Dongdong Chen and Mike E. Davies and Mohammad Golbabaee}, 8 | title = {Compressive MR Fingerprinting reconstruction with Neural Proximal Gradient iterations}, 9 | booktitle={International Conference on Medical image computing and computer-assisted intervention (MICCAI)}, 10 | year = {2020} 11 | } 12 | """ 13 | 14 | class ProxNet(torch.nn.Module): 15 | def __init__(self, args): 16 | super(ProxNet, self).__init__() 17 | self.args = args 18 | self.alpha = torch.autograd.Variable(torch.Tensor(args.initial_alpha).type(args.dtype), requires_grad=True) 19 | self.transformnet = ResNet(in_channels=10, out_channels=3, nRS = 1, chRS=64, MRFNETch=64) 20 | self.relu = nn.ReLU(inplace=False) 21 | 22 | def forward(self, HTy, H, HT, bloch): 23 | x = 0 24 | m_seq, x_seq = [], [] 25 | for t in range(self.args.time_step): 26 | a = self.relu(self.alpha[t]) + 1 27 | s = a*HTy if t == 0 else x - a* (HT(H(x)) - HTy) 28 | m = self.transformnet(s) 29 | x = bloch(m) 30 | 31 | m_seq.append(m) 32 | x_seq.append(x) 33 | return m_seq, x_seq 34 | 35 | 36 | class ResNet(nn.Module): 37 | def __init__(self, in_channels=10, out_channels=3, nRS = 2,chRS=120, MRFNETch=400): 38 | super(ResNet, self).__init__() 39 | self.name = 'resnet' 40 | self.rsb1 = ResidualBlock(in_channels, in_channels, chRS) 41 | self.rsb2 = ResidualBlock(in_channels, in_channels, chRS) 42 | 43 | self.mrfnet = MRFNET(in_channels,out_channels,MRFNETch) 44 | 45 | def forward(self, x): 46 | # encoding path 47 | x1 = self.rsb1(x) 48 | # or 49 | # x1 = self.rsb2(x1) 50 | xout = self.mrfnet(x1) 51 | return xout 52 | 53 | def conv3x3(in_channels, out_channels, stride=1): 54 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) 55 | 56 | class ResidualBlock(nn.Module): 57 | def __init__(self, in_channels, out_channels,chRS, stride=1, downsample=None): 58 | super(ResidualBlock, self).__init__() 59 | self.conv1 = conv3x3(in_channels, chRS, stride) 60 | self.conv2 = conv3x3(chRS, out_channels) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | 64 | def forward(self, x): 65 | residual = x 66 | out = self.conv1(x) 67 | out = self.relu(out) 68 | out = self.conv2(out) 69 | out += residual 70 | out = self.relu(out) 71 | return out 72 | 73 | class MRFNET(nn.Module): 74 | def __init__(self, ch_in=10, ch_out=3, MRFNETch=400): 75 | super(MRFNET, self).__init__() 76 | self.name='mrfcnn' 77 | self.cnn = nn.Sequential( 78 | nn.Conv2d(ch_in, MRFNETch, kernel_size=1, padding=0), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(MRFNETch, MRFNETch, kernel_size=1, padding=0), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(MRFNETch, ch_out, kernel_size=1, padding=0), 83 | ) 84 | def forward(self, x): 85 | return self.cnn(x) 86 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import time 5 | import numpy as np 6 | from datetime import datetime 7 | 8 | import torch 9 | 10 | """ 11 | Utils for compressive MR fingerprinting (CS-MRF) in the paper 12 | @inproceedings{chen2020compressive, 13 | author = {Dongdong Chen and Mike E. Davies and Mohammad Golbabaee}, 14 | title = {Compressive MR Fingerprinting reconstruction with Neural Proximal Gradient iterations}, 15 | booktitle={International Conference on Medical image computing and computer-assisted intervention (MICCAI)}, 16 | year = {2020} 17 | } 18 | """ 19 | 20 | 21 | def set_gpu(gpu): 22 | print('Current GPU:{}'.format(gpu)) 23 | torch.cuda.set_device(gpu) 24 | torch.backends.cudnn.enabled = True 25 | torch.backends.cudnn.benchmark =True 26 | dtype = torch.cuda.FloatTensor 27 | return dtype 28 | 29 | def check_paths(args): 30 | try: 31 | if not os.path.exists(args.save_model_dir): 32 | os.makedirs(args.save_model_dir) 33 | if args.checkpoint_model_dir is not None and not (os.path.exists(args.checkpoint_model_dir)): 34 | os.makedirs(args.checkpoint_model_dir) 35 | except OSError as e: 36 | print(e) 37 | sys.exit(1) 38 | 39 | def prefix(args): 40 | return '{}_{}_cuda_{}_sampling_{}_iter_{}_T_{}_x_{}_y_{}_m_{}_{}_{}_bs_{}_lr_{}_wd_{}'.format( 41 | args.filename, str(time.ctime()).replace(' ', '_'), 42 | args.cuda, args.sampling, args.epochs, args.time_step, 43 | args.loss_weight['x'], args.loss_weight['y'], 44 | args.loss_weight['m'][0], args.loss_weight['m'][1], args.loss_weight['m'][2], 45 | args.batch_size, args.lr, args.weight_decay) 46 | 47 | # -------------------------------- 48 | # logger 49 | # -------------------------------- 50 | def get_timestamp(): 51 | return datetime.now().strftime('%y%m%d-%H%M%S') 52 | 53 | class LOG(object): 54 | def __init__(self, filepath, filename, field_name=['iter', 'loss_x', 'loss_m', 'loss_y', 'loss_total', 'alpha']): 55 | self.filepath = filepath 56 | self.filename = filename 57 | self.field_name = field_name 58 | 59 | self.logfile, self.logwriter = csv_log(file_name=os.path.join(filepath, filename+'.csv'), field_name=field_name) 60 | self.logwriter.writeheader() 61 | 62 | def record(self, *args): 63 | dict = {} 64 | for i in range(len(self.field_name)): 65 | dict[self.field_name[i]]=args[i] 66 | self.logwriter.writerow(dict) 67 | 68 | def close(self): 69 | self.logfile.close() 70 | 71 | def print(self, msg): 72 | logT(msg) 73 | 74 | def csv_log(file_name, field_name): 75 | assert file_name is not None 76 | assert field_name is not None 77 | logfile = open(file_name, 'w') 78 | logwriter = csv.DictWriter(logfile, fieldnames=field_name) 79 | return logfile, logwriter 80 | 81 | def logT(*args, **kwargs): 82 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) 83 | 84 | def logger(args): 85 | logfile, logwriter = csv_log(file_name=os.path.join(args.net_dir, args.net_name+'.csv'), field_name=['iter', 'loss']) 86 | logwriter.writeheader() 87 | if args.opt['loss_type']=='mse': 88 | criterion = torch.nn.MSELoss().cuda() 89 | if args.opt['loss_type']=='l1': 90 | criterion = torch.nn.L1Loss().cuda() 91 | if args.opt['val_dataloader'] is not None: 92 | val_logfile, val_logwriter = csv_log(file_name=os.path.join(args.net_dir, args.net_name+'_val.csv'), field_name=['iter', 'loss']) 93 | val_logwriter.writeheader() 94 | return logfile, logwriter, val_logfile, val_logwriter, criterion 95 | else: 96 | return logfile, logwriter, criterion 97 | 98 | 99 | # -------------------------------- 100 | # Convert data type 101 | # -------------------------------- 102 | def to_tensor(data): 103 | """ 104 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts 105 | are stacked along the last dimension. 106 | Args: 107 | data (np.array): Input numpy array 108 | Returns: 109 | torch.Tensor: PyTorch version of data 110 | """ 111 | if np.iscomplexobj(data): 112 | data = np.stack((data.real, data.imag), axis=-1) 113 | return torch.from_numpy(data) 114 | 115 | def np_to_torch(img_np): 116 | '''Converts image in numpy.array to torch.Tensor. 117 | 118 | From C x W x H [0..1] to C x W x H [0..1] 119 | ''' 120 | return torch.from_numpy(img_np)[None, :] 121 | 122 | def torch_to_np(img_var): 123 | '''Converts an image in torch.Tensor format to np.array. 124 | 125 | From 1 x C x W x H [0..1] to C x W x H [0..1] 126 | ''' 127 | return img_var.detach().cpu().numpy() 128 | 129 | 130 | # -------------------------------- 131 | # complex-valued operation 132 | # -------------------------------- 133 | def complex_matmul(A, B): # A: (dim1, dim2, 2), B:(N, dim2, dim3, 2)' (a+bj)x(c+dj) = (ac-bd) + (bc+ad)j 134 | return torch.stack([torch.matmul(A[...,0], B[...,0]) - torch.matmul(A[...,1], B[...,1]), 135 | torch.matmul(A[...,1], B[...,0]) + torch.matmul(A[...,0], B[...,1])],dim=-1) 136 | 137 | def complex_abs(data): 138 | """ 139 | Compute the absolute value of a complex valued input tensor. 140 | Args: 141 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 142 | should be 2. 143 | Returns: 144 | torch.Tensor: Absolute value of data 145 | """ 146 | assert data.size(-1) == 2 147 | return (data ** 2).sum(dim=-1).sqrt() 148 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | 6 | import torch 7 | 8 | import data 9 | from bloch import BLOCH 10 | from network_arch import ProxNet 11 | from operators import OperatorBatch 12 | from utils import set_gpu, LOG, logT, check_paths, prefix 13 | 14 | ''' 15 | PyCharm (Python 3.6.9) 16 | PyTorch 1.3 17 | Windows 10 or Linux 18 | Dongdong Chen (d.chen@ed.ac.uk) 19 | github: https://github.com/echendongdong/PGD-Net 20 | 21 | If you have any question, please feel free to contact with me. 22 | Dongdong Chen (e-mail: d.chen@ed.ac.uk) 23 | by Dongdong Chen (01/March/2020) 24 | ''' 25 | 26 | """ 27 | # -------------------------------------------- 28 | Training/Testing code (GPU) of PGD-Net for compressive MR fingerprinting in the paper 29 | @inproceedings{chen2020compressive, 30 | author = {Dongdong Chen and Mike E. Davies and Mohammad Golbabaee}, 31 | title = {Compressive MR Fingerprinting reconstruction with Neural Proximal Gradient iterations}, 32 | booktitle={International Conference on Medical image computing and computer-assisted intervention (MICCAI)}, 33 | year = {2020} 34 | } 35 | # -------------------------------------------- 36 | Note: The data was from a partner company and we are restricted from sharing. 37 | Users need to specify their own dataset. 38 | Our code can be flexibly transferred or directly used on other customized MRF dataset. 39 | # -------------------------------------------- 40 | """ 41 | 42 | 43 | def train_proxnet(args): 44 | check_paths(args) 45 | # init GPU configuration 46 | args.dtype = set_gpu(args.cuda) 47 | 48 | # init seed 49 | np.random.seed(args.seed) 50 | torch.manual_seed(args.seed) 51 | 52 | # define training data 53 | train_dataset = data.MRFData(mod='train', sampling=args.sampling) 54 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) 55 | 56 | # init operators (subsampling + subspace dimension reduction + Fourier transformation) 57 | operator = OperatorBatch(sampling=args.sampling.upper()).cuda() 58 | H, HT = operator.forward, operator.adjoint 59 | bloch = BLOCH().cuda() 60 | 61 | # init PGD-Net (proxnet) 62 | proxnet = ProxNet(args).cuda() 63 | 64 | # init optimizer 65 | optimizer = torch.optim.Adam([{'params': proxnet.transformnet.parameters(), 66 | 'lr': args.lr, 'weight_decay': args.weight_decay}, 67 | {'params': proxnet.alpha, 'lr': args.lr2}]) 68 | 69 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=0.1) 70 | 71 | # init loss 72 | mse_loss = torch.nn.MSELoss()#.cuda() 73 | 74 | # init meters 75 | log = LOG(args.save_model_dir, filename=args.filename, field_name=['iter', 'loss_m', 'loss_x', 'loss_y', 'loss_total', 'alpha']) 76 | 77 | loss_epoch = 0 78 | loss_m_epoch, loss_x_epoch, loss_y_epoch =0,0,0 79 | 80 | # start PGD-Net training 81 | print('start training...') 82 | for e in range(args.epochs): 83 | proxnet.train() 84 | loss_m_seq = [] 85 | loss_x_seq = [] 86 | loss_y_seq = [] 87 | loss_total_seq = [] 88 | 89 | for x, m, y in train_loader: 90 | # covert data type (cuda) 91 | x, m, y = x.type(args.dtype), m.type(args.dtype), y.type(args.dtype) 92 | # add noise 93 | noise = args.noise_sigam * torch.randn(y.shape).type(args.dtype) 94 | HTy = HT(y + noise).type(args.dtype) 95 | 96 | # PGD-Net computation (iteration) 97 | # output the reconstructions (sequence) of MRF image x and its tissue property map m 98 | m_seq, x_seq = proxnet(HTy, H, HT, bloch) 99 | 100 | loss_x, loss_y, loss_m = 0,0,0 101 | for t in range(args.time_step): 102 | loss_y += mse_loss(H(x_seq[t]), y)/args.time_step 103 | for i in range(3): 104 | loss_m += args.loss_weight['m'][i] * mse_loss(m_seq[-1][:,i,:,:], m[:,i,:,:]) 105 | loss_x = mse_loss(x_seq[-1], x) 106 | 107 | # compute loss 108 | loss_total = loss_m + args.loss_weight['x'] * loss_x + args.loss_weight['y']*loss_y 109 | 110 | # update gradient 111 | optimizer.zero_grad() 112 | loss_total.backward() 113 | optimizer.step() 114 | 115 | # update meters 116 | loss_m_seq.append(loss_m.item()) 117 | loss_x_seq.append(loss_x.item()) 118 | loss_y_seq.append(loss_y.item()) 119 | loss_total_seq.append(loss_total.item()) 120 | 121 | # (scheduled) update learning rate 122 | scheduler.step() 123 | 124 | # print meters 125 | loss_m_epoch = np.mean(loss_m_seq) 126 | loss_x_epoch = np.mean(loss_x_seq) 127 | loss_y_epoch = np.mean(loss_y_seq) 128 | loss_epoch = np.mean(loss_total_seq) 129 | 130 | log.record(e+1, loss_m_epoch, loss_x_epoch, loss_y_epoch, loss_epoch, proxnet.alpha.detach().cpu().numpy()) 131 | logT("==>Epoch {}\tloss_m: {:.6f}\tloss_x: {:.6f}\tloss_y: {:.6f}\tloss_total: {:.6f}\talpha: {}" 132 | .format(e + 1, loss_m_epoch, loss_x_epoch, loss_y_epoch, loss_epoch, proxnet.alpha.detach().cpu().numpy())) 133 | 134 | # save checkpoint 135 | if args.checkpoint_model_dir is not None and (e + 1) % args.checkpoint_interval == 0: 136 | proxnet.eval() 137 | ckpt = { 138 | 'epoch': e+1, 139 | 'loss_m': loss_m_epoch, 140 | 'loss_x': loss_x_epoch, 141 | 'loss_y': loss_y_epoch, 142 | 'total_loss': loss_epoch, 143 | 'net_state_dict': proxnet.state_dict(), 144 | 'optimizer_state_dict': optimizer.state_dict(), 145 | 'alpha': proxnet.alpha.detach().cpu().numpy() 146 | } 147 | torch.save(ckpt, os.path.join(args.checkpoint_model_dir, 'ckp_epoch_{}.pt'.format(e))) 148 | proxnet.train() 149 | 150 | # save model 151 | proxnet.eval() 152 | state = { 153 | 'epoch':args.epochs, 154 | 'loss_m': loss_m_epoch, 155 | 'loss_x': loss_x_epoch, 156 | 'loss_y': loss_y_epoch, 157 | 'total_loss': loss_epoch, 158 | 'alpha': proxnet.alpha.detach().cpu().numpy(), 159 | 'net_state_dict': proxnet.state_dict(), 160 | 'optimizer_state_dict': optimizer.state_dict() 161 | } 162 | save_model_path = os.path.join(args.save_model_dir, log.filename+'.pt') 163 | torch.save(state, save_model_path) 164 | print("\nDone, trained model saved at", save_model_path) 165 | 166 | def test_proxnet(args): 167 | def load_proxnet(args): 168 | ckp = torch.load(args.net_path) 169 | alpha_learned = ckp['alpha'] 170 | 171 | net = ProxNet(args).cuda() 172 | net.load_state_dict(ckp['net_state_dict']) 173 | net.alpha = torch.from_numpy(alpha_learned) 174 | net.eval() 175 | print('alpha={}'.format(net.alpha)) 176 | return net 177 | 178 | operator = OperatorBatch(sampling=args.sampling.upper()).cuda() 179 | H, HT = operator.forward, operator.adjoint 180 | bloch = BLOCH().cuda() 181 | 182 | args.dtype = set_gpu(args.cuda) 183 | net = load_proxnet(args) 184 | batch_size = 1 185 | test_loader = torch.utils.data.DataLoader(dataset=data.MRFData(mod='test', sampling=args.sampling), 186 | batch_size=batch_size, shuffle=False) 187 | 188 | rmse_m, rmse_x, rmse_y = [],[],[] 189 | rmse_torch = lambda a,b:torch.norm(a-b, 2).detach().cpu().numpy()/torch.norm(b, 2).detach().cpu().numpy()/batch_size 190 | 191 | toc = time.time() 192 | for x, m, y in test_loader: 193 | m, y = m.type(args.dtype), y.type(args.dtype) 194 | HTy = HT(y).type(args.dtype) 195 | 196 | m_seq, x_seq = net(HTy, H, HT, bloch) 197 | m_hat = m_seq[-1] 198 | 199 | rmse_m.append(rmse_torch(m_hat, m)) 200 | 201 | elapsed = time.time() - toc 202 | print('time: {}'.format(elapsed / 16)) 203 | print('m error mean:{}, max: {}, std:{}'.format(np.mean(rmse_m), np.max(rmse_m), np.std(rmse_m))) 204 | 205 | 206 | if __name__=='__main__': 207 | def demo_train(): 208 | args = argparse.ArgumentParser().parse_args() 209 | 210 | args.cuda = 0 211 | args.seed = 5213 212 | args.sampling = 'S' # 'spiral' 213 | args.filename = 'pgd_net' 214 | 215 | args.epochs = 2000 216 | args.batch_size = 4 217 | args.noise_sigam = 0.01 218 | args.weight_decay = 1e-8 219 | args.checkpoint_interval = 100 220 | 221 | # learning rate for neural network 222 | args.lr = 1e-4 223 | # learning rate for alpha 224 | args.lr2 = .05 225 | # gamma, lambda, beta 226 | args.loss_weight = {'x': 0.001, 'y': 0.01, 'm': [1, 20, 2.5]} 227 | # PGD time step (T) 228 | args.time_step = 2 229 | # init alpha 230 | args.initial_alpha = np.asarray([2] * args.time_step) 231 | # init path 232 | args.prefix = prefix(args) 233 | args.save_model_dir = os.path.join('models', args.prefix) 234 | args.checkpoint_model_dir = os.path.join('models', args.prefix, 'ckp') 235 | print(args.prefix) 236 | 237 | # start to train 238 | train_proxnet(args) 239 | 240 | def demo_test(): 241 | args = argparse.ArgumentParser().parse_args() 242 | 243 | args.cuda = 0 244 | args.sampling = 'S' 245 | args.time_step = 2 246 | args.net_path = 'models/PGD_NET_Spiral_T_2.pt' 247 | 248 | test_proxnet(args) 249 | 250 | # demo_train() 251 | # demo_test() 252 | --------------------------------------------------------------------------------