├── dataset ├── generate_dataset_json │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-39.pyc │ ├── br35.py │ ├── brainmri.py │ ├── DTD.py │ ├── colonDB.py │ ├── clinicDB.py │ ├── head_ct.py │ ├── tn3k.py │ ├── isic.py │ ├── btad.py │ ├── visa.py │ ├── SDD.py │ ├── mpdd.py │ ├── endoTect.py │ ├── mvtec.py │ ├── custom_dataset.py │ └── DAGM.py └── dataset.py ├── models ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── build_model.py ├── simple_tokenizer.py └── model_load.py ├── __init__.py ├── assets ├── main-fig.png ├── table1.png ├── visualization_ind.png ├── visualization_med.png └── visualization_combined.jpg ├── checkpoints ├── trained_on_mvtec_crane │ ├── epoch_1.pth │ ├── epoch_2.pth │ ├── epoch_3.pth │ ├── epoch_4.pth │ └── epoch_5.pth ├── trained_on_visa_crane │ ├── epoch_1.pth │ ├── epoch_2.pth │ ├── epoch_3.pth │ ├── epoch_4.pth │ └── epoch_5.pth ├── trained_on_visa_cranep │ ├── epoch_1.pth │ ├── epoch_2.pth │ ├── epoch_3.pth │ ├── epoch_4.pth │ └── epoch_5.pth └── trained_on_mvtec_cranep │ ├── epoch_1.pth │ ├── epoch_2.pth │ ├── epoch_3.pth │ ├── epoch_4.pth │ └── epoch_5.pth ├── segment_anything ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── build_sam.cpython-39.pyc │ ├── predictor.cpython-39.pyc │ ├── automatic_mask_generator.cpython-39.pyc │ ├── __init__.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc │ └── automatic_mask_generator.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc ├── utils │ ├── __pycache__ │ │ ├── amg.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── transforms.cpython-39.pyc │ │ ├── amg.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc │ │ ├── __init__.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc │ │ └── transforms.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc │ ├── __init__.py │ ├── transforms.py │ ├── onnx.py │ └── amg.py ├── modeling │ ├── __pycache__ │ │ ├── sam.cpython-39.pyc │ │ ├── common.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── transformer.cpython-39.pyc │ │ ├── image_encoder.cpython-39.pyc │ │ ├── mask_decoder.cpython-39.pyc │ │ └── prompt_encoder.cpython-39.pyc │ ├── __init__.py │ ├── common.py │ ├── mask_decoder.py │ ├── sam.py │ ├── transformer.py │ └── prompt_encoder.py ├── __init__.py ├── build_sam.py └── predictor.py ├── setup.sh ├── runtime.sh ├── environment.yml ├── utils ├── similarity.py ├── transform.py ├── logger.py ├── visualization.py ├── loss.py ├── __init__.py └── metrics.py ├── LICENSE ├── test.sh ├── reproduce.sh ├── README.md └── train.py /dataset/generate_dataset_json/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_load import * -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | DATASETS_ROOT='/home/alireza/datasets' -------------------------------------------------------------------------------- /assets/main-fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/assets/main-fig.png -------------------------------------------------------------------------------- /assets/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/assets/table1.png -------------------------------------------------------------------------------- /assets/visualization_ind.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/assets/visualization_ind.png -------------------------------------------------------------------------------- /assets/visualization_med.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/assets/visualization_med.png -------------------------------------------------------------------------------- /assets/visualization_combined.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/assets/visualization_combined.jpg -------------------------------------------------------------------------------- /models/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/models/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_crane/epoch_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_crane/epoch_1.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_crane/epoch_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_crane/epoch_2.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_crane/epoch_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_crane/epoch_3.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_crane/epoch_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_crane/epoch_4.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_crane/epoch_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_crane/epoch_5.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_crane/epoch_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_crane/epoch_1.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_crane/epoch_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_crane/epoch_2.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_crane/epoch_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_crane/epoch_3.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_crane/epoch_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_crane/epoch_4.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_crane/epoch_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_crane/epoch_5.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_cranep/epoch_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_cranep/epoch_1.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_cranep/epoch_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_cranep/epoch_2.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_cranep/epoch_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_cranep/epoch_3.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_cranep/epoch_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_cranep/epoch_4.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_visa_cranep/epoch_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_visa_cranep/epoch_5.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_cranep/epoch_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_cranep/epoch_1.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_cranep/epoch_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_cranep/epoch_2.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_cranep/epoch_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_cranep/epoch_3.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_cranep/epoch_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_cranep/epoch_4.pth -------------------------------------------------------------------------------- /checkpoints/trained_on_mvtec_cranep/epoch_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/checkpoints/trained_on_mvtec_cranep/epoch_5.pth -------------------------------------------------------------------------------- /segment_anything/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/build_sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/__pycache__/build_sam.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/predictor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/__pycache__/predictor.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/amg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/utils/__pycache__/amg.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/modeling/__pycache__/sam.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/modeling/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/utils/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/generate_dataset_json/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/dataset/generate_dataset_json/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/__init__.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/__pycache__/__init__.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/amg.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/utils/__pycache__/amg.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/__init__.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/utils/__pycache__/__init__.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/transforms.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/utils/__pycache__/transforms.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/automatic_mask_generator.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlirezaSalehy/Crane/HEAD/segment_anything/__pycache__/automatic_mask_generator.cpython-39.sync-conflict-20250510-072407-YT4HGEF.pyc -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # pip install pipreqs vultures 2 | # conda env export > environment.yml 3 | 4 | conda env create --file environment.yml 5 | 6 | ### NOTE ### then run the below commands inside terminal to activate the env ### NOTE ### 7 | # conda init 8 | # conda activate crane_env 9 | -------------------------------------------------------------------------------- /runtime.sh: -------------------------------------------------------------------------------- 1 | 2 | # Capture start time 3 | start_time=$(date +%s) 4 | echo $start_time 5 | 6 | python test.py --devices 0 --epochs 5 --mean_all_layers false --dino_model dinov2 --model_name trained_mvtec_default 7 | 8 | # Capture end time 9 | end_time=$(date +%s) 10 | echo $end_time 11 | 12 | # Calculate elapsed time 13 | elapsed=$((end_time - start_time)) 14 | 15 | # Print result 16 | echo "Execution time: $elapsed seconds" -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - nvidia 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - numpy=1.26.4 8 | - pandas=2.2.3 9 | - scikit-learn=1.5.2 10 | - scipy=1.13.0 11 | - matplotlib=3.7.2 12 | - regex=2023.8.8 13 | - tabulate=0.9.0 14 | - pillow=10.4.0 15 | - tqdm=4.65.2 16 | - termcolor=2.5.0 17 | - pytorch-cuda=11.6 18 | - torchmetrics=1.4.0.post0 19 | - ipykernel=6.29.5 20 | - python=3.9.20 21 | - pip=24.2 22 | - pip: 23 | - albumentations==1.4.16 24 | - ftfy==6.2.0 25 | - humanhash3==0.0.6 26 | - opencv-python-headless>=4.9.0.80 27 | - torch==1.13.1 28 | - torchvision==0.14.1 29 | name: crane_env 30 | 31 | -------------------------------------------------------------------------------- /utils/similarity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F # Import F for direct access to functional operations like interpolate 3 | 4 | def calc_similarity_logits(image_features, text_features, temp=0.07): #1/100, 1/25 5 | image_features_ = image_features.unsqueeze(dim=1) if image_features.dim() == 2 else image_features 6 | logits = (image_features_ @ text_features.permute(0, 2, 1))/temp 7 | return logits.squeeze(dim=1) if image_features.dim() == 2 else logits 8 | 9 | # mode=nearest (only to check reproducability bcus deterministic) 10 | # relieve is to downsample the groundtruth or use a library which 11 | # supports autograd and is deterministic or some trial approach 12 | # like one dicussed in https://github.com/open-mmlab/mmsegmentation/issues/255 # bilinear 13 | def regrid_upsample(flat_scores, size, mode='bilinear'): 14 | h_w = int(flat_scores.shape[1] ** 0.5) 15 | regrided = flat_scores.reshape(flat_scores.shape[0], h_w, h_w, -1).permute(0, 3, 1, 2) 16 | upsampled = F.interpolate(regrided, (size, size), mode=mode).permute(0, 2, 3, 1) 17 | return upsampled -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Alireza Salehi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/build_model.py: -------------------------------------------------------------------------------- 1 | from .Crane import Crane 2 | 3 | def build_model(name, state_dict: dict, design_details = None): 4 | vision_width = state_dict["visual.conv1.weight"].shape[0] 5 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 6 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 7 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 8 | image_resolution = vision_patch_size * grid_size 9 | 10 | embed_dim = state_dict["text_projection"].shape[1] 11 | context_length = state_dict["positional_embedding"].shape[0] 12 | vocab_size = state_dict["token_embedding.weight"].shape[0] 13 | transformer_width = state_dict["ln_final.weight"].shape[0] 14 | transformer_heads = transformer_width // 64 15 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 16 | 17 | model = Crane( 18 | embed_dim, 19 | image_resolution, vision_layers, vision_width, vision_patch_size, 20 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details = design_details 21 | ) 22 | 23 | for key in ["input_resolution", "context_length", "vocab_size"]: 24 | if key in state_dict: 25 | del state_dict[key] 26 | 27 | #convert_weights(model) 28 | model.load_state_dict(state_dict) 29 | return model.eval() 30 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/br35.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from __init__ import DATASETS_ROOT 5 | 6 | class Br35Solver(object): 7 | CLSNAMES = ['brain'] 8 | 9 | def __init__(self, root='data/mvtec'): 10 | self.root = root 11 | self.meta_path = f'{root}/meta.json' 12 | 13 | def run(self): 14 | info = dict(train={}, test={}) 15 | for cls_name in self.CLSNAMES: 16 | cls_dir = f'{self.root}/{cls_name}' 17 | for phase in ['test']: 18 | cls_info = [] 19 | species = os.listdir(f'{cls_dir}') 20 | for specie in species: 21 | is_abnormal = True if specie not in ['no'] else False 22 | img_names = os.listdir(f'{cls_dir}/{specie}') 23 | img_names.sort() 24 | for idx, img_name in enumerate(img_names): 25 | info_img = dict( 26 | img_path=f'{cls_dir}/{specie}/{img_name}', 27 | cls_name=cls_name, 28 | mask_path="", 29 | specie_name=specie, 30 | anomaly=1 if is_abnormal else 0, 31 | ) 32 | cls_info.append(info_img) 33 | info[phase][cls_name] = cls_info 34 | with open(self.meta_path, 'w') as f: 35 | f.write(json.dumps(info, indent=4) + "\n") 36 | 37 | if __name__ == '__main__': 38 | runner = Br35Solver(root=f'{DATASETS_ROOT}/br35h') 39 | runner.run() 40 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/brainmri.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from __init__ import DATASETS_ROOT 5 | 6 | class IsbiSolver(object): 7 | CLSNAMES = ['brain'] 8 | 9 | def __init__(self, root='data/mvtec'): 10 | self.root = root 11 | self.meta_path = f'{root}/meta.json' 12 | 13 | def run(self): 14 | info = dict(train={}, test={}) 15 | for cls_name in self.CLSNAMES: 16 | cls_dir = f'{self.root}/brain_tumor_dataset' 17 | for phase in ['test']: 18 | cls_info = [] 19 | species = os.listdir(f'{cls_dir}') 20 | for specie in species: 21 | is_abnormal = True if specie not in ['no'] else False 22 | img_names = os.listdir(f'{cls_dir}/{specie}') 23 | img_names.sort() 24 | for idx, img_name in enumerate(img_names): 25 | info_img = dict( 26 | img_path=f'{cls_dir}/{specie}/{img_name}', 27 | cls_name=cls_name, 28 | mask_path="", 29 | specie_name=specie, 30 | anomaly=1 if is_abnormal else 0, 31 | ) 32 | cls_info.append(info_img) 33 | info[phase][cls_name] = cls_info 34 | with open(self.meta_path, 'w') as f: 35 | f.write(json.dumps(info, indent=4) + "\n") 36 | 37 | if __name__ == '__main__': 38 | runner = IsbiSolver(root=f'{DATASETS_ROOT}/brainmri') 39 | runner.run() 40 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # To test the current checkpoints (reported in paper) 2 | device=$1 3 | 4 | run_for_trained_on_mvtec() { 5 | local base_command="$1" 6 | shift 7 | local datasets=("$@") 8 | 9 | for dataset in "${datasets[@]}"; do 10 | local command="$base_command --dataset $dataset --model_name trained_on_mvtec_$cur_model_name" 11 | eval "$command" 12 | done 13 | } 14 | 15 | # Table 1 Training Scheme 16 | # Crane 17 | cur_model_name="crane" 18 | base_command="python test.py --devices $device --epoch 5 --dino_model none --soft_mean True --features_list 6 12 18 24 --visualize False" 19 | eval "$base_command --dataset mvtec --model_name trained_on_visa_$cur_model_name" 20 | run_for_trained_on_mvtec "$base_command" visa mpdd sdd btad dtd dagm 21 | run_for_trained_on_mvtec "$base_command" brainmri headct br35h isic tn3k cvc-colondb cvc-clinicdb 22 | 23 | # Table 1 Training Scheme 24 | # Crane+ (with D-Atten) 25 | cur_model_name="cranep" 26 | # MVTec 27 | base_command="python test.py --devices $device --epoch 5 --dino_model dinov2 --features_list 24 --visualize False" 28 | eval "$base_command --dataset mvtec --model_name trained_on_visa_$cur_model_name" 29 | # visa mpdd sdd btad dtd 30 | run_for_trained_on_mvtec "$base_command" visa mpdd sdd btad dtd 31 | eval "$base_command --dataset dagm --soft_mean True --model_name trained_on_mvtec_$cur_model_name" 32 | # DAGM 33 | base_command="python test.py --devices $device --epoch 1 --dino_model dinov2 --soft_mean True --features_list 24 --visualize False" 34 | # Medicals 35 | run_for_trained_on_mvtec "$base_command" brainmri headct br35h isic tn3k cvc-colondb cvc-clinicdb 36 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/DTD.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import json 5 | 6 | from __init__ import DATASETS_ROOT 7 | 8 | class DTDSolver(object): 9 | CLSNAMES = ['Woven_001', 'Woven_127', 'Woven_104', 'Stratified_154', 'Blotchy_099', 'Woven_068', 'Woven_125', 'Marbled_078', 'Perforated_037', 'Mesh_114', 'Fibrous_183', 'Matted_069'] 10 | 11 | def __init__(self, root='data/mvtec'): 12 | self.root = root 13 | self.meta_path = f'{root}/meta.json' 14 | 15 | def run(self): 16 | info = dict(train={}, test={}) 17 | for cls_name in self.CLSNAMES: 18 | cls_dir = f'{self.root}/{cls_name}' 19 | for phase in ['train', 'test']: 20 | cls_info = [] 21 | species = os.listdir(f'{cls_dir}/{phase}') 22 | for specie in species: 23 | is_abnormal = True if specie not in ['good'] else False 24 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 25 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 26 | img_names.sort() 27 | mask_names.sort() if mask_names is not None else None 28 | for idx, img_name in enumerate(img_names): 29 | info_img = dict( 30 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 31 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 32 | cls_name=cls_name, 33 | specie_name=specie, 34 | anomaly=1 if is_abnormal else 0, 35 | ) 36 | cls_info.append(info_img) 37 | info[phase][cls_name] = cls_info 38 | with open(self.meta_path, 'w') as f: 39 | f.write(json.dumps(info, indent=4) + "\n") 40 | 41 | 42 | if __name__ == '__main__': 43 | runner = DTDSolver(root=f'{DATASETS_ROOT}/dtd/') 44 | runner.run() 45 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/colonDB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | from __init__ import DATASETS_ROOT 6 | 7 | class ClinicDBSolver(object): 8 | CLSNAMES = [ 9 | 'colon', 10 | ] 11 | 12 | def __init__(self, root='data/mvtec'): 13 | self.root = root 14 | self.meta_path = f'{root}/meta.json' 15 | 16 | def run(self): 17 | info = dict(train={}, test={}) 18 | anomaly_samples = 0 19 | normal_samples = 0 20 | for cls_name in self.CLSNAMES: 21 | cls_dir = f'{self.root}' 22 | for phase in ['test']: 23 | cls_info = [] 24 | # is_abnormal = True if specie not in ['good'] else False 25 | img_names = os.listdir(f'{cls_dir}/images') 26 | mask_names = os.listdir(f'{cls_dir}/masks') 27 | img_names.sort() 28 | mask_names.sort() if mask_names is not None else None 29 | for idx, img_name in enumerate(img_names): 30 | info_img = dict( 31 | img_path=f'{cls_dir}/images/{img_name}', 32 | mask_path=f'{cls_dir}/masks/{mask_names[idx]}', 33 | cls_name=cls_name, 34 | specie_name='', 35 | anomaly=1 36 | ) 37 | cls_info.append(info_img) 38 | if phase == 'test': 39 | if True: 40 | anomaly_samples = anomaly_samples + 1 41 | else: 42 | normal_samples = normal_samples + 1 43 | info[phase][cls_name] = cls_info 44 | with open(self.meta_path, 'w') as f: 45 | f.write(json.dumps(info, indent=4) + "\n") 46 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 47 | 48 | 49 | if __name__ == '__main__': 50 | runner = ClinicDBSolver(root=f'{DATASETS_ROOT}/cvc-colondb') 51 | runner.run() 52 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/clinicDB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | from __init__ import DATASETS_ROOT 6 | 7 | class ClinicDBSolver(object): 8 | CLSNAMES = [ 9 | 'colon', 10 | ] 11 | 12 | def __init__(self, root='data/mvtec'): 13 | self.root = root 14 | self.meta_path = f'{root}/meta.json' 15 | 16 | def run(self): 17 | info = dict(train={}, test={}) 18 | anomaly_samples = 0 19 | normal_samples = 0 20 | for cls_name in self.CLSNAMES: 21 | cls_dir = f'{self.root}' 22 | for phase in ['test']: 23 | cls_info = [] 24 | # is_abnormal = True if specie not in ['good'] else False 25 | img_names = os.listdir(f'{cls_dir}/images') 26 | mask_names = os.listdir(f'{cls_dir}/masks') 27 | img_names.sort() 28 | mask_names.sort() if mask_names is not None else None 29 | for idx, img_name in enumerate(img_names): 30 | info_img = dict( 31 | img_path=f'{cls_dir}/images/{img_name}', 32 | mask_path=f'{cls_dir}/masks/{mask_names[idx]}', 33 | cls_name=cls_name, 34 | specie_name='', 35 | anomaly=1 36 | ) 37 | cls_info.append(info_img) 38 | if phase == 'test': 39 | if True: 40 | anomaly_samples = anomaly_samples + 1 41 | else: 42 | normal_samples = normal_samples + 1 43 | info[phase][cls_name] = cls_info 44 | with open(self.meta_path, 'w') as f: 45 | f.write(json.dumps(info, indent=4) + "\n") 46 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 47 | 48 | 49 | if __name__ == '__main__': 50 | runner = ClinicDBSolver(root=f'{DATASETS_ROOT}/cvc-clinicdb') 51 | runner.run() 52 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/head_ct.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from __init__ import DATASETS_ROOT 5 | 6 | class MpddSolver(object): 7 | CLSNAMES = ['brain'] 8 | 9 | def __init__(self, root='data/mvtec'): 10 | self.root = root 11 | self.meta_path = f'{root}/meta.json' 12 | 13 | def run(self): 14 | info = dict(train={}, test={}) 15 | anomaly_samples = 0 16 | normal_samples = 0 17 | for cls_name in self.CLSNAMES: 18 | cls_dir = f'{self.root}/{cls_name}' 19 | for phase in ['test']: 20 | cls_info = [] 21 | species = os.listdir(f'{cls_dir}/{phase}') 22 | for specie in species: 23 | is_abnormal = True if specie not in ['good'] else False 24 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 25 | 26 | img_names.sort() 27 | 28 | for idx, img_name in enumerate(img_names): 29 | info_img = dict( 30 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 31 | mask_path="", 32 | cls_name=cls_name, 33 | specie_name=specie, 34 | anomaly=1 if is_abnormal else 0, 35 | ) 36 | cls_info.append(info_img) 37 | if phase == 'test': 38 | if is_abnormal: 39 | anomaly_samples = anomaly_samples + 1 40 | else: 41 | normal_samples = normal_samples + 1 42 | info[phase][cls_name] = cls_info 43 | with open(self.meta_path, 'w') as f: 44 | f.write(json.dumps(info, indent=4) + "\n") 45 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 46 | 47 | if __name__ == '__main__': 48 | runner = MpddSolver(root=f'{DATASETS_ROOT}/headct') 49 | runner.run() 50 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/tn3k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | from __init__ import DATASETS_ROOT 6 | 7 | class ClinicDBSolver(object): 8 | CLSNAMES = [ 9 | 'thyroid', 10 | ] 11 | 12 | def __init__(self, root='data/mvtec'): 13 | self.root = root 14 | self.meta_path = f'{root}/meta.json' 15 | 16 | def run(self): 17 | info = dict(train={}, test={}) 18 | anomaly_samples = 0 19 | normal_samples = 0 20 | for cls_name in self.CLSNAMES: 21 | cls_dir = f'{self.root}' 22 | for phase in ['test']: 23 | cls_info = [] 24 | # is_abnormal = True if specie not in ['good'] else False 25 | img_names = os.listdir(f'{cls_dir}/test-image') 26 | mask_names = os.listdir(f'{cls_dir}/test-mask') 27 | img_names.sort() 28 | mask_names.sort() if mask_names is not None else None 29 | for idx, img_name in enumerate(img_names): 30 | info_img = dict( 31 | img_path=f'{cls_dir}/test-image/{img_name}', 32 | mask_path=f'{cls_dir}/test-mask/{mask_names[idx]}', 33 | cls_name=cls_name, 34 | specie_name='', 35 | anomaly=1 36 | ) 37 | cls_info.append(info_img) 38 | if phase == 'test': 39 | if True: 40 | anomaly_samples = anomaly_samples + 1 41 | else: 42 | normal_samples = normal_samples + 1 43 | info[phase][cls_name] = cls_info 44 | with open(self.meta_path, 'w') as f: 45 | f.write(json.dumps(info, indent=4) + "\n") 46 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | runner = ClinicDBSolver(root=f'{DATASETS_ROOT}/tn3k') 52 | runner.run() 53 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import torchvision.transforms as transforms 3 | import torch 4 | 5 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 6 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 7 | 8 | IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406) 9 | IMAGENET_DATASET_STD = (0.229, 0.224, 0.225) 10 | 11 | def unnormalize(tensor, mean, std): 12 | mean = torch.tensor(mean).view(-1, 1, 1).to(tensor.device) 13 | std = torch.tensor(std).view(-1, 1, 1).to(tensor.device) 14 | unnormalized_tensor = tensor * std + mean 15 | return unnormalized_tensor.clamp(0, 1) 16 | 17 | def normalize(pred, max_value=None, min_value=None): 18 | if max_value is None or min_value is None: 19 | return (pred - pred.min()) / (pred.max() - pred.min()) 20 | else: 21 | return (pred - min_value) / (max_value - min_value) 22 | 23 | def _convert_to_rgb(image): 24 | return image.convert('RGB') 25 | 26 | def image_transform(image_size, mean, std): 27 | normalize = transforms.Normalize(mean=mean, std=std) 28 | tnsfrms = [ 29 | transforms.Resize(size=(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), 30 | # transforms.CenterCrop(image_size), # NOTE: No need for centercrop 31 | _convert_to_rgb, 32 | transforms.ToTensor(), 33 | normalize, 34 | ] 35 | return transforms.Compose(tnsfrms) 36 | 37 | def get_transform(args): 38 | input_transform = image_transform(args.image_size, mean = OPENAI_DATASET_MEAN, std = OPENAI_DATASET_STD) 39 | # input_transform.transforms[0] = transforms.Resize(size=(args.image_size, args.image_size), interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None) # NOTE: put antialias 40 | # input_transform.transforms[1] = transforms.CenterCrop(size=(args.image_size, args.image_size)) 41 | 42 | label_transform = transforms.Compose([ 43 | transforms.Resize((args.image_size, args.image_size)), 44 | # transforms.CenterCrop(args.image_size), # NOTE: No need for centercrop 45 | transforms.ToTensor() 46 | ]) 47 | 48 | return input_transform, label_transform 49 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/isic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from __init__ import DATASETS_ROOT 5 | 6 | class IsicSolver(object): 7 | CLSNAMES = ['skin'] 8 | 9 | def __init__(self, root='data/mvtec'): 10 | self.root = root 11 | self.meta_path = f'{root}/meta.json' 12 | 13 | def run(self): 14 | info = dict(train={}, test={}) 15 | anomaly_samples = 0 16 | normal_samples = 0 17 | for cls_name in self.CLSNAMES: 18 | cls_dir = f'{self.root}/{cls_name}' 19 | for phase in ['test']: 20 | cls_info = [] 21 | species = os.listdir(f'{cls_dir}/{phase}') 22 | for specie in species: 23 | is_abnormal = True if specie not in ['good'] else False 24 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 25 | mask_names = os.listdir(f'{cls_dir}/ground_truth/') if is_abnormal else None 26 | img_names.sort() 27 | mask_names.sort() if mask_names is not None else None 28 | for idx, img_name in enumerate(img_names): 29 | info_img = dict( 30 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 31 | mask_path=f'{cls_name}/ground_truth/{mask_names[idx]}' if is_abnormal else '', 32 | cls_name=cls_name, 33 | specie_name=specie, 34 | anomaly=1 if is_abnormal else 0, 35 | ) 36 | cls_info.append(info_img) 37 | if phase == 'test': 38 | if is_abnormal: 39 | anomaly_samples = anomaly_samples + 1 40 | else: 41 | normal_samples = normal_samples + 1 42 | info[phase][cls_name] = cls_info 43 | with open(self.meta_path, 'w') as f: 44 | f.write(json.dumps(info, indent=4) + "\n") 45 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 46 | 47 | if __name__ == '__main__': 48 | runner = IsicSolver(root=f'{DATASETS_ROOT}/isic') 49 | runner.run() 50 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/btad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from __init__ import DATASETS_ROOT 5 | 6 | class BtadSolver(object): 7 | CLSNAMES = ['01', '02', '03'] 8 | 9 | def __init__(self, root='data/mvtec'): 10 | self.root = root 11 | self.meta_path = f'{root}/meta.json' 12 | 13 | def run(self): 14 | info = dict(train={}, test={}) 15 | anomaly_samples = 0 16 | normal_samples = 0 17 | for cls_name in self.CLSNAMES: 18 | cls_dir = f'{self.root}/{cls_name}' 19 | for phase in ['train', 'test']: 20 | cls_info = [] 21 | species = os.listdir(f'{cls_dir}/{phase}') 22 | for specie in species: 23 | is_abnormal = True if specie not in ['ok'] else False 24 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 25 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 26 | img_names.sort() 27 | mask_names.sort() if mask_names is not None else None 28 | for idx, img_name in enumerate(img_names): 29 | info_img = dict( 30 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 31 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 32 | cls_name=cls_name, 33 | specie_name=specie, 34 | anomaly=1 if is_abnormal else 0, 35 | ) 36 | cls_info.append(info_img) 37 | if phase == 'test': 38 | if is_abnormal: 39 | anomaly_samples = anomaly_samples + 1 40 | else: 41 | normal_samples = normal_samples + 1 42 | info[phase][cls_name] = cls_info 43 | with open(self.meta_path, 'w') as f: 44 | f.write(json.dumps(info, indent=4) + "\n") 45 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 46 | 47 | if __name__ == '__main__': 48 | runner = BtadSolver(root=f'{DATASETS_ROOT}/btad/') 49 | runner.run() 50 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/visa.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import pandas as pd 5 | 6 | from __init__ import DATASETS_ROOT 7 | 8 | class VisASolver(object): 9 | CLSNAMES = [ 10 | 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 11 | 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', 12 | 'pcb4', 'pipe_fryum', 13 | ] 14 | 15 | def __init__(self, root='data/visa'): 16 | self.root = root 17 | self.meta_path = f'{root}/meta.json' 18 | self.phases = ['train', 'test'] 19 | self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) 20 | 21 | def run(self): 22 | columns = self.csv_data.columns # [object, split, label, image, mask] 23 | info = {phase: {} for phase in self.phases} 24 | anomaly_samples = 0 25 | normal_samples = 0 26 | for cls_name in self.CLSNAMES: 27 | cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name] 28 | for phase in self.phases: 29 | cls_info = [] 30 | cls_data_phase = cls_data[cls_data[columns[1]] == phase] 31 | cls_data_phase.index = list(range(len(cls_data_phase))) 32 | for idx in range(cls_data_phase.shape[0]): 33 | data = cls_data_phase.loc[idx] 34 | is_abnormal = True if data[2] == 'anomaly' else False 35 | info_img = dict( 36 | img_path=data[3], 37 | mask_path=data[4] if is_abnormal else '', 38 | cls_name=cls_name, 39 | specie_name='', 40 | anomaly=1 if is_abnormal else 0, 41 | ) 42 | cls_info.append(info_img) 43 | if phase == 'test': 44 | if is_abnormal: 45 | anomaly_samples = anomaly_samples + 1 46 | else: 47 | normal_samples = normal_samples + 1 48 | info[phase][cls_name] = cls_info 49 | with open(self.meta_path, 'w') as f: 50 | f.write(json.dumps(info, indent=4) + "\n") 51 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 52 | 53 | 54 | if __name__ == '__main__': 55 | runner = VisASolver(root=f'{DATASETS_ROOT}/visa') 56 | runner.run() 57 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/SDD.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import json 5 | 6 | from __init__ import DATASETS_ROOT 7 | 8 | class SDDSolver(object): 9 | CLSNAMES = [ 10 | 'electrical commutators', 11 | ] 12 | 13 | def __init__(self, root='data/mvtec'): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | 17 | def run(self): 18 | info = dict(train={}, test={}) 19 | anomaly_samples = 0 20 | normal_samples = 0 21 | for cls_name in self.CLSNAMES: 22 | cls_dir = f'{self.root}/{cls_name}' 23 | for phase in ['train', 'test']: 24 | cls_info = [] 25 | species = os.listdir(f'{cls_dir}/{phase}') 26 | for specie in species: 27 | is_abnormal = True if specie not in ['good'] else False 28 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 29 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 30 | img_names.sort() 31 | mask_names.sort() if mask_names is not None else None 32 | for idx, img_name in enumerate(img_names): 33 | info_img = dict( 34 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 35 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 36 | cls_name=cls_name, 37 | specie_name=specie, 38 | anomaly=1 if is_abnormal else 0, 39 | ) 40 | cls_info.append(info_img) 41 | if phase == 'test': 42 | if is_abnormal: 43 | anomaly_samples = anomaly_samples + 1 44 | else: 45 | normal_samples = normal_samples + 1 46 | info[phase][cls_name] = cls_info 47 | with open(self.meta_path, 'w') as f: 48 | f.write(json.dumps(info, indent=4) + "\n") 49 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 50 | 51 | 52 | if __name__ == '__main__': 53 | runner = SDDSolver(root=f'{DATASETS_ROOT}/sdd/') 54 | runner.run() 55 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/mpdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from __init__ import DATASETS_ROOT 5 | 6 | class MpddSolver(object): 7 | CLSNAMES = ['bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate', 'tubes'] 8 | 9 | def __init__(self, root='data/mvtec'): 10 | self.root = root 11 | self.meta_path = f'{root}/meta.json' 12 | 13 | def run(self): 14 | info = dict(train={}, test={}) 15 | anomaly_samples = 0 16 | normal_samples = 0 17 | for cls_name in self.CLSNAMES: 18 | cls_dir = f'{self.root}/{cls_name}' 19 | for phase in ['train', 'test']: 20 | cls_info = [] 21 | species = os.listdir(f'{cls_dir}/{phase}') 22 | for specie in species: 23 | is_abnormal = True if specie not in ['good'] else False 24 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 25 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 26 | img_names.sort() 27 | mask_names.sort() if mask_names is not None else None 28 | for idx, img_name in enumerate(img_names): 29 | info_img = dict( 30 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 31 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 32 | cls_name=cls_name, 33 | specie_name=specie, 34 | anomaly=1 if is_abnormal else 0, 35 | ) 36 | cls_info.append(info_img) 37 | if phase == 'test': 38 | if is_abnormal: 39 | anomaly_samples = anomaly_samples + 1 40 | else: 41 | normal_samples = normal_samples + 1 42 | info[phase][cls_name] = cls_info 43 | with open(self.meta_path, 'w') as f: 44 | f.write(json.dumps(info, indent=4) + "\n") 45 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 46 | 47 | if __name__ == '__main__': 48 | runner = MpddSolver(root=f'{DATASETS_ROOT}/mpdd') 49 | runner.run() 50 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/endoTect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | from __init__ import DATASETS_ROOT 6 | 7 | class HyperSolver(object): 8 | CLSNAMES = [ 9 | 'colon', 10 | ] 11 | 12 | def __init__(self, root='data/mvtec'): 13 | self.root = root 14 | self.meta_path = f'{root}/meta.json' 15 | 16 | def run(self): 17 | info = dict(train={}, test={}) 18 | anomaly_samples = 0 19 | normal_samples = 0 20 | for cls_name in self.CLSNAMES: 21 | cls_dir = f'{self.root}' 22 | for phase in ['test']: 23 | cls_info = [] 24 | species = set(os.listdir(f'{cls_dir}'))-set(['masks']) 25 | print("species", species) 26 | for specie in species: 27 | is_abnormal = True if specie not in ['good'] else False 28 | img_names = os.listdir(f'{cls_dir}/{specie}') 29 | mask_names = os.listdir(f'{cls_dir}/masks/') if is_abnormal else None 30 | img_names.sort() 31 | mask_names.sort() if mask_names is not None else None 32 | assert len(img_names) == len(mask_names) if mask_names is not None else True 33 | for idx, img_name in enumerate(img_names): 34 | info_img = dict( 35 | img_path=f'{cls_dir}/{specie}/{img_name}', 36 | mask_path=f'{cls_dir}/masks/{mask_names[idx]}' if is_abnormal else '', 37 | cls_name=cls_name, 38 | specie_name=specie, 39 | anomaly=1 if is_abnormal else 0, 40 | ) 41 | cls_info.append(info_img) 42 | if phase == 'test': 43 | if is_abnormal: 44 | anomaly_samples = anomaly_samples + 1 45 | else: 46 | normal_samples = normal_samples + 1 47 | info[phase][cls_name] = cls_info 48 | with open(self.meta_path, 'w') as f: 49 | f.write(json.dumps(info, indent=4) + "\n") 50 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | runner = HyperSolver(root=f'{DATASETS_ROOT}/endo') 56 | runner.run() 57 | -------------------------------------------------------------------------------- /reproduce.sh: -------------------------------------------------------------------------------- 1 | # First train a model on MVTec and a model on VisA 2 | # Then use model train on MVTec for testing on the rest of datasets and the model train on VisA to test on MVTec 3 | # for both version Crane and Crane+ 4 | model_name=$1 5 | device=$2 6 | 7 | run_for_trained_on_mvtec() { 8 | local base_command="$1" 9 | shift 10 | local datasets=("$@") 11 | 12 | for dataset in "${datasets[@]}"; do 13 | local command="$base_command --dataset $dataset --model_name trained_on_mvtec_$cur_model_name" 14 | eval "$command" 15 | done 16 | } 17 | 18 | # Table 1 Training Scheme 19 | # Crane (woD-Attn) 20 | cur_model_name="${model_name}_crane" 21 | echo "The name for base version (Crane) is: $cur_model_name" 22 | 23 | python train.py --model_name "$cur_model_name" --dataset mvtec --device "$device" --features_list 6 12 18 24 --dino_model none --why "Evaluation purpose" 24 | python train.py --model_name "$cur_model_name" --dataset visa --device "$device" --features_list 6 12 18 24 --dino_model none --why "Evaluation purpose" 25 | 26 | base_command="python test.py --devices $device --epoch 5 --dino_model none --soft_mean True --features_list 6 12 18 24 --visualize False" 27 | eval "$base_command --dataset mvtec --model_name trained_on_visa_$cur_model_name" 28 | run_for_trained_on_mvtec "$base_command" visa mpdd sdd btad dtd dagm 29 | run_for_trained_on_mvtec "$base_command" brainmri headct br35h isic tn3k cvc-colondb cvc-clinicdb 30 | 31 | # Table 1 Training Scheme 32 | # Crane+ 33 | cur_model_name="${model_name}_cranep" 34 | echo "The name for enhanced version (Crane+) is: $cur_model_name" 35 | 36 | python train.py --model_name "$cur_model_name" --dataset mvtec --device "$device" --features_list 24 --why "Evaluation purpose" 37 | python train.py --model_name "$cur_model_name" --dataset visa --device "$device" --features_list 24 --why "Evaluation purpose" 38 | 39 | base_command="python test.py --devices $device --epoch 5 --dino_model dinov2 --soft_mean True --features_list 24 --visualize False" 40 | eval "$base_command --dataset mvtec --model_name trained_on_visa_$cur_model_name" 41 | run_for_trained_on_mvtec "$base_command" visa mpdd sdd btad dtd 42 | eval "$base_command --dataset dagm --model_name trained_on_visa_$cur_model_name --soft_mean True " 43 | base_command="python test.py --devices $device --epoch 1 --dino_model dinov2 --soft_mean True --features_list 24 --visualize False" 44 | run_for_trained_on_mvtec "$base_command" brainmri headct br35h isic tn3k cvc-colondb cvc-clinicdb -------------------------------------------------------------------------------- /dataset/generate_dataset_json/mvtec.py: -------------------------------------------------------------------------------- 1 | 2 | # %load mvtec.py 3 | # %%writefile mvtec.py 4 | # Cell magic function starts with double %, and should be placed at the very first line of a cell 5 | 6 | import os 7 | import json 8 | 9 | from __init__ import DATASETS_ROOT 10 | 11 | class MVTecSolver(object): 12 | CLSNAMES = [ 13 | 'bottle', 'cable', 'capsule', 'carpet', 'grid', 14 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 15 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 16 | ] 17 | 18 | def __init__(self, root='data/mvtec'): 19 | self.root = root 20 | self.meta_path = f'{root}/meta.json' 21 | 22 | def run(self): 23 | info = dict(train={}, test={}) 24 | anomaly_samples = 0 25 | normal_samples = 0 26 | for cls_name in self.CLSNAMES: 27 | cls_dir = f'{self.root}/{cls_name}' 28 | for phase in ['train', 'test']: 29 | cls_info = [] 30 | species = os.listdir(f'{cls_dir}/{phase}') 31 | for specie in species: 32 | is_abnormal = True if specie not in ['good'] else False 33 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 34 | mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None 35 | img_names.sort() 36 | mask_names.sort() if mask_names is not None else None 37 | for idx, img_name in enumerate(img_names): 38 | info_img = dict( 39 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 40 | mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', 41 | cls_name=cls_name, 42 | specie_name=specie, 43 | anomaly=1 if is_abnormal else 0, 44 | ) 45 | cls_info.append(info_img) 46 | if phase == 'test': 47 | if is_abnormal: 48 | anomaly_samples = anomaly_samples + 1 49 | else: 50 | normal_samples = normal_samples + 1 51 | info[phase][cls_name] = cls_info 52 | with open(self.meta_path, 'w') as f: 53 | f.write(json.dumps(info, indent=4) + "\n") 54 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 55 | if __name__ == '__main__': 56 | runner = MVTecSolver(root=f'{DATASETS_ROOT}/mvtec') 57 | runner.run() 58 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | import datetime 5 | import traceback 6 | 7 | def save_successful_run(args, file='run_log.txt'): 8 | run_log_path = os.path.join('./', file) 9 | os.makedirs('./', exist_ok=True) 10 | with open(run_log_path, 'a') as f: 11 | f.write("\n") 12 | f.write(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") 13 | f.write(f"Model Name: {args.dataset}_{args.model_name}\n") 14 | # f.write(f"Row Column: {args.row_column}\n") # Assuming args.row_column exists 15 | f.write("Run was successful.\n") 16 | 17 | def save_error_details_to_file(args, error): 18 | error_log_path = os.path.join(args.save_path, 'error_log.txt') 19 | os.makedirs(args.save_path, exist_ok=True) 20 | with open(error_log_path, 'a') as f: 21 | f.write("\n") 22 | f.write(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") 23 | f.write(f"Error: {str(error)}\n") 24 | f.write("Stack Trace:\n") 25 | traceback.print_exc(file=f) 26 | 27 | def save_args_to_file(args, command, log_dir=''): 28 | args_file_path = os.path.join(args.save_path, log_dir, 'args.txt') 29 | os.makedirs(os.path.dirname(args_file_path), exist_ok=True) 30 | if os.path.exists(args_file_path): 31 | print(f"Warning: The file {args_file_path} already exists and will be overwritten.") 32 | with open(args_file_path, 'a') as f: # Change 'w' to 'a' to append to the file 33 | f.write("\n") # Add new line before writing to the file 34 | f.write(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") # Add timestamp 35 | for arg, value in vars(args).items(): 36 | f.write(f"{arg}: {value}\n") 37 | f.write(f"Command arguments: {' '.join(command)}\n") # Add the command arguments to the file 38 | 39 | def get_logger(save_path): 40 | if not os.path.exists(save_path): 41 | os.makedirs(save_path) 42 | 43 | txt_path = os.path.join(save_path, 'log.txt') 44 | # logger 45 | root_logger = logging.getLogger() 46 | for handler in root_logger.handlers[:]: 47 | root_logger.removeHandler(handler) 48 | root_logger.setLevel(logging.WARNING) 49 | logger = logging.getLogger('test') 50 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 51 | datefmt='%y-%m-%d %H:%M:%S') 52 | logger.setLevel(logging.INFO) 53 | file_handler = logging.FileHandler(txt_path, mode='a') 54 | file_handler.setFormatter(formatter) 55 | logger.addHandler(file_handler) 56 | console_handler = logging.StreamHandler() 57 | console_handler.setFormatter(formatter) 58 | logger.addHandler(console_handler) 59 | return logger -------------------------------------------------------------------------------- /dataset/generate_dataset_json/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pathlib import Path 4 | 5 | def generate_meta(dataset_root, image_exts=('.png', '.jpg', '.jpeg'), mask_exts=('.png', '.jpg', '.jpeg')): 6 | import os 7 | import json 8 | from pathlib import Path 9 | 10 | meta = {"train": {}, "test": {}} 11 | dataset_root = Path(dataset_root) 12 | 13 | def is_valid(file, exts): 14 | return file.suffix.lower() in exts 15 | 16 | # Train section 17 | train_dir = dataset_root / "train" 18 | if train_dir.exists(): 19 | for cls_dir in train_dir.iterdir(): 20 | cls_name = cls_dir.name 21 | samples = [] 22 | good_dir = cls_dir / "good" 23 | if not good_dir.exists(): 24 | continue 25 | for img_path in sorted(good_dir.glob("*")): 26 | if not is_valid(img_path, image_exts): continue 27 | samples.append({ 28 | "img_path": str(img_path.relative_to(dataset_root)), 29 | "mask_path": "", 30 | "cls_name": cls_name, 31 | "specie_name": "good", 32 | "anomaly": 0 33 | }) 34 | meta["train"][cls_name] = samples 35 | 36 | # Test section 37 | test_dir = dataset_root / "test" 38 | if test_dir.exists(): 39 | for cls_dir in test_dir.iterdir(): 40 | cls_name = cls_dir.name 41 | samples = [] 42 | for specie_dir in cls_dir.iterdir(): 43 | if specie_dir.name == "masks" or not specie_dir.is_dir(): 44 | continue 45 | for img_path in sorted(specie_dir.glob("*")): 46 | if not is_valid(img_path, image_exts): continue 47 | anomaly = 0 if specie_dir.name.lower() == "good" else 1 48 | mask_path = "" 49 | if anomaly == 1: 50 | mask_dir = cls_dir / "masks" 51 | for ext in mask_exts: 52 | candidate = mask_dir / f"{img_path.stem}_mask{ext}" 53 | if candidate.exists(): 54 | mask_path = str(candidate.relative_to(dataset_root)) 55 | break 56 | samples.append({ 57 | "img_path": str(img_path.relative_to(dataset_root)), 58 | "mask_path": mask_path, 59 | "cls_name": cls_name, 60 | "specie_name": specie_dir.name, 61 | "anomaly": anomaly 62 | }) 63 | meta["test"][cls_name] = samples 64 | 65 | # Save 66 | out_path = dataset_root / "meta.json" 67 | with open(out_path, "w") as f: 68 | json.dump(meta, f, indent=4) 69 | print(f"✅ meta.json generated at {out_path}") 70 | -------------------------------------------------------------------------------- /dataset/generate_dataset_json/DAGM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | from __init__ import DATASETS_ROOT 6 | 7 | class DAGMSolver(object): 8 | CLSNAMES = [ 9 | 'Class1','Class2','Class3','Class4','Class5','Class6','Class7','Class8','Class9','Class10' 10 | ] 11 | 12 | def __init__(self, root='data/mvtec'): 13 | self.root = root 14 | self.meta_path = f'{root}/meta.json' 15 | 16 | def run(self): 17 | info = dict(train={}, test={}) 18 | anomaly_samples = 0 19 | normal_samples = 0 20 | for cls_name in self.CLSNAMES: 21 | cls_dir = f'{self.root}/{cls_name}' 22 | for phase in ['Train', 'Test']: 23 | cls_info = [] 24 | x, y, mask_names_none= [], [], [] 25 | img_dir = os.listdir(f'{cls_dir}/{phase}') 26 | 27 | mask_names = os.listdir(f'{cls_dir}/{phase}/Label') 28 | 29 | img_fpath_list = sorted([f 30 | for f in img_dir 31 | if f.endswith('.PNG')]) 32 | gt_fpath_list = sorted([f 33 | for f in mask_names 34 | if f.endswith('.PNG')]) 35 | 36 | img_exclude_list = [f.split("_")[0] + ".PNG" for f in gt_fpath_list] 37 | 38 | img_normal_fpath_list = list(set(img_fpath_list) - set(img_exclude_list)) 39 | 40 | x.extend(img_normal_fpath_list + img_exclude_list) 41 | 42 | y.extend([0] * len(img_normal_fpath_list) + [1]* len(img_exclude_list)) 43 | 44 | mask_names_none.extend([None] * len(img_normal_fpath_list) + gt_fpath_list) 45 | 46 | for idx, img_name in enumerate(x): 47 | info_img = dict( 48 | img_path=f'{cls_name}/{phase}/{img_name}', 49 | mask_path=f'{cls_name}/{phase}/Label/{mask_names_none[idx]}' if mask_names_none[idx] != None else '', 50 | cls_name=cls_name, 51 | specie_name='', 52 | anomaly=1 if y[idx] == 1 else 0, 53 | ) 54 | cls_info.append(info_img) 55 | if phase == 'Test': 56 | if y[idx] == 1: 57 | anomaly_samples = anomaly_samples + 1 58 | else: 59 | normal_samples = normal_samples + 1 60 | info[phase.lower()][cls_name] = cls_info 61 | with open(self.meta_path, 'w') as f: 62 | f.write(json.dumps(info, indent=4) + "\n") 63 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | runner = DAGMSolver(root=f'{DATASETS_ROOT}/dagm') 69 | runner.run() 70 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from utils.transform import normalize 4 | import numpy as np 5 | import torch 6 | 7 | def normalize(pred, max_value=None, min_value=None): 8 | if max_value is None or min_value is None: 9 | return (pred - pred.min()) / (pred.max() - pred.min()) 10 | else: 11 | return (pred - min_value) / (max_value - min_value) 12 | 13 | def denormalize(tensor, mean, std): 14 | # Convert mean and std to tensors for broadcasting 15 | mean = torch.tensor(mean, device=tensor.device).view(-1, 1, 1) 16 | std = torch.tensor(std, device=tensor.device).view(-1, 1, 1) 17 | 18 | # Denormalize the tensor: (value * std) + mean 19 | denormalized_tensor = tensor * std + mean 20 | return denormalized_tensor 21 | 22 | def apply_ad_scoremap(image, scoremap, alpha=0.5): 23 | np_image = np.asarray(image, dtype=float) 24 | # Convert scoremap from a PyTorch tensor to a NumPy array 25 | if isinstance(scoremap, torch.Tensor): 26 | scoremap = scoremap.detach().cpu().numpy() # Convert tensor to NumPy array 27 | scoremap = (scoremap * 255).astype(np.uint8) 28 | scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) 29 | scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) 30 | return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8) 31 | 32 | def visualizer(pathes, anomaly_map, masks, img_size, cls_name, save_path='./vis_img/', draw_contours=True): 33 | for idx, path in enumerate(pathes): 34 | cls = path.split('/')[-2] 35 | filename = path.split('/')[-1] 36 | 37 | # Save the final visualization 38 | save_vis = os.path.join(save_path, 'imgs', str(cls_name), str(cls)) 39 | os.makedirs(save_vis, exist_ok=True) 40 | 41 | # Load original image and resize 42 | vis = cv2.cvtColor(cv2.resize(cv2.imread(path), (img_size, img_size)), cv2.COLOR_BGR2RGB) 43 | filename_orig = filename.split('.')[0] + "_orig." + filename.split('.')[-1] # Append '_orig' 44 | cv2.imwrite(os.path.join(save_vis, filename_orig), vis) 45 | 46 | # Use the provided mask (it's guaranteed to be available) 47 | gt_mask = masks[idx].detach().cpu().numpy() 48 | gt_mask = (gt_mask > 0).astype(np.uint8) * 255 # Convert to binary (0 or 255) 49 | 50 | # Normalize and apply anomaly map 51 | mask = normalize(anomaly_map[idx]) 52 | vis = apply_ad_scoremap(vis, mask) 53 | 54 | # Convert back to BGR for OpenCV 55 | vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) 56 | filename_pred = filename.split('.')[0] + "_pred." + filename.split('.')[-1] # Append '_pred' 57 | cv2.imwrite(os.path.join(save_vis, filename_pred), vis) 58 | 59 | # Find and overlay contours (only if draw_contours is True) 60 | if draw_contours: 61 | filename_ctr = filename.split('.')[0] + "_cntr." + filename.split('.')[-1] # Append '_cntr' before file extension 62 | contours, _ = cv2.findContours(gt_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 63 | cv2.drawContours(vis, contours, -1, (120, 251, 120), 2) # Pale green contours 64 | cv2.imwrite(os.path.join(save_vis, filename_ctr), vis) 65 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /models/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from math import exp 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class BinaryFocalLoss(nn.Module): 12 | def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): 13 | super(BinaryFocalLoss, self).__init__() 14 | self.alpha = alpha 15 | self.gamma = gamma 16 | self.reduction = reduction 17 | 18 | def forward(self, inputs, targets): 19 | targets = targets.float() 20 | bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 21 | probas = torch.sigmoid(inputs) 22 | pt = probas * targets + (1 - probas) * (1 - targets) 23 | focal_factor = (1 - pt).pow(self.gamma) 24 | alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets) 25 | loss = alpha_factor * focal_factor * bce_loss 26 | 27 | # Apply reduction 28 | if self.reduction == 'mean': 29 | return loss.mean() 30 | elif self.reduction == 'sum': 31 | return loss.sum() 32 | else: 33 | return loss 34 | 35 | # Example usage 36 | if __name__ == "__main__": 37 | # Sample predictions (logits) and ground truth 38 | inputs = torch.tensor([0.2, -1.3, 1.2, 0.5], dtype=torch.float32) # Raw model outputs (logits) 39 | targets = torch.tensor([1, 0, 1, 0], dtype=torch.float32) # Binary ground truth labels 40 | 41 | # Instantiate the BinaryFocalLoss class 42 | focal_loss_fn = BinaryFocalLoss(alpha=0.25, gamma=2.0, reduction='mean') 43 | 44 | # Compute the focal loss 45 | loss = focal_loss_fn(inputs, targets) 46 | 47 | print(f"Focal Loss: {loss.item()}") 48 | 49 | 50 | class FocalLoss(nn.Module): 51 | """ 52 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 53 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 54 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 55 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 56 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 57 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 58 | focus on hard misclassified example 59 | :param smooth: (float,double) smooth value when cross entropy 60 | :param balance_index: (int) balance class index, should be specific when alpha is float 61 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 62 | """ 63 | 64 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): 65 | super(FocalLoss, self).__init__() 66 | self.apply_nonlin = apply_nonlin 67 | self.alpha = alpha 68 | self.gamma = gamma 69 | self.balance_index = balance_index 70 | self.smooth = smooth 71 | self.size_average = size_average 72 | 73 | if self.smooth is not None: 74 | if self.smooth < 0 or self.smooth > 1.0: 75 | raise ValueError('smooth value should be in [0,1]') 76 | 77 | def forward(self, logit, target): 78 | if self.apply_nonlin is not None: 79 | logit = self.apply_nonlin(logit) 80 | num_class = logit.shape[1] 81 | 82 | if logit.dim() > 2: 83 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 84 | logit = logit.view(logit.size(0), logit.size(1), -1) 85 | logit = logit.permute(0, 2, 1).contiguous() 86 | logit = logit.view(-1, logit.size(-1)) 87 | target = torch.squeeze(target, 1) 88 | target = target.view(-1, 1) 89 | alpha = self.alpha 90 | 91 | if alpha is None: 92 | alpha = torch.ones(num_class, 1) 93 | elif isinstance(alpha, (list, np.ndarray)): 94 | assert len(alpha) == num_class 95 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 96 | alpha = alpha / alpha.sum() 97 | elif isinstance(alpha, float): 98 | alpha = torch.ones(num_class, 1) 99 | alpha = alpha * (1 - self.alpha) 100 | alpha[self.balance_index] = self.alpha 101 | 102 | else: 103 | raise TypeError('Not support alpha type') 104 | 105 | if alpha.device != logit.device: 106 | alpha = alpha.to(logit.device) 107 | 108 | idx = target.cpu().long() 109 | 110 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 111 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 112 | if one_hot_key.device != logit.device: 113 | one_hot_key = one_hot_key.to(logit.device) 114 | 115 | if self.smooth: 116 | one_hot_key = torch.clamp( 117 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) 118 | pt = (one_hot_key * logit).sum(1) + self.smooth 119 | logpt = pt.log() 120 | 121 | gamma = self.gamma 122 | 123 | alpha = alpha[idx] 124 | alpha = torch.squeeze(alpha) 125 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 126 | 127 | if self.size_average: 128 | loss = loss.mean() 129 | return loss 130 | 131 | class BinaryFocalLoss(nn.Module): 132 | def __init__(self, alpha=0.25, gamma=2): 133 | super(BinaryFocalLoss, self).__init__() 134 | self.alpha = alpha 135 | self.gamma = gamma 136 | 137 | def forward(self, inputs, targets): 138 | probs = torch.sigmoid(inputs) 139 | probs = torch.clamp(probs, min=1e-6, max=1-1e-6) 140 | bce_loss = targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs) 141 | focal_loss = -(self.alpha * (1 - probs) ** self.gamma * bce_loss) 142 | return torch.mean(focal_loss) 143 | 144 | class BinaryDiceLoss(nn.Module): 145 | def __init__(self): 146 | super(BinaryDiceLoss, self).__init__() 147 | 148 | def forward(self, input, targets): 149 | # 获取每个批次的大小 N 150 | N = targets.size()[0] 151 | # 平滑变量 152 | smooth = 1 153 | # 将宽高 reshape 到同一纬度 154 | input_flat = input.view(N, -1) 155 | targets_flat = targets.view(N, -1) 156 | 157 | intersection = input_flat * targets_flat 158 | N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth) 159 | # 计算一个批次中平均每张图的损失 160 | loss = 1 - N_dice_eff.sum() / N 161 | return loss 162 | 163 | def smooth(arr, lamda1): 164 | new_array = arr 165 | arr2 = torch.zeros_like(arr) 166 | arr2[:, :-1, :] = arr[:, 1:, :] 167 | arr2[:, -1, :] = arr[:, -1, :] 168 | 169 | new_array2 = torch.zeros_like(new_array) 170 | new_array2[:, :, :-1] = new_array[:, :, 1:] 171 | new_array2[:, :, -1] = new_array[:, :, -1] 172 | loss = (torch.sum((arr2 - arr) ** 2) + torch.sum((new_array2 - new_array) ** 2)) / 2 173 | return lamda1 * loss 174 | 175 | def sparsity(arr, target, lamda2): 176 | if target == 0: 177 | loss = torch.mean(torch.norm(arr, dim=0)) 178 | else: 179 | loss = torch.mean(torch.norm(1-arr, dim=0)) 180 | return lamda2 * loss -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /models/model_load.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, ToTensor, Normalize 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | from .build_model import build_model 15 | from torchvision.transforms import InterpolationMode 16 | 17 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 18 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 19 | 20 | # https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/pretrained.py#L343 21 | 22 | __all__ = ["available_models", "load"] 23 | 24 | _MODELS = { 25 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 26 | } 27 | 28 | def _download( 29 | url: str, 30 | cache_dir: Union[str, None] = None, 31 | ): 32 | 33 | if not cache_dir: 34 | cache_dir = os.path.expanduser("~/.cache/clip") 35 | os.makedirs(cache_dir, exist_ok=True) 36 | filename = os.path.basename(url) 37 | 38 | if 'openaipublic' in url: 39 | expected_sha256 = url.split("/")[-2] 40 | elif 'mlfoundations' in url: 41 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 42 | else: 43 | expected_sha256 = '' 44 | 45 | download_target = os.path.join(cache_dir, filename) 46 | 47 | if os.path.exists(download_target) and not os.path.isfile(download_target): 48 | raise RuntimeError(f"{download_target} exists and is not a regular file") 49 | 50 | if os.path.isfile(download_target): 51 | if expected_sha256: 52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 53 | return download_target 54 | else: 55 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 56 | else: 57 | return download_target 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC), 82 | #CenterCrop(n_px), # rm center crop to explain whole image 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 95 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 96 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 97 | state_dict = checkpoint['state_dict'] 98 | else: 99 | state_dict = checkpoint 100 | if next(iter(state_dict.items()))[0].startswith('module'): 101 | state_dict = {k[7:]: v for k, v in state_dict.items()} 102 | return state_dict 103 | 104 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", design_details = None, jit: bool = False, download_root: str = None): 105 | """Load a CLIP model 106 | 107 | Parameters 108 | ---------- 109 | name : str 110 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 111 | 112 | device : Union[str, torch.device] 113 | The device to put the loaded model 114 | 115 | jit : bool 116 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 117 | 118 | download_root: str 119 | path to download the model files; by default, it uses "~/.cache/clip" 120 | 121 | Returns 122 | ------- 123 | model : torch.nn.Module 124 | The CLIP model 125 | 126 | preprocess : Callable[[PIL.Image], torch.Tensor] 127 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 128 | """ 129 | print("name", name) 130 | if name in _MODELS: 131 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 132 | elif os.path.isfile(name): 133 | model_path = name 134 | else: 135 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 136 | 137 | with open(model_path, 'rb') as opened_file: 138 | try: 139 | # loading JIT archive 140 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 141 | state_dict = None 142 | except RuntimeError: 143 | # loading saved state dict 144 | if jit: 145 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 146 | jit = False 147 | state_dict = torch.load(opened_file, map_location="cpu") 148 | 149 | if not jit: 150 | model = build_model(name, state_dict or model.state_dict(), design_details).to(device) 151 | if str(device) == "cpu": 152 | model.float() 153 | return model, _transform(model.visual.input_resolution) 154 | 155 | # patch the device names 156 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 157 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 158 | 159 | def patch_device(module): 160 | try: 161 | graphs = [module.graph] if hasattr(module, "graph") else [] 162 | except RuntimeError: 163 | graphs = [] 164 | 165 | if hasattr(module, "forward1"): 166 | graphs.append(module.forward1.graph) 167 | 168 | for graph in graphs: 169 | for node in graph.findAllNodes("prim::Constant"): 170 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 171 | node.copyAttributes(device_node) 172 | 173 | model.apply(patch_device) 174 | patch_device(model.encode_image) 175 | patch_device(model.encode_text) 176 | 177 | # patch dtype to float32 on CPU 178 | if str(device) == "cpu": 179 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 180 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 181 | float_node = float_input.node() 182 | 183 | def patch_float(module): 184 | try: 185 | graphs = [module.graph] if hasattr(module, "graph") else [] 186 | except RuntimeError: 187 | graphs = [] 188 | 189 | if hasattr(module, "forward1"): 190 | graphs.append(module.forward1.graph) 191 | 192 | for graph in graphs: 193 | for node in graph.findAllNodes("aten::to"): 194 | inputs = list(node.inputs()) 195 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 196 | if inputs[i].node()["value"] == 5: 197 | inputs[i].node().copyAttributes(float_node) 198 | 199 | model.apply(patch_float) 200 | patch_float(model.encode_image) 201 | patch_float(model.encode_text) 202 | 203 | model.float() 204 | 205 | return model, _transform(model.input_resolution.item()) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import os 5 | import argparse 6 | import hashlib 7 | import humanhash 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader, TensorDataset, Dataset 10 | 11 | # Utility shortcuts exposed at the package level 12 | 13 | def setup_seed(seed): 14 | os.environ['PYTHONHASHSEED'] = str(seed) 15 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # For CUDA 10.2+ 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | torch.use_deterministic_algorithms(True, warn_only=True) 24 | 25 | def seed_worker(worker_id): 26 | worker_seed = 111 + worker_id 27 | np.random.seed(worker_seed) 28 | random.seed(worker_seed) 29 | 30 | def turn_gradient_off(model): 31 | print("Turning off gradients in both the image and the text encoder") 32 | for _, param in model.named_parameters(): 33 | param.requires_grad_(False) 34 | 35 | enabled = set() 36 | for name, param in model.named_parameters(): 37 | if param.requires_grad: 38 | enabled.add(name) 39 | print(f"Parameters to be updated: {enabled}") 40 | 41 | model.eval() 42 | return model 43 | 44 | def str2bool(v): 45 | if isinstance(v, bool): 46 | return v 47 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 48 | return True 49 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 50 | return False 51 | else: 52 | raise argparse.ArgumentTypeError('Boolean value expected.') 53 | 54 | def make_human_readable_name(args, exclude=['model_name', 'dataset', 'data_path', 'datasets_root_dir', 55 | 'checkpoint_path', 'training_path', 'Timestamp', 'why', 56 | 'metrics', 'devices', 'epoch', 'visualize', 'help', None]): 57 | args=vars(args) 58 | name_value_pairs = [ 59 | f"{k}_{v}" 60 | for k,v in args.items() 61 | if k not in exclude # Exclude "help" or invalid arguments 62 | ] 63 | combined = ",".join(sorted(name_value_pairs)) # Sorting ensures consistent order 64 | hash_value = hashlib.sha256(combined.encode()).hexdigest() 65 | human_hash = humanhash.humanize(hash_value, words=2) 66 | return human_hash 67 | 68 | def check_args_conformance_with_train_args(args, training_path): 69 | # Check if args.txt exists in the training_path 70 | args_file_path = os.path.join(training_path, 'args.txt') 71 | configurations_dict = {} # Dictionary to store configurations 72 | last_config = {} 73 | mismatch_descriptions = [] # List to store mismatch descriptions 74 | if os.path.exists(args_file_path): 75 | with open(args_file_path, 'r') as f: 76 | # Read the entire content of the file 77 | file_content = f.read().strip() 78 | 79 | # Split the content into different configurations based on the 'Timestamp' keyword 80 | configurations = file_content.split('Timestamp:') 81 | 82 | # Iterate over each configuration to populate the dictionary 83 | for config in configurations: 84 | if config.strip(): 85 | # Convert the configuration to a dictionary 86 | file_args_dict = {} 87 | for line in config.strip().split('\n'): 88 | if ':' in line: 89 | key, value = line.split(':', 1) 90 | file_args_dict[key.strip()] = value.strip() 91 | 92 | # Store the configuration in the dictionary with a unique key 93 | timestamp = file_args_dict.get('Timestamp', 'Unknown') 94 | configurations_dict[timestamp] = file_args_dict 95 | 96 | # Select the last configuration for comparison 97 | if configurations_dict: 98 | last_timestamp = max(configurations_dict.keys()) 99 | last_config = configurations_dict[last_timestamp] 100 | 101 | # Compare with current args 102 | print(f"Checking configuration with the most recent timestamp: {last_timestamp}") 103 | non_critical_mismatches = { 104 | 'dataset', 'device', 'log_dir', 'model_name','dataset_category', 'use_scorebase_pooling', 'aug_rate',\ 105 | 'features_list', 'train_with_img_cls_type', 'epoch', 'type', 'save_path', 'train_with_img_cls_prob', 'why', 106 | } 107 | if args.dataset != last_config['dataset']: 108 | non_critical_mismatches.update(['k_shot', 'portion']) 109 | for key, value in vars(args).items(): 110 | if key in last_config: 111 | if str(value) != last_config[key] and key not in non_critical_mismatches: 112 | description = f"Argument mismatch for {key}: {value}, but file has {last_config[key]}" 113 | print(description) 114 | mismatch_descriptions.append(description) 115 | else: 116 | description = f"Argument {key} not found in the most recent args.txt configuration" 117 | print(description) 118 | else: 119 | print("No valid configuration found in args.txt") 120 | else: 121 | print(f"No args.txt file found in {training_path}") 122 | 123 | return last_config, mismatch_descriptions 124 | 125 | class CustomTensorDataset(Dataset): 126 | def __init__(self, dataset_features, paths): 127 | self.dataset = TensorDataset(*dataset_features) 128 | self.img_paths = paths 129 | self.length = len(self.dataset) 130 | 131 | assert len(self.dataset) == len(self.img_paths), \ 132 | "Number of images and paths must be the same." 133 | 134 | def __getitem__(self, index): 135 | labels, cls_ids, image_features, patch_features, abnorm_masks = self.dataset[index] 136 | sample = { 137 | 'anomaly': labels, 138 | 'cls_id': cls_ids, 139 | 'image_features': image_features, 140 | 'patch_features': patch_features, 141 | 'abnorm_mask': abnorm_masks, 142 | 'img_path': self.img_paths[index] 143 | } 144 | 145 | return sample 146 | 147 | def __len__(self): 148 | return self.length 149 | 150 | def prepare_encode_image_module(model, features_list): 151 | class EncodeImageModule(torch.nn.Module): 152 | def __init__(self, model, features_list): 153 | super(EncodeImageModule, self).__init__() 154 | self.model = model 155 | self.features_list = features_list 156 | 157 | def forward(self, image): 158 | image_features, patch_features = self.model.encode_image(image, self.features_list, self_cor_attn_layers=20) 159 | # image_features = image_features / image_features.norm(dim=-1, keepdim=True) 160 | # patch_features = [patch_feature / patch_feature.norm(dim=-1, keepdim=True) for patch_feature in patch_features] 161 | patch_features = torch.stack(patch_features, dim=1) 162 | return image_features, patch_features 163 | 164 | encode_image_module = EncodeImageModule(model, features_list) 165 | encode_image_module = torch.nn.DataParallel(encode_image_module) 166 | encode_image_module.cuda() 167 | return encode_image_module 168 | 169 | def precompute_image_features(data, encode_image_module, args): 170 | batch_size = 2 if args.dino_model == 'dino' else 8 171 | batch_size *= torch.cuda.device_count() 172 | 173 | g = torch.Generator() 174 | g.manual_seed(args.seed) 175 | test_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=torch.cuda.device_count(),\ 176 | prefetch_factor=2, pin_memory=True, generator=g, worker_init_fn=seed_worker) 177 | print(f"Total samples to process: {len(test_dataloader) * test_dataloader.batch_size}") 178 | 179 | device = 'cuda' 180 | data_items = [[] for _ in range(5)] 181 | img_paths = [] 182 | for items in tqdm(test_dataloader): 183 | image = items['img'].to(device) 184 | label = items['anomaly'] 185 | cls_id = items['cls_id'] 186 | abnorm_mask = items['abnorm_mask'] 187 | path = items['img_path'] 188 | 189 | with torch.no_grad(): 190 | image_features, patch_features = encode_image_module(image) 191 | 192 | for index, item in enumerate((label, cls_id, image_features, patch_features, abnorm_mask)): 193 | data_items[index].append(item.cpu()) 194 | img_paths.extend(path) 195 | 196 | data_items = [torch.cat(item_list, dim=0) for item_list in data_items] 197 | return (data_items, img_paths) -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # $Crane$; Context-Guided Prompt Learning and Attention Refinement for Zero-Shot Anomaly Detection 5 | The repository contains official code for $Crane$, a zero-shot anomaly detection framework built on CLIP. 6 | 7 | --- 8 | 9 | ## 📌 Table of Contents 10 | 11 | - [Introduction](#introduction) 12 | - [Results](#-main-results) 13 | - [Visualization](#%EF%B8%8F-visualization) 14 | - [Getting Started](#getting-started) 15 | - [Installation](#-installation) 16 | - [Datasets](#-datasets) 17 | - [Inference](#-inference) 18 | - [Training](#-training) 19 | - [Custom Dataset](#-custom-dataset) 20 | 21 | --- 22 | 23 | ## Introduction 24 | 25 | Anomaly Detection involves identifying deviations from normal data distributions and is critical in fields such as medical diagnostics and industrial defect detection. Traditional AD methods typically require the availability of normal training samples; however, this assumption is not always feasible. Recently, the rich pretraining knowledge of CLIP has shown promising zero-shot generalization in detecting anomalies 26 | without the need for training samples from target domains. However, CLIP’s coarse-grained image-text alignment limits localization and detection performance for fine-grained anomalies due to: (1) spatial misalignment, and (2) the limited sensitivity of global features to local anomalous patterns. In this paper, we propose $Crane$ which tackles both problems. First, we introduce a correlation-based attention module to retain spatial alignment more accurately. Second, to boost the model’s awareness of fine-grained anomalies, we condition the learnable prompts of the text encoder on image context extracted from the vision encoder and perform a local-to-global representation fusion. Moreover, our method can incorporate vision foundation models such as DINOv2 to further enhance spatial understanding and localization. The key insight of $Crane$ is to balance learnable adaptations for modeling anomalous concepts with non-learnable adaptations that preserve and exploit generalized pretrained knowledge, thereby minimizing in-domain overfitting and maximizing performance on unseen domains. Extensive evaluation across 14 diverse industrial and medical datasets demonstrates that $Crane$ consistently improves the state-of-the-art ZSAD from 2% to 28%, at both image and pixel levels, while remaining competitive in inference speed. 27 | 28 | 33 | 34 | 35 | ## Overview 36 | 37 | ![Architecture](assets/main-fig.png) 38 | 39 | ## 📊 Main Results 40 | 41 | ### Zero-shot evaluation on industrial & medical datasets 42 | ![Industrial](assets/table1.png) 43 | 44 | ## 🖼️ Visualization 45 | ### Samples of zero-shot anomaly localization of $Crane^+$ for both the main setting and the medical setting (discussed in Appendix E). The complete set of visualizations can be found in Appendix of the paper. 46 | ![total](assets/visualization_combined.jpg) 47 | 48 | ## Getting Started 49 | To reproduce the results, follow the instructions below to run inference and training: 50 | 51 | ### 🧰 Installation 52 | All required libraries, including the correct PyTorch version, are specified in environment.yaml. Running setup.sh will automatically create the environment and install all dependencies. 53 | 54 | ```bash 55 | git clone https://github.com/AlirezaSalehy/Crane.git && cd Crane 56 | bash setup.sh 57 | conda activate crane_env 58 | ``` 59 | The required checkpoints for CLIP and DINO checkpoints will be downloaded automatically by the code and stored in `~/.cache`. However, the ViT-B SAM checkpoint must be downloaded manually. 60 | Please download `sam_vit_b_01ec64.pth` from the official Segment Anything repository [here](https://github.com/facebookresearch/segment-anything) to the following directory: 61 | ``` 62 | ~/.cache/sam/sam_vit_b_01ec64.pth 63 | ``` 64 | 65 | ### 📁 Datasets 66 | You can download the datasets from their official sources, and use utilities in `datasets/generate_dataset_json/` to generate a compatible meta.json. Alternatively from the [AdaCLIP repository](https://github.com/caoyunkang/AdaCLIP?tab=readme-ov-file#industrial-visual-anomaly-detection-datasets) which has provided a compatible format of the datasets. Place all datasets under `DATASETS_ROOT`, which is defined in [`./__init__.py`](__init__.py). 67 | 68 | ### 🔍 Inference 69 | The checkpoints for our trained "default" model are available in [`checkpoints`](/checkpoints/) directory. After installing needed libraries, reproduce the results by running: 70 | ```bash 71 | bash test.sh "0" 72 | ``` 73 | Here, `"0"` specifies the CUDA device ID(s). 74 | 75 | ### 🔧 Training 76 | To train new checkpoints and test on the medical and industrial datasets using the default setting, simply run: 77 | 78 | ```bash 79 | bash reproduce.sh new_model 0 80 | ``` 81 | where `new_model` and `0` specify the name for the checkpoint and the available cuda device ID. 82 | 83 | ## ➕ Custom Dataset 84 | 85 | You can use your custom dataset with our model easily following instructions below: 86 | 87 | ### 1. Organize Your Data 88 | Your dataset must either include a `meta.json` file at the root directory, or be organized so that one can be automatically generated. 89 | 90 | The `meta.json` should follow this format: 91 | - A dictionary with `"train"` and `"test"` at the highest level 92 | - Each section contains class names mapped to a list of samples 93 | - Each sample includes: 94 | - `img_path`: path to the image relative to the root dir 95 | - `mask_path`: path to the mask relative to the root dir (empty for normal samples) 96 | - `cls_name`: class name 97 | - `specie_name`: subclass or condition (e.g., `"good"`, `"fault1"`) 98 | - `anomaly`: anomaly label; 0 (normal) or 1 (anomalous) 99 | 100 | If your dataset does not include the required `meta.json`, you can generate it automatically by organizing your data as shown below and running [`datasets/generate_dataset_json/custom_dataset.py`](datasets/generate_dataset_json/custom_dataset.py): 101 | 102 | ``` 103 | datasets/your_dataset/ 104 | ├── train/ 105 | │ ├── c1/ 106 | │ │ └── good/ 107 | │ │ ├── .png 108 | │ └── c2/ 109 | │ └── good/ 110 | │ ├── .png 111 | ├── test/ 112 | │ ├── c1/ 113 | │ │ ├── good/ 114 | │ │ │ ├── .png 115 | │ │ ├── fault1/ 116 | │ │ │ ├── .png 117 | │ │ ├── fault2/ 118 | │ │ │ ├── .png 119 | │ │ └── masks/ 120 | │ │ ├── .png 121 | │ └── c2/ 122 | │ ├── good/ 123 | ... ... 124 | ``` 125 | 126 | Once organized, run the script to generate a `meta.json` automatically at the dataset root. 127 | 128 | 129 | ### 2. Run Testing 130 | Then you should place your dataset in the `DATASETS_ROOT`, specified in [`datasets/generate_dataset_json/__init__.py`](datasets/generate_dataset_json/__init__.py) and run the inference: 131 | 132 | ```bash 133 | python test.py --dataset YOUR_DATASET --model_name default --epoch 5 134 | ``` 135 | ## ⚡Efficient Implementation 136 | - For fair inference throughput comparison with other methods, the default setting is single GPU and original AUPRO implementation. But below, you can get to know some of the enhancements that you can enable. 137 | - Due to the unusual slowness of the original implementation of AUPRO and not finding a good alternative, I made a few optimizations and tested them against the original. The results are available here in [FasterAUPRO](https://github.com/AlirezaSalehy/FasterAUPRO). The optimized version computes AUPRO **3× to 38×** faster, saving you hours in performance evaluation. 138 | - The `test.py` implementation supports multi-GPU, and by specifying more CUDA IDs with `--devices`, you can benefit from further execution speedup. 139 | 140 | ## 🔒 License 141 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 142 | 143 | 144 | ## 📄 Citation 145 | If you find this project helpful for your research, please consider citing the following BibTeX entry. 146 | 147 | 148 | 149 | 150 | **BibTeX:** 151 | ```bibtex 152 | @article{salehi2025crane, 153 | title={Crane: Context-Guided Prompt Learning and Attention Refinement for Zero-Shot Anomaly Detections}, 154 | author={Salehi, Alireza and Salehi, Mohammadreza and Hosseini, Reshad and Snoek, Cees GM and Yamada, Makoto and Sabokrou, Mohammad}, 155 | journal={arXiv preprint arXiv:2504.11055}, 156 | year={2025} 157 | } 158 | ``` 159 | 160 | ## Acknowledgements 161 | This project builds upon: 162 | 163 | - [AdaCLIP](https://github.com/caoyunkang/AdaCLIP) 164 | - [VAND](https://github.com/ByChelsea/VAND-APRIL-GAN) 165 | - [AnomalyCLIP](https://github.com/zqhang/AnomalyCLIP) 166 | - [OpenAI CLIP](https://github.com/openai/CLIP) 167 | - [ProxyCLIP](https://github.com/mc-lan/ProxyCLIP) 168 | 169 | We greatly appreciate the authors for their contributions and open-source support. 170 | 171 | --- 172 | 173 | ## Contact 174 | For questions or collaborations, please contact **[alireza99salehy@gmail.com](mailto:alireza99salehy@gmail.com)**. 175 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve 2 | 3 | import numpy as np 4 | from skimage import measure 5 | import torch 6 | from torchmetrics import AUROC, AveragePrecision 7 | import time 8 | 9 | def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3): 10 | # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py 11 | binary_amaps = np.zeros_like(amaps, dtype=bool) 12 | min_th, max_th = amaps.min(), amaps.max() 13 | delta = (max_th - min_th) / max_step 14 | pros, fprs, ths = [], [], [] 15 | for th in np.arange(min_th, max_th, delta): 16 | binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1 17 | pro = [] 18 | for binary_amap, mask in zip(binary_amaps, masks): 19 | for region in measure.regionprops(measure.label(mask)): 20 | tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum() 21 | pro.append(tp_pixels / region.area) 22 | inverse_masks = 1 - masks 23 | fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() 24 | fpr = fp_pixels / inverse_masks.sum() 25 | pros.append(np.array(pro).mean()) 26 | fprs.append(fpr) 27 | ths.append(th) 28 | pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths) 29 | idxes = fprs < expect_fpr 30 | fprs = fprs[idxes] 31 | fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min()) 32 | pro_auc = auc(fprs, pros[idxes]) 33 | return pro_auc 34 | 35 | def calc_f1_max(gt, pr): 36 | precisions, recalls, _ = precision_recall_curve(gt, pr) 37 | f1_scores = (2 * precisions * recalls) / (precisions + recalls) 38 | return np.max(f1_scores[np.isfinite(f1_scores)]) 39 | 40 | # without warning for division by zero 41 | # denom = precisions + recalls 42 | # f1_scores = np.zeros_like(denom) 43 | # valid = denom > 0 44 | # f1_scores[valid] = (2 * precisions[valid] * recalls[valid]) / denom[valid] 45 | 46 | def image_level_metrics(results, obj, metric): 47 | gt = results[obj]['gt_sp'] 48 | pr = results[obj]['pr_sp'] 49 | gt = np.array(gt) 50 | pr = np.array(pr) 51 | 52 | if len(np.unique(gt)) < 2: 53 | print("only one class present, can not calculate image metrics") 54 | return 0 55 | 56 | if metric == 'image-auroc': 57 | performance = roc_auc_score(gt, pr) 58 | elif metric == 'image-ap': 59 | performance = average_precision_score(gt, pr) 60 | elif metric == 'image-f1': 61 | performance = calc_f1_max(gt, pr) 62 | # performance = f1_score(gt, pr.round()) 63 | # assert f1_max == performance 64 | return performance 65 | 66 | def pixel_level_metrics(results, obj, metric): 67 | gt = results[obj]['imgs_masks'] 68 | pr = results[obj]['anomaly_maps'] 69 | 70 | if len(np.unique(gt)) < 2: 71 | print("only one class present, can not calculate pixel metrics") 72 | return 0 73 | 74 | if metric == 'pixel-auroc': 75 | # gt = np.array(gt.cpu()); pr = np.array(pr.cpu()) 76 | # performance = roc_auc_score(gt.ravel(), pr.ravel()) 77 | performance = AUROC(task="binary")(pr, gt.to(dtype=torch.long)).item() 78 | elif metric == 'pixel-aupro': 79 | if len(gt.shape) == 4: 80 | gt = gt.squeeze(1) 81 | if len(pr.shape) == 4: 82 | pr = pr.squeeze(1) 83 | performance = cal_pro_score_gpu(gt, pr) 84 | # performance = cal_pro_score(gt, pr) 85 | elif metric == 'pixel-ap': # NOTE: The order in sklearn and torch metrics is inverse 86 | # gt = np.array(gt.cpu()); pr = np.array(pr.cpu()) 87 | # performance= average_precision_score(gt.ravel(), pr.ravel()) 88 | performance = AveragePrecision(task="binary")(pr, gt.to(dtype=torch.long)).item() 89 | 90 | elif metric == 'pixel-f1': 91 | # gt = np.array(gt.cpu()); pr = np.array(pr.cpu()) 92 | # performance = f1_score(gt.ravel(), pr.ravel().round()) 93 | performance = calc_f1_max(gt.cpu().ravel(), pr.cpu().ravel()) 94 | return performance 95 | 96 | # NEW implementation for pro using GPU and PyTorch 97 | def cal_pro_score_gpu(masks, amaps, max_step=200, expect_fpr=0.3): 98 | # GPU implementation using PyTorch 99 | device="cuda" 100 | if not torch.is_tensor(amaps): 101 | amaps = torch.tensor(amaps) 102 | if not torch.is_tensor(masks): 103 | masks = torch.tensor(masks) 104 | 105 | amaps = amaps.to(device) 106 | masks = masks.to(device) 107 | 108 | binary_amaps = torch.zeros_like(amaps, dtype=torch.bool, device=device) 109 | min_th, max_th = amaps.min().item(), amaps.max().item() 110 | delta = (max_th - min_th) / max_step 111 | pros, fprs, ths = [], [], [] 112 | 113 | regionprops_list = [measure.regionprops(measure.label(mask.cpu().numpy())) for mask in masks] 114 | coords_list = [[(region.coords[:, 0], region.coords[:, 1], len(region.coords)) for region in regionprops] for regionprops in regionprops_list] 115 | inverse_masks = 1 - masks 116 | tn_pixel = inverse_masks.sum().item() # Pixels that truly has the label of 0 117 | for th in np.arange(min_th, max_th, delta): 118 | binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1 119 | pro = [] 120 | 121 | for binary_amap, regions_coords in zip(binary_amaps, coords_list): 122 | for coords in regions_coords: 123 | tp_pixels = binary_amap[coords[0], coords[1]].sum().item() 124 | pro.append(tp_pixels / coords[2]) 125 | 126 | fp_pixels = torch.logical_and(inverse_masks, binary_amaps).sum().item() 127 | fpr = fp_pixels / tn_pixel 128 | pros.append(np.mean(pro)) 129 | fprs.append(fpr) 130 | ths.append(th.item()) 131 | 132 | pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths) 133 | idxes = fprs < expect_fpr 134 | fprs = fprs[idxes] 135 | pros = pros[idxes] 136 | fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min()) 137 | pro_auc = auc(fprs, pros) 138 | return pro_auc 139 | 140 | # https://github.com/M-3LAB/open-iad/blob/main/metric/mvtec3d/au_pro.py#L205 141 | import numpy as np 142 | from scipy.ndimage import label 143 | from bisect import bisect 144 | __all__ = ['GroundTruthComponent', 'trapezoid', 'collect_anomaly_scores', 'compute_pro', 'calculate_au_pro'] 145 | 146 | class GroundTruthComponent: 147 | def __init__(self, anomaly_scores): 148 | self.anomaly_scores = anomaly_scores.copy() 149 | self.anomaly_scores.sort() 150 | self.index = 0 151 | self.last_threshold = None 152 | 153 | def compute_overlap(self, threshold): 154 | if self.last_threshold is not None: 155 | assert self.last_threshold <= threshold 156 | while self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold: 157 | self.index += 1 158 | return 1.0 - self.index / len(self.anomaly_scores) 159 | 160 | def trapezoid(x, y, x_max=None): 161 | x = np.array(x) 162 | y = np.array(y) 163 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y)) 164 | if not finite_mask.all(): 165 | print("""WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.""") 166 | x = x[finite_mask] 167 | y = y[finite_mask] 168 | correction = 0.0 169 | if x_max is not None: 170 | if x_max not in x: 171 | ins = bisect(x, x_max) 172 | assert 0 < ins < len(x) 173 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1])) 174 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1]) 175 | mask = x <= x_max 176 | x = x[mask] 177 | y = y[mask] 178 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction 179 | 180 | def collect_anomaly_scores(anomaly_maps, ground_truth_maps): 181 | assert len(anomaly_maps) == len(ground_truth_maps) 182 | ground_truth_components = [] 183 | anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size) 184 | structure = np.ones((3, 3), dtype=int) 185 | ok_index = 0 186 | for gt_map, prediction in zip(ground_truth_maps, anomaly_maps): 187 | labeled, n_components = label(gt_map, structure) 188 | num_ok_pixels = len(prediction[labeled == 0]) 189 | anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy() 190 | ok_index += num_ok_pixels 191 | for k in range(n_components): 192 | component_scores = prediction[labeled == (k + 1)] 193 | ground_truth_components.append(GroundTruthComponent(component_scores)) 194 | anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index) 195 | anomaly_scores_ok_pixels.sort() 196 | return ground_truth_components, anomaly_scores_ok_pixels 197 | 198 | def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds): 199 | ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps) 200 | threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int) 201 | fprs = [1.0] 202 | pros = [1.0] 203 | for pos in threshold_positions: 204 | threshold = anomaly_scores_ok_pixels[pos] 205 | fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels) 206 | pro = 0.0 207 | for component in ground_truth_components: 208 | pro += component.compute_overlap(threshold) 209 | pro /= len(ground_truth_components) 210 | fprs.append(fpr) 211 | pros.append(pro) 212 | fprs = fprs[::-1] 213 | pros = pros[::-1] 214 | return fprs, pros 215 | 216 | def calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresholds=200): 217 | # Compute the PRO curve. 218 | pro_curve = compute_pro(anomaly_maps=predictions, ground_truth_maps=gts, num_thresholds=num_thresholds) 219 | 220 | # Compute the area under the PRO curve. 221 | au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max=integration_limit) 222 | au_pro /= integration_limit 223 | 224 | # Return the evaluation metrics. 225 | return au_pro, pro_curve 226 | 227 | def test_pro_score(masks, amaps): 228 | start_cpu = time.time() 229 | pro_auc_cpu = cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3) 230 | end_cpu = time.time() 231 | cpu_duration = end_cpu - start_cpu 232 | print(f"CPU execution time: {cpu_duration:.4f} seconds") 233 | # start_gpu = time.time() 234 | # masks_torch = torch.tensor(masks, dtype=torch.float32, device='cuda') 235 | # amaps_torch = torch.tensor(amaps, dtype=torch.float32, device='cuda') 236 | # pro_auc_gpu = cal_pro_score_gpu(masks_torch, amaps_torch) 237 | # end_gpu = time.time() 238 | # gpu_duration = end_gpu - start_gpu 239 | # print(f"GPU execution time: {gpu_duration:.4f} seconds") 240 | start_openiad = time.time() 241 | pro_auc_openiad = calculate_au_pro(masks, amaps, integration_limit=0.3, num_thresholds=200)[0] 242 | end_openiad = time.time() 243 | openiad_duration = end_openiad - start_openiad 244 | print(f"openiad execution time: {openiad_duration:.4f} seconds") 245 | 246 | # assert np.isclose(pro_auc_cpu, pro_auc_openiad), f"Results differ: CPU={pro_auc_cpu}, GPU={pro_auc_openiad}" 247 | print(f"Test passed: CPU={pro_auc_cpu}, GPU={pro_auc_openiad}") 248 | 249 | if __name__ == "__main__": 250 | # Example usage (with small random data for testing) 251 | num_sam = 25 252 | device='7' 253 | 254 | # masks = np.random.randint(0, 2, (num_sam, 512, 512)) # Binary masks 255 | # amaps = np.random.rand(num_sam, 512, 512) # Anomaly maps 256 | 257 | # test_pro_score(masks, amaps) 258 | 259 | masks = np.random.randint(0, 2, (num_sam, 256, 256)) # Binary masks 260 | amaps = np.random.rand(num_sam, 256, 256) # Anomaly maps 261 | 262 | test_pro_score(masks, amaps) 263 | 264 | masks = np.random.randint(0, 2, (num_sam, 64, 64)) # Binary masks 265 | amaps = np.random.rand(num_sam, 64, 64) # Anomaly maps 266 | 267 | test_pro_score(masks, amaps) -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import models 2 | from models import Crane 3 | from models.prompt_ensemble import PromptLearner 4 | from dataset.dataset import Dataset 5 | from __init__ import DATASETS_ROOT 6 | 7 | from utils.transform import get_transform 8 | from utils.loss import FocalLoss, BinaryDiceLoss 9 | from utils.logger import get_logger 10 | from utils.similarity import calc_similarity_logits, regrid_upsample 11 | from utils import ( 12 | setup_seed, 13 | seed_worker, 14 | turn_gradient_off, 15 | str2bool, 16 | prepare_encode_image_module, 17 | precompute_image_features, 18 | CustomTensorDataset 19 | ) 20 | 21 | import sys 22 | import os 23 | import argparse 24 | import subprocess 25 | 26 | from tqdm import tqdm 27 | 28 | import torch 29 | from torch.utils.data import DataLoader 30 | import torch.nn.functional as F 31 | import numpy as np 32 | 33 | # import torch.profiler 34 | 35 | def train(args): 36 | logger = get_logger(args.save_path) 37 | 38 | preprocess, target_transform = get_transform(args) 39 | train_data = Dataset(roots=args.train_data_path, transform=preprocess, 40 | target_transform=target_transform, dataset_name=args.dataset, kwargs=args) 41 | g = torch.Generator() 42 | g.manual_seed(args.seed) 43 | # train_dataloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) # More basic for FPS comparison 44 | train_dataloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, 45 | num_workers=16, pin_memory=True, prefetch_factor=2, 46 | generator=g, worker_init_fn=seed_worker) # Faster 47 | print(f"Length of the dataset: {len(train_data)}") 48 | 49 | ########################################################################################## 50 | device = 'cuda' if torch.cuda.is_available() else "cpu" 51 | print(device) 52 | 53 | crane_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx, 'others': args} 54 | model, _ = models.load("ViT-L/14@336px", device=device, design_details = crane_parameters) 55 | model = turn_gradient_off(model) 56 | model.visual.replace_with_EAttn(to_layer=20, type=args.attn_type) # Replace last 20 layers 57 | if args.dino_model != 'none': 58 | model.use_DAttn(args.dino_model) 59 | 60 | prompt_learner = PromptLearner(model.to("cpu"), crane_parameters) 61 | sbp = Crane.ScoreBasePooling() 62 | 63 | model.to(device) 64 | prompt_learner.to(device) 65 | 66 | ########################################################################################## 67 | params = list(prompt_learner.parameters()) 68 | optimizer = torch.optim.Adam(params, lr=args.learning_rate, betas=(0.6, 0.999)) 69 | 70 | precompute = False 71 | if precompute: 72 | encode_image_module = prepare_encode_image_module(model, args.features_list) 73 | precompute_features, pathes = precompute_image_features(train_data, encode_image_module, args) 74 | precompute_dataset = CustomTensorDataset(precompute_features, pathes) 75 | train_dataloader = DataLoader(precompute_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, 76 | generator=g, worker_init_fn=seed_worker) 77 | model.visual.to('cpu') 78 | 79 | # losses 80 | ce_loss_focal = FocalLoss() 81 | loss_dice = BinaryDiceLoss() 82 | 83 | model.eval() 84 | prompt_learner.train() 85 | for epoch in tqdm(range(args.epoch)): 86 | loss_list = [] 87 | 88 | with tqdm(train_dataloader) as batch_tqdm: 89 | for items in batch_tqdm: 90 | label = items['anomaly'].to(device) 91 | abnorm_mask = items['abnorm_mask'].squeeze().to(device) 92 | 93 | if precompute: 94 | image_features, patch_features = items['image_features'].to(device), items['patch_features'].to(device) 95 | patch_features = patch_features.permute(1, 0, *range(2, patch_features.dim())) # 4, N, L, C 96 | else: 97 | image = items['img'].to(device) 98 | image_features, patch_features = model.encode_image(image, args.features_list, self_cor_attn_layers=20) 99 | # patch_features = torch.stack(patch_features, dim=0) 100 | # image_features = F.normalize(image_features, dim=-1) 101 | # patch_features = F.normalize(patch_features, dim=-1) 102 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 103 | patch_features = [patch_feature / patch_feature.norm(dim=-1, keepdim=True) for patch_feature in patch_features] # Note 104 | patch_features = torch.stack(patch_features, dim=0) 105 | 106 | # Text Features 107 | ######################################################################### 108 | prompts, tokenized_prompts, compound_prompts_text, is_train_with_img_cls = prompt_learner(img_emb=image_features) 109 | 110 | if is_train_with_img_cls: 111 | text_features_nrm = model.encode_text_learn(prompts[0], tokenized_prompts[0], compound_prompts_text) # input dims: 2, 77 | 2, 77, 768 | 9, 4, 768 112 | text_features_anm = model.encode_text_learn(prompts[1], tokenized_prompts[1], compound_prompts_text) # input dims: 2, 77 | 2, 77, 768 | 9, 4, 768 113 | text_features = torch.stack([text_features_nrm, text_features_anm], dim=1) 114 | else: 115 | text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).unsqueeze(dim=0) # 2, 77 | 2, 77, 768 | 9, 4, 768 116 | text_features = F.normalize(text_features, dim=-1).float() # 1, 2, 768 117 | 118 | # Similarity Map - Segmentation 119 | ######################################################################### 120 | similarity_map_list = [] 121 | for patch_feature in patch_features: 122 | pixel_logits = calc_similarity_logits(patch_feature, text_features, temp=0.07) 123 | pixel_scores = pixel_logits.softmax(dim=-1) 124 | similarity_map = regrid_upsample(pixel_scores, args.image_size, mode=args.interpolation) 125 | similarity_map_list.append((similarity_map, pixel_logits)) 126 | 127 | ce_focal_loss = 0 128 | dice_loss = 0 129 | for i in range(len(similarity_map_list)): 130 | whole_map = (1-similarity_map_list[i][0][...,0] + similarity_map_list[i][0][...,1])/2 131 | smlr_map = similarity_map_list[i][0].permute(0, 3, 1, 2) 132 | 133 | dice_loss += loss_dice(whole_map, abnorm_mask) 134 | ce_focal_loss += ce_loss_focal(smlr_map, abnorm_mask) 135 | 136 | # Similarity Score - Classification 137 | ######################################################################### 138 | if args.use_scorebase_pooling: 139 | alpha = 0.5 140 | sms = [sm_lst[1] for sm_lst in similarity_map_list] 141 | clustered_feature = sbp.forward(patch_features, sms) 142 | image_features = alpha * clustered_feature + (1 - alpha) * image_features # aggregates the class token and the clustered features for more comprehensive information 143 | image_features = F.normalize(image_features, dim=1) 144 | 145 | image_logits = calc_similarity_logits(image_features, text_features, temp=0.01) # batch_size, 1, 768 @ batch_size, 768, 2 or 3 146 | ce_img2txt_loss = F.cross_entropy(image_logits, label.long().to(device)) 147 | # txt2img_lbl = torch.stack([(1-label), label], dim=0)/label.sum() 148 | # ce_txt2img_loss = F.cross_entropy(image_logits.permute(1, 0), txt2img_lbl.to(device)) 149 | 150 | #loss 151 | optimizer.zero_grad() 152 | dice_loss *= 2 153 | ce_focal_loss *= 2 154 | ls = ce_focal_loss+dice_loss+0.2*ce_img2txt_loss 155 | ls.backward() 156 | optimizer.step() 157 | 158 | loss_list.append((ce_focal_loss.item(), dice_loss.item(), ce_img2txt_loss.item())) 159 | batch_tqdm.set_description(f"ce_fc_ls: {ce_focal_loss:.3f}, bcd_dice_ls: {dice_loss:.3f}, ce_img_ls: {ce_img2txt_loss:.3f}") 160 | # logs 161 | ce_focal_ls, dice_ls, ce_img_ls = np.mean(loss_list, axis=0) 162 | log_template = 'epoch [{}/{}], ce_fc_ls:{:.4f}, bdc_ls:{:.4f}, ce_img_ls:{:.4f}' 163 | logger.info(log_template.format(epoch + 1, args.epoch, ce_focal_ls, dice_ls, ce_img_ls)) 164 | 165 | # save model 166 | if (epoch + 1) % args.save_freq == 0: 167 | prmtp_ckp_path = os.path.join(args.save_path, 'epoch_' + str(epoch + 1) + '.pth') 168 | checkpoint_data = {"prompt_learner": prompt_learner.state_dict()} 169 | torch.save(checkpoint_data, prmtp_ckp_path) 170 | 171 | if __name__ == '__main__': 172 | dss = ['mvtec'] 173 | 174 | parser = argparse.ArgumentParser("Crane", add_help=True) 175 | parser.add_argument("--datasets_root_dir", type=str, default=f"{DATASETS_ROOT}") 176 | parser.add_argument("--train_data_path", type=str, nargs="+", default=[f"{DATASETS_ROOT}/{ds}/" for ds in dss]) 177 | parser.add_argument("--save_path", type=str, default='./checkpoints/') 178 | parser.add_argument("--model_name", type=str, default="default") # NOTE: The "trained_on_" will be prepended to the model name for saving checkpoints 179 | parser.add_argument("--seed", type=int, default=111) 180 | parser.add_argument("--save_freq", type=int, default=1, help="save frequency") 181 | 182 | parser.add_argument("--type", type=str, default='train') 183 | parser.add_argument("--device", type=int, default=0, help="cuda device") 184 | parser.add_argument("--epoch", type=int, default=5, help="epochs") 185 | parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") 186 | parser.add_argument("--batch_size", type=int, default=8, help="batch size") 187 | parser.add_argument("--aug_rate", type=float, default=0.0, help="augmentation rate") 188 | 189 | parser.add_argument("--dataset", type=str, nargs="+", default=[f'{ds}' for ds in dss], help="train dataset name") 190 | parser.add_argument("--k_shot", type=int, default=0, help="samples per class for few-shot learning. 0 means use all data.") 191 | parser.add_argument("--portion", type=float, default=1) 192 | 193 | parser.add_argument("--image_size", type=int, default=518, help="image size") 194 | parser.add_argument("--features_list", type=int, nargs="+", default=[24], help="layer features used") 195 | parser.add_argument("--interpolation", type=str, choices=['nearest', 'bilinear'], default='nearest') 196 | 197 | parser.add_argument("--depth", type=int, default=9, help="image size") 198 | parser.add_argument("--n_ctx", type=int, default=12, help="zero shot") 199 | parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot") 200 | 201 | parser.add_argument("--train_with_img_cls_prob", type=float, default=1) 202 | parser.add_argument("--train_with_img_cls_type", type=str, choices=["none", "replace_prefix", "replace_suffix", "pad_prefix", "pad_suffix"], default="pad_suffix") 203 | parser.add_argument("--dino_model", type=str, choices=['none', 'dinov2', 'dino', 'sam'], default='dinov2') 204 | parser.add_argument("--both_eattn_dattn", type=str2bool, default=True) 205 | parser.add_argument("--use_scorebase_pooling", type=str2bool, default=True) 206 | parser.add_argument("--attn_type", type=str, choices=["vv", "kk", "qq", "qq+kk", "qq+kk+vv", "(q+k+v)^2"], default="qq+kk+vv") 207 | parser.add_argument("--why", type=str, default="Neccessity of the experiment") 208 | 209 | args = parser.parse_args() 210 | 211 | if 'CUDA_VISIBLE_DEVICES' not in os.environ: 212 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device) 213 | command = [sys.executable, __file__, ] + sys.argv[1:] 214 | process = subprocess.Popen(command, env=os.environ) 215 | process.wait() 216 | 217 | else: 218 | setup_seed(args.seed) 219 | 220 | # paths 221 | args.train_data_path = [f"{args.datasets_root_dir}/{ds}/" for ds in args.dataset] 222 | args.save_path = f'{args.save_path}/trained_on_{"_".join(args.dataset)}_{args.model_name}/' 223 | print(f'running {args.model_name}') 224 | 225 | train(args) -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import json 3 | import random 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import albumentations as A 8 | 9 | def anomaly_map_guided_crop(img, img_mask): 10 | # Convert mask to numpy for bounding box calculation 11 | mask_np = np.array(img_mask) 12 | if mask_np.sum() > 0: # Check if there is any anomaly in the mask 13 | # Get the bounding box of the anomaly 14 | nonzero_coords = np.column_stack(np.nonzero(mask_np)) 15 | top_left = nonzero_coords.min(axis=0) # (y_min, x_min) 16 | bottom_right = nonzero_coords.max(axis=0) # (y_max, x_max) 17 | 18 | # Optionally, expand the bounding box to include some background 19 | padding = img.size[0]*0.1 + np.random.randint(0, int(img.size[0]-img.size[0]*0.1)) # This can be adjusted based on the context you want to include 20 | y_min, x_min = np.maximum([0, 0], top_left - padding) 21 | y_max, x_max = np.minimum([mask_np.shape[0], mask_np.shape[1]], bottom_right + padding) 22 | 23 | # Crop both image and mask 24 | img = img.crop((x_min, y_min, x_max, y_max)) 25 | img_mask = img_mask.crop((x_min, y_min, x_max, y_max)) 26 | 27 | return img, img_mask 28 | 29 | def save_selected_data_paths(data_all, folder_path, file_name='selected_data_paths.txt'): 30 | file_path = os.path.join(folder_path, file_name) 31 | 32 | # Read existing lines from the file 33 | if os.path.exists(file_path): 34 | with open(file_path, 'r') as f: 35 | existing_lines = f.readlines() 36 | else: 37 | existing_lines = [] 38 | 39 | # Ensure existing_lines has at least as many lines as data_all 40 | while len(existing_lines) < len(data_all): 41 | existing_lines.append('\n') 42 | 43 | # Update lines with new img_path data 44 | with open(file_path, 'w') as f: 45 | for i, data in enumerate(data_all): 46 | if 'img_path' in data: 47 | existing_line = existing_lines[i].rstrip('\n') 48 | updated_line = existing_line + ' ' + data['img_path'] if existing_line else data['img_path'] 49 | f.write(updated_line + '\n') 50 | else: 51 | print("Warning: 'img_path' not found in data entry.") 52 | 53 | def compare_data_with_file(data_list, folder_path, key='img_path'): 54 | # Construct the full file path 55 | file_path = os.path.join(folder_path, 'selected_data_paths.txt') 56 | 57 | # Check if the file exists 58 | if not os.path.exists(file_path): 59 | print(f"File not found: {file_path}") 60 | return 0, len(data_list) # Return 0 matches and all data as mismatches 61 | 62 | # Read the file paths from the file 63 | with open(file_path, 'r') as f: 64 | stored_file_paths = [line.strip() for line in f.readlines()] 65 | 66 | # Extract the file paths from the data list 67 | data_file_paths = [data[key] for data in data_list if key in data] 68 | 69 | # Calculate matches and mismatches 70 | matches = set(data_file_paths) & set(stored_file_paths) 71 | mismatches_data = set(data_file_paths) - matches 72 | mismatches_file = set(stored_file_paths) - matches 73 | 74 | match_count = len(matches) 75 | mismatch_count = len(mismatches_data) + len(mismatches_file) 76 | 77 | match_description = f"Matches ({match_count}): {matches}" 78 | mismatch_description = (f"Mismatches in data ({len(mismatches_data)}): {mismatches_data}, " 79 | f"Mismatches in file ({len(mismatches_file)}): {mismatches_file}") 80 | print(match_description) 81 | print(mismatch_description) 82 | 83 | return match_count, mismatch_count 84 | 85 | # Combining MVTec images to form VisA samples 86 | def combine_img(organized_data, random_defect=None): # random_defect = "abnormal" 87 | img_ls = [] 88 | mask_ls = [] 89 | for i in range(4): 90 | if random_defect is None: 91 | # random_defect = random.choice(list(organized_data.keys())) 92 | random_defect = random.choices(list(organized_data.keys()), weights=[0.8, 0.2], k=1)[0] # With these weight we make sure that normal images are created nrealy 50% of times 93 | 94 | random_sample = random.choice(organized_data[random_defect]) 95 | 96 | img_path = os.path.join(random_sample['root'], random_sample['img_path']) 97 | mask_path = os.path.join(random_sample['root'], random_sample['mask_path']) 98 | assert (os.path.exists(img_path)) 99 | img = Image.open(img_path) 100 | img_ls.append(img) 101 | if random_sample['anomaly'] == 0: 102 | img_mask = Image.fromarray(np.zeros((img.size[1], img.size[0])), mode='L') 103 | else: 104 | assert os.path.exists(mask_path) 105 | img_mask = np.array(Image.open(mask_path).convert('L')) > 0 106 | img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') 107 | mask_ls.append(img_mask) 108 | 109 | # image 110 | image_width, image_height = img_ls[0].size 111 | result_image = Image.new("RGB", (2 * image_width, 2 * image_height)) 112 | for i, img in enumerate(img_ls): 113 | row = i // 2 114 | col = i % 2 115 | x = col * image_width 116 | y = row * image_height 117 | result_image.paste(img, (x, y)) 118 | 119 | # mask 120 | result_mask = Image.new("L", (2 * image_width, 2 * image_height)) 121 | for i, img in enumerate(mask_ls): 122 | row = i // 2 123 | col = i % 2 124 | x = col * image_width 125 | y = row * image_height 126 | result_mask.paste(img, (x, y)) 127 | 128 | return result_image, result_mask 129 | 130 | class Dataset(data.Dataset): 131 | def __init__(self, roots, transform, target_transform, dataset_name, kwargs=None): 132 | self.roots = roots 133 | self.transform = transform 134 | self.target_transform = target_transform 135 | 136 | self.aug_rate = kwargs.aug_rate 137 | pr=0.20 # 0.5 # 138 | self.img_trans = A.Compose([ 139 | A.Rotate(limit=30, p=pr), 140 | A.RandomRotate90(p=pr), 141 | A.RandomBrightnessContrast(p=pr), 142 | A.GaussNoise(p=pr), 143 | A.OneOf([ 144 | A.Blur(blur_limit=3, p=pr), 145 | A.ColorJitter(p=pr), 146 | A.GaussianBlur(p=pr), 147 | ], p=pr) 148 | ], is_check_shapes=False) 149 | 150 | meta_infos = {} 151 | dataset_split='test' 152 | for root in roots: 153 | with open(f'{root}/meta.json', 'r') as f: 154 | meta_info = json.load(f) 155 | for cls in meta_info[dataset_split]: 156 | meta_info[dataset_split][cls] = [{**s, 'root': root} for s in meta_info[dataset_split][cls]] 157 | 158 | if cls in meta_infos: 159 | meta_infos[cls].extend(meta_info[dataset_split][cls]) 160 | else: 161 | meta_infos[cls] = meta_info[dataset_split][cls] 162 | self.cls_names = list(meta_infos.keys()) 163 | 164 | self.data_all = [] 165 | for cls_name in self.cls_names: 166 | self.data_all.extend(meta_infos[cls_name]) 167 | 168 | self.dataset_name = dataset_name 169 | self.class_ids = list(range(len(self.cls_names))) 170 | self.class_name_map_class_id = {k: index for k, index in zip(self.cls_names, self.class_ids)} 171 | 172 | self.portion = kwargs.portion 173 | self.k_shot = kwargs.k_shot 174 | self.mode = kwargs.type 175 | if not (self.portion == 1.0 and self.k_shot == 0): 176 | sampled_sets = self._sample(meta_infos) 177 | 178 | if self.mode == 'train': 179 | self.data_all = sampled_sets[0] # 180 | # save_train_data_paths(data_all) 181 | 182 | # elif self.mode == 'test' and kwargs.train_dataset == self.dataset_name: 183 | # self.data_all = sampled_sets[1] 184 | # compare_data_with_file(self.data_all, kwargs.training_path) 185 | 186 | # save_selected_data_paths(self.data_all, kwargs.save_path) 187 | 188 | 189 | self.organized_data = {cls_name: {'normal': [], 'abnormal': []} for cls_name in self.cls_names} 190 | for entry in self.data_all: 191 | cls_name = entry['cls_name'] 192 | anomaly_status = 'abnormal' if entry['anomaly'] == 1 else 'normal' 193 | self.organized_data[cls_name][anomaly_status].append(entry) 194 | 195 | self.length = len(self.data_all) 196 | print(f"number of samples: {self.length}") 197 | 198 | def augment(self, img , img_mask): 199 | img_mask = np.array(img_mask) 200 | img = np.array(img) 201 | augmentations = self.img_trans(mask=img_mask, image=img) 202 | img = augmentations["image"] 203 | img_mask = augmentations["mask"] 204 | img = Image.fromarray(img) 205 | img_mask = Image.fromarray(img_mask.astype(np.uint8), mode='L') 206 | return img, img_mask 207 | 208 | # Sample same number of normal and anomalous for all classes 209 | # Or to sample proportional to their length 210 | def _sample(self, meta_info): 211 | if self.portion == 1.0 and self.k_shot == 0: 212 | return self.data_all 213 | 214 | sampled_data = [] 215 | complement_data = [] 216 | 217 | for cls_name, data_list in meta_info.items(): 218 | nrm_smpls = [item for item in data_list if item['anomaly'] == 0] 219 | anm_smpls = [item for item in data_list if item['anomaly'] == 1] 220 | 221 | if self.k_shot > 0: 222 | n_samples = self.k_shot 223 | n_nrm_smpls = min(int(n_samples/2), len(nrm_smpls)) 224 | n_anm_smpls = min(int(n_samples/2), len(anm_smpls)) 225 | 226 | else: 227 | n_nrm_smpls = int(len(nrm_smpls)*self.portion) 228 | n_anm_smpls = int(len(anm_smpls)*self.portion) 229 | 230 | cls_data = [] 231 | cls_data.extend(random.sample(nrm_smpls, n_nrm_smpls)) 232 | cls_data.extend(random.sample(anm_smpls, n_anm_smpls)) 233 | sampled_data.extend(cls_data) 234 | 235 | complement_class_data = [item for item in data_list if item not in cls_data] 236 | complement_data.extend(complement_class_data) 237 | 238 | if self.k_shot + self.portion > 0: 239 | print(f'num samples for cls {cls_name}, norm: {n_nrm_smpls}, anom: {n_anm_smpls}') 240 | 241 | return sampled_data, complement_data 242 | 243 | def __len__(self): 244 | return self.length 245 | 246 | def _process_image(self, data): 247 | img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], \ 248 | data['cls_name'], data['specie_name'], data['anomaly'] 249 | 250 | root = data['root'] 251 | img = Image.open(os.path.join(root, img_path)) 252 | 253 | if anomaly == 0: 254 | img_mask = Image.fromarray(np.zeros((img.size[1], img.size[0])), mode='L') 255 | else: 256 | if os.path.isdir(os.path.join(root, mask_path)): 257 | img_mask = Image.fromarray(np.zeros((img.size[1], img.size[0])), mode='L') 258 | else: 259 | img_mask = np.array(Image.open(os.path.join(root, mask_path)).convert('L')) > 0 260 | img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') 261 | 262 | random_number = random.random() 263 | if self.mode == 'train' and self.aug_rate > random_number and \ 264 | cls_name in ['hazelnut', 'pill', 'zipper', 'bottle', 'screw', 'metal_nut', 'cable']: 265 | img, img_mask = combine_img(self.organized_data[cls_name]) 266 | anomaly = 1 if np.any(np.array(img_mask) > 0) else 0 267 | # if self.mode == "train": 268 | # self.augment(img=img, img_mask=img_mask) 269 | 270 | img = self.transform(img) if self.transform is not None else img 271 | img_mask = self.target_transform(img_mask) 272 | img_mask[img_mask > 0.5] = 1 273 | img_mask[img_mask <= 0.5] = 0 274 | 275 | result = { 276 | 'img': img, 277 | 'abnorm_mask': img_mask, 278 | 'cls_name': cls_name, 279 | 'anomaly': anomaly, 280 | 'img_path': os.path.join(root, img_path), 281 | "cls_id": self.class_name_map_class_id[cls_name] 282 | } 283 | 284 | return result 285 | 286 | def __getitem__(self, index): 287 | data = self.data_all[index] 288 | result = self._process_image(data) 289 | return result -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | --------------------------------------------------------------------------------