├── medsyn ├── __init__.py ├── utils.py ├── labelling.py ├── task_shape.py └── tasks.py ├── models ├── __init__.py ├── Adapter.py ├── MapMaker.py ├── Necker.py └── CoOp.py ├── utils ├── __init__.py ├── misc_helper.py └── losses.py ├── datasets ├── __init__.py └── dataset.py ├── assets ├── vis.jpg └── pipeline.jpg ├── open_clip ├── bpe_simple_vocab_16e6.txt.gz ├── constants.py ├── model_configs │ ├── ViT-B-16.json │ ├── ViT-B-32.json │ ├── ViT-M-16.json │ ├── ViT-M-32.json │ ├── ViT-S-16.json │ ├── ViT-S-32.json │ ├── ViT-B-16-plus.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-32-alt.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-M-16-alt.json │ ├── mt5-base-ViT-B-32.json │ ├── xlm-roberta-base-ViT-B-32.json │ ├── roberta-ViT-B-32.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── mt5-xl-ViT-H-14.json │ ├── ViT-bigG-14.json │ ├── xlm-roberta-large-ViT-H-14.json │ ├── RN50.json │ ├── RN101.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── RN50x64.json │ ├── vit_medium_patch16_gap_256.json │ ├── swin_base_patch4_window7_224.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── RN101-quickgelu.json │ ├── RN50-quickgelu.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_base_w_320.json │ ├── convnext_large_d_320.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ └── convnext_xxlarge_320.json ├── __init__.py ├── transform.py ├── utils.py ├── openai.py ├── factory.py ├── tokenizer.py ├── modified_resnet.py ├── pretrained.py └── model.py ├── requirements.txt ├── LICENSE ├── config ├── busi.yaml ├── brainmri.yaml └── chexpert.yaml ├── data └── README.md ├── README.md ├── test.py └── train.py /medsyn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnulab/MediCLIP/HEAD/assets/vis.jpg -------------------------------------------------------------------------------- /assets/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnulab/MediCLIP/HEAD/assets/pipeline.jpg -------------------------------------------------------------------------------- /open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnulab/MediCLIP/HEAD/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0+cu113 2 | torchvision==0.12.0+cu113 3 | Pillow==9.1.1 4 | scikit-image==0.19.3 5 | scikit-learn==1.1.2 6 | sklearn==0.0 7 | opencv-python==4.6.0.66 8 | grad_cam==1.4.3 9 | tqdm==4.61.2 10 | PyYAML==6.0 11 | easydict==1.9 12 | ftfy==6.1.3 13 | regex==2023.12.25 14 | imgaug==0.4.0 15 | numpy==1.22.4 16 | -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, \ 5 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 6 | from .openai import load_openai_model, list_openai_models 7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 8 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 9 | from .tokenizer import SimpleTokenizer, tokenize, decode 10 | from .transform import image_transform 11 | -------------------------------------------------------------------------------- /models/Adapter.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | class Adapter(nn.Module): 5 | 6 | def __init__(self, 7 | clip_model, 8 | target, 9 | ): 10 | 11 | super(Adapter, self).__init__() 12 | 13 | input_sizes = clip_model.token_c 14 | 15 | for i,input_size in enumerate(input_sizes): 16 | self.add_module("{}_adapter".format(i), nn.Sequential(nn.Conv2d(input_size, target, 1, 1))) 17 | 18 | 19 | def forward(self, tokens): 20 | vision_features=[] 21 | for i,token in enumerate(tokens): 22 | vision_feature=getattr(self,'{}_adapter'.format(i))(token).contiguous().permute(0, 2, 3, 1) 23 | vision_feature = vision_feature / vision_feature.norm(dim=-1, keepdim=True) 24 | vision_features.append(vision_feature) 25 | return vision_features -------------------------------------------------------------------------------- /models/MapMaker.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.nn.functional as F 6 | 7 | 8 | class MapMaker(nn.Module): 9 | 10 | def __init__(self,image_size): 11 | 12 | super(MapMaker, self).__init__() 13 | self.image_size = image_size 14 | 15 | 16 | def forward(self, vision_adapter_features,propmt_adapter_features): 17 | anomaly_maps=[] 18 | 19 | for i,vision_adapter_feature in enumerate(vision_adapter_features): 20 | B, H, W, C = vision_adapter_feature.shape 21 | anomaly_map = (vision_adapter_feature.view((B, H * W, C)) @ propmt_adapter_features).contiguous().view( 22 | (B, H, W, -1)).permute(0, 3, 1, 2) 23 | 24 | anomaly_maps.append(anomaly_map) 25 | 26 | anomaly_map = torch.stack(anomaly_maps, dim=0).mean(dim=0) 27 | anomaly_map = F.interpolate(anomaly_map, (self.image_size, self.image_size), mode='bilinear', align_corners=True) 28 | return torch.softmax(anomaly_map, dim=1) -------------------------------------------------------------------------------- /models/Necker.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class Necker(nn.Module): 8 | 9 | def __init__(self, 10 | clip_model 11 | ): 12 | super(Necker, self).__init__() 13 | self.clip_model=clip_model 14 | target = max(self.clip_model.token_size) 15 | for i,size in enumerate(self.clip_model.token_size): 16 | self.add_module("{}_upsample".format(i), 17 | nn.UpsamplingBilinear2d(scale_factor=target/size)) 18 | 19 | 20 | @torch.no_grad() 21 | def forward(self, tokens): 22 | align_features=[] 23 | for i,token in enumerate(tokens): 24 | if len(token.shape) == 3: 25 | B, N, C=token.shape 26 | token = token[:, 1:, :] 27 | token=token.view((B,int(math.sqrt(N-1)),int(math.sqrt(N-1)),C)).permute(0, 3, 1, 2) 28 | align_features.append(getattr(self, "{}_upsample".format(i))(token)) 29 | return align_features 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /open_clip/transform.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass, asdict 3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 4 | 5 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 6 | CenterCrop 7 | 8 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 9 | 10 | 11 | def _convert_to_rgb(image): 12 | return image.convert('RGB') 13 | 14 | 15 | 16 | def image_transform( 17 | mean: Optional[Tuple[float, ...]] = None, 18 | std: Optional[Tuple[float, ...]] = None, 19 | ): 20 | 21 | mean = mean or OPENAI_DATASET_MEAN 22 | if not isinstance(mean, (list, tuple)): 23 | mean = (mean,) * 3 24 | 25 | std = std or OPENAI_DATASET_STD 26 | if not isinstance(std, (list, tuple)): 27 | std = (std,) * 3 28 | 29 | normalize = Normalize(mean=mean, std=std) 30 | 31 | transforms = [ 32 | _convert_to_rgb, 33 | ToTensor(), 34 | normalize, 35 | ] 36 | 37 | return Compose(transforms) 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 cnulab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/busi.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 100 3 | 4 | data_root: data 5 | 6 | train_dataset: busi 7 | test_datasets: [busi] 8 | 9 | epoch: 200 10 | batch_size: 8 11 | 12 | print_freq_step: 10 13 | val_freq_epoch: 1 14 | 15 | image_size: 224 16 | model_name: ViT-L-14 17 | layers_out: [12,18,24] 18 | 19 | anomaly_tasks: 20 | CutpasteTask: 0.25 21 | GaussIntensityChangeTask: 0.25 22 | SourceTask: 0.25 23 | IdentityTask: 0.25 24 | 25 | prompt_maker: coop 26 | n_learnable_token: 8 27 | CSC: True 28 | class_token_positions: [end] 29 | 30 | save_root: results 31 | 32 | prompts: 33 | normal: [ 34 | normal, 35 | healthy, 36 | negative, 37 | unremarkable, 38 | clear, 39 | asymptomatic, 40 | normal findings, 41 | no findings, 42 | in good health, 43 | no evidence of disease 44 | ] 45 | 46 | abnormal: [ 47 | abnormal, 48 | positive, 49 | symptomatic, 50 | disease, 51 | lesion, 52 | pathological, 53 | impaired, 54 | evidence of disease, 55 | abnormal finding, 56 | pathological condition, 57 | pathological abnormality 58 | ] -------------------------------------------------------------------------------- /config/brainmri.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 100 3 | 4 | data_root: data 5 | train_dataset: brainmri 6 | test_datasets: [brainmri] 7 | 8 | epoch: 200 9 | batch_size: 8 10 | 11 | print_freq_step: 10 12 | val_freq_epoch: 1 13 | 14 | image_size: 224 15 | model_name: ViT-L-14 16 | layers_out: [12,18,24] 17 | 18 | anomaly_tasks: 19 | CutpasteTask: 0.25 20 | GaussIntensityChangeTask: 0.25 21 | SourceTask: 0.25 22 | IdentityTask: 0.25 23 | 24 | 25 | prompt_maker: coop 26 | n_learnable_token: 8 27 | CSC: True 28 | class_token_positions: [end] 29 | 30 | save_root: results 31 | 32 | prompts: 33 | normal: [ 34 | normal, 35 | healthy, 36 | negative, 37 | unremarkable, 38 | clear, 39 | asymptomatic, 40 | normal findings, 41 | no findings, 42 | in good health, 43 | no evidence of disease 44 | ] 45 | 46 | abnormal: [ 47 | abnormal, 48 | positive, 49 | symptomatic, 50 | disease, 51 | lesion, 52 | pathological, 53 | impaired, 54 | evidence of disease, 55 | abnormal finding, 56 | pathological condition, 57 | pathological abnormality 58 | ] -------------------------------------------------------------------------------- /config/chexpert.yaml: -------------------------------------------------------------------------------- 1 | version: v1.0.0 2 | random_seed: 100 3 | 4 | data_root: data 5 | 6 | train_dataset: chexpert 7 | test_datasets: [chexpert] 8 | 9 | epoch: 200 10 | batch_size: 8 11 | 12 | print_freq_step: 10 13 | val_freq_epoch: 1 14 | 15 | image_size: 224 16 | model_name: ViT-L-14 17 | layers_out: [12,18,24] 18 | 19 | anomaly_tasks: 20 | CutpasteTask: 0.25 21 | GaussIntensityChangeTask: 0.25 22 | SourceTask: 0.25 23 | IdentityTask: 0.25 24 | 25 | 26 | prompt_maker: coop 27 | n_learnable_token: 8 28 | CSC: True 29 | class_token_positions: [end] 30 | 31 | 32 | save_root: results 33 | 34 | prompts: 35 | normal: [ 36 | normal, 37 | healthy, 38 | negative, 39 | unremarkable, 40 | clear, 41 | asymptomatic, 42 | normal findings, 43 | no findings, 44 | in good health, 45 | no evidence of disease 46 | ] 47 | 48 | abnormal: [ 49 | abnormal, 50 | positive, 51 | symptomatic, 52 | disease, 53 | lesion, 54 | pathological, 55 | impaired, 56 | evidence of disease, 57 | abnormal finding, 58 | pathological condition, 59 | pathological abnormality 60 | ] -------------------------------------------------------------------------------- /medsyn/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | import numpy.typing as npt 4 | 5 | 6 | def accumulate_rotation(init_matrix: npt.NDArray, angle: float, axes: Tuple[int, int]) -> npt.NDArray: 7 | """ 8 | Calculates rotation matrix and multiplies it with current transformation matrix. 9 | :param init_matrix: Current transformation matrix. 10 | :param angle: Angle of rotation in radians. 11 | :param axes: Tuple of axes to be rotated within. 12 | """ 13 | 14 | cos_ang = np.cos(angle) 15 | sin_ang = np.sin(angle) 16 | 17 | rot_mat = np.identity(len(init_matrix)) 18 | 19 | # Note: as the axes are always in ascending order, that does mean that for N>2, some rotations may be backwards 20 | # (due to the sine/-sine always being in wrong places) 21 | # BUT because we're sampling it symmetrically around 0 the distribution doesn't change. 22 | a1, a2 = axes 23 | rot_mat[a1, a1] = rot_mat[a2, a2] = cos_ang 24 | rot_mat[a1, a2] = -sin_ang 25 | rot_mat[a2, a1] = sin_ang 26 | return rot_mat @ init_matrix 27 | 28 | 29 | def accumulate_scaling(init_matrix: npt.NDArray, scale: float) -> npt.NDArray: 30 | """ 31 | Calculates scaling matrix and multiplies it with current transformation matrix. 32 | Assumes cartesian coordinates (not homogeneous). 33 | :param init_matrix: Current transformation matrix 34 | :param scale: Factor to scale by. 35 | """ 36 | scale_mat = np.identity(len(init_matrix)) * scale 37 | 38 | return scale_mat @ init_matrix 39 | 40 | 41 | def get_patch_slices(patch_corner: np.ndarray, patch_shape: Tuple[int]) -> Tuple[slice]: 42 | return tuple([slice(c, c + d) for (c, d) in zip(patch_corner, patch_shape)]) 43 | 44 | 45 | # Same as above, but with additional slice at beginning to include all image channels. 46 | def get_patch_image_slices(patch_corner: np.ndarray, patch_shape: Tuple[int]) -> Tuple[slice]: 47 | return tuple([slice(None)] + list(get_patch_slices(patch_corner, patch_shape))) 48 | 49 | 50 | def nsa_sample_dimension(lb, ub, img_d): 51 | gamma_lb = 0.03 52 | gamma_shape = 2 53 | gamma_scale = 0.1 54 | 55 | gamma_sample = (gamma_lb + np.random.gamma(gamma_shape, gamma_scale)) * img_d 56 | 57 | return int(np.clip(gamma_sample, lb, ub)) 58 | -------------------------------------------------------------------------------- /open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ### Datasets Links 2 | - **BUSI [[Baidu Cloud (pass8866)]](https://pan.baidu.com/s/1EVt96fExiqrvMQslPDRRRg?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1PyvMXdNEVY86BY1PV8yKhPVS30TAmS6X/view?usp=drive_link) [[Official Link]](https://scholar.cu.edu.eg/?q=afahmy/pages/dataset)** 3 | - **BrainMRI [[Baidu Cloud (pass8866)]](https://pan.baidu.com/s/1--5vPMN-eTqePPYjpKTwvA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1kldE-5_wXaN-JR_8Y_mRCKQ6VZiyv3km/view?usp=drive_link) [[Official Link]](https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection)** 4 | - **CheXpert [[Baidu Cloud (pass8866)]](https://pan.baidu.com/s/15-V5wobA_7ICvZAXBraDGA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1pVYRipGC2VqjYP-wHdDFR-lLf7itLiUi/view?usp=drive_link) [[Official Link]](https://stanfordmlgroup.github.io/competitions/chexpert/)** 5 | 6 | ### The complete directory structure is as follows: 7 | ``` 8 | |--data 9 | |--brainmri 10 | |--samples 11 | |--train.json 12 | |--test.json 13 | |--images 14 | |--test 15 | |--abnormal 16 | |--image_97.jpg 17 | |--... 18 | |--normal 19 | |--image_32.jpg 20 | |--... 21 | |--train 22 | |--normal 23 | |--image_0.jpg 24 | |--... 25 | |--busi 26 | |--samples 27 | |--train.json 28 | |--test.json 29 | |--images 30 | |--test 31 | |--abnormal 32 | |--benign_0.jpg 33 | |--... 34 | |--normal 35 | |--normal_32.jpg 36 | |--... 37 | |--ground_true 38 | |--benign_mask_0.jpg 39 | |--... 40 | |--train 41 | |--normal 42 | |--normal_0.jpg 43 | |--... 44 | |--chexpert 45 | |--samples 46 | |--train.json 47 | |--test.json 48 | |--images 49 | |--test 50 | |--abnormal 51 | |--00002.jpg 52 | |--... 53 | |--normal 54 | |--00960.jpg 55 | |--... 56 | |--train 57 | |--normal 58 | |--00000.jpg 59 | |--... 60 | ``` 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MediCLIP 2 | 3 | **💡 This is the official implementation of the paper "MediCLIP: Adapting CLIP for Few-shot Medical Image Anomaly Detection"(MICCAI 2024) [[arxiv]](https://arxiv.org/abs/2405.11315)**. 4 | 5 | MediCLIP is an efficient few-shot medical image anomaly detection method, demonstrating SOTA anomaly detection performance with very few normal medical images. MediCLIP integrates learnable prompts, adapters, and realistic medical image anomaly synthesis tasks. 6 | 7 | 8 |
9 | 10 | ## 🔧 Installation 11 | 12 | To run experiments, first clone the repository and install `requirements.txt`. 13 | 14 | ``` 15 | $ git clone https://github.com/cnulab/MediCLIP.git 16 | $ cd MediCLIP 17 | $ pip install -r requirements.txt 18 | ``` 19 | ### Data preparation 20 | Download the following datasets: 21 | - **BUSI [[Baidu Cloud (pwd8866)]](https://pan.baidu.com/s/1EVt96fExiqrvMQslPDRRRg?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1PyvMXdNEVY86BY1PV8yKhPVS30TAmS6X/view?usp=drive_link) [[Official Link]](https://scholar.cu.edu.eg/?q=afahmy/pages/dataset)** 22 | - **BrainMRI [[Baidu Cloud (pwd8866)]](https://pan.baidu.com/s/1--5vPMN-eTqePPYjpKTwvA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1kldE-5_wXaN-JR_8Y_mRCKQ6VZiyv3km/view?usp=drive_link) [[Official Link]](https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection)** 23 | - **CheXpert [[Baidu Cloud (pwd8866)]](https://pan.baidu.com/s/15-V5wobA_7ICvZAXBraDGA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1pVYRipGC2VqjYP-wHdDFR-lLf7itLiUi/view?usp=drive_link) [[Official Link]](https://stanfordmlgroup.github.io/competitions/chexpert/)** 24 | 25 | Unzip them to the `data`. Please refer to [data/README](data/README.md). 26 | 27 | ## 🚀 Experiments 28 | 29 | To train the MediCLIP on the BrainMRI dataset with the support set size is 16: 30 | ``` 31 | $ python train.py --config_path config/brainmri.yaml --k_shot 16 32 | ``` 33 | 34 | To test the MediCLIP on the BrainMRI dataset: 35 | ``` 36 | $ python test.py --config_path config/brainmri.yaml --checkpoint_path xxx.pkl 37 | ``` 38 | Replace ``xxx.pkl`` with your checkpoint path. 39 |
40 | 41 | --- 42 | Code reference: **[[CLIP]](https://github.com/OpenAI/CLIP)** **[[CoOp]](https://github.com/KaiyangZhou/CoOp)** **[[Many-Tasks-Make-Light-Work]](https://github.com/matt-baugh/many-tasks-make-light-work)**. 43 | 44 | 45 | ## 🔗 Citation 46 | 47 | If this work is helpful to you, please cite it as: 48 | ``` 49 | @inproceedings{zhang2024mediclip, 50 | title={MediCLIP: Adapting CLIP for Few-shot Medical Image Anomaly Detection}, 51 | author={Ximiao Zhang, Min Xu, Dehui Qiu, Ruixin Yan, Ning Lang, and Xiuzhuang Zhou}, 52 | year={2024}, 53 | eprint={2405.11315}, 54 | archivePrefix={arXiv}, 55 | primaryClass={cs.CV} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | 28 | ): 29 | """Load a CLIP model 30 | 31 | Parameters 32 | ---------- 33 | name : str 34 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 35 | precision: str 36 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 37 | device : Union[str, torch.device] 38 | The device to put the loaded model 39 | jit : bool 40 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 41 | cache_dir : Optional[str] 42 | The directory to cache the downloaded model weights 43 | 44 | Returns 45 | ------- 46 | model : torch.nn.Module 47 | The CLIP model 48 | preprocess : Callable[[PIL.Image], torch.Tensor] 49 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 50 | """ 51 | 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai')) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location="cpu").eval() 67 | state_dict = None 68 | 69 | except RuntimeError: 70 | # loading saved state dict 71 | state_dict = torch.load(model_path, map_location="cpu") 72 | 73 | cast_dtype = get_cast_dtype(precision) 74 | 75 | try: 76 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 77 | except KeyError: 78 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 79 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 80 | 81 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 82 | model = model.to(device) 83 | 84 | if precision.startswith('amp') or precision == 'fp32': 85 | model.float() 86 | elif precision == 'bf16': 87 | convert_weights_to_lp(model, dtype=torch.bfloat16) 88 | 89 | return model 90 | -------------------------------------------------------------------------------- /medsyn/labelling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union 3 | import numpy as np 4 | import numpy.typing as npt 5 | from scipy import ndimage 6 | from skimage import morphology 7 | 8 | 9 | class AnomalyLabeller(ABC): 10 | @abstractmethod 11 | def label(self, aug_img: npt.NDArray[float], orig_img: npt.NDArray[float], mask: npt.NDArray[bool]) \ 12 | -> npt.NDArray[float]: 13 | """ 14 | :param aug_img: Image with anomaly augmentation applied. 15 | :param orig_img: Original image, prior to anomalies. 16 | :param mask: Mask of where the image has been altered. 17 | """ 18 | pass 19 | 20 | def __call__(self, aug_img: npt.NDArray[float], orig_img: npt.NDArray[float], mask: npt.NDArray[bool]) \ 21 | -> npt.NDArray[float]: 22 | 23 | return self.label(aug_img, orig_img, mask) 24 | 25 | 26 | class IntensityDiffLabeller(AnomalyLabeller, ABC): 27 | 28 | def __init__(self): 29 | super().__init__() 30 | self.binary_structures = {} 31 | 32 | @abstractmethod 33 | def label_fn(self, x: Union[npt.NDArray[float], float]) -> Union[npt.NDArray[float], float]: 34 | pass 35 | 36 | def label(self, 37 | aug_img: npt.NDArray[float], 38 | orig_img: npt.NDArray[float], 39 | mask: npt.NDArray[bool]) \ 40 | -> npt.NDArray[float]: 41 | """ 42 | :param aug_img: Image with patches blended within it. 43 | :param orig_img: Original image, prior to anomalies. 44 | :param mask: Mask of where the image has been altered. 45 | """ 46 | 47 | avg_diff = np.mean(mask * np.abs(aug_img - orig_img), axis=0) 48 | 49 | scaled_diff = self.label_fn(avg_diff) 50 | 51 | assert np.all(scaled_diff >= 0) 52 | 53 | num_dims = len(mask.shape) 54 | if num_dims not in self.binary_structures: 55 | self.binary_structures[num_dims] = ndimage.generate_binary_structure(num_dims, 2) 56 | 57 | bin_structure = self.binary_structures[num_dims] 58 | 59 | for anom_slice in ndimage.find_objects(ndimage.label(mask)[0]): 60 | anom_region_label = ndimage.grey_closing(scaled_diff[anom_slice], footprint=bin_structure) 61 | 62 | recon_seed = np.copy(anom_region_label) 63 | recon_seed[num_dims * (slice(1, -1),)] = anom_region_label.max() 64 | scaled_diff[anom_slice] = morphology.reconstruction(recon_seed, anom_region_label, 65 | method='erosion', 66 | ) 67 | return scaled_diff 68 | 69 | 70 | 71 | class SaturatingLabeller(IntensityDiffLabeller): 72 | 73 | def __init__(self, a: float, c: float): 74 | """ 75 | Labeller using transformed sigmoid function: (1 + c) / (1 + e^(-ax+b)) - c 76 | Function range is [-c, 1] 77 | """ 78 | super().__init__() 79 | self.a = a 80 | self.c = c 81 | 82 | def label_fn(self, x: Union[npt.NDArray[float], float]) -> Union[npt.NDArray[float], float]: 83 | return (1 + self.c) / (1 + np.exp(-self.a * x) / self.c) - self.c 84 | 85 | def __call__(self, aug_img: npt.NDArray[float], orig_img: npt.NDArray[float], mask: npt.NDArray[bool]) \ 86 | -> npt.NDArray[float]: 87 | return self.label(aug_img, orig_img, mask) 88 | 89 | 90 | class FlippedGaussianLabeller(IntensityDiffLabeller): 91 | def __init__(self, std: float): 92 | super(FlippedGaussianLabeller, self).__init__() 93 | self.std = std 94 | 95 | def label_fn(self, x: Union[npt.NDArray[float], float]) -> Union[npt.NDArray[float], float]: 96 | return 1 - np.exp(-x**2 / (2 * self.std**2)) 97 | -------------------------------------------------------------------------------- /utils/misc_helper.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | from datetime import datetime 4 | import random 5 | import os 6 | import torch 7 | import numpy as np 8 | from collections.abc import Mapping 9 | import shutil 10 | 11 | from sklearn import metrics 12 | 13 | 14 | def map_func(storage, location): 15 | return storage.cuda() 16 | 17 | 18 | def create_logger(name, log_file, level=logging.INFO): 19 | log = logging.getLogger(name) 20 | formatter = logging.Formatter( 21 | "[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s" 22 | ) 23 | fh = logging.FileHandler(log_file) 24 | fh.setFormatter(formatter) 25 | sh = logging.StreamHandler() 26 | sh.setFormatter(formatter) 27 | log.setLevel(level) 28 | log.addHandler(fh) 29 | log.addHandler(sh) 30 | return log 31 | 32 | 33 | 34 | def set_seed(seed): 35 | os.environ['PYTHONHASHSEED'] = str(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | torch.backends.cudnn.deterministic = True 40 | random.seed(seed) 41 | 42 | 43 | 44 | def get_current_time(): 45 | current_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 46 | return current_time 47 | 48 | 49 | 50 | class AverageMeter(object): 51 | """Computes and stores the average and current value""" 52 | 53 | def __init__(self, length=0): 54 | self.length = length 55 | self.reset() 56 | 57 | def reset(self): 58 | if self.length > 0: 59 | self.history = [] 60 | else: 61 | self.count = 0 62 | self.sum = 0.0 63 | self.val = 0.0 64 | self.avg = 0.0 65 | 66 | def update(self, val, num=1): 67 | if self.length > 0: 68 | # currently assert num==1 to avoid bad usage, refine when there are some explict requirements 69 | assert num == 1 70 | self.history.append(val) 71 | if len(self.history) > self.length: 72 | del self.history[0] 73 | 74 | self.val = self.history[-1] 75 | self.avg = np.mean(self.history) 76 | else: 77 | self.val = val 78 | self.sum += val * num 79 | self.count += num 80 | self.avg = self.sum / self.count 81 | 82 | 83 | def compute_imagewise_metrics( 84 | anomaly_prediction, 85 | anomaly_ground_truth_labels 86 | ): 87 | """ 88 | Computes retrieval statistics (AUROC, FPR, TPR). 89 | 90 | Args: 91 | anomaly_prediction: [np.array or list] [N] Assignment weights 92 | per image. Higher indicates higher 93 | probability of being an anomaly. 94 | anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1 95 | if image is an anomaly, 0 if not. 96 | """ 97 | auroc = metrics.roc_auc_score( 98 | anomaly_ground_truth_labels, anomaly_prediction 99 | ) 100 | 101 | return {"image-auroc": auroc} 102 | 103 | 104 | def compute_pixelwise_metrics( 105 | pixel_prediction, 106 | pixel_ground_truth_labels 107 | ): 108 | """ 109 | Computes retrieval statistics (AUROC, FPR, TPR). 110 | 111 | Args: 112 | anomaly_prediction: [np.array or list] [N] Assignment weights 113 | per image. Higher indicates higher 114 | probability of being an anomaly. 115 | anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1 116 | if image is an anomaly, 0 if not. 117 | """ 118 | pixel_prediction = np.concatenate( 119 | [pred.flatten() for pred in pixel_prediction], axis=0 120 | ) 121 | 122 | pixel_ground_truth_labels = np.concatenate( 123 | [label.flatten() for label in pixel_ground_truth_labels], axis=0 124 | ) 125 | 126 | pixel_ground_truth_labels[pixel_ground_truth_labels > 0] = 1 127 | 128 | pixel_auroc = metrics.roc_auc_score( 129 | pixel_ground_truth_labels, pixel_prediction 130 | ) 131 | 132 | return {"pixel-auroc": pixel_auroc} -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from math import exp 7 | 8 | 9 | class FocalLoss(nn.Module): 10 | """ 11 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 12 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 13 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 14 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 15 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 16 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 17 | focus on hard misclassified example 18 | :param smooth: (float,double) smooth value when cross entropy 19 | :param balance_index: (int) balance class index, should be specific when alpha is float 20 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 21 | """ 22 | 23 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): 24 | super(FocalLoss, self).__init__() 25 | self.apply_nonlin = apply_nonlin 26 | self.alpha = alpha 27 | self.gamma = gamma 28 | self.balance_index = balance_index 29 | self.smooth = smooth 30 | self.size_average = size_average 31 | 32 | if self.smooth is not None: 33 | if self.smooth < 0 or self.smooth > 1.0: 34 | raise ValueError('smooth value should be in [0,1]') 35 | 36 | def forward(self, logit, target): 37 | if self.apply_nonlin is not None: 38 | logit = self.apply_nonlin(logit) 39 | num_class = logit.shape[1] 40 | 41 | if logit.dim() > 2: 42 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 43 | logit = logit.view(logit.size(0), logit.size(1), -1) 44 | logit = logit.permute(0, 2, 1).contiguous() 45 | logit = logit.view(-1, logit.size(-1)) 46 | target = torch.squeeze(target, 1) 47 | target = target.view(-1, 1) 48 | alpha = self.alpha 49 | 50 | if alpha is None: 51 | alpha = torch.ones(num_class, 1) 52 | elif isinstance(alpha, (list, np.ndarray)): 53 | assert len(alpha) == num_class 54 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 55 | alpha = alpha / alpha.sum() 56 | elif isinstance(alpha, float): 57 | alpha = torch.ones(num_class, 1) 58 | alpha = alpha * (1 - self.alpha) 59 | alpha[self.balance_index] = self.alpha 60 | 61 | else: 62 | raise TypeError('Not support alpha type') 63 | 64 | if alpha.device != logit.device: 65 | alpha = alpha.to(logit.device) 66 | 67 | idx = target.cpu().long() 68 | 69 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 70 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 71 | if one_hot_key.device != logit.device: 72 | one_hot_key = one_hot_key.to(logit.device) 73 | 74 | if self.smooth: 75 | one_hot_key = torch.clamp( 76 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) 77 | pt = (one_hot_key * logit).sum(1) + self.smooth 78 | logpt = pt.log() 79 | 80 | gamma = self.gamma 81 | 82 | alpha = alpha[idx] 83 | alpha = torch.squeeze(alpha) 84 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 85 | 86 | if self.size_average: 87 | loss = loss.mean() 88 | return loss 89 | 90 | 91 | class BinaryDiceLoss(nn.Module): 92 | def __init__(self): 93 | super(BinaryDiceLoss, self).__init__() 94 | 95 | def forward(self, input, targets): 96 | N = targets.size()[0] 97 | smooth = 1 98 | input_flat = input.view(N, -1) 99 | targets_flat = targets.view(N, -1) 100 | 101 | intersection = input_flat * targets_flat 102 | N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth) 103 | loss = 1 - N_dice_eff.sum() / N 104 | return loss 105 | 106 | 107 | class CrossEntropyLoss(nn.Module): 108 | 109 | def __init__(self): 110 | super().__init__() 111 | self.criterion = nn.CrossEntropyLoss() 112 | 113 | def forward(self, logit,target): 114 | 115 | bsz,_,h,w=logit.size() 116 | logit = logit.view(bsz, 2, -1) 117 | gt_mask = target.view(bsz, -1).long() 118 | return self.criterion(logit,gt_mask) 119 | -------------------------------------------------------------------------------- /open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | import numpy as np 7 | from copy import deepcopy 8 | from pathlib import Path 9 | from typing import Any, Dict, Optional, Tuple, Union 10 | import torch 11 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 12 | from .model import CLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ 13 | resize_pos_embed, get_cast_dtype 14 | from .openai import load_openai_model 15 | from .transform import image_transform 16 | from .tokenizer import tokenize 17 | 18 | 19 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 20 | 21 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 22 | 23 | 24 | def _natural_key(string_): 25 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 26 | 27 | def _rescan_model_configs(): 28 | global _MODEL_CONFIGS 29 | 30 | config_ext = ('.json',) 31 | config_files = [] 32 | for config_path in _MODEL_CONFIG_PATHS: 33 | if config_path.is_file() and config_path.suffix in config_ext: 34 | config_files.append(config_path) 35 | elif config_path.is_dir(): 36 | for ext in config_ext: 37 | config_files.extend(config_path.glob(f'*{ext}')) 38 | 39 | for cf in config_files: 40 | with open(cf, 'r') as f: 41 | model_cfg = json.load(f) 42 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 43 | _MODEL_CONFIGS[cf.stem] = model_cfg 44 | 45 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 46 | 47 | 48 | _rescan_model_configs() # initial populate of model config registry 49 | 50 | 51 | def list_models(): 52 | """ enumerate available model architectures based on config files """ 53 | return list(_MODEL_CONFIGS.keys()) 54 | 55 | 56 | def add_model_config(path): 57 | """ add model config path or file and update registry """ 58 | if not isinstance(path, Path): 59 | path = Path(path) 60 | _MODEL_CONFIG_PATHS.append(path) 61 | _rescan_model_configs() 62 | 63 | def get_model_config(model_name): 64 | if model_name in _MODEL_CONFIGS: 65 | return deepcopy(_MODEL_CONFIGS[model_name]) 66 | else: 67 | return None 68 | 69 | def get_tokenizer(): 70 | return tokenize 71 | 72 | 73 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 74 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 75 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 76 | state_dict = checkpoint['state_dict'] 77 | else: 78 | state_dict = checkpoint 79 | if next(iter(state_dict.items()))[0].startswith('module'): 80 | state_dict = {k[7:]: v for k, v in state_dict.items()} 81 | return state_dict 82 | 83 | 84 | def load_checkpoint(model, checkpoint_path, strict=True): 85 | state_dict = load_state_dict(checkpoint_path) 86 | # detect old format and make compatible with new format 87 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 88 | state_dict = convert_to_custom_text_state_dict(state_dict) 89 | resize_pos_embed(state_dict, model) 90 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 91 | return incompatible_keys 92 | 93 | 94 | 95 | def create_model( 96 | model_name: str, 97 | img_size: int, 98 | precision: str = 'fp32', 99 | device: Union[str, torch.device] = 'cpu', 100 | ): 101 | 102 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 103 | pretrained_cfg = {} 104 | model_cfg = None 105 | 106 | if isinstance(device, str): 107 | device = torch.device(device) 108 | 109 | model_cfg = model_cfg or get_model_config(model_name) 110 | 111 | if model_cfg['vision_cfg']['image_size'] != img_size: 112 | model_cfg['vision_cfg']['image_size'] = img_size 113 | 114 | cast_dtype = get_cast_dtype(precision) 115 | 116 | model_pre = load_openai_model( 117 | model_name, 118 | precision=precision, 119 | device=device, 120 | ) 121 | 122 | state_dict = model_pre.state_dict() 123 | 124 | model = CLIP(**model_cfg, cast_dtype=cast_dtype) 125 | 126 | ### for resnet 127 | if not hasattr(model.visual, 'grid_size'): 128 | model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1)) 129 | 130 | resize_pos_embed(state_dict, model) 131 | incompatible_keys = model.load_state_dict(state_dict, strict=True) 132 | 133 | model.to(device=device) 134 | 135 | if precision in ("fp16", "bf16"): 136 | convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) 137 | 138 | # set image / mean metadata from pretrained_cfg if available, or use default 139 | model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN 140 | model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD 141 | 142 | else: 143 | model = load_openai_model( 144 | model_name, 145 | precision=precision, 146 | device=device, 147 | ) 148 | 149 | model.device=device 150 | 151 | return model,model_cfg 152 | 153 | 154 | 155 | def create_model_and_transforms( 156 | model_name: str, 157 | img_size: int, 158 | precision: str = 'fp32', 159 | device: Union[str, torch.device] = 'cpu', 160 | ): 161 | 162 | model,model_cfg = create_model( 163 | model_name, 164 | img_size, 165 | precision=precision, 166 | device=device, 167 | ) 168 | 169 | image_mean = OPENAI_DATASET_MEAN 170 | image_std = OPENAI_DATASET_STD 171 | 172 | preprocess = image_transform( 173 | mean=image_mean, 174 | std=image_std, 175 | ) 176 | 177 | return model, preprocess, model_cfg -------------------------------------------------------------------------------- /open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a significant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | def whitespace_clean(text): 66 | text = re.sub(r'\s+', ' ', text) 67 | text = text.strip() 68 | return text 69 | 70 | 71 | class SimpleTokenizer(object): 72 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 73 | self.byte_encoder = bytes_to_unicode() 74 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 75 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 76 | merges = merges[1:49152-256-2+1] 77 | merges = [tuple(merge.split()) for merge in merges] 78 | vocab = list(bytes_to_unicode().values()) 79 | 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | 84 | if not special_tokens: 85 | special_tokens = ['', ''] 86 | else: 87 | special_tokens = ['', ''] + special_tokens 88 | vocab.extend(special_tokens) 89 | self.encoder = dict(zip(vocab, range(len(vocab)))) 90 | self.decoder = {v: k for k, v in self.encoder.items()} 91 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 92 | self.cache = {t:t for t in special_tokens} 93 | special = "|".join(special_tokens) 94 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 95 | 96 | self.vocab_size = len(self.encoder) 97 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 98 | 99 | def bpe(self, token): 100 | if token in self.cache: 101 | return self.cache[token] 102 | word = tuple(token[:-1]) + ( token[-1] + '',) 103 | pairs = get_pairs(word) 104 | 105 | if not pairs: 106 | return token+'' 107 | 108 | while True: 109 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 110 | if bigram not in self.bpe_ranks: 111 | break 112 | first, second = bigram 113 | new_word = [] 114 | i = 0 115 | while i < len(word): 116 | try: 117 | j = word.index(first, i) 118 | new_word.extend(word[i:j]) 119 | i = j 120 | except: 121 | new_word.extend(word[i:]) 122 | break 123 | 124 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 125 | new_word.append(first+second) 126 | i += 2 127 | else: 128 | new_word.append(word[i]) 129 | i += 1 130 | new_word = tuple(new_word) 131 | word = new_word 132 | if len(word) == 1: 133 | break 134 | else: 135 | pairs = get_pairs(word) 136 | word = ' '.join(word) 137 | self.cache[token] = word 138 | return word 139 | 140 | def encode(self, text): 141 | bpe_tokens = [] 142 | text = whitespace_clean(basic_clean(text)).lower() 143 | for token in re.findall(self.pat, text): 144 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 145 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 146 | return bpe_tokens 147 | 148 | def decode(self, tokens): 149 | text = ''.join([self.decoder[token] for token in tokens]) 150 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 151 | return text 152 | 153 | 154 | _tokenizer = SimpleTokenizer() 155 | 156 | def decode(output_ids: torch.Tensor): 157 | output_ids = output_ids.cpu().numpy() 158 | return _tokenizer.decode(output_ids) 159 | 160 | 161 | 162 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 163 | """ 164 | Returns the tokenized representation of given input string(s) 165 | 166 | Parameters 167 | ---------- 168 | texts : Union[str, List[str]] 169 | An input string or a list of input strings to tokenize 170 | context_length : int 171 | The context length to use; all CLIP models use 77 as the context length 172 | 173 | Returns 174 | ------- 175 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 176 | """ 177 | if isinstance(texts, str): 178 | texts = [texts] 179 | 180 | sot_token = _tokenizer.encoder[""] 181 | eot_token = _tokenizer.encoder[""] 182 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 183 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 184 | 185 | for i, tokens in enumerate(all_tokens): 186 | if len(tokens) > context_length: 187 | tokens = tokens[:context_length] # Truncate 188 | tokens[-1] = eot_token 189 | result[i, :len(tokens)] = torch.tensor(tokens) 190 | return result 191 | 192 | 193 | -------------------------------------------------------------------------------- /open_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from open_clip.utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x, out_blocks): 174 | x = self.stem(x) 175 | x_1 = self.layer1(x) 176 | x_2 = self.layer2(x_1) 177 | x_3 = self.layer3(x_2) 178 | x_4 = self.layer4(x_3) 179 | x = self.attnpool(x_4) 180 | 181 | out_tokens = [] 182 | x_blocks = [x_1, x_2, x_3, x_4] 183 | for i in out_blocks: 184 | out_tokens.append(x_blocks[i - 1]) 185 | 186 | return x, out_tokens 187 | -------------------------------------------------------------------------------- /models/CoOp.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.nn import functional as F 7 | from torch.cuda.amp import GradScaler, autocast 8 | from open_clip.tokenizer import SimpleTokenizer,tokenize 9 | 10 | 11 | class TextEncoder(nn.Module): 12 | def __init__(self, clip_model): 13 | 14 | super().__init__() 15 | 16 | self.transformer = clip_model.transformer 17 | self.positional_embedding = clip_model.positional_embedding 18 | self.ln_final = clip_model.ln_final 19 | self.text_projection = clip_model.text_projection 20 | 21 | 22 | def forward(self, prompts, tokenized_prompts): 23 | 24 | x = prompts + self.positional_embedding 25 | x = x.permute(1, 0, 2) # NLD -> LND 26 | x,_,_ = self.transformer(x) 27 | x = x.permute(1, 0, 2) # LND -> NLD 28 | x = self.ln_final(x) 29 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 30 | return x 31 | 32 | 33 | 34 | class PromptLearner(nn.Module): 35 | def __init__(self, 36 | prompts, 37 | n_ctx, # prompt max len 38 | CSC, # True or False multi prompt 39 | class_token_position, # cls position 40 | clip_model): 41 | 42 | super().__init__() 43 | 44 | ctx_dim = clip_model.ln_final.weight.shape[0] # 45 | 46 | self.ctx={} 47 | 48 | for cls in prompts: 49 | for position in class_token_position: 50 | if CSC: 51 | ctx_vectors = torch.empty(len(prompts[cls]), n_ctx, ctx_dim).to(clip_model.device) 52 | else: 53 | ctx_vectors = torch.empty(n_ctx, ctx_dim).to(clip_model.device) 54 | nn.init.normal_(ctx_vectors, std=0.02) 55 | self.ctx['{}_{}'.format(cls,position)]=nn.Parameter(ctx_vectors,requires_grad=True) 56 | 57 | self.ctx = nn.ParameterDict(self.ctx) # to be optimized 58 | 59 | prompt_prefix = " ".join(["X"] * n_ctx) 60 | 61 | _tokenizer = SimpleTokenizer() 62 | 63 | prompts_split={cls: [prompt.replace("_", " ") for prompt in prompts[cls]] for cls in prompts} 64 | 65 | prompts_lens= {cls: [ len(_tokenizer.encode(prompt)) for prompt in prompts_split[cls]] for cls in prompts_split} 66 | 67 | prompts_learnable_tokens = {cls:[prompt_prefix + " " + prompt + "." for prompt in prompts_split[cls]] for cls in prompts_split} 68 | 69 | tokenized_prompts = {cls:torch.cat([tokenize(prompt) for prompt in prompts_learnable_tokens[cls]]).to(clip_model.device) for cls in prompts_learnable_tokens} 70 | 71 | with torch.no_grad(): 72 | embeddings = {cls:clip_model.token_embedding(tokenized_prompts[cls]) for cls in tokenized_prompts} 73 | 74 | self.register_embeddings={} 75 | 76 | for cls in embeddings: 77 | self.register_embeddings['{}_token_prefix'.format(cls)]=embeddings[cls][:, :1, :] 78 | self.register_embeddings['{}_token_suffix'.format(cls)]=embeddings[cls][:, 1 + n_ctx :, :] 79 | 80 | self.n_ctx = n_ctx 81 | self.tokenized_prompts = tokenized_prompts 82 | self.prompts_lens = prompts_lens 83 | self.class_token_position = class_token_position 84 | 85 | 86 | def forward(self): 87 | cls_prompts={} 88 | 89 | for cls in self.tokenized_prompts: 90 | 91 | prefix = self.register_embeddings['{}_token_prefix'.format(cls)] 92 | suffix = self.register_embeddings['{}_token_suffix'.format(cls)] 93 | 94 | cls_prompts[cls]=[] 95 | 96 | for position in self.class_token_position: 97 | 98 | ctx = self.ctx['{}_{}'.format(cls,position)] 99 | if ctx.dim() == 2: 100 | ctx = ctx.unsqueeze(0).expand(len(self.prompts_lens[cls]), -1, -1) 101 | 102 | if position == "end": 103 | prompts = torch.cat( 104 | [ 105 | prefix, # (n_cls, 1, dim) 106 | ctx, # (n_cls, n_ctx, dim) 107 | suffix, # (n_cls, *, dim) 108 | ], 109 | dim=1, 110 | ) 111 | 112 | elif position == "middle": 113 | 114 | half_n_ctx = self.n_ctx // 2 115 | prompts = [] 116 | 117 | for i in range(len(self.prompts_lens[cls])): 118 | p_len = self.prompts_lens[cls][i] 119 | 120 | prefix_i = prefix[i : i + 1, :, :] 121 | class_i = suffix[i : i + 1, :p_len, :] 122 | suffix_i = suffix[i : i + 1, p_len:, :] 123 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] 124 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] 125 | 126 | prompt = torch.cat( 127 | [ 128 | prefix_i, # (1, 1, dim) 129 | ctx_i_half1, # (1, n_ctx//2, dim) 130 | class_i, # (1, name_len, dim) 131 | ctx_i_half2, # (1, n_ctx//2, dim) 132 | suffix_i, # (1, *, dim) 133 | ], 134 | dim=1, 135 | ) 136 | prompts.append(prompt) 137 | prompts = torch.cat(prompts, dim=0) 138 | 139 | else : 140 | assert position == "front" 141 | prompts = [] 142 | 143 | for i in range(len(self.prompts_lens[cls])): 144 | p_len = self.prompts_lens[cls][i] 145 | 146 | prefix_i = prefix[i : i + 1, :, :] 147 | class_i = suffix[i : i + 1, :p_len, :] 148 | suffix_i = suffix[i : i + 1, p_len:, :] 149 | ctx_i = ctx[i : i + 1, :, :] 150 | prompt = torch.cat( 151 | [ 152 | prefix_i, # (1, 1, dim) 153 | class_i, # (1, name_len, dim) 154 | ctx_i, # (1, n_ctx, dim) 155 | suffix_i, # (1, *, dim) 156 | ], 157 | dim=1, 158 | ) 159 | prompts.append(prompt) 160 | 161 | prompts = torch.cat(prompts, dim=0) 162 | 163 | cls_prompts[cls].append(prompts) 164 | cls_prompts[cls]=torch.cat(cls_prompts[cls],dim=0) 165 | return cls_prompts 166 | 167 | 168 | class PromptMaker(nn.Module): 169 | 170 | def __init__(self, 171 | prompts, 172 | clip_model, 173 | n_ctx: int=8, # prompt max len 174 | CSC: bool= True, # True or False multi prompt 175 | class_token_position: list=['end'], # cls position 176 | ): 177 | 178 | super().__init__() 179 | assert 'normal' in prompts and 'abnormal' in prompts 180 | 181 | for position in class_token_position: 182 | assert position in ['end','middle','front'] 183 | 184 | self.prompt_learner = PromptLearner(prompts, n_ctx, CSC, class_token_position, clip_model) 185 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 186 | 187 | self.class_token_position = class_token_position 188 | self.text_encoder = TextEncoder(clip_model) 189 | 190 | def forward(self, image_features): 191 | prompts = self.prompt_learner() 192 | tokenized_prompts = self.tokenized_prompts 193 | text_features=[] 194 | 195 | for cls in prompts: 196 | class_embedding = self.text_encoder(prompts[cls], tokenized_prompts[cls].repeat(len(self.class_token_position),1)) 197 | class_embedding = class_embedding.mean(dim=0) 198 | class_embedding = class_embedding / class_embedding.norm() 199 | text_features.append(class_embedding) 200 | text_features = torch.stack(text_features, dim=1) 201 | return text_features -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from enum import Enum 4 | import PIL 5 | import torch 6 | from torchvision import transforms 7 | import json 8 | from PIL import Image 9 | import numpy as np 10 | 11 | from medsyn.tasks import CutPastePatchBlender,\ 12 | SmoothIntensityChangeTask,\ 13 | GaussIntensityChangeTask,\ 14 | SinkDeformationTask,\ 15 | SourceDeformationTask,\ 16 | IdentityTask 17 | 18 | 19 | class TrainDataset(torch.utils.data.Dataset): 20 | 21 | def __init__( 22 | self, 23 | args, 24 | source, 25 | preprocess, 26 | k_shot = -1, 27 | **kwargs, 28 | ): 29 | 30 | super().__init__() 31 | self.args = args 32 | self.source = source 33 | self.k_shot = k_shot 34 | self.transform_img = preprocess 35 | self.data_to_iterate = self.get_image_data() 36 | self.augs,self.augs_pro = self.load_anomaly_syn() 37 | assert sum(self.augs_pro)==1.0 38 | 39 | def __getitem__(self, idx): 40 | info = self.data_to_iterate[idx] 41 | image_path=os.path.join(self.source,'images',info['filename']) 42 | image = self.read_image(image_path) 43 | choice_aug = np.random.choice(a=[aug for aug in self.augs], 44 | p = [pro for pro in self.augs_pro], 45 | size=(1,), replace=False)[0] 46 | image, mask = choice_aug(image) 47 | image = Image.fromarray(image.astype(np.uint8)).convert('RGB') 48 | image = self.transform_img(image) 49 | mask = torch.from_numpy(mask) 50 | return { 51 | "image": image, 52 | "mask" : mask, 53 | } 54 | 55 | def __len__(self): 56 | return len(self.data_to_iterate) 57 | 58 | def read_image(self,path): 59 | image = PIL.Image.open(path).resize((self.args.image_size,self.args.image_size), 60 | PIL.Image.Resampling.BILINEAR).convert("L") 61 | image = np.array(image).astype(np.uint8) 62 | return image 63 | 64 | def get_image_data(self): 65 | data_to_iterate = [] 66 | with open(os.path.join(self.source,'samples',"train.json"), "r") as f_r: 67 | for line in f_r: 68 | meta = json.loads(line) 69 | data_to_iterate.append(meta) 70 | if self.k_shot != -1: 71 | data_to_iterate = random.sample( 72 | data_to_iterate, self.k_shot 73 | ) 74 | return data_to_iterate 75 | 76 | 77 | def load_anomaly_syn(self): 78 | tasks = [] 79 | task_probability = [] 80 | for task_name in self.args.anomaly_tasks.keys(): 81 | if task_name =='CutpasteTask': 82 | support_images = [self.read_image(os.path.join(self.source,'images',data['filename'])) for data in self.data_to_iterate] 83 | task = CutPastePatchBlender(support_images) 84 | elif task_name == 'SmoothIntensityTask': 85 | task = SmoothIntensityChangeTask(30.0) 86 | elif task_name == 'GaussIntensityChangeTask': 87 | task = GaussIntensityChangeTask() 88 | elif task_name == 'SinkTask': 89 | task = SinkDeformationTask() 90 | elif task_name == 'SourceTask': 91 | task = SourceDeformationTask() 92 | elif task_name == 'IdentityTask': 93 | task = IdentityTask() 94 | else: 95 | raise NotImplementedError("task must in [CutpasteTask, " 96 | "SmoothIntensityTask, " 97 | "GaussIntensityChangeTask," 98 | "SinkTask, SourceTask, IdentityTask]") 99 | 100 | tasks.append(task) 101 | task_probability.append(self.args.anomaly_tasks[task_name]) 102 | return tasks, task_probability 103 | 104 | 105 | 106 | class ChexpertTestDataset(torch.utils.data.Dataset): 107 | 108 | def __init__( 109 | self, 110 | args, 111 | source, 112 | preprocess, 113 | **kwargs, 114 | ): 115 | super().__init__() 116 | self.args = args 117 | self.source = source 118 | 119 | self.transform_img = preprocess 120 | self.data_to_iterate = self.get_image_data() 121 | 122 | 123 | def __getitem__(self, idx): 124 | info = self.data_to_iterate[idx] 125 | image_path = os.path.join(self.source,'images',info['filename']) 126 | image = PIL.Image.open(image_path).convert("RGB").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.BILINEAR) 127 | mask = np.zeros((self.args.image_size,self.args.image_size)).astype(np.float) 128 | image = self.transform_img(image) 129 | mask = torch.from_numpy(mask) 130 | 131 | return { 132 | "image": image, 133 | "mask" : mask, 134 | "classname": info['clsname'], 135 | "is_anomaly": info['label'], 136 | "image_path": image_path, 137 | } 138 | 139 | def __len__(self): 140 | return len(self.data_to_iterate) 141 | 142 | def get_image_data(self): 143 | data_to_iterate = [] 144 | with open(os.path.join(self.source,'samples',"test.json"), "r") as f_r: 145 | for line in f_r: 146 | meta = json.loads(line) 147 | data_to_iterate.append(meta) 148 | return data_to_iterate 149 | 150 | 151 | class BrainMRITestDataset(torch.utils.data.Dataset): 152 | 153 | def __init__( 154 | self, 155 | args, 156 | source, 157 | preprocess, 158 | **kwargs, 159 | ): 160 | 161 | super().__init__() 162 | self.args = args 163 | self.source = source 164 | self.transform_img = preprocess 165 | self.data_to_iterate = self.get_image_data() 166 | 167 | 168 | def __getitem__(self, idx): 169 | info = self.data_to_iterate[idx] 170 | image_path = os.path.join(self.source,'images',info['filename']) 171 | image = PIL.Image.open(image_path).convert("RGB").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.BILINEAR) 172 | mask = np.zeros((self.args.image_size,self.args.image_size)).astype(np.float) 173 | image = self.transform_img(image) 174 | mask = torch.from_numpy(mask) 175 | 176 | return { 177 | "image": image, 178 | "mask" : mask, 179 | "classname": info['clsname'], 180 | "is_anomaly": info['label'], 181 | "image_path": image_path, 182 | } 183 | 184 | def __len__(self): 185 | return len(self.data_to_iterate) 186 | 187 | 188 | def get_image_data(self): 189 | data_to_iterate = [] 190 | with open(os.path.join(self.source,'samples',"test.json"), "r") as f_r: 191 | for line in f_r: 192 | meta = json.loads(line) 193 | data_to_iterate.append(meta) 194 | return data_to_iterate 195 | 196 | 197 | 198 | class BusiTestDataset(torch.utils.data.Dataset): 199 | 200 | def __init__( 201 | self, 202 | args, 203 | source, 204 | preprocess, 205 | **kwargs, 206 | ): 207 | 208 | super().__init__() 209 | self.args = args 210 | self.source = source 211 | self.transform_img = preprocess 212 | self.data_to_iterate = self.get_image_data() 213 | 214 | 215 | def __getitem__(self, idx): 216 | info = self.data_to_iterate[idx] 217 | image_path = os.path.join(self.source,'images',info['filename']) 218 | image = PIL.Image.open(image_path).convert("RGB").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.BILINEAR) 219 | 220 | if info.get("mask", None): 221 | mask = os.path.join(self.source,'images',info['mask']) 222 | mask = PIL.Image.open(mask).convert("L").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.NEAREST) 223 | mask = np.array(mask).astype(np.float)/255.0 224 | mask [mask!=0.0] = 1.0 225 | else: 226 | mask = np.zeros((self.args.image_size,self.args.image_size)).astype(np.float) 227 | 228 | image = self.transform_img(image) 229 | mask = torch.from_numpy(mask) 230 | 231 | return { 232 | "image": image, 233 | "mask": mask, 234 | "classname": info['clsname'], 235 | "is_anomaly": info['label'], 236 | "image_path": image_path, 237 | } 238 | 239 | def __len__(self): 240 | return len(self.data_to_iterate) 241 | 242 | def get_image_data(self): 243 | data_to_iterate = [] 244 | with open(os.path.join(self.source,'samples',"test.json"), "r") as f_r: 245 | for line in f_r: 246 | meta = json.loads(line) 247 | data_to_iterate.append(meta) 248 | return data_to_iterate -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open_clip 3 | import torch 4 | import yaml 5 | from easydict import EasyDict 6 | from models.Necker import Necker 7 | from models.Adapter import Adapter 8 | import math 9 | import argparse 10 | import warnings 11 | from utils.misc_helper import * 12 | from datasets.dataset import ChexpertTestDataset,BusiTestDataset,BrainMRITestDataset 13 | from torch.utils.data import DataLoader 14 | from models.MapMaker import MapMaker 15 | import pprint 16 | import torch.nn.functional as F 17 | from pytorch_grad_cam.utils.image import show_cam_on_image 18 | from tqdm import tqdm 19 | from PIL import Image 20 | import cv2 21 | 22 | warnings.filterwarnings('ignore') 23 | 24 | def normalization(segmentations, image_size, avgpool_size = 128): 25 | 26 | segmentations = torch.tensor(segmentations[:, None, ...]).cuda() # N x 1 x H x W 27 | segmentations = F.interpolate(segmentations,(image_size, image_size), mode='bilinear', align_corners=True) 28 | 29 | segmentations_ = F.avg_pool2d(segmentations, (avgpool_size,avgpool_size), stride=1).cpu().numpy() 30 | 31 | min_scores = ( 32 | segmentations_.reshape(-1).min(axis=-1).reshape(1) 33 | ) 34 | 35 | max_scores = ( 36 | segmentations_.reshape(-1).max(axis=-1).reshape(1) 37 | ) 38 | 39 | segmentations = segmentations.squeeze(1).cpu().numpy() 40 | segmentations = (segmentations - min_scores) / (max_scores - min_scores) 41 | segmentations = np.clip(segmentations,a_min=0,a_max=1) 42 | 43 | segmentations = cv2.GaussianBlur(segmentations, (5, 5), 0) 44 | return segmentations 45 | 46 | 47 | @torch.no_grad() 48 | def make_vision_takens_info(model,model_cfg,layers_out): 49 | 50 | img = torch.ones((1,3,model_cfg['vision_cfg']['image_size'], 51 | model_cfg['vision_cfg']['image_size'])).to(model.device) 52 | 53 | img_feature,tokens = model.encode_image(img,layers_out) 54 | 55 | if len(tokens[0].shape)==3: 56 | model.token_size= [int(math.sqrt(token.shape[1]-1)) for token in tokens] 57 | model.token_c= [token.shape[-1] for token in tokens] 58 | else: 59 | model.token_size = [token.shape[2] for token in tokens] 60 | model.token_c = [token.shape[1] for token in tokens] 61 | 62 | model.embed_dim = model_cfg['embed_dim'] 63 | print("model token size is {}".format(model.token_size)," model token dim is {}".format(model.token_c)) 64 | 65 | 66 | @torch.no_grad() 67 | def main(args): 68 | 69 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 70 | 71 | with open(args.config_path) as f: 72 | args.config = EasyDict(yaml.load(f, Loader=yaml.FullLoader)) 73 | 74 | set_seed(seed=args.config.random_seed) 75 | 76 | model, preprocess, model_cfg = open_clip.create_model_and_transforms(args.config.model_name, args.config.image_size, device=device) 77 | 78 | for param in model.parameters(): 79 | param.requires_grad_(False) 80 | 81 | args.config.model_cfg = model_cfg 82 | 83 | make_vision_takens_info(model, 84 | args.config.model_cfg, 85 | args.config.layers_out) 86 | 87 | necker = Necker(clip_model=model).to(model.device) 88 | adapter = Adapter(clip_model=model,target=args.config.model_cfg['embed_dim']).to(model.device) 89 | 90 | if args.config.prompt_maker=='coop': 91 | from models.CoOp import PromptMaker 92 | else: 93 | raise NotImplementedError("type of prompt must in ['coop']") 94 | 95 | prompt_maker = PromptMaker( 96 | prompts=args.config.prompts, 97 | clip_model=model, 98 | n_ctx= args.config.n_learnable_token, 99 | CSC = args.config.CSC, 100 | class_token_position=args.config.class_token_positions, 101 | ).to(model.device) 102 | 103 | map_maker = MapMaker(image_size=args.config.image_size).to(model.device) 104 | 105 | 106 | checkpoints = torch.load(args.checkpoint_path,map_location=map_func) 107 | adapter.load_state_dict(checkpoints['adapter_state_dict']) 108 | prompt_maker.prompt_learner.load_state_dict(checkpoints['prompt_state_dict']) 109 | prompt_maker.prompt_learner.eval() 110 | adapter.eval() 111 | 112 | for test_dataset_name in args.config.test_datasets: 113 | 114 | if test_dataset_name == 'chexpert': 115 | 116 | test_dataset = ChexpertTestDataset( args=args.config, 117 | source=os.path.join(args.config.data_root,test_dataset_name), 118 | preprocess=preprocess, 119 | ) 120 | 121 | elif test_dataset_name =='brainmri': 122 | 123 | test_dataset = BrainMRITestDataset( 124 | args=args.config, 125 | source=os.path.join(args.config.data_root,test_dataset_name), 126 | preprocess=preprocess, 127 | ) 128 | elif test_dataset_name =='busi': 129 | 130 | test_dataset = BusiTestDataset( 131 | args=args.config, 132 | source=os.path.join(args.config.data_root,test_dataset_name), 133 | preprocess=preprocess) 134 | else: 135 | raise NotImplementedError("dataset must in ['chexpert','busi','brainmri'] ") 136 | 137 | test_dataloader = DataLoader(test_dataset, batch_size=args.config.batch_size,num_workers=2) 138 | results = validate(args,test_dataset_name,test_dataloader,model,necker,adapter,prompt_maker,map_maker) 139 | 140 | if test_dataset_name!='busi': 141 | print("{}, image auroc: {:.4f}".format(test_dataset_name, results["image-auroc"])) 142 | else: 143 | print("{}, image auroc: {:.4f}, pixel_auroc: {:.4f}".format(test_dataset_name, results["image-auroc"],results['pixel-auroc'])) 144 | 145 | 146 | def validate(args, dataset_name, test_dataloader, clip_model, necker, adapter, prompt_maker, map_maker): 147 | 148 | image_preds = [] 149 | image_gts= [] 150 | 151 | pixel_preds = [] 152 | pixel_gts = [] 153 | 154 | image_paths = [] 155 | 156 | for i, input in enumerate(test_dataloader): 157 | 158 | images = input['image'].to(clip_model.device) 159 | image_paths.extend(input['image_path']) 160 | 161 | _, image_tokens = clip_model.encode_image(images, out_layers=args.config.layers_out) 162 | image_features = necker(image_tokens) 163 | vision_adapter_features = adapter(image_features) 164 | propmt_adapter_features = prompt_maker(vision_adapter_features) 165 | anomaly_map = map_maker(vision_adapter_features, propmt_adapter_features) 166 | 167 | B, _, H, W = anomaly_map.shape 168 | anomaly_map = anomaly_map[:,1,:,:] 169 | 170 | pixel_preds.append(anomaly_map) 171 | anomaly_score,_ =torch.max(anomaly_map.view((B,H*W)), dim=-1) 172 | 173 | image_preds.extend(anomaly_score.cpu().numpy().tolist()) 174 | image_gts.extend(input['is_anomaly'].cpu().numpy().tolist()) 175 | 176 | if dataset_name=='busi': 177 | pixel_gts.append(input['mask'].cpu().numpy()) 178 | 179 | pixel_preds_np = [pixel_pred.cpu().numpy() for pixel_pred in pixel_preds] 180 | pixel_preds = normalization(torch.cat(pixel_preds,dim=0), args.config.image_size) 181 | 182 | if dataset_name == 'busi': 183 | pixel_gts = np.concatenate(pixel_gts,axis=0) 184 | 185 | save_images_root = os.path.join(args.vis_save_root,"{}".format(dataset_name)) 186 | os.makedirs(save_images_root,exist_ok=True) 187 | 188 | if dataset_name=='busi': 189 | iter= tqdm( 190 | zip(image_paths, image_gts, pixel_preds, pixel_gts), 191 | total=len(image_paths), 192 | desc="Generating Segmentation Images...", 193 | leave=False, 194 | ) 195 | else: 196 | iter= tqdm( 197 | zip(image_paths, image_gts, pixel_preds), 198 | total=len(image_paths), 199 | desc="Generating Segmentation Images...", 200 | leave=False, 201 | ) 202 | 203 | for i, data in enumerate(iter): 204 | if dataset_name=='busi': 205 | image_path, image_gt, pixel_pred, pixel_gt = data 206 | else: 207 | image_path, image_gt, pixel_pred = data 208 | 209 | _, image_name = os.path.split(image_path) 210 | 211 | image = Image.open(image_path).convert("RGB") 212 | image = image.resize((args.config.image_size,args.config.image_size)) 213 | image = np.array(image).astype(np.uint8) 214 | 215 | heat = show_cam_on_image( image / 255, pixel_pred, use_rgb=True) 216 | 217 | label_= "normal" if image_gt==0 else "abnormal" 218 | 219 | merge = [image,heat] 220 | 221 | if dataset_name == 'busi': 222 | pixel_gt = np.repeat(np.expand_dims(pixel_gt,axis=-1),3,axis=-1) 223 | merge.append(pixel_gt*255) 224 | 225 | Image.fromarray(np.concatenate(merge,axis=1).astype(np.uint8)).save(os.path.join(save_images_root,"{}_{}_{}".format(i,label_,image_name))) 226 | 227 | metric = compute_imagewise_metrics(image_preds,image_gts) 228 | if dataset_name == 'busi': 229 | metric.update(compute_pixelwise_metrics(pixel_preds_np, pixel_gts)) 230 | 231 | return metric 232 | 233 | 234 | if __name__ == '__main__': 235 | parser = argparse.ArgumentParser(description="Test MediCLIP") 236 | parser.add_argument("--config_path", type=str, help="model configs") 237 | parser.add_argument("--checkpoint_path", type=str, help='the checkpoint path') 238 | parser.add_argument("--vis_save_root", type=str, default='vis_results') 239 | args = parser.parse_args() 240 | torch.multiprocessing.set_start_method("spawn") 241 | main(args) 242 | 243 | -------------------------------------------------------------------------------- /open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from functools import partial 6 | from typing import Dict, Union 7 | 8 | from tqdm import tqdm 9 | 10 | 11 | def _pcfg(url='', hf_hub='', mean=None, std=None): 12 | return dict( 13 | url=url, 14 | hf_hub=hf_hub, 15 | mean=mean, 16 | std=std, 17 | ) 18 | 19 | 20 | _RN50 = dict( 21 | openai=_pcfg( 22 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), 23 | yfcc15m=_pcfg( 24 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), 25 | cc12m=_pcfg( 26 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), 27 | ) 28 | 29 | _RN50_quickgelu = dict( 30 | openai=_pcfg( 31 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), 32 | yfcc15m=_pcfg( 33 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), 34 | cc12m=_pcfg( 35 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), 36 | ) 37 | 38 | _RN101 = dict( 39 | openai=_pcfg( 40 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), 41 | yfcc15m=_pcfg( 42 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), 43 | ) 44 | 45 | _RN101_quickgelu = dict( 46 | openai=_pcfg( 47 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), 48 | yfcc15m=_pcfg( 49 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), 50 | ) 51 | 52 | _RN50x4 = dict( 53 | openai=_pcfg( 54 | "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), 55 | ) 56 | 57 | _RN50x16 = dict( 58 | openai=_pcfg( 59 | "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), 60 | ) 61 | 62 | _RN50x64 = dict( 63 | openai=_pcfg( 64 | "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), 65 | ) 66 | 67 | _VITB32 = dict( 68 | openai=_pcfg( 69 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 70 | laion400m_e31=_pcfg( 71 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 72 | laion400m_e32=_pcfg( 73 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 74 | laion2b_e16=_pcfg( 75 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), 76 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') 77 | ) 78 | 79 | _VITB32_quickgelu = dict( 80 | openai=_pcfg( 81 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 82 | laion400m_e31=_pcfg( 83 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 84 | laion400m_e32=_pcfg( 85 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 86 | ) 87 | 88 | _VITB16 = dict( 89 | openai=_pcfg( 90 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), 91 | laion400m_e31=_pcfg( 92 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), 93 | laion400m_e32=_pcfg( 94 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), 95 | # laion400m_32k=_pcfg( 96 | # url="", 97 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 98 | # laion400m_64k=_pcfg( 99 | # url="", 100 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 101 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), 102 | ) 103 | 104 | _VITB16_PLUS_240 = dict( 105 | laion400m_e31=_pcfg( 106 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), 107 | laion400m_e32=_pcfg( 108 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), 109 | ) 110 | 111 | _VITL14 = dict( 112 | openai=_pcfg( 113 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), 114 | laion400m_e31=_pcfg( 115 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), 116 | laion400m_e32=_pcfg( 117 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), 118 | laion2b_s32b_b82k=_pcfg( 119 | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', 120 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 121 | ) 122 | 123 | _VITL14_336 = dict( 124 | openai=_pcfg( 125 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), 126 | ) 127 | 128 | _VITH14 = dict( 129 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), 130 | ) 131 | 132 | _VITg14 = dict( 133 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), 134 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), 135 | ) 136 | 137 | _VITbigG14 = dict( 138 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), 139 | ) 140 | 141 | _robertaViTB32 = dict( 142 | laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), 143 | ) 144 | 145 | _xlmRobertaBaseViTB32 = dict( 146 | laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), 147 | ) 148 | 149 | _xlmRobertaLargeFrozenViTH14 = dict( 150 | frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), 151 | ) 152 | 153 | _convnext_base = dict( 154 | laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), 155 | ) 156 | 157 | _convnext_base_w = dict( 158 | laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), 159 | laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), 160 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), 161 | ) 162 | 163 | _convnext_base_w_320 = dict( 164 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), 165 | laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), 166 | ) 167 | 168 | _convnext_large_d = dict( 169 | laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), 170 | ) 171 | 172 | _convnext_large_d_320 = dict( 173 | laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), 174 | laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), 175 | ) 176 | 177 | _convnext_xxlarge = dict( 178 | laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), 179 | laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), 180 | laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), 181 | ) 182 | 183 | _coca_VITB32 = dict( 184 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), 185 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') 186 | ) 187 | 188 | _coca_VITL14 = dict( 189 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), 190 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') 191 | ) 192 | 193 | 194 | _PRETRAINED = { 195 | "RN50": _RN50, 196 | "RN50-quickgelu": _RN50_quickgelu, 197 | "RN101": _RN101, 198 | "RN101-quickgelu": _RN101_quickgelu, 199 | "RN50x4": _RN50x4, 200 | "RN50x16": _RN50x16, 201 | "RN50x64": _RN50x64, 202 | "ViT-B-32": _VITB32, 203 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 204 | "ViT-B-16": _VITB16, 205 | "ViT-B-16-plus-240": _VITB16_PLUS_240, 206 | "ViT-L-14": _VITL14, 207 | "ViT-L-14-336": _VITL14_336, 208 | "ViT-H-14": _VITH14, 209 | "ViT-g-14": _VITg14, 210 | "ViT-bigG-14": _VITbigG14, 211 | "roberta-ViT-B-32": _robertaViTB32, 212 | "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, 213 | "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, 214 | "convnext_base": _convnext_base, 215 | "convnext_base_w": _convnext_base_w, 216 | "convnext_base_w_320": _convnext_base_w_320, 217 | "convnext_large_d": _convnext_large_d, 218 | "convnext_large_d_320": _convnext_large_d_320, 219 | "convnext_xxlarge": _convnext_xxlarge, 220 | "coca_ViT-B-32": _coca_VITB32, 221 | "coca_ViT-L-14": _coca_VITL14, 222 | } 223 | 224 | 225 | def _clean_tag(tag: str): 226 | # normalize pretrained tags 227 | return tag.lower().replace('-', '_') 228 | 229 | 230 | def list_pretrained(as_str: bool = False): 231 | """ returns list of pretrained models 232 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 233 | """ 234 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 235 | 236 | 237 | def list_pretrained_models_by_tag(tag: str): 238 | """ return all models having the specified pretrain tag """ 239 | models = [] 240 | tag = _clean_tag(tag) 241 | for k in _PRETRAINED.keys(): 242 | if tag in _PRETRAINED[k]: 243 | models.append(k) 244 | return models 245 | 246 | 247 | def list_pretrained_tags_by_model(model: str): 248 | """ return all pretrain tags for the specified model architecture """ 249 | tags = [] 250 | if model in _PRETRAINED: 251 | tags.extend(_PRETRAINED[model].keys()) 252 | return tags 253 | 254 | 255 | def is_pretrained_cfg(model: str, tag: str): 256 | if model not in _PRETRAINED: 257 | return False 258 | return _clean_tag(tag) in _PRETRAINED[model] 259 | 260 | 261 | def get_pretrained_cfg(model: str, tag: str): 262 | if model not in _PRETRAINED: 263 | return {} 264 | model_pretrained = _PRETRAINED[model] 265 | return model_pretrained.get(_clean_tag(tag), {}) 266 | 267 | 268 | def get_pretrained_url(model: str, tag: str): 269 | cfg = get_pretrained_cfg(model, _clean_tag(tag)) 270 | return cfg.get('url', '') 271 | 272 | 273 | def download_pretrained_from_url( 274 | url: str, 275 | cache_dir: Union[str, None] = None, 276 | ): 277 | if not cache_dir: 278 | cache_dir = os.path.expanduser("~/.cache/clip") 279 | os.makedirs(cache_dir, exist_ok=True) 280 | filename = os.path.basename(url) 281 | 282 | if 'openaipublic' in url: 283 | expected_sha256 = url.split("/")[-2] 284 | elif 'mlfoundations' in url: 285 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 286 | else: 287 | expected_sha256 = '' 288 | 289 | download_target = os.path.join(cache_dir, filename) 290 | 291 | if os.path.exists(download_target) and not os.path.isfile(download_target): 292 | raise RuntimeError(f"{download_target} exists and is not a regular file") 293 | 294 | if os.path.isfile(download_target): 295 | if expected_sha256: 296 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 297 | return download_target 298 | else: 299 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 300 | else: 301 | return download_target 302 | 303 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 304 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 305 | while True: 306 | buffer = source.read(8192) 307 | if not buffer: 308 | break 309 | 310 | output.write(buffer) 311 | loop.update(len(buffer)) 312 | 313 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 314 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 315 | 316 | return download_target 317 | 318 | 319 | def download_pretrained( 320 | cfg: Dict, 321 | cache_dir: Union[str, None] = None, 322 | ): 323 | target = '' 324 | if not cfg: 325 | return target 326 | 327 | download_url = cfg.get('url', '') 328 | 329 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) 330 | 331 | return target 332 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | import yaml 4 | from easydict import EasyDict 5 | from models.Necker import Necker 6 | from models.Adapter import Adapter 7 | import math 8 | import argparse 9 | import warnings 10 | from utils.misc_helper import * 11 | from torch.utils.data import DataLoader 12 | from models.MapMaker import MapMaker 13 | from utils.losses import FocalLoss,BinaryDiceLoss 14 | from datasets.dataset import TrainDataset,\ 15 | ChexpertTestDataset,\ 16 | BusiTestDataset,\ 17 | BrainMRITestDataset 18 | import pprint 19 | from tqdm import tqdm 20 | warnings.filterwarnings('ignore') 21 | 22 | 23 | @torch.no_grad() 24 | def make_vision_takens_info(model,model_cfg,layers_out): 25 | 26 | img = torch.ones((1,3,model_cfg['vision_cfg']['image_size'], 27 | model_cfg['vision_cfg']['image_size'])).to(model.device) 28 | 29 | img_feature,tokens = model.encode_image(img,layers_out) 30 | 31 | if len(tokens[0].shape)==3: 32 | model.token_size= [int(math.sqrt(token.shape[1]-1)) for token in tokens] 33 | model.token_c= [token.shape[-1] for token in tokens] 34 | else: 35 | model.token_size = [token.shape[2] for token in tokens] 36 | model.token_c = [token.shape[1] for token in tokens] 37 | 38 | model.embed_dim = model_cfg['embed_dim'] 39 | print("model token size is {}".format(model.token_size)," model token dim is {}".format(model.token_c)) 40 | 41 | 42 | def main(args): 43 | 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | 46 | with open(args.config_path) as f: 47 | args.config = EasyDict(yaml.load(f, Loader=yaml.FullLoader)) 48 | 49 | model, preprocess, model_cfg = open_clip.create_model_and_transforms(args.config.model_name, args.config.image_size, device=device) 50 | 51 | for param in model.parameters(): 52 | param.requires_grad_(False) 53 | 54 | args.config.model_cfg = model_cfg 55 | 56 | make_vision_takens_info(model, 57 | args.config.model_cfg, 58 | args.config.layers_out) 59 | 60 | current_time = get_current_time() 61 | args.config.save_root=os.path.join(args.config.save_root,current_time) 62 | 63 | if not os.path.exists(args.config.save_root): 64 | os.makedirs(args.config.save_root) 65 | 66 | logger = create_logger("logger",os.path.join(args.config.save_root,'logger.log')) 67 | logger.info("config: {}".format(pprint.pformat(args))) 68 | 69 | necker = Necker(clip_model=model).to(model.device) 70 | adapter = Adapter(clip_model=model,target=args.config.model_cfg['embed_dim']).to(model.device) 71 | 72 | if args.config.prompt_maker=='coop': 73 | from models.CoOp import PromptMaker 74 | logger.info("load CoOp") 75 | else: 76 | raise NotImplementedError("type of prompt must in ['coop']") 77 | 78 | prompt_maker = PromptMaker( 79 | prompts=args.config.prompts, 80 | clip_model=model, 81 | n_ctx= args.config.n_learnable_token, 82 | CSC = args.config.CSC, 83 | class_token_position=args.config.class_token_positions, 84 | ).to(model.device) 85 | 86 | map_maker = MapMaker(image_size=args.config.image_size).to(model.device) 87 | 88 | optimizer = torch.optim.Adam([ 89 | {'params': prompt_maker.prompt_learner.parameters(),'lr': 0.001}, 90 | {'params': adapter.parameters(),"lr":0.001}, 91 | ], lr=0.001, betas=(0.5, 0.999)) 92 | 93 | train_dataset = TrainDataset(args=args.config, 94 | source=os.path.join(args.config.data_root,args.config.train_dataset), 95 | preprocess=preprocess, 96 | k_shot=args.k_shot) 97 | 98 | train_dataloader = DataLoader(train_dataset, batch_size=args.config.batch_size, shuffle=True, num_workers=2) 99 | 100 | test_dataloaders = {} 101 | best_record = {} 102 | 103 | for test_dataset_name in args.config.test_datasets: 104 | 105 | if test_dataset_name == 'chexpert': 106 | test_dataset = ChexpertTestDataset( args=args.config, 107 | source=os.path.join(args.config.data_root,test_dataset_name), 108 | preprocess=preprocess, 109 | ) 110 | 111 | elif test_dataset_name =='brainmri': 112 | 113 | test_dataset = BrainMRITestDataset( 114 | args=args.config, 115 | source=os.path.join(args.config.data_root,test_dataset_name), 116 | preprocess=preprocess, 117 | ) 118 | elif test_dataset_name =='busi': 119 | 120 | test_dataset = BusiTestDataset( 121 | args=args.config, 122 | source=os.path.join(args.config.data_root,test_dataset_name), 123 | preprocess=preprocess) 124 | else: 125 | raise NotImplementedError("dataset must in ['chexpert','busi','brainmri'] ") 126 | 127 | test_dataloader = DataLoader(test_dataset, batch_size=args.config.batch_size,num_workers=2) 128 | test_dataloaders[test_dataset_name]=test_dataloader 129 | best_record[test_dataset_name]=None 130 | 131 | logger.info("train data ({}) len {}".format(args.config.train_dataset,len(train_dataset))) 132 | 133 | for test_dataset_name in test_dataloaders: 134 | logger.info("test data ({}) len {}".format(test_dataset_name, len(test_dataloaders[test_dataset_name].dataset))) 135 | 136 | for task_name in args.config.anomaly_tasks: 137 | logger.info("anomaly syn task is {}, sampling probability is {}".format(task_name,args.config.anomaly_tasks[task_name])) 138 | 139 | for epoch in range(0, args.config.epoch): 140 | last_iter = epoch * len(train_dataloader) 141 | 142 | train_one_epoch( 143 | args, 144 | train_dataloader, 145 | optimizer, 146 | epoch, 147 | last_iter, 148 | logger, 149 | model, 150 | necker, 151 | adapter, 152 | prompt_maker, 153 | map_maker, 154 | ) 155 | 156 | if (epoch+1) % args.config.val_freq_epoch == 0: 157 | 158 | results = validate(args,test_dataloaders, epoch,model, necker,adapter,prompt_maker,map_maker) 159 | save_flag = False 160 | 161 | for test_dataset_name in results: 162 | if best_record[test_dataset_name] is None: 163 | if test_dataset_name=='busi': 164 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"], 165 | results[test_dataset_name]['pixel-auroc']] 166 | else: 167 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"]] 168 | 169 | save_flag=True 170 | else: 171 | if np.mean([results[test_dataset_name][key] for key in results[test_dataset_name]]) > np.mean(best_record[test_dataset_name]): 172 | if test_dataset_name == 'busi': 173 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"], 174 | results[test_dataset_name]['pixel-auroc']] 175 | else: 176 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"]] 177 | save_flag=True 178 | 179 | 180 | if test_dataset_name=='busi': 181 | logger.info("({}): Epoch: {}, image auroc: {:.4f}, pixel_auroc: {:.4f},".format(test_dataset_name, 182 | epoch+1, 183 | results[test_dataset_name]["image-auroc"], 184 | results[test_dataset_name]['pixel-auroc'])) 185 | else: 186 | logger.info("({}): Epoch: {}, image auroc: {:.4f},".format( 187 | test_dataset_name, 188 | epoch+1, 189 | results[test_dataset_name]["image-auroc"], 190 | )) 191 | 192 | for test_dataset_name in results: 193 | if test_dataset_name == 'busi': 194 | logger.info( 195 | "({} best): image auroc: {:.4f}, pixel auroc: {:.4f},".format( 196 | test_dataset_name, 197 | best_record[test_dataset_name][0], 198 | best_record[test_dataset_name][1], 199 | )) 200 | else: 201 | logger.info( 202 | "({} best): image auroc: {:.4f},".format( 203 | test_dataset_name, 204 | best_record[test_dataset_name][0], 205 | )) 206 | 207 | if save_flag: 208 | logger.info("save checkpoints in epoch: {}".format(epoch+1)) 209 | torch.save({ 210 | "adapter_state_dict": adapter.state_dict(), 211 | "prompt_state_dict": prompt_maker.prompt_learner.state_dict(), 212 | }, os.path.join(args.config.save_root, 'checkpoints_{}.pkl'.format(epoch + 1))) 213 | 214 | 215 | def train_one_epoch( 216 | args, 217 | train_dataloader, 218 | optimizer, 219 | epoch, 220 | start_iter, 221 | logger, 222 | clip_model, 223 | necker, 224 | adapter, 225 | prompt_maker, 226 | map_maker, 227 | ): 228 | 229 | loss_meter = AverageMeter(args.config.print_freq_step) 230 | 231 | focal_criterion = FocalLoss() 232 | dice_criterion = BinaryDiceLoss() 233 | 234 | adapter.train() 235 | prompt_maker.train() 236 | 237 | for i, input in enumerate(train_dataloader): 238 | curr_step = start_iter + i 239 | 240 | images = input['image'].to(clip_model.device) 241 | gt_mask = input['mask'].to(clip_model.device) 242 | 243 | with torch.no_grad(): 244 | _, image_tokens = clip_model.encode_image(images,out_layers=args.config.layers_out) 245 | image_features = necker(image_tokens) 246 | 247 | vision_adapter_features = adapter(image_features) 248 | propmt_adapter_features = prompt_maker(vision_adapter_features) 249 | anomaly_map = map_maker(vision_adapter_features,propmt_adapter_features) 250 | 251 | loss = [] 252 | 253 | loss.append(focal_criterion(anomaly_map,gt_mask)) 254 | loss.append(dice_criterion(anomaly_map[:, 1, :, :],gt_mask)) 255 | 256 | loss = torch.sum(torch.stack(loss)) 257 | loss_meter.update(loss.item()) 258 | 259 | optimizer.zero_grad() 260 | loss.backward() 261 | optimizer.step() 262 | 263 | if (curr_step + 1) % args.config.print_freq_step == 0: 264 | logger.info( 265 | "Epoch: [{0}/{1}]\t" 266 | "Iter: [{2}/{3}]\t" 267 | "Loss {loss.val:.5f} ({loss.avg:.5f})\t" 268 | .format( 269 | epoch+1 , 270 | args.config.epoch, 271 | curr_step + 1, 272 | len(train_dataloader) * args.config.epoch, 273 | loss=loss_meter, 274 | ) 275 | ) 276 | 277 | 278 | def validate(args, test_dataloaders, epoch, clip_model, necker, adapter, prompt_maker, map_maker): 279 | 280 | adapter.eval() 281 | prompt_maker.eval() 282 | results = {} 283 | 284 | for test_dataset_name in test_dataloaders: 285 | test_dataloader = test_dataloaders[test_dataset_name] 286 | 287 | anomaly_maps = [] 288 | anomaly_gts = [] 289 | 290 | image_scores = [] 291 | image_labels = [] 292 | 293 | with torch.no_grad(): 294 | for i, input in enumerate(tqdm(test_dataloader,desc=test_dataset_name)): 295 | 296 | images = input['image'].to(clip_model.device) 297 | 298 | _, image_tokens = clip_model.encode_image(images, out_layers=args.config.layers_out) 299 | image_features = necker(image_tokens) 300 | vision_adapter_features = adapter(image_features) 301 | propmt_adapter_features = prompt_maker(vision_adapter_features) 302 | anomaly_map = map_maker(vision_adapter_features, propmt_adapter_features) 303 | 304 | B,_,H,W = anomaly_map.shape 305 | 306 | anomaly_map = anomaly_map[:,1,:,:] 307 | anomaly_gt = input['mask'] 308 | 309 | anomaly_maps.append(anomaly_map.cpu().numpy()) 310 | anomaly_gts.append(anomaly_gt.cpu().numpy()) 311 | 312 | anomaly_score,_ = torch.max(anomaly_map.view((B,H*W)), dim=-1) 313 | 314 | image_scores.extend(anomaly_score.cpu().numpy().tolist()) 315 | image_labels.extend(input['is_anomaly'].cpu().numpy().tolist()) 316 | 317 | metric = compute_imagewise_metrics(image_scores,image_labels) 318 | 319 | if test_dataset_name=='busi': 320 | metric.update(compute_pixelwise_metrics(anomaly_maps,anomaly_gts)) 321 | 322 | results[test_dataset_name] = metric 323 | return results 324 | 325 | 326 | if __name__ == '__main__': 327 | parser = argparse.ArgumentParser(description="Train MediCLIP") 328 | parser.add_argument("--config_path", type=str, default='config/brainmri.yaml', help="model configs") 329 | parser.add_argument("--k_shot", type=int, default=16, help="normal image number") 330 | args = parser.parse_args() 331 | torch.multiprocessing.set_start_method("spawn") 332 | main(args) -------------------------------------------------------------------------------- /open_clip/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from dataclasses import dataclass 6 | import logging 7 | import math 8 | from typing import Optional, Tuple, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | from .modified_resnet import ModifiedResNet 17 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer 18 | from .utils import to_2tuple 19 | 20 | 21 | @dataclass 22 | class CLIPVisionCfg: 23 | layers: Union[Tuple[int, int, int, int], int] = 12 24 | width: int = 768 25 | head_width: int = 64 26 | mlp_ratio: float = 4.0 27 | patch_size: int = 16 28 | image_size: Union[Tuple[int, int], int] = 224 29 | ls_init_value: Optional[float] = None # layer scale initial value 30 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 31 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design 32 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) 33 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer 34 | n_queries: int = 256 # n_queries for attentional pooler 35 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 36 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 37 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 38 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 39 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 40 | timm_proj_bias: bool = False # enable bias final projection 41 | timm_drop: float = 0. # head dropout 42 | timm_drop_path: Optional[float] = None # backbone stochastic depth 43 | output_tokens: bool = False 44 | 45 | 46 | @dataclass 47 | class CLIPTextCfg: 48 | context_length: int = 77 49 | vocab_size: int = 49408 50 | width: int = 512 51 | heads: int = 8 52 | layers: int = 12 53 | ls_init_value: Optional[float] = None # layer scale initial value 54 | hf_model_name: str = None 55 | hf_tokenizer_name: str = None 56 | hf_model_pretrained: bool = True 57 | proj: str = 'mlp' 58 | pooler_type: str = 'mean_pooler' 59 | embed_cls: bool = False 60 | pad_id: int = 0 61 | output_tokens: bool = False 62 | 63 | 64 | def get_cast_dtype(precision: str): 65 | cast_dtype = None 66 | if precision == 'bf16': 67 | cast_dtype = torch.bfloat16 68 | elif precision == 'fp16': 69 | cast_dtype = torch.float16 70 | return cast_dtype 71 | 72 | 73 | def _build_vision_tower( 74 | embed_dim: int, 75 | vision_cfg: CLIPVisionCfg, 76 | quick_gelu: bool = False, 77 | cast_dtype: Optional[torch.dtype] = None 78 | ): 79 | if isinstance(vision_cfg, dict): 80 | vision_cfg = CLIPVisionCfg(**vision_cfg) 81 | 82 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 83 | # memory efficient in recent PyTorch releases (>= 1.10). 84 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 85 | act_layer = QuickGELU if quick_gelu else nn.GELU 86 | 87 | if isinstance(vision_cfg.layers, (tuple, list)): 88 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 89 | visual = ModifiedResNet( 90 | layers=vision_cfg.layers, 91 | output_dim=embed_dim, 92 | heads=vision_heads, 93 | image_size=vision_cfg.image_size, 94 | width=vision_cfg.width, 95 | ) 96 | else: 97 | vision_heads = vision_cfg.width // vision_cfg.head_width 98 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 99 | visual = VisionTransformer( 100 | image_size=vision_cfg.image_size, 101 | patch_size=vision_cfg.patch_size, 102 | width=vision_cfg.width, 103 | layers=vision_cfg.layers, 104 | heads=vision_heads, 105 | mlp_ratio=vision_cfg.mlp_ratio, 106 | ls_init_value=vision_cfg.ls_init_value, 107 | patch_dropout=vision_cfg.patch_dropout, 108 | input_patchnorm=vision_cfg.input_patchnorm, 109 | global_average_pool=vision_cfg.global_average_pool, 110 | attentional_pool=vision_cfg.attentional_pool, 111 | n_queries=vision_cfg.n_queries, 112 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 113 | output_tokens=vision_cfg.output_tokens, 114 | output_dim=embed_dim, 115 | act_layer=act_layer, 116 | norm_layer=norm_layer, 117 | ) 118 | 119 | return visual 120 | 121 | 122 | def _build_text_tower( 123 | embed_dim: int, 124 | text_cfg: CLIPTextCfg, 125 | quick_gelu: bool = False, 126 | cast_dtype: Optional[torch.dtype] = None, 127 | ): 128 | if isinstance(text_cfg, dict): 129 | text_cfg = CLIPTextCfg(**text_cfg) 130 | 131 | 132 | act_layer = QuickGELU if quick_gelu else nn.GELU 133 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 134 | 135 | text = TextTransformer( 136 | context_length=text_cfg.context_length, 137 | vocab_size=text_cfg.vocab_size, 138 | width=text_cfg.width, 139 | heads=text_cfg.heads, 140 | layers=text_cfg.layers, 141 | ls_init_value=text_cfg.ls_init_value, 142 | output_dim=embed_dim, 143 | embed_cls=text_cfg.embed_cls, 144 | output_tokens=text_cfg.output_tokens, 145 | pad_id=text_cfg.pad_id, 146 | act_layer=act_layer, 147 | norm_layer=norm_layer, 148 | ) 149 | 150 | return text 151 | 152 | 153 | class CLIP(nn.Module): 154 | output_dict: torch.jit.Final[bool] 155 | 156 | def __init__( 157 | self, 158 | embed_dim: int, 159 | vision_cfg: CLIPVisionCfg, 160 | text_cfg: CLIPTextCfg, 161 | quick_gelu: bool = False, 162 | cast_dtype: Optional[torch.dtype] = None, 163 | output_dict: bool = False, 164 | ): 165 | super().__init__() 166 | self.output_dict = output_dict 167 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 168 | 169 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 170 | 171 | self.transformer = text.transformer 172 | 173 | self.vocab_size = text.vocab_size 174 | self.token_embedding = text.token_embedding 175 | 176 | self.positional_embedding = text.positional_embedding 177 | self.ln_final = text.ln_final 178 | self.text_projection = text.text_projection 179 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 180 | 181 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 182 | 183 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 184 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 185 | 186 | @torch.jit.ignore 187 | def set_grad_checkpointing(self, enable=True): 188 | self.visual.set_grad_checkpointing(enable) 189 | self.transformer.grad_checkpointing = enable 190 | 191 | def encode_image(self, image, out_layers): 192 | features = self.visual(image, out_layers) 193 | return features 194 | 195 | 196 | def encode_text(self, text, normalize: bool = False): 197 | cast_dtype = self.transformer.get_cast_dtype() 198 | 199 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 200 | 201 | x = x + self.positional_embedding.to(cast_dtype) 202 | x = x.permute(1, 0, 2) # NLD -> LND 203 | x, attn, tokens = self.transformer(x, attn_mask=self.attn_mask) 204 | x = x.permute(1, 0, 2) # LND -> NLD 205 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 206 | # take features from the eot embedding (eot_token is the highest number in each sequence) 207 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 208 | 209 | return F.normalize(x, dim=-1) if normalize else x 210 | 211 | 212 | 213 | 214 | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): 215 | """Convert applicable model parameters to low-precision (bf16 or fp16)""" 216 | 217 | def _convert_weights(l): 218 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 219 | l.weight.data = l.weight.data.to(dtype) 220 | if l.bias is not None: 221 | l.bias.data = l.bias.data.to(dtype) 222 | 223 | if isinstance(l, (nn.MultiheadAttention, Attention)): 224 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 225 | tensor = getattr(l, attr) 226 | if tensor is not None: 227 | tensor.data = tensor.data.to(dtype) 228 | 229 | for name in ["text_projection", "proj"]: 230 | if hasattr(l, name): 231 | attr = getattr(l, name) 232 | if attr is not None: 233 | attr.data = attr.data.to(dtype) 234 | 235 | model.apply(_convert_weights) 236 | 237 | 238 | convert_weights_to_fp16 = convert_weights_to_lp # backwards compat 239 | 240 | 241 | # used to maintain checkpoint compatibility 242 | def convert_to_custom_text_state_dict(state_dict: dict): 243 | if 'text_projection' in state_dict: 244 | # old format state_dict, move text tower -> .text 245 | new_state_dict = {} 246 | for k, v in state_dict.items(): 247 | if any(k.startswith(p) for p in ( 248 | 'text_projection', 249 | 'positional_embedding', 250 | 'token_embedding', 251 | 'transformer', 252 | 'ln_final', 253 | )): 254 | k = 'text.' + k 255 | new_state_dict[k] = v 256 | return new_state_dict 257 | return state_dict 258 | 259 | 260 | def build_model_from_openai_state_dict( 261 | state_dict: dict, 262 | quick_gelu=True, 263 | cast_dtype=torch.float16, 264 | ): 265 | vit = "visual.proj" in state_dict 266 | 267 | if vit: 268 | vision_width = state_dict["visual.conv1.weight"].shape[0] 269 | vision_layers = len( 270 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 271 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 272 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 273 | image_size = vision_patch_size * grid_size 274 | else: 275 | counts: list = [ 276 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 277 | vision_layers = tuple(counts) 278 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 279 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 280 | vision_patch_size = None 281 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 282 | image_size = output_width * 32 283 | 284 | embed_dim = state_dict["text_projection"].shape[1] 285 | context_length = state_dict["positional_embedding"].shape[0] 286 | vocab_size = state_dict["token_embedding.weight"].shape[0] 287 | transformer_width = state_dict["ln_final.weight"].shape[0] 288 | transformer_heads = transformer_width // 64 289 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 290 | 291 | vision_cfg = CLIPVisionCfg( 292 | layers=vision_layers, 293 | width=vision_width, 294 | patch_size=vision_patch_size, 295 | image_size=image_size, 296 | ) 297 | text_cfg = CLIPTextCfg( 298 | context_length=context_length, 299 | vocab_size=vocab_size, 300 | width=transformer_width, 301 | heads=transformer_heads, 302 | layers=transformer_layers, 303 | ) 304 | 305 | model = CLIP( 306 | embed_dim, 307 | vision_cfg=vision_cfg, 308 | text_cfg=text_cfg, 309 | quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU 310 | cast_dtype=cast_dtype, 311 | ) 312 | 313 | for key in ["input_resolution", "context_length", "vocab_size"]: 314 | state_dict.pop(key, None) 315 | 316 | convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 317 | model.load_state_dict(state_dict) 318 | 319 | return model.eval() 320 | 321 | 322 | def trace_model(model, batch_size=256, device=torch.device('cpu')): 323 | model.eval() 324 | image_size = model.visual.image_size 325 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) 326 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) 327 | model = torch.jit.trace_module( 328 | model, 329 | inputs=dict( 330 | forward=(example_images, example_text), 331 | encode_text=(example_text,), 332 | encode_image=(example_images,) 333 | )) 334 | model.visual.image_size = image_size 335 | return model 336 | 337 | 338 | 339 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): 340 | # Rescale the grid of position embeddings when loading from state_dict 341 | flag = 1 342 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 343 | if old_pos_embed is None: 344 | flag = 0 345 | old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None) 346 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 347 | return 348 | grid_size = to_2tuple(model.visual.grid_size) 349 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 350 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 351 | if new_seq_len == old_pos_embed.shape[0]: 352 | return 353 | 354 | if extra_tokens: 355 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 356 | else: 357 | pos_emb_tok, pos_emb_img = None, old_pos_embed 358 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 359 | 360 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 361 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 362 | pos_emb_img = F.interpolate( 363 | pos_emb_img, 364 | size=grid_size, 365 | mode=interpolation, 366 | antialias=antialias, 367 | align_corners=False, 368 | ) 369 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 370 | if pos_emb_tok is not None: 371 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 372 | else: 373 | new_pos_embed = pos_emb_img 374 | if flag: 375 | state_dict['visual.positional_embedding'] = new_pos_embed 376 | else: 377 | state_dict['visual.attnpool.positional_embedding'] = new_pos_embed 378 | -------------------------------------------------------------------------------- /medsyn/task_shape.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | import functools 3 | import itertools 4 | from typing import Callable, List, Tuple 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from skimage.morphology import convex_hull_image 9 | 10 | from .utils import accumulate_rotation, accumulate_scaling 11 | import math 12 | import imgaug.augmenters as iaa 13 | 14 | 15 | class PatchShapeMaker(ABC): 16 | 17 | @abstractmethod 18 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray: 19 | """ 20 | :param dim_bounds: Tuples giving lower and upper bounds for patch size in each dimension 21 | :param img_dims: Image dimensions, can be used as scaling factor. 22 | Creates a patch mask to be used in the self-supervised task. 23 | Mask must have length(dim_bounds) dimensions. 24 | """ 25 | pass 26 | 27 | def __call__(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray: 28 | return self.get_patch_mask(dim_bounds, img_dims) 29 | 30 | 31 | class PerlinPatchMaker(PatchShapeMaker): 32 | 33 | def __init__(self, 34 | min_perlin_scale=0, 35 | perlin_scale=1, 36 | perlin_noise_threshold = 0.3, 37 | perlin_min_size = 0.2 38 | ): 39 | 40 | self.min_perlin_scale = min_perlin_scale 41 | self.perlin_scale = perlin_scale 42 | self.perlin_noise_threshold = perlin_noise_threshold 43 | self.perlin_min_size = perlin_min_size 44 | 45 | def lerp_np(self, x, y, w): 46 | fin_out = (y - x) * w + x 47 | return fin_out 48 | 49 | def rand_perlin_2d_np(self, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 50 | delta = (res[0] / shape[0], res[1] / shape[1]) 51 | d = (shape[0] // res[0], shape[1] // res[1]) 52 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 53 | 54 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 55 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 56 | tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1) 57 | 58 | tile_grads = lambda slice1, slice2: np.repeat( 59 | np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]], d[0], axis=0), d[1], axis=1) 60 | dot = lambda grad, shift: ( 61 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 62 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) 63 | 64 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 65 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 66 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 67 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 68 | t = fade(grid[:shape[0], :shape[1]]) 69 | return math.sqrt(2) * self.lerp_np(self.lerp_np(n00, n10, t[..., 0]), self.lerp_np(n01, n11, t[..., 0]), t[..., 1]) 70 | 71 | 72 | def get_patch_mask_and_intersect_fn(self, 73 | dim_bounds: List[Tuple[int, int]], 74 | img_dims: np.ndarray) \ 75 | -> Tuple[npt.NDArray[bool], Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]]: 76 | 77 | perlin_scalex = 2 ** (np.random.randint(low = self.min_perlin_scale, high = self.perlin_scale, size=(1,))[0]) 78 | perlin_scaley = 2 ** (np.random.randint(low = self.min_perlin_scale, high = self.perlin_scale, size=(1,))[0]) 79 | 80 | noise_size_x = np.random.randint(low=dim_bounds[0][0],high=dim_bounds[0][1]) 81 | noise_size_y = np.random.randint(low=dim_bounds[1][0],high=dim_bounds[1][1]) 82 | 83 | while True: 84 | perlin_noise = self.rand_perlin_2d_np((noise_size_x, noise_size_y), (perlin_scalex, perlin_scaley)) 85 | 86 | # apply affine transform 87 | rot = iaa.Affine(rotate=(-90, 90)) 88 | perlin_noise = rot(image=perlin_noise) 89 | 90 | # make a mask by applying threshold 91 | mask_noise = np.where( 92 | perlin_noise > self.perlin_noise_threshold, 93 | np.ones_like(perlin_noise), 94 | np.zeros_like(perlin_noise) 95 | ) 96 | mask_noise[mask_noise != 0] = 1.0 97 | 98 | if np.mean(mask_noise) >= self.perlin_min_size: 99 | break 100 | return mask_noise.astype(np.bool), None 101 | 102 | 103 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]: 104 | return self.get_patch_mask_and_intersect_fn(dim_bounds, img_dims)[0] 105 | 106 | 107 | 108 | class DeformedHypershapePatchMaker(PatchShapeMaker): 109 | 110 | def __init__(self, aligned_distance_to_edge_fn: Callable[[npt.NDArray[int], npt.NDArray[float], npt.NDArray[float]], 111 | npt.NDArray[float]], 112 | sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)): 113 | # Instead of mask function, need inclusion test function 114 | # takes shape + point as parameters 115 | self.aligned_distance_to_edge_fn = aligned_distance_to_edge_fn 116 | self.sample_dist = sample_dist 117 | self.rng = np.random.default_rng() 118 | 119 | @abstractmethod 120 | def within_aligned_shape(self, array_of_coords: npt.NDArray[float], shape_size: npt.NDArray[int]) \ 121 | -> npt.NDArray[bool]: 122 | """ 123 | Calculates whether a point is within an aligned shape with dimensions shape_size. 124 | :param array_of_coords: 125 | :param shape_size: 126 | """ 127 | 128 | def make_shape_intersect_function(self, 129 | inv_trans_matrix: npt.NDArray[float], 130 | trans_matrix: npt.NDArray[float], 131 | shape_size: npt.NDArray[int]) \ 132 | -> Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]: 133 | 134 | return lambda orig, direction: trans_matrix @ self.aligned_distance_to_edge_fn(shape_size, 135 | inv_trans_matrix @ orig, 136 | inv_trans_matrix @ direction) 137 | 138 | def get_patch_mask_and_intersect_fn(self, 139 | dim_bounds: List[Tuple[int, int]], 140 | img_dims: np.ndarray) \ 141 | -> Tuple[npt.NDArray[bool], Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]]: 142 | 143 | shape = np.array([self.sample_dist(lb, ub, d) for ((lb, ub), d) in zip(dim_bounds, img_dims)]) 144 | 145 | num_dims = len(dim_bounds) 146 | 147 | trans_matrix = np.identity(num_dims) 148 | 149 | # Instead of transforming mask, accumulate transformation matrix 150 | for d in range(num_dims): 151 | other_dim = self.rng.choice([i for i in range(num_dims) if i != d]) 152 | shear_factor = self.rng.normal(scale=0.2) 153 | trans_matrix[d, other_dim] = shear_factor 154 | 155 | # Rotate mask, using all possible access combinations 156 | trans_matrix = functools.reduce(lambda m, ds: accumulate_rotation(m, 157 | self.rng.uniform(-np.pi / 2, np.pi / 2), 158 | ds), 159 | itertools.combinations(range(num_dims), 2), 160 | trans_matrix) 161 | 162 | # Using corner points, calculate size of resulting shape 163 | shape_width = (shape - 1) / 2 164 | corner_coord_grid = np.array(np.meshgrid(*np.stack([shape_width, -shape_width], axis=-1), indexing='ij')) 165 | corner_coords = corner_coord_grid.reshape(num_dims, 2 ** num_dims) 166 | 167 | trans_corner_coords = trans_matrix @ corner_coords 168 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1)) 169 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1)) 170 | final_grid_shape = max_trans_coords + 1 - min_trans_coords 171 | 172 | # Check if transformations have made patch too big 173 | ub_shape = np.array([ub for _, ub in dim_bounds]) 174 | if np.any(final_grid_shape > ub_shape): 175 | # If so, scale down to be within limits. 176 | max_scale_diff = np.max(final_grid_shape / ub_shape) 177 | trans_matrix = accumulate_scaling(trans_matrix, 1 / max_scale_diff) 178 | 179 | # Repeat calculations with new transformation matrix 180 | trans_corner_coords = trans_matrix @ corner_coords 181 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1)) 182 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1)) 183 | final_grid_shape = max_trans_coords + 1 - min_trans_coords 184 | 185 | # Create meshgrid of coords of resulting shape 186 | coord_ranges = [np.arange(lb, ub + 1) for lb, ub in zip(min_trans_coords, max_trans_coords)] 187 | coord_grid = np.array(np.meshgrid(*coord_ranges, indexing='ij')) 188 | 189 | # Apply inverse transformation matrix, to compute sampling points 190 | inv_trans_matrix = np.linalg.inv(trans_matrix) 191 | inv_coord_grid_f = inv_trans_matrix @ np.reshape(coord_grid, (num_dims, -1)) 192 | 193 | # Apply inclusion test function, giving an array containing a boolean for each coordinate. 194 | inv_result_grid_f = self.within_aligned_shape(inv_coord_grid_f, shape) 195 | 196 | return np.reshape(inv_result_grid_f, final_grid_shape.astype(int)), \ 197 | self.make_shape_intersect_function(inv_trans_matrix, trans_matrix, shape) 198 | 199 | 200 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]: 201 | return self.get_patch_mask_and_intersect_fn(dim_bounds, img_dims)[0] 202 | 203 | 204 | 205 | 206 | 207 | def intersect_to_aligned_hyperrectangle_edge(hyperrectangle_shape: npt.NDArray[int], 208 | origin: npt.NDArray[float], 209 | direction: npt.NDArray[float]) \ 210 | -> npt.NDArray[float]: 211 | """ 212 | 213 | :param hyperrectangle_shape: Numpy array of hyperrectangle shape, (D) 214 | :param origin: Numpy array of origin coordinates (D, N) 215 | :param direction: Numpy array of normalised direction vectors (D, N) 216 | :return: 217 | """ 218 | rect_width = np.reshape(hyperrectangle_shape / 2, (-1, 1)) 219 | # Normalise direction magnitudes 220 | direction = direction / np.linalg.norm(direction, axis=0) 221 | 222 | max_dir = rect_width - origin 223 | min_dir = - rect_width - origin 224 | 225 | # Shape (2D, N) 226 | all_edge_distances = np.concatenate([max_dir / direction, 227 | min_dir / direction]) 228 | 229 | # Set any behind distances to inf, so we don't choose them 230 | all_edge_distances[all_edge_distances < 0] = np.inf 231 | 232 | # Shape (N,) 233 | min_distances = np.min(all_edge_distances, axis=0) 234 | 235 | min_is_inf = min_distances == np.inf 236 | assert not np.any(min_is_inf), f'Some lines are outside bounding box:\n' \ 237 | f'rectangle shape: {hyperrectangle_shape},' \ 238 | f'origins - {origin[min_is_inf]},\n' \ 239 | f'directions - {direction[min_is_inf]}' 240 | 241 | return origin + direction * min_distances 242 | 243 | 244 | 245 | class DeformedHyperrectanglePatchMaker(DeformedHypershapePatchMaker): 246 | 247 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)): 248 | super().__init__(intersect_to_aligned_hyperrectangle_edge, sample_dist) 249 | 250 | def within_aligned_shape(self, array_of_coords: npt.NDArray[float], shape_size: npt.NDArray[int]) \ 251 | -> npt.NDArray[bool]: 252 | rect_width = np.reshape(shape_size / 2, (-1, 1)) 253 | return np.all(-rect_width <= array_of_coords, axis=0) & np.all(array_of_coords <= rect_width, axis=0) 254 | 255 | 256 | 257 | def intersect_aligned_hyperellipse_edge(hyperellipse_shape: npt.NDArray[int], 258 | origin: npt.NDArray[float], 259 | direction: npt.NDArray[float]) \ 260 | -> npt.NDArray[float]: 261 | ellipse_radii_sq = np.reshape((hyperellipse_shape / 2) ** 2, (-1, 1)) 262 | # Normalise direction magnitudes 263 | direction = direction / np.linalg.norm(direction, axis=0) 264 | 265 | # Compute quadratic coefficients, all shape (N) 266 | a = np.sum(direction ** 2 / ellipse_radii_sq, axis=0) 267 | b = np.sum(2 * origin * direction / ellipse_radii_sq, axis=0) 268 | c = np.sum(origin ** 2 / ellipse_radii_sq, axis=0) - 1 269 | 270 | # Solve quadratic, (N) 271 | det = b ** 2 - 4 * a * c 272 | 273 | det_is_negative = det < 0 274 | assert not np.any(det_is_negative), f'Some lines never intersect ellipse:\n' \ 275 | f'Ellipse shape: {hyperellipse_shape}\n' \ 276 | f'origins: {origin[det_is_negative]}' \ 277 | f'directions: {direction[det_is_negative]}' 278 | 279 | solutions = (-b + np.array([[1], [-1]]) * np.sqrt(det)) / (2 * a) 280 | 281 | # Make any negative solutions (behind origin) infinity so we don't choose them 282 | solutions[solutions < 0] = np.inf 283 | 284 | min_solutions = np.min(solutions, axis=0) 285 | min_is_inf = min_solutions == np.inf 286 | assert not np.any(min_is_inf), f'Some lines are outside ellipse:\n' \ 287 | f'ellipse shape: {hyperellipse_shape},' \ 288 | f'origins - {origin[min_is_inf]},\n' \ 289 | f'directions - {direction[min_is_inf]}' 290 | return origin + direction * min_solutions 291 | 292 | 293 | class DeformedHyperellipsePatchMaker(DeformedHypershapePatchMaker): 294 | 295 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)): 296 | super().__init__(intersect_aligned_hyperellipse_edge, sample_dist) 297 | 298 | def within_aligned_shape(self, array_of_coords: npt.NDArray[float], shape_size: npt.NDArray[int]) \ 299 | -> npt.NDArray[bool]: 300 | ellipse_radii = np.reshape(shape_size / 2, (-1, 1)) 301 | return np.sum(array_of_coords ** 2 / ellipse_radii ** 2, axis=0) <= 1 302 | 303 | 304 | class CombinedDeformedHypershapePatchMaker(PatchShapeMaker): 305 | 306 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)): 307 | 308 | self.rect_maker = DeformedHyperrectanglePatchMaker(sample_dist) 309 | self.ellip_maker = DeformedHyperellipsePatchMaker(sample_dist) 310 | self.last_choice = None 311 | 312 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]: 313 | 314 | mode = np.random.choice(['rect', 'ellipse', 'comb']) 315 | self.last_choice = mode 316 | 317 | if mode == 'rect': 318 | return self.rect_maker(dim_bounds, img_dims) 319 | 320 | elif mode == 'ellipse': 321 | return self.ellip_maker(dim_bounds, img_dims) 322 | 323 | elif mode == 'comb': 324 | 325 | rect_mask = self.rect_maker(dim_bounds, img_dims) 326 | ellip_mask = self.ellip_maker(dim_bounds, img_dims) 327 | 328 | rect_size = np.sum(rect_mask) 329 | ellip_size = np.sum(ellip_mask) 330 | 331 | big_m, small_m = (rect_mask, ellip_mask) if rect_size >= ellip_size else (ellip_mask, rect_mask) 332 | 333 | # Choose point on big mask to put small mask 334 | mask_coords = np.nonzero(big_m) 335 | 336 | point_ind = self.rect_maker.rng.integers(len(mask_coords[0])) 337 | 338 | chosen_coord = np.array([m_cs[point_ind] for m_cs in mask_coords]) 339 | 340 | small_shape = np.array(small_m.shape) 341 | lower_coord = chosen_coord - small_shape // 2 342 | 343 | if np.any(lower_coord < 0): 344 | to_pad_below = np.maximum(-lower_coord, 0) 345 | big_m = np.pad(big_m, [(p, 0) for p in to_pad_below]) 346 | lower_coord += to_pad_below 347 | 348 | big_shape = np.array(big_m.shape) 349 | 350 | upper_coord = lower_coord + small_shape 351 | if np.any(upper_coord > big_shape): 352 | to_pad_above = np.maximum(upper_coord - big_shape, 0) 353 | big_m = np.pad(big_m, [(0, p) for p in to_pad_above]) 354 | 355 | big_m[tuple([slice(lb, ub) for lb, ub in zip(lower_coord, upper_coord)])] |= small_m 356 | 357 | return convex_hull_image(big_m) 358 | 359 | else: 360 | raise Exception('Invalid mask option') 361 | 362 | 363 | class EitherDeformedHypershapePatchMaker(PatchShapeMaker): 364 | 365 | def __init__(self, 366 | sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)): 367 | 368 | self.rect_maker = DeformedHyperrectanglePatchMaker(sample_dist) 369 | self.ellip_maker = DeformedHyperellipsePatchMaker(sample_dist) 370 | 371 | def get_patch_mask_and_intersect_fn(self, 372 | dim_bounds: List[Tuple[int, int]], 373 | img_dims: np.ndarray) \ 374 | -> Tuple[npt.NDArray[bool], Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]]: 375 | 376 | chosen_task = np.random.choice([self.rect_maker, self.ellip_maker]) 377 | return chosen_task.get_patch_mask_and_intersect_fn(dim_bounds, img_dims) 378 | 379 | 380 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]: 381 | return self.get_patch_mask_and_intersect_fn(dim_bounds, img_dims)[0] 382 | -------------------------------------------------------------------------------- /medsyn/tasks.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Callable, Optional, Tuple, Union 3 | import functools 4 | import itertools 5 | import numpy as np 6 | import numpy.typing as npt 7 | from scipy.ndimage import affine_transform,distance_transform_edt 8 | from numpy.linalg import norm 9 | import random 10 | from scipy import ndimage 11 | from scipy.ndimage import gaussian_filter 12 | 13 | from .labelling import FlippedGaussianLabeller,AnomalyLabeller 14 | from .task_shape import EitherDeformedHypershapePatchMaker,PerlinPatchMaker 15 | from .utils import * 16 | 17 | def cut_paste(sample: npt.NDArray[float], 18 | source_to_blend: npt.NDArray[float], 19 | anomaly_corner: npt.NDArray[int], 20 | anomaly_mask: npt.NDArray[bool]) -> npt.NDArray[float]: 21 | 22 | repeated_mask = np.broadcast_to(anomaly_mask, source_to_blend.shape) 23 | 24 | sample_slices = get_patch_image_slices(anomaly_corner, tuple(anomaly_mask.shape)) 25 | 26 | aug_sample = sample.copy() 27 | aug_sample[sample_slices][repeated_mask] = source_to_blend[repeated_mask] 28 | 29 | return aug_sample 30 | 31 | 32 | 33 | class BaseTask(ABC): 34 | def __init__(self, 35 | sample_labeller: Optional[AnomalyLabeller] = None, 36 | **all_kwargs): 37 | 38 | self.sample_labeller = sample_labeller 39 | self.rng = np.random.default_rng() 40 | 41 | self.min_anom_prop=0.3 42 | self.max_anom_prop=0.8 43 | 44 | self.anomaly_shape_maker = EitherDeformedHypershapePatchMaker(nsa_sample_dimension) 45 | self.all_kwargs = all_kwargs 46 | 47 | 48 | def apply(self, 49 | sample: npt.NDArray[float], 50 | *args, **kwargs)\ 51 | -> Tuple[npt.NDArray[float], npt.NDArray[float]]: 52 | """ 53 | Apply the self-supervised task to the single data sample. 54 | :param sample: Normal sample to be augmented 55 | :param sample_mask: Object mask of sample. 56 | :return: sample with task applied and label map. 57 | """ 58 | 59 | aug_sample = sample.copy() 60 | 61 | sample_shape = np.array(sample.shape[1:]) 62 | anomaly_mask = np.zeros(sample_shape, dtype=bool) 63 | 64 | min_anom_prop = self.min_anom_prop 65 | max_anom_prop = self.max_anom_prop 66 | 67 | min_dim_lens = (min_anom_prop * sample_shape).round().astype(int) 68 | max_dim_lens = (max_anom_prop * sample_shape).round().astype(int) 69 | # print(min_dim_lens,max_dim_lens) [15,15],[205,205] 70 | 71 | dim_bounds = list(zip(min_dim_lens, max_dim_lens)) #[(15, 205), (15, 205)] 72 | 73 | # For random number of times 74 | sample_mask = None 75 | 76 | for i in range(2): 77 | 78 | # Compute anomaly mask 79 | curr_anomaly_mask, intersect_fn = self.anomaly_shape_maker.get_patch_mask_and_intersect_fn(dim_bounds, 80 | sample_shape) 81 | 82 | # Choose anomaly location 83 | anomaly_corner = self.find_valid_anomaly_location(curr_anomaly_mask, sample_mask, sample_shape) 84 | 85 | # Apply self-supervised task 86 | 87 | aug_sample = self.augment_sample(aug_sample, sample_mask, anomaly_corner, curr_anomaly_mask, intersect_fn) 88 | 89 | anomaly_mask[get_patch_slices(anomaly_corner, curr_anomaly_mask.shape)] |= curr_anomaly_mask 90 | 91 | # Randomly brake at end of loop, ensuring we get at least 1 anomaly 92 | if self.rng.random() > 0.5: 93 | break 94 | 95 | if self.sample_labeller is not None: 96 | return aug_sample, self.sample_labeller(aug_sample, sample, anomaly_mask) 97 | else: 98 | # If no labeller is provided, we are probably in a calibration process 99 | return aug_sample, np.expand_dims(anomaly_mask, 0) 100 | 101 | 102 | def find_valid_anomaly_location(self, 103 | curr_anomaly_mask: npt.NDArray[bool], 104 | sample_mask: Optional[npt.NDArray[bool]], 105 | sample_shape: npt.NDArray[int]): 106 | 107 | curr_anomaly_shape = np.array(curr_anomaly_mask.shape) 108 | min_corner = np.zeros(len(sample_shape)) 109 | max_corner = sample_shape - curr_anomaly_shape 110 | 111 | # - Apply anomaly at location 112 | while True: 113 | anomaly_corner = self.rng.integers(min_corner, max_corner, endpoint=True) 114 | 115 | # If the sample mask is None, any location within the bounds is valid 116 | if sample_mask is None: 117 | break 118 | # Otherwise, we need to check that the intersection of the anomaly mask and the sample mask is at least 50% 119 | target_patch_obj_mask = sample_mask[get_patch_slices(anomaly_corner, curr_anomaly_mask.shape)] 120 | if (np.sum(target_patch_obj_mask & curr_anomaly_mask) / np.sum(curr_anomaly_mask)) >= 0.5: 121 | break 122 | 123 | return anomaly_corner 124 | 125 | 126 | def __call__(self, 127 | sample: npt.NDArray[float], 128 | *args, 129 | **kwargs)\ 130 | -> Tuple[npt.NDArray[float], npt.NDArray[float]]: 131 | """ 132 | Apply the self-supervised task to the single data sample. 133 | :param sample: Normal sample to be augmented 134 | :param sample_mask: Object mask of sample. 135 | :param **kwargs: 136 | * *sample_path*: Path to source image 137 | :return: sample with task applied and label map. 138 | """ 139 | if len(sample.shape)==2: 140 | sample = np.expand_dims(sample,axis=0) 141 | 142 | aug_sample, aug_mask = self.apply(sample, *args, **kwargs) 143 | 144 | if len(aug_sample.shape)==3 and aug_sample.shape[0]==1: 145 | aug_sample = aug_sample.squeeze(0) 146 | 147 | if len(aug_mask.shape)==3 and aug_mask.shape[0]==1: 148 | aug_mask = aug_mask.squeeze(0) 149 | 150 | return aug_sample,aug_mask.astype(np.float) 151 | 152 | 153 | @abstractmethod 154 | def augment_sample(self, sample: npt.NDArray[float], sample_mask: Optional[npt.NDArray[bool]], 155 | anomaly_corner: npt.NDArray[int], anomaly_mask: npt.NDArray[bool], 156 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \ 157 | -> npt.NDArray[float]: 158 | """ 159 | Apply self-supervised task to region at anomaly_corner covered by anomaly_mask 160 | :param sample: Sample to be augmented. 161 | :param sample_mask: Object mask of sample. 162 | :param anomaly_corner: Index of anomaly corner. 163 | :param anomaly_mask: Mask 164 | :param anomaly_intersect_fn: Function which, given a line's origin and direction, finds its intersection with 165 | the edge of the anomaly mask 166 | :return: 167 | """ 168 | 169 | 170 | class BasePatchBlendingTask(BaseTask): 171 | 172 | def __init__(self, 173 | sample_labeller: Optional[AnomalyLabeller], 174 | source_samples: list, 175 | blend_images: Callable[[npt.NDArray[float], npt.NDArray[float], npt.NDArray[int], npt.NDArray[bool]], 176 | npt.NDArray[float]], 177 | 178 | **all_kwargs): 179 | super().__init__(sample_labeller, **all_kwargs) 180 | self.source_samples = source_samples 181 | self.blend_images = blend_images 182 | 183 | 184 | def augment_sample(self, 185 | sample: npt.NDArray[float], # aug sample 186 | sample_mask: Optional[npt.NDArray[bool]], # None 187 | anomaly_corner: npt.NDArray[int], # center 188 | anomaly_mask: npt.NDArray[bool], # small anomaly mask 189 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \ 190 | -> npt.NDArray[float]: 191 | 192 | num_channels = sample.shape[0] # 1 193 | num_dims = len(sample.shape[1:]) #2 194 | 195 | # Sample source to blend into current sample 196 | source_sample = random.choice(self.source_samples) 197 | 198 | source_sample_shape = np.array(source_sample.shape[1:]) #(256,256) 199 | 200 | 201 | assert len(source_sample_shape) == num_dims, 'Source and target have different number of spatial dimensions: ' \ 202 | f's-{len(source_sample_shape)}, t-{num_dims}' 203 | 204 | assert source_sample.shape[0] == num_channels, \ 205 | f'Source and target have different number of channels: s-{source_sample.shape[0]}, t-{num_channels}' 206 | 207 | # Compute INVERSE transformation matrix for parameters (rotation, resizing) 208 | # This is the backwards operation (final source region -> initial source region). 209 | 210 | trans_matrix = functools.reduce(lambda m, ds: accumulate_rotation(m, 211 | self.rng.uniform(-np.pi / 4, np.pi / 4), 212 | ds), 213 | itertools.combinations(range(num_dims), 2), 214 | np.identity(num_dims)) 215 | 216 | # Compute effect on corner coords 217 | target_anomaly_shape = np.array(anomaly_mask.shape) 218 | 219 | corner_coords = np.array(np.meshgrid(*np.stack([np.zeros(num_dims), target_anomaly_shape], axis=-1), 220 | indexing='ij')).reshape(num_dims, 2 ** num_dims) 221 | 222 | trans_corner_coords = trans_matrix @ corner_coords 223 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1)) 224 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1)) 225 | init_grid_shape = max_trans_coords - min_trans_coords 226 | 227 | # Sample scale and clip so that source region isn't too big 228 | max_scale = np.min(0.8 * source_sample_shape / init_grid_shape) 229 | 230 | # Compute final transformation matrix 231 | scale_change = 1 + self.rng.exponential(scale=0.1) 232 | scale_raw = self.rng.choice([scale_change, 1 / scale_change]) 233 | scale = np.minimum(scale_raw, max_scale) 234 | 235 | trans_matrix = accumulate_scaling(trans_matrix, scale) 236 | 237 | # Recompute effect on corner coord 238 | trans_corner_coords = trans_matrix @ corner_coords 239 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1)) 240 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1)) 241 | final_init_grid_shape = max_trans_coords - min_trans_coords 242 | 243 | # Choose anomaly source location 244 | final_init_grid_shape = final_init_grid_shape.astype(int) 245 | min_corner = np.zeros(len(source_sample_shape)) 246 | max_corner = source_sample_shape - final_init_grid_shape 247 | 248 | source_corner = self.rng.integers(min_corner, max_corner, endpoint=True) 249 | 250 | # Extract source 251 | source_orig = source_sample[get_patch_image_slices(source_corner, tuple(final_init_grid_shape))] 252 | 253 | 254 | # Because we computed the backwards transformation we don't need to inverse the matrix 255 | source_to_blend = np.stack([affine_transform(chan, trans_matrix, offset=-min_trans_coords, 256 | output_shape=tuple(target_anomaly_shape)) 257 | for chan in source_orig]) 258 | 259 | spatial_axis = tuple(range(1, len(source_sample.shape))) 260 | # Spline interpolation can make values fall outside domain, so clip to the original range 261 | source_to_blend = np.clip(source_to_blend, 262 | source_sample.min(axis=spatial_axis, keepdims=True), 263 | source_sample.max(axis=spatial_axis, keepdims=True)) 264 | 265 | 266 | # As the blending can alter areas outside the mask, update the mask with any effected areas 267 | 268 | aug_sample = self.blend_images(sample, source_to_blend, anomaly_corner, anomaly_mask) 269 | 270 | sample_slices = get_patch_image_slices(anomaly_corner, tuple(anomaly_mask.shape)) 271 | sample_diff = np.mean(np.abs(sample[sample_slices] - aug_sample[sample_slices]), axis=0) 272 | 273 | anomaly_mask[sample_diff > 0.001] = True 274 | # Return sample with source blended into it 275 | return aug_sample 276 | 277 | 278 | 279 | class BaseDeformationTask(BaseTask): 280 | 281 | @abstractmethod 282 | def compute_mapping(self, 283 | sample: npt.NDArray[float], 284 | sample_mask: Optional[npt.NDArray[bool]], 285 | anomaly_corner: npt.NDArray[int], anomaly_mask: npt.NDArray[bool], 286 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \ 287 | -> npt.NDArray[float]: 288 | """ 289 | Returns array of size (*anomaly_mask.shape, len(anomaly_mask.shape)). 290 | Probably don't need entire sample, but including in for generality. 291 | :param sample: 292 | :param sample_mask: 293 | :param anomaly_corner: 294 | :param anomaly_mask: 295 | :param anomaly_intersect_fn: 296 | :return: 297 | """ 298 | 299 | def augment_sample(self, 300 | sample: npt.NDArray[float], 301 | sample_mask: Optional[npt.NDArray[bool]], 302 | anomaly_corner: npt.NDArray[int], 303 | anomaly_mask: npt.NDArray[bool], 304 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \ 305 | -> npt.NDArray[float]: 306 | 307 | num_channels = sample.shape[0] 308 | mapping = self.compute_mapping(sample, sample_mask, anomaly_corner, anomaly_mask, anomaly_intersect_fn) 309 | sample_slices = get_patch_slices(anomaly_corner, tuple(anomaly_mask.shape)) 310 | 311 | for chan in range(num_channels): 312 | sample[chan][sample_slices] = ndimage.map_coordinates(sample[chan][sample_slices], 313 | mapping, 314 | mode='nearest') 315 | return sample 316 | 317 | 318 | 319 | class RadialDeformationTask(BaseDeformationTask): 320 | 321 | def __init__(self, 322 | sample_labeller: Optional[AnomalyLabeller] = None, 323 | deform_factor: Optional[float] = None, 324 | deform_centre: Optional[npt.NDArray] = None, **kwargs): 325 | 326 | super().__init__(sample_labeller, **kwargs) 327 | self.deform_factor = deform_factor 328 | self.deform_centre = deform_centre 329 | self.max_anom_prop = 0.6 330 | self.min_anom_prop = 0.2 331 | 332 | def get_deform_factor(self, def_centre: npt.NDArray[int], anomaly_mask: npt.NDArray[bool]): 333 | return self.deform_factor if self.deform_factor is not None else 2 ** self.rng.uniform(0.5, 2) 334 | 335 | @abstractmethod 336 | def compute_new_distance(self, curr_distance: float, max_distance: float, factor: float) -> float: 337 | """ 338 | Compute new distance for point to be sampled from 339 | :param curr_distance: 340 | :param max_distance: 341 | :param factor: 342 | """ 343 | 344 | def compute_mapping(self, 345 | sample: npt.NDArray[float], 346 | sample_mask: Optional[npt.NDArray[bool]], 347 | anomaly_corner: npt.NDArray[int], 348 | anomaly_mask: npt.NDArray[bool], 349 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \ 350 | -> npt.NDArray[float]: 351 | # NOTE: This assumes that the shape is convex, will make discontinuities if it's not. 352 | 353 | anomaly_shape = np.array(anomaly_mask.shape) 354 | num_dims = len(anomaly_shape) 355 | 356 | # Expand so can later be broadcast with (D, N) 357 | mask_centre = (anomaly_shape - 1) / 2 358 | exp_mask_centre = np.reshape(mask_centre, (-1, 1)) 359 | # Shape (D, N) 360 | poss_centre_coords = np.stack(np.nonzero(anomaly_mask)) 361 | def_centre = self.deform_centre if self.deform_centre is not None else \ 362 | poss_centre_coords[:, np.random.randint(poss_centre_coords.shape[1])] 363 | 364 | assert anomaly_mask[tuple(def_centre.round().astype(int))], f'Centre is not within anomaly: {def_centre}' 365 | 366 | # exp_ = expanded 367 | exp_def_centre = np.reshape(def_centre, (-1, 1)) 368 | 369 | # (D, *anomaly_shape) 370 | mapping = np.stack(np.meshgrid(*[np.arange(s, dtype=float) for s in anomaly_shape], indexing='ij'), axis=0) 371 | 372 | # Ignore pixels on edge of bounding box 373 | mask_inner_slice = tuple([slice(1, -1)] * num_dims) 374 | map_inner_slice = tuple([slice(None)] + list(mask_inner_slice)) 375 | # Get all coords and transpose so coord index is last dimension (D, N) 376 | anomaly_coords = mapping[map_inner_slice][(slice(None), anomaly_mask[mask_inner_slice])] 377 | 378 | all_coords_to_centre = anomaly_coords - exp_def_centre 379 | all_coords_distance = norm(all_coords_to_centre, axis=0) 380 | # Ignore zero divided by zero, as we correct it before mapping is returned 381 | with np.errstate(invalid='ignore'): 382 | all_coords_norm_dirs = all_coords_to_centre / all_coords_distance 383 | 384 | mask_edge_intersections = anomaly_intersect_fn(exp_def_centre - exp_mask_centre, all_coords_norm_dirs) + exp_mask_centre 385 | 386 | mask_edge_distances = norm(mask_edge_intersections - exp_def_centre, axis=0) 387 | 388 | # Get factor once, so is same for all pixels 389 | def_factor = self.get_deform_factor(def_centre, anomaly_mask) 390 | new_coord_distances = self.compute_new_distance(all_coords_distance, mask_edge_distances, def_factor) 391 | # (D, N) 392 | new_coords = exp_def_centre + new_coord_distances * all_coords_norm_dirs 393 | 394 | mapping[map_inner_slice][(slice(None), anomaly_mask[mask_inner_slice])] = new_coords 395 | 396 | # Revert centre coordinate, as it will be nan due to the zero magnitude direction vector 397 | mapping[(slice(None), *def_centre)] = def_centre 398 | return mapping 399 | 400 | 401 | 402 | class CutPastePatchBlender(BasePatchBlendingTask): 403 | 404 | def __init__(self, 405 | source_images: list, 406 | Labelber_std: float= 0.2, 407 | **kwargs): 408 | sample_labeller = FlippedGaussianLabeller(Labelber_std) 409 | source_images=[ np.expand_dims(image,axis=0) if len(image.shape)==2 else image for image in source_images] 410 | super().__init__(sample_labeller, source_images, cut_paste) 411 | self.max_anom_prop = 0.6 412 | self.min_anom_prop = 0.1 413 | self.anomaly_shape_maker = PerlinPatchMaker() 414 | 415 | 416 | 417 | class SmoothIntensityChangeTask(BaseTask): 418 | 419 | def __init__(self, 420 | intensity_task_scale: float, 421 | sample_labeller: Optional[AnomalyLabeller] = None, 422 | **all_kwargs): 423 | 424 | super().__init__(sample_labeller, **all_kwargs) 425 | self.intensity_task_scale = intensity_task_scale 426 | self.max_anom_prop = 0.8 427 | self.min_anom_prop = 0.3 428 | self.anomaly_shape_maker = PerlinPatchMaker() 429 | 430 | def augment_sample(self, 431 | sample: npt.NDArray[float], # aug_sample 432 | sample_mask: Optional[npt.NDArray[bool]],# None 433 | anomaly_corner: npt.NDArray[int], # anomaly center 434 | anomaly_mask: npt.NDArray[bool], # small anomaly mask 435 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \ 436 | -> npt.NDArray[float]: 437 | 438 | num_chans = sample.shape[0] # 1 439 | sample_shape = sample.shape[1:] #(256,256) 440 | num_dims = len(sample_shape) # 2 441 | 442 | dist_map = distance_transform_edt(anomaly_mask) 443 | min_shape_dim = np.min(sample_shape) # 256 444 | 445 | smooth_dist = np.minimum(min_shape_dim * (0.02 + np.random.gamma(3, 0.01)), np.max(dist_map)) 446 | smooth_dist_map = dist_map / smooth_dist 447 | smooth_dist_map[smooth_dist_map > 1] = 1 448 | # smooth_dist_map = 1 449 | 450 | anomaly_patch_slices = get_patch_image_slices(anomaly_corner, anomaly_mask.shape) 451 | 452 | # anomaly_pixel_stds = np.array([np.std(c[anomaly_mask]) for c in sample[anomaly_patch_slices]]) 453 | # Randomly negate, so some intensity changes are subtractions 454 | 455 | intensity_changes = (self.intensity_task_scale / 2 + np.random.gamma(3, self.intensity_task_scale)) \ 456 | * np.random.choice([1, -1], size=num_chans) 457 | 458 | intensity_change_map = smooth_dist_map * np.reshape(intensity_changes, [-1] + [1] * num_dims) 459 | 460 | new_patch = sample[anomaly_patch_slices] + intensity_change_map 461 | 462 | spatial_axis = tuple(range(1, len(sample.shape))) 463 | 464 | sample[anomaly_patch_slices] = np.clip(new_patch, 465 | sample.min(axis=spatial_axis, keepdims=True), 466 | sample.max(axis=spatial_axis, keepdims=True)) 467 | 468 | return sample 469 | 470 | 471 | 472 | class GaussIntensityChangeTask(BaseTask): 473 | 474 | def __init__(self, 475 | sample_labeller: Optional[AnomalyLabeller] = None, 476 | **all_kwargs): 477 | 478 | super().__init__(sample_labeller, **all_kwargs) 479 | self.max_anom_prop = 0.8 480 | self.min_anom_prop = 0.3 481 | self.sigma_bs = [4, 7] 482 | self.positive_range = [0.4, 0.6] 483 | self.negative_range = [-0.6, -0.4] 484 | self.anomaly_shape_maker = PerlinPatchMaker() 485 | def get_predefined_texture(self, 486 | mask_shape, 487 | sigma_b, 488 | positive_range=None, 489 | negative_range=None, 490 | ): 491 | 492 | assert (positive_range is not None) or (negative_range is not None) 493 | 494 | random_sample = np.random.randn(mask_shape[0], mask_shape[1]) 495 | 496 | random_sample = (random_sample >= 0.0).astype(float) # int type can't do Gaussian filter 497 | 498 | random_sample = gaussian_filter(random_sample, sigma_b) 499 | 500 | random_sample = (random_sample - np.min(random_sample)) / (np.max(random_sample) - np.min(random_sample)) 501 | 502 | if np.random.uniform(0, 1) <= 0.5: 503 | u_0 = np.random.uniform(positive_range[0], positive_range[1]) 504 | else: 505 | if negative_range is not None: 506 | u_0 = np.random.uniform(negative_range[0], negative_range[1]) 507 | else: 508 | u_0 = np.random.uniform(-positive_range[1], -positive_range[0]) 509 | 510 | Bj = np.clip(u_0 * random_sample, -1, 1) 511 | return Bj 512 | 513 | def create_texture(self,sizes): 514 | texture = self.get_predefined_texture(sizes, 515 | random.choice(self.sigma_bs), 516 | self.positive_range, 517 | self.negative_range) 518 | return texture 519 | 520 | 521 | def augment_sample(self, 522 | sample: npt.NDArray[float], # aug_sample 523 | sample_mask: Optional[npt.NDArray[bool]],# None 524 | anomaly_corner: npt.NDArray[int], # anomaly center 525 | anomaly_mask: npt.NDArray[bool], # small anomaly mask 526 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \ 527 | -> npt.NDArray[float]: 528 | 529 | anomaly_mask_copy = anomaly_mask.astype(np.float) 530 | anomaly_patch_slices = get_patch_image_slices(anomaly_corner, anomaly_mask_copy.shape) 531 | 532 | texture = self.create_texture(sample.shape[1:]) 533 | 534 | while True: 535 | if len(texture.shape) npt.NDArray[float]: 569 | anomaly_mask[:,:] = False 570 | return sample 571 | 572 | 573 | 574 | class SinkDeformationTask(RadialDeformationTask): 575 | # y = 1 - (1 - x)^3 (between 0 and 1) 576 | # -> y = max_d (1 - (1 - curr / max_d) ^ factor) 577 | # -> y = max_d - (max_d - curr) ^ factor / max_d ^ (factor - 1) 578 | 579 | def compute_new_distance(self, curr_distance: Union[float, npt.NDArray[float]], 580 | max_distance: Union[float, npt.NDArray[float]], 581 | factor: Union[float, npt.NDArray[float]]) -> Union[float, npt.NDArray[float]]: 582 | 583 | return max_distance - (max_distance - curr_distance) ** factor / max_distance ** (factor - 1) 584 | 585 | 586 | 587 | class SourceDeformationTask(RadialDeformationTask): 588 | 589 | def compute_new_distance(self, curr_distance: Union[float, npt.NDArray[float]], 590 | max_distance: Union[float, npt.NDArray[float]], 591 | factor: Union[float, npt.NDArray[float]]) -> Union[float, npt.NDArray[float]]: 592 | # y = x^3 (between 0 and 1) 593 | # -> y = max_d * (curr / max) ^ factor 594 | # -> y = curr ^ factor / max_d ^ (factor - 1) to avoid FP errors 595 | return curr_distance ** factor / max_distance ** (factor - 1) --------------------------------------------------------------------------------