├── .gitignore
├── LICENSE
├── README.md
├── configs
├── cifar10
│ ├── al
│ │ ├── RESNET18.yaml
│ │ ├── RESNET18_ENS.yaml
│ │ └── RESNET18_IM.yaml
│ ├── evaluate
│ │ └── RESNET18.yaml
│ └── train
│ │ ├── RESNET18.yaml
│ │ └── RESNET18_ENS.yaml
├── cifar100
│ └── al
│ │ └── RESNET18.yaml
├── mnist
│ └── al
│ │ └── RESNET18.yaml
├── template.yaml
└── tinyimagenet
│ └── al
│ └── RESNET18.yaml
├── docs
├── AL_results.png
└── GETTING_STARTED.md
├── output
├── CIFAR10
│ └── resnet18
│ │ └── ENT_1
│ │ ├── Episodes_vs_Test Accuracy.png
│ │ ├── episode_0
│ │ ├── Epochs_vs_Loss.png
│ │ ├── Epochs_vs_Validation Accuracy.png
│ │ ├── Iterations_vs_Loss.png
│ │ ├── activeSet.npy
│ │ ├── lSet.npy
│ │ ├── plot_epoch_xvalues.txt
│ │ ├── plot_epoch_yvalues.txt
│ │ ├── plot_it_x_values.txt
│ │ ├── plot_it_y_values.txt
│ │ ├── uSet.npy
│ │ ├── val_acc_epochs_x.txt
│ │ └── val_acc_epochs_y.txt
│ │ ├── episode_1
│ │ ├── Epochs_vs_Loss.png
│ │ ├── Epochs_vs_Validation Accuracy.png
│ │ ├── Iterations_vs_Loss.png
│ │ ├── activeSet.npy
│ │ ├── lSet.npy
│ │ ├── plot_epoch_xvalues.txt
│ │ ├── plot_epoch_yvalues.txt
│ │ ├── plot_it_x_values.txt
│ │ ├── plot_it_y_values.txt
│ │ ├── uSet.npy
│ │ ├── val_acc_epochs_x.txt
│ │ └── val_acc_epochs_y.txt
│ │ ├── episode_2
│ │ ├── Epochs_vs_Loss.png
│ │ ├── Epochs_vs_Validation Accuracy.png
│ │ ├── Iterations_vs_Loss.png
│ │ ├── activeSet.npy
│ │ ├── lSet.npy
│ │ ├── plot_epoch_xvalues.txt
│ │ ├── plot_epoch_yvalues.txt
│ │ ├── plot_it_x_values.txt
│ │ ├── plot_it_y_values.txt
│ │ ├── uSet.npy
│ │ ├── val_acc_epochs_x.txt
│ │ └── val_acc_epochs_y.txt
│ │ ├── episode_3
│ │ ├── Epochs_vs_Loss.png
│ │ ├── Epochs_vs_Validation Accuracy.png
│ │ ├── Iterations_vs_Loss.png
│ │ ├── activeSet.npy
│ │ ├── lSet.npy
│ │ ├── plot_epoch_xvalues.txt
│ │ ├── plot_epoch_yvalues.txt
│ │ ├── plot_it_x_values.txt
│ │ ├── plot_it_y_values.txt
│ │ ├── uSet.npy
│ │ ├── val_acc_epochs_x.txt
│ │ └── val_acc_epochs_y.txt
│ │ ├── episode_4
│ │ ├── Epochs_vs_Loss.png
│ │ ├── Epochs_vs_Validation Accuracy.png
│ │ ├── Iterations_vs_Loss.png
│ │ ├── activeSet.npy
│ │ ├── lSet.npy
│ │ ├── plot_epoch_xvalues.txt
│ │ ├── plot_epoch_yvalues.txt
│ │ ├── plot_it_x_values.txt
│ │ ├── plot_it_y_values.txt
│ │ ├── uSet.npy
│ │ ├── val_acc_epochs_x.txt
│ │ └── val_acc_epochs_y.txt
│ │ ├── episode_5
│ │ ├── Epochs_vs_Loss.png
│ │ ├── Epochs_vs_Validation Accuracy.png
│ │ ├── Iterations_vs_Loss.png
│ │ ├── plot_epoch_xvalues.txt
│ │ ├── plot_epoch_yvalues.txt
│ │ ├── plot_it_x_values.txt
│ │ ├── plot_it_y_values.txt
│ │ ├── val_acc_epochs_x.txt
│ │ └── val_acc_epochs_y.txt
│ │ ├── lSet.npy
│ │ ├── plot_episode_xvalues.txt
│ │ ├── plot_episode_yvalues.txt
│ │ ├── stdout.log
│ │ ├── uSet.npy
│ │ └── valSet.npy
└── results_aggregator.ipynb
├── pycls
├── __init__.py
├── al
│ ├── ActiveLearning.py
│ ├── Sampling.py
│ ├── __init__.py
│ └── vaal_util.py
├── core
│ ├── __init__.py
│ ├── builders.py
│ ├── config.py
│ ├── losses.py
│ ├── net.py
│ └── optimizer.py
├── datasets
│ ├── __init__.py
│ ├── augment.py
│ ├── custom_datasets.py
│ ├── data.py
│ ├── imbalanced_cifar.py
│ ├── randaugment.py
│ ├── sampler.py
│ ├── simclr_augment.py
│ ├── tiny_imagenet.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── gaussian_blur.py
│ │ └── helpers.py
├── models
│ ├── __init__.py
│ ├── alexnet.py
│ ├── resnet.py
│ ├── vaal_model.py
│ └── vgg.py
└── utils
│ ├── __init__.py
│ ├── benchmark.py
│ ├── checkpoint.py
│ ├── distributed.py
│ ├── io.py
│ ├── logging.py
│ ├── meters.py
│ ├── metrics.py
│ ├── net.py
│ ├── plotting.py
│ └── timer.py
├── requirements.txt
└── tools
├── __init__.py
├── ensemble_al.py
├── ensemble_train.py
├── test_model.py
├── train.py
└── train_al.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Shared objects
7 | *.so
8 |
9 | # Distribution / packaging
10 | build/
11 | *.egg-info/
12 | *.egg
13 |
14 | # Temporary files
15 | *.swn
16 | *.swo
17 | *.swp
18 |
19 | # PyCharm
20 | .idea/
21 |
22 | # Mac
23 | .DS_STORE
24 |
25 | # Data symlinks
26 | pycls/datasets/data/
27 |
28 | # Other
29 | logs/
30 | scratch*
31 |
32 | data/
33 | output/
34 | *.ipynb_checkpoints/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright 2021 (c) Akshay L Chandra and Vineeth N Balasubramanian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Active Learning Toolkit for Image Classification in PyTorch
2 |
3 | This is a code base for deep active learning for image classification written in [PyTorch](https://pytorch.org/). It is build on top of FAIR's [pycls](https://github.com/facebookresearch/pycls/). I want to emphasize that this is a derivative of the toolkit originally shared with me via email by Prateek Munjal _et al._, the authors of the paper _"Towards Robust and Reproducible Active Learning using Neural Networks"_, paper available [here](https://arxiv.org/abs/2002.09564).
4 |
5 | ## Introduction
6 |
7 | The goal of this repository is to provide a simple and flexible codebase for deep active learning. It is designed to support rapid implementation and evaluation of research ideas. We also provide a results on CIFAR10 below.
8 |
9 | The codebase currently only supports single-machine single-gpu training. We will soon scale it to single-machine multi-gpu training, powered by the PyTorch distributed package.
10 |
11 | ## Using the Toolkit
12 |
13 | Please see [`GETTING_STARTED`](docs/GETTING_STARTED.md) for brief instructions on installation, adding new datasets, basic usage examples, etc.
14 |
15 | ## Active Learning Methods Supported
16 | * Uncertainty Sampling
17 | * Least Confidence
18 | * Min-Margin
19 | * Max-Entropy
20 | * Deep Bayesian Active Learning (DBAL) [1]
21 | * Bayesian Active Learning by Disagreement (BALD) [1]
22 | * Diversity Sampling
23 | * Coreset (greedy) [2]
24 | * Variational Adversarial Active Learning (VAAL) [3]
25 | * Query-by-Committee Sampling
26 | * Ensemble Variation Ratio (Ens-varR) [4]
27 |
28 |
29 | ## Datasets Supported
30 | * [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html)
31 | * [MNIST](http://yann.lecun.com/exdb/mnist/)
32 | * [SVHN](http://ufldl.stanford.edu/housenumbers/)
33 | * [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet) (Download the zip file [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip))
34 | * Long-Tail CIFAR-10/100
35 |
36 | Follow the instructions in [`GETTING_STARTED`](docs/GETTING_STARTED.md) to add a new dataset.
37 |
38 | ## Results on CIFAR10 and CIFAR100
39 |
40 | The following are the results on CIFAR10 and CIFAR100, trained with hyperameters present in `configs/cifar10/al/RESNET18.yaml` and `configs/cifar100/al/RESNET18.yaml` respectively. All results were averaged over 3 runs.
41 |
42 |
43 |
44 |
45 |

46 |
47 |
48 | ### CIFAR10 at 60%
49 | ```
50 | | AL Method | Test Accuracy |
51 | |:----------------:|:---------------------------:|
52 | | DBAL | 91.670000 +- 0.230651 |
53 | | Least Confidence | 91.510000 +- 0.087178 |
54 | | BALD | 91.470000 +- 0.293087 |
55 | | Coreset | 91.433333 +- 0.090738 |
56 | | Max-Entropy | 91.373333 +- 0.363639 |
57 | | Min-Margin | 91.333333 +- 0.234592 |
58 | | Ensemble-varR | 89.866667 +- 0.127410 |
59 | | Random | 89.803333 +- 0.230290 |
60 | | VAAL | 89.690000 +- 0.115326 |
61 | ```
62 |
63 | ### CIFAR100 at 60%
64 | ```
65 | | AL Method | Test Accuracy |
66 | |:----------------:|:---------------------------:|
67 | | DBAL | 55.400000 +- 1.037931 |
68 | | Coreset | 55.333333 +- 0.773714 |
69 | | Max-Entropy | 55.226667 +- 0.536128 |
70 | | BALD | 55.186667 +- 0.369639 |
71 | | Least Confidence | 55.003333 +- 0.937248 |
72 | | Min-Margin | 54.543333 +- 0.611583 |
73 | | Ensemble-varR | 54.186667 +- 0.325628 |
74 | | VAAL | 53.943333 +- 0.680686 |
75 | | Random | 53.546667 +- 0.302875 |
76 | ```
77 |
78 | ## Citing this Repository
79 |
80 | If you find this repo helpful in your research, please consider citing us and the owners of the original toolkit:
81 |
82 | ```
83 | @article{Chandra2021DeepAL,
84 | Author = {Akshay L Chandra and Vineeth N Balasubramanian},
85 | Title = {Deep Active Learning Toolkit for Image Classification in PyTorch},
86 | Journal = {https://github.com/acl21/deep-active-learning-pytorch},
87 | Year = {2021}
88 | }
89 |
90 | @article{Munjal2020TowardsRA,
91 | title={Towards Robust and Reproducible Active Learning Using Neural Networks},
92 | author={Prateek Munjal and N. Hayat and Munawar Hayat and J. Sourati and S. Khan},
93 | journal={ArXiv},
94 | year={2020},
95 | volume={abs/2002.09564}
96 | }
97 | ```
98 |
99 | ## License
100 |
101 | This toolkit is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information.
102 |
103 | ## References
104 |
105 | [1] Yarin Gal, Riashat Islam, and Zoubin Ghahramani. Deep bayesian active learning with image data. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1183–1192. JMLR. org, 2017.
106 |
107 | [2] Ozan Sener and Silvio Savarese. Active learning for convolutional neural networks: A core-set approach. In International Conference on Learning Representations, 2018.
108 |
109 | [3] Sinha, Samarth et al. Variational Adversarial Active Learning. 2019 IEEE/CVF International Conference on Computer Vision (ICCV) (2019): 5971-5980.
110 |
111 | [4] William H. Beluch, Tim Genewein, Andreas Nürnberger, and Jan M. Köhler. The power of ensembles for active learning in image classification. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 9368–9377, 2018.
--------------------------------------------------------------------------------
/configs/cifar10/al/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INIT_L_RATIO: 0.1
37 | BUDGET_SIZE: 5000
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
--------------------------------------------------------------------------------
/configs/cifar10/al/RESNET18_ENS.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | BUDGET_SIZE: 5000
37 | SAMPLING_FN: 'ensemble_var_R' # do not change this for ensemble learning
38 | MAX_ITER: 5
39 | ENSEMBLE:
40 | NUM_MODELS: 3
41 | MODEL_TYPE: ['resnet18']
--------------------------------------------------------------------------------
/configs/cifar10/al/RESNET18_IM.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: IMBALANCED_CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INITIAL_L_RATIO: 0.1
37 | BUDGET_SIZE: 1270
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
--------------------------------------------------------------------------------
/configs/cifar10/evaluate/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'SOME_RANDOM_NAME'
2 | RNG_SEED: 1
3 | GPU_ID: '3'
4 | DATASET:
5 | NAME: CIFAR10
6 | ROOT_DIR: 'data'
7 | MODEL:
8 | TYPE: resnet18
9 | NUM_CLASSES: 10
10 | TEST:
11 | SPLIT: test
12 | BATCH_SIZE: 200
13 | IM_SIZE: 32
14 | MODEL_PATH: 'path/to/saved/model'
15 | DATA_LOADER:
16 | NUM_WORKERS: 4
17 | CUDNN:
18 | BENCHMARK: True
--------------------------------------------------------------------------------
/configs/cifar10/train/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 1
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
--------------------------------------------------------------------------------
/configs/cifar10/train/RESNET18_ENS.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'SOME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR10
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ENSEMBLE:
36 | NUM_MODELS: 3
37 | MODEL_TYPE: ['resnet18']
--------------------------------------------------------------------------------
/configs/cifar100/al/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: CIFAR100
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 100
11 | OPTIM:
12 | TYPE: 'sgd'
13 | BASE_LR: 0.025
14 | LR_POLICY: cos
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 200
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 96
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INIT_L_RATIO: 0.1
37 | BUDGET_SIZE: 5000
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
--------------------------------------------------------------------------------
/configs/mnist/al/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: MNIST
5 | ROOT_DIR: './data'
6 | VAL_RATIO: 0.1
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 10
11 | OPTIM:
12 | TYPE: 'adam'
13 | BASE_LR: 0.005
14 | LR_POLICY: none
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 100
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 1
25 | IM_SIZE: 32
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 32
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INIT_L_RATIO: 0.001
37 | BUDGET_SIZE: 60
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
--------------------------------------------------------------------------------
/configs/template.yaml:
--------------------------------------------------------------------------------
1 | # Folder name where best model logs etc are saved. "auto" creates a timestamp based folder
2 | EXP_NAME: 'SOME_RANDOM_NAME'
3 | # Note that non-determinism may still be present due to non-deterministic
4 | # operator implementations in GPU operator libraries
5 | RNG_SEED: 1
6 | # GPU ID you want to execute the process on
7 | GPU_ID: '3'
8 | DATASET:
9 | NAME: CIFAR10 # or CIFAR100, MNIST, SVHN, TinyImageNet
10 | ROOT_DIR: 'data' # Relative path where data should be downloaded
11 | # Specifies the proportion of data in train set that should be considered as the validation data
12 | VAL_RATIO: 0.1
13 | # Data augmentation methods - 'simclr', 'randaug', 'horizontalflip'
14 | AUG_METHOD: 'horizontalflip'
15 | MODEL:
16 | # Model type.
17 | # Choose from vgg style ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',]
18 | # or from resnet style ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
19 | # 'wide_resnet50_2', 'wide_resnet101_2']
20 | TYPE: resnet18
21 | NUM_CLASSES: 10
22 | OPTIM:
23 | TYPE: 'sgd' # or 'adam'
24 | BASE_LR: 0.1
25 | # Learning rate policy select from {'cos', 'exp', 'steps'}
26 | LR_POLICY: steps
27 | # Steps for 'steps' policy (in epochs)
28 | STEPS: [0] #[0, 30, 60, 90]
29 | # Training Epochs
30 | MAX_EPOCH: 1
31 | # Momentum
32 | MOMENTUM: 0.9
33 | # Nesterov Momentum
34 | NESTEROV: False
35 | # L2 regularization
36 | WEIGHT_DECAY: 0.0005
37 | # Exponential decay factor
38 | GAMMA: 0.1
39 | TRAIN:
40 | SPLIT: train
41 | # Training mini-batch size
42 | BATCH_SIZE: 256
43 | # Image size
44 | IM_SIZE: 32
45 | IM_CHANNELS = 3
46 | # Evaluate model on test data every eval period epochs
47 | EVAL_PERIOD: 1
48 | TEST:
49 | SPLIT: test
50 | # Testing mini-batch size
51 | BATCH_SIZE: 200
52 | # Image size
53 | IM_SIZE: 32
54 | # Saved model to use for testing (useful when running tools/test_model.py)
55 | MODEL_PATH: ''
56 | DATA_LOADER:
57 | NUM_WORKERS: 4
58 | CUDNN:
59 | BENCHMARK: True
60 | ACTIVE_LEARNING:
61 | # Active sampling budget (at each episode)
62 | BUDGET_SIZE: 5000
63 | # Active sampling method
64 | SAMPLING_FN: 'dbal' # 'random', 'uncertainty', 'entropy', 'margin', 'bald', 'vaal', 'coreset', 'ensemble_var_R'
65 | # Initial labeled pool ratio (% of total train set that should be labeled before AL begins)
66 | INIT_L_RATIO: 0.1
67 | # Max AL episodes
68 | MAX_ITER: 1
69 | DROPOUT_ITERATIONS: 10 # Used by DBAL
70 | # Useful when running `tools/ensemble_al.py` or `tools/ensemble_train.py`
71 | ENSEMBLE:
72 | NUM_MODELS: 3
73 | MODEL_TYPE: ['resnet18']
--------------------------------------------------------------------------------
/configs/tinyimagenet/al/RESNET18.yaml:
--------------------------------------------------------------------------------
1 | # EXP_NAME: 'YOUR_EXPERIMENT_NAME'
2 | RNG_SEED: 1
3 | DATASET:
4 | NAME: TINYIMAGENET
5 | ROOT_DIR: '/path/to/dataset/directory/'
6 | VAL_RATIO: 0.05
7 | AUG_METHOD: 'hflip'
8 | MODEL:
9 | TYPE: resnet18
10 | NUM_CLASSES: 200
11 | OPTIM:
12 | TYPE: 'adam'
13 | BASE_LR: 0.001
14 | LR_POLICY: none
15 | LR_MULT: 0.1
16 | # STEPS: [0, 60, 120, 160, 200]
17 | MAX_EPOCH: 100
18 | MOMENTUM: 0.9
19 | NESTEROV: True
20 | WEIGHT_DECAY: 0.0003
21 | GAMMA: 0.1
22 | TRAIN:
23 | SPLIT: train
24 | BATCH_SIZE: 1
25 | IM_SIZE: 64
26 | EVAL_PERIOD: 2
27 | TEST:
28 | SPLIT: test
29 | BATCH_SIZE: 200
30 | IM_SIZE: 64
31 | DATA_LOADER:
32 | NUM_WORKERS: 4
33 | CUDNN:
34 | BENCHMARK: True
35 | ACTIVE_LEARNING:
36 | INIT_L_RATIO: 0.1
37 | BUDGET_SIZE: 5000
38 | # SAMPLING_FN: 'uncertainty'
39 | MAX_ITER: 5
--------------------------------------------------------------------------------
/docs/AL_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/docs/AL_results.png
--------------------------------------------------------------------------------
/docs/GETTING_STARTED.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | ## Environment Setup
4 |
5 | Clone the repository:
6 |
7 | ```
8 | git clone https://github.com/acl21/deep-active-learning-pytorch
9 | ```
10 |
11 | Install dependencies:
12 |
13 | ```
14 | pip install -r requirements.txt
15 | ```
16 |
17 | ## Understanding the Config File
18 | ```
19 | # Folder name where best model logs etc are saved. Setting EXP_NAME: "auto" creates a timestamp named folder
20 | EXP_NAME: 'YOUR_EXPERIMENT_NAME'
21 | # Note that non-determinism may still be present due to non-deterministic
22 | # operator implementations in GPU operator libraries
23 | RNG_SEED: 1
24 | # GPU ID you want to execute the process on (this feature isn't working as of now, use the commands shown in this file below instead)
25 | GPU_ID: '3'
26 | DATASET:
27 | NAME: CIFAR10 # or CIFAR100, MNIST, SVHN, TINYIMAGENET, IMBALANCED_CIFAR10/100
28 | ROOT_DIR: 'data' # Relative path where data should be downloaded
29 | # Specifies the proportion of data in train set that should be considered as the validation data
30 | VAL_RATIO: 0.1
31 | # Data augmentation methods - 'simclr', 'randaug', 'hflip'
32 | AUG_METHOD: 'hflip'
33 | MODEL:
34 | # Model type.
35 | # Choose from vgg style ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',]
36 | # or from resnet style ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
37 | # 'wide_resnet50_2', 'wide_resnet101_2']
38 | # or `alexnet`
39 | TYPE: resnet18
40 | NUM_CLASSES: 10
41 | OPTIM:
42 | TYPE: 'sgd' # or 'adam'
43 | BASE_LR: 0.025
44 | # Learning rate policy select from {'cos', 'exp', 'steps' or 'none'}
45 | LR_POLICY: cos
46 | # Steps for 'steps' policy (in epochs)
47 | STEPS: [0] #[0, 30, 60, 90]
48 | # Training Epochs
49 | MAX_EPOCH: 1
50 | # Momentum
51 | MOMENTUM: 0.9
52 | # Nesterov Momentum
53 | NESTEROV: False
54 | # L2 regularization
55 | WEIGHT_DECAY: 0.0005
56 | # Exponential decay factor
57 | GAMMA: 0.1
58 | TRAIN:
59 | SPLIT: train
60 | # Training mini-batch size
61 | BATCH_SIZE: 96
62 | # Image size
63 | IM_SIZE: 32
64 | IM_CHANNELS = 3
65 | # Evaluate model on test data every eval period epochs
66 | EVAL_PERIOD: 2
67 | TEST:
68 | SPLIT: test
69 | # Testing mini-batch size
70 | BATCH_SIZE: 200
71 | # Image size
72 | IM_SIZE: 32
73 | # Saved model to use for testing (useful when running `tools/test_model.py`)
74 | MODEL_PATH: ''
75 | DATA_LOADER:
76 | NUM_WORKERS: 4
77 | CUDNN:
78 | BENCHMARK: True
79 | ACTIVE_LEARNING:
80 | # Active sampling budget (at each episode)
81 | BUDGET_SIZE: 5000
82 | # Active sampling method
83 | SAMPLING_FN: 'dbal' # 'random', 'uncertainty', 'entropy', 'margin', 'bald', 'vaal', 'coreset', 'ensemble_var_R'
84 | # Initial labeled pool ratio (% of total train set that should be labeled before AL begins)
85 | INIT_L_RATIO: 0.1
86 | # Max AL episodes
87 | MAX_ITER: 5
88 | DROPOUT_ITERATIONS: 25 # Used by DBAL
89 | # Useful when running `ensemble_al.py` or `ensemble_train.py`
90 | ENSEMBLE:
91 | NUM_MODELS: 3
92 | MODEL_TYPE: ['resnet18']
93 | ```
94 |
95 | Please refer to `pycls/core/config.py` to configure your experiments at a deeper level.
96 |
97 |
98 | ## Execution Commands
99 | ### Active Learning
100 | Once the config file is configured appropriately, perform DBAL active learning with the following command inside the `tools` directory.
101 |
102 | ```
103 | CUDA_VISIBLE_DEVICES=0 python train_al.py \
104 | --cfg=../configs/cifar10/al/RESNET18.yaml --al=dbal --exp-name=YOUR_EXPERIMENT_NAME
105 | ```
106 |
107 | ### Ensemble Active Learning
108 |
109 | Watch out for the ensemble options in the config file. This setting by default using _Ensemble Variation-Ratio_ as the query method.
110 |
111 | ```
112 | CUDA_VISIBLE_DEVICES=0 python ensemble_al.py \
113 | --cfg=../configs/cifar10/al/RESNET18.yaml --exp-name=YOUR_EXPERIMENT_NAME
114 | ```
115 |
116 | ### Passive Learning
117 |
118 | ```
119 | CUDA_VISIBLE_DEVICES=0 python train.py \
120 | --cfg=../configs/cifar10/train/RESNET18.yaml --exp-name=YOUR_EXPERIMENT_NAME
121 | ```
122 |
123 | ### Ensemble Passive Learning
124 |
125 | Watch out for the ensemble options in the config file.
126 |
127 | ```
128 | CUDA_VISIBLE_DEVICES=0 python ensemble_train.py \
129 | --cfg=../configs/cifar10/train/RESNET18_ENS.yaml --exp-name=YOUR_EXPERIMENT_NAME
130 | ```
131 |
132 | ### Specific Model Evaluation
133 |
134 | This is useful if you want to evaluate a particular saved model. Pass the path to the model in the yaml file. Refer to the file inside the `config/evaluate` directory for clarity.
135 |
136 | ```
137 | CUDA_VISIBLE_DEVICES=0 python test_model.py \
138 | --cfg configs/cifar10/evaluate/RESNET18.yaml
139 | ```
140 |
141 |
142 | ## Add Your Own Dataset
143 |
144 | To add your own dataset, you need to do the following:
145 | 1. Write the PyTorch Dataset code for your custom dataset (or you could directly use the ones [PyTorch provides](https://pytorch.org/vision/stable/datasets.html)).
146 | 2. Create a sub class of the above Dataset with some desirable modifications and add it to the `pycls/datasets/custom_datasets.py`.
147 | * We add two new variables to the dataset - a boolean flag `no_aug` and `test_transform`.
148 | * We set the flag `no_aug = True` before iterating through unlabeled and the validations dataloaders so that data doesn't get augmented.
149 | * See how we modify the `__get_item__` function to achieve that:
150 | ```
151 | class CIFAR10(torchvision.datasets.CIFAR10):
152 | def __init__(self, root, train, transform, test_transform, download=True):
153 | super(CIFAR10, self).__init__(root, train, transform=transform, download=download)
154 | self.test_transform = test_transform
155 | self.no_aug = False
156 |
157 | def __getitem__(self, index: int):
158 | """
159 | Args:
160 | index (int): Index
161 |
162 | Returns:
163 | tuple: (image, target) where target is index of the target class.
164 | """
165 | img, target = self.data[index], self.targets[index]
166 |
167 | # doing this so that it is consistent with all other datasets
168 | # to return a PIL Image
169 | img = Image.fromarray(img)
170 |
171 | ##########################
172 | # set True before iterating through unlabeled or validation set
173 | if self.no_aug:
174 | if self.test_transform is not None:
175 | img = self.test_transform(img)
176 | else:
177 | if self.transform is not None:
178 | img = self.transform(img)
179 | #########################
180 |
181 | return img, target
182 | ```
183 | 3. Add your dataset in `pycls/dataset/data.py`
184 | * Add appropriate preprocessing steps to `getPreprocessOps`
185 | * Add the dataset call to `getDataset`
186 | 4. Create appropriate config `yaml` files and use them for training AL.
187 |
188 |
189 | ## Some Comments About Our Toolkit
190 | * Our toolkit currently only supports 'SGD' (with learning rate scheduler) and 'Adam' (no scheduler).
191 | * We log everything. Our toolkit saves the indices of the initial labeled pool, samples queried each episode, episode wise best model, visual plots for "Iteration vs Loss", "Epoch vs Val Accuracy", "Episode vs Test Accuracy" and more. Please check an experiment's logs at `output/CIFAR10/resnet18/ENT_1/` for clarity.
192 | * We added dropout (p=0.5) to all our models just before the final fully connected layer. We do this to allow the DBAL and BALD query methods to work.
193 | * We also provide an iPython notebook that aggregates results directly from the experiment folders. You can find it at `output/results_aggregator.ipynb`.
194 | * If you add your own dataset, please make sure you to create the custom version as explained in point 2 in the instructions. Failing to do that would mean that your unlabeled data (big red flag for AL) and validation data will have been augmentated. This is because we use a single dataset instance and subset and index based dataloaders.
195 | * We tested the toolkit only on a Linux machine with Python 3.8.
196 | * Please create an issue with appropriate details:
197 | * if you are unable to get the toolkit to work or run into any problems
198 | * if we have not provided credits correctly to the rightful owner (please attach proof)
199 | * if you notice any flaws in the implementation
200 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/Episodes_vs_Test Accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/Episodes_vs_Test Accuracy.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Validation Accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/Epochs_vs_Validation Accuracy.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/Iterations_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/Iterations_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/activeSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/activeSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/lSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/lSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/plot_epoch_xvalues.txt:
--------------------------------------------------------------------------------
1 | 1.00
2 | 2.00
3 | 3.00
4 | 4.00
5 | 5.00
6 | 6.00
7 | 7.00
8 | 8.00
9 | 9.00
10 | 10.00
11 | 11.00
12 | 12.00
13 | 13.00
14 | 14.00
15 | 15.00
16 | 16.00
17 | 17.00
18 | 18.00
19 | 19.00
20 | 20.00
21 | 21.00
22 | 22.00
23 | 23.00
24 | 24.00
25 | 25.00
26 | 26.00
27 | 27.00
28 | 28.00
29 | 29.00
30 | 30.00
31 | 31.00
32 | 32.00
33 | 33.00
34 | 34.00
35 | 35.00
36 | 36.00
37 | 37.00
38 | 38.00
39 | 39.00
40 | 40.00
41 | 41.00
42 | 42.00
43 | 43.00
44 | 44.00
45 | 45.00
46 | 46.00
47 | 47.00
48 | 48.00
49 | 49.00
50 | 50.00
51 | 51.00
52 | 52.00
53 | 53.00
54 | 54.00
55 | 55.00
56 | 56.00
57 | 57.00
58 | 58.00
59 | 59.00
60 | 60.00
61 | 61.00
62 | 62.00
63 | 63.00
64 | 64.00
65 | 65.00
66 | 66.00
67 | 67.00
68 | 68.00
69 | 69.00
70 | 70.00
71 | 71.00
72 | 72.00
73 | 73.00
74 | 74.00
75 | 75.00
76 | 76.00
77 | 77.00
78 | 78.00
79 | 79.00
80 | 80.00
81 | 81.00
82 | 82.00
83 | 83.00
84 | 84.00
85 | 85.00
86 | 86.00
87 | 87.00
88 | 88.00
89 | 89.00
90 | 90.00
91 | 91.00
92 | 92.00
93 | 93.00
94 | 94.00
95 | 95.00
96 | 96.00
97 | 97.00
98 | 98.00
99 | 99.00
100 | 100.00
101 | 101.00
102 | 102.00
103 | 103.00
104 | 104.00
105 | 105.00
106 | 106.00
107 | 107.00
108 | 108.00
109 | 109.00
110 | 110.00
111 | 111.00
112 | 112.00
113 | 113.00
114 | 114.00
115 | 115.00
116 | 116.00
117 | 117.00
118 | 118.00
119 | 119.00
120 | 120.00
121 | 121.00
122 | 122.00
123 | 123.00
124 | 124.00
125 | 125.00
126 | 126.00
127 | 127.00
128 | 128.00
129 | 129.00
130 | 130.00
131 | 131.00
132 | 132.00
133 | 133.00
134 | 134.00
135 | 135.00
136 | 136.00
137 | 137.00
138 | 138.00
139 | 139.00
140 | 140.00
141 | 141.00
142 | 142.00
143 | 143.00
144 | 144.00
145 | 145.00
146 | 146.00
147 | 147.00
148 | 148.00
149 | 149.00
150 | 150.00
151 | 151.00
152 | 152.00
153 | 153.00
154 | 154.00
155 | 155.00
156 | 156.00
157 | 157.00
158 | 158.00
159 | 159.00
160 | 160.00
161 | 161.00
162 | 162.00
163 | 163.00
164 | 164.00
165 | 165.00
166 | 166.00
167 | 167.00
168 | 168.00
169 | 169.00
170 | 170.00
171 | 171.00
172 | 172.00
173 | 173.00
174 | 174.00
175 | 175.00
176 | 176.00
177 | 177.00
178 | 178.00
179 | 179.00
180 | 180.00
181 | 181.00
182 | 182.00
183 | 183.00
184 | 184.00
185 | 185.00
186 | 186.00
187 | 187.00
188 | 188.00
189 | 189.00
190 | 190.00
191 | 191.00
192 | 192.00
193 | 193.00
194 | 194.00
195 | 195.00
196 | 196.00
197 | 197.00
198 | 198.00
199 | 199.00
200 | 200.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/plot_epoch_yvalues.txt:
--------------------------------------------------------------------------------
1 | 2.39
2 | 2.34
3 | 2.04
4 | 0.69
5 | 2.46
6 | 2.31
7 | 2.15
8 | 1.69
9 | 1.67
10 | 1.70
11 | 1.93
12 | 1.74
13 | 1.89
14 | 2.85
15 | 1.80
16 | 0.71
17 | 2.35
18 | 1.73
19 | 1.20
20 | 0.89
21 | 1.14
22 | 1.34
23 | 1.16
24 | 0.98
25 | 0.97
26 | 1.59
27 | 0.72
28 | 0.59
29 | 0.94
30 | 1.27
31 | 1.53
32 | 0.72
33 | 0.59
34 | 1.31
35 | 1.28
36 | 1.05
37 | 1.91
38 | 1.23
39 | 0.73
40 | 0.57
41 | 1.86
42 | 3.92
43 | 0.35
44 | 0.33
45 | 1.12
46 | 1.46
47 | 0.87
48 | 0.49
49 | 2.36
50 | 1.00
51 | 1.35
52 | 2.22
53 | 0.49
54 | 0.92
55 | 0.37
56 | 0.17
57 | 0.23
58 | 1.24
59 | 1.54
60 | 1.28
61 | 0.74
62 | 0.82
63 | 0.24
64 | 0.16
65 | 0.14
66 | 0.66
67 | 0.10
68 | 0.60
69 | 2.10
70 | 0.09
71 | 2.29
72 | 0.80
73 | 0.17
74 | 0.55
75 | 0.18
76 | 0.63
77 | 0.19
78 | 1.30
79 | 0.14
80 | 3.27
81 | 0.63
82 | 0.41
83 | 0.88
84 | 0.23
85 | 0.18
86 | 1.05
87 | 0.42
88 | 1.66
89 | 0.17
90 | 0.01
91 | 1.55
92 | 0.80
93 | 0.25
94 | 1.22
95 | 0.27
96 | 0.04
97 | 0.63
98 | 0.58
99 | 0.37
100 | 0.15
101 | 0.07
102 | 0.90
103 | 1.04
104 | 1.05
105 | 0.34
106 | 0.61
107 | 1.22
108 | 0.34
109 | 0.16
110 | 0.74
111 | 0.14
112 | 0.25
113 | 0.82
114 | 1.69
115 | 0.27
116 | 1.72
117 | 1.15
118 | 1.24
119 | 0.04
120 | 0.16
121 | 0.01
122 | 0.57
123 | 0.10
124 | 0.04
125 | 1.02
126 | 0.02
127 | 0.20
128 | 0.33
129 | 0.08
130 | 0.54
131 | 0.31
132 | 0.40
133 | 0.35
134 | 1.19
135 | 0.09
136 | 0.49
137 | 0.36
138 | 0.54
139 | 0.29
140 | 0.36
141 | 0.42
142 | 0.13
143 | 0.06
144 | 0.33
145 | 0.04
146 | 0.30
147 | 0.01
148 | 0.51
149 | 0.00
150 | 0.03
151 | 0.25
152 | 0.19
153 | 0.92
154 | 0.28
155 | 0.02
156 | 0.04
157 | 0.04
158 | 0.01
159 | 0.33
160 | 0.02
161 | 0.02
162 | 1.31
163 | 0.02
164 | 0.01
165 | 1.45
166 | 0.00
167 | 0.67
168 | 0.00
169 | 0.05
170 | 0.00
171 | 0.08
172 | 0.05
173 | 0.66
174 | 0.34
175 | 0.25
176 | 0.79
177 | 0.03
178 | 0.01
179 | 0.01
180 | 0.01
181 | 0.02
182 | 0.00
183 | 1.77
184 | 0.31
185 | 0.08
186 | 0.08
187 | 0.19
188 | 0.43
189 | 1.09
190 | 0.51
191 | 0.92
192 | 0.02
193 | 0.03
194 | 0.11
195 | 0.86
196 | 1.21
197 | 1.27
198 | 0.06
199 | 0.00
200 | 0.75
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/plot_it_x_values.txt:
--------------------------------------------------------------------------------
1 | 19.00
2 | 38.00
3 | 72.00
4 | 91.00
5 | 125.00
6 | 144.00
7 | 178.00
8 | 197.00
9 | 231.00
10 | 250.00
11 | 284.00
12 | 303.00
13 | 337.00
14 | 356.00
15 | 390.00
16 | 409.00
17 | 443.00
18 | 462.00
19 | 496.00
20 | 515.00
21 | 549.00
22 | 568.00
23 | 602.00
24 | 621.00
25 | 655.00
26 | 674.00
27 | 708.00
28 | 727.00
29 | 761.00
30 | 780.00
31 | 814.00
32 | 833.00
33 | 867.00
34 | 886.00
35 | 920.00
36 | 939.00
37 | 973.00
38 | 992.00
39 | 1026.00
40 | 1045.00
41 | 1079.00
42 | 1098.00
43 | 1132.00
44 | 1151.00
45 | 1185.00
46 | 1204.00
47 | 1238.00
48 | 1257.00
49 | 1291.00
50 | 1310.00
51 | 1344.00
52 | 1363.00
53 | 1397.00
54 | 1416.00
55 | 1450.00
56 | 1469.00
57 | 1503.00
58 | 1522.00
59 | 1556.00
60 | 1575.00
61 | 1609.00
62 | 1628.00
63 | 1662.00
64 | 1681.00
65 | 1715.00
66 | 1734.00
67 | 1768.00
68 | 1787.00
69 | 1821.00
70 | 1840.00
71 | 1874.00
72 | 1893.00
73 | 1927.00
74 | 1946.00
75 | 1980.00
76 | 1999.00
77 | 2033.00
78 | 2052.00
79 | 2086.00
80 | 2105.00
81 | 2139.00
82 | 2158.00
83 | 2192.00
84 | 2211.00
85 | 2245.00
86 | 2264.00
87 | 2298.00
88 | 2317.00
89 | 2351.00
90 | 2370.00
91 | 2404.00
92 | 2423.00
93 | 2457.00
94 | 2476.00
95 | 2510.00
96 | 2529.00
97 | 2563.00
98 | 2582.00
99 | 2616.00
100 | 2635.00
101 | 2669.00
102 | 2688.00
103 | 2722.00
104 | 2741.00
105 | 2775.00
106 | 2794.00
107 | 2828.00
108 | 2847.00
109 | 2881.00
110 | 2900.00
111 | 2934.00
112 | 2953.00
113 | 2987.00
114 | 3006.00
115 | 3040.00
116 | 3059.00
117 | 3093.00
118 | 3112.00
119 | 3146.00
120 | 3165.00
121 | 3199.00
122 | 3218.00
123 | 3252.00
124 | 3271.00
125 | 3305.00
126 | 3324.00
127 | 3358.00
128 | 3377.00
129 | 3411.00
130 | 3430.00
131 | 3464.00
132 | 3483.00
133 | 3517.00
134 | 3536.00
135 | 3570.00
136 | 3589.00
137 | 3623.00
138 | 3642.00
139 | 3676.00
140 | 3695.00
141 | 3729.00
142 | 3748.00
143 | 3782.00
144 | 3801.00
145 | 3835.00
146 | 3854.00
147 | 3888.00
148 | 3907.00
149 | 3941.00
150 | 3960.00
151 | 3994.00
152 | 4013.00
153 | 4047.00
154 | 4066.00
155 | 4100.00
156 | 4119.00
157 | 4153.00
158 | 4172.00
159 | 4206.00
160 | 4225.00
161 | 4259.00
162 | 4278.00
163 | 4312.00
164 | 4331.00
165 | 4365.00
166 | 4384.00
167 | 4418.00
168 | 4437.00
169 | 4471.00
170 | 4490.00
171 | 4524.00
172 | 4543.00
173 | 4577.00
174 | 4596.00
175 | 4630.00
176 | 4649.00
177 | 4683.00
178 | 4702.00
179 | 4736.00
180 | 4755.00
181 | 4789.00
182 | 4808.00
183 | 4842.00
184 | 4861.00
185 | 4895.00
186 | 4914.00
187 | 4948.00
188 | 4967.00
189 | 5001.00
190 | 5020.00
191 | 5054.00
192 | 5073.00
193 | 5107.00
194 | 5126.00
195 | 5160.00
196 | 5179.00
197 | 5213.00
198 | 5232.00
199 | 5266.00
200 | 5285.00
201 | 5319.00
202 | 5338.00
203 | 5372.00
204 | 5391.00
205 | 5425.00
206 | 5444.00
207 | 5478.00
208 | 5497.00
209 | 5531.00
210 | 5550.00
211 | 5584.00
212 | 5603.00
213 | 5637.00
214 | 5656.00
215 | 5690.00
216 | 5709.00
217 | 5743.00
218 | 5762.00
219 | 5796.00
220 | 5815.00
221 | 5849.00
222 | 5868.00
223 | 5902.00
224 | 5921.00
225 | 5955.00
226 | 5974.00
227 | 6008.00
228 | 6027.00
229 | 6061.00
230 | 6080.00
231 | 6114.00
232 | 6133.00
233 | 6167.00
234 | 6186.00
235 | 6220.00
236 | 6239.00
237 | 6273.00
238 | 6292.00
239 | 6326.00
240 | 6345.00
241 | 6379.00
242 | 6398.00
243 | 6432.00
244 | 6451.00
245 | 6485.00
246 | 6504.00
247 | 6538.00
248 | 6557.00
249 | 6591.00
250 | 6610.00
251 | 6644.00
252 | 6663.00
253 | 6697.00
254 | 6716.00
255 | 6750.00
256 | 6769.00
257 | 6803.00
258 | 6822.00
259 | 6856.00
260 | 6875.00
261 | 6909.00
262 | 6928.00
263 | 6962.00
264 | 6981.00
265 | 7015.00
266 | 7034.00
267 | 7068.00
268 | 7087.00
269 | 7121.00
270 | 7140.00
271 | 7174.00
272 | 7193.00
273 | 7227.00
274 | 7246.00
275 | 7280.00
276 | 7299.00
277 | 7333.00
278 | 7352.00
279 | 7386.00
280 | 7405.00
281 | 7439.00
282 | 7458.00
283 | 7492.00
284 | 7511.00
285 | 7545.00
286 | 7564.00
287 | 7598.00
288 | 7617.00
289 | 7651.00
290 | 7670.00
291 | 7704.00
292 | 7723.00
293 | 7757.00
294 | 7776.00
295 | 7810.00
296 | 7829.00
297 | 7863.00
298 | 7882.00
299 | 7916.00
300 | 7935.00
301 | 7969.00
302 | 7988.00
303 | 8022.00
304 | 8041.00
305 | 8075.00
306 | 8094.00
307 | 8128.00
308 | 8147.00
309 | 8181.00
310 | 8200.00
311 | 8234.00
312 | 8253.00
313 | 8287.00
314 | 8306.00
315 | 8340.00
316 | 8359.00
317 | 8393.00
318 | 8412.00
319 | 8446.00
320 | 8465.00
321 | 8499.00
322 | 8518.00
323 | 8552.00
324 | 8571.00
325 | 8605.00
326 | 8624.00
327 | 8658.00
328 | 8677.00
329 | 8711.00
330 | 8730.00
331 | 8764.00
332 | 8783.00
333 | 8817.00
334 | 8836.00
335 | 8870.00
336 | 8889.00
337 | 8923.00
338 | 8942.00
339 | 8976.00
340 | 8995.00
341 | 9029.00
342 | 9048.00
343 | 9082.00
344 | 9101.00
345 | 9135.00
346 | 9154.00
347 | 9188.00
348 | 9207.00
349 | 9241.00
350 | 9260.00
351 | 9294.00
352 | 9313.00
353 | 9347.00
354 | 9366.00
355 | 9400.00
356 | 9419.00
357 | 9453.00
358 | 9472.00
359 | 9506.00
360 | 9525.00
361 | 9559.00
362 | 9578.00
363 | 9612.00
364 | 9631.00
365 | 9665.00
366 | 9684.00
367 | 9718.00
368 | 9737.00
369 | 9771.00
370 | 9790.00
371 | 9824.00
372 | 9843.00
373 | 9877.00
374 | 9896.00
375 | 9930.00
376 | 9949.00
377 | 9983.00
378 | 10002.00
379 | 10036.00
380 | 10055.00
381 | 10089.00
382 | 10108.00
383 | 10142.00
384 | 10161.00
385 | 10195.00
386 | 10214.00
387 | 10248.00
388 | 10267.00
389 | 10301.00
390 | 10320.00
391 | 10354.00
392 | 10373.00
393 | 10407.00
394 | 10426.00
395 | 10460.00
396 | 10479.00
397 | 10513.00
398 | 10532.00
399 | 10566.00
400 | 10585.00
401 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/plot_it_y_values.txt:
--------------------------------------------------------------------------------
1 | 2.29
2 | 2.20
3 | 1.98
4 | 2.13
5 | 2.03
6 | 2.19
7 | 1.74
8 | 1.64
9 | 1.80
10 | 1.83
11 | 1.62
12 | 1.42
13 | 1.61
14 | 1.51
15 | 1.56
16 | 1.46
17 | 1.45
18 | 1.39
19 | 1.34
20 | 1.23
21 | 1.33
22 | 1.24
23 | 1.35
24 | 1.38
25 | 1.20
26 | 1.24
27 | 1.14
28 | 1.13
29 | 1.39
30 | 1.25
31 | 1.28
32 | 1.08
33 | 1.40
34 | 1.15
35 | 1.44
36 | 1.29
37 | 1.03
38 | 0.86
39 | 0.96
40 | 0.91
41 | 0.97
42 | 1.20
43 | 0.93
44 | 1.05
45 | 0.98
46 | 0.89
47 | 0.88
48 | 0.68
49 | 0.88
50 | 0.70
51 | 0.80
52 | 0.92
53 | 0.80
54 | 0.85
55 | 0.72
56 | 0.78
57 | 0.76
58 | 0.86
59 | 1.06
60 | 0.87
61 | 0.76
62 | 0.75
63 | 0.85
64 | 0.79
65 | 0.83
66 | 0.63
67 | 0.60
68 | 0.65
69 | 0.79
70 | 0.69
71 | 0.70
72 | 0.60
73 | 0.96
74 | 0.65
75 | 0.84
76 | 0.62
77 | 0.89
78 | 0.74
79 | 0.59
80 | 0.58
81 | 0.69
82 | 0.58
83 | 0.73
84 | 0.62
85 | 0.70
86 | 0.76
87 | 0.60
88 | 0.60
89 | 0.68
90 | 0.47
91 | 0.54
92 | 0.42
93 | 0.80
94 | 0.59
95 | 0.54
96 | 0.47
97 | 0.60
98 | 0.41
99 | 0.74
100 | 0.57
101 | 0.65
102 | 0.55
103 | 0.70
104 | 0.45
105 | 0.44
106 | 0.44
107 | 0.49
108 | 0.44
109 | 0.43
110 | 0.59
111 | 0.55
112 | 0.39
113 | 0.42
114 | 0.45
115 | 0.38
116 | 0.34
117 | 0.54
118 | 0.41
119 | 0.38
120 | 0.23
121 | 0.61
122 | 0.34
123 | 0.26
124 | 0.24
125 | 0.28
126 | 0.37
127 | 0.51
128 | 0.37
129 | 0.39
130 | 0.37
131 | 0.37
132 | 0.29
133 | 0.31
134 | 0.22
135 | 0.28
136 | 0.20
137 | 0.32
138 | 0.34
139 | 0.57
140 | 0.24
141 | 0.25
142 | 0.31
143 | 0.76
144 | 0.31
145 | 0.24
146 | 0.25
147 | 0.34
148 | 0.21
149 | 0.19
150 | 0.23
151 | 0.26
152 | 0.27
153 | 0.32
154 | 0.24
155 | 0.17
156 | 0.29
157 | 0.60
158 | 0.37
159 | 0.17
160 | 0.28
161 | 0.42
162 | 0.28
163 | 0.25
164 | 0.31
165 | 0.30
166 | 0.16
167 | 0.20
168 | 0.13
169 | 0.12
170 | 0.16
171 | 0.09
172 | 0.12
173 | 0.18
174 | 0.14
175 | 0.31
176 | 0.21
177 | 0.27
178 | 0.22
179 | 0.06
180 | 0.17
181 | 0.07
182 | 0.18
183 | 0.22
184 | 0.28
185 | 0.26
186 | 0.22
187 | 0.14
188 | 0.27
189 | 0.35
190 | 0.13
191 | 0.20
192 | 0.09
193 | 0.11
194 | 0.04
195 | 0.23
196 | 0.23
197 | 0.37
198 | 0.23
199 | 0.30
200 | 0.17
201 | 0.12
202 | 0.21
203 | 0.05
204 | 0.07
205 | 0.17
206 | 0.12
207 | 0.21
208 | 0.10
209 | 0.21
210 | 0.08
211 | 0.08
212 | 0.12
213 | 0.10
214 | 0.13
215 | 0.18
216 | 0.12
217 | 0.24
218 | 0.11
219 | 0.06
220 | 0.07
221 | 0.13
222 | 0.07
223 | 0.12
224 | 0.08
225 | 0.07
226 | 0.05
227 | 0.20
228 | 0.15
229 | 0.23
230 | 0.14
231 | 0.09
232 | 0.12
233 | 0.20
234 | 0.13
235 | 0.29
236 | 0.15
237 | 0.11
238 | 0.09
239 | 0.05
240 | 0.07
241 | 0.08
242 | 0.05
243 | 0.03
244 | 0.03
245 | 0.10
246 | 0.14
247 | 0.04
248 | 0.06
249 | 0.09
250 | 0.07
251 | 0.12
252 | 0.19
253 | 0.02
254 | 0.01
255 | 0.05
256 | 0.05
257 | 0.02
258 | 0.04
259 | 0.06
260 | 0.10
261 | 0.06
262 | 0.09
263 | 0.06
264 | 0.05
265 | 0.17
266 | 0.03
267 | 0.06
268 | 0.03
269 | 0.10
270 | 0.11
271 | 0.10
272 | 0.06
273 | 0.04
274 | 0.08
275 | 0.02
276 | 0.08
277 | 0.01
278 | 0.04
279 | 0.03
280 | 0.01
281 | 0.07
282 | 0.03
283 | 0.04
284 | 0.04
285 | 0.05
286 | 0.03
287 | 0.02
288 | 0.01
289 | 0.04
290 | 0.03
291 | 0.02
292 | 0.00
293 | 0.03
294 | 0.02
295 | 0.02
296 | 0.01
297 | 0.04
298 | 0.05
299 | 0.03
300 | 0.01
301 | 0.01
302 | 0.02
303 | 0.02
304 | 0.01
305 | 0.01
306 | 0.01
307 | 0.07
308 | 0.06
309 | 0.01
310 | 0.02
311 | 0.02
312 | 0.05
313 | 0.01
314 | 0.00
315 | 0.01
316 | 0.01
317 | 0.01
318 | 0.01
319 | 0.03
320 | 0.02
321 | 0.01
322 | 0.03
323 | 0.00
324 | 0.02
325 | 0.01
326 | 0.03
327 | 0.00
328 | 0.01
329 | 0.01
330 | 0.00
331 | 0.02
332 | 0.02
333 | 0.02
334 | 0.03
335 | 0.00
336 | 0.01
337 | 0.01
338 | 0.02
339 | 0.00
340 | 0.01
341 | 0.00
342 | 0.01
343 | 0.02
344 | 0.04
345 | 0.01
346 | 0.00
347 | 0.01
348 | 0.00
349 | 0.03
350 | 0.02
351 | 0.00
352 | 0.04
353 | 0.03
354 | 0.01
355 | 0.01
356 | 0.02
357 | 0.02
358 | 0.00
359 | 0.01
360 | 0.00
361 | 0.04
362 | 0.00
363 | 0.00
364 | 0.01
365 | 0.04
366 | 0.00
367 | 0.01
368 | 0.00
369 | 0.02
370 | 0.01
371 | 0.01
372 | 0.02
373 | 0.00
374 | 0.00
375 | 0.00
376 | 0.01
377 | 0.01
378 | 0.03
379 | 0.00
380 | 0.01
381 | 0.01
382 | 0.00
383 | 0.01
384 | 0.01
385 | 0.01
386 | 0.02
387 | 0.02
388 | 0.01
389 | 0.01
390 | 0.02
391 | 0.00
392 | 0.00
393 | 0.01
394 | 0.01
395 | 0.01
396 | 0.00
397 | 0.00
398 | 0.00
399 | 0.01
400 | 0.05
401 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/uSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_0/uSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/val_acc_epochs_x.txt:
--------------------------------------------------------------------------------
1 | 2.00
2 | 4.00
3 | 6.00
4 | 8.00
5 | 10.00
6 | 12.00
7 | 14.00
8 | 16.00
9 | 18.00
10 | 20.00
11 | 22.00
12 | 24.00
13 | 26.00
14 | 28.00
15 | 30.00
16 | 32.00
17 | 34.00
18 | 36.00
19 | 38.00
20 | 40.00
21 | 42.00
22 | 44.00
23 | 46.00
24 | 48.00
25 | 50.00
26 | 52.00
27 | 54.00
28 | 56.00
29 | 58.00
30 | 60.00
31 | 62.00
32 | 64.00
33 | 66.00
34 | 68.00
35 | 70.00
36 | 72.00
37 | 74.00
38 | 76.00
39 | 78.00
40 | 80.00
41 | 82.00
42 | 84.00
43 | 86.00
44 | 88.00
45 | 90.00
46 | 92.00
47 | 94.00
48 | 96.00
49 | 98.00
50 | 100.00
51 | 102.00
52 | 104.00
53 | 106.00
54 | 108.00
55 | 110.00
56 | 112.00
57 | 114.00
58 | 116.00
59 | 118.00
60 | 120.00
61 | 122.00
62 | 124.00
63 | 126.00
64 | 128.00
65 | 130.00
66 | 132.00
67 | 134.00
68 | 136.00
69 | 138.00
70 | 140.00
71 | 142.00
72 | 144.00
73 | 146.00
74 | 148.00
75 | 150.00
76 | 152.00
77 | 154.00
78 | 156.00
79 | 158.00
80 | 160.00
81 | 162.00
82 | 164.00
83 | 166.00
84 | 168.00
85 | 170.00
86 | 172.00
87 | 174.00
88 | 176.00
89 | 178.00
90 | 180.00
91 | 182.00
92 | 184.00
93 | 186.00
94 | 188.00
95 | 190.00
96 | 192.00
97 | 194.00
98 | 196.00
99 | 198.00
100 | 200.00
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_0/val_acc_epochs_y.txt:
--------------------------------------------------------------------------------
1 | 19.56
2 | 35.18
3 | 36.56
4 | 44.68
5 | 47.56
6 | 49.50
7 | 50.68
8 | 49.66
9 | 48.96
10 | 56.78
11 | 53.22
12 | 58.02
13 | 58.88
14 | 57.86
15 | 56.76
16 | 58.34
17 | 55.04
18 | 50.74
19 | 47.78
20 | 62.92
21 | 63.94
22 | 65.36
23 | 56.06
24 | 64.24
25 | 62.10
26 | 56.70
27 | 63.36
28 | 65.98
29 | 62.76
30 | 60.18
31 | 64.94
32 | 68.96
33 | 65.54
34 | 63.70
35 | 66.26
36 | 63.90
37 | 68.60
38 | 68.60
39 | 61.42
40 | 65.40
41 | 68.52
42 | 68.28
43 | 68.88
44 | 70.22
45 | 70.58
46 | 68.96
47 | 65.02
48 | 71.36
49 | 66.46
50 | 70.48
51 | 67.36
52 | 66.72
53 | 68.16
54 | 69.12
55 | 68.24
56 | 70.72
57 | 67.44
58 | 69.26
59 | 69.18
60 | 71.04
61 | 68.94
62 | 70.64
63 | 71.64
64 | 69.16
65 | 70.86
66 | 71.38
67 | 69.28
68 | 71.44
69 | 70.98
70 | 71.54
71 | 72.78
72 | 71.52
73 | 72.22
74 | 71.96
75 | 71.92
76 | 71.94
77 | 70.78
78 | 72.50
79 | 72.64
80 | 72.78
81 | 72.32
82 | 71.52
83 | 71.84
84 | 71.88
85 | 71.82
86 | 71.84
87 | 72.10
88 | 72.26
89 | 72.42
90 | 73.06
91 | 72.02
92 | 71.98
93 | 72.28
94 | 72.50
95 | 72.28
96 | 71.66
97 | 72.40
98 | 72.56
99 | 72.18
100 | 72.52
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Validation Accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/Epochs_vs_Validation Accuracy.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/Iterations_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/Iterations_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/activeSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/activeSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/lSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/lSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/plot_epoch_xvalues.txt:
--------------------------------------------------------------------------------
1 | 1.00
2 | 2.00
3 | 3.00
4 | 4.00
5 | 5.00
6 | 6.00
7 | 7.00
8 | 8.00
9 | 9.00
10 | 10.00
11 | 11.00
12 | 12.00
13 | 13.00
14 | 14.00
15 | 15.00
16 | 16.00
17 | 17.00
18 | 18.00
19 | 19.00
20 | 20.00
21 | 21.00
22 | 22.00
23 | 23.00
24 | 24.00
25 | 25.00
26 | 26.00
27 | 27.00
28 | 28.00
29 | 29.00
30 | 30.00
31 | 31.00
32 | 32.00
33 | 33.00
34 | 34.00
35 | 35.00
36 | 36.00
37 | 37.00
38 | 38.00
39 | 39.00
40 | 40.00
41 | 41.00
42 | 42.00
43 | 43.00
44 | 44.00
45 | 45.00
46 | 46.00
47 | 47.00
48 | 48.00
49 | 49.00
50 | 50.00
51 | 51.00
52 | 52.00
53 | 53.00
54 | 54.00
55 | 55.00
56 | 56.00
57 | 57.00
58 | 58.00
59 | 59.00
60 | 60.00
61 | 61.00
62 | 62.00
63 | 63.00
64 | 64.00
65 | 65.00
66 | 66.00
67 | 67.00
68 | 68.00
69 | 69.00
70 | 70.00
71 | 71.00
72 | 72.00
73 | 73.00
74 | 74.00
75 | 75.00
76 | 76.00
77 | 77.00
78 | 78.00
79 | 79.00
80 | 80.00
81 | 81.00
82 | 82.00
83 | 83.00
84 | 84.00
85 | 85.00
86 | 86.00
87 | 87.00
88 | 88.00
89 | 89.00
90 | 90.00
91 | 91.00
92 | 92.00
93 | 93.00
94 | 94.00
95 | 95.00
96 | 96.00
97 | 97.00
98 | 98.00
99 | 99.00
100 | 100.00
101 | 101.00
102 | 102.00
103 | 103.00
104 | 104.00
105 | 105.00
106 | 106.00
107 | 107.00
108 | 108.00
109 | 109.00
110 | 110.00
111 | 111.00
112 | 112.00
113 | 113.00
114 | 114.00
115 | 115.00
116 | 116.00
117 | 117.00
118 | 118.00
119 | 119.00
120 | 120.00
121 | 121.00
122 | 122.00
123 | 123.00
124 | 124.00
125 | 125.00
126 | 126.00
127 | 127.00
128 | 128.00
129 | 129.00
130 | 130.00
131 | 131.00
132 | 132.00
133 | 133.00
134 | 134.00
135 | 135.00
136 | 136.00
137 | 137.00
138 | 138.00
139 | 139.00
140 | 140.00
141 | 141.00
142 | 142.00
143 | 143.00
144 | 144.00
145 | 145.00
146 | 146.00
147 | 147.00
148 | 148.00
149 | 149.00
150 | 150.00
151 | 151.00
152 | 152.00
153 | 153.00
154 | 154.00
155 | 155.00
156 | 156.00
157 | 157.00
158 | 158.00
159 | 159.00
160 | 160.00
161 | 161.00
162 | 162.00
163 | 163.00
164 | 164.00
165 | 165.00
166 | 166.00
167 | 167.00
168 | 168.00
169 | 169.00
170 | 170.00
171 | 171.00
172 | 172.00
173 | 173.00
174 | 174.00
175 | 175.00
176 | 176.00
177 | 177.00
178 | 178.00
179 | 179.00
180 | 180.00
181 | 181.00
182 | 182.00
183 | 183.00
184 | 184.00
185 | 185.00
186 | 186.00
187 | 187.00
188 | 188.00
189 | 189.00
190 | 190.00
191 | 191.00
192 | 192.00
193 | 193.00
194 | 194.00
195 | 195.00
196 | 196.00
197 | 197.00
198 | 198.00
199 | 199.00
200 | 200.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/plot_epoch_yvalues.txt:
--------------------------------------------------------------------------------
1 | 1.02
2 | 1.78
3 | 1.48
4 | 1.88
5 | 0.78
6 | 0.54
7 | 0.73
8 | 1.24
9 | 0.26
10 | 0.63
11 | 0.54
12 | 0.52
13 | 1.25
14 | 0.21
15 | 0.99
16 | 0.76
17 | 0.76
18 | 0.84
19 | 0.81
20 | 0.62
21 | 1.19
22 | 0.14
23 | 1.02
24 | 0.77
25 | 0.73
26 | 0.16
27 | 0.35
28 | 0.71
29 | 0.44
30 | 0.41
31 | 0.48
32 | 0.12
33 | 0.04
34 | 0.55
35 | 0.74
36 | 0.32
37 | 0.31
38 | 0.37
39 | 0.15
40 | 0.11
41 | 0.66
42 | 0.31
43 | 0.38
44 | 0.04
45 | 0.16
46 | 0.41
47 | 0.09
48 | 0.17
49 | 0.27
50 | 0.40
51 | 0.02
52 | 0.29
53 | 0.06
54 | 0.21
55 | 0.47
56 | 0.23
57 | 0.22
58 | 0.49
59 | 0.09
60 | 0.42
61 | 0.04
62 | 0.12
63 | 0.10
64 | 0.23
65 | 0.29
66 | 0.68
67 | 0.01
68 | 0.01
69 | 0.07
70 | 0.02
71 | 0.06
72 | 0.37
73 | 0.45
74 | 0.65
75 | 0.73
76 | 0.63
77 | 0.34
78 | 0.04
79 | 1.02
80 | 0.17
81 | 0.12
82 | 1.28
83 | 0.02
84 | 0.11
85 | 0.26
86 | 0.51
87 | 0.08
88 | 0.04
89 | 0.41
90 | 0.01
91 | 0.02
92 | 0.01
93 | 0.03
94 | 0.01
95 | 0.29
96 | 0.58
97 | 0.02
98 | 0.02
99 | 0.01
100 | 0.42
101 | 0.15
102 | 0.06
103 | 0.00
104 | 0.51
105 | 0.04
106 | 0.37
107 | 0.02
108 | 0.07
109 | 0.05
110 | 0.21
111 | 0.01
112 | 0.02
113 | 0.00
114 | 0.01
115 | 0.00
116 | 0.00
117 | 0.01
118 | 0.03
119 | 0.01
120 | 0.01
121 | 0.00
122 | 0.03
123 | 0.09
124 | 0.06
125 | 0.09
126 | 0.02
127 | 0.00
128 | 0.44
129 | 0.07
130 | 0.00
131 | 0.01
132 | 0.00
133 | 0.00
134 | 0.01
135 | 0.02
136 | 0.30
137 | 0.00
138 | 0.01
139 | 0.10
140 | 0.01
141 | 0.00
142 | 0.00
143 | 0.00
144 | 0.04
145 | 0.00
146 | 0.00
147 | 0.00
148 | 0.00
149 | 0.03
150 | 0.02
151 | 0.02
152 | 0.20
153 | 0.01
154 | 0.01
155 | 0.03
156 | 0.22
157 | 0.00
158 | 0.30
159 | 0.15
160 | 0.00
161 | 0.02
162 | 0.01
163 | 0.17
164 | 0.00
165 | 0.00
166 | 0.00
167 | 0.00
168 | 0.01
169 | 0.05
170 | 0.00
171 | 0.01
172 | 0.00
173 | 0.08
174 | 0.11
175 | 0.00
176 | 0.00
177 | 0.07
178 | 0.00
179 | 0.03
180 | 0.00
181 | 0.00
182 | 0.01
183 | 0.00
184 | 0.00
185 | 0.00
186 | 0.05
187 | 0.00
188 | 0.06
189 | 0.00
190 | 0.03
191 | 0.00
192 | 0.00
193 | 0.05
194 | 0.11
195 | 0.05
196 | 0.00
197 | 0.00
198 | 0.00
199 | 0.00
200 | 0.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/uSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_1/uSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/val_acc_epochs_x.txt:
--------------------------------------------------------------------------------
1 | 2.00
2 | 4.00
3 | 6.00
4 | 8.00
5 | 10.00
6 | 12.00
7 | 14.00
8 | 16.00
9 | 18.00
10 | 20.00
11 | 22.00
12 | 24.00
13 | 26.00
14 | 28.00
15 | 30.00
16 | 32.00
17 | 34.00
18 | 36.00
19 | 38.00
20 | 40.00
21 | 42.00
22 | 44.00
23 | 46.00
24 | 48.00
25 | 50.00
26 | 52.00
27 | 54.00
28 | 56.00
29 | 58.00
30 | 60.00
31 | 62.00
32 | 64.00
33 | 66.00
34 | 68.00
35 | 70.00
36 | 72.00
37 | 74.00
38 | 76.00
39 | 78.00
40 | 80.00
41 | 82.00
42 | 84.00
43 | 86.00
44 | 88.00
45 | 90.00
46 | 92.00
47 | 94.00
48 | 96.00
49 | 98.00
50 | 100.00
51 | 102.00
52 | 104.00
53 | 106.00
54 | 108.00
55 | 110.00
56 | 112.00
57 | 114.00
58 | 116.00
59 | 118.00
60 | 120.00
61 | 122.00
62 | 124.00
63 | 126.00
64 | 128.00
65 | 130.00
66 | 132.00
67 | 134.00
68 | 136.00
69 | 138.00
70 | 140.00
71 | 142.00
72 | 144.00
73 | 146.00
74 | 148.00
75 | 150.00
76 | 152.00
77 | 154.00
78 | 156.00
79 | 158.00
80 | 160.00
81 | 162.00
82 | 164.00
83 | 166.00
84 | 168.00
85 | 170.00
86 | 172.00
87 | 174.00
88 | 176.00
89 | 178.00
90 | 180.00
91 | 182.00
92 | 184.00
93 | 186.00
94 | 188.00
95 | 190.00
96 | 192.00
97 | 194.00
98 | 196.00
99 | 198.00
100 | 200.00
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_1/val_acc_epochs_y.txt:
--------------------------------------------------------------------------------
1 | 70.04
2 | 72.88
3 | 71.68
4 | 73.36
5 | 75.62
6 | 75.38
7 | 72.94
8 | 74.72
9 | 74.52
10 | 74.28
11 | 71.98
12 | 74.08
13 | 75.66
14 | 74.38
15 | 76.30
16 | 77.66
17 | 74.90
18 | 75.98
19 | 76.94
20 | 78.02
21 | 77.38
22 | 77.76
23 | 77.18
24 | 77.96
25 | 76.46
26 | 78.00
27 | 75.94
28 | 78.16
29 | 78.02
30 | 78.28
31 | 78.28
32 | 79.14
33 | 79.44
34 | 78.34
35 | 79.20
36 | 78.58
37 | 78.94
38 | 78.28
39 | 79.90
40 | 77.70
41 | 78.76
42 | 78.20
43 | 78.00
44 | 79.32
45 | 79.20
46 | 79.44
47 | 80.58
48 | 79.42
49 | 80.66
50 | 80.80
51 | 80.44
52 | 80.34
53 | 80.64
54 | 80.42
55 | 80.98
56 | 81.26
57 | 80.94
58 | 81.38
59 | 81.62
60 | 81.42
61 | 81.34
62 | 81.48
63 | 80.56
64 | 81.08
65 | 81.18
66 | 82.16
67 | 81.94
68 | 81.84
69 | 81.72
70 | 81.10
71 | 81.64
72 | 81.86
73 | 82.72
74 | 82.34
75 | 82.14
76 | 81.72
77 | 82.30
78 | 82.22
79 | 81.92
80 | 82.60
81 | 82.40
82 | 82.06
83 | 82.14
84 | 82.14
85 | 82.42
86 | 82.44
87 | 82.28
88 | 82.42
89 | 82.50
90 | 82.48
91 | 82.54
92 | 82.58
93 | 83.08
94 | 82.52
95 | 82.64
96 | 82.36
97 | 82.56
98 | 82.40
99 | 82.40
100 | 82.94
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Validation Accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/Epochs_vs_Validation Accuracy.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/Iterations_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/Iterations_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/activeSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/activeSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/lSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/lSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/plot_epoch_xvalues.txt:
--------------------------------------------------------------------------------
1 | 1.00
2 | 2.00
3 | 3.00
4 | 4.00
5 | 5.00
6 | 6.00
7 | 7.00
8 | 8.00
9 | 9.00
10 | 10.00
11 | 11.00
12 | 12.00
13 | 13.00
14 | 14.00
15 | 15.00
16 | 16.00
17 | 17.00
18 | 18.00
19 | 19.00
20 | 20.00
21 | 21.00
22 | 22.00
23 | 23.00
24 | 24.00
25 | 25.00
26 | 26.00
27 | 27.00
28 | 28.00
29 | 29.00
30 | 30.00
31 | 31.00
32 | 32.00
33 | 33.00
34 | 34.00
35 | 35.00
36 | 36.00
37 | 37.00
38 | 38.00
39 | 39.00
40 | 40.00
41 | 41.00
42 | 42.00
43 | 43.00
44 | 44.00
45 | 45.00
46 | 46.00
47 | 47.00
48 | 48.00
49 | 49.00
50 | 50.00
51 | 51.00
52 | 52.00
53 | 53.00
54 | 54.00
55 | 55.00
56 | 56.00
57 | 57.00
58 | 58.00
59 | 59.00
60 | 60.00
61 | 61.00
62 | 62.00
63 | 63.00
64 | 64.00
65 | 65.00
66 | 66.00
67 | 67.00
68 | 68.00
69 | 69.00
70 | 70.00
71 | 71.00
72 | 72.00
73 | 73.00
74 | 74.00
75 | 75.00
76 | 76.00
77 | 77.00
78 | 78.00
79 | 79.00
80 | 80.00
81 | 81.00
82 | 82.00
83 | 83.00
84 | 84.00
85 | 85.00
86 | 86.00
87 | 87.00
88 | 88.00
89 | 89.00
90 | 90.00
91 | 91.00
92 | 92.00
93 | 93.00
94 | 94.00
95 | 95.00
96 | 96.00
97 | 97.00
98 | 98.00
99 | 99.00
100 | 100.00
101 | 101.00
102 | 102.00
103 | 103.00
104 | 104.00
105 | 105.00
106 | 106.00
107 | 107.00
108 | 108.00
109 | 109.00
110 | 110.00
111 | 111.00
112 | 112.00
113 | 113.00
114 | 114.00
115 | 115.00
116 | 116.00
117 | 117.00
118 | 118.00
119 | 119.00
120 | 120.00
121 | 121.00
122 | 122.00
123 | 123.00
124 | 124.00
125 | 125.00
126 | 126.00
127 | 127.00
128 | 128.00
129 | 129.00
130 | 130.00
131 | 131.00
132 | 132.00
133 | 133.00
134 | 134.00
135 | 135.00
136 | 136.00
137 | 137.00
138 | 138.00
139 | 139.00
140 | 140.00
141 | 141.00
142 | 142.00
143 | 143.00
144 | 144.00
145 | 145.00
146 | 146.00
147 | 147.00
148 | 148.00
149 | 149.00
150 | 150.00
151 | 151.00
152 | 152.00
153 | 153.00
154 | 154.00
155 | 155.00
156 | 156.00
157 | 157.00
158 | 158.00
159 | 159.00
160 | 160.00
161 | 161.00
162 | 162.00
163 | 163.00
164 | 164.00
165 | 165.00
166 | 166.00
167 | 167.00
168 | 168.00
169 | 169.00
170 | 170.00
171 | 171.00
172 | 172.00
173 | 173.00
174 | 174.00
175 | 175.00
176 | 176.00
177 | 177.00
178 | 178.00
179 | 179.00
180 | 180.00
181 | 181.00
182 | 182.00
183 | 183.00
184 | 184.00
185 | 185.00
186 | 186.00
187 | 187.00
188 | 188.00
189 | 189.00
190 | 190.00
191 | 191.00
192 | 192.00
193 | 193.00
194 | 194.00
195 | 195.00
196 | 196.00
197 | 197.00
198 | 198.00
199 | 199.00
200 | 200.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/plot_epoch_yvalues.txt:
--------------------------------------------------------------------------------
1 | 1.36
2 | 0.50
3 | 0.75
4 | 0.53
5 | 1.11
6 | 0.36
7 | 0.36
8 | 0.46
9 | 0.80
10 | 0.48
11 | 0.41
12 | 0.88
13 | 0.46
14 | 0.53
15 | 0.32
16 | 0.63
17 | 0.64
18 | 0.27
19 | 0.40
20 | 0.79
21 | 0.55
22 | 0.18
23 | 0.48
24 | 0.49
25 | 0.49
26 | 0.63
27 | 0.49
28 | 0.19
29 | 0.90
30 | 0.56
31 | 0.34
32 | 0.17
33 | 0.35
34 | 0.23
35 | 0.14
36 | 0.10
37 | 0.46
38 | 0.21
39 | 0.27
40 | 0.23
41 | 0.47
42 | 0.10
43 | 0.81
44 | 0.09
45 | 0.24
46 | 0.46
47 | 0.12
48 | 0.22
49 | 0.03
50 | 0.04
51 | 0.51
52 | 0.20
53 | 0.06
54 | 0.80
55 | 0.13
56 | 0.15
57 | 0.19
58 | 0.10
59 | 0.06
60 | 0.25
61 | 0.22
62 | 0.04
63 | 0.50
64 | 0.10
65 | 0.07
66 | 0.27
67 | 0.04
68 | 0.13
69 | 0.24
70 | 0.02
71 | 0.09
72 | 0.14
73 | 0.27
74 | 0.09
75 | 0.09
76 | 0.36
77 | 0.15
78 | 0.14
79 | 0.27
80 | 0.17
81 | 0.02
82 | 0.35
83 | 0.21
84 | 0.22
85 | 0.02
86 | 0.14
87 | 0.07
88 | 0.05
89 | 0.23
90 | 0.11
91 | 0.02
92 | 0.02
93 | 0.51
94 | 0.05
95 | 0.12
96 | 0.31
97 | 0.06
98 | 0.00
99 | 0.23
100 | 0.10
101 | 0.13
102 | 0.14
103 | 0.02
104 | 0.03
105 | 0.16
106 | 0.01
107 | 0.02
108 | 0.11
109 | 0.00
110 | 0.29
111 | 0.32
112 | 0.00
113 | 0.03
114 | 0.16
115 | 0.02
116 | 0.01
117 | 0.02
118 | 0.07
119 | 0.04
120 | 0.00
121 | 0.11
122 | 0.01
123 | 0.03
124 | 0.00
125 | 0.10
126 | 0.03
127 | 0.03
128 | 0.02
129 | 0.01
130 | 0.01
131 | 0.02
132 | 0.00
133 | 0.00
134 | 0.03
135 | 0.10
136 | 0.00
137 | 0.00
138 | 0.01
139 | 0.01
140 | 0.03
141 | 0.01
142 | 0.00
143 | 0.00
144 | 0.09
145 | 0.09
146 | 0.00
147 | 0.01
148 | 0.00
149 | 0.03
150 | 0.01
151 | 0.01
152 | 0.00
153 | 0.00
154 | 0.02
155 | 0.00
156 | 0.00
157 | 0.00
158 | 0.01
159 | 0.00
160 | 0.00
161 | 0.09
162 | 0.01
163 | 0.12
164 | 0.02
165 | 0.01
166 | 0.00
167 | 0.00
168 | 0.01
169 | 0.00
170 | 0.00
171 | 0.02
172 | 0.04
173 | 0.00
174 | 0.01
175 | 0.01
176 | 0.08
177 | 0.00
178 | 0.00
179 | 0.00
180 | 0.00
181 | 0.01
182 | 0.00
183 | 0.01
184 | 0.00
185 | 0.00
186 | 0.19
187 | 0.00
188 | 0.00
189 | 0.00
190 | 0.00
191 | 0.00
192 | 0.03
193 | 0.00
194 | 0.00
195 | 0.08
196 | 0.02
197 | 0.01
198 | 0.00
199 | 0.00
200 | 0.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/uSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_2/uSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/val_acc_epochs_x.txt:
--------------------------------------------------------------------------------
1 | 2.00
2 | 4.00
3 | 6.00
4 | 8.00
5 | 10.00
6 | 12.00
7 | 14.00
8 | 16.00
9 | 18.00
10 | 20.00
11 | 22.00
12 | 24.00
13 | 26.00
14 | 28.00
15 | 30.00
16 | 32.00
17 | 34.00
18 | 36.00
19 | 38.00
20 | 40.00
21 | 42.00
22 | 44.00
23 | 46.00
24 | 48.00
25 | 50.00
26 | 52.00
27 | 54.00
28 | 56.00
29 | 58.00
30 | 60.00
31 | 62.00
32 | 64.00
33 | 66.00
34 | 68.00
35 | 70.00
36 | 72.00
37 | 74.00
38 | 76.00
39 | 78.00
40 | 80.00
41 | 82.00
42 | 84.00
43 | 86.00
44 | 88.00
45 | 90.00
46 | 92.00
47 | 94.00
48 | 96.00
49 | 98.00
50 | 100.00
51 | 102.00
52 | 104.00
53 | 106.00
54 | 108.00
55 | 110.00
56 | 112.00
57 | 114.00
58 | 116.00
59 | 118.00
60 | 120.00
61 | 122.00
62 | 124.00
63 | 126.00
64 | 128.00
65 | 130.00
66 | 132.00
67 | 134.00
68 | 136.00
69 | 138.00
70 | 140.00
71 | 142.00
72 | 144.00
73 | 146.00
74 | 148.00
75 | 150.00
76 | 152.00
77 | 154.00
78 | 156.00
79 | 158.00
80 | 160.00
81 | 162.00
82 | 164.00
83 | 166.00
84 | 168.00
85 | 170.00
86 | 172.00
87 | 174.00
88 | 176.00
89 | 178.00
90 | 180.00
91 | 182.00
92 | 184.00
93 | 186.00
94 | 188.00
95 | 190.00
96 | 192.00
97 | 194.00
98 | 196.00
99 | 198.00
100 | 200.00
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_2/val_acc_epochs_y.txt:
--------------------------------------------------------------------------------
1 | 77.78
2 | 80.22
3 | 78.34
4 | 80.14
5 | 79.92
6 | 80.06
7 | 80.78
8 | 77.90
9 | 80.30
10 | 80.96
11 | 82.00
12 | 81.84
13 | 82.94
14 | 81.58
15 | 82.50
16 | 82.64
17 | 81.78
18 | 83.20
19 | 81.72
20 | 81.64
21 | 82.74
22 | 82.30
23 | 83.54
24 | 81.02
25 | 82.96
26 | 82.56
27 | 83.24
28 | 83.30
29 | 80.80
30 | 82.10
31 | 84.18
32 | 83.28
33 | 83.88
34 | 82.38
35 | 83.24
36 | 83.96
37 | 81.72
38 | 83.54
39 | 84.40
40 | 84.92
41 | 84.18
42 | 84.34
43 | 83.84
44 | 83.78
45 | 83.58
46 | 85.40
47 | 83.40
48 | 84.38
49 | 84.78
50 | 85.04
51 | 85.26
52 | 85.20
53 | 84.82
54 | 85.78
55 | 85.18
56 | 85.20
57 | 85.78
58 | 85.78
59 | 85.76
60 | 84.58
61 | 85.76
62 | 86.20
63 | 86.08
64 | 86.20
65 | 86.28
66 | 86.46
67 | 86.72
68 | 86.28
69 | 85.92
70 | 86.54
71 | 87.10
72 | 86.26
73 | 87.16
74 | 86.50
75 | 87.16
76 | 87.20
77 | 87.22
78 | 87.08
79 | 87.42
80 | 87.38
81 | 87.08
82 | 86.90
83 | 87.28
84 | 87.14
85 | 87.34
86 | 87.14
87 | 87.30
88 | 87.36
89 | 87.06
90 | 86.88
91 | 87.44
92 | 87.12
93 | 87.30
94 | 87.88
95 | 87.66
96 | 87.40
97 | 86.96
98 | 87.58
99 | 87.12
100 | 87.82
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Validation Accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/Epochs_vs_Validation Accuracy.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/Iterations_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/Iterations_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/activeSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/activeSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/lSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/lSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/plot_epoch_xvalues.txt:
--------------------------------------------------------------------------------
1 | 1.00
2 | 2.00
3 | 3.00
4 | 4.00
5 | 5.00
6 | 6.00
7 | 7.00
8 | 8.00
9 | 9.00
10 | 10.00
11 | 11.00
12 | 12.00
13 | 13.00
14 | 14.00
15 | 15.00
16 | 16.00
17 | 17.00
18 | 18.00
19 | 19.00
20 | 20.00
21 | 21.00
22 | 22.00
23 | 23.00
24 | 24.00
25 | 25.00
26 | 26.00
27 | 27.00
28 | 28.00
29 | 29.00
30 | 30.00
31 | 31.00
32 | 32.00
33 | 33.00
34 | 34.00
35 | 35.00
36 | 36.00
37 | 37.00
38 | 38.00
39 | 39.00
40 | 40.00
41 | 41.00
42 | 42.00
43 | 43.00
44 | 44.00
45 | 45.00
46 | 46.00
47 | 47.00
48 | 48.00
49 | 49.00
50 | 50.00
51 | 51.00
52 | 52.00
53 | 53.00
54 | 54.00
55 | 55.00
56 | 56.00
57 | 57.00
58 | 58.00
59 | 59.00
60 | 60.00
61 | 61.00
62 | 62.00
63 | 63.00
64 | 64.00
65 | 65.00
66 | 66.00
67 | 67.00
68 | 68.00
69 | 69.00
70 | 70.00
71 | 71.00
72 | 72.00
73 | 73.00
74 | 74.00
75 | 75.00
76 | 76.00
77 | 77.00
78 | 78.00
79 | 79.00
80 | 80.00
81 | 81.00
82 | 82.00
83 | 83.00
84 | 84.00
85 | 85.00
86 | 86.00
87 | 87.00
88 | 88.00
89 | 89.00
90 | 90.00
91 | 91.00
92 | 92.00
93 | 93.00
94 | 94.00
95 | 95.00
96 | 96.00
97 | 97.00
98 | 98.00
99 | 99.00
100 | 100.00
101 | 101.00
102 | 102.00
103 | 103.00
104 | 104.00
105 | 105.00
106 | 106.00
107 | 107.00
108 | 108.00
109 | 109.00
110 | 110.00
111 | 111.00
112 | 112.00
113 | 113.00
114 | 114.00
115 | 115.00
116 | 116.00
117 | 117.00
118 | 118.00
119 | 119.00
120 | 120.00
121 | 121.00
122 | 122.00
123 | 123.00
124 | 124.00
125 | 125.00
126 | 126.00
127 | 127.00
128 | 128.00
129 | 129.00
130 | 130.00
131 | 131.00
132 | 132.00
133 | 133.00
134 | 134.00
135 | 135.00
136 | 136.00
137 | 137.00
138 | 138.00
139 | 139.00
140 | 140.00
141 | 141.00
142 | 142.00
143 | 143.00
144 | 144.00
145 | 145.00
146 | 146.00
147 | 147.00
148 | 148.00
149 | 149.00
150 | 150.00
151 | 151.00
152 | 152.00
153 | 153.00
154 | 154.00
155 | 155.00
156 | 156.00
157 | 157.00
158 | 158.00
159 | 159.00
160 | 160.00
161 | 161.00
162 | 162.00
163 | 163.00
164 | 164.00
165 | 165.00
166 | 166.00
167 | 167.00
168 | 168.00
169 | 169.00
170 | 170.00
171 | 171.00
172 | 172.00
173 | 173.00
174 | 174.00
175 | 175.00
176 | 176.00
177 | 177.00
178 | 178.00
179 | 179.00
180 | 180.00
181 | 181.00
182 | 182.00
183 | 183.00
184 | 184.00
185 | 185.00
186 | 186.00
187 | 187.00
188 | 188.00
189 | 189.00
190 | 190.00
191 | 191.00
192 | 192.00
193 | 193.00
194 | 194.00
195 | 195.00
196 | 196.00
197 | 197.00
198 | 198.00
199 | 199.00
200 | 200.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/plot_epoch_yvalues.txt:
--------------------------------------------------------------------------------
1 | 0.43
2 | 0.54
3 | 0.60
4 | 0.36
5 | 0.66
6 | 0.30
7 | 0.21
8 | 0.75
9 | 0.14
10 | 0.27
11 | 0.59
12 | 0.28
13 | 0.31
14 | 0.86
15 | 0.44
16 | 0.59
17 | 0.23
18 | 0.27
19 | 0.56
20 | 0.45
21 | 0.25
22 | 0.39
23 | 0.47
24 | 0.69
25 | 0.29
26 | 0.61
27 | 0.43
28 | 0.34
29 | 0.30
30 | 0.19
31 | 0.18
32 | 0.15
33 | 0.27
34 | 0.23
35 | 0.23
36 | 0.46
37 | 0.12
38 | 0.02
39 | 0.21
40 | 0.71
41 | 0.39
42 | 0.40
43 | 0.28
44 | 0.09
45 | 0.12
46 | 0.20
47 | 0.17
48 | 0.06
49 | 0.33
50 | 0.37
51 | 0.56
52 | 0.36
53 | 0.38
54 | 0.15
55 | 0.13
56 | 0.33
57 | 0.28
58 | 0.29
59 | 0.19
60 | 0.24
61 | 0.04
62 | 0.04
63 | 0.04
64 | 0.25
65 | 0.02
66 | 0.21
67 | 0.14
68 | 0.04
69 | 0.28
70 | 0.06
71 | 0.27
72 | 0.12
73 | 0.22
74 | 0.24
75 | 0.35
76 | 0.08
77 | 0.08
78 | 0.12
79 | 0.06
80 | 0.07
81 | 0.04
82 | 0.18
83 | 0.03
84 | 0.03
85 | 0.28
86 | 0.23
87 | 0.03
88 | 0.03
89 | 0.10
90 | 0.23
91 | 0.09
92 | 0.09
93 | 0.32
94 | 0.04
95 | 0.03
96 | 0.04
97 | 0.03
98 | 0.30
99 | 0.02
100 | 0.09
101 | 0.09
102 | 0.04
103 | 0.04
104 | 0.09
105 | 0.23
106 | 0.02
107 | 0.10
108 | 0.19
109 | 0.23
110 | 0.00
111 | 0.02
112 | 0.08
113 | 0.01
114 | 0.00
115 | 0.00
116 | 0.03
117 | 0.05
118 | 0.01
119 | 0.06
120 | 0.00
121 | 0.01
122 | 0.13
123 | 0.01
124 | 0.01
125 | 0.03
126 | 0.00
127 | 0.00
128 | 0.00
129 | 0.10
130 | 0.02
131 | 0.01
132 | 0.01
133 | 0.02
134 | 0.02
135 | 0.00
136 | 0.02
137 | 0.01
138 | 0.01
139 | 0.00
140 | 0.01
141 | 0.07
142 | 0.05
143 | 0.01
144 | 0.01
145 | 0.01
146 | 0.09
147 | 0.05
148 | 0.01
149 | 0.01
150 | 0.00
151 | 0.01
152 | 0.01
153 | 0.00
154 | 0.00
155 | 0.05
156 | 0.01
157 | 0.00
158 | 0.03
159 | 0.00
160 | 0.06
161 | 0.00
162 | 0.00
163 | 0.00
164 | 0.00
165 | 0.00
166 | 0.00
167 | 0.07
168 | 0.00
169 | 0.00
170 | 0.00
171 | 0.00
172 | 0.00
173 | 0.00
174 | 0.00
175 | 0.00
176 | 0.00
177 | 0.00
178 | 0.00
179 | 0.01
180 | 0.00
181 | 0.00
182 | 0.07
183 | 0.00
184 | 0.00
185 | 0.00
186 | 0.00
187 | 0.00
188 | 0.00
189 | 0.01
190 | 0.00
191 | 0.00
192 | 0.00
193 | 0.00
194 | 0.00
195 | 0.01
196 | 0.00
197 | 0.00
198 | 0.00
199 | 0.00
200 | 0.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/uSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_3/uSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/val_acc_epochs_x.txt:
--------------------------------------------------------------------------------
1 | 2.00
2 | 4.00
3 | 6.00
4 | 8.00
5 | 10.00
6 | 12.00
7 | 14.00
8 | 16.00
9 | 18.00
10 | 20.00
11 | 22.00
12 | 24.00
13 | 26.00
14 | 28.00
15 | 30.00
16 | 32.00
17 | 34.00
18 | 36.00
19 | 38.00
20 | 40.00
21 | 42.00
22 | 44.00
23 | 46.00
24 | 48.00
25 | 50.00
26 | 52.00
27 | 54.00
28 | 56.00
29 | 58.00
30 | 60.00
31 | 62.00
32 | 64.00
33 | 66.00
34 | 68.00
35 | 70.00
36 | 72.00
37 | 74.00
38 | 76.00
39 | 78.00
40 | 80.00
41 | 82.00
42 | 84.00
43 | 86.00
44 | 88.00
45 | 90.00
46 | 92.00
47 | 94.00
48 | 96.00
49 | 98.00
50 | 100.00
51 | 102.00
52 | 104.00
53 | 106.00
54 | 108.00
55 | 110.00
56 | 112.00
57 | 114.00
58 | 116.00
59 | 118.00
60 | 120.00
61 | 122.00
62 | 124.00
63 | 126.00
64 | 128.00
65 | 130.00
66 | 132.00
67 | 134.00
68 | 136.00
69 | 138.00
70 | 140.00
71 | 142.00
72 | 144.00
73 | 146.00
74 | 148.00
75 | 150.00
76 | 152.00
77 | 154.00
78 | 156.00
79 | 158.00
80 | 160.00
81 | 162.00
82 | 164.00
83 | 166.00
84 | 168.00
85 | 170.00
86 | 172.00
87 | 174.00
88 | 176.00
89 | 178.00
90 | 180.00
91 | 182.00
92 | 184.00
93 | 186.00
94 | 188.00
95 | 190.00
96 | 192.00
97 | 194.00
98 | 196.00
99 | 198.00
100 | 200.00
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_3/val_acc_epochs_y.txt:
--------------------------------------------------------------------------------
1 | 80.46
2 | 82.38
3 | 83.24
4 | 84.12
5 | 82.84
6 | 81.04
7 | 84.58
8 | 84.36
9 | 83.66
10 | 83.64
11 | 83.24
12 | 84.52
13 | 84.14
14 | 82.96
15 | 82.40
16 | 84.62
17 | 84.44
18 | 86.14
19 | 84.90
20 | 85.64
21 | 86.28
22 | 85.44
23 | 86.80
24 | 84.58
25 | 85.36
26 | 85.48
27 | 85.54
28 | 86.56
29 | 85.90
30 | 85.20
31 | 86.02
32 | 85.40
33 | 85.56
34 | 86.36
35 | 84.94
36 | 85.76
37 | 85.46
38 | 87.02
39 | 86.40
40 | 86.78
41 | 86.06
42 | 86.22
43 | 87.22
44 | 85.86
45 | 86.96
46 | 86.54
47 | 87.20
48 | 87.56
49 | 87.14
50 | 86.30
51 | 86.64
52 | 87.96
53 | 88.24
54 | 87.42
55 | 88.38
56 | 88.02
57 | 88.32
58 | 87.74
59 | 87.88
60 | 88.74
61 | 87.96
62 | 88.48
63 | 88.76
64 | 88.54
65 | 88.56
66 | 89.30
67 | 89.00
68 | 88.82
69 | 89.40
70 | 88.82
71 | 89.32
72 | 89.26
73 | 89.34
74 | 89.44
75 | 89.32
76 | 89.68
77 | 89.88
78 | 89.92
79 | 89.92
80 | 89.68
81 | 90.04
82 | 90.30
83 | 89.84
84 | 90.22
85 | 89.60
86 | 89.80
87 | 90.16
88 | 89.90
89 | 89.76
90 | 90.10
91 | 90.18
92 | 89.98
93 | 90.10
94 | 90.48
95 | 89.94
96 | 89.88
97 | 89.72
98 | 89.86
99 | 89.90
100 | 90.24
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Validation Accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/Epochs_vs_Validation Accuracy.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/Iterations_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/Iterations_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/activeSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/activeSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/lSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/lSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/plot_epoch_xvalues.txt:
--------------------------------------------------------------------------------
1 | 1.00
2 | 2.00
3 | 3.00
4 | 4.00
5 | 5.00
6 | 6.00
7 | 7.00
8 | 8.00
9 | 9.00
10 | 10.00
11 | 11.00
12 | 12.00
13 | 13.00
14 | 14.00
15 | 15.00
16 | 16.00
17 | 17.00
18 | 18.00
19 | 19.00
20 | 20.00
21 | 21.00
22 | 22.00
23 | 23.00
24 | 24.00
25 | 25.00
26 | 26.00
27 | 27.00
28 | 28.00
29 | 29.00
30 | 30.00
31 | 31.00
32 | 32.00
33 | 33.00
34 | 34.00
35 | 35.00
36 | 36.00
37 | 37.00
38 | 38.00
39 | 39.00
40 | 40.00
41 | 41.00
42 | 42.00
43 | 43.00
44 | 44.00
45 | 45.00
46 | 46.00
47 | 47.00
48 | 48.00
49 | 49.00
50 | 50.00
51 | 51.00
52 | 52.00
53 | 53.00
54 | 54.00
55 | 55.00
56 | 56.00
57 | 57.00
58 | 58.00
59 | 59.00
60 | 60.00
61 | 61.00
62 | 62.00
63 | 63.00
64 | 64.00
65 | 65.00
66 | 66.00
67 | 67.00
68 | 68.00
69 | 69.00
70 | 70.00
71 | 71.00
72 | 72.00
73 | 73.00
74 | 74.00
75 | 75.00
76 | 76.00
77 | 77.00
78 | 78.00
79 | 79.00
80 | 80.00
81 | 81.00
82 | 82.00
83 | 83.00
84 | 84.00
85 | 85.00
86 | 86.00
87 | 87.00
88 | 88.00
89 | 89.00
90 | 90.00
91 | 91.00
92 | 92.00
93 | 93.00
94 | 94.00
95 | 95.00
96 | 96.00
97 | 97.00
98 | 98.00
99 | 99.00
100 | 100.00
101 | 101.00
102 | 102.00
103 | 103.00
104 | 104.00
105 | 105.00
106 | 106.00
107 | 107.00
108 | 108.00
109 | 109.00
110 | 110.00
111 | 111.00
112 | 112.00
113 | 113.00
114 | 114.00
115 | 115.00
116 | 116.00
117 | 117.00
118 | 118.00
119 | 119.00
120 | 120.00
121 | 121.00
122 | 122.00
123 | 123.00
124 | 124.00
125 | 125.00
126 | 126.00
127 | 127.00
128 | 128.00
129 | 129.00
130 | 130.00
131 | 131.00
132 | 132.00
133 | 133.00
134 | 134.00
135 | 135.00
136 | 136.00
137 | 137.00
138 | 138.00
139 | 139.00
140 | 140.00
141 | 141.00
142 | 142.00
143 | 143.00
144 | 144.00
145 | 145.00
146 | 146.00
147 | 147.00
148 | 148.00
149 | 149.00
150 | 150.00
151 | 151.00
152 | 152.00
153 | 153.00
154 | 154.00
155 | 155.00
156 | 156.00
157 | 157.00
158 | 158.00
159 | 159.00
160 | 160.00
161 | 161.00
162 | 162.00
163 | 163.00
164 | 164.00
165 | 165.00
166 | 166.00
167 | 167.00
168 | 168.00
169 | 169.00
170 | 170.00
171 | 171.00
172 | 172.00
173 | 173.00
174 | 174.00
175 | 175.00
176 | 176.00
177 | 177.00
178 | 178.00
179 | 179.00
180 | 180.00
181 | 181.00
182 | 182.00
183 | 183.00
184 | 184.00
185 | 185.00
186 | 186.00
187 | 187.00
188 | 188.00
189 | 189.00
190 | 190.00
191 | 191.00
192 | 192.00
193 | 193.00
194 | 194.00
195 | 195.00
196 | 196.00
197 | 197.00
198 | 198.00
199 | 199.00
200 | 200.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/plot_epoch_yvalues.txt:
--------------------------------------------------------------------------------
1 | 0.13
2 | 0.58
3 | 0.27
4 | 0.22
5 | 0.38
6 | 0.47
7 | 0.19
8 | 0.28
9 | 0.46
10 | 0.45
11 | 0.23
12 | 0.38
13 | 0.23
14 | 0.43
15 | 0.28
16 | 0.60
17 | 0.17
18 | 0.31
19 | 0.32
20 | 0.21
21 | 0.13
22 | 0.26
23 | 0.27
24 | 0.25
25 | 0.18
26 | 0.20
27 | 0.20
28 | 0.07
29 | 0.13
30 | 0.18
31 | 0.12
32 | 0.20
33 | 0.30
34 | 0.17
35 | 0.15
36 | 0.27
37 | 0.11
38 | 0.12
39 | 0.16
40 | 0.07
41 | 0.13
42 | 0.13
43 | 0.30
44 | 0.08
45 | 0.21
46 | 0.16
47 | 0.39
48 | 0.32
49 | 0.29
50 | 0.10
51 | 0.28
52 | 0.18
53 | 0.05
54 | 0.17
55 | 0.12
56 | 0.23
57 | 0.17
58 | 0.14
59 | 0.17
60 | 0.30
61 | 0.26
62 | 0.06
63 | 0.21
64 | 0.20
65 | 0.26
66 | 0.10
67 | 0.08
68 | 0.22
69 | 0.16
70 | 0.15
71 | 0.02
72 | 0.10
73 | 0.18
74 | 0.25
75 | 0.24
76 | 0.07
77 | 0.03
78 | 0.21
79 | 0.17
80 | 0.03
81 | 0.24
82 | 0.16
83 | 0.16
84 | 0.28
85 | 0.17
86 | 0.13
87 | 0.02
88 | 0.04
89 | 0.17
90 | 0.05
91 | 0.37
92 | 0.02
93 | 0.18
94 | 0.20
95 | 0.03
96 | 0.20
97 | 0.08
98 | 0.03
99 | 0.02
100 | 0.13
101 | 0.13
102 | 0.21
103 | 0.06
104 | 0.08
105 | 0.01
106 | 0.29
107 | 0.01
108 | 0.04
109 | 0.09
110 | 0.07
111 | 0.01
112 | 0.14
113 | 0.11
114 | 0.04
115 | 0.17
116 | 0.01
117 | 0.08
118 | 0.02
119 | 0.09
120 | 0.06
121 | 0.01
122 | 0.09
123 | 0.02
124 | 0.00
125 | 0.01
126 | 0.03
127 | 0.00
128 | 0.08
129 | 0.06
130 | 0.00
131 | 0.13
132 | 0.00
133 | 0.01
134 | 0.03
135 | 0.00
136 | 0.00
137 | 0.00
138 | 0.05
139 | 0.00
140 | 0.00
141 | 0.02
142 | 0.01
143 | 0.00
144 | 0.00
145 | 0.00
146 | 0.00
147 | 0.04
148 | 0.00
149 | 0.01
150 | 0.00
151 | 0.00
152 | 0.02
153 | 0.00
154 | 0.02
155 | 0.01
156 | 0.00
157 | 0.00
158 | 0.04
159 | 0.00
160 | 0.00
161 | 0.01
162 | 0.00
163 | 0.00
164 | 0.01
165 | 0.00
166 | 0.01
167 | 0.00
168 | 0.00
169 | 0.03
170 | 0.00
171 | 0.00
172 | 0.00
173 | 0.00
174 | 0.00
175 | 0.00
176 | 0.00
177 | 0.00
178 | 0.00
179 | 0.00
180 | 0.00
181 | 0.00
182 | 0.00
183 | 0.01
184 | 0.00
185 | 0.03
186 | 0.00
187 | 0.00
188 | 0.00
189 | 0.00
190 | 0.00
191 | 0.00
192 | 0.00
193 | 0.00
194 | 0.01
195 | 0.00
196 | 0.00
197 | 0.00
198 | 0.00
199 | 0.00
200 | 0.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/uSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_4/uSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/val_acc_epochs_x.txt:
--------------------------------------------------------------------------------
1 | 2.00
2 | 4.00
3 | 6.00
4 | 8.00
5 | 10.00
6 | 12.00
7 | 14.00
8 | 16.00
9 | 18.00
10 | 20.00
11 | 22.00
12 | 24.00
13 | 26.00
14 | 28.00
15 | 30.00
16 | 32.00
17 | 34.00
18 | 36.00
19 | 38.00
20 | 40.00
21 | 42.00
22 | 44.00
23 | 46.00
24 | 48.00
25 | 50.00
26 | 52.00
27 | 54.00
28 | 56.00
29 | 58.00
30 | 60.00
31 | 62.00
32 | 64.00
33 | 66.00
34 | 68.00
35 | 70.00
36 | 72.00
37 | 74.00
38 | 76.00
39 | 78.00
40 | 80.00
41 | 82.00
42 | 84.00
43 | 86.00
44 | 88.00
45 | 90.00
46 | 92.00
47 | 94.00
48 | 96.00
49 | 98.00
50 | 100.00
51 | 102.00
52 | 104.00
53 | 106.00
54 | 108.00
55 | 110.00
56 | 112.00
57 | 114.00
58 | 116.00
59 | 118.00
60 | 120.00
61 | 122.00
62 | 124.00
63 | 126.00
64 | 128.00
65 | 130.00
66 | 132.00
67 | 134.00
68 | 136.00
69 | 138.00
70 | 140.00
71 | 142.00
72 | 144.00
73 | 146.00
74 | 148.00
75 | 150.00
76 | 152.00
77 | 154.00
78 | 156.00
79 | 158.00
80 | 160.00
81 | 162.00
82 | 164.00
83 | 166.00
84 | 168.00
85 | 170.00
86 | 172.00
87 | 174.00
88 | 176.00
89 | 178.00
90 | 180.00
91 | 182.00
92 | 184.00
93 | 186.00
94 | 188.00
95 | 190.00
96 | 192.00
97 | 194.00
98 | 196.00
99 | 198.00
100 | 200.00
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_4/val_acc_epochs_y.txt:
--------------------------------------------------------------------------------
1 | 81.86
2 | 84.10
3 | 81.82
4 | 85.30
5 | 85.54
6 | 84.24
7 | 85.78
8 | 86.04
9 | 86.44
10 | 85.42
11 | 87.06
12 | 86.06
13 | 87.04
14 | 86.12
15 | 85.04
16 | 85.08
17 | 86.92
18 | 87.82
19 | 86.64
20 | 85.22
21 | 85.92
22 | 85.52
23 | 86.30
24 | 85.42
25 | 85.64
26 | 86.56
27 | 85.56
28 | 87.52
29 | 86.88
30 | 85.96
31 | 86.52
32 | 87.48
33 | 86.80
34 | 87.54
35 | 86.28
36 | 86.88
37 | 88.10
38 | 87.30
39 | 87.62
40 | 88.20
41 | 88.22
42 | 88.72
43 | 88.20
44 | 87.90
45 | 88.06
46 | 87.84
47 | 87.64
48 | 87.98
49 | 87.62
50 | 88.58
51 | 88.00
52 | 88.76
53 | 87.86
54 | 88.48
55 | 88.84
56 | 88.56
57 | 89.58
58 | 88.88
59 | 89.50
60 | 89.96
61 | 89.38
62 | 89.18
63 | 89.74
64 | 90.00
65 | 89.96
66 | 89.44
67 | 90.40
68 | 89.88
69 | 90.08
70 | 90.48
71 | 90.22
72 | 90.06
73 | 90.40
74 | 89.94
75 | 90.52
76 | 90.60
77 | 90.24
78 | 90.68
79 | 90.88
80 | 90.82
81 | 90.34
82 | 90.56
83 | 90.98
84 | 90.62
85 | 90.98
86 | 90.46
87 | 91.04
88 | 90.92
89 | 90.88
90 | 90.92
91 | 90.72
92 | 91.04
93 | 90.76
94 | 90.72
95 | 90.74
96 | 90.48
97 | 90.74
98 | 90.76
99 | 90.72
100 | 90.80
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Validation Accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_5/Epochs_vs_Validation Accuracy.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_5/Iterations_vs_Loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/episode_5/Iterations_vs_Loss.png
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_5/plot_epoch_xvalues.txt:
--------------------------------------------------------------------------------
1 | 1.00
2 | 2.00
3 | 3.00
4 | 4.00
5 | 5.00
6 | 6.00
7 | 7.00
8 | 8.00
9 | 9.00
10 | 10.00
11 | 11.00
12 | 12.00
13 | 13.00
14 | 14.00
15 | 15.00
16 | 16.00
17 | 17.00
18 | 18.00
19 | 19.00
20 | 20.00
21 | 21.00
22 | 22.00
23 | 23.00
24 | 24.00
25 | 25.00
26 | 26.00
27 | 27.00
28 | 28.00
29 | 29.00
30 | 30.00
31 | 31.00
32 | 32.00
33 | 33.00
34 | 34.00
35 | 35.00
36 | 36.00
37 | 37.00
38 | 38.00
39 | 39.00
40 | 40.00
41 | 41.00
42 | 42.00
43 | 43.00
44 | 44.00
45 | 45.00
46 | 46.00
47 | 47.00
48 | 48.00
49 | 49.00
50 | 50.00
51 | 51.00
52 | 52.00
53 | 53.00
54 | 54.00
55 | 55.00
56 | 56.00
57 | 57.00
58 | 58.00
59 | 59.00
60 | 60.00
61 | 61.00
62 | 62.00
63 | 63.00
64 | 64.00
65 | 65.00
66 | 66.00
67 | 67.00
68 | 68.00
69 | 69.00
70 | 70.00
71 | 71.00
72 | 72.00
73 | 73.00
74 | 74.00
75 | 75.00
76 | 76.00
77 | 77.00
78 | 78.00
79 | 79.00
80 | 80.00
81 | 81.00
82 | 82.00
83 | 83.00
84 | 84.00
85 | 85.00
86 | 86.00
87 | 87.00
88 | 88.00
89 | 89.00
90 | 90.00
91 | 91.00
92 | 92.00
93 | 93.00
94 | 94.00
95 | 95.00
96 | 96.00
97 | 97.00
98 | 98.00
99 | 99.00
100 | 100.00
101 | 101.00
102 | 102.00
103 | 103.00
104 | 104.00
105 | 105.00
106 | 106.00
107 | 107.00
108 | 108.00
109 | 109.00
110 | 110.00
111 | 111.00
112 | 112.00
113 | 113.00
114 | 114.00
115 | 115.00
116 | 116.00
117 | 117.00
118 | 118.00
119 | 119.00
120 | 120.00
121 | 121.00
122 | 122.00
123 | 123.00
124 | 124.00
125 | 125.00
126 | 126.00
127 | 127.00
128 | 128.00
129 | 129.00
130 | 130.00
131 | 131.00
132 | 132.00
133 | 133.00
134 | 134.00
135 | 135.00
136 | 136.00
137 | 137.00
138 | 138.00
139 | 139.00
140 | 140.00
141 | 141.00
142 | 142.00
143 | 143.00
144 | 144.00
145 | 145.00
146 | 146.00
147 | 147.00
148 | 148.00
149 | 149.00
150 | 150.00
151 | 151.00
152 | 152.00
153 | 153.00
154 | 154.00
155 | 155.00
156 | 156.00
157 | 157.00
158 | 158.00
159 | 159.00
160 | 160.00
161 | 161.00
162 | 162.00
163 | 163.00
164 | 164.00
165 | 165.00
166 | 166.00
167 | 167.00
168 | 168.00
169 | 169.00
170 | 170.00
171 | 171.00
172 | 172.00
173 | 173.00
174 | 174.00
175 | 175.00
176 | 176.00
177 | 177.00
178 | 178.00
179 | 179.00
180 | 180.00
181 | 181.00
182 | 182.00
183 | 183.00
184 | 184.00
185 | 185.00
186 | 186.00
187 | 187.00
188 | 188.00
189 | 189.00
190 | 190.00
191 | 191.00
192 | 192.00
193 | 193.00
194 | 194.00
195 | 195.00
196 | 196.00
197 | 197.00
198 | 198.00
199 | 199.00
200 | 200.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_5/plot_epoch_yvalues.txt:
--------------------------------------------------------------------------------
1 | 0.21
2 | 0.36
3 | 0.20
4 | 0.17
5 | 0.21
6 | 0.42
7 | 0.46
8 | 0.28
9 | 0.08
10 | 0.32
11 | 0.21
12 | 0.30
13 | 0.29
14 | 0.24
15 | 0.22
16 | 0.10
17 | 0.28
18 | 0.29
19 | 0.28
20 | 0.17
21 | 0.14
22 | 0.21
23 | 0.29
24 | 0.18
25 | 0.08
26 | 0.16
27 | 0.22
28 | 0.26
29 | 0.20
30 | 0.14
31 | 0.07
32 | 0.13
33 | 0.16
34 | 0.22
35 | 0.16
36 | 0.41
37 | 0.09
38 | 0.22
39 | 0.14
40 | 0.17
41 | 0.21
42 | 0.11
43 | 0.22
44 | 0.16
45 | 0.22
46 | 0.06
47 | 0.29
48 | 0.29
49 | 0.08
50 | 0.20
51 | 0.09
52 | 0.25
53 | 0.15
54 | 0.12
55 | 0.12
56 | 0.09
57 | 0.11
58 | 0.12
59 | 0.14
60 | 0.10
61 | 0.23
62 | 0.12
63 | 0.30
64 | 0.45
65 | 0.08
66 | 0.07
67 | 0.12
68 | 0.08
69 | 0.11
70 | 0.16
71 | 0.02
72 | 0.14
73 | 0.19
74 | 0.07
75 | 0.14
76 | 0.28
77 | 0.06
78 | 0.15
79 | 0.04
80 | 0.17
81 | 0.14
82 | 0.31
83 | 0.13
84 | 0.12
85 | 0.13
86 | 0.13
87 | 0.08
88 | 0.03
89 | 0.37
90 | 0.10
91 | 0.07
92 | 0.03
93 | 0.02
94 | 0.09
95 | 0.17
96 | 0.03
97 | 0.30
98 | 0.04
99 | 0.06
100 | 0.11
101 | 0.02
102 | 0.07
103 | 0.18
104 | 0.14
105 | 0.35
106 | 0.01
107 | 0.06
108 | 0.07
109 | 0.10
110 | 0.03
111 | 0.06
112 | 0.04
113 | 0.00
114 | 0.06
115 | 0.11
116 | 0.10
117 | 0.04
118 | 0.09
119 | 0.02
120 | 0.04
121 | 0.03
122 | 0.08
123 | 0.02
124 | 0.02
125 | 0.01
126 | 0.09
127 | 0.02
128 | 0.03
129 | 0.07
130 | 0.03
131 | 0.01
132 | 0.07
133 | 0.00
134 | 0.00
135 | 0.00
136 | 0.00
137 | 0.00
138 | 0.05
139 | 0.00
140 | 0.01
141 | 0.00
142 | 0.00
143 | 0.00
144 | 0.00
145 | 0.00
146 | 0.00
147 | 0.00
148 | 0.00
149 | 0.00
150 | 0.02
151 | 0.01
152 | 0.00
153 | 0.00
154 | 0.00
155 | 0.00
156 | 0.00
157 | 0.00
158 | 0.00
159 | 0.00
160 | 0.00
161 | 0.00
162 | 0.00
163 | 0.07
164 | 0.00
165 | 0.00
166 | 0.00
167 | 0.00
168 | 0.00
169 | 0.00
170 | 0.00
171 | 0.00
172 | 0.00
173 | 0.00
174 | 0.00
175 | 0.00
176 | 0.00
177 | 0.01
178 | 0.00
179 | 0.00
180 | 0.00
181 | 0.00
182 | 0.07
183 | 0.00
184 | 0.00
185 | 0.01
186 | 0.00
187 | 0.00
188 | 0.01
189 | 0.00
190 | 0.00
191 | 0.00
192 | 0.00
193 | 0.00
194 | 0.00
195 | 0.00
196 | 0.00
197 | 0.00
198 | 0.00
199 | 0.00
200 | 0.00
201 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_5/val_acc_epochs_x.txt:
--------------------------------------------------------------------------------
1 | 2.00
2 | 4.00
3 | 6.00
4 | 8.00
5 | 10.00
6 | 12.00
7 | 14.00
8 | 16.00
9 | 18.00
10 | 20.00
11 | 22.00
12 | 24.00
13 | 26.00
14 | 28.00
15 | 30.00
16 | 32.00
17 | 34.00
18 | 36.00
19 | 38.00
20 | 40.00
21 | 42.00
22 | 44.00
23 | 46.00
24 | 48.00
25 | 50.00
26 | 52.00
27 | 54.00
28 | 56.00
29 | 58.00
30 | 60.00
31 | 62.00
32 | 64.00
33 | 66.00
34 | 68.00
35 | 70.00
36 | 72.00
37 | 74.00
38 | 76.00
39 | 78.00
40 | 80.00
41 | 82.00
42 | 84.00
43 | 86.00
44 | 88.00
45 | 90.00
46 | 92.00
47 | 94.00
48 | 96.00
49 | 98.00
50 | 100.00
51 | 102.00
52 | 104.00
53 | 106.00
54 | 108.00
55 | 110.00
56 | 112.00
57 | 114.00
58 | 116.00
59 | 118.00
60 | 120.00
61 | 122.00
62 | 124.00
63 | 126.00
64 | 128.00
65 | 130.00
66 | 132.00
67 | 134.00
68 | 136.00
69 | 138.00
70 | 140.00
71 | 142.00
72 | 144.00
73 | 146.00
74 | 148.00
75 | 150.00
76 | 152.00
77 | 154.00
78 | 156.00
79 | 158.00
80 | 160.00
81 | 162.00
82 | 164.00
83 | 166.00
84 | 168.00
85 | 170.00
86 | 172.00
87 | 174.00
88 | 176.00
89 | 178.00
90 | 180.00
91 | 182.00
92 | 184.00
93 | 186.00
94 | 188.00
95 | 190.00
96 | 192.00
97 | 194.00
98 | 196.00
99 | 198.00
100 | 200.00
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/episode_5/val_acc_epochs_y.txt:
--------------------------------------------------------------------------------
1 | 82.96
2 | 82.74
3 | 85.02
4 | 85.02
5 | 86.20
6 | 85.22
7 | 86.16
8 | 86.50
9 | 86.38
10 | 85.42
11 | 85.34
12 | 85.38
13 | 86.98
14 | 86.28
15 | 86.26
16 | 86.72
17 | 87.10
18 | 86.54
19 | 86.48
20 | 88.04
21 | 86.72
22 | 86.68
23 | 86.16
24 | 86.72
25 | 87.52
26 | 86.62
27 | 86.62
28 | 85.92
29 | 88.34
30 | 87.86
31 | 88.50
32 | 86.96
33 | 87.66
34 | 88.32
35 | 87.88
36 | 87.74
37 | 88.20
38 | 88.44
39 | 88.48
40 | 88.24
41 | 87.46
42 | 87.92
43 | 87.68
44 | 87.90
45 | 88.70
46 | 87.58
47 | 87.90
48 | 89.60
49 | 87.96
50 | 88.80
51 | 89.28
52 | 88.92
53 | 89.86
54 | 89.46
55 | 88.32
56 | 89.36
57 | 89.46
58 | 89.76
59 | 89.20
60 | 89.48
61 | 89.76
62 | 89.48
63 | 89.70
64 | 89.42
65 | 90.00
66 | 89.44
67 | 90.42
68 | 89.90
69 | 90.22
70 | 90.32
71 | 90.74
72 | 91.00
73 | 91.04
74 | 90.80
75 | 90.84
76 | 90.92
77 | 90.74
78 | 91.36
79 | 90.92
80 | 91.18
81 | 91.50
82 | 91.26
83 | 91.18
84 | 91.14
85 | 91.56
86 | 91.60
87 | 91.80
88 | 91.14
89 | 91.28
90 | 91.34
91 | 91.02
92 | 91.48
93 | 91.10
94 | 91.50
95 | 91.26
96 | 91.34
97 | 91.38
98 | 91.34
99 | 91.62
100 | 91.50
101 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/lSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/lSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/plot_episode_xvalues.txt:
--------------------------------------------------------------------------------
1 | 0.00
2 | 1.00
3 | 2.00
4 | 3.00
5 | 4.00
6 | 5.00
7 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/plot_episode_yvalues.txt:
--------------------------------------------------------------------------------
1 | 71.83
2 | 82.86
3 | 87.32
4 | 90.04
5 | 90.88
6 | 91.12
7 |
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/uSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/uSet.npy
--------------------------------------------------------------------------------
/output/CIFAR10/resnet18/ENT_1/valSet.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/output/CIFAR10/resnet18/ENT_1/valSet.npy
--------------------------------------------------------------------------------
/pycls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/__init__.py
--------------------------------------------------------------------------------
/pycls/al/ActiveLearning.py:
--------------------------------------------------------------------------------
1 | # This file is slightly modified from a code implementation by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
2 | # GitHub: https://github.com/PrateekMunjal
3 | # ----------------------------------------------------------
4 |
5 | import numpy as np
6 | import torch
7 | from .Sampling import Sampling, CoreSetMIPSampling, AdversarySampler
8 | import pycls.utils.logging as lu
9 | import os
10 |
11 | logger = lu.get_logger(__name__)
12 |
13 | class ActiveLearning:
14 | """
15 | Implements standard active learning methods.
16 | """
17 |
18 | def __init__(self, dataObj, cfg):
19 | self.dataObj = dataObj
20 | self.sampler = Sampling(dataObj=dataObj,cfg=cfg)
21 | self.cfg = cfg
22 |
23 | def sample_from_uSet(self, clf_model, lSet, uSet, trainDataset, supportingModels=None):
24 | """
25 | Sample from uSet using cfg.ACTIVE_LEARNING.SAMPLING_FN.
26 |
27 | INPUT
28 | ------
29 | clf_model: Reference of task classifier model class [Typically VGG]
30 |
31 | supportingModels: List of models which are used for sampling process.
32 |
33 | OUTPUT
34 | -------
35 | Returns activeSet, uSet
36 | """
37 | assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE > 0, "Expected a positive budgetSize"
38 | assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE < len(uSet), "BudgetSet cannot exceed length of unlabelled set. Length of unlabelled set: {} and budgetSize: {}"\
39 | .format(len(uSet), self.cfg.ACTIVE_LEARNING.BUDGET_SIZE)
40 |
41 | if self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "random":
42 |
43 | activeSet, uSet = self.sampler.random(uSet=uSet, budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE)
44 |
45 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "uncertainty":
46 | oldmode = clf_model.training
47 | clf_model.eval()
48 | activeSet, uSet = self.sampler.uncertainty(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,lSet=lSet,uSet=uSet \
49 | ,model=clf_model,dataset=trainDataset)
50 | clf_model.train(oldmode)
51 |
52 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "entropy":
53 | oldmode = clf_model.training
54 | clf_model.eval()
55 | activeSet, uSet = self.sampler.entropy(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,lSet=lSet,uSet=uSet \
56 | ,model=clf_model,dataset=trainDataset)
57 | clf_model.train(oldmode)
58 |
59 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "margin":
60 | oldmode = clf_model.training
61 | clf_model.eval()
62 | activeSet, uSet = self.sampler.margin(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,lSet=lSet,uSet=uSet \
63 | ,model=clf_model,dataset=trainDataset)
64 | clf_model.train(oldmode)
65 |
66 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "coreset":
67 | waslatent = clf_model.penultimate_active
68 | wastrain = clf_model.training
69 | clf_model.penultimate_active = True
70 | # if self.cfg.TRAIN.DATASET == "IMAGENET":
71 | # clf_model.cuda(0)
72 | clf_model.eval()
73 | coreSetSampler = CoreSetMIPSampling(cfg=self.cfg, dataObj=self.dataObj)
74 | activeSet, uSet = coreSetSampler.query(lSet=lSet, uSet=uSet, clf_model=clf_model, dataset=trainDataset)
75 |
76 | clf_model.penultimate_active = waslatent
77 | clf_model.train(wastrain)
78 |
79 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "dbal" or self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "DBAL":
80 | activeSet, uSet = self.sampler.dbal(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE, \
81 | uSet=uSet, clf_model=clf_model,dataset=trainDataset)
82 |
83 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "bald" or self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "BALD":
84 | activeSet, uSet = self.sampler.bald(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE, uSet=uSet, clf_model=clf_model, dataset=trainDataset)
85 |
86 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "ensemble_var_R":
87 | activeSet, uSet = self.sampler.ensemble_var_R(budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE, uSet=uSet, clf_models=supportingModels, dataset=trainDataset)
88 |
89 | elif self.cfg.ACTIVE_LEARNING.SAMPLING_FN == "vaal":
90 | adv_sampler = AdversarySampler(cfg=self.cfg, dataObj=self.dataObj)
91 |
92 | # Train VAE and discriminator first
93 | vae, disc, uSet_loader = adv_sampler.vaal_perform_training(lSet=lSet, uSet=uSet, dataset=trainDataset)
94 |
95 | # Do active sampling
96 | activeSet, uSet = adv_sampler.sample_for_labeling(vae=vae, discriminator=disc, \
97 | unlabeled_dataloader=uSet_loader, uSet=uSet)
98 | else:
99 | print(f"{self.cfg.ACTIVE_LEARNING.SAMPLING_FN} is either not implemented or there is some spelling mistake.")
100 | raise NotImplementedError
101 |
102 | return activeSet, uSet
103 |
104 |
--------------------------------------------------------------------------------
/pycls/al/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/al/__init__.py
--------------------------------------------------------------------------------
/pycls/al/vaal_util.py:
--------------------------------------------------------------------------------
1 | # This file is directly taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
2 | # GitHub: https://github.com/PrateekMunjal
3 | # ----------------------------------------------------------
4 |
5 | # code modified from VAAL codebase
6 |
7 | import os
8 | import torch
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | from pycls.models import vaal_model as vm
13 | import pycls.utils.logging as lu
14 | # import pycls.datasets.loader as imagenet_loader
15 |
16 | logger = lu.get_logger(__name__)
17 |
18 | bce_loss = torch.nn.BCELoss().cuda()
19 |
20 | def data_parallel_wrapper(model, cur_device, cfg):
21 | model.cuda(cur_device)
22 | model = torch.nn.DataParallel(model, device_ids = [i for i in range(torch.cuda.device_count())])
23 | return model
24 |
25 | def distributed_wrapper(cfg, model, cur_device):
26 | # Transfer the model to the current GPU device
27 | model = model.cuda(device=cur_device)
28 |
29 | # Use multi-process data parallel model in the multi-gpu setting
30 | if cfg.NUM_GPUS > 1:
31 | # Make model replica operate on the current device
32 | model = torch.nn.parallel.DistributedDataParallel(
33 | module=model,
34 | device_ids=[cur_device],
35 | output_device=cur_device
36 | )
37 | return model
38 |
39 | def read_data(dataloader, labels=True):
40 | if labels:
41 | while True:
42 | for img, label in dataloader:
43 | yield img, label
44 | else:
45 | while True:
46 | for img, _ in dataloader:
47 | yield img
48 |
49 | def vae_loss( x, recon, mu, logvar, beta):
50 | mse_loss = torch.nn.MSELoss().cuda()
51 | recon = recon.cuda()
52 | x = x.cuda()
53 | MSE = mse_loss(recon, x)
54 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
55 | KLD = KLD * beta
56 | return MSE + KLD
57 |
58 | def train_vae_disc_epoch(cfg, vae_model, disc_model, optim_vae, optim_disc, lSetLoader, uSetLoader, cur_epoch, \
59 | n_lu, curr_vae_disc_iter, max_vae_disc_iters, change_lr_iter,isDistributed=False):
60 |
61 | if isDistributed:
62 | lSetLoader.sampler.set_epoch(cur_epoch)
63 | uSetLoader.sampler.set_epoch(cur_epoch)
64 |
65 | print('len(lSetLoader): {}'.format(len(lSetLoader)))
66 | print('len(uSetLoader): {}'.format(len(uSetLoader)))
67 |
68 | labeled_data = read_data(lSetLoader)
69 | unlabeled_data = read_data(uSetLoader, labels=False)
70 |
71 | vae_model.train()
72 | disc_model.train()
73 |
74 | temp_bs = int(cfg.VAAL.VAE_BS)
75 | train_iterations = int(n_lu/temp_bs)
76 |
77 | for temp_iter in range(train_iterations):
78 |
79 | if curr_vae_disc_iter !=0 and curr_vae_disc_iter%change_lr_iter==0:
80 | #print("Changing LR ---- ))__((---- ")
81 | for param in optim_vae.param_groups:
82 | param['lr'] = param['lr'] * 0.9
83 |
84 | for param in optim_disc.param_groups:
85 | param['lr'] = param['lr'] * 0.9
86 |
87 | curr_vae_disc_iter += 1
88 |
89 | ## VAE Step
90 | disc_model.eval()
91 | vae_model.train()
92 | labeled_imgs, labels = next(labeled_data)
93 | unlabeled_imgs = next(unlabeled_data)
94 |
95 | labeled_imgs = labeled_imgs.type(torch.cuda.FloatTensor)
96 | unlabeled_imgs = unlabeled_imgs.type(torch.cuda.FloatTensor)
97 |
98 | labeled_imgs = labeled_imgs.cuda()
99 | unlabeled_imgs = unlabeled_imgs.cuda()
100 |
101 | recon, z, mu, logvar = vae_model(labeled_imgs)
102 | recon = recon.view((labeled_imgs.shape[0], labeled_imgs.shape[1], labeled_imgs.shape[2], labeled_imgs.shape[3]))
103 | unsup_loss = vae_loss(labeled_imgs, recon, mu, logvar, cfg.VAAL.BETA)
104 | unlab_recon, unlab_z, unlab_mu, unlab_logvar = vae_model(unlabeled_imgs)
105 | unlab_recon = unlab_recon.view((unlabeled_imgs.shape[0], unlabeled_imgs.shape[1], unlabeled_imgs.shape[2], unlabeled_imgs.shape[3]))
106 | transductive_loss = vae_loss(unlabeled_imgs, unlab_recon, unlab_mu, unlab_logvar, cfg.VAAL.BETA)
107 |
108 | labeled_preds = disc_model(mu)
109 | unlabeled_preds = disc_model(unlab_mu)
110 |
111 | lab_real_preds = torch.ones(labeled_imgs.size(0),1).cuda()
112 | unlab_real_preds = torch.ones(unlabeled_imgs.size(0),1).cuda()
113 | dsc_loss = bce_loss(labeled_preds, lab_real_preds) + bce_loss(unlabeled_preds, unlab_real_preds)
114 |
115 | total_vae_loss = unsup_loss + transductive_loss + cfg.VAAL.ADVERSARY_PARAM * dsc_loss
116 |
117 | optim_vae.zero_grad()
118 | total_vae_loss.backward()
119 | optim_vae.step()
120 |
121 | ##DISC STEP
122 | vae_model.eval()
123 | disc_model.train()
124 |
125 | with torch.no_grad():
126 | _, _, mu, _ = vae_model(labeled_imgs)
127 | _, _, unlab_mu, _ = vae_model(unlabeled_imgs)
128 |
129 | labeled_preds = disc_model(mu)
130 | unlabeled_preds = disc_model(unlab_mu)
131 |
132 | lab_real_preds = torch.ones(labeled_imgs.size(0),1).cuda()
133 | unlab_fake_preds = torch.zeros(unlabeled_imgs.size(0),1).cuda()
134 |
135 | dsc_loss = bce_loss(labeled_preds, lab_real_preds) + \
136 | bce_loss(unlabeled_preds, unlab_fake_preds)
137 |
138 | optim_disc.zero_grad()
139 | dsc_loss.backward()
140 | optim_disc.step()
141 |
142 |
143 | if temp_iter%100 == 0:
144 | print("Epoch[{}],Iteration [{}/{}], VAE Loss: {:.3f}, Disc Loss: {:.4f}"\
145 | .format(cur_epoch,temp_iter, train_iterations, total_vae_loss.item(), dsc_loss.item()))
146 |
147 | return vae_model, disc_model, optim_vae, optim_disc, curr_vae_disc_iter
148 |
149 | def train_vae_disc(cfg, lSet, uSet, trainDataset, dataObj, debug=False):
150 |
151 | cur_device = torch.cuda.current_device()
152 | if cfg.DATASET.NAME == 'MNIST':
153 | vae_model = vm.VAE(cur_device, z_dim=cfg.VAAL.Z_DIM, nc=1)
154 | disc_model = vm.Discriminator(z_dim=cfg.VAAL.Z_DIM)
155 | else:
156 | vae_model = vm.VAE(cur_device, z_dim=cfg.VAAL.Z_DIM)
157 | disc_model = vm.Discriminator(z_dim=cfg.VAAL.Z_DIM)
158 |
159 |
160 | # vae_model = data_parallel_wrapper(vae_model, cur_device, cfg)
161 | # disc_model = data_parallel_wrapper(disc_model, cur_device, cfg)
162 |
163 | # if cfg.TRAIN.DATASET == "IMAGENET":
164 | # lSetLoader = imagenet_loader.construct_loader_no_aug(cfg, indices=lSet, isDistributed=False, isVaalSampling=True)
165 | # uSetLoader = imagenet_loader.construct_loader_no_aug(cfg, indices=uSet, isDistributed=False, isVaalSampling=True)
166 | # else:
167 | lSetLoader = dataObj.getIndexesDataLoader(indexes=lSet, batch_size=int(cfg.VAAL.VAE_BS) \
168 | ,data=trainDataset)
169 |
170 | uSetLoader = dataObj.getIndexesDataLoader(indexes=uSet, batch_size=int(cfg.VAAL.VAE_BS) \
171 | ,data=trainDataset)
172 |
173 | print("Initializing VAE and discriminator")
174 | logger.info("Initializing VAE and discriminator")
175 | optim_vae = torch.optim.Adam(vae_model.parameters(), lr=cfg.VAAL.VAE_LR)
176 | print(f"VAE Optimizer ==> {optim_vae}")
177 | logger.info(f"VAE Optimizer ==> {optim_vae}")
178 | optim_disc = torch.optim.Adam(disc_model.parameters(), lr=cfg.VAAL.DISC_LR)
179 | print(f"Disc Optimizer ==> {optim_disc}")
180 | logger.info(f"Disc Optimizer ==> {optim_disc}")
181 | print("==================================")
182 | logger.info("==================================\n")
183 |
184 | n_lu_points = len(lSet)+len(uSet)
185 | max_vae_disc_iters = int(n_lu_points/cfg.VAAL.VAE_BS)*cfg.VAAL.VAE_EPOCHS
186 | change_lr_iter = max_vae_disc_iters // 25
187 | curr_vae_disc_iter = 0
188 |
189 | vae_model = vae_model.cuda()
190 | disc_model = disc_model.cuda()
191 |
192 | for epoch in range(cfg.VAAL.VAE_EPOCHS):
193 | vae_model, disc_model, optim_vae, optim_disc, curr_vae_disc_iter = train_vae_disc_epoch(cfg, vae_model, disc_model, optim_vae, \
194 | optim_disc, lSetLoader, uSetLoader, epoch, n_lu_points, curr_vae_disc_iter, max_vae_disc_iters, change_lr_iter)
195 |
196 | #Save vae and disc models
197 | vae_sd = vae_model.module.state_dict() if cfg.NUM_GPUS > 1 else vae_model.state_dict()
198 | disc_sd = disc_model.module.state_dict() if cfg.NUM_GPUS > 1 else disc_model.state_dict()
199 | # Record the state
200 | vae_checkpoint = {
201 | 'epoch': cfg.VAAL.VAE_EPOCHS + 1,
202 | 'model_state': vae_sd,
203 | 'optimizer_state': optim_vae.state_dict(),
204 | 'cfg': cfg.dump()
205 | }
206 | disc_checkpoint = {
207 | 'epoch': cfg.VAAL.VAE_EPOCHS + 1,
208 | 'model_state': disc_sd,
209 | 'optimizer_state': optim_disc.state_dict(),
210 | 'cfg': cfg.dump()
211 | }
212 | # Write the checkpoint
213 | os.makedirs(cfg.EPISODE_DIR, exist_ok=True)
214 | vae_checkpoint_file = os.path.join(cfg.EPISODE_DIR, "vae.pyth")
215 | disc_checkpoint_file = os.path.join(cfg.EPISODE_DIR, "disc.pyth")
216 | torch.save(vae_checkpoint, vae_checkpoint_file)
217 | torch.save(disc_checkpoint, disc_checkpoint_file)
218 |
219 | if debug: print("Saved VAE model at {}".format(vae_checkpoint_file))
220 | if debug: print("Saved DISC model at {}".format(disc_checkpoint_file))
221 |
222 | return vae_model, disc_model
--------------------------------------------------------------------------------
/pycls/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/core/__init__.py
--------------------------------------------------------------------------------
/pycls/core/builders.py:
--------------------------------------------------------------------------------
1 | # This file is modified from official pycls repository
2 |
3 | """Model and loss construction functions."""
4 |
5 | from pycls.core.net import SoftCrossEntropyLoss
6 | from pycls.models.resnet import *
7 | from pycls.models.vgg import *
8 | from pycls.models.alexnet import *
9 |
10 | import torch
11 |
12 |
13 | # Supported models
14 | _models = {
15 | # VGG style architectures
16 | 'vgg11': vgg11,
17 | 'vgg11_bn': vgg11_bn,
18 | 'vgg13': vgg13,
19 | 'vgg13_bn': vgg13_bn,
20 | 'vgg16': vgg16,
21 | 'vgg16_bn': vgg16_bn,
22 | 'vgg19': vgg19,
23 | 'vgg19_bn': vgg19_bn,
24 |
25 | # ResNet style archiectures
26 | 'resnet18': resnet18,
27 | 'resnet34': resnet34,
28 | 'resnet50': resnet50,
29 | 'resnet101': resnet101,
30 | 'resnet152': resnet152,
31 | 'resnext50_32x4d': resnext50_32x4d,
32 | 'resnext101_32x8d': resnext101_32x8d,
33 | 'wide_resnet50_2': wide_resnet50_2,
34 | 'wide_resnet101_2': wide_resnet101_2,
35 |
36 | # AlexNet architecture
37 | 'alexnet': alexnet
38 | }
39 |
40 | # Supported loss functions
41 | _loss_funs = {"cross_entropy": SoftCrossEntropyLoss}
42 |
43 |
44 | def get_model(cfg):
45 | """Gets the model class specified in the config."""
46 | err_str = "Model type '{}' not supported"
47 | assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
48 | return _models[cfg.MODEL.TYPE]
49 |
50 |
51 | def get_loss_fun(cfg):
52 | """Gets the loss function class specified in the config."""
53 | err_str = "Loss function type '{}' not supported"
54 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
55 | return _loss_funs[cfg.MODEL.LOSS_FUN]
56 |
57 |
58 | def build_model(cfg):
59 | """Builds the model."""
60 | model = get_model(cfg)(num_classes=cfg.MODEL.NUM_CLASSES, use_dropout=True)
61 | if cfg.DATASET.NAME == 'MNIST':
62 | model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
63 |
64 | return model
65 |
66 |
67 | def build_loss_fun(cfg):
68 | """Build the loss function."""
69 | return get_loss_fun(cfg)()
70 |
71 |
72 | def register_model(name, ctor):
73 | """Registers a model dynamically."""
74 | _models[name] = ctor
75 |
76 |
77 | def register_loss_fun(name, ctor):
78 | """Registers a loss function dynamically."""
79 | _loss_funs[name] = ctor
80 |
--------------------------------------------------------------------------------
/pycls/core/losses.py:
--------------------------------------------------------------------------------
1 | """Loss functions."""
2 |
3 | import torch.nn as nn
4 |
5 | from pycls.core.config import cfg
6 |
7 | # Supported loss functions
8 | _loss_funs = {
9 | 'cross_entropy': nn.CrossEntropyLoss,
10 | }
11 |
12 | def get_loss_fun():
13 | """Retrieves the loss function."""
14 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), \
15 | 'Loss function \'{}\' not supported'.format(cfg.TRAIN.LOSS)
16 | return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda()
17 |
18 |
19 | def register_loss_fun(name, ctor):
20 | """Registers a loss function dynamically."""
21 | _loss_funs[name] = ctor
22 |
--------------------------------------------------------------------------------
/pycls/core/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for manipulating networks."""
9 |
10 | import itertools
11 |
12 | import numpy as np
13 | # import pycls.core.distributed as dist
14 | import torch
15 | from pycls.core.config import cfg
16 |
17 |
18 | def unwrap_model(model):
19 | """Remove the DistributedDataParallel wrapper if present."""
20 | wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel)
21 | return model.module if wrapped else model
22 |
23 |
24 | # @torch.no_grad()
25 | # def compute_precise_bn_stats(model, loader):
26 | # """Computes precise BN stats on training data."""
27 | # # Compute the number of minibatches to use
28 | # num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size / cfg.NUM_GPUS)
29 | # num_iter = min(num_iter, len(loader))
30 | # # Retrieve the BN layers
31 | # bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
32 | # # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
33 | # running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
34 | # running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
35 | # # Remember momentum values
36 | # momentums = [bn.momentum for bn in bns]
37 | # # Set momentum to 1.0 to compute BN stats that only reflect the current batch
38 | # for bn in bns:
39 | # bn.momentum = 1.0
40 | # # Average the BN stats for each BN layer over the batches
41 | # for inputs, _labels in itertools.islice(loader, num_iter):
42 | # model(inputs.cuda())
43 | # for i, bn in enumerate(bns):
44 | # running_means[i] += bn.running_mean / num_iter
45 | # running_vars[i] += bn.running_var / num_iter
46 | # # Sync BN stats across GPUs (no reduction if 1 GPU used)
47 | # running_means = dist.scaled_all_reduce(running_means)
48 | # running_vars = dist.scaled_all_reduce(running_vars)
49 | # # Set BN stats and restore original momentum values
50 | # for i, bn in enumerate(bns):
51 | # bn.running_mean = running_means[i]
52 | # bn.running_var = running_vars[i]
53 | # bn.momentum = momentums[i]
54 |
55 |
56 | def complexity(model):
57 | """Compute model complexity (model can be model instance or model class)."""
58 | size = cfg.TRAIN.IM_SIZE
59 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
60 | cx = unwrap_model(model).complexity(cx)
61 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
62 |
63 |
64 | def smooth_one_hot_labels(labels):
65 | """Convert each label to a one-hot vector."""
66 | n_classes, label_smooth = cfg.MODEL.NUM_CLASSES, cfg.TRAIN.LABEL_SMOOTHING
67 | err_str = "Invalid input to one_hot_vector()"
68 | assert labels.ndim == 1 and labels.max() < n_classes, err_str
69 | shape = (labels.shape[0], n_classes)
70 | neg_val = label_smooth / n_classes
71 | pos_val = 1.0 - label_smooth + neg_val
72 | labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device)
73 | labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val)
74 | return labels_one_hot
75 |
76 |
77 | class SoftCrossEntropyLoss(torch.nn.Module):
78 | """SoftCrossEntropyLoss (useful for label smoothing and mixup).
79 | Identical to torch.nn.CrossEntropyLoss if used with one-hot labels."""
80 |
81 | def __init__(self):
82 | super(SoftCrossEntropyLoss, self).__init__()
83 |
84 | def forward(self, x, y):
85 | loss = -y * torch.nn.functional.log_softmax(x, -1)
86 | return torch.sum(loss) / x.shape[0]
87 |
88 |
89 | def mixup(inputs, labels):
90 | """Apply mixup to minibatch (https://arxiv.org/abs/1710.09412)."""
91 | alpha = cfg.TRAIN.MIXUP_ALPHA
92 | assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot"
93 | if alpha > 0:
94 | m = np.random.beta(alpha, alpha)
95 | permutation = torch.randperm(labels.shape[0])
96 | inputs = m * inputs + (1.0 - m) * inputs[permutation, :]
97 | labels = m * labels + (1.0 - m) * labels[permutation, :]
98 | return inputs, labels, labels.argmax(1)
99 |
--------------------------------------------------------------------------------
/pycls/core/optimizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Optimizer."""
9 |
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import torch
13 |
14 |
15 | def construct_optimizer(cfg, model):
16 | """Constructs the optimizer.
17 |
18 | Note that the momentum update in PyTorch differs from the one in Caffe2.
19 | In particular,
20 |
21 | Caffe2:
22 | V := mu * V + lr * g
23 | p := p - V
24 |
25 | PyTorch:
26 | V := mu * V + g
27 | p := p - lr * V
28 |
29 | where V is the velocity, mu is the momentum factor, lr is the learning rate,
30 | g is the gradient and p are the parameters.
31 |
32 | Since V is defined independently of the learning rate in PyTorch,
33 | when the learning rate is changed there is no need to perform the
34 | momentum correction by scaling V (unlike in the Caffe2 case).
35 | """
36 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
37 | # Apply different weight decay to Batchnorm and non-batchnorm parameters.
38 | p_bn = [p for n, p in model.named_parameters() if "bn" in n]
39 | p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
40 | optim_params = [
41 | {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
42 | {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
43 | ]
44 | else:
45 | optim_params = model.parameters()
46 |
47 | if cfg.OPTIM.TYPE == 'sgd':
48 | optimizer = torch.optim.SGD(
49 | model.parameters(),
50 | lr=cfg.OPTIM.BASE_LR,
51 | momentum=cfg.OPTIM.MOMENTUM,
52 | weight_decay=cfg.OPTIM.WEIGHT_DECAY,
53 | dampening=cfg.OPTIM.DAMPENING,
54 | nesterov=cfg.OPTIM.NESTEROV
55 | )
56 | elif cfg.OPTIM.TYPE == 'adam':
57 | optimizer = torch.optim.Adam(
58 | model.parameters(),
59 | lr=cfg.OPTIM.BASE_LR,
60 | weight_decay=cfg.OPTIM.WEIGHT_DECAY
61 | )
62 | else:
63 | raise NotImplementedError
64 |
65 | return optimizer
66 |
67 |
68 | def lr_fun_steps(cfg, cur_epoch):
69 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
70 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
71 | return cfg.OPTIM.LR_MULT ** ind
72 |
73 |
74 | def lr_fun_exp(cfg, cur_epoch):
75 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
76 | return cfg.OPTIM.MIN_LR ** (cur_epoch / cfg.OPTIM.MAX_EPOCH)
77 |
78 |
79 | def lr_fun_cos(cfg, cur_epoch):
80 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
81 | lr = 0.5 * (1.0 + np.cos(np.pi * cur_epoch / cfg.OPTIM.MAX_EPOCH))
82 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR
83 |
84 |
85 | def lr_fun_lin(cfg, cur_epoch):
86 | """Linear schedule (cfg.OPTIM.LR_POLICY = 'lin')."""
87 | lr = 1.0 - cur_epoch / cfg.OPTIM.MAX_EPOCH
88 | return (1.0 - cfg.OPTIM.MIN_LR) * lr + cfg.OPTIM.MIN_LR
89 |
90 |
91 | def lr_fun_none(cfg, cur_epoch):
92 | """No schedule (cfg.OPTIM.LR_POLICY = 'none')."""
93 | return 1
94 |
95 |
96 | def get_lr_fun(cfg):
97 | """Retrieves the specified lr policy function"""
98 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
99 | assert lr_fun in globals(), "Unknown LR policy: " + cfg.OPTIM.LR_POLICY
100 | err_str = "exp lr policy requires OPTIM.MIN_LR to be greater than 0."
101 | assert cfg.OPTIM.LR_POLICY != "exp" or cfg.OPTIM.MIN_LR > 0, err_str
102 | return globals()[lr_fun]
103 |
104 |
105 | def get_epoch_lr(cfg, cur_epoch):
106 | """Retrieves the lr for the given epoch according to the policy."""
107 | # Get lr and scale by by BASE_LR
108 | lr = get_lr_fun(cfg)(cfg, cur_epoch) * cfg.OPTIM.BASE_LR
109 | # Linear warmup
110 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS and 'none' not in cfg.OPTIM.LR_POLICY:
111 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
112 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
113 | lr *= warmup_factor
114 | return lr
115 |
116 |
117 | def set_lr(optimizer, new_lr):
118 | """Sets the optimizer lr to the specified value."""
119 | for param_group in optimizer.param_groups:
120 | param_group["lr"] = new_lr
121 |
122 |
123 | def plot_lr_fun():
124 | """Visualizes lr function."""
125 | epochs = list(range(cfg.OPTIM.MAX_EPOCH))
126 | lrs = [get_epoch_lr(epoch) for epoch in epochs]
127 | plt.plot(epochs, lrs, ".-")
128 | plt.title("lr_policy: {}".format(cfg.OPTIM.LR_POLICY))
129 | plt.xlabel("epochs")
130 | plt.ylabel("learning rate")
131 | plt.ylim(bottom=0)
132 | plt.show()
133 |
--------------------------------------------------------------------------------
/pycls/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/datasets/__init__.py
--------------------------------------------------------------------------------
/pycls/datasets/custom_datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from PIL import Image
4 |
5 | class CIFAR10(torchvision.datasets.CIFAR10):
6 | def __init__(self, root, train, transform, test_transform, download=True):
7 | super(CIFAR10, self).__init__(root, train, transform=transform, download=download)
8 | self.test_transform = test_transform
9 | self.no_aug = False
10 |
11 | def __getitem__(self, index: int):
12 | """
13 | Args:
14 | index (int): Index
15 |
16 | Returns:
17 | tuple: (image, target) where target is index of the target class.
18 | """
19 | img, target = self.data[index], self.targets[index]
20 |
21 | # doing this so that it is consistent with all other datasets
22 | # to return a PIL Image
23 | img = Image.fromarray(img)
24 |
25 | if self.no_aug:
26 | if self.test_transform is not None:
27 | img = self.test_transform(img)
28 | else:
29 | if self.transform is not None:
30 | img = self.transform(img)
31 |
32 |
33 | return img, target
34 |
35 |
36 | class CIFAR100(torchvision.datasets.CIFAR100):
37 | def __init__(self, root, train, transform, test_transform, download=True):
38 | super(CIFAR100, self).__init__(root, train, transform=transform, download=download)
39 | self.test_transform = test_transform
40 | self.no_aug = False
41 |
42 | def __getitem__(self, index: int):
43 | """
44 | Args:
45 | index (int): Index
46 |
47 | Returns:
48 | tuple: (image, target) where target is index of the target class.
49 | """
50 | img, target = self.data[index], self.targets[index]
51 |
52 | # doing this so that it is consistent with all other datasets
53 | # to return a PIL Image
54 | img = Image.fromarray(img)
55 |
56 | if self.no_aug:
57 | if self.test_transform is not None:
58 | img = self.test_transform(img)
59 | else:
60 | if self.transform is not None:
61 | img = self.transform(img)
62 |
63 |
64 | return img, target
65 |
66 |
67 | class MNIST(torchvision.datasets.MNIST):
68 | def __init__(self, root, train, transform, test_transform, download=True):
69 | super(MNIST, self).__init__(root, train, transform=transform, download=download)
70 | self.test_transform = test_transform
71 | self.no_aug = False
72 |
73 | def __getitem__(self, index: int):
74 | """
75 | Args:
76 | index (int): Index
77 |
78 | Returns:
79 | tuple: (image, target) where target is index of the target class.
80 | """
81 | img, target = self.data[index], int(self.targets[index])
82 |
83 | # doing this so that it is consistent with all other datasets
84 | # to return a PIL Image
85 | img = Image.fromarray(img.numpy(), mode='L')
86 |
87 | if self.no_aug:
88 | if self.test_transform is not None:
89 | img = self.test_transform(img)
90 | else:
91 | if self.transform is not None:
92 | img = self.transform(img)
93 |
94 |
95 | return img, target
96 |
97 |
98 | class SVHN(torchvision.datasets.SVHN):
99 | def __init__(self, root, train, transform, test_transform, download=True):
100 | super(SVHN, self).__init__(root, train, transform=transform, download=download)
101 | self.test_transform = test_transform
102 | self.no_aug = False
103 |
104 | def __getitem__(self, index: int):
105 | """
106 | Args:
107 | index (int): Index
108 |
109 | Returns:
110 | tuple: (image, target) where target is index of the target class.
111 | """
112 | img, target = self.data[index], self.targets[index]
113 |
114 | # doing this so that it is consistent with all other datasets
115 | # to return a PIL Image
116 | img = Image.fromarray(img)
117 |
118 | if self.no_aug:
119 | if self.test_transform is not None:
120 | img = self.test_transform(img)
121 | else:
122 | if self.transform is not None:
123 | img = self.transform(img)
124 |
125 |
126 | return img, target
--------------------------------------------------------------------------------
/pycls/datasets/imbalanced_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | Credits: Kaihua Tang
3 | Source: https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/
4 | """
5 |
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | import numpy as np
9 | from PIL import Image
10 | import random
11 | from pycls.datasets.custom_datasets import CIFAR10, CIFAR100
12 |
13 | class IMBALANCECIFAR10(CIFAR10):
14 | cls_num = 10
15 | np.random.seed(1)
16 | def __init__(self, root, train, transform=None, test_transform=None, imbalance_ratio=0.02, imb_type='exp'):
17 | super(IMBALANCECIFAR10, self).__init__(root, train, transform=transform, test_transform=test_transform, download=True)
18 | self.train = train
19 | self.transform = transform
20 | if self.train:
21 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio)
22 | self.gen_imbalanced_data(img_num_list)
23 | phase = 'Train'
24 | else:
25 | phase = 'Test'
26 | self.labels = self.targets
27 |
28 | print("{} Mode: Contain {} images".format(phase, len(self.data)))
29 |
30 | def _get_class_dict(self):
31 | class_dict = dict()
32 | for i, anno in enumerate(self.get_annotations()):
33 | cat_id = anno["category_id"]
34 | if not cat_id in class_dict:
35 | class_dict[cat_id] = []
36 | class_dict[cat_id].append(i)
37 | return class_dict
38 |
39 |
40 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
41 | img_max = len(self.data) / cls_num
42 | img_num_per_cls = []
43 | if imb_type == 'exp':
44 | for cls_idx in range(cls_num):
45 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
46 | img_num_per_cls.append(int(num))
47 | elif imb_type == 'step':
48 | for cls_idx in range(cls_num // 2):
49 | img_num_per_cls.append(int(img_max))
50 | for cls_idx in range(cls_num // 2):
51 | img_num_per_cls.append(int(img_max * imb_factor))
52 | else:
53 | img_num_per_cls.extend([int(img_max)] * cls_num)
54 | return img_num_per_cls
55 |
56 | def gen_imbalanced_data(self, img_num_per_cls):
57 |
58 | new_data = []
59 | new_targets = []
60 | targets_np = np.array(self.targets, dtype=np.int64)
61 | classes = np.unique(targets_np)
62 |
63 | self.num_per_cls_dict = dict()
64 | for the_class, the_img_num in zip(classes, img_num_per_cls):
65 | self.num_per_cls_dict[the_class] = the_img_num
66 | idx = np.where(targets_np == the_class)[0]
67 | np.random.shuffle(idx)
68 | selec_idx = idx[:the_img_num]
69 | new_data.append(self.data[selec_idx, ...])
70 | new_targets.extend([the_class, ] * the_img_num)
71 | new_data = np.vstack(new_data)
72 | self.data = new_data
73 | self.targets = new_targets
74 |
75 | def __len__(self):
76 | return len(self.labels)
77 |
78 | def get_num_classes(self):
79 | return self.cls_num
80 |
81 | def get_annotations(self):
82 | annos = []
83 | for label in self.labels:
84 | annos.append({'category_id': int(label)})
85 | return annos
86 |
87 | def get_cls_num_list(self):
88 | cls_num_list = []
89 | for i in range(self.cls_num):
90 | cls_num_list.append(self.num_per_cls_dict[i])
91 | return cls_num_list
92 |
93 |
94 | class IMBALANCECIFAR100(CIFAR100):
95 | """`CIFAR100 `_ Dataset.
96 | This is a subclass of the `CIFAR10` Dataset.
97 | """
98 | cls_num = 100
99 | base_folder = 'cifar-100-python'
100 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
101 | filename = "cifar-100-python.tar.gz"
102 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
103 | train_list = [
104 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
105 | ]
106 |
107 | test_list = [
108 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
109 | ]
110 | meta = {
111 | 'filename': 'meta',
112 | 'key': 'fine_label_names',
113 | 'md5': '7973b15100ade9c7d40fb424638fde48',
114 | }
--------------------------------------------------------------------------------
/pycls/datasets/randaugment.py:
--------------------------------------------------------------------------------
1 | # This file is directly taken from code implementation by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
2 |
3 | # ----------------------------------------------------------
4 | #This file was modified from implementation of [Auto-Augment](https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py) to add Random Augmentations.
5 |
6 | """ AutoAugment ==[Modified to]==> RandAugment"""
7 |
8 |
9 | from PIL import Image, ImageEnhance, ImageOps
10 | import PIL.ImageDraw as ImageDraw
11 | import numpy as np
12 | import random
13 |
14 | class RandAugmentPolicy(object):
15 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
16 | Example:
17 | >>> policy = RandAugmentPolicy()
18 | >>> transformed = policy(image)
19 | Example as a PyTorch Transform:
20 | >>> transform=transforms.Compose([
21 | >>> transforms.Resize(256),
22 | >>> RandAugmentPolicy(),
23 | >>> transforms.ToTensor()])
24 | """
25 | #I change fill color from (128, 128, 128) to (0, 0, 0)
26 | def __init__(self, fillcolor=(0,0,0), N=1, M=5):
27 | self.policies = ["invert","autocontrast","equalize","rotate","solarize","color", \
28 | "posterize","contrast","brightness","sharpness","shearX","shearY","translateX", \
29 | "translateY","cutout"]
30 | self.N = N
31 | self.M = M
32 |
33 | def __call__(self, img):
34 | choosen_policies = np.random.choice(self.policies, self.N)
35 | for policy in choosen_policies:
36 | subpolicy_obj = SubPolicy(operation=policy, magnitude=self.M)
37 | img = subpolicy_obj(img)
38 |
39 | return img
40 |
41 | def __repr__(self):
42 | return "RandAugment CIFAR10 Policy with Cutout"
43 |
44 | class SubPolicy(object):
45 | def __init__(self, operation, magnitude, fillcolor=(128, 128, 128), MAX_PARAM=10):
46 | ranges = {
47 | "shearX": np.linspace(0, 0.3, MAX_PARAM),
48 | "shearY": np.linspace(0, 0.3, MAX_PARAM),
49 | "translateX": np.linspace(0, 150 / 331, MAX_PARAM),
50 | "translateY": np.linspace(0, 150 / 331, MAX_PARAM),
51 | "rotate": np.linspace(0, 30, MAX_PARAM),
52 | "color": np.linspace(0.0, 0.9, MAX_PARAM),
53 | "posterize": np.round(np.linspace(8, 4, MAX_PARAM), 0).astype(np.int),
54 | "solarize": np.linspace(256, 0, MAX_PARAM),
55 | "contrast": np.linspace(0.0, 0.9, MAX_PARAM),
56 | "sharpness": np.linspace(0.0, 0.9, MAX_PARAM),
57 | "brightness": np.linspace(0.0, 0.9, MAX_PARAM),
58 | "autocontrast": [0] * MAX_PARAM,
59 | "equalize": [0] * MAX_PARAM,
60 | "invert": [0] * MAX_PARAM,
61 | "cutout": np.linspace(0.0,0.8, MAX_PARAM),
62 | }
63 |
64 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
65 | def rotate_with_fill(img, magnitude):
66 | rot = img.convert("RGBA").rotate(magnitude)
67 | #return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
68 | return Image.composite(rot, Image.new("RGBA", rot.size, (0,) * 4), rot).convert(img.mode)
69 |
70 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
71 | assert 0.0 <= v <= 0.8
72 | if v <= 0.:
73 | return img
74 |
75 | v = v * img.size[0]
76 |
77 | return CutoutAbs(img, v)
78 |
79 |
80 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
81 | # assert 0 <= v <= 20
82 | if v < 0:
83 | return img
84 | w, h = img.size
85 | x0 = np.random.uniform(w)
86 | y0 = np.random.uniform(h)
87 |
88 | x0 = int(max(0, x0 - v / 2.))
89 | y0 = int(max(0, y0 - v / 2.))
90 | x1 = min(w, x0 + v)
91 | y1 = min(h, y0 + v)
92 |
93 | xy = (x0, y0, x1, y1)
94 | #color = (125, 123, 114)
95 | color = (0, 0, 0)
96 | img = img.copy()
97 | ImageDraw.Draw(img).rectangle(xy, color)
98 | return img
99 |
100 | func = {
101 | "shearX": lambda img, magnitude: img.transform(
102 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
103 | Image.BICUBIC, fillcolor=fillcolor),
104 | "shearY": lambda img, magnitude: img.transform(
105 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
106 | Image.BICUBIC, fillcolor=fillcolor),
107 | "translateX": lambda img, magnitude: img.transform(
108 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
109 | fillcolor=fillcolor),
110 | "translateY": lambda img, magnitude: img.transform(
111 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
112 | fillcolor=fillcolor),
113 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
114 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
115 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
116 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
117 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
118 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
119 | 1 + magnitude * random.choice([-1, 1])),
120 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
121 | 1 + magnitude * random.choice([-1, 1])),
122 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
123 | 1 + magnitude * random.choice([-1, 1])),
124 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
125 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
126 | "invert": lambda img, magnitude: ImageOps.invert(img),
127 | "cutout": lambda img, magnitude: Cutout(img, magnitude)
128 | }
129 |
130 | self.operation = func[operation]
131 | self.magnitude = ranges[operation][magnitude]
132 |
133 |
134 | def __call__(self, img):
135 | img = self.operation(img, self.magnitude)
136 | return img
137 |
138 |
--------------------------------------------------------------------------------
/pycls/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | # This file is directly taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
2 | # GitHub: https://github.com/PrateekMunjal
3 | # ----------------------------------------------------------
4 |
5 | from torch.utils.data.sampler import Sampler
6 |
7 |
8 | class IndexedSequentialSampler(Sampler):
9 | r"""Samples elements sequentially, always in the same order.
10 |
11 | Arguments:
12 | data_idxes (Dataset indexes): dataset indexes to sample from
13 | """
14 |
15 | def __init__(self, data_idxes, isDebug=False):
16 | if isDebug: print("========= my custom squential sampler =========")
17 | self.data_idxes = data_idxes
18 |
19 | def __iter__(self):
20 | return (self.data_idxes[i] for i in range(len(self.data_idxes)))
21 |
22 | def __len__(self):
23 | return len(self.data_idxes)
24 |
25 | # class IndexedDistributedSampler(Sampler):
26 | # """Sampler that restricts data loading to a particular index set of dataset.
27 |
28 | # It is especially useful in conjunction with
29 | # :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
30 | # process can pass a DistributedSampler instance as a DataLoader sampler,
31 | # and load a subset of the original dataset that is exclusive to it.
32 |
33 | # .. note::
34 | # Dataset is assumed to be of constant size.
35 |
36 | # Arguments:
37 | # dataset: Dataset used for sampling.
38 | # num_replicas (optional): Number of processes participating in
39 | # distributed training.
40 | # rank (optional): Rank of the current process within num_replicas.
41 | # """
42 |
43 | # def __init__(self, dataset, index_set, num_replicas=None, rank=None, allowRepeat=True):
44 | # if num_replicas is None:
45 | # if not dist.is_available():
46 | # raise RuntimeError("Requires distributed package to be available")
47 | # num_replicas = dist.get_world_size()
48 | # if rank is None:
49 | # if not dist.is_available():
50 | # raise RuntimeError("Requires distributed package to be available")
51 | # rank = dist.get_rank()
52 | # self.dataset = dataset
53 | # self.num_replicas = num_replicas
54 | # self.rank = rank
55 | # self.epoch = 0
56 | # self.index_set = index_set
57 | # self.allowRepeat = allowRepeat
58 | # if self.allowRepeat:
59 | # self.num_samples = int(math.ceil(len(self.index_set) * 1.0 / self.num_replicas))
60 | # self.total_size = self.num_samples * self.num_replicas
61 | # else:
62 | # self.num_samples = int(math.ceil((len(self.index_set)-self.rank) * 1.0 / self.num_replicas))
63 | # self.total_size = len(self.index_set)
64 |
65 | # def __iter__(self):
66 | # # deterministically shuffle based on epoch
67 | # g = torch.Generator()
68 | # g.manual_seed(self.epoch)
69 | # indices = torch.randperm(len(self.index_set), generator=g).tolist()
70 | # #To access valid indices
71 | # #indices = self.index_set[indices]
72 | # # add extra samples to make it evenly divisible
73 | # if self.allowRepeat:
74 | # indices += indices[:(self.total_size - len(indices))]
75 |
76 | # assert len(indices) == self.total_size
77 |
78 | # # subsample
79 | # indices = self.index_set[indices[self.rank:self.total_size:self.num_replicas]]
80 | # assert len(indices) == self.num_samples, "len(indices): {} and self.num_samples: {}"\
81 | # .format(len(indices), self.num_samples)
82 |
83 | # return iter(indices)
84 |
85 | # def __len__(self):
86 | # return self.num_samples
87 |
88 | # def set_epoch(self, epoch):
89 | # self.epoch = epoch
90 |
--------------------------------------------------------------------------------
/pycls/datasets/simclr_augment.py:
--------------------------------------------------------------------------------
1 | # Modified from the source: https://github.com/sthalles/PyTorch-BYOL
2 | # Previous owner of this file: Thalles Silva
3 |
4 | from torchvision import transforms
5 | from .utils.gaussian_blur import GaussianBlur
6 |
7 | def get_simclr_ops(input_shape, s=1):
8 | # get a set of data augmentation transformations as described in the SimCLR paper.
9 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
10 | ops = [transforms.RandomHorizontalFlip(),
11 | transforms.RandomApply([color_jitter], p=0.8),
12 | transforms.RandomGrayscale(p=0.2),
13 | GaussianBlur(kernel_size=int(0.1 * input_shape)),]
14 | return ops
15 |
16 |
17 |
--------------------------------------------------------------------------------
/pycls/datasets/tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | import torchvision.datasets as datasets
7 |
8 | from typing import Any
9 |
10 |
11 | class TinyImageNet(datasets.ImageFolder):
12 | """`Tiny ImageNet Classification Dataset.
13 |
14 | Args:
15 | root (string): Root directory of the ImageNet Dataset.
16 | split (string, optional): The dataset split, supports ``train``, or ``val``.
17 | transform (callable, optional): A function/transform that takes in an PIL image
18 | and returns a transformed version. E.g, ``transforms.RandomCrop``
19 | target_transform (callable, optional): A function/transform that takes in the
20 | target and transforms it.
21 | loader (callable, optional): A function to load an image given its path.
22 |
23 | Attributes:
24 | classes (list): List of the class name tuples.
25 | class_to_idx (dict): Dict with items (class_name, class_index).
26 | wnids (list): List of the WordNet IDs.
27 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
28 | samples (list): List of (image path, class_index) tuples
29 | targets (list): The class_index value for each image in the dataset
30 | """
31 | def __init__(self, root: str, split: str = 'train', transform=None, test_transform=None, **kwargs: Any) -> None:
32 | self.root = root
33 | self.test_transform = test_transform
34 | self.no_aug = False
35 |
36 | assert self.check_root(), "Something is wrong with the Tiny ImageNet dataset path. Download the official dataset zip from http://cs231n.stanford.edu/tiny-imagenet-200.zip and unzip it inside {}.".format(self.root)
37 | self.split = datasets.utils.verify_str_arg(split, "split", ("train", "val"))
38 | wnid_to_classes = self.load_wnid_to_classes()
39 |
40 | super(TinyImageNet, self).__init__(self.split_folder, **kwargs)
41 | self.transform = transform
42 | self.wnids = self.classes
43 | self.wnid_to_idx = self.class_to_idx
44 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
45 | self.class_to_idx = {cls: idx
46 | for idx, clss in enumerate(self.classes)
47 | for cls in clss}
48 | # Tiny ImageNet val directory structure is not similar to that of train's
49 | # So a custom loading function is necessary
50 | if self.split == 'val':
51 | self.root = root
52 | self.imgs, self.targets = self.load_val_data()
53 | self.samples = [(self.imgs[idx], self.targets[idx]) for idx in range(len(self.imgs))]
54 | self.root = os.path.join(self.root, 'val')
55 |
56 |
57 |
58 | # Split folder is used for the 'super' call. Since val directory is not structured like the train,
59 | # we simply use train's structure to get all classes and other stuff
60 | @property
61 | def split_folder(self) -> str:
62 | return os.path.join(self.root, 'train')
63 |
64 |
65 | def load_val_data(self):
66 | imgs, targets = [], []
67 | with open(os.path.join(self.root, 'val', 'val_annotations.txt'), 'r') as file:
68 | for line in file:
69 | if line.split()[1] in self.wnids:
70 | img_file, wnid = line.split('\t')[:2]
71 | imgs.append(os.path.join(self.root, 'val', 'images', img_file))
72 | targets.append(wnid)
73 | targets = np.array([self.wnid_to_idx[wnid] for wnid in targets])
74 | return imgs, targets
75 |
76 |
77 | def load_wnid_to_classes(self):
78 | wnid_to_classes = {}
79 | with open(os.path.join(self.root, 'words.txt'), 'r') as file:
80 | lines = file.readlines()
81 | lines = [x.split('\t') for x in lines]
82 | wnid_to_classes = {x[0]:x[1].strip() for x in lines}
83 | return wnid_to_classes
84 |
85 | def check_root(self):
86 | tinyim_set = ['words.txt', 'wnids.txt', 'train', 'val', 'test']
87 | for x in os.scandir(self.root):
88 | if x.name not in tinyim_set:
89 | return False
90 | return True
91 |
92 | def __getitem__(self, index: int):
93 | """
94 | Args:
95 | index (int): Index
96 |
97 | Returns:
98 | tuple: (sample, target) where target is class_index of the target class.
99 | """
100 | path, target = self.samples[index]
101 | sample = self.loader(path)
102 | if self.no_aug:
103 | if self.test_transform is not None:
104 | sample = self.test_transform(sample)
105 | else:
106 | if self.transform is not None:
107 | sample = self.transform(sample)
108 |
109 | return sample, target
--------------------------------------------------------------------------------
/pycls/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/datasets/utils/__init__.py
--------------------------------------------------------------------------------
/pycls/datasets/utils/gaussian_blur.py:
--------------------------------------------------------------------------------
1 | # Owner of this file: Thalles Silva
2 | # Source: https://github.com/sthalles/PyTorch-BYOL
3 | import torch
4 | from torchvision import transforms
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 |
9 | class GaussianBlur(object):
10 | """blur a single image on CPU"""
11 |
12 | def __init__(self, kernel_size):
13 | radias = kernel_size // 2
14 | kernel_size = radias * 2 + 1
15 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
16 | stride=1, padding=0, bias=False, groups=3)
17 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
18 | stride=1, padding=0, bias=False, groups=3)
19 | self.k = kernel_size
20 | self.r = radias
21 |
22 | self.blur = nn.Sequential(
23 | nn.ReflectionPad2d(radias),
24 | self.blur_h,
25 | self.blur_v
26 | )
27 |
28 | self.pil_to_tensor = transforms.ToTensor()
29 | self.tensor_to_pil = transforms.ToPILImage()
30 |
31 | def __call__(self, img):
32 | img = self.pil_to_tensor(img).unsqueeze(0)
33 |
34 | sigma = np.random.uniform(0.1, 2.0)
35 | x = np.arange(-self.r, self.r + 1)
36 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
37 | x = x / x.sum()
38 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
39 |
40 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
41 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
42 |
43 | with torch.no_grad():
44 | img = self.blur(img)
45 | img = img.squeeze()
46 |
47 | img = self.tensor_to_pil(img)
48 |
49 | return img
--------------------------------------------------------------------------------
/pycls/datasets/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import csv
2 |
3 | def load_imageset(path, set_name):
4 | """
5 | Returns the image set `set_name` present at `path` as a list.
6 | Keyword arguments:
7 | path -- path to data folder
8 | set_name -- image set name - labeled or unlabeled.
9 | """
10 | reader = csv.reader(open(os.path.join(path, set_name+'.csv'), 'rt'))
11 | reader = [r[0] for r in reader]
12 | return reader
--------------------------------------------------------------------------------
/pycls/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Expose model constructors."""
9 |
10 | from pycls.models.resnet import *
11 | from pycls.models.vgg import *
12 |
--------------------------------------------------------------------------------
/pycls/models/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
4 | from typing import Any
5 |
6 |
7 | __all__ = ['AlexNet', 'alexnet']
8 |
9 |
10 | model_urls = {
11 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
12 | }
13 |
14 |
15 | class AlexNet(nn.Module):
16 | '''
17 | AlexNet modified (features) for CIFAR10. Source: https://github.com/icpm/pytorch-cifar10/blob/master/models/AlexNet.py.
18 | '''
19 | def __init__(self, num_classes: int = 1000, use_dropout=False) -> None:
20 | super(AlexNet, self).__init__()
21 | self.use_dropout = use_dropout
22 | self.features = nn.Sequential(
23 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
24 | nn.ReLU(inplace=True),
25 | nn.MaxPool2d(kernel_size=2),
26 | nn.Conv2d(64, 192, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.MaxPool2d(kernel_size=2),
29 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
34 | nn.ReLU(inplace=True),
35 | nn.MaxPool2d(kernel_size=2),
36 | )
37 | # self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
38 | self.fc_block = nn.Sequential(
39 | nn.Linear(256 * 2 * 2, 4096, bias=False),
40 | nn.BatchNorm1d(4096),
41 | nn.ReLU(inplace=True),
42 | nn.Linear(4096, 4096, bias=False),
43 | nn.BatchNorm1d(4096),
44 | nn.ReLU(inplace=True),
45 | )
46 | self.classifier = nn.Sequential(
47 | nn.Linear(4096, num_classes),
48 | )
49 | self.penultimate_active = False
50 | self.drop = nn.Dropout(p=0.5)
51 |
52 | def forward(self, x: torch.Tensor) -> torch.Tensor:
53 | x = self.features(x)
54 | # x = self.avgpool(x)
55 | z = torch.flatten(x, 1)
56 | if self.use_dropout:
57 | x = self.drop(x)
58 | z = self.fc_block(z)
59 | x = self.classifier(z)
60 | if self.penultimate_active:
61 | return z, x
62 | return x
63 |
64 |
65 | def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet:
66 | r"""AlexNet model architecture from the
67 | `"One weird trick..." `_ paper.
68 | Args:
69 | pretrained (bool): If True, returns a model pre-trained on ImageNet
70 | progress (bool): If True, displays a progress bar of the download to stderr
71 | """
72 | model = AlexNet(**kwargs)
73 | if pretrained:
74 | state_dict = load_state_dict_from_url(model_urls['alexnet'],
75 | progress=progress)
76 | model.load_state_dict(state_dict)
77 | return model
--------------------------------------------------------------------------------
/pycls/models/vaal_model.py:
--------------------------------------------------------------------------------
1 | # This file is modified taken from a code implementation shared with me by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
2 | # GitHub: https://github.com/PrateekMunjal
3 | # ----------------------------------------------------------
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.nn.init as init
9 |
10 | import pycls.utils.logging as lu
11 |
12 | logger = lu.get_logger(__name__)
13 |
14 | class View(nn.Module):
15 | def __init__(self, size):
16 | super(View, self).__init__()
17 | self.size = size
18 |
19 | def forward(self, tensor):
20 | return tensor.view(self.size)
21 |
22 |
23 | class VAE(nn.Module):
24 | """Encoder-Decoder architecture for both WAE-MMD and WAE-GAN."""
25 | def __init__(self, device_id,z_dim=32, nc=3):
26 | super(VAE, self).__init__()
27 | print("============================")
28 | logger.info("============================")
29 | print(f"Constructing VAE MODEL with z_dim: {z_dim}")
30 | logger.info(f"Constructing VAE MODEL with z_dim: {z_dim}")
31 | print("============================")
32 | logger.info("============================")
33 | self.encode_shape = int(z_dim/16)
34 | if z_dim == 32:
35 | self.decode_shape = 4
36 | elif z_dim == 64:
37 | self.decode_shape = 8
38 | else:
39 | self.decode_shape = 4
40 | self.device_id = device_id
41 | self.z_dim = z_dim
42 | self.nc = nc
43 | self.encoder = nn.Sequential(
44 | nn.Conv2d(nc, 128, 4, 2, 1, bias=False), # B, 128, 32, 32 or B, 128, 64, 64
45 | nn.BatchNorm2d(128),
46 | nn.ReLU(True),
47 | nn.Conv2d(128, 256, 4, 2, 1, bias=False), # B, 256, 16, 16 or B, 256, 32, 32
48 | nn.BatchNorm2d(256),
49 | nn.ReLU(True),
50 | nn.Conv2d(256, 512, 4, 2, 1, bias=False), # B, 512, 8, 8 or B, 512, 16, 16
51 | nn.BatchNorm2d(512),
52 | nn.ReLU(True),
53 | nn.Conv2d(512, 1024, 4, 2, 1, bias=False), # B, 1024, 4, 4 or B, 1024, 8, 8
54 | nn.BatchNorm2d(1024),
55 | nn.ReLU(True),
56 | View((-1, 1024*self.encode_shape*self.encode_shape)), # B, 1024*4*4 or B, 1024, 4, 4
57 | )
58 |
59 | self.fc_mu = nn.Linear(1024*self.encode_shape*self.encode_shape, z_dim) # B, z_dim
60 | self.fc_logvar = nn.Linear(1024*self.encode_shape*self.encode_shape, z_dim) # B, z_dim
61 | self.decoder = nn.Sequential(
62 | nn.Linear(z_dim, 1024*self.decode_shape*self.decode_shape), # B, 1024*8*8
63 | View((-1, 1024, self.decode_shape, self.decode_shape)), # B, 1024, 8, 8
64 | nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), # B, 512, 16, 16
65 | nn.BatchNorm2d(512),
66 | nn.ReLU(True),
67 | nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # B, 256, 32, 32
68 | nn.BatchNorm2d(256),
69 | nn.ReLU(True),
70 | nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # B, 128, 64, 64
71 | nn.BatchNorm2d(128),
72 | nn.ReLU(True),
73 | nn.ConvTranspose2d(128, nc, 1), # B, nc, 64, 64
74 | )
75 | self.weight_init()
76 |
77 | def weight_init(self):
78 | for block in self._modules:
79 | try:
80 | for m in self._modules[block]:
81 | kaiming_init(m)
82 | except:
83 | kaiming_init(block)
84 |
85 | def forward(self, x):
86 | z = self._encode(x)
87 | mu, logvar = self.fc_mu(z), self.fc_logvar(z)
88 | z = self.reparameterize(mu, logvar)
89 | x_recon = self._decode(z)
90 |
91 | return x_recon, z, mu, logvar
92 |
93 | def reparameterize(self, mu, logvar):
94 | stds = (0.5 * logvar).exp()
95 | epsilon = torch.randn(*mu.size())
96 | #if mu.is_cuda:
97 | stds, epsilon = stds.cuda(), epsilon.cuda()
98 | mu = mu.cuda()
99 | latents = epsilon * stds + mu
100 | return latents
101 |
102 | def _encode(self, x):
103 | return self.encoder(x)
104 |
105 | def _decode(self, z):
106 | return self.decoder(z)
107 |
108 | class clf_Discriminator(nn.Module):
109 | """
110 | Model to circumvent the need of learning a separate task learner by
111 | combining the discriminator and task classifier.
112 | """
113 | def __init__(self, z_dim=10, n_classes=10):
114 | super(clf_Discriminator, self).__init__()
115 | self.z_dim = z_dim
116 | self.n_classes = n_classes
117 |
118 | self.net = nn.Sequential(
119 | nn.Linear(z_dim, 512),
120 | nn.ReLU(True),
121 | nn.Linear(512, 512),
122 | nn.ReLU(True),
123 | nn.Linear(512, 512),
124 | nn.ReLU(True),
125 | nn.Linear(512, 512),
126 | nn.ReLU(True)
127 | )
128 |
129 | self.disc_out = nn.Sequential(
130 | nn.Linear(512, 1),
131 | nn.Sigmoid()
132 | )
133 |
134 | self.clf_out = nn.Sequential(
135 | nn.Linear(512, self.n_classes),
136 | )
137 |
138 | self.weight_init()
139 |
140 | def weight_init(self):
141 | for block in self._modules:
142 | for m in self._modules[block]:
143 | kaiming_init(m)
144 |
145 | def forward(self, z):
146 | z = self.net(z)
147 | disc_out = self.disc_out(z)
148 | clf_out = self.clf_out(z)
149 | return disc_out, clf_out
150 |
151 |
152 | class Discriminator(nn.Module):
153 | """Adversary architecture(Discriminator) for WAE-GAN."""
154 | def __init__(self, z_dim=10):
155 | super(Discriminator, self).__init__()
156 | self.z_dim = z_dim
157 | self.penultimate_active = False
158 | self.net = nn.Sequential(
159 | nn.Linear(z_dim, 512),
160 | nn.ReLU(True),
161 | nn.Linear(512, 512),
162 | nn.ReLU(True),
163 | nn.Linear(512, 512),
164 | nn.ReLU(True),
165 | nn.Linear(512, 512),
166 | nn.ReLU(True)
167 | )
168 |
169 | self.out = nn.Sequential(
170 | nn.Linear(512, 1),
171 | nn.Sigmoid()
172 | )
173 |
174 | self.weight_init()
175 |
176 | def weight_init(self):
177 | for block in self._modules:
178 | for m in self._modules[block]:
179 | kaiming_init(m)
180 |
181 | def forward(self, z):
182 | z = self.net(z)
183 | if self.penultimate_active:
184 | return z, self.out(z)
185 | return self.out(z)
186 |
187 | class WGAN_Discriminator(nn.Module):
188 | """Adversary architecture(Discriminator) for WAE-GAN."""
189 | def __init__(self, z_dim=10):
190 | super(WGAN_Discriminator, self).__init__()
191 | self.z_dim = z_dim
192 | self.penultimate_active = False
193 | self.net = nn.Sequential(
194 | nn.Linear(z_dim, 512),
195 | nn.ReLU(True),
196 | nn.Linear(512, 512),
197 | nn.ReLU(True),
198 | nn.Linear(512, 512),
199 | nn.ReLU(True),
200 | nn.Linear(512, 512),
201 | nn.ReLU(True)
202 | )
203 |
204 | self.out = nn.Sequential(
205 | nn.Linear(512, 1),
206 | #nn.Sigmoid()
207 | )
208 |
209 | self.weight_init()
210 |
211 | def weight_init(self):
212 | for block in self._modules:
213 | for m in self._modules[block]:
214 | kaiming_init(m)
215 |
216 | def forward(self, z):
217 | z = self.net(z)
218 | if self.penultimate_active:
219 | return z, self.out(z)
220 | return self.out(z)
221 |
222 |
223 | def kaiming_init(m):
224 | if isinstance(m, (nn.Linear, nn.Conv2d)):
225 | init.kaiming_normal_(m.weight)
226 | if m.bias is not None:
227 | m.bias.data.fill_(0)
228 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
229 | m.weight.data.fill_(1)
230 | if m.bias is not None:
231 | m.bias.data.fill_(0)
232 |
233 |
234 | def normal_init(m, mean, std):
235 | if isinstance(m, (nn.Linear, nn.Conv2d)):
236 | m.weight.data.normal_(mean, std)
237 | if m.bias.data is not None:
238 | m.bias.data.zero_()
239 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
240 | m.weight.data.fill_(1)
241 | if m.bias.data is not None:
242 | m.bias.data.zero_()
243 |
--------------------------------------------------------------------------------
/pycls/models/vgg.py:
--------------------------------------------------------------------------------
1 | # Original Source: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
2 |
3 | # This file is modified to meet the implementation by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
4 | # GitHub: https://github.com/PrateekMunjal
5 | # ----------------------------------------------------------
6 |
7 |
8 | import torch
9 | import torch.nn as nn
10 | # from .utils import load_state_dict_from_url
11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
12 | from typing import Union, List, Dict, Any, cast
13 |
14 | import pycls.utils.logging as lu
15 | logger = lu.get_logger(__name__)
16 |
17 | __all__ = [
18 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
19 | 'vgg19_bn', 'vgg19',
20 | ]
21 |
22 |
23 | model_urls = {
24 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
25 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
26 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
27 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
28 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
29 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
30 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
31 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
32 | }
33 |
34 |
35 | class VGG(nn.Module):
36 |
37 | def __init__(
38 | self,
39 | features: nn.Module,
40 | num_classes: int = 1000,
41 | init_weights: bool = True,
42 | use_dropout = False
43 | ) -> None:
44 | super(VGG, self).__init__()
45 | self.penultimate_active = False
46 | if self.num_classes == 1000:
47 | logger.warning("This open source implementation is only suitable for small datasets like CIFAR. \
48 | For Imagenet we recommend to use Resnet based models")
49 | self.penultimate_dim = 4096
50 | else:
51 | self.penultimate_dim = 512
52 | self.features = features
53 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
54 | self.penultimate_act = nn.Sequential(
55 | nn.Linear(512 * 7 * 7, 4096),
56 | nn.ReLU(True),
57 | nn.Dropout(),
58 | nn.Linear(4096, self.penultimate_dim),
59 | nn.ReLU(True),
60 | #nn.Dropout(),
61 | )
62 | self.classifier = nn.Sequential(
63 | nn.Linear(self.penultimate_dim, num_classes)
64 | )
65 |
66 | # Describe model with source code link
67 | self.description = "VGG16 model loaded from VAAL source code with penultimate dim as {}".format(self.penultimate_dim)
68 | self.source_link = "https://github.com/sinhasam/vaal/blob/master/vgg.py"
69 |
70 | if init_weights:
71 | self._initialize_weights()
72 |
73 | def forward(self, x: torch.Tensor) -> torch.Tensor:
74 | x = self.features(x)
75 | x = self.avgpool(x)
76 | x = torch.flatten(x, 1)
77 | z = self.penultimate_act(x)
78 | x = self.classifier(z)
79 | if self.penultimate_active:
80 | return z, x
81 | return x
82 |
83 | def _initialize_weights(self) -> None:
84 | for m in self.modules():
85 | if isinstance(m, nn.Conv2d):
86 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
87 | if m.bias is not None:
88 | nn.init.constant_(m.bias, 0)
89 | elif isinstance(m, nn.BatchNorm2d):
90 | nn.init.constant_(m.weight, 1)
91 | nn.init.constant_(m.bias, 0)
92 | elif isinstance(m, nn.Linear):
93 | nn.init.normal_(m.weight, 0, 0.01)
94 | nn.init.constant_(m.bias, 0)
95 |
96 |
97 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
98 | layers: List[nn.Module] = []
99 | in_channels = 3
100 | for v in cfg:
101 | if v == 'M':
102 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
103 | else:
104 | v = cast(int, v)
105 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
106 | if batch_norm:
107 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
108 | else:
109 | layers += [conv2d, nn.ReLU(inplace=True)]
110 | in_channels = v
111 | return nn.Sequential(*layers)
112 |
113 |
114 | cfgs: Dict[str, List[Union[str, int]]] = {
115 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
116 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
117 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
118 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
119 | }
120 |
121 |
122 | def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
123 | if pretrained:
124 | kwargs['init_weights'] = False
125 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
126 | if pretrained:
127 | state_dict = load_state_dict_from_url(model_urls[arch],
128 | progress=progress)
129 | model.load_state_dict(state_dict)
130 | return model
131 |
132 |
133 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
134 | r"""VGG 11-layer model (configuration "A") from
135 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
136 | Args:
137 | pretrained (bool): If True, returns a model pre-trained on ImageNet
138 | progress (bool): If True, displays a progress bar of the download to stderr
139 | """
140 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
141 |
142 |
143 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
144 | r"""VGG 11-layer model (configuration "A") with batch normalization
145 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
146 | Args:
147 | pretrained (bool): If True, returns a model pre-trained on ImageNet
148 | progress (bool): If True, displays a progress bar of the download to stderr
149 | """
150 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
151 |
152 |
153 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
154 | r"""VGG 13-layer model (configuration "B")
155 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
156 | Args:
157 | pretrained (bool): If True, returns a model pre-trained on ImageNet
158 | progress (bool): If True, displays a progress bar of the download to stderr
159 | """
160 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
161 |
162 |
163 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
164 | r"""VGG 13-layer model (configuration "B") with batch normalization
165 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
166 | Args:
167 | pretrained (bool): If True, returns a model pre-trained on ImageNet
168 | progress (bool): If True, displays a progress bar of the download to stderr
169 | """
170 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
171 |
172 |
173 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
174 | r"""VGG 16-layer model (configuration "D")
175 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
176 | Args:
177 | pretrained (bool): If True, returns a model pre-trained on ImageNet
178 | progress (bool): If True, displays a progress bar of the download to stderr
179 | """
180 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
181 |
182 |
183 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
184 | r"""VGG 16-layer model (configuration "D") with batch normalization
185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
186 | Args:
187 | pretrained (bool): If True, returns a model pre-trained on ImageNet
188 | progress (bool): If True, displays a progress bar of the download to stderr
189 | """
190 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
191 |
192 |
193 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
194 | r"""VGG 19-layer model (configuration "E")
195 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | progress (bool): If True, displays a progress bar of the download to stderr
199 | """
200 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
201 |
202 |
203 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
204 | r"""VGG 19-layer model (configuration 'E') with batch normalization
205 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | progress (bool): If True, displays a progress bar of the download to stderr
209 | """
210 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
--------------------------------------------------------------------------------
/pycls/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/pycls/utils/__init__.py
--------------------------------------------------------------------------------
/pycls/utils/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Benchmarking functions."""
9 |
10 | import pycls.core.logging as logging
11 | import pycls.core.net as net
12 | import pycls.datasets.loader as loader
13 | import torch
14 | import torch.cuda.amp as amp
15 | from pycls.core.config import cfg
16 | from pycls.core.timer import Timer
17 |
18 |
19 | logger = logging.get_logger(__name__)
20 |
21 |
22 | @torch.no_grad()
23 | def compute_time_eval(model):
24 | """Computes precise model forward test time using dummy data."""
25 | # Use eval mode
26 | model.eval()
27 | # Generate a dummy mini-batch and copy data to GPU
28 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
29 | inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
30 | # Compute precise forward pass time
31 | timer = Timer()
32 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
33 | for cur_iter in range(total_iter):
34 | # Reset the timers after the warmup phase
35 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
36 | timer.reset()
37 | # Forward
38 | timer.tic()
39 | model(inputs)
40 | torch.cuda.synchronize()
41 | timer.toc()
42 | return timer.average_time
43 |
44 |
45 | def compute_time_train(model, loss_fun):
46 | """Computes precise model forward + backward time using dummy data."""
47 | # Use train mode
48 | model.train()
49 | # Generate a dummy mini-batch and copy data to GPU
50 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
51 | inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
52 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
53 | labels_one_hot = net.smooth_one_hot_labels(labels)
54 | # Cache BatchNorm2D running stats
55 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
56 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
57 | # Create a GradScaler for mixed precision training
58 | scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
59 | # Compute precise forward backward pass time
60 | fw_timer, bw_timer = Timer(), Timer()
61 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
62 | for cur_iter in range(total_iter):
63 | # Reset the timers after the warmup phase
64 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
65 | fw_timer.reset()
66 | bw_timer.reset()
67 | # Forward
68 | fw_timer.tic()
69 | with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
70 | preds = model(inputs)
71 | loss = loss_fun(preds, labels_one_hot)
72 | torch.cuda.synchronize()
73 | fw_timer.toc()
74 | # Backward
75 | bw_timer.tic()
76 | scaler.scale(loss).backward()
77 | torch.cuda.synchronize()
78 | bw_timer.toc()
79 | # Restore BatchNorm2D running stats
80 | for bn, (mean, var) in zip(bns, bn_stats):
81 | bn.running_mean, bn.running_var = mean, var
82 | return fw_timer.average_time, bw_timer.average_time
83 |
84 |
85 | def compute_time_loader(data_loader):
86 | """Computes loader time."""
87 | timer = Timer()
88 | loader.shuffle(data_loader, 0)
89 | data_loader_iterator = iter(data_loader)
90 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
91 | total_iter = min(total_iter, len(data_loader))
92 | for cur_iter in range(total_iter):
93 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
94 | timer.reset()
95 | timer.tic()
96 | next(data_loader_iterator)
97 | timer.toc()
98 | return timer.average_time
99 |
100 |
101 | def compute_time_model(model, loss_fun):
102 | """Times model."""
103 | logger.info("Computing model timings only...")
104 | # Compute timings
105 | test_fw_time = compute_time_eval(model)
106 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
107 | train_fw_bw_time = train_fw_time + train_bw_time
108 | # Output iter timing
109 | iter_times = {
110 | "test_fw_time": test_fw_time,
111 | "train_fw_time": train_fw_time,
112 | "train_bw_time": train_bw_time,
113 | "train_fw_bw_time": train_fw_bw_time,
114 | }
115 | logger.info(logging.dump_log_data(iter_times, "iter_times"))
116 |
117 |
118 | def compute_time_full(model, loss_fun, train_loader, test_loader):
119 | """Times model and data loader."""
120 | logger.info("Computing model and loader timings...")
121 | # Compute timings
122 | test_fw_time = compute_time_eval(model)
123 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
124 | train_fw_bw_time = train_fw_time + train_bw_time
125 | train_loader_time = compute_time_loader(train_loader)
126 | # Output iter timing
127 | iter_times = {
128 | "test_fw_time": test_fw_time,
129 | "train_fw_time": train_fw_time,
130 | "train_bw_time": train_bw_time,
131 | "train_fw_bw_time": train_fw_bw_time,
132 | "train_loader_time": train_loader_time,
133 | }
134 | logger.info(logging.dump_log_data(iter_times, "iter_times"))
135 | # Output epoch timing
136 | epoch_times = {
137 | "test_fw_time": test_fw_time * len(test_loader),
138 | "train_fw_time": train_fw_time * len(train_loader),
139 | "train_bw_time": train_bw_time * len(train_loader),
140 | "train_fw_bw_time": train_fw_bw_time * len(train_loader),
141 | "train_loader_time": train_loader_time * len(train_loader),
142 | }
143 | logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
144 | # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
145 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
146 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
147 |
--------------------------------------------------------------------------------
/pycls/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions that handle saving and loading of checkpoints."""
9 |
10 | import os
11 | from shutil import copyfile
12 |
13 | # import pycls.core.distributed as dist
14 | import torch
15 | from pycls.core.config import cfg
16 | from pycls.core.net import unwrap_model
17 |
18 |
19 | # Common prefix for checkpoint file names
20 | _NAME_PREFIX = "model_epoch_"
21 |
22 | # Checkpoints directory name
23 | _DIR_NAME = "checkpoints"
24 |
25 |
26 | def get_checkpoint_dir(episode_dir):
27 | """Retrieves the location for storing checkpoints."""
28 | return os.path.join(episode_dir, _DIR_NAME)
29 |
30 |
31 | def get_checkpoint(epoch, episode_dir):
32 | """Retrieves the path to a checkpoint file."""
33 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
34 | # return os.path.join(get_checkpoint_dir(), name)
35 | return os.path.join(episode_dir, name)
36 |
37 |
38 | def get_checkpoint_best(episode_dir):
39 | """Retrieves the path to the best checkpoint file."""
40 | return os.path.join(episode_dir, "model.pyth")
41 |
42 |
43 | def get_last_checkpoint(episode_dir):
44 | """Retrieves the most recent checkpoint (highest epoch number)."""
45 | checkpoint_dir = get_checkpoint_dir(episode_dir)
46 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
47 | last_checkpoint_name = sorted(checkpoints)[-1]
48 | return os.path.join(checkpoint_dir, last_checkpoint_name)
49 |
50 |
51 | def has_checkpoint(episode_dir):
52 | """Determines if there are checkpoints available."""
53 | checkpoint_dir = get_checkpoint_dir(episode_dir)
54 | if not os.path.exists(checkpoint_dir):
55 | return False
56 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
57 |
58 |
59 | def save_checkpoint(info, model_state, optimizer_state, epoch, cfg):
60 |
61 | """Saves a checkpoint."""
62 | # Save checkpoints only from the master process
63 | # if not dist.is_master_proc():
64 | # return
65 | # Ensure that the checkpoint dir exists
66 | os.makedirs(cfg.EPISODE_DIR, exist_ok=True)
67 |
68 | # Record the state
69 | checkpoint = {
70 | "epoch": epoch,
71 | "model_state": model_state,
72 | "optimizer_state": optimizer_state,
73 | "cfg": cfg.dump(),
74 | }
75 | global _NAME_PREFIX
76 | _NAME_PREFIX = info + '_' + _NAME_PREFIX
77 |
78 | # Write the checkpoint
79 | checkpoint_file = get_checkpoint(epoch, cfg.EPISODE_DIR)
80 | torch.save(checkpoint, checkpoint_file)
81 | # print("Model checkpoint saved at path: {}".format(checkpoint_file))
82 |
83 | _NAME_PREFIX = 'model_epoch_'
84 | return checkpoint_file
85 |
86 |
87 | def load_checkpoint(checkpoint_file, model, optimizer=None):
88 | """Loads the checkpoint from the given file."""
89 | err_str = "Checkpoint '{}' not found"
90 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
91 | checkpoint = torch.load(checkpoint_file, map_location="cpu")
92 | unwrap_model(model).load_state_dict(checkpoint["model_state"])
93 | optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else ()
94 | return model
95 |
96 |
97 | def delete_checkpoints(checkpoint_dir=None, keep="all"):
98 | """Deletes unneeded checkpoints, keep can be "all", "last", or "none"."""
99 | assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep)
100 | checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir()
101 | if keep == "all" or not os.path.exists(checkpoint_dir):
102 | return 0
103 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
104 | checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints
105 | [os.remove(os.path.join(checkpoint_dir, checkpoint)) for checkpoint in checkpoints]
106 | return len(checkpoints)
107 |
--------------------------------------------------------------------------------
/pycls/utils/distributed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Distributed helpers."""
9 |
10 | import multiprocessing
11 | import os
12 | import random
13 | import signal
14 | import threading
15 | import traceback
16 |
17 | import torch
18 | from pycls.core.config import cfg
19 |
20 |
21 | # Make work w recent PyTorch versions (https://github.com/pytorch/pytorch/issues/37377)
22 | os.environ["MKL_THREADING_LAYER"] = "GNU"
23 |
24 |
25 | def is_master_proc():
26 | """Determines if the current process is the master process.
27 |
28 | Master process is responsible for logging, writing and loading checkpoints. In
29 | the multi GPU setting, we assign the master role to the rank 0 process. When
30 | training using a single GPU, there is a single process which is considered master.
31 | """
32 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
33 |
34 |
35 | def init_process_group(proc_rank, world_size, port):
36 | """Initializes the default process group."""
37 | # Set the GPU to use
38 | torch.cuda.set_device(proc_rank)
39 | # Initialize the process group
40 | torch.distributed.init_process_group(
41 | backend=cfg.DIST_BACKEND,
42 | init_method="tcp://{}:{}".format(cfg.HOST, port),
43 | world_size=world_size,
44 | rank=proc_rank,
45 | )
46 |
47 |
48 | def destroy_process_group():
49 | """Destroys the default process group."""
50 | torch.distributed.destroy_process_group()
51 |
52 |
53 | def scaled_all_reduce(tensors):
54 | """Performs the scaled all_reduce operation on the provided tensors.
55 |
56 | The input tensors are modified in-place. Currently supports only the sum
57 | reduction operator. The reduced values are scaled by the inverse size of the
58 | process group (equivalent to cfg.NUM_GPUS).
59 | """
60 | # There is no need for reduction in the single-proc case
61 | if cfg.NUM_GPUS == 1:
62 | return tensors
63 | # Queue the reductions
64 | reductions = []
65 | for tensor in tensors:
66 | reduction = torch.distributed.all_reduce(tensor, async_op=True)
67 | reductions.append(reduction)
68 | # Wait for reductions to finish
69 | for reduction in reductions:
70 | reduction.wait()
71 | # Scale the results
72 | for tensor in tensors:
73 | tensor.mul_(1.0 / cfg.NUM_GPUS)
74 | return tensors
75 |
76 |
77 | class ChildException(Exception):
78 | """Wraps an exception from a child process."""
79 |
80 | def __init__(self, child_trace):
81 | super(ChildException, self).__init__(child_trace)
82 |
83 |
84 | class ErrorHandler(object):
85 | """Multiprocessing error handler (based on fairseq's).
86 |
87 | Listens for errors in child processes and propagates the tracebacks to the parent.
88 | """
89 |
90 | def __init__(self, error_queue):
91 | # Shared error queue
92 | self.error_queue = error_queue
93 | # Children processes sharing the error queue
94 | self.children_pids = []
95 | # Start a thread listening to errors
96 | self.error_listener = threading.Thread(target=self.listen, daemon=True)
97 | self.error_listener.start()
98 | # Register the signal handler
99 | signal.signal(signal.SIGUSR1, self.signal_handler)
100 |
101 | def add_child(self, pid):
102 | """Registers a child process."""
103 | self.children_pids.append(pid)
104 |
105 | def listen(self):
106 | """Listens for errors in the error queue."""
107 | # Wait until there is an error in the queue
108 | child_trace = self.error_queue.get()
109 | # Put the error back for the signal handler
110 | self.error_queue.put(child_trace)
111 | # Invoke the signal handler
112 | os.kill(os.getpid(), signal.SIGUSR1)
113 |
114 | def signal_handler(self, _sig_num, _stack_frame):
115 | """Signal handler."""
116 | # Kill children processes
117 | for pid in self.children_pids:
118 | os.kill(pid, signal.SIGINT)
119 | # Propagate the error from the child process
120 | raise ChildException(self.error_queue.get())
121 |
122 |
123 | def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs):
124 | """Runs a function from a child process."""
125 | try:
126 | # Initialize the process group
127 | init_process_group(proc_rank, world_size, port)
128 | # Run the function
129 | fun(*fun_args, **fun_kwargs)
130 | except KeyboardInterrupt:
131 | # Killed by the parent process
132 | pass
133 | except Exception:
134 | # Propagate exception to the parent process
135 | error_queue.put(traceback.format_exc())
136 | finally:
137 | # Destroy the process group
138 | destroy_process_group()
139 |
140 |
141 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
142 | """Runs a function in a multi-proc setting (unless num_proc == 1)."""
143 | # There is no need for multi-proc in the single-proc case
144 | fun_kwargs = fun_kwargs if fun_kwargs else {}
145 | if num_proc == 1:
146 | fun(*fun_args, **fun_kwargs)
147 | return
148 | # Handle errors from training subprocesses
149 | error_queue = multiprocessing.SimpleQueue()
150 | error_handler = ErrorHandler(error_queue)
151 | # Get a random port to use (without using global random number generator)
152 | port = random.Random().randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1])
153 | # Run each training subprocess
154 | ps = []
155 | for i in range(num_proc):
156 | p_i = multiprocessing.Process(
157 | target=run, args=(i, num_proc, port, error_queue, fun, fun_args, fun_kwargs)
158 | )
159 | ps.append(p_i)
160 | p_i.start()
161 | error_handler.add_child(p_i.pid)
162 | # Wait for each subprocess to finish
163 | for p in ps:
164 | p.join()
165 |
--------------------------------------------------------------------------------
/pycls/utils/io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """IO utilities (adapted from Detectron)"""
9 |
10 | import logging
11 | import os
12 | import re
13 | import sys
14 | from urllib import request as urlrequest
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
20 |
21 |
22 | def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL):
23 | """Download the file specified by the URL to the cache_dir and return the path to
24 | the cached file. If the argument is not a URL, simply return it as is.
25 | """
26 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
27 | if not is_url:
28 | return url_or_file
29 | url = url_or_file
30 | assert url.startswith(base_url), "url must start with: {}".format(base_url)
31 | cache_file_path = url.replace(base_url, cache_dir)
32 | if os.path.exists(cache_file_path):
33 | return cache_file_path
34 | cache_file_dir = os.path.dirname(cache_file_path)
35 | if not os.path.exists(cache_file_dir):
36 | os.makedirs(cache_file_dir)
37 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
38 | download_url(url, cache_file_path)
39 | return cache_file_path
40 |
41 |
42 | def _progress_bar(count, total):
43 | """Report download progress. Credit:
44 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
45 | """
46 | bar_len = 60
47 | filled_len = int(round(bar_len * count / float(total)))
48 | percents = round(100.0 * count / float(total), 1)
49 | bar = "=" * filled_len + "-" * (bar_len - filled_len)
50 | sys.stdout.write(
51 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
52 | )
53 | sys.stdout.flush()
54 | if count >= total:
55 | sys.stdout.write("\n")
56 |
57 |
58 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
59 | """Download url and write it to dst_file_path. Credit:
60 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
61 | """
62 | req = urlrequest.Request(url)
63 | response = urlrequest.urlopen(req)
64 | total_size = response.info().get("Content-Length").strip()
65 | total_size = int(total_size)
66 | bytes_so_far = 0
67 | with open(dst_file_path, "wb") as f:
68 | while 1:
69 | chunk = response.read(chunk_size)
70 | bytes_so_far += len(chunk)
71 | if not chunk:
72 | break
73 | if progress_hook:
74 | progress_hook(bytes_so_far, total_size)
75 | f.write(chunk)
76 | return bytes_so_far
77 |
--------------------------------------------------------------------------------
/pycls/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Logging."""
9 |
10 | import builtins
11 | import decimal
12 | import logging
13 | import os
14 | import simplejson
15 | import sys
16 |
17 | from pycls.core.config import cfg
18 |
19 | # import pycls.utils.distributed as du
20 |
21 | # Show filename and line number in logs
22 | _FORMAT = '[%(asctime)s %(filename)s: %(lineno)3d]: %(message)s'
23 |
24 | # Log file name (for cfg.LOG_DEST = 'file')
25 | _LOG_FILE = 'stdout.log'
26 |
27 | # Printed json stats lines will be tagged w/ this
28 | _TAG = 'json_stats: '
29 |
30 |
31 | def _suppress_print():
32 | """Suppresses printing from the current process."""
33 | def ignore(*_objects, _sep=' ', _end='\n', _file=sys.stdout, _flush=False):
34 | pass
35 | builtins.print = ignore
36 |
37 |
38 | def setup_logging(cfg):
39 | """Sets up the logging."""
40 | # Enable logging only for the master process
41 | # if du.is_master_proc():
42 | if True:
43 | # Clear the root logger to prevent any existing logging config
44 | # (e.g. set by another module) from messing with our setup
45 | logging.root.handlers = []
46 | # Construct logging configuration
47 | logging_config = {
48 | 'level': logging.INFO,
49 | 'format': _FORMAT,
50 | 'datefmt': '%Y-%m-%d %H:%M:%S'
51 | }
52 | # Log either to stdout or to a file
53 | if cfg.LOG_DEST == 'stdout':
54 | logging_config['stream'] = sys.stdout
55 | else:
56 | logging_config['filename'] = os.path.join(cfg.EXP_DIR, _LOG_FILE)
57 | # Configure logging
58 | logging.basicConfig(**logging_config)
59 | else:
60 | _suppress_print()
61 |
62 |
63 | def get_logger(name):
64 | """Retrieves the logger."""
65 | return logging.getLogger(name)
66 |
67 |
68 | def log_json_stats(stats):
69 | """Logs json stats."""
70 | # Decimal + string workaround for having fixed len float vals in logs
71 | stats = {
72 | k: decimal.Decimal('{:.12f}'.format(v)) if isinstance(v, float) else v
73 | for k, v in stats.items()
74 | }
75 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
76 | logger = get_logger(__name__)
77 | logger.info('{:s}{:s}'.format(_TAG, json_stats))
78 |
79 |
80 | def load_json_stats(log_file):
81 | """Loads json_stats from a single log file."""
82 | with open(log_file, 'r') as f:
83 | lines = f.readlines()
84 | json_lines = [l[l.find(_TAG) + len(_TAG):] for l in lines if _TAG in l]
85 | json_stats = [simplejson.loads(l) for l in json_lines]
86 | return json_stats
87 |
88 |
89 | def parse_json_stats(log, row_type, key):
90 | """Extract values corresponding to row_type/key out of log."""
91 | vals = [row[key] for row in log if row['_type'] == row_type and key in row]
92 | if key == 'iter' or key == 'epoch':
93 | vals = [int(val.split('/')[0]) for val in vals]
94 | return vals
95 |
96 |
97 | def get_log_files(log_dir, name_filter=''):
98 | """Get all log files in directory containing subdirs of trained models."""
99 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
100 | files = [os.path.join(log_dir, n, _LOG_FILE) for n in names]
101 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
102 | files, names = zip(*f_n_ps)
103 | return files, names
104 |
--------------------------------------------------------------------------------
/pycls/utils/meters.py:
--------------------------------------------------------------------------------
1 | #This file is modified from official pycls repository to adapt in AL settings.
2 | """Meters."""
3 |
4 | from collections import deque
5 |
6 | import datetime
7 | import numpy as np
8 |
9 | from pycls.core.config import cfg
10 | from pycls.utils.timer import Timer
11 |
12 | import pycls.utils.logging as lu
13 | import pycls.utils.metrics as metrics
14 |
15 |
16 | def eta_str(eta_td):
17 | """Converts an eta timedelta to a fixed-width string format."""
18 | days = eta_td.days
19 | hrs, rem = divmod(eta_td.seconds, 3600)
20 | mins, secs = divmod(rem, 60)
21 | return '{0:02},{1:02}:{2:02}:{3:02}'.format(days, hrs, mins, secs)
22 |
23 |
24 | class ScalarMeter(object):
25 | """Measures a scalar value (adapted from Detectron)."""
26 |
27 | def __init__(self, window_size):
28 | self.deque = deque(maxlen=window_size)
29 | self.total = 0.0
30 | self.count = 0
31 |
32 | def reset(self):
33 | self.deque.clear()
34 | self.total = 0.0
35 | self.count = 0
36 |
37 | def add_value(self, value):
38 | self.deque.append(value)
39 | self.count += 1
40 | self.total += value
41 |
42 | def get_win_median(self):
43 | return np.median(self.deque)
44 |
45 | def get_win_avg(self):
46 | return np.mean(self.deque)
47 |
48 | def get_global_avg(self):
49 | return self.total / self.count
50 |
51 |
52 | class TrainMeter(object):
53 | """Measures training stats."""
54 |
55 | def __init__(self, epoch_iters):
56 | self.epoch_iters = epoch_iters
57 | self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
58 | self.iter_timer = Timer()
59 | self.loss = ScalarMeter(cfg.LOG_PERIOD)
60 | self.loss_total = 0.0
61 | self.lr = None
62 | # Current minibatch errors (smoothed over a window)
63 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
64 | # Number of misclassified examples
65 | self.num_top1_mis = 0
66 | self.num_samples = 0
67 |
68 | def reset(self, timer=False):
69 | if timer:
70 | self.iter_timer.reset()
71 | self.loss.reset()
72 | self.loss_total = 0.0
73 | self.lr = None
74 | self.mb_top1_err.reset()
75 | self.num_top1_mis = 0
76 | self.num_samples = 0
77 |
78 | def iter_tic(self):
79 | self.iter_timer.tic()
80 |
81 | def iter_toc(self):
82 | self.iter_timer.toc()
83 |
84 | def update_stats(self, top1_err, loss, lr, mb_size):
85 | # Current minibatch stats
86 | self.mb_top1_err.add_value(top1_err)
87 | self.loss.add_value(loss)
88 | self.lr = lr
89 | # Aggregate stats
90 | self.num_top1_mis += top1_err * mb_size
91 | self.loss_total += loss * mb_size
92 | self.num_samples += mb_size
93 |
94 |
95 | def get_iter_stats(self, cur_epoch, cur_iter):
96 | eta_sec = self.iter_timer.average_time * (
97 | self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1)
98 | )
99 | eta_td = datetime.timedelta(seconds=int(eta_sec))
100 | mem_usage = metrics.gpu_mem_usage()
101 | stats = {
102 | '_type': 'train_iter',
103 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
104 | 'iter': '{}/{}'.format(cur_iter + 1, self.epoch_iters),
105 | 'top1_err': self.mb_top1_err.get_win_median(),
106 | 'loss': self.loss.get_win_median(),
107 | 'lr': self.lr,
108 | }
109 | return stats
110 |
111 | def log_iter_stats(self, cur_epoch, cur_iter):
112 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
113 | return
114 | stats = self.get_iter_stats(cur_epoch, cur_iter)
115 | lu.log_json_stats(stats)
116 |
117 | def get_epoch_stats(self, cur_epoch):
118 | eta_sec = self.iter_timer.average_time * (
119 | self.max_iter - (cur_epoch + 1) * self.epoch_iters
120 | )
121 | eta_td = datetime.timedelta(seconds=int(eta_sec))
122 | mem_usage = metrics.gpu_mem_usage()
123 | top1_err = self.num_top1_mis / self.num_samples
124 | avg_loss = self.loss_total / self.num_samples
125 | stats = {
126 | '_type': 'train_epoch',
127 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
128 | 'top1_err': top1_err,
129 | 'loss': avg_loss,
130 | 'lr': self.lr,
131 | }
132 | return stats
133 |
134 | def log_epoch_stats(self, cur_epoch):
135 | stats = self.get_epoch_stats(cur_epoch)
136 | lu.log_json_stats(stats)
137 |
138 |
139 | class TestMeter(object):
140 | """Measures testing stats."""
141 |
142 | def __init__(self, max_iter):
143 | self.max_iter = max_iter
144 | self.iter_timer = Timer()
145 | # Current minibatch errors (smoothed over a window)
146 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
147 | # Min errors (over the full test set)
148 | self.min_top1_err = 100.0
149 | # Number of misclassified examples
150 | self.num_top1_mis = 0
151 | self.num_samples = 0
152 |
153 | def reset(self, min_errs=False):
154 | if min_errs:
155 | self.min_top1_err = 100.0
156 | self.iter_timer.reset()
157 | self.mb_top1_err.reset()
158 | self.num_top1_mis = 0
159 | self.num_samples = 0
160 |
161 | def iter_tic(self):
162 | self.iter_timer.tic()
163 |
164 | def iter_toc(self):
165 | self.iter_timer.toc()
166 |
167 | def update_stats(self, top1_err, mb_size):
168 | self.mb_top1_err.add_value(top1_err)
169 | self.num_top1_mis += top1_err * mb_size
170 | self.num_samples += mb_size
171 |
172 | def get_iter_stats(self, cur_epoch, cur_iter):
173 | mem_usage = metrics.gpu_mem_usage()
174 | iter_stats = {
175 | '_type': 'test_iter',
176 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
177 | 'iter': '{}/{}'.format(cur_iter + 1, self.max_iter),
178 | 'top1_err': self.mb_top1_err.get_win_median(),
179 | }
180 | return iter_stats
181 |
182 | def log_iter_stats(self, cur_epoch, cur_iter):
183 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
184 | return
185 | stats = self.get_iter_stats(cur_epoch, cur_iter)
186 | lu.log_json_stats(stats)
187 |
188 | def get_epoch_stats(self, cur_epoch):
189 | top1_err = self.num_top1_mis / self.num_samples
190 | self.min_top1_err = min(self.min_top1_err, top1_err)
191 | mem_usage = metrics.gpu_mem_usage()
192 | stats = {
193 | '_type': 'test_epoch',
194 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
195 | 'top1_err': top1_err,
196 | 'min_top1_err': self.min_top1_err
197 | }
198 | return stats
199 |
200 | def log_epoch_stats(self, cur_epoch):
201 | stats = self.get_epoch_stats(cur_epoch)
202 | lu.log_json_stats(stats)
203 |
204 | class ValMeter(object):
205 | """Measures Validation stats."""
206 |
207 | def __init__(self, max_iter):
208 | self.max_iter = max_iter
209 | self.iter_timer = Timer()
210 | # Current minibatch errors (smoothed over a window)
211 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
212 | # Min errors (over the full Val set)
213 | self.min_top1_err = 100.0
214 | # Number of misclassified examples
215 | self.num_top1_mis = 0
216 | self.num_samples = 0
217 |
218 | def reset(self, min_errs=False):
219 | if min_errs:
220 | self.min_top1_err = 100.0
221 | self.iter_timer.reset()
222 | self.mb_top1_err.reset()
223 | self.num_top1_mis = 0
224 | self.num_samples = 0
225 |
226 | def iter_tic(self):
227 | self.iter_timer.tic()
228 |
229 | def iter_toc(self):
230 | self.iter_timer.toc()
231 |
232 | def update_stats(self, top1_err, mb_size):
233 | self.mb_top1_err.add_value(top1_err)
234 | self.num_top1_mis += top1_err * mb_size
235 | self.num_samples += mb_size
236 |
237 | def get_iter_stats(self, cur_epoch, cur_iter):
238 | mem_usage = metrics.gpu_mem_usage()
239 | iter_stats = {
240 | '_type': 'Val_iter',
241 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
242 | 'iter': '{}/{}'.format(cur_iter + 1, self.max_iter),
243 | 'top1_err': self.mb_top1_err.get_win_median(),
244 | }
245 | return iter_stats
246 |
247 | def log_iter_stats(self, cur_epoch, cur_iter):
248 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
249 | return
250 | stats = self.get_iter_stats(cur_epoch, cur_iter)
251 | lu.log_json_stats(stats)
252 |
253 | def get_epoch_stats(self, cur_epoch):
254 | top1_err = self.num_top1_mis / self.num_samples
255 | self.min_top1_err = min(self.min_top1_err, top1_err)
256 | mem_usage = metrics.gpu_mem_usage()
257 | stats = {
258 | '_type': 'Val_epoch',
259 | 'epoch': '{}/{}'.format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
260 | 'top1_err': top1_err,
261 | 'min_top1_err': self.min_top1_err
262 | }
263 | return stats
264 |
265 | def log_epoch_stats(self, cur_epoch):
266 | stats = self.get_epoch_stats(cur_epoch)
267 | lu.log_json_stats(stats)
--------------------------------------------------------------------------------
/pycls/utils/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for computing metrics."""
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 |
14 | from pycls.core.config import cfg
15 |
16 | # Number of bytes in a megabyte
17 | _B_IN_MB = 1024 * 1024
18 |
19 |
20 | def topks_correct(preds, labels, ks):
21 | """Computes the number of top-k correct predictions for each k."""
22 | assert preds.size(0) == labels.size(0), \
23 | 'Batch dim of predictions and labels must match'
24 | # Find the top max_k predictions for each sample
25 | _top_max_k_vals, top_max_k_inds = torch.topk(
26 | preds, max(ks), dim=1, largest=True, sorted=True
27 | )
28 | # (batch_size, max_k) -> (max_k, batch_size)
29 | top_max_k_inds = top_max_k_inds.t()
30 | # (batch_size, ) -> (max_k, batch_size)
31 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
32 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct
33 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
34 | # Compute the number of topk correct predictions for each k
35 | topks_correct = [
36 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks
37 | ]
38 | return topks_correct
39 |
40 |
41 | def topk_errors(preds, labels, ks):
42 | """Computes the top-k error for each k."""
43 | num_topks_correct = topks_correct(preds, labels, ks)
44 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
45 |
46 |
47 | def topk_accuracies(preds, labels, ks):
48 | """Computes the top-k accuracy for each k."""
49 | num_topks_correct = topks_correct(preds, labels, ks)
50 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]
51 |
52 |
53 | def params_count(model):
54 | """Computes the number of parameters."""
55 | return np.sum([p.numel() for p in model.parameters()]).item()
56 |
57 |
58 | def flops_count(model):
59 | """Computes the number of flops statically."""
60 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
61 | count = 0
62 | for n, m in model.named_modules():
63 | if isinstance(m, nn.Conv2d):
64 | if 'se.' in n:
65 | count += m.in_channels * m.out_channels + m.bias.numel()
66 | continue
67 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
68 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
69 | count += np.prod([
70 | m.weight.numel(),
71 | h_out, w_out
72 | ])
73 | if '.proj' not in n:
74 | h, w = h_out, w_out
75 | elif isinstance(m, nn.MaxPool2d):
76 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
77 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
78 | elif isinstance(m, nn.Linear):
79 | count += m.in_features * m.out_features
80 | return count.item()
81 |
82 |
83 | def gpu_mem_usage():
84 | """Computes the GPU memory usage for the current device (MB)."""
85 | mem_usage_bytes = torch.cuda.max_memory_allocated()
86 | return mem_usage_bytes / _B_IN_MB
87 |
--------------------------------------------------------------------------------
/pycls/utils/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for manipulating networks."""
9 |
10 | import itertools
11 | import math
12 | import torch
13 | import torch.nn as nn
14 |
15 | from pycls.core.config import cfg
16 |
17 |
18 | def init_weights(m):
19 | """Performs ResNet-style weight initialization."""
20 | if isinstance(m, nn.Conv2d):
21 | # Note that there is no bias due to BN
22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
24 | elif isinstance(m, nn.BatchNorm2d):
25 | zero_init_gamma = (
26 | hasattr(m, 'final_bn') and m.final_bn and
27 | cfg.BN.ZERO_INIT_FINAL_GAMMA
28 | )
29 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
30 | m.bias.data.zero_()
31 | elif isinstance(m, nn.Linear):
32 | m.weight.data.normal_(mean=0.0, std=0.01)
33 | m.bias.data.zero_()
34 |
35 |
36 | @torch.no_grad()
37 | def compute_precise_bn_stats(model, loader):
38 | """Computes precise BN stats on training data."""
39 | # Compute the number of minibatches to use
40 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
41 | # Retrieve the BN layers
42 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
43 | # Initialize stats storage
44 | mus = [torch.zeros_like(bn.running_mean) for bn in bns]
45 | sqs = [torch.zeros_like(bn.running_var) for bn in bns]
46 | # Remember momentum values
47 | moms = [bn.momentum for bn in bns]
48 | # Disable momentum
49 | for bn in bns:
50 | bn.momentum = 1.0
51 | # Accumulate the stats across the data samples
52 | for inputs, _labels in itertools.islice(loader, num_iter):
53 | model(inputs.cuda())
54 | # Accumulate the stats for each BN layer
55 | for i, bn in enumerate(bns):
56 | m, v = bn.running_mean, bn.running_var
57 | sqs[i] += (v + m * m) / num_iter
58 | mus[i] += m / num_iter
59 | # Set the stats and restore momentum values
60 | for i, bn in enumerate(bns):
61 | bn.running_var = sqs[i] - mus[i] * mus[i]
62 | bn.running_mean = mus[i]
63 | bn.momentum = moms[i]
64 |
65 |
66 | def reset_bn_stats(model):
67 | """Resets running BN stats."""
68 | for m in model.modules():
69 | if isinstance(m, torch.nn.BatchNorm2d):
70 | m.reset_running_stats()
71 |
72 |
73 | def drop_connect(x, drop_ratio):
74 | """Drop connect (adapted from DARTS)."""
75 | keep_ratio = 1.0 - drop_ratio
76 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
77 | mask.bernoulli_(keep_ratio)
78 | x.div_(keep_ratio)
79 | x.mul_(mask)
80 | return x
81 |
82 |
83 | def get_flat_weights(model):
84 | """Gets all model weights as a single flat vector."""
85 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
86 |
87 |
88 | def set_flat_weights(model, flat_weights):
89 | """Sets all model weights from a single flat vector."""
90 | k = 0
91 | for p in model.parameters():
92 | n = p.data.numel()
93 | p.data.copy_(flat_weights[k:(k + n)].view_as(p.data))
94 | k += n
95 | assert k == flat_weights.numel()
96 |
--------------------------------------------------------------------------------
/pycls/utils/plotting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Plotting functions."""
9 |
10 | import colorlover as cl
11 | import matplotlib.pyplot as plt
12 | import plotly.graph_objs as go
13 | import plotly.offline as offline
14 | import pycls.core.logging as logging
15 |
16 |
17 | def get_plot_colors(max_colors, color_format="pyplot"):
18 | """Generate colors for plotting."""
19 | colors = cl.scales["11"]["qual"]["Paired"]
20 | if max_colors > len(colors):
21 | colors = cl.to_rgb(cl.interp(colors, max_colors))
22 | if color_format == "pyplot":
23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
24 | return colors
25 |
26 |
27 | def prepare_plot_data(log_files, names, metric="top1_err"):
28 | """Load logs and extract data for plotting error curves."""
29 | plot_data = []
30 | for file, name in zip(log_files, names):
31 | d, data = {}, logging.sort_log_data(logging.load_log_data(file))
32 | for phase in ["train", "test"]:
33 | x = data[phase + "_epoch"]["epoch_ind"]
34 | y = data[phase + "_epoch"][metric]
35 | d["x_" + phase], d["y_" + phase] = x, y
36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
37 | plot_data.append(d)
38 | assert len(plot_data) > 0, "No data to plot"
39 | return plot_data
40 |
41 |
42 | def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
43 | """Plot error curves using plotly and save to file."""
44 | plot_data = prepare_plot_data(log_files, names, metric)
45 | colors = get_plot_colors(len(plot_data), "plotly")
46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend)
47 | data = []
48 | for i, d in enumerate(plot_data):
49 | s = str(i)
50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
52 | data.append(
53 | go.Scatter(
54 | x=d["x_train"],
55 | y=d["y_train"],
56 | mode="lines",
57 | name=d["train_label"],
58 | line=line_train,
59 | legendgroup=s,
60 | visible=True,
61 | showlegend=False,
62 | )
63 | )
64 | data.append(
65 | go.Scatter(
66 | x=d["x_test"],
67 | y=d["y_test"],
68 | mode="lines",
69 | name=d["test_label"],
70 | line=line_test,
71 | legendgroup=s,
72 | visible=True,
73 | showlegend=True,
74 | )
75 | )
76 | data.append(
77 | go.Scatter(
78 | x=d["x_train"],
79 | y=d["y_train"],
80 | mode="lines",
81 | name=d["train_label"],
82 | line=line_train,
83 | legendgroup=s,
84 | visible=False,
85 | showlegend=True,
86 | )
87 | )
88 | # Prepare layout w ability to toggle 'all', 'train', 'test'
89 | titlefont = {"size": 18, "color": "#7f7f7f"}
90 | vis = [[True, True, False], [False, False, True], [False, True, False]]
91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
92 | buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
93 | layout = go.Layout(
94 | title=metric + " vs. epoch
[dash=train, solid=test]",
95 | xaxis={"title": "epoch", "titlefont": titlefont},
96 | yaxis={"title": metric, "titlefont": titlefont},
97 | showlegend=True,
98 | hoverlabel={"namelength": -1},
99 | updatemenus=[
100 | {
101 | "buttons": buttons,
102 | "direction": "down",
103 | "showactive": True,
104 | "x": 1.02,
105 | "xanchor": "left",
106 | "y": 1.08,
107 | "yanchor": "top",
108 | }
109 | ],
110 | )
111 | # Create plotly plot
112 | offline.plot({"data": data, "layout": layout}, filename=filename)
113 |
114 |
115 | def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
116 | """Plot error curves using matplotlib.pyplot and save to file."""
117 | plot_data = prepare_plot_data(log_files, names, metric)
118 | colors = get_plot_colors(len(names))
119 | for ind, d in enumerate(plot_data):
120 | c, lbl = colors[ind], d["test_label"]
121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
123 | plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
124 | plt.xlabel("epoch", fontsize=14)
125 | plt.ylabel(metric, fontsize=14)
126 | plt.grid(alpha=0.4)
127 | plt.legend()
128 | if filename:
129 | plt.savefig(filename)
130 | plt.clf()
131 | else:
132 | plt.show()
133 |
--------------------------------------------------------------------------------
/pycls/utils/timer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Timer."""
9 |
10 | import time
11 |
12 |
13 | class Timer(object):
14 | """A simple timer (adapted from Detectron)."""
15 |
16 | def __init__(self):
17 | self.total_time = None
18 | self.calls = None
19 | self.start_time = None
20 | self.diff = None
21 | self.average_time = None
22 | self.reset()
23 |
24 | def tic(self):
25 | # using time.time as time.clock does not normalize for multithreading
26 | self.start_time = time.time()
27 |
28 | def toc(self):
29 | self.diff = time.time() - self.start_time
30 | self.total_time += self.diff
31 | self.calls += 1
32 | self.average_time = self.total_time / self.calls
33 |
34 | def reset(self):
35 | self.total_time = 0.0
36 | self.calls = 0
37 | self.start_time = 0.0
38 | self.diff = 0.0
39 | self.average_time = 0.0
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black==19.3b0
2 | flake8==3.8.4
3 | isort==4.3.21
4 | matplotlib==3.3.4
5 | numpy
6 | opencv-python==4.2.0.34
7 | torch==1.7.1
8 | torchvision==0.8.2
9 | parameterized
10 | setuptools
11 | simplejson
12 | yacs
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acl21/deep-active-learning-pytorch/637fd507235632903bcf84ed841ff524d847b94e/tools/__init__.py
--------------------------------------------------------------------------------
/tools/test_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from datetime import datetime
4 | import argparse
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | # local
13 |
14 | def add_path(path):
15 | if path not in sys.path:
16 | sys.path.insert(0, path)
17 |
18 | add_path(os.path.abspath('..'))
19 |
20 | import pycls.core.builders as model_builder
21 | from pycls.core.config import cfg, dump_cfg
22 | from pycls.datasets.data import Data
23 | import pycls.utils.checkpoint as cu
24 | import pycls.utils.logging as lu
25 | import pycls.utils.metrics as mu
26 | import pycls.utils.net as nu
27 | from pycls.utils.meters import TestMeter
28 |
29 | logger = lu.get_logger(__name__)
30 |
31 | def argparser():
32 | parser = argparse.ArgumentParser(description='Passive Learning - Image Classification')
33 | parser.add_argument('--cfg', dest='cfg_file', help='Config file', required=True, type=str)
34 |
35 | return parser
36 |
37 | def plot_arrays(x_vals, y_vals, x_name, y_name, dataset_name, out_dir, isDebug=False):
38 | # if not du.is_master_proc():
39 | # return
40 |
41 | import matplotlib.pyplot as plt
42 | temp_name = "{}_vs_{}".format(x_name, y_name)
43 | plt.xlabel(x_name)
44 | plt.ylabel(y_name)
45 | plt.title("Dataset: {}; {}".format(dataset_name, temp_name))
46 | plt.plot(x_vals, y_vals)
47 |
48 | if isDebug: print("plot_saved at : {}".format(os.path.join(out_dir, temp_name+'.png')))
49 |
50 | plt.savefig(os.path.join(out_dir, temp_name+".png"))
51 | plt.close()
52 |
53 | def save_plot_values(temp_arrays, temp_names, out_dir, isParallel=True, saveInTextFormat=False, isDebug=True):
54 |
55 | """ Saves arrays provided in the list in npy format """
56 | # Return if not master process
57 | # if isParallel:
58 | # if not du.is_master_proc():
59 | # return
60 |
61 | for i in range(len(temp_arrays)):
62 | temp_arrays[i] = np.array(temp_arrays[i])
63 | temp_dir = out_dir
64 | # if cfg.TRAIN.TRANSFER_EXP:
65 | # temp_dir += os.path.join("transfer_experiment",cfg.MODEL.TRANSFER_MODEL_TYPE+"_depth_"+str(cfg.MODEL.TRANSFER_MODEL_DEPTH))+"/"
66 |
67 | if not os.path.exists(temp_dir):
68 | os.makedirs(temp_dir)
69 | if saveInTextFormat:
70 | # if isDebug: print(f"Saving {temp_names[i]} at {temp_dir+temp_names[i]}.txt in text format!!")
71 | np.savetxt(temp_dir+'/'+temp_names[i]+".txt", temp_arrays[i], fmt="%d")
72 | else:
73 | # if isDebug: print(f"Saving {temp_names[i]} at {temp_dir+temp_names[i]}.npy in numpy format!!")
74 | np.save(temp_dir+'/'+temp_names[i]+".npy", temp_arrays[i])
75 |
76 | def is_eval_epoch(cur_epoch):
77 | """Determines if the model should be evaluated at the current epoch."""
78 | return (
79 | (cur_epoch + 1) % cfg.TRAIN.EVAL_PERIOD == 0 or
80 | (cur_epoch + 1) == cfg.OPTIM.MAX_EPOCH
81 | )
82 |
83 |
84 | def main(cfg):
85 |
86 | # Setting up GPU args
87 | use_cuda = (cfg.NUM_GPUS > 0) and torch.cuda.is_available()
88 | device = torch.device("cuda" if use_cuda else "cpu")
89 | kwargs = {'num_workers': cfg.DATA_LOADER.NUM_WORKERS, 'pin_memory': cfg.DATA_LOADER.PIN_MEMORY} if use_cuda else {}
90 |
91 | # Using specific GPU
92 | # os.environ['NVIDIA_VISIBLE_DEVICES'] = str(cfg.GPU_ID)
93 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
94 | # print("Using GPU : {}.\n".format(cfg.GPU_ID))
95 |
96 | # Getting the output directory ready (default is "/output")
97 | cfg.OUT_DIR = os.path.join(os.path.abspath('..'), cfg.OUT_DIR)
98 | if not os.path.exists(cfg.OUT_DIR):
99 | os.mkdir(cfg.OUT_DIR)
100 | # Create "DATASET/MODEL TYPE" specific directory
101 | dataset_out_dir = os.path.join(cfg.OUT_DIR, cfg.DATASET.NAME, cfg.MODEL.TYPE)
102 | if not os.path.exists(dataset_out_dir):
103 | os.makedirs(dataset_out_dir)
104 | # Creating the experiment directory inside the dataset specific directory
105 | # all logs, labeled, unlabeled, validation sets are stroed here
106 | # E.g., output/CIFAR10/resnet18/{timestamp or cfg.EXP_NAME based on arguments passed}
107 | if cfg.EXP_NAME == 'auto':
108 | now = datetime.now()
109 | exp_dir = f'{now.year}_{now.month}_{now.day}_{now.hour}{now.minute}{now.second}'
110 | else:
111 | exp_dir = cfg.EXP_NAME
112 |
113 | exp_dir = os.path.join(dataset_out_dir, exp_dir)
114 | if not os.path.exists(exp_dir):
115 | os.mkdir(exp_dir)
116 | print("Experiment Directory is {}.\n".format(exp_dir))
117 | else:
118 | print("Experiment Directory Already Exists: {}. Reusing it may lead to loss of old logs in the directory.\n".format(exp_dir))
119 | cfg.EXP_DIR = exp_dir
120 |
121 | # Save the config file in EXP_DIR
122 | dump_cfg(cfg)
123 |
124 | # Setup Logger
125 | lu.setup_logging(cfg)
126 |
127 | # Dataset preparing steps
128 | print("\n======== PREPARING TEST DATA ========\n")
129 | cfg.DATASET.ROOT_DIR = os.path.join(os.path.abspath('..'), cfg.DATASET.ROOT_DIR)
130 | data_obj = Data(cfg)
131 | test_data, test_size = data_obj.getDataset(save_dir=cfg.DATASET.ROOT_DIR, isTrain=False, isDownload=True)
132 |
133 | print("\nDataset {} Loaded Sucessfully. Total Test Size: {}\n".format(cfg.DATASET.NAME, test_size))
134 | logger.info("Dataset {} Loaded Sucessfully. Total Test Size: {}\n".format(cfg.DATASET.NAME, test_size))
135 |
136 | # Preparing dataloaders for testing
137 | test_loader = data_obj.getTestLoader(data=test_data, test_batch_size=cfg.TRAIN.BATCH_SIZE, seed_id=cfg.RNG_SEED)
138 |
139 | print("======== TESTING ========\n")
140 | logger.info("======== TESTING ========\n")
141 | test_acc = test_model(test_loader, os.path.join(os.path.abspath('..'), cfg.TEST.MODEL_PATH), cfg)
142 | print("Test Accuracy: {}.\n".format(round(test_acc, 4)))
143 | logger.info("Test Accuracy {}.\n".format(test_acc))
144 |
145 | print('Check the test accuracy inside {}/stdout.log'.format(cfg.EXP_DIR))
146 |
147 | print("================================\n\n")
148 | logger.info("================================\n\n")
149 |
150 |
151 | def test_model(test_loader, checkpoint_file, cfg, cur_episode=0):
152 |
153 | test_meter = TestMeter(len(test_loader))
154 |
155 | model = model_builder.build_model(cfg)
156 | model = cu.load_checkpoint(checkpoint_file, model)
157 |
158 | test_err = test_epoch(test_loader, model, test_meter, cur_episode)
159 | test_acc = 100. - test_err
160 |
161 | return test_acc
162 |
163 |
164 | @torch.no_grad()
165 | def test_epoch(test_loader, model, test_meter, cur_epoch):
166 | """Evaluates the model on the test set."""
167 |
168 | if torch.cuda.is_available():
169 | model.cuda()
170 |
171 | # Enable eval mode
172 | model.eval()
173 | test_meter.iter_tic()
174 |
175 | misclassifications = 0.
176 | totalSamples = 0.
177 |
178 | for cur_iter, (inputs, labels) in enumerate(tqdm(test_loader, desc="Test Data")):
179 | with torch.no_grad():
180 | # Transfer the data to the current GPU device
181 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
182 | inputs = inputs.type(torch.cuda.FloatTensor)
183 | # Compute the predictions
184 | preds = model(inputs)
185 | # Compute the errors
186 | top1_err, top5_err = mu.topk_errors(preds, labels, [1, 5])
187 | # Combine the errors across the GPUs
188 | # if cfg.NUM_GPUS > 1:
189 | # top1_err = du.scaled_all_reduce([top1_err])
190 | # #as above returns a list
191 | # top1_err = top1_err[0]
192 | # Copy the errors from GPU to CPU (sync point)
193 | top1_err = top1_err.item()
194 | # Multiply by Number of GPU's as top1_err is scaled by 1/Num_GPUs
195 | misclassifications += top1_err * inputs.size(0) * cfg.NUM_GPUS
196 | totalSamples += inputs.size(0)*cfg.NUM_GPUS
197 | test_meter.iter_toc()
198 | # Update and log stats
199 | test_meter.update_stats(
200 | top1_err=top1_err, mb_size=inputs.size(0) * cfg.NUM_GPUS
201 | )
202 | test_meter.log_iter_stats(cur_epoch, cur_iter)
203 | test_meter.iter_tic()
204 | # Log epoch stats
205 | test_meter.log_epoch_stats(cur_epoch)
206 | test_meter.reset()
207 |
208 | return misclassifications/totalSamples
209 |
210 |
211 | if __name__ == "__main__":
212 | cfg.merge_from_file(argparser().parse_args().cfg_file)
213 | main(cfg)
214 |
--------------------------------------------------------------------------------