├── data └── .gitkeep ├── models ├── tag │ ├── __init__.py │ └── tag_layers.py ├── unet │ ├── __init__.py │ ├── unet_parts.py │ └── res_net.py ├── __pycache__ │ └── discriminator.cpython-311.pyc ├── sam │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── build_sam.cpython-311.pyc │ │ ├── predictor.cpython-311.pyc │ │ └── automatic_mask_generator.cpython-311.pyc │ ├── utils │ │ ├── __pycache__ │ │ │ ├── amg.cpython-311.pyc │ │ │ ├── __init__.cpython-311.pyc │ │ │ └── transforms.cpython-311.pyc │ │ ├── __init__.py │ │ ├── transforms.py │ │ ├── onnx.py │ │ └── amg.py │ ├── modeling │ │ ├── __pycache__ │ │ │ ├── sam.cpython-311.pyc │ │ │ ├── common.cpython-311.pyc │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── mask_decoder.cpython-311.pyc │ │ │ ├── transformer.cpython-311.pyc │ │ │ ├── image_encoder.cpython-311.pyc │ │ │ └── prompt_encoder.cpython-311.pyc │ │ ├── __init__.py │ │ ├── common.py │ │ ├── mask_decoder.py │ │ ├── sam.py │ │ ├── transformer.py │ │ └── prompt_encoder.py │ ├── __init__.py │ ├── build_sam.py │ └── predictor.py ├── types_.py ├── vgg.py ├── squeezenet.py ├── implicitnet.py ├── discriminator.py ├── vae.py ├── senet.py ├── resnet.py └── implicitefficientnet.py ├── figs └── medsamadpt.jpeg ├── .gitignore ├── conf ├── __pycache__ │ ├── __init__.cpython-311.pyc │ └── global_settings.cpython-311.pyc ├── __init__.py └── global_settings.py ├── pytorch_ssim ├── __pycache__ │ └── __init__.cpython-311.pyc └── __init__.py ├── start_train.sh ├── start_val.sh ├── start_predict.sh ├── delete_history.sh ├── .github └── workflows │ └── issue-translator.yml ├── cfg.py ├── predict.py ├── val.py ├── train.py ├── precpt.py ├── post_processing.ipynb ├── README.md ├── dataset.py ├── environment.yml └── pre_processing.ipynb /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/tag/__init__.py: -------------------------------------------------------------------------------- 1 | from .tag import * -------------------------------------------------------------------------------- /models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import TransUNet 2 | -------------------------------------------------------------------------------- /figs/medsamadpt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/figs/medsamadpt.jpeg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /checkpoint 2 | /logs 3 | /runs 4 | pipline.sh 5 | /data/* 6 | !/data/.gitkeep 7 | __pycache__ 8 | my_weights -------------------------------------------------------------------------------- /conf/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/conf/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /conf/__pycache__/global_settings.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/conf/__pycache__/global_settings.cpython-311.pyc -------------------------------------------------------------------------------- /models/__pycache__/discriminator.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/__pycache__/discriminator.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/__pycache__/build_sam.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/__pycache__/build_sam.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/__pycache__/predictor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/__pycache__/predictor.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/utils/__pycache__/amg.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/utils/__pycache__/amg.cpython-311.pyc -------------------------------------------------------------------------------- /models/types_.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Union, Any, TypeVar, Tuple 2 | # from torch import tensor as Tensor 3 | 4 | Tensor = TypeVar('torch.tensor') -------------------------------------------------------------------------------- /pytorch_ssim/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/pytorch_ssim/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/modeling/__pycache__/sam.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/modeling/__pycache__/sam.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/modeling/__pycache__/common.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/modeling/__pycache__/common.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/utils/__pycache__/transforms.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/utils/__pycache__/transforms.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/modeling/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/modeling/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/modeling/__pycache__/mask_decoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/modeling/__pycache__/mask_decoder.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/modeling/__pycache__/transformer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/modeling/__pycache__/transformer.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/__pycache__/automatic_mask_generator.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/__pycache__/automatic_mask_generator.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/modeling/__pycache__/image_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/modeling/__pycache__/image_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /models/sam/modeling/__pycache__/prompt_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Medical-SAM-Adapter/main/models/sam/modeling/__pycache__/prompt_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /start_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python train.py -net sam -mod sam_adpt -exp_name mydataset -sam_ckpt ./checkpoint/sam/sam_vit_b_01ec64.pth -image_size 1024 -out_size 256 -b 1 -dataset mydataset -data_path ./data -val_freq 5 -vis 100 3 | -------------------------------------------------------------------------------- /models/sam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /start_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | echo "Input model weights path: " 3 | read -r weights 4 | 5 | python val.py -net sam -mod sam_adpt -exp_name mydataset -sam_ckpt ./checkpoint/sam/sam_vit_b_01ec64.pth -image_size 1024 -out_size 256 -b 1 -dataset mydataset -data_path ./data -val_freq 1 -vis 1 -weights $weights -------------------------------------------------------------------------------- /start_predict.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | echo "Input model weights path: " 3 | read -r weights 4 | 5 | python predict.py -net sam -mod sam_adpt -exp_name mydataset -sam_ckpt ./checkpoint/sam/sam_vit_b_01ec64.pth -image_size 1024 -out_size 256 -b 1 -dataset mydataset -data_path ./data -val_freq 1 -vis 1 -weights $weights -------------------------------------------------------------------------------- /conf/__init__.py: -------------------------------------------------------------------------------- 1 | """ dynamically load settings 2 | 3 | author baiyu 4 | """ 5 | import conf.global_settings as settings 6 | 7 | class Settings: 8 | def __init__(self, settings): 9 | 10 | for attr in dir(settings): 11 | if attr.isupper(): 12 | setattr(self, attr, getattr(settings, attr)) 13 | 14 | settings = Settings(settings) -------------------------------------------------------------------------------- /delete_history.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # confirm 4 | echo "Are you sure to delete all history? [y/n]" 5 | read -r confirm 6 | if [ "$confirm" != "y" ]; then 7 | echo "Abort." 8 | exit 0 9 | fi 10 | 11 | # delete 12 | echo "Deleting checkpoint" 13 | rm -rf ./checkpoint/sam/2023* 14 | echo "Deleting logs" 15 | rm -rf ./logs/* 16 | echo "Deleting runs" 17 | rm -rf ./runs/* 18 | echo "Done." -------------------------------------------------------------------------------- /models/sam/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /models/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /.github/workflows/issue-translator.yml: -------------------------------------------------------------------------------- 1 | name: 'issue-translator' 2 | on: 3 | issue_comment: 4 | types: [created] 5 | issues: 6 | types: [opened] 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: usthe/issues-translate-action@v2.7 13 | with: 14 | IS_MODIFY_TITLE: false 15 | # not require, default false, . Decide whether to modify the issue title 16 | # if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot. 17 | CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿 18 | # not require. Customize the translation robot prefix message. 19 | -------------------------------------------------------------------------------- /conf/global_settings.py: -------------------------------------------------------------------------------- 1 | """ configurations for this project 2 | 3 | author Junde 4 | """ 5 | import os 6 | from datetime import datetime 7 | 8 | #CIFAR100 dataset path (python version) 9 | #CIFAR100_PATH = '/nfs/private/cifar100/cifar-100-python' 10 | 11 | #mean and std of cifar100 dataset 12 | CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) 13 | CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404) 14 | 15 | GLAUCOMA_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) 16 | GLAUCOMA_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404) 17 | 18 | MASK_TRAIN_MEAN = (2.654204690220496/255) 19 | MASK_TRAIN_STD = (21.46473779720519/255) 20 | 21 | #CIFAR100_TEST_MEAN = (0.5088964127604166, 0.48739301317401956, 0.44194221124387256) 22 | #CIFAR100_TEST_STD = (0.2682515741720801, 0.2573637364478126, 0.2770957707973042) 23 | 24 | #directory to save weights file 25 | CHECKPOINT_PATH = 'checkpoint' 26 | 27 | #total training epoches 28 | EPOCH = 30000 29 | step_size = 10 30 | i = 1 31 | MILESTONES = [] 32 | while i * 5 <= EPOCH: 33 | MILESTONES.append(i* step_size) 34 | i += 1 35 | 36 | #initial learning rate 37 | #INIT_LR = 0.1 38 | 39 | #time of we run the script 40 | TIME_NOW = datetime.now().isoformat() 41 | 42 | #tensorboard log dir 43 | LOG_DIR = 'runs' 44 | 45 | #save weights file per SAVE_EPOCH epoch 46 | SAVE_EPOCH = 10 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """vgg in pytorch 2 | 3 | 4 | [1] Karen Simonyan, Andrew Zisserman 5 | 6 | Very Deep Convolutional Networks for Large-Scale Image Recognition. 7 | https://arxiv.org/abs/1409.1556v6 8 | """ 9 | '''VGG11/13/16/19 in Pytorch.''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | cfg = { 15 | 'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 16 | 'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 17 | 'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 18 | 'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] 19 | } 20 | 21 | class VGG(nn.Module): 22 | 23 | def __init__(self, features, num_class=100): 24 | super().__init__() 25 | self.features = features 26 | 27 | self.classifier = nn.Sequential( 28 | nn.Linear(512, 4096), 29 | nn.ReLU(inplace=True), 30 | nn.Dropout(), 31 | nn.Linear(4096, 4096), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(), 34 | nn.Linear(4096, num_class) 35 | ) 36 | 37 | def forward(self, x): 38 | output = self.features(x) 39 | output = output.view(output.size()[0], -1) 40 | output = self.classifier(output) 41 | 42 | return output 43 | 44 | def make_layers(cfg, batch_norm=False): 45 | layers = [] 46 | 47 | input_channel = 3 48 | for l in cfg: 49 | if l == 'M': 50 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 51 | continue 52 | 53 | layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)] 54 | 55 | if batch_norm: 56 | layers += [nn.BatchNorm2d(l)] 57 | 58 | layers += [nn.ReLU(inplace=True)] 59 | input_channel = l 60 | 61 | return nn.Sequential(*layers) 62 | 63 | def vgg11_bn(): 64 | return VGG(make_layers(cfg['A'], batch_norm=True)) 65 | 66 | def vgg13_bn(): 67 | return VGG(make_layers(cfg['B'], batch_norm=True)) 68 | 69 | def vgg16_bn(): 70 | return VGG(make_layers(cfg['D'], batch_norm=True)) 71 | 72 | def vgg19_bn(): 73 | return VGG(make_layers(cfg['E'], batch_norm=True)) 74 | 75 | 76 | -------------------------------------------------------------------------------- /models/sam/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | class Adapter(nn.Module): 13 | def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): 14 | super().__init__() 15 | self.skip_connect = skip_connect 16 | D_hidden_features = int(D_features * mlp_ratio) 17 | self.act = act_layer() 18 | self.D_fc1 = nn.Linear(D_features, D_hidden_features) 19 | self.D_fc2 = nn.Linear(D_hidden_features, D_features) 20 | 21 | def forward(self, x): 22 | # x is (BT, HW+1, D) 23 | xs = self.D_fc1(x) 24 | xs = self.act(xs) 25 | xs = self.D_fc2(xs) 26 | if self.skip_connect: 27 | x = x + xs 28 | else: 29 | x = xs 30 | return x 31 | 32 | 33 | class MLPBlock(nn.Module): 34 | def __init__( 35 | self, 36 | embedding_dim: int, 37 | mlp_dim: int, 38 | act: Type[nn.Module] = nn.GELU, 39 | ) -> None: 40 | super().__init__() 41 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 42 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 43 | self.act = act() 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | return self.lin2(self.act(self.lin1(x))) 47 | 48 | 49 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 50 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 51 | class LayerNorm2d(nn.Module): 52 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 53 | super().__init__() 54 | self.weight = nn.Parameter(torch.ones(num_channels)) 55 | self.bias = nn.Parameter(torch.zeros(num_channels)) 56 | self.eps = eps 57 | 58 | def forward(self, x: torch.Tensor) -> torch.Tensor: 59 | u = x.mean(1, keepdim=True) 60 | s = (x - u).pow(2).mean(1, keepdim=True) 61 | x = (x - u) / torch.sqrt(s + self.eps) 62 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 63 | return x 64 | -------------------------------------------------------------------------------- /models/squeezenet.py: -------------------------------------------------------------------------------- 1 | """squeezenet in pytorch 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Fire(nn.Module): 8 | 9 | def __init__(self, in_channel, out_channel, squzee_channel): 10 | 11 | super().__init__() 12 | self.squeeze = nn.Sequential( 13 | nn.Conv2d(in_channel, squzee_channel, 1), 14 | nn.BatchNorm2d(squzee_channel), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | self.expand_1x1 = nn.Sequential( 19 | nn.Conv2d(squzee_channel, int(out_channel / 2), 1), 20 | nn.BatchNorm2d(int(out_channel / 2)), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | self.expand_3x3 = nn.Sequential( 25 | nn.Conv2d(squzee_channel, int(out_channel / 2), 3, padding=1), 26 | nn.BatchNorm2d(int(out_channel / 2)), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | def forward(self, x): 31 | 32 | x = self.squeeze(x) 33 | x = torch.cat([ 34 | self.expand_1x1(x), 35 | self.expand_3x3(x) 36 | ], 1) 37 | 38 | return x 39 | 40 | class SqueezeNet(nn.Module): 41 | 42 | """mobile net with simple bypass""" 43 | def __init__(self, class_num=100): 44 | 45 | super().__init__() 46 | self.stem = nn.Sequential( 47 | nn.Conv2d(3, 96, 3, padding=1), 48 | nn.BatchNorm2d(96), 49 | nn.ReLU(inplace=True), 50 | nn.MaxPool2d(2, 2) 51 | ) 52 | 53 | self.fire2 = Fire(96, 128, 16) 54 | self.fire3 = Fire(128, 128, 16) 55 | self.fire4 = Fire(128, 256, 32) 56 | self.fire5 = Fire(256, 256, 32) 57 | self.fire6 = Fire(256, 384, 48) 58 | self.fire7 = Fire(384, 384, 48) 59 | self.fire8 = Fire(384, 512, 64) 60 | self.fire9 = Fire(512, 512, 64) 61 | 62 | self.conv10 = nn.Conv2d(512, class_num, 1) 63 | self.avg = nn.AdaptiveAvgPool2d(1) 64 | self.maxpool = nn.MaxPool2d(2, 2) 65 | 66 | def forward(self, x): 67 | x = self.stem(x) 68 | 69 | f2 = self.fire2(x) 70 | f3 = self.fire3(f2) + f2 71 | f4 = self.fire4(f3) 72 | f4 = self.maxpool(f4) 73 | 74 | f5 = self.fire5(f4) + f4 75 | f6 = self.fire6(f5) 76 | f7 = self.fire7(f6) + f6 77 | f8 = self.fire8(f7) 78 | f8 = self.maxpool(f8) 79 | 80 | f9 = self.fire9(f8) 81 | c10 = self.conv10(f9) 82 | 83 | x = self.avg(c10) 84 | x = x.view(x.size(0), -1) 85 | 86 | return x 87 | 88 | def squeezenet(class_num=1): 89 | return SqueezeNet(class_num=class_num) 90 | -------------------------------------------------------------------------------- /models/unet/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels): 12 | super().__init__() 13 | self.double_conv = nn.Sequential( 14 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(out_channels), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(out_channels), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, x): 23 | return self.double_conv(x) 24 | 25 | 26 | class Down(nn.Module): 27 | """Downscaling with maxpool then double conv""" 28 | 29 | def __init__(self, in_channels, out_channels): 30 | super().__init__() 31 | self.maxpool_conv = nn.Sequential( 32 | nn.MaxPool2d(2), 33 | DoubleConv(in_channels, out_channels) 34 | ) 35 | 36 | def forward(self, x): 37 | return self.maxpool_conv(x) 38 | 39 | 40 | class Up(nn.Module): 41 | """Upscaling then double conv""" 42 | 43 | def __init__(self, in_channels, out_channels, bilinear=True): 44 | super().__init__() 45 | 46 | # if bilinear, use the normal convolutions to reduce the number of channels 47 | if bilinear: 48 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 49 | else: 50 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 51 | 52 | self.conv = DoubleConv(in_channels, out_channels) 53 | 54 | def forward(self, x1, x2): 55 | x1 = self.up(x1) 56 | # input is CHW 57 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 58 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 59 | 60 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 61 | diffY // 2, diffY - diffY // 2]) 62 | # if you have padding issues, see 63 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 64 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 65 | x = torch.cat([x2, x1], dim=1) 66 | return self.conv(x) 67 | 68 | 69 | class OutConv(nn.Module): 70 | def __init__(self, in_channels, out_channels): 71 | super(OutConv, self).__init__() 72 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 73 | 74 | def forward(self, x): 75 | return self.conv(x) 76 | -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /models/implicitnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class LinearBottleNeck(nn.Module): 11 | 12 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=1): 13 | super().__init__() 14 | 15 | self.residual = nn.Sequential( 16 | nn.Conv2d(in_channels, in_channels * t, 1), 17 | nn.BatchNorm2d(in_channels * t), 18 | nn.ReLU6(inplace=True), 19 | 20 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 21 | nn.BatchNorm2d(in_channels * t), 22 | nn.ReLU6(inplace=True), 23 | 24 | nn.Conv2d(in_channels * t, out_channels, 1), 25 | nn.BatchNorm2d(out_channels) 26 | ) 27 | 28 | self.stride = stride 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | 32 | def forward(self, x): 33 | residual = self.residual(x) 34 | 35 | if self.stride == 1 and self.in_channels == self.out_channels: 36 | residual += x 37 | 38 | return residual 39 | 40 | 41 | 42 | 43 | class ImplicitNet(nn.Module): 44 | 45 | def __init__(self, class_num=1): 46 | super().__init__() 47 | 48 | self.pre = nn.Sequential( 49 | nn.Conv2d(5, 32, 1, padding=1), 50 | nn.BatchNorm2d(32), 51 | nn.ReLU6(inplace=True) 52 | ) 53 | 54 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 55 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 56 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 57 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 58 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 59 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 60 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 61 | 62 | self.conv1 = nn.Sequential( 63 | nn.Conv2d(320, 1280, 1), 64 | nn.BatchNorm2d(1280), 65 | nn.ReLU6(inplace=True) 66 | ) 67 | 68 | self.conv2 = nn.Conv2d(1280, class_num, 1) 69 | 70 | self.sigmoid = nn.Sigmoid() 71 | 72 | def forward(self, seg, label, natural): 73 | label = label.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(seg.size()) 74 | 75 | x = torch.cat((label,natural,seg),1) # concated input 76 | x = self.pre(x) 77 | x = self.stage1(x) 78 | x = self.stage2(x) 79 | x = self.stage3(x) 80 | x = self.stage4(x) 81 | x = self.stage5(x) 82 | x = self.stage6(x) 83 | x = self.stage7(x) 84 | x = self.conv1(x) 85 | #x = F.adaptive_avg_pool2d(x, 1) 86 | x = self.conv2(x) # (b,h/s,w/s,1) 87 | x = self.sigmoid(x) 88 | return x 89 | 90 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 91 | layers = [] 92 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 93 | 94 | while repeat - 1: 95 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 96 | repeat -= 1 97 | 98 | return nn.Sequential(*layers) 99 | 100 | 101 | 102 | 103 | def implicitnet(): 104 | return ImplicitNet() -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim as optim 8 | import torch.utils.data 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | import torchvision.utils as vutils 12 | import numpy as np 13 | 14 | # class Discriminator(nn.Module): 15 | # def __init__(self, ngpu, nc = 3, ndf = 64): 16 | # super(Discriminator, self).__init__() 17 | # self.ngpu = ngpu 18 | # self.main = nn.Sequential( 19 | # # input is (nc) x 64 x 64 20 | # nn.Conv2d(nc, ndf, 4, 4, 1, bias=False), 21 | # nn.LeakyReLU(0.2, inplace=True), 22 | # # state size. (ndf) x 32 x 32 23 | # nn.Conv2d(ndf, ndf * 2, 4, 4, 1, bias=False), 24 | # nn.BatchNorm2d(ndf * 2), 25 | # nn.LeakyReLU(0.2, inplace=True), 26 | # # state size. (ndf*2) x 16 x 16 27 | # nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 28 | # nn.BatchNorm2d(ndf * 4), 29 | # nn.LeakyReLU(0.2, inplace=True), 30 | # # state size. (ndf*4) x 8 x 8 31 | # nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 32 | # nn.BatchNorm2d(ndf * 8), 33 | # nn.LeakyReLU(0.2, inplace=True), 34 | # # state size. (ndf*8) x 4 x 4 35 | # nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 36 | # nn.Sigmoid() 37 | # ) 38 | 39 | # def forward(self, input): 40 | # return self.main(input) 41 | 42 | 43 | 44 | class Discriminator(torch.nn.Module): 45 | def __init__(self, channels): 46 | super().__init__() 47 | # Filters [256, 512, 1024] 48 | # Input_dim = channels (Cx64x64) 49 | # Output_dim = 1 50 | self.main_module = nn.Sequential( 51 | # Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid 52 | # in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch. 53 | # There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d() 54 | # Image (Cx32x32) 55 | nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1), 56 | nn.InstanceNorm2d(256, affine=True), 57 | nn.LeakyReLU(0.2, inplace=True), 58 | 59 | # State (256x16x16) 60 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1), 61 | nn.InstanceNorm2d(512, affine=True), 62 | nn.LeakyReLU(0.2, inplace=True), 63 | 64 | # State (512x8x8) 65 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1), 66 | nn.InstanceNorm2d(1024, affine=True), 67 | nn.LeakyReLU(0.2, inplace=True)) 68 | # output of main module --> State (1024x4x4) 69 | 70 | self.output = nn.Sequential( 71 | # The output of D is no longer a probability, we do not apply sigmoid at the output of D. 72 | nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0)) 73 | 74 | 75 | def forward(self, x): 76 | x = self.main_module(x) 77 | return self.output(x) 78 | 79 | def feature_extraction(self, x): 80 | # Use discriminator for feature extraction then flatten to vector of 16384 81 | x = self.main_module(x) 82 | return x.view(-1, 1024*4*4) 83 | 84 | 85 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('-net', type=str, required=True, help='net type') 6 | parser.add_argument('-baseline', type=str, default='unet', help='baseline net type') 7 | parser.add_argument('-seg_net', type=str, default='transunet', help='net type') 8 | parser.add_argument('-mod', type=str, required=True, help='mod type:seg,cls,val_ad') 9 | parser.add_argument('-exp_name', type=str, required=True, help='net type') 10 | parser.add_argument('-type', type=str, default='map', help='condition type:ave,rand,rand_map') 11 | parser.add_argument('-vis', type=int, default=None, help='visualization') 12 | parser.add_argument('-reverse', type=bool, default=False, help='adversary reverse') 13 | parser.add_argument('-pretrain', type=bool, default=False, help='adversary reverse') 14 | parser.add_argument('-val_freq',type=int,default=100,help='interval between each validation') 15 | parser.add_argument('-gpu', type=bool, default=True, help='use gpu or not') 16 | parser.add_argument('-gpu_device', type=int, default=0, help='use which gpu') 17 | parser.add_argument('-sim_gpu', type=int, default=0, help='split sim to this gpu') 18 | parser.add_argument('-epoch_ini', type=int, default=1, help='start epoch') 19 | parser.add_argument('-image_size', type=int, default=256, help='image_size') 20 | parser.add_argument('-out_size', type=int, default=256, help='output_size') 21 | parser.add_argument('-patch_size', type=int, default=2, help='patch_size') 22 | parser.add_argument('-dim', type=int, default=512, help='dim_size') 23 | parser.add_argument('-depth', type=int, default=1, help='depth') 24 | parser.add_argument('-heads', type=int, default=16, help='heads number') 25 | parser.add_argument('-mlp_dim', type=int, default=1024, help='mlp_dim') 26 | parser.add_argument('-w', type=int, default=4, help='number of workers for dataloader') 27 | parser.add_argument('-b', type=int, default=8, help='batch size for dataloader') 28 | parser.add_argument('-s', type=bool, default=True, help='whether shuffle the dataset') 29 | parser.add_argument('-warm', type=int, default=1, help='warm up training phase') 30 | parser.add_argument('-lr', type=float, default=1e-4, help='initial learning rate') 31 | parser.add_argument('-uinch', type=int, default=1, help='input channel of unet') 32 | parser.add_argument('-imp_lr', type=float, default=3e-4, help='implicit learning rate') 33 | parser.add_argument('-weights', type=str, default = 0, help='the weights file you want to test') 34 | parser.add_argument('-base_weights', type=str, default = 0, help='the weights baseline') 35 | parser.add_argument('-sim_weights', type=str, default = 0, help='the weights sim') 36 | parser.add_argument('-distributed', default='none' ,type=str,help='multi GPU ids to use') 37 | parser.add_argument('-dataset', default='isic' ,type=str,help='dataset name') 38 | parser.add_argument('-sam_ckpt', default=None , help='sam checkpoint address') 39 | parser.add_argument('-thd', type=bool, default=False , help='3d or not') 40 | parser.add_argument('-chunk', type=int, default=96 , help='crop volume depth') 41 | parser.add_argument('-num_sample', type=int, default=4 , help='sample pos and neg') 42 | parser.add_argument('-roi_size', type=int, default=96 , help='resolution of roi') 43 | parser.add_argument('-evl_chunk', type=int, default=None , help='evaluation chunk') 44 | parser.add_argument( 45 | '-data_path', 46 | type=str, 47 | default='../data', 48 | help='The path of segmentation data') 49 | # '../dataset/RIGA/DiscRegion' 50 | # '../dataset/ISIC' 51 | opt = parser.parse_args() 52 | 53 | return opt 54 | -------------------------------------------------------------------------------- /models/sam/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length) 40 | new_coords = np.empty_like(coords) 41 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w) 42 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h) 43 | return new_coords 44 | 45 | 46 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 47 | """ 48 | Expects a numpy array shape Bx4. Requires the original image size 49 | in (H, W) format. 50 | """ 51 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 52 | return boxes.reshape(-1, 4) 53 | 54 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Expects batched images with shape BxCxHxW and float format. This 57 | transformation may not exactly match apply_image. apply_image is 58 | the transformation expected by the model. 59 | """ 60 | # Expects an image in BCHW format. May not exactly match apply_image. 61 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 62 | return F.interpolate( 63 | image, target_size, mode="bilinear", align_corners=False, antialias=True 64 | ) 65 | 66 | def apply_coords_torch( 67 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a torch tensor with length 2 in the last dimension. Requires the 71 | original image size in (H, W) format. 72 | """ 73 | old_h, old_w = original_size 74 | new_h, new_w = self.get_preprocess_shape( 75 | original_size[0], original_size[1], self.target_length 76 | ) 77 | coords = deepcopy(coords).to(torch.float) 78 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 79 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 80 | return coords 81 | 82 | def apply_boxes_torch( 83 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 84 | ) -> torch.Tensor: 85 | """ 86 | Expects a torch tensor with shape Bx4. Requires the original image 87 | size in (H, W) format. 88 | """ 89 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 90 | return boxes.reshape(-1, 4) 91 | 92 | @staticmethod 93 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 94 | """ 95 | Compute the output size given input size and target long side length. 96 | """ 97 | scale = long_side_length * 1.0 / max(oldh, oldw) 98 | newh, neww = oldh * scale, oldw * scale 99 | neww = int(neww + 0.5) 100 | newh = int(newh + 0.5) 101 | return (newh, neww) 102 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # train.py 2 | #!/usr/bin/env python3 3 | 4 | """ valuate network using pytorch 5 | Junde Wu 6 | """ 7 | 8 | import os 9 | import sys 10 | import argparse 11 | from datetime import datetime 12 | from collections import OrderedDict 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | from skimage import io 21 | from torch.utils.data import DataLoader 22 | #from dataset import * 23 | from torch.autograd import Variable 24 | from PIL import Image 25 | from tensorboardX import SummaryWriter 26 | #from models.discriminatorlayer import discriminator 27 | from dataset import * 28 | from conf import settings 29 | import time 30 | import cfg 31 | from tqdm import tqdm 32 | from torch.utils.data import DataLoader, random_split 33 | from utils import * 34 | import function 35 | 36 | 37 | args = cfg.parse_args() 38 | if args.dataset == 'refuge' or args.dataset == 'refuge2': 39 | args.data_path = '../dataset' 40 | 41 | GPUdevice = torch.device('cuda', args.gpu_device) 42 | 43 | net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) 44 | 45 | '''load pretrained model''' 46 | assert args.weights != 0 47 | print(f'=> resuming from {args.weights}') 48 | assert os.path.exists(args.weights) 49 | checkpoint_file = os.path.join(args.weights) 50 | assert os.path.exists(checkpoint_file) 51 | loc = 'cuda:{}'.format(args.gpu_device) 52 | checkpoint = torch.load(checkpoint_file, map_location=loc) 53 | start_epoch = checkpoint['epoch'] 54 | best_tol = checkpoint['best_tol'] 55 | 56 | state_dict = checkpoint['state_dict'] 57 | if args.distributed != 'none': 58 | from collections import OrderedDict 59 | new_state_dict = OrderedDict() 60 | for k, v in state_dict.items(): 61 | # name = k[7:] # remove `module.` 62 | name = 'module.' + k 63 | new_state_dict[name] = v 64 | # load params 65 | else: 66 | new_state_dict = state_dict 67 | 68 | net.load_state_dict(new_state_dict) 69 | 70 | # args.path_helper = checkpoint['path_helper'] 71 | # logger = create_logger(args.path_helper['log_path']) 72 | # print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') 73 | 74 | # args.path_helper = set_log_dir('logs', args.exp_name) 75 | # logger = create_logger(args.path_helper['log_path']) 76 | # logger.info(args) 77 | 78 | args.path_helper = set_log_dir('logs', args.exp_name) 79 | logger = create_logger(args.path_helper['log_path']) 80 | logger.info(args) 81 | 82 | '''segmentation data''' 83 | transform_train = transforms.Compose([ 84 | transforms.Resize((args.image_size,args.image_size)), 85 | transforms.ToTensor(), 86 | ]) 87 | 88 | transform_train_seg = transforms.Compose([ 89 | transforms.ToTensor(), 90 | transforms.Resize((args.image_size,args.image_size)), 91 | ]) 92 | 93 | transform_test = transforms.Compose([ 94 | transforms.Resize((args.image_size, args.image_size)), 95 | transforms.ToTensor(), 96 | ]) 97 | 98 | transform_test_seg = transforms.Compose([ 99 | transforms.ToTensor(), 100 | transforms.Resize((args.image_size, args.image_size)), 101 | 102 | ]) 103 | '''data end''' 104 | if args.dataset == 'isic': 105 | print("not implemented") 106 | exit() 107 | '''isic data''' 108 | isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') 109 | isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') 110 | 111 | nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) 112 | nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) 113 | '''end''' 114 | 115 | elif args.dataset == 'decathlon': 116 | print("not implemented") 117 | exit() 118 | nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args) 119 | 120 | 121 | elif args.dataset == 'mydataset': 122 | '''mydataset''' 123 | mydata_predict_dataset = MyDataset(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'predict') 124 | 125 | nice_predict_loader = DataLoader(mydata_predict_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) 126 | '''end''' 127 | 128 | '''begain valuation''' 129 | best_acc = 0.0 130 | best_tol = 1e4 131 | 132 | if args.mod == 'sam_adpt': 133 | net.eval() 134 | function.predict_sam(args, nice_predict_loader, start_epoch, net) 135 | 136 | 137 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | # train.py 2 | #!/usr/bin/env python3 3 | 4 | """ valuate network using pytorch 5 | Junde Wu 6 | """ 7 | 8 | import os 9 | import sys 10 | import argparse 11 | from datetime import datetime 12 | from collections import OrderedDict 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | from skimage import io 21 | from torch.utils.data import DataLoader 22 | #from dataset import * 23 | from torch.autograd import Variable 24 | from PIL import Image 25 | from tensorboardX import SummaryWriter 26 | #from models.discriminatorlayer import discriminator 27 | from dataset import * 28 | from conf import settings 29 | import time 30 | import cfg 31 | from tqdm import tqdm 32 | from torch.utils.data import DataLoader, random_split 33 | from utils import * 34 | import function 35 | 36 | 37 | args = cfg.parse_args() 38 | if args.dataset == 'refuge' or args.dataset == 'refuge2': 39 | args.data_path = '../dataset' 40 | 41 | GPUdevice = torch.device('cuda', args.gpu_device) 42 | 43 | net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) 44 | 45 | '''load pretrained model''' 46 | assert args.weights != 0 47 | print(f'=> resuming from {args.weights}') 48 | assert os.path.exists(args.weights) 49 | checkpoint_file = os.path.join(args.weights) 50 | assert os.path.exists(checkpoint_file) 51 | loc = 'cuda:{}'.format(args.gpu_device) 52 | checkpoint = torch.load(checkpoint_file, map_location=loc) 53 | start_epoch = checkpoint['epoch'] 54 | best_tol = checkpoint['best_tol'] 55 | 56 | state_dict = checkpoint['state_dict'] 57 | if args.distributed != 'none': 58 | from collections import OrderedDict 59 | new_state_dict = OrderedDict() 60 | for k, v in state_dict.items(): 61 | # name = k[7:] # remove `module.` 62 | name = 'module.' + k 63 | new_state_dict[name] = v 64 | # load params 65 | else: 66 | new_state_dict = state_dict 67 | 68 | net.load_state_dict(new_state_dict) 69 | 70 | # args.path_helper = checkpoint['path_helper'] 71 | # logger = create_logger(args.path_helper['log_path']) 72 | # print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') 73 | 74 | # args.path_helper = set_log_dir('logs', args.exp_name) 75 | # logger = create_logger(args.path_helper['log_path']) 76 | # logger.info(args) 77 | 78 | args.path_helper = set_log_dir('logs', args.exp_name) 79 | logger = create_logger(args.path_helper['log_path']) 80 | logger.info(args) 81 | 82 | '''segmentation data''' 83 | transform_train = transforms.Compose([ 84 | transforms.Resize((args.image_size,args.image_size)), 85 | transforms.ToTensor(), 86 | ]) 87 | 88 | transform_train_seg = transforms.Compose([ 89 | transforms.ToTensor(), 90 | transforms.Resize((args.out_size,args.out_size)), 91 | ]) 92 | 93 | transform_test = transforms.Compose([ 94 | transforms.Resize((args.image_size, args.image_size)), 95 | transforms.ToTensor(), 96 | ]) 97 | 98 | transform_test_seg = transforms.Compose([ 99 | transforms.ToTensor(), 100 | transforms.Resize((args.out_size, args.out_size)), 101 | 102 | ]) 103 | '''data end''' 104 | if args.dataset == 'isic': 105 | '''isic data''' 106 | isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') 107 | isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') 108 | 109 | nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) 110 | nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) 111 | '''end''' 112 | 113 | elif args.dataset == 'decathlon': 114 | nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args) 115 | 116 | 117 | elif args.dataset == 'mydataset': 118 | '''mydataset''' 119 | mydata_train_dataset = MyDataset(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'train') 120 | mydata_test_dataset = MyDataset(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'test') 121 | 122 | nice_train_loader = DataLoader(mydata_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) 123 | nice_test_loader = DataLoader(mydata_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) 124 | '''end''' 125 | 126 | '''begain valuation''' 127 | best_acc = 0.0 128 | best_tol = 1e4 129 | 130 | if args.mod == 'sam_adpt': 131 | net.eval() 132 | tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net) 133 | logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.') 134 | -------------------------------------------------------------------------------- /models/sam/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from functools import partial 7 | from pathlib import Path 8 | import urllib.request 9 | import torch 10 | 11 | from .modeling import ( 12 | ImageEncoderViT, 13 | MaskDecoder, 14 | PromptEncoder, 15 | Sam, 16 | TwoWayTransformer, 17 | ) 18 | 19 | 20 | def build_sam_vit_h(args = None, checkpoint=None): 21 | return _build_sam( 22 | args, 23 | encoder_embed_dim=1280, 24 | encoder_depth=32, 25 | encoder_num_heads=16, 26 | encoder_global_attn_indexes=[7, 15, 23, 31], 27 | checkpoint=checkpoint, 28 | ) 29 | 30 | 31 | build_sam = build_sam_vit_h 32 | 33 | 34 | def build_sam_vit_l(args, checkpoint=None): 35 | return _build_sam( 36 | args, 37 | encoder_embed_dim=1024, 38 | encoder_depth=24, 39 | encoder_num_heads=16, 40 | encoder_global_attn_indexes=[5, 11, 17, 23], 41 | checkpoint=checkpoint, 42 | ) 43 | 44 | 45 | def build_sam_vit_b(args, checkpoint=None): 46 | return _build_sam( 47 | args, 48 | encoder_embed_dim=768, 49 | encoder_depth=12, 50 | encoder_num_heads=12, 51 | encoder_global_attn_indexes=[2, 5, 8, 11], 52 | checkpoint=checkpoint, 53 | ) 54 | 55 | 56 | sam_model_registry = { 57 | "default": build_sam_vit_h, 58 | "vit_h": build_sam_vit_h, 59 | "vit_l": build_sam_vit_l, 60 | "vit_b": build_sam_vit_b, 61 | } 62 | 63 | 64 | def _build_sam( 65 | args, 66 | encoder_embed_dim, 67 | encoder_depth, 68 | encoder_num_heads, 69 | encoder_global_attn_indexes, 70 | checkpoint=None, 71 | ): 72 | prompt_embed_dim = 256 73 | image_size = 1024 74 | vit_patch_size = 16 75 | image_embedding_size = image_size // vit_patch_size 76 | sam = Sam( 77 | args, 78 | image_encoder=ImageEncoderViT( 79 | args = args, 80 | depth=encoder_depth, 81 | embed_dim=encoder_embed_dim, 82 | img_size=image_size, 83 | mlp_ratio=4, 84 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 85 | num_heads=encoder_num_heads, 86 | patch_size=vit_patch_size, 87 | qkv_bias=True, 88 | use_rel_pos=True, 89 | global_attn_indexes=encoder_global_attn_indexes, 90 | window_size=14, 91 | out_chans=prompt_embed_dim, 92 | ), 93 | prompt_encoder=PromptEncoder( 94 | embed_dim=prompt_embed_dim, 95 | image_embedding_size=(image_embedding_size, image_embedding_size), 96 | input_image_size=(image_size, image_size), 97 | mask_in_chans=16, 98 | ), 99 | mask_decoder=MaskDecoder( 100 | num_multimask_outputs=3, 101 | transformer=TwoWayTransformer( 102 | depth=2, 103 | embedding_dim=prompt_embed_dim, 104 | mlp_dim=2048, 105 | num_heads=8, 106 | ), 107 | transformer_dim=prompt_embed_dim, 108 | iou_head_depth=3, 109 | iou_head_hidden_dim=256, 110 | ), 111 | pixel_mean=[123.675, 116.28, 103.53], 112 | pixel_std=[58.395, 57.12, 57.375], 113 | ) 114 | sam.eval() 115 | checkpoint = Path(checkpoint) 116 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists(): 117 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ") 118 | if len(cmd) == 0 or cmd.lower() == 'y': 119 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 120 | print("Downloading SAM ViT-B checkpoint...") 121 | urllib.request.urlretrieve( 122 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 123 | checkpoint, 124 | ) 125 | print(checkpoint.name, " is downloaded!") 126 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists(): 127 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ") 128 | if len(cmd) == 0 or cmd.lower() == 'y': 129 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 130 | print("Downloading SAM ViT-H checkpoint...") 131 | urllib.request.urlretrieve( 132 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 133 | checkpoint, 134 | ) 135 | print(checkpoint.name, " is downloaded!") 136 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists(): 137 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ") 138 | if len(cmd) == 0 or cmd.lower() == 'y': 139 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 140 | print("Downloading SAM ViT-L checkpoint...") 141 | urllib.request.urlretrieve( 142 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 143 | checkpoint, 144 | ) 145 | print(checkpoint.name, " is downloaded!") 146 | 147 | 148 | if checkpoint is not None: 149 | with open(checkpoint, "rb") as f: 150 | state_dict = torch.load(f) 151 | sam.load_state_dict(state_dict, strict = False) 152 | return sam 153 | -------------------------------------------------------------------------------- /models/tag/tag_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from timm.models.layers import trunc_normal_ 6 | 7 | 8 | Norm = nn.LayerNorm 9 | 10 | 11 | def apply_pos(tensor, pos, num_heads): 12 | if pos is None: 13 | return tensor 14 | elif len(tensor.shape) != len(pos.shape): 15 | tensor = rearrange(tensor, "b n (g c) -> b n g c", g=num_heads) 16 | tensor = tensor + pos 17 | tensor = rearrange(tensor, "b n g c -> b n (g c)") 18 | else: 19 | tensor = tensor + pos 20 | 21 | return tensor 22 | 23 | 24 | class FullRelPos(nn.Module): 25 | def __init__(self, h, w, dim, drop_ratio=0.): 26 | super(FullRelPos, self).__init__() 27 | self.h, self.w = h, w 28 | self.rel_emb_h = nn.Parameter(torch.Tensor(2 * h - 1, dim // 2)) # [-(q-1), q-1] 29 | self.rel_emb_w = nn.Parameter(torch.Tensor(2 * w - 1, dim // 2)) # [-(q-1), q-1] 30 | 31 | # get relative coordinates of the q-k index table 32 | coords_h = torch.arange(h) 33 | coords_w = torch.arange(w) 34 | self.rel_idx_h = coords_h[None, :] - coords_h[:, None] 35 | self.rel_idx_w = coords_w[None, :] - coords_w[:, None] 36 | self.rel_idx_h += h - 1 37 | self.rel_idx_w += w - 1 38 | 39 | nn.init.normal_(self.rel_emb_w, std=dim ** -0.5) 40 | nn.init.normal_(self.rel_emb_h, std=dim ** -0.5) 41 | trunc_normal_(self.rel_emb_w, std=.02) 42 | trunc_normal_(self.rel_emb_h, std=.02) 43 | self.drop_ratio = drop_ratio 44 | 45 | def forward(self, q, attn): 46 | abs_pos_h = self.rel_emb_h[self.rel_idx_h.view(-1)] 47 | abs_pos_w = self.rel_emb_w[self.rel_idx_w.view(-1)] 48 | abs_pos_h = rearrange(abs_pos_h, "(q k) c -> q k c", q=self.h) # [qh, kh, c] 49 | abs_pos_w = rearrange(abs_pos_w, "(q k) c -> q k c", q=self.w) # [qw, kw, c] 50 | 51 | q = rearrange(q, "b (qh qw) g (n c) -> b qh qw g n c", qh=self.h, qw=self.w, n=2) 52 | logits_h = torch.einsum("b h w g c, h k c -> b h w g k", q[..., 0, :], abs_pos_h) 53 | logits_w = torch.einsum("b h w g c, w k c -> b h w g k", q[..., 1, :], abs_pos_w) 54 | logits_h = rearrange(logits_h, "b h w g k -> b (h w) g k 1") 55 | logits_w = rearrange(logits_w, "b h w g k -> b (h w) g 1 k") 56 | 57 | attn = rearrange(attn, "b q g (kh kw) -> b q g kh kw", kh=self.h, kw=self.w) 58 | attn += logits_h 59 | attn += logits_w 60 | return rearrange(attn, "b q g h w -> b q g (h w)") 61 | 62 | 63 | class SimpleReasoning(nn.Module): 64 | def __init__(self, np, dim): 65 | super(SimpleReasoning, self).__init__() 66 | self.norm = Norm(dim) 67 | self.linear = nn.Conv1d(np, np, kernel_size=1, bias=False) 68 | 69 | def forward(self, x): 70 | tokens = self.norm(x) 71 | tokens = self.linear(tokens) 72 | return x + tokens 73 | 74 | 75 | class AnyAttention(nn.Module): 76 | def __init__(self, dim, num_heads, qkv_bias=False): 77 | super(AnyAttention, self).__init__() 78 | self.norm_q, self.norm_k, self.norm_v = Norm(dim), Norm(dim), Norm(dim) 79 | self.to_q = nn.Linear(dim, dim, bias=qkv_bias) 80 | self.to_k = nn.Linear(dim, dim, bias=qkv_bias) 81 | self.to_v = nn.Linear(dim, dim, bias=qkv_bias) 82 | 83 | self.scale = (dim / num_heads) ** (-0.5) 84 | self.num_heads = num_heads 85 | self.proj = nn.Linear(dim, dim) 86 | 87 | def get_qkv(self, q, k, v, qpos, kpos): 88 | q = apply_pos(q, qpos, self.num_heads) 89 | k = apply_pos(k, kpos, self.num_heads) 90 | v = apply_pos(v, None, 0) 91 | q, k, v = self.norm_q(q), self.norm_k(k), self.norm_v(v) 92 | q, k, v = self.to_q(q), self.to_k(k), self.to_v(v) 93 | return q, k, v 94 | 95 | def forward(self, q=None, k=None, v=None, qpos=None, kpos=None, mask=None, rel_pos=None): 96 | q, k, v = self.get_qkv(q, k, v, qpos, kpos) 97 | 98 | # reshape 99 | q = rearrange(q, "b n (g c) -> b n g c", g=self.num_heads) 100 | k = rearrange(k, "b n (g c) -> b n g c", g=self.num_heads) 101 | v = rearrange(v, "b n (g c) -> b n g c", g=self.num_heads) 102 | 103 | # attn matrix calculation 104 | attn = torch.einsum("b q g c, b k g c -> b q g k", q, k) 105 | if rel_pos is not None: 106 | attn = rel_pos(q, attn) 107 | attn *= self.scale 108 | if mask is not None: 109 | attn = attn.masked_fill(mask.bool(), value=float('-inf')) 110 | attn = F.softmax(attn, dim=-1) 111 | if mask is not None: 112 | attn = attn.masked_fill(mask.bool(), value=0) 113 | out = torch.einsum("b q g k, b k g c -> b q g c", attn, v.float()) 114 | out = rearrange(out, "b q g c -> b q (g c)") 115 | out = self.proj(out) 116 | return out 117 | 118 | 119 | class Mlp(nn.Module): 120 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 121 | norm_layer=nn.LayerNorm, drop=0.): 122 | super().__init__() 123 | out_features = out_features or in_features 124 | hidden_features = int(hidden_features) or in_features 125 | self.norm = norm_layer(in_features) 126 | self.fc1 = nn.Linear(in_features, hidden_features) 127 | self.act = act_layer() 128 | self.fc2 = nn.Linear(hidden_features, out_features) 129 | self.drop = nn.Dropout(drop) 130 | 131 | def forward(self, x): 132 | x = self.norm(x) 133 | x = self.fc1(x) 134 | x = self.act(x) 135 | x = self.drop(x) 136 | x = self.fc2(x) 137 | x = self.drop(x) 138 | return x -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | # from .types_ import * 7 | 8 | 9 | class VanillaVAE(nn.Module): 10 | def __init__(self,args, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims = None, 14 | **kwargs) -> None: 15 | super(VanillaVAE, self).__init__() 16 | 17 | self.latent_dim = latent_dim 18 | 19 | modules = [] 20 | if hidden_dims is None: 21 | hidden_dims = [32, 64, 128, 256, 512] 22 | 23 | if latent_dim is None: 24 | latent_dim = 512 25 | 26 | # Build Encoder 27 | for h_dim in hidden_dims: 28 | modules.append( 29 | nn.Sequential( 30 | nn.Conv2d(in_channels, out_channels=h_dim, 31 | kernel_size= 3, stride= 2, padding = 1), 32 | nn.BatchNorm2d(h_dim), 33 | nn.LeakyReLU()) 34 | ) 35 | in_channels = h_dim 36 | 37 | self.encoder = nn.Sequential(*modules) 38 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 39 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 40 | 41 | 42 | # Build Decoder 43 | modules = [] 44 | 45 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 46 | 47 | hidden_dims.reverse() 48 | 49 | for i in range(len(hidden_dims) - 1): 50 | modules.append( 51 | nn.Sequential( 52 | nn.ConvTranspose2d(hidden_dims[i], 53 | hidden_dims[i + 1], 54 | kernel_size=3, 55 | stride = 2, 56 | padding=1, 57 | output_padding=1), 58 | nn.BatchNorm2d(hidden_dims[i + 1]), 59 | nn.LeakyReLU()) 60 | ) 61 | 62 | 63 | 64 | self.decoder = nn.Sequential(*modules) 65 | 66 | self.final_layer = nn.Sequential( 67 | nn.ConvTranspose2d(hidden_dims[-1], 68 | hidden_dims[-1], 69 | kernel_size=3, 70 | stride=2, 71 | padding=1, 72 | output_padding=1), 73 | nn.BatchNorm2d(hidden_dims[-1]), 74 | nn.LeakyReLU(), 75 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 76 | kernel_size= 3, padding= 1), 77 | nn.Tanh()) 78 | 79 | def encode(self, input): 80 | """ 81 | Encodes the input by passing through the encoder network 82 | and returns the latent codes. 83 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 84 | :return: (Tensor) List of latent codes 85 | """ 86 | result = self.encoder(input) 87 | result = torch.flatten(result, start_dim=1) 88 | 89 | # Split the result into mu and var components 90 | # of the latent Gaussian distribution 91 | mu = self.fc_mu(result) 92 | # log_var = self.fc_var(result) 93 | 94 | return mu 95 | 96 | def decode(self, z): 97 | """ 98 | Maps the given latent codes 99 | onto the image space. 100 | :param z: (Tensor) [B x D] 101 | :return: (Tensor) [B x C x H x W] 102 | """ 103 | result = self.decoder_input(z) 104 | result = result.view(-1, 512, 2, 2) 105 | result = self.decoder(result) 106 | result = self.final_layer(result) 107 | return result 108 | 109 | # def reparameterize(self, mu, logvar): 110 | # """ 111 | # Reparameterization trick to sample from N(mu, var) from 112 | # N(0,1). 113 | # :param mu: (Tensor) Mean of the latent Gaussian [B x D] 114 | # :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 115 | # :return: (Tensor) [B x D] 116 | # """ 117 | # std = torch.exp(0.5 * logvar) 118 | # eps = torch.randn_like(std) 119 | # return eps * std + mu 120 | 121 | def forward(self, input, **kwargs): 122 | mu = self.encode(input) 123 | # z = self.reparameterize(mu, log_var) 124 | return self.decode(mu) 125 | 126 | def loss_function(self, 127 | *args, 128 | **kwargs) -> dict: 129 | """ 130 | Computes the VAE loss function. 131 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 132 | :param args: 133 | :param kwargs: 134 | :return: 135 | """ 136 | recons = args[0] 137 | input = args[1] 138 | # mu = args[2] 139 | # log_var = args[3] 140 | 141 | # kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 142 | recons_loss =F.mse_loss(recons, input) 143 | 144 | 145 | # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 146 | 147 | loss = recons_loss 148 | return loss 149 | # {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':recons_loss.detach()} 150 | 151 | 152 | def generate(self, x, **kwargs): 153 | """ 154 | Given an input image x, returns the reconstructed image 155 | :param x: (Tensor) [B x C x H x W] 156 | :return: (Tensor) [B x C x H x W] 157 | """ 158 | 159 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | """senet in pytorch 2 | 3 | 4 | 5 | [1] Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu 6 | 7 | Squeeze-and-Excitation Networks 8 | https://arxiv.org/abs/1709.01507 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class BasicResidualSEBlock(nn.Module): 16 | 17 | expansion = 1 18 | 19 | def __init__(self, in_channels, out_channels, stride, r=16): 20 | super().__init__() 21 | 22 | self.residual = nn.Sequential( 23 | nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1), 24 | nn.BatchNorm2d(out_channels), 25 | nn.ReLU(inplace=True), 26 | 27 | nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1), 28 | nn.BatchNorm2d(out_channels * self.expansion), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_channels != out_channels * self.expansion: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), 36 | nn.BatchNorm2d(out_channels * self.expansion) 37 | ) 38 | 39 | self.squeeze = nn.AdaptiveAvgPool2d(1) 40 | self.excitation = nn.Sequential( 41 | nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), 42 | nn.ReLU(inplace=True), 43 | nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def forward(self, x): 48 | shortcut = self.shortcut(x) 49 | residual = self.residual(x) 50 | 51 | squeeze = self.squeeze(residual) 52 | squeeze = squeeze.view(squeeze.size(0), -1) 53 | excitation = self.excitation(squeeze) 54 | excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) 55 | 56 | x = residual * excitation.expand_as(residual) + shortcut 57 | 58 | return F.relu(x) 59 | 60 | class BottleneckResidualSEBlock(nn.Module): 61 | 62 | expansion = 4 63 | 64 | def __init__(self, in_channels, out_channels, stride, r=16): 65 | super().__init__() 66 | 67 | self.residual = nn.Sequential( 68 | nn.Conv2d(in_channels, out_channels, 1), 69 | nn.BatchNorm2d(out_channels), 70 | nn.ReLU(inplace=True), 71 | 72 | nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1), 73 | nn.BatchNorm2d(out_channels), 74 | nn.ReLU(inplace=True), 75 | 76 | nn.Conv2d(out_channels, out_channels * self.expansion, 1), 77 | nn.BatchNorm2d(out_channels * self.expansion), 78 | nn.ReLU(inplace=True) 79 | ) 80 | 81 | self.squeeze = nn.AdaptiveAvgPool2d(1) 82 | self.excitation = nn.Sequential( 83 | nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), 84 | nn.ReLU(inplace=True), 85 | nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), 86 | nn.Sigmoid() 87 | ) 88 | 89 | self.shortcut = nn.Sequential() 90 | if stride != 1 or in_channels != out_channels * self.expansion: 91 | self.shortcut = nn.Sequential( 92 | nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), 93 | nn.BatchNorm2d(out_channels * self.expansion) 94 | ) 95 | 96 | def forward(self, x): 97 | 98 | shortcut = self.shortcut(x) 99 | 100 | residual = self.residual(x) 101 | squeeze = self.squeeze(residual) 102 | squeeze = squeeze.view(squeeze.size(0), -1) 103 | excitation = self.excitation(squeeze) 104 | excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) 105 | 106 | x = residual * excitation.expand_as(residual) + shortcut 107 | 108 | return F.relu(x) 109 | 110 | class SEResNet(nn.Module): 111 | 112 | def __init__(self, block, block_num, class_num=1): 113 | super().__init__() 114 | 115 | self.in_channels = 64 116 | 117 | self.pre = nn.Sequential( 118 | nn.Conv2d(3, 64, 3, padding=1), 119 | nn.BatchNorm2d(64), 120 | nn.ReLU(inplace=True) 121 | ) 122 | 123 | self.stage1 = self._make_stage(block, block_num[0], 64, 1) 124 | self.stage2 = self._make_stage(block, block_num[1], 128, 2) 125 | self.stage3 = self._make_stage(block, block_num[2], 256, 2) 126 | self.stage4 = self._make_stage(block, block_num[3], 516, 2) 127 | 128 | self.linear = nn.Linear(self.in_channels, class_num) 129 | 130 | def forward(self, x): 131 | x = self.pre(x) 132 | 133 | x = self.stage1(x) 134 | x = self.stage2(x) 135 | x = self.stage3(x) 136 | x = self.stage4(x) 137 | 138 | x = F.adaptive_avg_pool2d(x, 1) 139 | x = x.view(x.size(0), -1) 140 | 141 | x = self.linear(x) 142 | 143 | return x 144 | 145 | 146 | def _make_stage(self, block, num, out_channels, stride): 147 | 148 | layers = [] 149 | layers.append(block(self.in_channels, out_channels, stride)) 150 | self.in_channels = out_channels * block.expansion 151 | 152 | while num - 1: 153 | layers.append(block(self.in_channels, out_channels, 1)) 154 | num -= 1 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def seresnet18(): 159 | return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2]) 160 | 161 | def seresnet34(): 162 | return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3]) 163 | 164 | def seresnet50(): 165 | return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3]) 166 | 167 | def seresnet101(): 168 | return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3]) 169 | 170 | def seresnet152(): 171 | return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3]) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | 4 | 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 6 | 7 | Deep Residual Learning for Image Recognition 8 | https://arxiv.org/abs/1512.03385v1 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class BasicBlock(nn.Module): 15 | """Basic Block for resnet 18 and resnet 34 16 | 17 | """ 18 | 19 | #BasicBlock and BottleNeck block 20 | #have different output size 21 | #we use class attribute expansion 22 | #to distinct 23 | expansion = 1 24 | 25 | def __init__(self, in_channels, out_channels, stride=1): 26 | super().__init__() 27 | 28 | #residual function 29 | self.residual_function = nn.Sequential( 30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 31 | nn.BatchNorm2d(out_channels), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | #shortcut 38 | self.shortcut = nn.Sequential() 39 | 40 | #the shortcut output dimension is not the same with residual function 41 | #use 1*1 convolution to match the dimension 42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 46 | ) 47 | 48 | def forward(self, x): 49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 50 | 51 | class BottleNeck(nn.Module): 52 | """Residual block for resnet over 50 layers 53 | 54 | """ 55 | expansion = 4 56 | def __init__(self, in_channels, out_channels, stride=1): 57 | super().__init__() 58 | self.residual_function = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 60 | nn.BatchNorm2d(out_channels), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 63 | nn.BatchNorm2d(out_channels), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 67 | ) 68 | 69 | self.shortcut = nn.Sequential() 70 | 71 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 74 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 75 | ) 76 | 77 | def forward(self, x): 78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 79 | 80 | class ResNet(nn.Module): 81 | 82 | def __init__(self, block, num_block, num_classes=1): 83 | super().__init__() 84 | 85 | self.in_channels = 64 86 | 87 | self.conv1 = nn.Sequential( 88 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 89 | nn.BatchNorm2d(64), 90 | nn.ReLU(inplace=True)) 91 | #we use a different inputsize than the original paper 92 | #so conv2_x's stride is 1 93 | self.conv2_x = self._make_layer(block, 64, num_block[0], 2) 94 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 95 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 96 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 97 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 98 | self.fc = nn.Linear(512 * block.expansion, num_classes) 99 | 100 | def _make_layer(self, block, out_channels, num_blocks, stride): 101 | """make resnet layers(by layer i didnt mean this 'layer' was the 102 | same as a neuron netowork layer, ex. conv layer), one layer may 103 | contain more than one residual block 104 | 105 | Args: 106 | block: block type, basic block or bottle neck block 107 | out_channels: output depth channel number of this layer 108 | num_blocks: how many blocks per layer 109 | stride: the stride of the first block of this layer 110 | 111 | Return: 112 | return a resnet layer 113 | """ 114 | 115 | # we have num_block blocks per layer, the first block 116 | # could be 1 or 2, other blocks would always be 1 117 | strides = [stride] + [1] * (num_blocks - 1) 118 | layers = [] 119 | for stride in strides: 120 | layers.append(block(self.in_channels, out_channels, stride)) 121 | self.in_channels = out_channels * block.expansion 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | output = self.conv1(x) 127 | output = self.conv2_x(output) 128 | output = self.conv3_x(output) 129 | output = self.conv4_x(output) 130 | output = self.conv5_x(output) 131 | output = self.avg_pool(output) 132 | output = output.view(output.size(0), -1) 133 | output = self.fc(output) 134 | 135 | return output 136 | 137 | def resnet18(): 138 | """ return a ResNet 18 object 139 | """ 140 | return ResNet(BasicBlock, [2, 2, 2, 2]) 141 | 142 | def resnet34(): 143 | """ return a ResNet 34 object 144 | """ 145 | return ResNet(BasicBlock, [3, 4, 6, 3]) 146 | 147 | def resnet50(): 148 | """ return a ResNet 50 object 149 | """ 150 | return ResNet(BottleNeck, [3, 4, 6, 3]) 151 | 152 | def resnet101(): 153 | """ return a ResNet 101 object 154 | """ 155 | return ResNet(BottleNeck, [3, 4, 23, 3]) 156 | 157 | def resnet152(): 158 | """ return a ResNet 152 object 159 | """ 160 | return ResNet(BottleNeck, [3, 8, 36, 3]) 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /models/sam/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # train.py 2 | #!/usr/bin/env python3 3 | 4 | """ train network using pytorch 5 | Junde Wu 6 | """ 7 | 8 | import os 9 | from datetime import datetime 10 | from collections import OrderedDict 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix 16 | import torchvision 17 | import torchvision.transforms as transforms 18 | from skimage import io 19 | from torch.utils.data import DataLoader 20 | #from dataset import * 21 | from torch.autograd import Variable 22 | from PIL import Image 23 | from tensorboardX import SummaryWriter 24 | #from models.discriminatorlayer import discriminator 25 | from dataset import * 26 | from conf import settings 27 | import time 28 | import cfg 29 | from tqdm import tqdm 30 | from torch.utils.data import DataLoader, random_split 31 | from utils import * 32 | import function 33 | 34 | 35 | args = cfg.parse_args() 36 | 37 | GPUdevice = torch.device('cuda', args.gpu_device) 38 | 39 | net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) 40 | if args.pretrain: 41 | weights = torch.load(args.pretrain) 42 | net.load_state_dict(weights,strict=False) 43 | 44 | optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) 45 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay 46 | 47 | '''load pretrained model''' 48 | if args.weights != 0: 49 | print(f'=> resuming from {args.weights}') 50 | assert os.path.exists(args.weights) 51 | checkpoint_file = os.path.join(args.weights) 52 | assert os.path.exists(checkpoint_file) 53 | loc = 'cuda:{}'.format(args.gpu_device) 54 | checkpoint = torch.load(checkpoint_file, map_location=loc) 55 | start_epoch = checkpoint['epoch'] 56 | best_tol = checkpoint['best_tol'] 57 | 58 | net.load_state_dict(checkpoint['state_dict'],strict=False) 59 | # optimizer.load_state_dict(checkpoint['optimizer'], strict=False) 60 | 61 | args.path_helper = checkpoint['path_helper'] 62 | logger = create_logger(args.path_helper['log_path']) 63 | print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') 64 | 65 | args.path_helper = set_log_dir('logs', args.exp_name) 66 | logger = create_logger(args.path_helper['log_path']) 67 | logger.info(args) 68 | 69 | 70 | '''segmentation data''' 71 | transform_train = transforms.Compose([ 72 | transforms.Resize((args.image_size,args.image_size)), 73 | transforms.ToTensor(), 74 | ]) 75 | 76 | transform_train_seg = transforms.Compose([ 77 | transforms.Resize((args.out_size,args.out_size)), 78 | transforms.ToTensor(), 79 | ]) 80 | 81 | transform_test = transforms.Compose([ 82 | transforms.Resize((args.image_size, args.image_size)), 83 | transforms.ToTensor(), 84 | ]) 85 | 86 | transform_test_seg = transforms.Compose([ 87 | transforms.Resize((args.out_size,args.out_size)), 88 | transforms.ToTensor(), 89 | ]) 90 | 91 | 92 | if args.dataset == 'isic': 93 | '''isic data''' 94 | isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') 95 | isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') 96 | 97 | nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) 98 | nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) 99 | '''end''' 100 | 101 | elif args.dataset == 'decathlon': 102 | nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args) 103 | 104 | elif args.dataset == 'REFUGE': 105 | '''REFUGE data''' 106 | refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') 107 | refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') 108 | 109 | nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) 110 | nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) 111 | '''end''' 112 | 113 | elif args.dataset == 'mydataset': 114 | '''mydataset''' 115 | mydata_train_dataset = MyDataset(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'train') 116 | mydata_test_dataset = MyDataset(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'test') 117 | 118 | nice_train_loader = DataLoader(mydata_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) 119 | nice_test_loader = DataLoader(mydata_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) 120 | '''end''' 121 | 122 | '''checkpoint path and tensorboard''' 123 | # iter_per_epoch = len(Glaucoma_training_loader) 124 | checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) 125 | #use tensorboard 126 | if not os.path.exists(settings.LOG_DIR): 127 | os.mkdir(settings.LOG_DIR) 128 | writer = SummaryWriter(log_dir=os.path.join( 129 | settings.LOG_DIR, args.net, settings.TIME_NOW)) 130 | # input_tensor = torch.Tensor(args.b, 3, 256, 256).cuda(device = GPUdevice) 131 | # writer.add_graph(net, Variable(input_tensor, requires_grad=True)) 132 | 133 | #create checkpoint folder to save model 134 | if not os.path.exists(checkpoint_path): 135 | os.makedirs(checkpoint_path) 136 | checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') 137 | 138 | '''begain training''' 139 | best_acc = 0.0 140 | best_tol = 1e4 141 | for epoch in range(settings.EPOCH): 142 | if args.mod == 'sam_adpt': 143 | net.train() 144 | time_start = time.time() 145 | loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis) 146 | logger.info(f'Train loss: {loss}|| @ epoch {epoch}.') 147 | time_end = time.time() 148 | print('time_for_training ', time_end - time_start) 149 | 150 | net.eval() 151 | if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1: 152 | tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) 153 | logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') 154 | 155 | if args.distributed != 'none': 156 | sd = net.module.state_dict() 157 | else: 158 | sd = net.state_dict() 159 | 160 | if tol < best_tol: 161 | best_tol = tol 162 | is_best = True 163 | 164 | save_checkpoint({ 165 | 'epoch': epoch + 1, 166 | 'model': args.net, 167 | 'state_dict': sd, 168 | 'optimizer': optimizer.state_dict(), 169 | 'best_tol': best_tol, 170 | 'path_helper': args.path_helper, 171 | }, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint") 172 | else: 173 | is_best = False 174 | 175 | writer.close() 176 | -------------------------------------------------------------------------------- /models/sam/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | if image_embeddings.shape[0] != tokens.shape[0]: 127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 128 | else: 129 | src = image_embeddings 130 | src = src + dense_prompt_embeddings 131 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 132 | b, c, h, w = src.shape 133 | 134 | # Run the transformer 135 | hs, src = self.transformer(src, pos_src, tokens) 136 | iou_token_out = hs[:, 0, :] 137 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 138 | 139 | # Upscale mask embeddings and predict masks using the mask tokens 140 | src = src.transpose(1, 2).view(b, c, h, w) 141 | upscaled_embedding = self.output_upscaling(src) 142 | hyper_in_list: List[torch.Tensor] = [] 143 | for i in range(self.num_mask_tokens): 144 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 145 | hyper_in = torch.stack(hyper_in_list, dim=1) 146 | b, c, h, w = upscaled_embedding.shape 147 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 148 | 149 | # Generate mask quality predictions 150 | iou_pred = self.iou_prediction_head(iou_token_out) 151 | 152 | return masks, iou_pred 153 | 154 | 155 | # Lightly adapted from 156 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 157 | class MLP(nn.Module): 158 | def __init__( 159 | self, 160 | input_dim: int, 161 | hidden_dim: int, 162 | output_dim: int, 163 | num_layers: int, 164 | sigmoid_output: bool = False, 165 | ) -> None: 166 | super().__init__() 167 | self.num_layers = num_layers 168 | h = [hidden_dim] * (num_layers - 1) 169 | self.layers = nn.ModuleList( 170 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 171 | ) 172 | self.sigmoid_output = sigmoid_output 173 | 174 | def forward(self, x): 175 | for i, layer in enumerate(self.layers): 176 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 177 | if self.sigmoid_output: 178 | x = F.sigmoid(x) 179 | return x 180 | -------------------------------------------------------------------------------- /precpt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Function 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import torchvision.utils as vutils 12 | from torch.utils.data import DataLoader 13 | from dataset import Dataset_FullImg, Dataset_DiscRegion 14 | import math 15 | import PIL 16 | import matplotlib.pyplot as plt 17 | import seaborn as sns 18 | 19 | import collections 20 | import logging 21 | import math 22 | import os 23 | import time 24 | from datetime import datetime 25 | 26 | import dateutil.tz 27 | 28 | from typing import Union, Optional, List, Tuple, Text, BinaryIO 29 | import pathlib 30 | import warnings 31 | import numpy as np 32 | from PIL import Image, ImageDraw, ImageFont, ImageColor 33 | from lucent.optvis.param.spatial import pixel_image, fft_image, init_image 34 | from lucent.optvis.param.color import to_valid_rgb 35 | from torchvision.models import vgg19 36 | import torch.nn.functional as F 37 | import cfg 38 | 39 | import warnings 40 | from collections import OrderedDict 41 | import numpy as np 42 | from tqdm import tqdm 43 | from PIL import Image 44 | import torch 45 | 46 | 47 | 48 | 49 | args = cfg.parse_args() 50 | device = torch.device('cuda', args.gpu_device) 51 | cnn = vgg19(pretrained=True).features.to(device).eval() 52 | 53 | content_layers_default = ['conv_4'] 54 | style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] 55 | 56 | cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) 57 | cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) 58 | 59 | class ContentLoss(nn.Module): 60 | 61 | def __init__(self, target,): 62 | super(ContentLoss, self).__init__() 63 | # we 'detach' the target content from the tree used 64 | # to dynamically compute the gradient: this is a stated value, 65 | # not a variable. Otherwise the forward method of the criterion 66 | # will throw an error. 67 | self.target = target.detach() 68 | 69 | def forward(self, input): 70 | self.loss = F.mse_loss(input, self.target) 71 | return input 72 | 73 | def gram_matrix(input): 74 | a, b, c, d = input.size() # a=batch size(=1) 75 | # b=number of feature maps 76 | # (c,d)=dimensions of a f. map (N=c*d) 77 | 78 | features = input.view(a * b, c * d) # resise F_XL into \hat F_XL 79 | 80 | G = torch.mm(features, features.t()) # compute the gram product 81 | 82 | # we 'normalize' the values of the gram matrix 83 | # by dividing by the number of element in each feature maps. 84 | return G.div(a * b * c * d) 85 | 86 | class StyleLoss(nn.Module): 87 | 88 | def __init__(self, target_feature): 89 | super(StyleLoss, self).__init__() 90 | self.target = gram_matrix(target_feature).detach() 91 | 92 | def forward(self, input): 93 | G = gram_matrix(input) 94 | self.loss = F.mse_loss(G, self.target) 95 | return input 96 | 97 | # create a module to normalize input image so we can easily put it in a 98 | # nn.Sequential 99 | class Normalization(nn.Module): 100 | def __init__(self, mean, std): 101 | super(Normalization, self).__init__() 102 | # .view the mean and std to make them [C x 1 x 1] so that they can 103 | # directly work with image Tensor of shape [B x C x H x W]. 104 | # B is batch size. C is number of channels. H is height and W is width. 105 | self.mean = torch.tensor(mean).view(-1, 1, 1) 106 | self.std = torch.tensor(std).view(-1, 1, 1) 107 | 108 | def forward(self, img): 109 | # normalize img 110 | return (img - self.mean) / self.std 111 | 112 | def run_precpt(cnn, normalization_mean, normalization_std, 113 | content_img, style_img, input_img, 114 | style_weight=1000000, content_weight=1): 115 | model, style_losses, content_losses = precpt_loss(cnn, 116 | normalization_mean, normalization_std, style_img, content_img) 117 | 118 | # We want to optimize the input and not the model parameters so we 119 | # update all the requires_grad fields accordingly 120 | model.requires_grad_(False) 121 | input_img.requires_grad_(True) 122 | 123 | model(input_img) 124 | style_score = 0 125 | content_score = 0 126 | 127 | for sl in style_losses: 128 | style_score += sl.loss 129 | for cl in content_losses: 130 | content_score += cl.loss 131 | 132 | content_weight = 100 133 | style_weight = 100000 134 | style_score *= style_weight 135 | content_score *= content_weight 136 | 137 | loss = style_score + content_score 138 | # loss = content_score 139 | 140 | return loss 141 | 142 | 143 | def precpt_loss(cnn, normalization_mean, normalization_std, 144 | style_img, content_img, 145 | content_layers=content_layers_default, 146 | style_layers=style_layers_default): 147 | 148 | # normalization module 149 | normalization = Normalization(normalization_mean, normalization_std).to(device) 150 | 151 | # just in order to have an iterable access to or list of content/syle 152 | # losses 153 | content_losses = [] 154 | style_losses = [] 155 | # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential 156 | # to put in modules that are supposed to be activated sequentially 157 | model = nn.Sequential(normalization) 158 | 159 | i = 0 # increment every time we see a conv 160 | for layer in cnn.children(): 161 | if isinstance(layer, nn.Conv2d): 162 | i += 1 163 | name = 'conv_{}'.format(i) 164 | elif isinstance(layer, nn.ReLU): 165 | name = 'relu_{}'.format(i) 166 | # The in-place version doesn't play very nicely with the ContentLoss 167 | # and StyleLoss we insert below. So we replace with out-of-place 168 | # ones here. 169 | layer = nn.ReLU(inplace=False) 170 | elif isinstance(layer, nn.MaxPool2d): 171 | name = 'pool_{}'.format(i) 172 | elif isinstance(layer, nn.BatchNorm2d): 173 | name = 'bn_{}'.format(i) 174 | else: 175 | raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) 176 | 177 | model.add_module(name, layer) 178 | 179 | if name in content_layers: 180 | # add content loss: 181 | target = model(content_img).detach() 182 | content_loss = ContentLoss(target) 183 | model.add_module("content_loss_{}".format(i), content_loss) 184 | content_losses.append(content_loss) 185 | 186 | if name in style_layers: 187 | # add style loss: 188 | if style_img.size(1) == 1: 189 | style_img = style_img.expand(style_img.size(0),3, style_img.size(2),style_img.size(3)) 190 | target_feature = model(style_img).detach() 191 | style_loss = StyleLoss(target_feature) 192 | model.add_module("style_loss_{}".format(i), style_loss) 193 | style_losses.append(style_loss) 194 | 195 | # now we trim off the layers after the last content and style losses 196 | for i in range(len(model) - 1, -1, -1): 197 | if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): 198 | break 199 | 200 | model = model[:(i + 1)] 201 | 202 | return model, style_losses, content_losses 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /models/unet/res_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | 99 | def __init__(self, block, layers, inplanes= 3, num_classes=1000): 100 | self.inplanes = 64 101 | super(ResNet, self).__init__() 102 | self.conv1 = nn.Conv2d(inplanes, 64, kernel_size=7, stride=2, padding=3, 103 | bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | self.avgpool = nn.AvgPool2d(7, stride=1) 112 | self.fc = nn.Linear(512 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def resnet18(pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 166 | return model 167 | 168 | 169 | def resnet34(pretrained=False, **kwargs): 170 | """Constructs a ResNet-34 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(torch.load('resnet34-333f7ec4.pth')) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 214 | return model 215 | -------------------------------------------------------------------------------- /models/sam/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | args, 25 | image_encoder: ImageEncoderViT, 26 | prompt_encoder: PromptEncoder, 27 | mask_decoder: MaskDecoder, 28 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 29 | pixel_std: List[float] = [58.395, 57.12, 57.375], 30 | ) -> None: 31 | """ 32 | SAM predicts object masks from an image and input prompts. 33 | 34 | Arguments: 35 | image_encoder (ImageEncoderViT): The backbone used to encode the 36 | image into image embeddings that allow for efficient mask prediction. 37 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 38 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 39 | and encoded prompts. 40 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 41 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 42 | """ 43 | super().__init__() 44 | self.args = args 45 | self.image_encoder = image_encoder 46 | self.prompt_encoder = prompt_encoder 47 | self.mask_decoder = mask_decoder 48 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 49 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 50 | 51 | @property 52 | def device(self) -> Any: 53 | return self.pixel_mean.device 54 | 55 | @torch.no_grad() 56 | def forward( 57 | self, 58 | batched_input: List[Dict[str, Any]], 59 | multimask_output: bool, 60 | ) -> List[Dict[str, torch.Tensor]]: 61 | """ 62 | Predicts masks end-to-end from provided images and prompts. 63 | If prompts are not known in advance, using SamPredictor is 64 | recommended over calling the model directly. 65 | 66 | Arguments: 67 | batched_input (list(dict)): A list over input images, each a 68 | dictionary with the following keys. A prompt key can be 69 | excluded if it is not present. 70 | 'image': The image as a torch tensor in 3xHxW format, 71 | already transformed for input to the model. 72 | 'original_size': (tuple(int, int)) The original size of 73 | the image before transformation, as (H, W). 74 | 'point_coords': (torch.Tensor) Batched point prompts for 75 | this image, with shape BxNx2. Already transformed to the 76 | input frame of the model. 77 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 78 | with shape BxN. 79 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 80 | Already transformed to the input frame of the model. 81 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 82 | in the form Bx1xHxW. 83 | multimask_output (bool): Whether the model should predict multiple 84 | disambiguating masks, or return a single mask. 85 | 86 | Returns: 87 | (list(dict)): A list over input images, where each element is 88 | as dictionary with the following keys. 89 | 'masks': (torch.Tensor) Batched binary mask predictions, 90 | with shape BxCxHxW, where B is the number of input prompts, 91 | C is determined by multimask_output, and (H, W) is the 92 | original size of the image. 93 | 'iou_predictions': (torch.Tensor) The model's predictions 94 | of mask quality, in shape BxC. 95 | 'low_res_logits': (torch.Tensor) Low resolution logits with 96 | shape BxCxHxW, where H=W=256. Can be passed as mask input 97 | to subsequent iterations of prediction. 98 | """ 99 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 100 | image_embeddings = self.image_encoder(input_images) 101 | 102 | outputs = [] 103 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 104 | if "point_coords" in image_record: 105 | points = (image_record["point_coords"], image_record["point_labels"]) 106 | else: 107 | points = None 108 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 109 | points=points, 110 | boxes=image_record.get("boxes", None), 111 | masks=image_record.get("mask_inputs", None), 112 | ) 113 | low_res_masks, iou_predictions = self.mask_decoder( 114 | image_embeddings=curr_embedding.unsqueeze(0), 115 | image_pe=self.prompt_encoder.get_dense_pe(), 116 | sparse_prompt_embeddings=sparse_embeddings, 117 | dense_prompt_embeddings=dense_embeddings, 118 | multimask_output=multimask_output, 119 | ) 120 | masks = self.postprocess_masks( 121 | low_res_masks, 122 | input_size=image_record["image"].shape[-2:], 123 | original_size=image_record["original_size"], 124 | ) 125 | masks = masks > self.mask_threshold 126 | outputs.append( 127 | { 128 | "masks": masks, 129 | "iou_predictions": iou_predictions, 130 | "low_res_logits": low_res_masks, 131 | } 132 | ) 133 | return outputs 134 | 135 | def postprocess_masks( 136 | self, 137 | masks: torch.Tensor, 138 | input_size: Tuple[int, ...], 139 | original_size: Tuple[int, ...], 140 | ) -> torch.Tensor: 141 | """ 142 | Remove padding and upscale masks to the original image size. 143 | 144 | Arguments: 145 | masks (torch.Tensor): Batched masks from the mask_decoder, 146 | in BxCxHxW format. 147 | input_size (tuple(int, int)): The size of the image input to the 148 | model, in (H, W) format. Used to remove padding. 149 | original_size (tuple(int, int)): The original size of the image 150 | before resizing for input to the model, in (H, W) format. 151 | 152 | Returns: 153 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 154 | is given by original_size. 155 | """ 156 | masks = F.interpolate( 157 | masks, 158 | (self.image_encoder.img_size, self.image_encoder.img_size), 159 | mode="bilinear", 160 | align_corners=False, 161 | ) 162 | masks = masks[..., : input_size[0], : input_size[1]] 163 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 164 | return masks 165 | 166 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 167 | """Normalize pixel values and pad to a square input.""" 168 | # Normalize colors 169 | x = (x - self.pixel_mean) / self.pixel_std 170 | 171 | # Pad 172 | h, w = x.shape[-2:] 173 | padh = self.image_encoder.img_size - h 174 | padw = self.image_encoder.img_size - w 175 | x = F.pad(x, (0, padw, 0, padh)) 176 | return x 177 | -------------------------------------------------------------------------------- /post_processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import cv2\n", 10 | "import os\n", 11 | "import numpy as np\n", 12 | "from tqdm import tqdm\n", 13 | "from PIL import Image" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "# Merge Image" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "def merge_images(input_dir, output_dir):\n", 30 | " files = os.listdir(input_dir)\n", 31 | " if not os.path.exists(output_dir):\n", 32 | " os.makedirs(output_dir)\n", 33 | " for file in tqdm(files):\n", 34 | " if file.startswith(\"predictleft_\") and file.endswith(\".jpg\"):\n", 35 | " left_image_path = os.path.join(input_dir, file)\n", 36 | " right_image_id = file.removeprefix(\"predictleft_\").removesuffix(\".jpg\")\n", 37 | " right_image_path = os.path.join(input_dir, f\"predictright_{right_image_id}.jpg\")\n", 38 | " if os.path.exists(right_image_path):\n", 39 | " left_image = Image.open(left_image_path)\n", 40 | " right_image = Image.open(right_image_path)\n", 41 | " right_image = right_image.transpose(Image.FLIP_LEFT_RIGHT)\n", 42 | " merged_image = Image.new(\"1\", (640, 320))\n", 43 | " merged_image.paste(left_image, (0, 0))\n", 44 | " merged_image.paste(right_image, (320, 0))\n", 45 | " output_path = os.path.join(output_dir, f\"{right_image_id}.png\")\n", 46 | " merged_image.save(output_path)\n", 47 | " else:\n", 48 | " print(f\"Right Image Not Found: {right_image_path}\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# log = \"mydataset_2023_12_26_15_48_14\"\n", 58 | "# merge_images(f\"./logs/{log}/Samples\", f\"./logs/{log}/Samples_merged\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "def merge_images_overlap(input_dir, output_dir):\n", 68 | " files = os.listdir(input_dir)\n", 69 | " if not os.path.exists(output_dir):\n", 70 | " os.makedirs(output_dir)\n", 71 | " for file in tqdm(files):\n", 72 | " if file.startswith(\"predictleft_\") and file.endswith(\".jpg\"):\n", 73 | " left_image_path = os.path.join(input_dir, file)\n", 74 | " right_image_id = file.removeprefix(\"predictleft_\").removesuffix(\".jpg\")\n", 75 | " right_image_path = os.path.join(input_dir, f\"predictright_{right_image_id}.jpg\")\n", 76 | " if os.path.exists(right_image_path):\n", 77 | " left_image = Image.open(left_image_path)\n", 78 | " right_image = Image.open(right_image_path)\n", 79 | " right_image = right_image.transpose(Image.FLIP_LEFT_RIGHT)\n", 80 | " left_temp = Image.new(\"1\", (640, 320))\n", 81 | " right_temp = Image.new(\"1\", (640, 320))\n", 82 | " left_temp.paste(left_image, (30, 0))\n", 83 | " right_temp.paste(right_image, (290, 0))\n", 84 | " merged_image = ImageChops.logical_or(left_temp, right_temp)\n", 85 | " output_path = os.path.join(output_dir, f\"{right_image_id}.png\")\n", 86 | " merged_image.save(output_path)\n", 87 | " else:\n", 88 | " print(f\"Right Image Not Found: {right_image_path}\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# log = \"mydataset_2023_12_26_15_48_14\"\n", 98 | "# merge_images_overlap(f\"./logs/{log}/Samples\", f\"./logs/{log}/Samples_merged\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "# Filter Small Components" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "def filter_small_connected_components(binary_image, threshold):\n", 115 | " nb_components, output, stats, _ = cv2.connectedComponentsWithStats(binary_image, connectivity=8)\n", 116 | " sizes = stats[:, -1]\n", 117 | " filtered_image = np.zeros(output.shape)\n", 118 | " for i in range(1, nb_components):\n", 119 | " if sizes[i] >= threshold:\n", 120 | " filtered_image[output == i] = 255\n", 121 | " return filtered_image\n", 122 | "\n", 123 | "\n", 124 | "\n", 125 | "def process_images_in_directory(input_dir, output_dir, threshold):\n", 126 | " os.makedirs(output_dir, exist_ok=True)\n", 127 | " for filename in tqdm(os.listdir(input_dir)):\n", 128 | " if filename.endswith((\".png\", \".jpg\", \".jpeg\")):\n", 129 | " image_path = os.path.join(input_dir, filename)\n", 130 | " binary_image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)\n", 131 | " filtered_image = filter_small_connected_components(binary_image, threshold).astype(np.bool_)\n", 132 | " pil_image = Image.fromarray(filtered_image)\n", 133 | " output_path = os.path.join(output_dir, filename)\n", 134 | " pil_image.save(output_path)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "log = \"mydataset_2023_12_26_12_47_36\"\n", 144 | "process_images_in_directory(f\"./logs/{log}/Samples_merged\", f\"./logs/{log}/Samples_merged_filtered\", 200)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "# Convert To Binary Image" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "def convert_to_binary(image_path, threshold=128):\n", 161 | " img = Image.open(image_path)\n", 162 | " gray_img = img.convert('L')\n", 163 | " binary_img = gray_img.point(lambda pixel: 0 if pixel < threshold else 255, \"1\")\n", 164 | " return binary_img\n", 165 | "\n", 166 | "\n", 167 | "def batch_convert_to_binary(input_directory, output_directory, threshold=128):\n", 168 | " if not os.path.exists(output_directory):\n", 169 | " os.makedirs(output_directory)\n", 170 | "\n", 171 | " for filename in os.listdir(input_directory):\n", 172 | " if filename.endswith(\".jpg\"):\n", 173 | " input_path = os.path.join(input_directory, filename)\n", 174 | " output_path = os.path.join(output_directory, filename.replace(\"predict\", \"\").replace(\".jpg\", \".png\"))\n", 175 | " binary_image = convert_to_binary(input_path, threshold)\n", 176 | " binary_image.save(output_path)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "# log = \"mydataset_2023_12_26_15_48_14\"\n", 186 | "# batch_convert_to_binary(f\"./logs/{log}/Samples_merged\", f\"./logs/{log}/Samples_merged_binary\", 128)" 187 | ] 188 | } 189 | ], 190 | "metadata": { 191 | "kernelspec": { 192 | "display_name": "sam_adapt", 193 | "language": "python", 194 | "name": "python3" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.10.11" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 2 211 | } 212 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Medical-SAM-Adapter 2 | 3 | Medical SAM Adapter, or say MSA, is a project to fineturn [SAM](https://github.com/facebookresearch/segment-anything) using [Adaption](https://lightning.ai/pages/community/tutorial/lora-llm/) for the Medical Imaging. 4 | This method is elaborated in the paper [Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation](https://arxiv.org/abs/2304.12620). 5 | 6 | 7 | ## A Quick Overview 8 | 9 | 10 | 11 | ## News 12 | - 23-05-10. This project is still quickly updating 🌝. Check TODO list to see what will be released next. 13 | - 23-05-11. GitHub Dicussion opened. You guys can now talk, code and make friends on the playground 👨‍❤️‍👨. 14 | - 23-12-22. Released data loader and example case on [REFUGE](https://refuge.grand-challenge.org/) dataset. 15 | 16 | ## Requirement 17 | 18 | Install the environment: 19 | 20 | ``conda env create -f environment.yml`` 21 | 22 | ``conda activate sam_adapt`` 23 | 24 | Then download [SAM checkpoint](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth), and put it at ./checkpoint/sam/ 25 | 26 | You can run: 27 | 28 | ``wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth`` 29 | 30 | ``mv sam_vit_b_01ec64.pth ./checkpoint/sam`` 31 | creat the folder if it does not exist 32 | 33 | ## Example Cases 34 | 35 | ### Melanoma Segmentation from Skin Images (2D) 36 | 37 | 1. Download ISIC dataset part 1 from https://challenge.isic-archive.com/data/. Then put the csv files in "./data/isic" under your data path. Your dataset folder under "your_data_path" should be like: 38 | 39 | ISIC/ 40 | 41 | ISBI2016_ISIC_Part1_Test_Data/... 42 | 43 | ISBI2016_ISIC_Part1_Training_Data/... 44 | 45 | ISBI2016_ISIC_Part1_Test_GroundTruth.csv 46 | 47 | ISBI2016_ISIC_Part1_Training_GroundTruth.csv 48 | 49 | 2. Begin Adapting! run: ``python train.py -net sam -mod sam_adpt -exp_name *msa_test_isic* -sam_ckpt ./checkpoint/sam/sam_vit_b_01ec64.pth -image_size 1024 -b 32 -dataset isic -data_path *../data*`` 50 | change "data_path" and "exp_name" for your own useage. you can change "exp_name" to anything you want. 51 | 52 | You can descrease the ``image size`` or batch size ``b`` if out of memory. 53 | 54 | 3. Evaluation: The code can automatically evaluate the model on the test set during traing, set "--val_freq" to control how many epoches you want to evaluate once. You can also run val.py for the independent evaluation. 55 | 56 | 4. Result Visualization: You can set "--vis" parameter to control how many epoches you want to see the results in the training or evaluation process. 57 | 58 | In default, everything will be saved at `` ./logs/`` 59 | 60 | ### REFUGE: Optic-disc Segmentation from Fundus Images (2D) 61 | [REFUGE](https://refuge.grand-challenge.org/) dataset contains 1200 fundus images with optic disc/cup segmentations and clinical glaucoma labels. 62 | 63 | 1. Dowaload the dataset manually from [here](https://huggingface.co/datasets/realslimman/REFUGE-MultiRater/tree/main), or using command lines: 64 | 65 | ``git lfs install`` 66 | 67 | ``git clone git@hf.co:datasets/realslimman/REFUGE-MultiRater`` 68 | 69 | unzip and put the dataset to the target folder 70 | 71 | ``unzip ./REFUGE-MultiRater.zip`` 72 | 73 | ``mv REFUGE-MultiRater ./data`` 74 | 75 | 2. For training the adapter, run: ``python train.py -net sam -mod sam_adpt -exp_name REFUGE-MSAdapt -sam_ckpt ./checkpoint/sam/sam_vit_b_01ec64.pth -image_size 1024 -b 32 -dataset REFUGE -data_path ./data/REFUGE-MultiRater`` 76 | you can change "exp_name" to anything you want. 77 | 78 | You can descrease the ``image size`` or batch size ``b`` if out of memory. 79 | 80 | ### Abdominal Multiple Organs Segmentation (3D) 81 | 82 | This tutorial demonstrates how MSA can adapt SAM to 3D multi-organ segmentation task using the BTCV challenge dataset. 83 | 84 | For BTCV dataset, under Institutional Review Board (IRB) supervision, 50 abdomen CT scans of were randomly selected from a combination of an ongoing colorectal cancer chemotherapy trial, and a retrospective ventral hernia study. The 50 scans were captured during portal venous contrast phase with variable volume sizes (512 x 512 x 85 - 512 x 512 x 198) and field of views (approx. 280 x 280 x 280 mm3 - 500 x 500 x 650 mm3). The in-plane resolution varies from 0.54 x 0.54 mm2 to 0.98 x 0.98 mm2, while the slice thickness ranges from 2.5 mm to 5.0 mm. 85 | 86 | Target: 13 abdominal organs including 87 | Spleen 88 | Right Kidney 89 | Left Kidney 90 | Gallbladder 91 | Esophagus 92 | Liver 93 | Stomach 94 | Aorta 95 | IVC 96 | Portal and Splenic Veins 97 | Pancreas 98 | Right adrenal gland 99 | Left adrenal gland. 100 | Modality: CT 101 | Size: 30 3D volumes (24 Training + 6 Testing) 102 | Challenge: BTCV MICCAI Challenge 103 | The following figure shows image patches with the organ sub-regions that are annotated in the CT (top left) and the final labels for the whole dataset (right). 104 | 105 | 106 | 1. Prepare BTCV dataset following [MONAI](https://docs.monai.io/en/stable/index.html) instruction: 107 | 108 | Download BTCV dataset from: https://www.synapse.org/#!Synapse:syn3193805/wiki/217752. After you open the link, navigate to the "Files" tab, then download Abdomen/RawData.zip. 109 | 110 | After downloading the zip file, unzip. Then put images from RawData/Training/img in ../data/imagesTr, and put labels from RawData/Training/label in ../data/labelsTr. 111 | 112 | Download the json file for data splits from this [link](https://drive.google.com/file/d/1qcGh41p-rI3H_sQ0JwOAhNiQSXriQqGi/view). Place the JSON file at ../data/dataset_0.json. 113 | 114 | 2. For the Adaptation, run: ``python train.py -net sam -mod sam_adpt -exp_name msa-3d-sam-btcv -sam_ckpt ./checkpoint/sam/sam_vit_b_01ec64.pth -image_size 1024 -b 8 -dataset decathlon -thd True -chunk 96 -dataset ../data -num_sample 4`` 115 | 116 | You can modify following parameters to save the memory usage: '-b' the batch size, '-chunk' the 3D depth (channel) for each sample, '-num_sample' number of samples for [Monai.RandCropByPosNegLabeld](https://docs.monai.io/en/stable/transforms.html#randcropbyposneglabeld), 'evl_chunk' the 3D channel split step in the evaluation, decrease it if out of memory in the evaluation. 117 | 118 | ## Run on your own dataset 119 | It is simple to run MSA on the other datasets. Just write another dataset class following which in `` ./dataset.py``. You only need to make sure you return a dict with 120 | 121 | 122 | { 123 | 'image': A tensor saving images with size [C,H,W] for 2D image, size [C, H, W, D] for 3D data. 124 | D is the depth of 3D volume, C is the channel of a scan/frame, which is commonly 1 for CT, MRI, US data. 125 | If processing, say like a colorful surgical video, D could the number of time frames, and C will be 3 for a RGB frame. 126 | 127 | 'label': The target masks. Same size with the images except the resolutions (H and W). 128 | 129 | 'p_label': The prompt label to decide positive/negative prompt. To simplify, you can always set 1 if don't need the negative prompt function. 130 | 131 | 'pt': The prompt. Should be the same as that in SAM, e.g., a click prompt should be [x of click, y of click], one click for each scan/frame if using 3d data. 132 | 133 | 'image_meta_dict': Optional. if you want save/visulize the result, you should put the name of the image in it with the key ['filename_or_obj']. 134 | 135 | ...(others as you want) 136 | } 137 | 138 | 139 | Welcome to open issues if you meet any problem. It would be appreciated if you could contribute your dataset extensions. Unlike natural images, medical images vary a lot depending on different tasks. Expanding the generalization of a method requires everyone's efforts. 140 | 141 | ### TODO LIST 142 | 143 | - [ ] Jupyter tutorials. 144 | - [x] Fix bugs in BTCV. Add BTCV example. 145 | - [ ] Release REFUGE2, BraTs dataloaders and examples 146 | - [ ] Changable Image Resolution 147 | - [ ] Fix bugs in Multi-GPU parallel 148 | - [x] Sample and Vis in training 149 | - [ ] Release general data pre-processing and post-processing 150 | - [x] Release evaluation 151 | - [ ] Deploy on HuggingFace 152 | - [x] configuration 153 | - [ ] Release SSL code 154 | 155 | ## Cite 156 | comment out temporarily as the paper is under review 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ train and test dataset 2 | 3 | author jundewu 4 | """ 5 | import os 6 | import sys 7 | import pickle 8 | import cv2 9 | from skimage import io 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | import torchvision.transforms as transforms 17 | import pandas as pd 18 | from skimage.transform import rotate 19 | from utils import random_click 20 | import random 21 | from monai.transforms import LoadImaged, Randomizable,LoadImage 22 | 23 | 24 | class ISIC2016(Dataset): 25 | def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): 26 | 27 | df = pd.read_csv(os.path.join(data_path, 'ISBI2016_ISIC_Part1_' + mode + '_GroundTruth.csv'), encoding='gbk') 28 | self.name_list = df.iloc[:,1].tolist() 29 | self.label_list = df.iloc[:,2].tolist() 30 | self.data_path = data_path 31 | self.mode = mode 32 | self.prompt = prompt 33 | self.img_size = args.image_size 34 | 35 | self.transform = transform 36 | self.transform_msk = transform_msk 37 | 38 | def __len__(self): 39 | return len(self.name_list) 40 | 41 | def __getitem__(self, index): 42 | # if self.mode == 'Training': 43 | # point_label = random.randint(0, 1) 44 | # inout = random.randint(0, 1) 45 | # else: 46 | # inout = 1 47 | # point_label = 1 48 | inout = 1 49 | point_label = 1 50 | 51 | """Get the images""" 52 | name = self.name_list[index] 53 | img_path = os.path.join(self.data_path, name) 54 | 55 | mask_name = self.label_list[index] 56 | msk_path = os.path.join(self.data_path, mask_name) 57 | 58 | img = Image.open(img_path).convert('RGB') 59 | mask = Image.open(msk_path).convert('L') 60 | 61 | # if self.mode == 'Training': 62 | # label = 0 if self.label_list[index] == 'benign' else 1 63 | # else: 64 | # label = int(self.label_list[index]) 65 | 66 | newsize = (self.img_size, self.img_size) 67 | mask = mask.resize(newsize) 68 | 69 | if self.prompt == 'click': 70 | pt = random_click(np.array(mask) / 255, point_label, inout) 71 | 72 | if self.transform: 73 | state = torch.get_rng_state() 74 | img = self.transform(img) 75 | torch.set_rng_state(state) 76 | 77 | 78 | if self.transform_msk: 79 | mask = self.transform_msk(mask) 80 | 81 | # if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0): 82 | # mask = 1 - mask 83 | 84 | name = name.split('/')[-1].split(".jpg")[0] 85 | image_meta_dict = {'filename_or_obj':name} 86 | return { 87 | 'image':img, 88 | 'label': mask, 89 | 'p_label':point_label, 90 | 'pt':pt, 91 | 'image_meta_dict':image_meta_dict, 92 | } 93 | 94 | class REFUGE(Dataset): 95 | def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): 96 | self.data_path = data_path 97 | self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()] 98 | self.mode = mode 99 | self.prompt = prompt 100 | self.img_size = args.image_size 101 | self.mask_size = args.out_size 102 | 103 | self.transform = transform 104 | self.transform_msk = transform_msk 105 | 106 | def __len__(self): 107 | return len(self.subfolders) 108 | 109 | def __getitem__(self, index): 110 | inout = 1 111 | point_label = 1 112 | 113 | """Get the images""" 114 | subfolder = self.subfolders[index] 115 | name = subfolder.split('/')[-1] 116 | 117 | # raw image and raters path 118 | img_path = os.path.join(subfolder, name + '.jpg') 119 | multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)] 120 | multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)] 121 | 122 | # raw image and raters images 123 | img = Image.open(img_path).convert('RGB') 124 | multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path] 125 | multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path] 126 | 127 | # resize raters images for generating initial point click 128 | newsize = (self.img_size, self.img_size) 129 | multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup] 130 | multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc] 131 | 132 | # first click is the target agreement among all raters 133 | if self.prompt == 'click': 134 | pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout) 135 | pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout) 136 | 137 | if self.transform: 138 | state = torch.get_rng_state() 139 | img = self.transform(img) 140 | multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup] 141 | multi_rater_cup = torch.stack(multi_rater_cup, dim=0) 142 | # transform to mask size (out_size) for mask define 143 | mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) 144 | 145 | multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc] 146 | multi_rater_disc = torch.stack(multi_rater_disc, dim=0) 147 | mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) 148 | torch.set_rng_state(state) 149 | 150 | image_meta_dict = {'filename_or_obj':name} 151 | return { 152 | 'image':img, 153 | 'multi_rater_cup': multi_rater_cup, 154 | 'multi_rater_disc': multi_rater_disc, 155 | 'mask_cup': mask_cup, 156 | 'mask_disc': mask_disc, 157 | 'label': mask_disc, 158 | 'p_label':point_label, 159 | 'pt_cup':pt_cup, 160 | 'pt_disc':pt_disc, 161 | 'pt':pt_disc, 162 | 'selected_rater': torch.tensor(np.arange(7)), 163 | 'image_meta_dict':image_meta_dict, 164 | } 165 | 166 | 167 | class MyDataset(Dataset): 168 | def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'train',prompt = 'click', plane = False): 169 | 170 | df = pd.read_csv(os.path.join(data_path, mode + '.csv'), encoding='gbk') 171 | self.name_list = df.iloc[:,1].tolist() 172 | self.label_list = df.iloc[:,2].tolist() 173 | self.data_path = data_path 174 | self.mode = mode 175 | self.prompt = prompt 176 | self.img_size = args.image_size 177 | 178 | self.transform = transform 179 | self.transform_msk = transform_msk 180 | 181 | def __len__(self): 182 | return len(self.name_list) 183 | 184 | def __getitem__(self, index): 185 | inout = 1 186 | point_label = 1 187 | 188 | """Get the images""" 189 | name = self.name_list[index] 190 | img_path = os.path.join(self.data_path, name) 191 | 192 | mask_name = self.label_list[index] 193 | msk_path = os.path.join(self.data_path, mask_name) 194 | 195 | img = Image.open(img_path).convert('RGB') 196 | mask = Image.open(msk_path).convert('L') 197 | 198 | newsize = (self.img_size, self.img_size) 199 | mask = mask.resize(newsize) 200 | 201 | if self.prompt == 'click': 202 | pt = random_click(np.array(mask) / 255, point_label, inout) 203 | 204 | if self.transform: 205 | state = torch.get_rng_state() 206 | img = self.transform(img) 207 | torch.set_rng_state(state) 208 | 209 | 210 | if self.transform_msk: 211 | mask = self.transform_msk(mask) 212 | 213 | name = name.split('/')[-1].split(".jpg")[0] 214 | image_meta_dict = {'filename_or_obj':name} 215 | return { 216 | 'image':img, 217 | 'label': mask, 218 | 'p_label':point_label, 219 | 'pt':pt, 220 | 'image_meta_dict':image_meta_dict, 221 | } -------------------------------------------------------------------------------- /models/sam/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /models/sam/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sam_adapt 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - abseil-cpp=20211102.0=hd4dd3e8_0 10 | - absl-py=1.3.0=py310h06a4308_0 11 | - aiohttp=3.8.3=py310h5eee18b_0 12 | - async-timeout=4.0.2=py310h06a4308_0 13 | - attrs=22.1.0=py310h06a4308_0 14 | - blas=1.0=mkl 15 | - blosc=1.21.3=h6a678d5_0 16 | - bottleneck=1.3.5=py310ha9d4c09_0 17 | - brotli=1.0.9=h5eee18b_7 18 | - brotli-bin=1.0.9=h5eee18b_7 19 | - brotlipy=0.7.0=py310h7f8727e_1002 20 | - brunsli=0.1=h2531618_0 21 | - bzip2=1.0.8=h7b6447c_0 22 | - c-ares=1.19.0=h5eee18b_0 23 | - ca-certificates=2022.12.7=ha878542_0 24 | - cffi=1.15.1=py310h5eee18b_3 25 | - cfitsio=3.470=h5893167_7 26 | - charls=2.2.0=h2531618_0 27 | - cloudpickle=2.2.1=py310h06a4308_0 28 | - contourpy=1.0.5=py310hdb19cb5_0 29 | - cpuonly=2.0=0 30 | - cryptography=39.0.1=py310h9ce1e76_0 31 | - cudatoolkit=11.3.1=h2bc3f7f_2 32 | - cytoolz=0.12.0=py310h5eee18b_0 33 | - dask-core=2023.4.1=py310h06a4308_0 34 | - dbus=1.13.18=hb2f20db_0 35 | - expat=2.4.9=h6a678d5_0 36 | - ffmpeg=4.3=hf484d3e_0 37 | - fontconfig=2.14.1=h4c34cd2_2 38 | - freetype=2.12.1=h4a9f257_0 39 | - frozenlist=1.3.3=py310h5eee18b_0 40 | - fsspec=2023.4.0=py310h06a4308_0 41 | - giflib=5.2.1=h5eee18b_3 42 | - glib=2.69.1=he621ea3_2 43 | - gmp=6.2.1=h295c915_3 44 | - gnutls=3.6.15=he1e5248_0 45 | - grpc-cpp=1.48.2=h5bf31a4_0 46 | - grpcio=1.48.2=py310h5bf31a4_0 47 | - gst-plugins-base=1.14.1=h6a678d5_1 48 | - gstreamer=1.14.1=h5eee18b_1 49 | - icu=58.2=he6710b0_3 50 | - idna=3.4=py310h06a4308_0 51 | - imagecodecs=2021.8.26=py310h46e8fbd_2 52 | - imageio=2.26.0=py310h06a4308_0 53 | - importlib-metadata=6.0.0=py310h06a4308_0 54 | - intel-openmp=2021.4.0=h06a4308_3561 55 | - joblib=1.1.1=py310h06a4308_0 56 | - jpeg=9e=h5eee18b_1 57 | - jxrlib=1.1=h7b6447c_2 58 | - kiwisolver=1.4.4=py310h6a678d5_0 59 | - krb5=1.19.4=h568e23c_0 60 | - lame=3.100=h7b6447c_0 61 | - lazy_loader=0.1=py310h06a4308_0 62 | - lcms2=2.12=h3be6417_0 63 | - ld_impl_linux-64=2.38=h1181459_1 64 | - lerc=3.0=h295c915_0 65 | - libaec=1.0.4=he6710b0_1 66 | - libbrotlicommon=1.0.9=h5eee18b_7 67 | - libbrotlidec=1.0.9=h5eee18b_7 68 | - libbrotlienc=1.0.9=h5eee18b_7 69 | - libclang=14.0.6=default_hc6dbbc7_1 70 | - libclang13=14.0.6=default_he11475f_1 71 | - libcurl=7.88.1=h91b91d3_0 72 | - libdeflate=1.17=h5eee18b_0 73 | - libedit=3.1.20221030=h5eee18b_0 74 | - libev=4.33=h7f8727e_1 75 | - libevent=2.1.12=h8f2d780_0 76 | - libffi=3.4.2=h6a678d5_6 77 | - libgcc-ng=11.2.0=h1234567_1 78 | - libgfortran-ng=11.2.0=h00389a5_1 79 | - libgfortran5=11.2.0=h1234567_1 80 | - libgomp=11.2.0=h1234567_1 81 | - libiconv=1.16=h7f8727e_2 82 | - libidn2=2.3.2=h7f8727e_0 83 | - libllvm14=14.0.6=hdb19cb5_2 84 | - libnghttp2=1.46.0=hce63b2e_0 85 | - libpng=1.6.39=h5eee18b_0 86 | - libpq=12.9=h16c4e8d_3 87 | - libprotobuf=3.20.3=he621ea3_0 88 | - libssh2=1.10.0=h8f2d780_0 89 | - libstdcxx-ng=11.2.0=h1234567_1 90 | - libtasn1=4.19.0=h5eee18b_0 91 | - libtiff=4.5.0=h6a678d5_2 92 | - libunistring=0.9.10=h27cfd23_0 93 | - libuuid=1.41.5=h5eee18b_0 94 | - libwebp=1.2.4=h11a3e52_1 95 | - libwebp-base=1.2.4=h5eee18b_1 96 | - libxcb=1.15=h7f8727e_0 97 | - libxkbcommon=1.0.1=h5eee18b_1 98 | - libxml2=2.10.3=hcbfbd50_0 99 | - libxslt=1.1.37=h2085143_0 100 | - libzopfli=1.0.3=he6710b0_0 101 | - locket=1.0.0=py310h06a4308_0 102 | - lz4-c=1.9.4=h6a678d5_0 103 | - markdown=3.4.1=py310h06a4308_0 104 | - markupsafe=2.1.1=py310h7f8727e_0 105 | - matplotlib=3.7.1=py310h06a4308_1 106 | - matplotlib-base=3.7.1=py310h1128e8f_1 107 | - mkl=2021.4.0=h06a4308_640 108 | - mkl-service=2.4.0=py310h7f8727e_0 109 | - mkl_fft=1.3.1=py310hd6ae3a3_0 110 | - mkl_random=1.2.2=py310h00e6091_0 111 | - monai=1.1.0=pyhd8ed1ab_0 112 | - multidict=6.0.2=py310h5eee18b_0 113 | - ncurses=6.4=h6a678d5_0 114 | - nettle=3.7.3=hbbd107a_1 115 | - networkx=2.8.4=py310h06a4308_1 116 | - nspr=4.33=h295c915_0 117 | - nss=3.74=h0370c37_0 118 | - numexpr=2.8.4=py310h8879344_0 119 | - numpy=1.24.3=py310hd5efca6_0 120 | - numpy-base=1.24.3=py310h8e6c178_0 121 | - oauthlib=3.2.2=py310h06a4308_0 122 | - openh264=2.1.1=h4ff587b_0 123 | - openjpeg=2.4.0=h3ad879b_0 124 | - openssl=1.1.1t=h7f8727e_0 125 | - packaging=23.0=py310h06a4308_0 126 | - pandas=1.5.3=py310h1128e8f_0 127 | - pcre=8.45=h295c915_0 128 | - pillow=9.4.0=py310h6a678d5_0 129 | - pip=23.0.1=py310h06a4308_0 130 | - ply=3.11=py310h06a4308_0 131 | - protobuf=3.20.3=py310h6a678d5_0 132 | - pyjwt=2.4.0=py310h06a4308_0 133 | - pyopenssl=23.0.0=py310h06a4308_0 134 | - pyparsing=3.0.9=py310h06a4308_0 135 | - pyqt=5.15.7=py310h6a678d5_1 136 | - pysocks=1.7.1=py310h06a4308_0 137 | - python=3.10.11=h7a1cb2a_2 138 | - pytorch-mutex=1.0=cpu 139 | - pytz=2022.7=py310h06a4308_0 140 | - pywavelets=1.4.1=py310h5eee18b_0 141 | - pyyaml=6.0=py310h5eee18b_1 142 | - qt-main=5.15.2=h8373d8f_8 143 | - qt-webengine=5.15.9=hbbf29b9_6 144 | - qtwebkit=5.212=h3fafdc1_5 145 | - re2=2022.04.01=h295c915_0 146 | - readline=8.2=h5eee18b_0 147 | - requests=2.29.0=py310h06a4308_0 148 | - scikit-image=0.20.0=py310h6a678d5_0 149 | - scikit-learn=1.2.2=py310h6a678d5_0 150 | - scipy=1.10.1=py310hd5efca6_0 151 | - seaborn=0.12.2=py310h06a4308_0 152 | - setuptools=66.0.0=py310h06a4308_0 153 | - sip=6.6.2=py310h6a678d5_0 154 | - snappy=1.1.9=h295c915_0 155 | - sqlite=3.41.2=h5eee18b_0 156 | - tensorboard=2.11.0=py310h06a4308_0 157 | - tensorboard-data-server=0.6.1=py310h52d8a92_0 158 | - tensorboard-plugin-wit=1.8.1=py310h06a4308_0 159 | - tk=8.6.12=h1ccaba5_0 160 | - toolz=0.12.0=py310h06a4308_0 161 | - torchaudio=0.12.1=py310_cpu 162 | - tornado=6.2=py310h5eee18b_0 163 | - tqdm=4.65.0=py310h2f386ee_0 164 | - typing_extensions=4.5.0=py310h06a4308_0 165 | - tzdata=2023c=h04d1e81_0 166 | - urllib3=1.26.15=py310h06a4308_0 167 | - wheel=0.38.4=py310h06a4308_0 168 | - xz=5.4.2=h5eee18b_0 169 | - yaml=0.2.5=h7b6447c_0 170 | - yarl=1.8.1=py310h5eee18b_0 171 | - zfp=0.5.5=h295c915_6 172 | - zipp=3.11.0=py310h06a4308_0 173 | - zlib=1.2.13=h5eee18b_0 174 | - zstd=1.5.5=hc292b87_0 175 | - pip: 176 | - --extra-index-url https://download.pytorch.org/whl/cu113 177 | - aiosignal==1.2.0 178 | - alembic==1.10.4 179 | - appdirs==1.4.4 180 | - astor==0.8.1 181 | - asttokens==2.2.1 182 | - backcall==0.2.0 183 | - beautifulsoup4==4.12.2 184 | - blinker==1.6.2 185 | - cachetools==4.2.2 186 | - certifi==2022.12.7 187 | - charset-normalizer==2.0.4 188 | - click==8.1.3 189 | - cmaes==0.9.1 190 | - colorama==0.4.6 191 | - colorlog==6.7.0 192 | - contextlib2==21.6.0 193 | - coverage==6.5.0 194 | - coveralls==3.3.1 195 | - cucim==23.4.1 196 | - cycler==0.11.0 197 | - databricks-cli==0.17.7 198 | - decorator==5.1.1 199 | - docker==6.1.1 200 | - docopt==0.6.2 201 | - einops==0.6.1 202 | - entrypoints==0.4 203 | - exceptiongroup==1.1.1 204 | - executing==1.2.0 205 | - filelock==3.12.0 206 | - fire==0.5.0 207 | - flask==2.3.2 208 | - fonttools==4.25.0 209 | - future==0.18.3 210 | - gdown==4.7.1 211 | - gitdb==4.0.10 212 | - gitpython==3.1.31 213 | - google-auth==2.6.0 214 | - google-auth-oauthlib==0.4.4 215 | - greenlet==2.0.2 216 | - gunicorn==20.1.0 217 | - h5py==3.8.0 218 | - huggingface-hub==0.14.1 219 | - iniconfig==2.0.0 220 | - ipython==8.13.1 221 | - itk==5.3.0 222 | - itk-core==5.3.0 223 | - itk-filtering==5.3.0 224 | - itk-io==5.3.0 225 | - itk-numerics==5.3.0 226 | - itk-registration==5.3.0 227 | - itk-segmentation==5.3.0 228 | - itsdangerous==2.1.2 229 | - jedi==0.18.2 230 | - jinja2==3.1.2 231 | - json-tricks==3.16.1 232 | - jsonschema==4.17.3 233 | - kornia==0.4.1 234 | - lmdb==1.4.1 235 | - lucent==0.1.0 236 | - mako==1.2.4 237 | - matplotlib-inline==0.1.6 238 | - mlflow==2.3.1 239 | - munkres==1.1.4 240 | - nibabel==5.1.0 241 | - ninja==1.11.1 242 | - nni==2.10 243 | - nptyping==2.5.0 244 | - opencv-python==4.7.0.72 245 | - openslide-python==1.1.2 246 | - optuna==3.1.1 247 | - parso==0.8.3 248 | - partd==1.2.0 249 | - pexpect==4.8.0 250 | - pickleshare==0.7.5 251 | - pluggy==1.0.0 252 | - pooch==1.4.0 253 | - prettytable==3.7.0 254 | - prompt-toolkit==3.0.38 255 | - psutil==5.9.5 256 | - ptyprocess==0.7.0 257 | - pure-eval==0.2.2 258 | - pyarrow==11.0.0 259 | - pyasn1==0.4.8 260 | - pyasn1-modules==0.2.8 261 | - pycparser==2.21 262 | - pydicom==2.3.1 263 | - pygments==2.15.1 264 | - pynrrd==1.0.0 265 | - pyqt5-sip==12.11.0 266 | - pyrsistent==0.19.3 267 | - pytest==7.3.1 268 | - pytest-mock==3.10.0 269 | - python-dateutil==2.8.2 270 | - pythonwebhdfs==0.2.3 271 | - pytorch-ignite==0.4.10 272 | - querystring-parser==1.2.4 273 | - regex==2023.5.5 274 | - requests-oauthlib==1.3.0 275 | - responses==0.23.1 276 | - rsa==4.7.2 277 | - schema==0.7.5 278 | - simplejson==3.19.1 279 | - six==1.16.0 280 | - smmap==5.0.0 281 | - soupsieve==2.4.1 282 | - sqlalchemy==2.0.12 283 | - sqlparse==0.4.4 284 | - stack-data==0.6.2 285 | - tabulate==0.9.0 286 | - tensorboardx==2.2 287 | - termcolor==2.3.0 288 | - threadpoolctl==2.2.0 289 | - tifffile==2021.7.2 290 | - tokenizers==0.12.1 291 | - toml==0.10.2 292 | - tomli==2.0.1 293 | - torch==1.12.1+cu113 294 | - torch-lucent==0.1.8 295 | - torchvision==0.13.1+cu113 296 | - traitlets==5.9.0 297 | - transformers==4.21.3 298 | - typeguard==3.0.2 299 | - types-pyyaml==6.0.12.9 300 | - wcwidth==0.2.6 301 | - websocket-client==1.5.1 302 | - websockets==11.0.3 303 | - werkzeug==2.3.4 304 | -------------------------------------------------------------------------------- /pre_processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import csv\n", 11 | "import random\n", 12 | "from PIL import Image\n", 13 | "from tqdm import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "# Split Train and Test" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "def split(origin_image_path, origin_mask_path):\n", 30 | " images = os.listdir(origin_image_path)\n", 31 | " masks = os.listdir(origin_mask_path)\n", 32 | " assert len(images) == len(masks)\n", 33 | " num_data = len(images)\n", 34 | " random.shuffle(images)\n", 35 | " os.system(\"rm -rf ./data/train/*\")\n", 36 | " os.system(\"rm -rf ./data/test/*\")\n", 37 | " os.makedirs(\"./data/train/image\", exist_ok=True)\n", 38 | " os.makedirs(\"./data/train/mask\", exist_ok=True)\n", 39 | " os.makedirs(\"./data/test/image\", exist_ok=True)\n", 40 | " os.makedirs(\"./data/test/mask\", exist_ok=True)\n", 41 | " for idx, image in enumerate(tqdm(images)):\n", 42 | " image_name = image.split(\"/\")[-1]\n", 43 | " os.system(\"cp {} {}\".format(os.path.join(origin_image_path, image_name), os.path.join(f\"./data/{'test' if idx <= num_data // 10 else 'train'}/image\", image_name)))\n", 44 | " os.system(\"cp {} {}\".format(os.path.join(origin_mask_path, image_name), os.path.join(f\"./data/{'test' if idx <= num_data // 10 else 'train'}/mask\", image_name)))" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stderr", 54 | "output_type": "stream", 55 | "text": [ 56 | "100%|██████████| 2900/2900 [00:07<00:00, 413.31it/s]\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "split(\"./data/original/image/\", \"./data/original/mask/\")" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "# Split Image" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def split_images(input_dir, output_dir):\n", 78 | " if not os.path.exists(output_dir):\n", 79 | " os.makedirs(output_dir)\n", 80 | " image_files = [f for f in os.listdir(input_dir) if f.lower().endswith((\".png\", \".jpg\", \".jpeg\", \".gif\", \".bmp\"))]\n", 81 | " for image_file in tqdm(image_files):\n", 82 | " input_path = os.path.join(input_dir, image_file)\n", 83 | " original_image = Image.open(input_path)\n", 84 | " width, height = original_image.size\n", 85 | " split_point = width // 2\n", 86 | " left_image = original_image.crop((0, 0, split_point, height))\n", 87 | " right_image = original_image.crop((split_point, 0, width, height))\n", 88 | " left_output_path = os.path.join(output_dir, f\"left_{image_file}\")\n", 89 | " right_output_path = os.path.join(output_dir, f\"right_{image_file}\")\n", 90 | " right_image = right_image.transpose(Image.FLIP_LEFT_RIGHT)\n", 91 | " left_image.save(left_output_path)\n", 92 | " right_image.save(right_output_path)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stderr", 102 | "output_type": "stream", 103 | "text": [ 104 | "100%|██████████| 2609/2609 [01:29<00:00, 29.12it/s]\n", 105 | "100%|██████████| 2609/2609 [00:06<00:00, 407.35it/s]\n", 106 | "100%|██████████| 291/291 [00:10<00:00, 28.95it/s]\n", 107 | "100%|██████████| 291/291 [00:00<00:00, 374.17it/s]\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "split_images(\"./data/train/image\", \"./data/train/image_splited\")\n", 113 | "split_images(\"./data/train/mask\", \"./data/train/mask_splited\")\n", 114 | "split_images(\"./data/test/image\", \"./data/test/image_splited\")\n", 115 | "split_images(\"./data/test/mask\", \"./data/test/mask_splited\")\n", 116 | "# split_images(\"./data/predict/image\", \"./data/predict/image_splited\")" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "def split_images_overlap(input_dir, output_dir):\n", 126 | " if not os.path.exists(output_dir):\n", 127 | " os.makedirs(output_dir)\n", 128 | " image_files = [f for f in os.listdir(input_dir) if f.lower().endswith((\".png\", \".jpg\", \".jpeg\", \".gif\", \".bmp\"))]\n", 129 | " for image_file in tqdm(image_files):\n", 130 | " input_path = os.path.join(input_dir, image_file)\n", 131 | " original_image = Image.open(input_path)\n", 132 | " width, height = original_image.size\n", 133 | " split_point = width // 2\n", 134 | " left_image = original_image.crop((30, 0, split_point + 30, height))\n", 135 | " right_image = original_image.crop((split_point - 30, 0, width - 30, height))\n", 136 | " left_output_path = os.path.join(output_dir, f\"left_{image_file}\")\n", 137 | " right_output_path = os.path.join(output_dir, f\"right_{image_file}\")\n", 138 | " right_image = right_image.transpose(Image.FLIP_LEFT_RIGHT)\n", 139 | " left_image.save(left_output_path)\n", 140 | " right_image.save(right_output_path)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "# split_images_overlap(\"./data/train/image\", \"./data/train/image_splited\")\n", 150 | "# split_images_overlap(\"./data/train/mask\", \"./data/train/mask_splited\")\n", 151 | "# split_images_overlap(\"./data/test/image\", \"./data/test/image_splited\")\n", 152 | "# split_images_overlap(\"./data/test/mask\", \"./data/test/mask_splited\")\n", 153 | "# split_images_overlap(\"./data/predict/image\", \"./data/predict/image_splited\")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "# Filter Illegal Data" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "def is_binary_image_all_zeros(image_path):\n", 170 | " with Image.open(image_path) as img:\n", 171 | " img = img.convert(\"L\")\n", 172 | " pixels = list(img.getdata())\n", 173 | " return all(pixel == 0 for pixel in pixels)\n", 174 | " \n", 175 | "\n", 176 | "def get_illegal_data(mask_dir):\n", 177 | " illegal_data = []\n", 178 | " for filename in tqdm(os.listdir(mask_dir)):\n", 179 | " if filename.endswith((\".png\", \".jpg\", \".jpeg\", \".gif\", \".bmp\")):\n", 180 | " image_path = os.path.join(mask_dir, filename)\n", 181 | " if is_binary_image_all_zeros(image_path):\n", 182 | " illegal_data.append(filename)\n", 183 | " return illegal_data\n", 184 | "\n", 185 | "\n", 186 | "def perform_delete(iliiegal_data, image_dir, mask_dir):\n", 187 | " for filename in tqdm(iliiegal_data):\n", 188 | " os.remove(os.path.join(image_dir, filename))\n", 189 | " os.remove(os.path.join(mask_dir, filename))\n", 190 | "\n", 191 | "\n", 192 | "def filter_illegal_data(image_dir, mask_dir):\n", 193 | " illegal_data = get_illegal_data(mask_dir)\n", 194 | " print(f\"Found {len(illegal_data)} illegal data:\")\n", 195 | " print(illegal_data)\n", 196 | " input(\"Confirm?\")\n", 197 | " perform_delete(illegal_data, image_dir, mask_dir)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 7, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stderr", 207 | "output_type": "stream", 208 | "text": [ 209 | "100%|██████████| 5218/5218 [00:07<00:00, 677.85it/s]\n" 210 | ] 211 | }, 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "Found 8 illegal data:\n", 217 | "['left_train_0.png', 'left_train_1.png', 'right_A-17.png', 'left_train_898.png', 'right_train_419.png', 'right_train_1.png', 'right_train_0.png', 'left_A-5.png']\n" 218 | ] 219 | }, 220 | { 221 | "name": "stderr", 222 | "output_type": "stream", 223 | "text": [ 224 | "100%|██████████| 8/8 [00:00<00:00, 13706.88it/s]\n", 225 | "100%|██████████| 582/582 [00:00<00:00, 680.74it/s]\n" 226 | ] 227 | }, 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "Found 0 illegal data:\n", 233 | "[]\n" 234 | ] 235 | }, 236 | { 237 | "name": "stderr", 238 | "output_type": "stream", 239 | "text": [ 240 | "0it [00:00, ?it/s]\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "filter_illegal_data(\"./data/train/image_splited\", \"./data/train/mask_splited\")\n", 246 | "filter_illegal_data(\"./data/test/image_splited\", \"./data/test/mask_splited\")" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "# Generate CSV" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 10, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "def generate_csv(type, image_dir, mask_dir = None):\n", 263 | " images = os.listdir(image_dir)\n", 264 | "\n", 265 | " if mask_dir is None:\n", 266 | " masks = [None] * len(images)\n", 267 | " else:\n", 268 | " masks = os.listdir(mask_dir)\n", 269 | "\n", 270 | " image_folder = os.path.basename(image_dir)\n", 271 | " mask_folder = os.path.basename(mask_dir) if mask_dir is not None else None\n", 272 | "\n", 273 | " assert len(images) == len(masks)\n", 274 | "\n", 275 | " with open(f\"./data/{type}.csv\", \"w\", newline=\"\") as csvfile:\n", 276 | " writer = csv.writer(csvfile)\n", 277 | " writer.writerow([\"#\", \"img\", \"seg\"])\n", 278 | " for i in range(len(images)):\n", 279 | " writer.writerow([i, f\"{type}/{image_folder}/{images[i]}\", f\"{type}/{mask_folder}/{masks[i]}\" if mask_dir is not None else \"./placeholder.png\"])" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 11, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "generate_csv(\"train\", \"./data/train/image_splited\", \"./data/train/mask_splited\")\n", 289 | "generate_csv(\"test\", \"./data/test/image_splited\", \"./data/test/mask_splited\")\n", 290 | "# generate_csv(\"predict\", \"./data/predict/image\")" 291 | ] 292 | } 293 | ], 294 | "metadata": { 295 | "kernelspec": { 296 | "display_name": "sam_adapt", 297 | "language": "python", 298 | "name": "python3" 299 | }, 300 | "language_info": { 301 | "codemirror_mode": { 302 | "name": "ipython", 303 | "version": 3 304 | }, 305 | "file_extension": ".py", 306 | "mimetype": "text/x-python", 307 | "name": "python", 308 | "nbconvert_exporter": "python", 309 | "pygments_lexer": "ipython3", 310 | "version": "3.10.11" 311 | } 312 | }, 313 | "nbformat": 4, 314 | "nbformat_minor": 2 315 | } 316 | -------------------------------------------------------------------------------- /models/sam/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /models/sam/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /models/implicitefficientnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | __version__ = "0.5.1" 6 | from .utils import ( 7 | GlobalParams, 8 | BlockArgs, 9 | BlockDecoder, 10 | efficientnet, 11 | get_model_params, 12 | ) 13 | 14 | 15 | from .utils import ( 16 | round_filters, 17 | round_repeats, 18 | drop_connect, 19 | get_same_padding_conv2d, 20 | get_model_params, 21 | efficientnet_params, 22 | load_pretrained_weights, 23 | Swish, 24 | MemoryEfficientSwish, 25 | gram_matrix, 26 | ) 27 | 28 | 29 | class MBConvBlock(nn.Module): 30 | """ 31 | Mobile Inverted Residual Bottleneck Block 32 | 33 | Args: 34 | block_args (namedtuple): BlockArgs, see above 35 | global_params (namedtuple): GlobalParam, see above 36 | 37 | Attributes: 38 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 39 | """ 40 | 41 | def __init__(self, block_args, global_params): 42 | super().__init__() 43 | self._block_args = block_args 44 | self._bn_mom = 1 - global_params.batch_norm_momentum 45 | self._bn_eps = global_params.batch_norm_epsilon 46 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 47 | self.id_skip = block_args.id_skip # skip connection and drop connect 48 | 49 | # Get static or dynamic convolution depending on image size 50 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 51 | 52 | # Expansion phase 53 | inp = self._block_args.input_filters # number of input channels 54 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 55 | if self._block_args.expand_ratio != 1: 56 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 57 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 58 | 59 | # Depthwise convolution phase 60 | k = self._block_args.kernel_size 61 | s = self._block_args.stride 62 | self._depthwise_conv = Conv2d( 63 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 64 | kernel_size=k, stride=s, bias=False) 65 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 66 | 67 | # Squeeze and Excitation layer, if desired 68 | if self.has_se: 69 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 70 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 71 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 72 | 73 | # Output phase 74 | final_oup = self._block_args.output_filters 75 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 76 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 77 | self._swish = MemoryEfficientSwish() 78 | 79 | def forward(self, inputs, drop_connect_rate=None): 80 | """ 81 | :param inputs: input tensor 82 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 83 | :return: output of block 84 | """ 85 | 86 | # Expansion and Depthwise Convolution 87 | x = inputs 88 | if self._block_args.expand_ratio != 1: 89 | x = self._swish(self._bn0(self._expand_conv(inputs))) 90 | x = self._swish(self._bn1(self._depthwise_conv(x))) 91 | 92 | # Squeeze and Excitation 93 | if self.has_se: 94 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 95 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 96 | x = torch.sigmoid(x_squeezed) * x 97 | 98 | x = self._bn2(self._project_conv(x)) 99 | 100 | # Skip connection and drop connect 101 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 102 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 103 | if drop_connect_rate: 104 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 105 | x = x + inputs # skip connection 106 | return x 107 | 108 | def set_swish(self, memory_efficient=True): 109 | """Sets swish function as memory efficient (for training) or standard (for export)""" 110 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 111 | 112 | 113 | class EfficientNet(nn.Module): 114 | """ 115 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 116 | 117 | Args: 118 | blocks_args (list): A list of BlockArgs to construct blocks 119 | global_params (namedtuple): A set of GlobalParams shared between blocks 120 | 121 | Example: 122 | model = EfficientNet.from_pretrained('efficientnet-b0') 123 | 124 | """ 125 | 126 | def __init__(self, type, blocks_args=None, global_params=None): 127 | super().__init__() 128 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 129 | assert len(blocks_args) > 0, 'block args must be greater than 0' 130 | self._global_params = global_params 131 | self._blocks_args = blocks_args 132 | self.type = type 133 | # Get static or dynamic convolution depending on image size 134 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 135 | 136 | # Batch norm parameters 137 | bn_mom = 1 - self._global_params.batch_norm_momentum 138 | bn_eps = self._global_params.batch_norm_epsilon 139 | 140 | # Stem 141 | in_channels = 5 # rgb 142 | out_channels = round_filters(32, self._global_params) # number of output channels 143 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 144 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 145 | 146 | # Build blocks 147 | self._blocks = nn.ModuleList([]) 148 | for block_args in self._blocks_args: 149 | 150 | # Update block input and output filters based on depth multiplier. 151 | block_args = block_args._replace( 152 | input_filters=round_filters(block_args.input_filters, self._global_params), 153 | output_filters=round_filters(block_args.output_filters, self._global_params), 154 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 155 | ) 156 | 157 | # The first block needs to take care of stride and filter size increase. 158 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 159 | if block_args.num_repeat > 1: 160 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 161 | for _ in range(block_args.num_repeat - 1): 162 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 163 | 164 | # Head 165 | in_channels = block_args.output_filters # output of final block 166 | out_channels = round_filters(1280, self._global_params) 167 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 168 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 169 | 170 | # Final linear layer 171 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 172 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 173 | self._fc = nn.Linear(out_channels, 1) 174 | self._swish = MemoryEfficientSwish() 175 | self.conv_reg = nn.Conv2d(1792, 1, 1) 176 | if self.type == 'big_map' or self.type == 'img': 177 | self.conv_transe1 = nn.Conv2d(1792, 448, 1) 178 | self.bn_transe1 = nn.BatchNorm2d(num_features=448, momentum=bn_mom, eps=bn_eps) 179 | self.conv_transe2 = nn.Conv2d(448, 112, 1) 180 | self.bn_transe2 = nn.BatchNorm2d(num_features=112, momentum=bn_mom, eps=bn_eps) 181 | if self.type == 'big_map': 182 | self.conv_transe_mask = nn.Conv2d(112, 1, 1) 183 | self.deconv_big = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose 184 | if self.type == 'img': 185 | self.conv_transe3 = nn.Conv2d(112, 3, 1) 186 | self.deconv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose 187 | elif self.type == 'deconv_map' or self.type == 'deconv_img': 188 | self.conv_big_reg = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose 189 | self.conv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose 190 | else: 191 | self.conv_reg = nn.Conv2d(1792, 1, 1) 192 | 193 | self.relu = nn.ReLU() 194 | self.up_double = nn.Upsample(scale_factor=2, mode='bilinear') 195 | self.sig = nn.Sigmoid() 196 | 197 | def set_swish(self, memory_efficient=True): 198 | """Sets swish function as memory efficient (for training) or standard (for export)""" 199 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 200 | for block in self._blocks: 201 | block.set_swish(memory_efficient) 202 | 203 | def extract_features(self, inputs): 204 | """ Returns output of the final convolution layer """ 205 | 206 | # Stem 207 | x = self._swish(self._bn0(self._conv_stem(inputs))) 208 | 209 | # Blocks 210 | for idx, block in enumerate(self._blocks): 211 | drop_connect_rate = self._global_params.drop_connect_rate 212 | if drop_connect_rate: 213 | drop_connect_rate *= float(idx) / len(self._blocks) 214 | x = block(x, drop_connect_rate=drop_connect_rate) 215 | 216 | # Head 217 | x = self._swish(self._bn1(self._conv_head(x))) 218 | 219 | return x 220 | 221 | def forward(self, seg, label, natural): 222 | label = label.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(seg.size()) 223 | 224 | x = torch.cat((label, natural, seg), 1) # concated input 225 | bs = seg.size(0) 226 | # Convolution layers 227 | x = self.extract_features(x) 228 | if self.type == 'map': 229 | reg = self.conv_reg(x) 230 | reg = self.sig(reg) 231 | elif self.type == 'big_map': 232 | reg = self.up_double(x) # 12*14 233 | reg = self.relu(reg) 234 | reg = self.conv_transe1(reg) # 448 235 | reg = self.bn_transe1(reg) 236 | 237 | reg = self.up_double(reg) # 24*28 238 | reg = self.relu(reg) 239 | reg = self.conv_transe2(reg) # 112 240 | reg = self.bn_transe2(reg) 241 | 242 | reg = self.conv_transe_mask(reg) # 1 243 | reg = self.sig(reg) 244 | elif self.type == 'img': 245 | reg = self.up_double(x) # 12*14 246 | reg = self.relu(reg) 247 | reg = self.conv_transe1(reg) # 448 248 | reg = self.bn_transe1(reg) 249 | 250 | reg = self.up_double(reg) # 24*28 251 | reg = self.relu(reg) 252 | reg = self.conv_transe2(reg) # 112 253 | reg = self.bn_transe2(reg) 254 | 255 | reg = self.conv_transe3(reg) # 3 256 | reg = self.sig(reg) 257 | elif self.type == 'deconv_map': 258 | reg = self.conv_big_reg(x) 259 | reg = self.sig(reg) 260 | elif self.type == 'deconv_img': 261 | reg = self.conv_img(x) 262 | reg = self.sig(reg) 263 | elif self.type == 'feature': 264 | reg = gram_matrix(x - x.mean(0, True)) 265 | 266 | return reg 267 | 268 | @classmethod 269 | def from_name(cls, model_name, type, override_params=None): 270 | cls._check_model_name_is_valid(model_name) 271 | blocks_args, global_params = get_model_params(model_name, override_params) 272 | return cls(type, blocks_args, global_params) 273 | 274 | @classmethod 275 | def from_pretrained(cls, model_name, num_classes=1000, in_channels=3): 276 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 277 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) 278 | if in_channels != 3: 279 | Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size) 280 | out_channels = round_filters(32, model._global_params) 281 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 282 | return model 283 | 284 | @classmethod 285 | def from_pretrained(cls, model_name, num_classes=1000): 286 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 287 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) 288 | 289 | return model 290 | 291 | @classmethod 292 | def get_image_size(cls, model_name): 293 | cls._check_model_name_is_valid(model_name) 294 | _, _, res, _ = efficientnet_params(model_name) 295 | return res 296 | 297 | @classmethod 298 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): 299 | """ Validates model name. None that pretrained weights are only available for 300 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ 301 | num_models = 4 if also_need_pretrained_weights else 8 302 | valid_models = ['efficientnet-b' + str(i) for i in range(num_models)] 303 | if model_name not in valid_models: 304 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 305 | 306 | 307 | 308 | --------------------------------------------------------------------------------