├── README.md ├── SAM ├── LICENSE.txt ├── README.md ├── crop_image.py ├── datasets │ ├── __init__.py │ ├── dataset_cancer.py │ └── dataset_synapse.py ├── requirements.txt ├── sam_lora_image_encoder.py ├── segment_anything │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── test.py ├── train.py ├── trainer.py └── utils.py ├── cbam.py ├── dataset.py ├── model.py ├── train.py └── val.py /README.md: -------------------------------------------------------------------------------- 1 | # SAM-FNet 2 | 3 | This repository contains the implementation of the following paper: 4 | 5 | SAM-FNet: SAM-Guided Fusion Network for Laryngo-Pharyngeal Tumor Detection 6 | 7 | 8 | 9 | ## Fine-tune SAM with LoRA 10 | 11 | To fine-tune SAM using LoRA, we recommend following the guidelines provided in the original repository: [SAMed](https://github.com/hitachinsk/SAMed/tree/main). 12 | 13 | ### Steps: 14 | 15 | 1. **Fine-tune SAM:** 16 | - Follow the instructions in the [SAMed repository](https://github.com/hitachinsk/SAMed/tree/main) to fine-tune SAM with LoRA. 17 | 2. **Generate Local Images:** 18 | - After fine-tuning, modify the `crop_image.py` script to suit your requirements. 19 | - Run the script to generate local images: 20 | 21 | ```markdown 22 | python crop_image.py 23 | ``` 24 | 25 | 26 | 27 | ## Dataset 28 | 29 | Organize your datasets in the following manner: 30 | 31 | ```markdown 32 | datasets/ 33 | ├── dataset1/ 34 | │ ├── global/ 35 | │ │ ├── train/ 36 | │ │ │ ├── benign/ 37 | │ │ │ ├── normal/ 38 | │ │ │ └── tumor/ 39 | │ │ ├── val/ 40 | │ │ │ ├── benign/ 41 | │ │ │ ├── normal/ 42 | │ │ │ └── tumor/ 43 | │ │ └── test/ 44 | │ │ ├── benign/ 45 | │ │ ├── normal/ 46 | │ │ └── tumor/ 47 | │ └── local_seg/ 48 | │ ├── train/ 49 | │ │ ├── benign/ 50 | │ │ ├── normal/ 51 | │ │ └── tumor/ 52 | │ ├── val/ 53 | │ │ ├── benign/ 54 | │ │ ├── normal/ 55 | │ │ └── tumor/ 56 | │ └── test/ 57 | │ ├── benign/ 58 | │ ├── normal/ 59 | │ └── tumor/ 60 | ├── dataset6/ 61 | │ └── ... 62 | ``` 63 | 64 | 65 | 66 | ## Training 67 | 68 | 1. Modify the `class_labels` variable in the `dataset.py` file to reflect the classes in your dataset. 69 | 2. Run this command to train SAM-FNet. 70 | 71 | ```markdown 72 | python train.py --data_dir --save_path --num_classes --pretrained True --encoder ResNet50 73 | ``` 74 | 75 | - Replace `` with the path to your dataset. 76 | - Replace `` with the directory where you want to save the model checkpoints. 77 | - Replace `` with the number of classes in your classification task. 78 | 79 | 80 | 81 | ## Testing 82 | 83 | 1. Change "classes" in the val.py 84 | 2. Run this command to test. 85 | 86 | ```markdown 87 | python val.py --model_path --encoder ResNet50 --dataset --save_path 88 | ``` 89 | 90 | 91 | 92 | ## Acknowledgement 93 | 94 | The code of SAM-FNet is built upon [SAMed](https://github.com/hitachinsk/SAMed/tree/main) and [DLGNet](https://github.com/soleilssss/DLGNet), and we express our gratitude to these awesome projects. 95 | -------------------------------------------------------------------------------- /SAM/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kaidong Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SAM/README.md: -------------------------------------------------------------------------------- 1 | # SAMed_h 2 | 3 | ## Prerequisites 4 | - Linux (We tested our codes on Ubuntu 18.04) 5 | - Anaconda 6 | - Python 3.10.11 7 | - Pytorch 2.0.0 **(Pytorch 2+ is necessary)** 8 | 9 | To get started, first please clone the repo 10 | ``` 11 | git clone https://github.com/hitachinsk/SAMed.git 12 | cd SAMed_h 13 | ``` 14 | Then, please run the following commands: 15 | ``` 16 | conda create -n SAMed_h python=3.10.11 17 | conda activate SAMed_h 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## Quick start 22 | All the steps are the same as [SAMed](https://github.com/hitachinsk/SAMed). But you need to prepare the [vit_h version of SAM](https://github.com/facebookresearch/segment-anything#model-checkpoints) and [our pretrained checkpoint](https://drive.google.com/file/d/1Kx_vx9bcxJaiMYWAgljNtwtHcooUsq8m/view?usp=sharing). 23 | 24 | ## Training 25 | We adopt one A100 (80G) for training. 26 | 1. Please download the processed [training set](https://drive.google.com/file/d/1zuOQRyfo0QYgjcU_uZs0X3LdCnAC2m3G/view?usp=share_link), whose resolution is `224x224`, and put it in ``. Then, unzip and delete this file. We also prepare the [training set](https://drive.google.com/file/d/1F42WMa80UpH98Pw95oAzYDmxAAO2ApYg/view?usp=share_link) with resolution `512x512` for reference, the `224x224` version of training set is downsampled from the `512x512` version. 27 | 2. Run this command to train SAMed. 28 | ```bash 29 | python train.py --root_path --output --warmup --AdamW --tf32 --compile --use_amp --lr_exp 7 --max_epochs 400 --stop_epoch 300 30 | ``` 31 | Check the results in ``, and the training process will consume about 70G GPU memory. 32 | 33 | ## Difference between SAMed_h and SAMed 34 | - SAMed_h adopts the `vit_h` version of SAM as the base model. 35 | - SAMed_h needs more training iterations. Therefore, we set the max epoch to 400 and early stop to 300 for better performance. 36 | - Too large learning rate will cause the training instability of SAMed_h. Therefore, we increase the exponent of exponential decay from 0.9 to 7, which can greatly reduce the training instability. 37 | - For faster training speed and less memory consumption, SAMed_h adopts auto mixed-precision, tensor-float 32 and `compile` technology in pytorch 2.0. Therefore, pytorch2+ is necessary for training this model. 38 | -------------------------------------------------------------------------------- /SAM/crop_image.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gzip 3 | import os 4 | import pickle 5 | import sys 6 | 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import logging 10 | import random 11 | import numpy as np 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | 15 | from scipy.ndimage.interpolation import zoom 16 | 17 | from pathlib import Path 18 | import cv2 19 | from scipy.ndimage import label 20 | 21 | from segment_anything import sam_model_registry 22 | from sam_lora_image_encoder import LoRA_Sam 23 | 24 | def generate_cropped_image(): 25 | def do_crop(img, prd, crop_size): 26 | h, w = img.shape[:2] 27 | masked_img = img.copy() 28 | if np.max(prd) == 0: 29 | # 计算中心位置 30 | center_row = h // 2 31 | center_col = w // 2 32 | # 计算裁剪的起始和结束位置 33 | min_row = max(0, center_row - crop_size[0] // 2) 34 | min_col = max(0, center_col - crop_size[1] // 2) 35 | max_row = min(h, center_row + crop_size[0] // 2) 36 | max_col = min(w, center_col + crop_size[1] // 2) 37 | 38 | else: 39 | masked_img[prd != 255] = 0 40 | 41 | rows, cols = np.where(prd == 255) 42 | min_row, max_row, min_col, max_col = min(rows), max(rows), min(cols), max(cols) 43 | rect_width = max_col - min_col + 1 44 | rect_height = max_row - min_row + 1 45 | 46 | if rect_width < crop_size[0] or rect_height < crop_size[1]: 47 | # 计算裁剪区域的边界 48 | crop_min_row = max(0, min_row - max(0, (crop_size[0] - rect_height) // 2)) 49 | crop_max_row = min(prd.shape[0], crop_min_row + max(crop_size[0], rect_height)) 50 | 51 | crop_min_col = max(0, min_col - max(0, (crop_size[1] - rect_width) // 2)) 52 | crop_max_col = min(prd.shape[1], crop_min_col + max(crop_size[1], rect_width)) 53 | min_row, max_row, min_col, max_col = crop_min_row, crop_max_row, crop_min_col, crop_max_col 54 | 55 | # Crop the corresponding region from the original image 56 | cropped_img = Image.fromarray(masked_img[min_row:max_row, min_col:max_col]) 57 | 58 | return cropped_img 59 | 60 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 61 | root = r"../datasets/dataset1/global" 62 | classes = ['benign', 'tumor', 'normal'] 63 | phases = ['test'] 64 | source = root 65 | target = root.replace("global", "local_seg") 66 | input_size = 224 67 | crop_size = (256, 256) 68 | cudnn.benchmark = False 69 | cudnn.deterministic = True 70 | seed = 1234 71 | random.seed() 72 | np.random.seed(seed) 73 | torch.manual_seed(seed) 74 | torch.cuda.manual_seed(seed) 75 | rank = 4 76 | lora_ckpt = r"./exp4/epoch_199.pth" 77 | ckpt = r"./checkpoints/sam_vit_b_01ec64.pth" 78 | sam, img_embedding_size = sam_model_registry['vit_b'](image_size=input_size, 79 | num_classes=1, 80 | checkpoint=ckpt, 81 | pixel_mean=[0, 0, 0], 82 | pixel_std=[1, 1, 1]) 83 | 84 | net = LoRA_Sam(sam, rank).cuda() 85 | net.load_lora_parameters(lora_ckpt) 86 | 87 | net.eval() 88 | for phase in phases: 89 | for cls in classes: 90 | imgs = os.listdir(os.path.join(source, phase, cls)) 91 | for img in tqdm(imgs): 92 | torch.cuda.empty_cache() 93 | img_path = os.path.join(source, phase, cls, img) 94 | image = cv2.imread(img_path) 95 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 96 | origin_image = copy.deepcopy(image) 97 | x, y = image.shape[0:2] 98 | if x != input_size or y != input_size: 99 | image = zoom(image, (input_size / x, input_size / y, 1.0), order=3) 100 | inputs = torch.from_numpy(image.astype(np.float32) / 255.0) 101 | inputs = inputs.permute(2, 0, 1) 102 | inputs = inputs.unsqueeze(0).cuda() 103 | with torch.no_grad(): 104 | outputs = net(inputs, False, input_size) 105 | output_masks = outputs['masks'] 106 | out = torch.argmax(torch.softmax(output_masks, dim=1), dim=1).squeeze(0) 107 | prediction = out.cpu().detach().numpy() 108 | if x != input_size or y != input_size: 109 | prediction = zoom(prediction, (x / input_size, y / input_size), order=0) 110 | cropped_image = do_crop(img=origin_image.astype(np.uint8), 111 | prd=(prediction * 255).astype(np.uint8), 112 | crop_size=crop_size) 113 | output_path = os.path.join(target, phase, cls, img) 114 | if not os.path.exists(os.path.join(target, phase, cls)): 115 | os.makedirs(os.path.join(target, phase, cls)) 116 | cropped_image.save(output_path) 117 | 118 | if __name__ == "__main__": 119 | generate_cropped_image() 120 | -------------------------------------------------------------------------------- /SAM/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VVJia/SAM-FNet/a30c11d471cb7f03b40eb406fe36fe8dedc89c6f/SAM/datasets/__init__.py -------------------------------------------------------------------------------- /SAM/datasets/dataset_cancer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torchvision.datasets 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | from pycocotools import mask as coco_mask 11 | import os 12 | 13 | mean = [0.485, 0.456, 0.406] 14 | std = [0.229, 0.224, 0.225] 15 | 16 | def random_rot_flip(image, label): 17 | k = np.random.randint(0, 4) 18 | image = np.rot90(image, k) 19 | label = np.rot90(label, k) 20 | axis = np.random.randint(0, 2) 21 | image = np.flip(image, axis=axis).copy() 22 | label = np.flip(label, axis=axis).copy() 23 | return image, label 24 | 25 | 26 | def random_rotate(image, label): 27 | angle = np.random.randint(-20, 20) 28 | image = ndimage.rotate(image, angle, order=0, reshape=False) 29 | label = ndimage.rotate(label, angle, order=0, reshape=False) 30 | return image, label 31 | 32 | 33 | class RandomGenerator(object): 34 | def __init__(self, output_size, low_res, phase): 35 | self.output_size = output_size 36 | self.low_res = low_res 37 | self.phase = phase 38 | 39 | def __call__(self, sample): 40 | image, label = sample['image'], sample['label'] 41 | 42 | if self.phase == "train": 43 | if random.random() > 0.5: 44 | image, label = random_rot_flip(image, label) 45 | elif random.random() > 0.5: 46 | image, label = random_rotate(image, label) 47 | x, y = image.shape[0:2] 48 | if x != self.output_size[0] or y != self.output_size[1]: 49 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y, 1.0), order=3) # why not 3? 50 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 51 | label_h, label_w = label.shape 52 | low_res_label = zoom(label, (self.low_res[0] / label_h, self.low_res[1] / label_w), order=0) 53 | image = torch.from_numpy(image.astype(np.float32) / 255.0) 54 | # image = (image - torch.FloatTensor(mean)) / torch.FloatTensor(std) 55 | image = image.permute(2, 0, 1) 56 | label = torch.from_numpy(label.astype(np.float32)) 57 | low_res_label = torch.from_numpy(low_res_label.astype(np.float32)) 58 | sample = {'image': image, 'label': label.long(), 'low_res_label': low_res_label.long(), 'case_name': sample['case_name']} 59 | return sample 60 | 61 | def convert_coco_poly_to_mask(segmentations, height, width): 62 | masks = [] 63 | for polygons in segmentations: 64 | rles = coco_mask.frPyObjects(polygons, height, width) 65 | mask = coco_mask.decode(rles) 66 | if len(mask.shape) < 3: 67 | mask = mask[..., None] 68 | # mask = torch.as_tensor(mask, dtype=torch.uint8) 69 | mask = np.any(mask, axis=2) 70 | masks.append(mask) 71 | 72 | merged_mask = np.zeros((height, width), dtype=np.uint8) 73 | if masks: 74 | for mask in masks: 75 | merged_mask = merged_mask | mask 76 | 77 | return merged_mask 78 | 79 | class COCO_dataset(torchvision.datasets.CocoDetection): 80 | def __init__(self, img_folder, ann_file, split=None, transform=None): 81 | super(COCO_dataset, self).__init__(img_folder, ann_file) 82 | self.split = split 83 | self.transform = transform 84 | 85 | def __len__(self): 86 | return super(COCO_dataset, self).__len__() 87 | 88 | def __getitem__(self, idx): 89 | img, target = super(COCO_dataset, self).__getitem__(idx) 90 | 91 | # get filename 92 | image_info = self.coco.loadImgs(self.ids[idx])[0] 93 | filename = image_info['file_name'] 94 | 95 | # generate masks 96 | w, h = img.size 97 | segmentations = [obj['segmentation'] for obj in target] 98 | masks = convert_coco_poly_to_mask(segmentations, h, w) 99 | 100 | label_value = target[0]['category_id'] + 1 101 | masks[masks == 1] = label_value 102 | 103 | img = np.array(img) 104 | 105 | sample = {'image': img, 'label': masks} 106 | 107 | if self.transform: 108 | sample = self.transform(sample) 109 | 110 | sample['case_name'] = os.path.splitext(filename)[0] 111 | 112 | return sample 113 | 114 | class Cancer_dataset(Dataset): 115 | def __init__(self, data_dir, txt_dir, transform=None): 116 | # train or val or test 117 | phase = os.path.splitext(os.path.basename(txt_dir))[0] 118 | file_path = os.path.join(data_dir, phase) 119 | 120 | self.data = [os.path.join(file_path, file) for file in os.listdir(file_path)] 121 | self.transform = transform # using transform in torch! 122 | self.sample_list = open(txt_dir).readlines() 123 | 124 | def __len__(self): 125 | return len(self.sample_list) 126 | 127 | def __getitem__(self, idx): 128 | data_path = self.data[idx] 129 | data_dic = np.load(data_path) 130 | image, label = data_dic['image'], data_dic['label'] 131 | name = os.path.splitext(os.path.basename(data_path))[0] 132 | sample = {'image': image, 'label': label, 'case_name': name} 133 | if self.transform: 134 | sample = self.transform(sample) 135 | 136 | return sample 137 | 138 | -------------------------------------------------------------------------------- /SAM/datasets/dataset_synapse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import h5py 4 | import numpy as np 5 | import torch 6 | import torchvision.datasets 7 | from scipy import ndimage 8 | from scipy.ndimage.interpolation import zoom 9 | from torch.utils.data import Dataset 10 | from einops import repeat 11 | from icecream import ic 12 | from torchvision import transforms 13 | from pycocotools import mask as coco_mask 14 | import os 15 | 16 | mean = [0.485, 0.456, 0.406] 17 | std = [0.229, 0.224, 0.225] 18 | 19 | def random_rot_flip(image, label): 20 | k = np.random.randint(0, 4) 21 | image = np.rot90(image, k) 22 | label = np.rot90(label, k) 23 | axis = np.random.randint(0, 2) 24 | image = np.flip(image, axis=axis).copy() 25 | label = np.flip(label, axis=axis).copy() 26 | return image, label 27 | 28 | 29 | def random_rotate(image, label): 30 | angle = np.random.randint(-20, 20) 31 | image = ndimage.rotate(image, angle, order=0, reshape=False) 32 | label = ndimage.rotate(label, angle, order=0, reshape=False) 33 | return image, label 34 | 35 | 36 | class RandomGenerator(object): 37 | def __init__(self, output_size, low_res): 38 | self.output_size = output_size 39 | self.low_res = low_res 40 | 41 | def __call__(self, sample): 42 | image, label = sample['image'], sample['label'] 43 | 44 | if random.random() > 0.5: 45 | image, label = random_rot_flip(image, label) 46 | elif random.random() > 0.5: 47 | image, label = random_rotate(image, label) 48 | x, y = image.shape[0:2] 49 | if x != self.output_size[0] or y != self.output_size[1]: 50 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y, 1.0), order=3) # why not 3? 51 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 52 | label_h, label_w = label.shape 53 | low_res_label = zoom(label, (self.low_res[0] / label_h, self.low_res[1] / label_w), order=0) 54 | image = torch.from_numpy(image.astype(np.float32)) 55 | image = (image - torch.FloatTensor(mean)) / torch.FloatTensor(std) 56 | image = image.permute(2, 0, 1) 57 | label = torch.from_numpy(label.astype(np.float32)) 58 | low_res_label = torch.from_numpy(low_res_label.astype(np.float32)) 59 | sample = {'image': image, 'label': label.long(), 'low_res_label': low_res_label.long(), 'case_name': sample['case_name']} 60 | return sample 61 | 62 | def convert_coco_poly_to_mask(segmentations, height, width): 63 | masks = [] 64 | for polygons in segmentations: 65 | rles = coco_mask.frPyObjects(polygons, height, width) 66 | mask = coco_mask.decode(rles) 67 | if len(mask.shape) < 3: 68 | mask = mask[..., None] 69 | # mask = torch.as_tensor(mask, dtype=torch.uint8) 70 | mask = np.any(mask, axis=2) 71 | masks.append(mask) 72 | 73 | merged_mask = np.zeros((height, width), dtype=np.uint8) 74 | if masks: 75 | for mask in masks: 76 | merged_mask = merged_mask | mask 77 | 78 | return merged_mask 79 | 80 | class COCO_dataset(torchvision.datasets.CocoDetection): 81 | def __init__(self, img_folder, ann_file, split=None, transform=None): 82 | super(COCO_dataset, self).__init__(img_folder, ann_file) 83 | self.split = split 84 | self.transform = transform 85 | 86 | def __len__(self): 87 | return super(COCO_dataset, self).__len__() 88 | 89 | def __getitem__(self, idx): 90 | img, target = super(COCO_dataset, self).__getitem__(idx) 91 | 92 | # get filename 93 | image_info = self.coco.loadImgs(self.ids[idx])[0] 94 | filename = image_info['file_name'] 95 | 96 | # generate masks 97 | w, h = img.size 98 | segmentations = [obj['segmentation'] for obj in target] 99 | masks = convert_coco_poly_to_mask(segmentations, h, w) 100 | 101 | label_value = target[0]['category_id'] + 1 102 | masks[masks == 1] = label_value 103 | 104 | img = np.array(img) 105 | 106 | sample = {'image': img, 'label': masks} 107 | 108 | if self.transform: 109 | sample = self.transform(sample) 110 | 111 | sample['case_name'] = os.path.splitext(filename)[0] 112 | 113 | return sample 114 | 115 | class Cancer_dataset(Dataset): 116 | def __init__(self, data_dir, txt_dir, transform=None): 117 | # train or val or test 118 | phase = os.path.splitext(os.path.basename(txt_dir))[0] 119 | file_path = os.path.join(data_dir, phase) 120 | 121 | self.data = [os.path.join(file_path, file) for file in os.listdir(file_path)] 122 | self.transform = transform # using transform in torch! 123 | self.sample_list = open(txt_dir).readlines() 124 | 125 | def __len__(self): 126 | return len(self.sample_list) 127 | 128 | def __getitem__(self, idx): 129 | data_path = self.data[idx] 130 | data_dic = np.load(data_path) 131 | image, label = data_dic['image'], data_dic['label'] 132 | name = os.path.splitext(os.path.basename(data_path))[0] 133 | sample = {'image': image, 'label': label, 'case_name': name} 134 | if self.transform: 135 | sample = self.transform(sample) 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | # dataset = Synapse_dataset('/home/fanxm/pro/SAMed/datasets/train_npz_new_224/', '/home/fanxm/pro/SAMed/lists/lists_Synapse/', 'train', 142 | # transform=transforms.Compose( 143 | # [RandomGenerator(output_size=[224, 224], low_res=[56, 56])])) 144 | 145 | # dataset = Synapse_dataset('/home/fanxm/pro/SAMed/datasets/train_npz_new_224/', 146 | # '/home/fanxm/pro/SAMed/lists/lists_Synapse/', 'train') 147 | 148 | 149 | # dataset = COCO_dataset('/home/fanxm/pro/datasets/coco/val/', '/home/fanxm/pro/datasets/coco/annotations/cancer_val.json', 150 | # split="val", transform=transforms.Compose( 151 | # [RandomGenerator(output_size=[224, 224], low_res=[56, 56])])) 152 | 153 | dataset = Cancer_dataset("/home/user/pro/DLGNet/datasets/segment", "/home/user/pro/SAM/SAMed_h/datasets/val.txt", 154 | transform=transforms.Compose( 155 | [RandomGenerator(output_size=[224, 224], low_res=[56, 56])])) 156 | print(dataset[1]['image'].shape) 157 | print(dataset[1]['label'].shape) 158 | print(dataset[1]['case_name']) 159 | 160 | # # 161 | sample = dataset[1] 162 | img = (sample['image']) 163 | mask = (sample['label']) 164 | # for i in range(100): 165 | # img, target = dataset[i] 166 | # if len(target) != 1: 167 | # print(i) 168 | 169 | import matplotlib.pyplot as plt 170 | 171 | plt.figure() 172 | plt.imshow(img) 173 | plt.show() 174 | plt.figure() 175 | plt.imshow(mask, cmap='gray') # 使用灰度颜色映射 176 | plt.show() 177 | 178 | # print(dataset[100]['image'].shape) 179 | # print(dataset[100]['low_res_label'].shape) 180 | # plt.imshow(dataset[100]['low_res_label']) 181 | # plt.show() 182 | # plt.imshow(np.rollaxis(dataset[100]['image'].cpu().numpy(), 0, 3)) 183 | # plt.show() 184 | # 185 | # print(dataset[3]['image']) 186 | # data_item = iter(dataset) 187 | # for item in data_item: 188 | # print(item) -------------------------------------------------------------------------------- /SAM/requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | h5py==3.8.0 3 | icecream==2.1.3 4 | imageio==2.28.1 5 | MedPy==0.4.0 6 | nibabel==5.1.0 7 | monai==1.1.0 8 | numpy==1.24.3 9 | opencv_python==4.7.0.72 10 | pycocotools==2.0.6 11 | safetensors==0.3.1 12 | scipy==1.10.1 13 | SimpleITK==2.2.1 14 | tensorboardX==2.6 15 | torch==2.0.0 16 | torchvision==0.15.1 17 | tqdm==4.65.0 18 | ml-collections==0.1.1 19 | pycocotools==2.0.6 20 | onnx==1.14.0 21 | onnxruntime==1.14.1 22 | -------------------------------------------------------------------------------- /SAM/sam_lora_image_encoder.py: -------------------------------------------------------------------------------- 1 | from segment_anything import build_sam, SamPredictor 2 | from segment_anything import sam_model_registry 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.nn.parameter import Parameter 10 | from segment_anything.modeling import Sam 11 | from safetensors import safe_open 12 | from safetensors.torch import save_file 13 | 14 | from icecream import ic 15 | 16 | 17 | class _LoRA_qkv(nn.Module): 18 | """In Sam it is implemented as 19 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 20 | B, N, C = x.shape 21 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 22 | q, k, v = qkv.unbind(0) 23 | """ 24 | 25 | def __init__( 26 | self, 27 | qkv: nn.Module, 28 | linear_a_q: nn.Module, 29 | linear_b_q: nn.Module, 30 | linear_a_v: nn.Module, 31 | linear_b_v: nn.Module, 32 | ): 33 | super().__init__() 34 | self.qkv = qkv 35 | self.linear_a_q = linear_a_q 36 | self.linear_b_q = linear_b_q 37 | self.linear_a_v = linear_a_v 38 | self.linear_b_v = linear_b_v 39 | self.dim = qkv.in_features 40 | self.w_identity = torch.eye(qkv.in_features) 41 | 42 | def forward(self, x): 43 | qkv = self.qkv(x) # B,N,N,3*org_C 44 | new_q = self.linear_b_q(self.linear_a_q(x)) 45 | new_v = self.linear_b_v(self.linear_a_v(x)) 46 | qkv[:, :, :, : self.dim] += new_q 47 | qkv[:, :, :, -self.dim:] += new_v 48 | return qkv 49 | 50 | 51 | class LoRA_Sam(nn.Module): 52 | """Applies low-rank adaptation to a Sam model's image encoder. 53 | 54 | Args: 55 | sam_model: a vision transformer model, see base_vit.py 56 | r: rank of LoRA 57 | num_classes: how many classes the model output, default to the vit model 58 | lora_layer: which layer we apply LoRA. 59 | 60 | Examples:: 61 | >>> model = ViT('B_16_imagenet1k') 62 | >>> lora_model = LoRA_ViT(model, r=4) 63 | >>> preds = lora_model(img) 64 | >>> print(preds.shape) 65 | torch.Size([1, 1000]) 66 | """ 67 | 68 | def __init__(self, sam_model: Sam, r: int, lora_layer=None): 69 | super(LoRA_Sam, self).__init__() 70 | 71 | assert r > 0 72 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels 73 | # dim = base_vit_dim 74 | if lora_layer: 75 | self.lora_layer = lora_layer 76 | else: 77 | self.lora_layer = list( 78 | range(len(sam_model.image_encoder.blocks))) # Only apply lora to the image encoder by default 79 | # create for storage, then we can init them or load weights 80 | self.w_As = [] # These are linear layers 81 | self.w_Bs = [] 82 | 83 | # lets freeze first 84 | for param in sam_model.image_encoder.parameters(): 85 | param.requires_grad = False 86 | 87 | # Here, we do the surgery 88 | for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks): 89 | # If we only want few lora layer instead of all 90 | if t_layer_i not in self.lora_layer: 91 | continue 92 | w_qkv_linear = blk.attn.qkv 93 | self.dim = w_qkv_linear.in_features 94 | w_a_linear_q = nn.Linear(self.dim, r, bias=False) 95 | w_b_linear_q = nn.Linear(r, self.dim, bias=False) 96 | w_a_linear_v = nn.Linear(self.dim, r, bias=False) 97 | w_b_linear_v = nn.Linear(r, self.dim, bias=False) 98 | self.w_As.append(w_a_linear_q) 99 | self.w_Bs.append(w_b_linear_q) 100 | self.w_As.append(w_a_linear_v) 101 | self.w_Bs.append(w_b_linear_v) 102 | blk.attn.qkv = _LoRA_qkv( 103 | w_qkv_linear, 104 | w_a_linear_q, 105 | w_b_linear_q, 106 | w_a_linear_v, 107 | w_b_linear_v, 108 | ) 109 | self.reset_parameters() 110 | self.sam = sam_model 111 | 112 | def save_lora_parameters(self, filename: str) -> None: 113 | r"""Only safetensors is supported now. 114 | 115 | pip install safetensor if you do not have one installed yet. 116 | 117 | save both lora and fc parameters. 118 | """ 119 | 120 | assert filename.endswith(".pt") or filename.endswith('.pth') 121 | 122 | num_layer = len(self.w_As) # actually, it is half 123 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} 124 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} 125 | prompt_encoder_tensors = {} 126 | mask_decoder_tensors = {} 127 | 128 | # save prompt encoder, only `state_dict`, the `named_parameter` is not permitted 129 | if isinstance(self.sam, torch.nn.DataParallel) or isinstance(self.sam, torch.nn.parallel.DistributedDataParallel): 130 | state_dict = self.sam.module.state_dict() 131 | else: 132 | state_dict = self.sam.state_dict() 133 | for key, value in state_dict.items(): 134 | if 'prompt_encoder' in key: 135 | prompt_encoder_tensors[key] = value 136 | if 'mask_decoder' in key: 137 | mask_decoder_tensors[key] = value 138 | 139 | merged_dict = {**a_tensors, **b_tensors, **prompt_encoder_tensors, **mask_decoder_tensors} 140 | torch.save(merged_dict, filename) 141 | 142 | def load_lora_parameters(self, filename: str) -> None: 143 | r"""Only safetensors is supported now. 144 | 145 | pip install safetensor if you do not have one installed yet.\ 146 | 147 | load both lora and fc parameters. 148 | """ 149 | 150 | assert filename.endswith(".pt") or filename.endswith('.pth') 151 | 152 | state_dict = torch.load(filename) 153 | 154 | for i, w_A_linear in enumerate(self.w_As): 155 | saved_key = f"w_a_{i:03d}" 156 | saved_tensor = state_dict[saved_key] 157 | w_A_linear.weight = Parameter(saved_tensor) 158 | 159 | for i, w_B_linear in enumerate(self.w_Bs): 160 | saved_key = f"w_b_{i:03d}" 161 | saved_tensor = state_dict[saved_key] 162 | w_B_linear.weight = Parameter(saved_tensor) 163 | 164 | sam_dict = self.sam.state_dict() 165 | sam_keys = sam_dict.keys() 166 | 167 | # load prompt encoder 168 | prompt_encoder_keys = [k for k in sam_keys if 'prompt_encoder' in k] 169 | prompt_encoder_values = [state_dict[k] for k in prompt_encoder_keys] 170 | prompt_encoder_new_state_dict = {k: v for k, v in zip(prompt_encoder_keys, prompt_encoder_values)} 171 | sam_dict.update(prompt_encoder_new_state_dict) 172 | 173 | # load mask decoder 174 | mask_decoder_keys = [k for k in sam_keys if 'mask_decoder' in k] 175 | mask_decoder_values = [state_dict[k] for k in mask_decoder_keys] 176 | mask_decoder_new_state_dict = {k: v for k, v in zip(mask_decoder_keys, mask_decoder_values)} 177 | sam_dict.update(mask_decoder_new_state_dict) 178 | self.sam.load_state_dict(sam_dict) 179 | 180 | def reset_parameters(self) -> None: 181 | for w_A in self.w_As: 182 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) 183 | for w_B in self.w_Bs: 184 | nn.init.zeros_(w_B.weight) 185 | 186 | def forward(self, batched_input, multimask_output, image_size): 187 | return self.sam(batched_input, multimask_output, image_size) 188 | 189 | 190 | # def forward(self, x: Tensor) -> Tensor: 191 | # return self.lora_vit(x) 192 | 193 | 194 | if __name__ == "__main__": 195 | sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") 196 | lora_sam = LoRA_Sam(sam, 4) 197 | lora_sam.sam.image_encoder(torch.rand(size=(1, 3, 1024, 1024))) 198 | -------------------------------------------------------------------------------- /SAM/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /SAM/segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crops_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros(len(data["boxes"])), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros(len(data["boxes"])), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros(len(boxes)), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /SAM/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | from icecream import ic 10 | 11 | from functools import partial 12 | 13 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 14 | 15 | 16 | def build_sam_vit_h(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], 17 | checkpoint=None): 18 | return _build_sam( 19 | encoder_embed_dim=1280, 20 | encoder_depth=32, 21 | encoder_num_heads=16, 22 | encoder_global_attn_indexes=[7, 15, 23, 31], 23 | checkpoint=checkpoint, 24 | num_classes=num_classes, 25 | image_size=image_size, 26 | pixel_mean=pixel_mean, 27 | pixel_std=pixel_std 28 | ) 29 | 30 | 31 | build_sam = build_sam_vit_h 32 | 33 | 34 | def build_sam_vit_l(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], 35 | checkpoint=None): 36 | return _build_sam( 37 | encoder_embed_dim=1024, 38 | encoder_depth=24, 39 | encoder_num_heads=16, 40 | encoder_global_attn_indexes=[5, 11, 17, 23], 41 | checkpoint=checkpoint, 42 | num_classes=num_classes, 43 | image_size=image_size, 44 | pixel_mean=pixel_mean, 45 | pixel_std=pixel_std 46 | ) 47 | 48 | 49 | def build_sam_vit_b(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], 50 | checkpoint=None): 51 | return _build_sam( 52 | encoder_embed_dim=768, 53 | encoder_depth=12, 54 | encoder_num_heads=12, 55 | encoder_global_attn_indexes=[2, 5, 8, 11], 56 | # adopt global attention at [3, 6, 9, 12] transform layer, else window attention layer 57 | checkpoint=checkpoint, 58 | num_classes=num_classes, 59 | image_size=image_size, 60 | pixel_mean=pixel_mean, 61 | pixel_std=pixel_std 62 | ) 63 | 64 | 65 | sam_model_registry = { 66 | "default": build_sam_vit_h, 67 | "vit_h": build_sam_vit_h, 68 | "vit_l": build_sam_vit_l, 69 | "vit_b": build_sam_vit_b, 70 | } 71 | 72 | 73 | def _build_sam( 74 | encoder_embed_dim, 75 | encoder_depth, 76 | encoder_num_heads, 77 | encoder_global_attn_indexes, 78 | num_classes, 79 | image_size, 80 | pixel_mean, 81 | pixel_std, 82 | checkpoint=None, 83 | ): 84 | prompt_embed_dim = 256 85 | image_size = image_size 86 | vit_patch_size = 16 87 | image_embedding_size = image_size // vit_patch_size # Divide by 16 here 88 | sam = Sam( 89 | image_encoder=ImageEncoderViT( 90 | depth=encoder_depth, 91 | embed_dim=encoder_embed_dim, 92 | img_size=image_size, 93 | mlp_ratio=4, 94 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 95 | num_heads=encoder_num_heads, 96 | patch_size=vit_patch_size, 97 | qkv_bias=True, 98 | use_rel_pos=True, 99 | global_attn_indexes=encoder_global_attn_indexes, 100 | window_size=14, 101 | out_chans=prompt_embed_dim, 102 | ), 103 | prompt_encoder=PromptEncoder( 104 | embed_dim=prompt_embed_dim, 105 | image_embedding_size=(image_embedding_size, image_embedding_size), 106 | input_image_size=(image_size, image_size), 107 | mask_in_chans=16, 108 | ), 109 | mask_decoder=MaskDecoder( 110 | # num_multimask_outputs=3, 111 | num_multimask_outputs=num_classes, 112 | transformer=TwoWayTransformer( 113 | depth=2, 114 | embedding_dim=prompt_embed_dim, 115 | mlp_dim=2048, 116 | num_heads=8, 117 | ), 118 | transformer_dim=prompt_embed_dim, 119 | iou_head_depth=3, 120 | iou_head_hidden_dim=256, 121 | ), 122 | # pixel_mean=[123.675, 116.28, 103.53], 123 | # pixel_std=[58.395, 57.12, 57.375], 124 | pixel_mean=pixel_mean, 125 | pixel_std=pixel_std 126 | ) 127 | # sam.eval() 128 | sam.train() 129 | if checkpoint is not None: 130 | with open(checkpoint, "rb") as f: 131 | state_dict = torch.load(f) 132 | try: 133 | sam.load_state_dict(state_dict) 134 | except: 135 | new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size, encoder_global_attn_indexes) 136 | sam.load_state_dict(new_state_dict) 137 | return sam, image_embedding_size 138 | 139 | 140 | def load_from(sam, state_dict, image_size, vit_patch_size, encoder_global_attn_indexes): 141 | ega = encoder_global_attn_indexes 142 | sam_dict = sam.state_dict() 143 | except_keys = ['mask_tokens', 'output_hypernetworks_mlps', 'iou_prediction_head'] 144 | new_state_dict = {k: v for k, v in state_dict.items() if 145 | k in sam_dict.keys() and except_keys[0] not in k and except_keys[1] not in k and except_keys[2] not in k} 146 | pos_embed = new_state_dict['image_encoder.pos_embed'] 147 | token_size = int(image_size // vit_patch_size) 148 | if pos_embed.shape[1] != token_size: 149 | # resize pos embedding, which may sacrifice the performance, but I have no better idea 150 | pos_embed = pos_embed.permute(0, 3, 1, 2) # [b, c, h, w] 151 | pos_embed = F.interpolate(pos_embed, (token_size, token_size), mode='bilinear', align_corners=False) 152 | pos_embed = pos_embed.permute(0, 2, 3, 1) # [b, h, w, c] 153 | new_state_dict['image_encoder.pos_embed'] = pos_embed 154 | rel_pos_keys = [k for k in sam_dict.keys() if 'rel_pos' in k] 155 | global_rel_pos_keys = [] 156 | for rel_pos_key in rel_pos_keys: 157 | num = int(rel_pos_key.split('.')[2]) 158 | if num in encoder_global_attn_indexes: 159 | global_rel_pos_keys.append(rel_pos_key) 160 | # global_rel_pos_keys = [k for k in rel_pos_keys if '2' in k or '5' in k or '8' in k or '11' in k] 161 | for k in global_rel_pos_keys: 162 | rel_pos_params = new_state_dict[k] 163 | h, w = rel_pos_params.shape 164 | rel_pos_params = rel_pos_params.unsqueeze(0).unsqueeze(0) 165 | rel_pos_params = F.interpolate(rel_pos_params, (token_size * 2 - 1, w), mode='bilinear', align_corners=False) 166 | new_state_dict[k] = rel_pos_params[0, 0, ...] 167 | sam_dict.update(new_state_dict) 168 | return sam_dict 169 | -------------------------------------------------------------------------------- /SAM/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /SAM/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /SAM/segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from icecream import ic 11 | 12 | from typing import Optional, Tuple, Type 13 | 14 | from .common import LayerNorm2d, MLPBlock 15 | 16 | 17 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 18 | class ImageEncoderViT(nn.Module): 19 | def __init__( 20 | self, 21 | img_size: int = 1024, 22 | patch_size: int = 16, 23 | in_chans: int = 3, 24 | embed_dim: int = 768, 25 | depth: int = 12, 26 | num_heads: int = 12, 27 | mlp_ratio: float = 4.0, 28 | out_chans: int = 256, 29 | qkv_bias: bool = True, 30 | norm_layer: Type[nn.Module] = nn.LayerNorm, 31 | act_layer: Type[nn.Module] = nn.GELU, 32 | use_abs_pos: bool = True, 33 | use_rel_pos: bool = False, 34 | rel_pos_zero_init: bool = True, 35 | window_size: int = 0, 36 | global_attn_indexes: Tuple[int, ...] = (), 37 | ) -> None: 38 | """ 39 | Args: 40 | img_size (int): Input image size. 41 | patch_size (int): Patch size. 42 | in_chans (int): Number of input image channels. 43 | embed_dim (int): Patch embedding dimension. 44 | depth (int): Depth of ViT. 45 | num_heads (int): Number of attention heads in each ViT block. 46 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 47 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 48 | norm_layer (nn.Module): Normalization layer. 49 | act_layer (nn.Module): Activation layer. 50 | use_abs_pos (bool): If True, use absolute positional embeddings. 51 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 52 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 53 | window_size (int): Window size for window attention blocks. 54 | global_attn_indexes (list): Indexes for blocks using global attention. 55 | """ 56 | super().__init__() 57 | self.img_size = img_size 58 | 59 | self.patch_embed = PatchEmbed( 60 | kernel_size=(patch_size, patch_size), 61 | stride=(patch_size, patch_size), 62 | in_chans=in_chans, 63 | embed_dim=embed_dim, 64 | ) 65 | 66 | self.pos_embed: Optional[nn.Parameter] = None 67 | if use_abs_pos: 68 | # Initialize absolute positional embedding with pretrain image size. 69 | self.pos_embed = nn.Parameter( 70 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 71 | ) 72 | 73 | self.blocks = nn.ModuleList() 74 | for i in range(depth): 75 | block = Block( 76 | dim=embed_dim, 77 | num_heads=num_heads, 78 | mlp_ratio=mlp_ratio, 79 | qkv_bias=qkv_bias, 80 | norm_layer=norm_layer, 81 | act_layer=act_layer, 82 | use_rel_pos=use_rel_pos, 83 | rel_pos_zero_init=rel_pos_zero_init, 84 | window_size=window_size if i not in global_attn_indexes else 0, 85 | input_size=(img_size // patch_size, img_size // patch_size), 86 | ) 87 | self.blocks.append(block) 88 | 89 | self.neck = nn.Sequential( 90 | nn.Conv2d( 91 | embed_dim, 92 | out_chans, 93 | kernel_size=1, 94 | bias=False, 95 | ), 96 | LayerNorm2d(out_chans), 97 | nn.Conv2d( 98 | out_chans, 99 | out_chans, 100 | kernel_size=3, 101 | padding=1, 102 | bias=False, 103 | ), 104 | LayerNorm2d(out_chans), 105 | ) 106 | 107 | def forward(self, x: torch.Tensor) -> torch.Tensor: 108 | x = self.patch_embed(x) # pre embed: [1, 3, 1024, 1024], post embed: [1, 64, 64, 768] 109 | if self.pos_embed is not None: 110 | x = x + self.pos_embed 111 | 112 | for blk in self.blocks: 113 | x = blk(x) 114 | 115 | x = self.neck(x.permute(0, 3, 1, 2)) # [b, c, h, w], [1, 256, 64, 64] 116 | 117 | return x 118 | 119 | 120 | class Block(nn.Module): 121 | """Transformer blocks with support of window attention and residual propagation blocks""" 122 | 123 | def __init__( 124 | self, 125 | dim: int, 126 | num_heads: int, 127 | mlp_ratio: float = 4.0, 128 | qkv_bias: bool = True, 129 | norm_layer: Type[nn.Module] = nn.LayerNorm, 130 | act_layer: Type[nn.Module] = nn.GELU, 131 | use_rel_pos: bool = False, 132 | rel_pos_zero_init: bool = True, 133 | window_size: int = 0, 134 | input_size: Optional[Tuple[int, int]] = None, 135 | ) -> None: 136 | """ 137 | Args: 138 | dim (int): Number of input channels. 139 | num_heads (int): Number of attention heads in each ViT block. 140 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 141 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 142 | norm_layer (nn.Module): Normalization layer. 143 | act_layer (nn.Module): Activation layer. 144 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 145 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 146 | window_size (int): Window size for window attention blocks. If it equals 0, then 147 | use global attention. 148 | input_size (int or None): Input resolution for calculating the relative positional 149 | parameter size. 150 | """ 151 | super().__init__() 152 | self.norm1 = norm_layer(dim) 153 | self.attn = Attention( 154 | dim, 155 | num_heads=num_heads, 156 | qkv_bias=qkv_bias, 157 | use_rel_pos=use_rel_pos, 158 | rel_pos_zero_init=rel_pos_zero_init, 159 | input_size=input_size if window_size == 0 else (window_size, window_size), 160 | ) 161 | 162 | self.norm2 = norm_layer(dim) 163 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 164 | 165 | self.window_size = window_size 166 | 167 | def forward(self, x: torch.Tensor) -> torch.Tensor: 168 | shortcut = x 169 | x = self.norm1(x) 170 | # Window partition 171 | if self.window_size > 0: 172 | H, W = x.shape[1], x.shape[2] 173 | x, pad_hw = window_partition(x, self.window_size) # [B * num_windows, window_size, window_size, C] 174 | 175 | x = self.attn(x) 176 | # Reverse window partition 177 | if self.window_size > 0: 178 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 179 | 180 | x = shortcut + x 181 | x = x + self.mlp(self.norm2(x)) 182 | 183 | return x 184 | 185 | 186 | class Attention(nn.Module): 187 | """Multi-head Attention block with relative position embeddings.""" 188 | 189 | def __init__( 190 | self, 191 | dim: int, 192 | num_heads: int = 8, 193 | qkv_bias: bool = True, 194 | use_rel_pos: bool = False, 195 | rel_pos_zero_init: bool = True, 196 | input_size: Optional[Tuple[int, int]] = None, 197 | ) -> None: 198 | """ 199 | Args: 200 | dim (int): Number of input channels. 201 | num_heads (int): Number of attention heads. 202 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 203 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 204 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 205 | input_size (int or None): Input resolution for calculating the relative positional 206 | parameter size. 207 | """ 208 | super().__init__() 209 | self.num_heads = num_heads 210 | head_dim = dim // num_heads 211 | self.scale = head_dim**-0.5 212 | 213 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 214 | self.proj = nn.Linear(dim, dim) 215 | 216 | self.use_rel_pos = use_rel_pos 217 | if self.use_rel_pos: 218 | assert ( 219 | input_size is not None 220 | ), "Input size must be provided if using relative positional encoding." 221 | # initialize relative positional embeddings 222 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 223 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 224 | 225 | def forward(self, x: torch.Tensor) -> torch.Tensor: 226 | B, H, W, _ = x.shape 227 | # qkv with shape (3, B, nHead, H * W, C) 228 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 229 | # q, k, v with shape (B * nHead, H * W, C) 230 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 231 | 232 | attn = (q * self.scale) @ k.transpose(-2, -1) 233 | 234 | if self.use_rel_pos: 235 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 236 | 237 | attn = attn.softmax(dim=-1) 238 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 239 | x = self.proj(x) 240 | 241 | return x 242 | 243 | 244 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 245 | """ 246 | Partition into non-overlapping windows with padding if needed. 247 | Args: 248 | x (tensor): input tokens with [B, H, W, C]. 249 | window_size (int): window size. 250 | 251 | Returns: 252 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 253 | (Hp, Wp): padded height and width before partition 254 | """ 255 | B, H, W, C = x.shape 256 | 257 | pad_h = (window_size - H % window_size) % window_size 258 | pad_w = (window_size - W % window_size) % window_size 259 | if pad_h > 0 or pad_w > 0: 260 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 261 | Hp, Wp = H + pad_h, W + pad_w 262 | 263 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 264 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 265 | return windows, (Hp, Wp) 266 | 267 | 268 | def window_unpartition( 269 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 270 | ) -> torch.Tensor: 271 | """ 272 | Window unpartition into original sequences and removing padding. 273 | Args: 274 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 275 | window_size (int): window size. 276 | pad_hw (Tuple): padded height and width (Hp, Wp). 277 | hw (Tuple): original height and width (H, W) before padding. 278 | 279 | Returns: 280 | x: unpartitioned sequences with [B, H, W, C]. 281 | """ 282 | Hp, Wp = pad_hw 283 | H, W = hw 284 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 285 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 286 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 287 | 288 | if Hp > H or Wp > W: 289 | x = x[:, :H, :W, :].contiguous() 290 | return x 291 | 292 | 293 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 294 | """ 295 | Get relative positional embeddings according to the relative positions of 296 | query and key sizes. 297 | Args: 298 | q_size (int): size of query q. 299 | k_size (int): size of key k. 300 | rel_pos (Tensor): relative position embeddings (L, C). 301 | 302 | Returns: 303 | Extracted positional embeddings according to relative positions. 304 | """ 305 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 306 | # Interpolate rel pos if needed. 307 | if rel_pos.shape[0] != max_rel_dist: 308 | # Interpolate rel pos. 309 | rel_pos_resized = F.interpolate( 310 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 311 | size=max_rel_dist, 312 | mode="linear", 313 | ) 314 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 315 | else: 316 | rel_pos_resized = rel_pos 317 | 318 | # Scale the coords with short length if shapes for q and k are different. 319 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 320 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 321 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 322 | 323 | return rel_pos_resized[relative_coords.long()] 324 | 325 | 326 | def add_decomposed_rel_pos( 327 | attn: torch.Tensor, 328 | q: torch.Tensor, 329 | rel_pos_h: torch.Tensor, 330 | rel_pos_w: torch.Tensor, 331 | q_size: Tuple[int, int], 332 | k_size: Tuple[int, int], 333 | ) -> torch.Tensor: 334 | """ 335 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 336 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 337 | Args: 338 | attn (Tensor): attention map. 339 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 340 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 341 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 342 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 343 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 344 | 345 | Returns: 346 | attn (Tensor): attention map with added relative positional embeddings. 347 | """ 348 | q_h, q_w = q_size 349 | k_h, k_w = k_size 350 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 351 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 352 | 353 | B, _, dim = q.shape 354 | r_q = q.reshape(B, q_h, q_w, dim) 355 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 356 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 357 | 358 | attn = ( 359 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 360 | ).view(B, q_h * q_w, k_h * k_w) 361 | 362 | return attn 363 | 364 | 365 | class PatchEmbed(nn.Module): 366 | """ 367 | Image to Patch Embedding. 368 | """ 369 | 370 | def __init__( 371 | self, 372 | kernel_size: Tuple[int, int] = (16, 16), 373 | stride: Tuple[int, int] = (16, 16), 374 | padding: Tuple[int, int] = (0, 0), 375 | in_chans: int = 3, 376 | embed_dim: int = 768, 377 | ) -> None: 378 | """ 379 | Args: 380 | kernel_size (Tuple): kernel size of the projection layer. 381 | stride (Tuple): stride of the projection layer. 382 | padding (Tuple): padding size of the projection layer. 383 | in_chans (int): Number of input image channels. 384 | embed_dim (int): embed_dim (int): Patch embedding dimension. 385 | """ 386 | super().__init__() 387 | 388 | self.proj = nn.Conv2d( 389 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 390 | ) 391 | 392 | def forward(self, x: torch.Tensor) -> torch.Tensor: 393 | x = self.proj(x) 394 | # B C H W -> B H W C 395 | x = x.permute(0, 2, 3, 1) 396 | return x 397 | -------------------------------------------------------------------------------- /SAM/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from icecream import ic 11 | 12 | from typing import List, Tuple, Type 13 | 14 | from .common import LayerNorm2d 15 | 16 | 17 | class MaskDecoder(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | transformer_dim: int, 22 | transformer: nn.Module, 23 | num_multimask_outputs: int = 3, 24 | activation: Type[nn.Module] = nn.GELU, 25 | iou_head_depth: int = 3, 26 | iou_head_hidden_dim: int = 256, 27 | ) -> None: 28 | """ 29 | Predicts masks given an image and prompt embeddings, using a 30 | tranformer architecture. 31 | 32 | Arguments: 33 | transformer_dim (int): the channel dimension of the transformer 34 | transformer (nn.Module): the transformer used to predict masks 35 | num_multimask_outputs (int): the number of masks to predict 36 | when disambiguating masks 37 | activation (nn.Module): the type of activation to use when 38 | upscaling masks 39 | iou_head_depth (int): the depth of the MLP used to predict 40 | mask quality 41 | iou_head_hidden_dim (int): the hidden dimension of the MLP 42 | used to predict mask quality 43 | """ 44 | super().__init__() 45 | self.transformer_dim = transformer_dim 46 | self.transformer = transformer 47 | 48 | self.num_multimask_outputs = num_multimask_outputs 49 | 50 | self.iou_token = nn.Embedding(1, transformer_dim) 51 | self.num_mask_tokens = num_multimask_outputs + 1 52 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 53 | 54 | self.output_upscaling = nn.Sequential( 55 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 56 | LayerNorm2d(transformer_dim // 4), 57 | activation(), 58 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 59 | activation(), 60 | ) 61 | self.output_hypernetworks_mlps = nn.ModuleList( 62 | [ 63 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 64 | for i in range(self.num_mask_tokens) 65 | ] 66 | ) 67 | 68 | self.iou_prediction_head = MLP( 69 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 70 | ) 71 | 72 | def forward( 73 | self, 74 | image_embeddings: torch.Tensor, 75 | image_pe: torch.Tensor, 76 | sparse_prompt_embeddings: torch.Tensor, 77 | dense_prompt_embeddings: torch.Tensor, 78 | multimask_output: bool, 79 | ) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Predict masks given image and prompt embeddings. 82 | 83 | Arguments: 84 | image_embeddings (torch.Tensor): the embeddings from the image encoder 85 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 86 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 87 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 88 | multimask_output (bool): Whether to return multiple masks or a single 89 | mask. 90 | 91 | Returns: 92 | torch.Tensor: batched predicted masks 93 | torch.Tensor: batched predictions of mask quality 94 | """ 95 | masks, iou_pred = self.predict_masks( 96 | image_embeddings=image_embeddings, 97 | image_pe=image_pe, 98 | sparse_prompt_embeddings=sparse_prompt_embeddings, 99 | dense_prompt_embeddings=dense_prompt_embeddings, 100 | ) 101 | 102 | # Select the correct mask or masks for output 103 | # if multimask_output: 104 | # mask_slice = slice(1, None) 105 | # else: 106 | # mask_slice = slice(0, 1) 107 | # masks = masks[:, mask_slice, :, :] 108 | # iou_pred = iou_pred[:, mask_slice] 109 | 110 | # Prepare output 111 | return masks, iou_pred 112 | 113 | def predict_masks( 114 | self, 115 | image_embeddings: torch.Tensor, 116 | image_pe: torch.Tensor, 117 | sparse_prompt_embeddings: torch.Tensor, 118 | dense_prompt_embeddings: torch.Tensor, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """Predicts masks. See 'forward' for more details.""" 121 | # Concatenate output tokens 122 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 123 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 124 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 125 | 126 | # Expand per-image data in batch direction to be per-mask 127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 128 | src = src + dense_prompt_embeddings 129 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 130 | b, c, h, w = src.shape 131 | 132 | # Run the transformer 133 | hs, src = self.transformer(src, pos_src, tokens) 134 | iou_token_out = hs[:, 0, :] 135 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 136 | 137 | # Upscale mask embeddings and predict masks using the mask tokens 138 | src = src.transpose(1, 2).view(b, c, h, w) 139 | upscaled_embedding = self.output_upscaling(src) 140 | hyper_in_list: List[torch.Tensor] = [] 141 | for i in range(self.num_mask_tokens): 142 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 143 | hyper_in = torch.stack(hyper_in_list, dim=1) # [b, c, token_num] 144 | 145 | b, c, h, w = upscaled_embedding.shape # [h, token_num, h, w] 146 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # [1, 4, 256, 256], 256 = 4 * 64, the size of image embeddings 147 | 148 | # Generate mask quality predictions 149 | iou_pred = self.iou_prediction_head(iou_token_out) 150 | 151 | return masks, iou_pred 152 | 153 | 154 | # Lightly adapted from 155 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 156 | class MLP(nn.Module): 157 | def __init__( 158 | self, 159 | input_dim: int, 160 | hidden_dim: int, 161 | output_dim: int, 162 | num_layers: int, 163 | sigmoid_output: bool = False, 164 | ) -> None: 165 | super().__init__() 166 | self.num_layers = num_layers 167 | h = [hidden_dim] * (num_layers - 1) 168 | self.layers = nn.ModuleList( 169 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 170 | ) 171 | self.sigmoid_output = sigmoid_output 172 | 173 | def forward(self, x): 174 | for i, layer in enumerate(self.layers): 175 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 176 | if self.sigmoid_output: 177 | x = F.sigmoid(x) 178 | return x 179 | -------------------------------------------------------------------------------- /SAM/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) # downsample to 1/4 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /SAM/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from icecream import ic 11 | 12 | from typing import Any, Dict, List, Tuple 13 | 14 | from .image_encoder import ImageEncoderViT 15 | from .mask_decoder import MaskDecoder 16 | from .prompt_encoder import PromptEncoder 17 | 18 | 19 | class Sam(nn.Module): 20 | mask_threshold: float = 0.0 21 | image_format: str = "RGB" 22 | 23 | def __init__( 24 | self, 25 | image_encoder: ImageEncoderViT, 26 | prompt_encoder: PromptEncoder, 27 | mask_decoder: MaskDecoder, 28 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 29 | pixel_std: List[float] = [58.395, 57.12, 57.375], 30 | ) -> None: 31 | """ 32 | SAM predicts object masks from an image and input prompts. 33 | 34 | Arguments: 35 | image_encoder (ImageEncoderViT): The backbone used to encode the 36 | image into image embeddings that allow for efficient mask prediction. 37 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 38 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 39 | and encoded prompts. 40 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 41 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 42 | """ 43 | super().__init__() 44 | self.image_encoder = image_encoder 45 | self.prompt_encoder = prompt_encoder 46 | self.mask_decoder = mask_decoder 47 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 48 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 49 | 50 | @property 51 | def device(self) -> Any: 52 | return self.pixel_mean.device 53 | 54 | def forward(self, batched_input, multimask_output, image_size): 55 | if isinstance(batched_input, list): 56 | outputs = self.forward_test(batched_input, multimask_output) 57 | else: 58 | outputs = self.forward_train(batched_input, multimask_output, image_size) 59 | return outputs 60 | 61 | def forward_train(self, batched_input, multimask_output, image_size): 62 | input_images = self.preprocess(batched_input) 63 | image_embeddings = self.image_encoder(input_images) 64 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 65 | points=None, boxes=None, masks=None 66 | ) 67 | low_res_masks, iou_predictions = self.mask_decoder( 68 | image_embeddings=image_embeddings, 69 | image_pe=self.prompt_encoder.get_dense_pe(), 70 | sparse_prompt_embeddings=sparse_embeddings, 71 | dense_prompt_embeddings=dense_embeddings, 72 | multimask_output=multimask_output 73 | ) 74 | masks = self.postprocess_masks( 75 | low_res_masks, 76 | input_size=(image_size, image_size), 77 | original_size=(image_size, image_size) 78 | ) 79 | outputs = { 80 | 'masks': masks, 81 | 'iou_predictions': iou_predictions, 82 | 'low_res_logits': low_res_masks 83 | } 84 | return outputs 85 | 86 | @torch.no_grad() 87 | def forward_test( 88 | self, 89 | batched_input: List[Dict[str, Any]], 90 | multimask_output: bool, 91 | ) -> List[Dict[str, torch.Tensor]]: 92 | """ 93 | Predicts masks end-to-end from provided images and prompts. 94 | If prompts are not known in advance, using SamPredictor is 95 | recommended over calling the model directly. 96 | 97 | Arguments: 98 | batched_input (list(dict)): A list over input images, each a 99 | dictionary with the following keys. A prompt key can be 100 | excluded if it is not present. 101 | 'image': The image as a torch tensor in 3xHxW format, 102 | already transformed for input to the model. 103 | 'original_size': (tuple(int, int)) The original size of 104 | the image before transformation, as (H, W). 105 | 'point_coords': (torch.Tensor) Batched point prompts for 106 | this image, with shape BxNx2. Already transformed to the 107 | input frame of the model. 108 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 109 | with shape BxN. 110 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 111 | Already transformed to the input frame of the model. 112 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 113 | in the form Bx1xHxW. 114 | multimask_output (bool): Whether the model should predict multiple 115 | disambiguating masks, or return a single mask. 116 | 117 | Returns: 118 | (list(dict)): A list over input images, where each element is 119 | as dictionary with the following keys. 120 | 'masks': (torch.Tensor) Batched binary mask predictions, 121 | with shape BxCxHxW, where B is the number of input promts, 122 | C is determiend by multimask_output, and (H, W) is the 123 | original size of the image. 124 | 'iou_predictions': (torch.Tensor) The model's predictions 125 | of mask quality, in shape BxC. 126 | 'low_res_logits': (torch.Tensor) Low resolution logits with 127 | shape BxCxHxW, where H=W=256. Can be passed as mask input 128 | to subsequent iterations of prediction. 129 | """ 130 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 131 | image_embeddings = self.image_encoder(input_images) 132 | 133 | outputs = [] 134 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 135 | if "point_coords" in image_record: 136 | points = (image_record["point_coords"], image_record["point_labels"]) 137 | else: 138 | points = None 139 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 140 | points=points, 141 | boxes=image_record.get("boxes", None), 142 | masks=image_record.get("mask_inputs", None), 143 | ) 144 | low_res_masks, iou_predictions = self.mask_decoder( 145 | image_embeddings=curr_embedding.unsqueeze(0), 146 | image_pe=self.prompt_encoder.get_dense_pe(), 147 | sparse_prompt_embeddings=sparse_embeddings, 148 | dense_prompt_embeddings=dense_embeddings, 149 | multimask_output=multimask_output, 150 | ) 151 | masks = self.postprocess_masks( 152 | low_res_masks, 153 | input_size=image_record["image"].shape[-2:], 154 | original_size=image_record["original_size"], 155 | ) 156 | masks = masks > self.mask_threshold 157 | outputs.append( 158 | { 159 | "masks": masks, 160 | "iou_predictions": iou_predictions, 161 | "low_res_logits": low_res_masks, 162 | } 163 | ) 164 | return outputs 165 | 166 | def postprocess_masks( 167 | self, 168 | masks: torch.Tensor, 169 | input_size: Tuple[int, ...], 170 | original_size: Tuple[int, ...], 171 | ) -> torch.Tensor: 172 | """ 173 | Remove padding and upscale masks to the original image size. 174 | 175 | Arguments: 176 | masks (torch.Tensor): Batched masks from the mask_decoder, 177 | in BxCxHxW format. 178 | input_size (tuple(int, int)): The size of the image input to the 179 | model, in (H, W) format. Used to remove padding. 180 | original_size (tuple(int, int)): The original size of the image 181 | before resizing for input to the model, in (H, W) format. 182 | 183 | Returns: 184 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 185 | is given by original_size. 186 | """ 187 | masks = F.interpolate( 188 | masks, 189 | (self.image_encoder.img_size, self.image_encoder.img_size), 190 | mode="bilinear", 191 | align_corners=False, 192 | ) 193 | masks = masks[..., : input_size[0], : input_size[1]] 194 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 195 | return masks 196 | 197 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 198 | """Normalize pixel values and pad to a square input.""" 199 | # Normalize colors 200 | x = (x - self.pixel_mean) / self.pixel_std 201 | 202 | # Pad 203 | h, w = x.shape[-2:] 204 | padh = self.image_encoder.img_size - h 205 | padw = self.image_encoder.img_size - w 206 | x = F.pad(x, (0, padw, 0, padh)) 207 | return x 208 | 209 | -------------------------------------------------------------------------------- /SAM/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /SAM/segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks = masks[0].detach().cpu().numpy() 164 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 166 | return masks, iou_predictions, low_res_masks 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | box (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /SAM/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /SAM/segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecesary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /SAM/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /SAM/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /SAM/test.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import pickle 4 | import sys 5 | from tqdm import tqdm 6 | import logging 7 | import argparse 8 | import random 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import torch.backends.cudnn as cudnn 13 | from utils import test_single_volume 14 | 15 | from segment_anything import sam_model_registry 16 | from datasets.dataset_cancer import Cancer_dataset 17 | from sam_lora_image_encoder import LoRA_Sam 18 | 19 | 20 | def inference(args, multimask_output, model, test_save_path=None): 21 | # db_test = db_config['Dataset'](base_dir=args.volume_path, split='val') 22 | db_test = Cancer_dataset(data_dir=args.data_dir, txt_dir=args.txt_dir) 23 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=8) 24 | logging.info(f'{len(testloader)} test iterations per epoch') 25 | model.eval() 26 | metric_benign = 0.0 27 | metric_tumor = 0.0 28 | metric_list = 0.0 29 | n_b, n_t = 0, 0 30 | results = {} 31 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 32 | h, w = sampled_batch['image'].shape[2:] 33 | image, label, case_name = sampled_batch['image'], sampled_batch['label'], sampled_batch['case_name'][0] 34 | metric_i = test_single_volume(image, label, model, classes=args.num_classes, multimask_output=multimask_output, 35 | patch_size=[args.img_size, args.img_size], input_size=[args.input_size, args.input_size], 36 | test_save_path=test_save_path, case=case_name, results=results) 37 | if "benign" in case_name: 38 | metric_benign += np.array(metric_i) 39 | n_b += 1 40 | else: 41 | metric_tumor += np.array(metric_i) 42 | n_t += 1 43 | metric_list += np.array(metric_i) 44 | logging.info('idx %d case %s mean_dice %f' % ( 45 | i_batch, case_name, metric_i[0])) 46 | metric_list = metric_list / len(db_test) 47 | metric_benign = metric_benign / n_b 48 | metric_tumor = metric_tumor / n_t 49 | 50 | logging.info('benign mean_dice %f' % (metric_benign)) 51 | logging.info('tumor mean_dice %f' % (metric_tumor)) 52 | logging.info('Testing performance in best val model: mean_dice : %f' % (metric_list)) 53 | logging.info("Testing Finished!") 54 | 55 | test_result = os.path.join(test_save_path, "test.gz") 56 | with gzip.open(test_result, "wb") as f: 57 | pickle.dump(results, f) 58 | logging.info("Saving results at %s" % (test_result)) 59 | 60 | return 1 61 | 62 | 63 | def config_to_dict(config): 64 | items_dict = {} 65 | with open(config, 'r') as f: 66 | items = f.readlines() 67 | for i in range(len(items)): 68 | key, value = items[i].strip().split(': ') 69 | items_dict[key] = value 70 | return items_dict 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--config', type=str, default=None, help='The config file provided by the trained model') 76 | parser.add_argument('--data_dir', type=str, 77 | default='../datasets/segment', help='root dir for data') 78 | parser.add_argument('--output_dir', type=str, default='./exp/output') 79 | parser.add_argument('--txt_dir', type=str, 80 | default='./datasets/test.txt', help='list dir') 81 | parser.add_argument('--num_classes', type=int, default=1) 82 | parser.add_argument('--img_size', type=int, default=224, help='Input image size of the network') 83 | parser.add_argument('--input_size', type=int, default=224, help='The input size for training SAM model') 84 | parser.add_argument('--seed', type=int, 85 | default=1234, help='random seed') 86 | parser.add_argument('--is_savenii', action='store_false', help='Whether to save results during inference') 87 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 88 | parser.add_argument('--ckpt', type=str, default='./checkpoints/sam_vit_b_01ec64.pth', 89 | help='Pretrained checkpoint') 90 | parser.add_argument('--lora_ckpt', type=str, 91 | default='./exp/epoch_199.pth', help='The checkpoint from LoRA') 92 | parser.add_argument('--vit_name', type=str, default='vit_b', help='Select one vit model') 93 | parser.add_argument('--rank', type=int, default=4, help='Rank for LoRA adaptation') 94 | 95 | args = parser.parse_args() 96 | 97 | if args.config is not None: 98 | # overwtite default configurations with config file\ 99 | config_dict = config_to_dict(args.config) 100 | for key in config_dict: 101 | setattr(args, key, config_dict[key]) 102 | 103 | if not args.deterministic: 104 | cudnn.benchmark = True 105 | cudnn.deterministic = False 106 | else: 107 | cudnn.benchmark = False 108 | cudnn.deterministic = True 109 | random.seed(args.seed) 110 | np.random.seed(args.seed) 111 | torch.manual_seed(args.seed) 112 | torch.cuda.manual_seed(args.seed) 113 | 114 | if not os.path.exists(args.output_dir): 115 | os.makedirs(args.output_dir) 116 | 117 | # register model 118 | sam, img_embedding_size = sam_model_registry[args.vit_name](image_size=args.img_size, 119 | num_classes=args.num_classes, 120 | checkpoint=args.ckpt, pixel_mean=[0, 0, 0], 121 | pixel_std=[1, 1, 1]) 122 | 123 | #pkg = import_module(args.module) 124 | net = LoRA_Sam(sam, args.rank).cuda() 125 | 126 | assert args.lora_ckpt is not None 127 | net.load_lora_parameters(args.lora_ckpt) 128 | 129 | multimask_output = False 130 | 131 | # initialize log 132 | log_folder = os.path.join(args.output_dir, 'test_log') 133 | os.makedirs(log_folder, exist_ok=True) 134 | logging.basicConfig(filename=log_folder + '/' + 'log.txt', level=logging.INFO, 135 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 136 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 137 | logging.info(str(args)) 138 | 139 | if args.is_savenii: 140 | test_save_path = os.path.join(args.output_dir, 'predictions') 141 | os.makedirs(test_save_path, exist_ok=True) 142 | else: 143 | test_save_path = None 144 | inference(args, multimask_output, net, test_save_path) 145 | -------------------------------------------------------------------------------- /SAM/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | 8 | from sam_lora_image_encoder import LoRA_Sam 9 | from segment_anything import sam_model_registry 10 | 11 | from trainer import trainer_cancer 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data_dir', type=str, 15 | default='../datasets/segment', help='root dir for data') 16 | parser.add_argument('--output', type=str, default='./exp') 17 | parser.add_argument('--txt_dir', type=str, 18 | default='./datasets/train.txt', help='list dir') 19 | parser.add_argument('--num_classes', type=int, 20 | default=1, help='output channel of network') 21 | # parser.add_argument('--max_iterations', type=int, 22 | # default=30000, help='maximum epoch number to train') 23 | parser.add_argument('--max_epochs', type=int, 24 | default=200, help='maximum epoch number to train') 25 | # parser.add_argument('--stop_epoch', type=int, 26 | # default=200, help='maximum epoch number to train') 27 | parser.add_argument('--batch_size', type=int, 28 | default=128, help='batch_size per gpu') 29 | parser.add_argument('--n_gpu', type=int, default=2, help='total gpu') 30 | parser.add_argument('--deterministic', type=int, default=1, 31 | help='whether use deterministic training') 32 | parser.add_argument('--base_lr', type=float, default=0.005, 33 | help='segmentation network learning rate') 34 | parser.add_argument('--img_size', type=int, 35 | default=224, help='input patch size of network input') 36 | parser.add_argument('--seed', type=int, 37 | default=1234, help='random seed') 38 | parser.add_argument('--vit_name', type=str, 39 | default='vit_b', help='select one vit model') 40 | parser.add_argument('--ckpt', type=str, default='./checkpoints/sam_vit_b_01ec64.pth', 41 | help='Pretrained checkpoint') 42 | parser.add_argument('--lora_ckpt', type=str, default=None, help='Finetuned lora checkpoint') 43 | parser.add_argument('--rank', type=int, default=4, help='Rank for LoRA adaptation') 44 | parser.add_argument('--warmup', action='store_true', help='If activated, warp up the learning from a lower lr to the base_lr') 45 | parser.add_argument('--warmup_period', type=int, default=200, 46 | help='Warp up iterations, only valid whrn warmup is activated') 47 | parser.add_argument('--AdamW', action='store_true', help='If activated, use AdamW to finetune SAM model') 48 | parser.add_argument('--dice_param', type=float, default=0.8) 49 | 50 | parser.add_argument('--lr_exp', type=float, default=0.9, help='The learning rate decay expotential') 51 | 52 | # acceleration choices 53 | parser.add_argument('--tf32', action='store_true', help='If activated, use tf32 to accelerate the training process') 54 | parser.add_argument('--compile', action='store_true', help='If activated, compile the training model for acceleration') 55 | parser.add_argument('--use_amp', action='store_true', help='If activated, adopt mixed precision for acceleration') 56 | 57 | args = parser.parse_args() 58 | 59 | def worker_init_fn(worker_id): 60 | random.seed(args.seed + worker_id) 61 | 62 | if __name__ == "__main__": 63 | if args.tf32: 64 | torch.backends.cuda.matmul.allow_tf32 = True 65 | torch.backends.cudnn.allow_tf32 = True 66 | if not args.deterministic: 67 | cudnn.benchmark = True 68 | cudnn.deterministic = False 69 | else: 70 | cudnn.benchmark = False 71 | cudnn.deterministic = True 72 | 73 | random.seed(args.seed) 74 | np.random.seed(args.seed) 75 | torch.manual_seed(args.seed) 76 | torch.cuda.manual_seed(args.seed) 77 | 78 | if not os.path.exists(args.output): 79 | os.makedirs(args.output) 80 | 81 | # register model 82 | sam, img_embedding_size = sam_model_registry[args.vit_name](image_size=args.img_size, 83 | num_classes=args.num_classes, 84 | checkpoint=args.ckpt, pixel_mean=[0, 0, 0], 85 | pixel_std=[1, 1, 1]) 86 | 87 | net = LoRA_Sam(sam, args.rank).cuda() 88 | if args.compile: 89 | net = torch.compile(net) 90 | 91 | if args.lora_ckpt is not None: 92 | net.load_lora_parameters(args.lora_ckpt) 93 | 94 | multimask_output = False 95 | 96 | low_res = img_embedding_size * 4 97 | 98 | config_file = os.path.join(args.output, 'config.txt') 99 | config_items = [] 100 | for key, value in args.__dict__.items(): 101 | config_items.append(f'{key}: {value}\n') 102 | 103 | with open(config_file, 'w') as f: 104 | f.writelines(config_items) 105 | 106 | trainer_cancer(args, net, args.output, multimask_output, low_res) -------------------------------------------------------------------------------- /SAM/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import time 7 | import math 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from tensorboardX import SummaryWriter 13 | from torch.nn.modules.loss import CrossEntropyLoss 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | from tqdm import tqdm 17 | from utils import DiceLoss 18 | from torchvision import transforms 19 | from icecream import ic 20 | 21 | 22 | def calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, dice_weight:float=0.8): 23 | low_res_logits = outputs['low_res_logits'] 24 | loss_ce = ce_loss(low_res_logits, low_res_label_batch[:].long()) 25 | loss_dice = dice_loss(low_res_logits, low_res_label_batch, softmax=True) 26 | loss = (1 - dice_weight) * loss_ce + dice_weight * loss_dice 27 | return loss, loss_ce, loss_dice 28 | 29 | 30 | def trainer_cancer(args, model, snapshot_path, multimask_output, low_res): 31 | from datasets.dataset_cancer import Cancer_dataset, RandomGenerator 32 | logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, 33 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 34 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 35 | logging.info(str(args)) 36 | base_lr = args.base_lr 37 | num_classes = args.num_classes 38 | batch_size = args.batch_size * args.n_gpu 39 | # max_iterations = args.max_iterations 40 | db_train = Cancer_dataset(data_dir=args.data_dir, txt_dir=args.txt_dir, 41 | transform=transforms.Compose( 42 | [RandomGenerator(output_size=[args.img_size, args.img_size], 43 | low_res=[low_res, low_res], phase="train")])) 44 | db_test = Cancer_dataset(data_dir=args.data_dir.replace('train', 'val'), txt_dir=args.txt_dir.replace('train', 'val'), 45 | transform=transforms.Compose( 46 | [RandomGenerator(output_size=[args.img_size, args.img_size], 47 | low_res=[low_res, low_res], phase="val")])) 48 | print("The length of train set is: {}".format(len(db_train))) 49 | print("The length of val set is: {}".format(len(db_test))) 50 | 51 | def worker_init_fn(worker_id): 52 | random.seed(args.seed + worker_id) 53 | 54 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True, 55 | worker_init_fn=worker_init_fn) 56 | testloader = DataLoader(db_test, batch_size=batch_size*2, shuffle=False, num_workers=12, pin_memory=True, 57 | worker_init_fn=worker_init_fn) 58 | if args.n_gpu > 1: 59 | model = nn.DataParallel(model) 60 | model.train() 61 | ce_loss = CrossEntropyLoss() 62 | dice_loss = DiceLoss(num_classes + 1) 63 | if args.warmup: 64 | b_lr = base_lr / args.warmup_period 65 | else: 66 | b_lr = base_lr 67 | if args.AdamW: 68 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=b_lr, betas=(0.9, 0.999), weight_decay=0.1) 69 | else: 70 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=b_lr, momentum=0.9, weight_decay=0.0001) # Even pass the model.parameters(), the `requires_grad=False` layers will not update 71 | if args.use_amp: 72 | scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) 73 | writer = SummaryWriter(snapshot_path + '/log/train') 74 | writer_ = SummaryWriter(snapshot_path + '/log/val') 75 | iter_num = 0 76 | max_epoch = args.max_epochs 77 | # stop_epoch = args.stop_epoch 78 | max_iterations = args.max_epochs * len(trainloader) # max_epoch = max_iterations // len(trainloader) + 1 79 | logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations)) 80 | best_performance = 0.0 81 | iterator = tqdm(range(max_epoch), ncols=70) 82 | for epoch_num in iterator: 83 | epoch_loss_total = 0 84 | epoch_loss_dice = 0 85 | epoch_loss_ce = 0 86 | for i_batch, sampled_batch in enumerate(trainloader): 87 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] # [b, c, h, w], [b, h, w] 88 | low_res_label_batch = sampled_batch['low_res_label'] 89 | image_batch, label_batch = image_batch.cuda(), label_batch.cuda() 90 | low_res_label_batch = low_res_label_batch.cuda() 91 | assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}' 92 | if args.use_amp: 93 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=args.use_amp): 94 | outputs = model(image_batch, multimask_output, args.img_size) 95 | loss, loss_ce, loss_dice = calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, args.dice_param) 96 | scaler.scale(loss).backward() 97 | scaler.step(optimizer) 98 | scaler.update() 99 | optimizer.zero_grad() 100 | else: 101 | outputs = model(image_batch, multimask_output, args.img_size) 102 | loss, loss_ce, loss_dice = calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, args.dice_param) 103 | optimizer.zero_grad() 104 | loss.backward() 105 | # if args.skip_hard and iter_num > 3 * args.warmup_period and loss.item() > 0.4: 106 | # skip_hard_nums += 1 107 | # print(f'Skip hard nums: {skip_hard_nums}') 108 | # continue 109 | optimizer.step() 110 | 111 | epoch_loss_total += loss.item() 112 | epoch_loss_dice += loss_dice.item() 113 | epoch_loss_ce += loss_ce.item() 114 | 115 | if args.warmup and iter_num < args.warmup_period: 116 | lr_ = base_lr * ((iter_num + 1) / args.warmup_period) 117 | for param_group in optimizer.param_groups: 118 | param_group['lr'] = lr_ 119 | else: 120 | if args.warmup: 121 | shift_iter = iter_num - args.warmup_period 122 | assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero' 123 | else: 124 | shift_iter = iter_num 125 | lr_ = base_lr * (1.0 - shift_iter / max_iterations) ** args.lr_exp 126 | for param_group in optimizer.param_groups: 127 | param_group['lr'] = lr_ 128 | 129 | iter_num = iter_num + 1 130 | writer.add_scalar('info/lr', lr_, iter_num) 131 | # writer.add_scalar('info/total_loss', loss, iter_num) 132 | # writer.add_scalar('info/loss_ce', loss_ce, iter_num) 133 | # writer.add_scalar('info/loss_dice', loss_dice, iter_num) 134 | 135 | if iter_num % 20 == 0: 136 | image = image_batch[1, 0:1, :, :] 137 | image = (image - image.min()) / (image.max() - image.min()) 138 | writer.add_image('train/Image', image, iter_num) 139 | output_masks = outputs['masks'] 140 | output_masks = torch.argmax(torch.softmax(output_masks, dim=1), dim=1, keepdim=True) 141 | writer.add_image('train/Prediction', output_masks[1, ...] * 50, iter_num) 142 | labs = label_batch[1, ...].unsqueeze(0) * 50 143 | writer.add_image('train/GroundTruth', labs, iter_num) 144 | 145 | epoch_loss_total /= len(trainloader) 146 | epoch_loss_dice /= len(trainloader) 147 | epoch_loss_ce /= len(trainloader) 148 | 149 | logging.info('Train: epoch %d : loss : %f, loss_ce: %f, loss_dice: %f' % 150 | (epoch_num+1, epoch_loss_total, epoch_loss_ce, epoch_loss_dice)) 151 | 152 | writer.add_scalar('info/total_loss', epoch_loss_total, epoch_num+1) 153 | writer.add_scalar('info/loss_ce', epoch_loss_ce, epoch_num+1) 154 | writer.add_scalar('info/loss_dice', epoch_loss_dice, epoch_num+1) 155 | 156 | model.eval() 157 | epoch_loss_total = 0 158 | epoch_loss_dice = 0 159 | epoch_loss_ce = 0 160 | for sampled_batch in testloader: 161 | with torch.no_grad(): 162 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] # [b, c, h, w], [b, h, w] 163 | low_res_label_batch = sampled_batch['low_res_label'] 164 | image_batch, label_batch = image_batch.cuda(), label_batch.cuda() 165 | low_res_label_batch = low_res_label_batch.cuda() 166 | outputs = model(image_batch, multimask_output, args.img_size) 167 | loss, loss_ce, loss_dice = calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, args.dice_param) 168 | 169 | epoch_loss_total += loss.item() 170 | epoch_loss_dice += loss_dice.item() 171 | epoch_loss_ce += loss_ce.item() 172 | 173 | epoch_loss_total /= len(testloader) 174 | epoch_loss_dice /= len(testloader) 175 | epoch_loss_ce /= len(testloader) 176 | 177 | logging.info('Test: epoch %d : loss : %f, loss_ce: %f, loss_dice: %f' % 178 | (epoch_num + 1, epoch_loss_total, epoch_loss_ce, epoch_loss_dice)) 179 | 180 | writer_.add_scalar('info/total_loss', epoch_loss_total, epoch_num + 1) 181 | writer_.add_scalar('info/loss_ce', epoch_loss_ce, epoch_num + 1) 182 | writer_.add_scalar('info/loss_dice', epoch_loss_dice, epoch_num + 1) 183 | 184 | save_interval = 20 # int(max_epoch/6) 185 | if (epoch_num + 1) % save_interval == 0: 186 | save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 187 | try: 188 | model.save_lora_parameters(save_mode_path) 189 | except: 190 | model.module.save_lora_parameters(save_mode_path) 191 | logging.info("save model to {}".format(save_mode_path)) 192 | 193 | if epoch_num >= max_epoch - 1: 194 | save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 195 | try: 196 | model.save_lora_parameters(save_mode_path) 197 | except: 198 | model.module.save_lora_parameters(save_mode_path) 199 | logging.info("save model to {}".format(save_mode_path)) 200 | iterator.close() 201 | break 202 | 203 | writer.close() 204 | return "Training Finished!" 205 | -------------------------------------------------------------------------------- /SAM/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import numpy as np 4 | import torch 5 | from medpy import metric 6 | from scipy.ndimage import zoom 7 | import torch.nn as nn 8 | import SimpleITK as sitk 9 | import torch.nn.functional as F 10 | import imageio 11 | from einops import repeat 12 | from icecream import ic 13 | from PIL import Image 14 | import gzip 15 | import pickle 16 | 17 | 18 | class Focal_loss(nn.Module): 19 | def __init__(self, alpha=0.25, gamma=2, num_classes=3, size_average=True): 20 | super(Focal_loss, self).__init__() 21 | self.size_average = size_average 22 | if isinstance(alpha, list): 23 | assert len(alpha) == num_classes 24 | print(f'Focal loss alpha={alpha}, will assign alpha values for each class') 25 | self.alpha = torch.Tensor(alpha) 26 | else: 27 | assert alpha < 1 28 | print(f'Focal loss alpha={alpha}, will shrink the impact in background') 29 | self.alpha = torch.zeros(num_classes) 30 | self.alpha[0] = alpha 31 | self.alpha[1:] = 1 - alpha 32 | self.gamma = gamma 33 | self.num_classes = num_classes 34 | 35 | def forward(self, preds, labels): 36 | """ 37 | Calc focal loss 38 | :param preds: size: [B, N, C] or [B, C], corresponds to detection and classification tasks [B, C, H, W]: segmentation 39 | :param labels: size: [B, N] or [B] [B, H, W]: segmentation 40 | :return: 41 | """ 42 | self.alpha = self.alpha.to(preds.device) 43 | preds = preds.permute(0, 2, 3, 1).contiguous() 44 | preds = preds.view(-1, preds.size(-1)) 45 | B, H, W = labels.shape 46 | assert B * H * W == preds.shape[0] 47 | assert preds.shape[-1] == self.num_classes 48 | preds_logsoft = F.log_softmax(preds, dim=1) # log softmax 49 | preds_softmax = torch.exp(preds_logsoft) # softmax 50 | 51 | preds_softmax = preds_softmax.gather(1, labels.view(-1, 1)) 52 | preds_logsoft = preds_logsoft.gather(1, labels.view(-1, 1)) 53 | alpha = self.alpha.gather(0, labels.view(-1)) 54 | loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma), 55 | preds_logsoft) # torch.low(1 - preds_softmax) == (1 - pt) ** r 56 | 57 | loss = torch.mul(alpha, loss.t()) 58 | if self.size_average: 59 | loss = loss.mean() 60 | else: 61 | loss = loss.sum() 62 | return loss 63 | 64 | 65 | class DiceLoss(nn.Module): 66 | def __init__(self, n_classes): 67 | super(DiceLoss, self).__init__() 68 | self.n_classes = n_classes 69 | 70 | def _one_hot_encoder(self, input_tensor): 71 | tensor_list = [] 72 | for i in range(self.n_classes): 73 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 74 | tensor_list.append(temp_prob.unsqueeze(1)) 75 | output_tensor = torch.cat(tensor_list, dim=1) 76 | return output_tensor.float() 77 | 78 | def _dice_loss(self, score, target): 79 | target = target.float() 80 | smooth = 1e-5 81 | intersect = torch.sum(score * target) 82 | y_sum = torch.sum(target * target) 83 | z_sum = torch.sum(score * score) 84 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 85 | loss = 1 - loss 86 | return loss 87 | 88 | def forward(self, inputs, target, weight=None, softmax=False): 89 | if softmax: 90 | inputs = torch.softmax(inputs, dim=1) 91 | target = self._one_hot_encoder(target) 92 | if weight is None: 93 | weight = [1] * self.n_classes 94 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), 95 | target.size()) 96 | class_wise_dice = [] 97 | loss = 0.0 98 | for i in range(0, self.n_classes): 99 | dice = self._dice_loss(inputs[:, i], target[:, i]) 100 | class_wise_dice.append(1.0 - dice.item()) 101 | loss += dice * weight[i] 102 | return loss / self.n_classes 103 | 104 | 105 | def calculate_metric_percase(pred, gt): 106 | pred[pred > 0] = 1 107 | gt[gt > 0] = 1 108 | if pred.sum() > 0 and gt.sum() > 0: 109 | dice = metric.binary.dc(pred, gt) 110 | # hd95 = metric.binary.hd95(pred, gt) 111 | return dice 112 | elif pred.sum() > 0 and gt.sum() == 0: 113 | return 1 114 | else: 115 | return 0 116 | 117 | 118 | def test_single_volume(image, label, net, classes, multimask_output, patch_size=[256, 256], input_size=[224, 224], 119 | test_save_path=None, case=None, results=None): 120 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 121 | orign_image = copy.deepcopy(image) 122 | x, y = image.shape[0:2] 123 | if x != input_size[0] or y != input_size[1]: 124 | image = zoom(image, (input_size[0] / x, input_size[1] / y, 1.0), order=3) # previous using 0 125 | new_x, new_y = image.shape[0], image.shape[1] # [input_size[0], input_size[1]] 126 | if new_x != patch_size[0] or new_y != patch_size[1]: 127 | image = zoom(image, (patch_size[0] / new_x, patch_size[1] / new_y), order=3) 128 | inputs = torch.from_numpy(image.astype(np.float32) / 255.0) 129 | inputs = inputs.permute(2, 0, 1) 130 | inputs = inputs.unsqueeze(0).cuda() 131 | net.eval() 132 | with torch.no_grad(): 133 | outputs = net(inputs, multimask_output, patch_size[0]) 134 | output_masks = outputs['masks'] 135 | out = torch.argmax(torch.softmax(output_masks, dim=1), dim=1).squeeze(0) 136 | prediction = out.cpu().detach().numpy() 137 | if x != patch_size[0] or y != patch_size[1]: 138 | prediction = zoom(prediction, (x / patch_size[0], y / patch_size[1]), order=0) 139 | 140 | metric_list = [] 141 | # for i in range(1, classes + 1): 142 | # metric_list.append(calculate_metric_percase(prediction == i, label == i)) 143 | 144 | if test_save_path is not None: 145 | img = Image.fromarray(orign_image.astype(np.uint8)) 146 | prd = Image.fromarray((prediction*255).astype(np.uint8)) 147 | lab = Image.fromarray((label*255).astype(np.uint8)) 148 | 149 | # 可视化结果 150 | # img.save(os.path.join(test_save_path, case + "_ori.jpg")) 151 | # prd.save(os.path.join(test_save_path, case + "_prd.jpg")) 152 | # lab.save(os.path.join(test_save_path, case + "_gt.jpg")) 153 | 154 | results[case] = {"image": img, "pred": prd, "label": lab} 155 | 156 | return metric_list 157 | -------------------------------------------------------------------------------- /cbam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 7 | super(BasicConv, self).__init__() 8 | self.out_channels = out_planes 9 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 10 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 11 | self.relu = nn.ReLU() if relu else None 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | if self.bn is not None: 16 | x = self.bn(x) 17 | if self.relu is not None: 18 | x = self.relu(x) 19 | return x 20 | class Flatten(nn.Module): 21 | def forward(self, x): 22 | return x.view(x.size(0), -1) 23 | 24 | class ChannelGate(nn.Module): 25 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 26 | super(ChannelGate, self).__init__() 27 | self.gate_channels = gate_channels 28 | self.mlp = nn.Sequential( 29 | Flatten(), 30 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 31 | nn.ReLU(), 32 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 33 | ) 34 | self.pool_types = pool_types 35 | def forward(self, x): 36 | channel_att_sum = None 37 | for pool_type in self.pool_types: 38 | if pool_type=='avg': 39 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 40 | channel_att_raw = self.mlp( avg_pool ) 41 | elif pool_type=='max': 42 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 43 | channel_att_raw = self.mlp( max_pool ) 44 | elif pool_type=='lp': 45 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 46 | channel_att_raw = self.mlp( lp_pool ) 47 | elif pool_type=='lse': 48 | # LSE pool only 49 | lse_pool = logsumexp_2d(x) 50 | channel_att_raw = self.mlp( lse_pool ) 51 | if channel_att_sum is None: 52 | channel_att_sum = channel_att_raw 53 | else: 54 | channel_att_sum = channel_att_sum + channel_att_raw 55 | 56 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 57 | 58 | return x * scale 59 | 60 | 61 | def logsumexp_2d(tensor): 62 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 63 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 64 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 65 | return outputs 66 | 67 | class ChannelPool(nn.Module): 68 | def forward(self, x): 69 | return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1) 70 | 71 | class SpatialGate(nn.Module): 72 | def __init__(self): 73 | super(SpatialGate, self).__init__() 74 | kernel_size = 7 75 | self.compress = ChannelPool() 76 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 77 | def forward(self, x): 78 | x_compress = self.compress(x) 79 | x_out = self.spatial(x_compress) 80 | scale = torch.sigmoid(x_out) # broadcasting 81 | 82 | return x * scale, scale 83 | 84 | class CBAM(nn.Module): 85 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 86 | super(CBAM, self).__init__() 87 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 88 | self.no_spatial = no_spatial 89 | if not no_spatial: 90 | self.SpatialGate = SpatialGate() 91 | def forward(self, x): 92 | x_out = self.ChannelGate(x) 93 | if not self.no_spatial: 94 | x_out, scale = self.SpatialGate(x_out) 95 | return x_out -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import PIL.Image as Image 2 | import os 3 | 4 | import torch.utils.data as data 5 | from torch.utils.data import DataLoader 6 | 7 | class_labels = ['normal', 'benign', 'tumor'] 8 | 9 | def make_dataset(path): 10 | imgs = [] 11 | for i in range(len(class_labels)): 12 | img_dir_path = os.path.join(path, class_labels[i]) 13 | img_name = os.listdir(img_dir_path) 14 | for name in img_name: 15 | img_path = os.path.join(img_dir_path, name) 16 | crop_path = img_path.replace('global', 'local_seg') 17 | imgs.append((img_path, crop_path, i, name)) 18 | 19 | return imgs 20 | 21 | 22 | class LPCDataset(data.Dataset): 23 | def __init__(self, root, transform=None): 24 | self.imgs = make_dataset(root) 25 | self.transform = transform 26 | 27 | def __getitem__(self, index): 28 | img_path, crop_path, label, name = self.imgs[index] 29 | 30 | img_x = Image.open(img_path) 31 | crop_x = Image.open(crop_path) 32 | if self.transform is not None: 33 | img_x = self.transform(img_x) 34 | crop_x = self.transform(crop_x) 35 | 36 | # o for global features, 1 for local features 37 | return img_x, crop_x, label, 0, 1 38 | 39 | def __len__(self): 40 | return len(self.imgs) 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | from torchvision.transforms import * 46 | transforms_train = Compose([ 47 | RandomAffine(degrees=10, scale=(0.9, 1.1), translate=(0.1, 0.1), shear=5.729578), 48 | RandomHorizontalFlip(), 49 | ColorJitter(0.4, 0.4, 0.4, 0), 50 | Resize(256), 51 | CenterCrop(224), 52 | ToTensor(), 53 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 54 | ]) 55 | 56 | train_dataset = LPCDataset(root=r'./datasets/global/train', transform=transforms_train) 57 | train_dataloaders = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=1) 58 | case = next(iter(train_dataloaders)) 59 | 60 | print(case) 61 | 62 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet18, resnet34, resnet50 5 | from cbam import CBAM 6 | 7 | class ResNet(nn.Module): 8 | def __init__(self, resnet_model, pretrained=True): 9 | super().__init__() 10 | if pretrained: 11 | self.resnet = resnet_model(pretrained=True) 12 | else: 13 | self.resnet = resnet_model(pretrained=False) 14 | 15 | self.conv1 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool) 16 | self.layers1 = self.resnet.layer1 17 | self.layers2 = self.resnet.layer2 18 | self.layers3 = self.resnet.layer3 19 | self.layers4 = self.resnet.layer4 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | c2 = self.layers1(x) 24 | c3 = self.layers2(c2) 25 | c4 = self.layers3(c3) 26 | c5 = self.layers4(c4) 27 | return c2, c3, c4, c5 28 | 29 | class FPNBlock(nn.Module): 30 | def __init__(self, in_channels, out_channels=256, is_highest=False, is_lowest=False): 31 | super().__init__() 32 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 33 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 34 | self.is_highest = is_highest 35 | self.is_lowest = is_lowest 36 | 37 | def forward(self, x, y): 38 | x = self.conv1(x) 39 | if not self.is_highest: 40 | target_height = x.shape[2] 41 | target_width = x.shape[3] 42 | x += F.interpolate(y, size=(target_height, target_width), mode="bilinear", align_corners=True) 43 | if self.is_lowest: 44 | x = self.conv2(x) 45 | return x 46 | 47 | class FPN(nn.Module): 48 | def __init__(self, expansion=1, in_channels_list=[64, 128, 256, 512], out_channels=256): 49 | super().__init__() 50 | self.P2 = FPNBlock(in_channels_list[0]*expansion, out_channels=out_channels, is_lowest=True) 51 | self.P3 = FPNBlock(in_channels_list[1]*expansion, out_channels=out_channels) 52 | self.P4 = FPNBlock(in_channels_list[2]*expansion, out_channels=out_channels) 53 | self.P5 = FPNBlock(in_channels_list[3]*expansion, out_channels=out_channels, is_highest=True) 54 | 55 | def forward(self, C2, C3, C4, C5): 56 | x = self.P5(C5, None) 57 | x = self.P4(C4, x) 58 | x = self.P3(C3, x) 59 | P2 = self.P2(C2, x) 60 | return P2 61 | 62 | class ResNetFPN(nn.Module): 63 | def __init__(self, resnet_type='resnet50', pretrained=True): 64 | super().__init__() 65 | 66 | resnet_models = { 67 | 'resnet18': resnet18, 68 | 'resnet34': resnet34, 69 | 'resnet50': resnet50, 70 | } 71 | 72 | if resnet_type not in resnet_models: 73 | raise ValueError(f"Unsupported ResNet type: {resnet_type}. Choose from {list(resnet_models.keys())}") 74 | 75 | resnet_model = resnet_models[resnet_type] 76 | 77 | self.resnet1 = ResNet(resnet_model, pretrained=pretrained) 78 | self.resnet2 = ResNet(resnet_model, pretrained=pretrained) 79 | 80 | self.FPN1 = FPN(expansion=4) 81 | self.FPN2 = FPN(expansion=4) 82 | 83 | def forward(self, x, y): 84 | C2, C3, C4, C5 = self.resnet1(x) 85 | c2, c3, c4, c5 = self.resnet2(y) 86 | 87 | P2 = self.FPN1(C2, C3, C4, C5) 88 | p2 = self.FPN2(c2, c3, c4, c5) 89 | 90 | return P2, p2 91 | 92 | class SAM_FNet(nn.Module): 93 | def __init__(self, num_classes=3, num_features=2, resnet_type='resnet50', pretrained=True): 94 | super().__init__() 95 | self.resnet_fpn = ResNetFPN(resnet_type=resnet_type, pretrained=pretrained) 96 | 97 | self.cbam5_x = CBAM(256) 98 | self.cbam5_x1 = CBAM(256) 99 | 100 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 101 | 102 | self.drop_fusion = nn.Dropout(p=0.5, inplace=True) 103 | 104 | self.fc1_dense = nn.Linear(256, num_classes) 105 | self.fc2_dense = nn.Linear(256, num_classes) 106 | self.fc_dense = nn.Linear(256 * 2, num_classes) 107 | 108 | self.modal_dense = nn.Linear(256, num_features) 109 | 110 | def forward(self, img1, img2, target=None): 111 | feats = {} 112 | 113 | x, y = self.resnet_fpn(img1, img2) 114 | 115 | output1 = self.cbam5_x(x) # (1, 256, 56, 56) 116 | output2 = self.cbam5_x1(y) 117 | 118 | c1 = self.avg_pool(output1) 119 | c1 = c1.view(c1.size(0), -1) 120 | 121 | c2 = self.avg_pool(output2) 122 | c2 = c2.view(c2.size(0), -1) 123 | 124 | feats['global'] = F.normalize(c1, dim=-1, p=2) 125 | feats['local'] = F.normalize(c2, dim=-1, p=2) 126 | 127 | c1_cls = self.fc1_dense(c1) 128 | c2_cls = self.fc2_dense(c2) 129 | c1_mdl = self.modal_dense(c1) 130 | c2_mdl = self.modal_dense(c2) 131 | 132 | output = torch.cat((c1, c2), dim=1) 133 | output = self.drop_fusion(output) 134 | output = self.fc_dense(output) 135 | 136 | return c1_cls, c2_cls, output, c1_mdl, c2_mdl, feats 137 | 138 | 139 | def SAM_FNet18(num_classes=3, num_features=2, pretrained=True): 140 | return SAM_FNet(num_classes=num_classes, num_features=num_features, resnet_type='resnet18', pretrained=pretrained) 141 | 142 | def SAM_FNet34(num_classes=3, num_features=2, pretrained=True): 143 | return SAM_FNet(num_classes=num_classes, num_features=num_features, resnet_type='resnet34', pretrained=pretrained) 144 | 145 | def SAM_FNet50(num_classes=3, num_features=2, pretrained=True): 146 | return SAM_FNet(num_classes=num_classes, num_features=num_features, resnet_type='resnet50', pretrained=pretrained) 147 | 148 | 149 | if __name__ == "__main__": 150 | net = SAM_FNet50(3, 2, pretrained=True) 151 | x = torch.randn((1, 3, 256, 256), dtype=torch.float32) 152 | y = torch.randn((1, 3, 256, 256), dtype=torch.float32) 153 | ### Inference 154 | _, _, output1, _, _, _ = net.forward(x, y) 155 | ### Output 156 | print(output1) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | multiprocessing.set_start_method("spawn", True) 3 | 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import random 11 | from torch import optim 12 | from torch.utils.data import DataLoader 13 | from dataset import LPCDataset 14 | from model import SAM_FNet18, SAM_FNet34, SAM_FNet50 15 | import time 16 | from tqdm import tqdm 17 | from torch.nn.parallel import DataParallel 18 | from torchvision import transforms 19 | import torch.cuda.amp as amp 20 | from pathlib import Path 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | import os 24 | 25 | def classacc(predicted, label, num_classes=3): 26 | """ 27 | Calculate the accuracy for each class. 28 | 29 | Args: 30 | predicted (torch.Tensor): The predicted class labels. 31 | label (torch.Tensor): The ground truth class labels. 32 | num_classes (int): The number of classes. Default is 3 (Normal, Benign, Tumor). 33 | 34 | Returns: 35 | List[torch.Tensor]: A list containing the accuracy for each class. 36 | """ 37 | acc = [] 38 | for cls in range(num_classes): 39 | acc_class = ((label == cls) & (predicted == cls)).sum().float() 40 | acc.append(acc_class) 41 | 42 | return acc 43 | 44 | def cal_loss(args, criterion, m1, m2, c1, c2, outputs, features, labels, target1, target2, epoch): 45 | # GAN-like Loss 46 | if (epoch > args.gan_epoch) and args.gan_opt: 47 | idx = labels != 0 48 | loss_flag = torch.zeros_like(labels).float().cuda() 49 | loss_flag[idx] = 1 50 | loss1 = criterion['cs'](features['global'], features['local'], loss_flag) 51 | 52 | target1 = target1.unsqueeze(1) 53 | target2 = target2.unsqueeze(1) 54 | target1 = torch.cat([1.0 - target1, target1], dim=1) 55 | target2 = torch.cat([1.0 - target2, target2], dim=1) 56 | 57 | loss_be_m1 = F.binary_cross_entropy_with_logits(m1, target1, reduction='none') 58 | loss_be_m2 = F.binary_cross_entropy_with_logits(m2, target2, reduction='none') 59 | loss_be_m1 = loss_be_m1 * loss_flag.unsqueeze(1) 60 | loss_be_m2 = loss_be_m2 * loss_flag.unsqueeze(1) 61 | 62 | loss_flag_sum = loss_flag.sum() 63 | if loss_flag_sum == 0: 64 | loss_flag_sum = torch.tensor(1e-8).to(loss_flag.device) 65 | loss_be_m1 = loss_be_m1.sum() / loss_flag_sum 66 | loss_be_m2 = loss_be_m2.sum() / loss_flag_sum 67 | loss2 = loss_be_m1 + loss_be_m2 68 | else: 69 | loss1 = loss2 = 0 70 | 71 | # cross-entropy loss - global, local, and fusion 72 | loss3 = criterion['ce'](outputs, labels) 73 | loss4 = criterion['ce'](c1, labels) 74 | loss5 = criterion['ce'](c2, labels) 75 | 76 | # total loss 77 | loss = args.gan_weight * loss1 + args.gan_weight * loss2 + \ 78 | args.fusion_weight * loss3 + \ 79 | args.global_weight * loss4 + \ 80 | args.local_weight * loss5 81 | 82 | return loss1, loss2, loss3, loss4, loss5, loss 83 | 84 | def train_model(args, model, criterion, train_dataloaders, val_dataloaders, num_epochs, model_path, writer): 85 | if args.warmup: 86 | b_lr = args.lr / args.warmup_period 87 | else: 88 | b_lr = args.lr 89 | if args.AdamW: 90 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=b_lr, betas=(0.9, 0.999), weight_decay=0.1) 91 | else: 92 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=b_lr, momentum=0.9, weight_decay=0.0005) 93 | 94 | iter_num = 0 95 | max_iterations = num_epochs * len(train_dataloaders) 96 | print("{} iterations per epoch. {} max iterations ".format(len(train_dataloaders), max_iterations)) 97 | 98 | scaler = amp.GradScaler() 99 | best_val_acc = 0 100 | for epoch in np.arange(0, num_epochs) + 1: 101 | model.train() 102 | print("=======Epoch:{}=======".format(epoch)) 103 | epoch_start_time = time.time() 104 | 105 | epoch_loss = 0 106 | epoch_cos_loss = 0.0 107 | epoch_dis_loss = 0.0 108 | epoch_local_loss = 0.0 109 | epoch_global_loss = 0.0 110 | epoch_fusion_loss = 0.0 111 | 112 | # Initialize dictionaries for correct predictions and total counts 113 | correct = {cls: 0.0 for cls in range(args.num_classes)} 114 | total = {cls: 0.0 for cls in range(args.num_classes)} 115 | 116 | for idx, (input1, input2, labels, target1, target2) in tqdm(enumerate(train_dataloaders), total=len(train_dataloaders)): # 这边一次取出一个batchsize的东西 117 | input1, input2, labels, target1, target2 = \ 118 | input1.cuda(), input2.cuda(), labels.cuda(), target1.cuda(), target2.cuda() 119 | with amp.autocast(): 120 | c1, c2, outputs, m1, m2, features = model(input1, input2, labels) 121 | loss1, loss2, loss3, loss4, loss5, loss = cal_loss(args, criterion, m1, m2, c1, c2, 122 | outputs, features, labels, target1, target2, epoch) 123 | current_batchsize = outputs.size()[0] 124 | optimizer.zero_grad() 125 | scaler.scale(loss).backward() 126 | scaler.step(optimizer) 127 | scaler.update() 128 | epoch_loss += loss.item()*current_batchsize 129 | if (epoch > args.gan_epoch) and args.gan_opt: 130 | epoch_cos_loss += loss1.item()*current_batchsize 131 | epoch_dis_loss += loss2.item()*current_batchsize 132 | epoch_global_loss += loss4.item()*current_batchsize 133 | epoch_local_loss += loss5.item()*current_batchsize 134 | epoch_fusion_loss += loss3.item()*current_batchsize 135 | 136 | _, predicted = torch.max(outputs.data, 1) 137 | acc_temp = classacc(predicted, labels, args.num_classes) 138 | for cls in range(args.num_classes): 139 | correct[cls] += acc_temp[cls] 140 | total[cls] += (labels == cls).sum().float() 141 | 142 | if args.warmup and iter_num < args.warmup_period: 143 | lr_ = args.lr * ((iter_num + 1) / args.warmup_period) 144 | for param_group in optimizer.param_groups: 145 | param_group['lr'] = lr_ 146 | else: 147 | if args.warmup: 148 | shift_iter = iter_num - args.warmup_period 149 | assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero' 150 | else: 151 | shift_iter = iter_num 152 | lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.965 153 | for param_group in optimizer.param_groups: 154 | param_group['lr'] = lr_ 155 | 156 | iter_num = iter_num + 1 157 | 158 | correct_total = sum(correct.values()) 159 | sum_total = sum(total.values()) 160 | 161 | epochmean = epoch_loss/sum_total 162 | if epoch > args.gan_epoch and args.gan_opt: 163 | epoch_cos_mean = epoch_cos_loss/sum_total 164 | epoch_dis_mean = epoch_dis_loss/sum_total 165 | epoch_global_mean = epoch_global_loss/sum_total 166 | epoch_local_mean = epoch_local_loss/sum_total 167 | epoch_fusion_mean = epoch_fusion_loss/sum_total 168 | 169 | recall_classes = {cls: correct[cls]/total[cls] for cls in range(args.num_classes)} 170 | recall_mean = sum(recall_classes.values())/3.0 171 | acc_mean = correct_total/sum_total 172 | 173 | print(f"train_loss_mean_{epoch}: {epochmean:.4f}") 174 | if (epoch > args.gan_epoch) and args.gan_opt: 175 | print(f"train_cos_loss_mean_{epoch}: {epoch_cos_mean:.4f}") 176 | print(f"train_dis_loss_mean_{epoch}: {epoch_dis_mean:.4f}") 177 | 178 | print(f"train_global_loss_{epoch}: {epoch_global_mean:.4f}") 179 | print(f"train_local_loss_{epoch}: {epoch_local_mean:.4f}") 180 | print(f"train_fusion_loss_{epoch}: {epoch_fusion_mean:.4f}") 181 | 182 | for cls in range(args.num_classes): 183 | print(f"train_recall_class_{cls}_{epoch}: {recall_classes[cls]:.4f}") 184 | print(f"train_recall_class_mean_{epoch}: {recall_mean:.4f}") 185 | print(f"train_acc_{epoch}: {acc_mean:.4f}") 186 | 187 | ## tensorboard 188 | writer.add_scalar("train/loss", epochmean, epoch) 189 | if epoch > args.gan_epoch and args.gan_opt: 190 | writer.add_scalar("train/cos_loss", epoch_cos_mean, epoch) 191 | writer.add_scalar("train/dis_loss", epoch_dis_mean, epoch) 192 | writer.add_scalar("train/global_loss", epoch_global_mean, epoch) 193 | writer.add_scalar("train/local_loss", epoch_local_mean, epoch) 194 | writer.add_scalar("train/fusion_loss", epoch_fusion_mean, epoch) 195 | 196 | for cls in range(args.num_classes): 197 | writer.add_scalar(f"train/recall_{cls}", recall_classes[cls], epoch) 198 | writer.add_scalar("train/recall_mean", recall_mean, epoch) 199 | writer.add_scalar("train/acc", acc_mean, epoch) 200 | 201 | # ---------------------- validating -------------------------- 202 | model.eval() 203 | with torch.no_grad(): 204 | epoch_loss_val = 0 205 | epoch_glabol_loss_val = 0.0 206 | epoch_local_loss_val = 0.0 207 | epoch_cos_loss_val = 0.0 208 | epoch_dis_loss_val = 0.0 209 | epoch_fusion_loss_val = 0.0 210 | 211 | correct = {cls: 0.0 for cls in range(args.num_classes)} 212 | total = {cls: 0.0 for cls in range(args.num_classes)} 213 | 214 | for idx, (input1, input2, labels, target1, target2) in tqdm(enumerate(val_dataloaders), total=len(val_dataloaders)): 215 | input1, input2, labels, target1, target2 = \ 216 | input1.cuda(), input2.cuda(), labels.cuda(), target1.cuda(), target2.cuda() 217 | c1, c2, outputs, m1, m2, features = model(input1, input2, labels) 218 | loss1, loss2, loss3, loss4, loss5, loss = cal_loss(args, criterion, m1, m2, c1, c2, 219 | outputs, features, labels, target1, target2, epoch) 220 | current_batchsize = outputs.size()[0] 221 | epoch_loss_val += loss.item()*current_batchsize 222 | if (epoch > args.gan_epoch) and args.gan_opt: 223 | epoch_cos_loss_val += loss1.item()*current_batchsize 224 | epoch_dis_loss_val += loss2.item()*current_batchsize 225 | epoch_glabol_loss_val += loss4.item()*current_batchsize 226 | epoch_local_loss_val += loss5.item()*current_batchsize 227 | epoch_fusion_loss_val += loss3.item()*current_batchsize 228 | 229 | _, predicted = torch.max(outputs.data, 1) 230 | acc_temp = classacc(predicted, labels, args.num_classes) 231 | for cls in range(args.num_classes): 232 | correct[cls] += acc_temp[cls] 233 | total[cls] += (labels == cls).sum().float() 234 | 235 | correct_total = sum(correct.values()) 236 | sum_total = sum(total.values()) 237 | 238 | epochmean_val = epoch_loss_val/sum_total 239 | if epoch > args.gan_epoch and args.gan_opt: 240 | epochmean_cos_val = epoch_cos_loss_val/sum_total 241 | epochmean_dis_val = epoch_dis_loss_val/sum_total 242 | epochmean_global_val = epoch_glabol_loss_val / sum_total 243 | epochmean_local_val = epoch_local_loss_val / sum_total 244 | epochmean_fusion_val = epoch_fusion_loss_val / sum_total 245 | 246 | recall_classes = {cls: correct[cls] / total[cls] for cls in range(args.num_classes)} 247 | recall_mean = sum(recall_classes.values()) / 3.0 248 | acc_mean = correct_total / sum_total 249 | 250 | print(f"val_loss_mean_{epoch}: {epochmean_val:.4f}") 251 | if (epoch > args.gan_epoch) and args.gan_opt: 252 | print(f"val_cos_loss_mean_{epoch}: {epochmean_cos_val:.4f}") 253 | print(f"val_dis_loss_mean_{epoch}: {epochmean_dis_val:.4f}") 254 | print(f"val_global_loss_{epoch}: {epochmean_global_val:.4f}") 255 | print(f"val_local_loss_{epoch}: {epochmean_local_val:.4f}") 256 | print(f"val_fusion_loss_{epoch}: {epochmean_fusion_val:.4f}") 257 | 258 | for cls in range(args.num_classes): 259 | print(f"val_recall_class_{cls}_{epoch}: {recall_classes[cls]:.4f}") 260 | print(f"val_recall_class_mean_{epoch}: {recall_mean:.4f}") 261 | print(f"val_acc_{epoch}: {acc_mean:.4f}") 262 | 263 | ## tensorboard 264 | writer.add_scalar("val/loss", epochmean_val, epoch) 265 | if epoch > args.gan_epoch and args.gan_opt: 266 | writer.add_scalar("val/cos_loss", epochmean_cos_val, epoch) 267 | writer.add_scalar("val/dis_loss", epochmean_dis_val, epoch) 268 | writer.add_scalar("val/global_loss", epochmean_global_val, epoch) 269 | writer.add_scalar("val/local_loss", epochmean_local_val, epoch) 270 | writer.add_scalar("val/fusion_loss", epochmean_fusion_val, epoch) 271 | 272 | for cls in range(args.num_classes): 273 | writer.add_scalar(f"val/recall_{cls}", recall_classes[cls], epoch) 274 | writer.add_scalar("val/recall_mean", recall_mean, epoch) 275 | writer.add_scalar("val/acc", acc_mean, epoch) 276 | 277 | if acc_mean*0.1 + recall_mean*0.9 > best_val_acc: 278 | best_val_acc = acc_mean*0.1 + recall_mean*0.9 279 | 280 | # If using DP 281 | if torch.cuda.device_count() > 1: 282 | torch.save(model.module.state_dict(), model_path / '{}_{:.4f}.pth'.format(epoch, best_val_acc)) 283 | else: 284 | torch.save(model.state_dict(), model_path / '{}_{:.4f}.pth'.format(epoch, best_val_acc)) 285 | 286 | print("%2.2f sec(s)"%(time.time() - epoch_start_time)) 287 | 288 | # interval of saving model 289 | if (epoch % args.interval == 0 or epoch == num_epochs - 1): 290 | torch.save(model.module.state_dict(), model_path / '{}.pth'.format(epoch)) 291 | 292 | # Fix random seed for reproducibility 293 | def same_seeds(seed): 294 | torch.manual_seed(seed) 295 | if torch.cuda.is_available(): 296 | torch.cuda.manual_seed(seed) 297 | torch.cuda.manual_seed_all(seed) 298 | np.random.seed(seed) 299 | random.seed(seed) 300 | torch.backends.cudnn.benchmark = False 301 | torch.backends.cudnn.deterministic = True 302 | 303 | 304 | def main(): 305 | parser = argparse.ArgumentParser() 306 | parser.add_argument('--data_dir', type=str, default='./datasets/dataset1/global/train') 307 | parser.add_argument('--batch_size', type=int, default=256) 308 | parser.add_argument('--img_size', type=int, default=256) 309 | parser.add_argument('--epoch', type=int, default=60) 310 | parser.add_argument("--num_classes", type=int, default=3) 311 | 312 | parser.add_argument("--gan_opt", type=bool, default=False) 313 | parser.add_argument('--gan_epoch', type=int, default=10, help='epoch to start calculating GAN-like loss') 314 | parser.add_argument('--gan_weight', type=float, default=0.01) 315 | 316 | parser.add_argument('--AdamW', action='store_true') 317 | parser.add_argument('--warmup', action='store_true') 318 | parser.add_argument('--warmup_period', type=int, default=200) 319 | parser.add_argument('--lr', type=float, default=0.003) 320 | 321 | parser.add_argument('--global_weight', type=float, default=1.0) 322 | parser.add_argument('--local_weight', type=float, default=0.3) 323 | parser.add_argument('--fusion_weight', type=float, default=1.0) 324 | 325 | parser.add_argument("--pretrained", type=bool, default=True, help="whether to use pretrained models") 326 | parser.add_argument('--encoder', type=str, default='ResNet50', help="encoder name", 327 | choices=['ResNet18', 'ResNet34', 'ResNet50']) 328 | 329 | parser.add_argument('--save_path', type=str, default='./model_ours') 330 | parser.add_argument('--seed', type=int, default=42) 331 | parser.add_argument('--interval', type=int, default=5, help='interval of saving model during training') 332 | parser.add_argument('--devices', type=str, default="0, 1") 333 | 334 | args = parser.parse_args() 335 | 336 | save = Path(args.save_path) 337 | save.mkdir(parents=True, exist_ok=True) 338 | config_file = save / 'config.txt' 339 | config_items = [] 340 | for key, value in args.__dict__.items(): 341 | config_items.append(f'{key}: {value}\n') 342 | 343 | with open(config_file, 'w') as f: 344 | f.writelines(config_items) 345 | 346 | f.close() 347 | 348 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices 349 | 350 | same_seeds(args.seed) 351 | 352 | # num_features is fixed, representing the number of branches (global and local) 353 | if args.encoder == 'ResNet18': 354 | model = SAM_FNet18(num_classes=args.num_classes, num_features=2, pretrained=args.pretrained) 355 | elif args.encoder == 'ResNet34': 356 | model = SAM_FNet34(num_classes=args.num_classes, num_features=2, pretrained=args.pretrained) 357 | elif args.encoder == 'ResNet50': 358 | model = SAM_FNet50(num_classes=args.num_classes, num_features=2, pretrained=args.pretrained) 359 | model = model.cuda() 360 | if torch.cuda.device_count() > 1: 361 | model = DataParallel(model) 362 | 363 | criterion = { 364 | # classification loss for global, local, and fusion 365 | 'ce': nn.CrossEntropyLoss(), 366 | 367 | # GAN-like loss 368 | 'cs': nn.CosineEmbeddingLoss(), 369 | 'be': nn.BCEWithLogitsLoss(reduction='none') 370 | } 371 | 372 | data_transform = { 373 | 'train': transforms.Compose([ 374 | transforms.RandomAffine(degrees=10, scale=(0.9, 1.1), translate=(0.1, 0.1), shear=5.729578), 375 | transforms.RandomHorizontalFlip(p=0.5), 376 | transforms.ColorJitter(0.4, 0.4, 0.4, 0), 377 | transforms.Resize(args.img_size), 378 | transforms.CenterCrop(args.img_size), 379 | transforms.ToTensor(), 380 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 381 | 382 | 'val': transforms.Compose([ 383 | transforms.Resize(args.img_size), 384 | transforms.CenterCrop(args.img_size), 385 | transforms.ToTensor(), 386 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 387 | } 388 | 389 | 390 | train_dataset = LPCDataset(root = args.data_dir, transform = data_transform['train']) 391 | print('The length of training dataset:', len(train_dataset)) 392 | train_dataloaders = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True) 393 | val_dataset = LPCDataset(root = args.data_dir.replace('train', 'val'), transform = data_transform['val']) 394 | print('The length of validating dataset:', len(val_dataset)) 395 | val_dataloaders = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True) 396 | 397 | record_path = save / 'runs' 398 | model_path = save / 'weights' 399 | record_path.mkdir(parents=True, exist_ok=True) 400 | model_path.mkdir(parents=True, exist_ok=True) 401 | writer = SummaryWriter(record_path.as_posix()) 402 | train_model(args, model, criterion, train_dataloaders, val_dataloaders, args.epoch, model_path, writer) 403 | 404 | if __name__ == '__main__': 405 | main() -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import multiprocessing 4 | multiprocessing.set_start_method("spawn", True) 5 | 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | 12 | from model import SAM_FNet50, SAM_FNet18, SAM_FNet34 13 | from torchvision import transforms 14 | from dataset import LPCDataset 15 | from tqdm import tqdm 16 | import numpy as np 17 | from sklearn.metrics import classification_report 18 | 19 | import pandas as pd 20 | 21 | classes = { 22 | 0: 'normal', 23 | 1: 'benign', 24 | 2: 'tumor' 25 | } 26 | 27 | def count_metrics(plist, tlist, save_path): 28 | pred_np = np.array(plist) 29 | targets_np = np.array(tlist) 30 | 31 | report = classification_report(targets_np, pred_np, digits=4) 32 | print(report) 33 | 34 | # Save the classification report string to a file 35 | with open(save_path / 'classification_report.txt', 'w') as file: 36 | file.write(report) 37 | 38 | file.close() 39 | 40 | 41 | def count_pred(data): 42 | data_soft = F.softmax(data, dim=1) 43 | _, predicted = torch.max(data_soft.data, 1) 44 | return data_soft, predicted 45 | 46 | 47 | def test(args, model, val_dataloaders, save_path): 48 | model.eval() 49 | preds = [] 50 | targets = [] 51 | output_scores_list = [] 52 | 53 | with torch.no_grad(): 54 | for idx, (input1, input2, labels, _, _) in tqdm(enumerate(val_dataloaders), total=len(val_dataloaders)): 55 | input1, input2, labels = input1.cuda(), input2.cuda(), labels.cuda() 56 | o_g, o_l, o_f, _, _, _ = model(input1, input2, labels) 57 | 58 | if args.ensemble: 59 | output = (o_g + o_l + o_f) / 3.0 60 | else: 61 | output = o_f 62 | 63 | output, predicted = count_pred(output) 64 | preds.extend(predicted.cpu().numpy()) 65 | targets.extend(labels.cpu().numpy()) 66 | output_scores_list.extend(output.cpu().numpy()) 67 | 68 | preds = [classes[x] for x in preds] 69 | targets = [classes[x] for x in targets] 70 | results = pd.DataFrame({'preds': preds, 'targets': targets}) 71 | for i in range(3): 72 | results[f'class_{i}_score'] = [output_scores[i] for output_scores in output_scores_list] 73 | results.to_csv(save_path / "results.csv", index=False, header=True) 74 | 75 | count_metrics(preds, targets, save_path) 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--model_path', type=str, default='./model_ours/weights/46_0.9646.pth') 80 | parser.add_argument('--encoder', type=str, default='ResNet50', help="encoder name", 81 | choices=['ResNet18', 'ResNet34', 'ResNet50']) 82 | parser.add_argument('--dataset', type=str, default='dataset1') 83 | parser.add_argument('--batch_size', type=int, default=256) 84 | parser.add_argument('--img_size', type=int, default=256) 85 | parser.add_argument('--ensemble', type=bool, default=True) 86 | parser.add_argument('--save_path', type=str, default='./model_ours/') 87 | parser.add_argument('--devices', type=str, default='0,1') 88 | 89 | args = parser.parse_args() 90 | 91 | os.environ['CUDA_VISIBLE_DEVICES'] = args.devices 92 | 93 | model_path = args.model_path 94 | save_path = Path(args.save_path) 95 | save_path = save_path / args.dataset 96 | if not save_path.exists(): 97 | save_path.mkdir(parents=True, exist_ok=True) 98 | 99 | if args.encoder == 'ResNet18': 100 | model = SAM_FNet18(num_classes=args.num_classes, num_features=2, pretrained=False) 101 | elif args.encoder == 'ResNet34': 102 | model = SAM_FNet34(num_classes=args.num_classes, num_features=2, pretrained=False) 103 | elif args.encoder == 'ResNet50': 104 | model = SAM_FNet50(num_classes=args.num_classes, num_features=2, pretrained=False) 105 | model.load_state_dict(torch.load(model_path)) 106 | model = model.cuda() 107 | 108 | transforms_val = transforms.Compose([ 109 | transforms.Resize(args.img_size), 110 | transforms.CenterCrop(args.img_size), 111 | transforms.ToTensor(), 112 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 113 | ]) 114 | 115 | test_dataset = LPCDataset(root=f'./datasets/{args.dataset}/global/test', transform=transforms_val) 116 | print('The length of testing dataset', len(test_dataset)) 117 | test_dataloaders = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 118 | test(args, model, test_dataloaders, save_path) 119 | --------------------------------------------------------------------------------