├── models ├── __init__.py ├── sam2 │ ├── sam │ │ ├── __init__.py │ │ ├── prompt_encoder.py │ │ ├── transformer.py │ │ └── mask_decoder.py │ ├── utils │ │ ├── __init__.py │ │ └── transformers.py │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── image_encoder.py │ │ └── hieradet.py │ ├── configs │ │ └── sam2 │ │ │ ├── sam2_hiera_b+.yaml │ │ │ ├── sam2_hiera_s.yaml │ │ │ ├── sam2_hiera_l.yaml │ │ │ └── sam2_hiera_t.yaml │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── build_sam.py │ └── position_encoding.py ├── components │ ├── __init__.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── semantic_extraction.py │ └── utils.py └── utils.py ├── datasets ├── __init__.py ├── data_load_demo.py ├── utils.py ├── data_process.py ├── train_data_prepare.py ├── data_loader.py └── dataset_info.txt ├── visualization ├── __init__.py ├── visualization_area.py ├── utils.py ├── visualization_slice.py ├── visualization_3d.py └── visualization_config.py ├── static ├── overview.png ├── visualization.png └── visualization_appendix.png ├── scripts ├── train.sh └── test.sh ├── README.md ├── test.py └── env_config.yml /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/sam2/sam/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YU-deep/CRISP_SAM2/HEAD/static/overview.png -------------------------------------------------------------------------------- /static/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YU-deep/CRISP_SAM2/HEAD/static/visualization.png -------------------------------------------------------------------------------- /static/visualization_appendix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YU-deep/CRISP_SAM2/HEAD/static/visualization_appendix.png -------------------------------------------------------------------------------- /models/sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /models/sam2/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | export CKPT="your_path_to_checkpoint" 2 | export WORK_DIR="your_path_to_work_dir" 3 | export DATA_DIR="your_path_to_datasets" 4 | export CLIP_TEXT_CKPT="path_to_clip_text_encoder_checkpoint" 5 | export CLIP_IMAGE_CKPT="path_to_clip_image_encoder_checkpoint" 6 | export CONFIG_FILE="path_to_sam2_config_file" 7 | export SAM2_CKPT="path_to_sam2_checkpoint" 8 | # export HF_TOKEN=xxxxx 9 | # export MASTER_ADDR=xxx.xxx.xxx.xxx 10 | # export MASTER_PORT=xxxx 11 | 12 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py \ 13 | --resume $SEGVOL_CKPT \ 14 | --work_dir $WORK_DIR \ 15 | --data_dir $DATA_DIR \ 16 | --clip_text_ckpt $CLIP_TEXT_CKPT \ 17 | --clip_image_ckpt $CLIP_IMAGE_CKPT \ 18 | --config_file $CONFIG_FILE\ 19 | --sam2_ckpt SAM2_CKPT -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | export PRETRAINED="your_path_to_checkpoint" 2 | export DATA_DIR="your_path_to_datasets" 3 | export WORK_DIR="your_path_to_work_dir" 4 | export RESULT_DIR="your_path_to_result_dir" 5 | export CLIP_TEXT_CKPT="path_to_clip_text_encoder_checkpoint" 6 | export CLIP_IMAGE_CKPT="path_to_clip_image_encoder_checkpoint" 7 | export CONFIG_FILE="path_to_sam2_config_file" 8 | export SAM2_CKPT="path_to_sam2_checkpoint" 9 | # export HF_TOKEN=xxxxx 10 | # export MASTER_ADDR=xxx.xxx.xxx.xxx 11 | # export MASTER_PORT=xxxx 12 | 13 | 14 | python test.py \ 15 | --pretrain $PRETRAINED \ 16 | --data_dir $DATA_DIR \ 17 | --data_dir $DATA_DIR \ 18 | --result_dir $RESULT_DIR \ 19 | --clip_text_ckpt $CLIP_TEXT_CKPT \ 20 | --clip_image_ckpt $CLIP_IMAGE_CKPT \ 21 | --config_file $CONFIG_FILE\ 22 | --sam2_ckpt SAM2_CKPT 23 | -------------------------------------------------------------------------------- /visualization/visualization_area.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | 4 | def zoom_in_one_area(image_path, output_path, box, zoom_factor): 5 | image = Image.open(image_path) 6 | left, upper, right, lower = box 7 | assert left * upper * right * lower <= 1 8 | left = left * image.size[0] 9 | upper = upper * image.size[1] 10 | right = right * image.size[0] 11 | lower = lower * image.size[1] 12 | # print(image.size) 13 | cropped_image = image.crop((left, upper, right, lower)) 14 | new_size = (int(cropped_image.width * zoom_factor), int(cropped_image.height * zoom_factor)) 15 | zoomed_image = cropped_image.resize(new_size, Image.LANCZOS) 16 | zoomed_image.save(output_path) 17 | 18 | 19 | image_path = '../output/...' 20 | output_path = '../output/...' 21 | 22 | 23 | box = (0.6, 0.6, 0.8, 0.8) # (x_upper_left, y_upper_left, x_lower_right, y_lower_right) should be [0, 1] 24 | zoom_factor = 2 25 | 26 | zoom_in_one_area(image_path, output_path, box, zoom_factor) 27 | -------------------------------------------------------------------------------- /datasets/data_load_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | import ast 4 | import os 5 | import json 6 | 7 | """ 8 | You should download the whole dataset first at: https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg 9 | Then, use this loader tool to load dataset 10 | """ 11 | uniseg_path = '/PATH/M3D_Seg' # PATH : your path 12 | dataset_code = '0001' 13 | json_path = os.path.join('../', dataset_code, dataset_code + '.json') 14 | with open(json_path, 'r') as f: 15 | dataset_dict = json.load(f) 16 | 17 | ct_file_path = os.path.join(uniseg_path, dataset_dict['train'][0]['image']) 18 | gt_file_path = os.path.join(uniseg_path, dataset_dict['train'][0]['label']) 19 | 20 | img_array = np.load(ct_file_path)[0] 21 | # print('img_array.shape ', img_array.shape) 22 | 23 | allmatrix_sp= sparse.load_npz(gt_file_path) 24 | gt_shape = ast.literal_eval(gt_file_path.split('.')[-2].split('_')[-1]) 25 | gt_array=allmatrix_sp.toarray().reshape(gt_shape) 26 | # print('gt_array.shape ', gt_array.shape) 27 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | 4 | 5 | def split_left_right(file_name, output_path): 6 | nifti_img = nib.load(file_name) 7 | segmentation_data = nifti_img.get_fdata() 8 | 9 | LABEL = 2 10 | LEFT_LABEL = 5 11 | RIGHT_LABEL = 6 12 | 13 | kidney_indices = np.argwhere(segmentation_data == LABEL) 14 | new_segmentation_data = np.copy(segmentation_data) 15 | 16 | for idx in kidney_indices: 17 | x, y, z = idx 18 | # determine left or right based on x axis 19 | if x < segmentation_data.shape[0] // 2: 20 | new_segmentation_data[x, y, z] = LEFT_LABEL 21 | else: 22 | new_segmentation_data[x, y, z] = RIGHT_LABEL 23 | 24 | new_nifti_img = nib.Nifti1Image(new_segmentation_data, nifti_img.affine) 25 | nib.save(new_nifti_img, output_path) 26 | 27 | 28 | def generate_point_prompt(mask): 29 | rows, cols = np.where(mask) 30 | x_min, x_max = cols.min(), cols.max() 31 | y_min, y_max = rows.min(), rows.max() 32 | x = x_max - x_min 33 | y = y_max - y_min 34 | 35 | while True: 36 | i0, j0 = np.random.choice(rows), np.random.choice(cols) 37 | if (i0 + int(0.1 * x) < mask.shape[1] and mask[i0, j0 + int(0.1 * x)] and 38 | i0 - int(0.1 * x) >= 0 and mask[i0, j0 - int(0.1 * x)] and 39 | j0 + int(0.1 * y) < mask.shape[0] and mask[j0 + int(0.1 * y), i0] and 40 | j0 - int(0.1 * y) >= 0 and mask[j0 - int(0.1 * y), i0]): 41 | break 42 | return (i0, j0) 43 | 44 | 45 | def generate_bbox_prompt(mask): 46 | rows, cols = np.where(mask) 47 | i1 = int(np.mean(rows)) 48 | j1 = int(np.mean(cols)) 49 | 50 | x_min, x_max = cols.min(), cols.max() 51 | y_min, y_max = rows.min(), rows.max() 52 | x = x_max - x_min 53 | y = y_max - y_min 54 | 55 | t1 = np.random.uniform(0.1, 0.3) 56 | t2 = np.random.uniform(0.1, 0.3) 57 | 58 | x1 = int(i1 - t1 * x) 59 | y1 = int(j1 - t2 * y) 60 | x2 = int(i1 + t1 * x) 61 | y2 = int(j1 + t2 * y) 62 | 63 | return (x1, y1, x2, y2) 64 | -------------------------------------------------------------------------------- /models/components/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CrispSamLoss(nn.Module): 7 | def __init__(self, alpha=0.5, beta=0.1, gamma=0.1, omega1=20 / 21, omega2=1 / 21): 8 | super().__init__() 9 | self.alpha = nn.Parameter(torch.tensor(alpha, requires_grad=True)) 10 | self.beta = nn.Parameter(torch.tensor(beta, requires_grad=True)) 11 | self.gamma = nn.Parameter(torch.tensor(gamma, requires_grad=True)) 12 | self.omega1 = omega1 13 | self.omega2 = omega2 14 | 15 | def focal_loss(self, inputs, targets, gamma=2): 16 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 17 | pt = torch.exp(-BCE_loss) 18 | F_loss = (1 - pt) ** gamma * BCE_loss 19 | return F_loss.mean() 20 | 21 | def dice_loss(self, inputs, targets): 22 | inputs = torch.sigmoid(inputs) 23 | inputs = inputs.view(-1) 24 | targets = targets.view(-1) 25 | intersection = (inputs * targets).sum() 26 | dice = (2. * intersection) / (inputs.sum() + targets.sum()) 27 | return 1 - dice 28 | 29 | def forward(self, masks_original_pred, masks_refined_pred, masks_gt, iou_pred=None, iou_gt=None, obj_pred=None, obj_gt=None): 30 | 31 | L_focal_original = self.focal_loss(masks_original_pred, masks_gt) 32 | L_dice_original = self.dice_loss(masks_original_pred, masks_gt) 33 | L_seg_original = self.omega1 * L_focal_original + self.omega2 * L_dice_original 34 | 35 | L_focal_refined = self.focal_loss(masks_refined_pred, masks_gt) 36 | L_dice_refined = self.dice_loss(masks_refined_pred, masks_gt) 37 | L_seg_refined = self.omega1 * L_focal_refined + self.omega2 * L_dice_refined 38 | 39 | if iou_pred is not None: 40 | if iou_gt is None: 41 | iou_gt = torch.ones_like(iou_pred) 42 | L_mae = F.l1_loss(iou_pred, iou_gt) 43 | else: 44 | L_mae = torch.tensor(0.0, requires_grad=False) 45 | 46 | if obj_pred is not None: 47 | if obj_gt is None: 48 | obj_gt = torch.ones(obj_pred.size(0), dtype=torch.long, device=obj_pred.device) 49 | L_ce = F.cross_entropy(obj_pred, obj_gt) 50 | else: 51 | L_ce = torch.tensor(0.0, requires_grad=False) 52 | 53 | total_loss = self.alpha * L_seg_original + ( 54 | 1 - self.alpha) * L_seg_refined + self.beta * L_mae + self.gamma * L_ce 55 | return L_seg_original, L_seg_refined, total_loss 56 | 57 | -------------------------------------------------------------------------------- /visualization/utils.py: -------------------------------------------------------------------------------- 1 | import vtk 2 | from visualization_config import * 3 | 4 | 5 | def read_volume(file_name): 6 | if file_name.endswith(".nii.gz"): 7 | reader = vtk.vtkNIFTIImageReader() 8 | elif file_name.endswith(".nrrd"): 9 | reader = vtk.vtkNrrdReader() 10 | reader.SetFileNameSliceOffset(1) 11 | reader.SetDataByteOrderToBigEndian() 12 | reader.SetFileName(file_name) 13 | reader.Update() 14 | return reader 15 | 16 | 17 | def create_mask_extractor(reader): 18 | """ 19 | Given the output from mask (vtkNIFTIImageReader) extract it into 3D using 20 | vtkDiscreteMarchingCubes algorithm (https://www.vtk.org/doc/release/5.0/html/a01331.html). 21 | This algorithm is specialized for reading segmented volume FLARE22. 22 | :param mask: AbdomenCT-1k vtkNIFTIImageReader volume containing the mask 23 | :return: the extracted volume from vtkDiscreteMarchingCubes 24 | """ 25 | mask_extractor = vtk.vtkDiscreteMarchingCubes() 26 | mask_extractor.SetInputConnection(reader.GetOutputPort()) 27 | return mask_extractor 28 | 29 | 30 | def create_smoother(reducer, smooth_factor): 31 | """ 32 | Reorients some points in the volume to smooth the render edges. 33 | (https://www.vtk.org/doc/nightly/html/classvtkSmoothPolyDataFilter.html) 34 | """ 35 | smoother = vtk.vtkSmoothPolyDataFilter() 36 | smoother.SetInputConnection(reducer.GetOutputPort()) 37 | smoother.SetNumberOfIterations(smooth_factor) 38 | smoother.BoundarySmoothingOn() 39 | return smoother 40 | 41 | 42 | def create_mapper(stripper): 43 | mapper = vtk.vtkPolyDataMapper() 44 | mapper.SetInputConnection(stripper.GetOutputPort()) 45 | mapper.ScalarVisibilityOff() 46 | return mapper 47 | 48 | 49 | def create_property(opacity=0.9, color=(1.0, 0.0, 0.0)): 50 | prop = vtk.vtkProperty() 51 | prop.SetColor(color[0], color[1], color[2]) 52 | prop.SetOpacity(opacity) 53 | # prop.SetRepresentationToWireframe() 54 | return prop 55 | 56 | 57 | def create_actor(mapper, prop): 58 | actor = vtk.vtkActor() 59 | actor.SetMapper(mapper) 60 | actor.SetProperty(prop) 61 | return actor 62 | 63 | 64 | def create_renderer(bg_color): 65 | renderer = vtk.vtkRenderer() 66 | renderer.SetBackground(bg_color[0], bg_color[1], bg_color[2]) 67 | return renderer 68 | 69 | 70 | def create_renderwindow(window_name=APPLICATION_TITLE, window_size=(600, 600)): 71 | render_window = vtk.vtkRenderWindow() 72 | render_window.SetWindowName(window_name) 73 | render_window.SetSize(window_size[0], window_size[1]) 74 | return render_window 75 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | 5 | 6 | 7 | def calculate_similarity_scores(items): 8 | num_slices = items.shape[0] 9 | similarity_scores = [] 10 | for i in range(num_slices): 11 | current_slice = items[i].flatten().reshape(1, -1) 12 | other_slices = np.delete(items, i, axis=0).reshape(-1, items.shape[1] * items.shape[2]) 13 | similarities = cosine_similarity(current_slice, other_slices) 14 | score = similarities.sum() 15 | similarity_scores.append(score) 16 | return np.array(similarity_scores) 17 | 18 | 19 | # todo : rank the input 20 | def rank_slices(slices): 21 | similarity_scores = calculate_similarity_scores(slices) 22 | ranked_indices = np.argsort(similarity_scores)[::-1] 23 | ranked_slices = slices[ranked_indices] 24 | return ranked_slices 25 | 26 | 27 | def rank_cond_frames(cond_frame_outputs): 28 | features = [] 29 | for cond_frame_index in cond_frame_outputs: 30 | features.append(cond_frame_outputs[cond_frame_index]["maskmem_features"]) 31 | similarity_scores = calculate_similarity_scores(np.array(features)) 32 | ranked_indices = np.argsort(similarity_scores)[::-1] 33 | ranked_cond_frame_outputs = cond_frame_outputs[ranked_indices] 34 | return ranked_cond_frame_outputs 35 | 36 | 37 | def select_similar_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): 38 | """ 39 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` 40 | that are similar to the current frame at `frame_idx`. Here, we take 41 | - the similar conditioning frame `frame_idx` (if any); 42 | 43 | Outputs: 44 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. 45 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. 46 | """ 47 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: 48 | selected_outputs = cond_frame_outputs 49 | unselected_outputs = {} 50 | else: 51 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" 52 | selected_outputs = {} 53 | ranked_cond_frame_outputs = rank_cond_frames(cond_frame_outputs) 54 | for cond_frame_index in ranked_cond_frame_outputs: 55 | selected_outputs[cond_frame_index] = cond_frame_outputs[cond_frame_index] 56 | 57 | # add the similar conditioning frame until reaching a total 58 | # of `max_cond_frame_num` conditioning frames. 59 | num_remain = max_cond_frame_num - len(selected_outputs) 60 | inds_remain = sorted( 61 | (t for t in cond_frame_outputs if t not in selected_outputs), 62 | key=lambda x: abs(x - frame_idx), 63 | )[:num_remain] 64 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) 65 | unselected_outputs = { 66 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs 67 | } 68 | 69 | return selected_outputs, unselected_outputs 70 | -------------------------------------------------------------------------------- /visualization/visualization_slice.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import nibabel as nib 5 | import nrrd 6 | import numpy as np 7 | from matplotlib.colors import ListedColormap 8 | 9 | from visualization_config import get_mask_colors 10 | 11 | 12 | def get_slice_image(image_path, label_path, label_slice_index, image_slice_index): 13 | image_nii = nib.load(image_path) 14 | image_data = image_nii.get_fdata() 15 | if label_path.endswith("nii.gz"): 16 | label_nii = nib.load(label_path) 17 | label_data = label_nii.get_fdata() 18 | else: 19 | label_data, _ = nrrd.read(label_path) 20 | 21 | unique_labels = set(label_data.flatten()) 22 | label_count = len(unique_labels) 23 | print(f"Label Count : {label_count}") 24 | 25 | # get slice 26 | image_slice = image_data[:, :, image_slice_index] 27 | label_slice = label_data[:, :, label_slice_index] 28 | image_slice = np.rot90(image_slice, k=1) 29 | label_slice = np.rot90(label_slice, k=1) 30 | return image_slice, label_slice 31 | 32 | 33 | def download_image(dataset_name, image_slice, label_slice, output_path): 34 | custom_colors = get_mask_colors(dataset_name) 35 | color_count = len(custom_colors) 36 | print(f"Color Count : {color_count}") 37 | # custom_colors = [(0, 0, 0, 0)] + custom_colors 38 | 39 | unique_labels = np.unique(label_slice) 40 | unique_labels = unique_labels.astype(int) 41 | unique_labels = unique_labels[unique_labels != 0] 42 | # print(unique_labels) 43 | 44 | actual_colors = [custom_colors[i - 1] for i in unique_labels] 45 | cmap = ListedColormap(actual_colors) 46 | 47 | new_label_slice = np.zeros_like(label_slice) 48 | for i, label in enumerate(unique_labels): 49 | new_label_slice[label_slice == label] = i + 1 50 | 51 | fig, ax = plt.subplots() 52 | ax.axis('off') 53 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 54 | plt.imshow(image_slice, cmap='gray', interpolation='bilinear') 55 | new_label_slice = np.ma.masked_where(new_label_slice == 0, new_label_slice) 56 | plt.imshow(new_label_slice, cmap=cmap, alpha=0.75) 57 | plt.axis('off') 58 | plt.savefig(output_path, bbox_inches='tight', pad_inches=0, transparent=True) 59 | plt.show() 60 | 61 | 62 | if __name__ == '__main__': 63 | dataset_name = "..." 64 | image_path = '../input/...' 65 | label_path = '../input/...' 66 | if label_path.endswith("seg.nrrd"): 67 | model_name = label_path.split("/")[-1].split(".")[0].split("_")[-1] 68 | else: 69 | model_name = "GT" 70 | output_name = image_path.split('/')[3].split('.')[0] 71 | output_path = os.path.join('../output/.../', output_name) 72 | if not os.path.exists(output_path): 73 | os.makedirs(output_path) 74 | 75 | png_path = os.path.join(output_path, model_name + ".png") 76 | print(f"output path : {png_path}") 77 | label_slice_index = 67 78 | image_slice_index = 67 79 | 80 | image_slice, label_slice = get_slice_image(image_path=image_path, label_path=label_path, 81 | image_slice_index=image_slice_index, label_slice_index=label_slice_index) 82 | download_image(dataset_name, image_slice, label_slice, png_path) 83 | -------------------------------------------------------------------------------- /models/sam2/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) 36 | return windows, (Hp, Wp) 37 | 38 | 39 | def window_unpartition(windows, window_size, pad_hw, hw): 40 | """ 41 | Window unpartition into original sequences and removing padding. 42 | Args: 43 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 44 | window_size (int): window size. 45 | pad_hw (Tuple): padded height and width (Hp, Wp). 46 | hw (Tuple): original height and width (H, W) before padding. 47 | Returns: 48 | x: unpartitioned sequences with [B, H, W, C]. 49 | """ 50 | Hp, Wp = pad_hw 51 | H, W = hw 52 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 53 | x = windows.reshape( 54 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 55 | ) 56 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) 57 | 58 | if Hp > H or Wp > W: 59 | x = x[:, :H, :W, :] 60 | return x 61 | 62 | 63 | class PatchEmbed(nn.Module): 64 | """ 65 | Image to Patch Embedding. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | kernel_size: Tuple[int, ...] = (7, 7), 71 | stride: Tuple[int, ...] = (4, 4), 72 | padding: Tuple[int, ...] = (3, 3), 73 | in_chans: int = 3, 74 | embed_dim: int = 768, 75 | ): 76 | """ 77 | Args: 78 | kernel_size (Tuple): kernel size of the projection layer. 79 | stride (Tuple): stride of the projection layer. 80 | padding (Tuple): padding size of the projection layer. 81 | in_chans (int): Number of input image channels. 82 | embed_dim (int): embed_dim (int): Patch embedding dimension. 83 | """ 84 | super().__init__() 85 | self.proj = nn.Conv2d( 86 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 87 | ) 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | x = self.proj(x) 91 | # B C H W -> B H W C 92 | x = x.permute(0, 2, 3, 1) 93 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (ACM MM 25) CRISP-SAM2 : SAM2 with Cross-Modal Interaction and Semantic Prompting for Multi-Organ Segmentation 2 | #### Xinlei Yu 1, Changmiao Wang 2, Hui Jin 1, Ahmed Elazab 3, Gangyong Jia 1, Xiang Wan 2, Changqing Zou 4, Ruiquan Ge 1 3 | #### 1 Hangzhou Dianzi University, 2 Shenzhen Research Institute of Big Data, 3 Shenzhen University, 4 Zhejiang University 4 | 5 | 6 | ## 🌟Overview 7 | 8 | ![overview](static/overview.png) 9 | 10 | ## 🛠️ Quick Start 11 | ## Installation 12 | It is highly recommended to employ a virtual environment with Python >= 3.10, Pytorch >= 2.5.1 and corresponding CUDA. 13 | ``` 14 | cd CRISP-SAM2 15 | conda env create -f env_config.yml 16 | conda activate CRISP_SAM2 17 | ``` 18 | 19 | 20 | ## Dataset Preparation 21 | - ### Visual Inputs 22 | | Datasets | Links | 23 | |--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------| 24 | | M3D-Seg | https://github.com/BAAIDCAI/M3D/
https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg/
https://www.modelscope.cn/datasets/GoodBaiBai88/M3D-Seg/ | 25 | | MSD-Spleen | http://medicaldecathlon.com/ | 26 | | Pancreas-CT | https://wiki.cancerimagingarchive.net/display/public/pancreas-ct/ | 27 | | LUNA16 | https://luna16.grand-challenge.org/Data/ | 28 | | AbdomenCT-1k | https://github.com/JunMa11/AbdomenCT-1K/ | 29 | | WORD | https://paperswithcode.com/dataset/word/ | 30 | | FLARE22 | https://flare22.grand-challenge.org/ | 31 | | AMOS22 | https://amos22.grand-challenge.org/ | 32 | 33 | - ### Textual Inputs 34 | The descriptive definitions and descriptions are stored in 'term_dictionary.json', and compared to the original M3D-Seg joint dataset, we add supplementary sentences. Here, the dictionary can be expanded arbitrarily as required. 35 | 36 | ## Train & Test 37 | - ### Training process 38 | We highly recommend that the whole training process should be conducted on at least 8 A100-80G GPUs. 39 | ``` 40 | bash scripts/train.sh 41 | ``` 42 | - ### Test process 43 | ``` 44 | bash scripts/test.sh 45 | ``` 46 | 47 | 48 | ## Visualization 49 | We provide comprehensive visualization utils, including 2D, 3D and local area visualization. 50 | 51 | ![visualization](static/visualization.png) 52 | ![visualization](static/visualization_appendix.png) 53 | 54 | ## Citation 55 | If you have any questions about this work, please feel free to contact me at: xinleiyu88@gmail.com. And if you want to cite us, please add this in your paper: 56 | ``` 57 | @article{yu2025crisp, 58 | title={CRISP-SAM2: SAM2 with Cross-Modal Interaction and Semantic Prompting for Multi-Organ Segmentation}, 59 | author={Yu, Xinlei and Wang, Changmiao and Jin, Hui and Elazab, Ahmed and Jia, Gangyong and Wan, Xiang and Zou, Changqing and Ge, Ruiquan}, 60 | journal={arXiv preprint arXiv:2506.23121}, 61 | year={2025} 62 | } 63 | ``` -------------------------------------------------------------------------------- /models/sam2/configs/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [64, 64] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [64, 64] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /models/sam2/configs/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /models/sam2/configs/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [64, 64] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [64, 64] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /models/sam2/configs/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /visualization/visualization_3d.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import SimpleITK as sitk 4 | 5 | from visualization_config import * 6 | from utils import * 7 | 8 | 9 | def mhd_to_nii(file_path: str): 10 | itk_image = sitk.ReadImage(file_path) 11 | save_path = file_path.replace(".mhd", ".nii.gz") 12 | out_arr = sitk.GetArrayFromImage(itk_image) 13 | out = sitk.GetImageFromArray(out_arr) 14 | sitk.WriteImage(out, save_path) 15 | return save_path 16 | 17 | 18 | if __name__ == '__main__': 19 | # configs 20 | show_axes = False 21 | show_outline = False 22 | generate_outline_face = False 23 | take_snapshot = True 24 | SMOOTH_FACTOR = 400 25 | ROTATE_X = 270 26 | ROTATE_Y = 0 27 | ROTATE_Z = 0 28 | model_name = 'ours' 29 | dataset_name = '.../' 30 | MASK_COLORS = get_mask_colors(dataset_name) 31 | MASK_OPACITY = get_mask_opacity(dataset_name) 32 | file_name = '../input/...' 33 | if file_name.endswith(".mhd"): 34 | file_name = mhd_to_nii(file_name) 35 | output_name = file_name.split('/')[3].split('.')[12] 36 | else: 37 | output_name = file_name.split('/')[3].split('.')[0] 38 | output_path = '../output/' + dataset_name + output_name 39 | if not os.path.exists(output_path): 40 | os.makedirs(output_path) 41 | print(f"file path : {file_name}") 42 | print(f"output path : {output_path}") 43 | rotate_config = '_' + str(SMOOTH_FACTOR) + '_' + str(ROTATE_X) + '_' + str(ROTATE_Y) + '_' + str(ROTATE_Z) 44 | snapshot_filename = output_path + '/' + model_name + rotate_config + '.png' 45 | 46 | # reader 47 | reader = read_volume(file_name) 48 | 49 | # transform 50 | mask_transform = vtk.vtkTransform() 51 | mask_transform.PostMultiply() 52 | mask_transform.Scale(SCALE) # scale 53 | mask_transform.RotateX(ROTATE_X) # rotate 54 | mask_transform.RotateY(ROTATE_Y) 55 | mask_transform.RotateZ(ROTATE_Z) 56 | 57 | # renderer and render window 58 | renderer = create_renderer(bg_color=RENDERER_BG_COLOR) 59 | render_window = create_renderwindow() 60 | render_window.AddRenderer(renderer) 61 | # render_window.SetSize(100, 100) 62 | # renderer.SetViewport(0.1, 0.1, 0.9, 0.9) 63 | # mapper and actors for segmentation results 64 | n_labels = int(reader.GetOutput().GetScalarRange()[1]) 65 | print(n_labels) 66 | 67 | for idx in range(n_labels): 68 | extracter = create_mask_extractor(reader) # extracter 69 | extracter.SetValue(0, idx + 1) 70 | smoother = create_smoother(extracter, SMOOTH_FACTOR) # smoother 71 | mapper = create_mapper(stripper=smoother) 72 | prop = create_property(opacity=MASK_OPACITY[idx], color=MASK_COLORS[idx]) # property 73 | actor = create_actor(mapper=mapper, prop=prop) # actor 74 | actor.SetUserTransform(mask_transform) 75 | renderer.AddActor(actor) 76 | 77 | # outline of the whole image 78 | if show_outline: 79 | outline = vtk.vtkOutlineFilter() # show outline 80 | outline.SetInputConnection(reader.GetOutputPort()) 81 | if generate_outline_face: # show surface of the outline 82 | outline.GenerateFacesOn() 83 | extracter = create_mask_extractor(reader) 84 | mapper = create_mapper(stripper=outline) 85 | prop = create_property(opacity=OUTLINE_OPACITY, color=OUTLINE_COLOR) 86 | actor = create_actor(mapper=mapper, prop=prop) 87 | actor.SetUserTransform(mask_transform) 88 | renderer.AddActor(actor) 89 | 90 | # show axes for better visualization 91 | if show_axes: 92 | axes_actor = vtk.vtkAxesActor() 93 | axes_actor.SetTotalLength(TOTAL_LENGTH[0], TOTAL_LENGTH[1], TOTAL_LENGTH[2]) # set axes length 94 | axes_actor.SetScale(5, 5, 5) 95 | renderer.AddActor(axes_actor) 96 | 97 | # start render 98 | render_window.Render() 99 | 100 | # screenshot 101 | if take_snapshot: 102 | w2if = vtk.vtkWindowToImageFilter() 103 | w2if.SetInput(render_window) 104 | w2if.SetInputBufferTypeToRGB() 105 | w2if.ReadFrontBufferOff() 106 | w2if.Update() 107 | 108 | writer = vtk.vtkPNGWriter() 109 | writer.SetFileName(snapshot_filename) 110 | writer.SetInputConnection(w2if.GetOutputPort()) 111 | writer.Write() 112 | 113 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from datasets.data_loader import get_loader 4 | from models.crisp_sam2 import CrispSam2 5 | from models.sam2.build_sam import build_sam2 6 | import os 7 | import argparse 8 | from datetime import datetime 9 | import numpy as np 10 | from monai.metrics import compute_meandice, compute_average_surface_distance 11 | 12 | 13 | def get_test_args_parser(): 14 | parser = argparse.ArgumentParser(description='Test Configuration') 15 | 16 | parser.add_argument("--mode", type=str, default='test') 17 | parser.add_argument("--pretrain", type=str, default='/path/to/pretrained_model') 18 | parser.add_argument("--data_dir", type=str, default='/path/to/data') 19 | parser.add_argument("--dataset_codes", type=list, default=['0003']) 20 | parser.add_argument("--patch_size", default=(96, 96, 96), type=tuple) 21 | parser.add_argument("--spatial_size", default=(32, 512, 512), type=tuple) 22 | parser.add_argument("--work_dir", type=str, default='./work_dir') 23 | parser.add_argument("--config_file", type=str, default='./path/to/config_file') 24 | parser.add_argument("--sam2_ckpt", type=str, default='./path/to/sam2_ckpt') 25 | # parser.add_argument("--model_id", type=str, default='hf_id') 26 | parser.add_argument("--clip_text_ckpt", type=str, default='./path/to/clip_text_ckpt') 27 | parser.add_argument("--clip_image_ckpt", type=str, default='./path/to/clip_image_ckpt') 28 | parser.add_argument('--num_workers', type=int, default=8) 29 | parser.add_argument('--batch_size', type=int, default=1) 30 | parser.add_argument('--result_dir', type=str, default='./results') 31 | parser.add_argument('--gpu', type=int, default=0) 32 | return parser.parse_args() 33 | 34 | 35 | def test(args): 36 | sam_model = build_sam2(config_file=args.config_file, checkpoint=args.sam2_ckpt, mode="eval") 37 | model = CrispSam2( 38 | image_encoder=sam_model.image_encoder, 39 | memory_attention=sam_model.memory_attention, 40 | memory_encoder=sam_model.memory_encoder, 41 | clip_text_ckpt=args.clip_text_ckpt, 42 | clip_image_ckpt=args.clip_image_ckpt, 43 | ).cuda(args.gpu) 44 | model.eval() 45 | 46 | if args.pretrain: 47 | checkpoint = torch.load(args.pretrain, map_location=f'cuda:{args.gpu}') 48 | model.load_state_dict(checkpoint['model']) 49 | print(f"Loaded checkpoint from {args.pretrain}") 50 | 51 | test_dataloader = get_loader(args) 52 | 53 | os.makedirs(args.result_dir, exist_ok=True) 54 | log_file = os.path.join(args.result_dir, f"test_log_{datetime.now().strftime('%Y%m%d-%H%M')}.txt") 55 | 56 | total_dsc = [] 57 | total_nsd = [] 58 | with torch.no_grad(), tqdm(total=len(test_dataloader), desc="Testing") as pbar: 59 | for batch in test_dataloader: 60 | image = batch["image"].cuda(args.gpu) 61 | gt3D = batch["post_label"].cuda(args.gpu) 62 | organ_name_list = batch['organ_name_list'] 63 | text_list = batch['text'] 64 | patient_id = batch['patient_id'][0] 65 | 66 | pred_masks = [] 67 | for cls_idx in range(len(organ_name_list)): 68 | labels_cls = gt3D[:, cls_idx] 69 | organ_name = organ_name_list[cls_idx] 70 | crisp_masks = [] 71 | 72 | for frame_idx in range(image.size()[-1]): 73 | output = model(image, text_list, frame_idx) 74 | crisp_mask = output['crisp_masks'].sigmoid() > 0.5 75 | crisp_masks.append(crisp_mask.cpu().numpy()) 76 | 77 | mask_3d = np.stack(crisp_masks, axis=0).squeeze() 78 | pred_masks.append(mask_3d) 79 | save_path = os.path.join(args.result_dir, patient_id) 80 | os.makedirs(save_path, exist_ok=True) 81 | np.save(os.path.join(save_path, f"{organ_name}_pred.npy"), mask_3d) 82 | 83 | if gt3D is not None: 84 | gt_mask = labels_cls.squeeze().cpu().numpy() 85 | dsc = compute_meandice(mask_3d, gt_mask, include_background=False) 86 | total_dsc.append(dsc) 87 | nsd = compute_average_surface_distance(mask_3d, gt_mask, include_background=False) 88 | total_nsd.append(nsd) 89 | with open(log_file, 'a') as f: 90 | f.write(f"Patient {patient_id}, Organ {organ_name}: DSC={dsc:.4f}, NSD={nsd:.4f}\n") 91 | 92 | pbar.update(1) 93 | 94 | if total_dsc and total_nsd: 95 | mean_dsc = np.mean(total_dsc) 96 | mean_nsd = np.mean(total_nsd) 97 | print(f"Test Finished. Mean DSC: {mean_dsc:.4f}, Mean NSD: {mean_nsd:.4f}") 98 | with open(log_file, 'a') as f: 99 | f.write(f"Overall Mean DSC: {mean_dsc:.4f}, Overall Mean NSD: {mean_nsd:.4f}\n") 100 | else: 101 | print("No ground truth provided, skipped metric calculation.") 102 | 103 | 104 | def main(): 105 | args = get_test_args_parser() 106 | print("Test Arguments:", args) 107 | test(args) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /models/sam2/utils/transformers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.transforms import Normalize, Resize, ToTensor 13 | 14 | 15 | class SAM2Transforms(nn.Module): 16 | def __init__( 17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 18 | ): 19 | """ 20 | Transforms for SAM2. 21 | """ 22 | super().__init__() 23 | self.resolution = resolution 24 | self.mask_threshold = mask_threshold 25 | self.max_hole_area = max_hole_area 26 | self.max_sprinkle_area = max_sprinkle_area 27 | self.mean = [0.485, 0.456, 0.406] 28 | self.std = [0.229, 0.224, 0.225] 29 | self.to_tensor = ToTensor() 30 | self.transforms = torch.jit.script( 31 | nn.Sequential( 32 | Resize((self.resolution, self.resolution)), 33 | Normalize(self.mean, self.std), 34 | ) 35 | ) 36 | 37 | def __call__(self, x): 38 | x = self.to_tensor(x) 39 | return self.transforms(x) 40 | 41 | def forward_batch(self, img_list): 42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 43 | img_batch = torch.stack(img_batch, dim=0) 44 | return img_batch 45 | 46 | def transform_coords( 47 | self, coords: torch.Tensor, normalize=False, orig_hw=None 48 | ) -> torch.Tensor: 49 | """ 50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 52 | 53 | Returns 54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 55 | """ 56 | if normalize: 57 | assert orig_hw is not None 58 | h, w = orig_hw 59 | coords = coords.clone() 60 | coords[..., 0] = coords[..., 0] / w 61 | coords[..., 1] = coords[..., 1] / h 62 | 63 | coords = coords * self.resolution # unnormalize coords 64 | return coords 65 | 66 | def transform_boxes( 67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 72 | """ 73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 74 | return boxes 75 | 76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 77 | """ 78 | Perform PostProcessing on output masks. 79 | """ 80 | from sam2.utils.misc import get_connected_components 81 | 82 | masks = masks.float() 83 | input_masks = masks 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | try: 86 | if self.max_hole_area > 0: 87 | # Holes are those connected components in background with area <= self.fill_hole_area 88 | # (background regions are those with mask scores <= self.mask_threshold) 89 | labels, areas = get_connected_components( 90 | mask_flat <= self.mask_threshold 91 | ) 92 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 93 | is_hole = is_hole.reshape_as(masks) 94 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 96 | 97 | if self.max_sprinkle_area > 0: 98 | labels, areas = get_connected_components( 99 | mask_flat > self.mask_threshold 100 | ) 101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 102 | is_hole = is_hole.reshape_as(masks) 103 | # We fill holes with negative mask score (-10.0) to change them to background. 104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 105 | except Exception as e: 106 | # Skip the post-processing step if the CUDA kernel fails 107 | warnings.warn( 108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 110 | "functionality may be limited (which doesn't affect the results in most cases; see " 111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", 112 | category=UserWarning, 113 | stacklevel=2, 114 | ) 115 | masks = input_masks 116 | 117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 118 | return masks -------------------------------------------------------------------------------- /visualization/visualization_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | GUI config 3 | """ 4 | # main window, only for Qt GUI 5 | APPLICATION_TITLE = "3D Segmentation Visualizer" 6 | MIN_IMG_WINDOW_WIDTH = 200 7 | 8 | """ 9 | VTK config 10 | """ 11 | COLOR_CONFIG = { 12 | "liver": (220 / 256, 0 / 256, 0 / 256), 13 | "spleen": (0 / 256, 220 / 256, 0 / 256), 14 | "stomach": (0 / 256, 220 / 256, 220 / 256), 15 | "gallbladder": (220 / 256, 0 / 256, 220 / 256), 16 | "esophagus": (225 / 256, 210 / 256, 190 / 256), 17 | "pancreas": (0 / 256, 0 / 256, 250 / 256), 18 | "duodenum": (200 / 256, 130 / 256, 60 / 256), 19 | "aorta": (0 / 256, 120 / 256, 120 / 256), 20 | "bladder": (40 / 256, 140 / 256, 90 / 256), 21 | "inferior vena cava": (220 / 256, 190 / 256, 150 / 256), 22 | "left kidney": (40 / 256, 60 / 256, 80 / 256), 23 | "right kidney": (220 / 256, 220 / 256, 0 / 256), 24 | "left adrenal gland": (90 / 256, 20 / 256, 110 / 256), 25 | "right adrenal gland": (200 / 256, 60 / 256, 90 / 256), 26 | "left femur": (255 / 256, 230 / 256, 225 / 256), 27 | "right femur": (110 / 256, 90 / 256, 200 / 256), 28 | "left lung": (100 / 256, 40 / 256, 140 / 256), 29 | "right lung": (60 / 256, 100 / 256, 200 / 256), 30 | "default": (1, 1, 1) 31 | } 32 | 33 | 34 | # mask config 35 | def get_mask_colors(dataset_name: str): 36 | if "word" in dataset_name.lower(): 37 | return [ 38 | # # WORD 39 | COLOR_CONFIG["liver"], 40 | COLOR_CONFIG["spleen"], 41 | COLOR_CONFIG["left kidney"], 42 | COLOR_CONFIG["right kidney"], 43 | COLOR_CONFIG["stomach"], 44 | COLOR_CONFIG["gallbladder"], 45 | COLOR_CONFIG["esophagus"], 46 | COLOR_CONFIG["pancreas"], 47 | COLOR_CONFIG["duodenum"], 48 | COLOR_CONFIG["default"], 49 | COLOR_CONFIG["default"], 50 | COLOR_CONFIG["default"], 51 | COLOR_CONFIG["default"], 52 | COLOR_CONFIG["bladder"], 53 | COLOR_CONFIG["left femur"], 54 | COLOR_CONFIG["right femur"], 55 | ] 56 | elif "flare" in dataset_name.lower(): 57 | return [ 58 | COLOR_CONFIG["liver"], 59 | COLOR_CONFIG["right kidney"], 60 | COLOR_CONFIG["spleen"], 61 | COLOR_CONFIG["pancreas"], 62 | COLOR_CONFIG["aorta"], 63 | COLOR_CONFIG["inferior vena cava"], 64 | COLOR_CONFIG["right adrenal gland"], 65 | COLOR_CONFIG["left adrenal gland"], 66 | COLOR_CONFIG["gallbladder"], 67 | COLOR_CONFIG["esophagus"], 68 | COLOR_CONFIG["stomach"], 69 | COLOR_CONFIG["duodenum"], 70 | COLOR_CONFIG["left kidney"], 71 | ] 72 | elif "abdomen" in dataset_name.lower(): 73 | return [ 74 | COLOR_CONFIG["liver"], 75 | COLOR_CONFIG["default"], 76 | COLOR_CONFIG["spleen"], 77 | COLOR_CONFIG["pancreas"], 78 | COLOR_CONFIG["left kidney"], 79 | COLOR_CONFIG["right kidney"], 80 | ] 81 | elif "amos" in dataset_name.lower(): 82 | return [ 83 | COLOR_CONFIG["spleen"], 84 | COLOR_CONFIG["right kidney"], 85 | COLOR_CONFIG["left kidney"], 86 | COLOR_CONFIG["gallbladder"], 87 | COLOR_CONFIG["esophagus"], 88 | COLOR_CONFIG["liver"], 89 | COLOR_CONFIG["stomach"], 90 | COLOR_CONFIG["aorta"], 91 | COLOR_CONFIG["inferior vena cava"], 92 | COLOR_CONFIG["pancreas"], 93 | COLOR_CONFIG["right adrenal gland"], 94 | COLOR_CONFIG["left adrenal gland"], 95 | COLOR_CONFIG["duodenum"], 96 | COLOR_CONFIG["bladder"], 97 | COLOR_CONFIG["default"], 98 | ] 99 | elif "luna" in dataset_name.lower(): 100 | return [ 101 | COLOR_CONFIG["default"], 102 | COLOR_CONFIG["default"], 103 | COLOR_CONFIG["right lung"], 104 | COLOR_CONFIG["left lung"], 105 | COLOR_CONFIG["default"], 106 | ] 107 | elif "spleen" in dataset_name.lower(): 108 | return [ 109 | COLOR_CONFIG["spleen"], 110 | ] 111 | elif "pancreas" in dataset_name.lower(): 112 | return [ 113 | COLOR_CONFIG["pancreas"], 114 | COLOR_CONFIG["default"], 115 | ] 116 | else: 117 | return [] 118 | 119 | 120 | def get_mask_opacity(dataset_name: str): 121 | if "word" in dataset_name.lower(): 122 | return [ 123 | # # WORD 124 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1 125 | ] 126 | elif "flare" in dataset_name.lower(): 127 | return [ 128 | # # FLARE22 129 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 130 | ] 131 | elif "abdomen" in dataset_name.lower(): 132 | return [ 133 | 1, 1, 1, 1, 1, 1 134 | ] 135 | elif "amos" in dataset_name.lower(): 136 | return [ 137 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0 138 | ] 139 | elif "luna" in dataset_name.lower(): 140 | return [ 141 | 0, 0, 1, 1, 0 142 | ] 143 | elif "spleen" in dataset_name.lower(): 144 | return [ 145 | 1 146 | ] 147 | elif "pancreas" in dataset_name.lower(): 148 | return [ 149 | 1, 0 150 | ] 151 | else: 152 | return [] 153 | 154 | 155 | SMOOTH_FACTOR = 400 156 | MAX_LABEL_LENGTH = 10 157 | 158 | # renderer 159 | COMPARE = True 160 | RENDERER_BG_COLOR = (1., 1., 1.) 161 | 162 | # outline config 163 | SHOW_OUTLINE = True 164 | OUTLINE_COLOR = (0, 1, 1) 165 | OUTLINE_OPACITY = 0.2 166 | 167 | # transform config 168 | ROTATE_X = 270 169 | ROTATE_Y = 180 170 | ROTATE_Z = 0 171 | SCALE = (100, 100, 100) 172 | """ 173 | You should set the rotate config for the best effect of visualization. 174 | """ 175 | # axes config 176 | SHOW_AXES = True 177 | TOTAL_LENGTH = (50, 50, 50) 178 | 179 | -------------------------------------------------------------------------------- /models/sam2/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from models.sam2.sam.transformer import RoPEAttention 13 | 14 | from models.sam2.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output -------------------------------------------------------------------------------- /models/sam2/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from models.sam2.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} -------------------------------------------------------------------------------- /models/components/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import math 13 | import warnings 14 | from typing import List 15 | 16 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 17 | from torch.optim import Adam, Optimizer 18 | from torch.optim.lr_scheduler import _LRScheduler 19 | 20 | 21 | __all__ = ["LinearLR", "ExponentialLR"] 22 | 23 | 24 | class _LRSchedulerMONAI(_LRScheduler): 25 | """Base class for increasing the learning rate between two boundaries over a number 26 | of iterations""" 27 | 28 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: 29 | """ 30 | Args: 31 | optimizer: wrapped optimizer. 32 | end_lr: the final learning rate. 33 | num_iter: the number of iterations over which the test occurs. 34 | last_epoch: the index of last epoch. 35 | Returns: 36 | None 37 | """ 38 | self.end_lr = end_lr 39 | self.num_iter = num_iter 40 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) 41 | 42 | 43 | class LinearLR(_LRSchedulerMONAI): 44 | """Linearly increases the learning rate between two boundaries over a number of 45 | iterations. 46 | """ 47 | 48 | def get_lr(self): 49 | r = self.last_epoch / (self.num_iter - 1) 50 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 51 | 52 | 53 | class ExponentialLR(_LRSchedulerMONAI): 54 | """Exponentially increases the learning rate between two boundaries over a number of 55 | iterations. 56 | """ 57 | 58 | def get_lr(self): 59 | r = self.last_epoch / (self.num_iter - 1) 60 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 61 | 62 | 63 | class WarmupCosineSchedule(LambdaLR): 64 | """Linear warmup and then cosine decay. 65 | Based on https://huggingface.co/ implementation. 66 | """ 67 | 68 | def __init__( 69 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 70 | ) -> None: 71 | """ 72 | Args: 73 | optimizer: wrapped optimizer. 74 | warmup_steps: number of warmup iterations. 75 | t_total: total number of segment_anything_training iterations. 76 | cycles: cosine cycles parameter. 77 | last_epoch: the index of last epoch. 78 | Returns: 79 | None 80 | """ 81 | self.warmup_steps = warmup_steps 82 | self.t_total = t_total 83 | self.cycles = cycles 84 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) 85 | 86 | def lr_lambda(self, step): 87 | if step < self.warmup_steps: 88 | return float(step) / float(max(1.0, self.warmup_steps)) 89 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 90 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 91 | 92 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 93 | 94 | def __init__( 95 | self, 96 | optimizer: Optimizer, 97 | warmup_epochs: int, 98 | max_epochs: int, 99 | warmup_start_lr: float = 0.0, 100 | eta_min: float = 0.0, 101 | last_epoch: int = -1, 102 | ) -> None: 103 | """ 104 | Args: 105 | optimizer (Optimizer): Wrapped optimizer. 106 | warmup_epochs (int): Maximum number of iterations for linear warmup 107 | max_epochs (int): Maximum number of iterations 108 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 109 | eta_min (float): Minimum learning rate. Default: 0. 110 | last_epoch (int): The index of last epoch. Default: -1. 111 | """ 112 | self.warmup_epochs = warmup_epochs 113 | self.max_epochs = max_epochs 114 | self.warmup_start_lr = warmup_start_lr 115 | self.eta_min = eta_min 116 | 117 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 118 | 119 | def get_lr(self) -> List[float]: 120 | """ 121 | Compute learning rate using chainable form of the scheduler 122 | """ 123 | if not self._get_lr_called_within_step: 124 | warnings.warn( 125 | "To get the last learning rate computed by the scheduler, " 126 | "please use `get_last_lr()`.", 127 | UserWarning, 128 | ) 129 | 130 | if self.last_epoch == 0: 131 | return [self.warmup_start_lr] * len(self.base_lrs) 132 | elif self.last_epoch < self.warmup_epochs: 133 | return [ 134 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 135 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 136 | ] 137 | elif self.last_epoch == self.warmup_epochs: 138 | return self.base_lrs 139 | else: 140 | return [ 141 | self.eta_min + 0.5 * (base_lr - self.eta_min) * 142 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 143 | for base_lr in self.base_lrs 144 | ] 145 | 146 | def _get_closed_form_lr(self) -> List[float]: 147 | """ 148 | Called when epoch is passed as a param to the `step` function of the scheduler. 149 | """ 150 | if self.last_epoch < self.warmup_epochs: 151 | return [ 152 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 153 | for base_lr in self.base_lrs 154 | ] 155 | 156 | return [ 157 | self.eta_min + 0.5 * (base_lr - self.eta_min) * 158 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 159 | for base_lr in self.base_lrs 160 | ] -------------------------------------------------------------------------------- /models/sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | import torch 11 | from hydra import compose 12 | from hydra.utils import instantiate 13 | from omegaconf import OmegaConf 14 | 15 | import models.sam2 as sam2 16 | 17 | # Check if the user is running Python from the parent directory of the sam2 repo 18 | # (i.e. the directory where this repo is cloned into) -- this is not supported since 19 | # it could shadow the sam2 package and cause issues. 20 | if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): 21 | # If the user has "sam2/sam2" in their path, they are likey importing the repo itself 22 | # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). 23 | # This typically happens because the user is running Python from the parent directory 24 | # that contains the sam2 repo they cloned. 25 | raise RuntimeError( 26 | "You're likely running Python from the parent directory of the sam2 repository " 27 | "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " 28 | "This is not supported since the `sam2` Python package could be shadowed by the " 29 | "repository name (the repository is also named `sam2` and contains the Python package " 30 | "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " 31 | "rather than its parent dir, or from your home directory) after installing SAM 2." 32 | ) 33 | 34 | 35 | HF_MODEL_ID_TO_FILENAMES = { 36 | "facebook/sam2-hiera-tiny": ( 37 | "configs/sam2/sam2_hiera_t.yaml", 38 | "sam2_hiera_tiny.pt", 39 | ), 40 | "facebook/sam2-hiera-small": ( 41 | "configs/sam2/sam2_hiera_s.yaml", 42 | "sam2_hiera_small.pt", 43 | ), 44 | "facebook/sam2-hiera-base-plus": ( 45 | "configs/sam2/sam2_hiera_b+.yaml", 46 | "sam2_hiera_base_plus.pt", 47 | ), 48 | "facebook/sam2-hiera-large": ( 49 | "configs/sam2/sam2_hiera_l.yaml", 50 | "sam2_hiera_large.pt", 51 | ), 52 | "facebook/sam2.1-hiera-tiny": ( 53 | "configs/sam2.1/sam2.1_hiera_t.yaml", 54 | "sam2.1_hiera_tiny.pt", 55 | ), 56 | "facebook/sam2.1-hiera-small": ( 57 | "configs/sam2.1/sam2.1_hiera_s.yaml", 58 | "sam2.1_hiera_small.pt", 59 | ), 60 | "facebook/sam2.1-hiera-base-plus": ( 61 | "configs/sam2.1/sam2.1_hiera_b+.yaml", 62 | "sam2.1_hiera_base_plus.pt", 63 | ), 64 | "facebook/sam2.1-hiera-large": ( 65 | "configs/sam2.1/sam2.1_hiera_l.yaml", 66 | "sam2.1_hiera_large.pt", 67 | ), 68 | } 69 | 70 | 71 | def build_sam2( 72 | config_file, 73 | ckpt_path=None, 74 | device="cuda", 75 | mode="eval", 76 | hydra_overrides_extra=[], 77 | apply_postprocessing=True, 78 | **kwargs, 79 | ): 80 | 81 | if apply_postprocessing: 82 | hydra_overrides_extra = hydra_overrides_extra.copy() 83 | hydra_overrides_extra += [ 84 | # dynamically fall back to multi-mask if the single mask is not stable 85 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 86 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 87 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 88 | ] 89 | # Read config and init model 90 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 91 | OmegaConf.resolve(cfg) 92 | model = instantiate(cfg.model, _recursive_=True) 93 | _load_checkpoint(model, ckpt_path) 94 | model = model.to(device) 95 | if mode == "eval": 96 | model.eval() 97 | return model 98 | 99 | 100 | def build_sam2_video_predictor( 101 | config_file, 102 | ckpt_path=None, 103 | device="cuda", 104 | mode="eval", 105 | hydra_overrides_extra=[], 106 | apply_postprocessing=True, 107 | vos_optimized=False, 108 | **kwargs, 109 | ): 110 | hydra_overrides = [ 111 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 112 | ] 113 | if vos_optimized: 114 | hydra_overrides = [ 115 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", 116 | "++model.compile_image_encoder=True", # Let sam2_base handle this 117 | ] 118 | 119 | if apply_postprocessing: 120 | hydra_overrides_extra = hydra_overrides_extra.copy() 121 | hydra_overrides_extra += [ 122 | # dynamically fall back to multi-mask if the single mask is not stable 123 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 124 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 125 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 126 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 127 | "++model.binarize_mask_from_pts_for_mem_enc=true", 128 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 129 | "++model.fill_hole_area=8", 130 | ] 131 | hydra_overrides.extend(hydra_overrides_extra) 132 | 133 | # Read config and init model 134 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 135 | OmegaConf.resolve(cfg) 136 | model = instantiate(cfg.model, _recursive_=True) 137 | _load_checkpoint(model, ckpt_path) 138 | model = model.to(device) 139 | if mode == "eval": 140 | model.eval() 141 | return model 142 | 143 | 144 | def _hf_download(model_id): 145 | from huggingface_hub import hf_hub_download 146 | 147 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] 148 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 149 | return config_name, ckpt_path 150 | 151 | 152 | def build_sam2_hf(model_id, **kwargs): 153 | config_name, ckpt_path = _hf_download(model_id) 154 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) 155 | 156 | 157 | def build_sam2_video_predictor_hf(model_id, **kwargs): 158 | config_name, ckpt_path = _hf_download(model_id) 159 | return build_sam2_video_predictor( 160 | config_file=config_name, ckpt_path=ckpt_path, **kwargs 161 | ) 162 | 163 | 164 | def _load_checkpoint(model, ckpt_path): 165 | if ckpt_path is not None: 166 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] 167 | missing_keys, unexpected_keys = model.load_state_dict(sd) 168 | if missing_keys: 169 | logging.error(missing_keys) 170 | raise RuntimeError() 171 | if unexpected_keys: 172 | logging.error(unexpected_keys) 173 | raise RuntimeError() 174 | logging.info("Loaded checkpoint sucessfully") -------------------------------------------------------------------------------- /datasets/data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import multiprocessing 4 | import argparse 5 | from scipy import sparse 6 | from sklearn.model_selection import train_test_split 7 | import json 8 | 9 | from monai.transforms import ( 10 | AddChanneld, 11 | Compose, 12 | LoadImaged, 13 | Orientationd, 14 | ) 15 | 16 | def set_parse(): 17 | # %% set up parser 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--category", default=["liver", "spleen", "stomach", "gallbladder", "esophagus", "pancreas", "duodenum", "aorta", "bladder", "inferior vena cava", "left kidney", "right kidney", "left adrenal gland", "right adrenal gland", "left femur", "right femur", "left lung", "right lung"], type=list) 20 | parser.add_argument("--image_dir", type=str, required=True) 21 | parser.add_argument("--label_dir", type=str, required=True) 22 | parser.add_argument("--dataset_code", type=str, required=True) 23 | parser.add_argument("--save_root", type=str, required=True) 24 | parser.add_argument("--test_ratio", type=float, required=True) 25 | 26 | args = parser.parse_args() 27 | return args 28 | 29 | args = set_parse() 30 | 31 | # get ct> dir 32 | image_list_all = [item for item in sorted(os.listdir(args.image_dir))] 33 | label_list_all = [item for item in sorted(os.listdir(args.label_dir))] 34 | assert len(image_list_all) == len(label_list_all) 35 | print('dataset size ', len(image_list_all)) 36 | 37 | # build dataset 38 | data_path_list_all = [] 39 | for idx in range(len(image_list_all)): 40 | img_path = os.path.join(args.image_dir, image_list_all[idx]) 41 | label_path = os.path.join(args.label_dir, label_list_all[idx]) 42 | name = image_list_all[idx].split('.')[0] 43 | info = (idx, name, img_path, label_path) 44 | data_path_list_all.append(info) 45 | 46 | img_loader = Compose( 47 | [ 48 | LoadImaged(keys=['image', 'label']), 49 | AddChanneld(keys=['image', 'label']), 50 | Orientationd(keys=['image', 'label'], axcodes="RAS"), 51 | ] 52 | ) 53 | 54 | # save 55 | save_path = os.path.join(args.save_root, args.dataset_code) 56 | ct_save_path = os.path.join(save_path, 'ct') 57 | gt_save_path = os.path.join(save_path, 'gt') 58 | if not os.path.exists(ct_save_path): 59 | os.makedirs(ct_save_path) 60 | if not os.path.exists(gt_save_path): 61 | os.makedirs(gt_save_path) 62 | 63 | # exist file: 64 | exist_file_list = os.listdir(ct_save_path) 65 | print('exist_file_list ', exist_file_list) 66 | 67 | def normalize(ct_narray): 68 | ct_voxel_ndarray = ct_narray.copy() 69 | ct_voxel_ndarray = ct_voxel_ndarray.flatten() 70 | # for all data 71 | thred = np.mean(ct_voxel_ndarray) 72 | voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)] 73 | # for foreground data 74 | upper_bound = np.percentile(voxel_filtered, 99.95) 75 | lower_bound = np.percentile(voxel_filtered, 00.05) 76 | mean = np.mean(voxel_filtered) 77 | std = np.std(voxel_filtered) 78 | ### transform ### 79 | ct_narray = np.clip(ct_narray, lower_bound, upper_bound) 80 | ct_narray = (ct_narray - mean) / max(std, 1e-8) 81 | return ct_narray 82 | 83 | def run(info): 84 | idx, file_name, case_path, label_path = info 85 | item = {} 86 | if file_name + '.npy' in exist_file_list: 87 | print(file_name + '.npy exist, skip') 88 | return 89 | print('process ', idx, '---' ,file_name) 90 | # generate ct_voxel_ndarray 91 | item_load = { 92 | 'image' : case_path, 93 | 'label' : label_path, 94 | } 95 | item_load = img_loader(item_load) 96 | ct_voxel_ndarray = item_load['image'] 97 | gt_voxel_ndarray = item_load['label'] 98 | 99 | ct_shape = ct_voxel_ndarray.shape 100 | item['image'] = ct_voxel_ndarray 101 | 102 | # generate gt_voxel_ndarray 103 | gt_voxel_ndarray = np.array(gt_voxel_ndarray).squeeze() 104 | present_categories = np.unique(gt_voxel_ndarray) 105 | gt_masks = [] 106 | for cls_idx in range(len(args.category)): 107 | cls = cls_idx + 1 108 | if cls not in present_categories: 109 | gt_voxel_ndarray_category = np.zeros(ct_shape) 110 | gt_masks.append(gt_voxel_ndarray_category) 111 | print('case {} ==> zero category '.format(idx) + args.category[cls_idx]) 112 | print(gt_voxel_ndarray_category.shape) 113 | else: 114 | gt_voxel_ndarray_category = gt_voxel_ndarray.copy() 115 | gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0 116 | gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1 117 | gt_masks.append(gt_voxel_ndarray_category) 118 | gt_voxel_ndarray = np.stack(gt_masks, axis=0) 119 | 120 | assert gt_voxel_ndarray.shape[0] == len(args.category), str(gt_voxel_ndarray.shape[0]) 121 | assert gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:] 122 | item['label'] = gt_voxel_ndarray.astype(np.int32) 123 | print(idx, ' load done!') 124 | 125 | ############################# 126 | item['image'] = normalize(item['image']) 127 | print(idx, ' transform done') 128 | 129 | ############################ 130 | print(file_name + ' ct gt <--> ', item['image'].shape, item['label'].shape) 131 | np.save(os.path.join(ct_save_path, file_name + '.npy'), item['image']) 132 | allmatrix_sp=sparse.csr_matrix(item['label'].reshape(item['label'].shape[0], -1)) 133 | sparse.save_npz(os.path.join(gt_save_path, file_name + '.' + str(item['label'].shape)), allmatrix_sp) 134 | print(file_name + ' save done!') 135 | 136 | def generate_dataset_json(root_dir, output_file, test_ratio=0.2): 137 | ct_dir = os.path.join(root_dir, 'ct') 138 | gt_dir = os.path.join(root_dir, 'gt') 139 | ct_paths = sorted([os.path.join(ct_dir, f) for f in sorted(os.listdir(ct_dir))]) 140 | gt_paths = sorted([os.path.join(gt_dir, f) for f in sorted(os.listdir(gt_dir))]) 141 | 142 | data = list(zip(ct_paths, gt_paths)) 143 | train_data, val_data = train_test_split(data, test_size=test_ratio) 144 | labels = {} 145 | labels['0'] = 'background' 146 | for idx in range(len(args.category)): 147 | label_name = args.category[idx] 148 | label_id = idx + 1 149 | labels[str(label_id)] = label_name 150 | dataset = { 151 | 'name': f'{args.dataset_code} Dataset', 152 | 'description': f'{args.dataset_code} Dataset', 153 | 'tensorImageSize': '4D', 154 | 'modality': { 155 | '0': 'CT', 156 | }, 157 | 'labels': labels, 158 | 'numTraining': len(train_data), 159 | 'numTest': len(val_data), 160 | 'segment_anything_training': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in train_data], 161 | 'validation': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in val_data] 162 | } 163 | with open(output_file, 'w') as f: 164 | print(f'{output_file} dump') 165 | json.dump(dataset, f, indent=2) 166 | 167 | if __name__ == "__main__": 168 | with multiprocessing.Pool(processes=10) as pool: 169 | pool.map(run, data_path_list_all) 170 | print('Process Finished!') 171 | 172 | generate_dataset_json(root_dir=save_path, 173 | output_file=os.path.join(save_path, f'{args.dataset_code}.json'), 174 | test_ratio=args.test_ratio) 175 | print('Json Split Done!') 176 | -------------------------------------------------------------------------------- /env_config.yml: -------------------------------------------------------------------------------- 1 | name: CRISP_SAM2 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - bzip2=1.0.8=h5eee18b_6 9 | - ca-certificates=2024.3.11=h06a4308_0 10 | - debugpy=1.6.7=py312h6a678d5_0 11 | - importlib_metadata=8.2.0=hd8ed1ab_0 12 | - ld_impl_linux-64=2.38=h1181459_1 13 | - libffi=3.4.4=h6a678d5_0 14 | - libgcc-ng=11.2.0=h1234567_1 15 | - libgomp=11.2.0=h1234567_1 16 | - libstdcxx-ng=11.2.0=h1234567_1 17 | - ncurses=6.4=h6a678d5_0 18 | - openssl=3.0.14=h5eee18b_0 19 | - packaging=24.1=pyhd8ed1ab_0 20 | - pip=24.0=py312h06a4308_0 21 | - python=3.12.4=h5148396_1 22 | - readline=8.2=h5eee18b_0 23 | - setuptools=69.5.1=py312h06a4308_0 24 | - sqlite=3.45.3=h5eee18b_0 25 | - tk=8.6.14=h39e8969_0 26 | - wheel=0.43.0=py312h06a4308_0 27 | - xz=5.4.6=h5eee18b_1 28 | - zlib=1.2.13=h5eee18b_0 29 | - pip: 30 | - absl-py==2.1.0 31 | - accelerate==0.31.0 32 | - aiohttp==3.9.5 33 | - aiosignal==1.3.1 34 | - albumentations=0.3.2 35 | - alembic==1.13.1 36 | - aniso8601==9.0.1 37 | - annotated-types==0.7.0 38 | - appdirs==1.4.4 39 | - asciitree==0.3.3 40 | - astor==0.8.1 41 | - async-timeout==4.0.3 42 | - attrs==23.2.0 43 | - Authlib==1.3.1 44 | - bert-score==0.3.13 45 | - blinker==1.7.0 46 | - blis==0.7.11 47 | - catalogue==2.0.10 48 | - certifi==2022.12.7 49 | - cffi==1.16.0 50 | - charset-normalizer==2.1.1 51 | - clearml==1.14.5rc0 52 | - click==8.1.7 53 | - cloudpathlib==0.18.1 54 | - cloudpickle==3.0.0 55 | - cmake==3.28.4 56 | - colorama==0.4.6 57 | - coloredlogs==15.0.1 58 | - colorlog==6.8.2 59 | - confection==0.1.5 60 | - connected-components-3d==3.14.0 61 | - contextlib2==21.6.0 62 | - contourpy==1.2.0 63 | - crcmod==1.7 64 | - cryptography==42.0.8 65 | - cucim==23.10.0 66 | - cycler==0.12.1 67 | - cymem==2.0.8 68 | - datasets==2.19.1 69 | - Deprecated==1.2.14 70 | - dill==0.3.8 71 | - django-s3-file-field-client==1.0.1 72 | - docker==7.0.0 73 | - docutils==0.16.0 74 | - dotmap==1.3.30 75 | - dynamic-network-architectures==0.3.1 76 | - edt==2.4.0 77 | - einops==0.7.0 78 | - entrypoints==0.4 79 | - environs==11.0.0 80 | - eval_type_backport==0.2.0 81 | - fasteners==0.19 82 | - fastremap==1.14.1 83 | - filelock==3.11.0 84 | - fire==0.6.0 85 | - flake8==5.0.4 86 | - flash-attn==2.5.6 87 | - Flask==3.0.2 88 | - flatbuffers==24.3.7 89 | - fonttools==4.50.0 90 | - frozenlist==1.4.1 91 | - fsspec==2024.3.1 92 | - ftfy==6.2.3 93 | - furl==2.1.3 94 | - fvcore==0.1.5.post20221221 95 | - gdown==5.1.0 96 | - girder-cli-oauth-client==0.4.0 97 | - gitdb==4.0.11 98 | - GitPython==3.1.42 99 | - graphene==3.3 100 | - graphql-core==3.2.3 101 | - graphql-relay==3.2.0 102 | - greenlet==3.0.3 103 | - grpcio==1.62.1 104 | - gunicorn==21.2.0 105 | - h5py==3.10.0 106 | - hf_transfer==0.1.8 107 | - httpx==0.27.0 108 | - huggingface-hub==0.23.4 109 | - humanfriendly==10.0 110 | - humanize==4.9.0 111 | - idna==3.4 112 | - imagecodecs==2024.1.1 113 | - imageio==2.34.0 114 | - importlib_metadata==7.1.0 115 | - importlib_resources==6.4.0 116 | - iopath==0.1.10 117 | - isic-cli==0.0.0 118 | - isic-metadata==0.0.0 119 | - itk==5.3.0 120 | - itk-core==5.3.0 121 | - itk-filtering==5.3.0 122 | - itk-io==5.3.0 123 | - itk-numerics==5.3.0 124 | - itk-registration==5.3.0 125 | - itk-segmentation==5.3.0 126 | - itsdangerous==2.1.2 127 | - jinja2==3.1.4 128 | - joblib==1.4.2 129 | - json5==0.9.25 130 | - jsonpointer==3.0.0 131 | - jsonschema==4.23.0 132 | - jsonschema-specifications==2023.12.1 133 | - kiwisolver==1.4.5 134 | - lazy-loader==0.4 135 | - linecache2==1.0.0 136 | - markupsafe==2.1.5 137 | - matplotlib==3.9.1 138 | - mistune==3.0.2 139 | - mmcv=2.0.0rc4 140 | - mmengine=0.7.1 141 | - monai==1.3.2 142 | - mpmath==1.3.0 143 | - nbclient==0.10.0 144 | - nbconvert==7.16.4 145 | - nbformat==5.10.4 146 | - networkx==3.3 147 | - nibabel==5.2.1 148 | - nnunetv2==2.5.1 149 | - notebook-shim==0.2.4 150 | - numpy==2.0.1 151 | - nvidia-cublas-cu12==12.1.3.1 152 | - nvidia-cuda-cupti-cu12==12.1.105 153 | - nvidia-cuda-nvrtc-cu12==12.1.105 154 | - nvidia-cuda-runtime-cu12==12.1.105 155 | - nvidia-cudnn-cu12==9.1.0.70 156 | - nvidia-cufft-cu12==11.0.2.54 157 | - nvidia-curand-cu12==10.3.2.106 158 | - nvidia-cusolver-cu12==11.4.5.107 159 | - nvidia-cusparse-cu12==12.1.0.106 160 | - nvidia-nccl-cu12==2.20.5 161 | - nvidia-nvjitlink-cu12==12.5.82 162 | - nvidia-nvtx-cu12==12.1.105 163 | - omegaconf==2.3.0 164 | - opencv-python==4.10.0.84 165 | - opentelemetry-api==1.21.0 166 | - opentelemetry-exporter-otlp-proto-common==1.21.0 167 | - opentelemetry-exporter-otlp-proto-http==1.21.0 168 | - opentelemetry-proto==1.21.0 169 | - opentelemetry-sdk==1.21.0 170 | - opentelemetry-semantic-conventions==0.42b0 171 | - overrides==7.7.0 172 | - pandas==2.2.1 173 | - pandocfilters==1.5.1 174 | - pillow==10.4.0 175 | - portalocker==2.10.1 176 | - prometheus-client==0.20.0 177 | - protobuf==4.25.4 178 | - psutil==5.9.8 179 | - pycparser==2.22 180 | - pydantic==2.7.4 181 | - pydantic_core==2.18.4 182 | - pydicom==2.4.4 183 | - pyparsing==3.1.2 184 | - python-dateutil==2.9.0.post0 185 | - python-gdcm==3.0.24.1 186 | - graphviz==0.20.3 187 | - python-json-logger==2.0.7 188 | - pytorch-ignite==0.4.11 189 | - pytorch-lightning==2.2.5 190 | - pytz==2024.1 191 | - pyyaml==6.0.1 192 | - referencing==0.35.1 193 | - requests==2.32.3 194 | - rfc3339-validator==0.1.4 195 | - rfc3986-validator==0.1.1 196 | - rpds-py==0.19.1 197 | - scipy==1.12.0 198 | - simpleitk==2.3.1 199 | - sniffio==1.3.1 200 | - soupsieve==2.5 201 | - sphinx==4.0.2 202 | - sphinx-copybutton 203 | - sphinx_markdown_tables 204 | - sphinx_rtd_theme==0.5.2 205 | - sympy==1.13.1 206 | - synapseclient==4.4.0 207 | - tensorboard==2.16.2 208 | - tensorboard-data-server==0.7.2 209 | - tensorboardX==2.6.2.2 210 | - termcolor==2.4.0 211 | - terminado==0.18.1 212 | - threadpoolctl==3.5.0 213 | - tifffile==2024.7.24 214 | - tinycss2==1.3.0 215 | - torch==2.4.0 216 | - torchaudio==2.4.0 217 | - torchsummary==1.5.1 218 | - torchvision==0.19.0 219 | - torchmetrics==1.4.0.post0 220 | - tqdm==4.66.4 221 | - traceback2==1.4.0 222 | - triton==3.0.0 223 | - types-python-dateutil==2.9.0.20240316 224 | - tzdata==2024.1 225 | - unittest2==1.1.0 226 | - uri-template==1.3.0 227 | - urllib3==1.26.19 228 | - webcolors==24.6.0 229 | - webencodings==0.5.1 230 | - websocket-client==1.8.0 231 | - widgetsnbextension==4.0.11 232 | - wrapt==1.16.0 233 | - yacs==0.1.8 -------------------------------------------------------------------------------- /datasets/train_data_prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import multiprocessing 4 | import argparse 5 | from scipy import sparse 6 | from sklearn.model_selection import train_test_split 7 | import json 8 | join = os.path.join 9 | 10 | from monai.transforms import ( 11 | AddChanneld, 12 | Compose, 13 | LoadImaged, 14 | Orientationd, 15 | ) 16 | 17 | def set_parse(): 18 | # %% set up parser 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--category", default=["liver", "spleen", "stomach", "gallbladder", "esophagus", "pancreas", "duodenum", "aorta", "bladder", "inferior vena cava", "left kidney", "right kidney", "left adrenal gland", "right adrenal gland", "left femur", "right femur", "left lung", "right lung"], type=list) 21 | parser.add_argument("--image_dir", type=str, required=True) 22 | parser.add_argument("--label_dir", type=str, required=True) 23 | parser.add_argument("--dataset_code", type=str, required=True) 24 | parser.add_argument("--save_root", type=str, required=True) 25 | parser.add_argument("--test_ratio", type=float, required=True) 26 | 27 | args = parser.parse_args() 28 | return args 29 | 30 | args = set_parse() 31 | 32 | # get ct> dir 33 | image_list_all = [item for item in sorted(os.listdir(args.image_dir))] 34 | label_list_all = [item for item in sorted(os.listdir(args.label_dir))] 35 | assert len(image_list_all) == len(label_list_all) 36 | print('dataset size ', len(image_list_all)) 37 | 38 | # build dataset 39 | data_path_list_all = [] 40 | for idx in range(len(image_list_all)): 41 | img_path = join(args.image_dir, image_list_all[idx]) 42 | label_path = join(args.label_dir, label_list_all[idx]) 43 | name = image_list_all[idx].split('.')[0] 44 | info = (idx, name, img_path, label_path) 45 | data_path_list_all.append(info) 46 | 47 | img_loader = Compose( 48 | [ 49 | LoadImaged(keys=['image', 'label']), 50 | AddChanneld(keys=['image', 'label']), 51 | # Orientationd(keys=['image', 'label'], axcodes="RAS"), 52 | ] 53 | ) 54 | 55 | # save 56 | save_path = join(args.save_root, args.dataset_code) 57 | os.makedirs(save_path, exist_ok=True) 58 | 59 | # ct_save_path = join(save_path, 'ct') 60 | # gt_save_path = join(save_path, 'gt') 61 | # if not os.path.exists(ct_save_path): 62 | # os.makedirs(ct_save_path) 63 | # if not os.path.exists(gt_save_path): 64 | # os.makedirs(gt_save_path) 65 | 66 | # exist file: 67 | exist_file_list = os.listdir(save_path) 68 | print('exist_file_list ', exist_file_list) 69 | 70 | def normalize(ct_narray): 71 | ct_voxel_ndarray = ct_narray.copy() 72 | ct_voxel_ndarray = ct_voxel_ndarray.flatten() 73 | # for all data 74 | thred = np.mean(ct_voxel_ndarray) 75 | voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)] 76 | # for foreground data 77 | upper_bound = np.percentile(voxel_filtered, 99.95) 78 | lower_bound = np.percentile(voxel_filtered, 00.05) 79 | mean = np.mean(voxel_filtered) 80 | std = np.std(voxel_filtered) 81 | ### transform ### 82 | ct_narray = np.clip(ct_narray, lower_bound, upper_bound) 83 | ct_narray = (ct_narray - mean) / max(std, 1e-8) 84 | return ct_narray 85 | 86 | def run(info): 87 | idx, file_name, case_path, label_path = info 88 | 89 | item = {} 90 | if file_name in exist_file_list: 91 | print(file_name + ' exist, skip') 92 | return 93 | print('process ', idx, '---' ,file_name) 94 | # generate ct_voxel_ndarray 95 | item_load = { 96 | 'image' : case_path, 97 | 'label' : label_path, 98 | } 99 | item_load = img_loader(item_load) 100 | ct_voxel_ndarray = item_load['image'] 101 | gt_voxel_ndarray = item_load['label'] 102 | 103 | ct_shape = ct_voxel_ndarray.shape 104 | item['image'] = ct_voxel_ndarray 105 | 106 | # generate gt_voxel_ndarray 107 | gt_voxel_ndarray = np.array(gt_voxel_ndarray).squeeze() 108 | present_categories = np.unique(gt_voxel_ndarray) 109 | gt_masks = [] 110 | for cls_idx in range(len(args.category)): 111 | cls = cls_idx + 1 112 | if cls not in present_categories: 113 | gt_voxel_ndarray_category = np.zeros(ct_shape) 114 | gt_masks.append(gt_voxel_ndarray_category) 115 | print('case {} ==> zero category '.format(idx) + args.category[cls_idx]) 116 | print(gt_voxel_ndarray_category.shape) 117 | else: 118 | gt_voxel_ndarray_category = gt_voxel_ndarray.copy() 119 | gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0 120 | gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1 121 | gt_masks.append(gt_voxel_ndarray_category) 122 | gt_voxel_ndarray = np.stack(gt_masks, axis=0) 123 | 124 | assert gt_voxel_ndarray.shape[0] == len(args.category), str(gt_voxel_ndarray.shape[0]) 125 | assert gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:] 126 | item['label'] = gt_voxel_ndarray.astype(np.int32) 127 | print(idx, ' load done!') 128 | 129 | ############################# 130 | item['image'] = normalize(item['image']) 131 | print(idx, ' transform done') 132 | 133 | ############################ 134 | print(file_name + ' ct gt <--> ', item['image'].shape, item['label'].shape) 135 | case_path = join(save_path, file_name) 136 | os.makedirs(case_path, exist_ok=True) 137 | 138 | np.save(join(case_path, 'image.npy'), item['image']) 139 | allmatrix_sp=sparse.csr_matrix(item['label'].reshape(item['label'].shape[0], -1)) 140 | sparse.save_npz(join(case_path, 'mask_' + str(item['label'].shape)), allmatrix_sp) 141 | print(file_name + ' save done!') 142 | 143 | def generate_dataset_json(root_dir, output_file, test_ratio=0.2): 144 | cases = os.listdir(root_dir) 145 | ct_paths, gt_paths = [], [] 146 | for case_name in cases: 147 | case_files = sorted(os.listdir(join(root_dir, case_name))) 148 | ct_path = join(root_dir, case_name, case_files[0]) 149 | gt_path = join(root_dir, case_name, case_files[1]) 150 | ct_paths.append(ct_path) 151 | gt_paths.append(gt_path) 152 | 153 | data = list(zip(ct_paths, gt_paths)) 154 | train_data, val_data = train_test_split(data, test_size=test_ratio) 155 | labels = {} 156 | labels['0'] = 'background' 157 | for idx in range(len(args.category)): 158 | label_name = args.category[idx] 159 | label_id = idx + 1 160 | labels[str(label_id)] = label_name 161 | dataset = { 162 | 'name': f'{args.dataset_code} Dataset', 163 | 'description': f'{args.dataset_code} Dataset', 164 | 'tensorImageSize': '4D', 165 | 'modality': { 166 | '0': 'CT', 167 | }, 168 | 'labels': labels, 169 | 'numTrain': len(train_data), 170 | 'numTest': len(val_data), 171 | 'train': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in train_data], 172 | 'test': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in val_data] 173 | } 174 | with open(output_file, 'w') as f: 175 | print(f'{output_file} dump') 176 | json.dump(dataset, f, indent=2) 177 | 178 | if __name__ == "__main__": 179 | with multiprocessing.Pool(processes=64) as pool: 180 | pool.map(run, data_path_list_all) 181 | print('Process Finished!') 182 | 183 | generate_dataset_json(root_dir=save_path, 184 | output_file=join(save_path, f'{args.dataset_code}.json'), 185 | test_ratio=args.test_ratio) 186 | print('Json Split Done!') 187 | -------------------------------------------------------------------------------- /models/sam2/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from models.sam2.position_encoding import PositionEmbeddingRandom 13 | 14 | from models.sam2.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | 96 | point_embedding = torch.where( 97 | (labels == -1).unsqueeze(-1), 98 | torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, 99 | point_embedding, 100 | ) 101 | point_embedding = torch.where( 102 | (labels == 0).unsqueeze(-1), 103 | point_embedding + self.point_embeddings[0].weight, 104 | point_embedding, 105 | ) 106 | point_embedding = torch.where( 107 | (labels == 1).unsqueeze(-1), 108 | point_embedding + self.point_embeddings[1].weight, 109 | point_embedding, 110 | ) 111 | point_embedding = torch.where( 112 | (labels == 2).unsqueeze(-1), 113 | point_embedding + self.point_embeddings[2].weight, 114 | point_embedding, 115 | ) 116 | point_embedding = torch.where( 117 | (labels == 3).unsqueeze(-1), 118 | point_embedding + self.point_embeddings[3].weight, 119 | point_embedding, 120 | ) 121 | return point_embedding 122 | 123 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 124 | """Embeds box prompts.""" 125 | boxes = boxes + 0.5 # Shift to center of pixel 126 | coords = boxes.reshape(-1, 2, 2) 127 | corner_embedding = self.pe_layer.forward_with_coords( 128 | coords, self.input_image_size 129 | ) 130 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 131 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 132 | return corner_embedding 133 | 134 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 135 | """Embeds mask inputs.""" 136 | mask_embedding = self.mask_downscaling(masks) 137 | return mask_embedding 138 | 139 | def _get_batch_size( 140 | self, 141 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 142 | boxes: Optional[torch.Tensor], 143 | masks: Optional[torch.Tensor], 144 | ) -> int: 145 | """ 146 | Gets the batch size of the output given the batch size of the input prompts. 147 | """ 148 | if points is not None: 149 | return points[0].shape[0] 150 | elif boxes is not None: 151 | return boxes.shape[0] 152 | elif masks is not None: 153 | return masks.shape[0] 154 | else: 155 | return 1 156 | 157 | def _get_device(self) -> torch.device: 158 | return self.point_embeddings[0].weight.device 159 | 160 | def forward( 161 | self, 162 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 163 | boxes: Optional[torch.Tensor], 164 | masks: Optional[torch.Tensor], 165 | ) -> Tuple[torch.Tensor, torch.Tensor]: 166 | """ 167 | Embeds different types of prompts, returning both sparse and dense 168 | embeddings. 169 | 170 | Arguments: 171 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 172 | and labels to embed. 173 | boxes (torch.Tensor or none): boxes to embed 174 | masks (torch.Tensor or none): masks to embed 175 | 176 | Returns: 177 | torch.Tensor: sparse embeddings for the points and boxes, with shape 178 | BxNx(embed_dim), where N is determined by the number of input points 179 | and boxes. 180 | torch.Tensor: dense embeddings for the masks, in the shape 181 | Bx(embed_dim)x(embed_H)x(embed_W) 182 | """ 183 | bs = self._get_batch_size(points, boxes, masks) 184 | sparse_embeddings = torch.empty( 185 | (bs, 0, self.embed_dim), device=self._get_device() 186 | ) 187 | if points is not None: 188 | coords, labels = points 189 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 190 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 191 | if boxes is not None: 192 | box_embeddings = self._embed_boxes(boxes) 193 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 194 | 195 | if masks is not None: 196 | dense_embeddings = self._embed_masks(masks) 197 | else: 198 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 199 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 200 | ) 201 | 202 | return sparse_embeddings, dense_embeddings -------------------------------------------------------------------------------- /models/components/semantic_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 4 | 5 | from models.sam2.backbones.image_encoder import ImageEncoder 6 | from models.sam2.sam2_utils import MLP 7 | from transformers import ( 8 | AutoTokenizer, 9 | CLIPTextModel, 10 | CLIPTextConfig, 11 | CLIPVisionConfig, 12 | CLIPVisionModel, 13 | AutoFeatureExtractor, 14 | ) 15 | 16 | 17 | class SemanticInteraction(nn.Module): 18 | def __init__(self, clip_text_ckpt, clip_image_ckp, image_dim, text_dim, num_heads, stem_channel, qkv_bias=False, 19 | qk_scale=None, drop=0., 20 | attn_drop=0., drop_path=0., has_mlp=False): 21 | super().__init__() 22 | self.text_encoder = ClipTextEncoder(clip_text_ckpt) 23 | self.image_encoder = ClipVisionEncoder(clip_image_ckp) 24 | self.cross_att_one_level_v = CrossAttentionBlock( 25 | image_dim, num_heads=num_heads, stem_channel=stem_channel, 26 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp 27 | ) 28 | self.cross_att_one_level_t = CrossAttentionBlock( 29 | text_dim, num_heads=num_heads, stem_channel=stem_channel, 30 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp 31 | ) 32 | 33 | self.in_dim = text_dim + text_dim 34 | self.co_conv = nn.Conv2d(self.in_dim, self.in_dim // 2, kernel_size=3, stride=1, padding=1, bias=True) 35 | 36 | self.cross_att_two_level_v = CrossAttentionBlock( 37 | image_dim, num_heads=num_heads, stem_channel=stem_channel, 38 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp 39 | ) 40 | self.cross_att_two_level_t = CrossAttentionBlock( 41 | text_dim, num_heads=num_heads, stem_channel=stem_channel, 42 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp 43 | ) 44 | self.linear_1 = nn.Linear(self.in_dim // 2, self.in_dim * 2) 45 | self.dropout = nn.Dropout(drop) 46 | self.relu = nn.ReLU() 47 | self.linear_2 = nn.Linear(self.in_dim * 2, self.linear_1 // 2) 48 | 49 | def forward(self, image, text): 50 | f_v = self.image_encoder(image) 51 | f_t = self.text_encoder(text) 52 | 53 | f_vt = self.cross_att_one_level_v(f_v, f_t) 54 | f_tv = self.cross_att_one_level_t(f_t, f_v) 55 | 56 | f_c = self.co_conv(torch.cat([f_v, f_t], dim=1)) 57 | 58 | f_vt_dot = self.cross_att_two_level_v(f_c, f_vt) 59 | f_tv_dot = self.cross_att_two_level_t(f_c, f_tv) 60 | 61 | raw_out = f_vt_dot + f_tv_dot 62 | 63 | out = self.linear_1(raw_out) 64 | out = self.dropout(out) 65 | out = self.relu(out) 66 | out = self.linear_2(out) 67 | return raw_out, out 68 | 69 | 70 | class DepthWiseConv(nn.Module): 71 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): 72 | super().__init__() 73 | self.depthwise = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 74 | padding=padding, groups=in_channels) 75 | self.pointwise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | x = self.depthwise(x) 79 | x = self.pointwise(x) 80 | return x 81 | 82 | 83 | class CrossAttention(nn.Module): 84 | def __init__(self, dim, num_heads=8, stem_channel=16, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 85 | super().__init__() 86 | self.num_heads = num_heads 87 | head_dim = dim // num_heads 88 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 89 | self.scale = qk_scale or head_dim ** -0.5 90 | 91 | self.conv_k = DepthWiseConv(stem_channel, stem_channel, kernel_size=3, padding=1) 92 | self.conv_v = DepthWiseConv(stem_channel, stem_channel, kernel_size=3, padding=1) 93 | 94 | self.wq = nn.Linear(dim, dim, bias=qkv_bias) 95 | self.wk = nn.Linear(dim, dim, bias=qkv_bias) 96 | self.wv = nn.Linear(dim, dim, bias=qkv_bias) 97 | self.attn_drop = nn.Dropout(attn_drop) 98 | self.proj = nn.Linear(dim, dim) 99 | self.proj_drop = nn.Dropout(proj_drop) 100 | 101 | def forward(self, x, y): 102 | B, N, C = x.shape 103 | q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 104 | 3) # B1C -> B1H(C/H) -> BH1(C/H) 105 | 106 | k = self.conv_k(y) 107 | k = self.wk(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 108 | 3) # BNC -> BNH(C/H) -> BHN(C/H) 109 | v = self.conv_v(y) 110 | v = self.wv(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 111 | 3) # BNC -> BNH(C/H) -> BHN(C/H) 112 | 113 | attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N 114 | attn = attn.softmax(dim=-1) 115 | attn = self.attn_drop(attn) 116 | 117 | out = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C 118 | out = self.proj(out) 119 | out = self.proj_drop(out) 120 | return out 121 | 122 | 123 | class CrossAttentionBlock(nn.Module): 124 | 125 | def __init__(self, dim, num_heads, stem_channel, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 126 | drop_path=0., has_mlp=False): 127 | super().__init__() 128 | self.norm_x = nn.LayerNorm(dim) 129 | self.norm_y = nn.LayerNorm(dim) 130 | self.attn = CrossAttention( 131 | dim, num_heads=num_heads, stem_channel=stem_channel, qkv_bias=qkv_bias, qk_scale=qk_scale, 132 | attn_drop=attn_drop, proj_drop=drop) 133 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 135 | self.has_mlp = has_mlp 136 | if has_mlp: 137 | self.norm2 = nn.LayerNorm(dim) 138 | mlp_hidden_dim = int(dim * mlp_ratio) 139 | self.mlp = MLP(dim, mlp_hidden_dim, dim // 4, 3) 140 | 141 | def forward(self, x, y): 142 | out = self.attn(self.norm_x(x), self.norm_y(y)) 143 | out = self.drop_path(out) 144 | if self.has_mlp: 145 | out = self.drop_path(self.mlp(self.norm2(out))) 146 | return out 147 | 148 | 149 | class ClipTextEncoder(nn.Module): 150 | def __init__(self, clip_ckpt): 151 | """ 152 | :param clip_ckpt: the list of checkpoints of CLIP could be found at: https://huggingface.co/openai 153 | """ 154 | super().__init__() 155 | config = CLIPTextConfig() 156 | self.clip_text_model = CLIPTextModel(config) 157 | self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt) 158 | self.dim_align = nn.Linear(512, 768) 159 | # freeze text encoder 160 | for param in self.clip_text_model.parameters(): 161 | param.requires_grad = False 162 | 163 | def organ2tokens(self, organ_names): 164 | text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names] 165 | tokens = self.tokenizer(text_list, padding=True, return_tensors="pt") 166 | for key in tokens.keys(): 167 | tokens[key] = tokens[key].cuda() 168 | return tokens 169 | 170 | def forward(self, text): 171 | if text is None: 172 | return None 173 | if type(text) is str: 174 | text = [text] 175 | tokens = self.organ2tokens(text) 176 | clip_outputs = self.clip_text_model(**tokens) 177 | text_embedding = clip_outputs.pooler_output 178 | text_embedding = self.dim_align(text_embedding) 179 | return text_embedding 180 | 181 | 182 | class ClipVisionEncoder(nn.Module): 183 | def __init__(self, clip_ckpt): 184 | """ 185 | :param clip_ckpt: the list of checkpoints of CLIP could be found at: https://huggingface.co/openai 186 | """ 187 | super().__init__() 188 | config = CLIPVisionConfig() 189 | self.clip_vision_model = CLIPVisionModel(config) 190 | self.feature_extractor = AutoFeatureExtractor.from_pretrained(clip_ckpt) 191 | self.dim_align = nn.Linear(512, 768) 192 | # freeze text encoder 193 | for param in self.clip_vision_model.parameters(): 194 | param.requires_grad = False 195 | 196 | def images2features(self, images): 197 | features = self.feature_extractor(images=images, return_tensors="pt") 198 | for key in features.keys(): 199 | features[key] = features[key].cuda() 200 | return features 201 | 202 | def forward(self, images): 203 | if images is None: 204 | return None 205 | if type(images) is not list: 206 | images = [images] 207 | features = self.images2features(images) 208 | clip_outputs = self.clip_vision_model(**features) 209 | vision_embedding = clip_outputs.pooler_output 210 | vision_embedding = self.dim_align(vision_embedding) 211 | return vision_embedding 212 | -------------------------------------------------------------------------------- /models/sam2/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from models.sam2.sam2_utils import MLP 14 | 15 | class ImageEncoder(nn.Module): 16 | def __init__( 17 | self, 18 | trunk: nn.Module, 19 | neck: nn.Module, 20 | scalp: int = 0, 21 | ): 22 | super().__init__() 23 | self.trunk = trunk 24 | self.neck = neck 25 | self.scalp = scalp 26 | assert ( 27 | self.trunk.channel_list == self.neck.backbone_channel_list 28 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 29 | 30 | def forward(self, img_sample: torch.Tensor, injection: torch.Tensor) -> torch.Tensor: 31 | # Forward through backbone 32 | features, pos = self.neck(self.trunk(img_sample, injection)) 33 | if self.scalp > 0: 34 | # Discard the lowest resolution features 35 | features, pos = features[: -self.scalp], pos[: -self.scalp] 36 | 37 | src = features[-1] 38 | output = { 39 | "vision_features": src, 40 | "vision_pos_enc": pos, 41 | "backbone_fpn": features, 42 | } 43 | return output 44 | 45 | 46 | class FpnNeckInjection(nn.Module): 47 | """ 48 | A modified variant of Feature Pyramid Network (FPN) neck with semantic injection 49 | """ 50 | 51 | def __init__( 52 | self, 53 | position_encoding: nn.Module, 54 | d_model: int, 55 | backbone_channel_list: List[int], 56 | kernel_size: int = 1, 57 | stride: int = 1, 58 | padding: int = 0, 59 | fpn_interp_model: str = "bilinear", 60 | fuse_type: str = "sum", 61 | fpn_top_down_levels: Optional[List[int]] = None, 62 | ): 63 | """Initialize the neck 64 | :param trunk: the backbone 65 | :param position_encoding: the positional encoding to use 66 | :param d_model: the dimension of the model 67 | :param neck_norm: the normalization to use 68 | """ 69 | super().__init__() 70 | self.position_encoding = position_encoding 71 | self.convs = nn.ModuleList() 72 | self.backbone_channel_list = backbone_channel_list 73 | self.d_model = d_model 74 | for dim in backbone_channel_list: 75 | current = nn.Sequential() 76 | current.add_module( 77 | "conv", 78 | nn.Conv2d( 79 | in_channels=dim, 80 | out_channels=d_model, 81 | kernel_size=kernel_size, 82 | stride=stride, 83 | padding=padding, 84 | ), 85 | ) 86 | 87 | self.convs.append(current) 88 | self.fpn_interp_model = fpn_interp_model 89 | assert fuse_type in ["sum", "avg"] 90 | self.fuse_type = fuse_type 91 | 92 | # levels to have top-down features in its outputs 93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 94 | # have top-down propagation, while outputs of level 0 and level 1 have only 95 | # lateral features from the same backbone level. 96 | if fpn_top_down_levels is None: 97 | # default is to have top-down features on all levels 98 | fpn_top_down_levels = range(len(self.convs)) 99 | self.fpn_top_down_levels = list(fpn_top_down_levels) 100 | 101 | # see https://github.com/tianrun-chen/SAM-Adapter-PyTorch/blob/SAM2-Adapter-for-Segment-Anything-2/models/mmseg/models/sam/image_encoder.py 102 | self.shared_mlp = MLP(d_model, d_model, d_model // 4, 2) 103 | self.unshared_mlps = nn.ModuleList() 104 | 105 | for dim in backbone_channel_list: 106 | current = MLP(d_model // 4, d_model // 4, dim, 2) 107 | self.convs.append(current) 108 | 109 | 110 | def forward(self, xs: List[torch.Tensor], injection: torch.Tensor): 111 | 112 | out = [None] * len(self.convs) 113 | pos = [None] * len(self.convs) 114 | assert len(xs) == len(self.convs) 115 | 116 | shared_features = self.shared_mlp(injection) + injection 117 | # fpn forward pass 118 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 119 | prev_features = None 120 | # forward in top-down order (from low to high resolution) 121 | n = len(self.convs) - 1 122 | for i in range(n, -1, -1): 123 | x = xs[i] 124 | lateral_features = self.convs[n - i](x) 125 | if i in self.fpn_top_down_levels and prev_features is not None: 126 | top_down_features = F.interpolate( 127 | prev_features.to(dtype=torch.float32), 128 | scale_factor=2.0, 129 | mode=self.fpn_interp_model, 130 | align_corners=( 131 | None if self.fpn_interp_model == "nearest" else False 132 | ), 133 | antialias=False, 134 | ) 135 | injection = self.unshared_mlps[n - i](shared_features) 136 | prev_features = lateral_features + top_down_features 137 | prev_features = prev_features + injection 138 | if self.fuse_type == "avg": 139 | prev_features /= 2 140 | else: 141 | prev_features = lateral_features 142 | x_out = prev_features 143 | out[i] = x_out 144 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 145 | 146 | return out, pos 147 | 148 | 149 | 150 | class FpnNeck(nn.Module): 151 | """ 152 | A modified variant of Feature Pyramid Network (FPN) neck 153 | (we remove output conv and also do bicubic interpolation similar to ViT 154 | pos embed interpolation) 155 | """ 156 | 157 | def __init__( 158 | self, 159 | position_encoding: nn.Module, 160 | d_model: int, 161 | backbone_channel_list: List[int], 162 | kernel_size: int = 1, 163 | stride: int = 1, 164 | padding: int = 0, 165 | fpn_interp_model: str = "bilinear", 166 | fuse_type: str = "sum", 167 | fpn_top_down_levels: Optional[List[int]] = None, 168 | ): 169 | """Initialize the neck 170 | :param trunk: the backbone 171 | :param position_encoding: the positional encoding to use 172 | :param d_model: the dimension of the model 173 | :param neck_norm: the normalization to use 174 | """ 175 | super().__init__() 176 | self.position_encoding = position_encoding 177 | self.convs = nn.ModuleList() 178 | self.backbone_channel_list = backbone_channel_list 179 | self.d_model = d_model 180 | for dim in backbone_channel_list: 181 | current = nn.Sequential() 182 | current.add_module( 183 | "conv", 184 | nn.Conv2d( 185 | in_channels=dim, 186 | out_channels=d_model, 187 | kernel_size=kernel_size, 188 | stride=stride, 189 | padding=padding, 190 | ), 191 | ) 192 | 193 | self.convs.append(current) 194 | self.fpn_interp_model = fpn_interp_model 195 | assert fuse_type in ["sum", "avg"] 196 | self.fuse_type = fuse_type 197 | 198 | # levels to have top-down features in its outputs 199 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 200 | # have top-down propagation, while outputs of level 0 and level 1 have only 201 | # lateral features from the same backbone level. 202 | if fpn_top_down_levels is None: 203 | # default is to have top-down features on all levels 204 | fpn_top_down_levels = range(len(self.convs)) 205 | self.fpn_top_down_levels = list(fpn_top_down_levels) 206 | 207 | def forward(self, xs: List[torch.Tensor]): 208 | 209 | out = [None] * len(self.convs) 210 | pos = [None] * len(self.convs) 211 | assert len(xs) == len(self.convs) 212 | # fpn forward pass 213 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 214 | prev_features = None 215 | # forward in top-down order (from low to high resolution) 216 | n = len(self.convs) - 1 217 | for i in range(n, -1, -1): 218 | x = xs[i] 219 | lateral_features = self.convs[n - i](x) 220 | if i in self.fpn_top_down_levels and prev_features is not None: 221 | top_down_features = F.interpolate( 222 | prev_features.to(dtype=torch.float32), 223 | scale_factor=2.0, 224 | mode=self.fpn_interp_model, 225 | align_corners=( 226 | None if self.fpn_interp_model == "nearest" else False 227 | ), 228 | antialias=False, 229 | ) 230 | prev_features = lateral_features + top_down_features 231 | if self.fuse_type == "avg": 232 | prev_features /= 2 233 | else: 234 | prev_features = lateral_features 235 | x_out = prev_features 236 | out[i] = x_out 237 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 238 | 239 | return out, pos -------------------------------------------------------------------------------- /models/sam2/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention Is All You Need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | # Following settings only relevant 29 | # for warmping up cache for compilation 30 | warmup_cache: bool = True, 31 | image_size: int = 1024, 32 | strides: Tuple[int] = (4, 8, 16, 32), 33 | ): 34 | super().__init__() 35 | assert num_pos_feats % 2 == 0, "Expecting even model width" 36 | self.num_pos_feats = num_pos_feats // 2 37 | self.temperature = temperature 38 | self.normalize = normalize 39 | if scale is not None and normalize is False: 40 | raise ValueError("normalize should be True if scale is passed") 41 | if scale is None: 42 | scale = 2 * math.pi 43 | self.scale = scale 44 | 45 | self.cache = {} 46 | if warmup_cache and torch.cuda.is_available(): 47 | # Warmup cache for cuda, to help with compilation 48 | device = torch.device("cuda") 49 | for stride in strides: 50 | cache_key = (image_size // stride, image_size // stride) 51 | self._pe(1, device, *cache_key) 52 | 53 | def _encode_xy(self, x, y): 54 | # The positions are expected to be normalized 55 | assert len(x) == len(y) and x.ndim == y.ndim == 1 56 | x_embed = x * self.scale 57 | y_embed = y * self.scale 58 | 59 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 60 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 61 | 62 | pos_x = x_embed[:, None] / dim_t 63 | pos_y = y_embed[:, None] / dim_t 64 | pos_x = torch.stack( 65 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 66 | ).flatten(1) 67 | pos_y = torch.stack( 68 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 69 | ).flatten(1) 70 | return pos_x, pos_y 71 | 72 | @torch.no_grad() 73 | def encode_boxes(self, x, y, w, h): 74 | pos_x, pos_y = self._encode_xy(x, y) 75 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 76 | return pos 77 | 78 | encode = encode_boxes # Backwards compatibility 79 | 80 | @torch.no_grad() 81 | def encode_points(self, x, y, labels): 82 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 83 | assert bx == by and nx == ny and bx == bl and nx == nl 84 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 85 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 86 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 87 | return pos 88 | 89 | @torch.no_grad() 90 | def _pe(self, B, device, *cache_key): 91 | H, W = cache_key 92 | if cache_key in self.cache: 93 | return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) 94 | 95 | y_embed = ( 96 | torch.arange(1, H + 1, dtype=torch.float32, device=device) 97 | .view(1, -1, 1) 98 | .repeat(B, 1, W) 99 | ) 100 | x_embed = ( 101 | torch.arange(1, W + 1, dtype=torch.float32, device=device) 102 | .view(1, 1, -1) 103 | .repeat(B, H, 1) 104 | ) 105 | 106 | if self.normalize: 107 | eps = 1e-6 108 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 109 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 110 | 111 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) 112 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 113 | 114 | pos_x = x_embed[:, :, :, None] / dim_t 115 | pos_y = y_embed[:, :, :, None] / dim_t 116 | pos_x = torch.stack( 117 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 118 | ).flatten(3) 119 | pos_y = torch.stack( 120 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 121 | ).flatten(3) 122 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 123 | self.cache[cache_key] = pos[0] 124 | return pos 125 | 126 | @torch.no_grad() 127 | def forward(self, x: torch.Tensor): 128 | B = x.shape[0] 129 | cache_key = (x.shape[-2], x.shape[-1]) 130 | return self._pe(B, x.device, *cache_key) 131 | 132 | 133 | class PositionEmbeddingRandom(nn.Module): 134 | """ 135 | Positional encoding using random spatial frequencies. 136 | """ 137 | 138 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 139 | super().__init__() 140 | if scale is None or scale <= 0.0: 141 | scale = 1.0 142 | self.register_buffer( 143 | "positional_encoding_gaussian_matrix", 144 | scale * torch.randn((2, num_pos_feats)), 145 | ) 146 | 147 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 148 | """Positionally encode points that are normalized to [0,1].""" 149 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 150 | coords = 2 * coords - 1 151 | coords = coords @ self.positional_encoding_gaussian_matrix 152 | coords = 2 * np.pi * coords 153 | # outputs d_1 x ... x d_n x C shape 154 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 155 | 156 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 157 | """Generate positional encoding for a grid of the specified size.""" 158 | h, w = size 159 | device: Any = self.positional_encoding_gaussian_matrix.device 160 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 161 | y_embed = grid.cumsum(dim=0) - 0.5 162 | x_embed = grid.cumsum(dim=1) - 0.5 163 | y_embed = y_embed / h 164 | x_embed = x_embed / w 165 | 166 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 167 | return pe.permute(2, 0, 1) # C x H x W 168 | 169 | def forward_with_coords( 170 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 171 | ) -> torch.Tensor: 172 | """Positionally encode points that are not normalized to [0,1].""" 173 | coords = coords_input.clone() 174 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 175 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 176 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 177 | 178 | 179 | # Rotary Positional Encoding, adapted from: 180 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 181 | # 2. https://github.com/naver-ai/rope-vit 182 | # 3. https://github.com/lucidrains/rotary-embedding-torch 183 | 184 | 185 | def init_t_xy(end_x: int, end_y: int): 186 | t = torch.arange(end_x * end_y, dtype=torch.float32) 187 | t_x = (t % end_x).float() 188 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 189 | return t_x, t_y 190 | 191 | 192 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 193 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 194 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 195 | 196 | t_x, t_y = init_t_xy(end_x, end_y) 197 | freqs_x = torch.outer(t_x, freqs_x) 198 | freqs_y = torch.outer(t_y, freqs_y) 199 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 200 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 201 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 202 | 203 | 204 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 205 | ndim = x.ndim 206 | assert 0 <= 1 < ndim 207 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 208 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 209 | return freqs_cis.view(*shape) 210 | 211 | 212 | def apply_rotary_enc( 213 | xq: torch.Tensor, 214 | xk: torch.Tensor, 215 | freqs_cis: torch.Tensor, 216 | repeat_freqs_k: bool = False, 217 | ): 218 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 219 | xk_ = ( 220 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 221 | if xk.shape[-2] != 0 222 | else None 223 | ) 224 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 225 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 226 | if xk_ is None: 227 | # no keys to rotate, due to dropout 228 | return xq_out.type_as(xq).to(xq.device), xk 229 | # repeat freqs along seq_len dim to match k seq_len 230 | if repeat_freqs_k: 231 | r = xk_.shape[-2] // xq_.shape[-2] 232 | if freqs_cis.is_cuda: 233 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 234 | else: 235 | # torch.repeat on complex numbers may not be supported on non-CUDA devices 236 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten 237 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) 238 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 239 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) -------------------------------------------------------------------------------- /models/sam2/backbones/hieradet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from functools import partial 9 | from typing import List, Tuple, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from iopath.common.file_io import g_pathmgr 15 | 16 | from models.sam2.backbones.utils import ( 17 | PatchEmbed, 18 | window_partition, 19 | window_unpartition, 20 | ) 21 | 22 | from models.sam2.sam2_utils import DropPath, MLP 23 | 24 | 25 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: 26 | if pool is None: 27 | return x 28 | # (B, H, W, C) -> (B, C, H, W) 29 | x = x.permute(0, 3, 1, 2) 30 | x = pool(x) 31 | # (B, C, H', W') -> (B, H', W', C) 32 | x = x.permute(0, 2, 3, 1) 33 | if norm: 34 | x = norm(x) 35 | 36 | return x 37 | 38 | 39 | class MultiScaleAttention(nn.Module): 40 | def __init__( 41 | self, 42 | dim: int, 43 | dim_out: int, 44 | num_heads: int, 45 | q_pool: nn.Module = None, 46 | ): 47 | super().__init__() 48 | 49 | self.dim = dim 50 | self.dim_out = dim_out 51 | self.num_heads = num_heads 52 | self.q_pool = q_pool 53 | self.qkv = nn.Linear(dim, dim_out * 3) 54 | self.proj = nn.Linear(dim_out, dim_out) 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | B, H, W, _ = x.shape 58 | # qkv with shape (B, H * W, 3, nHead, C) 59 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) 60 | # q, k, v with shape (B, H * W, nheads, C) 61 | q, k, v = torch.unbind(qkv, 2) 62 | 63 | # Q pooling (for downsample at stage changes) 64 | if self.q_pool: 65 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool) 66 | H, W = q.shape[1:3] # downsampled shape 67 | q = q.reshape(B, H * W, self.num_heads, -1) 68 | 69 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose 70 | x = F.scaled_dot_product_attention( 71 | q.transpose(1, 2), 72 | k.transpose(1, 2), 73 | v.transpose(1, 2), 74 | ) 75 | # Transpose back 76 | x = x.transpose(1, 2) 77 | x = x.reshape(B, H, W, -1) 78 | 79 | x = self.proj(x) 80 | 81 | return x 82 | 83 | 84 | class MultiScaleBlock(nn.Module): 85 | def __init__( 86 | self, 87 | dim: int, 88 | dim_out: int, 89 | num_heads: int, 90 | mlp_ratio: float = 4.0, 91 | drop_path: float = 0.0, 92 | norm_layer: Union[nn.Module, str] = "LayerNorm", 93 | q_stride: Tuple[int, int] = None, 94 | act_layer: nn.Module = nn.GELU, 95 | window_size: int = 0, 96 | ): 97 | super().__init__() 98 | 99 | if isinstance(norm_layer, str): 100 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 101 | 102 | self.dim = dim 103 | self.dim_out = dim_out 104 | self.norm1 = norm_layer(dim) 105 | 106 | self.window_size = window_size 107 | 108 | self.pool, self.q_stride = None, q_stride 109 | if self.q_stride: 110 | self.pool = nn.MaxPool2d( 111 | kernel_size=q_stride, stride=q_stride, ceil_mode=False 112 | ) 113 | 114 | self.attn = MultiScaleAttention( 115 | dim, 116 | dim_out, 117 | num_heads=num_heads, 118 | q_pool=self.pool, 119 | ) 120 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 121 | 122 | self.norm2 = norm_layer(dim_out) 123 | self.mlp = MLP( 124 | dim_out, 125 | int(dim_out * mlp_ratio), 126 | dim_out, 127 | num_layers=2, 128 | activation=act_layer, 129 | ) 130 | 131 | if dim != dim_out: 132 | self.proj = nn.Linear(dim, dim_out) 133 | 134 | def forward(self, x: torch.Tensor) -> torch.Tensor: 135 | shortcut = x # B, H, W, C 136 | x = self.norm1(x) 137 | 138 | # Skip connection 139 | if self.dim != self.dim_out: 140 | shortcut = do_pool(self.proj(x), self.pool) 141 | 142 | # Window partition 143 | window_size = self.window_size 144 | if window_size > 0: 145 | H, W = x.shape[1], x.shape[2] 146 | x, pad_hw = window_partition(x, window_size) 147 | 148 | # Window Attention + Q Pooling (if stage change) 149 | x = self.attn(x) 150 | if self.q_stride: 151 | # Shapes have changed due to Q pooling 152 | window_size = self.window_size // self.q_stride[0] 153 | H, W = shortcut.shape[1:3] 154 | 155 | pad_h = (window_size - H % window_size) % window_size 156 | pad_w = (window_size - W % window_size) % window_size 157 | pad_hw = (H + pad_h, W + pad_w) 158 | 159 | # Reverse window partition 160 | if self.window_size > 0: 161 | x = window_unpartition(x, window_size, pad_hw, (H, W)) 162 | 163 | x = shortcut + self.drop_path(x) 164 | # MLP 165 | x = x + self.drop_path(self.mlp(self.norm2(x))) 166 | return x 167 | 168 | 169 | class Hiera(nn.Module): 170 | """ 171 | Reference: https://arxiv.org/abs/2306.00989 172 | """ 173 | 174 | def __init__( 175 | self, 176 | embed_dim: int = 96, # initial embed dim 177 | num_heads: int = 1, # initial number of heads 178 | drop_path_rate: float = 0.0, # stochastic depth 179 | q_pool: int = 3, # number of q_pool stages 180 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages 181 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage 182 | dim_mul: float = 2.0, # dim_mul factor at stage shift 183 | head_mul: float = 2.0, # head_mul factor at stage shift 184 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 185 | # window size per stage, when not using global att. 186 | window_spec: Tuple[int, ...] = ( 187 | 8, 188 | 4, 189 | 14, 190 | 7, 191 | ), 192 | # global attn in these blocks 193 | global_att_blocks: Tuple[int, ...] = ( 194 | 12, 195 | 16, 196 | 20, 197 | ), 198 | weights_path=None, 199 | return_interm_layers=True, # return feats from every stage 200 | ): 201 | super().__init__() 202 | 203 | assert len(stages) == len(window_spec) 204 | self.window_spec = window_spec 205 | 206 | depth = sum(stages) 207 | self.q_stride = q_stride 208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 209 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) 210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] 211 | self.return_interm_layers = return_interm_layers 212 | 213 | self.patch_embed = PatchEmbed( 214 | embed_dim=embed_dim, 215 | ) 216 | # Which blocks have global att? 217 | self.global_att_blocks = global_att_blocks 218 | 219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) 220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size 221 | self.pos_embed = nn.Parameter( 222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) 223 | ) 224 | self.pos_embed_window = nn.Parameter( 225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) 226 | ) 227 | 228 | dpr = [ 229 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 230 | ] # stochastic depth decay rule 231 | 232 | cur_stage = 1 233 | self.blocks = nn.ModuleList() 234 | 235 | for i in range(depth): 236 | dim_out = embed_dim 237 | # lags by a block, so first block of 238 | # next stage uses an initial window size 239 | # of previous stage and final window size of current stage 240 | window_size = self.window_spec[cur_stage - 1] 241 | 242 | if self.global_att_blocks is not None: 243 | window_size = 0 if i in self.global_att_blocks else window_size 244 | 245 | if i - 1 in self.stage_ends: 246 | dim_out = int(embed_dim * dim_mul) 247 | num_heads = int(num_heads * head_mul) 248 | cur_stage += 1 249 | 250 | block = MultiScaleBlock( 251 | dim=embed_dim, 252 | dim_out=dim_out, 253 | num_heads=num_heads, 254 | drop_path=dpr[i], 255 | q_stride=self.q_stride if i in self.q_pool_blocks else None, 256 | window_size=window_size, 257 | ) 258 | 259 | embed_dim = dim_out 260 | self.blocks.append(block) 261 | 262 | self.channel_list = ( 263 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]] 264 | if return_interm_layers 265 | else [self.blocks[-1].dim_out] 266 | ) 267 | 268 | if weights_path is not None: 269 | with g_pathmgr.open(weights_path, "rb") as f: 270 | chkpt = torch.load(f, map_location="cpu") 271 | logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) 272 | 273 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: 274 | h, w = hw 275 | window_embed = self.pos_embed_window 276 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") 277 | pos_embed = pos_embed + window_embed.tile( 278 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] 279 | ) 280 | pos_embed = pos_embed.permute(0, 2, 3, 1) 281 | return pos_embed 282 | 283 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 284 | x = self.patch_embed(x) 285 | # x: (B, H, W, C) 286 | 287 | # Add pos embed 288 | x = x + self._get_pos_embed(x.shape[1:3]) 289 | 290 | outputs = [] 291 | for i, blk in enumerate(self.blocks): 292 | x = blk(x) 293 | if (i == self.stage_ends[-1]) or ( 294 | i in self.stage_ends and self.return_interm_layers 295 | ): 296 | feats = x.permute(0, 3, 1, 2) 297 | outputs.append(feats) 298 | 299 | return outputs 300 | 301 | def get_layer_id(self, layer_name): 302 | # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 303 | num_layers = self.get_num_layers() 304 | 305 | if layer_name.find("rel_pos") != -1: 306 | return num_layers + 1 307 | elif layer_name.find("pos_embed") != -1: 308 | return 0 309 | elif layer_name.find("patch_embed") != -1: 310 | return 0 311 | elif layer_name.find("blocks") != -1: 312 | return int(layer_name.split("blocks")[1].split(".")[1]) + 1 313 | else: 314 | return num_layers + 1 315 | 316 | def get_num_layers(self) -> int: 317 | return len(self.blocks) -------------------------------------------------------------------------------- /models/sam2/sam/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from functools import partial 9 | from typing import Tuple, Type 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, Tensor 14 | 15 | from models.sam2.position_encoding import apply_rotary_enc, compute_axial_cis 16 | from models.sam2.sam2_utils import MLP 17 | 18 | 19 | class TwoWayTransformer(nn.Module): 20 | def __init__( 21 | self, 22 | depth: int, 23 | embedding_dim: int, 24 | num_heads: int, 25 | mlp_dim: int, 26 | activation: Type[nn.Module] = nn.ReLU, 27 | attention_downsample_rate: int = 2, 28 | ) -> None: 29 | """ 30 | A transformer decoder that attends to an input image using 31 | queries whose positional embedding is supplied. 32 | 33 | Args: 34 | depth (int): number of layers in the transformer 35 | embedding_dim (int): the channel dimension for the input embeddings 36 | num_heads (int): the number of heads for multihead attention. Must 37 | divide embedding_dim 38 | mlp_dim (int): the channel dimension internal to the MLP block 39 | activation (nn.Module): the activation to use in the MLP block 40 | """ 41 | super().__init__() 42 | self.depth = depth 43 | self.embedding_dim = embedding_dim 44 | self.num_heads = num_heads 45 | self.mlp_dim = mlp_dim 46 | self.layers = nn.ModuleList() 47 | 48 | for i in range(depth): 49 | self.layers.append( 50 | TwoWayAttentionBlock( 51 | embedding_dim=embedding_dim, 52 | num_heads=num_heads, 53 | mlp_dim=mlp_dim, 54 | activation=activation, 55 | attention_downsample_rate=attention_downsample_rate, 56 | skip_first_layer_pe=(i == 0), 57 | ) 58 | ) 59 | 60 | self.final_attn_token_to_image = Attention( 61 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 62 | ) 63 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 64 | 65 | def forward( 66 | self, 67 | image_embedding: Tensor, 68 | image_pe: Tensor, 69 | point_embedding: Tensor, 70 | ) -> Tuple[Tensor, Tensor]: 71 | """ 72 | Args: 73 | image_embedding (torch.Tensor): image to attend to. Should be shape 74 | B x embedding_dim x h x w for any h and w. 75 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 76 | have the same shape as image_embedding. 77 | point_embedding (torch.Tensor): the embedding to add to the query points. 78 | Must have shape B x N_points x embedding_dim for any N_points. 79 | 80 | Returns: 81 | torch.Tensor: the processed point_embedding 82 | torch.Tensor: the processed image_embedding 83 | """ 84 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 85 | bs, c, h, w = image_embedding.shape 86 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 87 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 88 | 89 | # Prepare queries 90 | queries = point_embedding 91 | keys = image_embedding 92 | 93 | # Apply transformer blocks and final layernorm 94 | for layer in self.layers: 95 | queries, keys = layer( 96 | queries=queries, 97 | keys=keys, 98 | query_pe=point_embedding, 99 | key_pe=image_pe, 100 | ) 101 | 102 | # Apply the final attention layer from the points to the image 103 | q = queries + point_embedding 104 | k = keys + image_pe 105 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 106 | queries = queries + attn_out 107 | queries = self.norm_final_attn(queries) 108 | 109 | return queries, keys 110 | 111 | 112 | class TwoWayAttentionBlock(nn.Module): 113 | def __init__( 114 | self, 115 | embedding_dim: int, 116 | num_heads: int, 117 | mlp_dim: int = 2048, 118 | activation: Type[nn.Module] = nn.ReLU, 119 | attention_downsample_rate: int = 2, 120 | skip_first_layer_pe: bool = False, 121 | ) -> None: 122 | """ 123 | A transformer block with four layers: (1) self-attention of sparse 124 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 125 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 126 | inputs. 127 | 128 | Arguments: 129 | embedding_dim (int): the channel dimension of the embeddings 130 | num_heads (int): the number of heads in the attention layers 131 | mlp_dim (int): the hidden dimension of the mlp block 132 | activation (nn.Module): the activation of the mlp block 133 | skip_first_layer_pe (bool): skip the PE on the first layer 134 | """ 135 | super().__init__() 136 | self.self_attn = Attention(embedding_dim, num_heads) 137 | self.norm1 = nn.LayerNorm(embedding_dim) 138 | 139 | self.cross_attn_token_to_image = Attention( 140 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 141 | ) 142 | self.norm2 = nn.LayerNorm(embedding_dim) 143 | 144 | self.mlp = MLP( 145 | embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation 146 | ) 147 | self.norm3 = nn.LayerNorm(embedding_dim) 148 | 149 | self.norm4 = nn.LayerNorm(embedding_dim) 150 | self.cross_attn_image_to_token = Attention( 151 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 152 | ) 153 | 154 | self.skip_first_layer_pe = skip_first_layer_pe 155 | 156 | def forward( 157 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 158 | ) -> Tuple[Tensor, Tensor]: 159 | # Self attention block 160 | if self.skip_first_layer_pe: 161 | queries = self.self_attn(q=queries, k=queries, v=queries) 162 | else: 163 | q = queries + query_pe 164 | attn_out = self.self_attn(q=q, k=q, v=queries) 165 | queries = queries + attn_out 166 | queries = self.norm1(queries) 167 | 168 | # Cross attention block, tokens attending to image embedding 169 | q = queries + query_pe 170 | k = keys + key_pe 171 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 172 | queries = queries + attn_out 173 | queries = self.norm2(queries) 174 | 175 | # MLP block 176 | mlp_out = self.mlp(queries) 177 | queries = queries + mlp_out 178 | queries = self.norm3(queries) 179 | 180 | # Cross attention block, image embedding attending to tokens 181 | q = queries + query_pe 182 | k = keys + key_pe 183 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 184 | keys = keys + attn_out 185 | keys = self.norm4(keys) 186 | 187 | return queries, keys 188 | 189 | 190 | class Attention(nn.Module): 191 | """ 192 | An attention layer that allows for downscaling the size of the embedding 193 | after projection to queries, keys, and values. 194 | """ 195 | 196 | def __init__( 197 | self, 198 | embedding_dim: int, 199 | num_heads: int, 200 | downsample_rate: int = 1, 201 | dropout: float = 0.0, 202 | kv_in_dim: int = None, 203 | ) -> None: 204 | super().__init__() 205 | self.embedding_dim = embedding_dim 206 | self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim 207 | self.internal_dim = embedding_dim // downsample_rate 208 | self.num_heads = num_heads 209 | assert ( 210 | self.internal_dim % num_heads == 0 211 | ), "num_heads must divide embedding_dim." 212 | 213 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 214 | self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 215 | self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 216 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 217 | 218 | self.dropout_p = dropout 219 | 220 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 221 | b, n, c = x.shape 222 | x = x.reshape(b, n, num_heads, c // num_heads) 223 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 224 | 225 | def _recombine_heads(self, x: Tensor) -> Tensor: 226 | b, n_heads, n_tokens, c_per_head = x.shape 227 | x = x.transpose(1, 2) 228 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 229 | 230 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 231 | # Input projections 232 | q = self.q_proj(q) 233 | k = self.k_proj(k) 234 | v = self.v_proj(v) 235 | 236 | # Separate into heads 237 | q = self._separate_heads(q, self.num_heads) 238 | k = self._separate_heads(k, self.num_heads) 239 | v = self._separate_heads(v, self.num_heads) 240 | 241 | dropout_p = self.dropout_p if self.training else 0.0 242 | # Attention 243 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 244 | 245 | out = self._recombine_heads(out) 246 | out = self.out_proj(out) 247 | 248 | return out 249 | 250 | 251 | class RoPEAttention(Attention): 252 | """Attention with rotary position encoding.""" 253 | 254 | def __init__( 255 | self, 256 | *args, 257 | rope_theta=10000.0, 258 | # whether to repeat q rope to match k length 259 | # this is needed for cross-attention to memories 260 | rope_k_repeat=False, 261 | feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution 262 | **kwargs, 263 | ): 264 | super().__init__(*args, **kwargs) 265 | 266 | self.compute_cis = partial( 267 | compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta 268 | ) 269 | freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) 270 | self.freqs_cis = ( 271 | freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis 272 | ) 273 | self.rope_k_repeat = rope_k_repeat 274 | 275 | def forward( 276 | self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 277 | ) -> Tensor: 278 | # Input projections 279 | q = self.q_proj(q) 280 | k = self.k_proj(k) 281 | v = self.v_proj(v) 282 | 283 | # Separate into heads 284 | q = self._separate_heads(q, self.num_heads) 285 | k = self._separate_heads(k, self.num_heads) 286 | v = self._separate_heads(v, self.num_heads) 287 | 288 | # Apply rotary position encoding 289 | w = h = math.sqrt(q.shape[-2]) 290 | self.freqs_cis = self.freqs_cis.to(q.device) 291 | if self.freqs_cis.shape[0] != q.shape[-2]: 292 | self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) 293 | if q.shape[-2] != k.shape[-2]: 294 | assert self.rope_k_repeat 295 | 296 | num_k_rope = k.size(-2) - num_k_exclude_rope 297 | q, k[:, :, :num_k_rope] = apply_rotary_enc( 298 | q, 299 | k[:, :, :num_k_rope], 300 | freqs_cis=self.freqs_cis, 301 | repeat_freqs_k=self.rope_k_repeat, 302 | ) 303 | 304 | dropout_p = self.dropout_p if self.training else 0.0 305 | # Attention 306 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 307 | 308 | out = self._recombine_heads(out) 309 | out = self.out_proj(out) 310 | 311 | return out -------------------------------------------------------------------------------- /datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from monai import data, transforms 5 | import itertools 6 | from torch.utils.data.distributed import DistributedSampler 7 | from torch.utils.data import Dataset, ConcatDataset 8 | import os 9 | import ast 10 | from scipy import sparse 11 | import random 12 | import json 13 | from torch.utils.data import DataLoader 14 | from monai.transforms import ( 15 | Compose, 16 | AddChanneld, 17 | CropForegroundd, 18 | SpatialPadd, 19 | Resized, 20 | RandCropByPosNegLabeld, 21 | RandFlipd, 22 | RandScaleIntensityd, 23 | RandShiftIntensityd, 24 | ToTensord, 25 | ) 26 | 27 | def read_json_file(file_path): 28 | try: 29 | with open(file_path, 'r', encoding='utf-8') as file: 30 | data = json.load(file) 31 | return data 32 | except Exception as e: 33 | print(f"{e}") 34 | return None 35 | 36 | 37 | class UnionDataset(Dataset): 38 | def __init__(self, concat_dataset, datasets): 39 | self.datasets = datasets 40 | self.lengths = [len(d) for d in datasets] 41 | self.offsets = torch.cumsum(torch.tensor([0] + self.lengths), dim=0) 42 | self.concat_dataset = concat_dataset 43 | 44 | def __len__(self): 45 | return sum(self.lengths) 46 | 47 | def __getitem__(self, idx): 48 | return self.concat_dataset[idx] 49 | 50 | 51 | class UniversalDataset(Dataset): 52 | def __init__(self, data, transform, test_mode, organ_list): 53 | self.data = data 54 | self.transform = transform 55 | # one pos point is base set 56 | self.num_positive_extra_max = 10 57 | self.num_negative_extra_max = 10 58 | self.test_mode = test_mode 59 | self.bbox_shift = 10 if test_mode else 0 60 | print(organ_list) 61 | organ_list.remove('background') 62 | self.target_list = organ_list 63 | 64 | def __len__(self): 65 | return len(self.data) 66 | 67 | def __getitem__(self, idx): 68 | # get path 69 | item_dict = self.data[idx] 70 | ct_path, gt_path = item_dict['image'], item_dict['label'] 71 | gt_shape = ast.literal_eval(gt_path.split('.')[-2].split('_')[-1]) 72 | 73 | # load data 74 | ct_array = np.load(ct_path)[0] 75 | allmatrix_sp = sparse.load_npz(gt_path) 76 | gt_array = allmatrix_sp.toarray().reshape(gt_shape) 77 | 78 | # transform 79 | if self.test_mode: 80 | item_ori = { 81 | 'image': ct_array, 82 | 'label': gt_array, 83 | } 84 | else: 85 | item_ori = { 86 | 'image': ct_array, 87 | 'label': gt_array, 88 | } 89 | if self.transform is not None: 90 | item = self.transform(item_ori) 91 | 92 | if type(item) == list: 93 | assert len(item) == 1 94 | item = item[0] 95 | 96 | assert type(item) != list 97 | item['organ_name_list'] = self.target_list 98 | item['post_label'] = item['label'] 99 | post_item = self.std_keys(item) 100 | return post_item 101 | 102 | def std_keys(self, post_item): 103 | keys_to_remain = ['image', 'post_label', 'organ_name_list'] 104 | keys_to_remove = post_item.keys() - keys_to_remain 105 | for key in keys_to_remove: 106 | del post_item[key] 107 | return post_item 108 | 109 | 110 | def query_descriptions(input_name, data, min_count=4, max_count=6): 111 | input_name = input_name.strip().lower() 112 | matching_keys = [k for k in data.keys() if k.lower() == input_name] 113 | 114 | if not matching_keys: 115 | return [] 116 | 117 | descriptions = data[matching_keys[0]] 118 | count = len(descriptions) 119 | 120 | select_count = min(max_count, count) 121 | select_count = max(min_count, select_count) 122 | 123 | selected = descriptions[:select_count] 124 | 125 | return selected 126 | 127 | 128 | 129 | class BatchedDistributedSampler(DistributedSampler): 130 | def __init__(self, dataset, shuffle, batch_size, num_replicas=None, rank=None): 131 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 132 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 133 | self.total_size = self.num_samples * self.num_replicas 134 | self.batch_size = batch_size 135 | 136 | def __iter__(self): 137 | print('run BatchedDistributedSampler iter') 138 | indices = list(range(len(self.dataset))) 139 | # indices += indices[:(self.total_size - len(indices))] 140 | # assert len(indices) == self.total_size 141 | 142 | indices = [indices[i:i + l] for i, l in zip(self.dataset.offsets[:-1], self.dataset.lengths)] 143 | 144 | if self.shuffle: 145 | for idx, subset_indices in enumerate(indices): 146 | random.shuffle(indices[idx]) 147 | 148 | # drop subset last 149 | for idx, subset_indices in enumerate(indices): 150 | r = len(subset_indices) % self.batch_size 151 | if r > 0: 152 | indices[idx] = indices[idx][:-r] 153 | indices = list(itertools.chain(*indices)) 154 | indices = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)] 155 | if self.shuffle: 156 | random.shuffle(indices) 157 | 158 | batch_num = len(indices) 159 | replicas_size = batch_num // self.num_replicas 160 | start = self.rank * replicas_size 161 | end = start + replicas_size if self.rank != self.num_replicas - 1 else batch_num 162 | batched_indices = list(itertools.chain(*(indices[start:end]))) 163 | ## 164 | indices = list(itertools.chain(*indices)) 165 | self.total_size = len(indices) 166 | self.num_samples = self.total_size // self.num_replicas 167 | ## 168 | return iter(batched_indices) 169 | 170 | 171 | def collate_fn(batch, dictionary_path): 172 | term_dictionary = read_json_file(dictionary_path) 173 | images = [] 174 | organ_name_list = None 175 | post_labels = [] 176 | organ_description_list = [] 177 | for sample in batch: 178 | images.append(sample['image']) 179 | assert organ_name_list is None or organ_name_list == sample['organ_name_list'] 180 | organ_name_list = sample['organ_name_list'] 181 | for organ in organ_name_list: 182 | organ_description_list.append(query_descriptions(organ, term_dictionary, 4, 6)) 183 | post_labels.append(sample['post_label']) 184 | return { 185 | 'image': torch.stack(images, dim=0), 186 | 'organ_name_list': organ_name_list, 187 | 'post_label': torch.stack(post_labels, dim=0), 188 | 'text': organ_description_list 189 | } 190 | 191 | 192 | class MinMaxNormalization(transforms.Transform): 193 | def __call__(self, data): 194 | d = dict(data) 195 | k = "image" 196 | d[k] = d[k] - d[k].min() 197 | d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None) 198 | return d 199 | 200 | 201 | class DimTranspose(transforms.Transform): 202 | def __init__(self, keys): 203 | self.keys = keys 204 | 205 | def __call__(self, data): 206 | d = dict(data) 207 | for key in self.keys: 208 | d[key] = np.swapaxes(d[key], -1, -3) 209 | return d 210 | 211 | 212 | def build_concat_dataset(root_path, dataset_codes, transform): 213 | concat_dataset = [] 214 | CombinationDataset_len = 0 215 | for dataset_code in dataset_codes: 216 | datalist_json = os.path.join(root_path, dataset_code, f'{dataset_code}.json') 217 | with open(datalist_json, 'r') as f: 218 | dataset_dict = json.load(f) 219 | datalist = dataset_dict['train'] 220 | universal_ds = UniversalDataset(data=datalist, transform=transform, test_mode=False, 221 | organ_list=list(dataset_dict['labels'].values())) 222 | concat_dataset.append(universal_ds) 223 | CombinationDataset_len += len(universal_ds) 224 | print(f'CombinationDataset loaded, dataset size: {CombinationDataset_len}') 225 | return UnionDataset(ConcatDataset(concat_dataset), concat_dataset) 226 | 227 | 228 | def get_loader(args): 229 | if args.mode == 'train': 230 | train_transform = Compose( 231 | [ 232 | AddChanneld(keys=["image"]), 233 | DimTranspose(keys=["image", "label"]), 234 | MinMaxNormalization(), 235 | CropForegroundd(keys=["image", "label"], source_key="image"), 236 | SpatialPadd(keys=["image", "label"], spatial_size=args.spatial_size, 237 | mode='constant'), 238 | transforms.OneOf(transforms=[ 239 | Resized(keys=["image", "label"], spatial_size=args.spatial_size), 240 | RandCropByPosNegLabeld( 241 | keys=["image", "label"], 242 | label_key="label", 243 | spatial_size=args.spatial_size, 244 | pos=2, 245 | neg=1, 246 | num_samples=1, 247 | image_key="image", 248 | image_threshold=0, 249 | ), 250 | ], 251 | weights=[1, 1] 252 | ), 253 | RandFlipd(keys=["image", "label"], prob=args.rand_flipped_prob, spatial_axis=0), 254 | RandFlipd(keys=["image", "label"], prob=args.rand_flipped_prob, spatial_axis=1), 255 | RandFlipd(keys=["image", "label"], prob=args.rand_flipped_prob, spatial_axis=2), 256 | RandScaleIntensityd(keys="image", factors=0.1, prob=args.rand_scale_intensityd_prob), 257 | RandShiftIntensityd(keys="image", offsets=0.1, prob=args.rand_shift_intensityd_prob), 258 | Resized(keys=["image", "label"], spatial_size=args.spatial_size), 259 | ToTensord(keys=["image", "label"]), 260 | ] 261 | ) 262 | 263 | print(f'----- train on combination dataset -----') 264 | combination_train_ds = build_concat_dataset(root_path=args.data_dir, dataset_codes=args.dataset_codes, 265 | transform=train_transform) 266 | train_sampler = BatchedDistributedSampler(combination_train_ds, shuffle=True, 267 | batch_size=args.batch_size) if args.dist else None 268 | train_loader = DataLoader( 269 | combination_train_ds, 270 | batch_size=args.batch_size, 271 | shuffle=(train_sampler is None), 272 | num_workers=args.num_workers, 273 | sampler=train_sampler, 274 | pin_memory=True, 275 | persistent_workers=True, 276 | collate_fn=collate_fn, 277 | ) 278 | return train_loader 279 | elif args.mode == 'test': 280 | test_transform = Compose( 281 | [ 282 | AddChanneld(keys=["image"]), 283 | DimTranspose(keys=["image", "label"]), 284 | MinMaxNormalization(), 285 | CropForegroundd(keys=["image", "label"], source_key="image"), 286 | SpatialPadd(keys=["image", "label"], spatial_size=args.spatial_size, 287 | mode='constant'), 288 | Resized(keys=["image", "label"], spatial_size=args.spatial_size), 289 | ToTensord(keys=["image", "label"]), 290 | ] 291 | ) 292 | 293 | print(f'----- test on combination dataset -----') 294 | combination_test_ds = build_concat_dataset(root_path=args.data_dir, dataset_codes=args.dataset_codes, 295 | transform=test_transform) 296 | test_sampler = BatchedDistributedSampler(combination_test_ds, shuffle=False, 297 | batch_size=args.batch_size) if args.dist else None 298 | test_loader = DataLoader( 299 | combination_test_ds, 300 | batch_size=args.batch_size, 301 | shuffle=False, 302 | num_workers=args.num_workers, 303 | sampler=test_sampler, 304 | pin_memory=True, 305 | persistent_workers=True, 306 | collate_fn=collate_fn, 307 | ) 308 | return test_loader 309 | else: 310 | raise ValueError("mode should be either 'train' or 'test'") -------------------------------------------------------------------------------- /datasets/dataset_info.txt: -------------------------------------------------------------------------------- 1 | dataset_name = { 2 | '0000': 'CHAOS', 3 | '0001': 'HaN-Seg', 4 | '0002': 'AMOS22', 5 | '0003': 'AbdomenCT-1k', 6 | '0004': 'KiTS23', 7 | '0005': 'KiPA22', 8 | '0006': 'KiTS19', 9 | '0007': 'BCTV', 10 | '0008': 'Pancreas-CT', 11 | '0009': '3D-IRCADb', 12 | '0010': 'AbdomenCT-12organ', 13 | '0011': 'TotalSegmentator', 14 | '0012': 'CT-ORG', 15 | '0013': 'WORD', 16 | '0014': 'VerSe19', 17 | '0015': 'VerSe20', 18 | '0016': 'SLIVER07', 19 | '0017': 'QUBIQ', 20 | '0018': 'MSD-colon', 21 | '0019': 'MSD-hepatic_vessel', 22 | '0020': 'MSD-liver', 23 | '0021': 'MSD-lung', 24 | '0022': 'MSD-pancreas', 25 | '0023': 'MSD-spleen', 26 | '0024': 'LUNA16', 27 | } 28 | 29 | dataset_info_raw = { 30 | '0000': ['liver'], 31 | '0001': ['A_Carotid_L','A_Carotid_R','Arytenoid','Bone_Mandible','Brainstem','BuccalMucosa','Cavity_Oral','Cochlea_L','Cochlea_R','Cricopharyngeus','Esophagus_S','Eye_AL','Eye_AR','Eye_PL','Eye_PR','Glnd_Lacrimal_L','Glnd_Lacrimal_R','Glnd_Submand_L','Glnd_Submand_R','Glnd_Thyroid','Glottis','Larynx_SG','Lips','OpticChiasm','OpticNrv_L','OpticNrv_R','Parotid_L','Parotid_R','Pituitary','SpinalCord'], 32 | '0002': ["spleen", "right kidney", "left kidney", "gall bladder", "esophagus", "liver", "stomach", "arota", "postcava", "pancreas", "right adrenal gland", "left adrenal gland", "duodenum", "bladder", "prostate/uterus"], 33 | '0003': ['liver', 'kidney', 'spleen', 'pancreas'], 34 | '0004': ['kidney', 'kidney tumor', 'kidney cyst'], 35 | '0005': ['Renal vein', 'Kidney', 'Renal artery', 'Tumor'], 36 | '0006': ['kidney', 'kidney tumor'], 37 | '0007': ['spleen','right kidney','left kidney','gallbladder','esophagus','liver','stomach','aorta','inferior vena cava','portal vein and splenic vein','pancreas','right adrenal gland','left adrenal gland'], 38 | '0008': ['pancreas'], 39 | '0009': ['Stones', 'artery', 'biliarysystem', 'bladder', 'bone', 'colon', 'gallbladder', 'heart', 'kidneys', 'leftkidney', 'leftlung', 'leftsurrenalgland', 'leftsurretumor', 'liver', 'livercyst', 'liverkyst', 'liverkyste', 'livertumor', 'livertumor01', 'livertumor02', 'livertumor03', 'livertumor04', 'livertumor05', 'livertumor06', 'livertumor07', 'livertumor1', 'livertumor2', 'livertumors', 'lungs', 'metal', 'metastasectomie', 'pancreas', 'portalvein', 'portalvein1', 'rightkidney', 'rightlung', 'rightsurrenalgland', 'rightsurretumor', 'skin', 'smallintestin', 'spleen', 'stomach', 'surrenalgland', 'tumor', 'uterus', 'venacava', 'venoussystem'], 40 | '0010': ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'], 41 | '0011': ['adrenal_gland_left', 'adrenal_gland_right', 'aorta', 'autochthon_left', 'autochthon_right', 'brain', 'clavicula_left', 'clavicula_right', 'colon', 'duodenum', 'esophagus', 'face', 'femur_left', 'femur_right', 'gallbladder', 'gluteus_maximus_left', 'gluteus_maximus_right', 'gluteus_medius_left', 'gluteus_medius_right', 'gluteus_minimus_left', 'gluteus_minimus_right', 'heart_atrium_left', 'heart_atrium_right', 'heart_myocardium', 'heart_ventricle_left', 'heart_ventricle_right', 'hip_left', 'hip_right', 'humerus_left', 'humerus_right', 'iliac_artery_left', 'iliac_artery_right', 'iliac_vena_left', 'iliac_vena_right', 'iliopsoas_left', 'iliopsoas_right', 'inferior_vena_cava', 'kidney_left', 'kidney_right', 'liver', 'lung_lower_lobe_left', 'lung_lower_lobe_right', 'lung_middle_lobe_right', 'lung_upper_lobe_left', 'lung_upper_lobe_right', 'pancreas', 'portal_vein_and_splenic_vein', 'pulmonary_artery', 'rib_left_1', 'rib_left_10', 'rib_left_11', 'rib_left_12', 'rib_left_2', 'rib_left_3', 'rib_left_4', 'rib_left_5', 'rib_left_6', 'rib_left_7', 'rib_left_8', 'rib_left_9', 'rib_right_1', 'rib_right_10', 'rib_right_11', 'rib_right_12', 'rib_right_2', 'rib_right_3', 'rib_right_4', 'rib_right_5', 'rib_right_6', 'rib_right_7', 'rib_right_8', 'rib_right_9', 'sacrum', 'scapula_left', 'scapula_right', 'small_bowel', 'spleen', 'stomach', 'trachea', 'urinary_bladder', 'vertebrae_C1', 'vertebrae_C2', 'vertebrae_C3', 'vertebrae_C4', 'vertebrae_C5', 'vertebrae_C6', 'vertebrae_C7', 'vertebrae_L1', 'vertebrae_L2', 'vertebrae_L3', 'vertebrae_L4', 'vertebrae_L5', 'vertebrae_T1', 'vertebrae_T10', 'vertebrae_T11', 'vertebrae_T12', 'vertebrae_T2', 'vertebrae_T3', 'vertebrae_T4', 'vertebrae_T5', 'vertebrae_T6', 'vertebrae_T7', 'vertebrae_T8', 'vertebrae_T9'], 42 | '0012': ['Liver','Bladder','Lungs','Kidneys','Bone','Brain'], 43 | '0013': ['Liver','Spleen','Kidney (L)','Kidney (R)','Stomach','Gallbladder','Esophagus','Pancreas','Duodenum','Colon','Intestine','Adrenal','Rectum','Bladder','Head of femur (L)','Head of femur (R)'], 44 | '0014': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4', 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1', 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5', 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9', 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1', 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6', 'sacrum','cocygis','additional 13th thoracic vertebra, T13',], 45 | '0015': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4', 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1', 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5', 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9', 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1', 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6', 'sacrum','cocygis','additional 13th thoracic vertebra, T13',], 46 | '0016': ['liver'], 47 | '0017': ['kidney', 'pancreas', 'pancreatic-lesion'], 48 | '0018': ['colon cancer'], 49 | '0019': ['hepatic vessels', 'tumour'], 50 | '0020': ['liver', 'tumour'], 51 | '0021': ['lung tumours'], 52 | '0022': ['pancreas', 'tumour'], 53 | '0023': ['spleen'], 54 | '0024': ['left lung', 'right lung', 'trachea'], 55 | } 56 | 57 | 58 | 59 | dataset_info = { 60 | '0000': ['liver'], 61 | '0001': ['carotid artery left', 'carotid artery right', 'arytenoid', 'bone mandible', 'brainstem', 62 | 'buccal mucosa', 63 | 'oral cavity', 'cochlea left', 'cochlea right', 'cricopharyngeal inlet', 'cervical esophagus', 64 | 'anterior eyeball left', 'anterior eyeball right', 65 | 'posterior eyeball left', 'posterior eyeball right', 'lacrimal gland left', 'lacrimal gland right', 66 | 'submandibular gland left', 'submandibular gland right', 67 | 'thyroid', 'larynx glottis', 'larynx supraglottic', 'lips', 'optic chiasm', 'optic nerve left', 68 | 'optic nerve right', 69 | 'parotid gland left', 'parotid gland right', 'pituitary gland', 'spinal cord'], 70 | '0002': ["spleen", "right kidney", "left kidney", "gall bladder", "esophagus", "liver", "stomach", "aorta", 71 | "postcava", "pancreas", "right adrenal gland", "left adrenal gland", "duodenum", "bladder", 72 | "prostate or uterus"], 73 | '0003': ['liver', 'kidney', 'spleen', 'pancreas'], 74 | '0004': ['kidney', 'kidney tumor', 'kidney cyst'], 75 | '0005': ['renal vein', 'kidney', 'renal artery', 'tumor'], 76 | '0006': ['kidney', 'kidney tumor'], 77 | '0007': ['spleen', 'right kidney', 'left kidney', 'gallbladder', 'esophagus', 'liver', 'stomach', 'aorta', 78 | 'inferior vena cava', 'portal vein and splenic vein', 'pancreas', 'right adrenal gland', 79 | 'left adrenal gland'], 80 | '0008': ['pancreas'], 81 | '0009': ['stones', 'artery', 'biliary system', 'bladder', 'bone', 'colon', 'gallbladder', 'heart', 82 | 'kidneys', 83 | 'left kidney', 'left lung', 'left suprarenal gland', 'left suprarenal tumor', 'liver', 84 | 'liver cyst', 'liver kyst', 85 | 86 | 'liver kyste', 'liver tumor', 'liver tumor 01', 'liver tumor 02', 'liver tumor 03', 87 | 'liver tumor 04', 88 | 'liver tumor 05', 'liver tumor 06', 'liver tumor 07', 'liver tumor 1', 'liver tumor 2', 89 | 'liver tumors', 90 | 91 | 'lungs', 'metal', 'metastasectomie', 'pancreas', 'portal vein', 'portal vein 1', 'right kidney', 92 | 'right lung', 'right suprarenal gland', 'right suprarenal tumor', 'skin', 'small intestin', 93 | 'spleen', 'stomach', 94 | 'suprarenal gland', 'tumor', 'uterus', 'vena cava', 'venous system'], 95 | '0010': ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 96 | 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 97 | 'left kidney'], 98 | '0011': ['adrenal gland left', 'adrenal gland right', 'aorta', 'autochthon left', 'autochthon right', 99 | 'brain', 'clavicula left', 'clavicula right', 'colon', 'duodenum', 'esophagus', 'face', 100 | 'femur left', 'femur right', 'gallbladder', 'gluteus maximus left', 'gluteus maximus right', 101 | 'gluteus medius left', 'gluteus medius right', 'gluteus minimus left', 'gluteus minimus right', 102 | 103 | 'heart atrium left', 'heart atrium right', 'heart myocardium', 'heart ventricle left', 104 | 'heart ventricle right', 'hip left', 'hip right', 'humerus left', 'humerus right', 105 | 'iliac artery left', 'iliac artery right', 'iliac vena left', 'iliac vena right', 'iliopsoas left', 106 | 107 | 'iliopsoas right', 'inferior vena cava', 'kidney left', 'kidney right', 'liver', 108 | 'lung lower lobe left', 'lung lower lobe right', 'lung middle lobe right', 'lung upper lobe left', 109 | 'lung upper lobe right', 'pancreas', 'portal vein and splenic vein', 'pulmonary artery', 110 | 111 | 'rib left 1', 'rib left 10', 'rib left 11', 'rib left 12', 'rib left 2', 'rib left 3', 112 | 'rib left 4', 'rib left 5', 'rib left 6', 'rib left 7', 'rib left 8', 'rib left 9', 'rib right 1', 113 | 114 | 'rib right 10', 'rib right 11', 'rib right 12', 'rib right 2', 'rib right 3', 'rib right 4', 115 | 'rib right 5', 'rib right 6', 'rib right 7', 'rib right 8', 'rib right 9', 'sacrum', 116 | 117 | 'scapula left', 'scapula right', 'small bowel', 'spleen', 'stomach', 'trachea', 'urinary bladder', 118 | 'vertebrae C1', 'vertebrae C2', 'vertebrae C3', 'vertebrae C4', 'vertebrae C5', 'vertebrae C6', 119 | 120 | 'vertebrae C7', 'vertebrae L1', 'vertebrae L2', 'vertebrae L3', 'vertebrae L4', 'vertebrae L5', 121 | 'vertebrae T1', 'vertebrae T10', 'vertebrae T11', 'vertebrae T12', 'vertebrae T2', 'vertebrae T3', 122 | 'vertebrae T4', 'vertebrae T5', 'vertebrae T6', 'vertebrae T7', 'vertebrae T8', 'vertebrae T9'], 123 | '0012': ['liver', 'bladder', 'lungs', 'kidneys', 'bone', 'brain'], 124 | '0013': ['liver', 'spleen', 'kidney left', 'kidney right', 'stomach', 'gallbladder', 'esophagus', 125 | 'pancreas', 126 | 'duodenum', 'colon', 'intestine', 'adrenal', 'rectum', 'bladder', 'head of femur left', 127 | 'head of femur right'], 128 | '0014': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4', 129 | 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1', 130 | 131 | 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5', 132 | 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9', 133 | 134 | 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1', 135 | 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6', 136 | 'sacrum', 'coccygis', 'additional 13th thoracic vertebra, T13', ], 137 | '0015': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4', 138 | 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1', 139 | 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5', 140 | 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9', 141 | 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1', 142 | 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6', 143 | 'sacrum', 'coccygis', 'additional 13th thoracic vertebra T13', ], 144 | '0016': ['liver'], 145 | '0017': ['kidney', 'pancreas', 'pancreatic lesion'], 146 | '0018': ['colon cancer'], 147 | '0019': ['hepatic vessels', 'tumour'], 148 | '0020': ['liver', 'tumour'], 149 | '0021': ['lung tumours'], 150 | '0022': ['pancreas', 'tumour'], 151 | '0023': ['spleen'], 152 | '0024': ['left lung', 'right lung', 'trachea'], 153 | } -------------------------------------------------------------------------------- /models/components/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # Modified from https://github.com/open-mmlab/mmdetection/blob/main/projects/EfficientDet/efficientdet/utils.py 3 | 4 | import math 5 | from typing import Tuple, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from mmcv.cnn.bricks import Swish, build_norm_layer 10 | from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_ 11 | from mmdet.registry import MODELS 12 | from mmdet.utils import OptConfigType 13 | from torch.autograd import Function 14 | import torch.nn.functional as F 15 | from torch.nn.modules.utils import _triple, _pair, _single 16 | 17 | import adapool_cuda 18 | 19 | 20 | def variance_scaling_trunc(tensor, gain=1.): 21 | fan_in, _ = _calculate_fan_in_and_fan_out(tensor) 22 | gain /= max(1.0, fan_in) 23 | std = math.sqrt(gain) / .87962566103423978 24 | return trunc_normal_(tensor, 0., std) 25 | 26 | 27 | @MODELS.register_module() 28 | class Conv2dSamePadding(nn.Conv2d): 29 | 30 | def __init__(self, 31 | in_channels: int, 32 | out_channels: int, 33 | kernel_size: Union[int, Tuple[int, int]], 34 | stride: Union[int, Tuple[int, int]] = 1, 35 | padding: Union[int, Tuple[int, int]] = 0, 36 | dilation: Union[int, Tuple[int, int]] = 1, 37 | groups: int = 1, 38 | bias: bool = True): 39 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, 40 | dilation, groups, bias) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | img_h, img_w = x.size()[-2:] 44 | kernel_h, kernel_w = self.weight.size()[-2:] 45 | extra_w = (math.ceil(img_w / self.stride[1]) - 46 | 1) * self.stride[1] - img_w + kernel_w 47 | extra_h = (math.ceil(img_h / self.stride[0]) - 48 | 1) * self.stride[0] - img_h + kernel_h 49 | 50 | left = extra_w // 2 51 | right = extra_w - left 52 | top = extra_h // 2 53 | bottom = extra_h - top 54 | x = F.pad(x, [left, right, top, bottom]) 55 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, 56 | self.dilation, self.groups) 57 | 58 | 59 | class MaxPool2dSamePadding(nn.Module): 60 | 61 | def __init__(self, 62 | kernel_size: Union[int, Tuple[int, int]] = 3, 63 | stride: Union[int, Tuple[int, int]] = 2, 64 | **kwargs): 65 | super().__init__() 66 | self.pool = nn.MaxPool2d(kernel_size, stride, **kwargs) 67 | self.stride = self.pool.stride 68 | self.kernel_size = self.pool.kernel_size 69 | 70 | if isinstance(self.stride, int): 71 | self.stride = [self.stride] * 2 72 | if isinstance(self.kernel_size, int): 73 | self.kernel_size = [self.kernel_size] * 2 74 | 75 | def forward(self, x): 76 | h, w = x.shape[-2:] 77 | 78 | extra_h = (math.ceil(w / self.stride[1]) - 79 | 1) * self.stride[1] - w + self.kernel_size[1] 80 | extra_v = (math.ceil(h / self.stride[0]) - 81 | 1) * self.stride[0] - h + self.kernel_size[0] 82 | 83 | left = extra_h // 2 84 | right = extra_h - left 85 | top = extra_v // 2 86 | bottom = extra_v - top 87 | 88 | x = F.pad(x, [left, right, top, bottom]) 89 | x = self.pool(x) 90 | 91 | return x 92 | 93 | 94 | class DepthWiseConvBlock(nn.Module): 95 | 96 | def __init__( 97 | self, 98 | in_channels: int, 99 | out_channels: int, 100 | apply_norm: bool = True, 101 | conv_bn_act_pattern: bool = False, 102 | norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3) 103 | ) -> None: 104 | super(DepthWiseConvBlock, self).__init__() 105 | self.depthwise_conv = Conv2dSamePadding( 106 | in_channels, 107 | in_channels, 108 | kernel_size=3, 109 | stride=1, 110 | groups=in_channels, 111 | bias=False) 112 | self.pointwise_conv = Conv2dSamePadding( 113 | in_channels, out_channels, kernel_size=1, stride=1) 114 | 115 | self.apply_norm = apply_norm 116 | if self.apply_norm: 117 | self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1] 118 | 119 | self.apply_activation = conv_bn_act_pattern 120 | if self.apply_activation: 121 | self.swish = Swish() 122 | 123 | def forward(self, x): 124 | x = self.depthwise_conv(x) 125 | x = self.pointwise_conv(x) 126 | if self.apply_norm: 127 | x = self.bn(x) 128 | if self.apply_activation: 129 | x = self.swish(x) 130 | 131 | return x 132 | 133 | 134 | class DownChannelBlock(nn.Module): 135 | 136 | def __init__( 137 | self, 138 | in_channels: int, 139 | out_channels: int, 140 | apply_norm: bool = True, 141 | conv_bn_act_pattern: bool = False, 142 | norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3) 143 | ) -> None: 144 | super(DownChannelBlock, self).__init__() 145 | self.down_conv = Conv2dSamePadding(in_channels, out_channels, 1) 146 | self.apply_norm = apply_norm 147 | if self.apply_norm: 148 | self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1] 149 | self.apply_activation = conv_bn_act_pattern 150 | if self.apply_activation: 151 | self.swish = Swish() 152 | 153 | def forward(self, x): 154 | x = self.down_conv(x) 155 | if self.apply_norm: 156 | x = self.bn(x) 157 | if self.apply_activation: 158 | x = self.swish(x) 159 | 160 | return x 161 | 162 | 163 | # Copied from https://github.com/wenxi-yue/SurgicalSAM/blob/main/surgicalSAM/model.py 164 | class CUDA_ADAPOOL2d(Function): 165 | 166 | @staticmethod 167 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 168 | def forward(ctx, input, beta, kernel=2, stride=None, return_mask=False): 169 | 170 | assert input.dtype == beta.dtype, '`input` and `beta` are not of the same dtype.' 171 | beta = torch.clamp(beta, 0., 1.) 172 | no_batch = False 173 | if len(input.size()) == 3: 174 | no_batch = True 175 | input.unsqueeze_(0) 176 | B, C, H, W = input.shape 177 | kernel = _pair(kernel) 178 | if stride is None: 179 | stride = kernel 180 | else: 181 | stride = _pair(stride) 182 | 183 | oH = (H - kernel[0]) // stride[0] + 1 184 | oW = (W - kernel[1]) // stride[1] + 1 185 | 186 | output = input.new_zeros((B, C, oH, oW)) 187 | if return_mask: 188 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW)) 189 | else: 190 | mask = input.new_zeros((1)) 191 | 192 | adapool_cuda.forward_2d(input.contiguous(), beta, kernel, stride, output, return_mask, mask) 193 | ctx.save_for_backward(input, beta) 194 | ctx.kernel = kernel 195 | ctx.stride = stride 196 | if return_mask: 197 | mask_ = mask.detach().clone() 198 | mask_.requires_grad = False 199 | CUDA_ADAPOOL2d.mask = mask_ 200 | output = torch.nan_to_num(output) 201 | if no_batch: 202 | return output.squeeze_(0) 203 | return output 204 | 205 | @staticmethod 206 | @torch.cuda.amp.custom_bwd 207 | def backward(ctx, grad_output): 208 | 209 | grad_input = torch.zeros_like(ctx.saved_tensors[0]) 210 | grad_beta = torch.zeros_like(ctx.saved_tensors[1]) 211 | 212 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input, grad_beta] 213 | adapool_cuda.backward_2d(*saved) 214 | 215 | return torch.nan_to_num(saved[-2]), torch.nan_to_num(saved[-1]), None, None, None 216 | 217 | 218 | class CUDA_ADAPOOL2d_EDSCW(Function): 219 | 220 | @staticmethod 221 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 222 | def forward(ctx, input, kernel=2, stride=None, return_mask=False): 223 | no_batch = False 224 | if len(input.size()) == 3: 225 | no_batch = True 226 | input.unsqueeze_(0) 227 | B, C, H, W = input.shape 228 | kernel = _pair(kernel) 229 | if stride is None: 230 | stride = kernel 231 | else: 232 | stride = _pair(stride) 233 | 234 | oH = (H - kernel[0]) // stride[0] + 1 235 | oW = (W - kernel[1]) // stride[1] + 1 236 | 237 | output = input.new_zeros((B, C, oH, oW)) 238 | if return_mask: 239 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW)) 240 | else: 241 | mask = input.new_zeros((1)) 242 | 243 | adapool_cuda.forward_2d_edscw(input.contiguous(), kernel, stride, output, return_mask, mask) 244 | ctx.save_for_backward(input) 245 | ctx.kernel = kernel 246 | ctx.stride = stride 247 | if return_mask: 248 | mask_ = mask.detach().clone() 249 | mask_.requires_grad = False 250 | CUDA_ADAPOOL2d_EDSCW.mask = mask_ 251 | output = torch.nan_to_num(output) 252 | if no_batch: 253 | return output.squeeze_(0) 254 | return output 255 | 256 | @staticmethod 257 | @torch.cuda.amp.custom_bwd 258 | def backward(ctx, grad_output): 259 | 260 | grad_input = torch.zeros_like(ctx.saved_tensors[0]) 261 | 262 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input] 263 | adapool_cuda.backward_2d_edscw(*saved) 264 | 265 | return torch.nan_to_num(saved[-1]), None, None, None 266 | 267 | 268 | class CUDA_IDWPOOL2d(Function): 269 | 270 | @staticmethod 271 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 272 | def forward(ctx, input, kernel=2, stride=None, return_mask=False): 273 | no_batch = False 274 | if len(input.size()) == 3: 275 | no_batch = True 276 | input.unsqueeze_(0) 277 | B, C, H, W = input.shape 278 | kernel = _pair(kernel) 279 | if stride is None: 280 | stride = kernel 281 | else: 282 | stride = _pair(stride) 283 | 284 | oH = (H - kernel[0]) // stride[0] + 1 285 | oW = (W - kernel[1]) // stride[1] + 1 286 | 287 | output = input.new_zeros((B, C, oH, oW)) 288 | if return_mask: 289 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW)) 290 | else: 291 | mask = input.new_zeros((1)) 292 | 293 | adapool_cuda.forward_2d_idw(input.contiguous(), kernel, stride, output, return_mask, mask) 294 | ctx.save_for_backward(input) 295 | ctx.kernel = kernel 296 | ctx.stride = stride 297 | if return_mask: 298 | mask_ = mask.detach().clone() 299 | mask_.requires_grad = False 300 | CUDA_ADAPOOL2d_EDSCW.mask = mask_ 301 | output = torch.nan_to_num(output) 302 | if no_batch: 303 | return output.squeeze_(0) 304 | return output 305 | 306 | @staticmethod 307 | @torch.cuda.amp.custom_bwd 308 | def backward(ctx, grad_output): 309 | 310 | grad_input = torch.zeros_like(ctx.saved_tensors[0]) 311 | 312 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input] 313 | adapool_cuda.backward_2d_idw(*saved) 314 | 315 | return torch.nan_to_num(saved[-1]), None, None, None 316 | 317 | 318 | class CUDA_ADAPOOL2d_EM(Function): 319 | 320 | @staticmethod 321 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 322 | def forward(ctx, input, kernel=2, stride=None, return_mask=False): 323 | 324 | no_batch = False 325 | if len(input.size()) == 3: 326 | no_batch = True 327 | input.unsqueeze_(0) 328 | B, C, H, W = input.shape 329 | kernel = _pair(kernel) 330 | if stride is None: 331 | stride = kernel 332 | else: 333 | stride = _pair(stride) 334 | 335 | oH = (H - kernel[0]) // stride[0] + 1 336 | oW = (W - kernel[1]) // stride[1] + 1 337 | 338 | output = input.new_zeros((B, C, oH, oW)) 339 | if return_mask: 340 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW)) 341 | else: 342 | mask = input.new_zeros((1)) 343 | 344 | adapool_cuda.forward_2d_em(input.contiguous(), kernel, stride, output, return_mask, mask) 345 | ctx.save_for_backward(input) 346 | ctx.kernel = kernel 347 | ctx.stride = stride 348 | if return_mask: 349 | mask_ = mask.detach().clone() 350 | mask_.requires_grad = False 351 | CUDA_ADAPOOL2d_EM.mask = mask_ 352 | output = torch.nan_to_num(output) 353 | if no_batch: 354 | return output.squeeze_(0) 355 | return output 356 | 357 | @staticmethod 358 | @torch.cuda.amp.custom_bwd 359 | def backward(ctx, grad_output): 360 | 361 | grad_input = torch.zeros_like(ctx.saved_tensors[0]) 362 | 363 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input] 364 | adapool_cuda.backward_2d_em(*saved) 365 | 366 | return torch.nan_to_num(saved[-1]), None, None, None 367 | -------------------------------------------------------------------------------- /models/sam2/sam/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from models.sam2.sam2_utils import LayerNorm2d, MLP 13 | 14 | 15 | class MaskDecoder(nn.Module): 16 | def __init__( 17 | self, 18 | *, 19 | transformer_dim: int, 20 | transformer: nn.Module, 21 | num_multimask_outputs: int = 3, 22 | activation: Type[nn.Module] = nn.GELU, 23 | iou_head_depth: int = 3, 24 | iou_head_hidden_dim: int = 256, 25 | use_high_res_features: bool = False, 26 | iou_prediction_use_sigmoid=False, 27 | dynamic_multimask_via_stability=False, 28 | dynamic_multimask_stability_delta=0.05, 29 | dynamic_multimask_stability_thresh=0.98, 30 | pred_obj_scores: bool = False, 31 | pred_obj_scores_mlp: bool = False, 32 | use_multimask_token_for_obj_ptr: bool = False, 33 | ) -> None: 34 | """ 35 | Predicts masks given an image and prompt embeddings, using a 36 | transformer architecture. 37 | 38 | Arguments: 39 | transformer_dim (int): the channel dimension of the transformer 40 | transformer (nn.Module): the transformer used to predict masks 41 | num_multimask_outputs (int): the number of masks to predict 42 | when disambiguating masks 43 | activation (nn.Module): the type of activation to use when 44 | upscaling masks 45 | iou_head_depth (int): the depth of the MLP used to predict 46 | mask quality 47 | iou_head_hidden_dim (int): the hidden dimension of the MLP 48 | used to predict mask quality 49 | """ 50 | super().__init__() 51 | self.transformer_dim = transformer_dim 52 | self.transformer = transformer 53 | 54 | self.num_multimask_outputs = num_multimask_outputs 55 | 56 | self.iou_token = nn.Embedding(1, transformer_dim) 57 | self.num_mask_tokens = num_multimask_outputs + 1 58 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 59 | 60 | self.pred_obj_scores = pred_obj_scores 61 | if self.pred_obj_scores: 62 | self.obj_score_token = nn.Embedding(1, transformer_dim) 63 | self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr 64 | 65 | self.output_upscaling = nn.Sequential( 66 | nn.ConvTranspose2d( 67 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 68 | ), 69 | LayerNorm2d(transformer_dim // 4), 70 | activation(), 71 | nn.ConvTranspose2d( 72 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 73 | ), 74 | activation(), 75 | ) 76 | self.use_high_res_features = use_high_res_features 77 | if use_high_res_features: 78 | self.conv_s0 = nn.Conv2d( 79 | transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 80 | ) 81 | self.conv_s1 = nn.Conv2d( 82 | transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 83 | ) 84 | 85 | self.output_hypernetworks_mlps = nn.ModuleList( 86 | [ 87 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 88 | for i in range(self.num_mask_tokens) 89 | ] 90 | ) 91 | 92 | self.iou_prediction_head = MLP( 93 | transformer_dim, 94 | iou_head_hidden_dim, 95 | self.num_mask_tokens, 96 | iou_head_depth, 97 | sigmoid_output=iou_prediction_use_sigmoid, 98 | ) 99 | if self.pred_obj_scores: 100 | self.pred_obj_score_head = nn.Linear(transformer_dim, 1) 101 | if pred_obj_scores_mlp: 102 | self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) 103 | 104 | # When outputting a single mask, optionally we can dynamically fall back to the best 105 | # multimask output token if the single mask output token gives low stability scores. 106 | self.dynamic_multimask_via_stability = dynamic_multimask_via_stability 107 | self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta 108 | self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh 109 | 110 | def forward( 111 | self, 112 | image_embeddings: torch.Tensor, 113 | image_pe: torch.Tensor, 114 | sparse_prompt_embeddings: torch.Tensor, 115 | dense_prompt_embeddings: torch.Tensor, 116 | multimask_output: bool, 117 | repeat_image: bool, 118 | high_res_features: Optional[List[torch.Tensor]] = None, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """ 121 | Predict masks given image and prompt embeddings. 122 | 123 | Arguments: 124 | image_embeddings (torch.Tensor): the embeddings from the image encoder 125 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 126 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 127 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 128 | multimask_output (bool): Whether to return multiple masks or a single 129 | mask. 130 | 131 | Returns: 132 | torch.Tensor: batched predicted masks 133 | torch.Tensor: batched predictions of mask quality 134 | torch.Tensor: batched sam token for mask output 135 | """ 136 | masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( 137 | image_embeddings=image_embeddings, 138 | image_pe=image_pe, 139 | sparse_prompt_embeddings=sparse_prompt_embeddings, 140 | dense_prompt_embeddings=dense_prompt_embeddings, 141 | repeat_image=repeat_image, 142 | high_res_features=high_res_features, 143 | ) 144 | 145 | # Select the correct mask or masks for output 146 | if multimask_output: 147 | masks = masks[:, 1:, :, :] 148 | iou_pred = iou_pred[:, 1:] 149 | elif self.dynamic_multimask_via_stability and not self.training: 150 | masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) 151 | else: 152 | masks = masks[:, 0:1, :, :] 153 | iou_pred = iou_pred[:, 0:1] 154 | 155 | if multimask_output and self.use_multimask_token_for_obj_ptr: 156 | sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape 157 | else: 158 | # Take the mask output token. Here we *always* use the token for single mask output. 159 | # At test time, even if we track after 1-click (and using multimask_output=True), 160 | # we still take the single mask token here. The rationale is that we always track 161 | # after multiple clicks during segment_anything_training, so the past tokens seen during segment_anything_training 162 | # are always the single mask token (and we'll let it be the object-memory token). 163 | sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape 164 | 165 | # Prepare output 166 | return masks, iou_pred, sam_tokens_out, object_score_logits 167 | 168 | def predict_masks( 169 | self, 170 | image_embeddings: torch.Tensor, 171 | image_pe: torch.Tensor, 172 | sparse_prompt_embeddings: torch.Tensor, 173 | dense_prompt_embeddings: torch.Tensor, 174 | repeat_image: bool, 175 | high_res_features: Optional[List[torch.Tensor]] = None, 176 | ) -> Tuple[torch.Tensor, torch.Tensor]: 177 | """Predicts masks. See 'forward' for more details.""" 178 | # Concatenate output tokens 179 | s = 0 180 | if self.pred_obj_scores: 181 | output_tokens = torch.cat( 182 | [ 183 | self.obj_score_token.weight, 184 | self.iou_token.weight, 185 | self.mask_tokens.weight, 186 | ], 187 | dim=0, 188 | ) 189 | s = 1 190 | else: 191 | output_tokens = torch.cat( 192 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 193 | ) 194 | output_tokens = output_tokens.unsqueeze(0).expand( 195 | sparse_prompt_embeddings.size(0), -1, -1 196 | ) 197 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 198 | 199 | # Expand per-image data in batch direction to be per-mask 200 | if repeat_image: 201 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 202 | else: 203 | assert image_embeddings.shape[0] == tokens.shape[0] 204 | src = image_embeddings 205 | src = src + dense_prompt_embeddings 206 | assert ( 207 | image_pe.size(0) == 1 208 | ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" 209 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 210 | b, c, h, w = src.shape 211 | 212 | # Run the transformer 213 | hs, src = self.transformer(src, pos_src, tokens) 214 | iou_token_out = hs[:, s, :] 215 | mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] 216 | 217 | # Upscale mask embeddings and predict masks using the mask tokens 218 | src = src.transpose(1, 2).view(b, c, h, w) 219 | if not self.use_high_res_features: 220 | upscaled_embedding = self.output_upscaling(src) 221 | else: 222 | dc1, ln1, act1, dc2, act2 = self.output_upscaling 223 | feat_s0, feat_s1 = high_res_features 224 | upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) 225 | upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) 226 | 227 | hyper_in_list: List[torch.Tensor] = [] 228 | for i in range(self.num_mask_tokens): 229 | hyper_in_list.append( 230 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 231 | ) 232 | hyper_in = torch.stack(hyper_in_list, dim=1) 233 | b, c, h, w = upscaled_embedding.shape 234 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 235 | 236 | # Generate mask quality predictions 237 | iou_pred = self.iou_prediction_head(iou_token_out) 238 | if self.pred_obj_scores: 239 | assert s == 1 240 | object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) 241 | else: 242 | # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 243 | object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) 244 | 245 | return masks, iou_pred, mask_tokens_out, object_score_logits 246 | 247 | def _get_stability_scores(self, mask_logits): 248 | """ 249 | Compute stability scores of the mask logits based on the IoU between upper and 250 | lower thresholds. 251 | """ 252 | mask_logits = mask_logits.flatten(-2) 253 | stability_delta = self.dynamic_multimask_stability_delta 254 | area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() 255 | area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() 256 | stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) 257 | return stability_scores 258 | 259 | def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): 260 | """ 261 | When outputting a single mask, if the stability score from the current single-mask 262 | output (based on output token 0) falls below a threshold, we instead select from 263 | multi-mask outputs (based on output token 1~3) the mask with the highest predicted 264 | IoU score. This is intended to ensure a valid mask for both clicking and tracking. 265 | """ 266 | # The best mask from multimask output tokens (1~3) 267 | multimask_logits = all_mask_logits[:, 1:, :, :] 268 | multimask_iou_scores = all_iou_scores[:, 1:] 269 | best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) 270 | batch_inds = torch.arange( 271 | multimask_iou_scores.size(0), device=all_iou_scores.device 272 | ) 273 | best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] 274 | best_multimask_logits = best_multimask_logits.unsqueeze(1) 275 | best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] 276 | best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) 277 | 278 | # The mask from singlemask output token 0 and its stability score 279 | singlemask_logits = all_mask_logits[:, 0:1, :, :] 280 | singlemask_iou_scores = all_iou_scores[:, 0:1] 281 | stability_scores = self._get_stability_scores(singlemask_logits) 282 | is_stable = stability_scores >= self.dynamic_multimask_stability_thresh 283 | 284 | # Dynamically fall back to best multimask output upon low stability scores. 285 | mask_logits_out = torch.where( 286 | is_stable[..., None, None].expand_as(singlemask_logits), 287 | singlemask_logits, 288 | best_multimask_logits, 289 | ) 290 | iou_scores_out = torch.where( 291 | is_stable.expand_as(singlemask_iou_scores), 292 | singlemask_iou_scores, 293 | best_multimask_iou_scores, 294 | ) 295 | return mask_logits_out, iou_scores_out --------------------------------------------------------------------------------