├── cifar
├── robustbench
│ ├── leaderboard
│ │ ├── __init__.py
│ │ ├── leaderboard.html.j2
│ │ └── template.py
│ ├── model_zoo
│ │ ├── architectures
│ │ │ ├── __init__.py
│ │ │ ├── utils_architectures.py
│ │ │ ├── wide_resnet.py
│ │ │ └── resnext.py
│ │ ├── __init__.py
│ │ ├── enums.py
│ │ ├── models.py
│ │ └── imagenet.py
│ ├── __init__.py
│ ├── zenodo_download.py
│ ├── loaders.py
│ └── eval.py
├── run_cifar10_gradual.sh
├── run_cifar10.sh
├── run_cifar100.sh
├── cfgs
│ ├── cifar10
│ │ ├── norm.yaml
│ │ ├── source.yaml
│ │ ├── tent.yaml
│ │ └── cotta.yaml
│ ├── cifar100
│ │ ├── norm.yaml
│ │ ├── source.yaml
│ │ ├── tent.yaml
│ │ └── cotta.yaml
│ └── 10orders
│ │ ├── tent
│ │ ├── tent0.yaml
│ │ ├── tent1.yaml
│ │ ├── tent2.yaml
│ │ ├── tent3.yaml
│ │ ├── tent4.yaml
│ │ ├── tent5.yaml
│ │ ├── tent6.yaml
│ │ ├── tent7.yaml
│ │ ├── tent8.yaml
│ │ └── tent9.yaml
│ │ └── cotta
│ │ ├── cotta0.yaml
│ │ ├── cotta1.yaml
│ │ ├── cotta2.yaml
│ │ ├── cotta3.yaml
│ │ ├── cotta4.yaml
│ │ ├── cotta5.yaml
│ │ ├── cotta6.yaml
│ │ ├── cotta7.yaml
│ │ ├── cotta8.yaml
│ │ └── cotta9.yaml
├── norm.py
├── tent.py
├── my_transforms.py
├── cifar100c.py
├── cifar10c.py
├── cifar10c_gradual.py
├── utils.py
├── conf.py
└── cotta.py
├── imagenet
├── robustbench
│ ├── leaderboard
│ │ ├── __init__.py
│ │ ├── leaderboard.html.j2
│ │ └── template.py
│ ├── model_zoo
│ │ ├── architectures
│ │ │ ├── __init__.py
│ │ │ ├── utils_architectures.py
│ │ │ ├── wide_resnet.py
│ │ │ └── resnext.py
│ │ ├── __init__.py
│ │ ├── enums.py
│ │ ├── models.py
│ │ └── imagenet.py
│ ├── __init__.py
│ ├── zenodo_download.py
│ └── loaders.py
├── .gitignore
├── cfgs
│ ├── norm.yaml
│ ├── source.yaml
│ └── 10orders
│ │ ├── cotta
│ │ ├── cotta0.yaml
│ │ ├── cotta1.yaml
│ │ ├── cotta2.yaml
│ │ ├── cotta3.yaml
│ │ ├── cotta4.yaml
│ │ ├── cotta5.yaml
│ │ ├── cotta6.yaml
│ │ ├── cotta7.yaml
│ │ ├── cotta8.yaml
│ │ └── cotta9.yaml
│ │ └── tent
│ │ ├── tent0.yaml
│ │ ├── tent1.yaml
│ │ ├── tent2.yaml
│ │ ├── tent3.yaml
│ │ ├── tent4.yaml
│ │ ├── tent5.yaml
│ │ ├── tent6.yaml
│ │ ├── tent7.yaml
│ │ ├── tent8.yaml
│ │ └── tent9.yaml
├── data
│ └── ImageNet-C
│ │ └── download.sh
├── run.sh
├── eval.py
├── norm.py
├── tent.py
├── my_transforms.py
├── imagenetc.py
├── conf.py
└── cotta.py
├── setup_env.sh
├── LICENSE
├── .gitignore
├── README.md
└── environment.yml
/cifar/robustbench/leaderboard/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imagenet/robustbench/leaderboard/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/architectures/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/architectures/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import model_dicts
2 |
3 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import model_dicts
2 |
3 |
--------------------------------------------------------------------------------
/setup_env.sh:
--------------------------------------------------------------------------------
1 | conda update conda
2 | conda env create -f environment.yml
3 |
4 |
5 |
--------------------------------------------------------------------------------
/cifar/robustbench/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import load_cifar10
2 | from .utils import load_model
3 | from .eval import benchmark
4 |
--------------------------------------------------------------------------------
/imagenet/robustbench/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import load_cifar10
2 | from .utils import load_model
3 | from .eval import benchmark
4 |
--------------------------------------------------------------------------------
/imagenet/.gitignore:
--------------------------------------------------------------------------------
1 | ## General
2 |
3 | *.pyc
4 | __pycache__/
5 | .ropeproject/
6 |
7 | *.pkl
8 | *.npy
9 |
10 | *.ipynb_checkpoints
11 |
12 | .envrc
13 |
14 | ## Project
15 |
16 | *.pth
17 |
18 | /notebooks/*wild.ipynb
19 | /experiments
20 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class BenchmarkDataset(Enum):
5 | cifar_10 = 'cifar10'
6 | cifar_100 = 'cifar100'
7 | imagenet = 'imagenet'
8 |
9 |
10 | class ThreatModel(Enum):
11 | Linf = "Linf"
12 | L2 = "L2"
13 | corruptions = "corruptions"
14 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class BenchmarkDataset(Enum):
5 | cifar_10 = 'cifar10'
6 | cifar_100 = 'cifar100'
7 | imagenet = 'imagenet'
8 |
9 |
10 | class ThreatModel(Enum):
11 | Linf = "Linf"
12 | L2 = "L2"
13 | corruptions = "corruptions"
14 |
--------------------------------------------------------------------------------
/cifar/run_cifar10_gradual.sh:
--------------------------------------------------------------------------------
1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh
2 | export PYTHONPATH=
3 | conda deactivate
4 | conda activate cotta
5 |
6 | for i in {0..9}
7 | do
8 | CUDA_VISIBLE_DEVICES=0 python -u cifar10c_gradual.py --cfg cfgs/10orders/tent/tent$i.yaml
9 | CUDA_VISIBLE_DEVICES=0 python -u cifar10c_gradual.py --cfg cfgs/10orders/cotta/cotta$i.yaml
10 | done
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/cifar/run_cifar10.sh:
--------------------------------------------------------------------------------
1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh
2 | export PYTHONPATH=
3 | conda deactivate
4 | conda activate cotta
5 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/source.yaml
6 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/norm.yaml
7 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/tent.yaml
8 | CUDA_VISIBLE_DEVICES=0 python cifar10c.py --cfg cfgs/cifar10/cotta.yaml
9 |
10 |
11 |
--------------------------------------------------------------------------------
/cifar/run_cifar100.sh:
--------------------------------------------------------------------------------
1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh
2 | export PYTHONPATH=
3 | conda deactivate
4 | conda activate cotta
5 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/source.yaml
6 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/norm.yaml
7 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/tent.yaml
8 | CUDA_VISIBLE_DEVICES=0 python cifar100c.py --cfg cfgs/cifar100/cotta.yaml
9 |
10 |
11 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar10/norm.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: norm
3 | ARCH: Standard
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar10
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 |
--------------------------------------------------------------------------------
/imagenet/cfgs/norm.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: norm
3 | ARCH: Standard_R50
4 | TEST:
5 | BATCH_SIZE: 64
6 | CORRUPTION:
7 | DATASET: imagenet
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 |
--------------------------------------------------------------------------------
/imagenet/cfgs/source.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: source
3 | ARCH: Standard_R50
4 | TEST:
5 | BATCH_SIZE: 64
6 | CORRUPTION:
7 | DATASET: imagenet
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar10/source.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: source
3 | ARCH: Standard
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar10
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar100/norm.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: norm
3 | ARCH: Hendrycks2020AugMix_ResNeXt
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar100
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar100/source.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: source
3 | ARCH: Hendrycks2020AugMix_ResNeXt
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar100
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent0.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - brightness
7 | - pixelate
8 | - gaussian_noise
9 | - motion_blur
10 | - zoom_blur
11 | - glass_blur
12 | - impulse_noise
13 | - jpeg_compression
14 | - defocus_blur
15 | - elastic_transform
16 | - shot_noise
17 | - frost
18 | - snow
19 | - fog
20 | - contrast
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent1.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - jpeg_compression
7 | - shot_noise
8 | - zoom_blur
9 | - frost
10 | - contrast
11 | - fog
12 | - defocus_blur
13 | - elastic_transform
14 | - gaussian_noise
15 | - brightness
16 | - glass_blur
17 | - impulse_noise
18 | - pixelate
19 | - snow
20 | - motion_blur
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent2.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - defocus_blur
8 | - gaussian_noise
9 | - shot_noise
10 | - snow
11 | - frost
12 | - glass_blur
13 | - zoom_blur
14 | - elastic_transform
15 | - jpeg_compression
16 | - pixelate
17 | - brightness
18 | - impulse_noise
19 | - motion_blur
20 | - fog
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent3.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - shot_noise
7 | - fog
8 | - glass_blur
9 | - pixelate
10 | - snow
11 | - elastic_transform
12 | - brightness
13 | - impulse_noise
14 | - defocus_blur
15 | - frost
16 | - contrast
17 | - gaussian_noise
18 | - motion_blur
19 | - jpeg_compression
20 | - zoom_blur
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent4.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - pixelate
7 | - glass_blur
8 | - zoom_blur
9 | - snow
10 | - fog
11 | - impulse_noise
12 | - brightness
13 | - motion_blur
14 | - frost
15 | - jpeg_compression
16 | - gaussian_noise
17 | - shot_noise
18 | - contrast
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent5.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - motion_blur
7 | - snow
8 | - fog
9 | - shot_noise
10 | - defocus_blur
11 | - contrast
12 | - zoom_blur
13 | - brightness
14 | - frost
15 | - elastic_transform
16 | - glass_blur
17 | - gaussian_noise
18 | - pixelate
19 | - jpeg_compression
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent6.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - frost
7 | - impulse_noise
8 | - jpeg_compression
9 | - contrast
10 | - zoom_blur
11 | - glass_blur
12 | - pixelate
13 | - snow
14 | - defocus_blur
15 | - motion_blur
16 | - brightness
17 | - elastic_transform
18 | - shot_noise
19 | - fog
20 | - gaussian_noise
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent7.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - glass_blur
7 | - zoom_blur
8 | - impulse_noise
9 | - fog
10 | - snow
11 | - jpeg_compression
12 | - gaussian_noise
13 | - frost
14 | - shot_noise
15 | - brightness
16 | - contrast
17 | - motion_blur
18 | - pixelate
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent8.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - defocus_blur
7 | - motion_blur
8 | - zoom_blur
9 | - shot_noise
10 | - gaussian_noise
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - contrast
15 | - pixelate
16 | - frost
17 | - snow
18 | - brightness
19 | - elastic_transform
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/tent/tent9.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - gaussian_noise
8 | - defocus_blur
9 | - zoom_blur
10 | - frost
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - pixelate
15 | - elastic_transform
16 | - shot_noise
17 | - impulse_noise
18 | - snow
19 | - motion_blur
20 | - brightness
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 200
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta0.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - brightness
7 | - pixelate
8 | - gaussian_noise
9 | - motion_blur
10 | - zoom_blur
11 | - glass_blur
12 | - impulse_noise
13 | - jpeg_compression
14 | - defocus_blur
15 | - elastic_transform
16 | - shot_noise
17 | - frost
18 | - snow
19 | - fog
20 | - contrast
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta1.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - jpeg_compression
7 | - shot_noise
8 | - zoom_blur
9 | - frost
10 | - contrast
11 | - fog
12 | - defocus_blur
13 | - elastic_transform
14 | - gaussian_noise
15 | - brightness
16 | - glass_blur
17 | - impulse_noise
18 | - pixelate
19 | - snow
20 | - motion_blur
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta2.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - defocus_blur
8 | - gaussian_noise
9 | - shot_noise
10 | - snow
11 | - frost
12 | - glass_blur
13 | - zoom_blur
14 | - elastic_transform
15 | - jpeg_compression
16 | - pixelate
17 | - brightness
18 | - impulse_noise
19 | - motion_blur
20 | - fog
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta3.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - shot_noise
7 | - fog
8 | - glass_blur
9 | - pixelate
10 | - snow
11 | - elastic_transform
12 | - brightness
13 | - impulse_noise
14 | - defocus_blur
15 | - frost
16 | - contrast
17 | - gaussian_noise
18 | - motion_blur
19 | - jpeg_compression
20 | - zoom_blur
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta4.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - pixelate
7 | - glass_blur
8 | - zoom_blur
9 | - snow
10 | - fog
11 | - impulse_noise
12 | - brightness
13 | - motion_blur
14 | - frost
15 | - jpeg_compression
16 | - gaussian_noise
17 | - shot_noise
18 | - contrast
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta5.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - motion_blur
7 | - snow
8 | - fog
9 | - shot_noise
10 | - defocus_blur
11 | - contrast
12 | - zoom_blur
13 | - brightness
14 | - frost
15 | - elastic_transform
16 | - glass_blur
17 | - gaussian_noise
18 | - pixelate
19 | - jpeg_compression
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta6.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - frost
7 | - impulse_noise
8 | - jpeg_compression
9 | - contrast
10 | - zoom_blur
11 | - glass_blur
12 | - pixelate
13 | - snow
14 | - defocus_blur
15 | - motion_blur
16 | - brightness
17 | - elastic_transform
18 | - shot_noise
19 | - fog
20 | - gaussian_noise
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta7.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - glass_blur
7 | - zoom_blur
8 | - impulse_noise
9 | - fog
10 | - snow
11 | - jpeg_compression
12 | - gaussian_noise
13 | - frost
14 | - shot_noise
15 | - brightness
16 | - contrast
17 | - motion_blur
18 | - pixelate
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta8.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - defocus_blur
7 | - motion_blur
8 | - zoom_blur
9 | - shot_noise
10 | - gaussian_noise
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - contrast
15 | - pixelate
16 | - frost
17 | - snow
18 | - brightness
19 | - elastic_transform
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/cotta/cotta9.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - gaussian_noise
8 | - defocus_blur
9 | - zoom_blur
10 | - frost
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - pixelate
15 | - elastic_transform
16 | - shot_noise
17 | - impulse_noise
18 | - snow
19 | - motion_blur
20 | - brightness
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.01
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent0.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - brightness
7 | - pixelate
8 | - gaussian_noise
9 | - motion_blur
10 | - zoom_blur
11 | - glass_blur
12 | - impulse_noise
13 | - jpeg_compression
14 | - defocus_blur
15 | - elastic_transform
16 | - shot_noise
17 | - frost
18 | - snow
19 | - fog
20 | - contrast
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent1.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - jpeg_compression
7 | - shot_noise
8 | - zoom_blur
9 | - frost
10 | - contrast
11 | - fog
12 | - defocus_blur
13 | - elastic_transform
14 | - gaussian_noise
15 | - brightness
16 | - glass_blur
17 | - impulse_noise
18 | - pixelate
19 | - snow
20 | - motion_blur
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent2.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - defocus_blur
8 | - gaussian_noise
9 | - shot_noise
10 | - snow
11 | - frost
12 | - glass_blur
13 | - zoom_blur
14 | - elastic_transform
15 | - jpeg_compression
16 | - pixelate
17 | - brightness
18 | - impulse_noise
19 | - motion_blur
20 | - fog
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent3.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - shot_noise
7 | - fog
8 | - glass_blur
9 | - pixelate
10 | - snow
11 | - elastic_transform
12 | - brightness
13 | - impulse_noise
14 | - defocus_blur
15 | - frost
16 | - contrast
17 | - gaussian_noise
18 | - motion_blur
19 | - jpeg_compression
20 | - zoom_blur
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent4.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - pixelate
7 | - glass_blur
8 | - zoom_blur
9 | - snow
10 | - fog
11 | - impulse_noise
12 | - brightness
13 | - motion_blur
14 | - frost
15 | - jpeg_compression
16 | - gaussian_noise
17 | - shot_noise
18 | - contrast
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent5.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - motion_blur
7 | - snow
8 | - fog
9 | - shot_noise
10 | - defocus_blur
11 | - contrast
12 | - zoom_blur
13 | - brightness
14 | - frost
15 | - elastic_transform
16 | - glass_blur
17 | - gaussian_noise
18 | - pixelate
19 | - jpeg_compression
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent6.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - frost
7 | - impulse_noise
8 | - jpeg_compression
9 | - contrast
10 | - zoom_blur
11 | - glass_blur
12 | - pixelate
13 | - snow
14 | - defocus_blur
15 | - motion_blur
16 | - brightness
17 | - elastic_transform
18 | - shot_noise
19 | - fog
20 | - gaussian_noise
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent7.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - glass_blur
7 | - zoom_blur
8 | - impulse_noise
9 | - fog
10 | - snow
11 | - jpeg_compression
12 | - gaussian_noise
13 | - frost
14 | - shot_noise
15 | - brightness
16 | - contrast
17 | - motion_blur
18 | - pixelate
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent8.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - defocus_blur
7 | - motion_blur
8 | - zoom_blur
9 | - shot_noise
10 | - gaussian_noise
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - contrast
15 | - pixelate
16 | - frost
17 | - snow
18 | - brightness
19 | - elastic_transform
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/imagenet/cfgs/10orders/tent/tent9.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: imagenet
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - gaussian_noise
8 | - defocus_blur
9 | - zoom_blur
10 | - frost
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - pixelate
15 | - elastic_transform
16 | - shot_noise
17 | - impulse_noise
18 | - snow
19 | - motion_blur
20 | - brightness
21 | MODEL:
22 | ADAPTATION: tent
23 | ARCH: Standard_R50
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 0.00025
27 | METHOD: SGD
28 | STEPS: 1
29 | WD: 0.0
30 | TEST:
31 | BATCH_SIZE: 64
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar10/tent.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: tent
3 | ARCH: Standard
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar10
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 | OPTIM:
27 | METHOD: Adam
28 | STEPS: 1
29 | BETA: 0.9
30 | LR: 1e-3
31 | WD: 0.
32 |
--------------------------------------------------------------------------------
/imagenet/data/ImageNet-C/download.sh:
--------------------------------------------------------------------------------
1 | wget --content-disposition https://zenodo.org/record/2235448/files/blur.tar?download=1
2 | wget --content-disposition https://zenodo.org/record/2235448/files/digital.tar?download=1
3 | wget --content-disposition https://zenodo.org/record/2235448/files/extra.tar?download=1
4 | wget --content-disposition https://zenodo.org/record/2235448/files/noise.tar?download=1
5 | wget --content-disposition https://zenodo.org/record/2235448/files/weather.tar?download=1
6 |
7 | tar -zxvf blur.tar
8 | tar -zxvf digital.tar
9 | tar -zxvf extra.tar
10 | tar -zxvf noise.tar
11 | tar -zxvf weather.tar
12 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar100/tent.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: tent
3 | ARCH: Hendrycks2020AugMix_ResNeXt
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar100
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 | OPTIM:
27 | METHOD: Adam
28 | STEPS: 1
29 | BETA: 0.9
30 | LR: 1e-3
31 | WD: 0.
32 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta0.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - brightness
7 | - pixelate
8 | - gaussian_noise
9 | - motion_blur
10 | - zoom_blur
11 | - glass_blur
12 | - impulse_noise
13 | - jpeg_compression
14 | - defocus_blur
15 | - elastic_transform
16 | - shot_noise
17 | - frost
18 | - snow
19 | - fog
20 | - contrast
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta1.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - jpeg_compression
7 | - shot_noise
8 | - zoom_blur
9 | - frost
10 | - contrast
11 | - fog
12 | - defocus_blur
13 | - elastic_transform
14 | - gaussian_noise
15 | - brightness
16 | - glass_blur
17 | - impulse_noise
18 | - pixelate
19 | - snow
20 | - motion_blur
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta2.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - defocus_blur
8 | - gaussian_noise
9 | - shot_noise
10 | - snow
11 | - frost
12 | - glass_blur
13 | - zoom_blur
14 | - elastic_transform
15 | - jpeg_compression
16 | - pixelate
17 | - brightness
18 | - impulse_noise
19 | - motion_blur
20 | - fog
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta3.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - shot_noise
7 | - fog
8 | - glass_blur
9 | - pixelate
10 | - snow
11 | - elastic_transform
12 | - brightness
13 | - impulse_noise
14 | - defocus_blur
15 | - frost
16 | - contrast
17 | - gaussian_noise
18 | - motion_blur
19 | - jpeg_compression
20 | - zoom_blur
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta4.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - pixelate
7 | - glass_blur
8 | - zoom_blur
9 | - snow
10 | - fog
11 | - impulse_noise
12 | - brightness
13 | - motion_blur
14 | - frost
15 | - jpeg_compression
16 | - gaussian_noise
17 | - shot_noise
18 | - contrast
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta5.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - motion_blur
7 | - snow
8 | - fog
9 | - shot_noise
10 | - defocus_blur
11 | - contrast
12 | - zoom_blur
13 | - brightness
14 | - frost
15 | - elastic_transform
16 | - glass_blur
17 | - gaussian_noise
18 | - pixelate
19 | - jpeg_compression
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta6.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - frost
7 | - impulse_noise
8 | - jpeg_compression
9 | - contrast
10 | - zoom_blur
11 | - glass_blur
12 | - pixelate
13 | - snow
14 | - defocus_blur
15 | - motion_blur
16 | - brightness
17 | - elastic_transform
18 | - shot_noise
19 | - fog
20 | - gaussian_noise
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta7.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - glass_blur
7 | - zoom_blur
8 | - impulse_noise
9 | - fog
10 | - snow
11 | - jpeg_compression
12 | - gaussian_noise
13 | - frost
14 | - shot_noise
15 | - brightness
16 | - contrast
17 | - motion_blur
18 | - pixelate
19 | - defocus_blur
20 | - elastic_transform
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta8.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - defocus_blur
7 | - motion_blur
8 | - zoom_blur
9 | - shot_noise
10 | - gaussian_noise
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - contrast
15 | - pixelate
16 | - frost
17 | - snow
18 | - brightness
19 | - elastic_transform
20 | - impulse_noise
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/10orders/cotta/cotta9.yaml:
--------------------------------------------------------------------------------
1 | CORRUPTION:
2 | DATASET: cifar10
3 | SEVERITY:
4 | - 5
5 | TYPE:
6 | - contrast
7 | - gaussian_noise
8 | - defocus_blur
9 | - zoom_blur
10 | - frost
11 | - glass_blur
12 | - jpeg_compression
13 | - fog
14 | - pixelate
15 | - elastic_transform
16 | - shot_noise
17 | - impulse_noise
18 | - snow
19 | - motion_blur
20 | - brightness
21 | MODEL:
22 | ADAPTATION: cotta
23 | ARCH: Standard
24 | OPTIM:
25 | BETA: 0.9
26 | LR: 1e-3
27 | METHOD: Adam
28 | STEPS: 1
29 | WD: 0.0
30 | MT: 0.999
31 | RST: 0.01
32 | AP: 0.92
33 | TEST:
34 | BATCH_SIZE: 200
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar10/cotta.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: cotta
3 | ARCH: Standard
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar10
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 | OPTIM:
27 | METHOD: Adam
28 | STEPS: 1
29 | BETA: 0.9
30 | LR: 1e-3
31 | WD: 0.
32 | MT: 0.999
33 | RST: 0.01
34 | AP: 0.92
35 |
--------------------------------------------------------------------------------
/cifar/cfgs/cifar100/cotta.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | ADAPTATION: cotta
3 | ARCH: Hendrycks2020AugMix_ResNeXt
4 | TEST:
5 | BATCH_SIZE: 200
6 | CORRUPTION:
7 | DATASET: cifar100
8 | SEVERITY:
9 | - 5
10 | TYPE:
11 | - gaussian_noise
12 | - shot_noise
13 | - impulse_noise
14 | - defocus_blur
15 | - glass_blur
16 | - motion_blur
17 | - zoom_blur
18 | - snow
19 | - frost
20 | - fog
21 | - brightness
22 | - contrast
23 | - elastic_transform
24 | - pixelate
25 | - jpeg_compression
26 | OPTIM:
27 | METHOD: Adam
28 | STEPS: 1
29 | BETA: 0.9
30 | LR: 1e-3
31 | WD: 0.
32 | MT: 0.999
33 | RST: 0.01
34 | AP: 0.72
35 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/models.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Any, Dict, Dict as OrderedDictType
3 |
4 | from robustbench.model_zoo.cifar10 import cifar_10_models
5 | from robustbench.model_zoo.cifar100 import cifar_100_models
6 | from robustbench.model_zoo.imagenet import imagenet_models
7 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
8 |
9 | ModelsDict = OrderedDictType[str, Dict[str, Any]]
10 | ThreatModelsDict = OrderedDictType[ThreatModel, ModelsDict]
11 | BenchmarkDict = OrderedDictType[BenchmarkDataset, ThreatModelsDict]
12 |
13 | model_dicts: BenchmarkDict = OrderedDict([
14 | (BenchmarkDataset.cifar_10, cifar_10_models),
15 | (BenchmarkDataset.cifar_100, cifar_100_models),
16 | (BenchmarkDataset.imagenet, imagenet_models)
17 | ])
18 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/models.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Any, Dict, Dict as OrderedDictType
3 |
4 | from robustbench.model_zoo.cifar10 import cifar_10_models
5 | from robustbench.model_zoo.cifar100 import cifar_100_models
6 | from robustbench.model_zoo.imagenet import imagenet_models
7 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
8 |
9 | ModelsDict = OrderedDictType[str, Dict[str, Any]]
10 | ThreatModelsDict = OrderedDictType[ThreatModel, ModelsDict]
11 | BenchmarkDict = OrderedDictType[BenchmarkDataset, ThreatModelsDict]
12 |
13 | model_dicts: BenchmarkDict = OrderedDict([
14 | (BenchmarkDataset.cifar_10, cifar_10_models),
15 | (BenchmarkDataset.cifar_100, cifar_100_models),
16 | (BenchmarkDataset.imagenet, imagenet_models)
17 | ])
18 |
--------------------------------------------------------------------------------
/imagenet/run.sh:
--------------------------------------------------------------------------------
1 | #source /cluster/home/qwang/miniconda3/etc/profile.d/conda.sh
2 | # Clean PATH and only use cotta env
3 | export PYTHONPATH=
4 | conda deactivate
5 | conda activate cotta
6 | # Source-only and AdaBN results are not affected by the order as no training is performed. Therefore only need to run once.
7 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/source.yaml
8 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/norm.yaml
9 | # TENT and CoTTA results are affected by the corruption sequence order
10 | for i in {0..9}
11 | do
12 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/10orders/tent/tent$i.yaml
13 | CUDA_VISIBLE_DEVICES=0 python -u imagenetc.py --cfg cfgs/10orders/cotta/cotta$i.yaml
14 | done
15 | # Run Mean and AVG for TENT and CoTTA
16 | cd output
17 | python3 -u ../eval.py | tee result.log
18 |
--------------------------------------------------------------------------------
/imagenet/eval.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import numpy as np
3 |
4 |
5 | def read_file(filename):
6 | lines = open(filename, "r").readlines()
7 | res = []
8 | for l in lines:
9 | if "error" in l and "]: error % [" in l:
10 | res.append(float(l.strip().split(" ")[-1][:-1]))
11 | assert len(res)==15
12 | return np.mean(np.array(res))
13 |
14 |
15 | def read_files(files):
16 | res = []
17 | for f in files:
18 | res.append(read_file(f))
19 | print("read", len(files), "files.")
20 | print(res)
21 | return np.mean(np.array(res)), np.std(np.array(res))
22 |
23 |
24 | print("read source files:")
25 | print(read_files(glob("source_*.txt")))
26 |
27 | print("read adabn files:")
28 | print(read_files(glob("norm_*.txt")))
29 |
30 | print("read tent files:")
31 | print(read_files(glob("tent[0-9]_*.txt")))
32 |
33 | print("read cotta files:")
34 | print(read_files(glob("cotta[0-9]_*.txt")))
35 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/architectures/utils_architectures.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 | from typing import Tuple
5 | from torch import Tensor
6 |
7 |
8 | class ImageNormalizer(nn.Module):
9 | def __init__(self, mean: Tuple[float, float, float],
10 | std: Tuple[float, float, float]) -> None:
11 | super(ImageNormalizer, self).__init__()
12 |
13 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1))
14 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1))
15 |
16 | def forward(self, input: Tensor) -> Tensor:
17 | return (input - self.mean) / self.std
18 |
19 |
20 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float],
21 | std: Tuple[float, float, float]) -> nn.Module:
22 | layers = OrderedDict([
23 | ('normalize', ImageNormalizer(mean, std)),
24 | ('model', model)
25 | ])
26 | return nn.Sequential(layers)
27 |
28 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/architectures/utils_architectures.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 | from typing import Tuple
5 | from torch import Tensor
6 |
7 |
8 | class ImageNormalizer(nn.Module):
9 | def __init__(self, mean: Tuple[float, float, float],
10 | std: Tuple[float, float, float]) -> None:
11 | super(ImageNormalizer, self).__init__()
12 |
13 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1))
14 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1))
15 |
16 | def forward(self, input: Tensor) -> Tensor:
17 | return (input - self.mean) / self.std
18 |
19 |
20 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float],
21 | std: Tuple[float, float, float]) -> nn.Module:
22 | layers = OrderedDict([
23 | ('normalize', ImageNormalizer(mean, std)),
24 | ('model', model)
25 | ])
26 | return nn.Sequential(layers)
27 |
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Qin Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
24 |
25 | ### Our code is based on TENT, therefore a copy of TENT's license is included here
26 | MIT License
27 |
28 | Copyright (c) 2021 Dequan Wang and Evan Shelhamer
29 |
30 | Permission is hereby granted, free of charge, to any person obtaining a copy
31 | of this software and associated documentation files (the "Software"), to deal
32 | in the Software without restriction, including without limitation the rights
33 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
34 | copies of the Software, and to permit persons to whom the Software is
35 | furnished to do so, subject to the following conditions:
36 |
37 | The above copyright notice and this permission notice shall be included in all
38 | copies or substantial portions of the Software.
39 |
40 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
41 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
42 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
43 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
44 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
45 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
46 | SOFTWARE.
47 |
48 |
--------------------------------------------------------------------------------
/cifar/norm.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Norm(nn.Module):
8 | """Norm adapts a model by estimating feature statistics during testing.
9 |
10 | Once equipped with Norm, the model normalizes its features during testing
11 | with batch-wise statistics, just like batch norm does during training.
12 | """
13 |
14 | def __init__(self, model, eps=1e-5, momentum=0.1,
15 | reset_stats=False, no_stats=False):
16 | super().__init__()
17 | self.model = model
18 | self.model = configure_model(model, eps, momentum, reset_stats,
19 | no_stats)
20 | self.model_state = deepcopy(self.model.state_dict())
21 |
22 | def forward(self, x):
23 | return self.model(x)
24 |
25 | def reset(self):
26 | self.model.load_state_dict(self.model_state, strict=True)
27 |
28 |
29 | def collect_stats(model):
30 | """Collect the normalization stats from batch norms.
31 |
32 | Walk the model's modules and collect all batch normalization stats.
33 | Return the stats and their names.
34 | """
35 | stats = []
36 | names = []
37 | for nm, m in model.named_modules():
38 | if isinstance(m, nn.BatchNorm2d):
39 | state = m.state_dict()
40 | if m.affine:
41 | del state['weight'], state['bias']
42 | for ns, s in state.items():
43 | stats.append(s)
44 | names.append(f"{nm}.{ns}")
45 | return stats, names
46 |
47 |
48 | def configure_model(model, eps, momentum, reset_stats, no_stats):
49 | """Configure model for adaptation by test-time normalization."""
50 | for m in model.modules():
51 | if isinstance(m, nn.BatchNorm2d):
52 | # use batch-wise statistics in forward
53 | m.train()
54 | # configure epsilon for stability, and momentum for updates
55 | m.eps = eps
56 | m.momentum = momentum
57 | if reset_stats:
58 | # reset state to estimate test stats without train stats
59 | m.reset_running_stats()
60 | if no_stats:
61 | # disable state entirely and use only batch stats
62 | m.track_running_stats = False
63 | m.running_mean = None
64 | m.running_var = None
65 | return model
66 |
--------------------------------------------------------------------------------
/imagenet/norm.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Norm(nn.Module):
8 | """Norm adapts a model by estimating feature statistics during testing.
9 |
10 | Once equipped with Norm, the model normalizes its features during testing
11 | with batch-wise statistics, just like batch norm does during training.
12 | """
13 |
14 | def __init__(self, model, eps=1e-5, momentum=0.1,
15 | reset_stats=False, no_stats=False):
16 | super().__init__()
17 | self.model = model
18 | self.model = configure_model(model, eps, momentum, reset_stats,
19 | no_stats)
20 | self.model_state = deepcopy(self.model.state_dict())
21 |
22 | def forward(self, x):
23 | return self.model(x)
24 |
25 | def reset(self):
26 | self.model.load_state_dict(self.model_state, strict=True)
27 |
28 |
29 | def collect_stats(model):
30 | """Collect the normalization stats from batch norms.
31 |
32 | Walk the model's modules and collect all batch normalization stats.
33 | Return the stats and their names.
34 | """
35 | stats = []
36 | names = []
37 | for nm, m in model.named_modules():
38 | if isinstance(m, nn.BatchNorm2d):
39 | state = m.state_dict()
40 | if m.affine:
41 | del state['weight'], state['bias']
42 | for ns, s in state.items():
43 | stats.append(s)
44 | names.append(f"{nm}.{ns}")
45 | return stats, names
46 |
47 |
48 | def configure_model(model, eps, momentum, reset_stats, no_stats):
49 | """Configure model for adaptation by test-time normalization."""
50 | for m in model.modules():
51 | if isinstance(m, nn.BatchNorm2d):
52 | # use batch-wise statistics in forward
53 | m.train()
54 | # configure epsilon for stability, and momentum for updates
55 | m.eps = eps
56 | m.momentum = momentum
57 | if reset_stats:
58 | # reset state to estimate test stats without train stats
59 | m.reset_running_stats()
60 | if no_stats:
61 | # disable state entirely and use only batch stats
62 | m.track_running_stats = False
63 | m.running_mean = None
64 | m.running_var = None
65 | return model
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/cifar/robustbench/zenodo_download.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import shutil
3 | from pathlib import Path
4 | from typing import Set
5 |
6 | import requests
7 | from tqdm import tqdm
8 |
9 | ZENODO_ENTRY_POINT = "https://zenodo.org/api"
10 | RECORDS_ENTRY_POINT = f"{ZENODO_ENTRY_POINT}/records/"
11 |
12 | CHUNK_SIZE = 65536
13 |
14 |
15 | class DownloadError(Exception):
16 | pass
17 |
18 |
19 | def download_file(url: str, save_dir: Path, total_bytes: int) -> Path:
20 | """Downloads large files from the given URL.
21 |
22 | From: https://stackoverflow.com/a/16696317
23 |
24 | :param url: The URL of the file.
25 | :param save_dir: The directory where the file should be saved.
26 | :param total_bytes: The total bytes of the file.
27 | :return: The path to the downloaded file.
28 | """
29 | local_filename = save_dir / url.split('/')[-1]
30 | print(f"Starting download from {url}")
31 | with requests.get(url, stream=True) as r:
32 | r.raise_for_status()
33 | with open(local_filename, 'wb') as f:
34 | iters = total_bytes // CHUNK_SIZE
35 | for chunk in tqdm(r.iter_content(chunk_size=CHUNK_SIZE),
36 | total=iters):
37 | f.write(chunk)
38 |
39 | return local_filename
40 |
41 |
42 | def file_md5(filename: Path) -> str:
43 | """Computes the MD5 hash of a given file"""
44 | hash_md5 = hashlib.md5()
45 | with open(filename, "rb") as f:
46 | for chunk in iter(lambda: f.read(32768), b""):
47 | hash_md5.update(chunk)
48 |
49 | return hash_md5.hexdigest()
50 |
51 |
52 | def zenodo_download(record_id: str, filenames_to_download: Set[str],
53 | save_dir: Path) -> None:
54 | """Downloads the given files from the given Zenodo record.
55 |
56 | :param record_id: The ID of the record.
57 | :param filenames_to_download: The files to download from the record.
58 | :param save_dir: The directory where the files should be saved.
59 | """
60 | if not save_dir.exists():
61 | save_dir.mkdir(parents=True, exist_ok=True)
62 |
63 | url = f"{RECORDS_ENTRY_POINT}/{record_id}"
64 | res = requests.get(url)
65 | files = res.json()["files"]
66 | files_to_download = list(
67 | filter(lambda file: file["key"] in filenames_to_download, files))
68 |
69 | for file in files_to_download:
70 | if (save_dir / file["key"]).exists():
71 | continue
72 | file_url = file["links"]["self"]
73 | file_checksum = file["checksum"].split(":")[-1]
74 | filename = download_file(file_url, save_dir, file["size"])
75 | if file_md5(filename) != file_checksum:
76 | raise DownloadError(
77 | "The hash of the downloaded file does not match"
78 | " the expected one.")
79 | print("Download finished, extracting...")
80 | shutil.unpack_archive(filename,
81 | extract_dir=save_dir,
82 | format=file["type"])
83 | print("Downloaded and extracted.")
84 |
--------------------------------------------------------------------------------
/imagenet/robustbench/zenodo_download.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import shutil
3 | from pathlib import Path
4 | from typing import Set
5 |
6 | import requests
7 | from tqdm import tqdm
8 |
9 | ZENODO_ENTRY_POINT = "https://zenodo.org/api"
10 | RECORDS_ENTRY_POINT = f"{ZENODO_ENTRY_POINT}/records/"
11 |
12 | CHUNK_SIZE = 65536
13 |
14 |
15 | class DownloadError(Exception):
16 | pass
17 |
18 |
19 | def download_file(url: str, save_dir: Path, total_bytes: int) -> Path:
20 | """Downloads large files from the given URL.
21 |
22 | From: https://stackoverflow.com/a/16696317
23 |
24 | :param url: The URL of the file.
25 | :param save_dir: The directory where the file should be saved.
26 | :param total_bytes: The total bytes of the file.
27 | :return: The path to the downloaded file.
28 | """
29 | local_filename = save_dir / url.split('/')[-1]
30 | print(f"Starting download from {url}")
31 | with requests.get(url, stream=True) as r:
32 | r.raise_for_status()
33 | with open(local_filename, 'wb') as f:
34 | iters = total_bytes // CHUNK_SIZE
35 | for chunk in tqdm(r.iter_content(chunk_size=CHUNK_SIZE),
36 | total=iters):
37 | f.write(chunk)
38 |
39 | return local_filename
40 |
41 |
42 | def file_md5(filename: Path) -> str:
43 | """Computes the MD5 hash of a given file"""
44 | hash_md5 = hashlib.md5()
45 | with open(filename, "rb") as f:
46 | for chunk in iter(lambda: f.read(32768), b""):
47 | hash_md5.update(chunk)
48 |
49 | return hash_md5.hexdigest()
50 |
51 |
52 | def zenodo_download(record_id: str, filenames_to_download: Set[str],
53 | save_dir: Path) -> None:
54 | """Downloads the given files from the given Zenodo record.
55 |
56 | :param record_id: The ID of the record.
57 | :param filenames_to_download: The files to download from the record.
58 | :param save_dir: The directory where the files should be saved.
59 | """
60 | if not save_dir.exists():
61 | save_dir.mkdir(parents=True, exist_ok=True)
62 |
63 | url = f"{RECORDS_ENTRY_POINT}/{record_id}"
64 | res = requests.get(url)
65 | files = res.json()["files"]
66 | files_to_download = list(
67 | filter(lambda file: file["key"] in filenames_to_download, files))
68 |
69 | for file in files_to_download:
70 | if (save_dir / file["key"]).exists():
71 | continue
72 | file_url = file["links"]["self"]
73 | file_checksum = file["checksum"].split(":")[-1]
74 | filename = download_file(file_url, save_dir, file["size"])
75 | if file_md5(filename) != file_checksum:
76 | raise DownloadError(
77 | "The hash of the downloaded file does not match"
78 | " the expected one.")
79 | print("Download finished, extracting...")
80 | shutil.unpack_archive(filename,
81 | extract_dir=save_dir,
82 | format=file["type"])
83 | print("Downloaded and extracted.")
84 |
--------------------------------------------------------------------------------
/cifar/robustbench/leaderboard/leaderboard.html.j2:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | | Rank |
5 | Method |
6 |
7 | Standard
8 | accuracy
9 | |
10 | {% if threat_model != "corruptions" %}
11 |
12 | AutoAttack
13 | robust
14 | accuracy
15 | |
16 |
17 | Best known
18 | robust
19 | accuracy
20 | |
21 |
22 | AA eval.
23 | potentially
24 | unreliable
25 | |
26 | {% endif %}
27 | {% if threat_model == "corruptions" %}
28 |
29 | Robust
30 | accuracy
31 | |
32 | {% endif %}
33 |
34 | Architecture |
35 | Venue |
36 |
37 |
38 |
39 | {% for model in models %}
40 |
41 | | {{ loop.index }} |
42 |
43 | {{ model.name }}
44 | {% if model.footnote is defined and model.footnote != None %}
45 |
46 |
49 | {% endif %}
50 | |
51 | {{ model.clean_acc }}% |
52 | {{ model[acc_field] }}% |
53 | {% if threat_model != "corruptions" %}
54 | {{ model.external if model.external is defined and model.external else model[acc_field]}}% |
55 | {{ "Unknown" if model.unreliable is not defined else (" ☑ " if model.unreliable else "× ") }} |
56 | {% endif %}
57 | {{ "☑" if model.additional_data else "×" }} |
58 | {{ model.architecture }} |
59 | {{ model.venue }} |
60 |
61 | {% endfor %}
62 |
63 |
64 |
--------------------------------------------------------------------------------
/imagenet/robustbench/leaderboard/leaderboard.html.j2:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | | Rank |
5 | Method |
6 |
7 | Standard
8 | accuracy
9 | |
10 | {% if threat_model != "corruptions" %}
11 |
12 | AutoAttack
13 | robust
14 | accuracy
15 | |
16 |
17 | Best known
18 | robust
19 | accuracy
20 | |
21 |
22 | AA eval.
23 | potentially
24 | unreliable
25 | |
26 | {% endif %}
27 | {% if threat_model == "corruptions" %}
28 |
29 | Robust
30 | accuracy
31 | |
32 | {% endif %}
33 |
34 | Architecture |
35 | Venue |
36 |
37 |
38 |
39 | {% for model in models %}
40 |
41 | | {{ loop.index }} |
42 |
43 | {{ model.name }}
44 | {% if model.footnote is defined and model.footnote != None %}
45 |
46 |
49 | {% endif %}
50 | |
51 | {{ model.clean_acc }}% |
52 | {{ model[acc_field] }}% |
53 | {% if threat_model != "corruptions" %}
54 | {{ model.external if model.external is defined and model.external else model[acc_field]}}% |
55 | {{ "Unknown" if model.unreliable is not defined else (" ☑ " if model.unreliable else "× ") }} |
56 | {% endif %}
57 | {{ "☑" if model.additional_data else "×" }} |
58 | {{ model.architecture }} |
59 | {{ model.venue }} |
60 |
61 | {% endfor %}
62 |
63 |
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoTTA: Continual Test-Time Adaptation
2 | Official code for [Continual Test-Time Domain Adaptation](https://arxiv.org/abs/2203.13591), published in CVPR 2022.
3 |
4 | This repository also includes other continual test-time adaptation methods for classification and segmentation.
5 | We provide benchmarking and comparison for the following methods:
6 | + [CoTTA](https://arxiv.org/abs/2203.13591)
7 | + AdaBN / BN Adapt
8 | + TENT
9 |
10 | on the following tasks
11 | + CIFAR10/100 -> CIFAR10C/100C (standard/gradual)
12 | + ImageNet -> ImageNetC
13 | + Cityscapes -> ACDC (segmentation)
14 |
15 | ## Prerequisite
16 | Please create and activate the following conda envrionment. To reproduce our results, please kindly create and use this environment.
17 | ```bash
18 | # It may take several minutes for conda to solve the environment
19 | conda update conda
20 | conda env create -f environment.yml
21 | conda activate cotta
22 | ```
23 |
24 | ## Classification Experiments
25 | ### CIFAR10-to-CIFAR10C-standard task
26 | ```bash
27 | # Tested on RTX2080TI
28 | cd cifar
29 | # This includes the comparison of all three methods as well as baseline
30 | bash run_cifar10.sh
31 | ```
32 | ### CIFAR10-to-CIFAR10C-gradual task
33 | ```bash
34 | # Tested on RTX2080TI
35 | bash run_cifar10_gradual.sh
36 | ```
37 | ### CIFAR100-to-CIFAR100C task
38 | ```bash
39 | # Tested on RTX3090
40 | bash run_cifar100.sh
41 | ```
42 |
43 | ### ImageNet-to-ImageNetC task
44 | ```bash
45 | # Tested on RTX3090
46 | cd imagenet
47 | bash run.sh
48 | ```
49 |
50 | ## Segmentation Experiments
51 | ### Cityscapes-to-ACDC segmentation task
52 | Since April 2022, we also offer the segmentation code based on Segformer.
53 | You can download it [here](https://github.com/qinenergy/cotta/issues/6)
54 | ```
55 | ## environment setup: a new conda environment is needed for segformer
56 | ## You may also want to check https://github.com/qinenergy/cotta/issues/13 if you have problem installing mmcv
57 | conda env create -f environment_segformer.yml
58 | pip install -e . --user
59 | conda activate segformer
60 | ## Run
61 | bash run_base.sh
62 | bash run_tent.sh
63 | bash run_cotta.sh
64 | # Example logs are included in ./example_logs/base.log, tent.log, and cotta.log.
65 | ## License for Cityscapses-to-ACDC code
66 | Non-commercial. Code is heavily based on Segformer. Please also check Segformer's LICENSE.
67 | ```
68 |
69 | ## Data links
70 | + [Supplementary PDF](https://1drv.ms/b/s!At2KHTLZCWRegpAOqP-8BCBQze68wg?e=wiyaAl)
71 | + [ACDC experiment code](https://1drv.ms/u/s!At2KHTLZCWRegpAcvrh9SA34gMpzNQ?e=TSiKs6)
72 | + [Other Supplementary](https://1drv.ms/f/s!At2KHTLZCWRegpAKdcRuOGE1S9ZGLg?e=iJcR9L)
73 |
74 | ## Citation
75 | Please cite our work if you find it useful.
76 | ```bibtex
77 | @inproceedings{wang2022continual,
78 | title={Continual Test-Time Domain Adaptation},
79 | author={Wang, Qin and Fink, Olga and Van Gool, Luc and Dai, Dengxin},
80 | booktitle={Proceedings of Conference on Computer Vision and Pattern Recognition},
81 | year={2022}
82 | }
83 | ```
84 |
85 | ## Acknowledgement
86 | + TENT code is heavily used. [official](https://github.com/DequanWang/tent)
87 | + KATANA code is used for augmentation. [official](https://github.com/giladcohen/KATANA)
88 | + Robustbench [official](https://github.com/RobustBench/robustbench)
89 |
90 |
91 | ## External data link
92 | + ImageNet-C [Download](https://zenodo.org/record/2235448#.Yj2RO_co_mF)
93 |
94 | For questions regarding the code, please contact wang@qin.ee .
95 |
--------------------------------------------------------------------------------
/cifar/robustbench/leaderboard/template.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | from jinja2 import Environment, PackageLoader, select_autoescape
7 |
8 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
9 | from robustbench.utils import ACC_FIELDS
10 |
11 |
12 | def generate_leaderboard(dataset: Union[str, BenchmarkDataset],
13 | threat_model: Union[str, ThreatModel],
14 | models_folder: str = "model_info") -> str:
15 | """Prints the HTML leaderboard starting from the .json results.
16 |
17 | The result is a that can be put directly into the RobustBench index.html page,
18 | and looks the same as the tables that are already existing.
19 |
20 | The .json results must have the same structure as the following:
21 | ``
22 | {
23 | "link": "https://arxiv.org/abs/2003.09461",
24 | "name": "Adversarial Robustness on In- and Out-Distribution Improves Explainability",
25 | "authors": "Maximilian Augustin, Alexander Meinke, Matthias Hein",
26 | "additional_data": true,
27 | "number_forward_passes": 1,
28 | "dataset": "cifar10",
29 | "venue": "ECCV 2020",
30 | "architecture": "ResNet-50",
31 | "eps": "0.5",
32 | "clean_acc": "91.08",
33 | "reported": "73.27",
34 | "autoattack_acc": "72.91"
35 | }
36 | ``
37 |
38 | If the model is robust to common corruptions, then the "autoattack_acc" field should be
39 | "corruptions_acc".
40 |
41 | :param dataset: The dataset of the wanted leaderboard.
42 | :param threat_model: The threat model of the wanted leaderboard.
43 | :param models_folder: The base folder of the model jsons (e.g. our "model_info" folder).
44 |
45 | :return: The resulting HTML table.
46 | """
47 | dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
48 | threat_model_: ThreatModel = ThreatModel(threat_model)
49 |
50 | folder = Path(models_folder) / dataset_.value / threat_model_.value
51 |
52 | acc_field = ACC_FIELDS[threat_model_]
53 |
54 | models = []
55 | for model_path in folder.glob("*.json"):
56 | with open(model_path) as fp:
57 | model = json.load(fp)
58 |
59 | models.append(model)
60 |
61 | models.sort(key=lambda x: x[acc_field], reverse=True)
62 |
63 | env = Environment(loader=PackageLoader('robustbench', 'leaderboard'),
64 | autoescape=select_autoescape(['html', 'xml']))
65 |
66 | template = env.get_template('leaderboard.html.j2')
67 |
68 | result = template.render(threat_model=threat_model, dataset=dataset, models=models, acc_field=acc_field)
69 | print(result)
70 | return result
71 |
72 |
73 | if __name__ == "__main__":
74 | parser = ArgumentParser()
75 | parser.add_argument(
76 | "--dataset",
77 | type=str,
78 | default="cifar10",
79 | help="The dataset of the desired leaderboard."
80 | )
81 | parser.add_argument(
82 | "--threat_model",
83 | type=str,
84 | help="The threat model of the desired leaderboard."
85 | )
86 | parser.add_argument(
87 | "--models_folder",
88 | type=str,
89 | default="model_info",
90 | help="The base folder of the model jsons (e.g. our 'model_info' folder)"
91 | )
92 | args = parser.parse_args()
93 |
94 | generate_leaderboard(args.dataset, args.threat_model, args.models_folder)
95 |
--------------------------------------------------------------------------------
/imagenet/robustbench/leaderboard/template.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | from jinja2 import Environment, PackageLoader, select_autoescape
7 |
8 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
9 | from robustbench.utils import ACC_FIELDS
10 |
11 |
12 | def generate_leaderboard(dataset: Union[str, BenchmarkDataset],
13 | threat_model: Union[str, ThreatModel],
14 | models_folder: str = "model_info") -> str:
15 | """Prints the HTML leaderboard starting from the .json results.
16 |
17 | The result is a that can be put directly into the RobustBench index.html page,
18 | and looks the same as the tables that are already existing.
19 |
20 | The .json results must have the same structure as the following:
21 | ``
22 | {
23 | "link": "https://arxiv.org/abs/2003.09461",
24 | "name": "Adversarial Robustness on In- and Out-Distribution Improves Explainability",
25 | "authors": "Maximilian Augustin, Alexander Meinke, Matthias Hein",
26 | "additional_data": true,
27 | "number_forward_passes": 1,
28 | "dataset": "cifar10",
29 | "venue": "ECCV 2020",
30 | "architecture": "ResNet-50",
31 | "eps": "0.5",
32 | "clean_acc": "91.08",
33 | "reported": "73.27",
34 | "autoattack_acc": "72.91"
35 | }
36 | ``
37 |
38 | If the model is robust to common corruptions, then the "autoattack_acc" field should be
39 | "corruptions_acc".
40 |
41 | :param dataset: The dataset of the wanted leaderboard.
42 | :param threat_model: The threat model of the wanted leaderboard.
43 | :param models_folder: The base folder of the model jsons (e.g. our "model_info" folder).
44 |
45 | :return: The resulting HTML table.
46 | """
47 | dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
48 | threat_model_: ThreatModel = ThreatModel(threat_model)
49 |
50 | folder = Path(models_folder) / dataset_.value / threat_model_.value
51 |
52 | acc_field = ACC_FIELDS[threat_model_]
53 |
54 | models = []
55 | for model_path in folder.glob("*.json"):
56 | with open(model_path) as fp:
57 | model = json.load(fp)
58 |
59 | models.append(model)
60 |
61 | models.sort(key=lambda x: x[acc_field], reverse=True)
62 |
63 | env = Environment(loader=PackageLoader('robustbench', 'leaderboard'),
64 | autoescape=select_autoescape(['html', 'xml']))
65 |
66 | template = env.get_template('leaderboard.html.j2')
67 |
68 | result = template.render(threat_model=threat_model, dataset=dataset, models=models, acc_field=acc_field)
69 | print(result)
70 | return result
71 |
72 |
73 | if __name__ == "__main__":
74 | parser = ArgumentParser()
75 | parser.add_argument(
76 | "--dataset",
77 | type=str,
78 | default="cifar10",
79 | help="The dataset of the desired leaderboard."
80 | )
81 | parser.add_argument(
82 | "--threat_model",
83 | type=str,
84 | help="The threat model of the desired leaderboard."
85 | )
86 | parser.add_argument(
87 | "--models_folder",
88 | type=str,
89 | default="model_info",
90 | help="The base folder of the model jsons (e.g. our 'model_info' folder)"
91 | )
92 | args = parser.parse_args()
93 |
94 | generate_leaderboard(args.dataset, args.threat_model, args.models_folder)
95 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/imagenet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from torchvision import models as pt_models
4 |
5 | from robustbench.model_zoo.enums import ThreatModel
6 | from robustbench.model_zoo.architectures.utils_architectures import normalize_model
7 |
8 |
9 | mu = (0.485, 0.456, 0.406)
10 | sigma = (0.229, 0.224, 0.225)
11 |
12 |
13 | linf = OrderedDict(
14 | [
15 | ('Wong2020Fast', { # requires resolution 288 x 288
16 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
17 | 'gdrive_id': '1deM2ZNS5tf3S_-eRURJi-IlvUL8WJQ_w',
18 | 'preprocessing': 'Crop288'
19 | }),
20 | ('Engstrom2019Robustness', {
21 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
22 | 'gdrive_id': '1T2Fvi1eCJTeAOEzrH_4TAIwO8HTOYVyn',
23 | 'preprocessing': 'Res256Crop224',
24 | }),
25 | ('Salman2020Do_R50', {
26 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
27 | 'gdrive_id': '1TmT5oGa1UvVjM3d-XeSj_XmKqBNRUg8r',
28 | 'preprocessing': 'Res256Crop224'
29 | }),
30 | ('Salman2020Do_R18', {
31 | 'model': lambda: normalize_model(pt_models.resnet18(), mu, sigma),
32 | 'gdrive_id': '1OThCOQCOxY6lAgxZxgiK3YuZDD7PPfPx',
33 | 'preprocessing': 'Res256Crop224'
34 | }),
35 | ('Salman2020Do_50_2', {
36 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma),
37 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB',
38 | 'preprocessing': 'Res256Crop224'
39 | }),
40 | ('Standard_R50', {
41 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma),
42 | 'gdrive_id': '',
43 | 'preprocessing': 'Res256Crop224'
44 | }),
45 | ])
46 |
47 | common_corruptions = OrderedDict(
48 | [
49 | ('Geirhos2018_SIN', {
50 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
51 | 'gdrive_id': '1hLgeY_rQIaOT4R-t_KyOqPNkczfaedgs',
52 | 'preprocessing': 'Res256Crop224'
53 | }),
54 | ('Geirhos2018_SIN_IN', {
55 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
56 | 'gdrive_id': '139pWopDnNERObZeLsXUysRcLg6N1iZHK',
57 | 'preprocessing': 'Res256Crop224'
58 | }),
59 | ('Geirhos2018_SIN_IN_IN', {
60 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
61 | 'gdrive_id': '1xOvyuxpOZ8I5CZOi0EGYG_R6tu3ZaJdO',
62 | 'preprocessing': 'Res256Crop224'
63 | }),
64 | ('Hendrycks2020Many', {
65 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
66 | 'gdrive_id': '1kylueoLtYtxkpVzoOA1B6tqdbRl2xt9X',
67 | 'preprocessing': 'Res256Crop224'
68 | }),
69 | ('Hendrycks2020AugMix', {
70 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
71 | 'gdrive_id': '1xRMj1GlO93tLoCMm0e5wEvZwqhIjxhoJ',
72 | 'preprocessing': 'Res256Crop224'
73 | }),
74 | ('Salman2020Do_50_2_Linf', {
75 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma),
76 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB',
77 | 'preprocessing': 'Res256Crop224'
78 | }),
79 | ('Standard_R50', {
80 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma),
81 | 'gdrive_id': '',
82 | 'preprocessing': 'Res256Crop224'
83 | }),
84 | ])
85 |
86 | imagenet_models = OrderedDict([(ThreatModel.Linf, linf),
87 | (ThreatModel.corruptions, common_corruptions)])
88 |
89 |
90 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/imagenet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from torchvision import models as pt_models
4 |
5 | from robustbench.model_zoo.enums import ThreatModel
6 | from robustbench.model_zoo.architectures.utils_architectures import normalize_model
7 |
8 |
9 | mu = (0.485, 0.456, 0.406)
10 | sigma = (0.229, 0.224, 0.225)
11 |
12 |
13 | linf = OrderedDict(
14 | [
15 | ('Wong2020Fast', { # requires resolution 288 x 288
16 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
17 | 'gdrive_id': '1deM2ZNS5tf3S_-eRURJi-IlvUL8WJQ_w',
18 | 'preprocessing': 'Crop288'
19 | }),
20 | ('Engstrom2019Robustness', {
21 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
22 | 'gdrive_id': '1T2Fvi1eCJTeAOEzrH_4TAIwO8HTOYVyn',
23 | 'preprocessing': 'Res256Crop224',
24 | }),
25 | ('Salman2020Do_R50', {
26 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
27 | 'gdrive_id': '1TmT5oGa1UvVjM3d-XeSj_XmKqBNRUg8r',
28 | 'preprocessing': 'Res256Crop224'
29 | }),
30 | ('Salman2020Do_R18', {
31 | 'model': lambda: normalize_model(pt_models.resnet18(), mu, sigma),
32 | 'gdrive_id': '1OThCOQCOxY6lAgxZxgiK3YuZDD7PPfPx',
33 | 'preprocessing': 'Res256Crop224'
34 | }),
35 | ('Salman2020Do_50_2', {
36 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma),
37 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB',
38 | 'preprocessing': 'Res256Crop224'
39 | }),
40 | ('Standard_R50', {
41 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma),
42 | 'gdrive_id': '',
43 | 'preprocessing': 'Res256Crop224'
44 | }),
45 | ])
46 |
47 | common_corruptions = OrderedDict(
48 | [
49 | ('Geirhos2018_SIN', {
50 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
51 | 'gdrive_id': '1hLgeY_rQIaOT4R-t_KyOqPNkczfaedgs',
52 | 'preprocessing': 'Res256Crop224'
53 | }),
54 | ('Geirhos2018_SIN_IN', {
55 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
56 | 'gdrive_id': '139pWopDnNERObZeLsXUysRcLg6N1iZHK',
57 | 'preprocessing': 'Res256Crop224'
58 | }),
59 | ('Geirhos2018_SIN_IN_IN', {
60 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
61 | 'gdrive_id': '1xOvyuxpOZ8I5CZOi0EGYG_R6tu3ZaJdO',
62 | 'preprocessing': 'Res256Crop224'
63 | }),
64 | ('Hendrycks2020Many', {
65 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
66 | 'gdrive_id': '1kylueoLtYtxkpVzoOA1B6tqdbRl2xt9X',
67 | 'preprocessing': 'Res256Crop224'
68 | }),
69 | ('Hendrycks2020AugMix', {
70 | 'model': lambda: normalize_model(pt_models.resnet50(), mu, sigma),
71 | 'gdrive_id': '1xRMj1GlO93tLoCMm0e5wEvZwqhIjxhoJ',
72 | 'preprocessing': 'Res256Crop224'
73 | }),
74 | ('Salman2020Do_50_2_Linf', {
75 | 'model': lambda: normalize_model(pt_models.wide_resnet50_2(), mu, sigma),
76 | 'gdrive_id': '1OT7xaQYljrTr3vGbM37xK9SPoPJvbSKB',
77 | 'preprocessing': 'Res256Crop224'
78 | }),
79 | ('Standard_R50', {
80 | 'model': lambda: normalize_model(pt_models.resnet50(pretrained=True), mu, sigma),
81 | 'gdrive_id': '',
82 | 'preprocessing': 'Res256Crop224'
83 | }),
84 | ])
85 |
86 | imagenet_models = OrderedDict([(ThreatModel.Linf, linf),
87 | (ThreatModel.corruptions, common_corruptions)])
88 |
89 |
90 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: cotta
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=4.5=1_gnu
8 | - blas=1.0=mkl
9 | - bzip2=1.0.8=h7b6447c_0
10 | - ca-certificates=2021.10.26=h06a4308_2
11 | - certifi=2021.10.8=py39h06a4308_0
12 | - colorama=0.4.4=pyhd3eb1b0_0
13 | - cudatoolkit=11.3.1=h2bc3f7f_2
14 | - ffmpeg=4.3=hf484d3e_0
15 | - freetype=2.10.4=h5ab3b9f_0
16 | - giflib=5.2.1=h7b6447c_0
17 | - gmp=6.2.1=h2531618_2
18 | - gnutls=3.6.15=he1e5248_0
19 | - intel-openmp=2021.3.0=h06a4308_3350
20 | - jpeg=9d=h7f8727e_0
21 | - lame=3.100=h7b6447c_0
22 | - lcms2=2.12=h3be6417_0
23 | - ld_impl_linux-64=2.35.1=h7274673_9
24 | - libffi=3.3=he6710b0_2
25 | - libgcc-ng=9.3.0=h5101ec6_17
26 | - libgomp=9.3.0=h5101ec6_17
27 | - libiconv=1.15=h63c8f33_5
28 | - libidn2=2.3.2=h7f8727e_0
29 | - libpng=1.6.37=hbc83047_0
30 | - libstdcxx-ng=9.3.0=hd4cf53a_17
31 | - libtasn1=4.16.0=h27cfd23_0
32 | - libtiff=4.2.0=h85742a9_0
33 | - libunistring=0.9.10=h27cfd23_0
34 | - libuv=1.40.0=h7b6447c_0
35 | - libwebp=1.2.0=h89dd481_0
36 | - libwebp-base=1.2.0=h27cfd23_0
37 | - lz4-c=1.9.3=h295c915_1
38 | - mkl=2021.3.0=h06a4308_520
39 | - mkl-service=2.4.0=py39h7f8727e_0
40 | - mkl_fft=1.3.1=py39hd3c417c_0
41 | - mkl_random=1.2.2=py39h51133e4_0
42 | - ncurses=6.2=he6710b0_1
43 | - nettle=3.7.3=hbbd107a_1
44 | - olefile=0.46=pyhd3eb1b0_0
45 | - openh264=2.1.0=hd408876_0
46 | - openssl=1.1.1l=h7f8727e_0
47 | - pillow=8.4.0=py39h5aabda8_0
48 | - pip=21.2.4=py39h06a4308_0
49 | - python=3.9.7=h12debd9_1
50 | - pytorch=1.10.0=py3.9_cuda11.3_cudnn8.2.0_0
51 | - pytorch-mutex=1.0=cuda
52 | - readline=8.1=h27cfd23_0
53 | - setuptools=58.0.4=py39h06a4308_0
54 | - six=1.16.0=pyhd3eb1b0_0
55 | - sqlite=3.36.0=hc218d9a_0
56 | - tk=8.6.11=h1ccaba5_0
57 | - torchaudio=0.10.0=py39_cu113
58 | - torchvision=0.11.1=py39_cu113
59 | - typing_extensions=3.10.0.2=pyh06a4308_0
60 | - tzdata=2021a=h5d7bf9c_0
61 | - wheel=0.37.0=pyhd3eb1b0_1
62 | - xz=5.2.5=h7b6447c_0
63 | - zlib=1.2.11=h7b6447c_3
64 | - zstd=1.4.9=haebb681_0
65 | - pip:
66 | - addict==2.4.0
67 | - appdirs==1.4.4
68 | - attr==0.3.1
69 | - attrs==21.2.0
70 | - backcall==0.2.0
71 | - black==19.3b0
72 | - chardet==3.0.4
73 | - click==8.0.3
74 | - cloudpickle==2.0.0
75 | - decorator==5.1.0
76 | - flake8==4.0.1
77 | - idna==2.10
78 | - imagecorruptions==1.1.2
79 | - imageio==2.10.1
80 | - iopath==0.1.9
81 | - ipython==7.28.0
82 | - isort==4.3.21
83 | - jedi==0.18.0
84 | - jinja2==2.11.3
85 | - joblib==1.1.0
86 | - markupsafe==2.0.1
87 | - matplotlib-inline==0.1.3
88 | - mccabe==0.6.1
89 | - networkx==2.6.3
90 | - numpy==1.19.5
91 | - packaging==21.0
92 | - parameterized==0.8.1
93 | - parso==0.8.2
94 | - pexpect==4.8.0
95 | - pickleshare==0.7.5
96 | - portalocker==2.3.2
97 | - prompt-toolkit==3.0.21
98 | - ptyprocess==0.7.0
99 | - pycodestyle==2.8.0
100 | - pyflakes==2.4.0
101 | - pygments==2.10.0
102 | - pytorchcv==0.0.67
103 | - pytz==2021.3
104 | - pywavelets==1.1.1
105 | - pyyaml==6.0
106 | - requests==2.23.0
107 | - git+https://github.com/robustbench/robustbench@v0.1#egg=robustbench
108 | - scikit-image==0.18.3
109 | - scikit-learn==1.0.1
110 | - scipy==1.7.1
111 | - simplejson==3.17.5
112 | - submitit==1.4.0
113 | - threadpoolctl==3.0.0
114 | - tifffile==2021.10.12
115 | - timm==0.4.12
116 | - toml==0.10.2
117 | - tqdm==4.56.2
118 | - traitlets==5.1.0
119 | - urllib3==1.25.11
120 | - wcwidth==0.2.5
121 | - yacs==0.1.8
122 | - yapf==0.31.0
123 | - yattag==1.14.0
124 | - gdown==5.1.0
125 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/architectures/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class WideResNet(nn.Module):
51 | """ Based on code from https://github.com/yaodongyu/TRADES """
52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True):
53 | super(WideResNet, self).__init__()
54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
55 | assert ((depth - 4) % 6 == 0)
56 | n = (depth - 4) / 6
57 | block = BasicBlock
58 | # 1st conv before any network block
59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
60 | padding=1, bias=False)
61 | # 1st block
62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
63 | if sub_block1:
64 | # 1st sub-block
65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
66 | # 2nd block
67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
68 | # 3rd block
69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
70 | # global average pooling and classifier
71 | self.bn1 = nn.BatchNorm2d(nChannels[3])
72 | self.relu = nn.ReLU(inplace=True)
73 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
74 | self.nChannels = nChannels[3]
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.Conv2d):
78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
79 | m.weight.data.normal_(0, math.sqrt(2. / n))
80 | elif isinstance(m, nn.BatchNorm2d):
81 | m.weight.data.fill_(1)
82 | m.bias.data.zero_()
83 | elif isinstance(m, nn.Linear) and not m.bias is None:
84 | m.bias.data.zero_()
85 |
86 | def forward(self, x):
87 | out = self.conv1(x)
88 | out = self.block1(out)
89 | out = self.block2(out)
90 | out = self.block3(out)
91 | out = self.relu(self.bn1(out))
92 | out = F.avg_pool2d(out, 8)
93 | out = out.view(-1, self.nChannels)
94 | return self.fc(out)
95 |
96 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/architectures/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class WideResNet(nn.Module):
51 | """ Based on code from https://github.com/yaodongyu/TRADES """
52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True):
53 | super(WideResNet, self).__init__()
54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
55 | assert ((depth - 4) % 6 == 0)
56 | n = (depth - 4) / 6
57 | block = BasicBlock
58 | # 1st conv before any network block
59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
60 | padding=1, bias=False)
61 | # 1st block
62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
63 | if sub_block1:
64 | # 1st sub-block
65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
66 | # 2nd block
67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
68 | # 3rd block
69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
70 | # global average pooling and classifier
71 | self.bn1 = nn.BatchNorm2d(nChannels[3])
72 | self.relu = nn.ReLU(inplace=True)
73 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
74 | self.nChannels = nChannels[3]
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.Conv2d):
78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
79 | m.weight.data.normal_(0, math.sqrt(2. / n))
80 | elif isinstance(m, nn.BatchNorm2d):
81 | m.weight.data.fill_(1)
82 | m.bias.data.zero_()
83 | elif isinstance(m, nn.Linear) and not m.bias is None:
84 | m.bias.data.zero_()
85 |
86 | def forward(self, x):
87 | out = self.conv1(x)
88 | out = self.block1(out)
89 | out = self.block2(out)
90 | out = self.block3(out)
91 | out = self.relu(self.bn1(out))
92 | out = F.avg_pool2d(out, 8)
93 | out = out.view(-1, self.nChannels)
94 | return self.fc(out)
95 |
96 |
--------------------------------------------------------------------------------
/cifar/tent.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.jit
6 |
7 |
8 | class Tent(nn.Module):
9 | """Tent adapts a model by entropy minimization during testing.
10 |
11 | Once tented, a model adapts itself by updating on every forward.
12 | """
13 | def __init__(self, model, optimizer, steps=1, episodic=False):
14 | super().__init__()
15 | self.model = model
16 | self.optimizer = optimizer
17 | self.steps = steps
18 | assert steps > 0, "tent requires >= 1 step(s) to forward and update"
19 | self.episodic = episodic
20 |
21 | # note: if the model is never reset, like for continual adaptation,
22 | # then skipping the state copy would save memory
23 | self.model_state, self.optimizer_state = \
24 | copy_model_and_optimizer(self.model, self.optimizer)
25 |
26 | def forward(self, x):
27 | if self.episodic:
28 | self.reset()
29 |
30 | for _ in range(self.steps):
31 | outputs = forward_and_adapt(x, self.model, self.optimizer)
32 |
33 | return outputs
34 |
35 | def reset(self):
36 | if self.model_state is None or self.optimizer_state is None:
37 | raise Exception("cannot reset without saved model/optimizer state")
38 | load_model_and_optimizer(self.model, self.optimizer,
39 | self.model_state, self.optimizer_state)
40 |
41 |
42 | @torch.jit.script
43 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
44 | """Entropy of softmax distribution from logits."""
45 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
46 |
47 |
48 | @torch.enable_grad() # ensure grads in possible no grad context for testing
49 | def forward_and_adapt(x, model, optimizer):
50 | """Forward and adapt model on batch of data.
51 |
52 | Measure entropy of the model prediction, take gradients, and update params.
53 | """
54 | # forward
55 | outputs = model(x)
56 | # adapt
57 | loss = softmax_entropy(outputs).mean(0)
58 | loss.backward()
59 | optimizer.step()
60 | optimizer.zero_grad()
61 | return outputs
62 |
63 |
64 | def collect_params(model):
65 | """Collect the affine scale + shift parameters from batch norms.
66 |
67 | Walk the model's modules and collect all batch normalization parameters.
68 | Return the parameters and their names.
69 |
70 | Note: other choices of parameterization are possible!
71 | """
72 | params = []
73 | names = []
74 | for nm, m in model.named_modules():
75 | if isinstance(m, nn.BatchNorm2d):
76 | for np, p in m.named_parameters():
77 | if np in ['weight', 'bias']: # weight is scale, bias is shift
78 | params.append(p)
79 | names.append(f"{nm}.{np}")
80 | return params, names
81 |
82 |
83 | def copy_model_and_optimizer(model, optimizer):
84 | """Copy the model and optimizer states for resetting after adaptation."""
85 | model_state = deepcopy(model.state_dict())
86 | optimizer_state = deepcopy(optimizer.state_dict())
87 | return model_state, optimizer_state
88 |
89 |
90 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
91 | """Restore the model and optimizer states from copies."""
92 | model.load_state_dict(model_state, strict=True)
93 | optimizer.load_state_dict(optimizer_state)
94 |
95 |
96 | def configure_model(model):
97 | """Configure model for use with tent."""
98 | # train mode, because tent optimizes the model to minimize entropy
99 | model.train()
100 | # disable grad, to (re-)enable only what tent updates
101 | model.requires_grad_(False)
102 | # configure norm for tent updates: enable grad + force batch statisics
103 | for m in model.modules():
104 | if isinstance(m, nn.BatchNorm2d):
105 | m.requires_grad_(True)
106 | # force use of batch stats in train and eval modes
107 | m.track_running_stats = False
108 | m.running_mean = None
109 | m.running_var = None
110 | return model
111 |
112 |
113 | def check_model(model):
114 | """Check model for compatability with tent."""
115 | is_training = model.training
116 | assert is_training, "tent needs train mode: call model.train()"
117 | param_grads = [p.requires_grad for p in model.parameters()]
118 | has_any_params = any(param_grads)
119 | has_all_params = all(param_grads)
120 | assert has_any_params, "tent needs params to update: " \
121 | "check which require grad"
122 | assert not has_all_params, "tent should not update all params: " \
123 | "check which require grad"
124 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
125 | assert has_bn, "tent needs normalization for its optimization"
126 |
--------------------------------------------------------------------------------
/imagenet/tent.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.jit
6 |
7 |
8 | class Tent(nn.Module):
9 | """Tent adapts a model by entropy minimization during testing.
10 |
11 | Once tented, a model adapts itself by updating on every forward.
12 | """
13 | def __init__(self, model, optimizer, steps=1, episodic=False):
14 | super().__init__()
15 | self.model = model
16 | self.optimizer = optimizer
17 | self.steps = steps
18 | assert steps > 0, "tent requires >= 1 step(s) to forward and update"
19 | self.episodic = episodic
20 |
21 | # note: if the model is never reset, like for continual adaptation,
22 | # then skipping the state copy would save memory
23 | self.model_state, self.optimizer_state = \
24 | copy_model_and_optimizer(self.model, self.optimizer)
25 |
26 | def forward(self, x):
27 | if self.episodic:
28 | self.reset()
29 |
30 | for _ in range(self.steps):
31 | outputs = forward_and_adapt(x, self.model, self.optimizer)
32 |
33 | return outputs
34 |
35 | def reset(self):
36 | if self.model_state is None or self.optimizer_state is None:
37 | raise Exception("cannot reset without saved model/optimizer state")
38 | load_model_and_optimizer(self.model, self.optimizer,
39 | self.model_state, self.optimizer_state)
40 |
41 |
42 | @torch.jit.script
43 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
44 | """Entropy of softmax distribution from logits."""
45 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
46 |
47 |
48 | @torch.enable_grad() # ensure grads in possible no grad context for testing
49 | def forward_and_adapt(x, model, optimizer):
50 | """Forward and adapt model on batch of data.
51 |
52 | Measure entropy of the model prediction, take gradients, and update params.
53 | """
54 | # forward
55 | outputs = model(x)
56 | # adapt
57 | loss = softmax_entropy(outputs).mean(0)
58 | loss.backward()
59 | optimizer.step()
60 | optimizer.zero_grad()
61 | return outputs
62 |
63 |
64 | def collect_params(model):
65 | """Collect the affine scale + shift parameters from batch norms.
66 |
67 | Walk the model's modules and collect all batch normalization parameters.
68 | Return the parameters and their names.
69 |
70 | Note: other choices of parameterization are possible!
71 | """
72 | params = []
73 | names = []
74 | for nm, m in model.named_modules():
75 | if isinstance(m, nn.BatchNorm2d):
76 | for np, p in m.named_parameters():
77 | if np in ['weight', 'bias']: # weight is scale, bias is shift
78 | params.append(p)
79 | names.append(f"{nm}.{np}")
80 | return params, names
81 |
82 |
83 | def copy_model_and_optimizer(model, optimizer):
84 | """Copy the model and optimizer states for resetting after adaptation."""
85 | model_state = deepcopy(model.state_dict())
86 | optimizer_state = deepcopy(optimizer.state_dict())
87 | return model_state, optimizer_state
88 |
89 |
90 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
91 | """Restore the model and optimizer states from copies."""
92 | model.load_state_dict(model_state, strict=True)
93 | optimizer.load_state_dict(optimizer_state)
94 |
95 |
96 | def configure_model(model):
97 | """Configure model for use with tent."""
98 | # train mode, because tent optimizes the model to minimize entropy
99 | model.train()
100 | # disable grad, to (re-)enable only what tent updates
101 | model.requires_grad_(False)
102 | # configure norm for tent updates: enable grad + force batch statisics
103 | for m in model.modules():
104 | if isinstance(m, nn.BatchNorm2d):
105 | m.requires_grad_(True)
106 | # force use of batch stats in train and eval modes
107 | m.track_running_stats = False
108 | m.running_mean = None
109 | m.running_var = None
110 | return model
111 |
112 |
113 | def check_model(model):
114 | """Check model for compatability with tent."""
115 | is_training = model.training
116 | assert is_training, "tent needs train mode: call model.train()"
117 | param_grads = [p.requires_grad for p in model.parameters()]
118 | has_any_params = any(param_grads)
119 | has_all_params = all(param_grads)
120 | assert has_any_params, "tent needs params to update: " \
121 | "check which require grad"
122 | assert not has_all_params, "tent should not update all params: " \
123 | "check which require grad"
124 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
125 | assert has_bn, "tent needs normalization for its optimization"
126 |
--------------------------------------------------------------------------------
/imagenet/my_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | from torchvision.transforms import ColorJitter, Compose, Lambda
4 | from numpy import random
5 |
6 | class GaussianNoise(torch.nn.Module):
7 | def __init__(self, mean=0., std=1.):
8 | super().__init__()
9 | self.std = std
10 | self.mean = mean
11 |
12 | def forward(self, img):
13 | noise = torch.randn(img.size()) * self.std + self.mean
14 | noise = noise.to(img.device)
15 | return img + noise
16 |
17 | def __repr__(self):
18 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
19 |
20 | class Clip(torch.nn.Module):
21 | def __init__(self, min_val=0., max_val=1.):
22 | super().__init__()
23 | self.min_val = min_val
24 | self.max_val = max_val
25 |
26 | def forward(self, img):
27 | return torch.clip(img, self.min_val, self.max_val)
28 |
29 | def __repr__(self):
30 | return self.__class__.__name__ + '(min_val={0}, max_val={1})'.format(self.min_val, self.max_val)
31 |
32 | class ColorJitterPro(ColorJitter):
33 | """Randomly change the brightness, contrast, saturation, and gamma correction of an image."""
34 |
35 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, gamma=0):
36 | super().__init__(brightness, contrast, saturation, hue)
37 | self.gamma = self._check_input(gamma, 'gamma')
38 |
39 | @staticmethod
40 | @torch.jit.unused
41 | def get_params(brightness, contrast, saturation, hue, gamma):
42 | """Get a randomized transform to be applied on image.
43 |
44 | Arguments are same as that of __init__.
45 |
46 | Returns:
47 | Transform which randomly adjusts brightness, contrast and
48 | saturation in a random order.
49 | """
50 | transforms = []
51 |
52 | if brightness is not None:
53 | brightness_factor = random.uniform(brightness[0], brightness[1])
54 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
55 |
56 | if contrast is not None:
57 | contrast_factor = random.uniform(contrast[0], contrast[1])
58 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
59 |
60 | if saturation is not None:
61 | saturation_factor = random.uniform(saturation[0], saturation[1])
62 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
63 |
64 | if hue is not None:
65 | hue_factor = random.uniform(hue[0], hue[1])
66 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
67 |
68 | if gamma is not None:
69 | gamma_factor = random.uniform(gamma[0], gamma[1])
70 | transforms.append(Lambda(lambda img: F.adjust_gamma(img, gamma_factor)))
71 |
72 | random.shuffle(transforms)
73 | transform = Compose(transforms)
74 |
75 | return transform
76 |
77 | def forward(self, img):
78 | """
79 | Args:
80 | img (PIL Image or Tensor): Input image.
81 |
82 | Returns:
83 | PIL Image or Tensor: Color jittered image.
84 | """
85 | fn_idx = torch.randperm(5)
86 | for fn_id in fn_idx:
87 | if fn_id == 0 and self.brightness is not None:
88 | brightness = self.brightness
89 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
90 | img = F.adjust_brightness(img, brightness_factor)
91 |
92 | if fn_id == 1 and self.contrast is not None:
93 | contrast = self.contrast
94 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
95 | img = F.adjust_contrast(img, contrast_factor)
96 |
97 | if fn_id == 2 and self.saturation is not None:
98 | saturation = self.saturation
99 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
100 | img = F.adjust_saturation(img, saturation_factor)
101 |
102 | if fn_id == 3 and self.hue is not None:
103 | hue = self.hue
104 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
105 | img = F.adjust_hue(img, hue_factor)
106 |
107 | if fn_id == 4 and self.gamma is not None:
108 | gamma = self.gamma
109 | gamma_factor = torch.tensor(1.0).uniform_(gamma[0], gamma[1]).item()
110 | img = img.clamp(1e-8, 1.0) # to fix Nan values in gradients, which happens when applying gamma
111 | # after contrast
112 | img = F.adjust_gamma(img, gamma_factor)
113 |
114 | return img
115 |
116 | def __repr__(self):
117 | format_string = self.__class__.__name__ + '('
118 | format_string += 'brightness={0}'.format(self.brightness)
119 | format_string += ', contrast={0}'.format(self.contrast)
120 | format_string += ', saturation={0}'.format(self.saturation)
121 | format_string += ', hue={0})'.format(self.hue)
122 | format_string += ', gamma={0})'.format(self.gamma)
123 | return format_string
124 |
--------------------------------------------------------------------------------
/cifar/my_transforms.py:
--------------------------------------------------------------------------------
1 | # KATANA: Simple Post-Training Robustness Using Test Time Augmentations
2 | # https://arxiv.org/pdf/2109.08191v1.pdf
3 | import torch
4 | import torchvision.transforms.functional as F
5 | from torchvision.transforms import ColorJitter, Compose, Lambda
6 | from numpy import random
7 |
8 | class GaussianNoise(torch.nn.Module):
9 | def __init__(self, mean=0., std=1.):
10 | super().__init__()
11 | self.std = std
12 | self.mean = mean
13 |
14 | def forward(self, img):
15 | noise = torch.randn(img.size()) * self.std + self.mean
16 | noise = noise.to(img.device)
17 | return img + noise
18 |
19 | def __repr__(self):
20 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
21 |
22 | class Clip(torch.nn.Module):
23 | def __init__(self, min_val=0., max_val=1.):
24 | super().__init__()
25 | self.min_val = min_val
26 | self.max_val = max_val
27 |
28 | def forward(self, img):
29 | return torch.clip(img, self.min_val, self.max_val)
30 |
31 | def __repr__(self):
32 | return self.__class__.__name__ + '(min_val={0}, max_val={1})'.format(self.min_val, self.max_val)
33 |
34 | class ColorJitterPro(ColorJitter):
35 | """Randomly change the brightness, contrast, saturation, and gamma correction of an image."""
36 |
37 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, gamma=0):
38 | super().__init__(brightness, contrast, saturation, hue)
39 | self.gamma = self._check_input(gamma, 'gamma')
40 |
41 | @staticmethod
42 | @torch.jit.unused
43 | def get_params(brightness, contrast, saturation, hue, gamma):
44 | """Get a randomized transform to be applied on image.
45 |
46 | Arguments are same as that of __init__.
47 |
48 | Returns:
49 | Transform which randomly adjusts brightness, contrast and
50 | saturation in a random order.
51 | """
52 | transforms = []
53 |
54 | if brightness is not None:
55 | brightness_factor = random.uniform(brightness[0], brightness[1])
56 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
57 |
58 | if contrast is not None:
59 | contrast_factor = random.uniform(contrast[0], contrast[1])
60 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
61 |
62 | if saturation is not None:
63 | saturation_factor = random.uniform(saturation[0], saturation[1])
64 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
65 |
66 | if hue is not None:
67 | hue_factor = random.uniform(hue[0], hue[1])
68 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
69 |
70 | if gamma is not None:
71 | gamma_factor = random.uniform(gamma[0], gamma[1])
72 | transforms.append(Lambda(lambda img: F.adjust_gamma(img, gamma_factor)))
73 |
74 | random.shuffle(transforms)
75 | transform = Compose(transforms)
76 |
77 | return transform
78 |
79 | def forward(self, img):
80 | """
81 | Args:
82 | img (PIL Image or Tensor): Input image.
83 |
84 | Returns:
85 | PIL Image or Tensor: Color jittered image.
86 | """
87 | fn_idx = torch.randperm(5)
88 | for fn_id in fn_idx:
89 | if fn_id == 0 and self.brightness is not None:
90 | brightness = self.brightness
91 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
92 | img = F.adjust_brightness(img, brightness_factor)
93 |
94 | if fn_id == 1 and self.contrast is not None:
95 | contrast = self.contrast
96 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
97 | img = F.adjust_contrast(img, contrast_factor)
98 |
99 | if fn_id == 2 and self.saturation is not None:
100 | saturation = self.saturation
101 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
102 | img = F.adjust_saturation(img, saturation_factor)
103 |
104 | if fn_id == 3 and self.hue is not None:
105 | hue = self.hue
106 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
107 | img = F.adjust_hue(img, hue_factor)
108 |
109 | if fn_id == 4 and self.gamma is not None:
110 | gamma = self.gamma
111 | gamma_factor = torch.tensor(1.0).uniform_(gamma[0], gamma[1]).item()
112 | img = img.clamp(1e-8, 1.0) # to fix Nan values in gradients, which happens when applying gamma
113 | # after contrast
114 | img = F.adjust_gamma(img, gamma_factor)
115 |
116 | return img
117 |
118 | def __repr__(self):
119 | format_string = self.__class__.__name__ + '('
120 | format_string += 'brightness={0}'.format(self.brightness)
121 | format_string += ', contrast={0}'.format(self.contrast)
122 | format_string += ', saturation={0}'.format(self.saturation)
123 | format_string += ', hue={0})'.format(self.hue)
124 | format_string += ', gamma={0})'.format(self.gamma)
125 | return format_string
126 |
--------------------------------------------------------------------------------
/imagenet/imagenetc.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.optim as optim
5 |
6 | from robustbench.data import load_imagenetc
7 | from robustbench.model_zoo.enums import ThreatModel
8 | from robustbench.utils import load_model
9 | from robustbench.utils import clean_accuracy as accuracy
10 |
11 | import tent
12 | import norm
13 | import cotta
14 |
15 | from conf import cfg, load_cfg_fom_args
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def evaluate(description):
22 | load_cfg_fom_args(description)
23 | # configure model
24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda()
26 | if cfg.MODEL.ADAPTATION == "source":
27 | logger.info("test-time adaptation: NONE")
28 | model = setup_source(base_model)
29 | if cfg.MODEL.ADAPTATION == "norm":
30 | logger.info("test-time adaptation: NORM")
31 | model = setup_norm(base_model)
32 | if cfg.MODEL.ADAPTATION == "tent":
33 | logger.info("test-time adaptation: TENT")
34 | model = setup_tent(base_model)
35 | if cfg.MODEL.ADAPTATION == "cotta":
36 | logger.info("test-time adaptation: CoTTA")
37 | model = setup_cotta(base_model)
38 | # evaluate on each severity and type of corruption in turn
39 | prev_ct = "x0"
40 | for ii, severity in enumerate(cfg.CORRUPTION.SEVERITY):
41 | for i_x, corruption_type in enumerate(cfg.CORRUPTION.TYPE):
42 | # reset adaptation for each combination of corruption x severity
43 | # note: for evaluation protocol, but not necessarily needed
44 | try:
45 | if i_x == 0:
46 | model.reset()
47 | logger.info("resetting model")
48 | else:
49 | logger.warning("not resetting model")
50 | except:
51 | logger.warning("not resetting model")
52 | x_test, y_test = load_imagenetc(cfg.CORRUPTION.NUM_EX,
53 | severity, cfg.DATA_DIR, False,
54 | [corruption_type])
55 | x_test, y_test = x_test.cuda(), y_test.cuda()
56 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE)
57 | err = 1. - acc
58 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}")
59 |
60 |
61 | def setup_source(model):
62 | """Set up the baseline source model without adaptation."""
63 | model.eval()
64 | logger.info(f"model for evaluation: %s", model)
65 | return model
66 |
67 |
68 | def setup_norm(model):
69 | """Set up test-time normalization adaptation.
70 |
71 | Adapt by normalizing features with test batch statistics.
72 | The statistics are measured independently for each batch;
73 | no running average or other cross-batch estimation is used.
74 | """
75 | norm_model = norm.Norm(model)
76 | logger.info(f"model for adaptation: %s", model)
77 | stats, stat_names = norm.collect_stats(model)
78 | logger.info(f"stats for adaptation: %s", stat_names)
79 | return norm_model
80 |
81 |
82 | def setup_tent(model):
83 | """Set up tent adaptation.
84 |
85 | Configure the model for training + feature modulation by batch statistics,
86 | collect the parameters for feature modulation by gradient optimization,
87 | set up the optimizer, and then tent the model.
88 | """
89 | model = tent.configure_model(model)
90 | params, param_names = tent.collect_params(model)
91 | optimizer = setup_optimizer(params)
92 | tent_model = tent.Tent(model, optimizer,
93 | steps=cfg.OPTIM.STEPS,
94 | episodic=cfg.MODEL.EPISODIC)
95 | logger.info(f"model for adaptation: %s", model)
96 | logger.info(f"params for adaptation: %s", param_names)
97 | logger.info(f"optimizer for adaptation: %s", optimizer)
98 | return tent_model
99 |
100 |
101 | def setup_optimizer(params):
102 | """Set up optimizer for tent adaptation.
103 |
104 | Tent needs an optimizer for test-time entropy minimization.
105 | In principle, tent could make use of any gradient optimizer.
106 | In practice, we advise choosing Adam or SGD+momentum.
107 | For optimization settings, we advise to use the settings from the end of
108 | trainig, if known, or start with a low learning rate (like 0.001) if not.
109 |
110 | For best results, try tuning the learning rate and batch size.
111 | """
112 | if cfg.OPTIM.METHOD == 'Adam':
113 | return optim.Adam(params,
114 | lr=cfg.OPTIM.LR,
115 | betas=(cfg.OPTIM.BETA, 0.999),
116 | weight_decay=cfg.OPTIM.WD)
117 | elif cfg.OPTIM.METHOD == 'SGD':
118 | return optim.SGD(params,
119 | lr=cfg.OPTIM.LR,
120 | momentum=0.9,
121 | dampening=0,
122 | weight_decay=cfg.OPTIM.WD,
123 | nesterov=True)
124 | else:
125 | raise NotImplementedError
126 |
127 | def setup_cotta(model):
128 | """Set up tent adaptation.
129 |
130 | Configure the model for training + feature modulation by batch statistics,
131 | collect the parameters for feature modulation by gradient optimization,
132 | set up the optimizer, and then tent the model.
133 | """
134 | model = cotta.configure_model(model)
135 | params, param_names = cotta.collect_params(model)
136 | optimizer = setup_optimizer(params)
137 | cotta_model = cotta.CoTTA(model, optimizer,
138 | steps=cfg.OPTIM.STEPS,
139 | episodic=cfg.MODEL.EPISODIC)
140 | logger.info(f"model for adaptation: %s", model)
141 | logger.info(f"params for adaptation: %s", param_names)
142 | logger.info(f"optimizer for adaptation: %s", optimizer)
143 | return cotta_model
144 |
145 |
146 | if __name__ == '__main__':
147 | evaluate('"Imagenet-C evaluation.')
148 |
--------------------------------------------------------------------------------
/cifar/cifar100c.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.optim as optim
5 |
6 | from robustbench.data import load_cifar100c
7 | from robustbench.model_zoo.enums import ThreatModel
8 | from robustbench.utils import load_model
9 | from robustbench.utils import clean_accuracy as accuracy
10 |
11 | import tent
12 | import norm
13 | import cotta
14 |
15 | from conf import cfg, load_cfg_fom_args
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def evaluate(description):
22 | load_cfg_fom_args(description)
23 | # configure model
24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda()
26 | if cfg.MODEL.ADAPTATION == "source":
27 | logger.info("test-time adaptation: NONE")
28 | model = setup_source(base_model)
29 | if cfg.MODEL.ADAPTATION == "norm":
30 | logger.info("test-time adaptation: NORM")
31 | model = setup_norm(base_model)
32 | if cfg.MODEL.ADAPTATION == "tent":
33 | logger.info("test-time adaptation: TENT")
34 | model = setup_tent(base_model)
35 | if cfg.MODEL.ADAPTATION == "cotta":
36 | logger.info("test-time adaptation: CoTTA")
37 | model = setup_cotta(base_model)
38 | # evaluate on each severity and type of corruption in turn
39 | prev_ct = "x0"
40 | for severity in cfg.CORRUPTION.SEVERITY:
41 | for i_c, corruption_type in enumerate(cfg.CORRUPTION.TYPE):
42 | # continual adaptation for all corruption
43 | if i_c == 0:
44 | try:
45 | model.reset()
46 | logger.info("resetting model")
47 | except:
48 | logger.warning("not resetting model")
49 | else:
50 | logger.warning("not resetting model")
51 | x_test, y_test = load_cifar100c(cfg.CORRUPTION.NUM_EX,
52 | severity, cfg.DATA_DIR, False,
53 | [corruption_type])
54 | x_test, y_test = x_test.cuda(), y_test.cuda()
55 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE)
56 | err = 1. - acc
57 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}")
58 |
59 |
60 | def setup_source(model):
61 | """Set up the baseline source model without adaptation."""
62 | model.eval()
63 | logger.info(f"model for evaluation: %s", model)
64 | return model
65 |
66 |
67 | def setup_norm(model):
68 | """Set up test-time normalization adaptation.
69 |
70 | Adapt by normalizing features with test batch statistics.
71 | The statistics are measured independently for each batch;
72 | no running average or other cross-batch estimation is used.
73 | """
74 | norm_model = norm.Norm(model)
75 | logger.info(f"model for adaptation: %s", model)
76 | stats, stat_names = norm.collect_stats(model)
77 | logger.info(f"stats for adaptation: %s", stat_names)
78 | return norm_model
79 |
80 |
81 | def setup_tent(model):
82 | """Set up tent adaptation.
83 |
84 | Configure the model for training + feature modulation by batch statistics,
85 | collect the parameters for feature modulation by gradient optimization,
86 | set up the optimizer, and then tent the model.
87 | """
88 | model = tent.configure_model(model)
89 | params, param_names = tent.collect_params(model)
90 | optimizer = setup_optimizer(params)
91 | tent_model = tent.Tent(model, optimizer,
92 | steps=cfg.OPTIM.STEPS,
93 | episodic=cfg.MODEL.EPISODIC)
94 | logger.info(f"model for adaptation: %s", model)
95 | logger.info(f"params for adaptation: %s", param_names)
96 | logger.info(f"optimizer for adaptation: %s", optimizer)
97 | return tent_model
98 |
99 |
100 | def setup_cotta(model):
101 | """Set up tent adaptation.
102 |
103 | Configure the model for training + feature modulation by batch statistics,
104 | collect the parameters for feature modulation by gradient optimization,
105 | set up the optimizer, and then tent the model.
106 | """
107 | model = cotta.configure_model(model)
108 | params, param_names = cotta.collect_params(model)
109 | optimizer = setup_optimizer(params)
110 | cotta_model = cotta.CoTTA(model, optimizer,
111 | steps=cfg.OPTIM.STEPS,
112 | episodic=cfg.MODEL.EPISODIC,
113 | mt_alpha=cfg.OPTIM.MT,
114 | rst_m=cfg.OPTIM.RST,
115 | ap=cfg.OPTIM.AP)
116 | logger.info(f"model for adaptation: %s", model)
117 | logger.info(f"params for adaptation: %s", param_names)
118 | logger.info(f"optimizer for adaptation: %s", optimizer)
119 | return cotta_model
120 |
121 |
122 | def setup_optimizer(params):
123 | """Set up optimizer for tent adaptation.
124 |
125 | Tent needs an optimizer for test-time entropy minimization.
126 | In principle, tent could make use of any gradient optimizer.
127 | In practice, we advise choosing Adam or SGD+momentum.
128 | For optimization settings, we advise to use the settings from the end of
129 | trainig, if known, or start with a low learning rate (like 0.001) if not.
130 |
131 | For best results, try tuning the learning rate and batch size.
132 | """
133 | if cfg.OPTIM.METHOD == 'Adam':
134 | return optim.Adam(params,
135 | lr=cfg.OPTIM.LR,
136 | betas=(cfg.OPTIM.BETA, 0.999),
137 | weight_decay=cfg.OPTIM.WD)
138 | elif cfg.OPTIM.METHOD == 'SGD':
139 | return optim.SGD(params,
140 | lr=cfg.OPTIM.LR,
141 | momentum=cfg.OPTIM.MOMENTUM,
142 | dampening=cfg.OPTIM.DAMPENING,
143 | weight_decay=cfg.OPTIM.WD,
144 | nesterov=cfg.OPTIM.NESTEROV)
145 | else:
146 | raise NotImplementedError
147 |
148 |
149 | if __name__ == '__main__':
150 | evaluate('"CIFAR-10-C evaluation.')
151 |
--------------------------------------------------------------------------------
/cifar/cifar10c.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.optim as optim
5 |
6 | from robustbench.data import load_cifar10c
7 | from robustbench.model_zoo.enums import ThreatModel
8 | from robustbench.utils import load_model
9 | from robustbench.utils import clean_accuracy as accuracy
10 |
11 | import tent
12 | import norm
13 | import cotta
14 |
15 | from conf import cfg, load_cfg_fom_args
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def evaluate(description):
22 | load_cfg_fom_args(description)
23 | # configure model
24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda()
26 | if cfg.MODEL.ADAPTATION == "source":
27 | logger.info("test-time adaptation: NONE")
28 | model = setup_source(base_model)
29 | if cfg.MODEL.ADAPTATION == "norm":
30 | logger.info("test-time adaptation: NORM")
31 | model = setup_norm(base_model)
32 | if cfg.MODEL.ADAPTATION == "tent":
33 | logger.info("test-time adaptation: TENT")
34 | model = setup_tent(base_model)
35 | if cfg.MODEL.ADAPTATION == "cotta":
36 | logger.info("test-time adaptation: CoTTA")
37 | model = setup_cotta(base_model)
38 | # evaluate on each severity and type of corruption in turn
39 | prev_ct = "x0"
40 | for severity in cfg.CORRUPTION.SEVERITY:
41 | for i_c, corruption_type in enumerate(cfg.CORRUPTION.TYPE):
42 | # continual adaptation for all corruption
43 | if i_c == 0:
44 | try:
45 | model.reset()
46 | logger.info("resetting model")
47 | except:
48 | logger.warning("not resetting model")
49 | else:
50 | logger.warning("not resetting model")
51 | x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX,
52 | severity, cfg.DATA_DIR, False,
53 | [corruption_type])
54 | x_test, y_test = x_test.cuda(), y_test.cuda()
55 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE)
56 | err = 1. - acc
57 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}")
58 |
59 |
60 | def setup_source(model):
61 | """Set up the baseline source model without adaptation."""
62 | model.eval()
63 | logger.info(f"model for evaluation: %s", model)
64 | return model
65 |
66 |
67 | def setup_norm(model):
68 | """Set up test-time normalization adaptation.
69 |
70 | Adapt by normalizing features with test batch statistics.
71 | The statistics are measured independently for each batch;
72 | no running average or other cross-batch estimation is used.
73 | """
74 | norm_model = norm.Norm(model)
75 | logger.info(f"model for adaptation: %s", model)
76 | stats, stat_names = norm.collect_stats(model)
77 | logger.info(f"stats for adaptation: %s", stat_names)
78 | return norm_model
79 |
80 |
81 | def setup_tent(model):
82 | """Set up tent adaptation.
83 |
84 | Configure the model for training + feature modulation by batch statistics,
85 | collect the parameters for feature modulation by gradient optimization,
86 | set up the optimizer, and then tent the model.
87 | """
88 | model = tent.configure_model(model)
89 | params, param_names = tent.collect_params(model)
90 | optimizer = setup_optimizer(params)
91 | tent_model = tent.Tent(model, optimizer,
92 | steps=cfg.OPTIM.STEPS,
93 | episodic=cfg.MODEL.EPISODIC)
94 | logger.info(f"model for adaptation: %s", model)
95 | logger.info(f"params for adaptation: %s", param_names)
96 | logger.info(f"optimizer for adaptation: %s", optimizer)
97 | return tent_model
98 |
99 |
100 | def setup_cotta(model):
101 | """Set up tent adaptation.
102 |
103 | Configure the model for training + feature modulation by batch statistics,
104 | collect the parameters for feature modulation by gradient optimization,
105 | set up the optimizer, and then tent the model.
106 | """
107 | model = cotta.configure_model(model)
108 | params, param_names = cotta.collect_params(model)
109 | optimizer = setup_optimizer(params)
110 | cotta_model = cotta.CoTTA(model, optimizer,
111 | steps=cfg.OPTIM.STEPS,
112 | episodic=cfg.MODEL.EPISODIC,
113 | mt_alpha=cfg.OPTIM.MT,
114 | rst_m=cfg.OPTIM.RST,
115 | ap=cfg.OPTIM.AP)
116 | logger.info(f"model for adaptation: %s", model)
117 | logger.info(f"params for adaptation: %s", param_names)
118 | logger.info(f"optimizer for adaptation: %s", optimizer)
119 | return cotta_model
120 |
121 |
122 | def setup_optimizer(params):
123 | """Set up optimizer for tent adaptation.
124 |
125 | Tent needs an optimizer for test-time entropy minimization.
126 | In principle, tent could make use of any gradient optimizer.
127 | In practice, we advise choosing Adam or SGD+momentum.
128 | For optimization settings, we advise to use the settings from the end of
129 | trainig, if known, or start with a low learning rate (like 0.001) if not.
130 |
131 | For best results, try tuning the learning rate and batch size.
132 | """
133 | if cfg.OPTIM.METHOD == 'Adam':
134 | return optim.Adam(params,
135 | lr=cfg.OPTIM.LR,
136 | betas=(cfg.OPTIM.BETA, 0.999),
137 | weight_decay=cfg.OPTIM.WD)
138 | elif cfg.OPTIM.METHOD == 'SGD':
139 | return optim.SGD(params,
140 | lr=cfg.OPTIM.LR,
141 | momentum=cfg.OPTIM.MOMENTUM,
142 | dampening=cfg.OPTIM.DAMPENING,
143 | weight_decay=cfg.OPTIM.WD,
144 | nesterov=cfg.OPTIM.NESTEROV)
145 | else:
146 | raise NotImplementedError
147 |
148 |
149 | if __name__ == '__main__':
150 | evaluate('"CIFAR-10-C evaluation.')
151 |
--------------------------------------------------------------------------------
/cifar/cifar10c_gradual.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.optim as optim
5 |
6 | from robustbench.data import load_cifar10c
7 | from robustbench.model_zoo.enums import ThreatModel
8 | from robustbench.utils import load_model
9 | from robustbench.utils import clean_accuracy as accuracy
10 |
11 | import tent
12 | import norm
13 | import cotta
14 |
15 | from conf import cfg, load_cfg_fom_args
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def evaluate(description):
22 | load_cfg_fom_args(description)
23 | # configure model
24 | base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
25 | cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda()
26 | if cfg.MODEL.ADAPTATION == "source":
27 | logger.info("test-time adaptation: NONE")
28 | model = setup_source(base_model)
29 | if cfg.MODEL.ADAPTATION == "norm":
30 | logger.info("test-time adaptation: NORM")
31 | model = setup_norm(base_model)
32 | if cfg.MODEL.ADAPTATION == "tent":
33 | logger.info("test-time adaptation: TENT")
34 | model = setup_tent(base_model)
35 | if cfg.MODEL.ADAPTATION == "cotta":
36 | logger.info("test-time adaptation: CoTTA")
37 | model = setup_cotta(base_model)
38 | # evaluate on each severity and type of corruption in turn
39 | prev_ct = "x0"
40 | for i_c, corruption_type in enumerate(cfg.CORRUPTION.TYPE):
41 | # continual adaptation for all corruption
42 | if i_c == 0:
43 | try:
44 | model.reset()
45 | logger.info("resetting model")
46 | except:
47 | logger.warning("not resetting model")
48 | severities = [5,4,3,2,1]
49 | #To simulate the large gap between the source and first target domain, we start from severity 5.
50 | else:
51 | severities = [1,2,3,4,5,4,3,2,1]
52 | logger.info("not resetting model")
53 | for severity in severities:
54 | x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX,
55 | severity, cfg.DATA_DIR, False,
56 | [corruption_type])
57 | x_test, y_test = x_test.cuda(), y_test.cuda()
58 | acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE)
59 | err = 1. - acc
60 | logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}")
61 |
62 |
63 | def setup_source(model):
64 | """Set up the baseline source model without adaptation."""
65 | model.eval()
66 | logger.info(f"model for evaluation: %s", model)
67 | return model
68 |
69 |
70 | def setup_norm(model):
71 | """Set up test-time normalization adaptation.
72 |
73 | Adapt by normalizing features with test batch statistics.
74 | The statistics are measured independently for each batch;
75 | no running average or other cross-batch estimation is used.
76 | """
77 | norm_model = norm.Norm(model)
78 | logger.info(f"model for adaptation: %s", model)
79 | stats, stat_names = norm.collect_stats(model)
80 | logger.info(f"stats for adaptation: %s", stat_names)
81 | return norm_model
82 |
83 |
84 | def setup_tent(model):
85 | """Set up tent adaptation.
86 |
87 | Configure the model for training + feature modulation by batch statistics,
88 | collect the parameters for feature modulation by gradient optimization,
89 | set up the optimizer, and then tent the model.
90 | """
91 | model = tent.configure_model(model)
92 | params, param_names = tent.collect_params(model)
93 | optimizer = setup_optimizer(params)
94 | tent_model = tent.Tent(model, optimizer,
95 | steps=cfg.OPTIM.STEPS,
96 | episodic=cfg.MODEL.EPISODIC)
97 | logger.info(f"model for adaptation: %s", model)
98 | logger.info(f"params for adaptation: %s", param_names)
99 | logger.info(f"optimizer for adaptation: %s", optimizer)
100 | return tent_model
101 |
102 |
103 | def setup_cotta(model):
104 | """Set up tent adaptation.
105 |
106 | Configure the model for training + feature modulation by batch statistics,
107 | collect the parameters for feature modulation by gradient optimization,
108 | set up the optimizer, and then tent the model.
109 | """
110 | model = cotta.configure_model(model)
111 | params, param_names = cotta.collect_params(model)
112 | optimizer = setup_optimizer(params)
113 | cotta_model = cotta.CoTTA(model, optimizer,
114 | steps=cfg.OPTIM.STEPS,
115 | episodic=cfg.MODEL.EPISODIC,
116 | mt_alpha=cfg.OPTIM.MT,
117 | rst_m=cfg.OPTIM.RST,
118 | ap=cfg.OPTIM.AP)
119 | logger.info(f"model for adaptation: %s", model)
120 | logger.info(f"params for adaptation: %s", param_names)
121 | logger.info(f"optimizer for adaptation: %s", optimizer)
122 | return cotta_model
123 |
124 |
125 | def setup_optimizer(params):
126 | """Set up optimizer for tent adaptation.
127 |
128 | Tent needs an optimizer for test-time entropy minimization.
129 | In principle, tent could make use of any gradient optimizer.
130 | In practice, we advise choosing Adam or SGD+momentum.
131 | For optimization settings, we advise to use the settings from the end of
132 | trainig, if known, or start with a low learning rate (like 0.001) if not.
133 |
134 | For best results, try tuning the learning rate and batch size.
135 | """
136 | if cfg.OPTIM.METHOD == 'Adam':
137 | return optim.Adam(params,
138 | lr=cfg.OPTIM.LR,
139 | betas=(cfg.OPTIM.BETA, 0.999),
140 | weight_decay=cfg.OPTIM.WD)
141 | elif cfg.OPTIM.METHOD == 'SGD':
142 | return optim.SGD(params,
143 | lr=cfg.OPTIM.LR,
144 | momentum=cfg.OPTIM.MOMENTUM,
145 | dampening=cfg.OPTIM.DAMPENING,
146 | weight_decay=cfg.OPTIM.WD,
147 | nesterov=cfg.OPTIM.NESTEROV)
148 | else:
149 | raise NotImplementedError
150 |
151 |
152 | if __name__ == '__main__':
153 | evaluate('"CIFAR-10-C evaluation.')
154 |
--------------------------------------------------------------------------------
/cifar/robustbench/model_zoo/architectures/resnext.py:
--------------------------------------------------------------------------------
1 | """ResNeXt implementation (https://arxiv.org/abs/1611.05431).
2 |
3 | MIT License
4 |
5 | Copyright (c) 2017 Xuanyi Dong
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
25 | From:
26 | https://github.com/google-research/augmix/blob/master/third_party/WideResNet_pytorch/wideresnet.py
27 |
28 | """
29 |
30 | import math
31 |
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 | from torch.nn import init
35 |
36 |
37 | class ResNeXtBottleneck(nn.Module):
38 | """
39 | ResNeXt Bottleneck Block type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua).
40 | """
41 | expansion = 4
42 |
43 | def __init__(self,
44 | inplanes,
45 | planes,
46 | cardinality,
47 | base_width,
48 | stride=1,
49 | downsample=None):
50 | super(ResNeXtBottleneck, self).__init__()
51 |
52 | dim = int(math.floor(planes * (base_width / 64.0)))
53 |
54 | self.conv_reduce = nn.Conv2d(
55 | inplanes,
56 | dim * cardinality,
57 | kernel_size=1,
58 | stride=1,
59 | padding=0,
60 | bias=False)
61 | self.bn_reduce = nn.BatchNorm2d(dim * cardinality)
62 |
63 | self.conv_conv = nn.Conv2d(
64 | dim * cardinality,
65 | dim * cardinality,
66 | kernel_size=3,
67 | stride=stride,
68 | padding=1,
69 | groups=cardinality,
70 | bias=False)
71 | self.bn = nn.BatchNorm2d(dim * cardinality)
72 |
73 | self.conv_expand = nn.Conv2d(
74 | dim * cardinality,
75 | planes * 4,
76 | kernel_size=1,
77 | stride=1,
78 | padding=0,
79 | bias=False)
80 | self.bn_expand = nn.BatchNorm2d(planes * 4)
81 |
82 | self.downsample = downsample
83 |
84 | def forward(self, x):
85 | residual = x
86 |
87 | bottleneck = self.conv_reduce(x)
88 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True)
89 |
90 | bottleneck = self.conv_conv(bottleneck)
91 | bottleneck = F.relu(self.bn(bottleneck), inplace=True)
92 |
93 | bottleneck = self.conv_expand(bottleneck)
94 | bottleneck = self.bn_expand(bottleneck)
95 |
96 | if self.downsample is not None:
97 | residual = self.downsample(x)
98 |
99 | return F.relu(residual + bottleneck, inplace=True)
100 |
101 |
102 | class CifarResNeXt(nn.Module):
103 | """ResNext optimized for the Cifar dataset, as specified in
104 | https://arxiv.org/pdf/1611.05431.pdf."""
105 |
106 | def __init__(self, block, depth, cardinality, base_width, num_classes):
107 | super(CifarResNeXt, self).__init__()
108 |
109 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
110 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101'
111 | layer_blocks = (depth - 2) // 9
112 |
113 | self.cardinality = cardinality
114 | self.base_width = base_width
115 | self.num_classes = num_classes
116 |
117 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
118 | self.bn_1 = nn.BatchNorm2d(64)
119 |
120 | self.inplanes = 64
121 | self.stage_1 = self._make_layer(block, 64, layer_blocks, 1)
122 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2)
123 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2)
124 | self.avgpool = nn.AvgPool2d(8)
125 | self.classifier = nn.Linear(256 * block.expansion, num_classes)
126 |
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130 | m.weight.data.normal_(0, math.sqrt(2. / n))
131 | elif isinstance(m, nn.BatchNorm2d):
132 | m.weight.data.fill_(1)
133 | m.bias.data.zero_()
134 | elif isinstance(m, nn.Linear):
135 | init.kaiming_normal_(m.weight)
136 | m.bias.data.zero_()
137 |
138 | def _make_layer(self, block, planes, blocks, stride=1):
139 | downsample = None
140 | if stride != 1 or self.inplanes != planes * block.expansion:
141 | downsample = nn.Sequential(
142 | nn.Conv2d(
143 | self.inplanes,
144 | planes * block.expansion,
145 | kernel_size=1,
146 | stride=stride,
147 | bias=False),
148 | nn.BatchNorm2d(planes * block.expansion),
149 | )
150 |
151 | layers = []
152 | layers.append(
153 | block(self.inplanes, planes, self.cardinality, self.base_width, stride,
154 | downsample))
155 | self.inplanes = planes * block.expansion
156 | for _ in range(1, blocks):
157 | layers.append(
158 | block(self.inplanes, planes, self.cardinality, self.base_width))
159 |
160 | return nn.Sequential(*layers)
161 |
162 | def forward(self, x):
163 | x = self.conv_1_3x3(x)
164 | x = F.relu(self.bn_1(x), inplace=True)
165 | x = self.stage_1(x)
166 | x = self.stage_2(x)
167 | x = self.stage_3(x)
168 | x = self.avgpool(x)
169 | x = x.view(x.size(0), -1)
170 | return self.classifier(x)
171 |
--------------------------------------------------------------------------------
/imagenet/robustbench/model_zoo/architectures/resnext.py:
--------------------------------------------------------------------------------
1 | """ResNeXt implementation (https://arxiv.org/abs/1611.05431).
2 |
3 | MIT License
4 |
5 | Copyright (c) 2017 Xuanyi Dong
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
25 | From:
26 | https://github.com/google-research/augmix/blob/master/third_party/WideResNet_pytorch/wideresnet.py
27 |
28 | """
29 |
30 | import math
31 |
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 | from torch.nn import init
35 |
36 |
37 | class ResNeXtBottleneck(nn.Module):
38 | """
39 | ResNeXt Bottleneck Block type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua).
40 | """
41 | expansion = 4
42 |
43 | def __init__(self,
44 | inplanes,
45 | planes,
46 | cardinality,
47 | base_width,
48 | stride=1,
49 | downsample=None):
50 | super(ResNeXtBottleneck, self).__init__()
51 |
52 | dim = int(math.floor(planes * (base_width / 64.0)))
53 |
54 | self.conv_reduce = nn.Conv2d(
55 | inplanes,
56 | dim * cardinality,
57 | kernel_size=1,
58 | stride=1,
59 | padding=0,
60 | bias=False)
61 | self.bn_reduce = nn.BatchNorm2d(dim * cardinality)
62 |
63 | self.conv_conv = nn.Conv2d(
64 | dim * cardinality,
65 | dim * cardinality,
66 | kernel_size=3,
67 | stride=stride,
68 | padding=1,
69 | groups=cardinality,
70 | bias=False)
71 | self.bn = nn.BatchNorm2d(dim * cardinality)
72 |
73 | self.conv_expand = nn.Conv2d(
74 | dim * cardinality,
75 | planes * 4,
76 | kernel_size=1,
77 | stride=1,
78 | padding=0,
79 | bias=False)
80 | self.bn_expand = nn.BatchNorm2d(planes * 4)
81 |
82 | self.downsample = downsample
83 |
84 | def forward(self, x):
85 | residual = x
86 |
87 | bottleneck = self.conv_reduce(x)
88 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True)
89 |
90 | bottleneck = self.conv_conv(bottleneck)
91 | bottleneck = F.relu(self.bn(bottleneck), inplace=True)
92 |
93 | bottleneck = self.conv_expand(bottleneck)
94 | bottleneck = self.bn_expand(bottleneck)
95 |
96 | if self.downsample is not None:
97 | residual = self.downsample(x)
98 |
99 | return F.relu(residual + bottleneck, inplace=True)
100 |
101 |
102 | class CifarResNeXt(nn.Module):
103 | """ResNext optimized for the Cifar dataset, as specified in
104 | https://arxiv.org/pdf/1611.05431.pdf."""
105 |
106 | def __init__(self, block, depth, cardinality, base_width, num_classes):
107 | super(CifarResNeXt, self).__init__()
108 |
109 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
110 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101'
111 | layer_blocks = (depth - 2) // 9
112 |
113 | self.cardinality = cardinality
114 | self.base_width = base_width
115 | self.num_classes = num_classes
116 |
117 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
118 | self.bn_1 = nn.BatchNorm2d(64)
119 |
120 | self.inplanes = 64
121 | self.stage_1 = self._make_layer(block, 64, layer_blocks, 1)
122 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2)
123 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2)
124 | self.avgpool = nn.AvgPool2d(8)
125 | self.classifier = nn.Linear(256 * block.expansion, num_classes)
126 |
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130 | m.weight.data.normal_(0, math.sqrt(2. / n))
131 | elif isinstance(m, nn.BatchNorm2d):
132 | m.weight.data.fill_(1)
133 | m.bias.data.zero_()
134 | elif isinstance(m, nn.Linear):
135 | init.kaiming_normal_(m.weight)
136 | m.bias.data.zero_()
137 |
138 | def _make_layer(self, block, planes, blocks, stride=1):
139 | downsample = None
140 | if stride != 1 or self.inplanes != planes * block.expansion:
141 | downsample = nn.Sequential(
142 | nn.Conv2d(
143 | self.inplanes,
144 | planes * block.expansion,
145 | kernel_size=1,
146 | stride=stride,
147 | bias=False),
148 | nn.BatchNorm2d(planes * block.expansion),
149 | )
150 |
151 | layers = []
152 | layers.append(
153 | block(self.inplanes, planes, self.cardinality, self.base_width, stride,
154 | downsample))
155 | self.inplanes = planes * block.expansion
156 | for _ in range(1, blocks):
157 | layers.append(
158 | block(self.inplanes, planes, self.cardinality, self.base_width))
159 |
160 | return nn.Sequential(*layers)
161 |
162 | def forward(self, x):
163 | x = self.conv_1_3x3(x)
164 | x = F.relu(self.bn_1(x), inplace=True)
165 | x = self.stage_1(x)
166 | x = self.stage_2(x)
167 | x = self.stage_3(x)
168 | x = self.avgpool(x)
169 | x = x.view(x.size(0), -1)
170 | return self.classifier(x)
171 |
--------------------------------------------------------------------------------
/cifar/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import numpy as np
5 | from typing import Dict, List, Tuple
6 | import torch.nn as nn
7 | import torch.utils.data as data
8 | import logging
9 | from tqdm import tqdm
10 |
11 | def pytorch_evaluate(net: nn.Module, data_loader: data.DataLoader, fetch_keys: List,
12 | x_shape: Tuple = None, output_shapes: Dict = None, to_tensor: bool=False, verbose=False) -> Tuple:
13 |
14 | if output_shapes is not None:
15 | for key in fetch_keys:
16 | assert key in output_shapes
17 |
18 | # Fetching inference outputs as numpy arrays
19 | batch_size = data_loader.batch_size
20 | num_samples = len(data_loader.dataset)
21 | batch_count = int(np.ceil(num_samples / batch_size))
22 | fetches_dict = {}
23 | fetches = []
24 | for key in fetch_keys:
25 | fetches_dict[key] = []
26 |
27 | net.eval()
28 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
29 |
30 | for batch_idx, (inputs, targets) in (tqdm(enumerate(data_loader)) if verbose else enumerate(data_loader)):
31 | if x_shape is not None:
32 | inputs = inputs.reshape(x_shape)
33 | inputs, targets = inputs.to(device), targets.to(device)
34 | outputs_dict = net(inputs)
35 | for key in fetch_keys:
36 | fetches_dict[key].append(outputs_dict[key].data.cpu().detach().numpy())
37 |
38 | # stack variables together
39 | for key in fetch_keys:
40 | fetch = np.vstack(fetches_dict[key])
41 | if output_shapes is not None:
42 | fetch = fetch.reshape(output_shapes[key])
43 | if to_tensor:
44 | fetch = torch.as_tensor(fetch, device=torch.device(device))
45 | fetches.append(fetch)
46 |
47 | assert batch_idx + 1 == batch_count
48 | assert fetches[0].shape[0] == num_samples
49 |
50 | return tuple(fetches)
51 |
52 | def boolean_string(s):
53 | # to use --use_bn True or --use_bn False in the shell. See:
54 | # https://stackoverflow.com/questions/44561722/why-in-argparse-a-true-is-always-true
55 | if s not in {'False', 'True'}:
56 | raise ValueError('Not a valid boolean string')
57 | return s == 'True'
58 |
59 | def convert_tensor_to_image(x: np.ndarray):
60 | """
61 | :param X: np.array of size (Batch, feature_dims, H, W) or (feature_dims, H, W)
62 | :return: X with (Batch, H, W, feature_dims) or (H, W, feature_dims) between 0:255, uint8
63 | """
64 | X = x.copy()
65 | X *= 255.0
66 | X = np.round(X)
67 | X = X.astype(np.uint8)
68 | if len(x.shape) == 3:
69 | X = np.transpose(X, [1, 2, 0])
70 | else:
71 | X = np.transpose(X, [0, 2, 3, 1])
72 | return X
73 |
74 | def convert_image_to_tensor(x: np.ndarray):
75 | """
76 | :param X: np.array of size (Batch, H, W, feature_dims) between 0:255, uint8
77 | :return: X with (Batch, feature_dims, H, W) float between [0:1]
78 | """
79 | assert x.dtype == np.uint8
80 | X = x.copy()
81 | X = X.astype(np.float32)
82 | X /= 255.0
83 | X = np.transpose(X, [0, 3, 1, 2])
84 | return X
85 |
86 | def majority_vote(x):
87 | return np.bincount(x).argmax()
88 |
89 | def get_ensemble_paths(ensemble_dir):
90 | ensemble_subdirs = next(os.walk(ensemble_dir))[1]
91 | ensemble_subdirs.sort()
92 | ensemble_paths = []
93 | for j, dir in enumerate(ensemble_subdirs): # for network j
94 | ensemble_paths.append(os.path.join(ensemble_dir, dir, 'ckpt.pth'))
95 |
96 | return ensemble_paths
97 |
98 | def set_logger(log_file):
99 | logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s',
100 | datefmt='%m/%d/%Y %I:%M:%S %p',
101 | level=logging.INFO,
102 | handlers=[logging.FileHandler(log_file, mode='w'),
103 | logging.StreamHandler(sys.stdout)]
104 | )
105 |
106 | def print_Linf_dists(X, X_test):
107 | logger = logging.getLogger()
108 | X_diff = (X - X_test).reshape(X.shape[0], -1)
109 | X_diff_abs = np.abs(X_diff)
110 | Linf_dist = X_diff_abs.max(axis=1)
111 | Linf_dist = Linf_dist[np.where(Linf_dist > 0.0)[0]]
112 | logger.info('The adversarial attacks distance: Max[L_inf]={}, E[L_inf]={}'.format(np.max(Linf_dist), np.mean(Linf_dist)))
113 |
114 | def calc_attack_rate(y_preds: np.ndarray, y_orig_norm_preds: np.ndarray, y_gt: np.ndarray) -> float:
115 | """
116 | Args:
117 | y_preds: The adv image's final prediction after the defense method
118 | y_orig_norm_preds: The original image's predictions
119 | y_gt: The GT labels
120 | targeted: Whether or not the attack was targeted
121 |
122 | Returns: attack rate in %
123 | """
124 | f0_inds = [] # net_fail
125 | f1_inds = [] # net_succ
126 | f2_inds = [] # net_succ AND attack_flip
127 |
128 | for i in range(len(y_gt)):
129 | f1 = y_orig_norm_preds[i] == y_gt[i]
130 | f2 = f1 and y_preds[i] != y_orig_norm_preds[i]
131 | if f1:
132 | f1_inds.append(i)
133 | else:
134 | f0_inds.append(i)
135 | if f2:
136 | f2_inds.append(i)
137 |
138 | attack_rate = len(f2_inds) / len(f1_inds)
139 | return attack_rate
140 |
141 | def get_all_files_recursive(path, suffix=None):
142 | files = []
143 | # r=root, d=directories, f=files
144 | for r, d, f in os.walk(path):
145 | for file in f:
146 | if suffix is None:
147 | files.append(os.path.join(r, file))
148 | elif '.' + suffix in file:
149 | files.append(os.path.join(r, file))
150 | return files
151 |
152 | def convert_grayscale_to_rgb(x: np.ndarray) -> np.ndarray:
153 | """
154 | Converts a 2D image shape=(x, y) to a RGB image (x, y, 3).
155 | Args:
156 | x: gray image
157 | Returns: rgb image
158 | """
159 | return np.stack((x, ) * 3, axis=-1)
160 |
161 | def inverse_map(x: dict) -> dict:
162 | """
163 | :param x: dictionary
164 | :return: inverse mapping, showing for each val its key
165 | """
166 | inv_map = {}
167 | for k, v in x.items():
168 | inv_map[v] = k
169 | return inv_map
170 |
171 | def get_image_shape(dataset: str) -> Tuple[int, int, int]:
172 | if dataset in ['cifar10', 'cifar100', 'svhn']:
173 | return 32, 32, 3
174 | elif dataset == 'tiny_imagenet':
175 | return 64, 64, 3
176 | else:
177 | raise AssertionError('Unsupported dataset {}'.format(dataset))
178 |
--------------------------------------------------------------------------------
/cifar/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Configuration file (powered by YACS)."""
7 |
8 | import argparse
9 | import os
10 | import sys
11 | import logging
12 | import random
13 | import torch
14 | import numpy as np
15 | from datetime import datetime
16 | from iopath.common.file_io import g_pathmgr
17 | from yacs.config import CfgNode as CfgNode
18 |
19 |
20 | # Global config object (example usage: from core.config import cfg)
21 | _C = CfgNode()
22 | cfg = _C
23 |
24 |
25 | # ----------------------------- Model options ------------------------------- #
26 | _C.MODEL = CfgNode()
27 |
28 | # Check https://github.com/RobustBench/robustbench for available models
29 | _C.MODEL.ARCH = 'Standard'
30 |
31 | # Choice of (source, norm, tent)
32 | # - source: baseline without adaptation
33 | # - norm: test-time normalization
34 | # - tent: test-time entropy minimization (ours)
35 | _C.MODEL.ADAPTATION = 'source'
36 |
37 | # By default tent is online, with updates persisting across batches.
38 | # To make adaptation episodic, and reset the model for each batch, choose True.
39 | _C.MODEL.EPISODIC = False
40 |
41 | # ----------------------------- Corruption options -------------------------- #
42 | _C.CORRUPTION = CfgNode()
43 |
44 | # Dataset for evaluation
45 | _C.CORRUPTION.DATASET = 'cifar10'
46 |
47 | # Check https://github.com/hendrycks/robustness for corruption details
48 | _C.CORRUPTION.TYPE = ['gaussian_noise', 'shot_noise', 'impulse_noise',
49 | 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
50 | 'snow', 'frost', 'fog', 'brightness', 'contrast',
51 | 'elastic_transform', 'pixelate', 'jpeg_compression']
52 | _C.CORRUPTION.SEVERITY = [5, 4, 3, 2, 1]
53 |
54 | # Number of examples to evaluate (10000 for all samples in CIFAR-10)
55 | _C.CORRUPTION.NUM_EX = 10000
56 |
57 | # ------------------------------- Batch norm options ------------------------ #
58 | _C.BN = CfgNode()
59 |
60 | # BN epsilon
61 | _C.BN.EPS = 1e-5
62 |
63 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
64 | _C.BN.MOM = 0.1
65 |
66 | # ------------------------------- Optimizer options ------------------------- #
67 | _C.OPTIM = CfgNode()
68 |
69 | # Number of updates per batch
70 | _C.OPTIM.STEPS = 1
71 |
72 | # Learning rate
73 | _C.OPTIM.LR = 1e-3
74 |
75 | # Choices: Adam, SGD
76 | _C.OPTIM.METHOD = 'Adam'
77 |
78 | # Beta
79 | _C.OPTIM.BETA = 0.9
80 |
81 | # Momentum
82 | _C.OPTIM.MOMENTUM = 0.9
83 |
84 | # Momentum dampening
85 | _C.OPTIM.DAMPENING = 0.0
86 |
87 | # Nesterov momentum
88 | _C.OPTIM.NESTEROV = True
89 |
90 | # L2 regularization
91 | _C.OPTIM.WD = 0.0
92 |
93 | # COTTA
94 | _C.OPTIM.MT = 0.999
95 | _C.OPTIM.RST = 0.01
96 | _C.OPTIM.AP = 0.92
97 |
98 | # ------------------------------- Testing options --------------------------- #
99 | _C.TEST = CfgNode()
100 |
101 | # Batch size for evaluation (and updates for norm + tent)
102 | _C.TEST.BATCH_SIZE = 128
103 |
104 | # --------------------------------- CUDNN options --------------------------- #
105 | _C.CUDNN = CfgNode()
106 |
107 | # Benchmark to select fastest CUDNN algorithms (best for fixed input sizes)
108 | _C.CUDNN.BENCHMARK = True
109 |
110 | # ---------------------------------- Misc options --------------------------- #
111 |
112 | # Optional description of a config
113 | _C.DESC = ""
114 |
115 | # Note that non-determinism is still present due to non-deterministic GPU ops
116 | _C.RNG_SEED = 1
117 |
118 | # Output directory
119 | _C.SAVE_DIR = "./output"
120 |
121 | # Data directory
122 | _C.DATA_DIR = "./data"
123 |
124 | # Weight directory
125 | _C.CKPT_DIR = "./ckpt"
126 |
127 | # Log destination (in SAVE_DIR)
128 | _C.LOG_DEST = "log.txt"
129 |
130 | # Log datetime
131 | _C.LOG_TIME = ''
132 |
133 | # # Config destination (in SAVE_DIR)
134 | # _C.CFG_DEST = "cfg.yaml"
135 |
136 | # --------------------------------- Default config -------------------------- #
137 | _CFG_DEFAULT = _C.clone()
138 | _CFG_DEFAULT.freeze()
139 |
140 |
141 | def assert_and_infer_cfg():
142 | """Checks config values invariants."""
143 | err_str = "Unknown adaptation method."
144 | assert _C.MODEL.ADAPTATION in ["source", "norm", "tent"]
145 | err_str = "Log destination '{}' not supported"
146 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
147 |
148 |
149 | def merge_from_file(cfg_file):
150 | with g_pathmgr.open(cfg_file, "r") as f:
151 | cfg = _C.load_cfg(f)
152 | _C.merge_from_other_cfg(cfg)
153 |
154 |
155 | def dump_cfg():
156 | """Dumps the config to the output directory."""
157 | cfg_file = os.path.join(_C.SAVE_DIR, _C.CFG_DEST)
158 | with g_pathmgr.open(cfg_file, "w") as f:
159 | _C.dump(stream=f)
160 |
161 |
162 | def load_cfg(out_dir, cfg_dest="config.yaml"):
163 | """Loads config from specified output directory."""
164 | cfg_file = os.path.join(out_dir, cfg_dest)
165 | merge_from_file(cfg_file)
166 |
167 |
168 | def reset_cfg():
169 | """Reset config to initial state."""
170 | cfg.merge_from_other_cfg(_CFG_DEFAULT)
171 |
172 |
173 | def load_cfg_fom_args(description="Config options."):
174 | """Load config from command line args and set any specified options."""
175 | current_time = datetime.now().strftime("%y%m%d_%H%M%S")
176 | parser = argparse.ArgumentParser(description=description)
177 | parser.add_argument("--cfg", dest="cfg_file", type=str, required=True,
178 | help="Config file location")
179 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER,
180 | help="See conf.py for all options")
181 | if len(sys.argv) == 1:
182 | parser.print_help()
183 | sys.exit(1)
184 | args = parser.parse_args()
185 |
186 | merge_from_file(args.cfg_file)
187 | cfg.merge_from_list(args.opts)
188 |
189 | log_dest = os.path.basename(args.cfg_file)
190 | log_dest = log_dest.replace('.yaml', '_{}.txt'.format(current_time))
191 |
192 | g_pathmgr.mkdirs(cfg.SAVE_DIR)
193 | cfg.LOG_TIME, cfg.LOG_DEST = current_time, log_dest
194 | cfg.freeze()
195 |
196 | logging.basicConfig(
197 | level=logging.INFO,
198 | format="[%(asctime)s] [%(filename)s: %(lineno)4d]: %(message)s",
199 | datefmt="%y/%m/%d %H:%M:%S",
200 | handlers=[
201 | logging.FileHandler(os.path.join(cfg.SAVE_DIR, cfg.LOG_DEST)),
202 | logging.StreamHandler()
203 | ])
204 |
205 | np.random.seed(cfg.RNG_SEED)
206 | torch.manual_seed(cfg.RNG_SEED)
207 | random.seed(cfg.RNG_SEED)
208 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
209 |
210 | logger = logging.getLogger(__name__)
211 | version = [torch.__version__, torch.version.cuda,
212 | torch.backends.cudnn.version()]
213 | logger.info(
214 | "PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version))
215 | logger.info(cfg)
216 |
--------------------------------------------------------------------------------
/imagenet/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Configuration file (powered by YACS)."""
7 |
8 | import argparse
9 | import os
10 | import sys
11 | import logging
12 | import random
13 | import torch
14 | import numpy as np
15 | from datetime import datetime
16 | from iopath.common.file_io import g_pathmgr
17 | from yacs.config import CfgNode as CfgNode
18 |
19 |
20 | # Global config object (example usage: from core.config import cfg)
21 | _C = CfgNode()
22 | cfg = _C
23 |
24 |
25 | # ----------------------------- Model options ------------------------------- #
26 | _C.MODEL = CfgNode()
27 |
28 | # Check https://github.com/RobustBench/robustbench for available models
29 | _C.MODEL.ARCH = 'Standard'
30 |
31 | # Choice of (source, norm, tent)
32 | # - source: baseline without adaptation
33 | # - norm: test-time normalization
34 | # - tent: test-time entropy minimization (ours)
35 | _C.MODEL.ADAPTATION = 'source'
36 |
37 | # By default tent is online, with updates persisting across batches.
38 | # To make adaptation episodic, and reset the model for each batch, choose True.
39 | _C.MODEL.EPISODIC = False
40 |
41 | # ----------------------------- Corruption options -------------------------- #
42 | _C.CORRUPTION = CfgNode()
43 |
44 | # Dataset for evaluation
45 | _C.CORRUPTION.DATASET = 'cifar10'
46 |
47 | # Check https://github.com/hendrycks/robustness for corruption details
48 | _C.CORRUPTION.TYPE = ['gaussian_noise', 'shot_noise', 'impulse_noise',
49 | 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
50 | 'snow', 'frost', 'fog', 'brightness', 'contrast',
51 | 'elastic_transform', 'pixelate', 'jpeg_compression']
52 | _C.CORRUPTION.SEVERITY = [5, 4, 3, 2, 1]
53 |
54 | # Number of examples to evaluate
55 | # The 5000 val images defined by Robustbench were actually used:
56 | # Please see https://github.com/RobustBench/robustbench/blob/7af0e34c6b383cd73ea7a1bbced358d7ce6ad22f/robustbench/data/imagenet_test_image_ids.txt
57 | _C.CORRUPTION.NUM_EX = 5000
58 |
59 | # ------------------------------- Batch norm options ------------------------ #
60 | _C.BN = CfgNode()
61 |
62 | # BN epsilon
63 | _C.BN.EPS = 1e-5
64 |
65 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
66 | _C.BN.MOM = 0.1
67 |
68 | # ------------------------------- Optimizer options ------------------------- #
69 | _C.OPTIM = CfgNode()
70 |
71 | # Number of updates per batch
72 | _C.OPTIM.STEPS = 1
73 |
74 | # Learning rate
75 | _C.OPTIM.LR = 1e-3
76 |
77 | # Choices: Adam, SGD
78 | _C.OPTIM.METHOD = 'Adam'
79 |
80 | # Beta
81 | _C.OPTIM.BETA = 0.9
82 |
83 | # Momentum
84 | _C.OPTIM.MOMENTUM = 0.9
85 |
86 | # Momentum dampening
87 | _C.OPTIM.DAMPENING = 0.0
88 |
89 | # Nesterov momentum
90 | _C.OPTIM.NESTEROV = True
91 |
92 | # L2 regularization
93 | _C.OPTIM.WD = 0.0
94 |
95 | # ------------------------------- Testing options --------------------------- #
96 | _C.TEST = CfgNode()
97 |
98 | # Batch size for evaluation (and updates for norm + tent)
99 | _C.TEST.BATCH_SIZE = 128
100 |
101 | # --------------------------------- CUDNN options --------------------------- #
102 | _C.CUDNN = CfgNode()
103 |
104 | # Benchmark to select fastest CUDNN algorithms (best for fixed input sizes)
105 | _C.CUDNN.BENCHMARK = True
106 |
107 | # ---------------------------------- Misc options --------------------------- #
108 |
109 | # Optional description of a config
110 | _C.DESC = ""
111 |
112 | # Note that non-determinism is still present due to non-deterministic GPU ops
113 | _C.RNG_SEED = 1
114 |
115 | # Output directory
116 | _C.SAVE_DIR = "./output"
117 |
118 | # Data directory
119 | _C.DATA_DIR = "./data"
120 |
121 | # Weight directory
122 | _C.CKPT_DIR = "./ckpt"
123 |
124 | # Log destination (in SAVE_DIR)
125 | _C.LOG_DEST = "log.txt"
126 |
127 | # Log datetime
128 | _C.LOG_TIME = ''
129 |
130 | # # Config destination (in SAVE_DIR)
131 | # _C.CFG_DEST = "cfg.yaml"
132 |
133 | # --------------------------------- Default config -------------------------- #
134 | _CFG_DEFAULT = _C.clone()
135 | _CFG_DEFAULT.freeze()
136 |
137 |
138 | def assert_and_infer_cfg():
139 | """Checks config values invariants."""
140 | err_str = "Unknown adaptation method."
141 | assert _C.MODEL.ADAPTATION in ["source", "norm", "tent"]
142 | err_str = "Log destination '{}' not supported"
143 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
144 |
145 |
146 | def merge_from_file(cfg_file):
147 | with g_pathmgr.open(cfg_file, "r") as f:
148 | cfg = _C.load_cfg(f)
149 | _C.merge_from_other_cfg(cfg)
150 |
151 |
152 | def dump_cfg():
153 | """Dumps the config to the output directory."""
154 | cfg_file = os.path.join(_C.SAVE_DIR, _C.CFG_DEST)
155 | with g_pathmgr.open(cfg_file, "w") as f:
156 | _C.dump(stream=f)
157 |
158 |
159 | def load_cfg(out_dir, cfg_dest="config.yaml"):
160 | """Loads config from specified output directory."""
161 | cfg_file = os.path.join(out_dir, cfg_dest)
162 | merge_from_file(cfg_file)
163 |
164 |
165 | def reset_cfg():
166 | """Reset config to initial state."""
167 | cfg.merge_from_other_cfg(_CFG_DEFAULT)
168 |
169 |
170 | def load_cfg_fom_args(description="Config options."):
171 | """Load config from command line args and set any specified options."""
172 | current_time = datetime.now().strftime("%y%m%d_%H%M%S")
173 | parser = argparse.ArgumentParser(description=description)
174 | parser.add_argument("--cfg", dest="cfg_file", type=str, required=True,
175 | help="Config file location")
176 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER,
177 | help="See conf.py for all options")
178 | if len(sys.argv) == 1:
179 | parser.print_help()
180 | sys.exit(1)
181 | args = parser.parse_args()
182 |
183 | merge_from_file(args.cfg_file)
184 | cfg.merge_from_list(args.opts)
185 |
186 | log_dest = os.path.basename(args.cfg_file)
187 | log_dest = log_dest.replace('.yaml', '_{}.txt'.format(current_time))
188 |
189 | g_pathmgr.mkdirs(cfg.SAVE_DIR)
190 | cfg.LOG_TIME, cfg.LOG_DEST = current_time, log_dest
191 | cfg.freeze()
192 |
193 | logging.basicConfig(
194 | level=logging.INFO,
195 | format="[%(asctime)s] [%(filename)s: %(lineno)4d]: %(message)s",
196 | datefmt="%y/%m/%d %H:%M:%S",
197 | handlers=[
198 | logging.FileHandler(os.path.join(cfg.SAVE_DIR, cfg.LOG_DEST)),
199 | logging.StreamHandler()
200 | ])
201 |
202 | np.random.seed(cfg.RNG_SEED)
203 | torch.manual_seed(cfg.RNG_SEED)
204 | random.seed(cfg.RNG_SEED)
205 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
206 |
207 | logger = logging.getLogger(__name__)
208 | version = [torch.__version__, torch.version.cuda,
209 | torch.backends.cudnn.version()]
210 | logger.info(
211 | "PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version))
212 | logger.info(cfg)
213 |
--------------------------------------------------------------------------------
/cifar/robustbench/loaders.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py.
3 | """
4 | from torchvision.datasets.vision import VisionDataset
5 |
6 | import torch
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 |
10 | from PIL import Image
11 |
12 | import os
13 | import os.path
14 | import sys
15 | import json
16 |
17 |
18 | def make_custom_dataset(root, path_imgs, cls_dict):
19 | with open(path_imgs, 'r') as f:
20 | fnames = f.readlines()
21 | with open(cls_dict, 'r') as f:
22 | class_to_idx = json.load(f)
23 | images = [(os.path.join(root, c.split('\n')[0]), class_to_idx[c.split('/')[0]]) for c in fnames]
24 |
25 | return images
26 |
27 |
28 | class CustomDatasetFolder(VisionDataset):
29 | """A generic data loader where the samples are arranged in this way: ::
30 | root/class_x/xxx.ext
31 | root/class_x/xxy.ext
32 | root/class_x/xxz.ext
33 | root/class_y/123.ext
34 | root/class_y/nsdf3.ext
35 | root/class_y/asd932_.ext
36 | Args:
37 | root (string): Root directory path.
38 | loader (callable): A function to load a sample given its path.
39 | extensions (tuple[string]): A list of allowed extensions.
40 | both extensions and is_valid_file should not be passed.
41 | transform (callable, optional): A function/transform that takes in
42 | a sample and returns a transformed version.
43 | E.g, ``transforms.RandomCrop`` for images.
44 | target_transform (callable, optional): A function/transform that takes
45 | in the target and transforms it.
46 | is_valid_file (callable, optional): A function that takes path of an Image file
47 | and check if the file is a valid_file (used to check of corrupt files)
48 | both extensions and is_valid_file should not be passed.
49 | Attributes:
50 | classes (list): List of the class names.
51 | class_to_idx (dict): Dict with items (class_name, class_index).
52 | samples (list): List of (sample path, class_index) tuples
53 | targets (list): The class_index value for each image in the dataset
54 | """
55 |
56 | def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
57 | super(CustomDatasetFolder, self).__init__(root)
58 | self.transform = transform
59 | self.target_transform = target_transform
60 | classes, class_to_idx = self._find_classes(self.root)
61 | samples = make_custom_dataset(self.root, 'robustbench/data/imagenet_test_image_ids.txt',
62 | 'robustbench/data/imagenet_class_to_id_map.json')
63 | if len(samples) == 0:
64 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
65 | "Supported extensions are: " + ",".join(extensions)))
66 |
67 | self.loader = loader
68 | self.extensions = extensions
69 |
70 | self.classes = classes
71 | self.class_to_idx = class_to_idx
72 | self.samples = samples
73 | self.targets = [s[1] for s in samples]
74 |
75 | def _find_classes(self, dir):
76 | """
77 | Finds the class folders in a dataset.
78 | Args:
79 | dir (string): Root directory path.
80 | Returns:
81 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
82 | Ensures:
83 | No class is a subdirectory of another.
84 | """
85 | if sys.version_info >= (3, 5):
86 | # Faster and available in Python 3.5 and above
87 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
88 | else:
89 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
90 | classes.sort()
91 | class_to_idx = {classes[i]: i for i in range(len(classes))}
92 | return classes, class_to_idx
93 |
94 | def __getitem__(self, index):
95 | """
96 | Args:
97 | index (int): Index
98 | Returns:
99 | tuple: (sample, target) where target is class_index of the target class.
100 | """
101 | path, target = self.samples[index]
102 | sample = self.loader(path)
103 | if self.transform is not None:
104 | sample = self.transform(sample)
105 | if self.target_transform is not None:
106 | target = self.target_transform(target)
107 | return sample, target, path
108 |
109 | def __len__(self):
110 | return len(self.samples)
111 |
112 |
113 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
114 |
115 |
116 | def pil_loader(path):
117 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
118 | with open(path, 'rb') as f:
119 | img = Image.open(f)
120 | return img.convert('RGB')
121 |
122 |
123 | def accimage_loader(path):
124 | import accimage
125 | try:
126 | return accimage.Image(path)
127 | except IOError:
128 | # Potentially a decoding problem, fall back to PIL.Image
129 | return pil_loader(path)
130 |
131 |
132 | def default_loader(path):
133 | from torchvision import get_image_backend
134 | if get_image_backend() == 'accimage':
135 | return accimage_loader(path)
136 | else:
137 | return pil_loader(path)
138 |
139 |
140 | class CustomImageFolder(CustomDatasetFolder):
141 | """A generic data loader where the images are arranged in this way: ::
142 | root/dog/xxx.png
143 | root/dog/xxy.png
144 | root/dog/xxz.png
145 | root/cat/123.png
146 | root/cat/nsdf3.png
147 | root/cat/asd932_.png
148 | Args:
149 | root (string): Root directory path.
150 | transform (callable, optional): A function/transform that takes in an PIL image
151 | and returns a transformed version. E.g, ``transforms.RandomCrop``
152 | target_transform (callable, optional): A function/transform that takes in the
153 | target and transforms it.
154 | loader (callable, optional): A function to load an image given its path.
155 | is_valid_file (callable, optional): A function that takes path of an Image file
156 | and check if the file is a valid_file (used to check of corrupt files)
157 | Attributes:
158 | classes (list): List of the class names.
159 | class_to_idx (dict): Dict with items (class_name, class_index).
160 | imgs (list): List of (image path, class_index) tuples
161 | """
162 |
163 | def __init__(self, root, transform=None, target_transform=None,
164 | loader=default_loader, is_valid_file=None):
165 | super(CustomImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
166 | transform=transform,
167 | target_transform=target_transform,
168 | is_valid_file=is_valid_file)
169 |
170 | self.imgs = self.samples
171 |
172 |
173 | if __name__ == '__main__':
174 | data_dir = '/home/scratch/datasets/imagenet/val'
175 | imagenet = CustomImageFolder(data_dir, transforms.Compose([
176 | transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]))
177 |
178 | torch.manual_seed(0)
179 |
180 | test_loader = data.DataLoader(imagenet, batch_size=5000, shuffle=True, num_workers=30)
181 |
182 | x, y, path = next(iter(test_loader))
183 |
184 | with open('path_imgs_2.txt', 'w') as f:
185 | f.write('\n'.join(path))
186 | f.flush()
187 |
188 |
--------------------------------------------------------------------------------
/imagenet/robustbench/loaders.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py.
3 | """
4 | from torchvision.datasets.vision import VisionDataset
5 |
6 | import torch
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 |
10 | from PIL import Image
11 |
12 | import os
13 | import os.path
14 | import sys
15 | import json
16 |
17 |
18 | def make_custom_dataset(root, path_imgs, cls_dict):
19 | with open(path_imgs, 'r') as f:
20 | fnames = f.readlines()
21 | with open(cls_dict, 'r') as f:
22 | class_to_idx = json.load(f)
23 | images = [(os.path.join(root, c.split('\n')[0]), class_to_idx[c.split('/')[0]]) for c in fnames]
24 |
25 | return images
26 |
27 |
28 | class CustomDatasetFolder(VisionDataset):
29 | """A generic data loader where the samples are arranged in this way: ::
30 | root/class_x/xxx.ext
31 | root/class_x/xxy.ext
32 | root/class_x/xxz.ext
33 | root/class_y/123.ext
34 | root/class_y/nsdf3.ext
35 | root/class_y/asd932_.ext
36 | Args:
37 | root (string): Root directory path.
38 | loader (callable): A function to load a sample given its path.
39 | extensions (tuple[string]): A list of allowed extensions.
40 | both extensions and is_valid_file should not be passed.
41 | transform (callable, optional): A function/transform that takes in
42 | a sample and returns a transformed version.
43 | E.g, ``transforms.RandomCrop`` for images.
44 | target_transform (callable, optional): A function/transform that takes
45 | in the target and transforms it.
46 | is_valid_file (callable, optional): A function that takes path of an Image file
47 | and check if the file is a valid_file (used to check of corrupt files)
48 | both extensions and is_valid_file should not be passed.
49 | Attributes:
50 | classes (list): List of the class names.
51 | class_to_idx (dict): Dict with items (class_name, class_index).
52 | samples (list): List of (sample path, class_index) tuples
53 | targets (list): The class_index value for each image in the dataset
54 | """
55 |
56 | def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
57 | super(CustomDatasetFolder, self).__init__(root)
58 | self.transform = transform
59 | self.target_transform = target_transform
60 | classes, class_to_idx = self._find_classes(self.root)
61 | samples = make_custom_dataset(self.root, 'robustbench/data/imagenet_test_image_ids.txt',
62 | 'robustbench/data/imagenet_class_to_id_map.json')
63 | if len(samples) == 0:
64 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
65 | "Supported extensions are: " + ",".join(extensions)))
66 |
67 | self.loader = loader
68 | self.extensions = extensions
69 |
70 | self.classes = classes
71 | self.class_to_idx = class_to_idx
72 | self.samples = samples
73 | self.targets = [s[1] for s in samples]
74 |
75 | def _find_classes(self, dir):
76 | """
77 | Finds the class folders in a dataset.
78 | Args:
79 | dir (string): Root directory path.
80 | Returns:
81 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
82 | Ensures:
83 | No class is a subdirectory of another.
84 | """
85 | if sys.version_info >= (3, 5):
86 | # Faster and available in Python 3.5 and above
87 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
88 | else:
89 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
90 | classes.sort()
91 | class_to_idx = {classes[i]: i for i in range(len(classes))}
92 | return classes, class_to_idx
93 |
94 | def __getitem__(self, index):
95 | """
96 | Args:
97 | index (int): Index
98 | Returns:
99 | tuple: (sample, target) where target is class_index of the target class.
100 | """
101 | path, target = self.samples[index]
102 | sample = self.loader(path)
103 | if self.transform is not None:
104 | sample = self.transform(sample)
105 | if self.target_transform is not None:
106 | target = self.target_transform(target)
107 | return sample, target, path
108 |
109 | def __len__(self):
110 | return len(self.samples)
111 |
112 |
113 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
114 |
115 |
116 | def pil_loader(path):
117 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
118 | with open(path, 'rb') as f:
119 | img = Image.open(f)
120 | return img.convert('RGB')
121 |
122 |
123 | def accimage_loader(path):
124 | import accimage
125 | try:
126 | return accimage.Image(path)
127 | except IOError:
128 | # Potentially a decoding problem, fall back to PIL.Image
129 | return pil_loader(path)
130 |
131 |
132 | def default_loader(path):
133 | from torchvision import get_image_backend
134 | if get_image_backend() == 'accimage':
135 | return accimage_loader(path)
136 | else:
137 | return pil_loader(path)
138 |
139 |
140 | class CustomImageFolder(CustomDatasetFolder):
141 | """A generic data loader where the images are arranged in this way: ::
142 | root/dog/xxx.png
143 | root/dog/xxy.png
144 | root/dog/xxz.png
145 | root/cat/123.png
146 | root/cat/nsdf3.png
147 | root/cat/asd932_.png
148 | Args:
149 | root (string): Root directory path.
150 | transform (callable, optional): A function/transform that takes in an PIL image
151 | and returns a transformed version. E.g, ``transforms.RandomCrop``
152 | target_transform (callable, optional): A function/transform that takes in the
153 | target and transforms it.
154 | loader (callable, optional): A function to load an image given its path.
155 | is_valid_file (callable, optional): A function that takes path of an Image file
156 | and check if the file is a valid_file (used to check of corrupt files)
157 | Attributes:
158 | classes (list): List of the class names.
159 | class_to_idx (dict): Dict with items (class_name, class_index).
160 | imgs (list): List of (image path, class_index) tuples
161 | """
162 |
163 | def __init__(self, root, transform=None, target_transform=None,
164 | loader=default_loader, is_valid_file=None):
165 | super(CustomImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
166 | transform=transform,
167 | target_transform=target_transform,
168 | is_valid_file=is_valid_file)
169 |
170 | self.imgs = self.samples
171 |
172 |
173 | if __name__ == '__main__':
174 | data_dir = '/home/scratch/datasets/imagenet/val'
175 | imagenet = CustomImageFolder(data_dir, transforms.Compose([
176 | transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]))
177 |
178 | torch.manual_seed(0)
179 |
180 | test_loader = data.DataLoader(imagenet, batch_size=5000, shuffle=True, num_workers=30)
181 |
182 | x, y, path = next(iter(test_loader))
183 |
184 | with open('path_imgs_2.txt', 'w') as f:
185 | f.write('\n'.join(path))
186 | f.flush()
187 |
188 |
--------------------------------------------------------------------------------
/cifar/cotta.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.jit
6 |
7 | import PIL
8 | import torchvision.transforms as transforms
9 | import my_transforms as my_transforms
10 | from time import time
11 | import logging
12 |
13 |
14 | def get_tta_transforms(gaussian_std: float=0.005, soft=False, clip_inputs=False):
15 | img_shape = (32, 32, 3)
16 | n_pixels = img_shape[0]
17 |
18 | clip_min, clip_max = 0.0, 1.0
19 |
20 | p_hflip = 0.5
21 |
22 | tta_transforms = transforms.Compose([
23 | my_transforms.Clip(0.0, 1.0),
24 | my_transforms.ColorJitterPro(
25 | brightness=[0.8, 1.2] if soft else [0.6, 1.4],
26 | contrast=[0.85, 1.15] if soft else [0.7, 1.3],
27 | saturation=[0.75, 1.25] if soft else [0.5, 1.5],
28 | hue=[-0.03, 0.03] if soft else [-0.06, 0.06],
29 | gamma=[0.85, 1.15] if soft else [0.7, 1.3]
30 | ),
31 | transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'),
32 | transforms.RandomAffine(
33 | degrees=[-8, 8] if soft else [-15, 15],
34 | translate=(1/16, 1/16),
35 | scale=(0.95, 1.05) if soft else (0.9, 1.1),
36 | shear=None,
37 | resample=PIL.Image.BILINEAR,
38 | fillcolor=None
39 | ),
40 | transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]),
41 | transforms.CenterCrop(size=n_pixels),
42 | transforms.RandomHorizontalFlip(p=p_hflip),
43 | my_transforms.GaussianNoise(0, gaussian_std),
44 | my_transforms.Clip(clip_min, clip_max)
45 | ])
46 | return tta_transforms
47 |
48 |
49 | def update_ema_variables(ema_model, model, alpha_teacher):
50 | for ema_param, param in zip(ema_model.parameters(), model.parameters()):
51 | ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
52 | return ema_model
53 |
54 |
55 | class CoTTA(nn.Module):
56 | """CoTTA adapts a model by entropy minimization during testing.
57 |
58 | Once tented, a model adapts itself by updating on every forward.
59 | """
60 | def __init__(self, model, optimizer, steps=1, episodic=False, mt_alpha=0.99, rst_m=0.1, ap=0.9):
61 | super().__init__()
62 | self.model = model
63 | self.optimizer = optimizer
64 | self.steps = steps
65 | assert steps > 0, "cotta requires >= 1 step(s) to forward and update"
66 | self.episodic = episodic
67 |
68 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
69 | copy_model_and_optimizer(self.model, self.optimizer)
70 | self.transform = get_tta_transforms()
71 | self.mt = mt_alpha
72 | self.rst = rst_m
73 | self.ap = ap
74 |
75 | def forward(self, x):
76 | if self.episodic:
77 | self.reset()
78 |
79 | for _ in range(self.steps):
80 | outputs = self.forward_and_adapt(x, self.model, self.optimizer)
81 |
82 | return outputs
83 |
84 | def reset(self):
85 | if self.model_state is None or self.optimizer_state is None:
86 | raise Exception("cannot reset without saved model/optimizer state")
87 | load_model_and_optimizer(self.model, self.optimizer,
88 | self.model_state, self.optimizer_state)
89 | # Use this line to also restore the teacher model
90 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
91 | copy_model_and_optimizer(self.model, self.optimizer)
92 |
93 |
94 | @torch.enable_grad() # ensure grads in possible no grad context for testing
95 | def forward_and_adapt(self, x, model, optimizer):
96 | outputs = self.model(x)
97 | # Teacher Prediction
98 | anchor_prob = torch.nn.functional.softmax(self.model_anchor(x), dim=1).max(1)[0]
99 | standard_ema = self.model_ema(x)
100 | # Augmentation-averaged Prediction
101 | N = 32
102 | outputs_emas = []
103 | for i in range(N):
104 | outputs_ = self.model_ema(self.transform(x)).detach()
105 | outputs_emas.append(outputs_)
106 | # Threshold choice discussed in supplementary
107 | if anchor_prob.mean(0) torch.Tensor:
131 | """Entropy of softmax distribution from logits."""
132 | return -(x_ema.softmax(1) * x.log_softmax(1)).sum(1)
133 |
134 | def collect_params(model):
135 | """Collect all trainable parameters.
136 |
137 | Walk the model's modules and collect all parameters.
138 | Return the parameters and their names.
139 |
140 | Note: other choices of parameterization are possible!
141 | """
142 | params = []
143 | names = []
144 | for nm, m in model.named_modules():
145 | if True:#isinstance(m, nn.BatchNorm2d): collect all
146 | for np, p in m.named_parameters():
147 | if np in ['weight', 'bias'] and p.requires_grad:
148 | params.append(p)
149 | names.append(f"{nm}.{np}")
150 | print(nm, np)
151 | return params, names
152 |
153 |
154 | def copy_model_and_optimizer(model, optimizer):
155 | """Copy the model and optimizer states for resetting after adaptation."""
156 | model_state = deepcopy(model.state_dict())
157 | model_anchor = deepcopy(model)
158 | optimizer_state = deepcopy(optimizer.state_dict())
159 | ema_model = deepcopy(model)
160 | for param in ema_model.parameters():
161 | param.detach_()
162 | return model_state, optimizer_state, ema_model, model_anchor
163 |
164 |
165 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
166 | """Restore the model and optimizer states from copies."""
167 | model.load_state_dict(model_state, strict=True)
168 | optimizer.load_state_dict(optimizer_state)
169 |
170 |
171 | def configure_model(model):
172 | """Configure model for use with tent."""
173 | # train mode, because tent optimizes the model to minimize entropy
174 | model.train()
175 | # disable grad, to (re-)enable only what we update
176 | model.requires_grad_(False)
177 | # enable all trainable
178 | for m in model.modules():
179 | if isinstance(m, nn.BatchNorm2d):
180 | m.requires_grad_(True)
181 | # force use of batch stats in train and eval modes
182 | m.track_running_stats = False
183 | m.running_mean = None
184 | m.running_var = None
185 | else:
186 | m.requires_grad_(True)
187 | return model
188 |
189 |
190 | def check_model(model):
191 | """Check model for compatability with tent."""
192 | is_training = model.training
193 | assert is_training, "tent needs train mode: call model.train()"
194 | param_grads = [p.requires_grad for p in model.parameters()]
195 | has_any_params = any(param_grads)
196 | has_all_params = all(param_grads)
197 | assert has_any_params, "tent needs params to update: " \
198 | "check which require grad"
199 | assert not has_all_params, "tent should not update all params: " \
200 | "check which require grad"
201 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
202 | assert has_bn, "tent needs normalization for its optimization"
203 |
--------------------------------------------------------------------------------
/imagenet/cotta.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.jit
6 |
7 | import PIL
8 | import torchvision.transforms as transforms
9 | import my_transforms as my_transforms
10 | from time import time
11 | import logging
12 |
13 |
14 | def get_tta_transforms(gaussian_std: float=0.005, soft=False, clip_inputs=False):
15 | img_shape = (224, 224, 3)
16 | n_pixels = img_shape[0]
17 |
18 | clip_min, clip_max = 0.0, 1.0
19 |
20 | p_hflip = 0.5
21 |
22 | tta_transforms = transforms.Compose([
23 | my_transforms.Clip(0.0, 1.0),
24 | my_transforms.ColorJitterPro(
25 | brightness=[0.8, 1.2] if soft else [0.6, 1.4],
26 | contrast=[0.85, 1.15] if soft else [0.7, 1.3],
27 | saturation=[0.75, 1.25] if soft else [0.5, 1.5],
28 | hue=[-0.03, 0.03] if soft else [-0.06, 0.06],
29 | gamma=[0.85, 1.15] if soft else [0.7, 1.3]
30 | ),
31 | transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'),
32 | transforms.RandomAffine(
33 | degrees=[-8, 8] if soft else [-15, 15],
34 | translate=(1/16, 1/16),
35 | scale=(0.95, 1.05) if soft else (0.9, 1.1),
36 | shear=None,
37 | resample=PIL.Image.BILINEAR,
38 | fillcolor=None
39 | ),
40 | transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]),
41 | transforms.CenterCrop(size=n_pixels),
42 | transforms.RandomHorizontalFlip(p=p_hflip),
43 | my_transforms.GaussianNoise(0, gaussian_std),
44 | my_transforms.Clip(clip_min, clip_max)
45 | ])
46 | return tta_transforms
47 |
48 |
49 | def update_ema_variables(ema_model, model, alpha_teacher):#, iteration):
50 | for ema_param, param in zip(ema_model.parameters(), model.parameters()):
51 | ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
52 | return ema_model
53 |
54 |
55 | class CoTTA(nn.Module):
56 | """CoTTA adapts a model by entropy minimization during testing.
57 |
58 | Once tented, a model adapts itself by updating on every forward.
59 | """
60 | def __init__(self, model, optimizer, steps=1, episodic=False):
61 | super().__init__()
62 | self.model = model
63 | self.optimizer = optimizer
64 | self.steps = steps
65 | assert steps > 0, "cotta requires >= 1 step(s) to forward and update"
66 | self.episodic = episodic
67 |
68 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
69 | copy_model_and_optimizer(self.model, self.optimizer)
70 | self.transform = get_tta_transforms()
71 |
72 | def forward(self, x):
73 | if self.episodic:
74 | self.reset()
75 |
76 | for _ in range(self.steps):
77 | outputs = self.forward_and_adapt(x, self.model, self.optimizer)
78 |
79 | return outputs
80 |
81 | def reset(self):
82 | if self.model_state is None or self.optimizer_state is None:
83 | raise Exception("cannot reset without saved model/optimizer state")
84 | load_model_and_optimizer(self.model, self.optimizer,
85 | self.model_state, self.optimizer_state)
86 | # use this line if you want to reset the teacher model as well. Maybe you also
87 | # want to del self.model_ema first to save gpu memory.
88 | self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
89 | copy_model_and_optimizer(self.model, self.optimizer)
90 |
91 |
92 | @torch.enable_grad() # ensure grads in possible no grad context for testing
93 | def forward_and_adapt(self, x, model, optimizer):
94 | outputs = self.model(x)
95 | self.model_ema.train()
96 | # Teacher Prediction
97 | anchor_prob = torch.nn.functional.softmax(self.model_anchor(x), dim=1).max(1)[0]
98 | standard_ema = self.model_ema(x)
99 | # Augmentation-averaged Prediction
100 | N = 32
101 | outputs_emas = []
102 | to_aug = anchor_prob.mean(0)<0.1
103 | if to_aug:
104 | for i in range(N):
105 | outputs_ = self.model_ema(self.transform(x)).detach()
106 | outputs_emas.append(outputs_)
107 | # Threshold choice discussed in supplementary
108 | if to_aug:
109 | outputs_ema = torch.stack(outputs_emas).mean(0)
110 | else:
111 | outputs_ema = standard_ema
112 | # Augmentation-averaged Prediction
113 | # Student update
114 | loss = (softmax_entropy(outputs, outputs_ema.detach())).mean(0)
115 | loss.backward()
116 | optimizer.step()
117 | optimizer.zero_grad()
118 | # Teacher update
119 | self.model_ema = update_ema_variables(ema_model = self.model_ema, model = self.model, alpha_teacher=0.999)
120 | # Stochastic restore
121 | if True:
122 | for nm, m in self.model.named_modules():
123 | for npp, p in m.named_parameters():
124 | if npp in ['weight', 'bias'] and p.requires_grad:
125 | mask = (torch.rand(p.shape)<0.001).float().cuda()
126 | with torch.no_grad():
127 | p.data = self.model_state[f"{nm}.{npp}"] * mask + p * (1.-mask)
128 | return outputs_ema
129 |
130 |
131 | @torch.jit.script
132 | def softmax_entropy(x, x_ema):# -> torch.Tensor:
133 | """Entropy of softmax distribution from logits."""
134 | return -0.5*(x_ema.softmax(1) * x.log_softmax(1)).sum(1)-0.5*(x.softmax(1) * x_ema.log_softmax(1)).sum(1)
135 |
136 | def collect_params(model):
137 | """Collect all trainable parameters.
138 |
139 | Walk the model's modules and collect all parameters.
140 | Return the parameters and their names.
141 |
142 | Note: other choices of parameterization are possible!
143 | """
144 | params = []
145 | names = []
146 | for nm, m in model.named_modules():
147 | if True:#isinstance(m, nn.BatchNorm2d): collect all
148 | for np, p in m.named_parameters():
149 | if np in ['weight', 'bias'] and p.requires_grad:
150 | params.append(p)
151 | names.append(f"{nm}.{np}")
152 | print(nm, np)
153 | return params, names
154 |
155 |
156 | def copy_model_and_optimizer(model, optimizer):
157 | """Copy the model and optimizer states for resetting after adaptation."""
158 | model_state = deepcopy(model.state_dict())
159 | model_anchor = deepcopy(model)
160 | optimizer_state = deepcopy(optimizer.state_dict())
161 | ema_model = deepcopy(model)
162 | for param in ema_model.parameters():
163 | param.detach_()
164 | return model_state, optimizer_state, ema_model, model_anchor
165 |
166 |
167 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
168 | """Restore the model and optimizer states from copies."""
169 | model.load_state_dict(model_state, strict=True)
170 | optimizer.load_state_dict(optimizer_state)
171 |
172 |
173 | def configure_model(model):
174 | """Configure model for use with tent."""
175 | # train mode, because tent optimizes the model to minimize entropy
176 | model.train()
177 | # disable grad, to (re-)enable only what we update
178 | model.requires_grad_(False)
179 | # enable all trainable
180 | for m in model.modules():
181 | if isinstance(m, nn.BatchNorm2d):
182 | m.requires_grad_(True)
183 | # force use of batch stats in train and eval modes
184 | m.track_running_stats = False
185 | m.running_mean = None
186 | m.running_var = None
187 | else:
188 | m.requires_grad_(True)
189 | return model
190 |
191 |
192 | def check_model(model):
193 | """Check model for compatability with tent."""
194 | is_training = model.training
195 | assert is_training, "tent needs train mode: call model.train()"
196 | param_grads = [p.requires_grad for p in model.parameters()]
197 | has_any_params = any(param_grads)
198 | has_all_params = all(param_grads)
199 | assert has_any_params, "tent needs params to update: " \
200 | "check which require grad"
201 | assert not has_all_params, "tent should not update all params: " \
202 | "check which require grad"
203 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
204 | assert has_bn, "tent needs normalization for its optimization"
205 |
--------------------------------------------------------------------------------
/cifar/robustbench/eval.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from argparse import Namespace
3 | from pathlib import Path
4 | from typing import Dict, Optional, Sequence, Tuple, Union
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | import random
10 | from autoattack import AutoAttack
11 | from torch import nn
12 | from tqdm import tqdm
13 |
14 | from robustbench.data import CORRUPTIONS, load_clean_dataset, \
15 | CORRUPTION_DATASET_LOADERS
16 | from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
17 | from robustbench.utils import clean_accuracy, load_model, parse_args, update_json
18 | from robustbench.model_zoo import model_dicts as all_models
19 |
20 |
21 | def benchmark(model: Union[nn.Module, Sequence[nn.Module]],
22 | n_examples: int = 10000,
23 | dataset: Union[str,
24 | BenchmarkDataset] = BenchmarkDataset.cifar_10,
25 | threat_model: Union[str, ThreatModel] = ThreatModel.Linf,
26 | to_disk: bool = False,
27 | model_name: Optional[str] = None,
28 | data_dir: str = "./data",
29 | device: Optional[Union[torch.device,
30 | Sequence[torch.device]]] = None,
31 | batch_size: int = 32,
32 | eps: Optional[float] = None,
33 | log_path: Optional[str] = None) -> Tuple[float, float]:
34 | """Benchmarks the given model(s).
35 |
36 | It is possible to benchmark on 3 different threat models, and to save the results on disk. In
37 | the future benchmarking multiple models in parallel is going to be possible.
38 |
39 | :param model: The model to benchmark.
40 | :param n_examples: The number of examples to use to benchmark the model.
41 | :param dataset: The dataset to use to benchmark. Must be one of {cifar10, cifar100}
42 | :param threat_model: The threat model to use to benchmark, must be one of {L2, Linf
43 | corruptions}
44 | :param to_disk: Whether the results must be saved on disk as .json.
45 | :param model_name: The name of the model to use to save the results. Must be specified if
46 | to_json is True.
47 | :param data_dir: The directory where the dataset is or where the dataset must be downloaded.
48 | :param device: The device to run the computations.
49 | :param batch_size: The batch size to run the computations. The larger, the faster the
50 | evaluation.
51 | :param eps: The epsilon to use for L2 and Linf threat models. Must not be specified for
52 | corruptions threat model.
53 |
54 | :return: A Tuple with the clean accuracy and the accuracy in the given threat model.
55 | """
56 | if isinstance(model, Sequence) or isinstance(device, Sequence):
57 | # Multiple models evaluation in parallel not yet implemented
58 | raise NotImplementedError
59 |
60 | try:
61 | if model.training:
62 | warnings.warn(Warning("The given model is *not* in eval mode."))
63 | except AttributeError:
64 | warnings.warn(
65 | Warning(
66 | "It is not possible to asses if the model is in eval mode"))
67 |
68 | dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
69 | threat_model_: ThreatModel = ThreatModel(threat_model)
70 |
71 | device = device or torch.device("cpu")
72 | model = model.to(device)
73 |
74 | if dataset == 'imagenet':
75 | prepr = all_models[dataset_][threat_model_][model_name]['preprocessing']
76 | else:
77 | prepr = 'none'
78 |
79 | clean_x_test, clean_y_test = load_clean_dataset(dataset_, n_examples,
80 | data_dir, prepr)
81 |
82 | accuracy = clean_accuracy(model,
83 | clean_x_test,
84 | clean_y_test,
85 | batch_size=batch_size,
86 | device=device)
87 | print(f'Clean accuracy: {accuracy:.2%}')
88 |
89 | if threat_model_ in {ThreatModel.Linf, ThreatModel.L2}:
90 | if eps is None:
91 | raise ValueError(
92 | "If the threat model is L2 or Linf, `eps` must be specified.")
93 |
94 | adversary = AutoAttack(model,
95 | norm=threat_model_.value,
96 | eps=eps,
97 | version='standard',
98 | device=device,
99 | log_path=log_path)
100 | x_adv = adversary.run_standard_evaluation(clean_x_test, clean_y_test, bs=batch_size)
101 | adv_accuracy = clean_accuracy(model,
102 | x_adv,
103 | clean_y_test,
104 | batch_size=batch_size,
105 | device=device)
106 | elif threat_model_ == ThreatModel.corruptions:
107 | corruptions = CORRUPTIONS
108 | print(f"Evaluating over {len(corruptions)} corruptions")
109 | # Save into a dict to make a Pandas DF with nested index
110 | adv_accuracy = corruptions_evaluation(batch_size, data_dir, dataset_,
111 | device, model, n_examples,
112 | to_disk, prepr, model_name)
113 | else:
114 | raise NotImplementedError
115 | print(f'Adversarial accuracy: {adv_accuracy:.2%}')
116 |
117 | if to_disk:
118 | if model_name is None:
119 | raise ValueError(
120 | "If `to_disk` is True, `model_name` should be specified.")
121 |
122 | update_json(dataset_, threat_model_, model_name, accuracy,
123 | adv_accuracy, eps)
124 |
125 | return accuracy, adv_accuracy
126 |
127 |
128 | def corruptions_evaluation(batch_size: int, data_dir: str,
129 | dataset: BenchmarkDataset, device: torch.device,
130 | model: nn.Module, n_examples: int, to_disk: bool,
131 | prepr: str, model_name: Optional[str]) -> float:
132 | if to_disk and model_name is None:
133 | raise ValueError(
134 | "If `to_disk` is True, `model_name` should be specified.")
135 |
136 | corruptions = CORRUPTIONS
137 | model_results_dict: Dict[Tuple[str, int], float] = {}
138 | for corruption in tqdm(corruptions):
139 | for severity in range(1, 6):
140 | x_corrupt, y_corrupt = CORRUPTION_DATASET_LOADERS[dataset](
141 | n_examples,
142 | severity,
143 | data_dir,
144 | shuffle=False,
145 | corruptions=[corruption],
146 | prepr=prepr)
147 |
148 | corruption_severity_accuracy = clean_accuracy(
149 | model,
150 | x_corrupt,
151 | y_corrupt,
152 | batch_size=batch_size,
153 | device=device)
154 | print('corruption={}, severity={}: {:.2%} accuracy'.format(
155 | corruption, severity, corruption_severity_accuracy))
156 |
157 | model_results_dict[(corruption,
158 | severity)] = corruption_severity_accuracy
159 |
160 | model_results = pd.DataFrame(model_results_dict, index=[model_name])
161 | adv_accuracy = model_results.values.mean()
162 |
163 | if not to_disk:
164 | return adv_accuracy
165 |
166 | # Save disaggregated results on disk
167 | existing_results_path = Path(
168 | "model_info"
169 | ) / dataset.value / "corruptions" / "unaggregated_results.csv"
170 | if not existing_results_path.parent.exists():
171 | existing_results_path.parent.mkdir(parents=True, exist_ok=True)
172 | try:
173 | existing_results = pd.read_csv(existing_results_path,
174 | header=[0, 1],
175 | index_col=0)
176 | existing_results.columns = existing_results.columns.set_levels([
177 | existing_results.columns.levels[0],
178 | existing_results.columns.levels[1].astype(int)
179 | ])
180 | full_results = pd.concat([existing_results, model_results])
181 | except FileNotFoundError:
182 | full_results = model_results
183 | full_results.to_csv(existing_results_path)
184 |
185 | return adv_accuracy
186 |
187 |
188 | def main(args: Namespace) -> None:
189 | torch.manual_seed(args.seed)
190 | torch.cuda.manual_seed(args.seed)
191 | np.random.seed(args.seed)
192 | random.seed(args.seed)
193 |
194 | model = load_model(args.model_name,
195 | model_dir=args.model_dir,
196 | dataset=args.dataset,
197 | threat_model=args.threat_model)
198 |
199 | model.eval()
200 |
201 | device = torch.device(args.device)
202 | benchmark(model,
203 | n_examples=args.n_ex,
204 | dataset=args.dataset,
205 | threat_model=args.threat_model,
206 | to_disk=args.to_disk,
207 | model_name=args.model_name,
208 | data_dir=args.data_dir,
209 | device=device,
210 | batch_size=args.batch_size,
211 | eps=args.eps)
212 |
213 |
214 | if __name__ == '__main__':
215 | # Example:
216 | # python -m robustbench.eval --n_ex=5000 --dataset=imagenet --threat_model=Linf \
217 | # --model_name=Salman2020Do_R18 --data_dir=/tmldata1/andriush/imagenet/val \
218 | # --batch_size=128 --eps=0.0156862745
219 | args_ = parse_args()
220 | main(args_)
221 |
--------------------------------------------------------------------------------