├── .gitignore ├── 2d_selection_gif.gif ├── LICENSE ├── README.md ├── USAGE.md ├── cifar_selection.png ├── dcom_delta_updating.gif ├── dcom_semi.png ├── dcom_supervised.png ├── deep-al ├── LICENSE ├── README.md ├── configs │ ├── cifar10 │ │ ├── al │ │ │ ├── RESNET18.yaml │ │ │ ├── RESNET18_ENS.yaml │ │ │ └── RESNET18_IM.yaml │ │ ├── evaluate │ │ │ └── RESNET18.yaml │ │ └── train │ │ │ ├── RESNET18.yaml │ │ │ └── RESNET18_ENS.yaml │ ├── cifar100 │ │ └── al │ │ │ └── RESNET18.yaml │ ├── template.yaml │ └── tinyimagenet │ │ └── al │ │ └── RESNET18.yaml ├── docs │ ├── AL_results.png │ └── GETTING_STARTED.md ├── pycls │ ├── __init__.py │ ├── al │ │ ├── ActiveLearning.py │ │ ├── DCoM.py │ │ ├── Sampling.py │ │ ├── __init__.py │ │ ├── prob_cover.py │ │ ├── typiclust.py │ │ └── vaal_util.py │ ├── core │ │ ├── __init__.py │ │ ├── builders.py │ │ ├── config.py │ │ ├── losses.py │ │ ├── net.py │ │ └── optimizer.py │ ├── datasets │ │ ├── __init__.py │ │ ├── augment.py │ │ ├── custom_datasets.py │ │ ├── data.py │ │ ├── imbalanced_cifar.py │ │ ├── randaugment.py │ │ ├── sampler.py │ │ ├── simclr_augment.py │ │ ├── tiny_imagenet.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── features.py │ │ │ ├── gaussian_blur.py │ │ │ └── helpers.py │ ├── models │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── resnet.py │ │ ├── vaal_model.py │ │ └── vgg.py │ └── utils │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── checkpoint.py │ │ ├── distributed.py │ │ ├── io.py │ │ ├── logging.py │ │ ├── meters.py │ │ ├── metrics.py │ │ ├── net.py │ │ ├── plotting.py │ │ └── timer.py ├── requirements.txt └── tools │ ├── __init__.py │ ├── ensemble_al.py │ ├── ensemble_train.py │ ├── test_model.py │ ├── train.py │ └── train_al.py ├── probcover_selection.gif ├── probcover_semi.png ├── results.png ├── scan ├── LICENSE ├── README.md ├── TUTORIAL.md ├── configs │ ├── env.yml │ ├── pretext │ │ ├── moco_imagenet100.yml │ │ ├── moco_imagenet200.yml │ │ ├── moco_imagenet50.yml │ │ ├── simclr_cifar10.yml │ │ ├── simclr_cifar100.yml │ │ ├── simclr_stl10.yml │ │ └── simclr_tinyimagenet.yml │ ├── scan │ │ ├── imagenet_eval.yml │ │ ├── scan_cifar10.yml │ │ ├── scan_cifar100.yml │ │ ├── scan_imagenet_100.yml │ │ ├── scan_imagenet_200.yml │ │ ├── scan_imagenet_50.yml │ │ ├── scan_stl10.yml │ │ └── scan_tinyimagenet.yml │ └── selflabel │ │ ├── selflabel_cifar10.yml │ │ ├── selflabel_cifar100.yml │ │ ├── selflabel_imagenet_100.yml │ │ ├── selflabel_imagenet_200.yml │ │ ├── selflabel_imagenet_50.yml │ │ └── selflabel_stl10.yml ├── data │ ├── augment.py │ ├── cifar.py │ ├── custom_dataset.py │ ├── imagenet.py │ ├── imagenet_subsets │ │ ├── imagenet_100.txt │ │ ├── imagenet_200.txt │ │ └── imagenet_50.txt │ ├── stl.py │ └── tinyimagenet.py ├── eval.py ├── images │ ├── pipeline.png │ ├── prototypes_cifar10.jpg │ ├── teaser.jpg │ └── tutorial │ │ ├── confusion_matrix_stl10.png │ │ └── prototypes_stl10.jpg ├── losses │ └── losses.py ├── moco.py ├── models │ ├── models.py │ ├── resnet.py │ ├── resnet_cifar.py │ ├── resnet_stl.py │ └── resnet_tinyimagenet.py ├── requirements.txt ├── scan.py ├── selflabel.py ├── simclr.py ├── tutorial_nn.py └── utils │ ├── collate.py │ ├── common_config.py │ ├── config.py │ ├── ema.py │ ├── evaluate_utils.py │ ├── memory.py │ ├── mypath.py │ ├── train_utils.py │ └── utils.py └── typiclust-env.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Shared objects 7 | *.so 8 | 9 | # Distribution / packaging 10 | build/ 11 | *.egg-info/ 12 | *.egg 13 | 14 | # Temporary files 15 | *.swn 16 | *.swo 17 | *.swp 18 | 19 | # PyCharm 20 | ../.idea/ 21 | 22 | # Mac 23 | .DS_STORE 24 | 25 | # Data symlinks 26 | pycls/datasets/data/ 27 | 28 | # Other 29 | logs/ 30 | scratch* 31 | *.out 32 | 33 | output/ 34 | *.ipynb_checkpoints/ 35 | data/ 36 | scan/results/ -------------------------------------------------------------------------------- /2d_selection_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/2d_selection_gif.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2022 (c) Avihu Dekel and Guy Hacohen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Typiclust, ProbCover & DCoM Official Code Repository 2 | 3 | 4 | This is the official implementation for the papers **Active Learning on a Budget - Opposite Strategies Suit High and Low Budgets** and **Active Learning Through a Covering Lens**. 5 | 6 | This code implements TypiClust, ProbCover and DCoM - Simple and Effective Low Budget Active Learning methods. 7 | ## Typiclust 8 | 9 | [**Arxiv link**](https://arxiv.org/abs/2202.02794), 10 | [**Twitter Post link**](https://twitter.com/AvihuDkl/status/1529385835694637058), 11 | [**Blog Post link**](https://avihu111.github.io/Active-Learning/) 12 | 13 | 14 | TypiClust first employs a representation learning method, then clusters the data into K clusters, and selects the most Typical (Dense) sample from every cluster. In other words, TypiClust selects samples from dense and diverse regions of the data distribution. 15 | 16 | Selection of 30 samples on CIFAR-10: 17 | 18 | 19 | 20 | Selection of 10 samples from a GMM: 21 | 22 | 23 | 24 | TypiClust Results summary 25 | 26 | 27 | 28 | ## Probability Cover 29 | 30 | [**Arxiv link**](https://arxiv.org/abs/2205.11320), 31 | [**Twitter Post link**](https://mobile.twitter.com/AvihuDkl/status/1579497337650839553), 32 | [**Blog Post link**](https://avihu111.github.io/Covering-Lens/) 33 | 34 | ProbCover also uses a representation learning method. Then, around every point is placed a $\delta$-radius ball, and the subset of $b$ (budget) balls which covers the most of the points is selected, with their centers chosen as the samples to be labeled. 35 | 36 | Unfolding selection of ProbCover 37 | 38 | 39 | 40 | ProbCover results in the Semi-Supervised training framework 41 | 42 | 43 | 44 | ## DCoM 45 | [**Arxiv link**](https://arxiv.org/abs/2407.01804) 46 | 47 | DCoM employs a representation learning approach. Initially, a $\Delta_{\text{avg}}$-radius ball is placed around each point. The $\Delta$ list provides a specific radius for each labeled example individually. From these, a subset of $b$ balls is chosen based on their coverage of the most points, with the centers of these balls selected as the samples to be labeled. After training the model, the $\Delta$ list is updated according to the purity of the balls to achieve more accurate radii and coverage. DCoM utilizes this coverage to determine the competence score, which balances typicality and uncertainty. 48 | 49 | Illustration of DCoM's $\Delta$ updating 50 | 51 | 52 | 53 | DCoM results in the Supervised training framework 54 | 55 | 56 | 57 | DCoM results in the Semi-Supervised training framework 58 | 59 | 60 | 61 | ## Usage 62 | 63 | Please see [`USAGE`](USAGE.md) for brief instructions on installation and basic usage examples. 64 | 65 | ## Citing this Repository 66 | This Repository makes use of two repositories: ([SCAN](https://github.com/wvangansbeke/Unsupervised-Classification) and [Deep-AL](https://github.com/acl21/deep-active-learning-pytorch)) 67 | Please consider citing their work and ours: 68 | ``` 69 | @article{hacohen2022active, 70 | title={Active learning on a budget: Opposite strategies suit high and low budgets}, 71 | author={Hacohen, Guy and Dekel, Avihu and Weinshall, Daphna}, 72 | journal={arXiv preprint arXiv:2202.02794}, 73 | year={2022} 74 | } 75 | 76 | @article{yehudaActiveLearningCovering2022, 77 | title = {Active {{Learning Through}} a {{Covering Lens}}}, 78 | author = {Yehuda, Ofer and Dekel, Avihu and Hacohen, Guy and Weinshall, Daphna}, 79 | journal={arXiv preprint arXiv:2205.11320}, 80 | year={2022} 81 | } 82 | 83 | @article{mishal2024dcom, 84 | title={DCoM: Active Learning for All Learners}, 85 | author={Mishal, Inbal and Weinshall, Daphna}, 86 | journal={arXiv preprint arXiv:2407.01804}, 87 | year={2024} 88 | } 89 | ``` 90 | 91 | ## License 92 | This toolkit is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 93 | -------------------------------------------------------------------------------- /USAGE.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Setup 4 | 5 | Clone the repository 6 | ``` 7 | git clone https://github.com/avihu111/TypiClust 8 | cd TypiClust 9 | ``` 10 | 11 | Create an environment using 12 | ``` 13 | conda create --name typiclust --file typiclust-env.txt 14 | conda activate typiclust 15 | pip install pyyaml easydict termcolor tqdm simplejson yacs 16 | ``` 17 | 18 | If that fails, it might be due to incompatible CUDA version. 19 | In that case, try installing by 20 | ``` 21 | conda create --name typiclust python=3.7 22 | conda activate typiclust 23 | conda install pytorch torchvision torchaudio cudatoolkit= -c pytorch 24 | conda install matplotlib scipy scikit-learn pandas 25 | conda install -c conda-forge faiss-gpu 26 | pip install pyyaml easydict termcolor tqdm simplejson yacs 27 | ``` 28 | Select the GPU to be used by running 29 | ``` 30 | CUDA_VISIBLE_DEVICES=0 31 | ``` 32 | 33 | ## Representation Learning 34 | Both TypiClust variants and ProbCover rely on representation learning. 35 | To train CIFAR-10 on simclr please run 36 | ``` 37 | cd scan 38 | python simclr.py --config_env configs/env.yml --config_exp configs/pretext/simclr_cifar10.yml 39 | cd .. 40 | ``` 41 | When this finishes, the file `./results/cifar-10/pretext/features_seed1.npy` should exist. 42 | 43 | To save time, you can download the features of CIFAR-10/100 from here: 44 | 45 | | Dataset | Download link | 46 | |------------------|---------------| 47 | |CIFAR10 | [Download](https://drive.google.com/file/d/1Le1ZuZOpfxBfxL3nnNahZcCt-lLWLQSB/view?usp=sharing) | 48 | |CIFAR100 | [Download](https://drive.google.com/file/d/1o2nz_SKLdcaTCB9XVA44qCTVSUSmktUb/view?usp=sharing) | 49 | 50 | 51 | and locate the files here:`./results/cifar-10/pretext/features_seed1.npy`. 52 | 53 | ## TypiClust - K-Means variant 54 | To select samples according to TypiClust (K-Means) where the `initial_size=0` and the `budget=100` please run 55 | ``` 56 | cd deep-al/tools 57 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al typiclust_rp --exp-name auto --initial_size 0 --budget 100 58 | cd ../../ 59 | ``` 60 | 61 | 62 | ## TypiClust - SCAN variant 63 | In this section we select `budget=10` samples without an initial set. 64 | We first must run SCAN clustering algorithm, as TypiClust uses its features cluster assignments. 65 | Please repeat the following command `for k in [10, 20, 30, 40, 50, 60]` 66 | 67 | ``` 68 | cd scan 69 | python scan.py --config_env configs/env.yml --config_exp configs/scan/scan_cifar10.yml --num_clusters k 70 | cd .. 71 | ``` 72 | You can use other representations and change the path in the file `deep-al/pycls/datasets/utils/features.py`. 73 | Then, you can run the active learning experiment by running 74 | ``` 75 | cd deep-al/tools 76 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al typiclust_dc --exp-name auto --initial_size 0 --budget 10 77 | cd ../../ 78 | ``` 79 | 80 | ## ProbCover 81 | Example ProbCover script 82 | ``` 83 | cd deep-al/tools 84 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al probcover --exp-name auto --initial_size 0 --budget 10 --delta 0.75 85 | cd ../../ 86 | ``` 87 | 88 | ## DCoM 89 | Example DCoM script: 90 | ``` 91 | cd deep-al/tools 92 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al dcom --exp-name auto --initial_size 0 --budget 10 --initial_delta 0.75 93 | cd ../../ 94 | ``` 95 | 96 | You can add the `a_logistic` and `k_logistic` parameters to the run using `--a_logistic` and `--k_logistic`. 97 | -------------------------------------------------------------------------------- /cifar_selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/cifar_selection.png -------------------------------------------------------------------------------- /dcom_delta_updating.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/dcom_delta_updating.gif -------------------------------------------------------------------------------- /dcom_semi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/dcom_semi.png -------------------------------------------------------------------------------- /dcom_supervised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/dcom_supervised.png -------------------------------------------------------------------------------- /deep-al/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2022 (c) Akshay L Chandra and Vineeth N Balasubramanian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /deep-al/README.md: -------------------------------------------------------------------------------- 1 | # Deep Active Learning Toolkit for Image Classification in PyTorch 2 | 3 | This is a code base for deep active learning for image classification written in [PyTorch](https://pytorch.org/). It is build on top of FAIR's [pycls](https://github.com/facebookresearch/pycls/). I want to emphasize that this is a derivative of the toolkit originally shared with me via email by Prateek Munjal _et al._, the authors of the paper _"Towards Robust and Reproducible Active Learning using Neural Networks"_, paper available [here](https://arxiv.org/abs/2002.09564). 4 | 5 | ## Introduction 6 | 7 | The goal of this repository is to provide a simple and flexible codebase for deep active learning. It is designed to support rapid implementation and evaluation of research ideas. We also provide a results on CIFAR10 below. 8 | 9 | The codebase currently only supports single-machine single-gpu training. We will soon scale it to single-machine multi-gpu training, powered by the PyTorch distributed package. 10 | 11 | ## Using the Toolkit 12 | 13 | Please see [`GETTING_STARTED`](docs/GETTING_STARTED.md) for brief instructions on installation, adding new datasets, basic usage examples, etc. 14 | 15 | ## Active Learning Methods Supported 16 | * Uncertainty Sampling 17 | * Least Confidence 18 | * Min-Margin 19 | * Max-Entropy 20 | * Deep Bayesian Active Learning (DBAL) [1] 21 | * Bayesian Active Learning by Disagreement (BALD) [1] 22 | * Diversity Sampling 23 | * Coreset (greedy) [2] 24 | * Variational Adversarial Active Learning (VAAL) [3] 25 | * Query-by-Committee Sampling 26 | * Ensemble Variation Ratio (Ens-varR) [4] 27 | 28 | 29 | ## Datasets Supported 30 | * [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html) 31 | * [MNIST](http://yann.lecun.com/exdb/mnist/) 32 | * [SVHN](http://ufldl.stanford.edu/housenumbers/) 33 | * [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet) (Download the zip file [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip)) 34 | * Long-Tail CIFAR-10/100 35 | 36 | Follow the instructions in [`GETTING_STARTED`](docs/GETTING_STARTED.md) to add a new dataset. 37 | 38 | ## Results on CIFAR10 and CIFAR100 39 | 40 | The following are the results on CIFAR10 and CIFAR100, trained with hyperameters present in `configs/cifar10/al/RESNET18.yaml` and `configs/cifar100/al/RESNET18.yaml` respectively. All results were averaged over 3 runs. 41 | 42 | 43 | 44 |
45 | 46 |
47 | 48 | ### CIFAR10 at 60% 49 | ``` 50 | | AL Method | Test Accuracy | 51 | |:----------------:|:---------------------------:| 52 | | DBAL | 91.670000 +- 0.230651 | 53 | | Least Confidence | 91.510000 +- 0.087178 | 54 | | BALD | 91.470000 +- 0.293087 | 55 | | Coreset | 91.433333 +- 0.090738 | 56 | | Max-Entropy | 91.373333 +- 0.363639 | 57 | | Min-Margin | 91.333333 +- 0.234592 | 58 | | Ensemble-varR | 89.866667 +- 0.127410 | 59 | | Random | 89.803333 +- 0.230290 | 60 | | VAAL | 89.690000 +- 0.115326 | 61 | ``` 62 | 63 | ### CIFAR100 at 60% 64 | ``` 65 | | AL Method | Test Accuracy | 66 | |:----------------:|:---------------------------:| 67 | | DBAL | 55.400000 +- 1.037931 | 68 | | Coreset | 55.333333 +- 0.773714 | 69 | | Max-Entropy | 55.226667 +- 0.536128 | 70 | | BALD | 55.186667 +- 0.369639 | 71 | | Least Confidence | 55.003333 +- 0.937248 | 72 | | Min-Margin | 54.543333 +- 0.611583 | 73 | | Ensemble-varR | 54.186667 +- 0.325628 | 74 | | VAAL | 53.943333 +- 0.680686 | 75 | | Random | 53.546667 +- 0.302875 | 76 | ``` 77 | 78 | ## Citing this Repository 79 | 80 | If you find this repo helpful in your research, please consider citing us and the owners of the original toolkit: 81 | 82 | ``` 83 | @article{Chandra2021DeepAL, 84 | Author = {Akshay L Chandra and Vineeth N Balasubramanian}, 85 | Title = {Deep Active Learning Toolkit for Image Classification in PyTorch}, 86 | Journal = {https://github.com/acl21/deep-active-learning-pytorch}, 87 | Year = {2021} 88 | } 89 | 90 | @article{Munjal2020TowardsRA, 91 | title={Towards Robust and Reproducible Active Learning Using Neural Networks}, 92 | author={Prateek Munjal and N. Hayat and Munawar Hayat and J. Sourati and S. Khan}, 93 | journal={ArXiv}, 94 | year={2020}, 95 | volume={abs/2002.09564} 96 | } 97 | ``` 98 | 99 | ## License 100 | 101 | This toolkit is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 102 | 103 | ## References 104 | 105 | [1] Yarin Gal, Riashat Islam, and Zoubin Ghahramani. Deep bayesian active learning with image data. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1183–1192. JMLR. org, 2017. 106 | 107 | [2] Ozan Sener and Silvio Savarese. Active learning for convolutional neural networks: A core-set approach. In International Conference on Learning Representations, 2018. 108 | 109 | [3] Sinha, Samarth et al. Variational Adversarial Active Learning. 2019 IEEE/CVF International Conference on Computer Vision (ICCV) (2019): 5971-5980. 110 | 111 | [4] William H. Beluch, Tim Genewein, Andreas Nürnberger, and Jan M. Köhler. The power of ensembles for active learning in image classification. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 9368–9377, 2018. -------------------------------------------------------------------------------- /deep-al/configs/cifar10/al/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 100 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INIT_L_RATIO: 0. 37 | # BUDGET_SIZE: 10 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 40 | FINE_TUNE: False 41 | DELTA_RESOLUTION: 0.05 42 | MAX_DELTA: 1.1 -------------------------------------------------------------------------------- /deep-al/configs/cifar10/al/RESNET18_ENS.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | BUDGET_SIZE: 5000 37 | SAMPLING_FN: 'ensemble_var_R' # do not change this for ensemble learning 38 | MAX_ITER: 5 39 | ENSEMBLE: 40 | NUM_MODELS: 3 41 | MODEL_TYPE: ['resnet18'] -------------------------------------------------------------------------------- /deep-al/configs/cifar10/al/RESNET18_IM.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: IMBALANCED_CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INITIAL_L_RATIO: 0.1 37 | BUDGET_SIZE: 1270 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 -------------------------------------------------------------------------------- /deep-al/configs/cifar10/evaluate/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'SOME_RANDOM_NAME' 2 | RNG_SEED: 1 3 | GPU_ID: '3' 4 | DATASET: 5 | NAME: CIFAR10 6 | ROOT_DIR: 'data' 7 | MODEL: 8 | TYPE: resnet18 9 | NUM_CLASSES: 10 10 | TEST: 11 | SPLIT: test 12 | BATCH_SIZE: 200 13 | IM_SIZE: 32 14 | MODEL_PATH: 'path/to/saved/model' 15 | DATA_LOADER: 16 | NUM_WORKERS: 4 17 | CUDNN: 18 | BENCHMARK: True -------------------------------------------------------------------------------- /deep-al/configs/cifar10/train/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 1 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True -------------------------------------------------------------------------------- /deep-al/configs/cifar10/train/RESNET18_ENS.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'SOME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ENSEMBLE: 36 | NUM_MODELS: 3 37 | MODEL_TYPE: ['resnet18'] -------------------------------------------------------------------------------- /deep-al/configs/cifar100/al/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR100 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 100 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 100 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 8 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INIT_L_RATIO: 0.1 37 | BUDGET_SIZE: 5000 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 40 | DELTA_RESOLUTION: 0.05 41 | MAX_DELTA: 1.1 -------------------------------------------------------------------------------- /deep-al/configs/template.yaml: -------------------------------------------------------------------------------- 1 | # Folder name where best model logs etc are saved. "auto" creates a timestamp based folder 2 | EXP_NAME: 'SOME_RANDOM_NAME' 3 | # Note that non-determinism may still be present due to non-deterministic 4 | # operator implementations in GPU operator libraries 5 | RNG_SEED: 1 6 | # GPU ID you want to execute the process on 7 | GPU_ID: '3' 8 | DATASET: 9 | NAME: CIFAR10 # or CIFAR100, MNIST, SVHN, TinyImageNet 10 | ROOT_DIR: 'data' # Relative path where data should be downloaded 11 | # Specifies the proportion of data in train set that should be considered as the validation data 12 | VAL_RATIO: 0.1 13 | # Data augmentation methods - 'simclr', 'randaug', 'horizontalflip' 14 | AUG_METHOD: 'horizontalflip' 15 | MODEL: 16 | # Model type. 17 | # Choose from vgg style ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',] 18 | # or from resnet style ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 19 | # 'wide_resnet50_2', 'wide_resnet101_2'] 20 | TYPE: resnet18 21 | NUM_CLASSES: 10 22 | OPTIM: 23 | TYPE: 'sgd' # or 'adam' 24 | BASE_LR: 0.1 25 | # Learning rate policy select from {'cos', 'exp', 'steps'} 26 | LR_POLICY: steps 27 | # Steps for 'steps' policy (in epochs) 28 | STEPS: [0] #[0, 30, 60, 90] 29 | # Training Epochs 30 | MAX_EPOCH: 1 31 | # Momentum 32 | MOMENTUM: 0.9 33 | # Nesterov Momentum 34 | NESTEROV: False 35 | # L2 regularization 36 | WEIGHT_DECAY: 0.0005 37 | # Exponential decay factor 38 | GAMMA: 0.1 39 | TRAIN: 40 | SPLIT: train 41 | # Training mini-batch size 42 | BATCH_SIZE: 256 43 | # Image size 44 | IM_SIZE: 32 45 | IM_CHANNELS = 3 46 | # Evaluate model on test data every eval period epochs 47 | EVAL_PERIOD: 1 48 | TEST: 49 | SPLIT: test 50 | # Testing mini-batch size 51 | BATCH_SIZE: 200 52 | # Image size 53 | IM_SIZE: 32 54 | # Saved model to use for testing (useful when running tools/test_model.py) 55 | MODEL_PATH: '' 56 | DATA_LOADER: 57 | NUM_WORKERS: 4 58 | CUDNN: 59 | BENCHMARK: True 60 | ACTIVE_LEARNING: 61 | # Active sampling budget (at each episode) 62 | BUDGET_SIZE: 5000 63 | # Active sampling method 64 | SAMPLING_FN: 'dbal' # 'random', 'uncertainty', 'entropy', 'margin', 'bald', 'vaal', 'coreset', 'ensemble_var_R' 65 | # Initial labeled pool ratio (% of total train set that should be labeled before AL begins) 66 | INIT_L_RATIO: 0.1 67 | # Max AL episodes 68 | MAX_ITER: 1 69 | DROPOUT_ITERATIONS: 10 # Used by DBAL 70 | # Useful when running `tools/ensemble_al.py` or `tools/ensemble_train.py` 71 | ENSEMBLE: 72 | NUM_MODELS: 3 73 | MODEL_TYPE: ['resnet18'] -------------------------------------------------------------------------------- /deep-al/configs/tinyimagenet/al/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: TINYIMAGENET 5 | ROOT_DIR: '/path/to/dataset/directory/' 6 | VAL_RATIO: 0.05 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 200 11 | OPTIM: 12 | TYPE: 'adam' 13 | BASE_LR: 0.001 14 | LR_POLICY: none 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 100 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 100 25 | IM_SIZE: 64 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 64 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INIT_L_RATIO: 0.1 37 | BUDGET_SIZE: 5000 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 40 | DELTA_RESOLUTION: 0.05 41 | MAX_DELTA: 1.1 -------------------------------------------------------------------------------- /deep-al/docs/AL_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/docs/AL_results.png -------------------------------------------------------------------------------- /deep-al/pycls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/__init__.py -------------------------------------------------------------------------------- /deep-al/pycls/al/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/al/__init__.py -------------------------------------------------------------------------------- /deep-al/pycls/al/prob_cover.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import pycls.datasets.utils as ds_utils 5 | 6 | class ProbCover: 7 | def __init__(self, cfg, lSet, uSet, budgetSize, delta): 8 | self.cfg = cfg 9 | self.ds_name = self.cfg['DATASET']['NAME'] 10 | self.seed = self.cfg['RNG_SEED'] 11 | self.all_features = ds_utils.load_features(self.ds_name, self.seed) 12 | self.lSet = lSet 13 | self.uSet = uSet 14 | self.budgetSize = budgetSize 15 | self.delta = delta 16 | self.relevant_indices = np.concatenate([self.lSet, self.uSet]).astype(int) 17 | self.rel_features = self.all_features[self.relevant_indices] 18 | self.graph_df = self.construct_graph() 19 | 20 | def construct_graph(self, batch_size=500): 21 | """ 22 | creates a directed graph where: 23 | x->y iff l2(x,y) < delta. 24 | 25 | represented by a list of edges (a sparse matrix). 26 | stored in a dataframe 27 | """ 28 | xs, ys, ds = [], [], [] 29 | print(f'Start constructing graph using delta={self.delta}') 30 | # distance computations are done in GPU 31 | cuda_feats = torch.tensor(self.rel_features).cuda() 32 | for i in range(len(self.rel_features) // batch_size): 33 | # distance comparisons are done in batches to reduce memory consumption 34 | cur_feats = cuda_feats[i * batch_size: (i + 1) * batch_size] 35 | dist = torch.cdist(cur_feats, cuda_feats) 36 | mask = dist < self.delta 37 | # saving edges using indices list - saves memory. 38 | x, y = mask.nonzero().T 39 | xs.append(x.cpu() + batch_size * i) 40 | ys.append(y.cpu()) 41 | ds.append(dist[mask].cpu()) 42 | 43 | xs = torch.cat(xs).numpy() 44 | ys = torch.cat(ys).numpy() 45 | ds = torch.cat(ds).numpy() 46 | 47 | df = pd.DataFrame({'x': xs, 'y': ys, 'd': ds}) 48 | print(f'Finished constructing graph using delta={self.delta}') 49 | print(f'Graph contains {len(df)} edges.') 50 | return df 51 | 52 | def select_samples(self): 53 | """ 54 | selecting samples using the greedy algorithm. 55 | iteratively: 56 | - removes incoming edges to all covered samples 57 | - selects the sample high the highest out degree (covers most new samples) 58 | 59 | """ 60 | print(f'Start selecting {self.budgetSize} samples.') 61 | selected = [] 62 | # removing incoming edges to all covered samples from the existing labeled set 63 | edge_from_seen = np.isin(self.graph_df.x, np.arange(len(self.lSet))) 64 | covered_samples = self.graph_df.y[edge_from_seen].unique() 65 | cur_df = self.graph_df[(~np.isin(self.graph_df.y, covered_samples))] 66 | for i in range(self.budgetSize): 67 | coverage = len(covered_samples) / len(self.relevant_indices) 68 | # selecting the sample with the highest degree 69 | degrees = np.bincount(cur_df.x, minlength=len(self.relevant_indices)) 70 | print(f'Iteration is {i}.\tGraph has {len(cur_df)} edges.\tMax degree is {degrees.max()}.\tCoverage is {coverage:.3f}') 71 | cur = degrees.argmax() 72 | # cur = np.random.choice(degrees.argsort()[::-1][:5]) # the paper randomizes selection 73 | 74 | # removing incoming edges to newly covered samples 75 | new_covered_samples = cur_df.y[(cur_df.x == cur)].values 76 | assert len(np.intersect1d(covered_samples, new_covered_samples)) == 0, 'all samples should be new' 77 | cur_df = cur_df[(~np.isin(cur_df.y, new_covered_samples))] 78 | 79 | covered_samples = np.concatenate([covered_samples, new_covered_samples]) 80 | selected.append(cur) 81 | 82 | assert len(selected) == self.budgetSize, 'added a different number of samples' 83 | activeSet = self.relevant_indices[selected] 84 | remainSet = np.array(sorted(list(set(self.uSet) - set(activeSet)))) 85 | 86 | print(f'Finished the selection of {len(activeSet)} samples.') 87 | print(f'Active set is {activeSet}') 88 | return activeSet, remainSet 89 | -------------------------------------------------------------------------------- /deep-al/pycls/al/typiclust.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import faiss 4 | from sklearn.cluster import MiniBatchKMeans, KMeans 5 | import pycls.datasets.utils as ds_utils 6 | 7 | def get_nn(features, num_neighbors): 8 | # calculates nearest neighbors on GPU 9 | d = features.shape[1] 10 | features = features.astype(np.float32) 11 | cpu_index = faiss.IndexFlatL2(d) 12 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) 13 | gpu_index.add(features) # add vectors to the index 14 | distances, indices = gpu_index.search(features, num_neighbors + 1) 15 | # 0 index is the same sample, dropping it 16 | return distances[:, 1:], indices[:, 1:] 17 | 18 | 19 | def get_mean_nn_dist(features, num_neighbors, return_indices=False): 20 | distances, indices = get_nn(features, num_neighbors) 21 | mean_distance = distances.mean(axis=1) 22 | if return_indices: 23 | return mean_distance, indices 24 | return mean_distance 25 | 26 | 27 | def calculate_typicality(features, num_neighbors): 28 | mean_distance = get_mean_nn_dist(features, num_neighbors) 29 | # low distance to NN is high density 30 | typicality = 1 / (mean_distance + 1e-5) 31 | return typicality 32 | 33 | 34 | def kmeans(features, num_clusters): 35 | if num_clusters <= 50: 36 | km = KMeans(n_clusters=num_clusters) 37 | km.fit_predict(features) 38 | else: 39 | km = MiniBatchKMeans(n_clusters=num_clusters, batch_size=5000) 40 | km.fit_predict(features) 41 | return km.labels_ 42 | 43 | 44 | class TypiClust: 45 | MIN_CLUSTER_SIZE = 5 46 | MAX_NUM_CLUSTERS = 500 47 | K_NN = 20 48 | 49 | def __init__(self, cfg, lSet, uSet, budgetSize, is_scan=False): 50 | self.cfg = cfg 51 | self.ds_name = self.cfg['DATASET']['NAME'] 52 | self.seed = self.cfg['RNG_SEED'] 53 | self.features = None 54 | self.clusters = None 55 | self.lSet = lSet 56 | self.uSet = uSet 57 | self.budgetSize = budgetSize 58 | self.init_features_and_clusters(is_scan) 59 | 60 | def init_features_and_clusters(self, is_scan): 61 | num_clusters = min(len(self.lSet) + self.budgetSize, self.MAX_NUM_CLUSTERS) 62 | print(f'Clustering into {num_clusters} clustering. Scan clustering: {is_scan}') 63 | if is_scan: 64 | fname_dict = {'CIFAR10': f'../../scan/results/cifar-10/scan/features_seed{self.seed}_clusters{num_clusters}.npy', 65 | 'CIFAR100': f'../../scan/results/cifar-100/scan/features_seed{self.seed}_clusters{num_clusters}.npy', 66 | 'TINYIMAGENET': f'../../scan/results/tiny-imagenet/scan/features_seed{self.seed}_clusters{num_clusters}.npy', 67 | } 68 | fname = fname_dict[self.ds_name] 69 | self.features = np.load(fname) 70 | self.clusters = np.load(fname.replace('features', 'probs')).argmax(axis=-1) 71 | else: 72 | self.features = ds_utils.load_features(self.ds_name, self.seed) 73 | self.clusters = kmeans(self.features, num_clusters=num_clusters) 74 | print(f'Finished clustering into {num_clusters} clusters.') 75 | 76 | def select_samples(self): 77 | # using only labeled+unlabeled indices, without validation set. 78 | relevant_indices = np.concatenate([self.lSet, self.uSet]).astype(int) 79 | features = self.features[relevant_indices] 80 | labels = np.copy(self.clusters[relevant_indices]) 81 | existing_indices = np.arange(len(self.lSet)) 82 | # counting cluster sizes and number of labeled samples per cluster 83 | cluster_ids, cluster_sizes = np.unique(labels, return_counts=True) 84 | cluster_labeled_counts = np.bincount(labels[existing_indices], minlength=len(cluster_ids)) 85 | clusters_df = pd.DataFrame({'cluster_id': cluster_ids, 'cluster_size': cluster_sizes, 'existing_count': cluster_labeled_counts, 86 | 'neg_cluster_size': -1 * cluster_sizes}) 87 | # drop too small clusters 88 | clusters_df = clusters_df[clusters_df.cluster_size > self.MIN_CLUSTER_SIZE] 89 | # sort clusters by lowest number of existing samples, and then by cluster sizes (large to small) 90 | clusters_df = clusters_df.sort_values(['existing_count', 'neg_cluster_size']) 91 | labels[existing_indices] = -1 92 | 93 | selected = [] 94 | 95 | for i in range(self.budgetSize): 96 | cluster = clusters_df.iloc[i % len(clusters_df)].cluster_id 97 | indices = (labels == cluster).nonzero()[0] 98 | rel_feats = features[indices] 99 | # in case we have too small cluster, calculate density among half of the cluster 100 | typicality = calculate_typicality(rel_feats, min(self.K_NN, len(indices) // 2)) 101 | idx = indices[typicality.argmax()] 102 | selected.append(idx) 103 | labels[idx] = -1 104 | 105 | selected = np.array(selected) 106 | assert len(selected) == self.budgetSize, 'added a different number of samples' 107 | assert len(np.intersect1d(selected, existing_indices)) == 0, 'should be new samples' 108 | activeSet = relevant_indices[selected] 109 | remainSet = np.array(sorted(list(set(self.uSet) - set(activeSet)))) 110 | 111 | print(f'Finished the selection of {len(activeSet)} samples.') 112 | print(f'Active set is {activeSet}') 113 | return activeSet, remainSet 114 | -------------------------------------------------------------------------------- /deep-al/pycls/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/core/__init__.py -------------------------------------------------------------------------------- /deep-al/pycls/core/builders.py: -------------------------------------------------------------------------------- 1 | # This file is modified from official pycls repository 2 | 3 | """Model and loss construction functions.""" 4 | 5 | from pycls.core.net import SoftCrossEntropyLoss 6 | from pycls.models.resnet import * 7 | from pycls.models.vgg import * 8 | from pycls.models.alexnet import * 9 | 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | # Supported models 14 | _models = { 15 | # VGG style architectures 16 | 'vgg11': vgg11, 17 | 'vgg11_bn': vgg11_bn, 18 | 'vgg13': vgg13, 19 | 'vgg13_bn': vgg13_bn, 20 | 'vgg16': vgg16, 21 | 'vgg16_bn': vgg16_bn, 22 | 'vgg19': vgg19, 23 | 'vgg19_bn': vgg19_bn, 24 | 25 | # ResNet style archiectures 26 | 'resnet18': resnet18, 27 | 'resnet34': resnet34, 28 | 'resnet50': resnet50, 29 | 'resnet101': resnet101, 30 | 'resnet152': resnet152, 31 | 'resnext50_32x4d': resnext50_32x4d, 32 | 'resnext101_32x8d': resnext101_32x8d, 33 | 'wide_resnet50_2': wide_resnet50_2, 34 | 'wide_resnet101_2': wide_resnet101_2, 35 | 36 | # AlexNet architecture 37 | 'alexnet': alexnet 38 | } 39 | 40 | # Supported loss functions 41 | _loss_funs = {"cross_entropy": SoftCrossEntropyLoss} 42 | 43 | 44 | class FeaturesNet(nn.Module): 45 | def __init__(self, in_layers, out_layers, use_mlp=False, penultimate_active=False): 46 | super().__init__() 47 | self.use_mlp = use_mlp 48 | self.penultimate_active = penultimate_active 49 | self.lin1 = nn.Linear(in_layers, in_layers) 50 | self.lin2 = nn.Linear(in_layers, in_layers) 51 | self.final = nn.Linear(in_layers, out_layers) 52 | 53 | def forward(self, x): 54 | feats = x 55 | if self.use_mlp: 56 | x = F.relu(self.lin1(x)) 57 | x = F.relu((self.lin2(x))) 58 | out = self.final(x) 59 | if self.penultimate_active: 60 | return feats, out 61 | return out 62 | 63 | 64 | def get_model(cfg): 65 | """Gets the model class specified in the config.""" 66 | err_str = "Model type '{}' not supported" 67 | assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE) 68 | return _models[cfg.MODEL.TYPE] 69 | 70 | 71 | def get_loss_fun(cfg): 72 | """Gets the loss function class specified in the config.""" 73 | err_str = "Loss function type '{}' not supported" 74 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS) 75 | return _loss_funs[cfg.MODEL.LOSS_FUN] 76 | 77 | 78 | def build_model(cfg): 79 | """Builds the model.""" 80 | if cfg.MODEL.LINEAR_FROM_FEATURES: 81 | num_features = 384 if cfg.DATASET.NAME in ['IMAGENET50', 'IMAGENET100', 'IMAGENET200'] else 512 82 | return FeaturesNet(num_features, cfg.MODEL.NUM_CLASSES) 83 | 84 | model = get_model(cfg)(num_classes=cfg.MODEL.NUM_CLASSES, use_dropout=True) 85 | if cfg.DATASET.NAME == 'MNIST': 86 | model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 87 | 88 | return model 89 | 90 | 91 | def build_loss_fun(cfg): 92 | """Build the loss function.""" 93 | return get_loss_fun(cfg)() 94 | 95 | 96 | def register_model(name, ctor): 97 | """Registers a model dynamically.""" 98 | _models[name] = ctor 99 | 100 | 101 | def register_loss_fun(name, ctor): 102 | """Registers a loss function dynamically.""" 103 | _loss_funs[name] = ctor 104 | -------------------------------------------------------------------------------- /deep-al/pycls/core/losses.py: -------------------------------------------------------------------------------- 1 | """Loss functions.""" 2 | 3 | import torch.nn as nn 4 | 5 | from pycls.core.config import cfg 6 | 7 | # Supported loss functions 8 | _loss_funs = { 9 | 'cross_entropy': nn.CrossEntropyLoss, 10 | } 11 | 12 | def get_loss_fun(): 13 | """Retrieves the loss function.""" 14 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), \ 15 | 'Loss function \'{}\' not supported'.format(cfg.TRAIN.LOSS) 16 | return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda() 17 | 18 | 19 | def register_loss_fun(name, ctor): 20 | """Registers a loss function dynamically.""" 21 | _loss_funs[name] = ctor 22 | -------------------------------------------------------------------------------- /deep-al/pycls/core/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import numpy as np 11 | # import pycls.core.distributed as dist 12 | import torch 13 | from pycls.core.config import cfg 14 | 15 | 16 | def unwrap_model(model): 17 | """Remove the DistributedDataParallel wrapper if present.""" 18 | wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel) 19 | return model.module if wrapped else model 20 | 21 | 22 | # @torch.no_grad() 23 | # def compute_precise_bn_stats(model, loader): 24 | # """Computes precise BN stats on training data.""" 25 | # # Compute the number of minibatches to use 26 | # num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size / cfg.NUM_GPUS) 27 | # num_iter = min(num_iter, len(loader)) 28 | # # Retrieve the BN layers 29 | # bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 30 | # # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch)) 31 | # running_means = [torch.zeros_like(bn.running_mean) for bn in bns] 32 | # running_vars = [torch.zeros_like(bn.running_var) for bn in bns] 33 | # # Remember momentum values 34 | # momentums = [bn.momentum for bn in bns] 35 | # # Set momentum to 1.0 to compute BN stats that only reflect the current batch 36 | # for bn in bns: 37 | # bn.momentum = 1.0 38 | # # Average the BN stats for each BN layer over the batches 39 | # for inputs, _labels in itertools.islice(loader, num_iter): 40 | # model(inputs.cuda()) 41 | # for i, bn in enumerate(bns): 42 | # running_means[i] += bn.running_mean / num_iter 43 | # running_vars[i] += bn.running_var / num_iter 44 | # # Sync BN stats across GPUs (no reduction if 1 GPU used) 45 | # running_means = dist.scaled_all_reduce(running_means) 46 | # running_vars = dist.scaled_all_reduce(running_vars) 47 | # # Set BN stats and restore original momentum values 48 | # for i, bn in enumerate(bns): 49 | # bn.running_mean = running_means[i] 50 | # bn.running_var = running_vars[i] 51 | # bn.momentum = momentums[i] 52 | 53 | 54 | def complexity(model): 55 | """Compute model complexity (model can be model instance or model class).""" 56 | size = cfg.TRAIN.IM_SIZE 57 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0} 58 | cx = unwrap_model(model).complexity(cx) 59 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]} 60 | 61 | 62 | def smooth_one_hot_labels(labels): 63 | """Convert each label to a one-hot vector.""" 64 | n_classes, label_smooth = cfg.MODEL.NUM_CLASSES, cfg.TRAIN.LABEL_SMOOTHING 65 | err_str = "Invalid input to one_hot_vector()" 66 | assert labels.ndim == 1 and labels.max() < n_classes, err_str 67 | shape = (labels.shape[0], n_classes) 68 | neg_val = label_smooth / n_classes 69 | pos_val = 1.0 - label_smooth + neg_val 70 | labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device) 71 | labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val) 72 | return labels_one_hot 73 | 74 | 75 | class SoftCrossEntropyLoss(torch.nn.Module): 76 | """SoftCrossEntropyLoss (useful for label smoothing and mixup). 77 | Identical to torch.nn.CrossEntropyLoss if used with one-hot labels.""" 78 | 79 | def __init__(self): 80 | super(SoftCrossEntropyLoss, self).__init__() 81 | 82 | def forward(self, x, y): 83 | loss = -y * torch.nn.functional.log_softmax(x, -1) 84 | return torch.sum(loss) / x.shape[0] 85 | 86 | 87 | def mixup(inputs, labels): 88 | """Apply mixup to minibatch (https://arxiv.org/abs/1710.09412).""" 89 | alpha = cfg.TRAIN.MIXUP_ALPHA 90 | assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot" 91 | if alpha > 0: 92 | m = np.random.beta(alpha, alpha) 93 | permutation = torch.randperm(labels.shape[0]) 94 | inputs = m * inputs + (1.0 - m) * inputs[permutation, :] 95 | labels = m * labels + (1.0 - m) * labels[permutation, :] 96 | return inputs, labels, labels.argmax(1) 97 | -------------------------------------------------------------------------------- /deep-al/pycls/core/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Optimizer.""" 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def construct_optimizer(cfg, model): 16 | """Constructs the optimizer. 17 | 18 | Note that the momentum update in PyTorch differs from the one in Caffe2. 19 | In particular, 20 | 21 | Caffe2: 22 | V := mu * V + lr * g 23 | p := p - V 24 | 25 | PyTorch: 26 | V := mu * V + g 27 | p := p - lr * V 28 | 29 | where V is the velocity, mu is the momentum factor, lr is the learning rate, 30 | g is the gradient and p are the parameters. 31 | 32 | Since V is defined independently of the learning rate in PyTorch, 33 | when the learning rate is changed there is no need to perform the 34 | momentum correction by scaling V (unlike in the Caffe2 case). 35 | """ 36 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY: 37 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 38 | p_bn = [p for n, p in model.named_parameters() if "bn" in n] 39 | p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n] 40 | optim_params = [ 41 | {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY}, 42 | {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, 43 | ] 44 | else: 45 | optim_params = model.parameters() 46 | 47 | if cfg.OPTIM.TYPE == 'sgd': 48 | optimizer = torch.optim.SGD( 49 | model.parameters(), 50 | lr=cfg.OPTIM.BASE_LR, 51 | momentum=cfg.OPTIM.MOMENTUM, 52 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 53 | dampening=cfg.OPTIM.DAMPENING, 54 | nesterov=cfg.OPTIM.NESTEROV 55 | ) 56 | elif cfg.OPTIM.TYPE == 'adam': 57 | optimizer = torch.optim.Adam( 58 | model.parameters(), 59 | lr=cfg.OPTIM.BASE_LR, 60 | weight_decay=cfg.OPTIM.WEIGHT_DECAY 61 | ) 62 | else: 63 | raise NotImplementedError 64 | 65 | return optimizer 66 | 67 | 68 | def lr_fun_steps(cfg, cur_epoch): 69 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" 70 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] 71 | return cfg.OPTIM.LR_MULT ** ind 72 | 73 | 74 | def lr_fun_exp(cfg, cur_epoch): 75 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" 76 | return cfg.OPTIM.MIN_LR ** (cur_epoch / cfg.OPTIM.MAX_EPOCH) 77 | 78 | 79 | def lr_fun_cos(cfg, cur_epoch): 80 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" 81 | lr = 0.5 * (1.0 + np.cos(np.pi * cur_epoch / cfg.OPTIM.MAX_EPOCH)) 82 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR 83 | 84 | 85 | def lr_fun_lin(cfg, cur_epoch): 86 | """Linear schedule (cfg.OPTIM.LR_POLICY = 'lin').""" 87 | lr = 1.0 - cur_epoch / cfg.OPTIM.MAX_EPOCH 88 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR 89 | 90 | 91 | def lr_fun_none(cfg, cur_epoch): 92 | """No schedule (cfg.OPTIM.LR_POLICY = 'none').""" 93 | return 1 94 | 95 | 96 | def get_lr_fun(cfg): 97 | """Retrieves the specified lr policy function""" 98 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY 99 | assert lr_fun in globals(), "Unknown LR policy: " + cfg.OPTIM.LR_POLICY 100 | err_str = "exp lr policy requires OPTIM.MIN_LR to be greater than 0." 101 | assert cfg.OPTIM.LR_POLICY != "exp" or cfg.OPTIM.MIN_LR > 0, err_str 102 | return globals()[lr_fun] 103 | 104 | 105 | def get_epoch_lr(cfg, cur_epoch): 106 | """Retrieves the lr for the given epoch according to the policy.""" 107 | # Get lr and scale by by BASE_LR 108 | lr = get_lr_fun(cfg)(cfg, cur_epoch) * cfg.OPTIM.BASE_LR 109 | # Linear warmup 110 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS and 'none' not in cfg.OPTIM.LR_POLICY: 111 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS 112 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha 113 | lr *= warmup_factor 114 | return lr 115 | 116 | 117 | def set_lr(optimizer, new_lr): 118 | """Sets the optimizer lr to the specified value.""" 119 | for param_group in optimizer.param_groups: 120 | param_group["lr"] = new_lr 121 | 122 | 123 | def plot_lr_fun(): 124 | """Visualizes lr function.""" 125 | epochs = list(range(cfg.OPTIM.MAX_EPOCH)) 126 | lrs = [get_epoch_lr(epoch) for epoch in epochs] 127 | plt.plot(epochs, lrs, ".-") 128 | plt.title("lr_policy: {}".format(cfg.OPTIM.LR_POLICY)) 129 | plt.xlabel("epochs") 130 | plt.ylabel("learning rate") 131 | plt.ylim(bottom=0) 132 | plt.show() 133 | -------------------------------------------------------------------------------- /deep-al/pycls/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/datasets/__init__.py -------------------------------------------------------------------------------- /deep-al/pycls/datasets/custom_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from PIL import Image 4 | import numpy as np 5 | import pycls.datasets.utils as ds_utils 6 | 7 | 8 | class CIFAR10(torchvision.datasets.CIFAR10): 9 | def __init__(self, root, train, transform, test_transform, download=True, only_features= False): 10 | super(CIFAR10, self).__init__(root, train, transform=transform, download=download) 11 | self.test_transform = test_transform 12 | self.no_aug = False 13 | self.only_features = only_features 14 | self.features = ds_utils.load_features("CIFAR10", train=train, normalized=False) 15 | 16 | 17 | def __getitem__(self, index: int): 18 | """ 19 | Args: 20 | index (int): Index 21 | 22 | Returns: 23 | tuple: (image, target) where target is index of the target class. 24 | """ 25 | img, target = self.data[index], self.targets[index] 26 | 27 | # doing this so that it is consistent with all other datasets 28 | # to return a PIL Image 29 | img = Image.fromarray(img) 30 | if self.only_features: 31 | img = self.features[index] 32 | else: 33 | if self.no_aug: 34 | if self.test_transform is not None: 35 | img = self.test_transform(img) 36 | else: 37 | if self.transform is not None: 38 | img = self.transform(img) 39 | 40 | 41 | return img, target 42 | 43 | 44 | class CIFAR100(torchvision.datasets.CIFAR100): 45 | def __init__(self, root, train, transform, test_transform, download=True, only_features= False): 46 | super(CIFAR100, self).__init__(root, train, transform=transform, download=download) 47 | self.test_transform = test_transform 48 | self.no_aug = False 49 | self.only_features = only_features 50 | self.features = ds_utils.load_features("CIFAR100", train=train, normalized=False) 51 | 52 | def __getitem__(self, index: int): 53 | """ 54 | Args: 55 | index (int): Index 56 | 57 | Returns: 58 | tuple: (image, target) where target is index of the target class. 59 | """ 60 | img, target = self.data[index], self.targets[index] 61 | 62 | # doing this so that it is consistent with all other datasets 63 | # to return a PIL Image 64 | img = Image.fromarray(img) 65 | if self.only_features: 66 | img = self.features[index] 67 | else: 68 | if self.no_aug: 69 | if self.test_transform is not None: 70 | img = self.test_transform(img) 71 | else: 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | 75 | return img, target 76 | 77 | 78 | class STL10(torchvision.datasets.STL10): 79 | def __init__(self, root, train, transform, test_transform, download=True): 80 | super(STL10, self).__init__(root, train, transform=transform, download=download) 81 | self.test_transform = test_transform 82 | self.no_aug = False 83 | self.targets = self.labels 84 | 85 | def __getitem__(self, index: int): 86 | """ 87 | Args: 88 | index (int): Index 89 | 90 | Returns: 91 | tuple: (image, target) where target is index of the target class. 92 | """ 93 | img, target = self.data[index], int(self.targets[index]) 94 | 95 | # doing this so that it is consistent with all other datasets 96 | # to return a PIL Image 97 | img = Image.fromarray(img.transpose(1,2,0)) 98 | 99 | if self.no_aug: 100 | if self.test_transform is not None: 101 | img = self.test_transform(img) 102 | else: 103 | if self.transform is not None: 104 | img = self.transform(img) 105 | 106 | return img, target 107 | 108 | 109 | class MNIST(torchvision.datasets.MNIST): 110 | def __init__(self, root, train, transform, test_transform, download=True): 111 | super(MNIST, self).__init__(root, train, transform=transform, download=download) 112 | self.test_transform = test_transform 113 | self.no_aug = False 114 | 115 | def __getitem__(self, index: int): 116 | """ 117 | Args: 118 | index (int): Index 119 | 120 | Returns: 121 | tuple: (image, target) where target is index of the target class. 122 | """ 123 | img, target = self.data[index], int(self.targets[index]) 124 | 125 | # doing this so that it is consistent with all other datasets 126 | # to return a PIL Image 127 | img = Image.fromarray(img.numpy(), mode='L') 128 | 129 | if self.no_aug: 130 | if self.test_transform is not None: 131 | img = self.test_transform(img) 132 | else: 133 | if self.transform is not None: 134 | img = self.transform(img) 135 | 136 | 137 | return img, target 138 | 139 | 140 | class SVHN(torchvision.datasets.SVHN): 141 | def __init__(self, root, train, transform, test_transform, download=True): 142 | super(SVHN, self).__init__(root, train, transform=transform, download=download) 143 | self.test_transform = test_transform 144 | self.no_aug = False 145 | 146 | def __getitem__(self, index: int): 147 | """ 148 | Args: 149 | index (int): Index 150 | 151 | Returns: 152 | tuple: (image, target) where target is index of the target class. 153 | """ 154 | img, target = self.data[index], self.targets[index] 155 | 156 | # doing this so that it is consistent with all other datasets 157 | # to return a PIL Image 158 | img = Image.fromarray(img) 159 | 160 | if self.no_aug: 161 | if self.test_transform is not None: 162 | img = self.test_transform(img) 163 | else: 164 | if self.transform is not None: 165 | img = self.transform(img) 166 | 167 | 168 | return img, target -------------------------------------------------------------------------------- /deep-al/pycls/datasets/imbalanced_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits: Kaihua Tang 3 | Source: https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/ 4 | """ 5 | 6 | import numpy as np 7 | from pycls.datasets.custom_datasets import CIFAR10, CIFAR100 8 | 9 | class IMBALANCECIFAR10(CIFAR10): 10 | cls_num = 10 11 | np.random.seed(1) 12 | def __init__(self, root, train, transform=None, test_transform=None, imbalance_ratio=0.02, imb_type='exp'): 13 | super(IMBALANCECIFAR10, self).__init__(root, train, transform=transform, test_transform=test_transform, download=True) 14 | self.train = train 15 | self.transform = transform 16 | if self.train: 17 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio) 18 | self.gen_imbalanced_data(img_num_list) 19 | phase = 'Train' 20 | else: 21 | phase = 'Test' 22 | self.labels = self.targets 23 | 24 | print("{} Mode: Contain {} images".format(phase, len(self.data))) 25 | 26 | def _get_class_dict(self): 27 | class_dict = dict() 28 | for i, anno in enumerate(self.get_annotations()): 29 | cat_id = anno["category_id"] 30 | if not cat_id in class_dict: 31 | class_dict[cat_id] = [] 32 | class_dict[cat_id].append(i) 33 | return class_dict 34 | 35 | 36 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 37 | img_max = len(self.data) / cls_num 38 | img_num_per_cls = [] 39 | if imb_type == 'exp': 40 | for cls_idx in range(cls_num): 41 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 42 | img_num_per_cls.append(int(num)) 43 | elif imb_type == 'step': 44 | for cls_idx in range(cls_num // 2): 45 | img_num_per_cls.append(int(img_max)) 46 | for cls_idx in range(cls_num // 2): 47 | img_num_per_cls.append(int(img_max * imb_factor)) 48 | else: 49 | img_num_per_cls.extend([int(img_max)] * cls_num) 50 | return img_num_per_cls 51 | 52 | def gen_imbalanced_data(self, img_num_per_cls): 53 | 54 | new_data = [] 55 | new_targets = [] 56 | targets_np = np.array(self.targets, dtype=np.int64) 57 | classes = np.unique(targets_np) 58 | 59 | self.num_per_cls_dict = dict() 60 | for the_class, the_img_num in zip(classes, img_num_per_cls): 61 | self.num_per_cls_dict[the_class] = the_img_num 62 | idx = np.where(targets_np == the_class)[0] 63 | np.random.shuffle(idx) 64 | selec_idx = idx[:the_img_num] 65 | new_data.append(self.data[selec_idx, ...]) 66 | new_targets.extend([the_class, ] * the_img_num) 67 | new_data = np.vstack(new_data) 68 | self.data = new_data 69 | self.targets = new_targets 70 | 71 | def __len__(self): 72 | return len(self.labels) 73 | 74 | def get_num_classes(self): 75 | return self.cls_num 76 | 77 | def get_annotations(self): 78 | annos = [] 79 | for label in self.labels: 80 | annos.append({'category_id': int(label)}) 81 | return annos 82 | 83 | def get_cls_num_list(self): 84 | cls_num_list = [] 85 | for i in range(self.cls_num): 86 | cls_num_list.append(self.num_per_cls_dict[i]) 87 | return cls_num_list 88 | 89 | 90 | class IMBALANCECIFAR100(CIFAR100): 91 | """`CIFAR100 `_ Dataset. 92 | This is a subclass of the `CIFAR10` Dataset. 93 | """ 94 | cls_num = 100 95 | base_folder = 'cifar-100-python' 96 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 97 | filename = "cifar-100-python.tar.gz" 98 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 99 | train_list = [ 100 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 101 | ] 102 | 103 | test_list = [ 104 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 105 | ] 106 | meta = { 107 | 'filename': 'meta', 108 | 'key': 'fine_label_names', 109 | 'md5': '7973b15100ade9c7d40fb424638fde48', 110 | } -------------------------------------------------------------------------------- /deep-al/pycls/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | # This file is directly taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564 2 | # GitHub: https://github.com/PrateekMunjal 3 | # ---------------------------------------------------------- 4 | 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class IndexedSequentialSampler(Sampler): 9 | r"""Samples elements sequentially, always in the same order. 10 | 11 | Arguments: 12 | data_idxes (Dataset indexes): dataset indexes to sample from 13 | """ 14 | 15 | def __init__(self, data_idxes, isDebug=False): 16 | if isDebug: print("========= my custom squential sampler =========") 17 | self.data_idxes = data_idxes 18 | 19 | def __iter__(self): 20 | return (self.data_idxes[i] for i in range(len(self.data_idxes))) 21 | 22 | def __len__(self): 23 | return len(self.data_idxes) 24 | 25 | # class IndexedDistributedSampler(Sampler): 26 | # """Sampler that restricts data loading to a particular index set of dataset. 27 | 28 | # It is especially useful in conjunction with 29 | # :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 30 | # process can pass a DistributedSampler instance as a DataLoader sampler, 31 | # and load a subset of the original dataset that is exclusive to it. 32 | 33 | # .. note:: 34 | # Dataset is assumed to be of constant size. 35 | 36 | # Arguments: 37 | # dataset: Dataset used for sampling. 38 | # num_replicas (optional): Number of processes participating in 39 | # distributed training. 40 | # rank (optional): Rank of the current process within num_replicas. 41 | # """ 42 | 43 | # def __init__(self, dataset, index_set, num_replicas=None, rank=None, allowRepeat=True): 44 | # if num_replicas is None: 45 | # if not dist.is_available(): 46 | # raise RuntimeError("Requires distributed package to be available") 47 | # num_replicas = dist.get_world_size() 48 | # if rank is None: 49 | # if not dist.is_available(): 50 | # raise RuntimeError("Requires distributed package to be available") 51 | # rank = dist.get_rank() 52 | # self.dataset = dataset 53 | # self.num_replicas = num_replicas 54 | # self.rank = rank 55 | # self.epoch = 0 56 | # self.index_set = index_set 57 | # self.allowRepeat = allowRepeat 58 | # if self.allowRepeat: 59 | # self.num_samples = int(math.ceil(len(self.index_set) * 1.0 / self.num_replicas)) 60 | # self.total_size = self.num_samples * self.num_replicas 61 | # else: 62 | # self.num_samples = int(math.ceil((len(self.index_set)-self.rank) * 1.0 / self.num_replicas)) 63 | # self.total_size = len(self.index_set) 64 | 65 | # def __iter__(self): 66 | # # deterministically shuffle based on epoch 67 | # g = torch.Generator() 68 | # g.manual_seed(self.epoch) 69 | # indices = torch.randperm(len(self.index_set), generator=g).tolist() 70 | # #To access valid indices 71 | # #indices = self.index_set[indices] 72 | # # add extra samples to make it evenly divisible 73 | # if self.allowRepeat: 74 | # indices += indices[:(self.total_size - len(indices))] 75 | 76 | # assert len(indices) == self.total_size 77 | 78 | # # subsample 79 | # indices = self.index_set[indices[self.rank:self.total_size:self.num_replicas]] 80 | # assert len(indices) == self.num_samples, "len(indices): {} and self.num_samples: {}"\ 81 | # .format(len(indices), self.num_samples) 82 | 83 | # return iter(indices) 84 | 85 | # def __len__(self): 86 | # return self.num_samples 87 | 88 | # def set_epoch(self, epoch): 89 | # self.epoch = epoch 90 | -------------------------------------------------------------------------------- /deep-al/pycls/datasets/simclr_augment.py: -------------------------------------------------------------------------------- 1 | # Modified from the source: https://github.com/sthalles/PyTorch-BYOL 2 | # Previous owner of this file: Thalles Silva 3 | 4 | from torchvision import transforms 5 | from .utils.gaussian_blur import GaussianBlur 6 | 7 | def get_simclr_ops(input_shape, s=1): 8 | # get a set of data augmentation transformations as described in the SimCLR paper. 9 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 10 | ops = [transforms.RandomHorizontalFlip(), 11 | transforms.RandomApply([color_jitter], p=0.8), 12 | transforms.RandomGrayscale(p=0.2), 13 | GaussianBlur(kernel_size=int(0.1 * input_shape)),] 14 | return ops 15 | 16 | 17 | -------------------------------------------------------------------------------- /deep-al/pycls/datasets/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | import torchvision.datasets as datasets 7 | 8 | from typing import Any 9 | 10 | 11 | class TinyImageNet(datasets.ImageFolder): 12 | """`Tiny ImageNet Classification Dataset. 13 | 14 | Args: 15 | root (string): Root directory of the ImageNet Dataset. 16 | split (string, optional): The dataset split, supports ``train``, or ``val``. 17 | transform (callable, optional): A function/transform that takes in an PIL image 18 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 19 | target_transform (callable, optional): A function/transform that takes in the 20 | target and transforms it. 21 | loader (callable, optional): A function to load an image given its path. 22 | 23 | Attributes: 24 | classes (list): List of the class name tuples. 25 | class_to_idx (dict): Dict with items (class_name, class_index). 26 | wnids (list): List of the WordNet IDs. 27 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index). 28 | samples (list): List of (image path, class_index) tuples 29 | targets (list): The class_index value for each image in the dataset 30 | """ 31 | def __init__(self, root: str, split: str = 'train', transform=None, test_transform=None, only_features=False, **kwargs: Any) -> None: 32 | self.root = root 33 | self.test_transform = test_transform 34 | self.no_aug = False 35 | 36 | assert self.check_root(), "Something is wrong with the Tiny ImageNet dataset path. Download the official dataset zip from http://cs231n.stanford.edu/tiny-imagenet-200.zip and unzip it inside {}.".format(self.root) 37 | self.split = datasets.utils.verify_str_arg(split, "split", ("train", "val")) 38 | wnid_to_classes = self.load_wnid_to_classes() 39 | 40 | super(TinyImageNet, self).__init__(self.split_folder, **kwargs) 41 | self.transform = transform 42 | self.wnids = self.classes 43 | self.wnid_to_idx = self.class_to_idx 44 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] 45 | self.class_to_idx = {cls: idx 46 | for idx, clss in enumerate(self.classes) 47 | for cls in clss} 48 | # Tiny ImageNet val directory structure is not similar to that of train's 49 | # So a custom loading function is necessary 50 | self.only_features = only_features 51 | if self.split == 'train': 52 | self.features = np.load('../../scan/results/tiny-imagenet/pretext/features_seed1.npy') 53 | elif self.split == 'val': 54 | self.features = np.load('../../scan/results/tiny-imagenet/pretext/test_features_seed1.npy') 55 | self.root = root 56 | self.imgs, self.targets = self.load_val_data() 57 | self.samples = [(self.imgs[idx], self.targets[idx]) for idx in range(len(self.imgs))] 58 | self.root = os.path.join(self.root, 'val') 59 | 60 | 61 | 62 | # Split folder is used for the 'super' call. Since val directory is not structured like the train, 63 | # we simply use train's structure to get all classes and other stuff 64 | @property 65 | def split_folder(self) -> str: 66 | return os.path.join(self.root, 'train') 67 | 68 | 69 | def load_val_data(self): 70 | imgs, targets = [], [] 71 | with open(os.path.join(self.root, 'val', 'val_annotations.txt'), 'r') as file: 72 | for line in file: 73 | if line.split()[1] in self.wnids: 74 | img_file, wnid = line.split('\t')[:2] 75 | imgs.append(os.path.join(self.root, 'val', 'images', img_file)) 76 | targets.append(wnid) 77 | targets = np.array([self.wnid_to_idx[wnid] for wnid in targets]) 78 | return imgs, targets 79 | 80 | 81 | def load_wnid_to_classes(self): 82 | wnid_to_classes = {} 83 | with open(os.path.join(self.root, 'words.txt'), 'r') as file: 84 | lines = file.readlines() 85 | lines = [x.split('\t') for x in lines] 86 | wnid_to_classes = {x[0]:x[1].strip() for x in lines} 87 | return wnid_to_classes 88 | 89 | def check_root(self): 90 | tinyim_set = ['words.txt', 'wnids.txt', 'train', 'val', 'test'] 91 | for x in os.scandir(self.root): 92 | if x.name not in tinyim_set: 93 | return False 94 | return True 95 | 96 | def __getitem__(self, index: int): 97 | """ 98 | Args: 99 | index (int): Index 100 | 101 | Returns: 102 | tuple: (sample, target) where target is class_index of the target class. 103 | """ 104 | path, target = self.samples[index] 105 | sample = self.loader(path) 106 | 107 | if self.only_features: 108 | sample = self.features[index] 109 | else: 110 | if self.no_aug: 111 | if self.test_transform is not None: 112 | sample = self.test_transform(sample) 113 | else: 114 | if self.transform is not None: 115 | sample = self.transform(sample) 116 | 117 | return sample, target -------------------------------------------------------------------------------- /deep-al/pycls/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .features import load_features -------------------------------------------------------------------------------- /deep-al/pycls/datasets/utils/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | DATASET_FEATURES_DICT = { 3 | 'train': 4 | { 5 | 'CIFAR10':'../../scan/results/cifar-10/pretext/features_seed{seed}.npy', 6 | 'CIFAR100':'../../scan/results/cifar-100/pretext/features_seed{seed}.npy', 7 | 'TINYIMAGENET': '../../scan/results/tiny-imagenet/pretext/features_seed{seed}.npy', 8 | 'IMAGENET50': '../../dino/runs/trainfeat.pth', 9 | 'IMAGENET100': '../../dino/runs/trainfeat.pth', 10 | 'IMAGENET200': '../../dino/runs/trainfeat.pth', 11 | }, 12 | 'test': 13 | { 14 | 'CIFAR10': '../../scan/results/cifar-10/pretext/test_features_seed{seed}.npy', 15 | 'CIFAR100': '../../scan/results/cifar-100/pretext/test_features_seed{seed}.npy', 16 | 'TINYIMAGENET': '../../scan/results/tiny-imagenet/pretext/test_features_seed{seed}.npy', 17 | 'IMAGENET50': '../../dino/runs/testfeat.pth', 18 | 'IMAGENET100': '../../dino/runs/testfeat.pth', 19 | 'IMAGENET200': '../../dino/runs/testfeat.pth', 20 | } 21 | } 22 | 23 | def load_features(ds_name, seed=1, train=True, normalized=True): 24 | " load pretrained features for a dataset " 25 | split = "train" if train else "test" 26 | fname = DATASET_FEATURES_DICT[split][ds_name].format(seed=seed) 27 | if fname.endswith('.npy'): 28 | features = np.load(fname) 29 | elif fname.endswith('.pth'): 30 | features = torch.load(fname) 31 | else: 32 | raise Exception("Unsupported filetype") 33 | if normalized: 34 | features = features / np.linalg.norm(features, axis=1, keepdims=True) 35 | return features -------------------------------------------------------------------------------- /deep-al/pycls/datasets/utils/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | # Owner of this file: Thalles Silva 2 | # Source: https://github.com/sthalles/PyTorch-BYOL 3 | import torch 4 | from torchvision import transforms 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class GaussianBlur(object): 10 | """blur a single image on CPU""" 11 | 12 | def __init__(self, kernel_size): 13 | radias = kernel_size // 2 14 | kernel_size = radias * 2 + 1 15 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 16 | stride=1, padding=0, bias=False, groups=3) 17 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 18 | stride=1, padding=0, bias=False, groups=3) 19 | self.k = kernel_size 20 | self.r = radias 21 | 22 | self.blur = nn.Sequential( 23 | nn.ReflectionPad2d(radias), 24 | self.blur_h, 25 | self.blur_v 26 | ) 27 | 28 | self.pil_to_tensor = transforms.ToTensor() 29 | self.tensor_to_pil = transforms.ToPILImage() 30 | 31 | def __call__(self, img): 32 | img = self.pil_to_tensor(img).unsqueeze(0) 33 | 34 | sigma = np.random.uniform(0.1, 2.0) 35 | x = np.arange(-self.r, self.r + 1) 36 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 37 | x = x / x.sum() 38 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 39 | 40 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 41 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 42 | 43 | with torch.no_grad(): 44 | img = self.blur(img) 45 | img = img.squeeze() 46 | 47 | img = self.tensor_to_pil(img) 48 | 49 | return img -------------------------------------------------------------------------------- /deep-al/pycls/datasets/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | def load_imageset(path, set_name): 4 | """ 5 | Returns the image set `set_name` present at `path` as a list. 6 | Keyword arguments: 7 | path -- path to data folder 8 | set_name -- image set name - labeled or unlabeled. 9 | """ 10 | reader = csv.reader(open(os.path.join(path, set_name+'.csv'), 'rt')) 11 | reader = [r[0] for r in reader] 12 | return reader -------------------------------------------------------------------------------- /deep-al/pycls/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Expose model constructors.""" 9 | 10 | -------------------------------------------------------------------------------- /deep-al/pycls/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | from typing import Any 5 | 6 | 7 | __all__ = ['AlexNet', 'alexnet'] 8 | 9 | 10 | model_urls = { 11 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 12 | } 13 | 14 | 15 | class AlexNet(nn.Module): 16 | ''' 17 | AlexNet modified (features) for CIFAR10. Source: https://github.com/icpm/pytorch-cifar10/blob/master/models/AlexNet.py. 18 | ''' 19 | def __init__(self, num_classes: int = 1000, use_dropout=False) -> None: 20 | super(AlexNet, self).__init__() 21 | self.use_dropout = use_dropout 22 | self.features = nn.Sequential( 23 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(kernel_size=2), 26 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=2), 29 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.MaxPool2d(kernel_size=2), 36 | ) 37 | # self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 38 | self.fc_block = nn.Sequential( 39 | nn.Linear(256 * 2 * 2, 4096, bias=False), 40 | nn.BatchNorm1d(4096), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(4096, 4096, bias=False), 43 | nn.BatchNorm1d(4096), 44 | nn.ReLU(inplace=True), 45 | ) 46 | self.classifier = nn.Sequential( 47 | nn.Linear(4096, num_classes), 48 | ) 49 | self.penultimate_active = False 50 | self.drop = nn.Dropout(p=0.5) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.features(x) 54 | # x = self.avgpool(x) 55 | z = torch.flatten(x, 1) 56 | if self.use_dropout: 57 | x = self.drop(x) 58 | z = self.fc_block(z) 59 | x = self.classifier(z) 60 | if self.penultimate_active: 61 | return z, x 62 | return x 63 | 64 | 65 | def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: 66 | r"""AlexNet model architecture from the 67 | `"One weird trick..." `_ paper. 68 | Args: 69 | pretrained (bool): If True, returns a model pre-trained on ImageNet 70 | progress (bool): If True, displays a progress bar of the download to stderr 71 | """ 72 | model = AlexNet(**kwargs) 73 | if pretrained: 74 | state_dict = load_state_dict_from_url(model_urls['alexnet'], 75 | progress=progress) 76 | model.load_state_dict(state_dict) 77 | return model -------------------------------------------------------------------------------- /deep-al/pycls/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/utils/__init__.py -------------------------------------------------------------------------------- /deep-al/pycls/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Benchmarking functions.""" 9 | 10 | import pycls.core.logging as logging 11 | import pycls.core.net as net 12 | import pycls.datasets.loader as loader 13 | import torch 14 | import torch.cuda.amp as amp 15 | from pycls.core.config import cfg 16 | from pycls.core.timer import Timer 17 | 18 | 19 | logger = logging.get_logger(__name__) 20 | 21 | 22 | @torch.no_grad() 23 | def compute_time_eval(model): 24 | """Computes precise model forward test time using dummy data.""" 25 | # Use eval mode 26 | model.eval() 27 | # Generate a dummy mini-batch and copy data to GPU 28 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) 29 | inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 30 | # Compute precise forward pass time 31 | timer = Timer() 32 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 33 | for cur_iter in range(total_iter): 34 | # Reset the timers after the warmup phase 35 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 36 | timer.reset() 37 | # Forward 38 | timer.tic() 39 | model(inputs) 40 | torch.cuda.synchronize() 41 | timer.toc() 42 | return timer.average_time 43 | 44 | 45 | def compute_time_train(model, loss_fun): 46 | """Computes precise model forward + backward time using dummy data.""" 47 | # Use train mode 48 | model.train() 49 | # Generate a dummy mini-batch and copy data to GPU 50 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 51 | inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 52 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) 53 | labels_one_hot = net.smooth_one_hot_labels(labels) 54 | # Cache BatchNorm2D running stats 55 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 56 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] 57 | # Create a GradScaler for mixed precision training 58 | scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) 59 | # Compute precise forward backward pass time 60 | fw_timer, bw_timer = Timer(), Timer() 61 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 62 | for cur_iter in range(total_iter): 63 | # Reset the timers after the warmup phase 64 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 65 | fw_timer.reset() 66 | bw_timer.reset() 67 | # Forward 68 | fw_timer.tic() 69 | with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION): 70 | preds = model(inputs) 71 | loss = loss_fun(preds, labels_one_hot) 72 | torch.cuda.synchronize() 73 | fw_timer.toc() 74 | # Backward 75 | bw_timer.tic() 76 | scaler.scale(loss).backward() 77 | torch.cuda.synchronize() 78 | bw_timer.toc() 79 | # Restore BatchNorm2D running stats 80 | for bn, (mean, var) in zip(bns, bn_stats): 81 | bn.running_mean, bn.running_var = mean, var 82 | return fw_timer.average_time, bw_timer.average_time 83 | 84 | 85 | def compute_time_loader(data_loader): 86 | """Computes loader time.""" 87 | timer = Timer() 88 | loader.shuffle(data_loader, 0) 89 | data_loader_iterator = iter(data_loader) 90 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 91 | total_iter = min(total_iter, len(data_loader)) 92 | for cur_iter in range(total_iter): 93 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 94 | timer.reset() 95 | timer.tic() 96 | next(data_loader_iterator) 97 | timer.toc() 98 | return timer.average_time 99 | 100 | 101 | def compute_time_model(model, loss_fun): 102 | """Times model.""" 103 | logger.info("Computing model timings only...") 104 | # Compute timings 105 | test_fw_time = compute_time_eval(model) 106 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 107 | train_fw_bw_time = train_fw_time + train_bw_time 108 | # Output iter timing 109 | iter_times = { 110 | "test_fw_time": test_fw_time, 111 | "train_fw_time": train_fw_time, 112 | "train_bw_time": train_bw_time, 113 | "train_fw_bw_time": train_fw_bw_time, 114 | } 115 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 116 | 117 | 118 | def compute_time_full(model, loss_fun, train_loader, test_loader): 119 | """Times model and data loader.""" 120 | logger.info("Computing model and loader timings...") 121 | # Compute timings 122 | test_fw_time = compute_time_eval(model) 123 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 124 | train_fw_bw_time = train_fw_time + train_bw_time 125 | train_loader_time = compute_time_loader(train_loader) 126 | # Output iter timing 127 | iter_times = { 128 | "test_fw_time": test_fw_time, 129 | "train_fw_time": train_fw_time, 130 | "train_bw_time": train_bw_time, 131 | "train_fw_bw_time": train_fw_bw_time, 132 | "train_loader_time": train_loader_time, 133 | } 134 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 135 | # Output epoch timing 136 | epoch_times = { 137 | "test_fw_time": test_fw_time * len(test_loader), 138 | "train_fw_time": train_fw_time * len(train_loader), 139 | "train_bw_time": train_bw_time * len(train_loader), 140 | "train_fw_bw_time": train_fw_bw_time * len(train_loader), 141 | "train_loader_time": train_loader_time * len(train_loader), 142 | } 143 | logger.info(logging.dump_log_data(epoch_times, "epoch_times")) 144 | # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1) 145 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time 146 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100)) 147 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions that handle saving and loading of checkpoints.""" 9 | 10 | import os 11 | 12 | # import pycls.core.distributed as dist 13 | import torch 14 | from pycls.core.net import unwrap_model 15 | 16 | 17 | # Common prefix for checkpoint file names 18 | _NAME_PREFIX = "model_epoch_" 19 | 20 | # Checkpoints directory name 21 | _DIR_NAME = "checkpoints" 22 | 23 | 24 | def get_checkpoint_dir(episode_dir): 25 | """Retrieves the location for storing checkpoints.""" 26 | return os.path.join(episode_dir, _DIR_NAME) 27 | 28 | 29 | def get_checkpoint(epoch, episode_dir): 30 | """Retrieves the path to a checkpoint file.""" 31 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) 32 | # return os.path.join(get_checkpoint_dir(), name) 33 | return os.path.join(episode_dir, name) 34 | 35 | 36 | def get_checkpoint_best(episode_dir): 37 | """Retrieves the path to the best checkpoint file.""" 38 | return os.path.join(episode_dir, "model.pyth") 39 | 40 | 41 | def get_last_checkpoint(episode_dir): 42 | """Retrieves the most recent checkpoint (highest epoch number).""" 43 | checkpoint_dir = get_checkpoint_dir(episode_dir) 44 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] 45 | last_checkpoint_name = sorted(checkpoints)[-1] 46 | return os.path.join(checkpoint_dir, last_checkpoint_name) 47 | 48 | 49 | def has_checkpoint(episode_dir): 50 | """Determines if there are checkpoints available.""" 51 | checkpoint_dir = get_checkpoint_dir(episode_dir) 52 | if not os.path.exists(checkpoint_dir): 53 | return False 54 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) 55 | 56 | 57 | def save_checkpoint(info, model_state, optimizer_state, epoch, cfg): 58 | 59 | """Saves a checkpoint.""" 60 | # Save checkpoints only from the master process 61 | # if not dist.is_master_proc(): 62 | # return 63 | # Ensure that the checkpoint dir exists 64 | os.makedirs(cfg.EPISODE_DIR, exist_ok=True) 65 | 66 | # Record the state 67 | checkpoint = { 68 | "epoch": epoch, 69 | "model_state": model_state, 70 | "optimizer_state": optimizer_state, 71 | "cfg": cfg.dump(), 72 | } 73 | global _NAME_PREFIX 74 | _NAME_PREFIX = info + '_' + _NAME_PREFIX 75 | 76 | # Write the checkpoint 77 | checkpoint_file = get_checkpoint(epoch, cfg.EPISODE_DIR) 78 | torch.save(checkpoint, checkpoint_file) 79 | # print("Model checkpoint saved at path: {}".format(checkpoint_file)) 80 | 81 | _NAME_PREFIX = 'model_epoch_' 82 | return checkpoint_file 83 | 84 | 85 | def load_checkpoint(checkpoint_file, model, optimizer=None): 86 | """Loads the checkpoint from the given file.""" 87 | err_str = "Checkpoint '{}' not found" 88 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) 89 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 90 | unwrap_model(model).load_state_dict(checkpoint["model_state"]) 91 | optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else () 92 | return model 93 | 94 | 95 | def delete_checkpoints(checkpoint_dir=None, keep="all"): 96 | """Deletes unneeded checkpoints, keep can be "all", "last", or "none".""" 97 | assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep) 98 | checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir() 99 | if keep == "all" or not os.path.exists(checkpoint_dir): 100 | return 0 101 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] 102 | checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints 103 | [os.remove(os.path.join(checkpoint_dir, checkpoint)) for checkpoint in checkpoints] 104 | return len(checkpoints) 105 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Distributed helpers.""" 9 | 10 | import multiprocessing 11 | import os 12 | import random 13 | import signal 14 | import threading 15 | import traceback 16 | 17 | import torch 18 | from pycls.core.config import cfg 19 | 20 | 21 | # Make work w recent PyTorch versions (https://github.com/pytorch/pytorch/issues/37377) 22 | os.environ["MKL_THREADING_LAYER"] = "GNU" 23 | 24 | 25 | def is_master_proc(): 26 | """Determines if the current process is the master process. 27 | 28 | Master process is responsible for logging, writing and loading checkpoints. In 29 | the multi GPU setting, we assign the master role to the rank 0 process. When 30 | training using a single GPU, there is a single process which is considered master. 31 | """ 32 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 33 | 34 | 35 | def init_process_group(proc_rank, world_size, port): 36 | """Initializes the default process group.""" 37 | # Set the GPU to use 38 | torch.cuda.set_device(proc_rank) 39 | # Initialize the process group 40 | torch.distributed.init_process_group( 41 | backend=cfg.DIST_BACKEND, 42 | init_method="tcp://{}:{}".format(cfg.HOST, port), 43 | world_size=world_size, 44 | rank=proc_rank, 45 | ) 46 | 47 | 48 | def destroy_process_group(): 49 | """Destroys the default process group.""" 50 | torch.distributed.destroy_process_group() 51 | 52 | 53 | def scaled_all_reduce(tensors): 54 | """Performs the scaled all_reduce operation on the provided tensors. 55 | 56 | The input tensors are modified in-place. Currently supports only the sum 57 | reduction operator. The reduced values are scaled by the inverse size of the 58 | process group (equivalent to cfg.NUM_GPUS). 59 | """ 60 | # There is no need for reduction in the single-proc case 61 | if cfg.NUM_GPUS == 1: 62 | return tensors 63 | # Queue the reductions 64 | reductions = [] 65 | for tensor in tensors: 66 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 67 | reductions.append(reduction) 68 | # Wait for reductions to finish 69 | for reduction in reductions: 70 | reduction.wait() 71 | # Scale the results 72 | for tensor in tensors: 73 | tensor.mul_(1.0 / cfg.NUM_GPUS) 74 | return tensors 75 | 76 | 77 | class ChildException(Exception): 78 | """Wraps an exception from a child process.""" 79 | 80 | def __init__(self, child_trace): 81 | super(ChildException, self).__init__(child_trace) 82 | 83 | 84 | class ErrorHandler(object): 85 | """Multiprocessing error handler (based on fairseq's). 86 | 87 | Listens for errors in child processes and propagates the tracebacks to the parent. 88 | """ 89 | 90 | def __init__(self, error_queue): 91 | # Shared error queue 92 | self.error_queue = error_queue 93 | # Children processes sharing the error queue 94 | self.children_pids = [] 95 | # Start a thread listening to errors 96 | self.error_listener = threading.Thread(target=self.listen, daemon=True) 97 | self.error_listener.start() 98 | # Register the signal handler 99 | signal.signal(signal.SIGUSR1, self.signal_handler) 100 | 101 | def add_child(self, pid): 102 | """Registers a child process.""" 103 | self.children_pids.append(pid) 104 | 105 | def listen(self): 106 | """Listens for errors in the error queue.""" 107 | # Wait until there is an error in the queue 108 | child_trace = self.error_queue.get() 109 | # Put the error back for the signal handler 110 | self.error_queue.put(child_trace) 111 | # Invoke the signal handler 112 | os.kill(os.getpid(), signal.SIGUSR1) 113 | 114 | def signal_handler(self, _sig_num, _stack_frame): 115 | """Signal handler.""" 116 | # Kill children processes 117 | for pid in self.children_pids: 118 | os.kill(pid, signal.SIGINT) 119 | # Propagate the error from the child process 120 | raise ChildException(self.error_queue.get()) 121 | 122 | 123 | def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs): 124 | """Runs a function from a child process.""" 125 | try: 126 | # Initialize the process group 127 | init_process_group(proc_rank, world_size, port) 128 | # Run the function 129 | fun(*fun_args, **fun_kwargs) 130 | except KeyboardInterrupt: 131 | # Killed by the parent process 132 | pass 133 | except Exception: 134 | # Propagate exception to the parent process 135 | error_queue.put(traceback.format_exc()) 136 | finally: 137 | # Destroy the process group 138 | destroy_process_group() 139 | 140 | 141 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): 142 | """Runs a function in a multi-proc setting (unless num_proc == 1).""" 143 | # There is no need for multi-proc in the single-proc case 144 | fun_kwargs = fun_kwargs if fun_kwargs else {} 145 | if num_proc == 1: 146 | fun(*fun_args, **fun_kwargs) 147 | return 148 | # Handle errors from training subprocesses 149 | error_queue = multiprocessing.SimpleQueue() 150 | error_handler = ErrorHandler(error_queue) 151 | # Get a random port to use (without using global random number generator) 152 | port = random.Random().randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1]) 153 | # Run each training subprocess 154 | ps = [] 155 | for i in range(num_proc): 156 | p_i = multiprocessing.Process( 157 | target=run, args=(i, num_proc, port, error_queue, fun, fun_args, fun_kwargs) 158 | ) 159 | ps.append(p_i) 160 | p_i.start() 161 | error_handler.add_child(p_i.pid) 162 | # Wait for each subprocess to finish 163 | for p in ps: 164 | p.join() 165 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """IO utilities (adapted from Detectron)""" 9 | 10 | import logging 11 | import os 12 | import re 13 | import sys 14 | from urllib import request as urlrequest 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" 20 | 21 | 22 | def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL): 23 | """Download the file specified by the URL to the cache_dir and return the path to 24 | the cached file. If the argument is not a URL, simply return it as is. 25 | """ 26 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None 27 | if not is_url: 28 | return url_or_file 29 | url = url_or_file 30 | assert url.startswith(base_url), "url must start with: {}".format(base_url) 31 | cache_file_path = url.replace(base_url, cache_dir) 32 | if os.path.exists(cache_file_path): 33 | return cache_file_path 34 | cache_file_dir = os.path.dirname(cache_file_path) 35 | if not os.path.exists(cache_file_dir): 36 | os.makedirs(cache_file_dir) 37 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) 38 | download_url(url, cache_file_path) 39 | return cache_file_path 40 | 41 | 42 | def _progress_bar(count, total): 43 | """Report download progress. Credit: 44 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 45 | """ 46 | bar_len = 60 47 | filled_len = int(round(bar_len * count / float(total))) 48 | percents = round(100.0 * count / float(total), 1) 49 | bar = "=" * filled_len + "-" * (bar_len - filled_len) 50 | sys.stdout.write( 51 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) 52 | ) 53 | sys.stdout.flush() 54 | if count >= total: 55 | sys.stdout.write("\n") 56 | 57 | 58 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): 59 | """Download url and write it to dst_file_path. Credit: 60 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook 61 | """ 62 | req = urlrequest.Request(url) 63 | response = urlrequest.urlopen(req) 64 | total_size = response.info().get("Content-Length").strip() 65 | total_size = int(total_size) 66 | bytes_so_far = 0 67 | with open(dst_file_path, "wb") as f: 68 | while 1: 69 | chunk = response.read(chunk_size) 70 | bytes_so_far += len(chunk) 71 | if not chunk: 72 | break 73 | if progress_hook: 74 | progress_hook(bytes_so_far, total_size) 75 | f.write(chunk) 76 | return bytes_so_far 77 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Logging.""" 9 | 10 | import builtins 11 | import decimal 12 | import logging 13 | import os 14 | import simplejson 15 | import sys 16 | 17 | # import pycls.utils.distributed as du 18 | 19 | # Show filename and line number in logs 20 | _FORMAT = '[%(asctime)s %(filename)s: %(lineno)3d]: %(message)s' 21 | 22 | # Log file name (for cfg.LOG_DEST = 'file') 23 | _LOG_FILE = 'stdout.log' 24 | 25 | # Printed json stats lines will be tagged w/ this 26 | _TAG = 'json_stats: ' 27 | 28 | 29 | def _suppress_print(): 30 | """Suppresses printing from the current process.""" 31 | def ignore(*_objects, _sep=' ', _end='\n', _file=sys.stdout, _flush=False): 32 | pass 33 | builtins.print = ignore 34 | 35 | 36 | def setup_logging(cfg): 37 | """Sets up the logging.""" 38 | # Enable logging only for the master process 39 | # if du.is_master_proc(): 40 | if True: 41 | # Clear the root logger to prevent any existing logging config 42 | # (e.g. set by another module) from messing with our setup 43 | logging.root.handlers = [] 44 | # Construct logging configuration 45 | logging_config = { 46 | 'level': logging.INFO, 47 | 'format': _FORMAT, 48 | 'datefmt': '%Y-%m-%d %H:%M:%S' 49 | } 50 | # Log either to stdout or to a file 51 | if cfg.LOG_DEST == 'stdout': 52 | logging_config['stream'] = sys.stdout 53 | else: 54 | logging_config['filename'] = os.path.join(cfg.EXP_DIR, _LOG_FILE) 55 | # Configure logging 56 | logging.basicConfig(**logging_config) 57 | else: 58 | _suppress_print() 59 | 60 | 61 | def get_logger(name): 62 | """Retrieves the logger.""" 63 | return logging.getLogger(name) 64 | 65 | 66 | def log_json_stats(stats): 67 | """Logs json stats.""" 68 | # Decimal + string workaround for having fixed len float vals in logs 69 | stats = { 70 | k: decimal.Decimal('{:.12f}'.format(v)) if isinstance(v, float) else v 71 | for k, v in stats.items() 72 | } 73 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 74 | logger = get_logger(__name__) 75 | logger.info('{:s}{:s}'.format(_TAG, json_stats)) 76 | 77 | 78 | def load_json_stats(log_file): 79 | """Loads json_stats from a single log file.""" 80 | with open(log_file, 'r') as f: 81 | lines = f.readlines() 82 | json_lines = [l[l.find(_TAG) + len(_TAG):] for l in lines if _TAG in l] 83 | json_stats = [simplejson.loads(l) for l in json_lines] 84 | return json_stats 85 | 86 | 87 | def parse_json_stats(log, row_type, key): 88 | """Extract values corresponding to row_type/key out of log.""" 89 | vals = [row[key] for row in log if row['_type'] == row_type and key in row] 90 | if key == 'iter' or key == 'epoch': 91 | vals = [int(val.split('/')[0]) for val in vals] 92 | return vals 93 | 94 | 95 | def get_log_files(log_dir, name_filter=''): 96 | """Get all log files in directory containing subdirs of trained models.""" 97 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] 98 | files = [os.path.join(log_dir, n, _LOG_FILE) for n in names] 99 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] 100 | files, names = zip(*f_n_ps) 101 | return files, names 102 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for computing metrics.""" 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | 14 | from pycls.core.config import cfg 15 | 16 | # Number of bytes in a megabyte 17 | _B_IN_MB = 1024 * 1024 18 | 19 | 20 | def topks_correct(preds, labels, ks): 21 | """Computes the number of top-k correct predictions for each k.""" 22 | assert preds.size(0) == labels.size(0), \ 23 | 'Batch dim of predictions and labels must match' 24 | # Find the top max_k predictions for each sample 25 | _top_max_k_vals, top_max_k_inds = torch.topk( 26 | preds, max(ks), dim=1, largest=True, sorted=True 27 | ) 28 | # (batch_size, max_k) -> (max_k, batch_size) 29 | top_max_k_inds = top_max_k_inds.t() 30 | # (batch_size, ) -> (max_k, batch_size) 31 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 32 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 33 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 34 | # Compute the number of topk correct predictions for each k 35 | topks_correct = [ 36 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks 37 | ] 38 | return topks_correct 39 | 40 | 41 | def topk_errors(preds, labels, ks): 42 | """Computes the top-k error for each k.""" 43 | num_topks_correct = topks_correct(preds, labels, ks) 44 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 45 | 46 | 47 | def topk_accuracies(preds, labels, ks): 48 | """Computes the top-k accuracy for each k.""" 49 | num_topks_correct = topks_correct(preds, labels, ks) 50 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 51 | 52 | 53 | def params_count(model): 54 | """Computes the number of parameters.""" 55 | return np.sum([p.numel() for p in model.parameters()]).item() 56 | 57 | 58 | def flops_count(model): 59 | """Computes the number of flops statically.""" 60 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE 61 | count = 0 62 | for n, m in model.named_modules(): 63 | if isinstance(m, nn.Conv2d): 64 | if 'se.' in n: 65 | count += m.in_channels * m.out_channels + m.bias.numel() 66 | continue 67 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 68 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 69 | count += np.prod([ 70 | m.weight.numel(), 71 | h_out, w_out 72 | ]) 73 | if '.proj' not in n: 74 | h, w = h_out, w_out 75 | elif isinstance(m, nn.MaxPool2d): 76 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 77 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 78 | elif isinstance(m, nn.Linear): 79 | count += m.in_features * m.out_features 80 | return count.item() 81 | 82 | 83 | def gpu_mem_usage(): 84 | """Computes the GPU memory usage for the current device (MB).""" 85 | mem_usage_bytes = torch.cuda.max_memory_allocated() 86 | return mem_usage_bytes / _B_IN_MB 87 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | 15 | from pycls.core.config import cfg 16 | 17 | 18 | def init_weights(m): 19 | """Performs ResNet-style weight initialization.""" 20 | if isinstance(m, nn.Conv2d): 21 | # Note that there is no bias due to BN 22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | zero_init_gamma = ( 26 | hasattr(m, 'final_bn') and m.final_bn and 27 | cfg.BN.ZERO_INIT_FINAL_GAMMA 28 | ) 29 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.Linear): 32 | m.weight.data.normal_(mean=0.0, std=0.01) 33 | m.bias.data.zero_() 34 | 35 | 36 | @torch.no_grad() 37 | def compute_precise_bn_stats(model, loader): 38 | """Computes precise BN stats on training data.""" 39 | # Compute the number of minibatches to use 40 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) 41 | # Retrieve the BN layers 42 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 43 | # Initialize stats storage 44 | mus = [torch.zeros_like(bn.running_mean) for bn in bns] 45 | sqs = [torch.zeros_like(bn.running_var) for bn in bns] 46 | # Remember momentum values 47 | moms = [bn.momentum for bn in bns] 48 | # Disable momentum 49 | for bn in bns: 50 | bn.momentum = 1.0 51 | # Accumulate the stats across the data samples 52 | for inputs, _labels in itertools.islice(loader, num_iter): 53 | model(inputs.cuda()) 54 | # Accumulate the stats for each BN layer 55 | for i, bn in enumerate(bns): 56 | m, v = bn.running_mean, bn.running_var 57 | sqs[i] += (v + m * m) / num_iter 58 | mus[i] += m / num_iter 59 | # Set the stats and restore momentum values 60 | for i, bn in enumerate(bns): 61 | bn.running_var = sqs[i] - mus[i] * mus[i] 62 | bn.running_mean = mus[i] 63 | bn.momentum = moms[i] 64 | 65 | 66 | def reset_bn_stats(model): 67 | """Resets running BN stats.""" 68 | for m in model.modules(): 69 | if isinstance(m, torch.nn.BatchNorm2d): 70 | m.reset_running_stats() 71 | 72 | 73 | def drop_connect(x, drop_ratio): 74 | """Drop connect (adapted from DARTS).""" 75 | keep_ratio = 1.0 - drop_ratio 76 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 77 | mask.bernoulli_(keep_ratio) 78 | x.div_(keep_ratio) 79 | x.mul_(mask) 80 | return x 81 | 82 | 83 | def get_flat_weights(model): 84 | """Gets all model weights as a single flat vector.""" 85 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) 86 | 87 | 88 | def set_flat_weights(model, flat_weights): 89 | """Sets all model weights from a single flat vector.""" 90 | k = 0 91 | for p in model.parameters(): 92 | n = p.data.numel() 93 | p.data.copy_(flat_weights[k:(k + n)].view_as(p.data)) 94 | k += n 95 | assert k == flat_weights.numel() 96 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Plotting functions.""" 9 | 10 | import colorlover as cl 11 | import matplotlib.pyplot as plt 12 | import plotly.graph_objs as go 13 | import plotly.offline as offline 14 | import pycls.core.logging as logging 15 | 16 | 17 | def get_plot_colors(max_colors, color_format="pyplot"): 18 | """Generate colors for plotting.""" 19 | colors = cl.scales["11"]["qual"]["Paired"] 20 | if max_colors > len(colors): 21 | colors = cl.to_rgb(cl.interp(colors, max_colors)) 22 | if color_format == "pyplot": 23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] 24 | return colors 25 | 26 | 27 | def prepare_plot_data(log_files, names, metric="top1_err"): 28 | """Load logs and extract data for plotting error curves.""" 29 | plot_data = [] 30 | for file, name in zip(log_files, names): 31 | d, data = {}, logging.sort_log_data(logging.load_log_data(file)) 32 | for phase in ["train", "test"]: 33 | x = data[phase + "_epoch"]["epoch_ind"] 34 | y = data[phase + "_epoch"][metric] 35 | d["x_" + phase], d["y_" + phase] = x, y 36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name 37 | plot_data.append(d) 38 | assert len(plot_data) > 0, "No data to plot" 39 | return plot_data 40 | 41 | 42 | def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"): 43 | """Plot error curves using plotly and save to file.""" 44 | plot_data = prepare_plot_data(log_files, names, metric) 45 | colors = get_plot_colors(len(plot_data), "plotly") 46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend) 47 | data = [] 48 | for i, d in enumerate(plot_data): 49 | s = str(i) 50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} 51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5} 52 | data.append( 53 | go.Scatter( 54 | x=d["x_train"], 55 | y=d["y_train"], 56 | mode="lines", 57 | name=d["train_label"], 58 | line=line_train, 59 | legendgroup=s, 60 | visible=True, 61 | showlegend=False, 62 | ) 63 | ) 64 | data.append( 65 | go.Scatter( 66 | x=d["x_test"], 67 | y=d["y_test"], 68 | mode="lines", 69 | name=d["test_label"], 70 | line=line_test, 71 | legendgroup=s, 72 | visible=True, 73 | showlegend=True, 74 | ) 75 | ) 76 | data.append( 77 | go.Scatter( 78 | x=d["x_train"], 79 | y=d["y_train"], 80 | mode="lines", 81 | name=d["train_label"], 82 | line=line_train, 83 | legendgroup=s, 84 | visible=False, 85 | showlegend=True, 86 | ) 87 | ) 88 | # Prepare layout w ability to toggle 'all', 'train', 'test' 89 | titlefont = {"size": 18, "color": "#7f7f7f"} 90 | vis = [[True, True, False], [False, False, True], [False, True, False]] 91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) 92 | buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons] 93 | layout = go.Layout( 94 | title=metric + " vs. epoch
[dash=train, solid=test]", 95 | xaxis={"title": "epoch", "titlefont": titlefont}, 96 | yaxis={"title": metric, "titlefont": titlefont}, 97 | showlegend=True, 98 | hoverlabel={"namelength": -1}, 99 | updatemenus=[ 100 | { 101 | "buttons": buttons, 102 | "direction": "down", 103 | "showactive": True, 104 | "x": 1.02, 105 | "xanchor": "left", 106 | "y": 1.08, 107 | "yanchor": "top", 108 | } 109 | ], 110 | ) 111 | # Create plotly plot 112 | offline.plot({"data": data, "layout": layout}, filename=filename) 113 | 114 | 115 | def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"): 116 | """Plot error curves using matplotlib.pyplot and save to file.""" 117 | plot_data = prepare_plot_data(log_files, names, metric) 118 | colors = get_plot_colors(len(names)) 119 | for ind, d in enumerate(plot_data): 120 | c, lbl = colors[ind], d["test_label"] 121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) 122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) 123 | plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14) 124 | plt.xlabel("epoch", fontsize=14) 125 | plt.ylabel(metric, fontsize=14) 126 | plt.grid(alpha=0.4) 127 | plt.legend() 128 | if filename: 129 | plt.savefig(filename) 130 | plt.clf() 131 | else: 132 | plt.show() 133 | -------------------------------------------------------------------------------- /deep-al/pycls/utils/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Timer.""" 9 | 10 | import time 11 | 12 | 13 | class Timer(object): 14 | """A simple timer (adapted from Detectron).""" 15 | 16 | def __init__(self): 17 | self.total_time = None 18 | self.calls = None 19 | self.start_time = None 20 | self.diff = None 21 | self.average_time = None 22 | self.reset() 23 | 24 | def tic(self): 25 | # using time.time as time.clock does not normalize for multithreading 26 | self.start_time = time.time() 27 | 28 | def toc(self): 29 | self.diff = time.time() - self.start_time 30 | self.total_time += self.diff 31 | self.calls += 1 32 | self.average_time = self.total_time / self.calls 33 | 34 | def reset(self): 35 | self.total_time = 0.0 36 | self.calls = 0 37 | self.start_time = 0.0 38 | self.diff = 0.0 39 | self.average_time = 0.0 40 | -------------------------------------------------------------------------------- /deep-al/requirements.txt: -------------------------------------------------------------------------------- 1 | black==19.3b0 2 | flake8==3.8.4 3 | isort==4.3.21 4 | matplotlib==3.3.4 5 | numpy 6 | opencv-python==4.2.0.34 7 | torch==1.7.1 8 | torchvision==0.8.2 9 | parameterized 10 | setuptools 11 | simplejson 12 | yacs -------------------------------------------------------------------------------- /deep-al/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/tools/__init__.py -------------------------------------------------------------------------------- /probcover_selection.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/probcover_selection.gif -------------------------------------------------------------------------------- /probcover_semi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/probcover_semi.png -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/results.png -------------------------------------------------------------------------------- /scan/configs/env.yml: -------------------------------------------------------------------------------- 1 | root_dir: ./results/ 2 | -------------------------------------------------------------------------------- /scan/configs/pretext/moco_imagenet100.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_100 12 | val_db_name: imagenet_100 13 | num_classes: 100 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /scan/configs/pretext/moco_imagenet200.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_200 12 | val_db_name: imagenet_200 13 | num_classes: 200 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /scan/configs/pretext/moco_imagenet50.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: moco # MoCo is used here 3 | 4 | # Model 5 | backbone: resnet50 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: imagenet_50 12 | val_db_name: imagenet_50 13 | num_classes: 50 14 | temperature: 0.07 15 | 16 | # Batch size and workers 17 | batch_size: 256 18 | num_workers: 8 19 | 20 | # Transformations 21 | transformation_kwargs: 22 | crop_size: 224 23 | normalize: 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | -------------------------------------------------------------------------------- /scan/configs/pretext/simclr_cifar10.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: cifar-10 12 | val_db_name: cifar-10 13 | num_classes: 10 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 32 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.4914, 0.4822, 0.4465] 51 | std: [0.2023, 0.1994, 0.2010] 52 | 53 | transformation_kwargs: 54 | crop_size: 32 55 | normalize: 56 | mean: [0.4914, 0.4822, 0.4465] 57 | std: [0.2023, 0.1994, 0.2010] 58 | -------------------------------------------------------------------------------- /scan/configs/pretext/simclr_cifar100.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: cifar-100 12 | val_db_name: cifar-100 13 | num_classes: 100 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 32 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.5071, 0.4867, 0.4408] 51 | std: [0.2675, 0.2565, 0.2761] 52 | 53 | transformation_kwargs: 54 | crop_size: 32 55 | normalize: 56 | mean: [0.5071, 0.4867, 0.4408] 57 | std: [0.2675, 0.2565, 0.2761] 58 | -------------------------------------------------------------------------------- /scan/configs/pretext/simclr_stl10.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: stl-10 12 | val_db_name: stl-10 13 | num_classes: 10 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 96 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.485, 0.456, 0.406] 51 | std: [0.229, 0.224, 0.225] 52 | 53 | transformation_kwargs: 54 | crop_size: 96 55 | normalize: 56 | mean: [0.485, 0.456, 0.406] 57 | std: [0.229, 0.224, 0.225] 58 | -------------------------------------------------------------------------------- /scan/configs/pretext/simclr_tinyimagenet.yml: -------------------------------------------------------------------------------- 1 | # Setup 2 | setup: simclr 3 | 4 | # Model 5 | backbone: resnet18 6 | model_kwargs: 7 | head: mlp 8 | features_dim: 128 9 | 10 | # Dataset 11 | train_db_name: tiny-imagenet 12 | val_db_name: tiny-imagenet 13 | num_classes: 200 14 | 15 | # Loss 16 | criterion: simclr 17 | criterion_kwargs: 18 | temperature: 0.1 19 | 20 | # Hyperparameters 21 | epochs: 500 22 | optimizer: sgd 23 | optimizer_kwargs: 24 | nesterov: False 25 | weight_decay: 0.0001 26 | momentum: 0.9 27 | lr: 0.4 28 | scheduler: cosine 29 | scheduler_kwargs: 30 | lr_decay_rate: 0.1 31 | batch_size: 512 32 | num_workers: 8 33 | 34 | # Transformations 35 | augmentation_strategy: simclr 36 | augmentation_kwargs: 37 | random_resized_crop: 38 | size: 64 39 | scale: [0.2, 1.0] 40 | color_jitter_random_apply: 41 | p: 0.8 42 | color_jitter: 43 | brightness: 0.4 44 | contrast: 0.4 45 | saturation: 0.4 46 | hue: 0.1 47 | random_grayscale: 48 | p: 0.2 49 | normalize: 50 | mean: [0.5071, 0.4867, 0.4408] 51 | std: [0.2675, 0.2565, 0.2761] 52 | 53 | transformation_kwargs: 54 | crop_size: 64 55 | normalize: 56 | mean: [0.5071, 0.4867, 0.4408] 57 | std: [0.2675, 0.2565, 0.2761] 58 | -------------------------------------------------------------------------------- /scan/configs/scan/imagenet_eval.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | num_heads: 10 # Use multiple heads 12 | 13 | # Dataset 14 | train_db_name: imagenet 15 | val_db_name: imagenet 16 | num_classes: 1000 17 | num_neighbors: 50 18 | 19 | # Transformations 20 | augmentation_strategy: simclr 21 | augmentation_kwargs: 22 | random_resized_crop: 23 | size: 224 24 | scale: [0.2, 1.0] 25 | color_jitter_random_apply: 26 | p: 0.8 27 | color_jitter: 28 | brightness: 0.4 29 | contrast: 0.4 30 | saturation: 0.4 31 | hue: 0.1 32 | random_grayscale: 33 | p: 0.2 34 | normalize: 35 | mean: [0.485, 0.456, 0.406] 36 | std: [0.229, 0.224, 0.225] 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | num_workers: 12 45 | -------------------------------------------------------------------------------- /scan/configs/scan/scan_cifar10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: cifar-10 18 | val_db_name: cifar-10 19 | num_classes: 10 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 32 26 | normalize: 27 | mean: [0.4914, 0.4822, 0.4465] 28 | std: [0.2023, 0.1994, 0.2010] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 16 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 32 37 | normalize: 38 | mean: [0.4914, 0.4822, 0.4465] 39 | std: [0.2023, 0.1994, 0.2010] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 50 47 | batch_size: 128 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /scan/configs/scan/scan_cifar100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: cifar-100 18 | val_db_name: cifar-100 19 | num_classes: 100 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 32 26 | normalize: 27 | mean: [0.5071, 0.4867, 0.4408] 28 | std: [0.2675, 0.2565, 0.2761] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 16 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 32 37 | normalize: 38 | mean: [0.5071, 0.4867, 0.4408] 39 | std: [0.2675, 0.2565, 0.2761] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 100 47 | batch_size: 512 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /scan/configs/scan/scan_imagenet_100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_100 18 | val_db_name: imagenet_100 19 | num_classes: 100 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 1024 56 | num_workers: 16 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /scan/configs/scan/scan_imagenet_200.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_200 18 | val_db_name: imagenet_200 19 | num_classes: 200 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 1024 56 | num_workers: 12 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /scan/configs/scan/scan_imagenet_50.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Model 10 | backbone: resnet50 11 | 12 | # Weight update 13 | update_cluster_head_only: True # Train only linear layer during SCAN 14 | num_heads: 10 # Use multiple heads 15 | 16 | # Dataset 17 | train_db_name: imagenet_50 18 | val_db_name: imagenet_50 19 | num_classes: 50 20 | num_neighbors: 50 21 | 22 | # Transformations 23 | augmentation_strategy: simclr 24 | augmentation_kwargs: 25 | random_resized_crop: 26 | size: 224 27 | scale: [0.2, 1.0] 28 | color_jitter_random_apply: 29 | p: 0.8 30 | color_jitter: 31 | brightness: 0.4 32 | contrast: 0.4 33 | saturation: 0.4 34 | hue: 0.1 35 | random_grayscale: 36 | p: 0.2 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | transformation_kwargs: 42 | crop_size: 224 43 | normalize: 44 | mean: [0.485, 0.456, 0.406] 45 | std: [0.229, 0.224, 0.225] 46 | 47 | # Hyperparameters 48 | optimizer: sgd 49 | optimizer_kwargs: 50 | lr: 5.0 51 | weight_decay: 0.0000 52 | nesterov: False 53 | momentum: 0.9 54 | epochs: 100 55 | batch_size: 512 56 | num_workers: 12 57 | 58 | # Scheduler 59 | scheduler: constant 60 | -------------------------------------------------------------------------------- /scan/configs/scan/scan_stl10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: stl-10 18 | val_db_name: stl-10 19 | num_classes: 10 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 96 26 | normalize: 27 | mean: [0.485, 0.456, 0.406] 28 | std: [0.229, 0.224, 0.225] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 32 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 96 37 | normalize: 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 100 47 | batch_size: 128 48 | num_workers: 8 49 | 50 | # Scheduler 51 | scheduler: constant 52 | -------------------------------------------------------------------------------- /scan/configs/scan/scan_tinyimagenet.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: scan 3 | 4 | # Loss 5 | criterion: scan 6 | criterion_kwargs: 7 | entropy_weight: 5.0 8 | 9 | # Weight update 10 | update_cluster_head_only: False # Update full network in SCAN 11 | num_heads: 1 # Only use one head 12 | 13 | # Model 14 | backbone: resnet18 15 | 16 | # Dataset 17 | train_db_name: tiny-imagenet 18 | val_db_name: tiny-imagenet 19 | num_classes: 200 20 | num_neighbors: 20 21 | 22 | # Transformations 23 | augmentation_strategy: ours 24 | augmentation_kwargs: 25 | crop_size: 64 26 | normalize: 27 | mean: [0.5071, 0.4867, 0.4408] 28 | std: [0.2675, 0.2565, 0.2761] 29 | num_strong_augs: 4 30 | cutout_kwargs: 31 | n_holes: 1 32 | length: 16 33 | random: True 34 | 35 | transformation_kwargs: 36 | crop_size: 64 37 | normalize: 38 | mean: [0.5071, 0.4867, 0.4408] 39 | std: [0.2675, 0.2565, 0.2761] 40 | 41 | # Hyperparameters 42 | optimizer: adam 43 | optimizer_kwargs: 44 | lr: 0.0001 45 | weight_decay: 0.0001 46 | epochs: 100 47 | batch_size: 512 48 | 49 | num_workers: 8 50 | 51 | # Scheduler 52 | scheduler: constant 53 | -------------------------------------------------------------------------------- /scan/configs/selflabel/selflabel_cifar10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Criterion 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: cifar-10 21 | val_db_name: cifar-10 22 | num_classes: 10 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 32 28 | normalize: 29 | mean: [0.4914, 0.4822, 0.4465] 30 | std: [0.2023, 0.1994, 0.2010] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 16 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 32 39 | normalize: 40 | mean: [0.4914, 0.4822, 0.4465] 41 | std: [0.2023, 0.1994, 0.2010] 42 | 43 | # Hyperparameters 44 | epochs: 200 45 | optimizer: adam 46 | optimizer_kwargs: 47 | lr: 0.0001 48 | weight_decay: 0.0001 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /scan/configs/selflabel/selflabel_cifar100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Criterion 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: cifar-20 21 | val_db_name: cifar-20 22 | num_classes: 20 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 32 28 | normalize: 29 | mean: [0.5071, 0.4867, 0.4408] 30 | std: [0.2675, 0.2565, 0.2761] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 16 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 32 39 | normalize: 40 | mean: [0.5071, 0.4867, 0.4408] 41 | std: [0.2675, 0.2565, 0.2761] 42 | 43 | # Hyperparameters 44 | epochs: 200 45 | optimizer: adam 46 | optimizer_kwargs: 47 | lr: 0.0001 48 | weight_decay: 0.0001 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /scan/configs/selflabel/selflabel_imagenet_100.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_100 22 | val_db_name: imagenet_100 23 | num_classes: 100 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 12 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /scan/configs/selflabel/selflabel_imagenet_200.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_200 22 | val_db_name: imagenet_200 23 | num_classes: 200 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 8 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /scan/configs/selflabel/selflabel_imagenet_50.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # Threshold 5 | confidence_threshold: 0.99 6 | 7 | # EMA 8 | use_ema: True 9 | ema_alpha: 0.999 10 | 11 | # Loss 12 | criterion: confidence-cross-entropy 13 | criterion_kwargs: 14 | apply_class_balancing: False 15 | 16 | # Model 17 | backbone: resnet50 18 | num_heads: 1 19 | 20 | # Dataset 21 | train_db_name: imagenet_50 22 | val_db_name: imagenet_50 23 | num_classes: 50 24 | 25 | # Transformations 26 | augmentation_strategy: ours 27 | augmentation_kwargs: 28 | crop_size: 224 29 | normalize: 30 | mean: [0.485, 0.456, 0.406] 31 | std: [0.229, 0.224, 0.225] 32 | num_strong_augs: 4 33 | cutout_kwargs: 34 | n_holes: 1 35 | length: 75 36 | random: True 37 | 38 | transformation_kwargs: 39 | crop_size: 224 40 | normalize: 41 | mean: [0.485, 0.456, 0.406] 42 | std: [0.229, 0.224, 0.225] 43 | 44 | # Hyperparameters 45 | optimizer: sgd 46 | optimizer_kwargs: 47 | lr: 0.03 48 | weight_decay: 0.0 49 | nesterov: False 50 | momentum: 0.9 51 | epochs: 25 52 | batch_size: 512 53 | num_workers: 16 54 | 55 | # Scheduler 56 | scheduler: constant 57 | -------------------------------------------------------------------------------- /scan/configs/selflabel/selflabel_stl10.yml: -------------------------------------------------------------------------------- 1 | # setup 2 | setup: selflabel 3 | 4 | # ema 5 | use_ema: False 6 | 7 | # Threshold 8 | confidence_threshold: 0.99 9 | 10 | # Loss 11 | criterion: confidence-cross-entropy 12 | criterion_kwargs: 13 | apply_class_balancing: True 14 | 15 | # Model 16 | backbone: resnet18 17 | num_heads: 1 18 | 19 | # Dataset 20 | train_db_name: stl-10 21 | val_db_name: stl-10 22 | num_classes: 10 23 | 24 | # Transformations 25 | augmentation_strategy: ours 26 | augmentation_kwargs: 27 | crop_size: 96 28 | normalize: 29 | mean: [0.485, 0.456, 0.406] 30 | std: [0.229, 0.224, 0.225] 31 | num_strong_augs: 4 32 | cutout_kwargs: 33 | n_holes: 1 34 | length: 32 35 | random: True 36 | 37 | transformation_kwargs: 38 | crop_size: 96 39 | normalize: 40 | mean: [0.485, 0.456, 0.406] 41 | std: [0.229, 0.224, 0.225] 42 | 43 | # Hyperparameters 44 | optimizer: adam 45 | optimizer_kwargs: 46 | lr: 0.0001 47 | weight_decay: 0.0001 48 | epochs: 100 49 | batch_size: 1000 50 | num_workers: 8 51 | 52 | # Scheduler 53 | scheduler: constant 54 | -------------------------------------------------------------------------------- /scan/data/augment.py: -------------------------------------------------------------------------------- 1 | # List of augmentations based on randaugment 2 | import random 3 | 4 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 5 | import numpy as np 6 | import torch 7 | from torchvision.transforms.transforms import Compose 8 | 9 | random_mirror = True 10 | 11 | def ShearX(img, v): 12 | if random_mirror and random.random() > 0.5: 13 | v = -v 14 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 15 | 16 | def ShearY(img, v): 17 | if random_mirror and random.random() > 0.5: 18 | v = -v 19 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 20 | 21 | def Identity(img, v): 22 | return img 23 | 24 | def TranslateX(img, v): 25 | if random_mirror and random.random() > 0.5: 26 | v = -v 27 | v = v * img.size[0] 28 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 29 | 30 | def TranslateY(img, v): 31 | if random_mirror and random.random() > 0.5: 32 | v = -v 33 | v = v * img.size[1] 34 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 35 | 36 | def TranslateXAbs(img, v): 37 | if random.random() > 0.5: 38 | v = -v 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 40 | 41 | def TranslateYAbs(img, v): 42 | if random.random() > 0.5: 43 | v = -v 44 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 45 | 46 | def Rotate(img, v): 47 | if random_mirror and random.random() > 0.5: 48 | v = -v 49 | return img.rotate(v) 50 | 51 | def AutoContrast(img, _): 52 | return PIL.ImageOps.autocontrast(img) 53 | 54 | def Invert(img, _): 55 | return PIL.ImageOps.invert(img) 56 | 57 | def Equalize(img, _): 58 | return PIL.ImageOps.equalize(img) 59 | 60 | def Solarize(img, v): 61 | return PIL.ImageOps.solarize(img, v) 62 | 63 | def Posterize(img, v): 64 | v = int(v) 65 | return PIL.ImageOps.posterize(img, v) 66 | 67 | def Contrast(img, v): 68 | return PIL.ImageEnhance.Contrast(img).enhance(v) 69 | 70 | def Color(img, v): 71 | return PIL.ImageEnhance.Color(img).enhance(v) 72 | 73 | def Brightness(img, v): 74 | return PIL.ImageEnhance.Brightness(img).enhance(v) 75 | 76 | def Sharpness(img, v): 77 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 78 | 79 | def augment_list(): 80 | l = [ 81 | (Identity, 0, 1), 82 | (AutoContrast, 0, 1), 83 | (Equalize, 0, 1), 84 | (Rotate, -30, 30), 85 | (Solarize, 0, 256), 86 | (Color, 0.05, 0.95), 87 | (Contrast, 0.05, 0.95), 88 | (Brightness, 0.05, 0.95), 89 | (Sharpness, 0.05, 0.95), 90 | (ShearX, -0.1, 0.1), 91 | (TranslateX, -0.1, 0.1), 92 | (TranslateY, -0.1, 0.1), 93 | (Posterize, 4, 8), 94 | (ShearY, -0.1, 0.1), 95 | ] 96 | return l 97 | 98 | 99 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 100 | 101 | class Augment: 102 | def __init__(self, n): 103 | self.n = n 104 | self.augment_list = augment_list() 105 | 106 | def __call__(self, img): 107 | ops = random.choices(self.augment_list, k=self.n) 108 | for op, minval, maxval in ops: 109 | val = (random.random()) * float(maxval - minval) + minval 110 | img = op(img, val) 111 | 112 | return img 113 | 114 | def get_augment(name): 115 | return augment_dict[name] 116 | 117 | def apply_augment(img, name, level): 118 | augment_fn, low, high = get_augment(name) 119 | return augment_fn(img.copy(), level * (high - low) + low) 120 | 121 | class Cutout(object): 122 | def __init__(self, n_holes, length, random=False): 123 | self.n_holes = n_holes 124 | self.length = length 125 | self.random = random 126 | 127 | def __call__(self, img): 128 | h = img.size(1) 129 | w = img.size(2) 130 | length = random.randint(1, self.length) 131 | mask = np.ones((h, w), np.float32) 132 | 133 | for n in range(self.n_holes): 134 | y = np.random.randint(h) 135 | x = np.random.randint(w) 136 | 137 | y1 = np.clip(y - length // 2, 0, h) 138 | y2 = np.clip(y + length // 2, 0, h) 139 | x1 = np.clip(x - length // 2, 0, w) 140 | x2 = np.clip(x + length // 2, 0, w) 141 | 142 | mask[y1: y2, x1: x2] = 0. 143 | 144 | mask = torch.from_numpy(mask) 145 | mask = mask.expand_as(img) 146 | img = img * mask 147 | 148 | return img 149 | -------------------------------------------------------------------------------- /scan/data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | """ 10 | AugmentedDataset 11 | Returns an image together with an augmentation. 12 | """ 13 | class AugmentedDataset(Dataset): 14 | def __init__(self, dataset): 15 | super(AugmentedDataset, self).__init__() 16 | transform = dataset.transform 17 | dataset.transform = None 18 | self.dataset = dataset 19 | 20 | if isinstance(transform, dict): 21 | self.image_transform = transform['standard'] 22 | self.augmentation_transform = transform['augment'] 23 | 24 | else: 25 | self.image_transform = transform 26 | self.augmentation_transform = transform 27 | 28 | def __len__(self): 29 | return len(self.dataset) 30 | 31 | def __getitem__(self, index): 32 | sample = self.dataset.__getitem__(index) 33 | image = sample['image'] 34 | 35 | sample['image'] = self.image_transform(image) 36 | sample['image_augmented'] = self.augmentation_transform(image) 37 | 38 | return sample 39 | 40 | 41 | """ 42 | NeighborsDataset 43 | Returns an image with one of its neighbors. 44 | """ 45 | class NeighborsDataset(Dataset): 46 | def __init__(self, dataset, indices, num_neighbors=None): 47 | super(NeighborsDataset, self).__init__() 48 | transform = dataset.transform 49 | 50 | if isinstance(transform, dict): 51 | self.anchor_transform = transform['standard'] 52 | self.neighbor_transform = transform['augment'] 53 | else: 54 | self.anchor_transform = transform 55 | self.neighbor_transform = transform 56 | 57 | dataset.transform = None 58 | self.dataset = dataset 59 | self.indices = indices # Nearest neighbor indices (np.array [len(dataset) x k]) 60 | if num_neighbors is not None: 61 | self.indices = self.indices[:, :num_neighbors+1] 62 | assert(self.indices.shape[0] == len(self.dataset)) 63 | 64 | def __len__(self): 65 | return len(self.dataset) 66 | 67 | def __getitem__(self, index): 68 | output = {} 69 | anchor = self.dataset.__getitem__(index) 70 | 71 | neighbor_index = np.random.choice(self.indices[index], 1)[0] 72 | neighbor = self.dataset.__getitem__(neighbor_index) 73 | 74 | anchor['image'] = self.anchor_transform(anchor['image']) 75 | neighbor['image'] = self.neighbor_transform(neighbor['image']) 76 | 77 | output['anchor'] = anchor['image'] 78 | output['neighbor'] = neighbor['image'] 79 | output['possible_neighbors'] = torch.from_numpy(self.indices[index]) 80 | output['target'] = anchor['target'] 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /scan/data/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | from PIL import Image 10 | from utils.mypath import MyPath 11 | from torchvision import transforms as tf 12 | from glob import glob 13 | 14 | 15 | class ImageNet(datasets.ImageFolder): 16 | def __init__(self, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): 17 | super(ImageNet, self).__init__(root=os.path.join(root, 'ILSVRC2012_img_%s' %(split)), 18 | transform=None) 19 | self.transform = transform 20 | self.split = split 21 | self.resize = tf.Resize(256) 22 | 23 | def __len__(self): 24 | return len(self.imgs) 25 | 26 | def __getitem__(self, index): 27 | path, target = self.imgs[index] 28 | with open(path, 'rb') as f: 29 | img = Image.open(f).convert('RGB') 30 | im_size = img.size 31 | img = self.resize(img) 32 | 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | 36 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index}} 37 | 38 | return out 39 | 40 | def get_image(self, index): 41 | path, target = self.imgs[index] 42 | with open(path, 'rb') as f: 43 | img = Image.open(f).convert('RGB') 44 | img = self.resize(img) 45 | return img 46 | 47 | 48 | class ImageNetSubset(data.Dataset): 49 | def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', 50 | transform=None): 51 | super(ImageNetSubset, self).__init__() 52 | 53 | self.root = os.path.join(root, 'ILSVRC2012_img_%s' %(split)) 54 | self.transform = transform 55 | self.split = split 56 | 57 | # Read the subset of classes to include (sorted) 58 | with open(subset_file, 'r') as f: 59 | result = f.read().splitlines() 60 | subdirs, class_names = [], [] 61 | for line in result: 62 | subdir, class_name = line.split(' ', 1) 63 | subdirs.append(subdir) 64 | class_names.append(class_name) 65 | 66 | # Gather the files (sorted) 67 | imgs = [] 68 | for i, subdir in enumerate(subdirs): 69 | subdir_path = os.path.join(self.root, subdir) 70 | files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG'))) 71 | for f in files: 72 | imgs.append((f, i)) 73 | self.imgs = imgs 74 | self.classes = class_names 75 | 76 | # Resize 77 | self.resize = tf.Resize(256) 78 | 79 | def get_image(self, index): 80 | path, target = self.imgs[index] 81 | with open(path, 'rb') as f: 82 | img = Image.open(f).convert('RGB') 83 | img = self.resize(img) 84 | return img 85 | 86 | def __len__(self): 87 | return len(self.imgs) 88 | 89 | def __getitem__(self, index): 90 | path, target = self.imgs[index] 91 | with open(path, 'rb') as f: 92 | img = Image.open(f).convert('RGB') 93 | im_size = img.size 94 | img = self.resize(img) 95 | class_name = self.classes[target] 96 | 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | 100 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index, 'class_name': class_name}} 101 | 102 | return out 103 | -------------------------------------------------------------------------------- /scan/data/imagenet_subsets/imagenet_100.txt: -------------------------------------------------------------------------------- 1 | n01558993 robin, American robin, Turdus migratorius 2 | n01601694 water ouzel, dipper 3 | n01669191 box turtle, box tortoise 4 | n01751748 sea snake 5 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 6 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 7 | n01770393 scorpion 8 | n01855672 goose 9 | n01871265 tusker 10 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 11 | n02037110 oystercatcher, oyster catcher 12 | n02058221 albatross, mollymawk 13 | n02087046 toy terrier 14 | n02088632 bluetick 15 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 16 | n02093754 Border terrier 17 | n02094114 Norfolk terrier 18 | n02096177 cairn, cairn terrier 19 | n02097130 giant schnauzer 20 | n02097298 Scotch terrier, Scottish terrier, Scottie 21 | n02099267 flat-coated retriever 22 | n02100877 Irish setter, red setter 23 | n02104365 schipperke 24 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 25 | n02106030 collie 26 | n02106166 Border collie 27 | n02107142 Doberman, Doberman pinscher 28 | n02110341 dalmatian, coach dog, carriage dog 29 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 30 | n02120079 Arctic fox, white fox, Alopex lagopus 31 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 32 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 33 | n02128385 leopard, Panthera pardus 34 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 35 | n02277742 ringlet, ringlet butterfly 36 | n02325366 wood rabbit, cottontail, cottontail rabbit 37 | n02364673 guinea pig, Cavia cobaya 38 | n02484975 guenon, guenon monkey 39 | n02489166 proboscis monkey, Nasalis larvatus 40 | n02708093 analog clock 41 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 42 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 43 | n02906734 broom 44 | n02909870 bucket, pail 45 | n03085013 computer keyboard, keypad 46 | n03124170 cowboy hat, ten-gallon hat 47 | n03127747 crash helmet 48 | n03160309 dam, dike, dyke 49 | n03255030 dumbbell 50 | n03272010 electric guitar 51 | n03291819 envelope 52 | n03337140 file, file cabinet, filing cabinet 53 | n03450230 gown 54 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 55 | n03498962 hatchet 56 | n03530642 honeycomb 57 | n03623198 knee pad 58 | n03649909 lawn mower, mower 59 | n03710721 maillot, tank suit 60 | n03717622 manhole cover 61 | n03733281 maze, labyrinth 62 | n03759954 microphone, mike 63 | n03775071 mitten 64 | n03814639 neck brace 65 | n03837869 obelisk 66 | n03838899 oboe, hautboy, hautbois 67 | n03854065 organ, pipe organ 68 | n03929855 pickelhaube 69 | n03930313 picket fence, paling 70 | n03954731 plane, carpenter's plane, woodworking plane 71 | n03956157 planetarium 72 | n03983396 pop bottle, soda bottle 73 | n04004767 printer 74 | n04026417 purse 75 | n04065272 recreational vehicle, RV, R.V. 76 | n04200800 shoe shop, shoe-shop, shoe store 77 | n04209239 shower curtain 78 | n04235860 sleeping bag 79 | n04311004 steel arch bridge 80 | n04325704 stole 81 | n04336792 stretcher 82 | n04346328 stupa, tope 83 | n04380533 table lamp 84 | n04428191 thresher, thrasher, threshing machine 85 | n04443257 tobacco shop, tobacconist shop, tobacconist 86 | n04458633 totem pole 87 | n04483307 trimaran 88 | n04509417 unicycle, monocycle 89 | n04515003 upright, upright piano 90 | n04525305 vending machine 91 | n04554684 washer, automatic washer, washing machine 92 | n04591157 Windsor tie 93 | n04592741 wing 94 | n04606251 wreck 95 | n07583066 guacamole 96 | n07613480 trifle 97 | n07693725 bagel, beigel 98 | n07711569 mashed potato 99 | n07753592 banana 100 | n11879895 rapeseed 101 | -------------------------------------------------------------------------------- /scan/data/imagenet_subsets/imagenet_50.txt: -------------------------------------------------------------------------------- 1 | n01601694 Dipper 2 | n01669191 Box Turtle 3 | n01755581 Diamondback Snake 4 | n01770393 Scorpion 5 | n01855672 Goose 6 | n02018207 Water Hen 7 | n02058221 Albatross 8 | n02096177 Cairn Terrier 9 | n02097130 Giant Schnauzer 10 | n02099267 Flat-Coated Retriever 11 | n02100877 Irish Setter 12 | n02104365 Schipperke 13 | n02106030 Collie 14 | n02114855 Coyote 15 | n02125311 Mountain Lion 16 | n02133161 Black Bear 17 | n02484975 Guenon 18 | n02489166 Proboscis Monkey 19 | n02747177 Trash Bin 20 | n02906734 Broom 21 | n03124170 Cowboy Hat 22 | n03272010 Electric Guitar 23 | n03337140 File Cabinet 24 | n03483316 Hair Dryer 25 | n03498962 Hatchet 26 | n03710721 Maillot 27 | n03717622 Manhole Cover 28 | n03733281 Maze 29 | n03759954 Microphone 30 | n03775071 Mitten 31 | n03814639 Neck Brace 32 | n03837869 Obelisk 33 | n03838899 Oboe 34 | n03854065 Organ 35 | n03954731 Woodworking Plane 36 | n03983396 Soda Bottle 37 | n04026417 Purse 38 | n04200800 Shoe Shop 39 | n04209239 Shower Curtain 40 | n04311004 Steel Arch Bridge 41 | n04380533 Table Lamp 42 | n04428191 Threshing Machine 43 | n04443257 Tobacco Shop 44 | n04509417 Unicycle 45 | n04525305 Vending Machine 46 | n04554684 Washing Machine 47 | n04606251 Wreck 48 | n07583066 Guacamole 49 | n07711569 Mashed Potato 50 | n07753592 Banana 51 | -------------------------------------------------------------------------------- /scan/data/tinyimagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import os 5 | import pickle 6 | from PIL import Image 7 | from torchvision.datasets.utils import check_integrity 8 | from torchvision import datasets 9 | from typing import Any 10 | 11 | 12 | def unpickle_object(path): 13 | with open(path, 'rb+') as file_pi: 14 | res = pickle.load(file_pi) 15 | return res 16 | 17 | 18 | class TinyImageNet(datasets.VisionDataset): 19 | """`Tiny ImageNet Classification Dataset. 20 | 21 | Args: 22 | root (string): Root directory of the ImageNet Dataset. 23 | split (string, optional): The dataset split, supports ``train``, or ``val``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | loader (callable, optional): A function to load an image given its path. 29 | 30 | Attributes: 31 | classes (list): List of the class name tuples. 32 | class_to_idx (dict): Dict with items (class_name, class_index). 33 | wnids (list): List of the WordNet IDs. 34 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index). 35 | samples (list): List of (image path, class_index) tuples 36 | targets (list): The class_index value for each image in the dataset 37 | """ 38 | 39 | def __init__(self, root: str, split: str = 'train', transform=None, **kwargs: Any) -> None: 40 | self.root = root 41 | if split == 'train+unlabeled': 42 | split = 'train' 43 | self.split = datasets.utils.verify_str_arg(split, "split", ("train", "val")) 44 | 45 | if self.split == 'train': 46 | self.images, self.targets, self.cls_to_id = unpickle_object('../../../daphna/data/tiny_imagenet/tiny-imagenet-200/train.pkl') 47 | elif self.split == 'val': 48 | self.images, self.targets, self.cls_to_id = unpickle_object('../../../daphna/data/tiny_imagenet/tiny-imagenet-200/val.pkl') 49 | else: 50 | raise NotImplementedError('unknown split') 51 | self.targets = self.targets.astype(int) 52 | self.classes = list(self.cls_to_id.keys()) 53 | super(TinyImageNet, self).__init__(root, **kwargs) 54 | self.transform = transform 55 | 56 | # Split folder is used for the 'super' call. Since val directory is not structured like the train, 57 | # we simply use train's structure to get all classes and other stuff 58 | @property 59 | def split_folder(self) -> str: 60 | return os.path.join(self.root, 'train') 61 | 62 | def __getitem__(self, index: int): 63 | """ 64 | Args: 65 | index (int): Index 66 | 67 | Returns: 68 | tuple: (sample, target) where target is class_index of the target class. 69 | """ 70 | sample = Image.fromarray(self.images[index]) 71 | target = int(self.targets[index]) 72 | 73 | if self.transform is not None: 74 | sample = self.transform(sample) 75 | 76 | out = {'image': sample, 'target': target, 'meta': {'im_size': 64, 'index': index, 'class_name': target}} 77 | return out 78 | 79 | def __len__(self): 80 | return len(self.targets) 81 | -------------------------------------------------------------------------------- /scan/images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/pipeline.png -------------------------------------------------------------------------------- /scan/images/prototypes_cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/prototypes_cifar10.jpg -------------------------------------------------------------------------------- /scan/images/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/teaser.jpg -------------------------------------------------------------------------------- /scan/images/tutorial/confusion_matrix_stl10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/tutorial/confusion_matrix_stl10.png -------------------------------------------------------------------------------- /scan/images/tutorial/prototypes_stl10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/tutorial/prototypes_stl10.jpg -------------------------------------------------------------------------------- /scan/losses/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | EPS=1e-8 9 | 10 | 11 | class MaskedCrossEntropyLoss(nn.Module): 12 | def __init__(self): 13 | super(MaskedCrossEntropyLoss, self).__init__() 14 | 15 | def forward(self, input, target, mask, weight, reduction='mean'): 16 | if not (mask != 0).any(): 17 | raise ValueError('Mask in MaskedCrossEntropyLoss is all zeros.') 18 | target = torch.masked_select(target, mask) 19 | b, c = input.size() 20 | n = target.size(0) 21 | input = torch.masked_select(input, mask.view(b, 1)).view(n, c) 22 | return F.cross_entropy(input, target, weight = weight, reduction = reduction) 23 | 24 | 25 | class ConfidenceBasedCE(nn.Module): 26 | def __init__(self, threshold, apply_class_balancing): 27 | super(ConfidenceBasedCE, self).__init__() 28 | self.loss = MaskedCrossEntropyLoss() 29 | self.softmax = nn.Softmax(dim = 1) 30 | self.threshold = threshold 31 | self.apply_class_balancing = apply_class_balancing 32 | 33 | def forward(self, anchors_weak, anchors_strong): 34 | """ 35 | Loss function during self-labeling 36 | 37 | input: logits for original samples and for its strong augmentations 38 | output: cross entropy 39 | """ 40 | # Retrieve target and mask based on weakly augmentated anchors 41 | weak_anchors_prob = self.softmax(anchors_weak) 42 | max_prob, target = torch.max(weak_anchors_prob, dim = 1) 43 | mask = max_prob > self.threshold 44 | b, c = weak_anchors_prob.size() 45 | target_masked = torch.masked_select(target, mask.squeeze()) 46 | n = target_masked.size(0) 47 | 48 | # Inputs are strongly augmented anchors 49 | input_ = anchors_strong 50 | 51 | # Class balancing weights 52 | if self.apply_class_balancing: 53 | idx, counts = torch.unique(target_masked, return_counts = True) 54 | freq = 1/(counts.float()/n) 55 | weight = torch.ones(c).cuda() 56 | weight[idx] = freq 57 | 58 | else: 59 | weight = None 60 | 61 | # Loss 62 | loss = self.loss(input_, target, mask, weight = weight, reduction='mean') 63 | 64 | return loss 65 | 66 | 67 | def entropy(x, input_as_probabilities): 68 | """ 69 | Helper function to compute the entropy over the batch 70 | 71 | input: batch w/ shape [b, num_classes] 72 | output: entropy value [is ideally -log(num_classes)] 73 | """ 74 | 75 | if input_as_probabilities: 76 | x_ = torch.clamp(x, min = EPS) 77 | b = x_ * torch.log(x_) 78 | else: 79 | b = F.softmax(x, dim = 1) * F.log_softmax(x, dim = 1) 80 | 81 | if len(b.size()) == 2: # Sample-wise entropy 82 | return -b.sum(dim = 1).mean() 83 | elif len(b.size()) == 1: # Distribution-wise entropy 84 | return - b.sum() 85 | else: 86 | raise ValueError('Input tensor is %d-Dimensional' %(len(b.size()))) 87 | 88 | 89 | class SCANLoss(nn.Module): 90 | def __init__(self, entropy_weight = 2.0): 91 | super(SCANLoss, self).__init__() 92 | self.softmax = nn.Softmax(dim = 1) 93 | self.bce = nn.BCELoss() 94 | self.entropy_weight = entropy_weight # Default = 2.0 95 | 96 | def forward(self, anchors, neighbors): 97 | """ 98 | input: 99 | - anchors: logits for anchor images w/ shape [b, num_classes] 100 | - neighbors: logits for neighbor images w/ shape [b, num_classes] 101 | 102 | output: 103 | - Loss 104 | """ 105 | # Softmax 106 | b, n = anchors.size() 107 | anchors_prob = self.softmax(anchors) 108 | positives_prob = self.softmax(neighbors) 109 | 110 | # Similarity in output space 111 | similarity = torch.bmm(anchors_prob.view(b, 1, n), positives_prob.view(b, n, 1)).squeeze() 112 | ones = torch.ones_like(similarity) 113 | consistency_loss = self.bce(similarity, ones) 114 | 115 | # Entropy loss 116 | entropy_loss = entropy(torch.mean(anchors_prob, 0), input_as_probabilities = True) 117 | 118 | # Total loss 119 | total_loss = consistency_loss - self.entropy_weight * entropy_loss 120 | 121 | return total_loss, consistency_loss, entropy_loss 122 | 123 | 124 | class SimCLRLoss(nn.Module): 125 | # Based on the implementation of SupContrast 126 | def __init__(self, temperature): 127 | super(SimCLRLoss, self).__init__() 128 | self.temperature = temperature 129 | 130 | 131 | def forward(self, features): 132 | """ 133 | input: 134 | - features: hidden feature representation of shape [b, 2, dim] 135 | 136 | output: 137 | - loss: loss computed according to SimCLR 138 | """ 139 | 140 | b, n, dim = features.size() 141 | assert(n == 2) 142 | mask = torch.eye(b, dtype=torch.float32).cuda() 143 | 144 | contrast_features = torch.cat(torch.unbind(features, dim=1), dim=0) 145 | anchor = features[:, 0] 146 | 147 | # Dot product 148 | dot_product = torch.matmul(anchor, contrast_features.T) / self.temperature 149 | 150 | # Log-sum trick for numerical stability 151 | logits_max, _ = torch.max(dot_product, dim=1, keepdim=True) 152 | logits = dot_product - logits_max.detach() 153 | 154 | mask = mask.repeat(1, 2) 155 | logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(b).view(-1, 1).cuda(), 0) 156 | mask = mask * logits_mask 157 | 158 | # Log-softmax 159 | exp_logits = torch.exp(logits) * logits_mask 160 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 161 | 162 | # Mean log-likelihood for positive 163 | loss = - ((mask * log_prob).sum(1) / mask.sum(1)).mean() 164 | 165 | return loss 166 | -------------------------------------------------------------------------------- /scan/moco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | import numpy as np 9 | 10 | from utils.config import create_config 11 | from utils.common_config import get_model, get_train_dataset,\ 12 | get_val_dataset, get_val_dataloader, get_val_transformations 13 | from utils.memory import MemoryBank 14 | from utils.utils import fill_memory_bank 15 | from termcolor import colored 16 | 17 | # Parser 18 | parser = argparse.ArgumentParser(description='MoCo') 19 | parser.add_argument('--config_env', 20 | help='Config file for the environment') 21 | parser.add_argument('--config_exp', 22 | help='Config file for the experiment') 23 | args = parser.parse_args() 24 | 25 | def main(): 26 | # Retrieve config file 27 | p = create_config(args.config_env, args.config_exp) 28 | print(colored(p, 'red')) 29 | 30 | 31 | # Model 32 | print(colored('Retrieve model', 'blue')) 33 | model = get_model(p) 34 | print('Model is {}'.format(model.__class__.__name__)) 35 | print(model) 36 | model = torch.nn.DataParallel(model) 37 | model = model.cuda() 38 | 39 | 40 | # CUDNN 41 | print(colored('Set CuDNN benchmark', 'blue')) 42 | torch.backends.cudnn.benchmark = True 43 | 44 | 45 | # Dataset 46 | print(colored('Retrieve dataset', 'blue')) 47 | transforms = get_val_transformations(p) 48 | train_dataset = get_train_dataset(p, transforms) 49 | val_dataset = get_val_dataset(p, transforms) 50 | train_dataloader = get_val_dataloader(p, train_dataset) 51 | val_dataloader = get_val_dataloader(p, val_dataset) 52 | print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset))) 53 | 54 | 55 | # Memory Bank 56 | print(colored('Build MemoryBank', 'blue')) 57 | memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'], p['temperature']) 58 | memory_bank_train.cuda() 59 | memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'], p['temperature']) 60 | memory_bank_val.cuda() 61 | 62 | 63 | # Load the official MoCoV2 checkpoint 64 | print(colored('Downloading moco v2 checkpoint', 'blue')) 65 | os.system('wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar') 66 | moco_state = torch.load('moco_v2_800ep_pretrain.pth.tar', map_location='cpu') 67 | 68 | 69 | # Transfer moco weights 70 | print(colored('Transfer MoCo weights to model', 'blue')) 71 | new_state_dict = {} 72 | state_dict = moco_state['state_dict'] 73 | for k in list(state_dict.keys()): 74 | # Copy backbone weights 75 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 76 | new_k = 'module.backbone.' + k[len('module.encoder_q.'):] 77 | new_state_dict[new_k] = state_dict[k] 78 | 79 | # Copy mlp weights 80 | elif k.startswith('module.encoder_q.fc'): 81 | new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.'):] 82 | new_state_dict[new_k] = state_dict[k] 83 | 84 | else: 85 | raise ValueError('Unexpected key {}'.format(k)) 86 | 87 | model.load_state_dict(new_state_dict) 88 | os.system('rm -rf moco_v2_800ep_pretrain.pth.tar') 89 | 90 | 91 | # Save final model 92 | print(colored('Save pretext model', 'blue')) 93 | torch.save(model.module.state_dict(), p['pretext_model']) 94 | model.module.contrastive_head = torch.nn.Identity() # In this case, we mine the neighbors before the MLP. 95 | 96 | 97 | # Mine the topk nearest neighbors (Train) 98 | # These will be used for training with the SCAN-Loss. 99 | topk = 50 100 | print(colored('Mine the nearest neighbors (Train)(Top-%d)' %(topk), 'blue')) 101 | transforms = get_val_transformations(p) 102 | train_dataset = get_train_dataset(p, transforms) 103 | fill_memory_bank(train_dataloader, model, memory_bank_train) 104 | indices, acc = memory_bank_train.mine_nearest_neighbors(topk) 105 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) 106 | np.save(p['topk_neighbors_train_path'], indices) 107 | 108 | 109 | # Mine the topk nearest neighbors (Validation) 110 | # These will be used for validation. 111 | topk = 5 112 | print(colored('Mine the nearest neighbors (Val)(Top-%d)' %(topk), 'blue')) 113 | fill_memory_bank(val_dataloader, model, memory_bank_val) 114 | print('Mine the neighbors') 115 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk) 116 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) 117 | np.save(p['topk_neighbors_val_path'], indices) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /scan/models/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ContrastiveModel(nn.Module): 11 | def __init__(self, backbone, head='mlp', features_dim=128): 12 | super(ContrastiveModel, self).__init__() 13 | self.backbone = backbone['backbone'] 14 | self.backbone_dim = backbone['dim'] 15 | self.head = head 16 | 17 | if head == 'linear': 18 | self.contrastive_head = nn.Linear(self.backbone_dim, features_dim) 19 | 20 | elif head == 'mlp': 21 | self.contrastive_head = nn.Sequential( 22 | nn.Linear(self.backbone_dim, self.backbone_dim), 23 | nn.ReLU(), nn.Linear(self.backbone_dim, features_dim)) 24 | 25 | else: 26 | raise ValueError('Invalid head {}'.format(head)) 27 | 28 | def forward(self, x, return_pre_last=False): 29 | pre_last = self.backbone(x) 30 | features = self.contrastive_head(pre_last) 31 | features = F.normalize(features, dim = 1) 32 | if return_pre_last: 33 | return features, pre_last 34 | return features 35 | 36 | 37 | class ClusteringModel(nn.Module): 38 | def __init__(self, backbone, nclusters, nheads=1): 39 | super(ClusteringModel, self).__init__() 40 | self.backbone = backbone['backbone'] 41 | self.backbone_dim = backbone['dim'] 42 | self.nheads = nheads 43 | assert(isinstance(self.nheads, int)) 44 | assert(self.nheads > 0) 45 | self.cluster_head = nn.ModuleList([nn.Linear(self.backbone_dim, nclusters) for _ in range(self.nheads)]) 46 | 47 | def forward(self, x, forward_pass='default'): 48 | if forward_pass == 'default': 49 | features = self.backbone(x) 50 | out = [cluster_head(features) for cluster_head in self.cluster_head] 51 | 52 | elif forward_pass == 'backbone': 53 | out = self.backbone(x) 54 | 55 | elif forward_pass == 'head': 56 | out = [cluster_head(x) for cluster_head in self.cluster_head] 57 | 58 | elif forward_pass == 'return_all': 59 | features = self.backbone(x) 60 | out = {'features': features, 'output': [cluster_head(features) for cluster_head in self.cluster_head]} 61 | 62 | else: 63 | raise ValueError('Invalid forward pass {}'.format(forward_pass)) 64 | 65 | return out 66 | -------------------------------------------------------------------------------- /scan/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch.nn as nn 6 | import torchvision.models as models 7 | 8 | 9 | def resnet50(): 10 | backbone = models.__dict__['resnet50']() 11 | backbone.fc = nn.Identity() 12 | return {'backbone': backbone, 'dim': 2048} 13 | -------------------------------------------------------------------------------- /scan/models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1, is_last=False): 13 | super(BasicBlock, self).__init__() 14 | self.is_last = is_last 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion * planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | preact = out 32 | out = F.relu(out) 33 | if self.is_last: 34 | return out, preact 35 | else: 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1, is_last=False): 43 | super(Bottleneck, self).__init__() 44 | self.is_last = is_last 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion * planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | preact = out 65 | out = F.relu(out) 66 | if self.is_last: 67 | return out, preact 68 | else: 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 78 | bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 | 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | # Zero-initialize the last BN in each residual branch, 94 | # so that the residual branch starts with zeros, and each residual block behaves 95 | # like an identity. This improves the model by 0.2~0.3% according to: 96 | # https://arxiv.org/abs/1706.02677 97 | if zero_init_residual: 98 | for m in self.modules(): 99 | if isinstance(m, Bottleneck): 100 | nn.init.constant_(m.bn3.weight, 0) 101 | elif isinstance(m, BasicBlock): 102 | nn.init.constant_(m.bn2.weight, 0) 103 | 104 | def _make_layer(self, block, planes, num_blocks, stride): 105 | strides = [stride] + [1] * (num_blocks - 1) 106 | layers = [] 107 | for i in range(num_blocks): 108 | stride = strides[i] 109 | layers.append(block(self.in_planes, planes, stride)) 110 | self.in_planes = planes * block.expansion 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = F.relu(self.bn1(self.conv1(x))) 115 | out = self.layer1(out) 116 | out = self.layer2(out) 117 | out = self.layer3(out) 118 | out = self.layer4(out) 119 | out = self.avgpool(out) 120 | out = torch.flatten(out, 1) 121 | return out 122 | 123 | 124 | def resnet18(**kwargs): 125 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} 126 | -------------------------------------------------------------------------------- /scan/models/resnet_stl.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1, is_last=False): 13 | super(BasicBlock, self).__init__() 14 | self.is_last = is_last 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion * planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | preact = out 32 | out = F.relu(out) 33 | if self.is_last: 34 | return out, preact 35 | else: 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1, is_last=False): 43 | super(Bottleneck, self).__init__() 44 | self.is_last = is_last 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion * planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | preact = out 65 | out = F.relu(out) 66 | if self.is_last: 67 | return out, preact 68 | else: 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 78 | bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.avgpool = nn.AvgPool2d(7, stride=1) 86 | 87 | for m in self.modules(): 88 | if isinstance(m, nn.Conv2d): 89 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 90 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 91 | nn.init.constant_(m.weight, 1) 92 | nn.init.constant_(m.bias, 0) 93 | 94 | # Zero-initialize the last BN in each residual branch, 95 | # so that the residual branch starts with zeros, and each residual block behaves 96 | # like an identity. This improves the model by 0.2~0.3% according to: 97 | # https://arxiv.org/abs/1706.02677 98 | if zero_init_residual: 99 | for m in self.modules(): 100 | if isinstance(m, Bottleneck): 101 | nn.init.constant_(m.bn3.weight, 0) 102 | elif isinstance(m, BasicBlock): 103 | nn.init.constant_(m.bn2.weight, 0) 104 | 105 | def _make_layer(self, block, planes, num_blocks, stride): 106 | strides = [stride] + [1] * (num_blocks - 1) 107 | layers = [] 108 | for i in range(num_blocks): 109 | stride = strides[i] 110 | layers.append(block(self.in_planes, planes, stride)) 111 | self.in_planes = planes * block.expansion 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | out = self.maxpool(F.relu(self.bn1(self.conv1(x)))) 116 | out = self.layer1(out) 117 | out = self.layer2(out) 118 | out = self.layer3(out) 119 | out = self.layer4(out) 120 | out = self.avgpool(out) 121 | out = torch.flatten(out, 1) 122 | return out 123 | 124 | 125 | def resnet18(**kwargs): 126 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} 127 | -------------------------------------------------------------------------------- /scan/models/resnet_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1, is_last=False): 13 | super(BasicBlock, self).__init__() 14 | self.is_last = is_last 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion * planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | preact = out 32 | out = F.relu(out) 33 | if self.is_last: 34 | return out, preact 35 | else: 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1, is_last=False): 43 | super(Bottleneck, self).__init__() 44 | self.is_last = is_last 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion * planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | preact = out 65 | out = F.relu(out) 66 | if self.is_last: 67 | return out, preact 68 | else: 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 78 | bias=False) 79 | self.bn1 = nn.BatchNorm2d(64) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 | 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | # Zero-initialize the last BN in each residual branch, 94 | # so that the residual branch starts with zeros, and each residual block behaves 95 | # like an identity. This improves the model by 0.2~0.3% according to: 96 | # https://arxiv.org/abs/1706.02677 97 | if zero_init_residual: 98 | for m in self.modules(): 99 | if isinstance(m, Bottleneck): 100 | nn.init.constant_(m.bn3.weight, 0) 101 | elif isinstance(m, BasicBlock): 102 | nn.init.constant_(m.bn2.weight, 0) 103 | 104 | def _make_layer(self, block, planes, num_blocks, stride): 105 | strides = [stride] + [1] * (num_blocks - 1) 106 | layers = [] 107 | for i in range(num_blocks): 108 | stride = strides[i] 109 | layers.append(block(self.in_planes, planes, stride)) 110 | self.in_planes = planes * block.expansion 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = F.relu(self.bn1(self.conv1(x))) 115 | out = self.layer1(out) 116 | out = self.layer2(out) 117 | out = self.layer3(out) 118 | out = self.layer4(out) 119 | out = self.avgpool(out) 120 | out = torch.flatten(out, 1) 121 | return out 122 | 123 | 124 | def resnet18(**kwargs): 125 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} 126 | -------------------------------------------------------------------------------- /scan/requirements.txt: -------------------------------------------------------------------------------- 1 | """ This file contains a list of packages and their versions that were used to produce the results. """ 2 | - _libgcc_mutex=0.1=main 3 | - blas=1.0=mkl 4 | - bzip2=1.0.8=h7b6447c_0 5 | - ca-certificates=2020.1.1=0 6 | - cairo=1.14.12=h8948797_3 7 | - certifi=2020.4.5.1=py37_0 8 | - cffi=1.14.0=py37h2e261b9_0 9 | - cmake=3.14.0=h52cb24c_0 10 | - cudatoolkit=10.0.130=0 11 | - cycler=0.10.0=py37_0 12 | - dbus=1.13.12=h746ee38_0 13 | - easydict=1.9=py_0 14 | - expat=2.2.6=he6710b0_0 15 | - faiss-gpu=1.6.3=py37h1a5d453_0 16 | - ffmpeg=4.0=hcdf2ecd_0 17 | - fontconfig=2.13.0=h9420a91_0 18 | - freeglut=3.0.0=hf484d3e_5 19 | - freetype=2.9.1=h8a8886c_1 20 | - glib=2.63.1=h5a9c865_0 21 | - graphite2=1.3.13=h23475e2_0 22 | - gst-plugins-base=1.14.0=hbbd80ab_1 23 | - gstreamer=1.14.0=hb453b48_1 24 | - h5py=2.8.0=py37h989c5e5_3 25 | - harfbuzz=1.8.8=hffaf4a1_0 26 | - hdf5=1.10.2=hba1933b_1 27 | - icu=58.2=h9c2bf20_1 28 | - imageio=2.8.0=py_0 29 | - intel-openmp=2020.0=166 30 | - jasper=2.0.14=h07fcdf6_1 31 | - joblib=0.14.1=py_0 32 | - jpeg=9b=h024ee3a_2 33 | - kiwisolver=1.1.0=py37he6710b0_0 34 | - krb5=1.17.1=h173b8e3_0 35 | - ld_impl_linux-64=2.33.1=h53a641e_7 36 | - libcurl=7.69.1=h20c2e04_0 37 | - libedit=3.1.20181209=hc058e9b_0 38 | - libffi=3.2.1=hd88cf55_4 39 | - libgcc-ng=9.1.0=hdf63c60_0 40 | - libgfortran-ng=7.3.0=hdf63c60_0 41 | - libglu=9.0.0=hf484d3e_1 42 | - libopencv=3.4.2=hb342d67_1 43 | - libopus=1.3.1=h7b6447c_0 44 | - libpng=1.6.37=hbc83047_0 45 | - libprotobuf=3.11.4=hd408876_0 46 | - libssh2=1.9.0=h1ba5d50_1 47 | - libstdcxx-ng=9.1.0=hdf63c60_0 48 | - libtiff=4.1.0=h2733197_0 49 | - libuuid=1.0.3=h1bed415_2 50 | - libvpx=1.7.0=h439df22_0 51 | - libxcb=1.13=h1bed415_1 52 | - libxml2=2.9.9=hea5a465_1 53 | - matplotlib=3.1.3=py37_0 54 | - matplotlib-base=3.1.3=py37hef1b27d_0 55 | - mkl=2020.0=166 56 | - mkl-service=2.3.0=py37he904b0f_0 57 | - mkl_fft=1.0.15=py37ha843d7b_0 58 | - mkl_random=1.1.0=py37hd6b4f25_0 59 | - ncurses=6.2=he6710b0_0 60 | - ninja=1.9.0=py37hfd86e86_0 61 | - numpy=1.18.1=py37h4f9e942_0 62 | - numpy-base=1.18.1=py37hde5b4d6_1 63 | - olefile=0.46=py_0 64 | - opencv=3.4.2=py37h6fd60c2_1 65 | - openssl=1.1.1g=h7b6447c_0 66 | - pcre=8.43=he6710b0_0 67 | - pillow=7.0.0=py37hb39fc2d_0 68 | - pip=20.0.2=py37_1 69 | - pixman=0.38.0=h7b6447c_0 70 | - protobuf=3.11.4=py37he6710b0_0 71 | - py-opencv=3.4.2=py37hb342d67_1 72 | - pycparser=2.20=py_0 73 | - pyparsing=2.4.6=py_0 74 | - pyqt=5.9.2=py37h05f1152_2 75 | - python=3.7.7=hcf32534_0_cpython 76 | - python-dateutil=2.8.1=py_0 77 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0 78 | - pyyaml=5.3.1=py37h7b6447c_0 79 | - qt=5.9.7=h5867ecd_1 80 | - readline=8.0=h7b6447c_0 81 | - rhash=1.3.8=h1ba5d50_0 82 | - scikit-learn=0.22.1=py37hd81dba3_0 83 | - scipy=1.4.1=py37h0b6359f_0 84 | - setuptools=46.1.3=py37_0 85 | - sip=4.19.8=py37hf484d3e_0 86 | - six=1.14.0=py37_0 87 | - sqlite=3.31.1=h7b6447c_0 88 | - swig=3.0.12=h38cdd7d_3 89 | - tensorboardx=2.0=py_0 90 | - termcolor=1.1.0=py37_1 91 | - tk=8.6.8=hbc83047_0 92 | - torchvision=0.5.0=py37_cu100 93 | - tornado=6.0.4=py37h7b6447c_1 94 | - typing=3.6.4=py37_0 95 | - wheel=0.34.2=py37_0 96 | - xz=5.2.4=h14c3975_4 97 | - yaml=0.1.7=had09818_2 98 | - zlib=1.2.11=h7b6447c_3 99 | - zstd=1.3.7=h0b5b093_0 100 | - pip: 101 | - blis==0.4.1 102 | - catalogue==1.0.0 103 | - chardet==3.0.4 104 | - cymem==2.0.3 105 | - en-core-web-sm==2.2.5 106 | - idna==2.9 107 | - importlib-metadata==1.6.0 108 | - murmurhash==1.0.2 109 | - plac==1.1.3 110 | - preshed==3.0.2 111 | - requests==2.23.0 112 | - spacy==2.2.4 113 | - srsly==1.0.2 114 | - thinc==7.4.0 115 | - tqdm==4.45.0 116 | - urllib3==1.25.8 117 | - wasabi==0.6.0 118 | - zipp==3.1.0 119 | -------------------------------------------------------------------------------- /scan/selflabel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | 9 | from utils.config import create_config 10 | from utils.common_config import get_train_dataset, get_train_transformations,\ 11 | get_val_dataset, get_val_transformations,\ 12 | get_train_dataloader, get_val_dataloader,\ 13 | get_optimizer, get_model, adjust_learning_rate,\ 14 | get_criterion 15 | from utils.ema import EMA 16 | from utils.evaluate_utils import get_predictions, hungarian_evaluate 17 | from utils.train_utils import selflabel_train 18 | from termcolor import colored 19 | 20 | # Parser 21 | parser = argparse.ArgumentParser(description='Self-labeling') 22 | parser.add_argument('--config_env', 23 | help='Config file for the environment') 24 | parser.add_argument('--config_exp', 25 | help='Config file for the experiment') 26 | args = parser.parse_args() 27 | 28 | def main(): 29 | # Retrieve config file 30 | p = create_config(args.config_env, args.config_exp) 31 | print(colored(p, 'red')) 32 | 33 | # Get model 34 | print(colored('Retrieve model', 'blue')) 35 | model = get_model(p, p['scan_model']) 36 | print(model) 37 | model = torch.nn.DataParallel(model) 38 | model = model.cuda() 39 | 40 | # Get criterion 41 | print(colored('Get loss', 'blue')) 42 | criterion = get_criterion(p) 43 | criterion.cuda() 44 | print(criterion) 45 | 46 | # CUDNN 47 | print(colored('Set CuDNN benchmark', 'blue')) 48 | torch.backends.cudnn.benchmark = True 49 | 50 | # Optimizer 51 | print(colored('Retrieve optimizer', 'blue')) 52 | optimizer = get_optimizer(p, model) 53 | print(optimizer) 54 | 55 | # Dataset 56 | print(colored('Retrieve dataset', 'blue')) 57 | 58 | # Transforms 59 | strong_transforms = get_train_transformations(p) 60 | val_transforms = get_val_transformations(p) 61 | train_dataset = get_train_dataset(p, {'standard': val_transforms, 'augment': strong_transforms}, 62 | split='train', to_augmented_dataset=True) 63 | train_dataloader = get_train_dataloader(p, train_dataset) 64 | val_dataset = get_val_dataset(p, val_transforms) 65 | val_dataloader = get_val_dataloader(p, val_dataset) 66 | print(colored('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)), 'yellow')) 67 | 68 | # Checkpoint 69 | if os.path.exists(p['selflabel_checkpoint']): 70 | print(colored('Restart from checkpoint {}'.format(p['selflabel_checkpoint']), 'blue')) 71 | checkpoint = torch.load(p['selflabel_checkpoint'], map_location='cpu') 72 | model.load_state_dict(checkpoint['model']) 73 | optimizer.load_state_dict(checkpoint['optimizer']) 74 | start_epoch = checkpoint['epoch'] 75 | 76 | else: 77 | print(colored('No checkpoint file at {}'.format(p['selflabel_checkpoint']), 'blue')) 78 | start_epoch = 0 79 | 80 | # EMA 81 | if p['use_ema']: 82 | ema = EMA(model, alpha=p['ema_alpha']) 83 | else: 84 | ema = None 85 | 86 | # Main loop 87 | print(colored('Starting main loop', 'blue')) 88 | 89 | for epoch in range(start_epoch, p['epochs']): 90 | print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow')) 91 | print(colored('-'*10, 'yellow')) 92 | 93 | # Adjust lr 94 | lr = adjust_learning_rate(p, optimizer, epoch) 95 | print('Adjusted learning rate to {:.5f}'.format(lr)) 96 | 97 | # Perform self-labeling 98 | print('Train ...') 99 | selflabel_train(train_dataloader, model, criterion, optimizer, epoch, ema=ema) 100 | 101 | # Evaluate (To monitor progress - Not for validation) 102 | print('Evaluate ...') 103 | predictions = get_predictions(p, val_dataloader, model) 104 | clustering_stats = hungarian_evaluate(0, predictions, compute_confusion_matrix=False) 105 | print(clustering_stats) 106 | 107 | # Checkpoint 108 | print('Checkpoint ...') 109 | torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 110 | 'epoch': epoch + 1}, p['selflabel_checkpoint']) 111 | torch.save(model.module.state_dict(), p['selflabel_model']) 112 | 113 | # Evaluate and save the final model 114 | print(colored('Evaluate model at the end', 'blue')) 115 | predictions = get_predictions(p, val_dataloader, model) 116 | clustering_stats = hungarian_evaluate(0, predictions, 117 | class_names=val_dataset.classes, 118 | compute_confusion_matrix=True, 119 | confusion_matrix_file=os.path.join(p['selflabel_dir'], 'confusion_matrix.png')) 120 | print(clustering_stats) 121 | torch.save(model.module.state_dict(), p['selflabel_model']) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /scan/tutorial_nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import argparse 6 | import os 7 | import numpy as np 8 | import torch 9 | 10 | from utils.config import create_config 11 | from utils.common_config import get_model, get_train_dataset, \ 12 | get_val_dataset, \ 13 | get_val_dataloader, \ 14 | get_val_transformations \ 15 | 16 | from utils.memory import MemoryBank 17 | from utils.train_utils import simclr_train 18 | from utils.utils import fill_memory_bank 19 | from termcolor import colored 20 | 21 | # Parser 22 | parser = argparse.ArgumentParser(description='Eval_nn') 23 | parser.add_argument('--config_env', 24 | help='Config file for the environment') 25 | parser.add_argument('--config_exp', 26 | help='Config file for the experiment') 27 | args = parser.parse_args() 28 | 29 | def main(): 30 | 31 | # Retrieve config file 32 | p = create_config(args.config_env, args.config_exp) 33 | print(colored(p, 'red')) 34 | 35 | # Model 36 | print(colored('Retrieve model', 'blue')) 37 | model = get_model(p) 38 | print('Model is {}'.format(model.__class__.__name__)) 39 | print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) 40 | print(model) 41 | model = model.cuda() 42 | 43 | # CUDNN 44 | print(colored('Set CuDNN benchmark', 'blue')) 45 | torch.backends.cudnn.benchmark = True 46 | 47 | # Dataset 48 | val_transforms = get_val_transformations(p) 49 | print('Validation transforms:', val_transforms) 50 | val_dataset = get_val_dataset(p, val_transforms) 51 | val_dataloader = get_val_dataloader(p, val_dataset) 52 | print('Dataset contains {} val samples'.format(len(val_dataset))) 53 | 54 | # Memory Bank 55 | print(colored('Build MemoryBank', 'blue')) 56 | base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval 57 | base_dataloader = get_val_dataloader(p, base_dataset) 58 | memory_bank_base = MemoryBank(len(base_dataset), 59 | p['model_kwargs']['features_dim'], 60 | p['num_classes'], p['criterion_kwargs']['temperature']) 61 | memory_bank_base.cuda() 62 | memory_bank_val = MemoryBank(len(val_dataset), 63 | p['model_kwargs']['features_dim'], 64 | p['num_classes'], p['criterion_kwargs']['temperature']) 65 | memory_bank_val.cuda() 66 | 67 | # Checkpoint 68 | assert os.path.exists(p['pretext_checkpoint']) 69 | print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue')) 70 | checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu') 71 | model.load_state_dict(checkpoint) 72 | model.cuda() 73 | 74 | # Save model 75 | torch.save(model.state_dict(), p['pretext_model']) 76 | 77 | # Mine the topk nearest neighbors at the very end (Train) 78 | # These will be served as input to the SCAN loss. 79 | print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue')) 80 | fill_memory_bank(base_dataloader, model, memory_bank_base) 81 | topk = 20 82 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 83 | indices, acc = memory_bank_base.mine_nearest_neighbors(topk) 84 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) 85 | np.save(p['topk_neighbors_train_path'], indices) 86 | 87 | # Mine the topk nearest neighbors at the very end (Val) 88 | # These will be used for validation. 89 | print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue')) 90 | fill_memory_bank(val_dataloader, model, memory_bank_val) 91 | topk = 5 92 | print('Mine the nearest neighbors (Top-%d)' %(topk)) 93 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk) 94 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) 95 | np.save(p['topk_neighbors_val_path'], indices) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /scan/utils/collate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import numpy as np 7 | import collections 8 | from torch._six import string_classes 9 | 10 | 11 | """ Custom collate function """ 12 | def collate_custom(batch): 13 | if isinstance(batch[0], np.int64): 14 | return np.stack(batch, 0) 15 | 16 | if isinstance(batch[0], torch.Tensor): 17 | return torch.stack(batch, 0) 18 | 19 | elif isinstance(batch[0], np.ndarray): 20 | return np.stack(batch, 0) 21 | 22 | elif isinstance(batch[0], int): 23 | return torch.LongTensor(batch) 24 | 25 | elif isinstance(batch[0], float): 26 | return torch.FloatTensor(batch) 27 | 28 | elif isinstance(batch[0], string_classes): 29 | return batch 30 | 31 | elif isinstance(batch[0], collections.Mapping): 32 | batch_modified = {key: collate_custom([d[key] for d in batch]) for key in batch[0] if key.find('idx') < 0} 33 | return batch_modified 34 | 35 | elif isinstance(batch[0], collections.Sequence): 36 | transposed = zip(*batch) 37 | return [collate_custom(samples) for samples in transposed] 38 | 39 | raise TypeError(('Type is {}'.format(type(batch[0])))) 40 | -------------------------------------------------------------------------------- /scan/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import yaml 7 | from easydict import EasyDict 8 | from utils.utils import mkdir_if_missing 9 | 10 | def create_config(config_file_env, config_file_exp, seed, num_clusters=None): 11 | # Config for environment path 12 | with open(config_file_env, 'r') as stream: 13 | root_dir = yaml.safe_load(stream)['root_dir'] 14 | 15 | with open(config_file_exp, 'r') as stream: 16 | config = yaml.safe_load(stream) 17 | 18 | config['seed'] = seed 19 | if num_clusters is not None: 20 | config['num_classes'] = num_clusters 21 | 22 | cfg = EasyDict() 23 | 24 | # Copy 25 | for k, v in config.items(): 26 | cfg[k] = v 27 | 28 | # Set paths for pretext task (These directories are needed in every stage) 29 | base_dir = os.path.join(root_dir, cfg['train_db_name']) 30 | pretext_dir = os.path.join(base_dir, 'pretext') 31 | mkdir_if_missing(base_dir) 32 | mkdir_if_missing(pretext_dir) 33 | cfg['pretext_dir'] = pretext_dir 34 | cfg['pretext_checkpoint'] = os.path.join(pretext_dir, f'checkpoint_seed{seed}.pth.tar') 35 | cfg['pretext_model'] = os.path.join(pretext_dir, f'model_seed{seed}.pth.tar') 36 | cfg['pretext_features'] = os.path.join(pretext_dir, f'features_seed{seed}.npy') 37 | cfg['topk_neighbors_train_path'] = os.path.join(pretext_dir, f'topk-train-neighbors_seed{seed}.npy') 38 | cfg['topk_neighbors_val_path'] = os.path.join(pretext_dir, f'topk-val-neighbors_seed{seed}.npy') 39 | 40 | # If we perform clustering or self-labeling step we need additional paths. 41 | # We also include a run identifier to support multiple runs w/ same hyperparams. 42 | if cfg['setup'] in ['scan', 'selflabel']: 43 | base_dir = os.path.join(root_dir, cfg['train_db_name']) 44 | scan_dir = os.path.join(base_dir, 'scan') 45 | selflabel_dir = os.path.join(base_dir, 'selflabel') 46 | mkdir_if_missing(base_dir) 47 | mkdir_if_missing(scan_dir) 48 | mkdir_if_missing(selflabel_dir) 49 | cfg['scan_dir'] = scan_dir 50 | cfg['scan_checkpoint'] = os.path.join(scan_dir, f'checkpoint_seed{seed}_clusters{num_clusters}.pth.tar') 51 | cfg['scan_model'] = os.path.join(scan_dir, f'model_seed{seed}_clusters{num_clusters}.pth.tar') 52 | cfg['scan_features'] = os.path.join(scan_dir, f'features_seed{seed}_clusters{num_clusters}.npy') 53 | cfg['selflabel_dir'] = selflabel_dir 54 | cfg['selflabel_checkpoint'] = os.path.join(selflabel_dir, f'checkpoint_seed{seed}_clusters{num_clusters}.pth.tar') 55 | cfg['selflabel_model'] = os.path.join(selflabel_dir, f'model_seed{seed}_clusters{num_clusters}.pth.tar') 56 | 57 | return cfg 58 | -------------------------------------------------------------------------------- /scan/utils/ema.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | 6 | class EMA(object): 7 | def __init__(self, model, alpha=0.999): 8 | self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()} 9 | self.param_keys = [k for k, _ in model.named_parameters()] 10 | self.alpha = alpha 11 | 12 | def update_params(self, model): 13 | state = model.state_dict() 14 | for name in self.param_keys: 15 | self.shadow[name].copy_(self.alpha * self.shadow[name] + (1 - self.alpha) * state[name]) 16 | 17 | def apply_shadow(self, model): 18 | model.load_state_dict(self.shadow, strict=True) 19 | -------------------------------------------------------------------------------- /scan/utils/memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class MemoryBank(object): 10 | def __init__(self, n, dim, num_classes, temperature, feature_dim=512): 11 | self.n = n 12 | self.dim = dim 13 | self.features = torch.FloatTensor(self.n, self.dim) 14 | self.pre_lasts = torch.FloatTensor(self.n, feature_dim) 15 | self.targets = torch.LongTensor(self.n) 16 | self.ptr = 0 17 | self.device = 'cpu' 18 | self.K = 100 19 | self.temperature = temperature 20 | self.C = num_classes 21 | 22 | def weighted_knn(self, predictions): 23 | # perform weighted knn 24 | retrieval_one_hot = torch.zeros(self.K, self.C).to(self.device) 25 | batchSize = predictions.shape[0] 26 | correlation = torch.matmul(predictions, self.features.t()) 27 | yd, yi = correlation.topk(self.K, dim=1, largest=True, sorted=True) 28 | candidates = self.targets.view(1,-1).expand(batchSize, -1) 29 | retrieval = torch.gather(candidates, 1, yi) 30 | retrieval_one_hot.resize_(batchSize * self.K, self.C).zero_() 31 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 32 | yd_transform = yd.clone().div_(self.temperature).exp_() 33 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , self.C), 34 | yd_transform.view(batchSize, -1, 1)), 1) 35 | _, class_preds = probs.sort(1, True) 36 | class_pred = class_preds[:, 0] 37 | 38 | return class_pred 39 | 40 | def knn(self, predictions): 41 | # perform knn 42 | correlation = torch.matmul(predictions, self.features.t()) 43 | sample_pred = torch.argmax(correlation, dim=1) 44 | class_pred = torch.index_select(self.targets, 0, sample_pred) 45 | return class_pred 46 | 47 | def mine_nearest_neighbors(self, topk, calculate_accuracy=True): 48 | # mine the topk nearest neighbors for every sample 49 | import faiss 50 | features = self.features.cpu().numpy() 51 | n, dim = features.shape[0], features.shape[1] 52 | index = faiss.IndexFlatIP(dim) 53 | index = faiss.index_cpu_to_all_gpus(index) 54 | index.add(features) 55 | distances, indices = index.search(features, topk+1) # Sample itself is included 56 | 57 | # evaluate 58 | if calculate_accuracy: 59 | targets = self.targets.cpu().numpy() 60 | neighbor_targets = np.take(targets, indices[:,1:], axis=0) # Exclude sample itself for eval 61 | anchor_targets = np.repeat(targets.reshape(-1,1), topk, axis=1) 62 | accuracy = np.mean(neighbor_targets == anchor_targets) 63 | return indices, accuracy 64 | 65 | else: 66 | return indices 67 | 68 | def reset(self): 69 | self.ptr = 0 70 | 71 | def update(self, features, pre_last, targets): 72 | b = features.size(0) 73 | 74 | assert(b + self.ptr <= self.n) 75 | 76 | self.features[self.ptr:self.ptr+b].copy_(features.detach()) 77 | self.pre_lasts[self.ptr:self.ptr+b].copy_(pre_last.detach()) 78 | self.targets[self.ptr:self.ptr+b].copy_(targets.detach()) 79 | self.ptr += b 80 | 81 | def to(self, device): 82 | self.features = self.features.to(device) 83 | self.pre_lasts = self.pre_lasts.to(device) 84 | self.targets = self.targets.to(device) 85 | self.device = device 86 | 87 | def cpu(self): 88 | self.to('cpu') 89 | 90 | def cuda(self): 91 | self.to('cuda:0') 92 | -------------------------------------------------------------------------------- /scan/utils/mypath.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | 7 | 8 | class MyPath(object): 9 | @staticmethod 10 | def db_root_dir(database=''): 11 | db_names = {'cifar-10', 'stl-10', 'cifar-100', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200'} 12 | assert(database in db_names) 13 | 14 | if database == 'cifar-10': 15 | return './datasets/cifar-10/' 16 | 17 | elif database == 'cifar-100': 18 | return './datasets/cifar-100/' 19 | 20 | elif database == 'stl-10': 21 | return './datasets/stl-10/' 22 | 23 | elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']: 24 | return './datasets/imagenet/' 25 | 26 | else: 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /scan/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import torch 6 | import numpy as np 7 | from utils.utils import AverageMeter, ProgressMeter 8 | 9 | 10 | def simclr_train(train_loader, model, criterion, optimizer, epoch): 11 | """ 12 | Train according to the scheme from SimCLR 13 | https://arxiv.org/abs/2002.05709 14 | """ 15 | losses = AverageMeter('Loss', ':.4e') 16 | progress = ProgressMeter(len(train_loader), 17 | [losses], 18 | prefix="Epoch: [{}]".format(epoch)) 19 | 20 | model.train() 21 | 22 | for i, batch in enumerate(train_loader): 23 | images = batch['image'] 24 | images_augmented = batch['image_augmented'] 25 | b, c, h, w = images.size() 26 | input_ = torch.cat([images.unsqueeze(1), images_augmented.unsqueeze(1)], dim=1) 27 | input_ = input_.view(-1, c, h, w) 28 | input_ = input_.cuda(non_blocking=True) 29 | targets = batch['target'].cuda(non_blocking=True) 30 | 31 | output = model(input_).view(b, 2, -1) 32 | loss = criterion(output) 33 | losses.update(loss.item()) 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | if i % 25 == 0: 40 | progress.display(i) 41 | 42 | 43 | def scan_train(train_loader, model, criterion, optimizer, epoch, update_cluster_head_only=False): 44 | """ 45 | Train w/ SCAN-Loss 46 | """ 47 | total_losses = AverageMeter('Total Loss', ':.4e') 48 | consistency_losses = AverageMeter('Consistency Loss', ':.4e') 49 | entropy_losses = AverageMeter('Entropy', ':.4e') 50 | progress = ProgressMeter(len(train_loader), 51 | [total_losses, consistency_losses, entropy_losses], 52 | prefix="Epoch: [{}]".format(epoch)) 53 | 54 | if update_cluster_head_only: 55 | model.eval() # No need to update BN 56 | else: 57 | model.train() # Update BN 58 | 59 | for i, batch in enumerate(train_loader): 60 | # Forward pass 61 | anchors = batch['anchor'].cuda(non_blocking=True) 62 | neighbors = batch['neighbor'].cuda(non_blocking=True) 63 | 64 | if update_cluster_head_only: # Only calculate gradient for backprop of linear layer 65 | with torch.no_grad(): 66 | anchors_features = model(anchors, forward_pass='backbone') 67 | neighbors_features = model(neighbors, forward_pass='backbone') 68 | anchors_output = model(anchors_features, forward_pass='head') 69 | neighbors_output = model(neighbors_features, forward_pass='head') 70 | 71 | else: # Calculate gradient for backprop of complete network 72 | anchors_output = model(anchors) 73 | neighbors_output = model(neighbors) 74 | 75 | # Loss for every head 76 | total_loss, consistency_loss, entropy_loss = [], [], [] 77 | for anchors_output_subhead, neighbors_output_subhead in zip(anchors_output, neighbors_output): 78 | total_loss_, consistency_loss_, entropy_loss_ = criterion(anchors_output_subhead, 79 | neighbors_output_subhead) 80 | total_loss.append(total_loss_) 81 | consistency_loss.append(consistency_loss_) 82 | entropy_loss.append(entropy_loss_) 83 | 84 | # Register the mean loss and backprop the total loss to cover all subheads 85 | total_losses.update(np.mean([v.item() for v in total_loss])) 86 | consistency_losses.update(np.mean([v.item() for v in consistency_loss])) 87 | entropy_losses.update(np.mean([v.item() for v in entropy_loss])) 88 | 89 | total_loss = torch.sum(torch.stack(total_loss, dim=0)) 90 | 91 | optimizer.zero_grad() 92 | total_loss.backward() 93 | optimizer.step() 94 | 95 | if i % 25 == 0: 96 | progress.display(i) 97 | 98 | 99 | def selflabel_train(train_loader, model, criterion, optimizer, epoch, ema=None): 100 | """ 101 | Self-labeling based on confident samples 102 | """ 103 | losses = AverageMeter('Loss', ':.4e') 104 | progress = ProgressMeter(len(train_loader), [losses], 105 | prefix="Epoch: [{}]".format(epoch)) 106 | model.train() 107 | 108 | for i, batch in enumerate(train_loader): 109 | images = batch['image'].cuda(non_blocking=True) 110 | images_augmented = batch['image_augmented'].cuda(non_blocking=True) 111 | 112 | with torch.no_grad(): 113 | output = model(images)[0] 114 | output_augmented = model(images_augmented)[0] 115 | 116 | loss = criterion(output, output_augmented) 117 | losses.update(loss.item()) 118 | 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if ema is not None: # Apply EMA to update the weights of the network 124 | ema.update_params(model) 125 | ema.apply_shadow(model) 126 | 127 | if i % 25 == 0: 128 | progress.display(i) 129 | -------------------------------------------------------------------------------- /scan/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import numpy as np 8 | import errno 9 | 10 | def mkdir_if_missing(directory): 11 | if not os.path.exists(directory): 12 | try: 13 | os.makedirs(directory) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | class AverageMeter(object): 20 | def __init__(self, name, fmt=':f'): 21 | self.name = name 22 | self.fmt = fmt 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def __str__(self): 38 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 39 | return fmtstr.format(**self.__dict__) 40 | 41 | 42 | class ProgressMeter(object): 43 | def __init__(self, num_batches, meters, prefix=""): 44 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 45 | self.meters = meters 46 | self.prefix = prefix 47 | 48 | def display(self, batch): 49 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 50 | entries += [str(meter) for meter in self.meters] 51 | print('\t'.join(entries)) 52 | 53 | def _get_batch_fmtstr(self, num_batches): 54 | num_digits = len(str(num_batches // 1)) 55 | fmt = '{:' + str(num_digits) + 'd}' 56 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 57 | 58 | 59 | @torch.no_grad() 60 | def fill_memory_bank(loader, model, memory_bank): 61 | model.eval() 62 | memory_bank.reset() 63 | 64 | for i, batch in enumerate(loader): 65 | images = batch['image'].cuda(non_blocking=True) 66 | targets = batch['target'].cuda(non_blocking=True) 67 | output, pre_last = model(images, return_pre_last=True) 68 | memory_bank.update(output, pre_last, targets) 69 | if i % 100 == 0: 70 | print('Fill Memory Bank [%d/%d]' %(i, len(loader))) 71 | 72 | 73 | def confusion_matrix(predictions, gt, class_names, output_file=None): 74 | # Plot confusion_matrix and store result to output_file 75 | import sklearn.metrics 76 | import matplotlib.pyplot as plt 77 | confusion_matrix = sklearn.metrics.confusion_matrix(gt, predictions) 78 | confusion_matrix = confusion_matrix / np.sum(confusion_matrix, 1) 79 | 80 | fig, axes = plt.subplots(1) 81 | plt.imshow(confusion_matrix, cmap='Blues') 82 | axes.set_xticks([i for i in range(len(class_names))]) 83 | axes.set_yticks([i for i in range(len(class_names))]) 84 | axes.set_xticklabels(class_names, ha='right', fontsize=8, rotation=40) 85 | axes.set_yticklabels(class_names, ha='right', fontsize=8) 86 | 87 | for (i, j), z in np.ndenumerate(confusion_matrix): 88 | if i == j: 89 | axes.text(j, i, '%d' %(100*z), ha='center', va='center', color='white', fontsize=6) 90 | else: 91 | pass 92 | 93 | plt.tight_layout() 94 | if output_file is None: 95 | plt.show() 96 | else: 97 | plt.savefig(output_file, dpi=300, bbox_inches='tight') 98 | plt.close() 99 | --------------------------------------------------------------------------------