├── experiments ├── __init__.py ├── datasets │ ├── __init__.py │ ├── STL10 │ │ ├── __init__.py │ │ ├── data_loader_stl10frac2.py │ │ └── data_loader_stl10frac.py │ ├── mnist │ │ ├── __init__.py │ │ └── data_loader_mnist.py │ ├── cifar10 │ │ ├── __init__.py │ │ ├── autoaugment.py │ │ └── data_loader_cifar10.py │ ├── cifar100 │ │ ├── __init__.py │ │ ├── autoaugment.py │ │ └── data_loader_cifar100.py │ ├── mnist12k │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── data_loader_mnist12k.py │ │ └── own_transforms.py │ ├── mnist_fliprot │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── own_transforms.py │ │ └── data_loader_mnist_fliprot.py │ ├── mnist_rot │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── own_transforms.py │ │ └── data_loader_mnist_rot.py │ └── download_mnist.sh ├── .gitignore ├── models │ ├── __init__.py │ ├── exp_cnn.py │ ├── example.py │ ├── utils.py │ ├── e2sfcnn.py │ ├── e2sfcnn_quotient.py │ └── wide_resnet.py ├── mnist_final.sh ├── setting_up_env.sh ├── stl10_ablation.sh ├── stl10_experiments.sh ├── mnist_restrict.sh ├── multiple_exps.py ├── print_results.py ├── cifar_experiments.sh ├── cifar_single.sh ├── retrieve_model.py ├── mnist_final_single.sh ├── count_parameters.py ├── stl10_single.sh ├── mnist_bench_single.sh ├── mnist_bench.sh ├── plot_exps.py └── optimizers_L1L2.py ├── LICENSE └── README.md /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/STL10/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/cifar10/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/cifar100/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/mnist12k/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_fliprot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_rot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | -------------------------------------------------------------------------------- /experiments/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | datasets/*/* 4 | 5 | !datasets/*/*.py 6 | !datasets/*/*.sh 7 | !datasets/*.py 8 | -------------------------------------------------------------------------------- /experiments/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .e2sfcnn import E2SFCNN 3 | from .e2sfcnn_quotient import E2SFCNN_QUOT 4 | 5 | from .exp_e2sfcnn import ExpE2SFCNN 6 | from .exp_cnn import ExpCNN 7 | 8 | from .e2_wide_resnet import * 9 | from .wide_resnet import * 10 | 11 | -------------------------------------------------------------------------------- /experiments/datasets/download_mnist.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip 4 | # uncompress 5 | unzip -n mnist_rotation_new.zip -d mnist_rot 6 | 7 | 8 | wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist.zip 9 | # uncompress 10 | unzip -n mnist.zip -d mnist12k -------------------------------------------------------------------------------- /experiments/mnist_final.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # C16 regular model 5 | ./mnist_final_single.sh --N 16 --model E2SFCNN --F "None" --fixparams --S 6 6 | 7 | # D16 -> C16 regular model 8 | ./mnist_final_single.sh --N 16 --flip --restrict 5 --model E2SFCNN --F "None" --fixparams --S 6 9 | 10 | # C16 quotient model 11 | ./mnist_final_single.sh --N 16 --model E2SFCNN_QUOT --F "None" --fixparams --S 6 12 | 13 | -------------------------------------------------------------------------------- /experiments/setting_up_env.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | conda create --name e2exp python=3.6 4 | source activate e2exp 5 | 6 | conda install -y pytorch=1.3 torchvision cudatoolkit=10.0 -c pytorch 7 | conda install -y -c conda-forge matplotlib 8 | conda install -y scipy=1.5 pandas scikit-learn=0.23 9 | conda install -y -c anaconda sqlite 10 | 11 | #pip install e2cnn 12 | 13 | mkdir tmp_e2cnn 14 | cd tmp_e2cnn 15 | git clone --single-branch --branch legacy_py3.6 https://github.com/QUVA-Lab/e2cnn 16 | mv e2cnn/e2cnn ../e2cnn 17 | cd .. 18 | rm -rf tmp_e2cnn 19 | 20 | -------------------------------------------------------------------------------- /experiments/stl10_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sizes=("250" "500" "1000" "2000" "4000") 4 | 5 | for i in $(seq 1 3) 6 | do 7 | 8 | for s in "${sizes[@]}" 9 | do 10 | 11 | dataset="STL10cif|$s" 12 | 13 | # Equivariant Big Model D8D4D1 14 | ./stl10_exp.sh --augment --deltaorth --model e2wrn16_8 --restrict 3 --fixparams --split --dataset "$dataset" --validate 15 | 16 | # Equivariant Small Model D8D4D1 17 | ./stl10_exp.sh --augment --deltaorth --model e2wrn16_8 --restrict 3 --dataset "$dataset" --validate 18 | 19 | # Conventional Model (with DeltaOrthogonal) 20 | ./stl10_exp.sh --augment --deltaorth --model wrn16_8 --dataset "$dataset" --validate 21 | 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /experiments/stl10_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in $(seq 1 5) 4 | do 5 | 6 | # Flip Equivariant Small Model D1D1D1 7 | ./stl10_single.sh --augment --deltaorth --model e2wrn16_8_stl --restrict 0 8 | 9 | # Flip Equivariant Large Model D1D1D1 10 | ./stl10_single.sh --augment --deltaorth --model e2wrn16_8_stl --restrict 0 --fixparams 11 | 12 | # Equivariant Small Model D8D4D1 13 | ./stl10_single.sh --augment --deltaorth --model e2wrn16_8_stl --restrict 3 14 | 15 | # Equivariant Large Model D8D4D1 16 | ./stl10_single.sh --augment --deltaorth --model e2wrn16_8_stl --restrict 3 --fixparams --split 17 | 18 | # Conventional Model (with DeltaOrthogonal) 19 | ./stl10_single.sh --augment --deltaorth --model wrn16_8_stl 20 | 21 | done 22 | -------------------------------------------------------------------------------- /experiments/mnist_restrict.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | INTERPOLATION=3 4 | # frequency cut-off policy; use the default one 5 | F="None" 6 | 7 | # layers where to perform restriction 8 | res=("1" "2" "3" "4" "5") 9 | 10 | for i in $(seq 1 5) 11 | do 12 | for r in "${res[@]}" 13 | do 14 | # D_16 -> C_16 model on MNIST rot 15 | ./mnist_bench_single.sh --S 1 --type regular --dataset "mnist_rot" --N 16 --flip --sgid 16 --restrict $r --F "$F" --sigma "None" --fixparams 16 | # D_16 -> C_1={e} model on MNIST 12k 17 | ./mnist_bench_single.sh --S 1 --type regular --dataset "mnist12k" --N 16 --flip --sgid 1 --restrict $r --F "$F" --sigma "None" --fixparams 18 | # C_16 -> C_1={e} model on MNIST 12k 19 | ./mnist_bench_single.sh --S 1 --type regular --dataset "mnist12k" --N 16 --sgid 1 --restrict $r --F "$F" --sigma "None" --fixparams 20 | done 21 | done 22 | -------------------------------------------------------------------------------- /experiments/datasets/mnist12k/convert.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | 5 | np.random.seed(42) 6 | 7 | f = open("mnist_test.amat", "r") 8 | 9 | test = [] 10 | 11 | for line in f: 12 | test.append([float(x) for x in line.split()]) 13 | 14 | test = np.array(test) 15 | 16 | 17 | f = open("mnist_train.amat", "r") 18 | 19 | trainval = [] 20 | 21 | for line in f: 22 | trainval.append([float(x) for x in line.split()]) 23 | 24 | trainval = np.array(trainval) 25 | 26 | train = trainval[:10000, :].copy() 27 | valid = trainval[10000:, :].copy() 28 | 29 | np.savez("mnist_trainval", images=trainval[:, :-1].reshape(-1, 28, 28), labels=trainval[:, -1]) 30 | np.savez("mnist_test", images=test[:, :-1].reshape(-1, 28, 28), labels=test[:, -1]) 31 | np.savez("mnist_train", images=train[:, :-1].reshape(-1, 28, 28), labels=train[:, -1]) 32 | np.savez("mnist_valid", images=valid[:, :-1].reshape(-1, 28, 28), labels=valid[:, -1]) 33 | 34 | del train 35 | del valid 36 | del test 37 | 38 | np.random.shuffle(trainval) 39 | 40 | train = trainval[:10000, :].copy() 41 | valid = trainval[10000:, :].copy() 42 | 43 | np.savez("mnist_train_shuffled", images=train[:, :-1].reshape(-1, 28, 28), labels=train[:, -1]) 44 | np.savez("mnist_valid_shuffled", images=valid[:, :-1].reshape(-1, 28, 28), labels=valid[:, -1]) 45 | 46 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_rot/convert.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | 5 | 6 | f = open("mnist_all_rotation_normalized_float_test.amat", "r") 7 | 8 | test = [] 9 | 10 | for line in f: 11 | test.append([float(x) for x in line.split()]) 12 | 13 | test = np.array(test) 14 | 15 | 16 | f = open("mnist_all_rotation_normalized_float_train_valid.amat", "r") 17 | 18 | trainval = [] 19 | 20 | for line in f: 21 | trainval.append([float(x) for x in line.split()]) 22 | 23 | trainval = np.array(trainval) 24 | 25 | train = trainval[:10000, :].copy() 26 | valid = trainval[10000:, :].copy() 27 | 28 | np.savez("mnist_rot_trainval", images=trainval[:, :-1].reshape(-1, 28, 28), labels=trainval[:, -1]) 29 | np.savez("mnist_rot_test", images=test[:, :-1].reshape(-1, 28, 28), labels=test[:, -1]) 30 | np.savez("mnist_rot_train", images=train[:, :-1].reshape(-1, 28, 28), labels=train[:, -1]) 31 | np.savez("mnist_rot_valid", images=valid[:, :-1].reshape(-1, 28, 28), labels=valid[:, -1]) 32 | 33 | del train 34 | del valid 35 | del test 36 | 37 | np.random.shuffle(trainval) 38 | 39 | train = trainval[:10000, :].copy() 40 | valid = trainval[10000:, :].copy() 41 | 42 | np.savez("mnist_rot_train_shuffled", images=train[:, :-1].reshape(-1, 28, 28), labels=train[:, -1]) 43 | np.savez("mnist_rot_valid_shuffled", images=valid[:, :-1].reshape(-1, 28, 28), labels=valid[:, -1]) 44 | 45 | -------------------------------------------------------------------------------- /experiments/multiple_exps.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import random 4 | from experiment import run_experiment 5 | from plot_exps import plot 6 | import utils 7 | import torch 8 | 9 | 10 | def main(config): 11 | 12 | torch.set_default_dtype(torch.float32) 13 | 14 | seed_sampler = random.Random() 15 | 16 | nexp = config.S 17 | del config.S 18 | 19 | logsfile = utils.logs_path(config) 20 | plotpath = utils.plot_path(config) 21 | 22 | for i in range(nexp): 23 | config.seed = seed_sampler.randint(0, 10000) 24 | run_experiment(config) 25 | plot(logsfile, plotpath) 26 | 27 | 28 | ######################################################################################################################## 29 | ######################################################################################################################## 30 | 31 | 32 | if __name__ == "__main__": 33 | # Parse training configuration 34 | parser = argparse.ArgumentParser() 35 | 36 | ######## Number of experiments ########## 37 | parser.add_argument('--S', type=int, help='Number of different experiments (different Seeds)') 38 | 39 | ######## EXPERIMENT'S PARAMETERS ######## 40 | parser = utils.args_exp_parameters(parser) 41 | 42 | config = parser.parse_args() 43 | main(config) 44 | 45 | -------------------------------------------------------------------------------- /experiments/print_results.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import os.path 4 | 5 | import sqlite3 6 | import pandas as pd 7 | 8 | 9 | def retrieve_logs(path: str) -> pd.DataFrame: 10 | conn = sqlite3.connect(path) 11 | logs = pd.read_sql_query("select * from logs;", conn) 12 | conn.close() 13 | 14 | return logs 15 | 16 | 17 | for folder in glob.iglob("./results/*/"): 18 | dataset = os.path.basename(os.path.normpath(folder)) 19 | 20 | print('###########################################################################################################') 21 | print(f"DATASET: {dataset}") 22 | 23 | exps = [] 24 | 25 | for db in glob.iglob(os.path.join(folder, "*.db")): 26 | model = os.path.splitext(os.path.basename(db))[0] 27 | 28 | # if model.endswith("confusion"): 29 | # continue 30 | 31 | logs = retrieve_logs(db) 32 | 33 | logs = logs[logs.split == "test"] 34 | last_iter = logs.iteration.max() 35 | logs = logs[logs.iteration == last_iter].groupby("seed").first() 36 | accuracies = logs.accuracy 37 | 38 | errors = 100.0 - 100.0 * accuracies 39 | 40 | e = "{:<55} | min: {:.5f}; mean: {:.5f}; std: {:.5f} | samples: {}".format(model, errors.min(), errors.mean(), errors.std(), len(accuracies)) 41 | exps.append(e) 42 | 43 | for exp in sorted(exps): 44 | print(exp) 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /experiments/cifar_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | for i in $(seq 1 3) 5 | do 6 | 7 | # small D8D4D1 models 8 | ./cifar_single.sh --fixparams --model e2wrn28_7 --restrict 3 --dataset cifar10 9 | ./cifar_single.sh --fixparams --model e2wrn28_7 --restrict 3 --dataset cifar100 10 | 11 | # big D8D4D1 models 12 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 3 --dataset cifar10 13 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 3 --dataset cifar100 14 | 15 | # D8D4D4 models 16 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 1 --dataset cifar10 17 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 1 --dataset cifar100 18 | 19 | # C8C4C1 models 20 | ./cifar_single.sh --fixparams --model e2wrn28_10R --restrict 3 --dataset cifar10 21 | ./cifar_single.sh --fixparams --model e2wrn28_10R --restrict 3 --dataset cifar100 22 | 23 | # D1D1D1 models 24 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 0 --dataset cifar10 25 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 0 --dataset cifar100 26 | 27 | 28 | # AutoAugment Experiments 29 | 30 | # small D8D4D1 models + AA 31 | ./cifar_single.sh --fixparams --model e2wrn28_7 --restrict 3 --augment --dataset cifar10 32 | ./cifar_single.sh --fixparams --model e2wrn28_7 --restrict 3 --augment --dataset cifar100 33 | 34 | # big D8D4D1 models + AA 35 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 3 --augment --dataset cifar10 36 | ./cifar_single.sh --fixparams --model e2wrn28_10 --restrict 3 --augment --dataset cifar100 37 | 38 | done 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2021 Qualcomm Innovation Center, Inc. 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted (subject to the limitations in the 7 | disclaimer below) provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | * Neither the name of Qualcomm Innovation Center, Inc. nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE 21 | GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT 22 | HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED 23 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 24 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 26 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 27 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 28 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 29 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 30 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 31 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN 32 | IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | 34 | 35 | -------------------------------------------------------------------------------- /experiments/cifar_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p gpu 4 | #SBATCH -N 1 5 | 6 | 7 | ################ ignore this ############ SBATCH --time=30:00:00 8 | 9 | #module load CUDA/9.0.176 10 | #module load CUDA/10.0.130 11 | #source activate e2cnn 12 | 13 | 14 | SEEDS=1 15 | N="8" 16 | 17 | #MODEL="e2wrn28_10" 18 | #MODEL="e2wrn28_10R" 19 | #MODEL="e2wrn28_7" 20 | MODEL="e2wrn28_7R" 21 | 22 | RESTRICT="3" 23 | DATASET="cifar10" 24 | WD="0.0005" 25 | FIXPARAMS=0 26 | AUGMENT=0 27 | 28 | while [[ $# -gt 0 ]] 29 | do 30 | key="$1" 31 | 32 | case $key in 33 | --N) 34 | N="$2" 35 | shift 36 | ;; 37 | --restrict) 38 | RESTRICT="$2" 39 | shift 40 | ;; 41 | --weight_decay) 42 | WD="$2" 43 | shift 44 | ;; 45 | --model) 46 | MODEL="$2" 47 | shift 48 | ;; 49 | --dataset) 50 | DATASET="$2" 51 | shift 52 | ;; 53 | --fixparams) 54 | FIXPARAMS=1 55 | shift 56 | ;; 57 | --augment) 58 | AUGMENT=1 59 | shift 60 | ;; 61 | --S) 62 | SEEDS="$2" 63 | shift 64 | ;; 65 | *) # unknown option 66 | shift # past argument 67 | ;; 68 | esac 69 | done 70 | 71 | 72 | PARAMS="--dataset=$DATASET --model=$MODEL --N=$N --restrict=$RESTRICT --F=1. --sigma=0.45" 73 | TRAIN_PARAMS="--adapt_lr=exponential --epochs=200 --lr=0.1 --batch_size=128 --optimizer=SGD --momentum=0.9 --weight_decay=$WD --eval_frequency=-1 --backup_model --no_earlystop --eval_test" 74 | TRAIN_PARAMS="$TRAIN_PARAMS --lr_decay_epoch=60 --lr_decay_factor=0.2 --lr_decay_start=0" 75 | #TRAIN_PARAMS="$TRAIN_PARAMS --lr_decay_schedule 60 120 180 --lr_decay_factor=0.2" 76 | 77 | PARAMS="$PARAMS --deltaorth" 78 | 79 | if [ "$FIXPARAMS" -eq "1" ]; then 80 | PARAMS="$PARAMS --fixparams" 81 | fi 82 | 83 | if [ "$AUGMENT" -eq "1" ]; then 84 | TRAIN_PARAMS="$TRAIN_PARAMS --augment" 85 | fi 86 | 87 | echo $DATASET 88 | 89 | # python count_parameters.py $PARAMS $TRAIN_PARAMS 90 | python multiple_exps.py --S=$SEEDS $PARAMS $TRAIN_PARAMS 91 | #python -O multiple_exps.py --S=$SEEDS $PARAMS $TRAIN_PARAMS 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_fliprot/convert.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | 5 | np.random.seed(42) 6 | 7 | def preprocess(dataset, flip_all=False): 8 | 9 | images = dataset[:, :-1].reshape(-1, 28, 28) 10 | labels = dataset[:, -1] 11 | 12 | if flip_all: 13 | # augment the dataset with a flipped copy of each datapoint 14 | flipped_images = images[:, :, ::-1] 15 | 16 | images = np.concatenate([images, flipped_images]) 17 | labels = np.concatenate([labels, labels]) 18 | else: 19 | # for each datapoint, we choose randomly whether to flip it or not 20 | idxs = np.random.binomial(1, 0.5, dataset.shape[0]) 21 | 22 | images[idxs, ...] = images[idxs, :, ::-1] 23 | 24 | return {"images": images, "labels": labels} 25 | 26 | f = open("../mnist_rot/mnist_all_rotation_normalized_float_test.amat", "r") 27 | 28 | test = [] 29 | 30 | for line in f: 31 | test.append([float(x) for x in line.split()]) 32 | 33 | test = np.array(test) 34 | np.savez("mnist_fliprot_test", **preprocess(test, flip_all=True)) 35 | del test 36 | 37 | f = open("../mnist_rot/mnist_all_rotation_normalized_float_train_valid.amat", "r") 38 | 39 | trainval = [] 40 | 41 | for line in f: 42 | trainval.append([float(x) for x in line.split()]) 43 | 44 | npoints = len(trainval) 45 | 46 | trainval = np.array(trainval) 47 | 48 | trainval = preprocess(trainval) 49 | 50 | np.savez("mnist_fliprot_trainval", **trainval) 51 | np.savez("mnist_fliprot_train", images=trainval["images"][:10000, ...], labels=trainval["labels"][:10000, ...]) 52 | np.savez("mnist_fliprot_valid", images=trainval["images"][10000:, ...], labels=trainval["labels"][10000:, ...]) 53 | 54 | idxs = np.arange(npoints) 55 | np.random.shuffle(idxs) 56 | trainval["images"] = trainval["images"][idxs, ...] 57 | trainval["labels"] = trainval["labels"][idxs] 58 | 59 | np.savez("mnist_fliprot_train_shuffled", images=trainval["images"][:10000, ...], labels=trainval["labels"][:10000, ...]) 60 | np.savez("mnist_fliprot_valid_shuffled", images=trainval["images"][10000:, ...], labels=trainval["labels"][10000:, ...]) 61 | 62 | 63 | -------------------------------------------------------------------------------- /experiments/retrieve_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | import torch 4 | 5 | from collections import OrderedDict 6 | 7 | ############################################################################################### 8 | # retrieve a stored model 9 | # use the usual command line parameters to specify the experiment that generated such model 10 | # add the "--seed " to specify the seed used to run that experiment 11 | # in order to discriminate between different run of the same experiments 12 | ############################################################################################### 13 | 14 | def retrieve(config): 15 | 16 | path = utils.backup_path(config) 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 19 | 20 | # N.B.: the saved file contains the model wrapped in DataParallel 21 | 22 | if device == "cpu": 23 | state_dict = torch.load(path, map_location=device) 24 | else: 25 | state_dict = torch.load(path) 26 | 27 | _, n_inputs, n_outputs = utils.build_dataloaders( 28 | config.dataset, 29 | config.batch_size, 30 | config.workers, 31 | config.augment, 32 | config.earlystop, 33 | config.reshuffle 34 | ) 35 | model = utils.build_model(config, n_inputs, n_outputs) 36 | 37 | # filter out `module.` prefix (coming from DataParallel) 38 | new_state_dict = OrderedDict() 39 | for k, v in state_dict.items(): 40 | name = k[7:] # remove `module.` 41 | new_state_dict[name] = v 42 | del state_dict 43 | 44 | model = model.to(device) 45 | # load params 46 | model.load_state_dict(new_state_dict) 47 | 48 | return model 49 | 50 | 51 | if __name__ == "__main__": 52 | 53 | # Parse training configuration 54 | parser = argparse.ArgumentParser() 55 | 56 | parser = utils.args_exp_parameters(parser) 57 | 58 | parser.add_argument('--seed', type=int, help='Seed of the experiment') 59 | parser.add_argument('--output', type=str, default=None, help="Path where to store the extracted model") 60 | 61 | config = parser.parse_args() 62 | 63 | # Train the model 64 | model = retrieve(config) 65 | 66 | if config.output is not None: 67 | torch.save(model.state_dict(), config.output) 68 | -------------------------------------------------------------------------------- /experiments/mnist_final_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SEEDS=1 4 | N="16" 5 | MODEL="E2SFCNN" 6 | RESTRICT="0" 7 | DATASET="mnist_rot" 8 | J="-1" 9 | SGID="" 10 | F="None" 11 | sigma="None" 12 | 13 | FLIP=0 14 | 15 | FIXPARAMS=0 16 | DELTAORTH=0 17 | 18 | INTERPOLATION=0 19 | 20 | 21 | while [[ $# -gt 0 ]] 22 | do 23 | key="$1" 24 | 25 | case $key in 26 | --N) 27 | N="$2" 28 | shift 29 | shift 30 | ;; 31 | --model) 32 | MODEL="$2" 33 | shift 34 | shift 35 | ;; 36 | --J) 37 | J="$2" 38 | shift 39 | shift 40 | ;; 41 | --F) 42 | F="$2" 43 | shift 44 | shift 45 | ;; 46 | --sigma) 47 | sigma="$2" 48 | shift 49 | shift 50 | ;; 51 | --restrict) 52 | RESTRICT="$2" 53 | shift 54 | shift 55 | ;; 56 | --flip) 57 | FLIP=1 58 | shift 59 | ;; 60 | --fixparams) 61 | FIXPARAMS=1 62 | shift 63 | ;; 64 | --sgid) 65 | SGID="$2" 66 | shift 67 | shift 68 | ;; 69 | --deltaorth) 70 | DELTAORTH=1 71 | shift 72 | ;; 73 | --interpolation) 74 | INTERPOLATION="$2" 75 | shift 76 | shift 77 | ;; 78 | --dataset) 79 | DATASET="$2" 80 | shift 81 | shift 82 | ;; 83 | --S) 84 | SEEDS="$2" 85 | shift 86 | shift 87 | ;; 88 | *) # unknown option 89 | shift # past argument 90 | ;; 91 | esac 92 | done 93 | 94 | 95 | PARAMS="--dataset=$DATASET --model=$MODEL --N=$N --restrict=$RESTRICT --F=$F --sigma=$sigma --epochs=40 --lr=0.015 --batch_size=64 --augment --time_limit=300 --verbose=2 --optimizer=sfcnn --l1 --adapt_lr=exponential --lr_decay_start=15 --lr_decay_factor=0.8 --lr_decay_epoch=1 --no_earlystop" 96 | PARAMS="$PARAMS --interpolation=$INTERPOLATION" 97 | 98 | if [ "$SGID" != "" ]; then 99 | PARAMS="$PARAMS --sgsize=$SGID" 100 | fi 101 | 102 | if [ "$FLIP" -eq "1" ]; then 103 | PARAMS="$PARAMS --flip" 104 | fi 105 | 106 | if [ "$FIXPARAMS" -eq "1" ]; then 107 | PARAMS="$PARAMS --fixparams" 108 | fi 109 | 110 | if [ "$DELTAORTH" -eq "1" ]; then 111 | PARAMS="$PARAMS --deltaorth" 112 | fi 113 | 114 | 115 | if [ "$J" -ne "-1" ]; then 116 | PARAMS="$PARAMS --J=$J" 117 | fi 118 | 119 | echo $PARAMS 120 | 121 | python multiple_exps.py --S=$SEEDS $PARAMS 122 | #python count_parameters.py $PARAMS 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /experiments/count_parameters.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import utils 5 | 6 | import datetime 7 | 8 | SHOW_PLOT = False 9 | SAVE_PLOT = True 10 | RESHUFFLE = False 11 | AUGMENT_TRAIN = False 12 | LEARNING_RATE = 1e-4 13 | BATCH_SIZE = 64 14 | EPOCHS = 40 15 | PLOT_FREQ = 100 16 | EVAL_FREQ = 100 17 | BACKUP = False 18 | BACKUP_FREQ = -1 19 | 20 | 21 | def count_params(config): 22 | 23 | _, n_inputs, n_outputs = utils.build_dataloaders( 24 | config.dataset, 25 | config.batch_size, 26 | config.workers, 27 | config.augment, 28 | config.earlystop, 29 | config.reshuffle 30 | ) 31 | expname = utils.exp_name(config) 32 | 33 | model = utils.build_model(config, n_inputs, n_outputs) 34 | 35 | nparams = sum([p.numel() for p in model.parameters() if p.requires_grad]) 36 | 37 | totmemory = sum([p.numel() * p.element_size() for p in model.parameters() if p.requires_grad]) 38 | totmemory += sum([p.numel() * p.element_size() for p in model.buffers()]) 39 | totmemory //= 1024 ** 2 40 | 41 | print("Total Parameters: {:<15} | Total Memory (MB): {:<15}".format(nparams, totmemory)) 42 | 43 | # for i, (name, mod) in enumerate(model.named_modules()): 44 | # print("\t", i, el.__class__, el.in_type.size, el.out_type.size) 45 | 46 | # mem = sum([p.numel() * p.element_size() for p in mod.parameters(recurse=False)]) 47 | # mem += sum([p.numel() * p.element_size() for p in mod.buffers(recurse=False)]) 48 | 49 | # mem //= 1024**2 50 | 51 | # print(f"\t{i}: {name}", mod, mem) 52 | 53 | return expname, nparams 54 | 55 | 56 | ################################################################################ 57 | ################################################################################ 58 | 59 | 60 | if __name__ == "__main__": 61 | # Parse training configuration 62 | parser = argparse.ArgumentParser() 63 | 64 | ######## EXPERIMENT'S PARAMETERS ######## 65 | parser = utils.args_exp_parameters(parser) 66 | 67 | config = parser.parse_args() 68 | 69 | print("----------------------------------------------------------") 70 | print(datetime.datetime.now()) 71 | 72 | expname, nparams = count_params(config) 73 | 74 | print(f"{expname}:\t{nparams} parameters") 75 | 76 | print(datetime.datetime.now()) 77 | print("----------------------------------------------------------") 78 | -------------------------------------------------------------------------------- /experiments/stl10_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | SEEDS=1 5 | N="8" 6 | 7 | #MODEL="e2wrn16_8" 8 | MODEL="wrn16_8" 9 | 10 | RESTRICT="3" 11 | DATASET="STL10cif" 12 | WD="0.0005" 13 | F="1." 14 | SIGMA="None" 15 | FIXPARAMS=0 16 | AUGMENT=0 17 | DELTAORTH=0 18 | SPLIT=0 19 | VALIDATE=0 20 | 21 | while [[ $# -gt 0 ]] 22 | do 23 | key="$1" 24 | 25 | case $key in 26 | --N) 27 | N="$2" 28 | shift 29 | ;; 30 | --restrict) 31 | RESTRICT="$2" 32 | shift 33 | ;; 34 | --weight_decay) 35 | WD="$2" 36 | shift 37 | ;; 38 | --F) 39 | F="$2" 40 | shift 41 | ;; 42 | --sigma) 43 | SIGMA="$2" 44 | shift 45 | ;; 46 | --model) 47 | MODEL="$2" 48 | shift 49 | ;; 50 | --dataset) 51 | DATASET="$2" 52 | shift 53 | ;; 54 | --fixparams) 55 | FIXPARAMS=1 56 | shift 57 | ;; 58 | --augment) 59 | AUGMENT=1 60 | shift 61 | ;; 62 | --deltaorth) 63 | DELTAORTH=1 64 | shift 65 | ;; 66 | --split) 67 | SPLIT=1 68 | shift 69 | ;; 70 | --validate) 71 | VALIDATE=1 72 | shift 73 | ;; 74 | --S) 75 | SEEDS="$2" 76 | shift 77 | ;; 78 | *) # unknown option 79 | shift # past argument 80 | ;; 81 | esac 82 | done 83 | 84 | 85 | PARAMS="--dataset=$DATASET --model=$MODEL --N=$N --restrict=$RESTRICT --F=$F --sigma=$SIGMA" 86 | TRAIN_PARAMS="--adapt_lr=exponential --epochs=1000 --lr=0.1 --optimizer=SGD --momentum=0.9 --weight_decay=$WD --eval_frequency=-5" 87 | TRAIN_PARAMS="$TRAIN_PARAMS --lr_decay_schedule 300 400 600 800 --lr_decay_factor=0.2" 88 | 89 | if [ "$VALIDATE" -eq "0" ]; then 90 | TRAIN_PARAMS="$TRAIN_PARAMS --no_earlystop --eval_test" 91 | fi 92 | 93 | if [ "$SPLIT" -eq "1" ]; then 94 | TRAIN_PARAMS="$TRAIN_PARAMS --eval_batch_size=128 --batch_size=64 --accumulate=2" 95 | else 96 | TRAIN_PARAMS="$TRAIN_PARAMS --eval_batch_size=128 --batch_size=128 --accumulate=1" 97 | fi 98 | 99 | if [ "$DELTAORTH" -eq "1" ]; then 100 | TRAIN_PARAMS="$TRAIN_PARAMS --deltaorth" 101 | fi 102 | 103 | if [ "$FIXPARAMS" -eq "1" ]; then 104 | PARAMS="$PARAMS --fixparams" 105 | fi 106 | 107 | if [ "$AUGMENT" -eq "1" ]; then 108 | TRAIN_PARAMS="$TRAIN_PARAMS --augment" 109 | fi 110 | 111 | TRAIN_PARAMS="$TRAIN_PARAMS --store_plot --plot_frequency=-5 --backup_model --verbose=4" 112 | 113 | 114 | echo $PARAMS 115 | echo $TRAIN_PARAMS 116 | 117 | #python -O multiple_exps.py --S=$SEEDS $PARAMS $TRAIN_PARAMS 118 | python multiple_exps.py --S=$SEEDS $PARAMS $TRAIN_PARAMS 119 | #python count_parameters.py $PARAMS $TRAIN_PARAMS 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /experiments/mnist_bench_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SEEDS=1 4 | N="16" 5 | MODEL="EXP" 6 | TYPE="regular" 7 | RESTRICT="0" 8 | DATASET="mnist_rot" 9 | J="-1" 10 | F="None" 11 | sigma="None" 12 | SGID="" 13 | 14 | FIXPARAMS=0 15 | DELTAORTH=0 16 | FLIP=0 17 | INTERPOLATION=2 18 | REGULARIZE=0 19 | 20 | while [[ $# -gt 0 ]] 21 | do 22 | key="$1" 23 | 24 | case $key in 25 | --N) 26 | N="$2" 27 | shift 28 | shift 29 | ;; 30 | --J) 31 | J="$2" 32 | shift 33 | shift 34 | ;; 35 | --F) 36 | F="$2" 37 | shift 38 | shift 39 | ;; 40 | --sgid) 41 | SGID="$2" 42 | shift 43 | shift 44 | ;; 45 | --sigma) 46 | sigma="$2" 47 | shift 48 | shift 49 | ;; 50 | --restrict) 51 | RESTRICT="$2" 52 | shift 53 | shift 54 | ;; 55 | --flip) 56 | FLIP=1 57 | shift 58 | ;; 59 | --regularize) 60 | REGULARIZE=1 61 | shift 62 | ;; 63 | --fixparams) 64 | FIXPARAMS=1 65 | shift 66 | ;; 67 | --deltaorth) 68 | DELTAORTH=1 69 | shift 70 | ;; 71 | --interpolation) 72 | INTERPOLATION="$2" 73 | shift 74 | shift 75 | ;; 76 | --model) 77 | MODEL="$2" 78 | shift 79 | shift 80 | ;; 81 | --type) 82 | TYPE="$2" 83 | shift 84 | shift 85 | ;; 86 | --dataset) 87 | DATASET="$2" 88 | shift 89 | shift 90 | ;; 91 | --S) 92 | SEEDS="$2" 93 | shift 94 | shift 95 | ;; 96 | *) # unknown option 97 | shift # past argument 98 | ;; 99 | esac 100 | done 101 | 102 | 103 | PARAMS="--dataset=$DATASET --model=$MODEL --type=$TYPE --N=$N --restrict=$RESTRICT --F=$F --sigma=$sigma --interpolation=$INTERPOLATION --epochs=30 --lr=0.001 --batch_size=64 --augment --time_limit=300 --verbose=2 --adapt_lr=exponential --lr_decay_start=10 --reshuffle" #--no_earlystop 104 | 105 | if [ "$SGID" != "" ]; then 106 | PARAMS="$PARAMS --sgsize=$SGID" 107 | fi 108 | 109 | if [ "$FLIP" -eq "1" ]; then 110 | PARAMS="$PARAMS --flip" 111 | fi 112 | 113 | if [ "$REGULARIZE" -eq "1" ]; then 114 | PARAMS="$PARAMS --weight_decay=0.0 --optimizer=sfcnn --lamb_fully_L2=0.0000001 --lamb_conv_L2=0.0000001 --lamb_bn_L2=0 --lamb_softmax_L2=0" 115 | else 116 | PARAMS="$PARAMS --weight_decay=0.0 --optimizer=Adam" 117 | PARAMS="$PARAMS --lamb_fully_L2=0.0 --lamb_conv_L2=0.0 --lamb_bn_L2=0 --lamb_softmax_L2=0" 118 | fi 119 | 120 | if [ "$FIXPARAMS" -eq "1" ]; then 121 | PARAMS="$PARAMS --fixparams" 122 | fi 123 | 124 | if [ "$DELTAORTH" -eq "1" ]; then 125 | PARAMS="$PARAMS --deltaorth" 126 | fi 127 | 128 | 129 | if [ "$J" -ne "-1" ]; then 130 | PARAMS="$PARAMS --J=$J" 131 | fi 132 | 133 | echo $PARAMS 134 | 135 | #python3.7 -O multiple_exps.py --S=$SEEDS $PARAMS 136 | python multiple_exps.py --S=$SEEDS $PARAMS 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /experiments/datasets/mnist/data_loader_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | 7 | 8 | def build_mnist_loader(mode, batch_size, num_workers=8, augment=False, reshuffle_seed=None): 9 | """ """ 10 | 11 | assert mode in ['train', 'valid', 'trainval', 'test'] 12 | assert reshuffle_seed is None or (mode != "test" and mode != 'trainval') 13 | 14 | rot_trans = transforms.Compose([transforms.RandomRotation(degrees=180, resample=False), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.5,), (1.0,))]) 17 | 18 | trans = transforms.Compose([transforms.ToTensor(), 19 | transforms.Normalize((0.5,), (1.0,))]) 20 | 21 | if mode == "test": 22 | # if doesn't exist, download mnist dataset 23 | test_set = dset.MNIST(root='./datasets/mnist/', train=False, transform=trans, download=True) 24 | 25 | loader = torch.utils.data.DataLoader( 26 | test_set, 27 | batch_size=batch_size, 28 | num_workers=num_workers, 29 | pin_memory=True 30 | ) 31 | else: 32 | 33 | # if doesn't exist, download mnist dataset 34 | train_set = dset.MNIST(root='./datasets/mnist/', train=True, transform=trans, download=True) 35 | 36 | if mode in ["valid", "train"]: 37 | 38 | valid_set = dset.MNIST(root='./datasets/mnist/', train=True, transform=trans, download=True) 39 | num_train = len(train_set) 40 | indices = list(range(num_train)) 41 | split = int(np.floor(num_train * 5/6)) 42 | 43 | if reshuffle_seed is not None: 44 | rng = np.random.RandomState(reshuffle_seed) 45 | rng.shuffle(indices) 46 | 47 | train_idx, valid_idx = indices[:split], indices[split:] 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | if mode == "train": 51 | loader = torch.utils.data.DataLoader( 52 | train_set, batch_size=batch_size, sampler=train_sampler, 53 | num_workers=num_workers, pin_memory=True 54 | ) 55 | else: 56 | loader = torch.utils.data.DataLoader( 57 | valid_set, batch_size=batch_size, sampler=valid_sampler, 58 | num_workers=num_workers, pin_memory=True 59 | ) 60 | else: 61 | # mode == "trainval" 62 | loader = torch.utils.data.DataLoader( 63 | train_set, batch_size=batch_size, 64 | num_workers=num_workers, pin_memory=True 65 | ) 66 | 67 | n_inputs = 1 68 | n_outputs = 10 69 | 70 | return loader, n_inputs, n_outputs 71 | 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Experiments for General E(2)-Equivariant Steerable CNNs 2 | -------------------------------------------------------------------------------- 3 | **[Paper](https://arxiv.org/abs/1911.08251)** | **[Library](https://github.com/QUVA-Lab/e2cnn)** 4 | 5 | 6 | ## Getting Started - Environment 7 | 8 | 9 | First, you can set up a Conda environment containing some packages required 10 | 11 | ``` 12 | conda create --name e2exp python=3.6 13 | source activate e2exp 14 | 15 | conda install -y pytorch=1.3 torchvision cudatoolkit=10.0 -c pytorch 16 | conda install -y -c conda-forge matplotlib 17 | conda install -y scipy=1.5 pandas scikit-learn=0.23 18 | conda install -y -c anaconda sqlite 19 | ``` 20 | 21 | Now, we add the [e2cnn](https://github.com/QUVA-Lab/e2cnn) library. 22 | Since the environment has Python 3.6, we clone the [legacy_py3.6](https://github.com/QUVA-Lab/e2cnn/tree/legacy_py3.6) 23 | branch. 24 | 25 | NOTE: make sure you are in the `./experiments/` folder before running the following commands. 26 | 27 | ``` 28 | mkdir tmp_e2cnn 29 | cd tmp_e2cnn 30 | git clone --single-branch --branch legacy_py3.6 https://github.com/QUVA-Lab/e2cnn 31 | mv e2cnn/e2cnn ../e2cnn 32 | cd .. 33 | rm -rf tmp_e2cnn 34 | ``` 35 | 36 | If you use Python 3.7 or higher, you can install the library just using 37 | ``` 38 | pip install e2cnn 39 | ``` 40 | 41 | These commands are already included in the file [setting_up_env.sh](./experiments/setting_up_env.sh), so you can also just run 42 | ``` 43 | cd experiments 44 | ./setting_up_env.sh 45 | ``` 46 | 47 | ## Getting Started - Datasets 48 | 49 | To automatically download the MNIST variants datasets, you can run the following commands 50 | (assuming you are in the `./experiments/` folder): 51 | 52 | ``` 53 | cd datasets 54 | ./download_mnist.sh 55 | 56 | source activate e2exp 57 | 58 | cd mnist_rot 59 | python convert.py 60 | 61 | cd ../mnist_fliprot 62 | python convert.py 63 | 64 | cd ../mnist12k 65 | python convert.py 66 | 67 | ``` 68 | 69 | 70 | ## Getting Started - Experiments 71 | 72 | All the experiments can be run automatically through the following few scripts 73 | (assuming you are in the `./experiments/` folder). 74 | 75 | 76 | To run all the model benchmarking experiments on transformed MNIST datasets: 77 | ``` 78 | ./mnist_bench.sh 79 | ``` 80 | 81 | To run the MNIST experiments with group restriction: 82 | ``` 83 | ./mnist_restrict.sh 84 | ``` 85 | 86 | To run the competitive MNIST experiments: 87 | ``` 88 | ./mnist_final.sh 89 | ``` 90 | 91 | To run the CIFAR10 and the CIFAR100 experiments: 92 | ``` 93 | ./cifar_experiments.sh 94 | ``` 95 | 96 | To run the experiments on the full STL10 dataset: 97 | ``` 98 | ./stl10_experiments.sh 99 | ``` 100 | 101 | To run the data ablation study on STL10 102 | ``` 103 | ./stl10_ablation.sh 104 | ``` 105 | 106 | You can find more details about the single experiments in each bash file. 107 | 108 | 109 | Experiments' logs and results are stored in a new `./results` folder. 110 | A summary of all experiments can be printed with the `print_results.py` script. 111 | 112 | 113 | ## Cite 114 | 115 | The development of the library and the experiments was part of the work done for our paper 116 | [General E(2)-Equivariant Steerable CNNs](https://arxiv.org/abs/1911.08251). 117 | Please cite this work if you use our code: 118 | 119 | ``` 120 | @inproceedings{e3cnn, 121 | title={{General E(2)-Equivariant Steerable CNNs}}, 122 | author={Weiler, Maurice and Cesa, Gabriele}, 123 | booktitle={Conference on Neural Information Processing Systems (NeurIPS)}, 124 | year={2019}, 125 | } 126 | ``` 127 | 128 | Feel free to [contact us](mailto:cesa.gabriele@gmail.com,m.weiler@uva.nl). 129 | 130 | ## License 131 | 132 | This code and the *e2cnn* library are distributed under BSD Clear license. See LICENSE file. 133 | -------------------------------------------------------------------------------- /experiments/models/exp_cnn.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from e2cnn.nn import * 4 | from e2cnn.group import * 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | import datetime 11 | 12 | from scipy import stats 13 | 14 | 15 | class ExpCNN(torch.nn.Module): 16 | 17 | def __init__(self, n_channels, n_classes, 18 | fix_param: bool = False, 19 | deltaorth: bool = False 20 | ): 21 | 22 | super(ExpCNN, self).__init__() 23 | 24 | self.n_channels = n_channels 25 | self.n_classes = n_classes 26 | 27 | self.fix_param = fix_param 28 | 29 | layers = [] 30 | 31 | self.LAYER = 0 32 | channels = n_channels 33 | 34 | # 28 px 35 | # Convolutional Layer 1 36 | 37 | self.LAYER += 1 38 | # l, channels = self.layer_builder(channels, 16, 7, 0) 39 | l, channels = self.layer_builder(channels, 16, 7, 1) 40 | layers += l 41 | 42 | # Convolutional Layer 2 43 | self.LAYER += 1 44 | l, channels = self.layer_builder(channels, 24, 5, 2, 2) 45 | layers += l 46 | 47 | # 14 px 48 | # Convolutional Layer 3 49 | self.LAYER += 1 50 | l, channels = self.layer_builder(channels, 32, 5, 2) 51 | layers += l 52 | 53 | # Convolutional Layer 4 54 | self.LAYER += 1 55 | l, channels = self.layer_builder(channels, 32, 5, 2, 2) 56 | layers += l 57 | 58 | # 7 px 59 | 60 | # Convolutional Layer 5 61 | self.LAYER += 1 62 | l, channels = self.layer_builder(channels, 48, 5, 2) 63 | layers += l 64 | 65 | # Convolutional Layer 6 66 | self.LAYER += 1 67 | l, channels = self.layer_builder(channels, 64, 5, 0, None, True) 68 | layers += l 69 | 70 | # Adaptive Pooling 71 | mpl = nn.AdaptiveAvgPool2d(1) 72 | layers.append(mpl) 73 | 74 | # 1 px 75 | 76 | # c = 64 77 | 78 | self.layers = torch.nn.ModuleList(layers) 79 | 80 | # Fully Connected 81 | 82 | self.fully_net = nn.Sequential( 83 | nn.Linear(channels, 64), 84 | nn.BatchNorm1d(64), 85 | nn.ELU(inplace=True), 86 | nn.Linear(64, n_classes), 87 | ) 88 | 89 | if deltaorth: 90 | for name, module in self.named_modules(): 91 | if isinstance(module, nn.Conv2d): 92 | # delta orthogonal intialization for the Pytorch's 1x1 Conv 93 | o, i, w, h = module.weight.shape 94 | if o >= i: 95 | module.weight.data.fill_(0.) 96 | module.weight.data[:, :, w // 2, h // 2] = torch.tensor( 97 | stats.ortho_group.rvs(max(i, o))[:o, :i]) 98 | else: 99 | torch.nn.init.xavier_uniform_(module.weight.data, gain=torch.nn.init.calculate_gain('sigmoid')) 100 | 101 | print("MODEL TOPOLOGY:") 102 | for i, (name, mod) in enumerate(self.named_modules()): 103 | params = sum([p.numel() for p in mod.parameters() if p.requires_grad]) 104 | if isinstance(mod, nn.Conv2d): 105 | print(f"\t{i: <3} - {name: <70} | {params: <8} | {mod.weight.shape[1]: <4}- {mod.weight.shape[0]: <4}") 106 | else: 107 | print(f"\t{i: <3} - {name: <70} | {params: <8} |") 108 | tot_param = sum([p.numel() for p in self.parameters() if p.requires_grad]) 109 | print("Total number of parameters:", tot_param) 110 | 111 | def forward(self, input): 112 | x = input 113 | for layer in self.layers: 114 | x = layer(x) 115 | 116 | x = self.fully_net(x.reshape(x.shape[0], -1)) 117 | 118 | return x 119 | 120 | def layer_builder(self, channels, C: int, s: int, padding: int = 0, pooling: int = None, 121 | orientation_pooling: bool = False): 122 | 123 | if self.fix_param and not orientation_pooling and self.LAYER > 1: 124 | # if self.fix_param and self.LAYER > 1: 125 | # to keep number of parameters more or less constant when changing groups 126 | # (more precisely, we try to keep them close to the number of parameters in the original SFCNN) 127 | t = 1 / 16 128 | C = int(round(C / np.sqrt(t))) 129 | 130 | layers = [] 131 | 132 | cl = nn.Conv2d(channels, C, s, padding=padding, bias=False) 133 | layers.append(cl) 134 | 135 | bn = nn.BatchNorm2d(C) 136 | layers.append(bn) 137 | 138 | nnl = nn.ELU(inplace=True) 139 | layers.append(nnl) 140 | 141 | if pooling is not None: 142 | pl = nn.MaxPool2d(pooling) 143 | layers.append(pl) 144 | 145 | return layers, C 146 | -------------------------------------------------------------------------------- /experiments/models/example.py: -------------------------------------------------------------------------------- 1 | from e2cnn.nn import * 2 | from e2cnn.group import * 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class E2SFCNN(torch.nn.Module): 10 | 11 | def __init__(self, 12 | n_channels: int, 13 | n_classes: int, 14 | N: int = 16, 15 | restrict: int = -1, 16 | ): 17 | r""" 18 | 19 | Args: 20 | n_channels: number of channels in the input 21 | n_classes: number of output classes 22 | N: number of rotations of the equivariance group 23 | restrict: number of initial convolutional layers which are also flip equivariant. 24 | After these layers, we restrict to only rotation equivariance. 25 | By default (-1) the restriciton is never done so the model is flip and rotations equivariant. 26 | If set to 0 the model is only rotation equivariant from the beginning 27 | """ 28 | 29 | super(E2SFCNN, self).__init__() 30 | 31 | assert N > 1 32 | 33 | self.n_channels = n_channels 34 | self.n_classes = n_classes 35 | self.N = N 36 | self.restrict = restrict 37 | 38 | gc = FlipRot2dOnR2(N) 39 | 40 | self.gc = gc 41 | 42 | self.LAYER = 0 43 | 44 | if self.restrict == self.LAYER: 45 | gc, _, _ = self.gc.restrict((None, N)) 46 | 47 | r1 = FieldType(gc, [gc.trivial_repr] * n_channels) 48 | 49 | eq_layers = [] 50 | 51 | # 28 px 52 | # Convolutional Layer 1 53 | 54 | self.LAYER += 1 55 | eq_layers += self.build_layers(r1, 24, 9, 0, None) 56 | 57 | # Convolutional Layer 2 58 | self.LAYER += 1 59 | eq_layers += self.build_layers(eq_layers[-1].out_type, 32, 7, 3, 2) 60 | 61 | # 14 px 62 | # Convolutional Layer 3 63 | self.LAYER += 1 64 | eq_layers += self.build_layers(eq_layers[-1].out_type, 36, 7, 3, None) 65 | 66 | # Convolutional Layer 4 67 | self.LAYER += 1 68 | eq_layers += self.build_layers(eq_layers[-1].out_type, 36, 7, 3, 2) 69 | 70 | # 7 px 71 | 72 | # Convolutional Layer 5 73 | self.LAYER += 1 74 | eq_layers += self.build_layers(eq_layers[-1].out_type, 64, 7, 3) 75 | 76 | # Convolutional Layer 6 77 | self.LAYER += 1 78 | eq_layers += self.build_layers(eq_layers[-1].out_type, 96, 5, 0, None, True) 79 | 80 | # Adaptive Pooling 81 | mpl = PointwiseAdaptiveMaxPool(eq_layers[-1].out_type, 1) 82 | eq_layers.append(mpl) 83 | 84 | # 1 px 85 | 86 | # c = 96 87 | c = eq_layers[-1].out_type.size 88 | 89 | self.in_repr = eq_layers[0].in_type 90 | self.eq_layers = SequentialModule(*eq_layers) 91 | 92 | # Fully Connected 93 | 94 | self.fully_net = nn.Sequential( 95 | nn.Dropout(p=0.3), 96 | nn.Linear(c, 96), 97 | nn.BatchNorm1d(96), 98 | nn.ELU(inplace=True), 99 | 100 | nn.Dropout(p=0.3), 101 | nn.Linear(96, 96), 102 | nn.BatchNorm1d(96), 103 | nn.ELU(inplace=True), 104 | 105 | nn.Dropout(p=0.3), 106 | nn.Linear(96, n_classes), 107 | ) 108 | 109 | def forward(self, input): 110 | x = GeometricTensor(input, self.in_repr) 111 | 112 | features = self.eq_layers(x) 113 | 114 | features = features.tensor.reshape(x.tensor.shape[0], -1) 115 | 116 | out = self.fully_net(features) 117 | 118 | return out 119 | 120 | def build_layers(self, 121 | r1: FieldType, 122 | C: int, 123 | s: int, 124 | padding: int = 0, 125 | pooling: int = None, 126 | orientation_pooling: bool = False, 127 | ): 128 | 129 | gc = r1.isometries 130 | 131 | layers = [] 132 | 133 | r2 = FieldType(gc, [gc.representations['regular']] * C) 134 | 135 | cl = R2Conv(r1, 136 | r2, 137 | s, 138 | padding=padding 139 | ) 140 | layers.append(cl) 141 | 142 | if self.restrict == self.LAYER: 143 | layers.append(RestrictionModule(layers[-1].out_type, (None, self.N))) 144 | layers.append(DisentangleModule(layers[-1].out_type)) 145 | 146 | bn = InnerBatchNorm(layers[-1].out_type) 147 | layers.append(bn) 148 | 149 | if orientation_pooling: 150 | pl = CapsulePool(layers[-1].out_type) 151 | layers.append(pl) 152 | 153 | if pooling is not None: 154 | pl = PointwiseMaxPool(layers[-1].out_type, pooling) 155 | layers.append(pl) 156 | 157 | nnl = ReLU(layers[-1].out_type, inplace=True) 158 | layers.append(nnl) 159 | 160 | return layers 161 | 162 | -------------------------------------------------------------------------------- /experiments/datasets/mnist12k/data_loader_mnist12k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.utils.data as data 5 | import os.path 6 | from torchvision import transforms 7 | 8 | from . import own_transforms 9 | 10 | ROOT = "./datasets/mnist12k/" 11 | 12 | 13 | class mnist_dataset(data.Dataset): 14 | """ flip-rotated MNIST dataset """ 15 | 16 | def __init__(self, mode, transform=None, target_transform=None, reshuffle_seed=None): 17 | """ 18 | :type mode: string from ['train', 'valid', 'test'] 19 | :param mode: determines which subset of the dataset is loaded and whether augmentation is used 20 | :type transform: callable 21 | :param transform: transformation applied to PIL images, returning transformed version 22 | :type target_transform: callable 23 | :param target_transform: transformation applied to labels 24 | :type reshuffle_seed: int 25 | :param reshuffle_seed: seed to use to reshuffle train or valid sets. If None (default), they are not reshuffled 26 | """ 27 | assert mode in ['train', 'valid', 'trainval', 'test'] 28 | assert reshuffle_seed is None or (mode != "test" and mode != 'trainval') 29 | 30 | self.mode = mode 31 | self.transform = transform 32 | self.target_transform = target_transform 33 | 34 | # load the numpy arrays 35 | if mode in ["train", "valid", "trainval"]: 36 | filename = os.path.join(ROOT, 'mnist_trainval.npz') 37 | 38 | data = np.load(filename) 39 | 40 | num_train = len(data["labels"]) 41 | indices = np.arange(0, num_train) 42 | 43 | if reshuffle_seed is not None: 44 | rng = np.random.RandomState(reshuffle_seed) 45 | rng.shuffle(indices) 46 | 47 | split = int(np.floor(num_train * 5/6)) 48 | 49 | if mode == "train": 50 | data = { 51 | "images": data["images"][indices[:split], :], 52 | "labels": data["labels"][indices[:split]] 53 | } 54 | elif mode == "valid": 55 | data = { 56 | "images": data["images"][indices[split:], :], 57 | "labels": data["labels"][indices[split:]] 58 | } 59 | 60 | else: 61 | filename = os.path.join(ROOT, 'mnist_test.npz') 62 | data = np.load(filename) 63 | 64 | self.images = data['images'].astype(np.float32) 65 | self.labels = data['labels'].astype(np.int64) 66 | self.num_samples = len(self.labels) 67 | 68 | def __getitem__(self, index): 69 | """ 70 | :type index: int 71 | :param index: index of data 72 | Returns: 73 | tuple: (image, target) where target is index of the target class. 74 | """ 75 | image, label = self.images[index], self.labels[index] 76 | # convert to PIL Image 77 | image = Image.fromarray(image) 78 | # transform images and labels 79 | if self.transform is not None: 80 | self.transform.update_randomization() 81 | image = self.transform(image) 82 | if self.target_transform is not None: 83 | label = self.target_transform(label) 84 | return image, label 85 | 86 | def __len__(self): 87 | return len(self.labels) 88 | 89 | 90 | def build_mnist12k_loader(mode, batch_size, num_workers=8, rot_interpol_augmentation=False, interpolation=0, 91 | reshuffle_seed=None, coords=False): 92 | """ """ 93 | rng = np.random.RandomState(42) 94 | 95 | assert mode in ['train', 'valid', 'trainval', 'test'] 96 | assert reshuffle_seed is None or (mode != "test" and mode != 'trainval') 97 | 98 | assert interpolation in [0, 2, 3] # NEAREST, BILINEAR, BICUBIC 99 | 100 | transform = [] 101 | if mode in ['valid', 'test']: 102 | 103 | shuffle = False 104 | drop_last = False 105 | transform = [own_transforms.GrayToTensor()] 106 | elif mode in ['train', 'trainval']: 107 | shuffle = True 108 | drop_last = True 109 | if rot_interpol_augmentation: 110 | transform = [ 111 | transforms.RandomRotation(5, resample=interpolation), 112 | own_transforms.GrayToTensor(), 113 | ] 114 | else: 115 | transform = [own_transforms.GrayToTensor()] 116 | else: 117 | raise ValueError('unknown mode for building mnist_rot loader') 118 | 119 | if coords: 120 | transform += [own_transforms.CoordinateField((28, 28))] 121 | 122 | transform = own_transforms.Compose(transform) 123 | 124 | dataset = mnist_dataset(mode, transform=transform, reshuffle_seed=reshuffle_seed) 125 | loader = torch.utils.data.DataLoader( 126 | dataset, 127 | batch_size=batch_size, 128 | shuffle=shuffle, 129 | # sampler=torch.utils.data.sampler.RandomSampler(dataset), 130 | num_workers=num_workers, 131 | drop_last=drop_last, 132 | pin_memory=True 133 | ) 134 | n_inputs = 1 135 | n_outputs = 10 136 | 137 | if coords: 138 | n_inputs += 2 139 | 140 | return loader, n_inputs, n_outputs 141 | 142 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_rot/own_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | 7 | 8 | class Compose(transforms.Compose): 9 | """ Composes several transforms together. 10 | Adapted from torchvision.transforms to update randomization 11 | """ 12 | 13 | def update_randomization(self): 14 | """ iterate over composed transforms and apply update when method update_randomization is implemented """ 15 | for t in self.transforms: 16 | update_fct = getattr(t, 'update_randomization', None) 17 | if update_fct is not None and callable(update_fct): 18 | update_fct() 19 | 20 | 21 | class Rotate(object): 22 | """ return image rotated by a random angle or zero degrees 23 | an angle of zero still gives interpolation effects, should be applied to test set when train set is rotated by random angle 24 | """ 25 | 26 | def __init__(self, rng=None, interpolation=0): 27 | self.rng = rng 28 | assert interpolation in [0, 2, 3] # NEAREST, BILINEAR, BICUBIC 29 | self.interpolation = interpolation 30 | self.update_randomization() 31 | 32 | def update_randomization(self): 33 | if self.rng: 34 | self.angle = self.rng.uniform(360) 35 | else: 36 | self.angle = 0 37 | 38 | def __call__(self, img): 39 | """ 40 | :type img: PIL.Image 41 | :param img: image to be transformed 42 | """ 43 | return img.rotate(angle=self.angle, resample=self.interpolation) 44 | 45 | 46 | class Rotate90(object): 47 | """ return image rotated by a random multiple of 90 degrees """ 48 | 49 | def __init__(self, rng=None): 50 | self.rng = rng 51 | self.update_randomization() 52 | 53 | def update_randomization(self): 54 | self.multiple = self.rng.randint(4) 55 | 56 | def __call__(self, img): 57 | """ 58 | :type img: PIL.Image 59 | :param img: image to be rotated 60 | """ 61 | img = np.rot90(img, self.multiple) 62 | return Image.fromarray(img) 63 | 64 | 65 | class ShiftScale(object): 66 | """ return image shifted and rescaled image """ 67 | 68 | def __init__(self, rng, shiftmax=1, scalemax=.025): 69 | self.shiftmax = shiftmax # default shifts by 0 to 1 pixel 70 | self.scalemax = scalemax # default scales between .975 to 1.025 percent 71 | self.rng = rng 72 | self.update_randomization() 73 | 74 | def update_randomization(self): 75 | shiftX = self.rng.uniform(self.shiftmax) 76 | shiftY = self.rng.uniform(self.shiftmax) 77 | scaleX = self.rng.uniform(1 - self.scalemax, 1 + self.scalemax) 78 | scaleY = self.rng.uniform(1 - self.scalemax, 1 + self.scalemax) 79 | 80 | def __call__(self, img): 81 | """ 82 | :type img: PIL.Image 83 | :param img: image to be transformed 84 | """ 85 | return img.transform(img.size, Image.AFFINE, data=(self.scaleX, 0, self.shiftX, 0, self.scaleY, self.shiftY), 86 | resample=Image.BILINEAR) 87 | 88 | 89 | class Reflect(object): 90 | """ reflect image """ 91 | 92 | def __init__(self, rng): 93 | self.rng = rng 94 | self.update_randomization() 95 | 96 | def update_randomization(self): 97 | self.flipX, self.flipY = self.rng.randint(2, size=2) # two random numbers, each from {0,1} 98 | 99 | def __call__(self, img): 100 | """ 101 | :type img: PIL.Image 102 | :param img: image to be transformed 103 | """ 104 | if self.flipX: 105 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 106 | if self.flipY: 107 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 108 | return img 109 | 110 | 111 | class GrayToTensor(object): 112 | """ converts gray image to tensor and adds channel dimension """ 113 | 114 | def __call__(self, img): 115 | """ 116 | :type tensor: torch.FloatTensor 117 | :param tensor: image tensor to which channel is added 118 | """ 119 | img = np.array(img, np.float32, copy=False)[np.newaxis, ...] # add channel dimension 120 | return torch.from_numpy(img) 121 | 122 | 123 | class CoordinateField(object): 124 | """ Add the x and y coordinates of each pixel as two additional scalar features """ 125 | 126 | def __init__(self, shape): 127 | coords = [torch.arange(s) for s in shape] 128 | coords = torch.stack(torch.meshgrid(coords)) 129 | coords = coords.to(dtype=torch.float) 130 | 131 | l = len(shape) 132 | 133 | assert coords.shape == (l,) + shape, coords.shape 134 | 135 | coords = coords.reshape(l, -1) 136 | 137 | coords -= coords.mean(dim=1, keepdim=True) 138 | coords /= coords.std(dim=1, keepdim=True) 139 | 140 | coords = coords.reshape(l, *shape) 141 | 142 | self.coords = coords 143 | 144 | self._expand_shape = tuple(-1 for _ in range(len(shape)+1)) 145 | 146 | def __call__(self, img): 147 | """ 148 | :type img: torch.FloatTensor 149 | :param img: image tensor to which channel is added 150 | """ 151 | return torch.cat([img, self.coords], dim=0) 152 | 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /experiments/mnist_bench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | datasets=("mnist_rot" "mnist12k" "mnist_fliprot" ) 5 | Ns=("2" "3" "4" "5" "6" "7" "8" "9" "12" "16" "20") 6 | 7 | # frequency cut-off policy; use the default one 8 | F="None" 9 | 10 | for i in $(seq 1 5) 11 | do 12 | 13 | # C_N and C_D regular and quotient models and conventional CNN 14 | for dataset in "${datasets[@]}" 15 | do 16 | # baseline standard CNN 17 | ./mnist_bench_single.sh --S 1 --model CNN --type "None" --dataset "$dataset" --N 1 --F "$F" --sigma "None" --fixparams --regularize 18 | 19 | # steerable CNN with C_1 equivariance (equivalent to a standard CNN) 20 | ./mnist_bench_single.sh --S 1 --type regular --dataset "$dataset" --N 1 --F "$F" --sigma "None" --fixparams --regularize 21 | 22 | # steerable CNN with D_1 equivariance (flip equivariance) 23 | ./mnist_bench_single.sh --S 1 --type regular --dataset "$dataset" --N 1 --flip --F "$F" --sigma "None" --fixparams --regularize 24 | 25 | for N in "${Ns[@]}" 26 | do 27 | # C_N 28 | ./mnist_bench_single.sh --S 1 --type regular --dataset "$dataset" --N $N --F "$F" --sigma "None" --fixparams --regularize 29 | # D_N 30 | ./mnist_bench_single.sh --S 1 --type regular --dataset "$dataset" --N $N --flip --F "$F" --sigma "None" --fixparams --regularize 31 | 32 | # C_N with quotient representations 33 | ./mnist_bench_single.sh --S 1 --type quotient --dataset "$dataset" --N $N --F "$F" --sigma "None" --fixparams --regularize 34 | done 35 | done 36 | 37 | # Other D_N and C_N models, only instantiated for N=16 38 | for dataset in "${datasets[@]}" 39 | do 40 | 41 | ./mnist_bench_single.sh --S 1 --type "scalarfield" --dataset "$dataset" --N 16 --F "$F" --sigma "None" --fixparams --regularize 42 | ./mnist_bench_single.sh --S 1 --type "scalarfield" --dataset "$dataset" --N 16 --flip --F "$F" --sigma "None" --fixparams --regularize 43 | 44 | dicrete_types=( "vectorfield" "regvector") 45 | # only C_N models 46 | for type in "${dicrete_types[@]}" 47 | do 48 | ./mnist_bench_single.sh --S 1 --type "$type" --dataset "$dataset" --N 16 --F "$F" --sigma "None" --fixparams --regularize 49 | done 50 | done 51 | 52 | # SO(2) and O(2) models 53 | # frequency is encoded with a minus sign, to distinguish it from the order of C_N when passed as an argumetn with --N 54 | freq=("-1" "-3" "-5" "-7") 55 | for dataset in "${datasets[@]}" 56 | do 57 | # O(2) invariant network using only isotropic filters 58 | ./mnist_bench_single.sh --S 1 --type "trivial" --dataset "$dataset" --N -1 --flip --F "$F" --sigma "None" --fixparams --regularize 59 | 60 | # experiment with different irrep types, up to frequency "-$f" 61 | for f in "${freq[@]}" 62 | do 63 | ./mnist_bench_single.sh --S 1 --type "hnet_conv" --dataset "$dataset" --N $f --F "$F" --sigma "None" --fixparams --regularize 64 | ./mnist_bench_single.sh --S 1 --type "hnet_conv" --dataset "$dataset" --N $f --flip --F "$F" --sigma "None" --fixparams --regularize 65 | 66 | ./mnist_bench_single.sh --S 1 --type "realhnet" --dataset "$dataset" --N $f --F "$F" --sigma "None" --fixparams --regularize 67 | ./mnist_bench_single.sh --S 1 --type "realhnet2" --dataset "$dataset" --N $f --F "$F" --sigma "None" --fixparams --regularize 68 | 69 | ./mnist_bench_single.sh --S 1 --type "inducedhnet_conv" --dataset "$dataset" --N $f --flip --F "$F" --sigma "None" --fixparams --regularize 70 | 71 | done 72 | 73 | # other SO(2) models 74 | ./mnist_bench_single.sh --S 1 --type "squash" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 75 | ./mnist_bench_single.sh --S 1 --type "hnet_norm" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 76 | ./mnist_bench_single.sh --S 1 --type "sharednorm" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 77 | ./mnist_bench_single.sh --S 1 --type "sharednorm2" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 78 | 79 | ./mnist_bench_single.sh --S 1 --type "gated_conv" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 80 | ./mnist_bench_single.sh --S 1 --type "gated_conv_shared" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 81 | ./mnist_bench_single.sh --S 1 --type "gated_norm" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 82 | ./mnist_bench_single.sh --S 1 --type "gated_norm_shared" --dataset "$dataset" --N -3 --F "$F" --sigma "None" --fixparams --regularize 83 | 84 | # other O(2) models 85 | ./mnist_bench_single.sh --S 1 --type "gated_conv" --dataset "$dataset" --N -3 --flip --F "$F" --sigma "None" --fixparams --regularize 86 | ./mnist_bench_single.sh --S 1 --type "gated_norm" --dataset "$dataset" --N -3 --flip --F "$F" --sigma "None" --fixparams --regularize 87 | ./mnist_bench_single.sh --S 1 --type "inducedgated_conv" --dataset "$dataset" --N -3 --flip --F "$F" --sigma "None" --fixparams --regularize 88 | ./mnist_bench_single.sh --S 1 --type "inducedgated_norm" --dataset "$dataset" --N -3 --flip --F "$F" --sigma "None" --fixparams --regularize 89 | 90 | done 91 | 92 | done 93 | 94 | -------------------------------------------------------------------------------- /experiments/datasets/mnist12k/own_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | 7 | 8 | class Compose(transforms.Compose): 9 | """ Composes several transforms together. 10 | Adapted from torchvision.transforms to update randomization 11 | """ 12 | 13 | def update_randomization(self): 14 | """ iterate over composed transforms and apply update when method update_randomization is implemented """ 15 | for t in self.transforms: 16 | update_fct = getattr(t, 'update_randomization', None) 17 | if update_fct is not None and callable(update_fct): 18 | update_fct() 19 | 20 | 21 | class FlipRotate(object): 22 | """ return image randomly flipped and rotated by a random angle or zero degrees 23 | an angle of zero still gives interpolation effects, should be applied to test set when train set is rotated by random angle 24 | """ 25 | 26 | def __init__(self, rng=None, interpolation=0): 27 | self.rng = rng 28 | assert interpolation in [0, 2, 3] # NEAREST, BILINEAR, BICUBIC 29 | self.interpolation = interpolation 30 | self.update_randomization() 31 | 32 | def update_randomization(self): 33 | if self.rng: 34 | self.angle = self.rng.uniform(360) 35 | else: 36 | self.angle = 0 37 | 38 | def __call__(self, img): 39 | """ 40 | :type img: PIL.Image 41 | :param img: image to be transformed 42 | """ 43 | if self.rng is not None and self.rng.rand() < 0.5: 44 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 45 | 46 | return img.rotate(angle=self.angle, resample=self.interpolation) 47 | 48 | 49 | class Rotate90(object): 50 | """ return image rotated by a random multiple of 90 degrees """ 51 | 52 | def __init__(self, rng=None): 53 | self.rng = rng 54 | self.update_randomization() 55 | 56 | def update_randomization(self): 57 | self.multiple = self.rng.randint(4) 58 | 59 | def __call__(self, img): 60 | """ 61 | :type img: PIL.Image 62 | :param img: image to be rotated 63 | """ 64 | img = np.rot90(img, self.multiple) 65 | return Image.fromarray(img) 66 | 67 | 68 | class ShiftScale(object): 69 | """ return image shifted and rescaled image """ 70 | 71 | def __init__(self, rng, shiftmax=1, scalemax=.025): 72 | self.shiftmax = shiftmax # default shifts by 0 to 1 pixel 73 | self.scalemax = scalemax # default scales between .975 to 1.025 percent 74 | self.rng = rng 75 | self.update_randomization() 76 | 77 | def update_randomization(self): 78 | shiftX = self.rng.uniform(self.shiftmax) 79 | shiftY = self.rng.uniform(self.shiftmax) 80 | scaleX = self.rng.uniform(1 - self.scalemax, 1 + self.scalemax) 81 | scaleY = self.rng.uniform(1 - self.scalemax, 1 + self.scalemax) 82 | 83 | def __call__(self, img): 84 | """ 85 | :type img: PIL.Image 86 | :param img: image to be transformed 87 | """ 88 | return img.transform(img.size, Image.AFFINE, data=(self.scaleX, 0, self.shiftX, 0, self.scaleY, self.shiftY), 89 | resample=Image.BILINEAR) 90 | 91 | 92 | class Reflect(object): 93 | """ reflect image """ 94 | 95 | def __init__(self, rng): 96 | self.rng = rng 97 | self.update_randomization() 98 | 99 | def update_randomization(self): 100 | self.flipX, self.flipY = self.rng.randint(2, size=2) # two random numbers, each from {0,1} 101 | 102 | def __call__(self, img): 103 | """ 104 | :type img: PIL.Image 105 | :param img: image to be transformed 106 | """ 107 | if self.flipX: 108 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 109 | if self.flipY: 110 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 111 | return img 112 | 113 | 114 | class GrayToTensor(object): 115 | """ converts gray image to tensor and adds channel dimension """ 116 | 117 | def __call__(self, img): 118 | """ 119 | :type tensor: torch.FloatTensor 120 | :param tensor: image tensor to which channel is added 121 | """ 122 | img = np.array(img, np.float32, copy=False)[np.newaxis, ...] # add channel dimension 123 | return torch.from_numpy(img) 124 | 125 | 126 | class CoordinateField(object): 127 | """ Add the x and y coordinates of each pixel as two additional scalar features """ 128 | 129 | def __init__(self, shape): 130 | coords = [torch.arange(s) for s in shape] 131 | coords = torch.stack(torch.meshgrid(coords)) 132 | coords = coords.to(dtype=torch.float) 133 | 134 | l = len(shape) 135 | 136 | assert coords.shape == (l,) + shape, coords.shape 137 | 138 | coords = coords.reshape(l, -1) 139 | 140 | coords -= coords.mean(dim=1, keepdim=True) 141 | coords /= coords.std(dim=1, keepdim=True) 142 | 143 | coords = coords.reshape(l, *shape) 144 | 145 | self.coords = coords 146 | 147 | self._expand_shape = tuple(-1 for _ in range(len(shape) + 1)) 148 | 149 | def __call__(self, img): 150 | """ 151 | :type img: torch.FloatTensor 152 | :param img: image tensor to which channel is added 153 | """ 154 | return torch.cat([img, self.coords], dim=0) 155 | 156 | 157 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_fliprot/own_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | 7 | 8 | class Compose(transforms.Compose): 9 | """ Composes several transforms together. 10 | Adapted from torchvision.transforms to update randomization 11 | """ 12 | 13 | def update_randomization(self): 14 | """ iterate over composed transforms and apply update when method update_randomization is implemented """ 15 | for t in self.transforms: 16 | update_fct = getattr(t, 'update_randomization', None) 17 | if update_fct is not None and callable(update_fct): 18 | update_fct() 19 | 20 | 21 | class FlipRotate(object): 22 | """ return image randomly flipped and rotated by a random angle or zero degrees 23 | an angle of zero still gives interpolation effects, should be applied to test set when train set is rotated by random angle 24 | """ 25 | 26 | def __init__(self, rng=None, interpolation=0): 27 | self.rng = rng 28 | assert interpolation in [0, 2, 3] # NEAREST, BILINEAR, BICUBIC 29 | self.interpolation = interpolation 30 | self.update_randomization() 31 | 32 | def update_randomization(self): 33 | if self.rng: 34 | self.angle = self.rng.uniform(360) 35 | else: 36 | self.angle = 0 37 | 38 | def __call__(self, img): 39 | """ 40 | :type img: PIL.Image 41 | :param img: image to be transformed 42 | """ 43 | if self.rng is not None and self.rng.rand() < 0.5: 44 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 45 | 46 | return img.rotate(angle=self.angle, resample=self.interpolation) 47 | 48 | 49 | class Rotate90(object): 50 | """ return image rotated by a random multiple of 90 degrees """ 51 | 52 | def __init__(self, rng=None): 53 | self.rng = rng 54 | self.update_randomization() 55 | 56 | def update_randomization(self): 57 | self.multiple = self.rng.randint(4) 58 | 59 | def __call__(self, img): 60 | """ 61 | :type img: PIL.Image 62 | :param img: image to be rotated 63 | """ 64 | img = np.rot90(img, self.multiple) 65 | return Image.fromarray(img) 66 | 67 | 68 | class ShiftScale(object): 69 | """ return image shifted and rescaled image """ 70 | 71 | def __init__(self, rng, shiftmax=1, scalemax=.025): 72 | self.shiftmax = shiftmax # default shifts by 0 to 1 pixel 73 | self.scalemax = scalemax # default scales between .975 to 1.025 percent 74 | self.rng = rng 75 | self.update_randomization() 76 | 77 | def update_randomization(self): 78 | shiftX = self.rng.uniform(self.shiftmax) 79 | shiftY = self.rng.uniform(self.shiftmax) 80 | scaleX = self.rng.uniform(1 - self.scalemax, 1 + self.scalemax) 81 | scaleY = self.rng.uniform(1 - self.scalemax, 1 + self.scalemax) 82 | 83 | def __call__(self, img): 84 | """ 85 | :type img: PIL.Image 86 | :param img: image to be transformed 87 | """ 88 | return img.transform(img.size, Image.AFFINE, data=(self.scaleX, 0, self.shiftX, 0, self.scaleY, self.shiftY), 89 | resample=Image.BILINEAR) 90 | 91 | 92 | class Reflect(object): 93 | """ reflect image """ 94 | 95 | def __init__(self, rng): 96 | self.rng = rng 97 | self.update_randomization() 98 | 99 | def update_randomization(self): 100 | self.flipX, self.flipY = self.rng.randint(2, size=2) # two random numbers, each from {0,1} 101 | 102 | def __call__(self, img): 103 | """ 104 | :type img: PIL.Image 105 | :param img: image to be transformed 106 | """ 107 | if self.flipX: 108 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 109 | if self.flipY: 110 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 111 | return img 112 | 113 | 114 | class GrayToTensor(object): 115 | """ converts gray image to tensor and adds channel dimension """ 116 | 117 | def __call__(self, img): 118 | """ 119 | :type tensor: torch.FloatTensor 120 | :param tensor: image tensor to which channel is added 121 | """ 122 | img = np.array(img, np.float32, copy=False)[np.newaxis, ...] # add channel dimension 123 | return torch.from_numpy(img) 124 | 125 | 126 | class CoordinateField(object): 127 | """ Add the x and y coordinates of each pixel as two additional scalar features """ 128 | 129 | def __init__(self, shape): 130 | coords = [torch.arange(s) for s in shape] 131 | coords = torch.stack(torch.meshgrid(coords)) 132 | coords = coords.to(dtype=torch.float) 133 | 134 | l = len(shape) 135 | 136 | assert coords.shape == (l,) + shape, coords.shape 137 | 138 | coords = coords.reshape(l, -1) 139 | 140 | coords -= coords.mean(dim=1, keepdim=True) 141 | coords /= coords.std(dim=1, keepdim=True) 142 | 143 | coords = coords.reshape(l, *shape) 144 | 145 | self.coords = coords 146 | 147 | self._expand_shape = tuple(-1 for _ in range(len(shape) + 1)) 148 | 149 | def __call__(self, img): 150 | """ 151 | :type img: torch.FloatTensor 152 | :param img: image tensor to which channel is added 153 | """ 154 | return torch.cat([img, self.coords], dim=0) 155 | 156 | 157 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_rot/data_loader_mnist_rot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | 7 | from . import own_transforms 8 | 9 | 10 | class mnist_rot_dataset(data.Dataset): 11 | """ rotated MNIST dataset """ 12 | 13 | def __init__(self, mode, transform=None, target_transform=None, reshuffle_seed=None): 14 | """ 15 | :type mode: string from ['train', 'valid', 'test'] 16 | :param mode: determines which subset of the dataset is loaded and whether augmentation is used 17 | :type transform: callable 18 | :param transform: transformation applied to PIL images, returning transformed version 19 | :type target_transform: callable 20 | :param target_transform: transformation applied to labels 21 | :type reshuffle_seed: int 22 | :param reshuffle_seed: seed to use to reshuffle train or valid sets. If None (default), they are not reshuffled 23 | """ 24 | assert mode in ['train', 'valid', 'trainval', 'test'] 25 | assert reshuffle_seed is None or (mode != "test" and mode != 'trainval') 26 | 27 | self.mode = mode 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | # load the numpy arrays 32 | if mode in ["train", "valid", "trainval"]: 33 | filename = './datasets/mnist_rot/mnist_rot_trainval.npz' 34 | 35 | data = np.load(filename) 36 | 37 | num_train = len(data["labels"]) 38 | indices = np.arange(0, num_train) 39 | 40 | if reshuffle_seed is not None: 41 | rng = np.random.RandomState(reshuffle_seed) 42 | rng.shuffle(indices) 43 | 44 | split = int(np.floor(num_train * 5/6)) 45 | 46 | if mode == "train": 47 | data = { 48 | "images": data["images"][indices[:split], :], 49 | "labels": data["labels"][indices[:split]] 50 | } 51 | elif mode == "valid": 52 | data = { 53 | "images": data["images"][indices[split:], :], 54 | "labels": data["labels"][indices[split:]] 55 | } 56 | 57 | else: 58 | filename = './datasets/mnist_rot/mnist_rot_test.npz' 59 | data = np.load(filename) 60 | 61 | self.images = data['images'].astype(np.float32) 62 | self.labels = data['labels'].astype(np.int64) 63 | self.num_samples = len(self.labels) 64 | 65 | def __getitem__(self, index): 66 | """ 67 | :type index: int 68 | :param index: index of data 69 | Returns: 70 | tuple: (image, target) where target is index of the target class. 71 | """ 72 | image, label = self.images[index], self.labels[index] 73 | # convert to PIL Image 74 | image = Image.fromarray(image) 75 | # transform images and labels 76 | if self.transform is not None: 77 | self.transform.update_randomization() 78 | image = self.transform(image) 79 | if self.target_transform is not None: 80 | label = self.target_transform(label) 81 | return image, label 82 | 83 | def __len__(self): 84 | return len(self.labels) 85 | 86 | 87 | def build_mnist_rot_loader(mode, batch_size, num_workers=8, rot_interpol_augmentation=False, interpolation=0, reshuffle_seed=None, coords=False): 88 | """ """ 89 | rng = np.random.RandomState(42) 90 | 91 | assert mode in ['train', 'valid', 'trainval', 'test'] 92 | assert reshuffle_seed is None or (mode != "test" and mode != 'trainval') 93 | 94 | transform = [] 95 | if mode in ['valid', 'test']: 96 | 97 | shuffle = False 98 | drop_last = False 99 | if rot_interpol_augmentation: 100 | transform = [ 101 | own_transforms.Rotate(rng=None, interpolation=interpolation), # only resamples image 102 | # own_transforms.ShiftScale(rng), 103 | own_transforms.GrayToTensor() 104 | ] 105 | else: 106 | transform = [own_transforms.GrayToTensor()] 107 | elif mode in ['train', 'trainval']: 108 | shuffle = True 109 | drop_last = True 110 | if rot_interpol_augmentation: 111 | transform = [ 112 | own_transforms.Rotate(rng=rng, interpolation=interpolation), 113 | # own_transforms.Rotate90(rng=rng), 114 | # own_transforms.ShiftScale(rng), 115 | own_transforms.GrayToTensor() 116 | ] 117 | 118 | else: 119 | transform = [ 120 | # own_transforms.Rotate90(rng=rng), 121 | own_transforms.GrayToTensor() 122 | ] 123 | else: 124 | raise ValueError('unknown mode for building mnist_rot loader') 125 | 126 | if coords: 127 | transform += [own_transforms.CoordinateField((28, 28))] 128 | 129 | transform = own_transforms.Compose(transform) 130 | 131 | dataset = mnist_rot_dataset(mode, transform=transform, reshuffle_seed=reshuffle_seed) 132 | loader = torch.utils.data.DataLoader( 133 | dataset, 134 | batch_size=batch_size, 135 | shuffle=shuffle, 136 | # sampler=torch.utils.data.sampler.RandomSampler(dataset), 137 | num_workers=num_workers, 138 | drop_last=drop_last, 139 | pin_memory=True 140 | ) 141 | n_inputs = 1 142 | n_outputs = 10 143 | 144 | if coords: 145 | n_inputs += 2 146 | 147 | return loader, n_inputs, n_outputs 148 | 149 | -------------------------------------------------------------------------------- /experiments/datasets/mnist_fliprot/data_loader_mnist_fliprot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | 7 | from . import own_transforms 8 | 9 | 10 | class mnist_fliprot_dataset(data.Dataset): 11 | """ flip-rotated MNIST dataset """ 12 | 13 | def __init__(self, mode, transform=None, target_transform=None, reshuffle_seed=None): 14 | """ 15 | :type mode: string from ['train', 'valid', 'test'] 16 | :param mode: determines which subset of the dataset is loaded and whether augmentation is used 17 | :type transform: callable 18 | :param transform: transformation applied to PIL images, returning transformed version 19 | :type target_transform: callable 20 | :param target_transform: transformation applied to labels 21 | :type reshuffle_seed: int 22 | :param reshuffle_seed: seed to use to reshuffle train or valid sets. If None (default), they are not reshuffled 23 | """ 24 | assert mode in ['train', 'valid', 'trainval', 'test'] 25 | assert reshuffle_seed is None or (mode != "test" and mode != 'trainval') 26 | 27 | self.mode = mode 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | # load the numpy arrays 32 | if mode in ["train", "valid", "trainval"]: 33 | filename = './datasets/mnist_fliprot/mnist_fliprot_trainval.npz' 34 | 35 | data = np.load(filename) 36 | 37 | num_train = len(data["labels"]) 38 | indices = np.arange(0, num_train) 39 | 40 | if reshuffle_seed is not None: 41 | rng = np.random.RandomState(reshuffle_seed) 42 | rng.shuffle(indices) 43 | 44 | split = int(np.floor(num_train * 5 / 6)) 45 | 46 | if mode == "train": 47 | data = { 48 | "images": data["images"][indices[:split], :], 49 | "labels": data["labels"][indices[:split]] 50 | } 51 | elif mode == "valid": 52 | data = { 53 | "images": data["images"][indices[split:], :], 54 | "labels": data["labels"][indices[split:]] 55 | } 56 | 57 | else: 58 | filename = './datasets/mnist_fliprot/mnist_fliprot_test.npz' 59 | data = np.load(filename) 60 | 61 | self.images = data['images'].astype(np.float32) 62 | self.labels = data['labels'].astype(np.int64) 63 | self.num_samples = len(self.labels) 64 | 65 | def __getitem__(self, index): 66 | """ 67 | :type index: int 68 | :param index: index of data 69 | Returns: 70 | tuple: (image, target) where target is index of the target class. 71 | """ 72 | image, label = self.images[index], self.labels[index] 73 | # convert to PIL Image 74 | image = Image.fromarray(image) 75 | # transform images and labels 76 | if self.transform is not None: 77 | self.transform.update_randomization() 78 | image = self.transform(image) 79 | if self.target_transform is not None: 80 | label = self.target_transform(label) 81 | return image, label 82 | 83 | def __len__(self): 84 | return len(self.labels) 85 | 86 | 87 | def build_mnist_rot_loader(mode, batch_size, num_workers=8, rot_interpol_augmentation=False, interpolation=0, reshuffle_seed=None, coords=False): 88 | """ """ 89 | rng = np.random.RandomState(42) 90 | 91 | assert mode in ['train', 'valid', 'trainval', 'test'] 92 | assert reshuffle_seed is None or (mode != "test" and mode != 'trainval') 93 | 94 | transform = [] 95 | if mode in ['valid', 'test']: 96 | 97 | shuffle = False 98 | drop_last = False 99 | if rot_interpol_augmentation: 100 | transform = [ 101 | own_transforms.FlipRotate(rng=None, interpolation=interpolation), # only resamples image 102 | # own_transforms.ShiftScale(rng), 103 | own_transforms.GrayToTensor() 104 | ] 105 | else: 106 | transform = [own_transforms.GrayToTensor()] 107 | elif mode in ['train', 'trainval']: 108 | shuffle = True 109 | drop_last = True 110 | if rot_interpol_augmentation: 111 | transform = [ 112 | own_transforms.FlipRotate(rng=rng, interpolation=interpolation), 113 | # own_transforms.Rotate90(rng=rng), 114 | # own_transforms.ShiftScale(rng), 115 | own_transforms.GrayToTensor() 116 | ] 117 | else: 118 | transform = [ 119 | # own_transforms.Rotate90(rng=rng), 120 | own_transforms.GrayToTensor() 121 | ] 122 | else: 123 | raise ValueError('unknown mode for building mnist_rot loader') 124 | 125 | if coords: 126 | transform += [own_transforms.CoordinateField((28, 28))] 127 | 128 | transform = own_transforms.Compose(transform) 129 | 130 | dataset = mnist_fliprot_dataset(mode, transform=transform, reshuffle_seed=reshuffle_seed) 131 | loader = torch.utils.data.DataLoader( 132 | dataset, 133 | batch_size=batch_size, 134 | shuffle=shuffle, 135 | # sampler=torch.utils.data.sampler.RandomSampler(dataset), 136 | num_workers=num_workers, 137 | drop_last=drop_last, 138 | pin_memory=True 139 | ) 140 | n_inputs = 1 141 | n_outputs = 10 142 | 143 | if coords: 144 | n_inputs += 2 145 | 146 | return loader, n_inputs, n_outputs 147 | 148 | -------------------------------------------------------------------------------- /experiments/datasets/STL10/data_loader_stl10frac2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | 9 | from .data_loader_stl10 import DATA_DIR, CIFAR_MEAN, CIFAR_STD, MEAN, STD 10 | from .data_loader_stl10 import Cutout 11 | 12 | 13 | def __balanced_subdataset_idxs(train_size, validation_size, labels, reshuffle): 14 | num_train = len(labels) 15 | assert train_size + validation_size <= num_train 16 | 17 | classes = set(labels) 18 | 19 | labels_idxs = {c: list() for c in classes} 20 | ratios = {c: 0. for c in classes} 21 | 22 | for i, l in enumerate(labels): 23 | labels_idxs[l].append(i) 24 | ratios[l] += 1. 25 | 26 | train_idx = list() 27 | valid_idx = list() 28 | for c in classes: 29 | ratios[c] /= num_train 30 | 31 | if reshuffle: 32 | np.random.shuffle(labels_idxs[c]) 33 | 34 | ts = int(round(train_size * ratios[c])) 35 | vs = int(round(validation_size * ratios[c])) 36 | 37 | valid_idx += labels_idxs[c][:vs] 38 | train_idx += labels_idxs[c][vs:vs+ts] 39 | 40 | return train_idx, valid_idx 41 | 42 | 43 | def __build_stl10_frac_loaders(size, 44 | batch_size, 45 | eval_batchsize, 46 | validation=True, 47 | num_workers=8, 48 | augment=False, 49 | reshuffle=True, 50 | mean=MEAN, 51 | std=STD, 52 | ): 53 | 54 | normalize = transforms.Normalize( 55 | mean=mean, 56 | std=std, 57 | ) 58 | 59 | # define transforms 60 | valid_transform = transforms.Compose([ 61 | transforms.ToTensor(), 62 | normalize, 63 | ]) 64 | 65 | if augment: 66 | train_transform = transforms.Compose([ 67 | transforms.RandomCrop(96, padding=12), 68 | transforms.RandomHorizontalFlip(), 69 | transforms.ToTensor(), 70 | # Cutout(32), 71 | Cutout(60), 72 | normalize, 73 | ]) 74 | else: 75 | train_transform = transforms.Compose([ 76 | transforms.ToTensor(), 77 | # Cutout(24), 78 | Cutout(48), 79 | normalize, 80 | ]) 81 | # train_transform = transforms.Compose([ 82 | # transforms.ToTensor(), 83 | # normalize, 84 | # ]) 85 | 86 | # load the dataset 87 | train_dataset = datasets.STL10( 88 | root=DATA_DIR, split="train", 89 | download=True, transform=train_transform, 90 | ) 91 | 92 | test_dataset = datasets.STL10( 93 | root=DATA_DIR, split="test", 94 | download=True, transform=valid_transform, 95 | ) 96 | 97 | if validation: 98 | 99 | valid_dataset = datasets.STL10( 100 | root=DATA_DIR, split="train", 101 | download=True, transform=valid_transform, 102 | ) 103 | 104 | validation_size = 1000 105 | train_idx, valid_idx = __balanced_subdataset_idxs(size, validation_size, train_dataset.labels, reshuffle) 106 | train_sampler = SubsetRandomSampler(train_idx) 107 | valid_sampler = SubsetRandomSampler(valid_idx) 108 | 109 | train_loader = torch.utils.data.DataLoader( 110 | train_dataset, batch_size=batch_size, sampler=train_sampler, 111 | num_workers=num_workers, pin_memory=True, 112 | ) 113 | valid_loader = torch.utils.data.DataLoader( 114 | valid_dataset, batch_size=eval_batchsize, sampler=valid_sampler, 115 | num_workers=num_workers, pin_memory=True, 116 | ) 117 | else: 118 | 119 | train_idx, _ = __balanced_subdataset_idxs(size, 0, train_dataset.labels, reshuffle) 120 | 121 | train_sampler = SubsetRandomSampler(train_idx) 122 | 123 | train_loader = torch.utils.data.DataLoader( 124 | train_dataset, batch_size=batch_size, sampler=train_sampler, 125 | num_workers=num_workers, pin_memory=True, 126 | ) 127 | valid_loader = None 128 | 129 | test_loader = torch.utils.data.DataLoader( 130 | test_dataset, batch_size=eval_batchsize, shuffle=False, 131 | num_workers=num_workers, pin_memory=True, 132 | ) 133 | 134 | n_inputs = 3 135 | n_classes = 10 136 | 137 | return train_loader, valid_loader, test_loader, n_inputs, n_classes 138 | 139 | 140 | def build_stl10_frac_loaders(size, 141 | batch_size, 142 | eval_batchsize, 143 | validation=True, 144 | num_workers=8, 145 | augment=False, 146 | reshuffle=True, 147 | ): 148 | return __build_stl10_frac_loaders(size, batch_size, eval_batchsize, validation, num_workers, augment, reshuffle, 149 | mean=MEAN, std=STD) 150 | 151 | 152 | def build_stl10cif_frac_loaders(size, 153 | batch_size, 154 | eval_batchsize, 155 | validation=True, 156 | num_workers=8, 157 | augment=False, 158 | reshuffle=True, 159 | ): 160 | return __build_stl10_frac_loaders(size, batch_size, eval_batchsize, validation, num_workers, augment, reshuffle, 161 | mean=CIFAR_MEAN, std=CIFAR_STD) 162 | 163 | -------------------------------------------------------------------------------- /experiments/models/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from e2cnn import nn 3 | from e2cnn import group 4 | from e2cnn import gspaces 5 | from e2cnn.nn import init 6 | 7 | import math 8 | import numpy as np 9 | 10 | STORE_PATH = "./models/stored/" 11 | 12 | CHANNELS_CONSTANT = 1 13 | 14 | 15 | def _get_fco(fco): 16 | if fco > 0.: 17 | fco *= np.pi 18 | return fco 19 | 20 | 21 | def conv7x7(in_type: nn.FieldType, out_type: nn.FieldType, stride=1, padding=3, dilation=1, bias=False, sigma=None, F=1., initialize=True): 22 | """7x7 convolution with padding""" 23 | fco = _get_fco(F) 24 | return nn.R2Conv(in_type, out_type, 7, 25 | stride=stride, 26 | padding=padding, 27 | dilation=dilation, 28 | bias=bias, 29 | sigma=sigma, 30 | frequencies_cutoff=fco, 31 | initialize=initialize 32 | ) 33 | 34 | 35 | def conv5x5(in_type: nn.FieldType, out_type: nn.FieldType, stride=1, padding=2, dilation=1, bias=False, sigma=None, F=1., initialize=True): 36 | """5x5 convolution with padding""" 37 | fco = _get_fco(F) 38 | return nn.R2Conv(in_type, out_type, 5, 39 | stride=stride, 40 | padding=padding, 41 | dilation=dilation, 42 | bias=bias, 43 | sigma=sigma, 44 | frequencies_cutoff=fco, 45 | initialize=initialize 46 | ) 47 | 48 | 49 | def conv3x3(in_type: nn.FieldType, out_type: nn.FieldType, padding=1, stride=1, dilation=1, bias=False, sigma=None, F=1., initialize=True): 50 | """3x3 convolution with padding""" 51 | fco = _get_fco(F) 52 | return nn.R2Conv(in_type, out_type, 3, 53 | stride=stride, 54 | padding=padding, 55 | dilation=dilation, 56 | bias=bias, 57 | sigma=sigma, 58 | frequencies_cutoff=fco, 59 | initialize=initialize 60 | ) 61 | 62 | 63 | def conv1x1(in_type: nn.FieldType, out_type: nn.FieldType, padding=0, stride=1, dilation=1, bias=False, sigma=None, F=1., initialize=True): 64 | """1x1 convolution""" 65 | fco = _get_fco(F) 66 | return nn.R2Conv(in_type, out_type, 1, 67 | stride=stride, 68 | padding=padding, 69 | dilation=dilation, 70 | bias=bias, 71 | sigma=sigma, 72 | frequencies_cutoff=fco, 73 | initialize=initialize 74 | ) 75 | 76 | 77 | def regular_fiber(gspace: gspaces.GeneralOnR2, planes: int, fixparams: bool = True): 78 | """ build a regular fiber with the specified number of channels""" 79 | assert gspace.fibergroup.order() > 0 80 | N = gspace.fibergroup.order() 81 | planes = planes / N 82 | if fixparams: 83 | planes *= math.sqrt(N * CHANNELS_CONSTANT) 84 | planes = int(planes) 85 | 86 | return nn.FieldType(gspace, [gspace.regular_repr] * planes) 87 | 88 | 89 | def quotient_fiber(gspace: gspaces.GeneralOnR2, planes: int, fixparams: bool = True): 90 | """ build a quotient fiber with the specified number of channels""" 91 | N = gspace.fibergroup.order() 92 | assert N > 0 93 | if isinstance(gspace, gspaces.FlipRot2dOnR2): 94 | n = N/2 95 | subgroups = [] 96 | for axis in [0, round(n/4), round(n/2)]: 97 | subgroups.append((int(axis), 1)) 98 | elif isinstance(gspace, gspaces.Rot2dOnR2): 99 | assert N % 4 == 0 100 | # subgroups = [int(round(N/2)), int(round(N/4))] 101 | subgroups = [2, 4] 102 | elif isinstance(gspace, gspaces.Flip2dOnR2): 103 | subgroups = [2] 104 | else: 105 | raise ValueError(f"Space {gspace} not supported") 106 | 107 | rs = [gspace.quotient_repr(subgroup) for subgroup in subgroups] 108 | size = sum([r.size for r in rs]) 109 | planes = planes / size 110 | if fixparams: 111 | planes *= math.sqrt(N * CHANNELS_CONSTANT) 112 | planes = int(planes) 113 | return nn.FieldType(gspace, rs * planes).sorted() 114 | 115 | 116 | def trivial_fiber(gspace: gspaces.GeneralOnR2, planes: int, fixparams: bool = True): 117 | """ build a trivial fiber with the specified number of channels""" 118 | 119 | if fixparams: 120 | planes *= math.sqrt(gspace.fibergroup.order() * CHANNELS_CONSTANT) 121 | planes = int(planes) 122 | return nn.FieldType(gspace, [gspace.trivial_repr] * planes) 123 | 124 | 125 | def mixed_fiber(gspace: gspaces.GeneralOnR2, planes: int, ratio: float, fixparams: bool = True): 126 | 127 | N = gspace.fibergroup.order() 128 | assert N > 0 129 | if isinstance(gspace, gspaces.FlipRot2dOnR2): 130 | subgroup = (0, 1) 131 | elif isinstance(gspace, gspaces.Flip2dOnR2): 132 | subgroup = 1 133 | else: 134 | raise ValueError(f"Space {gspace} not supported") 135 | 136 | qr = gspace.quotient_repr(subgroup) 137 | rr = gspace.regular_repr 138 | 139 | planes = planes / rr.size 140 | 141 | if fixparams: 142 | planes *= math.sqrt(N * CHANNELS_CONSTANT) 143 | 144 | r_planes = int(planes * ratio) 145 | q_planes = int(2*planes * (1-ratio)) 146 | 147 | return nn.FieldType(gspace, [rr] * r_planes + [qr] * q_planes).sorted() 148 | 149 | 150 | def mixed1_fiber(gspace: gspaces.GeneralOnR2, planes: int, fixparams: bool = True): 151 | return mixed_fiber(gspace=gspace, planes=planes, ratio=0.5, fixparams=fixparams) 152 | 153 | 154 | def mixed2_fiber(gspace: gspaces.GeneralOnR2, planes: int, fixparams: bool = True): 155 | return mixed_fiber(gspace=gspace, planes=planes, ratio=0.25, fixparams=fixparams) 156 | 157 | 158 | FIBERS = { 159 | "trivial": trivial_fiber, 160 | "quotient": quotient_fiber, 161 | "regular": regular_fiber, 162 | "mixed1": mixed1_fiber, 163 | "mixed2": mixed2_fiber, 164 | } 165 | 166 | -------------------------------------------------------------------------------- /experiments/models/e2sfcnn.py: -------------------------------------------------------------------------------- 1 | from e2cnn.nn import * 2 | from e2cnn.gspaces import * 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class E2SFCNN(torch.nn.Module): 10 | 11 | def __init__(self, n_channels, n_classes, 12 | N=None, 13 | restrict: int = -1, 14 | fix_param: bool = False, 15 | fco: float = 0.8, 16 | p_drop_fully: float = 0.3, 17 | J: int = 0, 18 | sigma: float = 0.6, 19 | sgsize: int = None, 20 | flip: bool = True, 21 | ): 22 | 23 | super(E2SFCNN, self).__init__() 24 | 25 | if N is None: 26 | N = 16 27 | 28 | assert N > 1 29 | 30 | self.n_channels = n_channels 31 | self.n_classes = n_classes 32 | 33 | # build the group O(2) or D_N depending on the number N of rotations specified 34 | if N > 1: 35 | self.gspace = FlipRot2dOnR2(N) 36 | elif N == 1: 37 | self.gspace = Flip2dOnR2() 38 | else: 39 | raise ValueError(N) 40 | 41 | # if flips are not required, immediately restrict to the SO(2) or C_N subgroup 42 | if not flip: 43 | if N != 1: 44 | sg = (None, N) 45 | else: 46 | sg = 1 47 | self.gspace, _, _ = self.gspace.restrict(sg) 48 | 49 | # id of the subgroup if group restriction is applied through the network 50 | if sgsize is not None: 51 | self.sgid = sgsize 52 | else: 53 | self.sgid = N 54 | 55 | if flip and N != 1: 56 | self.sgid = (None, self.sgid) 57 | 58 | if fco is not None and fco > 0.: 59 | fco *= np.pi 60 | frequencies_cutoff = fco 61 | 62 | eq_layers = [] 63 | 64 | LAYER = 0 65 | 66 | def build_layers(r1: FieldType, C: int, s: int, padding: int = 0, pooling: int = None, orientantion_pooling: bool = False): 67 | 68 | gspace = r1.gspace 69 | 70 | if fix_param: 71 | # to keep number of parameters more or less constant when changing groups 72 | # (more precisely, we try to keep them close to the number of parameters in the original SFCNN) 73 | C /= np.sqrt(gspace.fibergroup.order()/16) 74 | C = int(C) 75 | 76 | layers = [] 77 | 78 | r2 = FieldType(gspace, [gspace.representations['regular']] * C) 79 | 80 | cl = R2Conv(r1, r2, s, 81 | frequencies_cutoff=frequencies_cutoff, 82 | padding=padding, 83 | sigma=sigma, 84 | maximum_offset=J) 85 | layers.append(cl) 86 | 87 | if restrict == LAYER: 88 | layers.append(RestrictionModule(layers[-1].out_type, self.sgid)) 89 | layers.append(DisentangleModule(layers[-1].out_type)) 90 | 91 | bn = InnerBatchNorm(layers[-1].out_type) 92 | layers.append(bn) 93 | 94 | if orientantion_pooling: 95 | pl = GroupPooling(layers[-1].out_type) 96 | layers.append(pl) 97 | 98 | if pooling is not None: 99 | pl = PointwiseMaxPool(layers[-1].out_type, pooling) 100 | layers.append(pl) 101 | 102 | nnl = ELU(layers[-1].out_type, inplace=True) 103 | layers.append(nnl) 104 | 105 | return layers 106 | 107 | if restrict == LAYER: 108 | self.gspace, _, _ = self.gspace.restrict(self.sgid) 109 | 110 | r1 = FieldType(self.gspace, [self.gspace.trivial_repr] * n_channels) 111 | 112 | # 28 px 113 | # Convolutional Layer 1 114 | 115 | LAYER += 1 116 | #TODO no padding here? with such a large filter??? 117 | eq_layers += build_layers(r1, 24, 9, 0, None) 118 | 119 | # Convolutional Layer 2 120 | LAYER += 1 121 | eq_layers += build_layers(eq_layers[-1].out_type, 32, 7, 3, 2) 122 | 123 | # TODO this number is right iff we used padding in the first layer! 124 | # 14 px 125 | # Convolutional Layer 3 126 | LAYER += 1 127 | eq_layers += build_layers(eq_layers[-1].out_type, 36, 7, 3, None) 128 | 129 | # Convolutional Layer 4 130 | LAYER += 1 131 | eq_layers += build_layers(eq_layers[-1].out_type, 36, 7, 3, 2) 132 | 133 | # 7 px 134 | 135 | # Convolutional Layer 5 136 | LAYER += 1 137 | eq_layers += build_layers(eq_layers[-1].out_type, 64, 7, 3) 138 | 139 | # Convolutional Layer 6 140 | LAYER += 1 141 | eq_layers += build_layers(eq_layers[-1].out_type, 96, 5, 0, None, True) 142 | 143 | # Adaptive Pooling 144 | mpl = PointwiseAdaptiveMaxPool(eq_layers[-1].out_type, 1) 145 | eq_layers.append(mpl) 146 | 147 | # 1 px 148 | 149 | # c = 96 150 | c = eq_layers[-1].out_type.size 151 | 152 | self.in_repr = eq_layers[0].in_type 153 | self.eq_layers = torch.nn.ModuleList(eq_layers) 154 | 155 | # Fully Connected 156 | 157 | self.fully_net = nn.Sequential( 158 | nn.Dropout(p=p_drop_fully), 159 | nn.Linear(c, 96), 160 | nn.BatchNorm1d(96), 161 | nn.ELU(inplace=True), 162 | 163 | nn.Dropout(p=p_drop_fully), 164 | nn.Linear(96, 96), 165 | nn.BatchNorm1d(96), 166 | nn.ELU(inplace=True), 167 | 168 | nn.Dropout(p=p_drop_fully), 169 | nn.Linear(96, n_classes), 170 | ) 171 | 172 | def forward(self, input): 173 | x = GeometricTensor(input, self.in_repr) 174 | 175 | for layer in self.eq_layers: 176 | x = layer(x) 177 | 178 | x = self.fully_net(x.tensor.reshape(x.tensor.shape[0], -1)) 179 | 180 | return x 181 | -------------------------------------------------------------------------------- /experiments/plot_exps.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import argparse 4 | import os 5 | import matplotlib 6 | 7 | import utils 8 | 9 | if "DISPLAY" not in os.environ: 10 | matplotlib.use('Agg') 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | SHOW_PLOT = False 15 | SAVE_PLOT = True 16 | 17 | RESHUFFLE = False 18 | AUGMENT_TRAIN = False 19 | 20 | colors = { 21 | "train": "blue", 22 | "valid": "green", 23 | "test": "red" 24 | } 25 | 26 | 27 | def plot_mean_with_variance(axis, data, label): 28 | mean = data.mean() 29 | std = data.std() 30 | axis.plot(mean, label=label, color=colors[label]) 31 | axis.fill_between( 32 | mean.index, 33 | mean - std, 34 | mean + std, 35 | color=colors[label], 36 | alpha=0.1 37 | ) 38 | 39 | 40 | def plot(logs, plotpath=None, show=False, outfig=None): 41 | 42 | if isinstance(logs, str) and os.path.isfile(logs): 43 | logs = utils.retrieve_logs(logs) 44 | elif not isinstance(logs, pd.DataFrame): 45 | raise ValueError() 46 | 47 | if outfig is None: 48 | figure, (loss_axis, acc_axis) = plt.subplots(1, 2, figsize=(10, 4)) 49 | else: 50 | figure, (loss_axis, acc_axis) = outfig 51 | 52 | train = logs[logs.split.str.startswith("train")].groupby("iteration") 53 | valid = logs[logs.split == "valid"].groupby("iteration") 54 | test = logs[logs.split == "test"].groupby("iteration") 55 | 56 | #################### Plot Loss trends #################### 57 | 58 | loss_axis.cla() 59 | 60 | plot_mean_with_variance(loss_axis, train.loss, "train") 61 | if len(valid) > 0: 62 | plot_mean_with_variance(loss_axis, valid.loss, "valid") 63 | if len(test) > 0: 64 | plot_mean_with_variance(loss_axis, test.loss, "test") 65 | 66 | loss_axis.legend() 67 | loss_axis.set_xlabel('iterations') 68 | loss_axis.set_ylabel('Loss') 69 | 70 | #################### Plot Accuracy trends #################### 71 | 72 | acc_axis.cla() 73 | 74 | plot_mean_with_variance(acc_axis, train.accuracy, "train") 75 | if len(valid) > 0: 76 | plot_mean_with_variance(acc_axis, valid.accuracy, "valid") 77 | if len(test) > 0: 78 | plot_mean_with_variance(acc_axis, test.accuracy, "test") 79 | 80 | ################## Test score ######################## 81 | 82 | test = logs[logs.split == "test"] 83 | 84 | xmax = logs.iteration.max() 85 | 86 | if len(test) > 0: 87 | best_acc = test.accuracy.max() 88 | acc_axis.hlines(best_acc, xmin=0, xmax=xmax, linewidth=0.5, linestyles='--', label='Max Test Accuracy') 89 | acc_axis.set_yticks(list(acc_axis.get_yticks()) + [best_acc]) 90 | 91 | if len(test) > 1: 92 | mean_acc = test.accuracy.mean() 93 | mean_std = test.accuracy.std() 94 | acc_axis.hlines(mean_acc, xmin=0, xmax=xmax, linewidth=0.5, color=colors["test"], label='Mean Test Accuracy') 95 | acc_axis.fill_between([0, xmax], [mean_acc - mean_std] * 2, [mean_acc + mean_std] * 2, color=colors["test"], 96 | alpha=0.1) 97 | acc_axis.set_yticks(list(acc_axis.get_yticks()) + [mean_acc]) 98 | 99 | acc_axis.legend() 100 | acc_axis.set_xlabel('iterations') 101 | acc_axis.set_ylabel('Accuracy') 102 | 103 | figure.tight_layout() 104 | plt.draw() 105 | 106 | if plotpath is not None: 107 | figure.savefig(plotpath, format='svg', dpi=256, bbox_inches="tight") 108 | 109 | if show: 110 | figure.show() 111 | plt.pause(0.01) 112 | 113 | 114 | ################################################################################ 115 | ################################################################################ 116 | 117 | 118 | if __name__ == "__main__": 119 | # Parse training configuration 120 | parser = argparse.ArgumentParser() 121 | 122 | # Dataset params 123 | parser.add_argument('--dataset', type=str, help='The name of the dataset to use') 124 | parser.add_argument('--augment', dest="augment", action="store_true", 125 | help='Augment the training set with rotated images') 126 | parser.set_defaults(augment=AUGMENT_TRAIN) 127 | 128 | parser.add_argument('--reshuffle', dest="reshuffle", action="store_true", 129 | help='Reshuffle train and valid splits instead of using the default split') 130 | parser.set_defaults(reshuffle=RESHUFFLE) 131 | 132 | # Model params 133 | parser.add_argument('--model', type=str, help='The name of the model to use') 134 | parser.add_argument('--N', type=int, help='Size of cyclic group for GCNN and maximum frequency for HNET') 135 | parser.add_argument('--flip', dest="flip", action="store_true", 136 | help='Use also reflection equivariance in the EXP model') 137 | parser.set_defaults(flip=False) 138 | parser.add_argument('--sgsize', type=int, default=None, 139 | help='Number of rotations in the subgroup to restrict to in the EXP e2sfcnn models') 140 | parser.add_argument('--fixparams', dest="fixparams", action="store_true", 141 | help='Keep the number of parameters of the model fixed by adjusting its topology') 142 | parser.set_defaults(fixparams=False) 143 | parser.add_argument('--F', type=float, default=0.8, help='Frequency cut-off: maximum frequency at radius "r" is "F*r"') 144 | parser.add_argument('--sigma', type=float, default=0.6, help='Width of the rings building the bases (std of the gaussian window)') 145 | parser.add_argument('--J', type=int, default=0, help='Number of additional frequencies in the interwiners of finite groups') 146 | parser.add_argument('--restrict', type=int, default=-1, help='Layer where to restrict SFCNN from E(2) to SE(2)') 147 | 148 | # plot configs 149 | parser.add_argument('--show', dest="show", action="store_true", help='Show the plots during execution') 150 | parser.set_defaults(show=SHOW_PLOT) 151 | 152 | parser.add_argument('--store_plot', dest="store_plot", action="store_true", help='Save the plots in a file or not') 153 | parser.set_defaults(store_plot=SAVE_PLOT) 154 | 155 | config = parser.parse_args() 156 | 157 | # Draw the plot 158 | logs_file = utils.logs_path(config) 159 | plotpath = utils.plot_path(config) 160 | plot(logs_file, plotpath, config.show) 161 | -------------------------------------------------------------------------------- /experiments/datasets/STL10/data_loader_stl10frac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torchvision import datasets 5 | from torchvision import transforms 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | 8 | from .data_loader_stl10 import DATA_DIR, CIFAR_MEAN, CIFAR_STD, MEAN, STD 9 | from .data_loader_stl10 import Cutout 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | 13 | 14 | class TransformedSubsetDataset(Dataset): 15 | def __init__(self, dataset: Dataset, transform, indeces): 16 | assert max(indeces) < len(dataset) 17 | assert min(indeces) >= 0 18 | 19 | self.dataset = dataset 20 | self.transform = transform 21 | self.indices = list(indeces) 22 | 23 | def __getitem__(self, index): 24 | x, t = self.dataset[self.indices[index]] 25 | return self.transform(x), t 26 | 27 | def __len__(self): 28 | return len(self.indices) 29 | 30 | 31 | def __balanced_subdataset_idxs(train_size, validation_size, labels, reshuffle): 32 | num_train = len(labels) 33 | 34 | test_size = num_train - train_size - validation_size 35 | assert test_size >= 0 36 | 37 | classes = set(labels) 38 | 39 | labels_idxs = {c: list() for c in classes} 40 | ratios = {c: 0. for c in classes} 41 | 42 | for i, l in enumerate(labels): 43 | labels_idxs[l].append(i) 44 | ratios[l] += 1. 45 | 46 | train_idx = list() 47 | valid_idx = list() 48 | test_idx = list() 49 | for c in classes: 50 | ratios[c] /= num_train 51 | 52 | if reshuffle: 53 | np.random.shuffle(labels_idxs[c]) 54 | 55 | ts = int(round(train_size * ratios[c])) 56 | vs = int(round(validation_size * ratios[c])) 57 | 58 | valid_idx += labels_idxs[c][:vs] 59 | train_idx += labels_idxs[c][vs:vs+ts] 60 | test_idx += labels_idxs[c][vs+ts:] 61 | 62 | return train_idx, valid_idx, test_idx 63 | 64 | 65 | def __build_stl10_frac_loaders(size, 66 | batch_size, 67 | eval_batchsize, 68 | validation=True, 69 | num_workers=8, 70 | augment=False, 71 | reshuffle=True, 72 | mean=MEAN, 73 | std=STD, 74 | ): 75 | 76 | normalize = transforms.Normalize( 77 | mean=mean, 78 | std=std, 79 | ) 80 | 81 | # define transforms 82 | valid_transform = transforms.Compose([ 83 | transforms.ToTensor(), 84 | normalize, 85 | ]) 86 | 87 | if augment: 88 | train_transform = transforms.Compose([ 89 | transforms.RandomCrop(96, padding=12), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | # Cutout(32), 93 | Cutout(60), 94 | normalize, 95 | ]) 96 | else: 97 | train_transform = transforms.Compose([ 98 | transforms.ToTensor(), 99 | # Cutout(24), 100 | Cutout(48), 101 | normalize, 102 | ]) 103 | # train_transform = transforms.Compose([ 104 | # transforms.ToTensor(), 105 | # normalize, 106 | # ]) 107 | 108 | # load the dataset 109 | train = datasets.STL10( 110 | root=DATA_DIR, split="train", 111 | download=True, transform=None, 112 | ) 113 | test = datasets.STL10( 114 | root=DATA_DIR, split="test", 115 | download=True, transform=None, 116 | ) 117 | total_dataset = ConcatDataset([train, test]) 118 | labels = np.concatenate([train.labels, test.labels]) 119 | 120 | if validation: 121 | validation_size = 1000 122 | else: 123 | validation_size = 0 124 | 125 | train_idx, valid_idx, test_idx = __balanced_subdataset_idxs(size, validation_size, labels, reshuffle) 126 | 127 | train_dataset = TransformedSubsetDataset(total_dataset, train_transform, train_idx) 128 | train_loader = torch.utils.data.DataLoader( 129 | train_dataset, batch_size=batch_size, shuffle=True, 130 | num_workers=num_workers, pin_memory=True, 131 | ) 132 | 133 | if validation: 134 | valid_dataset = TransformedSubsetDataset(total_dataset, valid_transform, valid_idx) 135 | valid_loader = torch.utils.data.DataLoader( 136 | valid_dataset, batch_size=eval_batchsize, shuffle=False, 137 | num_workers=num_workers, pin_memory=True, 138 | ) 139 | else: 140 | valid_loader = None 141 | 142 | test_dataset = TransformedSubsetDataset(total_dataset, valid_transform, test_idx) 143 | test_loader = torch.utils.data.DataLoader( 144 | test_dataset, batch_size=eval_batchsize, shuffle=False, 145 | num_workers=num_workers, pin_memory=True, 146 | ) 147 | 148 | n_inputs = 3 149 | n_classes = 10 150 | 151 | return train_loader, valid_loader, test_loader, n_inputs, n_classes 152 | 153 | 154 | def build_stl10_frac_loaders(size, 155 | batch_size, 156 | eval_batchsize, 157 | validation=True, 158 | num_workers=8, 159 | augment=False, 160 | reshuffle=True, 161 | ): 162 | return __build_stl10_frac_loaders(size, batch_size, eval_batchsize, validation, num_workers, augment, reshuffle, 163 | mean=MEAN, std=STD) 164 | 165 | 166 | def build_stl10cif_frac_loaders(size, 167 | batch_size, 168 | eval_batchsize, 169 | validation=True, 170 | num_workers=8, 171 | augment=False, 172 | reshuffle=True, 173 | ): 174 | return __build_stl10_frac_loaders(size, batch_size, eval_batchsize, validation, num_workers, augment, reshuffle, 175 | mean=CIFAR_MEAN, std=CIFAR_STD) 176 | 177 | -------------------------------------------------------------------------------- /experiments/models/e2sfcnn_quotient.py: -------------------------------------------------------------------------------- 1 | from e2cnn.nn import * 2 | from e2cnn.gspaces import * 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | class E2SFCNN_QUOT(torch.nn.Module): 9 | 10 | def __init__(self, n_channels, n_classes, 11 | N=None, 12 | restrict: int = -1, 13 | fix_param: bool = False, 14 | fco: float = 0.8, 15 | p_drop_fully: float = 0.3, 16 | J: int = 0, 17 | sigma: float = 0.6, 18 | sgsize: int = None, 19 | flip: bool = True, 20 | ): 21 | 22 | super(E2SFCNN_QUOT, self).__init__() 23 | 24 | if N is None: 25 | N = 16 26 | 27 | assert N > 1 28 | 29 | self.N = N 30 | self.n_channels = n_channels 31 | self.n_classes = n_classes 32 | 33 | # build the group O(2) or D_N depending on the number N of rotations specified 34 | if N > 1: 35 | self.gspace = FlipRot2dOnR2(N) 36 | elif N == 1: 37 | self.gspace = Flip2dOnR2() 38 | else: 39 | raise ValueError(N) 40 | 41 | # if flips are not required, immediately restrict to the SO(2) or C_N subgroup 42 | if not flip: 43 | if N != 1: 44 | sg = (None, N) 45 | else: 46 | sg = 1 47 | self.gspace, _, _ = self.gspace.restrict(sg) 48 | 49 | # id of the subgroup if group restriction is applied through the network 50 | if sgsize is not None: 51 | self.sgid = sgsize 52 | else: 53 | self.sgid = N 54 | 55 | if flip and N != 1: 56 | self.sgid = (None, self.sgid) 57 | 58 | if fco is not None and fco > 0.: 59 | fco *= np.pi 60 | self.frequencies_cutoff = fco 61 | self.sigma = sigma 62 | self.J = J 63 | self.fix_param = fix_param 64 | self.restrict = restrict 65 | 66 | eq_layers = [] 67 | 68 | self.LAYER = 0 69 | 70 | if restrict == self.LAYER: 71 | gspace, _, _ = self.gspace.restrict(self.sgid) 72 | 73 | r1 = FieldType(gspace, [gspace.trivial_repr] * n_channels) 74 | 75 | # 28 px 76 | # Convolutional Layer 1 77 | 78 | self.LAYER += 1 79 | eq_layers += self.build_layer_quotient(r1, 24, 9, 0, None) 80 | 81 | # Convolutional Layer 2 82 | self.LAYER += 1 83 | eq_layers += self.build_layer_quotient(eq_layers[-1].out_type, 32, 7, 3, 2) 84 | 85 | # 14 px 86 | # Convolutional Layer 3 87 | self.LAYER += 1 88 | eq_layers += self.build_layer_quotient(eq_layers[-1].out_type, 36, 7, 3, None) 89 | 90 | # Convolutional Layer 4 91 | self.LAYER += 1 92 | eq_layers += self.build_layer_quotient(eq_layers[-1].out_type, 36, 7, 3, 2) 93 | 94 | # 7 px 95 | # Convolutional Layer 5 96 | self.LAYER += 1 97 | eq_layers += self.build_layer_quotient(eq_layers[-1].out_type, 64, 7, 3) 98 | 99 | # Convolutional Layer 6 100 | self.LAYER += 1 101 | eq_layers += self.build_layer_quotient(eq_layers[-1].out_type, 96, 5, 0, None, True) 102 | 103 | # Adaptive Pooling 104 | mpl = PointwiseAdaptiveMaxPool(eq_layers[-1].out_type, 1) 105 | eq_layers.append(mpl) 106 | 107 | # 1 px 108 | 109 | # c = 96 110 | c = eq_layers[-1].out_type.size 111 | 112 | self.in_repr = eq_layers[0].in_type 113 | self.eq_layers = torch.nn.ModuleList(eq_layers) 114 | 115 | # Fully Connected 116 | self.fully_net = nn.Sequential( 117 | nn.Dropout(p=p_drop_fully), 118 | nn.Linear(c, 96), 119 | nn.BatchNorm1d(96), 120 | nn.ELU(inplace=True), 121 | 122 | nn.Dropout(p=p_drop_fully), 123 | nn.Linear(96, 96), 124 | nn.BatchNorm1d(96), 125 | nn.ELU(inplace=True), 126 | 127 | nn.Dropout(p=p_drop_fully), 128 | nn.Linear(96, n_classes), 129 | ) 130 | 131 | def forward(self, input: torch.tensor): 132 | x = GeometricTensor(input, self.in_repr) 133 | 134 | for layer in self.eq_layers: 135 | x = layer(x) 136 | 137 | x = self.fully_net(x.tensor.reshape(x.tensor.shape[0], -1)) 138 | 139 | return x 140 | 141 | def build_quotient_feature_type(self, gspace): 142 | 143 | assert gspace.fibergroup.order() > 0 144 | if isinstance(gspace, FlipRot2dOnR2): 145 | n = int(gspace.fibergroup.order() / 2) 146 | repr = [gspace.regular_repr] * 5 147 | for i in [0, round(n / 4), round(n / 2)]: 148 | repr += [gspace.quotient_repr((int(i), 1))] * 2 149 | repr += [gspace.quotient_repr((None, int(n / 2)))] * 2 150 | repr += [gspace.trivial_repr] * int(gspace.fibergroup.order() / 4) 151 | elif isinstance(gspace, Rot2dOnR2): 152 | n = gspace.fibergroup.order() 153 | repr = [gspace.regular_repr] * 5 154 | repr += [gspace.quotient_repr(int(round(n / 2)))] * 2 155 | repr += [gspace.quotient_repr(int(round(n / 4)))] * 2 156 | repr += [gspace.trivial_repr] * int(gspace.fibergroup.order() / 4) 157 | else: 158 | repr = [gspace.regular_repr] 159 | 160 | return repr 161 | 162 | def build_layer_quotient(self, r1: FieldType, C: int, s: int, padding: int = 0, pooling: int = None, 163 | orientation_pooling: bool = False): 164 | 165 | gspace = r1.gspace 166 | 167 | if self.fix_param and not orientation_pooling and self.LAYER > 1: 168 | # to keep number of parameters more or less constant when changing groups 169 | # (more precisely, we try to keep them close to the number of parameters in the original SFCNN) 170 | t = gspace.fibergroup.order() / 16 171 | C = C / np.sqrt(t) 172 | 173 | layers = [] 174 | 175 | repr = self.build_quotient_feature_type(gspace) 176 | 177 | C /= sum([r.size for r in repr]) / gspace.fibergroup.order() 178 | 179 | C = int(round(C)) 180 | 181 | r2 = FieldType(gspace, repr * C).sorted() 182 | 183 | cl = R2Conv(r1, r2, s, 184 | frequencies_cutoff=self.frequencies_cutoff, 185 | padding=padding, 186 | sigma=self.sigma, 187 | maximum_offset=self.J) 188 | layers.append(cl) 189 | 190 | if self.restrict == self.LAYER: 191 | layers.append(RestrictionModule(layers[-1].out_type, self.sgid)) 192 | layers.append(DisentangleModule(layers[-1].out_type)) 193 | 194 | if orientation_pooling: 195 | pl = GroupPooling(layers[-1].out_type) 196 | layers.append(pl) 197 | 198 | bn = InnerBatchNorm(layers[-1].out_type) 199 | layers.append(bn) 200 | nnl = ELU(layers[-1].out_type, inplace=True) 201 | layers.append(nnl) 202 | 203 | if pooling is not None: 204 | pl = PointwiseMaxPool(layers[-1].out_type, pooling) 205 | layers.append(pl) 206 | 207 | return layers 208 | 209 | -------------------------------------------------------------------------------- /experiments/optimizers_L1L2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modify optimizers for including L1 loss 3 | Per default they only support weight decay (i.e. L2) 4 | """ 5 | 6 | import math 7 | from torch.optim import Optimizer 8 | 9 | ######################################################################################################################## 10 | # Code adapted from 11 | # - https://pytorch.org/docs/stable/_modules/torch/optim/sgd.html#SGD 12 | # - https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam 13 | # which have the following License: 14 | ######################################################################################################################## 15 | # From PyTorch: 16 | # 17 | # Copyright (c) 2016- Facebook, Inc (Adam Paszke) 18 | # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 19 | # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 20 | # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 21 | # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 22 | # Copyright (c) 2011-2013 NYU (Clement Farabet) 23 | # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 24 | # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 25 | # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 26 | # 27 | # From Caffe2: 28 | # 29 | # Copyright (c) 2016-present, Facebook Inc. All rights reserved. 30 | # 31 | # All contributions by Facebook: 32 | # Copyright (c) 2016 Facebook Inc. 33 | # 34 | # All contributions by Google: 35 | # Copyright (c) 2015 Google Inc. 36 | # All rights reserved. 37 | # 38 | # All contributions by Yangqing Jia: 39 | # Copyright (c) 2015 Yangqing Jia 40 | # All rights reserved. 41 | # 42 | # All contributions by Kakao Brain: 43 | # Copyright 2019-2020 Kakao Brain 44 | # 45 | # All contributions from Caffe: 46 | # Copyright(c) 2013, 2014, 2015, the respective contributors 47 | # All rights reserved. 48 | # 49 | # All other contributions: 50 | # Copyright(c) 2015, 2016 the respective contributors 51 | # All rights reserved. 52 | # 53 | # Caffe2 uses a copyright model similar to Caffe: each contributor holds 54 | # copyright over their contributions to Caffe2. The project versioning records 55 | # all such contribution and copyright details. If a contributor wants to further 56 | # mark their specific copyright on a particular contribution, they should 57 | # indicate their copyright solely in the commit message of the change when it is 58 | # committed. 59 | # 60 | # All rights reserved. 61 | # 62 | # Redistribution and use in source and binary forms, with or without 63 | # modification, are permitted provided that the following conditions are met: 64 | # 65 | # 1. Redistributions of source code must retain the above copyright 66 | # notice, this list of conditions and the following disclaimer. 67 | # 68 | # 2. Redistributions in binary form must reproduce the above copyright 69 | # notice, this list of conditions and the following disclaimer in the 70 | # documentation and/or other materials provided with the distribution. 71 | # 72 | # 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 73 | # and IDIAP Research Institute nor the names of its contributors may be 74 | # used to endorse or promote products derived from this software without 75 | # specific prior written permission. 76 | # 77 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 78 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 79 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 80 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 81 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 82 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 83 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 84 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 85 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 86 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 87 | # POSSIBILITY OF SUCH DAMAGE. 88 | ######################################################################################################################## 89 | 90 | class SGD(Optimizer): 91 | r"""Implements stochastic gradient descent (optionally with momentum). 92 | 93 | Nesterov momentum is based on the formula from 94 | `On the importance of initialization and momentum in deep learning`__. 95 | 96 | Args: 97 | params (iterable): iterable of parameters to optimize or dicts defining 98 | parameter groups 99 | lr (float): learning rate 100 | momentum (float, optional): momentum factor (default: 0) 101 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 102 | dampening (float, optional): dampening for momentum (default: 0) 103 | nesterov (bool, optional): enables Nesterov momentum (default: False) 104 | 105 | Example: 106 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 107 | >>> optimizer.zero_grad() 108 | >>> loss_fn(model(input), target).backward() 109 | >>> optimizer.step() 110 | 111 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 112 | 113 | .. note:: 114 | The implementation of SGD with Momentum/Nesterov subtly differs from 115 | Sutskever et. al. and implementations in some other frameworks. 116 | 117 | Considering the specific case of Momentum, the update can be written as 118 | 119 | .. math:: 120 | v = \rho * v + g \\ 121 | p = p - lr * v 122 | 123 | where p, g, v and :math:`\rho` denote the parameters, gradient, velocity, and 124 | momentum respectively. 125 | 126 | This is in constrast to Sutskever et. al. and 127 | other frameworks which employ an update of the form 128 | 129 | .. math:: 130 | v = \rho * v + lr * g \\ 131 | p = p - v 132 | 133 | The Nesterov version is analogously modified. 134 | """ 135 | 136 | def __init__(self, params, lr, momentum=0, dampening=0, lamb_L1=0, lamb_L2=0, nesterov=False): 137 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, lamb_L1=lamb_L1, lamb_L2=lamb_L2, nesterov=nesterov) 138 | if nesterov and (momentum <= 0 or dampening != 0): 139 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 140 | super(SGD, self).__init__(params, defaults) 141 | 142 | def __setstate__(self, state): 143 | super(SGD, self).__setstate__(state) 144 | for group in self.param_groups: 145 | group.setdefault('nesterov', False) 146 | 147 | def step(self, closure=None): 148 | """Performs a single optimization step. 149 | Arguments: 150 | closure (callable, optional): A closure that reevaluates the model 151 | and returns the loss. 152 | """ 153 | loss = None 154 | if closure is not None: 155 | loss = closure() 156 | 157 | for group in self.param_groups: 158 | momentum = group['momentum'] 159 | dampening = group['dampening'] 160 | nesterov = group['nesterov'] 161 | 162 | lamb_L1 = group['lamb_L1'] 163 | lamb_L2 = group['lamb_L2'] 164 | 165 | for p in group['params']: 166 | if p.grad is None: 167 | continue 168 | d_p = p.grad.data 169 | if lamb_L1 != 0: 170 | d_p.add_(lamb_L1, p.sign().data) 171 | if lamb_L2 != 0: 172 | d_p.add_(lamb_L2, p.data) 173 | if momentum != 0: 174 | param_state = self.state[p] 175 | if 'momentum_buffer' not in param_state: 176 | buf = param_state['momentum_buffer'] = d_p.clone() 177 | else: 178 | buf = param_state['momentum_buffer'] 179 | buf.mul_(momentum).add_(1 - dampening, d_p) 180 | if nesterov: 181 | d_p = d_p.add(momentum, buf) 182 | else: 183 | d_p = buf 184 | 185 | p.data.add_(-group['lr'], d_p) 186 | 187 | return loss 188 | 189 | 190 | 191 | class Adam(Optimizer): 192 | """Implements Adam algorithm. 193 | 194 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 195 | 196 | Arguments: 197 | params (iterable): iterable of parameters to optimize or dicts defining 198 | parameter groups 199 | lr (float, optional): learning rate (default: 1e-3) 200 | betas (Tuple[float, float], optional): coefficients used for computing 201 | running averages of gradient and its square (default: (0.9, 0.999)) 202 | eps (float, optional): term added to the denominator to improve 203 | numerical stability (default: 1e-8) 204 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 205 | 206 | .. _Adam\: A Method for Stochastic Optimization: 207 | https://arxiv.org/abs/1412.6980 208 | """ 209 | 210 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, lamb_L1=0, lamb_L2=0): 211 | defaults = dict(lr=lr, betas=betas, eps=eps, lamb_L1=lamb_L1, lamb_L2=lamb_L2) 212 | super(Adam, self).__init__(params, defaults) 213 | 214 | def step(self, closure=None): 215 | """Performs a single optimization step. 216 | 217 | Arguments: 218 | closure (callable, optional): A closure that reevaluates the model 219 | and returns the loss. 220 | """ 221 | loss = None 222 | if closure is not None: 223 | loss = closure() 224 | 225 | for group in self.param_groups: 226 | for p in group['params']: 227 | if p.grad is None: 228 | continue 229 | grad = p.grad.data 230 | state = self.state[p] 231 | 232 | # State initialization 233 | if len(state) == 0: 234 | state['step'] = 0 235 | # Exponential moving average of gradient values 236 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 237 | # Exponential moving average of squared gradient values 238 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 239 | 240 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 241 | beta1, beta2 = group['betas'] 242 | 243 | state['step'] += 1 244 | 245 | if group['lamb_L1'] != 0: 246 | grad.add_(group['lamb_L1'], p.sign().data) 247 | if group['lamb_L2'] != 0: 248 | grad.add_(group['lamb_L2'], p.data) 249 | 250 | # Decay the first and second moment running average coefficient 251 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 252 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 253 | 254 | denom = exp_avg_sq.sqrt().add_(group['eps']) 255 | 256 | bias_correction1 = 1 - beta1 ** state['step'] 257 | bias_correction2 = 1 - beta2 ** state['step'] 258 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 259 | 260 | p.data.addcdiv_(-step_size, exp_avg, denom) 261 | 262 | return loss 263 | 264 | 265 | -------------------------------------------------------------------------------- /experiments/datasets/cifar100/autoaugment.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################################## 2 | # Code adapted from https://github.com/DeepVoltaire/AutoAugment 3 | # with the following MIT Licence: 4 | ######################################################################################################################## 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2018 Philip Popien 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | # 28 | ######################################################################################################################## 29 | from PIL import Image, ImageEnhance, ImageOps 30 | import numpy as np 31 | import random 32 | 33 | 34 | class ImageNetPolicy(object): 35 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 36 | 37 | Example: 38 | >>> policy = ImageNetPolicy() 39 | >>> transformed = policy(image) 40 | 41 | Example as a PyTorch Transform: 42 | >>> transform=transforms.Compose([ 43 | >>> transforms.Resize(256), 44 | >>> ImageNetPolicy(), 45 | >>> transforms.ToTensor()]) 46 | """ 47 | 48 | def __init__(self, fillcolor=(128, 128, 128)): 49 | self.policies = [ 50 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 51 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 52 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 53 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 54 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 55 | 56 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 57 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 58 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 59 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 60 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 61 | 62 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 63 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 64 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 65 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 66 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 67 | 68 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 69 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 70 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 71 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 72 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 73 | 74 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 75 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 76 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 77 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) 78 | ] 79 | 80 | def __call__(self, img): 81 | policy_idx = random.randint(0, len(self.policies) - 1) 82 | return self.policies[policy_idx](img) 83 | 84 | def __repr__(self): 85 | return "AutoAugment ImageNet Policy" 86 | 87 | 88 | class CIFAR10Policy(object): 89 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 90 | 91 | Example: 92 | >>> policy = CIFAR10Policy() 93 | >>> transformed = policy(image) 94 | 95 | Example as a PyTorch Transform: 96 | >>> transform=transforms.Compose([ 97 | >>> transforms.Resize(256), 98 | >>> CIFAR10Policy(), 99 | >>> transforms.ToTensor()]) 100 | """ 101 | 102 | def __init__(self, fillcolor=(128, 128, 128)): 103 | self.policies = [ 104 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 105 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 106 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 107 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 108 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 109 | 110 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 111 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 112 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 113 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 114 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 115 | 116 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 117 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 118 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 119 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 120 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 121 | 122 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 123 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 124 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 125 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 126 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 127 | 128 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 129 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 130 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 131 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 132 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 133 | ] 134 | 135 | def __call__(self, img): 136 | policy_idx = random.randint(0, len(self.policies) - 1) 137 | return self.policies[policy_idx](img) 138 | 139 | def __repr__(self): 140 | return "AutoAugment CIFAR10 Policy" 141 | 142 | 143 | class SVHNPolicy(object): 144 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 145 | 146 | Example: 147 | >>> policy = SVHNPolicy() 148 | >>> transformed = policy(image) 149 | 150 | Example as a PyTorch Transform: 151 | >>> transform=transforms.Compose([ 152 | >>> transforms.Resize(256), 153 | >>> SVHNPolicy(), 154 | >>> transforms.ToTensor()]) 155 | """ 156 | 157 | def __init__(self, fillcolor=(128, 128, 128)): 158 | self.policies = [ 159 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 160 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 161 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 162 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 163 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 164 | 165 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 166 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 167 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 168 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 169 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 170 | 171 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 172 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 173 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 174 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 175 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 176 | 177 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 178 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 179 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 180 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 181 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 182 | 183 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 184 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 185 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 186 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 187 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 188 | ] 189 | 190 | def __call__(self, img): 191 | policy_idx = random.randint(0, len(self.policies) - 1) 192 | return self.policies[policy_idx](img) 193 | 194 | def __repr__(self): 195 | return "AutoAugment SVHN Policy" 196 | 197 | 198 | class SubPolicy(object): 199 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 200 | ranges = { 201 | "shearX": np.linspace(0, 0.3, 10), 202 | "shearY": np.linspace(0, 0.3, 10), 203 | "translateX": np.linspace(0, 150 / 331, 10), 204 | "translateY": np.linspace(0, 150 / 331, 10), 205 | "rotate": np.linspace(0, 30, 10), 206 | "color": np.linspace(0.0, 0.9, 10), 207 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 208 | "solarize": np.linspace(256, 0, 10), 209 | "contrast": np.linspace(0.0, 0.9, 10), 210 | "sharpness": np.linspace(0.0, 0.9, 10), 211 | "brightness": np.linspace(0.0, 0.9, 10), 212 | "autocontrast": [0] * 10, 213 | "equalize": [0] * 10, 214 | "invert": [0] * 10 215 | } 216 | 217 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 218 | def rotate_with_fill(img, magnitude): 219 | rot = img.convert("RGBA").rotate(magnitude) 220 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 221 | 222 | func = { 223 | "shearX": lambda img, magnitude: img.transform( 224 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 225 | Image.BICUBIC, fillcolor=fillcolor), 226 | "shearY": lambda img, magnitude: img.transform( 227 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 228 | Image.BICUBIC, fillcolor=fillcolor), 229 | "translateX": lambda img, magnitude: img.transform( 230 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 231 | fillcolor=fillcolor), 232 | "translateY": lambda img, magnitude: img.transform( 233 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 234 | fillcolor=fillcolor), 235 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 236 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 237 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 238 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 239 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 240 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 241 | 1 + magnitude * random.choice([-1, 1])), 242 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 243 | 1 + magnitude * random.choice([-1, 1])), 244 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 245 | 1 + magnitude * random.choice([-1, 1])), 246 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 247 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 248 | "invert": lambda img, magnitude: ImageOps.invert(img) 249 | } 250 | 251 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 252 | # operation1, ranges[operation1][magnitude_idx1], 253 | # operation2, ranges[operation2][magnitude_idx2]) 254 | self.p1 = p1 255 | self.operation1 = func[operation1] 256 | self.magnitude1 = ranges[operation1][magnitude_idx1] 257 | self.p2 = p2 258 | self.operation2 = func[operation2] 259 | self.magnitude2 = ranges[operation2][magnitude_idx2] 260 | 261 | def __call__(self, img): 262 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 263 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 264 | return img 265 | -------------------------------------------------------------------------------- /experiments/datasets/cifar10/autoaugment.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################################## 2 | # Code adapted from https://github.com/DeepVoltaire/AutoAugment 3 | # with the following MIT Licence: 4 | ######################################################################################################################## 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2018 Philip Popien 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | # 28 | ######################################################################################################################## 29 | 30 | from PIL import Image, ImageEnhance, ImageOps 31 | import numpy as np 32 | import random 33 | 34 | 35 | class ImageNetPolicy(object): 36 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 37 | 38 | Example: 39 | >>> policy = ImageNetPolicy() 40 | >>> transformed = policy(image) 41 | 42 | Example as a PyTorch Transform: 43 | >>> transform=transforms.Compose([ 44 | >>> transforms.Resize(256), 45 | >>> ImageNetPolicy(), 46 | >>> transforms.ToTensor()]) 47 | """ 48 | 49 | def __init__(self, fillcolor=(128, 128, 128)): 50 | self.policies = [ 51 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 52 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 53 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 54 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 55 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 56 | 57 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 58 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 59 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 60 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 61 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 62 | 63 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 64 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 65 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 66 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 67 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 68 | 69 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 70 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 71 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 72 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 73 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 74 | 75 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 76 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 77 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 78 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) 79 | ] 80 | 81 | def __call__(self, img): 82 | policy_idx = random.randint(0, len(self.policies) - 1) 83 | return self.policies[policy_idx](img) 84 | 85 | def __repr__(self): 86 | return "AutoAugment ImageNet Policy" 87 | 88 | 89 | class CIFAR10Policy(object): 90 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 91 | 92 | Example: 93 | >>> policy = CIFAR10Policy() 94 | >>> transformed = policy(image) 95 | 96 | Example as a PyTorch Transform: 97 | >>> transform=transforms.Compose([ 98 | >>> transforms.Resize(256), 99 | >>> CIFAR10Policy(), 100 | >>> transforms.ToTensor()]) 101 | """ 102 | 103 | def __init__(self, fillcolor=(128, 128, 128)): 104 | self.policies = [ 105 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 106 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 107 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 108 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 109 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 110 | 111 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 112 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 113 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 114 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 115 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 116 | 117 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 118 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 119 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 120 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 121 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 122 | 123 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 124 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 125 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 126 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 127 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 128 | 129 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 130 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 131 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 132 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 133 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 134 | ] 135 | 136 | def __call__(self, img): 137 | policy_idx = random.randint(0, len(self.policies) - 1) 138 | return self.policies[policy_idx](img) 139 | 140 | def __repr__(self): 141 | return "AutoAugment CIFAR10 Policy" 142 | 143 | 144 | class SVHNPolicy(object): 145 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 146 | 147 | Example: 148 | >>> policy = SVHNPolicy() 149 | >>> transformed = policy(image) 150 | 151 | Example as a PyTorch Transform: 152 | >>> transform=transforms.Compose([ 153 | >>> transforms.Resize(256), 154 | >>> SVHNPolicy(), 155 | >>> transforms.ToTensor()]) 156 | """ 157 | 158 | def __init__(self, fillcolor=(128, 128, 128)): 159 | self.policies = [ 160 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 161 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 162 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 163 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 164 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 165 | 166 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 167 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 168 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 169 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 170 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 171 | 172 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 173 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 174 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 175 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 176 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 177 | 178 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 179 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 180 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 181 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 182 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 183 | 184 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 185 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 186 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 187 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 188 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 189 | ] 190 | 191 | def __call__(self, img): 192 | policy_idx = random.randint(0, len(self.policies) - 1) 193 | return self.policies[policy_idx](img) 194 | 195 | def __repr__(self): 196 | return "AutoAugment SVHN Policy" 197 | 198 | 199 | class SubPolicy(object): 200 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 201 | ranges = { 202 | "shearX": np.linspace(0, 0.3, 10), 203 | "shearY": np.linspace(0, 0.3, 10), 204 | "translateX": np.linspace(0, 150 / 331, 10), 205 | "translateY": np.linspace(0, 150 / 331, 10), 206 | "rotate": np.linspace(0, 30, 10), 207 | "color": np.linspace(0.0, 0.9, 10), 208 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 209 | "solarize": np.linspace(256, 0, 10), 210 | "contrast": np.linspace(0.0, 0.9, 10), 211 | "sharpness": np.linspace(0.0, 0.9, 10), 212 | "brightness": np.linspace(0.0, 0.9, 10), 213 | "autocontrast": [0] * 10, 214 | "equalize": [0] * 10, 215 | "invert": [0] * 10 216 | } 217 | 218 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 219 | def rotate_with_fill(img, magnitude): 220 | rot = img.convert("RGBA").rotate(magnitude) 221 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 222 | 223 | func = { 224 | "shearX": lambda img, magnitude: img.transform( 225 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 226 | Image.BICUBIC, fillcolor=fillcolor), 227 | "shearY": lambda img, magnitude: img.transform( 228 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 229 | Image.BICUBIC, fillcolor=fillcolor), 230 | "translateX": lambda img, magnitude: img.transform( 231 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 232 | fillcolor=fillcolor), 233 | "translateY": lambda img, magnitude: img.transform( 234 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 235 | fillcolor=fillcolor), 236 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 237 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 238 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 239 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 240 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 241 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 242 | 1 + magnitude * random.choice([-1, 1])), 243 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 244 | 1 + magnitude * random.choice([-1, 1])), 245 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 246 | 1 + magnitude * random.choice([-1, 1])), 247 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 248 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 249 | "invert": lambda img, magnitude: ImageOps.invert(img) 250 | } 251 | 252 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 253 | # operation1, ranges[operation1][magnitude_idx1], 254 | # operation2, ranges[operation2][magnitude_idx2]) 255 | self.p1 = p1 256 | self.operation1 = func[operation1] 257 | self.magnitude1 = ranges[operation1][magnitude_idx1] 258 | self.p2 = p2 259 | self.operation2 = func[operation2] 260 | self.magnitude2 = ranges[operation2][magnitude_idx2] 261 | 262 | def __call__(self, img): 263 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 264 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 265 | return img 266 | -------------------------------------------------------------------------------- /experiments/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from scipy import stats 7 | 8 | 9 | __all__ = [ 10 | 'wrn16_8_stl', 11 | ] 12 | 13 | 14 | ######################################################################################################################## 15 | # The following piece of code was adapted from https://github.com/uoguelph-mlrg/Cutout/blob/master/model/wide_resnet.py 16 | # which has the following ECL-2.0 license: 17 | ######################################################################################################################## 18 | # Educational Community License, Version 2.0 (ECL-2.0) 19 | # 20 | # Version 2.0, April 2007 21 | # 22 | # http://www.osedu.org/licenses/ 23 | # 24 | # The Educational Community License version 2.0 ("ECL") consists of the Apache 2.0 license, modified to change the scope of the patent grant in section 3 to be specific to the needs of the education communities using this license. The original Apache 2.0 license can be found at: http://www.apache.org/licenses /LICENSE-2.0 25 | # 26 | # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 27 | # 28 | # 1. Definitions. 29 | # 30 | # "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 31 | # 32 | # "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 33 | # 34 | # "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 35 | # 36 | # "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 37 | # 38 | # "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 39 | # 40 | # "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 41 | # 42 | # "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 43 | # 44 | # "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 45 | # 46 | # "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 47 | # 48 | # "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 49 | # 50 | # 2. Grant of Copyright License. 51 | # 52 | # Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 53 | # 54 | # 3. Grant of Patent License. 55 | # 56 | # Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. Any patent license granted hereby with respect to contributions by an individual employed by an institution or organization is limited to patent claims where the individual that is the author of the Work is also the inventor of the patent claims licensed, and where the organization or institution has the right to grant such license under applicable grant and research funding agreements. No other express or implied licenses are granted. 57 | # 58 | # 4. Redistribution. 59 | # 60 | # You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 61 | # 62 | # You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 63 | # 64 | # 5. Submission of Contributions. 65 | # 66 | # Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 67 | # 68 | # 6. Trademarks. 69 | # 70 | # This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 71 | # 72 | # 7. Disclaimer of Warranty. 73 | # 74 | # Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 75 | # 76 | # 8. Limitation of Liability. 77 | # 78 | # In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 79 | # 80 | # 9. Accepting Warranty or Additional Liability. 81 | # 82 | # While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 83 | # 84 | # END OF TERMS AND CONDITIONS 85 | # 86 | # APPENDIX: How to apply the Educational Community License to your work 87 | # 88 | # To apply the Educational Community License to your work, attach 89 | # the following boilerplate notice, with the fields enclosed by 90 | # brackets "[]" replaced with your own identifying information. 91 | # (Don't include the brackets!) The text should be enclosed in the 92 | # appropriate comment syntax for the file format. We also recommend 93 | # that a file or class name and description of purpose be included on 94 | # the same "printed page" as the copyright notice for easier 95 | # identification within third-party archives. 96 | # 97 | # Copyright [yyyy] [name of copyright owner] Licensed under the 98 | # Educational Community License, Version 2.0 (the "License"); you may 99 | # not use this file except in compliance with the License. You may 100 | # obtain a copy of the License at 101 | # 102 | # http://www.osedu.org/licenses /ECL-2.0 103 | # 104 | # Unless required by applicable law or agreed to in writing, 105 | # software distributed under the License is distributed on an "AS IS" 106 | # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 107 | # or implied. See the License for the specific language governing 108 | # permissions and limitations under the license. 109 | 110 | class BasicBlock(nn.Module): 111 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 112 | super(BasicBlock, self).__init__() 113 | self.bn1 = nn.BatchNorm2d(in_planes) 114 | self.relu1 = nn.ReLU(inplace=True) 115 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 116 | padding=1, bias=False) 117 | self.bn2 = nn.BatchNorm2d(out_planes) 118 | self.relu2 = nn.ReLU(inplace=True) 119 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 120 | padding=1, bias=False) 121 | self.droprate = dropRate 122 | self.equalInOut = (in_planes == out_planes) 123 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 124 | padding=0, bias=False) or None 125 | 126 | def forward(self, x): 127 | 128 | if not self.equalInOut: 129 | x = self.relu1(self.bn1(x)) 130 | else: 131 | out = self.relu1(self.bn1(x)) 132 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 133 | if self.droprate > 0: 134 | out = F.dropout(out, p=self.droprate, training=self.training) 135 | out = self.conv2(out) 136 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 137 | 138 | 139 | class NetworkBlock(nn.Module): 140 | 141 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 142 | super(NetworkBlock, self).__init__() 143 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 144 | 145 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 146 | layers = [] 147 | for i in range(nb_layers): 148 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | return self.layer(x) 153 | 154 | 155 | class WideResNet(nn.Module): 156 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, initial_stride: int = 1, deltaorth: bool = False): 157 | super(WideResNet, self).__init__() 158 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 159 | assert((depth - 4) % 6 == 0) 160 | n = int((depth - 4) / 6) 161 | 162 | block = BasicBlock 163 | # 1st conv before any network block 164 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 165 | padding=1, bias=False) 166 | # 1st block 167 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, initial_stride, dropRate) 168 | # 2nd block 169 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 170 | # 3rd block 171 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 172 | # global average pooling and classifier 173 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.fc = nn.Linear(nChannels[3], num_classes) 176 | self.nChannels = nChannels[3] 177 | 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | if deltaorth: 181 | # delta orthogonal intialization for the Pytorch's 1x1 Conv 182 | o, i, w, h = m.weight.shape 183 | if o >= i: 184 | m.weight.data.fill_(0.) 185 | m.weight.data[:, :, w//2, h//2] = torch.tensor(stats.ortho_group.rvs(max(i, o))[:o, :i]) 186 | 187 | else: 188 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 189 | m.weight.data.normal_(0, math.sqrt(2. / n)) 190 | elif isinstance(m, nn.BatchNorm2d): 191 | m.weight.data.fill_(1) 192 | m.bias.data.zero_() 193 | elif isinstance(m, nn.Linear): 194 | m.bias.data.zero_() 195 | 196 | def forward(self, x): 197 | out = self.conv1(x) 198 | out = self.block1(out) 199 | out = self.block2(out) 200 | out = self.block3(out) 201 | out = self.relu(self.bn1(out)) 202 | 203 | b, c, w, h = out.shape 204 | 205 | out = F.avg_pool2d(out, (w, h)) 206 | out = out.view(-1, self.nChannels) 207 | out = self.fc(out) 208 | return out 209 | 210 | 211 | def wrn16_8_stl(**kwargs): 212 | """Constructs a Wide ResNet 16-8 model with initial stride of 2 as mentioned here: 213 | https://github.com/uoguelph-mlrg/Cutout/issues/2 214 | 215 | """ 216 | model = WideResNet(16, widen_factor=8, dropRate=0.3, initial_stride=2, **kwargs) 217 | return model 218 | -------------------------------------------------------------------------------- /experiments/datasets/cifar10/data_loader_cifar10.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | 9 | from .autoaugment import CIFAR10Policy 10 | 11 | DATA_DIR = "./datasets/cifar10/" #cifar-10-batches-py" 12 | 13 | MEAN = np.array([125.3, 123.0, 113.9]) / 255.0 # = np.array([0.49137255, 0.48235294, 0.44666667]) 14 | STD = np.array([63.0, 62.1, 66.7]) / 255.0 # = np.array([0.24705882, 0.24352941, 0.26156863]) 15 | 16 | 17 | ######################################################################################################################## 18 | # The following piece of code was adapted from https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 19 | # which has the following ECL-2.0 license: 20 | ######################################################################################################################## 21 | # Educational Community License, Version 2.0 (ECL-2.0) 22 | # 23 | # Version 2.0, April 2007 24 | # 25 | # http://www.osedu.org/licenses/ 26 | # 27 | # The Educational Community License version 2.0 ("ECL") consists of the Apache 2.0 license, modified to change the scope of the patent grant in section 3 to be specific to the needs of the education communities using this license. The original Apache 2.0 license can be found at: http://www.apache.org/licenses /LICENSE-2.0 28 | # 29 | # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 30 | # 31 | # 1. Definitions. 32 | # 33 | # "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 34 | # 35 | # "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 36 | # 37 | # "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 38 | # 39 | # "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 40 | # 41 | # "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 42 | # 43 | # "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 44 | # 45 | # "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 46 | # 47 | # "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 48 | # 49 | # "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 50 | # 51 | # "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 52 | # 53 | # 2. Grant of Copyright License. 54 | # 55 | # Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 56 | # 57 | # 3. Grant of Patent License. 58 | # 59 | # Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. Any patent license granted hereby with respect to contributions by an individual employed by an institution or organization is limited to patent claims where the individual that is the author of the Work is also the inventor of the patent claims licensed, and where the organization or institution has the right to grant such license under applicable grant and research funding agreements. No other express or implied licenses are granted. 60 | # 61 | # 4. Redistribution. 62 | # 63 | # You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 64 | # 65 | # You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 66 | # 67 | # 5. Submission of Contributions. 68 | # 69 | # Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 70 | # 71 | # 6. Trademarks. 72 | # 73 | # This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 74 | # 75 | # 7. Disclaimer of Warranty. 76 | # 77 | # Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 78 | # 79 | # 8. Limitation of Liability. 80 | # 81 | # In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 82 | # 83 | # 9. Accepting Warranty or Additional Liability. 84 | # 85 | # While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 86 | # 87 | # END OF TERMS AND CONDITIONS 88 | # 89 | # APPENDIX: How to apply the Educational Community License to your work 90 | # 91 | # To apply the Educational Community License to your work, attach 92 | # the following boilerplate notice, with the fields enclosed by 93 | # brackets "[]" replaced with your own identifying information. 94 | # (Don't include the brackets!) The text should be enclosed in the 95 | # appropriate comment syntax for the file format. We also recommend 96 | # that a file or class name and description of purpose be included on 97 | # the same "printed page" as the copyright notice for easier 98 | # identification within third-party archives. 99 | # 100 | # Copyright [yyyy] [name of copyright owner] Licensed under the 101 | # Educational Community License, Version 2.0 (the "License"); you may 102 | # not use this file except in compliance with the License. You may 103 | # obtain a copy of the License at 104 | # 105 | # http://www.osedu.org/licenses /ECL-2.0 106 | # 107 | # Unless required by applicable law or agreed to in writing, 108 | # software distributed under the License is distributed on an "AS IS" 109 | # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 110 | # or implied. See the License for the specific language governing 111 | # permissions and limitations under the license. 112 | 113 | class Cutout: 114 | 115 | """Randomly mask out a patch from an image. 116 | Args: 117 | size (int): The size of the square patch. 118 | """ 119 | def __init__(self, size): 120 | self.size = size 121 | 122 | def __call__(self, img): 123 | """ 124 | Args: 125 | img (Tensor): Tensor image 126 | Returns: 127 | Tensor: Image with a hole of dimension size x size cut out of it. 128 | """ 129 | h = img.size(1) 130 | w = img.size(2) 131 | 132 | mask = np.ones((h, w), np.float32) 133 | 134 | y = np.random.randint(h) 135 | x = np.random.randint(w) 136 | 137 | y1 = np.clip(y - self.size // 2, 0, h) 138 | y2 = np.clip(y + self.size // 2, 0, h) 139 | x1 = np.clip(x - self.size // 2, 0, w) 140 | x2 = np.clip(x + self.size // 2, 0, w) 141 | 142 | mask[y1: y2, x1: x2] = 0. 143 | 144 | mask = torch.from_numpy(mask) 145 | mask = mask.expand_as(img) 146 | img = img * mask 147 | 148 | return img 149 | 150 | ######################################################################################################################## 151 | 152 | 153 | def build_cifar10_loaders(batch_size, 154 | eval_batchsize, 155 | validation=True, 156 | num_workers=8, 157 | augment=False, 158 | reshuffle=True, 159 | ): 160 | 161 | # train_normalize = transforms.Normalize( 162 | # mean=[0.4914, 0.4822, 0.4465], 163 | # std=[0.2023, 0.1994, 0.2010], 164 | # ) 165 | 166 | # test_normalize = transforms.Normalize( 167 | # mean=[0.485, 0.456, 0.406], 168 | # std=[0.229, 0.224, 0.225], 169 | # ) 170 | 171 | normalize = transforms.Normalize( 172 | mean=MEAN, 173 | std=STD, 174 | ) 175 | 176 | # define transforms 177 | valid_transform = transforms.Compose([ 178 | transforms.ToTensor(), 179 | normalize, 180 | ]) 181 | 182 | if augment: 183 | train_transform = transforms.Compose([ 184 | transforms.RandomCrop(32, padding=4), 185 | transforms.RandomHorizontalFlip(), 186 | CIFAR10Policy(), 187 | transforms.ToTensor(), 188 | Cutout(16), 189 | normalize, 190 | ]) 191 | else: 192 | train_transform = transforms.Compose([ 193 | transforms.RandomCrop(32, padding=4), 194 | transforms.RandomHorizontalFlip(), 195 | transforms.ToTensor(), 196 | normalize, 197 | ]) 198 | # train_transform = transforms.Compose([ 199 | # transforms.ToTensor(), 200 | # normalize, 201 | # ]) 202 | 203 | # load the dataset 204 | train_dataset = datasets.CIFAR10( 205 | root=DATA_DIR, train=True, 206 | download=True, transform=train_transform, 207 | ) 208 | 209 | test_dataset = datasets.CIFAR10( 210 | root=DATA_DIR, train=False, 211 | download=True, transform=valid_transform, 212 | ) 213 | 214 | if validation: 215 | 216 | valid_dataset = datasets.CIFAR10( 217 | root=DATA_DIR, train=True, 218 | download=True, transform=valid_transform, 219 | ) 220 | num_train = len(train_dataset) 221 | indices = list(range(num_train)) 222 | split = int(np.floor(0.2 * num_train)) 223 | 224 | if reshuffle: 225 | np.random.shuffle(indices) 226 | 227 | train_idx, valid_idx = indices[split:], indices[:split] 228 | train_sampler = SubsetRandomSampler(train_idx) 229 | valid_sampler = SubsetRandomSampler(valid_idx) 230 | 231 | train_loader = torch.utils.data.DataLoader( 232 | train_dataset, batch_size=batch_size, sampler=train_sampler, 233 | num_workers=num_workers, pin_memory=True, 234 | ) 235 | valid_loader = torch.utils.data.DataLoader( 236 | valid_dataset, batch_size=eval_batchsize, sampler=valid_sampler, 237 | num_workers=num_workers, pin_memory=True, 238 | ) 239 | else: 240 | 241 | train_loader = torch.utils.data.DataLoader( 242 | train_dataset, batch_size=batch_size, shuffle=True, 243 | num_workers=num_workers, pin_memory=True, 244 | ) 245 | valid_loader = None 246 | 247 | test_loader = torch.utils.data.DataLoader( 248 | test_dataset, batch_size=eval_batchsize, shuffle=False, 249 | num_workers=num_workers, pin_memory=True, 250 | ) 251 | 252 | n_inputs = 3 253 | n_classes = 10 254 | 255 | return train_loader, valid_loader, test_loader, n_inputs, n_classes 256 | 257 | 258 | -------------------------------------------------------------------------------- /experiments/datasets/cifar100/data_loader_cifar100.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | 9 | from .autoaugment import CIFAR10Policy 10 | 11 | 12 | DATA_DIR = "./datasets/cifar100/" #cifar-100-batches-py" 13 | 14 | MEAN = np.array([125.3, 123.0, 113.9]) / 255.0 # = np.array([0.49137255, 0.48235294, 0.44666667]) 15 | STD = np.array([63.0, 62.1, 66.7]) / 255.0 # = np.array([0.24705882, 0.24352941, 0.26156863]) 16 | 17 | 18 | ######################################################################################################################## 19 | # The following piece of code was adapted from https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 20 | # which has the following ECL-2.0 license: 21 | ######################################################################################################################## 22 | # Educational Community License, Version 2.0 (ECL-2.0) 23 | # 24 | # Version 2.0, April 2007 25 | # 26 | # http://www.osedu.org/licenses/ 27 | # 28 | # The Educational Community License version 2.0 ("ECL") consists of the Apache 2.0 license, modified to change the scope of the patent grant in section 3 to be specific to the needs of the education communities using this license. The original Apache 2.0 license can be found at: http://www.apache.org/licenses /LICENSE-2.0 29 | # 30 | # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 31 | # 32 | # 1. Definitions. 33 | # 34 | # "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 35 | # 36 | # "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 37 | # 38 | # "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 39 | # 40 | # "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 41 | # 42 | # "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 43 | # 44 | # "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 45 | # 46 | # "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 47 | # 48 | # "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 49 | # 50 | # "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 51 | # 52 | # "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 53 | # 54 | # 2. Grant of Copyright License. 55 | # 56 | # Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 57 | # 58 | # 3. Grant of Patent License. 59 | # 60 | # Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. Any patent license granted hereby with respect to contributions by an individual employed by an institution or organization is limited to patent claims where the individual that is the author of the Work is also the inventor of the patent claims licensed, and where the organization or institution has the right to grant such license under applicable grant and research funding agreements. No other express or implied licenses are granted. 61 | # 62 | # 4. Redistribution. 63 | # 64 | # You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 65 | # 66 | # You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 67 | # 68 | # 5. Submission of Contributions. 69 | # 70 | # Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 71 | # 72 | # 6. Trademarks. 73 | # 74 | # This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 75 | # 76 | # 7. Disclaimer of Warranty. 77 | # 78 | # Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 79 | # 80 | # 8. Limitation of Liability. 81 | # 82 | # In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 83 | # 84 | # 9. Accepting Warranty or Additional Liability. 85 | # 86 | # While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 87 | # 88 | # END OF TERMS AND CONDITIONS 89 | # 90 | # APPENDIX: How to apply the Educational Community License to your work 91 | # 92 | # To apply the Educational Community License to your work, attach 93 | # the following boilerplate notice, with the fields enclosed by 94 | # brackets "[]" replaced with your own identifying information. 95 | # (Don't include the brackets!) The text should be enclosed in the 96 | # appropriate comment syntax for the file format. We also recommend 97 | # that a file or class name and description of purpose be included on 98 | # the same "printed page" as the copyright notice for easier 99 | # identification within third-party archives. 100 | # 101 | # Copyright [yyyy] [name of copyright owner] Licensed under the 102 | # Educational Community License, Version 2.0 (the "License"); you may 103 | # not use this file except in compliance with the License. You may 104 | # obtain a copy of the License at 105 | # 106 | # http://www.osedu.org/licenses /ECL-2.0 107 | # 108 | # Unless required by applicable law or agreed to in writing, 109 | # software distributed under the License is distributed on an "AS IS" 110 | # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 111 | # or implied. See the License for the specific language governing 112 | # permissions and limitations under the license. 113 | class Cutout: 114 | 115 | """Randomly mask out a patch from an image. 116 | Args: 117 | size (int): The size of the square patch. 118 | """ 119 | 120 | def __init__(self, size): 121 | self.size = size 122 | 123 | def __call__(self, img): 124 | """ 125 | Args: 126 | img (Tensor): Tensor image 127 | Returns: 128 | Tensor: Image with a hole of dimension size x size cut out of it. 129 | """ 130 | h = img.size(1) 131 | w = img.size(2) 132 | 133 | mask = np.ones((h, w), np.float32) 134 | 135 | y = np.random.randint(h) 136 | x = np.random.randint(w) 137 | 138 | y1 = np.clip(y - self.size // 2, 0, h) 139 | y2 = np.clip(y + self.size // 2, 0, h) 140 | x1 = np.clip(x - self.size // 2, 0, w) 141 | x2 = np.clip(x + self.size // 2, 0, w) 142 | 143 | mask[y1: y2, x1: x2] = 0. 144 | 145 | mask = torch.from_numpy(mask) 146 | mask = mask.expand_as(img) 147 | img = img * mask 148 | 149 | return img 150 | 151 | ######################################################################################################################## 152 | 153 | 154 | def build_cifar100_loaders(batch_size, 155 | eval_batchsize, 156 | validation=True, 157 | num_workers=8, 158 | augment=False, 159 | reshuffle=True, 160 | ): 161 | 162 | # train_normalize = transforms.Normalize( 163 | # mean=[0.4914, 0.4822, 0.4465], 164 | # std=[0.2023, 0.1994, 0.2010], 165 | # ) 166 | 167 | # test_normalize = transforms.Normalize( 168 | # mean=[0.485, 0.456, 0.406], 169 | # std=[0.229, 0.224, 0.225], 170 | # ) 171 | 172 | normalize = transforms.Normalize( 173 | mean=MEAN, 174 | std=STD, 175 | ) 176 | 177 | # define transforms 178 | valid_transform = transforms.Compose([ 179 | transforms.ToTensor(), 180 | normalize, 181 | ]) 182 | 183 | if augment: 184 | train_transform = transforms.Compose([ 185 | transforms.RandomCrop(32, padding=4), 186 | transforms.RandomHorizontalFlip(), 187 | CIFAR10Policy(), 188 | transforms.ToTensor(), 189 | Cutout(16), 190 | normalize, 191 | ]) 192 | else: 193 | train_transform = transforms.Compose([ 194 | transforms.RandomCrop(32, padding=4), 195 | transforms.RandomHorizontalFlip(), 196 | transforms.ToTensor(), 197 | normalize, 198 | ]) 199 | # train_transform = transforms.Compose([ 200 | # transforms.ToTensor(), 201 | # normalize, 202 | # ]) 203 | 204 | # load the dataset 205 | train_dataset = datasets.CIFAR100( 206 | root=DATA_DIR, train=True, 207 | download=True, transform=train_transform, 208 | ) 209 | 210 | test_dataset = datasets.CIFAR100( 211 | root=DATA_DIR, train=False, 212 | download=True, transform=valid_transform, 213 | ) 214 | 215 | if validation: 216 | 217 | valid_dataset = datasets.CIFAR100( 218 | root=DATA_DIR, train=True, 219 | download=True, transform=valid_transform, 220 | ) 221 | num_train = len(train_dataset) 222 | indices = list(range(num_train)) 223 | split = int(np.floor(0.2 * num_train)) 224 | 225 | if reshuffle: 226 | np.random.shuffle(indices) 227 | 228 | train_idx, valid_idx = indices[split:], indices[:split] 229 | train_sampler = SubsetRandomSampler(train_idx) 230 | valid_sampler = SubsetRandomSampler(valid_idx) 231 | 232 | train_loader = torch.utils.data.DataLoader( 233 | train_dataset, batch_size=batch_size, sampler=train_sampler, 234 | num_workers=num_workers, pin_memory=True, 235 | ) 236 | valid_loader = torch.utils.data.DataLoader( 237 | valid_dataset, batch_size=eval_batchsize, sampler=valid_sampler, 238 | num_workers=num_workers, pin_memory=True, 239 | ) 240 | else: 241 | 242 | train_loader = torch.utils.data.DataLoader( 243 | train_dataset, batch_size=batch_size, shuffle=True, 244 | num_workers=num_workers, pin_memory=True, 245 | ) 246 | valid_loader = None 247 | 248 | test_loader = torch.utils.data.DataLoader( 249 | test_dataset, batch_size=eval_batchsize, shuffle=False, 250 | num_workers=num_workers, pin_memory=True, 251 | ) 252 | 253 | n_inputs = 3 254 | n_classes = 100 255 | 256 | return train_loader, valid_loader, test_loader, n_inputs, n_classes 257 | 258 | 259 | --------------------------------------------------------------------------------