├── .gitignore ├── LICENSE ├── README.md ├── configs ├── cifar10 │ ├── al │ │ ├── RESNET18.yaml │ │ ├── RESNET18_ENS.yaml │ │ └── RESNET18_IM.yaml │ ├── evaluate │ │ └── RESNET18.yaml │ └── train │ │ ├── RESNET18.yaml │ │ └── RESNET18_ENS.yaml ├── cifar100 │ └── al │ │ └── RESNET18.yaml ├── mnist │ └── al │ │ └── RESNET18.yaml ├── template.yaml └── tinyimagenet │ └── al │ └── RESNET18.yaml ├── docs ├── AL_results.png └── GETTING_STARTED.md ├── output ├── CIFAR10 │ └── resnet18 │ │ └── ENT_1 │ │ ├── Episodes_vs_Test Accuracy.png │ │ ├── episode_0 │ │ ├── Epochs_vs_Loss.png │ │ ├── Epochs_vs_Validation Accuracy.png │ │ ├── Iterations_vs_Loss.png │ │ ├── activeSet.npy │ │ ├── lSet.npy │ │ ├── plot_epoch_xvalues.txt │ │ ├── plot_epoch_yvalues.txt │ │ ├── plot_it_x_values.txt │ │ ├── plot_it_y_values.txt │ │ ├── uSet.npy │ │ ├── val_acc_epochs_x.txt │ │ └── val_acc_epochs_y.txt │ │ ├── episode_1 │ │ ├── Epochs_vs_Loss.png │ │ ├── Epochs_vs_Validation Accuracy.png │ │ ├── Iterations_vs_Loss.png │ │ ├── activeSet.npy │ │ ├── lSet.npy │ │ ├── plot_epoch_xvalues.txt │ │ ├── plot_epoch_yvalues.txt │ │ ├── plot_it_x_values.txt │ │ ├── plot_it_y_values.txt │ │ ├── uSet.npy │ │ ├── val_acc_epochs_x.txt │ │ └── val_acc_epochs_y.txt │ │ ├── episode_2 │ │ ├── Epochs_vs_Loss.png │ │ ├── Epochs_vs_Validation Accuracy.png │ │ ├── Iterations_vs_Loss.png │ │ ├── activeSet.npy │ │ ├── lSet.npy │ │ ├── plot_epoch_xvalues.txt │ │ ├── plot_epoch_yvalues.txt │ │ ├── plot_it_x_values.txt │ │ ├── plot_it_y_values.txt │ │ ├── uSet.npy │ │ ├── val_acc_epochs_x.txt │ │ └── val_acc_epochs_y.txt │ │ ├── episode_3 │ │ ├── Epochs_vs_Loss.png │ │ ├── Epochs_vs_Validation Accuracy.png │ │ ├── Iterations_vs_Loss.png │ │ ├── activeSet.npy │ │ ├── lSet.npy │ │ ├── plot_epoch_xvalues.txt │ │ ├── plot_epoch_yvalues.txt │ │ ├── plot_it_x_values.txt │ │ ├── plot_it_y_values.txt │ │ ├── uSet.npy │ │ ├── val_acc_epochs_x.txt │ │ └── val_acc_epochs_y.txt │ │ ├── episode_4 │ │ ├── Epochs_vs_Loss.png │ │ ├── Epochs_vs_Validation Accuracy.png │ │ ├── Iterations_vs_Loss.png │ │ ├── activeSet.npy │ │ ├── lSet.npy │ │ ├── plot_epoch_xvalues.txt │ │ ├── plot_epoch_yvalues.txt │ │ ├── plot_it_x_values.txt │ │ ├── plot_it_y_values.txt │ │ ├── uSet.npy │ │ ├── val_acc_epochs_x.txt │ │ └── val_acc_epochs_y.txt │ │ ├── episode_5 │ │ ├── Epochs_vs_Loss.png │ │ ├── Epochs_vs_Validation Accuracy.png │ │ ├── Iterations_vs_Loss.png │ │ ├── plot_epoch_xvalues.txt │ │ ├── plot_epoch_yvalues.txt │ │ ├── plot_it_x_values.txt │ │ ├── plot_it_y_values.txt │ │ ├── val_acc_epochs_x.txt │ │ └── val_acc_epochs_y.txt │ │ ├── lSet.npy │ │ ├── plot_episode_xvalues.txt │ │ ├── plot_episode_yvalues.txt │ │ ├── stdout.log │ │ ├── uSet.npy │ │ └── valSet.npy └── results_aggregator.ipynb ├── pycls ├── __init__.py ├── al │ ├── ActiveLearning.py │ ├── Sampling.py │ ├── __init__.py │ └── vaal_util.py ├── core │ ├── __init__.py │ ├── builders.py │ ├── config.py │ ├── losses.py │ ├── net.py │ └── optimizer.py ├── datasets │ ├── __init__.py │ ├── augment.py │ ├── custom_datasets.py │ ├── data.py │ ├── imbalanced_cifar.py │ ├── randaugment.py │ ├── sampler.py │ ├── simclr_augment.py │ ├── tiny_imagenet.py │ └── utils │ │ ├── __init__.py │ │ ├── gaussian_blur.py │ │ └── helpers.py ├── models │ ├── __init__.py │ ├── alexnet.py │ ├── resnet.py │ ├── vaal_model.py │ └── vgg.py └── utils │ ├── __init__.py │ ├── benchmark.py │ ├── checkpoint.py │ ├── distributed.py │ ├── io.py │ ├── logging.py │ ├── meters.py │ ├── metrics.py │ ├── net.py │ ├── plotting.py │ └── timer.py ├── requirements.txt └── tools ├── __init__.py ├── ensemble_al.py ├── ensemble_train.py ├── test_model.py ├── train.py └── train_al.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Shared objects 7 | *.so 8 | 9 | # Distribution / packaging 10 | build/ 11 | *.egg-info/ 12 | *.egg 13 | 14 | # Temporary files 15 | *.swn 16 | *.swo 17 | *.swp 18 | 19 | # PyCharm 20 | .idea/ 21 | 22 | # Mac 23 | .DS_STORE 24 | 25 | # Data symlinks 26 | pycls/datasets/data/ 27 | 28 | # Other 29 | logs/ 30 | scratch* 31 | 32 | data/ 33 | output/ 34 | *.ipynb_checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2021 (c) Akshay L Chandra and Vineeth N Balasubramanian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Active Learning Toolkit for Image Classification in PyTorch 2 | 3 | This is a code base for deep active learning for image classification written in [PyTorch](https://pytorch.org/). It is build on top of FAIR's [pycls](https://github.com/facebookresearch/pycls/). I want to emphasize that this is a derivative of the toolkit originally shared with me via email by Prateek Munjal _et al._, the authors of the paper _"Towards Robust and Reproducible Active Learning using Neural Networks"_, paper available [here](https://arxiv.org/abs/2002.09564). 4 | 5 | ## Introduction 6 | 7 | The goal of this repository is to provide a simple and flexible codebase for deep active learning. It is designed to support rapid implementation and evaluation of research ideas. We also provide a results on CIFAR10 below. 8 | 9 | The codebase currently only supports single-machine single-gpu training. We will soon scale it to single-machine multi-gpu training, powered by the PyTorch distributed package. 10 | 11 | ## Using the Toolkit 12 | 13 | Please see [`GETTING_STARTED`](docs/GETTING_STARTED.md) for brief instructions on installation, adding new datasets, basic usage examples, etc. 14 | 15 | ## Active Learning Methods Supported 16 | * Uncertainty Sampling 17 | * Least Confidence 18 | * Min-Margin 19 | * Max-Entropy 20 | * Deep Bayesian Active Learning (DBAL) [1] 21 | * Bayesian Active Learning by Disagreement (BALD) [1] 22 | * Diversity Sampling 23 | * Coreset (greedy) [2] 24 | * Variational Adversarial Active Learning (VAAL) [3] 25 | * Query-by-Committee Sampling 26 | * Ensemble Variation Ratio (Ens-varR) [4] 27 | 28 | 29 | ## Datasets Supported 30 | * [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html) 31 | * [MNIST](http://yann.lecun.com/exdb/mnist/) 32 | * [SVHN](http://ufldl.stanford.edu/housenumbers/) 33 | * [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet) (Download the zip file [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip)) 34 | * Long-Tail CIFAR-10/100 35 | 36 | Follow the instructions in [`GETTING_STARTED`](docs/GETTING_STARTED.md) to add a new dataset. 37 | 38 | ## Results on CIFAR10 and CIFAR100 39 | 40 | The following are the results on CIFAR10 and CIFAR100, trained with hyperameters present in `configs/cifar10/al/RESNET18.yaml` and `configs/cifar100/al/RESNET18.yaml` respectively. All results were averaged over 3 runs. 41 | 42 | 43 | 44 |
45 | 46 |
47 | 48 | ### CIFAR10 at 60% 49 | ``` 50 | | AL Method | Test Accuracy | 51 | |:----------------:|:---------------------------:| 52 | | DBAL | 91.670000 +- 0.230651 | 53 | | Least Confidence | 91.510000 +- 0.087178 | 54 | | BALD | 91.470000 +- 0.293087 | 55 | | Coreset | 91.433333 +- 0.090738 | 56 | | Max-Entropy | 91.373333 +- 0.363639 | 57 | | Min-Margin | 91.333333 +- 0.234592 | 58 | | Ensemble-varR | 89.866667 +- 0.127410 | 59 | | Random | 89.803333 +- 0.230290 | 60 | | VAAL | 89.690000 +- 0.115326 | 61 | ``` 62 | 63 | ### CIFAR100 at 60% 64 | ``` 65 | | AL Method | Test Accuracy | 66 | |:----------------:|:---------------------------:| 67 | | DBAL | 55.400000 +- 1.037931 | 68 | | Coreset | 55.333333 +- 0.773714 | 69 | | Max-Entropy | 55.226667 +- 0.536128 | 70 | | BALD | 55.186667 +- 0.369639 | 71 | | Least Confidence | 55.003333 +- 0.937248 | 72 | | Min-Margin | 54.543333 +- 0.611583 | 73 | | Ensemble-varR | 54.186667 +- 0.325628 | 74 | | VAAL | 53.943333 +- 0.680686 | 75 | | Random | 53.546667 +- 0.302875 | 76 | ``` 77 | 78 | ## Citing this Repository 79 | 80 | If you find this repo helpful in your research, please consider citing us and the owners of the original toolkit: 81 | 82 | ``` 83 | @article{Chandra2021DeepAL, 84 | Author = {Akshay L Chandra and Vineeth N Balasubramanian}, 85 | Title = {Deep Active Learning Toolkit for Image Classification in PyTorch}, 86 | Journal = {https://github.com/acl21/deep-active-learning-pytorch}, 87 | Year = {2021} 88 | } 89 | 90 | @article{Munjal2020TowardsRA, 91 | title={Towards Robust and Reproducible Active Learning Using Neural Networks}, 92 | author={Prateek Munjal and N. Hayat and Munawar Hayat and J. Sourati and S. Khan}, 93 | journal={ArXiv}, 94 | year={2020}, 95 | volume={abs/2002.09564} 96 | } 97 | ``` 98 | 99 | ## License 100 | 101 | This toolkit is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 102 | 103 | ## References 104 | 105 | [1] Yarin Gal, Riashat Islam, and Zoubin Ghahramani. Deep bayesian active learning with image data. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1183–1192. JMLR. org, 2017. 106 | 107 | [2] Ozan Sener and Silvio Savarese. Active learning for convolutional neural networks: A core-set approach. In International Conference on Learning Representations, 2018. 108 | 109 | [3] Sinha, Samarth et al. Variational Adversarial Active Learning. 2019 IEEE/CVF International Conference on Computer Vision (ICCV) (2019): 5971-5980. 110 | 111 | [4] William H. Beluch, Tim Genewein, Andreas Nürnberger, and Jan M. Köhler. The power of ensembles for active learning in image classification. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 9368–9377, 2018. -------------------------------------------------------------------------------- /configs/cifar10/al/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INIT_L_RATIO: 0.1 37 | BUDGET_SIZE: 5000 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 -------------------------------------------------------------------------------- /configs/cifar10/al/RESNET18_ENS.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | BUDGET_SIZE: 5000 37 | SAMPLING_FN: 'ensemble_var_R' # do not change this for ensemble learning 38 | MAX_ITER: 5 39 | ENSEMBLE: 40 | NUM_MODELS: 3 41 | MODEL_TYPE: ['resnet18'] -------------------------------------------------------------------------------- /configs/cifar10/al/RESNET18_IM.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: IMBALANCED_CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INITIAL_L_RATIO: 0.1 37 | BUDGET_SIZE: 1270 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 -------------------------------------------------------------------------------- /configs/cifar10/evaluate/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'SOME_RANDOM_NAME' 2 | RNG_SEED: 1 3 | GPU_ID: '3' 4 | DATASET: 5 | NAME: CIFAR10 6 | ROOT_DIR: 'data' 7 | MODEL: 8 | TYPE: resnet18 9 | NUM_CLASSES: 10 10 | TEST: 11 | SPLIT: test 12 | BATCH_SIZE: 200 13 | IM_SIZE: 32 14 | MODEL_PATH: 'path/to/saved/model' 15 | DATA_LOADER: 16 | NUM_WORKERS: 4 17 | CUDNN: 18 | BENCHMARK: True -------------------------------------------------------------------------------- /configs/cifar10/train/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 1 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True -------------------------------------------------------------------------------- /configs/cifar10/train/RESNET18_ENS.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'SOME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR10 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ENSEMBLE: 36 | NUM_MODELS: 3 37 | MODEL_TYPE: ['resnet18'] -------------------------------------------------------------------------------- /configs/cifar100/al/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: CIFAR100 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 100 11 | OPTIM: 12 | TYPE: 'sgd' 13 | BASE_LR: 0.025 14 | LR_POLICY: cos 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 200 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 96 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INIT_L_RATIO: 0.1 37 | BUDGET_SIZE: 5000 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 -------------------------------------------------------------------------------- /configs/mnist/al/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: MNIST 5 | ROOT_DIR: './data' 6 | VAL_RATIO: 0.1 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 10 11 | OPTIM: 12 | TYPE: 'adam' 13 | BASE_LR: 0.005 14 | LR_POLICY: none 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 100 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 1 25 | IM_SIZE: 32 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 32 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INIT_L_RATIO: 0.001 37 | BUDGET_SIZE: 60 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 -------------------------------------------------------------------------------- /configs/template.yaml: -------------------------------------------------------------------------------- 1 | # Folder name where best model logs etc are saved. "auto" creates a timestamp based folder 2 | EXP_NAME: 'SOME_RANDOM_NAME' 3 | # Note that non-determinism may still be present due to non-deterministic 4 | # operator implementations in GPU operator libraries 5 | RNG_SEED: 1 6 | # GPU ID you want to execute the process on 7 | GPU_ID: '3' 8 | DATASET: 9 | NAME: CIFAR10 # or CIFAR100, MNIST, SVHN, TinyImageNet 10 | ROOT_DIR: 'data' # Relative path where data should be downloaded 11 | # Specifies the proportion of data in train set that should be considered as the validation data 12 | VAL_RATIO: 0.1 13 | # Data augmentation methods - 'simclr', 'randaug', 'horizontalflip' 14 | AUG_METHOD: 'horizontalflip' 15 | MODEL: 16 | # Model type. 17 | # Choose from vgg style ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',] 18 | # or from resnet style ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 19 | # 'wide_resnet50_2', 'wide_resnet101_2'] 20 | TYPE: resnet18 21 | NUM_CLASSES: 10 22 | OPTIM: 23 | TYPE: 'sgd' # or 'adam' 24 | BASE_LR: 0.1 25 | # Learning rate policy select from {'cos', 'exp', 'steps'} 26 | LR_POLICY: steps 27 | # Steps for 'steps' policy (in epochs) 28 | STEPS: [0] #[0, 30, 60, 90] 29 | # Training Epochs 30 | MAX_EPOCH: 1 31 | # Momentum 32 | MOMENTUM: 0.9 33 | # Nesterov Momentum 34 | NESTEROV: False 35 | # L2 regularization 36 | WEIGHT_DECAY: 0.0005 37 | # Exponential decay factor 38 | GAMMA: 0.1 39 | TRAIN: 40 | SPLIT: train 41 | # Training mini-batch size 42 | BATCH_SIZE: 256 43 | # Image size 44 | IM_SIZE: 32 45 | IM_CHANNELS = 3 46 | # Evaluate model on test data every eval period epochs 47 | EVAL_PERIOD: 1 48 | TEST: 49 | SPLIT: test 50 | # Testing mini-batch size 51 | BATCH_SIZE: 200 52 | # Image size 53 | IM_SIZE: 32 54 | # Saved model to use for testing (useful when running tools/test_model.py) 55 | MODEL_PATH: '' 56 | DATA_LOADER: 57 | NUM_WORKERS: 4 58 | CUDNN: 59 | BENCHMARK: True 60 | ACTIVE_LEARNING: 61 | # Active sampling budget (at each episode) 62 | BUDGET_SIZE: 5000 63 | # Active sampling method 64 | SAMPLING_FN: 'dbal' # 'random', 'uncertainty', 'entropy', 'margin', 'bald', 'vaal', 'coreset', 'ensemble_var_R' 65 | # Initial labeled pool ratio (% of total train set that should be labeled before AL begins) 66 | INIT_L_RATIO: 0.1 67 | # Max AL episodes 68 | MAX_ITER: 1 69 | DROPOUT_ITERATIONS: 10 # Used by DBAL 70 | # Useful when running `tools/ensemble_al.py` or `tools/ensemble_train.py` 71 | ENSEMBLE: 72 | NUM_MODELS: 3 73 | MODEL_TYPE: ['resnet18'] -------------------------------------------------------------------------------- /configs/tinyimagenet/al/RESNET18.yaml: -------------------------------------------------------------------------------- 1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME' 2 | RNG_SEED: 1 3 | DATASET: 4 | NAME: TINYIMAGENET 5 | ROOT_DIR: '/path/to/dataset/directory/' 6 | VAL_RATIO: 0.05 7 | AUG_METHOD: 'hflip' 8 | MODEL: 9 | TYPE: resnet18 10 | NUM_CLASSES: 200 11 | OPTIM: 12 | TYPE: 'adam' 13 | BASE_LR: 0.001 14 | LR_POLICY: none 15 | LR_MULT: 0.1 16 | # STEPS: [0, 60, 120, 160, 200] 17 | MAX_EPOCH: 100 18 | MOMENTUM: 0.9 19 | NESTEROV: True 20 | WEIGHT_DECAY: 0.0003 21 | GAMMA: 0.1 22 | TRAIN: 23 | SPLIT: train 24 | BATCH_SIZE: 1 25 | IM_SIZE: 64 26 | EVAL_PERIOD: 2 27 | TEST: 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | IM_SIZE: 64 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: True 35 | ACTIVE_LEARNING: 36 | INIT_L_RATIO: 0.1 37 | BUDGET_SIZE: 5000 38 | # SAMPLING_FN: 'uncertainty' 39 | MAX_ITER: 5 -------------------------------------------------------------------------------- /docs/AL_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/docs/AL_results.png -------------------------------------------------------------------------------- /docs/GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Environment Setup 4 | 5 | Clone the repository: 6 | 7 | ``` 8 | git clone https://github.com/acl21/deep-active-learning-pytorch 9 | ``` 10 | 11 | Install dependencies: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Understanding the Config File 18 | ``` 19 | # Folder name where best model logs etc are saved. Setting EXP_NAME: "auto" creates a timestamp named folder 20 | EXP_NAME: 'YOUR_EXPERIMENT_NAME' 21 | # Note that non-determinism may still be present due to non-deterministic 22 | # operator implementations in GPU operator libraries 23 | RNG_SEED: 1 24 | # GPU ID you want to execute the process on (this feature isn't working as of now, use the commands shown in this file below instead) 25 | GPU_ID: '3' 26 | DATASET: 27 | NAME: CIFAR10 # or CIFAR100, MNIST, SVHN, TINYIMAGENET, IMBALANCED_CIFAR10/100 28 | ROOT_DIR: 'data' # Relative path where data should be downloaded 29 | # Specifies the proportion of data in train set that should be considered as the validation data 30 | VAL_RATIO: 0.1 31 | # Data augmentation methods - 'simclr', 'randaug', 'hflip' 32 | AUG_METHOD: 'hflip' 33 | MODEL: 34 | # Model type. 35 | # Choose from vgg style ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',] 36 | # or from resnet style ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 37 | # 'wide_resnet50_2', 'wide_resnet101_2'] 38 | # or `alexnet` 39 | TYPE: resnet18 40 | NUM_CLASSES: 10 41 | OPTIM: 42 | TYPE: 'sgd' # or 'adam' 43 | BASE_LR: 0.025 44 | # Learning rate policy select from {'cos', 'exp', 'steps' or 'none'} 45 | LR_POLICY: cos 46 | # Steps for 'steps' policy (in epochs) 47 | STEPS: [0] #[0, 30, 60, 90] 48 | # Training Epochs 49 | MAX_EPOCH: 1 50 | # Momentum 51 | MOMENTUM: 0.9 52 | # Nesterov Momentum 53 | NESTEROV: False 54 | # L2 regularization 55 | WEIGHT_DECAY: 0.0005 56 | # Exponential decay factor 57 | GAMMA: 0.1 58 | TRAIN: 59 | SPLIT: train 60 | # Training mini-batch size 61 | BATCH_SIZE: 96 62 | # Image size 63 | IM_SIZE: 32 64 | IM_CHANNELS = 3 65 | # Evaluate model on test data every eval period epochs 66 | EVAL_PERIOD: 2 67 | TEST: 68 | SPLIT: test 69 | # Testing mini-batch size 70 | BATCH_SIZE: 200 71 | # Image size 72 | IM_SIZE: 32 73 | # Saved model to use for testing (useful when running `tools/test_model.py`) 74 | MODEL_PATH: '' 75 | DATA_LOADER: 76 | NUM_WORKERS: 4 77 | CUDNN: 78 | BENCHMARK: True 79 | ACTIVE_LEARNING: 80 | # Active sampling budget (at each episode) 81 | BUDGET_SIZE: 5000 82 | # Active sampling method 83 | SAMPLING_FN: 'dbal' # 'random', 'uncertainty', 'entropy', 'margin', 'bald', 'vaal', 'coreset', 'ensemble_var_R' 84 | # Initial labeled pool ratio (% of total train set that should be labeled before AL begins) 85 | INIT_L_RATIO: 0.1 86 | # Max AL episodes 87 | MAX_ITER: 5 88 | DROPOUT_ITERATIONS: 25 # Used by DBAL 89 | # Useful when running `ensemble_al.py` or `ensemble_train.py` 90 | ENSEMBLE: 91 | NUM_MODELS: 3 92 | MODEL_TYPE: ['resnet18'] 93 | ``` 94 | 95 | Please refer to `pycls/core/config.py` to configure your experiments at a deeper level. 96 | 97 | 98 | ## Execution Commands 99 | ### Active Learning 100 | Once the config file is configured appropriately, perform DBAL active learning with the following command inside the `tools` directory. 101 | 102 | ``` 103 | CUDA_VISIBLE_DEVICES=0 python train_al.py \ 104 | --cfg=../configs/cifar10/al/RESNET18.yaml --al=dbal --exp-name=YOUR_EXPERIMENT_NAME 105 | ``` 106 | 107 | ### Ensemble Active Learning 108 | 109 | Watch out for the ensemble options in the config file. This setting by default using _Ensemble Variation-Ratio_ as the query method. 110 | 111 | ``` 112 | CUDA_VISIBLE_DEVICES=0 python ensemble_al.py \ 113 | --cfg=../configs/cifar10/al/RESNET18.yaml --exp-name=YOUR_EXPERIMENT_NAME 114 | ``` 115 | 116 | ### Passive Learning 117 | 118 | ``` 119 | CUDA_VISIBLE_DEVICES=0 python train.py \ 120 | --cfg=../configs/cifar10/train/RESNET18.yaml --exp-name=YOUR_EXPERIMENT_NAME 121 | ``` 122 | 123 | ### Ensemble Passive Learning 124 | 125 | Watch out for the ensemble options in the config file. 126 | 127 | ``` 128 | CUDA_VISIBLE_DEVICES=0 python ensemble_train.py \ 129 | --cfg=../configs/cifar10/train/RESNET18_ENS.yaml --exp-name=YOUR_EXPERIMENT_NAME 130 | ``` 131 | 132 | ### Specific Model Evaluation 133 | 134 | This is useful if you want to evaluate a particular saved model. Pass the path to the model in the yaml file. Refer to the file inside the `config/evaluate` directory for clarity. 135 | 136 | ``` 137 | CUDA_VISIBLE_DEVICES=0 python test_model.py \ 138 | --cfg configs/cifar10/evaluate/RESNET18.yaml 139 | ``` 140 | 141 | 142 | ## Add Your Own Dataset 143 | 144 | To add your own dataset, you need to do the following: 145 | 1. Write the PyTorch Dataset code for your custom dataset (or you could directly use the ones [PyTorch provides](https://pytorch.org/vision/stable/datasets.html)). 146 | 2. Create a sub class of the above Dataset with some desirable modifications and add it to the `pycls/datasets/custom_datasets.py`. 147 | * We add two new variables to the dataset - a boolean flag `no_aug` and `test_transform`. 148 | * We set the flag `no_aug = True` before iterating through unlabeled and the validations dataloaders so that data doesn't get augmented. 149 | * See how we modify the `__get_item__` function to achieve that: 150 | ``` 151 | class CIFAR10(torchvision.datasets.CIFAR10): 152 | def __init__(self, root, train, transform, test_transform, download=True): 153 | super(CIFAR10, self).__init__(root, train, transform=transform, download=download) 154 | self.test_transform = test_transform 155 | self.no_aug = False 156 | 157 | def __getitem__(self, index: int): 158 | """ 159 | Args: 160 | index (int): Index 161 | 162 | Returns: 163 | tuple: (image, target) where target is index of the target class. 164 | """ 165 | img, target = self.data[index], self.targets[index] 166 | 167 | # doing this so that it is consistent with all other datasets 168 | # to return a PIL Image 169 | img = Image.fromarray(img) 170 | 171 | ########################## 172 | # set True before iterating through unlabeled or validation set 173 | if self.no_aug: 174 | if self.test_transform is not None: 175 | img = self.test_transform(img) 176 | else: 177 | if self.transform is not None: 178 | img = self.transform(img) 179 | ######################### 180 | 181 | return img, target 182 | ``` 183 | 3. Add your dataset in `pycls/dataset/data.py` 184 | * Add appropriate preprocessing steps to `getPreprocessOps` 185 | * Add the dataset call to `getDataset` 186 | 4. Create appropriate config `yaml` files and use them for training AL. 187 | 188 | 189 | ## Some Comments About Our Toolkit 190 | * Our toolkit currently only supports 'SGD' (with learning rate scheduler) and 'Adam' (no scheduler). 191 | * We log everything. Our toolkit saves the indices of the initial labeled pool, samples queried each episode, episode wise best model, visual plots for "Iteration vs Loss", "Epoch vs Val Accuracy", "Episode vs Test Accuracy" and more. Please check an experiment's logs at `output/CIFAR10/resnet18/ENT_1/` for clarity. 192 | * We added dropout (p=0.5) to all our models just before the final fully connected layer. We do this to allow the DBAL and BALD query methods to work. 193 | * We also provide an iPython notebook that aggregates results directly from the experiment folders. You can find it at `output/results_aggregator.ipynb`. 194 | * If you add your own dataset, please make sure you to create the custom version as explained in point 2 in the instructions. Failing to do that would mean that your unlabeled data (big red flag for AL) and validation data will have been augmentated. This is because we use a single dataset instance and subset and index based dataloaders. 195 | * We tested the toolkit only on a Linux machine with Python 3.8. 196 | * Please create an issue with appropriate details: 197 | * if you are unable to get the toolkit to work or run into any problems 198 | * if we have not provided credits correctly to the rightful owner (please attach proof) 199 | * if you notice any flaws in the implementation 200 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/Episodes_vs_Test Accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/Episodes_vs_Test Accuracy.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Validation Accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Validation Accuracy.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/Iterations_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/Iterations_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/activeSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/activeSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/lSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/lSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/plot_epoch_xvalues.txt: -------------------------------------------------------------------------------- 1 | 1.00 2 | 2.00 3 | 3.00 4 | 4.00 5 | 5.00 6 | 6.00 7 | 7.00 8 | 8.00 9 | 9.00 10 | 10.00 11 | 11.00 12 | 12.00 13 | 13.00 14 | 14.00 15 | 15.00 16 | 16.00 17 | 17.00 18 | 18.00 19 | 19.00 20 | 20.00 21 | 21.00 22 | 22.00 23 | 23.00 24 | 24.00 25 | 25.00 26 | 26.00 27 | 27.00 28 | 28.00 29 | 29.00 30 | 30.00 31 | 31.00 32 | 32.00 33 | 33.00 34 | 34.00 35 | 35.00 36 | 36.00 37 | 37.00 38 | 38.00 39 | 39.00 40 | 40.00 41 | 41.00 42 | 42.00 43 | 43.00 44 | 44.00 45 | 45.00 46 | 46.00 47 | 47.00 48 | 48.00 49 | 49.00 50 | 50.00 51 | 51.00 52 | 52.00 53 | 53.00 54 | 54.00 55 | 55.00 56 | 56.00 57 | 57.00 58 | 58.00 59 | 59.00 60 | 60.00 61 | 61.00 62 | 62.00 63 | 63.00 64 | 64.00 65 | 65.00 66 | 66.00 67 | 67.00 68 | 68.00 69 | 69.00 70 | 70.00 71 | 71.00 72 | 72.00 73 | 73.00 74 | 74.00 75 | 75.00 76 | 76.00 77 | 77.00 78 | 78.00 79 | 79.00 80 | 80.00 81 | 81.00 82 | 82.00 83 | 83.00 84 | 84.00 85 | 85.00 86 | 86.00 87 | 87.00 88 | 88.00 89 | 89.00 90 | 90.00 91 | 91.00 92 | 92.00 93 | 93.00 94 | 94.00 95 | 95.00 96 | 96.00 97 | 97.00 98 | 98.00 99 | 99.00 100 | 100.00 101 | 101.00 102 | 102.00 103 | 103.00 104 | 104.00 105 | 105.00 106 | 106.00 107 | 107.00 108 | 108.00 109 | 109.00 110 | 110.00 111 | 111.00 112 | 112.00 113 | 113.00 114 | 114.00 115 | 115.00 116 | 116.00 117 | 117.00 118 | 118.00 119 | 119.00 120 | 120.00 121 | 121.00 122 | 122.00 123 | 123.00 124 | 124.00 125 | 125.00 126 | 126.00 127 | 127.00 128 | 128.00 129 | 129.00 130 | 130.00 131 | 131.00 132 | 132.00 133 | 133.00 134 | 134.00 135 | 135.00 136 | 136.00 137 | 137.00 138 | 138.00 139 | 139.00 140 | 140.00 141 | 141.00 142 | 142.00 143 | 143.00 144 | 144.00 145 | 145.00 146 | 146.00 147 | 147.00 148 | 148.00 149 | 149.00 150 | 150.00 151 | 151.00 152 | 152.00 153 | 153.00 154 | 154.00 155 | 155.00 156 | 156.00 157 | 157.00 158 | 158.00 159 | 159.00 160 | 160.00 161 | 161.00 162 | 162.00 163 | 163.00 164 | 164.00 165 | 165.00 166 | 166.00 167 | 167.00 168 | 168.00 169 | 169.00 170 | 170.00 171 | 171.00 172 | 172.00 173 | 173.00 174 | 174.00 175 | 175.00 176 | 176.00 177 | 177.00 178 | 178.00 179 | 179.00 180 | 180.00 181 | 181.00 182 | 182.00 183 | 183.00 184 | 184.00 185 | 185.00 186 | 186.00 187 | 187.00 188 | 188.00 189 | 189.00 190 | 190.00 191 | 191.00 192 | 192.00 193 | 193.00 194 | 194.00 195 | 195.00 196 | 196.00 197 | 197.00 198 | 198.00 199 | 199.00 200 | 200.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/plot_epoch_yvalues.txt: -------------------------------------------------------------------------------- 1 | 2.39 2 | 2.34 3 | 2.04 4 | 0.69 5 | 2.46 6 | 2.31 7 | 2.15 8 | 1.69 9 | 1.67 10 | 1.70 11 | 1.93 12 | 1.74 13 | 1.89 14 | 2.85 15 | 1.80 16 | 0.71 17 | 2.35 18 | 1.73 19 | 1.20 20 | 0.89 21 | 1.14 22 | 1.34 23 | 1.16 24 | 0.98 25 | 0.97 26 | 1.59 27 | 0.72 28 | 0.59 29 | 0.94 30 | 1.27 31 | 1.53 32 | 0.72 33 | 0.59 34 | 1.31 35 | 1.28 36 | 1.05 37 | 1.91 38 | 1.23 39 | 0.73 40 | 0.57 41 | 1.86 42 | 3.92 43 | 0.35 44 | 0.33 45 | 1.12 46 | 1.46 47 | 0.87 48 | 0.49 49 | 2.36 50 | 1.00 51 | 1.35 52 | 2.22 53 | 0.49 54 | 0.92 55 | 0.37 56 | 0.17 57 | 0.23 58 | 1.24 59 | 1.54 60 | 1.28 61 | 0.74 62 | 0.82 63 | 0.24 64 | 0.16 65 | 0.14 66 | 0.66 67 | 0.10 68 | 0.60 69 | 2.10 70 | 0.09 71 | 2.29 72 | 0.80 73 | 0.17 74 | 0.55 75 | 0.18 76 | 0.63 77 | 0.19 78 | 1.30 79 | 0.14 80 | 3.27 81 | 0.63 82 | 0.41 83 | 0.88 84 | 0.23 85 | 0.18 86 | 1.05 87 | 0.42 88 | 1.66 89 | 0.17 90 | 0.01 91 | 1.55 92 | 0.80 93 | 0.25 94 | 1.22 95 | 0.27 96 | 0.04 97 | 0.63 98 | 0.58 99 | 0.37 100 | 0.15 101 | 0.07 102 | 0.90 103 | 1.04 104 | 1.05 105 | 0.34 106 | 0.61 107 | 1.22 108 | 0.34 109 | 0.16 110 | 0.74 111 | 0.14 112 | 0.25 113 | 0.82 114 | 1.69 115 | 0.27 116 | 1.72 117 | 1.15 118 | 1.24 119 | 0.04 120 | 0.16 121 | 0.01 122 | 0.57 123 | 0.10 124 | 0.04 125 | 1.02 126 | 0.02 127 | 0.20 128 | 0.33 129 | 0.08 130 | 0.54 131 | 0.31 132 | 0.40 133 | 0.35 134 | 1.19 135 | 0.09 136 | 0.49 137 | 0.36 138 | 0.54 139 | 0.29 140 | 0.36 141 | 0.42 142 | 0.13 143 | 0.06 144 | 0.33 145 | 0.04 146 | 0.30 147 | 0.01 148 | 0.51 149 | 0.00 150 | 0.03 151 | 0.25 152 | 0.19 153 | 0.92 154 | 0.28 155 | 0.02 156 | 0.04 157 | 0.04 158 | 0.01 159 | 0.33 160 | 0.02 161 | 0.02 162 | 1.31 163 | 0.02 164 | 0.01 165 | 1.45 166 | 0.00 167 | 0.67 168 | 0.00 169 | 0.05 170 | 0.00 171 | 0.08 172 | 0.05 173 | 0.66 174 | 0.34 175 | 0.25 176 | 0.79 177 | 0.03 178 | 0.01 179 | 0.01 180 | 0.01 181 | 0.02 182 | 0.00 183 | 1.77 184 | 0.31 185 | 0.08 186 | 0.08 187 | 0.19 188 | 0.43 189 | 1.09 190 | 0.51 191 | 0.92 192 | 0.02 193 | 0.03 194 | 0.11 195 | 0.86 196 | 1.21 197 | 1.27 198 | 0.06 199 | 0.00 200 | 0.75 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/plot_it_x_values.txt: -------------------------------------------------------------------------------- 1 | 19.00 2 | 38.00 3 | 72.00 4 | 91.00 5 | 125.00 6 | 144.00 7 | 178.00 8 | 197.00 9 | 231.00 10 | 250.00 11 | 284.00 12 | 303.00 13 | 337.00 14 | 356.00 15 | 390.00 16 | 409.00 17 | 443.00 18 | 462.00 19 | 496.00 20 | 515.00 21 | 549.00 22 | 568.00 23 | 602.00 24 | 621.00 25 | 655.00 26 | 674.00 27 | 708.00 28 | 727.00 29 | 761.00 30 | 780.00 31 | 814.00 32 | 833.00 33 | 867.00 34 | 886.00 35 | 920.00 36 | 939.00 37 | 973.00 38 | 992.00 39 | 1026.00 40 | 1045.00 41 | 1079.00 42 | 1098.00 43 | 1132.00 44 | 1151.00 45 | 1185.00 46 | 1204.00 47 | 1238.00 48 | 1257.00 49 | 1291.00 50 | 1310.00 51 | 1344.00 52 | 1363.00 53 | 1397.00 54 | 1416.00 55 | 1450.00 56 | 1469.00 57 | 1503.00 58 | 1522.00 59 | 1556.00 60 | 1575.00 61 | 1609.00 62 | 1628.00 63 | 1662.00 64 | 1681.00 65 | 1715.00 66 | 1734.00 67 | 1768.00 68 | 1787.00 69 | 1821.00 70 | 1840.00 71 | 1874.00 72 | 1893.00 73 | 1927.00 74 | 1946.00 75 | 1980.00 76 | 1999.00 77 | 2033.00 78 | 2052.00 79 | 2086.00 80 | 2105.00 81 | 2139.00 82 | 2158.00 83 | 2192.00 84 | 2211.00 85 | 2245.00 86 | 2264.00 87 | 2298.00 88 | 2317.00 89 | 2351.00 90 | 2370.00 91 | 2404.00 92 | 2423.00 93 | 2457.00 94 | 2476.00 95 | 2510.00 96 | 2529.00 97 | 2563.00 98 | 2582.00 99 | 2616.00 100 | 2635.00 101 | 2669.00 102 | 2688.00 103 | 2722.00 104 | 2741.00 105 | 2775.00 106 | 2794.00 107 | 2828.00 108 | 2847.00 109 | 2881.00 110 | 2900.00 111 | 2934.00 112 | 2953.00 113 | 2987.00 114 | 3006.00 115 | 3040.00 116 | 3059.00 117 | 3093.00 118 | 3112.00 119 | 3146.00 120 | 3165.00 121 | 3199.00 122 | 3218.00 123 | 3252.00 124 | 3271.00 125 | 3305.00 126 | 3324.00 127 | 3358.00 128 | 3377.00 129 | 3411.00 130 | 3430.00 131 | 3464.00 132 | 3483.00 133 | 3517.00 134 | 3536.00 135 | 3570.00 136 | 3589.00 137 | 3623.00 138 | 3642.00 139 | 3676.00 140 | 3695.00 141 | 3729.00 142 | 3748.00 143 | 3782.00 144 | 3801.00 145 | 3835.00 146 | 3854.00 147 | 3888.00 148 | 3907.00 149 | 3941.00 150 | 3960.00 151 | 3994.00 152 | 4013.00 153 | 4047.00 154 | 4066.00 155 | 4100.00 156 | 4119.00 157 | 4153.00 158 | 4172.00 159 | 4206.00 160 | 4225.00 161 | 4259.00 162 | 4278.00 163 | 4312.00 164 | 4331.00 165 | 4365.00 166 | 4384.00 167 | 4418.00 168 | 4437.00 169 | 4471.00 170 | 4490.00 171 | 4524.00 172 | 4543.00 173 | 4577.00 174 | 4596.00 175 | 4630.00 176 | 4649.00 177 | 4683.00 178 | 4702.00 179 | 4736.00 180 | 4755.00 181 | 4789.00 182 | 4808.00 183 | 4842.00 184 | 4861.00 185 | 4895.00 186 | 4914.00 187 | 4948.00 188 | 4967.00 189 | 5001.00 190 | 5020.00 191 | 5054.00 192 | 5073.00 193 | 5107.00 194 | 5126.00 195 | 5160.00 196 | 5179.00 197 | 5213.00 198 | 5232.00 199 | 5266.00 200 | 5285.00 201 | 5319.00 202 | 5338.00 203 | 5372.00 204 | 5391.00 205 | 5425.00 206 | 5444.00 207 | 5478.00 208 | 5497.00 209 | 5531.00 210 | 5550.00 211 | 5584.00 212 | 5603.00 213 | 5637.00 214 | 5656.00 215 | 5690.00 216 | 5709.00 217 | 5743.00 218 | 5762.00 219 | 5796.00 220 | 5815.00 221 | 5849.00 222 | 5868.00 223 | 5902.00 224 | 5921.00 225 | 5955.00 226 | 5974.00 227 | 6008.00 228 | 6027.00 229 | 6061.00 230 | 6080.00 231 | 6114.00 232 | 6133.00 233 | 6167.00 234 | 6186.00 235 | 6220.00 236 | 6239.00 237 | 6273.00 238 | 6292.00 239 | 6326.00 240 | 6345.00 241 | 6379.00 242 | 6398.00 243 | 6432.00 244 | 6451.00 245 | 6485.00 246 | 6504.00 247 | 6538.00 248 | 6557.00 249 | 6591.00 250 | 6610.00 251 | 6644.00 252 | 6663.00 253 | 6697.00 254 | 6716.00 255 | 6750.00 256 | 6769.00 257 | 6803.00 258 | 6822.00 259 | 6856.00 260 | 6875.00 261 | 6909.00 262 | 6928.00 263 | 6962.00 264 | 6981.00 265 | 7015.00 266 | 7034.00 267 | 7068.00 268 | 7087.00 269 | 7121.00 270 | 7140.00 271 | 7174.00 272 | 7193.00 273 | 7227.00 274 | 7246.00 275 | 7280.00 276 | 7299.00 277 | 7333.00 278 | 7352.00 279 | 7386.00 280 | 7405.00 281 | 7439.00 282 | 7458.00 283 | 7492.00 284 | 7511.00 285 | 7545.00 286 | 7564.00 287 | 7598.00 288 | 7617.00 289 | 7651.00 290 | 7670.00 291 | 7704.00 292 | 7723.00 293 | 7757.00 294 | 7776.00 295 | 7810.00 296 | 7829.00 297 | 7863.00 298 | 7882.00 299 | 7916.00 300 | 7935.00 301 | 7969.00 302 | 7988.00 303 | 8022.00 304 | 8041.00 305 | 8075.00 306 | 8094.00 307 | 8128.00 308 | 8147.00 309 | 8181.00 310 | 8200.00 311 | 8234.00 312 | 8253.00 313 | 8287.00 314 | 8306.00 315 | 8340.00 316 | 8359.00 317 | 8393.00 318 | 8412.00 319 | 8446.00 320 | 8465.00 321 | 8499.00 322 | 8518.00 323 | 8552.00 324 | 8571.00 325 | 8605.00 326 | 8624.00 327 | 8658.00 328 | 8677.00 329 | 8711.00 330 | 8730.00 331 | 8764.00 332 | 8783.00 333 | 8817.00 334 | 8836.00 335 | 8870.00 336 | 8889.00 337 | 8923.00 338 | 8942.00 339 | 8976.00 340 | 8995.00 341 | 9029.00 342 | 9048.00 343 | 9082.00 344 | 9101.00 345 | 9135.00 346 | 9154.00 347 | 9188.00 348 | 9207.00 349 | 9241.00 350 | 9260.00 351 | 9294.00 352 | 9313.00 353 | 9347.00 354 | 9366.00 355 | 9400.00 356 | 9419.00 357 | 9453.00 358 | 9472.00 359 | 9506.00 360 | 9525.00 361 | 9559.00 362 | 9578.00 363 | 9612.00 364 | 9631.00 365 | 9665.00 366 | 9684.00 367 | 9718.00 368 | 9737.00 369 | 9771.00 370 | 9790.00 371 | 9824.00 372 | 9843.00 373 | 9877.00 374 | 9896.00 375 | 9930.00 376 | 9949.00 377 | 9983.00 378 | 10002.00 379 | 10036.00 380 | 10055.00 381 | 10089.00 382 | 10108.00 383 | 10142.00 384 | 10161.00 385 | 10195.00 386 | 10214.00 387 | 10248.00 388 | 10267.00 389 | 10301.00 390 | 10320.00 391 | 10354.00 392 | 10373.00 393 | 10407.00 394 | 10426.00 395 | 10460.00 396 | 10479.00 397 | 10513.00 398 | 10532.00 399 | 10566.00 400 | 10585.00 401 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/plot_it_y_values.txt: -------------------------------------------------------------------------------- 1 | 2.29 2 | 2.20 3 | 1.98 4 | 2.13 5 | 2.03 6 | 2.19 7 | 1.74 8 | 1.64 9 | 1.80 10 | 1.83 11 | 1.62 12 | 1.42 13 | 1.61 14 | 1.51 15 | 1.56 16 | 1.46 17 | 1.45 18 | 1.39 19 | 1.34 20 | 1.23 21 | 1.33 22 | 1.24 23 | 1.35 24 | 1.38 25 | 1.20 26 | 1.24 27 | 1.14 28 | 1.13 29 | 1.39 30 | 1.25 31 | 1.28 32 | 1.08 33 | 1.40 34 | 1.15 35 | 1.44 36 | 1.29 37 | 1.03 38 | 0.86 39 | 0.96 40 | 0.91 41 | 0.97 42 | 1.20 43 | 0.93 44 | 1.05 45 | 0.98 46 | 0.89 47 | 0.88 48 | 0.68 49 | 0.88 50 | 0.70 51 | 0.80 52 | 0.92 53 | 0.80 54 | 0.85 55 | 0.72 56 | 0.78 57 | 0.76 58 | 0.86 59 | 1.06 60 | 0.87 61 | 0.76 62 | 0.75 63 | 0.85 64 | 0.79 65 | 0.83 66 | 0.63 67 | 0.60 68 | 0.65 69 | 0.79 70 | 0.69 71 | 0.70 72 | 0.60 73 | 0.96 74 | 0.65 75 | 0.84 76 | 0.62 77 | 0.89 78 | 0.74 79 | 0.59 80 | 0.58 81 | 0.69 82 | 0.58 83 | 0.73 84 | 0.62 85 | 0.70 86 | 0.76 87 | 0.60 88 | 0.60 89 | 0.68 90 | 0.47 91 | 0.54 92 | 0.42 93 | 0.80 94 | 0.59 95 | 0.54 96 | 0.47 97 | 0.60 98 | 0.41 99 | 0.74 100 | 0.57 101 | 0.65 102 | 0.55 103 | 0.70 104 | 0.45 105 | 0.44 106 | 0.44 107 | 0.49 108 | 0.44 109 | 0.43 110 | 0.59 111 | 0.55 112 | 0.39 113 | 0.42 114 | 0.45 115 | 0.38 116 | 0.34 117 | 0.54 118 | 0.41 119 | 0.38 120 | 0.23 121 | 0.61 122 | 0.34 123 | 0.26 124 | 0.24 125 | 0.28 126 | 0.37 127 | 0.51 128 | 0.37 129 | 0.39 130 | 0.37 131 | 0.37 132 | 0.29 133 | 0.31 134 | 0.22 135 | 0.28 136 | 0.20 137 | 0.32 138 | 0.34 139 | 0.57 140 | 0.24 141 | 0.25 142 | 0.31 143 | 0.76 144 | 0.31 145 | 0.24 146 | 0.25 147 | 0.34 148 | 0.21 149 | 0.19 150 | 0.23 151 | 0.26 152 | 0.27 153 | 0.32 154 | 0.24 155 | 0.17 156 | 0.29 157 | 0.60 158 | 0.37 159 | 0.17 160 | 0.28 161 | 0.42 162 | 0.28 163 | 0.25 164 | 0.31 165 | 0.30 166 | 0.16 167 | 0.20 168 | 0.13 169 | 0.12 170 | 0.16 171 | 0.09 172 | 0.12 173 | 0.18 174 | 0.14 175 | 0.31 176 | 0.21 177 | 0.27 178 | 0.22 179 | 0.06 180 | 0.17 181 | 0.07 182 | 0.18 183 | 0.22 184 | 0.28 185 | 0.26 186 | 0.22 187 | 0.14 188 | 0.27 189 | 0.35 190 | 0.13 191 | 0.20 192 | 0.09 193 | 0.11 194 | 0.04 195 | 0.23 196 | 0.23 197 | 0.37 198 | 0.23 199 | 0.30 200 | 0.17 201 | 0.12 202 | 0.21 203 | 0.05 204 | 0.07 205 | 0.17 206 | 0.12 207 | 0.21 208 | 0.10 209 | 0.21 210 | 0.08 211 | 0.08 212 | 0.12 213 | 0.10 214 | 0.13 215 | 0.18 216 | 0.12 217 | 0.24 218 | 0.11 219 | 0.06 220 | 0.07 221 | 0.13 222 | 0.07 223 | 0.12 224 | 0.08 225 | 0.07 226 | 0.05 227 | 0.20 228 | 0.15 229 | 0.23 230 | 0.14 231 | 0.09 232 | 0.12 233 | 0.20 234 | 0.13 235 | 0.29 236 | 0.15 237 | 0.11 238 | 0.09 239 | 0.05 240 | 0.07 241 | 0.08 242 | 0.05 243 | 0.03 244 | 0.03 245 | 0.10 246 | 0.14 247 | 0.04 248 | 0.06 249 | 0.09 250 | 0.07 251 | 0.12 252 | 0.19 253 | 0.02 254 | 0.01 255 | 0.05 256 | 0.05 257 | 0.02 258 | 0.04 259 | 0.06 260 | 0.10 261 | 0.06 262 | 0.09 263 | 0.06 264 | 0.05 265 | 0.17 266 | 0.03 267 | 0.06 268 | 0.03 269 | 0.10 270 | 0.11 271 | 0.10 272 | 0.06 273 | 0.04 274 | 0.08 275 | 0.02 276 | 0.08 277 | 0.01 278 | 0.04 279 | 0.03 280 | 0.01 281 | 0.07 282 | 0.03 283 | 0.04 284 | 0.04 285 | 0.05 286 | 0.03 287 | 0.02 288 | 0.01 289 | 0.04 290 | 0.03 291 | 0.02 292 | 0.00 293 | 0.03 294 | 0.02 295 | 0.02 296 | 0.01 297 | 0.04 298 | 0.05 299 | 0.03 300 | 0.01 301 | 0.01 302 | 0.02 303 | 0.02 304 | 0.01 305 | 0.01 306 | 0.01 307 | 0.07 308 | 0.06 309 | 0.01 310 | 0.02 311 | 0.02 312 | 0.05 313 | 0.01 314 | 0.00 315 | 0.01 316 | 0.01 317 | 0.01 318 | 0.01 319 | 0.03 320 | 0.02 321 | 0.01 322 | 0.03 323 | 0.00 324 | 0.02 325 | 0.01 326 | 0.03 327 | 0.00 328 | 0.01 329 | 0.01 330 | 0.00 331 | 0.02 332 | 0.02 333 | 0.02 334 | 0.03 335 | 0.00 336 | 0.01 337 | 0.01 338 | 0.02 339 | 0.00 340 | 0.01 341 | 0.00 342 | 0.01 343 | 0.02 344 | 0.04 345 | 0.01 346 | 0.00 347 | 0.01 348 | 0.00 349 | 0.03 350 | 0.02 351 | 0.00 352 | 0.04 353 | 0.03 354 | 0.01 355 | 0.01 356 | 0.02 357 | 0.02 358 | 0.00 359 | 0.01 360 | 0.00 361 | 0.04 362 | 0.00 363 | 0.00 364 | 0.01 365 | 0.04 366 | 0.00 367 | 0.01 368 | 0.00 369 | 0.02 370 | 0.01 371 | 0.01 372 | 0.02 373 | 0.00 374 | 0.00 375 | 0.00 376 | 0.01 377 | 0.01 378 | 0.03 379 | 0.00 380 | 0.01 381 | 0.01 382 | 0.00 383 | 0.01 384 | 0.01 385 | 0.01 386 | 0.02 387 | 0.02 388 | 0.01 389 | 0.01 390 | 0.02 391 | 0.00 392 | 0.00 393 | 0.01 394 | 0.01 395 | 0.01 396 | 0.00 397 | 0.00 398 | 0.00 399 | 0.01 400 | 0.05 401 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/uSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/uSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/val_acc_epochs_x.txt: -------------------------------------------------------------------------------- 1 | 2.00 2 | 4.00 3 | 6.00 4 | 8.00 5 | 10.00 6 | 12.00 7 | 14.00 8 | 16.00 9 | 18.00 10 | 20.00 11 | 22.00 12 | 24.00 13 | 26.00 14 | 28.00 15 | 30.00 16 | 32.00 17 | 34.00 18 | 36.00 19 | 38.00 20 | 40.00 21 | 42.00 22 | 44.00 23 | 46.00 24 | 48.00 25 | 50.00 26 | 52.00 27 | 54.00 28 | 56.00 29 | 58.00 30 | 60.00 31 | 62.00 32 | 64.00 33 | 66.00 34 | 68.00 35 | 70.00 36 | 72.00 37 | 74.00 38 | 76.00 39 | 78.00 40 | 80.00 41 | 82.00 42 | 84.00 43 | 86.00 44 | 88.00 45 | 90.00 46 | 92.00 47 | 94.00 48 | 96.00 49 | 98.00 50 | 100.00 51 | 102.00 52 | 104.00 53 | 106.00 54 | 108.00 55 | 110.00 56 | 112.00 57 | 114.00 58 | 116.00 59 | 118.00 60 | 120.00 61 | 122.00 62 | 124.00 63 | 126.00 64 | 128.00 65 | 130.00 66 | 132.00 67 | 134.00 68 | 136.00 69 | 138.00 70 | 140.00 71 | 142.00 72 | 144.00 73 | 146.00 74 | 148.00 75 | 150.00 76 | 152.00 77 | 154.00 78 | 156.00 79 | 158.00 80 | 160.00 81 | 162.00 82 | 164.00 83 | 166.00 84 | 168.00 85 | 170.00 86 | 172.00 87 | 174.00 88 | 176.00 89 | 178.00 90 | 180.00 91 | 182.00 92 | 184.00 93 | 186.00 94 | 188.00 95 | 190.00 96 | 192.00 97 | 194.00 98 | 196.00 99 | 198.00 100 | 200.00 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_0/val_acc_epochs_y.txt: -------------------------------------------------------------------------------- 1 | 19.56 2 | 35.18 3 | 36.56 4 | 44.68 5 | 47.56 6 | 49.50 7 | 50.68 8 | 49.66 9 | 48.96 10 | 56.78 11 | 53.22 12 | 58.02 13 | 58.88 14 | 57.86 15 | 56.76 16 | 58.34 17 | 55.04 18 | 50.74 19 | 47.78 20 | 62.92 21 | 63.94 22 | 65.36 23 | 56.06 24 | 64.24 25 | 62.10 26 | 56.70 27 | 63.36 28 | 65.98 29 | 62.76 30 | 60.18 31 | 64.94 32 | 68.96 33 | 65.54 34 | 63.70 35 | 66.26 36 | 63.90 37 | 68.60 38 | 68.60 39 | 61.42 40 | 65.40 41 | 68.52 42 | 68.28 43 | 68.88 44 | 70.22 45 | 70.58 46 | 68.96 47 | 65.02 48 | 71.36 49 | 66.46 50 | 70.48 51 | 67.36 52 | 66.72 53 | 68.16 54 | 69.12 55 | 68.24 56 | 70.72 57 | 67.44 58 | 69.26 59 | 69.18 60 | 71.04 61 | 68.94 62 | 70.64 63 | 71.64 64 | 69.16 65 | 70.86 66 | 71.38 67 | 69.28 68 | 71.44 69 | 70.98 70 | 71.54 71 | 72.78 72 | 71.52 73 | 72.22 74 | 71.96 75 | 71.92 76 | 71.94 77 | 70.78 78 | 72.50 79 | 72.64 80 | 72.78 81 | 72.32 82 | 71.52 83 | 71.84 84 | 71.88 85 | 71.82 86 | 71.84 87 | 72.10 88 | 72.26 89 | 72.42 90 | 73.06 91 | 72.02 92 | 71.98 93 | 72.28 94 | 72.50 95 | 72.28 96 | 71.66 97 | 72.40 98 | 72.56 99 | 72.18 100 | 72.52 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Validation Accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Validation Accuracy.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/Iterations_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/Iterations_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/activeSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/activeSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/lSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/lSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/plot_epoch_xvalues.txt: -------------------------------------------------------------------------------- 1 | 1.00 2 | 2.00 3 | 3.00 4 | 4.00 5 | 5.00 6 | 6.00 7 | 7.00 8 | 8.00 9 | 9.00 10 | 10.00 11 | 11.00 12 | 12.00 13 | 13.00 14 | 14.00 15 | 15.00 16 | 16.00 17 | 17.00 18 | 18.00 19 | 19.00 20 | 20.00 21 | 21.00 22 | 22.00 23 | 23.00 24 | 24.00 25 | 25.00 26 | 26.00 27 | 27.00 28 | 28.00 29 | 29.00 30 | 30.00 31 | 31.00 32 | 32.00 33 | 33.00 34 | 34.00 35 | 35.00 36 | 36.00 37 | 37.00 38 | 38.00 39 | 39.00 40 | 40.00 41 | 41.00 42 | 42.00 43 | 43.00 44 | 44.00 45 | 45.00 46 | 46.00 47 | 47.00 48 | 48.00 49 | 49.00 50 | 50.00 51 | 51.00 52 | 52.00 53 | 53.00 54 | 54.00 55 | 55.00 56 | 56.00 57 | 57.00 58 | 58.00 59 | 59.00 60 | 60.00 61 | 61.00 62 | 62.00 63 | 63.00 64 | 64.00 65 | 65.00 66 | 66.00 67 | 67.00 68 | 68.00 69 | 69.00 70 | 70.00 71 | 71.00 72 | 72.00 73 | 73.00 74 | 74.00 75 | 75.00 76 | 76.00 77 | 77.00 78 | 78.00 79 | 79.00 80 | 80.00 81 | 81.00 82 | 82.00 83 | 83.00 84 | 84.00 85 | 85.00 86 | 86.00 87 | 87.00 88 | 88.00 89 | 89.00 90 | 90.00 91 | 91.00 92 | 92.00 93 | 93.00 94 | 94.00 95 | 95.00 96 | 96.00 97 | 97.00 98 | 98.00 99 | 99.00 100 | 100.00 101 | 101.00 102 | 102.00 103 | 103.00 104 | 104.00 105 | 105.00 106 | 106.00 107 | 107.00 108 | 108.00 109 | 109.00 110 | 110.00 111 | 111.00 112 | 112.00 113 | 113.00 114 | 114.00 115 | 115.00 116 | 116.00 117 | 117.00 118 | 118.00 119 | 119.00 120 | 120.00 121 | 121.00 122 | 122.00 123 | 123.00 124 | 124.00 125 | 125.00 126 | 126.00 127 | 127.00 128 | 128.00 129 | 129.00 130 | 130.00 131 | 131.00 132 | 132.00 133 | 133.00 134 | 134.00 135 | 135.00 136 | 136.00 137 | 137.00 138 | 138.00 139 | 139.00 140 | 140.00 141 | 141.00 142 | 142.00 143 | 143.00 144 | 144.00 145 | 145.00 146 | 146.00 147 | 147.00 148 | 148.00 149 | 149.00 150 | 150.00 151 | 151.00 152 | 152.00 153 | 153.00 154 | 154.00 155 | 155.00 156 | 156.00 157 | 157.00 158 | 158.00 159 | 159.00 160 | 160.00 161 | 161.00 162 | 162.00 163 | 163.00 164 | 164.00 165 | 165.00 166 | 166.00 167 | 167.00 168 | 168.00 169 | 169.00 170 | 170.00 171 | 171.00 172 | 172.00 173 | 173.00 174 | 174.00 175 | 175.00 176 | 176.00 177 | 177.00 178 | 178.00 179 | 179.00 180 | 180.00 181 | 181.00 182 | 182.00 183 | 183.00 184 | 184.00 185 | 185.00 186 | 186.00 187 | 187.00 188 | 188.00 189 | 189.00 190 | 190.00 191 | 191.00 192 | 192.00 193 | 193.00 194 | 194.00 195 | 195.00 196 | 196.00 197 | 197.00 198 | 198.00 199 | 199.00 200 | 200.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/plot_epoch_yvalues.txt: -------------------------------------------------------------------------------- 1 | 1.02 2 | 1.78 3 | 1.48 4 | 1.88 5 | 0.78 6 | 0.54 7 | 0.73 8 | 1.24 9 | 0.26 10 | 0.63 11 | 0.54 12 | 0.52 13 | 1.25 14 | 0.21 15 | 0.99 16 | 0.76 17 | 0.76 18 | 0.84 19 | 0.81 20 | 0.62 21 | 1.19 22 | 0.14 23 | 1.02 24 | 0.77 25 | 0.73 26 | 0.16 27 | 0.35 28 | 0.71 29 | 0.44 30 | 0.41 31 | 0.48 32 | 0.12 33 | 0.04 34 | 0.55 35 | 0.74 36 | 0.32 37 | 0.31 38 | 0.37 39 | 0.15 40 | 0.11 41 | 0.66 42 | 0.31 43 | 0.38 44 | 0.04 45 | 0.16 46 | 0.41 47 | 0.09 48 | 0.17 49 | 0.27 50 | 0.40 51 | 0.02 52 | 0.29 53 | 0.06 54 | 0.21 55 | 0.47 56 | 0.23 57 | 0.22 58 | 0.49 59 | 0.09 60 | 0.42 61 | 0.04 62 | 0.12 63 | 0.10 64 | 0.23 65 | 0.29 66 | 0.68 67 | 0.01 68 | 0.01 69 | 0.07 70 | 0.02 71 | 0.06 72 | 0.37 73 | 0.45 74 | 0.65 75 | 0.73 76 | 0.63 77 | 0.34 78 | 0.04 79 | 1.02 80 | 0.17 81 | 0.12 82 | 1.28 83 | 0.02 84 | 0.11 85 | 0.26 86 | 0.51 87 | 0.08 88 | 0.04 89 | 0.41 90 | 0.01 91 | 0.02 92 | 0.01 93 | 0.03 94 | 0.01 95 | 0.29 96 | 0.58 97 | 0.02 98 | 0.02 99 | 0.01 100 | 0.42 101 | 0.15 102 | 0.06 103 | 0.00 104 | 0.51 105 | 0.04 106 | 0.37 107 | 0.02 108 | 0.07 109 | 0.05 110 | 0.21 111 | 0.01 112 | 0.02 113 | 0.00 114 | 0.01 115 | 0.00 116 | 0.00 117 | 0.01 118 | 0.03 119 | 0.01 120 | 0.01 121 | 0.00 122 | 0.03 123 | 0.09 124 | 0.06 125 | 0.09 126 | 0.02 127 | 0.00 128 | 0.44 129 | 0.07 130 | 0.00 131 | 0.01 132 | 0.00 133 | 0.00 134 | 0.01 135 | 0.02 136 | 0.30 137 | 0.00 138 | 0.01 139 | 0.10 140 | 0.01 141 | 0.00 142 | 0.00 143 | 0.00 144 | 0.04 145 | 0.00 146 | 0.00 147 | 0.00 148 | 0.00 149 | 0.03 150 | 0.02 151 | 0.02 152 | 0.20 153 | 0.01 154 | 0.01 155 | 0.03 156 | 0.22 157 | 0.00 158 | 0.30 159 | 0.15 160 | 0.00 161 | 0.02 162 | 0.01 163 | 0.17 164 | 0.00 165 | 0.00 166 | 0.00 167 | 0.00 168 | 0.01 169 | 0.05 170 | 0.00 171 | 0.01 172 | 0.00 173 | 0.08 174 | 0.11 175 | 0.00 176 | 0.00 177 | 0.07 178 | 0.00 179 | 0.03 180 | 0.00 181 | 0.00 182 | 0.01 183 | 0.00 184 | 0.00 185 | 0.00 186 | 0.05 187 | 0.00 188 | 0.06 189 | 0.00 190 | 0.03 191 | 0.00 192 | 0.00 193 | 0.05 194 | 0.11 195 | 0.05 196 | 0.00 197 | 0.00 198 | 0.00 199 | 0.00 200 | 0.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/uSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/uSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/val_acc_epochs_x.txt: -------------------------------------------------------------------------------- 1 | 2.00 2 | 4.00 3 | 6.00 4 | 8.00 5 | 10.00 6 | 12.00 7 | 14.00 8 | 16.00 9 | 18.00 10 | 20.00 11 | 22.00 12 | 24.00 13 | 26.00 14 | 28.00 15 | 30.00 16 | 32.00 17 | 34.00 18 | 36.00 19 | 38.00 20 | 40.00 21 | 42.00 22 | 44.00 23 | 46.00 24 | 48.00 25 | 50.00 26 | 52.00 27 | 54.00 28 | 56.00 29 | 58.00 30 | 60.00 31 | 62.00 32 | 64.00 33 | 66.00 34 | 68.00 35 | 70.00 36 | 72.00 37 | 74.00 38 | 76.00 39 | 78.00 40 | 80.00 41 | 82.00 42 | 84.00 43 | 86.00 44 | 88.00 45 | 90.00 46 | 92.00 47 | 94.00 48 | 96.00 49 | 98.00 50 | 100.00 51 | 102.00 52 | 104.00 53 | 106.00 54 | 108.00 55 | 110.00 56 | 112.00 57 | 114.00 58 | 116.00 59 | 118.00 60 | 120.00 61 | 122.00 62 | 124.00 63 | 126.00 64 | 128.00 65 | 130.00 66 | 132.00 67 | 134.00 68 | 136.00 69 | 138.00 70 | 140.00 71 | 142.00 72 | 144.00 73 | 146.00 74 | 148.00 75 | 150.00 76 | 152.00 77 | 154.00 78 | 156.00 79 | 158.00 80 | 160.00 81 | 162.00 82 | 164.00 83 | 166.00 84 | 168.00 85 | 170.00 86 | 172.00 87 | 174.00 88 | 176.00 89 | 178.00 90 | 180.00 91 | 182.00 92 | 184.00 93 | 186.00 94 | 188.00 95 | 190.00 96 | 192.00 97 | 194.00 98 | 196.00 99 | 198.00 100 | 200.00 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_1/val_acc_epochs_y.txt: -------------------------------------------------------------------------------- 1 | 70.04 2 | 72.88 3 | 71.68 4 | 73.36 5 | 75.62 6 | 75.38 7 | 72.94 8 | 74.72 9 | 74.52 10 | 74.28 11 | 71.98 12 | 74.08 13 | 75.66 14 | 74.38 15 | 76.30 16 | 77.66 17 | 74.90 18 | 75.98 19 | 76.94 20 | 78.02 21 | 77.38 22 | 77.76 23 | 77.18 24 | 77.96 25 | 76.46 26 | 78.00 27 | 75.94 28 | 78.16 29 | 78.02 30 | 78.28 31 | 78.28 32 | 79.14 33 | 79.44 34 | 78.34 35 | 79.20 36 | 78.58 37 | 78.94 38 | 78.28 39 | 79.90 40 | 77.70 41 | 78.76 42 | 78.20 43 | 78.00 44 | 79.32 45 | 79.20 46 | 79.44 47 | 80.58 48 | 79.42 49 | 80.66 50 | 80.80 51 | 80.44 52 | 80.34 53 | 80.64 54 | 80.42 55 | 80.98 56 | 81.26 57 | 80.94 58 | 81.38 59 | 81.62 60 | 81.42 61 | 81.34 62 | 81.48 63 | 80.56 64 | 81.08 65 | 81.18 66 | 82.16 67 | 81.94 68 | 81.84 69 | 81.72 70 | 81.10 71 | 81.64 72 | 81.86 73 | 82.72 74 | 82.34 75 | 82.14 76 | 81.72 77 | 82.30 78 | 82.22 79 | 81.92 80 | 82.60 81 | 82.40 82 | 82.06 83 | 82.14 84 | 82.14 85 | 82.42 86 | 82.44 87 | 82.28 88 | 82.42 89 | 82.50 90 | 82.48 91 | 82.54 92 | 82.58 93 | 83.08 94 | 82.52 95 | 82.64 96 | 82.36 97 | 82.56 98 | 82.40 99 | 82.40 100 | 82.94 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Validation Accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Validation Accuracy.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/Iterations_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/Iterations_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/activeSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/activeSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/lSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/lSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/plot_epoch_xvalues.txt: -------------------------------------------------------------------------------- 1 | 1.00 2 | 2.00 3 | 3.00 4 | 4.00 5 | 5.00 6 | 6.00 7 | 7.00 8 | 8.00 9 | 9.00 10 | 10.00 11 | 11.00 12 | 12.00 13 | 13.00 14 | 14.00 15 | 15.00 16 | 16.00 17 | 17.00 18 | 18.00 19 | 19.00 20 | 20.00 21 | 21.00 22 | 22.00 23 | 23.00 24 | 24.00 25 | 25.00 26 | 26.00 27 | 27.00 28 | 28.00 29 | 29.00 30 | 30.00 31 | 31.00 32 | 32.00 33 | 33.00 34 | 34.00 35 | 35.00 36 | 36.00 37 | 37.00 38 | 38.00 39 | 39.00 40 | 40.00 41 | 41.00 42 | 42.00 43 | 43.00 44 | 44.00 45 | 45.00 46 | 46.00 47 | 47.00 48 | 48.00 49 | 49.00 50 | 50.00 51 | 51.00 52 | 52.00 53 | 53.00 54 | 54.00 55 | 55.00 56 | 56.00 57 | 57.00 58 | 58.00 59 | 59.00 60 | 60.00 61 | 61.00 62 | 62.00 63 | 63.00 64 | 64.00 65 | 65.00 66 | 66.00 67 | 67.00 68 | 68.00 69 | 69.00 70 | 70.00 71 | 71.00 72 | 72.00 73 | 73.00 74 | 74.00 75 | 75.00 76 | 76.00 77 | 77.00 78 | 78.00 79 | 79.00 80 | 80.00 81 | 81.00 82 | 82.00 83 | 83.00 84 | 84.00 85 | 85.00 86 | 86.00 87 | 87.00 88 | 88.00 89 | 89.00 90 | 90.00 91 | 91.00 92 | 92.00 93 | 93.00 94 | 94.00 95 | 95.00 96 | 96.00 97 | 97.00 98 | 98.00 99 | 99.00 100 | 100.00 101 | 101.00 102 | 102.00 103 | 103.00 104 | 104.00 105 | 105.00 106 | 106.00 107 | 107.00 108 | 108.00 109 | 109.00 110 | 110.00 111 | 111.00 112 | 112.00 113 | 113.00 114 | 114.00 115 | 115.00 116 | 116.00 117 | 117.00 118 | 118.00 119 | 119.00 120 | 120.00 121 | 121.00 122 | 122.00 123 | 123.00 124 | 124.00 125 | 125.00 126 | 126.00 127 | 127.00 128 | 128.00 129 | 129.00 130 | 130.00 131 | 131.00 132 | 132.00 133 | 133.00 134 | 134.00 135 | 135.00 136 | 136.00 137 | 137.00 138 | 138.00 139 | 139.00 140 | 140.00 141 | 141.00 142 | 142.00 143 | 143.00 144 | 144.00 145 | 145.00 146 | 146.00 147 | 147.00 148 | 148.00 149 | 149.00 150 | 150.00 151 | 151.00 152 | 152.00 153 | 153.00 154 | 154.00 155 | 155.00 156 | 156.00 157 | 157.00 158 | 158.00 159 | 159.00 160 | 160.00 161 | 161.00 162 | 162.00 163 | 163.00 164 | 164.00 165 | 165.00 166 | 166.00 167 | 167.00 168 | 168.00 169 | 169.00 170 | 170.00 171 | 171.00 172 | 172.00 173 | 173.00 174 | 174.00 175 | 175.00 176 | 176.00 177 | 177.00 178 | 178.00 179 | 179.00 180 | 180.00 181 | 181.00 182 | 182.00 183 | 183.00 184 | 184.00 185 | 185.00 186 | 186.00 187 | 187.00 188 | 188.00 189 | 189.00 190 | 190.00 191 | 191.00 192 | 192.00 193 | 193.00 194 | 194.00 195 | 195.00 196 | 196.00 197 | 197.00 198 | 198.00 199 | 199.00 200 | 200.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/plot_epoch_yvalues.txt: -------------------------------------------------------------------------------- 1 | 1.36 2 | 0.50 3 | 0.75 4 | 0.53 5 | 1.11 6 | 0.36 7 | 0.36 8 | 0.46 9 | 0.80 10 | 0.48 11 | 0.41 12 | 0.88 13 | 0.46 14 | 0.53 15 | 0.32 16 | 0.63 17 | 0.64 18 | 0.27 19 | 0.40 20 | 0.79 21 | 0.55 22 | 0.18 23 | 0.48 24 | 0.49 25 | 0.49 26 | 0.63 27 | 0.49 28 | 0.19 29 | 0.90 30 | 0.56 31 | 0.34 32 | 0.17 33 | 0.35 34 | 0.23 35 | 0.14 36 | 0.10 37 | 0.46 38 | 0.21 39 | 0.27 40 | 0.23 41 | 0.47 42 | 0.10 43 | 0.81 44 | 0.09 45 | 0.24 46 | 0.46 47 | 0.12 48 | 0.22 49 | 0.03 50 | 0.04 51 | 0.51 52 | 0.20 53 | 0.06 54 | 0.80 55 | 0.13 56 | 0.15 57 | 0.19 58 | 0.10 59 | 0.06 60 | 0.25 61 | 0.22 62 | 0.04 63 | 0.50 64 | 0.10 65 | 0.07 66 | 0.27 67 | 0.04 68 | 0.13 69 | 0.24 70 | 0.02 71 | 0.09 72 | 0.14 73 | 0.27 74 | 0.09 75 | 0.09 76 | 0.36 77 | 0.15 78 | 0.14 79 | 0.27 80 | 0.17 81 | 0.02 82 | 0.35 83 | 0.21 84 | 0.22 85 | 0.02 86 | 0.14 87 | 0.07 88 | 0.05 89 | 0.23 90 | 0.11 91 | 0.02 92 | 0.02 93 | 0.51 94 | 0.05 95 | 0.12 96 | 0.31 97 | 0.06 98 | 0.00 99 | 0.23 100 | 0.10 101 | 0.13 102 | 0.14 103 | 0.02 104 | 0.03 105 | 0.16 106 | 0.01 107 | 0.02 108 | 0.11 109 | 0.00 110 | 0.29 111 | 0.32 112 | 0.00 113 | 0.03 114 | 0.16 115 | 0.02 116 | 0.01 117 | 0.02 118 | 0.07 119 | 0.04 120 | 0.00 121 | 0.11 122 | 0.01 123 | 0.03 124 | 0.00 125 | 0.10 126 | 0.03 127 | 0.03 128 | 0.02 129 | 0.01 130 | 0.01 131 | 0.02 132 | 0.00 133 | 0.00 134 | 0.03 135 | 0.10 136 | 0.00 137 | 0.00 138 | 0.01 139 | 0.01 140 | 0.03 141 | 0.01 142 | 0.00 143 | 0.00 144 | 0.09 145 | 0.09 146 | 0.00 147 | 0.01 148 | 0.00 149 | 0.03 150 | 0.01 151 | 0.01 152 | 0.00 153 | 0.00 154 | 0.02 155 | 0.00 156 | 0.00 157 | 0.00 158 | 0.01 159 | 0.00 160 | 0.00 161 | 0.09 162 | 0.01 163 | 0.12 164 | 0.02 165 | 0.01 166 | 0.00 167 | 0.00 168 | 0.01 169 | 0.00 170 | 0.00 171 | 0.02 172 | 0.04 173 | 0.00 174 | 0.01 175 | 0.01 176 | 0.08 177 | 0.00 178 | 0.00 179 | 0.00 180 | 0.00 181 | 0.01 182 | 0.00 183 | 0.01 184 | 0.00 185 | 0.00 186 | 0.19 187 | 0.00 188 | 0.00 189 | 0.00 190 | 0.00 191 | 0.00 192 | 0.03 193 | 0.00 194 | 0.00 195 | 0.08 196 | 0.02 197 | 0.01 198 | 0.00 199 | 0.00 200 | 0.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/uSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/uSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/val_acc_epochs_x.txt: -------------------------------------------------------------------------------- 1 | 2.00 2 | 4.00 3 | 6.00 4 | 8.00 5 | 10.00 6 | 12.00 7 | 14.00 8 | 16.00 9 | 18.00 10 | 20.00 11 | 22.00 12 | 24.00 13 | 26.00 14 | 28.00 15 | 30.00 16 | 32.00 17 | 34.00 18 | 36.00 19 | 38.00 20 | 40.00 21 | 42.00 22 | 44.00 23 | 46.00 24 | 48.00 25 | 50.00 26 | 52.00 27 | 54.00 28 | 56.00 29 | 58.00 30 | 60.00 31 | 62.00 32 | 64.00 33 | 66.00 34 | 68.00 35 | 70.00 36 | 72.00 37 | 74.00 38 | 76.00 39 | 78.00 40 | 80.00 41 | 82.00 42 | 84.00 43 | 86.00 44 | 88.00 45 | 90.00 46 | 92.00 47 | 94.00 48 | 96.00 49 | 98.00 50 | 100.00 51 | 102.00 52 | 104.00 53 | 106.00 54 | 108.00 55 | 110.00 56 | 112.00 57 | 114.00 58 | 116.00 59 | 118.00 60 | 120.00 61 | 122.00 62 | 124.00 63 | 126.00 64 | 128.00 65 | 130.00 66 | 132.00 67 | 134.00 68 | 136.00 69 | 138.00 70 | 140.00 71 | 142.00 72 | 144.00 73 | 146.00 74 | 148.00 75 | 150.00 76 | 152.00 77 | 154.00 78 | 156.00 79 | 158.00 80 | 160.00 81 | 162.00 82 | 164.00 83 | 166.00 84 | 168.00 85 | 170.00 86 | 172.00 87 | 174.00 88 | 176.00 89 | 178.00 90 | 180.00 91 | 182.00 92 | 184.00 93 | 186.00 94 | 188.00 95 | 190.00 96 | 192.00 97 | 194.00 98 | 196.00 99 | 198.00 100 | 200.00 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_2/val_acc_epochs_y.txt: -------------------------------------------------------------------------------- 1 | 77.78 2 | 80.22 3 | 78.34 4 | 80.14 5 | 79.92 6 | 80.06 7 | 80.78 8 | 77.90 9 | 80.30 10 | 80.96 11 | 82.00 12 | 81.84 13 | 82.94 14 | 81.58 15 | 82.50 16 | 82.64 17 | 81.78 18 | 83.20 19 | 81.72 20 | 81.64 21 | 82.74 22 | 82.30 23 | 83.54 24 | 81.02 25 | 82.96 26 | 82.56 27 | 83.24 28 | 83.30 29 | 80.80 30 | 82.10 31 | 84.18 32 | 83.28 33 | 83.88 34 | 82.38 35 | 83.24 36 | 83.96 37 | 81.72 38 | 83.54 39 | 84.40 40 | 84.92 41 | 84.18 42 | 84.34 43 | 83.84 44 | 83.78 45 | 83.58 46 | 85.40 47 | 83.40 48 | 84.38 49 | 84.78 50 | 85.04 51 | 85.26 52 | 85.20 53 | 84.82 54 | 85.78 55 | 85.18 56 | 85.20 57 | 85.78 58 | 85.78 59 | 85.76 60 | 84.58 61 | 85.76 62 | 86.20 63 | 86.08 64 | 86.20 65 | 86.28 66 | 86.46 67 | 86.72 68 | 86.28 69 | 85.92 70 | 86.54 71 | 87.10 72 | 86.26 73 | 87.16 74 | 86.50 75 | 87.16 76 | 87.20 77 | 87.22 78 | 87.08 79 | 87.42 80 | 87.38 81 | 87.08 82 | 86.90 83 | 87.28 84 | 87.14 85 | 87.34 86 | 87.14 87 | 87.30 88 | 87.36 89 | 87.06 90 | 86.88 91 | 87.44 92 | 87.12 93 | 87.30 94 | 87.88 95 | 87.66 96 | 87.40 97 | 86.96 98 | 87.58 99 | 87.12 100 | 87.82 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Validation Accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Validation Accuracy.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/Iterations_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/Iterations_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/activeSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/activeSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/lSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/lSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/plot_epoch_xvalues.txt: -------------------------------------------------------------------------------- 1 | 1.00 2 | 2.00 3 | 3.00 4 | 4.00 5 | 5.00 6 | 6.00 7 | 7.00 8 | 8.00 9 | 9.00 10 | 10.00 11 | 11.00 12 | 12.00 13 | 13.00 14 | 14.00 15 | 15.00 16 | 16.00 17 | 17.00 18 | 18.00 19 | 19.00 20 | 20.00 21 | 21.00 22 | 22.00 23 | 23.00 24 | 24.00 25 | 25.00 26 | 26.00 27 | 27.00 28 | 28.00 29 | 29.00 30 | 30.00 31 | 31.00 32 | 32.00 33 | 33.00 34 | 34.00 35 | 35.00 36 | 36.00 37 | 37.00 38 | 38.00 39 | 39.00 40 | 40.00 41 | 41.00 42 | 42.00 43 | 43.00 44 | 44.00 45 | 45.00 46 | 46.00 47 | 47.00 48 | 48.00 49 | 49.00 50 | 50.00 51 | 51.00 52 | 52.00 53 | 53.00 54 | 54.00 55 | 55.00 56 | 56.00 57 | 57.00 58 | 58.00 59 | 59.00 60 | 60.00 61 | 61.00 62 | 62.00 63 | 63.00 64 | 64.00 65 | 65.00 66 | 66.00 67 | 67.00 68 | 68.00 69 | 69.00 70 | 70.00 71 | 71.00 72 | 72.00 73 | 73.00 74 | 74.00 75 | 75.00 76 | 76.00 77 | 77.00 78 | 78.00 79 | 79.00 80 | 80.00 81 | 81.00 82 | 82.00 83 | 83.00 84 | 84.00 85 | 85.00 86 | 86.00 87 | 87.00 88 | 88.00 89 | 89.00 90 | 90.00 91 | 91.00 92 | 92.00 93 | 93.00 94 | 94.00 95 | 95.00 96 | 96.00 97 | 97.00 98 | 98.00 99 | 99.00 100 | 100.00 101 | 101.00 102 | 102.00 103 | 103.00 104 | 104.00 105 | 105.00 106 | 106.00 107 | 107.00 108 | 108.00 109 | 109.00 110 | 110.00 111 | 111.00 112 | 112.00 113 | 113.00 114 | 114.00 115 | 115.00 116 | 116.00 117 | 117.00 118 | 118.00 119 | 119.00 120 | 120.00 121 | 121.00 122 | 122.00 123 | 123.00 124 | 124.00 125 | 125.00 126 | 126.00 127 | 127.00 128 | 128.00 129 | 129.00 130 | 130.00 131 | 131.00 132 | 132.00 133 | 133.00 134 | 134.00 135 | 135.00 136 | 136.00 137 | 137.00 138 | 138.00 139 | 139.00 140 | 140.00 141 | 141.00 142 | 142.00 143 | 143.00 144 | 144.00 145 | 145.00 146 | 146.00 147 | 147.00 148 | 148.00 149 | 149.00 150 | 150.00 151 | 151.00 152 | 152.00 153 | 153.00 154 | 154.00 155 | 155.00 156 | 156.00 157 | 157.00 158 | 158.00 159 | 159.00 160 | 160.00 161 | 161.00 162 | 162.00 163 | 163.00 164 | 164.00 165 | 165.00 166 | 166.00 167 | 167.00 168 | 168.00 169 | 169.00 170 | 170.00 171 | 171.00 172 | 172.00 173 | 173.00 174 | 174.00 175 | 175.00 176 | 176.00 177 | 177.00 178 | 178.00 179 | 179.00 180 | 180.00 181 | 181.00 182 | 182.00 183 | 183.00 184 | 184.00 185 | 185.00 186 | 186.00 187 | 187.00 188 | 188.00 189 | 189.00 190 | 190.00 191 | 191.00 192 | 192.00 193 | 193.00 194 | 194.00 195 | 195.00 196 | 196.00 197 | 197.00 198 | 198.00 199 | 199.00 200 | 200.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/plot_epoch_yvalues.txt: -------------------------------------------------------------------------------- 1 | 0.43 2 | 0.54 3 | 0.60 4 | 0.36 5 | 0.66 6 | 0.30 7 | 0.21 8 | 0.75 9 | 0.14 10 | 0.27 11 | 0.59 12 | 0.28 13 | 0.31 14 | 0.86 15 | 0.44 16 | 0.59 17 | 0.23 18 | 0.27 19 | 0.56 20 | 0.45 21 | 0.25 22 | 0.39 23 | 0.47 24 | 0.69 25 | 0.29 26 | 0.61 27 | 0.43 28 | 0.34 29 | 0.30 30 | 0.19 31 | 0.18 32 | 0.15 33 | 0.27 34 | 0.23 35 | 0.23 36 | 0.46 37 | 0.12 38 | 0.02 39 | 0.21 40 | 0.71 41 | 0.39 42 | 0.40 43 | 0.28 44 | 0.09 45 | 0.12 46 | 0.20 47 | 0.17 48 | 0.06 49 | 0.33 50 | 0.37 51 | 0.56 52 | 0.36 53 | 0.38 54 | 0.15 55 | 0.13 56 | 0.33 57 | 0.28 58 | 0.29 59 | 0.19 60 | 0.24 61 | 0.04 62 | 0.04 63 | 0.04 64 | 0.25 65 | 0.02 66 | 0.21 67 | 0.14 68 | 0.04 69 | 0.28 70 | 0.06 71 | 0.27 72 | 0.12 73 | 0.22 74 | 0.24 75 | 0.35 76 | 0.08 77 | 0.08 78 | 0.12 79 | 0.06 80 | 0.07 81 | 0.04 82 | 0.18 83 | 0.03 84 | 0.03 85 | 0.28 86 | 0.23 87 | 0.03 88 | 0.03 89 | 0.10 90 | 0.23 91 | 0.09 92 | 0.09 93 | 0.32 94 | 0.04 95 | 0.03 96 | 0.04 97 | 0.03 98 | 0.30 99 | 0.02 100 | 0.09 101 | 0.09 102 | 0.04 103 | 0.04 104 | 0.09 105 | 0.23 106 | 0.02 107 | 0.10 108 | 0.19 109 | 0.23 110 | 0.00 111 | 0.02 112 | 0.08 113 | 0.01 114 | 0.00 115 | 0.00 116 | 0.03 117 | 0.05 118 | 0.01 119 | 0.06 120 | 0.00 121 | 0.01 122 | 0.13 123 | 0.01 124 | 0.01 125 | 0.03 126 | 0.00 127 | 0.00 128 | 0.00 129 | 0.10 130 | 0.02 131 | 0.01 132 | 0.01 133 | 0.02 134 | 0.02 135 | 0.00 136 | 0.02 137 | 0.01 138 | 0.01 139 | 0.00 140 | 0.01 141 | 0.07 142 | 0.05 143 | 0.01 144 | 0.01 145 | 0.01 146 | 0.09 147 | 0.05 148 | 0.01 149 | 0.01 150 | 0.00 151 | 0.01 152 | 0.01 153 | 0.00 154 | 0.00 155 | 0.05 156 | 0.01 157 | 0.00 158 | 0.03 159 | 0.00 160 | 0.06 161 | 0.00 162 | 0.00 163 | 0.00 164 | 0.00 165 | 0.00 166 | 0.00 167 | 0.07 168 | 0.00 169 | 0.00 170 | 0.00 171 | 0.00 172 | 0.00 173 | 0.00 174 | 0.00 175 | 0.00 176 | 0.00 177 | 0.00 178 | 0.00 179 | 0.01 180 | 0.00 181 | 0.00 182 | 0.07 183 | 0.00 184 | 0.00 185 | 0.00 186 | 0.00 187 | 0.00 188 | 0.00 189 | 0.01 190 | 0.00 191 | 0.00 192 | 0.00 193 | 0.00 194 | 0.00 195 | 0.01 196 | 0.00 197 | 0.00 198 | 0.00 199 | 0.00 200 | 0.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/uSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/uSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/val_acc_epochs_x.txt: -------------------------------------------------------------------------------- 1 | 2.00 2 | 4.00 3 | 6.00 4 | 8.00 5 | 10.00 6 | 12.00 7 | 14.00 8 | 16.00 9 | 18.00 10 | 20.00 11 | 22.00 12 | 24.00 13 | 26.00 14 | 28.00 15 | 30.00 16 | 32.00 17 | 34.00 18 | 36.00 19 | 38.00 20 | 40.00 21 | 42.00 22 | 44.00 23 | 46.00 24 | 48.00 25 | 50.00 26 | 52.00 27 | 54.00 28 | 56.00 29 | 58.00 30 | 60.00 31 | 62.00 32 | 64.00 33 | 66.00 34 | 68.00 35 | 70.00 36 | 72.00 37 | 74.00 38 | 76.00 39 | 78.00 40 | 80.00 41 | 82.00 42 | 84.00 43 | 86.00 44 | 88.00 45 | 90.00 46 | 92.00 47 | 94.00 48 | 96.00 49 | 98.00 50 | 100.00 51 | 102.00 52 | 104.00 53 | 106.00 54 | 108.00 55 | 110.00 56 | 112.00 57 | 114.00 58 | 116.00 59 | 118.00 60 | 120.00 61 | 122.00 62 | 124.00 63 | 126.00 64 | 128.00 65 | 130.00 66 | 132.00 67 | 134.00 68 | 136.00 69 | 138.00 70 | 140.00 71 | 142.00 72 | 144.00 73 | 146.00 74 | 148.00 75 | 150.00 76 | 152.00 77 | 154.00 78 | 156.00 79 | 158.00 80 | 160.00 81 | 162.00 82 | 164.00 83 | 166.00 84 | 168.00 85 | 170.00 86 | 172.00 87 | 174.00 88 | 176.00 89 | 178.00 90 | 180.00 91 | 182.00 92 | 184.00 93 | 186.00 94 | 188.00 95 | 190.00 96 | 192.00 97 | 194.00 98 | 196.00 99 | 198.00 100 | 200.00 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_3/val_acc_epochs_y.txt: -------------------------------------------------------------------------------- 1 | 80.46 2 | 82.38 3 | 83.24 4 | 84.12 5 | 82.84 6 | 81.04 7 | 84.58 8 | 84.36 9 | 83.66 10 | 83.64 11 | 83.24 12 | 84.52 13 | 84.14 14 | 82.96 15 | 82.40 16 | 84.62 17 | 84.44 18 | 86.14 19 | 84.90 20 | 85.64 21 | 86.28 22 | 85.44 23 | 86.80 24 | 84.58 25 | 85.36 26 | 85.48 27 | 85.54 28 | 86.56 29 | 85.90 30 | 85.20 31 | 86.02 32 | 85.40 33 | 85.56 34 | 86.36 35 | 84.94 36 | 85.76 37 | 85.46 38 | 87.02 39 | 86.40 40 | 86.78 41 | 86.06 42 | 86.22 43 | 87.22 44 | 85.86 45 | 86.96 46 | 86.54 47 | 87.20 48 | 87.56 49 | 87.14 50 | 86.30 51 | 86.64 52 | 87.96 53 | 88.24 54 | 87.42 55 | 88.38 56 | 88.02 57 | 88.32 58 | 87.74 59 | 87.88 60 | 88.74 61 | 87.96 62 | 88.48 63 | 88.76 64 | 88.54 65 | 88.56 66 | 89.30 67 | 89.00 68 | 88.82 69 | 89.40 70 | 88.82 71 | 89.32 72 | 89.26 73 | 89.34 74 | 89.44 75 | 89.32 76 | 89.68 77 | 89.88 78 | 89.92 79 | 89.92 80 | 89.68 81 | 90.04 82 | 90.30 83 | 89.84 84 | 90.22 85 | 89.60 86 | 89.80 87 | 90.16 88 | 89.90 89 | 89.76 90 | 90.10 91 | 90.18 92 | 89.98 93 | 90.10 94 | 90.48 95 | 89.94 96 | 89.88 97 | 89.72 98 | 89.86 99 | 89.90 100 | 90.24 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Validation Accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Validation Accuracy.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/Iterations_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/Iterations_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/activeSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/activeSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/lSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/lSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/plot_epoch_xvalues.txt: -------------------------------------------------------------------------------- 1 | 1.00 2 | 2.00 3 | 3.00 4 | 4.00 5 | 5.00 6 | 6.00 7 | 7.00 8 | 8.00 9 | 9.00 10 | 10.00 11 | 11.00 12 | 12.00 13 | 13.00 14 | 14.00 15 | 15.00 16 | 16.00 17 | 17.00 18 | 18.00 19 | 19.00 20 | 20.00 21 | 21.00 22 | 22.00 23 | 23.00 24 | 24.00 25 | 25.00 26 | 26.00 27 | 27.00 28 | 28.00 29 | 29.00 30 | 30.00 31 | 31.00 32 | 32.00 33 | 33.00 34 | 34.00 35 | 35.00 36 | 36.00 37 | 37.00 38 | 38.00 39 | 39.00 40 | 40.00 41 | 41.00 42 | 42.00 43 | 43.00 44 | 44.00 45 | 45.00 46 | 46.00 47 | 47.00 48 | 48.00 49 | 49.00 50 | 50.00 51 | 51.00 52 | 52.00 53 | 53.00 54 | 54.00 55 | 55.00 56 | 56.00 57 | 57.00 58 | 58.00 59 | 59.00 60 | 60.00 61 | 61.00 62 | 62.00 63 | 63.00 64 | 64.00 65 | 65.00 66 | 66.00 67 | 67.00 68 | 68.00 69 | 69.00 70 | 70.00 71 | 71.00 72 | 72.00 73 | 73.00 74 | 74.00 75 | 75.00 76 | 76.00 77 | 77.00 78 | 78.00 79 | 79.00 80 | 80.00 81 | 81.00 82 | 82.00 83 | 83.00 84 | 84.00 85 | 85.00 86 | 86.00 87 | 87.00 88 | 88.00 89 | 89.00 90 | 90.00 91 | 91.00 92 | 92.00 93 | 93.00 94 | 94.00 95 | 95.00 96 | 96.00 97 | 97.00 98 | 98.00 99 | 99.00 100 | 100.00 101 | 101.00 102 | 102.00 103 | 103.00 104 | 104.00 105 | 105.00 106 | 106.00 107 | 107.00 108 | 108.00 109 | 109.00 110 | 110.00 111 | 111.00 112 | 112.00 113 | 113.00 114 | 114.00 115 | 115.00 116 | 116.00 117 | 117.00 118 | 118.00 119 | 119.00 120 | 120.00 121 | 121.00 122 | 122.00 123 | 123.00 124 | 124.00 125 | 125.00 126 | 126.00 127 | 127.00 128 | 128.00 129 | 129.00 130 | 130.00 131 | 131.00 132 | 132.00 133 | 133.00 134 | 134.00 135 | 135.00 136 | 136.00 137 | 137.00 138 | 138.00 139 | 139.00 140 | 140.00 141 | 141.00 142 | 142.00 143 | 143.00 144 | 144.00 145 | 145.00 146 | 146.00 147 | 147.00 148 | 148.00 149 | 149.00 150 | 150.00 151 | 151.00 152 | 152.00 153 | 153.00 154 | 154.00 155 | 155.00 156 | 156.00 157 | 157.00 158 | 158.00 159 | 159.00 160 | 160.00 161 | 161.00 162 | 162.00 163 | 163.00 164 | 164.00 165 | 165.00 166 | 166.00 167 | 167.00 168 | 168.00 169 | 169.00 170 | 170.00 171 | 171.00 172 | 172.00 173 | 173.00 174 | 174.00 175 | 175.00 176 | 176.00 177 | 177.00 178 | 178.00 179 | 179.00 180 | 180.00 181 | 181.00 182 | 182.00 183 | 183.00 184 | 184.00 185 | 185.00 186 | 186.00 187 | 187.00 188 | 188.00 189 | 189.00 190 | 190.00 191 | 191.00 192 | 192.00 193 | 193.00 194 | 194.00 195 | 195.00 196 | 196.00 197 | 197.00 198 | 198.00 199 | 199.00 200 | 200.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/plot_epoch_yvalues.txt: -------------------------------------------------------------------------------- 1 | 0.13 2 | 0.58 3 | 0.27 4 | 0.22 5 | 0.38 6 | 0.47 7 | 0.19 8 | 0.28 9 | 0.46 10 | 0.45 11 | 0.23 12 | 0.38 13 | 0.23 14 | 0.43 15 | 0.28 16 | 0.60 17 | 0.17 18 | 0.31 19 | 0.32 20 | 0.21 21 | 0.13 22 | 0.26 23 | 0.27 24 | 0.25 25 | 0.18 26 | 0.20 27 | 0.20 28 | 0.07 29 | 0.13 30 | 0.18 31 | 0.12 32 | 0.20 33 | 0.30 34 | 0.17 35 | 0.15 36 | 0.27 37 | 0.11 38 | 0.12 39 | 0.16 40 | 0.07 41 | 0.13 42 | 0.13 43 | 0.30 44 | 0.08 45 | 0.21 46 | 0.16 47 | 0.39 48 | 0.32 49 | 0.29 50 | 0.10 51 | 0.28 52 | 0.18 53 | 0.05 54 | 0.17 55 | 0.12 56 | 0.23 57 | 0.17 58 | 0.14 59 | 0.17 60 | 0.30 61 | 0.26 62 | 0.06 63 | 0.21 64 | 0.20 65 | 0.26 66 | 0.10 67 | 0.08 68 | 0.22 69 | 0.16 70 | 0.15 71 | 0.02 72 | 0.10 73 | 0.18 74 | 0.25 75 | 0.24 76 | 0.07 77 | 0.03 78 | 0.21 79 | 0.17 80 | 0.03 81 | 0.24 82 | 0.16 83 | 0.16 84 | 0.28 85 | 0.17 86 | 0.13 87 | 0.02 88 | 0.04 89 | 0.17 90 | 0.05 91 | 0.37 92 | 0.02 93 | 0.18 94 | 0.20 95 | 0.03 96 | 0.20 97 | 0.08 98 | 0.03 99 | 0.02 100 | 0.13 101 | 0.13 102 | 0.21 103 | 0.06 104 | 0.08 105 | 0.01 106 | 0.29 107 | 0.01 108 | 0.04 109 | 0.09 110 | 0.07 111 | 0.01 112 | 0.14 113 | 0.11 114 | 0.04 115 | 0.17 116 | 0.01 117 | 0.08 118 | 0.02 119 | 0.09 120 | 0.06 121 | 0.01 122 | 0.09 123 | 0.02 124 | 0.00 125 | 0.01 126 | 0.03 127 | 0.00 128 | 0.08 129 | 0.06 130 | 0.00 131 | 0.13 132 | 0.00 133 | 0.01 134 | 0.03 135 | 0.00 136 | 0.00 137 | 0.00 138 | 0.05 139 | 0.00 140 | 0.00 141 | 0.02 142 | 0.01 143 | 0.00 144 | 0.00 145 | 0.00 146 | 0.00 147 | 0.04 148 | 0.00 149 | 0.01 150 | 0.00 151 | 0.00 152 | 0.02 153 | 0.00 154 | 0.02 155 | 0.01 156 | 0.00 157 | 0.00 158 | 0.04 159 | 0.00 160 | 0.00 161 | 0.01 162 | 0.00 163 | 0.00 164 | 0.01 165 | 0.00 166 | 0.01 167 | 0.00 168 | 0.00 169 | 0.03 170 | 0.00 171 | 0.00 172 | 0.00 173 | 0.00 174 | 0.00 175 | 0.00 176 | 0.00 177 | 0.00 178 | 0.00 179 | 0.00 180 | 0.00 181 | 0.00 182 | 0.00 183 | 0.01 184 | 0.00 185 | 0.03 186 | 0.00 187 | 0.00 188 | 0.00 189 | 0.00 190 | 0.00 191 | 0.00 192 | 0.00 193 | 0.00 194 | 0.01 195 | 0.00 196 | 0.00 197 | 0.00 198 | 0.00 199 | 0.00 200 | 0.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/uSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/uSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/val_acc_epochs_x.txt: -------------------------------------------------------------------------------- 1 | 2.00 2 | 4.00 3 | 6.00 4 | 8.00 5 | 10.00 6 | 12.00 7 | 14.00 8 | 16.00 9 | 18.00 10 | 20.00 11 | 22.00 12 | 24.00 13 | 26.00 14 | 28.00 15 | 30.00 16 | 32.00 17 | 34.00 18 | 36.00 19 | 38.00 20 | 40.00 21 | 42.00 22 | 44.00 23 | 46.00 24 | 48.00 25 | 50.00 26 | 52.00 27 | 54.00 28 | 56.00 29 | 58.00 30 | 60.00 31 | 62.00 32 | 64.00 33 | 66.00 34 | 68.00 35 | 70.00 36 | 72.00 37 | 74.00 38 | 76.00 39 | 78.00 40 | 80.00 41 | 82.00 42 | 84.00 43 | 86.00 44 | 88.00 45 | 90.00 46 | 92.00 47 | 94.00 48 | 96.00 49 | 98.00 50 | 100.00 51 | 102.00 52 | 104.00 53 | 106.00 54 | 108.00 55 | 110.00 56 | 112.00 57 | 114.00 58 | 116.00 59 | 118.00 60 | 120.00 61 | 122.00 62 | 124.00 63 | 126.00 64 | 128.00 65 | 130.00 66 | 132.00 67 | 134.00 68 | 136.00 69 | 138.00 70 | 140.00 71 | 142.00 72 | 144.00 73 | 146.00 74 | 148.00 75 | 150.00 76 | 152.00 77 | 154.00 78 | 156.00 79 | 158.00 80 | 160.00 81 | 162.00 82 | 164.00 83 | 166.00 84 | 168.00 85 | 170.00 86 | 172.00 87 | 174.00 88 | 176.00 89 | 178.00 90 | 180.00 91 | 182.00 92 | 184.00 93 | 186.00 94 | 188.00 95 | 190.00 96 | 192.00 97 | 194.00 98 | 196.00 99 | 198.00 100 | 200.00 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_4/val_acc_epochs_y.txt: -------------------------------------------------------------------------------- 1 | 81.86 2 | 84.10 3 | 81.82 4 | 85.30 5 | 85.54 6 | 84.24 7 | 85.78 8 | 86.04 9 | 86.44 10 | 85.42 11 | 87.06 12 | 86.06 13 | 87.04 14 | 86.12 15 | 85.04 16 | 85.08 17 | 86.92 18 | 87.82 19 | 86.64 20 | 85.22 21 | 85.92 22 | 85.52 23 | 86.30 24 | 85.42 25 | 85.64 26 | 86.56 27 | 85.56 28 | 87.52 29 | 86.88 30 | 85.96 31 | 86.52 32 | 87.48 33 | 86.80 34 | 87.54 35 | 86.28 36 | 86.88 37 | 88.10 38 | 87.30 39 | 87.62 40 | 88.20 41 | 88.22 42 | 88.72 43 | 88.20 44 | 87.90 45 | 88.06 46 | 87.84 47 | 87.64 48 | 87.98 49 | 87.62 50 | 88.58 51 | 88.00 52 | 88.76 53 | 87.86 54 | 88.48 55 | 88.84 56 | 88.56 57 | 89.58 58 | 88.88 59 | 89.50 60 | 89.96 61 | 89.38 62 | 89.18 63 | 89.74 64 | 90.00 65 | 89.96 66 | 89.44 67 | 90.40 68 | 89.88 69 | 90.08 70 | 90.48 71 | 90.22 72 | 90.06 73 | 90.40 74 | 89.94 75 | 90.52 76 | 90.60 77 | 90.24 78 | 90.68 79 | 90.88 80 | 90.82 81 | 90.34 82 | 90.56 83 | 90.98 84 | 90.62 85 | 90.98 86 | 90.46 87 | 91.04 88 | 90.92 89 | 90.88 90 | 90.92 91 | 90.72 92 | 91.04 93 | 90.76 94 | 90.72 95 | 90.74 96 | 90.48 97 | 90.74 98 | 90.76 99 | 90.72 100 | 90.80 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Validation Accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Validation Accuracy.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_5/Iterations_vs_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_5/Iterations_vs_Loss.png -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_5/plot_epoch_xvalues.txt: -------------------------------------------------------------------------------- 1 | 1.00 2 | 2.00 3 | 3.00 4 | 4.00 5 | 5.00 6 | 6.00 7 | 7.00 8 | 8.00 9 | 9.00 10 | 10.00 11 | 11.00 12 | 12.00 13 | 13.00 14 | 14.00 15 | 15.00 16 | 16.00 17 | 17.00 18 | 18.00 19 | 19.00 20 | 20.00 21 | 21.00 22 | 22.00 23 | 23.00 24 | 24.00 25 | 25.00 26 | 26.00 27 | 27.00 28 | 28.00 29 | 29.00 30 | 30.00 31 | 31.00 32 | 32.00 33 | 33.00 34 | 34.00 35 | 35.00 36 | 36.00 37 | 37.00 38 | 38.00 39 | 39.00 40 | 40.00 41 | 41.00 42 | 42.00 43 | 43.00 44 | 44.00 45 | 45.00 46 | 46.00 47 | 47.00 48 | 48.00 49 | 49.00 50 | 50.00 51 | 51.00 52 | 52.00 53 | 53.00 54 | 54.00 55 | 55.00 56 | 56.00 57 | 57.00 58 | 58.00 59 | 59.00 60 | 60.00 61 | 61.00 62 | 62.00 63 | 63.00 64 | 64.00 65 | 65.00 66 | 66.00 67 | 67.00 68 | 68.00 69 | 69.00 70 | 70.00 71 | 71.00 72 | 72.00 73 | 73.00 74 | 74.00 75 | 75.00 76 | 76.00 77 | 77.00 78 | 78.00 79 | 79.00 80 | 80.00 81 | 81.00 82 | 82.00 83 | 83.00 84 | 84.00 85 | 85.00 86 | 86.00 87 | 87.00 88 | 88.00 89 | 89.00 90 | 90.00 91 | 91.00 92 | 92.00 93 | 93.00 94 | 94.00 95 | 95.00 96 | 96.00 97 | 97.00 98 | 98.00 99 | 99.00 100 | 100.00 101 | 101.00 102 | 102.00 103 | 103.00 104 | 104.00 105 | 105.00 106 | 106.00 107 | 107.00 108 | 108.00 109 | 109.00 110 | 110.00 111 | 111.00 112 | 112.00 113 | 113.00 114 | 114.00 115 | 115.00 116 | 116.00 117 | 117.00 118 | 118.00 119 | 119.00 120 | 120.00 121 | 121.00 122 | 122.00 123 | 123.00 124 | 124.00 125 | 125.00 126 | 126.00 127 | 127.00 128 | 128.00 129 | 129.00 130 | 130.00 131 | 131.00 132 | 132.00 133 | 133.00 134 | 134.00 135 | 135.00 136 | 136.00 137 | 137.00 138 | 138.00 139 | 139.00 140 | 140.00 141 | 141.00 142 | 142.00 143 | 143.00 144 | 144.00 145 | 145.00 146 | 146.00 147 | 147.00 148 | 148.00 149 | 149.00 150 | 150.00 151 | 151.00 152 | 152.00 153 | 153.00 154 | 154.00 155 | 155.00 156 | 156.00 157 | 157.00 158 | 158.00 159 | 159.00 160 | 160.00 161 | 161.00 162 | 162.00 163 | 163.00 164 | 164.00 165 | 165.00 166 | 166.00 167 | 167.00 168 | 168.00 169 | 169.00 170 | 170.00 171 | 171.00 172 | 172.00 173 | 173.00 174 | 174.00 175 | 175.00 176 | 176.00 177 | 177.00 178 | 178.00 179 | 179.00 180 | 180.00 181 | 181.00 182 | 182.00 183 | 183.00 184 | 184.00 185 | 185.00 186 | 186.00 187 | 187.00 188 | 188.00 189 | 189.00 190 | 190.00 191 | 191.00 192 | 192.00 193 | 193.00 194 | 194.00 195 | 195.00 196 | 196.00 197 | 197.00 198 | 198.00 199 | 199.00 200 | 200.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_5/plot_epoch_yvalues.txt: -------------------------------------------------------------------------------- 1 | 0.21 2 | 0.36 3 | 0.20 4 | 0.17 5 | 0.21 6 | 0.42 7 | 0.46 8 | 0.28 9 | 0.08 10 | 0.32 11 | 0.21 12 | 0.30 13 | 0.29 14 | 0.24 15 | 0.22 16 | 0.10 17 | 0.28 18 | 0.29 19 | 0.28 20 | 0.17 21 | 0.14 22 | 0.21 23 | 0.29 24 | 0.18 25 | 0.08 26 | 0.16 27 | 0.22 28 | 0.26 29 | 0.20 30 | 0.14 31 | 0.07 32 | 0.13 33 | 0.16 34 | 0.22 35 | 0.16 36 | 0.41 37 | 0.09 38 | 0.22 39 | 0.14 40 | 0.17 41 | 0.21 42 | 0.11 43 | 0.22 44 | 0.16 45 | 0.22 46 | 0.06 47 | 0.29 48 | 0.29 49 | 0.08 50 | 0.20 51 | 0.09 52 | 0.25 53 | 0.15 54 | 0.12 55 | 0.12 56 | 0.09 57 | 0.11 58 | 0.12 59 | 0.14 60 | 0.10 61 | 0.23 62 | 0.12 63 | 0.30 64 | 0.45 65 | 0.08 66 | 0.07 67 | 0.12 68 | 0.08 69 | 0.11 70 | 0.16 71 | 0.02 72 | 0.14 73 | 0.19 74 | 0.07 75 | 0.14 76 | 0.28 77 | 0.06 78 | 0.15 79 | 0.04 80 | 0.17 81 | 0.14 82 | 0.31 83 | 0.13 84 | 0.12 85 | 0.13 86 | 0.13 87 | 0.08 88 | 0.03 89 | 0.37 90 | 0.10 91 | 0.07 92 | 0.03 93 | 0.02 94 | 0.09 95 | 0.17 96 | 0.03 97 | 0.30 98 | 0.04 99 | 0.06 100 | 0.11 101 | 0.02 102 | 0.07 103 | 0.18 104 | 0.14 105 | 0.35 106 | 0.01 107 | 0.06 108 | 0.07 109 | 0.10 110 | 0.03 111 | 0.06 112 | 0.04 113 | 0.00 114 | 0.06 115 | 0.11 116 | 0.10 117 | 0.04 118 | 0.09 119 | 0.02 120 | 0.04 121 | 0.03 122 | 0.08 123 | 0.02 124 | 0.02 125 | 0.01 126 | 0.09 127 | 0.02 128 | 0.03 129 | 0.07 130 | 0.03 131 | 0.01 132 | 0.07 133 | 0.00 134 | 0.00 135 | 0.00 136 | 0.00 137 | 0.00 138 | 0.05 139 | 0.00 140 | 0.01 141 | 0.00 142 | 0.00 143 | 0.00 144 | 0.00 145 | 0.00 146 | 0.00 147 | 0.00 148 | 0.00 149 | 0.00 150 | 0.02 151 | 0.01 152 | 0.00 153 | 0.00 154 | 0.00 155 | 0.00 156 | 0.00 157 | 0.00 158 | 0.00 159 | 0.00 160 | 0.00 161 | 0.00 162 | 0.00 163 | 0.07 164 | 0.00 165 | 0.00 166 | 0.00 167 | 0.00 168 | 0.00 169 | 0.00 170 | 0.00 171 | 0.00 172 | 0.00 173 | 0.00 174 | 0.00 175 | 0.00 176 | 0.00 177 | 0.01 178 | 0.00 179 | 0.00 180 | 0.00 181 | 0.00 182 | 0.07 183 | 0.00 184 | 0.00 185 | 0.01 186 | 0.00 187 | 0.00 188 | 0.01 189 | 0.00 190 | 0.00 191 | 0.00 192 | 0.00 193 | 0.00 194 | 0.00 195 | 0.00 196 | 0.00 197 | 0.00 198 | 0.00 199 | 0.00 200 | 0.00 201 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_5/val_acc_epochs_x.txt: -------------------------------------------------------------------------------- 1 | 2.00 2 | 4.00 3 | 6.00 4 | 8.00 5 | 10.00 6 | 12.00 7 | 14.00 8 | 16.00 9 | 18.00 10 | 20.00 11 | 22.00 12 | 24.00 13 | 26.00 14 | 28.00 15 | 30.00 16 | 32.00 17 | 34.00 18 | 36.00 19 | 38.00 20 | 40.00 21 | 42.00 22 | 44.00 23 | 46.00 24 | 48.00 25 | 50.00 26 | 52.00 27 | 54.00 28 | 56.00 29 | 58.00 30 | 60.00 31 | 62.00 32 | 64.00 33 | 66.00 34 | 68.00 35 | 70.00 36 | 72.00 37 | 74.00 38 | 76.00 39 | 78.00 40 | 80.00 41 | 82.00 42 | 84.00 43 | 86.00 44 | 88.00 45 | 90.00 46 | 92.00 47 | 94.00 48 | 96.00 49 | 98.00 50 | 100.00 51 | 102.00 52 | 104.00 53 | 106.00 54 | 108.00 55 | 110.00 56 | 112.00 57 | 114.00 58 | 116.00 59 | 118.00 60 | 120.00 61 | 122.00 62 | 124.00 63 | 126.00 64 | 128.00 65 | 130.00 66 | 132.00 67 | 134.00 68 | 136.00 69 | 138.00 70 | 140.00 71 | 142.00 72 | 144.00 73 | 146.00 74 | 148.00 75 | 150.00 76 | 152.00 77 | 154.00 78 | 156.00 79 | 158.00 80 | 160.00 81 | 162.00 82 | 164.00 83 | 166.00 84 | 168.00 85 | 170.00 86 | 172.00 87 | 174.00 88 | 176.00 89 | 178.00 90 | 180.00 91 | 182.00 92 | 184.00 93 | 186.00 94 | 188.00 95 | 190.00 96 | 192.00 97 | 194.00 98 | 196.00 99 | 198.00 100 | 200.00 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/episode_5/val_acc_epochs_y.txt: -------------------------------------------------------------------------------- 1 | 82.96 2 | 82.74 3 | 85.02 4 | 85.02 5 | 86.20 6 | 85.22 7 | 86.16 8 | 86.50 9 | 86.38 10 | 85.42 11 | 85.34 12 | 85.38 13 | 86.98 14 | 86.28 15 | 86.26 16 | 86.72 17 | 87.10 18 | 86.54 19 | 86.48 20 | 88.04 21 | 86.72 22 | 86.68 23 | 86.16 24 | 86.72 25 | 87.52 26 | 86.62 27 | 86.62 28 | 85.92 29 | 88.34 30 | 87.86 31 | 88.50 32 | 86.96 33 | 87.66 34 | 88.32 35 | 87.88 36 | 87.74 37 | 88.20 38 | 88.44 39 | 88.48 40 | 88.24 41 | 87.46 42 | 87.92 43 | 87.68 44 | 87.90 45 | 88.70 46 | 87.58 47 | 87.90 48 | 89.60 49 | 87.96 50 | 88.80 51 | 89.28 52 | 88.92 53 | 89.86 54 | 89.46 55 | 88.32 56 | 89.36 57 | 89.46 58 | 89.76 59 | 89.20 60 | 89.48 61 | 89.76 62 | 89.48 63 | 89.70 64 | 89.42 65 | 90.00 66 | 89.44 67 | 90.42 68 | 89.90 69 | 90.22 70 | 90.32 71 | 90.74 72 | 91.00 73 | 91.04 74 | 90.80 75 | 90.84 76 | 90.92 77 | 90.74 78 | 91.36 79 | 90.92 80 | 91.18 81 | 91.50 82 | 91.26 83 | 91.18 84 | 91.14 85 | 91.56 86 | 91.60 87 | 91.80 88 | 91.14 89 | 91.28 90 | 91.34 91 | 91.02 92 | 91.48 93 | 91.10 94 | 91.50 95 | 91.26 96 | 91.34 97 | 91.38 98 | 91.34 99 | 91.62 100 | 91.50 101 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/lSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/lSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/plot_episode_xvalues.txt: -------------------------------------------------------------------------------- 1 | 0.00 2 | 1.00 3 | 2.00 4 | 3.00 5 | 4.00 6 | 5.00 7 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/plot_episode_yvalues.txt: -------------------------------------------------------------------------------- 1 | 71.83 2 | 82.86 3 | 87.32 4 | 90.04 5 | 90.88 6 | 91.12 7 | -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/uSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/uSet.npy -------------------------------------------------------------------------------- /output/CIFAR10/resnet18/ENT_1/valSet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/valSet.npy -------------------------------------------------------------------------------- /pycls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/__init__.py -------------------------------------------------------------------------------- /pycls/al/ActiveLearning.py: -------------------------------------------------------------------------------- 1 | # This file is slightly modified from a code implementation by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564 2 | # GitHub: https://github.com/PrateekMunjal 3 | # ---------------------------------------------------------- 4 | 5 | import numpy as np 6 | import torch 7 | from .Sampling import Sampling, CoreSetMIPSampling, AdversarySampler 8 | import pycls.utils.logging as lu 9 | import os 10 | 11 | logger = lu.get_logger(__name__) 12 | 13 | class ActiveLearning: 14 | """ 15 | Implements standard active learning methods. 16 | """ 17 | 18 | def __init__(self, dataObj, cfg): 19 | self.dataObj = dataObj 20 | self.sampler = Sampling(dataObj=dataObj,cfg=cfg) 21 | self.cfg = cfg 22 | 23 | def sample_from_uSet(self, clf_model, lSet, uSet, trainDataset, supportingModels=None): 24 | """ 25 | Sample from uSet using cfg.ACTIVE_LEARNING.SAMPLING_FN. 26 | 27 | INPUT 28 | ------ 29 | clf_model: Reference of task classifier model class [Typically VGG] 30 | 31 | supportingModels: List of models which are used for sampling process. 32 | 33 | OUTPUT 34 | ------- 35 | Returns activeSet, uSet 36 | """ 37 | assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE > 0, "Expected a positive budgetSize" 38 | assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE < len(uSet), "BudgetSet cannot exceed length of unlabelled set. Length of unlabelled set: {} and budgetSize: {}"\ 39 | .format(len(uSet), self.cfg.ACTIVE_LEARNING.BUDGET_SIZE) 40 | 41 | if self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "random": 42 | 43 | activeSet, uSet = self.sampler.random(uSet=uSet, budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE) 44 | 45 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "uncertainty": 46 | oldmode = clf_model.training 47 | clf_model.eval() 48 | activeSet, uSet = self.sampler.uncertainty(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,lSet=lSet,uSet=uSet \ 49 | ,model=clf_model,dataset=trainDataset) 50 | clf_model.train(oldmode) 51 | 52 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "entropy": 53 | oldmode = clf_model.training 54 | clf_model.eval() 55 | activeSet, uSet = self.sampler.entropy(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,lSet=lSet,uSet=uSet \ 56 | ,model=clf_model,dataset=trainDataset) 57 | clf_model.train(oldmode) 58 | 59 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "margin": 60 | oldmode = clf_model.training 61 | clf_model.eval() 62 | activeSet, uSet = self.sampler.margin(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,lSet=lSet,uSet=uSet \ 63 | ,model=clf_model,dataset=trainDataset) 64 | clf_model.train(oldmode) 65 | 66 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "coreset": 67 | waslatent = clf_model.penultimate_active 68 | wastrain = clf_model.training 69 | clf_model.penultimate_active = True 70 | # if self.cfg.TRAIN.DATASET == "IMAGENET": 71 | # clf_model.cuda(0) 72 | clf_model.eval() 73 | coreSetSampler = CoreSetMIPSampling(cfg=self.cfg, dataObj=self.dataObj) 74 | activeSet, uSet = coreSetSampler.query(lSet=lSet, uSet=uSet, clf_model=clf_model, dataset=trainDataset) 75 | 76 | clf_model.penultimate_active = waslatent 77 | clf_model.train(wastrain) 78 | 79 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "dbal" or self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "DBAL": 80 | activeSet, uSet = self.sampler.dbal(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE, \ 81 | uSet=uSet, clf_model=clf_model,dataset=trainDataset) 82 | 83 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "bald" or self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "BALD": 84 | activeSet, uSet = self.sampler.bald(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE, uSet=uSet, clf_model=clf_model, dataset=trainDataset) 85 | 86 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "ensemble_var_R": 87 | activeSet, uSet = self.sampler.ensemble_var_R(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE, uSet=uSet, clf_models=supportingModels, dataset=trainDataset) 88 | 89 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "vaal": 90 | adv_sampler = AdversarySampler(cfg=self.cfg, dataObj=self.dataObj) 91 | 92 | # Train VAE and discriminator first 93 | vae, disc, uSet_loader = adv_sampler.vaal_perform_training(lSet=lSet, uSet=uSet, dataset=trainDataset) 94 | 95 | # Do active sampling 96 | activeSet, uSet = adv_sampler.sample_for_labeling(vae=vae, discriminator=disc, \ 97 | unlabeled_dataloader=uSet_loader, uSet=uSet) 98 | else: 99 | print(f"{self.cfg.ACTIVE_LEARNING.SAMPLING_FN} is either not implemented or there is some spelling mistake.") 100 | raise NotImplementedError 101 | 102 | return activeSet, uSet 103 | 104 | -------------------------------------------------------------------------------- /pycls/al/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/al/__init__.py -------------------------------------------------------------------------------- /pycls/al/vaal_util.py: -------------------------------------------------------------------------------- 1 | # This file is directly taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564 2 | # GitHub: https://github.com/PrateekMunjal 3 | # ---------------------------------------------------------- 4 | 5 | # code modified from VAAL codebase 6 | 7 | import os 8 | import torch 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from pycls.models import vaal_model as vm 13 | import pycls.utils.logging as lu 14 | # import pycls.datasets.loader as imagenet_loader 15 | 16 | logger = lu.get_logger(__name__) 17 | 18 | bce_loss = torch.nn.BCELoss().cuda() 19 | 20 | def data_parallel_wrapper(model, cur_device, cfg): 21 | model.cuda(cur_device) 22 | model = torch.nn.DataParallel(model, device_ids = [i for i in range(torch.cuda.device_count())]) 23 | return model 24 | 25 | def distributed_wrapper(cfg, model, cur_device): 26 | # Transfer the model to the current GPU device 27 | model = model.cuda(device=cur_device) 28 | 29 | # Use multi-process data parallel model in the multi-gpu setting 30 | if cfg.NUM_GPUS > 1: 31 | # Make model replica operate on the current device 32 | model = torch.nn.parallel.DistributedDataParallel( 33 | module=model, 34 | device_ids=[cur_device], 35 | output_device=cur_device 36 | ) 37 | return model 38 | 39 | def read_data(dataloader, labels=True): 40 | if labels: 41 | while True: 42 | for img, label in dataloader: 43 | yield img, label 44 | else: 45 | while True: 46 | for img, _ in dataloader: 47 | yield img 48 | 49 | def vae_loss( x, recon, mu, logvar, beta): 50 | mse_loss = torch.nn.MSELoss().cuda() 51 | recon = recon.cuda() 52 | x = x.cuda() 53 | MSE = mse_loss(recon, x) 54 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 55 | KLD = KLD * beta 56 | return MSE + KLD 57 | 58 | def train_vae_disc_epoch(cfg, vae_model, disc_model, optim_vae, optim_disc, lSetLoader, uSetLoader, cur_epoch, \ 59 | n_lu, curr_vae_disc_iter, max_vae_disc_iters, change_lr_iter,isDistributed=False): 60 | 61 | if isDistributed: 62 | lSetLoader.sampler.set_epoch(cur_epoch) 63 | uSetLoader.sampler.set_epoch(cur_epoch) 64 | 65 | print('len(lSetLoader): {}'.format(len(lSetLoader))) 66 | print('len(uSetLoader): {}'.format(len(uSetLoader))) 67 | 68 | labeled_data = read_data(lSetLoader) 69 | unlabeled_data = read_data(uSetLoader, labels=False) 70 | 71 | vae_model.train() 72 | disc_model.train() 73 | 74 | temp_bs = int(cfg.VAAL.VAE_BS) 75 | train_iterations = int(n_lu/temp_bs) 76 | 77 | for temp_iter in range(train_iterations): 78 | 79 | if curr_vae_disc_iter !=0 and curr_vae_disc_iter%change_lr_iter==0: 80 | #print("Changing LR ---- ))__((---- ") 81 | for param in optim_vae.param_groups: 82 | param['lr'] = param['lr'] * 0.9 83 | 84 | for param in optim_disc.param_groups: 85 | param['lr'] = param['lr'] * 0.9 86 | 87 | curr_vae_disc_iter += 1 88 | 89 | ## VAE Step 90 | disc_model.eval() 91 | vae_model.train() 92 | labeled_imgs, labels = next(labeled_data) 93 | unlabeled_imgs = next(unlabeled_data) 94 | 95 | labeled_imgs = labeled_imgs.type(torch.cuda.FloatTensor) 96 | unlabeled_imgs = unlabeled_imgs.type(torch.cuda.FloatTensor) 97 | 98 | labeled_imgs = labeled_imgs.cuda() 99 | unlabeled_imgs = unlabeled_imgs.cuda() 100 | 101 | recon, z, mu, logvar = vae_model(labeled_imgs) 102 | recon = recon.view((labeled_imgs.shape[0], labeled_imgs.shape[1], labeled_imgs.shape[2], labeled_imgs.shape[3])) 103 | unsup_loss = vae_loss(labeled_imgs, recon, mu, logvar, cfg.VAAL.BETA) 104 | unlab_recon, unlab_z, unlab_mu, unlab_logvar = vae_model(unlabeled_imgs) 105 | unlab_recon = unlab_recon.view((unlabeled_imgs.shape[0], unlabeled_imgs.shape[1], unlabeled_imgs.shape[2], unlabeled_imgs.shape[3])) 106 | transductive_loss = vae_loss(unlabeled_imgs, unlab_recon, unlab_mu, unlab_logvar, cfg.VAAL.BETA) 107 | 108 | labeled_preds = disc_model(mu) 109 | unlabeled_preds = disc_model(unlab_mu) 110 | 111 | lab_real_preds = torch.ones(labeled_imgs.size(0),1).cuda() 112 | unlab_real_preds = torch.ones(unlabeled_imgs.size(0),1).cuda() 113 | dsc_loss = bce_loss(labeled_preds, lab_real_preds) + bce_loss(unlabeled_preds, unlab_real_preds) 114 | 115 | total_vae_loss = unsup_loss + transductive_loss + cfg.VAAL.ADVERSARY_PARAM * dsc_loss 116 | 117 | optim_vae.zero_grad() 118 | total_vae_loss.backward() 119 | optim_vae.step() 120 | 121 | ##DISC STEP 122 | vae_model.eval() 123 | disc_model.train() 124 | 125 | with torch.no_grad(): 126 | _, _, mu, _ = vae_model(labeled_imgs) 127 | _, _, unlab_mu, _ = vae_model(unlabeled_imgs) 128 | 129 | labeled_preds = disc_model(mu) 130 | unlabeled_preds = disc_model(unlab_mu) 131 | 132 | lab_real_preds = torch.ones(labeled_imgs.size(0),1).cuda() 133 | unlab_fake_preds = torch.zeros(unlabeled_imgs.size(0),1).cuda() 134 | 135 | dsc_loss = bce_loss(labeled_preds, lab_real_preds) + \ 136 | bce_loss(unlabeled_preds, unlab_fake_preds) 137 | 138 | optim_disc.zero_grad() 139 | dsc_loss.backward() 140 | optim_disc.step() 141 | 142 | 143 | if temp_iter%100 == 0: 144 | print("Epoch[{}],Iteration [{}/{}], VAE Loss: {:.3f}, Disc Loss: {:.4f}"\ 145 | .format(cur_epoch,temp_iter, train_iterations, total_vae_loss.item(), dsc_loss.item())) 146 | 147 | return vae_model, disc_model, optim_vae, optim_disc, curr_vae_disc_iter 148 | 149 | def train_vae_disc(cfg, lSet, uSet, trainDataset, dataObj, debug=False): 150 | 151 | cur_device = torch.cuda.current_device() 152 | if cfg.DATASET.NAME == 'MNIST': 153 | vae_model = vm.VAE(cur_device, z_dim=cfg.VAAL.Z_DIM, nc=1) 154 | disc_model = vm.Discriminator(z_dim=cfg.VAAL.Z_DIM) 155 | else: 156 | vae_model = vm.VAE(cur_device, z_dim=cfg.VAAL.Z_DIM) 157 | disc_model = vm.Discriminator(z_dim=cfg.VAAL.Z_DIM) 158 | 159 | 160 | # vae_model = data_parallel_wrapper(vae_model, cur_device, cfg) 161 | # disc_model = data_parallel_wrapper(disc_model, cur_device, cfg) 162 | 163 | # if cfg.TRAIN.DATASET == "IMAGENET": 164 | # lSetLoader = imagenet_loader.construct_loader_no_aug(cfg, indices=lSet, isDistributed=False, isVaalSampling=True) 165 | # uSetLoader = imagenet_loader.construct_loader_no_aug(cfg, indices=uSet, isDistributed=False, isVaalSampling=True) 166 | # else: 167 | lSetLoader = dataObj.getIndexesDataLoader(indexes=lSet, batch_size=int(cfg.VAAL.VAE_BS) \ 168 | ,data=trainDataset) 169 | 170 | uSetLoader = dataObj.getIndexesDataLoader(indexes=uSet, batch_size=int(cfg.VAAL.VAE_BS) \ 171 | ,data=trainDataset) 172 | 173 | print("Initializing VAE and discriminator") 174 | logger.info("Initializing VAE and discriminator") 175 | optim_vae = torch.optim.Adam(vae_model.parameters(), lr=cfg.VAAL.VAE_LR) 176 | print(f"VAE Optimizer ==> {optim_vae}") 177 | logger.info(f"VAE Optimizer ==> {optim_vae}") 178 | optim_disc = torch.optim.Adam(disc_model.parameters(), lr=cfg.VAAL.DISC_LR) 179 | print(f"Disc Optimizer ==> {optim_disc}") 180 | logger.info(f"Disc Optimizer ==> {optim_disc}") 181 | print("==================================") 182 | logger.info("==================================\n") 183 | 184 | n_lu_points = len(lSet)+len(uSet) 185 | max_vae_disc_iters = int(n_lu_points/cfg.VAAL.VAE_BS)*cfg.VAAL.VAE_EPOCHS 186 | change_lr_iter = max_vae_disc_iters // 25 187 | curr_vae_disc_iter = 0 188 | 189 | vae_model = vae_model.cuda() 190 | disc_model = disc_model.cuda() 191 | 192 | for epoch in range(cfg.VAAL.VAE_EPOCHS): 193 | vae_model, disc_model, optim_vae, optim_disc, curr_vae_disc_iter = train_vae_disc_epoch(cfg, vae_model, disc_model, optim_vae, \ 194 | optim_disc, lSetLoader, uSetLoader, epoch, n_lu_points, curr_vae_disc_iter, max_vae_disc_iters, change_lr_iter) 195 | 196 | #Save vae and disc models 197 | vae_sd = vae_model.module.state_dict() if cfg.NUM_GPUS > 1 else vae_model.state_dict() 198 | disc_sd = disc_model.module.state_dict() if cfg.NUM_GPUS > 1 else disc_model.state_dict() 199 | # Record the state 200 | vae_checkpoint = { 201 | 'epoch': cfg.VAAL.VAE_EPOCHS + 1, 202 | 'model_state': vae_sd, 203 | 'optimizer_state': optim_vae.state_dict(), 204 | 'cfg': cfg.dump() 205 | } 206 | disc_checkpoint = { 207 | 'epoch': cfg.VAAL.VAE_EPOCHS + 1, 208 | 'model_state': disc_sd, 209 | 'optimizer_state': optim_disc.state_dict(), 210 | 'cfg': cfg.dump() 211 | } 212 | # Write the checkpoint 213 | os.makedirs(cfg.EPISODE_DIR, exist_ok=True) 214 | vae_checkpoint_file = os.path.join(cfg.EPISODE_DIR, "vae.pyth") 215 | disc_checkpoint_file = os.path.join(cfg.EPISODE_DIR, "disc.pyth") 216 | torch.save(vae_checkpoint, vae_checkpoint_file) 217 | torch.save(disc_checkpoint, disc_checkpoint_file) 218 | 219 | if debug: print("Saved VAE model at {}".format(vae_checkpoint_file)) 220 | if debug: print("Saved DISC model at {}".format(disc_checkpoint_file)) 221 | 222 | return vae_model, disc_model -------------------------------------------------------------------------------- /pycls/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/core/__init__.py -------------------------------------------------------------------------------- /pycls/core/builders.py: -------------------------------------------------------------------------------- 1 | # This file is modified from official pycls repository 2 | 3 | """Model and loss construction functions.""" 4 | 5 | from pycls.core.net import SoftCrossEntropyLoss 6 | from pycls.models.resnet import * 7 | from pycls.models.vgg import * 8 | from pycls.models.alexnet import * 9 | 10 | import torch 11 | 12 | 13 | # Supported models 14 | _models = { 15 | # VGG style architectures 16 | 'vgg11': vgg11, 17 | 'vgg11_bn': vgg11_bn, 18 | 'vgg13': vgg13, 19 | 'vgg13_bn': vgg13_bn, 20 | 'vgg16': vgg16, 21 | 'vgg16_bn': vgg16_bn, 22 | 'vgg19': vgg19, 23 | 'vgg19_bn': vgg19_bn, 24 | 25 | # ResNet style archiectures 26 | 'resnet18': resnet18, 27 | 'resnet34': resnet34, 28 | 'resnet50': resnet50, 29 | 'resnet101': resnet101, 30 | 'resnet152': resnet152, 31 | 'resnext50_32x4d': resnext50_32x4d, 32 | 'resnext101_32x8d': resnext101_32x8d, 33 | 'wide_resnet50_2': wide_resnet50_2, 34 | 'wide_resnet101_2': wide_resnet101_2, 35 | 36 | # AlexNet architecture 37 | 'alexnet': alexnet 38 | } 39 | 40 | # Supported loss functions 41 | _loss_funs = {"cross_entropy": SoftCrossEntropyLoss} 42 | 43 | 44 | def get_model(cfg): 45 | """Gets the model class specified in the config.""" 46 | err_str = "Model type '{}' not supported" 47 | assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE) 48 | return _models[cfg.MODEL.TYPE] 49 | 50 | 51 | def get_loss_fun(cfg): 52 | """Gets the loss function class specified in the config.""" 53 | err_str = "Loss function type '{}' not supported" 54 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS) 55 | return _loss_funs[cfg.MODEL.LOSS_FUN] 56 | 57 | 58 | def build_model(cfg): 59 | """Builds the model.""" 60 | model = get_model(cfg)(num_classes=cfg.MODEL.NUM_CLASSES, use_dropout=True) 61 | if cfg.DATASET.NAME == 'MNIST': 62 | model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 63 | 64 | return model 65 | 66 | 67 | def build_loss_fun(cfg): 68 | """Build the loss function.""" 69 | return get_loss_fun(cfg)() 70 | 71 | 72 | def register_model(name, ctor): 73 | """Registers a model dynamically.""" 74 | _models[name] = ctor 75 | 76 | 77 | def register_loss_fun(name, ctor): 78 | """Registers a loss function dynamically.""" 79 | _loss_funs[name] = ctor 80 | -------------------------------------------------------------------------------- /pycls/core/losses.py: -------------------------------------------------------------------------------- 1 | """Loss functions.""" 2 | 3 | import torch.nn as nn 4 | 5 | from pycls.core.config import cfg 6 | 7 | # Supported loss functions 8 | _loss_funs = { 9 | 'cross_entropy': nn.CrossEntropyLoss, 10 | } 11 | 12 | def get_loss_fun(): 13 | """Retrieves the loss function.""" 14 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), \ 15 | 'Loss function \'{}\' not supported'.format(cfg.TRAIN.LOSS) 16 | return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda() 17 | 18 | 19 | def register_loss_fun(name, ctor): 20 | """Registers a loss function dynamically.""" 21 | _loss_funs[name] = ctor 22 | -------------------------------------------------------------------------------- /pycls/core/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | 12 | import numpy as np 13 | # import pycls.core.distributed as dist 14 | import torch 15 | from pycls.core.config import cfg 16 | 17 | 18 | def unwrap_model(model): 19 | """Remove the DistributedDataParallel wrapper if present.""" 20 | wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel) 21 | return model.module if wrapped else model 22 | 23 | 24 | # @torch.no_grad() 25 | # def compute_precise_bn_stats(model, loader): 26 | # """Computes precise BN stats on training data.""" 27 | # # Compute the number of minibatches to use 28 | # num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size / cfg.NUM_GPUS) 29 | # num_iter = min(num_iter, len(loader)) 30 | # # Retrieve the BN layers 31 | # bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 32 | # # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch)) 33 | # running_means = [torch.zeros_like(bn.running_mean) for bn in bns] 34 | # running_vars = [torch.zeros_like(bn.running_var) for bn in bns] 35 | # # Remember momentum values 36 | # momentums = [bn.momentum for bn in bns] 37 | # # Set momentum to 1.0 to compute BN stats that only reflect the current batch 38 | # for bn in bns: 39 | # bn.momentum = 1.0 40 | # # Average the BN stats for each BN layer over the batches 41 | # for inputs, _labels in itertools.islice(loader, num_iter): 42 | # model(inputs.cuda()) 43 | # for i, bn in enumerate(bns): 44 | # running_means[i] += bn.running_mean / num_iter 45 | # running_vars[i] += bn.running_var / num_iter 46 | # # Sync BN stats across GPUs (no reduction if 1 GPU used) 47 | # running_means = dist.scaled_all_reduce(running_means) 48 | # running_vars = dist.scaled_all_reduce(running_vars) 49 | # # Set BN stats and restore original momentum values 50 | # for i, bn in enumerate(bns): 51 | # bn.running_mean = running_means[i] 52 | # bn.running_var = running_vars[i] 53 | # bn.momentum = momentums[i] 54 | 55 | 56 | def complexity(model): 57 | """Compute model complexity (model can be model instance or model class).""" 58 | size = cfg.TRAIN.IM_SIZE 59 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0} 60 | cx = unwrap_model(model).complexity(cx) 61 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]} 62 | 63 | 64 | def smooth_one_hot_labels(labels): 65 | """Convert each label to a one-hot vector.""" 66 | n_classes, label_smooth = cfg.MODEL.NUM_CLASSES, cfg.TRAIN.LABEL_SMOOTHING 67 | err_str = "Invalid input to one_hot_vector()" 68 | assert labels.ndim == 1 and labels.max() < n_classes, err_str 69 | shape = (labels.shape[0], n_classes) 70 | neg_val = label_smooth / n_classes 71 | pos_val = 1.0 - label_smooth + neg_val 72 | labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device) 73 | labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val) 74 | return labels_one_hot 75 | 76 | 77 | class SoftCrossEntropyLoss(torch.nn.Module): 78 | """SoftCrossEntropyLoss (useful for label smoothing and mixup). 79 | Identical to torch.nn.CrossEntropyLoss if used with one-hot labels.""" 80 | 81 | def __init__(self): 82 | super(SoftCrossEntropyLoss, self).__init__() 83 | 84 | def forward(self, x, y): 85 | loss = -y * torch.nn.functional.log_softmax(x, -1) 86 | return torch.sum(loss) / x.shape[0] 87 | 88 | 89 | def mixup(inputs, labels): 90 | """Apply mixup to minibatch (https://arxiv.org/abs/1710.09412).""" 91 | alpha = cfg.TRAIN.MIXUP_ALPHA 92 | assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot" 93 | if alpha > 0: 94 | m = np.random.beta(alpha, alpha) 95 | permutation = torch.randperm(labels.shape[0]) 96 | inputs = m * inputs + (1.0 - m) * inputs[permutation, :] 97 | labels = m * labels + (1.0 - m) * labels[permutation, :] 98 | return inputs, labels, labels.argmax(1) 99 | -------------------------------------------------------------------------------- /pycls/core/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Optimizer.""" 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def construct_optimizer(cfg, model): 16 | """Constructs the optimizer. 17 | 18 | Note that the momentum update in PyTorch differs from the one in Caffe2. 19 | In particular, 20 | 21 | Caffe2: 22 | V := mu * V + lr * g 23 | p := p - V 24 | 25 | PyTorch: 26 | V := mu * V + g 27 | p := p - lr * V 28 | 29 | where V is the velocity, mu is the momentum factor, lr is the learning rate, 30 | g is the gradient and p are the parameters. 31 | 32 | Since V is defined independently of the learning rate in PyTorch, 33 | when the learning rate is changed there is no need to perform the 34 | momentum correction by scaling V (unlike in the Caffe2 case). 35 | """ 36 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY: 37 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 38 | p_bn = [p for n, p in model.named_parameters() if "bn" in n] 39 | p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n] 40 | optim_params = [ 41 | {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY}, 42 | {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, 43 | ] 44 | else: 45 | optim_params = model.parameters() 46 | 47 | if cfg.OPTIM.TYPE == 'sgd': 48 | optimizer = torch.optim.SGD( 49 | model.parameters(), 50 | lr=cfg.OPTIM.BASE_LR, 51 | momentum=cfg.OPTIM.MOMENTUM, 52 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 53 | dampening=cfg.OPTIM.DAMPENING, 54 | nesterov=cfg.OPTIM.NESTEROV 55 | ) 56 | elif cfg.OPTIM.TYPE == 'adam': 57 | optimizer = torch.optim.Adam( 58 | model.parameters(), 59 | lr=cfg.OPTIM.BASE_LR, 60 | weight_decay=cfg.OPTIM.WEIGHT_DECAY 61 | ) 62 | else: 63 | raise NotImplementedError 64 | 65 | return optimizer 66 | 67 | 68 | def lr_fun_steps(cfg, cur_epoch): 69 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" 70 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] 71 | return cfg.OPTIM.LR_MULT ** ind 72 | 73 | 74 | def lr_fun_exp(cfg, cur_epoch): 75 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" 76 | return cfg.OPTIM.MIN_LR ** (cur_epoch / cfg.OPTIM.MAX_EPOCH) 77 | 78 | 79 | def lr_fun_cos(cfg, cur_epoch): 80 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" 81 | lr = 0.5 * (1.0 + np.cos(np.pi * cur_epoch / cfg.OPTIM.MAX_EPOCH)) 82 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR 83 | 84 | 85 | def lr_fun_lin(cfg, cur_epoch): 86 | """Linear schedule (cfg.OPTIM.LR_POLICY = 'lin').""" 87 | lr = 1.0 - cur_epoch / cfg.OPTIM.MAX_EPOCH 88 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR 89 | 90 | 91 | def lr_fun_none(cfg, cur_epoch): 92 | """No schedule (cfg.OPTIM.LR_POLICY = 'none').""" 93 | return 1 94 | 95 | 96 | def get_lr_fun(cfg): 97 | """Retrieves the specified lr policy function""" 98 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY 99 | assert lr_fun in globals(), "Unknown LR policy: " + cfg.OPTIM.LR_POLICY 100 | err_str = "exp lr policy requires OPTIM.MIN_LR to be greater than 0." 101 | assert cfg.OPTIM.LR_POLICY != "exp" or cfg.OPTIM.MIN_LR > 0, err_str 102 | return globals()[lr_fun] 103 | 104 | 105 | def get_epoch_lr(cfg, cur_epoch): 106 | """Retrieves the lr for the given epoch according to the policy.""" 107 | # Get lr and scale by by BASE_LR 108 | lr = get_lr_fun(cfg)(cfg, cur_epoch) * cfg.OPTIM.BASE_LR 109 | # Linear warmup 110 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS and 'none' not in cfg.OPTIM.LR_POLICY: 111 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS 112 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha 113 | lr *= warmup_factor 114 | return lr 115 | 116 | 117 | def set_lr(optimizer, new_lr): 118 | """Sets the optimizer lr to the specified value.""" 119 | for param_group in optimizer.param_groups: 120 | param_group["lr"] = new_lr 121 | 122 | 123 | def plot_lr_fun(): 124 | """Visualizes lr function.""" 125 | epochs = list(range(cfg.OPTIM.MAX_EPOCH)) 126 | lrs = [get_epoch_lr(epoch) for epoch in epochs] 127 | plt.plot(epochs, lrs, ".-") 128 | plt.title("lr_policy: {}".format(cfg.OPTIM.LR_POLICY)) 129 | plt.xlabel("epochs") 130 | plt.ylabel("learning rate") 131 | plt.ylim(bottom=0) 132 | plt.show() 133 | -------------------------------------------------------------------------------- /pycls/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/datasets/__init__.py -------------------------------------------------------------------------------- /pycls/datasets/custom_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from PIL import Image 4 | 5 | class CIFAR10(torchvision.datasets.CIFAR10): 6 | def __init__(self, root, train, transform, test_transform, download=True): 7 | super(CIFAR10, self).__init__(root, train, transform=transform, download=download) 8 | self.test_transform = test_transform 9 | self.no_aug = False 10 | 11 | def __getitem__(self, index: int): 12 | """ 13 | Args: 14 | index (int): Index 15 | 16 | Returns: 17 | tuple: (image, target) where target is index of the target class. 18 | """ 19 | img, target = self.data[index], self.targets[index] 20 | 21 | # doing this so that it is consistent with all other datasets 22 | # to return a PIL Image 23 | img = Image.fromarray(img) 24 | 25 | if self.no_aug: 26 | if self.test_transform is not None: 27 | img = self.test_transform(img) 28 | else: 29 | if self.transform is not None: 30 | img = self.transform(img) 31 | 32 | 33 | return img, target 34 | 35 | 36 | class CIFAR100(torchvision.datasets.CIFAR100): 37 | def __init__(self, root, train, transform, test_transform, download=True): 38 | super(CIFAR100, self).__init__(root, train, transform=transform, download=download) 39 | self.test_transform = test_transform 40 | self.no_aug = False 41 | 42 | def __getitem__(self, index: int): 43 | """ 44 | Args: 45 | index (int): Index 46 | 47 | Returns: 48 | tuple: (image, target) where target is index of the target class. 49 | """ 50 | img, target = self.data[index], self.targets[index] 51 | 52 | # doing this so that it is consistent with all other datasets 53 | # to return a PIL Image 54 | img = Image.fromarray(img) 55 | 56 | if self.no_aug: 57 | if self.test_transform is not None: 58 | img = self.test_transform(img) 59 | else: 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | 63 | 64 | return img, target 65 | 66 | 67 | class MNIST(torchvision.datasets.MNIST): 68 | def __init__(self, root, train, transform, test_transform, download=True): 69 | super(MNIST, self).__init__(root, train, transform=transform, download=download) 70 | self.test_transform = test_transform 71 | self.no_aug = False 72 | 73 | def __getitem__(self, index: int): 74 | """ 75 | Args: 76 | index (int): Index 77 | 78 | Returns: 79 | tuple: (image, target) where target is index of the target class. 80 | """ 81 | img, target = self.data[index], int(self.targets[index]) 82 | 83 | # doing this so that it is consistent with all other datasets 84 | # to return a PIL Image 85 | img = Image.fromarray(img.numpy(), mode='L') 86 | 87 | if self.no_aug: 88 | if self.test_transform is not None: 89 | img = self.test_transform(img) 90 | else: 91 | if self.transform is not None: 92 | img = self.transform(img) 93 | 94 | 95 | return img, target 96 | 97 | 98 | class SVHN(torchvision.datasets.SVHN): 99 | def __init__(self, root, train, transform, test_transform, download=True): 100 | super(SVHN, self).__init__(root, train, transform=transform, download=download) 101 | self.test_transform = test_transform 102 | self.no_aug = False 103 | 104 | def __getitem__(self, index: int): 105 | """ 106 | Args: 107 | index (int): Index 108 | 109 | Returns: 110 | tuple: (image, target) where target is index of the target class. 111 | """ 112 | img, target = self.data[index], self.targets[index] 113 | 114 | # doing this so that it is consistent with all other datasets 115 | # to return a PIL Image 116 | img = Image.fromarray(img) 117 | 118 | if self.no_aug: 119 | if self.test_transform is not None: 120 | img = self.test_transform(img) 121 | else: 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | 125 | 126 | return img, target -------------------------------------------------------------------------------- /pycls/datasets/imbalanced_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits: Kaihua Tang 3 | Source: https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/ 4 | """ 5 | 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | from PIL import Image 10 | import random 11 | from pycls.datasets.custom_datasets import CIFAR10, CIFAR100 12 | 13 | class IMBALANCECIFAR10(CIFAR10): 14 | cls_num = 10 15 | np.random.seed(1) 16 | def __init__(self, root, train, transform=None, test_transform=None, imbalance_ratio=0.02, imb_type='exp'): 17 | super(IMBALANCECIFAR10, self).__init__(root, train, transform=transform, test_transform=test_transform, download=True) 18 | self.train = train 19 | self.transform = transform 20 | if self.train: 21 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio) 22 | self.gen_imbalanced_data(img_num_list) 23 | phase = 'Train' 24 | else: 25 | phase = 'Test' 26 | self.labels = self.targets 27 | 28 | print("{} Mode: Contain {} images".format(phase, len(self.data))) 29 | 30 | def _get_class_dict(self): 31 | class_dict = dict() 32 | for i, anno in enumerate(self.get_annotations()): 33 | cat_id = anno["category_id"] 34 | if not cat_id in class_dict: 35 | class_dict[cat_id] = [] 36 | class_dict[cat_id].append(i) 37 | return class_dict 38 | 39 | 40 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 41 | img_max = len(self.data) / cls_num 42 | img_num_per_cls = [] 43 | if imb_type == 'exp': 44 | for cls_idx in range(cls_num): 45 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 46 | img_num_per_cls.append(int(num)) 47 | elif imb_type == 'step': 48 | for cls_idx in range(cls_num // 2): 49 | img_num_per_cls.append(int(img_max)) 50 | for cls_idx in range(cls_num // 2): 51 | img_num_per_cls.append(int(img_max * imb_factor)) 52 | else: 53 | img_num_per_cls.extend([int(img_max)] * cls_num) 54 | return img_num_per_cls 55 | 56 | def gen_imbalanced_data(self, img_num_per_cls): 57 | 58 | new_data = [] 59 | new_targets = [] 60 | targets_np = np.array(self.targets, dtype=np.int64) 61 | classes = np.unique(targets_np) 62 | 63 | self.num_per_cls_dict = dict() 64 | for the_class, the_img_num in zip(classes, img_num_per_cls): 65 | self.num_per_cls_dict[the_class] = the_img_num 66 | idx = np.where(targets_np == the_class)[0] 67 | np.random.shuffle(idx) 68 | selec_idx = idx[:the_img_num] 69 | new_data.append(self.data[selec_idx, ...]) 70 | new_targets.extend([the_class, ] * the_img_num) 71 | new_data = np.vstack(new_data) 72 | self.data = new_data 73 | self.targets = new_targets 74 | 75 | def __len__(self): 76 | return len(self.labels) 77 | 78 | def get_num_classes(self): 79 | return self.cls_num 80 | 81 | def get_annotations(self): 82 | annos = [] 83 | for label in self.labels: 84 | annos.append({'category_id': int(label)}) 85 | return annos 86 | 87 | def get_cls_num_list(self): 88 | cls_num_list = [] 89 | for i in range(self.cls_num): 90 | cls_num_list.append(self.num_per_cls_dict[i]) 91 | return cls_num_list 92 | 93 | 94 | class IMBALANCECIFAR100(CIFAR100): 95 | """`CIFAR100 `_ Dataset. 96 | This is a subclass of the `CIFAR10` Dataset. 97 | """ 98 | cls_num = 100 99 | base_folder = 'cifar-100-python' 100 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 101 | filename = "cifar-100-python.tar.gz" 102 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 103 | train_list = [ 104 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 105 | ] 106 | 107 | test_list = [ 108 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 109 | ] 110 | meta = { 111 | 'filename': 'meta', 112 | 'key': 'fine_label_names', 113 | 'md5': '7973b15100ade9c7d40fb424638fde48', 114 | } -------------------------------------------------------------------------------- /pycls/datasets/randaugment.py: -------------------------------------------------------------------------------- 1 | # This file is directly taken from code implementation by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564 2 | 3 | # ---------------------------------------------------------- 4 | #This file was modified from implementation of [Auto-Augment](https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py) to add Random Augmentations. 5 | 6 | """ AutoAugment ==[Modified to]==> RandAugment""" 7 | 8 | 9 | from PIL import Image, ImageEnhance, ImageOps 10 | import PIL.ImageDraw as ImageDraw 11 | import numpy as np 12 | import random 13 | 14 | class RandAugmentPolicy(object): 15 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 16 | Example: 17 | >>> policy = RandAugmentPolicy() 18 | >>> transformed = policy(image) 19 | Example as a PyTorch Transform: 20 | >>> transform=transforms.Compose([ 21 | >>> transforms.Resize(256), 22 | >>> RandAugmentPolicy(), 23 | >>> transforms.ToTensor()]) 24 | """ 25 | #I change fill color from (128, 128, 128) to (0, 0, 0) 26 | def __init__(self, fillcolor=(0,0,0), N=1, M=5): 27 | self.policies = ["invert","autocontrast","equalize","rotate","solarize","color", \ 28 | "posterize","contrast","brightness","sharpness","shearX","shearY","translateX", \ 29 | "translateY","cutout"] 30 | self.N = N 31 | self.M = M 32 | 33 | def __call__(self, img): 34 | choosen_policies = np.random.choice(self.policies, self.N) 35 | for policy in choosen_policies: 36 | subpolicy_obj = SubPolicy(operation=policy, magnitude=self.M) 37 | img = subpolicy_obj(img) 38 | 39 | return img 40 | 41 | def __repr__(self): 42 | return "RandAugment CIFAR10 Policy with Cutout" 43 | 44 | class SubPolicy(object): 45 | def __init__(self, operation, magnitude, fillcolor=(128, 128, 128), MAX_PARAM=10): 46 | ranges = { 47 | "shearX": np.linspace(0, 0.3, MAX_PARAM), 48 | "shearY": np.linspace(0, 0.3, MAX_PARAM), 49 | "translateX": np.linspace(0, 150 / 331, MAX_PARAM), 50 | "translateY": np.linspace(0, 150 / 331, MAX_PARAM), 51 | "rotate": np.linspace(0, 30, MAX_PARAM), 52 | "color": np.linspace(0.0, 0.9, MAX_PARAM), 53 | "posterize": np.round(np.linspace(8, 4, MAX_PARAM), 0).astype(np.int), 54 | "solarize": np.linspace(256, 0, MAX_PARAM), 55 | "contrast": np.linspace(0.0, 0.9, MAX_PARAM), 56 | "sharpness": np.linspace(0.0, 0.9, MAX_PARAM), 57 | "brightness": np.linspace(0.0, 0.9, MAX_PARAM), 58 | "autocontrast": [0] * MAX_PARAM, 59 | "equalize": [0] * MAX_PARAM, 60 | "invert": [0] * MAX_PARAM, 61 | "cutout": np.linspace(0.0,0.8, MAX_PARAM), 62 | } 63 | 64 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 65 | def rotate_with_fill(img, magnitude): 66 | rot = img.convert("RGBA").rotate(magnitude) 67 | #return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 68 | return Image.composite(rot, Image.new("RGBA", rot.size, (0,) * 4), rot).convert(img.mode) 69 | 70 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 71 | assert 0.0 <= v <= 0.8 72 | if v <= 0.: 73 | return img 74 | 75 | v = v * img.size[0] 76 | 77 | return CutoutAbs(img, v) 78 | 79 | 80 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 81 | # assert 0 <= v <= 20 82 | if v < 0: 83 | return img 84 | w, h = img.size 85 | x0 = np.random.uniform(w) 86 | y0 = np.random.uniform(h) 87 | 88 | x0 = int(max(0, x0 - v / 2.)) 89 | y0 = int(max(0, y0 - v / 2.)) 90 | x1 = min(w, x0 + v) 91 | y1 = min(h, y0 + v) 92 | 93 | xy = (x0, y0, x1, y1) 94 | #color = (125, 123, 114) 95 | color = (0, 0, 0) 96 | img = img.copy() 97 | ImageDraw.Draw(img).rectangle(xy, color) 98 | return img 99 | 100 | func = { 101 | "shearX": lambda img, magnitude: img.transform( 102 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 103 | Image.BICUBIC, fillcolor=fillcolor), 104 | "shearY": lambda img, magnitude: img.transform( 105 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 106 | Image.BICUBIC, fillcolor=fillcolor), 107 | "translateX": lambda img, magnitude: img.transform( 108 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 109 | fillcolor=fillcolor), 110 | "translateY": lambda img, magnitude: img.transform( 111 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 112 | fillcolor=fillcolor), 113 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 114 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 115 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 116 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 117 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 118 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 119 | 1 + magnitude * random.choice([-1, 1])), 120 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 121 | 1 + magnitude * random.choice([-1, 1])), 122 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 123 | 1 + magnitude * random.choice([-1, 1])), 124 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 125 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 126 | "invert": lambda img, magnitude: ImageOps.invert(img), 127 | "cutout": lambda img, magnitude: Cutout(img, magnitude) 128 | } 129 | 130 | self.operation = func[operation] 131 | self.magnitude = ranges[operation][magnitude] 132 | 133 | 134 | def __call__(self, img): 135 | img = self.operation(img, self.magnitude) 136 | return img 137 | 138 | -------------------------------------------------------------------------------- /pycls/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | # This file is directly taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564 2 | # GitHub: https://github.com/PrateekMunjal 3 | # ---------------------------------------------------------- 4 | 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class IndexedSequentialSampler(Sampler): 9 | r"""Samples elements sequentially, always in the same order. 10 | 11 | Arguments: 12 | data_idxes (Dataset indexes): dataset indexes to sample from 13 | """ 14 | 15 | def __init__(self, data_idxes, isDebug=False): 16 | if isDebug: print("========= my custom squential sampler =========") 17 | self.data_idxes = data_idxes 18 | 19 | def __iter__(self): 20 | return (self.data_idxes[i] for i in range(len(self.data_idxes))) 21 | 22 | def __len__(self): 23 | return len(self.data_idxes) 24 | 25 | # class IndexedDistributedSampler(Sampler): 26 | # """Sampler that restricts data loading to a particular index set of dataset. 27 | 28 | # It is especially useful in conjunction with 29 | # :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 30 | # process can pass a DistributedSampler instance as a DataLoader sampler, 31 | # and load a subset of the original dataset that is exclusive to it. 32 | 33 | # .. note:: 34 | # Dataset is assumed to be of constant size. 35 | 36 | # Arguments: 37 | # dataset: Dataset used for sampling. 38 | # num_replicas (optional): Number of processes participating in 39 | # distributed training. 40 | # rank (optional): Rank of the current process within num_replicas. 41 | # """ 42 | 43 | # def __init__(self, dataset, index_set, num_replicas=None, rank=None, allowRepeat=True): 44 | # if num_replicas is None: 45 | # if not dist.is_available(): 46 | # raise RuntimeError("Requires distributed package to be available") 47 | # num_replicas = dist.get_world_size() 48 | # if rank is None: 49 | # if not dist.is_available(): 50 | # raise RuntimeError("Requires distributed package to be available") 51 | # rank = dist.get_rank() 52 | # self.dataset = dataset 53 | # self.num_replicas = num_replicas 54 | # self.rank = rank 55 | # self.epoch = 0 56 | # self.index_set = index_set 57 | # self.allowRepeat = allowRepeat 58 | # if self.allowRepeat: 59 | # self.num_samples = int(math.ceil(len(self.index_set) * 1.0 / self.num_replicas)) 60 | # self.total_size = self.num_samples * self.num_replicas 61 | # else: 62 | # self.num_samples = int(math.ceil((len(self.index_set)-self.rank) * 1.0 / self.num_replicas)) 63 | # self.total_size = len(self.index_set) 64 | 65 | # def __iter__(self): 66 | # # deterministically shuffle based on epoch 67 | # g = torch.Generator() 68 | # g.manual_seed(self.epoch) 69 | # indices = torch.randperm(len(self.index_set), generator=g).tolist() 70 | # #To access valid indices 71 | # #indices = self.index_set[indices] 72 | # # add extra samples to make it evenly divisible 73 | # if self.allowRepeat: 74 | # indices += indices[:(self.total_size - len(indices))] 75 | 76 | # assert len(indices) == self.total_size 77 | 78 | # # subsample 79 | # indices = self.index_set[indices[self.rank:self.total_size:self.num_replicas]] 80 | # assert len(indices) == self.num_samples, "len(indices): {} and self.num_samples: {}"\ 81 | # .format(len(indices), self.num_samples) 82 | 83 | # return iter(indices) 84 | 85 | # def __len__(self): 86 | # return self.num_samples 87 | 88 | # def set_epoch(self, epoch): 89 | # self.epoch = epoch 90 | -------------------------------------------------------------------------------- /pycls/datasets/simclr_augment.py: -------------------------------------------------------------------------------- 1 | # Modified from the source: https://github.com/sthalles/PyTorch-BYOL 2 | # Previous owner of this file: Thalles Silva 3 | 4 | from torchvision import transforms 5 | from .utils.gaussian_blur import GaussianBlur 6 | 7 | def get_simclr_ops(input_shape, s=1): 8 | # get a set of data augmentation transformations as described in the SimCLR paper. 9 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 10 | ops = [transforms.RandomHorizontalFlip(), 11 | transforms.RandomApply([color_jitter], p=0.8), 12 | transforms.RandomGrayscale(p=0.2), 13 | GaussianBlur(kernel_size=int(0.1 * input_shape)),] 14 | return ops 15 | 16 | 17 | -------------------------------------------------------------------------------- /pycls/datasets/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | import torchvision.datasets as datasets 7 | 8 | from typing import Any 9 | 10 | 11 | class TinyImageNet(datasets.ImageFolder): 12 | """`Tiny ImageNet Classification Dataset. 13 | 14 | Args: 15 | root (string): Root directory of the ImageNet Dataset. 16 | split (string, optional): The dataset split, supports ``train``, or ``val``. 17 | transform (callable, optional): A function/transform that takes in an PIL image 18 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 19 | target_transform (callable, optional): A function/transform that takes in the 20 | target and transforms it. 21 | loader (callable, optional): A function to load an image given its path. 22 | 23 | Attributes: 24 | classes (list): List of the class name tuples. 25 | class_to_idx (dict): Dict with items (class_name, class_index). 26 | wnids (list): List of the WordNet IDs. 27 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index). 28 | samples (list): List of (image path, class_index) tuples 29 | targets (list): The class_index value for each image in the dataset 30 | """ 31 | def __init__(self, root: str, split: str = 'train', transform=None, test_transform=None, **kwargs: Any) -> None: 32 | self.root = root 33 | self.test_transform = test_transform 34 | self.no_aug = False 35 | 36 | assert self.check_root(), "Something is wrong with the Tiny ImageNet dataset path. Download the official dataset zip from http://cs231n.stanford.edu/tiny-imagenet-200.zip and unzip it inside {}.".format(self.root) 37 | self.split = datasets.utils.verify_str_arg(split, "split", ("train", "val")) 38 | wnid_to_classes = self.load_wnid_to_classes() 39 | 40 | super(TinyImageNet, self).__init__(self.split_folder, **kwargs) 41 | self.transform = transform 42 | self.wnids = self.classes 43 | self.wnid_to_idx = self.class_to_idx 44 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] 45 | self.class_to_idx = {cls: idx 46 | for idx, clss in enumerate(self.classes) 47 | for cls in clss} 48 | # Tiny ImageNet val directory structure is not similar to that of train's 49 | # So a custom loading function is necessary 50 | if self.split == 'val': 51 | self.root = root 52 | self.imgs, self.targets = self.load_val_data() 53 | self.samples = [(self.imgs[idx], self.targets[idx]) for idx in range(len(self.imgs))] 54 | self.root = os.path.join(self.root, 'val') 55 | 56 | 57 | 58 | # Split folder is used for the 'super' call. Since val directory is not structured like the train, 59 | # we simply use train's structure to get all classes and other stuff 60 | @property 61 | def split_folder(self) -> str: 62 | return os.path.join(self.root, 'train') 63 | 64 | 65 | def load_val_data(self): 66 | imgs, targets = [], [] 67 | with open(os.path.join(self.root, 'val', 'val_annotations.txt'), 'r') as file: 68 | for line in file: 69 | if line.split()[1] in self.wnids: 70 | img_file, wnid = line.split('\t')[:2] 71 | imgs.append(os.path.join(self.root, 'val', 'images', img_file)) 72 | targets.append(wnid) 73 | targets = np.array([self.wnid_to_idx[wnid] for wnid in targets]) 74 | return imgs, targets 75 | 76 | 77 | def load_wnid_to_classes(self): 78 | wnid_to_classes = {} 79 | with open(os.path.join(self.root, 'words.txt'), 'r') as file: 80 | lines = file.readlines() 81 | lines = [x.split('\t') for x in lines] 82 | wnid_to_classes = {x[0]:x[1].strip() for x in lines} 83 | return wnid_to_classes 84 | 85 | def check_root(self): 86 | tinyim_set = ['words.txt', 'wnids.txt', 'train', 'val', 'test'] 87 | for x in os.scandir(self.root): 88 | if x.name not in tinyim_set: 89 | return False 90 | return True 91 | 92 | def __getitem__(self, index: int): 93 | """ 94 | Args: 95 | index (int): Index 96 | 97 | Returns: 98 | tuple: (sample, target) where target is class_index of the target class. 99 | """ 100 | path, target = self.samples[index] 101 | sample = self.loader(path) 102 | if self.no_aug: 103 | if self.test_transform is not None: 104 | sample = self.test_transform(sample) 105 | else: 106 | if self.transform is not None: 107 | sample = self.transform(sample) 108 | 109 | return sample, target -------------------------------------------------------------------------------- /pycls/datasets/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/datasets/utils/__init__.py -------------------------------------------------------------------------------- /pycls/datasets/utils/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | # Owner of this file: Thalles Silva 2 | # Source: https://github.com/sthalles/PyTorch-BYOL 3 | import torch 4 | from torchvision import transforms 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class GaussianBlur(object): 10 | """blur a single image on CPU""" 11 | 12 | def __init__(self, kernel_size): 13 | radias = kernel_size // 2 14 | kernel_size = radias * 2 + 1 15 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 16 | stride=1, padding=0, bias=False, groups=3) 17 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 18 | stride=1, padding=0, bias=False, groups=3) 19 | self.k = kernel_size 20 | self.r = radias 21 | 22 | self.blur = nn.Sequential( 23 | nn.ReflectionPad2d(radias), 24 | self.blur_h, 25 | self.blur_v 26 | ) 27 | 28 | self.pil_to_tensor = transforms.ToTensor() 29 | self.tensor_to_pil = transforms.ToPILImage() 30 | 31 | def __call__(self, img): 32 | img = self.pil_to_tensor(img).unsqueeze(0) 33 | 34 | sigma = np.random.uniform(0.1, 2.0) 35 | x = np.arange(-self.r, self.r + 1) 36 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 37 | x = x / x.sum() 38 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 39 | 40 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 41 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 42 | 43 | with torch.no_grad(): 44 | img = self.blur(img) 45 | img = img.squeeze() 46 | 47 | img = self.tensor_to_pil(img) 48 | 49 | return img -------------------------------------------------------------------------------- /pycls/datasets/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | def load_imageset(path, set_name): 4 | """ 5 | Returns the image set `set_name` present at `path` as a list. 6 | Keyword arguments: 7 | path -- path to data folder 8 | set_name -- image set name - labeled or unlabeled. 9 | """ 10 | reader = csv.reader(open(os.path.join(path, set_name+'.csv'), 'rt')) 11 | reader = [r[0] for r in reader] 12 | return reader -------------------------------------------------------------------------------- /pycls/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Expose model constructors.""" 9 | 10 | from pycls.models.resnet import * 11 | from pycls.models.vgg import * 12 | -------------------------------------------------------------------------------- /pycls/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | from typing import Any 5 | 6 | 7 | __all__ = ['AlexNet', 'alexnet'] 8 | 9 | 10 | model_urls = { 11 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 12 | } 13 | 14 | 15 | class AlexNet(nn.Module): 16 | ''' 17 | AlexNet modified (features) for CIFAR10. Source: https://github.com/icpm/pytorch-cifar10/blob/master/models/AlexNet.py. 18 | ''' 19 | def __init__(self, num_classes: int = 1000, use_dropout=False) -> None: 20 | super(AlexNet, self).__init__() 21 | self.use_dropout = use_dropout 22 | self.features = nn.Sequential( 23 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(kernel_size=2), 26 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=2), 29 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.MaxPool2d(kernel_size=2), 36 | ) 37 | # self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 38 | self.fc_block = nn.Sequential( 39 | nn.Linear(256 * 2 * 2, 4096, bias=False), 40 | nn.BatchNorm1d(4096), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(4096, 4096, bias=False), 43 | nn.BatchNorm1d(4096), 44 | nn.ReLU(inplace=True), 45 | ) 46 | self.classifier = nn.Sequential( 47 | nn.Linear(4096, num_classes), 48 | ) 49 | self.penultimate_active = False 50 | self.drop = nn.Dropout(p=0.5) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.features(x) 54 | # x = self.avgpool(x) 55 | z = torch.flatten(x, 1) 56 | if self.use_dropout: 57 | x = self.drop(x) 58 | z = self.fc_block(z) 59 | x = self.classifier(z) 60 | if self.penultimate_active: 61 | return z, x 62 | return x 63 | 64 | 65 | def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: 66 | r"""AlexNet model architecture from the 67 | `"One weird trick..." `_ paper. 68 | Args: 69 | pretrained (bool): If True, returns a model pre-trained on ImageNet 70 | progress (bool): If True, displays a progress bar of the download to stderr 71 | """ 72 | model = AlexNet(**kwargs) 73 | if pretrained: 74 | state_dict = load_state_dict_from_url(model_urls['alexnet'], 75 | progress=progress) 76 | model.load_state_dict(state_dict) 77 | return model -------------------------------------------------------------------------------- /pycls/models/vaal_model.py: -------------------------------------------------------------------------------- 1 | # This file is modified taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564 2 | # GitHub: https://github.com/PrateekMunjal 3 | # ---------------------------------------------------------- 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | 10 | import pycls.utils.logging as lu 11 | 12 | logger = lu.get_logger(__name__) 13 | 14 | class View(nn.Module): 15 | def __init__(self, size): 16 | super(View, self).__init__() 17 | self.size = size 18 | 19 | def forward(self, tensor): 20 | return tensor.view(self.size) 21 | 22 | 23 | class VAE(nn.Module): 24 | """Encoder-Decoder architecture for both WAE-MMD and WAE-GAN.""" 25 | def __init__(self, device_id,z_dim=32, nc=3): 26 | super(VAE, self).__init__() 27 | print("============================") 28 | logger.info("============================") 29 | print(f"Constructing VAE MODEL with z_dim: {z_dim}") 30 | logger.info(f"Constructing VAE MODEL with z_dim: {z_dim}") 31 | print("============================") 32 | logger.info("============================") 33 | self.encode_shape = int(z_dim/16) 34 | if z_dim == 32: 35 | self.decode_shape = 4 36 | elif z_dim == 64: 37 | self.decode_shape = 8 38 | else: 39 | self.decode_shape = 4 40 | self.device_id = device_id 41 | self.z_dim = z_dim 42 | self.nc = nc 43 | self.encoder = nn.Sequential( 44 | nn.Conv2d(nc, 128, 4, 2, 1, bias=False), # B, 128, 32, 32 or B, 128, 64, 64 45 | nn.BatchNorm2d(128), 46 | nn.ReLU(True), 47 | nn.Conv2d(128, 256, 4, 2, 1, bias=False), # B, 256, 16, 16 or B, 256, 32, 32 48 | nn.BatchNorm2d(256), 49 | nn.ReLU(True), 50 | nn.Conv2d(256, 512, 4, 2, 1, bias=False), # B, 512, 8, 8 or B, 512, 16, 16 51 | nn.BatchNorm2d(512), 52 | nn.ReLU(True), 53 | nn.Conv2d(512, 1024, 4, 2, 1, bias=False), # B, 1024, 4, 4 or B, 1024, 8, 8 54 | nn.BatchNorm2d(1024), 55 | nn.ReLU(True), 56 | View((-1, 1024*self.encode_shape*self.encode_shape)), # B, 1024*4*4 or B, 1024, 4, 4 57 | ) 58 | 59 | self.fc_mu = nn.Linear(1024*self.encode_shape*self.encode_shape, z_dim) # B, z_dim 60 | self.fc_logvar = nn.Linear(1024*self.encode_shape*self.encode_shape, z_dim) # B, z_dim 61 | self.decoder = nn.Sequential( 62 | nn.Linear(z_dim, 1024*self.decode_shape*self.decode_shape), # B, 1024*8*8 63 | View((-1, 1024, self.decode_shape, self.decode_shape)), # B, 1024, 8, 8 64 | nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), # B, 512, 16, 16 65 | nn.BatchNorm2d(512), 66 | nn.ReLU(True), 67 | nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # B, 256, 32, 32 68 | nn.BatchNorm2d(256), 69 | nn.ReLU(True), 70 | nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # B, 128, 64, 64 71 | nn.BatchNorm2d(128), 72 | nn.ReLU(True), 73 | nn.ConvTranspose2d(128, nc, 1), # B, nc, 64, 64 74 | ) 75 | self.weight_init() 76 | 77 | def weight_init(self): 78 | for block in self._modules: 79 | try: 80 | for m in self._modules[block]: 81 | kaiming_init(m) 82 | except: 83 | kaiming_init(block) 84 | 85 | def forward(self, x): 86 | z = self._encode(x) 87 | mu, logvar = self.fc_mu(z), self.fc_logvar(z) 88 | z = self.reparameterize(mu, logvar) 89 | x_recon = self._decode(z) 90 | 91 | return x_recon, z, mu, logvar 92 | 93 | def reparameterize(self, mu, logvar): 94 | stds = (0.5 * logvar).exp() 95 | epsilon = torch.randn(*mu.size()) 96 | #if mu.is_cuda: 97 | stds, epsilon = stds.cuda(), epsilon.cuda() 98 | mu = mu.cuda() 99 | latents = epsilon * stds + mu 100 | return latents 101 | 102 | def _encode(self, x): 103 | return self.encoder(x) 104 | 105 | def _decode(self, z): 106 | return self.decoder(z) 107 | 108 | class clf_Discriminator(nn.Module): 109 | """ 110 | Model to circumvent the need of learning a separate task learner by 111 | combining the discriminator and task classifier. 112 | """ 113 | def __init__(self, z_dim=10, n_classes=10): 114 | super(clf_Discriminator, self).__init__() 115 | self.z_dim = z_dim 116 | self.n_classes = n_classes 117 | 118 | self.net = nn.Sequential( 119 | nn.Linear(z_dim, 512), 120 | nn.ReLU(True), 121 | nn.Linear(512, 512), 122 | nn.ReLU(True), 123 | nn.Linear(512, 512), 124 | nn.ReLU(True), 125 | nn.Linear(512, 512), 126 | nn.ReLU(True) 127 | ) 128 | 129 | self.disc_out = nn.Sequential( 130 | nn.Linear(512, 1), 131 | nn.Sigmoid() 132 | ) 133 | 134 | self.clf_out = nn.Sequential( 135 | nn.Linear(512, self.n_classes), 136 | ) 137 | 138 | self.weight_init() 139 | 140 | def weight_init(self): 141 | for block in self._modules: 142 | for m in self._modules[block]: 143 | kaiming_init(m) 144 | 145 | def forward(self, z): 146 | z = self.net(z) 147 | disc_out = self.disc_out(z) 148 | clf_out = self.clf_out(z) 149 | return disc_out, clf_out 150 | 151 | 152 | class Discriminator(nn.Module): 153 | """Adversary architecture(Discriminator) for WAE-GAN.""" 154 | def __init__(self, z_dim=10): 155 | super(Discriminator, self).__init__() 156 | self.z_dim = z_dim 157 | self.penultimate_active = False 158 | self.net = nn.Sequential( 159 | nn.Linear(z_dim, 512), 160 | nn.ReLU(True), 161 | nn.Linear(512, 512), 162 | nn.ReLU(True), 163 | nn.Linear(512, 512), 164 | nn.ReLU(True), 165 | nn.Linear(512, 512), 166 | nn.ReLU(True) 167 | ) 168 | 169 | self.out = nn.Sequential( 170 | nn.Linear(512, 1), 171 | nn.Sigmoid() 172 | ) 173 | 174 | self.weight_init() 175 | 176 | def weight_init(self): 177 | for block in self._modules: 178 | for m in self._modules[block]: 179 | kaiming_init(m) 180 | 181 | def forward(self, z): 182 | z = self.net(z) 183 | if self.penultimate_active: 184 | return z, self.out(z) 185 | return self.out(z) 186 | 187 | class WGAN_Discriminator(nn.Module): 188 | """Adversary architecture(Discriminator) for WAE-GAN.""" 189 | def __init__(self, z_dim=10): 190 | super(WGAN_Discriminator, self).__init__() 191 | self.z_dim = z_dim 192 | self.penultimate_active = False 193 | self.net = nn.Sequential( 194 | nn.Linear(z_dim, 512), 195 | nn.ReLU(True), 196 | nn.Linear(512, 512), 197 | nn.ReLU(True), 198 | nn.Linear(512, 512), 199 | nn.ReLU(True), 200 | nn.Linear(512, 512), 201 | nn.ReLU(True) 202 | ) 203 | 204 | self.out = nn.Sequential( 205 | nn.Linear(512, 1), 206 | #nn.Sigmoid() 207 | ) 208 | 209 | self.weight_init() 210 | 211 | def weight_init(self): 212 | for block in self._modules: 213 | for m in self._modules[block]: 214 | kaiming_init(m) 215 | 216 | def forward(self, z): 217 | z = self.net(z) 218 | if self.penultimate_active: 219 | return z, self.out(z) 220 | return self.out(z) 221 | 222 | 223 | def kaiming_init(m): 224 | if isinstance(m, (nn.Linear, nn.Conv2d)): 225 | init.kaiming_normal_(m.weight) 226 | if m.bias is not None: 227 | m.bias.data.fill_(0) 228 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 229 | m.weight.data.fill_(1) 230 | if m.bias is not None: 231 | m.bias.data.fill_(0) 232 | 233 | 234 | def normal_init(m, mean, std): 235 | if isinstance(m, (nn.Linear, nn.Conv2d)): 236 | m.weight.data.normal_(mean, std) 237 | if m.bias.data is not None: 238 | m.bias.data.zero_() 239 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 240 | m.weight.data.fill_(1) 241 | if m.bias.data is not None: 242 | m.bias.data.zero_() 243 | -------------------------------------------------------------------------------- /pycls/models/vgg.py: -------------------------------------------------------------------------------- 1 | # Original Source: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 2 | 3 | # This file is modified to meet the implementation by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564 4 | # GitHub: https://github.com/PrateekMunjal 5 | # ---------------------------------------------------------- 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | # from .utils import load_state_dict_from_url 11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | from typing import Union, List, Dict, Any, cast 13 | 14 | import pycls.utils.logging as lu 15 | logger = lu.get_logger(__name__) 16 | 17 | __all__ = [ 18 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 19 | 'vgg19_bn', 'vgg19', 20 | ] 21 | 22 | 23 | model_urls = { 24 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 25 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 26 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 27 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 28 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 29 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 30 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 31 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 32 | } 33 | 34 | 35 | class VGG(nn.Module): 36 | 37 | def __init__( 38 | self, 39 | features: nn.Module, 40 | num_classes: int = 1000, 41 | init_weights: bool = True, 42 | use_dropout = False 43 | ) -> None: 44 | super(VGG, self).__init__() 45 | self.penultimate_active = False 46 | if self.num_classes == 1000: 47 | logger.warning("This open source implementation is only suitable for small datasets like CIFAR. \ 48 | For Imagenet we recommend to use Resnet based models") 49 | self.penultimate_dim = 4096 50 | else: 51 | self.penultimate_dim = 512 52 | self.features = features 53 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 54 | self.penultimate_act = nn.Sequential( 55 | nn.Linear(512 * 7 * 7, 4096), 56 | nn.ReLU(True), 57 | nn.Dropout(), 58 | nn.Linear(4096, self.penultimate_dim), 59 | nn.ReLU(True), 60 | #nn.Dropout(), 61 | ) 62 | self.classifier = nn.Sequential( 63 | nn.Linear(self.penultimate_dim, num_classes) 64 | ) 65 | 66 | # Describe model with source code link 67 | self.description = "VGG16 model loaded from VAAL source code with penultimate dim as {}".format(self.penultimate_dim) 68 | self.source_link = "https://github.com/sinhasam/vaal/blob/master/vgg.py" 69 | 70 | if init_weights: 71 | self._initialize_weights() 72 | 73 | def forward(self, x: torch.Tensor) -> torch.Tensor: 74 | x = self.features(x) 75 | x = self.avgpool(x) 76 | x = torch.flatten(x, 1) 77 | z = self.penultimate_act(x) 78 | x = self.classifier(z) 79 | if self.penultimate_active: 80 | return z, x 81 | return x 82 | 83 | def _initialize_weights(self) -> None: 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 87 | if m.bias is not None: 88 | nn.init.constant_(m.bias, 0) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | elif isinstance(m, nn.Linear): 93 | nn.init.normal_(m.weight, 0, 0.01) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | 97 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 98 | layers: List[nn.Module] = [] 99 | in_channels = 3 100 | for v in cfg: 101 | if v == 'M': 102 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 103 | else: 104 | v = cast(int, v) 105 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 106 | if batch_norm: 107 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 108 | else: 109 | layers += [conv2d, nn.ReLU(inplace=True)] 110 | in_channels = v 111 | return nn.Sequential(*layers) 112 | 113 | 114 | cfgs: Dict[str, List[Union[str, int]]] = { 115 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 116 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 117 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 118 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 119 | } 120 | 121 | 122 | def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: 123 | if pretrained: 124 | kwargs['init_weights'] = False 125 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 126 | if pretrained: 127 | state_dict = load_state_dict_from_url(model_urls[arch], 128 | progress=progress) 129 | model.load_state_dict(state_dict) 130 | return model 131 | 132 | 133 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 134 | r"""VGG 11-layer model (configuration "A") from 135 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 136 | Args: 137 | pretrained (bool): If True, returns a model pre-trained on ImageNet 138 | progress (bool): If True, displays a progress bar of the download to stderr 139 | """ 140 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 141 | 142 | 143 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 144 | r"""VGG 11-layer model (configuration "A") with batch normalization 145 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 146 | Args: 147 | pretrained (bool): If True, returns a model pre-trained on ImageNet 148 | progress (bool): If True, displays a progress bar of the download to stderr 149 | """ 150 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 151 | 152 | 153 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 154 | r"""VGG 13-layer model (configuration "B") 155 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 156 | Args: 157 | pretrained (bool): If True, returns a model pre-trained on ImageNet 158 | progress (bool): If True, displays a progress bar of the download to stderr 159 | """ 160 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 161 | 162 | 163 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 164 | r"""VGG 13-layer model (configuration "B") with batch normalization 165 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | progress (bool): If True, displays a progress bar of the download to stderr 169 | """ 170 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 171 | 172 | 173 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 174 | r"""VGG 16-layer model (configuration "D") 175 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | progress (bool): If True, displays a progress bar of the download to stderr 179 | """ 180 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 181 | 182 | 183 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 184 | r"""VGG 16-layer model (configuration "D") with batch normalization 185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | progress (bool): If True, displays a progress bar of the download to stderr 189 | """ 190 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 191 | 192 | 193 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 194 | r"""VGG 19-layer model (configuration "E") 195 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | progress (bool): If True, displays a progress bar of the download to stderr 199 | """ 200 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 201 | 202 | 203 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 204 | r"""VGG 19-layer model (configuration 'E') with batch normalization 205 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | progress (bool): If True, displays a progress bar of the download to stderr 209 | """ 210 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /pycls/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/utils/__init__.py -------------------------------------------------------------------------------- /pycls/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Benchmarking functions.""" 9 | 10 | import pycls.core.logging as logging 11 | import pycls.core.net as net 12 | import pycls.datasets.loader as loader 13 | import torch 14 | import torch.cuda.amp as amp 15 | from pycls.core.config import cfg 16 | from pycls.core.timer import Timer 17 | 18 | 19 | logger = logging.get_logger(__name__) 20 | 21 | 22 | @torch.no_grad() 23 | def compute_time_eval(model): 24 | """Computes precise model forward test time using dummy data.""" 25 | # Use eval mode 26 | model.eval() 27 | # Generate a dummy mini-batch and copy data to GPU 28 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) 29 | inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 30 | # Compute precise forward pass time 31 | timer = Timer() 32 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 33 | for cur_iter in range(total_iter): 34 | # Reset the timers after the warmup phase 35 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 36 | timer.reset() 37 | # Forward 38 | timer.tic() 39 | model(inputs) 40 | torch.cuda.synchronize() 41 | timer.toc() 42 | return timer.average_time 43 | 44 | 45 | def compute_time_train(model, loss_fun): 46 | """Computes precise model forward + backward time using dummy data.""" 47 | # Use train mode 48 | model.train() 49 | # Generate a dummy mini-batch and copy data to GPU 50 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 51 | inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 52 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) 53 | labels_one_hot = net.smooth_one_hot_labels(labels) 54 | # Cache BatchNorm2D running stats 55 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 56 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] 57 | # Create a GradScaler for mixed precision training 58 | scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) 59 | # Compute precise forward backward pass time 60 | fw_timer, bw_timer = Timer(), Timer() 61 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 62 | for cur_iter in range(total_iter): 63 | # Reset the timers after the warmup phase 64 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 65 | fw_timer.reset() 66 | bw_timer.reset() 67 | # Forward 68 | fw_timer.tic() 69 | with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION): 70 | preds = model(inputs) 71 | loss = loss_fun(preds, labels_one_hot) 72 | torch.cuda.synchronize() 73 | fw_timer.toc() 74 | # Backward 75 | bw_timer.tic() 76 | scaler.scale(loss).backward() 77 | torch.cuda.synchronize() 78 | bw_timer.toc() 79 | # Restore BatchNorm2D running stats 80 | for bn, (mean, var) in zip(bns, bn_stats): 81 | bn.running_mean, bn.running_var = mean, var 82 | return fw_timer.average_time, bw_timer.average_time 83 | 84 | 85 | def compute_time_loader(data_loader): 86 | """Computes loader time.""" 87 | timer = Timer() 88 | loader.shuffle(data_loader, 0) 89 | data_loader_iterator = iter(data_loader) 90 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 91 | total_iter = min(total_iter, len(data_loader)) 92 | for cur_iter in range(total_iter): 93 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 94 | timer.reset() 95 | timer.tic() 96 | next(data_loader_iterator) 97 | timer.toc() 98 | return timer.average_time 99 | 100 | 101 | def compute_time_model(model, loss_fun): 102 | """Times model.""" 103 | logger.info("Computing model timings only...") 104 | # Compute timings 105 | test_fw_time = compute_time_eval(model) 106 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 107 | train_fw_bw_time = train_fw_time + train_bw_time 108 | # Output iter timing 109 | iter_times = { 110 | "test_fw_time": test_fw_time, 111 | "train_fw_time": train_fw_time, 112 | "train_bw_time": train_bw_time, 113 | "train_fw_bw_time": train_fw_bw_time, 114 | } 115 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 116 | 117 | 118 | def compute_time_full(model, loss_fun, train_loader, test_loader): 119 | """Times model and data loader.""" 120 | logger.info("Computing model and loader timings...") 121 | # Compute timings 122 | test_fw_time = compute_time_eval(model) 123 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 124 | train_fw_bw_time = train_fw_time + train_bw_time 125 | train_loader_time = compute_time_loader(train_loader) 126 | # Output iter timing 127 | iter_times = { 128 | "test_fw_time": test_fw_time, 129 | "train_fw_time": train_fw_time, 130 | "train_bw_time": train_bw_time, 131 | "train_fw_bw_time": train_fw_bw_time, 132 | "train_loader_time": train_loader_time, 133 | } 134 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 135 | # Output epoch timing 136 | epoch_times = { 137 | "test_fw_time": test_fw_time * len(test_loader), 138 | "train_fw_time": train_fw_time * len(train_loader), 139 | "train_bw_time": train_bw_time * len(train_loader), 140 | "train_fw_bw_time": train_fw_bw_time * len(train_loader), 141 | "train_loader_time": train_loader_time * len(train_loader), 142 | } 143 | logger.info(logging.dump_log_data(epoch_times, "epoch_times")) 144 | # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1) 145 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time 146 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100)) 147 | -------------------------------------------------------------------------------- /pycls/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions that handle saving and loading of checkpoints.""" 9 | 10 | import os 11 | from shutil import copyfile 12 | 13 | # import pycls.core.distributed as dist 14 | import torch 15 | from pycls.core.config import cfg 16 | from pycls.core.net import unwrap_model 17 | 18 | 19 | # Common prefix for checkpoint file names 20 | _NAME_PREFIX = "model_epoch_" 21 | 22 | # Checkpoints directory name 23 | _DIR_NAME = "checkpoints" 24 | 25 | 26 | def get_checkpoint_dir(episode_dir): 27 | """Retrieves the location for storing checkpoints.""" 28 | return os.path.join(episode_dir, _DIR_NAME) 29 | 30 | 31 | def get_checkpoint(epoch, episode_dir): 32 | """Retrieves the path to a checkpoint file.""" 33 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) 34 | # return os.path.join(get_checkpoint_dir(), name) 35 | return os.path.join(episode_dir, name) 36 | 37 | 38 | def get_checkpoint_best(episode_dir): 39 | """Retrieves the path to the best checkpoint file.""" 40 | return os.path.join(episode_dir, "model.pyth") 41 | 42 | 43 | def get_last_checkpoint(episode_dir): 44 | """Retrieves the most recent checkpoint (highest epoch number).""" 45 | checkpoint_dir = get_checkpoint_dir(episode_dir) 46 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] 47 | last_checkpoint_name = sorted(checkpoints)[-1] 48 | return os.path.join(checkpoint_dir, last_checkpoint_name) 49 | 50 | 51 | def has_checkpoint(episode_dir): 52 | """Determines if there are checkpoints available.""" 53 | checkpoint_dir = get_checkpoint_dir(episode_dir) 54 | if not os.path.exists(checkpoint_dir): 55 | return False 56 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) 57 | 58 | 59 | def save_checkpoint(info, model_state, optimizer_state, epoch, cfg): 60 | 61 | """Saves a checkpoint.""" 62 | # Save checkpoints only from the master process 63 | # if not dist.is_master_proc(): 64 | # return 65 | # Ensure that the checkpoint dir exists 66 | os.makedirs(cfg.EPISODE_DIR, exist_ok=True) 67 | 68 | # Record the state 69 | checkpoint = { 70 | "epoch": epoch, 71 | "model_state": model_state, 72 | "optimizer_state": optimizer_state, 73 | "cfg": cfg.dump(), 74 | } 75 | global _NAME_PREFIX 76 | _NAME_PREFIX = info + '_' + _NAME_PREFIX 77 | 78 | # Write the checkpoint 79 | checkpoint_file = get_checkpoint(epoch, cfg.EPISODE_DIR) 80 | torch.save(checkpoint, checkpoint_file) 81 | # print("Model checkpoint saved at path: {}".format(checkpoint_file)) 82 | 83 | _NAME_PREFIX = 'model_epoch_' 84 | return checkpoint_file 85 | 86 | 87 | def load_checkpoint(checkpoint_file, model, optimizer=None): 88 | """Loads the checkpoint from the given file.""" 89 | err_str = "Checkpoint '{}' not found" 90 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) 91 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 92 | unwrap_model(model).load_state_dict(checkpoint["model_state"]) 93 | optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else () 94 | return model 95 | 96 | 97 | def delete_checkpoints(checkpoint_dir=None, keep="all"): 98 | """Deletes unneeded checkpoints, keep can be "all", "last", or "none".""" 99 | assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep) 100 | checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir() 101 | if keep == "all" or not os.path.exists(checkpoint_dir): 102 | return 0 103 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] 104 | checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints 105 | [os.remove(os.path.join(checkpoint_dir, checkpoint)) for checkpoint in checkpoints] 106 | return len(checkpoints) 107 | -------------------------------------------------------------------------------- /pycls/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Distributed helpers.""" 9 | 10 | import multiprocessing 11 | import os 12 | import random 13 | import signal 14 | import threading 15 | import traceback 16 | 17 | import torch 18 | from pycls.core.config import cfg 19 | 20 | 21 | # Make work w recent PyTorch versions (https://github.com/pytorch/pytorch/issues/37377) 22 | os.environ["MKL_THREADING_LAYER"] = "GNU" 23 | 24 | 25 | def is_master_proc(): 26 | """Determines if the current process is the master process. 27 | 28 | Master process is responsible for logging, writing and loading checkpoints. In 29 | the multi GPU setting, we assign the master role to the rank 0 process. When 30 | training using a single GPU, there is a single process which is considered master. 31 | """ 32 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 33 | 34 | 35 | def init_process_group(proc_rank, world_size, port): 36 | """Initializes the default process group.""" 37 | # Set the GPU to use 38 | torch.cuda.set_device(proc_rank) 39 | # Initialize the process group 40 | torch.distributed.init_process_group( 41 | backend=cfg.DIST_BACKEND, 42 | init_method="tcp://{}:{}".format(cfg.HOST, port), 43 | world_size=world_size, 44 | rank=proc_rank, 45 | ) 46 | 47 | 48 | def destroy_process_group(): 49 | """Destroys the default process group.""" 50 | torch.distributed.destroy_process_group() 51 | 52 | 53 | def scaled_all_reduce(tensors): 54 | """Performs the scaled all_reduce operation on the provided tensors. 55 | 56 | The input tensors are modified in-place. Currently supports only the sum 57 | reduction operator. The reduced values are scaled by the inverse size of the 58 | process group (equivalent to cfg.NUM_GPUS). 59 | """ 60 | # There is no need for reduction in the single-proc case 61 | if cfg.NUM_GPUS == 1: 62 | return tensors 63 | # Queue the reductions 64 | reductions = [] 65 | for tensor in tensors: 66 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 67 | reductions.append(reduction) 68 | # Wait for reductions to finish 69 | for reduction in reductions: 70 | reduction.wait() 71 | # Scale the results 72 | for tensor in tensors: 73 | tensor.mul_(1.0 / cfg.NUM_GPUS) 74 | return tensors 75 | 76 | 77 | class ChildException(Exception): 78 | """Wraps an exception from a child process.""" 79 | 80 | def __init__(self, child_trace): 81 | super(ChildException, self).__init__(child_trace) 82 | 83 | 84 | class ErrorHandler(object): 85 | """Multiprocessing error handler (based on fairseq's). 86 | 87 | Listens for errors in child processes and propagates the tracebacks to the parent. 88 | """ 89 | 90 | def __init__(self, error_queue): 91 | # Shared error queue 92 | self.error_queue = error_queue 93 | # Children processes sharing the error queue 94 | self.children_pids = [] 95 | # Start a thread listening to errors 96 | self.error_listener = threading.Thread(target=self.listen, daemon=True) 97 | self.error_listener.start() 98 | # Register the signal handler 99 | signal.signal(signal.SIGUSR1, self.signal_handler) 100 | 101 | def add_child(self, pid): 102 | """Registers a child process.""" 103 | self.children_pids.append(pid) 104 | 105 | def listen(self): 106 | """Listens for errors in the error queue.""" 107 | # Wait until there is an error in the queue 108 | child_trace = self.error_queue.get() 109 | # Put the error back for the signal handler 110 | self.error_queue.put(child_trace) 111 | # Invoke the signal handler 112 | os.kill(os.getpid(), signal.SIGUSR1) 113 | 114 | def signal_handler(self, _sig_num, _stack_frame): 115 | """Signal handler.""" 116 | # Kill children processes 117 | for pid in self.children_pids: 118 | os.kill(pid, signal.SIGINT) 119 | # Propagate the error from the child process 120 | raise ChildException(self.error_queue.get()) 121 | 122 | 123 | def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs): 124 | """Runs a function from a child process.""" 125 | try: 126 | # Initialize the process group 127 | init_process_group(proc_rank, world_size, port) 128 | # Run the function 129 | fun(*fun_args, **fun_kwargs) 130 | except KeyboardInterrupt: 131 | # Killed by the parent process 132 | pass 133 | except Exception: 134 | # Propagate exception to the parent process 135 | error_queue.put(traceback.format_exc()) 136 | finally: 137 | # Destroy the process group 138 | destroy_process_group() 139 | 140 | 141 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): 142 | """Runs a function in a multi-proc setting (unless num_proc == 1).""" 143 | # There is no need for multi-proc in the single-proc case 144 | fun_kwargs = fun_kwargs if fun_kwargs else {} 145 | if num_proc == 1: 146 | fun(*fun_args, **fun_kwargs) 147 | return 148 | # Handle errors from training subprocesses 149 | error_queue = multiprocessing.SimpleQueue() 150 | error_handler = ErrorHandler(error_queue) 151 | # Get a random port to use (without using global random number generator) 152 | port = random.Random().randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1]) 153 | # Run each training subprocess 154 | ps = [] 155 | for i in range(num_proc): 156 | p_i = multiprocessing.Process( 157 | target=run, args=(i, num_proc, port, error_queue, fun, fun_args, fun_kwargs) 158 | ) 159 | ps.append(p_i) 160 | p_i.start() 161 | error_handler.add_child(p_i.pid) 162 | # Wait for each subprocess to finish 163 | for p in ps: 164 | p.join() 165 | -------------------------------------------------------------------------------- /pycls/utils/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """IO utilities (adapted from Detectron)""" 9 | 10 | import logging 11 | import os 12 | import re 13 | import sys 14 | from urllib import request as urlrequest 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" 20 | 21 | 22 | def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL): 23 | """Download the file specified by the URL to the cache_dir and return the path to 24 | the cached file. If the argument is not a URL, simply return it as is. 25 | """ 26 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None 27 | if not is_url: 28 | return url_or_file 29 | url = url_or_file 30 | assert url.startswith(base_url), "url must start with: {}".format(base_url) 31 | cache_file_path = url.replace(base_url, cache_dir) 32 | if os.path.exists(cache_file_path): 33 | return cache_file_path 34 | cache_file_dir = os.path.dirname(cache_file_path) 35 | if not os.path.exists(cache_file_dir): 36 | os.makedirs(cache_file_dir) 37 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) 38 | download_url(url, cache_file_path) 39 | return cache_file_path 40 | 41 | 42 | def _progress_bar(count, total): 43 | """Report download progress. Credit: 44 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 45 | """ 46 | bar_len = 60 47 | filled_len = int(round(bar_len * count / float(total))) 48 | percents = round(100.0 * count / float(total), 1) 49 | bar = "=" * filled_len + "-" * (bar_len - filled_len) 50 | sys.stdout.write( 51 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) 52 | ) 53 | sys.stdout.flush() 54 | if count >= total: 55 | sys.stdout.write("\n") 56 | 57 | 58 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): 59 | """Download url and write it to dst_file_path. Credit: 60 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook 61 | """ 62 | req = urlrequest.Request(url) 63 | response = urlrequest.urlopen(req) 64 | total_size = response.info().get("Content-Length").strip() 65 | total_size = int(total_size) 66 | bytes_so_far = 0 67 | with open(dst_file_path, "wb") as f: 68 | while 1: 69 | chunk = response.read(chunk_size) 70 | bytes_so_far += len(chunk) 71 | if not chunk: 72 | break 73 | if progress_hook: 74 | progress_hook(bytes_so_far, total_size) 75 | f.write(chunk) 76 | return bytes_so_far 77 | -------------------------------------------------------------------------------- /pycls/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Logging.""" 9 | 10 | import builtins 11 | import decimal 12 | import logging 13 | import os 14 | import simplejson 15 | import sys 16 | 17 | from pycls.core.config import cfg 18 | 19 | # import pycls.utils.distributed as du 20 | 21 | # Show filename and line number in logs 22 | _FORMAT = '[%(asctime)s %(filename)s: %(lineno)3d]: %(message)s' 23 | 24 | # Log file name (for cfg.LOG_DEST = 'file') 25 | _LOG_FILE = 'stdout.log' 26 | 27 | # Printed json stats lines will be tagged w/ this 28 | _TAG = 'json_stats: ' 29 | 30 | 31 | def _suppress_print(): 32 | """Suppresses printing from the current process.""" 33 | def ignore(*_objects, _sep=' ', _end='\n', _file=sys.stdout, _flush=False): 34 | pass 35 | builtins.print = ignore 36 | 37 | 38 | def setup_logging(cfg): 39 | """Sets up the logging.""" 40 | # Enable logging only for the master process 41 | # if du.is_master_proc(): 42 | if True: 43 | # Clear the root logger to prevent any existing logging config 44 | # (e.g. set by another module) from messing with our setup 45 | logging.root.handlers = [] 46 | # Construct logging configuration 47 | logging_config = { 48 | 'level': logging.INFO, 49 | 'format': _FORMAT, 50 | 'datefmt': '%Y-%m-%d %H:%M:%S' 51 | } 52 | # Log either to stdout or to a file 53 | if cfg.LOG_DEST == 'stdout': 54 | logging_config['stream'] = sys.stdout 55 | else: 56 | logging_config['filename'] = os.path.join(cfg.EXP_DIR, _LOG_FILE) 57 | # Configure logging 58 | logging.basicConfig(**logging_config) 59 | else: 60 | _suppress_print() 61 | 62 | 63 | def get_logger(name): 64 | """Retrieves the logger.""" 65 | return logging.getLogger(name) 66 | 67 | 68 | def log_json_stats(stats): 69 | """Logs json stats.""" 70 | # Decimal + string workaround for having fixed len float vals in logs 71 | stats = { 72 | k: decimal.Decimal('{:.12f}'.format(v)) if isinstance(v, float) else v 73 | for k, v in stats.items() 74 | } 75 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 76 | logger = get_logger(__name__) 77 | logger.info('{:s}{:s}'.format(_TAG, json_stats)) 78 | 79 | 80 | def load_json_stats(log_file): 81 | """Loads json_stats from a single log file.""" 82 | with open(log_file, 'r') as f: 83 | lines = f.readlines() 84 | json_lines = [l[l.find(_TAG) + len(_TAG):] for l in lines if _TAG in l] 85 | json_stats = [simplejson.loads(l) for l in json_lines] 86 | return json_stats 87 | 88 | 89 | def parse_json_stats(log, row_type, key): 90 | """Extract values corresponding to row_type/key out of log.""" 91 | vals = [row[key] for row in log if row['_type'] == row_type and key in row] 92 | if key == 'iter' or key == 'epoch': 93 | vals = [int(val.split('/')[0]) for val in vals] 94 | return vals 95 | 96 | 97 | def get_log_files(log_dir, name_filter=''): 98 | """Get all log files in directory containing subdirs of trained models.""" 99 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] 100 | files = [os.path.join(log_dir, n, _LOG_FILE) for n in names] 101 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] 102 | files, names = zip(*f_n_ps) 103 | return files, names 104 | -------------------------------------------------------------------------------- /pycls/utils/meters.py: -------------------------------------------------------------------------------- 1 | #This file is modified from official pycls repository to adapt in AL settings. 2 | """Meters.""" 3 | 4 | from collections import deque 5 | 6 | import datetime 7 | import numpy as np 8 | 9 | from pycls.core.config import cfg 10 | from pycls.utils.timer import Timer 11 | 12 | import pycls.utils.logging as lu 13 | import pycls.utils.metrics as metrics 14 | 15 | 16 | def eta_str(eta_td): 17 | """Converts an eta timedelta to a fixed-width string format.""" 18 | days = eta_td.days 19 | hrs, rem = divmod(eta_td.seconds, 3600) 20 | mins, secs = divmod(rem, 60) 21 | return '{0:02},{1:02}:{2:02}:{3:02}'.format(days, hrs, mins, secs) 22 | 23 | 24 | class ScalarMeter(object): 25 | """Measures a scalar value (adapted from Detectron).""" 26 | 27 | def __init__(self, window_size): 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | 32 | def reset(self): 33 | self.deque.clear() 34 | self.total = 0.0 35 | self.count = 0 36 | 37 | def add_value(self, value): 38 | self.deque.append(value) 39 | self.count += 1 40 | self.total += value 41 | 42 | def get_win_median(self): 43 | return np.median(self.deque) 44 | 45 | def get_win_avg(self): 46 | return np.mean(self.deque) 47 | 48 | def get_global_avg(self): 49 | return self.total / self.count 50 | 51 | 52 | class TrainMeter(object): 53 | """Measures training stats.""" 54 | 55 | def __init__(self, epoch_iters): 56 | self.epoch_iters = epoch_iters 57 | self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters 58 | self.iter_timer = Timer() 59 | self.loss = ScalarMeter(cfg.LOG_PERIOD) 60 | self.loss_total = 0.0 61 | self.lr = None 62 | # Current minibatch errors (smoothed over a window) 63 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 64 | # Number of misclassified examples 65 | self.num_top1_mis = 0 66 | self.num_samples = 0 67 | 68 | def reset(self, timer=False): 69 | if timer: 70 | self.iter_timer.reset() 71 | self.loss.reset() 72 | self.loss_total = 0.0 73 | self.lr = None 74 | self.mb_top1_err.reset() 75 | self.num_top1_mis = 0 76 | self.num_samples = 0 77 | 78 | def iter_tic(self): 79 | self.iter_timer.tic() 80 | 81 | def iter_toc(self): 82 | self.iter_timer.toc() 83 | 84 | def update_stats(self, top1_err, loss, lr, mb_size): 85 | # Current minibatch stats 86 | self.mb_top1_err.add_value(top1_err) 87 | self.loss.add_value(loss) 88 | self.lr = lr 89 | # Aggregate stats 90 | self.num_top1_mis += top1_err * mb_size 91 | self.loss_total += loss * mb_size 92 | self.num_samples += mb_size 93 | 94 | 95 | def get_iter_stats(self, cur_epoch, cur_iter): 96 | eta_sec = self.iter_timer.average_time * ( 97 | self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1) 98 | ) 99 | eta_td = datetime.timedelta(seconds=int(eta_sec)) 100 | mem_usage = metrics.gpu_mem_usage() 101 | stats = { 102 | '_type': 'train_iter', 103 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 104 | 'iter': '{}/{}'.format(cur_iter + 1, self.epoch_iters), 105 | 'top1_err': self.mb_top1_err.get_win_median(), 106 | 'loss': self.loss.get_win_median(), 107 | 'lr': self.lr, 108 | } 109 | return stats 110 | 111 | def log_iter_stats(self, cur_epoch, cur_iter): 112 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0: 113 | return 114 | stats = self.get_iter_stats(cur_epoch, cur_iter) 115 | lu.log_json_stats(stats) 116 | 117 | def get_epoch_stats(self, cur_epoch): 118 | eta_sec = self.iter_timer.average_time * ( 119 | self.max_iter - (cur_epoch + 1) * self.epoch_iters 120 | ) 121 | eta_td = datetime.timedelta(seconds=int(eta_sec)) 122 | mem_usage = metrics.gpu_mem_usage() 123 | top1_err = self.num_top1_mis / self.num_samples 124 | avg_loss = self.loss_total / self.num_samples 125 | stats = { 126 | '_type': 'train_epoch', 127 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 128 | 'top1_err': top1_err, 129 | 'loss': avg_loss, 130 | 'lr': self.lr, 131 | } 132 | return stats 133 | 134 | def log_epoch_stats(self, cur_epoch): 135 | stats = self.get_epoch_stats(cur_epoch) 136 | lu.log_json_stats(stats) 137 | 138 | 139 | class TestMeter(object): 140 | """Measures testing stats.""" 141 | 142 | def __init__(self, max_iter): 143 | self.max_iter = max_iter 144 | self.iter_timer = Timer() 145 | # Current minibatch errors (smoothed over a window) 146 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 147 | # Min errors (over the full test set) 148 | self.min_top1_err = 100.0 149 | # Number of misclassified examples 150 | self.num_top1_mis = 0 151 | self.num_samples = 0 152 | 153 | def reset(self, min_errs=False): 154 | if min_errs: 155 | self.min_top1_err = 100.0 156 | self.iter_timer.reset() 157 | self.mb_top1_err.reset() 158 | self.num_top1_mis = 0 159 | self.num_samples = 0 160 | 161 | def iter_tic(self): 162 | self.iter_timer.tic() 163 | 164 | def iter_toc(self): 165 | self.iter_timer.toc() 166 | 167 | def update_stats(self, top1_err, mb_size): 168 | self.mb_top1_err.add_value(top1_err) 169 | self.num_top1_mis += top1_err * mb_size 170 | self.num_samples += mb_size 171 | 172 | def get_iter_stats(self, cur_epoch, cur_iter): 173 | mem_usage = metrics.gpu_mem_usage() 174 | iter_stats = { 175 | '_type': 'test_iter', 176 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 177 | 'iter': '{}/{}'.format(cur_iter + 1, self.max_iter), 178 | 'top1_err': self.mb_top1_err.get_win_median(), 179 | } 180 | return iter_stats 181 | 182 | def log_iter_stats(self, cur_epoch, cur_iter): 183 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0: 184 | return 185 | stats = self.get_iter_stats(cur_epoch, cur_iter) 186 | lu.log_json_stats(stats) 187 | 188 | def get_epoch_stats(self, cur_epoch): 189 | top1_err = self.num_top1_mis / self.num_samples 190 | self.min_top1_err = min(self.min_top1_err, top1_err) 191 | mem_usage = metrics.gpu_mem_usage() 192 | stats = { 193 | '_type': 'test_epoch', 194 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 195 | 'top1_err': top1_err, 196 | 'min_top1_err': self.min_top1_err 197 | } 198 | return stats 199 | 200 | def log_epoch_stats(self, cur_epoch): 201 | stats = self.get_epoch_stats(cur_epoch) 202 | lu.log_json_stats(stats) 203 | 204 | class ValMeter(object): 205 | """Measures Validation stats.""" 206 | 207 | def __init__(self, max_iter): 208 | self.max_iter = max_iter 209 | self.iter_timer = Timer() 210 | # Current minibatch errors (smoothed over a window) 211 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 212 | # Min errors (over the full Val set) 213 | self.min_top1_err = 100.0 214 | # Number of misclassified examples 215 | self.num_top1_mis = 0 216 | self.num_samples = 0 217 | 218 | def reset(self, min_errs=False): 219 | if min_errs: 220 | self.min_top1_err = 100.0 221 | self.iter_timer.reset() 222 | self.mb_top1_err.reset() 223 | self.num_top1_mis = 0 224 | self.num_samples = 0 225 | 226 | def iter_tic(self): 227 | self.iter_timer.tic() 228 | 229 | def iter_toc(self): 230 | self.iter_timer.toc() 231 | 232 | def update_stats(self, top1_err, mb_size): 233 | self.mb_top1_err.add_value(top1_err) 234 | self.num_top1_mis += top1_err * mb_size 235 | self.num_samples += mb_size 236 | 237 | def get_iter_stats(self, cur_epoch, cur_iter): 238 | mem_usage = metrics.gpu_mem_usage() 239 | iter_stats = { 240 | '_type': 'Val_iter', 241 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 242 | 'iter': '{}/{}'.format(cur_iter + 1, self.max_iter), 243 | 'top1_err': self.mb_top1_err.get_win_median(), 244 | } 245 | return iter_stats 246 | 247 | def log_iter_stats(self, cur_epoch, cur_iter): 248 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0: 249 | return 250 | stats = self.get_iter_stats(cur_epoch, cur_iter) 251 | lu.log_json_stats(stats) 252 | 253 | def get_epoch_stats(self, cur_epoch): 254 | top1_err = self.num_top1_mis / self.num_samples 255 | self.min_top1_err = min(self.min_top1_err, top1_err) 256 | mem_usage = metrics.gpu_mem_usage() 257 | stats = { 258 | '_type': 'Val_epoch', 259 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 260 | 'top1_err': top1_err, 261 | 'min_top1_err': self.min_top1_err 262 | } 263 | return stats 264 | 265 | def log_epoch_stats(self, cur_epoch): 266 | stats = self.get_epoch_stats(cur_epoch) 267 | lu.log_json_stats(stats) -------------------------------------------------------------------------------- /pycls/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for computing metrics.""" 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | 14 | from pycls.core.config import cfg 15 | 16 | # Number of bytes in a megabyte 17 | _B_IN_MB = 1024 * 1024 18 | 19 | 20 | def topks_correct(preds, labels, ks): 21 | """Computes the number of top-k correct predictions for each k.""" 22 | assert preds.size(0) == labels.size(0), \ 23 | 'Batch dim of predictions and labels must match' 24 | # Find the top max_k predictions for each sample 25 | _top_max_k_vals, top_max_k_inds = torch.topk( 26 | preds, max(ks), dim=1, largest=True, sorted=True 27 | ) 28 | # (batch_size, max_k) -> (max_k, batch_size) 29 | top_max_k_inds = top_max_k_inds.t() 30 | # (batch_size, ) -> (max_k, batch_size) 31 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 32 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 33 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 34 | # Compute the number of topk correct predictions for each k 35 | topks_correct = [ 36 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks 37 | ] 38 | return topks_correct 39 | 40 | 41 | def topk_errors(preds, labels, ks): 42 | """Computes the top-k error for each k.""" 43 | num_topks_correct = topks_correct(preds, labels, ks) 44 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 45 | 46 | 47 | def topk_accuracies(preds, labels, ks): 48 | """Computes the top-k accuracy for each k.""" 49 | num_topks_correct = topks_correct(preds, labels, ks) 50 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 51 | 52 | 53 | def params_count(model): 54 | """Computes the number of parameters.""" 55 | return np.sum([p.numel() for p in model.parameters()]).item() 56 | 57 | 58 | def flops_count(model): 59 | """Computes the number of flops statically.""" 60 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE 61 | count = 0 62 | for n, m in model.named_modules(): 63 | if isinstance(m, nn.Conv2d): 64 | if 'se.' in n: 65 | count += m.in_channels * m.out_channels + m.bias.numel() 66 | continue 67 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 68 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 69 | count += np.prod([ 70 | m.weight.numel(), 71 | h_out, w_out 72 | ]) 73 | if '.proj' not in n: 74 | h, w = h_out, w_out 75 | elif isinstance(m, nn.MaxPool2d): 76 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 77 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 78 | elif isinstance(m, nn.Linear): 79 | count += m.in_features * m.out_features 80 | return count.item() 81 | 82 | 83 | def gpu_mem_usage(): 84 | """Computes the GPU memory usage for the current device (MB).""" 85 | mem_usage_bytes = torch.cuda.max_memory_allocated() 86 | return mem_usage_bytes / _B_IN_MB 87 | -------------------------------------------------------------------------------- /pycls/utils/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | 15 | from pycls.core.config import cfg 16 | 17 | 18 | def init_weights(m): 19 | """Performs ResNet-style weight initialization.""" 20 | if isinstance(m, nn.Conv2d): 21 | # Note that there is no bias due to BN 22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | zero_init_gamma = ( 26 | hasattr(m, 'final_bn') and m.final_bn and 27 | cfg.BN.ZERO_INIT_FINAL_GAMMA 28 | ) 29 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.Linear): 32 | m.weight.data.normal_(mean=0.0, std=0.01) 33 | m.bias.data.zero_() 34 | 35 | 36 | @torch.no_grad() 37 | def compute_precise_bn_stats(model, loader): 38 | """Computes precise BN stats on training data.""" 39 | # Compute the number of minibatches to use 40 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) 41 | # Retrieve the BN layers 42 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 43 | # Initialize stats storage 44 | mus = [torch.zeros_like(bn.running_mean) for bn in bns] 45 | sqs = [torch.zeros_like(bn.running_var) for bn in bns] 46 | # Remember momentum values 47 | moms = [bn.momentum for bn in bns] 48 | # Disable momentum 49 | for bn in bns: 50 | bn.momentum = 1.0 51 | # Accumulate the stats across the data samples 52 | for inputs, _labels in itertools.islice(loader, num_iter): 53 | model(inputs.cuda()) 54 | # Accumulate the stats for each BN layer 55 | for i, bn in enumerate(bns): 56 | m, v = bn.running_mean, bn.running_var 57 | sqs[i] += (v + m * m) / num_iter 58 | mus[i] += m / num_iter 59 | # Set the stats and restore momentum values 60 | for i, bn in enumerate(bns): 61 | bn.running_var = sqs[i] - mus[i] * mus[i] 62 | bn.running_mean = mus[i] 63 | bn.momentum = moms[i] 64 | 65 | 66 | def reset_bn_stats(model): 67 | """Resets running BN stats.""" 68 | for m in model.modules(): 69 | if isinstance(m, torch.nn.BatchNorm2d): 70 | m.reset_running_stats() 71 | 72 | 73 | def drop_connect(x, drop_ratio): 74 | """Drop connect (adapted from DARTS).""" 75 | keep_ratio = 1.0 - drop_ratio 76 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 77 | mask.bernoulli_(keep_ratio) 78 | x.div_(keep_ratio) 79 | x.mul_(mask) 80 | return x 81 | 82 | 83 | def get_flat_weights(model): 84 | """Gets all model weights as a single flat vector.""" 85 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) 86 | 87 | 88 | def set_flat_weights(model, flat_weights): 89 | """Sets all model weights from a single flat vector.""" 90 | k = 0 91 | for p in model.parameters(): 92 | n = p.data.numel() 93 | p.data.copy_(flat_weights[k:(k + n)].view_as(p.data)) 94 | k += n 95 | assert k == flat_weights.numel() 96 | -------------------------------------------------------------------------------- /pycls/utils/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Plotting functions.""" 9 | 10 | import colorlover as cl 11 | import matplotlib.pyplot as plt 12 | import plotly.graph_objs as go 13 | import plotly.offline as offline 14 | import pycls.core.logging as logging 15 | 16 | 17 | def get_plot_colors(max_colors, color_format="pyplot"): 18 | """Generate colors for plotting.""" 19 | colors = cl.scales["11"]["qual"]["Paired"] 20 | if max_colors > len(colors): 21 | colors = cl.to_rgb(cl.interp(colors, max_colors)) 22 | if color_format == "pyplot": 23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] 24 | return colors 25 | 26 | 27 | def prepare_plot_data(log_files, names, metric="top1_err"): 28 | """Load logs and extract data for plotting error curves.""" 29 | plot_data = [] 30 | for file, name in zip(log_files, names): 31 | d, data = {}, logging.sort_log_data(logging.load_log_data(file)) 32 | for phase in ["train", "test"]: 33 | x = data[phase + "_epoch"]["epoch_ind"] 34 | y = data[phase + "_epoch"][metric] 35 | d["x_" + phase], d["y_" + phase] = x, y 36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name 37 | plot_data.append(d) 38 | assert len(plot_data) > 0, "No data to plot" 39 | return plot_data 40 | 41 | 42 | def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"): 43 | """Plot error curves using plotly and save to file.""" 44 | plot_data = prepare_plot_data(log_files, names, metric) 45 | colors = get_plot_colors(len(plot_data), "plotly") 46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend) 47 | data = [] 48 | for i, d in enumerate(plot_data): 49 | s = str(i) 50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} 51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5} 52 | data.append( 53 | go.Scatter( 54 | x=d["x_train"], 55 | y=d["y_train"], 56 | mode="lines", 57 | name=d["train_label"], 58 | line=line_train, 59 | legendgroup=s, 60 | visible=True, 61 | showlegend=False, 62 | ) 63 | ) 64 | data.append( 65 | go.Scatter( 66 | x=d["x_test"], 67 | y=d["y_test"], 68 | mode="lines", 69 | name=d["test_label"], 70 | line=line_test, 71 | legendgroup=s, 72 | visible=True, 73 | showlegend=True, 74 | ) 75 | ) 76 | data.append( 77 | go.Scatter( 78 | x=d["x_train"], 79 | y=d["y_train"], 80 | mode="lines", 81 | name=d["train_label"], 82 | line=line_train, 83 | legendgroup=s, 84 | visible=False, 85 | showlegend=True, 86 | ) 87 | ) 88 | # Prepare layout w ability to toggle 'all', 'train', 'test' 89 | titlefont = {"size": 18, "color": "#7f7f7f"} 90 | vis = [[True, True, False], [False, False, True], [False, True, False]] 91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) 92 | buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons] 93 | layout = go.Layout( 94 | title=metric + " vs. epoch
[dash=train, solid=test]", 95 | xaxis={"title": "epoch", "titlefont": titlefont}, 96 | yaxis={"title": metric, "titlefont": titlefont}, 97 | showlegend=True, 98 | hoverlabel={"namelength": -1}, 99 | updatemenus=[ 100 | { 101 | "buttons": buttons, 102 | "direction": "down", 103 | "showactive": True, 104 | "x": 1.02, 105 | "xanchor": "left", 106 | "y": 1.08, 107 | "yanchor": "top", 108 | } 109 | ], 110 | ) 111 | # Create plotly plot 112 | offline.plot({"data": data, "layout": layout}, filename=filename) 113 | 114 | 115 | def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"): 116 | """Plot error curves using matplotlib.pyplot and save to file.""" 117 | plot_data = prepare_plot_data(log_files, names, metric) 118 | colors = get_plot_colors(len(names)) 119 | for ind, d in enumerate(plot_data): 120 | c, lbl = colors[ind], d["test_label"] 121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) 122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) 123 | plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14) 124 | plt.xlabel("epoch", fontsize=14) 125 | plt.ylabel(metric, fontsize=14) 126 | plt.grid(alpha=0.4) 127 | plt.legend() 128 | if filename: 129 | plt.savefig(filename) 130 | plt.clf() 131 | else: 132 | plt.show() 133 | -------------------------------------------------------------------------------- /pycls/utils/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Timer.""" 9 | 10 | import time 11 | 12 | 13 | class Timer(object): 14 | """A simple timer (adapted from Detectron).""" 15 | 16 | def __init__(self): 17 | self.total_time = None 18 | self.calls = None 19 | self.start_time = None 20 | self.diff = None 21 | self.average_time = None 22 | self.reset() 23 | 24 | def tic(self): 25 | # using time.time as time.clock does not normalize for multithreading 26 | self.start_time = time.time() 27 | 28 | def toc(self): 29 | self.diff = time.time() - self.start_time 30 | self.total_time += self.diff 31 | self.calls += 1 32 | self.average_time = self.total_time / self.calls 33 | 34 | def reset(self): 35 | self.total_time = 0.0 36 | self.calls = 0 37 | self.start_time = 0.0 38 | self.diff = 0.0 39 | self.average_time = 0.0 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==19.3b0 2 | flake8==3.8.4 3 | isort==4.3.21 4 | matplotlib==3.3.4 5 | numpy 6 | opencv-python==4.2.0.34 7 | torch==1.7.1 8 | torchvision==0.8.2 9 | parameterized 10 | setuptools 11 | simplejson 12 | yacs -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/tools/__init__.py -------------------------------------------------------------------------------- /tools/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | # local 13 | 14 | def add_path(path): 15 | if path not in sys.path: 16 | sys.path.insert(0, path) 17 | 18 | add_path(os.path.abspath('..')) 19 | 20 | import pycls.core.builders as model_builder 21 | from pycls.core.config import cfg, dump_cfg 22 | from pycls.datasets.data import Data 23 | import pycls.utils.checkpoint as cu 24 | import pycls.utils.logging as lu 25 | import pycls.utils.metrics as mu 26 | import pycls.utils.net as nu 27 | from pycls.utils.meters import TestMeter 28 | 29 | logger = lu.get_logger(__name__) 30 | 31 | def argparser(): 32 | parser = argparse.ArgumentParser(description='Passive Learning - Image Classification') 33 | parser.add_argument('--cfg', dest='cfg_file', help='Config file', required=True, type=str) 34 | 35 | return parser 36 | 37 | def plot_arrays(x_vals, y_vals, x_name, y_name, dataset_name, out_dir, isDebug=False): 38 | # if not du.is_master_proc(): 39 | # return 40 | 41 | import matplotlib.pyplot as plt 42 | temp_name = "{}_vs_{}".format(x_name, y_name) 43 | plt.xlabel(x_name) 44 | plt.ylabel(y_name) 45 | plt.title("Dataset: {}; {}".format(dataset_name, temp_name)) 46 | plt.plot(x_vals, y_vals) 47 | 48 | if isDebug: print("plot_saved at : {}".format(os.path.join(out_dir, temp_name+'.png'))) 49 | 50 | plt.savefig(os.path.join(out_dir, temp_name+".png")) 51 | plt.close() 52 | 53 | def save_plot_values(temp_arrays, temp_names, out_dir, isParallel=True, saveInTextFormat=False, isDebug=True): 54 | 55 | """ Saves arrays provided in the list in npy format """ 56 | # Return if not master process 57 | # if isParallel: 58 | # if not du.is_master_proc(): 59 | # return 60 | 61 | for i in range(len(temp_arrays)): 62 | temp_arrays[i] = np.array(temp_arrays[i]) 63 | temp_dir = out_dir 64 | # if cfg.TRAIN.TRANSFER_EXP: 65 | # temp_dir += os.path.join("transfer_experiment",cfg.MODEL.TRANSFER_MODEL_TYPE+"_depth_"+str(cfg.MODEL.TRANSFER_MODEL_DEPTH))+"/" 66 | 67 | if not os.path.exists(temp_dir): 68 | os.makedirs(temp_dir) 69 | if saveInTextFormat: 70 | # if isDebug: print(f"Saving {temp_names[i]} at {temp_dir+temp_names[i]}.txt in text format!!") 71 | np.savetxt(temp_dir+'/'+temp_names[i]+".txt", temp_arrays[i], fmt="%d") 72 | else: 73 | # if isDebug: print(f"Saving {temp_names[i]} at {temp_dir+temp_names[i]}.npy in numpy format!!") 74 | np.save(temp_dir+'/'+temp_names[i]+".npy", temp_arrays[i]) 75 | 76 | def is_eval_epoch(cur_epoch): 77 | """Determines if the model should be evaluated at the current epoch.""" 78 | return ( 79 | (cur_epoch + 1) % cfg.TRAIN.EVAL_PERIOD == 0 or 80 | (cur_epoch + 1) == cfg.OPTIM.MAX_EPOCH 81 | ) 82 | 83 | 84 | def main(cfg): 85 | 86 | # Setting up GPU args 87 | use_cuda = (cfg.NUM_GPUS > 0) and torch.cuda.is_available() 88 | device = torch.device("cuda" if use_cuda else "cpu") 89 | kwargs = {'num_workers': cfg.DATA_LOADER.NUM_WORKERS, 'pin_memory': cfg.DATA_LOADER.PIN_MEMORY} if use_cuda else {} 90 | 91 | # Using specific GPU 92 | # os.environ['NVIDIA_VISIBLE_DEVICES'] = str(cfg.GPU_ID) 93 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 94 | # print("Using GPU : {}.\n".format(cfg.GPU_ID)) 95 | 96 | # Getting the output directory ready (default is "/output") 97 | cfg.OUT_DIR = os.path.join(os.path.abspath('..'), cfg.OUT_DIR) 98 | if not os.path.exists(cfg.OUT_DIR): 99 | os.mkdir(cfg.OUT_DIR) 100 | # Create "DATASET/MODEL TYPE" specific directory 101 | dataset_out_dir = os.path.join(cfg.OUT_DIR, cfg.DATASET.NAME, cfg.MODEL.TYPE) 102 | if not os.path.exists(dataset_out_dir): 103 | os.makedirs(dataset_out_dir) 104 | # Creating the experiment directory inside the dataset specific directory 105 | # all logs, labeled, unlabeled, validation sets are stroed here 106 | # E.g., output/CIFAR10/resnet18/{timestamp or cfg.EXP_NAME based on arguments passed} 107 | if cfg.EXP_NAME == 'auto': 108 | now = datetime.now() 109 | exp_dir = f'{now.year}_{now.month}_{now.day}_{now.hour}{now.minute}{now.second}' 110 | else: 111 | exp_dir = cfg.EXP_NAME 112 | 113 | exp_dir = os.path.join(dataset_out_dir, exp_dir) 114 | if not os.path.exists(exp_dir): 115 | os.mkdir(exp_dir) 116 | print("Experiment Directory is {}.\n".format(exp_dir)) 117 | else: 118 | print("Experiment Directory Already Exists: {}. Reusing it may lead to loss of old logs in the directory.\n".format(exp_dir)) 119 | cfg.EXP_DIR = exp_dir 120 | 121 | # Save the config file in EXP_DIR 122 | dump_cfg(cfg) 123 | 124 | # Setup Logger 125 | lu.setup_logging(cfg) 126 | 127 | # Dataset preparing steps 128 | print("\n======== PREPARING TEST DATA ========\n") 129 | cfg.DATASET.ROOT_DIR = os.path.join(os.path.abspath('..'), cfg.DATASET.ROOT_DIR) 130 | data_obj = Data(cfg) 131 | test_data, test_size = data_obj.getDataset(save_dir=cfg.DATASET.ROOT_DIR, isTrain=False, isDownload=True) 132 | 133 | print("\nDataset {} Loaded Sucessfully. Total Test Size: {}\n".format(cfg.DATASET.NAME, test_size)) 134 | logger.info("Dataset {} Loaded Sucessfully. Total Test Size: {}\n".format(cfg.DATASET.NAME, test_size)) 135 | 136 | # Preparing dataloaders for testing 137 | test_loader = data_obj.getTestLoader(data=test_data, test_batch_size=cfg.TRAIN.BATCH_SIZE, seed_id=cfg.RNG_SEED) 138 | 139 | print("======== TESTING ========\n") 140 | logger.info("======== TESTING ========\n") 141 | test_acc = test_model(test_loader, os.path.join(os.path.abspath('..'), cfg.TEST.MODEL_PATH), cfg) 142 | print("Test Accuracy: {}.\n".format(round(test_acc, 4))) 143 | logger.info("Test Accuracy {}.\n".format(test_acc)) 144 | 145 | print('Check the test accuracy inside {}/stdout.log'.format(cfg.EXP_DIR)) 146 | 147 | print("================================\n\n") 148 | logger.info("================================\n\n") 149 | 150 | 151 | def test_model(test_loader, checkpoint_file, cfg, cur_episode=0): 152 | 153 | test_meter = TestMeter(len(test_loader)) 154 | 155 | model = model_builder.build_model(cfg) 156 | model = cu.load_checkpoint(checkpoint_file, model) 157 | 158 | test_err = test_epoch(test_loader, model, test_meter, cur_episode) 159 | test_acc = 100. - test_err 160 | 161 | return test_acc 162 | 163 | 164 | @torch.no_grad() 165 | def test_epoch(test_loader, model, test_meter, cur_epoch): 166 | """Evaluates the model on the test set.""" 167 | 168 | if torch.cuda.is_available(): 169 | model.cuda() 170 | 171 | # Enable eval mode 172 | model.eval() 173 | test_meter.iter_tic() 174 | 175 | misclassifications = 0. 176 | totalSamples = 0. 177 | 178 | for cur_iter, (inputs, labels) in enumerate(tqdm(test_loader, desc="Test Data")): 179 | with torch.no_grad(): 180 | # Transfer the data to the current GPU device 181 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) 182 | inputs = inputs.type(torch.cuda.FloatTensor) 183 | # Compute the predictions 184 | preds = model(inputs) 185 | # Compute the errors 186 | top1_err, top5_err = mu.topk_errors(preds, labels, [1, 5]) 187 | # Combine the errors across the GPUs 188 | # if cfg.NUM_GPUS > 1: 189 | # top1_err = du.scaled_all_reduce([top1_err]) 190 | # #as above returns a list 191 | # top1_err = top1_err[0] 192 | # Copy the errors from GPU to CPU (sync point) 193 | top1_err = top1_err.item() 194 | # Multiply by Number of GPU's as top1_err is scaled by 1/Num_GPUs 195 | misclassifications += top1_err * inputs.size(0) * cfg.NUM_GPUS 196 | totalSamples += inputs.size(0)*cfg.NUM_GPUS 197 | test_meter.iter_toc() 198 | # Update and log stats 199 | test_meter.update_stats( 200 | top1_err=top1_err, mb_size=inputs.size(0) * cfg.NUM_GPUS 201 | ) 202 | test_meter.log_iter_stats(cur_epoch, cur_iter) 203 | test_meter.iter_tic() 204 | # Log epoch stats 205 | test_meter.log_epoch_stats(cur_epoch) 206 | test_meter.reset() 207 | 208 | return misclassifications/totalSamples 209 | 210 | 211 | if __name__ == "__main__": 212 | cfg.merge_from_file(argparser().parse_args().cfg_file) 213 | main(cfg) 214 | --------------------------------------------------------------------------------