├── imgs ├── pccgan.png └── pmccgan.png ├── requirements.txt ├── lr_adj.py ├── image_pool.py ├── dataprocess ├── datamake3D.py └── datautils3d.py ├── evaluate.py ├── model ├── utils.py ├── network.py ├── context_cluster3D.py └── context_cluster3D_Multi.py ├── README.md ├── test.py └── train.py /imgs/pccgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gluucose/PCCGAN/HEAD/imgs/pccgan.png -------------------------------------------------------------------------------- /imgs/pmccgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gluucose/PCCGAN/HEAD/imgs/pmccgan.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.3 2 | pillow==9.0.1 3 | matplotlib==3.7.2 4 | timm==0.5.4 5 | eniops==0.4.1 6 | scipy==1.8.0 7 | torch==1.11.0+cu113 8 | torchvision==0.12.0+cu113 -------------------------------------------------------------------------------- /lr_adj.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | 3 | def update_learning_rate(optimizer_G): 4 | niter_decay = 50 5 | old_lr = optimizer_G.state_dict()['param_groups'][0]['lr'] 6 | lrd = old_lr / niter_decay 7 | 8 | lr = old_lr - lrd 9 | for param_group in optimizer_G.param_groups: 10 | param_group['lr'] = lr 11 | print('update learning rate: %f -> %f' % (old_lr, lr)) 12 | -------------------------------------------------------------------------------- /image_pool.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | 3 | import random 4 | 5 | import torch 6 | 7 | 8 | class ImagePool(): 9 | def __init__(self, pool_size): 10 | self.pool_size = pool_size 11 | if self.pool_size > 0: 12 | self.num_imgs = 0 13 | self.images = [] 14 | 15 | def query(self, images): 16 | # print('querying') 17 | if self.pool_size == 0: 18 | return images 19 | return_images = [] 20 | for image in images.data: 21 | image = torch.unsqueeze(image, 0) 22 | if self.num_imgs < self.pool_size: 23 | self.num_imgs = self.num_imgs + 1 24 | self.images.append(image) 25 | return_images.append(image) 26 | else: 27 | p = random.uniform(0, 1) 28 | if p > 0.5: 29 | random_id = random.randint(0, self.pool_size - 1) 30 | tmp = self.images[random_id].clone() 31 | self.images[random_id] = image 32 | return_images.append(tmp) 33 | else: 34 | return_images.append(image) 35 | return_images = torch.cat(return_images, 0) 36 | return return_images 37 | -------------------------------------------------------------------------------- /dataprocess/datamake3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | 3 | import os 4 | 5 | import SimpleITK as sitk 6 | import numpy as np 7 | from medpy.io import load 8 | 9 | 10 | def Datamake(root): 11 | all_names = [] 12 | for root, dirs, files in os.walk(root): 13 | all_names = (files) 14 | 15 | all_name = [] 16 | for i in all_names: 17 | if os.path.splitext(i)[1] == ".img": 18 | all_name.append(i) 19 | # print(all_name) 20 | 21 | # create result folder 22 | res_dir = root + '_cut' 23 | folder = os.path.exists(res_dir) 24 | if not folder: 25 | os.makedirs(res_dir) 26 | 27 | for file in all_name: 28 | image_path_mri = os.path.join(root, file) 29 | image_mri, h = load(image_path_mri) 30 | image_mri = np.array(image_mri) 31 | # print(image_l.shape) 32 | cut_cnt = 0 33 | for i in range(0, 5): 34 | for j in range(0, 5): 35 | for k in range(0, 5): 36 | image_cut = image_mri[16 * i:64 + 16 * i, 16 * j:64 + 16 * j, 16 * k:64 + 16 * k] 37 | savImg = sitk.GetImageFromArray(image_cut) 38 | sitk.WriteImage(savImg, res_dir + '/' + file + '_cut' + f'{cut_cnt:03d}' + '.img') 39 | cut_cnt += 1 40 | 41 | 42 | if __name__ == '__main__': 43 | Datamake('./dataset/') 44 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | 3 | import SimpleITK as sitk 4 | import numpy as np 5 | from skimage.metrics import normalized_root_mse as nmse 6 | from skimage.metrics import peak_signal_noise_ratio as psnr 7 | from skimage.metrics import structural_similarity as ssim 8 | 9 | from model.utils import norm 10 | 11 | 12 | def evaluateMulti(G, valloader, device): 13 | G.eval() 14 | PSNR_vals, SSIM_vals, NMSE_vals = list(), list(), list() 15 | for image_l, image_s, mri in valloader: 16 | testl = image_l 17 | image_l = norm(testl).to(device) 18 | tests = image_s 19 | image_s = norm(tests).to(device) 20 | testmri = mri 21 | testmri = norm(testmri).to(device) 22 | image_s = np.squeeze(image_s.cpu().detach().numpy()) 23 | 24 | res = G(image_l, testmri) 25 | res = res.cpu().detach().numpy() 26 | res = np.squeeze(res) 27 | 28 | image_l = image_l.cpu().detach().numpy() 29 | image_l = np.squeeze(image_l) 30 | y = np.nonzero(image_s) 31 | image_s_1 = image_s[y] 32 | res_1 = res[y] 33 | # cal PSNR 34 | cur_psnr = (psnr(res_1, image_s_1, data_range=1)) 35 | # cal ssim 36 | cur_ssim = ssim(res, image_s, multichannel=True) 37 | # cal mrse 38 | cur_nmse = nmse(image_s, res) ** 2 39 | 40 | PSNR_vals.append(cur_psnr) 41 | SSIM_vals.append(cur_ssim) 42 | NMSE_vals.append(cur_nmse) 43 | 44 | cur_mean_PSNR_val, cur_mean_SSIM_val, cur_mean_NMSE_val = \ 45 | np.mean(PSNR_vals), np.mean(SSIM_vals), np.mean(NMSE_vals) 46 | 47 | return cur_mean_PSNR_val, cur_mean_SSIM_val, cur_mean_NMSE_val 48 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some helper functions. 3 | """ 4 | 5 | import argparse 6 | import logging 7 | import os 8 | 9 | import SimpleITK as sitk 10 | import numpy as np 11 | import torch.nn as nn 12 | from PIL import Image 13 | from thop import profile 14 | 15 | 16 | def dict2namespace(config): 17 | namespace = argparse.Namespace() 18 | for key, value in config.items(): 19 | if isinstance(value, dict): 20 | new_value = dict2namespace(value) 21 | else: 22 | new_value = value 23 | setattr(namespace, key, new_value) 24 | return namespace 25 | 26 | 27 | def readTxtLineAsList(txt_path): 28 | fi = open(txt_path, 'r') 29 | txt = fi.readlines() 30 | res_list = [] 31 | for w in txt: 32 | w = w.replace('\n', '') 33 | res_list.append(w) 34 | return res_list 35 | 36 | 37 | def save_image(image_numpy, image_path): 38 | savImg = sitk.GetImageFromArray(image_numpy[:, :, :]) 39 | sitk.WriteImage(savImg, image_path) 40 | 41 | 42 | def weights_init(m): 43 | if isinstance(m, nn.Linear): 44 | # logging.info('=> init weight of Linear from xavier uniform') 45 | nn.init.xavier_uniform_(m.weight) 46 | if m.bias is not None: 47 | # logging.info('=> init bias of Linear to zeros') 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 50 | nn.init.constant_(m.bias, 0) 51 | nn.init.constant_(m.weight, 1.0) 52 | 53 | 54 | def print_model_parm_nums(model, x): 55 | flops, params = profile(model, inputs=(x,)) 56 | print(' + FLOPs: %.2fGFLOPs' % (flops / 1024 / 1024 / 1024), 57 | ' + Params: %.2fM' % (params / 1024 / 1024)) 58 | 59 | 60 | def get_logger(filename, verbosity=1, name=None): 61 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 62 | formatter = logging.Formatter( 63 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 64 | ) 65 | logger = logging.getLogger(name) 66 | logger.setLevel(level_dict[verbosity]) 67 | 68 | fh = logging.FileHandler(filename, "w") 69 | fh.setFormatter(formatter) 70 | logger.addHandler(fh) 71 | 72 | sh = logging.StreamHandler() 73 | sh.setFormatter(formatter) 74 | logger.addHandler(sh) 75 | 76 | return logger 77 | 78 | 79 | def norm(x): 80 | X = (x - x.min()) / (x.max() - x.min()) 81 | return X 82 | 83 | 84 | def del_file(path_data): 85 | for i in os.listdir(path_data): # os.listdir(path_data)#返回一个列表,里面是当前目录下面的所有东西的相对路径 86 | file_data = path_data + "\\" + i # 当前文件夹的下面的所有东西的绝对路径 87 | if os.path.isfile(file_data) == True: # os.path.isfile判断是否为文件,如果是文件,就删除.如果是文件夹.递归给del_file. 88 | os.remove(file_data) 89 | else: 90 | del_file(file_data) 91 | 92 | 93 | def MatrixToImage(data): 94 | if (data.max() > 2): 95 | data = (data - data.min()) / (data.max() - data.min()) 96 | data = data * 255 97 | # data=np.flipud(data) 98 | new_im = Image.fromarray(data.astype(np.uint8)) 99 | return new_im 100 | 101 | 102 | if __name__ == '__main__': 103 | pass 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PCC-GAN 2 | 3 | Image2Points: A 3D Point-based Context Clusters GAN for High-Quality PET Image Reconstruction, ICASSP 2024. 4 | [link] 5 | 6 | 7 | 8 | The extension of this work (PMC2-GAN): 3D Point-based Multi-Modal Context Clusters GAN for Low-Dose PET Image Denoising, TCSVT 2024. 9 | [link] 10 | 11 | 12 | ## Introduction 13 | ### PCC-GAN 14 | 15 | 16 |

To obtain high-quality Positron emission tomography (PET) images while minimizing radiation exposure, numerous methods have been proposed to reconstruct standard-dose PET (SPET) images from the corresponding low-dose PET (LPET) images. However, these methods heavily rely on voxel-based representations, which fall short of adequately accounting for the precise structure and fine-grained context, leading to compromised reconstruction. In this paper, we propose a 3D point-based context clusters GAN, namely PCC-GAN, to reconstruct high-quality SPET images from LPET. Specifically, inspired by the geometric representation power of points, we resort to a point-based representation to enhance the explicit expression of the image structure, thus facilitating the reconstruction with finer details. Moreover, a context clustering strategy is applied to explore the contextual relationships among points, which mitigates the ambiguities of small structures in the reconstructed images. Experiments on both clinical and phantom datasets demonstrate that our PCC-GAN outperforms the state-of-the-art reconstruction methods qualitatively and quantitatively.

17 | 18 | ### PMC2-GAN 19 | 20 | 21 |

The extensions are elaborated in the following aspects: (1) We refined the research background to rigorously state the necessity of effectively integrating the metabolic information in PET images and the anatomical structural information in MRI images as well as the advantages of multi-modality PET image denoising. (2) We innovatively designed Cross-CC blocks to efficiently harness the multi-modal inputs. In Cross-CC blocks, we developed cross-points aggregating and dispatching to dynamically balance the importance of the primary PET and the auxiliary MRI modalities to generate the target SPET images while maintaining an awareness of the structural and contextual relationships. Then, we elaborately integrated the Cross-CC blocks with the Self-CC blocks to construct the MultiMCC blocks, effectively extracting and integrating the information from both modalities in a structurally explicit and contextually rich manner while maximizing the knowledge in the primary PET modality.

22 | 23 | ## Installation and Training 24 | Clone this repository: 25 | ```bash 26 | git clone https://github.com/gluucose/PCCGAN.git 27 | ``` 28 | Install the required libraries: 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 |

Before training, prepare your data in the dataset folder. Note that, in the model folder, 34 | context_cluster3D.py is the generator for PCC-GAN (single modality), 35 | whereas context_cluster3D_Multi.py is the generator for PMC2-GAN (multi-modality). 36 | The training code is for PMC2-GAN with mutli-modality inputs. To train the model, you can run:

37 | 38 | ```bash 39 | python train.py 40 | ``` 41 | 42 | ## Acknowledgements 43 | We would like to thank the authors of following repositories for their great works: 44 | - [Context-Cluster (CoCs)](https://github.com/ma-xu/Context-Cluster) 45 | 46 | ## Citation 47 | If you find our work useful in your research, please consider citing our papers at: 48 | ```bash 49 | @inproceedings{cui2024image2points, 50 | title={Image2Points: A 3D Point-based Context Clusters GAN for High-Quality PET Image Reconstruction}, 51 | author={Cui, Jiaqi and Wang, Yan and Wen, Lu and Zeng, Pinxian and Wu, Xi and Zhou, Jiliu and Shen, Dinggang}, 52 | booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 53 | pages={1726-1730}, 54 | year={2024}, 55 | organization={IEEE} 56 | } 57 | ``` 58 | ```bash 59 | @article{cui20243d, 60 | title={3D Point-based Multi-Modal Context Clusters GAN for Low-Dose PET Image Denoising}, 61 | author={Cui, Jiaqi and Wang, Yan and Zhou, Luping and Fei, Yuchen and Zhou, Jiliu and Shen, Dinggang}, 62 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 63 | year={2024}, 64 | publisher={IEEE} 65 | } 66 | ``` 67 | 68 | ## Contact 69 | If you have any questions or suggestions, feel free to email [Jiaqi Cui](jiaqicui2001@gmail.com). 70 | 71 | -------------------------------------------------------------------------------- /dataprocess/datautils3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from medpy.io import load 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | 11 | class MyDataset(Dataset): 12 | def __init__(self, root_l, subfolder_l, root_s, subfolder_s, prefixs, transform=None): 13 | super(MyDataset, self).__init__() 14 | self.prefixs = prefixs 15 | self.l_path = os.path.join(root_l, subfolder_l) 16 | self.s_path = os.path.join(root_s, subfolder_s) 17 | self.templ = [x for x in os.listdir(self.l_path) if os.path.splitext(x)[1] == ".img"] 18 | self.temps = [x for x in os.listdir(self.s_path) if os.path.splitext(x)[1] == ".img"] 19 | self.image_list_l = [] 20 | self.image_list_s = [] 21 | 22 | for file in self.templ: 23 | for pre in prefixs: 24 | if pre in file: 25 | self.image_list_l.append(file) 26 | 27 | for file in self.temps: 28 | for pre in prefixs: 29 | if pre in file: 30 | self.image_list_s.append(file) 31 | # print(self.image_list_l) 32 | # print(self.image_list_s) 33 | self.transform = transform 34 | 35 | def __len__(self): 36 | return len(self.image_list_l) 37 | 38 | def __getitem__(self, item): 39 | # LPET 40 | image_path_l = os.path.join(self.l_path, self.image_list_l[item]) 41 | image_l, h = load(image_path_l) 42 | image = np.array(image_l) 43 | 44 | if self.transform is not None: 45 | image = self.transform(image_l) 46 | 47 | # SPET 48 | image_path_s = os.path.join(self.s_path, self.image_list_s[item]) 49 | image_s, h2 = load(image_path_s) 50 | image_s = np.array(image_s) 51 | 52 | image_l = image_l[np.newaxis, :] 53 | image_s = image_s[np.newaxis, :] 54 | image_l = torch.Tensor(image_l) 55 | image_s = torch.Tensor(image_s) 56 | # print(image.shape) 57 | if self.transform is not None: 58 | image = self.transform(image_s) 59 | 60 | return image_l, image_s 61 | 62 | 63 | class MyMultiDataset(Dataset): 64 | def __init__(self, root_l, subfolder_l, root_s, subfolder_s, root_mri, subfolder_mri, prefixs, transform=None): 65 | super(MyMultiDataset, self).__init__() 66 | self.prefixs = prefixs 67 | self.l_path = os.path.join(root_l, subfolder_l) 68 | self.s_path = os.path.join(root_s, subfolder_s) 69 | self.mri_path = os.path.join(root_mri, subfolder_mri) 70 | self.templ = [x for x in os.listdir(self.l_path) if os.path.splitext(x)[1] == ".img"] 71 | self.temps = [x for x in os.listdir(self.s_path) if os.path.splitext(x)[1] == ".img"] 72 | self.temp_mri = [x for x in os.listdir(self.mri_path) if os.path.splitext(x)[1] == ".img"] 73 | self.image_list_l = [] 74 | self.image_list_s = [] 75 | self.image_list_mri = [] 76 | 77 | for file in self.templ: 78 | for pre in prefixs: 79 | if pre in file: 80 | self.image_list_l.append(file) 81 | 82 | for file in self.temps: 83 | for pre in prefixs: 84 | if pre in file: 85 | self.image_list_s.append(file) 86 | 87 | for file in self.temp_mri: 88 | for pre in prefixs: 89 | if pre in file: 90 | self.image_list_mri.append(file) 91 | # print(self.image_list_l) 92 | # print(self.image_list_s) 93 | self.transform = transform 94 | 95 | def __len__(self): 96 | return len(self.image_list_l) 97 | 98 | def __getitem__(self, item): 99 | # LPET 100 | image_path_l = os.path.join(self.l_path, self.image_list_l[item]) 101 | image_l, h = load(image_path_l) 102 | image = np.array(image_l) 103 | 104 | if self.transform is not None: 105 | image = self.transform(image_l) 106 | 107 | # SPET 108 | image_path_s = os.path.join(self.s_path, self.image_list_s[item]) 109 | image_s, h2 = load(image_path_s) 110 | image_s = np.array(image_s) 111 | # MRI 112 | image_path_mri = os.path.join(self.mri_path, self.image_list_mri[item]) 113 | image_mri, h3 = load(image_path_mri) 114 | image_mri = np.array(image_mri) 115 | 116 | image_l = image_l[np.newaxis, :] 117 | image_s = image_s[np.newaxis, :] 118 | image_mri = image_mri[np.newaxis, :] 119 | image_l = torch.Tensor(image_l) 120 | image_s = torch.Tensor(image_s) 121 | image_mri = torch.Tensor(image_mri) 122 | 123 | if self.transform is not None: 124 | image = self.transform(image_s) 125 | 126 | return image_l, image_s, image_mri 127 | 128 | 129 | def loadData(root1, subfolder1, root2, subfolder2, prefixs, batch_size, shuffle=True): 130 | transform = None 131 | dataset = MyDataset(root1, subfolder1, root2, subfolder2, prefixs, transform=transform) 132 | 133 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) 134 | 135 | 136 | def loadMultiData(root1, subfolder1, root2, subfolder2, root3, subfolder3, prefixs, batch_size, shuffle=True): 137 | transform = None 138 | dataset = MyMultiDataset(root1, subfolder1, root2, subfolder2, root3, subfolder3, prefixs, transform=transform) 139 | 140 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) 141 | 142 | 143 | if __name__ == '__main__': 144 | pass 145 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | 3 | # @time:2023/5/8 12:41 4 | 5 | # Author:Cui 6 | 7 | import os 8 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 9 | import shutil 10 | 11 | import SimpleITK as sitk 12 | import numpy as np 13 | import torch 14 | from medpy.io import load 15 | from skimage.metrics import normalized_root_mse as nmse 16 | from skimage.metrics import peak_signal_noise_ratio as psnr 17 | from skimage.metrics import structural_similarity as ssim 18 | 19 | import dataprocess.datautils3d as util3d 20 | from model.network import Generator 21 | from model.utils import readTxtLineAsList, norm 22 | 23 | 24 | def predict_patches(val_imgs, valloader, pretrained_model): 25 | model = torch.load(pretrained_model, map_location='cpu') 26 | model.eval() 27 | model = Generator(layers=[1, 1, 1, 1], 28 | embed_dims=[64, 128, 256, 512], 29 | mlp_ratios=[8, 8, 4, 4], 30 | heads=[4, 4, 8, 8], head_dim=[24, 24, 24, 24]).to(device) 31 | model.load_state_dict(torch.load(pretrained_model, map_location='cpu'), False) 32 | model.eval() 33 | 34 | for i, (image_l, image_s, mri) in enumerate(valloader): 35 | image_l = norm(image_l).to(device) 36 | image_s = norm(image_s) 37 | image_s = np.squeeze(image_s.detach().numpy()) 38 | image_mri = norm(mri).to(device) 39 | 40 | res = model(image_l, image_mri) 41 | # res = image_l # for test only 42 | res = res.cpu().detach().numpy() 43 | res = np.squeeze(res) 44 | 45 | # save the predicted patches 46 | savImg = sitk.GetImageFromArray(res) 47 | filename = f'cut_{i:04d}' + '.img' 48 | savepath = './imgs/predicted_' + val_imgs[0] + '_patches' 49 | if not os.path.exists(savepath): 50 | os.makedirs(savepath) 51 | sitk.WriteImage(savImg, savepath + '/' + filename) 52 | # save the real patches 53 | savImg = sitk.GetImageFromArray(image_s) 54 | filename = f'cut_{i:04d}' + '.img' 55 | savepath = './imgs/real_' + val_imgs[0] + '_patches' 56 | if not os.path.exists(savepath): 57 | os.makedirs(savepath) 58 | sitk.WriteImage(savImg, savepath + '/' + filename) 59 | 60 | 61 | def concat(path, savedir, val_imgs): 62 | all_names = [] 63 | for root, dirs, files in os.walk(path): 64 | all_names = (files) 65 | all_name = [] 66 | for i in all_names: 67 | if os.path.splitext(i)[1] == ".img": 68 | # print(i) 69 | all_name.append(i) 70 | # all_name=all_name.sort(key=lambda k:(int(k[-7:-4]))) 71 | # print(all_name, len(all_name)) 72 | cnt = 0 73 | clips = [] 74 | stride_3d = [16, 16, 16] 75 | window_3d = [64, 64, 64] 76 | for file in all_name: 77 | # print('file: ', file) 78 | image_path = os.path.join(path, file) 79 | image, h = load(image_path) 80 | image = np.array(image) 81 | # image = np.moveaxis(image, [0, 1, 2], [2, 1, 0]) 82 | clips.append(image) 83 | if (len(clips) == 125): # 729 for 8 84 | cnt = 0 85 | s_d, s_h, s_w = stride_3d 86 | w_d, w_h, w_w = window_3d 87 | counter = np.zeros([128, 128, 128]) 88 | D, H, W = counter.shape 89 | num_d = (D - w_d) // s_d + 1 90 | num_h = (H - w_h) // s_h + 1 91 | num_w = (W - w_w) // s_w + 1 92 | res_collect = np.zeros([128, 128, 128]) 93 | # print(num_d, num_h, num_w) 94 | for i in range(num_d): 95 | for j in range(num_h): 96 | for k in range(num_w): 97 | counter[i * s_d:i * s_d + w_d, j * s_h:j * s_h + w_h, k * s_w:k * s_w + w_w] += 1 98 | x = clips[cnt] 99 | cnt += 1 100 | res_collect[i * s_d:i * s_d + w_d, j * s_h:j * s_h + w_h, k * s_w:k * s_w + w_w] += x 101 | res_collect /= counter 102 | res = res_collect 103 | cnt = 0 104 | clips = [] 105 | res = np.moveaxis(res, [0, 1, 2], [2, 1, 0]) 106 | y = np.where(res < 0.01) 107 | res[y] = 0.0 108 | 109 | if not os.path.exists(savedir): 110 | os.makedirs(savedir) 111 | savImg = sitk.GetImageFromArray(res) 112 | sitk.WriteImage(savImg, savedir + '/' + val_imgs[0] + '.img') 113 | 114 | 115 | def cal(predicted_path, real_path): 116 | res, h = load(predicted_path) 117 | image_s, h = load(real_path) 118 | if res.max() > 1: 119 | res = norm(res) 120 | if image_s.max() > 1: 121 | image_s = norm(image_s) 122 | 123 | y = np.nonzero(image_s) 124 | image_s_1 = image_s[y] 125 | res_1 = res[y] 126 | # calculate PSNR 127 | cur_psnr = psnr(image_s_1, res_1, data_range=1.) 128 | cur_ssim = ssim(res, image_s, multichannel=True) 129 | cur_nmse = nmse(image_s, res) ** 2 130 | print("PSNR: %.6f SSIM: %.6f NMSE: %.6f" % (cur_psnr, cur_ssim, cur_nmse)) 131 | 132 | 133 | def run(): 134 | val_txt_path = file_txt_dir + r"Ex" + str(Ex_num) + r"/val.txt" 135 | val_imgs = readTxtLineAsList(val_txt_path) 136 | 137 | valloader = util3d.loadMultiData(data_l_cut_path, '', data_S_cut_path, '', 138 | data_mri_cut_path, '', prefixs=val_imgs, 139 | batch_size=1, shuffle=False) 140 | 141 | """predict patches""" 142 | print('============= predict patches ===================') 143 | predict_patches(val_imgs, valloader, pretrained_model) 144 | """concat patches""" 145 | print('=========== concat predicted patches =============') 146 | # predicted pathces 147 | predicted_path = './imgs/predicted_' + val_imgs[0] + '_patches' 148 | save_path_f = './imgs/predicted_' + val_imgs[0] 149 | concat(predicted_path, save_path_f, val_imgs) 150 | # real patches 151 | real_path = './imgs/real_' + val_imgs[0] + '_patches' 152 | save_path_r = './imgs/real_' + val_imgs[0] 153 | concat(real_path, save_path_r, val_imgs) 154 | """calculate metrics""" 155 | print('============= calculate metrics ==================') 156 | save_path_fake = save_path_f + '/' + val_imgs[0] + '.img' 157 | save_path_real = save_path_r + '/' + val_imgs[0] + '.img' 158 | cal(save_path_real, save_path_fake) 159 | 160 | shutil.rmtree(predicted_path) 161 | shutil.rmtree(real_path) 162 | shutil.rmtree(save_path_r) 163 | 164 | 165 | if __name__ == '__main__': 166 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 167 | 168 | data_l_cut_path = r'./dataset/LPET_cut' 169 | data_S_cut_path = r'./dataset/SPET_cut' 170 | data_mri_cut_path = r'./dataset/T1_cut' 171 | file_txt_dir = r'./dataset/split/' 172 | 173 | Ex_num = 1 174 | pretrained_model = r'' 175 | 176 | run() 177 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 3 | import argparse 4 | import os 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | import dataprocess.datautils3d as util3d 13 | from evaluate import evaluateMulti 14 | from image_pool import ImagePool 15 | from lr_adj import update_learning_rate 16 | from model.network import Generator, Discriminator, GD_Train 17 | from model.utils import readTxtLineAsList, get_logger, weights_init, del_file, norm 18 | 19 | 20 | def parse_option(): 21 | parser = argparse.ArgumentParser('PCCGAN training and evaluation script', add_help=False) 22 | parser.add_argument('--file_txt_dir', default="./dataset/split/", type=str, 23 | help='path split txt file dir to dataset') 24 | parser.add_argument('--Ex_num', default=1, type=int, help='path split txt file dir to dataset') 25 | # dataset 26 | parser.add_argument('--data_l_cut_path', default='./dataset/LPET_cut', type=str, 27 | help='LPET cut data path') 28 | parser.add_argument('--data_s_cut_path', default='./dataset/SPET_cut', type=str, 29 | help='SPET cut data path') 30 | parser.add_argument('--data_mri_cut_path', default='./dataset/T1_cut', type=str, 31 | help='T1 cut data path') 32 | # training 33 | parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run') 34 | parser.add_argument('--batch_size', default=1, type=int, help="batch size for single GPU") 35 | parser.add_argument('--devices', default='0', type=str, help='Set the CUDA_VISIBLE_DEVICES var from this string') 36 | parser.add_argument('--lr_G', '--learning-rate-G', default=2e-4, type=float, help='initial learning rate of G') 37 | parser.add_argument('--lr_D', '--learning-rate-D', default=2e-4, type=float, help='initial learning rate of D') 38 | parser.add_argument('--lamb', default=100, type=float, dest='lamb') 39 | parser.add_argument('--beta1', default=0.9, type=float, dest='beta1') 40 | parser.add_argument('--resume', action='store_true', help='resume from checkpoint') 41 | parser.add_argument('--arc', default='Proposed', type=str, help='Network architecture') 42 | # logs 43 | parser.add_argument('--log_dir', default="./results/pretrained_model/", type=str, help='logger information dir') 44 | # validation 45 | parser.add_argument('--val_epoch_inv', default=1, type=int, help='validation interval epochs') 46 | # use checkpoint 47 | parser.add_argument('--use_checkpoint', action='store_true', help='use checkpoint during training') 48 | parser.add_argument('--checkpoint_file_path', default="./check_point.pkl", type=str, help='pretrained weight path') 49 | # use pretrained model 50 | parser.add_argument('--pretrained_model', action='store_true', help='pretrained weight path') 51 | parser.add_argument('--pretrained_model_dir', default="./results/pretrained_model", type=str, 52 | help='pretrained weight path') 53 | parser.add_argument('--temp_val_img_cut_dir', default="./temp/", type=str, help='validation image cut path') 54 | 55 | args, unparsed = parser.parse_known_args() 56 | 57 | return args 58 | 59 | 60 | def train(args): 61 | """Del and create result folders""" 62 | if not os.path.exists(args.pretrained_model_dir): 63 | os.makedirs(args.pretrained_model_dir) 64 | if not os.path.exists(args.temp_val_img_cut_dir): 65 | os.makedirs(args.temp_val_img_cut_dir) 66 | else: 67 | del_file(args.temp_val_img_cut_dir) 68 | 69 | """Dataset""" 70 | train_txt_path = args.file_txt_dir + r"Ex" + str(args.Ex_num) + r"/train.txt" 71 | val_txt_path = args.file_txt_dir + r"Ex" + str(args.Ex_num) + r"/val.txt" 72 | train_imgs = readTxtLineAsList(train_txt_path) 73 | val_imgs = readTxtLineAsList(val_txt_path) 74 | trainloader = util3d.loadMultiData(args.data_l_cut_path, '', args.data_s_cut_path, '', args.data_mri_cut_path, '', 75 | prefixs=train_imgs, batch_size=args.batch_size, shuffle=True) 76 | valloader = util3d.loadMultiData(args.data_l_cut_path, '', args.data_s_cut_path, '', args.data_mri_cut_path, '', 77 | prefixs=val_imgs, batch_size=1, shuffle=False) 78 | 79 | device = torch.device('cuda:' + args.devices if torch.cuda.is_available() else "cpu") 80 | 81 | imgpool = ImagePool(5) 82 | 83 | """Hyper-parameters""" 84 | lr_G, lr_D = args.lr_G, args.lr_D 85 | beta1 = args.beta1 86 | lamb = args.lamb # 87 | epochs = args.epochs 88 | start_epoch = 0 89 | 90 | """Network""" 91 | G = Generator(layers=[1, 1, 1, 1], # 2, 2, 2, 2 92 | embed_dims=[64, 128, 256, 512], # 64, 128, 256, 512 93 | mlp_ratios=[8, 8, 4, 4], 94 | heads=[4, 4, 8, 8], head_dim=[24, 24, 24, 24], 95 | type=args.arc).to(device) 96 | D = Discriminator().to(device) 97 | 98 | L1 = nn.L1Loss().to(device) 99 | 100 | optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(beta1, 0.999)) 101 | optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(beta1, 0.999)) 102 | 103 | """Loggings""""" 104 | log_dir = args.log_dir + r"Ex" + str(args.Ex_num) + "/" + str(args.arc) + "_lrG_" + '{:g}'.format( 105 | args.lr_G) + "_lamb_"'{:g}'.format(lamb) 106 | time_str = datetime.strftime(datetime.now(), '%m-%d_%H-%M-%S') 107 | if not os.path.exists(log_dir): 108 | os.makedirs(log_dir) 109 | log_full_path = log_dir + '/' + time_str + '.log' 110 | logger = get_logger(log_full_path) 111 | 112 | """Training""" 113 | D_Loss, G_Loss, Epochs = [], [], range(1, epochs + 1) 114 | PSNR_val_best, PSNR_val_epoch_best = start_epoch, start_epoch 115 | torch.cuda.empty_cache() 116 | # init 117 | G.apply(weights_init) 118 | D.apply(weights_init) 119 | 120 | for epoch in range(start_epoch, epochs): 121 | D_losses, G_losses, batch, d_l, g_l = [], [], 0, 0, 0 122 | for i, (x, y) in enumerate(trainloader): 123 | X = norm(x) # LPET 124 | Y = norm(y) # SPET 125 | 126 | d_loss, g_loss = GD_Train(D, G, X, Y, optimizer_G, optimizer_D, L1, device, imgpool, lamb=lamb) 127 | D_losses.append(d_loss) 128 | G_losses.append(g_loss) 129 | d_l, g_l = np.array(D_losses).mean(), np.array(G_losses).mean() 130 | print('[%d / %d]: batch#%d loss_d= %.6f loss_g= %.6f lr_g=%.6f lr_d=%.6f' % 131 | (epoch + 1, epochs, i, d_l, g_l, optimizer_G.state_dict()['param_groups'][0]['lr'], 132 | optimizer_D.state_dict()['param_groups'][0]['lr'])) 133 | 134 | D_Loss.append(d_l) 135 | G_Loss.append(g_l) 136 | logger.info("Train => Epoch:{} Avg.D_Loss:{} Avg.G_Loss:{}".format(epoch, d_l, g_l)) 137 | 138 | torch.save(G, os.path.join(log_dir, 'last_model.pkl')) 139 | 140 | # schedulerG.step() 141 | # schedulerD.step() 142 | if epoch > 50: 143 | update_learning_rate(optimizer_G) 144 | update_learning_rate(optimizer_D) 145 | 146 | # save and update check_point 147 | if args.use_checkpoint: 148 | checkpoint = { 149 | 'epoch': epoch, 150 | 'G': G.state_dict(), 151 | 'D': D.state_dict(), 152 | 'optimizer_G': optimizer_G.state_dict(), 153 | 'optimizer_D': optimizer_D.state_dict() 154 | } 155 | torch.save(checkpoint, os.path.join(log_dir, 'checkpoint.pkl')) 156 | 157 | # validate every val_epoch_inv 158 | if epoch % args.val_epoch_inv == 0: 159 | logger.info('Validation => Epoch {}'.format(epoch)) 160 | ave_psnr, ave_ssim, ave_mnse = evaluateMulti(G, valloader, device=device) 161 | if ave_psnr >= PSNR_val_best: 162 | PSNR_val_best, PSNR_val_epoch_best = ave_psnr, epoch 163 | torch.save(G.state_dict(), os.path.join(log_dir, str(args.Ex_num) + '_best_PSNR_model.pkl')) 164 | logger.info('Val => Model saved for better PSNR!') 165 | logger.info('Val => Epoch:{} PSNR:{:.6f} SSIM:{:.6f} MRSE:{:.6f}'.format( 166 | epoch + 1, ave_psnr, ave_ssim, ave_mnse)) 167 | logger.info('Val => Best PSNR(Epoch:{}):{:.6f}'.format(PSNR_val_epoch_best, PSNR_val_best)) 168 | else: 169 | logger.info('Val => Model NOT saved for better PSNR!') 170 | logger.info('Val => Epoch:{} PSNR:{:.6f} SSIM:{:.6f} MRSE:{:.6f}'.format( 171 | epoch + 1, ave_psnr, ave_ssim, ave_mnse)) 172 | logger.info('Val => Best PSNR(Epoch:{}):{:.6f} '.format(PSNR_val_epoch_best, PSNR_val_best)) 173 | 174 | 175 | if __name__ == '__main__': 176 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 177 | args = parse_option() 178 | print(args) 179 | train(args) 180 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | from timm.models.layers import to_3tuple 7 | from torch.autograd import Variable 8 | 9 | from context_cluster3D import ContextCluster, basic_blocks 10 | 11 | 12 | class GroupNorm(nn.GroupNorm): 13 | """ 14 | Group Normalization with 1 group. 15 | Input: tensor in shape [B, C, H, W, D] 16 | """ 17 | 18 | def __init__(self, num_channels, **kwargs): 19 | super().__init__(1, num_channels, **kwargs) 20 | 21 | 22 | class PointRecuder(nn.Module): 23 | """ 24 | Point Reducer is implemented by a layer of conv since it is mathmatically equal. 25 | Input: tensor in shape [B, in_chans, H, W, D] 26 | Output: tensor in shape [B, embed_dim, H/stride, W/stride, D/stride] 27 | """ 28 | 29 | def __init__(self, patch_size=16, stride=16, padding=0, 30 | in_chans=5, embed_dim=768, norm_layer=None): 31 | super().__init__() 32 | patch_size = to_3tuple(patch_size) 33 | stride = to_3tuple(stride) 34 | padding = to_3tuple(padding) 35 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) 36 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 37 | 38 | def forward(self, x): 39 | x = self.proj(x) 40 | x = self.norm(x) 41 | return x 42 | 43 | 44 | # generator for PCC-GAN 45 | class Generator(nn.Module): 46 | def __init__(self, 47 | layers=[1, 1, 1, 1], 48 | norm_layer=GroupNorm, 49 | embed_dims=[64, 128, 256, 512], 50 | mlp_ratios=[8, 8, 4, 4], 51 | downsamples=[True, True, True, True], 52 | proposal_w=[2, 2, 2, 2], 53 | proposal_h=[2, 2, 2, 2], 54 | proposal_d=[2, 2, 2, 2], 55 | fold_w=[1, 1, 1, 1], 56 | fold_h=[1, 1, 1, 1], 57 | fold_d=[1, 1, 1, 1], 58 | heads=[4, 4, 8, 8], 59 | head_dim=[24, 24, 24, 24], 60 | down_patch_size=3, 61 | down_pad=1 62 | ): 63 | super(Generator, self).__init__() 64 | 65 | # generator for PCC-GAN 66 | self.CoCs = ContextCluster( 67 | layers=layers, embed_dims=embed_dims, norm_layer=norm_layer, 68 | mlp_ratios=mlp_ratios, downsamples=downsamples, 69 | down_patch_size=down_patch_size, down_pad=down_pad, 70 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 71 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 72 | heads=heads, head_dim=head_dim, 73 | ) 74 | 75 | def forward(self, LPET): 76 | EPET = self.CoCs(LPET) 77 | 78 | return EPET 79 | 80 | 81 | # # generator for PMC2-GAN 82 | # class Generator(nn.Module): 83 | # def __init__(self, 84 | # layers=[1, 1, 1, 1], 85 | # norm_layer=GroupNorm, 86 | # embed_dims=[64, 128, 256, 512], 87 | # mlp_ratios=[8, 8, 4, 4], 88 | # downsamples=[True, True, True, True], 89 | # proposal_w=[2, 2, 2, 2], 90 | # proposal_h=[2, 2, 2, 2], 91 | # proposal_d=[2, 2, 2, 2], 92 | # fold_w=[1, 1, 1, 1], 93 | # fold_h=[1, 1, 1, 1], 94 | # fold_d=[1, 1, 1, 1], 95 | # heads=[4, 4, 8, 8], 96 | # head_dim=[24, 24, 24, 24], 97 | # down_patch_size=3, 98 | # down_pad=1 99 | # ): 100 | # super(Generator, self).__init__() 101 | # 102 | # # generator for PCC-GAN 103 | # self.CoCs = (ContextClusterMulti( 104 | # layers, embed_dims=embed_dims, norm_layer=norm_layer, 105 | # mlp_ratios=mlp_ratios, downsamples=downsamples, 106 | # down_patch_size=down_patch_size, down_pad=down_pad, 107 | # proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 108 | # fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 109 | # heads=heads, head_dim=head_dim, )) 110 | # 111 | # def forward(self, LPET, MRI): 112 | # EPET = self.CoCs(LPET, MRI) 113 | # 114 | # return EPET 115 | 116 | 117 | class Discriminator(nn.Module): 118 | def __init__(self, 119 | layers=[1, 1, 1, 1, 1], 120 | norm_layer=GroupNorm, 121 | embed_dims=[8, 16, 32, 64, 128], 122 | mlp_ratios=[8, 8, 4, 4, 4], 123 | proposal_w=[2, 2, 2, 2, 2], proposal_h=[2, 2, 2, 2, 2], proposal_d=[2, 2, 2, 2, 2], 124 | fold_w=[1, 1, 1, 1, 1], fold_h=[1, 1, 1, 1, 1], fold_d=[1, 1, 1, 1, 1], 125 | heads=[4, 4, 8, 8, 8], head_dim=[24, 24, 24, 24, 24], 126 | # fixed settings 127 | down_patch_size=3, down_stride=2, down_pad=1, in_patch_size=3, in_stride=2, in_pad=1, drop_rate=0., 128 | act_layer=nn.GELU, use_layer_scale=True, layer_scale_init_value=1e-5, 129 | ): 130 | super().__init__() 131 | """ Encoder """ 132 | self.patch_embed = PointRecuder(patch_size=in_patch_size, stride=in_stride, padding=in_pad, 133 | in_chans=5, embed_dim=embed_dims[0]) 134 | # en0 135 | self.en0 = basic_blocks(embed_dims[0], 0, layers, mlp_ratio=mlp_ratios[0], act_layer=act_layer, 136 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 137 | layer_scale_init_value=layer_scale_init_value, 138 | proposal_w=proposal_w[0], proposal_h=proposal_h[0], proposal_d=proposal_d[0], 139 | fold_w=fold_w[0], fold_h=fold_h[0], fold_d=fold_d[0], 140 | heads=heads[0], head_dim=head_dim[0], return_center=False) 141 | # en1 142 | self.down1 = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 143 | in_chans=embed_dims[0], embed_dim=embed_dims[1]) 144 | self.en1 = basic_blocks(embed_dims[1], 1, layers, mlp_ratio=mlp_ratios[1], act_layer=act_layer, 145 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 146 | layer_scale_init_value=layer_scale_init_value, 147 | proposal_w=proposal_w[1], proposal_h=proposal_h[1], proposal_d=proposal_d[1], 148 | fold_w=fold_w[1], fold_h=fold_h[1], fold_d=fold_d[1], 149 | heads=heads[1], head_dim=head_dim[1], return_center=False) 150 | # en2 151 | self.down2 = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 152 | in_chans=embed_dims[1], embed_dim=embed_dims[2]) 153 | self.en2 = basic_blocks(embed_dims[2], 2, layers, mlp_ratio=mlp_ratios[2], act_layer=act_layer, 154 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 155 | layer_scale_init_value=layer_scale_init_value, 156 | proposal_w=proposal_w[2], proposal_h=proposal_h[2], proposal_d=proposal_d[2], 157 | fold_w=fold_w[2], fold_h=fold_h[2], fold_d=fold_d[2], 158 | heads=heads[2], head_dim=head_dim[2], return_center=False) 159 | # en3 160 | self.down3 = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 161 | in_chans=embed_dims[2], embed_dim=embed_dims[3]) 162 | self.en3 = basic_blocks(embed_dims[3], 3, layers, mlp_ratio=mlp_ratios[3], act_layer=act_layer, 163 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 164 | layer_scale_init_value=layer_scale_init_value, 165 | proposal_w=proposal_w[3], proposal_h=proposal_h[3], proposal_d=proposal_d[3], 166 | fold_w=fold_w[3], fold_h=fold_h[3], fold_d=fold_d[3], 167 | heads=heads[3], head_dim=head_dim[3], return_center=False) 168 | # en3 169 | self.down4 = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 170 | in_chans=embed_dims[3], embed_dim=embed_dims[4]) 171 | self.en4 = basic_blocks(embed_dims[4], 4, layers, mlp_ratio=mlp_ratios[4], act_layer=act_layer, 172 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 173 | layer_scale_init_value=layer_scale_init_value, 174 | proposal_w=proposal_w[4], proposal_h=proposal_h[4], proposal_d=proposal_d[4], 175 | fold_w=fold_w[4], fold_h=fold_h[4], fold_d=fold_d[4], 176 | heads=heads[4], head_dim=head_dim[4], return_center=False) 177 | 178 | self.sigmoid = nn.Sigmoid() 179 | 180 | def forward_embeddings(self, x): 181 | _, c, img_w, img_h, img_d = x.shape 182 | # print(f"img size is {c} * {img_w} * {img_h}") 183 | # register positional information buffer. 184 | range_w = torch.arange(0, img_w, step=1) / (img_w - 1.0) 185 | range_h = torch.arange(0, img_h, step=1) / (img_h - 1.0) 186 | range_d = torch.arange(0, img_d, step=1) / (img_d - 1.0) 187 | fea_pos = torch.stack(torch.meshgrid(range_w, range_h, range_d), dim=-1).float() 188 | fea_pos = fea_pos.to(x.device) 189 | fea_pos = fea_pos - 0.5 190 | pos = fea_pos.permute(3, 0, 1, 2).unsqueeze(dim=0).expand(x.shape[0], -1, -1, -1, -1) 191 | x = self.patch_embed(torch.cat([x, pos], dim=1)) 192 | return x 193 | 194 | def forward(self, x): 195 | en0 = self.forward_embeddings(x) 196 | en0 = self.en0(en0) 197 | 198 | en1 = self.down1(en0) 199 | en1 = self.en1(en1) 200 | 201 | en2 = self.down2(en1) 202 | en2 = self.en2(en2) 203 | 204 | en3 = self.down3(en2) 205 | en3 = self.en3(en3) 206 | 207 | en4 = self.down4(en3) 208 | en4 = self.en4(en4) 209 | 210 | output = self.sigmoid(en4) 211 | 212 | return output 213 | 214 | 215 | class GANLoss_smooth(nn.Module): 216 | def __init__(self, device, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 217 | tensor=torch.FloatTensor): 218 | super(GANLoss_smooth, self).__init__() 219 | self.device = device 220 | self.real_label = target_real_label 221 | self.fake_label = target_fake_label 222 | self.real_label_var = None 223 | self.fake_label_var = None 224 | self.Tensor = tensor 225 | if use_lsgan: 226 | self.loss = nn.MSELoss() 227 | else: 228 | self.loss = nn.BCELoss() 229 | 230 | def get_target_tensor(self, input, target_is_real, smooth): 231 | if target_is_real: 232 | create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) 233 | if create_label: 234 | real_tensor = self.Tensor(input.size()).fill_(self.real_label + smooth * 0.5 - 0.3) 235 | self.real_label_var = Variable(real_tensor, requires_grad=False) 236 | target_tensor = self.real_label_var 237 | else: 238 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) 239 | if create_label: 240 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label + smooth * 0.3) 241 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 242 | target_tensor = self.fake_label_var 243 | return target_tensor 244 | 245 | def __call__(self, input, target_is_real): 246 | a = random.uniform(0, 1) 247 | target_tensor = self.get_target_tensor(input, target_is_real, a) 248 | return self.loss(input, target_tensor.to(self.device)) 249 | 250 | 251 | def D_train(D: Discriminator, G: Generator, LPET, SPET, optimizer_D, device): 252 | LPET = LPET.to(device) 253 | SPET = SPET.to(device) 254 | 255 | PET = torch.cat([LPET, SPET], dim=1) 256 | 257 | D.zero_grad() 258 | 259 | # real data 260 | D_output_r = D(PET).squeeze() 261 | criGAN = GANLoss_smooth(device=device) 262 | D_real_loss = criGAN(D_output_r, True) 263 | 264 | # fake data 265 | G_output = G(LPET) 266 | X_fake = torch.cat([LPET, G_output], dim=1) 267 | D_output_f = D(X_fake).squeeze() 268 | D_fake_loss = criGAN(D_output_f, False) 269 | 270 | # back prop 271 | D_loss = (D_real_loss + D_fake_loss) * 0.5 272 | 273 | D_loss.backward() 274 | optimizer_D.step() 275 | 276 | return D_loss.data.item() 277 | 278 | 279 | def G_train(D: Discriminator, G: Generator, LPET, SPET, L1, optimizer_G, device, lamb=100): 280 | LPET = LPET.to(device) 281 | SPET = SPET.to(device) 282 | 283 | G.zero_grad() 284 | 285 | # fake data 286 | G_output = G(LPET) 287 | X_fake = torch.cat([LPET, G_output], dim=1) 288 | D_output_f = D(X_fake).squeeze() 289 | criGAN = GANLoss_smooth(device=device) 290 | G_BCE_loss = criGAN(D_output_f, True) 291 | 292 | G_L1_Loss = L1(G_output, SPET) 293 | 294 | G_loss = G_BCE_loss + lamb * G_L1_Loss 295 | 296 | G_loss.backward() 297 | optimizer_G.step() 298 | 299 | return G_loss.data.item() 300 | 301 | 302 | def GD_Train(D: Discriminator, G: Generator, LPET, SPET, 303 | optimizer_G, optimizer_D, L1, device, imgpool, lamb=100): 304 | x = LPET.to(device) 305 | y = SPET.to(device) 306 | 307 | xy = torch.cat([x, y], dim=1) 308 | 309 | criGAN = GANLoss_smooth(device=device) 310 | G_output = G(x) 311 | 312 | X_fake = imgpool.query(torch.cat([x, G_output], dim=1)) 313 | 314 | # train_D 315 | optimizer_D.zero_grad() 316 | D_output_r = D(xy.detach()).squeeze() 317 | D_output_f = D(X_fake.detach()).squeeze() 318 | D_real_loss = criGAN(D_output_r, True) # real loss 319 | D_fake_loss = criGAN(D_output_f, False) # fake loss 320 | D_loss = (D_real_loss + D_fake_loss) * 0.5 321 | D_loss.backward() 322 | optimizer_D.step() 323 | 324 | # train_G 325 | optimizer_G.zero_grad() 326 | D_output = D(X_fake).squeeze() 327 | G_BCE_loss = criGAN(D_output, True) 328 | G_L1_Loss = L1(G_output, y) 329 | 330 | G_loss = G_BCE_loss + lamb * G_L1_Loss 331 | G_loss.backward() 332 | optimizer_G.step() 333 | 334 | return D_loss.data.item(), G_loss.data.item() 335 | 336 | 337 | if __name__ == '__main__': 338 | pass 339 | -------------------------------------------------------------------------------- /model/context_cluster3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | 3 | """ 4 | PCC-GAN generator implementation 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | from timm.models.layers import to_3tuple 11 | from timm.models.layers import trunc_normal_ 12 | from timm.models.registry import register_model 13 | 14 | 15 | class PointRecuder(nn.Module): 16 | """ 17 | Point Reducer is implemented by a layer of conv since it is mathmatically equal. 18 | Input: tensor in shape [B, in_chans, H, W, D] 19 | Output: tensor in shape [B, embed_dim, H/stride, W/stride, D/stride] 20 | """ 21 | 22 | def __init__(self, patch_size=16, stride=16, padding=0, 23 | in_chans=3, embed_dim=768, norm_layer=None): 24 | super().__init__() 25 | patch_size = to_3tuple(patch_size) 26 | stride = to_3tuple(stride) 27 | padding = to_3tuple(padding) 28 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) 29 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 30 | 31 | def forward(self, x): 32 | x = self.proj(x) 33 | x = self.norm(x) 34 | return x 35 | 36 | 37 | class PointExpander(nn.Module): 38 | """ 39 | Point Expander is implemented by a layer of decov since it is mathmatically equal. 40 | Input: tensor in shape [B, in_chans, H, W, D] 41 | Output: tensor in shape [B, embed_dim, H*stride, W*stride, D*stride] 42 | """ 43 | 44 | def __init__(self, patch_size=16, stride=16, padding=0, 45 | in_chans=3, embed_dim=768, norm_layer=None): 46 | super().__init__() 47 | patch_size = to_3tuple(patch_size) 48 | stride = to_3tuple(stride) 49 | padding = to_3tuple(padding) 50 | # print(in_chans, embed_dim) 51 | self.proj = nn.ConvTranspose3d(in_chans, embed_dim, kernel_size=patch_size, 52 | stride=stride, padding=padding) 53 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 54 | 55 | def forward(self, x): 56 | x = self.proj(x) 57 | x = self.norm(x) 58 | return x 59 | 60 | 61 | class GroupNorm(nn.GroupNorm): 62 | """ 63 | Group Normalization with 1 group. 64 | Input: tensor in shape [B, C, H, W, D] 65 | """ 66 | 67 | def __init__(self, num_channels, **kwargs): 68 | super().__init__(1, num_channels, **kwargs) 69 | 70 | 71 | def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor): 72 | """ 73 | return pair-wise similarity matrix between two tensors 74 | :param x1: [B,...,M,D] 75 | :param x2: [B,...,N,D] 76 | :return: similarity matrix [B,...,M,N] 77 | """ 78 | x1 = F.normalize(x1, dim=-1) 79 | x2 = F.normalize(x2, dim=-1) 80 | 81 | sim = torch.matmul(x1, x2.transpose(-2, -1)) 82 | return sim 83 | 84 | 85 | class Cluster(nn.Module): 86 | def __init__(self, dim, out_dim, 87 | proposal_w=2, proposal_h=2, proposal_d=2, 88 | fold_w=2, fold_h=2, fold_d=2, 89 | heads=4, head_dim=24, 90 | return_center=False): 91 | """ 92 | :param dim: channel nubmer 93 | :param out_dim: channel nubmer 94 | :param proposal_w: the sqrt(proposals) value, we can also set a different value 95 | :param proposal_h: the sqrt(proposals) value, we can also set a different value 96 | :param proposal_d: the sqrt(proposals) value, we can also set a different value 97 | :param fold_w: the sqrt(number of regions) value, we can also set a different value 98 | :param fold_h: the sqrt(number of regions) value, we can also set a different value 99 | :param fold_d: the sqrt(number of regions) value, we can also set a different value 100 | :param heads: heads number in context cluster 101 | :param head_dim: dimension of each head in context cluster 102 | :param return_center: if just return centers instead of dispatching back (deprecated). 103 | """ 104 | super().__init__() 105 | self.heads = heads 106 | self.head_dim = head_dim 107 | self.f = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for similarity 108 | self.proj = nn.Conv3d(heads * head_dim, out_dim, kernel_size=1) # for projecting channel number 109 | self.v = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for value 110 | self.sim_alpha = nn.Parameter(torch.ones(1)) 111 | self.sim_beta = nn.Parameter(torch.zeros(1)) 112 | self.centers_proposal = nn.AdaptiveAvgPool3d((proposal_w, proposal_h, proposal_d)) 113 | self.fold_w = fold_w 114 | self.fold_h = fold_h 115 | self.fold_d = fold_d 116 | self.return_center = return_center 117 | 118 | def forward(self, x): # [b,c,w,h, d] 119 | value = self.v(x) 120 | x = self.f(x) 121 | x = rearrange(x, "b (e c) w h d -> (b e) c w h d", e=self.heads) 122 | value = rearrange(value, "b (e c) w h d -> (b e) c w h d", e=self.heads) 123 | if self.fold_w > 1 and self.fold_h > 1: 124 | # split the big feature maps to small local regions to reduce computations. 125 | b0, c0, w0, h0, d0 = x.shape 126 | assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0 and d0 % self.fold_d == 0, \ 127 | f"Ensure the feature map size ({w0}*{h0}*{w0}) can be divided by fold " \ 128 | f"{self.fold_w}*{self.fold_h}*{self.fold_d}" 129 | x = rearrange(x, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2 f3) c w h d", f1=self.fold_w, 130 | f2=self.fold_h, f3=self.fold_d) # [bs*blocks,c,ks[0],ks[1],ks[2]] 131 | value = rearrange(value, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2) c w h d", f1=self.fold_w, 132 | f2=self.fold_h, f3=self.fold_d) 133 | b, c, w, h, d = x.shape 134 | centers = self.centers_proposal(x) # [b,c,C_W,C_H,C_D], we set M = C_W*C_H and N = w*h*d 135 | value_centers = rearrange(self.centers_proposal(value), 'b c w h d -> b (w h d) c') # [b,C_W,C_H,c] 136 | b, c, ww, hh, dd = centers.shape 137 | sim = torch.sigmoid( 138 | self.sim_beta + 139 | self.sim_alpha * pairwise_cos_sim( 140 | centers.reshape(b, c, -1).permute(0, 2, 1), 141 | x.reshape(b, c, -1).permute(0, 2, 1) 142 | ) 143 | ) # [B,M,N] 144 | # we use mask to sololy assign each point to one center 145 | sim_max, sim_max_idx = sim.max(dim=1, keepdim=True) 146 | mask = torch.zeros_like(sim) # binary #[B,M,N] 147 | mask.scatter_(1, sim_max_idx, 1.) 148 | sim = sim * mask 149 | value2 = rearrange(value, 'b c w h d -> b (w h d) c') # [B,N,D] 150 | # aggregate step, out shape [B,M,D] 151 | # a small bug: mask.sum should be sim.sum according to Eq. (1), 152 | # mask can be considered as a hard version of sim in our implementation. 153 | out = ((value2.unsqueeze(dim=1) * sim.unsqueeze(dim=-1)).sum(dim=2) + value_centers) / ( 154 | sim.sum(dim=-1, keepdim=True) + 1.0) # [B,M,D] 155 | 156 | if self.return_center: 157 | out = rearrange(out, "b (w h d) c -> b c w h d", w=ww, h=hh) # center shape 158 | else: 159 | # dispatch step, return to each point in a cluster 160 | out = (out.unsqueeze(dim=2) * sim.unsqueeze(dim=-1)).sum(dim=1) # [B,N,D] 161 | out = rearrange(out, "b (w h d) c -> b c w h d", w=w, h=h) # cluster shape 162 | 163 | if self.fold_w > 1 and self.fold_h > 1 and self.fold_d > 1: 164 | # recover the splited regions back to big feature maps if use the region partition. 165 | out = rearrange(out, "(b f1 f2 f3) c w h d -> b c (f1 w) (f2 h) (f3 d)", f1=self.fold_w, 166 | f2=self.fold_h, f3=self.fold_d) 167 | out = rearrange(out, "(b e) c w h d -> b (e c) w h d", e=self.heads) 168 | out = self.proj(out) 169 | return out 170 | 171 | 172 | class Mlp(nn.Module): 173 | """ 174 | Implementation of MLP with nn.Linear (would be slightly faster in both training and inference). 175 | Input: tensor with shape [B, C, H, W, D] 176 | """ 177 | 178 | def __init__(self, in_features, hidden_features=None, 179 | out_features=None, act_layer=nn.GELU, drop=0.): 180 | super().__init__() 181 | out_features = out_features or in_features 182 | hidden_features = hidden_features or in_features 183 | self.fc1 = nn.Linear(in_features, hidden_features) 184 | self.act = act_layer() 185 | self.fc2 = nn.Linear(hidden_features, out_features) 186 | self.drop = nn.Dropout(drop) 187 | self.apply(self._init_weights) 188 | 189 | def _init_weights(self, m): 190 | if isinstance(m, nn.Linear): 191 | trunc_normal_(m.weight, std=.02) 192 | if m.bias is not None: 193 | nn.init.constant_(m.bias, 0) 194 | 195 | def forward(self, x): 196 | x = self.fc1(x.permute(0, 2, 3, 4, 1)) 197 | x = self.act(x) 198 | x = self.drop(x) 199 | x = self.fc2(x).permute(0, 4, 1, 2, 3) 200 | x = self.drop(x) 201 | return x 202 | 203 | 204 | class ClusterBlock(nn.Module): 205 | """ 206 | Implementation of one block. 207 | --dim: embedding dim 208 | --mlp_ratio: mlp expansion ratio 209 | --act_layer: activation 210 | --norm_layer: normalization 211 | --drop: dropout rate 212 | --use_layer_scale, --layer_scale_init_value: LayerScale, 213 | refer to https://arxiv.org/abs/2103.17239 214 | """ 215 | 216 | def __init__(self, dim, mlp_ratio=4., 217 | act_layer=nn.GELU, norm_layer=GroupNorm, drop=0., 218 | use_layer_scale=True, layer_scale_init_value=1e-5, 219 | # for context-cluster 220 | proposal_w=2, proposal_h=2, proposal_d=2, 221 | fold_w=2, fold_h=2, fold_d=2, 222 | heads=4, head_dim=24, return_center=False): 223 | 224 | super().__init__() 225 | 226 | self.norm1 = norm_layer(dim) 227 | # dim, out_dim, proposal_w=2,proposal_h=2, fold_w=2, fold_h=2, heads=4, head_dim=24, return_center=False 228 | self.token_mixer = Cluster(dim=dim, out_dim=dim, 229 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 230 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 231 | heads=heads, head_dim=head_dim, return_center=return_center) 232 | self.norm2 = norm_layer(dim) 233 | mlp_hidden_dim = int(dim * mlp_ratio) 234 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 235 | 236 | # The following technique is useful to train deep ContextClusters. 237 | self.use_layer_scale = use_layer_scale 238 | if use_layer_scale: 239 | self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 240 | self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 241 | 242 | def forward(self, x): 243 | if self.use_layer_scale: 244 | x = x + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)) 245 | x = x + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)) 246 | else: 247 | x = x + self.token_mixer(self.norm1(x)) 248 | x = x + self.mlp(self.norm2(x)) 249 | return x 250 | 251 | 252 | def basic_blocks(dim, index, layers, 253 | mlp_ratio=4., 254 | act_layer=nn.GELU, norm_layer=GroupNorm, 255 | drop_rate=.0, 256 | use_layer_scale=True, layer_scale_init_value=1e-5, 257 | # for context-cluster 258 | proposal_w=2, proposal_h=2, proposal_d=2, 259 | fold_w=2, fold_h=2, fold_d=2, 260 | heads=4, head_dim=24, return_center=False): 261 | blocks = [] 262 | for block_idx in range(layers[index]): 263 | blocks.append(ClusterBlock( 264 | dim, mlp_ratio=mlp_ratio, 265 | act_layer=act_layer, norm_layer=norm_layer, 266 | drop=drop_rate, 267 | use_layer_scale=use_layer_scale, 268 | layer_scale_init_value=layer_scale_init_value, 269 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 270 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 271 | heads=heads, head_dim=head_dim, return_center=return_center 272 | )) 273 | blocks = nn.Sequential(*blocks) 274 | 275 | return blocks 276 | 277 | 278 | class ContextCluster(nn.Module): 279 | """ 280 | ContextCluster, the main class of our model 281 | --layers: [x,x,x,x], number of blocks for the 4 stages 282 | --embed_dims, --mlp_ratios, the embedding dims, mlp ratios 283 | --norm_layer, --act_layer: define the types of normalization and activation 284 | --in_patch_size, --in_stride, --in_pad: specify the patch embedding 285 | for the input image 286 | --down_patch_size --down_stride --down_pad: 287 | specify the downsample (patch embed.) 288 | """ 289 | 290 | def __init__(self, layers, embed_dims=None, 291 | mlp_ratios=None, 292 | norm_layer=nn.BatchNorm3d, act_layer=nn.GELU, 293 | in_patch_size=3, in_stride=2, in_pad=1, 294 | down_patch_size=3, down_stride=2, down_pad=1, 295 | up_patch_size=2, up_stride=2, up_pad=0, 296 | drop_rate=0., 297 | use_layer_scale=True, layer_scale_init_value=1e-5, 298 | # the parameters for context-cluster 299 | proposal_w=[2, 2, 2, 2], proposal_h=[2, 2, 2, 2], proposal_d=[2, 2, 2, 2], 300 | fold_w=[8, 4, 2, 1], fold_h=[8, 4, 2, 1], fold_d=[8, 4, 2, 1], 301 | heads=[2, 4, 6, 8], head_dim=[16, 16, 32, 32], 302 | **kwargs): 303 | super().__init__() 304 | 305 | """ Encoder """ 306 | self.patch_embed = PointRecuder(patch_size=in_patch_size, stride=in_stride, padding=in_pad, 307 | in_chans=4, embed_dim=embed_dims[0]) 308 | # en0 309 | self.en0 = basic_blocks(embed_dims[0], 0, layers, mlp_ratio=mlp_ratios[0], act_layer=act_layer, 310 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 311 | layer_scale_init_value=layer_scale_init_value, 312 | proposal_w=proposal_w[0], proposal_h=proposal_h[0], proposal_d=proposal_d[0], 313 | fold_w=fold_w[0], fold_h=fold_h[0], fold_d=fold_d[0], 314 | heads=heads[0], head_dim=head_dim[0], return_center=False) 315 | # en1 316 | self.down1 = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 317 | in_chans=embed_dims[0], embed_dim=embed_dims[1]) 318 | self.en1 = basic_blocks(embed_dims[1], 1, layers, mlp_ratio=mlp_ratios[1], act_layer=act_layer, 319 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 320 | layer_scale_init_value=layer_scale_init_value, 321 | proposal_w=proposal_w[1], proposal_h=proposal_h[1], proposal_d=proposal_d[1], 322 | fold_w=fold_w[1], fold_h=fold_h[1], fold_d=fold_d[1], 323 | heads=heads[1], head_dim=head_dim[1], return_center=False) 324 | # en2 325 | self.down2 = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 326 | in_chans=embed_dims[1], embed_dim=embed_dims[2]) 327 | self.en2 = basic_blocks(embed_dims[2], 2, layers, mlp_ratio=mlp_ratios[2], act_layer=act_layer, 328 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 329 | layer_scale_init_value=layer_scale_init_value, 330 | proposal_w=proposal_w[2], proposal_h=proposal_h[2], proposal_d=proposal_d[2], 331 | fold_w=fold_w[2], fold_h=fold_h[2], fold_d=fold_d[2], 332 | heads=heads[2], head_dim=head_dim[2], return_center=False) 333 | # en3 334 | self.down3 = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 335 | in_chans=embed_dims[2], embed_dim=embed_dims[3]) 336 | self.en3 = basic_blocks(embed_dims[3], 3, layers, mlp_ratio=mlp_ratios[3], act_layer=act_layer, 337 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 338 | layer_scale_init_value=layer_scale_init_value, 339 | proposal_w=proposal_w[3], proposal_h=proposal_h[3], proposal_d=proposal_d[3], 340 | fold_w=fold_w[3], fold_h=fold_h[3], fold_d=fold_d[3], 341 | heads=heads[3], head_dim=head_dim[3], return_center=False) 342 | 343 | """Decoder""" 344 | # de0 345 | self.de0 = basic_blocks(embed_dims[3], 3, layers, mlp_ratio=mlp_ratios[3], act_layer=act_layer, 346 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 347 | layer_scale_init_value=layer_scale_init_value, 348 | proposal_w=proposal_w[3], proposal_h=proposal_h[3], proposal_d=proposal_h[3], 349 | fold_w=fold_w[3], fold_h=fold_h[3], fold_d=fold_d[3], 350 | heads=heads[3], head_dim=head_dim[3], return_center=False) 351 | self.up0 = PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 352 | in_chans=embed_dims[3], embed_dim=embed_dims[2]) 353 | # de1 354 | self.de1 = basic_blocks(embed_dims[2], 2, layers, mlp_ratio=mlp_ratios[2], act_layer=act_layer, 355 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 356 | layer_scale_init_value=layer_scale_init_value, 357 | proposal_w=proposal_w[2], proposal_h=proposal_h[2], proposal_d=proposal_d[2], 358 | fold_w=fold_w[2], fold_h=fold_h[2], fold_d=fold_d[2], 359 | heads=heads[2], head_dim=head_dim[2], return_center=False) 360 | self.up1 = PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 361 | in_chans=embed_dims[2], embed_dim=embed_dims[1]) 362 | # de2 363 | self.de2 = basic_blocks(embed_dims[1], 1, layers, mlp_ratio=mlp_ratios[1], act_layer=act_layer, 364 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 365 | layer_scale_init_value=layer_scale_init_value, 366 | proposal_w=proposal_w[1], proposal_h=proposal_h[1], proposal_d=proposal_d[1], 367 | fold_w=fold_w[1], fold_h=fold_h[1], fold_d=fold_d[1], 368 | heads=heads[1], head_dim=head_dim[1], return_center=False) 369 | self.up2 = PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 370 | in_chans=embed_dims[1], embed_dim=embed_dims[0]) 371 | # de3 372 | self.de3 = basic_blocks(embed_dims[0], 0, layers, mlp_ratio=mlp_ratios[0], act_layer=act_layer, 373 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 374 | layer_scale_init_value=layer_scale_init_value, 375 | proposal_w=proposal_w[0], proposal_h=proposal_h[0], proposal_d=proposal_d[0], 376 | fold_w=fold_w[0], fold_h=fold_h[0], fold_d=fold_d[0], 377 | heads=heads[0], head_dim=head_dim[0], return_center=False) 378 | self.patch_expand = nn.Sequential( 379 | PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 380 | in_chans=embed_dims[0], embed_dim=3), 381 | nn.Conv3d(in_channels=3, out_channels=1, kernel_size=1, stride=1), 382 | nn.LeakyReLU(), 383 | nn.BatchNorm3d(1), 384 | ) 385 | 386 | def forward_embeddings(self, x): 387 | _, c, img_w, img_h, img_d = x.shape 388 | # print(f"img size is {c} * {img_w} * {img_h}") 389 | # register positional information buffer. 390 | range_w = torch.arange(0, img_w, step=1) / (img_w - 1.0) 391 | range_h = torch.arange(0, img_h, step=1) / (img_h - 1.0) 392 | range_d = torch.arange(0, img_d, step=1) / (img_d - 1.0) 393 | fea_pos = torch.stack(torch.meshgrid(range_w, range_h, range_d), dim=-1).float() 394 | fea_pos = fea_pos.to(x.device) 395 | fea_pos = fea_pos - 0.5 396 | 397 | pos = fea_pos.permute(3, 0, 1, 2).unsqueeze(dim=0).expand(x.shape[0], -1, -1, -1, -1) 398 | 399 | x = self.patch_embed(torch.cat([x, pos], dim=1)) 400 | 401 | return x 402 | 403 | def restore_embeddings(self, x): 404 | x = self.patch_expand(x) 405 | 406 | return x 407 | 408 | def forward(self, x): 409 | en0_emb = self.forward_embeddings(x) 410 | en0 = self.en0(en0_emb) 411 | 412 | en1 = self.down1(en0) 413 | en1 = self.en1(en1) 414 | 415 | en2 = self.down2(en1) 416 | en2 = self.en2(en2) 417 | 418 | en3 = self.down3(en2) 419 | en3 = self.en3(en3) 420 | 421 | de0 = self.de0(en3) 422 | de0 = self.up0(de0) + en2 423 | 424 | de1 = self.de1(de0) 425 | de1 = self.up1(de1) + en1 426 | 427 | de2 = self.de2(de1) 428 | de2 = self.up2(de2) + en0 429 | 430 | de3 = self.de3(de2) + en0_emb 431 | 432 | de3 = self.patch_expand(de3) 433 | 434 | output = de3 435 | 436 | return output 437 | 438 | 439 | @register_model 440 | def pccgen(**kwargs): 441 | layers = [1, 1, 1, 1] 442 | norm_layer = GroupNorm 443 | embed_dims = [64, 128, 256, 512] 444 | mlp_ratios = [8, 8, 4, 4] 445 | downsamples = [True, True, True, True] 446 | proposal_w = [2, 2, 2, 2] 447 | proposal_h = [2, 2, 2, 2] 448 | proposal_d = [2, 2, 2, 2] 449 | fold_w = [1, 1, 1, 1] 450 | fold_h = [1, 1, 1, 1] 451 | fold_d = [1, 1, 1, 1] 452 | heads = [4, 4, 8, 8] 453 | head_dim = [24, 24, 24, 24] 454 | down_patch_size = 3 455 | down_pad = 1 456 | model = ContextCluster( 457 | layers, embed_dims=embed_dims, norm_layer=norm_layer, 458 | mlp_ratios=mlp_ratios, downsamples=downsamples, 459 | down_patch_size=down_patch_size, down_pad=down_pad, 460 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 461 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 462 | heads=heads, head_dim=head_dim, 463 | **kwargs) 464 | 465 | return model 466 | 467 | 468 | if __name__ == '__main__': 469 | input = torch.rand(2, 1, 64, 64, 64) 470 | model = pccgen() 471 | out = model(input) 472 | print(out.shape) 473 | -------------------------------------------------------------------------------- /model/context_cluster3D_Multi.py: -------------------------------------------------------------------------------- 1 | """ 2 | PMC2-GAN generator implementation 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from timm.models.layers import to_3tuple 9 | from timm.models.layers import trunc_normal_ 10 | from timm.models.registry import register_model 11 | 12 | 13 | class PointRecuder(nn.Module): 14 | """ 15 | Point Reducer is implemented by a layer of conv since it is mathmatically equal. 16 | Input: tensor in shape [B, in_chans, H, W, D] 17 | Output: tensor in shape [B, embed_dim, H/stride, W/stride, D/stride] 18 | """ 19 | 20 | def __init__(self, patch_size=16, stride=16, padding=0, 21 | in_chans=3, embed_dim=768, norm_layer=None): 22 | super().__init__() 23 | patch_size = to_3tuple(patch_size) 24 | stride = to_3tuple(stride) 25 | padding = to_3tuple(padding) 26 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) 27 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 28 | 29 | def forward(self, x): 30 | x = self.proj(x) 31 | x = self.norm(x) 32 | return x 33 | 34 | 35 | class PointExpander(nn.Module): 36 | """ 37 | Point Expander is implemented by a layer of decov since it is mathmatically equal. 38 | Input: tensor in shape [B, in_chans, H, W, D] 39 | Output: tensor in shape [B, embed_dim, H*stride, W*stride, D*stride] 40 | """ 41 | 42 | def __init__(self, patch_size=16, stride=16, padding=0, 43 | in_chans=3, embed_dim=768, norm_layer=None): 44 | super().__init__() 45 | patch_size = to_3tuple(patch_size) 46 | stride = to_3tuple(stride) 47 | padding = to_3tuple(padding) 48 | # print(in_chans, embed_dim) 49 | self.proj = nn.ConvTranspose3d(in_chans, embed_dim, kernel_size=patch_size, 50 | stride=stride, padding=padding) 51 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 52 | 53 | def forward(self, x): 54 | x = self.proj(x) 55 | x = self.norm(x) 56 | return x 57 | 58 | 59 | class GroupNorm(nn.GroupNorm): 60 | """ 61 | Group Normalization with 1 group. 62 | Input: tensor in shape [B, C, H, W, D] 63 | """ 64 | 65 | def __init__(self, num_channels, **kwargs): 66 | super().__init__(1, num_channels, **kwargs) 67 | 68 | 69 | def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor): 70 | """ 71 | return pair-wise similarity matrix between two tensors 72 | :param x1: [B,...,M,D] 73 | :param x2: [B,...,N,D] 74 | :return: similarity matrix [B,...,M,N] 75 | """ 76 | x1 = F.normalize(x1, dim=-1) 77 | x2 = F.normalize(x2, dim=-1) 78 | 79 | sim = torch.matmul(x1, x2.transpose(-2, -1)) 80 | return sim 81 | 82 | 83 | class Cluster(nn.Module): 84 | def __init__(self, dim, out_dim, 85 | proposal_w=2, proposal_h=2, proposal_d=2, 86 | fold_w=2, fold_h=2, fold_d=2, 87 | heads=4, head_dim=24, 88 | return_center=False): 89 | """ 90 | :param dim: channel nubmer 91 | :param out_dim: channel nubmer 92 | :param proposal_w: the sqrt(proposals) value, we can also set a different value 93 | :param proposal_h: the sqrt(proposals) value, we can also set a different value 94 | :param proposal_d: the sqrt(proposals) value, we can also set a different value 95 | :param fold_w: the sqrt(number of regions) value, we can also set a different value 96 | :param fold_h: the sqrt(number of regions) value, we can also set a different value 97 | :param fold_d: the sqrt(number of regions) value, we can also set a different value 98 | :param heads: heads number in context cluster 99 | :param head_dim: dimension of each head in context cluster 100 | :param return_center: if just return centers instead of dispatching back (deprecated). 101 | """ 102 | super().__init__() 103 | self.heads = heads 104 | self.head_dim = head_dim 105 | self.f = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for similarity 106 | self.proj = nn.Conv3d(heads * head_dim, out_dim, kernel_size=1) # for projecting channel number 107 | self.v = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for value 108 | self.sim_alpha = nn.Parameter(torch.ones(1)) 109 | self.sim_beta = nn.Parameter(torch.zeros(1)) 110 | self.centers_proposal = nn.AdaptiveAvgPool3d((proposal_w, proposal_h, proposal_d)) 111 | self.fold_w = fold_w 112 | self.fold_h = fold_h 113 | self.fold_d = fold_d 114 | self.return_center = return_center 115 | 116 | def forward(self, x): # [b,c,w,h, d] 117 | value = self.v(x) 118 | x = self.f(x) 119 | x = rearrange(x, "b (e c) w h d -> (b e) c w h d", e=self.heads) 120 | value = rearrange(value, "b (e c) w h d -> (b e) c w h d", e=self.heads) 121 | if self.fold_w > 1 and self.fold_h > 1: 122 | # split the big feature maps to small local regions to reduce computations. 123 | b0, c0, w0, h0, d0 = x.shape 124 | assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0 and d0 % self.fold_d == 0, \ 125 | f"Ensure the feature map size ({w0}*{h0}*{w0}) can be divided by fold " \ 126 | f"{self.fold_w}*{self.fold_h}*{self.fold_d}" 127 | x = rearrange(x, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2 f3) c w h d", f1=self.fold_w, 128 | f2=self.fold_h, f3=self.fold_d) # [bs*blocks,c,ks[0],ks[1],ks[2]] 129 | value = rearrange(value, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2) c w h d", f1=self.fold_w, 130 | f2=self.fold_h, f3=self.fold_d) 131 | b, c, w, h, d = x.shape 132 | centers = self.centers_proposal(x) # [b,c,C_W,C_H,C_D], we set M = C_W*C_H and N = w*h*d 133 | value_centers = rearrange(self.centers_proposal(value), 'b c w h d -> b (w h d) c') # [b,C_W,C_H,c] 134 | b, c, ww, hh, dd = centers.shape 135 | sim = torch.sigmoid( 136 | self.sim_beta + 137 | self.sim_alpha * pairwise_cos_sim( 138 | centers.reshape(b, c, -1).permute(0, 2, 1), 139 | x.reshape(b, c, -1).permute(0, 2, 1) 140 | ) 141 | ) # [B,M,N] 142 | # we use mask to sololy assign each point to one center 143 | sim_max, sim_max_idx = sim.max(dim=1, keepdim=True) 144 | mask = torch.zeros_like(sim) # binary #[B,M,N] 145 | mask.scatter_(1, sim_max_idx, 1.) 146 | sim = sim * mask 147 | value2 = rearrange(value, 'b c w h d -> b (w h d) c') # [B,N,D] 148 | # aggregate step, out shape [B,M,D] 149 | # a small bug: mask.sum should be sim.sum according to Eq. (1), 150 | # mask can be considered as a hard version of sim in our implementation. 151 | out = ((value2.unsqueeze(dim=1) * sim.unsqueeze(dim=-1)).sum(dim=2) + value_centers) / ( 152 | sim.sum(dim=-1, keepdim=True) + 1.0) # [B,M,D] 153 | 154 | if self.return_center: 155 | out = rearrange(out, "b (w h d) c -> b c w h d", w=ww, h=hh) # center shape 156 | else: 157 | # dispatch step, return to each point in a cluster 158 | out = (out.unsqueeze(dim=2) * sim.unsqueeze(dim=-1)).sum(dim=1) # [B,N,D] 159 | out = rearrange(out, "b (w h d) c -> b c w h d", w=w, h=h) # cluster shape 160 | 161 | if self.fold_w > 1 and self.fold_h > 1 and self.fold_d > 1: 162 | # recover the splited regions back to big feature maps if use the region partition. 163 | out = rearrange(out, "(b f1 f2 f3) c w h d -> b c (f1 w) (f2 h) (f3 d)", f1=self.fold_w, 164 | f2=self.fold_h, f3=self.fold_d) 165 | out = rearrange(out, "(b e) c w h d -> b (e c) w h d", e=self.heads) 166 | out = self.proj(out) 167 | return out 168 | 169 | 170 | class CrossCluster(nn.Module): 171 | def __init__(self, dim, out_dim, 172 | proposal_w=2, proposal_h=2, proposal_d=2, 173 | fold_w=2, fold_h=2, fold_d=2, 174 | heads=4, head_dim=24, 175 | return_center=False): 176 | """ 177 | :param dim: channel nubmer 178 | :param out_dim: channel nubmer 179 | :param proposal_w: the sqrt(proposals) value, we can also set a different value 180 | :param proposal_h: the sqrt(proposals) value, we can also set a different value 181 | :param proposal_d: the sqrt(proposals) value, we can also set a different value 182 | :param fold_w: the sqrt(number of regions) value, we can also set a different value 183 | :param fold_h: the sqrt(number of regions) value, we can also set a different value 184 | :param fold_d: the sqrt(number of regions) value, we can also set a different value 185 | :param heads: heads number in context cluster 186 | :param head_dim: dimension of each head in context cluster 187 | :param return_center: if just return centers instead of dispatching back (deprecated). 188 | """ 189 | super().__init__() 190 | self.heads = heads 191 | self.head_dim = head_dim 192 | self.f_PET = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for similarity 193 | self.f_MRI = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for similarity 194 | self.proj = nn.Conv3d(heads * head_dim, out_dim, kernel_size=1) # for projecting channel number 195 | self.v_PET = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for value 196 | self.v_MRI = nn.Conv3d(dim, heads * head_dim, kernel_size=1) # for value 197 | self.sim_alpha = nn.Parameter(torch.ones(1)) 198 | self.sim_beta = nn.Parameter(torch.zeros(1)) 199 | self.centers_proposal_PET = nn.AdaptiveAvgPool3d((proposal_w, proposal_h, proposal_d)) 200 | self.fold_w = fold_w 201 | self.fold_h = fold_h 202 | self.fold_d = fold_d 203 | self.return_center = return_center 204 | 205 | def forward(self, PET, MRI): # [b,c,w,h, d] 206 | # calculate the center of PET 207 | value_PET, value_MRI = self.v_PET(PET), self.v_MRI(MRI) 208 | PET, MRI = self.f_PET(PET), self.f_MRI(MRI) 209 | PET = rearrange(PET, "b (e c) w h d -> (b e) c w h d", e=self.heads) 210 | MRI = rearrange(MRI, "b (e c) w h d -> (b e) c w h d", e=self.heads) 211 | value_PET = rearrange(value_PET, "b (e c) w h d -> (b e) c w h d", e=self.heads) 212 | value_MRI = rearrange(value_MRI, "b (e c) w h d -> (b e) c w h d", e=self.heads) 213 | 214 | if self.fold_w > 1 and self.fold_h > 1 and self.fold_d > 1: 215 | # split the big feature maps to small local regions to reduce computations. 216 | b0, c0, w0, h0, d0 = PET.shape 217 | assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0 and d0 % self.fold_d == 0, \ 218 | f"Ensure the feature map size ({w0}*{h0}*{w0}) can be divided by fold " \ 219 | f"{self.fold_w}*{self.fold_h}*{self.fold_d}" 220 | PET = rearrange(PET, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2 f3) c w h d", f1=self.fold_w, 221 | f2=self.fold_h, f3=self.fold_d) # [bs*blocks,c,ks[0],ks[1],ks[2]] 222 | MRI = rearrange(MRI, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2 f3) c w h d", f1=self.fold_w, 223 | f2=self.fold_h, f3=self.fold_d) # [bs*blocks,c,ks[0],ks[1],ks[2]] 224 | value_PET = rearrange(value_PET, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2) c w h d", f1=self.fold_w, 225 | f2=self.fold_h, f3=self.fold_d) 226 | value_MRI = rearrange(value_MRI, "b c (f1 w) (f2 h) (f3 d) -> (b f1 f2) c w h d", f1=self.fold_w, 227 | f2=self.fold_h, f3=self.fold_d) 228 | assert PET.shape == MRI.shape, f"Ensure the size of PET is equal to that of MRI" 229 | b, c, w, h, d = PET.shape 230 | centers_PET = self.centers_proposal_PET(PET) # [b,c,C_W,C_H,C_D], we set M = C_W*C_H and N = w*h*d 231 | value_centers_PET = rearrange(self.centers_proposal_PET(value_PET), 'b c w h d -> b (w h d) c') # [b,C_W,C_H,c] 232 | b, c, ww, hh, dd = centers_PET.shape 233 | sim = torch.sigmoid( 234 | self.sim_beta + 235 | self.sim_alpha * pairwise_cos_sim( 236 | centers_PET.reshape(b, c, -1).permute(0, 2, 1), 237 | MRI.reshape(b, c, -1).permute(0, 2, 1) 238 | ) 239 | ) # [B,M,N] 240 | # we use mask to sololy assign each point to one center 241 | sim_max, sim_max_idx = sim.max(dim=1, keepdim=True) 242 | mask = torch.zeros_like(sim) # binary #[B,M,N] 243 | mask.scatter_(1, sim_max_idx, 1.) 244 | sim = sim * mask 245 | value2 = rearrange(value_MRI, 'b c w h d -> b (w h d) c') # [B,N,D] 246 | # aggregate step, out shape [B,M,D] 247 | out = ((value2.unsqueeze(dim=1) * sim.unsqueeze(dim=-1)).sum(dim=2) + value_centers_PET) / ( 248 | sim.sum(dim=-1, keepdim=True) + 1.0) # [B,M,D] 249 | 250 | if self.return_center: 251 | out = rearrange(out, "b (w h d) c -> b c w h d", w=ww, h=hh) # center shape 252 | else: 253 | # dispatch step, return to each point in a cluster 254 | out = (out.unsqueeze(dim=2) * sim.unsqueeze(dim=-1)).sum(dim=1) # [B,N,D] 255 | out = rearrange(out, "b (w h d) c -> b c w h d", w=w, h=h) # cluster shape 256 | 257 | if self.fold_w > 1 and self.fold_h > 1 and self.fold_d > 1: 258 | # recover the splited regions back to big feature maps if use the region partition. 259 | out = rearrange(out, "(b f1 f2 f3) c w h d -> b c (f1 w) (f2 h) (f3 d)", f1=self.fold_w, 260 | f2=self.fold_h, f3=self.fold_d) 261 | out = rearrange(out, "(b e) c w h d -> b (e c) w h d", e=self.heads) 262 | out = self.proj(out) 263 | return out 264 | 265 | 266 | class Mlp(nn.Module): 267 | """ 268 | Implementation of MLP with nn.Linear (would be slightly faster in both training and inference). 269 | Input: tensor with shape [B, C, H, W, D] 270 | """ 271 | 272 | def __init__(self, in_features, hidden_features=None, 273 | out_features=None, act_layer=nn.GELU, drop=0.): 274 | super().__init__() 275 | out_features = out_features or in_features 276 | hidden_features = hidden_features or in_features 277 | self.fc1 = nn.Linear(in_features, hidden_features) 278 | self.act = act_layer() 279 | self.fc2 = nn.Linear(hidden_features, out_features) 280 | self.drop = nn.Dropout(drop) 281 | self.apply(self._init_weights) 282 | 283 | def _init_weights(self, m): 284 | if isinstance(m, nn.Linear): 285 | trunc_normal_(m.weight, std=.02) 286 | if m.bias is not None: 287 | nn.init.constant_(m.bias, 0) 288 | 289 | def forward(self, x): 290 | x = self.fc1(x.permute(0, 2, 3, 4, 1)) 291 | x = self.act(x) 292 | x = self.drop(x) 293 | x = self.fc2(x).permute(0, 4, 1, 2, 3) 294 | x = self.drop(x) 295 | return x 296 | 297 | 298 | class ClusterBlock(nn.Module): 299 | """ 300 | Implementation of one sinlge context cluster block. 301 | --dim: embedding dim 302 | --mlp_ratio: mlp expansion ratio 303 | --act_layer: activation 304 | --norm_layer: normalization 305 | --drop: dropout rate 306 | --use_layer_scale, --layer_scale_init_value: LayerScale, 307 | refer to https://arxiv.org/abs/2103.17239 308 | """ 309 | 310 | def __init__(self, dim, mlp_ratio=4., 311 | act_layer=nn.GELU, norm_layer=GroupNorm, drop=0., 312 | use_layer_scale=True, layer_scale_init_value=1e-5, 313 | # for context-cluster 314 | proposal_w=2, proposal_h=2, proposal_d=2, 315 | fold_w=2, fold_h=2, fold_d=2, 316 | heads=4, head_dim=24, return_center=False): 317 | 318 | super().__init__() 319 | 320 | self.norm1 = norm_layer(dim) 321 | # dim, out_dim, proposal_w=2,proposal_h=2, fold_w=2, fold_h=2, heads=4, head_dim=24, return_center=False 322 | self.token_mixer = Cluster(dim=dim, out_dim=dim, 323 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 324 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 325 | heads=heads, head_dim=head_dim, return_center=return_center) 326 | self.norm2 = norm_layer(dim) 327 | mlp_hidden_dim = int(dim * mlp_ratio) 328 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 329 | 330 | # The following technique is useful to train deep ContextClusters. 331 | self.use_layer_scale = use_layer_scale 332 | if use_layer_scale: 333 | self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 334 | self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 335 | 336 | def forward(self, x): 337 | if self.use_layer_scale: 338 | x = x + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)) 339 | x = x + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)) 340 | else: 341 | x = x + self.token_mixer(self.norm1(x)) 342 | x = x + self.mlp(self.norm2(x)) 343 | return x 344 | 345 | 346 | class CrossClusterBlock(nn.Module): 347 | """ 348 | Implementation of one cross contextcluster block. 349 | --dim: embedding dim 350 | --mlp_ratio: mlp expansion ratio 351 | --act_layer: activation 352 | --norm_layer: normalization 353 | --drop: dropout rate 354 | --use_layer_scale, --layer_scale_init_value: LayerScale, 355 | refer to https://arxiv.org/abs/2103.17239 356 | """ 357 | 358 | def __init__(self, dim, mlp_ratio=4., 359 | act_layer=nn.GELU, norm_layer=GroupNorm, drop=0., 360 | use_layer_scale=True, layer_scale_init_value=1e-5, 361 | # for cross context-cluster 362 | proposal_w=2, proposal_h=2, proposal_d=2, 363 | fold_w=2, fold_h=2, fold_d=2, 364 | heads=4, head_dim=24, return_center=False): 365 | 366 | super().__init__() 367 | 368 | self.norm1_PET = norm_layer(dim) 369 | self.norm1_MRI = norm_layer(dim) 370 | self.norm1_MIX = norm_layer(dim) 371 | self.token_mixer_PET = Cluster(dim=dim, out_dim=dim, 372 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 373 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 374 | heads=heads, head_dim=head_dim, return_center=return_center) 375 | self.token_mixer_MRI = Cluster(dim=dim, out_dim=dim, 376 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 377 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 378 | heads=heads, head_dim=head_dim, return_center=return_center) 379 | self.token_mixer_MIX = CrossCluster(dim=dim, out_dim=dim, 380 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 381 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 382 | heads=heads, head_dim=head_dim, return_center=return_center) 383 | self.norm2_PET = norm_layer(dim) 384 | self.norm2_MRI = norm_layer(dim) 385 | self.norm2_MIX = norm_layer(dim) 386 | mlp_hidden_dim = int(dim * mlp_ratio) 387 | self.mlp_PET = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 388 | self.mlp_MRI = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 389 | self.mlp_MIX = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 390 | 391 | # The following technique is useful to train deep ContextClusters. 392 | self.use_layer_scale = use_layer_scale 393 | if use_layer_scale: 394 | self.layer_scale_1_PET = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 395 | self.layer_scale_1_MRI = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 396 | self.layer_scale_1_MIX = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 397 | self.layer_scale_2_PET = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 398 | self.layer_scale_2_MRI = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 399 | self.layer_scale_2_MIX = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 400 | 401 | def forward(self, x): 402 | PET, MRI = x 403 | if self.use_layer_scale: 404 | PET = PET + self.layer_scale_1_PET.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * \ 405 | self.token_mixer_PET(self.norm1_PET(PET)) 406 | PET = PET + self.layer_scale_2_PET.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * \ 407 | self.mlp_PET(self.norm2_PET(PET)) 408 | 409 | MRI = MRI + self.layer_scale_1_MRI.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * \ 410 | self.token_mixer_MRI(self.norm1_MRI(MRI)) 411 | MRI = MRI + self.layer_scale_2_MRI.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * \ 412 | self.mlp_MRI(self.norm2_MRI(MRI)) 413 | 414 | MIX = self.layer_scale_1_MIX.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * \ 415 | self.token_mixer_MIX(self.norm1_MIX(PET), self.norm1_MIX(MRI)) 416 | MIX = self.layer_scale_2_MIX.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * \ 417 | self.mlp_MIX(self.norm2_MIX(MIX)) 418 | else: 419 | PET = PET + self.token_mixer_PET(self.norm1_PET(PET)) 420 | PET = PET + self.mlp_PET(self.norm2_PET(PET)) 421 | 422 | MRI = MRI + self.token_mixer_MRI(self.norm1_MRI(MRI)) 423 | MRI = MRI + self.mlp_MRI(self.norm2_MRI(MRI)) 424 | 425 | MIX = self.token_mixer_MIX(self.norm1_MIX(PET), self.norm1_MIX(MRI)) 426 | MIX = self.mlp_MIX(self.norm2_MIX(MIX)) 427 | 428 | return PET, MIX, MRI 429 | 430 | 431 | def basic_blocks(dim, index, layers, 432 | mlp_ratio=4., 433 | act_layer=nn.GELU, norm_layer=GroupNorm, 434 | drop_rate=.0, 435 | use_layer_scale=True, layer_scale_init_value=1e-5, 436 | # for context-cluster 437 | proposal_w=2, proposal_h=2, proposal_d=2, 438 | fold_w=2, fold_h=2, fold_d=2, 439 | heads=4, head_dim=24, return_center=False): 440 | blocks = [] 441 | for block_idx in range(layers[index]): 442 | blocks.append(ClusterBlock( 443 | dim, mlp_ratio=mlp_ratio, 444 | act_layer=act_layer, norm_layer=norm_layer, 445 | drop=drop_rate, 446 | use_layer_scale=use_layer_scale, 447 | layer_scale_init_value=layer_scale_init_value, 448 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 449 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 450 | heads=heads, head_dim=head_dim, return_center=return_center 451 | )) 452 | blocks = nn.Sequential(*blocks) 453 | 454 | # print(blocks) 455 | 456 | return blocks 457 | 458 | 459 | def basic_blocks_Multi(dim, index, layers, 460 | mlp_ratio=4., 461 | act_layer=nn.GELU, norm_layer=GroupNorm, 462 | drop_rate=.0, 463 | use_layer_scale=True, layer_scale_init_value=1e-5, 464 | # for context-cluster 465 | proposal_w=2, proposal_h=2, proposal_d=2, 466 | fold_w=2, fold_h=2, fold_d=2, 467 | heads=4, head_dim=24, return_center=False): 468 | blocks = [] 469 | for block_idx in range(layers[index]): 470 | blocks.append(CrossClusterBlock( 471 | dim, mlp_ratio=mlp_ratio, 472 | act_layer=act_layer, norm_layer=norm_layer, 473 | drop=drop_rate, 474 | use_layer_scale=use_layer_scale, 475 | layer_scale_init_value=layer_scale_init_value, 476 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 477 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 478 | heads=heads, head_dim=head_dim, return_center=return_center 479 | )) 480 | blocks = nn.Sequential(*blocks) 481 | 482 | # print(blocks) 483 | 484 | return blocks 485 | 486 | 487 | class ContextClusterMulti(nn.Module): 488 | """ 489 | ContextCluster, the main class of our model 490 | --layers: [x,x,x,x], number of blocks for the 4 stages 491 | --embed_dims, --mlp_ratios, the embedding dims, mlp ratios 492 | --norm_layer, --act_layer: define the types of normalization and activation 493 | --in_patch_size, --in_stride, --in_pad: specify the patch embedding 494 | for the input image 495 | --down_patch_size --down_stride --down_pad: 496 | specify the downsample (patch embed.) 497 | """ 498 | 499 | def __init__(self, layers, embed_dims=None, 500 | mlp_ratios=None, 501 | norm_layer=nn.BatchNorm3d, act_layer=nn.GELU, 502 | in_patch_size=3, in_stride=2, in_pad=1, 503 | down_patch_size=3, down_stride=2, down_pad=1, 504 | up_patch_size=2, up_stride=2, up_pad=0, 505 | drop_rate=0., 506 | use_layer_scale=True, layer_scale_init_value=1e-5, 507 | # the parameters for context-cluster 508 | proposal_w=[2, 2, 2, 2], proposal_h=[2, 2, 2, 2], proposal_d=[2, 2, 2, 2], 509 | fold_w=[8, 4, 2, 1], fold_h=[8, 4, 2, 1], fold_d=[8, 4, 2, 1], 510 | heads=[2, 4, 6, 8], head_dim=[16, 16, 32, 32], 511 | **kwargs): 512 | super().__init__() 513 | 514 | """ Encoder """ 515 | self.patch_embed = PointRecuder(patch_size=in_patch_size, stride=in_stride, padding=in_pad, 516 | in_chans=4, embed_dim=embed_dims[0]) 517 | # en0 518 | self.en0 = basic_blocks_Multi(embed_dims[0], 0, layers, mlp_ratio=mlp_ratios[0], act_layer=act_layer, 519 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 520 | layer_scale_init_value=layer_scale_init_value, 521 | proposal_w=proposal_w[0], proposal_h=proposal_h[0], proposal_d=proposal_d[0], 522 | fold_w=fold_w[0], fold_h=fold_h[0], fold_d=fold_d[0], 523 | heads=heads[0], head_dim=head_dim[0], return_center=False) 524 | # en1 525 | self.down1_PET = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 526 | in_chans=embed_dims[0], embed_dim=embed_dims[1]) 527 | self.down1_MRI = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 528 | in_chans=embed_dims[0], embed_dim=embed_dims[1]) 529 | self.en1 = basic_blocks_Multi(embed_dims[1], 1, layers, mlp_ratio=mlp_ratios[1], act_layer=act_layer, 530 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 531 | layer_scale_init_value=layer_scale_init_value, 532 | proposal_w=proposal_w[1], proposal_h=proposal_h[1], proposal_d=proposal_d[1], 533 | fold_w=fold_w[1], fold_h=fold_h[1], fold_d=fold_d[1], 534 | heads=heads[1], head_dim=head_dim[1], return_center=False) 535 | # en2 536 | self.down2_PET = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 537 | in_chans=embed_dims[1], embed_dim=embed_dims[2]) 538 | self.down2_MRI = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 539 | in_chans=embed_dims[1], embed_dim=embed_dims[2]) 540 | self.en2 = basic_blocks_Multi(embed_dims[2], 2, layers, mlp_ratio=mlp_ratios[2], act_layer=act_layer, 541 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 542 | layer_scale_init_value=layer_scale_init_value, 543 | proposal_w=proposal_w[2], proposal_h=proposal_h[2], proposal_d=proposal_d[2], 544 | fold_w=fold_w[2], fold_h=fold_h[2], fold_d=fold_d[2], 545 | heads=heads[2], head_dim=head_dim[2], return_center=False) 546 | # en3 547 | self.down3_PET = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 548 | in_chans=embed_dims[2], embed_dim=embed_dims[3]) 549 | self.down3_MRI = PointRecuder(patch_size=down_patch_size, stride=down_stride, padding=down_pad, 550 | in_chans=embed_dims[2], embed_dim=embed_dims[3]) 551 | self.en3 = basic_blocks_Multi(embed_dims[3], 3, layers, mlp_ratio=mlp_ratios[3], act_layer=act_layer, 552 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 553 | layer_scale_init_value=layer_scale_init_value, 554 | proposal_w=proposal_w[3], proposal_h=proposal_h[3], proposal_d=proposal_d[3], 555 | fold_w=fold_w[3], fold_h=fold_h[3], fold_d=fold_d[3], 556 | heads=heads[3], head_dim=head_dim[3], return_center=False) 557 | 558 | """Bottleneck""" 559 | self.bot = basic_blocks_Multi(embed_dims[3], 3, layers, mlp_ratio=mlp_ratios[3], act_layer=act_layer, 560 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 561 | layer_scale_init_value=layer_scale_init_value, 562 | proposal_w=proposal_w[3], proposal_h=proposal_h[3], proposal_d=proposal_d[3], 563 | fold_w=fold_w[3], fold_h=fold_h[3], fold_d=fold_d[3], 564 | heads=heads[3], head_dim=head_dim[3], return_center=False) 565 | """Decoder""" 566 | # de0 567 | self.de0 = basic_blocks(embed_dims[3], 3, layers, mlp_ratio=mlp_ratios[3], act_layer=act_layer, 568 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 569 | layer_scale_init_value=layer_scale_init_value, 570 | proposal_w=proposal_w[3], proposal_h=proposal_h[3], proposal_d=proposal_h[3], 571 | fold_w=fold_w[3], fold_h=fold_h[3], fold_d=fold_d[3], 572 | heads=heads[3], head_dim=head_dim[3], return_center=False) 573 | self.up0 = PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 574 | in_chans=embed_dims[3], embed_dim=embed_dims[2]) 575 | # de1 576 | self.de1 = basic_blocks(embed_dims[2], 2, layers, mlp_ratio=mlp_ratios[2], act_layer=act_layer, 577 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 578 | layer_scale_init_value=layer_scale_init_value, 579 | proposal_w=proposal_w[2], proposal_h=proposal_h[2], proposal_d=proposal_d[2], 580 | fold_w=fold_w[2], fold_h=fold_h[2], fold_d=fold_d[2], 581 | heads=heads[2], head_dim=head_dim[2], return_center=False) 582 | self.up1 = PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 583 | in_chans=embed_dims[2], embed_dim=embed_dims[1]) 584 | # de2 585 | self.de2 = basic_blocks(embed_dims[1], 1, layers, mlp_ratio=mlp_ratios[1], act_layer=act_layer, 586 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 587 | layer_scale_init_value=layer_scale_init_value, 588 | proposal_w=proposal_w[1], proposal_h=proposal_h[1], proposal_d=proposal_d[1], 589 | fold_w=fold_w[1], fold_h=fold_h[1], fold_d=fold_d[1], 590 | heads=heads[1], head_dim=head_dim[1], return_center=False) 591 | self.up2 = PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 592 | in_chans=embed_dims[1], embed_dim=embed_dims[0]) 593 | # de3 594 | self.de3 = basic_blocks(embed_dims[0], 0, layers, mlp_ratio=mlp_ratios[0], act_layer=act_layer, 595 | norm_layer=norm_layer, drop_rate=drop_rate, use_layer_scale=use_layer_scale, 596 | layer_scale_init_value=layer_scale_init_value, 597 | proposal_w=proposal_w[0], proposal_h=proposal_h[0], proposal_d=proposal_d[0], 598 | fold_w=fold_w[0], fold_h=fold_h[0], fold_d=fold_d[0], 599 | heads=heads[0], head_dim=head_dim[0], return_center=False) 600 | self.patch_expand = nn.Sequential( 601 | PointExpander(patch_size=up_patch_size, stride=up_stride, padding=up_pad, 602 | in_chans=embed_dims[0], embed_dim=3), 603 | nn.Conv3d(in_channels=3, out_channels=1, kernel_size=1, stride=1), 604 | nn.LeakyReLU(), 605 | nn.BatchNorm3d(1), 606 | ) 607 | 608 | # # add a norm layer for each output 609 | # self.out_indices = [0, 2, 4, 6] 610 | # for i_emb, i_layer in enumerate(self.out_indices): 611 | # if i_emb == 0 and os.environ.get('FORK_LAST3', None): 612 | # """For RetinaNet, `start_level=1`. The first norm layer will not used. 613 | # """ 614 | # layer = nn.Identity() 615 | # else: 616 | # layer = norm_layer(embed_dims[i_emb]) 617 | # layer_name = f'norm{i_layer}' 618 | # self.add_module(layer_name, layer) 619 | 620 | def forward_embeddings(self, x): 621 | _, c, img_w, img_h, img_d = x.shape 622 | # print(f"img size is {c} * {img_w} * {img_h}") 623 | # register positional information buffer. 624 | range_w = torch.arange(0, img_w, step=1) / (img_w - 1.0) 625 | range_h = torch.arange(0, img_h, step=1) / (img_h - 1.0) 626 | range_d = torch.arange(0, img_d, step=1) / (img_d - 1.0) 627 | fea_pos = torch.stack(torch.meshgrid(range_w, range_h, range_d), dim=-1).float() 628 | fea_pos = fea_pos.to(x.device) 629 | fea_pos = fea_pos - 0.5 630 | # print('fea_pos ', fea_pos.shape) 631 | pos = fea_pos.permute(3, 0, 1, 2).unsqueeze(dim=0).expand(x.shape[0], -1, -1, -1, -1) 632 | # print('pos ', pos.shape) 633 | x = self.patch_embed(torch.cat([x, pos], dim=1)) 634 | # print('x ', x.shape) 635 | return x 636 | 637 | def restore_embeddings(self, x): 638 | x = self.patch_expand(x) 639 | 640 | return x 641 | 642 | def forward(self, PET, MRI): 643 | # encoder 644 | en0_PET = self.forward_embeddings(PET) 645 | en0_MRI = self.forward_embeddings(MRI) 646 | en1_PET, en1_MIX, en1_MRI = self.en0((en0_PET, en0_MRI)) 647 | en1_PET = en1_PET + en1_MIX 648 | 649 | en1_PET = self.down1_PET(en1_PET) 650 | en1_MRI = self.down1_MRI(en1_MRI) 651 | en2_PET, en2_MIX, en2_MRI = self.en1((en1_PET, en1_MRI)) 652 | en2_PET = en2_PET + en2_MIX 653 | 654 | en2_PET = self.down2_PET(en2_PET) 655 | en2_MRI = self.down2_MRI(en2_MRI) 656 | en3_PET, en3_MIX, en3_MRI = self.en2((en2_PET, en2_MRI)) 657 | en3_PET = en3_PET + en3_MIX 658 | 659 | en3_PET = self.down3_PET(en3_PET) 660 | en3_MRI = self.down3_MRI(en3_MRI) 661 | en4_PET, en4_MIX, en4_MRI = self.en3((en3_PET, en3_MRI)) 662 | en4_PET = en4_PET + en4_MIX 663 | 664 | # bottleneck 665 | _, en_bot, _ = self.bot((en4_PET, en4_MRI)) 666 | 667 | # decoder 668 | de0 = self.de0(en_bot) 669 | de0 = self.up0(de0) + en2_PET 670 | 671 | de1 = self.de1(de0) 672 | de1 = self.up1(de1) + en1_PET 673 | 674 | de2 = self.de2(de1) 675 | de2 = self.up2(de2) + en0_PET 676 | 677 | de3 = self.de3(de2) 678 | 679 | # output 680 | de3 = de3 + en0_PET 681 | de3 = self.patch_expand(de3) 682 | 683 | # output = de3 + PET 684 | output = de3 685 | 686 | return output 687 | 688 | 689 | @register_model 690 | def pmccgen(**kwargs): 691 | # sharing same parameters as coc_tiny, without region partition. 692 | layers = [1, 1, 1, 1] 693 | norm_layer = GroupNorm 694 | embed_dims = [32, 64, 128, 256] # 64, 128, 256, 512 695 | mlp_ratios = [8, 8, 4, 4] 696 | downsamples = [True, True, True, True] 697 | proposal_w = [2, 2, 2, 2] 698 | proposal_h = [2, 2, 2, 2] 699 | proposal_d = [2, 2, 2, 2] 700 | fold_w = [1, 1, 1, 1] 701 | fold_h = [1, 1, 1, 1] 702 | fold_d = [1, 1, 1, 1] 703 | heads = [4, 4, 8, 8] 704 | head_dim = [24, 24, 24, 24] 705 | down_patch_size = 3 706 | down_pad = 1 707 | model = ContextClusterMulti( 708 | layers, embed_dims=embed_dims, norm_layer=norm_layer, 709 | mlp_ratios=mlp_ratios, downsamples=downsamples, 710 | down_patch_size=down_patch_size, down_pad=down_pad, 711 | proposal_w=proposal_w, proposal_h=proposal_h, proposal_d=proposal_d, 712 | fold_w=fold_w, fold_h=fold_h, fold_d=fold_d, 713 | heads=heads, head_dim=head_dim, 714 | **kwargs) 715 | 716 | return model 717 | 718 | 719 | if __name__ == '__main__': 720 | input = torch.rand(2, 1, 64, 64, 64) 721 | model = pmccgen() 722 | out = model(input, input) 723 | print(out.shape) 724 | --------------------------------------------------------------------------------