├── models
├── __init__.py
├── sam2
│ ├── sam
│ │ ├── __init__.py
│ │ ├── prompt_encoder.py
│ │ ├── transformer.py
│ │ └── mask_decoder.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── transformers.py
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── image_encoder.py
│ │ └── hieradet.py
│ ├── configs
│ │ └── sam2
│ │ │ ├── sam2_hiera_b+.yaml
│ │ │ ├── sam2_hiera_s.yaml
│ │ │ ├── sam2_hiera_l.yaml
│ │ │ └── sam2_hiera_t.yaml
│ ├── memory_attention.py
│ ├── memory_encoder.py
│ ├── build_sam.py
│ └── position_encoding.py
├── components
│ ├── __init__.py
│ ├── loss.py
│ ├── lr_scheduler.py
│ ├── semantic_extraction.py
│ └── utils.py
└── utils.py
├── datasets
├── __init__.py
├── data_load_demo.py
├── utils.py
├── data_process.py
├── train_data_prepare.py
├── data_loader.py
└── dataset_info.txt
├── visualization
├── __init__.py
├── visualization_area.py
├── utils.py
├── visualization_slice.py
├── visualization_3d.py
└── visualization_config.py
├── static
├── overview.png
├── visualization.png
└── visualization_appendix.png
├── scripts
├── train.sh
└── test.sh
├── README.md
├── test.py
└── env_config.yml
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/sam2/sam/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visualization/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/sam2/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/static/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YU-deep/CRISP_SAM2/HEAD/static/overview.png
--------------------------------------------------------------------------------
/static/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YU-deep/CRISP_SAM2/HEAD/static/visualization.png
--------------------------------------------------------------------------------
/static/visualization_appendix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YU-deep/CRISP_SAM2/HEAD/static/visualization_appendix.png
--------------------------------------------------------------------------------
/models/sam2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
--------------------------------------------------------------------------------
/models/sam2/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | export CKPT="your_path_to_checkpoint"
2 | export WORK_DIR="your_path_to_work_dir"
3 | export DATA_DIR="your_path_to_datasets"
4 | export CLIP_TEXT_CKPT="path_to_clip_text_encoder_checkpoint"
5 | export CLIP_IMAGE_CKPT="path_to_clip_image_encoder_checkpoint"
6 | export CONFIG_FILE="path_to_sam2_config_file"
7 | export SAM2_CKPT="path_to_sam2_checkpoint"
8 | # export HF_TOKEN=xxxxx
9 | # export MASTER_ADDR=xxx.xxx.xxx.xxx
10 | # export MASTER_PORT=xxxx
11 |
12 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py \
13 | --resume $SEGVOL_CKPT \
14 | --work_dir $WORK_DIR \
15 | --data_dir $DATA_DIR \
16 | --clip_text_ckpt $CLIP_TEXT_CKPT \
17 | --clip_image_ckpt $CLIP_IMAGE_CKPT \
18 | --config_file $CONFIG_FILE\
19 | --sam2_ckpt SAM2_CKPT
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | export PRETRAINED="your_path_to_checkpoint"
2 | export DATA_DIR="your_path_to_datasets"
3 | export WORK_DIR="your_path_to_work_dir"
4 | export RESULT_DIR="your_path_to_result_dir"
5 | export CLIP_TEXT_CKPT="path_to_clip_text_encoder_checkpoint"
6 | export CLIP_IMAGE_CKPT="path_to_clip_image_encoder_checkpoint"
7 | export CONFIG_FILE="path_to_sam2_config_file"
8 | export SAM2_CKPT="path_to_sam2_checkpoint"
9 | # export HF_TOKEN=xxxxx
10 | # export MASTER_ADDR=xxx.xxx.xxx.xxx
11 | # export MASTER_PORT=xxxx
12 |
13 |
14 | python test.py \
15 | --pretrain $PRETRAINED \
16 | --data_dir $DATA_DIR \
17 | --data_dir $DATA_DIR \
18 | --result_dir $RESULT_DIR \
19 | --clip_text_ckpt $CLIP_TEXT_CKPT \
20 | --clip_image_ckpt $CLIP_IMAGE_CKPT \
21 | --config_file $CONFIG_FILE\
22 | --sam2_ckpt SAM2_CKPT
23 |
--------------------------------------------------------------------------------
/visualization/visualization_area.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 |
4 | def zoom_in_one_area(image_path, output_path, box, zoom_factor):
5 | image = Image.open(image_path)
6 | left, upper, right, lower = box
7 | assert left * upper * right * lower <= 1
8 | left = left * image.size[0]
9 | upper = upper * image.size[1]
10 | right = right * image.size[0]
11 | lower = lower * image.size[1]
12 | # print(image.size)
13 | cropped_image = image.crop((left, upper, right, lower))
14 | new_size = (int(cropped_image.width * zoom_factor), int(cropped_image.height * zoom_factor))
15 | zoomed_image = cropped_image.resize(new_size, Image.LANCZOS)
16 | zoomed_image.save(output_path)
17 |
18 |
19 | image_path = '../output/...'
20 | output_path = '../output/...'
21 |
22 |
23 | box = (0.6, 0.6, 0.8, 0.8) # (x_upper_left, y_upper_left, x_lower_right, y_lower_right) should be [0, 1]
24 | zoom_factor = 2
25 |
26 | zoom_in_one_area(image_path, output_path, box, zoom_factor)
27 |
--------------------------------------------------------------------------------
/datasets/data_load_demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import sparse
3 | import ast
4 | import os
5 | import json
6 |
7 | """
8 | You should download the whole dataset first at: https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg
9 | Then, use this loader tool to load dataset
10 | """
11 | uniseg_path = '/PATH/M3D_Seg' # PATH : your path
12 | dataset_code = '0001'
13 | json_path = os.path.join('../', dataset_code, dataset_code + '.json')
14 | with open(json_path, 'r') as f:
15 | dataset_dict = json.load(f)
16 |
17 | ct_file_path = os.path.join(uniseg_path, dataset_dict['train'][0]['image'])
18 | gt_file_path = os.path.join(uniseg_path, dataset_dict['train'][0]['label'])
19 |
20 | img_array = np.load(ct_file_path)[0]
21 | # print('img_array.shape ', img_array.shape)
22 |
23 | allmatrix_sp= sparse.load_npz(gt_file_path)
24 | gt_shape = ast.literal_eval(gt_file_path.split('.')[-2].split('_')[-1])
25 | gt_array=allmatrix_sp.toarray().reshape(gt_shape)
26 | # print('gt_array.shape ', gt_array.shape)
27 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import nibabel as nib
3 |
4 |
5 | def split_left_right(file_name, output_path):
6 | nifti_img = nib.load(file_name)
7 | segmentation_data = nifti_img.get_fdata()
8 |
9 | LABEL = 2
10 | LEFT_LABEL = 5
11 | RIGHT_LABEL = 6
12 |
13 | kidney_indices = np.argwhere(segmentation_data == LABEL)
14 | new_segmentation_data = np.copy(segmentation_data)
15 |
16 | for idx in kidney_indices:
17 | x, y, z = idx
18 | # determine left or right based on x axis
19 | if x < segmentation_data.shape[0] // 2:
20 | new_segmentation_data[x, y, z] = LEFT_LABEL
21 | else:
22 | new_segmentation_data[x, y, z] = RIGHT_LABEL
23 |
24 | new_nifti_img = nib.Nifti1Image(new_segmentation_data, nifti_img.affine)
25 | nib.save(new_nifti_img, output_path)
26 |
27 |
28 | def generate_point_prompt(mask):
29 | rows, cols = np.where(mask)
30 | x_min, x_max = cols.min(), cols.max()
31 | y_min, y_max = rows.min(), rows.max()
32 | x = x_max - x_min
33 | y = y_max - y_min
34 |
35 | while True:
36 | i0, j0 = np.random.choice(rows), np.random.choice(cols)
37 | if (i0 + int(0.1 * x) < mask.shape[1] and mask[i0, j0 + int(0.1 * x)] and
38 | i0 - int(0.1 * x) >= 0 and mask[i0, j0 - int(0.1 * x)] and
39 | j0 + int(0.1 * y) < mask.shape[0] and mask[j0 + int(0.1 * y), i0] and
40 | j0 - int(0.1 * y) >= 0 and mask[j0 - int(0.1 * y), i0]):
41 | break
42 | return (i0, j0)
43 |
44 |
45 | def generate_bbox_prompt(mask):
46 | rows, cols = np.where(mask)
47 | i1 = int(np.mean(rows))
48 | j1 = int(np.mean(cols))
49 |
50 | x_min, x_max = cols.min(), cols.max()
51 | y_min, y_max = rows.min(), rows.max()
52 | x = x_max - x_min
53 | y = y_max - y_min
54 |
55 | t1 = np.random.uniform(0.1, 0.3)
56 | t2 = np.random.uniform(0.1, 0.3)
57 |
58 | x1 = int(i1 - t1 * x)
59 | y1 = int(j1 - t2 * y)
60 | x2 = int(i1 + t1 * x)
61 | y2 = int(j1 + t2 * y)
62 |
63 | return (x1, y1, x2, y2)
64 |
--------------------------------------------------------------------------------
/models/components/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class CrispSamLoss(nn.Module):
7 | def __init__(self, alpha=0.5, beta=0.1, gamma=0.1, omega1=20 / 21, omega2=1 / 21):
8 | super().__init__()
9 | self.alpha = nn.Parameter(torch.tensor(alpha, requires_grad=True))
10 | self.beta = nn.Parameter(torch.tensor(beta, requires_grad=True))
11 | self.gamma = nn.Parameter(torch.tensor(gamma, requires_grad=True))
12 | self.omega1 = omega1
13 | self.omega2 = omega2
14 |
15 | def focal_loss(self, inputs, targets, gamma=2):
16 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
17 | pt = torch.exp(-BCE_loss)
18 | F_loss = (1 - pt) ** gamma * BCE_loss
19 | return F_loss.mean()
20 |
21 | def dice_loss(self, inputs, targets):
22 | inputs = torch.sigmoid(inputs)
23 | inputs = inputs.view(-1)
24 | targets = targets.view(-1)
25 | intersection = (inputs * targets).sum()
26 | dice = (2. * intersection) / (inputs.sum() + targets.sum())
27 | return 1 - dice
28 |
29 | def forward(self, masks_original_pred, masks_refined_pred, masks_gt, iou_pred=None, iou_gt=None, obj_pred=None, obj_gt=None):
30 |
31 | L_focal_original = self.focal_loss(masks_original_pred, masks_gt)
32 | L_dice_original = self.dice_loss(masks_original_pred, masks_gt)
33 | L_seg_original = self.omega1 * L_focal_original + self.omega2 * L_dice_original
34 |
35 | L_focal_refined = self.focal_loss(masks_refined_pred, masks_gt)
36 | L_dice_refined = self.dice_loss(masks_refined_pred, masks_gt)
37 | L_seg_refined = self.omega1 * L_focal_refined + self.omega2 * L_dice_refined
38 |
39 | if iou_pred is not None:
40 | if iou_gt is None:
41 | iou_gt = torch.ones_like(iou_pred)
42 | L_mae = F.l1_loss(iou_pred, iou_gt)
43 | else:
44 | L_mae = torch.tensor(0.0, requires_grad=False)
45 |
46 | if obj_pred is not None:
47 | if obj_gt is None:
48 | obj_gt = torch.ones(obj_pred.size(0), dtype=torch.long, device=obj_pred.device)
49 | L_ce = F.cross_entropy(obj_pred, obj_gt)
50 | else:
51 | L_ce = torch.tensor(0.0, requires_grad=False)
52 |
53 | total_loss = self.alpha * L_seg_original + (
54 | 1 - self.alpha) * L_seg_refined + self.beta * L_mae + self.gamma * L_ce
55 | return L_seg_original, L_seg_refined, total_loss
56 |
57 |
--------------------------------------------------------------------------------
/visualization/utils.py:
--------------------------------------------------------------------------------
1 | import vtk
2 | from visualization_config import *
3 |
4 |
5 | def read_volume(file_name):
6 | if file_name.endswith(".nii.gz"):
7 | reader = vtk.vtkNIFTIImageReader()
8 | elif file_name.endswith(".nrrd"):
9 | reader = vtk.vtkNrrdReader()
10 | reader.SetFileNameSliceOffset(1)
11 | reader.SetDataByteOrderToBigEndian()
12 | reader.SetFileName(file_name)
13 | reader.Update()
14 | return reader
15 |
16 |
17 | def create_mask_extractor(reader):
18 | """
19 | Given the output from mask (vtkNIFTIImageReader) extract it into 3D using
20 | vtkDiscreteMarchingCubes algorithm (https://www.vtk.org/doc/release/5.0/html/a01331.html).
21 | This algorithm is specialized for reading segmented volume FLARE22.
22 | :param mask: AbdomenCT-1k vtkNIFTIImageReader volume containing the mask
23 | :return: the extracted volume from vtkDiscreteMarchingCubes
24 | """
25 | mask_extractor = vtk.vtkDiscreteMarchingCubes()
26 | mask_extractor.SetInputConnection(reader.GetOutputPort())
27 | return mask_extractor
28 |
29 |
30 | def create_smoother(reducer, smooth_factor):
31 | """
32 | Reorients some points in the volume to smooth the render edges.
33 | (https://www.vtk.org/doc/nightly/html/classvtkSmoothPolyDataFilter.html)
34 | """
35 | smoother = vtk.vtkSmoothPolyDataFilter()
36 | smoother.SetInputConnection(reducer.GetOutputPort())
37 | smoother.SetNumberOfIterations(smooth_factor)
38 | smoother.BoundarySmoothingOn()
39 | return smoother
40 |
41 |
42 | def create_mapper(stripper):
43 | mapper = vtk.vtkPolyDataMapper()
44 | mapper.SetInputConnection(stripper.GetOutputPort())
45 | mapper.ScalarVisibilityOff()
46 | return mapper
47 |
48 |
49 | def create_property(opacity=0.9, color=(1.0, 0.0, 0.0)):
50 | prop = vtk.vtkProperty()
51 | prop.SetColor(color[0], color[1], color[2])
52 | prop.SetOpacity(opacity)
53 | # prop.SetRepresentationToWireframe()
54 | return prop
55 |
56 |
57 | def create_actor(mapper, prop):
58 | actor = vtk.vtkActor()
59 | actor.SetMapper(mapper)
60 | actor.SetProperty(prop)
61 | return actor
62 |
63 |
64 | def create_renderer(bg_color):
65 | renderer = vtk.vtkRenderer()
66 | renderer.SetBackground(bg_color[0], bg_color[1], bg_color[2])
67 | return renderer
68 |
69 |
70 | def create_renderwindow(window_name=APPLICATION_TITLE, window_size=(600, 600)):
71 | render_window = vtk.vtkRenderWindow()
72 | render_window.SetWindowName(window_name)
73 | render_window.SetSize(window_size[0], window_size[1])
74 | return render_window
75 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics.pairwise import cosine_similarity
4 |
5 |
6 |
7 | def calculate_similarity_scores(items):
8 | num_slices = items.shape[0]
9 | similarity_scores = []
10 | for i in range(num_slices):
11 | current_slice = items[i].flatten().reshape(1, -1)
12 | other_slices = np.delete(items, i, axis=0).reshape(-1, items.shape[1] * items.shape[2])
13 | similarities = cosine_similarity(current_slice, other_slices)
14 | score = similarities.sum()
15 | similarity_scores.append(score)
16 | return np.array(similarity_scores)
17 |
18 |
19 | # todo : rank the input
20 | def rank_slices(slices):
21 | similarity_scores = calculate_similarity_scores(slices)
22 | ranked_indices = np.argsort(similarity_scores)[::-1]
23 | ranked_slices = slices[ranked_indices]
24 | return ranked_slices
25 |
26 |
27 | def rank_cond_frames(cond_frame_outputs):
28 | features = []
29 | for cond_frame_index in cond_frame_outputs:
30 | features.append(cond_frame_outputs[cond_frame_index]["maskmem_features"])
31 | similarity_scores = calculate_similarity_scores(np.array(features))
32 | ranked_indices = np.argsort(similarity_scores)[::-1]
33 | ranked_cond_frame_outputs = cond_frame_outputs[ranked_indices]
34 | return ranked_cond_frame_outputs
35 |
36 |
37 | def select_similar_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
38 | """
39 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
40 | that are similar to the current frame at `frame_idx`. Here, we take
41 | - the similar conditioning frame `frame_idx` (if any);
42 |
43 | Outputs:
44 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
45 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
46 | """
47 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
48 | selected_outputs = cond_frame_outputs
49 | unselected_outputs = {}
50 | else:
51 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
52 | selected_outputs = {}
53 | ranked_cond_frame_outputs = rank_cond_frames(cond_frame_outputs)
54 | for cond_frame_index in ranked_cond_frame_outputs:
55 | selected_outputs[cond_frame_index] = cond_frame_outputs[cond_frame_index]
56 |
57 | # add the similar conditioning frame until reaching a total
58 | # of `max_cond_frame_num` conditioning frames.
59 | num_remain = max_cond_frame_num - len(selected_outputs)
60 | inds_remain = sorted(
61 | (t for t in cond_frame_outputs if t not in selected_outputs),
62 | key=lambda x: abs(x - frame_idx),
63 | )[:num_remain]
64 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
65 | unselected_outputs = {
66 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
67 | }
68 |
69 | return selected_outputs, unselected_outputs
70 |
--------------------------------------------------------------------------------
/visualization/visualization_slice.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 | import nibabel as nib
5 | import nrrd
6 | import numpy as np
7 | from matplotlib.colors import ListedColormap
8 |
9 | from visualization_config import get_mask_colors
10 |
11 |
12 | def get_slice_image(image_path, label_path, label_slice_index, image_slice_index):
13 | image_nii = nib.load(image_path)
14 | image_data = image_nii.get_fdata()
15 | if label_path.endswith("nii.gz"):
16 | label_nii = nib.load(label_path)
17 | label_data = label_nii.get_fdata()
18 | else:
19 | label_data, _ = nrrd.read(label_path)
20 |
21 | unique_labels = set(label_data.flatten())
22 | label_count = len(unique_labels)
23 | print(f"Label Count : {label_count}")
24 |
25 | # get slice
26 | image_slice = image_data[:, :, image_slice_index]
27 | label_slice = label_data[:, :, label_slice_index]
28 | image_slice = np.rot90(image_slice, k=1)
29 | label_slice = np.rot90(label_slice, k=1)
30 | return image_slice, label_slice
31 |
32 |
33 | def download_image(dataset_name, image_slice, label_slice, output_path):
34 | custom_colors = get_mask_colors(dataset_name)
35 | color_count = len(custom_colors)
36 | print(f"Color Count : {color_count}")
37 | # custom_colors = [(0, 0, 0, 0)] + custom_colors
38 |
39 | unique_labels = np.unique(label_slice)
40 | unique_labels = unique_labels.astype(int)
41 | unique_labels = unique_labels[unique_labels != 0]
42 | # print(unique_labels)
43 |
44 | actual_colors = [custom_colors[i - 1] for i in unique_labels]
45 | cmap = ListedColormap(actual_colors)
46 |
47 | new_label_slice = np.zeros_like(label_slice)
48 | for i, label in enumerate(unique_labels):
49 | new_label_slice[label_slice == label] = i + 1
50 |
51 | fig, ax = plt.subplots()
52 | ax.axis('off')
53 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
54 | plt.imshow(image_slice, cmap='gray', interpolation='bilinear')
55 | new_label_slice = np.ma.masked_where(new_label_slice == 0, new_label_slice)
56 | plt.imshow(new_label_slice, cmap=cmap, alpha=0.75)
57 | plt.axis('off')
58 | plt.savefig(output_path, bbox_inches='tight', pad_inches=0, transparent=True)
59 | plt.show()
60 |
61 |
62 | if __name__ == '__main__':
63 | dataset_name = "..."
64 | image_path = '../input/...'
65 | label_path = '../input/...'
66 | if label_path.endswith("seg.nrrd"):
67 | model_name = label_path.split("/")[-1].split(".")[0].split("_")[-1]
68 | else:
69 | model_name = "GT"
70 | output_name = image_path.split('/')[3].split('.')[0]
71 | output_path = os.path.join('../output/.../', output_name)
72 | if not os.path.exists(output_path):
73 | os.makedirs(output_path)
74 |
75 | png_path = os.path.join(output_path, model_name + ".png")
76 | print(f"output path : {png_path}")
77 | label_slice_index = 67
78 | image_slice_index = 67
79 |
80 | image_slice, label_slice = get_slice_image(image_path=image_path, label_path=label_path,
81 | image_slice_index=image_slice_index, label_slice_index=label_slice_index)
82 | download_image(dataset_name, image_slice, label_slice, png_path)
83 |
--------------------------------------------------------------------------------
/models/sam2/backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Some utilities for backbones, in particular for windowing"""
8 |
9 | from typing import Tuple
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | def window_partition(x, window_size):
17 | """
18 | Partition into non-overlapping windows with padding if needed.
19 | Args:
20 | x (tensor): input tokens with [B, H, W, C].
21 | window_size (int): window size.
22 | Returns:
23 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
24 | (Hp, Wp): padded height and width before partition
25 | """
26 | B, H, W, C = x.shape
27 |
28 | pad_h = (window_size - H % window_size) % window_size
29 | pad_w = (window_size - W % window_size) % window_size
30 | if pad_h > 0 or pad_w > 0:
31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32 | Hp, Wp = H + pad_h, W + pad_w
33 |
34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35 | windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
36 | return windows, (Hp, Wp)
37 |
38 |
39 | def window_unpartition(windows, window_size, pad_hw, hw):
40 | """
41 | Window unpartition into original sequences and removing padding.
42 | Args:
43 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
44 | window_size (int): window size.
45 | pad_hw (Tuple): padded height and width (Hp, Wp).
46 | hw (Tuple): original height and width (H, W) before padding.
47 | Returns:
48 | x: unpartitioned sequences with [B, H, W, C].
49 | """
50 | Hp, Wp = pad_hw
51 | H, W = hw
52 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
53 | x = windows.reshape(
54 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1
55 | )
56 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
57 |
58 | if Hp > H or Wp > W:
59 | x = x[:, :H, :W, :]
60 | return x
61 |
62 |
63 | class PatchEmbed(nn.Module):
64 | """
65 | Image to Patch Embedding.
66 | """
67 |
68 | def __init__(
69 | self,
70 | kernel_size: Tuple[int, ...] = (7, 7),
71 | stride: Tuple[int, ...] = (4, 4),
72 | padding: Tuple[int, ...] = (3, 3),
73 | in_chans: int = 3,
74 | embed_dim: int = 768,
75 | ):
76 | """
77 | Args:
78 | kernel_size (Tuple): kernel size of the projection layer.
79 | stride (Tuple): stride of the projection layer.
80 | padding (Tuple): padding size of the projection layer.
81 | in_chans (int): Number of input image channels.
82 | embed_dim (int): embed_dim (int): Patch embedding dimension.
83 | """
84 | super().__init__()
85 | self.proj = nn.Conv2d(
86 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
87 | )
88 |
89 | def forward(self, x: torch.Tensor) -> torch.Tensor:
90 | x = self.proj(x)
91 | # B C H W -> B H W C
92 | x = x.permute(0, 2, 3, 1)
93 | return x
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # (ACM MM 25) CRISP-SAM2 : SAM2 with Cross-Modal Interaction and Semantic Prompting for Multi-Organ Segmentation
2 | #### Xinlei Yu 1, Changmiao Wang 2, Hui Jin 1, Ahmed Elazab 3, Gangyong Jia 1, Xiang Wan 2, Changqing Zou 4, Ruiquan Ge 1
3 | #### 1 Hangzhou Dianzi University, 2 Shenzhen Research Institute of Big Data, 3 Shenzhen University, 4 Zhejiang University
4 |
5 |
6 | ## 🌟Overview
7 |
8 | 
9 |
10 | ## 🛠️ Quick Start
11 | ## Installation
12 | It is highly recommended to employ a virtual environment with Python >= 3.10, Pytorch >= 2.5.1 and corresponding CUDA.
13 | ```
14 | cd CRISP-SAM2
15 | conda env create -f env_config.yml
16 | conda activate CRISP_SAM2
17 | ```
18 |
19 |
20 | ## Dataset Preparation
21 | - ### Visual Inputs
22 | | Datasets | Links |
23 | |--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|
24 | | M3D-Seg | https://github.com/BAAIDCAI/M3D/
https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg/
https://www.modelscope.cn/datasets/GoodBaiBai88/M3D-Seg/ |
25 | | MSD-Spleen | http://medicaldecathlon.com/ |
26 | | Pancreas-CT | https://wiki.cancerimagingarchive.net/display/public/pancreas-ct/ |
27 | | LUNA16 | https://luna16.grand-challenge.org/Data/ |
28 | | AbdomenCT-1k | https://github.com/JunMa11/AbdomenCT-1K/ |
29 | | WORD | https://paperswithcode.com/dataset/word/ |
30 | | FLARE22 | https://flare22.grand-challenge.org/ |
31 | | AMOS22 | https://amos22.grand-challenge.org/ |
32 |
33 | - ### Textual Inputs
34 | The descriptive definitions and descriptions are stored in 'term_dictionary.json', and compared to the original M3D-Seg joint dataset, we add supplementary sentences. Here, the dictionary can be expanded arbitrarily as required.
35 |
36 | ## Train & Test
37 | - ### Training process
38 | We highly recommend that the whole training process should be conducted on at least 8 A100-80G GPUs.
39 | ```
40 | bash scripts/train.sh
41 | ```
42 | - ### Test process
43 | ```
44 | bash scripts/test.sh
45 | ```
46 |
47 |
48 | ## Visualization
49 | We provide comprehensive visualization utils, including 2D, 3D and local area visualization.
50 |
51 | 
52 | 
53 |
54 | ## Citation
55 | If you have any questions about this work, please feel free to contact me at: xinleiyu88@gmail.com. And if you want to cite us, please add this in your paper:
56 | ```
57 | @article{yu2025crisp,
58 | title={CRISP-SAM2: SAM2 with Cross-Modal Interaction and Semantic Prompting for Multi-Organ Segmentation},
59 | author={Yu, Xinlei and Wang, Changmiao and Jin, Hui and Elazab, Ahmed and Jia, Gangyong and Wan, Xiang and Zou, Changqing and Ge, Ruiquan},
60 | journal={arXiv preprint arXiv:2506.23121},
61 | year={2025}
62 | }
63 | ```
--------------------------------------------------------------------------------
/models/sam2/configs/sam2/sam2_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [64, 64]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [64, 64]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | # use high-resolution feature map in the SAM mask decoder
93 | use_high_res_features_in_sam: true
94 | # output 3 masks on the first click on initial conditioning frames
95 | multimask_output_in_sam: true
96 | # SAM heads
97 | iou_prediction_use_sigmoid: True
98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99 | use_obj_ptrs_in_encoder: true
100 | add_tpos_enc_to_obj_ptrs: false
101 | only_obj_ptrs_in_the_past_for_eval: true
102 | # object occlusion prediction
103 | pred_obj_scores: true
104 | pred_obj_scores_mlp: true
105 | fixed_no_obj_ptr: true
106 | # multimask tracking settings
107 | multimask_output_for_tracking: true
108 | use_multimask_token_for_obj_ptr: true
109 | multimask_min_pt_num: 0
110 | multimask_max_pt_num: 1
111 | use_mlp_for_obj_ptr_proj: true
112 | # Compilation flag
113 | compile_image_encoder: False
114 |
--------------------------------------------------------------------------------
/models/sam2/configs/sam2/sam2_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [64, 64]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [64, 64]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | # use high-resolution feature map in the SAM mask decoder
96 | use_high_res_features_in_sam: true
97 | # output 3 masks on the first click on initial conditioning frames
98 | multimask_output_in_sam: true
99 | # SAM heads
100 | iou_prediction_use_sigmoid: True
101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102 | use_obj_ptrs_in_encoder: true
103 | add_tpos_enc_to_obj_ptrs: false
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 |
--------------------------------------------------------------------------------
/models/sam2/configs/sam2/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [64, 64]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [64, 64]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | compile_image_encoder: False
118 |
--------------------------------------------------------------------------------
/models/sam2/configs/sam2/sam2_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [64, 64]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [64, 64]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | # HieraT does not currently support compilation, should always be set to False
118 | compile_image_encoder: False
119 |
--------------------------------------------------------------------------------
/visualization/visualization_3d.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import SimpleITK as sitk
4 |
5 | from visualization_config import *
6 | from utils import *
7 |
8 |
9 | def mhd_to_nii(file_path: str):
10 | itk_image = sitk.ReadImage(file_path)
11 | save_path = file_path.replace(".mhd", ".nii.gz")
12 | out_arr = sitk.GetArrayFromImage(itk_image)
13 | out = sitk.GetImageFromArray(out_arr)
14 | sitk.WriteImage(out, save_path)
15 | return save_path
16 |
17 |
18 | if __name__ == '__main__':
19 | # configs
20 | show_axes = False
21 | show_outline = False
22 | generate_outline_face = False
23 | take_snapshot = True
24 | SMOOTH_FACTOR = 400
25 | ROTATE_X = 270
26 | ROTATE_Y = 0
27 | ROTATE_Z = 0
28 | model_name = 'ours'
29 | dataset_name = '.../'
30 | MASK_COLORS = get_mask_colors(dataset_name)
31 | MASK_OPACITY = get_mask_opacity(dataset_name)
32 | file_name = '../input/...'
33 | if file_name.endswith(".mhd"):
34 | file_name = mhd_to_nii(file_name)
35 | output_name = file_name.split('/')[3].split('.')[12]
36 | else:
37 | output_name = file_name.split('/')[3].split('.')[0]
38 | output_path = '../output/' + dataset_name + output_name
39 | if not os.path.exists(output_path):
40 | os.makedirs(output_path)
41 | print(f"file path : {file_name}")
42 | print(f"output path : {output_path}")
43 | rotate_config = '_' + str(SMOOTH_FACTOR) + '_' + str(ROTATE_X) + '_' + str(ROTATE_Y) + '_' + str(ROTATE_Z)
44 | snapshot_filename = output_path + '/' + model_name + rotate_config + '.png'
45 |
46 | # reader
47 | reader = read_volume(file_name)
48 |
49 | # transform
50 | mask_transform = vtk.vtkTransform()
51 | mask_transform.PostMultiply()
52 | mask_transform.Scale(SCALE) # scale
53 | mask_transform.RotateX(ROTATE_X) # rotate
54 | mask_transform.RotateY(ROTATE_Y)
55 | mask_transform.RotateZ(ROTATE_Z)
56 |
57 | # renderer and render window
58 | renderer = create_renderer(bg_color=RENDERER_BG_COLOR)
59 | render_window = create_renderwindow()
60 | render_window.AddRenderer(renderer)
61 | # render_window.SetSize(100, 100)
62 | # renderer.SetViewport(0.1, 0.1, 0.9, 0.9)
63 | # mapper and actors for segmentation results
64 | n_labels = int(reader.GetOutput().GetScalarRange()[1])
65 | print(n_labels)
66 |
67 | for idx in range(n_labels):
68 | extracter = create_mask_extractor(reader) # extracter
69 | extracter.SetValue(0, idx + 1)
70 | smoother = create_smoother(extracter, SMOOTH_FACTOR) # smoother
71 | mapper = create_mapper(stripper=smoother)
72 | prop = create_property(opacity=MASK_OPACITY[idx], color=MASK_COLORS[idx]) # property
73 | actor = create_actor(mapper=mapper, prop=prop) # actor
74 | actor.SetUserTransform(mask_transform)
75 | renderer.AddActor(actor)
76 |
77 | # outline of the whole image
78 | if show_outline:
79 | outline = vtk.vtkOutlineFilter() # show outline
80 | outline.SetInputConnection(reader.GetOutputPort())
81 | if generate_outline_face: # show surface of the outline
82 | outline.GenerateFacesOn()
83 | extracter = create_mask_extractor(reader)
84 | mapper = create_mapper(stripper=outline)
85 | prop = create_property(opacity=OUTLINE_OPACITY, color=OUTLINE_COLOR)
86 | actor = create_actor(mapper=mapper, prop=prop)
87 | actor.SetUserTransform(mask_transform)
88 | renderer.AddActor(actor)
89 |
90 | # show axes for better visualization
91 | if show_axes:
92 | axes_actor = vtk.vtkAxesActor()
93 | axes_actor.SetTotalLength(TOTAL_LENGTH[0], TOTAL_LENGTH[1], TOTAL_LENGTH[2]) # set axes length
94 | axes_actor.SetScale(5, 5, 5)
95 | renderer.AddActor(axes_actor)
96 |
97 | # start render
98 | render_window.Render()
99 |
100 | # screenshot
101 | if take_snapshot:
102 | w2if = vtk.vtkWindowToImageFilter()
103 | w2if.SetInput(render_window)
104 | w2if.SetInputBufferTypeToRGB()
105 | w2if.ReadFrontBufferOff()
106 | w2if.Update()
107 |
108 | writer = vtk.vtkPNGWriter()
109 | writer.SetFileName(snapshot_filename)
110 | writer.SetInputConnection(w2if.GetOutputPort())
111 | writer.Write()
112 |
113 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from datasets.data_loader import get_loader
4 | from models.crisp_sam2 import CrispSam2
5 | from models.sam2.build_sam import build_sam2
6 | import os
7 | import argparse
8 | from datetime import datetime
9 | import numpy as np
10 | from monai.metrics import compute_meandice, compute_average_surface_distance
11 |
12 |
13 | def get_test_args_parser():
14 | parser = argparse.ArgumentParser(description='Test Configuration')
15 |
16 | parser.add_argument("--mode", type=str, default='test')
17 | parser.add_argument("--pretrain", type=str, default='/path/to/pretrained_model')
18 | parser.add_argument("--data_dir", type=str, default='/path/to/data')
19 | parser.add_argument("--dataset_codes", type=list, default=['0003'])
20 | parser.add_argument("--patch_size", default=(96, 96, 96), type=tuple)
21 | parser.add_argument("--spatial_size", default=(32, 512, 512), type=tuple)
22 | parser.add_argument("--work_dir", type=str, default='./work_dir')
23 | parser.add_argument("--config_file", type=str, default='./path/to/config_file')
24 | parser.add_argument("--sam2_ckpt", type=str, default='./path/to/sam2_ckpt')
25 | # parser.add_argument("--model_id", type=str, default='hf_id')
26 | parser.add_argument("--clip_text_ckpt", type=str, default='./path/to/clip_text_ckpt')
27 | parser.add_argument("--clip_image_ckpt", type=str, default='./path/to/clip_image_ckpt')
28 | parser.add_argument('--num_workers', type=int, default=8)
29 | parser.add_argument('--batch_size', type=int, default=1)
30 | parser.add_argument('--result_dir', type=str, default='./results')
31 | parser.add_argument('--gpu', type=int, default=0)
32 | return parser.parse_args()
33 |
34 |
35 | def test(args):
36 | sam_model = build_sam2(config_file=args.config_file, checkpoint=args.sam2_ckpt, mode="eval")
37 | model = CrispSam2(
38 | image_encoder=sam_model.image_encoder,
39 | memory_attention=sam_model.memory_attention,
40 | memory_encoder=sam_model.memory_encoder,
41 | clip_text_ckpt=args.clip_text_ckpt,
42 | clip_image_ckpt=args.clip_image_ckpt,
43 | ).cuda(args.gpu)
44 | model.eval()
45 |
46 | if args.pretrain:
47 | checkpoint = torch.load(args.pretrain, map_location=f'cuda:{args.gpu}')
48 | model.load_state_dict(checkpoint['model'])
49 | print(f"Loaded checkpoint from {args.pretrain}")
50 |
51 | test_dataloader = get_loader(args)
52 |
53 | os.makedirs(args.result_dir, exist_ok=True)
54 | log_file = os.path.join(args.result_dir, f"test_log_{datetime.now().strftime('%Y%m%d-%H%M')}.txt")
55 |
56 | total_dsc = []
57 | total_nsd = []
58 | with torch.no_grad(), tqdm(total=len(test_dataloader), desc="Testing") as pbar:
59 | for batch in test_dataloader:
60 | image = batch["image"].cuda(args.gpu)
61 | gt3D = batch["post_label"].cuda(args.gpu)
62 | organ_name_list = batch['organ_name_list']
63 | text_list = batch['text']
64 | patient_id = batch['patient_id'][0]
65 |
66 | pred_masks = []
67 | for cls_idx in range(len(organ_name_list)):
68 | labels_cls = gt3D[:, cls_idx]
69 | organ_name = organ_name_list[cls_idx]
70 | crisp_masks = []
71 |
72 | for frame_idx in range(image.size()[-1]):
73 | output = model(image, text_list, frame_idx)
74 | crisp_mask = output['crisp_masks'].sigmoid() > 0.5
75 | crisp_masks.append(crisp_mask.cpu().numpy())
76 |
77 | mask_3d = np.stack(crisp_masks, axis=0).squeeze()
78 | pred_masks.append(mask_3d)
79 | save_path = os.path.join(args.result_dir, patient_id)
80 | os.makedirs(save_path, exist_ok=True)
81 | np.save(os.path.join(save_path, f"{organ_name}_pred.npy"), mask_3d)
82 |
83 | if gt3D is not None:
84 | gt_mask = labels_cls.squeeze().cpu().numpy()
85 | dsc = compute_meandice(mask_3d, gt_mask, include_background=False)
86 | total_dsc.append(dsc)
87 | nsd = compute_average_surface_distance(mask_3d, gt_mask, include_background=False)
88 | total_nsd.append(nsd)
89 | with open(log_file, 'a') as f:
90 | f.write(f"Patient {patient_id}, Organ {organ_name}: DSC={dsc:.4f}, NSD={nsd:.4f}\n")
91 |
92 | pbar.update(1)
93 |
94 | if total_dsc and total_nsd:
95 | mean_dsc = np.mean(total_dsc)
96 | mean_nsd = np.mean(total_nsd)
97 | print(f"Test Finished. Mean DSC: {mean_dsc:.4f}, Mean NSD: {mean_nsd:.4f}")
98 | with open(log_file, 'a') as f:
99 | f.write(f"Overall Mean DSC: {mean_dsc:.4f}, Overall Mean NSD: {mean_nsd:.4f}\n")
100 | else:
101 | print("No ground truth provided, skipped metric calculation.")
102 |
103 |
104 | def main():
105 | args = get_test_args_parser()
106 | print("Test Arguments:", args)
107 | test(args)
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
--------------------------------------------------------------------------------
/models/sam2/utils/transformers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import warnings
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchvision.transforms import Normalize, Resize, ToTensor
13 |
14 |
15 | class SAM2Transforms(nn.Module):
16 | def __init__(
17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
18 | ):
19 | """
20 | Transforms for SAM2.
21 | """
22 | super().__init__()
23 | self.resolution = resolution
24 | self.mask_threshold = mask_threshold
25 | self.max_hole_area = max_hole_area
26 | self.max_sprinkle_area = max_sprinkle_area
27 | self.mean = [0.485, 0.456, 0.406]
28 | self.std = [0.229, 0.224, 0.225]
29 | self.to_tensor = ToTensor()
30 | self.transforms = torch.jit.script(
31 | nn.Sequential(
32 | Resize((self.resolution, self.resolution)),
33 | Normalize(self.mean, self.std),
34 | )
35 | )
36 |
37 | def __call__(self, x):
38 | x = self.to_tensor(x)
39 | return self.transforms(x)
40 |
41 | def forward_batch(self, img_list):
42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
43 | img_batch = torch.stack(img_batch, dim=0)
44 | return img_batch
45 |
46 | def transform_coords(
47 | self, coords: torch.Tensor, normalize=False, orig_hw=None
48 | ) -> torch.Tensor:
49 | """
50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
52 |
53 | Returns
54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
55 | """
56 | if normalize:
57 | assert orig_hw is not None
58 | h, w = orig_hw
59 | coords = coords.clone()
60 | coords[..., 0] = coords[..., 0] / w
61 | coords[..., 1] = coords[..., 1] / h
62 |
63 | coords = coords * self.resolution # unnormalize coords
64 | return coords
65 |
66 | def transform_boxes(
67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None
68 | ) -> torch.Tensor:
69 | """
70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
72 | """
73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
74 | return boxes
75 |
76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
77 | """
78 | Perform PostProcessing on output masks.
79 | """
80 | from sam2.utils.misc import get_connected_components
81 |
82 | masks = masks.float()
83 | input_masks = masks
84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
85 | try:
86 | if self.max_hole_area > 0:
87 | # Holes are those connected components in background with area <= self.fill_hole_area
88 | # (background regions are those with mask scores <= self.mask_threshold)
89 | labels, areas = get_connected_components(
90 | mask_flat <= self.mask_threshold
91 | )
92 | is_hole = (labels > 0) & (areas <= self.max_hole_area)
93 | is_hole = is_hole.reshape_as(masks)
94 | # We fill holes with a small positive mask score (10.0) to change them to foreground.
95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
96 |
97 | if self.max_sprinkle_area > 0:
98 | labels, areas = get_connected_components(
99 | mask_flat > self.mask_threshold
100 | )
101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
102 | is_hole = is_hole.reshape_as(masks)
103 | # We fill holes with negative mask score (-10.0) to change them to background.
104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
105 | except Exception as e:
106 | # Skip the post-processing step if the CUDA kernel fails
107 | warnings.warn(
108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can "
109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
110 | "functionality may be limited (which doesn't affect the results in most cases; see "
111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
112 | category=UserWarning,
113 | stacklevel=2,
114 | )
115 | masks = input_masks
116 |
117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
118 | return masks
--------------------------------------------------------------------------------
/visualization/visualization_config.py:
--------------------------------------------------------------------------------
1 | """
2 | GUI config
3 | """
4 | # main window, only for Qt GUI
5 | APPLICATION_TITLE = "3D Segmentation Visualizer"
6 | MIN_IMG_WINDOW_WIDTH = 200
7 |
8 | """
9 | VTK config
10 | """
11 | COLOR_CONFIG = {
12 | "liver": (220 / 256, 0 / 256, 0 / 256),
13 | "spleen": (0 / 256, 220 / 256, 0 / 256),
14 | "stomach": (0 / 256, 220 / 256, 220 / 256),
15 | "gallbladder": (220 / 256, 0 / 256, 220 / 256),
16 | "esophagus": (225 / 256, 210 / 256, 190 / 256),
17 | "pancreas": (0 / 256, 0 / 256, 250 / 256),
18 | "duodenum": (200 / 256, 130 / 256, 60 / 256),
19 | "aorta": (0 / 256, 120 / 256, 120 / 256),
20 | "bladder": (40 / 256, 140 / 256, 90 / 256),
21 | "inferior vena cava": (220 / 256, 190 / 256, 150 / 256),
22 | "left kidney": (40 / 256, 60 / 256, 80 / 256),
23 | "right kidney": (220 / 256, 220 / 256, 0 / 256),
24 | "left adrenal gland": (90 / 256, 20 / 256, 110 / 256),
25 | "right adrenal gland": (200 / 256, 60 / 256, 90 / 256),
26 | "left femur": (255 / 256, 230 / 256, 225 / 256),
27 | "right femur": (110 / 256, 90 / 256, 200 / 256),
28 | "left lung": (100 / 256, 40 / 256, 140 / 256),
29 | "right lung": (60 / 256, 100 / 256, 200 / 256),
30 | "default": (1, 1, 1)
31 | }
32 |
33 |
34 | # mask config
35 | def get_mask_colors(dataset_name: str):
36 | if "word" in dataset_name.lower():
37 | return [
38 | # # WORD
39 | COLOR_CONFIG["liver"],
40 | COLOR_CONFIG["spleen"],
41 | COLOR_CONFIG["left kidney"],
42 | COLOR_CONFIG["right kidney"],
43 | COLOR_CONFIG["stomach"],
44 | COLOR_CONFIG["gallbladder"],
45 | COLOR_CONFIG["esophagus"],
46 | COLOR_CONFIG["pancreas"],
47 | COLOR_CONFIG["duodenum"],
48 | COLOR_CONFIG["default"],
49 | COLOR_CONFIG["default"],
50 | COLOR_CONFIG["default"],
51 | COLOR_CONFIG["default"],
52 | COLOR_CONFIG["bladder"],
53 | COLOR_CONFIG["left femur"],
54 | COLOR_CONFIG["right femur"],
55 | ]
56 | elif "flare" in dataset_name.lower():
57 | return [
58 | COLOR_CONFIG["liver"],
59 | COLOR_CONFIG["right kidney"],
60 | COLOR_CONFIG["spleen"],
61 | COLOR_CONFIG["pancreas"],
62 | COLOR_CONFIG["aorta"],
63 | COLOR_CONFIG["inferior vena cava"],
64 | COLOR_CONFIG["right adrenal gland"],
65 | COLOR_CONFIG["left adrenal gland"],
66 | COLOR_CONFIG["gallbladder"],
67 | COLOR_CONFIG["esophagus"],
68 | COLOR_CONFIG["stomach"],
69 | COLOR_CONFIG["duodenum"],
70 | COLOR_CONFIG["left kidney"],
71 | ]
72 | elif "abdomen" in dataset_name.lower():
73 | return [
74 | COLOR_CONFIG["liver"],
75 | COLOR_CONFIG["default"],
76 | COLOR_CONFIG["spleen"],
77 | COLOR_CONFIG["pancreas"],
78 | COLOR_CONFIG["left kidney"],
79 | COLOR_CONFIG["right kidney"],
80 | ]
81 | elif "amos" in dataset_name.lower():
82 | return [
83 | COLOR_CONFIG["spleen"],
84 | COLOR_CONFIG["right kidney"],
85 | COLOR_CONFIG["left kidney"],
86 | COLOR_CONFIG["gallbladder"],
87 | COLOR_CONFIG["esophagus"],
88 | COLOR_CONFIG["liver"],
89 | COLOR_CONFIG["stomach"],
90 | COLOR_CONFIG["aorta"],
91 | COLOR_CONFIG["inferior vena cava"],
92 | COLOR_CONFIG["pancreas"],
93 | COLOR_CONFIG["right adrenal gland"],
94 | COLOR_CONFIG["left adrenal gland"],
95 | COLOR_CONFIG["duodenum"],
96 | COLOR_CONFIG["bladder"],
97 | COLOR_CONFIG["default"],
98 | ]
99 | elif "luna" in dataset_name.lower():
100 | return [
101 | COLOR_CONFIG["default"],
102 | COLOR_CONFIG["default"],
103 | COLOR_CONFIG["right lung"],
104 | COLOR_CONFIG["left lung"],
105 | COLOR_CONFIG["default"],
106 | ]
107 | elif "spleen" in dataset_name.lower():
108 | return [
109 | COLOR_CONFIG["spleen"],
110 | ]
111 | elif "pancreas" in dataset_name.lower():
112 | return [
113 | COLOR_CONFIG["pancreas"],
114 | COLOR_CONFIG["default"],
115 | ]
116 | else:
117 | return []
118 |
119 |
120 | def get_mask_opacity(dataset_name: str):
121 | if "word" in dataset_name.lower():
122 | return [
123 | # # WORD
124 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1
125 | ]
126 | elif "flare" in dataset_name.lower():
127 | return [
128 | # # FLARE22
129 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
130 | ]
131 | elif "abdomen" in dataset_name.lower():
132 | return [
133 | 1, 1, 1, 1, 1, 1
134 | ]
135 | elif "amos" in dataset_name.lower():
136 | return [
137 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0
138 | ]
139 | elif "luna" in dataset_name.lower():
140 | return [
141 | 0, 0, 1, 1, 0
142 | ]
143 | elif "spleen" in dataset_name.lower():
144 | return [
145 | 1
146 | ]
147 | elif "pancreas" in dataset_name.lower():
148 | return [
149 | 1, 0
150 | ]
151 | else:
152 | return []
153 |
154 |
155 | SMOOTH_FACTOR = 400
156 | MAX_LABEL_LENGTH = 10
157 |
158 | # renderer
159 | COMPARE = True
160 | RENDERER_BG_COLOR = (1., 1., 1.)
161 |
162 | # outline config
163 | SHOW_OUTLINE = True
164 | OUTLINE_COLOR = (0, 1, 1)
165 | OUTLINE_OPACITY = 0.2
166 |
167 | # transform config
168 | ROTATE_X = 270
169 | ROTATE_Y = 180
170 | ROTATE_Z = 0
171 | SCALE = (100, 100, 100)
172 | """
173 | You should set the rotate config for the best effect of visualization.
174 | """
175 | # axes config
176 | SHOW_AXES = True
177 | TOTAL_LENGTH = (50, 50, 50)
178 |
179 |
--------------------------------------------------------------------------------
/models/sam2/memory_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import nn, Tensor
11 |
12 | from models.sam2.sam.transformer import RoPEAttention
13 |
14 | from models.sam2.sam2_utils import get_activation_fn, get_clones
15 |
16 |
17 | class MemoryAttentionLayer(nn.Module):
18 |
19 | def __init__(
20 | self,
21 | activation: str,
22 | cross_attention: nn.Module,
23 | d_model: int,
24 | dim_feedforward: int,
25 | dropout: float,
26 | pos_enc_at_attn: bool,
27 | pos_enc_at_cross_attn_keys: bool,
28 | pos_enc_at_cross_attn_queries: bool,
29 | self_attention: nn.Module,
30 | ):
31 | super().__init__()
32 | self.d_model = d_model
33 | self.dim_feedforward = dim_feedforward
34 | self.dropout_value = dropout
35 | self.self_attn = self_attention
36 | self.cross_attn_image = cross_attention
37 |
38 | # Implementation of Feedforward model
39 | self.linear1 = nn.Linear(d_model, dim_feedforward)
40 | self.dropout = nn.Dropout(dropout)
41 | self.linear2 = nn.Linear(dim_feedforward, d_model)
42 |
43 | self.norm1 = nn.LayerNorm(d_model)
44 | self.norm2 = nn.LayerNorm(d_model)
45 | self.norm3 = nn.LayerNorm(d_model)
46 | self.dropout1 = nn.Dropout(dropout)
47 | self.dropout2 = nn.Dropout(dropout)
48 | self.dropout3 = nn.Dropout(dropout)
49 |
50 | self.activation_str = activation
51 | self.activation = get_activation_fn(activation)
52 |
53 | # Where to add pos enc
54 | self.pos_enc_at_attn = pos_enc_at_attn
55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57 |
58 | def _forward_sa(self, tgt, query_pos):
59 | # Self-Attention
60 | tgt2 = self.norm1(tgt)
61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62 | tgt2 = self.self_attn(q, k, v=tgt2)
63 | tgt = tgt + self.dropout1(tgt2)
64 | return tgt
65 |
66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67 | kwds = {}
68 | if num_k_exclude_rope > 0:
69 | assert isinstance(self.cross_attn_image, RoPEAttention)
70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71 |
72 | # Cross-Attention
73 | tgt2 = self.norm2(tgt)
74 | tgt2 = self.cross_attn_image(
75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77 | v=memory,
78 | **kwds,
79 | )
80 | tgt = tgt + self.dropout2(tgt2)
81 | return tgt
82 |
83 | def forward(
84 | self,
85 | tgt,
86 | memory,
87 | pos: Optional[Tensor] = None,
88 | query_pos: Optional[Tensor] = None,
89 | num_k_exclude_rope: int = 0,
90 | ) -> torch.Tensor:
91 |
92 | # Self-Attn, Cross-Attn
93 | tgt = self._forward_sa(tgt, query_pos)
94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95 | # MLP
96 | tgt2 = self.norm3(tgt)
97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98 | tgt = tgt + self.dropout3(tgt2)
99 | return tgt
100 |
101 |
102 | class MemoryAttention(nn.Module):
103 | def __init__(
104 | self,
105 | d_model: int,
106 | pos_enc_at_input: bool,
107 | layer: nn.Module,
108 | num_layers: int,
109 | batch_first: bool = True, # Do layers expect batch first input?
110 | ):
111 | super().__init__()
112 | self.d_model = d_model
113 | self.layers = get_clones(layer, num_layers)
114 | self.num_layers = num_layers
115 | self.norm = nn.LayerNorm(d_model)
116 | self.pos_enc_at_input = pos_enc_at_input
117 | self.batch_first = batch_first
118 |
119 | def forward(
120 | self,
121 | curr: torch.Tensor, # self-attention inputs
122 | memory: torch.Tensor, # cross-attention inputs
123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126 | ):
127 | if isinstance(curr, list):
128 | assert isinstance(curr_pos, list)
129 | assert len(curr) == len(curr_pos) == 1
130 | curr, curr_pos = (
131 | curr[0],
132 | curr_pos[0],
133 | )
134 |
135 | assert (
136 | curr.shape[1] == memory.shape[1]
137 | ), "Batch size must be the same for curr and memory"
138 |
139 | output = curr
140 | if self.pos_enc_at_input and curr_pos is not None:
141 | output = output + 0.1 * curr_pos
142 |
143 | if self.batch_first:
144 | # Convert to batch first
145 | output = output.transpose(0, 1)
146 | curr_pos = curr_pos.transpose(0, 1)
147 | memory = memory.transpose(0, 1)
148 | memory_pos = memory_pos.transpose(0, 1)
149 |
150 | for layer in self.layers:
151 | kwds = {}
152 | if isinstance(layer.cross_attn_image, RoPEAttention):
153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154 |
155 | output = layer(
156 | tgt=output,
157 | memory=memory,
158 | pos=memory_pos,
159 | query_pos=curr_pos,
160 | **kwds,
161 | )
162 | normed_output = self.norm(output)
163 |
164 | if self.batch_first:
165 | # Convert back to seq first
166 | normed_output = normed_output.transpose(0, 1)
167 | curr_pos = curr_pos.transpose(0, 1)
168 |
169 | return normed_output
--------------------------------------------------------------------------------
/models/sam2/memory_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Tuple
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from models.sam2.sam2_utils import DropPath, get_clones, LayerNorm2d
15 |
16 |
17 | class MaskDownSampler(nn.Module):
18 | """
19 | Progressively downsample a mask by total_stride, each time by stride.
20 | Note that LayerNorm is applied per *token*, like in ViT.
21 |
22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23 | In the end, we linearly project to embed_dim channels.
24 | """
25 |
26 | def __init__(
27 | self,
28 | embed_dim=256,
29 | kernel_size=4,
30 | stride=4,
31 | padding=0,
32 | total_stride=16,
33 | activation=nn.GELU,
34 | ):
35 | super().__init__()
36 | num_layers = int(math.log2(total_stride) // math.log2(stride))
37 | assert stride**num_layers == total_stride
38 | self.encoder = nn.Sequential()
39 | mask_in_chans, mask_out_chans = 1, 1
40 | for _ in range(num_layers):
41 | mask_out_chans = mask_in_chans * (stride**2)
42 | self.encoder.append(
43 | nn.Conv2d(
44 | mask_in_chans,
45 | mask_out_chans,
46 | kernel_size=kernel_size,
47 | stride=stride,
48 | padding=padding,
49 | )
50 | )
51 | self.encoder.append(LayerNorm2d(mask_out_chans))
52 | self.encoder.append(activation())
53 | mask_in_chans = mask_out_chans
54 |
55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56 |
57 | def forward(self, x):
58 | return self.encoder(x)
59 |
60 |
61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62 | class CXBlock(nn.Module):
63 | r"""ConvNeXt Block. There are two equivalent implementations:
64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66 | We use (2) as we find it slightly faster in PyTorch
67 |
68 | Args:
69 | dim (int): Number of input channels.
70 | drop_path (float): Stochastic depth rate. Default: 0.0
71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72 | """
73 |
74 | def __init__(
75 | self,
76 | dim,
77 | kernel_size=7,
78 | padding=3,
79 | drop_path=0.0,
80 | layer_scale_init_value=1e-6,
81 | use_dwconv=True,
82 | ):
83 | super().__init__()
84 | self.dwconv = nn.Conv2d(
85 | dim,
86 | dim,
87 | kernel_size=kernel_size,
88 | padding=padding,
89 | groups=dim if use_dwconv else 1,
90 | ) # depthwise conv
91 | self.norm = LayerNorm2d(dim, eps=1e-6)
92 | self.pwconv1 = nn.Linear(
93 | dim, 4 * dim
94 | ) # pointwise/1x1 convs, implemented with linear layers
95 | self.act = nn.GELU()
96 | self.pwconv2 = nn.Linear(4 * dim, dim)
97 | self.gamma = (
98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99 | if layer_scale_init_value > 0
100 | else None
101 | )
102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103 |
104 | def forward(self, x):
105 | input = x
106 | x = self.dwconv(x)
107 | x = self.norm(x)
108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109 | x = self.pwconv1(x)
110 | x = self.act(x)
111 | x = self.pwconv2(x)
112 | if self.gamma is not None:
113 | x = self.gamma * x
114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115 |
116 | x = input + self.drop_path(x)
117 | return x
118 |
119 |
120 | class Fuser(nn.Module):
121 | def __init__(self, layer, num_layers, dim=None, input_projection=False):
122 | super().__init__()
123 | self.proj = nn.Identity()
124 | self.layers = get_clones(layer, num_layers)
125 |
126 | if input_projection:
127 | assert dim is not None
128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129 |
130 | def forward(self, x):
131 | # normally x: (N, C, H, W)
132 | x = self.proj(x)
133 | for layer in self.layers:
134 | x = layer(x)
135 | return x
136 |
137 |
138 | class MemoryEncoder(nn.Module):
139 | def __init__(
140 | self,
141 | out_dim,
142 | mask_downsampler,
143 | fuser,
144 | position_encoding,
145 | in_dim=256, # in_dim of pix_feats
146 | ):
147 | super().__init__()
148 |
149 | self.mask_downsampler = mask_downsampler
150 |
151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152 | self.fuser = fuser
153 | self.position_encoding = position_encoding
154 | self.out_proj = nn.Identity()
155 | if out_dim != in_dim:
156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157 |
158 | def forward(
159 | self,
160 | pix_feat: torch.Tensor,
161 | masks: torch.Tensor,
162 | skip_mask_sigmoid: bool = False,
163 | ) -> Tuple[torch.Tensor, torch.Tensor]:
164 | ## Process masks
165 | # sigmoid, so that less domain shift from gt masks which are bool
166 | if not skip_mask_sigmoid:
167 | masks = F.sigmoid(masks)
168 | masks = self.mask_downsampler(masks)
169 |
170 | ## Fuse pix_feats and downsampled masks
171 | # in case the visual features are on CPU, cast them to CUDA
172 | pix_feat = pix_feat.to(masks.device)
173 |
174 | x = self.pix_feat_proj(pix_feat)
175 | x = x + masks
176 | x = self.fuser(x)
177 | x = self.out_proj(x)
178 |
179 | pos = self.position_encoding(x).to(x.dtype)
180 |
181 | return {"vision_features": x, "vision_pos_enc": [pos]}
--------------------------------------------------------------------------------
/models/components/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 - 2021 MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | import math
13 | import warnings
14 | from typing import List
15 |
16 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
17 | from torch.optim import Adam, Optimizer
18 | from torch.optim.lr_scheduler import _LRScheduler
19 |
20 |
21 | __all__ = ["LinearLR", "ExponentialLR"]
22 |
23 |
24 | class _LRSchedulerMONAI(_LRScheduler):
25 | """Base class for increasing the learning rate between two boundaries over a number
26 | of iterations"""
27 |
28 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None:
29 | """
30 | Args:
31 | optimizer: wrapped optimizer.
32 | end_lr: the final learning rate.
33 | num_iter: the number of iterations over which the test occurs.
34 | last_epoch: the index of last epoch.
35 | Returns:
36 | None
37 | """
38 | self.end_lr = end_lr
39 | self.num_iter = num_iter
40 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch)
41 |
42 |
43 | class LinearLR(_LRSchedulerMONAI):
44 | """Linearly increases the learning rate between two boundaries over a number of
45 | iterations.
46 | """
47 |
48 | def get_lr(self):
49 | r = self.last_epoch / (self.num_iter - 1)
50 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
51 |
52 |
53 | class ExponentialLR(_LRSchedulerMONAI):
54 | """Exponentially increases the learning rate between two boundaries over a number of
55 | iterations.
56 | """
57 |
58 | def get_lr(self):
59 | r = self.last_epoch / (self.num_iter - 1)
60 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
61 |
62 |
63 | class WarmupCosineSchedule(LambdaLR):
64 | """Linear warmup and then cosine decay.
65 | Based on https://huggingface.co/ implementation.
66 | """
67 |
68 | def __init__(
69 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1
70 | ) -> None:
71 | """
72 | Args:
73 | optimizer: wrapped optimizer.
74 | warmup_steps: number of warmup iterations.
75 | t_total: total number of segment_anything_training iterations.
76 | cycles: cosine cycles parameter.
77 | last_epoch: the index of last epoch.
78 | Returns:
79 | None
80 | """
81 | self.warmup_steps = warmup_steps
82 | self.t_total = t_total
83 | self.cycles = cycles
84 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch)
85 |
86 | def lr_lambda(self, step):
87 | if step < self.warmup_steps:
88 | return float(step) / float(max(1.0, self.warmup_steps))
89 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
90 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
91 |
92 | class LinearWarmupCosineAnnealingLR(_LRScheduler):
93 |
94 | def __init__(
95 | self,
96 | optimizer: Optimizer,
97 | warmup_epochs: int,
98 | max_epochs: int,
99 | warmup_start_lr: float = 0.0,
100 | eta_min: float = 0.0,
101 | last_epoch: int = -1,
102 | ) -> None:
103 | """
104 | Args:
105 | optimizer (Optimizer): Wrapped optimizer.
106 | warmup_epochs (int): Maximum number of iterations for linear warmup
107 | max_epochs (int): Maximum number of iterations
108 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
109 | eta_min (float): Minimum learning rate. Default: 0.
110 | last_epoch (int): The index of last epoch. Default: -1.
111 | """
112 | self.warmup_epochs = warmup_epochs
113 | self.max_epochs = max_epochs
114 | self.warmup_start_lr = warmup_start_lr
115 | self.eta_min = eta_min
116 |
117 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
118 |
119 | def get_lr(self) -> List[float]:
120 | """
121 | Compute learning rate using chainable form of the scheduler
122 | """
123 | if not self._get_lr_called_within_step:
124 | warnings.warn(
125 | "To get the last learning rate computed by the scheduler, "
126 | "please use `get_last_lr()`.",
127 | UserWarning,
128 | )
129 |
130 | if self.last_epoch == 0:
131 | return [self.warmup_start_lr] * len(self.base_lrs)
132 | elif self.last_epoch < self.warmup_epochs:
133 | return [
134 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
135 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
136 | ]
137 | elif self.last_epoch == self.warmup_epochs:
138 | return self.base_lrs
139 | else:
140 | return [
141 | self.eta_min + 0.5 * (base_lr - self.eta_min) *
142 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
143 | for base_lr in self.base_lrs
144 | ]
145 |
146 | def _get_closed_form_lr(self) -> List[float]:
147 | """
148 | Called when epoch is passed as a param to the `step` function of the scheduler.
149 | """
150 | if self.last_epoch < self.warmup_epochs:
151 | return [
152 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
153 | for base_lr in self.base_lrs
154 | ]
155 |
156 | return [
157 | self.eta_min + 0.5 * (base_lr - self.eta_min) *
158 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
159 | for base_lr in self.base_lrs
160 | ]
--------------------------------------------------------------------------------
/models/sam2/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 |
10 | import torch
11 | from hydra import compose
12 | from hydra.utils import instantiate
13 | from omegaconf import OmegaConf
14 |
15 | import models.sam2 as sam2
16 |
17 | # Check if the user is running Python from the parent directory of the sam2 repo
18 | # (i.e. the directory where this repo is cloned into) -- this is not supported since
19 | # it could shadow the sam2 package and cause issues.
20 | if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21 | # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22 | # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23 | # This typically happens because the user is running Python from the parent directory
24 | # that contains the sam2 repo they cloned.
25 | raise RuntimeError(
26 | "You're likely running Python from the parent directory of the sam2 repository "
27 | "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28 | "This is not supported since the `sam2` Python package could be shadowed by the "
29 | "repository name (the repository is also named `sam2` and contains the Python package "
30 | "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31 | "rather than its parent dir, or from your home directory) after installing SAM 2."
32 | )
33 |
34 |
35 | HF_MODEL_ID_TO_FILENAMES = {
36 | "facebook/sam2-hiera-tiny": (
37 | "configs/sam2/sam2_hiera_t.yaml",
38 | "sam2_hiera_tiny.pt",
39 | ),
40 | "facebook/sam2-hiera-small": (
41 | "configs/sam2/sam2_hiera_s.yaml",
42 | "sam2_hiera_small.pt",
43 | ),
44 | "facebook/sam2-hiera-base-plus": (
45 | "configs/sam2/sam2_hiera_b+.yaml",
46 | "sam2_hiera_base_plus.pt",
47 | ),
48 | "facebook/sam2-hiera-large": (
49 | "configs/sam2/sam2_hiera_l.yaml",
50 | "sam2_hiera_large.pt",
51 | ),
52 | "facebook/sam2.1-hiera-tiny": (
53 | "configs/sam2.1/sam2.1_hiera_t.yaml",
54 | "sam2.1_hiera_tiny.pt",
55 | ),
56 | "facebook/sam2.1-hiera-small": (
57 | "configs/sam2.1/sam2.1_hiera_s.yaml",
58 | "sam2.1_hiera_small.pt",
59 | ),
60 | "facebook/sam2.1-hiera-base-plus": (
61 | "configs/sam2.1/sam2.1_hiera_b+.yaml",
62 | "sam2.1_hiera_base_plus.pt",
63 | ),
64 | "facebook/sam2.1-hiera-large": (
65 | "configs/sam2.1/sam2.1_hiera_l.yaml",
66 | "sam2.1_hiera_large.pt",
67 | ),
68 | }
69 |
70 |
71 | def build_sam2(
72 | config_file,
73 | ckpt_path=None,
74 | device="cuda",
75 | mode="eval",
76 | hydra_overrides_extra=[],
77 | apply_postprocessing=True,
78 | **kwargs,
79 | ):
80 |
81 | if apply_postprocessing:
82 | hydra_overrides_extra = hydra_overrides_extra.copy()
83 | hydra_overrides_extra += [
84 | # dynamically fall back to multi-mask if the single mask is not stable
85 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88 | ]
89 | # Read config and init model
90 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91 | OmegaConf.resolve(cfg)
92 | model = instantiate(cfg.model, _recursive_=True)
93 | _load_checkpoint(model, ckpt_path)
94 | model = model.to(device)
95 | if mode == "eval":
96 | model.eval()
97 | return model
98 |
99 |
100 | def build_sam2_video_predictor(
101 | config_file,
102 | ckpt_path=None,
103 | device="cuda",
104 | mode="eval",
105 | hydra_overrides_extra=[],
106 | apply_postprocessing=True,
107 | vos_optimized=False,
108 | **kwargs,
109 | ):
110 | hydra_overrides = [
111 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
112 | ]
113 | if vos_optimized:
114 | hydra_overrides = [
115 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
116 | "++model.compile_image_encoder=True", # Let sam2_base handle this
117 | ]
118 |
119 | if apply_postprocessing:
120 | hydra_overrides_extra = hydra_overrides_extra.copy()
121 | hydra_overrides_extra += [
122 | # dynamically fall back to multi-mask if the single mask is not stable
123 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
124 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
125 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
126 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
127 | "++model.binarize_mask_from_pts_for_mem_enc=true",
128 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
129 | "++model.fill_hole_area=8",
130 | ]
131 | hydra_overrides.extend(hydra_overrides_extra)
132 |
133 | # Read config and init model
134 | cfg = compose(config_name=config_file, overrides=hydra_overrides)
135 | OmegaConf.resolve(cfg)
136 | model = instantiate(cfg.model, _recursive_=True)
137 | _load_checkpoint(model, ckpt_path)
138 | model = model.to(device)
139 | if mode == "eval":
140 | model.eval()
141 | return model
142 |
143 |
144 | def _hf_download(model_id):
145 | from huggingface_hub import hf_hub_download
146 |
147 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
148 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
149 | return config_name, ckpt_path
150 |
151 |
152 | def build_sam2_hf(model_id, **kwargs):
153 | config_name, ckpt_path = _hf_download(model_id)
154 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
155 |
156 |
157 | def build_sam2_video_predictor_hf(model_id, **kwargs):
158 | config_name, ckpt_path = _hf_download(model_id)
159 | return build_sam2_video_predictor(
160 | config_file=config_name, ckpt_path=ckpt_path, **kwargs
161 | )
162 |
163 |
164 | def _load_checkpoint(model, ckpt_path):
165 | if ckpt_path is not None:
166 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
167 | missing_keys, unexpected_keys = model.load_state_dict(sd)
168 | if missing_keys:
169 | logging.error(missing_keys)
170 | raise RuntimeError()
171 | if unexpected_keys:
172 | logging.error(unexpected_keys)
173 | raise RuntimeError()
174 | logging.info("Loaded checkpoint sucessfully")
--------------------------------------------------------------------------------
/datasets/data_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import multiprocessing
4 | import argparse
5 | from scipy import sparse
6 | from sklearn.model_selection import train_test_split
7 | import json
8 |
9 | from monai.transforms import (
10 | AddChanneld,
11 | Compose,
12 | LoadImaged,
13 | Orientationd,
14 | )
15 |
16 | def set_parse():
17 | # %% set up parser
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--category", default=["liver", "spleen", "stomach", "gallbladder", "esophagus", "pancreas", "duodenum", "aorta", "bladder", "inferior vena cava", "left kidney", "right kidney", "left adrenal gland", "right adrenal gland", "left femur", "right femur", "left lung", "right lung"], type=list)
20 | parser.add_argument("--image_dir", type=str, required=True)
21 | parser.add_argument("--label_dir", type=str, required=True)
22 | parser.add_argument("--dataset_code", type=str, required=True)
23 | parser.add_argument("--save_root", type=str, required=True)
24 | parser.add_argument("--test_ratio", type=float, required=True)
25 |
26 | args = parser.parse_args()
27 | return args
28 |
29 | args = set_parse()
30 |
31 | # get ct> dir
32 | image_list_all = [item for item in sorted(os.listdir(args.image_dir))]
33 | label_list_all = [item for item in sorted(os.listdir(args.label_dir))]
34 | assert len(image_list_all) == len(label_list_all)
35 | print('dataset size ', len(image_list_all))
36 |
37 | # build dataset
38 | data_path_list_all = []
39 | for idx in range(len(image_list_all)):
40 | img_path = os.path.join(args.image_dir, image_list_all[idx])
41 | label_path = os.path.join(args.label_dir, label_list_all[idx])
42 | name = image_list_all[idx].split('.')[0]
43 | info = (idx, name, img_path, label_path)
44 | data_path_list_all.append(info)
45 |
46 | img_loader = Compose(
47 | [
48 | LoadImaged(keys=['image', 'label']),
49 | AddChanneld(keys=['image', 'label']),
50 | Orientationd(keys=['image', 'label'], axcodes="RAS"),
51 | ]
52 | )
53 |
54 | # save
55 | save_path = os.path.join(args.save_root, args.dataset_code)
56 | ct_save_path = os.path.join(save_path, 'ct')
57 | gt_save_path = os.path.join(save_path, 'gt')
58 | if not os.path.exists(ct_save_path):
59 | os.makedirs(ct_save_path)
60 | if not os.path.exists(gt_save_path):
61 | os.makedirs(gt_save_path)
62 |
63 | # exist file:
64 | exist_file_list = os.listdir(ct_save_path)
65 | print('exist_file_list ', exist_file_list)
66 |
67 | def normalize(ct_narray):
68 | ct_voxel_ndarray = ct_narray.copy()
69 | ct_voxel_ndarray = ct_voxel_ndarray.flatten()
70 | # for all data
71 | thred = np.mean(ct_voxel_ndarray)
72 | voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
73 | # for foreground data
74 | upper_bound = np.percentile(voxel_filtered, 99.95)
75 | lower_bound = np.percentile(voxel_filtered, 00.05)
76 | mean = np.mean(voxel_filtered)
77 | std = np.std(voxel_filtered)
78 | ### transform ###
79 | ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
80 | ct_narray = (ct_narray - mean) / max(std, 1e-8)
81 | return ct_narray
82 |
83 | def run(info):
84 | idx, file_name, case_path, label_path = info
85 | item = {}
86 | if file_name + '.npy' in exist_file_list:
87 | print(file_name + '.npy exist, skip')
88 | return
89 | print('process ', idx, '---' ,file_name)
90 | # generate ct_voxel_ndarray
91 | item_load = {
92 | 'image' : case_path,
93 | 'label' : label_path,
94 | }
95 | item_load = img_loader(item_load)
96 | ct_voxel_ndarray = item_load['image']
97 | gt_voxel_ndarray = item_load['label']
98 |
99 | ct_shape = ct_voxel_ndarray.shape
100 | item['image'] = ct_voxel_ndarray
101 |
102 | # generate gt_voxel_ndarray
103 | gt_voxel_ndarray = np.array(gt_voxel_ndarray).squeeze()
104 | present_categories = np.unique(gt_voxel_ndarray)
105 | gt_masks = []
106 | for cls_idx in range(len(args.category)):
107 | cls = cls_idx + 1
108 | if cls not in present_categories:
109 | gt_voxel_ndarray_category = np.zeros(ct_shape)
110 | gt_masks.append(gt_voxel_ndarray_category)
111 | print('case {} ==> zero category '.format(idx) + args.category[cls_idx])
112 | print(gt_voxel_ndarray_category.shape)
113 | else:
114 | gt_voxel_ndarray_category = gt_voxel_ndarray.copy()
115 | gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0
116 | gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1
117 | gt_masks.append(gt_voxel_ndarray_category)
118 | gt_voxel_ndarray = np.stack(gt_masks, axis=0)
119 |
120 | assert gt_voxel_ndarray.shape[0] == len(args.category), str(gt_voxel_ndarray.shape[0])
121 | assert gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:]
122 | item['label'] = gt_voxel_ndarray.astype(np.int32)
123 | print(idx, ' load done!')
124 |
125 | #############################
126 | item['image'] = normalize(item['image'])
127 | print(idx, ' transform done')
128 |
129 | ############################
130 | print(file_name + ' ct gt <--> ', item['image'].shape, item['label'].shape)
131 | np.save(os.path.join(ct_save_path, file_name + '.npy'), item['image'])
132 | allmatrix_sp=sparse.csr_matrix(item['label'].reshape(item['label'].shape[0], -1))
133 | sparse.save_npz(os.path.join(gt_save_path, file_name + '.' + str(item['label'].shape)), allmatrix_sp)
134 | print(file_name + ' save done!')
135 |
136 | def generate_dataset_json(root_dir, output_file, test_ratio=0.2):
137 | ct_dir = os.path.join(root_dir, 'ct')
138 | gt_dir = os.path.join(root_dir, 'gt')
139 | ct_paths = sorted([os.path.join(ct_dir, f) for f in sorted(os.listdir(ct_dir))])
140 | gt_paths = sorted([os.path.join(gt_dir, f) for f in sorted(os.listdir(gt_dir))])
141 |
142 | data = list(zip(ct_paths, gt_paths))
143 | train_data, val_data = train_test_split(data, test_size=test_ratio)
144 | labels = {}
145 | labels['0'] = 'background'
146 | for idx in range(len(args.category)):
147 | label_name = args.category[idx]
148 | label_id = idx + 1
149 | labels[str(label_id)] = label_name
150 | dataset = {
151 | 'name': f'{args.dataset_code} Dataset',
152 | 'description': f'{args.dataset_code} Dataset',
153 | 'tensorImageSize': '4D',
154 | 'modality': {
155 | '0': 'CT',
156 | },
157 | 'labels': labels,
158 | 'numTraining': len(train_data),
159 | 'numTest': len(val_data),
160 | 'segment_anything_training': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in train_data],
161 | 'validation': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in val_data]
162 | }
163 | with open(output_file, 'w') as f:
164 | print(f'{output_file} dump')
165 | json.dump(dataset, f, indent=2)
166 |
167 | if __name__ == "__main__":
168 | with multiprocessing.Pool(processes=10) as pool:
169 | pool.map(run, data_path_list_all)
170 | print('Process Finished!')
171 |
172 | generate_dataset_json(root_dir=save_path,
173 | output_file=os.path.join(save_path, f'{args.dataset_code}.json'),
174 | test_ratio=args.test_ratio)
175 | print('Json Split Done!')
176 |
--------------------------------------------------------------------------------
/env_config.yml:
--------------------------------------------------------------------------------
1 | name: CRISP_SAM2
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - bzip2=1.0.8=h5eee18b_6
9 | - ca-certificates=2024.3.11=h06a4308_0
10 | - debugpy=1.6.7=py312h6a678d5_0
11 | - importlib_metadata=8.2.0=hd8ed1ab_0
12 | - ld_impl_linux-64=2.38=h1181459_1
13 | - libffi=3.4.4=h6a678d5_0
14 | - libgcc-ng=11.2.0=h1234567_1
15 | - libgomp=11.2.0=h1234567_1
16 | - libstdcxx-ng=11.2.0=h1234567_1
17 | - ncurses=6.4=h6a678d5_0
18 | - openssl=3.0.14=h5eee18b_0
19 | - packaging=24.1=pyhd8ed1ab_0
20 | - pip=24.0=py312h06a4308_0
21 | - python=3.12.4=h5148396_1
22 | - readline=8.2=h5eee18b_0
23 | - setuptools=69.5.1=py312h06a4308_0
24 | - sqlite=3.45.3=h5eee18b_0
25 | - tk=8.6.14=h39e8969_0
26 | - wheel=0.43.0=py312h06a4308_0
27 | - xz=5.4.6=h5eee18b_1
28 | - zlib=1.2.13=h5eee18b_0
29 | - pip:
30 | - absl-py==2.1.0
31 | - accelerate==0.31.0
32 | - aiohttp==3.9.5
33 | - aiosignal==1.3.1
34 | - albumentations=0.3.2
35 | - alembic==1.13.1
36 | - aniso8601==9.0.1
37 | - annotated-types==0.7.0
38 | - appdirs==1.4.4
39 | - asciitree==0.3.3
40 | - astor==0.8.1
41 | - async-timeout==4.0.3
42 | - attrs==23.2.0
43 | - Authlib==1.3.1
44 | - bert-score==0.3.13
45 | - blinker==1.7.0
46 | - blis==0.7.11
47 | - catalogue==2.0.10
48 | - certifi==2022.12.7
49 | - cffi==1.16.0
50 | - charset-normalizer==2.1.1
51 | - clearml==1.14.5rc0
52 | - click==8.1.7
53 | - cloudpathlib==0.18.1
54 | - cloudpickle==3.0.0
55 | - cmake==3.28.4
56 | - colorama==0.4.6
57 | - coloredlogs==15.0.1
58 | - colorlog==6.8.2
59 | - confection==0.1.5
60 | - connected-components-3d==3.14.0
61 | - contextlib2==21.6.0
62 | - contourpy==1.2.0
63 | - crcmod==1.7
64 | - cryptography==42.0.8
65 | - cucim==23.10.0
66 | - cycler==0.12.1
67 | - cymem==2.0.8
68 | - datasets==2.19.1
69 | - Deprecated==1.2.14
70 | - dill==0.3.8
71 | - django-s3-file-field-client==1.0.1
72 | - docker==7.0.0
73 | - docutils==0.16.0
74 | - dotmap==1.3.30
75 | - dynamic-network-architectures==0.3.1
76 | - edt==2.4.0
77 | - einops==0.7.0
78 | - entrypoints==0.4
79 | - environs==11.0.0
80 | - eval_type_backport==0.2.0
81 | - fasteners==0.19
82 | - fastremap==1.14.1
83 | - filelock==3.11.0
84 | - fire==0.6.0
85 | - flake8==5.0.4
86 | - flash-attn==2.5.6
87 | - Flask==3.0.2
88 | - flatbuffers==24.3.7
89 | - fonttools==4.50.0
90 | - frozenlist==1.4.1
91 | - fsspec==2024.3.1
92 | - ftfy==6.2.3
93 | - furl==2.1.3
94 | - fvcore==0.1.5.post20221221
95 | - gdown==5.1.0
96 | - girder-cli-oauth-client==0.4.0
97 | - gitdb==4.0.11
98 | - GitPython==3.1.42
99 | - graphene==3.3
100 | - graphql-core==3.2.3
101 | - graphql-relay==3.2.0
102 | - greenlet==3.0.3
103 | - grpcio==1.62.1
104 | - gunicorn==21.2.0
105 | - h5py==3.10.0
106 | - hf_transfer==0.1.8
107 | - httpx==0.27.0
108 | - huggingface-hub==0.23.4
109 | - humanfriendly==10.0
110 | - humanize==4.9.0
111 | - idna==3.4
112 | - imagecodecs==2024.1.1
113 | - imageio==2.34.0
114 | - importlib_metadata==7.1.0
115 | - importlib_resources==6.4.0
116 | - iopath==0.1.10
117 | - isic-cli==0.0.0
118 | - isic-metadata==0.0.0
119 | - itk==5.3.0
120 | - itk-core==5.3.0
121 | - itk-filtering==5.3.0
122 | - itk-io==5.3.0
123 | - itk-numerics==5.3.0
124 | - itk-registration==5.3.0
125 | - itk-segmentation==5.3.0
126 | - itsdangerous==2.1.2
127 | - jinja2==3.1.4
128 | - joblib==1.4.2
129 | - json5==0.9.25
130 | - jsonpointer==3.0.0
131 | - jsonschema==4.23.0
132 | - jsonschema-specifications==2023.12.1
133 | - kiwisolver==1.4.5
134 | - lazy-loader==0.4
135 | - linecache2==1.0.0
136 | - markupsafe==2.1.5
137 | - matplotlib==3.9.1
138 | - mistune==3.0.2
139 | - mmcv=2.0.0rc4
140 | - mmengine=0.7.1
141 | - monai==1.3.2
142 | - mpmath==1.3.0
143 | - nbclient==0.10.0
144 | - nbconvert==7.16.4
145 | - nbformat==5.10.4
146 | - networkx==3.3
147 | - nibabel==5.2.1
148 | - nnunetv2==2.5.1
149 | - notebook-shim==0.2.4
150 | - numpy==2.0.1
151 | - nvidia-cublas-cu12==12.1.3.1
152 | - nvidia-cuda-cupti-cu12==12.1.105
153 | - nvidia-cuda-nvrtc-cu12==12.1.105
154 | - nvidia-cuda-runtime-cu12==12.1.105
155 | - nvidia-cudnn-cu12==9.1.0.70
156 | - nvidia-cufft-cu12==11.0.2.54
157 | - nvidia-curand-cu12==10.3.2.106
158 | - nvidia-cusolver-cu12==11.4.5.107
159 | - nvidia-cusparse-cu12==12.1.0.106
160 | - nvidia-nccl-cu12==2.20.5
161 | - nvidia-nvjitlink-cu12==12.5.82
162 | - nvidia-nvtx-cu12==12.1.105
163 | - omegaconf==2.3.0
164 | - opencv-python==4.10.0.84
165 | - opentelemetry-api==1.21.0
166 | - opentelemetry-exporter-otlp-proto-common==1.21.0
167 | - opentelemetry-exporter-otlp-proto-http==1.21.0
168 | - opentelemetry-proto==1.21.0
169 | - opentelemetry-sdk==1.21.0
170 | - opentelemetry-semantic-conventions==0.42b0
171 | - overrides==7.7.0
172 | - pandas==2.2.1
173 | - pandocfilters==1.5.1
174 | - pillow==10.4.0
175 | - portalocker==2.10.1
176 | - prometheus-client==0.20.0
177 | - protobuf==4.25.4
178 | - psutil==5.9.8
179 | - pycparser==2.22
180 | - pydantic==2.7.4
181 | - pydantic_core==2.18.4
182 | - pydicom==2.4.4
183 | - pyparsing==3.1.2
184 | - python-dateutil==2.9.0.post0
185 | - python-gdcm==3.0.24.1
186 | - graphviz==0.20.3
187 | - python-json-logger==2.0.7
188 | - pytorch-ignite==0.4.11
189 | - pytorch-lightning==2.2.5
190 | - pytz==2024.1
191 | - pyyaml==6.0.1
192 | - referencing==0.35.1
193 | - requests==2.32.3
194 | - rfc3339-validator==0.1.4
195 | - rfc3986-validator==0.1.1
196 | - rpds-py==0.19.1
197 | - scipy==1.12.0
198 | - simpleitk==2.3.1
199 | - sniffio==1.3.1
200 | - soupsieve==2.5
201 | - sphinx==4.0.2
202 | - sphinx-copybutton
203 | - sphinx_markdown_tables
204 | - sphinx_rtd_theme==0.5.2
205 | - sympy==1.13.1
206 | - synapseclient==4.4.0
207 | - tensorboard==2.16.2
208 | - tensorboard-data-server==0.7.2
209 | - tensorboardX==2.6.2.2
210 | - termcolor==2.4.0
211 | - terminado==0.18.1
212 | - threadpoolctl==3.5.0
213 | - tifffile==2024.7.24
214 | - tinycss2==1.3.0
215 | - torch==2.4.0
216 | - torchaudio==2.4.0
217 | - torchsummary==1.5.1
218 | - torchvision==0.19.0
219 | - torchmetrics==1.4.0.post0
220 | - tqdm==4.66.4
221 | - traceback2==1.4.0
222 | - triton==3.0.0
223 | - types-python-dateutil==2.9.0.20240316
224 | - tzdata==2024.1
225 | - unittest2==1.1.0
226 | - uri-template==1.3.0
227 | - urllib3==1.26.19
228 | - webcolors==24.6.0
229 | - webencodings==0.5.1
230 | - websocket-client==1.8.0
231 | - widgetsnbextension==4.0.11
232 | - wrapt==1.16.0
233 | - yacs==0.1.8
--------------------------------------------------------------------------------
/datasets/train_data_prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import multiprocessing
4 | import argparse
5 | from scipy import sparse
6 | from sklearn.model_selection import train_test_split
7 | import json
8 | join = os.path.join
9 |
10 | from monai.transforms import (
11 | AddChanneld,
12 | Compose,
13 | LoadImaged,
14 | Orientationd,
15 | )
16 |
17 | def set_parse():
18 | # %% set up parser
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--category", default=["liver", "spleen", "stomach", "gallbladder", "esophagus", "pancreas", "duodenum", "aorta", "bladder", "inferior vena cava", "left kidney", "right kidney", "left adrenal gland", "right adrenal gland", "left femur", "right femur", "left lung", "right lung"], type=list)
21 | parser.add_argument("--image_dir", type=str, required=True)
22 | parser.add_argument("--label_dir", type=str, required=True)
23 | parser.add_argument("--dataset_code", type=str, required=True)
24 | parser.add_argument("--save_root", type=str, required=True)
25 | parser.add_argument("--test_ratio", type=float, required=True)
26 |
27 | args = parser.parse_args()
28 | return args
29 |
30 | args = set_parse()
31 |
32 | # get ct> dir
33 | image_list_all = [item for item in sorted(os.listdir(args.image_dir))]
34 | label_list_all = [item for item in sorted(os.listdir(args.label_dir))]
35 | assert len(image_list_all) == len(label_list_all)
36 | print('dataset size ', len(image_list_all))
37 |
38 | # build dataset
39 | data_path_list_all = []
40 | for idx in range(len(image_list_all)):
41 | img_path = join(args.image_dir, image_list_all[idx])
42 | label_path = join(args.label_dir, label_list_all[idx])
43 | name = image_list_all[idx].split('.')[0]
44 | info = (idx, name, img_path, label_path)
45 | data_path_list_all.append(info)
46 |
47 | img_loader = Compose(
48 | [
49 | LoadImaged(keys=['image', 'label']),
50 | AddChanneld(keys=['image', 'label']),
51 | # Orientationd(keys=['image', 'label'], axcodes="RAS"),
52 | ]
53 | )
54 |
55 | # save
56 | save_path = join(args.save_root, args.dataset_code)
57 | os.makedirs(save_path, exist_ok=True)
58 |
59 | # ct_save_path = join(save_path, 'ct')
60 | # gt_save_path = join(save_path, 'gt')
61 | # if not os.path.exists(ct_save_path):
62 | # os.makedirs(ct_save_path)
63 | # if not os.path.exists(gt_save_path):
64 | # os.makedirs(gt_save_path)
65 |
66 | # exist file:
67 | exist_file_list = os.listdir(save_path)
68 | print('exist_file_list ', exist_file_list)
69 |
70 | def normalize(ct_narray):
71 | ct_voxel_ndarray = ct_narray.copy()
72 | ct_voxel_ndarray = ct_voxel_ndarray.flatten()
73 | # for all data
74 | thred = np.mean(ct_voxel_ndarray)
75 | voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
76 | # for foreground data
77 | upper_bound = np.percentile(voxel_filtered, 99.95)
78 | lower_bound = np.percentile(voxel_filtered, 00.05)
79 | mean = np.mean(voxel_filtered)
80 | std = np.std(voxel_filtered)
81 | ### transform ###
82 | ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
83 | ct_narray = (ct_narray - mean) / max(std, 1e-8)
84 | return ct_narray
85 |
86 | def run(info):
87 | idx, file_name, case_path, label_path = info
88 |
89 | item = {}
90 | if file_name in exist_file_list:
91 | print(file_name + ' exist, skip')
92 | return
93 | print('process ', idx, '---' ,file_name)
94 | # generate ct_voxel_ndarray
95 | item_load = {
96 | 'image' : case_path,
97 | 'label' : label_path,
98 | }
99 | item_load = img_loader(item_load)
100 | ct_voxel_ndarray = item_load['image']
101 | gt_voxel_ndarray = item_load['label']
102 |
103 | ct_shape = ct_voxel_ndarray.shape
104 | item['image'] = ct_voxel_ndarray
105 |
106 | # generate gt_voxel_ndarray
107 | gt_voxel_ndarray = np.array(gt_voxel_ndarray).squeeze()
108 | present_categories = np.unique(gt_voxel_ndarray)
109 | gt_masks = []
110 | for cls_idx in range(len(args.category)):
111 | cls = cls_idx + 1
112 | if cls not in present_categories:
113 | gt_voxel_ndarray_category = np.zeros(ct_shape)
114 | gt_masks.append(gt_voxel_ndarray_category)
115 | print('case {} ==> zero category '.format(idx) + args.category[cls_idx])
116 | print(gt_voxel_ndarray_category.shape)
117 | else:
118 | gt_voxel_ndarray_category = gt_voxel_ndarray.copy()
119 | gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0
120 | gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1
121 | gt_masks.append(gt_voxel_ndarray_category)
122 | gt_voxel_ndarray = np.stack(gt_masks, axis=0)
123 |
124 | assert gt_voxel_ndarray.shape[0] == len(args.category), str(gt_voxel_ndarray.shape[0])
125 | assert gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:]
126 | item['label'] = gt_voxel_ndarray.astype(np.int32)
127 | print(idx, ' load done!')
128 |
129 | #############################
130 | item['image'] = normalize(item['image'])
131 | print(idx, ' transform done')
132 |
133 | ############################
134 | print(file_name + ' ct gt <--> ', item['image'].shape, item['label'].shape)
135 | case_path = join(save_path, file_name)
136 | os.makedirs(case_path, exist_ok=True)
137 |
138 | np.save(join(case_path, 'image.npy'), item['image'])
139 | allmatrix_sp=sparse.csr_matrix(item['label'].reshape(item['label'].shape[0], -1))
140 | sparse.save_npz(join(case_path, 'mask_' + str(item['label'].shape)), allmatrix_sp)
141 | print(file_name + ' save done!')
142 |
143 | def generate_dataset_json(root_dir, output_file, test_ratio=0.2):
144 | cases = os.listdir(root_dir)
145 | ct_paths, gt_paths = [], []
146 | for case_name in cases:
147 | case_files = sorted(os.listdir(join(root_dir, case_name)))
148 | ct_path = join(root_dir, case_name, case_files[0])
149 | gt_path = join(root_dir, case_name, case_files[1])
150 | ct_paths.append(ct_path)
151 | gt_paths.append(gt_path)
152 |
153 | data = list(zip(ct_paths, gt_paths))
154 | train_data, val_data = train_test_split(data, test_size=test_ratio)
155 | labels = {}
156 | labels['0'] = 'background'
157 | for idx in range(len(args.category)):
158 | label_name = args.category[idx]
159 | label_id = idx + 1
160 | labels[str(label_id)] = label_name
161 | dataset = {
162 | 'name': f'{args.dataset_code} Dataset',
163 | 'description': f'{args.dataset_code} Dataset',
164 | 'tensorImageSize': '4D',
165 | 'modality': {
166 | '0': 'CT',
167 | },
168 | 'labels': labels,
169 | 'numTrain': len(train_data),
170 | 'numTest': len(val_data),
171 | 'train': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in train_data],
172 | 'test': [{'image': ct_path, 'label': gt_path} for ct_path, gt_path in val_data]
173 | }
174 | with open(output_file, 'w') as f:
175 | print(f'{output_file} dump')
176 | json.dump(dataset, f, indent=2)
177 |
178 | if __name__ == "__main__":
179 | with multiprocessing.Pool(processes=64) as pool:
180 | pool.map(run, data_path_list_all)
181 | print('Process Finished!')
182 |
183 | generate_dataset_json(root_dir=save_path,
184 | output_file=join(save_path, f'{args.dataset_code}.json'),
185 | test_ratio=args.test_ratio)
186 | print('Json Split Done!')
187 |
--------------------------------------------------------------------------------
/models/sam2/sam/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from models.sam2.position_encoding import PositionEmbeddingRandom
13 |
14 | from models.sam2.sam2_utils import LayerNorm2d
15 |
16 |
17 | class PromptEncoder(nn.Module):
18 | def __init__(
19 | self,
20 | embed_dim: int,
21 | image_embedding_size: Tuple[int, int],
22 | input_image_size: Tuple[int, int],
23 | mask_in_chans: int,
24 | activation: Type[nn.Module] = nn.GELU,
25 | ) -> None:
26 | """
27 | Encodes prompts for input to SAM's mask decoder.
28 |
29 | Arguments:
30 | embed_dim (int): The prompts' embedding dimension
31 | image_embedding_size (tuple(int, int)): The spatial size of the
32 | image embedding, as (H, W).
33 | input_image_size (int): The padded size of the image as input
34 | to the image encoder, as (H, W).
35 | mask_in_chans (int): The number of hidden channels used for
36 | encoding input masks.
37 | activation (nn.Module): The activation to use when encoding
38 | input masks.
39 | """
40 | super().__init__()
41 | self.embed_dim = embed_dim
42 | self.input_image_size = input_image_size
43 | self.image_embedding_size = image_embedding_size
44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45 |
46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47 | point_embeddings = [
48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49 | ]
50 | self.point_embeddings = nn.ModuleList(point_embeddings)
51 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
52 |
53 | self.mask_input_size = (
54 | 4 * image_embedding_size[0],
55 | 4 * image_embedding_size[1],
56 | )
57 | self.mask_downscaling = nn.Sequential(
58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59 | LayerNorm2d(mask_in_chans // 4),
60 | activation(),
61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62 | LayerNorm2d(mask_in_chans),
63 | activation(),
64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65 | )
66 | self.no_mask_embed = nn.Embedding(1, embed_dim)
67 |
68 | def get_dense_pe(self) -> torch.Tensor:
69 | """
70 | Returns the positional encoding used to encode point prompts,
71 | applied to a dense set of points the shape of the image encoding.
72 |
73 | Returns:
74 | torch.Tensor: Positional encoding with shape
75 | 1x(embed_dim)x(embedding_h)x(embedding_w)
76 | """
77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78 |
79 | def _embed_points(
80 | self,
81 | points: torch.Tensor,
82 | labels: torch.Tensor,
83 | pad: bool,
84 | ) -> torch.Tensor:
85 | """Embeds point prompts."""
86 | points = points + 0.5 # Shift to center of pixel
87 | if pad:
88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90 | points = torch.cat([points, padding_point], dim=1)
91 | labels = torch.cat([labels, padding_label], dim=1)
92 | point_embedding = self.pe_layer.forward_with_coords(
93 | points, self.input_image_size
94 | )
95 |
96 | point_embedding = torch.where(
97 | (labels == -1).unsqueeze(-1),
98 | torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
99 | point_embedding,
100 | )
101 | point_embedding = torch.where(
102 | (labels == 0).unsqueeze(-1),
103 | point_embedding + self.point_embeddings[0].weight,
104 | point_embedding,
105 | )
106 | point_embedding = torch.where(
107 | (labels == 1).unsqueeze(-1),
108 | point_embedding + self.point_embeddings[1].weight,
109 | point_embedding,
110 | )
111 | point_embedding = torch.where(
112 | (labels == 2).unsqueeze(-1),
113 | point_embedding + self.point_embeddings[2].weight,
114 | point_embedding,
115 | )
116 | point_embedding = torch.where(
117 | (labels == 3).unsqueeze(-1),
118 | point_embedding + self.point_embeddings[3].weight,
119 | point_embedding,
120 | )
121 | return point_embedding
122 |
123 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
124 | """Embeds box prompts."""
125 | boxes = boxes + 0.5 # Shift to center of pixel
126 | coords = boxes.reshape(-1, 2, 2)
127 | corner_embedding = self.pe_layer.forward_with_coords(
128 | coords, self.input_image_size
129 | )
130 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
131 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
132 | return corner_embedding
133 |
134 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
135 | """Embeds mask inputs."""
136 | mask_embedding = self.mask_downscaling(masks)
137 | return mask_embedding
138 |
139 | def _get_batch_size(
140 | self,
141 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
142 | boxes: Optional[torch.Tensor],
143 | masks: Optional[torch.Tensor],
144 | ) -> int:
145 | """
146 | Gets the batch size of the output given the batch size of the input prompts.
147 | """
148 | if points is not None:
149 | return points[0].shape[0]
150 | elif boxes is not None:
151 | return boxes.shape[0]
152 | elif masks is not None:
153 | return masks.shape[0]
154 | else:
155 | return 1
156 |
157 | def _get_device(self) -> torch.device:
158 | return self.point_embeddings[0].weight.device
159 |
160 | def forward(
161 | self,
162 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
163 | boxes: Optional[torch.Tensor],
164 | masks: Optional[torch.Tensor],
165 | ) -> Tuple[torch.Tensor, torch.Tensor]:
166 | """
167 | Embeds different types of prompts, returning both sparse and dense
168 | embeddings.
169 |
170 | Arguments:
171 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
172 | and labels to embed.
173 | boxes (torch.Tensor or none): boxes to embed
174 | masks (torch.Tensor or none): masks to embed
175 |
176 | Returns:
177 | torch.Tensor: sparse embeddings for the points and boxes, with shape
178 | BxNx(embed_dim), where N is determined by the number of input points
179 | and boxes.
180 | torch.Tensor: dense embeddings for the masks, in the shape
181 | Bx(embed_dim)x(embed_H)x(embed_W)
182 | """
183 | bs = self._get_batch_size(points, boxes, masks)
184 | sparse_embeddings = torch.empty(
185 | (bs, 0, self.embed_dim), device=self._get_device()
186 | )
187 | if points is not None:
188 | coords, labels = points
189 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
190 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
191 | if boxes is not None:
192 | box_embeddings = self._embed_boxes(boxes)
193 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
194 |
195 | if masks is not None:
196 | dense_embeddings = self._embed_masks(masks)
197 | else:
198 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
199 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
200 | )
201 |
202 | return sparse_embeddings, dense_embeddings
--------------------------------------------------------------------------------
/models/components/semantic_extraction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
4 |
5 | from models.sam2.backbones.image_encoder import ImageEncoder
6 | from models.sam2.sam2_utils import MLP
7 | from transformers import (
8 | AutoTokenizer,
9 | CLIPTextModel,
10 | CLIPTextConfig,
11 | CLIPVisionConfig,
12 | CLIPVisionModel,
13 | AutoFeatureExtractor,
14 | )
15 |
16 |
17 | class SemanticInteraction(nn.Module):
18 | def __init__(self, clip_text_ckpt, clip_image_ckp, image_dim, text_dim, num_heads, stem_channel, qkv_bias=False,
19 | qk_scale=None, drop=0.,
20 | attn_drop=0., drop_path=0., has_mlp=False):
21 | super().__init__()
22 | self.text_encoder = ClipTextEncoder(clip_text_ckpt)
23 | self.image_encoder = ClipVisionEncoder(clip_image_ckp)
24 | self.cross_att_one_level_v = CrossAttentionBlock(
25 | image_dim, num_heads=num_heads, stem_channel=stem_channel,
26 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp
27 | )
28 | self.cross_att_one_level_t = CrossAttentionBlock(
29 | text_dim, num_heads=num_heads, stem_channel=stem_channel,
30 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp
31 | )
32 |
33 | self.in_dim = text_dim + text_dim
34 | self.co_conv = nn.Conv2d(self.in_dim, self.in_dim // 2, kernel_size=3, stride=1, padding=1, bias=True)
35 |
36 | self.cross_att_two_level_v = CrossAttentionBlock(
37 | image_dim, num_heads=num_heads, stem_channel=stem_channel,
38 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp
39 | )
40 | self.cross_att_two_level_t = CrossAttentionBlock(
41 | text_dim, num_heads=num_heads, stem_channel=stem_channel,
42 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, has_mlp=has_mlp
43 | )
44 | self.linear_1 = nn.Linear(self.in_dim // 2, self.in_dim * 2)
45 | self.dropout = nn.Dropout(drop)
46 | self.relu = nn.ReLU()
47 | self.linear_2 = nn.Linear(self.in_dim * 2, self.linear_1 // 2)
48 |
49 | def forward(self, image, text):
50 | f_v = self.image_encoder(image)
51 | f_t = self.text_encoder(text)
52 |
53 | f_vt = self.cross_att_one_level_v(f_v, f_t)
54 | f_tv = self.cross_att_one_level_t(f_t, f_v)
55 |
56 | f_c = self.co_conv(torch.cat([f_v, f_t], dim=1))
57 |
58 | f_vt_dot = self.cross_att_two_level_v(f_c, f_vt)
59 | f_tv_dot = self.cross_att_two_level_t(f_c, f_tv)
60 |
61 | raw_out = f_vt_dot + f_tv_dot
62 |
63 | out = self.linear_1(raw_out)
64 | out = self.dropout(out)
65 | out = self.relu(out)
66 | out = self.linear_2(out)
67 | return raw_out, out
68 |
69 |
70 | class DepthWiseConv(nn.Module):
71 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
72 | super().__init__()
73 | self.depthwise = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
74 | padding=padding, groups=in_channels)
75 | self.pointwise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
76 |
77 | def forward(self, x):
78 | x = self.depthwise(x)
79 | x = self.pointwise(x)
80 | return x
81 |
82 |
83 | class CrossAttention(nn.Module):
84 | def __init__(self, dim, num_heads=8, stem_channel=16, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
85 | super().__init__()
86 | self.num_heads = num_heads
87 | head_dim = dim // num_heads
88 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
89 | self.scale = qk_scale or head_dim ** -0.5
90 |
91 | self.conv_k = DepthWiseConv(stem_channel, stem_channel, kernel_size=3, padding=1)
92 | self.conv_v = DepthWiseConv(stem_channel, stem_channel, kernel_size=3, padding=1)
93 |
94 | self.wq = nn.Linear(dim, dim, bias=qkv_bias)
95 | self.wk = nn.Linear(dim, dim, bias=qkv_bias)
96 | self.wv = nn.Linear(dim, dim, bias=qkv_bias)
97 | self.attn_drop = nn.Dropout(attn_drop)
98 | self.proj = nn.Linear(dim, dim)
99 | self.proj_drop = nn.Dropout(proj_drop)
100 |
101 | def forward(self, x, y):
102 | B, N, C = x.shape
103 | q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1,
104 | 3) # B1C -> B1H(C/H) -> BH1(C/H)
105 |
106 | k = self.conv_k(y)
107 | k = self.wk(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1,
108 | 3) # BNC -> BNH(C/H) -> BHN(C/H)
109 | v = self.conv_v(y)
110 | v = self.wv(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1,
111 | 3) # BNC -> BNH(C/H) -> BHN(C/H)
112 |
113 | attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
114 | attn = attn.softmax(dim=-1)
115 | attn = self.attn_drop(attn)
116 |
117 | out = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
118 | out = self.proj(out)
119 | out = self.proj_drop(out)
120 | return out
121 |
122 |
123 | class CrossAttentionBlock(nn.Module):
124 |
125 | def __init__(self, dim, num_heads, stem_channel, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
126 | drop_path=0., has_mlp=False):
127 | super().__init__()
128 | self.norm_x = nn.LayerNorm(dim)
129 | self.norm_y = nn.LayerNorm(dim)
130 | self.attn = CrossAttention(
131 | dim, num_heads=num_heads, stem_channel=stem_channel, qkv_bias=qkv_bias, qk_scale=qk_scale,
132 | attn_drop=attn_drop, proj_drop=drop)
133 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
135 | self.has_mlp = has_mlp
136 | if has_mlp:
137 | self.norm2 = nn.LayerNorm(dim)
138 | mlp_hidden_dim = int(dim * mlp_ratio)
139 | self.mlp = MLP(dim, mlp_hidden_dim, dim // 4, 3)
140 |
141 | def forward(self, x, y):
142 | out = self.attn(self.norm_x(x), self.norm_y(y))
143 | out = self.drop_path(out)
144 | if self.has_mlp:
145 | out = self.drop_path(self.mlp(self.norm2(out)))
146 | return out
147 |
148 |
149 | class ClipTextEncoder(nn.Module):
150 | def __init__(self, clip_ckpt):
151 | """
152 | :param clip_ckpt: the list of checkpoints of CLIP could be found at: https://huggingface.co/openai
153 | """
154 | super().__init__()
155 | config = CLIPTextConfig()
156 | self.clip_text_model = CLIPTextModel(config)
157 | self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt)
158 | self.dim_align = nn.Linear(512, 768)
159 | # freeze text encoder
160 | for param in self.clip_text_model.parameters():
161 | param.requires_grad = False
162 |
163 | def organ2tokens(self, organ_names):
164 | text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
165 | tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
166 | for key in tokens.keys():
167 | tokens[key] = tokens[key].cuda()
168 | return tokens
169 |
170 | def forward(self, text):
171 | if text is None:
172 | return None
173 | if type(text) is str:
174 | text = [text]
175 | tokens = self.organ2tokens(text)
176 | clip_outputs = self.clip_text_model(**tokens)
177 | text_embedding = clip_outputs.pooler_output
178 | text_embedding = self.dim_align(text_embedding)
179 | return text_embedding
180 |
181 |
182 | class ClipVisionEncoder(nn.Module):
183 | def __init__(self, clip_ckpt):
184 | """
185 | :param clip_ckpt: the list of checkpoints of CLIP could be found at: https://huggingface.co/openai
186 | """
187 | super().__init__()
188 | config = CLIPVisionConfig()
189 | self.clip_vision_model = CLIPVisionModel(config)
190 | self.feature_extractor = AutoFeatureExtractor.from_pretrained(clip_ckpt)
191 | self.dim_align = nn.Linear(512, 768)
192 | # freeze text encoder
193 | for param in self.clip_vision_model.parameters():
194 | param.requires_grad = False
195 |
196 | def images2features(self, images):
197 | features = self.feature_extractor(images=images, return_tensors="pt")
198 | for key in features.keys():
199 | features[key] = features[key].cuda()
200 | return features
201 |
202 | def forward(self, images):
203 | if images is None:
204 | return None
205 | if type(images) is not list:
206 | images = [images]
207 | features = self.images2features(images)
208 | clip_outputs = self.clip_vision_model(**features)
209 | vision_embedding = clip_outputs.pooler_output
210 | vision_embedding = self.dim_align(vision_embedding)
211 | return vision_embedding
212 |
--------------------------------------------------------------------------------
/models/sam2/backbones/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from models.sam2.sam2_utils import MLP
14 |
15 | class ImageEncoder(nn.Module):
16 | def __init__(
17 | self,
18 | trunk: nn.Module,
19 | neck: nn.Module,
20 | scalp: int = 0,
21 | ):
22 | super().__init__()
23 | self.trunk = trunk
24 | self.neck = neck
25 | self.scalp = scalp
26 | assert (
27 | self.trunk.channel_list == self.neck.backbone_channel_list
28 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
29 |
30 | def forward(self, img_sample: torch.Tensor, injection: torch.Tensor) -> torch.Tensor:
31 | # Forward through backbone
32 | features, pos = self.neck(self.trunk(img_sample, injection))
33 | if self.scalp > 0:
34 | # Discard the lowest resolution features
35 | features, pos = features[: -self.scalp], pos[: -self.scalp]
36 |
37 | src = features[-1]
38 | output = {
39 | "vision_features": src,
40 | "vision_pos_enc": pos,
41 | "backbone_fpn": features,
42 | }
43 | return output
44 |
45 |
46 | class FpnNeckInjection(nn.Module):
47 | """
48 | A modified variant of Feature Pyramid Network (FPN) neck with semantic injection
49 | """
50 |
51 | def __init__(
52 | self,
53 | position_encoding: nn.Module,
54 | d_model: int,
55 | backbone_channel_list: List[int],
56 | kernel_size: int = 1,
57 | stride: int = 1,
58 | padding: int = 0,
59 | fpn_interp_model: str = "bilinear",
60 | fuse_type: str = "sum",
61 | fpn_top_down_levels: Optional[List[int]] = None,
62 | ):
63 | """Initialize the neck
64 | :param trunk: the backbone
65 | :param position_encoding: the positional encoding to use
66 | :param d_model: the dimension of the model
67 | :param neck_norm: the normalization to use
68 | """
69 | super().__init__()
70 | self.position_encoding = position_encoding
71 | self.convs = nn.ModuleList()
72 | self.backbone_channel_list = backbone_channel_list
73 | self.d_model = d_model
74 | for dim in backbone_channel_list:
75 | current = nn.Sequential()
76 | current.add_module(
77 | "conv",
78 | nn.Conv2d(
79 | in_channels=dim,
80 | out_channels=d_model,
81 | kernel_size=kernel_size,
82 | stride=stride,
83 | padding=padding,
84 | ),
85 | )
86 |
87 | self.convs.append(current)
88 | self.fpn_interp_model = fpn_interp_model
89 | assert fuse_type in ["sum", "avg"]
90 | self.fuse_type = fuse_type
91 |
92 | # levels to have top-down features in its outputs
93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
94 | # have top-down propagation, while outputs of level 0 and level 1 have only
95 | # lateral features from the same backbone level.
96 | if fpn_top_down_levels is None:
97 | # default is to have top-down features on all levels
98 | fpn_top_down_levels = range(len(self.convs))
99 | self.fpn_top_down_levels = list(fpn_top_down_levels)
100 |
101 | # see https://github.com/tianrun-chen/SAM-Adapter-PyTorch/blob/SAM2-Adapter-for-Segment-Anything-2/models/mmseg/models/sam/image_encoder.py
102 | self.shared_mlp = MLP(d_model, d_model, d_model // 4, 2)
103 | self.unshared_mlps = nn.ModuleList()
104 |
105 | for dim in backbone_channel_list:
106 | current = MLP(d_model // 4, d_model // 4, dim, 2)
107 | self.convs.append(current)
108 |
109 |
110 | def forward(self, xs: List[torch.Tensor], injection: torch.Tensor):
111 |
112 | out = [None] * len(self.convs)
113 | pos = [None] * len(self.convs)
114 | assert len(xs) == len(self.convs)
115 |
116 | shared_features = self.shared_mlp(injection) + injection
117 | # fpn forward pass
118 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
119 | prev_features = None
120 | # forward in top-down order (from low to high resolution)
121 | n = len(self.convs) - 1
122 | for i in range(n, -1, -1):
123 | x = xs[i]
124 | lateral_features = self.convs[n - i](x)
125 | if i in self.fpn_top_down_levels and prev_features is not None:
126 | top_down_features = F.interpolate(
127 | prev_features.to(dtype=torch.float32),
128 | scale_factor=2.0,
129 | mode=self.fpn_interp_model,
130 | align_corners=(
131 | None if self.fpn_interp_model == "nearest" else False
132 | ),
133 | antialias=False,
134 | )
135 | injection = self.unshared_mlps[n - i](shared_features)
136 | prev_features = lateral_features + top_down_features
137 | prev_features = prev_features + injection
138 | if self.fuse_type == "avg":
139 | prev_features /= 2
140 | else:
141 | prev_features = lateral_features
142 | x_out = prev_features
143 | out[i] = x_out
144 | pos[i] = self.position_encoding(x_out).to(x_out.dtype)
145 |
146 | return out, pos
147 |
148 |
149 |
150 | class FpnNeck(nn.Module):
151 | """
152 | A modified variant of Feature Pyramid Network (FPN) neck
153 | (we remove output conv and also do bicubic interpolation similar to ViT
154 | pos embed interpolation)
155 | """
156 |
157 | def __init__(
158 | self,
159 | position_encoding: nn.Module,
160 | d_model: int,
161 | backbone_channel_list: List[int],
162 | kernel_size: int = 1,
163 | stride: int = 1,
164 | padding: int = 0,
165 | fpn_interp_model: str = "bilinear",
166 | fuse_type: str = "sum",
167 | fpn_top_down_levels: Optional[List[int]] = None,
168 | ):
169 | """Initialize the neck
170 | :param trunk: the backbone
171 | :param position_encoding: the positional encoding to use
172 | :param d_model: the dimension of the model
173 | :param neck_norm: the normalization to use
174 | """
175 | super().__init__()
176 | self.position_encoding = position_encoding
177 | self.convs = nn.ModuleList()
178 | self.backbone_channel_list = backbone_channel_list
179 | self.d_model = d_model
180 | for dim in backbone_channel_list:
181 | current = nn.Sequential()
182 | current.add_module(
183 | "conv",
184 | nn.Conv2d(
185 | in_channels=dim,
186 | out_channels=d_model,
187 | kernel_size=kernel_size,
188 | stride=stride,
189 | padding=padding,
190 | ),
191 | )
192 |
193 | self.convs.append(current)
194 | self.fpn_interp_model = fpn_interp_model
195 | assert fuse_type in ["sum", "avg"]
196 | self.fuse_type = fuse_type
197 |
198 | # levels to have top-down features in its outputs
199 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
200 | # have top-down propagation, while outputs of level 0 and level 1 have only
201 | # lateral features from the same backbone level.
202 | if fpn_top_down_levels is None:
203 | # default is to have top-down features on all levels
204 | fpn_top_down_levels = range(len(self.convs))
205 | self.fpn_top_down_levels = list(fpn_top_down_levels)
206 |
207 | def forward(self, xs: List[torch.Tensor]):
208 |
209 | out = [None] * len(self.convs)
210 | pos = [None] * len(self.convs)
211 | assert len(xs) == len(self.convs)
212 | # fpn forward pass
213 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
214 | prev_features = None
215 | # forward in top-down order (from low to high resolution)
216 | n = len(self.convs) - 1
217 | for i in range(n, -1, -1):
218 | x = xs[i]
219 | lateral_features = self.convs[n - i](x)
220 | if i in self.fpn_top_down_levels and prev_features is not None:
221 | top_down_features = F.interpolate(
222 | prev_features.to(dtype=torch.float32),
223 | scale_factor=2.0,
224 | mode=self.fpn_interp_model,
225 | align_corners=(
226 | None if self.fpn_interp_model == "nearest" else False
227 | ),
228 | antialias=False,
229 | )
230 | prev_features = lateral_features + top_down_features
231 | if self.fuse_type == "avg":
232 | prev_features /= 2
233 | else:
234 | prev_features = lateral_features
235 | x_out = prev_features
236 | out[i] = x_out
237 | pos[i] = self.position_encoding(x_out).to(x_out.dtype)
238 |
239 | return out, pos
--------------------------------------------------------------------------------
/models/sam2/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Any, Optional, Tuple
9 |
10 | import numpy as np
11 |
12 | import torch
13 | from torch import nn
14 |
15 |
16 | class PositionEmbeddingSine(nn.Module):
17 | """
18 | This is a more standard version of the position embedding, very similar to the one
19 | used by the Attention Is All You Need paper, generalized to work on images.
20 | """
21 |
22 | def __init__(
23 | self,
24 | num_pos_feats,
25 | temperature: int = 10000,
26 | normalize: bool = True,
27 | scale: Optional[float] = None,
28 | # Following settings only relevant
29 | # for warmping up cache for compilation
30 | warmup_cache: bool = True,
31 | image_size: int = 1024,
32 | strides: Tuple[int] = (4, 8, 16, 32),
33 | ):
34 | super().__init__()
35 | assert num_pos_feats % 2 == 0, "Expecting even model width"
36 | self.num_pos_feats = num_pos_feats // 2
37 | self.temperature = temperature
38 | self.normalize = normalize
39 | if scale is not None and normalize is False:
40 | raise ValueError("normalize should be True if scale is passed")
41 | if scale is None:
42 | scale = 2 * math.pi
43 | self.scale = scale
44 |
45 | self.cache = {}
46 | if warmup_cache and torch.cuda.is_available():
47 | # Warmup cache for cuda, to help with compilation
48 | device = torch.device("cuda")
49 | for stride in strides:
50 | cache_key = (image_size // stride, image_size // stride)
51 | self._pe(1, device, *cache_key)
52 |
53 | def _encode_xy(self, x, y):
54 | # The positions are expected to be normalized
55 | assert len(x) == len(y) and x.ndim == y.ndim == 1
56 | x_embed = x * self.scale
57 | y_embed = y * self.scale
58 |
59 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
60 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
61 |
62 | pos_x = x_embed[:, None] / dim_t
63 | pos_y = y_embed[:, None] / dim_t
64 | pos_x = torch.stack(
65 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
66 | ).flatten(1)
67 | pos_y = torch.stack(
68 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
69 | ).flatten(1)
70 | return pos_x, pos_y
71 |
72 | @torch.no_grad()
73 | def encode_boxes(self, x, y, w, h):
74 | pos_x, pos_y = self._encode_xy(x, y)
75 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
76 | return pos
77 |
78 | encode = encode_boxes # Backwards compatibility
79 |
80 | @torch.no_grad()
81 | def encode_points(self, x, y, labels):
82 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
83 | assert bx == by and nx == ny and bx == bl and nx == nl
84 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
85 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
86 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
87 | return pos
88 |
89 | @torch.no_grad()
90 | def _pe(self, B, device, *cache_key):
91 | H, W = cache_key
92 | if cache_key in self.cache:
93 | return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
94 |
95 | y_embed = (
96 | torch.arange(1, H + 1, dtype=torch.float32, device=device)
97 | .view(1, -1, 1)
98 | .repeat(B, 1, W)
99 | )
100 | x_embed = (
101 | torch.arange(1, W + 1, dtype=torch.float32, device=device)
102 | .view(1, 1, -1)
103 | .repeat(B, H, 1)
104 | )
105 |
106 | if self.normalize:
107 | eps = 1e-6
108 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
109 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
110 |
111 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
112 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
113 |
114 | pos_x = x_embed[:, :, :, None] / dim_t
115 | pos_y = y_embed[:, :, :, None] / dim_t
116 | pos_x = torch.stack(
117 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
118 | ).flatten(3)
119 | pos_y = torch.stack(
120 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
121 | ).flatten(3)
122 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
123 | self.cache[cache_key] = pos[0]
124 | return pos
125 |
126 | @torch.no_grad()
127 | def forward(self, x: torch.Tensor):
128 | B = x.shape[0]
129 | cache_key = (x.shape[-2], x.shape[-1])
130 | return self._pe(B, x.device, *cache_key)
131 |
132 |
133 | class PositionEmbeddingRandom(nn.Module):
134 | """
135 | Positional encoding using random spatial frequencies.
136 | """
137 |
138 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
139 | super().__init__()
140 | if scale is None or scale <= 0.0:
141 | scale = 1.0
142 | self.register_buffer(
143 | "positional_encoding_gaussian_matrix",
144 | scale * torch.randn((2, num_pos_feats)),
145 | )
146 |
147 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
148 | """Positionally encode points that are normalized to [0,1]."""
149 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
150 | coords = 2 * coords - 1
151 | coords = coords @ self.positional_encoding_gaussian_matrix
152 | coords = 2 * np.pi * coords
153 | # outputs d_1 x ... x d_n x C shape
154 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
155 |
156 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
157 | """Generate positional encoding for a grid of the specified size."""
158 | h, w = size
159 | device: Any = self.positional_encoding_gaussian_matrix.device
160 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
161 | y_embed = grid.cumsum(dim=0) - 0.5
162 | x_embed = grid.cumsum(dim=1) - 0.5
163 | y_embed = y_embed / h
164 | x_embed = x_embed / w
165 |
166 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
167 | return pe.permute(2, 0, 1) # C x H x W
168 |
169 | def forward_with_coords(
170 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
171 | ) -> torch.Tensor:
172 | """Positionally encode points that are not normalized to [0,1]."""
173 | coords = coords_input.clone()
174 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
175 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
176 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
177 |
178 |
179 | # Rotary Positional Encoding, adapted from:
180 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
181 | # 2. https://github.com/naver-ai/rope-vit
182 | # 3. https://github.com/lucidrains/rotary-embedding-torch
183 |
184 |
185 | def init_t_xy(end_x: int, end_y: int):
186 | t = torch.arange(end_x * end_y, dtype=torch.float32)
187 | t_x = (t % end_x).float()
188 | t_y = torch.div(t, end_x, rounding_mode="floor").float()
189 | return t_x, t_y
190 |
191 |
192 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
193 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
194 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
195 |
196 | t_x, t_y = init_t_xy(end_x, end_y)
197 | freqs_x = torch.outer(t_x, freqs_x)
198 | freqs_y = torch.outer(t_y, freqs_y)
199 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
200 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
201 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
202 |
203 |
204 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
205 | ndim = x.ndim
206 | assert 0 <= 1 < ndim
207 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
208 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
209 | return freqs_cis.view(*shape)
210 |
211 |
212 | def apply_rotary_enc(
213 | xq: torch.Tensor,
214 | xk: torch.Tensor,
215 | freqs_cis: torch.Tensor,
216 | repeat_freqs_k: bool = False,
217 | ):
218 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
219 | xk_ = (
220 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
221 | if xk.shape[-2] != 0
222 | else None
223 | )
224 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
225 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
226 | if xk_ is None:
227 | # no keys to rotate, due to dropout
228 | return xq_out.type_as(xq).to(xq.device), xk
229 | # repeat freqs along seq_len dim to match k seq_len
230 | if repeat_freqs_k:
231 | r = xk_.shape[-2] // xq_.shape[-2]
232 | if freqs_cis.is_cuda:
233 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
234 | else:
235 | # torch.repeat on complex numbers may not be supported on non-CUDA devices
236 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
237 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
238 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
239 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
--------------------------------------------------------------------------------
/models/sam2/backbones/hieradet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from functools import partial
9 | from typing import List, Tuple, Union
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from iopath.common.file_io import g_pathmgr
15 |
16 | from models.sam2.backbones.utils import (
17 | PatchEmbed,
18 | window_partition,
19 | window_unpartition,
20 | )
21 |
22 | from models.sam2.sam2_utils import DropPath, MLP
23 |
24 |
25 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26 | if pool is None:
27 | return x
28 | # (B, H, W, C) -> (B, C, H, W)
29 | x = x.permute(0, 3, 1, 2)
30 | x = pool(x)
31 | # (B, C, H', W') -> (B, H', W', C)
32 | x = x.permute(0, 2, 3, 1)
33 | if norm:
34 | x = norm(x)
35 |
36 | return x
37 |
38 |
39 | class MultiScaleAttention(nn.Module):
40 | def __init__(
41 | self,
42 | dim: int,
43 | dim_out: int,
44 | num_heads: int,
45 | q_pool: nn.Module = None,
46 | ):
47 | super().__init__()
48 |
49 | self.dim = dim
50 | self.dim_out = dim_out
51 | self.num_heads = num_heads
52 | self.q_pool = q_pool
53 | self.qkv = nn.Linear(dim, dim_out * 3)
54 | self.proj = nn.Linear(dim_out, dim_out)
55 |
56 | def forward(self, x: torch.Tensor) -> torch.Tensor:
57 | B, H, W, _ = x.shape
58 | # qkv with shape (B, H * W, 3, nHead, C)
59 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60 | # q, k, v with shape (B, H * W, nheads, C)
61 | q, k, v = torch.unbind(qkv, 2)
62 |
63 | # Q pooling (for downsample at stage changes)
64 | if self.q_pool:
65 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66 | H, W = q.shape[1:3] # downsampled shape
67 | q = q.reshape(B, H * W, self.num_heads, -1)
68 |
69 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70 | x = F.scaled_dot_product_attention(
71 | q.transpose(1, 2),
72 | k.transpose(1, 2),
73 | v.transpose(1, 2),
74 | )
75 | # Transpose back
76 | x = x.transpose(1, 2)
77 | x = x.reshape(B, H, W, -1)
78 |
79 | x = self.proj(x)
80 |
81 | return x
82 |
83 |
84 | class MultiScaleBlock(nn.Module):
85 | def __init__(
86 | self,
87 | dim: int,
88 | dim_out: int,
89 | num_heads: int,
90 | mlp_ratio: float = 4.0,
91 | drop_path: float = 0.0,
92 | norm_layer: Union[nn.Module, str] = "LayerNorm",
93 | q_stride: Tuple[int, int] = None,
94 | act_layer: nn.Module = nn.GELU,
95 | window_size: int = 0,
96 | ):
97 | super().__init__()
98 |
99 | if isinstance(norm_layer, str):
100 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101 |
102 | self.dim = dim
103 | self.dim_out = dim_out
104 | self.norm1 = norm_layer(dim)
105 |
106 | self.window_size = window_size
107 |
108 | self.pool, self.q_stride = None, q_stride
109 | if self.q_stride:
110 | self.pool = nn.MaxPool2d(
111 | kernel_size=q_stride, stride=q_stride, ceil_mode=False
112 | )
113 |
114 | self.attn = MultiScaleAttention(
115 | dim,
116 | dim_out,
117 | num_heads=num_heads,
118 | q_pool=self.pool,
119 | )
120 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121 |
122 | self.norm2 = norm_layer(dim_out)
123 | self.mlp = MLP(
124 | dim_out,
125 | int(dim_out * mlp_ratio),
126 | dim_out,
127 | num_layers=2,
128 | activation=act_layer,
129 | )
130 |
131 | if dim != dim_out:
132 | self.proj = nn.Linear(dim, dim_out)
133 |
134 | def forward(self, x: torch.Tensor) -> torch.Tensor:
135 | shortcut = x # B, H, W, C
136 | x = self.norm1(x)
137 |
138 | # Skip connection
139 | if self.dim != self.dim_out:
140 | shortcut = do_pool(self.proj(x), self.pool)
141 |
142 | # Window partition
143 | window_size = self.window_size
144 | if window_size > 0:
145 | H, W = x.shape[1], x.shape[2]
146 | x, pad_hw = window_partition(x, window_size)
147 |
148 | # Window Attention + Q Pooling (if stage change)
149 | x = self.attn(x)
150 | if self.q_stride:
151 | # Shapes have changed due to Q pooling
152 | window_size = self.window_size // self.q_stride[0]
153 | H, W = shortcut.shape[1:3]
154 |
155 | pad_h = (window_size - H % window_size) % window_size
156 | pad_w = (window_size - W % window_size) % window_size
157 | pad_hw = (H + pad_h, W + pad_w)
158 |
159 | # Reverse window partition
160 | if self.window_size > 0:
161 | x = window_unpartition(x, window_size, pad_hw, (H, W))
162 |
163 | x = shortcut + self.drop_path(x)
164 | # MLP
165 | x = x + self.drop_path(self.mlp(self.norm2(x)))
166 | return x
167 |
168 |
169 | class Hiera(nn.Module):
170 | """
171 | Reference: https://arxiv.org/abs/2306.00989
172 | """
173 |
174 | def __init__(
175 | self,
176 | embed_dim: int = 96, # initial embed dim
177 | num_heads: int = 1, # initial number of heads
178 | drop_path_rate: float = 0.0, # stochastic depth
179 | q_pool: int = 3, # number of q_pool stages
180 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182 | dim_mul: float = 2.0, # dim_mul factor at stage shift
183 | head_mul: float = 2.0, # head_mul factor at stage shift
184 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185 | # window size per stage, when not using global att.
186 | window_spec: Tuple[int, ...] = (
187 | 8,
188 | 4,
189 | 14,
190 | 7,
191 | ),
192 | # global attn in these blocks
193 | global_att_blocks: Tuple[int, ...] = (
194 | 12,
195 | 16,
196 | 20,
197 | ),
198 | weights_path=None,
199 | return_interm_layers=True, # return feats from every stage
200 | ):
201 | super().__init__()
202 |
203 | assert len(stages) == len(window_spec)
204 | self.window_spec = window_spec
205 |
206 | depth = sum(stages)
207 | self.q_stride = q_stride
208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209 | assert 0 <= q_pool <= len(self.stage_ends[:-1])
210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211 | self.return_interm_layers = return_interm_layers
212 |
213 | self.patch_embed = PatchEmbed(
214 | embed_dim=embed_dim,
215 | )
216 | # Which blocks have global att?
217 | self.global_att_blocks = global_att_blocks
218 |
219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221 | self.pos_embed = nn.Parameter(
222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223 | )
224 | self.pos_embed_window = nn.Parameter(
225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226 | )
227 |
228 | dpr = [
229 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
230 | ] # stochastic depth decay rule
231 |
232 | cur_stage = 1
233 | self.blocks = nn.ModuleList()
234 |
235 | for i in range(depth):
236 | dim_out = embed_dim
237 | # lags by a block, so first block of
238 | # next stage uses an initial window size
239 | # of previous stage and final window size of current stage
240 | window_size = self.window_spec[cur_stage - 1]
241 |
242 | if self.global_att_blocks is not None:
243 | window_size = 0 if i in self.global_att_blocks else window_size
244 |
245 | if i - 1 in self.stage_ends:
246 | dim_out = int(embed_dim * dim_mul)
247 | num_heads = int(num_heads * head_mul)
248 | cur_stage += 1
249 |
250 | block = MultiScaleBlock(
251 | dim=embed_dim,
252 | dim_out=dim_out,
253 | num_heads=num_heads,
254 | drop_path=dpr[i],
255 | q_stride=self.q_stride if i in self.q_pool_blocks else None,
256 | window_size=window_size,
257 | )
258 |
259 | embed_dim = dim_out
260 | self.blocks.append(block)
261 |
262 | self.channel_list = (
263 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264 | if return_interm_layers
265 | else [self.blocks[-1].dim_out]
266 | )
267 |
268 | if weights_path is not None:
269 | with g_pathmgr.open(weights_path, "rb") as f:
270 | chkpt = torch.load(f, map_location="cpu")
271 | logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272 |
273 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274 | h, w = hw
275 | window_embed = self.pos_embed_window
276 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277 | pos_embed = pos_embed + window_embed.tile(
278 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279 | )
280 | pos_embed = pos_embed.permute(0, 2, 3, 1)
281 | return pos_embed
282 |
283 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284 | x = self.patch_embed(x)
285 | # x: (B, H, W, C)
286 |
287 | # Add pos embed
288 | x = x + self._get_pos_embed(x.shape[1:3])
289 |
290 | outputs = []
291 | for i, blk in enumerate(self.blocks):
292 | x = blk(x)
293 | if (i == self.stage_ends[-1]) or (
294 | i in self.stage_ends and self.return_interm_layers
295 | ):
296 | feats = x.permute(0, 3, 1, 2)
297 | outputs.append(feats)
298 |
299 | return outputs
300 |
301 | def get_layer_id(self, layer_name):
302 | # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303 | num_layers = self.get_num_layers()
304 |
305 | if layer_name.find("rel_pos") != -1:
306 | return num_layers + 1
307 | elif layer_name.find("pos_embed") != -1:
308 | return 0
309 | elif layer_name.find("patch_embed") != -1:
310 | return 0
311 | elif layer_name.find("blocks") != -1:
312 | return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313 | else:
314 | return num_layers + 1
315 |
316 | def get_num_layers(self) -> int:
317 | return len(self.blocks)
--------------------------------------------------------------------------------
/models/sam2/sam/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from functools import partial
9 | from typing import Tuple, Type
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn, Tensor
14 |
15 | from models.sam2.position_encoding import apply_rotary_enc, compute_axial_cis
16 | from models.sam2.sam2_utils import MLP
17 |
18 |
19 | class TwoWayTransformer(nn.Module):
20 | def __init__(
21 | self,
22 | depth: int,
23 | embedding_dim: int,
24 | num_heads: int,
25 | mlp_dim: int,
26 | activation: Type[nn.Module] = nn.ReLU,
27 | attention_downsample_rate: int = 2,
28 | ) -> None:
29 | """
30 | A transformer decoder that attends to an input image using
31 | queries whose positional embedding is supplied.
32 |
33 | Args:
34 | depth (int): number of layers in the transformer
35 | embedding_dim (int): the channel dimension for the input embeddings
36 | num_heads (int): the number of heads for multihead attention. Must
37 | divide embedding_dim
38 | mlp_dim (int): the channel dimension internal to the MLP block
39 | activation (nn.Module): the activation to use in the MLP block
40 | """
41 | super().__init__()
42 | self.depth = depth
43 | self.embedding_dim = embedding_dim
44 | self.num_heads = num_heads
45 | self.mlp_dim = mlp_dim
46 | self.layers = nn.ModuleList()
47 |
48 | for i in range(depth):
49 | self.layers.append(
50 | TwoWayAttentionBlock(
51 | embedding_dim=embedding_dim,
52 | num_heads=num_heads,
53 | mlp_dim=mlp_dim,
54 | activation=activation,
55 | attention_downsample_rate=attention_downsample_rate,
56 | skip_first_layer_pe=(i == 0),
57 | )
58 | )
59 |
60 | self.final_attn_token_to_image = Attention(
61 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
62 | )
63 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
64 |
65 | def forward(
66 | self,
67 | image_embedding: Tensor,
68 | image_pe: Tensor,
69 | point_embedding: Tensor,
70 | ) -> Tuple[Tensor, Tensor]:
71 | """
72 | Args:
73 | image_embedding (torch.Tensor): image to attend to. Should be shape
74 | B x embedding_dim x h x w for any h and w.
75 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
76 | have the same shape as image_embedding.
77 | point_embedding (torch.Tensor): the embedding to add to the query points.
78 | Must have shape B x N_points x embedding_dim for any N_points.
79 |
80 | Returns:
81 | torch.Tensor: the processed point_embedding
82 | torch.Tensor: the processed image_embedding
83 | """
84 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
85 | bs, c, h, w = image_embedding.shape
86 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
87 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
88 |
89 | # Prepare queries
90 | queries = point_embedding
91 | keys = image_embedding
92 |
93 | # Apply transformer blocks and final layernorm
94 | for layer in self.layers:
95 | queries, keys = layer(
96 | queries=queries,
97 | keys=keys,
98 | query_pe=point_embedding,
99 | key_pe=image_pe,
100 | )
101 |
102 | # Apply the final attention layer from the points to the image
103 | q = queries + point_embedding
104 | k = keys + image_pe
105 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
106 | queries = queries + attn_out
107 | queries = self.norm_final_attn(queries)
108 |
109 | return queries, keys
110 |
111 |
112 | class TwoWayAttentionBlock(nn.Module):
113 | def __init__(
114 | self,
115 | embedding_dim: int,
116 | num_heads: int,
117 | mlp_dim: int = 2048,
118 | activation: Type[nn.Module] = nn.ReLU,
119 | attention_downsample_rate: int = 2,
120 | skip_first_layer_pe: bool = False,
121 | ) -> None:
122 | """
123 | A transformer block with four layers: (1) self-attention of sparse
124 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
125 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
126 | inputs.
127 |
128 | Arguments:
129 | embedding_dim (int): the channel dimension of the embeddings
130 | num_heads (int): the number of heads in the attention layers
131 | mlp_dim (int): the hidden dimension of the mlp block
132 | activation (nn.Module): the activation of the mlp block
133 | skip_first_layer_pe (bool): skip the PE on the first layer
134 | """
135 | super().__init__()
136 | self.self_attn = Attention(embedding_dim, num_heads)
137 | self.norm1 = nn.LayerNorm(embedding_dim)
138 |
139 | self.cross_attn_token_to_image = Attention(
140 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
141 | )
142 | self.norm2 = nn.LayerNorm(embedding_dim)
143 |
144 | self.mlp = MLP(
145 | embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
146 | )
147 | self.norm3 = nn.LayerNorm(embedding_dim)
148 |
149 | self.norm4 = nn.LayerNorm(embedding_dim)
150 | self.cross_attn_image_to_token = Attention(
151 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
152 | )
153 |
154 | self.skip_first_layer_pe = skip_first_layer_pe
155 |
156 | def forward(
157 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
158 | ) -> Tuple[Tensor, Tensor]:
159 | # Self attention block
160 | if self.skip_first_layer_pe:
161 | queries = self.self_attn(q=queries, k=queries, v=queries)
162 | else:
163 | q = queries + query_pe
164 | attn_out = self.self_attn(q=q, k=q, v=queries)
165 | queries = queries + attn_out
166 | queries = self.norm1(queries)
167 |
168 | # Cross attention block, tokens attending to image embedding
169 | q = queries + query_pe
170 | k = keys + key_pe
171 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
172 | queries = queries + attn_out
173 | queries = self.norm2(queries)
174 |
175 | # MLP block
176 | mlp_out = self.mlp(queries)
177 | queries = queries + mlp_out
178 | queries = self.norm3(queries)
179 |
180 | # Cross attention block, image embedding attending to tokens
181 | q = queries + query_pe
182 | k = keys + key_pe
183 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
184 | keys = keys + attn_out
185 | keys = self.norm4(keys)
186 |
187 | return queries, keys
188 |
189 |
190 | class Attention(nn.Module):
191 | """
192 | An attention layer that allows for downscaling the size of the embedding
193 | after projection to queries, keys, and values.
194 | """
195 |
196 | def __init__(
197 | self,
198 | embedding_dim: int,
199 | num_heads: int,
200 | downsample_rate: int = 1,
201 | dropout: float = 0.0,
202 | kv_in_dim: int = None,
203 | ) -> None:
204 | super().__init__()
205 | self.embedding_dim = embedding_dim
206 | self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
207 | self.internal_dim = embedding_dim // downsample_rate
208 | self.num_heads = num_heads
209 | assert (
210 | self.internal_dim % num_heads == 0
211 | ), "num_heads must divide embedding_dim."
212 |
213 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
214 | self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
215 | self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
216 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
217 |
218 | self.dropout_p = dropout
219 |
220 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
221 | b, n, c = x.shape
222 | x = x.reshape(b, n, num_heads, c // num_heads)
223 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
224 |
225 | def _recombine_heads(self, x: Tensor) -> Tensor:
226 | b, n_heads, n_tokens, c_per_head = x.shape
227 | x = x.transpose(1, 2)
228 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
229 |
230 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
231 | # Input projections
232 | q = self.q_proj(q)
233 | k = self.k_proj(k)
234 | v = self.v_proj(v)
235 |
236 | # Separate into heads
237 | q = self._separate_heads(q, self.num_heads)
238 | k = self._separate_heads(k, self.num_heads)
239 | v = self._separate_heads(v, self.num_heads)
240 |
241 | dropout_p = self.dropout_p if self.training else 0.0
242 | # Attention
243 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
244 |
245 | out = self._recombine_heads(out)
246 | out = self.out_proj(out)
247 |
248 | return out
249 |
250 |
251 | class RoPEAttention(Attention):
252 | """Attention with rotary position encoding."""
253 |
254 | def __init__(
255 | self,
256 | *args,
257 | rope_theta=10000.0,
258 | # whether to repeat q rope to match k length
259 | # this is needed for cross-attention to memories
260 | rope_k_repeat=False,
261 | feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
262 | **kwargs,
263 | ):
264 | super().__init__(*args, **kwargs)
265 |
266 | self.compute_cis = partial(
267 | compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
268 | )
269 | freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
270 | self.freqs_cis = (
271 | freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
272 | )
273 | self.rope_k_repeat = rope_k_repeat
274 |
275 | def forward(
276 | self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
277 | ) -> Tensor:
278 | # Input projections
279 | q = self.q_proj(q)
280 | k = self.k_proj(k)
281 | v = self.v_proj(v)
282 |
283 | # Separate into heads
284 | q = self._separate_heads(q, self.num_heads)
285 | k = self._separate_heads(k, self.num_heads)
286 | v = self._separate_heads(v, self.num_heads)
287 |
288 | # Apply rotary position encoding
289 | w = h = math.sqrt(q.shape[-2])
290 | self.freqs_cis = self.freqs_cis.to(q.device)
291 | if self.freqs_cis.shape[0] != q.shape[-2]:
292 | self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
293 | if q.shape[-2] != k.shape[-2]:
294 | assert self.rope_k_repeat
295 |
296 | num_k_rope = k.size(-2) - num_k_exclude_rope
297 | q, k[:, :, :num_k_rope] = apply_rotary_enc(
298 | q,
299 | k[:, :, :num_k_rope],
300 | freqs_cis=self.freqs_cis,
301 | repeat_freqs_k=self.rope_k_repeat,
302 | )
303 |
304 | dropout_p = self.dropout_p if self.training else 0.0
305 | # Attention
306 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
307 |
308 | out = self._recombine_heads(out)
309 | out = self.out_proj(out)
310 |
311 | return out
--------------------------------------------------------------------------------
/datasets/data_loader.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from monai import data, transforms
5 | import itertools
6 | from torch.utils.data.distributed import DistributedSampler
7 | from torch.utils.data import Dataset, ConcatDataset
8 | import os
9 | import ast
10 | from scipy import sparse
11 | import random
12 | import json
13 | from torch.utils.data import DataLoader
14 | from monai.transforms import (
15 | Compose,
16 | AddChanneld,
17 | CropForegroundd,
18 | SpatialPadd,
19 | Resized,
20 | RandCropByPosNegLabeld,
21 | RandFlipd,
22 | RandScaleIntensityd,
23 | RandShiftIntensityd,
24 | ToTensord,
25 | )
26 |
27 | def read_json_file(file_path):
28 | try:
29 | with open(file_path, 'r', encoding='utf-8') as file:
30 | data = json.load(file)
31 | return data
32 | except Exception as e:
33 | print(f"{e}")
34 | return None
35 |
36 |
37 | class UnionDataset(Dataset):
38 | def __init__(self, concat_dataset, datasets):
39 | self.datasets = datasets
40 | self.lengths = [len(d) for d in datasets]
41 | self.offsets = torch.cumsum(torch.tensor([0] + self.lengths), dim=0)
42 | self.concat_dataset = concat_dataset
43 |
44 | def __len__(self):
45 | return sum(self.lengths)
46 |
47 | def __getitem__(self, idx):
48 | return self.concat_dataset[idx]
49 |
50 |
51 | class UniversalDataset(Dataset):
52 | def __init__(self, data, transform, test_mode, organ_list):
53 | self.data = data
54 | self.transform = transform
55 | # one pos point is base set
56 | self.num_positive_extra_max = 10
57 | self.num_negative_extra_max = 10
58 | self.test_mode = test_mode
59 | self.bbox_shift = 10 if test_mode else 0
60 | print(organ_list)
61 | organ_list.remove('background')
62 | self.target_list = organ_list
63 |
64 | def __len__(self):
65 | return len(self.data)
66 |
67 | def __getitem__(self, idx):
68 | # get path
69 | item_dict = self.data[idx]
70 | ct_path, gt_path = item_dict['image'], item_dict['label']
71 | gt_shape = ast.literal_eval(gt_path.split('.')[-2].split('_')[-1])
72 |
73 | # load data
74 | ct_array = np.load(ct_path)[0]
75 | allmatrix_sp = sparse.load_npz(gt_path)
76 | gt_array = allmatrix_sp.toarray().reshape(gt_shape)
77 |
78 | # transform
79 | if self.test_mode:
80 | item_ori = {
81 | 'image': ct_array,
82 | 'label': gt_array,
83 | }
84 | else:
85 | item_ori = {
86 | 'image': ct_array,
87 | 'label': gt_array,
88 | }
89 | if self.transform is not None:
90 | item = self.transform(item_ori)
91 |
92 | if type(item) == list:
93 | assert len(item) == 1
94 | item = item[0]
95 |
96 | assert type(item) != list
97 | item['organ_name_list'] = self.target_list
98 | item['post_label'] = item['label']
99 | post_item = self.std_keys(item)
100 | return post_item
101 |
102 | def std_keys(self, post_item):
103 | keys_to_remain = ['image', 'post_label', 'organ_name_list']
104 | keys_to_remove = post_item.keys() - keys_to_remain
105 | for key in keys_to_remove:
106 | del post_item[key]
107 | return post_item
108 |
109 |
110 | def query_descriptions(input_name, data, min_count=4, max_count=6):
111 | input_name = input_name.strip().lower()
112 | matching_keys = [k for k in data.keys() if k.lower() == input_name]
113 |
114 | if not matching_keys:
115 | return []
116 |
117 | descriptions = data[matching_keys[0]]
118 | count = len(descriptions)
119 |
120 | select_count = min(max_count, count)
121 | select_count = max(min_count, select_count)
122 |
123 | selected = descriptions[:select_count]
124 |
125 | return selected
126 |
127 |
128 |
129 | class BatchedDistributedSampler(DistributedSampler):
130 | def __init__(self, dataset, shuffle, batch_size, num_replicas=None, rank=None):
131 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
132 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
133 | self.total_size = self.num_samples * self.num_replicas
134 | self.batch_size = batch_size
135 |
136 | def __iter__(self):
137 | print('run BatchedDistributedSampler iter')
138 | indices = list(range(len(self.dataset)))
139 | # indices += indices[:(self.total_size - len(indices))]
140 | # assert len(indices) == self.total_size
141 |
142 | indices = [indices[i:i + l] for i, l in zip(self.dataset.offsets[:-1], self.dataset.lengths)]
143 |
144 | if self.shuffle:
145 | for idx, subset_indices in enumerate(indices):
146 | random.shuffle(indices[idx])
147 |
148 | # drop subset last
149 | for idx, subset_indices in enumerate(indices):
150 | r = len(subset_indices) % self.batch_size
151 | if r > 0:
152 | indices[idx] = indices[idx][:-r]
153 | indices = list(itertools.chain(*indices))
154 | indices = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
155 | if self.shuffle:
156 | random.shuffle(indices)
157 |
158 | batch_num = len(indices)
159 | replicas_size = batch_num // self.num_replicas
160 | start = self.rank * replicas_size
161 | end = start + replicas_size if self.rank != self.num_replicas - 1 else batch_num
162 | batched_indices = list(itertools.chain(*(indices[start:end])))
163 | ##
164 | indices = list(itertools.chain(*indices))
165 | self.total_size = len(indices)
166 | self.num_samples = self.total_size // self.num_replicas
167 | ##
168 | return iter(batched_indices)
169 |
170 |
171 | def collate_fn(batch, dictionary_path):
172 | term_dictionary = read_json_file(dictionary_path)
173 | images = []
174 | organ_name_list = None
175 | post_labels = []
176 | organ_description_list = []
177 | for sample in batch:
178 | images.append(sample['image'])
179 | assert organ_name_list is None or organ_name_list == sample['organ_name_list']
180 | organ_name_list = sample['organ_name_list']
181 | for organ in organ_name_list:
182 | organ_description_list.append(query_descriptions(organ, term_dictionary, 4, 6))
183 | post_labels.append(sample['post_label'])
184 | return {
185 | 'image': torch.stack(images, dim=0),
186 | 'organ_name_list': organ_name_list,
187 | 'post_label': torch.stack(post_labels, dim=0),
188 | 'text': organ_description_list
189 | }
190 |
191 |
192 | class MinMaxNormalization(transforms.Transform):
193 | def __call__(self, data):
194 | d = dict(data)
195 | k = "image"
196 | d[k] = d[k] - d[k].min()
197 | d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
198 | return d
199 |
200 |
201 | class DimTranspose(transforms.Transform):
202 | def __init__(self, keys):
203 | self.keys = keys
204 |
205 | def __call__(self, data):
206 | d = dict(data)
207 | for key in self.keys:
208 | d[key] = np.swapaxes(d[key], -1, -3)
209 | return d
210 |
211 |
212 | def build_concat_dataset(root_path, dataset_codes, transform):
213 | concat_dataset = []
214 | CombinationDataset_len = 0
215 | for dataset_code in dataset_codes:
216 | datalist_json = os.path.join(root_path, dataset_code, f'{dataset_code}.json')
217 | with open(datalist_json, 'r') as f:
218 | dataset_dict = json.load(f)
219 | datalist = dataset_dict['train']
220 | universal_ds = UniversalDataset(data=datalist, transform=transform, test_mode=False,
221 | organ_list=list(dataset_dict['labels'].values()))
222 | concat_dataset.append(universal_ds)
223 | CombinationDataset_len += len(universal_ds)
224 | print(f'CombinationDataset loaded, dataset size: {CombinationDataset_len}')
225 | return UnionDataset(ConcatDataset(concat_dataset), concat_dataset)
226 |
227 |
228 | def get_loader(args):
229 | if args.mode == 'train':
230 | train_transform = Compose(
231 | [
232 | AddChanneld(keys=["image"]),
233 | DimTranspose(keys=["image", "label"]),
234 | MinMaxNormalization(),
235 | CropForegroundd(keys=["image", "label"], source_key="image"),
236 | SpatialPadd(keys=["image", "label"], spatial_size=args.spatial_size,
237 | mode='constant'),
238 | transforms.OneOf(transforms=[
239 | Resized(keys=["image", "label"], spatial_size=args.spatial_size),
240 | RandCropByPosNegLabeld(
241 | keys=["image", "label"],
242 | label_key="label",
243 | spatial_size=args.spatial_size,
244 | pos=2,
245 | neg=1,
246 | num_samples=1,
247 | image_key="image",
248 | image_threshold=0,
249 | ),
250 | ],
251 | weights=[1, 1]
252 | ),
253 | RandFlipd(keys=["image", "label"], prob=args.rand_flipped_prob, spatial_axis=0),
254 | RandFlipd(keys=["image", "label"], prob=args.rand_flipped_prob, spatial_axis=1),
255 | RandFlipd(keys=["image", "label"], prob=args.rand_flipped_prob, spatial_axis=2),
256 | RandScaleIntensityd(keys="image", factors=0.1, prob=args.rand_scale_intensityd_prob),
257 | RandShiftIntensityd(keys="image", offsets=0.1, prob=args.rand_shift_intensityd_prob),
258 | Resized(keys=["image", "label"], spatial_size=args.spatial_size),
259 | ToTensord(keys=["image", "label"]),
260 | ]
261 | )
262 |
263 | print(f'----- train on combination dataset -----')
264 | combination_train_ds = build_concat_dataset(root_path=args.data_dir, dataset_codes=args.dataset_codes,
265 | transform=train_transform)
266 | train_sampler = BatchedDistributedSampler(combination_train_ds, shuffle=True,
267 | batch_size=args.batch_size) if args.dist else None
268 | train_loader = DataLoader(
269 | combination_train_ds,
270 | batch_size=args.batch_size,
271 | shuffle=(train_sampler is None),
272 | num_workers=args.num_workers,
273 | sampler=train_sampler,
274 | pin_memory=True,
275 | persistent_workers=True,
276 | collate_fn=collate_fn,
277 | )
278 | return train_loader
279 | elif args.mode == 'test':
280 | test_transform = Compose(
281 | [
282 | AddChanneld(keys=["image"]),
283 | DimTranspose(keys=["image", "label"]),
284 | MinMaxNormalization(),
285 | CropForegroundd(keys=["image", "label"], source_key="image"),
286 | SpatialPadd(keys=["image", "label"], spatial_size=args.spatial_size,
287 | mode='constant'),
288 | Resized(keys=["image", "label"], spatial_size=args.spatial_size),
289 | ToTensord(keys=["image", "label"]),
290 | ]
291 | )
292 |
293 | print(f'----- test on combination dataset -----')
294 | combination_test_ds = build_concat_dataset(root_path=args.data_dir, dataset_codes=args.dataset_codes,
295 | transform=test_transform)
296 | test_sampler = BatchedDistributedSampler(combination_test_ds, shuffle=False,
297 | batch_size=args.batch_size) if args.dist else None
298 | test_loader = DataLoader(
299 | combination_test_ds,
300 | batch_size=args.batch_size,
301 | shuffle=False,
302 | num_workers=args.num_workers,
303 | sampler=test_sampler,
304 | pin_memory=True,
305 | persistent_workers=True,
306 | collate_fn=collate_fn,
307 | )
308 | return test_loader
309 | else:
310 | raise ValueError("mode should be either 'train' or 'test'")
--------------------------------------------------------------------------------
/datasets/dataset_info.txt:
--------------------------------------------------------------------------------
1 | dataset_name = {
2 | '0000': 'CHAOS',
3 | '0001': 'HaN-Seg',
4 | '0002': 'AMOS22',
5 | '0003': 'AbdomenCT-1k',
6 | '0004': 'KiTS23',
7 | '0005': 'KiPA22',
8 | '0006': 'KiTS19',
9 | '0007': 'BCTV',
10 | '0008': 'Pancreas-CT',
11 | '0009': '3D-IRCADb',
12 | '0010': 'AbdomenCT-12organ',
13 | '0011': 'TotalSegmentator',
14 | '0012': 'CT-ORG',
15 | '0013': 'WORD',
16 | '0014': 'VerSe19',
17 | '0015': 'VerSe20',
18 | '0016': 'SLIVER07',
19 | '0017': 'QUBIQ',
20 | '0018': 'MSD-colon',
21 | '0019': 'MSD-hepatic_vessel',
22 | '0020': 'MSD-liver',
23 | '0021': 'MSD-lung',
24 | '0022': 'MSD-pancreas',
25 | '0023': 'MSD-spleen',
26 | '0024': 'LUNA16',
27 | }
28 |
29 | dataset_info_raw = {
30 | '0000': ['liver'],
31 | '0001': ['A_Carotid_L','A_Carotid_R','Arytenoid','Bone_Mandible','Brainstem','BuccalMucosa','Cavity_Oral','Cochlea_L','Cochlea_R','Cricopharyngeus','Esophagus_S','Eye_AL','Eye_AR','Eye_PL','Eye_PR','Glnd_Lacrimal_L','Glnd_Lacrimal_R','Glnd_Submand_L','Glnd_Submand_R','Glnd_Thyroid','Glottis','Larynx_SG','Lips','OpticChiasm','OpticNrv_L','OpticNrv_R','Parotid_L','Parotid_R','Pituitary','SpinalCord'],
32 | '0002': ["spleen", "right kidney", "left kidney", "gall bladder", "esophagus", "liver", "stomach", "arota", "postcava", "pancreas", "right adrenal gland", "left adrenal gland", "duodenum", "bladder", "prostate/uterus"],
33 | '0003': ['liver', 'kidney', 'spleen', 'pancreas'],
34 | '0004': ['kidney', 'kidney tumor', 'kidney cyst'],
35 | '0005': ['Renal vein', 'Kidney', 'Renal artery', 'Tumor'],
36 | '0006': ['kidney', 'kidney tumor'],
37 | '0007': ['spleen','right kidney','left kidney','gallbladder','esophagus','liver','stomach','aorta','inferior vena cava','portal vein and splenic vein','pancreas','right adrenal gland','left adrenal gland'],
38 | '0008': ['pancreas'],
39 | '0009': ['Stones', 'artery', 'biliarysystem', 'bladder', 'bone', 'colon', 'gallbladder', 'heart', 'kidneys', 'leftkidney', 'leftlung', 'leftsurrenalgland', 'leftsurretumor', 'liver', 'livercyst', 'liverkyst', 'liverkyste', 'livertumor', 'livertumor01', 'livertumor02', 'livertumor03', 'livertumor04', 'livertumor05', 'livertumor06', 'livertumor07', 'livertumor1', 'livertumor2', 'livertumors', 'lungs', 'metal', 'metastasectomie', 'pancreas', 'portalvein', 'portalvein1', 'rightkidney', 'rightlung', 'rightsurrenalgland', 'rightsurretumor', 'skin', 'smallintestin', 'spleen', 'stomach', 'surrenalgland', 'tumor', 'uterus', 'venacava', 'venoussystem'],
40 | '0010': ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
41 | '0011': ['adrenal_gland_left', 'adrenal_gland_right', 'aorta', 'autochthon_left', 'autochthon_right', 'brain', 'clavicula_left', 'clavicula_right', 'colon', 'duodenum', 'esophagus', 'face', 'femur_left', 'femur_right', 'gallbladder', 'gluteus_maximus_left', 'gluteus_maximus_right', 'gluteus_medius_left', 'gluteus_medius_right', 'gluteus_minimus_left', 'gluteus_minimus_right', 'heart_atrium_left', 'heart_atrium_right', 'heart_myocardium', 'heart_ventricle_left', 'heart_ventricle_right', 'hip_left', 'hip_right', 'humerus_left', 'humerus_right', 'iliac_artery_left', 'iliac_artery_right', 'iliac_vena_left', 'iliac_vena_right', 'iliopsoas_left', 'iliopsoas_right', 'inferior_vena_cava', 'kidney_left', 'kidney_right', 'liver', 'lung_lower_lobe_left', 'lung_lower_lobe_right', 'lung_middle_lobe_right', 'lung_upper_lobe_left', 'lung_upper_lobe_right', 'pancreas', 'portal_vein_and_splenic_vein', 'pulmonary_artery', 'rib_left_1', 'rib_left_10', 'rib_left_11', 'rib_left_12', 'rib_left_2', 'rib_left_3', 'rib_left_4', 'rib_left_5', 'rib_left_6', 'rib_left_7', 'rib_left_8', 'rib_left_9', 'rib_right_1', 'rib_right_10', 'rib_right_11', 'rib_right_12', 'rib_right_2', 'rib_right_3', 'rib_right_4', 'rib_right_5', 'rib_right_6', 'rib_right_7', 'rib_right_8', 'rib_right_9', 'sacrum', 'scapula_left', 'scapula_right', 'small_bowel', 'spleen', 'stomach', 'trachea', 'urinary_bladder', 'vertebrae_C1', 'vertebrae_C2', 'vertebrae_C3', 'vertebrae_C4', 'vertebrae_C5', 'vertebrae_C6', 'vertebrae_C7', 'vertebrae_L1', 'vertebrae_L2', 'vertebrae_L3', 'vertebrae_L4', 'vertebrae_L5', 'vertebrae_T1', 'vertebrae_T10', 'vertebrae_T11', 'vertebrae_T12', 'vertebrae_T2', 'vertebrae_T3', 'vertebrae_T4', 'vertebrae_T5', 'vertebrae_T6', 'vertebrae_T7', 'vertebrae_T8', 'vertebrae_T9'],
42 | '0012': ['Liver','Bladder','Lungs','Kidneys','Bone','Brain'],
43 | '0013': ['Liver','Spleen','Kidney (L)','Kidney (R)','Stomach','Gallbladder','Esophagus','Pancreas','Duodenum','Colon','Intestine','Adrenal','Rectum','Bladder','Head of femur (L)','Head of femur (R)'],
44 | '0014': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4', 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1', 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5', 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9', 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1', 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6', 'sacrum','cocygis','additional 13th thoracic vertebra, T13',],
45 | '0015': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4', 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1', 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5', 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9', 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1', 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6', 'sacrum','cocygis','additional 13th thoracic vertebra, T13',],
46 | '0016': ['liver'],
47 | '0017': ['kidney', 'pancreas', 'pancreatic-lesion'],
48 | '0018': ['colon cancer'],
49 | '0019': ['hepatic vessels', 'tumour'],
50 | '0020': ['liver', 'tumour'],
51 | '0021': ['lung tumours'],
52 | '0022': ['pancreas', 'tumour'],
53 | '0023': ['spleen'],
54 | '0024': ['left lung', 'right lung', 'trachea'],
55 | }
56 |
57 |
58 |
59 | dataset_info = {
60 | '0000': ['liver'],
61 | '0001': ['carotid artery left', 'carotid artery right', 'arytenoid', 'bone mandible', 'brainstem',
62 | 'buccal mucosa',
63 | 'oral cavity', 'cochlea left', 'cochlea right', 'cricopharyngeal inlet', 'cervical esophagus',
64 | 'anterior eyeball left', 'anterior eyeball right',
65 | 'posterior eyeball left', 'posterior eyeball right', 'lacrimal gland left', 'lacrimal gland right',
66 | 'submandibular gland left', 'submandibular gland right',
67 | 'thyroid', 'larynx glottis', 'larynx supraglottic', 'lips', 'optic chiasm', 'optic nerve left',
68 | 'optic nerve right',
69 | 'parotid gland left', 'parotid gland right', 'pituitary gland', 'spinal cord'],
70 | '0002': ["spleen", "right kidney", "left kidney", "gall bladder", "esophagus", "liver", "stomach", "aorta",
71 | "postcava", "pancreas", "right adrenal gland", "left adrenal gland", "duodenum", "bladder",
72 | "prostate or uterus"],
73 | '0003': ['liver', 'kidney', 'spleen', 'pancreas'],
74 | '0004': ['kidney', 'kidney tumor', 'kidney cyst'],
75 | '0005': ['renal vein', 'kidney', 'renal artery', 'tumor'],
76 | '0006': ['kidney', 'kidney tumor'],
77 | '0007': ['spleen', 'right kidney', 'left kidney', 'gallbladder', 'esophagus', 'liver', 'stomach', 'aorta',
78 | 'inferior vena cava', 'portal vein and splenic vein', 'pancreas', 'right adrenal gland',
79 | 'left adrenal gland'],
80 | '0008': ['pancreas'],
81 | '0009': ['stones', 'artery', 'biliary system', 'bladder', 'bone', 'colon', 'gallbladder', 'heart',
82 | 'kidneys',
83 | 'left kidney', 'left lung', 'left suprarenal gland', 'left suprarenal tumor', 'liver',
84 | 'liver cyst', 'liver kyst',
85 |
86 | 'liver kyste', 'liver tumor', 'liver tumor 01', 'liver tumor 02', 'liver tumor 03',
87 | 'liver tumor 04',
88 | 'liver tumor 05', 'liver tumor 06', 'liver tumor 07', 'liver tumor 1', 'liver tumor 2',
89 | 'liver tumors',
90 |
91 | 'lungs', 'metal', 'metastasectomie', 'pancreas', 'portal vein', 'portal vein 1', 'right kidney',
92 | 'right lung', 'right suprarenal gland', 'right suprarenal tumor', 'skin', 'small intestin',
93 | 'spleen', 'stomach',
94 | 'suprarenal gland', 'tumor', 'uterus', 'vena cava', 'venous system'],
95 | '0010': ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava',
96 | 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum',
97 | 'left kidney'],
98 | '0011': ['adrenal gland left', 'adrenal gland right', 'aorta', 'autochthon left', 'autochthon right',
99 | 'brain', 'clavicula left', 'clavicula right', 'colon', 'duodenum', 'esophagus', 'face',
100 | 'femur left', 'femur right', 'gallbladder', 'gluteus maximus left', 'gluteus maximus right',
101 | 'gluteus medius left', 'gluteus medius right', 'gluteus minimus left', 'gluteus minimus right',
102 |
103 | 'heart atrium left', 'heart atrium right', 'heart myocardium', 'heart ventricle left',
104 | 'heart ventricle right', 'hip left', 'hip right', 'humerus left', 'humerus right',
105 | 'iliac artery left', 'iliac artery right', 'iliac vena left', 'iliac vena right', 'iliopsoas left',
106 |
107 | 'iliopsoas right', 'inferior vena cava', 'kidney left', 'kidney right', 'liver',
108 | 'lung lower lobe left', 'lung lower lobe right', 'lung middle lobe right', 'lung upper lobe left',
109 | 'lung upper lobe right', 'pancreas', 'portal vein and splenic vein', 'pulmonary artery',
110 |
111 | 'rib left 1', 'rib left 10', 'rib left 11', 'rib left 12', 'rib left 2', 'rib left 3',
112 | 'rib left 4', 'rib left 5', 'rib left 6', 'rib left 7', 'rib left 8', 'rib left 9', 'rib right 1',
113 |
114 | 'rib right 10', 'rib right 11', 'rib right 12', 'rib right 2', 'rib right 3', 'rib right 4',
115 | 'rib right 5', 'rib right 6', 'rib right 7', 'rib right 8', 'rib right 9', 'sacrum',
116 |
117 | 'scapula left', 'scapula right', 'small bowel', 'spleen', 'stomach', 'trachea', 'urinary bladder',
118 | 'vertebrae C1', 'vertebrae C2', 'vertebrae C3', 'vertebrae C4', 'vertebrae C5', 'vertebrae C6',
119 |
120 | 'vertebrae C7', 'vertebrae L1', 'vertebrae L2', 'vertebrae L3', 'vertebrae L4', 'vertebrae L5',
121 | 'vertebrae T1', 'vertebrae T10', 'vertebrae T11', 'vertebrae T12', 'vertebrae T2', 'vertebrae T3',
122 | 'vertebrae T4', 'vertebrae T5', 'vertebrae T6', 'vertebrae T7', 'vertebrae T8', 'vertebrae T9'],
123 | '0012': ['liver', 'bladder', 'lungs', 'kidneys', 'bone', 'brain'],
124 | '0013': ['liver', 'spleen', 'kidney left', 'kidney right', 'stomach', 'gallbladder', 'esophagus',
125 | 'pancreas',
126 | 'duodenum', 'colon', 'intestine', 'adrenal', 'rectum', 'bladder', 'head of femur left',
127 | 'head of femur right'],
128 | '0014': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4',
129 | 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1',
130 |
131 | 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5',
132 | 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9',
133 |
134 | 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1',
135 | 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6',
136 | 'sacrum', 'coccygis', 'additional 13th thoracic vertebra, T13', ],
137 | '0015': ['cervical spine C1', 'cervical spine C2', 'cervical spine C3', 'cervical spine C4',
138 | 'cervical spine C5', 'cervical spine C6', 'cervical spine C7', 'thoracic spine T1',
139 | 'thoracic spine T2', 'thoracic spine T3', 'thoracic spine T4', 'thoracic spine T5',
140 | 'thoracic spine T6', 'thoracic spine T7', 'thoracic spine T8', 'thoracic spine T9',
141 | 'thoracic spine T10', 'thoracic spine T11', 'thoracic spine T12', 'lumbar spine L1',
142 | 'lumbar spine L1', 'lumbar spine L3', 'lumbar spine L4', 'lumbar spine L5', 'lumbar spine L6',
143 | 'sacrum', 'coccygis', 'additional 13th thoracic vertebra T13', ],
144 | '0016': ['liver'],
145 | '0017': ['kidney', 'pancreas', 'pancreatic lesion'],
146 | '0018': ['colon cancer'],
147 | '0019': ['hepatic vessels', 'tumour'],
148 | '0020': ['liver', 'tumour'],
149 | '0021': ['lung tumours'],
150 | '0022': ['pancreas', 'tumour'],
151 | '0023': ['spleen'],
152 | '0024': ['left lung', 'right lung', 'trachea'],
153 | }
--------------------------------------------------------------------------------
/models/components/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # Modified from https://github.com/open-mmlab/mmdetection/blob/main/projects/EfficientDet/efficientdet/utils.py
3 |
4 | import math
5 | from typing import Tuple, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from mmcv.cnn.bricks import Swish, build_norm_layer
10 | from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_
11 | from mmdet.registry import MODELS
12 | from mmdet.utils import OptConfigType
13 | from torch.autograd import Function
14 | import torch.nn.functional as F
15 | from torch.nn.modules.utils import _triple, _pair, _single
16 |
17 | import adapool_cuda
18 |
19 |
20 | def variance_scaling_trunc(tensor, gain=1.):
21 | fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
22 | gain /= max(1.0, fan_in)
23 | std = math.sqrt(gain) / .87962566103423978
24 | return trunc_normal_(tensor, 0., std)
25 |
26 |
27 | @MODELS.register_module()
28 | class Conv2dSamePadding(nn.Conv2d):
29 |
30 | def __init__(self,
31 | in_channels: int,
32 | out_channels: int,
33 | kernel_size: Union[int, Tuple[int, int]],
34 | stride: Union[int, Tuple[int, int]] = 1,
35 | padding: Union[int, Tuple[int, int]] = 0,
36 | dilation: Union[int, Tuple[int, int]] = 1,
37 | groups: int = 1,
38 | bias: bool = True):
39 | super().__init__(in_channels, out_channels, kernel_size, stride, 0,
40 | dilation, groups, bias)
41 |
42 | def forward(self, x: torch.Tensor) -> torch.Tensor:
43 | img_h, img_w = x.size()[-2:]
44 | kernel_h, kernel_w = self.weight.size()[-2:]
45 | extra_w = (math.ceil(img_w / self.stride[1]) -
46 | 1) * self.stride[1] - img_w + kernel_w
47 | extra_h = (math.ceil(img_h / self.stride[0]) -
48 | 1) * self.stride[0] - img_h + kernel_h
49 |
50 | left = extra_w // 2
51 | right = extra_w - left
52 | top = extra_h // 2
53 | bottom = extra_h - top
54 | x = F.pad(x, [left, right, top, bottom])
55 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
56 | self.dilation, self.groups)
57 |
58 |
59 | class MaxPool2dSamePadding(nn.Module):
60 |
61 | def __init__(self,
62 | kernel_size: Union[int, Tuple[int, int]] = 3,
63 | stride: Union[int, Tuple[int, int]] = 2,
64 | **kwargs):
65 | super().__init__()
66 | self.pool = nn.MaxPool2d(kernel_size, stride, **kwargs)
67 | self.stride = self.pool.stride
68 | self.kernel_size = self.pool.kernel_size
69 |
70 | if isinstance(self.stride, int):
71 | self.stride = [self.stride] * 2
72 | if isinstance(self.kernel_size, int):
73 | self.kernel_size = [self.kernel_size] * 2
74 |
75 | def forward(self, x):
76 | h, w = x.shape[-2:]
77 |
78 | extra_h = (math.ceil(w / self.stride[1]) -
79 | 1) * self.stride[1] - w + self.kernel_size[1]
80 | extra_v = (math.ceil(h / self.stride[0]) -
81 | 1) * self.stride[0] - h + self.kernel_size[0]
82 |
83 | left = extra_h // 2
84 | right = extra_h - left
85 | top = extra_v // 2
86 | bottom = extra_v - top
87 |
88 | x = F.pad(x, [left, right, top, bottom])
89 | x = self.pool(x)
90 |
91 | return x
92 |
93 |
94 | class DepthWiseConvBlock(nn.Module):
95 |
96 | def __init__(
97 | self,
98 | in_channels: int,
99 | out_channels: int,
100 | apply_norm: bool = True,
101 | conv_bn_act_pattern: bool = False,
102 | norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3)
103 | ) -> None:
104 | super(DepthWiseConvBlock, self).__init__()
105 | self.depthwise_conv = Conv2dSamePadding(
106 | in_channels,
107 | in_channels,
108 | kernel_size=3,
109 | stride=1,
110 | groups=in_channels,
111 | bias=False)
112 | self.pointwise_conv = Conv2dSamePadding(
113 | in_channels, out_channels, kernel_size=1, stride=1)
114 |
115 | self.apply_norm = apply_norm
116 | if self.apply_norm:
117 | self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1]
118 |
119 | self.apply_activation = conv_bn_act_pattern
120 | if self.apply_activation:
121 | self.swish = Swish()
122 |
123 | def forward(self, x):
124 | x = self.depthwise_conv(x)
125 | x = self.pointwise_conv(x)
126 | if self.apply_norm:
127 | x = self.bn(x)
128 | if self.apply_activation:
129 | x = self.swish(x)
130 |
131 | return x
132 |
133 |
134 | class DownChannelBlock(nn.Module):
135 |
136 | def __init__(
137 | self,
138 | in_channels: int,
139 | out_channels: int,
140 | apply_norm: bool = True,
141 | conv_bn_act_pattern: bool = False,
142 | norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3)
143 | ) -> None:
144 | super(DownChannelBlock, self).__init__()
145 | self.down_conv = Conv2dSamePadding(in_channels, out_channels, 1)
146 | self.apply_norm = apply_norm
147 | if self.apply_norm:
148 | self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1]
149 | self.apply_activation = conv_bn_act_pattern
150 | if self.apply_activation:
151 | self.swish = Swish()
152 |
153 | def forward(self, x):
154 | x = self.down_conv(x)
155 | if self.apply_norm:
156 | x = self.bn(x)
157 | if self.apply_activation:
158 | x = self.swish(x)
159 |
160 | return x
161 |
162 |
163 | # Copied from https://github.com/wenxi-yue/SurgicalSAM/blob/main/surgicalSAM/model.py
164 | class CUDA_ADAPOOL2d(Function):
165 |
166 | @staticmethod
167 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
168 | def forward(ctx, input, beta, kernel=2, stride=None, return_mask=False):
169 |
170 | assert input.dtype == beta.dtype, '`input` and `beta` are not of the same dtype.'
171 | beta = torch.clamp(beta, 0., 1.)
172 | no_batch = False
173 | if len(input.size()) == 3:
174 | no_batch = True
175 | input.unsqueeze_(0)
176 | B, C, H, W = input.shape
177 | kernel = _pair(kernel)
178 | if stride is None:
179 | stride = kernel
180 | else:
181 | stride = _pair(stride)
182 |
183 | oH = (H - kernel[0]) // stride[0] + 1
184 | oW = (W - kernel[1]) // stride[1] + 1
185 |
186 | output = input.new_zeros((B, C, oH, oW))
187 | if return_mask:
188 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW))
189 | else:
190 | mask = input.new_zeros((1))
191 |
192 | adapool_cuda.forward_2d(input.contiguous(), beta, kernel, stride, output, return_mask, mask)
193 | ctx.save_for_backward(input, beta)
194 | ctx.kernel = kernel
195 | ctx.stride = stride
196 | if return_mask:
197 | mask_ = mask.detach().clone()
198 | mask_.requires_grad = False
199 | CUDA_ADAPOOL2d.mask = mask_
200 | output = torch.nan_to_num(output)
201 | if no_batch:
202 | return output.squeeze_(0)
203 | return output
204 |
205 | @staticmethod
206 | @torch.cuda.amp.custom_bwd
207 | def backward(ctx, grad_output):
208 |
209 | grad_input = torch.zeros_like(ctx.saved_tensors[0])
210 | grad_beta = torch.zeros_like(ctx.saved_tensors[1])
211 |
212 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input, grad_beta]
213 | adapool_cuda.backward_2d(*saved)
214 |
215 | return torch.nan_to_num(saved[-2]), torch.nan_to_num(saved[-1]), None, None, None
216 |
217 |
218 | class CUDA_ADAPOOL2d_EDSCW(Function):
219 |
220 | @staticmethod
221 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
222 | def forward(ctx, input, kernel=2, stride=None, return_mask=False):
223 | no_batch = False
224 | if len(input.size()) == 3:
225 | no_batch = True
226 | input.unsqueeze_(0)
227 | B, C, H, W = input.shape
228 | kernel = _pair(kernel)
229 | if stride is None:
230 | stride = kernel
231 | else:
232 | stride = _pair(stride)
233 |
234 | oH = (H - kernel[0]) // stride[0] + 1
235 | oW = (W - kernel[1]) // stride[1] + 1
236 |
237 | output = input.new_zeros((B, C, oH, oW))
238 | if return_mask:
239 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW))
240 | else:
241 | mask = input.new_zeros((1))
242 |
243 | adapool_cuda.forward_2d_edscw(input.contiguous(), kernel, stride, output, return_mask, mask)
244 | ctx.save_for_backward(input)
245 | ctx.kernel = kernel
246 | ctx.stride = stride
247 | if return_mask:
248 | mask_ = mask.detach().clone()
249 | mask_.requires_grad = False
250 | CUDA_ADAPOOL2d_EDSCW.mask = mask_
251 | output = torch.nan_to_num(output)
252 | if no_batch:
253 | return output.squeeze_(0)
254 | return output
255 |
256 | @staticmethod
257 | @torch.cuda.amp.custom_bwd
258 | def backward(ctx, grad_output):
259 |
260 | grad_input = torch.zeros_like(ctx.saved_tensors[0])
261 |
262 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input]
263 | adapool_cuda.backward_2d_edscw(*saved)
264 |
265 | return torch.nan_to_num(saved[-1]), None, None, None
266 |
267 |
268 | class CUDA_IDWPOOL2d(Function):
269 |
270 | @staticmethod
271 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
272 | def forward(ctx, input, kernel=2, stride=None, return_mask=False):
273 | no_batch = False
274 | if len(input.size()) == 3:
275 | no_batch = True
276 | input.unsqueeze_(0)
277 | B, C, H, W = input.shape
278 | kernel = _pair(kernel)
279 | if stride is None:
280 | stride = kernel
281 | else:
282 | stride = _pair(stride)
283 |
284 | oH = (H - kernel[0]) // stride[0] + 1
285 | oW = (W - kernel[1]) // stride[1] + 1
286 |
287 | output = input.new_zeros((B, C, oH, oW))
288 | if return_mask:
289 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW))
290 | else:
291 | mask = input.new_zeros((1))
292 |
293 | adapool_cuda.forward_2d_idw(input.contiguous(), kernel, stride, output, return_mask, mask)
294 | ctx.save_for_backward(input)
295 | ctx.kernel = kernel
296 | ctx.stride = stride
297 | if return_mask:
298 | mask_ = mask.detach().clone()
299 | mask_.requires_grad = False
300 | CUDA_ADAPOOL2d_EDSCW.mask = mask_
301 | output = torch.nan_to_num(output)
302 | if no_batch:
303 | return output.squeeze_(0)
304 | return output
305 |
306 | @staticmethod
307 | @torch.cuda.amp.custom_bwd
308 | def backward(ctx, grad_output):
309 |
310 | grad_input = torch.zeros_like(ctx.saved_tensors[0])
311 |
312 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input]
313 | adapool_cuda.backward_2d_idw(*saved)
314 |
315 | return torch.nan_to_num(saved[-1]), None, None, None
316 |
317 |
318 | class CUDA_ADAPOOL2d_EM(Function):
319 |
320 | @staticmethod
321 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
322 | def forward(ctx, input, kernel=2, stride=None, return_mask=False):
323 |
324 | no_batch = False
325 | if len(input.size()) == 3:
326 | no_batch = True
327 | input.unsqueeze_(0)
328 | B, C, H, W = input.shape
329 | kernel = _pair(kernel)
330 | if stride is None:
331 | stride = kernel
332 | else:
333 | stride = _pair(stride)
334 |
335 | oH = (H - kernel[0]) // stride[0] + 1
336 | oW = (W - kernel[1]) // stride[1] + 1
337 |
338 | output = input.new_zeros((B, C, oH, oW))
339 | if return_mask:
340 | mask = input.new_zeros((B, kernel[0] * oH, kernel[1] * oW))
341 | else:
342 | mask = input.new_zeros((1))
343 |
344 | adapool_cuda.forward_2d_em(input.contiguous(), kernel, stride, output, return_mask, mask)
345 | ctx.save_for_backward(input)
346 | ctx.kernel = kernel
347 | ctx.stride = stride
348 | if return_mask:
349 | mask_ = mask.detach().clone()
350 | mask_.requires_grad = False
351 | CUDA_ADAPOOL2d_EM.mask = mask_
352 | output = torch.nan_to_num(output)
353 | if no_batch:
354 | return output.squeeze_(0)
355 | return output
356 |
357 | @staticmethod
358 | @torch.cuda.amp.custom_bwd
359 | def backward(ctx, grad_output):
360 |
361 | grad_input = torch.zeros_like(ctx.saved_tensors[0])
362 |
363 | saved = [grad_output] + list(ctx.saved_tensors) + [ctx.kernel, ctx.stride, grad_input]
364 | adapool_cuda.backward_2d_em(*saved)
365 |
366 | return torch.nan_to_num(saved[-1]), None, None, None
367 |
--------------------------------------------------------------------------------
/models/sam2/sam/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from models.sam2.sam2_utils import LayerNorm2d, MLP
13 |
14 |
15 | class MaskDecoder(nn.Module):
16 | def __init__(
17 | self,
18 | *,
19 | transformer_dim: int,
20 | transformer: nn.Module,
21 | num_multimask_outputs: int = 3,
22 | activation: Type[nn.Module] = nn.GELU,
23 | iou_head_depth: int = 3,
24 | iou_head_hidden_dim: int = 256,
25 | use_high_res_features: bool = False,
26 | iou_prediction_use_sigmoid=False,
27 | dynamic_multimask_via_stability=False,
28 | dynamic_multimask_stability_delta=0.05,
29 | dynamic_multimask_stability_thresh=0.98,
30 | pred_obj_scores: bool = False,
31 | pred_obj_scores_mlp: bool = False,
32 | use_multimask_token_for_obj_ptr: bool = False,
33 | ) -> None:
34 | """
35 | Predicts masks given an image and prompt embeddings, using a
36 | transformer architecture.
37 |
38 | Arguments:
39 | transformer_dim (int): the channel dimension of the transformer
40 | transformer (nn.Module): the transformer used to predict masks
41 | num_multimask_outputs (int): the number of masks to predict
42 | when disambiguating masks
43 | activation (nn.Module): the type of activation to use when
44 | upscaling masks
45 | iou_head_depth (int): the depth of the MLP used to predict
46 | mask quality
47 | iou_head_hidden_dim (int): the hidden dimension of the MLP
48 | used to predict mask quality
49 | """
50 | super().__init__()
51 | self.transformer_dim = transformer_dim
52 | self.transformer = transformer
53 |
54 | self.num_multimask_outputs = num_multimask_outputs
55 |
56 | self.iou_token = nn.Embedding(1, transformer_dim)
57 | self.num_mask_tokens = num_multimask_outputs + 1
58 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59 |
60 | self.pred_obj_scores = pred_obj_scores
61 | if self.pred_obj_scores:
62 | self.obj_score_token = nn.Embedding(1, transformer_dim)
63 | self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64 |
65 | self.output_upscaling = nn.Sequential(
66 | nn.ConvTranspose2d(
67 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68 | ),
69 | LayerNorm2d(transformer_dim // 4),
70 | activation(),
71 | nn.ConvTranspose2d(
72 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73 | ),
74 | activation(),
75 | )
76 | self.use_high_res_features = use_high_res_features
77 | if use_high_res_features:
78 | self.conv_s0 = nn.Conv2d(
79 | transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80 | )
81 | self.conv_s1 = nn.Conv2d(
82 | transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83 | )
84 |
85 | self.output_hypernetworks_mlps = nn.ModuleList(
86 | [
87 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88 | for i in range(self.num_mask_tokens)
89 | ]
90 | )
91 |
92 | self.iou_prediction_head = MLP(
93 | transformer_dim,
94 | iou_head_hidden_dim,
95 | self.num_mask_tokens,
96 | iou_head_depth,
97 | sigmoid_output=iou_prediction_use_sigmoid,
98 | )
99 | if self.pred_obj_scores:
100 | self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101 | if pred_obj_scores_mlp:
102 | self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103 |
104 | # When outputting a single mask, optionally we can dynamically fall back to the best
105 | # multimask output token if the single mask output token gives low stability scores.
106 | self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107 | self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108 | self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109 |
110 | def forward(
111 | self,
112 | image_embeddings: torch.Tensor,
113 | image_pe: torch.Tensor,
114 | sparse_prompt_embeddings: torch.Tensor,
115 | dense_prompt_embeddings: torch.Tensor,
116 | multimask_output: bool,
117 | repeat_image: bool,
118 | high_res_features: Optional[List[torch.Tensor]] = None,
119 | ) -> Tuple[torch.Tensor, torch.Tensor]:
120 | """
121 | Predict masks given image and prompt embeddings.
122 |
123 | Arguments:
124 | image_embeddings (torch.Tensor): the embeddings from the image encoder
125 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128 | multimask_output (bool): Whether to return multiple masks or a single
129 | mask.
130 |
131 | Returns:
132 | torch.Tensor: batched predicted masks
133 | torch.Tensor: batched predictions of mask quality
134 | torch.Tensor: batched sam token for mask output
135 | """
136 | masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137 | image_embeddings=image_embeddings,
138 | image_pe=image_pe,
139 | sparse_prompt_embeddings=sparse_prompt_embeddings,
140 | dense_prompt_embeddings=dense_prompt_embeddings,
141 | repeat_image=repeat_image,
142 | high_res_features=high_res_features,
143 | )
144 |
145 | # Select the correct mask or masks for output
146 | if multimask_output:
147 | masks = masks[:, 1:, :, :]
148 | iou_pred = iou_pred[:, 1:]
149 | elif self.dynamic_multimask_via_stability and not self.training:
150 | masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151 | else:
152 | masks = masks[:, 0:1, :, :]
153 | iou_pred = iou_pred[:, 0:1]
154 |
155 | if multimask_output and self.use_multimask_token_for_obj_ptr:
156 | sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157 | else:
158 | # Take the mask output token. Here we *always* use the token for single mask output.
159 | # At test time, even if we track after 1-click (and using multimask_output=True),
160 | # we still take the single mask token here. The rationale is that we always track
161 | # after multiple clicks during segment_anything_training, so the past tokens seen during segment_anything_training
162 | # are always the single mask token (and we'll let it be the object-memory token).
163 | sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164 |
165 | # Prepare output
166 | return masks, iou_pred, sam_tokens_out, object_score_logits
167 |
168 | def predict_masks(
169 | self,
170 | image_embeddings: torch.Tensor,
171 | image_pe: torch.Tensor,
172 | sparse_prompt_embeddings: torch.Tensor,
173 | dense_prompt_embeddings: torch.Tensor,
174 | repeat_image: bool,
175 | high_res_features: Optional[List[torch.Tensor]] = None,
176 | ) -> Tuple[torch.Tensor, torch.Tensor]:
177 | """Predicts masks. See 'forward' for more details."""
178 | # Concatenate output tokens
179 | s = 0
180 | if self.pred_obj_scores:
181 | output_tokens = torch.cat(
182 | [
183 | self.obj_score_token.weight,
184 | self.iou_token.weight,
185 | self.mask_tokens.weight,
186 | ],
187 | dim=0,
188 | )
189 | s = 1
190 | else:
191 | output_tokens = torch.cat(
192 | [self.iou_token.weight, self.mask_tokens.weight], dim=0
193 | )
194 | output_tokens = output_tokens.unsqueeze(0).expand(
195 | sparse_prompt_embeddings.size(0), -1, -1
196 | )
197 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198 |
199 | # Expand per-image data in batch direction to be per-mask
200 | if repeat_image:
201 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202 | else:
203 | assert image_embeddings.shape[0] == tokens.shape[0]
204 | src = image_embeddings
205 | src = src + dense_prompt_embeddings
206 | assert (
207 | image_pe.size(0) == 1
208 | ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210 | b, c, h, w = src.shape
211 |
212 | # Run the transformer
213 | hs, src = self.transformer(src, pos_src, tokens)
214 | iou_token_out = hs[:, s, :]
215 | mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
216 |
217 | # Upscale mask embeddings and predict masks using the mask tokens
218 | src = src.transpose(1, 2).view(b, c, h, w)
219 | if not self.use_high_res_features:
220 | upscaled_embedding = self.output_upscaling(src)
221 | else:
222 | dc1, ln1, act1, dc2, act2 = self.output_upscaling
223 | feat_s0, feat_s1 = high_res_features
224 | upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
225 | upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
226 |
227 | hyper_in_list: List[torch.Tensor] = []
228 | for i in range(self.num_mask_tokens):
229 | hyper_in_list.append(
230 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231 | )
232 | hyper_in = torch.stack(hyper_in_list, dim=1)
233 | b, c, h, w = upscaled_embedding.shape
234 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
235 |
236 | # Generate mask quality predictions
237 | iou_pred = self.iou_prediction_head(iou_token_out)
238 | if self.pred_obj_scores:
239 | assert s == 1
240 | object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
241 | else:
242 | # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
243 | object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
244 |
245 | return masks, iou_pred, mask_tokens_out, object_score_logits
246 |
247 | def _get_stability_scores(self, mask_logits):
248 | """
249 | Compute stability scores of the mask logits based on the IoU between upper and
250 | lower thresholds.
251 | """
252 | mask_logits = mask_logits.flatten(-2)
253 | stability_delta = self.dynamic_multimask_stability_delta
254 | area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
255 | area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
256 | stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
257 | return stability_scores
258 |
259 | def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
260 | """
261 | When outputting a single mask, if the stability score from the current single-mask
262 | output (based on output token 0) falls below a threshold, we instead select from
263 | multi-mask outputs (based on output token 1~3) the mask with the highest predicted
264 | IoU score. This is intended to ensure a valid mask for both clicking and tracking.
265 | """
266 | # The best mask from multimask output tokens (1~3)
267 | multimask_logits = all_mask_logits[:, 1:, :, :]
268 | multimask_iou_scores = all_iou_scores[:, 1:]
269 | best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270 | batch_inds = torch.arange(
271 | multimask_iou_scores.size(0), device=all_iou_scores.device
272 | )
273 | best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274 | best_multimask_logits = best_multimask_logits.unsqueeze(1)
275 | best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
276 | best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
277 |
278 | # The mask from singlemask output token 0 and its stability score
279 | singlemask_logits = all_mask_logits[:, 0:1, :, :]
280 | singlemask_iou_scores = all_iou_scores[:, 0:1]
281 | stability_scores = self._get_stability_scores(singlemask_logits)
282 | is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
283 |
284 | # Dynamically fall back to best multimask output upon low stability scores.
285 | mask_logits_out = torch.where(
286 | is_stable[..., None, None].expand_as(singlemask_logits),
287 | singlemask_logits,
288 | best_multimask_logits,
289 | )
290 | iou_scores_out = torch.where(
291 | is_stable.expand_as(singlemask_iou_scores),
292 | singlemask_iou_scores,
293 | best_multimask_iou_scores,
294 | )
295 | return mask_logits_out, iou_scores_out
--------------------------------------------------------------------------------