├── .gitignore
├── 2d_selection_gif.gif
├── LICENSE
├── README.md
├── USAGE.md
├── cifar_selection.png
├── dcom_delta_updating.gif
├── dcom_semi.png
├── dcom_supervised.png
├── deep-al
├── LICENSE
├── README.md
├── configs
│ ├── cifar10
│ │ ├── al
│ │ │ ├── RESNET18.yaml
│ │ │ ├── RESNET18_ENS.yaml
│ │ │ └── RESNET18_IM.yaml
│ │ ├── evaluate
│ │ │ └── RESNET18.yaml
│ │ └── train
│ │ │ ├── RESNET18.yaml
│ │ │ └── RESNET18_ENS.yaml
│ ├── cifar100
│ │ └── al
│ │ │ └── RESNET18.yaml
│ ├── template.yaml
│ └── tinyimagenet
│ │ └── al
│ │ └── RESNET18.yaml
├── docs
│ ├── AL_results.png
│ └── GETTING_STARTED.md
├── pycls
│ ├── __init__.py
│ ├── al
│ │ ├── ActiveLearning.py
│ │ ├── DCoM.py
│ │ ├── Sampling.py
│ │ ├── __init__.py
│ │ ├── prob_cover.py
│ │ ├── typiclust.py
│ │ └── vaal_util.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── builders.py
│ │ ├── config.py
│ │ ├── losses.py
│ │ ├── net.py
│ │ └── optimizer.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── augment.py
│ │ ├── custom_datasets.py
│ │ ├── data.py
│ │ ├── imbalanced_cifar.py
│ │ ├── randaugment.py
│ │ ├── sampler.py
│ │ ├── simclr_augment.py
│ │ ├── tiny_imagenet.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── features.py
│ │ │ ├── gaussian_blur.py
│ │ │ └── helpers.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── alexnet.py
│ │ ├── resnet.py
│ │ ├── vaal_model.py
│ │ └── vgg.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── benchmark.py
│ │ ├── checkpoint.py
│ │ ├── distributed.py
│ │ ├── io.py
│ │ ├── logging.py
│ │ ├── meters.py
│ │ ├── metrics.py
│ │ ├── net.py
│ │ ├── plotting.py
│ │ └── timer.py
├── requirements.txt
└── tools
│ ├── __init__.py
│ ├── ensemble_al.py
│ ├── ensemble_train.py
│ ├── test_model.py
│ ├── train.py
│ └── train_al.py
├── probcover_selection.gif
├── probcover_semi.png
├── results.png
├── scan
├── LICENSE
├── README.md
├── TUTORIAL.md
├── configs
│ ├── env.yml
│ ├── pretext
│ │ ├── moco_imagenet100.yml
│ │ ├── moco_imagenet200.yml
│ │ ├── moco_imagenet50.yml
│ │ ├── simclr_cifar10.yml
│ │ ├── simclr_cifar100.yml
│ │ ├── simclr_stl10.yml
│ │ └── simclr_tinyimagenet.yml
│ ├── scan
│ │ ├── imagenet_eval.yml
│ │ ├── scan_cifar10.yml
│ │ ├── scan_cifar100.yml
│ │ ├── scan_imagenet_100.yml
│ │ ├── scan_imagenet_200.yml
│ │ ├── scan_imagenet_50.yml
│ │ ├── scan_stl10.yml
│ │ └── scan_tinyimagenet.yml
│ └── selflabel
│ │ ├── selflabel_cifar10.yml
│ │ ├── selflabel_cifar100.yml
│ │ ├── selflabel_imagenet_100.yml
│ │ ├── selflabel_imagenet_200.yml
│ │ ├── selflabel_imagenet_50.yml
│ │ └── selflabel_stl10.yml
├── data
│ ├── augment.py
│ ├── cifar.py
│ ├── custom_dataset.py
│ ├── imagenet.py
│ ├── imagenet_subsets
│ │ ├── imagenet_100.txt
│ │ ├── imagenet_200.txt
│ │ └── imagenet_50.txt
│ ├── stl.py
│ └── tinyimagenet.py
├── eval.py
├── images
│ ├── pipeline.png
│ ├── prototypes_cifar10.jpg
│ ├── teaser.jpg
│ └── tutorial
│ │ ├── confusion_matrix_stl10.png
│ │ └── prototypes_stl10.jpg
├── losses
│ └── losses.py
├── moco.py
├── models
│ ├── models.py
│ ├── resnet.py
│ ├── resnet_cifar.py
│ ├── resnet_stl.py
│ └── resnet_tinyimagenet.py
├── requirements.txt
├── scan.py
├── selflabel.py
├── simclr.py
├── tutorial_nn.py
└── utils
│ ├── collate.py
│ ├── common_config.py
│ ├── config.py
│ ├── ema.py
│ ├── evaluate_utils.py
│ ├── memory.py
│ ├── mypath.py
│ ├── train_utils.py
│ └── utils.py
└── typiclust-env.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Shared objects
7 | *.so
8 |
9 | # Distribution / packaging
10 | build/
11 | *.egg-info/
12 | *.egg
13 |
14 | # Temporary files
15 | *.swn
16 | *.swo
17 | *.swp
18 |
19 | # PyCharm
20 | ../.idea/
21 |
22 | # Mac
23 | .DS_STORE
24 |
25 | # Data symlinks
26 | pycls/datasets/data/
27 |
28 | # Other
29 | logs/
30 | scratch*
31 | *.out
32 |
33 | output/
34 | *.ipynb_checkpoints/
35 | data/
36 | scan/results/
--------------------------------------------------------------------------------
/2d_selection_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/2d_selection_gif.gif
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright 2022 (c) Avihu Dekel and Guy Hacohen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Typiclust, ProbCover & DCoM Official Code Repository
2 |
3 |
4 | This is the official implementation for the papers **Active Learning on a Budget - Opposite Strategies Suit High and Low Budgets** and **Active Learning Through a Covering Lens**.
5 |
6 | This code implements TypiClust, ProbCover and DCoM - Simple and Effective Low Budget Active Learning methods.
7 | ## Typiclust
8 |
9 | [**Arxiv link**](https://arxiv.org/abs/2202.02794),
10 | [**Twitter Post link**](https://twitter.com/AvihuDkl/status/1529385835694637058),
11 | [**Blog Post link**](https://avihu111.github.io/Active-Learning/)
12 |
13 |
14 | TypiClust first employs a representation learning method, then clusters the data into K clusters, and selects the most Typical (Dense) sample from every cluster. In other words, TypiClust selects samples from dense and diverse regions of the data distribution.
15 |
16 | Selection of 30 samples on CIFAR-10:
17 |
18 |
19 |
20 | Selection of 10 samples from a GMM:
21 |
22 |
23 |
24 | TypiClust Results summary
25 |
26 |
27 |
28 | ## Probability Cover
29 |
30 | [**Arxiv link**](https://arxiv.org/abs/2205.11320),
31 | [**Twitter Post link**](https://mobile.twitter.com/AvihuDkl/status/1579497337650839553),
32 | [**Blog Post link**](https://avihu111.github.io/Covering-Lens/)
33 |
34 | ProbCover also uses a representation learning method. Then, around every point is placed a $\delta$-radius ball, and the subset of $b$ (budget) balls which covers the most of the points is selected, with their centers chosen as the samples to be labeled.
35 |
36 | Unfolding selection of ProbCover
37 |
38 |
39 |
40 | ProbCover results in the Semi-Supervised training framework
41 |
42 |
43 |
44 | ## DCoM
45 | [**Arxiv link**](https://arxiv.org/abs/2407.01804)
46 |
47 | DCoM employs a representation learning approach. Initially, a $\Delta_{\text{avg}}$-radius ball is placed around each point. The $\Delta$ list provides a specific radius for each labeled example individually. From these, a subset of $b$ balls is chosen based on their coverage of the most points, with the centers of these balls selected as the samples to be labeled. After training the model, the $\Delta$ list is updated according to the purity of the balls to achieve more accurate radii and coverage. DCoM utilizes this coverage to determine the competence score, which balances typicality and uncertainty.
48 |
49 | Illustration of DCoM's $\Delta$ updating
50 |
51 |
52 |
53 | DCoM results in the Supervised training framework
54 |
55 |
56 |
57 | DCoM results in the Semi-Supervised training framework
58 |
59 |
60 |
61 | ## Usage
62 |
63 | Please see [`USAGE`](USAGE.md) for brief instructions on installation and basic usage examples.
64 |
65 | ## Citing this Repository
66 | This Repository makes use of two repositories: ([SCAN](https://github.com/wvangansbeke/Unsupervised-Classification) and [Deep-AL](https://github.com/acl21/deep-active-learning-pytorch))
67 | Please consider citing their work and ours:
68 | ```
69 | @article{hacohen2022active,
70 | title={Active learning on a budget: Opposite strategies suit high and low budgets},
71 | author={Hacohen, Guy and Dekel, Avihu and Weinshall, Daphna},
72 | journal={arXiv preprint arXiv:2202.02794},
73 | year={2022}
74 | }
75 |
76 | @article{yehudaActiveLearningCovering2022,
77 | title = {Active {{Learning Through}} a {{Covering Lens}}},
78 | author = {Yehuda, Ofer and Dekel, Avihu and Hacohen, Guy and Weinshall, Daphna},
79 | journal={arXiv preprint arXiv:2205.11320},
80 | year={2022}
81 | }
82 |
83 | @article{mishal2024dcom,
84 | title={DCoM: Active Learning for All Learners},
85 | author={Mishal, Inbal and Weinshall, Daphna},
86 | journal={arXiv preprint arXiv:2407.01804},
87 | year={2024}
88 | }
89 | ```
90 |
91 | ## License
92 | This toolkit is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information.
93 |
--------------------------------------------------------------------------------
/USAGE.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | ## Setup
4 |
5 | Clone the repository
6 | ```
7 | git clone https://github.com/avihu111/TypiClust
8 | cd TypiClust
9 | ```
10 |
11 | Create an environment using
12 | ```
13 | conda create --name typiclust --file typiclust-env.txt
14 | conda activate typiclust
15 | pip install pyyaml easydict termcolor tqdm simplejson yacs
16 | ```
17 |
18 | If that fails, it might be due to incompatible CUDA version.
19 | In that case, try installing by
20 | ```
21 | conda create --name typiclust python=3.7
22 | conda activate typiclust
23 | conda install pytorch torchvision torchaudio cudatoolkit= -c pytorch
24 | conda install matplotlib scipy scikit-learn pandas
25 | conda install -c conda-forge faiss-gpu
26 | pip install pyyaml easydict termcolor tqdm simplejson yacs
27 | ```
28 | Select the GPU to be used by running
29 | ```
30 | CUDA_VISIBLE_DEVICES=0
31 | ```
32 |
33 | ## Representation Learning
34 | Both TypiClust variants and ProbCover rely on representation learning.
35 | To train CIFAR-10 on simclr please run
36 | ```
37 | cd scan
38 | python simclr.py --config_env configs/env.yml --config_exp configs/pretext/simclr_cifar10.yml
39 | cd ..
40 | ```
41 | When this finishes, the file `./results/cifar-10/pretext/features_seed1.npy` should exist.
42 |
43 | To save time, you can download the features of CIFAR-10/100 from here:
44 |
45 | | Dataset | Download link |
46 | |------------------|---------------|
47 | |CIFAR10 | [Download](https://drive.google.com/file/d/1Le1ZuZOpfxBfxL3nnNahZcCt-lLWLQSB/view?usp=sharing) |
48 | |CIFAR100 | [Download](https://drive.google.com/file/d/1o2nz_SKLdcaTCB9XVA44qCTVSUSmktUb/view?usp=sharing) |
49 |
50 |
51 | and locate the files here:`./results/cifar-10/pretext/features_seed1.npy`.
52 |
53 | ## TypiClust - K-Means variant
54 | To select samples according to TypiClust (K-Means) where the `initial_size=0` and the `budget=100` please run
55 | ```
56 | cd deep-al/tools
57 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al typiclust_rp --exp-name auto --initial_size 0 --budget 100
58 | cd ../../
59 | ```
60 |
61 |
62 | ## TypiClust - SCAN variant
63 | In this section we select `budget=10` samples without an initial set.
64 | We first must run SCAN clustering algorithm, as TypiClust uses its features cluster assignments.
65 | Please repeat the following command `for k in [10, 20, 30, 40, 50, 60]`
66 |
67 | ```
68 | cd scan
69 | python scan.py --config_env configs/env.yml --config_exp configs/scan/scan_cifar10.yml --num_clusters k
70 | cd ..
71 | ```
72 | You can use other representations and change the path in the file `deep-al/pycls/datasets/utils/features.py`.
73 | Then, you can run the active learning experiment by running
74 | ```
75 | cd deep-al/tools
76 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al typiclust_dc --exp-name auto --initial_size 0 --budget 10
77 | cd ../../
78 | ```
79 |
80 | ## ProbCover
81 | Example ProbCover script
82 | ```
83 | cd deep-al/tools
84 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al probcover --exp-name auto --initial_size 0 --budget 10 --delta 0.75
85 | cd ../../
86 | ```
87 |
88 | ## DCoM
89 | Example DCoM script:
90 | ```
91 | cd deep-al/tools
92 | python train_al.py --cfg ../configs/cifar10/al/RESNET18.yaml --al dcom --exp-name auto --initial_size 0 --budget 10 --initial_delta 0.75
93 | cd ../../
94 | ```
95 |
96 | You can add the `a_logistic` and `k_logistic` parameters to the run using `--a_logistic` and `--k_logistic`.
97 |
--------------------------------------------------------------------------------
/cifar_selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/cifar_selection.png
--------------------------------------------------------------------------------
/dcom_delta_updating.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/dcom_delta_updating.gif
--------------------------------------------------------------------------------
/dcom_semi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/dcom_semi.png
--------------------------------------------------------------------------------
/dcom_supervised.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/dcom_supervised.png
--------------------------------------------------------------------------------
/deep-al/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright 2022 (c) Akshay L Chandra and Vineeth N Balasubramanian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/deep-al/README.md:
--------------------------------------------------------------------------------
1 | # Deep Active Learning Toolkit for Image Classification in PyTorch
2 |
3 | This is a code base for deep active learning for image classification written in [PyTorch](https://pytorch.org/). It is build on top of FAIR's [pycls](https://github.com/facebookresearch/pycls/). I want to emphasize that this is a derivative of the toolkit originally shared with me via email by Prateek Munjal _et al._, the authors of the paper _"Towards Robust and Reproducible Active Learning using Neural Networks"_, paper available [here](https://arxiv.org/abs/2002.09564).
4 |
5 | ## Introduction
6 |
7 | The goal of this repository is to provide a simple and flexible codebase for deep active learning. It is designed to support rapid implementation and evaluation of research ideas. We also provide a results on CIFAR10 below.
8 |
9 | The codebase currently only supports single-machine single-gpu training. We will soon scale it to single-machine multi-gpu training, powered by the PyTorch distributed package.
10 |
11 | ## Using the Toolkit
12 |
13 | Please see [`GETTING_STARTED`](docs/GETTING_STARTED.md) for brief instructions on installation, adding new datasets, basic usage examples, etc.
14 |
15 | ## Active Learning Methods Supported
16 | * Uncertainty Sampling
17 | * Least Confidence
18 | * Min-Margin
19 | * Max-Entropy
20 | * Deep Bayesian Active Learning (DBAL) [1]
21 | * Bayesian Active Learning by Disagreement (BALD) [1]
22 | * Diversity Sampling
23 | * Coreset (greedy) [2]
24 | * Variational Adversarial Active Learning (VAAL) [3]
25 | * Query-by-Committee Sampling
26 | * Ensemble Variation Ratio (Ens-varR) [4]
27 |
28 |
29 | ## Datasets Supported
30 | * [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html)
31 | * [MNIST](http://yann.lecun.com/exdb/mnist/)
32 | * [SVHN](http://ufldl.stanford.edu/housenumbers/)
33 | * [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet) (Download the zip file [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip))
34 | * Long-Tail CIFAR-10/100
35 |
36 | Follow the instructions in [`GETTING_STARTED`](docs/GETTING_STARTED.md) to add a new dataset.
37 |
38 | ## Results on CIFAR10 and CIFAR100
39 |
40 | The following are the results on CIFAR10 and CIFAR100, trained with hyperameters present in `configs/cifar10/al/RESNET18.yaml` and `configs/cifar100/al/RESNET18.yaml` respectively. All results were averaged over 3 runs.
41 |
42 |
43 |
44 |
45 |

46 |
47 |
48 | ### CIFAR10 at 60%
49 | ```
50 | | AL Method | Test Accuracy |
51 | |:----------------:|:---------------------------:|
52 | | DBAL | 91.670000 +- 0.230651 |
53 | | Least Confidence | 91.510000 +- 0.087178 |
54 | | BALD | 91.470000 +- 0.293087 |
55 | | Coreset | 91.433333 +- 0.090738 |
56 | | Max-Entropy | 91.373333 +- 0.363639 |
57 | | Min-Margin | 91.333333 +- 0.234592 |
58 | | Ensemble-varR | 89.866667 +- 0.127410 |
59 | | Random | 89.803333 +- 0.230290 |
60 | | VAAL | 89.690000 +- 0.115326 |
61 | ```
62 |
63 | ### CIFAR100 at 60%
64 | ```
65 | | AL Method | Test Accuracy |
66 | |:----------------:|:---------------------------:|
67 | | DBAL | 55.400000 +- 1.037931 |
68 | | Coreset | 55.333333 +- 0.773714 |
69 | | Max-Entropy | 55.226667 +- 0.536128 |
70 | | BALD | 55.186667 +- 0.369639 |
71 | | Least Confidence | 55.003333 +- 0.937248 |
72 | | Min-Margin | 54.543333 +- 0.611583 |
73 | | Ensemble-varR | 54.186667 +- 0.325628 |
74 | | VAAL | 53.943333 +- 0.680686 |
75 | | Random | 53.546667 +- 0.302875 |
76 | ```
77 |
78 | ## Citing this Repository
79 |
80 | If you find this repo helpful in your research, please consider citing us and the owners of the original toolkit:
81 |
82 | ```
83 | @article{Chandra2021DeepAL,
84 | Author = {Akshay L Chandra and Vineeth N Balasubramanian},
85 | Title = {Deep Active Learning Toolkit for Image Classification in PyTorch},
86 | Journal = {https://github.com/acl21/deep-active-learning-pytorch},
87 | Year = {2021}
88 | }
89 |
90 | @article{Munjal2020TowardsRA,
91 | title={Towards Robust and Reproducible Active Learning Using Neural Networks},
92 | author={Prateek Munjal and N. Hayat and Munawar Hayat and J. Sourati and S. Khan},
93 | journal={ArXiv},
94 | year={2020},
95 | volume={abs/2002.09564}
96 | }
97 | ```
98 |
99 | ## License
100 |
101 | This toolkit is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information.
102 |
103 | ## References
104 |
105 | [1] Yarin Gal, Riashat Islam, and Zoubin Ghahramani. Deep bayesian active learning with image data. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1183–1192. JMLR. org, 2017.
106 |
107 | [2] Ozan Sener and Silvio Savarese. Active learning for convolutional neural networks: A core-set approach. In International Conference on Learning Representations, 2018.
108 |
109 | [3] Sinha, Samarth et al. Variational Adversarial Active Learning. 2019 IEEE/CVF International Conference on Computer Vision (ICCV) (2019): 5971-5980.
110 |
111 | [4] William H. Beluch, Tim Genewein, Andreas Nürnberger, and Jan M. Köhler. The power of ensembles for active learning in image classification. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 9368–9377, 2018.
--------------------------------------------------------------------------------
/deep-al/configs/cifar10/al/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 100
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INIT_L_RATIO: 0.
37 | # BUDGET_SIZE: 10
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
40 | FINE_TUNE: False
41 | DELTA_RESOLUTION: 0.05
42 | MAX_DELTA: 1.1
--------------------------------------------------------------------------------
/deep-al/configs/cifar10/al/RESNET18_ENS.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | BUDGET_SIZE: 5000
37 | SAMPLING_FN: 'ensemble_var_R' # do not change this for ensemble learning
38 | MAX_ITER: 5
39 | ENSEMBLE:
40 | NUM_MODELS: 3
41 | MODEL_TYPE: ['resnet18']
--------------------------------------------------------------------------------
/deep-al/configs/cifar10/al/RESNET18_IM.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: IMBALANCED_CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INITIAL_L_RATIO: 0.1
37 | BUDGET_SIZE: 1270
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
--------------------------------------------------------------------------------
/deep-al/configs/cifar10/evaluate/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'SOME_RANDOM_NAME'
2 | RNG_SEED: 1
3 | GPU_ID: '3'
4 | DATASET:
5 | NAME: CIFAR10
6 | ROOT_DIR: 'data'
7 | MODEL:
8 | TYPE: resnet18
9 | NUM_CLASSES: 10
10 | TEST:
11 | SPLIT: test
12 | BATCH_SIZE: 200
13 | IM_SIZE: 32
14 | MODEL_PATH: 'path/to/saved/model'
15 | DATA_LOADER:
16 | NUM_WORKERS: 4
17 | CUDNN:
18 | BENCHMARK: True
--------------------------------------------------------------------------------
/deep-al/configs/cifar10/train/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 1
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
--------------------------------------------------------------------------------
/deep-al/configs/cifar10/train/RESNET18_ENS.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'SOME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ENSEMBLE:
36 | NUM_MODELS: 3
37 | MODEL_TYPE: ['resnet18']
--------------------------------------------------------------------------------
/deep-al/configs/cifar100/al/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR100
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 100
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 100
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 8
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INIT_L_RATIO: 0.1
37 | BUDGET_SIZE: 5000
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
40 | DELTA_RESOLUTION: 0.05
41 | MAX_DELTA: 1.1
--------------------------------------------------------------------------------
/deep-al/configs/template.yaml:
--------------------------------------------------------------------------------
1 | # Folder name where best model logs etc are saved. "auto" creates a timestamp based folder
2 | EXP_NAME: 'SOME_RANDOM_NAME'
3 | # Note that non-determinism may still be present due to non-deterministic
4 | # operator implementations in GPU operator libraries
5 | RNG_SEED: 1
6 | # GPU ID you want to execute the process on
7 | GPU_ID: '3'
8 | DATASET:
9 | NAME: CIFAR10 # or CIFAR100, MNIST, SVHN, TinyImageNet
10 | ROOT_DIR: 'data' # Relative path where data should be downloaded
11 | # Specifies the proportion of data in train set that should be considered as the validation data
12 | VAL_RATIO: 0.1
13 | # Data augmentation methods - 'simclr', 'randaug', 'horizontalflip'
14 | AUG_METHOD: 'horizontalflip'
15 | MODEL:
16 | # Model type.
17 | # Choose from vgg style ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',]
18 | # or from resnet style ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
19 | # 'wide_resnet50_2', 'wide_resnet101_2']
20 | TYPE: resnet18
21 | NUM_CLASSES: 10
22 | OPTIM:
23 | TYPE: 'sgd' # or 'adam'
24 | BASE_LR: 0.1
25 | # Learning rate policy select from {'cos', 'exp', 'steps'}
26 | LR_POLICY: steps
27 | # Steps for 'steps' policy (in epochs)
28 | STEPS: [0] #[0, 30, 60, 90]
29 | # Training Epochs
30 | MAX_EPOCH: 1
31 | # Momentum
32 | MOMENTUM: 0.9
33 | # Nesterov Momentum
34 | NESTEROV: False
35 | # L2 regularization
36 | WEIGHT_DECAY: 0.0005
37 | # Exponential decay factor
38 | GAMMA: 0.1
39 | TRAIN:
40 | SPLIT: train
41 | # Training mini-batch size
42 | BATCH_SIZE: 256
43 | # Image size
44 | IM_SIZE: 32
45 | IM_CHANNELS = 3
46 | # Evaluate model on test data every eval period epochs
47 | EVAL_PERIOD: 1
48 | TEST:
49 | SPLIT: test
50 | # Testing mini-batch size
51 | BATCH_SIZE: 200
52 | # Image size
53 | IM_SIZE: 32
54 | # Saved model to use for testing (useful when running tools/test_model.py)
55 | MODEL_PATH: ''
56 | DATA_LOADER:
57 | NUM_WORKERS: 4
58 | CUDNN:
59 | BENCHMARK: True
60 | ACTIVE_LEARNING:
61 | # Active sampling budget (at each episode)
62 | BUDGET_SIZE: 5000
63 | # Active sampling method
64 | SAMPLING_FN: 'dbal' # 'random', 'uncertainty', 'entropy', 'margin', 'bald', 'vaal', 'coreset', 'ensemble_var_R'
65 | # Initial labeled pool ratio (% of total train set that should be labeled before AL begins)
66 | INIT_L_RATIO: 0.1
67 | # Max AL episodes
68 | MAX_ITER: 1
69 | DROPOUT_ITERATIONS: 10 # Used by DBAL
70 | # Useful when running `tools/ensemble_al.py` or `tools/ensemble_train.py`
71 | ENSEMBLE:
72 | NUM_MODELS: 3
73 | MODEL_TYPE: ['resnet18']
--------------------------------------------------------------------------------
/deep-al/configs/tinyimagenet/al/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: TINYIMAGENET
5 | ROOT_DIR: '/path/to/dataset/directory/'
6 | VAL_RATIO: 0.05
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 200
11 | OPTIM:
12 | TYPE: 'adam'
13 | BASE_LR: 0.001
14 | LR_POLICY: none
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 100
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 100
25 | IM_SIZE: 64
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 64
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INIT_L_RATIO: 0.1
37 | BUDGET_SIZE: 5000
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
40 | DELTA_RESOLUTION: 0.05
41 | MAX_DELTA: 1.1
--------------------------------------------------------------------------------
/deep-al/docs/AL_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/docs/AL_results.png
--------------------------------------------------------------------------------
/deep-al/pycls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/__init__.py
--------------------------------------------------------------------------------
/deep-al/pycls/al/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/al/__init__.py
--------------------------------------------------------------------------------
/deep-al/pycls/al/prob_cover.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | import pycls.datasets.utils as ds_utils
5 |
6 | class ProbCover:
7 | def __init__(self, cfg, lSet, uSet, budgetSize, delta):
8 | self.cfg = cfg
9 | self.ds_name = self.cfg['DATASET']['NAME']
10 | self.seed = self.cfg['RNG_SEED']
11 | self.all_features = ds_utils.load_features(self.ds_name, self.seed)
12 | self.lSet = lSet
13 | self.uSet = uSet
14 | self.budgetSize = budgetSize
15 | self.delta = delta
16 | self.relevant_indices = np.concatenate([self.lSet, self.uSet]).astype(int)
17 | self.rel_features = self.all_features[self.relevant_indices]
18 | self.graph_df = self.construct_graph()
19 |
20 | def construct_graph(self, batch_size=500):
21 | """
22 | creates a directed graph where:
23 | x->y iff l2(x,y) < delta.
24 |
25 | represented by a list of edges (a sparse matrix).
26 | stored in a dataframe
27 | """
28 | xs, ys, ds = [], [], []
29 | print(f'Start constructing graph using delta={self.delta}')
30 | # distance computations are done in GPU
31 | cuda_feats = torch.tensor(self.rel_features).cuda()
32 | for i in range(len(self.rel_features) // batch_size):
33 | # distance comparisons are done in batches to reduce memory consumption
34 | cur_feats = cuda_feats[i * batch_size: (i + 1) * batch_size]
35 | dist = torch.cdist(cur_feats, cuda_feats)
36 | mask = dist < self.delta
37 | # saving edges using indices list - saves memory.
38 | x, y = mask.nonzero().T
39 | xs.append(x.cpu() + batch_size * i)
40 | ys.append(y.cpu())
41 | ds.append(dist[mask].cpu())
42 |
43 | xs = torch.cat(xs).numpy()
44 | ys = torch.cat(ys).numpy()
45 | ds = torch.cat(ds).numpy()
46 |
47 | df = pd.DataFrame({'x': xs, 'y': ys, 'd': ds})
48 | print(f'Finished constructing graph using delta={self.delta}')
49 | print(f'Graph contains {len(df)} edges.')
50 | return df
51 |
52 | def select_samples(self):
53 | """
54 | selecting samples using the greedy algorithm.
55 | iteratively:
56 | - removes incoming edges to all covered samples
57 | - selects the sample high the highest out degree (covers most new samples)
58 |
59 | """
60 | print(f'Start selecting {self.budgetSize} samples.')
61 | selected = []
62 | # removing incoming edges to all covered samples from the existing labeled set
63 | edge_from_seen = np.isin(self.graph_df.x, np.arange(len(self.lSet)))
64 | covered_samples = self.graph_df.y[edge_from_seen].unique()
65 | cur_df = self.graph_df[(~np.isin(self.graph_df.y, covered_samples))]
66 | for i in range(self.budgetSize):
67 | coverage = len(covered_samples) / len(self.relevant_indices)
68 | # selecting the sample with the highest degree
69 | degrees = np.bincount(cur_df.x, minlength=len(self.relevant_indices))
70 | print(f'Iteration is {i}.\tGraph has {len(cur_df)} edges.\tMax degree is {degrees.max()}.\tCoverage is {coverage:.3f}')
71 | cur = degrees.argmax()
72 | # cur = np.random.choice(degrees.argsort()[::-1][:5]) # the paper randomizes selection
73 |
74 | # removing incoming edges to newly covered samples
75 | new_covered_samples = cur_df.y[(cur_df.x == cur)].values
76 | assert len(np.intersect1d(covered_samples, new_covered_samples)) == 0, 'all samples should be new'
77 | cur_df = cur_df[(~np.isin(cur_df.y, new_covered_samples))]
78 |
79 | covered_samples = np.concatenate([covered_samples, new_covered_samples])
80 | selected.append(cur)
81 |
82 | assert len(selected) == self.budgetSize, 'added a different number of samples'
83 | activeSet = self.relevant_indices[selected]
84 | remainSet = np.array(sorted(list(set(self.uSet) - set(activeSet))))
85 |
86 | print(f'Finished the selection of {len(activeSet)} samples.')
87 | print(f'Active set is {activeSet}')
88 | return activeSet, remainSet
89 |
--------------------------------------------------------------------------------
/deep-al/pycls/al/typiclust.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import faiss
4 | from sklearn.cluster import MiniBatchKMeans, KMeans
5 | import pycls.datasets.utils as ds_utils
6 |
7 | def get_nn(features, num_neighbors):
8 | # calculates nearest neighbors on GPU
9 | d = features.shape[1]
10 | features = features.astype(np.float32)
11 | cpu_index = faiss.IndexFlatL2(d)
12 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index)
13 | gpu_index.add(features) # add vectors to the index
14 | distances, indices = gpu_index.search(features, num_neighbors + 1)
15 | # 0 index is the same sample, dropping it
16 | return distances[:, 1:], indices[:, 1:]
17 |
18 |
19 | def get_mean_nn_dist(features, num_neighbors, return_indices=False):
20 | distances, indices = get_nn(features, num_neighbors)
21 | mean_distance = distances.mean(axis=1)
22 | if return_indices:
23 | return mean_distance, indices
24 | return mean_distance
25 |
26 |
27 | def calculate_typicality(features, num_neighbors):
28 | mean_distance = get_mean_nn_dist(features, num_neighbors)
29 | # low distance to NN is high density
30 | typicality = 1 / (mean_distance + 1e-5)
31 | return typicality
32 |
33 |
34 | def kmeans(features, num_clusters):
35 | if num_clusters <= 50:
36 | km = KMeans(n_clusters=num_clusters)
37 | km.fit_predict(features)
38 | else:
39 | km = MiniBatchKMeans(n_clusters=num_clusters, batch_size=5000)
40 | km.fit_predict(features)
41 | return km.labels_
42 |
43 |
44 | class TypiClust:
45 | MIN_CLUSTER_SIZE = 5
46 | MAX_NUM_CLUSTERS = 500
47 | K_NN = 20
48 |
49 | def __init__(self, cfg, lSet, uSet, budgetSize, is_scan=False):
50 | self.cfg = cfg
51 | self.ds_name = self.cfg['DATASET']['NAME']
52 | self.seed = self.cfg['RNG_SEED']
53 | self.features = None
54 | self.clusters = None
55 | self.lSet = lSet
56 | self.uSet = uSet
57 | self.budgetSize = budgetSize
58 | self.init_features_and_clusters(is_scan)
59 |
60 | def init_features_and_clusters(self, is_scan):
61 | num_clusters = min(len(self.lSet) + self.budgetSize, self.MAX_NUM_CLUSTERS)
62 | print(f'Clustering into {num_clusters} clustering. Scan clustering: {is_scan}')
63 | if is_scan:
64 | fname_dict = {'CIFAR10': f'../../scan/results/cifar-10/scan/features_seed{self.seed}_clusters{num_clusters}.npy',
65 | 'CIFAR100': f'../../scan/results/cifar-100/scan/features_seed{self.seed}_clusters{num_clusters}.npy',
66 | 'TINYIMAGENET': f'../../scan/results/tiny-imagenet/scan/features_seed{self.seed}_clusters{num_clusters}.npy',
67 | }
68 | fname = fname_dict[self.ds_name]
69 | self.features = np.load(fname)
70 | self.clusters = np.load(fname.replace('features', 'probs')).argmax(axis=-1)
71 | else:
72 | self.features = ds_utils.load_features(self.ds_name, self.seed)
73 | self.clusters = kmeans(self.features, num_clusters=num_clusters)
74 | print(f'Finished clustering into {num_clusters} clusters.')
75 |
76 | def select_samples(self):
77 | # using only labeled+unlabeled indices, without validation set.
78 | relevant_indices = np.concatenate([self.lSet, self.uSet]).astype(int)
79 | features = self.features[relevant_indices]
80 | labels = np.copy(self.clusters[relevant_indices])
81 | existing_indices = np.arange(len(self.lSet))
82 | # counting cluster sizes and number of labeled samples per cluster
83 | cluster_ids, cluster_sizes = np.unique(labels, return_counts=True)
84 | cluster_labeled_counts = np.bincount(labels[existing_indices], minlength=len(cluster_ids))
85 | clusters_df = pd.DataFrame({'cluster_id': cluster_ids, 'cluster_size': cluster_sizes, 'existing_count': cluster_labeled_counts,
86 | 'neg_cluster_size': -1 * cluster_sizes})
87 | # drop too small clusters
88 | clusters_df = clusters_df[clusters_df.cluster_size > self.MIN_CLUSTER_SIZE]
89 | # sort clusters by lowest number of existing samples, and then by cluster sizes (large to small)
90 | clusters_df = clusters_df.sort_values(['existing_count', 'neg_cluster_size'])
91 | labels[existing_indices] = -1
92 |
93 | selected = []
94 |
95 | for i in range(self.budgetSize):
96 | cluster = clusters_df.iloc[i % len(clusters_df)].cluster_id
97 | indices = (labels == cluster).nonzero()[0]
98 | rel_feats = features[indices]
99 | # in case we have too small cluster, calculate density among half of the cluster
100 | typicality = calculate_typicality(rel_feats, min(self.K_NN, len(indices) // 2))
101 | idx = indices[typicality.argmax()]
102 | selected.append(idx)
103 | labels[idx] = -1
104 |
105 | selected = np.array(selected)
106 | assert len(selected) == self.budgetSize, 'added a different number of samples'
107 | assert len(np.intersect1d(selected, existing_indices)) == 0, 'should be new samples'
108 | activeSet = relevant_indices[selected]
109 | remainSet = np.array(sorted(list(set(self.uSet) - set(activeSet))))
110 |
111 | print(f'Finished the selection of {len(activeSet)} samples.')
112 | print(f'Active set is {activeSet}')
113 | return activeSet, remainSet
114 |
--------------------------------------------------------------------------------
/deep-al/pycls/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/core/__init__.py
--------------------------------------------------------------------------------
/deep-al/pycls/core/builders.py:
--------------------------------------------------------------------------------
1 | # This file is modified from official pycls repository
2 |
3 | """Model and loss construction functions."""
4 |
5 | from pycls.core.net import SoftCrossEntropyLoss
6 | from pycls.models.resnet import *
7 | from pycls.models.vgg import *
8 | from pycls.models.alexnet import *
9 |
10 | import torch
11 | from torch import nn
12 | from torch.nn import functional as F
13 | # Supported models
14 | _models = {
15 | # VGG style architectures
16 | 'vgg11': vgg11,
17 | 'vgg11_bn': vgg11_bn,
18 | 'vgg13': vgg13,
19 | 'vgg13_bn': vgg13_bn,
20 | 'vgg16': vgg16,
21 | 'vgg16_bn': vgg16_bn,
22 | 'vgg19': vgg19,
23 | 'vgg19_bn': vgg19_bn,
24 |
25 | # ResNet style archiectures
26 | 'resnet18': resnet18,
27 | 'resnet34': resnet34,
28 | 'resnet50': resnet50,
29 | 'resnet101': resnet101,
30 | 'resnet152': resnet152,
31 | 'resnext50_32x4d': resnext50_32x4d,
32 | 'resnext101_32x8d': resnext101_32x8d,
33 | 'wide_resnet50_2': wide_resnet50_2,
34 | 'wide_resnet101_2': wide_resnet101_2,
35 |
36 | # AlexNet architecture
37 | 'alexnet': alexnet
38 | }
39 |
40 | # Supported loss functions
41 | _loss_funs = {"cross_entropy": SoftCrossEntropyLoss}
42 |
43 |
44 | class FeaturesNet(nn.Module):
45 | def __init__(self, in_layers, out_layers, use_mlp=False, penultimate_active=False):
46 | super().__init__()
47 | self.use_mlp = use_mlp
48 | self.penultimate_active = penultimate_active
49 | self.lin1 = nn.Linear(in_layers, in_layers)
50 | self.lin2 = nn.Linear(in_layers, in_layers)
51 | self.final = nn.Linear(in_layers, out_layers)
52 |
53 | def forward(self, x):
54 | feats = x
55 | if self.use_mlp:
56 | x = F.relu(self.lin1(x))
57 | x = F.relu((self.lin2(x)))
58 | out = self.final(x)
59 | if self.penultimate_active:
60 | return feats, out
61 | return out
62 |
63 |
64 | def get_model(cfg):
65 | """Gets the model class specified in the config."""
66 | err_str = "Model type '{}' not supported"
67 | assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
68 | return _models[cfg.MODEL.TYPE]
69 |
70 |
71 | def get_loss_fun(cfg):
72 | """Gets the loss function class specified in the config."""
73 | err_str = "Loss function type '{}' not supported"
74 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
75 | return _loss_funs[cfg.MODEL.LOSS_FUN]
76 |
77 |
78 | def build_model(cfg):
79 | """Builds the model."""
80 | if cfg.MODEL.LINEAR_FROM_FEATURES:
81 | num_features = 384 if cfg.DATASET.NAME in ['IMAGENET50', 'IMAGENET100', 'IMAGENET200'] else 512
82 | return FeaturesNet(num_features, cfg.MODEL.NUM_CLASSES)
83 |
84 | model = get_model(cfg)(num_classes=cfg.MODEL.NUM_CLASSES, use_dropout=True)
85 | if cfg.DATASET.NAME == 'MNIST':
86 | model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
87 |
88 | return model
89 |
90 |
91 | def build_loss_fun(cfg):
92 | """Build the loss function."""
93 | return get_loss_fun(cfg)()
94 |
95 |
96 | def register_model(name, ctor):
97 | """Registers a model dynamically."""
98 | _models[name] = ctor
99 |
100 |
101 | def register_loss_fun(name, ctor):
102 | """Registers a loss function dynamically."""
103 | _loss_funs[name] = ctor
104 |
--------------------------------------------------------------------------------
/deep-al/pycls/core/losses.py:
--------------------------------------------------------------------------------
1 | """Loss functions."""
2 |
3 | import torch.nn as nn
4 |
5 | from pycls.core.config import cfg
6 |
7 | # Supported loss functions
8 | _loss_funs = {
9 | 'cross_entropy': nn.CrossEntropyLoss,
10 | }
11 |
12 | def get_loss_fun():
13 | """Retrieves the loss function."""
14 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), \
15 | 'Loss function \'{}\' not supported'.format(cfg.TRAIN.LOSS)
16 | return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda()
17 |
18 |
19 | def register_loss_fun(name, ctor):
20 | """Registers a loss function dynamically."""
21 | _loss_funs[name] = ctor
22 |
--------------------------------------------------------------------------------
/deep-al/pycls/core/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for manipulating networks."""
9 |
10 | import numpy as np
11 | # import pycls.core.distributed as dist
12 | import torch
13 | from pycls.core.config import cfg
14 |
15 |
16 | def unwrap_model(model):
17 | """Remove the DistributedDataParallel wrapper if present."""
18 | wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel)
19 | return model.module if wrapped else model
20 |
21 |
22 | # @torch.no_grad()
23 | # def compute_precise_bn_stats(model, loader):
24 | # """Computes precise BN stats on training data."""
25 | # # Compute the number of minibatches to use
26 | # num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size / cfg.NUM_GPUS)
27 | # num_iter = min(num_iter, len(loader))
28 | # # Retrieve the BN layers
29 | # bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
30 | # # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
31 | # running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
32 | # running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
33 | # # Remember momentum values
34 | # momentums = [bn.momentum for bn in bns]
35 | # # Set momentum to 1.0 to compute BN stats that only reflect the current batch
36 | # for bn in bns:
37 | # bn.momentum = 1.0
38 | # # Average the BN stats for each BN layer over the batches
39 | # for inputs, _labels in itertools.islice(loader, num_iter):
40 | # model(inputs.cuda())
41 | # for i, bn in enumerate(bns):
42 | # running_means[i] += bn.running_mean / num_iter
43 | # running_vars[i] += bn.running_var / num_iter
44 | # # Sync BN stats across GPUs (no reduction if 1 GPU used)
45 | # running_means = dist.scaled_all_reduce(running_means)
46 | # running_vars = dist.scaled_all_reduce(running_vars)
47 | # # Set BN stats and restore original momentum values
48 | # for i, bn in enumerate(bns):
49 | # bn.running_mean = running_means[i]
50 | # bn.running_var = running_vars[i]
51 | # bn.momentum = momentums[i]
52 |
53 |
54 | def complexity(model):
55 | """Compute model complexity (model can be model instance or model class)."""
56 | size = cfg.TRAIN.IM_SIZE
57 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
58 | cx = unwrap_model(model).complexity(cx)
59 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
60 |
61 |
62 | def smooth_one_hot_labels(labels):
63 | """Convert each label to a one-hot vector."""
64 | n_classes, label_smooth = cfg.MODEL.NUM_CLASSES, cfg.TRAIN.LABEL_SMOOTHING
65 | err_str = "Invalid input to one_hot_vector()"
66 | assert labels.ndim == 1 and labels.max() < n_classes, err_str
67 | shape = (labels.shape[0], n_classes)
68 | neg_val = label_smooth / n_classes
69 | pos_val = 1.0 - label_smooth + neg_val
70 | labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device)
71 | labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val)
72 | return labels_one_hot
73 |
74 |
75 | class SoftCrossEntropyLoss(torch.nn.Module):
76 | """SoftCrossEntropyLoss (useful for label smoothing and mixup).
77 | Identical to torch.nn.CrossEntropyLoss if used with one-hot labels."""
78 |
79 | def __init__(self):
80 | super(SoftCrossEntropyLoss, self).__init__()
81 |
82 | def forward(self, x, y):
83 | loss = -y * torch.nn.functional.log_softmax(x, -1)
84 | return torch.sum(loss) / x.shape[0]
85 |
86 |
87 | def mixup(inputs, labels):
88 | """Apply mixup to minibatch (https://arxiv.org/abs/1710.09412)."""
89 | alpha = cfg.TRAIN.MIXUP_ALPHA
90 | assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot"
91 | if alpha > 0:
92 | m = np.random.beta(alpha, alpha)
93 | permutation = torch.randperm(labels.shape[0])
94 | inputs = m * inputs + (1.0 - m) * inputs[permutation, :]
95 | labels = m * labels + (1.0 - m) * labels[permutation, :]
96 | return inputs, labels, labels.argmax(1)
97 |
--------------------------------------------------------------------------------
/deep-al/pycls/core/optimizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Optimizer."""
9 |
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import torch
13 |
14 |
15 | def construct_optimizer(cfg, model):
16 | """Constructs the optimizer.
17 |
18 | Note that the momentum update in PyTorch differs from the one in Caffe2.
19 | In particular,
20 |
21 | Caffe2:
22 | V := mu * V + lr * g
23 | p := p - V
24 |
25 | PyTorch:
26 | V := mu * V + g
27 | p := p - lr * V
28 |
29 | where V is the velocity, mu is the momentum factor, lr is the learning rate,
30 | g is the gradient and p are the parameters.
31 |
32 | Since V is defined independently of the learning rate in PyTorch,
33 | when the learning rate is changed there is no need to perform the
34 | momentum correction by scaling V (unlike in the Caffe2 case).
35 | """
36 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
37 | # Apply different weight decay to Batchnorm and non-batchnorm parameters.
38 | p_bn = [p for n, p in model.named_parameters() if "bn" in n]
39 | p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
40 | optim_params = [
41 | {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
42 | {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
43 | ]
44 | else:
45 | optim_params = model.parameters()
46 |
47 | if cfg.OPTIM.TYPE == 'sgd':
48 | optimizer = torch.optim.SGD(
49 | model.parameters(),
50 | lr=cfg.OPTIM.BASE_LR,
51 | momentum=cfg.OPTIM.MOMENTUM,
52 | weight_decay=cfg.OPTIM.WEIGHT_DECAY,
53 | dampening=cfg.OPTIM.DAMPENING,
54 | nesterov=cfg.OPTIM.NESTEROV
55 | )
56 | elif cfg.OPTIM.TYPE == 'adam':
57 | optimizer = torch.optim.Adam(
58 | model.parameters(),
59 | lr=cfg.OPTIM.BASE_LR,
60 | weight_decay=cfg.OPTIM.WEIGHT_DECAY
61 | )
62 | else:
63 | raise NotImplementedError
64 |
65 | return optimizer
66 |
67 |
68 | def lr_fun_steps(cfg, cur_epoch):
69 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
70 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
71 | return cfg.OPTIM.LR_MULT ** ind
72 |
73 |
74 | def lr_fun_exp(cfg, cur_epoch):
75 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
76 | return cfg.OPTIM.MIN_LR ** (cur_epoch / cfg.OPTIM.MAX_EPOCH)
77 |
78 |
79 | def lr_fun_cos(cfg, cur_epoch):
80 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
81 | lr = 0.5 * (1.0 + np.cos(np.pi * cur_epoch / cfg.OPTIM.MAX_EPOCH))
82 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR
83 |
84 |
85 | def lr_fun_lin(cfg, cur_epoch):
86 | """Linear schedule (cfg.OPTIM.LR_POLICY = 'lin')."""
87 | lr = 1.0 - cur_epoch / cfg.OPTIM.MAX_EPOCH
88 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR
89 |
90 |
91 | def lr_fun_none(cfg, cur_epoch):
92 | """No schedule (cfg.OPTIM.LR_POLICY = 'none')."""
93 | return 1
94 |
95 |
96 | def get_lr_fun(cfg):
97 | """Retrieves the specified lr policy function"""
98 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
99 | assert lr_fun in globals(), "Unknown LR policy: " + cfg.OPTIM.LR_POLICY
100 | err_str = "exp lr policy requires OPTIM.MIN_LR to be greater than 0."
101 | assert cfg.OPTIM.LR_POLICY != "exp" or cfg.OPTIM.MIN_LR > 0, err_str
102 | return globals()[lr_fun]
103 |
104 |
105 | def get_epoch_lr(cfg, cur_epoch):
106 | """Retrieves the lr for the given epoch according to the policy."""
107 | # Get lr and scale by by BASE_LR
108 | lr = get_lr_fun(cfg)(cfg, cur_epoch) * cfg.OPTIM.BASE_LR
109 | # Linear warmup
110 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS and 'none' not in cfg.OPTIM.LR_POLICY:
111 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
112 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
113 | lr *= warmup_factor
114 | return lr
115 |
116 |
117 | def set_lr(optimizer, new_lr):
118 | """Sets the optimizer lr to the specified value."""
119 | for param_group in optimizer.param_groups:
120 | param_group["lr"] = new_lr
121 |
122 |
123 | def plot_lr_fun():
124 | """Visualizes lr function."""
125 | epochs = list(range(cfg.OPTIM.MAX_EPOCH))
126 | lrs = [get_epoch_lr(epoch) for epoch in epochs]
127 | plt.plot(epochs, lrs, ".-")
128 | plt.title("lr_policy: {}".format(cfg.OPTIM.LR_POLICY))
129 | plt.xlabel("epochs")
130 | plt.ylabel("learning rate")
131 | plt.ylim(bottom=0)
132 | plt.show()
133 |
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/datasets/__init__.py
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/custom_datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from PIL import Image
4 | import numpy as np
5 | import pycls.datasets.utils as ds_utils
6 |
7 |
8 | class CIFAR10(torchvision.datasets.CIFAR10):
9 | def __init__(self, root, train, transform, test_transform, download=True, only_features= False):
10 | super(CIFAR10, self).__init__(root, train, transform=transform, download=download)
11 | self.test_transform = test_transform
12 | self.no_aug = False
13 | self.only_features = only_features
14 | self.features = ds_utils.load_features("CIFAR10", train=train, normalized=False)
15 |
16 |
17 | def __getitem__(self, index: int):
18 | """
19 | Args:
20 | index (int): Index
21 |
22 | Returns:
23 | tuple: (image, target) where target is index of the target class.
24 | """
25 | img, target = self.data[index], self.targets[index]
26 |
27 | # doing this so that it is consistent with all other datasets
28 | # to return a PIL Image
29 | img = Image.fromarray(img)
30 | if self.only_features:
31 | img = self.features[index]
32 | else:
33 | if self.no_aug:
34 | if self.test_transform is not None:
35 | img = self.test_transform(img)
36 | else:
37 | if self.transform is not None:
38 | img = self.transform(img)
39 |
40 |
41 | return img, target
42 |
43 |
44 | class CIFAR100(torchvision.datasets.CIFAR100):
45 | def __init__(self, root, train, transform, test_transform, download=True, only_features= False):
46 | super(CIFAR100, self).__init__(root, train, transform=transform, download=download)
47 | self.test_transform = test_transform
48 | self.no_aug = False
49 | self.only_features = only_features
50 | self.features = ds_utils.load_features("CIFAR100", train=train, normalized=False)
51 |
52 | def __getitem__(self, index: int):
53 | """
54 | Args:
55 | index (int): Index
56 |
57 | Returns:
58 | tuple: (image, target) where target is index of the target class.
59 | """
60 | img, target = self.data[index], self.targets[index]
61 |
62 | # doing this so that it is consistent with all other datasets
63 | # to return a PIL Image
64 | img = Image.fromarray(img)
65 | if self.only_features:
66 | img = self.features[index]
67 | else:
68 | if self.no_aug:
69 | if self.test_transform is not None:
70 | img = self.test_transform(img)
71 | else:
72 | if self.transform is not None:
73 | img = self.transform(img)
74 |
75 | return img, target
76 |
77 |
78 | class STL10(torchvision.datasets.STL10):
79 | def __init__(self, root, train, transform, test_transform, download=True):
80 | super(STL10, self).__init__(root, train, transform=transform, download=download)
81 | self.test_transform = test_transform
82 | self.no_aug = False
83 | self.targets = self.labels
84 |
85 | def __getitem__(self, index: int):
86 | """
87 | Args:
88 | index (int): Index
89 |
90 | Returns:
91 | tuple: (image, target) where target is index of the target class.
92 | """
93 | img, target = self.data[index], int(self.targets[index])
94 |
95 | # doing this so that it is consistent with all other datasets
96 | # to return a PIL Image
97 | img = Image.fromarray(img.transpose(1,2,0))
98 |
99 | if self.no_aug:
100 | if self.test_transform is not None:
101 | img = self.test_transform(img)
102 | else:
103 | if self.transform is not None:
104 | img = self.transform(img)
105 |
106 | return img, target
107 |
108 |
109 | class MNIST(torchvision.datasets.MNIST):
110 | def __init__(self, root, train, transform, test_transform, download=True):
111 | super(MNIST, self).__init__(root, train, transform=transform, download=download)
112 | self.test_transform = test_transform
113 | self.no_aug = False
114 |
115 | def __getitem__(self, index: int):
116 | """
117 | Args:
118 | index (int): Index
119 |
120 | Returns:
121 | tuple: (image, target) where target is index of the target class.
122 | """
123 | img, target = self.data[index], int(self.targets[index])
124 |
125 | # doing this so that it is consistent with all other datasets
126 | # to return a PIL Image
127 | img = Image.fromarray(img.numpy(), mode='L')
128 |
129 | if self.no_aug:
130 | if self.test_transform is not None:
131 | img = self.test_transform(img)
132 | else:
133 | if self.transform is not None:
134 | img = self.transform(img)
135 |
136 |
137 | return img, target
138 |
139 |
140 | class SVHN(torchvision.datasets.SVHN):
141 | def __init__(self, root, train, transform, test_transform, download=True):
142 | super(SVHN, self).__init__(root, train, transform=transform, download=download)
143 | self.test_transform = test_transform
144 | self.no_aug = False
145 |
146 | def __getitem__(self, index: int):
147 | """
148 | Args:
149 | index (int): Index
150 |
151 | Returns:
152 | tuple: (image, target) where target is index of the target class.
153 | """
154 | img, target = self.data[index], self.targets[index]
155 |
156 | # doing this so that it is consistent with all other datasets
157 | # to return a PIL Image
158 | img = Image.fromarray(img)
159 |
160 | if self.no_aug:
161 | if self.test_transform is not None:
162 | img = self.test_transform(img)
163 | else:
164 | if self.transform is not None:
165 | img = self.transform(img)
166 |
167 |
168 | return img, target
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/imbalanced_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | Credits: Kaihua Tang
3 | Source: https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/
4 | """
5 |
6 | import numpy as np
7 | from pycls.datasets.custom_datasets import CIFAR10, CIFAR100
8 |
9 | class IMBALANCECIFAR10(CIFAR10):
10 | cls_num = 10
11 | np.random.seed(1)
12 | def __init__(self, root, train, transform=None, test_transform=None, imbalance_ratio=0.02, imb_type='exp'):
13 | super(IMBALANCECIFAR10, self).__init__(root, train, transform=transform, test_transform=test_transform, download=True)
14 | self.train = train
15 | self.transform = transform
16 | if self.train:
17 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio)
18 | self.gen_imbalanced_data(img_num_list)
19 | phase = 'Train'
20 | else:
21 | phase = 'Test'
22 | self.labels = self.targets
23 |
24 | print("{} Mode: Contain {} images".format(phase, len(self.data)))
25 |
26 | def _get_class_dict(self):
27 | class_dict = dict()
28 | for i, anno in enumerate(self.get_annotations()):
29 | cat_id = anno["category_id"]
30 | if not cat_id in class_dict:
31 | class_dict[cat_id] = []
32 | class_dict[cat_id].append(i)
33 | return class_dict
34 |
35 |
36 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
37 | img_max = len(self.data) / cls_num
38 | img_num_per_cls = []
39 | if imb_type == 'exp':
40 | for cls_idx in range(cls_num):
41 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
42 | img_num_per_cls.append(int(num))
43 | elif imb_type == 'step':
44 | for cls_idx in range(cls_num // 2):
45 | img_num_per_cls.append(int(img_max))
46 | for cls_idx in range(cls_num // 2):
47 | img_num_per_cls.append(int(img_max * imb_factor))
48 | else:
49 | img_num_per_cls.extend([int(img_max)] * cls_num)
50 | return img_num_per_cls
51 |
52 | def gen_imbalanced_data(self, img_num_per_cls):
53 |
54 | new_data = []
55 | new_targets = []
56 | targets_np = np.array(self.targets, dtype=np.int64)
57 | classes = np.unique(targets_np)
58 |
59 | self.num_per_cls_dict = dict()
60 | for the_class, the_img_num in zip(classes, img_num_per_cls):
61 | self.num_per_cls_dict[the_class] = the_img_num
62 | idx = np.where(targets_np == the_class)[0]
63 | np.random.shuffle(idx)
64 | selec_idx = idx[:the_img_num]
65 | new_data.append(self.data[selec_idx, ...])
66 | new_targets.extend([the_class, ] * the_img_num)
67 | new_data = np.vstack(new_data)
68 | self.data = new_data
69 | self.targets = new_targets
70 |
71 | def __len__(self):
72 | return len(self.labels)
73 |
74 | def get_num_classes(self):
75 | return self.cls_num
76 |
77 | def get_annotations(self):
78 | annos = []
79 | for label in self.labels:
80 | annos.append({'category_id': int(label)})
81 | return annos
82 |
83 | def get_cls_num_list(self):
84 | cls_num_list = []
85 | for i in range(self.cls_num):
86 | cls_num_list.append(self.num_per_cls_dict[i])
87 | return cls_num_list
88 |
89 |
90 | class IMBALANCECIFAR100(CIFAR100):
91 | """`CIFAR100 `_ Dataset.
92 | This is a subclass of the `CIFAR10` Dataset.
93 | """
94 | cls_num = 100
95 | base_folder = 'cifar-100-python'
96 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
97 | filename = "cifar-100-python.tar.gz"
98 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
99 | train_list = [
100 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
101 | ]
102 |
103 | test_list = [
104 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
105 | ]
106 | meta = {
107 | 'filename': 'meta',
108 | 'key': 'fine_label_names',
109 | 'md5': '7973b15100ade9c7d40fb424638fde48',
110 | }
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | # This file is directly taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
2 | # GitHub: https://github.com/PrateekMunjal
3 | # ----------------------------------------------------------
4 |
5 | from torch.utils.data.sampler import Sampler
6 |
7 |
8 | class IndexedSequentialSampler(Sampler):
9 | r"""Samples elements sequentially, always in the same order.
10 |
11 | Arguments:
12 | data_idxes (Dataset indexes): dataset indexes to sample from
13 | """
14 |
15 | def __init__(self, data_idxes, isDebug=False):
16 | if isDebug: print("========= my custom squential sampler =========")
17 | self.data_idxes = data_idxes
18 |
19 | def __iter__(self):
20 | return (self.data_idxes[i] for i in range(len(self.data_idxes)))
21 |
22 | def __len__(self):
23 | return len(self.data_idxes)
24 |
25 | # class IndexedDistributedSampler(Sampler):
26 | # """Sampler that restricts data loading to a particular index set of dataset.
27 |
28 | # It is especially useful in conjunction with
29 | # :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
30 | # process can pass a DistributedSampler instance as a DataLoader sampler,
31 | # and load a subset of the original dataset that is exclusive to it.
32 |
33 | # .. note::
34 | # Dataset is assumed to be of constant size.
35 |
36 | # Arguments:
37 | # dataset: Dataset used for sampling.
38 | # num_replicas (optional): Number of processes participating in
39 | # distributed training.
40 | # rank (optional): Rank of the current process within num_replicas.
41 | # """
42 |
43 | # def __init__(self, dataset, index_set, num_replicas=None, rank=None, allowRepeat=True):
44 | # if num_replicas is None:
45 | # if not dist.is_available():
46 | # raise RuntimeError("Requires distributed package to be available")
47 | # num_replicas = dist.get_world_size()
48 | # if rank is None:
49 | # if not dist.is_available():
50 | # raise RuntimeError("Requires distributed package to be available")
51 | # rank = dist.get_rank()
52 | # self.dataset = dataset
53 | # self.num_replicas = num_replicas
54 | # self.rank = rank
55 | # self.epoch = 0
56 | # self.index_set = index_set
57 | # self.allowRepeat = allowRepeat
58 | # if self.allowRepeat:
59 | # self.num_samples = int(math.ceil(len(self.index_set) * 1.0 / self.num_replicas))
60 | # self.total_size = self.num_samples * self.num_replicas
61 | # else:
62 | # self.num_samples = int(math.ceil((len(self.index_set)-self.rank) * 1.0 / self.num_replicas))
63 | # self.total_size = len(self.index_set)
64 |
65 | # def __iter__(self):
66 | # # deterministically shuffle based on epoch
67 | # g = torch.Generator()
68 | # g.manual_seed(self.epoch)
69 | # indices = torch.randperm(len(self.index_set), generator=g).tolist()
70 | # #To access valid indices
71 | # #indices = self.index_set[indices]
72 | # # add extra samples to make it evenly divisible
73 | # if self.allowRepeat:
74 | # indices += indices[:(self.total_size - len(indices))]
75 |
76 | # assert len(indices) == self.total_size
77 |
78 | # # subsample
79 | # indices = self.index_set[indices[self.rank:self.total_size:self.num_replicas]]
80 | # assert len(indices) == self.num_samples, "len(indices): {} and self.num_samples: {}"\
81 | # .format(len(indices), self.num_samples)
82 |
83 | # return iter(indices)
84 |
85 | # def __len__(self):
86 | # return self.num_samples
87 |
88 | # def set_epoch(self, epoch):
89 | # self.epoch = epoch
90 |
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/simclr_augment.py:
--------------------------------------------------------------------------------
1 | # Modified from the source: https://github.com/sthalles/PyTorch-BYOL
2 | # Previous owner of this file: Thalles Silva
3 |
4 | from torchvision import transforms
5 | from .utils.gaussian_blur import GaussianBlur
6 |
7 | def get_simclr_ops(input_shape, s=1):
8 | # get a set of data augmentation transformations as described in the SimCLR paper.
9 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
10 | ops = [transforms.RandomHorizontalFlip(),
11 | transforms.RandomApply([color_jitter], p=0.8),
12 | transforms.RandomGrayscale(p=0.2),
13 | GaussianBlur(kernel_size=int(0.1 * input_shape)),]
14 | return ops
15 |
16 |
17 |
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | import torchvision.datasets as datasets
7 |
8 | from typing import Any
9 |
10 |
11 | class TinyImageNet(datasets.ImageFolder):
12 | """`Tiny ImageNet Classification Dataset.
13 |
14 | Args:
15 | root (string): Root directory of the ImageNet Dataset.
16 | split (string, optional): The dataset split, supports ``train``, or ``val``.
17 | transform (callable, optional): A function/transform that takes in an PIL image
18 | and returns a transformed version. E.g, ``transforms.RandomCrop``
19 | target_transform (callable, optional): A function/transform that takes in the
20 | target and transforms it.
21 | loader (callable, optional): A function to load an image given its path.
22 |
23 | Attributes:
24 | classes (list): List of the class name tuples.
25 | class_to_idx (dict): Dict with items (class_name, class_index).
26 | wnids (list): List of the WordNet IDs.
27 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
28 | samples (list): List of (image path, class_index) tuples
29 | targets (list): The class_index value for each image in the dataset
30 | """
31 | def __init__(self, root: str, split: str = 'train', transform=None, test_transform=None, only_features=False, **kwargs: Any) -> None:
32 | self.root = root
33 | self.test_transform = test_transform
34 | self.no_aug = False
35 |
36 | assert self.check_root(), "Something is wrong with the Tiny ImageNet dataset path. Download the official dataset zip from http://cs231n.stanford.edu/tiny-imagenet-200.zip and unzip it inside {}.".format(self.root)
37 | self.split = datasets.utils.verify_str_arg(split, "split", ("train", "val"))
38 | wnid_to_classes = self.load_wnid_to_classes()
39 |
40 | super(TinyImageNet, self).__init__(self.split_folder, **kwargs)
41 | self.transform = transform
42 | self.wnids = self.classes
43 | self.wnid_to_idx = self.class_to_idx
44 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
45 | self.class_to_idx = {cls: idx
46 | for idx, clss in enumerate(self.classes)
47 | for cls in clss}
48 | # Tiny ImageNet val directory structure is not similar to that of train's
49 | # So a custom loading function is necessary
50 | self.only_features = only_features
51 | if self.split == 'train':
52 | self.features = np.load('../../scan/results/tiny-imagenet/pretext/features_seed1.npy')
53 | elif self.split == 'val':
54 | self.features = np.load('../../scan/results/tiny-imagenet/pretext/test_features_seed1.npy')
55 | self.root = root
56 | self.imgs, self.targets = self.load_val_data()
57 | self.samples = [(self.imgs[idx], self.targets[idx]) for idx in range(len(self.imgs))]
58 | self.root = os.path.join(self.root, 'val')
59 |
60 |
61 |
62 | # Split folder is used for the 'super' call. Since val directory is not structured like the train,
63 | # we simply use train's structure to get all classes and other stuff
64 | @property
65 | def split_folder(self) -> str:
66 | return os.path.join(self.root, 'train')
67 |
68 |
69 | def load_val_data(self):
70 | imgs, targets = [], []
71 | with open(os.path.join(self.root, 'val', 'val_annotations.txt'), 'r') as file:
72 | for line in file:
73 | if line.split()[1] in self.wnids:
74 | img_file, wnid = line.split('\t')[:2]
75 | imgs.append(os.path.join(self.root, 'val', 'images', img_file))
76 | targets.append(wnid)
77 | targets = np.array([self.wnid_to_idx[wnid] for wnid in targets])
78 | return imgs, targets
79 |
80 |
81 | def load_wnid_to_classes(self):
82 | wnid_to_classes = {}
83 | with open(os.path.join(self.root, 'words.txt'), 'r') as file:
84 | lines = file.readlines()
85 | lines = [x.split('\t') for x in lines]
86 | wnid_to_classes = {x[0]:x[1].strip() for x in lines}
87 | return wnid_to_classes
88 |
89 | def check_root(self):
90 | tinyim_set = ['words.txt', 'wnids.txt', 'train', 'val', 'test']
91 | for x in os.scandir(self.root):
92 | if x.name not in tinyim_set:
93 | return False
94 | return True
95 |
96 | def __getitem__(self, index: int):
97 | """
98 | Args:
99 | index (int): Index
100 |
101 | Returns:
102 | tuple: (sample, target) where target is class_index of the target class.
103 | """
104 | path, target = self.samples[index]
105 | sample = self.loader(path)
106 |
107 | if self.only_features:
108 | sample = self.features[index]
109 | else:
110 | if self.no_aug:
111 | if self.test_transform is not None:
112 | sample = self.test_transform(sample)
113 | else:
114 | if self.transform is not None:
115 | sample = self.transform(sample)
116 |
117 | return sample, target
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .features import load_features
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/utils/features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | DATASET_FEATURES_DICT = {
3 | 'train':
4 | {
5 | 'CIFAR10':'../../scan/results/cifar-10/pretext/features_seed{seed}.npy',
6 | 'CIFAR100':'../../scan/results/cifar-100/pretext/features_seed{seed}.npy',
7 | 'TINYIMAGENET': '../../scan/results/tiny-imagenet/pretext/features_seed{seed}.npy',
8 | 'IMAGENET50': '../../dino/runs/trainfeat.pth',
9 | 'IMAGENET100': '../../dino/runs/trainfeat.pth',
10 | 'IMAGENET200': '../../dino/runs/trainfeat.pth',
11 | },
12 | 'test':
13 | {
14 | 'CIFAR10': '../../scan/results/cifar-10/pretext/test_features_seed{seed}.npy',
15 | 'CIFAR100': '../../scan/results/cifar-100/pretext/test_features_seed{seed}.npy',
16 | 'TINYIMAGENET': '../../scan/results/tiny-imagenet/pretext/test_features_seed{seed}.npy',
17 | 'IMAGENET50': '../../dino/runs/testfeat.pth',
18 | 'IMAGENET100': '../../dino/runs/testfeat.pth',
19 | 'IMAGENET200': '../../dino/runs/testfeat.pth',
20 | }
21 | }
22 |
23 | def load_features(ds_name, seed=1, train=True, normalized=True):
24 | " load pretrained features for a dataset "
25 | split = "train" if train else "test"
26 | fname = DATASET_FEATURES_DICT[split][ds_name].format(seed=seed)
27 | if fname.endswith('.npy'):
28 | features = np.load(fname)
29 | elif fname.endswith('.pth'):
30 | features = torch.load(fname)
31 | else:
32 | raise Exception("Unsupported filetype")
33 | if normalized:
34 | features = features / np.linalg.norm(features, axis=1, keepdims=True)
35 | return features
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/utils/gaussian_blur.py:
--------------------------------------------------------------------------------
1 | # Owner of this file: Thalles Silva
2 | # Source: https://github.com/sthalles/PyTorch-BYOL
3 | import torch
4 | from torchvision import transforms
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 |
9 | class GaussianBlur(object):
10 | """blur a single image on CPU"""
11 |
12 | def __init__(self, kernel_size):
13 | radias = kernel_size // 2
14 | kernel_size = radias * 2 + 1
15 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
16 | stride=1, padding=0, bias=False, groups=3)
17 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
18 | stride=1, padding=0, bias=False, groups=3)
19 | self.k = kernel_size
20 | self.r = radias
21 |
22 | self.blur = nn.Sequential(
23 | nn.ReflectionPad2d(radias),
24 | self.blur_h,
25 | self.blur_v
26 | )
27 |
28 | self.pil_to_tensor = transforms.ToTensor()
29 | self.tensor_to_pil = transforms.ToPILImage()
30 |
31 | def __call__(self, img):
32 | img = self.pil_to_tensor(img).unsqueeze(0)
33 |
34 | sigma = np.random.uniform(0.1, 2.0)
35 | x = np.arange(-self.r, self.r + 1)
36 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
37 | x = x / x.sum()
38 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
39 |
40 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
41 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
42 |
43 | with torch.no_grad():
44 | img = self.blur(img)
45 | img = img.squeeze()
46 |
47 | img = self.tensor_to_pil(img)
48 |
49 | return img
--------------------------------------------------------------------------------
/deep-al/pycls/datasets/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import csv
2 |
3 | def load_imageset(path, set_name):
4 | """
5 | Returns the image set `set_name` present at `path` as a list.
6 | Keyword arguments:
7 | path -- path to data folder
8 | set_name -- image set name - labeled or unlabeled.
9 | """
10 | reader = csv.reader(open(os.path.join(path, set_name+'.csv'), 'rt'))
11 | reader = [r[0] for r in reader]
12 | return reader
--------------------------------------------------------------------------------
/deep-al/pycls/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Expose model constructors."""
9 |
10 |
--------------------------------------------------------------------------------
/deep-al/pycls/models/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
4 | from typing import Any
5 |
6 |
7 | __all__ = ['AlexNet', 'alexnet']
8 |
9 |
10 | model_urls = {
11 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
12 | }
13 |
14 |
15 | class AlexNet(nn.Module):
16 | '''
17 | AlexNet modified (features) for CIFAR10. Source: https://github.com/icpm/pytorch-cifar10/blob/master/models/AlexNet.py.
18 | '''
19 | def __init__(self, num_classes: int = 1000, use_dropout=False) -> None:
20 | super(AlexNet, self).__init__()
21 | self.use_dropout = use_dropout
22 | self.features = nn.Sequential(
23 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
24 | nn.ReLU(inplace=True),
25 | nn.MaxPool2d(kernel_size=2),
26 | nn.Conv2d(64, 192, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.MaxPool2d(kernel_size=2),
29 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
34 | nn.ReLU(inplace=True),
35 | nn.MaxPool2d(kernel_size=2),
36 | )
37 | # self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
38 | self.fc_block = nn.Sequential(
39 | nn.Linear(256 * 2 * 2, 4096, bias=False),
40 | nn.BatchNorm1d(4096),
41 | nn.ReLU(inplace=True),
42 | nn.Linear(4096, 4096, bias=False),
43 | nn.BatchNorm1d(4096),
44 | nn.ReLU(inplace=True),
45 | )
46 | self.classifier = nn.Sequential(
47 | nn.Linear(4096, num_classes),
48 | )
49 | self.penultimate_active = False
50 | self.drop = nn.Dropout(p=0.5)
51 |
52 | def forward(self, x: torch.Tensor) -> torch.Tensor:
53 | x = self.features(x)
54 | # x = self.avgpool(x)
55 | z = torch.flatten(x, 1)
56 | if self.use_dropout:
57 | x = self.drop(x)
58 | z = self.fc_block(z)
59 | x = self.classifier(z)
60 | if self.penultimate_active:
61 | return z, x
62 | return x
63 |
64 |
65 | def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet:
66 | r"""AlexNet model architecture from the
67 | `"One weird trick..." `_ paper.
68 | Args:
69 | pretrained (bool): If True, returns a model pre-trained on ImageNet
70 | progress (bool): If True, displays a progress bar of the download to stderr
71 | """
72 | model = AlexNet(**kwargs)
73 | if pretrained:
74 | state_dict = load_state_dict_from_url(model_urls['alexnet'],
75 | progress=progress)
76 | model.load_state_dict(state_dict)
77 | return model
--------------------------------------------------------------------------------
/deep-al/pycls/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/pycls/utils/__init__.py
--------------------------------------------------------------------------------
/deep-al/pycls/utils/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Benchmarking functions."""
9 |
10 | import pycls.core.logging as logging
11 | import pycls.core.net as net
12 | import pycls.datasets.loader as loader
13 | import torch
14 | import torch.cuda.amp as amp
15 | from pycls.core.config import cfg
16 | from pycls.core.timer import Timer
17 |
18 |
19 | logger = logging.get_logger(__name__)
20 |
21 |
22 | @torch.no_grad()
23 | def compute_time_eval(model):
24 | """Computes precise model forward test time using dummy data."""
25 | # Use eval mode
26 | model.eval()
27 | # Generate a dummy mini-batch and copy data to GPU
28 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
29 | inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
30 | # Compute precise forward pass time
31 | timer = Timer()
32 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
33 | for cur_iter in range(total_iter):
34 | # Reset the timers after the warmup phase
35 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
36 | timer.reset()
37 | # Forward
38 | timer.tic()
39 | model(inputs)
40 | torch.cuda.synchronize()
41 | timer.toc()
42 | return timer.average_time
43 |
44 |
45 | def compute_time_train(model, loss_fun):
46 | """Computes precise model forward + backward time using dummy data."""
47 | # Use train mode
48 | model.train()
49 | # Generate a dummy mini-batch and copy data to GPU
50 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
51 | inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
52 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
53 | labels_one_hot = net.smooth_one_hot_labels(labels)
54 | # Cache BatchNorm2D running stats
55 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
56 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
57 | # Create a GradScaler for mixed precision training
58 | scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
59 | # Compute precise forward backward pass time
60 | fw_timer, bw_timer = Timer(), Timer()
61 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
62 | for cur_iter in range(total_iter):
63 | # Reset the timers after the warmup phase
64 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
65 | fw_timer.reset()
66 | bw_timer.reset()
67 | # Forward
68 | fw_timer.tic()
69 | with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
70 | preds = model(inputs)
71 | loss = loss_fun(preds, labels_one_hot)
72 | torch.cuda.synchronize()
73 | fw_timer.toc()
74 | # Backward
75 | bw_timer.tic()
76 | scaler.scale(loss).backward()
77 | torch.cuda.synchronize()
78 | bw_timer.toc()
79 | # Restore BatchNorm2D running stats
80 | for bn, (mean, var) in zip(bns, bn_stats):
81 | bn.running_mean, bn.running_var = mean, var
82 | return fw_timer.average_time, bw_timer.average_time
83 |
84 |
85 | def compute_time_loader(data_loader):
86 | """Computes loader time."""
87 | timer = Timer()
88 | loader.shuffle(data_loader, 0)
89 | data_loader_iterator = iter(data_loader)
90 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
91 | total_iter = min(total_iter, len(data_loader))
92 | for cur_iter in range(total_iter):
93 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
94 | timer.reset()
95 | timer.tic()
96 | next(data_loader_iterator)
97 | timer.toc()
98 | return timer.average_time
99 |
100 |
101 | def compute_time_model(model, loss_fun):
102 | """Times model."""
103 | logger.info("Computing model timings only...")
104 | # Compute timings
105 | test_fw_time = compute_time_eval(model)
106 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
107 | train_fw_bw_time = train_fw_time + train_bw_time
108 | # Output iter timing
109 | iter_times = {
110 | "test_fw_time": test_fw_time,
111 | "train_fw_time": train_fw_time,
112 | "train_bw_time": train_bw_time,
113 | "train_fw_bw_time": train_fw_bw_time,
114 | }
115 | logger.info(logging.dump_log_data(iter_times, "iter_times"))
116 |
117 |
118 | def compute_time_full(model, loss_fun, train_loader, test_loader):
119 | """Times model and data loader."""
120 | logger.info("Computing model and loader timings...")
121 | # Compute timings
122 | test_fw_time = compute_time_eval(model)
123 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
124 | train_fw_bw_time = train_fw_time + train_bw_time
125 | train_loader_time = compute_time_loader(train_loader)
126 | # Output iter timing
127 | iter_times = {
128 | "test_fw_time": test_fw_time,
129 | "train_fw_time": train_fw_time,
130 | "train_bw_time": train_bw_time,
131 | "train_fw_bw_time": train_fw_bw_time,
132 | "train_loader_time": train_loader_time,
133 | }
134 | logger.info(logging.dump_log_data(iter_times, "iter_times"))
135 | # Output epoch timing
136 | epoch_times = {
137 | "test_fw_time": test_fw_time * len(test_loader),
138 | "train_fw_time": train_fw_time * len(train_loader),
139 | "train_bw_time": train_bw_time * len(train_loader),
140 | "train_fw_bw_time": train_fw_bw_time * len(train_loader),
141 | "train_loader_time": train_loader_time * len(train_loader),
142 | }
143 | logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
144 | # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
145 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
146 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
147 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions that handle saving and loading of checkpoints."""
9 |
10 | import os
11 |
12 | # import pycls.core.distributed as dist
13 | import torch
14 | from pycls.core.net import unwrap_model
15 |
16 |
17 | # Common prefix for checkpoint file names
18 | _NAME_PREFIX = "model_epoch_"
19 |
20 | # Checkpoints directory name
21 | _DIR_NAME = "checkpoints"
22 |
23 |
24 | def get_checkpoint_dir(episode_dir):
25 | """Retrieves the location for storing checkpoints."""
26 | return os.path.join(episode_dir, _DIR_NAME)
27 |
28 |
29 | def get_checkpoint(epoch, episode_dir):
30 | """Retrieves the path to a checkpoint file."""
31 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
32 | # return os.path.join(get_checkpoint_dir(), name)
33 | return os.path.join(episode_dir, name)
34 |
35 |
36 | def get_checkpoint_best(episode_dir):
37 | """Retrieves the path to the best checkpoint file."""
38 | return os.path.join(episode_dir, "model.pyth")
39 |
40 |
41 | def get_last_checkpoint(episode_dir):
42 | """Retrieves the most recent checkpoint (highest epoch number)."""
43 | checkpoint_dir = get_checkpoint_dir(episode_dir)
44 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
45 | last_checkpoint_name = sorted(checkpoints)[-1]
46 | return os.path.join(checkpoint_dir, last_checkpoint_name)
47 |
48 |
49 | def has_checkpoint(episode_dir):
50 | """Determines if there are checkpoints available."""
51 | checkpoint_dir = get_checkpoint_dir(episode_dir)
52 | if not os.path.exists(checkpoint_dir):
53 | return False
54 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
55 |
56 |
57 | def save_checkpoint(info, model_state, optimizer_state, epoch, cfg):
58 |
59 | """Saves a checkpoint."""
60 | # Save checkpoints only from the master process
61 | # if not dist.is_master_proc():
62 | # return
63 | # Ensure that the checkpoint dir exists
64 | os.makedirs(cfg.EPISODE_DIR, exist_ok=True)
65 |
66 | # Record the state
67 | checkpoint = {
68 | "epoch": epoch,
69 | "model_state": model_state,
70 | "optimizer_state": optimizer_state,
71 | "cfg": cfg.dump(),
72 | }
73 | global _NAME_PREFIX
74 | _NAME_PREFIX = info + '_' + _NAME_PREFIX
75 |
76 | # Write the checkpoint
77 | checkpoint_file = get_checkpoint(epoch, cfg.EPISODE_DIR)
78 | torch.save(checkpoint, checkpoint_file)
79 | # print("Model checkpoint saved at path: {}".format(checkpoint_file))
80 |
81 | _NAME_PREFIX = 'model_epoch_'
82 | return checkpoint_file
83 |
84 |
85 | def load_checkpoint(checkpoint_file, model, optimizer=None):
86 | """Loads the checkpoint from the given file."""
87 | err_str = "Checkpoint '{}' not found"
88 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
89 | checkpoint = torch.load(checkpoint_file, map_location="cpu")
90 | unwrap_model(model).load_state_dict(checkpoint["model_state"])
91 | optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else ()
92 | return model
93 |
94 |
95 | def delete_checkpoints(checkpoint_dir=None, keep="all"):
96 | """Deletes unneeded checkpoints, keep can be "all", "last", or "none"."""
97 | assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep)
98 | checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir()
99 | if keep == "all" or not os.path.exists(checkpoint_dir):
100 | return 0
101 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
102 | checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints
103 | [os.remove(os.path.join(checkpoint_dir, checkpoint)) for checkpoint in checkpoints]
104 | return len(checkpoints)
105 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/distributed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Distributed helpers."""
9 |
10 | import multiprocessing
11 | import os
12 | import random
13 | import signal
14 | import threading
15 | import traceback
16 |
17 | import torch
18 | from pycls.core.config import cfg
19 |
20 |
21 | # Make work w recent PyTorch versions (https://github.com/pytorch/pytorch/issues/37377)
22 | os.environ["MKL_THREADING_LAYER"] = "GNU"
23 |
24 |
25 | def is_master_proc():
26 | """Determines if the current process is the master process.
27 |
28 | Master process is responsible for logging, writing and loading checkpoints. In
29 | the multi GPU setting, we assign the master role to the rank 0 process. When
30 | training using a single GPU, there is a single process which is considered master.
31 | """
32 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
33 |
34 |
35 | def init_process_group(proc_rank, world_size, port):
36 | """Initializes the default process group."""
37 | # Set the GPU to use
38 | torch.cuda.set_device(proc_rank)
39 | # Initialize the process group
40 | torch.distributed.init_process_group(
41 | backend=cfg.DIST_BACKEND,
42 | init_method="tcp://{}:{}".format(cfg.HOST, port),
43 | world_size=world_size,
44 | rank=proc_rank,
45 | )
46 |
47 |
48 | def destroy_process_group():
49 | """Destroys the default process group."""
50 | torch.distributed.destroy_process_group()
51 |
52 |
53 | def scaled_all_reduce(tensors):
54 | """Performs the scaled all_reduce operation on the provided tensors.
55 |
56 | The input tensors are modified in-place. Currently supports only the sum
57 | reduction operator. The reduced values are scaled by the inverse size of the
58 | process group (equivalent to cfg.NUM_GPUS).
59 | """
60 | # There is no need for reduction in the single-proc case
61 | if cfg.NUM_GPUS == 1:
62 | return tensors
63 | # Queue the reductions
64 | reductions = []
65 | for tensor in tensors:
66 | reduction = torch.distributed.all_reduce(tensor, async_op=True)
67 | reductions.append(reduction)
68 | # Wait for reductions to finish
69 | for reduction in reductions:
70 | reduction.wait()
71 | # Scale the results
72 | for tensor in tensors:
73 | tensor.mul_(1.0 / cfg.NUM_GPUS)
74 | return tensors
75 |
76 |
77 | class ChildException(Exception):
78 | """Wraps an exception from a child process."""
79 |
80 | def __init__(self, child_trace):
81 | super(ChildException, self).__init__(child_trace)
82 |
83 |
84 | class ErrorHandler(object):
85 | """Multiprocessing error handler (based on fairseq's).
86 |
87 | Listens for errors in child processes and propagates the tracebacks to the parent.
88 | """
89 |
90 | def __init__(self, error_queue):
91 | # Shared error queue
92 | self.error_queue = error_queue
93 | # Children processes sharing the error queue
94 | self.children_pids = []
95 | # Start a thread listening to errors
96 | self.error_listener = threading.Thread(target=self.listen, daemon=True)
97 | self.error_listener.start()
98 | # Register the signal handler
99 | signal.signal(signal.SIGUSR1, self.signal_handler)
100 |
101 | def add_child(self, pid):
102 | """Registers a child process."""
103 | self.children_pids.append(pid)
104 |
105 | def listen(self):
106 | """Listens for errors in the error queue."""
107 | # Wait until there is an error in the queue
108 | child_trace = self.error_queue.get()
109 | # Put the error back for the signal handler
110 | self.error_queue.put(child_trace)
111 | # Invoke the signal handler
112 | os.kill(os.getpid(), signal.SIGUSR1)
113 |
114 | def signal_handler(self, _sig_num, _stack_frame):
115 | """Signal handler."""
116 | # Kill children processes
117 | for pid in self.children_pids:
118 | os.kill(pid, signal.SIGINT)
119 | # Propagate the error from the child process
120 | raise ChildException(self.error_queue.get())
121 |
122 |
123 | def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs):
124 | """Runs a function from a child process."""
125 | try:
126 | # Initialize the process group
127 | init_process_group(proc_rank, world_size, port)
128 | # Run the function
129 | fun(*fun_args, **fun_kwargs)
130 | except KeyboardInterrupt:
131 | # Killed by the parent process
132 | pass
133 | except Exception:
134 | # Propagate exception to the parent process
135 | error_queue.put(traceback.format_exc())
136 | finally:
137 | # Destroy the process group
138 | destroy_process_group()
139 |
140 |
141 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
142 | """Runs a function in a multi-proc setting (unless num_proc == 1)."""
143 | # There is no need for multi-proc in the single-proc case
144 | fun_kwargs = fun_kwargs if fun_kwargs else {}
145 | if num_proc == 1:
146 | fun(*fun_args, **fun_kwargs)
147 | return
148 | # Handle errors from training subprocesses
149 | error_queue = multiprocessing.SimpleQueue()
150 | error_handler = ErrorHandler(error_queue)
151 | # Get a random port to use (without using global random number generator)
152 | port = random.Random().randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1])
153 | # Run each training subprocess
154 | ps = []
155 | for i in range(num_proc):
156 | p_i = multiprocessing.Process(
157 | target=run, args=(i, num_proc, port, error_queue, fun, fun_args, fun_kwargs)
158 | )
159 | ps.append(p_i)
160 | p_i.start()
161 | error_handler.add_child(p_i.pid)
162 | # Wait for each subprocess to finish
163 | for p in ps:
164 | p.join()
165 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """IO utilities (adapted from Detectron)"""
9 |
10 | import logging
11 | import os
12 | import re
13 | import sys
14 | from urllib import request as urlrequest
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
20 |
21 |
22 | def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL):
23 | """Download the file specified by the URL to the cache_dir and return the path to
24 | the cached file. If the argument is not a URL, simply return it as is.
25 | """
26 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
27 | if not is_url:
28 | return url_or_file
29 | url = url_or_file
30 | assert url.startswith(base_url), "url must start with: {}".format(base_url)
31 | cache_file_path = url.replace(base_url, cache_dir)
32 | if os.path.exists(cache_file_path):
33 | return cache_file_path
34 | cache_file_dir = os.path.dirname(cache_file_path)
35 | if not os.path.exists(cache_file_dir):
36 | os.makedirs(cache_file_dir)
37 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
38 | download_url(url, cache_file_path)
39 | return cache_file_path
40 |
41 |
42 | def _progress_bar(count, total):
43 | """Report download progress. Credit:
44 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
45 | """
46 | bar_len = 60
47 | filled_len = int(round(bar_len * count / float(total)))
48 | percents = round(100.0 * count / float(total), 1)
49 | bar = "=" * filled_len + "-" * (bar_len - filled_len)
50 | sys.stdout.write(
51 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
52 | )
53 | sys.stdout.flush()
54 | if count >= total:
55 | sys.stdout.write("\n")
56 |
57 |
58 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
59 | """Download url and write it to dst_file_path. Credit:
60 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
61 | """
62 | req = urlrequest.Request(url)
63 | response = urlrequest.urlopen(req)
64 | total_size = response.info().get("Content-Length").strip()
65 | total_size = int(total_size)
66 | bytes_so_far = 0
67 | with open(dst_file_path, "wb") as f:
68 | while 1:
69 | chunk = response.read(chunk_size)
70 | bytes_so_far += len(chunk)
71 | if not chunk:
72 | break
73 | if progress_hook:
74 | progress_hook(bytes_so_far, total_size)
75 | f.write(chunk)
76 | return bytes_so_far
77 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Logging."""
9 |
10 | import builtins
11 | import decimal
12 | import logging
13 | import os
14 | import simplejson
15 | import sys
16 |
17 | # import pycls.utils.distributed as du
18 |
19 | # Show filename and line number in logs
20 | _FORMAT = '[%(asctime)s %(filename)s: %(lineno)3d]: %(message)s'
21 |
22 | # Log file name (for cfg.LOG_DEST = 'file')
23 | _LOG_FILE = 'stdout.log'
24 |
25 | # Printed json stats lines will be tagged w/ this
26 | _TAG = 'json_stats: '
27 |
28 |
29 | def _suppress_print():
30 | """Suppresses printing from the current process."""
31 | def ignore(*_objects, _sep=' ', _end='\n', _file=sys.stdout, _flush=False):
32 | pass
33 | builtins.print = ignore
34 |
35 |
36 | def setup_logging(cfg):
37 | """Sets up the logging."""
38 | # Enable logging only for the master process
39 | # if du.is_master_proc():
40 | if True:
41 | # Clear the root logger to prevent any existing logging config
42 | # (e.g. set by another module) from messing with our setup
43 | logging.root.handlers = []
44 | # Construct logging configuration
45 | logging_config = {
46 | 'level': logging.INFO,
47 | 'format': _FORMAT,
48 | 'datefmt': '%Y-%m-%d %H:%M:%S'
49 | }
50 | # Log either to stdout or to a file
51 | if cfg.LOG_DEST == 'stdout':
52 | logging_config['stream'] = sys.stdout
53 | else:
54 | logging_config['filename'] = os.path.join(cfg.EXP_DIR, _LOG_FILE)
55 | # Configure logging
56 | logging.basicConfig(**logging_config)
57 | else:
58 | _suppress_print()
59 |
60 |
61 | def get_logger(name):
62 | """Retrieves the logger."""
63 | return logging.getLogger(name)
64 |
65 |
66 | def log_json_stats(stats):
67 | """Logs json stats."""
68 | # Decimal + string workaround for having fixed len float vals in logs
69 | stats = {
70 | k: decimal.Decimal('{:.12f}'.format(v)) if isinstance(v, float) else v
71 | for k, v in stats.items()
72 | }
73 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
74 | logger = get_logger(__name__)
75 | logger.info('{:s}{:s}'.format(_TAG, json_stats))
76 |
77 |
78 | def load_json_stats(log_file):
79 | """Loads json_stats from a single log file."""
80 | with open(log_file, 'r') as f:
81 | lines = f.readlines()
82 | json_lines = [l[l.find(_TAG) + len(_TAG):] for l in lines if _TAG in l]
83 | json_stats = [simplejson.loads(l) for l in json_lines]
84 | return json_stats
85 |
86 |
87 | def parse_json_stats(log, row_type, key):
88 | """Extract values corresponding to row_type/key out of log."""
89 | vals = [row[key] for row in log if row['_type'] == row_type and key in row]
90 | if key == 'iter' or key == 'epoch':
91 | vals = [int(val.split('/')[0]) for val in vals]
92 | return vals
93 |
94 |
95 | def get_log_files(log_dir, name_filter=''):
96 | """Get all log files in directory containing subdirs of trained models."""
97 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
98 | files = [os.path.join(log_dir, n, _LOG_FILE) for n in names]
99 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
100 | files, names = zip(*f_n_ps)
101 | return files, names
102 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for computing metrics."""
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 |
14 | from pycls.core.config import cfg
15 |
16 | # Number of bytes in a megabyte
17 | _B_IN_MB = 1024 * 1024
18 |
19 |
20 | def topks_correct(preds, labels, ks):
21 | """Computes the number of top-k correct predictions for each k."""
22 | assert preds.size(0) == labels.size(0), \
23 | 'Batch dim of predictions and labels must match'
24 | # Find the top max_k predictions for each sample
25 | _top_max_k_vals, top_max_k_inds = torch.topk(
26 | preds, max(ks), dim=1, largest=True, sorted=True
27 | )
28 | # (batch_size, max_k) -> (max_k, batch_size)
29 | top_max_k_inds = top_max_k_inds.t()
30 | # (batch_size, ) -> (max_k, batch_size)
31 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
32 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct
33 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
34 | # Compute the number of topk correct predictions for each k
35 | topks_correct = [
36 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks
37 | ]
38 | return topks_correct
39 |
40 |
41 | def topk_errors(preds, labels, ks):
42 | """Computes the top-k error for each k."""
43 | num_topks_correct = topks_correct(preds, labels, ks)
44 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
45 |
46 |
47 | def topk_accuracies(preds, labels, ks):
48 | """Computes the top-k accuracy for each k."""
49 | num_topks_correct = topks_correct(preds, labels, ks)
50 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]
51 |
52 |
53 | def params_count(model):
54 | """Computes the number of parameters."""
55 | return np.sum([p.numel() for p in model.parameters()]).item()
56 |
57 |
58 | def flops_count(model):
59 | """Computes the number of flops statically."""
60 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
61 | count = 0
62 | for n, m in model.named_modules():
63 | if isinstance(m, nn.Conv2d):
64 | if 'se.' in n:
65 | count += m.in_channels * m.out_channels + m.bias.numel()
66 | continue
67 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
68 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
69 | count += np.prod([
70 | m.weight.numel(),
71 | h_out, w_out
72 | ])
73 | if '.proj' not in n:
74 | h, w = h_out, w_out
75 | elif isinstance(m, nn.MaxPool2d):
76 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
77 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
78 | elif isinstance(m, nn.Linear):
79 | count += m.in_features * m.out_features
80 | return count.item()
81 |
82 |
83 | def gpu_mem_usage():
84 | """Computes the GPU memory usage for the current device (MB)."""
85 | mem_usage_bytes = torch.cuda.max_memory_allocated()
86 | return mem_usage_bytes / _B_IN_MB
87 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for manipulating networks."""
9 |
10 | import itertools
11 | import math
12 | import torch
13 | import torch.nn as nn
14 |
15 | from pycls.core.config import cfg
16 |
17 |
18 | def init_weights(m):
19 | """Performs ResNet-style weight initialization."""
20 | if isinstance(m, nn.Conv2d):
21 | # Note that there is no bias due to BN
22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
24 | elif isinstance(m, nn.BatchNorm2d):
25 | zero_init_gamma = (
26 | hasattr(m, 'final_bn') and m.final_bn and
27 | cfg.BN.ZERO_INIT_FINAL_GAMMA
28 | )
29 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
30 | m.bias.data.zero_()
31 | elif isinstance(m, nn.Linear):
32 | m.weight.data.normal_(mean=0.0, std=0.01)
33 | m.bias.data.zero_()
34 |
35 |
36 | @torch.no_grad()
37 | def compute_precise_bn_stats(model, loader):
38 | """Computes precise BN stats on training data."""
39 | # Compute the number of minibatches to use
40 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
41 | # Retrieve the BN layers
42 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
43 | # Initialize stats storage
44 | mus = [torch.zeros_like(bn.running_mean) for bn in bns]
45 | sqs = [torch.zeros_like(bn.running_var) for bn in bns]
46 | # Remember momentum values
47 | moms = [bn.momentum for bn in bns]
48 | # Disable momentum
49 | for bn in bns:
50 | bn.momentum = 1.0
51 | # Accumulate the stats across the data samples
52 | for inputs, _labels in itertools.islice(loader, num_iter):
53 | model(inputs.cuda())
54 | # Accumulate the stats for each BN layer
55 | for i, bn in enumerate(bns):
56 | m, v = bn.running_mean, bn.running_var
57 | sqs[i] += (v + m * m) / num_iter
58 | mus[i] += m / num_iter
59 | # Set the stats and restore momentum values
60 | for i, bn in enumerate(bns):
61 | bn.running_var = sqs[i] - mus[i] * mus[i]
62 | bn.running_mean = mus[i]
63 | bn.momentum = moms[i]
64 |
65 |
66 | def reset_bn_stats(model):
67 | """Resets running BN stats."""
68 | for m in model.modules():
69 | if isinstance(m, torch.nn.BatchNorm2d):
70 | m.reset_running_stats()
71 |
72 |
73 | def drop_connect(x, drop_ratio):
74 | """Drop connect (adapted from DARTS)."""
75 | keep_ratio = 1.0 - drop_ratio
76 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
77 | mask.bernoulli_(keep_ratio)
78 | x.div_(keep_ratio)
79 | x.mul_(mask)
80 | return x
81 |
82 |
83 | def get_flat_weights(model):
84 | """Gets all model weights as a single flat vector."""
85 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
86 |
87 |
88 | def set_flat_weights(model, flat_weights):
89 | """Sets all model weights from a single flat vector."""
90 | k = 0
91 | for p in model.parameters():
92 | n = p.data.numel()
93 | p.data.copy_(flat_weights[k:(k + n)].view_as(p.data))
94 | k += n
95 | assert k == flat_weights.numel()
96 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/plotting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Plotting functions."""
9 |
10 | import colorlover as cl
11 | import matplotlib.pyplot as plt
12 | import plotly.graph_objs as go
13 | import plotly.offline as offline
14 | import pycls.core.logging as logging
15 |
16 |
17 | def get_plot_colors(max_colors, color_format="pyplot"):
18 | """Generate colors for plotting."""
19 | colors = cl.scales["11"]["qual"]["Paired"]
20 | if max_colors > len(colors):
21 | colors = cl.to_rgb(cl.interp(colors, max_colors))
22 | if color_format == "pyplot":
23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
24 | return colors
25 |
26 |
27 | def prepare_plot_data(log_files, names, metric="top1_err"):
28 | """Load logs and extract data for plotting error curves."""
29 | plot_data = []
30 | for file, name in zip(log_files, names):
31 | d, data = {}, logging.sort_log_data(logging.load_log_data(file))
32 | for phase in ["train", "test"]:
33 | x = data[phase + "_epoch"]["epoch_ind"]
34 | y = data[phase + "_epoch"][metric]
35 | d["x_" + phase], d["y_" + phase] = x, y
36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
37 | plot_data.append(d)
38 | assert len(plot_data) > 0, "No data to plot"
39 | return plot_data
40 |
41 |
42 | def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
43 | """Plot error curves using plotly and save to file."""
44 | plot_data = prepare_plot_data(log_files, names, metric)
45 | colors = get_plot_colors(len(plot_data), "plotly")
46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend)
47 | data = []
48 | for i, d in enumerate(plot_data):
49 | s = str(i)
50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
52 | data.append(
53 | go.Scatter(
54 | x=d["x_train"],
55 | y=d["y_train"],
56 | mode="lines",
57 | name=d["train_label"],
58 | line=line_train,
59 | legendgroup=s,
60 | visible=True,
61 | showlegend=False,
62 | )
63 | )
64 | data.append(
65 | go.Scatter(
66 | x=d["x_test"],
67 | y=d["y_test"],
68 | mode="lines",
69 | name=d["test_label"],
70 | line=line_test,
71 | legendgroup=s,
72 | visible=True,
73 | showlegend=True,
74 | )
75 | )
76 | data.append(
77 | go.Scatter(
78 | x=d["x_train"],
79 | y=d["y_train"],
80 | mode="lines",
81 | name=d["train_label"],
82 | line=line_train,
83 | legendgroup=s,
84 | visible=False,
85 | showlegend=True,
86 | )
87 | )
88 | # Prepare layout w ability to toggle 'all', 'train', 'test'
89 | titlefont = {"size": 18, "color": "#7f7f7f"}
90 | vis = [[True, True, False], [False, False, True], [False, True, False]]
91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
92 | buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
93 | layout = go.Layout(
94 | title=metric + " vs. epoch
[dash=train, solid=test]",
95 | xaxis={"title": "epoch", "titlefont": titlefont},
96 | yaxis={"title": metric, "titlefont": titlefont},
97 | showlegend=True,
98 | hoverlabel={"namelength": -1},
99 | updatemenus=[
100 | {
101 | "buttons": buttons,
102 | "direction": "down",
103 | "showactive": True,
104 | "x": 1.02,
105 | "xanchor": "left",
106 | "y": 1.08,
107 | "yanchor": "top",
108 | }
109 | ],
110 | )
111 | # Create plotly plot
112 | offline.plot({"data": data, "layout": layout}, filename=filename)
113 |
114 |
115 | def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
116 | """Plot error curves using matplotlib.pyplot and save to file."""
117 | plot_data = prepare_plot_data(log_files, names, metric)
118 | colors = get_plot_colors(len(names))
119 | for ind, d in enumerate(plot_data):
120 | c, lbl = colors[ind], d["test_label"]
121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
123 | plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
124 | plt.xlabel("epoch", fontsize=14)
125 | plt.ylabel(metric, fontsize=14)
126 | plt.grid(alpha=0.4)
127 | plt.legend()
128 | if filename:
129 | plt.savefig(filename)
130 | plt.clf()
131 | else:
132 | plt.show()
133 |
--------------------------------------------------------------------------------
/deep-al/pycls/utils/timer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Timer."""
9 |
10 | import time
11 |
12 |
13 | class Timer(object):
14 | """A simple timer (adapted from Detectron)."""
15 |
16 | def __init__(self):
17 | self.total_time = None
18 | self.calls = None
19 | self.start_time = None
20 | self.diff = None
21 | self.average_time = None
22 | self.reset()
23 |
24 | def tic(self):
25 | # using time.time as time.clock does not normalize for multithreading
26 | self.start_time = time.time()
27 |
28 | def toc(self):
29 | self.diff = time.time() - self.start_time
30 | self.total_time += self.diff
31 | self.calls += 1
32 | self.average_time = self.total_time / self.calls
33 |
34 | def reset(self):
35 | self.total_time = 0.0
36 | self.calls = 0
37 | self.start_time = 0.0
38 | self.diff = 0.0
39 | self.average_time = 0.0
40 |
--------------------------------------------------------------------------------
/deep-al/requirements.txt:
--------------------------------------------------------------------------------
1 | black==19.3b0
2 | flake8==3.8.4
3 | isort==4.3.21
4 | matplotlib==3.3.4
5 | numpy
6 | opencv-python==4.2.0.34
7 | torch==1.7.1
8 | torchvision==0.8.2
9 | parameterized
10 | setuptools
11 | simplejson
12 | yacs
--------------------------------------------------------------------------------
/deep-al/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/deep-al/tools/__init__.py
--------------------------------------------------------------------------------
/probcover_selection.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/probcover_selection.gif
--------------------------------------------------------------------------------
/probcover_semi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/probcover_semi.png
--------------------------------------------------------------------------------
/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/results.png
--------------------------------------------------------------------------------
/scan/configs/env.yml:
--------------------------------------------------------------------------------
1 | root_dir: ./results/
2 |
--------------------------------------------------------------------------------
/scan/configs/pretext/moco_imagenet100.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | setup: moco # MoCo is used here
3 |
4 | # Model
5 | backbone: resnet50
6 | model_kwargs:
7 | head: mlp
8 | features_dim: 128
9 |
10 | # Dataset
11 | train_db_name: imagenet_100
12 | val_db_name: imagenet_100
13 | num_classes: 100
14 | temperature: 0.07
15 |
16 | # Batch size and workers
17 | batch_size: 256
18 | num_workers: 8
19 |
20 | # Transformations
21 | transformation_kwargs:
22 | crop_size: 224
23 | normalize:
24 | mean: [0.485, 0.456, 0.406]
25 | std: [0.229, 0.224, 0.225]
26 |
--------------------------------------------------------------------------------
/scan/configs/pretext/moco_imagenet200.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | setup: moco # MoCo is used here
3 |
4 | # Model
5 | backbone: resnet50
6 | model_kwargs:
7 | head: mlp
8 | features_dim: 128
9 |
10 | # Dataset
11 | train_db_name: imagenet_200
12 | val_db_name: imagenet_200
13 | num_classes: 200
14 | temperature: 0.07
15 |
16 | # Batch size and workers
17 | batch_size: 256
18 | num_workers: 8
19 |
20 | # Transformations
21 | transformation_kwargs:
22 | crop_size: 224
23 | normalize:
24 | mean: [0.485, 0.456, 0.406]
25 | std: [0.229, 0.224, 0.225]
26 |
--------------------------------------------------------------------------------
/scan/configs/pretext/moco_imagenet50.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | setup: moco # MoCo is used here
3 |
4 | # Model
5 | backbone: resnet50
6 | model_kwargs:
7 | head: mlp
8 | features_dim: 128
9 |
10 | # Dataset
11 | train_db_name: imagenet_50
12 | val_db_name: imagenet_50
13 | num_classes: 50
14 | temperature: 0.07
15 |
16 | # Batch size and workers
17 | batch_size: 256
18 | num_workers: 8
19 |
20 | # Transformations
21 | transformation_kwargs:
22 | crop_size: 224
23 | normalize:
24 | mean: [0.485, 0.456, 0.406]
25 | std: [0.229, 0.224, 0.225]
26 |
--------------------------------------------------------------------------------
/scan/configs/pretext/simclr_cifar10.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | setup: simclr
3 |
4 | # Model
5 | backbone: resnet18
6 | model_kwargs:
7 | head: mlp
8 | features_dim: 128
9 |
10 | # Dataset
11 | train_db_name: cifar-10
12 | val_db_name: cifar-10
13 | num_classes: 10
14 |
15 | # Loss
16 | criterion: simclr
17 | criterion_kwargs:
18 | temperature: 0.1
19 |
20 | # Hyperparameters
21 | epochs: 500
22 | optimizer: sgd
23 | optimizer_kwargs:
24 | nesterov: False
25 | weight_decay: 0.0001
26 | momentum: 0.9
27 | lr: 0.4
28 | scheduler: cosine
29 | scheduler_kwargs:
30 | lr_decay_rate: 0.1
31 | batch_size: 512
32 | num_workers: 8
33 |
34 | # Transformations
35 | augmentation_strategy: simclr
36 | augmentation_kwargs:
37 | random_resized_crop:
38 | size: 32
39 | scale: [0.2, 1.0]
40 | color_jitter_random_apply:
41 | p: 0.8
42 | color_jitter:
43 | brightness: 0.4
44 | contrast: 0.4
45 | saturation: 0.4
46 | hue: 0.1
47 | random_grayscale:
48 | p: 0.2
49 | normalize:
50 | mean: [0.4914, 0.4822, 0.4465]
51 | std: [0.2023, 0.1994, 0.2010]
52 |
53 | transformation_kwargs:
54 | crop_size: 32
55 | normalize:
56 | mean: [0.4914, 0.4822, 0.4465]
57 | std: [0.2023, 0.1994, 0.2010]
58 |
--------------------------------------------------------------------------------
/scan/configs/pretext/simclr_cifar100.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | setup: simclr
3 |
4 | # Model
5 | backbone: resnet18
6 | model_kwargs:
7 | head: mlp
8 | features_dim: 128
9 |
10 | # Dataset
11 | train_db_name: cifar-100
12 | val_db_name: cifar-100
13 | num_classes: 100
14 |
15 | # Loss
16 | criterion: simclr
17 | criterion_kwargs:
18 | temperature: 0.1
19 |
20 | # Hyperparameters
21 | epochs: 500
22 | optimizer: sgd
23 | optimizer_kwargs:
24 | nesterov: False
25 | weight_decay: 0.0001
26 | momentum: 0.9
27 | lr: 0.4
28 | scheduler: cosine
29 | scheduler_kwargs:
30 | lr_decay_rate: 0.1
31 | batch_size: 512
32 | num_workers: 8
33 |
34 | # Transformations
35 | augmentation_strategy: simclr
36 | augmentation_kwargs:
37 | random_resized_crop:
38 | size: 32
39 | scale: [0.2, 1.0]
40 | color_jitter_random_apply:
41 | p: 0.8
42 | color_jitter:
43 | brightness: 0.4
44 | contrast: 0.4
45 | saturation: 0.4
46 | hue: 0.1
47 | random_grayscale:
48 | p: 0.2
49 | normalize:
50 | mean: [0.5071, 0.4867, 0.4408]
51 | std: [0.2675, 0.2565, 0.2761]
52 |
53 | transformation_kwargs:
54 | crop_size: 32
55 | normalize:
56 | mean: [0.5071, 0.4867, 0.4408]
57 | std: [0.2675, 0.2565, 0.2761]
58 |
--------------------------------------------------------------------------------
/scan/configs/pretext/simclr_stl10.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | setup: simclr
3 |
4 | # Model
5 | backbone: resnet18
6 | model_kwargs:
7 | head: mlp
8 | features_dim: 128
9 |
10 | # Dataset
11 | train_db_name: stl-10
12 | val_db_name: stl-10
13 | num_classes: 10
14 |
15 | # Loss
16 | criterion: simclr
17 | criterion_kwargs:
18 | temperature: 0.1
19 |
20 | # Hyperparameters
21 | epochs: 500
22 | optimizer: sgd
23 | optimizer_kwargs:
24 | nesterov: False
25 | weight_decay: 0.0001
26 | momentum: 0.9
27 | lr: 0.4
28 | scheduler: cosine
29 | scheduler_kwargs:
30 | lr_decay_rate: 0.1
31 | batch_size: 512
32 | num_workers: 8
33 |
34 | # Transformations
35 | augmentation_strategy: simclr
36 | augmentation_kwargs:
37 | random_resized_crop:
38 | size: 96
39 | scale: [0.2, 1.0]
40 | color_jitter_random_apply:
41 | p: 0.8
42 | color_jitter:
43 | brightness: 0.4
44 | contrast: 0.4
45 | saturation: 0.4
46 | hue: 0.1
47 | random_grayscale:
48 | p: 0.2
49 | normalize:
50 | mean: [0.485, 0.456, 0.406]
51 | std: [0.229, 0.224, 0.225]
52 |
53 | transformation_kwargs:
54 | crop_size: 96
55 | normalize:
56 | mean: [0.485, 0.456, 0.406]
57 | std: [0.229, 0.224, 0.225]
58 |
--------------------------------------------------------------------------------
/scan/configs/pretext/simclr_tinyimagenet.yml:
--------------------------------------------------------------------------------
1 | # Setup
2 | setup: simclr
3 |
4 | # Model
5 | backbone: resnet18
6 | model_kwargs:
7 | head: mlp
8 | features_dim: 128
9 |
10 | # Dataset
11 | train_db_name: tiny-imagenet
12 | val_db_name: tiny-imagenet
13 | num_classes: 200
14 |
15 | # Loss
16 | criterion: simclr
17 | criterion_kwargs:
18 | temperature: 0.1
19 |
20 | # Hyperparameters
21 | epochs: 500
22 | optimizer: sgd
23 | optimizer_kwargs:
24 | nesterov: False
25 | weight_decay: 0.0001
26 | momentum: 0.9
27 | lr: 0.4
28 | scheduler: cosine
29 | scheduler_kwargs:
30 | lr_decay_rate: 0.1
31 | batch_size: 512
32 | num_workers: 8
33 |
34 | # Transformations
35 | augmentation_strategy: simclr
36 | augmentation_kwargs:
37 | random_resized_crop:
38 | size: 64
39 | scale: [0.2, 1.0]
40 | color_jitter_random_apply:
41 | p: 0.8
42 | color_jitter:
43 | brightness: 0.4
44 | contrast: 0.4
45 | saturation: 0.4
46 | hue: 0.1
47 | random_grayscale:
48 | p: 0.2
49 | normalize:
50 | mean: [0.5071, 0.4867, 0.4408]
51 | std: [0.2675, 0.2565, 0.2761]
52 |
53 | transformation_kwargs:
54 | crop_size: 64
55 | normalize:
56 | mean: [0.5071, 0.4867, 0.4408]
57 | std: [0.2675, 0.2565, 0.2761]
58 |
--------------------------------------------------------------------------------
/scan/configs/scan/imagenet_eval.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Model
10 | backbone: resnet50
11 | num_heads: 10 # Use multiple heads
12 |
13 | # Dataset
14 | train_db_name: imagenet
15 | val_db_name: imagenet
16 | num_classes: 1000
17 | num_neighbors: 50
18 |
19 | # Transformations
20 | augmentation_strategy: simclr
21 | augmentation_kwargs:
22 | random_resized_crop:
23 | size: 224
24 | scale: [0.2, 1.0]
25 | color_jitter_random_apply:
26 | p: 0.8
27 | color_jitter:
28 | brightness: 0.4
29 | contrast: 0.4
30 | saturation: 0.4
31 | hue: 0.1
32 | random_grayscale:
33 | p: 0.2
34 | normalize:
35 | mean: [0.485, 0.456, 0.406]
36 | std: [0.229, 0.224, 0.225]
37 |
38 | transformation_kwargs:
39 | crop_size: 224
40 | normalize:
41 | mean: [0.485, 0.456, 0.406]
42 | std: [0.229, 0.224, 0.225]
43 |
44 | num_workers: 12
45 |
--------------------------------------------------------------------------------
/scan/configs/scan/scan_cifar10.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Weight update
10 | update_cluster_head_only: False # Update full network in SCAN
11 | num_heads: 1 # Only use one head
12 |
13 | # Model
14 | backbone: resnet18
15 |
16 | # Dataset
17 | train_db_name: cifar-10
18 | val_db_name: cifar-10
19 | num_classes: 10
20 | num_neighbors: 20
21 |
22 | # Transformations
23 | augmentation_strategy: ours
24 | augmentation_kwargs:
25 | crop_size: 32
26 | normalize:
27 | mean: [0.4914, 0.4822, 0.4465]
28 | std: [0.2023, 0.1994, 0.2010]
29 | num_strong_augs: 4
30 | cutout_kwargs:
31 | n_holes: 1
32 | length: 16
33 | random: True
34 |
35 | transformation_kwargs:
36 | crop_size: 32
37 | normalize:
38 | mean: [0.4914, 0.4822, 0.4465]
39 | std: [0.2023, 0.1994, 0.2010]
40 |
41 | # Hyperparameters
42 | optimizer: adam
43 | optimizer_kwargs:
44 | lr: 0.0001
45 | weight_decay: 0.0001
46 | epochs: 50
47 | batch_size: 128
48 | num_workers: 8
49 |
50 | # Scheduler
51 | scheduler: constant
52 |
--------------------------------------------------------------------------------
/scan/configs/scan/scan_cifar100.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Weight update
10 | update_cluster_head_only: False # Update full network in SCAN
11 | num_heads: 1 # Only use one head
12 |
13 | # Model
14 | backbone: resnet18
15 |
16 | # Dataset
17 | train_db_name: cifar-100
18 | val_db_name: cifar-100
19 | num_classes: 100
20 | num_neighbors: 20
21 |
22 | # Transformations
23 | augmentation_strategy: ours
24 | augmentation_kwargs:
25 | crop_size: 32
26 | normalize:
27 | mean: [0.5071, 0.4867, 0.4408]
28 | std: [0.2675, 0.2565, 0.2761]
29 | num_strong_augs: 4
30 | cutout_kwargs:
31 | n_holes: 1
32 | length: 16
33 | random: True
34 |
35 | transformation_kwargs:
36 | crop_size: 32
37 | normalize:
38 | mean: [0.5071, 0.4867, 0.4408]
39 | std: [0.2675, 0.2565, 0.2761]
40 |
41 | # Hyperparameters
42 | optimizer: adam
43 | optimizer_kwargs:
44 | lr: 0.0001
45 | weight_decay: 0.0001
46 | epochs: 100
47 | batch_size: 512
48 | num_workers: 8
49 |
50 | # Scheduler
51 | scheduler: constant
52 |
--------------------------------------------------------------------------------
/scan/configs/scan/scan_imagenet_100.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Model
10 | backbone: resnet50
11 |
12 | # Weight update
13 | update_cluster_head_only: True # Train only linear layer during SCAN
14 | num_heads: 10 # Use multiple heads
15 |
16 | # Dataset
17 | train_db_name: imagenet_100
18 | val_db_name: imagenet_100
19 | num_classes: 100
20 | num_neighbors: 50
21 |
22 | # Transformations
23 | augmentation_strategy: simclr
24 | augmentation_kwargs:
25 | random_resized_crop:
26 | size: 224
27 | scale: [0.2, 1.0]
28 | color_jitter_random_apply:
29 | p: 0.8
30 | color_jitter:
31 | brightness: 0.4
32 | contrast: 0.4
33 | saturation: 0.4
34 | hue: 0.1
35 | random_grayscale:
36 | p: 0.2
37 | normalize:
38 | mean: [0.485, 0.456, 0.406]
39 | std: [0.229, 0.224, 0.225]
40 |
41 | transformation_kwargs:
42 | crop_size: 224
43 | normalize:
44 | mean: [0.485, 0.456, 0.406]
45 | std: [0.229, 0.224, 0.225]
46 |
47 | # Hyperparameters
48 | optimizer: sgd
49 | optimizer_kwargs:
50 | lr: 5.0
51 | weight_decay: 0.0000
52 | nesterov: False
53 | momentum: 0.9
54 | epochs: 100
55 | batch_size: 1024
56 | num_workers: 16
57 |
58 | # Scheduler
59 | scheduler: constant
60 |
--------------------------------------------------------------------------------
/scan/configs/scan/scan_imagenet_200.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Model
10 | backbone: resnet50
11 |
12 | # Weight update
13 | update_cluster_head_only: True # Train only linear layer during SCAN
14 | num_heads: 10 # Use multiple heads
15 |
16 | # Dataset
17 | train_db_name: imagenet_200
18 | val_db_name: imagenet_200
19 | num_classes: 200
20 | num_neighbors: 50
21 |
22 | # Transformations
23 | augmentation_strategy: simclr
24 | augmentation_kwargs:
25 | random_resized_crop:
26 | size: 224
27 | scale: [0.2, 1.0]
28 | color_jitter_random_apply:
29 | p: 0.8
30 | color_jitter:
31 | brightness: 0.4
32 | contrast: 0.4
33 | saturation: 0.4
34 | hue: 0.1
35 | random_grayscale:
36 | p: 0.2
37 | normalize:
38 | mean: [0.485, 0.456, 0.406]
39 | std: [0.229, 0.224, 0.225]
40 |
41 | transformation_kwargs:
42 | crop_size: 224
43 | normalize:
44 | mean: [0.485, 0.456, 0.406]
45 | std: [0.229, 0.224, 0.225]
46 |
47 | # Hyperparameters
48 | optimizer: sgd
49 | optimizer_kwargs:
50 | lr: 5.0
51 | weight_decay: 0.0000
52 | nesterov: False
53 | momentum: 0.9
54 | epochs: 100
55 | batch_size: 1024
56 | num_workers: 12
57 |
58 | # Scheduler
59 | scheduler: constant
60 |
--------------------------------------------------------------------------------
/scan/configs/scan/scan_imagenet_50.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Model
10 | backbone: resnet50
11 |
12 | # Weight update
13 | update_cluster_head_only: True # Train only linear layer during SCAN
14 | num_heads: 10 # Use multiple heads
15 |
16 | # Dataset
17 | train_db_name: imagenet_50
18 | val_db_name: imagenet_50
19 | num_classes: 50
20 | num_neighbors: 50
21 |
22 | # Transformations
23 | augmentation_strategy: simclr
24 | augmentation_kwargs:
25 | random_resized_crop:
26 | size: 224
27 | scale: [0.2, 1.0]
28 | color_jitter_random_apply:
29 | p: 0.8
30 | color_jitter:
31 | brightness: 0.4
32 | contrast: 0.4
33 | saturation: 0.4
34 | hue: 0.1
35 | random_grayscale:
36 | p: 0.2
37 | normalize:
38 | mean: [0.485, 0.456, 0.406]
39 | std: [0.229, 0.224, 0.225]
40 |
41 | transformation_kwargs:
42 | crop_size: 224
43 | normalize:
44 | mean: [0.485, 0.456, 0.406]
45 | std: [0.229, 0.224, 0.225]
46 |
47 | # Hyperparameters
48 | optimizer: sgd
49 | optimizer_kwargs:
50 | lr: 5.0
51 | weight_decay: 0.0000
52 | nesterov: False
53 | momentum: 0.9
54 | epochs: 100
55 | batch_size: 512
56 | num_workers: 12
57 |
58 | # Scheduler
59 | scheduler: constant
60 |
--------------------------------------------------------------------------------
/scan/configs/scan/scan_stl10.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Weight update
10 | update_cluster_head_only: False # Update full network in SCAN
11 | num_heads: 1 # Only use one head
12 |
13 | # Model
14 | backbone: resnet18
15 |
16 | # Dataset
17 | train_db_name: stl-10
18 | val_db_name: stl-10
19 | num_classes: 10
20 | num_neighbors: 20
21 |
22 | # Transformations
23 | augmentation_strategy: ours
24 | augmentation_kwargs:
25 | crop_size: 96
26 | normalize:
27 | mean: [0.485, 0.456, 0.406]
28 | std: [0.229, 0.224, 0.225]
29 | num_strong_augs: 4
30 | cutout_kwargs:
31 | n_holes: 1
32 | length: 32
33 | random: True
34 |
35 | transformation_kwargs:
36 | crop_size: 96
37 | normalize:
38 | mean: [0.485, 0.456, 0.406]
39 | std: [0.229, 0.224, 0.225]
40 |
41 | # Hyperparameters
42 | optimizer: adam
43 | optimizer_kwargs:
44 | lr: 0.0001
45 | weight_decay: 0.0001
46 | epochs: 100
47 | batch_size: 128
48 | num_workers: 8
49 |
50 | # Scheduler
51 | scheduler: constant
52 |
--------------------------------------------------------------------------------
/scan/configs/scan/scan_tinyimagenet.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: scan
3 |
4 | # Loss
5 | criterion: scan
6 | criterion_kwargs:
7 | entropy_weight: 5.0
8 |
9 | # Weight update
10 | update_cluster_head_only: False # Update full network in SCAN
11 | num_heads: 1 # Only use one head
12 |
13 | # Model
14 | backbone: resnet18
15 |
16 | # Dataset
17 | train_db_name: tiny-imagenet
18 | val_db_name: tiny-imagenet
19 | num_classes: 200
20 | num_neighbors: 20
21 |
22 | # Transformations
23 | augmentation_strategy: ours
24 | augmentation_kwargs:
25 | crop_size: 64
26 | normalize:
27 | mean: [0.5071, 0.4867, 0.4408]
28 | std: [0.2675, 0.2565, 0.2761]
29 | num_strong_augs: 4
30 | cutout_kwargs:
31 | n_holes: 1
32 | length: 16
33 | random: True
34 |
35 | transformation_kwargs:
36 | crop_size: 64
37 | normalize:
38 | mean: [0.5071, 0.4867, 0.4408]
39 | std: [0.2675, 0.2565, 0.2761]
40 |
41 | # Hyperparameters
42 | optimizer: adam
43 | optimizer_kwargs:
44 | lr: 0.0001
45 | weight_decay: 0.0001
46 | epochs: 100
47 | batch_size: 512
48 |
49 | num_workers: 8
50 |
51 | # Scheduler
52 | scheduler: constant
53 |
--------------------------------------------------------------------------------
/scan/configs/selflabel/selflabel_cifar10.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: selflabel
3 |
4 | # ema
5 | use_ema: False
6 |
7 | # Threshold
8 | confidence_threshold: 0.99
9 |
10 | # Criterion
11 | criterion: confidence-cross-entropy
12 | criterion_kwargs:
13 | apply_class_balancing: True
14 |
15 | # Model
16 | backbone: resnet18
17 | num_heads: 1
18 |
19 | # Dataset
20 | train_db_name: cifar-10
21 | val_db_name: cifar-10
22 | num_classes: 10
23 |
24 | # Transformations
25 | augmentation_strategy: ours
26 | augmentation_kwargs:
27 | crop_size: 32
28 | normalize:
29 | mean: [0.4914, 0.4822, 0.4465]
30 | std: [0.2023, 0.1994, 0.2010]
31 | num_strong_augs: 4
32 | cutout_kwargs:
33 | n_holes: 1
34 | length: 16
35 | random: True
36 |
37 | transformation_kwargs:
38 | crop_size: 32
39 | normalize:
40 | mean: [0.4914, 0.4822, 0.4465]
41 | std: [0.2023, 0.1994, 0.2010]
42 |
43 | # Hyperparameters
44 | epochs: 200
45 | optimizer: adam
46 | optimizer_kwargs:
47 | lr: 0.0001
48 | weight_decay: 0.0001
49 | batch_size: 1000
50 | num_workers: 8
51 |
52 | # Scheduler
53 | scheduler: constant
54 |
--------------------------------------------------------------------------------
/scan/configs/selflabel/selflabel_cifar100.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: selflabel
3 |
4 | # ema
5 | use_ema: False
6 |
7 | # Threshold
8 | confidence_threshold: 0.99
9 |
10 | # Criterion
11 | criterion: confidence-cross-entropy
12 | criterion_kwargs:
13 | apply_class_balancing: True
14 |
15 | # Model
16 | backbone: resnet18
17 | num_heads: 1
18 |
19 | # Dataset
20 | train_db_name: cifar-20
21 | val_db_name: cifar-20
22 | num_classes: 20
23 |
24 | # Transformations
25 | augmentation_strategy: ours
26 | augmentation_kwargs:
27 | crop_size: 32
28 | normalize:
29 | mean: [0.5071, 0.4867, 0.4408]
30 | std: [0.2675, 0.2565, 0.2761]
31 | num_strong_augs: 4
32 | cutout_kwargs:
33 | n_holes: 1
34 | length: 16
35 | random: True
36 |
37 | transformation_kwargs:
38 | crop_size: 32
39 | normalize:
40 | mean: [0.5071, 0.4867, 0.4408]
41 | std: [0.2675, 0.2565, 0.2761]
42 |
43 | # Hyperparameters
44 | epochs: 200
45 | optimizer: adam
46 | optimizer_kwargs:
47 | lr: 0.0001
48 | weight_decay: 0.0001
49 | batch_size: 1000
50 | num_workers: 8
51 |
52 | # Scheduler
53 | scheduler: constant
54 |
--------------------------------------------------------------------------------
/scan/configs/selflabel/selflabel_imagenet_100.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: selflabel
3 |
4 | # Threshold
5 | confidence_threshold: 0.99
6 |
7 | # EMA
8 | use_ema: True
9 | ema_alpha: 0.999
10 |
11 | # Loss
12 | criterion: confidence-cross-entropy
13 | criterion_kwargs:
14 | apply_class_balancing: False
15 |
16 | # Model
17 | backbone: resnet50
18 | num_heads: 1
19 |
20 | # Dataset
21 | train_db_name: imagenet_100
22 | val_db_name: imagenet_100
23 | num_classes: 100
24 |
25 | # Transformations
26 | augmentation_strategy: ours
27 | augmentation_kwargs:
28 | crop_size: 224
29 | normalize:
30 | mean: [0.485, 0.456, 0.406]
31 | std: [0.229, 0.224, 0.225]
32 | num_strong_augs: 4
33 | cutout_kwargs:
34 | n_holes: 1
35 | length: 75
36 | random: True
37 |
38 | transformation_kwargs:
39 | crop_size: 224
40 | normalize:
41 | mean: [0.485, 0.456, 0.406]
42 | std: [0.229, 0.224, 0.225]
43 |
44 | # Hyperparameters
45 | optimizer: sgd
46 | optimizer_kwargs:
47 | lr: 0.03
48 | weight_decay: 0.0
49 | nesterov: False
50 | momentum: 0.9
51 | epochs: 25
52 | batch_size: 512
53 | num_workers: 12
54 |
55 | # Scheduler
56 | scheduler: constant
57 |
--------------------------------------------------------------------------------
/scan/configs/selflabel/selflabel_imagenet_200.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: selflabel
3 |
4 | # Threshold
5 | confidence_threshold: 0.99
6 |
7 | # EMA
8 | use_ema: True
9 | ema_alpha: 0.999
10 |
11 | # Loss
12 | criterion: confidence-cross-entropy
13 | criterion_kwargs:
14 | apply_class_balancing: False
15 |
16 | # Model
17 | backbone: resnet50
18 | num_heads: 1
19 |
20 | # Dataset
21 | train_db_name: imagenet_200
22 | val_db_name: imagenet_200
23 | num_classes: 200
24 |
25 | # Transformations
26 | augmentation_strategy: ours
27 | augmentation_kwargs:
28 | crop_size: 224
29 | normalize:
30 | mean: [0.485, 0.456, 0.406]
31 | std: [0.229, 0.224, 0.225]
32 | num_strong_augs: 4
33 | cutout_kwargs:
34 | n_holes: 1
35 | length: 75
36 | random: True
37 |
38 | transformation_kwargs:
39 | crop_size: 224
40 | normalize:
41 | mean: [0.485, 0.456, 0.406]
42 | std: [0.229, 0.224, 0.225]
43 |
44 | # Hyperparameters
45 | optimizer: sgd
46 | optimizer_kwargs:
47 | lr: 0.03
48 | weight_decay: 0.0
49 | nesterov: False
50 | momentum: 0.9
51 | epochs: 25
52 | batch_size: 512
53 | num_workers: 8
54 |
55 | # Scheduler
56 | scheduler: constant
57 |
--------------------------------------------------------------------------------
/scan/configs/selflabel/selflabel_imagenet_50.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: selflabel
3 |
4 | # Threshold
5 | confidence_threshold: 0.99
6 |
7 | # EMA
8 | use_ema: True
9 | ema_alpha: 0.999
10 |
11 | # Loss
12 | criterion: confidence-cross-entropy
13 | criterion_kwargs:
14 | apply_class_balancing: False
15 |
16 | # Model
17 | backbone: resnet50
18 | num_heads: 1
19 |
20 | # Dataset
21 | train_db_name: imagenet_50
22 | val_db_name: imagenet_50
23 | num_classes: 50
24 |
25 | # Transformations
26 | augmentation_strategy: ours
27 | augmentation_kwargs:
28 | crop_size: 224
29 | normalize:
30 | mean: [0.485, 0.456, 0.406]
31 | std: [0.229, 0.224, 0.225]
32 | num_strong_augs: 4
33 | cutout_kwargs:
34 | n_holes: 1
35 | length: 75
36 | random: True
37 |
38 | transformation_kwargs:
39 | crop_size: 224
40 | normalize:
41 | mean: [0.485, 0.456, 0.406]
42 | std: [0.229, 0.224, 0.225]
43 |
44 | # Hyperparameters
45 | optimizer: sgd
46 | optimizer_kwargs:
47 | lr: 0.03
48 | weight_decay: 0.0
49 | nesterov: False
50 | momentum: 0.9
51 | epochs: 25
52 | batch_size: 512
53 | num_workers: 16
54 |
55 | # Scheduler
56 | scheduler: constant
57 |
--------------------------------------------------------------------------------
/scan/configs/selflabel/selflabel_stl10.yml:
--------------------------------------------------------------------------------
1 | # setup
2 | setup: selflabel
3 |
4 | # ema
5 | use_ema: False
6 |
7 | # Threshold
8 | confidence_threshold: 0.99
9 |
10 | # Loss
11 | criterion: confidence-cross-entropy
12 | criterion_kwargs:
13 | apply_class_balancing: True
14 |
15 | # Model
16 | backbone: resnet18
17 | num_heads: 1
18 |
19 | # Dataset
20 | train_db_name: stl-10
21 | val_db_name: stl-10
22 | num_classes: 10
23 |
24 | # Transformations
25 | augmentation_strategy: ours
26 | augmentation_kwargs:
27 | crop_size: 96
28 | normalize:
29 | mean: [0.485, 0.456, 0.406]
30 | std: [0.229, 0.224, 0.225]
31 | num_strong_augs: 4
32 | cutout_kwargs:
33 | n_holes: 1
34 | length: 32
35 | random: True
36 |
37 | transformation_kwargs:
38 | crop_size: 96
39 | normalize:
40 | mean: [0.485, 0.456, 0.406]
41 | std: [0.229, 0.224, 0.225]
42 |
43 | # Hyperparameters
44 | optimizer: adam
45 | optimizer_kwargs:
46 | lr: 0.0001
47 | weight_decay: 0.0001
48 | epochs: 100
49 | batch_size: 1000
50 | num_workers: 8
51 |
52 | # Scheduler
53 | scheduler: constant
54 |
--------------------------------------------------------------------------------
/scan/data/augment.py:
--------------------------------------------------------------------------------
1 | # List of augmentations based on randaugment
2 | import random
3 |
4 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
5 | import numpy as np
6 | import torch
7 | from torchvision.transforms.transforms import Compose
8 |
9 | random_mirror = True
10 |
11 | def ShearX(img, v):
12 | if random_mirror and random.random() > 0.5:
13 | v = -v
14 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
15 |
16 | def ShearY(img, v):
17 | if random_mirror and random.random() > 0.5:
18 | v = -v
19 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
20 |
21 | def Identity(img, v):
22 | return img
23 |
24 | def TranslateX(img, v):
25 | if random_mirror and random.random() > 0.5:
26 | v = -v
27 | v = v * img.size[0]
28 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
29 |
30 | def TranslateY(img, v):
31 | if random_mirror and random.random() > 0.5:
32 | v = -v
33 | v = v * img.size[1]
34 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
35 |
36 | def TranslateXAbs(img, v):
37 | if random.random() > 0.5:
38 | v = -v
39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
40 |
41 | def TranslateYAbs(img, v):
42 | if random.random() > 0.5:
43 | v = -v
44 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
45 |
46 | def Rotate(img, v):
47 | if random_mirror and random.random() > 0.5:
48 | v = -v
49 | return img.rotate(v)
50 |
51 | def AutoContrast(img, _):
52 | return PIL.ImageOps.autocontrast(img)
53 |
54 | def Invert(img, _):
55 | return PIL.ImageOps.invert(img)
56 |
57 | def Equalize(img, _):
58 | return PIL.ImageOps.equalize(img)
59 |
60 | def Solarize(img, v):
61 | return PIL.ImageOps.solarize(img, v)
62 |
63 | def Posterize(img, v):
64 | v = int(v)
65 | return PIL.ImageOps.posterize(img, v)
66 |
67 | def Contrast(img, v):
68 | return PIL.ImageEnhance.Contrast(img).enhance(v)
69 |
70 | def Color(img, v):
71 | return PIL.ImageEnhance.Color(img).enhance(v)
72 |
73 | def Brightness(img, v):
74 | return PIL.ImageEnhance.Brightness(img).enhance(v)
75 |
76 | def Sharpness(img, v):
77 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
78 |
79 | def augment_list():
80 | l = [
81 | (Identity, 0, 1),
82 | (AutoContrast, 0, 1),
83 | (Equalize, 0, 1),
84 | (Rotate, -30, 30),
85 | (Solarize, 0, 256),
86 | (Color, 0.05, 0.95),
87 | (Contrast, 0.05, 0.95),
88 | (Brightness, 0.05, 0.95),
89 | (Sharpness, 0.05, 0.95),
90 | (ShearX, -0.1, 0.1),
91 | (TranslateX, -0.1, 0.1),
92 | (TranslateY, -0.1, 0.1),
93 | (Posterize, 4, 8),
94 | (ShearY, -0.1, 0.1),
95 | ]
96 | return l
97 |
98 |
99 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
100 |
101 | class Augment:
102 | def __init__(self, n):
103 | self.n = n
104 | self.augment_list = augment_list()
105 |
106 | def __call__(self, img):
107 | ops = random.choices(self.augment_list, k=self.n)
108 | for op, minval, maxval in ops:
109 | val = (random.random()) * float(maxval - minval) + minval
110 | img = op(img, val)
111 |
112 | return img
113 |
114 | def get_augment(name):
115 | return augment_dict[name]
116 |
117 | def apply_augment(img, name, level):
118 | augment_fn, low, high = get_augment(name)
119 | return augment_fn(img.copy(), level * (high - low) + low)
120 |
121 | class Cutout(object):
122 | def __init__(self, n_holes, length, random=False):
123 | self.n_holes = n_holes
124 | self.length = length
125 | self.random = random
126 |
127 | def __call__(self, img):
128 | h = img.size(1)
129 | w = img.size(2)
130 | length = random.randint(1, self.length)
131 | mask = np.ones((h, w), np.float32)
132 |
133 | for n in range(self.n_holes):
134 | y = np.random.randint(h)
135 | x = np.random.randint(w)
136 |
137 | y1 = np.clip(y - length // 2, 0, h)
138 | y2 = np.clip(y + length // 2, 0, h)
139 | x1 = np.clip(x - length // 2, 0, w)
140 | x2 = np.clip(x + length // 2, 0, w)
141 |
142 | mask[y1: y2, x1: x2] = 0.
143 |
144 | mask = torch.from_numpy(mask)
145 | mask = mask.expand_as(img)
146 | img = img * mask
147 |
148 | return img
149 |
--------------------------------------------------------------------------------
/scan/data/custom_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset
8 |
9 | """
10 | AugmentedDataset
11 | Returns an image together with an augmentation.
12 | """
13 | class AugmentedDataset(Dataset):
14 | def __init__(self, dataset):
15 | super(AugmentedDataset, self).__init__()
16 | transform = dataset.transform
17 | dataset.transform = None
18 | self.dataset = dataset
19 |
20 | if isinstance(transform, dict):
21 | self.image_transform = transform['standard']
22 | self.augmentation_transform = transform['augment']
23 |
24 | else:
25 | self.image_transform = transform
26 | self.augmentation_transform = transform
27 |
28 | def __len__(self):
29 | return len(self.dataset)
30 |
31 | def __getitem__(self, index):
32 | sample = self.dataset.__getitem__(index)
33 | image = sample['image']
34 |
35 | sample['image'] = self.image_transform(image)
36 | sample['image_augmented'] = self.augmentation_transform(image)
37 |
38 | return sample
39 |
40 |
41 | """
42 | NeighborsDataset
43 | Returns an image with one of its neighbors.
44 | """
45 | class NeighborsDataset(Dataset):
46 | def __init__(self, dataset, indices, num_neighbors=None):
47 | super(NeighborsDataset, self).__init__()
48 | transform = dataset.transform
49 |
50 | if isinstance(transform, dict):
51 | self.anchor_transform = transform['standard']
52 | self.neighbor_transform = transform['augment']
53 | else:
54 | self.anchor_transform = transform
55 | self.neighbor_transform = transform
56 |
57 | dataset.transform = None
58 | self.dataset = dataset
59 | self.indices = indices # Nearest neighbor indices (np.array [len(dataset) x k])
60 | if num_neighbors is not None:
61 | self.indices = self.indices[:, :num_neighbors+1]
62 | assert(self.indices.shape[0] == len(self.dataset))
63 |
64 | def __len__(self):
65 | return len(self.dataset)
66 |
67 | def __getitem__(self, index):
68 | output = {}
69 | anchor = self.dataset.__getitem__(index)
70 |
71 | neighbor_index = np.random.choice(self.indices[index], 1)[0]
72 | neighbor = self.dataset.__getitem__(neighbor_index)
73 |
74 | anchor['image'] = self.anchor_transform(anchor['image'])
75 | neighbor['image'] = self.neighbor_transform(neighbor['image'])
76 |
77 | output['anchor'] = anchor['image']
78 | output['neighbor'] = neighbor['image']
79 | output['possible_neighbors'] = torch.from_numpy(self.indices[index])
80 | output['target'] = anchor['target']
81 |
82 | return output
83 |
--------------------------------------------------------------------------------
/scan/data/imagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import os
6 | import torch
7 | import torchvision.datasets as datasets
8 | import torch.utils.data as data
9 | from PIL import Image
10 | from utils.mypath import MyPath
11 | from torchvision import transforms as tf
12 | from glob import glob
13 |
14 |
15 | class ImageNet(datasets.ImageFolder):
16 | def __init__(self, root=MyPath.db_root_dir('imagenet'), split='train', transform=None):
17 | super(ImageNet, self).__init__(root=os.path.join(root, 'ILSVRC2012_img_%s' %(split)),
18 | transform=None)
19 | self.transform = transform
20 | self.split = split
21 | self.resize = tf.Resize(256)
22 |
23 | def __len__(self):
24 | return len(self.imgs)
25 |
26 | def __getitem__(self, index):
27 | path, target = self.imgs[index]
28 | with open(path, 'rb') as f:
29 | img = Image.open(f).convert('RGB')
30 | im_size = img.size
31 | img = self.resize(img)
32 |
33 | if self.transform is not None:
34 | img = self.transform(img)
35 |
36 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index}}
37 |
38 | return out
39 |
40 | def get_image(self, index):
41 | path, target = self.imgs[index]
42 | with open(path, 'rb') as f:
43 | img = Image.open(f).convert('RGB')
44 | img = self.resize(img)
45 | return img
46 |
47 |
48 | class ImageNetSubset(data.Dataset):
49 | def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train',
50 | transform=None):
51 | super(ImageNetSubset, self).__init__()
52 |
53 | self.root = os.path.join(root, 'ILSVRC2012_img_%s' %(split))
54 | self.transform = transform
55 | self.split = split
56 |
57 | # Read the subset of classes to include (sorted)
58 | with open(subset_file, 'r') as f:
59 | result = f.read().splitlines()
60 | subdirs, class_names = [], []
61 | for line in result:
62 | subdir, class_name = line.split(' ', 1)
63 | subdirs.append(subdir)
64 | class_names.append(class_name)
65 |
66 | # Gather the files (sorted)
67 | imgs = []
68 | for i, subdir in enumerate(subdirs):
69 | subdir_path = os.path.join(self.root, subdir)
70 | files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG')))
71 | for f in files:
72 | imgs.append((f, i))
73 | self.imgs = imgs
74 | self.classes = class_names
75 |
76 | # Resize
77 | self.resize = tf.Resize(256)
78 |
79 | def get_image(self, index):
80 | path, target = self.imgs[index]
81 | with open(path, 'rb') as f:
82 | img = Image.open(f).convert('RGB')
83 | img = self.resize(img)
84 | return img
85 |
86 | def __len__(self):
87 | return len(self.imgs)
88 |
89 | def __getitem__(self, index):
90 | path, target = self.imgs[index]
91 | with open(path, 'rb') as f:
92 | img = Image.open(f).convert('RGB')
93 | im_size = img.size
94 | img = self.resize(img)
95 | class_name = self.classes[target]
96 |
97 | if self.transform is not None:
98 | img = self.transform(img)
99 |
100 | out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index, 'class_name': class_name}}
101 |
102 | return out
103 |
--------------------------------------------------------------------------------
/scan/data/imagenet_subsets/imagenet_100.txt:
--------------------------------------------------------------------------------
1 | n01558993 robin, American robin, Turdus migratorius
2 | n01601694 water ouzel, dipper
3 | n01669191 box turtle, box tortoise
4 | n01751748 sea snake
5 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus
6 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes
7 | n01770393 scorpion
8 | n01855672 goose
9 | n01871265 tusker
10 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana
11 | n02037110 oystercatcher, oyster catcher
12 | n02058221 albatross, mollymawk
13 | n02087046 toy terrier
14 | n02088632 bluetick
15 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier
16 | n02093754 Border terrier
17 | n02094114 Norfolk terrier
18 | n02096177 cairn, cairn terrier
19 | n02097130 giant schnauzer
20 | n02097298 Scotch terrier, Scottish terrier, Scottie
21 | n02099267 flat-coated retriever
22 | n02100877 Irish setter, red setter
23 | n02104365 schipperke
24 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland
25 | n02106030 collie
26 | n02106166 Border collie
27 | n02107142 Doberman, Doberman pinscher
28 | n02110341 dalmatian, coach dog, carriage dog
29 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans
30 | n02120079 Arctic fox, white fox, Alopex lagopus
31 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus
32 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
33 | n02128385 leopard, Panthera pardus
34 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus
35 | n02277742 ringlet, ringlet butterfly
36 | n02325366 wood rabbit, cottontail, cottontail rabbit
37 | n02364673 guinea pig, Cavia cobaya
38 | n02484975 guenon, guenon monkey
39 | n02489166 proboscis monkey, Nasalis larvatus
40 | n02708093 analog clock
41 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
42 | n02835271 bicycle-built-for-two, tandem bicycle, tandem
43 | n02906734 broom
44 | n02909870 bucket, pail
45 | n03085013 computer keyboard, keypad
46 | n03124170 cowboy hat, ten-gallon hat
47 | n03127747 crash helmet
48 | n03160309 dam, dike, dyke
49 | n03255030 dumbbell
50 | n03272010 electric guitar
51 | n03291819 envelope
52 | n03337140 file, file cabinet, filing cabinet
53 | n03450230 gown
54 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier
55 | n03498962 hatchet
56 | n03530642 honeycomb
57 | n03623198 knee pad
58 | n03649909 lawn mower, mower
59 | n03710721 maillot, tank suit
60 | n03717622 manhole cover
61 | n03733281 maze, labyrinth
62 | n03759954 microphone, mike
63 | n03775071 mitten
64 | n03814639 neck brace
65 | n03837869 obelisk
66 | n03838899 oboe, hautboy, hautbois
67 | n03854065 organ, pipe organ
68 | n03929855 pickelhaube
69 | n03930313 picket fence, paling
70 | n03954731 plane, carpenter's plane, woodworking plane
71 | n03956157 planetarium
72 | n03983396 pop bottle, soda bottle
73 | n04004767 printer
74 | n04026417 purse
75 | n04065272 recreational vehicle, RV, R.V.
76 | n04200800 shoe shop, shoe-shop, shoe store
77 | n04209239 shower curtain
78 | n04235860 sleeping bag
79 | n04311004 steel arch bridge
80 | n04325704 stole
81 | n04336792 stretcher
82 | n04346328 stupa, tope
83 | n04380533 table lamp
84 | n04428191 thresher, thrasher, threshing machine
85 | n04443257 tobacco shop, tobacconist shop, tobacconist
86 | n04458633 totem pole
87 | n04483307 trimaran
88 | n04509417 unicycle, monocycle
89 | n04515003 upright, upright piano
90 | n04525305 vending machine
91 | n04554684 washer, automatic washer, washing machine
92 | n04591157 Windsor tie
93 | n04592741 wing
94 | n04606251 wreck
95 | n07583066 guacamole
96 | n07613480 trifle
97 | n07693725 bagel, beigel
98 | n07711569 mashed potato
99 | n07753592 banana
100 | n11879895 rapeseed
101 |
--------------------------------------------------------------------------------
/scan/data/imagenet_subsets/imagenet_50.txt:
--------------------------------------------------------------------------------
1 | n01601694 Dipper
2 | n01669191 Box Turtle
3 | n01755581 Diamondback Snake
4 | n01770393 Scorpion
5 | n01855672 Goose
6 | n02018207 Water Hen
7 | n02058221 Albatross
8 | n02096177 Cairn Terrier
9 | n02097130 Giant Schnauzer
10 | n02099267 Flat-Coated Retriever
11 | n02100877 Irish Setter
12 | n02104365 Schipperke
13 | n02106030 Collie
14 | n02114855 Coyote
15 | n02125311 Mountain Lion
16 | n02133161 Black Bear
17 | n02484975 Guenon
18 | n02489166 Proboscis Monkey
19 | n02747177 Trash Bin
20 | n02906734 Broom
21 | n03124170 Cowboy Hat
22 | n03272010 Electric Guitar
23 | n03337140 File Cabinet
24 | n03483316 Hair Dryer
25 | n03498962 Hatchet
26 | n03710721 Maillot
27 | n03717622 Manhole Cover
28 | n03733281 Maze
29 | n03759954 Microphone
30 | n03775071 Mitten
31 | n03814639 Neck Brace
32 | n03837869 Obelisk
33 | n03838899 Oboe
34 | n03854065 Organ
35 | n03954731 Woodworking Plane
36 | n03983396 Soda Bottle
37 | n04026417 Purse
38 | n04200800 Shoe Shop
39 | n04209239 Shower Curtain
40 | n04311004 Steel Arch Bridge
41 | n04380533 Table Lamp
42 | n04428191 Threshing Machine
43 | n04443257 Tobacco Shop
44 | n04509417 Unicycle
45 | n04525305 Vending Machine
46 | n04554684 Washing Machine
47 | n04606251 Wreck
48 | n07583066 Guacamole
49 | n07711569 Mashed Potato
50 | n07753592 Banana
51 |
--------------------------------------------------------------------------------
/scan/data/tinyimagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause.
3 | """
4 | import os
5 | import pickle
6 | from PIL import Image
7 | from torchvision.datasets.utils import check_integrity
8 | from torchvision import datasets
9 | from typing import Any
10 |
11 |
12 | def unpickle_object(path):
13 | with open(path, 'rb+') as file_pi:
14 | res = pickle.load(file_pi)
15 | return res
16 |
17 |
18 | class TinyImageNet(datasets.VisionDataset):
19 | """`Tiny ImageNet Classification Dataset.
20 |
21 | Args:
22 | root (string): Root directory of the ImageNet Dataset.
23 | split (string, optional): The dataset split, supports ``train``, or ``val``.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.RandomCrop``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | loader (callable, optional): A function to load an image given its path.
29 |
30 | Attributes:
31 | classes (list): List of the class name tuples.
32 | class_to_idx (dict): Dict with items (class_name, class_index).
33 | wnids (list): List of the WordNet IDs.
34 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
35 | samples (list): List of (image path, class_index) tuples
36 | targets (list): The class_index value for each image in the dataset
37 | """
38 |
39 | def __init__(self, root: str, split: str = 'train', transform=None, **kwargs: Any) -> None:
40 | self.root = root
41 | if split == 'train+unlabeled':
42 | split = 'train'
43 | self.split = datasets.utils.verify_str_arg(split, "split", ("train", "val"))
44 |
45 | if self.split == 'train':
46 | self.images, self.targets, self.cls_to_id = unpickle_object('../../../daphna/data/tiny_imagenet/tiny-imagenet-200/train.pkl')
47 | elif self.split == 'val':
48 | self.images, self.targets, self.cls_to_id = unpickle_object('../../../daphna/data/tiny_imagenet/tiny-imagenet-200/val.pkl')
49 | else:
50 | raise NotImplementedError('unknown split')
51 | self.targets = self.targets.astype(int)
52 | self.classes = list(self.cls_to_id.keys())
53 | super(TinyImageNet, self).__init__(root, **kwargs)
54 | self.transform = transform
55 |
56 | # Split folder is used for the 'super' call. Since val directory is not structured like the train,
57 | # we simply use train's structure to get all classes and other stuff
58 | @property
59 | def split_folder(self) -> str:
60 | return os.path.join(self.root, 'train')
61 |
62 | def __getitem__(self, index: int):
63 | """
64 | Args:
65 | index (int): Index
66 |
67 | Returns:
68 | tuple: (sample, target) where target is class_index of the target class.
69 | """
70 | sample = Image.fromarray(self.images[index])
71 | target = int(self.targets[index])
72 |
73 | if self.transform is not None:
74 | sample = self.transform(sample)
75 |
76 | out = {'image': sample, 'target': target, 'meta': {'im_size': 64, 'index': index, 'class_name': target}}
77 | return out
78 |
79 | def __len__(self):
80 | return len(self.targets)
81 |
--------------------------------------------------------------------------------
/scan/images/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/pipeline.png
--------------------------------------------------------------------------------
/scan/images/prototypes_cifar10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/prototypes_cifar10.jpg
--------------------------------------------------------------------------------
/scan/images/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/teaser.jpg
--------------------------------------------------------------------------------
/scan/images/tutorial/confusion_matrix_stl10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/tutorial/confusion_matrix_stl10.png
--------------------------------------------------------------------------------
/scan/images/tutorial/prototypes_stl10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/avihu111/TypiClust/4097a71c348f60492ab22be0c4c9da224e637af6/scan/images/tutorial/prototypes_stl10.jpg
--------------------------------------------------------------------------------
/scan/losses/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | EPS=1e-8
9 |
10 |
11 | class MaskedCrossEntropyLoss(nn.Module):
12 | def __init__(self):
13 | super(MaskedCrossEntropyLoss, self).__init__()
14 |
15 | def forward(self, input, target, mask, weight, reduction='mean'):
16 | if not (mask != 0).any():
17 | raise ValueError('Mask in MaskedCrossEntropyLoss is all zeros.')
18 | target = torch.masked_select(target, mask)
19 | b, c = input.size()
20 | n = target.size(0)
21 | input = torch.masked_select(input, mask.view(b, 1)).view(n, c)
22 | return F.cross_entropy(input, target, weight = weight, reduction = reduction)
23 |
24 |
25 | class ConfidenceBasedCE(nn.Module):
26 | def __init__(self, threshold, apply_class_balancing):
27 | super(ConfidenceBasedCE, self).__init__()
28 | self.loss = MaskedCrossEntropyLoss()
29 | self.softmax = nn.Softmax(dim = 1)
30 | self.threshold = threshold
31 | self.apply_class_balancing = apply_class_balancing
32 |
33 | def forward(self, anchors_weak, anchors_strong):
34 | """
35 | Loss function during self-labeling
36 |
37 | input: logits for original samples and for its strong augmentations
38 | output: cross entropy
39 | """
40 | # Retrieve target and mask based on weakly augmentated anchors
41 | weak_anchors_prob = self.softmax(anchors_weak)
42 | max_prob, target = torch.max(weak_anchors_prob, dim = 1)
43 | mask = max_prob > self.threshold
44 | b, c = weak_anchors_prob.size()
45 | target_masked = torch.masked_select(target, mask.squeeze())
46 | n = target_masked.size(0)
47 |
48 | # Inputs are strongly augmented anchors
49 | input_ = anchors_strong
50 |
51 | # Class balancing weights
52 | if self.apply_class_balancing:
53 | idx, counts = torch.unique(target_masked, return_counts = True)
54 | freq = 1/(counts.float()/n)
55 | weight = torch.ones(c).cuda()
56 | weight[idx] = freq
57 |
58 | else:
59 | weight = None
60 |
61 | # Loss
62 | loss = self.loss(input_, target, mask, weight = weight, reduction='mean')
63 |
64 | return loss
65 |
66 |
67 | def entropy(x, input_as_probabilities):
68 | """
69 | Helper function to compute the entropy over the batch
70 |
71 | input: batch w/ shape [b, num_classes]
72 | output: entropy value [is ideally -log(num_classes)]
73 | """
74 |
75 | if input_as_probabilities:
76 | x_ = torch.clamp(x, min = EPS)
77 | b = x_ * torch.log(x_)
78 | else:
79 | b = F.softmax(x, dim = 1) * F.log_softmax(x, dim = 1)
80 |
81 | if len(b.size()) == 2: # Sample-wise entropy
82 | return -b.sum(dim = 1).mean()
83 | elif len(b.size()) == 1: # Distribution-wise entropy
84 | return - b.sum()
85 | else:
86 | raise ValueError('Input tensor is %d-Dimensional' %(len(b.size())))
87 |
88 |
89 | class SCANLoss(nn.Module):
90 | def __init__(self, entropy_weight = 2.0):
91 | super(SCANLoss, self).__init__()
92 | self.softmax = nn.Softmax(dim = 1)
93 | self.bce = nn.BCELoss()
94 | self.entropy_weight = entropy_weight # Default = 2.0
95 |
96 | def forward(self, anchors, neighbors):
97 | """
98 | input:
99 | - anchors: logits for anchor images w/ shape [b, num_classes]
100 | - neighbors: logits for neighbor images w/ shape [b, num_classes]
101 |
102 | output:
103 | - Loss
104 | """
105 | # Softmax
106 | b, n = anchors.size()
107 | anchors_prob = self.softmax(anchors)
108 | positives_prob = self.softmax(neighbors)
109 |
110 | # Similarity in output space
111 | similarity = torch.bmm(anchors_prob.view(b, 1, n), positives_prob.view(b, n, 1)).squeeze()
112 | ones = torch.ones_like(similarity)
113 | consistency_loss = self.bce(similarity, ones)
114 |
115 | # Entropy loss
116 | entropy_loss = entropy(torch.mean(anchors_prob, 0), input_as_probabilities = True)
117 |
118 | # Total loss
119 | total_loss = consistency_loss - self.entropy_weight * entropy_loss
120 |
121 | return total_loss, consistency_loss, entropy_loss
122 |
123 |
124 | class SimCLRLoss(nn.Module):
125 | # Based on the implementation of SupContrast
126 | def __init__(self, temperature):
127 | super(SimCLRLoss, self).__init__()
128 | self.temperature = temperature
129 |
130 |
131 | def forward(self, features):
132 | """
133 | input:
134 | - features: hidden feature representation of shape [b, 2, dim]
135 |
136 | output:
137 | - loss: loss computed according to SimCLR
138 | """
139 |
140 | b, n, dim = features.size()
141 | assert(n == 2)
142 | mask = torch.eye(b, dtype=torch.float32).cuda()
143 |
144 | contrast_features = torch.cat(torch.unbind(features, dim=1), dim=0)
145 | anchor = features[:, 0]
146 |
147 | # Dot product
148 | dot_product = torch.matmul(anchor, contrast_features.T) / self.temperature
149 |
150 | # Log-sum trick for numerical stability
151 | logits_max, _ = torch.max(dot_product, dim=1, keepdim=True)
152 | logits = dot_product - logits_max.detach()
153 |
154 | mask = mask.repeat(1, 2)
155 | logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(b).view(-1, 1).cuda(), 0)
156 | mask = mask * logits_mask
157 |
158 | # Log-softmax
159 | exp_logits = torch.exp(logits) * logits_mask
160 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
161 |
162 | # Mean log-likelihood for positive
163 | loss = - ((mask * log_prob).sum(1) / mask.sum(1)).mean()
164 |
165 | return loss
166 |
--------------------------------------------------------------------------------
/scan/moco.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import argparse
6 | import os
7 | import torch
8 | import numpy as np
9 |
10 | from utils.config import create_config
11 | from utils.common_config import get_model, get_train_dataset,\
12 | get_val_dataset, get_val_dataloader, get_val_transformations
13 | from utils.memory import MemoryBank
14 | from utils.utils import fill_memory_bank
15 | from termcolor import colored
16 |
17 | # Parser
18 | parser = argparse.ArgumentParser(description='MoCo')
19 | parser.add_argument('--config_env',
20 | help='Config file for the environment')
21 | parser.add_argument('--config_exp',
22 | help='Config file for the experiment')
23 | args = parser.parse_args()
24 |
25 | def main():
26 | # Retrieve config file
27 | p = create_config(args.config_env, args.config_exp)
28 | print(colored(p, 'red'))
29 |
30 |
31 | # Model
32 | print(colored('Retrieve model', 'blue'))
33 | model = get_model(p)
34 | print('Model is {}'.format(model.__class__.__name__))
35 | print(model)
36 | model = torch.nn.DataParallel(model)
37 | model = model.cuda()
38 |
39 |
40 | # CUDNN
41 | print(colored('Set CuDNN benchmark', 'blue'))
42 | torch.backends.cudnn.benchmark = True
43 |
44 |
45 | # Dataset
46 | print(colored('Retrieve dataset', 'blue'))
47 | transforms = get_val_transformations(p)
48 | train_dataset = get_train_dataset(p, transforms)
49 | val_dataset = get_val_dataset(p, transforms)
50 | train_dataloader = get_val_dataloader(p, train_dataset)
51 | val_dataloader = get_val_dataloader(p, val_dataset)
52 | print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset)))
53 |
54 |
55 | # Memory Bank
56 | print(colored('Build MemoryBank', 'blue'))
57 | memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'], p['temperature'])
58 | memory_bank_train.cuda()
59 | memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'], p['temperature'])
60 | memory_bank_val.cuda()
61 |
62 |
63 | # Load the official MoCoV2 checkpoint
64 | print(colored('Downloading moco v2 checkpoint', 'blue'))
65 | os.system('wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar')
66 | moco_state = torch.load('moco_v2_800ep_pretrain.pth.tar', map_location='cpu')
67 |
68 |
69 | # Transfer moco weights
70 | print(colored('Transfer MoCo weights to model', 'blue'))
71 | new_state_dict = {}
72 | state_dict = moco_state['state_dict']
73 | for k in list(state_dict.keys()):
74 | # Copy backbone weights
75 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
76 | new_k = 'module.backbone.' + k[len('module.encoder_q.'):]
77 | new_state_dict[new_k] = state_dict[k]
78 |
79 | # Copy mlp weights
80 | elif k.startswith('module.encoder_q.fc'):
81 | new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.'):]
82 | new_state_dict[new_k] = state_dict[k]
83 |
84 | else:
85 | raise ValueError('Unexpected key {}'.format(k))
86 |
87 | model.load_state_dict(new_state_dict)
88 | os.system('rm -rf moco_v2_800ep_pretrain.pth.tar')
89 |
90 |
91 | # Save final model
92 | print(colored('Save pretext model', 'blue'))
93 | torch.save(model.module.state_dict(), p['pretext_model'])
94 | model.module.contrastive_head = torch.nn.Identity() # In this case, we mine the neighbors before the MLP.
95 |
96 |
97 | # Mine the topk nearest neighbors (Train)
98 | # These will be used for training with the SCAN-Loss.
99 | topk = 50
100 | print(colored('Mine the nearest neighbors (Train)(Top-%d)' %(topk), 'blue'))
101 | transforms = get_val_transformations(p)
102 | train_dataset = get_train_dataset(p, transforms)
103 | fill_memory_bank(train_dataloader, model, memory_bank_train)
104 | indices, acc = memory_bank_train.mine_nearest_neighbors(topk)
105 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc))
106 | np.save(p['topk_neighbors_train_path'], indices)
107 |
108 |
109 | # Mine the topk nearest neighbors (Validation)
110 | # These will be used for validation.
111 | topk = 5
112 | print(colored('Mine the nearest neighbors (Val)(Top-%d)' %(topk), 'blue'))
113 | fill_memory_bank(val_dataloader, model, memory_bank_val)
114 | print('Mine the neighbors')
115 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
116 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc))
117 | np.save(p['topk_neighbors_val_path'], indices)
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------
/scan/models/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ContrastiveModel(nn.Module):
11 | def __init__(self, backbone, head='mlp', features_dim=128):
12 | super(ContrastiveModel, self).__init__()
13 | self.backbone = backbone['backbone']
14 | self.backbone_dim = backbone['dim']
15 | self.head = head
16 |
17 | if head == 'linear':
18 | self.contrastive_head = nn.Linear(self.backbone_dim, features_dim)
19 |
20 | elif head == 'mlp':
21 | self.contrastive_head = nn.Sequential(
22 | nn.Linear(self.backbone_dim, self.backbone_dim),
23 | nn.ReLU(), nn.Linear(self.backbone_dim, features_dim))
24 |
25 | else:
26 | raise ValueError('Invalid head {}'.format(head))
27 |
28 | def forward(self, x, return_pre_last=False):
29 | pre_last = self.backbone(x)
30 | features = self.contrastive_head(pre_last)
31 | features = F.normalize(features, dim = 1)
32 | if return_pre_last:
33 | return features, pre_last
34 | return features
35 |
36 |
37 | class ClusteringModel(nn.Module):
38 | def __init__(self, backbone, nclusters, nheads=1):
39 | super(ClusteringModel, self).__init__()
40 | self.backbone = backbone['backbone']
41 | self.backbone_dim = backbone['dim']
42 | self.nheads = nheads
43 | assert(isinstance(self.nheads, int))
44 | assert(self.nheads > 0)
45 | self.cluster_head = nn.ModuleList([nn.Linear(self.backbone_dim, nclusters) for _ in range(self.nheads)])
46 |
47 | def forward(self, x, forward_pass='default'):
48 | if forward_pass == 'default':
49 | features = self.backbone(x)
50 | out = [cluster_head(features) for cluster_head in self.cluster_head]
51 |
52 | elif forward_pass == 'backbone':
53 | out = self.backbone(x)
54 |
55 | elif forward_pass == 'head':
56 | out = [cluster_head(x) for cluster_head in self.cluster_head]
57 |
58 | elif forward_pass == 'return_all':
59 | features = self.backbone(x)
60 | out = {'features': features, 'output': [cluster_head(features) for cluster_head in self.cluster_head]}
61 |
62 | else:
63 | raise ValueError('Invalid forward pass {}'.format(forward_pass))
64 |
65 | return out
66 |
--------------------------------------------------------------------------------
/scan/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import torch.nn as nn
6 | import torchvision.models as models
7 |
8 |
9 | def resnet50():
10 | backbone = models.__dict__['resnet50']()
11 | backbone.fc = nn.Identity()
12 | return {'backbone': backbone, 'dim': 2048}
13 |
--------------------------------------------------------------------------------
/scan/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, in_planes, planes, stride=1, is_last=False):
13 | super(BasicBlock, self).__init__()
14 | self.is_last = is_last
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 |
20 | self.shortcut = nn.Sequential()
21 | if stride != 1 or in_planes != self.expansion * planes:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(self.expansion * planes)
25 | )
26 |
27 | def forward(self, x):
28 | out = F.relu(self.bn1(self.conv1(x)))
29 | out = self.bn2(self.conv2(out))
30 | out += self.shortcut(x)
31 | preact = out
32 | out = F.relu(out)
33 | if self.is_last:
34 | return out, preact
35 | else:
36 | return out
37 |
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1, is_last=False):
43 | super(Bottleneck, self).__init__()
44 | self.is_last = is_last
45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
46 | self.bn1 = nn.BatchNorm2d(planes)
47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
51 |
52 | self.shortcut = nn.Sequential()
53 | if stride != 1 or in_planes != self.expansion * planes:
54 | self.shortcut = nn.Sequential(
55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
56 | nn.BatchNorm2d(self.expansion * planes)
57 | )
58 |
59 | def forward(self, x):
60 | out = F.relu(self.bn1(self.conv1(x)))
61 | out = F.relu(self.bn2(self.conv2(out)))
62 | out = self.bn3(self.conv3(out))
63 | out += self.shortcut(x)
64 | preact = out
65 | out = F.relu(out)
66 | if self.is_last:
67 | return out, preact
68 | else:
69 | return out
70 |
71 |
72 | class ResNet(nn.Module):
73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
74 | super(ResNet, self).__init__()
75 | self.in_planes = 64
76 |
77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
78 | bias=False)
79 | self.bn1 = nn.BatchNorm2d(64)
80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
85 |
86 | for m in self.modules():
87 | if isinstance(m, nn.Conv2d):
88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
90 | nn.init.constant_(m.weight, 1)
91 | nn.init.constant_(m.bias, 0)
92 |
93 | # Zero-initialize the last BN in each residual branch,
94 | # so that the residual branch starts with zeros, and each residual block behaves
95 | # like an identity. This improves the model by 0.2~0.3% according to:
96 | # https://arxiv.org/abs/1706.02677
97 | if zero_init_residual:
98 | for m in self.modules():
99 | if isinstance(m, Bottleneck):
100 | nn.init.constant_(m.bn3.weight, 0)
101 | elif isinstance(m, BasicBlock):
102 | nn.init.constant_(m.bn2.weight, 0)
103 |
104 | def _make_layer(self, block, planes, num_blocks, stride):
105 | strides = [stride] + [1] * (num_blocks - 1)
106 | layers = []
107 | for i in range(num_blocks):
108 | stride = strides[i]
109 | layers.append(block(self.in_planes, planes, stride))
110 | self.in_planes = planes * block.expansion
111 | return nn.Sequential(*layers)
112 |
113 | def forward(self, x):
114 | out = F.relu(self.bn1(self.conv1(x)))
115 | out = self.layer1(out)
116 | out = self.layer2(out)
117 | out = self.layer3(out)
118 | out = self.layer4(out)
119 | out = self.avgpool(out)
120 | out = torch.flatten(out, 1)
121 | return out
122 |
123 |
124 | def resnet18(**kwargs):
125 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512}
126 |
--------------------------------------------------------------------------------
/scan/models/resnet_stl.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, in_planes, planes, stride=1, is_last=False):
13 | super(BasicBlock, self).__init__()
14 | self.is_last = is_last
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 |
20 | self.shortcut = nn.Sequential()
21 | if stride != 1 or in_planes != self.expansion * planes:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(self.expansion * planes)
25 | )
26 |
27 | def forward(self, x):
28 | out = F.relu(self.bn1(self.conv1(x)))
29 | out = self.bn2(self.conv2(out))
30 | out += self.shortcut(x)
31 | preact = out
32 | out = F.relu(out)
33 | if self.is_last:
34 | return out, preact
35 | else:
36 | return out
37 |
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1, is_last=False):
43 | super(Bottleneck, self).__init__()
44 | self.is_last = is_last
45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
46 | self.bn1 = nn.BatchNorm2d(planes)
47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
51 |
52 | self.shortcut = nn.Sequential()
53 | if stride != 1 or in_planes != self.expansion * planes:
54 | self.shortcut = nn.Sequential(
55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
56 | nn.BatchNorm2d(self.expansion * planes)
57 | )
58 |
59 | def forward(self, x):
60 | out = F.relu(self.bn1(self.conv1(x)))
61 | out = F.relu(self.bn2(self.conv2(out)))
62 | out = self.bn3(self.conv3(out))
63 | out += self.shortcut(x)
64 | preact = out
65 | out = F.relu(out)
66 | if self.is_last:
67 | return out, preact
68 | else:
69 | return out
70 |
71 |
72 | class ResNet(nn.Module):
73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
74 | super(ResNet, self).__init__()
75 | self.in_planes = 64
76 |
77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
78 | bias=False)
79 | self.bn1 = nn.BatchNorm2d(64)
80 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
85 | self.avgpool = nn.AvgPool2d(7, stride=1)
86 |
87 | for m in self.modules():
88 | if isinstance(m, nn.Conv2d):
89 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
90 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
91 | nn.init.constant_(m.weight, 1)
92 | nn.init.constant_(m.bias, 0)
93 |
94 | # Zero-initialize the last BN in each residual branch,
95 | # so that the residual branch starts with zeros, and each residual block behaves
96 | # like an identity. This improves the model by 0.2~0.3% according to:
97 | # https://arxiv.org/abs/1706.02677
98 | if zero_init_residual:
99 | for m in self.modules():
100 | if isinstance(m, Bottleneck):
101 | nn.init.constant_(m.bn3.weight, 0)
102 | elif isinstance(m, BasicBlock):
103 | nn.init.constant_(m.bn2.weight, 0)
104 |
105 | def _make_layer(self, block, planes, num_blocks, stride):
106 | strides = [stride] + [1] * (num_blocks - 1)
107 | layers = []
108 | for i in range(num_blocks):
109 | stride = strides[i]
110 | layers.append(block(self.in_planes, planes, stride))
111 | self.in_planes = planes * block.expansion
112 | return nn.Sequential(*layers)
113 |
114 | def forward(self, x):
115 | out = self.maxpool(F.relu(self.bn1(self.conv1(x))))
116 | out = self.layer1(out)
117 | out = self.layer2(out)
118 | out = self.layer3(out)
119 | out = self.layer4(out)
120 | out = self.avgpool(out)
121 | out = torch.flatten(out, 1)
122 | return out
123 |
124 |
125 | def resnet18(**kwargs):
126 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512}
127 |
--------------------------------------------------------------------------------
/scan/models/resnet_tinyimagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, in_planes, planes, stride=1, is_last=False):
13 | super(BasicBlock, self).__init__()
14 | self.is_last = is_last
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 |
20 | self.shortcut = nn.Sequential()
21 | if stride != 1 or in_planes != self.expansion * planes:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(self.expansion * planes)
25 | )
26 |
27 | def forward(self, x):
28 | out = F.relu(self.bn1(self.conv1(x)))
29 | out = self.bn2(self.conv2(out))
30 | out += self.shortcut(x)
31 | preact = out
32 | out = F.relu(out)
33 | if self.is_last:
34 | return out, preact
35 | else:
36 | return out
37 |
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1, is_last=False):
43 | super(Bottleneck, self).__init__()
44 | self.is_last = is_last
45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
46 | self.bn1 = nn.BatchNorm2d(planes)
47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
50 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
51 |
52 | self.shortcut = nn.Sequential()
53 | if stride != 1 or in_planes != self.expansion * planes:
54 | self.shortcut = nn.Sequential(
55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
56 | nn.BatchNorm2d(self.expansion * planes)
57 | )
58 |
59 | def forward(self, x):
60 | out = F.relu(self.bn1(self.conv1(x)))
61 | out = F.relu(self.bn2(self.conv2(out)))
62 | out = self.bn3(self.conv3(out))
63 | out += self.shortcut(x)
64 | preact = out
65 | out = F.relu(out)
66 | if self.is_last:
67 | return out, preact
68 | else:
69 | return out
70 |
71 |
72 | class ResNet(nn.Module):
73 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
74 | super(ResNet, self).__init__()
75 | self.in_planes = 64
76 |
77 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
78 | bias=False)
79 | self.bn1 = nn.BatchNorm2d(64)
80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
85 |
86 | for m in self.modules():
87 | if isinstance(m, nn.Conv2d):
88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
90 | nn.init.constant_(m.weight, 1)
91 | nn.init.constant_(m.bias, 0)
92 |
93 | # Zero-initialize the last BN in each residual branch,
94 | # so that the residual branch starts with zeros, and each residual block behaves
95 | # like an identity. This improves the model by 0.2~0.3% according to:
96 | # https://arxiv.org/abs/1706.02677
97 | if zero_init_residual:
98 | for m in self.modules():
99 | if isinstance(m, Bottleneck):
100 | nn.init.constant_(m.bn3.weight, 0)
101 | elif isinstance(m, BasicBlock):
102 | nn.init.constant_(m.bn2.weight, 0)
103 |
104 | def _make_layer(self, block, planes, num_blocks, stride):
105 | strides = [stride] + [1] * (num_blocks - 1)
106 | layers = []
107 | for i in range(num_blocks):
108 | stride = strides[i]
109 | layers.append(block(self.in_planes, planes, stride))
110 | self.in_planes = planes * block.expansion
111 | return nn.Sequential(*layers)
112 |
113 | def forward(self, x):
114 | out = F.relu(self.bn1(self.conv1(x)))
115 | out = self.layer1(out)
116 | out = self.layer2(out)
117 | out = self.layer3(out)
118 | out = self.layer4(out)
119 | out = self.avgpool(out)
120 | out = torch.flatten(out, 1)
121 | return out
122 |
123 |
124 | def resnet18(**kwargs):
125 | return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512}
126 |
--------------------------------------------------------------------------------
/scan/requirements.txt:
--------------------------------------------------------------------------------
1 | """ This file contains a list of packages and their versions that were used to produce the results. """
2 | - _libgcc_mutex=0.1=main
3 | - blas=1.0=mkl
4 | - bzip2=1.0.8=h7b6447c_0
5 | - ca-certificates=2020.1.1=0
6 | - cairo=1.14.12=h8948797_3
7 | - certifi=2020.4.5.1=py37_0
8 | - cffi=1.14.0=py37h2e261b9_0
9 | - cmake=3.14.0=h52cb24c_0
10 | - cudatoolkit=10.0.130=0
11 | - cycler=0.10.0=py37_0
12 | - dbus=1.13.12=h746ee38_0
13 | - easydict=1.9=py_0
14 | - expat=2.2.6=he6710b0_0
15 | - faiss-gpu=1.6.3=py37h1a5d453_0
16 | - ffmpeg=4.0=hcdf2ecd_0
17 | - fontconfig=2.13.0=h9420a91_0
18 | - freeglut=3.0.0=hf484d3e_5
19 | - freetype=2.9.1=h8a8886c_1
20 | - glib=2.63.1=h5a9c865_0
21 | - graphite2=1.3.13=h23475e2_0
22 | - gst-plugins-base=1.14.0=hbbd80ab_1
23 | - gstreamer=1.14.0=hb453b48_1
24 | - h5py=2.8.0=py37h989c5e5_3
25 | - harfbuzz=1.8.8=hffaf4a1_0
26 | - hdf5=1.10.2=hba1933b_1
27 | - icu=58.2=h9c2bf20_1
28 | - imageio=2.8.0=py_0
29 | - intel-openmp=2020.0=166
30 | - jasper=2.0.14=h07fcdf6_1
31 | - joblib=0.14.1=py_0
32 | - jpeg=9b=h024ee3a_2
33 | - kiwisolver=1.1.0=py37he6710b0_0
34 | - krb5=1.17.1=h173b8e3_0
35 | - ld_impl_linux-64=2.33.1=h53a641e_7
36 | - libcurl=7.69.1=h20c2e04_0
37 | - libedit=3.1.20181209=hc058e9b_0
38 | - libffi=3.2.1=hd88cf55_4
39 | - libgcc-ng=9.1.0=hdf63c60_0
40 | - libgfortran-ng=7.3.0=hdf63c60_0
41 | - libglu=9.0.0=hf484d3e_1
42 | - libopencv=3.4.2=hb342d67_1
43 | - libopus=1.3.1=h7b6447c_0
44 | - libpng=1.6.37=hbc83047_0
45 | - libprotobuf=3.11.4=hd408876_0
46 | - libssh2=1.9.0=h1ba5d50_1
47 | - libstdcxx-ng=9.1.0=hdf63c60_0
48 | - libtiff=4.1.0=h2733197_0
49 | - libuuid=1.0.3=h1bed415_2
50 | - libvpx=1.7.0=h439df22_0
51 | - libxcb=1.13=h1bed415_1
52 | - libxml2=2.9.9=hea5a465_1
53 | - matplotlib=3.1.3=py37_0
54 | - matplotlib-base=3.1.3=py37hef1b27d_0
55 | - mkl=2020.0=166
56 | - mkl-service=2.3.0=py37he904b0f_0
57 | - mkl_fft=1.0.15=py37ha843d7b_0
58 | - mkl_random=1.1.0=py37hd6b4f25_0
59 | - ncurses=6.2=he6710b0_0
60 | - ninja=1.9.0=py37hfd86e86_0
61 | - numpy=1.18.1=py37h4f9e942_0
62 | - numpy-base=1.18.1=py37hde5b4d6_1
63 | - olefile=0.46=py_0
64 | - opencv=3.4.2=py37h6fd60c2_1
65 | - openssl=1.1.1g=h7b6447c_0
66 | - pcre=8.43=he6710b0_0
67 | - pillow=7.0.0=py37hb39fc2d_0
68 | - pip=20.0.2=py37_1
69 | - pixman=0.38.0=h7b6447c_0
70 | - protobuf=3.11.4=py37he6710b0_0
71 | - py-opencv=3.4.2=py37hb342d67_1
72 | - pycparser=2.20=py_0
73 | - pyparsing=2.4.6=py_0
74 | - pyqt=5.9.2=py37h05f1152_2
75 | - python=3.7.7=hcf32534_0_cpython
76 | - python-dateutil=2.8.1=py_0
77 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0
78 | - pyyaml=5.3.1=py37h7b6447c_0
79 | - qt=5.9.7=h5867ecd_1
80 | - readline=8.0=h7b6447c_0
81 | - rhash=1.3.8=h1ba5d50_0
82 | - scikit-learn=0.22.1=py37hd81dba3_0
83 | - scipy=1.4.1=py37h0b6359f_0
84 | - setuptools=46.1.3=py37_0
85 | - sip=4.19.8=py37hf484d3e_0
86 | - six=1.14.0=py37_0
87 | - sqlite=3.31.1=h7b6447c_0
88 | - swig=3.0.12=h38cdd7d_3
89 | - tensorboardx=2.0=py_0
90 | - termcolor=1.1.0=py37_1
91 | - tk=8.6.8=hbc83047_0
92 | - torchvision=0.5.0=py37_cu100
93 | - tornado=6.0.4=py37h7b6447c_1
94 | - typing=3.6.4=py37_0
95 | - wheel=0.34.2=py37_0
96 | - xz=5.2.4=h14c3975_4
97 | - yaml=0.1.7=had09818_2
98 | - zlib=1.2.11=h7b6447c_3
99 | - zstd=1.3.7=h0b5b093_0
100 | - pip:
101 | - blis==0.4.1
102 | - catalogue==1.0.0
103 | - chardet==3.0.4
104 | - cymem==2.0.3
105 | - en-core-web-sm==2.2.5
106 | - idna==2.9
107 | - importlib-metadata==1.6.0
108 | - murmurhash==1.0.2
109 | - plac==1.1.3
110 | - preshed==3.0.2
111 | - requests==2.23.0
112 | - spacy==2.2.4
113 | - srsly==1.0.2
114 | - thinc==7.4.0
115 | - tqdm==4.45.0
116 | - urllib3==1.25.8
117 | - wasabi==0.6.0
118 | - zipp==3.1.0
119 |
--------------------------------------------------------------------------------
/scan/selflabel.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import argparse
6 | import os
7 | import torch
8 |
9 | from utils.config import create_config
10 | from utils.common_config import get_train_dataset, get_train_transformations,\
11 | get_val_dataset, get_val_transformations,\
12 | get_train_dataloader, get_val_dataloader,\
13 | get_optimizer, get_model, adjust_learning_rate,\
14 | get_criterion
15 | from utils.ema import EMA
16 | from utils.evaluate_utils import get_predictions, hungarian_evaluate
17 | from utils.train_utils import selflabel_train
18 | from termcolor import colored
19 |
20 | # Parser
21 | parser = argparse.ArgumentParser(description='Self-labeling')
22 | parser.add_argument('--config_env',
23 | help='Config file for the environment')
24 | parser.add_argument('--config_exp',
25 | help='Config file for the experiment')
26 | args = parser.parse_args()
27 |
28 | def main():
29 | # Retrieve config file
30 | p = create_config(args.config_env, args.config_exp)
31 | print(colored(p, 'red'))
32 |
33 | # Get model
34 | print(colored('Retrieve model', 'blue'))
35 | model = get_model(p, p['scan_model'])
36 | print(model)
37 | model = torch.nn.DataParallel(model)
38 | model = model.cuda()
39 |
40 | # Get criterion
41 | print(colored('Get loss', 'blue'))
42 | criterion = get_criterion(p)
43 | criterion.cuda()
44 | print(criterion)
45 |
46 | # CUDNN
47 | print(colored('Set CuDNN benchmark', 'blue'))
48 | torch.backends.cudnn.benchmark = True
49 |
50 | # Optimizer
51 | print(colored('Retrieve optimizer', 'blue'))
52 | optimizer = get_optimizer(p, model)
53 | print(optimizer)
54 |
55 | # Dataset
56 | print(colored('Retrieve dataset', 'blue'))
57 |
58 | # Transforms
59 | strong_transforms = get_train_transformations(p)
60 | val_transforms = get_val_transformations(p)
61 | train_dataset = get_train_dataset(p, {'standard': val_transforms, 'augment': strong_transforms},
62 | split='train', to_augmented_dataset=True)
63 | train_dataloader = get_train_dataloader(p, train_dataset)
64 | val_dataset = get_val_dataset(p, val_transforms)
65 | val_dataloader = get_val_dataloader(p, val_dataset)
66 | print(colored('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)), 'yellow'))
67 |
68 | # Checkpoint
69 | if os.path.exists(p['selflabel_checkpoint']):
70 | print(colored('Restart from checkpoint {}'.format(p['selflabel_checkpoint']), 'blue'))
71 | checkpoint = torch.load(p['selflabel_checkpoint'], map_location='cpu')
72 | model.load_state_dict(checkpoint['model'])
73 | optimizer.load_state_dict(checkpoint['optimizer'])
74 | start_epoch = checkpoint['epoch']
75 |
76 | else:
77 | print(colored('No checkpoint file at {}'.format(p['selflabel_checkpoint']), 'blue'))
78 | start_epoch = 0
79 |
80 | # EMA
81 | if p['use_ema']:
82 | ema = EMA(model, alpha=p['ema_alpha'])
83 | else:
84 | ema = None
85 |
86 | # Main loop
87 | print(colored('Starting main loop', 'blue'))
88 |
89 | for epoch in range(start_epoch, p['epochs']):
90 | print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow'))
91 | print(colored('-'*10, 'yellow'))
92 |
93 | # Adjust lr
94 | lr = adjust_learning_rate(p, optimizer, epoch)
95 | print('Adjusted learning rate to {:.5f}'.format(lr))
96 |
97 | # Perform self-labeling
98 | print('Train ...')
99 | selflabel_train(train_dataloader, model, criterion, optimizer, epoch, ema=ema)
100 |
101 | # Evaluate (To monitor progress - Not for validation)
102 | print('Evaluate ...')
103 | predictions = get_predictions(p, val_dataloader, model)
104 | clustering_stats = hungarian_evaluate(0, predictions, compute_confusion_matrix=False)
105 | print(clustering_stats)
106 |
107 | # Checkpoint
108 | print('Checkpoint ...')
109 | torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(),
110 | 'epoch': epoch + 1}, p['selflabel_checkpoint'])
111 | torch.save(model.module.state_dict(), p['selflabel_model'])
112 |
113 | # Evaluate and save the final model
114 | print(colored('Evaluate model at the end', 'blue'))
115 | predictions = get_predictions(p, val_dataloader, model)
116 | clustering_stats = hungarian_evaluate(0, predictions,
117 | class_names=val_dataset.classes,
118 | compute_confusion_matrix=True,
119 | confusion_matrix_file=os.path.join(p['selflabel_dir'], 'confusion_matrix.png'))
120 | print(clustering_stats)
121 | torch.save(model.module.state_dict(), p['selflabel_model'])
122 |
123 |
124 | if __name__ == "__main__":
125 | main()
126 |
--------------------------------------------------------------------------------
/scan/tutorial_nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import argparse
6 | import os
7 | import numpy as np
8 | import torch
9 |
10 | from utils.config import create_config
11 | from utils.common_config import get_model, get_train_dataset, \
12 | get_val_dataset, \
13 | get_val_dataloader, \
14 | get_val_transformations \
15 |
16 | from utils.memory import MemoryBank
17 | from utils.train_utils import simclr_train
18 | from utils.utils import fill_memory_bank
19 | from termcolor import colored
20 |
21 | # Parser
22 | parser = argparse.ArgumentParser(description='Eval_nn')
23 | parser.add_argument('--config_env',
24 | help='Config file for the environment')
25 | parser.add_argument('--config_exp',
26 | help='Config file for the experiment')
27 | args = parser.parse_args()
28 |
29 | def main():
30 |
31 | # Retrieve config file
32 | p = create_config(args.config_env, args.config_exp)
33 | print(colored(p, 'red'))
34 |
35 | # Model
36 | print(colored('Retrieve model', 'blue'))
37 | model = get_model(p)
38 | print('Model is {}'.format(model.__class__.__name__))
39 | print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6))
40 | print(model)
41 | model = model.cuda()
42 |
43 | # CUDNN
44 | print(colored('Set CuDNN benchmark', 'blue'))
45 | torch.backends.cudnn.benchmark = True
46 |
47 | # Dataset
48 | val_transforms = get_val_transformations(p)
49 | print('Validation transforms:', val_transforms)
50 | val_dataset = get_val_dataset(p, val_transforms)
51 | val_dataloader = get_val_dataloader(p, val_dataset)
52 | print('Dataset contains {} val samples'.format(len(val_dataset)))
53 |
54 | # Memory Bank
55 | print(colored('Build MemoryBank', 'blue'))
56 | base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval
57 | base_dataloader = get_val_dataloader(p, base_dataset)
58 | memory_bank_base = MemoryBank(len(base_dataset),
59 | p['model_kwargs']['features_dim'],
60 | p['num_classes'], p['criterion_kwargs']['temperature'])
61 | memory_bank_base.cuda()
62 | memory_bank_val = MemoryBank(len(val_dataset),
63 | p['model_kwargs']['features_dim'],
64 | p['num_classes'], p['criterion_kwargs']['temperature'])
65 | memory_bank_val.cuda()
66 |
67 | # Checkpoint
68 | assert os.path.exists(p['pretext_checkpoint'])
69 | print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue'))
70 | checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu')
71 | model.load_state_dict(checkpoint)
72 | model.cuda()
73 |
74 | # Save model
75 | torch.save(model.state_dict(), p['pretext_model'])
76 |
77 | # Mine the topk nearest neighbors at the very end (Train)
78 | # These will be served as input to the SCAN loss.
79 | print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue'))
80 | fill_memory_bank(base_dataloader, model, memory_bank_base)
81 | topk = 20
82 | print('Mine the nearest neighbors (Top-%d)' %(topk))
83 | indices, acc = memory_bank_base.mine_nearest_neighbors(topk)
84 | print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc))
85 | np.save(p['topk_neighbors_train_path'], indices)
86 |
87 | # Mine the topk nearest neighbors at the very end (Val)
88 | # These will be used for validation.
89 | print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue'))
90 | fill_memory_bank(val_dataloader, model, memory_bank_val)
91 | topk = 5
92 | print('Mine the nearest neighbors (Top-%d)' %(topk))
93 | indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
94 | print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc))
95 | np.save(p['topk_neighbors_val_path'], indices)
96 |
97 |
98 | if __name__ == '__main__':
99 | main()
100 |
--------------------------------------------------------------------------------
/scan/utils/collate.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import torch
6 | import numpy as np
7 | import collections
8 | from torch._six import string_classes
9 |
10 |
11 | """ Custom collate function """
12 | def collate_custom(batch):
13 | if isinstance(batch[0], np.int64):
14 | return np.stack(batch, 0)
15 |
16 | if isinstance(batch[0], torch.Tensor):
17 | return torch.stack(batch, 0)
18 |
19 | elif isinstance(batch[0], np.ndarray):
20 | return np.stack(batch, 0)
21 |
22 | elif isinstance(batch[0], int):
23 | return torch.LongTensor(batch)
24 |
25 | elif isinstance(batch[0], float):
26 | return torch.FloatTensor(batch)
27 |
28 | elif isinstance(batch[0], string_classes):
29 | return batch
30 |
31 | elif isinstance(batch[0], collections.Mapping):
32 | batch_modified = {key: collate_custom([d[key] for d in batch]) for key in batch[0] if key.find('idx') < 0}
33 | return batch_modified
34 |
35 | elif isinstance(batch[0], collections.Sequence):
36 | transposed = zip(*batch)
37 | return [collate_custom(samples) for samples in transposed]
38 |
39 | raise TypeError(('Type is {}'.format(type(batch[0]))))
40 |
--------------------------------------------------------------------------------
/scan/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import os
6 | import yaml
7 | from easydict import EasyDict
8 | from utils.utils import mkdir_if_missing
9 |
10 | def create_config(config_file_env, config_file_exp, seed, num_clusters=None):
11 | # Config for environment path
12 | with open(config_file_env, 'r') as stream:
13 | root_dir = yaml.safe_load(stream)['root_dir']
14 |
15 | with open(config_file_exp, 'r') as stream:
16 | config = yaml.safe_load(stream)
17 |
18 | config['seed'] = seed
19 | if num_clusters is not None:
20 | config['num_classes'] = num_clusters
21 |
22 | cfg = EasyDict()
23 |
24 | # Copy
25 | for k, v in config.items():
26 | cfg[k] = v
27 |
28 | # Set paths for pretext task (These directories are needed in every stage)
29 | base_dir = os.path.join(root_dir, cfg['train_db_name'])
30 | pretext_dir = os.path.join(base_dir, 'pretext')
31 | mkdir_if_missing(base_dir)
32 | mkdir_if_missing(pretext_dir)
33 | cfg['pretext_dir'] = pretext_dir
34 | cfg['pretext_checkpoint'] = os.path.join(pretext_dir, f'checkpoint_seed{seed}.pth.tar')
35 | cfg['pretext_model'] = os.path.join(pretext_dir, f'model_seed{seed}.pth.tar')
36 | cfg['pretext_features'] = os.path.join(pretext_dir, f'features_seed{seed}.npy')
37 | cfg['topk_neighbors_train_path'] = os.path.join(pretext_dir, f'topk-train-neighbors_seed{seed}.npy')
38 | cfg['topk_neighbors_val_path'] = os.path.join(pretext_dir, f'topk-val-neighbors_seed{seed}.npy')
39 |
40 | # If we perform clustering or self-labeling step we need additional paths.
41 | # We also include a run identifier to support multiple runs w/ same hyperparams.
42 | if cfg['setup'] in ['scan', 'selflabel']:
43 | base_dir = os.path.join(root_dir, cfg['train_db_name'])
44 | scan_dir = os.path.join(base_dir, 'scan')
45 | selflabel_dir = os.path.join(base_dir, 'selflabel')
46 | mkdir_if_missing(base_dir)
47 | mkdir_if_missing(scan_dir)
48 | mkdir_if_missing(selflabel_dir)
49 | cfg['scan_dir'] = scan_dir
50 | cfg['scan_checkpoint'] = os.path.join(scan_dir, f'checkpoint_seed{seed}_clusters{num_clusters}.pth.tar')
51 | cfg['scan_model'] = os.path.join(scan_dir, f'model_seed{seed}_clusters{num_clusters}.pth.tar')
52 | cfg['scan_features'] = os.path.join(scan_dir, f'features_seed{seed}_clusters{num_clusters}.npy')
53 | cfg['selflabel_dir'] = selflabel_dir
54 | cfg['selflabel_checkpoint'] = os.path.join(selflabel_dir, f'checkpoint_seed{seed}_clusters{num_clusters}.pth.tar')
55 | cfg['selflabel_model'] = os.path.join(selflabel_dir, f'model_seed{seed}_clusters{num_clusters}.pth.tar')
56 |
57 | return cfg
58 |
--------------------------------------------------------------------------------
/scan/utils/ema.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 |
6 | class EMA(object):
7 | def __init__(self, model, alpha=0.999):
8 | self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()}
9 | self.param_keys = [k for k, _ in model.named_parameters()]
10 | self.alpha = alpha
11 |
12 | def update_params(self, model):
13 | state = model.state_dict()
14 | for name in self.param_keys:
15 | self.shadow[name].copy_(self.alpha * self.shadow[name] + (1 - self.alpha) * state[name])
16 |
17 | def apply_shadow(self, model):
18 | model.load_state_dict(self.shadow, strict=True)
19 |
--------------------------------------------------------------------------------
/scan/utils/memory.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import numpy as np
6 | import torch
7 |
8 |
9 | class MemoryBank(object):
10 | def __init__(self, n, dim, num_classes, temperature, feature_dim=512):
11 | self.n = n
12 | self.dim = dim
13 | self.features = torch.FloatTensor(self.n, self.dim)
14 | self.pre_lasts = torch.FloatTensor(self.n, feature_dim)
15 | self.targets = torch.LongTensor(self.n)
16 | self.ptr = 0
17 | self.device = 'cpu'
18 | self.K = 100
19 | self.temperature = temperature
20 | self.C = num_classes
21 |
22 | def weighted_knn(self, predictions):
23 | # perform weighted knn
24 | retrieval_one_hot = torch.zeros(self.K, self.C).to(self.device)
25 | batchSize = predictions.shape[0]
26 | correlation = torch.matmul(predictions, self.features.t())
27 | yd, yi = correlation.topk(self.K, dim=1, largest=True, sorted=True)
28 | candidates = self.targets.view(1,-1).expand(batchSize, -1)
29 | retrieval = torch.gather(candidates, 1, yi)
30 | retrieval_one_hot.resize_(batchSize * self.K, self.C).zero_()
31 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
32 | yd_transform = yd.clone().div_(self.temperature).exp_()
33 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , self.C),
34 | yd_transform.view(batchSize, -1, 1)), 1)
35 | _, class_preds = probs.sort(1, True)
36 | class_pred = class_preds[:, 0]
37 |
38 | return class_pred
39 |
40 | def knn(self, predictions):
41 | # perform knn
42 | correlation = torch.matmul(predictions, self.features.t())
43 | sample_pred = torch.argmax(correlation, dim=1)
44 | class_pred = torch.index_select(self.targets, 0, sample_pred)
45 | return class_pred
46 |
47 | def mine_nearest_neighbors(self, topk, calculate_accuracy=True):
48 | # mine the topk nearest neighbors for every sample
49 | import faiss
50 | features = self.features.cpu().numpy()
51 | n, dim = features.shape[0], features.shape[1]
52 | index = faiss.IndexFlatIP(dim)
53 | index = faiss.index_cpu_to_all_gpus(index)
54 | index.add(features)
55 | distances, indices = index.search(features, topk+1) # Sample itself is included
56 |
57 | # evaluate
58 | if calculate_accuracy:
59 | targets = self.targets.cpu().numpy()
60 | neighbor_targets = np.take(targets, indices[:,1:], axis=0) # Exclude sample itself for eval
61 | anchor_targets = np.repeat(targets.reshape(-1,1), topk, axis=1)
62 | accuracy = np.mean(neighbor_targets == anchor_targets)
63 | return indices, accuracy
64 |
65 | else:
66 | return indices
67 |
68 | def reset(self):
69 | self.ptr = 0
70 |
71 | def update(self, features, pre_last, targets):
72 | b = features.size(0)
73 |
74 | assert(b + self.ptr <= self.n)
75 |
76 | self.features[self.ptr:self.ptr+b].copy_(features.detach())
77 | self.pre_lasts[self.ptr:self.ptr+b].copy_(pre_last.detach())
78 | self.targets[self.ptr:self.ptr+b].copy_(targets.detach())
79 | self.ptr += b
80 |
81 | def to(self, device):
82 | self.features = self.features.to(device)
83 | self.pre_lasts = self.pre_lasts.to(device)
84 | self.targets = self.targets.to(device)
85 | self.device = device
86 |
87 | def cpu(self):
88 | self.to('cpu')
89 |
90 | def cuda(self):
91 | self.to('cuda:0')
92 |
--------------------------------------------------------------------------------
/scan/utils/mypath.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import os
6 |
7 |
8 | class MyPath(object):
9 | @staticmethod
10 | def db_root_dir(database=''):
11 | db_names = {'cifar-10', 'stl-10', 'cifar-100', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200'}
12 | assert(database in db_names)
13 |
14 | if database == 'cifar-10':
15 | return './datasets/cifar-10/'
16 |
17 | elif database == 'cifar-100':
18 | return './datasets/cifar-100/'
19 |
20 | elif database == 'stl-10':
21 | return './datasets/stl-10/'
22 |
23 | elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']:
24 | return './datasets/imagenet/'
25 |
26 | else:
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/scan/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import torch
6 | import numpy as np
7 | from utils.utils import AverageMeter, ProgressMeter
8 |
9 |
10 | def simclr_train(train_loader, model, criterion, optimizer, epoch):
11 | """
12 | Train according to the scheme from SimCLR
13 | https://arxiv.org/abs/2002.05709
14 | """
15 | losses = AverageMeter('Loss', ':.4e')
16 | progress = ProgressMeter(len(train_loader),
17 | [losses],
18 | prefix="Epoch: [{}]".format(epoch))
19 |
20 | model.train()
21 |
22 | for i, batch in enumerate(train_loader):
23 | images = batch['image']
24 | images_augmented = batch['image_augmented']
25 | b, c, h, w = images.size()
26 | input_ = torch.cat([images.unsqueeze(1), images_augmented.unsqueeze(1)], dim=1)
27 | input_ = input_.view(-1, c, h, w)
28 | input_ = input_.cuda(non_blocking=True)
29 | targets = batch['target'].cuda(non_blocking=True)
30 |
31 | output = model(input_).view(b, 2, -1)
32 | loss = criterion(output)
33 | losses.update(loss.item())
34 |
35 | optimizer.zero_grad()
36 | loss.backward()
37 | optimizer.step()
38 |
39 | if i % 25 == 0:
40 | progress.display(i)
41 |
42 |
43 | def scan_train(train_loader, model, criterion, optimizer, epoch, update_cluster_head_only=False):
44 | """
45 | Train w/ SCAN-Loss
46 | """
47 | total_losses = AverageMeter('Total Loss', ':.4e')
48 | consistency_losses = AverageMeter('Consistency Loss', ':.4e')
49 | entropy_losses = AverageMeter('Entropy', ':.4e')
50 | progress = ProgressMeter(len(train_loader),
51 | [total_losses, consistency_losses, entropy_losses],
52 | prefix="Epoch: [{}]".format(epoch))
53 |
54 | if update_cluster_head_only:
55 | model.eval() # No need to update BN
56 | else:
57 | model.train() # Update BN
58 |
59 | for i, batch in enumerate(train_loader):
60 | # Forward pass
61 | anchors = batch['anchor'].cuda(non_blocking=True)
62 | neighbors = batch['neighbor'].cuda(non_blocking=True)
63 |
64 | if update_cluster_head_only: # Only calculate gradient for backprop of linear layer
65 | with torch.no_grad():
66 | anchors_features = model(anchors, forward_pass='backbone')
67 | neighbors_features = model(neighbors, forward_pass='backbone')
68 | anchors_output = model(anchors_features, forward_pass='head')
69 | neighbors_output = model(neighbors_features, forward_pass='head')
70 |
71 | else: # Calculate gradient for backprop of complete network
72 | anchors_output = model(anchors)
73 | neighbors_output = model(neighbors)
74 |
75 | # Loss for every head
76 | total_loss, consistency_loss, entropy_loss = [], [], []
77 | for anchors_output_subhead, neighbors_output_subhead in zip(anchors_output, neighbors_output):
78 | total_loss_, consistency_loss_, entropy_loss_ = criterion(anchors_output_subhead,
79 | neighbors_output_subhead)
80 | total_loss.append(total_loss_)
81 | consistency_loss.append(consistency_loss_)
82 | entropy_loss.append(entropy_loss_)
83 |
84 | # Register the mean loss and backprop the total loss to cover all subheads
85 | total_losses.update(np.mean([v.item() for v in total_loss]))
86 | consistency_losses.update(np.mean([v.item() for v in consistency_loss]))
87 | entropy_losses.update(np.mean([v.item() for v in entropy_loss]))
88 |
89 | total_loss = torch.sum(torch.stack(total_loss, dim=0))
90 |
91 | optimizer.zero_grad()
92 | total_loss.backward()
93 | optimizer.step()
94 |
95 | if i % 25 == 0:
96 | progress.display(i)
97 |
98 |
99 | def selflabel_train(train_loader, model, criterion, optimizer, epoch, ema=None):
100 | """
101 | Self-labeling based on confident samples
102 | """
103 | losses = AverageMeter('Loss', ':.4e')
104 | progress = ProgressMeter(len(train_loader), [losses],
105 | prefix="Epoch: [{}]".format(epoch))
106 | model.train()
107 |
108 | for i, batch in enumerate(train_loader):
109 | images = batch['image'].cuda(non_blocking=True)
110 | images_augmented = batch['image_augmented'].cuda(non_blocking=True)
111 |
112 | with torch.no_grad():
113 | output = model(images)[0]
114 | output_augmented = model(images_augmented)[0]
115 |
116 | loss = criterion(output, output_augmented)
117 | losses.update(loss.item())
118 |
119 | optimizer.zero_grad()
120 | loss.backward()
121 | optimizer.step()
122 |
123 | if ema is not None: # Apply EMA to update the weights of the network
124 | ema.update_params(model)
125 | ema.apply_shadow(model)
126 |
127 | if i % 25 == 0:
128 | progress.display(i)
129 |
--------------------------------------------------------------------------------
/scan/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import os
6 | import torch
7 | import numpy as np
8 | import errno
9 |
10 | def mkdir_if_missing(directory):
11 | if not os.path.exists(directory):
12 | try:
13 | os.makedirs(directory)
14 | except OSError as e:
15 | if e.errno != errno.EEXIST:
16 | raise
17 |
18 |
19 | class AverageMeter(object):
20 | def __init__(self, name, fmt=':f'):
21 | self.name = name
22 | self.fmt = fmt
23 | self.reset()
24 |
25 | def reset(self):
26 | self.val = 0
27 | self.avg = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def update(self, val, n=1):
32 | self.val = val
33 | self.sum += val * n
34 | self.count += n
35 | self.avg = self.sum / self.count
36 |
37 | def __str__(self):
38 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
39 | return fmtstr.format(**self.__dict__)
40 |
41 |
42 | class ProgressMeter(object):
43 | def __init__(self, num_batches, meters, prefix=""):
44 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
45 | self.meters = meters
46 | self.prefix = prefix
47 |
48 | def display(self, batch):
49 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
50 | entries += [str(meter) for meter in self.meters]
51 | print('\t'.join(entries))
52 |
53 | def _get_batch_fmtstr(self, num_batches):
54 | num_digits = len(str(num_batches // 1))
55 | fmt = '{:' + str(num_digits) + 'd}'
56 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
57 |
58 |
59 | @torch.no_grad()
60 | def fill_memory_bank(loader, model, memory_bank):
61 | model.eval()
62 | memory_bank.reset()
63 |
64 | for i, batch in enumerate(loader):
65 | images = batch['image'].cuda(non_blocking=True)
66 | targets = batch['target'].cuda(non_blocking=True)
67 | output, pre_last = model(images, return_pre_last=True)
68 | memory_bank.update(output, pre_last, targets)
69 | if i % 100 == 0:
70 | print('Fill Memory Bank [%d/%d]' %(i, len(loader)))
71 |
72 |
73 | def confusion_matrix(predictions, gt, class_names, output_file=None):
74 | # Plot confusion_matrix and store result to output_file
75 | import sklearn.metrics
76 | import matplotlib.pyplot as plt
77 | confusion_matrix = sklearn.metrics.confusion_matrix(gt, predictions)
78 | confusion_matrix = confusion_matrix / np.sum(confusion_matrix, 1)
79 |
80 | fig, axes = plt.subplots(1)
81 | plt.imshow(confusion_matrix, cmap='Blues')
82 | axes.set_xticks([i for i in range(len(class_names))])
83 | axes.set_yticks([i for i in range(len(class_names))])
84 | axes.set_xticklabels(class_names, ha='right', fontsize=8, rotation=40)
85 | axes.set_yticklabels(class_names, ha='right', fontsize=8)
86 |
87 | for (i, j), z in np.ndenumerate(confusion_matrix):
88 | if i == j:
89 | axes.text(j, i, '%d' %(100*z), ha='center', va='center', color='white', fontsize=6)
90 | else:
91 | pass
92 |
93 | plt.tight_layout()
94 | if output_file is None:
95 | plt.show()
96 | else:
97 | plt.savefig(output_file, dpi=300, bbox_inches='tight')
98 | plt.close()
99 |
--------------------------------------------------------------------------------