├── cifar ├── robustbench │ ├── leaderboard │ │ ├── __init__.py │ │ ├── leaderboard.html.j2 │ │ └── template.py │ ├── model_zoo │ │ ├── architectures │ │ │ ├── __init__.py │ │ │ ├── utils_architectures.py │ │ │ ├── wide_resnet.py │ │ │ └── resnext.py │ │ ├── __init__.py │ │ ├── enums.py │ │ ├── models.py │ │ └── imagenet.py │ ├── __init__.py │ ├── zenodo_download.py │ ├── loaders.py │ └── eval.py ├── run_cifar10_gradual.sh ├── run_cifar10.sh ├── run_cifar100.sh ├── cfgs │ ├── cifar10 │ │ ├── norm.yaml │ │ ├── source.yaml │ │ ├── tent.yaml │ │ └── cotta.yaml │ ├── cifar100 │ │ ├── norm.yaml │ │ ├── source.yaml │ │ ├── tent.yaml │ │ └── cotta.yaml │ └── 10orders │ │ ├── tent │ │ ├── tent0.yaml │ │ ├── tent1.yaml │ │ ├── tent2.yaml │ │ ├── tent3.yaml │ │ ├── tent4.yaml │ │ ├── tent5.yaml │ │ ├── tent6.yaml │ │ ├── tent7.yaml │ │ ├── tent8.yaml │ │ └── tent9.yaml │ │ └── cotta │ │ ├── cotta0.yaml │ │ ├── cotta1.yaml │ │ ├── cotta2.yaml │ │ ├── cotta3.yaml │ │ ├── cotta4.yaml │ │ ├── cotta5.yaml │ │ ├── cotta6.yaml │ │ ├── cotta7.yaml │ │ ├── cotta8.yaml │ │ └── cotta9.yaml ├── norm.py ├── tent.py ├── my_transforms.py ├── cifar100c.py ├── cifar10c.py ├── cifar10c_gradual.py ├── utils.py ├── conf.py └── cotta.py ├── imagenet ├── robustbench │ ├── leaderboard │ │ ├── __init__.py │ │ ├── leaderboard.html.j2 │ │ └── template.py │ ├── model_zoo │ │ ├── architectures │ │ │ ├── __init__.py │ │ │ ├── utils_architectures.py │ │ │ ├── wide_resnet.py │ │ │ └── resnext.py │ │ ├── __init__.py │ │ ├── enums.py │ │ ├── models.py │ │ └── imagenet.py │ ├── __init__.py │ ├── zenodo_download.py │ └── loaders.py ├── .gitignore ├── cfgs │ ├── norm.yaml │ ├── source.yaml │ └── 10orders │ │ ├── cotta │ │ ├── cotta0.yaml │ │ ├── cotta1.yaml │ │ ├── cotta2.yaml │ │ ├── cotta3.yaml │ │ ├── cotta4.yaml │ │ ├── cotta5.yaml │ │ ├── cotta6.yaml │ │ ├── cotta7.yaml │ │ ├── cotta8.yaml │ │ └── cotta9.yaml │ │ └── tent │ │ ├── tent0.yaml │ │ ├── tent1.yaml │ │ ├── tent2.yaml │ │ ├── tent3.yaml │ │ ├── tent4.yaml │ │ ├── tent5.yaml │ │ ├── tent6.yaml │ │ ├── tent7.yaml │ │ ├── tent8.yaml │ │ └── tent9.yaml ├── data │ └── ImageNet-C │ │ └── download.sh ├── run.sh ├── eval.py ├── norm.py ├── tent.py ├── my_transforms.py ├── imagenetc.py ├── conf.py └── cotta.py ├── setup_env.sh ├── LICENSE ├── .gitignore ├── README.md └── environment.yml /cifar/robustbench/leaderboard/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imagenet/robustbench/leaderboard/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import model_dicts 2 | 3 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import model_dicts 2 | 3 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | conda update conda 2 | conda env create -f environment.yml 3 | 4 | 5 | -------------------------------------------------------------------------------- /cifar/robustbench/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import load_cifar10 2 | from .utils import load_model 3 | from .eval import benchmark 4 | -------------------------------------------------------------------------------- /imagenet/robustbench/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import load_cifar10 2 | from .utils import load_model 3 | from .eval import benchmark 4 | -------------------------------------------------------------------------------- /imagenet/.gitignore: -------------------------------------------------------------------------------- 1 | ## General 2 | 3 | *.pyc 4 | __pycache__/ 5 | .ropeproject/ 6 | 7 | *.pkl 8 | *.npy 9 | 10 | *.ipynb_checkpoints 11 | 12 | .envrc 13 | 14 | ## Project 15 | 16 | *.pth 17 | 18 | /notebooks/*wild.ipynb 19 | /experiments 20 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BenchmarkDataset(Enum): 5 | cifar_10 = 'cifar10' 6 | cifar_100 = 'cifar100' 7 | imagenet = 'imagenet' 8 | 9 | 10 | class ThreatModel(Enum): 11 | Linf = "Linf" 12 | L2 = "L2" 13 | corruptions = "corruptions" 14 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BenchmarkDataset(Enum): 5 | cifar_10 = 'cifar10' 6 | cifar_100 = 'cifar100' 7 | imagenet = 'imagenet' 8 | 9 | 10 | class ThreatModel(Enum): 11 | Linf = "Linf" 12 | L2 = "L2" 13 | corruptions = "corruptions" 14 | -------------------------------------------------------------------------------- /cifar/run_cifar10_gradual.sh: -------------------------------------------------------------------------------- 1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh 2 | export PYTHONPATH= 3 | conda deactivate 4 | conda activate cotta 5 | 6 | for i in {0..9} 7 | do 8 | CUDA_VISIBLE_DEVICES=0 python -u cifar10c_gradual.py --cfg cfgs/10orders/tent/tent$i.yaml 9 | CUDA_VISIBLE_DEVICES=0 python -u cifar10c_gradual.py --cfg cfgs/10orders/cotta/cotta$i.yaml 10 | done 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /cifar/run_cifar10.sh: -------------------------------------------------------------------------------- 1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh 2 | export PYTHONPATH= 3 | conda deactivate 4 | conda activate cotta 5 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/source.yaml 6 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/norm.yaml 7 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/tent.yaml 8 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/cotta.yaml 9 | 10 | 11 | -------------------------------------------------------------------------------- /cifar/run_cifar100.sh: -------------------------------------------------------------------------------- 1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh 2 | export PYTHONPATH= 3 | conda deactivate 4 | conda activate cotta 5 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/source.yaml 6 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/norm.yaml 7 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/tent.yaml 8 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/cotta.yaml 9 | 10 | 11 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar10/norm.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: norm 3 | ARCH: Standard 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar10 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | -------------------------------------------------------------------------------- /imagenet/cfgs/norm.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: norm 3 | ARCH: Standard_R50 4 | TEST: 5 | BATCH_SIZE: 64 6 | CORRUPTION: 7 | DATASET: imagenet 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | -------------------------------------------------------------------------------- /imagenet/cfgs/source.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: source 3 | ARCH: Standard_R50 4 | TEST: 5 | BATCH_SIZE: 64 6 | CORRUPTION: 7 | DATASET: imagenet 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar10/source.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: source 3 | ARCH: Standard 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar10 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar100/norm.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: norm 3 | ARCH: Hendrycks2020AugMix_ResNeXt 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar100 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar100/source.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: source 3 | ARCH: Hendrycks2020AugMix_ResNeXt 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar100 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent0.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - brightness 7 | - pixelate 8 | - gaussian_noise 9 | - motion_blur 10 | - zoom_blur 11 | - glass_blur 12 | - impulse_noise 13 | - jpeg_compression 14 | - defocus_blur 15 | - elastic_transform 16 | - shot_noise 17 | - frost 18 | - snow 19 | - fog 20 | - contrast 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent1.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - jpeg_compression 7 | - shot_noise 8 | - zoom_blur 9 | - frost 10 | - contrast 11 | - fog 12 | - defocus_blur 13 | - elastic_transform 14 | - gaussian_noise 15 | - brightness 16 | - glass_blur 17 | - impulse_noise 18 | - pixelate 19 | - snow 20 | - motion_blur 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent2.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - defocus_blur 8 | - gaussian_noise 9 | - shot_noise 10 | - snow 11 | - frost 12 | - glass_blur 13 | - zoom_blur 14 | - elastic_transform 15 | - jpeg_compression 16 | - pixelate 17 | - brightness 18 | - impulse_noise 19 | - motion_blur 20 | - fog 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent3.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - shot_noise 7 | - fog 8 | - glass_blur 9 | - pixelate 10 | - snow 11 | - elastic_transform 12 | - brightness 13 | - impulse_noise 14 | - defocus_blur 15 | - frost 16 | - contrast 17 | - gaussian_noise 18 | - motion_blur 19 | - jpeg_compression 20 | - zoom_blur 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent4.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - pixelate 7 | - glass_blur 8 | - zoom_blur 9 | - snow 10 | - fog 11 | - impulse_noise 12 | - brightness 13 | - motion_blur 14 | - frost 15 | - jpeg_compression 16 | - gaussian_noise 17 | - shot_noise 18 | - contrast 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent5.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - motion_blur 7 | - snow 8 | - fog 9 | - shot_noise 10 | - defocus_blur 11 | - contrast 12 | - zoom_blur 13 | - brightness 14 | - frost 15 | - elastic_transform 16 | - glass_blur 17 | - gaussian_noise 18 | - pixelate 19 | - jpeg_compression 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent6.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - frost 7 | - impulse_noise 8 | - jpeg_compression 9 | - contrast 10 | - zoom_blur 11 | - glass_blur 12 | - pixelate 13 | - snow 14 | - defocus_blur 15 | - motion_blur 16 | - brightness 17 | - elastic_transform 18 | - shot_noise 19 | - fog 20 | - gaussian_noise 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent7.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - glass_blur 7 | - zoom_blur 8 | - impulse_noise 9 | - fog 10 | - snow 11 | - jpeg_compression 12 | - gaussian_noise 13 | - frost 14 | - shot_noise 15 | - brightness 16 | - contrast 17 | - motion_blur 18 | - pixelate 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent8.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - defocus_blur 7 | - motion_blur 8 | - zoom_blur 9 | - shot_noise 10 | - gaussian_noise 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - contrast 15 | - pixelate 16 | - frost 17 | - snow 18 | - brightness 19 | - elastic_transform 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/tent/tent9.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - gaussian_noise 8 | - defocus_blur 9 | - zoom_blur 10 | - frost 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - pixelate 15 | - elastic_transform 16 | - shot_noise 17 | - impulse_noise 18 | - snow 19 | - motion_blur 20 | - brightness 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 200 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta0.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - brightness 7 | - pixelate 8 | - gaussian_noise 9 | - motion_blur 10 | - zoom_blur 11 | - glass_blur 12 | - impulse_noise 13 | - jpeg_compression 14 | - defocus_blur 15 | - elastic_transform 16 | - shot_noise 17 | - frost 18 | - snow 19 | - fog 20 | - contrast 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta1.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - jpeg_compression 7 | - shot_noise 8 | - zoom_blur 9 | - frost 10 | - contrast 11 | - fog 12 | - defocus_blur 13 | - elastic_transform 14 | - gaussian_noise 15 | - brightness 16 | - glass_blur 17 | - impulse_noise 18 | - pixelate 19 | - snow 20 | - motion_blur 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta2.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - defocus_blur 8 | - gaussian_noise 9 | - shot_noise 10 | - snow 11 | - frost 12 | - glass_blur 13 | - zoom_blur 14 | - elastic_transform 15 | - jpeg_compression 16 | - pixelate 17 | - brightness 18 | - impulse_noise 19 | - motion_blur 20 | - fog 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta3.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - shot_noise 7 | - fog 8 | - glass_blur 9 | - pixelate 10 | - snow 11 | - elastic_transform 12 | - brightness 13 | - impulse_noise 14 | - defocus_blur 15 | - frost 16 | - contrast 17 | - gaussian_noise 18 | - motion_blur 19 | - jpeg_compression 20 | - zoom_blur 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta4.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - pixelate 7 | - glass_blur 8 | - zoom_blur 9 | - snow 10 | - fog 11 | - impulse_noise 12 | - brightness 13 | - motion_blur 14 | - frost 15 | - jpeg_compression 16 | - gaussian_noise 17 | - shot_noise 18 | - contrast 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta5.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - motion_blur 7 | - snow 8 | - fog 9 | - shot_noise 10 | - defocus_blur 11 | - contrast 12 | - zoom_blur 13 | - brightness 14 | - frost 15 | - elastic_transform 16 | - glass_blur 17 | - gaussian_noise 18 | - pixelate 19 | - jpeg_compression 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta6.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - frost 7 | - impulse_noise 8 | - jpeg_compression 9 | - contrast 10 | - zoom_blur 11 | - glass_blur 12 | - pixelate 13 | - snow 14 | - defocus_blur 15 | - motion_blur 16 | - brightness 17 | - elastic_transform 18 | - shot_noise 19 | - fog 20 | - gaussian_noise 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta7.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - glass_blur 7 | - zoom_blur 8 | - impulse_noise 9 | - fog 10 | - snow 11 | - jpeg_compression 12 | - gaussian_noise 13 | - frost 14 | - shot_noise 15 | - brightness 16 | - contrast 17 | - motion_blur 18 | - pixelate 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta8.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - defocus_blur 7 | - motion_blur 8 | - zoom_blur 9 | - shot_noise 10 | - gaussian_noise 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - contrast 15 | - pixelate 16 | - frost 17 | - snow 18 | - brightness 19 | - elastic_transform 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/cotta/cotta9.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - gaussian_noise 8 | - defocus_blur 9 | - zoom_blur 10 | - frost 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - pixelate 15 | - elastic_transform 16 | - shot_noise 17 | - impulse_noise 18 | - snow 19 | - motion_blur 20 | - brightness 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.01 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent0.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - brightness 7 | - pixelate 8 | - gaussian_noise 9 | - motion_blur 10 | - zoom_blur 11 | - glass_blur 12 | - impulse_noise 13 | - jpeg_compression 14 | - defocus_blur 15 | - elastic_transform 16 | - shot_noise 17 | - frost 18 | - snow 19 | - fog 20 | - contrast 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent1.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - jpeg_compression 7 | - shot_noise 8 | - zoom_blur 9 | - frost 10 | - contrast 11 | - fog 12 | - defocus_blur 13 | - elastic_transform 14 | - gaussian_noise 15 | - brightness 16 | - glass_blur 17 | - impulse_noise 18 | - pixelate 19 | - snow 20 | - motion_blur 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent2.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - defocus_blur 8 | - gaussian_noise 9 | - shot_noise 10 | - snow 11 | - frost 12 | - glass_blur 13 | - zoom_blur 14 | - elastic_transform 15 | - jpeg_compression 16 | - pixelate 17 | - brightness 18 | - impulse_noise 19 | - motion_blur 20 | - fog 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent3.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - shot_noise 7 | - fog 8 | - glass_blur 9 | - pixelate 10 | - snow 11 | - elastic_transform 12 | - brightness 13 | - impulse_noise 14 | - defocus_blur 15 | - frost 16 | - contrast 17 | - gaussian_noise 18 | - motion_blur 19 | - jpeg_compression 20 | - zoom_blur 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent4.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - pixelate 7 | - glass_blur 8 | - zoom_blur 9 | - snow 10 | - fog 11 | - impulse_noise 12 | - brightness 13 | - motion_blur 14 | - frost 15 | - jpeg_compression 16 | - gaussian_noise 17 | - shot_noise 18 | - contrast 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent5.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - motion_blur 7 | - snow 8 | - fog 9 | - shot_noise 10 | - defocus_blur 11 | - contrast 12 | - zoom_blur 13 | - brightness 14 | - frost 15 | - elastic_transform 16 | - glass_blur 17 | - gaussian_noise 18 | - pixelate 19 | - jpeg_compression 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent6.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - frost 7 | - impulse_noise 8 | - jpeg_compression 9 | - contrast 10 | - zoom_blur 11 | - glass_blur 12 | - pixelate 13 | - snow 14 | - defocus_blur 15 | - motion_blur 16 | - brightness 17 | - elastic_transform 18 | - shot_noise 19 | - fog 20 | - gaussian_noise 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent7.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - glass_blur 7 | - zoom_blur 8 | - impulse_noise 9 | - fog 10 | - snow 11 | - jpeg_compression 12 | - gaussian_noise 13 | - frost 14 | - shot_noise 15 | - brightness 16 | - contrast 17 | - motion_blur 18 | - pixelate 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent8.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - defocus_blur 7 | - motion_blur 8 | - zoom_blur 9 | - shot_noise 10 | - gaussian_noise 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - contrast 15 | - pixelate 16 | - frost 17 | - snow 18 | - brightness 19 | - elastic_transform 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /imagenet/cfgs/10orders/tent/tent9.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: imagenet 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - gaussian_noise 8 | - defocus_blur 9 | - zoom_blur 10 | - frost 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - pixelate 15 | - elastic_transform 16 | - shot_noise 17 | - impulse_noise 18 | - snow 19 | - motion_blur 20 | - brightness 21 | MODEL: 22 | ADAPTATION: tent 23 | ARCH: Standard_R50 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 0.00025 27 | METHOD: SGD 28 | STEPS: 1 29 | WD: 0.0 30 | TEST: 31 | BATCH_SIZE: 64 32 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar10/tent.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: tent 3 | ARCH: Standard 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar10 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | OPTIM: 27 | METHOD: Adam 28 | STEPS: 1 29 | BETA: 0.9 30 | LR: 1e-3 31 | WD: 0. 32 | -------------------------------------------------------------------------------- /imagenet/data/ImageNet-C/download.sh: -------------------------------------------------------------------------------- 1 | wget --content-disposition https://zenodo.org/record/2235448/files/blur.tar?download=1 2 | wget --content-disposition https://zenodo.org/record/2235448/files/digital.tar?download=1 3 | wget --content-disposition https://zenodo.org/record/2235448/files/extra.tar?download=1 4 | wget --content-disposition https://zenodo.org/record/2235448/files/noise.tar?download=1 5 | wget --content-disposition https://zenodo.org/record/2235448/files/weather.tar?download=1 6 | 7 | tar -zxvf blur.tar 8 | tar -zxvf digital.tar 9 | tar -zxvf extra.tar 10 | tar -zxvf noise.tar 11 | tar -zxvf weather.tar 12 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar100/tent.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: tent 3 | ARCH: Hendrycks2020AugMix_ResNeXt 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar100 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | OPTIM: 27 | METHOD: Adam 28 | STEPS: 1 29 | BETA: 0.9 30 | LR: 1e-3 31 | WD: 0. 32 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta0.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - brightness 7 | - pixelate 8 | - gaussian_noise 9 | - motion_blur 10 | - zoom_blur 11 | - glass_blur 12 | - impulse_noise 13 | - jpeg_compression 14 | - defocus_blur 15 | - elastic_transform 16 | - shot_noise 17 | - frost 18 | - snow 19 | - fog 20 | - contrast 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta1.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - jpeg_compression 7 | - shot_noise 8 | - zoom_blur 9 | - frost 10 | - contrast 11 | - fog 12 | - defocus_blur 13 | - elastic_transform 14 | - gaussian_noise 15 | - brightness 16 | - glass_blur 17 | - impulse_noise 18 | - pixelate 19 | - snow 20 | - motion_blur 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta2.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - defocus_blur 8 | - gaussian_noise 9 | - shot_noise 10 | - snow 11 | - frost 12 | - glass_blur 13 | - zoom_blur 14 | - elastic_transform 15 | - jpeg_compression 16 | - pixelate 17 | - brightness 18 | - impulse_noise 19 | - motion_blur 20 | - fog 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta3.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - shot_noise 7 | - fog 8 | - glass_blur 9 | - pixelate 10 | - snow 11 | - elastic_transform 12 | - brightness 13 | - impulse_noise 14 | - defocus_blur 15 | - frost 16 | - contrast 17 | - gaussian_noise 18 | - motion_blur 19 | - jpeg_compression 20 | - zoom_blur 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta4.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - pixelate 7 | - glass_blur 8 | - zoom_blur 9 | - snow 10 | - fog 11 | - impulse_noise 12 | - brightness 13 | - motion_blur 14 | - frost 15 | - jpeg_compression 16 | - gaussian_noise 17 | - shot_noise 18 | - contrast 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta5.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - motion_blur 7 | - snow 8 | - fog 9 | - shot_noise 10 | - defocus_blur 11 | - contrast 12 | - zoom_blur 13 | - brightness 14 | - frost 15 | - elastic_transform 16 | - glass_blur 17 | - gaussian_noise 18 | - pixelate 19 | - jpeg_compression 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta6.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - frost 7 | - impulse_noise 8 | - jpeg_compression 9 | - contrast 10 | - zoom_blur 11 | - glass_blur 12 | - pixelate 13 | - snow 14 | - defocus_blur 15 | - motion_blur 16 | - brightness 17 | - elastic_transform 18 | - shot_noise 19 | - fog 20 | - gaussian_noise 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta7.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - glass_blur 7 | - zoom_blur 8 | - impulse_noise 9 | - fog 10 | - snow 11 | - jpeg_compression 12 | - gaussian_noise 13 | - frost 14 | - shot_noise 15 | - brightness 16 | - contrast 17 | - motion_blur 18 | - pixelate 19 | - defocus_blur 20 | - elastic_transform 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta8.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - defocus_blur 7 | - motion_blur 8 | - zoom_blur 9 | - shot_noise 10 | - gaussian_noise 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - contrast 15 | - pixelate 16 | - frost 17 | - snow 18 | - brightness 19 | - elastic_transform 20 | - impulse_noise 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/10orders/cotta/cotta9.yaml: -------------------------------------------------------------------------------- 1 | CORRUPTION: 2 | DATASET: cifar10 3 | SEVERITY: 4 | - 5 5 | TYPE: 6 | - contrast 7 | - gaussian_noise 8 | - defocus_blur 9 | - zoom_blur 10 | - frost 11 | - glass_blur 12 | - jpeg_compression 13 | - fog 14 | - pixelate 15 | - elastic_transform 16 | - shot_noise 17 | - impulse_noise 18 | - snow 19 | - motion_blur 20 | - brightness 21 | MODEL: 22 | ADAPTATION: cotta 23 | ARCH: Standard 24 | OPTIM: 25 | BETA: 0.9 26 | LR: 1e-3 27 | METHOD: Adam 28 | STEPS: 1 29 | WD: 0.0 30 | MT: 0.999 31 | RST: 0.01 32 | AP: 0.92 33 | TEST: 34 | BATCH_SIZE: 200 35 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar10/cotta.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: cotta 3 | ARCH: Standard 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar10 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | OPTIM: 27 | METHOD: Adam 28 | STEPS: 1 29 | BETA: 0.9 30 | LR: 1e-3 31 | WD: 0. 32 | MT: 0.999 33 | RST: 0.01 34 | AP: 0.92 35 | -------------------------------------------------------------------------------- /cifar/cfgs/cifar100/cotta.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | ADAPTATION: cotta 3 | ARCH: Hendrycks2020AugMix_ResNeXt 4 | TEST: 5 | BATCH_SIZE: 200 6 | CORRUPTION: 7 | DATASET: cifar100 8 | SEVERITY: 9 | - 5 10 | TYPE: 11 | - gaussian_noise 12 | - shot_noise 13 | - impulse_noise 14 | - defocus_blur 15 | - glass_blur 16 | - motion_blur 17 | - zoom_blur 18 | - snow 19 | - frost 20 | - fog 21 | - brightness 22 | - contrast 23 | - elastic_transform 24 | - pixelate 25 | - jpeg_compression 26 | OPTIM: 27 | METHOD: Adam 28 | STEPS: 1 29 | BETA: 0.9 30 | LR: 1e-3 31 | WD: 0. 32 | MT: 0.999 33 | RST: 0.01 34 | AP: 0.72 35 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Any, Dict, Dict as OrderedDictType 3 | 4 | from robustbench.model_zoo.cifar10 import cifar_10_models 5 | from robustbench.model_zoo.cifar100 import cifar_100_models 6 | from robustbench.model_zoo.imagenet import imagenet_models 7 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel 8 | 9 | ModelsDict = OrderedDictType[str, Dict[str, Any]] 10 | ThreatModelsDict = OrderedDictType[ThreatModel, ModelsDict] 11 | BenchmarkDict = OrderedDictType[BenchmarkDataset, ThreatModelsDict] 12 | 13 | model_dicts: BenchmarkDict = OrderedDict([ 14 | (BenchmarkDataset.cifar_10, cifar_10_models), 15 | (BenchmarkDataset.cifar_100, cifar_100_models), 16 | (BenchmarkDataset.imagenet, imagenet_models) 17 | ]) 18 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Any, Dict, Dict as OrderedDictType 3 | 4 | from robustbench.model_zoo.cifar10 import cifar_10_models 5 | from robustbench.model_zoo.cifar100 import cifar_100_models 6 | from robustbench.model_zoo.imagenet import imagenet_models 7 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel 8 | 9 | ModelsDict = OrderedDictType[str, Dict[str, Any]] 10 | ThreatModelsDict = OrderedDictType[ThreatModel, ModelsDict] 11 | BenchmarkDict = OrderedDictType[BenchmarkDataset, ThreatModelsDict] 12 | 13 | model_dicts: BenchmarkDict = OrderedDict([ 14 | (BenchmarkDataset.cifar_10, cifar_10_models), 15 | (BenchmarkDataset.cifar_100, cifar_100_models), 16 | (BenchmarkDataset.imagenet, imagenet_models) 17 | ]) 18 | -------------------------------------------------------------------------------- /imagenet/run.sh: -------------------------------------------------------------------------------- 1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh 2 | # Clean PATH and only use cotta env 3 | export PYTHONPATH= 4 | conda deactivate 5 | conda activate cotta 6 | # Source-only and AdaBN results are not affected by the order as no training is performed. Therefore only need to run once. 7 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/source.yaml 8 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/norm.yaml 9 | # TENT and CoTTA results are affected by the corruption sequence order 10 | for i in {0..9} 11 | do 12 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/10orders/tent/tent$i.yaml 13 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/10orders/cotta/cotta$i.yaml 14 | done 15 | # Run Mean and AVG for TENT and CoTTA 16 | cd output 17 | python3 -u ../eval.py | tee result.log 18 | -------------------------------------------------------------------------------- /imagenet/eval.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import numpy as np 3 | 4 | 5 | def read_file(filename): 6 | lines = open(filename, "r").readlines() 7 | res = [] 8 | for l in lines: 9 | if "error" in l and "]: error % [" in l: 10 | res.append(float(l.strip().split(" ")[-1][:-1])) 11 | assert len(res)==15 12 | return np.mean(np.array(res)) 13 | 14 | 15 | def read_files(files): 16 | res = [] 17 | for f in files: 18 | res.append(read_file(f)) 19 | print("read", len(files), "files.") 20 | print(res) 21 | return np.mean(np.array(res)), np.std(np.array(res)) 22 | 23 | 24 | print("read source files:") 25 | print(read_files(glob("source_*.txt"))) 26 | 27 | print("read adabn files:") 28 | print(read_files(glob("norm_*.txt"))) 29 | 30 | print("read tent files:") 31 | print(read_files(glob("tent[0-9]_*.txt"))) 32 | 33 | print("read cotta files:") 34 | print(read_files(glob("cotta[0-9]_*.txt"))) 35 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/architectures/utils_architectures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from typing import Tuple 5 | from torch import Tensor 6 | 7 | 8 | class ImageNormalizer(nn.Module): 9 | def __init__(self, mean: Tuple[float, float, float], 10 | std: Tuple[float, float, float]) -> None: 11 | super(ImageNormalizer, self).__init__() 12 | 13 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1)) 14 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1)) 15 | 16 | def forward(self, input: Tensor) -> Tensor: 17 | return (input - self.mean) / self.std 18 | 19 | 20 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float], 21 | std: Tuple[float, float, float]) -> nn.Module: 22 | layers = OrderedDict([ 23 | ('normalize', ImageNormalizer(mean, std)), 24 | ('model', model) 25 | ]) 26 | return nn.Sequential(layers) 27 | 28 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/architectures/utils_architectures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from typing import Tuple 5 | from torch import Tensor 6 | 7 | 8 | class ImageNormalizer(nn.Module): 9 | def __init__(self, mean: Tuple[float, float, float], 10 | std: Tuple[float, float, float]) -> None: 11 | super(ImageNormalizer, self).__init__() 12 | 13 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1)) 14 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1)) 15 | 16 | def forward(self, input: Tensor) -> Tensor: 17 | return (input - self.mean) / self.std 18 | 19 | 20 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float], 21 | std: Tuple[float, float, float]) -> nn.Module: 22 | layers = OrderedDict([ 23 | ('normalize', ImageNormalizer(mean, std)), 24 | ('model', model) 25 | ]) 26 | return nn.Sequential(layers) 27 | 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Qin Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | 25 | ### Our code is based on TENT, therefore a copy of TENT's license is included here 26 | MIT License 27 | 28 | Copyright (c) 2021 Dequan Wang and Evan Shelhamer 29 | 30 | Permission is hereby granted, free of charge, to any person obtaining a copy 31 | of this software and associated documentation files (the "Software"), to deal 32 | in the Software without restriction, including without limitation the rights 33 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 34 | copies of the Software, and to permit persons to whom the Software is 35 | furnished to do so, subject to the following conditions: 36 | 37 | The above copyright notice and this permission notice shall be included in all 38 | copies or substantial portions of the Software. 39 | 40 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 41 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 42 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 43 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 44 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 45 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 46 | SOFTWARE. 47 | 48 | -------------------------------------------------------------------------------- /cifar/norm.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Norm(nn.Module): 8 | """Norm adapts a model by estimating feature statistics during testing. 9 | 10 | Once equipped with Norm, the model normalizes its features during testing 11 | with batch-wise statistics, just like batch norm does during training. 12 | """ 13 | 14 | def __init__(self, model, eps=1e-5, momentum=0.1, 15 | reset_stats=False, no_stats=False): 16 | super().__init__() 17 | self.model = model 18 | self.model = configure_model(model, eps, momentum, reset_stats, 19 | no_stats) 20 | self.model_state = deepcopy(self.model.state_dict()) 21 | 22 | def forward(self, x): 23 | return self.model(x) 24 | 25 | def reset(self): 26 | self.model.load_state_dict(self.model_state, strict=True) 27 | 28 | 29 | def collect_stats(model): 30 | """Collect the normalization stats from batch norms. 31 | 32 | Walk the model's modules and collect all batch normalization stats. 33 | Return the stats and their names. 34 | """ 35 | stats = [] 36 | names = [] 37 | for nm, m in model.named_modules(): 38 | if isinstance(m, nn.BatchNorm2d): 39 | state = m.state_dict() 40 | if m.affine: 41 | del state['weight'], state['bias'] 42 | for ns, s in state.items(): 43 | stats.append(s) 44 | names.append(f"{nm}.{ns}") 45 | return stats, names 46 | 47 | 48 | def configure_model(model, eps, momentum, reset_stats, no_stats): 49 | """Configure model for adaptation by test-time normalization.""" 50 | for m in model.modules(): 51 | if isinstance(m, nn.BatchNorm2d): 52 | # use batch-wise statistics in forward 53 | m.train() 54 | # configure epsilon for stability, and momentum for updates 55 | m.eps = eps 56 | m.momentum = momentum 57 | if reset_stats: 58 | # reset state to estimate test stats without train stats 59 | m.reset_running_stats() 60 | if no_stats: 61 | # disable state entirely and use only batch stats 62 | m.track_running_stats = False 63 | m.running_mean = None 64 | m.running_var = None 65 | return model 66 | -------------------------------------------------------------------------------- /imagenet/norm.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Norm(nn.Module): 8 | """Norm adapts a model by estimating feature statistics during testing. 9 | 10 | Once equipped with Norm, the model normalizes its features during testing 11 | with batch-wise statistics, just like batch norm does during training. 12 | """ 13 | 14 | def __init__(self, model, eps=1e-5, momentum=0.1, 15 | reset_stats=False, no_stats=False): 16 | super().__init__() 17 | self.model = model 18 | self.model = configure_model(model, eps, momentum, reset_stats, 19 | no_stats) 20 | self.model_state = deepcopy(self.model.state_dict()) 21 | 22 | def forward(self, x): 23 | return self.model(x) 24 | 25 | def reset(self): 26 | self.model.load_state_dict(self.model_state, strict=True) 27 | 28 | 29 | def collect_stats(model): 30 | """Collect the normalization stats from batch norms. 31 | 32 | Walk the model's modules and collect all batch normalization stats. 33 | Return the stats and their names. 34 | """ 35 | stats = [] 36 | names = [] 37 | for nm, m in model.named_modules(): 38 | if isinstance(m, nn.BatchNorm2d): 39 | state = m.state_dict() 40 | if m.affine: 41 | del state['weight'], state['bias'] 42 | for ns, s in state.items(): 43 | stats.append(s) 44 | names.append(f"{nm}.{ns}") 45 | return stats, names 46 | 47 | 48 | def configure_model(model, eps, momentum, reset_stats, no_stats): 49 | """Configure model for adaptation by test-time normalization.""" 50 | for m in model.modules(): 51 | if isinstance(m, nn.BatchNorm2d): 52 | # use batch-wise statistics in forward 53 | m.train() 54 | # configure epsilon for stability, and momentum for updates 55 | m.eps = eps 56 | m.momentum = momentum 57 | if reset_stats: 58 | # reset state to estimate test stats without train stats 59 | m.reset_running_stats() 60 | if no_stats: 61 | # disable state entirely and use only batch stats 62 | m.track_running_stats = False 63 | m.running_mean = None 64 | m.running_var = None 65 | return model 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /cifar/robustbench/zenodo_download.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import shutil 3 | from pathlib import Path 4 | from typing import Set 5 | 6 | import requests 7 | from tqdm import tqdm 8 | 9 | ZENODO_ENTRY_POINT = "https://zenodo.org/api" 10 | RECORDS_ENTRY_POINT = f"{ZENODO_ENTRY_POINT}/records/" 11 | 12 | CHUNK_SIZE = 65536 13 | 14 | 15 | class DownloadError(Exception): 16 | pass 17 | 18 | 19 | def download_file(url: str, save_dir: Path, total_bytes: int) -> Path: 20 | """Downloads large files from the given URL. 21 | 22 | From: https://stackoverflow.com/a/16696317 23 | 24 | :param url: The URL of the file. 25 | :param save_dir: The directory where the file should be saved. 26 | :param total_bytes: The total bytes of the file. 27 | :return: The path to the downloaded file. 28 | """ 29 | local_filename = save_dir / url.split('/')[-1] 30 | print(f"Starting download from {url}") 31 | with requests.get(url, stream=True) as r: 32 | r.raise_for_status() 33 | with open(local_filename, 'wb') as f: 34 | iters = total_bytes // CHUNK_SIZE 35 | for chunk in tqdm(r.iter_content(chunk_size=CHUNK_SIZE), 36 | total=iters): 37 | f.write(chunk) 38 | 39 | return local_filename 40 | 41 | 42 | def file_md5(filename: Path) -> str: 43 | """Computes the MD5 hash of a given file""" 44 | hash_md5 = hashlib.md5() 45 | with open(filename, "rb") as f: 46 | for chunk in iter(lambda: f.read(32768), b""): 47 | hash_md5.update(chunk) 48 | 49 | return hash_md5.hexdigest() 50 | 51 | 52 | def zenodo_download(record_id: str, filenames_to_download: Set[str], 53 | save_dir: Path) -> None: 54 | """Downloads the given files from the given Zenodo record. 55 | 56 | :param record_id: The ID of the record. 57 | :param filenames_to_download: The files to download from the record. 58 | :param save_dir: The directory where the files should be saved. 59 | """ 60 | if not save_dir.exists(): 61 | save_dir.mkdir(parents=True, exist_ok=True) 62 | 63 | url = f"{RECORDS_ENTRY_POINT}/{record_id}" 64 | res = requests.get(url) 65 | files = res.json()["files"] 66 | files_to_download = list( 67 | filter(lambda file: file["key"] in filenames_to_download, files)) 68 | 69 | for file in files_to_download: 70 | if (save_dir / file["key"]).exists(): 71 | continue 72 | file_url = file["links"]["self"] 73 | file_checksum = file["checksum"].split(":")[-1] 74 | filename = download_file(file_url, save_dir, file["size"]) 75 | if file_md5(filename) != file_checksum: 76 | raise DownloadError( 77 | "The hash of the downloaded file does not match" 78 | " the expected one.") 79 | print("Download finished, extracting...") 80 | shutil.unpack_archive(filename, 81 | extract_dir=save_dir, 82 | format=file["type"]) 83 | print("Downloaded and extracted.") 84 | -------------------------------------------------------------------------------- /imagenet/robustbench/zenodo_download.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import shutil 3 | from pathlib import Path 4 | from typing import Set 5 | 6 | import requests 7 | from tqdm import tqdm 8 | 9 | ZENODO_ENTRY_POINT = "https://zenodo.org/api" 10 | RECORDS_ENTRY_POINT = f"{ZENODO_ENTRY_POINT}/records/" 11 | 12 | CHUNK_SIZE = 65536 13 | 14 | 15 | class DownloadError(Exception): 16 | pass 17 | 18 | 19 | def download_file(url: str, save_dir: Path, total_bytes: int) -> Path: 20 | """Downloads large files from the given URL. 21 | 22 | From: https://stackoverflow.com/a/16696317 23 | 24 | :param url: The URL of the file. 25 | :param save_dir: The directory where the file should be saved. 26 | :param total_bytes: The total bytes of the file. 27 | :return: The path to the downloaded file. 28 | """ 29 | local_filename = save_dir / url.split('/')[-1] 30 | print(f"Starting download from {url}") 31 | with requests.get(url, stream=True) as r: 32 | r.raise_for_status() 33 | with open(local_filename, 'wb') as f: 34 | iters = total_bytes // CHUNK_SIZE 35 | for chunk in tqdm(r.iter_content(chunk_size=CHUNK_SIZE), 36 | total=iters): 37 | f.write(chunk) 38 | 39 | return local_filename 40 | 41 | 42 | def file_md5(filename: Path) -> str: 43 | """Computes the MD5 hash of a given file""" 44 | hash_md5 = hashlib.md5() 45 | with open(filename, "rb") as f: 46 | for chunk in iter(lambda: f.read(32768), b""): 47 | hash_md5.update(chunk) 48 | 49 | return hash_md5.hexdigest() 50 | 51 | 52 | def zenodo_download(record_id: str, filenames_to_download: Set[str], 53 | save_dir: Path) -> None: 54 | """Downloads the given files from the given Zenodo record. 55 | 56 | :param record_id: The ID of the record. 57 | :param filenames_to_download: The files to download from the record. 58 | :param save_dir: The directory where the files should be saved. 59 | """ 60 | if not save_dir.exists(): 61 | save_dir.mkdir(parents=True, exist_ok=True) 62 | 63 | url = f"{RECORDS_ENTRY_POINT}/{record_id}" 64 | res = requests.get(url) 65 | files = res.json()["files"] 66 | files_to_download = list( 67 | filter(lambda file: file["key"] in filenames_to_download, files)) 68 | 69 | for file in files_to_download: 70 | if (save_dir / file["key"]).exists(): 71 | continue 72 | file_url = file["links"]["self"] 73 | file_checksum = file["checksum"].split(":")[-1] 74 | filename = download_file(file_url, save_dir, file["size"]) 75 | if file_md5(filename) != file_checksum: 76 | raise DownloadError( 77 | "The hash of the downloaded file does not match" 78 | " the expected one.") 79 | print("Download finished, extracting...") 80 | shutil.unpack_archive(filename, 81 | extract_dir=save_dir, 82 | format=file["type"]) 83 | print("Downloaded and extracted.") 84 | -------------------------------------------------------------------------------- /cifar/robustbench/leaderboard/leaderboard.html.j2: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 10 | {% if threat_model != "corruptions" %} 11 | 16 | 21 | 26 | {% endif %} 27 | {% if threat_model == "corruptions" %} 28 | 32 | {% endif %} 33 | 34 | 35 | 36 | 37 | 38 | 39 | {% for model in models %} 40 | 41 | 42 | 51 | 52 | 53 | {% if threat_model != "corruptions" %} 54 | 55 | 56 | {% endif %} 57 | 58 | 59 | 60 | 61 | {% endfor %} 62 | 63 |
RankMethod 7 | Standard
8 | accuracy 9 |
12 | AutoAttack
13 | robust
14 | accuracy 15 |
17 | Best known
18 | robust
19 | accuracy 20 |
22 | AA eval.
23 | potentially
24 | unreliable 25 |
29 | Robust
30 | accuracy 31 |
Extra
data
ArchitectureVenue
{{ loop.index }} 43 | {{ model.name }} 44 | {% if model.footnote is defined and model.footnote != None %} 45 |
46 | 47 | {{ model.footnote }} 48 | 49 | {% endif %} 50 |
{{ model.clean_acc }}%{{ model[acc_field] }}%{{ model.external if model.external is defined and model.external else model[acc_field]}}%{{ "Unknown" if model.unreliable is not defined else ("
" if model.unreliable else "
×
") }}
{{ "☑" if model.additional_data else "×" }}{{ model.architecture }}{{ model.venue }}
64 | -------------------------------------------------------------------------------- /imagenet/robustbench/leaderboard/leaderboard.html.j2: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 10 | {% if threat_model != "corruptions" %} 11 | 16 | 21 | 26 | {% endif %} 27 | {% if threat_model == "corruptions" %} 28 | 32 | {% endif %} 33 | 34 | 35 | 36 | 37 | 38 | 39 | {% for model in models %} 40 | 41 | 42 | 51 | 52 | 53 | {% if threat_model != "corruptions" %} 54 | 55 | 56 | {% endif %} 57 | 58 | 59 | 60 | 61 | {% endfor %} 62 | 63 |
RankMethod 7 | Standard
8 | accuracy 9 |
12 | AutoAttack
13 | robust
14 | accuracy 15 |
17 | Best known
18 | robust
19 | accuracy 20 |
22 | AA eval.
23 | potentially
24 | unreliable 25 |
29 | Robust
30 | accuracy 31 |
Extra
data
ArchitectureVenue
{{ loop.index }} 43 | {{ model.name }} 44 | {% if model.footnote is defined and model.footnote != None %} 45 |
46 | 47 | {{ model.footnote }} 48 | 49 | {% endif %} 50 |
{{ model.clean_acc }}%{{ model[acc_field] }}%{{ model.external if model.external is defined and model.external else model[acc_field]}}%{{ "Unknown" if model.unreliable is not defined else ("
" if model.unreliable else "
×
") }}
{{ "☑" if model.additional_data else "×" }}{{ model.architecture }}{{ model.venue }}
64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoTTA: Continual Test-Time Adaptation 2 | Official code for [Continual Test-Time Domain Adaptation](https://arxiv.org/abs/2203.13591), published in CVPR 2022. 3 | 4 | This repository also includes other continual test-time adaptation methods for classification and segmentation. 5 | We provide benchmarking and comparison for the following methods: 6 | + [CoTTA](https://arxiv.org/abs/2203.13591) 7 | + AdaBN / BN Adapt 8 | + TENT 9 | 10 | on the following tasks 11 | + CIFAR10/100 -> CIFAR10C/100C (standard/gradual) 12 | + ImageNet -> ImageNetC 13 | + Cityscapes -> ACDC (segmentation) 14 | 15 | ## Prerequisite 16 | Please create and activate the following conda envrionment. To reproduce our results, please kindly create and use this environment. 17 | ```bash 18 | # It may take several minutes for conda to solve the environment 19 | conda update conda 20 | conda env create -f environment.yml 21 | conda activate cotta 22 | ``` 23 | 24 | ## Classification Experiments 25 | ### CIFAR10-to-CIFAR10C-standard task 26 | ```bash 27 | # Tested on RTX2080TI 28 | cd cifar 29 | # This includes the comparison of all three methods as well as baseline 30 | bash run_cifar10.sh 31 | ``` 32 | ### CIFAR10-to-CIFAR10C-gradual task 33 | ```bash 34 | # Tested on RTX2080TI 35 | bash run_cifar10_gradual.sh 36 | ``` 37 | ### CIFAR100-to-CIFAR100C task 38 | ```bash 39 | # Tested on RTX3090 40 | bash run_cifar100.sh 41 | ``` 42 | 43 | ### ImageNet-to-ImageNetC task 44 | ```bash 45 | # Tested on RTX3090 46 | cd imagenet 47 | bash run.sh 48 | ``` 49 | 50 | ## Segmentation Experiments 51 | ### Cityscapes-to-ACDC segmentation task 52 | Since April 2022, we also offer the segmentation code based on Segformer. 53 | You can download it [here](https://github.com/qinenergy/cotta/issues/6) 54 | ``` 55 | ## environment setup: a new conda environment is needed for segformer 56 | ## You may also want to check https://github.com/qinenergy/cotta/issues/13 if you have problem installing mmcv 57 | conda env create -f environment_segformer.yml 58 | pip install -e . --user 59 | conda activate segformer 60 | ## Run 61 | bash run_base.sh 62 | bash run_tent.sh 63 | bash run_cotta.sh 64 | # Example logs are included in ./example_logs/base.log, tent.log, and cotta.log. 65 | ## License for Cityscapses-to-ACDC code 66 | Non-commercial. Code is heavily based on Segformer. Please also check Segformer's LICENSE. 67 | ``` 68 | 69 | ## Data links 70 | + [Supplementary PDF](https://1drv.ms/b/s!At2KHTLZCWRegpAOqP-8BCBQze68wg?e=wiyaAl) 71 | + [ACDC experiment code](https://1drv.ms/u/s!At2KHTLZCWRegpAcvrh9SA34gMpzNQ?e=TSiKs6) 72 | + [Other Supplementary](https://1drv.ms/f/s!At2KHTLZCWRegpAKdcRuOGE1S9ZGLg?e=iJcR9L) 73 | 74 | ## Citation 75 | Please cite our work if you find it useful. 76 | ```bibtex 77 | @inproceedings{wang2022continual, 78 | title={Continual Test-Time Domain Adaptation}, 79 | author={Wang, Qin and Fink, Olga and Van Gool, Luc and Dai, Dengxin}, 80 | booktitle={Proceedings of Conference on Computer Vision and Pattern Recognition}, 81 | year={2022} 82 | } 83 | ``` 84 | 85 | ## Acknowledgement 86 | + TENT code is heavily used. [official](https://github.com/DequanWang/tent) 87 | + KATANA code is used for augmentation. [official](https://github.com/giladcohen/KATANA) 88 | + Robustbench [official](https://github.com/RobustBench/robustbench) 89 | 90 | 91 | ## External data link 92 | + ImageNet-C [Download](https://zenodo.org/record/2235448#.Yj2RO_co_mF) 93 | 94 | For questions regarding the code, please contact wang@qin.ee . 95 | -------------------------------------------------------------------------------- /cifar/robustbench/leaderboard/template.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | from jinja2 import Environment, PackageLoader, select_autoescape 7 | 8 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel 9 | from robustbench.utils import ACC_FIELDS 10 | 11 | 12 | def generate_leaderboard(dataset: Union[str, BenchmarkDataset], 13 | threat_model: Union[str, ThreatModel], 14 | models_folder: str = "model_info") -> str: 15 | """Prints the HTML leaderboard starting from the .json results. 16 | 17 | The result is a that can be put directly into the RobustBench index.html page, 18 | and looks the same as the tables that are already existing. 19 | 20 | The .json results must have the same structure as the following: 21 | `` 22 | { 23 | "link": "https://arxiv.org/abs/2003.09461", 24 | "name": "Adversarial Robustness on In- and Out-Distribution Improves Explainability", 25 | "authors": "Maximilian Augustin, Alexander Meinke, Matthias Hein", 26 | "additional_data": true, 27 | "number_forward_passes": 1, 28 | "dataset": "cifar10", 29 | "venue": "ECCV 2020", 30 | "architecture": "ResNet-50", 31 | "eps": "0.5", 32 | "clean_acc": "91.08", 33 | "reported": "73.27", 34 | "autoattack_acc": "72.91" 35 | } 36 | `` 37 | 38 | If the model is robust to common corruptions, then the "autoattack_acc" field should be 39 | "corruptions_acc". 40 | 41 | :param dataset: The dataset of the wanted leaderboard. 42 | :param threat_model: The threat model of the wanted leaderboard. 43 | :param models_folder: The base folder of the model jsons (e.g. our "model_info" folder). 44 | 45 | :return: The resulting HTML table. 46 | """ 47 | dataset_: BenchmarkDataset = BenchmarkDataset(dataset) 48 | threat_model_: ThreatModel = ThreatModel(threat_model) 49 | 50 | folder = Path(models_folder) / dataset_.value / threat_model_.value 51 | 52 | acc_field = ACC_FIELDS[threat_model_] 53 | 54 | models = [] 55 | for model_path in folder.glob("*.json"): 56 | with open(model_path) as fp: 57 | model = json.load(fp) 58 | 59 | models.append(model) 60 | 61 | models.sort(key=lambda x: x[acc_field], reverse=True) 62 | 63 | env = Environment(loader=PackageLoader('robustbench', 'leaderboard'), 64 | autoescape=select_autoescape(['html', 'xml'])) 65 | 66 | template = env.get_template('leaderboard.html.j2') 67 | 68 | result = template.render(threat_model=threat_model, dataset=dataset, models=models, acc_field=acc_field) 69 | print(result) 70 | return result 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = ArgumentParser() 75 | parser.add_argument( 76 | "--dataset", 77 | type=str, 78 | default="cifar10", 79 | help="The dataset of the desired leaderboard." 80 | ) 81 | parser.add_argument( 82 | "--threat_model", 83 | type=str, 84 | help="The threat model of the desired leaderboard." 85 | ) 86 | parser.add_argument( 87 | "--models_folder", 88 | type=str, 89 | default="model_info", 90 | help="The base folder of the model jsons (e.g. our 'model_info' folder)" 91 | ) 92 | args = parser.parse_args() 93 | 94 | generate_leaderboard(args.dataset, args.threat_model, args.models_folder) 95 | -------------------------------------------------------------------------------- /imagenet/robustbench/leaderboard/template.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | from jinja2 import Environment, PackageLoader, select_autoescape 7 | 8 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel 9 | from robustbench.utils import ACC_FIELDS 10 | 11 | 12 | def generate_leaderboard(dataset: Union[str, BenchmarkDataset], 13 | threat_model: Union[str, ThreatModel], 14 | models_folder: str = "model_info") -> str: 15 | """Prints the HTML leaderboard starting from the .json results. 16 | 17 | The result is a
that can be put directly into the RobustBench index.html page, 18 | and looks the same as the tables that are already existing. 19 | 20 | The .json results must have the same structure as the following: 21 | `` 22 | { 23 | "link": "https://arxiv.org/abs/2003.09461", 24 | "name": "Adversarial Robustness on In- and Out-Distribution Improves Explainability", 25 | "authors": "Maximilian Augustin, Alexander Meinke, Matthias Hein", 26 | "additional_data": true, 27 | "number_forward_passes": 1, 28 | "dataset": "cifar10", 29 | "venue": "ECCV 2020", 30 | "architecture": "ResNet-50", 31 | "eps": "0.5", 32 | "clean_acc": "91.08", 33 | "reported": "73.27", 34 | "autoattack_acc": "72.91" 35 | } 36 | `` 37 | 38 | If the model is robust to common corruptions, then the "autoattack_acc" field should be 39 | "corruptions_acc". 40 | 41 | :param dataset: The dataset of the wanted leaderboard. 42 | :param threat_model: The threat model of the wanted leaderboard. 43 | :param models_folder: The base folder of the model jsons (e.g. our "model_info" folder). 44 | 45 | :return: The resulting HTML table. 46 | """ 47 | dataset_: BenchmarkDataset = BenchmarkDataset(dataset) 48 | threat_model_: ThreatModel = ThreatModel(threat_model) 49 | 50 | folder = Path(models_folder) / dataset_.value / threat_model_.value 51 | 52 | acc_field = ACC_FIELDS[threat_model_] 53 | 54 | models = [] 55 | for model_path in folder.glob("*.json"): 56 | with open(model_path) as fp: 57 | model = json.load(fp) 58 | 59 | models.append(model) 60 | 61 | models.sort(key=lambda x: x[acc_field], reverse=True) 62 | 63 | env = Environment(loader=PackageLoader('robustbench', 'leaderboard'), 64 | autoescape=select_autoescape(['html', 'xml'])) 65 | 66 | template = env.get_template('leaderboard.html.j2') 67 | 68 | result = template.render(threat_model=threat_model, dataset=dataset, models=models, acc_field=acc_field) 69 | print(result) 70 | return result 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = ArgumentParser() 75 | parser.add_argument( 76 | "--dataset", 77 | type=str, 78 | default="cifar10", 79 | help="The dataset of the desired leaderboard." 80 | ) 81 | parser.add_argument( 82 | "--threat_model", 83 | type=str, 84 | help="The threat model of the desired leaderboard." 85 | ) 86 | parser.add_argument( 87 | "--models_folder", 88 | type=str, 89 | default="model_info", 90 | help="The base folder of the model jsons (e.g. our 'model_info' folder)" 91 | ) 92 | args = parser.parse_args() 93 | 94 | generate_leaderboard(args.dataset, args.threat_model, args.models_folder) 95 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/imagenet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from torchvision import models as pt_models 4 | 5 | from robustbench.model_zoo.enums import ThreatModel 6 | from robustbench.model_zoo.architectures.utils_architectures import normalize_model 7 | 8 | 9 | mu = (0.485, 0.456, 0.406) 10 | sigma = (0.229, 0.224, 0.225) 11 | 12 | 13 | linf = OrderedDict( 14 | [ 15 | ('Wong2020Fast', { # requires resolution 288 x 288 16 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 17 | 'gdrive_id': '1deM2ZNS5tf3S_-eRURJi-IlvUL8WJQ_w', 18 | 'preprocessing': 'Crop288' 19 | }), 20 | ('Engstrom2019Robustness', { 21 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 22 | 'gdrive_id': '1T2Fvi1eCJTeAOEzrH_4TAIwO8HTOYVyn', 23 | 'preprocessing': 'Res256Crop224', 24 | }), 25 | ('Salman2020Do_R50', { 26 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 27 | 'gdrive_id': '1TmT5oGa1UvVjM3d-XeSj_XmKqBNRUg8r', 28 | 'preprocessing': 'Res256Crop224' 29 | }), 30 | ('Salman2020Do_R18', { 31 | 'model': lambda: normalize_model(pt_models.resnet18(), mu, sigma), 32 | 'gdrive_id': '1OThCOQCOxY6lAgxZxgiK3YuZDD7PPfPx', 33 | 'preprocessing': 'Res256Crop224' 34 | }), 35 | ('Salman2020Do_50_2', { 36 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma), 37 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB', 38 | 'preprocessing': 'Res256Crop224' 39 | }), 40 | ('Standard_R50', { 41 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma), 42 | 'gdrive_id': '', 43 | 'preprocessing': 'Res256Crop224' 44 | }), 45 | ]) 46 | 47 | common_corruptions = OrderedDict( 48 | [ 49 | ('Geirhos2018_SIN', { 50 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 51 | 'gdrive_id': '1hLgeY_rQIaOT4R-t_KyOqPNkczfaedgs', 52 | 'preprocessing': 'Res256Crop224' 53 | }), 54 | ('Geirhos2018_SIN_IN', { 55 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 56 | 'gdrive_id': '139pWopDnNERObZeLsXUysRcLg6N1iZHK', 57 | 'preprocessing': 'Res256Crop224' 58 | }), 59 | ('Geirhos2018_SIN_IN_IN', { 60 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 61 | 'gdrive_id': '1xOvyuxpOZ8I5CZOi0EGYG_R6tu3ZaJdO', 62 | 'preprocessing': 'Res256Crop224' 63 | }), 64 | ('Hendrycks2020Many', { 65 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 66 | 'gdrive_id': '1kylueoLtYtxkpVzoOA1B6tqdbRl2xt9X', 67 | 'preprocessing': 'Res256Crop224' 68 | }), 69 | ('Hendrycks2020AugMix', { 70 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 71 | 'gdrive_id': '1xRMj1GlO93tLoCMm0e5wEvZwqhIjxhoJ', 72 | 'preprocessing': 'Res256Crop224' 73 | }), 74 | ('Salman2020Do_50_2_Linf', { 75 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma), 76 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB', 77 | 'preprocessing': 'Res256Crop224' 78 | }), 79 | ('Standard_R50', { 80 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma), 81 | 'gdrive_id': '', 82 | 'preprocessing': 'Res256Crop224' 83 | }), 84 | ]) 85 | 86 | imagenet_models = OrderedDict([(ThreatModel.Linf, linf), 87 | (ThreatModel.corruptions, common_corruptions)]) 88 | 89 | 90 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/imagenet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from torchvision import models as pt_models 4 | 5 | from robustbench.model_zoo.enums import ThreatModel 6 | from robustbench.model_zoo.architectures.utils_architectures import normalize_model 7 | 8 | 9 | mu = (0.485, 0.456, 0.406) 10 | sigma = (0.229, 0.224, 0.225) 11 | 12 | 13 | linf = OrderedDict( 14 | [ 15 | ('Wong2020Fast', { # requires resolution 288 x 288 16 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 17 | 'gdrive_id': '1deM2ZNS5tf3S_-eRURJi-IlvUL8WJQ_w', 18 | 'preprocessing': 'Crop288' 19 | }), 20 | ('Engstrom2019Robustness', { 21 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 22 | 'gdrive_id': '1T2Fvi1eCJTeAOEzrH_4TAIwO8HTOYVyn', 23 | 'preprocessing': 'Res256Crop224', 24 | }), 25 | ('Salman2020Do_R50', { 26 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 27 | 'gdrive_id': '1TmT5oGa1UvVjM3d-XeSj_XmKqBNRUg8r', 28 | 'preprocessing': 'Res256Crop224' 29 | }), 30 | ('Salman2020Do_R18', { 31 | 'model': lambda: normalize_model(pt_models.resnet18(), mu, sigma), 32 | 'gdrive_id': '1OThCOQCOxY6lAgxZxgiK3YuZDD7PPfPx', 33 | 'preprocessing': 'Res256Crop224' 34 | }), 35 | ('Salman2020Do_50_2', { 36 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma), 37 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB', 38 | 'preprocessing': 'Res256Crop224' 39 | }), 40 | ('Standard_R50', { 41 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma), 42 | 'gdrive_id': '', 43 | 'preprocessing': 'Res256Crop224' 44 | }), 45 | ]) 46 | 47 | common_corruptions = OrderedDict( 48 | [ 49 | ('Geirhos2018_SIN', { 50 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 51 | 'gdrive_id': '1hLgeY_rQIaOT4R-t_KyOqPNkczfaedgs', 52 | 'preprocessing': 'Res256Crop224' 53 | }), 54 | ('Geirhos2018_SIN_IN', { 55 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 56 | 'gdrive_id': '139pWopDnNERObZeLsXUysRcLg6N1iZHK', 57 | 'preprocessing': 'Res256Crop224' 58 | }), 59 | ('Geirhos2018_SIN_IN_IN', { 60 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 61 | 'gdrive_id': '1xOvyuxpOZ8I5CZOi0EGYG_R6tu3ZaJdO', 62 | 'preprocessing': 'Res256Crop224' 63 | }), 64 | ('Hendrycks2020Many', { 65 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 66 | 'gdrive_id': '1kylueoLtYtxkpVzoOA1B6tqdbRl2xt9X', 67 | 'preprocessing': 'Res256Crop224' 68 | }), 69 | ('Hendrycks2020AugMix', { 70 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma), 71 | 'gdrive_id': '1xRMj1GlO93tLoCMm0e5wEvZwqhIjxhoJ', 72 | 'preprocessing': 'Res256Crop224' 73 | }), 74 | ('Salman2020Do_50_2_Linf', { 75 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma), 76 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB', 77 | 'preprocessing': 'Res256Crop224' 78 | }), 79 | ('Standard_R50', { 80 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma), 81 | 'gdrive_id': '', 82 | 'preprocessing': 'Res256Crop224' 83 | }), 84 | ]) 85 | 86 | imagenet_models = OrderedDict([(ThreatModel.Linf, linf), 87 | (ThreatModel.corruptions, common_corruptions)]) 88 | 89 | 90 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: cotta 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7b6447c_0 10 | - ca-certificates=2021.10.26=h06a4308_2 11 | - certifi=2021.10.8=py39h06a4308_0 12 | - colorama=0.4.4=pyhd3eb1b0_0 13 | - cudatoolkit=11.3.1=h2bc3f7f_2 14 | - ffmpeg=4.3=hf484d3e_0 15 | - freetype=2.10.4=h5ab3b9f_0 16 | - giflib=5.2.1=h7b6447c_0 17 | - gmp=6.2.1=h2531618_2 18 | - gnutls=3.6.15=he1e5248_0 19 | - intel-openmp=2021.3.0=h06a4308_3350 20 | - jpeg=9d=h7f8727e_0 21 | - lame=3.100=h7b6447c_0 22 | - lcms2=2.12=h3be6417_0 23 | - ld_impl_linux-64=2.35.1=h7274673_9 24 | - libffi=3.3=he6710b0_2 25 | - libgcc-ng=9.3.0=h5101ec6_17 26 | - libgomp=9.3.0=h5101ec6_17 27 | - libiconv=1.15=h63c8f33_5 28 | - libidn2=2.3.2=h7f8727e_0 29 | - libpng=1.6.37=hbc83047_0 30 | - libstdcxx-ng=9.3.0=hd4cf53a_17 31 | - libtasn1=4.16.0=h27cfd23_0 32 | - libtiff=4.2.0=h85742a9_0 33 | - libunistring=0.9.10=h27cfd23_0 34 | - libuv=1.40.0=h7b6447c_0 35 | - libwebp=1.2.0=h89dd481_0 36 | - libwebp-base=1.2.0=h27cfd23_0 37 | - lz4-c=1.9.3=h295c915_1 38 | - mkl=2021.3.0=h06a4308_520 39 | - mkl-service=2.4.0=py39h7f8727e_0 40 | - mkl_fft=1.3.1=py39hd3c417c_0 41 | - mkl_random=1.2.2=py39h51133e4_0 42 | - ncurses=6.2=he6710b0_1 43 | - nettle=3.7.3=hbbd107a_1 44 | - olefile=0.46=pyhd3eb1b0_0 45 | - openh264=2.1.0=hd408876_0 46 | - openssl=1.1.1l=h7f8727e_0 47 | - pillow=8.4.0=py39h5aabda8_0 48 | - pip=21.2.4=py39h06a4308_0 49 | - python=3.9.7=h12debd9_1 50 | - pytorch=1.10.0=py3.9_cuda11.3_cudnn8.2.0_0 51 | - pytorch-mutex=1.0=cuda 52 | - readline=8.1=h27cfd23_0 53 | - setuptools=58.0.4=py39h06a4308_0 54 | - six=1.16.0=pyhd3eb1b0_0 55 | - sqlite=3.36.0=hc218d9a_0 56 | - tk=8.6.11=h1ccaba5_0 57 | - torchaudio=0.10.0=py39_cu113 58 | - torchvision=0.11.1=py39_cu113 59 | - typing_extensions=3.10.0.2=pyh06a4308_0 60 | - tzdata=2021a=h5d7bf9c_0 61 | - wheel=0.37.0=pyhd3eb1b0_1 62 | - xz=5.2.5=h7b6447c_0 63 | - zlib=1.2.11=h7b6447c_3 64 | - zstd=1.4.9=haebb681_0 65 | - pip: 66 | - addict==2.4.0 67 | - appdirs==1.4.4 68 | - attr==0.3.1 69 | - attrs==21.2.0 70 | - backcall==0.2.0 71 | - black==19.3b0 72 | - chardet==3.0.4 73 | - click==8.0.3 74 | - cloudpickle==2.0.0 75 | - decorator==5.1.0 76 | - flake8==4.0.1 77 | - idna==2.10 78 | - imagecorruptions==1.1.2 79 | - imageio==2.10.1 80 | - iopath==0.1.9 81 | - ipython==7.28.0 82 | - isort==4.3.21 83 | - jedi==0.18.0 84 | - jinja2==2.11.3 85 | - joblib==1.1.0 86 | - markupsafe==2.0.1 87 | - matplotlib-inline==0.1.3 88 | - mccabe==0.6.1 89 | - networkx==2.6.3 90 | - numpy==1.19.5 91 | - packaging==21.0 92 | - parameterized==0.8.1 93 | - parso==0.8.2 94 | - pexpect==4.8.0 95 | - pickleshare==0.7.5 96 | - portalocker==2.3.2 97 | - prompt-toolkit==3.0.21 98 | - ptyprocess==0.7.0 99 | - pycodestyle==2.8.0 100 | - pyflakes==2.4.0 101 | - pygments==2.10.0 102 | - pytorchcv==0.0.67 103 | - pytz==2021.3 104 | - pywavelets==1.1.1 105 | - pyyaml==6.0 106 | - requests==2.23.0 107 | - git+https://github.com/robustbench/robustbench@v0.1#egg=robustbench 108 | - scikit-image==0.18.3 109 | - scikit-learn==1.0.1 110 | - scipy==1.7.1 111 | - simplejson==3.17.5 112 | - submitit==1.4.0 113 | - threadpoolctl==3.0.0 114 | - tifffile==2021.10.12 115 | - timm==0.4.12 116 | - toml==0.10.2 117 | - tqdm==4.56.2 118 | - traitlets==5.1.0 119 | - urllib3==1.25.11 120 | - wcwidth==0.2.5 121 | - yacs==0.1.8 122 | - yapf==0.31.0 123 | - yattag==1.14.0 124 | - gdown==5.1.0 125 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/architectures/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class WideResNet(nn.Module): 51 | """ Based on code from https://github.com/yaodongyu/TRADES """ 52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True): 53 | super(WideResNet, self).__init__() 54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 55 | assert ((depth - 4) % 6 == 0) 56 | n = (depth - 4) / 6 57 | block = BasicBlock 58 | # 1st conv before any network block 59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 60 | padding=1, bias=False) 61 | # 1st block 62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 63 | if sub_block1: 64 | # 1st sub-block 65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 66 | # 2nd block 67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 68 | # 3rd block 69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 70 | # global average pooling and classifier 71 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 74 | self.nChannels = nChannels[3] 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | elif isinstance(m, nn.Linear) and not m.bias is None: 84 | m.bias.data.zero_() 85 | 86 | def forward(self, x): 87 | out = self.conv1(x) 88 | out = self.block1(out) 89 | out = self.block2(out) 90 | out = self.block3(out) 91 | out = self.relu(self.bn1(out)) 92 | out = F.avg_pool2d(out, 8) 93 | out = out.view(-1, self.nChannels) 94 | return self.fc(out) 95 | 96 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/architectures/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class WideResNet(nn.Module): 51 | """ Based on code from https://github.com/yaodongyu/TRADES """ 52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True): 53 | super(WideResNet, self).__init__() 54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 55 | assert ((depth - 4) % 6 == 0) 56 | n = (depth - 4) / 6 57 | block = BasicBlock 58 | # 1st conv before any network block 59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 60 | padding=1, bias=False) 61 | # 1st block 62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 63 | if sub_block1: 64 | # 1st sub-block 65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 66 | # 2nd block 67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 68 | # 3rd block 69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 70 | # global average pooling and classifier 71 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 74 | self.nChannels = nChannels[3] 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | elif isinstance(m, nn.Linear) and not m.bias is None: 84 | m.bias.data.zero_() 85 | 86 | def forward(self, x): 87 | out = self.conv1(x) 88 | out = self.block1(out) 89 | out = self.block2(out) 90 | out = self.block3(out) 91 | out = self.relu(self.bn1(out)) 92 | out = F.avg_pool2d(out, 8) 93 | out = out.view(-1, self.nChannels) 94 | return self.fc(out) 95 | 96 | -------------------------------------------------------------------------------- /cifar/tent.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.jit 6 | 7 | 8 | class Tent(nn.Module): 9 | """Tent adapts a model by entropy minimization during testing. 10 | 11 | Once tented, a model adapts itself by updating on every forward. 12 | """ 13 | def __init__(self, model, optimizer, steps=1, episodic=False): 14 | super().__init__() 15 | self.model = model 16 | self.optimizer = optimizer 17 | self.steps = steps 18 | assert steps > 0, "tent requires >= 1 step(s) to forward and update" 19 | self.episodic = episodic 20 | 21 | # note: if the model is never reset, like for continual adaptation, 22 | # then skipping the state copy would save memory 23 | self.model_state, self.optimizer_state = \ 24 | copy_model_and_optimizer(self.model, self.optimizer) 25 | 26 | def forward(self, x): 27 | if self.episodic: 28 | self.reset() 29 | 30 | for _ in range(self.steps): 31 | outputs = forward_and_adapt(x, self.model, self.optimizer) 32 | 33 | return outputs 34 | 35 | def reset(self): 36 | if self.model_state is None or self.optimizer_state is None: 37 | raise Exception("cannot reset without saved model/optimizer state") 38 | load_model_and_optimizer(self.model, self.optimizer, 39 | self.model_state, self.optimizer_state) 40 | 41 | 42 | @torch.jit.script 43 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 44 | """Entropy of softmax distribution from logits.""" 45 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 46 | 47 | 48 | @torch.enable_grad() # ensure grads in possible no grad context for testing 49 | def forward_and_adapt(x, model, optimizer): 50 | """Forward and adapt model on batch of data. 51 | 52 | Measure entropy of the model prediction, take gradients, and update params. 53 | """ 54 | # forward 55 | outputs = model(x) 56 | # adapt 57 | loss = softmax_entropy(outputs).mean(0) 58 | loss.backward() 59 | optimizer.step() 60 | optimizer.zero_grad() 61 | return outputs 62 | 63 | 64 | def collect_params(model): 65 | """Collect the affine scale + shift parameters from batch norms. 66 | 67 | Walk the model's modules and collect all batch normalization parameters. 68 | Return the parameters and their names. 69 | 70 | Note: other choices of parameterization are possible! 71 | """ 72 | params = [] 73 | names = [] 74 | for nm, m in model.named_modules(): 75 | if isinstance(m, nn.BatchNorm2d): 76 | for np, p in m.named_parameters(): 77 | if np in ['weight', 'bias']: # weight is scale, bias is shift 78 | params.append(p) 79 | names.append(f"{nm}.{np}") 80 | return params, names 81 | 82 | 83 | def copy_model_and_optimizer(model, optimizer): 84 | """Copy the model and optimizer states for resetting after adaptation.""" 85 | model_state = deepcopy(model.state_dict()) 86 | optimizer_state = deepcopy(optimizer.state_dict()) 87 | return model_state, optimizer_state 88 | 89 | 90 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 91 | """Restore the model and optimizer states from copies.""" 92 | model.load_state_dict(model_state, strict=True) 93 | optimizer.load_state_dict(optimizer_state) 94 | 95 | 96 | def configure_model(model): 97 | """Configure model for use with tent.""" 98 | # train mode, because tent optimizes the model to minimize entropy 99 | model.train() 100 | # disable grad, to (re-)enable only what tent updates 101 | model.requires_grad_(False) 102 | # configure norm for tent updates: enable grad + force batch statisics 103 | for m in model.modules(): 104 | if isinstance(m, nn.BatchNorm2d): 105 | m.requires_grad_(True) 106 | # force use of batch stats in train and eval modes 107 | m.track_running_stats = False 108 | m.running_mean = None 109 | m.running_var = None 110 | return model 111 | 112 | 113 | def check_model(model): 114 | """Check model for compatability with tent.""" 115 | is_training = model.training 116 | assert is_training, "tent needs train mode: call model.train()" 117 | param_grads = [p.requires_grad for p in model.parameters()] 118 | has_any_params = any(param_grads) 119 | has_all_params = all(param_grads) 120 | assert has_any_params, "tent needs params to update: " \ 121 | "check which require grad" 122 | assert not has_all_params, "tent should not update all params: " \ 123 | "check which require grad" 124 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 125 | assert has_bn, "tent needs normalization for its optimization" 126 | -------------------------------------------------------------------------------- /imagenet/tent.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.jit 6 | 7 | 8 | class Tent(nn.Module): 9 | """Tent adapts a model by entropy minimization during testing. 10 | 11 | Once tented, a model adapts itself by updating on every forward. 12 | """ 13 | def __init__(self, model, optimizer, steps=1, episodic=False): 14 | super().__init__() 15 | self.model = model 16 | self.optimizer = optimizer 17 | self.steps = steps 18 | assert steps > 0, "tent requires >= 1 step(s) to forward and update" 19 | self.episodic = episodic 20 | 21 | # note: if the model is never reset, like for continual adaptation, 22 | # then skipping the state copy would save memory 23 | self.model_state, self.optimizer_state = \ 24 | copy_model_and_optimizer(self.model, self.optimizer) 25 | 26 | def forward(self, x): 27 | if self.episodic: 28 | self.reset() 29 | 30 | for _ in range(self.steps): 31 | outputs = forward_and_adapt(x, self.model, self.optimizer) 32 | 33 | return outputs 34 | 35 | def reset(self): 36 | if self.model_state is None or self.optimizer_state is None: 37 | raise Exception("cannot reset without saved model/optimizer state") 38 | load_model_and_optimizer(self.model, self.optimizer, 39 | self.model_state, self.optimizer_state) 40 | 41 | 42 | @torch.jit.script 43 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 44 | """Entropy of softmax distribution from logits.""" 45 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 46 | 47 | 48 | @torch.enable_grad() # ensure grads in possible no grad context for testing 49 | def forward_and_adapt(x, model, optimizer): 50 | """Forward and adapt model on batch of data. 51 | 52 | Measure entropy of the model prediction, take gradients, and update params. 53 | """ 54 | # forward 55 | outputs = model(x) 56 | # adapt 57 | loss = softmax_entropy(outputs).mean(0) 58 | loss.backward() 59 | optimizer.step() 60 | optimizer.zero_grad() 61 | return outputs 62 | 63 | 64 | def collect_params(model): 65 | """Collect the affine scale + shift parameters from batch norms. 66 | 67 | Walk the model's modules and collect all batch normalization parameters. 68 | Return the parameters and their names. 69 | 70 | Note: other choices of parameterization are possible! 71 | """ 72 | params = [] 73 | names = [] 74 | for nm, m in model.named_modules(): 75 | if isinstance(m, nn.BatchNorm2d): 76 | for np, p in m.named_parameters(): 77 | if np in ['weight', 'bias']: # weight is scale, bias is shift 78 | params.append(p) 79 | names.append(f"{nm}.{np}") 80 | return params, names 81 | 82 | 83 | def copy_model_and_optimizer(model, optimizer): 84 | """Copy the model and optimizer states for resetting after adaptation.""" 85 | model_state = deepcopy(model.state_dict()) 86 | optimizer_state = deepcopy(optimizer.state_dict()) 87 | return model_state, optimizer_state 88 | 89 | 90 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 91 | """Restore the model and optimizer states from copies.""" 92 | model.load_state_dict(model_state, strict=True) 93 | optimizer.load_state_dict(optimizer_state) 94 | 95 | 96 | def configure_model(model): 97 | """Configure model for use with tent.""" 98 | # train mode, because tent optimizes the model to minimize entropy 99 | model.train() 100 | # disable grad, to (re-)enable only what tent updates 101 | model.requires_grad_(False) 102 | # configure norm for tent updates: enable grad + force batch statisics 103 | for m in model.modules(): 104 | if isinstance(m, nn.BatchNorm2d): 105 | m.requires_grad_(True) 106 | # force use of batch stats in train and eval modes 107 | m.track_running_stats = False 108 | m.running_mean = None 109 | m.running_var = None 110 | return model 111 | 112 | 113 | def check_model(model): 114 | """Check model for compatability with tent.""" 115 | is_training = model.training 116 | assert is_training, "tent needs train mode: call model.train()" 117 | param_grads = [p.requires_grad for p in model.parameters()] 118 | has_any_params = any(param_grads) 119 | has_all_params = all(param_grads) 120 | assert has_any_params, "tent needs params to update: " \ 121 | "check which require grad" 122 | assert not has_all_params, "tent should not update all params: " \ 123 | "check which require grad" 124 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 125 | assert has_bn, "tent needs normalization for its optimization" 126 | -------------------------------------------------------------------------------- /imagenet/my_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | from torchvision.transforms import ColorJitter, Compose, Lambda 4 | from numpy import random 5 | 6 | class GaussianNoise(torch.nn.Module): 7 | def __init__(self, mean=0., std=1.): 8 | super().__init__() 9 | self.std = std 10 | self.mean = mean 11 | 12 | def forward(self, img): 13 | noise = torch.randn(img.size()) * self.std + self.mean 14 | noise = noise.to(img.device) 15 | return img + noise 16 | 17 | def __repr__(self): 18 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 19 | 20 | class Clip(torch.nn.Module): 21 | def __init__(self, min_val=0., max_val=1.): 22 | super().__init__() 23 | self.min_val = min_val 24 | self.max_val = max_val 25 | 26 | def forward(self, img): 27 | return torch.clip(img, self.min_val, self.max_val) 28 | 29 | def __repr__(self): 30 | return self.__class__.__name__ + '(min_val={0}, max_val={1})'.format(self.min_val, self.max_val) 31 | 32 | class ColorJitterPro(ColorJitter): 33 | """Randomly change the brightness, contrast, saturation, and gamma correction of an image.""" 34 | 35 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, gamma=0): 36 | super().__init__(brightness, contrast, saturation, hue) 37 | self.gamma = self._check_input(gamma, 'gamma') 38 | 39 | @staticmethod 40 | @torch.jit.unused 41 | def get_params(brightness, contrast, saturation, hue, gamma): 42 | """Get a randomized transform to be applied on image. 43 | 44 | Arguments are same as that of __init__. 45 | 46 | Returns: 47 | Transform which randomly adjusts brightness, contrast and 48 | saturation in a random order. 49 | """ 50 | transforms = [] 51 | 52 | if brightness is not None: 53 | brightness_factor = random.uniform(brightness[0], brightness[1]) 54 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 55 | 56 | if contrast is not None: 57 | contrast_factor = random.uniform(contrast[0], contrast[1]) 58 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 59 | 60 | if saturation is not None: 61 | saturation_factor = random.uniform(saturation[0], saturation[1]) 62 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 63 | 64 | if hue is not None: 65 | hue_factor = random.uniform(hue[0], hue[1]) 66 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 67 | 68 | if gamma is not None: 69 | gamma_factor = random.uniform(gamma[0], gamma[1]) 70 | transforms.append(Lambda(lambda img: F.adjust_gamma(img, gamma_factor))) 71 | 72 | random.shuffle(transforms) 73 | transform = Compose(transforms) 74 | 75 | return transform 76 | 77 | def forward(self, img): 78 | """ 79 | Args: 80 | img (PIL Image or Tensor): Input image. 81 | 82 | Returns: 83 | PIL Image or Tensor: Color jittered image. 84 | """ 85 | fn_idx = torch.randperm(5) 86 | for fn_id in fn_idx: 87 | if fn_id == 0 and self.brightness is not None: 88 | brightness = self.brightness 89 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 90 | img = F.adjust_brightness(img, brightness_factor) 91 | 92 | if fn_id == 1 and self.contrast is not None: 93 | contrast = self.contrast 94 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() 95 | img = F.adjust_contrast(img, contrast_factor) 96 | 97 | if fn_id == 2 and self.saturation is not None: 98 | saturation = self.saturation 99 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() 100 | img = F.adjust_saturation(img, saturation_factor) 101 | 102 | if fn_id == 3 and self.hue is not None: 103 | hue = self.hue 104 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() 105 | img = F.adjust_hue(img, hue_factor) 106 | 107 | if fn_id == 4 and self.gamma is not None: 108 | gamma = self.gamma 109 | gamma_factor = torch.tensor(1.0).uniform_(gamma[0], gamma[1]).item() 110 | img = img.clamp(1e-8, 1.0) # to fix Nan values in gradients, which happens when applying gamma 111 | # after contrast 112 | img = F.adjust_gamma(img, gamma_factor) 113 | 114 | return img 115 | 116 | def __repr__(self): 117 | format_string = self.__class__.__name__ + '(' 118 | format_string += 'brightness={0}'.format(self.brightness) 119 | format_string += ', contrast={0}'.format(self.contrast) 120 | format_string += ', saturation={0}'.format(self.saturation) 121 | format_string += ', hue={0})'.format(self.hue) 122 | format_string += ', gamma={0})'.format(self.gamma) 123 | return format_string 124 | -------------------------------------------------------------------------------- /cifar/my_transforms.py: -------------------------------------------------------------------------------- 1 | # KATANA: Simple Post-Training Robustness Using Test Time Augmentations 2 | # https://arxiv.org/pdf/2109.08191v1.pdf 3 | import torch 4 | import torchvision.transforms.functional as F 5 | from torchvision.transforms import ColorJitter, Compose, Lambda 6 | from numpy import random 7 | 8 | class GaussianNoise(torch.nn.Module): 9 | def __init__(self, mean=0., std=1.): 10 | super().__init__() 11 | self.std = std 12 | self.mean = mean 13 | 14 | def forward(self, img): 15 | noise = torch.randn(img.size()) * self.std + self.mean 16 | noise = noise.to(img.device) 17 | return img + noise 18 | 19 | def __repr__(self): 20 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 21 | 22 | class Clip(torch.nn.Module): 23 | def __init__(self, min_val=0., max_val=1.): 24 | super().__init__() 25 | self.min_val = min_val 26 | self.max_val = max_val 27 | 28 | def forward(self, img): 29 | return torch.clip(img, self.min_val, self.max_val) 30 | 31 | def __repr__(self): 32 | return self.__class__.__name__ + '(min_val={0}, max_val={1})'.format(self.min_val, self.max_val) 33 | 34 | class ColorJitterPro(ColorJitter): 35 | """Randomly change the brightness, contrast, saturation, and gamma correction of an image.""" 36 | 37 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, gamma=0): 38 | super().__init__(brightness, contrast, saturation, hue) 39 | self.gamma = self._check_input(gamma, 'gamma') 40 | 41 | @staticmethod 42 | @torch.jit.unused 43 | def get_params(brightness, contrast, saturation, hue, gamma): 44 | """Get a randomized transform to be applied on image. 45 | 46 | Arguments are same as that of __init__. 47 | 48 | Returns: 49 | Transform which randomly adjusts brightness, contrast and 50 | saturation in a random order. 51 | """ 52 | transforms = [] 53 | 54 | if brightness is not None: 55 | brightness_factor = random.uniform(brightness[0], brightness[1]) 56 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 57 | 58 | if contrast is not None: 59 | contrast_factor = random.uniform(contrast[0], contrast[1]) 60 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 61 | 62 | if saturation is not None: 63 | saturation_factor = random.uniform(saturation[0], saturation[1]) 64 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 65 | 66 | if hue is not None: 67 | hue_factor = random.uniform(hue[0], hue[1]) 68 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 69 | 70 | if gamma is not None: 71 | gamma_factor = random.uniform(gamma[0], gamma[1]) 72 | transforms.append(Lambda(lambda img: F.adjust_gamma(img, gamma_factor))) 73 | 74 | random.shuffle(transforms) 75 | transform = Compose(transforms) 76 | 77 | return transform 78 | 79 | def forward(self, img): 80 | """ 81 | Args: 82 | img (PIL Image or Tensor): Input image. 83 | 84 | Returns: 85 | PIL Image or Tensor: Color jittered image. 86 | """ 87 | fn_idx = torch.randperm(5) 88 | for fn_id in fn_idx: 89 | if fn_id == 0 and self.brightness is not None: 90 | brightness = self.brightness 91 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 92 | img = F.adjust_brightness(img, brightness_factor) 93 | 94 | if fn_id == 1 and self.contrast is not None: 95 | contrast = self.contrast 96 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() 97 | img = F.adjust_contrast(img, contrast_factor) 98 | 99 | if fn_id == 2 and self.saturation is not None: 100 | saturation = self.saturation 101 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() 102 | img = F.adjust_saturation(img, saturation_factor) 103 | 104 | if fn_id == 3 and self.hue is not None: 105 | hue = self.hue 106 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() 107 | img = F.adjust_hue(img, hue_factor) 108 | 109 | if fn_id == 4 and self.gamma is not None: 110 | gamma = self.gamma 111 | gamma_factor = torch.tensor(1.0).uniform_(gamma[0], gamma[1]).item() 112 | img = img.clamp(1e-8, 1.0) # to fix Nan values in gradients, which happens when applying gamma 113 | # after contrast 114 | img = F.adjust_gamma(img, gamma_factor) 115 | 116 | return img 117 | 118 | def __repr__(self): 119 | format_string = self.__class__.__name__ + '(' 120 | format_string += 'brightness={0}'.format(self.brightness) 121 | format_string += ', contrast={0}'.format(self.contrast) 122 | format_string += ', saturation={0}'.format(self.saturation) 123 | format_string += ', hue={0})'.format(self.hue) 124 | format_string += ', gamma={0})'.format(self.gamma) 125 | return format_string 126 | -------------------------------------------------------------------------------- /imagenet/imagenetc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from robustbench.data import load_imagenetc 7 | from robustbench.model_zoo.enums import ThreatModel 8 | from robustbench.utils import load_model 9 | from robustbench.utils import clean_accuracy as accuracy 10 | 11 | import tent 12 | import norm 13 | import cotta 14 | 15 | from conf import cfg, load_cfg_fom_args 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def evaluate(description): 22 | load_cfg_fom_args(description) 23 | # configure model 24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR, 25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda() 26 | if cfg.MODEL.ADAPTATION == "source": 27 | logger.info("test-time adaptation: NONE") 28 | model = setup_source(base_model) 29 | if cfg.MODEL.ADAPTATION == "norm": 30 | logger.info("test-time adaptation: NORM") 31 | model = setup_norm(base_model) 32 | if cfg.MODEL.ADAPTATION == "tent": 33 | logger.info("test-time adaptation: TENT") 34 | model = setup_tent(base_model) 35 | if cfg.MODEL.ADAPTATION == "cotta": 36 | logger.info("test-time adaptation: CoTTA") 37 | model = setup_cotta(base_model) 38 | # evaluate on each severity and type of corruption in turn 39 | prev_ct = "x0" 40 | for ii, severity in enumerate(cfg.CORRUPTION.SEVERITY): 41 | for i_x, corruption_type in enumerate(cfg.CORRUPTION.TYPE): 42 | # reset adaptation for each combination of corruption x severity 43 | # note: for evaluation protocol, but not necessarily needed 44 | try: 45 | if i_x == 0: 46 | model.reset() 47 | logger.info("resetting model") 48 | else: 49 | logger.warning("not resetting model") 50 | except: 51 | logger.warning("not resetting model") 52 | x_test, y_test = load_imagenetc(cfg.CORRUPTION.NUM_EX, 53 | severity, cfg.DATA_DIR, False, 54 | [corruption_type]) 55 | x_test, y_test = x_test.cuda(), y_test.cuda() 56 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE) 57 | err = 1. - acc 58 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}") 59 | 60 | 61 | def setup_source(model): 62 | """Set up the baseline source model without adaptation.""" 63 | model.eval() 64 | logger.info(f"model for evaluation: %s", model) 65 | return model 66 | 67 | 68 | def setup_norm(model): 69 | """Set up test-time normalization adaptation. 70 | 71 | Adapt by normalizing features with test batch statistics. 72 | The statistics are measured independently for each batch; 73 | no running average or other cross-batch estimation is used. 74 | """ 75 | norm_model = norm.Norm(model) 76 | logger.info(f"model for adaptation: %s", model) 77 | stats, stat_names = norm.collect_stats(model) 78 | logger.info(f"stats for adaptation: %s", stat_names) 79 | return norm_model 80 | 81 | 82 | def setup_tent(model): 83 | """Set up tent adaptation. 84 | 85 | Configure the model for training + feature modulation by batch statistics, 86 | collect the parameters for feature modulation by gradient optimization, 87 | set up the optimizer, and then tent the model. 88 | """ 89 | model = tent.configure_model(model) 90 | params, param_names = tent.collect_params(model) 91 | optimizer = setup_optimizer(params) 92 | tent_model = tent.Tent(model, optimizer, 93 | steps=cfg.OPTIM.STEPS, 94 | episodic=cfg.MODEL.EPISODIC) 95 | logger.info(f"model for adaptation: %s", model) 96 | logger.info(f"params for adaptation: %s", param_names) 97 | logger.info(f"optimizer for adaptation: %s", optimizer) 98 | return tent_model 99 | 100 | 101 | def setup_optimizer(params): 102 | """Set up optimizer for tent adaptation. 103 | 104 | Tent needs an optimizer for test-time entropy minimization. 105 | In principle, tent could make use of any gradient optimizer. 106 | In practice, we advise choosing Adam or SGD+momentum. 107 | For optimization settings, we advise to use the settings from the end of 108 | trainig, if known, or start with a low learning rate (like 0.001) if not. 109 | 110 | For best results, try tuning the learning rate and batch size. 111 | """ 112 | if cfg.OPTIM.METHOD == 'Adam': 113 | return optim.Adam(params, 114 | lr=cfg.OPTIM.LR, 115 | betas=(cfg.OPTIM.BETA, 0.999), 116 | weight_decay=cfg.OPTIM.WD) 117 | elif cfg.OPTIM.METHOD == 'SGD': 118 | return optim.SGD(params, 119 | lr=cfg.OPTIM.LR, 120 | momentum=0.9, 121 | dampening=0, 122 | weight_decay=cfg.OPTIM.WD, 123 | nesterov=True) 124 | else: 125 | raise NotImplementedError 126 | 127 | def setup_cotta(model): 128 | """Set up tent adaptation. 129 | 130 | Configure the model for training + feature modulation by batch statistics, 131 | collect the parameters for feature modulation by gradient optimization, 132 | set up the optimizer, and then tent the model. 133 | """ 134 | model = cotta.configure_model(model) 135 | params, param_names = cotta.collect_params(model) 136 | optimizer = setup_optimizer(params) 137 | cotta_model = cotta.CoTTA(model, optimizer, 138 | steps=cfg.OPTIM.STEPS, 139 | episodic=cfg.MODEL.EPISODIC) 140 | logger.info(f"model for adaptation: %s", model) 141 | logger.info(f"params for adaptation: %s", param_names) 142 | logger.info(f"optimizer for adaptation: %s", optimizer) 143 | return cotta_model 144 | 145 | 146 | if __name__ == '__main__': 147 | evaluate('"Imagenet-C evaluation.') 148 | -------------------------------------------------------------------------------- /cifar/cifar100c.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from robustbench.data import load_cifar100c 7 | from robustbench.model_zoo.enums import ThreatModel 8 | from robustbench.utils import load_model 9 | from robustbench.utils import clean_accuracy as accuracy 10 | 11 | import tent 12 | import norm 13 | import cotta 14 | 15 | from conf import cfg, load_cfg_fom_args 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def evaluate(description): 22 | load_cfg_fom_args(description) 23 | # configure model 24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR, 25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda() 26 | if cfg.MODEL.ADAPTATION == "source": 27 | logger.info("test-time adaptation: NONE") 28 | model = setup_source(base_model) 29 | if cfg.MODEL.ADAPTATION == "norm": 30 | logger.info("test-time adaptation: NORM") 31 | model = setup_norm(base_model) 32 | if cfg.MODEL.ADAPTATION == "tent": 33 | logger.info("test-time adaptation: TENT") 34 | model = setup_tent(base_model) 35 | if cfg.MODEL.ADAPTATION == "cotta": 36 | logger.info("test-time adaptation: CoTTA") 37 | model = setup_cotta(base_model) 38 | # evaluate on each severity and type of corruption in turn 39 | prev_ct = "x0" 40 | for severity in cfg.CORRUPTION.SEVERITY: 41 | for i_c, corruption_type in enumerate(cfg.CORRUPTION.TYPE): 42 | # continual adaptation for all corruption 43 | if i_c == 0: 44 | try: 45 | model.reset() 46 | logger.info("resetting model") 47 | except: 48 | logger.warning("not resetting model") 49 | else: 50 | logger.warning("not resetting model") 51 | x_test, y_test = load_cifar100c(cfg.CORRUPTION.NUM_EX, 52 | severity, cfg.DATA_DIR, False, 53 | [corruption_type]) 54 | x_test, y_test = x_test.cuda(), y_test.cuda() 55 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE) 56 | err = 1. - acc 57 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}") 58 | 59 | 60 | def setup_source(model): 61 | """Set up the baseline source model without adaptation.""" 62 | model.eval() 63 | logger.info(f"model for evaluation: %s", model) 64 | return model 65 | 66 | 67 | def setup_norm(model): 68 | """Set up test-time normalization adaptation. 69 | 70 | Adapt by normalizing features with test batch statistics. 71 | The statistics are measured independently for each batch; 72 | no running average or other cross-batch estimation is used. 73 | """ 74 | norm_model = norm.Norm(model) 75 | logger.info(f"model for adaptation: %s", model) 76 | stats, stat_names = norm.collect_stats(model) 77 | logger.info(f"stats for adaptation: %s", stat_names) 78 | return norm_model 79 | 80 | 81 | def setup_tent(model): 82 | """Set up tent adaptation. 83 | 84 | Configure the model for training + feature modulation by batch statistics, 85 | collect the parameters for feature modulation by gradient optimization, 86 | set up the optimizer, and then tent the model. 87 | """ 88 | model = tent.configure_model(model) 89 | params, param_names = tent.collect_params(model) 90 | optimizer = setup_optimizer(params) 91 | tent_model = tent.Tent(model, optimizer, 92 | steps=cfg.OPTIM.STEPS, 93 | episodic=cfg.MODEL.EPISODIC) 94 | logger.info(f"model for adaptation: %s", model) 95 | logger.info(f"params for adaptation: %s", param_names) 96 | logger.info(f"optimizer for adaptation: %s", optimizer) 97 | return tent_model 98 | 99 | 100 | def setup_cotta(model): 101 | """Set up tent adaptation. 102 | 103 | Configure the model for training + feature modulation by batch statistics, 104 | collect the parameters for feature modulation by gradient optimization, 105 | set up the optimizer, and then tent the model. 106 | """ 107 | model = cotta.configure_model(model) 108 | params, param_names = cotta.collect_params(model) 109 | optimizer = setup_optimizer(params) 110 | cotta_model = cotta.CoTTA(model, optimizer, 111 | steps=cfg.OPTIM.STEPS, 112 | episodic=cfg.MODEL.EPISODIC, 113 | mt_alpha=cfg.OPTIM.MT, 114 | rst_m=cfg.OPTIM.RST, 115 | ap=cfg.OPTIM.AP) 116 | logger.info(f"model for adaptation: %s", model) 117 | logger.info(f"params for adaptation: %s", param_names) 118 | logger.info(f"optimizer for adaptation: %s", optimizer) 119 | return cotta_model 120 | 121 | 122 | def setup_optimizer(params): 123 | """Set up optimizer for tent adaptation. 124 | 125 | Tent needs an optimizer for test-time entropy minimization. 126 | In principle, tent could make use of any gradient optimizer. 127 | In practice, we advise choosing Adam or SGD+momentum. 128 | For optimization settings, we advise to use the settings from the end of 129 | trainig, if known, or start with a low learning rate (like 0.001) if not. 130 | 131 | For best results, try tuning the learning rate and batch size. 132 | """ 133 | if cfg.OPTIM.METHOD == 'Adam': 134 | return optim.Adam(params, 135 | lr=cfg.OPTIM.LR, 136 | betas=(cfg.OPTIM.BETA, 0.999), 137 | weight_decay=cfg.OPTIM.WD) 138 | elif cfg.OPTIM.METHOD == 'SGD': 139 | return optim.SGD(params, 140 | lr=cfg.OPTIM.LR, 141 | momentum=cfg.OPTIM.MOMENTUM, 142 | dampening=cfg.OPTIM.DAMPENING, 143 | weight_decay=cfg.OPTIM.WD, 144 | nesterov=cfg.OPTIM.NESTEROV) 145 | else: 146 | raise NotImplementedError 147 | 148 | 149 | if __name__ == '__main__': 150 | evaluate('"CIFAR-10-C evaluation.') 151 | -------------------------------------------------------------------------------- /cifar/cifar10c.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from robustbench.data import load_cifar10c 7 | from robustbench.model_zoo.enums import ThreatModel 8 | from robustbench.utils import load_model 9 | from robustbench.utils import clean_accuracy as accuracy 10 | 11 | import tent 12 | import norm 13 | import cotta 14 | 15 | from conf import cfg, load_cfg_fom_args 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def evaluate(description): 22 | load_cfg_fom_args(description) 23 | # configure model 24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR, 25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda() 26 | if cfg.MODEL.ADAPTATION == "source": 27 | logger.info("test-time adaptation: NONE") 28 | model = setup_source(base_model) 29 | if cfg.MODEL.ADAPTATION == "norm": 30 | logger.info("test-time adaptation: NORM") 31 | model = setup_norm(base_model) 32 | if cfg.MODEL.ADAPTATION == "tent": 33 | logger.info("test-time adaptation: TENT") 34 | model = setup_tent(base_model) 35 | if cfg.MODEL.ADAPTATION == "cotta": 36 | logger.info("test-time adaptation: CoTTA") 37 | model = setup_cotta(base_model) 38 | # evaluate on each severity and type of corruption in turn 39 | prev_ct = "x0" 40 | for severity in cfg.CORRUPTION.SEVERITY: 41 | for i_c, corruption_type in enumerate(cfg.CORRUPTION.TYPE): 42 | # continual adaptation for all corruption 43 | if i_c == 0: 44 | try: 45 | model.reset() 46 | logger.info("resetting model") 47 | except: 48 | logger.warning("not resetting model") 49 | else: 50 | logger.warning("not resetting model") 51 | x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX, 52 | severity, cfg.DATA_DIR, False, 53 | [corruption_type]) 54 | x_test, y_test = x_test.cuda(), y_test.cuda() 55 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE) 56 | err = 1. - acc 57 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}") 58 | 59 | 60 | def setup_source(model): 61 | """Set up the baseline source model without adaptation.""" 62 | model.eval() 63 | logger.info(f"model for evaluation: %s", model) 64 | return model 65 | 66 | 67 | def setup_norm(model): 68 | """Set up test-time normalization adaptation. 69 | 70 | Adapt by normalizing features with test batch statistics. 71 | The statistics are measured independently for each batch; 72 | no running average or other cross-batch estimation is used. 73 | """ 74 | norm_model = norm.Norm(model) 75 | logger.info(f"model for adaptation: %s", model) 76 | stats, stat_names = norm.collect_stats(model) 77 | logger.info(f"stats for adaptation: %s", stat_names) 78 | return norm_model 79 | 80 | 81 | def setup_tent(model): 82 | """Set up tent adaptation. 83 | 84 | Configure the model for training + feature modulation by batch statistics, 85 | collect the parameters for feature modulation by gradient optimization, 86 | set up the optimizer, and then tent the model. 87 | """ 88 | model = tent.configure_model(model) 89 | params, param_names = tent.collect_params(model) 90 | optimizer = setup_optimizer(params) 91 | tent_model = tent.Tent(model, optimizer, 92 | steps=cfg.OPTIM.STEPS, 93 | episodic=cfg.MODEL.EPISODIC) 94 | logger.info(f"model for adaptation: %s", model) 95 | logger.info(f"params for adaptation: %s", param_names) 96 | logger.info(f"optimizer for adaptation: %s", optimizer) 97 | return tent_model 98 | 99 | 100 | def setup_cotta(model): 101 | """Set up tent adaptation. 102 | 103 | Configure the model for training + feature modulation by batch statistics, 104 | collect the parameters for feature modulation by gradient optimization, 105 | set up the optimizer, and then tent the model. 106 | """ 107 | model = cotta.configure_model(model) 108 | params, param_names = cotta.collect_params(model) 109 | optimizer = setup_optimizer(params) 110 | cotta_model = cotta.CoTTA(model, optimizer, 111 | steps=cfg.OPTIM.STEPS, 112 | episodic=cfg.MODEL.EPISODIC, 113 | mt_alpha=cfg.OPTIM.MT, 114 | rst_m=cfg.OPTIM.RST, 115 | ap=cfg.OPTIM.AP) 116 | logger.info(f"model for adaptation: %s", model) 117 | logger.info(f"params for adaptation: %s", param_names) 118 | logger.info(f"optimizer for adaptation: %s", optimizer) 119 | return cotta_model 120 | 121 | 122 | def setup_optimizer(params): 123 | """Set up optimizer for tent adaptation. 124 | 125 | Tent needs an optimizer for test-time entropy minimization. 126 | In principle, tent could make use of any gradient optimizer. 127 | In practice, we advise choosing Adam or SGD+momentum. 128 | For optimization settings, we advise to use the settings from the end of 129 | trainig, if known, or start with a low learning rate (like 0.001) if not. 130 | 131 | For best results, try tuning the learning rate and batch size. 132 | """ 133 | if cfg.OPTIM.METHOD == 'Adam': 134 | return optim.Adam(params, 135 | lr=cfg.OPTIM.LR, 136 | betas=(cfg.OPTIM.BETA, 0.999), 137 | weight_decay=cfg.OPTIM.WD) 138 | elif cfg.OPTIM.METHOD == 'SGD': 139 | return optim.SGD(params, 140 | lr=cfg.OPTIM.LR, 141 | momentum=cfg.OPTIM.MOMENTUM, 142 | dampening=cfg.OPTIM.DAMPENING, 143 | weight_decay=cfg.OPTIM.WD, 144 | nesterov=cfg.OPTIM.NESTEROV) 145 | else: 146 | raise NotImplementedError 147 | 148 | 149 | if __name__ == '__main__': 150 | evaluate('"CIFAR-10-C evaluation.') 151 | -------------------------------------------------------------------------------- /cifar/cifar10c_gradual.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from robustbench.data import load_cifar10c 7 | from robustbench.model_zoo.enums import ThreatModel 8 | from robustbench.utils import load_model 9 | from robustbench.utils import clean_accuracy as accuracy 10 | 11 | import tent 12 | import norm 13 | import cotta 14 | 15 | from conf import cfg, load_cfg_fom_args 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def evaluate(description): 22 | load_cfg_fom_args(description) 23 | # configure model 24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR, 25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda() 26 | if cfg.MODEL.ADAPTATION == "source": 27 | logger.info("test-time adaptation: NONE") 28 | model = setup_source(base_model) 29 | if cfg.MODEL.ADAPTATION == "norm": 30 | logger.info("test-time adaptation: NORM") 31 | model = setup_norm(base_model) 32 | if cfg.MODEL.ADAPTATION == "tent": 33 | logger.info("test-time adaptation: TENT") 34 | model = setup_tent(base_model) 35 | if cfg.MODEL.ADAPTATION == "cotta": 36 | logger.info("test-time adaptation: CoTTA") 37 | model = setup_cotta(base_model) 38 | # evaluate on each severity and type of corruption in turn 39 | prev_ct = "x0" 40 | for i_c, corruption_type in enumerate(cfg.CORRUPTION.TYPE): 41 | # continual adaptation for all corruption 42 | if i_c == 0: 43 | try: 44 | model.reset() 45 | logger.info("resetting model") 46 | except: 47 | logger.warning("not resetting model") 48 | severities = [5,4,3,2,1] 49 | #To simulate the large gap between the source and first target domain, we start from severity 5. 50 | else: 51 | severities = [1,2,3,4,5,4,3,2,1] 52 | logger.info("not resetting model") 53 | for severity in severities: 54 | x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX, 55 | severity, cfg.DATA_DIR, False, 56 | [corruption_type]) 57 | x_test, y_test = x_test.cuda(), y_test.cuda() 58 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE) 59 | err = 1. - acc 60 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}") 61 | 62 | 63 | def setup_source(model): 64 | """Set up the baseline source model without adaptation.""" 65 | model.eval() 66 | logger.info(f"model for evaluation: %s", model) 67 | return model 68 | 69 | 70 | def setup_norm(model): 71 | """Set up test-time normalization adaptation. 72 | 73 | Adapt by normalizing features with test batch statistics. 74 | The statistics are measured independently for each batch; 75 | no running average or other cross-batch estimation is used. 76 | """ 77 | norm_model = norm.Norm(model) 78 | logger.info(f"model for adaptation: %s", model) 79 | stats, stat_names = norm.collect_stats(model) 80 | logger.info(f"stats for adaptation: %s", stat_names) 81 | return norm_model 82 | 83 | 84 | def setup_tent(model): 85 | """Set up tent adaptation. 86 | 87 | Configure the model for training + feature modulation by batch statistics, 88 | collect the parameters for feature modulation by gradient optimization, 89 | set up the optimizer, and then tent the model. 90 | """ 91 | model = tent.configure_model(model) 92 | params, param_names = tent.collect_params(model) 93 | optimizer = setup_optimizer(params) 94 | tent_model = tent.Tent(model, optimizer, 95 | steps=cfg.OPTIM.STEPS, 96 | episodic=cfg.MODEL.EPISODIC) 97 | logger.info(f"model for adaptation: %s", model) 98 | logger.info(f"params for adaptation: %s", param_names) 99 | logger.info(f"optimizer for adaptation: %s", optimizer) 100 | return tent_model 101 | 102 | 103 | def setup_cotta(model): 104 | """Set up tent adaptation. 105 | 106 | Configure the model for training + feature modulation by batch statistics, 107 | collect the parameters for feature modulation by gradient optimization, 108 | set up the optimizer, and then tent the model. 109 | """ 110 | model = cotta.configure_model(model) 111 | params, param_names = cotta.collect_params(model) 112 | optimizer = setup_optimizer(params) 113 | cotta_model = cotta.CoTTA(model, optimizer, 114 | steps=cfg.OPTIM.STEPS, 115 | episodic=cfg.MODEL.EPISODIC, 116 | mt_alpha=cfg.OPTIM.MT, 117 | rst_m=cfg.OPTIM.RST, 118 | ap=cfg.OPTIM.AP) 119 | logger.info(f"model for adaptation: %s", model) 120 | logger.info(f"params for adaptation: %s", param_names) 121 | logger.info(f"optimizer for adaptation: %s", optimizer) 122 | return cotta_model 123 | 124 | 125 | def setup_optimizer(params): 126 | """Set up optimizer for tent adaptation. 127 | 128 | Tent needs an optimizer for test-time entropy minimization. 129 | In principle, tent could make use of any gradient optimizer. 130 | In practice, we advise choosing Adam or SGD+momentum. 131 | For optimization settings, we advise to use the settings from the end of 132 | trainig, if known, or start with a low learning rate (like 0.001) if not. 133 | 134 | For best results, try tuning the learning rate and batch size. 135 | """ 136 | if cfg.OPTIM.METHOD == 'Adam': 137 | return optim.Adam(params, 138 | lr=cfg.OPTIM.LR, 139 | betas=(cfg.OPTIM.BETA, 0.999), 140 | weight_decay=cfg.OPTIM.WD) 141 | elif cfg.OPTIM.METHOD == 'SGD': 142 | return optim.SGD(params, 143 | lr=cfg.OPTIM.LR, 144 | momentum=cfg.OPTIM.MOMENTUM, 145 | dampening=cfg.OPTIM.DAMPENING, 146 | weight_decay=cfg.OPTIM.WD, 147 | nesterov=cfg.OPTIM.NESTEROV) 148 | else: 149 | raise NotImplementedError 150 | 151 | 152 | if __name__ == '__main__': 153 | evaluate('"CIFAR-10-C evaluation.') 154 | -------------------------------------------------------------------------------- /cifar/robustbench/model_zoo/architectures/resnext.py: -------------------------------------------------------------------------------- 1 | """ResNeXt implementation (https://arxiv.org/abs/1611.05431). 2 | 3 | MIT License 4 | 5 | Copyright (c) 2017 Xuanyi Dong 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | From: 26 | https://github.com/google-research/augmix/blob/master/third_party/WideResNet_pytorch/wideresnet.py 27 | 28 | """ 29 | 30 | import math 31 | 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | from torch.nn import init 35 | 36 | 37 | class ResNeXtBottleneck(nn.Module): 38 | """ 39 | ResNeXt Bottleneck Block type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua). 40 | """ 41 | expansion = 4 42 | 43 | def __init__(self, 44 | inplanes, 45 | planes, 46 | cardinality, 47 | base_width, 48 | stride=1, 49 | downsample=None): 50 | super(ResNeXtBottleneck, self).__init__() 51 | 52 | dim = int(math.floor(planes * (base_width / 64.0))) 53 | 54 | self.conv_reduce = nn.Conv2d( 55 | inplanes, 56 | dim * cardinality, 57 | kernel_size=1, 58 | stride=1, 59 | padding=0, 60 | bias=False) 61 | self.bn_reduce = nn.BatchNorm2d(dim * cardinality) 62 | 63 | self.conv_conv = nn.Conv2d( 64 | dim * cardinality, 65 | dim * cardinality, 66 | kernel_size=3, 67 | stride=stride, 68 | padding=1, 69 | groups=cardinality, 70 | bias=False) 71 | self.bn = nn.BatchNorm2d(dim * cardinality) 72 | 73 | self.conv_expand = nn.Conv2d( 74 | dim * cardinality, 75 | planes * 4, 76 | kernel_size=1, 77 | stride=1, 78 | padding=0, 79 | bias=False) 80 | self.bn_expand = nn.BatchNorm2d(planes * 4) 81 | 82 | self.downsample = downsample 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | bottleneck = self.conv_reduce(x) 88 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True) 89 | 90 | bottleneck = self.conv_conv(bottleneck) 91 | bottleneck = F.relu(self.bn(bottleneck), inplace=True) 92 | 93 | bottleneck = self.conv_expand(bottleneck) 94 | bottleneck = self.bn_expand(bottleneck) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | return F.relu(residual + bottleneck, inplace=True) 100 | 101 | 102 | class CifarResNeXt(nn.Module): 103 | """ResNext optimized for the Cifar dataset, as specified in 104 | https://arxiv.org/pdf/1611.05431.pdf.""" 105 | 106 | def __init__(self, block, depth, cardinality, base_width, num_classes): 107 | super(CifarResNeXt, self).__init__() 108 | 109 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 110 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101' 111 | layer_blocks = (depth - 2) // 9 112 | 113 | self.cardinality = cardinality 114 | self.base_width = base_width 115 | self.num_classes = num_classes 116 | 117 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 118 | self.bn_1 = nn.BatchNorm2d(64) 119 | 120 | self.inplanes = 64 121 | self.stage_1 = self._make_layer(block, 64, layer_blocks, 1) 122 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2) 123 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2) 124 | self.avgpool = nn.AvgPool2d(8) 125 | self.classifier = nn.Linear(256 * block.expansion, num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.Linear): 135 | init.kaiming_normal_(m.weight) 136 | m.bias.data.zero_() 137 | 138 | def _make_layer(self, block, planes, blocks, stride=1): 139 | downsample = None 140 | if stride != 1 or self.inplanes != planes * block.expansion: 141 | downsample = nn.Sequential( 142 | nn.Conv2d( 143 | self.inplanes, 144 | planes * block.expansion, 145 | kernel_size=1, 146 | stride=stride, 147 | bias=False), 148 | nn.BatchNorm2d(planes * block.expansion), 149 | ) 150 | 151 | layers = [] 152 | layers.append( 153 | block(self.inplanes, planes, self.cardinality, self.base_width, stride, 154 | downsample)) 155 | self.inplanes = planes * block.expansion 156 | for _ in range(1, blocks): 157 | layers.append( 158 | block(self.inplanes, planes, self.cardinality, self.base_width)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | x = self.conv_1_3x3(x) 164 | x = F.relu(self.bn_1(x), inplace=True) 165 | x = self.stage_1(x) 166 | x = self.stage_2(x) 167 | x = self.stage_3(x) 168 | x = self.avgpool(x) 169 | x = x.view(x.size(0), -1) 170 | return self.classifier(x) 171 | -------------------------------------------------------------------------------- /imagenet/robustbench/model_zoo/architectures/resnext.py: -------------------------------------------------------------------------------- 1 | """ResNeXt implementation (https://arxiv.org/abs/1611.05431). 2 | 3 | MIT License 4 | 5 | Copyright (c) 2017 Xuanyi Dong 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | From: 26 | https://github.com/google-research/augmix/blob/master/third_party/WideResNet_pytorch/wideresnet.py 27 | 28 | """ 29 | 30 | import math 31 | 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | from torch.nn import init 35 | 36 | 37 | class ResNeXtBottleneck(nn.Module): 38 | """ 39 | ResNeXt Bottleneck Block type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua). 40 | """ 41 | expansion = 4 42 | 43 | def __init__(self, 44 | inplanes, 45 | planes, 46 | cardinality, 47 | base_width, 48 | stride=1, 49 | downsample=None): 50 | super(ResNeXtBottleneck, self).__init__() 51 | 52 | dim = int(math.floor(planes * (base_width / 64.0))) 53 | 54 | self.conv_reduce = nn.Conv2d( 55 | inplanes, 56 | dim * cardinality, 57 | kernel_size=1, 58 | stride=1, 59 | padding=0, 60 | bias=False) 61 | self.bn_reduce = nn.BatchNorm2d(dim * cardinality) 62 | 63 | self.conv_conv = nn.Conv2d( 64 | dim * cardinality, 65 | dim * cardinality, 66 | kernel_size=3, 67 | stride=stride, 68 | padding=1, 69 | groups=cardinality, 70 | bias=False) 71 | self.bn = nn.BatchNorm2d(dim * cardinality) 72 | 73 | self.conv_expand = nn.Conv2d( 74 | dim * cardinality, 75 | planes * 4, 76 | kernel_size=1, 77 | stride=1, 78 | padding=0, 79 | bias=False) 80 | self.bn_expand = nn.BatchNorm2d(planes * 4) 81 | 82 | self.downsample = downsample 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | bottleneck = self.conv_reduce(x) 88 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True) 89 | 90 | bottleneck = self.conv_conv(bottleneck) 91 | bottleneck = F.relu(self.bn(bottleneck), inplace=True) 92 | 93 | bottleneck = self.conv_expand(bottleneck) 94 | bottleneck = self.bn_expand(bottleneck) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | return F.relu(residual + bottleneck, inplace=True) 100 | 101 | 102 | class CifarResNeXt(nn.Module): 103 | """ResNext optimized for the Cifar dataset, as specified in 104 | https://arxiv.org/pdf/1611.05431.pdf.""" 105 | 106 | def __init__(self, block, depth, cardinality, base_width, num_classes): 107 | super(CifarResNeXt, self).__init__() 108 | 109 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 110 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101' 111 | layer_blocks = (depth - 2) // 9 112 | 113 | self.cardinality = cardinality 114 | self.base_width = base_width 115 | self.num_classes = num_classes 116 | 117 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 118 | self.bn_1 = nn.BatchNorm2d(64) 119 | 120 | self.inplanes = 64 121 | self.stage_1 = self._make_layer(block, 64, layer_blocks, 1) 122 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2) 123 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2) 124 | self.avgpool = nn.AvgPool2d(8) 125 | self.classifier = nn.Linear(256 * block.expansion, num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.Linear): 135 | init.kaiming_normal_(m.weight) 136 | m.bias.data.zero_() 137 | 138 | def _make_layer(self, block, planes, blocks, stride=1): 139 | downsample = None 140 | if stride != 1 or self.inplanes != planes * block.expansion: 141 | downsample = nn.Sequential( 142 | nn.Conv2d( 143 | self.inplanes, 144 | planes * block.expansion, 145 | kernel_size=1, 146 | stride=stride, 147 | bias=False), 148 | nn.BatchNorm2d(planes * block.expansion), 149 | ) 150 | 151 | layers = [] 152 | layers.append( 153 | block(self.inplanes, planes, self.cardinality, self.base_width, stride, 154 | downsample)) 155 | self.inplanes = planes * block.expansion 156 | for _ in range(1, blocks): 157 | layers.append( 158 | block(self.inplanes, planes, self.cardinality, self.base_width)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | x = self.conv_1_3x3(x) 164 | x = F.relu(self.bn_1(x), inplace=True) 165 | x = self.stage_1(x) 166 | x = self.stage_2(x) 167 | x = self.stage_3(x) 168 | x = self.avgpool(x) 169 | x = x.view(x.size(0), -1) 170 | return self.classifier(x) 171 | -------------------------------------------------------------------------------- /cifar/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | from typing import Dict, List, Tuple 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | import logging 9 | from tqdm import tqdm 10 | 11 | def pytorch_evaluate(net: nn.Module, data_loader: data.DataLoader, fetch_keys: List, 12 | x_shape: Tuple = None, output_shapes: Dict = None, to_tensor: bool=False, verbose=False) -> Tuple: 13 | 14 | if output_shapes is not None: 15 | for key in fetch_keys: 16 | assert key in output_shapes 17 | 18 | # Fetching inference outputs as numpy arrays 19 | batch_size = data_loader.batch_size 20 | num_samples = len(data_loader.dataset) 21 | batch_count = int(np.ceil(num_samples / batch_size)) 22 | fetches_dict = {} 23 | fetches = [] 24 | for key in fetch_keys: 25 | fetches_dict[key] = [] 26 | 27 | net.eval() 28 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 29 | 30 | for batch_idx, (inputs, targets) in (tqdm(enumerate(data_loader)) if verbose else enumerate(data_loader)): 31 | if x_shape is not None: 32 | inputs = inputs.reshape(x_shape) 33 | inputs, targets = inputs.to(device), targets.to(device) 34 | outputs_dict = net(inputs) 35 | for key in fetch_keys: 36 | fetches_dict[key].append(outputs_dict[key].data.cpu().detach().numpy()) 37 | 38 | # stack variables together 39 | for key in fetch_keys: 40 | fetch = np.vstack(fetches_dict[key]) 41 | if output_shapes is not None: 42 | fetch = fetch.reshape(output_shapes[key]) 43 | if to_tensor: 44 | fetch = torch.as_tensor(fetch, device=torch.device(device)) 45 | fetches.append(fetch) 46 | 47 | assert batch_idx + 1 == batch_count 48 | assert fetches[0].shape[0] == num_samples 49 | 50 | return tuple(fetches) 51 | 52 | def boolean_string(s): 53 | # to use --use_bn True or --use_bn False in the shell. See: 54 | # https://stackoverflow.com/questions/44561722/why-in-argparse-a-true-is-always-true 55 | if s not in {'False', 'True'}: 56 | raise ValueError('Not a valid boolean string') 57 | return s == 'True' 58 | 59 | def convert_tensor_to_image(x: np.ndarray): 60 | """ 61 | :param X: np.array of size (Batch, feature_dims, H, W) or (feature_dims, H, W) 62 | :return: X with (Batch, H, W, feature_dims) or (H, W, feature_dims) between 0:255, uint8 63 | """ 64 | X = x.copy() 65 | X *= 255.0 66 | X = np.round(X) 67 | X = X.astype(np.uint8) 68 | if len(x.shape) == 3: 69 | X = np.transpose(X, [1, 2, 0]) 70 | else: 71 | X = np.transpose(X, [0, 2, 3, 1]) 72 | return X 73 | 74 | def convert_image_to_tensor(x: np.ndarray): 75 | """ 76 | :param X: np.array of size (Batch, H, W, feature_dims) between 0:255, uint8 77 | :return: X with (Batch, feature_dims, H, W) float between [0:1] 78 | """ 79 | assert x.dtype == np.uint8 80 | X = x.copy() 81 | X = X.astype(np.float32) 82 | X /= 255.0 83 | X = np.transpose(X, [0, 3, 1, 2]) 84 | return X 85 | 86 | def majority_vote(x): 87 | return np.bincount(x).argmax() 88 | 89 | def get_ensemble_paths(ensemble_dir): 90 | ensemble_subdirs = next(os.walk(ensemble_dir))[1] 91 | ensemble_subdirs.sort() 92 | ensemble_paths = [] 93 | for j, dir in enumerate(ensemble_subdirs): # for network j 94 | ensemble_paths.append(os.path.join(ensemble_dir, dir, 'ckpt.pth')) 95 | 96 | return ensemble_paths 97 | 98 | def set_logger(log_file): 99 | logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', 100 | datefmt='%m/%d/%Y %I:%M:%S %p', 101 | level=logging.INFO, 102 | handlers=[logging.FileHandler(log_file, mode='w'), 103 | logging.StreamHandler(sys.stdout)] 104 | ) 105 | 106 | def print_Linf_dists(X, X_test): 107 | logger = logging.getLogger() 108 | X_diff = (X - X_test).reshape(X.shape[0], -1) 109 | X_diff_abs = np.abs(X_diff) 110 | Linf_dist = X_diff_abs.max(axis=1) 111 | Linf_dist = Linf_dist[np.where(Linf_dist > 0.0)[0]] 112 | logger.info('The adversarial attacks distance: Max[L_inf]={}, E[L_inf]={}'.format(np.max(Linf_dist), np.mean(Linf_dist))) 113 | 114 | def calc_attack_rate(y_preds: np.ndarray, y_orig_norm_preds: np.ndarray, y_gt: np.ndarray) -> float: 115 | """ 116 | Args: 117 | y_preds: The adv image's final prediction after the defense method 118 | y_orig_norm_preds: The original image's predictions 119 | y_gt: The GT labels 120 | targeted: Whether or not the attack was targeted 121 | 122 | Returns: attack rate in % 123 | """ 124 | f0_inds = [] # net_fail 125 | f1_inds = [] # net_succ 126 | f2_inds = [] # net_succ AND attack_flip 127 | 128 | for i in range(len(y_gt)): 129 | f1 = y_orig_norm_preds[i] == y_gt[i] 130 | f2 = f1 and y_preds[i] != y_orig_norm_preds[i] 131 | if f1: 132 | f1_inds.append(i) 133 | else: 134 | f0_inds.append(i) 135 | if f2: 136 | f2_inds.append(i) 137 | 138 | attack_rate = len(f2_inds) / len(f1_inds) 139 | return attack_rate 140 | 141 | def get_all_files_recursive(path, suffix=None): 142 | files = [] 143 | # r=root, d=directories, f=files 144 | for r, d, f in os.walk(path): 145 | for file in f: 146 | if suffix is None: 147 | files.append(os.path.join(r, file)) 148 | elif '.' + suffix in file: 149 | files.append(os.path.join(r, file)) 150 | return files 151 | 152 | def convert_grayscale_to_rgb(x: np.ndarray) -> np.ndarray: 153 | """ 154 | Converts a 2D image shape=(x, y) to a RGB image (x, y, 3). 155 | Args: 156 | x: gray image 157 | Returns: rgb image 158 | """ 159 | return np.stack((x, ) * 3, axis=-1) 160 | 161 | def inverse_map(x: dict) -> dict: 162 | """ 163 | :param x: dictionary 164 | :return: inverse mapping, showing for each val its key 165 | """ 166 | inv_map = {} 167 | for k, v in x.items(): 168 | inv_map[v] = k 169 | return inv_map 170 | 171 | def get_image_shape(dataset: str) -> Tuple[int, int, int]: 172 | if dataset in ['cifar10', 'cifar100', 'svhn']: 173 | return 32, 32, 3 174 | elif dataset == 'tiny_imagenet': 175 | return 64, 64, 3 176 | else: 177 | raise AssertionError('Unsupported dataset {}'.format(dataset)) 178 | -------------------------------------------------------------------------------- /cifar/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Configuration file (powered by YACS).""" 7 | 8 | import argparse 9 | import os 10 | import sys 11 | import logging 12 | import random 13 | import torch 14 | import numpy as np 15 | from datetime import datetime 16 | from iopath.common.file_io import g_pathmgr 17 | from yacs.config import CfgNode as CfgNode 18 | 19 | 20 | # Global config object (example usage: from core.config import cfg) 21 | _C = CfgNode() 22 | cfg = _C 23 | 24 | 25 | # ----------------------------- Model options ------------------------------- # 26 | _C.MODEL = CfgNode() 27 | 28 | # Check https://github.com/RobustBench/robustbench for available models 29 | _C.MODEL.ARCH = 'Standard' 30 | 31 | # Choice of (source, norm, tent) 32 | # - source: baseline without adaptation 33 | # - norm: test-time normalization 34 | # - tent: test-time entropy minimization (ours) 35 | _C.MODEL.ADAPTATION = 'source' 36 | 37 | # By default tent is online, with updates persisting across batches. 38 | # To make adaptation episodic, and reset the model for each batch, choose True. 39 | _C.MODEL.EPISODIC = False 40 | 41 | # ----------------------------- Corruption options -------------------------- # 42 | _C.CORRUPTION = CfgNode() 43 | 44 | # Dataset for evaluation 45 | _C.CORRUPTION.DATASET = 'cifar10' 46 | 47 | # Check https://github.com/hendrycks/robustness for corruption details 48 | _C.CORRUPTION.TYPE = ['gaussian_noise', 'shot_noise', 'impulse_noise', 49 | 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 50 | 'snow', 'frost', 'fog', 'brightness', 'contrast', 51 | 'elastic_transform', 'pixelate', 'jpeg_compression'] 52 | _C.CORRUPTION.SEVERITY = [5, 4, 3, 2, 1] 53 | 54 | # Number of examples to evaluate (10000 for all samples in CIFAR-10) 55 | _C.CORRUPTION.NUM_EX = 10000 56 | 57 | # ------------------------------- Batch norm options ------------------------ # 58 | _C.BN = CfgNode() 59 | 60 | # BN epsilon 61 | _C.BN.EPS = 1e-5 62 | 63 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) 64 | _C.BN.MOM = 0.1 65 | 66 | # ------------------------------- Optimizer options ------------------------- # 67 | _C.OPTIM = CfgNode() 68 | 69 | # Number of updates per batch 70 | _C.OPTIM.STEPS = 1 71 | 72 | # Learning rate 73 | _C.OPTIM.LR = 1e-3 74 | 75 | # Choices: Adam, SGD 76 | _C.OPTIM.METHOD = 'Adam' 77 | 78 | # Beta 79 | _C.OPTIM.BETA = 0.9 80 | 81 | # Momentum 82 | _C.OPTIM.MOMENTUM = 0.9 83 | 84 | # Momentum dampening 85 | _C.OPTIM.DAMPENING = 0.0 86 | 87 | # Nesterov momentum 88 | _C.OPTIM.NESTEROV = True 89 | 90 | # L2 regularization 91 | _C.OPTIM.WD = 0.0 92 | 93 | # COTTA 94 | _C.OPTIM.MT = 0.999 95 | _C.OPTIM.RST = 0.01 96 | _C.OPTIM.AP = 0.92 97 | 98 | # ------------------------------- Testing options --------------------------- # 99 | _C.TEST = CfgNode() 100 | 101 | # Batch size for evaluation (and updates for norm + tent) 102 | _C.TEST.BATCH_SIZE = 128 103 | 104 | # --------------------------------- CUDNN options --------------------------- # 105 | _C.CUDNN = CfgNode() 106 | 107 | # Benchmark to select fastest CUDNN algorithms (best for fixed input sizes) 108 | _C.CUDNN.BENCHMARK = True 109 | 110 | # ---------------------------------- Misc options --------------------------- # 111 | 112 | # Optional description of a config 113 | _C.DESC = "" 114 | 115 | # Note that non-determinism is still present due to non-deterministic GPU ops 116 | _C.RNG_SEED = 1 117 | 118 | # Output directory 119 | _C.SAVE_DIR = "./output" 120 | 121 | # Data directory 122 | _C.DATA_DIR = "./data" 123 | 124 | # Weight directory 125 | _C.CKPT_DIR = "./ckpt" 126 | 127 | # Log destination (in SAVE_DIR) 128 | _C.LOG_DEST = "log.txt" 129 | 130 | # Log datetime 131 | _C.LOG_TIME = '' 132 | 133 | # # Config destination (in SAVE_DIR) 134 | # _C.CFG_DEST = "cfg.yaml" 135 | 136 | # --------------------------------- Default config -------------------------- # 137 | _CFG_DEFAULT = _C.clone() 138 | _CFG_DEFAULT.freeze() 139 | 140 | 141 | def assert_and_infer_cfg(): 142 | """Checks config values invariants.""" 143 | err_str = "Unknown adaptation method." 144 | assert _C.MODEL.ADAPTATION in ["source", "norm", "tent"] 145 | err_str = "Log destination '{}' not supported" 146 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) 147 | 148 | 149 | def merge_from_file(cfg_file): 150 | with g_pathmgr.open(cfg_file, "r") as f: 151 | cfg = _C.load_cfg(f) 152 | _C.merge_from_other_cfg(cfg) 153 | 154 | 155 | def dump_cfg(): 156 | """Dumps the config to the output directory.""" 157 | cfg_file = os.path.join(_C.SAVE_DIR, _C.CFG_DEST) 158 | with g_pathmgr.open(cfg_file, "w") as f: 159 | _C.dump(stream=f) 160 | 161 | 162 | def load_cfg(out_dir, cfg_dest="config.yaml"): 163 | """Loads config from specified output directory.""" 164 | cfg_file = os.path.join(out_dir, cfg_dest) 165 | merge_from_file(cfg_file) 166 | 167 | 168 | def reset_cfg(): 169 | """Reset config to initial state.""" 170 | cfg.merge_from_other_cfg(_CFG_DEFAULT) 171 | 172 | 173 | def load_cfg_fom_args(description="Config options."): 174 | """Load config from command line args and set any specified options.""" 175 | current_time = datetime.now().strftime("%y%m%d_%H%M%S") 176 | parser = argparse.ArgumentParser(description=description) 177 | parser.add_argument("--cfg", dest="cfg_file", type=str, required=True, 178 | help="Config file location") 179 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER, 180 | help="See conf.py for all options") 181 | if len(sys.argv) == 1: 182 | parser.print_help() 183 | sys.exit(1) 184 | args = parser.parse_args() 185 | 186 | merge_from_file(args.cfg_file) 187 | cfg.merge_from_list(args.opts) 188 | 189 | log_dest = os.path.basename(args.cfg_file) 190 | log_dest = log_dest.replace('.yaml', '_{}.txt'.format(current_time)) 191 | 192 | g_pathmgr.mkdirs(cfg.SAVE_DIR) 193 | cfg.LOG_TIME, cfg.LOG_DEST = current_time, log_dest 194 | cfg.freeze() 195 | 196 | logging.basicConfig( 197 | level=logging.INFO, 198 | format="[%(asctime)s] [%(filename)s: %(lineno)4d]: %(message)s", 199 | datefmt="%y/%m/%d %H:%M:%S", 200 | handlers=[ 201 | logging.FileHandler(os.path.join(cfg.SAVE_DIR, cfg.LOG_DEST)), 202 | logging.StreamHandler() 203 | ]) 204 | 205 | np.random.seed(cfg.RNG_SEED) 206 | torch.manual_seed(cfg.RNG_SEED) 207 | random.seed(cfg.RNG_SEED) 208 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK 209 | 210 | logger = logging.getLogger(__name__) 211 | version = [torch.__version__, torch.version.cuda, 212 | torch.backends.cudnn.version()] 213 | logger.info( 214 | "PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version)) 215 | logger.info(cfg) 216 | -------------------------------------------------------------------------------- /imagenet/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Configuration file (powered by YACS).""" 7 | 8 | import argparse 9 | import os 10 | import sys 11 | import logging 12 | import random 13 | import torch 14 | import numpy as np 15 | from datetime import datetime 16 | from iopath.common.file_io import g_pathmgr 17 | from yacs.config import CfgNode as CfgNode 18 | 19 | 20 | # Global config object (example usage: from core.config import cfg) 21 | _C = CfgNode() 22 | cfg = _C 23 | 24 | 25 | # ----------------------------- Model options ------------------------------- # 26 | _C.MODEL = CfgNode() 27 | 28 | # Check https://github.com/RobustBench/robustbench for available models 29 | _C.MODEL.ARCH = 'Standard' 30 | 31 | # Choice of (source, norm, tent) 32 | # - source: baseline without adaptation 33 | # - norm: test-time normalization 34 | # - tent: test-time entropy minimization (ours) 35 | _C.MODEL.ADAPTATION = 'source' 36 | 37 | # By default tent is online, with updates persisting across batches. 38 | # To make adaptation episodic, and reset the model for each batch, choose True. 39 | _C.MODEL.EPISODIC = False 40 | 41 | # ----------------------------- Corruption options -------------------------- # 42 | _C.CORRUPTION = CfgNode() 43 | 44 | # Dataset for evaluation 45 | _C.CORRUPTION.DATASET = 'cifar10' 46 | 47 | # Check https://github.com/hendrycks/robustness for corruption details 48 | _C.CORRUPTION.TYPE = ['gaussian_noise', 'shot_noise', 'impulse_noise', 49 | 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 50 | 'snow', 'frost', 'fog', 'brightness', 'contrast', 51 | 'elastic_transform', 'pixelate', 'jpeg_compression'] 52 | _C.CORRUPTION.SEVERITY = [5, 4, 3, 2, 1] 53 | 54 | # Number of examples to evaluate 55 | # The 5000 val images defined by Robustbench were actually used: 56 | # Please see https://github.com/RobustBench/robustbench/blob/7af0e34c6b383cd73ea7a1bbced358d7ce6ad22f/robustbench/data/imagenet_test_image_ids.txt 57 | _C.CORRUPTION.NUM_EX = 5000 58 | 59 | # ------------------------------- Batch norm options ------------------------ # 60 | _C.BN = CfgNode() 61 | 62 | # BN epsilon 63 | _C.BN.EPS = 1e-5 64 | 65 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) 66 | _C.BN.MOM = 0.1 67 | 68 | # ------------------------------- Optimizer options ------------------------- # 69 | _C.OPTIM = CfgNode() 70 | 71 | # Number of updates per batch 72 | _C.OPTIM.STEPS = 1 73 | 74 | # Learning rate 75 | _C.OPTIM.LR = 1e-3 76 | 77 | # Choices: Adam, SGD 78 | _C.OPTIM.METHOD = 'Adam' 79 | 80 | # Beta 81 | _C.OPTIM.BETA = 0.9 82 | 83 | # Momentum 84 | _C.OPTIM.MOMENTUM = 0.9 85 | 86 | # Momentum dampening 87 | _C.OPTIM.DAMPENING = 0.0 88 | 89 | # Nesterov momentum 90 | _C.OPTIM.NESTEROV = True 91 | 92 | # L2 regularization 93 | _C.OPTIM.WD = 0.0 94 | 95 | # ------------------------------- Testing options --------------------------- # 96 | _C.TEST = CfgNode() 97 | 98 | # Batch size for evaluation (and updates for norm + tent) 99 | _C.TEST.BATCH_SIZE = 128 100 | 101 | # --------------------------------- CUDNN options --------------------------- # 102 | _C.CUDNN = CfgNode() 103 | 104 | # Benchmark to select fastest CUDNN algorithms (best for fixed input sizes) 105 | _C.CUDNN.BENCHMARK = True 106 | 107 | # ---------------------------------- Misc options --------------------------- # 108 | 109 | # Optional description of a config 110 | _C.DESC = "" 111 | 112 | # Note that non-determinism is still present due to non-deterministic GPU ops 113 | _C.RNG_SEED = 1 114 | 115 | # Output directory 116 | _C.SAVE_DIR = "./output" 117 | 118 | # Data directory 119 | _C.DATA_DIR = "./data" 120 | 121 | # Weight directory 122 | _C.CKPT_DIR = "./ckpt" 123 | 124 | # Log destination (in SAVE_DIR) 125 | _C.LOG_DEST = "log.txt" 126 | 127 | # Log datetime 128 | _C.LOG_TIME = '' 129 | 130 | # # Config destination (in SAVE_DIR) 131 | # _C.CFG_DEST = "cfg.yaml" 132 | 133 | # --------------------------------- Default config -------------------------- # 134 | _CFG_DEFAULT = _C.clone() 135 | _CFG_DEFAULT.freeze() 136 | 137 | 138 | def assert_and_infer_cfg(): 139 | """Checks config values invariants.""" 140 | err_str = "Unknown adaptation method." 141 | assert _C.MODEL.ADAPTATION in ["source", "norm", "tent"] 142 | err_str = "Log destination '{}' not supported" 143 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) 144 | 145 | 146 | def merge_from_file(cfg_file): 147 | with g_pathmgr.open(cfg_file, "r") as f: 148 | cfg = _C.load_cfg(f) 149 | _C.merge_from_other_cfg(cfg) 150 | 151 | 152 | def dump_cfg(): 153 | """Dumps the config to the output directory.""" 154 | cfg_file = os.path.join(_C.SAVE_DIR, _C.CFG_DEST) 155 | with g_pathmgr.open(cfg_file, "w") as f: 156 | _C.dump(stream=f) 157 | 158 | 159 | def load_cfg(out_dir, cfg_dest="config.yaml"): 160 | """Loads config from specified output directory.""" 161 | cfg_file = os.path.join(out_dir, cfg_dest) 162 | merge_from_file(cfg_file) 163 | 164 | 165 | def reset_cfg(): 166 | """Reset config to initial state.""" 167 | cfg.merge_from_other_cfg(_CFG_DEFAULT) 168 | 169 | 170 | def load_cfg_fom_args(description="Config options."): 171 | """Load config from command line args and set any specified options.""" 172 | current_time = datetime.now().strftime("%y%m%d_%H%M%S") 173 | parser = argparse.ArgumentParser(description=description) 174 | parser.add_argument("--cfg", dest="cfg_file", type=str, required=True, 175 | help="Config file location") 176 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER, 177 | help="See conf.py for all options") 178 | if len(sys.argv) == 1: 179 | parser.print_help() 180 | sys.exit(1) 181 | args = parser.parse_args() 182 | 183 | merge_from_file(args.cfg_file) 184 | cfg.merge_from_list(args.opts) 185 | 186 | log_dest = os.path.basename(args.cfg_file) 187 | log_dest = log_dest.replace('.yaml', '_{}.txt'.format(current_time)) 188 | 189 | g_pathmgr.mkdirs(cfg.SAVE_DIR) 190 | cfg.LOG_TIME, cfg.LOG_DEST = current_time, log_dest 191 | cfg.freeze() 192 | 193 | logging.basicConfig( 194 | level=logging.INFO, 195 | format="[%(asctime)s] [%(filename)s: %(lineno)4d]: %(message)s", 196 | datefmt="%y/%m/%d %H:%M:%S", 197 | handlers=[ 198 | logging.FileHandler(os.path.join(cfg.SAVE_DIR, cfg.LOG_DEST)), 199 | logging.StreamHandler() 200 | ]) 201 | 202 | np.random.seed(cfg.RNG_SEED) 203 | torch.manual_seed(cfg.RNG_SEED) 204 | random.seed(cfg.RNG_SEED) 205 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK 206 | 207 | logger = logging.getLogger(__name__) 208 | version = [torch.__version__, torch.version.cuda, 209 | torch.backends.cudnn.version()] 210 | logger.info( 211 | "PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version)) 212 | logger.info(cfg) 213 | -------------------------------------------------------------------------------- /cifar/robustbench/loaders.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py. 3 | """ 4 | from torchvision.datasets.vision import VisionDataset 5 | 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | 12 | import os 13 | import os.path 14 | import sys 15 | import json 16 | 17 | 18 | def make_custom_dataset(root, path_imgs, cls_dict): 19 | with open(path_imgs, 'r') as f: 20 | fnames = f.readlines() 21 | with open(cls_dict, 'r') as f: 22 | class_to_idx = json.load(f) 23 | images = [(os.path.join(root, c.split('\n')[0]), class_to_idx[c.split('/')[0]]) for c in fnames] 24 | 25 | return images 26 | 27 | 28 | class CustomDatasetFolder(VisionDataset): 29 | """A generic data loader where the samples are arranged in this way: :: 30 | root/class_x/xxx.ext 31 | root/class_x/xxy.ext 32 | root/class_x/xxz.ext 33 | root/class_y/123.ext 34 | root/class_y/nsdf3.ext 35 | root/class_y/asd932_.ext 36 | Args: 37 | root (string): Root directory path. 38 | loader (callable): A function to load a sample given its path. 39 | extensions (tuple[string]): A list of allowed extensions. 40 | both extensions and is_valid_file should not be passed. 41 | transform (callable, optional): A function/transform that takes in 42 | a sample and returns a transformed version. 43 | E.g, ``transforms.RandomCrop`` for images. 44 | target_transform (callable, optional): A function/transform that takes 45 | in the target and transforms it. 46 | is_valid_file (callable, optional): A function that takes path of an Image file 47 | and check if the file is a valid_file (used to check of corrupt files) 48 | both extensions and is_valid_file should not be passed. 49 | Attributes: 50 | classes (list): List of the class names. 51 | class_to_idx (dict): Dict with items (class_name, class_index). 52 | samples (list): List of (sample path, class_index) tuples 53 | targets (list): The class_index value for each image in the dataset 54 | """ 55 | 56 | def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None): 57 | super(CustomDatasetFolder, self).__init__(root) 58 | self.transform = transform 59 | self.target_transform = target_transform 60 | classes, class_to_idx = self._find_classes(self.root) 61 | samples = make_custom_dataset(self.root, 'robustbench/data/imagenet_test_image_ids.txt', 62 | 'robustbench/data/imagenet_class_to_id_map.json') 63 | if len(samples) == 0: 64 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 65 | "Supported extensions are: " + ",".join(extensions))) 66 | 67 | self.loader = loader 68 | self.extensions = extensions 69 | 70 | self.classes = classes 71 | self.class_to_idx = class_to_idx 72 | self.samples = samples 73 | self.targets = [s[1] for s in samples] 74 | 75 | def _find_classes(self, dir): 76 | """ 77 | Finds the class folders in a dataset. 78 | Args: 79 | dir (string): Root directory path. 80 | Returns: 81 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 82 | Ensures: 83 | No class is a subdirectory of another. 84 | """ 85 | if sys.version_info >= (3, 5): 86 | # Faster and available in Python 3.5 and above 87 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 88 | else: 89 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 90 | classes.sort() 91 | class_to_idx = {classes[i]: i for i in range(len(classes))} 92 | return classes, class_to_idx 93 | 94 | def __getitem__(self, index): 95 | """ 96 | Args: 97 | index (int): Index 98 | Returns: 99 | tuple: (sample, target) where target is class_index of the target class. 100 | """ 101 | path, target = self.samples[index] 102 | sample = self.loader(path) 103 | if self.transform is not None: 104 | sample = self.transform(sample) 105 | if self.target_transform is not None: 106 | target = self.target_transform(target) 107 | return sample, target, path 108 | 109 | def __len__(self): 110 | return len(self.samples) 111 | 112 | 113 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 114 | 115 | 116 | def pil_loader(path): 117 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 118 | with open(path, 'rb') as f: 119 | img = Image.open(f) 120 | return img.convert('RGB') 121 | 122 | 123 | def accimage_loader(path): 124 | import accimage 125 | try: 126 | return accimage.Image(path) 127 | except IOError: 128 | # Potentially a decoding problem, fall back to PIL.Image 129 | return pil_loader(path) 130 | 131 | 132 | def default_loader(path): 133 | from torchvision import get_image_backend 134 | if get_image_backend() == 'accimage': 135 | return accimage_loader(path) 136 | else: 137 | return pil_loader(path) 138 | 139 | 140 | class CustomImageFolder(CustomDatasetFolder): 141 | """A generic data loader where the images are arranged in this way: :: 142 | root/dog/xxx.png 143 | root/dog/xxy.png 144 | root/dog/xxz.png 145 | root/cat/123.png 146 | root/cat/nsdf3.png 147 | root/cat/asd932_.png 148 | Args: 149 | root (string): Root directory path. 150 | transform (callable, optional): A function/transform that takes in an PIL image 151 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 152 | target_transform (callable, optional): A function/transform that takes in the 153 | target and transforms it. 154 | loader (callable, optional): A function to load an image given its path. 155 | is_valid_file (callable, optional): A function that takes path of an Image file 156 | and check if the file is a valid_file (used to check of corrupt files) 157 | Attributes: 158 | classes (list): List of the class names. 159 | class_to_idx (dict): Dict with items (class_name, class_index). 160 | imgs (list): List of (image path, class_index) tuples 161 | """ 162 | 163 | def __init__(self, root, transform=None, target_transform=None, 164 | loader=default_loader, is_valid_file=None): 165 | super(CustomImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 166 | transform=transform, 167 | target_transform=target_transform, 168 | is_valid_file=is_valid_file) 169 | 170 | self.imgs = self.samples 171 | 172 | 173 | if __name__ == '__main__': 174 | data_dir = '/home/scratch/datasets/imagenet/val' 175 | imagenet = CustomImageFolder(data_dir, transforms.Compose([ 176 | transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])) 177 | 178 | torch.manual_seed(0) 179 | 180 | test_loader = data.DataLoader(imagenet, batch_size=5000, shuffle=True, num_workers=30) 181 | 182 | x, y, path = next(iter(test_loader)) 183 | 184 | with open('path_imgs_2.txt', 'w') as f: 185 | f.write('\n'.join(path)) 186 | f.flush() 187 | 188 | -------------------------------------------------------------------------------- /imagenet/robustbench/loaders.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py. 3 | """ 4 | from torchvision.datasets.vision import VisionDataset 5 | 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | 12 | import os 13 | import os.path 14 | import sys 15 | import json 16 | 17 | 18 | def make_custom_dataset(root, path_imgs, cls_dict): 19 | with open(path_imgs, 'r') as f: 20 | fnames = f.readlines() 21 | with open(cls_dict, 'r') as f: 22 | class_to_idx = json.load(f) 23 | images = [(os.path.join(root, c.split('\n')[0]), class_to_idx[c.split('/')[0]]) for c in fnames] 24 | 25 | return images 26 | 27 | 28 | class CustomDatasetFolder(VisionDataset): 29 | """A generic data loader where the samples are arranged in this way: :: 30 | root/class_x/xxx.ext 31 | root/class_x/xxy.ext 32 | root/class_x/xxz.ext 33 | root/class_y/123.ext 34 | root/class_y/nsdf3.ext 35 | root/class_y/asd932_.ext 36 | Args: 37 | root (string): Root directory path. 38 | loader (callable): A function to load a sample given its path. 39 | extensions (tuple[string]): A list of allowed extensions. 40 | both extensions and is_valid_file should not be passed. 41 | transform (callable, optional): A function/transform that takes in 42 | a sample and returns a transformed version. 43 | E.g, ``transforms.RandomCrop`` for images. 44 | target_transform (callable, optional): A function/transform that takes 45 | in the target and transforms it. 46 | is_valid_file (callable, optional): A function that takes path of an Image file 47 | and check if the file is a valid_file (used to check of corrupt files) 48 | both extensions and is_valid_file should not be passed. 49 | Attributes: 50 | classes (list): List of the class names. 51 | class_to_idx (dict): Dict with items (class_name, class_index). 52 | samples (list): List of (sample path, class_index) tuples 53 | targets (list): The class_index value for each image in the dataset 54 | """ 55 | 56 | def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None): 57 | super(CustomDatasetFolder, self).__init__(root) 58 | self.transform = transform 59 | self.target_transform = target_transform 60 | classes, class_to_idx = self._find_classes(self.root) 61 | samples = make_custom_dataset(self.root, 'robustbench/data/imagenet_test_image_ids.txt', 62 | 'robustbench/data/imagenet_class_to_id_map.json') 63 | if len(samples) == 0: 64 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 65 | "Supported extensions are: " + ",".join(extensions))) 66 | 67 | self.loader = loader 68 | self.extensions = extensions 69 | 70 | self.classes = classes 71 | self.class_to_idx = class_to_idx 72 | self.samples = samples 73 | self.targets = [s[1] for s in samples] 74 | 75 | def _find_classes(self, dir): 76 | """ 77 | Finds the class folders in a dataset. 78 | Args: 79 | dir (string): Root directory path. 80 | Returns: 81 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 82 | Ensures: 83 | No class is a subdirectory of another. 84 | """ 85 | if sys.version_info >= (3, 5): 86 | # Faster and available in Python 3.5 and above 87 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 88 | else: 89 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 90 | classes.sort() 91 | class_to_idx = {classes[i]: i for i in range(len(classes))} 92 | return classes, class_to_idx 93 | 94 | def __getitem__(self, index): 95 | """ 96 | Args: 97 | index (int): Index 98 | Returns: 99 | tuple: (sample, target) where target is class_index of the target class. 100 | """ 101 | path, target = self.samples[index] 102 | sample = self.loader(path) 103 | if self.transform is not None: 104 | sample = self.transform(sample) 105 | if self.target_transform is not None: 106 | target = self.target_transform(target) 107 | return sample, target, path 108 | 109 | def __len__(self): 110 | return len(self.samples) 111 | 112 | 113 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 114 | 115 | 116 | def pil_loader(path): 117 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 118 | with open(path, 'rb') as f: 119 | img = Image.open(f) 120 | return img.convert('RGB') 121 | 122 | 123 | def accimage_loader(path): 124 | import accimage 125 | try: 126 | return accimage.Image(path) 127 | except IOError: 128 | # Potentially a decoding problem, fall back to PIL.Image 129 | return pil_loader(path) 130 | 131 | 132 | def default_loader(path): 133 | from torchvision import get_image_backend 134 | if get_image_backend() == 'accimage': 135 | return accimage_loader(path) 136 | else: 137 | return pil_loader(path) 138 | 139 | 140 | class CustomImageFolder(CustomDatasetFolder): 141 | """A generic data loader where the images are arranged in this way: :: 142 | root/dog/xxx.png 143 | root/dog/xxy.png 144 | root/dog/xxz.png 145 | root/cat/123.png 146 | root/cat/nsdf3.png 147 | root/cat/asd932_.png 148 | Args: 149 | root (string): Root directory path. 150 | transform (callable, optional): A function/transform that takes in an PIL image 151 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 152 | target_transform (callable, optional): A function/transform that takes in the 153 | target and transforms it. 154 | loader (callable, optional): A function to load an image given its path. 155 | is_valid_file (callable, optional): A function that takes path of an Image file 156 | and check if the file is a valid_file (used to check of corrupt files) 157 | Attributes: 158 | classes (list): List of the class names. 159 | class_to_idx (dict): Dict with items (class_name, class_index). 160 | imgs (list): List of (image path, class_index) tuples 161 | """ 162 | 163 | def __init__(self, root, transform=None, target_transform=None, 164 | loader=default_loader, is_valid_file=None): 165 | super(CustomImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 166 | transform=transform, 167 | target_transform=target_transform, 168 | is_valid_file=is_valid_file) 169 | 170 | self.imgs = self.samples 171 | 172 | 173 | if __name__ == '__main__': 174 | data_dir = '/home/scratch/datasets/imagenet/val' 175 | imagenet = CustomImageFolder(data_dir, transforms.Compose([ 176 | transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])) 177 | 178 | torch.manual_seed(0) 179 | 180 | test_loader = data.DataLoader(imagenet, batch_size=5000, shuffle=True, num_workers=30) 181 | 182 | x, y, path = next(iter(test_loader)) 183 | 184 | with open('path_imgs_2.txt', 'w') as f: 185 | f.write('\n'.join(path)) 186 | f.flush() 187 | 188 | -------------------------------------------------------------------------------- /cifar/cotta.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.jit 6 | 7 | import PIL 8 | import torchvision.transforms as transforms 9 | import my_transforms as my_transforms 10 | from time import time 11 | import logging 12 | 13 | 14 | def get_tta_transforms(gaussian_std: float=0.005, soft=False, clip_inputs=False): 15 | img_shape = (32, 32, 3) 16 | n_pixels = img_shape[0] 17 | 18 | clip_min, clip_max = 0.0, 1.0 19 | 20 | p_hflip = 0.5 21 | 22 | tta_transforms = transforms.Compose([ 23 | my_transforms.Clip(0.0, 1.0), 24 | my_transforms.ColorJitterPro( 25 | brightness=[0.8, 1.2] if soft else [0.6, 1.4], 26 | contrast=[0.85, 1.15] if soft else [0.7, 1.3], 27 | saturation=[0.75, 1.25] if soft else [0.5, 1.5], 28 | hue=[-0.03, 0.03] if soft else [-0.06, 0.06], 29 | gamma=[0.85, 1.15] if soft else [0.7, 1.3] 30 | ), 31 | transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'), 32 | transforms.RandomAffine( 33 | degrees=[-8, 8] if soft else [-15, 15], 34 | translate=(1/16, 1/16), 35 | scale=(0.95, 1.05) if soft else (0.9, 1.1), 36 | shear=None, 37 | resample=PIL.Image.BILINEAR, 38 | fillcolor=None 39 | ), 40 | transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]), 41 | transforms.CenterCrop(size=n_pixels), 42 | transforms.RandomHorizontalFlip(p=p_hflip), 43 | my_transforms.GaussianNoise(0, gaussian_std), 44 | my_transforms.Clip(clip_min, clip_max) 45 | ]) 46 | return tta_transforms 47 | 48 | 49 | def update_ema_variables(ema_model, model, alpha_teacher): 50 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 51 | ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:] 52 | return ema_model 53 | 54 | 55 | class CoTTA(nn.Module): 56 | """CoTTA adapts a model by entropy minimization during testing. 57 | 58 | Once tented, a model adapts itself by updating on every forward. 59 | """ 60 | def __init__(self, model, optimizer, steps=1, episodic=False, mt_alpha=0.99, rst_m=0.1, ap=0.9): 61 | super().__init__() 62 | self.model = model 63 | self.optimizer = optimizer 64 | self.steps = steps 65 | assert steps > 0, "cotta requires >= 1 step(s) to forward and update" 66 | self.episodic = episodic 67 | 68 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \ 69 | copy_model_and_optimizer(self.model, self.optimizer) 70 | self.transform = get_tta_transforms() 71 | self.mt = mt_alpha 72 | self.rst = rst_m 73 | self.ap = ap 74 | 75 | def forward(self, x): 76 | if self.episodic: 77 | self.reset() 78 | 79 | for _ in range(self.steps): 80 | outputs = self.forward_and_adapt(x, self.model, self.optimizer) 81 | 82 | return outputs 83 | 84 | def reset(self): 85 | if self.model_state is None or self.optimizer_state is None: 86 | raise Exception("cannot reset without saved model/optimizer state") 87 | load_model_and_optimizer(self.model, self.optimizer, 88 | self.model_state, self.optimizer_state) 89 | # Use this line to also restore the teacher model 90 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \ 91 | copy_model_and_optimizer(self.model, self.optimizer) 92 | 93 | 94 | @torch.enable_grad() # ensure grads in possible no grad context for testing 95 | def forward_and_adapt(self, x, model, optimizer): 96 | outputs = self.model(x) 97 | # Teacher Prediction 98 | anchor_prob = torch.nn.functional.softmax(self.model_anchor(x), dim=1).max(1)[0] 99 | standard_ema = self.model_ema(x) 100 | # Augmentation-averaged Prediction 101 | N = 32 102 | outputs_emas = [] 103 | for i in range(N): 104 | outputs_ = self.model_ema(self.transform(x)).detach() 105 | outputs_emas.append(outputs_) 106 | # Threshold choice discussed in supplementary 107 | if anchor_prob.mean(0) torch.Tensor: 131 | """Entropy of softmax distribution from logits.""" 132 | return -(x_ema.softmax(1) * x.log_softmax(1)).sum(1) 133 | 134 | def collect_params(model): 135 | """Collect all trainable parameters. 136 | 137 | Walk the model's modules and collect all parameters. 138 | Return the parameters and their names. 139 | 140 | Note: other choices of parameterization are possible! 141 | """ 142 | params = [] 143 | names = [] 144 | for nm, m in model.named_modules(): 145 | if True:#isinstance(m, nn.BatchNorm2d): collect all 146 | for np, p in m.named_parameters(): 147 | if np in ['weight', 'bias'] and p.requires_grad: 148 | params.append(p) 149 | names.append(f"{nm}.{np}") 150 | print(nm, np) 151 | return params, names 152 | 153 | 154 | def copy_model_and_optimizer(model, optimizer): 155 | """Copy the model and optimizer states for resetting after adaptation.""" 156 | model_state = deepcopy(model.state_dict()) 157 | model_anchor = deepcopy(model) 158 | optimizer_state = deepcopy(optimizer.state_dict()) 159 | ema_model = deepcopy(model) 160 | for param in ema_model.parameters(): 161 | param.detach_() 162 | return model_state, optimizer_state, ema_model, model_anchor 163 | 164 | 165 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 166 | """Restore the model and optimizer states from copies.""" 167 | model.load_state_dict(model_state, strict=True) 168 | optimizer.load_state_dict(optimizer_state) 169 | 170 | 171 | def configure_model(model): 172 | """Configure model for use with tent.""" 173 | # train mode, because tent optimizes the model to minimize entropy 174 | model.train() 175 | # disable grad, to (re-)enable only what we update 176 | model.requires_grad_(False) 177 | # enable all trainable 178 | for m in model.modules(): 179 | if isinstance(m, nn.BatchNorm2d): 180 | m.requires_grad_(True) 181 | # force use of batch stats in train and eval modes 182 | m.track_running_stats = False 183 | m.running_mean = None 184 | m.running_var = None 185 | else: 186 | m.requires_grad_(True) 187 | return model 188 | 189 | 190 | def check_model(model): 191 | """Check model for compatability with tent.""" 192 | is_training = model.training 193 | assert is_training, "tent needs train mode: call model.train()" 194 | param_grads = [p.requires_grad for p in model.parameters()] 195 | has_any_params = any(param_grads) 196 | has_all_params = all(param_grads) 197 | assert has_any_params, "tent needs params to update: " \ 198 | "check which require grad" 199 | assert not has_all_params, "tent should not update all params: " \ 200 | "check which require grad" 201 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 202 | assert has_bn, "tent needs normalization for its optimization" 203 | -------------------------------------------------------------------------------- /imagenet/cotta.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.jit 6 | 7 | import PIL 8 | import torchvision.transforms as transforms 9 | import my_transforms as my_transforms 10 | from time import time 11 | import logging 12 | 13 | 14 | def get_tta_transforms(gaussian_std: float=0.005, soft=False, clip_inputs=False): 15 | img_shape = (224, 224, 3) 16 | n_pixels = img_shape[0] 17 | 18 | clip_min, clip_max = 0.0, 1.0 19 | 20 | p_hflip = 0.5 21 | 22 | tta_transforms = transforms.Compose([ 23 | my_transforms.Clip(0.0, 1.0), 24 | my_transforms.ColorJitterPro( 25 | brightness=[0.8, 1.2] if soft else [0.6, 1.4], 26 | contrast=[0.85, 1.15] if soft else [0.7, 1.3], 27 | saturation=[0.75, 1.25] if soft else [0.5, 1.5], 28 | hue=[-0.03, 0.03] if soft else [-0.06, 0.06], 29 | gamma=[0.85, 1.15] if soft else [0.7, 1.3] 30 | ), 31 | transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'), 32 | transforms.RandomAffine( 33 | degrees=[-8, 8] if soft else [-15, 15], 34 | translate=(1/16, 1/16), 35 | scale=(0.95, 1.05) if soft else (0.9, 1.1), 36 | shear=None, 37 | resample=PIL.Image.BILINEAR, 38 | fillcolor=None 39 | ), 40 | transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]), 41 | transforms.CenterCrop(size=n_pixels), 42 | transforms.RandomHorizontalFlip(p=p_hflip), 43 | my_transforms.GaussianNoise(0, gaussian_std), 44 | my_transforms.Clip(clip_min, clip_max) 45 | ]) 46 | return tta_transforms 47 | 48 | 49 | def update_ema_variables(ema_model, model, alpha_teacher):#, iteration): 50 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 51 | ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:] 52 | return ema_model 53 | 54 | 55 | class CoTTA(nn.Module): 56 | """CoTTA adapts a model by entropy minimization during testing. 57 | 58 | Once tented, a model adapts itself by updating on every forward. 59 | """ 60 | def __init__(self, model, optimizer, steps=1, episodic=False): 61 | super().__init__() 62 | self.model = model 63 | self.optimizer = optimizer 64 | self.steps = steps 65 | assert steps > 0, "cotta requires >= 1 step(s) to forward and update" 66 | self.episodic = episodic 67 | 68 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \ 69 | copy_model_and_optimizer(self.model, self.optimizer) 70 | self.transform = get_tta_transforms() 71 | 72 | def forward(self, x): 73 | if self.episodic: 74 | self.reset() 75 | 76 | for _ in range(self.steps): 77 | outputs = self.forward_and_adapt(x, self.model, self.optimizer) 78 | 79 | return outputs 80 | 81 | def reset(self): 82 | if self.model_state is None or self.optimizer_state is None: 83 | raise Exception("cannot reset without saved model/optimizer state") 84 | load_model_and_optimizer(self.model, self.optimizer, 85 | self.model_state, self.optimizer_state) 86 | # use this line if you want to reset the teacher model as well. Maybe you also 87 | # want to del self.model_ema first to save gpu memory. 88 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \ 89 | copy_model_and_optimizer(self.model, self.optimizer) 90 | 91 | 92 | @torch.enable_grad() # ensure grads in possible no grad context for testing 93 | def forward_and_adapt(self, x, model, optimizer): 94 | outputs = self.model(x) 95 | self.model_ema.train() 96 | # Teacher Prediction 97 | anchor_prob = torch.nn.functional.softmax(self.model_anchor(x), dim=1).max(1)[0] 98 | standard_ema = self.model_ema(x) 99 | # Augmentation-averaged Prediction 100 | N = 32 101 | outputs_emas = [] 102 | to_aug = anchor_prob.mean(0)<0.1 103 | if to_aug: 104 | for i in range(N): 105 | outputs_ = self.model_ema(self.transform(x)).detach() 106 | outputs_emas.append(outputs_) 107 | # Threshold choice discussed in supplementary 108 | if to_aug: 109 | outputs_ema = torch.stack(outputs_emas).mean(0) 110 | else: 111 | outputs_ema = standard_ema 112 | # Augmentation-averaged Prediction 113 | # Student update 114 | loss = (softmax_entropy(outputs, outputs_ema.detach())).mean(0) 115 | loss.backward() 116 | optimizer.step() 117 | optimizer.zero_grad() 118 | # Teacher update 119 | self.model_ema = update_ema_variables(ema_model = self.model_ema, model = self.model, alpha_teacher=0.999) 120 | # Stochastic restore 121 | if True: 122 | for nm, m in self.model.named_modules(): 123 | for npp, p in m.named_parameters(): 124 | if npp in ['weight', 'bias'] and p.requires_grad: 125 | mask = (torch.rand(p.shape)<0.001).float().cuda() 126 | with torch.no_grad(): 127 | p.data = self.model_state[f"{nm}.{npp}"] * mask + p * (1.-mask) 128 | return outputs_ema 129 | 130 | 131 | @torch.jit.script 132 | def softmax_entropy(x, x_ema):# -> torch.Tensor: 133 | """Entropy of softmax distribution from logits.""" 134 | return -0.5*(x_ema.softmax(1) * x.log_softmax(1)).sum(1)-0.5*(x.softmax(1) * x_ema.log_softmax(1)).sum(1) 135 | 136 | def collect_params(model): 137 | """Collect all trainable parameters. 138 | 139 | Walk the model's modules and collect all parameters. 140 | Return the parameters and their names. 141 | 142 | Note: other choices of parameterization are possible! 143 | """ 144 | params = [] 145 | names = [] 146 | for nm, m in model.named_modules(): 147 | if True:#isinstance(m, nn.BatchNorm2d): collect all 148 | for np, p in m.named_parameters(): 149 | if np in ['weight', 'bias'] and p.requires_grad: 150 | params.append(p) 151 | names.append(f"{nm}.{np}") 152 | print(nm, np) 153 | return params, names 154 | 155 | 156 | def copy_model_and_optimizer(model, optimizer): 157 | """Copy the model and optimizer states for resetting after adaptation.""" 158 | model_state = deepcopy(model.state_dict()) 159 | model_anchor = deepcopy(model) 160 | optimizer_state = deepcopy(optimizer.state_dict()) 161 | ema_model = deepcopy(model) 162 | for param in ema_model.parameters(): 163 | param.detach_() 164 | return model_state, optimizer_state, ema_model, model_anchor 165 | 166 | 167 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 168 | """Restore the model and optimizer states from copies.""" 169 | model.load_state_dict(model_state, strict=True) 170 | optimizer.load_state_dict(optimizer_state) 171 | 172 | 173 | def configure_model(model): 174 | """Configure model for use with tent.""" 175 | # train mode, because tent optimizes the model to minimize entropy 176 | model.train() 177 | # disable grad, to (re-)enable only what we update 178 | model.requires_grad_(False) 179 | # enable all trainable 180 | for m in model.modules(): 181 | if isinstance(m, nn.BatchNorm2d): 182 | m.requires_grad_(True) 183 | # force use of batch stats in train and eval modes 184 | m.track_running_stats = False 185 | m.running_mean = None 186 | m.running_var = None 187 | else: 188 | m.requires_grad_(True) 189 | return model 190 | 191 | 192 | def check_model(model): 193 | """Check model for compatability with tent.""" 194 | is_training = model.training 195 | assert is_training, "tent needs train mode: call model.train()" 196 | param_grads = [p.requires_grad for p in model.parameters()] 197 | has_any_params = any(param_grads) 198 | has_all_params = all(param_grads) 199 | assert has_any_params, "tent needs params to update: " \ 200 | "check which require grad" 201 | assert not has_all_params, "tent should not update all params: " \ 202 | "check which require grad" 203 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 204 | assert has_bn, "tent needs normalization for its optimization" 205 | -------------------------------------------------------------------------------- /cifar/robustbench/eval.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from argparse import Namespace 3 | from pathlib import Path 4 | from typing import Dict, Optional, Sequence, Tuple, Union 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import random 10 | from autoattack import AutoAttack 11 | from torch import nn 12 | from tqdm import tqdm 13 | 14 | from robustbench.data import CORRUPTIONS, load_clean_dataset, \ 15 | CORRUPTION_DATASET_LOADERS 16 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel 17 | from robustbench.utils import clean_accuracy, load_model, parse_args, update_json 18 | from robustbench.model_zoo import model_dicts as all_models 19 | 20 | 21 | def benchmark(model: Union[nn.Module, Sequence[nn.Module]], 22 | n_examples: int = 10000, 23 | dataset: Union[str, 24 | BenchmarkDataset] = BenchmarkDataset.cifar_10, 25 | threat_model: Union[str, ThreatModel] = ThreatModel.Linf, 26 | to_disk: bool = False, 27 | model_name: Optional[str] = None, 28 | data_dir: str = "./data", 29 | device: Optional[Union[torch.device, 30 | Sequence[torch.device]]] = None, 31 | batch_size: int = 32, 32 | eps: Optional[float] = None, 33 | log_path: Optional[str] = None) -> Tuple[float, float]: 34 | """Benchmarks the given model(s). 35 | 36 | It is possible to benchmark on 3 different threat models, and to save the results on disk. In 37 | the future benchmarking multiple models in parallel is going to be possible. 38 | 39 | :param model: The model to benchmark. 40 | :param n_examples: The number of examples to use to benchmark the model. 41 | :param dataset: The dataset to use to benchmark. Must be one of {cifar10, cifar100} 42 | :param threat_model: The threat model to use to benchmark, must be one of {L2, Linf 43 | corruptions} 44 | :param to_disk: Whether the results must be saved on disk as .json. 45 | :param model_name: The name of the model to use to save the results. Must be specified if 46 | to_json is True. 47 | :param data_dir: The directory where the dataset is or where the dataset must be downloaded. 48 | :param device: The device to run the computations. 49 | :param batch_size: The batch size to run the computations. The larger, the faster the 50 | evaluation. 51 | :param eps: The epsilon to use for L2 and Linf threat models. Must not be specified for 52 | corruptions threat model. 53 | 54 | :return: A Tuple with the clean accuracy and the accuracy in the given threat model. 55 | """ 56 | if isinstance(model, Sequence) or isinstance(device, Sequence): 57 | # Multiple models evaluation in parallel not yet implemented 58 | raise NotImplementedError 59 | 60 | try: 61 | if model.training: 62 | warnings.warn(Warning("The given model is *not* in eval mode.")) 63 | except AttributeError: 64 | warnings.warn( 65 | Warning( 66 | "It is not possible to asses if the model is in eval mode")) 67 | 68 | dataset_: BenchmarkDataset = BenchmarkDataset(dataset) 69 | threat_model_: ThreatModel = ThreatModel(threat_model) 70 | 71 | device = device or torch.device("cpu") 72 | model = model.to(device) 73 | 74 | if dataset == 'imagenet': 75 | prepr = all_models[dataset_][threat_model_][model_name]['preprocessing'] 76 | else: 77 | prepr = 'none' 78 | 79 | clean_x_test, clean_y_test = load_clean_dataset(dataset_, n_examples, 80 | data_dir, prepr) 81 | 82 | accuracy = clean_accuracy(model, 83 | clean_x_test, 84 | clean_y_test, 85 | batch_size=batch_size, 86 | device=device) 87 | print(f'Clean accuracy: {accuracy:.2%}') 88 | 89 | if threat_model_ in {ThreatModel.Linf, ThreatModel.L2}: 90 | if eps is None: 91 | raise ValueError( 92 | "If the threat model is L2 or Linf, `eps` must be specified.") 93 | 94 | adversary = AutoAttack(model, 95 | norm=threat_model_.value, 96 | eps=eps, 97 | version='standard', 98 | device=device, 99 | log_path=log_path) 100 | x_adv = adversary.run_standard_evaluation(clean_x_test, clean_y_test, bs=batch_size) 101 | adv_accuracy = clean_accuracy(model, 102 | x_adv, 103 | clean_y_test, 104 | batch_size=batch_size, 105 | device=device) 106 | elif threat_model_ == ThreatModel.corruptions: 107 | corruptions = CORRUPTIONS 108 | print(f"Evaluating over {len(corruptions)} corruptions") 109 | # Save into a dict to make a Pandas DF with nested index 110 | adv_accuracy = corruptions_evaluation(batch_size, data_dir, dataset_, 111 | device, model, n_examples, 112 | to_disk, prepr, model_name) 113 | else: 114 | raise NotImplementedError 115 | print(f'Adversarial accuracy: {adv_accuracy:.2%}') 116 | 117 | if to_disk: 118 | if model_name is None: 119 | raise ValueError( 120 | "If `to_disk` is True, `model_name` should be specified.") 121 | 122 | update_json(dataset_, threat_model_, model_name, accuracy, 123 | adv_accuracy, eps) 124 | 125 | return accuracy, adv_accuracy 126 | 127 | 128 | def corruptions_evaluation(batch_size: int, data_dir: str, 129 | dataset: BenchmarkDataset, device: torch.device, 130 | model: nn.Module, n_examples: int, to_disk: bool, 131 | prepr: str, model_name: Optional[str]) -> float: 132 | if to_disk and model_name is None: 133 | raise ValueError( 134 | "If `to_disk` is True, `model_name` should be specified.") 135 | 136 | corruptions = CORRUPTIONS 137 | model_results_dict: Dict[Tuple[str, int], float] = {} 138 | for corruption in tqdm(corruptions): 139 | for severity in range(1, 6): 140 | x_corrupt, y_corrupt = CORRUPTION_DATASET_LOADERS[dataset]( 141 | n_examples, 142 | severity, 143 | data_dir, 144 | shuffle=False, 145 | corruptions=[corruption], 146 | prepr=prepr) 147 | 148 | corruption_severity_accuracy = clean_accuracy( 149 | model, 150 | x_corrupt, 151 | y_corrupt, 152 | batch_size=batch_size, 153 | device=device) 154 | print('corruption={}, severity={}: {:.2%} accuracy'.format( 155 | corruption, severity, corruption_severity_accuracy)) 156 | 157 | model_results_dict[(corruption, 158 | severity)] = corruption_severity_accuracy 159 | 160 | model_results = pd.DataFrame(model_results_dict, index=[model_name]) 161 | adv_accuracy = model_results.values.mean() 162 | 163 | if not to_disk: 164 | return adv_accuracy 165 | 166 | # Save disaggregated results on disk 167 | existing_results_path = Path( 168 | "model_info" 169 | ) / dataset.value / "corruptions" / "unaggregated_results.csv" 170 | if not existing_results_path.parent.exists(): 171 | existing_results_path.parent.mkdir(parents=True, exist_ok=True) 172 | try: 173 | existing_results = pd.read_csv(existing_results_path, 174 | header=[0, 1], 175 | index_col=0) 176 | existing_results.columns = existing_results.columns.set_levels([ 177 | existing_results.columns.levels[0], 178 | existing_results.columns.levels[1].astype(int) 179 | ]) 180 | full_results = pd.concat([existing_results, model_results]) 181 | except FileNotFoundError: 182 | full_results = model_results 183 | full_results.to_csv(existing_results_path) 184 | 185 | return adv_accuracy 186 | 187 | 188 | def main(args: Namespace) -> None: 189 | torch.manual_seed(args.seed) 190 | torch.cuda.manual_seed(args.seed) 191 | np.random.seed(args.seed) 192 | random.seed(args.seed) 193 | 194 | model = load_model(args.model_name, 195 | model_dir=args.model_dir, 196 | dataset=args.dataset, 197 | threat_model=args.threat_model) 198 | 199 | model.eval() 200 | 201 | device = torch.device(args.device) 202 | benchmark(model, 203 | n_examples=args.n_ex, 204 | dataset=args.dataset, 205 | threat_model=args.threat_model, 206 | to_disk=args.to_disk, 207 | model_name=args.model_name, 208 | data_dir=args.data_dir, 209 | device=device, 210 | batch_size=args.batch_size, 211 | eps=args.eps) 212 | 213 | 214 | if __name__ == '__main__': 215 | # Example: 216 | # python -m robustbench.eval --n_ex=5000 --dataset=imagenet --threat_model=Linf \ 217 | # --model_name=Salman2020Do_R18 --data_dir=/tmldata1/andriush/imagenet/val \ 218 | # --batch_size=128 --eps=0.0156862745 219 | args_ = parse_args() 220 | main(args_) 221 | --------------------------------------------------------------------------------