├── configs ├── env.yml ├── pretext │ ├── moco_imagenet50.yml │ ├── moco_imagenet100.yml │ ├── moco_imagenet200.yml │ ├── simclr_stl10.yml │ ├── simclr_cifar10.yml │ └── simclr_cifar20.yml ├── scan │ ├── imagenet_eval.yml │ ├── scan_stl10.yml │ ├── scan_cifar10.yml │ ├── scan_cifar20.yml │ ├── scan_imagenet_50.yml │ ├── scan_imagenet_100.yml │ └── scan_imagenet_200.yml └── selflabel │ ├── selflabel_stl10.yml │ ├── selflabel_cifar10.yml │ ├── selflabel_cifar20.yml │ ├── selflabel_imagenet_200.yml │ ├── selflabel_imagenet_50.yml │ └── selflabel_imagenet_100.yml ├── images ├── teaser.jpg ├── pipeline.png ├── prototypes_cifar10.jpg └── tutorial │ ├── prototypes_stl10.jpg │ └── confusion_matrix_stl10.png ├── models ├── resnet.py ├── models.py ├── resnet_cifar.py └── resnet_stl.py ├── utils ├── ema.py ├── mypath.py ├── collate.py ├── config.py ├── utils.py ├── memory.py ├── train_utils.py ├── evaluate_utils.py └── common_config.py ├── data ├── imagenet_subsets │ ├── imagenet_50.txt │ ├── imagenet_100.txt │ └── imagenet_200.txt ├── custom_dataset.py ├── imagenet.py ├── augment.py ├── stl.py └── cifar.py ├── requirements.txt ├── tutorial_nn.py ├── moco.py ├── selflabel.py ├── losses └── losses.py ├── eval.py ├── scan.py ├── TUTORIAL.md ├── simclr.py ├── README.md └── LICENSE /configs/env.yml: -------------------------------------------------------------------------------- 1 | root_dir: /path/where/to/store/results/ 2 | -------------------------------------------------------------------------------- /images/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/teaser.jpg -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/pipeline.png -------------------------------------------------------------------------------- /images/prototypes_cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/prototypes_cifar10.jpg -------------------------------------------------------------------------------- /images/tutorial/prototypes_stl10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/tutorial/prototypes_stl10.jpg -------------------------------------------------------------------------------- /images/tutorial/confusion_matrix_stl10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/HEAD/images/tutorial/confusion_matrix_stl10.png -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch.nn as nn 6 | import torchvision.models as models 7 | 8 | 9 | def resnet50(): 10 | backbone = models.__dict__['resnet50']() 11 | backbone.fc = nn.Identity() 12 | return {'backbone': backbone, 'dim': 2048} 13 | -------------------------------------------------------------------------------- /configs/pretext/moco_imagenet50.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_50 12 | val_db_name: imagenet_50 13 | num_classes: 50 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /configs/pretext/moco_imagenet100.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_100 12 | val_db_name: imagenet_100 13 | num_classes: 100 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /configs/pretext/moco_imagenet200.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_200 12 | val_db_name: imagenet_200 13 | num_classes: 200 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | class EMA(object): 7 | def __init__(self, model, alpha=0.999): 8 | self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()} 9 | self.param_keys = [k for k, _ in model.named_parameters()] 10 | self.alpha = alpha 11 | 12 | def update_params(self, model): 13 | state = model.state_dict() 14 | for name in self.param_keys: 15 | self.shadow[name].copy_(self.alpha * self.shadow[name] + (1 - self.alpha) * state[name]) 16 | 17 | def apply_shadow(self, model): 18 | model.load_state_dict(self.shadow, strict=True) 19 | -------------------------------------------------------------------------------- /utils/mypath.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | 7 | 8 | class MyPath(object): 9 | @staticmethod 10 | def db_root_dir(database=''): 11 | db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200'} 12 | assert(database in db_names) 13 | 14 | if database == 'cifar-10': 15 | return '/path/to/cifar-10/' 16 | 17 | elif database == 'cifar-20': 18 | return '/path/to/cifar-20/' 19 | 20 | elif database == 'stl-10': 21 | return '/path/to/stl-10/' 22 | 23 | elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']: 24 | return '/path/to/imagenet/' 25 | 26 | else: 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /configs/scan/imagenet_eval.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | num_heads: 10 # Use multiple heads 12 | 13 | # Dataset 14 | train_db_name: imagenet 15 | val_db_name: imagenet 16 | num_classes: 1000 17 | num_neighbors: 50 18 | 19 | # Transformations 20 | augmentation_strategy: simclr 21 | augmentation_kwargs: 22 | random_resized_crop: 23 | size: 224 24 | scale: [0.2, 1.0] 25 | color_jitter_random_apply: 26 | p: 0.8 27 | color_jitter: 28 | brightness: 0.4 29 | contrast: 0.4 30 | saturation: 0.4 31 | hue: 0.1 32 | random_grayscale: 33 | p: 0.2 34 | normalize: 35 | mean: [0.485, 0.456, 0.406] 36 | std: [0.229, 0.224, 0.225] 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | num_workers: 12 45 | -------------------------------------------------------------------------------- /configs/scan/scan_stl10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: stl-10 18 | val_db_name: stl-10 19 | num_classes: 10 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 96 26 | normalize: 27 | mean: [0.485, 0.456, 0.406] 28 | std: [0.229, 0.224, 0.225] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 32 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 96 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 100 47 | batch_size: 128 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_stl10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Loss 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: stl-10 21 | val_db_name: stl-10 22 | num_classes: 10 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 96 28 | normalize: 29 | mean: [0.485, 0.456, 0.406] 30 | std: [0.229, 0.224, 0.225] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 32 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 96 39 | normalize: 40 | mean: [0.485, 0.456, 0.406] 41 | std: [0.229, 0.224, 0.225] 42 | 43 | # Hyperparameters 44 | optimizer: adam 45 | optimizer_kwargs: 46 | lr: 0.0001 47 | weight_decay: 0.0001 48 | epochs: 100 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /configs/scan/scan_cifar10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: cifar-10 18 | val_db_name: cifar-10 19 | num_classes: 10 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 32 26 | normalize: 27 | mean: [0.4914, 0.4822, 0.4465] 28 | std: [0.2023, 0.1994, 0.2010] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 16 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 32 37 | normalize: 38 | mean: [0.4914, 0.4822, 0.4465] 39 | std: [0.2023, 0.1994, 0.2010] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 50 47 | batch_size: 128 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /configs/scan/scan_cifar20.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: cifar-20 18 | val_db_name: cifar-20 19 | num_classes: 20 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 32 26 | normalize: 27 | mean: [0.5071, 0.4867, 0.4408] 28 | std: [0.2675, 0.2565, 0.2761] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 16 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 32 37 | normalize: 38 | mean: [0.5071, 0.4867, 0.4408] 39 | std: [0.2675, 0.2565, 0.2761] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 100 47 | batch_size: 512 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_cifar10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Criterion 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: cifar-10 21 | val_db_name: cifar-10 22 | num_classes: 10 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 32 28 | normalize: 29 | mean: [0.4914, 0.4822, 0.4465] 30 | std: [0.2023, 0.1994, 0.2010] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 16 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 32 39 | normalize: 40 | mean: [0.4914, 0.4822, 0.4465] 41 | std: [0.2023, 0.1994, 0.2010] 42 | 43 | # Hyperparameters 44 | epochs: 200 45 | optimizer: adam 46 | optimizer_kwargs: 47 | lr: 0.0001 48 | weight_decay: 0.0001 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_cifar20.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Criterion 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: cifar-20 21 | val_db_name: cifar-20 22 | num_classes: 20 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 32 28 | normalize: 29 | mean: [0.5071, 0.4867, 0.4408] 30 | std: [0.2675, 0.2565, 0.2761] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 16 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 32 39 | normalize: 40 | mean: [0.5071, 0.4867, 0.4408] 41 | std: [0.2675, 0.2565, 0.2761] 42 | 43 | # Hyperparameters 44 | epochs: 200 45 | optimizer: adam 46 | optimizer_kwargs: 47 | lr: 0.0001 48 | weight_decay: 0.0001 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_imagenet_200.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_200 22 | val_db_name: imagenet_200 23 | num_classes: 200 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 8 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_imagenet_50.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_50 22 | val_db_name: imagenet_50 23 | num_classes: 50 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 16 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /configs/selflabel/selflabel_imagenet_100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_100 22 | val_db_name: imagenet_100 23 | num_classes: 100 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 12 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /configs/pretext/simclr_stl10.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: stl-10 12 | val_db_name: stl-10 13 | num_classes: 10 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 96 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.485, 0.456, 0.406] 51 | std: [0.229, 0.224, 0.225] 52 | 53 | transformation_kwargs: 54 | crop_size: 96 55 | normalize: 56 | mean: [0.485, 0.456, 0.406] 57 | std: [0.229, 0.224, 0.225] 58 | -------------------------------------------------------------------------------- /configs/pretext/simclr_cifar10.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: cifar-10 12 | val_db_name: cifar-10 13 | num_classes: 10 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 32 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.4914, 0.4822, 0.4465] 51 | std: [0.2023, 0.1994, 0.2010] 52 | 53 | transformation_kwargs: 54 | crop_size: 32 55 | normalize: 56 | mean: [0.4914, 0.4822, 0.4465] 57 | std: [0.2023, 0.1994, 0.2010] 58 | -------------------------------------------------------------------------------- /configs/pretext/simclr_cifar20.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: cifar-20 12 | val_db_name: cifar-20 13 | num_classes: 20 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 32 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.5071, 0.4867, 0.4408] 51 | std: [0.2675, 0.2565, 0.2761] 52 | 53 | transformation_kwargs: 54 | crop_size: 32 55 | normalize: 56 | mean: [0.5071, 0.4867, 0.4408] 57 | std: [0.2675, 0.2565, 0.2761] 58 | -------------------------------------------------------------------------------- /data/imagenet_subsets/imagenet_50.txt: -------------------------------------------------------------------------------- 1 | n01601694 Dipper 2 | n01669191 Box Turtle 3 | n01755581 Diamondback Snake 4 | n01770393 Scorpion 5 | n01855672 Goose 6 | n02018207 Water Hen 7 | n02058221 Albatross 8 | n02096177 Cairn Terrier 9 | n02097130 Giant Schnauzer 10 | n02099267 Flat-Coated Retriever 11 | n02100877 Irish Setter 12 | n02104365 Schipperke 13 | n02106030 Collie 14 | n02114855 Coyote 15 | n02125311 Mountain Lion 16 | n02133161 Black Bear 17 | n02484975 Guenon 18 | n02489166 Proboscis Monkey 19 | n02747177 Trash Bin 20 | n02906734 Broom 21 | n03124170 Cowboy Hat 22 | n03272010 Electric Guitar 23 | n03337140 File Cabinet 24 | n03483316 Hair Dryer 25 | n03498962 Hatchet 26 | n03710721 Maillot 27 | n03717622 Manhole Cover 28 | n03733281 Maze 29 | n03759954 Microphone 30 | n03775071 Mitten 31 | n03814639 Neck Brace 32 | n03837869 Obelisk 33 | n03838899 Oboe 34 | n03854065 Organ 35 | n03954731 Woodworking Plane 36 | n03983396 Soda Bottle 37 | n04026417 Purse 38 | n04200800 Shoe Shop 39 | n04209239 Shower Curtain 40 | n04311004 Steel Arch Bridge 41 | n04380533 Table Lamp 42 | n04428191 Threshing Machine 43 | n04443257 Tobacco Shop 44 | n04509417 Unicycle 45 | n04525305 Vending Machine 46 | n04554684 Washing Machine 47 | n04606251 Wreck 48 | n07583066 Guacamole 49 | n07711569 Mashed Potato 50 | n07753592 Banana 51 | -------------------------------------------------------------------------------- /configs/scan/scan_imagenet_50.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_50 18 | val_db_name: imagenet_50 19 | num_classes: 50 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 512 56 | num_workers: 12 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /utils/collate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import numpy as np 7 | import collections 8 | from torch._six import string_classes 9 | 10 | 11 | """ Custom collate function """ 12 | def collate_custom(batch): 13 | if isinstance(batch[0], np.int64): 14 | return np.stack(batch, 0) 15 | 16 | if isinstance(batch[0], torch.Tensor): 17 | return torch.stack(batch, 0) 18 | 19 | elif isinstance(batch[0], np.ndarray): 20 | return np.stack(batch, 0) 21 | 22 | elif isinstance(batch[0], int): 23 | return torch.LongTensor(batch) 24 | 25 | elif isinstance(batch[0], float): 26 | return torch.FloatTensor(batch) 27 | 28 | elif isinstance(batch[0], string_classes): 29 | return batch 30 | 31 | elif isinstance(batch[0], collections.Mapping): 32 | batch_modified = {key: collate_custom([d[key] for d in batch]) for key in batch[0] if key.find('idx') < 0} 33 | return batch_modified 34 | 35 | elif isinstance(batch[0], collections.Sequence): 36 | transposed = zip(*batch) 37 | return [collate_custom(samples) for samples in transposed] 38 | 39 | raise TypeError(('Type is {}'.format(type(batch[0])))) 40 | -------------------------------------------------------------------------------- /configs/scan/scan_imagenet_100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_100 18 | val_db_name: imagenet_100 19 | num_classes: 100 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 1024 56 | num_workers: 16 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /configs/scan/scan_imagenet_200.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_200 18 | val_db_name: imagenet_200 19 | num_classes: 200 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 1024 56 | num_workers: 12 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import yaml 7 | from easydict import EasyDict 8 | from utils.utils import mkdir_if_missing 9 | 10 | def create_config(config_file_env, config_file_exp): 11 | # Config for environment path 12 | with open(config_file_env, 'r') as stream: 13 | root_dir = yaml.safe_load(stream)['root_dir'] 14 | 15 | with open(config_file_exp, 'r') as stream: 16 | config = yaml.safe_load(stream) 17 | 18 | cfg = EasyDict() 19 | 20 | # Copy 21 | for k, v in config.items(): 22 | cfg[k] = v 23 | 24 | # Set paths for pretext task (These directories are needed in every stage) 25 | base_dir = os.path.join(root_dir, cfg['train_db_name']) 26 | pretext_dir = os.path.join(base_dir, 'pretext') 27 | mkdir_if_missing(base_dir) 28 | mkdir_if_missing(pretext_dir) 29 | cfg['pretext_dir'] = pretext_dir 30 | cfg['pretext_checkpoint'] = os.path.join(pretext_dir, 'checkpoint.pth.tar') 31 | cfg['pretext_model'] = os.path.join(pretext_dir, 'model.pth.tar') 32 | cfg['topk_neighbors_train_path'] = os.path.join(pretext_dir, 'topk-train-neighbors.npy') 33 | cfg['topk_neighbors_val_path'] = os.path.join(pretext_dir, 'topk-val-neighbors.npy') 34 | 35 | # If we perform clustering or self-labeling step we need additional paths. 36 | # We also include a run identifier to support multiple runs w/ same hyperparams. 37 | if cfg['setup'] in ['scan', 'selflabel']: 38 | base_dir = os.path.join(root_dir, cfg['train_db_name']) 39 | scan_dir = os.path.join(base_dir, 'scan') 40 | selflabel_dir = os.path.join(base_dir, 'selflabel') 41 | mkdir_if_missing(base_dir) 42 | mkdir_if_missing(scan_dir) 43 | mkdir_if_missing(selflabel_dir) 44 | cfg['scan_dir'] = scan_dir 45 | cfg['scan_checkpoint'] = os.path.join(scan_dir, 'checkpoint.pth.tar') 46 | cfg['scan_model'] = os.path.join(scan_dir, 'model.pth.tar') 47 | cfg['selflabel_dir'] = selflabel_dir 48 | cfg['selflabel_checkpoint'] = os.path.join(selflabel_dir, 'checkpoint.pth.tar') 49 | cfg['selflabel_model'] = os.path.join(selflabel_dir, 'model.pth.tar') 50 | 51 | return cfg 52 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ContrastiveModel(nn.Module): 11 | def __init__(self, backbone, head='mlp', features_dim=128): 12 | super(ContrastiveModel, self).__init__() 13 | self.backbone = backbone['backbone'] 14 | self.backbone_dim = backbone['dim'] 15 | self.head = head 16 | 17 | if head == 'linear': 18 | self.contrastive_head = nn.Linear(self.backbone_dim, features_dim) 19 | 20 | elif head == 'mlp': 21 | self.contrastive_head = nn.Sequential( 22 | nn.Linear(self.backbone_dim, self.backbone_dim), 23 | nn.ReLU(), nn.Linear(self.backbone_dim, features_dim)) 24 | 25 | else: 26 | raise ValueError('Invalid head {}'.format(head)) 27 | 28 | def forward(self, x): 29 | features = self.contrastive_head(self.backbone(x)) 30 | features = F.normalize(features, dim = 1) 31 | return features 32 | 33 | 34 | class ClusteringModel(nn.Module): 35 | def __init__(self, backbone, nclusters, nheads=1): 36 | super(ClusteringModel, self).__init__() 37 | self.backbone = backbone['backbone'] 38 | self.backbone_dim = backbone['dim'] 39 | self.nheads = nheads 40 | assert(isinstance(self.nheads, int)) 41 | assert(self.nheads > 0) 42 | self.cluster_head = nn.ModuleList([nn.Linear(self.backbone_dim, nclusters) for _ in range(self.nheads)]) 43 | 44 | def forward(self, x, forward_pass='default'): 45 | if forward_pass == 'default': 46 | features = self.backbone(x) 47 | out = [cluster_head(features) for cluster_head in self.cluster_head] 48 | 49 | elif forward_pass == 'backbone': 50 | out = self.backbone(x) 51 | 52 | elif forward_pass == 'head': 53 | out = [cluster_head(x) for cluster_head in self.cluster_head] 54 | 55 | elif forward_pass == 'return_all': 56 | features = self.backbone(x) 57 | out = {'features': features, 'output': [cluster_head(features) for cluster_head in self.cluster_head]} 58 | 59 | else: 60 | raise ValueError('Invalid forward pass {}'.format(forward_pass)) 61 | 62 | return out 63 | -------------------------------------------------------------------------------- /data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | """ 10 | AugmentedDataset 11 | Returns an image together with an augmentation. 12 | """ 13 | class AugmentedDataset(Dataset): 14 | def __init__(self, dataset): 15 | super(AugmentedDataset, self).__init__() 16 | transform = dataset.transform 17 | dataset.transform = None 18 | self.dataset = dataset 19 | 20 | if isinstance(transform, dict): 21 | self.image_transform = transform['standard'] 22 | self.augmentation_transform = transform['augment'] 23 | 24 | else: 25 | self.image_transform = transform 26 | self.augmentation_transform = transform 27 | 28 | def __len__(self): 29 | return len(self.dataset) 30 | 31 | def __getitem__(self, index): 32 | sample = self.dataset.__getitem__(index) 33 | image = sample['image'] 34 | 35 | sample['image'] = self.image_transform(image) 36 | sample['image_augmented'] = self.augmentation_transform(image) 37 | 38 | return sample 39 | 40 | 41 | """ 42 | NeighborsDataset 43 | Returns an image with one of its neighbors. 44 | """ 45 | class NeighborsDataset(Dataset): 46 | def __init__(self, dataset, indices, num_neighbors=None): 47 | super(NeighborsDataset, self).__init__() 48 | transform = dataset.transform 49 | 50 | if isinstance(transform, dict): 51 | self.anchor_transform = transform['standard'] 52 | self.neighbor_transform = transform['augment'] 53 | else: 54 | self.anchor_transform = transform 55 | self.neighbor_transform = transform 56 | 57 | dataset.transform = None 58 | self.dataset = dataset 59 | self.indices = indices # Nearest neighbor indices (np.array [len(dataset) x k]) 60 | if num_neighbors is not None: 61 | self.indices = self.indices[:, :num_neighbors+1] 62 | assert(self.indices.shape[0] == len(self.dataset)) 63 | 64 | def __len__(self): 65 | return len(self.dataset) 66 | 67 | def __getitem__(self, index): 68 | output = {} 69 | anchor = self.dataset.__getitem__(index) 70 | 71 | neighbor_index = np.random.choice(self.indices[index], 1)[0] 72 | neighbor = self.dataset.__getitem__(neighbor_index) 73 | 74 | anchor['image'] = self.anchor_transform(anchor['image']) 75 | neighbor['image'] = self.neighbor_transform(neighbor['image']) 76 | 77 | output['anchor'] = anchor['image'] 78 | output['neighbor'] = neighbor['image'] 79 | output['possible_neighbors'] = torch.from_numpy(self.indices[index]) 80 | output['target'] = anchor['target'] 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import numpy as np 8 | import errno 9 | 10 | def mkdir_if_missing(directory): 11 | if not os.path.exists(directory): 12 | try: 13 | os.makedirs(directory) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | class AverageMeter(object): 20 | def __init__(self, name, fmt=':f'): 21 | self.name = name 22 | self.fmt = fmt 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def __str__(self): 38 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 39 | return fmtstr.format(**self.__dict__) 40 | 41 | 42 | class ProgressMeter(object): 43 | def __init__(self, num_batches, meters, prefix=""): 44 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 45 | self.meters = meters 46 | self.prefix = prefix 47 | 48 | def display(self, batch): 49 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 50 | entries += [str(meter) for meter in self.meters] 51 | print('\t'.join(entries)) 52 | 53 | def _get_batch_fmtstr(self, num_batches): 54 | num_digits = len(str(num_batches // 1)) 55 | fmt = '{:' + str(num_digits) + 'd}' 56 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 57 | 58 | 59 | @torch.no_grad() 60 | def fill_memory_bank(loader, model, memory_bank): 61 | model.eval() 62 | memory_bank.reset() 63 | 64 | for i, batch in enumerate(loader): 65 | images = batch['image'].cuda(non_blocking=True) 66 | targets = batch['target'].cuda(non_blocking=True) 67 | output = model(images) 68 | memory_bank.update(output, targets) 69 | if i % 100 == 0: 70 | print('Fill Memory Bank [%d/%d]' %(i, len(loader))) 71 | 72 | 73 | def confusion_matrix(predictions, gt, class_names, output_file=None): 74 | # Plot confusion_matrix and store result to output_file 75 | import sklearn.metrics 76 | import matplotlib.pyplot as plt 77 | confusion_matrix = sklearn.metrics.confusion_matrix(gt, predictions) 78 | confusion_matrix = confusion_matrix / np.sum(confusion_matrix, 1) 79 | 80 | fig, axes = plt.subplots(1) 81 | plt.imshow(confusion_matrix, cmap='Blues') 82 | axes.set_xticks([i for i in range(len(class_names))]) 83 | axes.set_yticks([i for i in range(len(class_names))]) 84 | axes.set_xticklabels(class_names, ha='right', fontsize=8, rotation=40) 85 | axes.set_yticklabels(class_names, ha='right', fontsize=8) 86 | 87 | for (i, j), z in np.ndenumerate(confusion_matrix): 88 | if i == j: 89 | axes.text(j, i, '%d' %(100*z), ha='center', va='center', color='white', fontsize=6) 90 | else: 91 | pass 92 | 93 | plt.tight_layout() 94 | if output_file is None: 95 | plt.show() 96 | else: 97 | plt.savefig(output_file, dpi=300, bbox_inches='tight') 98 | plt.close() 99 | -------------------------------------------------------------------------------- /utils/memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class MemoryBank(object): 10 | def __init__(self, n, dim, num_classes, temperature): 11 | self.n = n 12 | self.dim = dim 13 | self.features = torch.FloatTensor(self.n, self.dim) 14 | self.targets = torch.LongTensor(self.n) 15 | self.ptr = 0 16 | self.device = 'cpu' 17 | self.K = 100 18 | self.temperature = temperature 19 | self.C = num_classes 20 | 21 | def weighted_knn(self, predictions): 22 | # perform weighted knn 23 | retrieval_one_hot = torch.zeros(self.K, self.C).to(self.device) 24 | batchSize = predictions.shape[0] 25 | correlation = torch.matmul(predictions, self.features.t()) 26 | yd, yi = correlation.topk(self.K, dim=1, largest=True, sorted=True) 27 | candidates = self.targets.view(1,-1).expand(batchSize, -1) 28 | retrieval = torch.gather(candidates, 1, yi) 29 | retrieval_one_hot.resize_(batchSize * self.K, self.C).zero_() 30 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 31 | yd_transform = yd.clone().div_(self.temperature).exp_() 32 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , self.C), 33 | yd_transform.view(batchSize, -1, 1)), 1) 34 | _, class_preds = probs.sort(1, True) 35 | class_pred = class_preds[:, 0] 36 | 37 | return class_pred 38 | 39 | def knn(self, predictions): 40 | # perform knn 41 | correlation = torch.matmul(predictions, self.features.t()) 42 | sample_pred = torch.argmax(correlation, dim=1) 43 | class_pred = torch.index_select(self.targets, 0, sample_pred) 44 | return class_pred 45 | 46 | def mine_nearest_neighbors(self, topk, calculate_accuracy=True): 47 | # mine the topk nearest neighbors for every sample 48 | import faiss 49 | features = self.features.cpu().numpy() 50 | n, dim = features.shape[0], features.shape[1] 51 | index = faiss.IndexFlatIP(dim) 52 | index = faiss.index_cpu_to_all_gpus(index) 53 | index.add(features) 54 | distances, indices = index.search(features, topk+1) # Sample itself is included 55 | 56 | # evaluate 57 | if calculate_accuracy: 58 | targets = self.targets.cpu().numpy() 59 | neighbor_targets = np.take(targets, indices[:,1:], axis=0) # Exclude sample itself for eval 60 | anchor_targets = np.repeat(targets.reshape(-1,1), topk, axis=1) 61 | accuracy = np.mean(neighbor_targets == anchor_targets) 62 | return indices, accuracy 63 | 64 | else: 65 | return indices 66 | 67 | def reset(self): 68 | self.ptr = 0 69 | 70 | def update(self, features, targets): 71 | b = features.size(0) 72 | 73 | assert(b + self.ptr <= self.n) 74 | 75 | self.features[self.ptr:self.ptr+b].copy_(features.detach()) 76 | self.targets[self.ptr:self.ptr+b].copy_(targets.detach()) 77 | self.ptr += b 78 | 79 | def to(self, device): 80 | self.features = self.features.to(device) 81 | self.targets = self.targets.to(device) 82 | self.device = device 83 | 84 | def cpu(self): 85 | self.to('cpu') 86 | 87 | def cuda(self): 88 | self.to('cuda:0') 89 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | from PIL import Image 10 | from utils.mypath import MyPath 11 | from torchvision import transforms as tf 12 | from glob import glob 13 | 14 | 15 | class ImageNet(datasets.ImageFolder): 16 | def __init__(self, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): 17 | super(ImageNet, self).__init__(root=os.path.join(root, 'ILSVRC2012_img_%s' %(split)), 18 | transform=None) 19 | self.transform = transform 20 | self.split = split 21 | self.resize = tf.Resize(256) 22 | 23 | def __len__(self): 24 | return len(self.imgs) 25 | 26 | def __getitem__(self, index): 27 | path, target = self.imgs[index] 28 | with open(path, 'rb') as f: 29 | img = Image.open(f).convert('RGB') 30 | im_size = img.size 31 | img = self.resize(img) 32 | 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | 36 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index}} 37 | 38 | return out 39 | 40 | def get_image(self, index): 41 | path, target = self.imgs[index] 42 | with open(path, 'rb') as f: 43 | img = Image.open(f).convert('RGB') 44 | img = self.resize(img) 45 | return img 46 | 47 | 48 | class ImageNetSubset(data.Dataset): 49 | def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', 50 | transform=None): 51 | super(ImageNetSubset, self).__init__() 52 | 53 | self.root = os.path.join(root, 'ILSVRC2012_img_%s' %(split)) 54 | self.transform = transform 55 | self.split = split 56 | 57 | # Read the subset of classes to include (sorted) 58 | with open(subset_file, 'r') as f: 59 | result = f.read().splitlines() 60 | subdirs, class_names = [], [] 61 | for line in result: 62 | subdir, class_name = line.split(' ', 1) 63 | subdirs.append(subdir) 64 | class_names.append(class_name) 65 | 66 | # Gather the files (sorted) 67 | imgs = [] 68 | for i, subdir in enumerate(subdirs): 69 | subdir_path = os.path.join(self.root, subdir) 70 | files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG'))) 71 | for f in files: 72 | imgs.append((f, i)) 73 | self.imgs = imgs 74 | self.classes = class_names 75 | 76 | # Resize 77 | self.resize = tf.Resize(256) 78 | 79 | def get_image(self, index): 80 | path, target = self.imgs[index] 81 | with open(path, 'rb') as f: 82 | img = Image.open(f).convert('RGB') 83 | img = self.resize(img) 84 | return img 85 | 86 | def __len__(self): 87 | return len(self.imgs) 88 | 89 | def __getitem__(self, index): 90 | path, target = self.imgs[index] 91 | with open(path, 'rb') as f: 92 | img = Image.open(f).convert('RGB') 93 | im_size = img.size 94 | img = self.resize(img) 95 | class_name = self.classes[target] 96 | 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | 100 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index, 'class_name': class_name}} 101 | 102 | return out 103 | -------------------------------------------------------------------------------- /data/imagenet_subsets/imagenet_100.txt: -------------------------------------------------------------------------------- 1 | n01558993 robin, American robin, Turdus migratorius 2 | n01601694 water ouzel, dipper 3 | n01669191 box turtle, box tortoise 4 | n01751748 sea snake 5 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 6 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 7 | n01770393 scorpion 8 | n01855672 goose 9 | n01871265 tusker 10 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 11 | n02037110 oystercatcher, oyster catcher 12 | n02058221 albatross, mollymawk 13 | n02087046 toy terrier 14 | n02088632 bluetick 15 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 16 | n02093754 Border terrier 17 | n02094114 Norfolk terrier 18 | n02096177 cairn, cairn terrier 19 | n02097130 giant schnauzer 20 | n02097298 Scotch terrier, Scottish terrier, Scottie 21 | n02099267 flat-coated retriever 22 | n02100877 Irish setter, red setter 23 | n02104365 schipperke 24 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 25 | n02106030 collie 26 | n02106166 Border collie 27 | n02107142 Doberman, Doberman pinscher 28 | n02110341 dalmatian, coach dog, carriage dog 29 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 30 | n02120079 Arctic fox, white fox, Alopex lagopus 31 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 32 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 33 | n02128385 leopard, Panthera pardus 34 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 35 | n02277742 ringlet, ringlet butterfly 36 | n02325366 wood rabbit, cottontail, cottontail rabbit 37 | n02364673 guinea pig, Cavia cobaya 38 | n02484975 guenon, guenon monkey 39 | n02489166 proboscis monkey, Nasalis larvatus 40 | n02708093 analog clock 41 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 42 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 43 | n02906734 broom 44 | n02909870 bucket, pail 45 | n03085013 computer keyboard, keypad 46 | n03124170 cowboy hat, ten-gallon hat 47 | n03127747 crash helmet 48 | n03160309 dam, dike, dyke 49 | n03255030 dumbbell 50 | n03272010 electric guitar 51 | n03291819 envelope 52 | n03337140 file, file cabinet, filing cabinet 53 | n03450230 gown 54 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 55 | n03498962 hatchet 56 | n03530642 honeycomb 57 | n03623198 knee pad 58 | n03649909 lawn mower, mower 59 | n03710721 maillot, tank suit 60 | n03717622 manhole cover 61 | n03733281 maze, labyrinth 62 | n03759954 microphone, mike 63 | n03775071 mitten 64 | n03814639 neck brace 65 | n03837869 obelisk 66 | n03838899 oboe, hautboy, hautbois 67 | n03854065 organ, pipe organ 68 | n03929855 pickelhaube 69 | n03930313 picket fence, paling 70 | n03954731 plane, carpenter's plane, woodworking plane 71 | n03956157 planetarium 72 | n03983396 pop bottle, soda bottle 73 | n04004767 printer 74 | n04026417 purse 75 | n04065272 recreational vehicle, RV, R.V. 76 | n04200800 shoe shop, shoe-shop, shoe store 77 | n04209239 shower curtain 78 | n04235860 sleeping bag 79 | n04311004 steel arch bridge 80 | n04325704 stole 81 | n04336792 stretcher 82 | n04346328 stupa, tope 83 | n04380533 table lamp 84 | n04428191 thresher, thrasher, threshing machine 85 | n04443257 tobacco shop, tobacconist shop, tobacconist 86 | n04458633 totem pole 87 | n04483307 trimaran 88 | n04509417 unicycle, monocycle 89 | n04515003 upright, upright piano 90 | n04525305 vending machine 91 | n04554684 washer, automatic washer, washing machine 92 | n04591157 Windsor tie 93 | n04592741 wing 94 | n04606251 wreck 95 | n07583066 guacamole 96 | n07613480 trifle 97 | n07693725 bagel, beigel 98 | n07711569 mashed potato 99 | n07753592 banana 100 | n11879895 rapeseed 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | """ This file contains a list of packages and their versions that were used to produce the results. """ 2 | - _libgcc_mutex=0.1=main 3 | - blas=1.0=mkl 4 | - bzip2=1.0.8=h7b6447c_0 5 | - ca-certificates=2020.1.1=0 6 | - cairo=1.14.12=h8948797_3 7 | - certifi=2020.4.5.1=py37_0 8 | - cffi=1.14.0=py37h2e261b9_0 9 | - cmake=3.14.0=h52cb24c_0 10 | - cudatoolkit=10.0.130=0 11 | - cycler=0.10.0=py37_0 12 | - dbus=1.13.12=h746ee38_0 13 | - easydict=1.9=py_0 14 | - expat=2.2.6=he6710b0_0 15 | - faiss-gpu=1.6.3=py37h1a5d453_0 16 | - ffmpeg=4.0=hcdf2ecd_0 17 | - fontconfig=2.13.0=h9420a91_0 18 | - freeglut=3.0.0=hf484d3e_5 19 | - freetype=2.9.1=h8a8886c_1 20 | - glib=2.63.1=h5a9c865_0 21 | - graphite2=1.3.13=h23475e2_0 22 | - gst-plugins-base=1.14.0=hbbd80ab_1 23 | - gstreamer=1.14.0=hb453b48_1 24 | - h5py=2.8.0=py37h989c5e5_3 25 | - harfbuzz=1.8.8=hffaf4a1_0 26 | - hdf5=1.10.2=hba1933b_1 27 | - icu=58.2=h9c2bf20_1 28 | - imageio=2.8.0=py_0 29 | - intel-openmp=2020.0=166 30 | - jasper=2.0.14=h07fcdf6_1 31 | - joblib=0.14.1=py_0 32 | - jpeg=9b=h024ee3a_2 33 | - kiwisolver=1.1.0=py37he6710b0_0 34 | - krb5=1.17.1=h173b8e3_0 35 | - ld_impl_linux-64=2.33.1=h53a641e_7 36 | - libcurl=7.69.1=h20c2e04_0 37 | - libedit=3.1.20181209=hc058e9b_0 38 | - libffi=3.2.1=hd88cf55_4 39 | - libgcc-ng=9.1.0=hdf63c60_0 40 | - libgfortran-ng=7.3.0=hdf63c60_0 41 | - libglu=9.0.0=hf484d3e_1 42 | - libopencv=3.4.2=hb342d67_1 43 | - libopus=1.3.1=h7b6447c_0 44 | - libpng=1.6.37=hbc83047_0 45 | - libprotobuf=3.11.4=hd408876_0 46 | - libssh2=1.9.0=h1ba5d50_1 47 | - libstdcxx-ng=9.1.0=hdf63c60_0 48 | - libtiff=4.1.0=h2733197_0 49 | - libuuid=1.0.3=h1bed415_2 50 | - libvpx=1.7.0=h439df22_0 51 | - libxcb=1.13=h1bed415_1 52 | - libxml2=2.9.9=hea5a465_1 53 | - matplotlib=3.1.3=py37_0 54 | - matplotlib-base=3.1.3=py37hef1b27d_0 55 | - mkl=2020.0=166 56 | - mkl-service=2.3.0=py37he904b0f_0 57 | - mkl_fft=1.0.15=py37ha843d7b_0 58 | - mkl_random=1.1.0=py37hd6b4f25_0 59 | - ncurses=6.2=he6710b0_0 60 | - ninja=1.9.0=py37hfd86e86_0 61 | - numpy=1.18.1=py37h4f9e942_0 62 | - numpy-base=1.18.1=py37hde5b4d6_1 63 | - olefile=0.46=py_0 64 | - opencv=3.4.2=py37h6fd60c2_1 65 | - openssl=1.1.1g=h7b6447c_0 66 | - pcre=8.43=he6710b0_0 67 | - pillow=7.0.0=py37hb39fc2d_0 68 | - pip=20.0.2=py37_1 69 | - pixman=0.38.0=h7b6447c_0 70 | - protobuf=3.11.4=py37he6710b0_0 71 | - py-opencv=3.4.2=py37hb342d67_1 72 | - pycparser=2.20=py_0 73 | - pyparsing=2.4.6=py_0 74 | - pyqt=5.9.2=py37h05f1152_2 75 | - python=3.7.7=hcf32534_0_cpython 76 | - python-dateutil=2.8.1=py_0 77 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0 78 | - pyyaml=5.3.1=py37h7b6447c_0 79 | - qt=5.9.7=h5867ecd_1 80 | - readline=8.0=h7b6447c_0 81 | - rhash=1.3.8=h1ba5d50_0 82 | - scikit-learn=0.22.1=py37hd81dba3_0 83 | - scipy=1.4.1=py37h0b6359f_0 84 | - setuptools=46.1.3=py37_0 85 | - sip=4.19.8=py37hf484d3e_0 86 | - six=1.14.0=py37_0 87 | - sqlite=3.31.1=h7b6447c_0 88 | - swig=3.0.12=h38cdd7d_3 89 | - tensorboardx=2.0=py_0 90 | - termcolor=1.1.0=py37_1 91 | - tk=8.6.8=hbc83047_0 92 | - torchvision=0.5.0=py37_cu100 93 | - tornado=6.0.4=py37h7b6447c_1 94 | - typing=3.6.4=py37_0 95 | - wheel=0.34.2=py37_0 96 | - xz=5.2.4=h14c3975_4 97 | - yaml=0.1.7=had09818_2 98 | - zlib=1.2.11=h7b6447c_3 99 | - zstd=1.3.7=h0b5b093_0 100 | - pip: 101 | - blis==0.4.1 102 | - catalogue==1.0.0 103 | - chardet==3.0.4 104 | - cymem==2.0.3 105 | - en-core-web-sm==2.2.5 106 | - idna==2.9 107 | - importlib-metadata==1.6.0 108 | - murmurhash==1.0.2 109 | - plac==1.1.3 110 | - preshed==3.0.2 111 | - requests==2.23.0 112 | - spacy==2.2.4 113 | - srsly==1.0.2 114 | - thinc==7.4.0 115 | - tqdm==4.45.0 116 | - urllib3==1.25.8 117 | - wasabi==0.6.0 118 | - zipp==3.1.0 119 | -------------------------------------------------------------------------------- /tutorial_nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import numpy as np 8 | import torch 9 | 10 | from utils.config import create_config 11 | from utils.common_config import get_model, get_train_dataset, \ 12 | get_val_dataset, \ 13 | get_val_dataloader, \ 14 | get_val_transformations \ 15 | 16 | from utils.memory import MemoryBank 17 | from utils.train_utils import simclr_train 18 | from utils.utils import fill_memory_bank 19 | from termcolor import colored 20 | 21 | # Parser 22 | parser = argparse.ArgumentParser(description='Eval_nn') 23 | parser.add_argument('--config_env', 24 | help='Config file for the environment') 25 | parser.add_argument('--config_exp', 26 | help='Config file for the experiment') 27 | args = parser.parse_args() 28 | 29 | def main(): 30 | 31 | # Retrieve config file 32 | p = create_config(args.config_env, args.config_exp) 33 | print(colored(p, 'red')) 34 | 35 | # Model 36 | print(colored('Retrieve model', 'blue')) 37 | model = get_model(p) 38 | print('Model is {}'.format(model.__class__.__name__)) 39 | print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) 40 | print(model) 41 | model = model.cuda() 42 | 43 | # CUDNN 44 | print(colored('Set CuDNN benchmark', 'blue')) 45 | torch.backends.cudnn.benchmark = True 46 | 47 | # Dataset 48 | val_transforms = get_val_transformations(p) 49 | print('Validation transforms:', val_transforms) 50 | val_dataset = get_val_dataset(p, val_transforms) 51 | val_dataloader = get_val_dataloader(p, val_dataset) 52 | print('Dataset contains {} val samples'.format(len(val_dataset))) 53 | 54 | # Memory Bank 55 | print(colored('Build MemoryBank', 'blue')) 56 | base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval 57 | base_dataloader = get_val_dataloader(p, base_dataset) 58 | memory_bank_base = MemoryBank(len(base_dataset), 59 | p['model_kwargs']['features_dim'], 60 | p['num_classes'], p['criterion_kwargs']['temperature']) 61 | memory_bank_base.cuda() 62 | memory_bank_val = MemoryBank(len(val_dataset), 63 | p['model_kwargs']['features_dim'], 64 | p['num_classes'], p['criterion_kwargs']['temperature']) 65 | memory_bank_val.cuda() 66 | 67 | # Checkpoint 68 | assert os.path.exists(p['pretext_checkpoint']) 69 | print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue')) 70 | checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu') 71 | model.load_state_dict(checkpoint) 72 | model.cuda() 73 | 74 | # Save model 75 | torch.save(model.state_dict(), p['pretext_model']) 76 | 77 | # Mine the topk nearest neighbors at the very end (Train) 78 | # These will be served as input to the SCAN loss. 79 | print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue')) 80 | fill_memory_bank(base_dataloader, model, memory_bank_base) 81 | topk = 20 82 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 83 | indices, acc = memory_bank_base.mine_nearest_neighbors(topk) 84 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) 85 | np.save(p['topk_neighbors_train_path'], indices) 86 | 87 | # Mine the topk nearest neighbors at the very end (Val) 88 | # These will be used for validation. 89 | print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue')) 90 | fill_memory_bank(val_dataloader, model, memory_bank_val) 91 | topk = 5 92 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 93 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk) 94 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) 95 | np.save(p['topk_neighbors_val_path'], indices) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /data/augment.py: -------------------------------------------------------------------------------- 1 | # List of augmentations based on randaugment 2 | import random 3 | 4 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 5 | import numpy as np 6 | import torch 7 | from torchvision.transforms.transforms import Compose 8 | 9 | random_mirror = True 10 | 11 | def ShearX(img, v): 12 | if random_mirror and random.random() > 0.5: 13 | v = -v 14 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 15 | 16 | def ShearY(img, v): 17 | if random_mirror and random.random() > 0.5: 18 | v = -v 19 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 20 | 21 | def Identity(img, v): 22 | return img 23 | 24 | def TranslateX(img, v): 25 | if random_mirror and random.random() > 0.5: 26 | v = -v 27 | v = v * img.size[0] 28 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 29 | 30 | def TranslateY(img, v): 31 | if random_mirror and random.random() > 0.5: 32 | v = -v 33 | v = v * img.size[1] 34 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 35 | 36 | def TranslateXAbs(img, v): 37 | if random.random() > 0.5: 38 | v = -v 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 40 | 41 | def TranslateYAbs(img, v): 42 | if random.random() > 0.5: 43 | v = -v 44 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 45 | 46 | def Rotate(img, v): 47 | if random_mirror and random.random() > 0.5: 48 | v = -v 49 | return img.rotate(v) 50 | 51 | def AutoContrast(img, _): 52 | return PIL.ImageOps.autocontrast(img) 53 | 54 | def Invert(img, _): 55 | return PIL.ImageOps.invert(img) 56 | 57 | def Equalize(img, _): 58 | return PIL.ImageOps.equalize(img) 59 | 60 | def Solarize(img, v): 61 | return PIL.ImageOps.solarize(img, v) 62 | 63 | def Posterize(img, v): 64 | v = int(v) 65 | return PIL.ImageOps.posterize(img, v) 66 | 67 | def Contrast(img, v): 68 | return PIL.ImageEnhance.Contrast(img).enhance(v) 69 | 70 | def Color(img, v): 71 | return PIL.ImageEnhance.Color(img).enhance(v) 72 | 73 | def Brightness(img, v): 74 | return PIL.ImageEnhance.Brightness(img).enhance(v) 75 | 76 | def Sharpness(img, v): 77 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 78 | 79 | def augment_list(): 80 | l = [ 81 | (Identity, 0, 1), 82 | (AutoContrast, 0, 1), 83 | (Equalize, 0, 1), 84 | (Rotate, -30, 30), 85 | (Solarize, 0, 256), 86 | (Color, 0.05, 0.95), 87 | (Contrast, 0.05, 0.95), 88 | (Brightness, 0.05, 0.95), 89 | (Sharpness, 0.05, 0.95), 90 | (ShearX, -0.1, 0.1), 91 | (TranslateX, -0.1, 0.1), 92 | (TranslateY, -0.1, 0.1), 93 | (Posterize, 4, 8), 94 | (ShearY, -0.1, 0.1), 95 | ] 96 | return l 97 | 98 | 99 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 100 | 101 | class Augment: 102 | def __init__(self, n): 103 | self.n = n 104 | self.augment_list = augment_list() 105 | 106 | def __call__(self, img): 107 | ops = random.choices(self.augment_list, k=self.n) 108 | for op, minval, maxval in ops: 109 | val = (random.random()) * float(maxval - minval) + minval 110 | img = op(img, val) 111 | 112 | return img 113 | 114 | def get_augment(name): 115 | return augment_dict[name] 116 | 117 | def apply_augment(img, name, level): 118 | augment_fn, low, high = get_augment(name) 119 | return augment_fn(img.copy(), level * (high - low) + low) 120 | 121 | class Cutout(object): 122 | def __init__(self, n_holes, length, random=False): 123 | self.n_holes = n_holes 124 | self.length = length 125 | self.random = random 126 | 127 | def __call__(self, img): 128 | h = img.size(1) 129 | w = img.size(2) 130 | length = random.randint(1, self.length) 131 | mask = np.ones((h, w), np.float32) 132 | 133 | for n in range(self.n_holes): 134 | y = np.random.randint(h) 135 | x = np.random.randint(w) 136 | 137 | y1 = np.clip(y - length // 2, 0, h) 138 | y2 = np.clip(y + length // 2, 0, h) 139 | x1 = np.clip(x - length // 2, 0, w) 140 | x2 = np.clip(x + length // 2, 0, w) 141 | 142 | mask[y1: y2, x1: x2] = 0. 143 | 144 | mask = torch.from_numpy(mask) 145 | mask = mask.expand_as(img) 146 | img = img * mask 147 | 148 | return img 149 | -------------------------------------------------------------------------------- /moco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | import numpy as np 9 | 10 | from utils.config import create_config 11 | from utils.common_config import get_model, get_train_dataset,\ 12 | get_val_dataset, get_val_dataloader, get_val_transformations 13 | from utils.memory import MemoryBank 14 | from utils.utils import fill_memory_bank 15 | from termcolor import colored 16 | 17 | # Parser 18 | parser = argparse.ArgumentParser(description='MoCo') 19 | parser.add_argument('--config_env', 20 | help='Config file for the environment') 21 | parser.add_argument('--config_exp', 22 | help='Config file for the experiment') 23 | args = parser.parse_args() 24 | 25 | def main(): 26 | # Retrieve config file 27 | p = create_config(args.config_env, args.config_exp) 28 | print(colored(p, 'red')) 29 | 30 | 31 | # Model 32 | print(colored('Retrieve model', 'blue')) 33 | model = get_model(p) 34 | print('Model is {}'.format(model.__class__.__name__)) 35 | print(model) 36 | model = torch.nn.DataParallel(model) 37 | model = model.cuda() 38 | 39 | 40 | # CUDNN 41 | print(colored('Set CuDNN benchmark', 'blue')) 42 | torch.backends.cudnn.benchmark = True 43 | 44 | 45 | # Dataset 46 | print(colored('Retrieve dataset', 'blue')) 47 | transforms = get_val_transformations(p) 48 | train_dataset = get_train_dataset(p, transforms) 49 | val_dataset = get_val_dataset(p, transforms) 50 | train_dataloader = get_val_dataloader(p, train_dataset) 51 | val_dataloader = get_val_dataloader(p, val_dataset) 52 | print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset))) 53 | 54 | 55 | # Memory Bank 56 | print(colored('Build MemoryBank', 'blue')) 57 | memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'], p['temperature']) 58 | memory_bank_train.cuda() 59 | memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'], p['temperature']) 60 | memory_bank_val.cuda() 61 | 62 | 63 | # Load the official MoCoV2 checkpoint 64 | print(colored('Downloading moco v2 checkpoint', 'blue')) 65 | os.system('wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar') 66 | moco_state = torch.load('moco_v2_800ep_pretrain.pth.tar', map_location='cpu') 67 | 68 | 69 | # Transfer moco weights 70 | print(colored('Transfer MoCo weights to model', 'blue')) 71 | new_state_dict = {} 72 | state_dict = moco_state['state_dict'] 73 | for k in list(state_dict.keys()): 74 | # Copy backbone weights 75 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 76 | new_k = 'module.backbone.' + k[len('module.encoder_q.'):] 77 | new_state_dict[new_k] = state_dict[k] 78 | 79 | # Copy mlp weights 80 | elif k.startswith('module.encoder_q.fc'): 81 | new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.'):] 82 | new_state_dict[new_k] = state_dict[k] 83 | 84 | else: 85 | raise ValueError('Unexpected key {}'.format(k)) 86 | 87 | model.load_state_dict(new_state_dict) 88 | os.system('rm -rf moco_v2_800ep_pretrain.pth.tar') 89 | 90 | 91 | # Save final model 92 | print(colored('Save pretext model', 'blue')) 93 | torch.save(model.module.state_dict(), p['pretext_model']) 94 | model.module.contrastive_head = torch.nn.Identity() # In this case, we mine the neighbors before the MLP. 95 | 96 | 97 | # Mine the topk nearest neighbors (Train) 98 | # These will be used for training with the SCAN-Loss. 99 | topk = 50 100 | print(colored('Mine the nearest neighbors (Train)(Top-%d)' %(topk), 'blue')) 101 | transforms = get_val_transformations(p) 102 | train_dataset = get_train_dataset(p, transforms) 103 | fill_memory_bank(train_dataloader, model, memory_bank_train) 104 | indices, acc = memory_bank_train.mine_nearest_neighbors(topk) 105 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) 106 | np.save(p['topk_neighbors_train_path'], indices) 107 | 108 | 109 | # Mine the topk nearest neighbors (Validation) 110 | # These will be used for validation. 111 | topk = 5 112 | print(colored('Mine the nearest neighbors (Val)(Top-%d)' %(topk), 'blue')) 113 | fill_memory_bank(val_dataloader, model, memory_bank_val) 114 | print('Mine the neighbors') 115 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk) 116 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) 117 | np.save(p['topk_neighbors_val_path'], indices) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /selflabel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | 9 | from utils.config import create_config 10 | from utils.common_config import get_train_dataset, get_train_transformations,\ 11 | get_val_dataset, get_val_transformations,\ 12 | get_train_dataloader, get_val_dataloader,\ 13 | get_optimizer, get_model, adjust_learning_rate,\ 14 | get_criterion 15 | from utils.ema import EMA 16 | from utils.evaluate_utils import get_predictions, hungarian_evaluate 17 | from utils.train_utils import selflabel_train 18 | from termcolor import colored 19 | 20 | # Parser 21 | parser = argparse.ArgumentParser(description='Self-labeling') 22 | parser.add_argument('--config_env', 23 | help='Config file for the environment') 24 | parser.add_argument('--config_exp', 25 | help='Config file for the experiment') 26 | args = parser.parse_args() 27 | 28 | def main(): 29 | # Retrieve config file 30 | p = create_config(args.config_env, args.config_exp) 31 | print(colored(p, 'red')) 32 | 33 | # Get model 34 | print(colored('Retrieve model', 'blue')) 35 | model = get_model(p, p['scan_model']) 36 | print(model) 37 | model = torch.nn.DataParallel(model) 38 | model = model.cuda() 39 | 40 | # Get criterion 41 | print(colored('Get loss', 'blue')) 42 | criterion = get_criterion(p) 43 | criterion.cuda() 44 | print(criterion) 45 | 46 | # CUDNN 47 | print(colored('Set CuDNN benchmark', 'blue')) 48 | torch.backends.cudnn.benchmark = True 49 | 50 | # Optimizer 51 | print(colored('Retrieve optimizer', 'blue')) 52 | optimizer = get_optimizer(p, model) 53 | print(optimizer) 54 | 55 | # Dataset 56 | print(colored('Retrieve dataset', 'blue')) 57 | 58 | # Transforms 59 | strong_transforms = get_train_transformations(p) 60 | val_transforms = get_val_transformations(p) 61 | train_dataset = get_train_dataset(p, {'standard': val_transforms, 'augment': strong_transforms}, 62 | split='train', to_augmented_dataset=True) 63 | train_dataloader = get_train_dataloader(p, train_dataset) 64 | val_dataset = get_val_dataset(p, val_transforms) 65 | val_dataloader = get_val_dataloader(p, val_dataset) 66 | print(colored('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)), 'yellow')) 67 | 68 | # Checkpoint 69 | if os.path.exists(p['selflabel_checkpoint']): 70 | print(colored('Restart from checkpoint {}'.format(p['selflabel_checkpoint']), 'blue')) 71 | checkpoint = torch.load(p['selflabel_checkpoint'], map_location='cpu') 72 | model.load_state_dict(checkpoint['model']) 73 | optimizer.load_state_dict(checkpoint['optimizer']) 74 | start_epoch = checkpoint['epoch'] 75 | 76 | else: 77 | print(colored('No checkpoint file at {}'.format(p['selflabel_checkpoint']), 'blue')) 78 | start_epoch = 0 79 | 80 | # EMA 81 | if p['use_ema']: 82 | ema = EMA(model, alpha=p['ema_alpha']) 83 | else: 84 | ema = None 85 | 86 | # Main loop 87 | print(colored('Starting main loop', 'blue')) 88 | 89 | for epoch in range(start_epoch, p['epochs']): 90 | print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow')) 91 | print(colored('-'*10, 'yellow')) 92 | 93 | # Adjust lr 94 | lr = adjust_learning_rate(p, optimizer, epoch) 95 | print('Adjusted learning rate to {:.5f}'.format(lr)) 96 | 97 | # Perform self-labeling 98 | print('Train ...') 99 | selflabel_train(train_dataloader, model, criterion, optimizer, epoch, ema=ema) 100 | 101 | # Evaluate (To monitor progress - Not for validation) 102 | print('Evaluate ...') 103 | predictions = get_predictions(p, val_dataloader, model) 104 | clustering_stats = hungarian_evaluate(0, predictions, compute_confusion_matrix=False) 105 | print(clustering_stats) 106 | 107 | # Checkpoint 108 | print('Checkpoint ...') 109 | torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 110 | 'epoch': epoch + 1}, p['selflabel_checkpoint']) 111 | torch.save(model.module.state_dict(), p['selflabel_model']) 112 | 113 | # Evaluate and save the final model 114 | print(colored('Evaluate model at the end', 'blue')) 115 | predictions = get_predictions(p, val_dataloader, model) 116 | clustering_stats = hungarian_evaluate(0, predictions, 117 | class_names=val_dataset.classes, 118 | compute_confusion_matrix=True, 119 | confusion_matrix_file=os.path.join(p['selflabel_dir'], 'confusion_matrix.png')) 120 | print(clustering_stats) 121 | torch.save(model.module.state_dict(), p['selflabel_model']) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import numpy as np 7 | from utils.utils import AverageMeter, ProgressMeter 8 | 9 | 10 | def simclr_train(train_loader, model, criterion, optimizer, epoch): 11 | """ 12 | Train according to the scheme from SimCLR 13 | https://arxiv.org/abs/2002.05709 14 | """ 15 | losses = AverageMeter('Loss', ':.4e') 16 | progress = ProgressMeter(len(train_loader), 17 | [losses], 18 | prefix="Epoch: [{}]".format(epoch)) 19 | 20 | model.train() 21 | 22 | for i, batch in enumerate(train_loader): 23 | images = batch['image'] 24 | images_augmented = batch['image_augmented'] 25 | b, c, h, w = images.size() 26 | input_ = torch.cat([images.unsqueeze(1), images_augmented.unsqueeze(1)], dim=1) 27 | input_ = input_.view(-1, c, h, w) 28 | input_ = input_.cuda(non_blocking=True) 29 | targets = batch['target'].cuda(non_blocking=True) 30 | 31 | output = model(input_).view(b, 2, -1) 32 | loss = criterion(output) 33 | losses.update(loss.item()) 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | if i % 25 == 0: 40 | progress.display(i) 41 | 42 | 43 | def scan_train(train_loader, model, criterion, optimizer, epoch, update_cluster_head_only=False): 44 | """ 45 | Train w/ SCAN-Loss 46 | """ 47 | total_losses = AverageMeter('Total Loss', ':.4e') 48 | consistency_losses = AverageMeter('Consistency Loss', ':.4e') 49 | entropy_losses = AverageMeter('Entropy', ':.4e') 50 | progress = ProgressMeter(len(train_loader), 51 | [total_losses, consistency_losses, entropy_losses], 52 | prefix="Epoch: [{}]".format(epoch)) 53 | 54 | if update_cluster_head_only: 55 | model.eval() # No need to update BN 56 | else: 57 | model.train() # Update BN 58 | 59 | for i, batch in enumerate(train_loader): 60 | # Forward pass 61 | anchors = batch['anchor'].cuda(non_blocking=True) 62 | neighbors = batch['neighbor'].cuda(non_blocking=True) 63 | 64 | if update_cluster_head_only: # Only calculate gradient for backprop of linear layer 65 | with torch.no_grad(): 66 | anchors_features = model(anchors, forward_pass='backbone') 67 | neighbors_features = model(neighbors, forward_pass='backbone') 68 | anchors_output = model(anchors_features, forward_pass='head') 69 | neighbors_output = model(neighbors_features, forward_pass='head') 70 | 71 | else: # Calculate gradient for backprop of complete network 72 | anchors_output = model(anchors) 73 | neighbors_output = model(neighbors) 74 | 75 | # Loss for every head 76 | total_loss, consistency_loss, entropy_loss = [], [], [] 77 | for anchors_output_subhead, neighbors_output_subhead in zip(anchors_output, neighbors_output): 78 | total_loss_, consistency_loss_, entropy_loss_ = criterion(anchors_output_subhead, 79 | neighbors_output_subhead) 80 | total_loss.append(total_loss_) 81 | consistency_loss.append(consistency_loss_) 82 | entropy_loss.append(entropy_loss_) 83 | 84 | # Register the mean loss and backprop the total loss to cover all subheads 85 | total_losses.update(np.mean([v.item() for v in total_loss])) 86 | consistency_losses.update(np.mean([v.item() for v in consistency_loss])) 87 | entropy_losses.update(np.mean([v.item() for v in entropy_loss])) 88 | 89 | total_loss = torch.sum(torch.stack(total_loss, dim=0)) 90 | 91 | optimizer.zero_grad() 92 | total_loss.backward() 93 | optimizer.step() 94 | 95 | if i % 25 == 0: 96 | progress.display(i) 97 | 98 | 99 | def selflabel_train(train_loader, model, criterion, optimizer, epoch, ema=None): 100 | """ 101 | Self-labeling based on confident samples 102 | """ 103 | losses = AverageMeter('Loss', ':.4e') 104 | progress = ProgressMeter(len(train_loader), [losses], 105 | prefix="Epoch: [{}]".format(epoch)) 106 | model.train() 107 | 108 | for i, batch in enumerate(train_loader): 109 | images = batch['image'].cuda(non_blocking=True) 110 | images_augmented = batch['image_augmented'].cuda(non_blocking=True) 111 | 112 | with torch.no_grad(): 113 | output = model(images)[0] 114 | output_augmented = model(images_augmented)[0] 115 | 116 | loss = criterion(output, output_augmented) 117 | losses.update(loss.item()) 118 | 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if ema is not None: # Apply EMA to update the weights of the network 124 | ema.update_params(model) 125 | ema.apply_shadow(model) 126 | 127 | if i % 25 == 0: 128 | progress.display(i) 129 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1, is_last=False): 13 | super(BasicBlock, self).__init__() 14 | self.is_last = is_last 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion * planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | preact = out 32 | out = F.relu(out) 33 | if self.is_last: 34 | return out, preact 35 | else: 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1, is_last=False): 43 | super(Bottleneck, self).__init__() 44 | self.is_last = is_last 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion * planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | preact = out 65 | out = F.relu(out) 66 | if self.is_last: 67 | return out, preact 68 | else: 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 78 | bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 | 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | # Zero-initialize the last BN in each residual branch, 94 | # so that the residual branch starts with zeros, and each residual block behaves 95 | # like an identity. This improves the model by 0.2~0.3% according to: 96 | # https://arxiv.org/abs/1706.02677 97 | if zero_init_residual: 98 | for m in self.modules(): 99 | if isinstance(m, Bottleneck): 100 | nn.init.constant_(m.bn3.weight, 0) 101 | elif isinstance(m, BasicBlock): 102 | nn.init.constant_(m.bn2.weight, 0) 103 | 104 | def _make_layer(self, block, planes, num_blocks, stride): 105 | strides = [stride] + [1] * (num_blocks - 1) 106 | layers = [] 107 | for i in range(num_blocks): 108 | stride = strides[i] 109 | layers.append(block(self.in_planes, planes, stride)) 110 | self.in_planes = planes * block.expansion 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = F.relu(self.bn1(self.conv1(x))) 115 | out = self.layer1(out) 116 | out = self.layer2(out) 117 | out = self.layer3(out) 118 | out = self.layer4(out) 119 | out = self.avgpool(out) 120 | out = torch.flatten(out, 1) 121 | return out 122 | 123 | 124 | def resnet18(**kwargs): 125 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} 126 | -------------------------------------------------------------------------------- /models/resnet_stl.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1, is_last=False): 13 | super(BasicBlock, self).__init__() 14 | self.is_last = is_last 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion * planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | preact = out 32 | out = F.relu(out) 33 | if self.is_last: 34 | return out, preact 35 | else: 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1, is_last=False): 43 | super(Bottleneck, self).__init__() 44 | self.is_last = is_last 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion * planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | preact = out 65 | out = F.relu(out) 66 | if self.is_last: 67 | return out, preact 68 | else: 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 78 | bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.avgpool = nn.AvgPool2d(7, stride=1) 86 | 87 | for m in self.modules(): 88 | if isinstance(m, nn.Conv2d): 89 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 90 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 91 | nn.init.constant_(m.weight, 1) 92 | nn.init.constant_(m.bias, 0) 93 | 94 | # Zero-initialize the last BN in each residual branch, 95 | # so that the residual branch starts with zeros, and each residual block behaves 96 | # like an identity. This improves the model by 0.2~0.3% according to: 97 | # https://arxiv.org/abs/1706.02677 98 | if zero_init_residual: 99 | for m in self.modules(): 100 | if isinstance(m, Bottleneck): 101 | nn.init.constant_(m.bn3.weight, 0) 102 | elif isinstance(m, BasicBlock): 103 | nn.init.constant_(m.bn2.weight, 0) 104 | 105 | def _make_layer(self, block, planes, num_blocks, stride): 106 | strides = [stride] + [1] * (num_blocks - 1) 107 | layers = [] 108 | for i in range(num_blocks): 109 | stride = strides[i] 110 | layers.append(block(self.in_planes, planes, stride)) 111 | self.in_planes = planes * block.expansion 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | out = self.maxpool(F.relu(self.bn1(self.conv1(x)))) 116 | out = self.layer1(out) 117 | out = self.layer2(out) 118 | out = self.layer3(out) 119 | out = self.layer4(out) 120 | out = self.avgpool(out) 121 | out = torch.flatten(out, 1) 122 | return out 123 | 124 | 125 | def resnet18(**kwargs): 126 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} 127 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | EPS=1e-8 9 | 10 | 11 | class MaskedCrossEntropyLoss(nn.Module): 12 | def __init__(self): 13 | super(MaskedCrossEntropyLoss, self).__init__() 14 | 15 | def forward(self, input, target, mask, weight, reduction='mean'): 16 | if not (mask != 0).any(): 17 | raise ValueError('Mask in MaskedCrossEntropyLoss is all zeros.') 18 | target = torch.masked_select(target, mask) 19 | b, c = input.size() 20 | n = target.size(0) 21 | input = torch.masked_select(input, mask.view(b, 1)).view(n, c) 22 | return F.cross_entropy(input, target, weight = weight, reduction = reduction) 23 | 24 | 25 | class ConfidenceBasedCE(nn.Module): 26 | def __init__(self, threshold, apply_class_balancing): 27 | super(ConfidenceBasedCE, self).__init__() 28 | self.loss = MaskedCrossEntropyLoss() 29 | self.softmax = nn.Softmax(dim = 1) 30 | self.threshold = threshold 31 | self.apply_class_balancing = apply_class_balancing 32 | 33 | def forward(self, anchors_weak, anchors_strong): 34 | """ 35 | Loss function during self-labeling 36 | 37 | input: logits for original samples and for its strong augmentations 38 | output: cross entropy 39 | """ 40 | # Retrieve target and mask based on weakly augmentated anchors 41 | weak_anchors_prob = self.softmax(anchors_weak) 42 | max_prob, target = torch.max(weak_anchors_prob, dim = 1) 43 | mask = max_prob > self.threshold 44 | b, c = weak_anchors_prob.size() 45 | target_masked = torch.masked_select(target, mask.squeeze()) 46 | n = target_masked.size(0) 47 | 48 | # Inputs are strongly augmented anchors 49 | input_ = anchors_strong 50 | 51 | # Class balancing weights 52 | if self.apply_class_balancing: 53 | idx, counts = torch.unique(target_masked, return_counts = True) 54 | freq = 1/(counts.float()/n) 55 | weight = torch.ones(c).cuda() 56 | weight[idx] = freq 57 | 58 | else: 59 | weight = None 60 | 61 | # Loss 62 | loss = self.loss(input_, target, mask, weight = weight, reduction='mean') 63 | 64 | return loss 65 | 66 | 67 | def entropy(x, input_as_probabilities): 68 | """ 69 | Helper function to compute the entropy over the batch 70 | 71 | input: batch w/ shape [b, num_classes] 72 | output: entropy value [is ideally -log(num_classes)] 73 | """ 74 | 75 | if input_as_probabilities: 76 | x_ = torch.clamp(x, min = EPS) 77 | b = x_ * torch.log(x_) 78 | else: 79 | b = F.softmax(x, dim = 1) * F.log_softmax(x, dim = 1) 80 | 81 | if len(b.size()) == 2: # Sample-wise entropy 82 | return -b.sum(dim = 1).mean() 83 | elif len(b.size()) == 1: # Distribution-wise entropy 84 | return - b.sum() 85 | else: 86 | raise ValueError('Input tensor is %d-Dimensional' %(len(b.size()))) 87 | 88 | 89 | class SCANLoss(nn.Module): 90 | def __init__(self, entropy_weight = 2.0): 91 | super(SCANLoss, self).__init__() 92 | self.softmax = nn.Softmax(dim = 1) 93 | self.bce = nn.BCELoss() 94 | self.entropy_weight = entropy_weight # Default = 2.0 95 | 96 | def forward(self, anchors, neighbors): 97 | """ 98 | input: 99 | - anchors: logits for anchor images w/ shape [b, num_classes] 100 | - neighbors: logits for neighbor images w/ shape [b, num_classes] 101 | 102 | output: 103 | - Loss 104 | """ 105 | # Softmax 106 | b, n = anchors.size() 107 | anchors_prob = self.softmax(anchors) 108 | positives_prob = self.softmax(neighbors) 109 | 110 | # Similarity in output space 111 | similarity = torch.bmm(anchors_prob.view(b, 1, n), positives_prob.view(b, n, 1)).squeeze() 112 | ones = torch.ones_like(similarity) 113 | consistency_loss = self.bce(similarity, ones) 114 | 115 | # Entropy loss 116 | entropy_loss = entropy(torch.mean(anchors_prob, 0), input_as_probabilities = True) 117 | 118 | # Total loss 119 | total_loss = consistency_loss - self.entropy_weight * entropy_loss 120 | 121 | return total_loss, consistency_loss, entropy_loss 122 | 123 | 124 | class SimCLRLoss(nn.Module): 125 | # Based on the implementation of SupContrast 126 | def __init__(self, temperature): 127 | super(SimCLRLoss, self).__init__() 128 | self.temperature = temperature 129 | 130 | 131 | def forward(self, features): 132 | """ 133 | input: 134 | - features: hidden feature representation of shape [b, 2, dim] 135 | 136 | output: 137 | - loss: loss computed according to SimCLR 138 | """ 139 | 140 | b, n, dim = features.size() 141 | assert(n == 2) 142 | mask = torch.eye(b, dtype=torch.float32).cuda() 143 | 144 | contrast_features = torch.cat(torch.unbind(features, dim=1), dim=0) 145 | anchor = features[:, 0] 146 | 147 | # Dot product 148 | dot_product = torch.matmul(anchor, contrast_features.T) / self.temperature 149 | 150 | # Log-sum trick for numerical stability 151 | logits_max, _ = torch.max(dot_product, dim=1, keepdim=True) 152 | logits = dot_product - logits_max.detach() 153 | 154 | mask = mask.repeat(1, 2) 155 | logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(b).view(-1, 1).cuda(), 0) 156 | mask = mask * logits_mask 157 | 158 | # Log-softmax 159 | exp_logits = torch.exp(logits) * logits_mask 160 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 161 | 162 | # Mean log-likelihood for positive 163 | loss = - ((mask * log_prob).sum(1) / mask.sum(1)).mean() 164 | 165 | return loss 166 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import torch 7 | import yaml 8 | from termcolor import colored 9 | from utils.common_config import get_val_dataset, get_val_transformations, get_val_dataloader,\ 10 | get_model 11 | from utils.evaluate_utils import get_predictions, hungarian_evaluate 12 | from utils.memory import MemoryBank 13 | from utils.utils import fill_memory_bank 14 | from PIL import Image 15 | 16 | FLAGS = argparse.ArgumentParser(description='Evaluate models from the model zoo') 17 | FLAGS.add_argument('--config_exp', help='Location of config file') 18 | FLAGS.add_argument('--model', help='Location where model is saved') 19 | FLAGS.add_argument('--visualize_prototypes', action='store_true', 20 | help='Show the prototpye for each cluster') 21 | args = FLAGS.parse_args() 22 | 23 | def main(): 24 | 25 | # Read config file 26 | print(colored('Read config file {} ...'.format(args.config_exp), 'blue')) 27 | with open(args.config_exp, 'r') as stream: 28 | config = yaml.safe_load(stream) 29 | config['batch_size'] = 512 # To make sure we can evaluate on a single 1080ti 30 | print(config) 31 | 32 | # Get dataset 33 | print(colored('Get validation dataset ...', 'blue')) 34 | transforms = get_val_transformations(config) 35 | dataset = get_val_dataset(config, transforms) 36 | dataloader = get_val_dataloader(config, dataset) 37 | print('Number of samples: {}'.format(len(dataset))) 38 | 39 | # Get model 40 | print(colored('Get model ...', 'blue')) 41 | model = get_model(config) 42 | print(model) 43 | 44 | # Read model weights 45 | print(colored('Load model weights ...', 'blue')) 46 | state_dict = torch.load(args.model, map_location='cpu') 47 | 48 | if config['setup'] in ['simclr', 'moco', 'selflabel']: 49 | model.load_state_dict(state_dict) 50 | 51 | elif config['setup'] == 'scan': 52 | model.load_state_dict(state_dict['model']) 53 | 54 | else: 55 | raise NotImplementedError 56 | 57 | # CUDA 58 | model.cuda() 59 | 60 | # Perform evaluation 61 | if config['setup'] in ['simclr', 'moco']: 62 | print(colored('Perform evaluation of the pretext task (setup={}).'.format(config['setup']), 'blue')) 63 | print('Create Memory Bank') 64 | if config['setup'] == 'simclr': # Mine neighbors after MLP 65 | memory_bank = MemoryBank(len(dataset), config['model_kwargs']['features_dim'], 66 | config['num_classes'], config['criterion_kwargs']['temperature']) 67 | 68 | else: # Mine neighbors before MLP 69 | memory_bank = MemoryBank(len(dataset), config['model_kwargs']['features_dim'], 70 | config['num_classes'], config['temperature']) 71 | memory_bank.cuda() 72 | 73 | print('Fill Memory Bank') 74 | fill_memory_bank(dataloader, model, memory_bank) 75 | 76 | print('Mine the nearest neighbors') 77 | for topk in [1, 5, 20]: # Similar to Fig 2 in paper 78 | _, acc = memory_bank.mine_nearest_neighbors(topk) 79 | print('Accuracy of top-{} nearest neighbors on validation set is {:.2f}'.format(topk, 100*acc)) 80 | 81 | 82 | elif config['setup'] in ['scan', 'selflabel']: 83 | print(colored('Perform evaluation of the clustering model (setup={}).'.format(config['setup']), 'blue')) 84 | head = state_dict['head'] if config['setup'] == 'scan' else 0 85 | predictions, features = get_predictions(config, dataloader, model, return_features=True) 86 | clustering_stats = hungarian_evaluate(head, predictions, dataset.classes, 87 | compute_confusion_matrix=True) 88 | print(clustering_stats) 89 | if args.visualize_prototypes: 90 | prototype_indices = get_prototypes(config, predictions[head], features, model) 91 | visualize_indices(prototype_indices, dataset, clustering_stats['hungarian_match']) 92 | else: 93 | raise NotImplementedError 94 | 95 | @torch.no_grad() 96 | def get_prototypes(config, predictions, features, model, topk=10): 97 | import torch.nn.functional as F 98 | 99 | # Get topk most certain indices and pred labels 100 | print('Get topk') 101 | probs = predictions['probabilities'] 102 | n_classes = probs.shape[1] 103 | dims = features.shape[1] 104 | max_probs, pred_labels = torch.max(probs, dim = 1) 105 | indices = torch.zeros((n_classes, topk)) 106 | for pred_id in range(n_classes): 107 | probs_copy = max_probs.clone() 108 | mask_out = ~(pred_labels == pred_id) 109 | probs_copy[mask_out] = -1 110 | conf_vals, conf_idx = torch.topk(probs_copy, k = topk, largest = True, sorted = True) 111 | indices[pred_id, :] = conf_idx 112 | 113 | # Get corresponding features 114 | selected_features = torch.index_select(features, dim=0, index=indices.view(-1).long()) 115 | selected_features = selected_features.unsqueeze(1).view(n_classes, -1, dims) 116 | 117 | # Get mean feature per class 118 | mean_features = torch.mean(selected_features, dim=1) 119 | 120 | # Get min distance wrt to mean 121 | diff_features = selected_features - mean_features.unsqueeze(1) 122 | diff_norm = torch.norm(diff_features, 2, dim=2) 123 | 124 | # Get final indices 125 | _, best_indices = torch.min(diff_norm, dim=1) 126 | one_hot = F.one_hot(best_indices.long(), indices.size(1)).byte() 127 | proto_indices = torch.masked_select(indices.view(-1), one_hot.view(-1)) 128 | proto_indices = proto_indices.int().tolist() 129 | return proto_indices 130 | 131 | def visualize_indices(indices, dataset, hungarian_match): 132 | import matplotlib.pyplot as plt 133 | import numpy as np 134 | 135 | for idx in indices: 136 | img = np.array(dataset.get_image(idx)).astype(np.uint8) 137 | img = Image.fromarray(img) 138 | plt.figure() 139 | plt.axis('off') 140 | plt.imshow(img) 141 | plt.show() 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /scan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | 9 | from termcolor import colored 10 | from utils.config import create_config 11 | from utils.common_config import get_train_transformations, get_val_transformations,\ 12 | get_train_dataset, get_train_dataloader,\ 13 | get_val_dataset, get_val_dataloader,\ 14 | get_optimizer, get_model, get_criterion,\ 15 | adjust_learning_rate 16 | from utils.evaluate_utils import get_predictions, scan_evaluate, hungarian_evaluate 17 | from utils.train_utils import scan_train 18 | 19 | FLAGS = argparse.ArgumentParser(description='SCAN Loss') 20 | FLAGS.add_argument('--config_env', help='Location of path config file') 21 | FLAGS.add_argument('--config_exp', help='Location of experiments config file') 22 | 23 | def main(): 24 | args = FLAGS.parse_args() 25 | p = create_config(args.config_env, args.config_exp) 26 | print(colored(p, 'red')) 27 | 28 | # CUDNN 29 | torch.backends.cudnn.benchmark = True 30 | 31 | # Data 32 | print(colored('Get dataset and dataloaders', 'blue')) 33 | train_transformations = get_train_transformations(p) 34 | val_transformations = get_val_transformations(p) 35 | train_dataset = get_train_dataset(p, train_transformations, 36 | split='train', to_neighbors_dataset = True) 37 | val_dataset = get_val_dataset(p, val_transformations, to_neighbors_dataset = True) 38 | train_dataloader = get_train_dataloader(p, train_dataset) 39 | val_dataloader = get_val_dataloader(p, val_dataset) 40 | print('Train transforms:', train_transformations) 41 | print('Validation transforms:', val_transformations) 42 | print('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset))) 43 | 44 | # Model 45 | print(colored('Get model', 'blue')) 46 | model = get_model(p, p['pretext_model']) 47 | print(model) 48 | model = torch.nn.DataParallel(model) 49 | model = model.cuda() 50 | 51 | # Optimizer 52 | print(colored('Get optimizer', 'blue')) 53 | optimizer = get_optimizer(p, model, p['update_cluster_head_only']) 54 | print(optimizer) 55 | 56 | # Warning 57 | if p['update_cluster_head_only']: 58 | print(colored('WARNING: SCAN will only update the cluster head', 'red')) 59 | 60 | # Loss function 61 | print(colored('Get loss', 'blue')) 62 | criterion = get_criterion(p) 63 | criterion.cuda() 64 | print(criterion) 65 | 66 | # Checkpoint 67 | if os.path.exists(p['scan_checkpoint']): 68 | print(colored('Restart from checkpoint {}'.format(p['scan_checkpoint']), 'blue')) 69 | checkpoint = torch.load(p['scan_checkpoint'], map_location='cpu') 70 | model.load_state_dict(checkpoint['model']) 71 | optimizer.load_state_dict(checkpoint['optimizer']) 72 | start_epoch = checkpoint['epoch'] 73 | best_loss = checkpoint['best_loss'] 74 | best_loss_head = checkpoint['best_loss_head'] 75 | 76 | else: 77 | print(colored('No checkpoint file at {}'.format(p['scan_checkpoint']), 'blue')) 78 | start_epoch = 0 79 | best_loss = 1e4 80 | best_loss_head = None 81 | 82 | # Main loop 83 | print(colored('Starting main loop', 'blue')) 84 | 85 | for epoch in range(start_epoch, p['epochs']): 86 | print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow')) 87 | print(colored('-'*15, 'yellow')) 88 | 89 | # Adjust lr 90 | lr = adjust_learning_rate(p, optimizer, epoch) 91 | print('Adjusted learning rate to {:.5f}'.format(lr)) 92 | 93 | # Train 94 | print('Train ...') 95 | scan_train(train_dataloader, model, criterion, optimizer, epoch, p['update_cluster_head_only']) 96 | 97 | # Evaluate 98 | print('Make prediction on validation set ...') 99 | predictions = get_predictions(p, val_dataloader, model) 100 | 101 | print('Evaluate based on SCAN loss ...') 102 | scan_stats = scan_evaluate(predictions) 103 | print(scan_stats) 104 | lowest_loss_head = scan_stats['lowest_loss_head'] 105 | lowest_loss = scan_stats['lowest_loss'] 106 | 107 | if lowest_loss < best_loss: 108 | print('New lowest loss on validation set: %.4f -> %.4f' %(best_loss, lowest_loss)) 109 | print('Lowest loss head is %d' %(lowest_loss_head)) 110 | best_loss = lowest_loss 111 | best_loss_head = lowest_loss_head 112 | torch.save({'model': model.module.state_dict(), 'head': best_loss_head}, p['scan_model']) 113 | 114 | else: 115 | print('No new lowest loss on validation set: %.4f -> %.4f' %(best_loss, lowest_loss)) 116 | print('Lowest loss head is %d' %(best_loss_head)) 117 | 118 | print('Evaluate with hungarian matching algorithm ...') 119 | clustering_stats = hungarian_evaluate(lowest_loss_head, predictions, compute_confusion_matrix=False) 120 | print(clustering_stats) 121 | 122 | # Checkpoint 123 | print('Checkpoint ...') 124 | torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 125 | 'epoch': epoch + 1, 'best_loss': best_loss, 'best_loss_head': best_loss_head}, 126 | p['scan_checkpoint']) 127 | 128 | # Evaluate and save the final model 129 | print(colored('Evaluate best model based on SCAN metric at the end', 'blue')) 130 | model_checkpoint = torch.load(p['scan_model'], map_location='cpu') 131 | model.module.load_state_dict(model_checkpoint['model']) 132 | predictions = get_predictions(p, val_dataloader, model) 133 | clustering_stats = hungarian_evaluate(model_checkpoint['head'], predictions, 134 | class_names=val_dataset.dataset.classes, 135 | compute_confusion_matrix=True, 136 | confusion_matrix_file=os.path.join(p['scan_dir'], 'confusion_matrix.png')) 137 | print(clustering_stats) 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /TUTORIAL.md: -------------------------------------------------------------------------------- 1 | # Tutorial: Semantic Clustering on STL-10 with SCAN 2 | 3 | You can follow this guide to obtain the semantic clusters with SCAN on the STL-10 dataset. The procedure is equivalent for the other datasets. 4 | 5 | ## Contents 6 | 1. [Preparation](#preparation) 7 | 0. [Pretext task](#pretext-task) 8 | 0. [Semantic clustering](#semantic-clustering) 9 | 0. [Visualization](#visualization) 10 | 0. [Citation](#citation) 11 | 12 | ## Preparation 13 | ### Repository 14 | Clone the repository and navigate to the directory: 15 | ```bash 16 | git clone https://github.com/wvangansbeke/Unsupervised-Classification.git 17 | cd Unsupervised-Classification 18 | ``` 19 | 20 | ### Environment 21 | Activate your python environment containing the packages in the README.md. 22 | Make sure you have a GPU available (ideally a 1080TI or better) and set $gpu_ids to your desired gpu number(s): 23 | ```bash 24 | conda activate your_anaconda_env 25 | export CUDA_VISIBLE_DEVICES=$gpu_ids 26 | ``` 27 | I will use an environment with Python 3.7, Pytorch 1.6, CUDA 10.2 and CUDNN 7.5.6 for this example. 28 | 29 | ### Paths 30 | Adapt the path in `configs/env.yml` to `repository_eccv/`, since this directory will be used in this tutorial. 31 | Make the following directories. The models will be saved there, other directories will be made on the fly if necessary. 32 | ```bash 33 | mkdir -p repository_eccv/stl-10/pretext/ 34 | ``` 35 | Set the path in `utils/mypath.py` to your dataset root path as mentioned in the README.md 36 | 37 | ## Pretext task 38 | First we will run the pretext task (i.e. SimCLR) on the train+unlabeled set of STL-10. 39 | Feel free to run this task with the correct config file: 40 | ``` 41 | python simclr.py --config_env configs/env.yml --config_exp configs/pretext/simclr_stl10.yml 42 | ``` 43 | 44 | In order to save time, we provide pretrained models in the README.md for all the datasets discussed in the paper. 45 | First, download the pretrained model [here](https://drive.google.com/file/d/1261NDFfXuKR2Dh4RWHYYhcicdcPag9NZ/view?usp=sharing) and save it in your experiments directory. Then, move the downloaded model to the correct location (i.e. `repository_eccv/stl-10/pretext/`) and calculate the nearest neighbors. This can be achieved by running the following commands: 46 | ```bash 47 | mv simclr_stl-10.pth.tar repository_eccv/stl-10/pretext/checkpoint.pth.tar # Move model to correct location 48 | python tutorial_nn.py --config_env configs/env.yml --config_exp configs/pretext/simclr_stl10.yml # Compute neighbors 49 | ``` 50 | 51 | You should get the following results: 52 | ``` 53 | > Restart from checkpoint repository_eccv/stl-10/pretext/checkpoint.pth.tar 54 | > Fill memory bank for mining the nearest neighbors (train) ... 55 | > Fill Memory Bank [0/10] 56 | > Mine the nearest neighbors (Top-20) 57 | > Accuracy of top-20 nearest neighbors on train set is 72.81 58 | > Fill memory bank for mining the nearest neighbors (val) ... 59 | > Fill Memory Bank [0/16] 60 | > Mine the nearest neighbors (Top-5) 61 | > Accuracy of top-5 nearest neighbors on val set is 79.85 62 | ``` 63 | Now, the model has been correctly saved for the clustering step and the nearest neighbors were computed automatically. 64 | 65 | ## Semantic clustering 66 | 67 | We will start the clustering procedure now. Simply run the command underneath. The nearest neighbors and pretext model will be loaded automatically: 68 | ```bash 69 | python scan.py --config_env configs/env.yml --config_exp configs/scan/scan_stl10.yml 70 | ``` 71 | 72 | On average, you should get around 75.5% (as reported in the paper). I get around 80% for this run. 73 | As can be seen, the best model is selected based on the lowest loss on the validation set. 74 | A complete log file is included in `logs/scan_stl10.txt`. It can be viewed in color with `cat logs/scan_stl10.txt` in your terminal. 75 | ``` 76 | > Epoch 100/100 77 | > Adjusted learning rate to 0.00010 78 | > Train ... 79 | > Epoch: [99][ 0/39] Total Loss -1.0914e+01 (-1.0914e+01) Consistency Loss 5.4953e-01 (5.4953e-01) Entropy 2.2927e+00 (2.2927e+00) 80 | > Epoch: [99][25/39] Total Loss -1.0860e+01 (-1.0824e+01) Consistency Loss 6.0039e-01 (6.2986e-01) Entropy 2.2920e+00 (2.2908e+00) 81 | > Make prediction on validation set ... 82 | > Evaluate based on SCAN loss ... 83 | > {'scan': [{'entropy': 2.298224687576294, 'consistency': 0.4744518995285034, 'total_loss': -1.8237727880477905}], 'lowest_loss_head': 0, 'lowest_loss': -1.8237727880477905} 84 | > No new lowest loss on validation set: -1.8248 -> -1.8238 85 | > Lowest loss head is 0 86 | > Evaluate with hungarian matching algorithm ... 87 | > {'ACC': 0.79, 'ARI': 0.6165977838702869, 'NMI': 0.6698521018927263, 'Top-5': 0.9895, 'hungarian_match': [(0, 7), (1, 3), (2, 1), (3, 5), (4, 2), (5, 0), (6, 8), (7, 9), (8, 4), (9, 6)]} 88 | > Checkpoint ... 89 | > Evaluate best model based on SCAN metric at the end 90 | > {'ACC': 0.8015, 'ARI': 0.6332440325942004, 'NMI': 0.6823369373831116, 'Top-5': 0.990625, 'hungarian_match': [(0, 7), (1, 3), (2, 1), (3, 5), (4, 2), (5, 0), (6, 8), (7, 9), (8, 4), (9, 6)]} ] 91 | ``` 92 | 93 | ## Visualization 94 | Now, we can visualize the confusion matrix and the prototypes of our model. We define the prototypes as the most confident samples for each cluster. We visualize the sample which is the closest to the mean embedding of its confident samples for each cluster. Run the following command: 95 | ```bash 96 | python eval.py --config_exp configs/scan/scan_stl10.yml --model repository_eccv/stl-10/scan/model.pth.tar --visualize_prototypes 97 | ``` 98 | As can be seen from the confusion matrix, the model confuses primarily between visually similar classes (e.g. cats, dogs and monkeys). Some images are classified near perfection (e.g. ship) without the use of ground truth. 99 |
100 |
101 |
103 |
104 |
18 |
19 |
20 |
21 |
22 |
23 | [](https://paperswithcode.com/sota/unsupervised-image-classification-on-imagenet?p=learning-to-classify-images-without-labels)
24 | [](https://paperswithcode.com/sota/unsupervised-image-classification-on-cifar-10?p=learning-to-classify-images-without-labels)
25 | [](https://paperswithcode.com/sota/unsupervised-image-classification-on-stl-10?p=learning-to-classify-images-without-labels)
26 | [](https://paperswithcode.com/sota/unsupervised-image-classification-on-cifar-20?p=learning-to-classify-images-without-labels)
27 |
142 |
143 |