├── configs ├── env.yml ├── pretext │ ├── moco_imagenet50.yml │ ├── moco_imagenet100.yml │ ├── moco_imagenet200.yml │ ├── simclr_stl10.yml │ ├── simclr_cifar10.yml │ └── simclr_cifar20.yml ├── scan │ ├── imagenet_eval.yml │ ├── scan_stl10.yml │ ├── scan_cifar10.yml │ ├── scan_cifar20.yml │ ├── scan_imagenet_50.yml │ ├── scan_imagenet_100.yml │ └── scan_imagenet_200.yml └── selflabel │ ├── selflabel_stl10.yml │ ├── selflabel_cifar10.yml │ ├── selflabel_cifar20.yml │ ├── selflabel_imagenet_200.yml │ ├── selflabel_imagenet_50.yml │ └── selflabel_imagenet_100.yml ├── images ├── teaser.jpg ├── pipeline.png ├── prototypes_cifar10.jpg └── tutorial │ ├── prototypes_stl10.jpg │ └── confusion_matrix_stl10.png ├── models ├── resnet.py ├── models.py ├── resnet_cifar.py └── resnet_stl.py ├── utils ├── ema.py ├── mypath.py ├── collate.py ├── config.py ├── utils.py ├── memory.py ├── train_utils.py ├── evaluate_utils.py └── common_config.py ├── data ├── imagenet_subsets │ ├── imagenet_50.txt │ ├── imagenet_100.txt │ └── imagenet_200.txt ├── custom_dataset.py ├── imagenet.py ├── augment.py ├── stl.py └── cifar.py ├── requirements.txt ├── tutorial_nn.py ├── moco.py ├── selflabel.py ├── losses └── losses.py ├── eval.py ├── scan.py ├── TUTORIAL.md ├── simclr.py ├── README.md └── LICENSE /configs/env.yml: -------------------------------------------------------------------------------- 1 | root_dir: /path/where/to/store/results/ 2 | -------------------------------------------------------------------------------- /images/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/teaser.jpg -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/pipeline.png -------------------------------------------------------------------------------- /images/prototypes_cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/prototypes_cifar10.jpg -------------------------------------------------------------------------------- /images/tutorial/prototypes_stl10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/tutorial/prototypes_stl10.jpg -------------------------------------------------------------------------------- /images/tutorial/confusion_matrix_stl10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/tutorial/confusion_matrix_stl10.png -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch.nn as nn 6 | import torchvision.models as models 7 | 8 | 9 | def resnet50(): 10 | backbone = models.__dict__['resnet50']() 11 | backbone.fc = nn.Identity() 12 | return {'backbone': backbone, 'dim': 2048} 13 | -------------------------------------------------------------------------------- /configs/pretext/moco_imagenet50.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_50 12 | val_db_name: imagenet_50 13 | num_classes: 50 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /configs/pretext/moco_imagenet100.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_100 12 | val_db_name: imagenet_100 13 | num_classes: 100 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /configs/pretext/moco_imagenet200.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_200 12 | val_db_name: imagenet_200 13 | num_classes: 200 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | class EMA(object): 7 | def __init__(self, model, alpha=0.999): 8 | self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()} 9 | self.param_keys = [k for k, _ in model.named_parameters()] 10 | self.alpha = alpha 11 | 12 | def update_params(self, model): 13 | state = model.state_dict() 14 | for name in self.param_keys: 15 | self.shadow[name].copy_(self.alpha * self.shadow[name] + (1 - self.alpha) * state[name]) 16 | 17 | def apply_shadow(self, model): 18 | model.load_state_dict(self.shadow, strict=True) 19 | -------------------------------------------------------------------------------- /utils/mypath.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | 7 | 8 | class MyPath(object): 9 | @staticmethod 10 | def db_root_dir(database=''): 11 | db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200'} 12 | assert(database in db_names) 13 | 14 | if database == 'cifar-10': 15 | return '/path/to/cifar-10/' 16 | 17 | elif database == 'cifar-20': 18 | return '/path/to/cifar-20/' 19 | 20 | elif database == 'stl-10': 21 | return '/path/to/stl-10/' 22 | 23 | elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']: 24 | return '/path/to/imagenet/' 25 | 26 | else: 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /configs/scan/imagenet_eval.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | num_heads: 10 # Use multiple heads 12 | 13 | # Dataset 14 | train_db_name: imagenet 15 | val_db_name: imagenet 16 | num_classes: 1000 17 | num_neighbors: 50 18 | 19 | # Transformations 20 | augmentation_strategy: simclr 21 | augmentation_kwargs: 22 | random_resized_crop: 23 | size: 224 24 | scale: [0.2, 1.0] 25 | color_jitter_random_apply: 26 | p: 0.8 27 | color_jitter: 28 | brightness: 0.4 29 | contrast: 0.4 30 | saturation: 0.4 31 | hue: 0.1 32 | random_grayscale: 33 | p: 0.2 34 | normalize: 35 | mean: [0.485, 0.456, 0.406] 36 | std: [0.229, 0.224, 0.225] 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | num_workers: 12 45 | -------------------------------------------------------------------------------- /configs/scan/scan_stl10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: stl-10 18 | val_db_name: stl-10 19 | num_classes: 10 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 96 26 | normalize: 27 | mean: [0.485, 0.456, 0.406] 28 | std: [0.229, 0.224, 0.225] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 32 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 96 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 100 47 | batch_size: 128 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_stl10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Loss 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: stl-10 21 | val_db_name: stl-10 22 | num_classes: 10 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 96 28 | normalize: 29 | mean: [0.485, 0.456, 0.406] 30 | std: [0.229, 0.224, 0.225] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 32 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 96 39 | normalize: 40 | mean: [0.485, 0.456, 0.406] 41 | std: [0.229, 0.224, 0.225] 42 | 43 | # Hyperparameters 44 | optimizer: adam 45 | optimizer_kwargs: 46 | lr: 0.0001 47 | weight_decay: 0.0001 48 | epochs: 100 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /configs/scan/scan_cifar10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: cifar-10 18 | val_db_name: cifar-10 19 | num_classes: 10 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 32 26 | normalize: 27 | mean: [0.4914, 0.4822, 0.4465] 28 | std: [0.2023, 0.1994, 0.2010] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 16 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 32 37 | normalize: 38 | mean: [0.4914, 0.4822, 0.4465] 39 | std: [0.2023, 0.1994, 0.2010] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 50 47 | batch_size: 128 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /configs/scan/scan_cifar20.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: cifar-20 18 | val_db_name: cifar-20 19 | num_classes: 20 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 32 26 | normalize: 27 | mean: [0.5071, 0.4867, 0.4408] 28 | std: [0.2675, 0.2565, 0.2761] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 16 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 32 37 | normalize: 38 | mean: [0.5071, 0.4867, 0.4408] 39 | std: [0.2675, 0.2565, 0.2761] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 100 47 | batch_size: 512 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_cifar10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Criterion 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: cifar-10 21 | val_db_name: cifar-10 22 | num_classes: 10 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 32 28 | normalize: 29 | mean: [0.4914, 0.4822, 0.4465] 30 | std: [0.2023, 0.1994, 0.2010] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 16 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 32 39 | normalize: 40 | mean: [0.4914, 0.4822, 0.4465] 41 | std: [0.2023, 0.1994, 0.2010] 42 | 43 | # Hyperparameters 44 | epochs: 200 45 | optimizer: adam 46 | optimizer_kwargs: 47 | lr: 0.0001 48 | weight_decay: 0.0001 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_cifar20.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Criterion 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: cifar-20 21 | val_db_name: cifar-20 22 | num_classes: 20 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 32 28 | normalize: 29 | mean: [0.5071, 0.4867, 0.4408] 30 | std: [0.2675, 0.2565, 0.2761] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 16 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 32 39 | normalize: 40 | mean: [0.5071, 0.4867, 0.4408] 41 | std: [0.2675, 0.2565, 0.2761] 42 | 43 | # Hyperparameters 44 | epochs: 200 45 | optimizer: adam 46 | optimizer_kwargs: 47 | lr: 0.0001 48 | weight_decay: 0.0001 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_imagenet_200.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_200 22 | val_db_name: imagenet_200 23 | num_classes: 200 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 8 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_imagenet_50.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_50 22 | val_db_name: imagenet_50 23 | num_classes: 50 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 16 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_imagenet_100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_100 22 | val_db_name: imagenet_100 23 | num_classes: 100 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 12 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /configs/pretext/simclr_stl10.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: stl-10 12 | val_db_name: stl-10 13 | num_classes: 10 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 96 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.485, 0.456, 0.406] 51 | std: [0.229, 0.224, 0.225] 52 | 53 | transformation_kwargs: 54 | crop_size: 96 55 | normalize: 56 | mean: [0.485, 0.456, 0.406] 57 | std: [0.229, 0.224, 0.225] 58 | -------------------------------------------------------------------------------- /configs/pretext/simclr_cifar10.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: cifar-10 12 | val_db_name: cifar-10 13 | num_classes: 10 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 32 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.4914, 0.4822, 0.4465] 51 | std: [0.2023, 0.1994, 0.2010] 52 | 53 | transformation_kwargs: 54 | crop_size: 32 55 | normalize: 56 | mean: [0.4914, 0.4822, 0.4465] 57 | std: [0.2023, 0.1994, 0.2010] 58 | -------------------------------------------------------------------------------- /configs/pretext/simclr_cifar20.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: cifar-20 12 | val_db_name: cifar-20 13 | num_classes: 20 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 32 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.5071, 0.4867, 0.4408] 51 | std: [0.2675, 0.2565, 0.2761] 52 | 53 | transformation_kwargs: 54 | crop_size: 32 55 | normalize: 56 | mean: [0.5071, 0.4867, 0.4408] 57 | std: [0.2675, 0.2565, 0.2761] 58 | -------------------------------------------------------------------------------- /data/imagenet_subsets/imagenet_50.txt: -------------------------------------------------------------------------------- 1 | n01601694 Dipper 2 | n01669191 Box Turtle 3 | n01755581 Diamondback Snake 4 | n01770393 Scorpion 5 | n01855672 Goose 6 | n02018207 Water Hen 7 | n02058221 Albatross 8 | n02096177 Cairn Terrier 9 | n02097130 Giant Schnauzer 10 | n02099267 Flat-Coated Retriever 11 | n02100877 Irish Setter 12 | n02104365 Schipperke 13 | n02106030 Collie 14 | n02114855 Coyote 15 | n02125311 Mountain Lion 16 | n02133161 Black Bear 17 | n02484975 Guenon 18 | n02489166 Proboscis Monkey 19 | n02747177 Trash Bin 20 | n02906734 Broom 21 | n03124170 Cowboy Hat 22 | n03272010 Electric Guitar 23 | n03337140 File Cabinet 24 | n03483316 Hair Dryer 25 | n03498962 Hatchet 26 | n03710721 Maillot 27 | n03717622 Manhole Cover 28 | n03733281 Maze 29 | n03759954 Microphone 30 | n03775071 Mitten 31 | n03814639 Neck Brace 32 | n03837869 Obelisk 33 | n03838899 Oboe 34 | n03854065 Organ 35 | n03954731 Woodworking Plane 36 | n03983396 Soda Bottle 37 | n04026417 Purse 38 | n04200800 Shoe Shop 39 | n04209239 Shower Curtain 40 | n04311004 Steel Arch Bridge 41 | n04380533 Table Lamp 42 | n04428191 Threshing Machine 43 | n04443257 Tobacco Shop 44 | n04509417 Unicycle 45 | n04525305 Vending Machine 46 | n04554684 Washing Machine 47 | n04606251 Wreck 48 | n07583066 Guacamole 49 | n07711569 Mashed Potato 50 | n07753592 Banana 51 | -------------------------------------------------------------------------------- /configs/scan/scan_imagenet_50.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_50 18 | val_db_name: imagenet_50 19 | num_classes: 50 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 512 56 | num_workers: 12 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /utils/collate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import numpy as np 7 | import collections 8 | from torch._six import string_classes 9 | 10 | 11 | """ Custom collate function """ 12 | def collate_custom(batch): 13 | if isinstance(batch[0], np.int64): 14 | return np.stack(batch, 0) 15 | 16 | if isinstance(batch[0], torch.Tensor): 17 | return torch.stack(batch, 0) 18 | 19 | elif isinstance(batch[0], np.ndarray): 20 | return np.stack(batch, 0) 21 | 22 | elif isinstance(batch[0], int): 23 | return torch.LongTensor(batch) 24 | 25 | elif isinstance(batch[0], float): 26 | return torch.FloatTensor(batch) 27 | 28 | elif isinstance(batch[0], string_classes): 29 | return batch 30 | 31 | elif isinstance(batch[0], collections.Mapping): 32 | batch_modified = {key: collate_custom([d[key] for d in batch]) for key in batch[0] if key.find('idx') < 0} 33 | return batch_modified 34 | 35 | elif isinstance(batch[0], collections.Sequence): 36 | transposed = zip(*batch) 37 | return [collate_custom(samples) for samples in transposed] 38 | 39 | raise TypeError(('Type is {}'.format(type(batch[0])))) 40 | -------------------------------------------------------------------------------- /configs/scan/scan_imagenet_100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_100 18 | val_db_name: imagenet_100 19 | num_classes: 100 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 1024 56 | num_workers: 16 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /configs/scan/scan_imagenet_200.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_200 18 | val_db_name: imagenet_200 19 | num_classes: 200 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 1024 56 | num_workers: 12 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import yaml 7 | from easydict import EasyDict 8 | from utils.utils import mkdir_if_missing 9 | 10 | def create_config(config_file_env, config_file_exp): 11 | # Config for environment path 12 | with open(config_file_env, 'r') as stream: 13 | root_dir = yaml.safe_load(stream)['root_dir'] 14 | 15 | with open(config_file_exp, 'r') as stream: 16 | config = yaml.safe_load(stream) 17 | 18 | cfg = EasyDict() 19 | 20 | # Copy 21 | for k, v in config.items(): 22 | cfg[k] = v 23 | 24 | # Set paths for pretext task (These directories are needed in every stage) 25 | base_dir = os.path.join(root_dir, cfg['train_db_name']) 26 | pretext_dir = os.path.join(base_dir, 'pretext') 27 | mkdir_if_missing(base_dir) 28 | mkdir_if_missing(pretext_dir) 29 | cfg['pretext_dir'] = pretext_dir 30 | cfg['pretext_checkpoint'] = os.path.join(pretext_dir, 'checkpoint.pth.tar') 31 | cfg['pretext_model'] = os.path.join(pretext_dir, 'model.pth.tar') 32 | cfg['topk_neighbors_train_path'] = os.path.join(pretext_dir, 'topk-train-neighbors.npy') 33 | cfg['topk_neighbors_val_path'] = os.path.join(pretext_dir, 'topk-val-neighbors.npy') 34 | 35 | # If we perform clustering or self-labeling step we need additional paths. 36 | # We also include a run identifier to support multiple runs w/ same hyperparams. 37 | if cfg['setup'] in ['scan', 'selflabel']: 38 | base_dir = os.path.join(root_dir, cfg['train_db_name']) 39 | scan_dir = os.path.join(base_dir, 'scan') 40 | selflabel_dir = os.path.join(base_dir, 'selflabel') 41 | mkdir_if_missing(base_dir) 42 | mkdir_if_missing(scan_dir) 43 | mkdir_if_missing(selflabel_dir) 44 | cfg['scan_dir'] = scan_dir 45 | cfg['scan_checkpoint'] = os.path.join(scan_dir, 'checkpoint.pth.tar') 46 | cfg['scan_model'] = os.path.join(scan_dir, 'model.pth.tar') 47 | cfg['selflabel_dir'] = selflabel_dir 48 | cfg['selflabel_checkpoint'] = os.path.join(selflabel_dir, 'checkpoint.pth.tar') 49 | cfg['selflabel_model'] = os.path.join(selflabel_dir, 'model.pth.tar') 50 | 51 | return cfg 52 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ContrastiveModel(nn.Module): 11 | def __init__(self, backbone, head='mlp', features_dim=128): 12 | super(ContrastiveModel, self).__init__() 13 | self.backbone = backbone['backbone'] 14 | self.backbone_dim = backbone['dim'] 15 | self.head = head 16 | 17 | if head == 'linear': 18 | self.contrastive_head = nn.Linear(self.backbone_dim, features_dim) 19 | 20 | elif head == 'mlp': 21 | self.contrastive_head = nn.Sequential( 22 | nn.Linear(self.backbone_dim, self.backbone_dim), 23 | nn.ReLU(), nn.Linear(self.backbone_dim, features_dim)) 24 | 25 | else: 26 | raise ValueError('Invalid head {}'.format(head)) 27 | 28 | def forward(self, x): 29 | features = self.contrastive_head(self.backbone(x)) 30 | features = F.normalize(features, dim = 1) 31 | return features 32 | 33 | 34 | class ClusteringModel(nn.Module): 35 | def __init__(self, backbone, nclusters, nheads=1): 36 | super(ClusteringModel, self).__init__() 37 | self.backbone = backbone['backbone'] 38 | self.backbone_dim = backbone['dim'] 39 | self.nheads = nheads 40 | assert(isinstance(self.nheads, int)) 41 | assert(self.nheads > 0) 42 | self.cluster_head = nn.ModuleList([nn.Linear(self.backbone_dim, nclusters) for _ in range(self.nheads)]) 43 | 44 | def forward(self, x, forward_pass='default'): 45 | if forward_pass == 'default': 46 | features = self.backbone(x) 47 | out = [cluster_head(features) for cluster_head in self.cluster_head] 48 | 49 | elif forward_pass == 'backbone': 50 | out = self.backbone(x) 51 | 52 | elif forward_pass == 'head': 53 | out = [cluster_head(x) for cluster_head in self.cluster_head] 54 | 55 | elif forward_pass == 'return_all': 56 | features = self.backbone(x) 57 | out = {'features': features, 'output': [cluster_head(features) for cluster_head in self.cluster_head]} 58 | 59 | else: 60 | raise ValueError('Invalid forward pass {}'.format(forward_pass)) 61 | 62 | return out 63 | -------------------------------------------------------------------------------- /data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | """ 10 | AugmentedDataset 11 | Returns an image together with an augmentation. 12 | """ 13 | class AugmentedDataset(Dataset): 14 | def __init__(self, dataset): 15 | super(AugmentedDataset, self).__init__() 16 | transform = dataset.transform 17 | dataset.transform = None 18 | self.dataset = dataset 19 | 20 | if isinstance(transform, dict): 21 | self.image_transform = transform['standard'] 22 | self.augmentation_transform = transform['augment'] 23 | 24 | else: 25 | self.image_transform = transform 26 | self.augmentation_transform = transform 27 | 28 | def __len__(self): 29 | return len(self.dataset) 30 | 31 | def __getitem__(self, index): 32 | sample = self.dataset.__getitem__(index) 33 | image = sample['image'] 34 | 35 | sample['image'] = self.image_transform(image) 36 | sample['image_augmented'] = self.augmentation_transform(image) 37 | 38 | return sample 39 | 40 | 41 | """ 42 | NeighborsDataset 43 | Returns an image with one of its neighbors. 44 | """ 45 | class NeighborsDataset(Dataset): 46 | def __init__(self, dataset, indices, num_neighbors=None): 47 | super(NeighborsDataset, self).__init__() 48 | transform = dataset.transform 49 | 50 | if isinstance(transform, dict): 51 | self.anchor_transform = transform['standard'] 52 | self.neighbor_transform = transform['augment'] 53 | else: 54 | self.anchor_transform = transform 55 | self.neighbor_transform = transform 56 | 57 | dataset.transform = None 58 | self.dataset = dataset 59 | self.indices = indices # Nearest neighbor indices (np.array [len(dataset) x k]) 60 | if num_neighbors is not None: 61 | self.indices = self.indices[:, :num_neighbors+1] 62 | assert(self.indices.shape[0] == len(self.dataset)) 63 | 64 | def __len__(self): 65 | return len(self.dataset) 66 | 67 | def __getitem__(self, index): 68 | output = {} 69 | anchor = self.dataset.__getitem__(index) 70 | 71 | neighbor_index = np.random.choice(self.indices[index], 1)[0] 72 | neighbor = self.dataset.__getitem__(neighbor_index) 73 | 74 | anchor['image'] = self.anchor_transform(anchor['image']) 75 | neighbor['image'] = self.neighbor_transform(neighbor['image']) 76 | 77 | output['anchor'] = anchor['image'] 78 | output['neighbor'] = neighbor['image'] 79 | output['possible_neighbors'] = torch.from_numpy(self.indices[index]) 80 | output['target'] = anchor['target'] 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import numpy as np 8 | import errno 9 | 10 | def mkdir_if_missing(directory): 11 | if not os.path.exists(directory): 12 | try: 13 | os.makedirs(directory) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | class AverageMeter(object): 20 | def __init__(self, name, fmt=':f'): 21 | self.name = name 22 | self.fmt = fmt 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def __str__(self): 38 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 39 | return fmtstr.format(**self.__dict__) 40 | 41 | 42 | class ProgressMeter(object): 43 | def __init__(self, num_batches, meters, prefix=""): 44 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 45 | self.meters = meters 46 | self.prefix = prefix 47 | 48 | def display(self, batch): 49 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 50 | entries += [str(meter) for meter in self.meters] 51 | print('\t'.join(entries)) 52 | 53 | def _get_batch_fmtstr(self, num_batches): 54 | num_digits = len(str(num_batches // 1)) 55 | fmt = '{:' + str(num_digits) + 'd}' 56 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 57 | 58 | 59 | @torch.no_grad() 60 | def fill_memory_bank(loader, model, memory_bank): 61 | model.eval() 62 | memory_bank.reset() 63 | 64 | for i, batch in enumerate(loader): 65 | images = batch['image'].cuda(non_blocking=True) 66 | targets = batch['target'].cuda(non_blocking=True) 67 | output = model(images) 68 | memory_bank.update(output, targets) 69 | if i % 100 == 0: 70 | print('Fill Memory Bank [%d/%d]' %(i, len(loader))) 71 | 72 | 73 | def confusion_matrix(predictions, gt, class_names, output_file=None): 74 | # Plot confusion_matrix and store result to output_file 75 | import sklearn.metrics 76 | import matplotlib.pyplot as plt 77 | confusion_matrix = sklearn.metrics.confusion_matrix(gt, predictions) 78 | confusion_matrix = confusion_matrix / np.sum(confusion_matrix, 1) 79 | 80 | fig, axes = plt.subplots(1) 81 | plt.imshow(confusion_matrix, cmap='Blues') 82 | axes.set_xticks([i for i in range(len(class_names))]) 83 | axes.set_yticks([i for i in range(len(class_names))]) 84 | axes.set_xticklabels(class_names, ha='right', fontsize=8, rotation=40) 85 | axes.set_yticklabels(class_names, ha='right', fontsize=8) 86 | 87 | for (i, j), z in np.ndenumerate(confusion_matrix): 88 | if i == j: 89 | axes.text(j, i, '%d' %(100*z), ha='center', va='center', color='white', fontsize=6) 90 | else: 91 | pass 92 | 93 | plt.tight_layout() 94 | if output_file is None: 95 | plt.show() 96 | else: 97 | plt.savefig(output_file, dpi=300, bbox_inches='tight') 98 | plt.close() 99 | -------------------------------------------------------------------------------- /utils/memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class MemoryBank(object): 10 | def __init__(self, n, dim, num_classes, temperature): 11 | self.n = n 12 | self.dim = dim 13 | self.features = torch.FloatTensor(self.n, self.dim) 14 | self.targets = torch.LongTensor(self.n) 15 | self.ptr = 0 16 | self.device = 'cpu' 17 | self.K = 100 18 | self.temperature = temperature 19 | self.C = num_classes 20 | 21 | def weighted_knn(self, predictions): 22 | # perform weighted knn 23 | retrieval_one_hot = torch.zeros(self.K, self.C).to(self.device) 24 | batchSize = predictions.shape[0] 25 | correlation = torch.matmul(predictions, self.features.t()) 26 | yd, yi = correlation.topk(self.K, dim=1, largest=True, sorted=True) 27 | candidates = self.targets.view(1,-1).expand(batchSize, -1) 28 | retrieval = torch.gather(candidates, 1, yi) 29 | retrieval_one_hot.resize_(batchSize * self.K, self.C).zero_() 30 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 31 | yd_transform = yd.clone().div_(self.temperature).exp_() 32 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , self.C), 33 | yd_transform.view(batchSize, -1, 1)), 1) 34 | _, class_preds = probs.sort(1, True) 35 | class_pred = class_preds[:, 0] 36 | 37 | return class_pred 38 | 39 | def knn(self, predictions): 40 | # perform knn 41 | correlation = torch.matmul(predictions, self.features.t()) 42 | sample_pred = torch.argmax(correlation, dim=1) 43 | class_pred = torch.index_select(self.targets, 0, sample_pred) 44 | return class_pred 45 | 46 | def mine_nearest_neighbors(self, topk, calculate_accuracy=True): 47 | # mine the topk nearest neighbors for every sample 48 | import faiss 49 | features = self.features.cpu().numpy() 50 | n, dim = features.shape[0], features.shape[1] 51 | index = faiss.IndexFlatIP(dim) 52 | index = faiss.index_cpu_to_all_gpus(index) 53 | index.add(features) 54 | distances, indices = index.search(features, topk+1) # Sample itself is included 55 | 56 | # evaluate 57 | if calculate_accuracy: 58 | targets = self.targets.cpu().numpy() 59 | neighbor_targets = np.take(targets, indices[:,1:], axis=0) # Exclude sample itself for eval 60 | anchor_targets = np.repeat(targets.reshape(-1,1), topk, axis=1) 61 | accuracy = np.mean(neighbor_targets == anchor_targets) 62 | return indices, accuracy 63 | 64 | else: 65 | return indices 66 | 67 | def reset(self): 68 | self.ptr = 0 69 | 70 | def update(self, features, targets): 71 | b = features.size(0) 72 | 73 | assert(b + self.ptr <= self.n) 74 | 75 | self.features[self.ptr:self.ptr+b].copy_(features.detach()) 76 | self.targets[self.ptr:self.ptr+b].copy_(targets.detach()) 77 | self.ptr += b 78 | 79 | def to(self, device): 80 | self.features = self.features.to(device) 81 | self.targets = self.targets.to(device) 82 | self.device = device 83 | 84 | def cpu(self): 85 | self.to('cpu') 86 | 87 | def cuda(self): 88 | self.to('cuda:0') 89 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | from PIL import Image 10 | from utils.mypath import MyPath 11 | from torchvision import transforms as tf 12 | from glob import glob 13 | 14 | 15 | class ImageNet(datasets.ImageFolder): 16 | def __init__(self, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): 17 | super(ImageNet, self).__init__(root=os.path.join(root, 'ILSVRC2012_img_%s' %(split)), 18 | transform=None) 19 | self.transform = transform 20 | self.split = split 21 | self.resize = tf.Resize(256) 22 | 23 | def __len__(self): 24 | return len(self.imgs) 25 | 26 | def __getitem__(self, index): 27 | path, target = self.imgs[index] 28 | with open(path, 'rb') as f: 29 | img = Image.open(f).convert('RGB') 30 | im_size = img.size 31 | img = self.resize(img) 32 | 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | 36 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index}} 37 | 38 | return out 39 | 40 | def get_image(self, index): 41 | path, target = self.imgs[index] 42 | with open(path, 'rb') as f: 43 | img = Image.open(f).convert('RGB') 44 | img = self.resize(img) 45 | return img 46 | 47 | 48 | class ImageNetSubset(data.Dataset): 49 | def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', 50 | transform=None): 51 | super(ImageNetSubset, self).__init__() 52 | 53 | self.root = os.path.join(root, 'ILSVRC2012_img_%s' %(split)) 54 | self.transform = transform 55 | self.split = split 56 | 57 | # Read the subset of classes to include (sorted) 58 | with open(subset_file, 'r') as f: 59 | result = f.read().splitlines() 60 | subdirs, class_names = [], [] 61 | for line in result: 62 | subdir, class_name = line.split(' ', 1) 63 | subdirs.append(subdir) 64 | class_names.append(class_name) 65 | 66 | # Gather the files (sorted) 67 | imgs = [] 68 | for i, subdir in enumerate(subdirs): 69 | subdir_path = os.path.join(self.root, subdir) 70 | files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG'))) 71 | for f in files: 72 | imgs.append((f, i)) 73 | self.imgs = imgs 74 | self.classes = class_names 75 | 76 | # Resize 77 | self.resize = tf.Resize(256) 78 | 79 | def get_image(self, index): 80 | path, target = self.imgs[index] 81 | with open(path, 'rb') as f: 82 | img = Image.open(f).convert('RGB') 83 | img = self.resize(img) 84 | return img 85 | 86 | def __len__(self): 87 | return len(self.imgs) 88 | 89 | def __getitem__(self, index): 90 | path, target = self.imgs[index] 91 | with open(path, 'rb') as f: 92 | img = Image.open(f).convert('RGB') 93 | im_size = img.size 94 | img = self.resize(img) 95 | class_name = self.classes[target] 96 | 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | 100 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index, 'class_name': class_name}} 101 | 102 | return out 103 | -------------------------------------------------------------------------------- /data/imagenet_subsets/imagenet_100.txt: -------------------------------------------------------------------------------- 1 | n01558993 robin, American robin, Turdus migratorius 2 | n01601694 water ouzel, dipper 3 | n01669191 box turtle, box tortoise 4 | n01751748 sea snake 5 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 6 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 7 | n01770393 scorpion 8 | n01855672 goose 9 | n01871265 tusker 10 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 11 | n02037110 oystercatcher, oyster catcher 12 | n02058221 albatross, mollymawk 13 | n02087046 toy terrier 14 | n02088632 bluetick 15 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 16 | n02093754 Border terrier 17 | n02094114 Norfolk terrier 18 | n02096177 cairn, cairn terrier 19 | n02097130 giant schnauzer 20 | n02097298 Scotch terrier, Scottish terrier, Scottie 21 | n02099267 flat-coated retriever 22 | n02100877 Irish setter, red setter 23 | n02104365 schipperke 24 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 25 | n02106030 collie 26 | n02106166 Border collie 27 | n02107142 Doberman, Doberman pinscher 28 | n02110341 dalmatian, coach dog, carriage dog 29 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 30 | n02120079 Arctic fox, white fox, Alopex lagopus 31 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 32 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 33 | n02128385 leopard, Panthera pardus 34 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 35 | n02277742 ringlet, ringlet butterfly 36 | n02325366 wood rabbit, cottontail, cottontail rabbit 37 | n02364673 guinea pig, Cavia cobaya 38 | n02484975 guenon, guenon monkey 39 | n02489166 proboscis monkey, Nasalis larvatus 40 | n02708093 analog clock 41 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 42 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 43 | n02906734 broom 44 | n02909870 bucket, pail 45 | n03085013 computer keyboard, keypad 46 | n03124170 cowboy hat, ten-gallon hat 47 | n03127747 crash helmet 48 | n03160309 dam, dike, dyke 49 | n03255030 dumbbell 50 | n03272010 electric guitar 51 | n03291819 envelope 52 | n03337140 file, file cabinet, filing cabinet 53 | n03450230 gown 54 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 55 | n03498962 hatchet 56 | n03530642 honeycomb 57 | n03623198 knee pad 58 | n03649909 lawn mower, mower 59 | n03710721 maillot, tank suit 60 | n03717622 manhole cover 61 | n03733281 maze, labyrinth 62 | n03759954 microphone, mike 63 | n03775071 mitten 64 | n03814639 neck brace 65 | n03837869 obelisk 66 | n03838899 oboe, hautboy, hautbois 67 | n03854065 organ, pipe organ 68 | n03929855 pickelhaube 69 | n03930313 picket fence, paling 70 | n03954731 plane, carpenter's plane, woodworking plane 71 | n03956157 planetarium 72 | n03983396 pop bottle, soda bottle 73 | n04004767 printer 74 | n04026417 purse 75 | n04065272 recreational vehicle, RV, R.V. 76 | n04200800 shoe shop, shoe-shop, shoe store 77 | n04209239 shower curtain 78 | n04235860 sleeping bag 79 | n04311004 steel arch bridge 80 | n04325704 stole 81 | n04336792 stretcher 82 | n04346328 stupa, tope 83 | n04380533 table lamp 84 | n04428191 thresher, thrasher, threshing machine 85 | n04443257 tobacco shop, tobacconist shop, tobacconist 86 | n04458633 totem pole 87 | n04483307 trimaran 88 | n04509417 unicycle, monocycle 89 | n04515003 upright, upright piano 90 | n04525305 vending machine 91 | n04554684 washer, automatic washer, washing machine 92 | n04591157 Windsor tie 93 | n04592741 wing 94 | n04606251 wreck 95 | n07583066 guacamole 96 | n07613480 trifle 97 | n07693725 bagel, beigel 98 | n07711569 mashed potato 99 | n07753592 banana 100 | n11879895 rapeseed 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | """ This file contains a list of packages and their versions that were used to produce the results. """ 2 | - _libgcc_mutex=0.1=main 3 | - blas=1.0=mkl 4 | - bzip2=1.0.8=h7b6447c_0 5 | - ca-certificates=2020.1.1=0 6 | - cairo=1.14.12=h8948797_3 7 | - certifi=2020.4.5.1=py37_0 8 | - cffi=1.14.0=py37h2e261b9_0 9 | - cmake=3.14.0=h52cb24c_0 10 | - cudatoolkit=10.0.130=0 11 | - cycler=0.10.0=py37_0 12 | - dbus=1.13.12=h746ee38_0 13 | - easydict=1.9=py_0 14 | - expat=2.2.6=he6710b0_0 15 | - faiss-gpu=1.6.3=py37h1a5d453_0 16 | - ffmpeg=4.0=hcdf2ecd_0 17 | - fontconfig=2.13.0=h9420a91_0 18 | - freeglut=3.0.0=hf484d3e_5 19 | - freetype=2.9.1=h8a8886c_1 20 | - glib=2.63.1=h5a9c865_0 21 | - graphite2=1.3.13=h23475e2_0 22 | - gst-plugins-base=1.14.0=hbbd80ab_1 23 | - gstreamer=1.14.0=hb453b48_1 24 | - h5py=2.8.0=py37h989c5e5_3 25 | - harfbuzz=1.8.8=hffaf4a1_0 26 | - hdf5=1.10.2=hba1933b_1 27 | - icu=58.2=h9c2bf20_1 28 | - imageio=2.8.0=py_0 29 | - intel-openmp=2020.0=166 30 | - jasper=2.0.14=h07fcdf6_1 31 | - joblib=0.14.1=py_0 32 | - jpeg=9b=h024ee3a_2 33 | - kiwisolver=1.1.0=py37he6710b0_0 34 | - krb5=1.17.1=h173b8e3_0 35 | - ld_impl_linux-64=2.33.1=h53a641e_7 36 | - libcurl=7.69.1=h20c2e04_0 37 | - libedit=3.1.20181209=hc058e9b_0 38 | - libffi=3.2.1=hd88cf55_4 39 | - libgcc-ng=9.1.0=hdf63c60_0 40 | - libgfortran-ng=7.3.0=hdf63c60_0 41 | - libglu=9.0.0=hf484d3e_1 42 | - libopencv=3.4.2=hb342d67_1 43 | - libopus=1.3.1=h7b6447c_0 44 | - libpng=1.6.37=hbc83047_0 45 | - libprotobuf=3.11.4=hd408876_0 46 | - libssh2=1.9.0=h1ba5d50_1 47 | - libstdcxx-ng=9.1.0=hdf63c60_0 48 | - libtiff=4.1.0=h2733197_0 49 | - libuuid=1.0.3=h1bed415_2 50 | - libvpx=1.7.0=h439df22_0 51 | - libxcb=1.13=h1bed415_1 52 | - libxml2=2.9.9=hea5a465_1 53 | - matplotlib=3.1.3=py37_0 54 | - matplotlib-base=3.1.3=py37hef1b27d_0 55 | - mkl=2020.0=166 56 | - mkl-service=2.3.0=py37he904b0f_0 57 | - mkl_fft=1.0.15=py37ha843d7b_0 58 | - mkl_random=1.1.0=py37hd6b4f25_0 59 | - ncurses=6.2=he6710b0_0 60 | - ninja=1.9.0=py37hfd86e86_0 61 | - numpy=1.18.1=py37h4f9e942_0 62 | - numpy-base=1.18.1=py37hde5b4d6_1 63 | - olefile=0.46=py_0 64 | - opencv=3.4.2=py37h6fd60c2_1 65 | - openssl=1.1.1g=h7b6447c_0 66 | - pcre=8.43=he6710b0_0 67 | - pillow=7.0.0=py37hb39fc2d_0 68 | - pip=20.0.2=py37_1 69 | - pixman=0.38.0=h7b6447c_0 70 | - protobuf=3.11.4=py37he6710b0_0 71 | - py-opencv=3.4.2=py37hb342d67_1 72 | - pycparser=2.20=py_0 73 | - pyparsing=2.4.6=py_0 74 | - pyqt=5.9.2=py37h05f1152_2 75 | - python=3.7.7=hcf32534_0_cpython 76 | - python-dateutil=2.8.1=py_0 77 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0 78 | - pyyaml=5.3.1=py37h7b6447c_0 79 | - qt=5.9.7=h5867ecd_1 80 | - readline=8.0=h7b6447c_0 81 | - rhash=1.3.8=h1ba5d50_0 82 | - scikit-learn=0.22.1=py37hd81dba3_0 83 | - scipy=1.4.1=py37h0b6359f_0 84 | - setuptools=46.1.3=py37_0 85 | - sip=4.19.8=py37hf484d3e_0 86 | - six=1.14.0=py37_0 87 | - sqlite=3.31.1=h7b6447c_0 88 | - swig=3.0.12=h38cdd7d_3 89 | - tensorboardx=2.0=py_0 90 | - termcolor=1.1.0=py37_1 91 | - tk=8.6.8=hbc83047_0 92 | - torchvision=0.5.0=py37_cu100 93 | - tornado=6.0.4=py37h7b6447c_1 94 | - typing=3.6.4=py37_0 95 | - wheel=0.34.2=py37_0 96 | - xz=5.2.4=h14c3975_4 97 | - yaml=0.1.7=had09818_2 98 | - zlib=1.2.11=h7b6447c_3 99 | - zstd=1.3.7=h0b5b093_0 100 | - pip: 101 | - blis==0.4.1 102 | - catalogue==1.0.0 103 | - chardet==3.0.4 104 | - cymem==2.0.3 105 | - en-core-web-sm==2.2.5 106 | - idna==2.9 107 | - importlib-metadata==1.6.0 108 | - murmurhash==1.0.2 109 | - plac==1.1.3 110 | - preshed==3.0.2 111 | - requests==2.23.0 112 | - spacy==2.2.4 113 | - srsly==1.0.2 114 | - thinc==7.4.0 115 | - tqdm==4.45.0 116 | - urllib3==1.25.8 117 | - wasabi==0.6.0 118 | - zipp==3.1.0 119 | -------------------------------------------------------------------------------- /tutorial_nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import numpy as np 8 | import torch 9 | 10 | from utils.config import create_config 11 | from utils.common_config import get_model, get_train_dataset, \ 12 | get_val_dataset, \ 13 | get_val_dataloader, \ 14 | get_val_transformations \ 15 | 16 | from utils.memory import MemoryBank 17 | from utils.train_utils import simclr_train 18 | from utils.utils import fill_memory_bank 19 | from termcolor import colored 20 | 21 | # Parser 22 | parser = argparse.ArgumentParser(description='Eval_nn') 23 | parser.add_argument('--config_env', 24 | help='Config file for the environment') 25 | parser.add_argument('--config_exp', 26 | help='Config file for the experiment') 27 | args = parser.parse_args() 28 | 29 | def main(): 30 | 31 | # Retrieve config file 32 | p = create_config(args.config_env, args.config_exp) 33 | print(colored(p, 'red')) 34 | 35 | # Model 36 | print(colored('Retrieve model', 'blue')) 37 | model = get_model(p) 38 | print('Model is {}'.format(model.__class__.__name__)) 39 | print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) 40 | print(model) 41 | model = model.cuda() 42 | 43 | # CUDNN 44 | print(colored('Set CuDNN benchmark', 'blue')) 45 | torch.backends.cudnn.benchmark = True 46 | 47 | # Dataset 48 | val_transforms = get_val_transformations(p) 49 | print('Validation transforms:', val_transforms) 50 | val_dataset = get_val_dataset(p, val_transforms) 51 | val_dataloader = get_val_dataloader(p, val_dataset) 52 | print('Dataset contains {} val samples'.format(len(val_dataset))) 53 | 54 | # Memory Bank 55 | print(colored('Build MemoryBank', 'blue')) 56 | base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval 57 | base_dataloader = get_val_dataloader(p, base_dataset) 58 | memory_bank_base = MemoryBank(len(base_dataset), 59 | p['model_kwargs']['features_dim'], 60 | p['num_classes'], p['criterion_kwargs']['temperature']) 61 | memory_bank_base.cuda() 62 | memory_bank_val = MemoryBank(len(val_dataset), 63 | p['model_kwargs']['features_dim'], 64 | p['num_classes'], p['criterion_kwargs']['temperature']) 65 | memory_bank_val.cuda() 66 | 67 | # Checkpoint 68 | assert os.path.exists(p['pretext_checkpoint']) 69 | print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue')) 70 | checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu') 71 | model.load_state_dict(checkpoint) 72 | model.cuda() 73 | 74 | # Save model 75 | torch.save(model.state_dict(), p['pretext_model']) 76 | 77 | # Mine the topk nearest neighbors at the very end (Train) 78 | # These will be served as input to the SCAN loss. 79 | print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue')) 80 | fill_memory_bank(base_dataloader, model, memory_bank_base) 81 | topk = 20 82 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 83 | indices, acc = memory_bank_base.mine_nearest_neighbors(topk) 84 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) 85 | np.save(p['topk_neighbors_train_path'], indices) 86 | 87 | # Mine the topk nearest neighbors at the very end (Val) 88 | # These will be used for validation. 89 | print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue')) 90 | fill_memory_bank(val_dataloader, model, memory_bank_val) 91 | topk = 5 92 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 93 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk) 94 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) 95 | np.save(p['topk_neighbors_val_path'], indices) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /data/augment.py: -------------------------------------------------------------------------------- 1 | # List of augmentations based on randaugment 2 | import random 3 | 4 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 5 | import numpy as np 6 | import torch 7 | from torchvision.transforms.transforms import Compose 8 | 9 | random_mirror = True 10 | 11 | def ShearX(img, v): 12 | if random_mirror and random.random() > 0.5: 13 | v = -v 14 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 15 | 16 | def ShearY(img, v): 17 | if random_mirror and random.random() > 0.5: 18 | v = -v 19 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 20 | 21 | def Identity(img, v): 22 | return img 23 | 24 | def TranslateX(img, v): 25 | if random_mirror and random.random() > 0.5: 26 | v = -v 27 | v = v * img.size[0] 28 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 29 | 30 | def TranslateY(img, v): 31 | if random_mirror and random.random() > 0.5: 32 | v = -v 33 | v = v * img.size[1] 34 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 35 | 36 | def TranslateXAbs(img, v): 37 | if random.random() > 0.5: 38 | v = -v 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 40 | 41 | def TranslateYAbs(img, v): 42 | if random.random() > 0.5: 43 | v = -v 44 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 45 | 46 | def Rotate(img, v): 47 | if random_mirror and random.random() > 0.5: 48 | v = -v 49 | return img.rotate(v) 50 | 51 | def AutoContrast(img, _): 52 | return PIL.ImageOps.autocontrast(img) 53 | 54 | def Invert(img, _): 55 | return PIL.ImageOps.invert(img) 56 | 57 | def Equalize(img, _): 58 | return PIL.ImageOps.equalize(img) 59 | 60 | def Solarize(img, v): 61 | return PIL.ImageOps.solarize(img, v) 62 | 63 | def Posterize(img, v): 64 | v = int(v) 65 | return PIL.ImageOps.posterize(img, v) 66 | 67 | def Contrast(img, v): 68 | return PIL.ImageEnhance.Contrast(img).enhance(v) 69 | 70 | def Color(img, v): 71 | return PIL.ImageEnhance.Color(img).enhance(v) 72 | 73 | def Brightness(img, v): 74 | return PIL.ImageEnhance.Brightness(img).enhance(v) 75 | 76 | def Sharpness(img, v): 77 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 78 | 79 | def augment_list(): 80 | l = [ 81 | (Identity, 0, 1), 82 | (AutoContrast, 0, 1), 83 | (Equalize, 0, 1), 84 | (Rotate, -30, 30), 85 | (Solarize, 0, 256), 86 | (Color, 0.05, 0.95), 87 | (Contrast, 0.05, 0.95), 88 | (Brightness, 0.05, 0.95), 89 | (Sharpness, 0.05, 0.95), 90 | (ShearX, -0.1, 0.1), 91 | (TranslateX, -0.1, 0.1), 92 | (TranslateY, -0.1, 0.1), 93 | (Posterize, 4, 8), 94 | (ShearY, -0.1, 0.1), 95 | ] 96 | return l 97 | 98 | 99 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 100 | 101 | class Augment: 102 | def __init__(self, n): 103 | self.n = n 104 | self.augment_list = augment_list() 105 | 106 | def __call__(self, img): 107 | ops = random.choices(self.augment_list, k=self.n) 108 | for op, minval, maxval in ops: 109 | val = (random.random()) * float(maxval - minval) + minval 110 | img = op(img, val) 111 | 112 | return img 113 | 114 | def get_augment(name): 115 | return augment_dict[name] 116 | 117 | def apply_augment(img, name, level): 118 | augment_fn, low, high = get_augment(name) 119 | return augment_fn(img.copy(), level * (high - low) + low) 120 | 121 | class Cutout(object): 122 | def __init__(self, n_holes, length, random=False): 123 | self.n_holes = n_holes 124 | self.length = length 125 | self.random = random 126 | 127 | def __call__(self, img): 128 | h = img.size(1) 129 | w = img.size(2) 130 | length = random.randint(1, self.length) 131 | mask = np.ones((h, w), np.float32) 132 | 133 | for n in range(self.n_holes): 134 | y = np.random.randint(h) 135 | x = np.random.randint(w) 136 | 137 | y1 = np.clip(y - length // 2, 0, h) 138 | y2 = np.clip(y + length // 2, 0, h) 139 | x1 = np.clip(x - length // 2, 0, w) 140 | x2 = np.clip(x + length // 2, 0, w) 141 | 142 | mask[y1: y2, x1: x2] = 0. 143 | 144 | mask = torch.from_numpy(mask) 145 | mask = mask.expand_as(img) 146 | img = img * mask 147 | 148 | return img 149 | -------------------------------------------------------------------------------- /moco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | import numpy as np 9 | 10 | from utils.config import create_config 11 | from utils.common_config import get_model, get_train_dataset,\ 12 | get_val_dataset, get_val_dataloader, get_val_transformations 13 | from utils.memory import MemoryBank 14 | from utils.utils import fill_memory_bank 15 | from termcolor import colored 16 | 17 | # Parser 18 | parser = argparse.ArgumentParser(description='MoCo') 19 | parser.add_argument('--config_env', 20 | help='Config file for the environment') 21 | parser.add_argument('--config_exp', 22 | help='Config file for the experiment') 23 | args = parser.parse_args() 24 | 25 | def main(): 26 | # Retrieve config file 27 | p = create_config(args.config_env, args.config_exp) 28 | print(colored(p, 'red')) 29 | 30 | 31 | # Model 32 | print(colored('Retrieve model', 'blue')) 33 | model = get_model(p) 34 | print('Model is {}'.format(model.__class__.__name__)) 35 | print(model) 36 | model = torch.nn.DataParallel(model) 37 | model = model.cuda() 38 | 39 | 40 | # CUDNN 41 | print(colored('Set CuDNN benchmark', 'blue')) 42 | torch.backends.cudnn.benchmark = True 43 | 44 | 45 | # Dataset 46 | print(colored('Retrieve dataset', 'blue')) 47 | transforms = get_val_transformations(p) 48 | train_dataset = get_train_dataset(p, transforms) 49 | val_dataset = get_val_dataset(p, transforms) 50 | train_dataloader = get_val_dataloader(p, train_dataset) 51 | val_dataloader = get_val_dataloader(p, val_dataset) 52 | print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset))) 53 | 54 | 55 | # Memory Bank 56 | print(colored('Build MemoryBank', 'blue')) 57 | memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'], p['temperature']) 58 | memory_bank_train.cuda() 59 | memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'], p['temperature']) 60 | memory_bank_val.cuda() 61 | 62 | 63 | # Load the official MoCoV2 checkpoint 64 | print(colored('Downloading moco v2 checkpoint', 'blue')) 65 | os.system('wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar') 66 | moco_state = torch.load('moco_v2_800ep_pretrain.pth.tar', map_location='cpu') 67 | 68 | 69 | # Transfer moco weights 70 | print(colored('Transfer MoCo weights to model', 'blue')) 71 | new_state_dict = {} 72 | state_dict = moco_state['state_dict'] 73 | for k in list(state_dict.keys()): 74 | # Copy backbone weights 75 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 76 | new_k = 'module.backbone.' + k[len('module.encoder_q.'):] 77 | new_state_dict[new_k] = state_dict[k] 78 | 79 | # Copy mlp weights 80 | elif k.startswith('module.encoder_q.fc'): 81 | new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.'):] 82 | new_state_dict[new_k] = state_dict[k] 83 | 84 | else: 85 | raise ValueError('Unexpected key {}'.format(k)) 86 | 87 | model.load_state_dict(new_state_dict) 88 | os.system('rm -rf moco_v2_800ep_pretrain.pth.tar') 89 | 90 | 91 | # Save final model 92 | print(colored('Save pretext model', 'blue')) 93 | torch.save(model.module.state_dict(), p['pretext_model']) 94 | model.module.contrastive_head = torch.nn.Identity() # In this case, we mine the neighbors before the MLP. 95 | 96 | 97 | # Mine the topk nearest neighbors (Train) 98 | # These will be used for training with the SCAN-Loss. 99 | topk = 50 100 | print(colored('Mine the nearest neighbors (Train)(Top-%d)' %(topk), 'blue')) 101 | transforms = get_val_transformations(p) 102 | train_dataset = get_train_dataset(p, transforms) 103 | fill_memory_bank(train_dataloader, model, memory_bank_train) 104 | indices, acc = memory_bank_train.mine_nearest_neighbors(topk) 105 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) 106 | np.save(p['topk_neighbors_train_path'], indices) 107 | 108 | 109 | # Mine the topk nearest neighbors (Validation) 110 | # These will be used for validation. 111 | topk = 5 112 | print(colored('Mine the nearest neighbors (Val)(Top-%d)' %(topk), 'blue')) 113 | fill_memory_bank(val_dataloader, model, memory_bank_val) 114 | print('Mine the neighbors') 115 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk) 116 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) 117 | np.save(p['topk_neighbors_val_path'], indices) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /selflabel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | 9 | from utils.config import create_config 10 | from utils.common_config import get_train_dataset, get_train_transformations,\ 11 | get_val_dataset, get_val_transformations,\ 12 | get_train_dataloader, get_val_dataloader,\ 13 | get_optimizer, get_model, adjust_learning_rate,\ 14 | get_criterion 15 | from utils.ema import EMA 16 | from utils.evaluate_utils import get_predictions, hungarian_evaluate 17 | from utils.train_utils import selflabel_train 18 | from termcolor import colored 19 | 20 | # Parser 21 | parser = argparse.ArgumentParser(description='Self-labeling') 22 | parser.add_argument('--config_env', 23 | help='Config file for the environment') 24 | parser.add_argument('--config_exp', 25 | help='Config file for the experiment') 26 | args = parser.parse_args() 27 | 28 | def main(): 29 | # Retrieve config file 30 | p = create_config(args.config_env, args.config_exp) 31 | print(colored(p, 'red')) 32 | 33 | # Get model 34 | print(colored('Retrieve model', 'blue')) 35 | model = get_model(p, p['scan_model']) 36 | print(model) 37 | model = torch.nn.DataParallel(model) 38 | model = model.cuda() 39 | 40 | # Get criterion 41 | print(colored('Get loss', 'blue')) 42 | criterion = get_criterion(p) 43 | criterion.cuda() 44 | print(criterion) 45 | 46 | # CUDNN 47 | print(colored('Set CuDNN benchmark', 'blue')) 48 | torch.backends.cudnn.benchmark = True 49 | 50 | # Optimizer 51 | print(colored('Retrieve optimizer', 'blue')) 52 | optimizer = get_optimizer(p, model) 53 | print(optimizer) 54 | 55 | # Dataset 56 | print(colored('Retrieve dataset', 'blue')) 57 | 58 | # Transforms 59 | strong_transforms = get_train_transformations(p) 60 | val_transforms = get_val_transformations(p) 61 | train_dataset = get_train_dataset(p, {'standard': val_transforms, 'augment': strong_transforms}, 62 | split='train', to_augmented_dataset=True) 63 | train_dataloader = get_train_dataloader(p, train_dataset) 64 | val_dataset = get_val_dataset(p, val_transforms) 65 | val_dataloader = get_val_dataloader(p, val_dataset) 66 | print(colored('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)), 'yellow')) 67 | 68 | # Checkpoint 69 | if os.path.exists(p['selflabel_checkpoint']): 70 | print(colored('Restart from checkpoint {}'.format(p['selflabel_checkpoint']), 'blue')) 71 | checkpoint = torch.load(p['selflabel_checkpoint'], map_location='cpu') 72 | model.load_state_dict(checkpoint['model']) 73 | optimizer.load_state_dict(checkpoint['optimizer']) 74 | start_epoch = checkpoint['epoch'] 75 | 76 | else: 77 | print(colored('No checkpoint file at {}'.format(p['selflabel_checkpoint']), 'blue')) 78 | start_epoch = 0 79 | 80 | # EMA 81 | if p['use_ema']: 82 | ema = EMA(model, alpha=p['ema_alpha']) 83 | else: 84 | ema = None 85 | 86 | # Main loop 87 | print(colored('Starting main loop', 'blue')) 88 | 89 | for epoch in range(start_epoch, p['epochs']): 90 | print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow')) 91 | print(colored('-'*10, 'yellow')) 92 | 93 | # Adjust lr 94 | lr = adjust_learning_rate(p, optimizer, epoch) 95 | print('Adjusted learning rate to {:.5f}'.format(lr)) 96 | 97 | # Perform self-labeling 98 | print('Train ...') 99 | selflabel_train(train_dataloader, model, criterion, optimizer, epoch, ema=ema) 100 | 101 | # Evaluate (To monitor progress - Not for validation) 102 | print('Evaluate ...') 103 | predictions = get_predictions(p, val_dataloader, model) 104 | clustering_stats = hungarian_evaluate(0, predictions, compute_confusion_matrix=False) 105 | print(clustering_stats) 106 | 107 | # Checkpoint 108 | print('Checkpoint ...') 109 | torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 110 | 'epoch': epoch + 1}, p['selflabel_checkpoint']) 111 | torch.save(model.module.state_dict(), p['selflabel_model']) 112 | 113 | # Evaluate and save the final model 114 | print(colored('Evaluate model at the end', 'blue')) 115 | predictions = get_predictions(p, val_dataloader, model) 116 | clustering_stats = hungarian_evaluate(0, predictions, 117 | class_names=val_dataset.classes, 118 | compute_confusion_matrix=True, 119 | confusion_matrix_file=os.path.join(p['selflabel_dir'], 'confusion_matrix.png')) 120 | print(clustering_stats) 121 | torch.save(model.module.state_dict(), p['selflabel_model']) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import numpy as np 7 | from utils.utils import AverageMeter, ProgressMeter 8 | 9 | 10 | def simclr_train(train_loader, model, criterion, optimizer, epoch): 11 | """ 12 | Train according to the scheme from SimCLR 13 | https://arxiv.org/abs/2002.05709 14 | """ 15 | losses = AverageMeter('Loss', ':.4e') 16 | progress = ProgressMeter(len(train_loader), 17 | [losses], 18 | prefix="Epoch: [{}]".format(epoch)) 19 | 20 | model.train() 21 | 22 | for i, batch in enumerate(train_loader): 23 | images = batch['image'] 24 | images_augmented = batch['image_augmented'] 25 | b, c, h, w = images.size() 26 | input_ = torch.cat([images.unsqueeze(1), images_augmented.unsqueeze(1)], dim=1) 27 | input_ = input_.view(-1, c, h, w) 28 | input_ = input_.cuda(non_blocking=True) 29 | targets = batch['target'].cuda(non_blocking=True) 30 | 31 | output = model(input_).view(b, 2, -1) 32 | loss = criterion(output) 33 | losses.update(loss.item()) 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | if i % 25 == 0: 40 | progress.display(i) 41 | 42 | 43 | def scan_train(train_loader, model, criterion, optimizer, epoch, update_cluster_head_only=False): 44 | """ 45 | Train w/ SCAN-Loss 46 | """ 47 | total_losses = AverageMeter('Total Loss', ':.4e') 48 | consistency_losses = AverageMeter('Consistency Loss', ':.4e') 49 | entropy_losses = AverageMeter('Entropy', ':.4e') 50 | progress = ProgressMeter(len(train_loader), 51 | [total_losses, consistency_losses, entropy_losses], 52 | prefix="Epoch: [{}]".format(epoch)) 53 | 54 | if update_cluster_head_only: 55 | model.eval() # No need to update BN 56 | else: 57 | model.train() # Update BN 58 | 59 | for i, batch in enumerate(train_loader): 60 | # Forward pass 61 | anchors = batch['anchor'].cuda(non_blocking=True) 62 | neighbors = batch['neighbor'].cuda(non_blocking=True) 63 | 64 | if update_cluster_head_only: # Only calculate gradient for backprop of linear layer 65 | with torch.no_grad(): 66 | anchors_features = model(anchors, forward_pass='backbone') 67 | neighbors_features = model(neighbors, forward_pass='backbone') 68 | anchors_output = model(anchors_features, forward_pass='head') 69 | neighbors_output = model(neighbors_features, forward_pass='head') 70 | 71 | else: # Calculate gradient for backprop of complete network 72 | anchors_output = model(anchors) 73 | neighbors_output = model(neighbors) 74 | 75 | # Loss for every head 76 | total_loss, consistency_loss, entropy_loss = [], [], [] 77 | for anchors_output_subhead, neighbors_output_subhead in zip(anchors_output, neighbors_output): 78 | total_loss_, consistency_loss_, entropy_loss_ = criterion(anchors_output_subhead, 79 | neighbors_output_subhead) 80 | total_loss.append(total_loss_) 81 | consistency_loss.append(consistency_loss_) 82 | entropy_loss.append(entropy_loss_) 83 | 84 | # Register the mean loss and backprop the total loss to cover all subheads 85 | total_losses.update(np.mean([v.item() for v in total_loss])) 86 | consistency_losses.update(np.mean([v.item() for v in consistency_loss])) 87 | entropy_losses.update(np.mean([v.item() for v in entropy_loss])) 88 | 89 | total_loss = torch.sum(torch.stack(total_loss, dim=0)) 90 | 91 | optimizer.zero_grad() 92 | total_loss.backward() 93 | optimizer.step() 94 | 95 | if i % 25 == 0: 96 | progress.display(i) 97 | 98 | 99 | def selflabel_train(train_loader, model, criterion, optimizer, epoch, ema=None): 100 | """ 101 | Self-labeling based on confident samples 102 | """ 103 | losses = AverageMeter('Loss', ':.4e') 104 | progress = ProgressMeter(len(train_loader), [losses], 105 | prefix="Epoch: [{}]".format(epoch)) 106 | model.train() 107 | 108 | for i, batch in enumerate(train_loader): 109 | images = batch['image'].cuda(non_blocking=True) 110 | images_augmented = batch['image_augmented'].cuda(non_blocking=True) 111 | 112 | with torch.no_grad(): 113 | output = model(images)[0] 114 | output_augmented = model(images_augmented)[0] 115 | 116 | loss = criterion(output, output_augmented) 117 | losses.update(loss.item()) 118 | 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if ema is not None: # Apply EMA to update the weights of the network 124 | ema.update_params(model) 125 | ema.apply_shadow(model) 126 | 127 | if i % 25 == 0: 128 | progress.display(i) 129 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1, is_last=False): 13 | super(BasicBlock, self).__init__() 14 | self.is_last = is_last 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion * planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | preact = out 32 | out = F.relu(out) 33 | if self.is_last: 34 | return out, preact 35 | else: 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1, is_last=False): 43 | super(Bottleneck, self).__init__() 44 | self.is_last = is_last 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion * planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | preact = out 65 | out = F.relu(out) 66 | if self.is_last: 67 | return out, preact 68 | else: 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 78 | bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 | 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | # Zero-initialize the last BN in each residual branch, 94 | # so that the residual branch starts with zeros, and each residual block behaves 95 | # like an identity. This improves the model by 0.2~0.3% according to: 96 | # https://arxiv.org/abs/1706.02677 97 | if zero_init_residual: 98 | for m in self.modules(): 99 | if isinstance(m, Bottleneck): 100 | nn.init.constant_(m.bn3.weight, 0) 101 | elif isinstance(m, BasicBlock): 102 | nn.init.constant_(m.bn2.weight, 0) 103 | 104 | def _make_layer(self, block, planes, num_blocks, stride): 105 | strides = [stride] + [1] * (num_blocks - 1) 106 | layers = [] 107 | for i in range(num_blocks): 108 | stride = strides[i] 109 | layers.append(block(self.in_planes, planes, stride)) 110 | self.in_planes = planes * block.expansion 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = F.relu(self.bn1(self.conv1(x))) 115 | out = self.layer1(out) 116 | out = self.layer2(out) 117 | out = self.layer3(out) 118 | out = self.layer4(out) 119 | out = self.avgpool(out) 120 | out = torch.flatten(out, 1) 121 | return out 122 | 123 | 124 | def resnet18(**kwargs): 125 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} 126 | -------------------------------------------------------------------------------- /models/resnet_stl.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1, is_last=False): 13 | super(BasicBlock, self).__init__() 14 | self.is_last = is_last 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion * planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | preact = out 32 | out = F.relu(out) 33 | if self.is_last: 34 | return out, preact 35 | else: 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1, is_last=False): 43 | super(Bottleneck, self).__init__() 44 | self.is_last = is_last 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion * planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | preact = out 65 | out = F.relu(out) 66 | if self.is_last: 67 | return out, preact 68 | else: 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 78 | bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.avgpool = nn.AvgPool2d(7, stride=1) 86 | 87 | for m in self.modules(): 88 | if isinstance(m, nn.Conv2d): 89 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 90 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 91 | nn.init.constant_(m.weight, 1) 92 | nn.init.constant_(m.bias, 0) 93 | 94 | # Zero-initialize the last BN in each residual branch, 95 | # so that the residual branch starts with zeros, and each residual block behaves 96 | # like an identity. This improves the model by 0.2~0.3% according to: 97 | # https://arxiv.org/abs/1706.02677 98 | if zero_init_residual: 99 | for m in self.modules(): 100 | if isinstance(m, Bottleneck): 101 | nn.init.constant_(m.bn3.weight, 0) 102 | elif isinstance(m, BasicBlock): 103 | nn.init.constant_(m.bn2.weight, 0) 104 | 105 | def _make_layer(self, block, planes, num_blocks, stride): 106 | strides = [stride] + [1] * (num_blocks - 1) 107 | layers = [] 108 | for i in range(num_blocks): 109 | stride = strides[i] 110 | layers.append(block(self.in_planes, planes, stride)) 111 | self.in_planes = planes * block.expansion 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | out = self.maxpool(F.relu(self.bn1(self.conv1(x)))) 116 | out = self.layer1(out) 117 | out = self.layer2(out) 118 | out = self.layer3(out) 119 | out = self.layer4(out) 120 | out = self.avgpool(out) 121 | out = torch.flatten(out, 1) 122 | return out 123 | 124 | 125 | def resnet18(**kwargs): 126 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} 127 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | EPS=1e-8 9 | 10 | 11 | class MaskedCrossEntropyLoss(nn.Module): 12 | def __init__(self): 13 | super(MaskedCrossEntropyLoss, self).__init__() 14 | 15 | def forward(self, input, target, mask, weight, reduction='mean'): 16 | if not (mask != 0).any(): 17 | raise ValueError('Mask in MaskedCrossEntropyLoss is all zeros.') 18 | target = torch.masked_select(target, mask) 19 | b, c = input.size() 20 | n = target.size(0) 21 | input = torch.masked_select(input, mask.view(b, 1)).view(n, c) 22 | return F.cross_entropy(input, target, weight = weight, reduction = reduction) 23 | 24 | 25 | class ConfidenceBasedCE(nn.Module): 26 | def __init__(self, threshold, apply_class_balancing): 27 | super(ConfidenceBasedCE, self).__init__() 28 | self.loss = MaskedCrossEntropyLoss() 29 | self.softmax = nn.Softmax(dim = 1) 30 | self.threshold = threshold 31 | self.apply_class_balancing = apply_class_balancing 32 | 33 | def forward(self, anchors_weak, anchors_strong): 34 | """ 35 | Loss function during self-labeling 36 | 37 | input: logits for original samples and for its strong augmentations 38 | output: cross entropy 39 | """ 40 | # Retrieve target and mask based on weakly augmentated anchors 41 | weak_anchors_prob = self.softmax(anchors_weak) 42 | max_prob, target = torch.max(weak_anchors_prob, dim = 1) 43 | mask = max_prob > self.threshold 44 | b, c = weak_anchors_prob.size() 45 | target_masked = torch.masked_select(target, mask.squeeze()) 46 | n = target_masked.size(0) 47 | 48 | # Inputs are strongly augmented anchors 49 | input_ = anchors_strong 50 | 51 | # Class balancing weights 52 | if self.apply_class_balancing: 53 | idx, counts = torch.unique(target_masked, return_counts = True) 54 | freq = 1/(counts.float()/n) 55 | weight = torch.ones(c).cuda() 56 | weight[idx] = freq 57 | 58 | else: 59 | weight = None 60 | 61 | # Loss 62 | loss = self.loss(input_, target, mask, weight = weight, reduction='mean') 63 | 64 | return loss 65 | 66 | 67 | def entropy(x, input_as_probabilities): 68 | """ 69 | Helper function to compute the entropy over the batch 70 | 71 | input: batch w/ shape [b, num_classes] 72 | output: entropy value [is ideally -log(num_classes)] 73 | """ 74 | 75 | if input_as_probabilities: 76 | x_ = torch.clamp(x, min = EPS) 77 | b = x_ * torch.log(x_) 78 | else: 79 | b = F.softmax(x, dim = 1) * F.log_softmax(x, dim = 1) 80 | 81 | if len(b.size()) == 2: # Sample-wise entropy 82 | return -b.sum(dim = 1).mean() 83 | elif len(b.size()) == 1: # Distribution-wise entropy 84 | return - b.sum() 85 | else: 86 | raise ValueError('Input tensor is %d-Dimensional' %(len(b.size()))) 87 | 88 | 89 | class SCANLoss(nn.Module): 90 | def __init__(self, entropy_weight = 2.0): 91 | super(SCANLoss, self).__init__() 92 | self.softmax = nn.Softmax(dim = 1) 93 | self.bce = nn.BCELoss() 94 | self.entropy_weight = entropy_weight # Default = 2.0 95 | 96 | def forward(self, anchors, neighbors): 97 | """ 98 | input: 99 | - anchors: logits for anchor images w/ shape [b, num_classes] 100 | - neighbors: logits for neighbor images w/ shape [b, num_classes] 101 | 102 | output: 103 | - Loss 104 | """ 105 | # Softmax 106 | b, n = anchors.size() 107 | anchors_prob = self.softmax(anchors) 108 | positives_prob = self.softmax(neighbors) 109 | 110 | # Similarity in output space 111 | similarity = torch.bmm(anchors_prob.view(b, 1, n), positives_prob.view(b, n, 1)).squeeze() 112 | ones = torch.ones_like(similarity) 113 | consistency_loss = self.bce(similarity, ones) 114 | 115 | # Entropy loss 116 | entropy_loss = entropy(torch.mean(anchors_prob, 0), input_as_probabilities = True) 117 | 118 | # Total loss 119 | total_loss = consistency_loss - self.entropy_weight * entropy_loss 120 | 121 | return total_loss, consistency_loss, entropy_loss 122 | 123 | 124 | class SimCLRLoss(nn.Module): 125 | # Based on the implementation of SupContrast 126 | def __init__(self, temperature): 127 | super(SimCLRLoss, self).__init__() 128 | self.temperature = temperature 129 | 130 | 131 | def forward(self, features): 132 | """ 133 | input: 134 | - features: hidden feature representation of shape [b, 2, dim] 135 | 136 | output: 137 | - loss: loss computed according to SimCLR 138 | """ 139 | 140 | b, n, dim = features.size() 141 | assert(n == 2) 142 | mask = torch.eye(b, dtype=torch.float32).cuda() 143 | 144 | contrast_features = torch.cat(torch.unbind(features, dim=1), dim=0) 145 | anchor = features[:, 0] 146 | 147 | # Dot product 148 | dot_product = torch.matmul(anchor, contrast_features.T) / self.temperature 149 | 150 | # Log-sum trick for numerical stability 151 | logits_max, _ = torch.max(dot_product, dim=1, keepdim=True) 152 | logits = dot_product - logits_max.detach() 153 | 154 | mask = mask.repeat(1, 2) 155 | logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(b).view(-1, 1).cuda(), 0) 156 | mask = mask * logits_mask 157 | 158 | # Log-softmax 159 | exp_logits = torch.exp(logits) * logits_mask 160 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 161 | 162 | # Mean log-likelihood for positive 163 | loss = - ((mask * log_prob).sum(1) / mask.sum(1)).mean() 164 | 165 | return loss 166 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import torch 7 | import yaml 8 | from termcolor import colored 9 | from utils.common_config import get_val_dataset, get_val_transformations, get_val_dataloader,\ 10 | get_model 11 | from utils.evaluate_utils import get_predictions, hungarian_evaluate 12 | from utils.memory import MemoryBank 13 | from utils.utils import fill_memory_bank 14 | from PIL import Image 15 | 16 | FLAGS = argparse.ArgumentParser(description='Evaluate models from the model zoo') 17 | FLAGS.add_argument('--config_exp', help='Location of config file') 18 | FLAGS.add_argument('--model', help='Location where model is saved') 19 | FLAGS.add_argument('--visualize_prototypes', action='store_true', 20 | help='Show the prototpye for each cluster') 21 | args = FLAGS.parse_args() 22 | 23 | def main(): 24 | 25 | # Read config file 26 | print(colored('Read config file {} ...'.format(args.config_exp), 'blue')) 27 | with open(args.config_exp, 'r') as stream: 28 | config = yaml.safe_load(stream) 29 | config['batch_size'] = 512 # To make sure we can evaluate on a single 1080ti 30 | print(config) 31 | 32 | # Get dataset 33 | print(colored('Get validation dataset ...', 'blue')) 34 | transforms = get_val_transformations(config) 35 | dataset = get_val_dataset(config, transforms) 36 | dataloader = get_val_dataloader(config, dataset) 37 | print('Number of samples: {}'.format(len(dataset))) 38 | 39 | # Get model 40 | print(colored('Get model ...', 'blue')) 41 | model = get_model(config) 42 | print(model) 43 | 44 | # Read model weights 45 | print(colored('Load model weights ...', 'blue')) 46 | state_dict = torch.load(args.model, map_location='cpu') 47 | 48 | if config['setup'] in ['simclr', 'moco', 'selflabel']: 49 | model.load_state_dict(state_dict) 50 | 51 | elif config['setup'] == 'scan': 52 | model.load_state_dict(state_dict['model']) 53 | 54 | else: 55 | raise NotImplementedError 56 | 57 | # CUDA 58 | model.cuda() 59 | 60 | # Perform evaluation 61 | if config['setup'] in ['simclr', 'moco']: 62 | print(colored('Perform evaluation of the pretext task (setup={}).'.format(config['setup']), 'blue')) 63 | print('Create Memory Bank') 64 | if config['setup'] == 'simclr': # Mine neighbors after MLP 65 | memory_bank = MemoryBank(len(dataset), config['model_kwargs']['features_dim'], 66 | config['num_classes'], config['criterion_kwargs']['temperature']) 67 | 68 | else: # Mine neighbors before MLP 69 | memory_bank = MemoryBank(len(dataset), config['model_kwargs']['features_dim'], 70 | config['num_classes'], config['temperature']) 71 | memory_bank.cuda() 72 | 73 | print('Fill Memory Bank') 74 | fill_memory_bank(dataloader, model, memory_bank) 75 | 76 | print('Mine the nearest neighbors') 77 | for topk in [1, 5, 20]: # Similar to Fig 2 in paper 78 | _, acc = memory_bank.mine_nearest_neighbors(topk) 79 | print('Accuracy of top-{} nearest neighbors on validation set is {:.2f}'.format(topk, 100*acc)) 80 | 81 | 82 | elif config['setup'] in ['scan', 'selflabel']: 83 | print(colored('Perform evaluation of the clustering model (setup={}).'.format(config['setup']), 'blue')) 84 | head = state_dict['head'] if config['setup'] == 'scan' else 0 85 | predictions, features = get_predictions(config, dataloader, model, return_features=True) 86 | clustering_stats = hungarian_evaluate(head, predictions, dataset.classes, 87 | compute_confusion_matrix=True) 88 | print(clustering_stats) 89 | if args.visualize_prototypes: 90 | prototype_indices = get_prototypes(config, predictions[head], features, model) 91 | visualize_indices(prototype_indices, dataset, clustering_stats['hungarian_match']) 92 | else: 93 | raise NotImplementedError 94 | 95 | @torch.no_grad() 96 | def get_prototypes(config, predictions, features, model, topk=10): 97 | import torch.nn.functional as F 98 | 99 | # Get topk most certain indices and pred labels 100 | print('Get topk') 101 | probs = predictions['probabilities'] 102 | n_classes = probs.shape[1] 103 | dims = features.shape[1] 104 | max_probs, pred_labels = torch.max(probs, dim = 1) 105 | indices = torch.zeros((n_classes, topk)) 106 | for pred_id in range(n_classes): 107 | probs_copy = max_probs.clone() 108 | mask_out = ~(pred_labels == pred_id) 109 | probs_copy[mask_out] = -1 110 | conf_vals, conf_idx = torch.topk(probs_copy, k = topk, largest = True, sorted = True) 111 | indices[pred_id, :] = conf_idx 112 | 113 | # Get corresponding features 114 | selected_features = torch.index_select(features, dim=0, index=indices.view(-1).long()) 115 | selected_features = selected_features.unsqueeze(1).view(n_classes, -1, dims) 116 | 117 | # Get mean feature per class 118 | mean_features = torch.mean(selected_features, dim=1) 119 | 120 | # Get min distance wrt to mean 121 | diff_features = selected_features - mean_features.unsqueeze(1) 122 | diff_norm = torch.norm(diff_features, 2, dim=2) 123 | 124 | # Get final indices 125 | _, best_indices = torch.min(diff_norm, dim=1) 126 | one_hot = F.one_hot(best_indices.long(), indices.size(1)).byte() 127 | proto_indices = torch.masked_select(indices.view(-1), one_hot.view(-1)) 128 | proto_indices = proto_indices.int().tolist() 129 | return proto_indices 130 | 131 | def visualize_indices(indices, dataset, hungarian_match): 132 | import matplotlib.pyplot as plt 133 | import numpy as np 134 | 135 | for idx in indices: 136 | img = np.array(dataset.get_image(idx)).astype(np.uint8) 137 | img = Image.fromarray(img) 138 | plt.figure() 139 | plt.axis('off') 140 | plt.imshow(img) 141 | plt.show() 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /scan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | 9 | from termcolor import colored 10 | from utils.config import create_config 11 | from utils.common_config import get_train_transformations, get_val_transformations,\ 12 | get_train_dataset, get_train_dataloader,\ 13 | get_val_dataset, get_val_dataloader,\ 14 | get_optimizer, get_model, get_criterion,\ 15 | adjust_learning_rate 16 | from utils.evaluate_utils import get_predictions, scan_evaluate, hungarian_evaluate 17 | from utils.train_utils import scan_train 18 | 19 | FLAGS = argparse.ArgumentParser(description='SCAN Loss') 20 | FLAGS.add_argument('--config_env', help='Location of path config file') 21 | FLAGS.add_argument('--config_exp', help='Location of experiments config file') 22 | 23 | def main(): 24 | args = FLAGS.parse_args() 25 | p = create_config(args.config_env, args.config_exp) 26 | print(colored(p, 'red')) 27 | 28 | # CUDNN 29 | torch.backends.cudnn.benchmark = True 30 | 31 | # Data 32 | print(colored('Get dataset and dataloaders', 'blue')) 33 | train_transformations = get_train_transformations(p) 34 | val_transformations = get_val_transformations(p) 35 | train_dataset = get_train_dataset(p, train_transformations, 36 | split='train', to_neighbors_dataset = True) 37 | val_dataset = get_val_dataset(p, val_transformations, to_neighbors_dataset = True) 38 | train_dataloader = get_train_dataloader(p, train_dataset) 39 | val_dataloader = get_val_dataloader(p, val_dataset) 40 | print('Train transforms:', train_transformations) 41 | print('Validation transforms:', val_transformations) 42 | print('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset))) 43 | 44 | # Model 45 | print(colored('Get model', 'blue')) 46 | model = get_model(p, p['pretext_model']) 47 | print(model) 48 | model = torch.nn.DataParallel(model) 49 | model = model.cuda() 50 | 51 | # Optimizer 52 | print(colored('Get optimizer', 'blue')) 53 | optimizer = get_optimizer(p, model, p['update_cluster_head_only']) 54 | print(optimizer) 55 | 56 | # Warning 57 | if p['update_cluster_head_only']: 58 | print(colored('WARNING: SCAN will only update the cluster head', 'red')) 59 | 60 | # Loss function 61 | print(colored('Get loss', 'blue')) 62 | criterion = get_criterion(p) 63 | criterion.cuda() 64 | print(criterion) 65 | 66 | # Checkpoint 67 | if os.path.exists(p['scan_checkpoint']): 68 | print(colored('Restart from checkpoint {}'.format(p['scan_checkpoint']), 'blue')) 69 | checkpoint = torch.load(p['scan_checkpoint'], map_location='cpu') 70 | model.load_state_dict(checkpoint['model']) 71 | optimizer.load_state_dict(checkpoint['optimizer']) 72 | start_epoch = checkpoint['epoch'] 73 | best_loss = checkpoint['best_loss'] 74 | best_loss_head = checkpoint['best_loss_head'] 75 | 76 | else: 77 | print(colored('No checkpoint file at {}'.format(p['scan_checkpoint']), 'blue')) 78 | start_epoch = 0 79 | best_loss = 1e4 80 | best_loss_head = None 81 | 82 | # Main loop 83 | print(colored('Starting main loop', 'blue')) 84 | 85 | for epoch in range(start_epoch, p['epochs']): 86 | print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow')) 87 | print(colored('-'*15, 'yellow')) 88 | 89 | # Adjust lr 90 | lr = adjust_learning_rate(p, optimizer, epoch) 91 | print('Adjusted learning rate to {:.5f}'.format(lr)) 92 | 93 | # Train 94 | print('Train ...') 95 | scan_train(train_dataloader, model, criterion, optimizer, epoch, p['update_cluster_head_only']) 96 | 97 | # Evaluate 98 | print('Make prediction on validation set ...') 99 | predictions = get_predictions(p, val_dataloader, model) 100 | 101 | print('Evaluate based on SCAN loss ...') 102 | scan_stats = scan_evaluate(predictions) 103 | print(scan_stats) 104 | lowest_loss_head = scan_stats['lowest_loss_head'] 105 | lowest_loss = scan_stats['lowest_loss'] 106 | 107 | if lowest_loss < best_loss: 108 | print('New lowest loss on validation set: %.4f -> %.4f' %(best_loss, lowest_loss)) 109 | print('Lowest loss head is %d' %(lowest_loss_head)) 110 | best_loss = lowest_loss 111 | best_loss_head = lowest_loss_head 112 | torch.save({'model': model.module.state_dict(), 'head': best_loss_head}, p['scan_model']) 113 | 114 | else: 115 | print('No new lowest loss on validation set: %.4f -> %.4f' %(best_loss, lowest_loss)) 116 | print('Lowest loss head is %d' %(best_loss_head)) 117 | 118 | print('Evaluate with hungarian matching algorithm ...') 119 | clustering_stats = hungarian_evaluate(lowest_loss_head, predictions, compute_confusion_matrix=False) 120 | print(clustering_stats) 121 | 122 | # Checkpoint 123 | print('Checkpoint ...') 124 | torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 125 | 'epoch': epoch + 1, 'best_loss': best_loss, 'best_loss_head': best_loss_head}, 126 | p['scan_checkpoint']) 127 | 128 | # Evaluate and save the final model 129 | print(colored('Evaluate best model based on SCAN metric at the end', 'blue')) 130 | model_checkpoint = torch.load(p['scan_model'], map_location='cpu') 131 | model.module.load_state_dict(model_checkpoint['model']) 132 | predictions = get_predictions(p, val_dataloader, model) 133 | clustering_stats = hungarian_evaluate(model_checkpoint['head'], predictions, 134 | class_names=val_dataset.dataset.classes, 135 | compute_confusion_matrix=True, 136 | confusion_matrix_file=os.path.join(p['scan_dir'], 'confusion_matrix.png')) 137 | print(clustering_stats) 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /TUTORIAL.md: -------------------------------------------------------------------------------- 1 | # Tutorial: Semantic Clustering on STL-10 with SCAN 2 | 3 | You can follow this guide to obtain the semantic clusters with SCAN on the STL-10 dataset. The procedure is equivalent for the other datasets. 4 | 5 | ## Contents 6 | 1. [Preparation](#preparation) 7 | 0. [Pretext task](#pretext-task) 8 | 0. [Semantic clustering](#semantic-clustering) 9 | 0. [Visualization](#visualization) 10 | 0. [Citation](#citation) 11 | 12 | ## Preparation 13 | ### Repository 14 | Clone the repository and navigate to the directory: 15 | ```bash 16 | git clone https://github.com/wvangansbeke/Unsupervised-Classification.git 17 | cd Unsupervised-Classification 18 | ``` 19 | 20 | ### Environment 21 | Activate your python environment containing the packages in the README.md. 22 | Make sure you have a GPU available (ideally a 1080TI or better) and set $gpu_ids to your desired gpu number(s): 23 | ```bash 24 | conda activate your_anaconda_env 25 | export CUDA_VISIBLE_DEVICES=$gpu_ids 26 | ``` 27 | I will use an environment with Python 3.7, Pytorch 1.6, CUDA 10.2 and CUDNN 7.5.6 for this example. 28 | 29 | ### Paths 30 | Adapt the path in `configs/env.yml` to `repository_eccv/`, since this directory will be used in this tutorial. 31 | Make the following directories. The models will be saved there, other directories will be made on the fly if necessary. 32 | ```bash 33 | mkdir -p repository_eccv/stl-10/pretext/ 34 | ``` 35 | Set the path in `utils/mypath.py` to your dataset root path as mentioned in the README.md 36 | 37 | ## Pretext task 38 | First we will run the pretext task (i.e. SimCLR) on the train+unlabeled set of STL-10. 39 | Feel free to run this task with the correct config file: 40 | ``` 41 | python simclr.py --config_env configs/env.yml --config_exp configs/pretext/simclr_stl10.yml 42 | ``` 43 | 44 | In order to save time, we provide pretrained models in the README.md for all the datasets discussed in the paper. 45 | First, download the pretrained model [here](https://drive.google.com/file/d/1261NDFfXuKR2Dh4RWHYYhcicdcPag9NZ/view?usp=sharing) and save it in your experiments directory. Then, move the downloaded model to the correct location (i.e. `repository_eccv/stl-10/pretext/`) and calculate the nearest neighbors. This can be achieved by running the following commands: 46 | ```bash 47 | mv simclr_stl-10.pth.tar repository_eccv/stl-10/pretext/checkpoint.pth.tar # Move model to correct location 48 | python tutorial_nn.py --config_env configs/env.yml --config_exp configs/pretext/simclr_stl10.yml # Compute neighbors 49 | ``` 50 | 51 | You should get the following results: 52 | ``` 53 | > Restart from checkpoint repository_eccv/stl-10/pretext/checkpoint.pth.tar 54 | > Fill memory bank for mining the nearest neighbors (train) ... 55 | > Fill Memory Bank [0/10] 56 | > Mine the nearest neighbors (Top-20) 57 | > Accuracy of top-20 nearest neighbors on train set is 72.81 58 | > Fill memory bank for mining the nearest neighbors (val) ... 59 | > Fill Memory Bank [0/16] 60 | > Mine the nearest neighbors (Top-5) 61 | > Accuracy of top-5 nearest neighbors on val set is 79.85 62 | ``` 63 | Now, the model has been correctly saved for the clustering step and the nearest neighbors were computed automatically. 64 | 65 | ## Semantic clustering 66 | 67 | We will start the clustering procedure now. Simply run the command underneath. The nearest neighbors and pretext model will be loaded automatically: 68 | ```bash 69 | python scan.py --config_env configs/env.yml --config_exp configs/scan/scan_stl10.yml 70 | ``` 71 | 72 | On average, you should get around 75.5% (as reported in the paper). I get around 80% for this run. 73 | As can be seen, the best model is selected based on the lowest loss on the validation set. 74 | A complete log file is included in `logs/scan_stl10.txt`. It can be viewed in color with `cat logs/scan_stl10.txt` in your terminal. 75 | ``` 76 | > Epoch 100/100 77 | > Adjusted learning rate to 0.00010 78 | > Train ... 79 | > Epoch: [99][ 0/39] Total Loss -1.0914e+01 (-1.0914e+01) Consistency Loss 5.4953e-01 (5.4953e-01) Entropy 2.2927e+00 (2.2927e+00) 80 | > Epoch: [99][25/39] Total Loss -1.0860e+01 (-1.0824e+01) Consistency Loss 6.0039e-01 (6.2986e-01) Entropy 2.2920e+00 (2.2908e+00) 81 | > Make prediction on validation set ... 82 | > Evaluate based on SCAN loss ... 83 | > {'scan': [{'entropy': 2.298224687576294, 'consistency': 0.4744518995285034, 'total_loss': -1.8237727880477905}], 'lowest_loss_head': 0, 'lowest_loss': -1.8237727880477905} 84 | > No new lowest loss on validation set: -1.8248 -> -1.8238 85 | > Lowest loss head is 0 86 | > Evaluate with hungarian matching algorithm ... 87 | > {'ACC': 0.79, 'ARI': 0.6165977838702869, 'NMI': 0.6698521018927263, 'Top-5': 0.9895, 'hungarian_match': [(0, 7), (1, 3), (2, 1), (3, 5), (4, 2), (5, 0), (6, 8), (7, 9), (8, 4), (9, 6)]} 88 | > Checkpoint ... 89 | > Evaluate best model based on SCAN metric at the end 90 | > {'ACC': 0.8015, 'ARI': 0.6332440325942004, 'NMI': 0.6823369373831116, 'Top-5': 0.990625, 'hungarian_match': [(0, 7), (1, 3), (2, 1), (3, 5), (4, 2), (5, 0), (6, 8), (7, 9), (8, 4), (9, 6)]} ] 91 | ``` 92 | 93 | ## Visualization 94 | Now, we can visualize the confusion matrix and the prototypes of our model. We define the prototypes as the most confident samples for each cluster. We visualize the sample which is the closest to the mean embedding of its confident samples for each cluster. Run the following command: 95 | ```bash 96 | python eval.py --config_exp configs/scan/scan_stl10.yml --model repository_eccv/stl-10/scan/model.pth.tar --visualize_prototypes 97 | ``` 98 | As can be seen from the confusion matrix, the model confuses primarily between visually similar classes (e.g. cats, dogs and monkeys). Some images are classified near perfection (e.g. ship) without the use of ground truth. 99 |

100 | 101 |

102 |

103 | 104 |

105 | 106 | 107 | ## Citation 108 | 109 | If you find this tutorial useful for your research, please consider citing our paper: 110 | 111 | ```bibtex 112 | @inproceedings{vangansbeke2020scan, 113 | title={Scan: Learning to classify images without labels}, 114 | author={Van Gansbeke, Wouter and Vandenhende, Simon and Georgoulis, Stamatios and Proesmans, Marc and Van Gool, Luc}, 115 | booktitle={Proceedings of the European Conference on Computer Vision}, 116 | year={2020} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /simclr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | import numpy as np 9 | 10 | from utils.config import create_config 11 | from utils.common_config import get_criterion, get_model, get_train_dataset,\ 12 | get_val_dataset, get_train_dataloader,\ 13 | get_val_dataloader, get_train_transformations,\ 14 | get_val_transformations, get_optimizer,\ 15 | adjust_learning_rate 16 | from utils.evaluate_utils import contrastive_evaluate 17 | from utils.memory import MemoryBank 18 | from utils.train_utils import simclr_train 19 | from utils.utils import fill_memory_bank 20 | from termcolor import colored 21 | 22 | # Parser 23 | parser = argparse.ArgumentParser(description='SimCLR') 24 | parser.add_argument('--config_env', 25 | help='Config file for the environment') 26 | parser.add_argument('--config_exp', 27 | help='Config file for the experiment') 28 | args = parser.parse_args() 29 | 30 | def main(): 31 | 32 | # Retrieve config file 33 | p = create_config(args.config_env, args.config_exp) 34 | print(colored(p, 'red')) 35 | 36 | # Model 37 | print(colored('Retrieve model', 'blue')) 38 | model = get_model(p) 39 | print('Model is {}'.format(model.__class__.__name__)) 40 | print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) 41 | print(model) 42 | model = model.cuda() 43 | 44 | # CUDNN 45 | print(colored('Set CuDNN benchmark', 'blue')) 46 | torch.backends.cudnn.benchmark = True 47 | 48 | # Dataset 49 | print(colored('Retrieve dataset', 'blue')) 50 | train_transforms = get_train_transformations(p) 51 | print('Train transforms:', train_transforms) 52 | val_transforms = get_val_transformations(p) 53 | print('Validation transforms:', val_transforms) 54 | train_dataset = get_train_dataset(p, train_transforms, to_augmented_dataset=True, 55 | split='train+unlabeled') # Split is for stl-10 56 | val_dataset = get_val_dataset(p, val_transforms) 57 | train_dataloader = get_train_dataloader(p, train_dataset) 58 | val_dataloader = get_val_dataloader(p, val_dataset) 59 | print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset))) 60 | 61 | # Memory Bank 62 | print(colored('Build MemoryBank', 'blue')) 63 | base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval 64 | base_dataloader = get_val_dataloader(p, base_dataset) 65 | memory_bank_base = MemoryBank(len(base_dataset), 66 | p['model_kwargs']['features_dim'], 67 | p['num_classes'], p['criterion_kwargs']['temperature']) 68 | memory_bank_base.cuda() 69 | memory_bank_val = MemoryBank(len(val_dataset), 70 | p['model_kwargs']['features_dim'], 71 | p['num_classes'], p['criterion_kwargs']['temperature']) 72 | memory_bank_val.cuda() 73 | 74 | # Criterion 75 | print(colored('Retrieve criterion', 'blue')) 76 | criterion = get_criterion(p) 77 | print('Criterion is {}'.format(criterion.__class__.__name__)) 78 | criterion = criterion.cuda() 79 | 80 | # Optimizer and scheduler 81 | print(colored('Retrieve optimizer', 'blue')) 82 | optimizer = get_optimizer(p, model) 83 | print(optimizer) 84 | 85 | # Checkpoint 86 | if os.path.exists(p['pretext_checkpoint']): 87 | print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue')) 88 | checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu') 89 | optimizer.load_state_dict(checkpoint['optimizer']) 90 | model.load_state_dict(checkpoint['model']) 91 | model.cuda() 92 | start_epoch = checkpoint['epoch'] 93 | 94 | else: 95 | print(colored('No checkpoint file at {}'.format(p['pretext_checkpoint']), 'blue')) 96 | start_epoch = 0 97 | model = model.cuda() 98 | 99 | # Training 100 | print(colored('Starting main loop', 'blue')) 101 | for epoch in range(start_epoch, p['epochs']): 102 | print(colored('Epoch %d/%d' %(epoch, p['epochs']), 'yellow')) 103 | print(colored('-'*15, 'yellow')) 104 | 105 | # Adjust lr 106 | lr = adjust_learning_rate(p, optimizer, epoch) 107 | print('Adjusted learning rate to {:.5f}'.format(lr)) 108 | 109 | # Train 110 | print('Train ...') 111 | simclr_train(train_dataloader, model, criterion, optimizer, epoch) 112 | 113 | # Fill memory bank 114 | print('Fill memory bank for kNN...') 115 | fill_memory_bank(base_dataloader, model, memory_bank_base) 116 | 117 | # Evaluate (To monitor progress - Not for validation) 118 | print('Evaluate ...') 119 | top1 = contrastive_evaluate(val_dataloader, model, memory_bank_base) 120 | print('Result of kNN evaluation is %.2f' %(top1)) 121 | 122 | # Checkpoint 123 | print('Checkpoint ...') 124 | torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 125 | 'epoch': epoch + 1}, p['pretext_checkpoint']) 126 | 127 | # Save final model 128 | torch.save(model.state_dict(), p['pretext_model']) 129 | 130 | # Mine the topk nearest neighbors at the very end (Train) 131 | # These will be served as input to the SCAN loss. 132 | print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue')) 133 | fill_memory_bank(base_dataloader, model, memory_bank_base) 134 | topk = 20 135 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 136 | indices, acc = memory_bank_base.mine_nearest_neighbors(topk) 137 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) 138 | np.save(p['topk_neighbors_train_path'], indices) 139 | 140 | 141 | # Mine the topk nearest neighbors at the very end (Val) 142 | # These will be used for validation. 143 | print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue')) 144 | fill_memory_bank(val_dataloader, model, memory_bank_val) 145 | topk = 5 146 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 147 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk) 148 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) 149 | np.save(p['topk_neighbors_val_path'], indices) 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /data/imagenet_subsets/imagenet_200.txt: -------------------------------------------------------------------------------- 1 | n01443537 goldfish, Carassius auratus 2 | n01514668 cock 3 | n01558993 robin, American robin, Turdus migratorius 4 | n01601694 water ouzel, dipper 5 | n01669191 box turtle, box tortoise 6 | n01751748 sea snake 7 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 8 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 9 | n01770081 harvestman, daddy longlegs, Phalangium opilio 10 | n01770393 scorpion 11 | n01773797 garden spider, Aranea diademata 12 | n01775062 wolf spider, hunting spider 13 | n01806143 peacock 14 | n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 15 | n01855672 goose 16 | n01871265 tusker 17 | n01917289 brain coral 18 | n02002724 black stork, Ciconia nigra 19 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 20 | n02028035 redshank, Tringa totanus 21 | n02033041 dowitcher 22 | n02037110 oystercatcher, oyster catcher 23 | n02058221 albatross, mollymawk 24 | n02087046 toy terrier 25 | n02088094 Afghan hound, Afghan 26 | n02088632 bluetick 27 | n02090622 borzoi, Russian wolfhound 28 | n02091134 whippet 29 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 30 | n02093754 Border terrier 31 | n02093859 Kerry blue terrier 32 | n02094114 Norfolk terrier 33 | n02096177 cairn, cairn terrier 34 | n02096294 Australian terrier 35 | n02097130 giant schnauzer 36 | n02097298 Scotch terrier, Scottish terrier, Scottie 37 | n02099267 flat-coated retriever 38 | n02099849 Chesapeake Bay retriever 39 | n02100735 English setter 40 | n02100877 Irish setter, red setter 41 | n02104365 schipperke 42 | n02105251 briard 43 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 44 | n02106030 collie 45 | n02106166 Border collie 46 | n02107142 Doberman, Doberman pinscher 47 | n02107683 Bernese mountain dog 48 | n02110185 Siberian husky 49 | n02110341 dalmatian, coach dog, carriage dog 50 | n02111889 Samoyed, Samoyede 51 | n02113799 standard poodle 52 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 53 | n02120079 Arctic fox, white fox, Alopex lagopus 54 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 55 | n02123597 Siamese cat, Siamese 56 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 57 | n02128385 leopard, Panthera pardus 58 | n02129165 lion, king of beasts, Panthera leo 59 | n02130308 cheetah, chetah, Acinonyx jubatus 60 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 61 | n02219486 ant, emmet, pismire 62 | n02229544 cricket 63 | n02277742 ringlet, ringlet butterfly 64 | n02325366 wood rabbit, cottontail, cottontail rabbit 65 | n02342885 hamster 66 | n02364673 guinea pig, Cavia cobaya 67 | n02437616 llama 68 | n02454379 armadillo 69 | n02484975 guenon, guenon monkey 70 | n02488702 colobus, colobus monkey 71 | n02489166 proboscis monkey, Nasalis larvatus 72 | n02492660 howler monkey, howler 73 | n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 74 | n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 75 | n02643566 lionfish 76 | n02655020 puffer, pufferfish, blowfish, globefish 77 | n02666196 abacus 78 | n02708093 analog clock 79 | n02730930 apron 80 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 81 | n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack 82 | n02783161 ballpoint, ballpoint pen, ballpen, Biro 83 | n02791124 barber chair 84 | n02794156 barometer 85 | n02804414 bassinet 86 | n02804610 bassoon 87 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 88 | n02837789 bikini, two-piece 89 | n02865351 bolo tie, bolo, bola tie, bola 90 | n02877765 bottlecap 91 | n02906734 broom 92 | n02909870 bucket, pail 93 | n02910353 buckle 94 | n02916936 bulletproof vest 95 | n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 96 | n02978881 cassette 97 | n02981792 catamaran 98 | n02992211 cello, violoncello 99 | n03063599 coffee mug 100 | n03085013 computer keyboard, keypad 101 | n03124170 cowboy hat, ten-gallon hat 102 | n03127747 crash helmet 103 | n03134739 croquet ball 104 | n03160309 dam, dike, dyke 105 | n03208938 disk brake, disc brake 106 | n03255030 dumbbell 107 | n03272010 electric guitar 108 | n03291819 envelope 109 | n03337140 file, file cabinet, filing cabinet 110 | n03347037 fire screen, fireguard 111 | n03355925 flagpole, flagstaff 112 | n03445924 golfcart, golf cart 113 | n03447447 gondola 114 | n03447721 gong, tam-tam 115 | n03450230 gown 116 | n03457902 greenhouse, nursery, glasshouse 117 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 118 | n03498962 hatchet 119 | n03530642 honeycomb 120 | n03599486 jinrikisha, ricksha, rickshaw 121 | n03617480 kimono 122 | n03623198 knee pad 123 | n03627232 knot 124 | n03649909 lawn mower, mower 125 | n03710721 maillot, tank suit 126 | n03717622 manhole cover 127 | n03733281 maze, labyrinth 128 | n03759954 microphone, mike 129 | n03763968 military uniform 130 | n03764736 milk can 131 | n03775071 mitten 132 | n03782006 monitor 133 | n03814639 neck brace 134 | n03837869 obelisk 135 | n03838899 oboe, hautboy, hautbois 136 | n03841143 odometer, hodometer, mileometer, milometer 137 | n03854065 organ, pipe organ 138 | n03868863 oxygen mask 139 | n03874293 paddlewheel, paddle wheel 140 | n03877845 palace 141 | n03929855 pickelhaube 142 | n03930313 picket fence, paling 143 | n03954731 plane, carpenter's plane, woodworking plane 144 | n03956157 planetarium 145 | n03976467 Polaroid camera, Polaroid Land camera 146 | n03976657 pole 147 | n03983396 pop bottle, soda bottle 148 | n04004767 printer 149 | n04026417 purse 150 | n04065272 recreational vehicle, RV, R.V. 151 | n04146614 school bus 152 | n04200800 shoe shop, shoe-shop, shoe store 153 | n04209239 shower curtain 154 | n04235860 sleeping bag 155 | n04238763 slide rule, slipstick 156 | n04252077 snowmobile 157 | n04263257 soup bowl 158 | n04264628 space bar 159 | n04265275 space heater 160 | n04266014 space shuttle 161 | n04275548 spider web, spider's web 162 | n04311004 steel arch bridge 163 | n04325704 stole 164 | n04330267 stove 165 | n04332243 strainer 166 | n04335435 streetcar, tram, tramcar, trolley, trolley car 167 | n04336792 stretcher 168 | n04346328 stupa, tope 169 | n04350905 suit, suit of clothes 170 | n04366367 suspension bridge 171 | n04367480 swab, swob, mop 172 | n04380533 table lamp 173 | n04392985 tape player 174 | n04428191 thresher, thrasher, threshing machine 175 | n04443257 tobacco shop, tobacconist shop, tobacconist 176 | n04458633 totem pole 177 | n04479046 trench coat 178 | n04483307 trimaran 179 | n04509417 unicycle, monocycle 180 | n04515003 upright, upright piano 181 | n04525305 vending machine 182 | n04532106 vestment 183 | n04554684 washer, automatic washer, washing machine 184 | n04579145 whiskey jug 185 | n04591157 Windsor tie 186 | n04592741 wing 187 | n04604644 worm fence, snake fence, snake-rail fence, Virginia fence 188 | n04606251 wreck 189 | n07583066 guacamole 190 | n07584110 consomme 191 | n07613480 trifle 192 | n07615774 ice lolly, lolly, lollipop, popsicle 193 | n07684084 French loaf 194 | n07693725 bagel, beigel 195 | n07711569 mashed potato 196 | n07714990 broccoli 197 | n07749582 lemon 198 | n07753592 banana 199 | n09288635 geyser 200 | n11879895 rapeseed 201 | -------------------------------------------------------------------------------- /utils/evaluate_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from utils.common_config import get_feature_dimensions_backbone 9 | from utils.utils import AverageMeter, confusion_matrix 10 | from data.custom_dataset import NeighborsDataset 11 | from sklearn import metrics 12 | from scipy.optimize import linear_sum_assignment 13 | from losses.losses import entropy 14 | 15 | 16 | @torch.no_grad() 17 | def contrastive_evaluate(val_loader, model, memory_bank): 18 | top1 = AverageMeter('Acc@1', ':6.2f') 19 | model.eval() 20 | 21 | for batch in val_loader: 22 | images = batch['image'].cuda(non_blocking=True) 23 | target = batch['target'].cuda(non_blocking=True) 24 | 25 | output = model(images) 26 | output = memory_bank.weighted_knn(output) 27 | 28 | acc1 = 100*torch.mean(torch.eq(output, target).float()) 29 | top1.update(acc1.item(), images.size(0)) 30 | 31 | return top1.avg 32 | 33 | 34 | @torch.no_grad() 35 | def get_predictions(p, dataloader, model, return_features=False): 36 | # Make predictions on a dataset with neighbors 37 | model.eval() 38 | predictions = [[] for _ in range(p['num_heads'])] 39 | probs = [[] for _ in range(p['num_heads'])] 40 | targets = [] 41 | if return_features: 42 | ft_dim = get_feature_dimensions_backbone(p) 43 | features = torch.zeros((len(dataloader.sampler), ft_dim)).cuda() 44 | 45 | if isinstance(dataloader.dataset, NeighborsDataset): # Also return the neighbors 46 | key_ = 'anchor' 47 | include_neighbors = True 48 | neighbors = [] 49 | 50 | else: 51 | key_ = 'image' 52 | include_neighbors = False 53 | 54 | ptr = 0 55 | for batch in dataloader: 56 | images = batch[key_].cuda(non_blocking=True) 57 | bs = images.shape[0] 58 | res = model(images, forward_pass='return_all') 59 | output = res['output'] 60 | if return_features: 61 | features[ptr: ptr+bs] = res['features'] 62 | ptr += bs 63 | for i, output_i in enumerate(output): 64 | predictions[i].append(torch.argmax(output_i, dim=1)) 65 | probs[i].append(F.softmax(output_i, dim=1)) 66 | targets.append(batch['target']) 67 | if include_neighbors: 68 | neighbors.append(batch['possible_neighbors']) 69 | 70 | predictions = [torch.cat(pred_, dim = 0).cpu() for pred_ in predictions] 71 | probs = [torch.cat(prob_, dim=0).cpu() for prob_ in probs] 72 | targets = torch.cat(targets, dim=0) 73 | 74 | if include_neighbors: 75 | neighbors = torch.cat(neighbors, dim=0) 76 | out = [{'predictions': pred_, 'probabilities': prob_, 'targets': targets, 'neighbors': neighbors} for pred_, prob_ in zip(predictions, probs)] 77 | 78 | else: 79 | out = [{'predictions': pred_, 'probabilities': prob_, 'targets': targets} for pred_, prob_ in zip(predictions, probs)] 80 | 81 | if return_features: 82 | return out, features.cpu() 83 | else: 84 | return out 85 | 86 | 87 | @torch.no_grad() 88 | def scan_evaluate(predictions): 89 | # Evaluate model based on SCAN loss. 90 | num_heads = len(predictions) 91 | output = [] 92 | 93 | for head in predictions: 94 | # Neighbors and anchors 95 | probs = head['probabilities'] 96 | neighbors = head['neighbors'] 97 | anchors = torch.arange(neighbors.size(0)).view(-1,1).expand_as(neighbors) 98 | 99 | # Entropy loss 100 | entropy_loss = entropy(torch.mean(probs, dim=0), input_as_probabilities=True).item() 101 | 102 | # Consistency loss 103 | similarity = torch.matmul(probs, probs.t()) 104 | neighbors = neighbors.contiguous().view(-1) 105 | anchors = anchors.contiguous().view(-1) 106 | similarity = similarity[anchors, neighbors] 107 | ones = torch.ones_like(similarity) 108 | consistency_loss = F.binary_cross_entropy(similarity, ones).item() 109 | 110 | # Total loss 111 | total_loss = - entropy_loss + consistency_loss 112 | 113 | output.append({'entropy': entropy_loss, 'consistency': consistency_loss, 'total_loss': total_loss}) 114 | 115 | total_losses = [output_['total_loss'] for output_ in output] 116 | lowest_loss_head = np.argmin(total_losses) 117 | lowest_loss = np.min(total_losses) 118 | 119 | return {'scan': output, 'lowest_loss_head': lowest_loss_head, 'lowest_loss': lowest_loss} 120 | 121 | 122 | @torch.no_grad() 123 | def hungarian_evaluate(subhead_index, all_predictions, class_names=None, 124 | compute_purity=True, compute_confusion_matrix=True, 125 | confusion_matrix_file=None): 126 | # Evaluate model based on hungarian matching between predicted cluster assignment and gt classes. 127 | # This is computed only for the passed subhead index. 128 | 129 | # Hungarian matching 130 | head = all_predictions[subhead_index] 131 | targets = head['targets'].cuda() 132 | predictions = head['predictions'].cuda() 133 | probs = head['probabilities'].cuda() 134 | num_classes = torch.unique(targets).numel() 135 | num_elems = targets.size(0) 136 | 137 | match = _hungarian_match(predictions, targets, preds_k=num_classes, targets_k=num_classes) 138 | reordered_preds = torch.zeros(num_elems, dtype=predictions.dtype).cuda() 139 | for pred_i, target_i in match: 140 | reordered_preds[predictions == int(pred_i)] = int(target_i) 141 | 142 | # Gather performance metrics 143 | acc = int((reordered_preds == targets).sum()) / float(num_elems) 144 | nmi = metrics.normalized_mutual_info_score(targets.cpu().numpy(), predictions.cpu().numpy()) 145 | ari = metrics.adjusted_rand_score(targets.cpu().numpy(), predictions.cpu().numpy()) 146 | 147 | _, preds_top5 = probs.topk(5, 1, largest=True) 148 | reordered_preds_top5 = torch.zeros_like(preds_top5) 149 | for pred_i, target_i in match: 150 | reordered_preds_top5[preds_top5 == int(pred_i)] = int(target_i) 151 | correct_top5_binary = reordered_preds_top5.eq(targets.view(-1,1).expand_as(reordered_preds_top5)) 152 | top5 = float(correct_top5_binary.sum()) / float(num_elems) 153 | 154 | # Compute confusion matrix 155 | if compute_confusion_matrix: 156 | confusion_matrix(reordered_preds.cpu().numpy(), targets.cpu().numpy(), 157 | class_names, confusion_matrix_file) 158 | 159 | return {'ACC': acc, 'ARI': ari, 'NMI': nmi, 'ACC Top-5': top5, 'hungarian_match': match} 160 | 161 | 162 | @torch.no_grad() 163 | def _hungarian_match(flat_preds, flat_targets, preds_k, targets_k): 164 | # Based on implementation from IIC 165 | num_samples = flat_targets.shape[0] 166 | 167 | assert (preds_k == targets_k) # one to one 168 | num_k = preds_k 169 | num_correct = np.zeros((num_k, num_k)) 170 | 171 | for c1 in range(num_k): 172 | for c2 in range(num_k): 173 | # elementwise, so each sample contributes once 174 | votes = int(((flat_preds == c1) * (flat_targets == c2)).sum()) 175 | num_correct[c1, c2] = votes 176 | 177 | # num_correct is small 178 | match = linear_sum_assignment(num_samples - num_correct) 179 | match = np.array(list(zip(*match))) 180 | 181 | # return as list of tuples, out_c to gt_c 182 | res = [] 183 | for out_c, gt_c in match: 184 | res.append((out_c, gt_c)) 185 | 186 | return res 187 | -------------------------------------------------------------------------------- /data/stl.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | from PIL import Image 5 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive, verify_str_arg 6 | from torch.utils.data import Dataset 7 | from utils.mypath import MyPath 8 | import os 9 | import numpy as np 10 | 11 | 12 | class STL10(Dataset): 13 | """`STL10 `_ Dataset. 14 | Args: 15 | root (string): Root directory of dataset where directory 16 | ``stl10_binary`` exists. 17 | split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}. 18 | Accordingly dataset is selected. 19 | folds (int, optional): One of {0-9} or None. 20 | For training, loads one of the 10 pre-defined folds of 1k samples for the 21 | standard evaluation procedure. If no value is passed, loads the 5k samples. 22 | transform (callable, optional): A function/transform that takes in an PIL image 23 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 24 | target_transform (callable, optional): A function/transform that takes in the 25 | target and transforms it. 26 | download (bool, optional): If true, downloads the dataset from the internet and 27 | puts it in root directory. If dataset is already downloaded, it is not 28 | downloaded again. 29 | """ 30 | base_folder = 'stl10_binary' 31 | url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz" 32 | filename = "stl10_binary.tar.gz" 33 | tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb' 34 | class_names_file = 'class_names.txt' 35 | folds_list_file = 'fold_indices.txt' 36 | train_list = [ 37 | ['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'], 38 | ['train_y.bin', '5a34089d4802c674881badbb80307741'], 39 | ['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4'] 40 | ] 41 | 42 | test_list = [ 43 | ['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'], 44 | ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e'] 45 | ] 46 | splits = ('train', 'train+unlabeled', 'unlabeled', 'test') 47 | 48 | def __init__(self, root=MyPath.db_root_dir('stl-10'), 49 | split='train', folds=None, transform=None, 50 | download=False): 51 | super(STL10, self).__init__() 52 | self.root = root 53 | self.transform = transform 54 | self.split = verify_str_arg(split, "split", self.splits) 55 | self.folds = self._verify_folds(folds) 56 | if download: 57 | self.download() 58 | elif not self._check_integrity(): 59 | raise RuntimeError( 60 | 'Dataset not found or corrupted. ' 61 | 'You can use download=True to download it') 62 | 63 | # now load the picked numpy arrays 64 | if self.split == 'train': 65 | self.data, self.labels = self.__loadfile( 66 | self.train_list[0][0], self.train_list[1][0]) 67 | self.__load_folds(folds) 68 | 69 | elif self.split == 'train+unlabeled': 70 | self.data, self.labels = self.__loadfile( 71 | self.train_list[0][0], self.train_list[1][0]) 72 | self.__load_folds(folds) 73 | unlabeled_data, _ = self.__loadfile(self.train_list[2][0]) 74 | self.data = np.concatenate((self.data, unlabeled_data)) 75 | self.labels = np.concatenate( 76 | (self.labels, np.asarray([-1] * unlabeled_data.shape[0]))) 77 | 78 | elif self.split == 'unlabeled': 79 | self.data, _ = self.__loadfile(self.train_list[2][0]) 80 | self.labels = np.asarray([-1] * self.data.shape[0]) 81 | else: # self.split == 'test': 82 | self.data, self.labels = self.__loadfile( 83 | self.test_list[0][0], self.test_list[1][0]) 84 | 85 | class_file = os.path.join( 86 | self.root, self.base_folder, self.class_names_file) 87 | if os.path.isfile(class_file): 88 | with open(class_file) as f: 89 | self.classes = f.read().splitlines() 90 | 91 | if self.split == 'train': # Added this to be able to filter out fp from neighbors 92 | self.targets = self.labels 93 | 94 | 95 | def _verify_folds(self, folds): 96 | if folds is None: 97 | return folds 98 | elif isinstance(folds, int): 99 | if folds in range(10): 100 | return folds 101 | msg = ("Value for argument folds should be in the range [0, 10), " 102 | "but got {}.") 103 | raise ValueError(msg.format(folds)) 104 | else: 105 | msg = "Expected type None or int for argument folds, but got type {}." 106 | raise ValueError(msg.format(type(folds))) 107 | 108 | 109 | def __getitem__(self, index): 110 | """ 111 | Args: 112 | index (int): Index 113 | Returns: 114 | tuple: (image, target) where target is index of the target class. 115 | """ 116 | if self.labels is not None: 117 | img, target = self.data[index], int(self.labels[index]) 118 | class_name = self.classes[target] 119 | else: 120 | img, target = self.data[index], 255 # 255 is an ignore index 121 | class_name = 'unlabeled' 122 | 123 | # make consistent with all other datasets 124 | # return a PIL Image 125 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 126 | img_size = img.size 127 | 128 | if self.transform is not None: 129 | img = self.transform(img) 130 | 131 | out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index, 'class_name': class_name}} 132 | 133 | return out 134 | 135 | def get_image(self, index): 136 | img = self.data[index] 137 | img = np.transpose(img, (1, 2, 0)) 138 | return img 139 | 140 | def __len__(self): 141 | return self.data.shape[0] 142 | 143 | def __loadfile(self, data_file, labels_file=None): 144 | labels = None 145 | if labels_file: 146 | path_to_labels = os.path.join( 147 | self.root, self.base_folder, labels_file) 148 | with open(path_to_labels, 'rb') as f: 149 | labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based 150 | 151 | path_to_data = os.path.join(self.root, self.base_folder, data_file) 152 | with open(path_to_data, 'rb') as f: 153 | # read whole file in uint8 chunks 154 | everything = np.fromfile(f, dtype=np.uint8) 155 | images = np.reshape(everything, (-1, 3, 96, 96)) 156 | images = np.transpose(images, (0, 1, 3, 2)) 157 | 158 | return images, labels 159 | 160 | def _check_integrity(self): 161 | root = self.root 162 | for fentry in (self.train_list + self.test_list): 163 | filename, md5 = fentry[0], fentry[1] 164 | fpath = os.path.join(root, self.base_folder, filename) 165 | if not check_integrity(fpath, md5): 166 | return False 167 | return True 168 | 169 | def download(self): 170 | if self._check_integrity(): 171 | print('Files already downloaded and verified') 172 | return 173 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 174 | self._check_integrity() 175 | 176 | def extra_repr(self): 177 | return "Split: {split}".format(**self.__dict__) 178 | 179 | def __load_folds(self, folds): 180 | # loads one of the folds if specified 181 | if folds is None: 182 | return 183 | path_to_folds = os.path.join( 184 | self.root, self.base_folder, self.folds_list_file) 185 | with open(path_to_folds, 'r') as f: 186 | str_idx = f.read().splitlines()[folds] 187 | list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ') 188 | self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx] 189 | -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import os 5 | import pickle 6 | import sys 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from utils.mypath import MyPath 12 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 13 | 14 | 15 | class CIFAR10(Dataset): 16 | """`CIFAR10 `_ Dataset. 17 | Args: 18 | root (string): Root directory of dataset where directory 19 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 20 | train (bool, optional): If True, creates dataset from training set, otherwise 21 | creates from test set. 22 | transform (callable, optional): A function/transform that takes in an PIL image 23 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 24 | download (bool, optional): If true, downloads the dataset from the internet and 25 | puts it in root directory. If dataset is already downloaded, it is not 26 | downloaded again. 27 | """ 28 | base_folder = 'cifar-10-batches-py' 29 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 30 | filename = "cifar-10-python.tar.gz" 31 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 32 | train_list = [ 33 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 34 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 35 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 36 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 37 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 38 | ] 39 | 40 | test_list = [ 41 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 42 | ] 43 | meta = { 44 | 'filename': 'batches.meta', 45 | 'key': 'label_names', 46 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 47 | } 48 | 49 | def __init__(self, root=MyPath.db_root_dir('cifar-10'), train=True, transform=None, 50 | download=False): 51 | 52 | super(CIFAR10, self).__init__() 53 | self.root = root 54 | self.transform = transform 55 | self.train = train # training set or test set 56 | self.classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 57 | 58 | if download: 59 | self.download() 60 | 61 | if not self._check_integrity(): 62 | raise RuntimeError('Dataset not found or corrupted.' + 63 | ' You can use download=True to download it') 64 | 65 | if self.train: 66 | downloaded_list = self.train_list 67 | else: 68 | downloaded_list = self.test_list 69 | 70 | self.data = [] 71 | self.targets = [] 72 | 73 | # now load the picked numpy arrays 74 | for file_name, checksum in downloaded_list: 75 | file_path = os.path.join(self.root, self.base_folder, file_name) 76 | with open(file_path, 'rb') as f: 77 | if sys.version_info[0] == 2: 78 | entry = pickle.load(f) 79 | else: 80 | entry = pickle.load(f, encoding='latin1') 81 | self.data.append(entry['data']) 82 | if 'labels' in entry: 83 | self.targets.extend(entry['labels']) 84 | else: 85 | self.targets.extend(entry['fine_labels']) 86 | 87 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 88 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 89 | 90 | self._load_meta() 91 | 92 | def _load_meta(self): 93 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 94 | if not check_integrity(path, self.meta['md5']): 95 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 96 | ' You can use download=True to download it') 97 | with open(path, 'rb') as infile: 98 | if sys.version_info[0] == 2: 99 | data = pickle.load(infile) 100 | else: 101 | data = pickle.load(infile, encoding='latin1') 102 | self.classes = data[self.meta['key']] 103 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 104 | 105 | def __getitem__(self, index): 106 | """ 107 | Args: 108 | index (int): Index 109 | Returns: 110 | dict: {'image': image, 'target': index of target class, 'meta': dict} 111 | """ 112 | img, target = self.data[index], self.targets[index] 113 | img_size = (img.shape[0], img.shape[1]) 114 | img = Image.fromarray(img) 115 | class_name = self.classes[target] 116 | 117 | if self.transform is not None: 118 | img = self.transform(img) 119 | 120 | out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index, 'class_name': class_name}} 121 | 122 | return out 123 | 124 | def get_image(self, index): 125 | img = self.data[index] 126 | return img 127 | 128 | def __len__(self): 129 | return len(self.data) 130 | 131 | def _check_integrity(self): 132 | root = self.root 133 | for fentry in (self.train_list + self.test_list): 134 | filename, md5 = fentry[0], fentry[1] 135 | fpath = os.path.join(root, self.base_folder, filename) 136 | if not check_integrity(fpath, md5): 137 | return False 138 | return True 139 | 140 | def download(self): 141 | if self._check_integrity(): 142 | print('Files already downloaded and verified') 143 | return 144 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 145 | 146 | def extra_repr(self): 147 | return "Split: {}".format("Train" if self.train is True else "Test") 148 | 149 | 150 | class CIFAR20(CIFAR10): 151 | """CIFAR20 Dataset. 152 | 153 | This is a subclass of the `CIFAR10` Dataset. 154 | """ 155 | base_folder = 'cifar-100-python' 156 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 157 | filename = "cifar-100-python.tar.gz" 158 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 159 | train_list = [ 160 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 161 | ] 162 | 163 | test_list = [ 164 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 165 | ] 166 | meta = { 167 | 'filename': 'meta', 168 | 'key': 'fine_label_names', 169 | 'md5': '7973b15100ade9c7d40fb424638fde48', 170 | } 171 | def __init__(self, root=MyPath.db_root_dir('cifar-20'), train=True, transform=None, 172 | download=False): 173 | super(CIFAR20, self).__init__(root, train=train,transform=transform, 174 | download=download) 175 | # Remap classes from cifar-100 to cifar-20 176 | new_ = self.targets 177 | for idx, target in enumerate(self.targets): 178 | new_[idx] = _cifar100_to_cifar20(target) 179 | self.targets = new_ 180 | self.classes = ['aquatic mammals', 'fish', 'flowers', 'food containers', 'fruit and vegetables', 'household electrical devices', 'househould furniture', 'insects', 'large carnivores', 'large man-made outdoor things', 'large natural outdoor scenes', 'large omnivores and herbivores', 'medium-sized mammals', 'non-insect invertebrates', 'people', 'reptiles', 'small mammals', 'trees', 'vehicles 1', 'vehicles 2'] 181 | 182 | 183 | def _cifar100_to_cifar20(target): 184 | _dict = \ 185 | {0: 4, 186 | 1: 1, 187 | 2: 14, 188 | 3: 8, 189 | 4: 0, 190 | 5: 6, 191 | 6: 7, 192 | 7: 7, 193 | 8: 18, 194 | 9: 3, 195 | 10: 3, 196 | 11: 14, 197 | 12: 9, 198 | 13: 18, 199 | 14: 7, 200 | 15: 11, 201 | 16: 3, 202 | 17: 9, 203 | 18: 7, 204 | 19: 11, 205 | 20: 6, 206 | 21: 11, 207 | 22: 5, 208 | 23: 10, 209 | 24: 7, 210 | 25: 6, 211 | 26: 13, 212 | 27: 15, 213 | 28: 3, 214 | 29: 15, 215 | 30: 0, 216 | 31: 11, 217 | 32: 1, 218 | 33: 10, 219 | 34: 12, 220 | 35: 14, 221 | 36: 16, 222 | 37: 9, 223 | 38: 11, 224 | 39: 5, 225 | 40: 5, 226 | 41: 19, 227 | 42: 8, 228 | 43: 8, 229 | 44: 15, 230 | 45: 13, 231 | 46: 14, 232 | 47: 17, 233 | 48: 18, 234 | 49: 10, 235 | 50: 16, 236 | 51: 4, 237 | 52: 17, 238 | 53: 4, 239 | 54: 2, 240 | 55: 0, 241 | 56: 17, 242 | 57: 4, 243 | 58: 18, 244 | 59: 17, 245 | 60: 10, 246 | 61: 3, 247 | 62: 2, 248 | 63: 12, 249 | 64: 12, 250 | 65: 16, 251 | 66: 12, 252 | 67: 1, 253 | 68: 9, 254 | 69: 19, 255 | 70: 2, 256 | 71: 10, 257 | 72: 0, 258 | 73: 1, 259 | 74: 16, 260 | 75: 12, 261 | 76: 9, 262 | 77: 13, 263 | 78: 15, 264 | 79: 13, 265 | 80: 16, 266 | 81: 19, 267 | 82: 2, 268 | 83: 4, 269 | 84: 6, 270 | 85: 19, 271 | 86: 5, 272 | 87: 5, 273 | 88: 8, 274 | 89: 19, 275 | 90: 18, 276 | 91: 1, 277 | 92: 2, 278 | 93: 15, 279 | 94: 6, 280 | 95: 0, 281 | 96: 17, 282 | 97: 8, 283 | 98: 14, 284 | 99: 13} 285 | 286 | return _dict[target] 287 | -------------------------------------------------------------------------------- /utils/common_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import math 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms as transforms 10 | from data.augment import Augment, Cutout 11 | from utils.collate import collate_custom 12 | 13 | 14 | def get_criterion(p): 15 | if p['criterion'] == 'simclr': 16 | from losses.losses import SimCLRLoss 17 | criterion = SimCLRLoss(**p['criterion_kwargs']) 18 | 19 | elif p['criterion'] == 'scan': 20 | from losses.losses import SCANLoss 21 | criterion = SCANLoss(**p['criterion_kwargs']) 22 | 23 | elif p['criterion'] == 'confidence-cross-entropy': 24 | from losses.losses import ConfidenceBasedCE 25 | criterion = ConfidenceBasedCE(p['confidence_threshold'], p['criterion_kwargs']['apply_class_balancing']) 26 | 27 | else: 28 | raise ValueError('Invalid criterion {}'.format(p['criterion'])) 29 | 30 | return criterion 31 | 32 | 33 | def get_feature_dimensions_backbone(p): 34 | if p['backbone'] == 'resnet18': 35 | return 512 36 | 37 | elif p['backbone'] == 'resnet50': 38 | return 2048 39 | 40 | else: 41 | raise NotImplementedError 42 | 43 | 44 | def get_model(p, pretrain_path=None): 45 | # Get backbone 46 | if p['backbone'] == 'resnet18': 47 | if p['train_db_name'] in ['cifar-10', 'cifar-20']: 48 | from models.resnet_cifar import resnet18 49 | backbone = resnet18() 50 | 51 | elif p['train_db_name'] == 'stl-10': 52 | from models.resnet_stl import resnet18 53 | backbone = resnet18() 54 | 55 | else: 56 | raise NotImplementedError 57 | 58 | elif p['backbone'] == 'resnet50': 59 | if 'imagenet' in p['train_db_name']: 60 | from models.resnet import resnet50 61 | backbone = resnet50() 62 | 63 | else: 64 | raise NotImplementedError 65 | 66 | else: 67 | raise ValueError('Invalid backbone {}'.format(p['backbone'])) 68 | 69 | # Setup 70 | if p['setup'] in ['simclr', 'moco']: 71 | from models.models import ContrastiveModel 72 | model = ContrastiveModel(backbone, **p['model_kwargs']) 73 | 74 | elif p['setup'] in ['scan', 'selflabel']: 75 | from models.models import ClusteringModel 76 | if p['setup'] == 'selflabel': 77 | assert(p['num_heads'] == 1) 78 | model = ClusteringModel(backbone, p['num_classes'], p['num_heads']) 79 | 80 | else: 81 | raise ValueError('Invalid setup {}'.format(p['setup'])) 82 | 83 | # Load pretrained weights 84 | if pretrain_path is not None and os.path.exists(pretrain_path): 85 | state = torch.load(pretrain_path, map_location='cpu') 86 | 87 | if p['setup'] == 'scan': # Weights are supposed to be transfered from contrastive training 88 | missing = model.load_state_dict(state, strict=False) 89 | assert(set(missing[1]) == { 90 | 'contrastive_head.0.weight', 'contrastive_head.0.bias', 91 | 'contrastive_head.2.weight', 'contrastive_head.2.bias'} 92 | or set(missing[1]) == { 93 | 'contrastive_head.weight', 'contrastive_head.bias'}) 94 | 95 | elif p['setup'] == 'selflabel': # Weights are supposed to be transfered from scan 96 | # We only continue with the best head (pop all heads first, then copy back the best head) 97 | model_state = state['model'] 98 | all_heads = [k for k in model_state.keys() if 'cluster_head' in k] 99 | best_head_weight = model_state['cluster_head.%d.weight' %(state['head'])] 100 | best_head_bias = model_state['cluster_head.%d.bias' %(state['head'])] 101 | for k in all_heads: 102 | model_state.pop(k) 103 | 104 | model_state['cluster_head.0.weight'] = best_head_weight 105 | model_state['cluster_head.0.bias'] = best_head_bias 106 | missing = model.load_state_dict(model_state, strict=True) 107 | 108 | else: 109 | raise NotImplementedError 110 | 111 | elif pretrain_path is not None and not os.path.exists(pretrain_path): 112 | raise ValueError('Path with pre-trained weights does not exist {}'.format(pretrain_path)) 113 | 114 | else: 115 | pass 116 | 117 | return model 118 | 119 | 120 | def get_train_dataset(p, transform, to_augmented_dataset=False, 121 | to_neighbors_dataset=False, split=None): 122 | # Base dataset 123 | if p['train_db_name'] == 'cifar-10': 124 | from data.cifar import CIFAR10 125 | dataset = CIFAR10(train=True, transform=transform, download=True) 126 | 127 | elif p['train_db_name'] == 'cifar-20': 128 | from data.cifar import CIFAR20 129 | dataset = CIFAR20(train=True, transform=transform, download=True) 130 | 131 | elif p['train_db_name'] == 'stl-10': 132 | from data.stl import STL10 133 | dataset = STL10(split=split, transform=transform, download=True) 134 | 135 | elif p['train_db_name'] == 'imagenet': 136 | from data.imagenet import ImageNet 137 | dataset = ImageNet(split='train', transform=transform) 138 | 139 | elif p['train_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: 140 | from data.imagenet import ImageNetSubset 141 | subset_file = './data/imagenet_subsets/%s.txt' %(p['train_db_name']) 142 | dataset = ImageNetSubset(subset_file=subset_file, split='train', transform=transform) 143 | 144 | else: 145 | raise ValueError('Invalid train dataset {}'.format(p['train_db_name'])) 146 | 147 | # Wrap into other dataset (__getitem__ changes) 148 | if to_augmented_dataset: # Dataset returns an image and an augmentation of that image. 149 | from data.custom_dataset import AugmentedDataset 150 | dataset = AugmentedDataset(dataset) 151 | 152 | if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. 153 | from data.custom_dataset import NeighborsDataset 154 | indices = np.load(p['topk_neighbors_train_path']) 155 | dataset = NeighborsDataset(dataset, indices, p['num_neighbors']) 156 | 157 | return dataset 158 | 159 | 160 | def get_val_dataset(p, transform=None, to_neighbors_dataset=False): 161 | # Base dataset 162 | if p['val_db_name'] == 'cifar-10': 163 | from data.cifar import CIFAR10 164 | dataset = CIFAR10(train=False, transform=transform, download=True) 165 | 166 | elif p['val_db_name'] == 'cifar-20': 167 | from data.cifar import CIFAR20 168 | dataset = CIFAR20(train=False, transform=transform, download=True) 169 | 170 | elif p['val_db_name'] == 'stl-10': 171 | from data.stl import STL10 172 | dataset = STL10(split='test', transform=transform, download=True) 173 | 174 | elif p['val_db_name'] == 'imagenet': 175 | from data.imagenet import ImageNet 176 | dataset = ImageNet(split='val', transform=transform) 177 | 178 | elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: 179 | from data.imagenet import ImageNetSubset 180 | subset_file = './data/imagenet_subsets/%s.txt' %(p['val_db_name']) 181 | dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform) 182 | 183 | else: 184 | raise ValueError('Invalid validation dataset {}'.format(p['val_db_name'])) 185 | 186 | # Wrap into other dataset (__getitem__ changes) 187 | if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. 188 | from data.custom_dataset import NeighborsDataset 189 | indices = np.load(p['topk_neighbors_val_path']) 190 | dataset = NeighborsDataset(dataset, indices, 5) # Only use 5 191 | 192 | return dataset 193 | 194 | 195 | def get_train_dataloader(p, dataset): 196 | return torch.utils.data.DataLoader(dataset, num_workers=p['num_workers'], 197 | batch_size=p['batch_size'], pin_memory=True, collate_fn=collate_custom, 198 | drop_last=True, shuffle=True) 199 | 200 | 201 | def get_val_dataloader(p, dataset): 202 | return torch.utils.data.DataLoader(dataset, num_workers=p['num_workers'], 203 | batch_size=p['batch_size'], pin_memory=True, collate_fn=collate_custom, 204 | drop_last=False, shuffle=False) 205 | 206 | 207 | def get_train_transformations(p): 208 | if p['augmentation_strategy'] == 'standard': 209 | # Standard augmentation strategy 210 | return transforms.Compose([ 211 | transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), 212 | transforms.RandomHorizontalFlip(), 213 | transforms.ToTensor(), 214 | transforms.Normalize(**p['augmentation_kwargs']['normalize']) 215 | ]) 216 | 217 | elif p['augmentation_strategy'] == 'simclr': 218 | # Augmentation strategy from the SimCLR paper 219 | return transforms.Compose([ 220 | transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), 221 | transforms.RandomHorizontalFlip(), 222 | transforms.RandomApply([ 223 | transforms.ColorJitter(**p['augmentation_kwargs']['color_jitter']) 224 | ], p=p['augmentation_kwargs']['color_jitter_random_apply']['p']), 225 | transforms.RandomGrayscale(**p['augmentation_kwargs']['random_grayscale']), 226 | transforms.ToTensor(), 227 | transforms.Normalize(**p['augmentation_kwargs']['normalize']) 228 | ]) 229 | 230 | elif p['augmentation_strategy'] == 'ours': 231 | # Augmentation strategy from our paper 232 | return transforms.Compose([ 233 | transforms.RandomHorizontalFlip(), 234 | transforms.RandomCrop(p['augmentation_kwargs']['crop_size']), 235 | Augment(p['augmentation_kwargs']['num_strong_augs']), 236 | transforms.ToTensor(), 237 | transforms.Normalize(**p['augmentation_kwargs']['normalize']), 238 | Cutout( 239 | n_holes = p['augmentation_kwargs']['cutout_kwargs']['n_holes'], 240 | length = p['augmentation_kwargs']['cutout_kwargs']['length'], 241 | random = p['augmentation_kwargs']['cutout_kwargs']['random'])]) 242 | 243 | else: 244 | raise ValueError('Invalid augmentation strategy {}'.format(p['augmentation_strategy'])) 245 | 246 | 247 | def get_val_transformations(p): 248 | return transforms.Compose([ 249 | transforms.CenterCrop(p['transformation_kwargs']['crop_size']), 250 | transforms.ToTensor(), 251 | transforms.Normalize(**p['transformation_kwargs']['normalize'])]) 252 | 253 | 254 | def get_optimizer(p, model, cluster_head_only=False): 255 | if cluster_head_only: # Only weights in the cluster head will be updated 256 | for name, param in model.named_parameters(): 257 | if 'cluster_head' in name: 258 | param.requires_grad = True 259 | else: 260 | param.requires_grad = False 261 | params = list(filter(lambda p: p.requires_grad, model.parameters())) 262 | assert(len(params) == 2 * p['num_heads']) 263 | 264 | else: 265 | params = model.parameters() 266 | 267 | 268 | if p['optimizer'] == 'sgd': 269 | optimizer = torch.optim.SGD(params, **p['optimizer_kwargs']) 270 | 271 | elif p['optimizer'] == 'adam': 272 | optimizer = torch.optim.Adam(params, **p['optimizer_kwargs']) 273 | 274 | else: 275 | raise ValueError('Invalid optimizer {}'.format(p['optimizer'])) 276 | 277 | return optimizer 278 | 279 | 280 | def adjust_learning_rate(p, optimizer, epoch): 281 | lr = p['optimizer_kwargs']['lr'] 282 | 283 | if p['scheduler'] == 'cosine': 284 | eta_min = lr * (p['scheduler_kwargs']['lr_decay_rate'] ** 3) 285 | lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / p['epochs'])) / 2 286 | 287 | elif p['scheduler'] == 'step': 288 | steps = np.sum(epoch > np.array(p['scheduler_kwargs']['lr_decay_epochs'])) 289 | if steps > 0: 290 | lr = lr * (p['scheduler_kwargs']['lr_decay_rate'] ** steps) 291 | 292 | elif p['scheduler'] == 'constant': 293 | lr = lr 294 | 295 | else: 296 | raise ValueError('Invalid learning rate schedule {}'.format(p['scheduler'])) 297 | 298 | for param_group in optimizer.param_groups: 299 | param_group['lr'] = lr 300 | 301 | return lr 302 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Classify Images without Labels 2 | 3 | 4 | This repo contains the Pytorch implementation of our paper: 5 | > [**SCAN: Learning to Classify Images without Labels**](https://arxiv.org/pdf/2005.12320.pdf) 6 | > 7 | > [Wouter Van Gansbeke](https://twitter.com/WGansbeke), [Simon Vandenhende](https://twitter.com/svandenh1), [Stamatios Georgoulis](https://twitter.com/stam_g), Marc Proesmans and Luc Van Gool. 8 | 9 | - __Accepted at ECCV 2020 ([Slides](https://wvangansbeke.github.io/pdfs/unsupervised_classification.pdf)). Watch the explanation of our paper by Yannic Kilcher on [YouTube](https://www.youtube.com/watch?v=hQEnzdLkPj4).__ 10 | 11 | - 🏆 __SOTA on 4 benchmarks. Check out [Papers With Code](https://paperswithcode.com/paper/learning-to-classify-images-without-labels) for [Image Clustering](https://paperswithcode.com/task/image-clustering) or [Unsup. Classification](https://paperswithcode.com/task/unsupervised-image-classification).__ 12 | - Related works: 13 | - 🆕 __Interested in unsupervised semantic segmentation? Check out our new preprint: [MaskDistill](https://github.com/wvangansbeke/MaskDistill).__ 14 | - 🆕 __Interested in representation learning? Check out our NeurIPS'21 [paper](https://arxiv.org/abs/2106.05967) and [code](https://github.com/wvangansbeke/Revisiting-Contrastive-SSL).__ 15 | - 🆕 __More on unsupervised semantic segmentation? Check out our ICCV'21 paper: [MaskContrast](https://github.com/wvangansbeke/Unsupervised-Semantic-Segmentation).__ 16 | - 📜 __Looking for influential papers in self-supervised learning? Check out this [reading list](https://github.com/wvangansbeke/Self-Supervised-Learning-Overview).__ 17 |

18 | 19 |

20 | 21 | 22 | 23 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-to-classify-images-without-labels/unsupervised-image-classification-on-imagenet)](https://paperswithcode.com/sota/unsupervised-image-classification-on-imagenet?p=learning-to-classify-images-without-labels) 24 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-to-classify-images-without-labels/unsupervised-image-classification-on-cifar-10)](https://paperswithcode.com/sota/unsupervised-image-classification-on-cifar-10?p=learning-to-classify-images-without-labels) 25 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-to-classify-images-without-labels/unsupervised-image-classification-on-stl-10)](https://paperswithcode.com/sota/unsupervised-image-classification-on-stl-10?p=learning-to-classify-images-without-labels) 26 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-to-classify-images-without-labels/unsupervised-image-classification-on-cifar-20)](https://paperswithcode.com/sota/unsupervised-image-classification-on-cifar-20?p=learning-to-classify-images-without-labels) 27 |

28 | 29 | 30 | 31 | ## Contents 32 | 1. [Introduction](#introduction) 33 | 0. [Prior Work](#prior-work) 34 | 0. [Installation](#installation) 35 | 0. [Training](#training) 36 | 0. [Model Zoo](#model-zoo) 37 | 0. [Tutorial](#tutorial) 38 | 0. [Citation](#citation) 39 | 40 | 🆕 Tutorial section has been added, checkout [TUTORIAL.md](https://github.com/wvangansbeke/Unsupervised-Classification/blob/master/TUTORIAL.md). 41 | 42 | 🆕 Prior work section has been added, checkout [Prior Work](#problems-prior-work). 43 | 44 | ## Introduction 45 | Can we automatically group images into semantically meaningful clusters when ground-truth annotations are absent? The task of unsupervised image classification remains an important, and open challenge in computer vision. Several recent approaches have tried to tackle this problem in an end-to-end fashion. In this paper, we deviate from recent works, and advocate a two-step approach where feature learning and clustering are decoupled. 46 | 47 | We outperform state-of-the-art methods by large margins, in particular +26.6% on CIFAR10, +25.0% on CIFAR100-20 and +21.3% on STL10 in terms of classification accuracy. Our method is the first to perform well on ImageNet (1000 classes). 48 | __Check out the benchmarks on the [Papers-with-code](https://paperswithcode.com/paper/learning-to-classify-images-without-labels) website for [Image Clustering](https://paperswithcode.com/task/image-clustering) and [Unsupervised Image Classification](https://paperswithcode.com/task/unsupervised-image-classification).__ 49 | 50 | ## Prior Work 51 | - Train set/test set: 52 | We would like to point out that most prior work in unsupervised classification use both the train and test set during training. We believe this is bad practice and therefore propose to only train on the train set. The final numbers should be reported on the test set (see table 3 of our paper). This also allows us to directly compare with supervised and semi-supervised methods in the literature. We encourage future work to do the same. We observe around 2% improvement over the reported numbers when including the test set. 53 | 54 | - Reproducibility: 55 | We noticed that prior work is very initialization sensitive. So, we don't think reporting a single number is therefore fair. We report our results as the mean and standard deviation over 10 runs. 56 | 57 | Please follow the instructions underneath to perform semantic clustering with SCAN. 58 | 59 | ## Installation 60 | The code runs with recent Pytorch versions, e.g. 1.4. 61 | Assuming [Anaconda](https://docs.anaconda.com/anaconda/install/), the most important packages can be installed as: 62 | ```shell 63 | conda install pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.0 -c pytorch 64 | conda install matplotlib scipy scikit-learn # For evaluation and confusion matrix visualization 65 | conda install faiss-gpu # For efficient nearest neighbors search 66 | conda install pyyaml easydict # For using config files 67 | conda install termcolor # For colored print statements 68 | ``` 69 | We refer to the `requirements.txt` file for an overview of the packages in the environment we used to produce our results. 70 | 71 | ## Training 72 | 73 | ### Setup 74 | The following files need to be adapted in order to run the code on your own machine: 75 | - Change the file paths to the datasets in `utils/mypath.py`, e.g. `/path/to/cifar10`. 76 | - Specify the output directory in `configs/env.yml`. All results will be stored under this directory. 77 | 78 | Our experimental evaluation includes the following datasets: CIFAR10, CIFAR100-20, STL10 and ImageNet. The ImageNet dataset should be downloaded separately and saved to the path described in `utils/mypath.py`. Other datasets will be downloaded automatically and saved to the correct path when missing. 79 | 80 | ### Train model 81 | The configuration files can be found in the `configs/` directory. The training procedure consists of the following steps: 82 | - __STEP 1__: Solve the pretext task i.e. `simclr.py` 83 | - __STEP 2__: Perform the clustering step i.e. `scan.py` 84 | - __STEP 3__: Perform the self-labeling step i.e. `selflabel.py` 85 | 86 | For example, run the following commands sequentially to perform our method on CIFAR10: 87 | ```shell 88 | python simclr.py --config_env configs/your_env.yml --config_exp configs/pretext/simclr_cifar10.yml 89 | python scan.py --config_env configs/your_env.yml --config_exp configs/scan/scan_cifar10.yml 90 | python selflabel.py --config_env configs/your_env.yml --config_exp configs/selflabel/selflabel_cifar10.yml 91 | ``` 92 | ### Remarks 93 | The provided hyperparameters are identical for CIFAR10, CIFAR100-20 and STL10. However, fine-tuning the hyperparameters can further improve the results. We list the most important hyperparameters of our method below: 94 | - Entropy weight: Can be adapted when the number of clusters changes. In general, try to avoid imbalanced clusters during training. 95 | - Confidence threshold: When every cluster contains a sufficiently large amount of confident samples, it can be beneficial to increase the threshold. This generally helps to decrease the noise. The ablation can be found in the paper. 96 | - Number of neighbors in SCAN: The dependency on this hyperparameter is rather small as shown in the paper. 97 | 98 | ## Model Zoo 99 | ### Pretext tasks 100 | We perform the instance discrimination task in accordance with the scheme from [SimCLR](https://arxiv.org/abs/2002.05709) on CIFAR10, CIFAR100 and STL10. Pretrained models can be downloaded from the links listed below. On ImageNet, we use the pretrained weights provided by [MoCo](https://github.com/facebookresearch/moco) and transfer them to be compatible with our code repository. 101 | 102 | | Dataset | Download link | 103 | |------------------|---------------| 104 | |CIFAR10 | [Download](https://drive.google.com/file/d/1Cl5oAcJKoNE5FSTZsBSAKLcyA5jXGgTT/view?usp=sharing) | 105 | |CIFAR100 | [Download](https://drive.google.com/file/d/1huW-ChBVvKcx7t8HyDaWTQB5Li1Fht9x/view?usp=sharing) | 106 | |STL10 | [Download](https://drive.google.com/file/d/1261NDFfXuKR2Dh4RWHYYhcicdcPag9NZ/view?usp=sharing) | 107 | 108 | ### Clustering 109 | We provide the following pretrained models after training with the __SCAN-loss__, and after the __self-labeling__ step. The best models can be found here and we futher refer to the paper for the averages and standard deviations. 110 | 111 | | Dataset | Step | ACC | NMI | ARI |Download link | 112 | |------------------|-------------------|---------------------- |-----------------|-----------|--------------| 113 | | CIFAR10 | SCAN-loss | 81.6 | 71.5 | 66.5 |[Download](https://drive.google.com/file/d/1v6b6jJY5M4-duSqWpGFmdf9e9T3dPrx0/view?usp=sharing) | 114 | | | Self-labeling | 88.3 | 79.7 | 77.2 |[Download](https://drive.google.com/file/d/18gITFzAbQsGS5vt8hyi5HjbeRDsVLihw/view?usp=sharing) | 115 | | CIFAR100 | SCAN-loss | 44.0 | 44.9 | 28.3 |[Download](https://drive.google.com/file/d/1pPCi1QG05kP_JdoX29dxEhVddIRk68Sd/view?usp=sharing) | 116 | | | Self-labeling | 50.7 | 48.6 | 33.3 |[Download](https://drive.google.com/file/d/11mEmpDMyq63pM4kmDy6ItHouI6Q__uB7/view?usp=sharing) | 117 | | STL10 | SCAN-loss | 79.2 | 67.3 | 61.8 |[Download](https://drive.google.com/file/d/1y1cnGLpeTVo80cnWhAJy-B72FYs2AjZ_/view?usp=sharing) | 118 | | | Self-labeling | 80.9 | 69.8 | 64.6 |[Download](https://drive.google.com/file/d/1uNYN9XOMIPb40hmxOzALg4PWhU_xwkEF/view?usp=sharing) | 119 | | ImageNet-50 | SCAN-loss | 75.1 | 80.5 | 63.5 |[Download](https://drive.google.com/file/d/1UdBtvCHVGd08x8SiH6Cuh6mQmqsADg0t/view?usp=sharing) | 120 | | | Self-labeling | 76.8 | 82.2 | 66.1 |[Download](https://drive.google.com/file/d/1iOE4_lQ4w7CGPLU4algBDG34nz68eN8o/view?usp=sharing) | 121 | | ImageNet-100 | SCAN-loss | 66.2 | 78.7 | 54.4 |[Download](https://drive.google.com/file/d/1tcROQ3wc_MbxmLr05qt-UvF9yVrBwBq9/view?usp=sharing) | 122 | | | Self-labeling | 68.9 | 80.8 | 57.6 |[Download](https://drive.google.com/file/d/1VVgRpJ9DJn9dNrbAKbfPer2FllTvP6Cs/view?usp=sharing) | 123 | | ImageNet-200 | SCAN-loss | 56.3 | 75.7 | 44.1 |[Download](https://drive.google.com/file/d/1oO-OCW2MiXmNC4sD6pkw8PurYScX7oVW/view?usp=sharing) | 124 | | | Self-labeling | 58.1 | 77.2 | 47.0 |[Download](https://drive.google.com/file/d/11dfobUwy6ragh7PoqFagoEns5-teWalm/view?usp=sharing) | 125 | 126 | ### Result ImageNet 127 | We also train SCAN on ImageNet for 1000 clusters. We use 10 clusterheads and finally take the head with the lowest loss. The accuracy (ACC), normalized mutual information (NMI), adjusted mutual information (AMI) and adjusted rand index (ARI) are computed: 128 | 129 | Method | ACC | NMI | AMI | ARI | Download link | 130 | |----------------|---------------------- |-----------------|------|-------|----------------| 131 | | SCAN (ResNet50) | 39.9 | 72.0 | 51.2 | 27.5 |[Download](https://drive.google.com/file/d/1PcF8ydoWoqhxARGuW55KcNarfZyJwcza/view?usp=sharing) | 132 | 133 | 134 | 135 | ### Evaluation 136 | Pretrained models from the model zoo can be evaluated using the `eval.py` script. For example, the model on cifar-10 can be evaluated as follows: 137 | ```shell 138 | python eval.py --config_exp configs/scan/scan_cifar10.yml --model $MODEL_PATH 139 | ``` 140 | Visualizing the prototype images is easily done by setting the `--visualize_prototypes` flag. For example on cifar-10: 141 |

142 | 143 |

144 | 145 | Similarly, you might want to have a look at the clusters found on ImageNet (as shown at the top). First download the model (link in table above) and then execute the following command: 146 | ```shell 147 | python eval.py --config_exp configs/scan/imagenet_eval.yml --model $MODEL_PATH_IMAGENET 148 | ``` 149 | 150 | 151 | 152 | 153 | ## Tutorial 154 | 155 | If you want to see another (more detailed) example for STL-10, checkout [TUTORIAL.md](https://github.com/wvangansbeke/Unsupervised-Classification/blob/master/TUTORIAL.md). It provides a detailed guide and includes visualizations and log files with the training progress. 156 | 157 | ## Citation 158 | 159 | If you find this repo useful for your research, please consider citing our paper: 160 | 161 | ```bibtex 162 | @inproceedings{vangansbeke2020scan, 163 | title={Scan: Learning to classify images without labels}, 164 | author={Van Gansbeke, Wouter and Vandenhende, Simon and Georgoulis, Stamatios and Proesmans, Marc and Van Gool, Luc}, 165 | booktitle={Proceedings of the European Conference on Computer Vision}, 166 | year={2020} 167 | } 168 | 169 | ``` 170 | For any enquiries, please contact the main authors. 171 | 172 | ## License 173 | 174 | This software is released under a creative commons license which allows for personal and research use only. For a commercial license please contact the authors. You can view a license summary [here](http://creativecommons.org/licenses/by-nc/4.0/). 175 | 176 | ## Acknoledgements 177 | This work was supported by Toyota, and was carried out at the TRACE Lab at KU Leuven (Toyota Research on Automated Cars in Europe - Leuven). 178 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | --------------------------------------------------------------------------------