├── CSANet ├── datasets │ └── dataset_CSANet.py ├── loss_function.py ├── networks │ ├── __pycache__ │ │ ├── vit_seg_configs.cpython-39.pyc │ │ ├── vit_seg_modeling.cpython-39.pyc │ │ ├── vit_seg_modeling32.cpython-39.pyc │ │ ├── vit_seg_modeling4.cpython-39.pyc │ │ ├── vit_seg_modeling4_wm.cpython-39.pyc │ │ ├── vit_seg_modeling8.cpython-39.pyc │ │ ├── vit_seg_modeling_SA.cpython-39.pyc │ │ ├── vit_seg_modeling_og.cpython-39.pyc │ │ ├── vit_seg_modeling_resnet_skip.cpython-39.pyc │ │ └── vit_seg_modeling_wcs.cpython-39.pyc │ ├── vit_seg_configs.py │ ├── vit_seg_modeling.py │ └── vit_seg_modeling_resnet_skip.py ├── test.py ├── train.py ├── trainer.py ├── utils.py └── visualization.py ├── LICENSE ├── README.md ├── data └── README.md ├── model └── vit_checkpoint │ └── imagenet21k │ └── README.md ├── preprocessing.py ├── requirements.txt └── utils └── lists ├── test_vol.txt ├── train_image.txt └── train_mask.txt /CSANet/datasets/dataset_CSANet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torch.utils.data import Dataset 9 | import SimpleITK as sitk 10 | import os 11 | import SimpleITK as sitk 12 | from PIL import Image 13 | import numpy as np 14 | import cv2 15 | 16 | def random_horizontal_flip(image,next_image, prev_image, segmentation): 17 | # Generate a random number to decide whether to flip or not 18 | flip = random.choice([True, False]) 19 | 20 | # Perform horizontal flipping if flip is True 21 | if flip: 22 | flipped_image = np.fliplr(image) 23 | flipped_next_image = np.fliplr(next_image) 24 | flipped_prev_image = np.fliplr(prev_image) 25 | flipped_segmentation = np.fliplr(segmentation) 26 | else: 27 | flipped_image = image 28 | flipped_next_image = next_image 29 | flipped_prev_image = prev_image 30 | flipped_segmentation = segmentation 31 | 32 | return flipped_image,flipped_next_image,flipped_prev_image,flipped_segmentation 33 | 34 | 35 | 36 | class RandomGenerator(object): 37 | """ 38 | Applies random transformations to a sample including horizontal flips and resizing to a target size. 39 | 40 | Parameters: 41 | output_size (tuple): Desired output dimensions (height, width) for the images and labels. 42 | """ 43 | def __init__(self, output_size): 44 | self.output_size = output_size 45 | 46 | def __call__(self, sample): 47 | # Unpack the sample dictionary to individual components 48 | image, label = sample['image'], sample['label'] 49 | next_image, prev_image = sample['next_image'], sample['prev_image'] 50 | 51 | # Apply a random horizontal flip to the images and label 52 | image,next_image, prev_image, label = random_horizontal_flip(image, next_image, prev_image, label) 53 | # Check if the current size matches the desired output size 54 | x, y = image.shape 55 | if x != self.output_size[0] or y != self.output_size[1]: 56 | # Rescale images to match the specified output size using cubic interpolation 57 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 58 | next_image = zoom(next_image, (self.output_size[0] / x, self.output_size[1] / y), order=3) 59 | prev_image = zoom(prev_image, (self.output_size[0] / x, self.output_size[1] / y), order=3) 60 | # Rescale the label using nearest neighbor interpolation (order=0) to avoid creating new labels 61 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 62 | 63 | # Convert numpy arrays to PyTorch tensors and add a channel dimension to images 64 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 65 | label = torch.from_numpy(label.astype(np.float32)) 66 | next_image = torch.from_numpy(next_image.astype(np.float32)).unsqueeze(0) 67 | prev_image = torch.from_numpy(prev_image.astype(np.float32)).unsqueeze(0) 68 | # Return the modified sample as a dictionary 69 | sample = {'image': image, 'next_image': next_image, 'prev_image': prev_image, 'label': label.long()} 70 | return sample 71 | 72 | 73 | def extract_and_increase_number(file_name): 74 | """ 75 | Generates the filenames for the next and previous sequence by incrementing and decrementing the numerical part of a given filename. 76 | 77 | Parameters: 78 | file_name (str): The original filename from which to derive the next and previous filenames. 79 | The filename must end with a numerical value preceded by an underscore. 80 | 81 | Returns: 82 | tuple: Contains two strings, the first being the next filename in sequence and the second 83 | the previous filename in sequence. If the original number is 0, the previous filename 84 | will also use 0 to avoid negative numbering. 85 | """ 86 | parts = file_name.rsplit("_", 1) 87 | parts_next = parts[0] 88 | parts_prev = parts[0] 89 | number = int(parts[1]) 90 | 91 | next_number = number + 1 92 | prev_number = number - 1 93 | if prev_number== -1: 94 | pre_number = 0 95 | 96 | next_numbers = str(next_number) 97 | prev_numbers = str(prev_number) 98 | next_file_name = parts_next+"_"+str(next_numbers) 99 | prev_file_name = parts_prev+"_"+str(prev_numbers) 100 | 101 | return next_file_name,prev_file_name 102 | 103 | 104 | 105 | def check_and_create_file(file_name, image_name, folder_path): 106 | file_path = os.path.join(folder_path, "trainingImages", file_name+'.npy') 107 | if os.path.exists(file_path): 108 | return file_name 109 | else: 110 | available_name = image_name 111 | return available_name 112 | 113 | 114 | class CSANet_dataset(Dataset): 115 | """ 116 | Dataset handler for CSANet, designed to manage image and mask data for training and testing phases. 117 | 118 | Attributes: 119 | base_dir (str): Directory where image and mask data are stored. 120 | list_dir (str): Directory where the lists of data splits are located. 121 | split (str): The current dataset split, indicating training or testing phase. 122 | transform (callable, optional): A function/transform to apply to the samples. 123 | 124 | Note: 125 | This class expects directory structures and file naming conventions that match the specifics 126 | given in the initialization arguments. 127 | """ 128 | 129 | def __init__(self, base_dir, list_dir, split, transform=None): 130 | self.transform = transform # using transform in torch! 131 | self.split = split 132 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 133 | self.image_sample_list = open(os.path.join(list_dir, 'train_image.txt')).readlines() 134 | self.mask_sample_list = open(os.path.join(list_dir, 'train_mask.txt')).readlines() 135 | self.data_dir = base_dir 136 | 137 | def __len__(self): 138 | return len(self.sample_list) 139 | 140 | def __getitem__(self, idx): 141 | if self.split == "train_image" or self.split == "train_image_train" or self.split == "train_image_test": 142 | 143 | slice_name = self.image_sample_list[idx].strip('\n') 144 | image_data_path = os.path.join(self.data_dir, "trainingImages", slice_name+'.npy') 145 | image = np.load(image_data_path) 146 | #print("##################################### image path = ", image_data_path) 147 | # Manage sequence continuity by fetching adjacent slices 148 | next_file_name, prev_file_name = extract_and_increase_number(slice_name) 149 | 150 | next_file_name = check_and_create_file (next_file_name, slice_name, self.data_dir) 151 | prev_file_name = check_and_create_file (prev_file_name, slice_name, self.data_dir) 152 | 153 | 154 | next_image_path = os.path.join(self.data_dir, "trainingImages", next_file_name +'.npy') 155 | prev_image_path = os.path.join(self.data_dir, "trainingImages", prev_file_name +'.npy') 156 | 157 | next_image = np.load(next_image_path) 158 | prev_image = np.load(prev_image_path) 159 | 160 | 161 | mask_name = self.mask_sample_list[idx].strip('\n') 162 | label_data_path = os.path.join(self.data_dir, "trainingMasks", mask_name+'.npy') 163 | #print("############################################# label path = ", label_data_path) 164 | label = np.load(label_data_path) 165 | 166 | sample = {'image': image, 'next_image': next_image, 'prev_image': prev_image, 'label': label} 167 | 168 | if self.transform: 169 | sample = self.transform(sample) # Apply transformations if specified 170 | sample['case_name'] = self.sample_list[idx].strip('\n') 171 | return sample 172 | else: 173 | # Handling testing data, assuming single volume processing 174 | vol_name = self.sample_list[idx].strip('\n') 175 | image_data_path = os.path.join(self.data_dir, "testVol", vol_name) 176 | label_data_path = os.path.join(self.data_dir, "testMask", vol_name) 177 | 178 | image_new = sitk.ReadImage(image_data_path) 179 | img = sitk.GetArrayFromImage(image_new) 180 | 181 | 182 | next_image = sitk.GetArrayFromImage(image_new).astype(np.float64) 183 | prev_image = sitk.GetArrayFromImage(image_new).astype(np.float64) 184 | 185 | # Preprocess image data for testing phase 186 | combined_slices = sitk.GetArrayFromImage(image_new).astype(np.float64) 187 | 188 | 189 | for i in range(img.shape[0]): 190 | img_array = img[i, :, :].astype(np.uint8) 191 | p1 = np.percentile(img_array, 1) 192 | p99 = np.percentile(img_array, 99) 193 | 194 | normalized_img = (img_array - p1) / (p99 - p1) 195 | normalized_img = np.clip(normalized_img, 0, 1) 196 | 197 | combined_slices[i,:,:] = normalized_img 198 | 199 | if i-1 > -1 : 200 | next_image[i-1,:,:] = combined_slices[i,:,:] 201 | 202 | if i-1<0: 203 | prev_image[i,:,:] = combined_slices[i,:,:] 204 | else : 205 | prev_image[i,:,:] = combined_slices[i-1,:,:] 206 | 207 | next_image[img.shape[0]-1,:,:] = combined_slices[img.shape[0]-1,:,:] 208 | 209 | segmentation = sitk.ReadImage(label_data_path) 210 | label = sitk.GetArrayFromImage(segmentation) 211 | sample = {'image': combined_slices, 'next_image': next_image, 'prev_image': prev_image, 'label': label} 212 | if self.transform: 213 | sample = self.transform(sample) # Apply transformations if specified 214 | num_string = self.sample_list[idx].strip('\n') 215 | case_num = num_string.split('.')[0] 216 | sample['case_name'] = case_num 217 | return sample 218 | 219 | 220 | -------------------------------------------------------------------------------- /CSANet/loss_function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | import torch.nn as nn 6 | import SimpleITK as sitk 7 | import os 8 | import nibabel as nib 9 | from skimage.measure import label, regionprops 10 | import scipy.ndimage as ndi 11 | import math 12 | 13 | class DiceLoss(nn.Module): 14 | """ 15 | Implements a Dice loss for evaluating segmentation performance, where Dice loss is a measure of overlap 16 | between two samples and can be used as a loss function for training deep learning models for segmentation tasks. 17 | 18 | Attributes: 19 | - n_classes (int): Number of classes for segmentation. 20 | """ 21 | def __init__(self, n_classes): 22 | """ 23 | Initializes the DiceLoss module with the number of classes. 24 | 25 | Parameters: 26 | - n_classes (int): Number of segmentation classes. 27 | """ 28 | super(DiceLoss, self).__init__() 29 | self.n_classes = n_classes 30 | 31 | def _one_hot_encoder(self, input_tensor): 32 | """ 33 | Converts a tensor of indices of a categorical variable into a one-hot encoded format. 34 | 35 | Parameters: 36 | - input_tensor (torch.Tensor): Tensor containing indices that will be one-hot encoded. 37 | 38 | Returns: 39 | - torch.Tensor: One-hot encoded tensor. 40 | """ 41 | tensor_list = [] 42 | for i in range(self.n_classes): 43 | temp_prob = input_tensor == i 44 | tensor_list.append(temp_prob.unsqueeze(1)) 45 | output_tensor = torch.cat(tensor_list, dim=1) 46 | return output_tensor.float() 47 | 48 | def _dice_loss(self, score, target): 49 | """ 50 | Computes the Dice loss between the predicted scores and the one-hot encoded target. 51 | 52 | Parameters: 53 | - score (torch.Tensor): Predicted scores for each class. 54 | - target (torch.Tensor): One-hot encoded true labels. 55 | 56 | Returns: 57 | - float: Dice loss value. 58 | """ 59 | target = target.float() 60 | smooth = 1e-5 61 | intersect = torch.sum(score * target) 62 | y_sum = torch.sum(target * target) 63 | z_sum = torch.sum(score * score) 64 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 65 | loss = 1 - loss 66 | return loss 67 | 68 | def forward(self, inputs, target, weight=None, softmax=False): 69 | """ 70 | Forward pass for calculating Dice loss for multiple classes. 71 | 72 | Parameters: 73 | - inputs (torch.Tensor): Input logits or softmax predictions. 74 | - target (torch.Tensor): Ground truth labels. 75 | - weight (list of float, optional): Class weights. 76 | - softmax (bool, optional): Whether to apply softmax to inputs before calculating loss. 77 | 78 | Returns: 79 | - float: Mean Dice loss across all classes. 80 | """ 81 | if softmax: 82 | inputs = torch.softmax(inputs, dim=1) 83 | target = self._one_hot_encoder(target) 84 | if weight is None: 85 | weight = [1] * self.n_classes 86 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 87 | class_wise_dice = [] 88 | loss = 0.0 89 | for i in range(0, self.n_classes): 90 | dice = self._dice_loss(inputs[:, i], target[:, i]) 91 | class_wise_dice.append(1.0 - dice.item()) 92 | loss += dice 93 | return loss / (self.n_classes - 1) -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_configs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_configs.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling32.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling32.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling4.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling4.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling4_wm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling4_wm.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling8.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling8.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling_SA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling_SA.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling_og.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling_og.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling_resnet_skip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling_resnet_skip.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/__pycache__/vit_seg_modeling_wcs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mirthAI/CSA-Net/9be2dbe8d2247ab91d03f18bd8af92448a675ff9/CSANet/networks/__pycache__/vit_seg_modeling_wcs.cpython-39.pyc -------------------------------------------------------------------------------- /CSANet/networks/vit_seg_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | def get_b16_config(): 4 | """Returns the ViT-B/16 configuration.""" 5 | config = ml_collections.ConfigDict() 6 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 7 | config.hidden_size = 768 8 | config.transformer = ml_collections.ConfigDict() 9 | config.transformer.mlp_dim = 3072 10 | config.transformer.num_heads = 12 11 | config.transformer.num_layers = 12 12 | config.transformer.attention_dropout_rate = 0.0 13 | config.transformer.dropout_rate = 0.1 14 | 15 | config.classifier = 'seg' 16 | config.representation_size = None 17 | config.resnet_pretrained_path = None 18 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' 19 | config.patch_size = 16 20 | 21 | config.decoder_channels = (256, 128, 64, 16) 22 | config.n_classes = 2 23 | config.activation = 'softmax' 24 | return config 25 | 26 | 27 | def get_testing(): 28 | """Returns a minimal configuration for testing.""" 29 | config = ml_collections.ConfigDict() 30 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 31 | config.hidden_size = 1 32 | config.transformer = ml_collections.ConfigDict() 33 | config.transformer.mlp_dim = 1 34 | config.transformer.num_heads = 1 35 | config.transformer.num_layers = 1 36 | config.transformer.attention_dropout_rate = 0.0 37 | config.transformer.dropout_rate = 0.1 38 | config.classifier = 'token' 39 | config.representation_size = None 40 | return config 41 | 42 | def get_r50_b16_config(): 43 | """Returns the Resnet50 + ViT-B/16 configuration.""" 44 | config = get_b16_config() 45 | config.patches.grid = (16, 16) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | 50 | config.classifier = 'seg' 51 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 52 | config.decoder_channels = (256, 128, 64, 16) 53 | config.skip_channels = [512, 256, 64, 16] 54 | config.n_classes = 2 55 | config.n_skip = 3 56 | config.activation = 'softmax' 57 | 58 | return config 59 | 60 | 61 | def get_b32_config(): 62 | """Returns the ViT-B/32 configuration.""" 63 | config = get_b16_config() 64 | config.patches.size = (32, 32) 65 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' 66 | return config 67 | 68 | 69 | def get_l16_config(): 70 | """Returns the ViT-L/16 configuration.""" 71 | config = ml_collections.ConfigDict() 72 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 73 | config.hidden_size = 1024 74 | config.transformer = ml_collections.ConfigDict() 75 | config.transformer.mlp_dim = 4096 76 | config.transformer.num_heads = 16 77 | config.transformer.num_layers = 24 78 | config.transformer.attention_dropout_rate = 0.0 79 | config.transformer.dropout_rate = 0.1 80 | config.representation_size = None 81 | 82 | # custom 83 | config.classifier = 'seg' 84 | config.resnet_pretrained_path = None 85 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' 86 | config.decoder_channels = (256, 128, 64, 16) 87 | config.n_classes = 2 88 | config.activation = 'softmax' 89 | return config 90 | 91 | 92 | def get_r50_l16_config(): 93 | """Returns the Resnet50 + ViT-L/16 configuration. customized """ 94 | config = get_l16_config() 95 | config.patches.grid = (16, 16) 96 | config.resnet = ml_collections.ConfigDict() 97 | config.resnet.num_layers = (3, 4, 9) 98 | config.resnet.width_factor = 1 99 | 100 | config.classifier = 'seg' 101 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 102 | config.decoder_channels = (256, 128, 64, 16) 103 | config.skip_channels = [512, 256, 64, 16] 104 | config.n_classes = 2 105 | config.activation = 'softmax' 106 | return config 107 | 108 | 109 | def get_l32_config(): 110 | """Returns the ViT-L/32 configuration.""" 111 | config = get_l16_config() 112 | config.patches.size = (32, 32) 113 | return config 114 | 115 | 116 | def get_h14_config(): 117 | """Returns the ViT-L/16 configuration.""" 118 | config = ml_collections.ConfigDict() 119 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 120 | config.hidden_size = 1280 121 | config.transformer = ml_collections.ConfigDict() 122 | config.transformer.mlp_dim = 5120 123 | config.transformer.num_heads = 16 124 | config.transformer.num_layers = 32 125 | config.transformer.attention_dropout_rate = 0.0 126 | config.transformer.dropout_rate = 0.1 127 | config.classifier = 'token' 128 | config.representation_size = None 129 | 130 | return config 131 | -------------------------------------------------------------------------------- /CSANet/networks/vit_seg_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | import torch.nn.functional as F 10 | from os.path import join as pjoin 11 | import torch 12 | import torch.nn as nn 13 | import numpy as np 14 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 15 | from torch.nn.modules.utils import _pair 16 | from scipy import ndimage 17 | from . import vit_seg_configs as configs 18 | from .vit_seg_modeling_resnet_skip import ResNetV2 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 25 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 26 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 27 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 28 | FC_0 = "MlpBlock_3/Dense_0" 29 | FC_1 = "MlpBlock_3/Dense_1" 30 | ATTENTION_NORM = "LayerNorm_0" 31 | MLP_NORM = "LayerNorm_2" 32 | 33 | 34 | def np2th(weights, conv=False): 35 | """Possibly convert HWIO to OIHW.""" 36 | if conv: 37 | weights = weights.transpose([3, 2, 0, 1]) 38 | return torch.from_numpy(weights) 39 | 40 | 41 | def swish(x): 42 | return x * torch.sigmoid(x) 43 | 44 | 45 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 46 | 47 | # Attention module definition 48 | class Attention(nn.Module): 49 | def __init__(self, config, vis): 50 | super(Attention, self).__init__() 51 | self.vis = vis 52 | self.num_attention_heads = config.transformer["num_heads"] 53 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 54 | self.all_head_size = self.num_attention_heads * self.attention_head_size 55 | 56 | self.query = Linear(config.hidden_size, self.all_head_size) 57 | self.key = Linear(config.hidden_size, self.all_head_size) 58 | self.value = Linear(config.hidden_size, self.all_head_size) 59 | 60 | self.out = Linear(config.hidden_size, config.hidden_size) 61 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 62 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 63 | 64 | self.softmax = Softmax(dim=-1) 65 | 66 | def transpose_for_scores(self, x): 67 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 68 | x = x.view(*new_x_shape) 69 | return x.permute(0, 2, 1, 3) 70 | 71 | def forward(self, hidden_states): 72 | mixed_query_layer = self.query(hidden_states) 73 | mixed_key_layer = self.key(hidden_states) 74 | mixed_value_layer = self.value(hidden_states) 75 | 76 | query_layer = self.transpose_for_scores(mixed_query_layer) 77 | key_layer = self.transpose_for_scores(mixed_key_layer) 78 | value_layer = self.transpose_for_scores(mixed_value_layer) 79 | 80 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 81 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 82 | attention_probs = self.softmax(attention_scores) 83 | weights = attention_probs if self.vis else None 84 | attention_probs = self.attn_dropout(attention_probs) 85 | 86 | context_layer = torch.matmul(attention_probs, value_layer) 87 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 88 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 89 | context_layer = context_layer.view(*new_context_layer_shape) 90 | attention_output = self.out(context_layer) 91 | attention_output = self.proj_dropout(attention_output) 92 | return attention_output, weights 93 | 94 | # Multi-Layer Perceptron (MLP) module definitio 95 | class Mlp(nn.Module): 96 | def __init__(self, config): 97 | super(Mlp, self).__init__() 98 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 99 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 100 | self.act_fn = ACT2FN["gelu"] 101 | self.dropout = Dropout(config.transformer["dropout_rate"]) 102 | 103 | self._init_weights() 104 | 105 | def _init_weights(self): 106 | nn.init.xavier_uniform_(self.fc1.weight) 107 | nn.init.xavier_uniform_(self.fc2.weight) 108 | nn.init.normal_(self.fc1.bias, std=1e-6) 109 | nn.init.normal_(self.fc2.bias, std=1e-6) 110 | 111 | def forward(self, x): 112 | x = self.fc1(x) 113 | x = self.act_fn(x) 114 | x = self.dropout(x) 115 | x = self.fc2(x) 116 | x = self.dropout(x) 117 | return x 118 | 119 | # Non-Local Block for multi-cross attention 120 | class NLBlockND_multicross_block(nn.Module): 121 | """ 122 | Non-Local Block for multi-cross attention. 123 | 124 | Args: 125 | in_channels (int): Number of input channels. 126 | inter_channels (int, optional): Number of intermediate channels. Defaults to None. 127 | 128 | Attributes: 129 | in_channels (int): Number of input channels. 130 | inter_channels (int): Number of intermediate channels. 131 | g (nn.Conv2d): Convolutional layer for the 'g' branch. 132 | final (nn.Conv2d): Final convolutional layer. 133 | W_z (nn.Sequential): Sequential block containing a convolutional layer followed by batch normalization for weight 'z'. 134 | theta (nn.Conv2d): Convolutional layer for the 'theta' branch. 135 | phi (nn.Conv2d): Convolutional layer for the 'phi' branch. 136 | 137 | Methods: 138 | forward(x_thisBranch, x_otherBranch): Forward pass of the non-local block. 139 | 140 | """ 141 | def __init__(self, in_channels, inter_channels=None): 142 | super(NLBlockND_multicross_block, self).__init__() 143 | self.in_channels = in_channels 144 | self.inter_channels = inter_channels 145 | 146 | if self.inter_channels is None: 147 | self.inter_channels = in_channels // 2 148 | if self.inter_channels == 0: 149 | self.inter_channels = 1 150 | 151 | conv_nd = nn.Conv2d 152 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 153 | bn = nn.BatchNorm2d 154 | 155 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 156 | self.final = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 157 | self.W_z = nn.Sequential( 158 | conv_nd(in_channels=self.inter_channels, out_channels=self.inter_channels, kernel_size=1), 159 | bn(self.inter_channels) 160 | ) 161 | 162 | nn.init.constant_(self.W_z[1].weight, 0) 163 | nn.init.constant_(self.W_z[1].bias, 0) 164 | 165 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 166 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 167 | 168 | def forward(self, x_thisBranch, x_otherBranch): 169 | batch_size = x_thisBranch.size(0) 170 | g_x = self.g(x_thisBranch).view(batch_size, self.inter_channels, -1) 171 | g_x = g_x.permute(0, 2, 1) 172 | 173 | theta_x = self.theta(x_thisBranch).view(batch_size, self.inter_channels, -1) 174 | phi_x = self.phi(x_otherBranch).view(batch_size, self.inter_channels, -1) 175 | phi_x = phi_x.permute(0, 2, 1) 176 | 177 | f = torch.matmul(phi_x, theta_x) 178 | f_div_C = F.softmax(f, dim=-1) 179 | 180 | y = torch.matmul(f_div_C, g_x) 181 | y = y.permute(0, 2, 1).contiguous() 182 | y = y.view(batch_size, self.inter_channels, *x_thisBranch.size()[2:]) 183 | 184 | z = self.W_z(y) 185 | return z 186 | 187 | # Multi-Cross Attention Block 188 | class NLBlockND_multicross(nn.Module): 189 | 190 | def __init__(self, in_channels, inter_channels=None): 191 | super(NLBlockND_multicross, self).__init__() 192 | self.in_channels = in_channels 193 | self.inter_channels = inter_channels 194 | 195 | if self.inter_channels is None: 196 | self.inter_channels = in_channels // 2 197 | if self.inter_channels == 0: 198 | self.inter_channels = 1 199 | self.cross_attention = NLBlockND_multicross_block(in_channels=1024, inter_channels=64) 200 | def forward(self, x_thisBranch, x_otherBranch): 201 | outputs = [] 202 | for i in range(16): 203 | cross_attention = NLBlockND_multicross_block(in_channels=1024, inter_channels=64) 204 | cross_attention = cross_attention.to('cuda') 205 | output = cross_attention(x_thisBranch,x_otherBranch) 206 | 207 | outputs.append(output) 208 | final_output = torch.cat(outputs, dim=1) 209 | #final_output = final_output + x_thisBranch #Changed 210 | return final_output 211 | 212 | 213 | class DoubleConv(nn.Module): 214 | def __init__(self, in_channels, out_channels, mid_channels=None): 215 | super().__init__() 216 | if not mid_channels: 217 | mid_channels = out_channels 218 | self.double_conv = nn.Sequential( 219 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 220 | nn.BatchNorm2d(mid_channels), 221 | nn.ReLU(inplace=True), 222 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 223 | nn.BatchNorm2d(out_channels), 224 | nn.ReLU(inplace=True) 225 | ) 226 | 227 | def forward(self, x): 228 | return self.double_conv(x) 229 | 230 | 231 | 232 | class DownCross(nn.Module): 233 | 234 | 235 | def __init__(self, in_channels, out_channels): 236 | super().__init__() 237 | self.maxpool_conv = nn.Sequential( 238 | DoubleConv(in_channels, out_channels) 239 | ) 240 | 241 | def forward(self, x): 242 | return self.maxpool_conv(x) 243 | 244 | 245 | # Embeddings module for constructing embeddings from patches and position embeddings 246 | class Embeddings(nn.Module): 247 | """Construct the embeddings from patch, position embeddings. 248 | """ 249 | def __init__(self, config, img_size, in_channels=3): 250 | super(Embeddings, self).__init__() 251 | self.hybrid = None 252 | self.hybrid_prev = None 253 | self.hybrid_next = None 254 | self.config = config 255 | img_size = _pair(img_size) 256 | 257 | self.cross_attention_multi_1 = NLBlockND_multicross(in_channels=1024, inter_channels=512) 258 | self.cross_attention_multi_2 = NLBlockND_multicross(in_channels=1024, inter_channels=512) 259 | self.cross_attention_multi_3 = NLBlockND_multicross(in_channels=1024, inter_channels=512) 260 | self.downcross_three = (DownCross(3072, 1024)) 261 | 262 | 263 | if config.patches.get("grid") is not None: # ResNet 264 | grid_size = config.patches["grid"] 265 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) 266 | patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) 267 | n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) 268 | self.hybrid = True 269 | else: 270 | patch_size = _pair(config.patches["size"]) 271 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 272 | self.hybrid = False 273 | self.hybrid_next = False 274 | self.hybrid_prev = False 275 | if self.hybrid: 276 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) 277 | self.hybrid_model_prev = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) 278 | self.hybrid_model_next = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) 279 | in_channels = self.hybrid_model.width * 16 280 | self.patch_embeddings = Conv2d(in_channels=in_channels, 281 | out_channels=config.hidden_size, 282 | kernel_size=patch_size, 283 | stride=patch_size) 284 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) 285 | 286 | self.dropout = Dropout(config.transformer["dropout_rate"]) 287 | 288 | 289 | def forward(self, x_prev,x,x_next): 290 | if self.hybrid: 291 | 292 | x, features = self.hybrid_model(x) 293 | x_prev, features1 = self.hybrid_model(x_prev) 294 | x_next, features2 = self.hybrid_model(x_next) 295 | else: 296 | features = None 297 | 298 | 299 | xt1 = self.cross_attention_multi_1(x,x_next) 300 | xt2 = self.cross_attention_multi_2(x,x_prev) 301 | xt3 = self.cross_attention_multi_3(x,x) 302 | 303 | xt = torch.cat([xt1,xt3,xt2], dim=1) 304 | x = self.downcross_three(xt) 305 | 306 | x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) 307 | x = x.flatten(2) 308 | x = x.transpose(-1, -2) # (B, n_patches, hidden) 309 | 310 | embeddings = x + self.position_embeddings 311 | embeddings = self.dropout(embeddings) 312 | return embeddings, features 313 | 314 | 315 | # Transformer Block 316 | class Block(nn.Module): 317 | def __init__(self, config, vis): 318 | super(Block, self).__init__() 319 | self.hidden_size = config.hidden_size 320 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 321 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 322 | self.ffn = Mlp(config) 323 | self.attn = Attention(config, vis) 324 | 325 | def forward(self, x): 326 | h = x 327 | x = self.attention_norm(x) 328 | x, weights = self.attn(x) 329 | x = x + h 330 | 331 | h = x 332 | x = self.ffn_norm(x) 333 | x = self.ffn(x) 334 | x = x + h 335 | return x, weights 336 | 337 | def load_from(self, weights, n_block): 338 | ROOT = f"Transformer/encoderblock_{n_block}" 339 | with torch.no_grad(): 340 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() 341 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 342 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() 343 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() 344 | 345 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 346 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 347 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 348 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 349 | 350 | self.attn.query.weight.copy_(query_weight) 351 | self.attn.key.weight.copy_(key_weight) 352 | self.attn.value.weight.copy_(value_weight) 353 | self.attn.out.weight.copy_(out_weight) 354 | self.attn.query.bias.copy_(query_bias) 355 | self.attn.key.bias.copy_(key_bias) 356 | self.attn.value.bias.copy_(value_bias) 357 | self.attn.out.bias.copy_(out_bias) 358 | 359 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 360 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 361 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 362 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 363 | 364 | self.ffn.fc1.weight.copy_(mlp_weight_0) 365 | self.ffn.fc2.weight.copy_(mlp_weight_1) 366 | self.ffn.fc1.bias.copy_(mlp_bias_0) 367 | self.ffn.fc2.bias.copy_(mlp_bias_1) 368 | 369 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 370 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 371 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 372 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 373 | 374 | 375 | # Transformer Encoder 376 | class Encoder(nn.Module): 377 | def __init__(self, config, vis): 378 | super(Encoder, self).__init__() 379 | self.vis = vis 380 | self.layer = nn.ModuleList() 381 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 382 | for _ in range(config.transformer["num_layers"]): 383 | layer = Block(config, vis) 384 | self.layer.append(copy.deepcopy(layer)) 385 | 386 | def forward(self, hidden_states): 387 | attn_weights = [] 388 | for layer_block in self.layer: 389 | hidden_states, weights = layer_block(hidden_states) 390 | if self.vis: 391 | attn_weights.append(weights) 392 | encoded = self.encoder_norm(hidden_states) 393 | return encoded, attn_weights 394 | 395 | # Transformer architecture 396 | class Transformer(nn.Module): 397 | def __init__(self, config, img_size, vis): 398 | super(Transformer, self).__init__() 399 | self.embeddings = Embeddings(config, img_size=img_size) 400 | self.encoder = Encoder(config, vis) 401 | 402 | def forward(self, x_prev,x,x_next): 403 | embedding_output, features = self.embeddings(x_prev,x,x_next) 404 | encoded, attn_weights = self.encoder(embedding_output) 405 | return encoded, attn_weights, features 406 | 407 | # Conv2dReLU module 408 | class Conv2dReLU(nn.Sequential): 409 | def __init__( 410 | self, 411 | in_channels, 412 | out_channels, 413 | kernel_size, 414 | padding=0, 415 | stride=1, 416 | use_batchnorm=True, 417 | ): 418 | conv = nn.Conv2d( 419 | in_channels, 420 | out_channels, 421 | kernel_size, 422 | stride=stride, 423 | padding=padding, 424 | bias=not (use_batchnorm), 425 | ) 426 | relu = nn.ReLU(inplace=True) 427 | 428 | bn = nn.BatchNorm2d(out_channels) 429 | 430 | super(Conv2dReLU, self).__init__(conv, bn, relu) 431 | 432 | # Decoder block for the segmentation head 433 | class DecoderBlock(nn.Module): 434 | def __init__( 435 | self, 436 | in_channels, 437 | out_channels, 438 | skip_channels=0, 439 | use_batchnorm=True, 440 | ): 441 | super().__init__() 442 | self.conv1 = Conv2dReLU( 443 | in_channels + skip_channels, 444 | out_channels, 445 | kernel_size=3, 446 | padding=1, 447 | use_batchnorm=use_batchnorm, 448 | ) 449 | self.conv2 = Conv2dReLU( 450 | out_channels, 451 | out_channels, 452 | kernel_size=3, 453 | padding=1, 454 | use_batchnorm=use_batchnorm, 455 | ) 456 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 457 | 458 | def forward(self, x, skip=None): 459 | x = self.up(x) 460 | if skip is not None: 461 | x = torch.cat([x, skip], dim=1) 462 | x = self.conv1(x) 463 | x = self.conv2(x) 464 | return x 465 | 466 | # Segmentation head module 467 | class SegmentationHead(nn.Sequential): 468 | 469 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): 470 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 471 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 472 | super().__init__(conv2d, upsampling) 473 | 474 | # DecoderCup module 475 | class DecoderCup(nn.Module): 476 | def __init__(self, config): 477 | super().__init__() 478 | self.config = config 479 | head_channels = 512 480 | self.conv_more = Conv2dReLU( 481 | config.hidden_size, 482 | head_channels, 483 | kernel_size=3, 484 | padding=1, 485 | use_batchnorm=True, 486 | ) 487 | decoder_channels = config.decoder_channels 488 | in_channels = [head_channels] + list(decoder_channels[:-1]) 489 | out_channels = decoder_channels 490 | 491 | if self.config.n_skip != 0: 492 | skip_channels = self.config.skip_channels 493 | for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip 494 | skip_channels[3-i]=0 495 | 496 | else: 497 | skip_channels=[0,0,0,0] 498 | 499 | blocks = [ 500 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 501 | ] 502 | self.blocks = nn.ModuleList(blocks) 503 | 504 | def forward(self, hidden_states, features=None): 505 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 506 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 507 | x = hidden_states.permute(0, 2, 1) 508 | x = x.contiguous().view(B, hidden, h, w) 509 | x = self.conv_more(x) 510 | for i, decoder_block in enumerate(self.blocks): 511 | if features is not None: 512 | skip = features[i] if (i < self.config.n_skip) else None 513 | else: 514 | skip = None 515 | x = decoder_block(x, skip=skip) 516 | return x 517 | 518 | 519 | 520 | # Vision Transformer model 521 | class VisionTransformer(nn.Module): 522 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 523 | super(VisionTransformer, self).__init__() 524 | self.num_classes = num_classes 525 | self.zero_head = zero_head 526 | self.classifier = config.classifier 527 | self.transformer = Transformer(config, img_size, vis) 528 | self.decoder = DecoderCup(config) 529 | self.segmentation_head = SegmentationHead( 530 | in_channels=config['decoder_channels'][-1], 531 | out_channels=config['n_classes'], 532 | kernel_size=3, 533 | ) 534 | self.config = config 535 | 536 | def forward(self, x_prev,x,x_next): 537 | if x.size()[1] == 1: 538 | x = x.repeat(1,3,1,1) 539 | x_prev = x_prev.repeat(1,3,1,1) 540 | x_next = x_next.repeat(1,3,1,1) 541 | x, attn_weights, features = self.transformer(x_prev,x,x_next) # (B, n_patch, hidden) 542 | x = self.decoder(x, features) 543 | logits = self.segmentation_head(x) 544 | return logits 545 | 546 | def load_from(self, weights): 547 | with torch.no_grad(): 548 | 549 | res_weight = weights 550 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 551 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 552 | 553 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 554 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 555 | 556 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 557 | 558 | posemb_new = self.transformer.embeddings.position_embeddings 559 | if posemb.size() == posemb_new.size(): 560 | self.transformer.embeddings.position_embeddings.copy_(posemb) 561 | elif posemb.size()[1]-1 == posemb_new.size()[1]: 562 | posemb = posemb[:, 1:] 563 | self.transformer.embeddings.position_embeddings.copy_(posemb) 564 | else: 565 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) 566 | ntok_new = posemb_new.size(1) 567 | if self.classifier == "seg": 568 | _, posemb_grid = posemb[:, :1], posemb[0, 1:] 569 | gs_old = int(np.sqrt(len(posemb_grid))) 570 | gs_new = int(np.sqrt(ntok_new)) 571 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 572 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 573 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 574 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np 575 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 576 | posemb = posemb_grid 577 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 578 | 579 | # Encoder whole 580 | for bname, block in self.transformer.encoder.named_children(): 581 | for uname, unit in block.named_children(): 582 | unit.load_from(weights, n_block=uname) 583 | 584 | if self.transformer.embeddings.hybrid: 585 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) 586 | gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) 587 | gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) 588 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 589 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 590 | 591 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 592 | for uname, unit in block.named_children(): 593 | unit.load_from(res_weight, n_block=bname, n_unit=uname) 594 | 595 | 596 | # Configuration dictionary for different Vision Transformer variants 597 | CONFIGS = { 598 | 'ViT-B_16': configs.get_b16_config(), 599 | 'ViT-B_32': configs.get_b32_config(), 600 | 'ViT-L_16': configs.get_l16_config(), 601 | 'ViT-L_32': configs.get_l32_config(), 602 | 'ViT-H_14': configs.get_h14_config(), 603 | 'R50-ViT-B_16': configs.get_r50_b16_config(), 604 | 'R50-ViT-L_16': configs.get_r50_l16_config(), 605 | 'testing': configs.get_testing(), 606 | } 607 | 608 | 609 | -------------------------------------------------------------------------------- /CSANet/networks/vit_seg_modeling_resnet_skip.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from os.path import join as pjoin 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def np2th(weights, conv=False): 12 | """Possibly convert HWIO to OIHW.""" 13 | if conv: 14 | weights = weights.transpose([3, 2, 0, 1]) 15 | return torch.from_numpy(weights) 16 | 17 | 18 | class StdConv2d(nn.Conv2d): 19 | 20 | def forward(self, x): 21 | w = self.weight 22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 23 | w = (w - m) / torch.sqrt(v + 1e-5) 24 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 25 | self.dilation, self.groups) 26 | 27 | 28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 30 | padding=1, bias=bias, groups=groups) 31 | 32 | 33 | def conv1x1(cin, cout, stride=1, bias=False): 34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 35 | padding=0, bias=bias) 36 | 37 | 38 | class PreActBottleneck(nn.Module): 39 | """Pre-activation (v2) bottleneck block. 40 | """ 41 | 42 | def __init__(self, cin, cout=None, cmid=None, stride=1): 43 | super().__init__() 44 | cout = cout or cin 45 | cmid = cmid or cout//4 46 | 47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 48 | self.conv1 = conv1x1(cin, cmid, bias=False) 49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 52 | self.conv3 = conv1x1(cmid, cout, bias=False) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | if (stride != 1 or cin != cout): 56 | # Projection also with pre-activation according to paper. 57 | self.downsample = conv1x1(cin, cout, stride, bias=False) 58 | self.gn_proj = nn.GroupNorm(cout, cout) 59 | 60 | def forward(self, x): 61 | 62 | # Residual branch 63 | residual = x 64 | if hasattr(self, 'downsample'): 65 | residual = self.downsample(x) 66 | residual = self.gn_proj(residual) 67 | 68 | # Unit's branch 69 | y = self.relu(self.gn1(self.conv1(x))) 70 | y = self.relu(self.gn2(self.conv2(y))) 71 | y = self.gn3(self.conv3(y)) 72 | 73 | y = self.relu(residual + y) 74 | return y 75 | 76 | def load_from(self, weights, n_block, n_unit): 77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 80 | 81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 83 | 84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 86 | 87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 89 | 90 | self.conv1.weight.copy_(conv1_weight) 91 | self.conv2.weight.copy_(conv2_weight) 92 | self.conv3.weight.copy_(conv3_weight) 93 | 94 | self.gn1.weight.copy_(gn1_weight.view(-1)) 95 | self.gn1.bias.copy_(gn1_bias.view(-1)) 96 | 97 | self.gn2.weight.copy_(gn2_weight.view(-1)) 98 | self.gn2.bias.copy_(gn2_bias.view(-1)) 99 | 100 | self.gn3.weight.copy_(gn3_weight.view(-1)) 101 | self.gn3.bias.copy_(gn3_bias.view(-1)) 102 | 103 | if hasattr(self, 'downsample'): 104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 107 | 108 | self.downsample.weight.copy_(proj_conv_weight) 109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 111 | 112 | class ResNetV2(nn.Module): 113 | """Implementation of Pre-activation (v2) ResNet mode.""" 114 | 115 | def __init__(self, block_units, width_factor): 116 | super().__init__() 117 | width = int(64 * width_factor) 118 | self.width = width 119 | 120 | self.root = nn.Sequential(OrderedDict([ 121 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), 122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 123 | ('relu', nn.ReLU(inplace=True)), 124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 125 | ])) 126 | 127 | self.body = nn.Sequential(OrderedDict([ 128 | ('block1', nn.Sequential(OrderedDict( 129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 131 | ))), 132 | ('block2', nn.Sequential(OrderedDict( 133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 135 | ))), 136 | ('block3', nn.Sequential(OrderedDict( 137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 139 | ))), 140 | ])) 141 | 142 | def forward(self, x): 143 | features = [] 144 | b, c, in_size, _ = x.size() 145 | x = self.root(x) 146 | features.append(x) 147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) 148 | for i in range(len(self.body)-1): 149 | x = self.body[i](x) 150 | right_size = int(in_size / 4 / (i+1)) 151 | if x.size()[2] != right_size: 152 | pad = right_size - x.size()[2] 153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) 154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) 155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] 156 | else: 157 | feat = x 158 | features.append(feat) 159 | x = self.body[-1](x) 160 | return x, features[::-1] 161 | -------------------------------------------------------------------------------- /CSANet/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import sys 9 | import time 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from networks.vit_seg_modeling import VisionTransformer as ViT_seg 13 | from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg 14 | from datasets.dataset_CSANet import CSANet_dataset 15 | from tensorboardX import SummaryWriter 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | from torchvision import transforms 19 | from utils import test_single_volume 20 | from datasets.dataset_CSANet import CSANet_dataset, RandomGenerator 21 | 22 | 23 | """ 24 | This script configures and initializes training for the CSANet segmentation model using Vision Transformers. It handles command-line arguments for various training parameters, sets up deterministic options for reproducibility, and initializes the model with specified configurations. 25 | 26 | Parameters: 27 | - volume_path: Directory for validation volume data. 28 | - dataset: Name of the dataset or experiment. 29 | - num_classes: Number of output classes for segmentation. 30 | - list_dir: Directory containing lists of data samples. 31 | - max_iterations: Maximum number of iterations to train. 32 | - max_epochs: Maximum number of epochs to train. 33 | - batch_size: Number of samples per batch. 34 | - seed: Seed for random number generators for reproducibility. 35 | - n_gpu: Number of GPUs to use. 36 | - img_size: Size of the input images. 37 | - base_lr: Base learning rate for the optimizer. 38 | - deterministic: Flag to set training as deterministic. 39 | - n_skip: Number of skip connections in the model. 40 | - vit_name: Name of the Vision Transformer model configuration. 41 | - vit_patches_size: Size of patches for the ViT model. 42 | 43 | The script also loads and validates the model from a saved state if available and performs inference to evaluate the model on a test dataset. 44 | """ 45 | 46 | # Setup command-line argument parsing 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--volume_path', type=str, 49 | default='../data', help='root dir for validation volume data') 50 | parser.add_argument('--dataset', type=str, 51 | default='CSANet', help='experiment_name') 52 | parser.add_argument('--num_classes', type=int, 53 | default=5, help='output channel of network') 54 | parser.add_argument('--list_dir', type=str, 55 | default='./lists', help='list dir') 56 | parser.add_argument('--max_iterations', type=int, 57 | default=300000, help='maximum epoch number to train') 58 | parser.add_argument('--max_epochs', type=int, 59 | default=40, help='maximum epoch number to train') 60 | parser.add_argument('--batch_size', type=int, 61 | default=16, help='batch_size per gpu') 62 | parser.add_argument('--seed', type=int, 63 | default=1234, help='random seed') 64 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 65 | parser.add_argument('--img_size', type=int, 66 | default=224, help='input patch size of network input') 67 | parser.add_argument('--base_lr', type=float, default=0.001, 68 | help='segmentation network learning rate') 69 | parser.add_argument('--deterministic', type=int, default=1, 70 | help='whether use deterministic training') 71 | parser.add_argument('--n_skip', type=int, 72 | default=3, help='using number of skip-connect, default is num') 73 | parser.add_argument('--vit_name', type=str, 74 | default='R50-ViT-B_16', help='select one vit model') 75 | parser.add_argument('--vit_patches_size', type=int, 76 | default=16, help='vit_patches_size, default is 16') 77 | 78 | args = parser.parse_args() 79 | 80 | 81 | 82 | 83 | def vol_inference(args, model, test_save_path=None, validation=False): 84 | """ 85 | Performs inference on a test dataset, computes performance metrics such as Dice coefficients and distances, 86 | and can operate in validation mode or test mode based on a flag. 87 | 88 | Parameters: 89 | - args (Namespace): Contains all the necessary settings such as dataset paths, number of classes, image size, etc. 90 | - model (torch.nn.Module): The trained model to be evaluated. 91 | - test_save_path (str, optional): Path where test outputs (such as images) can be saved. 92 | - validation (bool, optional): If True, function returns average Dice coefficient for validation purposes. 93 | If False, returns a string message indicating test completion. 94 | 95 | Returns: 96 | - float: If validation is True, returns the average Dice coefficient. 97 | - str: If validation is False, returns a completion message "Testing Finished!" 98 | 99 | The function logs the number of test iterations, processes each test sample to compute Dice coefficients and distances, 100 | and aggregates these metrics across the dataset for reporting or validation. 101 | """ 102 | # Load the test dataset 103 | db_test = args.Dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir) 104 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 105 | num = len(testloader) 106 | logging.info("{} test iterations per epoch".format(len(testloader))) 107 | model.eval() 108 | metric_list = 0.0 109 | 110 | # Initialize metrics storage 111 | total_dice_coeff1, total_dice_coeff2, total_dice_coeff3, total_dice_coeff4 = 0, 0, 0, 0 112 | total_dist1, total_dist2, total_dist3, total_dist4 = 0, 0, 0, 0 113 | 114 | # Process each batch in the test loader 115 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 116 | # Retrieve image and label from the batch 117 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'] 118 | image_next, image_prev = sampled_batch['next_image'], sampled_batch['prev_image'] 119 | 120 | dice1, dice2, dice3, dice4, dist1, dist2,dist3, dist4 = test_single_volume(image_next, image, image_prev, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], 121 | test_save_path=test_save_path, case=case_name) 122 | # Output the metrics for monitoring 123 | print("dice1 = ",dice1, " dice2 = ", dice2, "dice3= ",dice3,"dice4= ", dice4) 124 | total_dice_coeff1 = total_dice_coeff1 + dice1 125 | total_dice_coeff2 = total_dice_coeff2 + dice2 126 | total_dice_coeff3 = total_dice_coeff3 + dice3 127 | total_dice_coeff4 = total_dice_coeff4 + dice4 128 | total_dist1 = total_dist1 + dist1 129 | total_dist2 = total_dist2 + dist2 130 | total_dist3 = total_dist3 + dist3 131 | total_dist4 = total_dist4 + dist4 132 | 133 | # Calculate average metrics for all cases 134 | print(f"dice1={total_dice_coeff1/num}, dice2={total_dice_coeff2/num}, dice3={total_dice_coeff3/num}, dice4={total_dice_coeff4/num}, hd1={total_dist1/num}, hd2={total_dist2/num}, hd3={total_dist3/num}, hd4={total_dist4/num}") 135 | avg_dice = (total_dice_coeff1 + total_dice_coeff2 + total_dice_coeff3 + total_dice_coeff4) / (4*num) 136 | print("avg_dice = ",avg_dice) 137 | # Return the appropriate result based on the validation flag 138 | if validation: 139 | return avg_dice 140 | else: 141 | return "Testing Finished!" 142 | 143 | 144 | 145 | 146 | if __name__ == "__main__": 147 | # Setup GPU/CPU seeds for reproducibility if deterministic mode is enabled 148 | if not args.deterministic: 149 | cudnn.benchmark = True 150 | cudnn.deterministic = False 151 | else: 152 | cudnn.benchmark = False 153 | cudnn.deterministic = True 154 | random.seed(args.seed) 155 | np.random.seed(args.seed) 156 | torch.manual_seed(args.seed) 157 | torch.cuda.manual_seed(args.seed) 158 | 159 | dataset_name = args.dataset 160 | # Load dataset configuration based on the provided dataset name 161 | dataset_config = { 162 | 'CSANet': { 163 | 'Dataset': CSANet_dataset, 164 | 'root_path': '../data/train_npz', 165 | 'volume_path': '../data/', 166 | 'list_dir': './lists', 167 | 'num_classes': 5, 168 | 'z_spacing': 1, 169 | }, 170 | } 171 | 172 | 173 | if args.batch_size != 24 and args.batch_size % 6 == 0: 174 | args.base_lr *= args.batch_size / 24 175 | 176 | args.num_classes = dataset_config[dataset_name]['num_classes'] 177 | args.list_dir = dataset_config[dataset_name]['list_dir'] 178 | args.is_pretrain = True 179 | args.Dataset = dataset_config[dataset_name]['Dataset'] 180 | args.exp = 'CSANet_' + str(args.img_size) 181 | 182 | 183 | 184 | snapshot_path = "../model/{}/{}".format(args.exp, 'TU') 185 | snapshot_path += '_' + args.vit_name 186 | snapshot_path = snapshot_path + '_skip' + str(args.n_skip) 187 | snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path 188 | snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path 189 | snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path 190 | snapshot_path = snapshot_path+'_bs'+str(args.batch_size) 191 | snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path 192 | 193 | # Initialize and load the ViT model from the specified configuration and saved state 194 | config_vit = CONFIGS_ViT_seg[args.vit_name] 195 | config_vit.n_classes = args.num_classes 196 | config_vit.n_skip = args.n_skip 197 | config_vit.patches.size = (args.vit_patches_size, args.vit_patches_size) 198 | if args.vit_name.find('R50') !=-1: 199 | config_vit.patches.grid = (int(args.img_size/args.vit_patches_size), int(args.img_size/args.vit_patches_size)) 200 | net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda() 201 | 202 | snapshot = os.path.join(snapshot_path, 'best_model.pth') 203 | if not os.path.exists(snapshot): snapshot = snapshot.replace('best_model', 'epoch_'+str(args.max_epochs-1)) 204 | net.load_state_dict(torch.load(snapshot)) 205 | snapshot_name = snapshot_path.split('/')[-1] 206 | log_folder = './test_log/test_log_' + args.exp 207 | 208 | os.makedirs(log_folder, exist_ok=True) 209 | # Setup logging and initiate volume inference 210 | logging.basicConfig(filename=log_folder + '/'+snapshot_name+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 211 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 212 | logging.info(str(args)) 213 | logging.info(snapshot_name) 214 | 215 | vol_inference(args, net, validation=False) 216 | 217 | -------------------------------------------------------------------------------- /CSANet/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from networks.vit_seg_modeling import VisionTransformer as ViT_seg 9 | from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg 10 | from trainer import trainer_CSANet 11 | from datasets.dataset_CSANet import CSANet_dataset 12 | 13 | 14 | 15 | """ 16 | This script initializes and runs training for the CSANet segmentation model using Vision Transformer (ViT) architecture. 17 | It configures the training environment, sets up the data loading for a medical imaging dataset, and initializes the model 18 | with predefined or specified hyperparameters. The script is designed to be run with command-line arguments that allow 19 | customization of various parameters including data paths, model specifics, and training settings. 20 | 21 | Command-Line Arguments: 22 | - root_path: Directory containing training data. 23 | - dataset: Identifier for the dataset used, affecting certain preset configurations. 24 | - list_dir: Directory containing lists of training data specifics. 25 | - num_classes: Number of classes for segmentation. 26 | - volume_path: Path to validation data for model evaluation. 27 | - max_iterations: Total number of iterations the training should run. 28 | - max_epochs: Maximum number of epochs for which the model trains. 29 | - batch_size: Number of samples in each batch. 30 | - n_gpu: Number of GPUs available for training. 31 | - deterministic: Flag to ensure deterministic results, useful for reproducibility. 32 | - base_lr: Base learning rate for the optimizer. 33 | - img_size: Dimensions of the input images for the model. 34 | - seed: Random seed for initialization to ensure reproducibility. 35 | - n_skip: Number of skip connections in the ViT model. 36 | - vit_name: Name of the Vision Transformer configuration to be used. 37 | - vit_patches_size: Size of patches used in the ViT model. 38 | 39 | The script supports customization of the training process through these parameters and uses a pre-defined configuration 40 | for setting up the model, dataset, and training operations based on the provided dataset name. 41 | """ 42 | 43 | # Setup command-line interface 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--root_path', type=str, 46 | default='../data/train_npz', help='root dir for data') 47 | parser.add_argument('--dataset', type=str, 48 | default='CSANet', help='experiment_name') 49 | parser.add_argument('--list_dir', type=str, 50 | default='./lists', help='list dir') 51 | parser.add_argument('--num_classes', type=int, 52 | default=5, help='output channel of network') # class change ----------------- 53 | parser.add_argument('--volume_path', type=str, 54 | default='../data', help='root dir for validation volume data') 55 | parser.add_argument('--max_iterations', type=int, 56 | default=300000, help='maximum epoch number to train') 57 | parser.add_argument('--max_epochs', type=int, 58 | default=40, help='maximum epoch number to train') 59 | parser.add_argument('--batch_size', type=int, 60 | default=16, help='batch_size per gpu') 61 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 62 | parser.add_argument('--deterministic', type=int, default=1, 63 | help='whether use deterministic training') 64 | parser.add_argument('--base_lr', type=float, default=0.001, 65 | help='segmentation network learning rate') 66 | parser.add_argument('--img_size', type=int, 67 | default=224, help='input patch size of network input') 68 | parser.add_argument('--seed', type=int, 69 | default=1234, help='random seed') 70 | parser.add_argument('--n_skip', type=int, 71 | default=3, help='using number of skip-connect, default is num') 72 | parser.add_argument('--vit_name', type=str, 73 | default='R50-ViT-B_16', help='select one vit model') 74 | parser.add_argument('--vit_patches_size', type=int, 75 | default=16, help='vit_patches_size, default is 16') 76 | 77 | args = parser.parse_args() 78 | 79 | 80 | if __name__ == "__main__": 81 | # Configure deterministic behavior for reproducibility if specified 82 | if not args.deterministic: 83 | cudnn.benchmark = True 84 | cudnn.deterministic = False 85 | else: 86 | cudnn.benchmark = False 87 | cudnn.deterministic = True 88 | 89 | random.seed(args.seed) 90 | np.random.seed(args.seed) 91 | torch.manual_seed(args.seed) 92 | torch.cuda.manual_seed(args.seed) 93 | dataset_name = args.dataset 94 | dataset_config = { 95 | 'CSANet': { 96 | 'Dataset': CSANet_dataset, 97 | 'root_path': '../data/train_npz', 98 | 'volume_path': '../data', 99 | 'list_dir': './lists', 100 | 'num_classes': 5, 101 | 'z_spacing': 1, 102 | }, 103 | } 104 | if args.batch_size != 24 and args.batch_size % 6 == 0: 105 | args.base_lr *= args.batch_size / 24 106 | args.num_classes = dataset_config[dataset_name]['num_classes'] 107 | 108 | args.root_path = dataset_config[dataset_name]['root_path'] 109 | args.list_dir = dataset_config[dataset_name]['list_dir'] 110 | args.is_pretrain = True 111 | args.Dataset = dataset_config[dataset_name]['Dataset'] 112 | args.exp = 'CSANet_'+ str(args.img_size) 113 | 114 | # Build snapshot path based on the configuration and command-line arguments 115 | snapshot_path = "../model/{}/{}".format(args.exp, 'TU') 116 | snapshot_path += '_' + args.vit_name 117 | snapshot_path = snapshot_path + '_skip' + str(args.n_skip) 118 | snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path 119 | snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path 120 | snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path 121 | snapshot_path = snapshot_path+'_bs'+str(args.batch_size) 122 | snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path 123 | 124 | print("snapshot path = ", snapshot_path) 125 | if not os.path.exists(snapshot_path): 126 | os.makedirs(snapshot_path) 127 | 128 | # Load Vision Transformer with the specific configuration 129 | config_vit = CONFIGS_ViT_seg[args.vit_name] 130 | config_vit.n_classes = args.num_classes 131 | config_vit.n_skip = args.n_skip 132 | if args.vit_name.find('R50') != -1: 133 | config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size)) 134 | net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda() 135 | # Load initial weights if pretrained path is provided 136 | net.load_from(weights=np.load(config_vit.pretrained_path)) 137 | # Start training using the specified trainer for the dataset 138 | trainer = {'CSANet': trainer_CSANet} 139 | trainer[dataset_name](args, net, snapshot_path) 140 | -------------------------------------------------------------------------------- /CSANet/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import time 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from tensorboardX import SummaryWriter 12 | from torch.nn.modules.loss import CrossEntropyLoss 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | from loss_function import DiceLoss 16 | from torchvision import transforms 17 | from PIL import Image 18 | import matplotlib.pyplot as plt 19 | from utils import test_single_volume 20 | from visualization import save_visualization 21 | import cv2 22 | import torch.backends.cudnn as cudnn 23 | from datasets.dataset_CSANet import CSANet_dataset, RandomGenerator 24 | from test import vol_inference 25 | 26 | 27 | 28 | 29 | 30 | 31 | def trainer_CSANet(args, model, snapshot_path): 32 | 33 | """ 34 | Trains the CSANet model with the specified parameters and dataset, performing evaluations and saving the model state based on performance metrics. 35 | 36 | Parameters: 37 | - args (Namespace): Configuration containing all settings for the training process, such as dataset paths, learning rates, batch sizes, and more. 38 | - model (torch.nn.Module): The neural network model to be trained. 39 | - snapshot_path (str): Directory path where training snapshots (model states and logs) will be saved. 40 | 41 | Returns: 42 | - str: A message indicating that training has finished. 43 | 44 | The function initializes training setup, logs configurations, and enters a training loop where it continually feeds data through the model, computes losses, updates the model's weights, and logs the results. It also evaluates the model periodically using the `vol_inference` function and saves the model state when performance improves. The summary of training progress is saved using TensorBoard. 45 | """ 46 | 47 | # Configure logging 48 | logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, 49 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 50 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 51 | logging.info(str(args)) 52 | 53 | # Set training parameters from args 54 | base_lr = args.base_lr 55 | num_classes = args.num_classes 56 | batch_size = args.batch_size * args.n_gpu 57 | 58 | # Initialize dataset and dataloader 59 | db_train = CSANet_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train_image", 60 | transform=transforms.Compose( 61 | [RandomGenerator(output_size=[args.img_size, args.img_size])])) 62 | print("The length of train set is: {}".format(len(db_train))) 63 | 64 | def worker_init_fn(worker_id): 65 | # Seed each worker for reproducibility 66 | random.seed(args.seed + worker_id) 67 | 68 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, 69 | worker_init_fn=worker_init_fn) 70 | # Use DataParallel for multi-GPU training 71 | if args.n_gpu > 1: 72 | model = nn.DataParallel(model) 73 | model.train() 74 | 75 | # Define loss functions and optimizer 76 | ce_loss = CrossEntropyLoss() 77 | dice_loss = DiceLoss(num_classes) 78 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 79 | 80 | # Initialize TensorBoard writer 81 | writer = SummaryWriter(snapshot_path + '/log') 82 | iter_num = 0 83 | max_epoch = args.max_epochs 84 | max_iterations = args.max_epochs * len(trainloader) 85 | logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations)) 86 | best_performance = 0.0 87 | folder_path = "./training_result" 88 | if not os.path.exists(folder_path): 89 | os.makedirs(folder_path) 90 | 91 | #vol_inference(args, model, validation=False) 92 | # Training loop 93 | iterator = tqdm(range(max_epoch), ncols=70) 94 | for epoch_num in iterator: 95 | model.train() 96 | for i_batch, sampled_batch in enumerate(trainloader): 97 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 98 | image_next, image_prev = sampled_batch['next_image'], sampled_batch['prev_image'] 99 | 100 | # Ensure all tensors are on the same device 101 | image_batch, label_batch = image_batch.cuda(), label_batch.cuda() 102 | image_next_batch, image_prev_batch = image_next.cuda(), image_prev.cuda() 103 | 104 | # Forward pass 105 | outputs = model(image_prev_batch, image_batch, image_next_batch) 106 | 107 | # Calculate loss 108 | loss_ce = ce_loss(outputs, label_batch[:].long()) 109 | loss_dice = dice_loss(outputs, label_batch, softmax=True) 110 | loss = 0.5 * loss_ce + 0.5 * loss_dice 111 | 112 | # Backpropagation 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | 117 | # Visualization and logging 118 | save_visualization(outputs, label_batch, epoch_num, i_batch) 119 | 120 | # Learning rate adjustment 121 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 122 | for param_group in optimizer.param_groups: 123 | param_group['lr'] = lr_ 124 | 125 | iter_num = iter_num + 1 126 | writer.add_scalar('info/lr', lr_, iter_num) 127 | writer.add_scalar('info/total_loss', loss, iter_num) 128 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 129 | 130 | logging.info('iteration %d : loss : %f, loss_ce: %f' % (iter_num, loss.item(), loss_ce.item())) 131 | 132 | # End-of-epoch validation and checkpointing 133 | if epoch_num > 10 and (epoch_num % 5 == 0 or epoch_num == 39): 134 | avg_dice = vol_inference(args, model, validation=True) 135 | if avg_dice > best_performance: 136 | best_performance = avg_dice 137 | save_mode_path = os.path.join(snapshot_path, 'best_model.pth') 138 | torch.save(model.state_dict(), save_mode_path) 139 | logging.info(f"Saved new best model to {save_mode_path}") 140 | 141 | vol_inference(args, model, validation=False) 142 | writer.close() 143 | return "Training Finished!" 144 | -------------------------------------------------------------------------------- /CSANet/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | import torch.nn as nn 6 | import SimpleITK as sitk 7 | import os 8 | import nibabel as nib 9 | from skimage.measure import label, regionprops 10 | import scipy.ndimage as ndi 11 | import math 12 | 13 | 14 | def computing_COM_distance(mask_array, pred_array, US_indicies, spacing): 15 | 16 | """ 17 | Computes the average physical distance between the centers of mass (COM) of corresponding predicted and ground truth masks. 18 | 19 | Parameters: 20 | - mask_array (np.array): Array of ground truth masks. 21 | - pred_array (np.array): Array of predicted masks. 22 | - US_indicies (list of int): Indices of the mask slices to be processed. 23 | - spacing (tuple of float): Physical spacing between pixels in the masks. 24 | 25 | Returns: 26 | - float: Mean physical distance between centers of mass across specified mask slices. 27 | """ 28 | 29 | dist = [] 30 | for num, US_num in enumerate(US_indicies): 31 | predicted_mask = pred_array[US_num].astype('uint8') 32 | ground_truth_mask = mask_array[US_num].astype('uint8') 33 | cy_hist, cx_hist = ndi.center_of_mass(predicted_mask) 34 | cy_us, cx_us = ndi.center_of_mass(ground_truth_mask) 35 | temp = math.dist([cx_hist, cy_hist], [cx_us, cy_us]) 36 | phy_temp = temp * spacing[0] 37 | dist.append(phy_temp) 38 | distances = np.array(dist) 39 | 40 | if distances.size > 0: 41 | percentile_95 = np.percentile(distances, 95) 42 | mean_of_95th_percentile = np.mean(distances[distances <= percentile_95]) 43 | else: 44 | percentile_95 = 0 45 | mean_of_95th_percentile = 0 46 | return mean_of_95th_percentile 47 | 48 | 49 | 50 | 51 | def Dice_cal(image1, image2): 52 | 53 | """ 54 | Calculates Dice coefficients for multiple class labels between two images. 55 | 56 | Parameters: 57 | - image1 (SimpleITK.Image): First image for comparison. 58 | - image2 (SimpleITK.Image): Second image for comparison. 59 | 60 | Returns: 61 | - tuple: Dice coefficients for each class label (1, 2, 3, 4). 62 | """ 63 | 64 | class_labels = [1 , 2, 3, 4] 65 | for num_labels in class_labels: 66 | # Create binary masks for each class 67 | mask1 = sitk.Cast(image1 == num_labels, sitk.sitkInt32) 68 | mask2 = sitk.Cast(image2 == num_labels, sitk.sitkInt32) 69 | 70 | overlap_filter = sitk.LabelOverlapMeasuresImageFilter() 71 | overlap_filter.Execute(mask1, mask2) 72 | 73 | if num_labels == 1: 74 | dice_coeff_1 = overlap_filter.GetDiceCoefficient() 75 | elif num_labels == 2: 76 | dice_coeff_2 = overlap_filter.GetDiceCoefficient() 77 | elif num_labels == 3: 78 | dice_coeff_3 = overlap_filter.GetDiceCoefficient() 79 | elif num_labels == 4: 80 | dice_coeff_4 = overlap_filter.GetDiceCoefficient() 81 | 82 | return dice_coeff_1, dice_coeff_2, dice_coeff_3, dice_coeff_4 83 | 84 | def compute_class_hausdorff(labels, outputs, class_index, spacing): 85 | 86 | """ 87 | Computes the Hausdorff distance for a specific class based on its segmentation masks. 88 | 89 | Parameters: 90 | - labels (np.array): Array of ground truth labels for all classes. 91 | - outputs (np.array): Array of predicted labels for all classes. 92 | - class_index (int): Index of the class for which to compute the distance. 93 | - spacing (tuple of float): Physical spacing of the images. 94 | 95 | Returns: 96 | - float: Computed Hausdorff distance for the specified class. 97 | """ 98 | 99 | US_indicies = [] 100 | new_labels = labels[:,:,:,class_index] 101 | new_outputs = outputs[:,:,:,class_index] 102 | 103 | for z in range(new_labels.shape[0]): 104 | if np.sum(new_labels[z]) > 0 and np.sum(new_outputs[z]) > 0: 105 | US_indicies.append(z) 106 | 107 | hausdorff_dist = computing_COM_distance(new_labels, new_outputs, US_indicies, spacing) 108 | return hausdorff_dist 109 | 110 | 111 | def test_single_volume(image_next, image, image_prev, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None): 112 | 113 | """ 114 | Tests a single volume for segmentation using a deep learning model, and evaluates segmentation accuracy using Dice and Hausdorff distances. 115 | 116 | Parameters: 117 | - image_next, image_prev, image (np.array): Current, previous, and next slices of the image volume. 118 | - label (np.array): Ground truth labels for the current image slice. 119 | - net (torch.nn.Module): Neural network model used for segmentation. 120 | - classes (int): Number of segmentation classes. 121 | - patch_size (list of int): Size of the patches processed by the network. 122 | - test_save_path (str, optional): Path to save the segmentation results. 123 | - case (str, optional): Identifier for the case being tested. 124 | 125 | Returns: 126 | - tuple: Dice coefficients and Hausdorff distances for each class. 127 | """ 128 | # Convert tensors to numpy arrays for processing 129 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 130 | image_next, image_prev = image_next.squeeze(0).cpu().detach().numpy(), image_prev.squeeze(0).cpu().detach().numpy() 131 | 132 | # Initialize prediction array 133 | if len(image.shape) == 3: 134 | prediction = np.zeros_like(label) 135 | for ind in range(image.shape[0]): 136 | # Resize slices if necessary to match the network's expected input size 137 | slice = image[ind, :, :] 138 | slice_prev = image_prev[ind, :, :] 139 | slice_next = image_next[ind, :, :] 140 | 141 | x, y = slice.shape[0], slice.shape[1] 142 | if x != patch_size[0] or y != patch_size[1]: 143 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 144 | slice_prev = zoom(slice_prev, (patch_size[0] / x, patch_size[1] / y), order=3) 145 | slice_next = zoom(slice_next, (patch_size[0] / x, patch_size[1] / y), order=3) 146 | 147 | # Convert slices to tensors and run through the network 148 | input_curr = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 149 | input_prev = torch.from_numpy(slice_prev).unsqueeze(0).unsqueeze(0).float().cuda() 150 | input_next = torch.from_numpy(slice_next).unsqueeze(0).unsqueeze(0).float().cuda() 151 | 152 | net.eval() 153 | 154 | with torch.no_grad(): 155 | outputs = net(input_prev, input_curr , input_next) 156 | out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) 157 | out = out.cpu().detach().numpy() 158 | if x != patch_size[0] or y != patch_size[1]: 159 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 160 | else: 161 | pred = out 162 | prediction[ind] = pred 163 | else: 164 | # Handle single slice case 165 | input = torch.from_numpy(image).unsqueeze( 166 | 0).unsqueeze(0).float().cuda() 167 | net.eval() 168 | with torch.no_grad(): 169 | out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 170 | prediction = out.cpu().detach().numpy() 171 | 172 | 173 | 174 | 175 | # Prepare data for analysis and visualization 176 | Result_path = "./Result" 177 | if not os.path.exists(Result_path): 178 | os.makedirs(Result_path) 179 | num_case = case[0] 180 | test_vol_path = "../data/testVol" + "/" + num_case + ".nii.gz" 181 | vol_image = sitk.ReadImage(test_vol_path) 182 | 183 | # Create SimpleITK images from numpy arrays for evaluation 184 | pred_image = sitk.GetImageFromArray(prediction) 185 | pred_image.SetSpacing(vol_image.GetSpacing()) 186 | pred_image.SetDirection(vol_image.GetDirection()) 187 | pred_image.SetOrigin(vol_image.GetOrigin()) 188 | pred_path = './Result/' + num_case +'_segmentation.nii.gz' 189 | sitk.WriteImage(pred_image, pred_path) 190 | 191 | label_image = sitk.GetImageFromArray(label) 192 | label_image.SetSpacing(vol_image.GetSpacing()) 193 | label_image.SetDirection(vol_image.GetDirection()) 194 | label_image.SetOrigin(vol_image.GetOrigin()) 195 | label_path = './Result/' + num_case +'_label_segmentation.nii.gz' 196 | sitk.WriteImage(label_image, label_path) 197 | 198 | # Load ground truth mask for evaluation 199 | mask_path = "../data/testMask/"+num_case+".nii.gz" 200 | mask_img = sitk.ReadImage(mask_path) 201 | 202 | image1 = pred_image 203 | image2 = mask_img 204 | dice_coeff_1, dice_coeff_2, dice_coeff_3, dice_coeff_4 = 0.0, 0.0, 0.0, 0.0 205 | 206 | # Dice Coefficient Calculation 207 | dice_coeff_1, dice_coeff_2, dice_coeff_3, dice_coeff_4= Dice_cal(image1, image2) 208 | 209 | 210 | labels = np.eye(classes)[label] 211 | outputs = np.eye(classes)[prediction] 212 | spacing = vol_image.GetSpacing() 213 | 214 | # Hausdroff Distance Calculation 215 | hausdorff_dist_1 = compute_class_hausdorff(labels, outputs, 1, spacing) 216 | hausdorff_dist_2 = compute_class_hausdorff(labels, outputs, 2, spacing) 217 | hausdorff_dist_3 = compute_class_hausdorff(labels, outputs, 3, spacing) 218 | hausdorff_dist_4 = compute_class_hausdorff(labels, outputs, 4, spacing) 219 | 220 | return dice_coeff_1, dice_coeff_2, dice_coeff_3, dice_coeff_4, hausdorff_dist_1,hausdorff_dist_2, hausdorff_dist_3, hausdorff_dist_4 221 | -------------------------------------------------------------------------------- /CSANet/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def save_visualization(outputs, label_batch, epoch_num, i_batch): 7 | """ 8 | Processes the outputs and label batch, and saves visualization of the predictions and labels. 9 | 10 | Parameters: 11 | - outputs (torch.Tensor): The output predictions from the model. 12 | - label_batch (torch.Tensor): The batch of ground truth labels. 13 | - epoch_num (int): Current epoch number. 14 | - i_batch (int): Current batch index. 15 | 16 | Saves images to disk showing the predicted segmentation and actual labels. 17 | """ 18 | outputs = torch.softmax(outputs, dim=1) 19 | outputs = torch.argmax(outputs, dim=1).squeeze(dim=1) 20 | rand_slice_out = outputs[0,:,:] 21 | rand_slice_out = rand_slice_out.cpu().detach().numpy() 22 | plt.imshow(rand_slice_out) 23 | plt.colorbar() 24 | path1 = f"./training_result/{epoch_num}_{i_batch}pred_image.png" 25 | plt.savefig(path1) 26 | plt.close() 27 | 28 | rand_label_slice = label_batch[0,:,:] 29 | rand_label_slice = rand_label_slice.cpu().detach().numpy() 30 | plt.imshow(rand_label_slice) 31 | plt.colorbar() 32 | path2 = f"./training_result/{epoch_num}_{i_batch}label_image.png" 33 | plt.savefig(path2) 34 | plt.close() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 mirth AI lab at UF 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSA-Net: A Flexible 2.5D Medical Image Segmentation Approach with In-Slice and Cross-Slice Attention 2 | 3 |
4 |
5 |
6 | Figure 1: Visual representation of the CSA-Net architecture.
7 |