├── medsyn
├── __init__.py
├── utils.py
├── labelling.py
├── task_shape.py
└── tasks.py
├── models
├── __init__.py
├── Adapter.py
├── MapMaker.py
├── Necker.py
└── CoOp.py
├── utils
├── __init__.py
├── misc_helper.py
└── losses.py
├── datasets
├── __init__.py
└── dataset.py
├── assets
├── vis.jpg
└── pipeline.jpg
├── open_clip
├── bpe_simple_vocab_16e6.txt.gz
├── constants.py
├── model_configs
│ ├── ViT-B-16.json
│ ├── ViT-B-32.json
│ ├── ViT-M-16.json
│ ├── ViT-M-32.json
│ ├── ViT-S-16.json
│ ├── ViT-S-32.json
│ ├── ViT-B-16-plus.json
│ ├── ViT-L-14-280.json
│ ├── ViT-L-14-336.json
│ ├── ViT-L-14.json
│ ├── ViT-L-16-320.json
│ ├── ViT-L-16.json
│ ├── ViT-M-32-alt.json
│ ├── ViT-S-16-alt.json
│ ├── ViT-S-32-alt.json
│ ├── ViT-B-16-plus-240.json
│ ├── ViT-B-32-plus-256.json
│ ├── ViT-H-14.json
│ ├── ViT-H-16.json
│ ├── ViT-B-32-quickgelu.json
│ ├── ViT-M-16-alt.json
│ ├── mt5-base-ViT-B-32.json
│ ├── xlm-roberta-base-ViT-B-32.json
│ ├── roberta-ViT-B-32.json
│ ├── ViT-e-14.json
│ ├── ViT-g-14.json
│ ├── mt5-xl-ViT-H-14.json
│ ├── ViT-bigG-14.json
│ ├── xlm-roberta-large-ViT-H-14.json
│ ├── RN50.json
│ ├── RN101.json
│ ├── RN50x16.json
│ ├── RN50x4.json
│ ├── RN50x64.json
│ ├── vit_medium_patch16_gap_256.json
│ ├── swin_base_patch4_window7_224.json
│ ├── vit_relpos_medium_patch16_cls_224.json
│ ├── RN101-quickgelu.json
│ ├── RN50-quickgelu.json
│ ├── convnext_base.json
│ ├── convnext_base_w.json
│ ├── convnext_large.json
│ ├── convnext_large_d.json
│ ├── convnext_small.json
│ ├── convnext_tiny.json
│ ├── convnext_base_w_320.json
│ ├── convnext_large_d_320.json
│ ├── convnext_xlarge.json
│ ├── convnext_xxlarge.json
│ └── convnext_xxlarge_320.json
├── __init__.py
├── transform.py
├── utils.py
├── openai.py
├── factory.py
├── tokenizer.py
├── modified_resnet.py
├── pretrained.py
└── model.py
├── requirements.txt
├── LICENSE
├── config
├── busi.yaml
├── brainmri.yaml
└── chexpert.yaml
├── data
└── README.md
├── README.md
├── test.py
└── train.py
/medsyn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/vis.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cnulab/MediCLIP/HEAD/assets/vis.jpg
--------------------------------------------------------------------------------
/assets/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cnulab/MediCLIP/HEAD/assets/pipeline.jpg
--------------------------------------------------------------------------------
/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cnulab/MediCLIP/HEAD/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/open_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.11.0+cu113
2 | torchvision==0.12.0+cu113
3 | Pillow==9.1.1
4 | scikit-image==0.19.3
5 | scikit-learn==1.1.2
6 | sklearn==0.0
7 | opencv-python==4.6.0.66
8 | grad_cam==1.4.3
9 | tqdm==4.61.2
10 | PyYAML==6.0
11 | easydict==1.9
12 | ftfy==6.1.3
13 | regex==2023.12.25
14 | imgaug==0.4.0
15 | numpy==1.22.4
16 |
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-M-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-M-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-S-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 384,
13 | "heads": 6,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-S-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 384,
13 | "heads": 6,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-16-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-14-280.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 280,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-16-320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 320,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-M-32-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 384,
13 | "heads": 6,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-S-16-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 256,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 256,
13 | "heads": 4,
14 | "layers": 10
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-S-32-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 256,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 384,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 256,
13 | "heads": 4,
14 | "layers": 10
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-16-plus-240.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 240,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-32-plus-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 256,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-H-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 16
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-M-16-alt.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 384,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 512,
7 | "patch_size": 16,
8 | "ls_init_value": 1e-4
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 384,
14 | "heads": 6,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/mt5-base-ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "hf_model_name": "google/mt5-base",
11 | "hf_tokenizer_name": "google/mt5-base",
12 | "proj": "mlp",
13 | "pooler_type": "mean_pooler"
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "hf_model_name": "xlm-roberta-base",
11 | "hf_tokenizer_name": "xlm-roberta-base",
12 | "proj": "mlp",
13 | "pooler_type": "mean_pooler"
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/open_clip/model_configs/roberta-ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "roberta-base",
12 | "hf_tokenizer_name": "roberta-base",
13 | "proj": "mlp",
14 | "pooler_type": "mean_pooler"
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-e-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 56,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.5715,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1280,
15 | "heads": 20,
16 | "layers": 36
17 | }
18 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1024,
15 | "heads": 16,
16 | "layers": 24
17 | }
18 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/mt5-xl-ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "google/mt5-xl",
12 | "hf_tokenizer_name": "google/mt5-xl",
13 | "proj": "mlp",
14 | "pooler_type": "mean_pooler"
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-bigG-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 48,
6 | "width": 1664,
7 | "head_width": 104,
8 | "mlp_ratio": 4.9231,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1280,
15 | "heads": 20,
16 | "layers": 32
17 | }
18 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "hf_model_name": "xlm-roberta-large",
12 | "hf_tokenizer_name": "xlm-roberta-large",
13 | "proj": "mlp",
14 | "pooler_type": "mean_pooler"
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50x64.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 448,
5 | "layers": [
6 | 3,
7 | 15,
8 | 36,
9 | 10
10 | ],
11 | "width": 128,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 1024,
18 | "heads": 16,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/vit_medium_patch16_gap_256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_medium_patch16_gap_256",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 256
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/swin_base_patch4_window7_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "timm_model_name": "swin_base_patch4_window7_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 640,
14 | "heads": 10,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_base",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_base_w.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_base",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 640,
16 | "heads": 10,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_large",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 768,
16 | "heads": 12,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_large_d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_large",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "mlp",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 768,
16 | "heads": 12,
17 | "layers": 16
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_small",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_tiny",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 224
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_base_w_320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_base",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 320
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 640,
16 | "heads": 10,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_large_d_320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_large",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "mlp",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 320
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 768,
16 | "heads": 12,
17 | "layers": 16
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_xlarge.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_xlarge",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 1024,
16 | "heads": 16,
17 | "layers": 20
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_xxlarge.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_xxlarge",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 256
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 1024,
16 | "heads": 16,
17 | "layers": 24
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/convnext_xxlarge_320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "convnext_xxlarge",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "timm_drop": 0.0,
9 | "timm_drop_path": 0.1,
10 | "image_size": 320
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 1024,
16 | "heads": 16,
17 | "layers": 24
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2 | from .factory import create_model, create_model_and_transforms, get_tokenizer
3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, \
5 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
6 | from .openai import load_openai_model, list_openai_models
7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
8 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
9 | from .tokenizer import SimpleTokenizer, tokenize, decode
10 | from .transform import image_transform
11 |
--------------------------------------------------------------------------------
/models/Adapter.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 |
4 | class Adapter(nn.Module):
5 |
6 | def __init__(self,
7 | clip_model,
8 | target,
9 | ):
10 |
11 | super(Adapter, self).__init__()
12 |
13 | input_sizes = clip_model.token_c
14 |
15 | for i,input_size in enumerate(input_sizes):
16 | self.add_module("{}_adapter".format(i), nn.Sequential(nn.Conv2d(input_size, target, 1, 1)))
17 |
18 |
19 | def forward(self, tokens):
20 | vision_features=[]
21 | for i,token in enumerate(tokens):
22 | vision_feature=getattr(self,'{}_adapter'.format(i))(token).contiguous().permute(0, 2, 3, 1)
23 | vision_feature = vision_feature / vision_feature.norm(dim=-1, keepdim=True)
24 | vision_features.append(vision_feature)
25 | return vision_features
--------------------------------------------------------------------------------
/models/MapMaker.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | import torch.nn.functional as F
6 |
7 |
8 | class MapMaker(nn.Module):
9 |
10 | def __init__(self,image_size):
11 |
12 | super(MapMaker, self).__init__()
13 | self.image_size = image_size
14 |
15 |
16 | def forward(self, vision_adapter_features,propmt_adapter_features):
17 | anomaly_maps=[]
18 |
19 | for i,vision_adapter_feature in enumerate(vision_adapter_features):
20 | B, H, W, C = vision_adapter_feature.shape
21 | anomaly_map = (vision_adapter_feature.view((B, H * W, C)) @ propmt_adapter_features).contiguous().view(
22 | (B, H, W, -1)).permute(0, 3, 1, 2)
23 |
24 | anomaly_maps.append(anomaly_map)
25 |
26 | anomaly_map = torch.stack(anomaly_maps, dim=0).mean(dim=0)
27 | anomaly_map = F.interpolate(anomaly_map, (self.image_size, self.image_size), mode='bilinear', align_corners=True)
28 | return torch.softmax(anomaly_map, dim=1)
--------------------------------------------------------------------------------
/models/Necker.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class Necker(nn.Module):
8 |
9 | def __init__(self,
10 | clip_model
11 | ):
12 | super(Necker, self).__init__()
13 | self.clip_model=clip_model
14 | target = max(self.clip_model.token_size)
15 | for i,size in enumerate(self.clip_model.token_size):
16 | self.add_module("{}_upsample".format(i),
17 | nn.UpsamplingBilinear2d(scale_factor=target/size))
18 |
19 |
20 | @torch.no_grad()
21 | def forward(self, tokens):
22 | align_features=[]
23 | for i,token in enumerate(tokens):
24 | if len(token.shape) == 3:
25 | B, N, C=token.shape
26 | token = token[:, 1:, :]
27 | token=token.view((B,int(math.sqrt(N-1)),int(math.sqrt(N-1)),C)).permute(0, 3, 1, 2)
28 | align_features.append(getattr(self, "{}_upsample".format(i))(token))
29 | return align_features
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from dataclasses import dataclass, asdict
3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union
4 |
5 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
6 | CenterCrop
7 |
8 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
9 |
10 |
11 | def _convert_to_rgb(image):
12 | return image.convert('RGB')
13 |
14 |
15 |
16 | def image_transform(
17 | mean: Optional[Tuple[float, ...]] = None,
18 | std: Optional[Tuple[float, ...]] = None,
19 | ):
20 |
21 | mean = mean or OPENAI_DATASET_MEAN
22 | if not isinstance(mean, (list, tuple)):
23 | mean = (mean,) * 3
24 |
25 | std = std or OPENAI_DATASET_STD
26 | if not isinstance(std, (list, tuple)):
27 | std = (std,) * 3
28 |
29 | normalize = Normalize(mean=mean, std=std)
30 |
31 | transforms = [
32 | _convert_to_rgb,
33 | ToTensor(),
34 | normalize,
35 | ]
36 |
37 | return Compose(transforms)
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 cnulab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config/busi.yaml:
--------------------------------------------------------------------------------
1 | version: v1.0.0
2 | random_seed: 100
3 |
4 | data_root: data
5 |
6 | train_dataset: busi
7 | test_datasets: [busi]
8 |
9 | epoch: 200
10 | batch_size: 8
11 |
12 | print_freq_step: 10
13 | val_freq_epoch: 1
14 |
15 | image_size: 224
16 | model_name: ViT-L-14
17 | layers_out: [12,18,24]
18 |
19 | anomaly_tasks:
20 | CutpasteTask: 0.25
21 | GaussIntensityChangeTask: 0.25
22 | SourceTask: 0.25
23 | IdentityTask: 0.25
24 |
25 | prompt_maker: coop
26 | n_learnable_token: 8
27 | CSC: True
28 | class_token_positions: [end]
29 |
30 | save_root: results
31 |
32 | prompts:
33 | normal: [
34 | normal,
35 | healthy,
36 | negative,
37 | unremarkable,
38 | clear,
39 | asymptomatic,
40 | normal findings,
41 | no findings,
42 | in good health,
43 | no evidence of disease
44 | ]
45 |
46 | abnormal: [
47 | abnormal,
48 | positive,
49 | symptomatic,
50 | disease,
51 | lesion,
52 | pathological,
53 | impaired,
54 | evidence of disease,
55 | abnormal finding,
56 | pathological condition,
57 | pathological abnormality
58 | ]
--------------------------------------------------------------------------------
/config/brainmri.yaml:
--------------------------------------------------------------------------------
1 | version: v1.0.0
2 | random_seed: 100
3 |
4 | data_root: data
5 | train_dataset: brainmri
6 | test_datasets: [brainmri]
7 |
8 | epoch: 200
9 | batch_size: 8
10 |
11 | print_freq_step: 10
12 | val_freq_epoch: 1
13 |
14 | image_size: 224
15 | model_name: ViT-L-14
16 | layers_out: [12,18,24]
17 |
18 | anomaly_tasks:
19 | CutpasteTask: 0.25
20 | GaussIntensityChangeTask: 0.25
21 | SourceTask: 0.25
22 | IdentityTask: 0.25
23 |
24 |
25 | prompt_maker: coop
26 | n_learnable_token: 8
27 | CSC: True
28 | class_token_positions: [end]
29 |
30 | save_root: results
31 |
32 | prompts:
33 | normal: [
34 | normal,
35 | healthy,
36 | negative,
37 | unremarkable,
38 | clear,
39 | asymptomatic,
40 | normal findings,
41 | no findings,
42 | in good health,
43 | no evidence of disease
44 | ]
45 |
46 | abnormal: [
47 | abnormal,
48 | positive,
49 | symptomatic,
50 | disease,
51 | lesion,
52 | pathological,
53 | impaired,
54 | evidence of disease,
55 | abnormal finding,
56 | pathological condition,
57 | pathological abnormality
58 | ]
--------------------------------------------------------------------------------
/config/chexpert.yaml:
--------------------------------------------------------------------------------
1 | version: v1.0.0
2 | random_seed: 100
3 |
4 | data_root: data
5 |
6 | train_dataset: chexpert
7 | test_datasets: [chexpert]
8 |
9 | epoch: 200
10 | batch_size: 8
11 |
12 | print_freq_step: 10
13 | val_freq_epoch: 1
14 |
15 | image_size: 224
16 | model_name: ViT-L-14
17 | layers_out: [12,18,24]
18 |
19 | anomaly_tasks:
20 | CutpasteTask: 0.25
21 | GaussIntensityChangeTask: 0.25
22 | SourceTask: 0.25
23 | IdentityTask: 0.25
24 |
25 |
26 | prompt_maker: coop
27 | n_learnable_token: 8
28 | CSC: True
29 | class_token_positions: [end]
30 |
31 |
32 | save_root: results
33 |
34 | prompts:
35 | normal: [
36 | normal,
37 | healthy,
38 | negative,
39 | unremarkable,
40 | clear,
41 | asymptomatic,
42 | normal findings,
43 | no findings,
44 | in good health,
45 | no evidence of disease
46 | ]
47 |
48 | abnormal: [
49 | abnormal,
50 | positive,
51 | symptomatic,
52 | disease,
53 | lesion,
54 | pathological,
55 | impaired,
56 | evidence of disease,
57 | abnormal finding,
58 | pathological condition,
59 | pathological abnormality
60 | ]
--------------------------------------------------------------------------------
/medsyn/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import numpy as np
3 | import numpy.typing as npt
4 |
5 |
6 | def accumulate_rotation(init_matrix: npt.NDArray, angle: float, axes: Tuple[int, int]) -> npt.NDArray:
7 | """
8 | Calculates rotation matrix and multiplies it with current transformation matrix.
9 | :param init_matrix: Current transformation matrix.
10 | :param angle: Angle of rotation in radians.
11 | :param axes: Tuple of axes to be rotated within.
12 | """
13 |
14 | cos_ang = np.cos(angle)
15 | sin_ang = np.sin(angle)
16 |
17 | rot_mat = np.identity(len(init_matrix))
18 |
19 | # Note: as the axes are always in ascending order, that does mean that for N>2, some rotations may be backwards
20 | # (due to the sine/-sine always being in wrong places)
21 | # BUT because we're sampling it symmetrically around 0 the distribution doesn't change.
22 | a1, a2 = axes
23 | rot_mat[a1, a1] = rot_mat[a2, a2] = cos_ang
24 | rot_mat[a1, a2] = -sin_ang
25 | rot_mat[a2, a1] = sin_ang
26 | return rot_mat @ init_matrix
27 |
28 |
29 | def accumulate_scaling(init_matrix: npt.NDArray, scale: float) -> npt.NDArray:
30 | """
31 | Calculates scaling matrix and multiplies it with current transformation matrix.
32 | Assumes cartesian coordinates (not homogeneous).
33 | :param init_matrix: Current transformation matrix
34 | :param scale: Factor to scale by.
35 | """
36 | scale_mat = np.identity(len(init_matrix)) * scale
37 |
38 | return scale_mat @ init_matrix
39 |
40 |
41 | def get_patch_slices(patch_corner: np.ndarray, patch_shape: Tuple[int]) -> Tuple[slice]:
42 | return tuple([slice(c, c + d) for (c, d) in zip(patch_corner, patch_shape)])
43 |
44 |
45 | # Same as above, but with additional slice at beginning to include all image channels.
46 | def get_patch_image_slices(patch_corner: np.ndarray, patch_shape: Tuple[int]) -> Tuple[slice]:
47 | return tuple([slice(None)] + list(get_patch_slices(patch_corner, patch_shape)))
48 |
49 |
50 | def nsa_sample_dimension(lb, ub, img_d):
51 | gamma_lb = 0.03
52 | gamma_shape = 2
53 | gamma_scale = 0.1
54 |
55 | gamma_sample = (gamma_lb + np.random.gamma(gamma_shape, gamma_scale)) * img_d
56 |
57 | return int(np.clip(gamma_sample, lb, ub))
58 |
--------------------------------------------------------------------------------
/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | from torch import nn as nn
5 | from torchvision.ops.misc import FrozenBatchNorm2d
6 |
7 |
8 | def freeze_batch_norm_2d(module, module_match={}, name=''):
9 | """
10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
12 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
13 |
14 | Args:
15 | module (torch.nn.Module): Any PyTorch module.
16 | module_match (dict): Dictionary of full module names to freeze (all if empty)
17 | name (str): Full module name (prefix)
18 |
19 | Returns:
20 | torch.nn.Module: Resulting module
21 |
22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
23 | """
24 | res = module
25 | is_match = True
26 | if module_match:
27 | is_match = name in module_match
28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
29 | res = FrozenBatchNorm2d(module.num_features)
30 | res.num_features = module.num_features
31 | res.affine = module.affine
32 | if module.affine:
33 | res.weight.data = module.weight.data.clone().detach()
34 | res.bias.data = module.bias.data.clone().detach()
35 | res.running_mean.data = module.running_mean.data
36 | res.running_var.data = module.running_var.data
37 | res.eps = module.eps
38 | else:
39 | for child_name, child in module.named_children():
40 | full_child_name = '.'.join([name, child_name]) if name else child_name
41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
42 | if new_child is not child:
43 | res.add_module(child_name, new_child)
44 | return res
45 |
46 |
47 | # From PyTorch internals
48 | def _ntuple(n):
49 | def parse(x):
50 | if isinstance(x, collections.abc.Iterable):
51 | return x
52 | return tuple(repeat(x, n))
53 | return parse
54 |
55 |
56 | to_1tuple = _ntuple(1)
57 | to_2tuple = _ntuple(2)
58 | to_3tuple = _ntuple(3)
59 | to_4tuple = _ntuple(4)
60 | to_ntuple = lambda n, x: _ntuple(n)(x)
61 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ### Datasets Links
2 | - **BUSI [[Baidu Cloud (pass8866)]](https://pan.baidu.com/s/1EVt96fExiqrvMQslPDRRRg?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1PyvMXdNEVY86BY1PV8yKhPVS30TAmS6X/view?usp=drive_link) [[Official Link]](https://scholar.cu.edu.eg/?q=afahmy/pages/dataset)**
3 | - **BrainMRI [[Baidu Cloud (pass8866)]](https://pan.baidu.com/s/1--5vPMN-eTqePPYjpKTwvA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1kldE-5_wXaN-JR_8Y_mRCKQ6VZiyv3km/view?usp=drive_link) [[Official Link]](https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection)**
4 | - **CheXpert [[Baidu Cloud (pass8866)]](https://pan.baidu.com/s/15-V5wobA_7ICvZAXBraDGA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1pVYRipGC2VqjYP-wHdDFR-lLf7itLiUi/view?usp=drive_link) [[Official Link]](https://stanfordmlgroup.github.io/competitions/chexpert/)**
5 |
6 | ### The complete directory structure is as follows:
7 | ```
8 | |--data
9 | |--brainmri
10 | |--samples
11 | |--train.json
12 | |--test.json
13 | |--images
14 | |--test
15 | |--abnormal
16 | |--image_97.jpg
17 | |--...
18 | |--normal
19 | |--image_32.jpg
20 | |--...
21 | |--train
22 | |--normal
23 | |--image_0.jpg
24 | |--...
25 | |--busi
26 | |--samples
27 | |--train.json
28 | |--test.json
29 | |--images
30 | |--test
31 | |--abnormal
32 | |--benign_0.jpg
33 | |--...
34 | |--normal
35 | |--normal_32.jpg
36 | |--...
37 | |--ground_true
38 | |--benign_mask_0.jpg
39 | |--...
40 | |--train
41 | |--normal
42 | |--normal_0.jpg
43 | |--...
44 | |--chexpert
45 | |--samples
46 | |--train.json
47 | |--test.json
48 | |--images
49 | |--test
50 | |--abnormal
51 | |--00002.jpg
52 | |--...
53 | |--normal
54 | |--00960.jpg
55 | |--...
56 | |--train
57 | |--normal
58 | |--00000.jpg
59 | |--...
60 | ```
61 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MediCLIP
2 |
3 | **💡 This is the official implementation of the paper "MediCLIP: Adapting CLIP for Few-shot Medical Image Anomaly Detection"(MICCAI 2024) [[arxiv]](https://arxiv.org/abs/2405.11315)**.
4 |
5 | MediCLIP is an efficient few-shot medical image anomaly detection method, demonstrating SOTA anomaly detection performance with very few normal medical images. MediCLIP integrates learnable prompts, adapters, and realistic medical image anomaly synthesis tasks.
6 |
7 |
8 |

9 |
10 | ## 🔧 Installation
11 |
12 | To run experiments, first clone the repository and install `requirements.txt`.
13 |
14 | ```
15 | $ git clone https://github.com/cnulab/MediCLIP.git
16 | $ cd MediCLIP
17 | $ pip install -r requirements.txt
18 | ```
19 | ### Data preparation
20 | Download the following datasets:
21 | - **BUSI [[Baidu Cloud (pwd8866)]](https://pan.baidu.com/s/1EVt96fExiqrvMQslPDRRRg?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1PyvMXdNEVY86BY1PV8yKhPVS30TAmS6X/view?usp=drive_link) [[Official Link]](https://scholar.cu.edu.eg/?q=afahmy/pages/dataset)**
22 | - **BrainMRI [[Baidu Cloud (pwd8866)]](https://pan.baidu.com/s/1--5vPMN-eTqePPYjpKTwvA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1kldE-5_wXaN-JR_8Y_mRCKQ6VZiyv3km/view?usp=drive_link) [[Official Link]](https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection)**
23 | - **CheXpert [[Baidu Cloud (pwd8866)]](https://pan.baidu.com/s/15-V5wobA_7ICvZAXBraDGA?pwd=8866) [[Google Drive]](https://drive.google.com/file/d/1pVYRipGC2VqjYP-wHdDFR-lLf7itLiUi/view?usp=drive_link) [[Official Link]](https://stanfordmlgroup.github.io/competitions/chexpert/)**
24 |
25 | Unzip them to the `data`. Please refer to [data/README](data/README.md).
26 |
27 | ## 🚀 Experiments
28 |
29 | To train the MediCLIP on the BrainMRI dataset with the support set size is 16:
30 | ```
31 | $ python train.py --config_path config/brainmri.yaml --k_shot 16
32 | ```
33 |
34 | To test the MediCLIP on the BrainMRI dataset:
35 | ```
36 | $ python test.py --config_path config/brainmri.yaml --checkpoint_path xxx.pkl
37 | ```
38 | Replace ``xxx.pkl`` with your checkpoint path.
39 | 
40 |
41 | ---
42 | Code reference: **[[CLIP]](https://github.com/OpenAI/CLIP)** **[[CoOp]](https://github.com/KaiyangZhou/CoOp)** **[[Many-Tasks-Make-Light-Work]](https://github.com/matt-baugh/many-tasks-make-light-work)**.
43 |
44 |
45 | ## 🔗 Citation
46 |
47 | If this work is helpful to you, please cite it as:
48 | ```
49 | @inproceedings{zhang2024mediclip,
50 | title={MediCLIP: Adapting CLIP for Few-shot Medical Image Anomaly Detection},
51 | author={Ximiao Zhang, Min Xu, Dehui Qiu, Ruixin Yan, Ning Lang, and Xiuzhuang Zhou},
52 | year={2024},
53 | eprint={2405.11315},
54 | archivePrefix={arXiv},
55 | primaryClass={cs.CV}
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_models_by_tag('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | precision: Optional[str] = None,
26 | device: Optional[Union[str, torch.device]] = None,
27 |
28 | ):
29 | """Load a CLIP model
30 |
31 | Parameters
32 | ----------
33 | name : str
34 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
35 | precision: str
36 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
37 | device : Union[str, torch.device]
38 | The device to put the loaded model
39 | jit : bool
40 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
41 | cache_dir : Optional[str]
42 | The directory to cache the downloaded model weights
43 |
44 | Returns
45 | -------
46 | model : torch.nn.Module
47 | The CLIP model
48 | preprocess : Callable[[PIL.Image], torch.Tensor]
49 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
50 | """
51 |
52 | if device is None:
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 | if precision is None:
55 | precision = 'fp32' if device == 'cpu' else 'fp16'
56 |
57 | if get_pretrained_url(name, 'openai'):
58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'))
59 | elif os.path.isfile(name):
60 | model_path = name
61 | else:
62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63 |
64 | try:
65 | # loading JIT archive
66 | model = torch.jit.load(model_path, map_location="cpu").eval()
67 | state_dict = None
68 |
69 | except RuntimeError:
70 | # loading saved state dict
71 | state_dict = torch.load(model_path, map_location="cpu")
72 |
73 | cast_dtype = get_cast_dtype(precision)
74 |
75 | try:
76 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
77 | except KeyError:
78 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
79 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
80 |
81 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
82 | model = model.to(device)
83 |
84 | if precision.startswith('amp') or precision == 'fp32':
85 | model.float()
86 | elif precision == 'bf16':
87 | convert_weights_to_lp(model, dtype=torch.bfloat16)
88 |
89 | return model
90 |
--------------------------------------------------------------------------------
/medsyn/labelling.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Union
3 | import numpy as np
4 | import numpy.typing as npt
5 | from scipy import ndimage
6 | from skimage import morphology
7 |
8 |
9 | class AnomalyLabeller(ABC):
10 | @abstractmethod
11 | def label(self, aug_img: npt.NDArray[float], orig_img: npt.NDArray[float], mask: npt.NDArray[bool]) \
12 | -> npt.NDArray[float]:
13 | """
14 | :param aug_img: Image with anomaly augmentation applied.
15 | :param orig_img: Original image, prior to anomalies.
16 | :param mask: Mask of where the image has been altered.
17 | """
18 | pass
19 |
20 | def __call__(self, aug_img: npt.NDArray[float], orig_img: npt.NDArray[float], mask: npt.NDArray[bool]) \
21 | -> npt.NDArray[float]:
22 |
23 | return self.label(aug_img, orig_img, mask)
24 |
25 |
26 | class IntensityDiffLabeller(AnomalyLabeller, ABC):
27 |
28 | def __init__(self):
29 | super().__init__()
30 | self.binary_structures = {}
31 |
32 | @abstractmethod
33 | def label_fn(self, x: Union[npt.NDArray[float], float]) -> Union[npt.NDArray[float], float]:
34 | pass
35 |
36 | def label(self,
37 | aug_img: npt.NDArray[float],
38 | orig_img: npt.NDArray[float],
39 | mask: npt.NDArray[bool]) \
40 | -> npt.NDArray[float]:
41 | """
42 | :param aug_img: Image with patches blended within it.
43 | :param orig_img: Original image, prior to anomalies.
44 | :param mask: Mask of where the image has been altered.
45 | """
46 |
47 | avg_diff = np.mean(mask * np.abs(aug_img - orig_img), axis=0)
48 |
49 | scaled_diff = self.label_fn(avg_diff)
50 |
51 | assert np.all(scaled_diff >= 0)
52 |
53 | num_dims = len(mask.shape)
54 | if num_dims not in self.binary_structures:
55 | self.binary_structures[num_dims] = ndimage.generate_binary_structure(num_dims, 2)
56 |
57 | bin_structure = self.binary_structures[num_dims]
58 |
59 | for anom_slice in ndimage.find_objects(ndimage.label(mask)[0]):
60 | anom_region_label = ndimage.grey_closing(scaled_diff[anom_slice], footprint=bin_structure)
61 |
62 | recon_seed = np.copy(anom_region_label)
63 | recon_seed[num_dims * (slice(1, -1),)] = anom_region_label.max()
64 | scaled_diff[anom_slice] = morphology.reconstruction(recon_seed, anom_region_label,
65 | method='erosion',
66 | )
67 | return scaled_diff
68 |
69 |
70 |
71 | class SaturatingLabeller(IntensityDiffLabeller):
72 |
73 | def __init__(self, a: float, c: float):
74 | """
75 | Labeller using transformed sigmoid function: (1 + c) / (1 + e^(-ax+b)) - c
76 | Function range is [-c, 1]
77 | """
78 | super().__init__()
79 | self.a = a
80 | self.c = c
81 |
82 | def label_fn(self, x: Union[npt.NDArray[float], float]) -> Union[npt.NDArray[float], float]:
83 | return (1 + self.c) / (1 + np.exp(-self.a * x) / self.c) - self.c
84 |
85 | def __call__(self, aug_img: npt.NDArray[float], orig_img: npt.NDArray[float], mask: npt.NDArray[bool]) \
86 | -> npt.NDArray[float]:
87 | return self.label(aug_img, orig_img, mask)
88 |
89 |
90 | class FlippedGaussianLabeller(IntensityDiffLabeller):
91 | def __init__(self, std: float):
92 | super(FlippedGaussianLabeller, self).__init__()
93 | self.std = std
94 |
95 | def label_fn(self, x: Union[npt.NDArray[float], float]) -> Union[npt.NDArray[float], float]:
96 | return 1 - np.exp(-x**2 / (2 * self.std**2))
97 |
--------------------------------------------------------------------------------
/utils/misc_helper.py:
--------------------------------------------------------------------------------
1 |
2 | import logging
3 | from datetime import datetime
4 | import random
5 | import os
6 | import torch
7 | import numpy as np
8 | from collections.abc import Mapping
9 | import shutil
10 |
11 | from sklearn import metrics
12 |
13 |
14 | def map_func(storage, location):
15 | return storage.cuda()
16 |
17 |
18 | def create_logger(name, log_file, level=logging.INFO):
19 | log = logging.getLogger(name)
20 | formatter = logging.Formatter(
21 | "[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s"
22 | )
23 | fh = logging.FileHandler(log_file)
24 | fh.setFormatter(formatter)
25 | sh = logging.StreamHandler()
26 | sh.setFormatter(formatter)
27 | log.setLevel(level)
28 | log.addHandler(fh)
29 | log.addHandler(sh)
30 | return log
31 |
32 |
33 |
34 | def set_seed(seed):
35 | os.environ['PYTHONHASHSEED'] = str(seed)
36 | np.random.seed(seed)
37 | torch.manual_seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 | torch.backends.cudnn.deterministic = True
40 | random.seed(seed)
41 |
42 |
43 |
44 | def get_current_time():
45 | current_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
46 | return current_time
47 |
48 |
49 |
50 | class AverageMeter(object):
51 | """Computes and stores the average and current value"""
52 |
53 | def __init__(self, length=0):
54 | self.length = length
55 | self.reset()
56 |
57 | def reset(self):
58 | if self.length > 0:
59 | self.history = []
60 | else:
61 | self.count = 0
62 | self.sum = 0.0
63 | self.val = 0.0
64 | self.avg = 0.0
65 |
66 | def update(self, val, num=1):
67 | if self.length > 0:
68 | # currently assert num==1 to avoid bad usage, refine when there are some explict requirements
69 | assert num == 1
70 | self.history.append(val)
71 | if len(self.history) > self.length:
72 | del self.history[0]
73 |
74 | self.val = self.history[-1]
75 | self.avg = np.mean(self.history)
76 | else:
77 | self.val = val
78 | self.sum += val * num
79 | self.count += num
80 | self.avg = self.sum / self.count
81 |
82 |
83 | def compute_imagewise_metrics(
84 | anomaly_prediction,
85 | anomaly_ground_truth_labels
86 | ):
87 | """
88 | Computes retrieval statistics (AUROC, FPR, TPR).
89 |
90 | Args:
91 | anomaly_prediction: [np.array or list] [N] Assignment weights
92 | per image. Higher indicates higher
93 | probability of being an anomaly.
94 | anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1
95 | if image is an anomaly, 0 if not.
96 | """
97 | auroc = metrics.roc_auc_score(
98 | anomaly_ground_truth_labels, anomaly_prediction
99 | )
100 |
101 | return {"image-auroc": auroc}
102 |
103 |
104 | def compute_pixelwise_metrics(
105 | pixel_prediction,
106 | pixel_ground_truth_labels
107 | ):
108 | """
109 | Computes retrieval statistics (AUROC, FPR, TPR).
110 |
111 | Args:
112 | anomaly_prediction: [np.array or list] [N] Assignment weights
113 | per image. Higher indicates higher
114 | probability of being an anomaly.
115 | anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1
116 | if image is an anomaly, 0 if not.
117 | """
118 | pixel_prediction = np.concatenate(
119 | [pred.flatten() for pred in pixel_prediction], axis=0
120 | )
121 |
122 | pixel_ground_truth_labels = np.concatenate(
123 | [label.flatten() for label in pixel_ground_truth_labels], axis=0
124 | )
125 |
126 | pixel_ground_truth_labels[pixel_ground_truth_labels > 0] = 1
127 |
128 | pixel_auroc = metrics.roc_auc_score(
129 | pixel_ground_truth_labels, pixel_prediction
130 | )
131 |
132 | return {"pixel-auroc": pixel_auroc}
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from math import exp
7 |
8 |
9 | class FocalLoss(nn.Module):
10 | """
11 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
12 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
13 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
14 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
15 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
16 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
17 | focus on hard misclassified example
18 | :param smooth: (float,double) smooth value when cross entropy
19 | :param balance_index: (int) balance class index, should be specific when alpha is float
20 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
21 | """
22 |
23 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
24 | super(FocalLoss, self).__init__()
25 | self.apply_nonlin = apply_nonlin
26 | self.alpha = alpha
27 | self.gamma = gamma
28 | self.balance_index = balance_index
29 | self.smooth = smooth
30 | self.size_average = size_average
31 |
32 | if self.smooth is not None:
33 | if self.smooth < 0 or self.smooth > 1.0:
34 | raise ValueError('smooth value should be in [0,1]')
35 |
36 | def forward(self, logit, target):
37 | if self.apply_nonlin is not None:
38 | logit = self.apply_nonlin(logit)
39 | num_class = logit.shape[1]
40 |
41 | if logit.dim() > 2:
42 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
43 | logit = logit.view(logit.size(0), logit.size(1), -1)
44 | logit = logit.permute(0, 2, 1).contiguous()
45 | logit = logit.view(-1, logit.size(-1))
46 | target = torch.squeeze(target, 1)
47 | target = target.view(-1, 1)
48 | alpha = self.alpha
49 |
50 | if alpha is None:
51 | alpha = torch.ones(num_class, 1)
52 | elif isinstance(alpha, (list, np.ndarray)):
53 | assert len(alpha) == num_class
54 | alpha = torch.FloatTensor(alpha).view(num_class, 1)
55 | alpha = alpha / alpha.sum()
56 | elif isinstance(alpha, float):
57 | alpha = torch.ones(num_class, 1)
58 | alpha = alpha * (1 - self.alpha)
59 | alpha[self.balance_index] = self.alpha
60 |
61 | else:
62 | raise TypeError('Not support alpha type')
63 |
64 | if alpha.device != logit.device:
65 | alpha = alpha.to(logit.device)
66 |
67 | idx = target.cpu().long()
68 |
69 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
70 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
71 | if one_hot_key.device != logit.device:
72 | one_hot_key = one_hot_key.to(logit.device)
73 |
74 | if self.smooth:
75 | one_hot_key = torch.clamp(
76 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
77 | pt = (one_hot_key * logit).sum(1) + self.smooth
78 | logpt = pt.log()
79 |
80 | gamma = self.gamma
81 |
82 | alpha = alpha[idx]
83 | alpha = torch.squeeze(alpha)
84 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
85 |
86 | if self.size_average:
87 | loss = loss.mean()
88 | return loss
89 |
90 |
91 | class BinaryDiceLoss(nn.Module):
92 | def __init__(self):
93 | super(BinaryDiceLoss, self).__init__()
94 |
95 | def forward(self, input, targets):
96 | N = targets.size()[0]
97 | smooth = 1
98 | input_flat = input.view(N, -1)
99 | targets_flat = targets.view(N, -1)
100 |
101 | intersection = input_flat * targets_flat
102 | N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
103 | loss = 1 - N_dice_eff.sum() / N
104 | return loss
105 |
106 |
107 | class CrossEntropyLoss(nn.Module):
108 |
109 | def __init__(self):
110 | super().__init__()
111 | self.criterion = nn.CrossEntropyLoss()
112 |
113 | def forward(self, logit,target):
114 |
115 | bsz,_,h,w=logit.size()
116 | logit = logit.view(bsz, 2, -1)
117 | gt_mask = target.view(bsz, -1).long()
118 | return self.criterion(logit,gt_mask)
119 |
--------------------------------------------------------------------------------
/open_clip/factory.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | import numpy as np
7 | from copy import deepcopy
8 | from pathlib import Path
9 | from typing import Any, Dict, Optional, Tuple, Union
10 | import torch
11 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12 | from .model import CLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13 | resize_pos_embed, get_cast_dtype
14 | from .openai import load_openai_model
15 | from .transform import image_transform
16 | from .tokenizer import tokenize
17 |
18 |
19 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
20 |
21 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
22 |
23 |
24 | def _natural_key(string_):
25 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
26 |
27 | def _rescan_model_configs():
28 | global _MODEL_CONFIGS
29 |
30 | config_ext = ('.json',)
31 | config_files = []
32 | for config_path in _MODEL_CONFIG_PATHS:
33 | if config_path.is_file() and config_path.suffix in config_ext:
34 | config_files.append(config_path)
35 | elif config_path.is_dir():
36 | for ext in config_ext:
37 | config_files.extend(config_path.glob(f'*{ext}'))
38 |
39 | for cf in config_files:
40 | with open(cf, 'r') as f:
41 | model_cfg = json.load(f)
42 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
43 | _MODEL_CONFIGS[cf.stem] = model_cfg
44 |
45 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
46 |
47 |
48 | _rescan_model_configs() # initial populate of model config registry
49 |
50 |
51 | def list_models():
52 | """ enumerate available model architectures based on config files """
53 | return list(_MODEL_CONFIGS.keys())
54 |
55 |
56 | def add_model_config(path):
57 | """ add model config path or file and update registry """
58 | if not isinstance(path, Path):
59 | path = Path(path)
60 | _MODEL_CONFIG_PATHS.append(path)
61 | _rescan_model_configs()
62 |
63 | def get_model_config(model_name):
64 | if model_name in _MODEL_CONFIGS:
65 | return deepcopy(_MODEL_CONFIGS[model_name])
66 | else:
67 | return None
68 |
69 | def get_tokenizer():
70 | return tokenize
71 |
72 |
73 | def load_state_dict(checkpoint_path: str, map_location='cpu'):
74 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
75 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
76 | state_dict = checkpoint['state_dict']
77 | else:
78 | state_dict = checkpoint
79 | if next(iter(state_dict.items()))[0].startswith('module'):
80 | state_dict = {k[7:]: v for k, v in state_dict.items()}
81 | return state_dict
82 |
83 |
84 | def load_checkpoint(model, checkpoint_path, strict=True):
85 | state_dict = load_state_dict(checkpoint_path)
86 | # detect old format and make compatible with new format
87 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
88 | state_dict = convert_to_custom_text_state_dict(state_dict)
89 | resize_pos_embed(state_dict, model)
90 | incompatible_keys = model.load_state_dict(state_dict, strict=strict)
91 | return incompatible_keys
92 |
93 |
94 |
95 | def create_model(
96 | model_name: str,
97 | img_size: int,
98 | precision: str = 'fp32',
99 | device: Union[str, torch.device] = 'cpu',
100 | ):
101 |
102 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
103 | pretrained_cfg = {}
104 | model_cfg = None
105 |
106 | if isinstance(device, str):
107 | device = torch.device(device)
108 |
109 | model_cfg = model_cfg or get_model_config(model_name)
110 |
111 | if model_cfg['vision_cfg']['image_size'] != img_size:
112 | model_cfg['vision_cfg']['image_size'] = img_size
113 |
114 | cast_dtype = get_cast_dtype(precision)
115 |
116 | model_pre = load_openai_model(
117 | model_name,
118 | precision=precision,
119 | device=device,
120 | )
121 |
122 | state_dict = model_pre.state_dict()
123 |
124 | model = CLIP(**model_cfg, cast_dtype=cast_dtype)
125 |
126 | ### for resnet
127 | if not hasattr(model.visual, 'grid_size'):
128 | model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1))
129 |
130 | resize_pos_embed(state_dict, model)
131 | incompatible_keys = model.load_state_dict(state_dict, strict=True)
132 |
133 | model.to(device=device)
134 |
135 | if precision in ("fp16", "bf16"):
136 | convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
137 |
138 | # set image / mean metadata from pretrained_cfg if available, or use default
139 | model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
140 | model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
141 |
142 | else:
143 | model = load_openai_model(
144 | model_name,
145 | precision=precision,
146 | device=device,
147 | )
148 |
149 | model.device=device
150 |
151 | return model,model_cfg
152 |
153 |
154 |
155 | def create_model_and_transforms(
156 | model_name: str,
157 | img_size: int,
158 | precision: str = 'fp32',
159 | device: Union[str, torch.device] = 'cpu',
160 | ):
161 |
162 | model,model_cfg = create_model(
163 | model_name,
164 | img_size,
165 | precision=precision,
166 | device=device,
167 | )
168 |
169 | image_mean = OPENAI_DATASET_MEAN
170 | image_std = OPENAI_DATASET_STD
171 |
172 | preprocess = image_transform(
173 | mean=image_mean,
174 | std=image_std,
175 | )
176 |
177 | return model, preprocess, model_cfg
--------------------------------------------------------------------------------
/open_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 | # https://stackoverflow.com/q/62691279
16 | import os
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 |
19 |
20 | @lru_cache()
21 | def default_bpe():
22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23 |
24 |
25 | @lru_cache()
26 | def bytes_to_unicode():
27 | """
28 | Returns list of utf-8 byte and a corresponding list of unicode strings.
29 | The reversible bpe codes work on unicode strings.
30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32 | This is a significant percentage of your normal, say, 32K bpe vocab.
33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34 | And avoids mapping to whitespace/control characters the bpe code barfs on.
35 | """
36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37 | cs = bs[:]
38 | n = 0
39 | for b in range(2**8):
40 | if b not in bs:
41 | bs.append(b)
42 | cs.append(2**8+n)
43 | n += 1
44 | cs = [chr(n) for n in cs]
45 | return dict(zip(bs, cs))
46 |
47 |
48 | def get_pairs(word):
49 | """Return set of symbol pairs in a word.
50 | Word is represented as tuple of symbols (symbols being variable-length strings).
51 | """
52 | pairs = set()
53 | prev_char = word[0]
54 | for char in word[1:]:
55 | pairs.add((prev_char, char))
56 | prev_char = char
57 | return pairs
58 |
59 |
60 | def basic_clean(text):
61 | text = ftfy.fix_text(text)
62 | text = html.unescape(html.unescape(text))
63 | return text.strip()
64 |
65 | def whitespace_clean(text):
66 | text = re.sub(r'\s+', ' ', text)
67 | text = text.strip()
68 | return text
69 |
70 |
71 | class SimpleTokenizer(object):
72 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
73 | self.byte_encoder = bytes_to_unicode()
74 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
75 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
76 | merges = merges[1:49152-256-2+1]
77 | merges = [tuple(merge.split()) for merge in merges]
78 | vocab = list(bytes_to_unicode().values())
79 |
80 | vocab = vocab + [v+'' for v in vocab]
81 | for merge in merges:
82 | vocab.append(''.join(merge))
83 |
84 | if not special_tokens:
85 | special_tokens = ['', '']
86 | else:
87 | special_tokens = ['', ''] + special_tokens
88 | vocab.extend(special_tokens)
89 | self.encoder = dict(zip(vocab, range(len(vocab))))
90 | self.decoder = {v: k for k, v in self.encoder.items()}
91 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
92 | self.cache = {t:t for t in special_tokens}
93 | special = "|".join(special_tokens)
94 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
95 |
96 | self.vocab_size = len(self.encoder)
97 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
98 |
99 | def bpe(self, token):
100 | if token in self.cache:
101 | return self.cache[token]
102 | word = tuple(token[:-1]) + ( token[-1] + '',)
103 | pairs = get_pairs(word)
104 |
105 | if not pairs:
106 | return token+''
107 |
108 | while True:
109 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
110 | if bigram not in self.bpe_ranks:
111 | break
112 | first, second = bigram
113 | new_word = []
114 | i = 0
115 | while i < len(word):
116 | try:
117 | j = word.index(first, i)
118 | new_word.extend(word[i:j])
119 | i = j
120 | except:
121 | new_word.extend(word[i:])
122 | break
123 |
124 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
125 | new_word.append(first+second)
126 | i += 2
127 | else:
128 | new_word.append(word[i])
129 | i += 1
130 | new_word = tuple(new_word)
131 | word = new_word
132 | if len(word) == 1:
133 | break
134 | else:
135 | pairs = get_pairs(word)
136 | word = ' '.join(word)
137 | self.cache[token] = word
138 | return word
139 |
140 | def encode(self, text):
141 | bpe_tokens = []
142 | text = whitespace_clean(basic_clean(text)).lower()
143 | for token in re.findall(self.pat, text):
144 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
145 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
146 | return bpe_tokens
147 |
148 | def decode(self, tokens):
149 | text = ''.join([self.decoder[token] for token in tokens])
150 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
151 | return text
152 |
153 |
154 | _tokenizer = SimpleTokenizer()
155 |
156 | def decode(output_ids: torch.Tensor):
157 | output_ids = output_ids.cpu().numpy()
158 | return _tokenizer.decode(output_ids)
159 |
160 |
161 |
162 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
163 | """
164 | Returns the tokenized representation of given input string(s)
165 |
166 | Parameters
167 | ----------
168 | texts : Union[str, List[str]]
169 | An input string or a list of input strings to tokenize
170 | context_length : int
171 | The context length to use; all CLIP models use 77 as the context length
172 |
173 | Returns
174 | -------
175 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
176 | """
177 | if isinstance(texts, str):
178 | texts = [texts]
179 |
180 | sot_token = _tokenizer.encoder[""]
181 | eot_token = _tokenizer.encoder[""]
182 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
183 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
184 |
185 | for i, tokens in enumerate(all_tokens):
186 | if len(tokens) > context_length:
187 | tokens = tokens[:context_length] # Truncate
188 | tokens[-1] = eot_token
189 | result[i, :len(tokens)] = torch.tensor(tokens)
190 | return result
191 |
192 |
193 |
--------------------------------------------------------------------------------
/open_clip/modified_resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from open_clip.utils import freeze_batch_norm_2d
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x, out_blocks):
174 | x = self.stem(x)
175 | x_1 = self.layer1(x)
176 | x_2 = self.layer2(x_1)
177 | x_3 = self.layer3(x_2)
178 | x_4 = self.layer4(x_3)
179 | x = self.attnpool(x_4)
180 |
181 | out_tokens = []
182 | x_blocks = [x_1, x_2, x_3, x_4]
183 | for i in out_blocks:
184 | out_tokens.append(x_blocks[i - 1])
185 |
186 | return x, out_tokens
187 |
--------------------------------------------------------------------------------
/models/CoOp.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from torch.nn import functional as F
7 | from torch.cuda.amp import GradScaler, autocast
8 | from open_clip.tokenizer import SimpleTokenizer,tokenize
9 |
10 |
11 | class TextEncoder(nn.Module):
12 | def __init__(self, clip_model):
13 |
14 | super().__init__()
15 |
16 | self.transformer = clip_model.transformer
17 | self.positional_embedding = clip_model.positional_embedding
18 | self.ln_final = clip_model.ln_final
19 | self.text_projection = clip_model.text_projection
20 |
21 |
22 | def forward(self, prompts, tokenized_prompts):
23 |
24 | x = prompts + self.positional_embedding
25 | x = x.permute(1, 0, 2) # NLD -> LND
26 | x,_,_ = self.transformer(x)
27 | x = x.permute(1, 0, 2) # LND -> NLD
28 | x = self.ln_final(x)
29 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
30 | return x
31 |
32 |
33 |
34 | class PromptLearner(nn.Module):
35 | def __init__(self,
36 | prompts,
37 | n_ctx, # prompt max len
38 | CSC, # True or False multi prompt
39 | class_token_position, # cls position
40 | clip_model):
41 |
42 | super().__init__()
43 |
44 | ctx_dim = clip_model.ln_final.weight.shape[0] #
45 |
46 | self.ctx={}
47 |
48 | for cls in prompts:
49 | for position in class_token_position:
50 | if CSC:
51 | ctx_vectors = torch.empty(len(prompts[cls]), n_ctx, ctx_dim).to(clip_model.device)
52 | else:
53 | ctx_vectors = torch.empty(n_ctx, ctx_dim).to(clip_model.device)
54 | nn.init.normal_(ctx_vectors, std=0.02)
55 | self.ctx['{}_{}'.format(cls,position)]=nn.Parameter(ctx_vectors,requires_grad=True)
56 |
57 | self.ctx = nn.ParameterDict(self.ctx) # to be optimized
58 |
59 | prompt_prefix = " ".join(["X"] * n_ctx)
60 |
61 | _tokenizer = SimpleTokenizer()
62 |
63 | prompts_split={cls: [prompt.replace("_", " ") for prompt in prompts[cls]] for cls in prompts}
64 |
65 | prompts_lens= {cls: [ len(_tokenizer.encode(prompt)) for prompt in prompts_split[cls]] for cls in prompts_split}
66 |
67 | prompts_learnable_tokens = {cls:[prompt_prefix + " " + prompt + "." for prompt in prompts_split[cls]] for cls in prompts_split}
68 |
69 | tokenized_prompts = {cls:torch.cat([tokenize(prompt) for prompt in prompts_learnable_tokens[cls]]).to(clip_model.device) for cls in prompts_learnable_tokens}
70 |
71 | with torch.no_grad():
72 | embeddings = {cls:clip_model.token_embedding(tokenized_prompts[cls]) for cls in tokenized_prompts}
73 |
74 | self.register_embeddings={}
75 |
76 | for cls in embeddings:
77 | self.register_embeddings['{}_token_prefix'.format(cls)]=embeddings[cls][:, :1, :]
78 | self.register_embeddings['{}_token_suffix'.format(cls)]=embeddings[cls][:, 1 + n_ctx :, :]
79 |
80 | self.n_ctx = n_ctx
81 | self.tokenized_prompts = tokenized_prompts
82 | self.prompts_lens = prompts_lens
83 | self.class_token_position = class_token_position
84 |
85 |
86 | def forward(self):
87 | cls_prompts={}
88 |
89 | for cls in self.tokenized_prompts:
90 |
91 | prefix = self.register_embeddings['{}_token_prefix'.format(cls)]
92 | suffix = self.register_embeddings['{}_token_suffix'.format(cls)]
93 |
94 | cls_prompts[cls]=[]
95 |
96 | for position in self.class_token_position:
97 |
98 | ctx = self.ctx['{}_{}'.format(cls,position)]
99 | if ctx.dim() == 2:
100 | ctx = ctx.unsqueeze(0).expand(len(self.prompts_lens[cls]), -1, -1)
101 |
102 | if position == "end":
103 | prompts = torch.cat(
104 | [
105 | prefix, # (n_cls, 1, dim)
106 | ctx, # (n_cls, n_ctx, dim)
107 | suffix, # (n_cls, *, dim)
108 | ],
109 | dim=1,
110 | )
111 |
112 | elif position == "middle":
113 |
114 | half_n_ctx = self.n_ctx // 2
115 | prompts = []
116 |
117 | for i in range(len(self.prompts_lens[cls])):
118 | p_len = self.prompts_lens[cls][i]
119 |
120 | prefix_i = prefix[i : i + 1, :, :]
121 | class_i = suffix[i : i + 1, :p_len, :]
122 | suffix_i = suffix[i : i + 1, p_len:, :]
123 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
124 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
125 |
126 | prompt = torch.cat(
127 | [
128 | prefix_i, # (1, 1, dim)
129 | ctx_i_half1, # (1, n_ctx//2, dim)
130 | class_i, # (1, name_len, dim)
131 | ctx_i_half2, # (1, n_ctx//2, dim)
132 | suffix_i, # (1, *, dim)
133 | ],
134 | dim=1,
135 | )
136 | prompts.append(prompt)
137 | prompts = torch.cat(prompts, dim=0)
138 |
139 | else :
140 | assert position == "front"
141 | prompts = []
142 |
143 | for i in range(len(self.prompts_lens[cls])):
144 | p_len = self.prompts_lens[cls][i]
145 |
146 | prefix_i = prefix[i : i + 1, :, :]
147 | class_i = suffix[i : i + 1, :p_len, :]
148 | suffix_i = suffix[i : i + 1, p_len:, :]
149 | ctx_i = ctx[i : i + 1, :, :]
150 | prompt = torch.cat(
151 | [
152 | prefix_i, # (1, 1, dim)
153 | class_i, # (1, name_len, dim)
154 | ctx_i, # (1, n_ctx, dim)
155 | suffix_i, # (1, *, dim)
156 | ],
157 | dim=1,
158 | )
159 | prompts.append(prompt)
160 |
161 | prompts = torch.cat(prompts, dim=0)
162 |
163 | cls_prompts[cls].append(prompts)
164 | cls_prompts[cls]=torch.cat(cls_prompts[cls],dim=0)
165 | return cls_prompts
166 |
167 |
168 | class PromptMaker(nn.Module):
169 |
170 | def __init__(self,
171 | prompts,
172 | clip_model,
173 | n_ctx: int=8, # prompt max len
174 | CSC: bool= True, # True or False multi prompt
175 | class_token_position: list=['end'], # cls position
176 | ):
177 |
178 | super().__init__()
179 | assert 'normal' in prompts and 'abnormal' in prompts
180 |
181 | for position in class_token_position:
182 | assert position in ['end','middle','front']
183 |
184 | self.prompt_learner = PromptLearner(prompts, n_ctx, CSC, class_token_position, clip_model)
185 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
186 |
187 | self.class_token_position = class_token_position
188 | self.text_encoder = TextEncoder(clip_model)
189 |
190 | def forward(self, image_features):
191 | prompts = self.prompt_learner()
192 | tokenized_prompts = self.tokenized_prompts
193 | text_features=[]
194 |
195 | for cls in prompts:
196 | class_embedding = self.text_encoder(prompts[cls], tokenized_prompts[cls].repeat(len(self.class_token_position),1))
197 | class_embedding = class_embedding.mean(dim=0)
198 | class_embedding = class_embedding / class_embedding.norm()
199 | text_features.append(class_embedding)
200 | text_features = torch.stack(text_features, dim=1)
201 | return text_features
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from enum import Enum
4 | import PIL
5 | import torch
6 | from torchvision import transforms
7 | import json
8 | from PIL import Image
9 | import numpy as np
10 |
11 | from medsyn.tasks import CutPastePatchBlender,\
12 | SmoothIntensityChangeTask,\
13 | GaussIntensityChangeTask,\
14 | SinkDeformationTask,\
15 | SourceDeformationTask,\
16 | IdentityTask
17 |
18 |
19 | class TrainDataset(torch.utils.data.Dataset):
20 |
21 | def __init__(
22 | self,
23 | args,
24 | source,
25 | preprocess,
26 | k_shot = -1,
27 | **kwargs,
28 | ):
29 |
30 | super().__init__()
31 | self.args = args
32 | self.source = source
33 | self.k_shot = k_shot
34 | self.transform_img = preprocess
35 | self.data_to_iterate = self.get_image_data()
36 | self.augs,self.augs_pro = self.load_anomaly_syn()
37 | assert sum(self.augs_pro)==1.0
38 |
39 | def __getitem__(self, idx):
40 | info = self.data_to_iterate[idx]
41 | image_path=os.path.join(self.source,'images',info['filename'])
42 | image = self.read_image(image_path)
43 | choice_aug = np.random.choice(a=[aug for aug in self.augs],
44 | p = [pro for pro in self.augs_pro],
45 | size=(1,), replace=False)[0]
46 | image, mask = choice_aug(image)
47 | image = Image.fromarray(image.astype(np.uint8)).convert('RGB')
48 | image = self.transform_img(image)
49 | mask = torch.from_numpy(mask)
50 | return {
51 | "image": image,
52 | "mask" : mask,
53 | }
54 |
55 | def __len__(self):
56 | return len(self.data_to_iterate)
57 |
58 | def read_image(self,path):
59 | image = PIL.Image.open(path).resize((self.args.image_size,self.args.image_size),
60 | PIL.Image.Resampling.BILINEAR).convert("L")
61 | image = np.array(image).astype(np.uint8)
62 | return image
63 |
64 | def get_image_data(self):
65 | data_to_iterate = []
66 | with open(os.path.join(self.source,'samples',"train.json"), "r") as f_r:
67 | for line in f_r:
68 | meta = json.loads(line)
69 | data_to_iterate.append(meta)
70 | if self.k_shot != -1:
71 | data_to_iterate = random.sample(
72 | data_to_iterate, self.k_shot
73 | )
74 | return data_to_iterate
75 |
76 |
77 | def load_anomaly_syn(self):
78 | tasks = []
79 | task_probability = []
80 | for task_name in self.args.anomaly_tasks.keys():
81 | if task_name =='CutpasteTask':
82 | support_images = [self.read_image(os.path.join(self.source,'images',data['filename'])) for data in self.data_to_iterate]
83 | task = CutPastePatchBlender(support_images)
84 | elif task_name == 'SmoothIntensityTask':
85 | task = SmoothIntensityChangeTask(30.0)
86 | elif task_name == 'GaussIntensityChangeTask':
87 | task = GaussIntensityChangeTask()
88 | elif task_name == 'SinkTask':
89 | task = SinkDeformationTask()
90 | elif task_name == 'SourceTask':
91 | task = SourceDeformationTask()
92 | elif task_name == 'IdentityTask':
93 | task = IdentityTask()
94 | else:
95 | raise NotImplementedError("task must in [CutpasteTask, "
96 | "SmoothIntensityTask, "
97 | "GaussIntensityChangeTask,"
98 | "SinkTask, SourceTask, IdentityTask]")
99 |
100 | tasks.append(task)
101 | task_probability.append(self.args.anomaly_tasks[task_name])
102 | return tasks, task_probability
103 |
104 |
105 |
106 | class ChexpertTestDataset(torch.utils.data.Dataset):
107 |
108 | def __init__(
109 | self,
110 | args,
111 | source,
112 | preprocess,
113 | **kwargs,
114 | ):
115 | super().__init__()
116 | self.args = args
117 | self.source = source
118 |
119 | self.transform_img = preprocess
120 | self.data_to_iterate = self.get_image_data()
121 |
122 |
123 | def __getitem__(self, idx):
124 | info = self.data_to_iterate[idx]
125 | image_path = os.path.join(self.source,'images',info['filename'])
126 | image = PIL.Image.open(image_path).convert("RGB").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.BILINEAR)
127 | mask = np.zeros((self.args.image_size,self.args.image_size)).astype(np.float)
128 | image = self.transform_img(image)
129 | mask = torch.from_numpy(mask)
130 |
131 | return {
132 | "image": image,
133 | "mask" : mask,
134 | "classname": info['clsname'],
135 | "is_anomaly": info['label'],
136 | "image_path": image_path,
137 | }
138 |
139 | def __len__(self):
140 | return len(self.data_to_iterate)
141 |
142 | def get_image_data(self):
143 | data_to_iterate = []
144 | with open(os.path.join(self.source,'samples',"test.json"), "r") as f_r:
145 | for line in f_r:
146 | meta = json.loads(line)
147 | data_to_iterate.append(meta)
148 | return data_to_iterate
149 |
150 |
151 | class BrainMRITestDataset(torch.utils.data.Dataset):
152 |
153 | def __init__(
154 | self,
155 | args,
156 | source,
157 | preprocess,
158 | **kwargs,
159 | ):
160 |
161 | super().__init__()
162 | self.args = args
163 | self.source = source
164 | self.transform_img = preprocess
165 | self.data_to_iterate = self.get_image_data()
166 |
167 |
168 | def __getitem__(self, idx):
169 | info = self.data_to_iterate[idx]
170 | image_path = os.path.join(self.source,'images',info['filename'])
171 | image = PIL.Image.open(image_path).convert("RGB").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.BILINEAR)
172 | mask = np.zeros((self.args.image_size,self.args.image_size)).astype(np.float)
173 | image = self.transform_img(image)
174 | mask = torch.from_numpy(mask)
175 |
176 | return {
177 | "image": image,
178 | "mask" : mask,
179 | "classname": info['clsname'],
180 | "is_anomaly": info['label'],
181 | "image_path": image_path,
182 | }
183 |
184 | def __len__(self):
185 | return len(self.data_to_iterate)
186 |
187 |
188 | def get_image_data(self):
189 | data_to_iterate = []
190 | with open(os.path.join(self.source,'samples',"test.json"), "r") as f_r:
191 | for line in f_r:
192 | meta = json.loads(line)
193 | data_to_iterate.append(meta)
194 | return data_to_iterate
195 |
196 |
197 |
198 | class BusiTestDataset(torch.utils.data.Dataset):
199 |
200 | def __init__(
201 | self,
202 | args,
203 | source,
204 | preprocess,
205 | **kwargs,
206 | ):
207 |
208 | super().__init__()
209 | self.args = args
210 | self.source = source
211 | self.transform_img = preprocess
212 | self.data_to_iterate = self.get_image_data()
213 |
214 |
215 | def __getitem__(self, idx):
216 | info = self.data_to_iterate[idx]
217 | image_path = os.path.join(self.source,'images',info['filename'])
218 | image = PIL.Image.open(image_path).convert("RGB").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.BILINEAR)
219 |
220 | if info.get("mask", None):
221 | mask = os.path.join(self.source,'images',info['mask'])
222 | mask = PIL.Image.open(mask).convert("L").resize((self.args.image_size,self.args.image_size),PIL.Image.Resampling.NEAREST)
223 | mask = np.array(mask).astype(np.float)/255.0
224 | mask [mask!=0.0] = 1.0
225 | else:
226 | mask = np.zeros((self.args.image_size,self.args.image_size)).astype(np.float)
227 |
228 | image = self.transform_img(image)
229 | mask = torch.from_numpy(mask)
230 |
231 | return {
232 | "image": image,
233 | "mask": mask,
234 | "classname": info['clsname'],
235 | "is_anomaly": info['label'],
236 | "image_path": image_path,
237 | }
238 |
239 | def __len__(self):
240 | return len(self.data_to_iterate)
241 |
242 | def get_image_data(self):
243 | data_to_iterate = []
244 | with open(os.path.join(self.source,'samples',"test.json"), "r") as f_r:
245 | for line in f_r:
246 | meta = json.loads(line)
247 | data_to_iterate.append(meta)
248 | return data_to_iterate
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open_clip
3 | import torch
4 | import yaml
5 | from easydict import EasyDict
6 | from models.Necker import Necker
7 | from models.Adapter import Adapter
8 | import math
9 | import argparse
10 | import warnings
11 | from utils.misc_helper import *
12 | from datasets.dataset import ChexpertTestDataset,BusiTestDataset,BrainMRITestDataset
13 | from torch.utils.data import DataLoader
14 | from models.MapMaker import MapMaker
15 | import pprint
16 | import torch.nn.functional as F
17 | from pytorch_grad_cam.utils.image import show_cam_on_image
18 | from tqdm import tqdm
19 | from PIL import Image
20 | import cv2
21 |
22 | warnings.filterwarnings('ignore')
23 |
24 | def normalization(segmentations, image_size, avgpool_size = 128):
25 |
26 | segmentations = torch.tensor(segmentations[:, None, ...]).cuda() # N x 1 x H x W
27 | segmentations = F.interpolate(segmentations,(image_size, image_size), mode='bilinear', align_corners=True)
28 |
29 | segmentations_ = F.avg_pool2d(segmentations, (avgpool_size,avgpool_size), stride=1).cpu().numpy()
30 |
31 | min_scores = (
32 | segmentations_.reshape(-1).min(axis=-1).reshape(1)
33 | )
34 |
35 | max_scores = (
36 | segmentations_.reshape(-1).max(axis=-1).reshape(1)
37 | )
38 |
39 | segmentations = segmentations.squeeze(1).cpu().numpy()
40 | segmentations = (segmentations - min_scores) / (max_scores - min_scores)
41 | segmentations = np.clip(segmentations,a_min=0,a_max=1)
42 |
43 | segmentations = cv2.GaussianBlur(segmentations, (5, 5), 0)
44 | return segmentations
45 |
46 |
47 | @torch.no_grad()
48 | def make_vision_takens_info(model,model_cfg,layers_out):
49 |
50 | img = torch.ones((1,3,model_cfg['vision_cfg']['image_size'],
51 | model_cfg['vision_cfg']['image_size'])).to(model.device)
52 |
53 | img_feature,tokens = model.encode_image(img,layers_out)
54 |
55 | if len(tokens[0].shape)==3:
56 | model.token_size= [int(math.sqrt(token.shape[1]-1)) for token in tokens]
57 | model.token_c= [token.shape[-1] for token in tokens]
58 | else:
59 | model.token_size = [token.shape[2] for token in tokens]
60 | model.token_c = [token.shape[1] for token in tokens]
61 |
62 | model.embed_dim = model_cfg['embed_dim']
63 | print("model token size is {}".format(model.token_size)," model token dim is {}".format(model.token_c))
64 |
65 |
66 | @torch.no_grad()
67 | def main(args):
68 |
69 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
70 |
71 | with open(args.config_path) as f:
72 | args.config = EasyDict(yaml.load(f, Loader=yaml.FullLoader))
73 |
74 | set_seed(seed=args.config.random_seed)
75 |
76 | model, preprocess, model_cfg = open_clip.create_model_and_transforms(args.config.model_name, args.config.image_size, device=device)
77 |
78 | for param in model.parameters():
79 | param.requires_grad_(False)
80 |
81 | args.config.model_cfg = model_cfg
82 |
83 | make_vision_takens_info(model,
84 | args.config.model_cfg,
85 | args.config.layers_out)
86 |
87 | necker = Necker(clip_model=model).to(model.device)
88 | adapter = Adapter(clip_model=model,target=args.config.model_cfg['embed_dim']).to(model.device)
89 |
90 | if args.config.prompt_maker=='coop':
91 | from models.CoOp import PromptMaker
92 | else:
93 | raise NotImplementedError("type of prompt must in ['coop']")
94 |
95 | prompt_maker = PromptMaker(
96 | prompts=args.config.prompts,
97 | clip_model=model,
98 | n_ctx= args.config.n_learnable_token,
99 | CSC = args.config.CSC,
100 | class_token_position=args.config.class_token_positions,
101 | ).to(model.device)
102 |
103 | map_maker = MapMaker(image_size=args.config.image_size).to(model.device)
104 |
105 |
106 | checkpoints = torch.load(args.checkpoint_path,map_location=map_func)
107 | adapter.load_state_dict(checkpoints['adapter_state_dict'])
108 | prompt_maker.prompt_learner.load_state_dict(checkpoints['prompt_state_dict'])
109 | prompt_maker.prompt_learner.eval()
110 | adapter.eval()
111 |
112 | for test_dataset_name in args.config.test_datasets:
113 |
114 | if test_dataset_name == 'chexpert':
115 |
116 | test_dataset = ChexpertTestDataset( args=args.config,
117 | source=os.path.join(args.config.data_root,test_dataset_name),
118 | preprocess=preprocess,
119 | )
120 |
121 | elif test_dataset_name =='brainmri':
122 |
123 | test_dataset = BrainMRITestDataset(
124 | args=args.config,
125 | source=os.path.join(args.config.data_root,test_dataset_name),
126 | preprocess=preprocess,
127 | )
128 | elif test_dataset_name =='busi':
129 |
130 | test_dataset = BusiTestDataset(
131 | args=args.config,
132 | source=os.path.join(args.config.data_root,test_dataset_name),
133 | preprocess=preprocess)
134 | else:
135 | raise NotImplementedError("dataset must in ['chexpert','busi','brainmri'] ")
136 |
137 | test_dataloader = DataLoader(test_dataset, batch_size=args.config.batch_size,num_workers=2)
138 | results = validate(args,test_dataset_name,test_dataloader,model,necker,adapter,prompt_maker,map_maker)
139 |
140 | if test_dataset_name!='busi':
141 | print("{}, image auroc: {:.4f}".format(test_dataset_name, results["image-auroc"]))
142 | else:
143 | print("{}, image auroc: {:.4f}, pixel_auroc: {:.4f}".format(test_dataset_name, results["image-auroc"],results['pixel-auroc']))
144 |
145 |
146 | def validate(args, dataset_name, test_dataloader, clip_model, necker, adapter, prompt_maker, map_maker):
147 |
148 | image_preds = []
149 | image_gts= []
150 |
151 | pixel_preds = []
152 | pixel_gts = []
153 |
154 | image_paths = []
155 |
156 | for i, input in enumerate(test_dataloader):
157 |
158 | images = input['image'].to(clip_model.device)
159 | image_paths.extend(input['image_path'])
160 |
161 | _, image_tokens = clip_model.encode_image(images, out_layers=args.config.layers_out)
162 | image_features = necker(image_tokens)
163 | vision_adapter_features = adapter(image_features)
164 | propmt_adapter_features = prompt_maker(vision_adapter_features)
165 | anomaly_map = map_maker(vision_adapter_features, propmt_adapter_features)
166 |
167 | B, _, H, W = anomaly_map.shape
168 | anomaly_map = anomaly_map[:,1,:,:]
169 |
170 | pixel_preds.append(anomaly_map)
171 | anomaly_score,_ =torch.max(anomaly_map.view((B,H*W)), dim=-1)
172 |
173 | image_preds.extend(anomaly_score.cpu().numpy().tolist())
174 | image_gts.extend(input['is_anomaly'].cpu().numpy().tolist())
175 |
176 | if dataset_name=='busi':
177 | pixel_gts.append(input['mask'].cpu().numpy())
178 |
179 | pixel_preds_np = [pixel_pred.cpu().numpy() for pixel_pred in pixel_preds]
180 | pixel_preds = normalization(torch.cat(pixel_preds,dim=0), args.config.image_size)
181 |
182 | if dataset_name == 'busi':
183 | pixel_gts = np.concatenate(pixel_gts,axis=0)
184 |
185 | save_images_root = os.path.join(args.vis_save_root,"{}".format(dataset_name))
186 | os.makedirs(save_images_root,exist_ok=True)
187 |
188 | if dataset_name=='busi':
189 | iter= tqdm(
190 | zip(image_paths, image_gts, pixel_preds, pixel_gts),
191 | total=len(image_paths),
192 | desc="Generating Segmentation Images...",
193 | leave=False,
194 | )
195 | else:
196 | iter= tqdm(
197 | zip(image_paths, image_gts, pixel_preds),
198 | total=len(image_paths),
199 | desc="Generating Segmentation Images...",
200 | leave=False,
201 | )
202 |
203 | for i, data in enumerate(iter):
204 | if dataset_name=='busi':
205 | image_path, image_gt, pixel_pred, pixel_gt = data
206 | else:
207 | image_path, image_gt, pixel_pred = data
208 |
209 | _, image_name = os.path.split(image_path)
210 |
211 | image = Image.open(image_path).convert("RGB")
212 | image = image.resize((args.config.image_size,args.config.image_size))
213 | image = np.array(image).astype(np.uint8)
214 |
215 | heat = show_cam_on_image( image / 255, pixel_pred, use_rgb=True)
216 |
217 | label_= "normal" if image_gt==0 else "abnormal"
218 |
219 | merge = [image,heat]
220 |
221 | if dataset_name == 'busi':
222 | pixel_gt = np.repeat(np.expand_dims(pixel_gt,axis=-1),3,axis=-1)
223 | merge.append(pixel_gt*255)
224 |
225 | Image.fromarray(np.concatenate(merge,axis=1).astype(np.uint8)).save(os.path.join(save_images_root,"{}_{}_{}".format(i,label_,image_name)))
226 |
227 | metric = compute_imagewise_metrics(image_preds,image_gts)
228 | if dataset_name == 'busi':
229 | metric.update(compute_pixelwise_metrics(pixel_preds_np, pixel_gts))
230 |
231 | return metric
232 |
233 |
234 | if __name__ == '__main__':
235 | parser = argparse.ArgumentParser(description="Test MediCLIP")
236 | parser.add_argument("--config_path", type=str, help="model configs")
237 | parser.add_argument("--checkpoint_path", type=str, help='the checkpoint path')
238 | parser.add_argument("--vis_save_root", type=str, default='vis_results')
239 | args = parser.parse_args()
240 | torch.multiprocessing.set_start_method("spawn")
241 | main(args)
242 |
243 |
--------------------------------------------------------------------------------
/open_clip/pretrained.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from functools import partial
6 | from typing import Dict, Union
7 |
8 | from tqdm import tqdm
9 |
10 |
11 | def _pcfg(url='', hf_hub='', mean=None, std=None):
12 | return dict(
13 | url=url,
14 | hf_hub=hf_hub,
15 | mean=mean,
16 | std=std,
17 | )
18 |
19 |
20 | _RN50 = dict(
21 | openai=_pcfg(
22 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
23 | yfcc15m=_pcfg(
24 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
25 | cc12m=_pcfg(
26 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
27 | )
28 |
29 | _RN50_quickgelu = dict(
30 | openai=_pcfg(
31 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
32 | yfcc15m=_pcfg(
33 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
34 | cc12m=_pcfg(
35 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
36 | )
37 |
38 | _RN101 = dict(
39 | openai=_pcfg(
40 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
41 | yfcc15m=_pcfg(
42 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
43 | )
44 |
45 | _RN101_quickgelu = dict(
46 | openai=_pcfg(
47 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
48 | yfcc15m=_pcfg(
49 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
50 | )
51 |
52 | _RN50x4 = dict(
53 | openai=_pcfg(
54 | "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
55 | )
56 |
57 | _RN50x16 = dict(
58 | openai=_pcfg(
59 | "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
60 | )
61 |
62 | _RN50x64 = dict(
63 | openai=_pcfg(
64 | "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
65 | )
66 |
67 | _VITB32 = dict(
68 | openai=_pcfg(
69 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
70 | laion400m_e31=_pcfg(
71 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
72 | laion400m_e32=_pcfg(
73 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
74 | laion2b_e16=_pcfg(
75 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
76 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
77 | )
78 |
79 | _VITB32_quickgelu = dict(
80 | openai=_pcfg(
81 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
82 | laion400m_e31=_pcfg(
83 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
84 | laion400m_e32=_pcfg(
85 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
86 | )
87 |
88 | _VITB16 = dict(
89 | openai=_pcfg(
90 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
91 | laion400m_e31=_pcfg(
92 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
93 | laion400m_e32=_pcfg(
94 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
95 | # laion400m_32k=_pcfg(
96 | # url="",
97 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
98 | # laion400m_64k=_pcfg(
99 | # url="",
100 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
101 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
102 | )
103 |
104 | _VITB16_PLUS_240 = dict(
105 | laion400m_e31=_pcfg(
106 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
107 | laion400m_e32=_pcfg(
108 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
109 | )
110 |
111 | _VITL14 = dict(
112 | openai=_pcfg(
113 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
114 | laion400m_e31=_pcfg(
115 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
116 | laion400m_e32=_pcfg(
117 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
118 | laion2b_s32b_b82k=_pcfg(
119 | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
120 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
121 | )
122 |
123 | _VITL14_336 = dict(
124 | openai=_pcfg(
125 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
126 | )
127 |
128 | _VITH14 = dict(
129 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
130 | )
131 |
132 | _VITg14 = dict(
133 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
134 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
135 | )
136 |
137 | _VITbigG14 = dict(
138 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
139 | )
140 |
141 | _robertaViTB32 = dict(
142 | laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
143 | )
144 |
145 | _xlmRobertaBaseViTB32 = dict(
146 | laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
147 | )
148 |
149 | _xlmRobertaLargeFrozenViTH14 = dict(
150 | frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
151 | )
152 |
153 | _convnext_base = dict(
154 | laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
155 | )
156 |
157 | _convnext_base_w = dict(
158 | laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
159 | laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
160 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
161 | )
162 |
163 | _convnext_base_w_320 = dict(
164 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
165 | laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
166 | )
167 |
168 | _convnext_large_d = dict(
169 | laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
170 | )
171 |
172 | _convnext_large_d_320 = dict(
173 | laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
174 | laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
175 | )
176 |
177 | _convnext_xxlarge = dict(
178 | laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
179 | laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
180 | laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
181 | )
182 |
183 | _coca_VITB32 = dict(
184 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
185 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
186 | )
187 |
188 | _coca_VITL14 = dict(
189 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
190 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
191 | )
192 |
193 |
194 | _PRETRAINED = {
195 | "RN50": _RN50,
196 | "RN50-quickgelu": _RN50_quickgelu,
197 | "RN101": _RN101,
198 | "RN101-quickgelu": _RN101_quickgelu,
199 | "RN50x4": _RN50x4,
200 | "RN50x16": _RN50x16,
201 | "RN50x64": _RN50x64,
202 | "ViT-B-32": _VITB32,
203 | "ViT-B-32-quickgelu": _VITB32_quickgelu,
204 | "ViT-B-16": _VITB16,
205 | "ViT-B-16-plus-240": _VITB16_PLUS_240,
206 | "ViT-L-14": _VITL14,
207 | "ViT-L-14-336": _VITL14_336,
208 | "ViT-H-14": _VITH14,
209 | "ViT-g-14": _VITg14,
210 | "ViT-bigG-14": _VITbigG14,
211 | "roberta-ViT-B-32": _robertaViTB32,
212 | "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
213 | "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
214 | "convnext_base": _convnext_base,
215 | "convnext_base_w": _convnext_base_w,
216 | "convnext_base_w_320": _convnext_base_w_320,
217 | "convnext_large_d": _convnext_large_d,
218 | "convnext_large_d_320": _convnext_large_d_320,
219 | "convnext_xxlarge": _convnext_xxlarge,
220 | "coca_ViT-B-32": _coca_VITB32,
221 | "coca_ViT-L-14": _coca_VITL14,
222 | }
223 |
224 |
225 | def _clean_tag(tag: str):
226 | # normalize pretrained tags
227 | return tag.lower().replace('-', '_')
228 |
229 |
230 | def list_pretrained(as_str: bool = False):
231 | """ returns list of pretrained models
232 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
233 | """
234 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
235 |
236 |
237 | def list_pretrained_models_by_tag(tag: str):
238 | """ return all models having the specified pretrain tag """
239 | models = []
240 | tag = _clean_tag(tag)
241 | for k in _PRETRAINED.keys():
242 | if tag in _PRETRAINED[k]:
243 | models.append(k)
244 | return models
245 |
246 |
247 | def list_pretrained_tags_by_model(model: str):
248 | """ return all pretrain tags for the specified model architecture """
249 | tags = []
250 | if model in _PRETRAINED:
251 | tags.extend(_PRETRAINED[model].keys())
252 | return tags
253 |
254 |
255 | def is_pretrained_cfg(model: str, tag: str):
256 | if model not in _PRETRAINED:
257 | return False
258 | return _clean_tag(tag) in _PRETRAINED[model]
259 |
260 |
261 | def get_pretrained_cfg(model: str, tag: str):
262 | if model not in _PRETRAINED:
263 | return {}
264 | model_pretrained = _PRETRAINED[model]
265 | return model_pretrained.get(_clean_tag(tag), {})
266 |
267 |
268 | def get_pretrained_url(model: str, tag: str):
269 | cfg = get_pretrained_cfg(model, _clean_tag(tag))
270 | return cfg.get('url', '')
271 |
272 |
273 | def download_pretrained_from_url(
274 | url: str,
275 | cache_dir: Union[str, None] = None,
276 | ):
277 | if not cache_dir:
278 | cache_dir = os.path.expanduser("~/.cache/clip")
279 | os.makedirs(cache_dir, exist_ok=True)
280 | filename = os.path.basename(url)
281 |
282 | if 'openaipublic' in url:
283 | expected_sha256 = url.split("/")[-2]
284 | elif 'mlfoundations' in url:
285 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
286 | else:
287 | expected_sha256 = ''
288 |
289 | download_target = os.path.join(cache_dir, filename)
290 |
291 | if os.path.exists(download_target) and not os.path.isfile(download_target):
292 | raise RuntimeError(f"{download_target} exists and is not a regular file")
293 |
294 | if os.path.isfile(download_target):
295 | if expected_sha256:
296 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
297 | return download_target
298 | else:
299 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
300 | else:
301 | return download_target
302 |
303 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
304 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
305 | while True:
306 | buffer = source.read(8192)
307 | if not buffer:
308 | break
309 |
310 | output.write(buffer)
311 | loop.update(len(buffer))
312 |
313 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
314 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
315 |
316 | return download_target
317 |
318 |
319 | def download_pretrained(
320 | cfg: Dict,
321 | cache_dir: Union[str, None] = None,
322 | ):
323 | target = ''
324 | if not cfg:
325 | return target
326 |
327 | download_url = cfg.get('url', '')
328 |
329 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
330 |
331 | return target
332 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import open_clip
2 | import torch
3 | import yaml
4 | from easydict import EasyDict
5 | from models.Necker import Necker
6 | from models.Adapter import Adapter
7 | import math
8 | import argparse
9 | import warnings
10 | from utils.misc_helper import *
11 | from torch.utils.data import DataLoader
12 | from models.MapMaker import MapMaker
13 | from utils.losses import FocalLoss,BinaryDiceLoss
14 | from datasets.dataset import TrainDataset,\
15 | ChexpertTestDataset,\
16 | BusiTestDataset,\
17 | BrainMRITestDataset
18 | import pprint
19 | from tqdm import tqdm
20 | warnings.filterwarnings('ignore')
21 |
22 |
23 | @torch.no_grad()
24 | def make_vision_takens_info(model,model_cfg,layers_out):
25 |
26 | img = torch.ones((1,3,model_cfg['vision_cfg']['image_size'],
27 | model_cfg['vision_cfg']['image_size'])).to(model.device)
28 |
29 | img_feature,tokens = model.encode_image(img,layers_out)
30 |
31 | if len(tokens[0].shape)==3:
32 | model.token_size= [int(math.sqrt(token.shape[1]-1)) for token in tokens]
33 | model.token_c= [token.shape[-1] for token in tokens]
34 | else:
35 | model.token_size = [token.shape[2] for token in tokens]
36 | model.token_c = [token.shape[1] for token in tokens]
37 |
38 | model.embed_dim = model_cfg['embed_dim']
39 | print("model token size is {}".format(model.token_size)," model token dim is {}".format(model.token_c))
40 |
41 |
42 | def main(args):
43 |
44 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
45 |
46 | with open(args.config_path) as f:
47 | args.config = EasyDict(yaml.load(f, Loader=yaml.FullLoader))
48 |
49 | model, preprocess, model_cfg = open_clip.create_model_and_transforms(args.config.model_name, args.config.image_size, device=device)
50 |
51 | for param in model.parameters():
52 | param.requires_grad_(False)
53 |
54 | args.config.model_cfg = model_cfg
55 |
56 | make_vision_takens_info(model,
57 | args.config.model_cfg,
58 | args.config.layers_out)
59 |
60 | current_time = get_current_time()
61 | args.config.save_root=os.path.join(args.config.save_root,current_time)
62 |
63 | if not os.path.exists(args.config.save_root):
64 | os.makedirs(args.config.save_root)
65 |
66 | logger = create_logger("logger",os.path.join(args.config.save_root,'logger.log'))
67 | logger.info("config: {}".format(pprint.pformat(args)))
68 |
69 | necker = Necker(clip_model=model).to(model.device)
70 | adapter = Adapter(clip_model=model,target=args.config.model_cfg['embed_dim']).to(model.device)
71 |
72 | if args.config.prompt_maker=='coop':
73 | from models.CoOp import PromptMaker
74 | logger.info("load CoOp")
75 | else:
76 | raise NotImplementedError("type of prompt must in ['coop']")
77 |
78 | prompt_maker = PromptMaker(
79 | prompts=args.config.prompts,
80 | clip_model=model,
81 | n_ctx= args.config.n_learnable_token,
82 | CSC = args.config.CSC,
83 | class_token_position=args.config.class_token_positions,
84 | ).to(model.device)
85 |
86 | map_maker = MapMaker(image_size=args.config.image_size).to(model.device)
87 |
88 | optimizer = torch.optim.Adam([
89 | {'params': prompt_maker.prompt_learner.parameters(),'lr': 0.001},
90 | {'params': adapter.parameters(),"lr":0.001},
91 | ], lr=0.001, betas=(0.5, 0.999))
92 |
93 | train_dataset = TrainDataset(args=args.config,
94 | source=os.path.join(args.config.data_root,args.config.train_dataset),
95 | preprocess=preprocess,
96 | k_shot=args.k_shot)
97 |
98 | train_dataloader = DataLoader(train_dataset, batch_size=args.config.batch_size, shuffle=True, num_workers=2)
99 |
100 | test_dataloaders = {}
101 | best_record = {}
102 |
103 | for test_dataset_name in args.config.test_datasets:
104 |
105 | if test_dataset_name == 'chexpert':
106 | test_dataset = ChexpertTestDataset( args=args.config,
107 | source=os.path.join(args.config.data_root,test_dataset_name),
108 | preprocess=preprocess,
109 | )
110 |
111 | elif test_dataset_name =='brainmri':
112 |
113 | test_dataset = BrainMRITestDataset(
114 | args=args.config,
115 | source=os.path.join(args.config.data_root,test_dataset_name),
116 | preprocess=preprocess,
117 | )
118 | elif test_dataset_name =='busi':
119 |
120 | test_dataset = BusiTestDataset(
121 | args=args.config,
122 | source=os.path.join(args.config.data_root,test_dataset_name),
123 | preprocess=preprocess)
124 | else:
125 | raise NotImplementedError("dataset must in ['chexpert','busi','brainmri'] ")
126 |
127 | test_dataloader = DataLoader(test_dataset, batch_size=args.config.batch_size,num_workers=2)
128 | test_dataloaders[test_dataset_name]=test_dataloader
129 | best_record[test_dataset_name]=None
130 |
131 | logger.info("train data ({}) len {}".format(args.config.train_dataset,len(train_dataset)))
132 |
133 | for test_dataset_name in test_dataloaders:
134 | logger.info("test data ({}) len {}".format(test_dataset_name, len(test_dataloaders[test_dataset_name].dataset)))
135 |
136 | for task_name in args.config.anomaly_tasks:
137 | logger.info("anomaly syn task is {}, sampling probability is {}".format(task_name,args.config.anomaly_tasks[task_name]))
138 |
139 | for epoch in range(0, args.config.epoch):
140 | last_iter = epoch * len(train_dataloader)
141 |
142 | train_one_epoch(
143 | args,
144 | train_dataloader,
145 | optimizer,
146 | epoch,
147 | last_iter,
148 | logger,
149 | model,
150 | necker,
151 | adapter,
152 | prompt_maker,
153 | map_maker,
154 | )
155 |
156 | if (epoch+1) % args.config.val_freq_epoch == 0:
157 |
158 | results = validate(args,test_dataloaders, epoch,model, necker,adapter,prompt_maker,map_maker)
159 | save_flag = False
160 |
161 | for test_dataset_name in results:
162 | if best_record[test_dataset_name] is None:
163 | if test_dataset_name=='busi':
164 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"],
165 | results[test_dataset_name]['pixel-auroc']]
166 | else:
167 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"]]
168 |
169 | save_flag=True
170 | else:
171 | if np.mean([results[test_dataset_name][key] for key in results[test_dataset_name]]) > np.mean(best_record[test_dataset_name]):
172 | if test_dataset_name == 'busi':
173 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"],
174 | results[test_dataset_name]['pixel-auroc']]
175 | else:
176 | best_record[test_dataset_name] = [results[test_dataset_name]["image-auroc"]]
177 | save_flag=True
178 |
179 |
180 | if test_dataset_name=='busi':
181 | logger.info("({}): Epoch: {}, image auroc: {:.4f}, pixel_auroc: {:.4f},".format(test_dataset_name,
182 | epoch+1,
183 | results[test_dataset_name]["image-auroc"],
184 | results[test_dataset_name]['pixel-auroc']))
185 | else:
186 | logger.info("({}): Epoch: {}, image auroc: {:.4f},".format(
187 | test_dataset_name,
188 | epoch+1,
189 | results[test_dataset_name]["image-auroc"],
190 | ))
191 |
192 | for test_dataset_name in results:
193 | if test_dataset_name == 'busi':
194 | logger.info(
195 | "({} best): image auroc: {:.4f}, pixel auroc: {:.4f},".format(
196 | test_dataset_name,
197 | best_record[test_dataset_name][0],
198 | best_record[test_dataset_name][1],
199 | ))
200 | else:
201 | logger.info(
202 | "({} best): image auroc: {:.4f},".format(
203 | test_dataset_name,
204 | best_record[test_dataset_name][0],
205 | ))
206 |
207 | if save_flag:
208 | logger.info("save checkpoints in epoch: {}".format(epoch+1))
209 | torch.save({
210 | "adapter_state_dict": adapter.state_dict(),
211 | "prompt_state_dict": prompt_maker.prompt_learner.state_dict(),
212 | }, os.path.join(args.config.save_root, 'checkpoints_{}.pkl'.format(epoch + 1)))
213 |
214 |
215 | def train_one_epoch(
216 | args,
217 | train_dataloader,
218 | optimizer,
219 | epoch,
220 | start_iter,
221 | logger,
222 | clip_model,
223 | necker,
224 | adapter,
225 | prompt_maker,
226 | map_maker,
227 | ):
228 |
229 | loss_meter = AverageMeter(args.config.print_freq_step)
230 |
231 | focal_criterion = FocalLoss()
232 | dice_criterion = BinaryDiceLoss()
233 |
234 | adapter.train()
235 | prompt_maker.train()
236 |
237 | for i, input in enumerate(train_dataloader):
238 | curr_step = start_iter + i
239 |
240 | images = input['image'].to(clip_model.device)
241 | gt_mask = input['mask'].to(clip_model.device)
242 |
243 | with torch.no_grad():
244 | _, image_tokens = clip_model.encode_image(images,out_layers=args.config.layers_out)
245 | image_features = necker(image_tokens)
246 |
247 | vision_adapter_features = adapter(image_features)
248 | propmt_adapter_features = prompt_maker(vision_adapter_features)
249 | anomaly_map = map_maker(vision_adapter_features,propmt_adapter_features)
250 |
251 | loss = []
252 |
253 | loss.append(focal_criterion(anomaly_map,gt_mask))
254 | loss.append(dice_criterion(anomaly_map[:, 1, :, :],gt_mask))
255 |
256 | loss = torch.sum(torch.stack(loss))
257 | loss_meter.update(loss.item())
258 |
259 | optimizer.zero_grad()
260 | loss.backward()
261 | optimizer.step()
262 |
263 | if (curr_step + 1) % args.config.print_freq_step == 0:
264 | logger.info(
265 | "Epoch: [{0}/{1}]\t"
266 | "Iter: [{2}/{3}]\t"
267 | "Loss {loss.val:.5f} ({loss.avg:.5f})\t"
268 | .format(
269 | epoch+1 ,
270 | args.config.epoch,
271 | curr_step + 1,
272 | len(train_dataloader) * args.config.epoch,
273 | loss=loss_meter,
274 | )
275 | )
276 |
277 |
278 | def validate(args, test_dataloaders, epoch, clip_model, necker, adapter, prompt_maker, map_maker):
279 |
280 | adapter.eval()
281 | prompt_maker.eval()
282 | results = {}
283 |
284 | for test_dataset_name in test_dataloaders:
285 | test_dataloader = test_dataloaders[test_dataset_name]
286 |
287 | anomaly_maps = []
288 | anomaly_gts = []
289 |
290 | image_scores = []
291 | image_labels = []
292 |
293 | with torch.no_grad():
294 | for i, input in enumerate(tqdm(test_dataloader,desc=test_dataset_name)):
295 |
296 | images = input['image'].to(clip_model.device)
297 |
298 | _, image_tokens = clip_model.encode_image(images, out_layers=args.config.layers_out)
299 | image_features = necker(image_tokens)
300 | vision_adapter_features = adapter(image_features)
301 | propmt_adapter_features = prompt_maker(vision_adapter_features)
302 | anomaly_map = map_maker(vision_adapter_features, propmt_adapter_features)
303 |
304 | B,_,H,W = anomaly_map.shape
305 |
306 | anomaly_map = anomaly_map[:,1,:,:]
307 | anomaly_gt = input['mask']
308 |
309 | anomaly_maps.append(anomaly_map.cpu().numpy())
310 | anomaly_gts.append(anomaly_gt.cpu().numpy())
311 |
312 | anomaly_score,_ = torch.max(anomaly_map.view((B,H*W)), dim=-1)
313 |
314 | image_scores.extend(anomaly_score.cpu().numpy().tolist())
315 | image_labels.extend(input['is_anomaly'].cpu().numpy().tolist())
316 |
317 | metric = compute_imagewise_metrics(image_scores,image_labels)
318 |
319 | if test_dataset_name=='busi':
320 | metric.update(compute_pixelwise_metrics(anomaly_maps,anomaly_gts))
321 |
322 | results[test_dataset_name] = metric
323 | return results
324 |
325 |
326 | if __name__ == '__main__':
327 | parser = argparse.ArgumentParser(description="Train MediCLIP")
328 | parser.add_argument("--config_path", type=str, default='config/brainmri.yaml', help="model configs")
329 | parser.add_argument("--k_shot", type=int, default=16, help="normal image number")
330 | args = parser.parse_args()
331 | torch.multiprocessing.set_start_method("spawn")
332 | main(args)
--------------------------------------------------------------------------------
/open_clip/model.py:
--------------------------------------------------------------------------------
1 | """ CLIP Model
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | from dataclasses import dataclass
6 | import logging
7 | import math
8 | from typing import Optional, Tuple, Union
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn
14 | from torch.utils.checkpoint import checkpoint
15 |
16 | from .modified_resnet import ModifiedResNet
17 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
18 | from .utils import to_2tuple
19 |
20 |
21 | @dataclass
22 | class CLIPVisionCfg:
23 | layers: Union[Tuple[int, int, int, int], int] = 12
24 | width: int = 768
25 | head_width: int = 64
26 | mlp_ratio: float = 4.0
27 | patch_size: int = 16
28 | image_size: Union[Tuple[int, int], int] = 224
29 | ls_init_value: Optional[float] = None # layer scale initial value
30 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
31 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
32 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
33 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
34 | n_queries: int = 256 # n_queries for attentional pooler
35 | attn_pooler_heads: int = 8 # n heads for attentional_pooling
36 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size
37 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
38 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
39 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
40 | timm_proj_bias: bool = False # enable bias final projection
41 | timm_drop: float = 0. # head dropout
42 | timm_drop_path: Optional[float] = None # backbone stochastic depth
43 | output_tokens: bool = False
44 |
45 |
46 | @dataclass
47 | class CLIPTextCfg:
48 | context_length: int = 77
49 | vocab_size: int = 49408
50 | width: int = 512
51 | heads: int = 8
52 | layers: int = 12
53 | ls_init_value: Optional[float] = None # layer scale initial value
54 | hf_model_name: str = None
55 | hf_tokenizer_name: str = None
56 | hf_model_pretrained: bool = True
57 | proj: str = 'mlp'
58 | pooler_type: str = 'mean_pooler'
59 | embed_cls: bool = False
60 | pad_id: int = 0
61 | output_tokens: bool = False
62 |
63 |
64 | def get_cast_dtype(precision: str):
65 | cast_dtype = None
66 | if precision == 'bf16':
67 | cast_dtype = torch.bfloat16
68 | elif precision == 'fp16':
69 | cast_dtype = torch.float16
70 | return cast_dtype
71 |
72 |
73 | def _build_vision_tower(
74 | embed_dim: int,
75 | vision_cfg: CLIPVisionCfg,
76 | quick_gelu: bool = False,
77 | cast_dtype: Optional[torch.dtype] = None
78 | ):
79 | if isinstance(vision_cfg, dict):
80 | vision_cfg = CLIPVisionCfg(**vision_cfg)
81 |
82 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
83 | # memory efficient in recent PyTorch releases (>= 1.10).
84 | # NOTE: timm models always use native GELU regardless of quick_gelu flag.
85 | act_layer = QuickGELU if quick_gelu else nn.GELU
86 |
87 | if isinstance(vision_cfg.layers, (tuple, list)):
88 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
89 | visual = ModifiedResNet(
90 | layers=vision_cfg.layers,
91 | output_dim=embed_dim,
92 | heads=vision_heads,
93 | image_size=vision_cfg.image_size,
94 | width=vision_cfg.width,
95 | )
96 | else:
97 | vision_heads = vision_cfg.width // vision_cfg.head_width
98 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
99 | visual = VisionTransformer(
100 | image_size=vision_cfg.image_size,
101 | patch_size=vision_cfg.patch_size,
102 | width=vision_cfg.width,
103 | layers=vision_cfg.layers,
104 | heads=vision_heads,
105 | mlp_ratio=vision_cfg.mlp_ratio,
106 | ls_init_value=vision_cfg.ls_init_value,
107 | patch_dropout=vision_cfg.patch_dropout,
108 | input_patchnorm=vision_cfg.input_patchnorm,
109 | global_average_pool=vision_cfg.global_average_pool,
110 | attentional_pool=vision_cfg.attentional_pool,
111 | n_queries=vision_cfg.n_queries,
112 | attn_pooler_heads=vision_cfg.attn_pooler_heads,
113 | output_tokens=vision_cfg.output_tokens,
114 | output_dim=embed_dim,
115 | act_layer=act_layer,
116 | norm_layer=norm_layer,
117 | )
118 |
119 | return visual
120 |
121 |
122 | def _build_text_tower(
123 | embed_dim: int,
124 | text_cfg: CLIPTextCfg,
125 | quick_gelu: bool = False,
126 | cast_dtype: Optional[torch.dtype] = None,
127 | ):
128 | if isinstance(text_cfg, dict):
129 | text_cfg = CLIPTextCfg(**text_cfg)
130 |
131 |
132 | act_layer = QuickGELU if quick_gelu else nn.GELU
133 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
134 |
135 | text = TextTransformer(
136 | context_length=text_cfg.context_length,
137 | vocab_size=text_cfg.vocab_size,
138 | width=text_cfg.width,
139 | heads=text_cfg.heads,
140 | layers=text_cfg.layers,
141 | ls_init_value=text_cfg.ls_init_value,
142 | output_dim=embed_dim,
143 | embed_cls=text_cfg.embed_cls,
144 | output_tokens=text_cfg.output_tokens,
145 | pad_id=text_cfg.pad_id,
146 | act_layer=act_layer,
147 | norm_layer=norm_layer,
148 | )
149 |
150 | return text
151 |
152 |
153 | class CLIP(nn.Module):
154 | output_dict: torch.jit.Final[bool]
155 |
156 | def __init__(
157 | self,
158 | embed_dim: int,
159 | vision_cfg: CLIPVisionCfg,
160 | text_cfg: CLIPTextCfg,
161 | quick_gelu: bool = False,
162 | cast_dtype: Optional[torch.dtype] = None,
163 | output_dict: bool = False,
164 | ):
165 | super().__init__()
166 | self.output_dict = output_dict
167 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
168 |
169 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
170 |
171 | self.transformer = text.transformer
172 |
173 | self.vocab_size = text.vocab_size
174 | self.token_embedding = text.token_embedding
175 |
176 | self.positional_embedding = text.positional_embedding
177 | self.ln_final = text.ln_final
178 | self.text_projection = text.text_projection
179 | self.register_buffer('attn_mask', text.attn_mask, persistent=False)
180 |
181 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
182 |
183 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
184 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
185 |
186 | @torch.jit.ignore
187 | def set_grad_checkpointing(self, enable=True):
188 | self.visual.set_grad_checkpointing(enable)
189 | self.transformer.grad_checkpointing = enable
190 |
191 | def encode_image(self, image, out_layers):
192 | features = self.visual(image, out_layers)
193 | return features
194 |
195 |
196 | def encode_text(self, text, normalize: bool = False):
197 | cast_dtype = self.transformer.get_cast_dtype()
198 |
199 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
200 |
201 | x = x + self.positional_embedding.to(cast_dtype)
202 | x = x.permute(1, 0, 2) # NLD -> LND
203 | x, attn, tokens = self.transformer(x, attn_mask=self.attn_mask)
204 | x = x.permute(1, 0, 2) # LND -> NLD
205 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
206 | # take features from the eot embedding (eot_token is the highest number in each sequence)
207 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
208 |
209 | return F.normalize(x, dim=-1) if normalize else x
210 |
211 |
212 |
213 |
214 | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
215 | """Convert applicable model parameters to low-precision (bf16 or fp16)"""
216 |
217 | def _convert_weights(l):
218 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
219 | l.weight.data = l.weight.data.to(dtype)
220 | if l.bias is not None:
221 | l.bias.data = l.bias.data.to(dtype)
222 |
223 | if isinstance(l, (nn.MultiheadAttention, Attention)):
224 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
225 | tensor = getattr(l, attr)
226 | if tensor is not None:
227 | tensor.data = tensor.data.to(dtype)
228 |
229 | for name in ["text_projection", "proj"]:
230 | if hasattr(l, name):
231 | attr = getattr(l, name)
232 | if attr is not None:
233 | attr.data = attr.data.to(dtype)
234 |
235 | model.apply(_convert_weights)
236 |
237 |
238 | convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
239 |
240 |
241 | # used to maintain checkpoint compatibility
242 | def convert_to_custom_text_state_dict(state_dict: dict):
243 | if 'text_projection' in state_dict:
244 | # old format state_dict, move text tower -> .text
245 | new_state_dict = {}
246 | for k, v in state_dict.items():
247 | if any(k.startswith(p) for p in (
248 | 'text_projection',
249 | 'positional_embedding',
250 | 'token_embedding',
251 | 'transformer',
252 | 'ln_final',
253 | )):
254 | k = 'text.' + k
255 | new_state_dict[k] = v
256 | return new_state_dict
257 | return state_dict
258 |
259 |
260 | def build_model_from_openai_state_dict(
261 | state_dict: dict,
262 | quick_gelu=True,
263 | cast_dtype=torch.float16,
264 | ):
265 | vit = "visual.proj" in state_dict
266 |
267 | if vit:
268 | vision_width = state_dict["visual.conv1.weight"].shape[0]
269 | vision_layers = len(
270 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
271 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
272 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
273 | image_size = vision_patch_size * grid_size
274 | else:
275 | counts: list = [
276 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
277 | vision_layers = tuple(counts)
278 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
279 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
280 | vision_patch_size = None
281 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
282 | image_size = output_width * 32
283 |
284 | embed_dim = state_dict["text_projection"].shape[1]
285 | context_length = state_dict["positional_embedding"].shape[0]
286 | vocab_size = state_dict["token_embedding.weight"].shape[0]
287 | transformer_width = state_dict["ln_final.weight"].shape[0]
288 | transformer_heads = transformer_width // 64
289 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
290 |
291 | vision_cfg = CLIPVisionCfg(
292 | layers=vision_layers,
293 | width=vision_width,
294 | patch_size=vision_patch_size,
295 | image_size=image_size,
296 | )
297 | text_cfg = CLIPTextCfg(
298 | context_length=context_length,
299 | vocab_size=vocab_size,
300 | width=transformer_width,
301 | heads=transformer_heads,
302 | layers=transformer_layers,
303 | )
304 |
305 | model = CLIP(
306 | embed_dim,
307 | vision_cfg=vision_cfg,
308 | text_cfg=text_cfg,
309 | quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
310 | cast_dtype=cast_dtype,
311 | )
312 |
313 | for key in ["input_resolution", "context_length", "vocab_size"]:
314 | state_dict.pop(key, None)
315 |
316 | convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
317 | model.load_state_dict(state_dict)
318 |
319 | return model.eval()
320 |
321 |
322 | def trace_model(model, batch_size=256, device=torch.device('cpu')):
323 | model.eval()
324 | image_size = model.visual.image_size
325 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
326 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
327 | model = torch.jit.trace_module(
328 | model,
329 | inputs=dict(
330 | forward=(example_images, example_text),
331 | encode_text=(example_text,),
332 | encode_image=(example_images,)
333 | ))
334 | model.visual.image_size = image_size
335 | return model
336 |
337 |
338 |
339 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
340 | # Rescale the grid of position embeddings when loading from state_dict
341 | flag = 1
342 | old_pos_embed = state_dict.get('visual.positional_embedding', None)
343 | if old_pos_embed is None:
344 | flag = 0
345 | old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None)
346 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
347 | return
348 | grid_size = to_2tuple(model.visual.grid_size)
349 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
350 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
351 | if new_seq_len == old_pos_embed.shape[0]:
352 | return
353 |
354 | if extra_tokens:
355 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
356 | else:
357 | pos_emb_tok, pos_emb_img = None, old_pos_embed
358 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
359 |
360 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
361 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
362 | pos_emb_img = F.interpolate(
363 | pos_emb_img,
364 | size=grid_size,
365 | mode=interpolation,
366 | antialias=antialias,
367 | align_corners=False,
368 | )
369 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
370 | if pos_emb_tok is not None:
371 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
372 | else:
373 | new_pos_embed = pos_emb_img
374 | if flag:
375 | state_dict['visual.positional_embedding'] = new_pos_embed
376 | else:
377 | state_dict['visual.attnpool.positional_embedding'] = new_pos_embed
378 |
--------------------------------------------------------------------------------
/medsyn/task_shape.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod, ABC
2 | import functools
3 | import itertools
4 | from typing import Callable, List, Tuple
5 |
6 | import numpy as np
7 | import numpy.typing as npt
8 | from skimage.morphology import convex_hull_image
9 |
10 | from .utils import accumulate_rotation, accumulate_scaling
11 | import math
12 | import imgaug.augmenters as iaa
13 |
14 |
15 | class PatchShapeMaker(ABC):
16 |
17 | @abstractmethod
18 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray:
19 | """
20 | :param dim_bounds: Tuples giving lower and upper bounds for patch size in each dimension
21 | :param img_dims: Image dimensions, can be used as scaling factor.
22 | Creates a patch mask to be used in the self-supervised task.
23 | Mask must have length(dim_bounds) dimensions.
24 | """
25 | pass
26 |
27 | def __call__(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray:
28 | return self.get_patch_mask(dim_bounds, img_dims)
29 |
30 |
31 | class PerlinPatchMaker(PatchShapeMaker):
32 |
33 | def __init__(self,
34 | min_perlin_scale=0,
35 | perlin_scale=1,
36 | perlin_noise_threshold = 0.3,
37 | perlin_min_size = 0.2
38 | ):
39 |
40 | self.min_perlin_scale = min_perlin_scale
41 | self.perlin_scale = perlin_scale
42 | self.perlin_noise_threshold = perlin_noise_threshold
43 | self.perlin_min_size = perlin_min_size
44 |
45 | def lerp_np(self, x, y, w):
46 | fin_out = (y - x) * w + x
47 | return fin_out
48 |
49 | def rand_perlin_2d_np(self, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
50 | delta = (res[0] / shape[0], res[1] / shape[1])
51 | d = (shape[0] // res[0], shape[1] // res[1])
52 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
53 |
54 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
55 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
56 | tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1)
57 |
58 | tile_grads = lambda slice1, slice2: np.repeat(
59 | np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]], d[0], axis=0), d[1], axis=1)
60 | dot = lambda grad, shift: (
61 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
62 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
63 |
64 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
65 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
66 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
67 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
68 | t = fade(grid[:shape[0], :shape[1]])
69 | return math.sqrt(2) * self.lerp_np(self.lerp_np(n00, n10, t[..., 0]), self.lerp_np(n01, n11, t[..., 0]), t[..., 1])
70 |
71 |
72 | def get_patch_mask_and_intersect_fn(self,
73 | dim_bounds: List[Tuple[int, int]],
74 | img_dims: np.ndarray) \
75 | -> Tuple[npt.NDArray[bool], Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]]:
76 |
77 | perlin_scalex = 2 ** (np.random.randint(low = self.min_perlin_scale, high = self.perlin_scale, size=(1,))[0])
78 | perlin_scaley = 2 ** (np.random.randint(low = self.min_perlin_scale, high = self.perlin_scale, size=(1,))[0])
79 |
80 | noise_size_x = np.random.randint(low=dim_bounds[0][0],high=dim_bounds[0][1])
81 | noise_size_y = np.random.randint(low=dim_bounds[1][0],high=dim_bounds[1][1])
82 |
83 | while True:
84 | perlin_noise = self.rand_perlin_2d_np((noise_size_x, noise_size_y), (perlin_scalex, perlin_scaley))
85 |
86 | # apply affine transform
87 | rot = iaa.Affine(rotate=(-90, 90))
88 | perlin_noise = rot(image=perlin_noise)
89 |
90 | # make a mask by applying threshold
91 | mask_noise = np.where(
92 | perlin_noise > self.perlin_noise_threshold,
93 | np.ones_like(perlin_noise),
94 | np.zeros_like(perlin_noise)
95 | )
96 | mask_noise[mask_noise != 0] = 1.0
97 |
98 | if np.mean(mask_noise) >= self.perlin_min_size:
99 | break
100 | return mask_noise.astype(np.bool), None
101 |
102 |
103 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]:
104 | return self.get_patch_mask_and_intersect_fn(dim_bounds, img_dims)[0]
105 |
106 |
107 |
108 | class DeformedHypershapePatchMaker(PatchShapeMaker):
109 |
110 | def __init__(self, aligned_distance_to_edge_fn: Callable[[npt.NDArray[int], npt.NDArray[float], npt.NDArray[float]],
111 | npt.NDArray[float]],
112 | sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)):
113 | # Instead of mask function, need inclusion test function
114 | # takes shape + point as parameters
115 | self.aligned_distance_to_edge_fn = aligned_distance_to_edge_fn
116 | self.sample_dist = sample_dist
117 | self.rng = np.random.default_rng()
118 |
119 | @abstractmethod
120 | def within_aligned_shape(self, array_of_coords: npt.NDArray[float], shape_size: npt.NDArray[int]) \
121 | -> npt.NDArray[bool]:
122 | """
123 | Calculates whether a point is within an aligned shape with dimensions shape_size.
124 | :param array_of_coords:
125 | :param shape_size:
126 | """
127 |
128 | def make_shape_intersect_function(self,
129 | inv_trans_matrix: npt.NDArray[float],
130 | trans_matrix: npt.NDArray[float],
131 | shape_size: npt.NDArray[int]) \
132 | -> Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]:
133 |
134 | return lambda orig, direction: trans_matrix @ self.aligned_distance_to_edge_fn(shape_size,
135 | inv_trans_matrix @ orig,
136 | inv_trans_matrix @ direction)
137 |
138 | def get_patch_mask_and_intersect_fn(self,
139 | dim_bounds: List[Tuple[int, int]],
140 | img_dims: np.ndarray) \
141 | -> Tuple[npt.NDArray[bool], Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]]:
142 |
143 | shape = np.array([self.sample_dist(lb, ub, d) for ((lb, ub), d) in zip(dim_bounds, img_dims)])
144 |
145 | num_dims = len(dim_bounds)
146 |
147 | trans_matrix = np.identity(num_dims)
148 |
149 | # Instead of transforming mask, accumulate transformation matrix
150 | for d in range(num_dims):
151 | other_dim = self.rng.choice([i for i in range(num_dims) if i != d])
152 | shear_factor = self.rng.normal(scale=0.2)
153 | trans_matrix[d, other_dim] = shear_factor
154 |
155 | # Rotate mask, using all possible access combinations
156 | trans_matrix = functools.reduce(lambda m, ds: accumulate_rotation(m,
157 | self.rng.uniform(-np.pi / 2, np.pi / 2),
158 | ds),
159 | itertools.combinations(range(num_dims), 2),
160 | trans_matrix)
161 |
162 | # Using corner points, calculate size of resulting shape
163 | shape_width = (shape - 1) / 2
164 | corner_coord_grid = np.array(np.meshgrid(*np.stack([shape_width, -shape_width], axis=-1), indexing='ij'))
165 | corner_coords = corner_coord_grid.reshape(num_dims, 2 ** num_dims)
166 |
167 | trans_corner_coords = trans_matrix @ corner_coords
168 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1))
169 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1))
170 | final_grid_shape = max_trans_coords + 1 - min_trans_coords
171 |
172 | # Check if transformations have made patch too big
173 | ub_shape = np.array([ub for _, ub in dim_bounds])
174 | if np.any(final_grid_shape > ub_shape):
175 | # If so, scale down to be within limits.
176 | max_scale_diff = np.max(final_grid_shape / ub_shape)
177 | trans_matrix = accumulate_scaling(trans_matrix, 1 / max_scale_diff)
178 |
179 | # Repeat calculations with new transformation matrix
180 | trans_corner_coords = trans_matrix @ corner_coords
181 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1))
182 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1))
183 | final_grid_shape = max_trans_coords + 1 - min_trans_coords
184 |
185 | # Create meshgrid of coords of resulting shape
186 | coord_ranges = [np.arange(lb, ub + 1) for lb, ub in zip(min_trans_coords, max_trans_coords)]
187 | coord_grid = np.array(np.meshgrid(*coord_ranges, indexing='ij'))
188 |
189 | # Apply inverse transformation matrix, to compute sampling points
190 | inv_trans_matrix = np.linalg.inv(trans_matrix)
191 | inv_coord_grid_f = inv_trans_matrix @ np.reshape(coord_grid, (num_dims, -1))
192 |
193 | # Apply inclusion test function, giving an array containing a boolean for each coordinate.
194 | inv_result_grid_f = self.within_aligned_shape(inv_coord_grid_f, shape)
195 |
196 | return np.reshape(inv_result_grid_f, final_grid_shape.astype(int)), \
197 | self.make_shape_intersect_function(inv_trans_matrix, trans_matrix, shape)
198 |
199 |
200 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]:
201 | return self.get_patch_mask_and_intersect_fn(dim_bounds, img_dims)[0]
202 |
203 |
204 |
205 |
206 |
207 | def intersect_to_aligned_hyperrectangle_edge(hyperrectangle_shape: npt.NDArray[int],
208 | origin: npt.NDArray[float],
209 | direction: npt.NDArray[float]) \
210 | -> npt.NDArray[float]:
211 | """
212 |
213 | :param hyperrectangle_shape: Numpy array of hyperrectangle shape, (D)
214 | :param origin: Numpy array of origin coordinates (D, N)
215 | :param direction: Numpy array of normalised direction vectors (D, N)
216 | :return:
217 | """
218 | rect_width = np.reshape(hyperrectangle_shape / 2, (-1, 1))
219 | # Normalise direction magnitudes
220 | direction = direction / np.linalg.norm(direction, axis=0)
221 |
222 | max_dir = rect_width - origin
223 | min_dir = - rect_width - origin
224 |
225 | # Shape (2D, N)
226 | all_edge_distances = np.concatenate([max_dir / direction,
227 | min_dir / direction])
228 |
229 | # Set any behind distances to inf, so we don't choose them
230 | all_edge_distances[all_edge_distances < 0] = np.inf
231 |
232 | # Shape (N,)
233 | min_distances = np.min(all_edge_distances, axis=0)
234 |
235 | min_is_inf = min_distances == np.inf
236 | assert not np.any(min_is_inf), f'Some lines are outside bounding box:\n' \
237 | f'rectangle shape: {hyperrectangle_shape},' \
238 | f'origins - {origin[min_is_inf]},\n' \
239 | f'directions - {direction[min_is_inf]}'
240 |
241 | return origin + direction * min_distances
242 |
243 |
244 |
245 | class DeformedHyperrectanglePatchMaker(DeformedHypershapePatchMaker):
246 |
247 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)):
248 | super().__init__(intersect_to_aligned_hyperrectangle_edge, sample_dist)
249 |
250 | def within_aligned_shape(self, array_of_coords: npt.NDArray[float], shape_size: npt.NDArray[int]) \
251 | -> npt.NDArray[bool]:
252 | rect_width = np.reshape(shape_size / 2, (-1, 1))
253 | return np.all(-rect_width <= array_of_coords, axis=0) & np.all(array_of_coords <= rect_width, axis=0)
254 |
255 |
256 |
257 | def intersect_aligned_hyperellipse_edge(hyperellipse_shape: npt.NDArray[int],
258 | origin: npt.NDArray[float],
259 | direction: npt.NDArray[float]) \
260 | -> npt.NDArray[float]:
261 | ellipse_radii_sq = np.reshape((hyperellipse_shape / 2) ** 2, (-1, 1))
262 | # Normalise direction magnitudes
263 | direction = direction / np.linalg.norm(direction, axis=0)
264 |
265 | # Compute quadratic coefficients, all shape (N)
266 | a = np.sum(direction ** 2 / ellipse_radii_sq, axis=0)
267 | b = np.sum(2 * origin * direction / ellipse_radii_sq, axis=0)
268 | c = np.sum(origin ** 2 / ellipse_radii_sq, axis=0) - 1
269 |
270 | # Solve quadratic, (N)
271 | det = b ** 2 - 4 * a * c
272 |
273 | det_is_negative = det < 0
274 | assert not np.any(det_is_negative), f'Some lines never intersect ellipse:\n' \
275 | f'Ellipse shape: {hyperellipse_shape}\n' \
276 | f'origins: {origin[det_is_negative]}' \
277 | f'directions: {direction[det_is_negative]}'
278 |
279 | solutions = (-b + np.array([[1], [-1]]) * np.sqrt(det)) / (2 * a)
280 |
281 | # Make any negative solutions (behind origin) infinity so we don't choose them
282 | solutions[solutions < 0] = np.inf
283 |
284 | min_solutions = np.min(solutions, axis=0)
285 | min_is_inf = min_solutions == np.inf
286 | assert not np.any(min_is_inf), f'Some lines are outside ellipse:\n' \
287 | f'ellipse shape: {hyperellipse_shape},' \
288 | f'origins - {origin[min_is_inf]},\n' \
289 | f'directions - {direction[min_is_inf]}'
290 | return origin + direction * min_solutions
291 |
292 |
293 | class DeformedHyperellipsePatchMaker(DeformedHypershapePatchMaker):
294 |
295 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)):
296 | super().__init__(intersect_aligned_hyperellipse_edge, sample_dist)
297 |
298 | def within_aligned_shape(self, array_of_coords: npt.NDArray[float], shape_size: npt.NDArray[int]) \
299 | -> npt.NDArray[bool]:
300 | ellipse_radii = np.reshape(shape_size / 2, (-1, 1))
301 | return np.sum(array_of_coords ** 2 / ellipse_radii ** 2, axis=0) <= 1
302 |
303 |
304 | class CombinedDeformedHypershapePatchMaker(PatchShapeMaker):
305 |
306 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)):
307 |
308 | self.rect_maker = DeformedHyperrectanglePatchMaker(sample_dist)
309 | self.ellip_maker = DeformedHyperellipsePatchMaker(sample_dist)
310 | self.last_choice = None
311 |
312 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]:
313 |
314 | mode = np.random.choice(['rect', 'ellipse', 'comb'])
315 | self.last_choice = mode
316 |
317 | if mode == 'rect':
318 | return self.rect_maker(dim_bounds, img_dims)
319 |
320 | elif mode == 'ellipse':
321 | return self.ellip_maker(dim_bounds, img_dims)
322 |
323 | elif mode == 'comb':
324 |
325 | rect_mask = self.rect_maker(dim_bounds, img_dims)
326 | ellip_mask = self.ellip_maker(dim_bounds, img_dims)
327 |
328 | rect_size = np.sum(rect_mask)
329 | ellip_size = np.sum(ellip_mask)
330 |
331 | big_m, small_m = (rect_mask, ellip_mask) if rect_size >= ellip_size else (ellip_mask, rect_mask)
332 |
333 | # Choose point on big mask to put small mask
334 | mask_coords = np.nonzero(big_m)
335 |
336 | point_ind = self.rect_maker.rng.integers(len(mask_coords[0]))
337 |
338 | chosen_coord = np.array([m_cs[point_ind] for m_cs in mask_coords])
339 |
340 | small_shape = np.array(small_m.shape)
341 | lower_coord = chosen_coord - small_shape // 2
342 |
343 | if np.any(lower_coord < 0):
344 | to_pad_below = np.maximum(-lower_coord, 0)
345 | big_m = np.pad(big_m, [(p, 0) for p in to_pad_below])
346 | lower_coord += to_pad_below
347 |
348 | big_shape = np.array(big_m.shape)
349 |
350 | upper_coord = lower_coord + small_shape
351 | if np.any(upper_coord > big_shape):
352 | to_pad_above = np.maximum(upper_coord - big_shape, 0)
353 | big_m = np.pad(big_m, [(0, p) for p in to_pad_above])
354 |
355 | big_m[tuple([slice(lb, ub) for lb, ub in zip(lower_coord, upper_coord)])] |= small_m
356 |
357 | return convex_hull_image(big_m)
358 |
359 | else:
360 | raise Exception('Invalid mask option')
361 |
362 |
363 | class EitherDeformedHypershapePatchMaker(PatchShapeMaker):
364 |
365 | def __init__(self,
366 | sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)):
367 |
368 | self.rect_maker = DeformedHyperrectanglePatchMaker(sample_dist)
369 | self.ellip_maker = DeformedHyperellipsePatchMaker(sample_dist)
370 |
371 | def get_patch_mask_and_intersect_fn(self,
372 | dim_bounds: List[Tuple[int, int]],
373 | img_dims: np.ndarray) \
374 | -> Tuple[npt.NDArray[bool], Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]]:
375 |
376 | chosen_task = np.random.choice([self.rect_maker, self.ellip_maker])
377 | return chosen_task.get_patch_mask_and_intersect_fn(dim_bounds, img_dims)
378 |
379 |
380 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> npt.NDArray[bool]:
381 | return self.get_patch_mask_and_intersect_fn(dim_bounds, img_dims)[0]
382 |
--------------------------------------------------------------------------------
/medsyn/tasks.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Callable, Optional, Tuple, Union
3 | import functools
4 | import itertools
5 | import numpy as np
6 | import numpy.typing as npt
7 | from scipy.ndimage import affine_transform,distance_transform_edt
8 | from numpy.linalg import norm
9 | import random
10 | from scipy import ndimage
11 | from scipy.ndimage import gaussian_filter
12 |
13 | from .labelling import FlippedGaussianLabeller,AnomalyLabeller
14 | from .task_shape import EitherDeformedHypershapePatchMaker,PerlinPatchMaker
15 | from .utils import *
16 |
17 | def cut_paste(sample: npt.NDArray[float],
18 | source_to_blend: npt.NDArray[float],
19 | anomaly_corner: npt.NDArray[int],
20 | anomaly_mask: npt.NDArray[bool]) -> npt.NDArray[float]:
21 |
22 | repeated_mask = np.broadcast_to(anomaly_mask, source_to_blend.shape)
23 |
24 | sample_slices = get_patch_image_slices(anomaly_corner, tuple(anomaly_mask.shape))
25 |
26 | aug_sample = sample.copy()
27 | aug_sample[sample_slices][repeated_mask] = source_to_blend[repeated_mask]
28 |
29 | return aug_sample
30 |
31 |
32 |
33 | class BaseTask(ABC):
34 | def __init__(self,
35 | sample_labeller: Optional[AnomalyLabeller] = None,
36 | **all_kwargs):
37 |
38 | self.sample_labeller = sample_labeller
39 | self.rng = np.random.default_rng()
40 |
41 | self.min_anom_prop=0.3
42 | self.max_anom_prop=0.8
43 |
44 | self.anomaly_shape_maker = EitherDeformedHypershapePatchMaker(nsa_sample_dimension)
45 | self.all_kwargs = all_kwargs
46 |
47 |
48 | def apply(self,
49 | sample: npt.NDArray[float],
50 | *args, **kwargs)\
51 | -> Tuple[npt.NDArray[float], npt.NDArray[float]]:
52 | """
53 | Apply the self-supervised task to the single data sample.
54 | :param sample: Normal sample to be augmented
55 | :param sample_mask: Object mask of sample.
56 | :return: sample with task applied and label map.
57 | """
58 |
59 | aug_sample = sample.copy()
60 |
61 | sample_shape = np.array(sample.shape[1:])
62 | anomaly_mask = np.zeros(sample_shape, dtype=bool)
63 |
64 | min_anom_prop = self.min_anom_prop
65 | max_anom_prop = self.max_anom_prop
66 |
67 | min_dim_lens = (min_anom_prop * sample_shape).round().astype(int)
68 | max_dim_lens = (max_anom_prop * sample_shape).round().astype(int)
69 | # print(min_dim_lens,max_dim_lens) [15,15],[205,205]
70 |
71 | dim_bounds = list(zip(min_dim_lens, max_dim_lens)) #[(15, 205), (15, 205)]
72 |
73 | # For random number of times
74 | sample_mask = None
75 |
76 | for i in range(2):
77 |
78 | # Compute anomaly mask
79 | curr_anomaly_mask, intersect_fn = self.anomaly_shape_maker.get_patch_mask_and_intersect_fn(dim_bounds,
80 | sample_shape)
81 |
82 | # Choose anomaly location
83 | anomaly_corner = self.find_valid_anomaly_location(curr_anomaly_mask, sample_mask, sample_shape)
84 |
85 | # Apply self-supervised task
86 |
87 | aug_sample = self.augment_sample(aug_sample, sample_mask, anomaly_corner, curr_anomaly_mask, intersect_fn)
88 |
89 | anomaly_mask[get_patch_slices(anomaly_corner, curr_anomaly_mask.shape)] |= curr_anomaly_mask
90 |
91 | # Randomly brake at end of loop, ensuring we get at least 1 anomaly
92 | if self.rng.random() > 0.5:
93 | break
94 |
95 | if self.sample_labeller is not None:
96 | return aug_sample, self.sample_labeller(aug_sample, sample, anomaly_mask)
97 | else:
98 | # If no labeller is provided, we are probably in a calibration process
99 | return aug_sample, np.expand_dims(anomaly_mask, 0)
100 |
101 |
102 | def find_valid_anomaly_location(self,
103 | curr_anomaly_mask: npt.NDArray[bool],
104 | sample_mask: Optional[npt.NDArray[bool]],
105 | sample_shape: npt.NDArray[int]):
106 |
107 | curr_anomaly_shape = np.array(curr_anomaly_mask.shape)
108 | min_corner = np.zeros(len(sample_shape))
109 | max_corner = sample_shape - curr_anomaly_shape
110 |
111 | # - Apply anomaly at location
112 | while True:
113 | anomaly_corner = self.rng.integers(min_corner, max_corner, endpoint=True)
114 |
115 | # If the sample mask is None, any location within the bounds is valid
116 | if sample_mask is None:
117 | break
118 | # Otherwise, we need to check that the intersection of the anomaly mask and the sample mask is at least 50%
119 | target_patch_obj_mask = sample_mask[get_patch_slices(anomaly_corner, curr_anomaly_mask.shape)]
120 | if (np.sum(target_patch_obj_mask & curr_anomaly_mask) / np.sum(curr_anomaly_mask)) >= 0.5:
121 | break
122 |
123 | return anomaly_corner
124 |
125 |
126 | def __call__(self,
127 | sample: npt.NDArray[float],
128 | *args,
129 | **kwargs)\
130 | -> Tuple[npt.NDArray[float], npt.NDArray[float]]:
131 | """
132 | Apply the self-supervised task to the single data sample.
133 | :param sample: Normal sample to be augmented
134 | :param sample_mask: Object mask of sample.
135 | :param **kwargs:
136 | * *sample_path*: Path to source image
137 | :return: sample with task applied and label map.
138 | """
139 | if len(sample.shape)==2:
140 | sample = np.expand_dims(sample,axis=0)
141 |
142 | aug_sample, aug_mask = self.apply(sample, *args, **kwargs)
143 |
144 | if len(aug_sample.shape)==3 and aug_sample.shape[0]==1:
145 | aug_sample = aug_sample.squeeze(0)
146 |
147 | if len(aug_mask.shape)==3 and aug_mask.shape[0]==1:
148 | aug_mask = aug_mask.squeeze(0)
149 |
150 | return aug_sample,aug_mask.astype(np.float)
151 |
152 |
153 | @abstractmethod
154 | def augment_sample(self, sample: npt.NDArray[float], sample_mask: Optional[npt.NDArray[bool]],
155 | anomaly_corner: npt.NDArray[int], anomaly_mask: npt.NDArray[bool],
156 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \
157 | -> npt.NDArray[float]:
158 | """
159 | Apply self-supervised task to region at anomaly_corner covered by anomaly_mask
160 | :param sample: Sample to be augmented.
161 | :param sample_mask: Object mask of sample.
162 | :param anomaly_corner: Index of anomaly corner.
163 | :param anomaly_mask: Mask
164 | :param anomaly_intersect_fn: Function which, given a line's origin and direction, finds its intersection with
165 | the edge of the anomaly mask
166 | :return:
167 | """
168 |
169 |
170 | class BasePatchBlendingTask(BaseTask):
171 |
172 | def __init__(self,
173 | sample_labeller: Optional[AnomalyLabeller],
174 | source_samples: list,
175 | blend_images: Callable[[npt.NDArray[float], npt.NDArray[float], npt.NDArray[int], npt.NDArray[bool]],
176 | npt.NDArray[float]],
177 |
178 | **all_kwargs):
179 | super().__init__(sample_labeller, **all_kwargs)
180 | self.source_samples = source_samples
181 | self.blend_images = blend_images
182 |
183 |
184 | def augment_sample(self,
185 | sample: npt.NDArray[float], # aug sample
186 | sample_mask: Optional[npt.NDArray[bool]], # None
187 | anomaly_corner: npt.NDArray[int], # center
188 | anomaly_mask: npt.NDArray[bool], # small anomaly mask
189 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \
190 | -> npt.NDArray[float]:
191 |
192 | num_channels = sample.shape[0] # 1
193 | num_dims = len(sample.shape[1:]) #2
194 |
195 | # Sample source to blend into current sample
196 | source_sample = random.choice(self.source_samples)
197 |
198 | source_sample_shape = np.array(source_sample.shape[1:]) #(256,256)
199 |
200 |
201 | assert len(source_sample_shape) == num_dims, 'Source and target have different number of spatial dimensions: ' \
202 | f's-{len(source_sample_shape)}, t-{num_dims}'
203 |
204 | assert source_sample.shape[0] == num_channels, \
205 | f'Source and target have different number of channels: s-{source_sample.shape[0]}, t-{num_channels}'
206 |
207 | # Compute INVERSE transformation matrix for parameters (rotation, resizing)
208 | # This is the backwards operation (final source region -> initial source region).
209 |
210 | trans_matrix = functools.reduce(lambda m, ds: accumulate_rotation(m,
211 | self.rng.uniform(-np.pi / 4, np.pi / 4),
212 | ds),
213 | itertools.combinations(range(num_dims), 2),
214 | np.identity(num_dims))
215 |
216 | # Compute effect on corner coords
217 | target_anomaly_shape = np.array(anomaly_mask.shape)
218 |
219 | corner_coords = np.array(np.meshgrid(*np.stack([np.zeros(num_dims), target_anomaly_shape], axis=-1),
220 | indexing='ij')).reshape(num_dims, 2 ** num_dims)
221 |
222 | trans_corner_coords = trans_matrix @ corner_coords
223 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1))
224 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1))
225 | init_grid_shape = max_trans_coords - min_trans_coords
226 |
227 | # Sample scale and clip so that source region isn't too big
228 | max_scale = np.min(0.8 * source_sample_shape / init_grid_shape)
229 |
230 | # Compute final transformation matrix
231 | scale_change = 1 + self.rng.exponential(scale=0.1)
232 | scale_raw = self.rng.choice([scale_change, 1 / scale_change])
233 | scale = np.minimum(scale_raw, max_scale)
234 |
235 | trans_matrix = accumulate_scaling(trans_matrix, scale)
236 |
237 | # Recompute effect on corner coord
238 | trans_corner_coords = trans_matrix @ corner_coords
239 | min_trans_coords = np.floor(np.min(trans_corner_coords, axis=1))
240 | max_trans_coords = np.ceil(np.max(trans_corner_coords, axis=1))
241 | final_init_grid_shape = max_trans_coords - min_trans_coords
242 |
243 | # Choose anomaly source location
244 | final_init_grid_shape = final_init_grid_shape.astype(int)
245 | min_corner = np.zeros(len(source_sample_shape))
246 | max_corner = source_sample_shape - final_init_grid_shape
247 |
248 | source_corner = self.rng.integers(min_corner, max_corner, endpoint=True)
249 |
250 | # Extract source
251 | source_orig = source_sample[get_patch_image_slices(source_corner, tuple(final_init_grid_shape))]
252 |
253 |
254 | # Because we computed the backwards transformation we don't need to inverse the matrix
255 | source_to_blend = np.stack([affine_transform(chan, trans_matrix, offset=-min_trans_coords,
256 | output_shape=tuple(target_anomaly_shape))
257 | for chan in source_orig])
258 |
259 | spatial_axis = tuple(range(1, len(source_sample.shape)))
260 | # Spline interpolation can make values fall outside domain, so clip to the original range
261 | source_to_blend = np.clip(source_to_blend,
262 | source_sample.min(axis=spatial_axis, keepdims=True),
263 | source_sample.max(axis=spatial_axis, keepdims=True))
264 |
265 |
266 | # As the blending can alter areas outside the mask, update the mask with any effected areas
267 |
268 | aug_sample = self.blend_images(sample, source_to_blend, anomaly_corner, anomaly_mask)
269 |
270 | sample_slices = get_patch_image_slices(anomaly_corner, tuple(anomaly_mask.shape))
271 | sample_diff = np.mean(np.abs(sample[sample_slices] - aug_sample[sample_slices]), axis=0)
272 |
273 | anomaly_mask[sample_diff > 0.001] = True
274 | # Return sample with source blended into it
275 | return aug_sample
276 |
277 |
278 |
279 | class BaseDeformationTask(BaseTask):
280 |
281 | @abstractmethod
282 | def compute_mapping(self,
283 | sample: npt.NDArray[float],
284 | sample_mask: Optional[npt.NDArray[bool]],
285 | anomaly_corner: npt.NDArray[int], anomaly_mask: npt.NDArray[bool],
286 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \
287 | -> npt.NDArray[float]:
288 | """
289 | Returns array of size (*anomaly_mask.shape, len(anomaly_mask.shape)).
290 | Probably don't need entire sample, but including in for generality.
291 | :param sample:
292 | :param sample_mask:
293 | :param anomaly_corner:
294 | :param anomaly_mask:
295 | :param anomaly_intersect_fn:
296 | :return:
297 | """
298 |
299 | def augment_sample(self,
300 | sample: npt.NDArray[float],
301 | sample_mask: Optional[npt.NDArray[bool]],
302 | anomaly_corner: npt.NDArray[int],
303 | anomaly_mask: npt.NDArray[bool],
304 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \
305 | -> npt.NDArray[float]:
306 |
307 | num_channels = sample.shape[0]
308 | mapping = self.compute_mapping(sample, sample_mask, anomaly_corner, anomaly_mask, anomaly_intersect_fn)
309 | sample_slices = get_patch_slices(anomaly_corner, tuple(anomaly_mask.shape))
310 |
311 | for chan in range(num_channels):
312 | sample[chan][sample_slices] = ndimage.map_coordinates(sample[chan][sample_slices],
313 | mapping,
314 | mode='nearest')
315 | return sample
316 |
317 |
318 |
319 | class RadialDeformationTask(BaseDeformationTask):
320 |
321 | def __init__(self,
322 | sample_labeller: Optional[AnomalyLabeller] = None,
323 | deform_factor: Optional[float] = None,
324 | deform_centre: Optional[npt.NDArray] = None, **kwargs):
325 |
326 | super().__init__(sample_labeller, **kwargs)
327 | self.deform_factor = deform_factor
328 | self.deform_centre = deform_centre
329 | self.max_anom_prop = 0.6
330 | self.min_anom_prop = 0.2
331 |
332 | def get_deform_factor(self, def_centre: npt.NDArray[int], anomaly_mask: npt.NDArray[bool]):
333 | return self.deform_factor if self.deform_factor is not None else 2 ** self.rng.uniform(0.5, 2)
334 |
335 | @abstractmethod
336 | def compute_new_distance(self, curr_distance: float, max_distance: float, factor: float) -> float:
337 | """
338 | Compute new distance for point to be sampled from
339 | :param curr_distance:
340 | :param max_distance:
341 | :param factor:
342 | """
343 |
344 | def compute_mapping(self,
345 | sample: npt.NDArray[float],
346 | sample_mask: Optional[npt.NDArray[bool]],
347 | anomaly_corner: npt.NDArray[int],
348 | anomaly_mask: npt.NDArray[bool],
349 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \
350 | -> npt.NDArray[float]:
351 | # NOTE: This assumes that the shape is convex, will make discontinuities if it's not.
352 |
353 | anomaly_shape = np.array(anomaly_mask.shape)
354 | num_dims = len(anomaly_shape)
355 |
356 | # Expand so can later be broadcast with (D, N)
357 | mask_centre = (anomaly_shape - 1) / 2
358 | exp_mask_centre = np.reshape(mask_centre, (-1, 1))
359 | # Shape (D, N)
360 | poss_centre_coords = np.stack(np.nonzero(anomaly_mask))
361 | def_centre = self.deform_centre if self.deform_centre is not None else \
362 | poss_centre_coords[:, np.random.randint(poss_centre_coords.shape[1])]
363 |
364 | assert anomaly_mask[tuple(def_centre.round().astype(int))], f'Centre is not within anomaly: {def_centre}'
365 |
366 | # exp_ = expanded
367 | exp_def_centre = np.reshape(def_centre, (-1, 1))
368 |
369 | # (D, *anomaly_shape)
370 | mapping = np.stack(np.meshgrid(*[np.arange(s, dtype=float) for s in anomaly_shape], indexing='ij'), axis=0)
371 |
372 | # Ignore pixels on edge of bounding box
373 | mask_inner_slice = tuple([slice(1, -1)] * num_dims)
374 | map_inner_slice = tuple([slice(None)] + list(mask_inner_slice))
375 | # Get all coords and transpose so coord index is last dimension (D, N)
376 | anomaly_coords = mapping[map_inner_slice][(slice(None), anomaly_mask[mask_inner_slice])]
377 |
378 | all_coords_to_centre = anomaly_coords - exp_def_centre
379 | all_coords_distance = norm(all_coords_to_centre, axis=0)
380 | # Ignore zero divided by zero, as we correct it before mapping is returned
381 | with np.errstate(invalid='ignore'):
382 | all_coords_norm_dirs = all_coords_to_centre / all_coords_distance
383 |
384 | mask_edge_intersections = anomaly_intersect_fn(exp_def_centre - exp_mask_centre, all_coords_norm_dirs) + exp_mask_centre
385 |
386 | mask_edge_distances = norm(mask_edge_intersections - exp_def_centre, axis=0)
387 |
388 | # Get factor once, so is same for all pixels
389 | def_factor = self.get_deform_factor(def_centre, anomaly_mask)
390 | new_coord_distances = self.compute_new_distance(all_coords_distance, mask_edge_distances, def_factor)
391 | # (D, N)
392 | new_coords = exp_def_centre + new_coord_distances * all_coords_norm_dirs
393 |
394 | mapping[map_inner_slice][(slice(None), anomaly_mask[mask_inner_slice])] = new_coords
395 |
396 | # Revert centre coordinate, as it will be nan due to the zero magnitude direction vector
397 | mapping[(slice(None), *def_centre)] = def_centre
398 | return mapping
399 |
400 |
401 |
402 | class CutPastePatchBlender(BasePatchBlendingTask):
403 |
404 | def __init__(self,
405 | source_images: list,
406 | Labelber_std: float= 0.2,
407 | **kwargs):
408 | sample_labeller = FlippedGaussianLabeller(Labelber_std)
409 | source_images=[ np.expand_dims(image,axis=0) if len(image.shape)==2 else image for image in source_images]
410 | super().__init__(sample_labeller, source_images, cut_paste)
411 | self.max_anom_prop = 0.6
412 | self.min_anom_prop = 0.1
413 | self.anomaly_shape_maker = PerlinPatchMaker()
414 |
415 |
416 |
417 | class SmoothIntensityChangeTask(BaseTask):
418 |
419 | def __init__(self,
420 | intensity_task_scale: float,
421 | sample_labeller: Optional[AnomalyLabeller] = None,
422 | **all_kwargs):
423 |
424 | super().__init__(sample_labeller, **all_kwargs)
425 | self.intensity_task_scale = intensity_task_scale
426 | self.max_anom_prop = 0.8
427 | self.min_anom_prop = 0.3
428 | self.anomaly_shape_maker = PerlinPatchMaker()
429 |
430 | def augment_sample(self,
431 | sample: npt.NDArray[float], # aug_sample
432 | sample_mask: Optional[npt.NDArray[bool]],# None
433 | anomaly_corner: npt.NDArray[int], # anomaly center
434 | anomaly_mask: npt.NDArray[bool], # small anomaly mask
435 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \
436 | -> npt.NDArray[float]:
437 |
438 | num_chans = sample.shape[0] # 1
439 | sample_shape = sample.shape[1:] #(256,256)
440 | num_dims = len(sample_shape) # 2
441 |
442 | dist_map = distance_transform_edt(anomaly_mask)
443 | min_shape_dim = np.min(sample_shape) # 256
444 |
445 | smooth_dist = np.minimum(min_shape_dim * (0.02 + np.random.gamma(3, 0.01)), np.max(dist_map))
446 | smooth_dist_map = dist_map / smooth_dist
447 | smooth_dist_map[smooth_dist_map > 1] = 1
448 | # smooth_dist_map = 1
449 |
450 | anomaly_patch_slices = get_patch_image_slices(anomaly_corner, anomaly_mask.shape)
451 |
452 | # anomaly_pixel_stds = np.array([np.std(c[anomaly_mask]) for c in sample[anomaly_patch_slices]])
453 | # Randomly negate, so some intensity changes are subtractions
454 |
455 | intensity_changes = (self.intensity_task_scale / 2 + np.random.gamma(3, self.intensity_task_scale)) \
456 | * np.random.choice([1, -1], size=num_chans)
457 |
458 | intensity_change_map = smooth_dist_map * np.reshape(intensity_changes, [-1] + [1] * num_dims)
459 |
460 | new_patch = sample[anomaly_patch_slices] + intensity_change_map
461 |
462 | spatial_axis = tuple(range(1, len(sample.shape)))
463 |
464 | sample[anomaly_patch_slices] = np.clip(new_patch,
465 | sample.min(axis=spatial_axis, keepdims=True),
466 | sample.max(axis=spatial_axis, keepdims=True))
467 |
468 | return sample
469 |
470 |
471 |
472 | class GaussIntensityChangeTask(BaseTask):
473 |
474 | def __init__(self,
475 | sample_labeller: Optional[AnomalyLabeller] = None,
476 | **all_kwargs):
477 |
478 | super().__init__(sample_labeller, **all_kwargs)
479 | self.max_anom_prop = 0.8
480 | self.min_anom_prop = 0.3
481 | self.sigma_bs = [4, 7]
482 | self.positive_range = [0.4, 0.6]
483 | self.negative_range = [-0.6, -0.4]
484 | self.anomaly_shape_maker = PerlinPatchMaker()
485 | def get_predefined_texture(self,
486 | mask_shape,
487 | sigma_b,
488 | positive_range=None,
489 | negative_range=None,
490 | ):
491 |
492 | assert (positive_range is not None) or (negative_range is not None)
493 |
494 | random_sample = np.random.randn(mask_shape[0], mask_shape[1])
495 |
496 | random_sample = (random_sample >= 0.0).astype(float) # int type can't do Gaussian filter
497 |
498 | random_sample = gaussian_filter(random_sample, sigma_b)
499 |
500 | random_sample = (random_sample - np.min(random_sample)) / (np.max(random_sample) - np.min(random_sample))
501 |
502 | if np.random.uniform(0, 1) <= 0.5:
503 | u_0 = np.random.uniform(positive_range[0], positive_range[1])
504 | else:
505 | if negative_range is not None:
506 | u_0 = np.random.uniform(negative_range[0], negative_range[1])
507 | else:
508 | u_0 = np.random.uniform(-positive_range[1], -positive_range[0])
509 |
510 | Bj = np.clip(u_0 * random_sample, -1, 1)
511 | return Bj
512 |
513 | def create_texture(self,sizes):
514 | texture = self.get_predefined_texture(sizes,
515 | random.choice(self.sigma_bs),
516 | self.positive_range,
517 | self.negative_range)
518 | return texture
519 |
520 |
521 | def augment_sample(self,
522 | sample: npt.NDArray[float], # aug_sample
523 | sample_mask: Optional[npt.NDArray[bool]],# None
524 | anomaly_corner: npt.NDArray[int], # anomaly center
525 | anomaly_mask: npt.NDArray[bool], # small anomaly mask
526 | anomaly_intersect_fn: Callable[[npt.NDArray[float], npt.NDArray[float]], npt.NDArray[float]]) \
527 | -> npt.NDArray[float]:
528 |
529 | anomaly_mask_copy = anomaly_mask.astype(np.float)
530 | anomaly_patch_slices = get_patch_image_slices(anomaly_corner, anomaly_mask_copy.shape)
531 |
532 | texture = self.create_texture(sample.shape[1:])
533 |
534 | while True:
535 | if len(texture.shape) npt.NDArray[float]:
569 | anomaly_mask[:,:] = False
570 | return sample
571 |
572 |
573 |
574 | class SinkDeformationTask(RadialDeformationTask):
575 | # y = 1 - (1 - x)^3 (between 0 and 1)
576 | # -> y = max_d (1 - (1 - curr / max_d) ^ factor)
577 | # -> y = max_d - (max_d - curr) ^ factor / max_d ^ (factor - 1)
578 |
579 | def compute_new_distance(self, curr_distance: Union[float, npt.NDArray[float]],
580 | max_distance: Union[float, npt.NDArray[float]],
581 | factor: Union[float, npt.NDArray[float]]) -> Union[float, npt.NDArray[float]]:
582 |
583 | return max_distance - (max_distance - curr_distance) ** factor / max_distance ** (factor - 1)
584 |
585 |
586 |
587 | class SourceDeformationTask(RadialDeformationTask):
588 |
589 | def compute_new_distance(self, curr_distance: Union[float, npt.NDArray[float]],
590 | max_distance: Union[float, npt.NDArray[float]],
591 | factor: Union[float, npt.NDArray[float]]) -> Union[float, npt.NDArray[float]]:
592 | # y = x^3 (between 0 and 1)
593 | # -> y = max_d * (curr / max) ^ factor
594 | # -> y = curr ^ factor / max_d ^ (factor - 1) to avoid FP errors
595 | return curr_distance ** factor / max_distance ** (factor - 1)
--------------------------------------------------------------------------------