├── modeling ├── README.md ├── __init__.py ├── common.py ├── mask_decoder.py ├── sam.py ├── transformer.py ├── prompt_encoder.py ├── image_encoder.py └── tiny_vit_sam.py ├── utils ├── README.md └── transforms.py ├── loss.py ├── data_split.json ├── README.md ├── sppnet.py ├── dataloader.py ├── point.py ├── eval.py ├── train.py ├── train_adapter.py └── LICENSE /modeling/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class DiceLoss(nn.Module): 8 | def __init__(self, weight=None, size_average=True): 9 | super(DiceLoss, self).__init__() 10 | 11 | def forward(self, inputs, targets, smooth=1): 12 | 13 | inputs = torch.sigmoid(inputs) 14 | 15 | intersection = (inputs * targets).sum() 16 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 17 | 18 | return 1-dice 19 | 20 | class IoU(nn.Module): 21 | def __init__(self, weight=None, size_average=True): 22 | super(DiceLoss, self).__init__() 23 | 24 | def forward(self, inputs, targets, smooth=1): 25 | 26 | inputs = torch.sigmoid(inputs) 27 | 28 | intersection = (inputs * targets).sum() 29 | total = (inputs + targets).sum() 30 | union = total - intersection 31 | 32 | IoU = (intersection + smooth) / (union + smooth) 33 | 34 | return IoU 35 | -------------------------------------------------------------------------------- /data_split.json: -------------------------------------------------------------------------------- 1 | {"train": ["TCGA-50-5931-01Z-00-DX1.tif", "TCGA-XS-A8TJ-01Z-00-DX1.tif", "TCGA-B0-5711-01Z-00-DX1.tif", "TCGA-G9-6362-01Z-00-DX1.tif", "TCGA-MH-A561-01Z-00-DX1.tif", "TCGA-BC-A217-01Z-00-DX1.tif", "TCGA-49-4488-01Z-00-DX1.tif", "TCGA-HE-7130-01Z-00-DX1.tif", "TCGA-HE-7129-01Z-00-DX1.tif", "TCGA-38-6178-01Z-00-DX1.tif", "TCGA-UZ-A9PJ-01Z-00-DX1.tif", "TCGA-RD-A8N9-01A-01-TS1.tif", "TCGA-DK-A2I6-01A-01-TS1.tif", "TCGA-AR-A1AK-01Z-00-DX1.tif", "TCGA-F9-A8NY-01Z-00-DX1.tif", "TCGA-HE-7128-01Z-00-DX1.tif", "TCGA-G9-6356-01Z-00-DX1.tif", "TCGA-G9-6348-01Z-00-DX1.tif", "TCGA-E2-A1B5-01Z-00-DX1.tif", "TCGA-21-5784-01Z-00-DX1.tif", "TCGA-18-5592-01Z-00-DX1.tif", "TCGA-AY-A8YK-01A-01-TS1.tif", "TCGA-G9-6363-01Z-00-DX1.tif"], "valid": ["TCGA-21-5786-01Z-00-DX1.tif", "TCGA-CH-5767-01Z-00-DX1.tif", "TCGA-G9-6336-01Z-00-DX1.tif", "TCGA-E2-A14V-01Z-00-DX1.tif", "TCGA-B0-5710-01Z-00-DX1.tif", "TCGA-A7-A13E-01Z-00-DX1.tif", "TCGA-KB-A93J-01A-01-TS1.tif", "TCGA-B0-5698-01Z-00-DX1.tif", "TCGA-G2-A2EK-01A-02-TSB.tif", "TCGA-A7-A13F-01Z-00-DX1.tif", "TCGA-UZ-A9PN-01Z-00-DX1.tif", "TCGA-FG-A87N-01Z-00-DX1.tif", "TCGA-AR-A1AS-01Z-00-DX1.tif", "TCGA-NH-A8F7-01A-01-TS1.tif"], "test": ["TCGA-EJ-A46H-01A-03-TSC.tif", "TCGA-IZ-8196-01A-01-BS1.tif", "TCGA-44-2665-01B-06-BS6.tif", "TCGA-2Z-A9J9-01A-01-TS1.tif", "TCGA-AC-A2FO-01A-01-TS1.tif", "TCGA-A6-6782-01A-01-BS1.tif", "TCGA-CU-A0YN-01A-02-BSB.tif", "TCGA-HC-7209-01A-01-TS1.tif", "TCGA-HT-8564-01Z-00-DX1.tif", "TCGA-GL-6846-01A-01-BS1.tif", "TCGA-ZF-A9R5-01A-01-TS1.tif", "TCGA-AO-A0J2-01A-01-BSA.tif", "TCGA-69-7764-01A-01-TS1.tif", "TCGA-FG-A4MU-01B-01-TS1.tif"]} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPPNet: A Single-Point Prompt Network for Nuclei Image Segmentation (Boost SAM) 2 | 3 | ## News 4 | 2023.07.14: The SPPNet model and training code have been submitted. The paper will be updated later. 5 | 6 | 2023.08.24: The paper has been accepted by [MICCAI-MLMI 2023](https://sites.google.com/view/mlmi2023). The preprint has been available at [arXiv](https://arxiv.org/abs/2308.12231). 7 | 8 | 2023.09.27: Release a New Beta version for users who want to fine-tune the SAM pre-trained image encoder. We add the adapter based on [Medical-SAM-Adapter](https://github.com/WuJunde/Medical-SAM-Adapter). 9 | 10 | ## Requirements 11 | 1. pytorch==1.10.0 12 | 2. pytorch-lightning==1.1.0 13 | 3. albumentations==0.3.2 14 | 4. seaborn 15 | 5. sklearn 16 | 17 | ## Environment 18 | NVIDIA RTX2080Ti Tensor Core GPU, 4-core CPU, and 28GB RAM 19 | 20 | ## Evaluation on MoNuSeg-2018 21 | 22 | | Method| mIoU(%) | DSC(%) | Params(M) | FLOPs | FPS | 23 | | ---- | ---- | ---- | ---- | ---- | ---- | 24 | | SAM (Fine-tuned) | 60.18±8.15 | 74.76±7.00 | 635.93 | 2736.63 | 1.39| 25 | | SPPNet | 66.43±4.32 | 79.77±3.11 | 9.79 | 39.90 | 22.61 | 26 | 27 | ## Dataset 28 | To apply the model on a custom dataset, the data tree should be constructed as: 29 | ``` 30 | ├── data 31 | ├── images 32 | ├── image_1.png 33 | ├── image_2.png 34 | ├── image_n.png 35 | ├── masks 36 | ├── image_1.npy 37 | ├── image_2.npy 38 | ├── image_n.npy 39 | ``` 40 | ## Train 41 | ``` 42 | python train.py --dataset your/data/path --jsonfile your/json/path --loss dice --batch 16 --lr 0.001 --epoch 50 43 | ``` 44 | ## Evaluation 45 | ``` 46 | python eval.py --dataset your/data/path --jsonfile your/json/path --model save_models/model_best.pth --debug True 47 | ``` 48 | ## Acknowledgement 49 | The codes are modified from [SAM](https://github.com/facebookresearch/segment-anything) and [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). 50 | 51 | -------------------------------------------------------------------------------- /modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | class Adapter(nn.Module): 13 | def __init__(self, D_features, mlp_ratio=0.25, skip_connect=True): 14 | super().__init__() 15 | self.skip_connect = skip_connect 16 | D_hidden_features = int(D_features * mlp_ratio) 17 | self.act = nn.GELU() 18 | self.D_fc1 = nn.Linear(D_features, D_hidden_features) 19 | self.D_fc2 = nn.Linear(D_hidden_features, D_features) 20 | 21 | def forward(self, x): 22 | # x is (BT, HW+1, D) 23 | xs = self.D_fc1(x) 24 | xs = self.act(xs) 25 | xs = self.D_fc2(xs) 26 | if self.skip_connect: 27 | x = x + xs 28 | else: 29 | x = xs 30 | return x 31 | 32 | class MLPBlock(nn.Module): 33 | def __init__( 34 | self, 35 | embedding_dim: int, 36 | mlp_dim: int, 37 | act: Type[nn.Module] = nn.GELU, 38 | ) -> None: 39 | super().__init__() 40 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 41 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 42 | self.act = act() 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | return self.lin2(self.act(self.lin1(x))) 46 | 47 | 48 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 49 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 50 | class LayerNorm2d(nn.Module): 51 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 52 | super().__init__() 53 | self.weight = nn.Parameter(torch.ones(num_channels)) 54 | self.bias = nn.Parameter(torch.zeros(num_channels)) 55 | self.eps = eps 56 | 57 | def forward(self, x: torch.Tensor) -> torch.Tensor: 58 | u = x.mean(1, keepdim=True) 59 | s = (x - u).pow(2).mean(1, keepdim=True) 60 | x = (x - u) / torch.sqrt(s + self.eps) 61 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 62 | return x 63 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sppnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from modeling.mask_decoder import MaskDecoder 5 | from modeling.prompt_encoder import PromptEncoder 6 | from modeling.transformer import TwoWayTransformer 7 | from modeling.tiny_vit_sam import TinyViT 8 | from utils.transforms import ResizeLongestSide 9 | from modeling.image_encoder import ImageEncoderViT 10 | from functools import partial 11 | 12 | class LLSIE(nn.Module): 13 | def __init__(self,in_channels, out_channels, kernel_size=3): 14 | super(LLSIE, self).__init__() 15 | self.input_layer = nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size, padding= kernel_size // 2), 17 | nn.ReLU(inplace=True), 18 | nn.BatchNorm2d(out_channels)) 19 | self.depthwise = nn.Sequential( 20 | nn.Conv2d(out_channels, out_channels, kernel_size, groups=out_channels, padding= kernel_size // 2), 21 | nn.ReLU(inplace=True), 22 | nn.BatchNorm2d(out_channels)) 23 | self.pointwise = nn.Sequential( 24 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 25 | nn.ReLU(inplace=True), 26 | nn.BatchNorm2d(out_channels)) 27 | def forward(self, x): 28 | x = self.input_layer(x) 29 | residual = x 30 | x = self.depthwise(x) 31 | x += residual 32 | x = self.pointwise(x) 33 | return x 34 | 35 | class Model(nn.Module): 36 | def __init__(self, image_encoder): 37 | super(Model, self).__init__() 38 | self.image_encoder = image_encoder 39 | # self.image_encoder = TinyViT(img_size=1024, in_chans=3, num_classes=1000, 40 | # embed_dims=[64, 128, 160, 320], 41 | # depths=[2, 2, 6, 2], 42 | # num_heads=[2, 4, 5, 10], 43 | # window_sizes=[7, 7, 14, 7], 44 | # mlp_ratio=4., 45 | # drop_rate=0., 46 | # drop_path_rate=0.0, 47 | # use_checkpoint=True, 48 | # mbconv_expand_ratio=4.0, 49 | # local_conv_size=3, 50 | # layer_lr_decay=0.8 51 | # ) 52 | self.prompt_encoder = PromptEncoder( 53 | embed_dim=256, 54 | image_embedding_size=(64, 64), # 1024 // 16 55 | input_image_size=(1024, 1024), 56 | mask_in_chans=16, 57 | ) 58 | self.mask_decoder = MaskDecoder( 59 | num_multimask_outputs=3, 60 | transformer=TwoWayTransformer( 61 | depth=2, 62 | embedding_dim=256, 63 | mlp_dim=2048, 64 | num_heads=8, 65 | ), 66 | transformer_dim=256, 67 | iou_head_depth=3, 68 | iou_head_hidden_dim=256, 69 | ) 70 | self.transform = ResizeLongestSide(1024) 71 | 72 | self.conv1 = LLSIE(3, 32) 73 | # self.maxpool = nn.MaxPool2d(kernel_size=2) 74 | 75 | def forward(self, x_resized, point_coords, point_labels, x, img_shape): 76 | 77 | low_level_infos = self.conv1(x_resized) 78 | 79 | image_embeddings = self.image_encoder(x) 80 | 81 | transformed_coords = self.transform.apply_coords_torch(point_coords, img_shape) 82 | 83 | outputs = [] 84 | 85 | for one_coords, one_label, one_x, lli in zip(transformed_coords, point_labels, image_embeddings, low_level_infos): 86 | # for one_coords, one_label, one_x in zip(transformed_coords, point_labels, image_embeddings): 87 | 88 | one_coords = one_coords.unsqueeze(0) 89 | one_label = one_label.unsqueeze(0) 90 | 91 | points = (one_coords, one_label) 92 | 93 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 94 | points=points, 95 | boxes=None, 96 | masks=None, 97 | ) 98 | 99 | low_res_masks, iou_predictions = self.mask_decoder( 100 | image_embeddings=one_x.unsqueeze(0), 101 | image_pe=self.prompt_encoder.get_dense_pe(), 102 | sparse_prompt_embeddings=sparse_embeddings, 103 | dense_prompt_embeddings=dense_embeddings, 104 | multimask_output=False, 105 | low_level_info=lli, 106 | ) 107 | 108 | outputs.append(low_res_masks.squeeze(0)) 109 | 110 | return torch.stack(outputs, dim=0) 111 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io, transform, color,img_as_ubyte 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import cv2 6 | import torch 7 | import albumentations as A 8 | from albumentations.pytorch import ToTensor 9 | from point import RandomExtractor, CentreExtractor, CNPS 10 | import torchvision.transforms as transforms 11 | import torch.nn.functional as F 12 | 13 | class sam_inputer(Dataset): 14 | def __init__(self,path,data, transform=None, pixel_mean=[123.675, 116.280, 103.530], pixel_std=[58.395, 57.12, 57.375]): 15 | self.path = path 16 | self.folders = data 17 | self.transforms = transform 18 | self.to_tesnor = transforms.Compose([transforms.ToTensor(), ]) 19 | self.pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) 20 | self.pixel_std = torch.Tensor(pixel_mean).view(-1, 1, 1) 21 | self.img_size = 1024 22 | 23 | def __len__(self): 24 | return len(self.folders) 25 | 26 | 27 | def __getitem__(self,idx): 28 | image_id = list(self.folders[idx].split('.'))[0] 29 | image_path = os.path.join(self.path,'images/',self.folders[idx]) 30 | mask_path = os.path.join(self.path,'masks/',image_id) 31 | npy_path = os.path.join(self.path,'npy/',image_id) + '.npy' 32 | 33 | point_coord, point_class = CNPS(npy_path) 34 | point_coord = torch.tensor(point_coord) 35 | point_class = torch.tensor(point_class) 36 | 37 | img = io.imread(image_path)[:,:,:3].astype('float32') 38 | mask = io.imread(mask_path+'.png', as_gray=True) 39 | 40 | img_vit = self.to_tesnor(img) 41 | img_vit, h, w = self.preprocess(img_vit) 42 | 43 | augmented = self.transforms(image=img, mask=mask) 44 | img = augmented['image'] 45 | mask = augmented['mask'] 46 | 47 | return (img, point_coord, point_class, img_vit, mask, image_id, h, w) 48 | 49 | def preprocess(self, x): 50 | """Normalize pixel values and pad to a square input.""" 51 | # Normalize colors 52 | x = (x - self.pixel_mean) / self.pixel_std 53 | 54 | # Pad 55 | h, w = x.shape[-2:] 56 | padh = self.img_size - h 57 | padw = self.img_size - w 58 | x = F.pad(x, (0, padw, 0, padh)) 59 | 60 | return x, h, w 61 | 62 | # class sam_inputer(Dataset): 63 | # def __init__(self,path,data, transform=None): 64 | # self.path = path 65 | # self.folders = data 66 | # self.transforms = transform 67 | # self.to_tesnor = transforms.Compose([transforms.ToTensor(), ]) 68 | # self.pixel_mean = (123.675, 116.280, 103.530) 69 | # self.pixel_std = (58.395, 57.12, 57.375) 70 | # self.img_size = 1024 71 | 72 | # def __len__(self): 73 | # return len(self.folders) 74 | 75 | 76 | # def __getitem__(self,idx): 77 | # image_id = list(self.folders[idx].split('.'))[0] 78 | # image_path = os.path.join(self.path,'images/',self.folders[idx]) 79 | # pt_path = os.path.join(self.path,'features/',image_id) 80 | # mask_path = os.path.join(self.path,'masks/',image_id) 81 | # npy_path = os.path.join(self.path,'npy/',image_id) + '.npy' 82 | 83 | # point_coord, point_class = CNPS(npy_path) 84 | # point_coord = torch.tensor(point_coord) 85 | # point_class = torch.tensor(point_class) 86 | 87 | # img = io.imread(image_path)[:,:,:3].astype('float32') 88 | # mask = io.imread(mask_path+'.png', as_gray=True) 89 | 90 | 91 | # self.pixel_mean = np.array(self.pixel_mean, dtype=np.float32) 92 | # self.pixel_std = np.array(self.pixel_std, dtype=np.float32) 93 | # img_vit = (img - self.pixel_mean) / self.pixel_std 94 | 95 | # h, w = img_vit.shape[:-1] 96 | # padh = self.img_size - h 97 | # padw = self.img_size - w 98 | # img_vit = self.to_tesnor(img_vit) 99 | # img_vit = F.pad(img_vit, (0, padw, 0, padh)) 100 | 101 | # augmented = self.transforms(image=img, mask=mask) 102 | # img = augmented['image'] 103 | # mask = augmented['mask'] 104 | 105 | # # return (img, sam_feature, mask, image_id) 106 | # return (img, point_coord, point_class, img_vit, mask, image_id, h, w) 107 | -------------------------------------------------------------------------------- /point.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy.ndimage import gaussian_filter 4 | 5 | 6 | 7 | def RandomExtractor(npy_path, pos_num = 1, neg_num = 1): 8 | # np.random.seed(42) 9 | point_coord = [] 10 | point_class = [] 11 | 12 | npy = np.load(npy_path) 13 | mask = npy[:,:,0] 14 | classes = npy[:,:,1] 15 | num_cell = list(set(mask.flatten()))[1:] 16 | 17 | if pos_num == -1: 18 | pos_num = len(num_cell) 19 | neg_num = len(num_cell) 20 | 21 | if pos_num == 0: 22 | cell_num = len(num_cell) 23 | pos_num = np.random.randint(cell_num+1,size=1)[0] 24 | neg_num = pos_num 25 | 26 | neg_coord = np.argwhere(mask == 0) 27 | neg_ids = np.random.randint(len(neg_coord+1),size=neg_num) 28 | for idx in neg_ids: 29 | row, col = neg_coord[idx][0], neg_coord[idx][1] 30 | point_coord.append([row, col]) 31 | point_class.append(classes[row][col]) 32 | 33 | pos_ids = np.random.choice(num_cell, pos_num, replace=False) 34 | for idx in pos_ids: 35 | pos_coord = np.argwhere(mask == idx) 36 | coord_idx = np.random.randint(len(pos_coord+1),size=1)[0] 37 | row, col = pos_coord[coord_idx][0], pos_coord[coord_idx][1] 38 | point_coord.append([col, row]) 39 | point_class.append(classes[row][col]) 40 | 41 | return point_coord, point_class 42 | 43 | 44 | def CentreExtractor(npy_path, pos_num = 1, neg_num = 1): 45 | np.random.seed(42) 46 | point_coord = [] 47 | point_class = [] 48 | 49 | npy = np.load(npy_path) 50 | mask = npy[:,:,0] 51 | classes = npy[:,:,1] 52 | num_cell = list(set(mask.flatten()))[1:] 53 | 54 | if pos_num == -1: 55 | pos_num = len(num_cell) 56 | neg_num = len(num_cell) 57 | 58 | if pos_num == 0: 59 | cell_num = len(num_cell) 60 | pos_num = np.random.randint(cell_num+1,size=1)[0] 61 | neg_num = pos_num 62 | 63 | neg_coord = np.argwhere(mask == 0) 64 | neg_ids = np.random.randint(len(neg_coord+1),size=neg_num) 65 | for idx in neg_ids: 66 | row, col = neg_coord[idx][0], neg_coord[idx][1] 67 | point_coord.append([col, row]) 68 | point_class.append(classes[row][col]) 69 | 70 | pos_ids = np.random.choice(num_cell, pos_num, replace=False) 71 | for idx in pos_ids: 72 | src = np.zeros(mask.shape, np.uint8) 73 | src[mask == idx] = 255 74 | dist = cv2.distanceTransform(src, cv2.DIST_L1, 3) 75 | pos_coord = np.argwhere(dist == dist.max()) 76 | row, col = pos_coord[0][0], pos_coord[0][1] 77 | point_coord.append([col, row]) 78 | point_class.append(classes[row][col]) 79 | 80 | 81 | return point_coord, point_class 82 | 83 | def CNPS(npy_path, pos_num = 1, neg_num = 1, is_center = False): 84 | np.random.seed(42) 85 | point_coord = [] 86 | point_class = [] 87 | 88 | npy = np.load(npy_path) 89 | mask = npy[:,:,0] 90 | classes = npy[:,:,1] 91 | num_cell = list(set(mask.flatten()))[1:] 92 | 93 | if pos_num == -1: 94 | pos_num = len(num_cell) 95 | neg_num = len(num_cell) 96 | 97 | if pos_num == 0: 98 | cell_num = len(num_cell) 99 | pos_num = np.random.randint(cell_num+1,size=1)[0] 100 | neg_num = pos_num 101 | 102 | neg_coord = np.argwhere(mask == 0) 103 | neg_ids = np.random.randint(len(neg_coord+1),size=neg_num) 104 | for idx in neg_ids: 105 | row, col = neg_coord[idx][0], neg_coord[idx][1] 106 | point_coord.append([col, row]) 107 | point_class.append(classes[row][col]) 108 | 109 | pos_ids = np.random.choice(num_cell, pos_num, replace=True) 110 | for idx in pos_ids: 111 | src = np.zeros(mask.shape, np.uint8) 112 | src[mask == idx] = 255 113 | dist = cv2.distanceTransform(src, cv2.DIST_L1, 3) 114 | pos_coord = np.argwhere(dist == dist.max()) 115 | row, col = pos_coord[0][0], pos_coord[0][1] 116 | 117 | if is_center: 118 | point_coord.append([col, row]) 119 | point_class.append(classes[row][col]) 120 | else: 121 | target = np.zeros(mask.shape, dtype=np.uint8) 122 | target[row][col] = 255 123 | target = 100.0 * (target[:,:] > 0) 124 | target = gaussian_filter(target, sigma=(1, 1), mode='nearest', radius=2) 125 | #numpy 1.21.5, scipy 1.7,1 126 | coords = np.argwhere(target != 0) 127 | 128 | normalised = (target - target.min()) / (target.max() - target.min()) 129 | normalised *= 128 130 | normalised[normalised !=0] += 128 131 | 132 | coords_randm = np.random.randint(len(coords+1),size=len(coords)) 133 | for coord_id in coords_randm: 134 | x, y = coords[coord_id] 135 | # print(row, col) 136 | if src[x][y] != 0: 137 | point_coord.append([y, x]) 138 | point_class.append(classes[x][y]) 139 | break 140 | 141 | # return normalised, point_coord, point_class 142 | # return (col, row), point_coord, point_class 143 | return point_coord, point_class 144 | 145 | 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | npy_path = 'monuseg/data/npy/TCGA-2Z-A9J9-01A-01-TS1.npy' 151 | mask_path = 'monuseg/data/masks/TCGA-2Z-A9J9-01A-01-TS1.png' 152 | point_coord, point_class = CentreExtractor(npy_path) 153 | 154 | print(point_coord) 155 | print(point_class) 156 | 157 | mask = cv2.imread(mask_path, 0) 158 | print(mask.shape) 159 | 160 | for coord in point_coord: 161 | cv2.circle(mask, coord, 1, 128, 1) 162 | 163 | cv2.imwrite('point_visual.png', mask) 164 | 165 | 166 | tmp, point_coord, point_class = CNPS(npy_path) 167 | 168 | print(point_coord) 169 | print(point_class) 170 | 171 | mask = cv2.imread(mask_path, 0) 172 | # print(mask.shape) 173 | # print(tmp.astype(np.uint8)) 174 | mask += tmp.astype(np.uint8) 175 | # print(set(tmp.flatten())) 176 | # for coord in point_coord: 177 | # cv2.circle(mask, coord, 1, 128, 2) 178 | 179 | # cv2.circle(mask, tmp, 1, 128, 1) 180 | 181 | cv2.imwrite('point_visual2.png', mask) 182 | 183 | 184 | 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | low_level_info: torch.Tensor, 79 | ) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Predict masks given image and prompt embeddings. 82 | 83 | Arguments: 84 | image_embeddings (torch.Tensor): the embeddings from the image encoder 85 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 86 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 87 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 88 | multimask_output (bool): Whether to return multiple masks or a single 89 | mask. 90 | 91 | Returns: 92 | torch.Tensor: batched predicted masks 93 | torch.Tensor: batched predictions of mask quality 94 | """ 95 | masks, iou_pred = self.predict_masks( 96 | image_embeddings=image_embeddings, 97 | image_pe=image_pe, 98 | sparse_prompt_embeddings=sparse_prompt_embeddings, 99 | dense_prompt_embeddings=dense_prompt_embeddings, 100 | low_level_info=low_level_info, 101 | ) 102 | 103 | # Select the correct mask or masks for output 104 | if multimask_output: 105 | mask_slice = slice(1, None) 106 | else: 107 | mask_slice = slice(0, 1) 108 | masks = masks[:, mask_slice, :, :] 109 | iou_pred = iou_pred[:, mask_slice] 110 | 111 | # Prepare output 112 | return masks, iou_pred 113 | 114 | def predict_masks( 115 | self, 116 | image_embeddings: torch.Tensor, 117 | image_pe: torch.Tensor, 118 | sparse_prompt_embeddings: torch.Tensor, 119 | dense_prompt_embeddings: torch.Tensor, 120 | low_level_info: torch.Tensor, 121 | ) -> Tuple[torch.Tensor, torch.Tensor]: 122 | """Predicts masks. See 'forward' for more details.""" 123 | # Concatenate output tokens 124 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 125 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 126 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 127 | 128 | # Expand per-image data in batch direction to be per-mask 129 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 130 | src = src + dense_prompt_embeddings 131 | 132 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 133 | b, c, h, w = src.shape 134 | 135 | # Run the transformer 136 | hs, src = self.transformer(src, pos_src, tokens) 137 | iou_token_out = hs[:, 0, :] 138 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 139 | 140 | # Upscale mask embeddings and predict masks using the mask tokens 141 | src = src.transpose(1, 2).view(b, c, h, w) 142 | 143 | upscaled_embedding = self.output_upscaling(src) 144 | upscaled_embedding += low_level_info 145 | # print(upscaled_embedding.shape) 146 | hyper_in_list: List[torch.Tensor] = [] 147 | for i in range(self.num_mask_tokens): 148 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 149 | hyper_in = torch.stack(hyper_in_list, dim=1) 150 | b, c, h, w = upscaled_embedding.shape 151 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 152 | # print(masks.shape) 153 | 154 | # Generate mask quality predictions 155 | iou_pred = self.iou_prediction_head(iou_token_out) 156 | 157 | return masks, iou_pred 158 | 159 | 160 | # Lightly adapted from 161 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 162 | class MLP(nn.Module): 163 | def __init__( 164 | self, 165 | input_dim: int, 166 | hidden_dim: int, 167 | output_dim: int, 168 | num_layers: int, 169 | sigmoid_output: bool = False, 170 | ) -> None: 171 | super().__init__() 172 | self.num_layers = num_layers 173 | h = [hidden_dim] * (num_layers - 1) 174 | self.layers = nn.ModuleList( 175 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 176 | ) 177 | self.sigmoid_output = sigmoid_output 178 | 179 | def forward(self, x): 180 | for i, layer in enumerate(self.layers): 181 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 182 | if self.sigmoid_output: 183 | x = F.sigmoid(x) 184 | return x 185 | -------------------------------------------------------------------------------- /modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | from dataloader import sam_inputer 6 | import albumentations as A 7 | from albumentations.pytorch import ToTensor 8 | from pytorch_lightning.metrics import Accuracy, Precision, Recall, F1 9 | import argparse 10 | import time 11 | import pandas as pd 12 | import cv2 13 | import os 14 | from skimage import io, transform 15 | from PIL import Image 16 | import json 17 | from tqdm import tqdm 18 | import sppnet 19 | import torch.nn.functional as F 20 | from typing import Any, Dict, List, Tuple 21 | 22 | def postprocess_masks( 23 | masks: torch.Tensor, 24 | input_size: (Tuple[int, ...]), 25 | original_size: (Tuple[int, ...]), 26 | ) -> torch.Tensor: 27 | """ 28 | Remove padding and upscale masks to the original image size. 29 | 30 | Arguments: 31 | masks (torch.Tensor): Batched masks from the mask_decoder, 32 | in BxCxHxW format. 33 | input_size (tuple(int, int)): The size of the image input to the 34 | model, in (H, W) format. Used to remove padding. 35 | original_size (tuple(int, int)): The original size of the image 36 | before resizing for input to the model, in (H, W) format. 37 | 38 | Returns: 39 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 40 | is given by original_size. 41 | """ 42 | masks = F.interpolate( 43 | masks, 44 | (1024, 1024), 45 | mode="bilinear", 46 | align_corners=False, 47 | ) 48 | masks = masks[..., : input_size[0], : input_size[1]] 49 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 50 | return masks 51 | 52 | class IoU(nn.Module): 53 | def __init__(self, weight=None, size_average=True): 54 | super(IoU, self).__init__() 55 | 56 | def forward(self, inputs, targets, smooth=1): 57 | 58 | inputs = inputs.view(-1) 59 | targets = targets.view(-1) 60 | 61 | intersection = (inputs * targets).sum() 62 | total = (inputs + targets).sum() 63 | union = total - intersection 64 | 65 | IoU = (intersection + smooth)/(union + smooth) 66 | 67 | return IoU 68 | 69 | class Dice(nn.Module): 70 | def __init__(self, weight=None, size_average=True): 71 | super(Dice, self).__init__() 72 | 73 | def forward(self, inputs, targets, smooth=1): 74 | 75 | intersection = (inputs * targets).sum() 76 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 77 | 78 | return dice 79 | 80 | def get_transform(): 81 | return A.Compose( 82 | [ 83 | A.Resize(256, 256), 84 | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 85 | ToTensor() 86 | ]) 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('--dataset', default='monuseg/images/',type=str, help='the path of dataset') 90 | parser.add_argument('--jsonfile', default='data_split.json',type=str, help='') 91 | parser.add_argument('--model',default='save_models/model_best.pth', type=str, help='the path of model') 92 | parser.add_argument('--debug',default=True, type=bool, help='plot mask') 93 | args = parser.parse_args() 94 | 95 | os.makedirs('debug/',exist_ok=True) 96 | 97 | with open(args.jsonfile, 'r') as f: 98 | df = json.load(f) 99 | 100 | test_files = df['test'] 101 | 102 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 103 | 104 | test_dataset = sam_inputer(args.dataset,test_files, get_transform()) 105 | 106 | model = sppnet.Model() 107 | model.load_state_dict(torch.load(args.model)) 108 | 109 | model = model.cuda() 110 | 111 | acc_eval = Accuracy() 112 | pre_eval = Precision() 113 | dice_eval = Dice() 114 | recall_eval = Recall() 115 | f1_eval = F1(2) 116 | iou_eval = IoU() 117 | iou_score = [] 118 | acc_score = [] 119 | pre_score = [] 120 | recall_score = [] 121 | f1_score = [] 122 | dice_score = [] 123 | time_cost = [] 124 | id_lists = [] 125 | 126 | since = time.time() 127 | if args.debug: 128 | for image_id in test_files: 129 | img = cv2.imread(f'{args.dataset}images/{image_id}') 130 | img = cv2.resize(img, ((256,256))) 131 | img_id = list(image_id.split('.'))[0] 132 | cv2.imwrite(f'debug/{img_id}.png',img) 133 | 134 | with torch.no_grad(): 135 | for img, point_coord, point_class, img_vit, mask, img_id, h, w in tqdm(test_dataset): 136 | 137 | point_coord = Variable(torch.unsqueeze(point_coord, dim=0), requires_grad=False).cuda() 138 | point_class = Variable(torch.unsqueeze(point_class, dim=0), requires_grad=False).cuda() 139 | img_vit = Variable(torch.unsqueeze(img_vit, dim=0), requires_grad=False).cuda() 140 | img = Variable(torch.unsqueeze(img, dim=0), requires_grad=False).cuda() 141 | mask = Variable(torch.unsqueeze(mask, dim=0), requires_grad=False).cuda() 142 | 143 | torch.cuda.synchronize() 144 | start = time.time() 145 | pred = model(img, point_coord, point_class, img_vit, (h, w)) 146 | torch.cuda.synchronize() 147 | end = time.time() 148 | time_cost.append(end-start) 149 | 150 | pred_tmp = postprocess_masks(pred, (1000, 1000), (1000, 1000)) 151 | # print(pred_tmp) 152 | 153 | pred = torch.sigmoid(pred) 154 | 155 | pred[pred >= 0.5] = 1 156 | pred[pred < 0.5] = 0 157 | 158 | 159 | pred_draw = pred.clone().detach() 160 | mask_draw = mask.clone().detach() 161 | 162 | 163 | if args.debug: 164 | img_id = list(img_id.split('.'))[0] 165 | img_numpy = pred_draw.cpu().detach().numpy()[0][0] 166 | img_numpy[img_numpy==1] = 255 167 | cv2.imwrite(f'debug/{img_id}_pred.png',img_numpy) 168 | 169 | mask_numpy = mask_draw.cpu().detach().numpy()[0][0] 170 | mask_numpy[mask_numpy==1] = 255 171 | cv2.imwrite(f'debug/{img_id}_gt.png',mask_numpy) 172 | iouscore = iou_eval(pred,mask) 173 | dicescore = dice_eval(pred,mask) 174 | pred = pred.view(-1) 175 | mask = mask.view(-1) 176 | 177 | accscore = acc_eval(pred.cpu(),mask.cpu()) 178 | prescore = pre_eval(pred.cpu(),mask.cpu()) 179 | recallscore = recall_eval(pred.cpu(),mask.cpu()) 180 | f1score = f1_eval(pred.cpu(),mask.cpu()) 181 | iou_score.append(iouscore.cpu().detach().numpy()) 182 | dice_score.append(dicescore.cpu().detach().numpy()) 183 | acc_score.append(accscore.cpu().detach().numpy()) 184 | pre_score.append(prescore.cpu().detach().numpy()) 185 | recall_score.append(recallscore.cpu().detach().numpy()) 186 | f1_score.append(f1score.cpu().detach().numpy()) 187 | id_lists.append(img_id) 188 | torch.cuda.empty_cache() 189 | 190 | time_elapsed = time.time() - since 191 | 192 | result_dict = {'image_id':id_lists, 'miou':iou_score, 'dice':dice_score} 193 | result_df = pd.DataFrame(result_dict) 194 | result_df.to_csv('best.csv',index=False) 195 | 196 | print('Evaluation complete in {:.0f}m {:.0f}s'.format( 197 | time_elapsed // 60, time_elapsed % 60)) 198 | print('FPS: {:.2f}'.format(1.0/(sum(time_cost)/len(time_cost)))) 199 | print('mean IoU:',round(np.mean(iou_score),4),round(np.std(iou_score),4)) 200 | print('mean accuracy:',round(np.mean(acc_score),4),round(np.std(acc_score),4)) 201 | print('mean precsion:',round(np.mean(pre_score),4),round(np.std(pre_score),4)) 202 | print('mean recall:',round(np.mean(recall_score),4),round(np.std(recall_score),4)) 203 | print('mean F1-score:',round(np.mean(f1_score),4),round(np.std(f1_score),4)) 204 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | from torch import optim 7 | import time 8 | import albumentations as A 9 | from albumentations.pytorch import ToTensor 10 | from torch.utils.data import random_split 11 | from torch.optim import lr_scheduler 12 | import seaborn as sns 13 | import pandas as pd 14 | import argparse 15 | import os 16 | from dataloader import sam_inputer 17 | from sklearn.model_selection import GroupKFold 18 | from loss import * 19 | from tqdm import tqdm 20 | import json 21 | import sppnet 22 | from modeling.tiny_vit_sam import TinyViT 23 | 24 | 25 | def get_train_transform(): 26 | return A.Compose( 27 | [ 28 | A.Resize(256, 256), 29 | # A.HorizontalFlip(p=0.25), 30 | # A.RandomBrightness(p=0.25), 31 | # A.ShiftScaleRotate(shift_limit=0,p=0.25), 32 | # A.CoarseDropout(), 33 | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 34 | ToTensor() 35 | ]) 36 | 37 | def get_valid_transform(): 38 | return A.Compose( 39 | [ 40 | A.Resize(256, 256), 41 | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 42 | ToTensor() 43 | ]) 44 | 45 | 46 | def train_model(model, criterion, optimizer, scheduler, num_epochs=5): 47 | since = time.time() 48 | 49 | Loss_list = {'train': [], 'valid': []} 50 | Accuracy_list = {'train': [], 'valid': []} 51 | 52 | best_model_wts = model.state_dict() 53 | 54 | best_loss = float('inf') 55 | counter = 0 56 | 57 | for epoch in range(num_epochs): 58 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 59 | print('-' * 10) 60 | 61 | # Each epoch has a training and validation phase 62 | for phase in ['train', 'valid']: 63 | if phase == 'train': 64 | model.train(True) 65 | 66 | else: 67 | model.train(False) 68 | 69 | running_loss = [] 70 | running_corrects = [] 71 | 72 | # Iterate over data 73 | #for inputs,labels,label_for_ce,image_id in dataloaders[phase]: 74 | for img, point_coord, point_class, img_vit, labels, _, h, w in tqdm(dataloaders[phase]): 75 | # wrap them in Variable 76 | if torch.cuda.is_available(): 77 | 78 | point_coord = Variable(point_coord.cuda()) 79 | point_class = Variable(point_class.cuda()) 80 | img_vit = Variable(img_vit.cuda()) 81 | img = Variable(img.cuda()) 82 | labels = Variable(labels.cuda()) 83 | #label_for_ce = Variable(label_for_ce.cuda()) 84 | else: 85 | img, point_coord, point_class, img_vit, labels = Variable(img), Variable(point_coord), Variable(point_class), Variable(img_vit), Variable(labels) 86 | 87 | # zero the parameter gradients 88 | optimizer.zero_grad() 89 | #label_for_ce = label_for_ce.long() 90 | # forward 91 | outputs = model(img, point_coord, point_class, img_vit, (h[0].item(), w[0].item())) 92 | # print(outputs) 93 | 94 | loss = criterion(outputs, labels) 95 | score = accuracy_metric(outputs,labels) 96 | 97 | if phase == 'train': 98 | loss.backward() 99 | optimizer.step() 100 | 101 | # calculate loss and IoU 102 | running_loss.append(loss.item()) 103 | running_corrects.append(score.item()) 104 | 105 | 106 | epoch_loss = np.mean(running_loss) 107 | epoch_acc = np.mean(running_corrects) 108 | 109 | print('{} Loss: {:.4f} IoU: {:.4f}'.format( 110 | phase, epoch_loss, epoch_acc)) 111 | 112 | Loss_list[phase].append(epoch_loss) 113 | Accuracy_list[phase].append(epoch_acc) 114 | 115 | # save parameters 116 | if phase == 'valid' and epoch_loss <= best_loss: 117 | best_loss = epoch_loss 118 | best_model_wts = model.state_dict() 119 | counter = 0 120 | elif phase == 'valid' and epoch_loss > best_loss: 121 | counter += 1 122 | if phase == 'train': 123 | scheduler.step() 124 | 125 | print() 126 | 127 | time_elapsed = time.time() - since 128 | print('Training complete in {:.0f}m {:.0f}s'.format( 129 | time_elapsed // 60, time_elapsed % 60)) 130 | print('Best val loss: {:4f}'.format(best_loss)) 131 | 132 | torch.save(best_model_wts, 'save_models/model_best.pth') 133 | 134 | return Loss_list, Accuracy_list 135 | 136 | if __name__ == '__main__': 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--dataset', type=str,default='monuseg/images', help='the path of images') 139 | parser.add_argument('--prompt', type=str,default='sam_vit_h_4b8939.pth', help='') 140 | parser.add_argument('--encoder', type=str,default='mobile_sam.pt', help='') 141 | parser.add_argument('--jsonfile', type=str,default='data_split.json', help='') 142 | parser.add_argument('--loss', default='dice', help='loss type') 143 | parser.add_argument('--batch', type=int, default=4, help='batch size') 144 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 145 | parser.add_argument('--epoch', type=int, default=50, help='epoches') 146 | args = parser.parse_args() 147 | 148 | os.makedirs(f'save_models/',exist_ok=True) 149 | 150 | with open(args.jsonfile, 'r') as f: 151 | df = json.load(f) 152 | 153 | val_files = df['valid'] 154 | train_files = df['train'] 155 | 156 | train_dataset = sam_inputer(args.dataset,train_files,get_train_transform()) 157 | val_dataset = sam_inputer(args.dataset,val_files,get_valid_transform()) 158 | 159 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch, shuffle=True,drop_last=True) 160 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1 ,drop_last=True) 161 | 162 | dataloaders = {'train':train_loader,'valid':val_loader} 163 | 164 | vit_encoder = TinyViT(img_size=1024, in_chans=3, num_classes=1000, 165 | embed_dims=[64, 128, 160, 320], 166 | depths=[2, 2, 6, 2], 167 | num_heads=[2, 4, 5, 10], 168 | window_sizes=[7, 7, 14, 7], 169 | mlp_ratio=4., 170 | drop_rate=0., 171 | drop_path_rate=0.0, 172 | use_checkpoint=True, 173 | mbconv_expand_ratio=4.0, 174 | local_conv_size=3, 175 | layer_lr_decay=0.8 176 | ) 177 | 178 | model_ft = sppnet.Model(image_encoder=vit_encoder) 179 | 180 | encoder_dict = torch.load(args.encoder) 181 | pre_dict = {k: v for k, v in encoder_dict.items() if list(k.split('.'))[0] == 'image_encoder'} 182 | model_ft.load_state_dict(pre_dict, strict=False) 183 | 184 | prompt_dict = torch.load(args.prompt) 185 | pre_dict = {k: v for k, v in prompt_dict.items() if list(k.split('.'))[0] != 'image_encoder'} 186 | model_ft.load_state_dict(pre_dict, strict=False) 187 | 188 | if torch.cuda.is_available(): 189 | model_ft = model_ft.cuda() 190 | 191 | # Loss, IoU and Optimizer 192 | if args.loss == 'ce': 193 | criterion = nn.BCELoss() 194 | if args.loss == 'dice': 195 | criterion = DiceLoss() 196 | 197 | accuracy_metric = IoU() 198 | optimizer_ft = optim.Adam(model_ft.parameters(),lr = args.lr) 199 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.8) 200 | #exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, patience=5, factor=0.1,min_lr=1e-6) 201 | Loss_list, Accuracy_list = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, 202 | num_epochs=args.epoch) 203 | 204 | plt.title('Validation loss and IoU',) 205 | valid_data = pd.DataFrame({'Loss':Loss_list["valid"], 'IoU':Accuracy_list["valid"]}) 206 | valid_data.to_csv(f'valid_data.csv') 207 | sns.lineplot(data=valid_data,dashes=False) 208 | plt.ylabel('Value') 209 | plt.xlabel('Epochs') 210 | plt.savefig('valid.png') 211 | 212 | plt.figure() 213 | plt.title('Training loss and IoU',) 214 | valid_data = pd.DataFrame({'Loss':Loss_list["train"],'IoU':Accuracy_list["train"]}) 215 | valid_data.to_csv(f'train_data.csv') 216 | sns.lineplot(data=valid_data,dashes=False) 217 | plt.ylabel('Value') 218 | plt.xlabel('Epochs') 219 | plt.savefig('train.png') 220 | 221 | -------------------------------------------------------------------------------- /modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /train_adapter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | from torch import optim 7 | import time 8 | import albumentations as A 9 | from albumentations.pytorch import ToTensor 10 | from torch.utils.data import random_split 11 | from torch.optim import lr_scheduler 12 | import seaborn as sns 13 | import pandas as pd 14 | import argparse 15 | import os 16 | from dataloader import sam_inputer 17 | from sklearn.model_selection import GroupKFold 18 | from loss import * 19 | from tqdm import tqdm 20 | import json 21 | import sppnet 22 | from modeling.image_encoder import ImageEncoderViT 23 | from modeling.tiny_vit_sam import TinyViT 24 | from functools import partial 25 | 26 | 27 | def get_train_transform(): 28 | return A.Compose( 29 | [ 30 | A.Resize(256, 256), 31 | # A.HorizontalFlip(p=0.25), 32 | # A.RandomBrightness(p=0.25), 33 | # A.ShiftScaleRotate(shift_limit=0,p=0.25), 34 | # A.CoarseDropout(), 35 | # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 36 | ToTensor() 37 | ]) 38 | 39 | def get_valid_transform(): 40 | return A.Compose( 41 | [ 42 | A.Resize(256, 256), 43 | # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 44 | ToTensor() 45 | ]) 46 | 47 | 48 | def train_model(model, criterion, optimizer, scheduler, num_epochs=5): 49 | since = time.time() 50 | 51 | Loss_list = {'train': [], 'valid': []} 52 | Accuracy_list = {'train': [], 'valid': []} 53 | 54 | best_model_wts = model.state_dict() 55 | 56 | best_loss = float('inf') 57 | counter = 0 58 | 59 | for epoch in range(num_epochs): 60 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 61 | print('-' * 10) 62 | 63 | # Each epoch has a training and validation phase 64 | for phase in ['train', 'valid']: 65 | if phase == 'train': 66 | model.train(True) 67 | 68 | else: 69 | model.train(False) 70 | 71 | running_loss = [] 72 | running_corrects = [] 73 | 74 | # Iterate over data 75 | #for inputs,labels,label_for_ce,image_id in dataloaders[phase]: 76 | for img, point_coord, point_class, img_vit, labels, _, h, w in tqdm(dataloaders[phase]): 77 | # wrap them in Variable 78 | if torch.cuda.is_available(): 79 | 80 | point_coord = Variable(point_coord.cuda()) 81 | point_class = Variable(point_class.cuda()) 82 | img_vit = Variable(img_vit.cuda()) 83 | img = Variable(img.cuda()) 84 | labels = Variable(labels.cuda()) 85 | 86 | else: 87 | img, point_coord, point_class, img_vit, labels = Variable(img), Variable(point_coord), Variable(point_class), Variable(img_vit), Variable(labels) 88 | 89 | # zero the parameter gradients 90 | optimizer.zero_grad() 91 | 92 | # forward 93 | if phase == 'train': 94 | if torch.cuda.device_count() > 1: 95 | for n, value in model.module.image_encoder.named_parameters(): 96 | if "Adapter" not in n: 97 | value.requires_grad = False 98 | else: 99 | for n, value in model.image_encoder.named_parameters(): 100 | if "Adapter" not in n: 101 | value.requires_grad = False 102 | 103 | outputs = model(img, point_coord, point_class, img_vit, (h[0].item(), w[0].item())) 104 | 105 | loss = criterion(outputs, labels) 106 | score = accuracy_metric(outputs,labels) 107 | 108 | if phase == 'train': 109 | loss.backward() 110 | optimizer.step() 111 | 112 | # calculate loss and IoU 113 | running_loss.append(loss.item()) 114 | running_corrects.append(score.item()) 115 | 116 | 117 | epoch_loss = np.mean(running_loss) 118 | epoch_acc = np.mean(running_corrects) 119 | 120 | print('{} Loss: {:.4f} IoU: {:.4f}'.format( 121 | phase, epoch_loss, epoch_acc)) 122 | 123 | Loss_list[phase].append(epoch_loss) 124 | Accuracy_list[phase].append(epoch_acc) 125 | 126 | # save parameters 127 | if phase == 'valid' and epoch_loss <= best_loss: 128 | best_loss = epoch_loss 129 | best_model_wts = model.state_dict() 130 | counter = 0 131 | elif phase == 'valid' and epoch_loss > best_loss: 132 | counter += 1 133 | if phase == 'train': 134 | scheduler.step() 135 | 136 | print() 137 | 138 | time_elapsed = time.time() - since 139 | print('Training complete in {:.0f}m {:.0f}s'.format( 140 | time_elapsed // 60, time_elapsed % 60)) 141 | print('Best val loss: {:4f}'.format(best_loss)) 142 | 143 | torch.save(best_model_wts, 'save_models/model_best.pth') 144 | 145 | return Loss_list, Accuracy_list 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--dataset', type=str,default='datasets/', help='the path of images') 150 | parser.add_argument('--prompt', type=str,default='pre_weights/sam_vit_b_01ec64.pth', help='') 151 | parser.add_argument('--encoder', type=str,default='pre_weights/sam_vit_b_01ec64.pth', help='') 152 | parser.add_argument('--jsonfile', type=str,default='datasets/monuseg/data_split.json', help='') 153 | parser.add_argument('--loss', default='dice', help='loss type') 154 | parser.add_argument('--batch', type=int, default=2, help='batch size') 155 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 156 | parser.add_argument('--epoch', type=int, default=20, help='epoches') 157 | args = parser.parse_args() 158 | 159 | os.makedirs(f'save_models/',exist_ok=True) 160 | 161 | with open(args.jsonfile, 'r') as f: 162 | df = json.load(f) 163 | 164 | val_files = df['valid'] 165 | train_files = df['train'] 166 | 167 | train_dataset = sam_inputer(args.dataset,train_files,get_train_transform()) 168 | val_dataset = sam_inputer(args.dataset,val_files,get_valid_transform()) 169 | 170 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch, shuffle=True,drop_last=True) 171 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1 ,drop_last=True) 172 | 173 | dataloaders = {'train':train_loader,'valid':val_loader} 174 | 175 | # vit_encoder = TinyViT(img_size=1024, in_chans=3, num_classes=1000, 176 | # embed_dims=[64, 128, 160, 320], 177 | # depths=[2, 2, 6, 2], 178 | # num_heads=[2, 4, 5, 10], 179 | # window_sizes=[7, 7, 14, 7], 180 | # mlp_ratio=4., 181 | # drop_rate=0., 182 | # drop_path_rate=0.0, 183 | # use_checkpoint=True, 184 | # mbconv_expand_ratio=4.0, 185 | # local_conv_size=3, 186 | # layer_lr_decay=0.8 187 | # ) 188 | 189 | vit_encoder = ImageEncoderViT( 190 | depth=12, 191 | embed_dim=768, 192 | img_size=1024, 193 | mlp_ratio=4, 194 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 195 | num_heads=12, 196 | patch_size=16, 197 | qkv_bias=True, 198 | use_rel_pos=True, 199 | global_attn_indexes=[2, 5, 8, 11], 200 | window_size=14, 201 | out_chans=256, 202 | ) 203 | 204 | # vit_encoder = ImageEncoderViT( 205 | # depth=32, 206 | # embed_dim=1280, 207 | # img_size=1024, 208 | # mlp_ratio=4, 209 | # norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 210 | # num_heads=16, 211 | # patch_size=16, 212 | # qkv_bias=True, 213 | # use_rel_pos=True, 214 | # global_attn_indexes=[7, 15, 23, 31], 215 | # window_size=14, 216 | # out_chans=256, 217 | # ) 218 | 219 | model_ft = sppnet.Model(image_encoder=vit_encoder) 220 | # model_ft.load_state_dict(torch.load(args.encoder), strict=False) 221 | 222 | # pretrain_dict = torch.load(args.encoder) 223 | # model_dict = model_ft.state_dict() 224 | # image_encoder_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict} 225 | # print(image_encoder_dict.keys()) 226 | # model_dict.update(image_encoder_dict) 227 | # model_ft.load_state_dict(model_dict, strict=True) 228 | 229 | encoder_dict = torch.load(args.encoder) 230 | pre_dict = {k: v for k, v in encoder_dict.items() if list(k.split('.'))[0] == 'image_encoder'} 231 | model_ft.load_state_dict(pre_dict, strict=False) 232 | 233 | pre_dict = {k: v for k, v in encoder_dict.items() if list(k.split('.'))[0] == 'prompt_encoder'} 234 | model_ft.load_state_dict(pre_dict, strict=False) 235 | 236 | pre_dict = {k: v for k, v in encoder_dict.items() if list(k.split('.'))[0] == 'mask_decoder'} 237 | model_ft.load_state_dict(pre_dict, strict=False) 238 | 239 | if torch.cuda.device_count() > 1: 240 | print("Let's use", torch.cuda.device_count(), "GPUs!") 241 | 242 | model_ft = nn.DataParallel(model_ft) 243 | 244 | if torch.cuda.is_available(): 245 | model_ft = model_ft.cuda() 246 | 247 | # Loss, IoU and Optimizer 248 | if args.loss == 'ce': 249 | criterion = nn.BCELoss() 250 | if args.loss == 'dice': 251 | criterion = DiceLoss() 252 | 253 | accuracy_metric = IoU() 254 | optimizer_ft = optim.Adam(model_ft.parameters(),lr = args.lr) 255 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.8) 256 | #exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, patience=5, factor=0.1,min_lr=1e-6) 257 | Loss_list, Accuracy_list = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, 258 | num_epochs=args.epoch) 259 | 260 | plt.title('Validation loss and IoU',) 261 | valid_data = pd.DataFrame({'Loss':Loss_list["valid"], 'IoU':Accuracy_list["valid"]}) 262 | valid_data.to_csv(f'valid_data.csv') 263 | sns.lineplot(data=valid_data,dashes=False) 264 | plt.ylabel('Value') 265 | plt.xlabel('Epochs') 266 | plt.savefig('valid.png') 267 | 268 | plt.figure() 269 | plt.title('Training loss and IoU',) 270 | valid_data = pd.DataFrame({'Loss':Loss_list["train"],'IoU':Accuracy_list["train"]}) 271 | valid_data.to_csv(f'train_data.csv') 272 | sns.lineplot(data=valid_data,dashes=False) 273 | plt.ylabel('Value') 274 | plt.xlabel('Epochs') 275 | plt.savefig('train.png') -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock, Adapter 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 148 | positional parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | 153 | self.attn = Attention( 154 | dim, 155 | num_heads=num_heads, 156 | qkv_bias=qkv_bias, 157 | use_rel_pos=use_rel_pos, 158 | rel_pos_zero_init=rel_pos_zero_init, 159 | input_size=input_size if window_size == 0 else (window_size, window_size), 160 | ) 161 | 162 | self.norm2 = norm_layer(dim) 163 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 164 | 165 | self.window_size = window_size 166 | 167 | #----------------------------------------------- 168 | 169 | self.ft = Adapter(dim) 170 | self.MLP_Adapter = Adapter(dim, skip_connect=False) 171 | 172 | #----------------------------------------------- 173 | 174 | def forward(self, x: torch.Tensor) -> torch.Tensor: 175 | shortcut = x 176 | x = self.norm1(x) 177 | # Window partition 178 | if self.window_size > 0: 179 | H, W = x.shape[1], x.shape[2] 180 | x, pad_hw = window_partition(x, self.window_size) 181 | 182 | x = self.attn(x) 183 | #----------------------------------------------- 184 | 185 | x = self.ft(x) 186 | 187 | #----------------------------------------------- 188 | 189 | # Reverse window partition 190 | if self.window_size > 0: 191 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 192 | 193 | x = shortcut + x 194 | 195 | #----------------------------------------------- 196 | xn = self.norm2(x) 197 | x = x + self.mlp(xn) + 0.5 * self.MLP_Adapter(xn) 198 | 199 | # x = x + self.mlp(self.norm2(x)) 200 | #----------------------------------------------- 201 | 202 | return x 203 | 204 | 205 | class Attention(nn.Module): 206 | """Multi-head Attention block with relative position embeddings.""" 207 | 208 | def __init__( 209 | self, 210 | dim: int, 211 | num_heads: int = 8, 212 | qkv_bias: bool = True, 213 | use_rel_pos: bool = False, 214 | rel_pos_zero_init: bool = True, 215 | input_size: Optional[Tuple[int, int]] = None, 216 | ) -> None: 217 | """ 218 | Args: 219 | dim (int): Number of input channels. 220 | num_heads (int): Number of attention heads. 221 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 222 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 223 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 224 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 225 | positional parameter size. 226 | """ 227 | super().__init__() 228 | self.num_heads = num_heads 229 | head_dim = dim // num_heads 230 | self.scale = head_dim**-0.5 231 | 232 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 233 | self.proj = nn.Linear(dim, dim) 234 | 235 | self.use_rel_pos = use_rel_pos 236 | if self.use_rel_pos: 237 | assert ( 238 | input_size is not None 239 | ), "Input size must be provided if using relative positional encoding." 240 | # initialize relative positional embeddings 241 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 242 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 243 | 244 | def forward(self, x: torch.Tensor) -> torch.Tensor: 245 | B, H, W, _ = x.shape 246 | # qkv with shape (3, B, nHead, H * W, C) 247 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 248 | # q, k, v with shape (B * nHead, H * W, C) 249 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 250 | 251 | attn = (q * self.scale) @ k.transpose(-2, -1) 252 | 253 | if self.use_rel_pos: 254 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 255 | 256 | attn = attn.softmax(dim=-1) 257 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 258 | x = self.proj(x) 259 | 260 | return x 261 | 262 | 263 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 264 | """ 265 | Partition into non-overlapping windows with padding if needed. 266 | Args: 267 | x (tensor): input tokens with [B, H, W, C]. 268 | window_size (int): window size. 269 | 270 | Returns: 271 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 272 | (Hp, Wp): padded height and width before partition 273 | """ 274 | B, H, W, C = x.shape 275 | 276 | pad_h = (window_size - H % window_size) % window_size 277 | pad_w = (window_size - W % window_size) % window_size 278 | if pad_h > 0 or pad_w > 0: 279 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 280 | Hp, Wp = H + pad_h, W + pad_w 281 | 282 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 283 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 284 | return windows, (Hp, Wp) 285 | 286 | 287 | def window_unpartition( 288 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 289 | ) -> torch.Tensor: 290 | """ 291 | Window unpartition into original sequences and removing padding. 292 | Args: 293 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 294 | window_size (int): window size. 295 | pad_hw (Tuple): padded height and width (Hp, Wp). 296 | hw (Tuple): original height and width (H, W) before padding. 297 | 298 | Returns: 299 | x: unpartitioned sequences with [B, H, W, C]. 300 | """ 301 | Hp, Wp = pad_hw 302 | H, W = hw 303 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 304 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 305 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 306 | 307 | if Hp > H or Wp > W: 308 | x = x[:, :H, :W, :].contiguous() 309 | return x 310 | 311 | 312 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 313 | """ 314 | Get relative positional embeddings according to the relative positions of 315 | query and key sizes. 316 | Args: 317 | q_size (int): size of query q. 318 | k_size (int): size of key k. 319 | rel_pos (Tensor): relative position embeddings (L, C). 320 | 321 | Returns: 322 | Extracted positional embeddings according to relative positions. 323 | """ 324 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 325 | # Interpolate rel pos if needed. 326 | if rel_pos.shape[0] != max_rel_dist: 327 | # Interpolate rel pos. 328 | rel_pos_resized = F.interpolate( 329 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 330 | size=max_rel_dist, 331 | mode="linear", 332 | ) 333 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 334 | else: 335 | rel_pos_resized = rel_pos 336 | 337 | # Scale the coords with short length if shapes for q and k are different. 338 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 339 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 340 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 341 | 342 | return rel_pos_resized[relative_coords.long()] 343 | 344 | 345 | def add_decomposed_rel_pos( 346 | attn: torch.Tensor, 347 | q: torch.Tensor, 348 | rel_pos_h: torch.Tensor, 349 | rel_pos_w: torch.Tensor, 350 | q_size: Tuple[int, int], 351 | k_size: Tuple[int, int], 352 | ) -> torch.Tensor: 353 | """ 354 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 355 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 356 | Args: 357 | attn (Tensor): attention map. 358 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 359 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 360 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 361 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 362 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 363 | 364 | Returns: 365 | attn (Tensor): attention map with added relative positional embeddings. 366 | """ 367 | q_h, q_w = q_size 368 | k_h, k_w = k_size 369 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 370 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 371 | 372 | B, _, dim = q.shape 373 | r_q = q.reshape(B, q_h, q_w, dim) 374 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 375 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 376 | 377 | attn = ( 378 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 379 | ).view(B, q_h * q_w, k_h * k_w) 380 | 381 | return attn 382 | 383 | 384 | class PatchEmbed(nn.Module): 385 | """ 386 | Image to Patch Embedding. 387 | """ 388 | 389 | def __init__( 390 | self, 391 | kernel_size: Tuple[int, int] = (16, 16), 392 | stride: Tuple[int, int] = (16, 16), 393 | padding: Tuple[int, int] = (0, 0), 394 | in_chans: int = 3, 395 | embed_dim: int = 768, 396 | ) -> None: 397 | """ 398 | Args: 399 | kernel_size (Tuple): kernel size of the projection layer. 400 | stride (Tuple): stride of the projection layer. 401 | padding (Tuple): padding size of the projection layer. 402 | in_chans (int): Number of input image channels. 403 | embed_dim (int): Patch embedding dimension. 404 | """ 405 | super().__init__() 406 | 407 | self.proj = nn.Conv2d( 408 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 409 | ) 410 | 411 | def forward(self, x: torch.Tensor) -> torch.Tensor: 412 | x = self.proj(x) 413 | # B C H W -> B H W C 414 | x = x.permute(0, 2, 3, 1) 415 | return x 416 | -------------------------------------------------------------------------------- /modeling/tiny_vit_sam.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # TinyViT Model Architecture 3 | # Copyright (c) 2022 Microsoft 4 | # Adapted from LeViT and Swin Transformer 5 | # LeViT: (https://github.com/facebookresearch/levit) 6 | # Swin: (https://github.com/microsoft/swin-transformer) 7 | # Build the TinyViT Model 8 | # -------------------------------------------------------- 9 | 10 | import itertools 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.utils.checkpoint as checkpoint 15 | from timm.models.layers import DropPath as TimmDropPath,\ 16 | to_2tuple, trunc_normal_ 17 | from timm.models.registry import register_model 18 | from typing import Tuple 19 | #import loralib as lora 20 | 21 | 22 | class Conv2d_BN(torch.nn.Sequential): 23 | def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, 24 | groups=1, bn_weight_init=1): 25 | super().__init__() 26 | self.add_module('c', torch.nn.Conv2d( 27 | a, b, ks, stride, pad, dilation, groups, bias=False)) 28 | bn = torch.nn.BatchNorm2d(b) 29 | torch.nn.init.constant_(bn.weight, bn_weight_init) 30 | torch.nn.init.constant_(bn.bias, 0) 31 | self.add_module('bn', bn) 32 | 33 | @torch.no_grad() 34 | def fuse(self): 35 | c, bn = self._modules.values() 36 | w = bn.weight / (bn.running_var + bn.eps)**0.5 37 | w = c.weight * w[:, None, None, None] 38 | b = bn.bias - bn.running_mean * bn.weight / \ 39 | (bn.running_var + bn.eps)**0.5 40 | m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( 41 | 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) 42 | m.weight.data.copy_(w) 43 | m.bias.data.copy_(b) 44 | return m 45 | 46 | 47 | class DropPath(TimmDropPath): 48 | def __init__(self, drop_prob=None): 49 | super().__init__(drop_prob=drop_prob) 50 | self.drop_prob = drop_prob 51 | 52 | def __repr__(self): 53 | msg = super().__repr__() 54 | msg += f'(drop_prob={self.drop_prob})' 55 | return msg 56 | 57 | 58 | class PatchEmbed(nn.Module): 59 | def __init__(self, in_chans, embed_dim, resolution, activation): 60 | super().__init__() 61 | img_size: Tuple[int, int] = to_2tuple(resolution) 62 | self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) 63 | self.num_patches = self.patches_resolution[0] * \ 64 | self.patches_resolution[1] 65 | self.in_chans = in_chans 66 | self.embed_dim = embed_dim 67 | n = embed_dim 68 | self.seq = nn.Sequential( 69 | Conv2d_BN(in_chans, n // 2, 3, 2, 1), 70 | activation(), 71 | Conv2d_BN(n // 2, n, 3, 2, 1), 72 | ) 73 | 74 | def forward(self, x): 75 | return self.seq(x) 76 | 77 | 78 | class MBConv(nn.Module): 79 | def __init__(self, in_chans, out_chans, expand_ratio, 80 | activation, drop_path): 81 | super().__init__() 82 | self.in_chans = in_chans 83 | self.hidden_chans = int(in_chans * expand_ratio) 84 | self.out_chans = out_chans 85 | 86 | self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) 87 | self.act1 = activation() 88 | 89 | self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, 90 | ks=3, stride=1, pad=1, groups=self.hidden_chans) 91 | self.act2 = activation() 92 | 93 | self.conv3 = Conv2d_BN( 94 | self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) 95 | self.act3 = activation() 96 | 97 | self.drop_path = DropPath( 98 | drop_path) if drop_path > 0. else nn.Identity() 99 | 100 | def forward(self, x): 101 | shortcut = x 102 | 103 | x = self.conv1(x) 104 | x = self.act1(x) 105 | 106 | x = self.conv2(x) 107 | x = self.act2(x) 108 | 109 | x = self.conv3(x) 110 | 111 | x = self.drop_path(x) 112 | 113 | x += shortcut 114 | x = self.act3(x) 115 | 116 | return x 117 | 118 | 119 | class PatchMerging(nn.Module): 120 | def __init__(self, input_resolution, dim, out_dim, activation): 121 | super().__init__() 122 | 123 | self.input_resolution = input_resolution 124 | self.dim = dim 125 | self.out_dim = out_dim 126 | self.act = activation() 127 | self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) 128 | stride_c=2 129 | if(out_dim==320 or out_dim==448 or out_dim==576): 130 | stride_c=1 131 | self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) 132 | self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) 133 | 134 | def forward(self, x): 135 | if x.ndim == 3: 136 | H, W = self.input_resolution 137 | B = len(x) 138 | # (B, C, H, W) 139 | x = x.view(B, H, W, -1).permute(0, 3, 1, 2) 140 | 141 | x = self.conv1(x) 142 | x = self.act(x) 143 | 144 | x = self.conv2(x) 145 | x = self.act(x) 146 | x = self.conv3(x) 147 | x = x.flatten(2).transpose(1, 2) 148 | return x 149 | 150 | 151 | class ConvLayer(nn.Module): 152 | def __init__(self, dim, input_resolution, depth, 153 | activation, 154 | drop_path=0., downsample=None, use_checkpoint=False, 155 | out_dim=None, 156 | conv_expand_ratio=4., 157 | ): 158 | 159 | super().__init__() 160 | self.dim = dim 161 | self.input_resolution = input_resolution 162 | self.depth = depth 163 | self.use_checkpoint = use_checkpoint 164 | 165 | # build blocks 166 | self.blocks = nn.ModuleList([ 167 | MBConv(dim, dim, conv_expand_ratio, activation, 168 | drop_path[i] if isinstance(drop_path, list) else drop_path, 169 | ) 170 | for i in range(depth)]) 171 | 172 | # patch merging layer 173 | if downsample is not None: 174 | self.downsample = downsample( 175 | input_resolution, dim=dim, out_dim=out_dim, activation=activation) 176 | else: 177 | self.downsample = None 178 | 179 | def forward(self, x): 180 | for blk in self.blocks: 181 | if self.use_checkpoint: 182 | x = checkpoint.checkpoint(blk, x) 183 | else: 184 | x = blk(x) 185 | if self.downsample is not None: 186 | x = self.downsample(x) 187 | return x 188 | 189 | 190 | class Mlp(nn.Module): 191 | def __init__(self, in_features, hidden_features=None, 192 | out_features=None, act_layer=nn.GELU, drop=0.): 193 | super().__init__() 194 | out_features = out_features or in_features 195 | hidden_features = hidden_features or in_features 196 | self.norm = nn.LayerNorm(in_features) 197 | self.fc1 = nn.Linear(in_features, hidden_features) 198 | self.fc2 = nn.Linear(hidden_features, out_features) 199 | self.act = act_layer() 200 | self.drop = nn.Dropout(drop) 201 | 202 | def forward(self, x): 203 | x = self.norm(x) 204 | 205 | x = self.fc1(x) 206 | x = self.act(x) 207 | x = self.drop(x) 208 | x = self.fc2(x) 209 | x = self.drop(x) 210 | return x 211 | 212 | 213 | class Attention(torch.nn.Module): 214 | def __init__(self, dim, key_dim, num_heads=8, 215 | attn_ratio=4, 216 | resolution=(14, 14), 217 | ): 218 | super().__init__() 219 | # (h, w) 220 | assert isinstance(resolution, tuple) and len(resolution) == 2 221 | self.num_heads = num_heads 222 | self.scale = key_dim ** -0.5 223 | self.key_dim = key_dim 224 | self.nh_kd = nh_kd = key_dim * num_heads 225 | self.d = int(attn_ratio * key_dim) 226 | self.dh = int(attn_ratio * key_dim) * num_heads 227 | self.attn_ratio = attn_ratio 228 | h = self.dh + nh_kd * 2 229 | 230 | self.norm = nn.LayerNorm(dim) 231 | # self.qkv = lora.Linear(dim, h, r=16) 232 | # self.proj = lora.Linear(self.dh, dim, r=16) 233 | self.qkv = nn.Linear(dim, h) 234 | self.proj = nn.Linear(self.dh, dim) 235 | 236 | points = list(itertools.product( 237 | range(resolution[0]), range(resolution[1]))) 238 | N = len(points) 239 | attention_offsets = {} 240 | idxs = [] 241 | for p1 in points: 242 | for p2 in points: 243 | offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) 244 | if offset not in attention_offsets: 245 | attention_offsets[offset] = len(attention_offsets) 246 | idxs.append(attention_offsets[offset]) 247 | self.attention_biases = torch.nn.Parameter( 248 | torch.zeros(num_heads, len(attention_offsets))) 249 | self.register_buffer('attention_bias_idxs', 250 | torch.LongTensor(idxs).view(N, N), 251 | persistent=False) 252 | 253 | @torch.no_grad() 254 | def train(self, mode=True): 255 | super().train(mode) 256 | if mode and hasattr(self, 'ab'): 257 | del self.ab 258 | else: 259 | self.ab = self.attention_biases[:, self.attention_bias_idxs] 260 | 261 | def forward(self, x): # x (B,N,C) 262 | B, N, _ = x.shape 263 | 264 | # Normalization 265 | x = self.norm(x) 266 | 267 | qkv = self.qkv(x) 268 | # (B, N, num_heads, d) 269 | q, k, v = qkv.view(B, N, self.num_heads, - 270 | 1).split([self.key_dim, self.key_dim, self.d], dim=3) 271 | # (B, num_heads, N, d) 272 | q = q.permute(0, 2, 1, 3) 273 | k = k.permute(0, 2, 1, 3) 274 | v = v.permute(0, 2, 1, 3) 275 | 276 | attn = ( 277 | (q @ k.transpose(-2, -1)) * self.scale 278 | + 279 | (self.attention_biases[:, self.attention_bias_idxs] 280 | if self.training else self.ab) 281 | ) 282 | attn = attn.softmax(dim=-1) 283 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) 284 | x = self.proj(x) 285 | return x 286 | 287 | 288 | class TinyViTBlock(nn.Module): 289 | r""" TinyViT Block. 290 | 291 | Args: 292 | dim (int): Number of input channels. 293 | input_resolution (tuple[int, int]): Input resulotion. 294 | num_heads (int): Number of attention heads. 295 | window_size (int): Window size. 296 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 297 | drop (float, optional): Dropout rate. Default: 0.0 298 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 299 | local_conv_size (int): the kernel size of the convolution between 300 | Attention and MLP. Default: 3 301 | activation: the activation function. Default: nn.GELU 302 | """ 303 | 304 | def __init__(self, dim, input_resolution, num_heads, window_size=7, 305 | mlp_ratio=4., drop=0., drop_path=0., 306 | local_conv_size=3, 307 | activation=nn.GELU, 308 | ): 309 | super().__init__() 310 | self.dim = dim 311 | self.input_resolution = input_resolution 312 | self.num_heads = num_heads 313 | assert window_size > 0, 'window_size must be greater than 0' 314 | self.window_size = window_size 315 | self.mlp_ratio = mlp_ratio 316 | 317 | self.drop_path = DropPath( 318 | drop_path) if drop_path > 0. else nn.Identity() 319 | 320 | assert dim % num_heads == 0, 'dim must be divisible by num_heads' 321 | head_dim = dim // num_heads 322 | 323 | window_resolution = (window_size, window_size) 324 | self.attn = Attention(dim, head_dim, num_heads, 325 | attn_ratio=1, resolution=window_resolution) 326 | 327 | mlp_hidden_dim = int(dim * mlp_ratio) 328 | mlp_activation = activation 329 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 330 | act_layer=mlp_activation, drop=drop) 331 | 332 | pad = local_conv_size // 2 333 | self.local_conv = Conv2d_BN( 334 | dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) 335 | 336 | def forward(self, x): 337 | H, W = self.input_resolution 338 | B, L, C = x.shape 339 | assert L == H * W, "input feature has wrong size" 340 | res_x = x 341 | if H == self.window_size and W == self.window_size: 342 | x = self.attn(x) 343 | else: 344 | x = x.view(B, H, W, C) 345 | pad_b = (self.window_size - H % 346 | self.window_size) % self.window_size 347 | pad_r = (self.window_size - W % 348 | self.window_size) % self.window_size 349 | padding = pad_b > 0 or pad_r > 0 350 | 351 | if padding: 352 | x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) 353 | 354 | pH, pW = H + pad_b, W + pad_r 355 | nH = pH // self.window_size 356 | nW = pW // self.window_size 357 | # window partition 358 | x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( 359 | B * nH * nW, self.window_size * self.window_size, C) 360 | x = self.attn(x) 361 | # window reverse 362 | x = x.view(B, nH, nW, self.window_size, self.window_size, 363 | C).transpose(2, 3).reshape(B, pH, pW, C) 364 | 365 | if padding: 366 | x = x[:, :H, :W].contiguous() 367 | 368 | x = x.view(B, L, C) 369 | 370 | x = res_x + self.drop_path(x) 371 | 372 | x = x.transpose(1, 2).reshape(B, C, H, W) 373 | x = self.local_conv(x) 374 | x = x.view(B, C, L).transpose(1, 2) 375 | 376 | x = x + self.drop_path(self.mlp(x)) 377 | return x 378 | 379 | def extra_repr(self) -> str: 380 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 381 | f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" 382 | 383 | 384 | class BasicLayer(nn.Module): 385 | """ A basic TinyViT layer for one stage. 386 | 387 | Args: 388 | dim (int): Number of input channels. 389 | input_resolution (tuple[int]): Input resolution. 390 | depth (int): Number of blocks. 391 | num_heads (int): Number of attention heads. 392 | window_size (int): Local window size. 393 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 394 | drop (float, optional): Dropout rate. Default: 0.0 395 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 396 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 397 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 398 | local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 399 | activation: the activation function. Default: nn.GELU 400 | out_dim: the output dimension of the layer. Default: dim 401 | """ 402 | 403 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 404 | mlp_ratio=4., drop=0., 405 | drop_path=0., downsample=None, use_checkpoint=False, 406 | local_conv_size=3, 407 | activation=nn.GELU, 408 | out_dim=None, 409 | ): 410 | 411 | super().__init__() 412 | self.dim = dim 413 | self.input_resolution = input_resolution 414 | self.depth = depth 415 | self.use_checkpoint = use_checkpoint 416 | 417 | # build blocks 418 | self.blocks = nn.ModuleList([ 419 | TinyViTBlock(dim=dim, input_resolution=input_resolution, 420 | num_heads=num_heads, window_size=window_size, 421 | mlp_ratio=mlp_ratio, 422 | drop=drop, 423 | drop_path=drop_path[i] if isinstance( 424 | drop_path, list) else drop_path, 425 | local_conv_size=local_conv_size, 426 | activation=activation, 427 | ) 428 | for i in range(depth)]) 429 | 430 | # patch merging layer 431 | if downsample is not None: 432 | self.downsample = downsample( 433 | input_resolution, dim=dim, out_dim=out_dim, activation=activation) 434 | else: 435 | self.downsample = None 436 | 437 | def forward(self, x): 438 | for blk in self.blocks: 439 | if self.use_checkpoint: 440 | x = checkpoint.checkpoint(blk, x) 441 | else: 442 | x = blk(x) 443 | if self.downsample is not None: 444 | x = self.downsample(x) 445 | return x 446 | 447 | def extra_repr(self) -> str: 448 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 449 | 450 | class LayerNorm2d(nn.Module): 451 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 452 | super().__init__() 453 | self.weight = nn.Parameter(torch.ones(num_channels)) 454 | self.bias = nn.Parameter(torch.zeros(num_channels)) 455 | self.eps = eps 456 | 457 | def forward(self, x: torch.Tensor) -> torch.Tensor: 458 | u = x.mean(1, keepdim=True) 459 | s = (x - u).pow(2).mean(1, keepdim=True) 460 | x = (x - u) / torch.sqrt(s + self.eps) 461 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 462 | return x 463 | class TinyViT(nn.Module): 464 | def __init__(self, img_size=224, in_chans=3, num_classes=1000, 465 | embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], 466 | num_heads=[3, 6, 12, 24], 467 | window_sizes=[7, 7, 14, 7], 468 | mlp_ratio=4., 469 | drop_rate=0., 470 | drop_path_rate=0.1, 471 | use_checkpoint=False, 472 | mbconv_expand_ratio=4.0, 473 | local_conv_size=3, 474 | layer_lr_decay=1.0, 475 | ): 476 | super().__init__() 477 | self.img_size=img_size 478 | self.num_classes = num_classes 479 | self.depths = depths 480 | self.num_layers = len(depths) 481 | self.mlp_ratio = mlp_ratio 482 | 483 | activation = nn.GELU 484 | 485 | self.patch_embed = PatchEmbed(in_chans=in_chans, 486 | embed_dim=embed_dims[0], 487 | resolution=img_size, 488 | activation=activation) 489 | 490 | patches_resolution = self.patch_embed.patches_resolution 491 | self.patches_resolution = patches_resolution 492 | 493 | # stochastic depth 494 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 495 | sum(depths))] # stochastic depth decay rule 496 | 497 | # build layers 498 | self.layers = nn.ModuleList() 499 | for i_layer in range(self.num_layers): 500 | kwargs = dict(dim=embed_dims[i_layer], 501 | input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), 502 | patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), 503 | # input_resolution=(patches_resolution[0] // (2 ** i_layer), 504 | # patches_resolution[1] // (2 ** i_layer)), 505 | depth=depths[i_layer], 506 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 507 | downsample=PatchMerging if ( 508 | i_layer < self.num_layers - 1) else None, 509 | use_checkpoint=use_checkpoint, 510 | out_dim=embed_dims[min( 511 | i_layer + 1, len(embed_dims) - 1)], 512 | activation=activation, 513 | ) 514 | if i_layer == 0: 515 | layer = ConvLayer( 516 | conv_expand_ratio=mbconv_expand_ratio, 517 | **kwargs, 518 | ) 519 | else: 520 | layer = BasicLayer( 521 | num_heads=num_heads[i_layer], 522 | window_size=window_sizes[i_layer], 523 | mlp_ratio=self.mlp_ratio, 524 | drop=drop_rate, 525 | local_conv_size=local_conv_size, 526 | **kwargs) 527 | self.layers.append(layer) 528 | 529 | # Classifier head 530 | self.norm_head = nn.LayerNorm(embed_dims[-1]) 531 | self.head = nn.Linear( 532 | embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() 533 | # self.head = lora.Linear( 534 | # embed_dims[-1], num_classes, r=16) if num_classes > 0 else torch.nn.Identity() 535 | 536 | # init weights 537 | self.apply(self._init_weights) 538 | self.set_layer_lr_decay(layer_lr_decay) 539 | self.neck = nn.Sequential( 540 | nn.Conv2d( 541 | embed_dims[-1], 542 | 256, 543 | kernel_size=1, 544 | bias=False, 545 | ), 546 | LayerNorm2d(256), 547 | nn.Conv2d( 548 | 256, 549 | 256, 550 | kernel_size=3, 551 | padding=1, 552 | bias=False, 553 | ), 554 | LayerNorm2d(256), 555 | ) 556 | def set_layer_lr_decay(self, layer_lr_decay): 557 | decay_rate = layer_lr_decay 558 | 559 | # layers -> blocks (depth) 560 | depth = sum(self.depths) 561 | lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] 562 | print("LR SCALES:", lr_scales) 563 | 564 | def _set_lr_scale(m, scale): 565 | for p in m.parameters(): 566 | p.lr_scale = scale 567 | 568 | self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) 569 | i = 0 570 | for layer in self.layers: 571 | for block in layer.blocks: 572 | block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) 573 | i += 1 574 | if layer.downsample is not None: 575 | layer.downsample.apply( 576 | lambda x: _set_lr_scale(x, lr_scales[i - 1])) 577 | assert i == depth 578 | for m in [self.norm_head, self.head]: 579 | m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) 580 | 581 | for k, p in self.named_parameters(): 582 | p.param_name = k 583 | 584 | def _check_lr_scale(m): 585 | for p in m.parameters(): 586 | assert hasattr(p, 'lr_scale'), p.param_name 587 | 588 | self.apply(_check_lr_scale) 589 | 590 | def _init_weights(self, m): 591 | if isinstance(m, nn.Linear): 592 | trunc_normal_(m.weight, std=.02) 593 | if isinstance(m, nn.Linear) and m.bias is not None: 594 | nn.init.constant_(m.bias, 0) 595 | elif isinstance(m, nn.LayerNorm): 596 | nn.init.constant_(m.bias, 0) 597 | nn.init.constant_(m.weight, 1.0) 598 | 599 | @torch.jit.ignore 600 | def no_weight_decay_keywords(self): 601 | return {'attention_biases'} 602 | 603 | def forward_features(self, x): 604 | # x: (N, C, H, W) 605 | x = self.patch_embed(x) 606 | 607 | x = self.layers[0](x) 608 | start_i = 1 609 | 610 | for i in range(start_i, len(self.layers)): 611 | layer = self.layers[i] 612 | x = layer(x) 613 | B,_,C=x.size() 614 | x = x.view(B, 64, 64, C) 615 | x=x.permute(0, 3, 1, 2) 616 | x=self.neck(x) 617 | return x 618 | 619 | def forward(self, x): 620 | x = self.forward_features(x) 621 | #x = self.norm_head(x) 622 | #x = self.head(x) 623 | return x 624 | 625 | 626 | _checkpoint_url_format = \ 627 | 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' 628 | _provided_checkpoints = { 629 | 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', 630 | 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', 631 | 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', 632 | 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', 633 | 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', 634 | } 635 | 636 | 637 | def register_tiny_vit_model(fn): 638 | '''Register a TinyViT model 639 | It is a wrapper of `register_model` with loading the pretrained checkpoint. 640 | ''' 641 | def fn_wrapper(pretrained=False, **kwargs): 642 | model = fn() 643 | if pretrained: 644 | model_name = fn.__name__ 645 | assert model_name in _provided_checkpoints, \ 646 | f'Sorry that the checkpoint `{model_name}` is not provided yet.' 647 | url = _checkpoint_url_format.format( 648 | _provided_checkpoints[model_name]) 649 | checkpoint = torch.hub.load_state_dict_from_url( 650 | url=url, 651 | map_location='cpu', check_hash=False, 652 | ) 653 | model.load_state_dict(checkpoint['model']) 654 | 655 | return model 656 | 657 | # rename the name of fn_wrapper 658 | fn_wrapper.__name__ = fn.__name__ 659 | return register_model(fn_wrapper) 660 | 661 | 662 | @register_tiny_vit_model 663 | def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): 664 | return TinyViT( 665 | num_classes=num_classes, 666 | embed_dims=[64, 128, 160, 320], 667 | depths=[2, 2, 6, 2], 668 | num_heads=[2, 4, 5, 10], 669 | window_sizes=[7, 7, 14, 7], 670 | drop_path_rate=drop_path_rate, 671 | ) 672 | 673 | 674 | @register_tiny_vit_model 675 | def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): 676 | return TinyViT( 677 | num_classes=num_classes, 678 | embed_dims=[64, 128, 256, 448], 679 | depths=[2, 2, 6, 2], 680 | num_heads=[2, 4, 8, 14], 681 | window_sizes=[7, 7, 14, 7], 682 | drop_path_rate=drop_path_rate, 683 | ) 684 | 685 | 686 | @register_tiny_vit_model 687 | def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): 688 | return TinyViT( 689 | num_classes=num_classes, 690 | embed_dims=[96, 192, 384, 576], 691 | depths=[2, 2, 6, 2], 692 | num_heads=[3, 6, 12, 18], 693 | window_sizes=[7, 7, 14, 7], 694 | drop_path_rate=drop_path_rate, 695 | ) 696 | 697 | 698 | @register_tiny_vit_model 699 | def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): 700 | return TinyViT( 701 | img_size=384, 702 | num_classes=num_classes, 703 | embed_dims=[96, 192, 384, 576], 704 | depths=[2, 2, 6, 2], 705 | num_heads=[3, 6, 12, 18], 706 | window_sizes=[12, 12, 24, 12], 707 | drop_path_rate=drop_path_rate, 708 | ) 709 | 710 | 711 | @register_tiny_vit_model 712 | def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): 713 | return TinyViT( 714 | img_size=512, 715 | num_classes=num_classes, 716 | embed_dims=[96, 192, 384, 576], 717 | depths=[2, 2, 6, 2], 718 | num_heads=[3, 6, 12, 18], 719 | window_sizes=[16, 16, 32, 16], 720 | drop_path_rate=drop_path_rate, 721 | ) 722 | --------------------------------------------------------------------------------