├── README.md └── modelCode ├── ChannelAttention.py ├── LoadData.py ├── PrjModule.py ├── RDSVD.py ├── SOUL_Net.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # SOUL-Net: A Sparse and Low-Rank Unrolling Network for Spectral CT Image Reconstruction 4 | 5 | This code need install ctlib for projection backprojection acceleration, which is implement by W. Xia (code: https://github.com/xwj01/CTLIB). 6 | Thanks very much for the code provided by W. Xia. 7 | 8 | If you use this code, please cite "SOUL-Net: A Sparse and Low-Rank Unrolling Network for Spectral CT Image Reconstruction". 9 | 10 | 11 | # Environment: 12 | windows 10:python 3.9,cuda 11.2, pytorch 1.9.0 -------------------------------------------------------------------------------- /modelCode/ChannelAttention.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch.nn as nn 3 | import torch 4 | class channelAtten(nn.Module): 5 | def __init__(self,hidden=256): 6 | super(channelAtten, self).__init__() 7 | self.pool1=nn.AvgPool2d(256) 8 | self.seconv1=nn.Conv2d(hidden,hidden//16,kernel_size=1) 9 | self.seconv2=nn.Conv2d(hidden//16,hidden,kernel_size=1) 10 | def forward(self,x): 11 | b,c,h,w=x.shape 12 | 13 | avgpoolres=self.pool1(x) 14 | res=self.seconv1(avgpoolres) 15 | res=torch.relu(res) 16 | res=self.seconv2(res) 17 | res=torch.sigmoid(res) 18 | res=res*x 19 | # loss=torch.norm(res,1)/torch.norm(res,2)+torch.norm(res,1) 20 | return res -------------------------------------------------------------------------------- /modelCode/LoadData.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydicom 3 | import os 4 | import torch 5 | import scipy.io as sio 6 | import scipy 7 | import math 8 | import matplotlib.pyplot as plt 9 | from scipy.sparse.linalg import bicgstab 10 | import torch 11 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 12 | 13 | class LData(torch.utils.data.Dataset): 14 | def __init__(self,**kwargs): 15 | super(LData, self).__init__() 16 | self.path=kwargs['datapath'] 17 | self.train = kwargs['train'] 18 | self.trainimgnum=kwargs['TrainImgNum'] 19 | self.testimgnum=kwargs['TestImgNum'] 20 | trainSetPath=[] 21 | testSetPath=[] 22 | if self.train : 23 | for j in range(self.trainimgnum): 24 | traindir=os.path.join(self.path,"train/IMG") 25 | fname=traindir+str(j+1)+".mat" 26 | trainSetPath.append(fname) 27 | self.trainPath=(trainSetPath) 28 | else: 29 | for j in range(self.testimgnum): 30 | testdir=os.path.join(self.path,"test/IMG") 31 | fname=testdir+str(j+1)+".mat" 32 | testSetPath.append(fname) 33 | self.testPath=(testSetPath) 34 | def __getitem__(self, index): 35 | if self.train==True: 36 | imgpath=self.trainPath[index] 37 | else: 38 | imgpath=self.testPath[index] 39 | mat=sio.loadmat(imgpath) 40 | img=mat['Label'] 41 | NoiseSino=mat['NoiseSino'] 42 | fbpres=mat['FbpRes'] 43 | return fbpres,NoiseSino,img 44 | 45 | def __len__(self): 46 | if self.train==True: 47 | return (self.trainimgnum) 48 | else: 49 | return (self.testimgnum) 50 | -------------------------------------------------------------------------------- /modelCode/PrjModule.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch.nn as nn 3 | import torch 4 | import ctlib 5 | from torch.autograd import Function 6 | 7 | class prj_module(nn.Module): 8 | def __init__(self, weight,options): 9 | super(prj_module, self).__init__() 10 | self.weight = weight 11 | self.options = nn.Parameter(options, requires_grad=False) 12 | 13 | def forward(self, input_data, proj): 14 | return prj_fun.apply(input_data, self.weight, proj, self.options) 15 | class prj_fun(Function): 16 | @staticmethod 17 | def forward(self, input_data, weight, proj, options): 18 | b, c, h, w = input_data.shape 19 | input_data = input_data.contiguous().view(b, c, h, w) 20 | b_sino, c_sino, h_sino, w_sino = proj.shape 21 | proj = proj.contiguous().view(b_sino, c_sino, h_sino, w_sino) 22 | temp=ctlib.projection(input_data,options,0)-proj 23 | intervening_res = ctlib.backprojection(temp, options,0) 24 | self.save_for_backward(intervening_res, weight, options) 25 | out = input_data - weight * intervening_res 26 | return out 27 | 28 | @staticmethod 29 | def backward(self, grad_output): 30 | intervening_res, weight, options = self.saved_tensors 31 | b,c,h,w=grad_output.shape 32 | grad_output = grad_output.contiguous().view(b,c,h,w) 33 | temp = ctlib.projection(grad_output, options,0) 34 | t_b,t_c,t_h,t_w=temp.shape 35 | temp = temp.contiguous().view(t_b,t_c,t_h,t_w) 36 | temp = ctlib.backprojection(temp, options,0) 37 | grad_input = grad_output - weight * temp 38 | temp = intervening_res * grad_output 39 | grad_weight = - temp.sum().view(-1) 40 | return grad_input, grad_weight, None, None -------------------------------------------------------------------------------- /modelCode/RDSVD.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | 4 | def geometric_approximation_1(s): 5 | ba, chan = s.shape 6 | dtype = s.dtype 7 | 8 | I = torch.ones([ba, chan], device=s.device).type(dtype) 9 | temp = 1e-8 * I 10 | s = s + temp 11 | I = torch.diag_embed(I) 12 | # I = torch.eye(s.shape[0], device=s.device).type(dtype) 13 | p = s.unsqueeze(-1) / s.unsqueeze(-2) - I 14 | p = torch.where(p < 1., p, 1. / p) 15 | 16 | a1 = s.unsqueeze(-1).repeat(1, 1, chan).permute(0, 2,1) 17 | # a1 = s.repeat(s.shape[0], 1).t() 18 | a1_t = a1.permute(0, 2,1) 19 | lamiPluslamj = 1. / ((s.unsqueeze(-1) + s.unsqueeze(-2))) # do not need to sub I,because have been * a1 20 | 21 | a1 = 1. / torch.where(a1 >= a1_t, a1, - a1_t) 22 | # a1 *= torch.ones(s.shape[0], s.shape[0], device=s.device).type(dtype) - I 23 | a1 *= torch.ones_like(I, device=s.device).type(dtype) - I 24 | p_app = torch.ones_like(p) 25 | p_hat = torch.ones_like(p) 26 | for i in range(9): 27 | p_hat = p_hat * p 28 | p_app += p_hat 29 | a1 = lamiPluslamj * a1 * p_app 30 | 31 | return a1 32 | 33 | 34 | class svdv2_1(torch.autograd.Function): 35 | @staticmethod 36 | def forward(ctx, M): 37 | try: 38 | U, S, V = torch.svd(M,some=True,compute_uv=True) 39 | except:#avoid cond() too large 40 | print(M.max()) 41 | print(M.min()) 42 | ipdb.set_trace() 43 | U, S, V = torch.svd(M+1e-3*M.mean()*torch.rand_like(M),some=True,compute_uv=True) 44 | dtype = M.dtype 45 | S[S <= torch.finfo(dtype).eps] = torch.finfo(dtype).eps 46 | ctx.save_for_backward(M, U, S,V) 47 | return U,S,V 48 | 49 | @staticmethod 50 | def backward(ctx, dL_du, dL_ds,dL_dv): 51 | M, U,S,V = ctx.saved_tensors 52 | k= geometric_approximation_1(S) 53 | k[k == float('inf')] = k[k != float('inf')].max() 54 | k[k == float('-inf')] = k[k != float('-inf')].min() 55 | k[k != k] = k.max() 56 | K_t=k.permute(0,2,1) 57 | diag_s=torch.diag_embed(S) 58 | VT=torch.permute(V,[0,2,1]) 59 | tt=2*torch.matmul(diag_s,K_t*torch.matmul(VT,dL_dv)) 60 | grad_input=tt+torch.diag_embed(dL_ds) 61 | US=torch.matmul(U, grad_input) 62 | grad_input = torch.matmul(US, VT) 63 | return grad_input 64 | 65 | -------------------------------------------------------------------------------- /modelCode/SOUL_Net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from PrjModule import prj_module 4 | from RDSVD import svdv2_1 5 | from ChannelAttention import channelAtten 6 | 7 | class Block(nn.Module): 8 | def __init__(self,options): 9 | super(Block, self).__init__() 10 | self.weight=nn.Parameter(torch.zeros(1)) 11 | self.block1 = prj_module(self.weight,options) 12 | self.convmodel2 = nn.Sequential( 13 | nn.Conv2d(5,256, 3, 1, 1), 14 | nn.ReLU(), 15 | nn.Conv2d(256, 256, 3, 1, 1), 16 | 17 | ) 18 | self.convmodel3 = nn.Sequential( 19 | nn.Conv2d(256, 256, 3, 1, 1), 20 | nn.ReLU(), 21 | nn.Conv2d(256, 5, 3, 1, 1), 22 | 23 | ) 24 | self.thre = nn.Parameter(torch.zeros([1, 5])) 25 | self.thre1=nn.Parameter(torch.zeros(1)) 26 | 27 | self.chanatten = channelAtten(256) 28 | self.rho=nn.Parameter(torch.zeros(1)+1) 29 | 30 | def func1(self,x): 31 | U,S,V=svdv2_1.apply(x) 32 | VT=V.permute(0,2,1) 33 | mythre = torch.sigmoid(self.thre1) * S[:,0] 34 | mythre=torch.unsqueeze(mythre,-1) 35 | S=S-mythre 36 | S=torch.relu(S) 37 | S = torch.diag_embed(S) 38 | US=torch.matmul(U,S) 39 | USV=torch.matmul(US,VT) 40 | return USV,0 41 | def RX(self,X): 42 | b,c,h,w=X.shape 43 | X_0=torch.reshape(X,[b,c,h*w]) 44 | X_0=torch.permute(X_0,[0,2,1]) 45 | return X_0 46 | def RTX(self,X,shape): 47 | b,c,h,w=shape 48 | 49 | X_0=torch.permute(X,[0,2,1]) 50 | X_0 = torch.reshape(X_0, [b, c, h, w]) 51 | return X_0 52 | def lowRankSparse(self,X,proj,BB): 53 | X_0=self.RX(X) 54 | X_1=X_0-BB 55 | Z,_ = self.func1(X_1) 56 | tmp=X_0-BB-Z 57 | 58 | temp=self.RTX(tmp,X.shape) 59 | r=self.fidelity(X, proj) -self.weight*self.rho*temp 60 | S_k=self.convmodel2(r) 61 | S_k=self.chanatten(S_k) 62 | S_k = self.convmodel3(S_k) 63 | S_k=S_k+r 64 | RXn=self.RX(S_k) 65 | BB=BB+Z-RXn 66 | return S_k,BB 67 | 68 | def fidelity(self,input,proj): 69 | b,c,wimg,himg=input.shape 70 | _,_, wsino, hsino = proj.shape 71 | tmp=torch.reshape(input,[b*c,1,wimg,himg]) 72 | projtmp = torch.reshape(proj, [b * c, 1, wsino,hsino]) 73 | tmp1=self.block1(tmp,projtmp) 74 | res = torch.reshape(tmp1, [b ,c,wimg,himg]) 75 | return res 76 | 77 | 78 | def forward(self,myinput,proj,BB): 79 | res,BB=self.lowRankSparse(myinput,proj,BB) 80 | return res,BB 81 | class nBlock(nn.Module): 82 | def __init__(self,**kwargs): 83 | super(nBlock,self).__init__() 84 | self.iternum=kwargs['blocknum'] 85 | views = kwargs['views'] 86 | dets = kwargs['dets'] 87 | width = kwargs['width'] 88 | height = kwargs['height'] 89 | dImg = kwargs['dImg'] 90 | dDet = kwargs['dDet'] 91 | dAng = kwargs['dAng'] 92 | s2r = kwargs['s2r'] 93 | d2r = kwargs['d2r'] 94 | binshift = kwargs['binshift'] 95 | options = torch.Tensor([views, dets, width, height, dImg, dDet, dAng, s2r, d2r, binshift]) 96 | Blocklist=[] 97 | for i in range(0, self.iternum): 98 | Blocklist.append(Block(options)) 99 | self.Blocklist = nn.ModuleList(Blocklist) 100 | def forward(self,input1,proj): 101 | 102 | outputlist=[] 103 | outputlist.append(input1) 104 | b,c,h,w=input1.shape 105 | BB=torch.zeros([b,h*w,c]).cuda() 106 | for layer_idx in range(self.iternum): 107 | output,BB = self.Blocklist[layer_idx](outputlist[layer_idx],proj,BB) 108 | outputlist.append(output) 109 | return outputlist[-1] 110 | -------------------------------------------------------------------------------- /modelCode/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import LoadData 3 | from skimage.metrics import peak_signal_noise_ratio as psnr 4 | from SOUL_Net import nBlock 5 | from torch.autograd import Variable 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from skimage.metrics import structural_similarity as ssim 9 | import time 10 | import argparse 11 | import matplotlib.pyplot as plt 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--epochs", type=int, default=300, help="number of epochs of training") 14 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") 15 | parser.add_argument("--lr", type=float, default=1e-4, help="adam: learning rate") 16 | parser.add_argument("--num_block", type=int, default=10) 17 | parser.add_argument("--model_save_path", type=str, default="saved_models/1st") 18 | parser.add_argument("--data_path", type=str, default="../../genedatas/ctlib_64_1024_72_58_98178_3.5_3/")#datas 19 | parser.add_argument("--models_path", type=str, default="models/") 20 | parser.add_argument("--testres_path", type=str, default="testres/") 21 | parser.add_argument('--checkpoint_interval', type=int, default=1) 22 | parser.add_argument('--initmethod',type=str,default="Fbp") 23 | parser.add_argument('--method',type=str,default="SOUL_Net_ctlib_64_1024_72_58_98178_3.5_3_2ci") 24 | parser.add_argument('--TrainImgNum',type=int, default=400) 25 | parser.add_argument('--TestImgNum',type=int, default=100) 26 | parser.add_argument('--Spectrumlen',type=int, default=5) 27 | opt = parser.parse_args() 28 | cuda = True if torch.cuda.is_available() else False 29 | useCuda=cuda and True 30 | 31 | 32 | def Train(): 33 | if useCuda: 34 | print("use cuda") 35 | torch.cuda.set_device(0) 36 | else: 37 | print("Do not use cuda") 38 | LoadDatas= LoadData.LData(TrainImgNum=opt.TrainImgNum, TestImgNum=opt.TestImgNum, datapath=opt.data_path, train=True) 39 | LoadTestDatas= LoadData.LData(TrainImgNum=opt.TrainImgNum, TestImgNum=opt.TestImgNum, datapath=opt.data_path, train=False) 40 | train_loader = DataLoader(dataset=LoadDatas, batch_size=opt.batch_size,shuffle=True) 41 | test_loader = DataLoader(dataset=LoadTestDatas, batch_size=opt.batch_size,shuffle=False) 42 | if useCuda: 43 | net=nBlock(blocknum=opt.num_block, views=64, dets=1024, width=256, height=256, 44 | dImg=0.0072, dDet=0.0058, dAng=0.098178, s2r=3.5, d2r=3, binshift=0).cuda() 45 | else: 46 | net=nBlock(blocknum=opt.num_block,views=64, dets=1024, width=256, height=256, 47 | dImg=0.0072, dDet=0.0058, dAng=0.098178, s2r=3.5, d2r=3, binshift=0) 48 | 49 | optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr) 50 | l1criterion=torch.nn.L1Loss() 51 | Loss=[] 52 | for epoch in range(opt.epochs+1): 53 | starttime=time.time() 54 | traLoss=0 55 | for i,(x,Y,img) in enumerate(train_loader): 56 | if useCuda: 57 | batch_Sino=Variable(Y.float()).cuda() 58 | batch_x=Variable(x.float()).cuda() 59 | else: 60 | batch_Sino=Variable(Y.float()) 61 | batch_x=Variable(x.float()) 62 | batch_Label=Variable(img.float()) 63 | net.train() 64 | if useCuda: 65 | out=net(batch_x,batch_Sino) 66 | out=out.cpu() 67 | else: 68 | out=net(batch_x,batch_Sino) 69 | finalLoss=l1criterion(out,batch_Label) 70 | optimizer.zero_grad() 71 | torch.autograd.set_detect_anomaly(True) 72 | finalLoss.backward() 73 | 74 | optimizer.step() 75 | traLoss+=finalLoss.data.item() 76 | print("epoch = ",epoch," loss = ",traLoss) 77 | Loss.append(traLoss) 78 | torch.save({ 'state_dict': net.state_dict(), 'Loss': Loss,#itnum,batchsize,lr 79 | 'optimizer': optimizer.state_dict()}, 80 | "../"+opt.models_path +opt.method+'_'+ str(opt.initmethod)+ '_'+str(opt.num_block)+ '_' + str(opt.batch_size) +'_'+str(opt.lr)+ '.pth') 81 | endtime=time.time() 82 | print('cost time= ',endtime-starttime,' s') 83 | totalSSIM=[] 84 | totalPSNR=[] 85 | if epoch%10==9 or epoch==0: 86 | SsimRes=[0]*5 87 | PredImg=[] 88 | PSNR = [0] * 5 89 | with torch.no_grad(): 90 | net.eval() 91 | testloss=0 92 | for j,(testx,testY,testimg) in enumerate(test_loader): 93 | testx=torch.Tensor(testx.float()).cuda() 94 | testY=torch.Tensor(testY.float()).cuda() 95 | testout=net(testx,testY) 96 | testout=testout.cpu() 97 | PredImg.append(testout) 98 | 99 | for bat in range(opt.batch_size): 100 | for spe in range(opt.Spectrumlen): 101 | ssimtmp=ssim((testout[bat,spe,:,:].numpy()),(testimg[bat,spe,:,:].numpy()),data_range=1) 102 | SsimRes[spe]+=ssimtmp 103 | 104 | PSNRtmp = psnr((testout[bat, spe, :, :].numpy()), (testimg[bat, spe, :, :].numpy()), 105 | data_range=1) 106 | PSNR[spe] += PSNRtmp 107 | SsimRes=np.array(SsimRes) 108 | SsimRes=SsimRes/(opt.TestImgNum) 109 | PSNR = np.array(PSNR) 110 | PSNR = PSNR / (opt.TestImgNum) 111 | totalSSIM.append(SsimRes) 112 | totalPSNR.append(PSNR) 113 | torch.save({ 'ssim': totalSSIM, 114 | 'psnr': totalPSNR, 115 | 'predimg': PredImg}, 116 | "../"+opt.testres_path +opt.method+'_'+ str(opt.initmethod)+ '_'+str(opt.num_block)+ '_' + str(opt.batch_size) + '_' + str(opt.lr)+ '.pth') 117 | print("test ssim= ",SsimRes) 118 | print("test PSNR= ", PSNR) 119 | 120 | if __name__=="__main__": 121 | Train() 122 | 123 | 124 | 125 | --------------------------------------------------------------------------------