├── MOODv2 ├── src │ ├── __init__.py │ ├── list_dataset.py │ ├── extract_feature_vit.py │ └── demo.py ├── imgs │ ├── framework.png │ ├── distribution.png │ ├── moodv2_table.png │ ├── performance.png │ ├── DTD_cracked_0004.jpg │ └── ImageNet_ILSVRC2012_val_00000293.JPEG ├── configs │ ├── vit-base-p16_384px.py │ ├── vit-base-p14_224px.py │ ├── vit-base-p14_518px.py │ ├── vit-base-p16_224px.py │ ├── vit-huge-p14_448px.py │ ├── vit-huge-p14_224px.py │ ├── swin-base-w7_224px.py │ ├── beit-large-p14_224px.py │ ├── beit-base-p16_224px.py │ ├── beit-base-p16_384px.py │ ├── pre-beit-base-p16_224px.py │ ├── pre-beitv2-base-p16_224px.py │ └── pre-mocov3-base-p16_224px.py ├── .gitignore └── README.md ├── MOODv1 ├── imgs │ └── moodv1_performance.png ├── .gitignore ├── requirements.txt ├── deepspeed_config.json ├── dall_e │ ├── __init__.py │ ├── utils.py │ ├── encoder.py │ └── decoder.py ├── datasets │ └── imagenet30.py ├── masking_generator.py ├── engine_for_pretraining.py ├── checkpoint.py ├── ood_utils.py ├── transforms.py ├── modeling_pretrain.py ├── optim_factory.py ├── modeling_discrete_vae.py ├── engine_for_finetuning.py ├── README.md ├── run_beit_pretraining.py ├── datasets.py ├── eval_with_features.py ├── dataset_folder.py └── eval_with_logits.py └── README.md /MOODv2/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MOODv2/imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/framework.png -------------------------------------------------------------------------------- /MOODv2/imgs/distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/distribution.png -------------------------------------------------------------------------------- /MOODv2/imgs/moodv2_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/moodv2_table.png -------------------------------------------------------------------------------- /MOODv2/imgs/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/performance.png -------------------------------------------------------------------------------- /MOODv1/imgs/moodv1_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv1/imgs/moodv1_performance.png -------------------------------------------------------------------------------- /MOODv2/imgs/DTD_cracked_0004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/DTD_cracked_0004.jpg -------------------------------------------------------------------------------- /MOODv1/.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | checkpoints/ 3 | run.md 4 | output 5 | pretrained/ 6 | tokenizer/ 7 | *.sh 8 | eval_results/ -------------------------------------------------------------------------------- /MOODv2/imgs/ImageNet_ILSVRC2012_val_00000293.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MOOD/HEAD/MOODv2/imgs/ImageNet_ILSVRC2012_val_00000293.JPEG -------------------------------------------------------------------------------- /MOODv1/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | timm==0.3.2 4 | Pillow 5 | blobfile 6 | mypy 7 | numpy 8 | pytest 9 | requests 10 | einops 11 | tensorboardX 12 | deepspeed==0.4.0 13 | scipy 14 | -------------------------------------------------------------------------------- /MOODv2/configs/vit-base-p16_384px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='VisionTransformer', 5 | arch='b', 6 | img_size=384, 7 | patch_size=16, 8 | drop_rate=0.1,), 9 | neck=None, 10 | head=dict( 11 | type='VisionTransformerClsHead', 12 | num_classes=1000, 13 | in_channels=768, 14 | )) 15 | 16 | -------------------------------------------------------------------------------- /MOODv2/configs/vit-base-p14_224px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='VisionTransformer', 5 | arch='b', 6 | img_size=224, 7 | patch_size=14, 8 | drop_rate=0.1,), 9 | neck=None, 10 | head=dict( 11 | type='VisionTransformerClsHead', 12 | num_classes=1000, 13 | in_channels=768, 14 | )) 15 | 16 | 17 | -------------------------------------------------------------------------------- /MOODv2/configs/vit-base-p14_518px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='VisionTransformer', 5 | arch='b', 6 | img_size=518, 7 | patch_size=14, 8 | drop_rate=0.1,), 9 | neck=None, 10 | head=dict( 11 | type='VisionTransformerClsHead', 12 | num_classes=1000, 13 | in_channels=768, 14 | )) 15 | 16 | 17 | -------------------------------------------------------------------------------- /MOODv2/configs/vit-base-p16_224px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='VisionTransformer', 5 | arch='b', 6 | img_size=224, 7 | patch_size=16, 8 | drop_rate=0.1,), 9 | neck=None, 10 | head=dict( 11 | type='VisionTransformerClsHead', 12 | num_classes=1000, 13 | in_channels=768, 14 | )) 15 | 16 | 17 | -------------------------------------------------------------------------------- /MOODv2/configs/vit-huge-p14_448px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='VisionTransformer', 5 | arch='huge', 6 | img_size=448, 7 | patch_size=14, 8 | drop_path_rate=0.3, # set to 0.3 9 | out_type='avg_featmap', 10 | final_norm=False,), 11 | neck=None, 12 | head=dict( 13 | type='LinearClsHead', 14 | num_classes=1000, 15 | in_channels=1280, 16 | ) 17 | ) -------------------------------------------------------------------------------- /MOODv2/configs/vit-huge-p14_224px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='VisionTransformer', 5 | arch='huge', 6 | img_size=224, 7 | patch_size=14, 8 | drop_path_rate=0.3, # set to 0.3 9 | out_type='avg_featmap', 10 | final_norm=False,), 11 | neck=None, 12 | head=dict( 13 | type='LinearClsHead', 14 | num_classes=1000, 15 | in_channels=1280, 16 | ) 17 | ) 18 | -------------------------------------------------------------------------------- /MOODv2/configs/swin-base-w7_224px.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='SwinTransformer', 6 | arch='base', 7 | img_size=224, 8 | stage_cfgs=dict(block_cfgs=dict(window_size=7)), 9 | ), 10 | neck=dict(type='GlobalAveragePooling'), 11 | head=dict( 12 | type='LinearClsHead', 13 | num_classes=1000, 14 | in_channels=1024, 15 | cal_acc=False 16 | ), 17 | ) 18 | -------------------------------------------------------------------------------- /MOODv2/configs/beit-large-p14_224px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='BEiTViT', 5 | arch='l', 6 | img_size=224, 7 | patch_size=14, 8 | out_type='avg_featmap', 9 | use_abs_pos_emb=False, 10 | use_rel_pos_bias=True, 11 | use_shared_rel_pos_bias=False, 12 | ), 13 | neck=None, 14 | head=dict( 15 | type='LinearClsHead', 16 | num_classes=1000, 17 | in_channels=1024, 18 | ), 19 | ) -------------------------------------------------------------------------------- /MOODv2/configs/beit-base-p16_224px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='BEiTViT', 5 | arch='base', 6 | img_size=224, 7 | patch_size=16, 8 | out_type='avg_featmap', 9 | use_abs_pos_emb=False, 10 | use_rel_pos_bias=True, 11 | use_shared_rel_pos_bias=False, 12 | ), 13 | neck=None, 14 | head=dict( 15 | type='LinearClsHead', 16 | num_classes=1000, 17 | in_channels=768, 18 | ), 19 | ) 20 | -------------------------------------------------------------------------------- /MOODv2/configs/beit-base-p16_384px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='ImageClassifier', 3 | backbone=dict( 4 | type='BEiTViT', 5 | arch='base', 6 | img_size=384, 7 | patch_size=16, 8 | out_type='avg_featmap', 9 | use_abs_pos_emb=False, 10 | use_rel_pos_bias=True, 11 | use_shared_rel_pos_bias=False, 12 | ), 13 | neck=None, 14 | head=dict( 15 | type='LinearClsHead', 16 | num_classes=1000, 17 | in_channels=768, 18 | ), 19 | ) 20 | -------------------------------------------------------------------------------- /MOODv2/configs/pre-beit-base-p16_224px.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='BEiT', 4 | backbone=dict( 5 | type='BEiTPretrainViT', 6 | arch='base', 7 | patch_size=16, 8 | drop_path_rate=0.1, 9 | final_norm=True, 10 | out_type='raw', 11 | layer_scale_init_value=0.1, 12 | ), 13 | neck=None, 14 | head=dict( 15 | type='BEiTV1Head', 16 | embed_dims=768, 17 | num_embed=8192, 18 | loss=dict(type='CrossEntropyLoss'), 19 | ) 20 | ) -------------------------------------------------------------------------------- /MOODv1/deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 8, 3 | "train_micro_batch_size_per_gpu": 8, 4 | "steps_per_print": 1000, 5 | "optimizer": { 6 | "type": "Adam", 7 | "adam_w_mode": true, 8 | "params": { 9 | "lr": 0.004, 10 | "weight_decay": 0.05, 11 | "bias_correction": true, 12 | "betas": [ 13 | 0.9, 14 | 0.999 15 | ], 16 | "eps": 1e-08 17 | } 18 | }, 19 | "fp16": { 20 | "enabled": true, 21 | "loss_scale": 0, 22 | "initial_scale_power": 7, 23 | "loss_scale_window": 128 24 | } 25 | } -------------------------------------------------------------------------------- /MOODv2/configs/pre-beitv2-base-p16_224px.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BEiT', 3 | backbone=dict( 4 | type='BEiTPretrainViT', 5 | arch='base', 6 | patch_size=16, 7 | out_indices=[-4, -1], 8 | final_norm=False, 9 | out_type='raw',), 10 | neck=dict( 11 | type='BEiTV2Neck', 12 | num_layers=2, 13 | early_layers=9, 14 | backbone_arch='base', 15 | ), 16 | head=dict( 17 | type='BEiTV2Head', 18 | embed_dims=768, 19 | num_embed=8192, 20 | loss=dict(type='CrossEntropyLoss')), 21 | ) 22 | 23 | -------------------------------------------------------------------------------- /MOODv1/dall_e/__init__.py: -------------------------------------------------------------------------------- 1 | import io, requests 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dall_e.encoder import Encoder 6 | from dall_e.decoder import Decoder 7 | from dall_e.utils import map_pixels, unmap_pixels 8 | 9 | def load_model(path: str, device: torch.device = None) -> nn.Module: 10 | if path.startswith('http://') or path.startswith('https://'): 11 | resp = requests.get(path) 12 | resp.raise_for_status() 13 | 14 | with io.BytesIO(resp.content) as buf: 15 | return torch.load(buf, map_location=device) 16 | else: 17 | with open(path, 'rb') as f: 18 | return torch.load(f, map_location=device) 19 | -------------------------------------------------------------------------------- /MOODv2/configs/pre-mocov3-base-p16_224px.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | temperature = 0.2 3 | model = dict( 4 | type='MoCoV3', 5 | base_momentum=0.01, 6 | backbone=dict( 7 | type='MoCoV3ViT', 8 | arch='base', # embed_dim = 768 9 | img_size=224, 10 | patch_size=16, 11 | stop_grad_conv1=True), 12 | neck=dict( 13 | type='NonLinearNeck', 14 | in_channels=768, 15 | hid_channels=4096, 16 | out_channels=256, 17 | num_layers=3, 18 | with_bias=False, 19 | with_last_bn=True, 20 | with_last_bn_affine=False, 21 | with_last_bias=False, 22 | with_avg_pool=False), 23 | head=dict( 24 | type='MoCoV3Head', 25 | predictor=dict( 26 | type='NonLinearNeck', 27 | in_channels=256, 28 | hid_channels=4096, 29 | out_channels=256, 30 | num_layers=2, 31 | with_bias=False, 32 | with_last_bn=True, 33 | with_last_bn_affine=False, 34 | with_last_bias=False, 35 | with_avg_pool=False), 36 | loss=dict(type='CrossEntropyLoss', loss_weight=2 * temperature), 37 | temperature=temperature,)) 38 | 39 | -------------------------------------------------------------------------------- /MOODv2/src/list_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import torch.utils.data as data 5 | from PIL import Image 6 | 7 | 8 | def default_loader(path): 9 | return Image.open(path).convert('RGB') 10 | 11 | 12 | def default_flist_reader(flist): 13 | """flist format: impath label\nimpath label\n.""" 14 | imlist = [] 15 | with open(flist, 'r') as rf: 16 | for line in rf.readlines(): 17 | data = line.strip().rsplit(maxsplit=1) 18 | if len(data) == 2: 19 | impath, imlabel = data 20 | else: 21 | impath, imlabel = data[0], 0 22 | imlist.append((impath, int(imlabel))) 23 | 24 | return imlist 25 | 26 | 27 | class ImageFilelist(data.Dataset): 28 | 29 | def __init__(self, 30 | root, 31 | flist, 32 | transform=None, 33 | target_transform=None, 34 | flist_reader=default_flist_reader, 35 | loader=default_loader): 36 | self.root = root 37 | self.imlist = flist_reader(flist) 38 | self.transform = transform 39 | self.target_transform = target_transform 40 | self.loader = loader 41 | 42 | def __getitem__(self, index): 43 | impath, target = self.imlist[index] 44 | img = self.loader(os.path.join(self.root, impath)) 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | if self.target_transform is not None: 48 | target = self.target_transform(target) 49 | 50 | return img, target 51 | 52 | def __len__(self): 53 | return len(self.imlist) 54 | -------------------------------------------------------------------------------- /MOODv1/dall_e/utils.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | logit_laplace_eps: float = 0.1 9 | 10 | @attr.s(eq=False) 11 | class Conv2d(nn.Module): 12 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 13 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) 14 | kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) 15 | 16 | use_float16: bool = attr.ib(default=True) 17 | device: torch.device = attr.ib(default=torch.device('cpu')) 18 | requires_grad: bool = attr.ib(default=False) 19 | 20 | def __attrs_post_init__(self) -> None: 21 | super().__init__() 22 | 23 | w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32, 24 | device=self.device, requires_grad=self.requires_grad) 25 | w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) 26 | 27 | b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device, 28 | requires_grad=self.requires_grad) 29 | self.w, self.b = nn.Parameter(w), nn.Parameter(b) 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | if self.use_float16 and 'cuda' in self.w.device.type: 33 | if x.dtype != torch.float16: 34 | x = x.half() 35 | 36 | w, b = self.w.half(), self.b.half() 37 | else: 38 | if x.dtype != torch.float32: 39 | x = x.float() 40 | 41 | w, b = self.w, self.b 42 | 43 | return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) 44 | 45 | def map_pixels(x: torch.Tensor) -> torch.Tensor: 46 | if x.dtype != torch.float: 47 | raise ValueError('expected input to have type float') 48 | 49 | return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps 50 | 51 | def unmap_pixels(x: torch.Tensor) -> torch.Tensor: 52 | if len(x.shape) != 4: 53 | raise ValueError('expected input to be 4d') 54 | if x.dtype != torch.float: 55 | raise ValueError('expected input to have type float') 56 | 57 | return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1) 58 | -------------------------------------------------------------------------------- /MOODv1/datasets/imagenet30.py: -------------------------------------------------------------------------------- 1 | ''' 2 | To Do: 3 | make symlink for ImageNet30 4 | ''' 5 | 6 | import os, errno 7 | 8 | def symlink_force(target, link_name): 9 | print('{}->{}'.format(target, link_name)) 10 | try: 11 | os.symlink(target, link_name) 12 | except: 13 | os.remove(link_name) 14 | os.symlink(target, link_name) 15 | 16 | 17 | def mkdir(directory): 18 | if not os.path.exists(directory): 19 | os.makedirs(directory) 20 | 21 | pairs = [ 22 | ('acorn', 'n12267677'), 23 | ('airliner', 'n02690373'), 24 | ('ambulance', 'n02701002'), 25 | ('american_alligator', 'n01698640'), 26 | ('banjo', 'n02787622'), 27 | ('barn', 'n02793495'), 28 | ('bikini', 'n02837789'), 29 | ('digital_clock', 'n03196217'), 30 | ('dragonfly', 'n02268443'), 31 | ('dumbbell', 'n03255030'), 32 | ('forklift', 'n03384352'), 33 | ('goblet', 'n03443371'), 34 | ('grand_piano', 'n03452741'), 35 | ('hotdog', 'n07697537'), 36 | ('hourglass', 'n03544143'), 37 | ('manhole_cover', 'n03717622'), 38 | ('mosque', 'n03788195'), 39 | ('nail', 'n03804744'), 40 | ('parking_meter', 'n03891332'), 41 | ('pillow', 'n03938244'), 42 | ('revolver', 'n04086273'), 43 | ('rotary_dial_telephone', 'n03187595'), 44 | ('schooner', 'n04147183'), 45 | ('snowmobile', 'n04252077'), 46 | ('soccer_ball', 'n04254680'), 47 | ('stingray', 'n01498041'), 48 | ('strawberry', 'n07745940'), 49 | ('tank', 'n04389033'), 50 | ('toaster', 'n04442312'), 51 | ('volcano', 'n09472597') 52 | ] 53 | 54 | # set source and target paths here 55 | ori_imagenet1k_path = '' 56 | target_imagenet30_path = '' 57 | 58 | train_path = os.path.join(target_imagenet30_path, 'train') 59 | test_path = os.path.join(target_imagenet30_path, 'test') 60 | mkdir(train_path) 61 | mkdir(test_path) 62 | 63 | for pair in pairs: 64 | symlink_force(os.path.join(ori_imagenet1k_path, 'train', pair[1]), os.path.join(train_path, pair[0])) 65 | symlink_force(os.path.join(ori_imagenet1k_path, 'val', pair[1]), os.path.join(test_path, pair[0])) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOOD 2 |

3 | • 🤗 Model 4 | • 🐱 Code 5 | • 📃 MOODv1 6 | • 📃 MOODv2
7 |

8 | 9 |

10 | framework 11 |

12 | 13 | ## MOODv1: Rethinking Out-of-Distribution Detection: Masked Image Modeling is All You Need (CVPR2023) 14 | The core of out-of-distribution (OOD) detection is to learn the in-distribution (ID) representation, which is distinguishable from OOD samples. Previous work applied recognition-based methods to learn the ID features, which tend to learn shortcuts instead of comprehensive representations. In this work, we find surprisingly that simply using reconstruction-based methods could boost the performance of OOD detection significantly. We deeply explore the main contributors of OOD detection and find that reconstruction-based pretext tasks have the potential to provide a generally applicable and efficacious prior, which benefits the model in learning intrinsic data distributions of the ID dataset. Specifically, we take Masked Image Modeling as a pretext task for our OOD detection framework (MOOD). Without bells and whistles, MOOD outperforms previous SOTA of one-class OOD detection by 5.7%, multi-class OOD detection by 3.0%, and near-distribution OOD detection by 2.1%. It even defeats the 10-shot-per-class outlier exposure OOD detection, although we do not include any OOD samples for our detection. 15 |

16 | moodv1 17 |

18 | 19 | ## MOODv2: Masked Image Modeling for Out-of-Distribution Detection (TPAMI2024) 20 | The crux of effective out-of-distribution (OOD) detection lies in acquiring a robust in-distribution (ID) representation, distinct from OOD samples. While previous methods predominantly leaned on recognition-based techniques for this purpose, they often resulted in shortcut learning, lacking comprehensive representations. In our study, we conducted a comprehensive analysis, exploring distinct pretraining tasks and employing various OOD score functions. The results highlight that the feature representations pre-trained through reconstruction yield a notable enhancement and narrow the performance gap among various score functions. This suggests that even simple score functions can rival complex ones when leveraging reconstruction-based pretext tasks. Reconstruction-based pretext tasks adapt well to various score functions. As such, it holds promising potential for further expansion. Our OOD detection framework, MOODv2, employs the masked image modeling pretext task. Without bells and whistles, MOODv2 impressively enhances 14.30% AUROC to 95.68% on ImageNet and achieves 99.98% on CIFAR-10. 21 |

22 | table 23 |

24 | -------------------------------------------------------------------------------- /MOODv2/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | scripts 3 | outputs 4 | models 5 | data 6 | pretrain 7 | tools 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | **/*.pyc 14 | core.* 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | !**/ilayout/lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | 117 | # custom 118 | /models 119 | /data 120 | .vscode 121 | .idea 122 | *.pkl 123 | *.pkl.json 124 | *.log.json 125 | *.npy 126 | work_dirs/ 127 | work_dirs 128 | work_dir 129 | workdirs/ 130 | results 131 | checkpoints 132 | logs 133 | log 134 | projects/*/data/ 135 | projects/*/data 136 | 137 | # Pytorch 138 | *.pth 139 | 140 | *.DS_Store 141 | 142 | *.tar 143 | *.TTF 144 | 145 | # web demo 146 | projects/ocrserver/ocrserver/static/results 147 | projects/ocrserver/ocrserver/static/uploads 148 | projects/ocrserver/inference_results 149 | projects/ocrserver/logs 150 | projects/autotrain/*/data/ 151 | projects/autotrain/*/data 152 | tmp/ 153 | debug/ 154 | 155 | # medius shell 156 | run_medius*.sh 157 | run_visualize*.sh 158 | projects/videoocr/run_sdk*.sh 159 | projects/videoocr/run_compare*.sh 160 | 161 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 162 | 163 | # dependencies 164 | node_modules 165 | /.pnp 166 | .pnp.js 167 | 168 | # testing 169 | /coverage 170 | 171 | # production 172 | /dist 173 | 174 | # misc 175 | .DS_Store 176 | .env.local 177 | .env.development.local 178 | .env.test.local 179 | .env.production.local 180 | 181 | npm-debug.log* 182 | yarn-debug.log* 183 | yarn-error.log* 184 | 185 | **/public/fonts/ 186 | **/temp/ 187 | **/tmp/ 188 | projects/ilayout/public/img/texture 189 | 190 | outputs/ 191 | checkpoints/ -------------------------------------------------------------------------------- /MOODv1/masking_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 3 | Copyright Zhun Zhong & Liang Zheng 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | 7 | Modified by Hangbo Bao, for generating the masked position for visual image transformer 8 | """ 9 | # -------------------------------------------------------- 10 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 11 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 12 | # Copyright (c) 2021 Microsoft 13 | # Licensed under The MIT License [see LICENSE for details] 14 | # By Hangbo Bao 15 | # Based on timm, DINO and DeiT code bases 16 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 17 | # Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 18 | # Copyright Zhun Zhong & Liang Zheng 19 | # 20 | # Hacked together by / Copyright 2020 Ross Wightman 21 | # 22 | # Modified by Hangbo Bao, for generating the masked position for visual image transformer 23 | # --------------------------------------------------------' 24 | import random 25 | import math 26 | import numpy as np 27 | 28 | 29 | class MaskingGenerator: 30 | def __init__( 31 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, 32 | min_aspect=0.3, max_aspect=None): 33 | if not isinstance(input_size, tuple): 34 | input_size = (input_size, ) * 2 35 | self.height, self.width = input_size 36 | 37 | self.num_patches = self.height * self.width 38 | self.num_masking_patches = num_masking_patches 39 | 40 | self.min_num_patches = min_num_patches 41 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 42 | 43 | max_aspect = max_aspect or 1 / min_aspect 44 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 45 | 46 | def __repr__(self): 47 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 48 | self.height, self.width, self.min_num_patches, self.max_num_patches, 49 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 50 | return repr_str 51 | 52 | def get_shape(self): 53 | return self.height, self.width 54 | 55 | def _mask(self, mask, max_mask_patches): 56 | delta = 0 57 | for attempt in range(10): 58 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 59 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 60 | h = int(round(math.sqrt(target_area * aspect_ratio))) 61 | w = int(round(math.sqrt(target_area / aspect_ratio))) 62 | if w < self.width and h < self.height: 63 | top = random.randint(0, self.height - h) 64 | left = random.randint(0, self.width - w) 65 | 66 | num_masked = mask[top: top + h, left: left + w].sum() 67 | # Overlap 68 | if 0 < h * w - num_masked <= max_mask_patches: 69 | for i in range(top, top + h): 70 | for j in range(left, left + w): 71 | if mask[i, j] == 0: 72 | mask[i, j] = 1 73 | delta += 1 74 | 75 | if delta > 0: 76 | break 77 | return delta 78 | 79 | def __call__(self): 80 | mask = np.zeros(shape=self.get_shape(), dtype=np.int) 81 | mask_count = 0 82 | while mask_count < self.num_masking_patches: 83 | max_mask_patches = self.num_masking_patches - mask_count 84 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 85 | 86 | delta = self._mask(mask, max_mask_patches) 87 | if delta == 0: 88 | break 89 | else: 90 | mask_count += delta 91 | 92 | return mask 93 | -------------------------------------------------------------------------------- /MOODv1/dall_e/encoder.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | from functools import partial 10 | from dall_e.utils import Conv2d 11 | 12 | @attr.s(eq=False, repr=False) 13 | class EncoderBlock(nn.Module): 14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) 16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) 17 | 18 | device: torch.device = attr.ib(default=None) 19 | requires_grad: bool = attr.ib(default=False) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | super().__init__() 23 | self.n_hid = self.n_out // 4 24 | self.post_gain = 1 / (self.n_layers ** 2) 25 | 26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() 28 | self.res_path = nn.Sequential(OrderedDict([ 29 | ('relu_1', nn.ReLU()), 30 | ('conv_1', make_conv(self.n_in, self.n_hid, 3)), 31 | ('relu_2', nn.ReLU()), 32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), 33 | ('relu_3', nn.ReLU()), 34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), 35 | ('relu_4', nn.ReLU()), 36 | ('conv_4', make_conv(self.n_hid, self.n_out, 1)),])) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.id_path(x) + self.post_gain * self.res_path(x) 40 | 41 | @attr.s(eq=False, repr=False) 42 | class Encoder(nn.Module): 43 | group_count: int = 4 44 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) 45 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) 46 | input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) 47 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) 48 | 49 | device: torch.device = attr.ib(default=torch.device('cpu')) 50 | requires_grad: bool = attr.ib(default=False) 51 | use_mixed_precision: bool = attr.ib(default=True) 52 | 53 | def __attrs_post_init__(self) -> None: 54 | super().__init__() 55 | 56 | blk_range = range(self.n_blk_per_group) 57 | n_layers = self.group_count * self.n_blk_per_group 58 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 59 | make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device, 60 | requires_grad=self.requires_grad) 61 | 62 | self.blocks = nn.Sequential(OrderedDict([ 63 | ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), 64 | ('group_1', nn.Sequential(OrderedDict([ 65 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], 66 | ('pool', nn.MaxPool2d(kernel_size=2)), 67 | ]))), 68 | ('group_2', nn.Sequential(OrderedDict([ 69 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], 70 | ('pool', nn.MaxPool2d(kernel_size=2)), 71 | ]))), 72 | ('group_3', nn.Sequential(OrderedDict([ 73 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], 74 | ('pool', nn.MaxPool2d(kernel_size=2)), 75 | ]))), 76 | ('group_4', nn.Sequential(OrderedDict([ 77 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], 78 | ]))), 79 | ('output', nn.Sequential(OrderedDict([ 80 | ('relu', nn.ReLU()), 81 | ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)), 82 | ]))), 83 | ])) 84 | 85 | def forward(self, x: torch.Tensor) -> torch.Tensor: 86 | if len(x.shape) != 4: 87 | raise ValueError(f'input shape {x.shape} is not 4d') 88 | if x.shape[1] != self.input_channels: 89 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') 90 | if x.dtype != torch.float32: 91 | raise ValueError('input must have dtype torch.float32') 92 | 93 | return self.blocks(x) 94 | -------------------------------------------------------------------------------- /MOODv1/dall_e/decoder.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | from functools import partial 10 | from dall_e.utils import Conv2d 11 | 12 | @attr.s(eq=False, repr=False) 13 | class DecoderBlock(nn.Module): 14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) 16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) 17 | 18 | device: torch.device = attr.ib(default=None) 19 | requires_grad: bool = attr.ib(default=False) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | super().__init__() 23 | self.n_hid = self.n_out // 4 24 | self.post_gain = 1 / (self.n_layers ** 2) 25 | 26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() 28 | self.res_path = nn.Sequential(OrderedDict([ 29 | ('relu_1', nn.ReLU()), 30 | ('conv_1', make_conv(self.n_in, self.n_hid, 1)), 31 | ('relu_2', nn.ReLU()), 32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), 33 | ('relu_3', nn.ReLU()), 34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), 35 | ('relu_4', nn.ReLU()), 36 | ('conv_4', make_conv(self.n_hid, self.n_out, 3)),])) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.id_path(x) + self.post_gain * self.res_path(x) 40 | 41 | @attr.s(eq=False, repr=False) 42 | class Decoder(nn.Module): 43 | group_count: int = 4 44 | n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8) 45 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) 46 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) 47 | output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) 48 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) 49 | 50 | device: torch.device = attr.ib(default=torch.device('cpu')) 51 | requires_grad: bool = attr.ib(default=False) 52 | use_mixed_precision: bool = attr.ib(default=True) 53 | 54 | def __attrs_post_init__(self) -> None: 55 | super().__init__() 56 | 57 | blk_range = range(self.n_blk_per_group) 58 | n_layers = self.group_count * self.n_blk_per_group 59 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 60 | make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device, 61 | requires_grad=self.requires_grad) 62 | 63 | self.blocks = nn.Sequential(OrderedDict([ 64 | ('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)), 65 | ('group_1', nn.Sequential(OrderedDict([ 66 | *[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], 67 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 68 | ]))), 69 | ('group_2', nn.Sequential(OrderedDict([ 70 | *[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], 71 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 72 | ]))), 73 | ('group_3', nn.Sequential(OrderedDict([ 74 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], 75 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 76 | ]))), 77 | ('group_4', nn.Sequential(OrderedDict([ 78 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], 79 | ]))), 80 | ('output', nn.Sequential(OrderedDict([ 81 | ('relu', nn.ReLU()), 82 | ('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)), 83 | ]))), 84 | ])) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | if len(x.shape) != 4: 88 | raise ValueError(f'input shape {x.shape} is not 4d') 89 | if x.shape[1] != self.vocab_size: 90 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}') 91 | if x.dtype != torch.float32: 92 | raise ValueError('input must have dtype torch.float32') 93 | 94 | return self.blocks(x) 95 | -------------------------------------------------------------------------------- /MOODv2/src/extract_feature_vit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import pickle 4 | from os.path import dirname 5 | import os 6 | import mmengine 7 | import numpy as np 8 | import torch 9 | import torchvision as tv 10 | from tqdm import tqdm 11 | 12 | from list_dataset import ImageFilelist 13 | from mmpretrain.apis import init_model 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='Say hello') 18 | parser.add_argument('data_root', help='Path to data') 19 | parser.add_argument('--out_file', help='Path to output file') 20 | parser.add_argument( 21 | '--cfg', default='configs/vit-base-p16-384.py', help='Path to config') 22 | parser.add_argument( 23 | '--checkpoint', 24 | default='checkpoints/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-' 25 | '384_20210928-98e8652b.pth', 26 | help='Path to checkpoint') 27 | parser.add_argument('--img_list', help='Path to image list') 28 | parser.add_argument('--batch', type=int, default=256, help='batch size') 29 | parser.add_argument('--workers', type=int, default=4, help='num of workers') 30 | parser.add_argument('--fc_save_path', default=None, help='Path to save fc') 31 | return parser.parse_args() 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | 37 | torch.backends.cudnn.benchmark = True 38 | 39 | cfg = mmengine.Config.fromfile(args.cfg) 40 | model = init_model(cfg, args.checkpoint, 0).cuda().eval() 41 | 42 | if args.fc_save_path is not None: 43 | if os.path.exists(args.fc_save_path): 44 | print(f'{args.fc_save_path} exists.') 45 | return 46 | mmengine.mkdir_or_exist(dirname(args.fc_save_path)) 47 | if cfg.model.head.type == 'VisionTransformerClsHead': 48 | fc = model.head.layers.head 49 | elif cfg.model.head.type == 'LinearClsHead': 50 | fc = model.head.fc 51 | elif cfg.model.head.type in ['BEiTV1Head', 'BEiTV2Head']: 52 | fc = model.head.cls_head 53 | elif cfg.model.head.type in ['MoCoV3Head']: 54 | print(f'{cfg.model.head.type} utilize NonLinearNetwork which cannot be represented by a weight and a bias') 55 | raise 56 | else: 57 | print(cfg.model.head.type) 58 | print(model.head) 59 | import pdb;pdb.set_trace() 60 | raise NotImplementedError(cfg.model.backbone.type) 61 | w = fc.weight.cpu().detach().numpy() 62 | b = fc.bias.cpu().detach().numpy() 63 | with open(args.fc_save_path, 'wb') as f: 64 | pickle.dump([w, b], f) 65 | return 66 | 67 | # if os.path.exists(out_file): 68 | # print(f'{out_file} exists.') 69 | # return 70 | 71 | if hasattr(cfg.model.backbone, 'img_size'): 72 | img_size = cfg.model.backbone.img_size 73 | else: 74 | img_size = 224 75 | 76 | transform = tv.transforms.Compose([ 77 | tv.transforms.Resize((img_size, img_size)), 78 | tv.transforms.ToTensor(), 79 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 80 | ]) 81 | 82 | if args.img_list not in [None, 'None']: 83 | dataset = ImageFilelist(args.data_root, args.img_list, transform) 84 | else: 85 | dataset = tv.datasets.ImageFolder(args.data_root, transform) 86 | print(f'lenth of dataset: {len(dataset)}') 87 | 88 | dataloader = torch.utils.data.DataLoader( 89 | dataset, 90 | batch_size=args.batch, 91 | shuffle=False, 92 | num_workers=args.workers, 93 | pin_memory=True, 94 | drop_last=False) 95 | 96 | features = [] 97 | with torch.no_grad(): 98 | for i, (x, _) in enumerate(tqdm(dataloader)): 99 | x = x.cuda() 100 | if cfg.model.backbone.type == 'BEiTPretrainViT': 101 | # (B, L, C) -> (B, C) 102 | feat_batch = model.backbone( 103 | x, mask=None)[0].mean(1) 104 | elif cfg.model.backbone.type == 'SwinTransformer': 105 | # (B, C, H, W) -> (B, C) 106 | feat_batch = model.backbone(x)[0] 107 | B, C, H, W = feat_batch.shape 108 | feat_batch = feat_batch.reshape(B, C, -1).mean(-1) 109 | else: 110 | # (B, C) 111 | feat_batch = model.backbone(x)[0] 112 | assert len(feat_batch.shape) == 2 113 | features.append(feat_batch.cpu().numpy()) 114 | 115 | features = np.concatenate(features, axis=0) 116 | print(f'features: {features.shape}') 117 | mmengine.mkdir_or_exist(dirname(args.out_file)) 118 | with open(args.out_file, 'wb') as f: 119 | pickle.dump(features, f) 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /MOODv1/engine_for_pretraining.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm, DINO and DeiT code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # https://github.com/facebookresearch/deit/ 10 | # https://github.com/facebookresearch/dino 11 | # --------------------------------------------------------' 12 | import math 13 | import sys 14 | from typing import Iterable 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | import utils 20 | print_freq = 1000 21 | 22 | def train_one_epoch(model: torch.nn.Module, d_vae: torch.nn.Module, 23 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 24 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 25 | log_writer=None, lr_scheduler=None, start_steps=None, 26 | lr_schedule_values=None, wd_schedule_values=None): 27 | model.train() 28 | metric_logger = utils.MetricLogger(delimiter=" ") 29 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 30 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 31 | header = 'Epoch: [{}]'.format(epoch) 32 | 33 | for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 34 | # assign learning rate & weight decay for each step 35 | it = start_steps + step # global training iteration 36 | if lr_schedule_values is not None or wd_schedule_values is not None: 37 | for i, param_group in enumerate(optimizer.param_groups): 38 | if lr_schedule_values is not None: 39 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 40 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 41 | param_group["weight_decay"] = wd_schedule_values[it] 42 | 43 | samples, images, bool_masked_pos = batch 44 | images = images.to(device, non_blocking=True) 45 | samples = samples.to(device, non_blocking=True) 46 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True) 47 | 48 | with torch.no_grad(): 49 | input_ids = d_vae.get_codebook_indices(images).flatten(1) 50 | bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool) 51 | labels = input_ids[bool_masked_pos] 52 | 53 | with torch.cuda.amp.autocast(): 54 | outputs = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False) # torch.Size([74, 8192]) 55 | loss = nn.CrossEntropyLoss()(input=outputs, target=labels) # torch.Size([74]) 56 | 57 | loss_value = loss.item() 58 | 59 | if not math.isfinite(loss_value): 60 | print("Loss is {}, stopping training".format(loss_value)) 61 | sys.exit(1) 62 | 63 | optimizer.zero_grad() 64 | # this attribute is added by timm on one optimizer (adahessian) 65 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 66 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=is_second_order) 68 | loss_scale_value = loss_scaler.state_dict()["scale"] 69 | 70 | torch.cuda.synchronize() 71 | 72 | mlm_acc = (outputs.max(-1)[1] == labels).float().mean().item() 73 | 74 | metric_logger.update(mlm_acc=mlm_acc) 75 | if log_writer is not None: 76 | log_writer.update(mlm_acc=mlm_acc, head="loss") 77 | 78 | metric_logger.update(loss=loss_value) 79 | metric_logger.update(loss_scale=loss_scale_value) 80 | min_lr = 10. 81 | max_lr = 0. 82 | for group in optimizer.param_groups: 83 | min_lr = min(min_lr, group["lr"]) 84 | max_lr = max(max_lr, group["lr"]) 85 | 86 | metric_logger.update(lr=max_lr) 87 | metric_logger.update(min_lr=min_lr) 88 | weight_decay_value = None 89 | for group in optimizer.param_groups: 90 | if group["weight_decay"] > 0: 91 | weight_decay_value = group["weight_decay"] 92 | metric_logger.update(weight_decay=weight_decay_value) 93 | metric_logger.update(grad_norm=grad_norm) 94 | 95 | if log_writer is not None: 96 | log_writer.update(loss=loss_value, head="loss") 97 | log_writer.update(loss_scale=loss_scale_value, head="opt") 98 | log_writer.update(lr=max_lr, head="opt") 99 | log_writer.update(min_lr=min_lr, head="opt") 100 | log_writer.update(weight_decay=weight_decay_value, head="opt") 101 | log_writer.update(grad_norm=grad_norm, head="opt") 102 | 103 | log_writer.set_step() 104 | 105 | if lr_scheduler is not None: 106 | lr_scheduler.step_update(start_steps + step) 107 | # gather the stats from all processes 108 | metric_logger.synchronize_between_processes() 109 | print("Averaged stats:", metric_logger) 110 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 111 | -------------------------------------------------------------------------------- /MOODv1/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | #from tensorflow.io import gfile 4 | import numpy as np 5 | 6 | def get_l16_config(config): 7 | """ ViT-L/16 configuration """ 8 | config.patch_size = 16 9 | config.emb_dim = 1024 10 | config.mlp_dim = 4096 11 | config.num_heads = 16 12 | config.num_layers = 24 13 | config.attn_dropout_rate = 0.0 14 | config.dropout_rate = 0.1 15 | return config 16 | 17 | def get_b16_config(config): 18 | """ ViT-B/16 configuration """ 19 | config.patch_size = 16 20 | config.emb_dim = 768 21 | config.mlp_dim = 3072 22 | config.num_heads = 12 23 | config.num_layers = 12 24 | config.attn_dropout_rate = 0.0 25 | config.dropout_rate = 0.1 26 | return config 27 | 28 | def load_checkpoint(path, new_img=384, patch=16, emb_dim=768, layers=12): 29 | """ Load weights from a given checkpoint path in npz/pth """ 30 | if path.endswith('npz'): 31 | keys, values = load_jax(path) 32 | state_dict = convert_jax_pytorch(keys, values) 33 | elif path.endswith('pth') or path.endswith('pt'): 34 | state_dict = torch.load(path, map_location=torch.device("cpu")) 35 | else: 36 | raise ValueError("checkpoint format {} not supported yet!".format(path.split('.')[-1])) 37 | 38 | if 'model' in state_dict.keys(): 39 | state_dict = state_dict['model'] 40 | elif 'state_dict' in state_dict.keys(): 41 | state_dict = state_dict['state_dict'] 42 | 43 | # if 'pos_embed' in state_dict.keys(): 44 | # # Deal with class token 45 | # posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 46 | # model_grid_seq = new_img//patch 47 | # ckpt_grid_seq = int(np.sqrt(posemb_grid.shape[0])) 48 | 49 | # if model_grid_seq!=ckpt_grid_seq: 50 | # # Get old and new grid sizes 51 | # posemb_grid = posemb_grid.reshape(ckpt_grid_seq, ckpt_grid_seq, -1) 52 | 53 | # posemb_grid = torch.unsqueeze(posemb_grid.permute(2, 0, 1), dim=0) 54 | # posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=(model_grid_seq, model_grid_seq), mode='bicubic', align_corners=False) 55 | # posemb_grid = posemb_grid.permute(0, 2, 3, 1).flatten(1, 2) 56 | 57 | # # Deal with class token and return 58 | # posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 59 | # # if 'jx' in path: 60 | # # state_dict['pos_embed'] = posemb 61 | # # else: 62 | # state_dict['transformer.pos_embedding.pos_embedding'] = posemb 63 | # print('Resized positional embedding from (%d,%d) to (%d,%d)'%(ckpt_grid_seq,ckpt_grid_seq,model_grid_seq,model_grid_seq)) 64 | return state_dict 65 | 66 | 67 | def load_jax(path): 68 | """ Loads params from a npz checkpoint previously stored with `save()` in jax implemetation """ 69 | ckpt_dict = np.load(path, allow_pickle=False) 70 | keys, values = zip(*list(ckpt_dict.items())) 71 | # with gfile.GFile(path, 'rb') as f: 72 | # ckpt_dict = np.load(f, allow_pickle=False) 73 | # keys, values = zip(*list(ckpt_dict.items())) 74 | return keys, values 75 | 76 | 77 | def save_jax_to_pytorch(jax_path, save_path): 78 | model_name = jax_path.split('/')[-1].split('.')[0] 79 | keys, values = load_jax(jax_path) 80 | state_dict = convert_jax_pytorch(keys, values) 81 | checkpoint = {'state_dict': state_dict} 82 | torch.save(checkpoint, os.path.join(save_path, model_name + '.pth')) 83 | 84 | 85 | def replace_names(names): 86 | """ Replace jax model names with pytorch model names """ 87 | new_names = [] 88 | for name in names: 89 | if name == 'Transformer': 90 | new_names.append('transformer') 91 | elif name == 'encoder_norm': 92 | new_names.append('norm') 93 | elif 'encoderblock' in name: 94 | num = name.split('_')[-1] 95 | new_names.append('encoder_layers') 96 | new_names.append(num) 97 | elif 'LayerNorm' in name: 98 | num = name.split('_')[-1] 99 | if num == '0': 100 | new_names.append('norm{}'.format(1)) 101 | elif num == '2': 102 | new_names.append('norm{}'.format(2)) 103 | elif 'MlpBlock' in name: 104 | new_names.append('mlp') 105 | elif 'Dense' in name: 106 | num = name.split('_')[-1] 107 | new_names.append('fc{}'.format(int(num) + 1)) 108 | elif 'MultiHeadDotProductAttention' in name: 109 | new_names.append('attn') 110 | elif name == 'kernel' or name == 'scale': 111 | new_names.append('weight') 112 | elif name == 'bias': 113 | new_names.append(name) 114 | elif name == 'posembed_input': 115 | new_names.append('pos_embedding') 116 | elif name == 'pos_embedding': 117 | new_names.append('pos_embedding') 118 | elif name == 'embedding': 119 | new_names.append('embedding') 120 | elif name == 'head': 121 | new_names.append('classifier') 122 | elif name == 'cls': 123 | new_names.append('cls_token') 124 | else: 125 | new_names.append(name) 126 | return new_names 127 | 128 | 129 | def convert_jax_pytorch(keys, values): 130 | """ Convert jax model parameters with pytorch model parameters """ 131 | state_dict = {} 132 | for key, value in zip(keys, values): 133 | 134 | # convert name to torch names 135 | names = key.split('/') 136 | torch_names = replace_names(names) 137 | torch_key = '.'.join(w for w in torch_names) 138 | 139 | # convert values to tensor and check shapes 140 | tensor_value = torch.tensor(value, dtype=torch.float) 141 | # check shape 142 | num_dim = len(tensor_value.shape) 143 | 144 | if num_dim == 1: 145 | tensor_value = tensor_value.squeeze() 146 | elif num_dim == 2 and torch_names[-1] == 'weight': 147 | # for normal weight, transpose it 148 | tensor_value = tensor_value.T 149 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] in ['query', 'key', 'value']: 150 | feat_dim, num_heads, head_dim = tensor_value.shape 151 | # for multi head attention q/k/v weight 152 | tensor_value = tensor_value 153 | elif num_dim == 2 and torch_names[-1] == 'bias' and torch_names[-2] in ['query', 'key', 'value']: 154 | # for multi head attention q/k/v bias 155 | tensor_value = tensor_value 156 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] == 'out': 157 | # for multi head attention out weight 158 | tensor_value = tensor_value 159 | elif num_dim == 4 and torch_names[-1] == 'weight': 160 | tensor_value = tensor_value.permute(3, 2, 0, 1) 161 | 162 | # print("{}: {}".format(torch_key, tensor_value.shape)) 163 | state_dict[torch_key] = tensor_value 164 | return state_dict 165 | 166 | 167 | if __name__ == '__main__': 168 | save_jax_to_pytorch('/Users/leon/Downloads/jax/imagenet21k+imagenet2012_ViT-L_16-224.npz', '/Users/leon/Downloads/pytorch') 169 | 170 | 171 | -------------------------------------------------------------------------------- /MOODv1/ood_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sklearn.metrics as skm 4 | from torch.utils.data.dataset import Subset 5 | from scipy import linalg 6 | from sklearn.metrics import roc_curve, auc 7 | import matplotlib.pyplot as plt 8 | from PIL import Image, ImageFile 9 | import cv2 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | ## utils ## 13 | def mkdir(path): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | return path 17 | 18 | 19 | ### dataset ### 20 | def get_subclass_dataset(dataset, classes): 21 | if not isinstance(classes, list): 22 | classes = [classes] 23 | indices = [] 24 | for idx, tgt in enumerate(dataset.targets): 25 | if tgt in classes: 26 | indices.append(idx) 27 | subdataset = Subset(dataset, indices) 28 | return subdataset 29 | 30 | 31 | def get_superclass_list(dataset): 32 | CIFAR10_SUPERCLASS = list(range(10)) # one class 33 | IMAGENET_SUPERCLASS = list(range(30)) # one class 34 | CIFAR100_SUPERCLASS = [ 35 | [4, 31, 55, 72, 95], 36 | [1, 33, 67, 73, 91], 37 | [54, 62, 70, 82, 92], 38 | [9, 10, 16, 29, 61], 39 | [0, 51, 53, 57, 83], 40 | [22, 25, 40, 86, 87], 41 | [5, 20, 26, 84, 94], 42 | [6, 7, 14, 18, 24], 43 | [3, 42, 43, 88, 97], 44 | [12, 17, 38, 68, 76], 45 | [23, 34, 49, 60, 71], 46 | [15, 19, 21, 32, 39], 47 | [35, 63, 64, 66, 75], 48 | [27, 45, 77, 79, 99], 49 | [2, 11, 36, 46, 98], 50 | [28, 30, 44, 78, 93], 51 | [37, 50, 65, 74, 80], 52 | [47, 52, 56, 59, 96], 53 | [8, 13, 48, 58, 90], 54 | [41, 69, 81, 85, 89], 55 | ] 56 | if dataset.lower() == 'cifar10': 57 | return CIFAR10_SUPERCLASS 58 | elif dataset.lower() == 'cifar100': 59 | return CIFAR100_SUPERCLASS 60 | elif dataset.lower() == 'imagenet' or dataset.lower() == 'imagenet30': 61 | return IMAGENET_SUPERCLASS 62 | else: 63 | raise NotImplementedError() 64 | 65 | def get_scores_one_cluster(ftrain, ftest, food, args): 66 | methods = { 67 | 'mahalanobis': mahalanobis, # Mahalanobis Distance 68 | 'cos': cosine_similarity, # cosine similarity 69 | 'projection': projection, # projection distance 70 | 'gauss': gauss_distribution, # distribution percentage of gauss distribution 71 | 'kmeans': kmeans, # the distance to the nearest cluster 72 | 'euclidean': euclidean_distance, # euclidean distance 73 | 'minkowski': minkowski_distance, # minkowski distance 74 | 'chebyshev': chebyshev_distance, # chebyshev distance 75 | } 76 | din = methods[args.metric](ftrain, ftest, args) 77 | dood = methods[args.metric](ftrain, food, args) 78 | label = [0] * len(din) + [1] * len(dood) 79 | return din, dood, label 80 | 81 | 82 | def mahalanobis(ftrain, ftest, args): 83 | cov = lambda x: np.cov(x.T, bias=True) 84 | if args.cc or args.avgcc: 85 | dtest = [[] for _ in range(args.nb_classes)] 86 | mean = [[] for _ in range(args.nb_classes)] 87 | std = [[] for _ in range(args.nb_classes)] 88 | for i in range(args.nb_classes): 89 | std[i] = np.linalg.pinv(cov(ftrain[i])) 90 | mean[i] = np.mean(ftrain[i], axis=0, keepdims=True) 91 | dtest[i] = np.sum((ftest - mean[i])* (std[i].dot((ftest - mean[i]).T)).T, axis=-1,) 92 | if args.cc: 93 | return np.min(dtest, axis=0) 94 | else: 95 | return np.mean(dtest, axis=0) 96 | 97 | else: 98 | std = np.linalg.pinv(cov(ftrain)) 99 | mean = np.mean(ftrain, axis=0, keepdims=True) 100 | dtest = np.sum((ftest - mean)* (std.dot((ftest - mean).T)).T, axis=-1,) 101 | return dtest 102 | 103 | ## get features ### 104 | def get_features(model, dataloader, name, args, is_train=False, max_num=1e10): 105 | model.eval() 106 | features = [] 107 | for index, (img, label) in enumerate(dataloader): 108 | if index >= max_num: 109 | break 110 | img, label = img.cuda(), label.cuda() 111 | feature = model.forward_features(img) 112 | features += list(feature.data.cpu().numpy()) 113 | if args.class_idx is None and (index + 1) % 100 == 0: 114 | shape = np.array(features).shape 115 | print('{}: ({}, {})/({}, {})'.format(name, index+1, shape[-1], len(dataloader), shape[-1]), end='\r') 116 | print('\n') 117 | features = np.array(features) 118 | return features 119 | 120 | 121 | #### OOD detection #### 122 | def get_roc_sklearn(xin, xood, labels): 123 | data = np.concatenate((xin, xood)) 124 | auroc = skm.roc_auc_score(labels, data) 125 | # import pdb;pdb.set_trace() 126 | return auroc 127 | 128 | 129 | def get_pr_sklearn(xin, xood, labels=None): 130 | data = np.concatenate((xin, xood)) 131 | aupr = skm.average_precision_score(labels, data) 132 | return aupr 133 | 134 | 135 | def get_fpr(xin, xood, labels): 136 | if labels == [0] * len(xin) + [1] * len(xood): 137 | return np.sum(xood < np.percentile(xin, 95)) / len(xood) 138 | elif labels == [1] * len(xin) + [0] * len(xood): 139 | return np.sum(xood > np.percentile(xin, 95)) / len(xood) 140 | else: 141 | raise 142 | 143 | 144 | def projection(ftrain, ftest): 145 | from sklearn.metrics.pairwise import cosine_similarity 146 | matrix_in = cosine_similarity(ftrain, ftest) 147 | mod_in = np.linalg.norm(ftest)**2 148 | din = np.max(matrix_in, axis=1)*mod_in 149 | return din 150 | 151 | 152 | def gauss_distribution(ftrain, ftest): 153 | shrunkcov = True 154 | if shrunkcov: 155 | from sklearn.covariance import ledoit_wolf 156 | print("Using ledoit-wolf covariance estimator.") 157 | cov = lambda x: ledoit_wolf(x)[0] 158 | else: 159 | cov = lambda x: np.cov(x.T, bias=True) 160 | 161 | std = np.linalg.pinv(cov(ftrain)) 162 | mean = np.mean(ftrain, axis=0, keepdims=True) 163 | D = len(ftrain[0]) 164 | 165 | dtest = np.sum((ftest - mean)* (std.dot((ftest - mean).T)).T, axis=-1,) 166 | k = 1 / ((2*np.pi)**(D/2) * np.linalg.norm(std)**0.5) 167 | ptest = k * np.exp(-0.5 * dtest) 168 | return ptest 169 | 170 | def kmeans(ftrain, ftest, ypred, nclusters): 171 | from sklearn.cluster import KMeans 172 | kMeansModel = KMeans(init='k-means++', n_clusters=nclusters, max_iter=100000) 173 | kMeansModel.fit(ftrain) 174 | 175 | distances = kMeansModel.transform(ftest) 176 | inDtC = np.min(distances, axis=1) 177 | 178 | return inDtC 179 | 180 | def cosine_similarity(ftrain, ftest): 181 | from sklearn.metrics.pairwise import cosine_similarity 182 | matrix_in = cosine_similarity(ftrain, ftest) 183 | din = np.max(matrix_in, axis=1) 184 | return din 185 | 186 | def euclidean_distance(ftrain, ftest): 187 | mean = np.mean(ftrain, axis=0, keepdims=True) 188 | dtest = np.sqrt(np.sum(np.power(ftest - mean, 2), axis=-1,)) 189 | return dtest 190 | 191 | def minkowski_distance(ftrain, ftest): 192 | mean = np.mean(ftrain, axis=0, keepdims=True) 193 | dtest = np.sum(np.abs(ftest - mean), axis=-1,) 194 | return dtest 195 | 196 | def chebyshev_distance(ftrain, ftest): 197 | mean = np.mean(ftrain, axis=0, keepdims=True) 198 | dtest = np.max(np.abs(ftest - mean), axis=-1,) 199 | return dtest -------------------------------------------------------------------------------- /MOODv1/transforms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # --------------------------------------------------------' 10 | import torch 11 | import torchvision.transforms.functional as F 12 | from PIL import Image 13 | import warnings 14 | import math 15 | import random 16 | import numpy as np 17 | 18 | 19 | class ToNumpy: 20 | 21 | def __call__(self, pil_img): 22 | np_img = np.array(pil_img, dtype=np.uint8) 23 | if np_img.ndim < 3: 24 | np_img = np.expand_dims(np_img, axis=-1) 25 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 26 | return np_img 27 | 28 | 29 | class ToTensor: 30 | 31 | def __init__(self, dtype=torch.float32): 32 | self.dtype = dtype 33 | 34 | def __call__(self, pil_img): 35 | np_img = np.array(pil_img, dtype=np.uint8) 36 | if np_img.ndim < 3: 37 | np_img = np.expand_dims(np_img, axis=-1) 38 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 39 | return torch.from_numpy(np_img).to(dtype=self.dtype) 40 | 41 | 42 | _pil_interpolation_to_str = { 43 | Image.NEAREST: 'PIL.Image.NEAREST', 44 | Image.BILINEAR: 'PIL.Image.BILINEAR', 45 | Image.BICUBIC: 'PIL.Image.BICUBIC', 46 | Image.LANCZOS: 'PIL.Image.LANCZOS', 47 | Image.HAMMING: 'PIL.Image.HAMMING', 48 | Image.BOX: 'PIL.Image.BOX', 49 | } 50 | 51 | 52 | def _pil_interp(method): 53 | if method == 'bicubic': 54 | return Image.BICUBIC 55 | elif method == 'lanczos': 56 | return Image.LANCZOS 57 | elif method == 'hamming': 58 | return Image.HAMMING 59 | else: 60 | # default bilinear, do we want to allow nearest? 61 | return Image.BILINEAR 62 | 63 | 64 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 65 | 66 | 67 | class RandomResizedCropAndInterpolationWithTwoPic: 68 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. 69 | 70 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 71 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 72 | is finally resized to given size. 73 | This is popularly used to train the Inception networks. 74 | 75 | Args: 76 | size: expected output size of each edge 77 | scale: range of size of the origin size cropped 78 | ratio: range of aspect ratio of the origin aspect ratio cropped 79 | interpolation: Default: PIL.Image.BILINEAR 80 | """ 81 | 82 | def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), 83 | interpolation='bilinear', second_interpolation='lanczos'): 84 | if isinstance(size, tuple): 85 | self.size = size 86 | else: 87 | self.size = (size, size) 88 | if second_size is not None: 89 | if isinstance(second_size, tuple): 90 | self.second_size = second_size 91 | else: 92 | self.second_size = (second_size, second_size) 93 | else: 94 | self.second_size = None 95 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 96 | warnings.warn("range should be of kind (min, max)") 97 | 98 | if interpolation == 'random': 99 | self.interpolation = _RANDOM_INTERPOLATION 100 | else: 101 | self.interpolation = _pil_interp(interpolation) 102 | self.second_interpolation = _pil_interp(second_interpolation) 103 | self.scale = scale 104 | self.ratio = ratio 105 | 106 | @staticmethod 107 | def get_params(img, scale, ratio): 108 | """Get parameters for ``crop`` for a random sized crop. 109 | 110 | Args: 111 | img (PIL Image): Image to be cropped. 112 | scale (tuple): range of size of the origin size cropped 113 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 114 | 115 | Returns: 116 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 117 | sized crop. 118 | """ 119 | area = img.size[0] * img.size[1] 120 | 121 | for attempt in range(10): 122 | target_area = random.uniform(*scale) * area 123 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 124 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 125 | 126 | w = int(round(math.sqrt(target_area * aspect_ratio))) 127 | h = int(round(math.sqrt(target_area / aspect_ratio))) 128 | 129 | if w <= img.size[0] and h <= img.size[1]: 130 | i = random.randint(0, img.size[1] - h) 131 | j = random.randint(0, img.size[0] - w) 132 | return i, j, h, w 133 | 134 | # Fallback to central crop 135 | in_ratio = img.size[0] / img.size[1] 136 | if in_ratio < min(ratio): 137 | w = img.size[0] 138 | h = int(round(w / min(ratio))) 139 | elif in_ratio > max(ratio): 140 | h = img.size[1] 141 | w = int(round(h * max(ratio))) 142 | else: # whole image 143 | w = img.size[0] 144 | h = img.size[1] 145 | i = (img.size[1] - h) // 2 146 | j = (img.size[0] - w) // 2 147 | return i, j, h, w 148 | 149 | def __call__(self, img): 150 | """ 151 | Args: 152 | img (PIL Image): Image to be cropped and resized. 153 | 154 | Returns: 155 | PIL Image: Randomly cropped and resized image. 156 | """ 157 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 158 | if isinstance(self.interpolation, (tuple, list)): 159 | interpolation = random.choice(self.interpolation) 160 | else: 161 | interpolation = self.interpolation 162 | if self.second_size is None: 163 | return F.resized_crop(img, i, j, h, w, self.size, interpolation) 164 | else: 165 | return F.resized_crop(img, i, j, h, w, self.size, interpolation), \ 166 | F.resized_crop(img, i, j, h, w, self.second_size, self.second_interpolation) 167 | 168 | def __repr__(self): 169 | if isinstance(self.interpolation, (tuple, list)): 170 | interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) 171 | else: 172 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 173 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 174 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 175 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 176 | format_string += ', interpolation={0}'.format(interpolate_str) 177 | if self.second_size is not None: 178 | format_string += ', second_size={0}'.format(self.second_size) 179 | format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation]) 180 | format_string += ')' 181 | return format_string 182 | -------------------------------------------------------------------------------- /MOODv1/modeling_pretrain.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm and DeiT code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # https://github.com/facebookresearch/deit/ 10 | # --------------------------------------------------------' 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | from functools import partial 15 | import matplotlib.pyplot as plt 16 | from modeling_finetune import Block, _cfg, PatchEmbed, RelativePositionBias 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 19 | import numpy as np 20 | 21 | def trunc_normal_(tensor, mean=0., std=1.): 22 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 23 | 24 | 25 | __all__ = [ 26 | 'beit_base_patch16_224_8k_vocab', 27 | 'beit_large_patch16_224_8k_vocab', 28 | ] 29 | 30 | 31 | class VisionTransformerForMaskedImageModeling(nn.Module): 32 | def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12, 33 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., 34 | drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None, 35 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, init_std=0.02, **kwargs): 36 | super().__init__() 37 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 38 | 39 | self.patch_embed = PatchEmbed( 40 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 41 | num_patches = self.patch_embed.num_patches 42 | 43 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 44 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 45 | if use_abs_pos_emb: 46 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 47 | else: 48 | self.pos_embed = None 49 | self.pos_drop = nn.Dropout(p=drop_rate) 50 | 51 | if use_shared_rel_pos_bias: 52 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) 53 | else: 54 | self.rel_pos_bias = None 55 | 56 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 57 | self.blocks = nn.ModuleList([ 58 | Block( 59 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 60 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 61 | init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, 62 | attn_head_dim=attn_head_dim, 63 | ) 64 | for i in range(depth)]) 65 | self.norm = norm_layer(embed_dim) 66 | 67 | self.init_std = init_std 68 | self.lm_head = nn.Linear(embed_dim, vocab_size) 69 | 70 | if self.pos_embed is not None: 71 | trunc_normal_(self.pos_embed, std=self.init_std) 72 | trunc_normal_(self.cls_token, std=self.init_std) 73 | trunc_normal_(self.mask_token, std=self.init_std) 74 | trunc_normal_(self.lm_head.weight, std=self.init_std) 75 | self.apply(self._init_weights) 76 | self.fix_init_weight() 77 | 78 | def fix_init_weight(self): 79 | def rescale(param, layer_id): 80 | param.div_(math.sqrt(2.0 * layer_id)) 81 | 82 | for layer_id, layer in enumerate(self.blocks): 83 | rescale(layer.attn.proj.weight.data, layer_id + 1) 84 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 85 | 86 | def _init_weights(self, m): 87 | if isinstance(m, nn.Linear): 88 | trunc_normal_(m.weight, std=self.init_std) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | elif isinstance(m, nn.Conv2d): 95 | trunc_normal_(m.weight, std=self.init_std) 96 | if m.bias is not None: 97 | nn.init.constant_(m.bias, 0) 98 | 99 | @torch.jit.ignore 100 | def no_weight_decay(self): 101 | return {'pos_embed', 'cls_token'} 102 | 103 | def get_num_layers(self): 104 | return len(self.blocks) 105 | 106 | 107 | def forward_features(self, x, bool_masked_pos=None): 108 | x = self.patch_embed(x, bool_masked_pos=bool_masked_pos) 109 | batch_size, seq_len, _ = x.size() 110 | 111 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 112 | mask_token = self.mask_token.expand(batch_size, seq_len, -1) 113 | 114 | # replace the masked visual tokens by mask_token 115 | if bool_masked_pos is not None: 116 | w = bool_masked_pos.unsqueeze(-1).type_as(mask_token) 117 | x = x * (1 - w) + mask_token * w 118 | 119 | x = torch.cat((cls_tokens, x), dim=1) 120 | if self.pos_embed is not None: 121 | x = x + self.pos_embed 122 | x = self.pos_drop(x) 123 | 124 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 125 | for blk in self.blocks: 126 | x = blk(x, rel_pos_bias=rel_pos_bias) 127 | 128 | return self.norm(x) 129 | 130 | def forward(self, x, bool_masked_pos=None, return_all_tokens=False): 131 | x = self.forward_features(x, bool_masked_pos=bool_masked_pos) 132 | x = x[:, 1:] 133 | # import pdb;pdb.set_trace() 134 | if return_all_tokens: 135 | return self.lm_head(x) 136 | else: 137 | # return the masked tokens 138 | return self.lm_head(x[bool_masked_pos]) 139 | 140 | 141 | @register_model 142 | def beit_base_patch16_224_8k_vocab(pretrained=False, **kwargs): 143 | model = VisionTransformerForMaskedImageModeling( 144 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 145 | norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs) 146 | model.default_cfg = _cfg() 147 | if pretrained: 148 | checkpoint = torch.load( 149 | kwargs["init_ckpt"], map_location="cpu" 150 | ) 151 | model.load_state_dict(checkpoint["model"]) 152 | return model 153 | 154 | 155 | @register_model 156 | def beit_large_patch16_224_8k_vocab(pretrained=False, **kwargs): 157 | model = VisionTransformerForMaskedImageModeling( 158 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 159 | norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs) 160 | model.default_cfg = _cfg() 161 | if pretrained: 162 | checkpoint = torch.load( 163 | kwargs["init_ckpt"], map_location="cpu" 164 | ) 165 | model.load_state_dict(checkpoint["model"]) 166 | return model 167 | -------------------------------------------------------------------------------- /MOODv1/optim_factory.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # --------------------------------------------------------' 10 | import torch 11 | from torch import optim as optim 12 | 13 | import json 14 | 15 | try: 16 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 17 | has_apex = True 18 | except ImportError: 19 | has_apex = False 20 | 21 | 22 | def get_num_layer_for_vit(var_name, num_max_layer): 23 | if var_name in ("cls_token", "mask_token", "pos_embed"): 24 | return 0 25 | elif var_name.startswith("patch_embed"): 26 | return 0 27 | elif var_name.startswith("rel_pos_bias"): 28 | return num_max_layer - 1 29 | elif var_name.startswith("blocks"): 30 | layer_id = int(var_name.split('.')[1]) 31 | return layer_id + 1 32 | else: 33 | return num_max_layer - 1 34 | 35 | 36 | class LayerDecayValueAssigner(object): 37 | def __init__(self, values): 38 | self.values = values 39 | 40 | def get_scale(self, layer_id): 41 | return self.values[layer_id] 42 | 43 | def get_layer_id(self, var_name): 44 | return get_num_layer_for_vit(var_name, len(self.values)) 45 | 46 | 47 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 48 | parameter_group_names = {} 49 | parameter_group_vars = {} 50 | 51 | for name, param in model.named_parameters(): 52 | if not param.requires_grad: 53 | continue # frozen weights 54 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 55 | group_name = "no_decay" 56 | this_weight_decay = 0. 57 | else: 58 | group_name = "decay" 59 | this_weight_decay = weight_decay 60 | if get_num_layer is not None: 61 | layer_id = get_num_layer(name) 62 | group_name = "layer_%d_%s" % (layer_id, group_name) 63 | else: 64 | layer_id = None 65 | 66 | if group_name not in parameter_group_names: 67 | if get_layer_scale is not None: 68 | scale = get_layer_scale(layer_id) 69 | else: 70 | scale = 1. 71 | 72 | parameter_group_names[group_name] = { 73 | "weight_decay": this_weight_decay, 74 | "params": [], 75 | "lr_scale": scale 76 | } 77 | parameter_group_vars[group_name] = { 78 | "weight_decay": this_weight_decay, 79 | "params": [], 80 | "lr_scale": scale 81 | } 82 | 83 | parameter_group_vars[group_name]["params"].append(param) 84 | parameter_group_names[group_name]["params"].append(name) 85 | # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 86 | return list(parameter_group_vars.values()) 87 | 88 | 89 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 90 | opt_lower = args.opt.lower() 91 | weight_decay = args.weight_decay 92 | if weight_decay and filter_bias_and_bn: 93 | skip = {} 94 | if skip_list is not None: 95 | skip = skip_list 96 | elif hasattr(model, 'no_weight_decay'): 97 | skip = model.no_weight_decay() 98 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 99 | weight_decay = 0. 100 | else: 101 | parameters = model.parameters() 102 | 103 | if 'fused' in opt_lower: 104 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 105 | 106 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 107 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 108 | opt_args['eps'] = args.opt_eps 109 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 110 | opt_args['betas'] = args.opt_betas 111 | 112 | opt_split = opt_lower.split('_') 113 | opt_lower = opt_split[-1] 114 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 115 | opt_args.pop('eps', None) 116 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 117 | elif opt_lower == 'momentum': 118 | opt_args.pop('eps', None) 119 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 120 | elif opt_lower == 'adam': 121 | optimizer = optim.Adam(parameters, **opt_args) 122 | elif opt_lower == 'adamw': 123 | optimizer = optim.AdamW(parameters, **opt_args) 124 | elif opt_lower == 'nadam': 125 | from timm.optim.nadam import Nadam 126 | optimizer = Nadam(parameters, **opt_args) 127 | elif opt_lower == 'radam': 128 | from timm.optim.radam import RAdam 129 | optimizer = RAdam(parameters, **opt_args) 130 | elif opt_lower == 'adamp': 131 | from timm.optim.adamp import AdamP 132 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 133 | elif opt_lower == 'sgdp': 134 | from timm.optim.sgdp import SGDP 135 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 136 | elif opt_lower == 'adadelta': 137 | optimizer = optim.Adadelta(parameters, **opt_args) 138 | elif opt_lower == 'adafactor': 139 | if not args.lr: 140 | opt_args['lr'] = None 141 | from timm.optim.adafactor import Adafactor 142 | optimizer = Adafactor(parameters, **opt_args) 143 | elif opt_lower == 'adahessian': 144 | from timm.optim.adahessian import Adahessian 145 | optimizer = Adahessian(parameters, **opt_args) 146 | elif opt_lower == 'rmsprop': 147 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 148 | elif opt_lower == 'rmsproptf': 149 | from timm.optim.rmsprop_tf import RMSpropTF 150 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 151 | elif opt_lower == 'novograd': 152 | from timm.optim.novograd import NovoGrad 153 | optimizer = NovoGrad(parameters, **opt_args) 154 | elif opt_lower == 'nvnovograd': 155 | from timm.optim.nvnovograd import NvNovoGrad 156 | optimizer = NvNovoGrad(parameters, **opt_args) 157 | elif opt_lower == 'fusedsgd': 158 | opt_args.pop('eps', None) 159 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 160 | elif opt_lower == 'fusedmomentum': 161 | opt_args.pop('eps', None) 162 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 163 | elif opt_lower == 'fusedadam': 164 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 165 | elif opt_lower == 'fusedadamw': 166 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 167 | elif opt_lower == 'fusedlamb': 168 | optimizer = FusedLAMB(parameters, **opt_args) 169 | elif opt_lower == 'fusednovograd': 170 | opt_args.setdefault('betas', (0.95, 0.98)) 171 | optimizer = FusedNovoGrad(parameters, **opt_args) 172 | else: 173 | assert False and "Invalid optimizer" 174 | raise ValueError 175 | 176 | if len(opt_split) > 1: 177 | if opt_split[0] == 'lookahead': 178 | from timm.optim.lookahead import Lookahead 179 | optimizer = Lookahead(optimizer) 180 | 181 | return optimizer 182 | -------------------------------------------------------------------------------- /MOODv1/modeling_discrete_vae.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on OpenAI DALL-E and lucidrains' DALLE-pytorch code bases 8 | # https://github.com/openai/DALL-E 9 | # https://github.com/lucidrains/DALLE-pytorch 10 | # --------------------------------------------------------' 11 | from math import sqrt 12 | import os 13 | import torch 14 | from torch import nn, einsum 15 | import torch.nn.functional as F 16 | from einops import rearrange 17 | 18 | 19 | def top_k(logits, thres = 0.5): 20 | num_logits = logits.shape[-1] 21 | k = max(int((1 - thres) * num_logits), 1) 22 | val, ind = torch.topk(logits, k) 23 | probs = torch.full_like(logits, float('-inf')) 24 | probs.scatter_(1, ind, val) 25 | return probs 26 | 27 | 28 | def exists(val): 29 | return val is not None 30 | 31 | 32 | def default(val, d): 33 | return val if exists(val) else d 34 | 35 | 36 | def eval_decorator(fn): 37 | def inner(model, *args, **kwargs): 38 | was_training = model.training 39 | model.eval() 40 | out = fn(model, *args, **kwargs) 41 | model.train(was_training) 42 | return out 43 | return inner 44 | 45 | 46 | class BasicVAE(nn.Module): 47 | 48 | def get_codebook_indices(self, images): 49 | raise NotImplementedError() 50 | 51 | def decode(self, img_seq): 52 | raise NotImplementedError() 53 | 54 | def get_codebook_probs(self, img_seq): 55 | raise NotImplementedError() 56 | 57 | def get_image_tokens_size(self): 58 | pass 59 | 60 | def get_image_size(self): 61 | pass 62 | 63 | 64 | class ResBlock(nn.Module): 65 | def __init__(self, chan_in, hidden_size, chan_out): 66 | super().__init__() 67 | self.net = nn.Sequential( 68 | nn.Conv2d(chan_in, hidden_size, 3, padding=1), 69 | nn.ReLU(), 70 | nn.Conv2d(hidden_size, hidden_size, 3, padding=1), 71 | nn.ReLU(), 72 | nn.Conv2d(hidden_size, chan_out, 1) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.net(x) + x 77 | 78 | 79 | class DiscreteVAE(BasicVAE): 80 | def __init__( 81 | self, 82 | image_size = 256, 83 | num_tokens = 512, 84 | codebook_dim = 512, 85 | num_layers = 3, 86 | hidden_dim = 64, 87 | channels = 3, 88 | smooth_l1_loss = False, 89 | temperature = 0.9, 90 | straight_through = False, 91 | kl_div_loss_weight = 0. 92 | ): 93 | super().__init__() 94 | # assert log2(image_size).is_integer(), 'image size must be a power of 2' 95 | assert num_layers >= 1, 'number of layers must be greater than or equal to 1' 96 | 97 | self.image_size = image_size 98 | self.num_tokens = num_tokens 99 | self.num_layers = num_layers 100 | self.temperature = temperature 101 | self.straight_through = straight_through 102 | self.codebook = nn.Embedding(num_tokens, codebook_dim) 103 | 104 | enc_layers = [] 105 | dec_layers = [] 106 | 107 | enc_in = channels 108 | dec_in = codebook_dim 109 | 110 | for layer_id in range(num_layers): 111 | enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU())) 112 | enc_layers.append(ResBlock(chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim)) 113 | enc_in = hidden_dim 114 | dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU())) 115 | dec_layers.append(ResBlock(chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim)) 116 | dec_in = hidden_dim 117 | 118 | enc_layers.append(nn.Conv2d(hidden_dim, num_tokens, 1)) 119 | dec_layers.append(nn.Conv2d(hidden_dim, channels, 1)) 120 | 121 | self.encoder = nn.Sequential(*enc_layers) 122 | self.decoder = nn.Sequential(*dec_layers) 123 | 124 | self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss 125 | self.kl_div_loss_weight = kl_div_loss_weight 126 | 127 | def get_image_size(self): 128 | return self.image_size 129 | 130 | def get_image_tokens_size(self): 131 | return self.image_size // 8 132 | 133 | @torch.no_grad() 134 | @eval_decorator 135 | def get_codebook_indices(self, images): 136 | logits = self.forward(images, return_logits = True) 137 | codebook_indices = logits.argmax(dim = 1) 138 | return codebook_indices 139 | 140 | @torch.no_grad() 141 | @eval_decorator 142 | def get_codebook_probs(self, images): 143 | logits = self.forward(images, return_logits = True) 144 | return nn.Softmax(dim=1)(logits) 145 | 146 | def decode( 147 | self, 148 | img_seq 149 | ): 150 | image_embeds = self.codebook(img_seq) 151 | b, n, d = image_embeds.shape 152 | h = w = int(sqrt(n)) 153 | 154 | image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w) 155 | images = self.decoder(image_embeds) 156 | return images 157 | 158 | def forward( 159 | self, 160 | img, 161 | return_loss = False, 162 | return_recons = False, 163 | return_logits = False, 164 | temp = None 165 | ): 166 | device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight 167 | assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}' 168 | 169 | logits = self.encoder(img) 170 | 171 | if return_logits: 172 | return logits # return logits for getting hard image indices for DALL-E training 173 | 174 | temp = default(temp, self.temperature) 175 | soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through) 176 | sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight) 177 | out = self.decoder(sampled) 178 | 179 | if not return_loss: 180 | return out 181 | 182 | # reconstruction loss 183 | 184 | recon_loss = self.loss_fn(img, out) 185 | 186 | # kl divergence 187 | 188 | logits = rearrange(logits, 'b n h w -> b (h w) n') 189 | qy = F.softmax(logits, dim = -1) 190 | 191 | log_qy = torch.log(qy + 1e-10) 192 | log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device)) 193 | kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True) 194 | 195 | loss = recon_loss + (kl_div * kl_div_loss_weight) 196 | 197 | if not return_recons: 198 | return loss 199 | 200 | return loss, out 201 | 202 | 203 | from dall_e import load_model 204 | 205 | 206 | class Dalle_VAE(BasicVAE): 207 | def __init__(self, window_size): 208 | super().__init__() 209 | self.encoder = None 210 | self.decoder = None 211 | self.window_size = window_size 212 | 213 | def load_model(self, model_dir, device): 214 | self.encoder = load_model(os.path.join(model_dir, "encoder.pkl"), device) 215 | self.decoder = load_model(os.path.join(model_dir, "decoder.pkl"), device) 216 | 217 | def decode(self, img_seq): 218 | bsz = img_seq.size()[0] 219 | img_seq = img_seq.view(bsz, self.window_size, self.window_size) 220 | z = F.one_hot(img_seq, num_classes=self.encoder.vocab_size).permute(0, 3, 1, 2).float() 221 | return self.decoder(z).float() 222 | 223 | def get_codebook_indices(self, images): 224 | z_logits = self.encoder(images) 225 | return torch.argmax(z_logits, axis=1) 226 | 227 | def get_codebook_probs(self, images): 228 | z_logits = self.encoder(images) 229 | return nn.Softmax(dim=1)(z_logits) 230 | 231 | def forward(self, img_seq_prob, no_process=False): 232 | if no_process: 233 | return self.decoder(img_seq_prob.float()).float() 234 | else: 235 | bsz, seq_len, num_class = img_seq_prob.size() 236 | z = img_seq_prob.view(bsz, self.window_size, self.window_size, self.encoder.vocab_size) 237 | return self.decoder(z.permute(0, 3, 1, 2).float()).float() 238 | -------------------------------------------------------------------------------- /MOODv1/engine_for_finetuning.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm, DINO and DeiT code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # https://github.com/facebookresearch/deit/ 10 | # https://github.com/facebookresearch/dino 11 | # --------------------------------------------------------' 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy, ModelEma 20 | 21 | import utils 22 | 23 | 24 | def train_class_batch(model, samples, target, criterion, 25 | pretrain_model=None, l2_alpha=0.0, l2_beta=0.0): 26 | outputs = model(samples) 27 | loss = criterion(outputs, target) 28 | # print(f'ce loss {loss}', end=', ') 29 | if pretrain_model is not None: 30 | reg_loss = 0 31 | for params, pretrain_params in zip(model.parameters(), pretrain_model.parameters()): 32 | if params.size() == pretrain_params.size(): 33 | # parts that share the architecture 34 | delta_param = params - pretrain_params 35 | reg_loss += 0.5 * l2_alpha * delta_param.norm(2)**2 36 | else: # parts that differ 37 | reg_loss += 0.5 * l2_beta * params.norm(2)**2 38 | loss += reg_loss 39 | # print(f'reg loss {reg_loss}, total loss {loss}') 40 | return loss, outputs 41 | 42 | 43 | def get_loss_scale_for_deepspeed(model): 44 | optimizer = model.optimizer 45 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 46 | 47 | 48 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 49 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 50 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 51 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 52 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 53 | num_training_steps_per_epoch=None, update_freq=None, 54 | pretrain_model=None, l2_alpha=0.0, l2_beta=0.0,): 55 | model.train(True) 56 | metric_logger = utils.MetricLogger(delimiter=" ") 57 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 58 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 59 | header = 'Epoch: [{}]'.format(epoch) 60 | print_freq = 10 61 | 62 | if loss_scaler is None: 63 | model.zero_grad() 64 | model.micro_steps = 0 65 | else: 66 | optimizer.zero_grad() 67 | 68 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 69 | step = data_iter_step // update_freq 70 | if step >= num_training_steps_per_epoch: 71 | continue 72 | it = start_steps + step # global training iteration 73 | # Update LR & WD for the first acc 74 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 75 | for i, param_group in enumerate(optimizer.param_groups): 76 | if lr_schedule_values is not None: 77 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 78 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 79 | param_group["weight_decay"] = wd_schedule_values[it] 80 | 81 | samples = samples.to(device, non_blocking=True) 82 | targets = targets.to(device, non_blocking=True) 83 | 84 | if mixup_fn is not None: 85 | samples, targets = mixup_fn(samples, targets) 86 | 87 | if loss_scaler is None: 88 | samples = samples.half() 89 | loss, output = train_class_batch( 90 | model, samples, targets, criterion, 91 | pretrain_model=pretrain_model, l2_alpha=l2_alpha, l2_beta=l2_beta,) 92 | else: 93 | with torch.cuda.amp.autocast(): 94 | loss, output = train_class_batch( 95 | model, samples, targets, criterion, 96 | pretrain_model=pretrain_model, l2_alpha=l2_alpha, l2_beta=l2_beta,) 97 | 98 | loss_value = loss.item() 99 | 100 | if not math.isfinite(loss_value): 101 | print("Loss is {}, stopping training".format(loss_value)) 102 | sys.exit(1) 103 | 104 | if loss_scaler is None: 105 | loss /= update_freq 106 | model.backward(loss) 107 | model.step() 108 | 109 | if (data_iter_step + 1) % update_freq == 0: 110 | # model.zero_grad() 111 | # Deepspeed will call step() & model.zero_grad() automatic 112 | if model_ema is not None: 113 | model_ema.update(model) 114 | grad_norm = None 115 | loss_scale_value = get_loss_scale_for_deepspeed(model) 116 | else: 117 | # this attribute is added by timm on one optimizer (adahessian) 118 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 119 | loss /= update_freq 120 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 121 | parameters=model.parameters(), create_graph=is_second_order, 122 | update_grad=(data_iter_step + 1) % update_freq == 0) 123 | if (data_iter_step + 1) % update_freq == 0: 124 | optimizer.zero_grad() 125 | if model_ema is not None: 126 | model_ema.update(model) 127 | loss_scale_value = loss_scaler.state_dict()["scale"] 128 | 129 | torch.cuda.synchronize() 130 | 131 | if mixup_fn is None: 132 | class_acc = (output.max(-1)[-1] == targets).float().mean() 133 | else: 134 | class_acc = None 135 | metric_logger.update(loss=loss_value) 136 | metric_logger.update(class_acc=class_acc) 137 | metric_logger.update(loss_scale=loss_scale_value) 138 | min_lr = 10. 139 | max_lr = 0. 140 | for group in optimizer.param_groups: 141 | min_lr = min(min_lr, group["lr"]) 142 | max_lr = max(max_lr, group["lr"]) 143 | 144 | metric_logger.update(lr=max_lr) 145 | metric_logger.update(min_lr=min_lr) 146 | weight_decay_value = None 147 | for group in optimizer.param_groups: 148 | if group["weight_decay"] > 0: 149 | weight_decay_value = group["weight_decay"] 150 | metric_logger.update(weight_decay=weight_decay_value) 151 | metric_logger.update(grad_norm=grad_norm) 152 | 153 | if log_writer is not None: 154 | log_writer.update(loss=loss_value, head="loss") 155 | log_writer.update(class_acc=class_acc, head="loss") 156 | log_writer.update(loss_scale=loss_scale_value, head="opt") 157 | log_writer.update(lr=max_lr, head="opt") 158 | log_writer.update(min_lr=min_lr, head="opt") 159 | log_writer.update(weight_decay=weight_decay_value, head="opt") 160 | log_writer.update(grad_norm=grad_norm, head="opt") 161 | 162 | log_writer.set_step() 163 | 164 | # gather the stats from all processes 165 | metric_logger.synchronize_between_processes() 166 | print("Averaged stats:", metric_logger) 167 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 168 | 169 | 170 | @torch.no_grad() 171 | def evaluate(data_loader, model, device): 172 | criterion = torch.nn.CrossEntropyLoss() 173 | 174 | metric_logger = utils.MetricLogger(delimiter=" ") 175 | header = 'Test:' 176 | 177 | # switch to evaluation mode 178 | model.eval() 179 | 180 | for batch in metric_logger.log_every(data_loader, 10, header): 181 | images = batch[0] 182 | target = batch[-1] 183 | images = images.to(device, non_blocking=True) 184 | target = target.to(device, non_blocking=True) 185 | 186 | # compute output 187 | with torch.cuda.amp.autocast(): 188 | output = model(images) 189 | loss = criterion(output, target) 190 | 191 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 192 | 193 | batch_size = images.shape[0] 194 | metric_logger.update(loss=loss.item()) 195 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 196 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 197 | # gather the stats from all processes 198 | metric_logger.synchronize_between_processes() 199 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 200 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 201 | 202 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 203 | -------------------------------------------------------------------------------- /MOODv1/README.md: -------------------------------------------------------------------------------- 1 | # MOODv1 2 | 3 |

4 | • 🤗 Model 5 | • 🐱 Code 6 | • 📃 MOODv1 7 | • 📃 MOODv2
8 |

9 | ## Abstract 10 | The core of out-of-distribution (OOD) detection is to learn the in-distribution (ID) representation, which is distinguishable from OOD samples. Previous work applied recognition-based methods to learn the ID features, which tend to learn shortcuts instead of comprehensive representations. In this work, we find surprisingly that simply using reconstruction-based methods could boost the performance of OOD detection significantly. We deeply explore the main contributors of OOD detection and find that reconstruction-based pretext tasks have the potential to provide a generally applicable and efficacious prior, which benefits the model in learning intrinsic data distributions of the ID dataset. Specifically, we take Masked Image Modeling as a pretext task for our OOD detection framework (MOOD). Without bells and whistles, MOOD outperforms previous SOTA of one-class OOD detection by 5.7%, multi-class OOD detection by 3.0%, and near-distribution OOD detection by 2.1%. It even defeats the 10-shot-per-class outlier exposure OOD detection, although we do not include any OOD samples for our detection. 11 | 12 | ## Setup 13 | Follow official [BEiT](https://github.com/microsoft/unilm/tree/master/beit) to setup. 14 | 15 | ## Datasets 16 | We suggest to organize datasets as following 17 | ```bash 18 | - MOOD 19 | - data 20 | - cifar10 21 | - cifar-10-batches-py 22 | - cifar100 23 | - cifar-100-python 24 | - imagenet30 25 | - test 26 | - train 27 | - val 28 | - imagenet1k 29 | - test 30 | - train 31 | - val 32 | - $OOD_DATASET 33 | - images 34 | ... 35 | ``` 36 | In this case, for example, if you want to train on CIFAR-10, set the parameters `-- data_path ./data/cifar10 --data_set cifar10`. 37 | 38 | We provide `datasets/imagenet30.py` for you to create soft link for `imagenet30`. 39 | 40 | ## Pretrained models 41 | 42 | Follow [BEiT](https://github.com/microsoft/unilm/tree/master/beit) to pre-train the model or directly utilize the official released weights pretrained on ImageNet-22k. The models were pretrained with 224x224 resolution. 43 | - `BEiT-base`: #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16 (#parameters: 86M) 44 | - `BEiT-large`: #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16 (#parameters: 304M) 45 | 46 | Download checkpoints that are **self-supervised pretrained and then intermediate fine-tuned** on ImageNet-22k (recommended): 47 | - BEiT-base: [beit_base_patch16_224_pt22k_ft22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth) 48 | - BEiT-large: [beit_large_patch16_224_pt22k_ft22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth) 49 | 50 | Download checkpoints that are **self-supervised pretrained** on ImageNet-22k: 51 | - BEiT-base: [beit_base_patch16_224_pt22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k.pth) 52 | - BEiT-large: [beit_large_patch16_224_pt22k](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k.pth) 53 | 54 | ## Fine-tuning on In-Distribution Dataset 55 | ### Multi-Class Fine-tuning 56 | For ViT-large, 57 | ```bash 58 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_class_finetuning.py \ 59 | --model beit_large_patch16_224 --data_path $ID_DATA_PATH --data_set $ID_DATASET \ 60 | --finetune https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k.pth \ 61 | --batch_size 8 --lr 2e-5 --update_freq 2 \ 62 | --warmup_epochs 5 --epochs 100 --layer_decay 0.9 --drop_path 0.4 \ 63 | --weight_decay 1e-8 --enable_deepspeed 64 | ``` 65 | The hyper-parameters are the same with the official [BEiT](https://github.com/microsoft/unilm/tree/master/beit). 66 | 67 | ### One-class Fine-tuning 68 | For one-class fine-tuning, please assign a class as in-distribution by adding command '--class_idx $CLASS_IDX'. Others are out-of-distribution. We support three in-distribution datasets, including `['cifar100', 'cifar10' and 'imagenet30']`. Noted that we only fine-tuned one-class imagenet30 in the original paper. 69 | For ViT-large, 70 | ```bash 71 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_class_finetuning.py \ 72 | --model beit_large_patch16_224 --data_path $ID_DATA_PATH --data_set $ID_DATASET \ 73 | --finetune https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k.pth \ 74 | --batch_size 8 --lr 2e-5 --update_freq 2 \ 75 | --warmup_epochs 5 --epochs 100 --layer_decay 0.9 --drop_path 0.4 \ 76 | --weight_decay 1e-8 --enable_deepspeed --class_idx $CLASS_IDX 77 | ``` 78 | 79 | ## OOD detection 80 | ### Multi-Class OOD Detection 81 | With OOD detection metric using **features**, we support `['mahalanobis', 'cos', 'projection', 'gauss', 'kmeans', 'euclidean', 'minkowski', 'chebyshev']` with the following command 82 | ```bash 83 | python eval_with_features.py --ckpt $CKPT_PATH --data_set $ID_DATASET --ood_dataset $OOD_DATASET --ood_data_path $OOD_DATA_PATH --metric $OOD_METRIC 84 | ``` 85 | With OOD detection metric using **logits**, we support `['softmax', 'entropy', 'energy', 'gradnorm']` with the following command 86 | ```bash 87 | python eval_with_logits.py --ckpt $CKPT_PATH --data_set $ID_DATASET --ood_dataset $OOD_DATASET --ood_data_path $OOD_DATA_PATH --metric $OOD_METRIC 88 | ``` 89 | 90 | ### One-Class OOD Detection 91 | For one-class OOD detection, please assign a class as in-distribution by adding command '--class_idx $CLASS_IDX'. Others are out-of-distribution. We support three in-distribution datasets, including `['cifar100', 'cifar10' and 'imagenet30']`. 92 | With OOD detection metric using **features**, we support `['mahalanobis', 'cos', 'projection', 'gauss', 'kmeans', 'euclidean', 'minkowski', 'chebyshev']` with the following command 93 | ```bash 94 | python eval_with_features.py --ckpt $CKPT_PATH --data_set $ID_DATASET --metric $OOD_METRIC --class_idx $CLASS_IDX 95 | ``` 96 | With OOD detection metric using **logits**, we support `['softmax', 'entropy', 'energy', 'gradnorm']` with the following command 97 | ```bash 98 | python eval_with_logits.py --ckpt $CKPT_PATH --data_set $ID_DATASET --metric $OOD_METRIC --class_idx $CLASS_IDX 99 | ``` 100 | 101 | ## Results 102 | ### Multi-class OOD detection 103 | For CIFAR-10, 104 | | CIFAR-10 | SVHN | CIFAR-100 | LSUN | Avg | 105 | |:--------: |:--------: |:---------: |:--------: |:-----: | 106 | | [ckpt](https://drive.google.com/file/d/1b_uWi2bty3tyspxEEM4jtyCh3WR9FpYm/view?usp=share_link), [distances](https://drive.google.com/drive/folders/1MeoEHArSeHc7D35-9vGEt782GKphU2PM?usp=share_link) | 99.8 | 99.4 | 99.9 | 99.7 | 107 | 108 | For CIFAR-100, 109 | | CIFAR-100 | SVHN | CIFAR-10 | LSUN | Avg | 110 | |:---------: |:--------: |:--------: |:--------: |:-----: | 111 | | [ckpt](https://drive.google.com/file/d/1MCPUTnz5DjNmR8gyWGMAl7qW811cH13X/view?usp=share_link), [distances](https://drive.google.com/drive/folders/1CV4kpb3OKiCj9uN9fMj1reG59RPDcF_X?usp=share_link) | 96.5 | 98.3 | 96.3 | 97.0 | 112 | 113 | For ImageNet-30, 114 | | ImageNet30 | Dogs | Places365 | Flowers102 | Pets | Food | Caltech256 | Dtd | Avg | 115 | |:----------: |:--------: |:---------: |:----------: |:--------: |:--------: |:----------: |:--------: |:-----: | 116 | | [ckpt](https://drive.google.com/file/d/1nTOimKRcNHlT_hKNfWJejDkYyDs4xd63/view?usp=share_link), [distances](https://drive.google.com/drive/folders/1CH3TjnohalbUIWgNcKzM3Swzy5BIo--f?usp=share_link) | 99.40 | 98.90 | 100.00 | 99.10 | 96.60 | 99.50 | 98.9 | 98.9 | 117 | 118 | For ImageNet-1k, 119 | | ImageNet1k | iNaturalist | SUN | Places | Textures | Average | 120 | |:----------: |:-----------: |:--------: |:--------: |:--------: |:-------: | 121 | | [ckpt](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft1k.pth), [distances](https://drive.google.com/drive/folders/1-JT_81-a8mRMc_jikeuVKJYDbLCJz_mr?usp=share_link) | 86.9 | 89.8 | 88.5 | 91.3 | 89.1 | 122 | 123 | ### One-class OOD detection 124 | For CIFAR-10, 125 | | Method | Airplane | Automobile | Bird | Cat | Dear | Dog | Frog | Horse | Ship | Truck | 126 | |:---------: |:--------: |:----------: |:--------: |:--------: |:--------: |:--------: |:--------: |:--------: |:--------: |:--------: | 127 | | ours | 98.6 | 99.3 | 94.3 | 93.2 | 98.1 | 96.5 | 99.3 | 99.0 | 98.8 | 97.8 | 128 | 129 | For CIFAR-100, 130 | | Method | AUROC | 131 | |:---------: |:--------: | 132 | | ours | 96.4 | 133 | 134 | For ImageNet-30, 135 | | Method | AUROC | 136 | |:---------------------: |:--------: | 137 | | ours | 92.0 | 138 | 139 | ## Acknowledgement 140 | 141 | This repository is built using the [beit](https://github.com/microsoft/unilm/tree/master/beit) library and the [SSD](https://github.com/inspire-group/SSD) repository. 142 | 143 | ## Citation 144 | If you find our research helpful, kindly cite 145 | ``` 146 | @inproceedings{li2023rethinking, 147 | title={Rethinking Out-of-distribution (OOD) Detection: Masked Image Modeling is All You Need}, 148 | author={Li, Jingyao and Chen, Pengguang and He, Zexin and Yu, Shaozuo and Liu, Shu and Jia, Jiaya}, 149 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 150 | pages={11578--11589}, 151 | year={2023} 152 | } 153 | ``` -------------------------------------------------------------------------------- /MOODv2/README.md: -------------------------------------------------------------------------------- 1 | # Official code for MOODv2: Masked Image Modeling for Out-of-Distribution Detection 2 | 3 |

4 | • 🤗 Model 5 | • 🐱 Code 6 | • 📃 MOODv1 7 | • 📃 MOODv2
8 |

9 | 10 |

11 | framework 12 |

13 | 14 | ## Abstract 15 | The crux of effective out-of-distribution (OOD) detection lies in acquiring a robust in-distribution (ID) representation, distinct from OOD samples. While previous methods predominantly leaned on recognition-based techniques for this purpose, they often resulted in shortcut learning, lacking comprehensive representations. In our study, we conducted a comprehensive analysis, exploring distinct pretraining tasks and employing various OOD score functions. The results highlight that the feature representations pre-trained through reconstruction yield a notable enhancement and narrow the performance gap among various score functions. This suggests that even simple score functions can rival complex ones when leveraging reconstruction-based pretext tasks. Reconstruction-based pretext tasks adapt well to various score functions. As such, it holds promising potential for further expansion. Our OOD detection framework, MOODv2, employs the masked image modeling pretext task. Without bells and whistles, MOODv2 impressively enhances 14.30% AUROC to 95.68% on ImageNet and achieves 99.98% on CIFAR-10. 16 | 17 | 18 | ## Performance 19 |

20 | table 21 |

22 | 23 | 24 | ## Datasets 25 | Dataset source can be downloaded here. 26 | - [ImageNet](https://www.image-net.org/). The ILSVRC 2012 dataset as In-distribution (ID) dataset. The training subset is [this file](datalists/imagenet2012_train_random_200k.txt). 27 | - [OpenImage-O](https://github.com/openimages/dataset/blob/main/READMEV3.md). The OpenImage-O dataset is a subset of the OpenImage-V3 testing set. The filelist is [here](datalists/openimage_o.txt). 28 | - [Texture](https://www.robots.ox.ac.uk/~vgg/data/dtd/). The filelist ruled out four classes that coincides with ImageNet is [here](datalists/texture.txt). 29 | - [iNaturalist](https://arxiv.org/pdf/1707.06642.pdf). Follow the instructions in the [link](https://github.com/deeplearning-wisc/large_scale_ood) to prepare the iNaturalist OOD dataset. 30 | - [ImageNet-O](https://github.com/hendrycks/natural-adv-examples). Follow the guide to download the ImageNet-O OOD dataset. 31 | 32 | ```bash 33 | mkdir data 34 | cd data 35 | ln -s /path/to/imagenet imagenet 36 | ln -s /path/to/openimage_o openimage_o 37 | ln -s /path/to/texture texture 38 | ln -s /path/to/inaturalist inaturalist 39 | ln -s /path/to/imagenet_o imagenet_o 40 | cd .. 41 | ``` 42 | 43 | ## Environment 44 | Please follow the instruction in [mmpretrain](https://github.com/open-mmlab/mmpretrain) for environment preparation. 45 | 46 | ## Demo 47 | To predict an input image is in-distribution or out-of-distribution, we support the following OOD detection methods: 48 | - `MSP` 49 | - `MaxLogit` 50 | - `Energy` 51 | - `Energy+React` 52 | - `ViM` 53 | - `Residual` 54 | - `GradNorm` 55 | - `Mahalanobis` 56 | - `KL-Matching` 57 | 58 | ### Example Usage 1 59 | 60 | **Step 1: Download the features and logits** 61 | ```bash 62 | git clone https://huggingface.co/JingyaoLi/MOODv2 63 | cd MOODv2 64 | git lfs pull 65 | ``` 66 | 67 | **Step 2: Detect your image** 68 | ```bash 69 | python src/demo.py \ 70 | --img_path imgs/DTD_cracked_0004.jpg \ 71 | --cfg configs/beit-base-p16_224px.py \ 72 | --checkpoint pretrain/beitv2-base.pth \ 73 | --fc data/fc.pkl \ 74 | --id_train_feature data/imagenet_train.pkl \ 75 | --id_val_feature data/imagenet_test.pkl \ 76 | --methods MSP MaxLogit Energy Energy+React ViM Residual GradNorm Mahalanobis 77 | ``` 78 | 79 | For the example OOD image `imgs/DTD_cracked_0004.jpg`, you are supposed to get: 80 | ``` 81 | MSP evaluation: out-of-distribution 82 | MaxLogit evaluation: out-of-distribution 83 | Energy evaluation: out-of-distribution 84 | Energy+React evaluation: out-of-distribution 85 | ViM evaluation: out-of-distribution 86 | Residual evaluation: out-of-distribution 87 | GradNorm evaluation: out-of-distribution 88 | Mahalanobis evaluation: out-of-distribution 89 | ``` 90 | 91 | ### Example Usage 2 92 | In case you want to extract the features and logits from the vision encoder without downloading our preprocessed features and logits: 93 | 94 | **Step 1: Download the checkpoint of the vision encoder** 95 | ```bash 96 | mkdir pretrain & cd pretrain 97 | wget https://huggingface.co/JingyaoLi/MOODv2/resolve/main/pretrain/beitv2-base.pth 98 | cd .. 99 | ``` 100 | 101 | **Step 2: Extract the features and logits from the vision encoder** 102 | ```bash 103 | # ID train features 104 | python src/extract_feature_vit.py $IMAGENET_PATH \ 105 | --out_file outputs/imagenet_train.pkl \ 106 | --cfg configs/beit-base-p16_224px.py \ 107 | --checkpoint pretrain/beitv2-base.pth \ 108 | --img_list datalists/imagenet2012_train_random_200k.txt 109 | 110 | # ID test features 111 | python src/extract_feature_vit.py $IMAGENET_PATH \ 112 | --out_file outputs/imagenet_test.pkl \ 113 | --cfg configs/beit-base-p16_224px.py \ 114 | --checkpoint pretrain/beitv2-base.pth \ 115 | --img_list datalists/imagenet2012_val_list.txt 116 | 117 | # Logits 118 | python src/extract_feature_vit.py $IMAGENET_PATH \ 119 | --cfg configs/beit-base-p16_224px.py \ 120 | --checkpoint pretrain/beitv2-base.pth \ 121 | --fc outputs/fc.pkl \ 122 | ``` 123 | 124 | **Step 3: Detect your image** 125 | ```bash 126 | python src/demo.py \ 127 | --img_path imgs/DTD_cracked_0004.jpg \ 128 | --cfg configs/beit-base-p16_224px.py \ 129 | --checkpoint pretrain/beitv2-base.pth \ 130 | --fc outputs/fc.pkl \ 131 | --id_train_feature outputs/imagenet_train.pkl \ 132 | --id_val_feature outputs/imagenet_test.pkl \ 133 | --methods MSP MaxLogit Energy Energy+React ViM Residual GradNorm Mahalanobis 134 | ``` 135 | 136 | ## OOD Detection Benchmark 137 | **Step 1: Download the checkpoint of the vision encoder** 138 | | Name | Paper | Config | Checkpoint | Train/Test Command | 139 | |:------:|:-------:|:-------:|:-------:|:-------:| 140 | | BEiT | [paper](https://arxiv.org/abs/2106.08254) | [config](configs/beit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmclassification/v0/beit/beit-base_3rdparty_in1k_20221114-c0a4df23.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/beit) | 141 | | BEiTv2 | [paper](https://arxiv.org/abs/2208.06366) | [config](configs/beit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmclassification/v0/beit/beitv2-base_3rdparty_in1k_20221114-73e11905.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/beitv2) | 142 | | ViT | [paper](https://arxiv.org/abs/2010.11929) | [config](configs/vit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmclassification/v0/vit/vit-base-p16_pt-32xb128-mae_in1k_20220623-4c544545.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/vision_transformer) | 143 | | MoCov3 | [paper](https://arxiv.org/abs/2104.02057) | [config](configs/vit-base-p16_224px.py) | [ckpt](https://download.openmmlab.com/mmselfsup/1.x/mocov3/mocov3_vit-base-p16_16xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb64-coslr-150e_in1k/vit-base-p16_ft-8xb64-coslr-150e_in1k_20220826-f1e6c442.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/mocov3) | 144 | | DINOv2 | [paper](https://arxiv.org/abs/2304.07193) | [config](configs/vit-base-p14_224px.py) | [ckpt](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-base-p14_dinov2-pre_3rdparty_20230426-ba246503.pth) | [README](https://github.com/open-mmlab/mmpretrain/tree/main/configs/dinov2) | 145 | 146 | 147 | **Step 2: Extract the features and logits from the vision encoder** 148 | 149 | Step 2.1 Extract features 150 | ```bash 151 | python src/extract_feature_vit.py $DATA_ROOT $OUT_FILE --cfg $CFG --checkpoint $CHECKPOINT --img_list $IMG_LIST 152 | ``` 153 | e.g. 154 | ```bash 155 | python extract_feature_vit.py data/imagenet outputs/vit_imagenet_val.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/imagenet2012_val_list.txt 156 | python extract_feature_vit.py data/imagenet outputs/vit_train_200k.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/imagenet2012_train_random_200k.txt 157 | python extract_feature_vit.py data/openimage_o outputs/vit_openimage_o.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/openimage_o.txt 158 | python extract_feature_vit.py data/texture outputs/vit_texture.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/texture.txt 159 | python extract_feature_vit.py data/inaturalist outputs/vit_inaturalist.pkl --cfg $CFG --checkpoint $CHECKPOINT 160 | python extract_feature_vit.py data/imagenet_o outputs/vit_imagenet_o.pkl --cfg $CFG --checkpoint $CHECKPOINT 161 | python extract_feature_vit.py data/cifar10 outputs/vit_cifar10_train.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/cifar10_train.txt 162 | python extract_feature_vit.py data/cifar10 outputs/vit_cifar10_test.pkl --cfg $CFG --checkpoint $CHECKPOINT --img_list datalists/cifar10_test.txt 163 | ``` 164 | 165 | Step 2.2 Extract weights and bias in fc 166 | ```bash 167 | python src/extract_feature_vit.py $DATA_ROOT $OUT_FILE --cfg $CFG --checkpoint $CHECKPOINT --fc_save_path $FC_SAVE_PATH 168 | ``` 169 | e.g. 170 | ```bash 171 | python src/extract_feature_vit.py $DATA_ROOT $OUT_FILE --cfg $CFG --checkpoint $CHECKPOINT --fc_save_path outputs/vit_fc.pkl 172 | ``` 173 | 174 | **Step 3: Evaluation** 175 | ```bash 176 | python src/benchmark.py $FC_SAVE_PATH $ID_DATA $ID_TRAIN_FEATURE $ID_VAL_FEATURE $OOD_FEATURE 177 | ``` 178 | e.g. 179 | ```bash 180 | python src/benchmark.py outputs/vit_fc.pkl outputs/vit_train_200k.pkl outputs/vit_imagenet_val.pkl outputs/vit_openimage_o.pkl outputs/vit_texture.pkl outputs/vit_inaturalist.pkl outputs/vit_imagenet_o.pkl 181 | python src/benchmark.py outputs/vit_fc.pkl outputs/vit_cifar10_train.pkl outputs/vit_cifar10_test.pkl outputs/vit_openimage_o.pkl outputs/vit_texture.pkl outputs/vit_inaturalist.pkl outputs/vit_imagenet_o.pkl 182 | ``` 183 | 184 | ## Acknowledgement 185 | Part of the code is modified from [ViM](https://github.com/haoqiwang/vim) repo. 186 | 187 | -------------------------------------------------------------------------------- /MOODv2/src/demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import json 4 | from os.path import basename, splitext 5 | import os 6 | import mmengine 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from numpy.linalg import norm, pinv 11 | from scipy.special import logsumexp, softmax 12 | from sklearn import metrics 13 | from sklearn.covariance import EmpiricalCovariance 14 | from sklearn.metrics import pairwise_distances_argmin_min 15 | from tqdm import tqdm 16 | import pickle 17 | from os.path import dirname 18 | import torchvision as tv 19 | from PIL import Image 20 | from mmpretrain.apis import init_model 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Detect an image') 24 | parser.add_argument( 25 | '--cfg', help='Path to config', 26 | default='/dataset/jingyaoli/AD/MOOD_/MOODv2/configs/beit-base-p16_224px.py') 27 | parser.add_argument('--ood_feature', 28 | default=None, help='Path to ood feature file') 29 | parser.add_argument( 30 | '--checkpoint', help='Path to checkpoint', 31 | default='/dataset/jingyaoli/AD/MOODv2/pretrain/beit-base_3rdparty_in1k_20221114-c0a4df23.pth',) 32 | parser.add_argument('--img_path', help='Path to image', 33 | default='/dataset/jingyaoli/AD/MOOD_/MOODv2/imgs/DTD_cracked_0004.jpg') 34 | parser.add_argument('--fc', 35 | default='/dataset/jingyaoli/AD/MOODv2/outputs/beit-224px/fc.pkl', help='Path to fc path') 36 | parser.add_argument('--id_data', default='imagenet', help='id data name') 37 | parser.add_argument('--id_train_feature', 38 | default='/dataset/jingyaoli/AD/MOODv2/outputs/beit-224px/imagenet_train.pkl', help='Path to data') 39 | parser.add_argument('--id_val_feature', 40 | default='/dataset/jingyaoli/AD/MOODv2/outputs/beit-224px/imagenet_test.pkl', help='Path to output file') 41 | parser.add_argument('--ood_features', 42 | default=None, nargs='+', help='Path to ood features') 43 | parser.add_argument( 44 | '--methods', nargs='+', 45 | default=['MSP', 'MaxLogit', 'Energy', 'Energy+React', 'ViM', 'Residual', 'GradNorm', 'Mahalanobis', ], # 'KL-Matching' 46 | help='methods') 47 | parser.add_argument( 48 | '--train_label', 49 | default='datalists/imagenet2012_train_random_200k.txt', 50 | help='Path to train labels') 51 | parser.add_argument( 52 | '--clip_quantile', default=0.99, help='Clip quantile to react') 53 | parser.add_argument( 54 | '--fpr', default=95, help='False Positive Rate') 55 | return parser.parse_args() 56 | 57 | def evaluate(method, score_id, score_ood, target_fpr): 58 | threhold = np.percentile(score_id, 100 - target_fpr) 59 | if score_ood >= threhold: 60 | print('\033[94m', method, '\033[0m', 'evaluation:', '\033[92m', 'in-distribution', '\033[0m') 61 | else: 62 | print('\033[94m', method, '\033[0m', 'evaluation:', '\033[91m', 'out-of-distribution', '\033[0m') 63 | 64 | def kl(p, q): 65 | return np.sum(np.where(p != 0, p * np.log(p / q), 0)) 66 | 67 | def gradnorm(x, w, b, num_cls): 68 | fc = torch.nn.Linear(*w.shape[::-1]) 69 | fc.weight.data[...] = torch.from_numpy(w) 70 | fc.bias.data[...] = torch.from_numpy(b) 71 | fc.cuda() 72 | 73 | x = torch.from_numpy(x).float().cuda() 74 | logsoftmax = torch.nn.LogSoftmax(dim=-1).cuda() 75 | 76 | confs = [] 77 | 78 | for i in tqdm(x, desc='Computing Gradnorm ID/OOD score'): 79 | targets = torch.ones((1, num_cls)).cuda() 80 | fc.zero_grad() 81 | loss = torch.mean( 82 | torch.sum(-targets * logsoftmax(fc(i[None])), dim=-1)) 83 | loss.backward() 84 | layer_grad_norm = torch.sum(torch.abs( 85 | fc.weight.grad.data)).cpu().numpy() 86 | confs.append(layer_grad_norm) 87 | 88 | return np.array(confs) 89 | 90 | def extract_image_feature(args): 91 | torch.backends.cudnn.benchmark = True 92 | 93 | print('=> Loading model') 94 | cfg = mmengine.Config.fromfile(args.cfg) 95 | model = init_model(cfg, args.checkpoint, 0).cuda().eval() 96 | 97 | print('=> Loading image') 98 | if hasattr(cfg.model.backbone, 'img_size'): 99 | img_size = cfg.model.backbone.img_size 100 | else: 101 | img_size = 224 102 | 103 | transform = tv.transforms.Compose([ 104 | tv.transforms.Resize((img_size, img_size)), 105 | tv.transforms.ToTensor(), 106 | tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 107 | ]) 108 | 109 | x = transform(Image.open(args.img_path).convert('RGB')).unsqueeze(0) 110 | 111 | print('=> Extracting feature') 112 | with torch.no_grad(): 113 | x = x.cuda() 114 | if cfg.model.backbone.type == 'BEiTPretrainViT': 115 | # (B, L, C) -> (B, C) 116 | feat_batch = model.backbone( 117 | x, mask=None)[0].mean(1) 118 | elif cfg.model.backbone.type == 'SwinTransformer': 119 | # (B, C, H, W) -> (B, C) 120 | feat_batch = model.backbone(x)[0] 121 | B, C, H, W = feat_batch.shape 122 | feat_batch = feat_batch.reshape(B, C, -1).mean(-1) 123 | else: 124 | # (B, C) 125 | feat_batch = model.backbone(x)[0] 126 | assert len(feat_batch.shape) == 2 127 | feature = feat_batch.cpu().numpy() 128 | 129 | print(f'Extracted Feature: {feature.shape}') 130 | return feature 131 | 132 | def main(): 133 | args = parse_args() 134 | if args.ood_feature and os.path.exists(args.ood_feature): 135 | feature_ood = mmengine.load(args.ood_feature) 136 | else: 137 | feature_ood = extract_image_feature(args) 138 | 139 | if os.path.exists(args.fc): 140 | w, b = mmengine.load(args.fc) 141 | print(f'{w.shape=}, {b.shape=}') 142 | num_cls = len(b) 143 | 144 | train_labels = np.array([ 145 | int(line.rsplit(' ', 1)[-1]) 146 | for line in mmengine.list_from_file(args.train_label) 147 | ], dtype=int) 148 | 149 | print(f'image path: {args.img_path}') 150 | 151 | print('=> Loading features') 152 | feature_id_train = mmengine.load(args.id_train_feature).squeeze() 153 | feature_id_val = mmengine.load(args.id_val_feature).squeeze() 154 | 155 | print(f'{feature_id_train.shape=}, {feature_id_val.shape=}') 156 | 157 | if os.path.exists(args.fc): 158 | print('=> Computing logits...') 159 | logit_id_train = feature_id_train @ w.T + b 160 | logit_id_val = feature_id_val @ w.T + b 161 | logit_ood = feature_ood @ w.T + b 162 | 163 | print('=> Computing softmax...') 164 | softmax_id_train = softmax(logit_id_train, axis=-1) 165 | softmax_id_val = softmax(logit_id_val, axis=-1) 166 | softmax_ood = softmax(logit_ood, axis=-1) 167 | 168 | u = -np.matmul(pinv(w), b) 169 | 170 | # --------------------------------------- 171 | method = 'MSP' 172 | if method in args.methods: 173 | score_id = softmax_id_val.max(axis=-1) 174 | score_ood = softmax_ood.max(axis=-1) 175 | result = evaluate(method, score_id, score_ood, args.fpr) 176 | 177 | # --------------------------------------- 178 | method = 'MaxLogit' 179 | if method in args.methods: 180 | score_id = logit_id_val.max(axis=-1) 181 | score_ood = logit_ood.max(axis=-1) 182 | result = evaluate(method, score_id, score_ood, args.fpr) 183 | 184 | # --------------------------------------- 185 | method = 'Energy' 186 | if method in args.methods: 187 | score_id = logsumexp(logit_id_val, axis=-1) 188 | score_ood = logsumexp(logit_ood, axis=-1) 189 | result = evaluate(method, score_id, score_ood, args.fpr) 190 | 191 | # --------------------------------------- 192 | method = 'Energy+React' 193 | if method in args.methods: 194 | clip = np.quantile(feature_id_train, args.clip_quantile) 195 | logit_id_val_clip = np.clip( 196 | feature_id_val, a_min=None, a_max=clip) @ w.T + b 197 | score_id = logsumexp(logit_id_val_clip, axis=-1) 198 | 199 | logit_ood_clip = np.clip(feature_ood, a_min=None, a_max=clip) @ w.T + b 200 | score_ood = logsumexp(logit_ood_clip, axis=-1) 201 | result = evaluate(method, score_id, score_ood, args.fpr) 202 | 203 | # --------------------------------------- 204 | method = 'ViM' 205 | if method in args.methods: 206 | if feature_id_val.shape[-1] >= 2048: 207 | DIM = num_cls 208 | elif feature_id_val.shape[-1] >= 768: 209 | DIM = 512 210 | else: 211 | DIM = feature_id_val.shape[-1] // 2 212 | 213 | ec = EmpiricalCovariance(assume_centered=True) 214 | ec.fit(feature_id_train - u) 215 | eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_) 216 | NS = np.ascontiguousarray( 217 | (eigen_vectors.T[np.argsort(eig_vals * -1)[DIM:]]).T) 218 | vlogit_id_train = norm(np.matmul(feature_id_train - u, NS), axis=-1) 219 | alpha = logit_id_train.max(axis=-1).mean() / vlogit_id_train.mean() 220 | 221 | vlogit_id_val = norm(np.matmul(feature_id_val - u, NS), axis=-1) * alpha 222 | energy_id_val = logsumexp(logit_id_val, axis=-1) 223 | score_id = -vlogit_id_val + energy_id_val 224 | 225 | energy_ood = logsumexp(logit_ood, axis=-1) 226 | vlogit_ood = norm(np.matmul(feature_ood - u, NS), axis=-1) * alpha 227 | score_ood = -vlogit_ood + energy_ood 228 | result = evaluate(method, score_id, score_ood, args.fpr) 229 | 230 | # --------------------------------------- 231 | method = 'Residual' 232 | if method in args.methods: 233 | if feature_id_val.shape[-1] >= 2048: 234 | DIM = 1000 235 | elif feature_id_val.shape[-1] >= 768: 236 | DIM = 512 237 | else: 238 | DIM = feature_id_val.shape[-1] // 2 239 | ec = EmpiricalCovariance(assume_centered=True) 240 | ec.fit(feature_id_train - u) 241 | eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_) 242 | NS = np.ascontiguousarray( 243 | (eigen_vectors.T[np.argsort(eig_vals * -1)[DIM:]]).T) 244 | 245 | score_id = -norm(np.matmul(feature_id_val - u, NS), axis=-1) 246 | 247 | score_ood = -norm(np.matmul(feature_ood - u, NS), axis=-1) 248 | result = evaluate(method, score_id, score_ood, args.fpr) 249 | 250 | # --------------------------------------- 251 | method = 'GradNorm' 252 | if method in args.methods: 253 | score_ood = gradnorm(feature_ood, w, b, num_cls) 254 | score_id = gradnorm(feature_id_val, w, b, num_cls) 255 | result = evaluate(method, score_id, score_ood, args.fpr) 256 | 257 | # --------------------------------------- 258 | method = 'Mahalanobis' 259 | if method in args.methods: 260 | train_means = [] 261 | train_feat_centered = [] 262 | for i in tqdm(range(train_labels.max() + 1), desc='Computing classwise mean feature'): 263 | fs = feature_id_train[train_labels == i] 264 | _m = fs.mean(axis=0) 265 | train_means.append(_m) 266 | train_feat_centered.extend(fs - _m) 267 | 268 | ec = EmpiricalCovariance(assume_centered=True) 269 | ec.fit(np.array(train_feat_centered).astype(np.float64)) 270 | 271 | mean = torch.from_numpy(np.array(train_means)).cuda().float() 272 | prec = torch.from_numpy(ec.precision_).cuda().float() 273 | 274 | score_id = -np.array( 275 | [(((f - mean) @ prec) * (f - mean)).sum(axis=-1).min().cpu().item() 276 | for f in tqdm(torch.from_numpy(feature_id_val).cuda().float(), desc='Computing Mahalanobis ID score')]) 277 | 278 | score_ood = -np.array([ 279 | (((f - mean) @ prec) * (f - mean)).sum(axis=-1).min().cpu().item() 280 | for f in tqdm(torch.from_numpy(feature_ood).cuda().float(), desc='Computing Mahalanobis OOD score') 281 | ]) 282 | result = evaluate(method, score_id, score_ood, args.fpr) 283 | 284 | # --------------------------------------- 285 | method = 'KL-Matching' 286 | if method in args.methods: 287 | 288 | pred_labels_train = np.argmax(softmax_id_train, axis=-1) 289 | mean_softmax_train = [] 290 | for i in tqdm(range(num_cls), desc='Computing classwise mean softmax'): 291 | mean_softmax = softmax_id_train[pred_labels_train == i] 292 | if mean_softmax.shape[0] == 0: 293 | mean_softmax_train.append(np.zeros((num_cls))) 294 | else: 295 | mean_softmax_train.append(np.mean(mean_softmax, axis=0)) 296 | 297 | score_id = -pairwise_distances_argmin_min( 298 | softmax_id_val, np.array(mean_softmax_train), metric=kl)[1] 299 | 300 | score_ood = -pairwise_distances_argmin_min( 301 | softmax_ood, np.array(mean_softmax_train), metric=kl)[1] 302 | result = evaluate(method, score_id, score_ood, args.fpr) 303 | 304 | if __name__ == '__main__': 305 | main() 306 | -------------------------------------------------------------------------------- /MOODv1/run_beit_pretraining.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm, DINO and DeiT code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # https://github.com/facebookresearch/deit 10 | # https://github.com/facebookresearch/dino 11 | # --------------------------------------------------------' 12 | import argparse 13 | import datetime 14 | import numpy as np 15 | import time 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | import json 19 | import os 20 | 21 | from pathlib import Path 22 | 23 | from timm.models import create_model 24 | from optim_factory import create_optimizer 25 | 26 | from datasets import build_beit_pretraining_dataset 27 | from engine_for_pretraining import train_one_epoch 28 | from utils import NativeScalerWithGradNormCount as NativeScaler 29 | import utils 30 | import modeling_pretrain 31 | 32 | 33 | def get_args(): 34 | parser = argparse.ArgumentParser('BEiT pre-training script', add_help=False) 35 | parser.add_argument('--pretrain', default=None, type=str) 36 | parser.add_argument('--batch_size', default=64, type=int) 37 | parser.add_argument('--epochs', default=300, type=int) 38 | parser.add_argument('--save_ckpt_freq', default=20, type=int) 39 | parser.add_argument("--discrete_vae_weight_path", type=str) 40 | parser.add_argument("--discrete_vae_type", type=str, default="dall-e") 41 | parser.add_argument('--data_set', default='cifar10', choices=['cifar10', 'cifar100', 'imagenet30'], 42 | type=str, help='ImageNet dataset path') 43 | parser.add_argument('--class_idx', help='None: multi-class, Not None: one-class', default=None, type=int) 44 | 45 | # Model parameters 46 | parser.add_argument('--model', default='beit_base_patch16_224_8k_vocab', type=str, metavar='MODEL', 47 | help='Name of model to train') 48 | parser.add_argument('--rel_pos_bias', action='store_true') 49 | parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') 50 | parser.set_defaults(rel_pos_bias=True) 51 | parser.add_argument('--abs_pos_emb', action='store_true') 52 | parser.set_defaults(abs_pos_emb=False) 53 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float, 54 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") 55 | 56 | parser.add_argument('--num_mask_patches', default=75, type=int, 57 | help='number of the visual tokens/patches need be masked') 58 | parser.add_argument('--max_mask_patches_per_block', type=int, default=None) 59 | parser.add_argument('--min_mask_patches_per_block', type=int, default=16) 60 | 61 | parser.add_argument('--input_size', default=224, type=int, 62 | help='images input size for backbone') 63 | parser.add_argument('--second_input_size', default=112, type=int, 64 | help='images input size for discrete vae') 65 | 66 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 67 | help='Drop path rate (default: 0.1)') 68 | 69 | # Optimizer parameters 70 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 71 | help='Optimizer (default: "adamw"') 72 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 73 | help='Optimizer Epsilon (default: 1e-8)') 74 | parser.add_argument('--opt_betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA', 75 | help='Optimizer Betas (default: 0.9, 0.999, use opt default)') 76 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 77 | help='Clip gradient norm (default: None, no clipping)') 78 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 79 | help='SGD momentum (default: 0.9)') 80 | parser.add_argument('--weight_decay', type=float, default=0.05, 81 | help='weight decay (default: 0.05)') 82 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 83 | weight decay. We use a cosine schedule for WD. 84 | (Set the same value with args.weight_decay to keep weight decay no change)""") 85 | 86 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 87 | help='learning rate (default: 5e-4)') 88 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', 89 | help='warmup learning rate (default: 1e-6)') 90 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR', 91 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 92 | 93 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 94 | help='epochs to warmup LR, if scheduler supports') 95 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 96 | help='epochs to warmup LR, if scheduler supports') 97 | 98 | # Augmentation parameters 99 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 100 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 101 | parser.add_argument('--second_interpolation', type=str, default='lanczos', 102 | help='Interpolation for discrete vae (random, bilinear, bicubic default: "lanczos")') 103 | 104 | # Dataset parameters 105 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 106 | help='dataset path') 107 | parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true') 108 | 109 | parser.add_argument('--output_dir', default='', 110 | help='path where to save, empty for no saving') 111 | parser.add_argument('--log_dir', default=None, 112 | help='path where to tensorboard log') 113 | parser.add_argument('--device', default='cuda', 114 | help='device to use for training / testing') 115 | parser.add_argument('--seed', default=0, type=int) 116 | parser.add_argument('--resume', default='', help='resume from checkpoint') 117 | parser.add_argument('--auto_resume', action='store_true') 118 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 119 | parser.set_defaults(auto_resume=True) 120 | 121 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 122 | help='start epoch') 123 | parser.add_argument('--num_workers', default=10, type=int) 124 | parser.add_argument('--pin_mem', action='store_true', 125 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 126 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', 127 | help='') 128 | parser.set_defaults(pin_mem=True) 129 | 130 | # distributed training parameters 131 | parser.add_argument('--world_size', default=1, type=int, 132 | help='number of distributed processes') 133 | parser.add_argument('--local_rank', default=-1, type=int) 134 | parser.add_argument('--dist_on_itp', action='store_true') 135 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 136 | 137 | return parser.parse_args() 138 | 139 | 140 | def get_model(args): 141 | print(f"Creating model: {args.model}") 142 | model = create_model( 143 | args.model, 144 | pretrained=False, 145 | drop_path_rate=args.drop_path, 146 | drop_block_rate=None, 147 | use_shared_rel_pos_bias=args.rel_pos_bias, 148 | use_abs_pos_emb=args.abs_pos_emb, 149 | init_values=args.layer_scale_init_value, 150 | ) 151 | 152 | return model 153 | 154 | 155 | def main(args): 156 | utils.init_distributed_mode(args) 157 | 158 | print(args) 159 | 160 | device = torch.device(args.device) 161 | 162 | # fix the seed for reproducibility 163 | seed = args.seed + utils.get_rank() 164 | torch.manual_seed(seed) 165 | np.random.seed(seed) 166 | # random.seed(seed) 167 | 168 | cudnn.benchmark = True 169 | 170 | model = get_model(args) 171 | patch_size = model.patch_embed.patch_size 172 | print("Patch size = %s" % str(patch_size)) 173 | args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) 174 | args.patch_size = patch_size 175 | 176 | # get dataset 177 | dataset_train = build_beit_pretraining_dataset(args, args.data_set) 178 | 179 | # prepare discrete vae 180 | d_vae = utils.create_d_vae( 181 | weight_path=args.discrete_vae_weight_path, d_vae_type=args.discrete_vae_type, 182 | device=device, image_size=args.second_input_size) 183 | 184 | if True: # args.distributed: 185 | num_tasks = utils.get_world_size() 186 | global_rank = utils.get_rank() 187 | sampler_rank = global_rank 188 | num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks 189 | 190 | sampler_train = torch.utils.data.DistributedSampler( 191 | dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True 192 | ) 193 | print("Sampler_train = %s" % str(sampler_train)) 194 | else: 195 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 196 | 197 | if global_rank == 0 and args.log_dir is not None: 198 | os.makedirs(args.log_dir, exist_ok=True) 199 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 200 | else: 201 | log_writer = None 202 | 203 | data_loader_train = torch.utils.data.DataLoader( 204 | dataset_train, sampler=sampler_train, 205 | batch_size=args.batch_size, 206 | num_workers=args.num_workers, 207 | pin_memory=args.pin_mem, 208 | drop_last=True, 209 | ) 210 | 211 | print('=> Loading checkpoint'.format(args.pretrain)) 212 | if args.pretrain: 213 | from utils import load_state_dict 214 | state_dict = torch.load(args.pretrain, map_location='cpu') 215 | load_state_dict(model, state_dict) 216 | 217 | model.to(device) 218 | model_without_ddp = model 219 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 220 | 221 | print("Model = %s" % str(model_without_ddp)) 222 | print('number of params:', n_parameters) 223 | 224 | total_batch_size = args.batch_size * utils.get_world_size() 225 | print("LR = %.8f" % args.lr) 226 | print("Batch size = %d" % total_batch_size) 227 | print("Number of training steps = %d" % num_training_steps_per_epoch) 228 | print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch)) 229 | 230 | if args.distributed: 231 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 232 | model_without_ddp = model.module 233 | 234 | optimizer = create_optimizer( 235 | args, model_without_ddp) 236 | loss_scaler = NativeScaler() 237 | 238 | print("Use step level LR & WD scheduler!") 239 | lr_schedule_values = utils.cosine_scheduler( 240 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 241 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 242 | ) 243 | if args.weight_decay_end is None: 244 | args.weight_decay_end = args.weight_decay 245 | wd_schedule_values = utils.cosine_scheduler( 246 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 247 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 248 | 249 | utils.auto_load_model( 250 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 251 | 252 | print(f"Start training for {args.epochs} epochs") 253 | start_time = time.time() 254 | for epoch in range(args.start_epoch, args.epochs): 255 | if args.distributed: 256 | data_loader_train.sampler.set_epoch(epoch) 257 | if log_writer is not None: 258 | log_writer.set_step(epoch * num_training_steps_per_epoch) 259 | train_stats = train_one_epoch( 260 | model, d_vae, data_loader_train, 261 | optimizer, device, epoch, loss_scaler, 262 | args.clip_grad, log_writer=log_writer, 263 | start_steps=epoch * num_training_steps_per_epoch, 264 | lr_schedule_values=lr_schedule_values, 265 | wd_schedule_values=wd_schedule_values, 266 | ) 267 | if args.output_dir: 268 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 269 | utils.save_model( 270 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 271 | loss_scaler=loss_scaler, epoch=epoch) 272 | 273 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 274 | 'epoch': epoch, 'n_parameters': n_parameters} 275 | 276 | if args.output_dir and utils.is_main_process(): 277 | if log_writer is not None: 278 | log_writer.flush() 279 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 280 | f.write(json.dumps(log_stats) + "\n") 281 | 282 | total_time = time.time() - start_time 283 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 284 | print('Training time {}'.format(total_time_str)) 285 | 286 | 287 | if __name__ == '__main__': 288 | args = get_args() 289 | args.output_dir = '/dataset/jingyaoli/AD/temp/{}/{}/trial:{}'.format(args.data_set, args.model, args.trial) 290 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 291 | main(args) 292 | -------------------------------------------------------------------------------- /MOODv1/datasets.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm, DINO and DeiT code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # https://github.com/facebookresearch/deit/ 10 | # https://github.com/facebookresearch/dino 11 | # --------------------------------------------------------' 12 | import os 13 | import torch 14 | 15 | from torchvision import datasets, transforms 16 | from dataset_folder import ImageFolder, SegmentationDataset 17 | import blobfile as bf 18 | from timm.data.constants import \ 19 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 20 | from transforms import RandomResizedCropAndInterpolationWithTwoPic 21 | from timm.data import create_transform 22 | 23 | 24 | from typing import TypeVar, Generic 25 | from dall_e.utils import map_pixels 26 | from masking_generator import MaskingGenerator 27 | import ood_utils 28 | from PIL import Image 29 | import numpy as np 30 | 31 | T_co = TypeVar('T_co', covariant=True) 32 | 33 | class Dataset(Generic[T_co]): 34 | def __getitem__(self, index): 35 | raise NotImplementedError 36 | 37 | class DataAugmentationForBEiT(object): 38 | def __init__(self, args): 39 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 40 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 41 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 42 | 43 | self.common_transform = transforms.Compose([ 44 | transforms.ColorJitter(0.4, 0.4, 0.4), 45 | transforms.RandomHorizontalFlip(p=0.5), 46 | RandomResizedCropAndInterpolationWithTwoPic( 47 | size=args.input_size, second_size=args.second_input_size, 48 | interpolation=args.train_interpolation, second_interpolation=args.second_interpolation, 49 | ), 50 | ]) 51 | 52 | self.patch_transform = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize( 55 | mean=torch.tensor(mean), 56 | std=torch.tensor(std)) 57 | ]) 58 | 59 | if args.discrete_vae_type == "dall-e": 60 | self.visual_token_transform = transforms.Compose([ 61 | transforms.ToTensor(), 62 | map_pixels, 63 | ]) 64 | elif args.discrete_vae_type == "customized": 65 | self.visual_token_transform = transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Normalize( 68 | mean=IMAGENET_INCEPTION_MEAN, 69 | std=IMAGENET_INCEPTION_STD, 70 | ), 71 | ]) 72 | else: 73 | raise NotImplementedError() 74 | 75 | self.masked_position_generator = MaskingGenerator( 76 | args.window_size, num_masking_patches=args.num_mask_patches, 77 | max_num_patches=args.max_mask_patches_per_block, 78 | min_num_patches=args.min_mask_patches_per_block, 79 | ) 80 | 81 | def __call__(self, image): 82 | for_patches, for_visual_tokens = self.common_transform(image) 83 | return \ 84 | self.patch_transform(for_patches), self.visual_token_transform(for_visual_tokens), \ 85 | self.masked_position_generator() 86 | 87 | def __repr__(self): 88 | repr = "(DataAugmentationForBEiT,\n" 89 | repr += " common_transform = %s,\n" % str(self.common_transform) 90 | repr += " patch_transform = %s,\n" % str(self.patch_transform) 91 | repr += " visual_tokens_transform = %s,\n" % str(self.visual_token_transform) 92 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 93 | repr += ")" 94 | return repr 95 | 96 | 97 | def build_beit_pretraining_dataset(args, data_set=None): 98 | '''train set for beit''' 99 | if data_set == None: 100 | data_set = args.data_set 101 | transform = DataAugmentationForBEiT(args) 102 | 103 | # print("Data Aug = %s" % str(transform)) 104 | data_path = args.data_path 105 | if data_set == 'cifar100': 106 | dataset = datasets.CIFAR100(data_path, train=True, transform=transform) 107 | elif data_set == 'cifar10': 108 | dataset = datasets.CIFAR10(data_path, train=True, transform=transform) 109 | elif data_set == 'imagenet30': 110 | dataset = ImageFolder(os.path.join(data_path, 'train'), transform=transform) 111 | else: 112 | dataset = ImageFolder(data_path, transform=transform) 113 | return dataset 114 | 115 | 116 | class Subset(Dataset[T_co]): 117 | def __init__(self, dataset, indices): 118 | self.dataset = dataset 119 | self.indices = np.array(indices) 120 | self.data = np.array(dataset.data)[self.indices, :] 121 | self.targets = np.array(dataset.targets)[self.indices] 122 | 123 | def __getitem__(self, idx): 124 | return self.dataset[self.indices[idx]] 125 | 126 | def __len__(self): 127 | return len(self.indices) 128 | 129 | 130 | class Mixup(Dataset[T_co]): 131 | def __init__(self, dataset, transform, alpha=0.2): 132 | self.dataset = dataset 133 | self.targets = dataset.targets 134 | self.data = dataset.data 135 | self.alpha = alpha 136 | self.baselenth = len(self.data) 137 | self.transform = transform 138 | 139 | 140 | def __getitem__(self, idx): 141 | if idx < self.baselenth: 142 | img, target = self.data[idx], self.targets[idx] 143 | img = Image.fromarray(img) 144 | img = self.transform(img) 145 | return img, target 146 | else: 147 | img, target = self.data[idx-self.baselenth], self.targets[idx-self.baselenth] 148 | img = Image.fromarray(img) 149 | img = self.transform(img) 150 | 151 | lam = np.random.beta(self.alpha, self.alpha) 152 | mix_index = np.random.randint(0, self.baselenth-1) 153 | mix_img, mix_target = self.data[mix_index], self.targets[mix_index] 154 | mix_img = Image.fromarray(mix_img) 155 | mix_img = self.transform(mix_img) 156 | 157 | img = lam * img + (1 - lam) * mix_img 158 | target = int(lam * target + (1 - lam) * mix_target) 159 | return img, target 160 | 161 | def __len__(self): 162 | return self.baselenth*2 163 | 164 | 165 | class Rotation(Dataset[T_co]): 166 | def __init__(self, dataset, transform): 167 | self.dataset = dataset 168 | self.targets = dataset.targets 169 | self.data = dataset.data 170 | self.baselenth = len(self.data) 171 | self.transform = transform 172 | 173 | def __getitem__(self, idx): 174 | rot_angle = idx // self.baselenth # range: 0-3 175 | image_idx = idx % self.baselenth # range: 0-baselenth 176 | img, target = self.data[image_idx], self.targets[image_idx] 177 | img = np.rot90(img, rot_angle) 178 | img = Image.fromarray(img) 179 | img = self.transform(img) 180 | return img, target 181 | 182 | def __len__(self): 183 | return self.baselenth*4 184 | 185 | 186 | class Transform(Dataset[T_co]): 187 | def __init__(self, dataset, transform): 188 | self.dataset = dataset 189 | self.data = dataset.data 190 | self.targets = dataset.targets 191 | self.transform = transform 192 | 193 | def __getitem__(self, idx): 194 | img, target = self.data[idx], self.targets[idx] 195 | img = Image.fromarray(img) 196 | img = self.transform(img) 197 | return img, target 198 | 199 | def __len__(self): 200 | return len(self.dataset) 201 | 202 | 203 | def build_dataset(is_train, args, data_set=None, ood=False, is_trans=True, ood_data_path=None): 204 | if data_set == None: 205 | data_set = args.data_set 206 | 207 | if not is_trans: 208 | # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 209 | transform = transforms.Compose([ 210 | transforms.Resize(256), 211 | transforms.CenterCrop(224), 212 | transforms.ToTensor(), 213 | # normalize, 214 | ]) 215 | else: 216 | transform = build_transform(is_train, args) 217 | 218 | if ood_data_path is None: 219 | data_path = args.data_path 220 | else: 221 | data_path = ood_data_path 222 | 223 | if data_set == 'cifar100': 224 | dataset = datasets.CIFAR100(data_path, train=is_train, transform=transform, download=False) 225 | nb_classes = 100 226 | elif data_set == 'cifar10': 227 | dataset = datasets.CIFAR10(data_path, train=is_train, transform=transform, download=False) 228 | nb_classes = 10 229 | elif data_set == 'svhn': 230 | dataset = datasets.SVHN(data_path, split='train' if is_train else 'test', transform=transform, download=False) 231 | nb_classes = 10 232 | elif data_set == 'imagenet30': 233 | split = 'train' if is_train else 'test' 234 | nb_classes = 30 235 | dataset = ImageFolder(os.path.join(data_path, split), transform=transform) 236 | elif data_set == 'caltech256': 237 | split = 'train' if is_train else 'test' 238 | nb_classes = 256 239 | dataset = ImageFolder(os.path.join(data_path, split), transform=transform) 240 | elif data_set == 'imagenet1k': 241 | root = os.path.join(data_path, 'train' if is_train else 'val') 242 | dataset = ImageFolder(root, transform=transform) 243 | nb_classes = 1000 244 | elif data_set == "image_folder": 245 | root = data_path if is_train else args.eval_data_path 246 | dataset = ImageFolder(root, transform=transform) 247 | nb_classes = args.nb_classes 248 | assert len(dataset.class_to_idx) == nb_classes 249 | else: 250 | nb_classes = None 251 | dataset = ImageFolder(data_path, transform=transform) 252 | 253 | if isinstance(args.class_idx, int) and not ood: 254 | print('Using one class idx (class idx:{})'.format(args.class_idx)) 255 | cls_list = get_superclass_list(data_set) 256 | dataset = get_subclass_dataset(dataset, classes=cls_list[args.class_idx]) 257 | 258 | return dataset, nb_classes 259 | 260 | 261 | def _list_image_files_recursively(data_dir): 262 | results = [] 263 | for entry in sorted(bf.listdir(data_dir)): 264 | full_path = bf.join(data_dir, entry) 265 | ext = entry.split(".")[-1] 266 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 267 | results.append(full_path) 268 | elif bf.isdir(full_path): 269 | results.extend(_list_image_files_recursively(full_path)) 270 | return results 271 | 272 | 273 | 274 | def get_superclass_list(dataset): 275 | CIFAR10_SUPERCLASS = list(range(10)) # one class 276 | IMAGENET_SUPERCLASS = list(range(30)) # one class 277 | CIFAR100_SUPERCLASS = [ 278 | [4, 31, 55, 72, 95], 279 | [1, 33, 67, 73, 91], 280 | [54, 62, 70, 82, 92], 281 | [9, 10, 16, 29, 61], 282 | [0, 51, 53, 57, 83], 283 | [22, 25, 40, 86, 87], 284 | [5, 20, 26, 84, 94], 285 | [6, 7, 14, 18, 24], 286 | [3, 42, 43, 88, 97], 287 | [12, 17, 38, 68, 76], 288 | [23, 34, 49, 60, 71], 289 | [15, 19, 21, 32, 39], 290 | [35, 63, 64, 66, 75], 291 | [27, 45, 77, 79, 99], 292 | [2, 11, 36, 46, 98], 293 | [28, 30, 44, 78, 93], 294 | [37, 50, 65, 74, 80], 295 | [47, 52, 56, 59, 96], 296 | [8, 13, 48, 58, 90], 297 | [41, 69, 81, 85, 89], 298 | ] 299 | if dataset == 'cifar10': 300 | return CIFAR10_SUPERCLASS 301 | elif dataset == 'cifar100': 302 | return CIFAR100_SUPERCLASS 303 | elif dataset.lower() == 'imagenet' or dataset.lower() == 'imagenet30': 304 | return IMAGENET_SUPERCLASS 305 | else: 306 | raise NotImplementedError() 307 | 308 | 309 | def get_subclass_dataset(dataset, classes): 310 | if not isinstance(classes, list): 311 | classes = [classes] 312 | indices = [] 313 | for idx, tgt in enumerate(dataset.targets): 314 | if tgt in classes: 315 | indices.append(idx) 316 | dataset = Subset(dataset, indices) 317 | return dataset 318 | 319 | 320 | def build_transform(is_train, args): 321 | resize_im = args.input_size > 32 322 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 323 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 324 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 325 | 326 | if is_train: 327 | # this should always dispatch to transforms_imagenet_train 328 | transform = create_transform( 329 | input_size=args.input_size, 330 | is_training=True, 331 | color_jitter=args.color_jitter, 332 | auto_augment=args.aa, 333 | interpolation=args.train_interpolation, 334 | re_prob=args.reprob, 335 | re_mode=args.remode, 336 | re_count=args.recount, 337 | mean=mean, 338 | std=std, 339 | ) 340 | if not resize_im: 341 | # replace RandomResizedCropAndInterpolation with 342 | # RandomCrop 343 | transform.transforms[0] = transforms.RandomCrop( 344 | args.input_size, padding=4) 345 | return transform 346 | 347 | t = [] 348 | if resize_im: 349 | if args.crop_pct is None: 350 | if args.input_size < 384: 351 | args.crop_pct = 224 / 256 352 | else: 353 | args.crop_pct = 1.0 354 | size = int(args.input_size / args.crop_pct) 355 | t.append( 356 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 357 | ) 358 | t.append(transforms.CenterCrop(args.input_size)) 359 | 360 | t.append(transforms.ToTensor()) 361 | t.append(transforms.Normalize(mean, std)) 362 | return transforms.Compose(t) 363 | 364 | 365 | -------------------------------------------------------------------------------- /MOODv1/eval_with_features.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import os 5 | from pydoc import classname 6 | import numpy as np 7 | import time 8 | import logging 9 | import argparse 10 | from collections import OrderedDict 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import DataLoader 15 | 16 | from datasets import build_dataset 17 | from timm.models import create_model 18 | import ood_utils 19 | import utils 20 | 21 | import modeling_finetune 22 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 23 | 24 | def get_scores(ftrain, ftest, food, args): 25 | return ood_utils.get_scores_one_cluster(ftrain, ftest, food, args) 26 | 27 | 28 | def get_eval_results(ftrain, ftest, food, args): 29 | if ftrain is not None: 30 | if args.cc or args.avgcc: 31 | for i in range(args.nb_classes): 32 | ftrain[i] /= np.linalg.norm(ftrain[i], axis=-1, keepdims=True) + 1e-10 33 | else: 34 | ftrain /= np.linalg.norm(ftrain, axis=-1, keepdims=True) + 1e-10 35 | 36 | if ftest is not None: 37 | ftest /= np.linalg.norm(ftest, axis=-1, keepdims=True) + 1e-10 38 | 39 | if food is not None: 40 | food /= np.linalg.norm(food, axis=-1, keepdims=True) + 1e-10 41 | 42 | dtest, dood, labels = get_scores(ftrain, ftest, food, args) 43 | 44 | fpr95 = ood_utils.get_fpr(dtest, dood, labels) 45 | auroc, aupr = ood_utils.get_roc_sklearn(dtest, dood, labels), ood_utils.get_pr_sklearn(dtest, dood, labels) 46 | return fpr95, auroc, aupr 47 | 48 | 49 | def compute(model, ood_set, features_train, features_test, args): 50 | ood_sampler = torch.utils.data.SequentialSampler(ood_set) 51 | 52 | ood_loader = torch.utils.data.DataLoader(ood_set, sampler=ood_sampler, 53 | batch_size=args.batch_size, num_workers=args.num_workers, 54 | pin_memory=args.pin_mem, drop_last=False) 55 | 56 | features_ood = ood_utils.get_features(model, ood_loader, args.ood_dataset, args) 57 | fpr95, auroc, aupr = get_eval_results(features_train, features_test, features_ood, args,) 58 | return fpr95, auroc, aupr 59 | 60 | 61 | def get_args(): 62 | parser = argparse.ArgumentParser(description="SSD evaluation") 63 | parser.add_argument('--ood_dataset', help='name of ood dataset', default=None, type=str) 64 | parser.add_argument('--ood_data_path', help='path of ood dataset', default=None, type=str) 65 | 66 | parser.add_argument('--cc', help='class-conditioned distance', default=False, action='store_true') 67 | parser.add_argument('--avgcc', help='average of class-conditioned distance', default=False, action='store_true') 68 | parser.add_argument('--max_num', default=1e10, help='maximum number of test samples', type=int) 69 | 70 | parser.add_argument("--results-dir", type=str, default="./eval_results") 71 | parser.add_argument('--class_idx', help='One-class OOD: the idx of the ID class number. Multi-class OOD: None', default=None, type=int) 72 | 73 | parser.add_argument("--batch_size", type=int, default=8) 74 | parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path", required=True) 75 | parser.add_argument("--metric", type=str, default='mahalanobis', help='OOD detection metric of the features', 76 | choices=['mahalanobis', 'cos', 'projection', 'gauss', 'kmeans', 'euclidean', 'minkowski', 'chebyshev']) 77 | 78 | parser.set_defaults(eval_ood=True) 79 | # Model parameters 80 | parser.add_argument('--model', default='beit_large_patch16_224', type=str, metavar='MODEL', 81 | help='Name of model to train') 82 | parser.add_argument('--rel_pos_bias', action='store_true') 83 | parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') 84 | parser.set_defaults(rel_pos_bias=True) 85 | parser.add_argument('--abs_pos_emb', action='store_true') 86 | parser.set_defaults(abs_pos_emb=False) 87 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float, 88 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") 89 | 90 | parser.add_argument('--input_size', default=224, type=int, help='images input size') 91 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 92 | help='Dropout rate (default: 0.)') 93 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', 94 | help='Attention dropout rate (default: 0.)') 95 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 96 | help='Drop path rate (default: 0.1)') 97 | 98 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False) 99 | 100 | parser.add_argument('--model_ema', action='store_true', default=False) 101 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 102 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') 103 | 104 | # Augmentation parameters 105 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 106 | help='Color jitter factor (default: 0.4)') 107 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 108 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 109 | parser.add_argument('--smoothing', type=float, default=0.1, 110 | help='Label smoothing (default: 0.1)') 111 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 112 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 113 | 114 | # Evaluation parameters 115 | parser.add_argument('--crop_pct', type=float, default=None) 116 | 117 | # * Random Erase params 118 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 119 | help='Random erase prob (default: 0.25)') 120 | parser.add_argument('--remode', type=str, default='pixel', 121 | help='Random erase mode (default: "pixel")') 122 | parser.add_argument('--recount', type=int, default=1, 123 | help='Random erase count (default: 1)') 124 | parser.add_argument('--resplit', action='store_true', default=False, 125 | help='Do not random erase first (clean) augmentation split') 126 | 127 | # * Mixup params 128 | parser.add_argument('--mixup', type=float, default=0, 129 | help='mixup alpha, mixup enabled if > 0.') 130 | parser.add_argument('--cutmix', type=float, default=0, 131 | help='cutmix alpha, cutmix enabled if > 0.') 132 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 133 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 134 | parser.add_argument('--mixup_prob', type=float, default=1.0, 135 | help='Probability of performing mixup or cutmix when either/both is enabled') 136 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 137 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 138 | parser.add_argument('--mixup_mode', type=str, default='batch', 139 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 140 | 141 | # * Finetuning params 142 | parser.add_argument('--finetune', default='', 143 | help='finetune from checkpoint') 144 | parser.add_argument('--model_key', default='model|module', type=str) 145 | parser.add_argument('--model_prefix', default='', type=str) 146 | parser.add_argument('--init_scale', default=0.001, type=float) 147 | parser.add_argument('--use_mean_pooling', action='store_true') 148 | parser.set_defaults(use_mean_pooling=True) 149 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') 150 | parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False) 151 | 152 | # Dataset parameters 153 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 154 | help='dataset path') 155 | parser.add_argument('--eval_data_path', default=None, type=str, 156 | help='dataset path for evaluation') 157 | 158 | parser.add_argument('--nb_classes', default=0, type=int, 159 | help='number of the classification types') 160 | parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true') 161 | 162 | parser.add_argument('--data_set', default='IMNET', choices=['cifar10', 'cifar100', 'imagenet1k', 'imagenet30'], 163 | type=str, help='ImageNet dataset path') 164 | parser.add_argument('--output_dir', default='', 165 | help='path where to save, empty for no saving') 166 | parser.add_argument('--log_dir', default=None, 167 | help='path where to tensorboard log') 168 | parser.add_argument('--device', default='cuda', 169 | help='device to use for training / testing') 170 | parser.add_argument('--seed', default=0, type=int) 171 | parser.add_argument('--resume', default='', 172 | help='resume from checkpoint') 173 | parser.add_argument('--auto_resume', action='store_true') 174 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 175 | parser.set_defaults(auto_resume=True) 176 | 177 | parser.add_argument('--save_ckpt', action='store_true') 178 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') 179 | parser.set_defaults(save_ckpt=True) 180 | 181 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 182 | help='start epoch') 183 | parser.add_argument('--eval', action='store_true', 184 | help='Perform evaluation only') 185 | parser.add_argument('--dist_eval', action='store_true', default=False, 186 | help='Enabling distributed evaluation') 187 | parser.add_argument('--num_workers', default=10, type=int) 188 | parser.add_argument('--pin_mem', action='store_true', 189 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 190 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 191 | parser.set_defaults(pin_mem=True) 192 | 193 | args = parser.parse_args() 194 | return args 195 | 196 | 197 | def main(): 198 | # check ckpt 199 | assert os.path.exists(args.ckpt), 'Not find {}'.format(args.ckpt) 200 | print('loading from {}'.format(args.ckpt)) 201 | 202 | # load checkpoint 203 | global model 204 | if args.ckpt.startswith('https'): 205 | state_dict = torch.hub.load_state_dict_from_url( 206 | args.ckpt, map_location='cpu', check_hash=True) 207 | else: 208 | state_dict = torch.load(args.ckpt, map_location='cpu') 209 | 210 | if 'model' in state_dict.keys(): 211 | state_dict = state_dict['model'] 212 | elif 'state_dict' in state_dict.keys(): 213 | state_dict = state_dict['state_dict'] 214 | elif 'module' in state_dict.keys(): 215 | state_dict = state_dict['module'] 216 | 217 | for k in list(state_dict.keys()): 218 | state_dict[k.replace('module.', '')] = state_dict.pop(k) 219 | 220 | if 'pt22k' in args.ckpt: 221 | model = utils.set_checkpoint_for_finetune(model, args, args.ckpt) 222 | else: 223 | utils.load_state_dict(model, state_dict, prefix=args.model_prefix) 224 | 225 | # dataloaders 226 | train_set, args.nb_classes = build_dataset(is_train=True, data_set=args.data_set, args=args) 227 | test_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args) 228 | 229 | train_sampler = torch.utils.data.RandomSampler(train_set) 230 | test_sampler = torch.utils.data.SequentialSampler(test_set) 231 | 232 | train_loader = torch.utils.data.DataLoader(train_set, sampler=train_sampler, 233 | batch_size=args.batch_size, num_workers=args.num_workers, 234 | pin_memory=args.pin_mem, drop_last=True,) 235 | 236 | test_loader = torch.utils.data.DataLoader(test_set, sampler=test_sampler, 237 | batch_size=args.batch_size, num_workers=args.num_workers, 238 | pin_memory=args.pin_mem, drop_last=False) 239 | 240 | features_train = ood_utils.get_features(model, train_loader, args.data_set, args, is_train=True) 241 | features_test = ood_utils.get_features(model, test_loader, args.data_set, args) 242 | 243 | if args.class_idx is None: 244 | ood_set, _ = build_dataset(is_train=False, data_set=args.ood_dataset, args=args, ood=True, ood_data_path=args.ood_data_path) 245 | fpr95, auroc, aupr = compute(model, ood_set, features_train, features_test, args) 246 | 247 | ood = "%-15s\t" % (args.ood_dataset) 248 | logger.info("{}\tIn-data: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format( 249 | args.metric, args.data_set, ood, fpr95*100, auroc*100, aupr*100)) 250 | aurocs.append(auroc*100) 251 | print('\t'.join('{:.1f}'.format(auroc) for auroc in aurocs)) 252 | 253 | 254 | else: 255 | ood_multi_class_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args, ood=True, ood_data_path=args.ood_data_path) 256 | 257 | fpr95s, aurocs, auprs = [], [], [] 258 | for d in range(len(cls_list)): 259 | if d == args.class_idx: 260 | continue 261 | ood_set = ood_utils.get_subclass_dataset(ood_multi_class_set, classes=cls_list[d]) 262 | 263 | args.ood_dataset = str(d) 264 | fpr95, auroc, aupr = compute(model, ood_set, features_train, features_test, args) 265 | logger.info("{}\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format( 266 | args.metric, args.data_set, args.class_idx, d, fpr95*100, auroc*100, aupr*100)) 267 | 268 | fpr95s.append(fpr95) 269 | aurocs.append(auroc) 270 | auprs.append(aupr) 271 | 272 | fpr95 = np.mean(fpr95s)*100 273 | auroc = np.mean(aurocs)*100 274 | aupr = np.mean(auprs)*100 275 | 276 | results = "{}\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format( 277 | args.metric, args.data_set, args.class_idx, 'all', fpr95, auroc, aupr) 278 | logger.info(results) 279 | return fpr95, auroc, aupr 280 | 281 | 282 | def create_logger(results_dir): 283 | # create logger 284 | logging.basicConfig(level=logging.INFO, format="%(message)s") 285 | logger = logging.getLogger() 286 | 287 | # create results dir 288 | if results_dir is not None: 289 | ood_utils.mkdir(results_dir) 290 | results_file = os.path.join(results_dir, 'eval_results.txt') 291 | logger.addHandler(logging.FileHandler(results_file, "a")) 292 | print('=> Saving to {}'.format(results_file)) 293 | return logger 294 | 295 | 296 | if __name__ == "__main__": 297 | args = get_args() 298 | device = "cuda:0" 299 | 300 | torch.manual_seed(args.seed) 301 | torch.cuda.manual_seed(args.seed) 302 | torch.cuda.manual_seed_all(args.seed) 303 | 304 | # create model 305 | _, args.nb_classes = build_dataset(is_train=False, data_set=args.data_set, args=args) 306 | model = create_model(args.model, pretrained=False, num_classes=args.nb_classes, 307 | drop_rate=args.drop, drop_path_rate=args.drop_path, attn_drop_rate=args.attn_drop_rate, 308 | drop_block_rate=None, use_mean_pooling=args.use_mean_pooling, init_scale=args.init_scale, 309 | use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb, 310 | init_values=args.layer_scale_init_value,).cuda() 311 | 312 | # ood 313 | logger = create_logger(args.results_dir) 314 | if args.class_idx is not None: 315 | cls_list = ood_utils.get_superclass_list(args.data_set) 316 | main() 317 | 318 | 319 | 320 | 321 | 322 | -------------------------------------------------------------------------------- /MOODv1/dataset_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Modified on torchvision code bases 8 | # https://github.com/pytorch/vision 9 | # --------------------------------------------------------' 10 | from torchvision.datasets.vision import VisionDataset 11 | import torch 12 | from PIL import Image 13 | from torch.utils.data import DataLoader, Dataset 14 | import os 15 | import os.path 16 | import random 17 | import blobfile as bf 18 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 19 | import numpy as np 20 | 21 | def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: 22 | """Checks if a file is an allowed extension. 23 | 24 | Args: 25 | filename (string): path to a file 26 | extensions (tuple of strings): extensions to consider (lowercase) 27 | 28 | Returns: 29 | bool: True if the filename ends with one of given extensions 30 | """ 31 | return filename.lower().endswith(extensions) 32 | 33 | 34 | def is_image_file(filename: str) -> bool: 35 | """Checks if a file is an allowed image extension. 36 | 37 | Args: 38 | filename (string): path to a file 39 | 40 | Returns: 41 | bool: True if the filename ends with a known image extension 42 | """ 43 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 44 | 45 | 46 | def make_dataset( 47 | directory: str, 48 | class_to_idx: Dict[str, int], 49 | extensions: Optional[Tuple[str, ...]] = None, 50 | is_valid_file: Optional[Callable[[str], bool]] = None, 51 | ) -> List[Tuple[str, int]]: 52 | instances = [] 53 | directory = os.path.expanduser(directory) 54 | both_none = extensions is None and is_valid_file is None 55 | both_something = extensions is not None and is_valid_file is not None 56 | if both_none or both_something: 57 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 58 | if extensions is not None: 59 | def is_valid_file(x: str) -> bool: 60 | return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) 61 | is_valid_file = cast(Callable[[str], bool], is_valid_file) 62 | for target_class in sorted(class_to_idx.keys()): 63 | class_index = class_to_idx[target_class] 64 | target_dir = os.path.join(directory, target_class) 65 | if not os.path.isdir(target_dir): 66 | continue 67 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 68 | for fname in sorted(fnames): 69 | path = os.path.join(root, fname) 70 | if is_valid_file(path): 71 | item = path, class_index 72 | instances.append(item) 73 | return instances 74 | 75 | 76 | class DatasetFolder(VisionDataset): 77 | """A generic data loader where the samples are arranged in this way: :: 78 | 79 | root/class_x/xxx.ext 80 | root/class_x/xxy.ext 81 | root/class_x/xxz.ext 82 | 83 | root/class_y/123.ext 84 | root/class_y/nsdf3.ext 85 | root/class_y/asd932_.ext 86 | 87 | Args: 88 | root (string): Root directory path. 89 | loader (callable): A function to load a sample given its path. 90 | extensions (tuple[string]): A list of allowed extensions. 91 | both extensions and is_valid_file should not be passed. 92 | transform (callable, optional): A function/transform that takes in 93 | a sample and returns a transformed version. 94 | E.g, ``transforms.RandomCrop`` for images. 95 | target_transform (callable, optional): A function/transform that takes 96 | in the target and transforms it. 97 | is_valid_file (callable, optional): A function that takes path of a file 98 | and check if the file is a valid file (used to check of corrupt files) 99 | both extensions and is_valid_file should not be passed. 100 | 101 | Attributes: 102 | classes (list): List of the class names sorted alphabetically. 103 | class_to_idx (dict): Dict with items (class_name, class_index). 104 | samples (list): List of (sample path, class_index) tuples 105 | targets (list): The class_index value for each image in the dataset 106 | """ 107 | 108 | def __init__( 109 | self, 110 | root: str, 111 | loader: Callable[[str], Any], 112 | extensions: Optional[Tuple[str, ...]] = None, 113 | transform: Optional[Callable] = None, 114 | target_transform: Optional[Callable] = None, 115 | is_valid_file: Optional[Callable[[str], bool]] = None, 116 | ) -> None: 117 | super(DatasetFolder, self).__init__(root, transform=transform, 118 | target_transform=target_transform) 119 | classes, class_to_idx = self._find_classes(self.root) 120 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 121 | if len(samples) == 0: 122 | msg = "Found 0 files in subfolders of: {}\n".format(self.root) 123 | if extensions is not None: 124 | msg += "Supported extensions are: {}".format(",".join(extensions)) 125 | raise RuntimeError(msg) 126 | 127 | self.loader = loader 128 | self.extensions = extensions 129 | 130 | self.classes = classes 131 | self.class_to_idx = class_to_idx 132 | self.samples = samples 133 | self.targets = [s[1] for s in samples] 134 | 135 | def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: 136 | """ 137 | Finds the class folders in a dataset. 138 | 139 | Args: 140 | dir (string): Root directory path. 141 | 142 | Returns: 143 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 144 | 145 | Ensures: 146 | No class is a subdirectory of another. 147 | """ 148 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 149 | classes.sort() 150 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 151 | return classes, class_to_idx 152 | 153 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 154 | """ 155 | Args: 156 | index (int): Index 157 | 158 | Returns: 159 | tuple: (sample, target) where target is class_index of the target class. 160 | """ 161 | while True: 162 | try: 163 | path, target = self.samples[index] 164 | sample = self.loader(path) 165 | break 166 | except Exception as e: 167 | print(e) 168 | index = random.randint(0, len(self.samples) - 1) 169 | 170 | if self.transform is not None: 171 | sample = self.transform(sample) 172 | if self.target_transform is not None: 173 | target = self.target_transform(target) 174 | 175 | return sample, target 176 | 177 | def __len__(self) -> int: 178 | return len(self.samples) 179 | 180 | 181 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 182 | 183 | 184 | def pil_loader(path: str) -> Image.Image: 185 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | 190 | 191 | # TODO: specify the return type 192 | def accimage_loader(path: str) -> Any: 193 | import accimage 194 | try: 195 | return accimage.Image(path) 196 | except IOError: 197 | # Potentially a decoding problem, fall back to PIL.Image 198 | return pil_loader(path) 199 | 200 | 201 | def default_loader(path: str) -> Any: 202 | from torchvision import get_image_backend 203 | if get_image_backend() == 'accimage': 204 | return accimage_loader(path) 205 | else: 206 | return pil_loader(path) 207 | 208 | 209 | class ImageFolder(DatasetFolder): 210 | """A generic data loader where the images are arranged in this way: :: 211 | 212 | root/dog/xxx.png 213 | root/dog/xxy.png 214 | root/dog/xxz.png 215 | 216 | root/cat/123.png 217 | root/cat/nsdf3.png 218 | root/cat/asd932_.png 219 | 220 | Args: 221 | root (string): Root directory path. 222 | transform (callable, optional): A function/transform that takes in an PIL image 223 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 224 | target_transform (callable, optional): A function/transform that takes in the 225 | target and transforms it. 226 | loader (callable, optional): A function to load an image given its path. 227 | is_valid_file (callable, optional): A function that takes path of an Image file 228 | and check if the file is a valid file (used to check of corrupt files) 229 | 230 | Attributes: 231 | classes (list): List of the class names sorted alphabetically. 232 | class_to_idx (dict): Dict with items (class_name, class_index). 233 | imgs (list): List of (image path, class_index) tuples 234 | """ 235 | 236 | def __init__( 237 | self, 238 | root: str, 239 | transform: Optional[Callable] = None, 240 | target_transform: Optional[Callable] = None, 241 | loader: Callable[[str], Any] = default_loader, 242 | is_valid_file: Optional[Callable[[str], bool]] = None, 243 | ): 244 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 245 | transform=transform, 246 | target_transform=target_transform, 247 | is_valid_file=is_valid_file) 248 | self.imgs = self.samples 249 | self.data = self.samples 250 | 251 | 252 | class SegmentationDataset(Dataset): 253 | def __init__( 254 | self, 255 | image_paths, 256 | label_paths=None, 257 | resolution=224, 258 | random_crop=False, 259 | random_flip=True, 260 | ): 261 | super().__init__() 262 | self.resolution = resolution 263 | self.local_images = image_paths 264 | self.local_labels = label_paths 265 | self.random_crop = random_crop 266 | self.random_flip = random_flip 267 | 268 | def __len__(self): 269 | return len(self.local_images) 270 | 271 | def __getitem__(self, idx): 272 | data = {} 273 | 274 | image_path = self.local_images[idx] 275 | with bf.BlobFile(image_path, "rb") as f: 276 | pil_image = Image.open(f) 277 | pil_image.load() 278 | pil_image = pil_image.convert("RGB") 279 | 280 | label_path = None 281 | if self.local_labels is not None: 282 | label_path = self.local_labels[idx] 283 | if label_path is not None: 284 | with bf.BlobFile(label_path, "rb") as f: 285 | pil_label = Image.open(f) 286 | pil_label.load() 287 | pil_label = pil_label.convert("L") 288 | 289 | if self.random_crop: 290 | if label_path is not None: 291 | arr_image, arr_label = random_crop_two_arr(pil_image, pil_label, self.resolution) 292 | else: 293 | arr = random_crop_arr(pil_image, self.resolution) 294 | else: 295 | arr_image = center_crop_arr(pil_image, self.resolution) 296 | if label_path is not None: 297 | arr_label = center_crop_arr(pil_label, self.resolution) 298 | 299 | def save_an_image(image, save_name): 300 | im = Image.fromarray(image) 301 | im.save(save_name) 302 | 303 | if self.random_flip and random.random() < 0.5: 304 | arr_image = arr_image[:, ::-1, :] 305 | if label_path is not None: 306 | arr_label = arr_label[:, ::-1] 307 | 308 | arr_image = arr_image.astype(np.float32) / 255 # / 127.5 - 1 309 | # np.ascontiguousarray to ensure the memory continuity 310 | image = torch.tensor(np.ascontiguousarray(np.transpose(arr_image, [2, 0, 1]))) 311 | 312 | if label_path is not None: 313 | arr_label = arr_label.astype(np.float32) / 255 # / 127.5 - 1 314 | # label = np.ascontiguousarray(np.transpose(data['label'], [2, 0, 1])) 315 | label = torch.tensor(np.ascontiguousarray(np.expand_dims(arr_label, axis=0))) 316 | # (batch_size, 3, image_size, image_size), (batch_size, image_size, image_size) 317 | else: 318 | label = torch.ones((1, self.resolution, self.resolution)) 319 | 320 | return image, label 321 | 322 | def center_crop_arr(pil_image, image_size): 323 | # We are not on a new enough PIL to support the `reducing_gap` 324 | # argument, which uses BOX downsampling at powers of two first. 325 | # Thus, we do it by hand to improve downsample quality. 326 | while min(*pil_image.size) >= 2 * image_size: 327 | pil_image = pil_image.resize( 328 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 329 | ) 330 | 331 | scale = image_size / min(*pil_image.size) 332 | pil_image = pil_image.resize( 333 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 334 | ) 335 | 336 | arr = np.array(pil_image) 337 | crop_y = (arr.shape[0] - image_size) // 2 338 | crop_x = (arr.shape[1] - image_size) // 2 339 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 340 | 341 | 342 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 343 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 344 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 345 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 346 | 347 | # We are not on a new enough PIL to support the `reducing_gap` 348 | # argument, which uses BOX downsampling at powers of two first. 349 | # Thus, we do it by hand to improve downsample quality. 350 | while min(*pil_image.size) >= 2 * smaller_dim_size: 351 | pil_image = pil_image.resize( 352 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 353 | ) 354 | 355 | scale = smaller_dim_size / min(*pil_image.size) 356 | pil_image = pil_image.resize( 357 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 358 | ) 359 | 360 | arr = np.array(pil_image) 361 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 362 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 363 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 364 | 365 | 366 | def random_crop_two_arr(pil_image, pil_label, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 367 | assert pil_image.shape == pil_label.shape 368 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 369 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 370 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 371 | 372 | # We are not on a new enough PIL to support the `reducing_gap` 373 | # argument, which uses BOX downsampling at powers of two first. 374 | # Thus, we do it by hand to improve downsample quality. 375 | while min(*pil_image.size) >= 2 * smaller_dim_size: 376 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 377 | pil_label = pil_label.resize(tuple(x // 2 for x in pil_label.size), resample=Image.BOX) 378 | 379 | scale = smaller_dim_size / min(*pil_image.size) 380 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 381 | pil_label = pil_label.resize(tuple(round(x * scale) for x in pil_label.size), resample=Image.BICUBIC) 382 | 383 | arr_image = np.array(pil_image) 384 | arr_label = np.array(pil_label) 385 | crop_y = random.randrange(arr_image.shape[0] - image_size + 1) 386 | crop_x = random.randrange(arr_image.shape[1] - image_size + 1) 387 | 388 | image = arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 389 | label = arr_label[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 390 | return image, label 391 | -------------------------------------------------------------------------------- /MOODv1/eval_with_logits.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import os 5 | from pydoc import classname 6 | import numpy as np 7 | import time 8 | import logging 9 | import argparse 10 | from collections import OrderedDict 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | 17 | from datasets import build_dataset 18 | from timm.models import create_model 19 | import ood_utils 20 | import utils 21 | 22 | from torch.autograd import Variable 23 | import modeling_finetune 24 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 25 | 26 | get_metric = None 27 | 28 | def get_eval_results(dtest, dood, args): 29 | labels = [1] * len(dtest) + [0] * len(dood) if args.metric in ['softmax', 'gradnorm'] else [0] * len(dtest) + [1] * len(dood) 30 | fpr95 = ood_utils.get_fpr(dtest, dood, labels) 31 | auroc, aupr = ood_utils.get_roc_sklearn(dtest, dood, labels), ood_utils.get_pr_sklearn(dtest, dood, labels) 32 | return fpr95, auroc, aupr 33 | 34 | 35 | def compute(model, ood_set, softmax_test, args): 36 | ood_sampler = torch.utils.data.SequentialSampler(ood_set) 37 | ood_loader = torch.utils.data.DataLoader(ood_set, sampler=ood_sampler, 38 | batch_size=args.batch_size, num_workers=args.num_workers, 39 | pin_memory=args.pin_mem, drop_last=False) 40 | softmax_ood = get_metric(model, ood_loader, args.ood_dataset, args) 41 | fpr95, auroc, aupr = get_eval_results(softmax_test, softmax_ood, args,) 42 | return fpr95, auroc, aupr 43 | 44 | 45 | def get_args(): 46 | parser = argparse.ArgumentParser(description="SSD evaluation") 47 | 48 | parser.add_argument('--metric', default='softmax', help='OOD detection metric of the logits', 49 | type=str, choices=['softmax', 'entropy', 'energy', 'gradnorm']) 50 | parser.add_argument('--ood_dataset', help='name of ood dataset', default=None, type=str) 51 | parser.add_argument('--ood_data_path', help='path of ood dataset', default=None, type=str) 52 | parser.add_argument('--cc', help='class-conditioned distance', default=False, action='store_true') 53 | parser.add_argument('--avgcc', help='average of class-conditioned distance', default=False, action='store_true') 54 | 55 | parser.add_argument("--results-dir", type=str, default="./eval_results") 56 | parser.add_argument('--class_idx', help='One-class OOD: the idx of the ID class number. Multi-class OOD: None', default=None, type=int) 57 | 58 | parser.add_argument("--batch_size", type=int, default=16) 59 | parser.add_argument("--ckpt", type=str, help="checkpoint path", required=True) 60 | parser.add_argument("--method", type=str, default='MOOD', help='name of logger') 61 | 62 | parser.set_defaults(eval_ood=True) 63 | # Model parameters 64 | parser.add_argument('--model', default='beit_large_patch16_224', type=str, metavar='MODEL', 65 | help='Name of model to train') 66 | parser.add_argument('--rel_pos_bias', action='store_true') 67 | parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') 68 | parser.set_defaults(rel_pos_bias=True) 69 | parser.add_argument('--abs_pos_emb', action='store_true') 70 | parser.set_defaults(abs_pos_emb=False) 71 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float, 72 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") 73 | 74 | parser.add_argument('--input_size', default=224, type=int, help='images input size') 75 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 76 | help='Dropout rate (default: 0.)') 77 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', 78 | help='Attention dropout rate (default: 0.)') 79 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 80 | help='Drop path rate (default: 0.1)') 81 | 82 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False) 83 | 84 | parser.add_argument('--model_ema', action='store_true', default=False) 85 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 86 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') 87 | 88 | # Augmentation parameters 89 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 90 | help='Color jitter factor (default: 0.4)') 91 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 92 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 93 | parser.add_argument('--smoothing', type=float, default=0.1, 94 | help='Label smoothing (default: 0.1)') 95 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 96 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 97 | 98 | # Evaluation parameters 99 | parser.add_argument('--crop_pct', type=float, default=None) 100 | 101 | # * Random Erase params 102 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 103 | help='Random erase prob (default: 0.25)') 104 | parser.add_argument('--remode', type=str, default='pixel', 105 | help='Random erase mode (default: "pixel")') 106 | parser.add_argument('--recount', type=int, default=1, 107 | help='Random erase count (default: 1)') 108 | parser.add_argument('--resplit', action='store_true', default=False, 109 | help='Do not random erase first (clean) augmentation split') 110 | 111 | # * Mixup params 112 | parser.add_argument('--mixup', type=float, default=0, 113 | help='mixup alpha, mixup enabled if > 0.') 114 | parser.add_argument('--cutmix', type=float, default=0, 115 | help='cutmix alpha, cutmix enabled if > 0.') 116 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 117 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 118 | parser.add_argument('--mixup_prob', type=float, default=1.0, 119 | help='Probability of performing mixup or cutmix when either/both is enabled') 120 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 121 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 122 | parser.add_argument('--mixup_mode', type=str, default='batch', 123 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 124 | 125 | # * Finetuning params 126 | parser.add_argument('--finetune', default='', 127 | help='finetune from checkpoint') 128 | parser.add_argument('--imagenet30_key', default='model|module', type=str) 129 | parser.add_argument('--model_prefix', default='', type=str) 130 | parser.add_argument('--init_scale', default=0.001, type=float) 131 | parser.add_argument('--use_mean_pooling', action='store_true') 132 | parser.set_defaults(use_mean_pooling=True) 133 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') 134 | parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False) 135 | 136 | # Dataset parameters 137 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 138 | help='dataset path') 139 | parser.add_argument('--eval_data_path', default=None, type=str, 140 | help='dataset path for evaluation') 141 | parser.add_argument('--nb_classes', default=0, type=int, 142 | help='number of the classification types') 143 | parser.add_argument('--imagenet_default_mean_and_std', default=False, action='store_true') 144 | 145 | parser.add_argument('--data_set', default='IMNET', 146 | type=str, help='ImageNet dataset path') 147 | parser.add_argument('--output_dir', default='', 148 | help='path where to save, empty for no saving') 149 | parser.add_argument('--log_dir', default=None, 150 | help='path where to tensorboard log') 151 | parser.add_argument('--device', default='cuda', 152 | help='device to use for training / testing') 153 | parser.add_argument('--seed', default=0, type=int) 154 | parser.add_argument('--resume', default='', 155 | help='resume from checkpoint') 156 | parser.add_argument('--auto_resume', action='store_true') 157 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 158 | parser.set_defaults(auto_resume=True) 159 | 160 | parser.add_argument('--save_ckpt', action='store_true') 161 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') 162 | parser.set_defaults(save_ckpt=True) 163 | 164 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 165 | help='start epoch') 166 | parser.add_argument('--eval', action='store_true', 167 | help='Perform evaluation only') 168 | parser.add_argument('--dist_eval', action='store_true', default=False, 169 | help='Enabling distributed evaluation') 170 | parser.add_argument('--num_workers', default=10, type=int) 171 | parser.add_argument('--pin_mem', action='store_true', 172 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 173 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 174 | parser.set_defaults(pin_mem=True) 175 | 176 | args = parser.parse_args() 177 | return args 178 | 179 | ## get energy ### 180 | def get_energy(model, dataloader, name, args, is_train=False, max_num=1e10): 181 | model.eval() 182 | energy = [] 183 | with torch.no_grad(): 184 | for index, (img, label) in enumerate(dataloader): 185 | if index >= max_num: 186 | break 187 | img, label = img.cuda(), label.cuda() 188 | outputs = model(img) 189 | # e = torch.logsumexp(outputs, dim=1) 190 | e = -torch.log(torch.sum(torch.exp(outputs), dim=1)) 191 | energy += list(e.view(e.size(0)).cpu().numpy()) 192 | if args.class_idx is None and (index + 1) % 100 == 0: 193 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r') 194 | energy = np.array(energy) 195 | print('\nenergy shape, ', energy.shape) 196 | return energy 197 | 198 | ## get energy ### 199 | def get_entropy(model, dataloader, name, args, is_train=False, max_num=1e10): 200 | model.eval() 201 | entropys = [] 202 | with torch.no_grad(): 203 | for index, (img, label) in enumerate(dataloader): 204 | if index >= max_num: 205 | break 206 | img, label = img.cuda(), label.cuda() 207 | outputs = model(img) 208 | e = -1.0 * torch.sum(F.softmax(outputs, dim=1) * F.log_softmax(outputs, dim=1), dim=1) 209 | entropys += list(e.view(e.size(0)).cpu().numpy()) 210 | if args.class_idx is None and (index + 1) % 100 == 0: 211 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r') 212 | entropys = np.array(entropys) 213 | print('\nentropys shape, ', entropys.shape) 214 | return entropys 215 | 216 | ## get gradnorm ### 217 | def get_gradnorm(model, dataloader, name, args, is_train=False, max_num=1e10): 218 | confs = [] 219 | logsoftmax = torch.nn.LogSoftmax(dim=-1).cuda() 220 | 221 | for index, (img, label) in enumerate(dataloader): 222 | if index >= max_num: 223 | break 224 | inputs = Variable(img.cuda(), requires_grad=True) 225 | model.zero_grad() 226 | 227 | img, label = img.cuda(), label.cuda() 228 | outputs = model(img) 229 | 230 | targets = torch.ones((inputs.shape[0], args.nb_classes)).cuda() 231 | loss = torch.mean(torch.sum(-targets * logsoftmax(outputs), dim=-1)) 232 | 233 | loss.backward() 234 | layer_grad = model.head.weight.grad.data 235 | layer_grad_norm = torch.sum(torch.abs(layer_grad)).cpu().numpy() 236 | confs.append(layer_grad_norm) 237 | 238 | if args.class_idx is None and (index + 1) % 100 == 0: 239 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r') 240 | 241 | return confs 242 | 243 | 244 | ## get softmax ### 245 | def get_softmax(model, dataloader, name, args, is_train=False, max_num=1e10): 246 | model.eval() 247 | probs = [] 248 | with torch.no_grad(): 249 | for index, (img, label) in enumerate(dataloader): 250 | if index >= max_num: 251 | break 252 | img, label = img.cuda(), label.cuda() 253 | # import pdb;pdb.set_trace() 254 | prob = torch.max(model(img), axis=-1).values 255 | probs += list(prob.cpu().numpy()) 256 | if args.class_idx is None and (index + 1) % 100 == 0: 257 | print('{}: {}/{}'.format(name, index+1, len(dataloader)), end='\r') 258 | probs = np.array(probs) 259 | print('\n') 260 | return probs 261 | 262 | 263 | def main(): 264 | # check ckpt 265 | assert os.path.exists(args.ckpt), 'Not find {}'.format(args.ckpt) 266 | print('loading from {}'.format(args.ckpt)) 267 | 268 | # load checkpoint 269 | global model 270 | if args.ckpt.startswith('https'): 271 | state_dict = torch.hub.load_state_dict_from_url( 272 | args.ckpt, map_location='cpu', check_hash=True) 273 | else: 274 | state_dict = torch.load(args.ckpt, map_location='cpu') 275 | 276 | if 'model' in state_dict.keys(): 277 | state_dict = state_dict['model'] 278 | elif 'state_dict' in state_dict.keys(): 279 | state_dict = state_dict['state_dict'] 280 | elif 'module' in state_dict.keys(): 281 | state_dict = state_dict['module'] 282 | 283 | for k in list(state_dict.keys()): 284 | state_dict[k.replace('module.', '')] = state_dict.pop(k) 285 | 286 | if 'pt22k' in args.ckpt: # for pretrained ckpt 287 | model = utils.set_checkpoint_for_finetune(model, args, args.ckpt) 288 | else: # for fine-tuned ckpt 289 | utils.load_state_dict(model, state_dict, prefix=args.model_prefix) 290 | 291 | # test dataloader 292 | test_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args) 293 | test_sampler = torch.utils.data.SequentialSampler(test_set) 294 | test_loader = torch.utils.data.DataLoader(test_set, sampler=test_sampler, 295 | batch_size=args.batch_size, num_workers=args.num_workers, 296 | pin_memory=args.pin_mem, drop_last=False) 297 | 298 | global get_metric 299 | metric_dict = { 300 | 'softmax': get_softmax, 301 | 'entropy': get_entropy, 302 | 'energy': get_energy, 303 | 'gradnorm': get_gradnorm, 304 | } 305 | get_metric = metric_dict[args.metric] 306 | softmax_test = get_metric(model, test_loader, args.data_set, args) 307 | 308 | if args.class_idx is None: 309 | ood_set, _ = build_dataset(is_train=False, data_set=args.ood_dataset, args=args, ood=True, ood_data_path=args.ood_data_path) 310 | fpr95, auroc, aupr = compute(model, ood_set, softmax_test, args) 311 | ood = "%-15s\t" % (args.ood_dataset) 312 | logger.info("{}\tIn-data: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format( 313 | args.metric, args.data_set, ood, fpr95*100, auroc*100, aupr*100)) 314 | 315 | else: 316 | ood_multi_class_set, _ = build_dataset(is_train=False, data_set=args.data_set, args=args, ood=True) 317 | 318 | fpr95s, aurocs, auprs = [], [], [] 319 | for d in range(len(cls_list)): 320 | if d == args.class_idx: 321 | continue 322 | ood_set = ood_utils.get_subclass_dataset(ood_multi_class_set, classes=cls_list[d]) 323 | 324 | args.ood_dataset = str(d) 325 | fpr95, auroc, aupr = compute(model, ood_set, softmax_test, args) 326 | logger.info("MOOD\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format( 327 | args.data_set, args.class_idx, d, fpr95*100, auroc*100, aupr*100)) 328 | 329 | fpr95s.append(fpr95) 330 | aurocs.append(auroc) 331 | auprs.append(aupr) 332 | 333 | fpr95 = np.mean(fpr95s)*100 334 | auroc = np.mean(aurocs)*100 335 | aupr = np.mean(auprs)*100 336 | 337 | results = "MOOD\tDataset: {}\tID: {}\tOOD: {}\tFPR95: {:.2f}\tAUROC: {:.2f}\tAUPR: {:.2f}".format( 338 | args.data_set, args.class_idx, 'all', fpr95, auroc, aupr) 339 | logger.info(results) 340 | return fpr95, auroc, aupr 341 | 342 | 343 | def create_logger(results_dir): 344 | # create logger 345 | logging.basicConfig(level=logging.INFO, format="%(message)s") 346 | logger = logging.getLogger() 347 | 348 | # create results dir 349 | if results_dir is not None: 350 | ood_utils.mkdir(results_dir) 351 | results_file = os.path.join(results_dir, 'eval_results.txt') 352 | logger.addHandler(logging.FileHandler(results_file, "a")) 353 | print('=> Saving to {}'.format(results_file)) 354 | return logger 355 | 356 | 357 | if __name__ == "__main__": 358 | args = get_args() 359 | device = "cuda:0" 360 | 361 | torch.manual_seed(args.seed) 362 | torch.cuda.manual_seed(args.seed) 363 | torch.cuda.manual_seed_all(args.seed) 364 | 365 | # create model 366 | _, args.nb_classes = build_dataset(is_train=False, data_set=args.data_set, args=args) 367 | model = create_model(args.model, pretrained=False, num_classes=args.nb_classes, 368 | drop_rate=args.drop, drop_path_rate=args.drop_path, attn_drop_rate=args.attn_drop_rate, 369 | drop_block_rate=None, use_mean_pooling=args.use_mean_pooling, init_scale=args.init_scale, 370 | use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb, 371 | init_values=args.layer_scale_init_value).cuda() 372 | 373 | # ood 374 | logger = create_logger(args.results_dir) 375 | if args.class_idx is not None: 376 | cls_list = ood_utils.get_superclass_list(args.data_set) 377 | main() 378 | 379 | 380 | 381 | 382 | --------------------------------------------------------------------------------