├── README.md ├── demos ├── AHPNet.py ├── AdaptiveNet.py ├── AirNet.py ├── CNNRPGD.py ├── DBP.py ├── DSigNet.py ├── FBPConvNet.py ├── FistaNet.py ├── HDNet.py ├── KSAE.py ├── LEARN.py ├── LPD.py ├── MAGIC.py ├── MetaInvNet.py ├── MomentumNet.py ├── RED_CNN.py ├── VVBPTensorNet.py ├── __init__.py ├── config.json ├── iCTNet.py ├── iRadonMap.py └── vis_tools.py ├── recon ├── __init__.py └── models │ ├── AHPNet.py │ ├── AdaptiveNet.py │ ├── AirNet.py │ ├── DBP.py │ ├── DSigNet.py │ ├── FBPConvNet.py │ ├── FistaNet.py │ ├── FramingUNet.py │ ├── HDNet.py │ ├── KSAE.py │ ├── LEARN.py │ ├── LEARN_FBP.py │ ├── LPD.py │ ├── MAGIC.py │ ├── MetaInvNet.py │ ├── MomentumNet.py │ ├── RED_CNN.py │ ├── UNet.py │ ├── VVBPTensorNet.py │ ├── __init__.py │ ├── iCTNet.py │ └── iRadonMap.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Synergizing-Physics-Model-based-and-Data-driven-Methods-for-Low-Dose-CT 2 | 3 | This is the codes for the physics/model data-driven methods for low-dose CT. They are reproduced by us based on the content of these papers. If you think there is something wrong with the code, please contact us. 4 | 5 | 6 | Before using this code, please install our library for CT reconstruction: 7 | 8 | https://github.com/xwj01/CTLIB 9 | 10 | Then type: 11 | 12 | ``` 13 | pip install -r requirements.txt 14 | python setup.py install 15 | 16 | ``` 17 | 18 | The dataset can be access on: 19 | https://drive.google.com/file/d/1WocZQo7f4Zx8wlrHywMNlUMtiQUCjPfA/view?usp=sharing 20 | -------------------------------------------------------------------------------- /demos/AHPNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import AHPNet 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = AHPNet(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | self.options = options.cuda() 48 | if self.is_vis_show: 49 | self.vis = Visualizer(env='AHPNet') 50 | 51 | def forward(self, x, p): 52 | out = self.model(x, p) 53 | return out 54 | 55 | def training_step(self, batch, batch_idx): 56 | x, p, y = batch 57 | out = self(x, p) 58 | layer = out.size(1) 59 | loss = 0.0 60 | for i in range(0, layer-1): 61 | loss = loss + F.mse_loss(out[:,[i],:,:], y) * 0.8 62 | loss = loss + F.mse_loss(out[:,[-1],:,:], y) 63 | self.log("train_loss", loss, on_step=True, on_epoch=True) 64 | if self.is_vis_show: 65 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 66 | return loss 67 | 68 | def validation_step(self, batch, batch_idx): 69 | x, p, y = batch 70 | out = self(x, p) 71 | loss = F.mse_loss(out[:,[-1],:,:], y) 72 | self.log("val_loss", loss, on_step=True, on_epoch=True) 73 | 74 | def test_step(self, batch, batch_idx): 75 | x, p, y, res_name = batch 76 | out = self(x, p) 77 | if self.is_res_save: 78 | self.res_save(out[:,[-1],:,:], res_name) 79 | 80 | def configure_optimizers(self): 81 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 82 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs) 83 | return [optimizer], [scheduler] 84 | 85 | def show_win_norm(self, y): 86 | x = y.clone() 87 | x[xself.show_win[1]] = self.show_win[1] 89 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 90 | return x 91 | 92 | def vis_show(self, loss, x, y, out, mode='Train'): 93 | self.vis.plot(mode + ' Loss', loss.item()) 94 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 95 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 96 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 97 | 98 | def res_save(self, out, res_name): 99 | res = out.cpu().numpy() 100 | if not os.path.exists(self.res_dir): 101 | os.mkdir(self.res_dir) 102 | for i in range(res.shape[0]): 103 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 104 | 105 | class data_loader(Dataset): 106 | def __init__(self, root, dose, mode): 107 | self.x_dir_name = 'input_' + dose 108 | self.x_path = os.path.join(root, mode, self.x_dir_name) 109 | self.mode = mode 110 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 111 | 112 | def __getitem__(self, index): 113 | file_x = self.files_x[index] 114 | file_p = file_x.replace('input', 'projection') 115 | file_y = file_x.replace(self.x_dir_name, 'label') 116 | input_data = scio.loadmat(file_x)['data'] 117 | prj_data = scio.loadmat(file_p)['data'] 118 | label_data = scio.loadmat(file_y)['data'] 119 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 120 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 121 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 122 | if self.mode == 'train' or self.mode == 'vali': 123 | return input_data, prj_data, label_data 124 | elif self.mode == 'test': 125 | res_name = file_x[-13:] 126 | return input_data, prj_data, label_data, res_name 127 | 128 | def __len__(self): 129 | return len(self.files_x) 130 | 131 | if __name__ == "__main__": 132 | args = get_parameters() 133 | network = net(args) 134 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 135 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback], strategy="ddp") 136 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 137 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 138 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 139 | trainer.fit(network, train_loader, vali_loader) 140 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_1/checkpoints/last.ckpt') 141 | trainer.test(network, test_loader, ckpt_path='best') 142 | -------------------------------------------------------------------------------- /demos/AdaptiveNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import AdaptiveNet 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = AdaptiveNet(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='AdaptiveNet') 49 | 50 | def forward(self, p): 51 | out = self.model(p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/AirNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import AirNet 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = AirNet(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='AirNet') 49 | 50 | def forward(self, x, p): 51 | out = self.model(x, p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(x, p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(x, p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(x, p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback], strategy="ddp") 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/CNNRPGD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import FBPConvNet 14 | from vis_tools import Visualizer 15 | import ctlib 16 | 17 | def setup_parser(arguments, title): 18 | parser = argparse.ArgumentParser(description=title, 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | for key, val in arguments.items(): 21 | parser.add_argument('--%s' % key, 22 | type=eval(val["type"]), 23 | help=val["help"], 24 | default=val["default"], 25 | nargs=val["nargs"] if "nargs" in val else None) 26 | return parser 27 | 28 | def get_parameters(title=None): 29 | with open("config.json") as data_file: 30 | data = json.load(data_file) 31 | parser = setup_parser(data, title) 32 | parameters = parser.parse_args() 33 | return parameters 34 | 35 | class net(pl.LightningModule): 36 | def __init__(self, args): 37 | super().__init__() 38 | options = torch.tensor([args.views, args.dets, args.width, args.height, 39 | args.dImg, args.dDet, args.Ang0, args.dAng, 40 | args.s2r, args.d2r, args.binshift, args.scan_type]) 41 | self.model = FBPConvNet() 42 | self.epochs = args.epochs 43 | self.T1 = int(self.epochs / 3) 44 | self.T2 = self.T1 * 2 45 | self.lr = args.lr 46 | self.gamma = args.gamma 47 | self.is_vis_show = args.is_vis_show 48 | self.show_win = args.show_win 49 | self.is_res_save = args.is_res_save 50 | self.res_dir = args.res_dir 51 | self.options = options 52 | if self.is_vis_show: 53 | self.vis = Visualizer(env='CNNRPGD') 54 | 55 | def forward(self, x): 56 | out = self.model(x) 57 | return out 58 | 59 | def training_step(self, batch, batch_idx): 60 | current_epoch = self.trainer.current_epoch 61 | x, p, y = batch 62 | if current_epoch >= self.T1: 63 | with torch.no_grad(): 64 | x_t = self(x).detach() 65 | if current_epoch >= self.T1 and current_epoch < self.T2: 66 | x = torch.cat((x, x_t), dim=0) 67 | y = torch.cat((y, y), dim=0) 68 | elif current_epoch >= self.T2: 69 | x = torch.cat((x, x_t, y), dim=0) 70 | y = torch.cat((y, y, y), dim=0) 71 | out = self(x) 72 | loss = F.mse_loss(out, y) 73 | self.log("train_loss", loss, on_step=True, on_epoch=True) 74 | if self.is_vis_show: 75 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 76 | return loss 77 | 78 | def validation_step(self, batch, batch_idx): 79 | current_epoch = self.trainer.current_epoch 80 | x, p, y = batch 81 | if current_epoch >= self.T1: 82 | with torch.no_grad(): 83 | x_t = self(x).detach() 84 | if current_epoch >= self.T1 and current_epoch < self.T2: 85 | x = torch.cat((x, x_t), dim=0) 86 | y = torch.cat((y, y), dim=0) 87 | elif current_epoch >= self.T2: 88 | x = torch.cat((x, x_t, y), dim=0) 89 | y = torch.cat((y, y, y), dim=0) 90 | out = self(x) 91 | loss = F.mse_loss(out, y) 92 | self.log("val_loss", loss) 93 | 94 | def test_step(self, batch, batch_idx): 95 | x, p, y, res_name = batch 96 | options = self.options.to(x.device) 97 | alpha = torch.ones(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 98 | # c_set = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 99 | # gamma_set = [0.1, 0.05, 0.01, 0.005, 0.001, 0.0005, 0.0001] 100 | # c_set = [0.5] 101 | # gamma_set = [0.5, 0.4, 0.3, 0.2, 0.09, 0.08, 0.07, 0.06] 102 | c_set = [0.5] 103 | gamma_set=[0.1] 104 | for c in c_set: 105 | for gamma in gamma_set: 106 | x_k = x.clone() 107 | for i in range(100): 108 | p_tmp = ctlib.projection(x_k.contiguous(), options) 109 | p_error = p - p_tmp 110 | x_error = ctlib.projection_t(p_error.contiguous(), options) 111 | z_k = self(x_k + gamma * x_error) 112 | norm_new = torch.linalg.vector_norm(z_k - x_k, dim=(2,3), keepdim=True) 113 | if i > 0: 114 | mask = norm_new > c * norm_old 115 | alpha = (c * norm_old/ (norm_new + 1e-8) * mask + ~mask) * alpha 116 | norm_old = norm_new 117 | x_k = (1 - alpha) * x_k + alpha * z_k 118 | out = x_k 119 | if self.is_res_save: 120 | self.res_dir = 'result_c_' + str(c) + '_gamma_' + str(gamma) 121 | self.res_save(out, res_name) 122 | 123 | def configure_optimizers(self): 124 | optimizer = torch.optim.SGD(self.parameters(), lr=self.lr, momentum=0.99) 125 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[self.T1], gamma=0.1) 126 | return [optimizer], [scheduler] 127 | 128 | def show_win_norm(self, y): 129 | x = y.clone() 130 | x[xself.show_win[1]] = self.show_win[1] 132 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 133 | return x 134 | 135 | def vis_show(self, loss, x, y, out, mode='Train'): 136 | self.vis.plot(mode + ' Loss', loss.item()) 137 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 138 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 139 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 140 | 141 | def res_save(self, out, res_name): 142 | res = out.cpu().numpy() 143 | if not os.path.exists(self.res_dir): 144 | os.mkdir(self.res_dir) 145 | for i in range(res.shape[0]): 146 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 147 | 148 | class data_loader(Dataset): 149 | def __init__(self, root, dose, mode): 150 | self.x_dir_name = 'input_' + dose 151 | self.x_path = os.path.join(root, mode, self.x_dir_name) 152 | self.mode = mode 153 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 154 | 155 | def __getitem__(self, index): 156 | file_x = self.files_x[index] 157 | file_p = file_x.replace('input', 'projection') 158 | file_y = file_x.replace(self.x_dir_name, 'label') 159 | input_data = scio.loadmat(file_x)['data'] 160 | prj_data = scio.loadmat(file_p)['data'] 161 | label_data = scio.loadmat(file_y)['data'] 162 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 163 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 164 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 165 | if self.mode == 'train' or self.mode == 'vali': 166 | return input_data, prj_data, label_data 167 | elif self.mode == 'test': 168 | res_name = file_x[-13:] 169 | return input_data, prj_data, label_data, res_name 170 | 171 | def __len__(self): 172 | return len(self.files_x) 173 | 174 | if __name__ == "__main__": 175 | args = get_parameters() 176 | network = net(args) 177 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 178 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback],strategy="ddp") 179 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 180 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 181 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 182 | trainer.fit(network, train_loader, vali_loader) 183 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_3/checkpoints/last.ckpt') 184 | trainer.test(network, test_loader, ckpt_path='best') 185 | -------------------------------------------------------------------------------- /demos/DBP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import DBP 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = DBP(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='DBP') 49 | 50 | def forward(self, p): 51 | out = self.model(p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_1/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/DSigNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | from torch.autograd import Function 6 | from .iRadonMap import PixelIndexCal_cuda 7 | from .iRadonMap import BackProjNet 8 | 9 | def PixelIndexCal_DownSampling(length, width, lds, wds): 10 | length, width = int(length/lds), int(width/wds) 11 | ds_indices = torch.zeros(lds*wds, width*length).type(torch.LongTensor) 12 | for x in range(lds): 13 | for y in range(wds): 14 | k = x*width*wds+y 15 | for z in range(length): 16 | i, j = z*width, x*wds+y 17 | st = k+z*width*wds*lds 18 | ds_indices[j, i:i+width] = torch.tensor(range(st,st+width*wds, wds)) 19 | return ds_indices.view(-1) 20 | 21 | 22 | def PixelIndexCal_UpSampling(index, length, width): 23 | index = index.view(-1) 24 | _, ups_indices = index.sort(dim=0, descending=False) 25 | return ups_indices.view(-1) 26 | 27 | class DownSamplingBlock(nn.Module): 28 | def __init__(self, planes=8, length=512, width=736, lds=2, wds=2): 29 | super(DownSamplingBlock, self).__init__() 30 | self.length = int(length/lds) 31 | self.width = int(width/wds) 32 | self.extra_channel = lds*wds 33 | self.ds_index = nn.Parameter(PixelIndexCal_DownSampling(length, width, lds, wds), requires_grad=False) 34 | self.filter = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 35 | self.ln = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 36 | self.leakyrelu = nn.LeakyReLU(0.2, True) 37 | 38 | def forward(self, input): 39 | _, channel, length, width = input.size() 40 | output = torch.index_select(input.view(-1, channel, length*width), 2, self.ds_index) 41 | output = output.view(-1, channel*self.extra_channel, self.length, self.width) 42 | output = self.leakyrelu(self.ln(self.filter(output))) 43 | 44 | return output 45 | 46 | 47 | class UpSamplingBlock(nn.Module): 48 | def __init__(self, planes=8, length=64, width=64, lups=2, wups=2): 49 | super(UpSamplingBlock, self).__init__() 50 | 51 | self.length = length*lups 52 | self.width = width*wups 53 | self.extra_channel = lups*wups 54 | ds_index = PixelIndexCal_DownSampling(self.length, self.width, lups, wups) 55 | self.ups_index = nn.Parameter(PixelIndexCal_UpSampling(ds_index, self.length, self.width), requires_grad=False) 56 | self.filter = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 57 | self.ln = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 58 | self.leakyrelu = nn.LeakyReLU(0.2, True) 59 | 60 | def forward(self, input): 61 | _, channel, length, width = input.size() 62 | channel = int(channel/self.extra_channel) 63 | output = torch.index_select(input.view(-1, channel, self.extra_channel*length*width), 2, self.ups_index) 64 | output = output.view(-1, channel, self.length, self.width) 65 | output = self.leakyrelu(self.ln(self.filter(output))) 66 | 67 | return output 68 | 69 | 70 | class ResidualBlock(nn.Module): 71 | def __init__(self, planes): 72 | super(ResidualBlock, self).__init__() 73 | 74 | self.filter1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 75 | self.ln1 = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 76 | self.leakyrelu1 = nn.LeakyReLU(0.2, True) 77 | self.filter2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 78 | self.ln2 = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 79 | self.leakyrelu2 = nn.LeakyReLU(0.2, True) 80 | 81 | def forward(self, input): 82 | output = self.leakyrelu1(self.ln1(self.filter1(input))) 83 | output = self.ln2(self.filter2(output)) 84 | output += input 85 | output = self.leakyrelu2(output) 86 | 87 | return output 88 | 89 | 90 | class SinoNet(nn.Module): 91 | def __init__(self, bp_channel, num_filters): 92 | super(SinoNet, self).__init__() 93 | 94 | model_list = [nn.Conv2d(1, num_filters, kernel_size=3, stride=1, padding=1, bias=True), nn.GroupNorm(num_channels=num_filters, num_groups=1, affine=False), nn.LeakyReLU(0.2, True)] 95 | model_list += [DownSamplingBlock(planes=num_filters*4, length=512, width=736, lds=2, wds=2)] 96 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 97 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 98 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 99 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 100 | 101 | model_list += [nn.Conv2d(num_filters*4, bp_channel, kernel_size=1, stride=1, padding=0, bias=True)] 102 | self.model = nn.Sequential(*model_list) 103 | 104 | def forward(self, input): 105 | 106 | return self.model(input) 107 | 108 | class SpatialNet(nn.Module): 109 | def __init__(self, bp_channel, num_filters): 110 | super(SpatialNet, self).__init__() 111 | 112 | model_list = [nn.Conv2d(bp_channel, num_filters*4, kernel_size=3, stride=1, padding=1, bias=True), nn.GroupNorm(num_channels=num_filters*4, num_groups=1, affine=False), nn.LeakyReLU(0.2, True)] 113 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 114 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 115 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 116 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 117 | model_list += [UpSamplingBlock(planes=num_filters, length=256, width=256, lups=2, wups=2)] 118 | 119 | model_list += [nn.Conv2d(num_filters, 1, kernel_size=1, stride=1, padding=0, bias=True)] 120 | self.model = nn.Sequential(*model_list) 121 | 122 | def forward(self, input): 123 | 124 | return self.model(input) 125 | 126 | class DSigNet(nn.Module): 127 | def __init__(self, options, bp_channel=4, num_filters=16, scale_factor=2) -> None: 128 | super().__init__() 129 | geo_real = {'nVoxelX': int(options[2]), 'sVoxelX': float(options[4]) * int(options[2]), 'dVoxelX': float(options[4]), 130 | 'nVoxelY': int(options[3]), 'sVoxelY': float(options[4]) * int(options[3]), 'dVoxelY': float(options[4]), 131 | 'nDetecU': int(options[1]), 'sDetecU': float(options[5]) * int(options[1]), 'dDetecU': float(options[5]), 132 | 'offOriginX': 0.0, 'offOriginY': 0.0, 133 | 'views': int(options[0]), 'slices': 1, 134 | 'DSD': float(options[8]) + float(options[9]), 'DSO': float(options[8]), 'DOD': float(options[9]), 135 | 'start_angle': 0.0, 'end_angle': float(options[7]) * int(options[0]), 136 | 'mode': 'fanflat', 'extent': 1, # currently extent supports 1, 2, or 3. 137 | } 138 | geo_virtual = dict() 139 | geo_virtual.update({x: int(geo_real[x]/scale_factor) for x in ['views']}) 140 | geo_virtual.update({x: int(geo_real[x]/scale_factor) for x in ['nVoxelX', 'nVoxelY', 'nDetecU']}) 141 | geo_virtual.update({x: geo_real[x]/scale_factor for x in ['sVoxelX', 'sVoxelY', 'sDetecU', 'DSD', 'DSO', 'DOD', 'offOriginX', 'offOriginY']}) 142 | geo_virtual.update({x: geo_real[x] for x in ['dVoxelX', 'dVoxelY', 'dDetecU', 'slices', 'start_angle', 'end_angle', 'mode', 'extent']}) 143 | geo_virtual['indices'], geo_virtual['weights'] = PixelIndexCal_cuda(geo_virtual) 144 | self.SinoNet = SinoNet(bp_channel, num_filters) 145 | self.BackProjNet = BackProjNet(geo_virtual, bp_channel) 146 | self.SpatialNet = SpatialNet(bp_channel, num_filters) 147 | for module in self.modules(): 148 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 149 | nn.init.xavier_uniform_(module.weight) 150 | if module.bias is not None: 151 | module.bias.data.zero_() 152 | 153 | def forward(self, input): 154 | output = self.SinoNet(input) 155 | output = self.BackProjNet(output) 156 | output = self.SpatialNet(output) 157 | return output 158 | 159 | -------------------------------------------------------------------------------- /demos/FBPConvNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import FBPConvNet 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | self.model = FBPConvNet() 38 | self.epochs = args.epochs 39 | self.lr = args.lr 40 | self.is_vis_show = args.is_vis_show 41 | self.show_win = args.show_win 42 | self.is_res_save = args.is_res_save 43 | self.res_dir = args.res_dir 44 | if self.is_vis_show: 45 | self.vis = Visualizer(env='FBPConvNet') 46 | 47 | def forward(self, x): 48 | out = self.model(x) 49 | return out 50 | 51 | def training_step(self, batch, batch_idx): 52 | x, y = batch 53 | out = self(x) 54 | loss = F.mse_loss(out, y) 55 | self.log("train_loss", loss, on_step=True, on_epoch=True) 56 | if self.is_vis_show: 57 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 58 | return loss 59 | 60 | def validation_step(self, batch, batch_idx): 61 | x, y = batch 62 | out = self(x) 63 | loss = F.mse_loss(out, y) 64 | self.log("val_loss", loss) 65 | 66 | def test_step(self, batch, batch_idx): 67 | x, y, res_name = batch 68 | out = self(x) 69 | if self.is_res_save: 70 | self.res_save(out, res_name) 71 | 72 | def configure_optimizers(self): 73 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 74 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs) 75 | return [optimizer], [scheduler] 76 | 77 | def show_win_norm(self, y): 78 | x = y.clone() 79 | x[xself.show_win[1]] = self.show_win[1] 81 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 82 | return x 83 | 84 | def vis_show(self, loss, x, y, out, mode='Train'): 85 | self.vis.plot(mode + ' Loss', loss.item()) 86 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 87 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 88 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 89 | 90 | def res_save(self, out, res_name): 91 | res = out.cpu().numpy() 92 | if not os.path.exists(self.res_dir): 93 | os.mkdir(self.res_dir) 94 | for i in range(res.shape[0]): 95 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 96 | 97 | class data_loader(Dataset): 98 | def __init__(self, root, dose, mode): 99 | self.x_dir_name = 'input_patch_' + dose if mode == 'train' else 'input_' + dose 100 | self.x_path = os.path.join(root, mode, self.x_dir_name) 101 | self.mode = mode 102 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 103 | 104 | def __getitem__(self, index): 105 | file_x = self.files_x[index] 106 | file_y = file_x.replace(self.x_dir_name, 'label_patch') if self.mode == 'train' else file_x.replace(self.x_dir_name, 'label') 107 | input_data = scio.loadmat(file_x)['data'] 108 | label_data = scio.loadmat(file_y)['data'] 109 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 110 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 111 | if self.mode == 'train' or self.mode == 'vali': 112 | return input_data, label_data 113 | elif self.mode == 'test': 114 | res_name = file_x[-13:] 115 | return input_data, label_data, res_name 116 | 117 | def __len__(self): 118 | return len(self.files_x) 119 | 120 | if __name__ == "__main__": 121 | args = get_parameters() 122 | network = net(args) 123 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 124 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 125 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size*64, shuffle=True, num_workers=args.cpus) 126 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 127 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 128 | trainer.fit(network, train_loader, vali_loader) 129 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_1/checkpoints/last.ckpt') 130 | trainer.test(network, test_loader, ckpt_path='best') 131 | -------------------------------------------------------------------------------- /demos/FistaNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import FistaNet 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = FistaNet(7, options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='FistaNet') 49 | 50 | def forward(self, x, p): 51 | out, loss_layers_sym, loss_st = self.model(x, p) 52 | return out, loss_layers_sym, loss_st 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out, loss_layers_sym, loss_st = self(x, p) 57 | loss_discrepancy = F.mse_loss(out, y) 58 | loss_constraint = torch.mean(loss_layers_sym ** 2) 59 | sparsity_constraint = torch.mean(loss_st.abs()) 60 | loss = loss_discrepancy + 0.01 * loss_constraint + 0.001 * sparsity_constraint 61 | self.log("train_loss", loss, on_step=True, on_epoch=True) 62 | if self.is_vis_show: 63 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 64 | return loss 65 | 66 | def validation_step(self, batch, batch_idx): 67 | x, p, y = batch 68 | out, _, _ = self(x, p) 69 | loss = F.mse_loss(out, y) 70 | self.log("val_loss", loss) 71 | 72 | def test_step(self, batch, batch_idx): 73 | x, p, y, res_name = batch 74 | out, _, _ = self(x, p) 75 | if self.is_res_save: 76 | self.res_save(out, res_name) 77 | 78 | def configure_optimizers(self): 79 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 80 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs) 81 | return [optimizer], [scheduler] 82 | 83 | def show_win_norm(self, y): 84 | x = y.clone() 85 | x[xself.show_win[1]] = self.show_win[1] 87 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 88 | return x 89 | 90 | def vis_show(self, loss, x, y, out, mode='Train'): 91 | self.vis.plot(mode + ' Loss', loss.item()) 92 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 93 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 94 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 95 | 96 | def res_save(self, out, res_name): 97 | res = out.cpu().numpy() 98 | if not os.path.exists(self.res_dir): 99 | os.mkdir(self.res_dir) 100 | for i in range(res.shape[0]): 101 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 102 | 103 | class data_loader(Dataset): 104 | def __init__(self, root, dose, mode): 105 | self.x_dir_name = 'input_' + dose 106 | self.x_path = os.path.join(root, mode, self.x_dir_name) 107 | self.mode = mode 108 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 109 | 110 | def __getitem__(self, index): 111 | file_x = self.files_x[index] 112 | file_p = file_x.replace('input', 'projection') 113 | file_y = file_x.replace(self.x_dir_name, 'label') 114 | input_data = scio.loadmat(file_x)['data'] 115 | prj_data = scio.loadmat(file_p)['data'] 116 | label_data = scio.loadmat(file_y)['data'] 117 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 118 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 119 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 120 | if self.mode == 'train' or self.mode == 'vali': 121 | return input_data, prj_data, label_data 122 | elif self.mode == 'test': 123 | res_name = file_x[-13:] 124 | return input_data, prj_data, label_data, res_name 125 | 126 | def __len__(self): 127 | return len(self.files_x) 128 | 129 | if __name__ == "__main__": 130 | args = get_parameters() 131 | network = net(args) 132 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 133 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 134 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 135 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 136 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 137 | trainer.fit(network, train_loader, vali_loader) 138 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_1/checkpoints/last.ckpt') 139 | trainer.test(network, test_loader, ckpt_path='best') 140 | -------------------------------------------------------------------------------- /demos/HDNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import HDNet_2d 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = HDNet_2d(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='HDNet_2d') 49 | 50 | def forward(self, p): 51 | out = self.model(p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/KSAE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import KSAE 14 | from vis_tools import Visualizer 15 | import ctlib 16 | 17 | def setup_parser(arguments, title): 18 | parser = argparse.ArgumentParser(description=title, 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | for key, val in arguments.items(): 21 | parser.add_argument('--%s' % key, 22 | type=eval(val["type"]), 23 | help=val["help"], 24 | default=val["default"], 25 | nargs=val["nargs"] if "nargs" in val else None) 26 | return parser 27 | 28 | def get_parameters(title=None): 29 | with open("config.json") as data_file: 30 | data = json.load(data_file) 31 | parser = setup_parser(data, title) 32 | parameters = parser.parse_args() 33 | return parameters 34 | 35 | class net(pl.LightningModule): 36 | def __init__(self, args): 37 | super().__init__() 38 | options = torch.tensor([args.views, args.dets, args.width, args.height, 39 | args.dImg, args.dDet, args.Ang0, args.dAng, 40 | args.s2r, args.d2r, args.binshift, args.scan_type]) 41 | self.model = KSAE() 42 | self.epochs = args.epochs 43 | self.alpha = 0.05 44 | self.gamma = 0.5 45 | self.beta = 0.005 46 | self.Nsd = 5 47 | self.iteration = 50 48 | self.lr = args.lr 49 | self.is_vis_show = args.is_vis_show 50 | self.show_win = args.show_win 51 | self.is_res_save = args.is_res_save 52 | self.res_dir = args.res_dir 53 | self.options = options 54 | if self.is_vis_show: 55 | self.vis = Visualizer(env='KSAE') 56 | 57 | def forward(self, x): 58 | out = self.model(x) 59 | return out 60 | 61 | def patch_sample(self, x, patch_size=16, stride=12, h_ind=None, w_ind=None): 62 | B, C, H, W = x.shape 63 | Ph = H-patch_size+1 64 | Pw = W-patch_size+1 65 | if h_ind is None: 66 | h_ind = list(range(0, Ph, stride)) 67 | h_ind.append(Ph-1) 68 | h_ind = np.asarray(h_ind) 69 | h_ind[1:-1] += np.random.randint((stride - patch_size + 1) / 2, (patch_size - stride) / 2, [len(h_ind)-2]) 70 | h_ind[h_ind > Ph-1] = Ph - 1 71 | if w_ind is None: 72 | w_ind = list(range(0, Pw, stride)) 73 | w_ind.append(Pw-1) 74 | w_ind = np.asarray(w_ind) 75 | w_ind[1:-1] += np.random.randint((stride - patch_size + 1) / 2, (patch_size - stride) / 2, [len(w_ind)-2]) 76 | w_ind[w_ind > Pw-1] = Pw - 1 77 | 78 | y = torch.empty(B, C, len(h_ind), len(w_ind), patch_size, patch_size, device=x.device) 79 | for i in range(len(h_ind)): 80 | for j in range(len(w_ind)): 81 | y[:,:,i,j,:,:] = x[:,:,h_ind[i]:h_ind[i]+patch_size,w_ind[j]:w_ind[j]+patch_size] 82 | 83 | return y, h_ind, w_ind 84 | 85 | def patch_put(self, y, H, W, h_ind, w_ind, patch_size=16): 86 | 87 | x = torch.zeros(y.size(0), y.size(1), H, W, device=y.device) 88 | for i in range(len(h_ind)): 89 | for j in range(len(w_ind)): 90 | x[:,:,h_ind[i]:h_ind[i]+patch_size,w_ind[j]:w_ind[j]+patch_size] += y[:,:,i,j,:,:] 91 | 92 | return x 93 | 94 | def training_step(self, batch, batch_idx): 95 | x, p, y = batch 96 | x, h_ind, w_ind = self.patch_sample(x) 97 | y, _, _, = self.patch_sample(y, h_ind=h_ind, w_ind=w_ind) 98 | out = self(x) 99 | loss = F.mse_loss(out, y) 100 | self.log("train_loss", loss, on_step=True, on_epoch=True) 101 | if self.is_vis_show: 102 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 103 | return loss 104 | 105 | def validation_step(self, batch, batch_idx): 106 | x, p, y = batch 107 | options = self.options.to(x.device) 108 | W = torch.exp(-p) 109 | M_tmp = ctlib.projection(torch.ones_like(x).contiguous(), options) * W 110 | M = ctlib.projection_t(M_tmp.contiguous(), options) 111 | x_t = x.clone() 112 | z_t = x.clone() 113 | x_old = x.clone() 114 | for i in range(self.iteration): 115 | x_patch, h_ind, w_ind = self.patch_sample(x_t) 116 | overlap_mask = self.patch_put(torch.ones_like(x_patch), x.size(2), x.size(3), h_ind, w_ind) 117 | with torch.no_grad(): 118 | y_patch = self(x_patch) 119 | y_patch.requires_grad = True 120 | for k in range(self.Nsd): 121 | y_patch_res = self(y_patch) 122 | patch_loss = F.mse_loss(y_patch_res, x_patch, reduction='sum') 123 | grad = torch.autograd.grad(patch_loss, y_patch)[0] 124 | grad_norm = (grad ** 2).sum((-2, -1), keepdim=True).sqrt() 125 | y_patch = y_patch - self.alpha * grad / grad_norm 126 | y_error = (ctlib.projection(x_t.contiguous(), options) - p) * W 127 | grad_xt_1 = ctlib.projection_t(y_error.contiguous(), options) 128 | with torch.no_grad(): 129 | y_patch_res = self(y_patch) 130 | grad_xt_2 = self.beta * self.patch_put(x_patch - y_patch_res, x.size(2), x.size(3), h_ind, w_ind) 131 | grad_xt = (grad_xt_1 + grad_xt_2) / (M + self.beta * overlap_mask) 132 | x_t = z_t - grad_xt 133 | z_t = x_t + self.gamma * (x_t - x_old) 134 | x_old = x_t 135 | loss = F.mse_loss(x_t, y) 136 | self.log("val_loss", loss) 137 | 138 | def test_step(self, batch, batch_idx): 139 | x, p, y, res_name = batch 140 | options = self.options.to(x.device) 141 | W = torch.exp(-p) 142 | M_tmp = ctlib.projection(torch.ones_like(x).contiguous(), options) * W 143 | M = ctlib.projection_t(M_tmp.contiguous(), options) 144 | x_t = x.clone() 145 | z_t = x.clone() 146 | x_old = x.clone() 147 | for i in range(self.iteration): 148 | x_patch, h_ind, w_ind = self.patch_sample(x_t) 149 | overlap_mask = self.patch_put(torch.ones_like(x_patch), x.size(2), x.size(3), h_ind, w_ind) 150 | with torch.no_grad(): 151 | y_patch = self(x_patch) 152 | y_patch.requires_grad = True 153 | for k in range(self.Nsd): 154 | y_patch_res = self(y_patch) 155 | patch_loss = F.mse_loss(y_patch_res, x_patch, reduction='sum') 156 | grad = torch.autograd.grad(patch_loss, y_patch)[0] 157 | grad_norm = (grad ** 2).sum((-2, -1), keepdim=True).sqrt() 158 | y_patch = y_patch - self.alpha * grad / grad_norm 159 | y_error = (ctlib.projection(x_t.contiguous(), options) - p) * W 160 | grad_xt_1 = ctlib.projection_t(y_error.contiguous(), options) 161 | with torch.no_grad(): 162 | y_patch_res = self(y_patch) 163 | grad_xt_2 = self.beta * self.patch_put(x_patch - y_patch_res, x.size(2), x.size(3), h_ind, w_ind) 164 | grad_xt = (grad_xt_1 + grad_xt_2) / (M + self.beta * overlap_mask) 165 | x_t = z_t - grad_xt 166 | z_t = x_t + self.gamma * (x_t - x_old) 167 | x_old = x_t 168 | out = x_t 169 | if self.is_res_save: 170 | self.res_save(out, res_name) 171 | 172 | def on_validation_model_eval(self, *args, **kwargs): 173 | super().on_validation_model_eval(*args, **kwargs) 174 | torch.set_grad_enabled(True) 175 | 176 | def on_test_model_eval(self, *args, **kwargs): 177 | super().on_test_model_eval(*args, **kwargs) 178 | torch.set_grad_enabled(True) 179 | 180 | def configure_optimizers(self): 181 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 182 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs) 183 | return [optimizer], [scheduler] 184 | 185 | def show_win_norm(self, y): 186 | x = y.clone() 187 | x[xself.show_win[1]] = self.show_win[1] 189 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 190 | return x 191 | 192 | def vis_show(self, loss, x, y, out, mode='Train'): 193 | self.vis.plot(mode + ' Loss', loss.item()) 194 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 195 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 196 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 197 | 198 | def res_save(self, out, res_name): 199 | res = out.cpu().numpy() 200 | if not os.path.exists(self.res_dir): 201 | os.mkdir(self.res_dir) 202 | for i in range(res.shape[0]): 203 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 204 | 205 | class data_loader(Dataset): 206 | def __init__(self, root, dose, mode): 207 | self.x_dir_name = 'input_' + dose 208 | self.x_path = os.path.join(root, mode, self.x_dir_name) 209 | self.mode = mode 210 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 211 | 212 | def __getitem__(self, index): 213 | file_x = self.files_x[index] 214 | file_p = file_x.replace('input', 'projection') 215 | file_y = file_x.replace(self.x_dir_name, 'label') 216 | input_data = scio.loadmat(file_x)['data'] 217 | prj_data = scio.loadmat(file_p)['data'] 218 | label_data = scio.loadmat(file_y)['data'] 219 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 220 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 221 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 222 | if self.mode == 'train' or self.mode == 'vali': 223 | return input_data, prj_data, label_data 224 | elif self.mode == 'test': 225 | res_name = file_x[-13:] 226 | return input_data, prj_data, label_data, res_name 227 | 228 | def __len__(self): 229 | return len(self.files_x) 230 | 231 | if __name__ == "__main__": 232 | args = get_parameters() 233 | network = net(args) 234 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 235 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 236 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 237 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size*8, shuffle=False, num_workers=args.cpus) 238 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size*8, shuffle=False, num_workers=args.cpus) 239 | trainer.fit(network, train_loader, vali_loader) 240 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_1/checkpoints/last.ckpt') 241 | trainer.test(network, test_loader, ckpt_path='best') 242 | -------------------------------------------------------------------------------- /demos/LEARN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import LEARN 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = LEARN(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='LEARN') 49 | 50 | def forward(self, x, p): 51 | out = self.model(x, p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(x, p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(x, p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(x, p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/LPD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import Learned_primal_dual 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = Learned_primal_dual(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='LPD') 49 | 50 | def forward(self, x, p): 51 | out = self.model(x, p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(x, p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(x, p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(x, p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/MAGIC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import MAGIC 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = MAGIC(options, img_size=512, p_size=12, stride=4, gcn_hid_ch=144) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='MAGIC') 49 | 50 | def forward(self, x, p): 51 | out = self.model(x, p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(x, p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(x, p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(x, p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback], strategy="ddp") 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/MetaInvNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import MetaInvNet_H 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = MetaInvNet_H(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | self.options = options.cuda() 48 | if self.is_vis_show: 49 | self.vis = Visualizer(env='MetaInvNet') 50 | 51 | def forward(self, x, p): 52 | out = self.model(x, p) 53 | return out 54 | 55 | def training_step(self, batch, batch_idx): 56 | x, p, y = batch 57 | out = self(x, p) 58 | layer = len(out) 59 | loss = 0.0 60 | for i in range(0, layer): 61 | loss = loss + F.mse_loss(out[i], y) * 1.1**i 62 | self.log("train_loss", loss, on_step=True, on_epoch=True) 63 | if self.is_vis_show: 64 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 65 | return loss 66 | 67 | def validation_step(self, batch, batch_idx): 68 | x, p, y = batch 69 | out = self(x, p) 70 | loss = F.mse_loss(out[-1], y) 71 | self.log("val_loss", loss, on_step=True, on_epoch=True) 72 | 73 | def test_step(self, batch, batch_idx): 74 | x, p, y, res_name = batch 75 | out = self(x, p) 76 | if self.is_res_save: 77 | self.res_save(out[-1], res_name) 78 | 79 | def configure_optimizers(self): 80 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 81 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 82 | return [optimizer], [scheduler] 83 | 84 | def show_win_norm(self, y): 85 | x = y.clone() 86 | x[xself.show_win[1]] = self.show_win[1] 88 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 89 | return x 90 | 91 | def vis_show(self, loss, x, y, out, mode='Train'): 92 | self.vis.plot(mode + ' Loss', loss.item()) 93 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 94 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 95 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 96 | 97 | def res_save(self, out, res_name): 98 | res = out.cpu().numpy() 99 | if not os.path.exists(self.res_dir): 100 | os.mkdir(self.res_dir) 101 | for i in range(res.shape[0]): 102 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 103 | 104 | class data_loader(Dataset): 105 | def __init__(self, root, dose, mode): 106 | self.x_dir_name = 'input_' + dose 107 | self.x_path = os.path.join(root, mode, self.x_dir_name) 108 | self.mode = mode 109 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 110 | 111 | def __getitem__(self, index): 112 | file_x = self.files_x[index] 113 | file_p = file_x.replace('input', 'projection') 114 | file_y = file_x.replace(self.x_dir_name, 'label') 115 | input_data = scio.loadmat(file_x)['data'] 116 | prj_data = scio.loadmat(file_p)['data'] 117 | label_data = scio.loadmat(file_y)['data'] 118 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 119 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 120 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 121 | if self.mode == 'train' or self.mode == 'vali': 122 | return input_data, prj_data, label_data 123 | elif self.mode == 'test': 124 | res_name = file_x[-13:] 125 | return input_data, prj_data, label_data, res_name 126 | 127 | def __len__(self): 128 | return len(self.files_x) 129 | 130 | if __name__ == "__main__": 131 | args = get_parameters() 132 | network = net(args) 133 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 134 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback], strategy="ddp") 135 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 136 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 137 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 138 | trainer.fit(network, train_loader, vali_loader) 139 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 140 | trainer.test(network, test_loader, ckpt_path='best') 141 | -------------------------------------------------------------------------------- /demos/RED_CNN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import RED_CNN 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | self.model = RED_CNN() 38 | self.epochs = args.epochs 39 | self.lr = args.lr 40 | self.is_vis_show = args.is_vis_show 41 | self.show_win = args.show_win 42 | self.is_res_save = args.is_res_save 43 | self.res_dir = args.res_dir 44 | if self.is_vis_show: 45 | self.vis = Visualizer(env='RED_CNN') 46 | 47 | def forward(self, x): 48 | out = self.model(x) 49 | return out 50 | 51 | def training_step(self, batch, batch_idx): 52 | x, y = batch 53 | out = self(x) 54 | loss = F.mse_loss(out, y) 55 | self.log("train_loss", loss, on_step=True, on_epoch=True) 56 | if self.is_vis_show: 57 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 58 | return loss 59 | 60 | def validation_step(self, batch, batch_idx): 61 | x, y = batch 62 | out = self(x) 63 | loss = F.mse_loss(out, y) 64 | self.log("val_loss", loss) 65 | 66 | def test_step(self, batch, batch_idx): 67 | x, y, res_name = batch 68 | out = self(x) 69 | if self.is_res_save: 70 | self.res_save(out, res_name) 71 | 72 | def configure_optimizers(self): 73 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 74 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 75 | return [optimizer], [scheduler] 76 | 77 | def show_win_norm(self, y): 78 | x = y.clone() 79 | x[xself.show_win[1]] = self.show_win[1] 81 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 82 | return x 83 | 84 | def vis_show(self, loss, x, y, out, mode='Train'): 85 | self.vis.plot(mode + ' Loss', loss.item()) 86 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 87 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 88 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 89 | 90 | def res_save(self, out, res_name): 91 | res = out.cpu().numpy() 92 | if not os.path.exists(self.res_dir): 93 | os.mkdir(self.res_dir) 94 | for i in range(res.shape[0]): 95 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 96 | 97 | class data_loader(Dataset): 98 | def __init__(self, root, dose, mode): 99 | self.x_dir_name = 'input_patch_' + dose if mode == 'train' else 'input_' + dose 100 | self.x_path = os.path.join(root, mode, self.x_dir_name) 101 | self.mode = mode 102 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 103 | 104 | def __getitem__(self, index): 105 | file_x = self.files_x[index] 106 | file_y = file_x.replace(self.x_dir_name, 'label_patch') if self.mode == 'train' else file_x.replace(self.x_dir_name, 'label') 107 | input_data = scio.loadmat(file_x)['data'] 108 | label_data = scio.loadmat(file_y)['data'] 109 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 110 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 111 | if self.mode == 'train' or self.mode == 'vali': 112 | return input_data, label_data 113 | elif self.mode == 'test': 114 | res_name = file_x[-13:] 115 | return input_data, label_data, res_name 116 | 117 | def __len__(self): 118 | return len(self.files_x) 119 | 120 | if __name__ == "__main__": 121 | args = get_parameters() 122 | network = net(args) 123 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 124 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 125 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size*64, shuffle=True, num_workers=args.cpus) 126 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 127 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 128 | trainer.fit(network, train_loader, vali_loader) 129 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 130 | trainer.test(network, test_loader, ckpt_path='best') 131 | -------------------------------------------------------------------------------- /demos/VVBPTensorNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import pytorch_lightning as pl 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from recon.models import VVBPTensorNet 14 | from vis_tools import Visualizer 15 | 16 | def setup_parser(arguments, title): 17 | parser = argparse.ArgumentParser(description=title, 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | for key, val in arguments.items(): 20 | parser.add_argument('--%s' % key, 21 | type=eval(val["type"]), 22 | help=val["help"], 23 | default=val["default"], 24 | nargs=val["nargs"] if "nargs" in val else None) 25 | return parser 26 | 27 | def get_parameters(title=None): 28 | with open("config.json") as data_file: 29 | data = json.load(data_file) 30 | parser = setup_parser(data, title) 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | class net(pl.LightningModule): 35 | def __init__(self, args): 36 | super().__init__() 37 | options = torch.tensor([args.views, args.dets, args.width, args.height, 38 | args.dImg, args.dDet, args.Ang0, args.dAng, 39 | args.s2r, args.d2r, args.binshift, args.scan_type]) 40 | self.model = VVBPTensorNet(options) 41 | self.epochs = args.epochs 42 | self.lr = args.lr 43 | self.is_vis_show = args.is_vis_show 44 | self.show_win = args.show_win 45 | self.is_res_save = args.is_res_save 46 | self.res_dir = args.res_dir 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='VVBPTensorNet') 49 | 50 | def forward(self, p): 51 | out = self.model(p) 52 | return out 53 | 54 | def training_step(self, batch, batch_idx): 55 | x, p, y = batch 56 | out = self(p) 57 | loss = F.mse_loss(out, y) 58 | self.log("train_loss", loss, on_step=True, on_epoch=True) 59 | if self.is_vis_show: 60 | self.vis_show(loss.detach(), x.detach(), y.detach(), out.detach()) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, p, y = batch 65 | out = self(p) 66 | loss = F.mse_loss(out, y) 67 | self.log("val_loss", loss) 68 | 69 | def test_step(self, batch, batch_idx): 70 | x, p, y, res_name = batch 71 | out = self(p) 72 | if self.is_res_save: 73 | self.res_save(out, res_name) 74 | 75 | def configure_optimizers(self): 76 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 77 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs) 78 | return [optimizer], [scheduler] 79 | 80 | def show_win_norm(self, y): 81 | x = y.clone() 82 | x[xself.show_win[1]] = self.show_win[1] 84 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 85 | return x 86 | 87 | def vis_show(self, loss, x, y, out, mode='Train'): 88 | self.vis.plot(mode + ' Loss', loss.item()) 89 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 90 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 91 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 92 | 93 | def res_save(self, out, res_name): 94 | res = out.cpu().numpy() 95 | if not os.path.exists(self.res_dir): 96 | os.mkdir(self.res_dir) 97 | for i in range(res.shape[0]): 98 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 99 | 100 | class data_loader(Dataset): 101 | def __init__(self, root, dose, mode): 102 | self.x_dir_name = 'input_' + dose 103 | self.x_path = os.path.join(root, mode, self.x_dir_name) 104 | self.mode = mode 105 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 106 | 107 | def __getitem__(self, index): 108 | file_x = self.files_x[index] 109 | file_p = file_x.replace('input', 'projection') 110 | file_y = file_x.replace(self.x_dir_name, 'label') 111 | input_data = scio.loadmat(file_x)['data'] 112 | prj_data = scio.loadmat(file_p)['data'] 113 | label_data = scio.loadmat(file_y)['data'] 114 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 115 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 116 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 117 | if self.mode == 'train' or self.mode == 'vali': 118 | return input_data, prj_data, label_data 119 | elif self.mode == 'test': 120 | res_name = file_x[-13:] 121 | return input_data, prj_data, label_data, res_name 122 | 123 | def __len__(self): 124 | return len(self.files_x) 125 | 126 | if __name__ == "__main__": 127 | args = get_parameters() 128 | network = net(args) 129 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 130 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback]) 131 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 132 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 133 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 134 | trainer.fit(network, train_loader, vali_loader) 135 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_1/checkpoints/last.ckpt') 136 | trainer.test(network, test_loader, ckpt_path='best') 137 | -------------------------------------------------------------------------------- /demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deep-Imaging-Group/Physics-Model-Data-Driven-Review/88e3155c41fb24b23832a9fdcb6dffb4b7c204c8/demos/__init__.py -------------------------------------------------------------------------------- /demos/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs": { 3 | "type": "int", 4 | "help": "Max traning epochs", 5 | "default": 200}, 6 | "batch_size": { 7 | "type": "int", 8 | "help": "Batch size of training data", 9 | "default": 4}, 10 | "lr": { 11 | "type": "float", 12 | "help": "Learning rate", 13 | "default": 1e-4}, 14 | "gpus": { 15 | "type": "int", 16 | "help": "Number of GPUs", 17 | "default": 1}, 18 | "gpu_ids": { 19 | "type": "int", 20 | "help": "GPU IDs", 21 | "nargs": "+", 22 | "default": [0]}, 23 | "is_specified_gpus": { 24 | "type": "bool", 25 | "help": "Whether use the specified GPUs", 26 | "default": false}, 27 | "cpus": { 28 | "type": "int", 29 | "help": "Number of CPUs", 30 | "default": 4}, 31 | "data_root_dir": { 32 | "type": "str", 33 | "help": "Root dir of data", 34 | "default": "/mnt/nfs-data-storage/xwj/dataset/mayo_low_dose_512"}, 35 | "dose": { 36 | "type": "str", 37 | "help": "Dose of low-dose CT", 38 | "default": "25%"}, 39 | "is_vis_show": { 40 | "type": "bool", 41 | "help": "Whether show the loss curve and res with visdom", 42 | "default": false}, 43 | "show_win": { 44 | "type": "float", 45 | "help": "Show window of the res", 46 | "nargs": "+", 47 | "default": [1.6128, 2.3808]}, 48 | "is_res_save": { 49 | "type": "bool", 50 | "help": "Whether save the test res", 51 | "default": true}, 52 | "res_dir": { 53 | "type": "str", 54 | "help": "The dir to save the test res", 55 | "default": "result"}, 56 | "views": { 57 | "type": "int", 58 | "help": "Number of projection views", 59 | "default": 512}, 60 | "dets": { 61 | "type": "int", 62 | "help": "Number of detector elements", 63 | "default": 736}, 64 | "width": { 65 | "type": "int", 66 | "help": "Width of image", 67 | "default": 512}, 68 | "height": { 69 | "type": "int", 70 | "help": "Height of image", 71 | "default": 512}, 72 | "dImg": { 73 | "type": "float", 74 | "help": "Physical size of a pixel", 75 | "default": 0.006641}, 76 | "dDet": { 77 | "type": "float", 78 | "help": "Physical size of a detector element", 79 | "default": 0.012858}, 80 | "Ang0": { 81 | "type": "float", 82 | "help": "Start angle", 83 | "default": 0}, 84 | "dAng": { 85 | "type": "float", 86 | "help": "Rad interval between two views", 87 | "default": 0.012268}, 88 | "s2r": { 89 | "type": "float", 90 | "help": "Distance between x-ray soruce and rotation center", 91 | "default": 5.95}, 92 | "d2r": { 93 | "type": "float", 94 | "help": "Distance between x-ray soruce and rotation center", 95 | "default": 4.906}, 96 | "binshift": { 97 | "type": "float", 98 | "help": "Shift of detector", 99 | "default": 0}, 100 | "scan_type": { 101 | "type": "int", 102 | "help": "Scannint type, 0: equal distance fan beam, 1: euqal angle fan beam, 2: parallel beam", 103 | "default": 0} 104 | } -------------------------------------------------------------------------------- /demos/iCTNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import scipy.io as scio 6 | import numpy as np 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import pytorch_lightning as pl 12 | from torch.utils.data import Dataset 13 | from torch.utils.data import DataLoader 14 | from pytorch_lightning.callbacks import ModelCheckpoint 15 | from recon.models import iCTNet 16 | from vis_tools import Visualizer 17 | 18 | def setup_parser(arguments, title): 19 | parser = argparse.ArgumentParser(description=title, 20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | for key, val in arguments.items(): 22 | parser.add_argument('--%s' % key, 23 | type=eval(val["type"]), 24 | help=val["help"], 25 | default=val["default"], 26 | nargs=val["nargs"] if "nargs" in val else None) 27 | return parser 28 | 29 | def get_parameters(title=None): 30 | with open("config.json") as data_file: 31 | data = json.load(data_file) 32 | parser = setup_parser(data, title) 33 | parameters = parser.parse_args() 34 | return parameters 35 | 36 | class net(pl.LightningModule): 37 | def __init__(self, args): 38 | super().__init__() 39 | self.model = iCTNet(args.views, args.dets, args.width, args.height, args.dAng) 40 | self.epochs = args.epochs 41 | self.lr = args.lr 42 | self.is_vis_show = args.is_vis_show 43 | self.show_win = args.show_win 44 | self.is_res_save = args.is_res_save 45 | self.res_dir = args.res_dir 46 | self.automatic_optimization = False 47 | if self.is_vis_show: 48 | self.vis = Visualizer(env='iCTNet') 49 | dets = args.dets 50 | dDet = args.dDet 51 | s2r = args.s2r 52 | d2r = args.d2r 53 | virdet = dDet * s2r / (s2r + d2r) 54 | filter = torch.empty(2 * dets - 1) 55 | pi = torch.acos(torch.tensor(-1.0)) 56 | for i in range(filter.size(0)): 57 | x = i - dets + 1 58 | if abs(x) % 2 == 1: 59 | filter[i] = -1 / (pi * pi * x * x * virdet * virdet) 60 | elif x == 0: 61 | filter[i] = 1 / (4 * virdet * virdet) 62 | else: 63 | filter[i] = 0 64 | filter = filter.view(1,1,1,-1) 65 | self.filter = nn.Parameter(filter, requires_grad=False) 66 | self.dets = dets 67 | 68 | def forward(self, p): 69 | out = self.model(p) 70 | return out 71 | 72 | def training_step(self, batch, batch_idx): 73 | current_epoch = self.trainer.current_epoch 74 | x, p, y, z = batch 75 | if current_epoch < 200: 76 | out = self.model.segment1(p) 77 | loss = F.mse_loss(out, z) 78 | elif current_epoch < 400: 79 | out = self.model.segment2(z) 80 | loss = F.mse_loss(out, z) 81 | elif current_epoch < 600: 82 | pf = torch.nn.functional.conv2d(z, self.filter, padding=(0,self.dets-1)) 83 | out = self.model.segment3(z) 84 | loss = F.mse_loss(out, pf) 85 | elif current_epoch < 800: 86 | pf = torch.nn.functional.conv2d(z, self.filter, padding=(0,self.dets-1)) 87 | out = self.model.segment4(pf) 88 | loss = F.mse_loss(out, y) 89 | else: 90 | out = self(p) 91 | loss = F.mse_loss(out, y) 92 | self.log("train_loss", loss, on_step=True, on_epoch=True) 93 | 94 | def validation_step(self, batch, batch_idx): 95 | current_epoch = self.trainer.current_epoch 96 | x, p, y, z = batch 97 | if current_epoch >= 800: 98 | out = self(p) 99 | loss = F.mse_loss(out, y) 100 | else: 101 | loss = 1e5 102 | self.log("val_loss", loss) 103 | 104 | def test_step(self, batch, batch_idx): 105 | x, p, y, z, res_name = batch 106 | out = self(p) 107 | if self.is_res_save: 108 | self.res_save(out, res_name) 109 | 110 | def configure_optimizers(self): 111 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) 112 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[920, 960], gamma=0.1) 113 | return [optimizer], [scheduler] 114 | 115 | def show_win_norm(self, y): 116 | x = y.clone() 117 | x[xself.show_win[1]] = self.show_win[1] 119 | x = (x - self.show_win[0]) / (self.show_win[1] - self.show_win[0]) * 255 120 | return x 121 | 122 | def vis_show(self, loss, x, y, out, mode='Train'): 123 | self.vis.plot(mode + ' Loss', loss.item()) 124 | self.vis.img(mode + ' Ground Truth', self.show_win_norm(y).cpu()) 125 | self.vis.img(mode + ' Input', self.show_win_norm(x).cpu()) 126 | self.vis.img(mode + ' Result', self.show_win_norm(out).cpu()) 127 | 128 | def res_save(self, out, res_name): 129 | res = out.cpu().numpy() 130 | if not os.path.exists(self.res_dir): 131 | os.mkdir(self.res_dir) 132 | for i in range(res.shape[0]): 133 | scio.savemat(self.res_dir + '/' + res_name[i], {'data':res[i].squeeze()}) 134 | 135 | class data_loader(Dataset): 136 | def __init__(self, root, dose, mode): 137 | self.x_dir_name = 'input_' + dose 138 | self.x_path = os.path.join(root, mode, self.x_dir_name) 139 | self.mode = mode 140 | self.files_x = np.array(sorted(glob.glob(os.path.join(self.x_path, 'data') + '*.mat'))) 141 | 142 | def __getitem__(self, index): 143 | file_x = self.files_x[index] 144 | file_p = file_x.replace('input', 'projection') 145 | file_y = file_x.replace(self.x_dir_name, 'label') 146 | file_py = file_x.replace(self.x_dir_name, 'projection') 147 | input_data = scio.loadmat(file_x)['data'] 148 | prj_data = scio.loadmat(file_p)['data'] 149 | label_data = scio.loadmat(file_y)['data'] 150 | prj_label = scio.loadmat(file_py)['data'] 151 | input_data = torch.FloatTensor(input_data).unsqueeze_(0) 152 | prj_data = torch.FloatTensor(prj_data).unsqueeze_(0) 153 | label_data = torch.FloatTensor(label_data).unsqueeze_(0) 154 | prj_label = torch.FloatTensor(prj_label).unsqueeze_(0) 155 | if self.mode == 'train' or self.mode == 'vali': 156 | return input_data, prj_data, label_data, prj_label 157 | elif self.mode == 'test': 158 | res_name = file_x[-13:] 159 | return input_data, prj_data, label_data, prj_label, res_name 160 | 161 | def __len__(self): 162 | return len(self.files_x) 163 | 164 | if __name__ == "__main__": 165 | args = get_parameters() 166 | network = net(args) 167 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=3, mode="min") 168 | trainer = pl.Trainer(gpus=args.gpu_ids if args.is_specified_gpus else args.gpus, log_every_n_steps=1, max_epochs=args.epochs, callbacks=[checkpoint_callback], strategy="ddp") 169 | train_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'train'), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus) 170 | vali_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'vali'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 171 | test_loader = DataLoader(data_loader(args.data_root_dir, args.dose, 'test'), batch_size=args.batch_size, shuffle=False, num_workers=args.cpus) 172 | trainer.fit(network, train_loader, vali_loader) 173 | # trainer.fit(network, train_loader, vali_loader, ckpt_path='lightning_logs/version_1/checkpoints/last.ckpt') 174 | trainer.test(network, test_loader, ckpt_path='best') 175 | -------------------------------------------------------------------------------- /demos/vis_tools.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch as t 4 | import visdom 5 | from matplotlib import pyplot as plot 6 | 7 | class Visualizer(object): 8 | """ 9 | wrapper for visdom 10 | you can still access naive visdom function by 11 | self.line, self.scater,self._send,etc. 12 | due to the implementation of `__getattr__` 13 | """ 14 | 15 | def __init__(self, env='default', **kwargs): 16 | self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs) 17 | self._vis_kw = kwargs 18 | 19 | # e.g.('loss',23) the 23th value of loss 20 | self.index = {} 21 | self.log_text = '' 22 | 23 | def reinit(self, env='default', **kwargs): 24 | """ 25 | change the config of visdom 26 | """ 27 | self.vis = visdom.Visdom(env=env, **kwargs) 28 | return self 29 | 30 | def plot_many(self, d): 31 | """ 32 | plot multi values 33 | @params d: dict (name,value) i.e. ('loss',0.11) 34 | """ 35 | for k, v in d.items(): 36 | if v is not None: 37 | self.plot(k, v) 38 | 39 | def img_many(self, d): 40 | for k, v in d.items(): 41 | self.img(k, v) 42 | 43 | def plot(self, name, y, **kwargs): 44 | """ 45 | self.plot('loss',1.00) 46 | """ 47 | x = self.index.get(name, 0) 48 | self.vis.line(Y=np.array([y]), X=np.array([x]), 49 | win=name, 50 | opts=dict(title=name), 51 | update=None if x == 0 else 'append', 52 | **kwargs 53 | ) 54 | self.index[name] = x + 1 55 | 56 | def img(self, name, img_, **kwargs): 57 | """ 58 | self.img('input_img',t.Tensor(64,64)) 59 | self.img('input_imgs',t.Tensor(3,64,64)) 60 | self.img('input_imgs',t.Tensor(100,1,64,64)) 61 | self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10) 62 | !!don't ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!! 63 | """ 64 | self.vis.images(t.Tensor(img_).numpy(), 65 | win=name, 66 | opts=dict(title=name), 67 | **kwargs 68 | ) 69 | 70 | def log(self, info, win='log_text'): 71 | """ 72 | self.log({'loss':1,'lr':0.0001}) 73 | """ 74 | self.log_text += ('[{time}] {info}
'.format( 75 | time=time.strftime('%m%d_%H%M%S'), \ 76 | info=info)) 77 | self.vis.text(self.log_text, win) 78 | 79 | def __getattr__(self, name): 80 | return getattr(self.vis, name) 81 | 82 | def state_dict(self): 83 | return { 84 | 'index': self.index, 85 | 'vis_kw': self._vis_kw, 86 | 'log_text': self.log_text, 87 | 'env': self.vis.env 88 | } 89 | 90 | def load_state_dict(self, d): 91 | self.vis = visdom.Visdom(env=d.get('env', self.vis.env), **(self.d.get('vis_kw'))) 92 | self.log_text = d.get('log_text', '') 93 | self.index = d.get('index', dict()) 94 | return self -------------------------------------------------------------------------------- /recon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deep-Imaging-Group/Physics-Model-Data-Driven-Review/88e3155c41fb24b23832a9fdcb6dffb4b7c204c8/recon/__init__.py -------------------------------------------------------------------------------- /recon/models/AHPNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | import ctlib 7 | 8 | def filter_gen(): 9 | h0 = torch.tensor([1/4, 1/2, 1/4]).unsqueeze(-1) 10 | h1 = torch.tensor([-1/4, 1/2, 1/4]).unsqueeze(-1) 11 | h2 = torch.tensor([math.sqrt(2)/4, 0, -math.sqrt(2)/4]).unsqueeze(-1) 12 | h = [h0, h1, h2] 13 | filter = [] 14 | for i in range(3): 15 | for j in range(3): 16 | if i == 0 and j == 0: 17 | continue 18 | f = h[i] @ h[j].t() 19 | f = f.view(1,1,3,3) 20 | filter.append(f) 21 | filter = torch.cat(filter, dim=0) 22 | return filter 23 | 24 | class MLP(nn.Module): 25 | def __init__(self, options) -> None: 26 | super().__init__() 27 | self.model = nn.Sequential( 28 | nn.Linear(9,9), 29 | nn.ReLU(True), 30 | nn.Linear(9,9), 31 | nn.ReLU(True), 32 | nn.Linear(9,8), 33 | nn.ReLU(True) 34 | ) 35 | self.filter = nn.Parameter(filter_gen(),requires_grad=False) 36 | self.options = nn.Parameter(options,requires_grad=False) 37 | 38 | def forward(self, x, z, p): 39 | r0 = p - ctlib.projection(x, self.options) 40 | rk = z - F.conv2d(x, self.filter, stride=1, padding=1) 41 | r0_norm = (r0 ** 2).sum(dim=(2,3)) 42 | rk_norm = (rk ** 2).sum(dim=(2,3)) 43 | r_norm = torch.cat((r0_norm, rk_norm), dim=1) 44 | beta = self.model(r_norm) 45 | beta = beta.unsqueeze(-1).unsqueeze(-1) 46 | return beta 47 | 48 | class CNN(nn.Module): 49 | def __init__(self, k) -> None: 50 | super().__init__() 51 | layers = [] 52 | layers.append(nn.Conv2d(k, 64, 3, 1, 1)) 53 | layers.append(nn.ReLU(True)) 54 | for i in range(17): 55 | layers.append(nn.Conv2d(64, 64, 3, 1, 1)) 56 | layers.append(nn.BatchNorm2d(64)) 57 | layers.append(nn.ReLU(True)) 58 | layers.append(nn.Conv2d(64, 1, 3, 1, 1)) 59 | layers.append(nn.ReLU(True)) 60 | self.model = nn.Sequential(*layers) 61 | self.filter = nn.Parameter(filter_gen(),requires_grad=False) 62 | 63 | def forward(self, x): 64 | x_tilde = self.model(x) 65 | z = F.conv2d(x_tilde, self.filter, stride=1, padding=1) 66 | return z 67 | 68 | class CGModule(nn.Module): 69 | def __init__(self, options): 70 | super().__init__() 71 | self.options = nn.Parameter(options, requires_grad=False) 72 | self.filter = nn.Parameter(filter_gen(),requires_grad=False) 73 | self.filter_t = nn.Parameter(self.filter.flip((2,3)),requires_grad=False) 74 | 75 | def AWx(self,img,mu): 76 | Ax = ctlib.projection(img, self.options) 77 | AtAx = ctlib.projection_t(Ax, self.options) 78 | Ax0 = AtAx + self.Ft(self.F(img), mu) 79 | return Ax0 80 | 81 | def F(self,x): 82 | return F.conv2d(x, self.filter, stride=1, padding=1) 83 | 84 | def Ft(self,y, mu): 85 | Ft = F.conv2d(y, self.filter_t, stride=1, padding=1, groups=8) * mu 86 | return Ft.sum(dim=1, keepdim=True) 87 | 88 | def pATAp(self,img): 89 | Ap=ctlib.projection(img, self.options) 90 | pATApNorm=torch.sum(Ap**2,dim=(1,2,3), keepdim=True) 91 | return pATApNorm 92 | 93 | def pWTWp(self,img,mu): 94 | Wp=self.F(img) 95 | mu_Wp=mu*(Wp**2) 96 | pWTWpNorm=torch.sum(mu_Wp,dim=(1,2,3), keepdim=True) 97 | return pWTWpNorm 98 | 99 | def CG_alg(self,x,mu,y,z,CGiter=20): 100 | Aty = ctlib.projection_t(y, self.options) 101 | Ftz = self.Ft(z, mu) 102 | res = Aty + Ftz 103 | r=res 104 | p=-res 105 | for k in range(CGiter): 106 | pATApNorm = self.pATAp(p) 107 | mu_pWtWpNorm=self.pWTWp(p,mu) 108 | rTr=torch.sum(r**2,dim=(1,2,3), keepdim=True) 109 | alphak = rTr / (mu_pWtWpNorm+pATApNorm) 110 | x = x+alphak*p 111 | r = r+alphak*self.AWx(p,mu) 112 | betak = torch.sum(r**2,dim=(1,2,3), keepdim=True)/ rTr 113 | p=-r+betak*p 114 | 115 | pATApNorm = self.pATAp(p) 116 | mu_pWtWpNorm=self.pWTWp(p,mu) 117 | rTr=torch.sum(r**2,dim=(1,2,3), keepdim=True) 118 | alphak = rTr/(mu_pWtWpNorm+pATApNorm) 119 | x = x+alphak*p 120 | return x 121 | 122 | class IterBlock(nn.Module): 123 | def __init__(self, k, options) -> None: 124 | super().__init__() 125 | self.Dcnn = CNN(k) 126 | self.Pmlp = MLP(options) 127 | self.CGModule = CGModule(options) 128 | self.filter = nn.Parameter(filter_gen(),requires_grad=False) 129 | 130 | def forward(self, x, p): 131 | z = self.Dcnn(x) 132 | beta = self.Pmlp(x[:,[-1],:,:].detach(), z.detach(), p) 133 | x_t = self.CGModule.CG_alg(x[:,[-1],:,:], beta, p, z, CGiter=5) 134 | return x_t 135 | 136 | 137 | class AHPNet(nn.Module): 138 | def __init__(self, options, layers=3): 139 | super(AHPNet,self).__init__() 140 | self.layers = layers 141 | self.model = nn.ModuleList([IterBlock(i+1, options) for i in range(self.layers)]) 142 | 143 | def forward(self, x, p): 144 | res = x.clone() 145 | for model in self.model: 146 | x_t = model(res, p) 147 | res = torch.cat((res, x_t), dim=1) 148 | return res -------------------------------------------------------------------------------- /recon/models/AdaptiveNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .LEARN_FBP import FBP 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, in_ch, out_ch): 7 | super(Conv2d, self).__init__() 8 | self.model = nn.Sequential( 9 | nn.Conv2d(in_ch, out_ch, 5, 1, 2), 10 | nn.BatchNorm2d(out_ch), 11 | nn.ReLU(inplace=True) 12 | ) 13 | 14 | def forward(self, x): 15 | return self.model(x) 16 | 17 | class SubNet(nn.Module): 18 | def __init__(self, layers): 19 | super(SubNet, self).__init__() 20 | self.conv_first = nn.Sequential(nn.Conv2d(1, 64, 5, 1, 2), nn.ReLU(inplace=True)) 21 | self.conv = nn.ModuleList([Conv2d(64, 64) for i in range(layers)]) 22 | self.conv_last = nn.Conv2d(64, 1, 5, 1, 2) 23 | self.layers = layers 24 | 25 | def forward(self, x): 26 | y = x.clone() 27 | y = self.conv_first(y) 28 | z = y.clone() 29 | for layer in self.conv: 30 | y = layer(y) 31 | z += y 32 | z = z / (self.layers + 1) 33 | z = self.conv_last(z) 34 | out = z + x 35 | return out 36 | 37 | 38 | class AdaptiveNet(nn.Module): 39 | def __init__(self, options): 40 | super(AdaptiveNet, self).__init__() 41 | self.model = nn.Sequential(SubNet(3), FBP(options), SubNet(5)) 42 | for module in self.modules(): 43 | if isinstance(module, nn.Conv2d): 44 | nn.init.normal_(module.weight, mean=0, std=0.01) 45 | if module.bias is not None: 46 | module.bias.data.zero_() 47 | if isinstance(module, nn.ConvTranspose2d): 48 | nn.init.normal_(module.weight, mean=0, std=0.01) 49 | if module.bias is not None: 50 | module.bias.data.zero_() 51 | if isinstance(module, nn.BatchNorm2d): 52 | module.weight.data.fill_(1) 53 | module.bias.data.zero_() 54 | 55 | def forward(self, y): 56 | out = self.model(y) 57 | return out 58 | -------------------------------------------------------------------------------- /recon/models/AirNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from .LEARN import projector 5 | from .LEARN_FBP import fidelity_module 6 | 7 | class Iter_block(nn.Module): 8 | def __init__(self, hid_channels, kernel_size, padding, options, idx): 9 | super(Iter_block, self).__init__() 10 | self.block1 = fidelity_module(options) 11 | self.block2 = nn.Sequential( 12 | nn.Conv2d(idx + 1, hid_channels, kernel_size=kernel_size, padding=padding), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(hid_channels, 1, kernel_size=kernel_size, padding=padding) 17 | ) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | def forward(self, input_data, proj, iter_res): 21 | mid_res = self.block1(input_data, proj) 22 | if iter_res is None: 23 | deep_res = mid_res.clone() 24 | else: 25 | deep_res = torch.cat((mid_res, iter_res), dim=1) 26 | out = self.block2(deep_res) + mid_res 27 | return out, deep_res 28 | 29 | class AirNet(nn.Module): 30 | def __init__(self, options, block_num=50, hid_channels=48, kernel_size=3, padding=1): 31 | super(AirNet, self).__init__() 32 | self.model = nn.ModuleList([Iter_block(hid_channels, kernel_size, padding, options, i) for i in range(block_num)]) 33 | for module in self.modules(): 34 | if isinstance(module, fidelity_module): 35 | module.weight.data.zero_() 36 | if isinstance(module, nn.Conv2d): 37 | nn.init.normal_(module.weight, mean=0, std=0.01) 38 | if module.bias is not None: 39 | module.bias.data.zero_() 40 | 41 | def forward(self, input_data, proj): 42 | x = input_data 43 | iter_res = None 44 | for index, module in enumerate(self.model): 45 | x, iter_res = module(x, proj, iter_res) 46 | return x -------------------------------------------------------------------------------- /recon/models/DBP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import ctlib 5 | 6 | class bprj_sv_fun(Function): 7 | @staticmethod 8 | def forward(self, proj, options): 9 | self.save_for_backward(options) 10 | return ctlib.backprojection_sv(proj, options) 11 | 12 | @staticmethod 13 | def backward(self, grad_output): 14 | options = self.saved_tensors[0] 15 | temp = grad_output.sum(1, keepdim=True) 16 | grad_input = ctlib.backprojection_t(temp.contiguous(), options) 17 | return grad_input, None 18 | 19 | class backprojector_sv(nn.Module): 20 | def __init__(self): 21 | super(backprojector_sv, self).__init__() 22 | 23 | def forward(self, proj, options): 24 | return bprj_sv_fun.apply(proj, options) 25 | 26 | class DBP(nn.Module): 27 | def __init__(self, options) -> None: 28 | super().__init__() 29 | self.options = nn.Parameter(options, requires_grad=False) 30 | layers = [] 31 | layers.append(nn.Conv2d(512, 64, 3, 1, 1)) 32 | layers.append(nn.ReLU(inplace=True)) 33 | for i in range(15): 34 | layers.append(nn.Conv2d(64, 64, 3, 1, 1)) 35 | layers.append(nn.BatchNorm2d(64)) 36 | layers.append(nn.ReLU(inplace=True)) 37 | layers.append(nn.Conv2d(64, 1, 3, 1, 1)) 38 | self.model = nn.Sequential(*layers) 39 | self.backprojector = backprojector_sv() 40 | 41 | def forward(self, p): 42 | x = self.backprojector(p, self.options) 43 | out = self.model(x) 44 | return out -------------------------------------------------------------------------------- /recon/models/DSigNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | from torch.autograd import Function 6 | from .iRadonMap import PixelIndexCal_cuda 7 | from .iRadonMap import BackProjNet 8 | 9 | def PixelIndexCal_DownSampling(length, width, lds, wds): 10 | length, width = int(length/lds), int(width/wds) 11 | ds_indices = torch.zeros(lds*wds, width*length).type(torch.LongTensor) 12 | for x in range(lds): 13 | for y in range(wds): 14 | k = x*width*wds+y 15 | for z in range(length): 16 | i, j = z*width, x*wds+y 17 | st = k+z*width*wds*lds 18 | ds_indices[j, i:i+width] = torch.tensor(range(st,st+width*wds, wds)) 19 | return ds_indices.view(-1) 20 | 21 | 22 | def PixelIndexCal_UpSampling(index, length, width): 23 | index = index.view(-1) 24 | _, ups_indices = index.sort(dim=0, descending=False) 25 | return ups_indices.view(-1) 26 | 27 | class DownSamplingBlock(nn.Module): 28 | def __init__(self, planes=8, length=512, width=736, lds=2, wds=2): 29 | super(DownSamplingBlock, self).__init__() 30 | self.length = int(length/lds) 31 | self.width = int(width/wds) 32 | self.extra_channel = lds*wds 33 | self.ds_index = nn.Parameter(PixelIndexCal_DownSampling(length, width, lds, wds), requires_grad=False) 34 | self.filter = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 35 | self.ln = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 36 | self.leakyrelu = nn.LeakyReLU(0.2, True) 37 | 38 | def forward(self, input): 39 | _, channel, length, width = input.size() 40 | output = torch.index_select(input.view(-1, channel, length*width), 2, self.ds_index) 41 | output = output.view(-1, channel*self.extra_channel, self.length, self.width) 42 | output = self.leakyrelu(self.ln(self.filter(output))) 43 | 44 | return output 45 | 46 | 47 | class UpSamplingBlock(nn.Module): 48 | def __init__(self, planes=8, length=64, width=64, lups=2, wups=2): 49 | super(UpSamplingBlock, self).__init__() 50 | 51 | self.length = length*lups 52 | self.width = width*wups 53 | self.extra_channel = lups*wups 54 | ds_index = PixelIndexCal_DownSampling(self.length, self.width, lups, wups) 55 | self.ups_index = nn.Parameter(PixelIndexCal_UpSampling(ds_index, self.length, self.width), requires_grad=False) 56 | self.filter = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 57 | self.ln = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 58 | self.leakyrelu = nn.LeakyReLU(0.2, True) 59 | 60 | def forward(self, input): 61 | _, channel, length, width = input.size() 62 | channel = int(channel/self.extra_channel) 63 | output = torch.index_select(input.view(-1, channel, self.extra_channel*length*width), 2, self.ups_index) 64 | output = output.view(-1, channel, self.length, self.width) 65 | output = self.leakyrelu(self.ln(self.filter(output))) 66 | 67 | return output 68 | 69 | 70 | class ResidualBlock(nn.Module): 71 | def __init__(self, planes): 72 | super(ResidualBlock, self).__init__() 73 | 74 | self.filter1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 75 | self.ln1 = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 76 | self.leakyrelu1 = nn.LeakyReLU(0.2, True) 77 | self.filter2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True) 78 | self.ln2 = nn.GroupNorm(num_channels=planes, num_groups=1, affine=False) 79 | self.leakyrelu2 = nn.LeakyReLU(0.2, True) 80 | 81 | def forward(self, input): 82 | output = self.leakyrelu1(self.ln1(self.filter1(input))) 83 | output = self.ln2(self.filter2(output)) 84 | output += input 85 | output = self.leakyrelu2(output) 86 | 87 | return output 88 | 89 | 90 | class SinoNet(nn.Module): 91 | def __init__(self, bp_channel, num_filters): 92 | super(SinoNet, self).__init__() 93 | 94 | model_list = [nn.Conv2d(1, num_filters, kernel_size=3, stride=1, padding=1, bias=True), nn.GroupNorm(num_channels=num_filters, num_groups=1, affine=False), nn.LeakyReLU(0.2, True)] 95 | model_list += [DownSamplingBlock(planes=num_filters*4, length=512, width=736, lds=2, wds=2)] 96 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 97 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 98 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 99 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 100 | 101 | model_list += [nn.Conv2d(num_filters*4, bp_channel, kernel_size=1, stride=1, padding=0, bias=True)] 102 | self.model = nn.Sequential(*model_list) 103 | 104 | def forward(self, input): 105 | 106 | return self.model(input) 107 | 108 | class SpatialNet(nn.Module): 109 | def __init__(self, bp_channel, num_filters): 110 | super(SpatialNet, self).__init__() 111 | 112 | model_list = [nn.Conv2d(bp_channel, num_filters*4, kernel_size=3, stride=1, padding=1, bias=True), nn.GroupNorm(num_channels=num_filters*4, num_groups=1, affine=False), nn.LeakyReLU(0.2, True)] 113 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 114 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 115 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 116 | model_list += [ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4), ResidualBlock(planes=num_filters*4)] 117 | model_list += [UpSamplingBlock(planes=num_filters, length=256, width=256, lups=2, wups=2)] 118 | 119 | model_list += [nn.Conv2d(num_filters, 1, kernel_size=1, stride=1, padding=0, bias=True)] 120 | self.model = nn.Sequential(*model_list) 121 | 122 | def forward(self, input): 123 | 124 | return self.model(input) 125 | 126 | class DSigNet(nn.Module): 127 | def __init__(self, options, bp_channel=4, num_filters=16, scale_factor=2) -> None: 128 | super().__init__() 129 | geo_real = {'nVoxelX': int(options[2]), 'sVoxelX': float(options[4]) * int(options[2]), 'dVoxelX': float(options[4]), 130 | 'nVoxelY': int(options[3]), 'sVoxelY': float(options[4]) * int(options[3]), 'dVoxelY': float(options[4]), 131 | 'nDetecU': int(options[1]), 'sDetecU': float(options[5]) * int(options[1]), 'dDetecU': float(options[5]), 132 | 'offOriginX': 0.0, 'offOriginY': 0.0, 133 | 'views': int(options[0]), 'slices': 1, 134 | 'DSD': float(options[8]) + float(options[9]), 'DSO': float(options[8]), 'DOD': float(options[9]), 135 | 'start_angle': 0.0, 'end_angle': float(options[7]) * int(options[0]), 136 | 'mode': 'fanflat', 'extent': 1, # currently extent supports 1, 2, or 3. 137 | } 138 | geo_virtual = dict() 139 | geo_virtual.update({x: int(geo_real[x]/scale_factor) for x in ['views']}) 140 | geo_virtual.update({x: int(geo_real[x]/scale_factor) for x in ['nVoxelX', 'nVoxelY', 'nDetecU']}) 141 | geo_virtual.update({x: geo_real[x]/scale_factor for x in ['sVoxelX', 'sVoxelY', 'sDetecU', 'DSD', 'DSO', 'DOD', 'offOriginX', 'offOriginY']}) 142 | geo_virtual.update({x: geo_real[x] for x in ['dVoxelX', 'dVoxelY', 'dDetecU', 'slices', 'start_angle', 'end_angle', 'mode', 'extent']}) 143 | geo_virtual['indices'], geo_virtual['weights'] = PixelIndexCal_cuda(geo_virtual) 144 | self.SinoNet = SinoNet(bp_channel, num_filters) 145 | self.BackProjNet = BackProjNet(geo_virtual, bp_channel) 146 | self.SpatialNet = SpatialNet(bp_channel, num_filters) 147 | for module in self.modules(): 148 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 149 | nn.init.xavier_uniform_(module.weight) 150 | if module.bias is not None: 151 | module.bias.data.zero_() 152 | 153 | def forward(self, input): 154 | output = self.SinoNet(input) 155 | output = self.BackProjNet(output) 156 | output = self.SpatialNet(output) 157 | return output 158 | 159 | -------------------------------------------------------------------------------- /recon/models/FBPConvNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DownBlock(nn.Module): 5 | def __init__(self, in_ch, out_ch, first_block=False): 6 | super(DownBlock, self).__init__() 7 | self.model = nn.Sequential( 8 | nn.Sequential(nn.Conv2d(1, in_ch, 3, 1, 1), nn.BatchNorm2d(in_ch), nn.ReLU(inplace=True)) if first_block else nn.MaxPool2d(2), 9 | nn.Conv2d(in_ch, out_ch, 3, 1, 1), 10 | nn.BatchNorm2d(out_ch), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 13 | nn.BatchNorm2d(out_ch), 14 | nn.ReLU(inplace=True) 15 | ) 16 | 17 | def forward(self, x): 18 | out = self.model(x) 19 | return out 20 | 21 | class UpBlock(nn.Module): 22 | def __init__(self, in_ch, out_ch): 23 | super(UpBlock, self).__init__() 24 | self.pool = nn.Sequential( 25 | nn.ConvTranspose2d(in_ch, out_ch, 3, 2, 1, 1), 26 | nn.BatchNorm2d(out_ch), 27 | nn.ReLU(inplace=True), 28 | ) 29 | self.model = nn.Sequential( 30 | nn.Conv2d(in_ch, out_ch, 3, 1, 1), 31 | nn.BatchNorm2d(out_ch), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 34 | nn.BatchNorm2d(out_ch), 35 | nn.ReLU(inplace=True) 36 | ) 37 | 38 | def forward(self, x, y): 39 | x_t = self.pool(x) 40 | x_in = torch.cat((x_t, y), dim=1) 41 | out = self.model(x_in) 42 | return out 43 | 44 | class FBPConvNet(nn.Module): 45 | def __init__(self): 46 | super(FBPConvNet, self).__init__() 47 | self.conv1 = DownBlock(64, 64, True) 48 | self.conv2 = DownBlock(64, 128) 49 | self.conv3 = DownBlock(128, 256) 50 | self.conv4 = DownBlock(256, 512) 51 | self.conv5 = DownBlock(512, 1024) 52 | self.conv4_t = UpBlock(1024, 512) 53 | self.conv3_t = UpBlock(512, 256) 54 | self.conv2_t = UpBlock(256, 128) 55 | self.conv1_t = UpBlock(128, 64) 56 | self.conv_last = nn.Conv2d(64, 1, 1, 1, 0) 57 | for module in self.modules(): 58 | if isinstance(module, nn.Conv2d): 59 | nn.init.normal_(module.weight, mean=0, std=0.01) 60 | if module.bias is not None: 61 | module.bias.data.zero_() 62 | if isinstance(module, nn.ConvTranspose2d): 63 | nn.init.normal_(module.weight, mean=0, std=0.01) 64 | if module.bias is not None: 65 | module.bias.data.zero_() 66 | if isinstance(module, nn.BatchNorm2d): 67 | module.weight.data.fill_(1) 68 | module.bias.data.zero_() 69 | 70 | def forward(self, x): 71 | # encoder 72 | x_1 = self.conv1(x) 73 | x_2 = self.conv2(x_1) 74 | x_3 = self.conv3(x_2) 75 | x_4 = self.conv4(x_3) 76 | x_5 = self.conv5(x_4) 77 | 78 | # decoder 79 | y_4 = self.conv4_t(x_5, x_4) 80 | y_3 = self.conv3_t(y_4, x_3) 81 | y_2 = self.conv2_t(y_3, x_2) 82 | y_1 = self.conv1_t(y_2, x_1) 83 | y = self.conv_last(y_1) 84 | out = y + x 85 | return out 86 | -------------------------------------------------------------------------------- /recon/models/FistaNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .LEARN import projector 5 | from .LEARN import projector_t 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | """docstring for BasicBlock""" 10 | 11 | def __init__(self, options, features=32): 12 | super(BasicBlock, self).__init__() 13 | self.Sp = nn.Softplus() 14 | self.conv_D = nn.Conv2d(1, features, (3,3), stride=1, padding=1) 15 | self.conv_forward = nn.Sequential( 16 | nn.Conv2d(features, features, (3,3), stride=1, padding=1), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(features, features, (3,3), stride=1, padding=1), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(features, features, (3,3), stride=1, padding=1), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(features, features, (3,3), stride=1, padding=1) 23 | ) 24 | self.conv_backward = nn.Sequential( 25 | nn.Conv2d(features, features, (3,3), stride=1, padding=1), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(features, features, (3,3), stride=1, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(features, features, (3,3), stride=1, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(features, features, (3,3), stride=1, padding=1) 32 | ) 33 | self.conv_G = nn.Conv2d(features, 1, (3,3), stride=1, padding=1) 34 | self.options = nn.Parameter(options, requires_grad=False) 35 | self.projector = projector() 36 | self.projector_t = projector_t() 37 | 38 | def forward(self, x, y, W_inv, lambda_step, soft_thr): 39 | p_error = self.projector(x, self.options) - y 40 | x_error = self.projector_t(p_error, self.options) * W_inv 41 | x = x - self.Sp(lambda_step) * x_error 42 | x_input = x.clone() 43 | x_D = self.conv_D(x_input) 44 | x_forward = self.conv_forward(x_D) 45 | x_st = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.Sp(soft_thr))) 46 | x_backward = self.conv_backward(x_st) 47 | x_G = self.conv_G(x_backward) 48 | x_pred = F.relu(x_input + x_G) 49 | x_D_est = self.conv_backward(x_forward) 50 | symloss = x_D_est - x_D 51 | return x_pred, symloss, x_st 52 | 53 | class FistaNet(nn.Module): 54 | def __init__(self, LayerNo, options): 55 | super(FistaNet, self).__init__() 56 | self.LayerNo = LayerNo 57 | W_inv = self.normlize_weight(options) 58 | self.W_inv = nn.Parameter(W_inv, requires_grad=False) 59 | self.fcs = nn.ModuleList([BasicBlock(options, features=32) for i in range(LayerNo)]) 60 | for module in self.modules(): 61 | if isinstance(module, nn.Conv2d): 62 | nn.init.xavier_normal_(module.weight) 63 | if module.bias is not None: 64 | nn.init.constant_(module.bias, 0) 65 | elif isinstance(module, nn.BatchNorm2d): 66 | nn.init.constant_(module.weight, 1) 67 | nn.init.constant_(module.bias, 0) 68 | elif isinstance(module, nn.Linear): 69 | nn.init.normal_(module.weight, 0, 0.01) 70 | nn.init.constant_(module.bias, 0) 71 | # thresholding value 72 | self.w_theta = nn.Parameter(torch.Tensor([-0.5])) 73 | self.b_theta = nn.Parameter(torch.Tensor([-2])) 74 | # gradient step 75 | self.w_mu = nn.Parameter(torch.Tensor([-0.2])) 76 | self.b_mu = nn.Parameter(torch.Tensor([0.1])) 77 | # two-step update weight 78 | self.w_rho = nn.Parameter(torch.Tensor([0.5])) 79 | self.b_rho = nn.Parameter(torch.Tensor([0])) 80 | self.Sp = nn.Softplus() 81 | 82 | def normlize_weight(self, options): 83 | height = options[2].int().item() 84 | width = options[3].int().item() 85 | x0 = torch.ones(1, 1, height, width) 86 | p = projector()(x0.double().cuda().contiguous(), options.double().cuda()) 87 | W = projector_t()(p, options.double().cuda()) 88 | W_inv = 1 / (W + 1e-6) 89 | W_inv[W==0] = 0 90 | W_inv = W_inv.cpu().float() 91 | return W_inv 92 | 93 | def forward(self, x0, b): 94 | xold = x0 95 | y = xold 96 | layers_sym = [] 97 | layers_st = [] 98 | xnews = [] 99 | xnews.append(xold) 100 | for i in range(self.LayerNo): 101 | theta_ = self.w_theta * i + self.b_theta 102 | mu_ = self.w_mu * i + self.b_mu 103 | xnew, layer_sym, layer_st = self.fcs[i](y, b, self.W_inv, mu_, theta_) 104 | rho_ = (self.Sp(self.w_rho * i + self.b_rho) - self.Sp(self.b_rho)) / self.Sp(self.w_rho * i + self.b_rho) 105 | y = xnew + rho_ * (xnew - xold) 106 | xold = xnew 107 | xnews.append(xnew) 108 | layers_sym.append(layer_sym) 109 | layers_st.append(layer_st) 110 | layers_sym = torch.cat(layers_sym, dim=0) 111 | layers_st = torch.cat(layers_st, dim=0) 112 | return xnew, layers_sym, layers_st -------------------------------------------------------------------------------- /recon/models/FramingUNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Harr_wav(nn.Module): 5 | def __init__(self): 6 | super(Harr_wav, self).__init__() 7 | filter = [[[ 0.5, 0.5], [ 0.5, 0.5]], 8 | [[-0.5, 0.5], [-0.5, 0.5]], 9 | [[-0.5, -0.5], [ 0.5, 0.5]], 10 | [[ 0.5, -0.5], [-0.5, 0.5]]] 11 | weight = torch.tensor(filter).view(4, 1, 1, 2, 2) 12 | self.weight = nn.Parameter(weight, requires_grad=False) 13 | 14 | def forward(self, x): 15 | return nn.functional.conv3d(x.unsqueeze(1), self.weight, stride=(1,2,2)) 16 | 17 | class Harr_iwav_cat(nn.Module): 18 | def __init__(self): 19 | super(Harr_iwav_cat, self).__init__() 20 | filter_LL = [[ 0.5, 0.5], [ 0.5, 0.5]] 21 | filter_LH = [[-0.5, 0.5], [-0.5, 0.5]] 22 | filter_HL = [[-0.5, -0.5], [ 0.5, 0.5]] 23 | filter_HH= [[ 0.5, -0.5], [-0.5, 0.5]] 24 | weight_LL = torch.tensor(filter_LL).view(1,1,1,2,2) 25 | weight_LH = torch.tensor(filter_LH).view(1,1,1,2,2) 26 | weight_HL = torch.tensor(filter_HL).view(1,1,1,2,2) 27 | weight_HH = torch.tensor(filter_HH).view(1,1,1,2,2) 28 | self.weight_LL = nn.Parameter(weight_LL, requires_grad=False) 29 | self.weight_LH = nn.Parameter(weight_LH, requires_grad=False) 30 | self.weight_HL = nn.Parameter(weight_HL, requires_grad=False) 31 | self.weight_HH = nn.Parameter(weight_HH, requires_grad=False) 32 | 33 | def forward(self, x_LL, x, y): 34 | LL = nn.functional.conv_transpose3d(x_LL.unsqueeze(1), self.weight_LL, stride=(1,2,2)).squeeze(1) 35 | LH = nn.functional.conv_transpose3d(x[:,[1],...], self.weight_LH, stride=(1,2,2)).squeeze(1) 36 | HL = nn.functional.conv_transpose3d(x[:,[2],...], self.weight_HL, stride=(1,2,2)).squeeze(1) 37 | HH = nn.functional.conv_transpose3d(x[:,[3],...], self.weight_HH, stride=(1,2,2)).squeeze(1) 38 | out = torch.cat((LL,LH,HL,HH, y), dim=1) 39 | return out 40 | 41 | class ConvBlock(nn.Module): 42 | def __init__(self, in_ch, hid_ch, out_ch): 43 | super(ConvBlock, self).__init__() 44 | self.model = nn.Sequential( 45 | nn.Conv2d(in_ch, hid_ch, 3, 1, 1), 46 | nn.BatchNorm2d(hid_ch), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(hid_ch, out_ch, 3, 1, 1), 49 | nn.BatchNorm2d(out_ch), 50 | nn.ReLU(inplace=True) 51 | ) 52 | 53 | def forward(self, x): 54 | return self.model(x) 55 | 56 | class FramingUNet(nn.Module): 57 | def __init__(self): 58 | super(FramingUNet, self).__init__() 59 | self.conv1 = ConvBlock(1, 64, 64) 60 | self.conv2 = ConvBlock(64, 128, 128) 61 | self.conv3 = ConvBlock(128, 256, 256) 62 | self.conv4 = ConvBlock(256, 512, 512) 63 | self.conv5 = ConvBlock(512, 1024, 512) 64 | self.conv4_t = ConvBlock(2560, 512, 256) 65 | self.conv3_t = ConvBlock(1280, 256, 128) 66 | self.conv2_t = ConvBlock(640, 128, 64) 67 | self.conv1_t = ConvBlock(320, 64, 64) 68 | self.conv_last = nn.Conv2d(64, 1, 3, 1, 1) 69 | self.downsample = Harr_wav() 70 | self.upsample = Harr_iwav_cat() 71 | for module in self.modules(): 72 | if isinstance(module, nn.Conv2d): 73 | nn.init.normal_(module.weight, mean=0, std=0.01) 74 | if module.bias is not None: 75 | module.bias.data.zero_() 76 | if isinstance(module, nn.BatchNorm2d): 77 | module.weight.data.fill_(1) 78 | module.bias.data.zero_() 79 | 80 | def forward(self, x0): 81 | x1 = self.conv1(x0) 82 | wav1 = self.downsample(x1) 83 | x2 = self.conv2(wav1[:,0,...]) 84 | wav2 = self.downsample(x2) 85 | x3 = self.conv3(wav2[:,0,...]) 86 | wav3 = self.downsample(x3) 87 | x4 = self.conv4(wav3[:,0,...]) 88 | wav4 = self.downsample(x4) 89 | x5 = self.conv5(wav4[:,0,...]) 90 | iwav4 = self.upsample(x5, wav4, x4) 91 | x4_t = self.conv4_t(iwav4) 92 | iwav3 = self.upsample(x4_t, wav3, x3) 93 | x3_t = self.conv3_t(iwav3) 94 | iwav2 = self.upsample(x3_t, wav2, x2) 95 | x2_t = self.conv2_t(iwav2) 96 | iwav1 = self.upsample(x2_t, wav1, x1) 97 | x1_t = self.conv1_t(iwav1) 98 | out = self.conv_last(x1_t) 99 | return out -------------------------------------------------------------------------------- /recon/models/HDNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .LEARN_FBP import FBP 4 | 5 | class DownBlock_2d(nn.Module): 6 | def __init__(self, in_ch, out_ch, first_block=False): 7 | super(DownBlock_2d, self).__init__() 8 | self.model = nn.Sequential( 9 | nn.Identity() if first_block else nn.Conv2d(in_ch, in_ch, 3, 2, 1), 10 | nn.Conv2d(in_ch, out_ch, 3, 1, 1), 11 | nn.BatchNorm2d(out_ch), 12 | nn.LeakyReLU(0.2, inplace=True), 13 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 14 | nn.BatchNorm2d(out_ch), 15 | nn.LeakyReLU(0.2, inplace=True) 16 | ) 17 | 18 | def forward(self, x): 19 | out = self.model(x) 20 | return out 21 | 22 | class UpBlock_2d(nn.Module): 23 | def __init__(self, in_ch, out_ch): 24 | super(UpBlock_2d, self).__init__() 25 | self.pool = nn.Sequential( 26 | nn.ConvTranspose2d(in_ch, out_ch, 3, 2, 1, 1), 27 | nn.BatchNorm2d(out_ch), 28 | nn.LeakyReLU(0.2, inplace=True) 29 | ) 30 | self.model = nn.Sequential( 31 | nn.Conv2d(in_ch, out_ch, 3, 1, 1), 32 | nn.BatchNorm2d(out_ch), 33 | nn.LeakyReLU(0.2, inplace=True), 34 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 35 | nn.BatchNorm2d(out_ch), 36 | nn.LeakyReLU(0.2, inplace=True) 37 | ) 38 | 39 | def forward(self, x, y): 40 | x_t = self.pool(x) 41 | x_in = torch.cat((x_t, y), dim=1) 42 | out = self.model(x_in) 43 | return out 44 | 45 | class UNet_2d(nn.Module): 46 | def __init__(self): 47 | super(UNet_2d, self).__init__() 48 | self.conv1 = DownBlock_2d(1, 64, True) 49 | self.conv2 = DownBlock_2d(64, 128) 50 | self.conv3 = DownBlock_2d(128, 256) 51 | self.conv4 = DownBlock_2d(256, 512) 52 | self.conv5 = DownBlock_2d(512, 1024) 53 | self.conv4_t = UpBlock_2d(1024, 512) 54 | self.conv3_t = UpBlock_2d(512, 256) 55 | self.conv2_t = UpBlock_2d(256, 128) 56 | self.conv1_t = UpBlock_2d(128, 64) 57 | self.conv_last = nn.Conv2d(64, 1, 1, 1, 0) 58 | 59 | def forward(self, x): 60 | # encoder 61 | x_1 = self.conv1(x) 62 | x_2 = self.conv2(x_1) 63 | x_3 = self.conv3(x_2) 64 | x_4 = self.conv4(x_3) 65 | x_5 = self.conv5(x_4) 66 | 67 | # decoder 68 | y_4 = self.conv4_t(x_5, x_4) 69 | y_3 = self.conv3_t(y_4, x_3) 70 | y_2 = self.conv2_t(y_3, x_2) 71 | y_1 = self.conv1_t(y_2, x_1) 72 | y = self.conv_last(y_1) 73 | out = y + x 74 | return out 75 | 76 | class HDNet_2d(nn.Module): 77 | def __init__(self, options): 78 | super(HDNet_2d, self).__init__() 79 | self.model = nn.Sequential(UNet_2d(), FBP(options), UNet_2d()) 80 | for module in self.modules(): 81 | if isinstance(module, nn.Conv2d): 82 | nn.init.normal_(module.weight, mean=0, std=0.01) 83 | if module.bias is not None: 84 | module.bias.data.zero_() 85 | if isinstance(module, nn.ConvTranspose2d): 86 | nn.init.normal_(module.weight, mean=0, std=0.01) 87 | if module.bias is not None: 88 | module.bias.data.zero_() 89 | if isinstance(module, nn.BatchNorm2d): 90 | module.weight.data.fill_(1) 91 | module.bias.data.zero_() 92 | 93 | def forward(self, y): 94 | out = self.model(y) 95 | return out 96 | -------------------------------------------------------------------------------- /recon/models/KSAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class KSAE(nn.Module): 5 | def __init__(self, imgsize=256, hid_ch=1024, sparsity=100) -> None: 6 | super().__init__() 7 | self.encoder = nn.Sequential( 8 | nn.Linear(imgsize, hid_ch), 9 | nn.ReLU(inplace=True), 10 | nn.Linear(hid_ch, hid_ch), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(hid_ch, hid_ch), 13 | nn.ReLU(inplace=True) 14 | ) 15 | self.decoder = nn.Sequential( 16 | nn.Linear(hid_ch, hid_ch), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(hid_ch, hid_ch), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(hid_ch, imgsize), 21 | ) 22 | self.sparsity = sparsity 23 | 24 | def forward(self, x): 25 | B, C, Ph, Pw, H, W = x.shape 26 | x_in = x.view(B*C*Ph*Pw, H*W).contiguous() 27 | feature = self.encoder(x_in) 28 | mask = torch.zeros_like(feature) 29 | _, indices = torch.topk(feature.detach(), self.sparsity, dim=-1, sorted=False) 30 | mask.scatter_(1, indices, 1.0) 31 | feature = feature * mask 32 | res = self.decoder(feature) 33 | out = res.view(B, C, Ph, Pw, H, W).contiguous() 34 | return out 35 | -------------------------------------------------------------------------------- /recon/models/LEARN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import ctlib 5 | 6 | class prj_fun(Function): 7 | @staticmethod 8 | def forward(self, image, options): 9 | self.save_for_backward(options) 10 | return ctlib.projection(image, options) 11 | 12 | @staticmethod 13 | def backward(self, grad_output): 14 | options = self.saved_tensors[0] 15 | grad_input = ctlib.projection_t(grad_output.contiguous(), options) 16 | return grad_input, None 17 | 18 | class prj_t_fun(Function): 19 | @staticmethod 20 | def forward(self, proj, options): 21 | self.save_for_backward(options) 22 | return ctlib.projection_t(proj, options) 23 | 24 | @staticmethod 25 | def backward(self, grad_output): 26 | options = self.saved_tensors[0] 27 | grad_input = ctlib.projection(grad_output.contiguous(), options) 28 | return grad_input, None 29 | 30 | class projector(nn.Module): 31 | def __init__(self): 32 | super(projector, self).__init__() 33 | 34 | def forward(self, image, options): 35 | return prj_fun.apply(image, options) 36 | 37 | class projector_t(nn.Module): 38 | def __init__(self): 39 | super(projector_t, self).__init__() 40 | 41 | def forward(self, proj, options): 42 | return prj_t_fun.apply(proj, options) 43 | 44 | class fidelity_module(nn.Module): 45 | def __init__(self, options): 46 | super(fidelity_module, self).__init__() 47 | self.options = nn.Parameter(options, requires_grad=False) 48 | self.weight = nn.Parameter(torch.Tensor(1).squeeze()) 49 | self.projector = projector() 50 | self.projector_t = projector_t() 51 | 52 | def forward(self, input_data, proj): 53 | temp = self.projector(input_data, self.options) - proj 54 | intervening_res = self.projector_t(temp, self.options) 55 | out = input_data - self.weight * intervening_res 56 | return out 57 | 58 | class Iter_block(nn.Module): 59 | def __init__(self, hid_channels, kernel_size, padding, options): 60 | super(Iter_block, self).__init__() 61 | self.block1 = fidelity_module(options) 62 | self.block2 = nn.Sequential( 63 | nn.Conv2d(1, hid_channels, kernel_size=kernel_size, padding=padding), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding), 66 | nn.ReLU(inplace=True), 67 | nn.Conv2d(hid_channels, 1, kernel_size=kernel_size, padding=padding) 68 | ) 69 | self.relu = nn.ReLU(inplace=True) 70 | 71 | def forward(self, input_data, proj): 72 | tmp1 = self.block1(input_data, proj) 73 | tmp2 = self.block2(input_data) 74 | output = tmp1 + tmp2 75 | output = self.relu(output) 76 | return output 77 | 78 | class LEARN(nn.Module): 79 | def __init__(self, options, block_num=50, hid_channels=48, kernel_size=5, padding=2): 80 | super(LEARN, self).__init__() 81 | self.model = nn.ModuleList([Iter_block(hid_channels, kernel_size, padding, options) for i in range(block_num)]) 82 | for module in self.modules(): 83 | if isinstance(module, fidelity_module): 84 | module.weight.data.zero_() 85 | if isinstance(module, nn.Conv2d): 86 | nn.init.normal_(module.weight, mean=0, std=0.01) 87 | if module.bias is not None: 88 | module.bias.data.zero_() 89 | 90 | def forward(self, input_data, proj): 91 | x = input_data 92 | for index, module in enumerate(self.model): 93 | x = module(x, proj) 94 | return x 95 | -------------------------------------------------------------------------------- /recon/models/LEARN_FBP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import ctlib 5 | from .LEARN import projector 6 | 7 | class bprj_fun(Function): 8 | @staticmethod 9 | def forward(self, proj, options): 10 | self.save_for_backward(options) 11 | return ctlib.backprojection(proj, options) 12 | 13 | @staticmethod 14 | def backward(self, grad_output): 15 | options = self.saved_tensors[0] 16 | grad_input = ctlib.backprojection_t(grad_output.contiguous(), options) 17 | return grad_input, None 18 | 19 | class backprojector(nn.Module): 20 | def __init__(self): 21 | super(backprojector, self).__init__() 22 | 23 | def forward(self, image, options): 24 | return bprj_fun.apply(image, options) 25 | 26 | class FBP(nn.Module): 27 | def __init__(self, options): 28 | super(FBP, self).__init__() 29 | dets = int(options[1]) 30 | dDet = options[5] 31 | s2r = options[8] 32 | d2r = options[9] 33 | virdet = dDet * s2r / (s2r + d2r) 34 | filter = torch.empty(2 * dets - 1) 35 | pi = torch.acos(torch.tensor(-1.0)) 36 | for i in range(filter.size(0)): 37 | x = i - dets + 1 38 | if abs(x) % 2 == 1: 39 | filter[i] = -1 / (pi * pi * x * x * virdet * virdet) 40 | elif x == 0: 41 | filter[i] = 1 / (4 * virdet * virdet) 42 | else: 43 | filter[i] = 0 44 | filter = filter.view(1,1,1,-1) 45 | w = torch.arange((-dets / 2 + 0.5) * virdet, dets / 2 * virdet, virdet) 46 | w = s2r / torch.sqrt(s2r ** 2 + w ** 2) 47 | w = w.view(1,1,1,-1) * virdet 48 | self.w = nn.Parameter(w, requires_grad=False) 49 | self.filter = nn.Parameter(filter, requires_grad=False) 50 | self.options = nn.Parameter(options, requires_grad=False) 51 | self.backprojector = backprojector() 52 | self.dets = dets 53 | self.coef = pi / options[0] 54 | 55 | def forward(self, projection): 56 | p = projection * self.w 57 | p = torch.nn.functional.conv2d(p, self.filter, padding=(0,self.dets-1)) 58 | recon = self.backprojector(p, self.options) 59 | recon = recon * self.coef 60 | return recon 61 | 62 | class fidelity_module(nn.Module): 63 | def __init__(self, options): 64 | super(fidelity_module, self).__init__() 65 | self.options = nn.Parameter(options, requires_grad=False) 66 | self.weight = nn.Parameter(torch.Tensor(1).squeeze()) 67 | self.projector = projector() 68 | self.fbp = FBP(options) 69 | 70 | def forward(self, input_data, proj): 71 | p_tmp = self.projector(input_data, self.options) 72 | y_error = proj - p_tmp 73 | x_error = self.fbp(y_error) 74 | out = self.weight * x_error + input_data 75 | return out 76 | 77 | class Iter_block(nn.Module): 78 | def __init__(self, hid_channels, kernel_size, padding, options): 79 | super(Iter_block, self).__init__() 80 | self.block1 = fidelity_module(options) 81 | self.block2 = nn.Sequential( 82 | nn.Conv2d(1, hid_channels, kernel_size=kernel_size, padding=padding), 83 | nn.ReLU(inplace=True), 84 | nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(hid_channels, 1, kernel_size=kernel_size, padding=padding) 87 | ) 88 | self.relu = nn.ReLU(inplace=True) 89 | 90 | def forward(self, input_data, proj): 91 | tmp1 = self.block1(input_data, proj) 92 | tmp2 = self.block2(input_data) 93 | output = tmp1 + tmp2 94 | output = self.relu(output) 95 | return output 96 | 97 | class LEARN_FBP(nn.Module): 98 | def __init__(self, options, block_num=50, hid_channels=48, kernel_size=5, padding=2): 99 | super(LEARN_FBP, self).__init__() 100 | self.model = nn.ModuleList([Iter_block(hid_channels, kernel_size, padding, options) for i in range(block_num)]) 101 | for module in self.modules(): 102 | if isinstance(module, fidelity_module): 103 | module.weight.data.zero_() 104 | if isinstance(module, nn.Conv2d): 105 | nn.init.normal_(module.weight, mean=0, std=0.01) 106 | if module.bias is not None: 107 | module.bias.data.zero_() 108 | 109 | def forward(self, input_data, proj): 110 | x = input_data 111 | for index, module in enumerate(self.model): 112 | x = module(x, proj) 113 | return x -------------------------------------------------------------------------------- /recon/models/LPD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .LEARN import projector 4 | from .LEARN import projector_t 5 | 6 | class primal_module(nn.Module): 7 | def __init__(self, n_primal, hid_channels, kernel_size, padding, options): 8 | super(primal_module, self).__init__() 9 | self.model = nn.Sequential( 10 | nn.Conv2d(n_primal+1, hid_channels, kernel_size=kernel_size, padding=padding), 11 | nn.PReLU(), 12 | nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding), 13 | nn.PReLU(), 14 | nn.Conv2d(hid_channels, n_primal, kernel_size=kernel_size, padding=padding), 15 | ) 16 | self.options = nn.Parameter(options, requires_grad=False) 17 | self.projector_t = projector_t() 18 | 19 | def forward(self, x, h): 20 | t = self.projector_t(h, self.options) 21 | inputs = torch.cat((x, t), dim=1) 22 | return x + self.model(inputs) 23 | 24 | class dual_module(nn.Module): 25 | def __init__(self, n_dual, hid_channels, kernel_size, padding, options): 26 | super(dual_module, self).__init__() 27 | self.model = nn.Sequential( 28 | nn.Conv2d(n_dual+2, hid_channels, kernel_size=kernel_size, padding=padding), 29 | nn.PReLU(), 30 | nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding), 31 | nn.PReLU(), 32 | nn.Conv2d(hid_channels, n_dual, kernel_size=kernel_size, padding=padding), 33 | ) 34 | self.options = nn.Parameter(options, requires_grad=False) 35 | self.projector = projector() 36 | 37 | def forward(self, x, y, h): 38 | t = self.projector(x, self.options) 39 | inputs = torch.cat((h,t,y), dim=1) 40 | return h + self.model(inputs) 41 | 42 | class Learned_primal_dual(nn.Module): 43 | def __init__(self, options, n_iter=10, n_primal=5, n_dual=5, hid_channels=32, kernel_size=3, padding=1): 44 | super(Learned_primal_dual, self).__init__() 45 | self.primal_models = nn.ModuleList([primal_module(n_primal, hid_channels, kernel_size, padding, options) for i in range(n_iter)]) 46 | self.dual_models = nn.ModuleList([dual_module(n_dual, hid_channels, kernel_size, padding, options) for i in range(n_iter)]) 47 | self.n_iter = n_iter 48 | self.n_primal = n_primal 49 | self.n_dual = n_dual 50 | for module in self.modules(): 51 | if isinstance(module, nn.Conv2d): 52 | nn.init.normal_(module.weight, mean=0, std=0.01) 53 | if module.bias is not None: 54 | module.bias.data.zero_() 55 | 56 | def forward(self, x0, y): 57 | h0 = torch.zeros(y.size(0), self.n_dual, y.size(2), y.size(3), device=y.device) 58 | x0 = x0.expand(x0.size(0), self.n_primal, x0.size(2), x0.size(3)) 59 | for i in range(self.n_iter): 60 | h = self.dual_models[i](x0[:,[1],:,:], y, h0) 61 | x = self.primal_models[i](x0, h[:,[0],:,:]) 62 | x0 = x 63 | h0 = h 64 | return x[:,[0],:,:] -------------------------------------------------------------------------------- /recon/models/MAGIC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import ctlib 5 | from .LEARN import fidelity_module 6 | 7 | class adj_weight(nn.Module): 8 | def __init__(self, k): 9 | super(adj_weight, self).__init__() 10 | self.k = k 11 | 12 | def forward(self, x): 13 | return ctlib.laplacian(x, self.k) 14 | 15 | def img2patch(x, patch_size, stride): 16 | x_size = x.size() 17 | Ph = x_size[-2]-patch_size+1 18 | Pw = x_size[-1]-patch_size+1 19 | y = torch.empty(*x_size[:-2], Ph, Pw, patch_size, patch_size, device=x.device) 20 | for i in range(patch_size): 21 | for j in range(patch_size): 22 | y[...,i,j] = x[...,i:i+Ph,j:j+Ph] 23 | return y[...,::stride,::stride,:,:] 24 | 25 | def patch2img(y, patch_size, stride, x_size): 26 | Ph = x_size[-2]-patch_size+1 27 | Pw = x_size[-1]-patch_size+1 28 | y_tmp = torch.zeros(*x_size[:-2], Ph, Pw, patch_size, patch_size, device=y.device) 29 | y_tmp[...,::stride,::stride,:,:] = y 30 | x = torch.zeros(*x_size, device=y.device) 31 | for i in range(patch_size): 32 | for j in range(patch_size): 33 | x[...,i:i+Ph,j:j+Ph] += y_tmp[...,i,j] 34 | return x 35 | 36 | class img2patch_fun(Function): 37 | 38 | @staticmethod 39 | def forward(self, x, size): 40 | self.save_for_backward(size) 41 | patch_size = size[0] 42 | stride = size[1] 43 | p_size = size[5:] 44 | y = img2patch(x, patch_size, stride) 45 | out = y.reshape(y.size(0), p_size[1]*p_size[2], p_size[3]*p_size[4]) 46 | return out 47 | 48 | @staticmethod 49 | def backward(self, grad_output): 50 | size = self.saved_tensors[0] 51 | patch_size = size[0] 52 | stride = size[1] 53 | x_size = size[2:5] 54 | p_size = size[5:] 55 | y = grad_output.view(grad_output.size(0), *p_size) 56 | grad_input = patch2img(y, patch_size, stride, (grad_output.size(0), *x_size)) 57 | return grad_input, None 58 | 59 | class patch2img_fun(Function): 60 | 61 | @staticmethod 62 | def forward(self, x, size): 63 | self.save_for_backward(size) 64 | patch_size = size[0] 65 | stride = size[1] 66 | x_size = size[2:5] 67 | p_size = size[5:] 68 | y = x.view(x.size(0), *p_size) 69 | out = patch2img(y, patch_size, stride, (x.size(0), *x_size)) 70 | return out 71 | 72 | @staticmethod 73 | def backward(self, grad_output): 74 | size = self.saved_tensors[0] 75 | patch_size = size[0] 76 | stride = size[1] 77 | p_size = size[5:] 78 | y = img2patch(grad_output, patch_size, stride) 79 | grad_input = y.reshape(grad_output.size(0), p_size[1]*p_size[2], p_size[3]*p_size[4]) 80 | return grad_input, None 81 | 82 | class Im2Patch(nn.Module): 83 | def __init__(self, patch_size, stride, img_size) -> None: 84 | super(Im2Patch, self).__init__() 85 | Ph = (img_size-patch_size) // stride + 1 86 | Pw = (img_size-patch_size) // stride + 1 87 | self.size = torch.LongTensor([patch_size, stride, 1, img_size, img_size, 1, Ph, Pw, patch_size, patch_size]) 88 | 89 | def forward(self, x): 90 | return img2patch_fun.apply(x, self.size) 91 | 92 | class Patch2Im(nn.Module): 93 | def __init__(self, patch_size, stride, img_size) -> None: 94 | super(Patch2Im, self).__init__() 95 | Ph = (img_size-patch_size) // stride + 1 96 | Pw = (img_size-patch_size) // stride + 1 97 | self.size = torch.LongTensor([patch_size, stride, 1, img_size, img_size, 1, Ph, Pw, patch_size, patch_size]) 98 | m = torch.ones(1, Ph * Pw, patch_size ** 2) 99 | mask = patch2img_fun.apply(m, self.size) 100 | self.mask = nn.Parameter(mask, requires_grad=False) 101 | 102 | def forward(self, x): 103 | y = patch2img_fun.apply(x, self.size) 104 | out = y / self.mask 105 | return out 106 | 107 | class GCN(nn.Module): 108 | def __init__(self, in_channels, out_channels): 109 | super(GCN, self).__init__() 110 | self.weight = nn.Parameter(torch.FloatTensor(in_channels, out_channels)) 111 | self.bias = nn.Parameter(torch.FloatTensor(out_channels)) 112 | 113 | def forward(self, x, adj): 114 | t = x.view(-1, x.size(2)) 115 | support = torch.mm(t, self.weight) 116 | support = support.view(x.size(0), x.size(1), -1) 117 | out = torch.zeros_like(support) 118 | for i in range(x.size(0)): 119 | out[i] = torch.mm(adj[i], support[i]) 120 | out = out + self.bias 121 | return out 122 | 123 | 124 | class Iter_block(nn.Module): 125 | def __init__(self, hid_channels, kernel_size, padding, img_size, p_size, stride, gcn_hid_ch, options): 126 | super(Iter_block, self).__init__() 127 | self.block1 = fidelity_module(options) 128 | self.block2 = nn.Sequential( 129 | nn.Conv2d(1, hid_channels, kernel_size=kernel_size, padding=padding), 130 | nn.ReLU(inplace=True), 131 | nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding), 132 | nn.ReLU(inplace=True), 133 | nn.Conv2d(hid_channels, 1, kernel_size=kernel_size, padding=padding) 134 | ) 135 | self.block3 = GCN(p_size**2, gcn_hid_ch) 136 | self.block4 = GCN(gcn_hid_ch, p_size**2) 137 | self.image2patch = Im2Patch(p_size, stride, img_size) 138 | self.patch2image = Patch2Im(p_size, stride, img_size) 139 | self.relu = nn.ReLU(inplace=True) 140 | 141 | def forward(self, input_data, proj, adj): 142 | tmp1 = self.block1(input_data, proj) 143 | tmp2 = self.block2(input_data) 144 | patch = self.image2patch(input_data) 145 | tmp3 = self.relu(self.block3(patch, adj)) 146 | tmp3 = self.block4(tmp3, adj) 147 | tmp3 = self.patch2image(tmp3) 148 | output = tmp1 + tmp2 + tmp3 149 | output = self.relu(output) 150 | return output 151 | 152 | class MAGIC(nn.Module): 153 | def __init__(self, options, block_num=50, hid_channels=64, kernel_size=5, padding=2, img_size=256, p_size=6, stride=2, gcn_hid_ch=64, k=9): 154 | super(MAGIC, self).__init__() 155 | self.block1 = nn.ModuleList([Iter_block(hid_channels, kernel_size, padding, img_size, p_size, stride, gcn_hid_ch, options) for i in range(block_num//2)]) 156 | self.block2 = nn.ModuleList([Iter_block(hid_channels, kernel_size, padding, img_size, p_size, stride, gcn_hid_ch, options) for i in range(block_num//2)]) 157 | self.adj_weight = adj_weight(k) 158 | self.image2patch = Im2Patch(p_size, stride, img_size) 159 | for module in self.modules(): 160 | if isinstance(module, fidelity_module): 161 | module.weight.data.zero_() 162 | if isinstance(module, nn.Conv2d): 163 | nn.init.normal_(module.weight, mean=0, std=0.01) 164 | if module.bias is not None: 165 | module.bias.data.zero_() 166 | if isinstance(module, GCN): 167 | nn.init.normal_(module.weight, mean=0, std=0.01) 168 | module.bias.data.zero_() 169 | 170 | def forward(self, input_data, proj): 171 | x = input_data 172 | patch1 = self.image2patch(x) 173 | adj1 = [] 174 | for i in range(input_data.size(0)): 175 | adj1.append(self.adj_weight(patch1[i])) 176 | for index, module in enumerate(self.block1): 177 | x = module(x, proj, adj1) 178 | adj2 = [] 179 | patch2 = self.image2patch(x) 180 | for i in range(input_data.size(0)): 181 | adj2.append(self.adj_weight(patch2[i])) 182 | for index, module in enumerate(self.block2): 183 | x = module(x, proj, adj2) 184 | return x 185 | -------------------------------------------------------------------------------- /recon/models/MetaInvNet.py: -------------------------------------------------------------------------------- 1 | 2 | # sub-parts of the U-Net model 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import ctlib 8 | 9 | class double_conv(nn.Module): 10 | '''(conv => BN => ReLU) * 2''' 11 | def __init__(self, in_ch, out_ch): 12 | super(double_conv, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=True), 15 | nn.BatchNorm2d(out_ch, eps=0.0001, momentum = 0.95, track_running_stats=False), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=True), 18 | nn.BatchNorm2d(out_ch, eps=0.0001, momentum = 0.95, track_running_stats=False), 19 | nn.ReLU(inplace=True)) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | return x 24 | 25 | 26 | class inconv(nn.Module): 27 | def __init__(self, in_ch, out_ch): 28 | super(inconv, self).__init__() 29 | self.conv = double_conv(in_ch, out_ch) 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | return x 34 | 35 | 36 | class down(nn.Module): 37 | def __init__(self, in_ch, out_ch): 38 | super(down, self).__init__() 39 | self.mpconv = nn.Sequential( 40 | nn.MaxPool2d(2), 41 | double_conv(in_ch, out_ch)) 42 | 43 | def forward(self, x): 44 | x = self.mpconv(x) 45 | return x 46 | 47 | 48 | class up(nn.Module): 49 | def __init__(self, in_ch, out_ch, bilinear=True): 50 | super(up, self).__init__() 51 | if bilinear: 52 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 53 | else: 54 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 55 | 56 | self.conv = double_conv(in_ch, out_ch) 57 | 58 | def forward(self, x1, x2): 59 | x1 = self.up(x1) 60 | 61 | # input is CHW 62 | diffY = x2.size()[2] - x1.size()[2] 63 | diffX = x2.size()[3] - x1.size()[3] 64 | 65 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)) 66 | 67 | x = torch.cat([x2, x1], dim=1) 68 | x = self.conv(x) 69 | return x 70 | 71 | class outconv(nn.Module): 72 | def __init__(self, in_ch, out_ch): 73 | super(outconv, self).__init__() 74 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 75 | 76 | def forward(self, x): 77 | x = self.conv(x) 78 | return x 79 | 80 | class UNet(nn.Module): 81 | def __init__(self, n_channels, n_classes): 82 | super(UNet, self).__init__() 83 | self.inc = inconv(n_channels, 64) 84 | self.down1 = down(64, 128) 85 | self.down2 = down(128, 256) 86 | self.down3 = down(256, 512) 87 | self.down4 = down(512, 512) 88 | self.up1 = up(1024, 256) 89 | self.up2 = up(512, 128) 90 | self.up3 = up(256, 64) 91 | self.up4 = up(128, 64) 92 | self.outc = outconv(64, n_classes) 93 | 94 | def forward(self, x_in): 95 | x1 = self.inc(x_in) 96 | x2 = self.down1(x1) 97 | x3 = self.down2(x2) 98 | x4 = self.down3(x3) 99 | x5 = self.down4(x4) 100 | x = self.up1(x5, x4) 101 | x = self.up2(x, x3) 102 | x = self.up3(x, x2) 103 | x = self.up4(x, x1) 104 | x = x_in+self.outc(x) 105 | return x 106 | 107 | class Wavelet(nn.Module): 108 | def __init__(self) -> None: 109 | super().__init__() 110 | D,R=self.GenerateFrameletFilter(frame=1) 111 | D_tmp=torch.zeros(3,1,3,1) 112 | for ll in range(3): 113 | D_tmp[ll,]=torch.from_numpy(np.reshape(D[ll],(-1,1))) 114 | W=D_tmp 115 | W2=W.permute(0,1,3,2) 116 | kernel_dec=np.kron(W.numpy(),W2.numpy()) 117 | kernel_dec=torch.tensor(kernel_dec,requires_grad=False,dtype=torch.float32) 118 | R_tmp=torch.zeros(3,1,1,3) 119 | for ll in range(3): 120 | R_tmp[ll,]=torch.from_numpy(np.reshape(R[ll],(1,-1))) 121 | R=R_tmp 122 | R2=R_tmp.permute(0,1,3,2) 123 | kernel_rec=np.kron(R2.numpy(),R.numpy()) 124 | kernel_rec=torch.tensor(kernel_rec,requires_grad=False,dtype=torch.float32).view(1,9,3,3) 125 | self.kernel_dec = nn.Parameter(kernel_dec, requires_grad=False) 126 | self.kernel_rec = nn.Parameter(kernel_rec, requires_grad=False) 127 | 128 | def GenerateFrameletFilter(self, frame): 129 | # Haar Wavelet 130 | if frame==0: 131 | D1=np.array([0.0, 1.0, 1.0] )/2 132 | D2=np.array([0.0, 1, -1])/2 133 | D3=('cc') 134 | R1=np.array([1 , 1 ,0])/2 135 | R2=np.array([-1, 1, 0])/2 136 | R3=('cc') 137 | D=[D1,D2,D3] 138 | R=[R1,R2,R3] 139 | # Piecewise Linear Framelet 140 | elif frame==1: 141 | D1=np.array([1.0, 2, 1])/4 142 | D2=np.array([1, 0, -1])/4*np.sqrt(2) 143 | D3=np.array([-1 ,2 ,-1])/4 144 | D4='ccc' 145 | R1=np.array([1, 2, 1])/4 146 | R2=np.array([-1, 0, 1])/4*np.sqrt(2) 147 | R3=np.array([-1, 2 ,-1])/4 148 | R4='ccc' 149 | D=[D1,D2,D3,D4] 150 | R=[R1,R2,R3,R4] 151 | # Piecewise Cubic Framelet 152 | elif frame==3: 153 | D1=np.array([1, 4 ,6, 4, 1])/16 154 | D2=np.array([1 ,2 ,0 ,-2, -1])/8 155 | D3=np.array([-1, 0 ,2 ,0, -1])/16*np.sqrt(6) 156 | D4=np.array([-1 ,2 ,0, -2, 1])/8 157 | D5=np.array([1, -4 ,6, -4, 1])/16 158 | D6='ccccc' 159 | R1=np.array([1 ,4, 6, 4 ,1])/16 160 | R2=np.array([-1, -2, 0, 2, 1])/8 161 | R3=np.array([-1, 0 ,2, 0, -1])/16*np.sqrt(6) 162 | R4=np.array([1 ,-2, 0, 2, -1])/8 163 | R5=np.array([1, -4, 6, -4 ,1])/16 164 | R6='ccccc' 165 | D=[D1,D2,D3,D4,D5,D6] 166 | R=[R1,R2,R3,R4,R5,R6] 167 | return D,R 168 | 169 | def W(self, img): 170 | Dec_coeff=F.conv2d(F.pad(img, (1,1,1,1), mode='circular'), self.kernel_dec[1:,...]) 171 | return Dec_coeff 172 | 173 | def Wt(self, Dec_coeff): 174 | kernel_rec=self.kernel_rec.view(9,1,3,3) 175 | tem_coeff=F.conv2d(F.pad(Dec_coeff, (1,1,1,1), mode='circular'), kernel_rec[1:,:,...],groups=8) 176 | rec_img=torch.sum(tem_coeff,dim=1,keepdim=True) 177 | return rec_img 178 | 179 | class MetaInvH(nn.Module): 180 | '''MetaInvNet with heavy weight CG-Init''' 181 | def __init__(self, options): 182 | super(MetaInvH, self).__init__() 183 | self.CGModule = CGClass(options) 184 | 185 | def forward(self, x, sino, laam, miu, CGInitCNN): 186 | Wu=self.CGModule.W(x) 187 | dnz=F.relu(Wu-laam)-F.relu(-Wu-laam) 188 | PtY=ctlib.projection_t(sino, self.CGModule.options) 189 | muWtV=self.CGModule.Wt(dnz) 190 | rhs=PtY+muWtV*miu 191 | 192 | uk0=CGInitCNN(x) 193 | Ax0=self.CGModule.AWx(uk0,miu) 194 | res=Ax0-rhs 195 | img=self.CGModule.CG_alg(uk0, miu, res, CGiter=5) 196 | return img 197 | 198 | class CGClass(nn.Module): 199 | def __init__(self, options): 200 | super().__init__() 201 | self.options = nn.Parameter(options, requires_grad=False) 202 | self.Wavelet = Wavelet() 203 | 204 | def AWx(self,img,mu): 205 | Ax = ctlib.projection(img, self.options) 206 | AtAx = ctlib.projection_t(Ax, self.options) 207 | Ax0 = AtAx + self.Wt(self.W(img))*mu 208 | return Ax0 209 | 210 | def W(self,img): 211 | return self.Wavelet.W(img) 212 | 213 | def Wt(self,Wu): 214 | return self.Wavelet.Wt(Wu) 215 | 216 | def pATAp(self,img): 217 | Ap=ctlib.projection(img, self.options) 218 | pATApNorm=torch.sum(Ap**2,dim=(1,2,3), keepdim=True) 219 | return pATApNorm 220 | 221 | def pWTWp(self,img,mu): 222 | Wp=self.W(img) 223 | mu_Wp=mu*(Wp**2) 224 | pWTWpNorm=torch.sum(mu_Wp,dim=(1,2,3), keepdim=True) 225 | return pWTWpNorm 226 | 227 | def CG_alg(self,x,mu,res,CGiter=20): 228 | r=res 229 | p=-res 230 | for k in range(CGiter): 231 | pATApNorm = self.pATAp(p) 232 | mu_pWtWpNorm=self.pWTWp(p,mu) 233 | rTr=torch.sum(r**2,dim=(1,2,3), keepdim=True) 234 | alphak = rTr / (mu_pWtWpNorm+pATApNorm) 235 | x = x+alphak*p 236 | r = r+alphak*self.AWx(p,mu) 237 | betak = torch.sum(r**2,dim=(1,2,3), keepdim=True)/ rTr 238 | p=-r+betak*p 239 | 240 | pATApNorm = self.pATAp(p) 241 | mu_pWtWpNorm=self.pWTWp(p,mu) 242 | rTr=torch.sum(r**2,dim=(1,2,3), keepdim=True) 243 | alphak = rTr/(mu_pWtWpNorm+pATApNorm) 244 | x = x+alphak*p 245 | return x 246 | 247 | class MetaInvNet_H(nn.Module): 248 | def __init__(self, options, layers = 3, InitNet = MetaInvH): 249 | super(MetaInvNet_H,self).__init__() 250 | self.layers = layers 251 | self.net = nn.ModuleList([InitNet(options) for i in range(self.layers+1)]) 252 | self.CGInitCNN=UNet(n_channels=1, n_classes=1) 253 | 254 | def forward(self, fbpu, sino): 255 | img_list = [None] * (self.layers + 1) 256 | laam=0.05 257 | miu=0.01 258 | img_list[0] = self.net[0](fbpu.detach(), sino.detach(), laam, miu, self.CGInitCNN) 259 | inc_lam, inc_miu=0.0008, 0.02 260 | for i in range(self.layers): 261 | laam=laam-inc_lam 262 | miu=miu+inc_miu 263 | img_list[i+1] = self.net[i+1](img_list[i], sino.detach(), laam, miu, self.CGInitCNN) 264 | return img_list -------------------------------------------------------------------------------- /recon/models/MomentumNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class dCNN(nn.Module): 4 | def __init__(self): 5 | super(dCNN, self).__init__() 6 | self.model = nn.Sequential( 7 | nn.Conv2d(1, 64, 3, 1, 1), 8 | nn.ReLU(inplace=True), 9 | nn.Conv2d(64, 64, 3, 1, 1), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(64, 64, 3, 1, 1), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(64, 1, 3, 1, 1), 14 | ) 15 | for module in self.modules(): 16 | if isinstance(module, nn.Conv2d): 17 | nn.init.normal_(module.weight, 0, 0.01) 18 | if module.bias is not None: 19 | module.bias.data.zero_() 20 | 21 | def forward(self, x): 22 | return x + self.model(x) 23 | -------------------------------------------------------------------------------- /recon/models/RED_CNN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class RED_CNN(nn.Module): 4 | def __init__(self, hid_channels=48, kernel_size=5, padding=2): 5 | super(RED_CNN, self).__init__() 6 | self.conv_1 = nn.Conv2d(1, hid_channels, kernel_size=kernel_size, padding=padding) 7 | self.conv_2 = nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 8 | self.conv_3 = nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 9 | self.conv_4 = nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 10 | self.conv_5 = nn.Conv2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 11 | self.conv_t_1 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 12 | self.conv_t_2 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 13 | self.conv_t_3 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 14 | self.conv_t_4 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size=kernel_size, padding=padding) 15 | self.conv_t_5= nn.ConvTranspose2d(hid_channels, 1, kernel_size=kernel_size, padding=padding) 16 | self.relu = nn.ReLU(inplace=True) 17 | for module in self.modules(): 18 | if isinstance(module, nn.Conv2d): 19 | nn.init.normal_(module.weight, mean=0, std=0.01) 20 | if module.bias is not None: 21 | module.bias.data.zero_() 22 | if isinstance(module, nn.ConvTranspose2d): 23 | nn.init.normal_(module.weight, mean=0, std=0.01) 24 | if module.bias is not None: 25 | module.bias.data.zero_() 26 | 27 | def forward(self, x): 28 | # encoder 29 | residual_1 = x.clone() 30 | out = self.relu(self.conv_1(x)) 31 | out = self.relu(self.conv_2(out)) 32 | residual_2 = out.clone() 33 | out = self.relu(self.conv_3(out)) 34 | out = self.relu(self.conv_4(out)) 35 | residual_3 = out.clone() 36 | out = self.relu(self.conv_5(out)) 37 | 38 | # decoder 39 | out = self.conv_t_1(out) 40 | out = out + residual_3 41 | out = self.conv_t_2(self.relu(out)) 42 | out = self.conv_t_3(self.relu(out)) 43 | out = out + residual_2 44 | out = self.conv_t_4(self.relu(out)) 45 | out = self.conv_t_5(self.relu(out)) 46 | out = out + residual_1 47 | out = self.relu(out) 48 | return out 49 | -------------------------------------------------------------------------------- /recon/models/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DownBlock(nn.Module): 5 | def __init__(self, in_ch, out_ch, first_block=False): 6 | super(DownBlock, self).__init__() 7 | self.model = nn.Sequential( 8 | nn.Conv2d(1, in_ch, 3, 1, 1) if first_block else nn.MaxPool2d(2), 9 | nn.Conv2d(in_ch, out_ch, 3, 1, 1), 10 | nn.BatchNorm2d(out_ch), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 13 | nn.BatchNorm2d(out_ch), 14 | nn.ReLU(inplace=True) 15 | ) 16 | 17 | def forward(self, x): 18 | out = self.model(x) 19 | return out 20 | 21 | class UpBlock(nn.Module): 22 | def __init__(self, in_ch, out_ch): 23 | super(UpBlock, self).__init__() 24 | self.pool = nn.Sequential( 25 | nn.ConvTranspose2d(in_ch, out_ch, 3, 2, 1, 1), 26 | nn.BatchNorm2d(out_ch), 27 | nn.ReLU(inplace=True), 28 | ) 29 | self.model = nn.Sequential( 30 | nn.Conv2d(in_ch, out_ch, 3, 1, 1), 31 | nn.BatchNorm2d(out_ch), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 34 | nn.BatchNorm2d(out_ch), 35 | nn.ReLU(inplace=True) 36 | ) 37 | 38 | def forward(self, x, y): 39 | x_t = self.pool(x) 40 | x_in = torch.cat((x_t, y), dim=1) 41 | out = self.model(x_in) 42 | return out 43 | 44 | class UNet(nn.Module): 45 | def __init__(self): 46 | super(UNet, self).__init__() 47 | self.conv1 = DownBlock(64, 64, True) 48 | self.conv2 = DownBlock(64, 128) 49 | self.conv3 = DownBlock(128, 256) 50 | self.conv4 = DownBlock(256, 512) 51 | self.conv5 = DownBlock(512, 1024) 52 | self.conv4_t = UpBlock(1024, 512) 53 | self.conv3_t = UpBlock(512, 256) 54 | self.conv2_t = UpBlock(256, 128) 55 | self.conv1_t = UpBlock(128, 64) 56 | self.conv_last = nn.Conv2d(64, 1, 1, 1, 0) 57 | for module in self.modules(): 58 | if isinstance(module, nn.Conv2d): 59 | nn.init.normal_(module.weight, mean=0, std=0.01) 60 | if module.bias is not None: 61 | module.bias.data.zero_() 62 | if isinstance(module, nn.ConvTranspose2d): 63 | nn.init.normal_(module.weight, mean=0, std=0.01) 64 | if module.bias is not None: 65 | module.bias.data.zero_() 66 | if isinstance(module, nn.BatchNorm2d): 67 | module.weight.data.fill_(1) 68 | module.bias.data.zero_() 69 | 70 | def forward(self, x): 71 | # encoder 72 | x_1 = self.conv1(x) 73 | x_2 = self.conv2(x_1) 74 | x_3 = self.conv3(x_2) 75 | x_4 = self.conv4(x_3) 76 | x_5 = self.conv5(x_4) 77 | 78 | # decoder 79 | y_4 = self.conv4_t(x_5, x_4) 80 | y_3 = self.conv3_t(y_4, x_3) 81 | y_2 = self.conv2_t(y_3, x_2) 82 | y_1 = self.conv1_t(y_2, x_1) 83 | out = self.conv_last(y_1) 84 | return out 85 | -------------------------------------------------------------------------------- /recon/models/VVBPTensorNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import ctlib 5 | from .RED_CNN import RED_CNN 6 | 7 | class bprj_sv_fun(Function): 8 | @staticmethod 9 | def forward(self, proj, options): 10 | self.save_for_backward(options) 11 | return ctlib.backprojection_sv(proj, options) 12 | 13 | @staticmethod 14 | def backward(self, grad_output): 15 | options = self.saved_tensors[0] 16 | temp = grad_output.sum(1, keepdim=True) 17 | grad_input = ctlib.backprojection_t(temp.contiguous(), options) 18 | return grad_input, None 19 | 20 | class backprojector_sv(nn.Module): 21 | def __init__(self): 22 | super(backprojector_sv, self).__init__() 23 | 24 | def forward(self, proj, options): 25 | return bprj_sv_fun.apply(proj, options) 26 | 27 | 28 | class FBP_sv(nn.Module): 29 | def __init__(self, options) -> None: 30 | super().__init__() 31 | dets = int(options[1]) 32 | dDet = options[5] 33 | s2r = options[7] 34 | d2r = options[8] 35 | virdet = dDet * s2r / (s2r + d2r) 36 | filter = torch.empty(2 * dets - 1) 37 | pi = torch.acos(torch.tensor(-1.0)) 38 | for i in range(filter.size(0)): 39 | x = i - dets + 1 40 | if abs(x) % 2 == 1: 41 | filter[i] = -1 / (pi * pi * x * x * virdet * virdet) 42 | elif x == 0: 43 | filter[i] = 1 / (4 * virdet * virdet) 44 | else: 45 | filter[i] = 0 46 | filter = filter.view(1,1,1,-1) 47 | w = torch.arange((-dets / 2 + 0.5) * virdet, dets / 2 * virdet, virdet) 48 | w = s2r / torch.sqrt(s2r ** 2 + w ** 2) 49 | w = w.view(1,1,1,-1) * virdet * pi / options[0] 50 | self.w = nn.Parameter(w, requires_grad=False) 51 | self.filter = nn.Parameter(filter, requires_grad=False) 52 | self.options = nn.Parameter(options, requires_grad=False) 53 | self.backprojector = backprojector_sv() 54 | self.dets = dets 55 | 56 | def forward(self, projection): 57 | p = projection * self.w 58 | p = torch.nn.functional.conv2d(p, self.filter, padding=(0,self.dets-1)) 59 | recon = self.backprojector(p, self.options) 60 | return recon 61 | 62 | class VVBPTensorNet(nn.Module): 63 | def __init__(self, options) -> None: 64 | super().__init__() 65 | self.backprojector = FBP_sv(options) 66 | self.model = RED_CNN() 67 | self.conv = nn.Conv2d(512,1,3,1,1) 68 | 69 | def forward(self, p): 70 | x = self.backprojector(p) 71 | x_in, _ = torch.sort(x, dim=1) 72 | x_in = self.conv(x_in) 73 | out = self.model(x_in) 74 | return out -------------------------------------------------------------------------------- /recon/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .LEARN import LEARN 2 | from .LEARN_FBP import LEARN_FBP 3 | from .LPD import Learned_primal_dual 4 | from .RED_CNN import RED_CNN 5 | from .FramingUNet import FramingUNet 6 | from .MAGIC import MAGIC 7 | from .FBPConvNet import FBPConvNet 8 | from .MomentumNet import dCNN 9 | from .FistaNet import FistaNet 10 | from .iCTNet import iCTNet 11 | from .AirNet import AirNet 12 | from .HDNet import HDNet_2d 13 | from .AdaptiveNet import AdaptiveNet 14 | from .KSAE import KSAE 15 | from .DBP import DBP 16 | from .VVBPTensorNet import VVBPTensorNet 17 | from .MetaInvNet import MetaInvNet_H 18 | from .AHPNet import AHPNet 19 | from .iRadonMap import iRadonMap 20 | from .DSigNet import DSigNet -------------------------------------------------------------------------------- /recon/models/iCTNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import math 5 | import torchvision.transforms as transforms 6 | 7 | class rotation(nn.Module): 8 | def __init__(self, dAng, height, width): 9 | super(rotation, self).__init__() 10 | self.dAng = dAng * 180 / math.pi 11 | self.height = height 12 | self.width = width 13 | 14 | def forward(self, x): 15 | B, N2, Nv, L = x.shape 16 | y = x.transpose(1, 3).view(B, Nv, self.height, self.width).contiguous() 17 | out = torch.empty_like(y) 18 | for i in range(Nv): 19 | ang = self.dAng * i 20 | out[:, i] = transforms.functional.rotate(y[:, i], ang, transforms.InterpolationMode.BILINEAR) 21 | out = out.view(B, Nv, N2, 1).contiguous() 22 | return out 23 | 24 | 25 | class iCTNet(nn.Module): 26 | def __init__(self, N_v, N_c, height, width, dAng, beta=5, alpha_1 = 1, alpha_2 = 1): 27 | super(iCTNet, self).__init__() 28 | self.L1 = nn.Sequential( 29 | nn.Conv2d(1, 64, (1, 3), padding=(0, 1)), 30 | nn.Hardshrink(1e-5) 31 | ) 32 | self.L2 = nn.Sequential( 33 | nn.Conv2d(64, 64, (1, 3), padding=(0, 1)), 34 | nn.Hardshrink(1e-5) 35 | ) 36 | self.L3 = nn.Sequential( 37 | nn.Conv2d(129, 1, (1, 3), padding=(0, 1)), 38 | nn.Hardshrink(1e-5) 39 | ) 40 | self.L4 = nn.Sequential( 41 | nn.Conv2d(N_v, N_v * alpha_1, (1, 1), padding=(0, 0)), 42 | nn.Hardshrink(1e-8) 43 | ) 44 | self.L5 = nn.Sequential( 45 | nn.Conv2d(N_v * alpha_1, N_v * alpha_2, (1, 1), padding=(0, 0)), 46 | nn.Hardshrink(1e-8) 47 | ) 48 | self.L6 = nn.Sequential( 49 | nn.Conv2d(1, 1, (N_v * alpha_2, N_c), padding='same', padding_mode='circular', bias=False), 50 | nn.Identity() 51 | ) 52 | self.L7 = nn.Sequential( 53 | nn.Conv2d(1, 16, (1, beta), padding=(0 ,(beta-1)//2)), 54 | nn.Tanh() 55 | ) 56 | self.L8 = nn.Sequential( 57 | nn.Conv2d(16, 1, (1, beta), padding=(0, (beta-1)//2)), 58 | nn.Tanh() 59 | ) 60 | self.L9 = nn.Sequential( 61 | nn.Conv2d(N_c, N_c, (1, 1), padding=(0, 0)), 62 | nn.Tanh() 63 | ) 64 | self.L10 = nn.Sequential( 65 | nn.Conv2d(N_c, height * width, (1, 1), padding=(0, 0), bias=False), 66 | nn.Identity() 67 | ) 68 | self.L11 = rotation(dAng, height, width) 69 | self.L12 = nn.Sequential( 70 | nn.Conv2d(N_v * alpha_2, 1, (1, 1), padding=(0, 0), bias=False), 71 | nn.Identity() 72 | ) 73 | self.height = height 74 | self.width = width 75 | for module in self.modules(): 76 | if isinstance(module, nn.Conv2d): 77 | nn.init.xavier_uniform_(module.weight) 78 | if module.bias is not None: 79 | module.bias.data.zero_() 80 | 81 | def forward(self, x): 82 | x1 = self.L1(x) 83 | x2 = self.L2(x1) 84 | x3_in = torch.cat((x,x1,x2), dim=1) 85 | x3 = self.L3(x3_in) 86 | x4_in = x3.transpose(1, 2) 87 | x4 = self.L4(x4_in) 88 | x5 = self.L5(x4) 89 | x6_in = x5.transpose(1, 2) 90 | x6 =self.L6(x6_in) 91 | x7 = self.L7(x6) 92 | x8 = self.L8(x7) 93 | x9_in = x8.transpose(1, 3) 94 | x9 = self.L9(x9_in) 95 | x10 = self.L10(x9) 96 | x11 = self.L11(x10) 97 | x12 = self.L12(x11) 98 | out = x12.view(x12.size(0), 1, self.height, self.width) 99 | return out 100 | 101 | def segment1(self, x): 102 | x1 = self.L1(x) 103 | x2 = self.L2(x1) 104 | x3_in = torch.cat((x,x1,x2), dim=1) 105 | out = self.L3(x3_in) 106 | return out 107 | 108 | def segment2(self, x): 109 | x4_in = x.transpose(1, 2) 110 | x4 = self.L4(x4_in) 111 | x5 = self.L5(x4) 112 | out = x5.transpose(1, 2) 113 | return out 114 | 115 | def segment3(self, x): 116 | x6 =self.L6(x) 117 | x7 = self.L7(x6) 118 | x8 = self.L8(x7) 119 | x9_in = x8.transpose(1, 3) 120 | x9 = self.L9(x9_in) 121 | out = x9.transpose(1, 3) 122 | return out 123 | 124 | def segment4(self, x): 125 | x = x.transpose(1, 3) 126 | x10 = self.L10(x) 127 | x11 = self.L11(x10) 128 | x12 = self.L12(x11) 129 | out = x12.view(x12.size(0), 1, self.height, self.width) 130 | return out 131 | -------------------------------------------------------------------------------- /recon/models/iRadonMap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | from torch.autograd import Function 6 | import pycuda.autoinit 7 | import pycuda.driver as drv 8 | from pycuda.compiler import SourceModule 9 | 10 | def computeDeltasCube(geo, alpha): 11 | # Get coords of Img(0,0) 12 | P0 = {'x': -(geo['sVoxelX']/2 - geo['dVoxelX']/2) + geo['offOriginX'], 13 | 'y': (geo['sVoxelY']/2 - geo['dVoxelY']/2) + geo['offOriginY']} 14 | 15 | # Get coors from next voxel in each direction 16 | Px0 = {'x': P0['x'] + geo['dVoxelX'], 'y': P0['y']} 17 | Py0 = {'x': P0['x'], 'y': P0['y'] - geo['dVoxelY']} 18 | 19 | P = {'x': P0['x'] *math.cos(alpha)+P0['y'] *math.sin(alpha), 20 | 'y': -P0['x'] *math.sin(alpha)+P0['y'] *math.cos(alpha)} 21 | 22 | Px = {'x': Px0['x'] *math.cos(alpha)+Px0['y'] *math.sin(alpha), 23 | 'y': -Px0['x'] *math.sin(alpha)+Px0['y'] *math.cos(alpha)} 24 | 25 | Py = {'x': Py0['x'] *math.cos(alpha)+Py0['y'] *math.sin(alpha), 26 | 'y': -Py0['x'] *math.sin(alpha)+Py0['y'] *math.cos(alpha)} 27 | 28 | # Scale coords so detector pixels are 1x1 29 | Px['y'] =Px['y']/geo['dDetecU'] 30 | P['y'] =P['y']/geo['dDetecU'] 31 | Py['y'] =Py['y']/geo['dDetecU'] 32 | 33 | # Compute unit vector of change between voxels 34 | deltaX ={'x': Px['x']-P['x'], 'y': Px['y']-P['y']} 35 | deltaY ={'x': Py['x']-P['x'], 'y': Py['y']-P['y']} 36 | 37 | return P, deltaX, deltaY 38 | 39 | def PixelIndexCal_cuda(geo): 40 | mod = SourceModule(""" 41 | __global__ void KernelPixelIndexCal_cuda( 42 | const float *geo, const float *xyzorigin, 43 | const float *deltaX, const float *deltaY, 44 | const float *alpha, const float *angle, 45 | const int *mode, float *u, float *w) 46 | { 47 | const int indX = blockIdx.x * blockDim.x + threadIdx.x; 48 | const int indY = blockIdx.y * blockDim.y + threadIdx.y; 49 | 50 | int extent = (int)geo[3], nVoxelX = (int)geo[8], nVoxelY = (int)geo[9]; 51 | 52 | unsigned long idx = indX*nVoxelY+indY; 53 | 54 | if ((indX>=nVoxelX) || (indY>=nVoxelY)) 55 | return; 56 | 57 | float DSD = geo[0], DSO = geo[1], nDetecU = geo[2], sVoxelX = geo[4], sVoxelY = geo[5], dVoxelX = geo[6], dVoxelY = geo[7]; 58 | 59 | float P_x = xyzorigin[0] + indX * deltaX[0] + indY * deltaY[0]; 60 | float P_y = xyzorigin[1] + indX * deltaX[1] + indY * deltaY[1]; 61 | 62 | float S_x = DSO; 63 | float S_y; 64 | 65 | if (mode[0] == 0) 66 | S_y = P_y; 67 | else if (mode[0] == 1) 68 | S_y = 0.0; 69 | 70 | float vectX = P_x - S_x; 71 | float vectY = P_y - S_y; 72 | 73 | float t = (DSO - DSD - S_x) / vectX; 74 | float y = vectY * t + S_y; 75 | 76 | float detindx = y + nDetecU / 2; 77 | 78 | float realx = -1*sVoxelX/2 + dVoxelX/2 + indX *dVoxelX; 79 | float realy = -1*sVoxelY/2 + dVoxelY/2 + indY *dVoxelY; 80 | 81 | float weight = (DSO + realy *sin(alpha[0]) - realx *cos(alpha[0])) / DSO; 82 | 83 | weight = 1 / (weight *weight); 84 | 85 | if (detindx > (nDetecU-2)) 86 | detindx = nDetecU-2; 87 | if (detindx < 1) 88 | detindx = 1; 89 | 90 | float tmp_index = detindx + nDetecU * angle[0]; 91 | 92 | if (extent == 1) 93 | { 94 | u[idx] = tmp_index; 95 | w[idx] = weight; 96 | } 97 | else if (extent == 2) 98 | { 99 | if ((detindx - (int)detindx) > 0.5) 100 | { 101 | u[idx*extent+0] = tmp_index - 1.0; 102 | } 103 | else if ((detindx - (int)detindx) < 0.5) 104 | { 105 | u[idx*extent+0] = tmp_index + 1.0; 106 | } 107 | 108 | u[idx*extent+1] = tmp_index; 109 | 110 | w[idx*extent+0] = weight/2; 111 | w[idx*extent+1] = weight/2; 112 | } 113 | else if (extent == 3) 114 | { 115 | u[idx*extent+0] = tmp_index - 1.0; 116 | u[idx*extent+1] = tmp_index; 117 | u[idx*extent+2] = tmp_index + 1.0; 118 | 119 | w[idx*extent+0] = weight/3; 120 | w[idx*extent+1] = weight/3; 121 | w[idx*extent+2] = weight/3; 122 | } 123 | } 124 | """) 125 | 126 | KernelPixelIndexCal_cuda = mod.get_function("KernelPixelIndexCal_cuda") 127 | 128 | 129 | nTheads = 32 # nTheads is no more than 32, becuase the total nTheads in one block should be no more than 1024, i.e., block=(nTheads, nTheads, 1), nTheads*nTheads*1 <= 1024. 130 | nBlocks = (geo['nVoxelX'] + nTheads - 1) // nTheads 131 | 132 | alphas = np.linspace(geo['start_angle'], geo['end_angle'], geo['views'], False) 133 | sino_indices = torch.zeros(geo['nVoxelX']*geo['nVoxelY']*geo['extent'], geo['views']).type(torch.LongTensor) 134 | sino_weights = torch.zeros(geo['nVoxelX']*geo['nVoxelY']*geo['extent'], geo['views']).type(torch.FloatTensor) 135 | 136 | for angle in range(geo['views']): 137 | alpha = -alphas[angle] 138 | xyzorigin_dic, deltaX_dic, deltaY_dic = computeDeltasCube(geo, alpha) 139 | indices = np.zeros(geo['nVoxelX']*geo['nVoxelY']*geo['extent'], dtype=np.float32) 140 | weights = np.zeros(geo['nVoxelX']*geo['nVoxelY']*geo['extent'], dtype=np.float32) 141 | tmp_geo = np.array([geo[i] for i in ['DSD', 'DSO', 'nDetecU', 'extent', 'sVoxelX', 'sVoxelY', 'dVoxelX', 'dVoxelY', 'nVoxelX', 'nVoxelY']], dtype=np.float32) 142 | xyzorigin = np.array(list(xyzorigin_dic.values()), dtype=np.float32) 143 | deltaX = np.array(list(deltaX_dic.values()), dtype=np.float32) 144 | deltaY = np.array(list(deltaY_dic.values()), dtype=np.float32) 145 | tmp_angle = np.array([angle], dtype=np.float32) 146 | tmp_mode = np.array([0 if geo['mode'] == 'parallel' else 1]) 147 | tmp_alpha = np.array([alpha], dtype=np.float32) 148 | 149 | KernelPixelIndexCal_cuda(drv.In(tmp_geo), drv.In(xyzorigin), drv.In(deltaX), drv.In(deltaY), 150 | drv.In(tmp_alpha), drv.In(tmp_angle), drv.In(tmp_mode), drv.InOut(indices), 151 | drv.InOut(weights), block=(nTheads, nTheads, 1), grid=(nBlocks, nBlocks)) 152 | 153 | sino_indices[:, angle] = torch.from_numpy(indices) 154 | sino_weights[:, angle] = torch.from_numpy(weights) 155 | 156 | return sino_indices.view(-1), sino_weights.view(-1) 157 | 158 | class DotProduct(Function): 159 | @staticmethod 160 | def forward(ctx, input, weight): 161 | ctx.save_for_backward(input, weight) 162 | output = input*weight.unsqueeze(0).expand_as(input) 163 | return output 164 | 165 | @staticmethod 166 | def backward(ctx, grad_output): 167 | input, weight = ctx.saved_tensors 168 | grad_input = grad_weight = None 169 | 170 | if ctx.needs_input_grad[0]: 171 | grad_input = grad_output*weight.unsqueeze(0).expand_as(grad_output) 172 | if ctx.needs_input_grad[1]: 173 | grad_weight = grad_output*input 174 | grad_weight = grad_weight.sum(0).squeeze(0) 175 | 176 | return grad_input, grad_weight 177 | 178 | class BackProjNet(nn.Module): 179 | def __init__(self, geo, channel=1, learn=False): 180 | super(BackProjNet, self).__init__() 181 | self.geo = geo 182 | self.learn = learn 183 | self.channel = channel 184 | self.indices = nn.Parameter(self.geo['indices'], requires_grad=False) 185 | 186 | if self.learn: 187 | self.weight = nn.Parameter(self.geo['weights']) 188 | self.bias = nn.Parameter(torch.Tensor(self.geo['nVoxelX']*self.geo['nVoxelY'])) 189 | else: 190 | self.register_parameter('weight', None) 191 | self.register_parameter('bias', None) 192 | 193 | def forward(self, input): 194 | input = input.reshape(-1, self.channel, self.geo['views']*self.geo['nDetecU']) 195 | output = torch.index_select(input, 2, self.indices) 196 | if self.learn: 197 | output = DotProduct.apply(output, self.weight) 198 | output = output.view(-1, self.channel, self.geo['nVoxelX']*self.geo['nVoxelY'], self.geo['views']*self.geo['extent']) 199 | output = torch.sum(output, 3) * (self.geo['end_angle']-self.geo['start_angle']) / (2*self.geo['views']*self.geo['extent']) 200 | if self.learn: 201 | output += self.bias.unsqueeze(0).expand_as(output) 202 | output = output.view(-1, self.channel, self.geo['nVoxelX'], self.geo['nVoxelY']) 203 | output = output.flip((2,3)) 204 | return output 205 | 206 | class block(nn.Module): 207 | def __init__(self) -> None: 208 | super().__init__() 209 | self.l1 = nn.Sequential( 210 | nn.Conv2d(64, 64, 3, 1, 1), 211 | nn.ReLU(inplace=True), 212 | nn.Conv2d(64, 64, 3, 1, 1) 213 | ) 214 | self.l2 = nn.ReLU(inplace=True) 215 | 216 | def forward(self, x): 217 | y = self.l1(x) 218 | out = self.l2(x+y) 219 | return out 220 | 221 | class rCNN(nn.Module): 222 | def __init__(self) -> None: 223 | super().__init__() 224 | layers = [] 225 | layers.append(nn.Conv2d(1, 64, 3, 1, 1)) 226 | layers.append(nn.GroupNorm(num_channels=64, num_groups=1, affine=False)) 227 | layers.append(nn.ReLU()) 228 | for i in range(11): 229 | layers.append(block()) 230 | layers.append(nn.Conv2d(64, 1, 3, 1, 1)) 231 | self.layers = nn.Sequential(*layers) 232 | 233 | def forward(self, x): 234 | out = self.layers(x) 235 | return out 236 | 237 | class iRadonMap(nn.Module): 238 | def __init__(self, options) -> None: 239 | super().__init__() 240 | geo_real = {'nVoxelX': int(options[2]), 'sVoxelX': float(options[4]) * int(options[2]), 'dVoxelX': float(options[4]), 241 | 'nVoxelY': int(options[3]), 'sVoxelY': float(options[4]) * int(options[3]), 'dVoxelY': float(options[4]), 242 | 'nDetecU': int(options[1]), 'sDetecU': float(options[5]) * int(options[1]), 'dDetecU': float(options[5]), 243 | 'offOriginX': 0.0, 'offOriginY': 0.0, 244 | 'views': int(options[0]), 'slices': 1, 245 | 'DSD': float(options[8]) + float(options[9]), 'DSO': float(options[8]), 'DOD': float(options[9]), 246 | 'start_angle': 0.0, 'end_angle': float(options[7]) * int(options[0]), 247 | 'mode': 'fanflat', 'extent': 1, # currently extent supports 1, 2, or 3. 248 | } 249 | geo_virtual = dict() 250 | geo_virtual.update({x: int(geo_real[x]/1) for x in ['views']}) 251 | geo_virtual.update({x: int(geo_real[x]/1) for x in ['nVoxelX', 'nVoxelY', 'nDetecU']}) 252 | geo_virtual.update({x: geo_real[x]/1 for x in ['sVoxelX', 'sVoxelY', 'sDetecU', 'DSD', 'DSO', 'DOD', 'offOriginX', 'offOriginY']}) 253 | geo_virtual.update({x: geo_real[x] for x in ['dVoxelX', 'dVoxelY', 'dDetecU', 'slices', 'start_angle', 'end_angle', 'mode', 'extent']}) 254 | geo_virtual['indices'], geo_virtual['weights'] = PixelIndexCal_cuda(geo_virtual) 255 | self.convf = nn.Sequential( 256 | nn.Linear(geo_virtual['nDetecU'], geo_virtual['nDetecU']), 257 | nn.ReLU(), 258 | nn.Linear(geo_virtual['nDetecU'], geo_virtual['nDetecU']), 259 | nn.ReLU() 260 | ) 261 | # self.convf = nn.Sequential( 262 | # nn.Conv2d(1,64, (1,3),padding=(0,1)), 263 | # nn.ReLU(True), 264 | # nn.Conv2d(64,1, (1,3),padding=(0,1)) 265 | # ) 266 | self.spbp = BackProjNet(geo_virtual, learn=False) 267 | self.rcnn = rCNN() 268 | for module in self.modules(): 269 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 270 | nn.init.xavier_uniform_(module.weight) 271 | if module.bias is not None: 272 | module.bias.data.zero_() 273 | 274 | def forward(self, p): 275 | convp = self.convf(p) 276 | x = self.spbp(convp) 277 | out = self.rcnn(x) 278 | return out 279 | 280 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pycuda 2 | scipy 3 | numpy 4 | torch 5 | pytorch-lightning 6 | visdom 7 | --------------------------------------------------------------------------------