├── MOODv2
├── src
│ ├── __init__.py
│ ├── list_dataset.py
│ ├── extract_feature_vit.py
│ └── demo.py
├── imgs
│ ├── framework.png
│ ├── distribution.png
│ ├── moodv2_table.png
│ ├── performance.png
│ ├── DTD_cracked_0004.jpg
│ └── ImageNet_ILSVRC2012_val_00000293.JPEG
├── configs
│ ├── vit-base-p16_384px.py
│ ├── vit-base-p14_224px.py
│ ├── vit-base-p14_518px.py
│ ├── vit-base-p16_224px.py
│ ├── vit-huge-p14_448px.py
│ ├── vit-huge-p14_224px.py
│ ├── swin-base-w7_224px.py
│ ├── beit-large-p14_224px.py
│ ├── beit-base-p16_224px.py
│ ├── beit-base-p16_384px.py
│ ├── pre-beit-base-p16_224px.py
│ ├── pre-beitv2-base-p16_224px.py
│ └── pre-mocov3-base-p16_224px.py
├── .gitignore
└── README.md
├── MOODv1
├── imgs
│ └── moodv1_performance.png
├── .gitignore
├── requirements.txt
├── deepspeed_config.json
├── dall_e
│ ├── __init__.py
│ ├── utils.py
│ ├── encoder.py
│ └── decoder.py
├── datasets
│ └── imagenet30.py
├── masking_generator.py
├── engine_for_pretraining.py
├── checkpoint.py
├── ood_utils.py
├── transforms.py
├── modeling_pretrain.py
├── optim_factory.py
├── modeling_discrete_vae.py
├── engine_for_finetuning.py
├── README.md
├── run_beit_pretraining.py
├── datasets.py
├── eval_with_features.py
├── dataset_folder.py
└── eval_with_logits.py
└── README.md
/MOODv2/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MOODv2/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/framework.png
--------------------------------------------------------------------------------
/MOODv2/imgs/distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/distribution.png
--------------------------------------------------------------------------------
/MOODv2/imgs/moodv2_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/moodv2_table.png
--------------------------------------------------------------------------------
/MOODv2/imgs/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/performance.png
--------------------------------------------------------------------------------
/MOODv1/imgs/moodv1_performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv1/imgs/moodv1_performance.png
--------------------------------------------------------------------------------
/MOODv2/imgs/DTD_cracked_0004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/DTD_cracked_0004.jpg
--------------------------------------------------------------------------------
/MOODv1/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/**
2 | checkpoints/
3 | run.md
4 | output
5 | pretrained/
6 | tokenizer/
7 | *.sh
8 | eval_results/
--------------------------------------------------------------------------------
/MOODv2/imgs/ImageNet_ILSVRC2012_val_00000293.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/ImageNet_ILSVRC2012_val_00000293.JPEG
--------------------------------------------------------------------------------
/MOODv1/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | torchvision==0.8.2
3 | timm==0.3.2
4 | Pillow
5 | blobfile
6 | mypy
7 | numpy
8 | pytest
9 | requests
10 | einops
11 | tensorboardX
12 | deepspeed==0.4.0
13 | scipy
14 |
--------------------------------------------------------------------------------
/MOODv2/configs/vit-base-p16_384px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='VisionTransformer',
5 | arch='b',
6 | img_size=384,
7 | patch_size=16,
8 | drop_rate=0.1,),
9 | neck=None,
10 | head=dict(
11 | type='VisionTransformerClsHead',
12 | num_classes=1000,
13 | in_channels=768,
14 | ))
15 |
16 |
--------------------------------------------------------------------------------
/MOODv2/configs/vit-base-p14_224px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='VisionTransformer',
5 | arch='b',
6 | img_size=224,
7 | patch_size=14,
8 | drop_rate=0.1,),
9 | neck=None,
10 | head=dict(
11 | type='VisionTransformerClsHead',
12 | num_classes=1000,
13 | in_channels=768,
14 | ))
15 |
16 |
17 |
--------------------------------------------------------------------------------
/MOODv2/configs/vit-base-p14_518px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='VisionTransformer',
5 | arch='b',
6 | img_size=518,
7 | patch_size=14,
8 | drop_rate=0.1,),
9 | neck=None,
10 | head=dict(
11 | type='VisionTransformerClsHead',
12 | num_classes=1000,
13 | in_channels=768,
14 | ))
15 |
16 |
17 |
--------------------------------------------------------------------------------
/MOODv2/configs/vit-base-p16_224px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='VisionTransformer',
5 | arch='b',
6 | img_size=224,
7 | patch_size=16,
8 | drop_rate=0.1,),
9 | neck=None,
10 | head=dict(
11 | type='VisionTransformerClsHead',
12 | num_classes=1000,
13 | in_channels=768,
14 | ))
15 |
16 |
17 |
--------------------------------------------------------------------------------
/MOODv2/configs/vit-huge-p14_448px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='VisionTransformer',
5 | arch='huge',
6 | img_size=448,
7 | patch_size=14,
8 | drop_path_rate=0.3, # set to 0.3
9 | out_type='avg_featmap',
10 | final_norm=False,),
11 | neck=None,
12 | head=dict(
13 | type='LinearClsHead',
14 | num_classes=1000,
15 | in_channels=1280,
16 | )
17 | )
--------------------------------------------------------------------------------
/MOODv2/configs/vit-huge-p14_224px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='VisionTransformer',
5 | arch='huge',
6 | img_size=224,
7 | patch_size=14,
8 | drop_path_rate=0.3, # set to 0.3
9 | out_type='avg_featmap',
10 | final_norm=False,),
11 | neck=None,
12 | head=dict(
13 | type='LinearClsHead',
14 | num_classes=1000,
15 | in_channels=1280,
16 | )
17 | )
18 |
--------------------------------------------------------------------------------
/MOODv2/configs/swin-base-w7_224px.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | model = dict(
3 | type='ImageClassifier',
4 | backbone=dict(
5 | type='SwinTransformer',
6 | arch='base',
7 | img_size=224,
8 | stage_cfgs=dict(block_cfgs=dict(window_size=7)),
9 | ),
10 | neck=dict(type='GlobalAveragePooling'),
11 | head=dict(
12 | type='LinearClsHead',
13 | num_classes=1000,
14 | in_channels=1024,
15 | cal_acc=False
16 | ),
17 | )
18 |
--------------------------------------------------------------------------------
/MOODv2/configs/beit-large-p14_224px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='BEiTViT',
5 | arch='l',
6 | img_size=224,
7 | patch_size=14,
8 | out_type='avg_featmap',
9 | use_abs_pos_emb=False,
10 | use_rel_pos_bias=True,
11 | use_shared_rel_pos_bias=False,
12 | ),
13 | neck=None,
14 | head=dict(
15 | type='LinearClsHead',
16 | num_classes=1000,
17 | in_channels=1024,
18 | ),
19 | )
--------------------------------------------------------------------------------
/MOODv2/configs/beit-base-p16_224px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='BEiTViT',
5 | arch='base',
6 | img_size=224,
7 | patch_size=16,
8 | out_type='avg_featmap',
9 | use_abs_pos_emb=False,
10 | use_rel_pos_bias=True,
11 | use_shared_rel_pos_bias=False,
12 | ),
13 | neck=None,
14 | head=dict(
15 | type='LinearClsHead',
16 | num_classes=1000,
17 | in_channels=768,
18 | ),
19 | )
20 |
--------------------------------------------------------------------------------
/MOODv2/configs/beit-base-p16_384px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='ImageClassifier',
3 | backbone=dict(
4 | type='BEiTViT',
5 | arch='base',
6 | img_size=384,
7 | patch_size=16,
8 | out_type='avg_featmap',
9 | use_abs_pos_emb=False,
10 | use_rel_pos_bias=True,
11 | use_shared_rel_pos_bias=False,
12 | ),
13 | neck=None,
14 | head=dict(
15 | type='LinearClsHead',
16 | num_classes=1000,
17 | in_channels=768,
18 | ),
19 | )
20 |
--------------------------------------------------------------------------------
/MOODv2/configs/pre-beit-base-p16_224px.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | model = dict(
3 | type='BEiT',
4 | backbone=dict(
5 | type='BEiTPretrainViT',
6 | arch='base',
7 | patch_size=16,
8 | drop_path_rate=0.1,
9 | final_norm=True,
10 | out_type='raw',
11 | layer_scale_init_value=0.1,
12 | ),
13 | neck=None,
14 | head=dict(
15 | type='BEiTV1Head',
16 | embed_dims=768,
17 | num_embed=8192,
18 | loss=dict(type='CrossEntropyLoss'),
19 | )
20 | )
--------------------------------------------------------------------------------
/MOODv1/deepspeed_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": 8,
3 | "train_micro_batch_size_per_gpu": 8,
4 | "steps_per_print": 1000,
5 | "optimizer": {
6 | "type": "Adam",
7 | "adam_w_mode": true,
8 | "params": {
9 | "lr": 0.004,
10 | "weight_decay": 0.05,
11 | "bias_correction": true,
12 | "betas": [
13 | 0.9,
14 | 0.999
15 | ],
16 | "eps": 1e-08
17 | }
18 | },
19 | "fp16": {
20 | "enabled": true,
21 | "loss_scale": 0,
22 | "initial_scale_power": 7,
23 | "loss_scale_window": 128
24 | }
25 | }
--------------------------------------------------------------------------------
/MOODv2/configs/pre-beitv2-base-p16_224px.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BEiT',
3 | backbone=dict(
4 | type='BEiTPretrainViT',
5 | arch='base',
6 | patch_size=16,
7 | out_indices=[-4, -1],
8 | final_norm=False,
9 | out_type='raw',),
10 | neck=dict(
11 | type='BEiTV2Neck',
12 | num_layers=2,
13 | early_layers=9,
14 | backbone_arch='base',
15 | ),
16 | head=dict(
17 | type='BEiTV2Head',
18 | embed_dims=768,
19 | num_embed=8192,
20 | loss=dict(type='CrossEntropyLoss')),
21 | )
22 |
23 |
--------------------------------------------------------------------------------
/MOODv1/dall_e/__init__.py:
--------------------------------------------------------------------------------
1 | import io, requests
2 | import torch
3 | import torch.nn as nn
4 |
5 | from dall_e.encoder import Encoder
6 | from dall_e.decoder import Decoder
7 | from dall_e.utils import map_pixels, unmap_pixels
8 |
9 | def load_model(path: str, device: torch.device = None) -> nn.Module:
10 | if path.startswith('http://') or path.startswith('https://'):
11 | resp = requests.get(path)
12 | resp.raise_for_status()
13 |
14 | with io.BytesIO(resp.content) as buf:
15 | return torch.load(buf, map_location=device)
16 | else:
17 | with open(path, 'rb') as f:
18 | return torch.load(f, map_location=device)
19 |
--------------------------------------------------------------------------------
/MOODv2/configs/pre-mocov3-base-p16_224px.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | temperature = 0.2
3 | model = dict(
4 | type='MoCoV3',
5 | base_momentum=0.01,
6 | backbone=dict(
7 | type='MoCoV3ViT',
8 | arch='base', # embed_dim = 768
9 | img_size=224,
10 | patch_size=16,
11 | stop_grad_conv1=True),
12 | neck=dict(
13 | type='NonLinearNeck',
14 | in_channels=768,
15 | hid_channels=4096,
16 | out_channels=256,
17 | num_layers=3,
18 | with_bias=False,
19 | with_last_bn=True,
20 | with_last_bn_affine=False,
21 | with_last_bias=False,
22 | with_avg_pool=False),
23 | head=dict(
24 | type='MoCoV3Head',
25 | predictor=dict(
26 | type='NonLinearNeck',
27 | in_channels=256,
28 | hid_channels=4096,
29 | out_channels=256,
30 | num_layers=2,
31 | with_bias=False,
32 | with_last_bn=True,
33 | with_last_bn_affine=False,
34 | with_last_bias=False,
35 | with_avg_pool=False),
36 | loss=dict(type='CrossEntropyLoss', loss_weight=2 * temperature),
37 | temperature=temperature,))
38 |
39 |
--------------------------------------------------------------------------------
/MOODv2/src/list_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 |
4 | import torch.utils.data as data
5 | from PIL import Image
6 |
7 |
8 | def default_loader(path):
9 | return Image.open(path).convert('RGB')
10 |
11 |
12 | def default_flist_reader(flist):
13 | """flist format: impath label\nimpath label\n."""
14 | imlist = []
15 | with open(flist, 'r') as rf:
16 | for line in rf.readlines():
17 | data = line.strip().rsplit(maxsplit=1)
18 | if len(data) == 2:
19 | impath, imlabel = data
20 | else:
21 | impath, imlabel = data[0], 0
22 | imlist.append((impath, int(imlabel)))
23 |
24 | return imlist
25 |
26 |
27 | class ImageFilelist(data.Dataset):
28 |
29 | def __init__(self,
30 | root,
31 | flist,
32 | transform=None,
33 | target_transform=None,
34 | flist_reader=default_flist_reader,
35 | loader=default_loader):
36 | self.root = root
37 | self.imlist = flist_reader(flist)
38 | self.transform = transform
39 | self.target_transform = target_transform
40 | self.loader = loader
41 |
42 | def __getitem__(self, index):
43 | impath, target = self.imlist[index]
44 | img = self.loader(os.path.join(self.root, impath))
45 | if self.transform is not None:
46 | img = self.transform(img)
47 | if self.target_transform is not None:
48 | target = self.target_transform(target)
49 |
50 | return img, target
51 |
52 | def __len__(self):
53 | return len(self.imlist)
54 |
--------------------------------------------------------------------------------
/MOODv1/dall_e/utils.py:
--------------------------------------------------------------------------------
1 | import attr
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | logit_laplace_eps: float = 0.1
9 |
10 | @attr.s(eq=False)
11 | class Conv2d(nn.Module):
12 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
13 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
14 | kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
15 |
16 | use_float16: bool = attr.ib(default=True)
17 | device: torch.device = attr.ib(default=torch.device('cpu'))
18 | requires_grad: bool = attr.ib(default=False)
19 |
20 | def __attrs_post_init__(self) -> None:
21 | super().__init__()
22 |
23 | w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
24 | device=self.device, requires_grad=self.requires_grad)
25 | w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
26 |
27 | b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
28 | requires_grad=self.requires_grad)
29 | self.w, self.b = nn.Parameter(w), nn.Parameter(b)
30 |
31 | def forward(self, x: torch.Tensor) -> torch.Tensor:
32 | if self.use_float16 and 'cuda' in self.w.device.type:
33 | if x.dtype != torch.float16:
34 | x = x.half()
35 |
36 | w, b = self.w.half(), self.b.half()
37 | else:
38 | if x.dtype != torch.float32:
39 | x = x.float()
40 |
41 | w, b = self.w, self.b
42 |
43 | return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
44 |
45 | def map_pixels(x: torch.Tensor) -> torch.Tensor:
46 | if x.dtype != torch.float:
47 | raise ValueError('expected input to have type float')
48 |
49 | return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
50 |
51 | def unmap_pixels(x: torch.Tensor) -> torch.Tensor:
52 | if len(x.shape) != 4:
53 | raise ValueError('expected input to be 4d')
54 | if x.dtype != torch.float:
55 | raise ValueError('expected input to have type float')
56 |
57 | return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)
58 |
--------------------------------------------------------------------------------
/MOODv1/datasets/imagenet30.py:
--------------------------------------------------------------------------------
1 | '''
2 | To Do:
3 | make symlink for ImageNet30
4 | '''
5 |
6 | import os, errno
7 |
8 | def symlink_force(target, link_name):
9 | print('{}->{}'.format(target, link_name))
10 | try:
11 | os.symlink(target, link_name)
12 | except:
13 | os.remove(link_name)
14 | os.symlink(target, link_name)
15 |
16 |
17 | def mkdir(directory):
18 | if not os.path.exists(directory):
19 | os.makedirs(directory)
20 |
21 | pairs = [
22 | ('acorn', 'n12267677'),
23 | ('airliner', 'n02690373'),
24 | ('ambulance', 'n02701002'),
25 | ('american_alligator', 'n01698640'),
26 | ('banjo', 'n02787622'),
27 | ('barn', 'n02793495'),
28 | ('bikini', 'n02837789'),
29 | ('digital_clock', 'n03196217'),
30 | ('dragonfly', 'n02268443'),
31 | ('dumbbell', 'n03255030'),
32 | ('forklift', 'n03384352'),
33 | ('goblet', 'n03443371'),
34 | ('grand_piano', 'n03452741'),
35 | ('hotdog', 'n07697537'),
36 | ('hourglass', 'n03544143'),
37 | ('manhole_cover', 'n03717622'),
38 | ('mosque', 'n03788195'),
39 | ('nail', 'n03804744'),
40 | ('parking_meter', 'n03891332'),
41 | ('pillow', 'n03938244'),
42 | ('revolver', 'n04086273'),
43 | ('rotary_dial_telephone', 'n03187595'),
44 | ('schooner', 'n04147183'),
45 | ('snowmobile', 'n04252077'),
46 | ('soccer_ball', 'n04254680'),
47 | ('stingray', 'n01498041'),
48 | ('strawberry', 'n07745940'),
49 | ('tank', 'n04389033'),
50 | ('toaster', 'n04442312'),
51 | ('volcano', 'n09472597')
52 | ]
53 |
54 | # set source and target paths here
55 | ori_imagenet1k_path = ''
56 | target_imagenet30_path = ''
57 |
58 | train_path = os.path.join(target_imagenet30_path, 'train')
59 | test_path = os.path.join(target_imagenet30_path, 'test')
60 | mkdir(train_path)
61 | mkdir(test_path)
62 |
63 | for pair in pairs:
64 | symlink_force(os.path.join(ori_imagenet1k_path, 'train', pair[1]), os.path.join(train_path, pair[0]))
65 | symlink_force(os.path.join(ori_imagenet1k_path, 'val', pair[1]), os.path.join(test_path, pair[0]))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MOOD
2 |
3 | • 🤗 Model
4 | • 🐱 Code
5 | • 📃 MOODv1
6 | • 📃 MOODv2
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## MOODv1: Rethinking Out-of-Distribution Detection: Masked Image Modeling is All You Need (CVPR2023)
14 | The core of out-of-distribution (OOD) detection is to learn the in-distribution (ID) representation, which is distinguishable from OOD samples. Previous work applied recognition-based methods to learn the ID features, which tend to learn shortcuts instead of comprehensive representations. In this work, we find surprisingly that simply using reconstruction-based methods could boost the performance of OOD detection significantly. We deeply explore the main contributors of OOD detection and find that reconstruction-based pretext tasks have the potential to provide a generally applicable and efficacious prior, which benefits the model in learning intrinsic data distributions of the ID dataset. Specifically, we take Masked Image Modeling as a pretext task for our OOD detection framework (MOOD). Without bells and whistles, MOOD outperforms previous SOTA of one-class OOD detection by 5.7%, multi-class OOD detection by 3.0%, and near-distribution OOD detection by 2.1%. It even defeats the 10-shot-per-class outlier exposure OOD detection, although we do not include any OOD samples for our detection.
15 |
16 |
17 |
18 |
19 | ## MOODv2: Masked Image Modeling for Out-of-Distribution Detection (TPAMI2024)
20 | The crux of effective out-of-distribution (OOD) detection lies in acquiring a robust in-distribution (ID) representation, distinct from OOD samples. While previous methods predominantly leaned on recognition-based techniques for this purpose, they often resulted in shortcut learning, lacking comprehensive representations. In our study, we conducted a comprehensive analysis, exploring distinct pretraining tasks and employing various OOD score functions. The results highlight that the feature representations pre-trained through reconstruction yield a notable enhancement and narrow the performance gap among various score functions. This suggests that even simple score functions can rival complex ones when leveraging reconstruction-based pretext tasks. Reconstruction-based pretext tasks adapt well to various score functions. As such, it holds promising potential for further expansion. Our OOD detection framework, MOODv2, employs the masked image modeling pretext task. Without bells and whistles, MOODv2 impressively enhances 14.30% AUROC to 95.68% on ImageNet and achieves 99.98% on CIFAR-10.
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/MOODv2/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | scripts
3 | outputs
4 | models
5 | data
6 | pretrain
7 | tools
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 | **/*.pyc
14 | core.*
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | !**/ilayout/lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # SageMath parsed files
93 | *.sage.py
94 |
95 | # Environments
96 | .env
97 | .venv
98 | env/
99 | venv/
100 | ENV/
101 | env.bak/
102 | venv.bak/
103 |
104 | # Spyder project settings
105 | .spyderproject
106 | .spyproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 | # mkdocs documentation
112 | /site
113 |
114 | # mypy
115 | .mypy_cache/
116 |
117 | # custom
118 | /models
119 | /data
120 | .vscode
121 | .idea
122 | *.pkl
123 | *.pkl.json
124 | *.log.json
125 | *.npy
126 | work_dirs/
127 | work_dirs
128 | work_dir
129 | workdirs/
130 | results
131 | checkpoints
132 | logs
133 | log
134 | projects/*/data/
135 | projects/*/data
136 |
137 | # Pytorch
138 | *.pth
139 |
140 | *.DS_Store
141 |
142 | *.tar
143 | *.TTF
144 |
145 | # web demo
146 | projects/ocrserver/ocrserver/static/results
147 | projects/ocrserver/ocrserver/static/uploads
148 | projects/ocrserver/inference_results
149 | projects/ocrserver/logs
150 | projects/autotrain/*/data/
151 | projects/autotrain/*/data
152 | tmp/
153 | debug/
154 |
155 | # medius shell
156 | run_medius*.sh
157 | run_visualize*.sh
158 | projects/videoocr/run_sdk*.sh
159 | projects/videoocr/run_compare*.sh
160 |
161 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
162 |
163 | # dependencies
164 | node_modules
165 | /.pnp
166 | .pnp.js
167 |
168 | # testing
169 | /coverage
170 |
171 | # production
172 | /dist
173 |
174 | # misc
175 | .DS_Store
176 | .env.local
177 | .env.development.local
178 | .env.test.local
179 | .env.production.local
180 |
181 | npm-debug.log*
182 | yarn-debug.log*
183 | yarn-error.log*
184 |
185 | **/public/fonts/
186 | **/temp/
187 | **/tmp/
188 | projects/ilayout/public/img/texture
189 |
190 | outputs/
191 | checkpoints/
--------------------------------------------------------------------------------
/MOODv1/masking_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
3 | Copyright Zhun Zhong & Liang Zheng
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 |
7 | Modified by Hangbo Bao, for generating the masked position for visual image transformer
8 | """
9 | # --------------------------------------------------------
10 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
11 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
12 | # Copyright (c) 2021 Microsoft
13 | # Licensed under The MIT License [see LICENSE for details]
14 | # By Hangbo Bao
15 | # Based on timm, DINO and DeiT code bases
16 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
17 | # Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
18 | # Copyright Zhun Zhong & Liang Zheng
19 | #
20 | # Hacked together by / Copyright 2020 Ross Wightman
21 | #
22 | # Modified by Hangbo Bao, for generating the masked position for visual image transformer
23 | # --------------------------------------------------------'
24 | import random
25 | import math
26 | import numpy as np
27 |
28 |
29 | class MaskingGenerator:
30 | def __init__(
31 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None,
32 | min_aspect=0.3, max_aspect=None):
33 | if not isinstance(input_size, tuple):
34 | input_size = (input_size, ) * 2
35 | self.height, self.width = input_size
36 |
37 | self.num_patches = self.height * self.width
38 | self.num_masking_patches = num_masking_patches
39 |
40 | self.min_num_patches = min_num_patches
41 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
42 |
43 | max_aspect = max_aspect or 1 / min_aspect
44 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
45 |
46 | def __repr__(self):
47 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
48 | self.height, self.width, self.min_num_patches, self.max_num_patches,
49 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
50 | return repr_str
51 |
52 | def get_shape(self):
53 | return self.height, self.width
54 |
55 | def _mask(self, mask, max_mask_patches):
56 | delta = 0
57 | for attempt in range(10):
58 | target_area = random.uniform(self.min_num_patches, max_mask_patches)
59 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
60 | h = int(round(math.sqrt(target_area * aspect_ratio)))
61 | w = int(round(math.sqrt(target_area / aspect_ratio)))
62 | if w < self.width and h < self.height:
63 | top = random.randint(0, self.height - h)
64 | left = random.randint(0, self.width - w)
65 |
66 | num_masked = mask[top: top + h, left: left + w].sum()
67 | # Overlap
68 | if 0 < h * w - num_masked <= max_mask_patches:
69 | for i in range(top, top + h):
70 | for j in range(left, left + w):
71 | if mask[i, j] == 0:
72 | mask[i, j] = 1
73 | delta += 1
74 |
75 | if delta > 0:
76 | break
77 | return delta
78 |
79 | def __call__(self):
80 | mask = np.zeros(shape=self.get_shape(), dtype=np.int)
81 | mask_count = 0
82 | while mask_count < self.num_masking_patches:
83 | max_mask_patches = self.num_masking_patches - mask_count
84 | max_mask_patches = min(max_mask_patches, self.max_num_patches)
85 |
86 | delta = self._mask(mask, max_mask_patches)
87 | if delta == 0:
88 | break
89 | else:
90 | mask_count += delta
91 |
92 | return mask
93 |
--------------------------------------------------------------------------------
/MOODv1/dall_e/encoder.py:
--------------------------------------------------------------------------------
1 | import attr
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from collections import OrderedDict
9 | from functools import partial
10 | from dall_e.utils import Conv2d
11 |
12 | @attr.s(eq=False, repr=False)
13 | class EncoderBlock(nn.Module):
14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
17 |
18 | device: torch.device = attr.ib(default=None)
19 | requires_grad: bool = attr.ib(default=False)
20 |
21 | def __attrs_post_init__(self) -> None:
22 | super().__init__()
23 | self.n_hid = self.n_out // 4
24 | self.post_gain = 1 / (self.n_layers ** 2)
25 |
26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
28 | self.res_path = nn.Sequential(OrderedDict([
29 | ('relu_1', nn.ReLU()),
30 | ('conv_1', make_conv(self.n_in, self.n_hid, 3)),
31 | ('relu_2', nn.ReLU()),
32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
33 | ('relu_3', nn.ReLU()),
34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
35 | ('relu_4', nn.ReLU()),
36 | ('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | return self.id_path(x) + self.post_gain * self.res_path(x)
40 |
41 | @attr.s(eq=False, repr=False)
42 | class Encoder(nn.Module):
43 | group_count: int = 4
44 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
45 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
46 | input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
47 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
48 |
49 | device: torch.device = attr.ib(default=torch.device('cpu'))
50 | requires_grad: bool = attr.ib(default=False)
51 | use_mixed_precision: bool = attr.ib(default=True)
52 |
53 | def __attrs_post_init__(self) -> None:
54 | super().__init__()
55 |
56 | blk_range = range(self.n_blk_per_group)
57 | n_layers = self.group_count * self.n_blk_per_group
58 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
59 | make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device,
60 | requires_grad=self.requires_grad)
61 |
62 | self.blocks = nn.Sequential(OrderedDict([
63 | ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
64 | ('group_1', nn.Sequential(OrderedDict([
65 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
66 | ('pool', nn.MaxPool2d(kernel_size=2)),
67 | ]))),
68 | ('group_2', nn.Sequential(OrderedDict([
69 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
70 | ('pool', nn.MaxPool2d(kernel_size=2)),
71 | ]))),
72 | ('group_3', nn.Sequential(OrderedDict([
73 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
74 | ('pool', nn.MaxPool2d(kernel_size=2)),
75 | ]))),
76 | ('group_4', nn.Sequential(OrderedDict([
77 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
78 | ]))),
79 | ('output', nn.Sequential(OrderedDict([
80 | ('relu', nn.ReLU()),
81 | ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),
82 | ]))),
83 | ]))
84 |
85 | def forward(self, x: torch.Tensor) -> torch.Tensor:
86 | if len(x.shape) != 4:
87 | raise ValueError(f'input shape {x.shape} is not 4d')
88 | if x.shape[1] != self.input_channels:
89 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
90 | if x.dtype != torch.float32:
91 | raise ValueError('input must have dtype torch.float32')
92 |
93 | return self.blocks(x)
94 |
--------------------------------------------------------------------------------
/MOODv1/dall_e/decoder.py:
--------------------------------------------------------------------------------
1 | import attr
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from collections import OrderedDict
9 | from functools import partial
10 | from dall_e.utils import Conv2d
11 |
12 | @attr.s(eq=False, repr=False)
13 | class DecoderBlock(nn.Module):
14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
17 |
18 | device: torch.device = attr.ib(default=None)
19 | requires_grad: bool = attr.ib(default=False)
20 |
21 | def __attrs_post_init__(self) -> None:
22 | super().__init__()
23 | self.n_hid = self.n_out // 4
24 | self.post_gain = 1 / (self.n_layers ** 2)
25 |
26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
28 | self.res_path = nn.Sequential(OrderedDict([
29 | ('relu_1', nn.ReLU()),
30 | ('conv_1', make_conv(self.n_in, self.n_hid, 1)),
31 | ('relu_2', nn.ReLU()),
32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
33 | ('relu_3', nn.ReLU()),
34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
35 | ('relu_4', nn.ReLU()),
36 | ('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | return self.id_path(x) + self.post_gain * self.res_path(x)
40 |
41 | @attr.s(eq=False, repr=False)
42 | class Decoder(nn.Module):
43 | group_count: int = 4
44 | n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8)
45 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
46 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
47 | output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
48 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
49 |
50 | device: torch.device = attr.ib(default=torch.device('cpu'))
51 | requires_grad: bool = attr.ib(default=False)
52 | use_mixed_precision: bool = attr.ib(default=True)
53 |
54 | def __attrs_post_init__(self) -> None:
55 | super().__init__()
56 |
57 | blk_range = range(self.n_blk_per_group)
58 | n_layers = self.group_count * self.n_blk_per_group
59 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
60 | make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device,
61 | requires_grad=self.requires_grad)
62 |
63 | self.blocks = nn.Sequential(OrderedDict([
64 | ('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),
65 | ('group_1', nn.Sequential(OrderedDict([
66 | *[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
67 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
68 | ]))),
69 | ('group_2', nn.Sequential(OrderedDict([
70 | *[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
71 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
72 | ]))),
73 | ('group_3', nn.Sequential(OrderedDict([
74 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
75 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
76 | ]))),
77 | ('group_4', nn.Sequential(OrderedDict([
78 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
79 | ]))),
80 | ('output', nn.Sequential(OrderedDict([
81 | ('relu', nn.ReLU()),
82 | ('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),
83 | ]))),
84 | ]))
85 |
86 | def forward(self, x: torch.Tensor) -> torch.Tensor:
87 | if len(x.shape) != 4:
88 | raise ValueError(f'input shape {x.shape} is not 4d')
89 | if x.shape[1] != self.vocab_size:
90 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')
91 | if x.dtype != torch.float32:
92 | raise ValueError('input must have dtype torch.float32')
93 |
94 | return self.blocks(x)
95 |
--------------------------------------------------------------------------------
/MOODv2/src/extract_feature_vit.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | import pickle
4 | from os.path import dirname
5 | import os
6 | import mmengine
7 | import numpy as np
8 | import torch
9 | import torchvision as tv
10 | from tqdm import tqdm
11 |
12 | from list_dataset import ImageFilelist
13 | from mmpretrain.apis import init_model
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser(description='Say hello')
18 | parser.add_argument('data_root', help='Path to data')
19 | parser.add_argument('--out_file', help='Path to output file')
20 | parser.add_argument(
21 | '--cfg', default='configs/vit-base-p16-384.py', help='Path to config')
22 | parser.add_argument(
23 | '--checkpoint',
24 | default='checkpoints/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-'
25 | '384_20210928-98e8652b.pth',
26 | help='Path to checkpoint')
27 | parser.add_argument('--img_list', help='Path to image list')
28 | parser.add_argument('--batch', type=int, default=256, help='batch size')
29 | parser.add_argument('--workers', type=int, default=4, help='num of workers')
30 | parser.add_argument('--fc_save_path', default=None, help='Path to save fc')
31 | return parser.parse_args()
32 |
33 |
34 | def main():
35 | args = parse_args()
36 |
37 | torch.backends.cudnn.benchmark = True
38 |
39 | cfg = mmengine.Config.fromfile(args.cfg)
40 | model = init_model(cfg, args.checkpoint, 0).cuda().eval()
41 |
42 | if args.fc_save_path is not None:
43 | if os.path.exists(args.fc_save_path):
44 | print(f'{args.fc_save_path} exists.')
45 | return
46 | mmengine.mkdir_or_exist(dirname(args.fc_save_path))
47 | if cfg.model.head.type == 'VisionTransformerClsHead':
48 | fc = model.head.layers.head
49 | elif cfg.model.head.type == 'LinearClsHead':
50 | fc = model.head.fc
51 | elif cfg.model.head.type in ['BEiTV1Head', 'BEiTV2Head']:
52 | fc = model.head.cls_head
53 | elif cfg.model.head.type in ['MoCoV3Head']:
54 | print(f'{cfg.model.head.type} utilize NonLinearNetwork which cannot be represented by a weight and a bias')
55 | raise
56 | else:
57 | print(cfg.model.head.type)
58 | print(model.head)
59 | import pdb;pdb.set_trace()
60 | raise NotImplementedError(cfg.model.backbone.type)
61 | w = fc.weight.cpu().detach().numpy()
62 | b = fc.bias.cpu().detach().numpy()
63 | with open(args.fc_save_path, 'wb') as f:
64 | pickle.dump([w, b], f)
65 | return
66 |
67 | # if os.path.exists(out_file):
68 | # print(f'{out_file} exists.')
69 | # return
70 |
71 | if hasattr(cfg.model.backbone, 'img_size'):
72 | img_size = cfg.model.backbone.img_size
73 | else:
74 | img_size = 224
75 |
76 | transform = tv.transforms.Compose([
77 | tv.transforms.Resize((img_size, img_size)),
78 | tv.transforms.ToTensor(),
79 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
80 | ])
81 |
82 | if args.img_list not in [None, 'None']:
83 | dataset = ImageFilelist(args.data_root, args.img_list, transform)
84 | else:
85 | dataset = tv.datasets.ImageFolder(args.data_root, transform)
86 | print(f'lenth of dataset: {len(dataset)}')
87 |
88 | dataloader = torch.utils.data.DataLoader(
89 | dataset,
90 | batch_size=args.batch,
91 | shuffle=False,
92 | num_workers=args.workers,
93 | pin_memory=True,
94 | drop_last=False)
95 |
96 | features = []
97 | with torch.no_grad():
98 | for i, (x, _) in enumerate(tqdm(dataloader)):
99 | x = x.cuda()
100 | if cfg.model.backbone.type == 'BEiTPretrainViT':
101 | # (B, L, C) -> (B, C)
102 | feat_batch = model.backbone(
103 | x, mask=None)[0].mean(1)
104 | elif cfg.model.backbone.type == 'SwinTransformer':
105 | # (B, C, H, W) -> (B, C)
106 | feat_batch = model.backbone(x)[0]
107 | B, C, H, W = feat_batch.shape
108 | feat_batch = feat_batch.reshape(B, C, -1).mean(-1)
109 | else:
110 | # (B, C)
111 | feat_batch = model.backbone(x)[0]
112 | assert len(feat_batch.shape) == 2
113 | features.append(feat_batch.cpu().numpy())
114 |
115 | features = np.concatenate(features, axis=0)
116 | print(f'features: {features.shape}')
117 | mmengine.mkdir_or_exist(dirname(args.out_file))
118 | with open(args.out_file, 'wb') as f:
119 | pickle.dump(features, f)
120 |
121 | if __name__ == '__main__':
122 | main()
123 |
--------------------------------------------------------------------------------
/MOODv1/engine_for_pretraining.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm, DINO and DeiT code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # https://github.com/facebookresearch/deit/
10 | # https://github.com/facebookresearch/dino
11 | # --------------------------------------------------------'
12 | import math
13 | import sys
14 | from typing import Iterable
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | import utils
20 | print_freq = 1000
21 |
22 | def train_one_epoch(model: torch.nn.Module, d_vae: torch.nn.Module,
23 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
24 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
25 | log_writer=None, lr_scheduler=None, start_steps=None,
26 | lr_schedule_values=None, wd_schedule_values=None):
27 | model.train()
28 | metric_logger = utils.MetricLogger(delimiter=" ")
29 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
30 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
31 | header = 'Epoch: [{}]'.format(epoch)
32 |
33 | for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
34 | # assign learning rate & weight decay for each step
35 | it = start_steps + step # global training iteration
36 | if lr_schedule_values is not None or wd_schedule_values is not None:
37 | for i, param_group in enumerate(optimizer.param_groups):
38 | if lr_schedule_values is not None:
39 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
40 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
41 | param_group["weight_decay"] = wd_schedule_values[it]
42 |
43 | samples, images, bool_masked_pos = batch
44 | images = images.to(device, non_blocking=True)
45 | samples = samples.to(device, non_blocking=True)
46 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True)
47 |
48 | with torch.no_grad():
49 | input_ids = d_vae.get_codebook_indices(images).flatten(1)
50 | bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
51 | labels = input_ids[bool_masked_pos]
52 |
53 | with torch.cuda.amp.autocast():
54 | outputs = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False) # torch.Size([74, 8192])
55 | loss = nn.CrossEntropyLoss()(input=outputs, target=labels) # torch.Size([74])
56 |
57 | loss_value = loss.item()
58 |
59 | if not math.isfinite(loss_value):
60 | print("Loss is {}, stopping training".format(loss_value))
61 | sys.exit(1)
62 |
63 | optimizer.zero_grad()
64 | # this attribute is added by timm on one optimizer (adahessian)
65 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
66 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
67 | parameters=model.parameters(), create_graph=is_second_order)
68 | loss_scale_value = loss_scaler.state_dict()["scale"]
69 |
70 | torch.cuda.synchronize()
71 |
72 | mlm_acc = (outputs.max(-1)[1] == labels).float().mean().item()
73 |
74 | metric_logger.update(mlm_acc=mlm_acc)
75 | if log_writer is not None:
76 | log_writer.update(mlm_acc=mlm_acc, head="loss")
77 |
78 | metric_logger.update(loss=loss_value)
79 | metric_logger.update(loss_scale=loss_scale_value)
80 | min_lr = 10.
81 | max_lr = 0.
82 | for group in optimizer.param_groups:
83 | min_lr = min(min_lr, group["lr"])
84 | max_lr = max(max_lr, group["lr"])
85 |
86 | metric_logger.update(lr=max_lr)
87 | metric_logger.update(min_lr=min_lr)
88 | weight_decay_value = None
89 | for group in optimizer.param_groups:
90 | if group["weight_decay"] > 0:
91 | weight_decay_value = group["weight_decay"]
92 | metric_logger.update(weight_decay=weight_decay_value)
93 | metric_logger.update(grad_norm=grad_norm)
94 |
95 | if log_writer is not None:
96 | log_writer.update(loss=loss_value, head="loss")
97 | log_writer.update(loss_scale=loss_scale_value, head="opt")
98 | log_writer.update(lr=max_lr, head="opt")
99 | log_writer.update(min_lr=min_lr, head="opt")
100 | log_writer.update(weight_decay=weight_decay_value, head="opt")
101 | log_writer.update(grad_norm=grad_norm, head="opt")
102 |
103 | log_writer.set_step()
104 |
105 | if lr_scheduler is not None:
106 | lr_scheduler.step_update(start_steps + step)
107 | # gather the stats from all processes
108 | metric_logger.synchronize_between_processes()
109 | print("Averaged stats:", metric_logger)
110 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
111 |
--------------------------------------------------------------------------------
/MOODv1/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | #from tensorflow.io import gfile
4 | import numpy as np
5 |
6 | def get_l16_config(config):
7 | """ ViT-L/16 configuration """
8 | config.patch_size = 16
9 | config.emb_dim = 1024
10 | config.mlp_dim = 4096
11 | config.num_heads = 16
12 | config.num_layers = 24
13 | config.attn_dropout_rate = 0.0
14 | config.dropout_rate = 0.1
15 | return config
16 |
17 | def get_b16_config(config):
18 | """ ViT-B/16 configuration """
19 | config.patch_size = 16
20 | config.emb_dim = 768
21 | config.mlp_dim = 3072
22 | config.num_heads = 12
23 | config.num_layers = 12
24 | config.attn_dropout_rate = 0.0
25 | config.dropout_rate = 0.1
26 | return config
27 |
28 | def load_checkpoint(path, new_img=384, patch=16, emb_dim=768, layers=12):
29 | """ Load weights from a given checkpoint path in npz/pth """
30 | if path.endswith('npz'):
31 | keys, values = load_jax(path)
32 | state_dict = convert_jax_pytorch(keys, values)
33 | elif path.endswith('pth') or path.endswith('pt'):
34 | state_dict = torch.load(path, map_location=torch.device("cpu"))
35 | else:
36 | raise ValueError("checkpoint format {} not supported yet!".format(path.split('.')[-1]))
37 |
38 | if 'model' in state_dict.keys():
39 | state_dict = state_dict['model']
40 | elif 'state_dict' in state_dict.keys():
41 | state_dict = state_dict['state_dict']
42 |
43 | # if 'pos_embed' in state_dict.keys():
44 | # # Deal with class token
45 | # posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
46 | # model_grid_seq = new_img//patch
47 | # ckpt_grid_seq = int(np.sqrt(posemb_grid.shape[0]))
48 |
49 | # if model_grid_seq!=ckpt_grid_seq:
50 | # # Get old and new grid sizes
51 | # posemb_grid = posemb_grid.reshape(ckpt_grid_seq, ckpt_grid_seq, -1)
52 |
53 | # posemb_grid = torch.unsqueeze(posemb_grid.permute(2, 0, 1), dim=0)
54 | # posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=(model_grid_seq, model_grid_seq), mode='bicubic', align_corners=False)
55 | # posemb_grid = posemb_grid.permute(0, 2, 3, 1).flatten(1, 2)
56 |
57 | # # Deal with class token and return
58 | # posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
59 | # # if 'jx' in path:
60 | # # state_dict['pos_embed'] = posemb
61 | # # else:
62 | # state_dict['transformer.pos_embedding.pos_embedding'] = posemb
63 | # print('Resized positional embedding from (%d,%d) to (%d,%d)'%(ckpt_grid_seq,ckpt_grid_seq,model_grid_seq,model_grid_seq))
64 | return state_dict
65 |
66 |
67 | def load_jax(path):
68 | """ Loads params from a npz checkpoint previously stored with `save()` in jax implemetation """
69 | ckpt_dict = np.load(path, allow_pickle=False)
70 | keys, values = zip(*list(ckpt_dict.items()))
71 | # with gfile.GFile(path, 'rb') as f:
72 | # ckpt_dict = np.load(f, allow_pickle=False)
73 | # keys, values = zip(*list(ckpt_dict.items()))
74 | return keys, values
75 |
76 |
77 | def save_jax_to_pytorch(jax_path, save_path):
78 | model_name = jax_path.split('/')[-1].split('.')[0]
79 | keys, values = load_jax(jax_path)
80 | state_dict = convert_jax_pytorch(keys, values)
81 | checkpoint = {'state_dict': state_dict}
82 | torch.save(checkpoint, os.path.join(save_path, model_name + '.pth'))
83 |
84 |
85 | def replace_names(names):
86 | """ Replace jax model names with pytorch model names """
87 | new_names = []
88 | for name in names:
89 | if name == 'Transformer':
90 | new_names.append('transformer')
91 | elif name == 'encoder_norm':
92 | new_names.append('norm')
93 | elif 'encoderblock' in name:
94 | num = name.split('_')[-1]
95 | new_names.append('encoder_layers')
96 | new_names.append(num)
97 | elif 'LayerNorm' in name:
98 | num = name.split('_')[-1]
99 | if num == '0':
100 | new_names.append('norm{}'.format(1))
101 | elif num == '2':
102 | new_names.append('norm{}'.format(2))
103 | elif 'MlpBlock' in name:
104 | new_names.append('mlp')
105 | elif 'Dense' in name:
106 | num = name.split('_')[-1]
107 | new_names.append('fc{}'.format(int(num) + 1))
108 | elif 'MultiHeadDotProductAttention' in name:
109 | new_names.append('attn')
110 | elif name == 'kernel' or name == 'scale':
111 | new_names.append('weight')
112 | elif name == 'bias':
113 | new_names.append(name)
114 | elif name == 'posembed_input':
115 | new_names.append('pos_embedding')
116 | elif name == 'pos_embedding':
117 | new_names.append('pos_embedding')
118 | elif name == 'embedding':
119 | new_names.append('embedding')
120 | elif name == 'head':
121 | new_names.append('classifier')
122 | elif name == 'cls':
123 | new_names.append('cls_token')
124 | else:
125 | new_names.append(name)
126 | return new_names
127 |
128 |
129 | def convert_jax_pytorch(keys, values):
130 | """ Convert jax model parameters with pytorch model parameters """
131 | state_dict = {}
132 | for key, value in zip(keys, values):
133 |
134 | # convert name to torch names
135 | names = key.split('/')
136 | torch_names = replace_names(names)
137 | torch_key = '.'.join(w for w in torch_names)
138 |
139 | # convert values to tensor and check shapes
140 | tensor_value = torch.tensor(value, dtype=torch.float)
141 | # check shape
142 | num_dim = len(tensor_value.shape)
143 |
144 | if num_dim == 1:
145 | tensor_value = tensor_value.squeeze()
146 | elif num_dim == 2 and torch_names[-1] == 'weight':
147 | # for normal weight, transpose it
148 | tensor_value = tensor_value.T
149 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] in ['query', 'key', 'value']:
150 | feat_dim, num_heads, head_dim = tensor_value.shape
151 | # for multi head attention q/k/v weight
152 | tensor_value = tensor_value
153 | elif num_dim == 2 and torch_names[-1] == 'bias' and torch_names[-2] in ['query', 'key', 'value']:
154 | # for multi head attention q/k/v bias
155 | tensor_value = tensor_value
156 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] == 'out':
157 | # for multi head attention out weight
158 | tensor_value = tensor_value
159 | elif num_dim == 4 and torch_names[-1] == 'weight':
160 | tensor_value = tensor_value.permute(3, 2, 0, 1)
161 |
162 | # print("{}: {}".format(torch_key, tensor_value.shape))
163 | state_dict[torch_key] = tensor_value
164 | return state_dict
165 |
166 |
167 | if __name__ == '__main__':
168 | save_jax_to_pytorch('/Users/leon/Downloads/jax/imagenet21k+imagenet2012_ViT-L_16-224.npz', '/Users/leon/Downloads/pytorch')
169 |
170 |
171 |
--------------------------------------------------------------------------------
/MOODv1/ood_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sklearn.metrics as skm
4 | from torch.utils.data.dataset import Subset
5 | from scipy import linalg
6 | from sklearn.metrics import roc_curve, auc
7 | import matplotlib.pyplot as plt
8 | from PIL import Image, ImageFile
9 | import cv2
10 | ImageFile.LOAD_TRUNCATED_IMAGES = True
11 |
12 | ## utils ##
13 | def mkdir(path):
14 | if not os.path.exists(path):
15 | os.makedirs(path)
16 | return path
17 |
18 |
19 | ### dataset ###
20 | def get_subclass_dataset(dataset, classes):
21 | if not isinstance(classes, list):
22 | classes = [classes]
23 | indices = []
24 | for idx, tgt in enumerate(dataset.targets):
25 | if tgt in classes:
26 | indices.append(idx)
27 | subdataset = Subset(dataset, indices)
28 | return subdataset
29 |
30 |
31 | def get_superclass_list(dataset):
32 | CIFAR10_SUPERCLASS = list(range(10)) # one class
33 | IMAGENET_SUPERCLASS = list(range(30)) # one class
34 | CIFAR100_SUPERCLASS = [
35 | [4, 31, 55, 72, 95],
36 | [1, 33, 67, 73, 91],
37 | [54, 62, 70, 82, 92],
38 | [9, 10, 16, 29, 61],
39 | [0, 51, 53, 57, 83],
40 | [22, 25, 40, 86, 87],
41 | [5, 20, 26, 84, 94],
42 | [6, 7, 14, 18, 24],
43 | [3, 42, 43, 88, 97],
44 | [12, 17, 38, 68, 76],
45 | [23, 34, 49, 60, 71],
46 | [15, 19, 21, 32, 39],
47 | [35, 63, 64, 66, 75],
48 | [27, 45, 77, 79, 99],
49 | [2, 11, 36, 46, 98],
50 | [28, 30, 44, 78, 93],
51 | [37, 50, 65, 74, 80],
52 | [47, 52, 56, 59, 96],
53 | [8, 13, 48, 58, 90],
54 | [41, 69, 81, 85, 89],
55 | ]
56 | if dataset.lower() == 'cifar10':
57 | return CIFAR10_SUPERCLASS
58 | elif dataset.lower() == 'cifar100':
59 | return CIFAR100_SUPERCLASS
60 | elif dataset.lower() == 'imagenet' or dataset.lower() == 'imagenet30':
61 | return IMAGENET_SUPERCLASS
62 | else:
63 | raise NotImplementedError()
64 |
65 | def get_scores_one_cluster(ftrain, ftest, food, args):
66 | methods = {
67 | 'mahalanobis': mahalanobis, # Mahalanobis Distance
68 | 'cos': cosine_similarity, # cosine similarity
69 | 'projection': projection, # projection distance
70 | 'gauss': gauss_distribution, # distribution percentage of gauss distribution
71 | 'kmeans': kmeans, # the distance to the nearest cluster
72 | 'euclidean': euclidean_distance, # euclidean distance
73 | 'minkowski': minkowski_distance, # minkowski distance
74 | 'chebyshev': chebyshev_distance, # chebyshev distance
75 | }
76 | din = methods[args.metric](ftrain, ftest, args)
77 | dood = methods[args.metric](ftrain, food, args)
78 | label = [0] * len(din) + [1] * len(dood)
79 | return din, dood, label
80 |
81 |
82 | def mahalanobis(ftrain, ftest, args):
83 | cov = lambda x: np.cov(x.T, bias=True)
84 | if args.cc or args.avgcc:
85 | dtest = [[] for _ in range(args.nb_classes)]
86 | mean = [[] for _ in range(args.nb_classes)]
87 | std = [[] for _ in range(args.nb_classes)]
88 | for i in range(args.nb_classes):
89 | std[i] = np.linalg.pinv(cov(ftrain[i]))
90 | mean[i] = np.mean(ftrain[i], axis=0, keepdims=True)
91 | dtest[i] = np.sum((ftest - mean[i])* (std[i].dot((ftest - mean[i]).T)).T, axis=-1,)
92 | if args.cc:
93 | return np.min(dtest, axis=0)
94 | else:
95 | return np.mean(dtest, axis=0)
96 |
97 | else:
98 | std = np.linalg.pinv(cov(ftrain))
99 | mean = np.mean(ftrain, axis=0, keepdims=True)
100 | dtest = np.sum((ftest - mean)* (std.dot((ftest - mean).T)).T, axis=-1,)
101 | return dtest
102 |
103 | ## get features ###
104 | def get_features(model, dataloader, name, args, is_train=False, max_num=1e10):
105 | model.eval()
106 | features = []
107 | for index, (img, label) in enumerate(dataloader):
108 | if index >= max_num:
109 | break
110 | img, label = img.cuda(), label.cuda()
111 | feature = model.forward_features(img)
112 | features += list(feature.data.cpu().numpy())
113 | if args.class_idx is None and (index + 1) % 100 == 0:
114 | shape = np.array(features).shape
115 | print('{}: ({}, {})/({}, {})'.format(name, index+1, shape[-1], len(dataloader), shape[-1]), end='\r')
116 | print('\n')
117 | features = np.array(features)
118 | return features
119 |
120 |
121 | #### OOD detection ####
122 | def get_roc_sklearn(xin, xood, labels):
123 | data = np.concatenate((xin, xood))
124 | auroc = skm.roc_auc_score(labels, data)
125 | # import pdb;pdb.set_trace()
126 | return auroc
127 |
128 |
129 | def get_pr_sklearn(xin, xood, labels=None):
130 | data = np.concatenate((xin, xood))
131 | aupr = skm.average_precision_score(labels, data)
132 | return aupr
133 |
134 |
135 | def get_fpr(xin, xood, labels):
136 | if labels == [0] * len(xin) + [1] * len(xood):
137 | return np.sum(xood < np.percentile(xin, 95)) / len(xood)
138 | elif labels == [1] * len(xin) + [0] * len(xood):
139 | return np.sum(xood > np.percentile(xin, 95)) / len(xood)
140 | else:
141 | raise
142 |
143 |
144 | def projection(ftrain, ftest):
145 | from sklearn.metrics.pairwise import cosine_similarity
146 | matrix_in = cosine_similarity(ftrain, ftest)
147 | mod_in = np.linalg.norm(ftest)**2
148 | din = np.max(matrix_in, axis=1)*mod_in
149 | return din
150 |
151 |
152 | def gauss_distribution(ftrain, ftest):
153 | shrunkcov = True
154 | if shrunkcov:
155 | from sklearn.covariance import ledoit_wolf
156 | print("Using ledoit-wolf covariance estimator.")
157 | cov = lambda x: ledoit_wolf(x)[0]
158 | else:
159 | cov = lambda x: np.cov(x.T, bias=True)
160 |
161 | std = np.linalg.pinv(cov(ftrain))
162 | mean = np.mean(ftrain, axis=0, keepdims=True)
163 | D = len(ftrain[0])
164 |
165 | dtest = np.sum((ftest - mean)* (std.dot((ftest - mean).T)).T, axis=-1,)
166 | k = 1 / ((2*np.pi)**(D/2) * np.linalg.norm(std)**0.5)
167 | ptest = k * np.exp(-0.5 * dtest)
168 | return ptest
169 |
170 | def kmeans(ftrain, ftest, ypred, nclusters):
171 | from sklearn.cluster import KMeans
172 | kMeansModel = KMeans(init='k-means++', n_clusters=nclusters, max_iter=100000)
173 | kMeansModel.fit(ftrain)
174 |
175 | distances = kMeansModel.transform(ftest)
176 | inDtC = np.min(distances, axis=1)
177 |
178 | return inDtC
179 |
180 | def cosine_similarity(ftrain, ftest):
181 | from sklearn.metrics.pairwise import cosine_similarity
182 | matrix_in = cosine_similarity(ftrain, ftest)
183 | din = np.max(matrix_in, axis=1)
184 | return din
185 |
186 | def euclidean_distance(ftrain, ftest):
187 | mean = np.mean(ftrain, axis=0, keepdims=True)
188 | dtest = np.sqrt(np.sum(np.power(ftest - mean, 2), axis=-1,))
189 | return dtest
190 |
191 | def minkowski_distance(ftrain, ftest):
192 | mean = np.mean(ftrain, axis=0, keepdims=True)
193 | dtest = np.sum(np.abs(ftest - mean), axis=-1,)
194 | return dtest
195 |
196 | def chebyshev_distance(ftrain, ftest):
197 | mean = np.mean(ftrain, axis=0, keepdims=True)
198 | dtest = np.max(np.abs(ftest - mean), axis=-1,)
199 | return dtest
--------------------------------------------------------------------------------
/MOODv1/transforms.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # --------------------------------------------------------'
10 | import torch
11 | import torchvision.transforms.functional as F
12 | from PIL import Image
13 | import warnings
14 | import math
15 | import random
16 | import numpy as np
17 |
18 |
19 | class ToNumpy:
20 |
21 | def __call__(self, pil_img):
22 | np_img = np.array(pil_img, dtype=np.uint8)
23 | if np_img.ndim < 3:
24 | np_img = np.expand_dims(np_img, axis=-1)
25 | np_img = np.rollaxis(np_img, 2) # HWC to CHW
26 | return np_img
27 |
28 |
29 | class ToTensor:
30 |
31 | def __init__(self, dtype=torch.float32):
32 | self.dtype = dtype
33 |
34 | def __call__(self, pil_img):
35 | np_img = np.array(pil_img, dtype=np.uint8)
36 | if np_img.ndim < 3:
37 | np_img = np.expand_dims(np_img, axis=-1)
38 | np_img = np.rollaxis(np_img, 2) # HWC to CHW
39 | return torch.from_numpy(np_img).to(dtype=self.dtype)
40 |
41 |
42 | _pil_interpolation_to_str = {
43 | Image.NEAREST: 'PIL.Image.NEAREST',
44 | Image.BILINEAR: 'PIL.Image.BILINEAR',
45 | Image.BICUBIC: 'PIL.Image.BICUBIC',
46 | Image.LANCZOS: 'PIL.Image.LANCZOS',
47 | Image.HAMMING: 'PIL.Image.HAMMING',
48 | Image.BOX: 'PIL.Image.BOX',
49 | }
50 |
51 |
52 | def _pil_interp(method):
53 | if method == 'bicubic':
54 | return Image.BICUBIC
55 | elif method == 'lanczos':
56 | return Image.LANCZOS
57 | elif method == 'hamming':
58 | return Image.HAMMING
59 | else:
60 | # default bilinear, do we want to allow nearest?
61 | return Image.BILINEAR
62 |
63 |
64 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
65 |
66 |
67 | class RandomResizedCropAndInterpolationWithTwoPic:
68 | """Crop the given PIL Image to random size and aspect ratio with random interpolation.
69 |
70 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random
71 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
72 | is finally resized to given size.
73 | This is popularly used to train the Inception networks.
74 |
75 | Args:
76 | size: expected output size of each edge
77 | scale: range of size of the origin size cropped
78 | ratio: range of aspect ratio of the origin aspect ratio cropped
79 | interpolation: Default: PIL.Image.BILINEAR
80 | """
81 |
82 | def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
83 | interpolation='bilinear', second_interpolation='lanczos'):
84 | if isinstance(size, tuple):
85 | self.size = size
86 | else:
87 | self.size = (size, size)
88 | if second_size is not None:
89 | if isinstance(second_size, tuple):
90 | self.second_size = second_size
91 | else:
92 | self.second_size = (second_size, second_size)
93 | else:
94 | self.second_size = None
95 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
96 | warnings.warn("range should be of kind (min, max)")
97 |
98 | if interpolation == 'random':
99 | self.interpolation = _RANDOM_INTERPOLATION
100 | else:
101 | self.interpolation = _pil_interp(interpolation)
102 | self.second_interpolation = _pil_interp(second_interpolation)
103 | self.scale = scale
104 | self.ratio = ratio
105 |
106 | @staticmethod
107 | def get_params(img, scale, ratio):
108 | """Get parameters for ``crop`` for a random sized crop.
109 |
110 | Args:
111 | img (PIL Image): Image to be cropped.
112 | scale (tuple): range of size of the origin size cropped
113 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
114 |
115 | Returns:
116 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
117 | sized crop.
118 | """
119 | area = img.size[0] * img.size[1]
120 |
121 | for attempt in range(10):
122 | target_area = random.uniform(*scale) * area
123 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
124 | aspect_ratio = math.exp(random.uniform(*log_ratio))
125 |
126 | w = int(round(math.sqrt(target_area * aspect_ratio)))
127 | h = int(round(math.sqrt(target_area / aspect_ratio)))
128 |
129 | if w <= img.size[0] and h <= img.size[1]:
130 | i = random.randint(0, img.size[1] - h)
131 | j = random.randint(0, img.size[0] - w)
132 | return i, j, h, w
133 |
134 | # Fallback to central crop
135 | in_ratio = img.size[0] / img.size[1]
136 | if in_ratio < min(ratio):
137 | w = img.size[0]
138 | h = int(round(w / min(ratio)))
139 | elif in_ratio > max(ratio):
140 | h = img.size[1]
141 | w = int(round(h * max(ratio)))
142 | else: # whole image
143 | w = img.size[0]
144 | h = img.size[1]
145 | i = (img.size[1] - h) // 2
146 | j = (img.size[0] - w) // 2
147 | return i, j, h, w
148 |
149 | def __call__(self, img):
150 | """
151 | Args:
152 | img (PIL Image): Image to be cropped and resized.
153 |
154 | Returns:
155 | PIL Image: Randomly cropped and resized image.
156 | """
157 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
158 | if isinstance(self.interpolation, (tuple, list)):
159 | interpolation = random.choice(self.interpolation)
160 | else:
161 | interpolation = self.interpolation
162 | if self.second_size is None:
163 | return F.resized_crop(img, i, j, h, w, self.size, interpolation)
164 | else:
165 | return F.resized_crop(img, i, j, h, w, self.size, interpolation), \
166 | F.resized_crop(img, i, j, h, w, self.second_size, self.second_interpolation)
167 |
168 | def __repr__(self):
169 | if isinstance(self.interpolation, (tuple, list)):
170 | interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
171 | else:
172 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
173 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
174 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
175 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
176 | format_string += ', interpolation={0}'.format(interpolate_str)
177 | if self.second_size is not None:
178 | format_string += ', second_size={0}'.format(self.second_size)
179 | format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
180 | format_string += ')'
181 | return format_string
182 |
--------------------------------------------------------------------------------
/MOODv1/modeling_pretrain.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm and DeiT code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # https://github.com/facebookresearch/deit/
10 | # --------------------------------------------------------'
11 | import math
12 | import torch
13 | import torch.nn as nn
14 | from functools import partial
15 | import matplotlib.pyplot as plt
16 | from modeling_finetune import Block, _cfg, PatchEmbed, RelativePositionBias
17 | from timm.models.registry import register_model
18 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
19 | import numpy as np
20 |
21 | def trunc_normal_(tensor, mean=0., std=1.):
22 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
23 |
24 |
25 | __all__ = [
26 | 'beit_base_patch16_224_8k_vocab',
27 | 'beit_large_patch16_224_8k_vocab',
28 | ]
29 |
30 |
31 | class VisionTransformerForMaskedImageModeling(nn.Module):
32 | def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12,
33 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
34 | drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,
35 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, init_std=0.02, **kwargs):
36 | super().__init__()
37 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
38 |
39 | self.patch_embed = PatchEmbed(
40 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
41 | num_patches = self.patch_embed.num_patches
42 |
43 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
44 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
45 | if use_abs_pos_emb:
46 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
47 | else:
48 | self.pos_embed = None
49 | self.pos_drop = nn.Dropout(p=drop_rate)
50 |
51 | if use_shared_rel_pos_bias:
52 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
53 | else:
54 | self.rel_pos_bias = None
55 |
56 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
57 | self.blocks = nn.ModuleList([
58 | Block(
59 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
60 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
61 | init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
62 | attn_head_dim=attn_head_dim,
63 | )
64 | for i in range(depth)])
65 | self.norm = norm_layer(embed_dim)
66 |
67 | self.init_std = init_std
68 | self.lm_head = nn.Linear(embed_dim, vocab_size)
69 |
70 | if self.pos_embed is not None:
71 | trunc_normal_(self.pos_embed, std=self.init_std)
72 | trunc_normal_(self.cls_token, std=self.init_std)
73 | trunc_normal_(self.mask_token, std=self.init_std)
74 | trunc_normal_(self.lm_head.weight, std=self.init_std)
75 | self.apply(self._init_weights)
76 | self.fix_init_weight()
77 |
78 | def fix_init_weight(self):
79 | def rescale(param, layer_id):
80 | param.div_(math.sqrt(2.0 * layer_id))
81 |
82 | for layer_id, layer in enumerate(self.blocks):
83 | rescale(layer.attn.proj.weight.data, layer_id + 1)
84 | rescale(layer.mlp.fc2.weight.data, layer_id + 1)
85 |
86 | def _init_weights(self, m):
87 | if isinstance(m, nn.Linear):
88 | trunc_normal_(m.weight, std=self.init_std)
89 | if isinstance(m, nn.Linear) and m.bias is not None:
90 | nn.init.constant_(m.bias, 0)
91 | elif isinstance(m, nn.LayerNorm):
92 | nn.init.constant_(m.bias, 0)
93 | nn.init.constant_(m.weight, 1.0)
94 | elif isinstance(m, nn.Conv2d):
95 | trunc_normal_(m.weight, std=self.init_std)
96 | if m.bias is not None:
97 | nn.init.constant_(m.bias, 0)
98 |
99 | @torch.jit.ignore
100 | def no_weight_decay(self):
101 | return {'pos_embed', 'cls_token'}
102 |
103 | def get_num_layers(self):
104 | return len(self.blocks)
105 |
106 |
107 | def forward_features(self, x, bool_masked_pos=None):
108 | x = self.patch_embed(x, bool_masked_pos=bool_masked_pos)
109 | batch_size, seq_len, _ = x.size()
110 |
111 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
112 | mask_token = self.mask_token.expand(batch_size, seq_len, -1)
113 |
114 | # replace the masked visual tokens by mask_token
115 | if bool_masked_pos is not None:
116 | w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
117 | x = x * (1 - w) + mask_token * w
118 |
119 | x = torch.cat((cls_tokens, x), dim=1)
120 | if self.pos_embed is not None:
121 | x = x + self.pos_embed
122 | x = self.pos_drop(x)
123 |
124 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
125 | for blk in self.blocks:
126 | x = blk(x, rel_pos_bias=rel_pos_bias)
127 |
128 | return self.norm(x)
129 |
130 | def forward(self, x, bool_masked_pos=None, return_all_tokens=False):
131 | x = self.forward_features(x, bool_masked_pos=bool_masked_pos)
132 | x = x[:, 1:]
133 | # import pdb;pdb.set_trace()
134 | if return_all_tokens:
135 | return self.lm_head(x)
136 | else:
137 | # return the masked tokens
138 | return self.lm_head(x[bool_masked_pos])
139 |
140 |
141 | @register_model
142 | def beit_base_patch16_224_8k_vocab(pretrained=False, **kwargs):
143 | model = VisionTransformerForMaskedImageModeling(
144 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
145 | norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)
146 | model.default_cfg = _cfg()
147 | if pretrained:
148 | checkpoint = torch.load(
149 | kwargs["init_ckpt"], map_location="cpu"
150 | )
151 | model.load_state_dict(checkpoint["model"])
152 | return model
153 |
154 |
155 | @register_model
156 | def beit_large_patch16_224_8k_vocab(pretrained=False, **kwargs):
157 | model = VisionTransformerForMaskedImageModeling(
158 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
159 | norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)
160 | model.default_cfg = _cfg()
161 | if pretrained:
162 | checkpoint = torch.load(
163 | kwargs["init_ckpt"], map_location="cpu"
164 | )
165 | model.load_state_dict(checkpoint["model"])
166 | return model
167 |
--------------------------------------------------------------------------------
/MOODv1/optim_factory.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # --------------------------------------------------------'
10 | import torch
11 | from torch import optim as optim
12 |
13 | import json
14 |
15 | try:
16 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
17 | has_apex = True
18 | except ImportError:
19 | has_apex = False
20 |
21 |
22 | def get_num_layer_for_vit(var_name, num_max_layer):
23 | if var_name in ("cls_token", "mask_token", "pos_embed"):
24 | return 0
25 | elif var_name.startswith("patch_embed"):
26 | return 0
27 | elif var_name.startswith("rel_pos_bias"):
28 | return num_max_layer - 1
29 | elif var_name.startswith("blocks"):
30 | layer_id = int(var_name.split('.')[1])
31 | return layer_id + 1
32 | else:
33 | return num_max_layer - 1
34 |
35 |
36 | class LayerDecayValueAssigner(object):
37 | def __init__(self, values):
38 | self.values = values
39 |
40 | def get_scale(self, layer_id):
41 | return self.values[layer_id]
42 |
43 | def get_layer_id(self, var_name):
44 | return get_num_layer_for_vit(var_name, len(self.values))
45 |
46 |
47 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
48 | parameter_group_names = {}
49 | parameter_group_vars = {}
50 |
51 | for name, param in model.named_parameters():
52 | if not param.requires_grad:
53 | continue # frozen weights
54 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
55 | group_name = "no_decay"
56 | this_weight_decay = 0.
57 | else:
58 | group_name = "decay"
59 | this_weight_decay = weight_decay
60 | if get_num_layer is not None:
61 | layer_id = get_num_layer(name)
62 | group_name = "layer_%d_%s" % (layer_id, group_name)
63 | else:
64 | layer_id = None
65 |
66 | if group_name not in parameter_group_names:
67 | if get_layer_scale is not None:
68 | scale = get_layer_scale(layer_id)
69 | else:
70 | scale = 1.
71 |
72 | parameter_group_names[group_name] = {
73 | "weight_decay": this_weight_decay,
74 | "params": [],
75 | "lr_scale": scale
76 | }
77 | parameter_group_vars[group_name] = {
78 | "weight_decay": this_weight_decay,
79 | "params": [],
80 | "lr_scale": scale
81 | }
82 |
83 | parameter_group_vars[group_name]["params"].append(param)
84 | parameter_group_names[group_name]["params"].append(name)
85 | # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
86 | return list(parameter_group_vars.values())
87 |
88 |
89 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
90 | opt_lower = args.opt.lower()
91 | weight_decay = args.weight_decay
92 | if weight_decay and filter_bias_and_bn:
93 | skip = {}
94 | if skip_list is not None:
95 | skip = skip_list
96 | elif hasattr(model, 'no_weight_decay'):
97 | skip = model.no_weight_decay()
98 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
99 | weight_decay = 0.
100 | else:
101 | parameters = model.parameters()
102 |
103 | if 'fused' in opt_lower:
104 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
105 |
106 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
107 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
108 | opt_args['eps'] = args.opt_eps
109 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
110 | opt_args['betas'] = args.opt_betas
111 |
112 | opt_split = opt_lower.split('_')
113 | opt_lower = opt_split[-1]
114 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
115 | opt_args.pop('eps', None)
116 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
117 | elif opt_lower == 'momentum':
118 | opt_args.pop('eps', None)
119 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
120 | elif opt_lower == 'adam':
121 | optimizer = optim.Adam(parameters, **opt_args)
122 | elif opt_lower == 'adamw':
123 | optimizer = optim.AdamW(parameters, **opt_args)
124 | elif opt_lower == 'nadam':
125 | from timm.optim.nadam import Nadam
126 | optimizer = Nadam(parameters, **opt_args)
127 | elif opt_lower == 'radam':
128 | from timm.optim.radam import RAdam
129 | optimizer = RAdam(parameters, **opt_args)
130 | elif opt_lower == 'adamp':
131 | from timm.optim.adamp import AdamP
132 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
133 | elif opt_lower == 'sgdp':
134 | from timm.optim.sgdp import SGDP
135 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
136 | elif opt_lower == 'adadelta':
137 | optimizer = optim.Adadelta(parameters, **opt_args)
138 | elif opt_lower == 'adafactor':
139 | if not args.lr:
140 | opt_args['lr'] = None
141 | from timm.optim.adafactor import Adafactor
142 | optimizer = Adafactor(parameters, **opt_args)
143 | elif opt_lower == 'adahessian':
144 | from timm.optim.adahessian import Adahessian
145 | optimizer = Adahessian(parameters, **opt_args)
146 | elif opt_lower == 'rmsprop':
147 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
148 | elif opt_lower == 'rmsproptf':
149 | from timm.optim.rmsprop_tf import RMSpropTF
150 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
151 | elif opt_lower == 'novograd':
152 | from timm.optim.novograd import NovoGrad
153 | optimizer = NovoGrad(parameters, **opt_args)
154 | elif opt_lower == 'nvnovograd':
155 | from timm.optim.nvnovograd import NvNovoGrad
156 | optimizer = NvNovoGrad(parameters, **opt_args)
157 | elif opt_lower == 'fusedsgd':
158 | opt_args.pop('eps', None)
159 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
160 | elif opt_lower == 'fusedmomentum':
161 | opt_args.pop('eps', None)
162 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
163 | elif opt_lower == 'fusedadam':
164 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
165 | elif opt_lower == 'fusedadamw':
166 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
167 | elif opt_lower == 'fusedlamb':
168 | optimizer = FusedLAMB(parameters, **opt_args)
169 | elif opt_lower == 'fusednovograd':
170 | opt_args.setdefault('betas', (0.95, 0.98))
171 | optimizer = FusedNovoGrad(parameters, **opt_args)
172 | else:
173 | assert False and "Invalid optimizer"
174 | raise ValueError
175 |
176 | if len(opt_split) > 1:
177 | if opt_split[0] == 'lookahead':
178 | from timm.optim.lookahead import Lookahead
179 | optimizer = Lookahead(optimizer)
180 |
181 | return optimizer
182 |
--------------------------------------------------------------------------------
/MOODv1/modeling_discrete_vae.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on OpenAI DALL-E and lucidrains' DALLE-pytorch code bases
8 | # https://github.com/openai/DALL-E
9 | # https://github.com/lucidrains/DALLE-pytorch
10 | # --------------------------------------------------------'
11 | from math import sqrt
12 | import os
13 | import torch
14 | from torch import nn, einsum
15 | import torch.nn.functional as F
16 | from einops import rearrange
17 |
18 |
19 | def top_k(logits, thres = 0.5):
20 | num_logits = logits.shape[-1]
21 | k = max(int((1 - thres) * num_logits), 1)
22 | val, ind = torch.topk(logits, k)
23 | probs = torch.full_like(logits, float('-inf'))
24 | probs.scatter_(1, ind, val)
25 | return probs
26 |
27 |
28 | def exists(val):
29 | return val is not None
30 |
31 |
32 | def default(val, d):
33 | return val if exists(val) else d
34 |
35 |
36 | def eval_decorator(fn):
37 | def inner(model, *args, **kwargs):
38 | was_training = model.training
39 | model.eval()
40 | out = fn(model, *args, **kwargs)
41 | model.train(was_training)
42 | return out
43 | return inner
44 |
45 |
46 | class BasicVAE(nn.Module):
47 |
48 | def get_codebook_indices(self, images):
49 | raise NotImplementedError()
50 |
51 | def decode(self, img_seq):
52 | raise NotImplementedError()
53 |
54 | def get_codebook_probs(self, img_seq):
55 | raise NotImplementedError()
56 |
57 | def get_image_tokens_size(self):
58 | pass
59 |
60 | def get_image_size(self):
61 | pass
62 |
63 |
64 | class ResBlock(nn.Module):
65 | def __init__(self, chan_in, hidden_size, chan_out):
66 | super().__init__()
67 | self.net = nn.Sequential(
68 | nn.Conv2d(chan_in, hidden_size, 3, padding=1),
69 | nn.ReLU(),
70 | nn.Conv2d(hidden_size, hidden_size, 3, padding=1),
71 | nn.ReLU(),
72 | nn.Conv2d(hidden_size, chan_out, 1)
73 | )
74 |
75 | def forward(self, x):
76 | return self.net(x) + x
77 |
78 |
79 | class DiscreteVAE(BasicVAE):
80 | def __init__(
81 | self,
82 | image_size = 256,
83 | num_tokens = 512,
84 | codebook_dim = 512,
85 | num_layers = 3,
86 | hidden_dim = 64,
87 | channels = 3,
88 | smooth_l1_loss = False,
89 | temperature = 0.9,
90 | straight_through = False,
91 | kl_div_loss_weight = 0.
92 | ):
93 | super().__init__()
94 | # assert log2(image_size).is_integer(), 'image size must be a power of 2'
95 | assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
96 |
97 | self.image_size = image_size
98 | self.num_tokens = num_tokens
99 | self.num_layers = num_layers
100 | self.temperature = temperature
101 | self.straight_through = straight_through
102 | self.codebook = nn.Embedding(num_tokens, codebook_dim)
103 |
104 | enc_layers = []
105 | dec_layers = []
106 |
107 | enc_in = channels
108 | dec_in = codebook_dim
109 |
110 | for layer_id in range(num_layers):
111 | enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU()))
112 | enc_layers.append(ResBlock(chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim))
113 | enc_in = hidden_dim
114 | dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU()))
115 | dec_layers.append(ResBlock(chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim))
116 | dec_in = hidden_dim
117 |
118 | enc_layers.append(nn.Conv2d(hidden_dim, num_tokens, 1))
119 | dec_layers.append(nn.Conv2d(hidden_dim, channels, 1))
120 |
121 | self.encoder = nn.Sequential(*enc_layers)
122 | self.decoder = nn.Sequential(*dec_layers)
123 |
124 | self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
125 | self.kl_div_loss_weight = kl_div_loss_weight
126 |
127 | def get_image_size(self):
128 | return self.image_size
129 |
130 | def get_image_tokens_size(self):
131 | return self.image_size // 8
132 |
133 | @torch.no_grad()
134 | @eval_decorator
135 | def get_codebook_indices(self, images):
136 | logits = self.forward(images, return_logits = True)
137 | codebook_indices = logits.argmax(dim = 1)
138 | return codebook_indices
139 |
140 | @torch.no_grad()
141 | @eval_decorator
142 | def get_codebook_probs(self, images):
143 | logits = self.forward(images, return_logits = True)
144 | return nn.Softmax(dim=1)(logits)
145 |
146 | def decode(
147 | self,
148 | img_seq
149 | ):
150 | image_embeds = self.codebook(img_seq)
151 | b, n, d = image_embeds.shape
152 | h = w = int(sqrt(n))
153 |
154 | image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
155 | images = self.decoder(image_embeds)
156 | return images
157 |
158 | def forward(
159 | self,
160 | img,
161 | return_loss = False,
162 | return_recons = False,
163 | return_logits = False,
164 | temp = None
165 | ):
166 | device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
167 | assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'
168 |
169 | logits = self.encoder(img)
170 |
171 | if return_logits:
172 | return logits # return logits for getting hard image indices for DALL-E training
173 |
174 | temp = default(temp, self.temperature)
175 | soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
176 | sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
177 | out = self.decoder(sampled)
178 |
179 | if not return_loss:
180 | return out
181 |
182 | # reconstruction loss
183 |
184 | recon_loss = self.loss_fn(img, out)
185 |
186 | # kl divergence
187 |
188 | logits = rearrange(logits, 'b n h w -> b (h w) n')
189 | qy = F.softmax(logits, dim = -1)
190 |
191 | log_qy = torch.log(qy + 1e-10)
192 | log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device))
193 | kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True)
194 |
195 | loss = recon_loss + (kl_div * kl_div_loss_weight)
196 |
197 | if not return_recons:
198 | return loss
199 |
200 | return loss, out
201 |
202 |
203 | from dall_e import load_model
204 |
205 |
206 | class Dalle_VAE(BasicVAE):
207 | def __init__(self, window_size):
208 | super().__init__()
209 | self.encoder = None
210 | self.decoder = None
211 | self.window_size = window_size
212 |
213 | def load_model(self, model_dir, device):
214 | self.encoder = load_model(os.path.join(model_dir, "encoder.pkl"), device)
215 | self.decoder = load_model(os.path.join(model_dir, "decoder.pkl"), device)
216 |
217 | def decode(self, img_seq):
218 | bsz = img_seq.size()[0]
219 | img_seq = img_seq.view(bsz, self.window_size, self.window_size)
220 | z = F.one_hot(img_seq, num_classes=self.encoder.vocab_size).permute(0, 3, 1, 2).float()
221 | return self.decoder(z).float()
222 |
223 | def get_codebook_indices(self, images):
224 | z_logits = self.encoder(images)
225 | return torch.argmax(z_logits, axis=1)
226 |
227 | def get_codebook_probs(self, images):
228 | z_logits = self.encoder(images)
229 | return nn.Softmax(dim=1)(z_logits)
230 |
231 | def forward(self, img_seq_prob, no_process=False):
232 | if no_process:
233 | return self.decoder(img_seq_prob.float()).float()
234 | else:
235 | bsz, seq_len, num_class = img_seq_prob.size()
236 | z = img_seq_prob.view(bsz, self.window_size, self.window_size, self.encoder.vocab_size)
237 | return self.decoder(z.permute(0, 3, 1, 2).float()).float()
238 |
--------------------------------------------------------------------------------
/MOODv1/engine_for_finetuning.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm, DINO and DeiT code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # https://github.com/facebookresearch/deit/
10 | # https://github.com/facebookresearch/dino
11 | # --------------------------------------------------------'
12 | import math
13 | import sys
14 | from typing import Iterable, Optional
15 |
16 | import torch
17 |
18 | from timm.data import Mixup
19 | from timm.utils import accuracy, ModelEma
20 |
21 | import utils
22 |
23 |
24 | def train_class_batch(model, samples, target, criterion,
25 | pretrain_model=None, l2_alpha=0.0, l2_beta=0.0):
26 | outputs = model(samples)
27 | loss = criterion(outputs, target)
28 | # print(f'ce loss {loss}', end=', ')
29 | if pretrain_model is not None:
30 | reg_loss = 0
31 | for params, pretrain_params in zip(model.parameters(), pretrain_model.parameters()):
32 | if params.size() == pretrain_params.size():
33 | # parts that share the architecture
34 | delta_param = params - pretrain_params
35 | reg_loss += 0.5 * l2_alpha * delta_param.norm(2)**2
36 | else: # parts that differ
37 | reg_loss += 0.5 * l2_beta * params.norm(2)**2
38 | loss += reg_loss
39 | # print(f'reg loss {reg_loss}, total loss {loss}')
40 | return loss, outputs
41 |
42 |
43 | def get_loss_scale_for_deepspeed(model):
44 | optimizer = model.optimizer
45 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
46 |
47 |
48 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
49 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
50 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
51 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
52 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
53 | num_training_steps_per_epoch=None, update_freq=None,
54 | pretrain_model=None, l2_alpha=0.0, l2_beta=0.0,):
55 | model.train(True)
56 | metric_logger = utils.MetricLogger(delimiter=" ")
57 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
58 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
59 | header = 'Epoch: [{}]'.format(epoch)
60 | print_freq = 10
61 |
62 | if loss_scaler is None:
63 | model.zero_grad()
64 | model.micro_steps = 0
65 | else:
66 | optimizer.zero_grad()
67 |
68 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
69 | step = data_iter_step // update_freq
70 | if step >= num_training_steps_per_epoch:
71 | continue
72 | it = start_steps + step # global training iteration
73 | # Update LR & WD for the first acc
74 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
75 | for i, param_group in enumerate(optimizer.param_groups):
76 | if lr_schedule_values is not None:
77 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
78 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
79 | param_group["weight_decay"] = wd_schedule_values[it]
80 |
81 | samples = samples.to(device, non_blocking=True)
82 | targets = targets.to(device, non_blocking=True)
83 |
84 | if mixup_fn is not None:
85 | samples, targets = mixup_fn(samples, targets)
86 |
87 | if loss_scaler is None:
88 | samples = samples.half()
89 | loss, output = train_class_batch(
90 | model, samples, targets, criterion,
91 | pretrain_model=pretrain_model, l2_alpha=l2_alpha, l2_beta=l2_beta,)
92 | else:
93 | with torch.cuda.amp.autocast():
94 | loss, output = train_class_batch(
95 | model, samples, targets, criterion,
96 | pretrain_model=pretrain_model, l2_alpha=l2_alpha, l2_beta=l2_beta,)
97 |
98 | loss_value = loss.item()
99 |
100 | if not math.isfinite(loss_value):
101 | print("Loss is {}, stopping training".format(loss_value))
102 | sys.exit(1)
103 |
104 | if loss_scaler is None:
105 | loss /= update_freq
106 | model.backward(loss)
107 | model.step()
108 |
109 | if (data_iter_step + 1) % update_freq == 0:
110 | # model.zero_grad()
111 | # Deepspeed will call step() & model.zero_grad() automatic
112 | if model_ema is not None:
113 | model_ema.update(model)
114 | grad_norm = None
115 | loss_scale_value = get_loss_scale_for_deepspeed(model)
116 | else:
117 | # this attribute is added by timm on one optimizer (adahessian)
118 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
119 | loss /= update_freq
120 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
121 | parameters=model.parameters(), create_graph=is_second_order,
122 | update_grad=(data_iter_step + 1) % update_freq == 0)
123 | if (data_iter_step + 1) % update_freq == 0:
124 | optimizer.zero_grad()
125 | if model_ema is not None:
126 | model_ema.update(model)
127 | loss_scale_value = loss_scaler.state_dict()["scale"]
128 |
129 | torch.cuda.synchronize()
130 |
131 | if mixup_fn is None:
132 | class_acc = (output.max(-1)[-1] == targets).float().mean()
133 | else:
134 | class_acc = None
135 | metric_logger.update(loss=loss_value)
136 | metric_logger.update(class_acc=class_acc)
137 | metric_logger.update(loss_scale=loss_scale_value)
138 | min_lr = 10.
139 | max_lr = 0.
140 | for group in optimizer.param_groups:
141 | min_lr = min(min_lr, group["lr"])
142 | max_lr = max(max_lr, group["lr"])
143 |
144 | metric_logger.update(lr=max_lr)
145 | metric_logger.update(min_lr=min_lr)
146 | weight_decay_value = None
147 | for group in optimizer.param_groups:
148 | if group["weight_decay"] > 0:
149 | weight_decay_value = group["weight_decay"]
150 | metric_logger.update(weight_decay=weight_decay_value)
151 | metric_logger.update(grad_norm=grad_norm)
152 |
153 | if log_writer is not None:
154 | log_writer.update(loss=loss_value, head="loss")
155 | log_writer.update(class_acc=class_acc, head="loss")
156 | log_writer.update(loss_scale=loss_scale_value, head="opt")
157 | log_writer.update(lr=max_lr, head="opt")
158 | log_writer.update(min_lr=min_lr, head="opt")
159 | log_writer.update(weight_decay=weight_decay_value, head="opt")
160 | log_writer.update(grad_norm=grad_norm, head="opt")
161 |
162 | log_writer.set_step()
163 |
164 | # gather the stats from all processes
165 | metric_logger.synchronize_between_processes()
166 | print("Averaged stats:", metric_logger)
167 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
168 |
169 |
170 | @torch.no_grad()
171 | def evaluate(data_loader, model, device):
172 | criterion = torch.nn.CrossEntropyLoss()
173 |
174 | metric_logger = utils.MetricLogger(delimiter=" ")
175 | header = 'Test:'
176 |
177 | # switch to evaluation mode
178 | model.eval()
179 |
180 | for batch in metric_logger.log_every(data_loader, 10, header):
181 | images = batch[0]
182 | target = batch[-1]
183 | images = images.to(device, non_blocking=True)
184 | target = target.to(device, non_blocking=True)
185 |
186 | # compute output
187 | with torch.cuda.amp.autocast():
188 | output = model(images)
189 | loss = criterion(output, target)
190 |
191 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
192 |
193 | batch_size = images.shape[0]
194 | metric_logger.update(loss=loss.item())
195 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
196 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
197 | # gather the stats from all processes
198 | metric_logger.synchronize_between_processes()
199 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
200 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
201 |
202 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
203 |
--------------------------------------------------------------------------------
/MOODv1/README.md:
--------------------------------------------------------------------------------
1 | # MOODv1
2 |
3 |
4 | • 🤗 Model
5 | • 🐱 Code
6 | • 📃 MOODv1
7 | • 📃 MOODv2
8 |
9 | ## Abstract
10 | The core of out-of-distribution (OOD) detection is to learn the in-distribution (ID) representation, which is distinguishable from OOD samples. Previous work applied recognition-based methods to learn the ID features, which tend to learn shortcuts instead of comprehensive representations. In this work, we find surprisingly that simply using reconstruction-based methods could boost the performance of OOD detection significantly. We deeply explore the main contributors of OOD detection and find that reconstruction-based pretext tasks have the potential to provide a generally applicable and efficacious prior, which benefits the model in learning intrinsic data distributions of the ID dataset. Specifically, we take Masked Image Modeling as a pretext task for our OOD detection framework (MOOD). Without bells and whistles, MOOD outperforms previous SOTA of one-class OOD detection by 5.7%, multi-class OOD detection by 3.0%, and near-distribution OOD detection by 2.1%. It even defeats the 10-shot-per-class outlier exposure OOD detection, although we do not include any OOD samples for our detection.
11 |
12 | ## Setup
13 | Follow official [BEiT](https://github.com/microsoft/unilm/tree/master/beit) to setup.
14 |
15 | ## Datasets
16 | We suggest to organize datasets as following
17 | ```bash
18 | - MOOD
19 | - data
20 | - cifar10
21 | - cifar-10-batches-py
22 | - cifar100
23 | - cifar-100-python
24 | - imagenet30
25 | - test
26 | - train
27 | - val
28 | - imagenet1k
29 | - test
30 | - train
31 | - val
32 | - $OOD_DATASET
33 | - images
34 | ...
35 | ```
36 | In this case, for example, if you want to train on CIFAR-10, set the parameters `-- data_path ./data/cifar10 --data_set cifar10`.
37 |
38 | We provide `datasets/imagenet30.py` for you to create soft link for `imagenet30`.
39 |
40 | ## Pretrained models
41 |
42 | Follow [BEiT](https://github.com/microsoft/unilm/tree/master/beit) to pre-train the model or directly utilize the official released weights pretrained on ImageNet-22k. The models were pretrained with 224x224 resolution.
43 | - `BEiT-base`: #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16 (#parameters: 86M)
44 | - `BEiT-large`: #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16 (#parameters: 304M)
45 |
46 | Download checkpoints that are **self-supervised pretrained and then intermediate fine-tuned** on ImageNet-22k (recommended):
47 | - BEiT-base: [beit_base_patch16_224_pt22k_ft22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth)
48 | - BEiT-large: [beit_large_patch16_224_pt22k_ft22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth)
49 |
50 | Download checkpoints that are **self-supervised pretrained** on ImageNet-22k:
51 | - BEiT-base: [beit_base_patch16_224_pt22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k.pth)
52 | - BEiT-large: [beit_large_patch16_224_pt22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k.pth)
53 |
54 | ## Fine-tuning on In-Distribution Dataset
55 | ### Multi-Class Fine-tuning
56 | For ViT-large,
57 | ```bash
58 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_class_finetuning.py \
59 | --model beit_large_patch16_224 --data_path $ID_DATA_PATH --data_set $ID_DATASET \
60 | --finetune https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k.pth \
61 | --batch_size 8 --lr 2e-5 --update_freq 2 \
62 | --warmup_epochs 5 --epochs 100 --layer_decay 0.9 --drop_path 0.4 \
63 | --weight_decay 1e-8 --enable_deepspeed
64 | ```
65 | The hyper-parameters are the same with the official [BEiT](https://github.com/microsoft/unilm/tree/master/beit).
66 |
67 | ### One-class Fine-tuning
68 | For one-class fine-tuning, please assign a class as in-distribution by adding command '--class_idx $CLASS_IDX'. Others are out-of-distribution. We support three in-distribution datasets, including `['cifar100', 'cifar10' and 'imagenet30']`. Noted that we only fine-tuned one-class imagenet30 in the original paper.
69 | For ViT-large,
70 | ```bash
71 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_class_finetuning.py \
72 | --model beit_large_patch16_224 --data_path $ID_DATA_PATH --data_set $ID_DATASET \
73 | --finetune https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k.pth \
74 | --batch_size 8 --lr 2e-5 --update_freq 2 \
75 | --warmup_epochs 5 --epochs 100 --layer_decay 0.9 --drop_path 0.4 \
76 | --weight_decay 1e-8 --enable_deepspeed --class_idx $CLASS_IDX
77 | ```
78 |
79 | ## OOD detection
80 | ### Multi-Class OOD Detection
81 | With OOD detection metric using **features**, we support `['mahalanobis', 'cos', 'projection', 'gauss', 'kmeans', 'euclidean', 'minkowski', 'chebyshev']` with the following command
82 | ```bash
83 | python eval_with_features.py --ckpt $CKPT_PATH --data_set $ID_DATASET --ood_dataset $OOD_DATASET --ood_data_path $OOD_DATA_PATH --metric $OOD_METRIC
84 | ```
85 | With OOD detection metric using **logits**, we support `['softmax', 'entropy', 'energy', 'gradnorm']` with the following command
86 | ```bash
87 | python eval_with_logits.py --ckpt $CKPT_PATH --data_set $ID_DATASET --ood_dataset $OOD_DATASET --ood_data_path $OOD_DATA_PATH --metric $OOD_METRIC
88 | ```
89 |
90 | ### One-Class OOD Detection
91 | For one-class OOD detection, please assign a class as in-distribution by adding command '--class_idx $CLASS_IDX'. Others are out-of-distribution. We support three in-distribution datasets, including `['cifar100', 'cifar10' and 'imagenet30']`.
92 | With OOD detection metric using **features**, we support `['mahalanobis', 'cos', 'projection', 'gauss', 'kmeans', 'euclidean', 'minkowski', 'chebyshev']` with the following command
93 | ```bash
94 | python eval_with_features.py --ckpt $CKPT_PATH --data_set $ID_DATASET --metric $OOD_METRIC --class_idx $CLASS_IDX
95 | ```
96 | With OOD detection metric using **logits**, we support `['softmax', 'entropy', 'energy', 'gradnorm']` with the following command
97 | ```bash
98 | python eval_with_logits.py --ckpt $CKPT_PATH --data_set $ID_DATASET --metric $OOD_METRIC --class_idx $CLASS_IDX
99 | ```
100 |
101 | ## Results
102 | ### Multi-class OOD detection
103 | For CIFAR-10,
104 | | CIFAR-10 | SVHN | CIFAR-100 | LSUN | Avg |
105 | |:--------: |:--------: |:---------: |:--------: |:-----: |
106 | | [ckpt](https://drive.google.com/file/d/1b_uWi2bty3tyspxEEM4jtyCh3WR9FpYm/view?usp=share_link), [distances](https://drive.google.com/drive/folders/1MeoEHArSeHc7D35-9vGEt782GKphU2PM?usp=share_link) | 99.8 | 99.4 | 99.9 | 99.7 |
107 |
108 | For CIFAR-100,
109 | | CIFAR-100 | SVHN | CIFAR-10 | LSUN | Avg |
110 | |:---------: |:--------: |:--------: |:--------: |:-----: |
111 | | [ckpt](https://drive.google.com/file/d/1MCPUTnz5DjNmR8gyWGMAl7qW811cH13X/view?usp=share_link), [distances](https://drive.google.com/drive/folders/1CV4kpb3OKiCj9uN9fMj1reG59RPDcF_X?usp=share_link) | 96.5 | 98.3 | 96.3 | 97.0 |
112 |
113 | For ImageNet-30,
114 | | ImageNet30 | Dogs | Places365 | Flowers102 | Pets | Food | Caltech256 | Dtd | Avg |
115 | |:----------: |:--------: |:---------: |:----------: |:--------: |:--------: |:----------: |:--------: |:-----: |
116 | | [ckpt](https://drive.google.com/file/d/1nTOimKRcNHlT_hKNfWJejDkYyDs4xd63/view?usp=share_link), [distances](https://drive.google.com/drive/folders/1CH3TjnohalbUIWgNcKzM3Swzy5BIo--f?usp=share_link) | 99.40 | 98.90 | 100.00 | 99.10 | 96.60 | 99.50 | 98.9 | 98.9 |
117 |
118 | For ImageNet-1k,
119 | | ImageNet1k | iNaturalist | SUN | Places | Textures | Average |
120 | |:----------: |:-----------: |:--------: |:--------: |:--------: |:-------: |
121 | | [ckpt](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft1k.pth), [distances](https://drive.google.com/drive/folders/1-JT_81-a8mRMc_jikeuVKJYDbLCJz_mr?usp=share_link) | 86.9 | 89.8 | 88.5 | 91.3 | 89.1 |
122 |
123 | ### One-class OOD detection
124 | For CIFAR-10,
125 | | Method | Airplane | Automobile | Bird | Cat | Dear | Dog | Frog | Horse | Ship | Truck |
126 | |:---------: |:--------: |:----------: |:--------: |:--------: |:--------: |:--------: |:--------: |:--------: |:--------: |:--------: |
127 | | ours | 98.6 | 99.3 | 94.3 | 93.2 | 98.1 | 96.5 | 99.3 | 99.0 | 98.8 | 97.8 |
128 |
129 | For CIFAR-100,
130 | | Method | AUROC |
131 | |:---------: |:--------: |
132 | | ours | 96.4 |
133 |
134 | For ImageNet-30,
135 | | Method | AUROC |
136 | |:---------------------: |:--------: |
137 | | ours | 92.0 |
138 |
139 | ## Acknowledgement
140 |
141 | This repository is built using the [beit](https://github.com/microsoft/unilm/tree/master/beit) library and the [SSD](https://github.com/inspire-group/SSD) repository.
142 |
143 | ## Citation
144 | If you find our research helpful, kindly cite
145 | ```
146 | @inproceedings{li2023rethinking,
147 | title={Rethinking Out-of-distribution (OOD) Detection: Masked Image Modeling is All You Need},
148 | author={Li, Jingyao and Chen, Pengguang and He, Zexin and Yu, Shaozuo and Liu, Shu and Jia, Jiaya},
149 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
150 | pages={11578--11589},
151 | year={2023}
152 | }
153 | ```
--------------------------------------------------------------------------------
/MOODv2/README.md:
--------------------------------------------------------------------------------
1 | # Official code for MOODv2: Masked Image Modeling for Out-of-Distribution Detection
2 |
3 |
4 | • 🤗 Model
5 | • 🐱 Code
6 | • 📃 MOODv1
7 | • 📃 MOODv2
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## Abstract
15 | The crux of effective out-of-distribution (OOD) detection lies in acquiring a robust in-distribution (ID) representation, distinct from OOD samples. While previous methods predominantly leaned on recognition-based techniques for this purpose, they often resulted in shortcut learning, lacking comprehensive representations. In our study, we conducted a comprehensive analysis, exploring distinct pretraining tasks and employing various OOD score functions. The results highlight that the feature representations pre-trained through reconstruction yield a notable enhancement and narrow the performance gap among various score functions. This suggests that even simple score functions can rival complex ones when leveraging reconstruction-based pretext tasks. Reconstruction-based pretext tasks adapt well to various score functions. As such, it holds promising potential for further expansion. Our OOD detection framework, MOODv2, employs the masked image modeling pretext task. Without bells and whistles, MOODv2 impressively enhances 14.30% AUROC to 95.68% on ImageNet and achieves 99.98% on CIFAR-10.
16 |
17 |
18 | ## Performance
19 |
20 |
21 |
22 |
23 |
24 | ## Datasets
25 | Dataset source can be downloaded here.
26 | - [ImageNet](https://www.image-net.org/). The ILSVRC 2012 dataset as In-distribution (ID) dataset. The training subset is [this file](datalists/imagenet2012_train_random_200k.txt).
27 | - [OpenImage-O](https://github.com/openimages/dataset/blob/main/READMEV3.md). The OpenImage-O dataset is a subset of the OpenImage-V3 testing set. The filelist is [here](datalists/openimage_o.txt).
28 | - [Texture](https://www.robots.ox.ac.uk/~vgg/data/dtd/). The filelist ruled out four classes that coincides with ImageNet is [here](datalists/texture.txt).
29 | - [iNaturalist](https://arxiv.org/pdf/1707.06642.pdf). Follow the instructions in the [link](https://github.com/deeplearning-wisc/large_scale_ood) to prepare the iNaturalist OOD dataset.
30 | - [ImageNet-O](https://github.com/hendrycks/natural-adv-examples). Follow the guide to download the ImageNet-O OOD dataset.
31 |
32 | ```bash
33 | mkdir data
34 | cd data
35 | ln -s /path/to/imagenet imagenet
36 | ln -s /path/to/openimage_o openimage_o
37 | ln -s /path/to/texture texture
38 | ln -s /path/to/inaturalist inaturalist
39 | ln -s /path/to/imagenet_o imagenet_o
40 | cd ..
41 | ```
42 |
43 | ## Environment
44 | Please follow the instruction in [mmpretrain](https://github.com/open-mmlab/mmpretrain) for environment preparation.
45 |
46 | ## Demo
47 | To predict an input image is in-distribution or out-of-distribution, we support the following OOD detection methods:
48 | - `MSP`
49 | - `MaxLogit`
50 | - `Energy`
51 | - `Energy+React`
52 | - `ViM`
53 | - `Residual`
54 | - `GradNorm`
55 | - `Mahalanobis`
56 | - `KL-Matching`
57 |
58 | ### Example Usage 1
59 |
60 | **Step 1: Download the features and logits**
61 | ```bash
62 | git clone https://huggingface.co/JingyaoLi/MOODv2
63 | cd MOODv2
64 | git lfs pull
65 | ```
66 |
67 | **Step 2: Detect your image**
68 | ```bash
69 | python src/demo.py \
70 | --img_path imgs/DTD_cracked_0004.jpg \
71 | --cfg configs/beit-base-p16_224px.py \
72 | --checkpoint pretrain/beitv2-base.pth \
73 | --fc data/fc.pkl \
74 | --id_train_feature data/imagenet_train.pkl \
75 | --id_val_feature data/imagenet_test.pkl \
76 | --methods MSP MaxLogit Energy Energy+React ViM Residual GradNorm Mahalanobis
77 | ```
78 |
79 | For the example OOD image `imgs/DTD_cracked_0004.jpg`, you are supposed to get:
80 | ```
81 | MSP evaluation: out-of-distribution
82 | MaxLogit evaluation: out-of-distribution
83 | Energy evaluation: out-of-distribution
84 | Energy+React evaluation: out-of-distribution
85 | ViM evaluation: out-of-distribution
86 | Residual evaluation: out-of-distribution
87 | GradNorm evaluation: out-of-distribution
88 | Mahalanobis evaluation: out-of-distribution
89 | ```
90 |
91 | ### Example Usage 2
92 | In case you want to extract the features and logits from the vision encoder without downloading our preprocessed features and logits:
93 |
94 | **Step 1: Download the checkpoint of the vision encoder**
95 | ```bash
96 | mkdir pretrain & cd pretrain
97 | wget https://huggingface.co/JingyaoLi/MOODv2/resolve/main/pretrain/beitv2-base.pth
98 | cd ..
99 | ```
100 |
101 | **Step 2: Extract the features and logits from the vision encoder**
102 | ```bash
103 | # ID train features
104 | python src/extract_feature_vit.py $IMAGENET_PATH \
105 | --out_file outputs/imagenet_train.pkl \
106 | --cfg configs/beit-base-p16_224px.py \
107 | --checkpoint pretrain/beitv2-base.pth \
108 | --img_list datalists/imagenet2012_train_random_200k.txt
109 |
110 | # ID test features
111 | python src/extract_feature_vit.py $IMAGENET_PATH \
112 | --out_file outputs/imagenet_test.pkl \
113 | --cfg configs/beit-base-p16_224px.py \
114 | --checkpoint pretrain/beitv2-base.pth \
115 | --img_list datalists/imagenet2012_val_list.txt
116 |
117 | # Logits
118 | python src/extract_feature_vit.py $IMAGENET_PATH \
119 | --cfg configs/beit-base-p16_224px.py \
120 | --checkpoint pretrain/beitv2-base.pth \
121 | --fc outputs/fc.pkl \
122 | ```
123 |
124 | **Step 3: Detect your image**
125 | ```bash
126 | python src/demo.py \
127 | --img_path imgs/DTD_cracked_0004.jpg \
128 | --cfg configs/beit-base-p16_224px.py \
129 | --checkpoint pretrain/beitv2-base.pth \
130 | --fc outputs/fc.pkl \
131 | --id_train_feature outputs/imagenet_train.pkl \
132 | --id_val_feature outputs/imagenet_test.pkl \
133 | --methods MSP MaxLogit Energy Energy+React ViM Residual GradNorm Mahalanobis
134 | ```
135 |
136 | ## OOD Detection Benchmark
137 | **Step 1: Download the checkpoint of the vision encoder**
138 | | Name | Paper | Config | Checkpoint | Train/Test Command |
139 | |:------:|:-------:|:-------:|:-------:|:-------:|
140 | | BEiT | [paper](https://arxiv.org/abs/2106.08254) | [config](configs/beit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmclassification/v0/beit/beit-base_3rdparty_in1k_20221114-c0a4df23.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/beit) |
141 | | BEiTv2 | [paper](https://arxiv.org/abs/2208.06366) | [config](configs/beit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmclassification/v0/beit/beitv2-base_3rdparty_in1k_20221114-73e11905.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/beitv2) |
142 | | ViT | [paper](https://arxiv.org/abs/2010.11929) | [config](configs/vit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmclassification/v0/vit/vit-base-p16_pt-32xb128-mae_in1k_20220623-4c544545.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/vision_transformer) |
143 | | MoCov3 | [paper](https://arxiv.org/abs/2104.02057) | [config](configs/vit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmselfsup/1.x/mocov3/mocov3_vit-base-p16_16xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb64-coslr-150e_in1k/vit-base-p16_ft-8xb64-coslr-150e_in1k_20220826-f1e6c442.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/mocov3) |
144 | | DINOv2 | [paper](https://arxiv.org/abs/2304.07193) | [config](configs/vit-base-p14_224px.py) | [ckpt](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-base-p14_dinov2-pre_3rdparty_20230426-ba246503.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/dinov2) |
145 |
146 |
147 | **Step 2: Extract the features and logits from the vision encoder**
148 |
149 | Step 2.1 Extract features
150 | ```bash
151 | python src/extract_feature_vit.py $DATA_ROOT $OUT_FILE --cfg $CFG --checkpoint $CHECKPOINT --img_list $IMG_LIST
152 | ```
153 | e.g.
154 | ```bash
155 | python extract_feature_vit.py data/imagenet outputs/vit_imagenet_val.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/imagenet2012_val_list.txt
156 | python extract_feature_vit.py data/imagenet outputs/vit_train_200k.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/imagenet2012_train_random_200k.txt
157 | python extract_feature_vit.py data/openimage_o outputs/vit_openimage_o.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/openimage_o.txt
158 | python extract_feature_vit.py data/texture outputs/vit_texture.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/texture.txt
159 | python extract_feature_vit.py data/inaturalist outputs/vit_inaturalist.pkl --cfg $CFG --checkpoint $CHECKPOINT
160 | python extract_feature_vit.py data/imagenet_o outputs/vit_imagenet_o.pkl --cfg $CFG --checkpoint $CHECKPOINT
161 | python extract_feature_vit.py data/cifar10 outputs/vit_cifar10_train.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/cifar10_train.txt
162 | python extract_feature_vit.py data/cifar10 outputs/vit_cifar10_test.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/cifar10_test.txt
163 | ```
164 |
165 | Step 2.2 Extract weights and bias in fc
166 | ```bash
167 | python src/extract_feature_vit.py $DATA_ROOT $OUT_FILE --cfg $CFG --checkpoint $CHECKPOINT --fc_save_path $FC_SAVE_PATH
168 | ```
169 | e.g.
170 | ```bash
171 | python src/extract_feature_vit.py $DATA_ROOT $OUT_FILE --cfg $CFG --checkpoint $CHECKPOINT --fc_save_path outputs/vit_fc.pkl
172 | ```
173 |
174 | **Step 3: Evaluation**
175 | ```bash
176 | python src/benchmark.py $FC_SAVE_PATH $ID_DATA $ID_TRAIN_FEATURE $ID_VAL_FEATURE $OOD_FEATURE
177 | ```
178 | e.g.
179 | ```bash
180 | python src/benchmark.py outputs/vit_fc.pkl outputs/vit_train_200k.pkl outputs/vit_imagenet_val.pkl outputs/vit_openimage_o.pkl outputs/vit_texture.pkl outputs/vit_inaturalist.pkl outputs/vit_imagenet_o.pkl
181 | python src/benchmark.py outputs/vit_fc.pkl outputs/vit_cifar10_train.pkl outputs/vit_cifar10_test.pkl outputs/vit_openimage_o.pkl outputs/vit_texture.pkl outputs/vit_inaturalist.pkl outputs/vit_imagenet_o.pkl
182 | ```
183 |
184 | ## Acknowledgement
185 | Part of the code is modified from [ViM](https://github.com/haoqiwang/vim) repo.
186 |
187 |
--------------------------------------------------------------------------------
/MOODv2/src/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | import json
4 | from os.path import basename, splitext
5 | import os
6 | import mmengine
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | from numpy.linalg import norm, pinv
11 | from scipy.special import logsumexp, softmax
12 | from sklearn import metrics
13 | from sklearn.covariance import EmpiricalCovariance
14 | from sklearn.metrics import pairwise_distances_argmin_min
15 | from tqdm import tqdm
16 | import pickle
17 | from os.path import dirname
18 | import torchvision as tv
19 | from PIL import Image
20 | from mmpretrain.apis import init_model
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description='Detect an image')
24 | parser.add_argument(
25 | '--cfg', help='Path to config',
26 | default='/dataset/jingyaoli/AD/MOOD_/MOODv2/configs/beit-base-p16_224px.py')
27 | parser.add_argument('--ood_feature',
28 | default=None, help='Path to ood feature file')
29 | parser.add_argument(
30 | '--checkpoint', help='Path to checkpoint',
31 | default='/dataset/jingyaoli/AD/MOODv2/pretrain/beit-base_3rdparty_in1k_20221114-c0a4df23.pth',)
32 | parser.add_argument('--img_path', help='Path to image',
33 | default='/dataset/jingyaoli/AD/MOOD_/MOODv2/imgs/DTD_cracked_0004.jpg')
34 | parser.add_argument('--fc',
35 | default='/dataset/jingyaoli/AD/MOODv2/outputs/beit-224px/fc.pkl', help='Path to fc path')
36 | parser.add_argument('--id_data', default='imagenet', help='id data name')
37 | parser.add_argument('--id_train_feature',
38 | default='/dataset/jingyaoli/AD/MOODv2/outputs/beit-224px/imagenet_train.pkl', help='Path to data')
39 | parser.add_argument('--id_val_feature',
40 | default='/dataset/jingyaoli/AD/MOODv2/outputs/beit-224px/imagenet_test.pkl', help='Path to output file')
41 | parser.add_argument('--ood_features',
42 | default=None, nargs='+', help='Path to ood features')
43 | parser.add_argument(
44 | '--methods', nargs='+',
45 | default=['MSP', 'MaxLogit', 'Energy', 'Energy+React', 'ViM', 'Residual', 'GradNorm', 'Mahalanobis', ], # 'KL-Matching'
46 | help='methods')
47 | parser.add_argument(
48 | '--train_label',
49 | default='datalists/imagenet2012_train_random_200k.txt',
50 | help='Path to train labels')
51 | parser.add_argument(
52 | '--clip_quantile', default=0.99, help='Clip quantile to react')
53 | parser.add_argument(
54 | '--fpr', default=95, help='False Positive Rate')
55 | return parser.parse_args()
56 |
57 | def evaluate(method, score_id, score_ood, target_fpr):
58 | threhold = np.percentile(score_id, 100 - target_fpr)
59 | if score_ood >= threhold:
60 | print('\033[94m', method, '\033[0m', 'evaluation:', '\033[92m', 'in-distribution', '\033[0m')
61 | else:
62 | print('\033[94m', method, '\033[0m', 'evaluation:', '\033[91m', 'out-of-distribution', '\033[0m')
63 |
64 | def kl(p, q):
65 | return np.sum(np.where(p != 0, p * np.log(p / q), 0))
66 |
67 | def gradnorm(x, w, b, num_cls):
68 | fc = torch.nn.Linear(*w.shape[::-1])
69 | fc.weight.data[...] = torch.from_numpy(w)
70 | fc.bias.data[...] = torch.from_numpy(b)
71 | fc.cuda()
72 |
73 | x = torch.from_numpy(x).float().cuda()
74 | logsoftmax = torch.nn.LogSoftmax(dim=-1).cuda()
75 |
76 | confs = []
77 |
78 | for i in tqdm(x, desc='Computing Gradnorm ID/OOD score'):
79 | targets = torch.ones((1, num_cls)).cuda()
80 | fc.zero_grad()
81 | loss = torch.mean(
82 | torch.sum(-targets * logsoftmax(fc(i[None])), dim=-1))
83 | loss.backward()
84 | layer_grad_norm = torch.sum(torch.abs(
85 | fc.weight.grad.data)).cpu().numpy()
86 | confs.append(layer_grad_norm)
87 |
88 | return np.array(confs)
89 |
90 | def extract_image_feature(args):
91 | torch.backends.cudnn.benchmark = True
92 |
93 | print('=> Loading model')
94 | cfg = mmengine.Config.fromfile(args.cfg)
95 | model = init_model(cfg, args.checkpoint, 0).cuda().eval()
96 |
97 | print('=> Loading image')
98 | if hasattr(cfg.model.backbone, 'img_size'):
99 | img_size = cfg.model.backbone.img_size
100 | else:
101 | img_size = 224
102 |
103 | transform = tv.transforms.Compose([
104 | tv.transforms.Resize((img_size, img_size)),
105 | tv.transforms.ToTensor(),
106 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
107 | ])
108 |
109 | x = transform(Image.open(args.img_path).convert('RGB')).unsqueeze(0)
110 |
111 | print('=> Extracting feature')
112 | with torch.no_grad():
113 | x = x.cuda()
114 | if cfg.model.backbone.type == 'BEiTPretrainViT':
115 | # (B, L, C) -> (B, C)
116 | feat_batch = model.backbone(
117 | x, mask=None)[0].mean(1)
118 | elif cfg.model.backbone.type == 'SwinTransformer':
119 | # (B, C, H, W) -> (B, C)
120 | feat_batch = model.backbone(x)[0]
121 | B, C, H, W = feat_batch.shape
122 | feat_batch = feat_batch.reshape(B, C, -1).mean(-1)
123 | else:
124 | # (B, C)
125 | feat_batch = model.backbone(x)[0]
126 | assert len(feat_batch.shape) == 2
127 | feature = feat_batch.cpu().numpy()
128 |
129 | print(f'Extracted Feature: {feature.shape}')
130 | return feature
131 |
132 | def main():
133 | args = parse_args()
134 | if args.ood_feature and os.path.exists(args.ood_feature):
135 | feature_ood = mmengine.load(args.ood_feature)
136 | else:
137 | feature_ood = extract_image_feature(args)
138 |
139 | if os.path.exists(args.fc):
140 | w, b = mmengine.load(args.fc)
141 | print(f'{w.shape=}, {b.shape=}')
142 | num_cls = len(b)
143 |
144 | train_labels = np.array([
145 | int(line.rsplit(' ', 1)[-1])
146 | for line in mmengine.list_from_file(args.train_label)
147 | ], dtype=int)
148 |
149 | print(f'image path: {args.img_path}')
150 |
151 | print('=> Loading features')
152 | feature_id_train = mmengine.load(args.id_train_feature).squeeze()
153 | feature_id_val = mmengine.load(args.id_val_feature).squeeze()
154 |
155 | print(f'{feature_id_train.shape=}, {feature_id_val.shape=}')
156 |
157 | if os.path.exists(args.fc):
158 | print('=> Computing logits...')
159 | logit_id_train = feature_id_train @ w.T + b
160 | logit_id_val = feature_id_val @ w.T + b
161 | logit_ood = feature_ood @ w.T + b
162 |
163 | print('=> Computing softmax...')
164 | softmax_id_train = softmax(logit_id_train, axis=-1)
165 | softmax_id_val = softmax(logit_id_val, axis=-1)
166 | softmax_ood = softmax(logit_ood, axis=-1)
167 |
168 | u = -np.matmul(pinv(w), b)
169 |
170 | # ---------------------------------------
171 | method = 'MSP'
172 | if method in args.methods:
173 | score_id = softmax_id_val.max(axis=-1)
174 | score_ood = softmax_ood.max(axis=-1)
175 | result = evaluate(method, score_id, score_ood, args.fpr)
176 |
177 | # ---------------------------------------
178 | method = 'MaxLogit'
179 | if method in args.methods:
180 | score_id = logit_id_val.max(axis=-1)
181 | score_ood = logit_ood.max(axis=-1)
182 | result = evaluate(method, score_id, score_ood, args.fpr)
183 |
184 | # ---------------------------------------
185 | method = 'Energy'
186 | if method in args.methods:
187 | score_id = logsumexp(logit_id_val, axis=-1)
188 | score_ood = logsumexp(logit_ood, axis=-1)
189 | result = evaluate(method, score_id, score_ood, args.fpr)
190 |
191 | # ---------------------------------------
192 | method = 'Energy+React'
193 | if method in args.methods:
194 | clip = np.quantile(feature_id_train, args.clip_quantile)
195 | logit_id_val_clip = np.clip(
196 | feature_id_val, a_min=None, a_max=clip) @ w.T + b
197 | score_id = logsumexp(logit_id_val_clip, axis=-1)
198 |
199 | logit_ood_clip = np.clip(feature_ood, a_min=None, a_max=clip) @ w.T + b
200 | score_ood = logsumexp(logit_ood_clip, axis=-1)
201 | result = evaluate(method, score_id, score_ood, args.fpr)
202 |
203 | # ---------------------------------------
204 | method = 'ViM'
205 | if method in args.methods:
206 | if feature_id_val.shape[-1] >= 2048:
207 | DIM = num_cls
208 | elif feature_id_val.shape[-1] >= 768:
209 | DIM = 512
210 | else:
211 | DIM = feature_id_val.shape[-1] // 2
212 |
213 | ec = EmpiricalCovariance(assume_centered=True)
214 | ec.fit(feature_id_train - u)
215 | eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_)
216 | NS = np.ascontiguousarray(
217 | (eigen_vectors.T[np.argsort(eig_vals * -1)[DIM:]]).T)
218 | vlogit_id_train = norm(np.matmul(feature_id_train - u, NS), axis=-1)
219 | alpha = logit_id_train.max(axis=-1).mean() / vlogit_id_train.mean()
220 |
221 | vlogit_id_val = norm(np.matmul(feature_id_val - u, NS), axis=-1) * alpha
222 | energy_id_val = logsumexp(logit_id_val, axis=-1)
223 | score_id = -vlogit_id_val + energy_id_val
224 |
225 | energy_ood = logsumexp(logit_ood, axis=-1)
226 | vlogit_ood = norm(np.matmul(feature_ood - u, NS), axis=-1) * alpha
227 | score_ood = -vlogit_ood + energy_ood
228 | result = evaluate(method, score_id, score_ood, args.fpr)
229 |
230 | # ---------------------------------------
231 | method = 'Residual'
232 | if method in args.methods:
233 | if feature_id_val.shape[-1] >= 2048:
234 | DIM = 1000
235 | elif feature_id_val.shape[-1] >= 768:
236 | DIM = 512
237 | else:
238 | DIM = feature_id_val.shape[-1] // 2
239 | ec = EmpiricalCovariance(assume_centered=True)
240 | ec.fit(feature_id_train - u)
241 | eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_)
242 | NS = np.ascontiguousarray(
243 | (eigen_vectors.T[np.argsort(eig_vals * -1)[DIM:]]).T)
244 |
245 | score_id = -norm(np.matmul(feature_id_val - u, NS), axis=-1)
246 |
247 | score_ood = -norm(np.matmul(feature_ood - u, NS), axis=-1)
248 | result = evaluate(method, score_id, score_ood, args.fpr)
249 |
250 | # ---------------------------------------
251 | method = 'GradNorm'
252 | if method in args.methods:
253 | score_ood = gradnorm(feature_ood, w, b, num_cls)
254 | score_id = gradnorm(feature_id_val, w, b, num_cls)
255 | result = evaluate(method, score_id, score_ood, args.fpr)
256 |
257 | # ---------------------------------------
258 | method = 'Mahalanobis'
259 | if method in args.methods:
260 | train_means = []
261 | train_feat_centered = []
262 | for i in tqdm(range(train_labels.max() + 1), desc='Computing classwise mean feature'):
263 | fs = feature_id_train[train_labels == i]
264 | _m = fs.mean(axis=0)
265 | train_means.append(_m)
266 | train_feat_centered.extend(fs - _m)
267 |
268 | ec = EmpiricalCovariance(assume_centered=True)
269 | ec.fit(np.array(train_feat_centered).astype(np.float64))
270 |
271 | mean = torch.from_numpy(np.array(train_means)).cuda().float()
272 | prec = torch.from_numpy(ec.precision_).cuda().float()
273 |
274 | score_id = -np.array(
275 | [(((f - mean) @ prec) * (f - mean)).sum(axis=-1).min().cpu().item()
276 | for f in tqdm(torch.from_numpy(feature_id_val).cuda().float(), desc='Computing Mahalanobis ID score')])
277 |
278 | score_ood = -np.array([
279 | (((f - mean) @ prec) * (f - mean)).sum(axis=-1).min().cpu().item()
280 | for f in tqdm(torch.from_numpy(feature_ood).cuda().float(), desc='Computing Mahalanobis OOD score')
281 | ])
282 | result = evaluate(method, score_id, score_ood, args.fpr)
283 |
284 | # ---------------------------------------
285 | method = 'KL-Matching'
286 | if method in args.methods:
287 |
288 | pred_labels_train = np.argmax(softmax_id_train, axis=-1)
289 | mean_softmax_train = []
290 | for i in tqdm(range(num_cls), desc='Computing classwise mean softmax'):
291 | mean_softmax = softmax_id_train[pred_labels_train == i]
292 | if mean_softmax.shape[0] == 0:
293 | mean_softmax_train.append(np.zeros((num_cls)))
294 | else:
295 | mean_softmax_train.append(np.mean(mean_softmax, axis=0))
296 |
297 | score_id = -pairwise_distances_argmin_min(
298 | softmax_id_val, np.array(mean_softmax_train), metric=kl)[1]
299 |
300 | score_ood = -pairwise_distances_argmin_min(
301 | softmax_ood, np.array(mean_softmax_train), metric=kl)[1]
302 | result = evaluate(method, score_id, score_ood, args.fpr)
303 |
304 | if __name__ == '__main__':
305 | main()
306 |
--------------------------------------------------------------------------------
/MOODv1/run_beit_pretraining.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm, DINO and DeiT code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # https://github.com/facebookresearch/deit
10 | # https://github.com/facebookresearch/dino
11 | # --------------------------------------------------------'
12 | import argparse
13 | import datetime
14 | import numpy as np
15 | import time
16 | import torch
17 | import torch.backends.cudnn as cudnn
18 | import json
19 | import os
20 |
21 | from pathlib import Path
22 |
23 | from timm.models import create_model
24 | from optim_factory import create_optimizer
25 |
26 | from datasets import build_beit_pretraining_dataset
27 | from engine_for_pretraining import train_one_epoch
28 | from utils import NativeScalerWithGradNormCount as NativeScaler
29 | import utils
30 | import modeling_pretrain
31 |
32 |
33 | def get_args():
34 | parser = argparse.ArgumentParser('BEiT pre-training script', add_help=False)
35 | parser.add_argument('--pretrain', default=None, type=str)
36 | parser.add_argument('--batch_size', default=64, type=int)
37 | parser.add_argument('--epochs', default=300, type=int)
38 | parser.add_argument('--save_ckpt_freq', default=20, type=int)
39 | parser.add_argument("--discrete_vae_weight_path", type=str)
40 | parser.add_argument("--discrete_vae_type", type=str, default="dall-e")
41 | parser.add_argument('--data_set', default='cifar10', choices=['cifar10', 'cifar100', 'imagenet30'],
42 | type=str, help='ImageNet dataset path')
43 | parser.add_argument('--class_idx', help='None: multi-class, Not None: one-class', default=None, type=int)
44 |
45 | # Model parameters
46 | parser.add_argument('--model', default='beit_base_patch16_224_8k_vocab', type=str, metavar='MODEL',
47 | help='Name of model to train')
48 | parser.add_argument('--rel_pos_bias', action='store_true')
49 | parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')
50 | parser.set_defaults(rel_pos_bias=True)
51 | parser.add_argument('--abs_pos_emb', action='store_true')
52 | parser.set_defaults(abs_pos_emb=False)
53 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
54 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
55 |
56 | parser.add_argument('--num_mask_patches', default=75, type=int,
57 | help='number of the visual tokens/patches need be masked')
58 | parser.add_argument('--max_mask_patches_per_block', type=int, default=None)
59 | parser.add_argument('--min_mask_patches_per_block', type=int, default=16)
60 |
61 | parser.add_argument('--input_size', default=224, type=int,
62 | help='images input size for backbone')
63 | parser.add_argument('--second_input_size', default=112, type=int,
64 | help='images input size for discrete vae')
65 |
66 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
67 | help='Drop path rate (default: 0.1)')
68 |
69 | # Optimizer parameters
70 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
71 | help='Optimizer (default: "adamw"')
72 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
73 | help='Optimizer Epsilon (default: 1e-8)')
74 | parser.add_argument('--opt_betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA',
75 | help='Optimizer Betas (default: 0.9, 0.999, use opt default)')
76 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
77 | help='Clip gradient norm (default: None, no clipping)')
78 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
79 | help='SGD momentum (default: 0.9)')
80 | parser.add_argument('--weight_decay', type=float, default=0.05,
81 | help='weight decay (default: 0.05)')
82 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
83 | weight decay. We use a cosine schedule for WD.
84 | (Set the same value with args.weight_decay to keep weight decay no change)""")
85 |
86 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
87 | help='learning rate (default: 5e-4)')
88 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
89 | help='warmup learning rate (default: 1e-6)')
90 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
91 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
92 |
93 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
94 | help='epochs to warmup LR, if scheduler supports')
95 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
96 | help='epochs to warmup LR, if scheduler supports')
97 |
98 | # Augmentation parameters
99 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
100 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
101 | parser.add_argument('--second_interpolation', type=str, default='lanczos',
102 | help='Interpolation for discrete vae (random, bilinear, bicubic default: "lanczos")')
103 |
104 | # Dataset parameters
105 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
106 | help='dataset path')
107 | parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
108 |
109 | parser.add_argument('--output_dir', default='',
110 | help='path where to save, empty for no saving')
111 | parser.add_argument('--log_dir', default=None,
112 | help='path where to tensorboard log')
113 | parser.add_argument('--device', default='cuda',
114 | help='device to use for training / testing')
115 | parser.add_argument('--seed', default=0, type=int)
116 | parser.add_argument('--resume', default='', help='resume from checkpoint')
117 | parser.add_argument('--auto_resume', action='store_true')
118 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
119 | parser.set_defaults(auto_resume=True)
120 |
121 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
122 | help='start epoch')
123 | parser.add_argument('--num_workers', default=10, type=int)
124 | parser.add_argument('--pin_mem', action='store_true',
125 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
126 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
127 | help='')
128 | parser.set_defaults(pin_mem=True)
129 |
130 | # distributed training parameters
131 | parser.add_argument('--world_size', default=1, type=int,
132 | help='number of distributed processes')
133 | parser.add_argument('--local_rank', default=-1, type=int)
134 | parser.add_argument('--dist_on_itp', action='store_true')
135 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
136 |
137 | return parser.parse_args()
138 |
139 |
140 | def get_model(args):
141 | print(f"Creating model: {args.model}")
142 | model = create_model(
143 | args.model,
144 | pretrained=False,
145 | drop_path_rate=args.drop_path,
146 | drop_block_rate=None,
147 | use_shared_rel_pos_bias=args.rel_pos_bias,
148 | use_abs_pos_emb=args.abs_pos_emb,
149 | init_values=args.layer_scale_init_value,
150 | )
151 |
152 | return model
153 |
154 |
155 | def main(args):
156 | utils.init_distributed_mode(args)
157 |
158 | print(args)
159 |
160 | device = torch.device(args.device)
161 |
162 | # fix the seed for reproducibility
163 | seed = args.seed + utils.get_rank()
164 | torch.manual_seed(seed)
165 | np.random.seed(seed)
166 | # random.seed(seed)
167 |
168 | cudnn.benchmark = True
169 |
170 | model = get_model(args)
171 | patch_size = model.patch_embed.patch_size
172 | print("Patch size = %s" % str(patch_size))
173 | args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
174 | args.patch_size = patch_size
175 |
176 | # get dataset
177 | dataset_train = build_beit_pretraining_dataset(args, args.data_set)
178 |
179 | # prepare discrete vae
180 | d_vae = utils.create_d_vae(
181 | weight_path=args.discrete_vae_weight_path, d_vae_type=args.discrete_vae_type,
182 | device=device, image_size=args.second_input_size)
183 |
184 | if True: # args.distributed:
185 | num_tasks = utils.get_world_size()
186 | global_rank = utils.get_rank()
187 | sampler_rank = global_rank
188 | num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks
189 |
190 | sampler_train = torch.utils.data.DistributedSampler(
191 | dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
192 | )
193 | print("Sampler_train = %s" % str(sampler_train))
194 | else:
195 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
196 |
197 | if global_rank == 0 and args.log_dir is not None:
198 | os.makedirs(args.log_dir, exist_ok=True)
199 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
200 | else:
201 | log_writer = None
202 |
203 | data_loader_train = torch.utils.data.DataLoader(
204 | dataset_train, sampler=sampler_train,
205 | batch_size=args.batch_size,
206 | num_workers=args.num_workers,
207 | pin_memory=args.pin_mem,
208 | drop_last=True,
209 | )
210 |
211 | print('=> Loading checkpoint'.format(args.pretrain))
212 | if args.pretrain:
213 | from utils import load_state_dict
214 | state_dict = torch.load(args.pretrain, map_location='cpu')
215 | load_state_dict(model, state_dict)
216 |
217 | model.to(device)
218 | model_without_ddp = model
219 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
220 |
221 | print("Model = %s" % str(model_without_ddp))
222 | print('number of params:', n_parameters)
223 |
224 | total_batch_size = args.batch_size * utils.get_world_size()
225 | print("LR = %.8f" % args.lr)
226 | print("Batch size = %d" % total_batch_size)
227 | print("Number of training steps = %d" % num_training_steps_per_epoch)
228 | print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
229 |
230 | if args.distributed:
231 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
232 | model_without_ddp = model.module
233 |
234 | optimizer = create_optimizer(
235 | args, model_without_ddp)
236 | loss_scaler = NativeScaler()
237 |
238 | print("Use step level LR & WD scheduler!")
239 | lr_schedule_values = utils.cosine_scheduler(
240 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
241 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
242 | )
243 | if args.weight_decay_end is None:
244 | args.weight_decay_end = args.weight_decay
245 | wd_schedule_values = utils.cosine_scheduler(
246 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
247 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
248 |
249 | utils.auto_load_model(
250 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
251 |
252 | print(f"Start training for {args.epochs} epochs")
253 | start_time = time.time()
254 | for epoch in range(args.start_epoch, args.epochs):
255 | if args.distributed:
256 | data_loader_train.sampler.set_epoch(epoch)
257 | if log_writer is not None:
258 | log_writer.set_step(epoch * num_training_steps_per_epoch)
259 | train_stats = train_one_epoch(
260 | model, d_vae, data_loader_train,
261 | optimizer, device, epoch, loss_scaler,
262 | args.clip_grad, log_writer=log_writer,
263 | start_steps=epoch * num_training_steps_per_epoch,
264 | lr_schedule_values=lr_schedule_values,
265 | wd_schedule_values=wd_schedule_values,
266 | )
267 | if args.output_dir:
268 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
269 | utils.save_model(
270 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
271 | loss_scaler=loss_scaler, epoch=epoch)
272 |
273 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
274 | 'epoch': epoch, 'n_parameters': n_parameters}
275 |
276 | if args.output_dir and utils.is_main_process():
277 | if log_writer is not None:
278 | log_writer.flush()
279 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
280 | f.write(json.dumps(log_stats) + "\n")
281 |
282 | total_time = time.time() - start_time
283 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
284 | print('Training time {}'.format(total_time_str))
285 |
286 |
287 | if __name__ == '__main__':
288 | args = get_args()
289 | args.output_dir = '/dataset/jingyaoli/AD/temp/{}/{}/trial:{}'.format(args.data_set, args.model, args.trial)
290 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
291 | main(args)
292 |
--------------------------------------------------------------------------------
/MOODv1/datasets.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm, DINO and DeiT code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # https://github.com/facebookresearch/deit/
10 | # https://github.com/facebookresearch/dino
11 | # --------------------------------------------------------'
12 | import os
13 | import torch
14 |
15 | from torchvision import datasets, transforms
16 | from dataset_folder import ImageFolder, SegmentationDataset
17 | import blobfile as bf
18 | from timm.data.constants import \
19 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
20 | from transforms import RandomResizedCropAndInterpolationWithTwoPic
21 | from timm.data import create_transform
22 |
23 |
24 | from typing import TypeVar, Generic
25 | from dall_e.utils import map_pixels
26 | from masking_generator import MaskingGenerator
27 | import ood_utils
28 | from PIL import Image
29 | import numpy as np
30 |
31 | T_co = TypeVar('T_co', covariant=True)
32 |
33 | class Dataset(Generic[T_co]):
34 | def __getitem__(self, index):
35 | raise NotImplementedError
36 |
37 | class DataAugmentationForBEiT(object):
38 | def __init__(self, args):
39 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
40 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
41 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
42 |
43 | self.common_transform = transforms.Compose([
44 | transforms.ColorJitter(0.4, 0.4, 0.4),
45 | transforms.RandomHorizontalFlip(p=0.5),
46 | RandomResizedCropAndInterpolationWithTwoPic(
47 | size=args.input_size, second_size=args.second_input_size,
48 | interpolation=args.train_interpolation, second_interpolation=args.second_interpolation,
49 | ),
50 | ])
51 |
52 | self.patch_transform = transforms.Compose([
53 | transforms.ToTensor(),
54 | transforms.Normalize(
55 | mean=torch.tensor(mean),
56 | std=torch.tensor(std))
57 | ])
58 |
59 | if args.discrete_vae_type == "dall-e":
60 | self.visual_token_transform = transforms.Compose([
61 | transforms.ToTensor(),
62 | map_pixels,
63 | ])
64 | elif args.discrete_vae_type == "customized":
65 | self.visual_token_transform = transforms.Compose([
66 | transforms.ToTensor(),
67 | transforms.Normalize(
68 | mean=IMAGENET_INCEPTION_MEAN,
69 | std=IMAGENET_INCEPTION_STD,
70 | ),
71 | ])
72 | else:
73 | raise NotImplementedError()
74 |
75 | self.masked_position_generator = MaskingGenerator(
76 | args.window_size, num_masking_patches=args.num_mask_patches,
77 | max_num_patches=args.max_mask_patches_per_block,
78 | min_num_patches=args.min_mask_patches_per_block,
79 | )
80 |
81 | def __call__(self, image):
82 | for_patches, for_visual_tokens = self.common_transform(image)
83 | return \
84 | self.patch_transform(for_patches), self.visual_token_transform(for_visual_tokens), \
85 | self.masked_position_generator()
86 |
87 | def __repr__(self):
88 | repr = "(DataAugmentationForBEiT,\n"
89 | repr += " common_transform = %s,\n" % str(self.common_transform)
90 | repr += " patch_transform = %s,\n" % str(self.patch_transform)
91 | repr += " visual_tokens_transform = %s,\n" % str(self.visual_token_transform)
92 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
93 | repr += ")"
94 | return repr
95 |
96 |
97 | def build_beit_pretraining_dataset(args, data_set=None):
98 | '''train set for beit'''
99 | if data_set == None:
100 | data_set = args.data_set
101 | transform = DataAugmentationForBEiT(args)
102 |
103 | # print("Data Aug = %s" % str(transform))
104 | data_path = args.data_path
105 | if data_set == 'cifar100':
106 | dataset = datasets.CIFAR100(data_path, train=True, transform=transform)
107 | elif data_set == 'cifar10':
108 | dataset = datasets.CIFAR10(data_path, train=True, transform=transform)
109 | elif data_set == 'imagenet30':
110 | dataset = ImageFolder(os.path.join(data_path, 'train'), transform=transform)
111 | else:
112 | dataset = ImageFolder(data_path, transform=transform)
113 | return dataset
114 |
115 |
116 | class Subset(Dataset[T_co]):
117 | def __init__(self, dataset, indices):
118 | self.dataset = dataset
119 | self.indices = np.array(indices)
120 | self.data = np.array(dataset.data)[self.indices, :]
121 | self.targets = np.array(dataset.targets)[self.indices]
122 |
123 | def __getitem__(self, idx):
124 | return self.dataset[self.indices[idx]]
125 |
126 | def __len__(self):
127 | return len(self.indices)
128 |
129 |
130 | class Mixup(Dataset[T_co]):
131 | def __init__(self, dataset, transform, alpha=0.2):
132 | self.dataset = dataset
133 | self.targets = dataset.targets
134 | self.data = dataset.data
135 | self.alpha = alpha
136 | self.baselenth = len(self.data)
137 | self.transform = transform
138 |
139 |
140 | def __getitem__(self, idx):
141 | if idx < self.baselenth:
142 | img, target = self.data[idx], self.targets[idx]
143 | img = Image.fromarray(img)
144 | img = self.transform(img)
145 | return img, target
146 | else:
147 | img, target = self.data[idx-self.baselenth], self.targets[idx-self.baselenth]
148 | img = Image.fromarray(img)
149 | img = self.transform(img)
150 |
151 | lam = np.random.beta(self.alpha, self.alpha)
152 | mix_index = np.random.randint(0, self.baselenth-1)
153 | mix_img, mix_target = self.data[mix_index], self.targets[mix_index]
154 | mix_img = Image.fromarray(mix_img)
155 | mix_img = self.transform(mix_img)
156 |
157 | img = lam * img + (1 - lam) * mix_img
158 | target = int(lam * target + (1 - lam) * mix_target)
159 | return img, target
160 |
161 | def __len__(self):
162 | return self.baselenth*2
163 |
164 |
165 | class Rotation(Dataset[T_co]):
166 | def __init__(self, dataset, transform):
167 | self.dataset = dataset
168 | self.targets = dataset.targets
169 | self.data = dataset.data
170 | self.baselenth = len(self.data)
171 | self.transform = transform
172 |
173 | def __getitem__(self, idx):
174 | rot_angle = idx // self.baselenth # range: 0-3
175 | image_idx = idx % self.baselenth # range: 0-baselenth
176 | img, target = self.data[image_idx], self.targets[image_idx]
177 | img = np.rot90(img, rot_angle)
178 | img = Image.fromarray(img)
179 | img = self.transform(img)
180 | return img, target
181 |
182 | def __len__(self):
183 | return self.baselenth*4
184 |
185 |
186 | class Transform(Dataset[T_co]):
187 | def __init__(self, dataset, transform):
188 | self.dataset = dataset
189 | self.data = dataset.data
190 | self.targets = dataset.targets
191 | self.transform = transform
192 |
193 | def __getitem__(self, idx):
194 | img, target = self.data[idx], self.targets[idx]
195 | img = Image.fromarray(img)
196 | img = self.transform(img)
197 | return img, target
198 |
199 | def __len__(self):
200 | return len(self.dataset)
201 |
202 |
203 | def build_dataset(is_train, args, data_set=None, ood=False, is_trans=True, ood_data_path=None):
204 | if data_set == None:
205 | data_set = args.data_set
206 |
207 | if not is_trans:
208 | # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
209 | transform = transforms.Compose([
210 | transforms.Resize(256),
211 | transforms.CenterCrop(224),
212 | transforms.ToTensor(),
213 | # normalize,
214 | ])
215 | else:
216 | transform = build_transform(is_train, args)
217 |
218 | if ood_data_path is None:
219 | data_path = args.data_path
220 | else:
221 | data_path = ood_data_path
222 |
223 | if data_set == 'cifar100':
224 | dataset = datasets.CIFAR100(data_path, train=is_train, transform=transform, download=False)
225 | nb_classes = 100
226 | elif data_set == 'cifar10':
227 | dataset = datasets.CIFAR10(data_path, train=is_train, transform=transform, download=False)
228 | nb_classes = 10
229 | elif data_set == 'svhn':
230 | dataset = datasets.SVHN(data_path, split='train' if is_train else 'test', transform=transform, download=False)
231 | nb_classes = 10
232 | elif data_set == 'imagenet30':
233 | split = 'train' if is_train else 'test'
234 | nb_classes = 30
235 | dataset = ImageFolder(os.path.join(data_path, split), transform=transform)
236 | elif data_set == 'caltech256':
237 | split = 'train' if is_train else 'test'
238 | nb_classes = 256
239 | dataset = ImageFolder(os.path.join(data_path, split), transform=transform)
240 | elif data_set == 'imagenet1k':
241 | root = os.path.join(data_path, 'train' if is_train else 'val')
242 | dataset = ImageFolder(root, transform=transform)
243 | nb_classes = 1000
244 | elif data_set == "image_folder":
245 | root = data_path if is_train else args.eval_data_path
246 | dataset = ImageFolder(root, transform=transform)
247 | nb_classes = args.nb_classes
248 | assert len(dataset.class_to_idx) == nb_classes
249 | else:
250 | nb_classes = None
251 | dataset = ImageFolder(data_path, transform=transform)
252 |
253 | if isinstance(args.class_idx, int) and not ood:
254 | print('Using one class idx (class idx:{})'.format(args.class_idx))
255 | cls_list = get_superclass_list(data_set)
256 | dataset = get_subclass_dataset(dataset, classes=cls_list[args.class_idx])
257 |
258 | return dataset, nb_classes
259 |
260 |
261 | def _list_image_files_recursively(data_dir):
262 | results = []
263 | for entry in sorted(bf.listdir(data_dir)):
264 | full_path = bf.join(data_dir, entry)
265 | ext = entry.split(".")[-1]
266 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
267 | results.append(full_path)
268 | elif bf.isdir(full_path):
269 | results.extend(_list_image_files_recursively(full_path))
270 | return results
271 |
272 |
273 |
274 | def get_superclass_list(dataset):
275 | CIFAR10_SUPERCLASS = list(range(10)) # one class
276 | IMAGENET_SUPERCLASS = list(range(30)) # one class
277 | CIFAR100_SUPERCLASS = [
278 | [4, 31, 55, 72, 95],
279 | [1, 33, 67, 73, 91],
280 | [54, 62, 70, 82, 92],
281 | [9, 10, 16, 29, 61],
282 | [0, 51, 53, 57, 83],
283 | [22, 25, 40, 86, 87],
284 | [5, 20, 26, 84, 94],
285 | [6, 7, 14, 18, 24],
286 | [3, 42, 43, 88, 97],
287 | [12, 17, 38, 68, 76],
288 | [23, 34, 49, 60, 71],
289 | [15, 19, 21, 32, 39],
290 | [35, 63, 64, 66, 75],
291 | [27, 45, 77, 79, 99],
292 | [2, 11, 36, 46, 98],
293 | [28, 30, 44, 78, 93],
294 | [37, 50, 65, 74, 80],
295 | [47, 52, 56, 59, 96],
296 | [8, 13, 48, 58, 90],
297 | [41, 69, 81, 85, 89],
298 | ]
299 | if dataset == 'cifar10':
300 | return CIFAR10_SUPERCLASS
301 | elif dataset == 'cifar100':
302 | return CIFAR100_SUPERCLASS
303 | elif dataset.lower() == 'imagenet' or dataset.lower() == 'imagenet30':
304 | return IMAGENET_SUPERCLASS
305 | else:
306 | raise NotImplementedError()
307 |
308 |
309 | def get_subclass_dataset(dataset, classes):
310 | if not isinstance(classes, list):
311 | classes = [classes]
312 | indices = []
313 | for idx, tgt in enumerate(dataset.targets):
314 | if tgt in classes:
315 | indices.append(idx)
316 | dataset = Subset(dataset, indices)
317 | return dataset
318 |
319 |
320 | def build_transform(is_train, args):
321 | resize_im = args.input_size > 32
322 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
323 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
324 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
325 |
326 | if is_train:
327 | # this should always dispatch to transforms_imagenet_train
328 | transform = create_transform(
329 | input_size=args.input_size,
330 | is_training=True,
331 | color_jitter=args.color_jitter,
332 | auto_augment=args.aa,
333 | interpolation=args.train_interpolation,
334 | re_prob=args.reprob,
335 | re_mode=args.remode,
336 | re_count=args.recount,
337 | mean=mean,
338 | std=std,
339 | )
340 | if not resize_im:
341 | # replace RandomResizedCropAndInterpolation with
342 | # RandomCrop
343 | transform.transforms[0] = transforms.RandomCrop(
344 | args.input_size, padding=4)
345 | return transform
346 |
347 | t = []
348 | if resize_im:
349 | if args.crop_pct is None:
350 | if args.input_size < 384:
351 | args.crop_pct = 224 / 256
352 | else:
353 | args.crop_pct = 1.0
354 | size = int(args.input_size / args.crop_pct)
355 | t.append(
356 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
357 | )
358 | t.append(transforms.CenterCrop(args.input_size))
359 |
360 | t.append(transforms.ToTensor())
361 | t.append(transforms.Normalize(mean, std))
362 | return transforms.Compose(t)
363 |
364 |
365 |
--------------------------------------------------------------------------------
/MOODv1/eval_with_features.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import absolute_import
3 |
4 | import os
5 | from pydoc import classname
6 | import numpy as np
7 | import time
8 | import logging
9 | import argparse
10 | from collections import OrderedDict
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.utils.data import DataLoader
15 |
16 | from datasets import build_dataset
17 | from timm.models import create_model
18 | import ood_utils
19 | import utils
20 |
21 | import modeling_finetune
22 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
23 |
24 | def get_scores(ftrain, ftest, food, args):
25 | return ood_utils.get_scores_one_cluster(ftrain, ftest, food, args)
26 |
27 |
28 | def get_eval_results(ftrain, ftest, food, args):
29 | if ftrain is not None:
30 | if args.cc or args.avgcc:
31 | for i in range(args.nb_classes):
32 | ftrain[i] /= np.linalg.norm(ftrain[i], axis=-1, keepdims=True) + 1e-10
33 | else:
34 | ftrain /= np.linalg.norm(ftrain, axis=-1, keepdims=True) + 1e-10
35 |
36 | if ftest is not None:
37 | ftest /= np.linalg.norm(ftest, axis=-1, keepdims=True) + 1e-10
38 |
39 | if food is not None:
40 | food /= np.linalg.norm(food, axis=-1, keepdims=True) + 1e-10
41 |
42 | dtest, dood, labels = get_scores(ftrain, ftest, food, args)
43 |
44 | fpr95 = ood_utils.get_fpr(dtest, dood, labels)
45 | auroc, aupr = ood_utils.get_roc_sklearn(dtest, dood, labels), ood_utils.get_pr_sklearn(dtest, dood, labels)
46 | return fpr95, auroc, aupr
47 |
48 |
49 | def compute(model, ood_set, features_train, features_test, args):
50 | ood_sampler = torch.utils.data.SequentialSampler(ood_set)
51 |
52 | ood_loader = torch.utils.data.DataLoader(ood_set, sampler=ood_sampler,
53 | batch_size=args.batch_size, num_workers=args.num_workers,
54 | pin_memory=args.pin_mem, drop_last=False)
55 |
56 | features_ood = ood_utils.get_features(model, ood_loader, args.ood_dataset, args)
57 | fpr95, auroc, aupr = get_eval_results(features_train, features_test, features_ood, args,)
58 | return fpr95, auroc, aupr
59 |
60 |
61 | def get_args():
62 | parser = argparse.ArgumentParser(description="SSD evaluation")
63 | parser.add_argument('--ood_dataset', help='name of ood dataset', default=None, type=str)
64 | parser.add_argument('--ood_data_path', help='path of ood dataset', default=None, type=str)
65 |
66 | parser.add_argument('--cc', help='class-conditioned distance', default=False, action='store_true')
67 | parser.add_argument('--avgcc', help='average of class-conditioned distance', default=False, action='store_true')
68 | parser.add_argument('--max_num', default=1e10, help='maximum number of test samples', type=int)
69 |
70 | parser.add_argument("--results-dir", type=str, default="./eval_results")
71 | parser.add_argument('--class_idx', help='One-class OOD: the idx of the ID class number. Multi-class OOD: None', default=None, type=int)
72 |
73 | parser.add_argument("--batch_size", type=int, default=8)
74 | parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path", required=True)
75 | parser.add_argument("--metric", type=str, default='mahalanobis', help='OOD detection metric of the features',
76 | choices=['mahalanobis', 'cos', 'projection', 'gauss', 'kmeans', 'euclidean', 'minkowski', 'chebyshev'])
77 |
78 | parser.set_defaults(eval_ood=True)
79 | # Model parameters
80 | parser.add_argument('--model', default='beit_large_patch16_224', type=str, metavar='MODEL',
81 | help='Name of model to train')
82 | parser.add_argument('--rel_pos_bias', action='store_true')
83 | parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')
84 | parser.set_defaults(rel_pos_bias=True)
85 | parser.add_argument('--abs_pos_emb', action='store_true')
86 | parser.set_defaults(abs_pos_emb=False)
87 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
88 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
89 |
90 | parser.add_argument('--input_size', default=224, type=int, help='images input size')
91 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
92 | help='Dropout rate (default: 0.)')
93 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
94 | help='Attention dropout rate (default: 0.)')
95 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
96 | help='Drop path rate (default: 0.1)')
97 |
98 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
99 |
100 | parser.add_argument('--model_ema', action='store_true', default=False)
101 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
102 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
103 |
104 | # Augmentation parameters
105 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
106 | help='Color jitter factor (default: 0.4)')
107 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
108 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
109 | parser.add_argument('--smoothing', type=float, default=0.1,
110 | help='Label smoothing (default: 0.1)')
111 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
112 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
113 |
114 | # Evaluation parameters
115 | parser.add_argument('--crop_pct', type=float, default=None)
116 |
117 | # * Random Erase params
118 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
119 | help='Random erase prob (default: 0.25)')
120 | parser.add_argument('--remode', type=str, default='pixel',
121 | help='Random erase mode (default: "pixel")')
122 | parser.add_argument('--recount', type=int, default=1,
123 | help='Random erase count (default: 1)')
124 | parser.add_argument('--resplit', action='store_true', default=False,
125 | help='Do not random erase first (clean) augmentation split')
126 |
127 | # * Mixup params
128 | parser.add_argument('--mixup', type=float, default=0,
129 | help='mixup alpha, mixup enabled if > 0.')
130 | parser.add_argument('--cutmix', type=float, default=0,
131 | help='cutmix alpha, cutmix enabled if > 0.')
132 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
133 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
134 | parser.add_argument('--mixup_prob', type=float, default=1.0,
135 | help='Probability of performing mixup or cutmix when either/both is enabled')
136 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
137 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
138 | parser.add_argument('--mixup_mode', type=str, default='batch',
139 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
140 |
141 | # * Finetuning params
142 | parser.add_argument('--finetune', default='',
143 | help='finetune from checkpoint')
144 | parser.add_argument('--model_key', default='model|module', type=str)
145 | parser.add_argument('--model_prefix', default='', type=str)
146 | parser.add_argument('--init_scale', default=0.001, type=float)
147 | parser.add_argument('--use_mean_pooling', action='store_true')
148 | parser.set_defaults(use_mean_pooling=True)
149 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
150 | parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)
151 |
152 | # Dataset parameters
153 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
154 | help='dataset path')
155 | parser.add_argument('--eval_data_path', default=None, type=str,
156 | help='dataset path for evaluation')
157 |
158 | parser.add_argument('--nb_classes', default=0, type=int,
159 | help='number of the classification types')
160 | parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
161 |
162 | parser.add_argument('--data_set', default='IMNET', choices=['cifar10', 'cifar100', 'imagenet1k', 'imagenet30'],
163 | type=str, help='ImageNet dataset path')
164 | parser.add_argument('--output_dir', default='',
165 | help='path where to save, empty for no saving')
166 | parser.add_argument('--log_dir', default=None,
167 | help='path where to tensorboard log')
168 | parser.add_argument('--device', default='cuda',
169 | help='device to use for training / testing')
170 | parser.add_argument('--seed', default=0, type=int)
171 | parser.add_argument('--resume', default='',
172 | help='resume from checkpoint')
173 | parser.add_argument('--auto_resume', action='store_true')
174 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
175 | parser.set_defaults(auto_resume=True)
176 |
177 | parser.add_argument('--save_ckpt', action='store_true')
178 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
179 | parser.set_defaults(save_ckpt=True)
180 |
181 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
182 | help='start epoch')
183 | parser.add_argument('--eval', action='store_true',
184 | help='Perform evaluation only')
185 | parser.add_argument('--dist_eval', action='store_true', default=False,
186 | help='Enabling distributed evaluation')
187 | parser.add_argument('--num_workers', default=10, type=int)
188 | parser.add_argument('--pin_mem', action='store_true',
189 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
190 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
191 | parser.set_defaults(pin_mem=True)
192 |
193 | args = parser.parse_args()
194 | return args
195 |
196 |
197 | def main():
198 | # check ckpt
199 | assert os.path.exists(args.ckpt), 'Not find {}'.format(args.ckpt)
200 | print('loading from {}'.format(args.ckpt))
201 |
202 | # load checkpoint
203 | global model
204 | if args.ckpt.startswith('https'):
205 | state_dict = torch.hub.load_state_dict_from_url(
206 | args.ckpt, map_location='cpu', check_hash=True)
207 | else:
208 | state_dict = torch.load(args.ckpt, map_location='cpu')
209 |
210 | if 'model' in state_dict.keys():
211 | state_dict = state_dict['model']
212 | elif 'state_dict' in state_dict.keys():
213 | state_dict = state_dict['state_dict']
214 | elif 'module' in state_dict.keys():
215 | state_dict = state_dict['module']
216 |
217 | for k in list(state_dict.keys()):
218 | state_dict[k.replace('module.', '')] = state_dict.pop(k)
219 |
220 | if 'pt22k' in args.ckpt:
221 | model = utils.set_checkpoint_for_finetune(model, args, args.ckpt)
222 | else:
223 | utils.load_state_dict(model, state_dict, prefix=args.model_prefix)
224 |
225 | # dataloaders
226 | train_set, args.nb_classes = build_dataset(is_train=True, data_set=args.data_set, args=args)
227 | test_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args)
228 |
229 | train_sampler = torch.utils.data.RandomSampler(train_set)
230 | test_sampler = torch.utils.data.SequentialSampler(test_set)
231 |
232 | train_loader = torch.utils.data.DataLoader(train_set, sampler=train_sampler,
233 | batch_size=args.batch_size, num_workers=args.num_workers,
234 | pin_memory=args.pin_mem, drop_last=True,)
235 |
236 | test_loader = torch.utils.data.DataLoader(test_set, sampler=test_sampler,
237 | batch_size=args.batch_size, num_workers=args.num_workers,
238 | pin_memory=args.pin_mem, drop_last=False)
239 |
240 | features_train = ood_utils.get_features(model, train_loader, args.data_set, args, is_train=True)
241 | features_test = ood_utils.get_features(model, test_loader, args.data_set, args)
242 |
243 | if args.class_idx is None:
244 | ood_set, _ = build_dataset(is_train=False, data_set=args.ood_dataset, args=args, ood=True, ood_data_path=args.ood_data_path)
245 | fpr95, auroc, aupr = compute(model, ood_set, features_train, features_test, args)
246 |
247 | ood = "%-15s\t" % (args.ood_dataset)
248 | logger.info("{}\tIn-data: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format(
249 | args.metric, args.data_set, ood, fpr95*100, auroc*100, aupr*100))
250 | aurocs.append(auroc*100)
251 | print('\t'.join('{:.1f}'.format(auroc) for auroc in aurocs))
252 |
253 |
254 | else:
255 | ood_multi_class_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args, ood=True, ood_data_path=args.ood_data_path)
256 |
257 | fpr95s, aurocs, auprs = [], [], []
258 | for d in range(len(cls_list)):
259 | if d == args.class_idx:
260 | continue
261 | ood_set = ood_utils.get_subclass_dataset(ood_multi_class_set, classes=cls_list[d])
262 |
263 | args.ood_dataset = str(d)
264 | fpr95, auroc, aupr = compute(model, ood_set, features_train, features_test, args)
265 | logger.info("{}\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format(
266 | args.metric, args.data_set, args.class_idx, d, fpr95*100, auroc*100, aupr*100))
267 |
268 | fpr95s.append(fpr95)
269 | aurocs.append(auroc)
270 | auprs.append(aupr)
271 |
272 | fpr95 = np.mean(fpr95s)*100
273 | auroc = np.mean(aurocs)*100
274 | aupr = np.mean(auprs)*100
275 |
276 | results = "{}\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format(
277 | args.metric, args.data_set, args.class_idx, 'all', fpr95, auroc, aupr)
278 | logger.info(results)
279 | return fpr95, auroc, aupr
280 |
281 |
282 | def create_logger(results_dir):
283 | # create logger
284 | logging.basicConfig(level=logging.INFO, format="%(message)s")
285 | logger = logging.getLogger()
286 |
287 | # create results dir
288 | if results_dir is not None:
289 | ood_utils.mkdir(results_dir)
290 | results_file = os.path.join(results_dir, 'eval_results.txt')
291 | logger.addHandler(logging.FileHandler(results_file, "a"))
292 | print('=> Saving to {}'.format(results_file))
293 | return logger
294 |
295 |
296 | if __name__ == "__main__":
297 | args = get_args()
298 | device = "cuda:0"
299 |
300 | torch.manual_seed(args.seed)
301 | torch.cuda.manual_seed(args.seed)
302 | torch.cuda.manual_seed_all(args.seed)
303 |
304 | # create model
305 | _, args.nb_classes = build_dataset(is_train=False, data_set=args.data_set, args=args)
306 | model = create_model(args.model, pretrained=False, num_classes=args.nb_classes,
307 | drop_rate=args.drop, drop_path_rate=args.drop_path, attn_drop_rate=args.attn_drop_rate,
308 | drop_block_rate=None, use_mean_pooling=args.use_mean_pooling, init_scale=args.init_scale,
309 | use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb,
310 | init_values=args.layer_scale_init_value,).cuda()
311 |
312 | # ood
313 | logger = create_logger(args.results_dir)
314 | if args.class_idx is not None:
315 | cls_list = ood_utils.get_superclass_list(args.data_set)
316 | main()
317 |
318 |
319 |
320 |
321 |
322 |
--------------------------------------------------------------------------------
/MOODv1/dataset_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Modified on torchvision code bases
8 | # https://github.com/pytorch/vision
9 | # --------------------------------------------------------'
10 | from torchvision.datasets.vision import VisionDataset
11 | import torch
12 | from PIL import Image
13 | from torch.utils.data import DataLoader, Dataset
14 | import os
15 | import os.path
16 | import random
17 | import blobfile as bf
18 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple
19 | import numpy as np
20 |
21 | def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
22 | """Checks if a file is an allowed extension.
23 |
24 | Args:
25 | filename (string): path to a file
26 | extensions (tuple of strings): extensions to consider (lowercase)
27 |
28 | Returns:
29 | bool: True if the filename ends with one of given extensions
30 | """
31 | return filename.lower().endswith(extensions)
32 |
33 |
34 | def is_image_file(filename: str) -> bool:
35 | """Checks if a file is an allowed image extension.
36 |
37 | Args:
38 | filename (string): path to a file
39 |
40 | Returns:
41 | bool: True if the filename ends with a known image extension
42 | """
43 | return has_file_allowed_extension(filename, IMG_EXTENSIONS)
44 |
45 |
46 | def make_dataset(
47 | directory: str,
48 | class_to_idx: Dict[str, int],
49 | extensions: Optional[Tuple[str, ...]] = None,
50 | is_valid_file: Optional[Callable[[str], bool]] = None,
51 | ) -> List[Tuple[str, int]]:
52 | instances = []
53 | directory = os.path.expanduser(directory)
54 | both_none = extensions is None and is_valid_file is None
55 | both_something = extensions is not None and is_valid_file is not None
56 | if both_none or both_something:
57 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
58 | if extensions is not None:
59 | def is_valid_file(x: str) -> bool:
60 | return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
61 | is_valid_file = cast(Callable[[str], bool], is_valid_file)
62 | for target_class in sorted(class_to_idx.keys()):
63 | class_index = class_to_idx[target_class]
64 | target_dir = os.path.join(directory, target_class)
65 | if not os.path.isdir(target_dir):
66 | continue
67 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
68 | for fname in sorted(fnames):
69 | path = os.path.join(root, fname)
70 | if is_valid_file(path):
71 | item = path, class_index
72 | instances.append(item)
73 | return instances
74 |
75 |
76 | class DatasetFolder(VisionDataset):
77 | """A generic data loader where the samples are arranged in this way: ::
78 |
79 | root/class_x/xxx.ext
80 | root/class_x/xxy.ext
81 | root/class_x/xxz.ext
82 |
83 | root/class_y/123.ext
84 | root/class_y/nsdf3.ext
85 | root/class_y/asd932_.ext
86 |
87 | Args:
88 | root (string): Root directory path.
89 | loader (callable): A function to load a sample given its path.
90 | extensions (tuple[string]): A list of allowed extensions.
91 | both extensions and is_valid_file should not be passed.
92 | transform (callable, optional): A function/transform that takes in
93 | a sample and returns a transformed version.
94 | E.g, ``transforms.RandomCrop`` for images.
95 | target_transform (callable, optional): A function/transform that takes
96 | in the target and transforms it.
97 | is_valid_file (callable, optional): A function that takes path of a file
98 | and check if the file is a valid file (used to check of corrupt files)
99 | both extensions and is_valid_file should not be passed.
100 |
101 | Attributes:
102 | classes (list): List of the class names sorted alphabetically.
103 | class_to_idx (dict): Dict with items (class_name, class_index).
104 | samples (list): List of (sample path, class_index) tuples
105 | targets (list): The class_index value for each image in the dataset
106 | """
107 |
108 | def __init__(
109 | self,
110 | root: str,
111 | loader: Callable[[str], Any],
112 | extensions: Optional[Tuple[str, ...]] = None,
113 | transform: Optional[Callable] = None,
114 | target_transform: Optional[Callable] = None,
115 | is_valid_file: Optional[Callable[[str], bool]] = None,
116 | ) -> None:
117 | super(DatasetFolder, self).__init__(root, transform=transform,
118 | target_transform=target_transform)
119 | classes, class_to_idx = self._find_classes(self.root)
120 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
121 | if len(samples) == 0:
122 | msg = "Found 0 files in subfolders of: {}\n".format(self.root)
123 | if extensions is not None:
124 | msg += "Supported extensions are: {}".format(",".join(extensions))
125 | raise RuntimeError(msg)
126 |
127 | self.loader = loader
128 | self.extensions = extensions
129 |
130 | self.classes = classes
131 | self.class_to_idx = class_to_idx
132 | self.samples = samples
133 | self.targets = [s[1] for s in samples]
134 |
135 | def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
136 | """
137 | Finds the class folders in a dataset.
138 |
139 | Args:
140 | dir (string): Root directory path.
141 |
142 | Returns:
143 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
144 |
145 | Ensures:
146 | No class is a subdirectory of another.
147 | """
148 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
149 | classes.sort()
150 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
151 | return classes, class_to_idx
152 |
153 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
154 | """
155 | Args:
156 | index (int): Index
157 |
158 | Returns:
159 | tuple: (sample, target) where target is class_index of the target class.
160 | """
161 | while True:
162 | try:
163 | path, target = self.samples[index]
164 | sample = self.loader(path)
165 | break
166 | except Exception as e:
167 | print(e)
168 | index = random.randint(0, len(self.samples) - 1)
169 |
170 | if self.transform is not None:
171 | sample = self.transform(sample)
172 | if self.target_transform is not None:
173 | target = self.target_transform(target)
174 |
175 | return sample, target
176 |
177 | def __len__(self) -> int:
178 | return len(self.samples)
179 |
180 |
181 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
182 |
183 |
184 | def pil_loader(path: str) -> Image.Image:
185 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
186 | with open(path, 'rb') as f:
187 | img = Image.open(f)
188 | return img.convert('RGB')
189 |
190 |
191 | # TODO: specify the return type
192 | def accimage_loader(path: str) -> Any:
193 | import accimage
194 | try:
195 | return accimage.Image(path)
196 | except IOError:
197 | # Potentially a decoding problem, fall back to PIL.Image
198 | return pil_loader(path)
199 |
200 |
201 | def default_loader(path: str) -> Any:
202 | from torchvision import get_image_backend
203 | if get_image_backend() == 'accimage':
204 | return accimage_loader(path)
205 | else:
206 | return pil_loader(path)
207 |
208 |
209 | class ImageFolder(DatasetFolder):
210 | """A generic data loader where the images are arranged in this way: ::
211 |
212 | root/dog/xxx.png
213 | root/dog/xxy.png
214 | root/dog/xxz.png
215 |
216 | root/cat/123.png
217 | root/cat/nsdf3.png
218 | root/cat/asd932_.png
219 |
220 | Args:
221 | root (string): Root directory path.
222 | transform (callable, optional): A function/transform that takes in an PIL image
223 | and returns a transformed version. E.g, ``transforms.RandomCrop``
224 | target_transform (callable, optional): A function/transform that takes in the
225 | target and transforms it.
226 | loader (callable, optional): A function to load an image given its path.
227 | is_valid_file (callable, optional): A function that takes path of an Image file
228 | and check if the file is a valid file (used to check of corrupt files)
229 |
230 | Attributes:
231 | classes (list): List of the class names sorted alphabetically.
232 | class_to_idx (dict): Dict with items (class_name, class_index).
233 | imgs (list): List of (image path, class_index) tuples
234 | """
235 |
236 | def __init__(
237 | self,
238 | root: str,
239 | transform: Optional[Callable] = None,
240 | target_transform: Optional[Callable] = None,
241 | loader: Callable[[str], Any] = default_loader,
242 | is_valid_file: Optional[Callable[[str], bool]] = None,
243 | ):
244 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
245 | transform=transform,
246 | target_transform=target_transform,
247 | is_valid_file=is_valid_file)
248 | self.imgs = self.samples
249 | self.data = self.samples
250 |
251 |
252 | class SegmentationDataset(Dataset):
253 | def __init__(
254 | self,
255 | image_paths,
256 | label_paths=None,
257 | resolution=224,
258 | random_crop=False,
259 | random_flip=True,
260 | ):
261 | super().__init__()
262 | self.resolution = resolution
263 | self.local_images = image_paths
264 | self.local_labels = label_paths
265 | self.random_crop = random_crop
266 | self.random_flip = random_flip
267 |
268 | def __len__(self):
269 | return len(self.local_images)
270 |
271 | def __getitem__(self, idx):
272 | data = {}
273 |
274 | image_path = self.local_images[idx]
275 | with bf.BlobFile(image_path, "rb") as f:
276 | pil_image = Image.open(f)
277 | pil_image.load()
278 | pil_image = pil_image.convert("RGB")
279 |
280 | label_path = None
281 | if self.local_labels is not None:
282 | label_path = self.local_labels[idx]
283 | if label_path is not None:
284 | with bf.BlobFile(label_path, "rb") as f:
285 | pil_label = Image.open(f)
286 | pil_label.load()
287 | pil_label = pil_label.convert("L")
288 |
289 | if self.random_crop:
290 | if label_path is not None:
291 | arr_image, arr_label = random_crop_two_arr(pil_image, pil_label, self.resolution)
292 | else:
293 | arr = random_crop_arr(pil_image, self.resolution)
294 | else:
295 | arr_image = center_crop_arr(pil_image, self.resolution)
296 | if label_path is not None:
297 | arr_label = center_crop_arr(pil_label, self.resolution)
298 |
299 | def save_an_image(image, save_name):
300 | im = Image.fromarray(image)
301 | im.save(save_name)
302 |
303 | if self.random_flip and random.random() < 0.5:
304 | arr_image = arr_image[:, ::-1, :]
305 | if label_path is not None:
306 | arr_label = arr_label[:, ::-1]
307 |
308 | arr_image = arr_image.astype(np.float32) / 255 # / 127.5 - 1
309 | # np.ascontiguousarray to ensure the memory continuity
310 | image = torch.tensor(np.ascontiguousarray(np.transpose(arr_image, [2, 0, 1])))
311 |
312 | if label_path is not None:
313 | arr_label = arr_label.astype(np.float32) / 255 # / 127.5 - 1
314 | # label = np.ascontiguousarray(np.transpose(data['label'], [2, 0, 1]))
315 | label = torch.tensor(np.ascontiguousarray(np.expand_dims(arr_label, axis=0)))
316 | # (batch_size, 3, image_size, image_size), (batch_size, image_size, image_size)
317 | else:
318 | label = torch.ones((1, self.resolution, self.resolution))
319 |
320 | return image, label
321 |
322 | def center_crop_arr(pil_image, image_size):
323 | # We are not on a new enough PIL to support the `reducing_gap`
324 | # argument, which uses BOX downsampling at powers of two first.
325 | # Thus, we do it by hand to improve downsample quality.
326 | while min(*pil_image.size) >= 2 * image_size:
327 | pil_image = pil_image.resize(
328 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
329 | )
330 |
331 | scale = image_size / min(*pil_image.size)
332 | pil_image = pil_image.resize(
333 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
334 | )
335 |
336 | arr = np.array(pil_image)
337 | crop_y = (arr.shape[0] - image_size) // 2
338 | crop_x = (arr.shape[1] - image_size) // 2
339 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
340 |
341 |
342 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
343 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
344 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
345 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
346 |
347 | # We are not on a new enough PIL to support the `reducing_gap`
348 | # argument, which uses BOX downsampling at powers of two first.
349 | # Thus, we do it by hand to improve downsample quality.
350 | while min(*pil_image.size) >= 2 * smaller_dim_size:
351 | pil_image = pil_image.resize(
352 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
353 | )
354 |
355 | scale = smaller_dim_size / min(*pil_image.size)
356 | pil_image = pil_image.resize(
357 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
358 | )
359 |
360 | arr = np.array(pil_image)
361 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
362 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
363 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
364 |
365 |
366 | def random_crop_two_arr(pil_image, pil_label, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
367 | assert pil_image.shape == pil_label.shape
368 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
369 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
370 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
371 |
372 | # We are not on a new enough PIL to support the `reducing_gap`
373 | # argument, which uses BOX downsampling at powers of two first.
374 | # Thus, we do it by hand to improve downsample quality.
375 | while min(*pil_image.size) >= 2 * smaller_dim_size:
376 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
377 | pil_label = pil_label.resize(tuple(x // 2 for x in pil_label.size), resample=Image.BOX)
378 |
379 | scale = smaller_dim_size / min(*pil_image.size)
380 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
381 | pil_label = pil_label.resize(tuple(round(x * scale) for x in pil_label.size), resample=Image.BICUBIC)
382 |
383 | arr_image = np.array(pil_image)
384 | arr_label = np.array(pil_label)
385 | crop_y = random.randrange(arr_image.shape[0] - image_size + 1)
386 | crop_x = random.randrange(arr_image.shape[1] - image_size + 1)
387 |
388 | image = arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
389 | label = arr_label[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
390 | return image, label
391 |
--------------------------------------------------------------------------------
/MOODv1/eval_with_logits.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import absolute_import
3 |
4 | import os
5 | from pydoc import classname
6 | import numpy as np
7 | import time
8 | import logging
9 | import argparse
10 | from collections import OrderedDict
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.utils.data import DataLoader
15 | import torch.nn.functional as F
16 |
17 | from datasets import build_dataset
18 | from timm.models import create_model
19 | import ood_utils
20 | import utils
21 |
22 | from torch.autograd import Variable
23 | import modeling_finetune
24 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
25 |
26 | get_metric = None
27 |
28 | def get_eval_results(dtest, dood, args):
29 | labels = [1] * len(dtest) + [0] * len(dood) if args.metric in ['softmax', 'gradnorm'] else [0] * len(dtest) + [1] * len(dood)
30 | fpr95 = ood_utils.get_fpr(dtest, dood, labels)
31 | auroc, aupr = ood_utils.get_roc_sklearn(dtest, dood, labels), ood_utils.get_pr_sklearn(dtest, dood, labels)
32 | return fpr95, auroc, aupr
33 |
34 |
35 | def compute(model, ood_set, softmax_test, args):
36 | ood_sampler = torch.utils.data.SequentialSampler(ood_set)
37 | ood_loader = torch.utils.data.DataLoader(ood_set, sampler=ood_sampler,
38 | batch_size=args.batch_size, num_workers=args.num_workers,
39 | pin_memory=args.pin_mem, drop_last=False)
40 | softmax_ood = get_metric(model, ood_loader, args.ood_dataset, args)
41 | fpr95, auroc, aupr = get_eval_results(softmax_test, softmax_ood, args,)
42 | return fpr95, auroc, aupr
43 |
44 |
45 | def get_args():
46 | parser = argparse.ArgumentParser(description="SSD evaluation")
47 |
48 | parser.add_argument('--metric', default='softmax', help='OOD detection metric of the logits',
49 | type=str, choices=['softmax', 'entropy', 'energy', 'gradnorm'])
50 | parser.add_argument('--ood_dataset', help='name of ood dataset', default=None, type=str)
51 | parser.add_argument('--ood_data_path', help='path of ood dataset', default=None, type=str)
52 | parser.add_argument('--cc', help='class-conditioned distance', default=False, action='store_true')
53 | parser.add_argument('--avgcc', help='average of class-conditioned distance', default=False, action='store_true')
54 |
55 | parser.add_argument("--results-dir", type=str, default="./eval_results")
56 | parser.add_argument('--class_idx', help='One-class OOD: the idx of the ID class number. Multi-class OOD: None', default=None, type=int)
57 |
58 | parser.add_argument("--batch_size", type=int, default=16)
59 | parser.add_argument("--ckpt", type=str, help="checkpoint path", required=True)
60 | parser.add_argument("--method", type=str, default='MOOD', help='name of logger')
61 |
62 | parser.set_defaults(eval_ood=True)
63 | # Model parameters
64 | parser.add_argument('--model', default='beit_large_patch16_224', type=str, metavar='MODEL',
65 | help='Name of model to train')
66 | parser.add_argument('--rel_pos_bias', action='store_true')
67 | parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')
68 | parser.set_defaults(rel_pos_bias=True)
69 | parser.add_argument('--abs_pos_emb', action='store_true')
70 | parser.set_defaults(abs_pos_emb=False)
71 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
72 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
73 |
74 | parser.add_argument('--input_size', default=224, type=int, help='images input size')
75 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
76 | help='Dropout rate (default: 0.)')
77 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
78 | help='Attention dropout rate (default: 0.)')
79 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
80 | help='Drop path rate (default: 0.1)')
81 |
82 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
83 |
84 | parser.add_argument('--model_ema', action='store_true', default=False)
85 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
86 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
87 |
88 | # Augmentation parameters
89 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
90 | help='Color jitter factor (default: 0.4)')
91 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
92 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
93 | parser.add_argument('--smoothing', type=float, default=0.1,
94 | help='Label smoothing (default: 0.1)')
95 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
96 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
97 |
98 | # Evaluation parameters
99 | parser.add_argument('--crop_pct', type=float, default=None)
100 |
101 | # * Random Erase params
102 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
103 | help='Random erase prob (default: 0.25)')
104 | parser.add_argument('--remode', type=str, default='pixel',
105 | help='Random erase mode (default: "pixel")')
106 | parser.add_argument('--recount', type=int, default=1,
107 | help='Random erase count (default: 1)')
108 | parser.add_argument('--resplit', action='store_true', default=False,
109 | help='Do not random erase first (clean) augmentation split')
110 |
111 | # * Mixup params
112 | parser.add_argument('--mixup', type=float, default=0,
113 | help='mixup alpha, mixup enabled if > 0.')
114 | parser.add_argument('--cutmix', type=float, default=0,
115 | help='cutmix alpha, cutmix enabled if > 0.')
116 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
117 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
118 | parser.add_argument('--mixup_prob', type=float, default=1.0,
119 | help='Probability of performing mixup or cutmix when either/both is enabled')
120 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
121 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
122 | parser.add_argument('--mixup_mode', type=str, default='batch',
123 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
124 |
125 | # * Finetuning params
126 | parser.add_argument('--finetune', default='',
127 | help='finetune from checkpoint')
128 | parser.add_argument('--imagenet30_key', default='model|module', type=str)
129 | parser.add_argument('--model_prefix', default='', type=str)
130 | parser.add_argument('--init_scale', default=0.001, type=float)
131 | parser.add_argument('--use_mean_pooling', action='store_true')
132 | parser.set_defaults(use_mean_pooling=True)
133 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
134 | parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)
135 |
136 | # Dataset parameters
137 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
138 | help='dataset path')
139 | parser.add_argument('--eval_data_path', default=None, type=str,
140 | help='dataset path for evaluation')
141 | parser.add_argument('--nb_classes', default=0, type=int,
142 | help='number of the classification types')
143 | parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true')
144 |
145 | parser.add_argument('--data_set', default='IMNET',
146 | type=str, help='ImageNet dataset path')
147 | parser.add_argument('--output_dir', default='',
148 | help='path where to save, empty for no saving')
149 | parser.add_argument('--log_dir', default=None,
150 | help='path where to tensorboard log')
151 | parser.add_argument('--device', default='cuda',
152 | help='device to use for training / testing')
153 | parser.add_argument('--seed', default=0, type=int)
154 | parser.add_argument('--resume', default='',
155 | help='resume from checkpoint')
156 | parser.add_argument('--auto_resume', action='store_true')
157 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
158 | parser.set_defaults(auto_resume=True)
159 |
160 | parser.add_argument('--save_ckpt', action='store_true')
161 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
162 | parser.set_defaults(save_ckpt=True)
163 |
164 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
165 | help='start epoch')
166 | parser.add_argument('--eval', action='store_true',
167 | help='Perform evaluation only')
168 | parser.add_argument('--dist_eval', action='store_true', default=False,
169 | help='Enabling distributed evaluation')
170 | parser.add_argument('--num_workers', default=10, type=int)
171 | parser.add_argument('--pin_mem', action='store_true',
172 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
173 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
174 | parser.set_defaults(pin_mem=True)
175 |
176 | args = parser.parse_args()
177 | return args
178 |
179 | ## get energy ###
180 | def get_energy(model, dataloader, name, args, is_train=False, max_num=1e10):
181 | model.eval()
182 | energy = []
183 | with torch.no_grad():
184 | for index, (img, label) in enumerate(dataloader):
185 | if index >= max_num:
186 | break
187 | img, label = img.cuda(), label.cuda()
188 | outputs = model(img)
189 | # e = torch.logsumexp(outputs, dim=1)
190 | e = -torch.log(torch.sum(torch.exp(outputs), dim=1))
191 | energy += list(e.view(e.size(0)).cpu().numpy())
192 | if args.class_idx is None and (index + 1) % 100 == 0:
193 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r')
194 | energy = np.array(energy)
195 | print('\nenergy shape, ', energy.shape)
196 | return energy
197 |
198 | ## get energy ###
199 | def get_entropy(model, dataloader, name, args, is_train=False, max_num=1e10):
200 | model.eval()
201 | entropys = []
202 | with torch.no_grad():
203 | for index, (img, label) in enumerate(dataloader):
204 | if index >= max_num:
205 | break
206 | img, label = img.cuda(), label.cuda()
207 | outputs = model(img)
208 | e = -1.0 * torch.sum(F.softmax(outputs, dim=1) * F.log_softmax(outputs, dim=1), dim=1)
209 | entropys += list(e.view(e.size(0)).cpu().numpy())
210 | if args.class_idx is None and (index + 1) % 100 == 0:
211 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r')
212 | entropys = np.array(entropys)
213 | print('\nentropys shape, ', entropys.shape)
214 | return entropys
215 |
216 | ## get gradnorm ###
217 | def get_gradnorm(model, dataloader, name, args, is_train=False, max_num=1e10):
218 | confs = []
219 | logsoftmax = torch.nn.LogSoftmax(dim=-1).cuda()
220 |
221 | for index, (img, label) in enumerate(dataloader):
222 | if index >= max_num:
223 | break
224 | inputs = Variable(img.cuda(), requires_grad=True)
225 | model.zero_grad()
226 |
227 | img, label = img.cuda(), label.cuda()
228 | outputs = model(img)
229 |
230 | targets = torch.ones((inputs.shape[0], args.nb_classes)).cuda()
231 | loss = torch.mean(torch.sum(-targets * logsoftmax(outputs), dim=-1))
232 |
233 | loss.backward()
234 | layer_grad = model.head.weight.grad.data
235 | layer_grad_norm = torch.sum(torch.abs(layer_grad)).cpu().numpy()
236 | confs.append(layer_grad_norm)
237 |
238 | if args.class_idx is None and (index + 1) % 100 == 0:
239 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r')
240 |
241 | return confs
242 |
243 |
244 | ## get softmax ###
245 | def get_softmax(model, dataloader, name, args, is_train=False, max_num=1e10):
246 | model.eval()
247 | probs = []
248 | with torch.no_grad():
249 | for index, (img, label) in enumerate(dataloader):
250 | if index >= max_num:
251 | break
252 | img, label = img.cuda(), label.cuda()
253 | # import pdb;pdb.set_trace()
254 | prob = torch.max(model(img), axis=-1).values
255 | probs += list(prob.cpu().numpy())
256 | if args.class_idx is None and (index + 1) % 100 == 0:
257 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r')
258 | probs = np.array(probs)
259 | print('\n')
260 | return probs
261 |
262 |
263 | def main():
264 | # check ckpt
265 | assert os.path.exists(args.ckpt), 'Not find {}'.format(args.ckpt)
266 | print('loading from {}'.format(args.ckpt))
267 |
268 | # load checkpoint
269 | global model
270 | if args.ckpt.startswith('https'):
271 | state_dict = torch.hub.load_state_dict_from_url(
272 | args.ckpt, map_location='cpu', check_hash=True)
273 | else:
274 | state_dict = torch.load(args.ckpt, map_location='cpu')
275 |
276 | if 'model' in state_dict.keys():
277 | state_dict = state_dict['model']
278 | elif 'state_dict' in state_dict.keys():
279 | state_dict = state_dict['state_dict']
280 | elif 'module' in state_dict.keys():
281 | state_dict = state_dict['module']
282 |
283 | for k in list(state_dict.keys()):
284 | state_dict[k.replace('module.', '')] = state_dict.pop(k)
285 |
286 | if 'pt22k' in args.ckpt: # for pretrained ckpt
287 | model = utils.set_checkpoint_for_finetune(model, args, args.ckpt)
288 | else: # for fine-tuned ckpt
289 | utils.load_state_dict(model, state_dict, prefix=args.model_prefix)
290 |
291 | # test dataloader
292 | test_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args)
293 | test_sampler = torch.utils.data.SequentialSampler(test_set)
294 | test_loader = torch.utils.data.DataLoader(test_set, sampler=test_sampler,
295 | batch_size=args.batch_size, num_workers=args.num_workers,
296 | pin_memory=args.pin_mem, drop_last=False)
297 |
298 | global get_metric
299 | metric_dict = {
300 | 'softmax': get_softmax,
301 | 'entropy': get_entropy,
302 | 'energy': get_energy,
303 | 'gradnorm': get_gradnorm,
304 | }
305 | get_metric = metric_dict[args.metric]
306 | softmax_test = get_metric(model, test_loader, args.data_set, args)
307 |
308 | if args.class_idx is None:
309 | ood_set, _ = build_dataset(is_train=False, data_set=args.ood_dataset, args=args, ood=True, ood_data_path=args.ood_data_path)
310 | fpr95, auroc, aupr = compute(model, ood_set, softmax_test, args)
311 | ood = "%-15s\t" % (args.ood_dataset)
312 | logger.info("{}\tIn-data: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format(
313 | args.metric, args.data_set, ood, fpr95*100, auroc*100, aupr*100))
314 |
315 | else:
316 | ood_multi_class_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args, ood=True)
317 |
318 | fpr95s, aurocs, auprs = [], [], []
319 | for d in range(len(cls_list)):
320 | if d == args.class_idx:
321 | continue
322 | ood_set = ood_utils.get_subclass_dataset(ood_multi_class_set, classes=cls_list[d])
323 |
324 | args.ood_dataset = str(d)
325 | fpr95, auroc, aupr = compute(model, ood_set, softmax_test, args)
326 | logger.info("MOOD\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format(
327 | args.data_set, args.class_idx, d, fpr95*100, auroc*100, aupr*100))
328 |
329 | fpr95s.append(fpr95)
330 | aurocs.append(auroc)
331 | auprs.append(aupr)
332 |
333 | fpr95 = np.mean(fpr95s)*100
334 | auroc = np.mean(aurocs)*100
335 | aupr = np.mean(auprs)*100
336 |
337 | results = "MOOD\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format(
338 | args.data_set, args.class_idx, 'all', fpr95, auroc, aupr)
339 | logger.info(results)
340 | return fpr95, auroc, aupr
341 |
342 |
343 | def create_logger(results_dir):
344 | # create logger
345 | logging.basicConfig(level=logging.INFO, format="%(message)s")
346 | logger = logging.getLogger()
347 |
348 | # create results dir
349 | if results_dir is not None:
350 | ood_utils.mkdir(results_dir)
351 | results_file = os.path.join(results_dir, 'eval_results.txt')
352 | logger.addHandler(logging.FileHandler(results_file, "a"))
353 | print('=> Saving to {}'.format(results_file))
354 | return logger
355 |
356 |
357 | if __name__ == "__main__":
358 | args = get_args()
359 | device = "cuda:0"
360 |
361 | torch.manual_seed(args.seed)
362 | torch.cuda.manual_seed(args.seed)
363 | torch.cuda.manual_seed_all(args.seed)
364 |
365 | # create model
366 | _, args.nb_classes = build_dataset(is_train=False, data_set=args.data_set, args=args)
367 | model = create_model(args.model, pretrained=False, num_classes=args.nb_classes,
368 | drop_rate=args.drop, drop_path_rate=args.drop_path, attn_drop_rate=args.attn_drop_rate,
369 | drop_block_rate=None, use_mean_pooling=args.use_mean_pooling, init_scale=args.init_scale,
370 | use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb,
371 | init_values=args.layer_scale_init_value).cuda()
372 |
373 | # ood
374 | logger = create_logger(args.results_dir)
375 | if args.class_idx is not None:
376 | cls_list = ood_utils.get_superclass_list(args.data_set)
377 | main()
378 |
379 |
380 |
381 |
382 |
--------------------------------------------------------------------------------