├── README.md ├── configs ├── eval │ ├── cifar10 │ │ ├── adversarial.json │ │ ├── hmc_p=1.json │ │ ├── hmc_p=10.json │ │ ├── hmc_p=100.json │ │ ├── hmc_p=1000.json │ │ ├── random_sampling_p=1.json │ │ ├── random_sampling_p=10.json │ │ ├── random_sampling_p=100.json │ │ ├── random_sampling_p=1000.json │ │ └── standard.json │ └── mnist │ │ ├── adversarial.json │ │ ├── hmc_p=1.json │ │ ├── hmc_p=10.json │ │ ├── hmc_p=100.json │ │ ├── hmc_p=1000.json │ │ ├── random_sampling_p=1.json │ │ ├── random_sampling_p=10.json │ │ ├── random_sampling_p=100.json │ │ ├── random_sampling_p=1000.json │ │ └── standard.json └── train │ ├── cifar10 │ ├── adversarial.json │ ├── random_sampling_p=1.json │ ├── random_sampling_p=10.json │ ├── random_sampling_p=100.json │ ├── random_sampling_p=1000.json │ ├── rs_p=100_discrete.json │ ├── rs_p=10_discrete.json │ ├── rs_p=1_discrete.json │ ├── standard.json │ └── standard_discrete.json │ └── mnist │ ├── adversarial.json │ ├── hmc_p=1.json │ ├── hmc_p=10.json │ ├── hmc_p=100.json │ ├── hmc_p=1000.json │ ├── random_sampling_p=1.json │ ├── random_sampling_p=10.json │ ├── random_sampling_p=100.json │ ├── random_sampling_p=1000.json │ └── standard.json ├── datasets.py ├── eval_discrete.py ├── models ├── mnist.py └── preactresnet.py ├── requirements.txt ├── train.py ├── train_discrete.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Robustness between the worst and average case 2 | 3 | *A repository that implements intermediate robustness training and evaluation from the NeurIPS 2021 paper [Robustness between the worst and average case](https://proceedings.neurips.cc/paper/2021/file/ea4c796cccfc3899b5f9ae2874237c20-Paper.pdf). 4 | Created by [Leslie Rice](https://leslierice1.github.io/), [Anna Bair](https://annaebair.github.io/), [Huan Zhang](https://www.huan-zhang.com/) and [Zico Kolter](http://zicokolter.com).* 5 | 6 | ## Installation and usage 7 | - To install all required packages run: `pip install -r requirements.txt`. 8 | - Pretrained model weights can be downloaded [here](https://drive.google.com/drive/folders/1YCFXzdx2dGjmQGU30v6CRhUORjQsjHKV?usp=sharing). 9 | - To train (l_infty perturbations), run `python train.py -c {path_to_training_config_file}.json`. 10 | - To evaluate (l_infty perturbations), run `python train.py -c {path_to_evaluation_config_file}.json`. 11 | - To train (spatial transformations), run `python train_discrete.py -c {path_to_training_config_file}.json`. 12 | - To evaluate (spatial transformations), run `python eval_discrete.py --checkpoint {path_to_model_checkpoint}.pth`. 13 | -------------------------------------------------------------------------------- /configs/eval/cifar10/adversarial.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "adversarial", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 50, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/adversarial_seed=2/checkpoints/checkpoint_10.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/hmc_p=1.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=1", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/hmc_p=1/checkpoints/checkpoint_99.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/hmc_p=10.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=10", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/hmc_p=10/checkpoints/checkpoint_99.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/hmc_p=100.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=100", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/hmc_p=100/checkpoints/checkpoint_99.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/hmc_p=1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=1000", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/hmc_p=1000/checkpoints/checkpoint_99.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/random_sampling_p=1.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/random_sampling_p=1_seed=2/checkpoints/checkpoint_199.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/random_sampling_p=10.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=10", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/random_sampling_p=10_seed=3/checkpoints/checkpoint_199.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/random_sampling_p=100.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=100", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/random_sampling_p=100_seed=2/checkpoints/checkpoint_199.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/random_sampling_p=1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1000", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/random_sampling_p=1000_seed=3/checkpoints/checkpoint_199.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/cifar10/standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "standard", 3 | "seed": 3, 4 | "training": null, 5 | "data": { 6 | "dataset": "cifar10", 7 | "num_workers": 2, 8 | "root": "../data/cifar10", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.03, 24 | "m": 50, 25 | "restarts": 10 26 | } 27 | }, 28 | { 29 | "type": "random_sampling", 30 | "params": { 31 | "epsilon": 0.03, 32 | "p": 1, 33 | "m": 500 34 | } 35 | }, 36 | { 37 | "type": "random_sampling", 38 | "params": { 39 | "epsilon": 0.03, 40 | "p": 10, 41 | "m": 500 42 | } 43 | }, 44 | { 45 | "type": "random_sampling", 46 | "params": { 47 | "epsilon": 0.03, 48 | "p": 100, 49 | "m": 500 50 | } 51 | }, 52 | { 53 | "type": "random_sampling", 54 | "params": { 55 | "epsilon": 0.03, 56 | "p": 1000, 57 | "m": 500 58 | } 59 | }, 60 | { 61 | "type": "hmc", 62 | "params": { 63 | "epsilon": 0.03, 64 | "p": 1, 65 | "m": 100, 66 | "l": 10, 67 | "path_len": 0.6, 68 | "sigma": 0.1, 69 | "anneal_theta": true 70 | } 71 | }, 72 | { 73 | "type": "hmc", 74 | "params": { 75 | "epsilon": 0.03, 76 | "p": 10, 77 | "m": 50, 78 | "l": 10, 79 | "path_len": 0.6, 80 | "sigma": 0.1, 81 | "anneal_theta": true 82 | } 83 | }, 84 | { 85 | "type": "hmc", 86 | "params": { 87 | "epsilon": 0.03, 88 | "p": 100, 89 | "m": 50, 90 | "l": 10, 91 | "path_len": 0.4, 92 | "sigma": 0.1, 93 | "anneal_theta": true 94 | } 95 | }, 96 | { 97 | "type": "hmc", 98 | "params": { 99 | "epsilon": 0.03, 100 | "p": 1000, 101 | "m": 50, 102 | "l": 10, 103 | "path_len": 0.09, 104 | "sigma": 0.1, 105 | "anneal_theta": true 106 | } 107 | } 108 | ], 109 | "checkpoint_filename": "experiments/cifar10/standard_seed=3/checkpoints/checkpoint_199.pth" 110 | } 111 | -------------------------------------------------------------------------------- /configs/eval/mnist/adversarial.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "adversarial", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/adversarial_seed=3/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/hmc_p=1.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=1", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/hmc_p=1_seed=2/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/hmc_p=10.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=10", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/hmc_p=10_seed=2/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/hmc_p=100.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=100", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/hmc_p=100_seed=2/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/hmc_p=1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=1000", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/hmc_p=1000_seed=2/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/random_sampling_p=1.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/random_sampling_p=1_seed=2/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/random_sampling_p=10.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=10", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/random_sampling_p=10_seed=3/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/random_sampling_p=100.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=100", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/random_sampling_p=100_seed=3/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/random_sampling_p=1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1000", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/random_sampling_p=1000_seed=2/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/eval/mnist/standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "standard", 3 | "seed": 1, 4 | "training": null, 5 | "data": { 6 | "dataset": "mnist", 7 | "num_workers": 1, 8 | "root": "../data/mnist", 9 | "training": null, 10 | "test": { 11 | "batch_size": 128 12 | }, 13 | "use_half": false 14 | }, 15 | "evaluations": [ 16 | { 17 | "type": "standard", 18 | "params": {} 19 | }, 20 | { 21 | "type": "adversarial", 22 | "params": { 23 | "epsilon": 0.3, 24 | "m": 100 25 | } 26 | }, 27 | { 28 | "type": "random_sampling", 29 | "params": { 30 | "m": 2000, 31 | "p": 1, 32 | "epsilon": 0.3 33 | } 34 | }, 35 | { 36 | "type": "random_sampling", 37 | "params": { 38 | "m": 2000, 39 | "p": 10, 40 | "epsilon": 0.3 41 | } 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 2000, 47 | "p": 100, 48 | "epsilon": 0.3 49 | } 50 | }, 51 | { 52 | "type": "random_sampling", 53 | "params": { 54 | "m": 2000, 55 | "p": 1000, 56 | "epsilon": 0.3 57 | } 58 | }, 59 | { 60 | "type": "hmc", 61 | "params": { 62 | "epsilon": 0.3, 63 | "p": 1, 64 | "m": 100, 65 | "l": 20, 66 | "path_len": 0.6, 67 | "sigma": 0.1 68 | } 69 | }, 70 | { 71 | "type": "hmc", 72 | "params": { 73 | "epsilon": 0.3, 74 | "p": 10, 75 | "m": 100, 76 | "l": 20, 77 | "path_len": 0.6, 78 | "sigma": 0.1 79 | } 80 | }, 81 | { 82 | "type": "hmc", 83 | "params": { 84 | "epsilon": 0.3, 85 | "p": 100, 86 | "m": 100, 87 | "l": 20, 88 | "path_len": 0.4, 89 | "sigma": 0.1 90 | } 91 | }, 92 | { 93 | "type": "hmc", 94 | "params": { 95 | "epsilon": 0.3, 96 | "p": 1000, 97 | "m": 100, 98 | "l": 20, 99 | "path_len": 0.2, 100 | "sigma": 0.1 101 | } 102 | } 103 | ], 104 | "checkpoint_filename": "experiments/mnist/standard_10/checkpoints/checkpoint_9.pth" 105 | } 106 | -------------------------------------------------------------------------------- /configs/train/cifar10/adversarial.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "adversarial", 3 | "seed": 1, 4 | "training": { 5 | "type": "adversarial", 6 | "params": { 7 | "m": 10, 8 | "epsilon": 0.03, 9 | "alpha_scale": 2.5 10 | }, 11 | "epochs": 20, 12 | "test_interval": 1, 13 | "checkpoint_interval": 1, 14 | "opt": { 15 | "type": "sgd", 16 | "params": { 17 | "lr": 0.1, 18 | "momentum": 0.9, 19 | "weight_decay": 5e-4, 20 | "nesterov": true 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "cifar10", 27 | "num_workers": 2, 28 | "root": "../data/cifar10", 29 | "training": { 30 | "flip_crop": false, 31 | "batch_size": 128 32 | }, 33 | "test": { 34 | "batch_size": 128 35 | }, 36 | "use_half": false 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "adversarial", 45 | "params": { 46 | "m": 20, 47 | "epsilon": 0.03, 48 | "alpha_scale": 2.5 49 | } 50 | } 51 | ] 52 | } 53 | -------------------------------------------------------------------------------- /configs/train/cifar10/random_sampling_p=1.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 10, 8 | "p": 1, 9 | "epsilon": 0.03 10 | }, 11 | "epochs": 200, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "sgd", 16 | "params": { 17 | "lr": 0.1, 18 | "momentum": 0.9, 19 | "weight_decay": 5e-4, 20 | "nesterov": true 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "cifar10", 27 | "num_workers": 2, 28 | "root": "../data/cifar10", 29 | "training": { 30 | "flip_crop": false, 31 | "batch_size": 128 32 | }, 33 | "test": { 34 | "batch_size": 128 35 | }, 36 | "use_half": false 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 10, 47 | "p": 1, 48 | "epsilon": 0.03 49 | } 50 | } 51 | ] 52 | } 53 | -------------------------------------------------------------------------------- /configs/train/cifar10/random_sampling_p=10.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=10", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 10, 8 | "p": 10, 9 | "epsilon": 0.03 10 | }, 11 | "epochs": 200, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "sgd", 16 | "params": { 17 | "lr": 0.1, 18 | "momentum": 0.9, 19 | "weight_decay": 5e-4, 20 | "nesterov": true 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "cifar10", 27 | "num_workers": 2, 28 | "root": "../data/cifar10", 29 | "training": { 30 | "flip_crop": false, 31 | "batch_size": 128 32 | }, 33 | "test": { 34 | "batch_size": 128 35 | }, 36 | "use_half": false 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 10, 47 | "p": 10, 48 | "epsilon": 0.03 49 | } 50 | } 51 | ] 52 | } 53 | -------------------------------------------------------------------------------- /configs/train/cifar10/random_sampling_p=100.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=100", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 10, 8 | "p": 100, 9 | "epsilon": 0.03 10 | }, 11 | "epochs": 200, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "sgd", 16 | "params": { 17 | "lr": 0.1, 18 | "momentum": 0.9, 19 | "weight_decay": 5e-4, 20 | "nesterov": true 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "cifar10", 27 | "num_workers": 2, 28 | "root": "../data/cifar10", 29 | "training": { 30 | "flip_crop": false, 31 | "batch_size": 128 32 | }, 33 | "test": { 34 | "batch_size": 128 35 | }, 36 | "use_half": false 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 10, 47 | "p": 100, 48 | "epsilon": 0.03 49 | } 50 | } 51 | ] 52 | } 53 | -------------------------------------------------------------------------------- /configs/train/cifar10/random_sampling_p=1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1000", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 10, 8 | "p": 1000, 9 | "epsilon": 0.03 10 | }, 11 | "epochs": 200, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "sgd", 16 | "params": { 17 | "lr": 0.1, 18 | "momentum": 0.9, 19 | "weight_decay": 5e-4, 20 | "nesterov": true 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "cifar10", 27 | "num_workers": 2, 28 | "root": "../data/cifar10", 29 | "training": { 30 | "flip_crop": false, 31 | "batch_size": 128 32 | }, 33 | "test": { 34 | "batch_size": 128 35 | }, 36 | "use_half": false 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "m": 10, 47 | "p": 1000, 48 | "epsilon": 0.03 49 | } 50 | } 51 | ] 52 | } 53 | -------------------------------------------------------------------------------- /configs/train/cifar10/rs_p=100_discrete.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "rs_p=100_200_epochs", 3 | "seed": 1, 4 | "training": { 5 | "type": "discrete_random_sampling", 6 | "params": { 7 | "m": 10, 8 | "p": 100 9 | }, 10 | "epochs": 200, 11 | "test_interval": 10, 12 | "checkpoint_interval": 10, 13 | "opt": { 14 | "type": "sgd", 15 | "params": { 16 | "lr": 0.1, 17 | "momentum": 0.9, 18 | "weight_decay": 5e-4, 19 | "nesterov": true 20 | } 21 | }, 22 | "lr_schedule": "cyclic" 23 | }, 24 | "data": { 25 | "dataset": "cifar10", 26 | "num_workers": 2, 27 | "root": "../data/cifar10", 28 | "training": { 29 | "flip_crop": true, 30 | "batch_size": 128 31 | }, 32 | "test": { 33 | "batch_size": 128 34 | } 35 | }, 36 | "evaluations": [ 37 | { 38 | "type": "standard", 39 | "params": {} 40 | } 41 | ] 42 | } 43 | -------------------------------------------------------------------------------- /configs/train/cifar10/rs_p=10_discrete.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "rs_p=10_200_epochs", 3 | "seed": 1, 4 | "training": { 5 | "type": "discrete_random_sampling", 6 | "params": { 7 | "m": 10, 8 | "p": 10 9 | }, 10 | "epochs": 200, 11 | "test_interval": 10, 12 | "checkpoint_interval": 10, 13 | "opt": { 14 | "type": "sgd", 15 | "params": { 16 | "lr": 0.1, 17 | "momentum": 0.9, 18 | "weight_decay": 5e-4, 19 | "nesterov": true 20 | } 21 | }, 22 | "lr_schedule": "cyclic" 23 | }, 24 | "data": { 25 | "dataset": "cifar10", 26 | "num_workers": 2, 27 | "root": "../data/cifar10", 28 | "training": { 29 | "flip_crop": true, 30 | "batch_size": 128 31 | }, 32 | "test": { 33 | "batch_size": 128 34 | } 35 | }, 36 | "evaluations": [ 37 | { 38 | "type": "standard", 39 | "params": {} 40 | } 41 | ] 42 | } 43 | -------------------------------------------------------------------------------- /configs/train/cifar10/rs_p=1_discrete.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "rs_p=1_200_epochs", 3 | "seed": 1, 4 | "training": { 5 | "type": "discrete_random_sampling", 6 | "params": { 7 | "m": 10, 8 | "p": 1 9 | }, 10 | "epochs": 200, 11 | "test_interval": 10, 12 | "checkpoint_interval": 10, 13 | "opt": { 14 | "type": "sgd", 15 | "params": { 16 | "lr": 0.1, 17 | "momentum": 0.9, 18 | "weight_decay": 5e-4, 19 | "nesterov": true 20 | } 21 | }, 22 | "lr_schedule": "cyclic" 23 | }, 24 | "data": { 25 | "dataset": "cifar10", 26 | "num_workers": 2, 27 | "root": "../data/cifar10", 28 | "training": { 29 | "flip_crop": true, 30 | "batch_size": 128 31 | }, 32 | "test": { 33 | "batch_size": 128 34 | } 35 | }, 36 | "evaluations": [ 37 | { 38 | "type": "standard", 39 | "params": {} 40 | } 41 | ] 42 | } 43 | -------------------------------------------------------------------------------- /configs/train/cifar10/standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "standard", 3 | "seed": 1, 4 | "training": { 5 | "type": "standard", 6 | "params": {}, 7 | "epochs": 200, 8 | "test_interval": 10, 9 | "checkpoint_interval": 10, 10 | "opt": { 11 | "type": "sgd", 12 | "params": { 13 | "lr": 0.1, 14 | "momentum": 0.9, 15 | "weight_decay": 5e-4, 16 | "nesterov": true 17 | } 18 | }, 19 | "lr_schedule": "multi_step" 20 | }, 21 | "data": { 22 | "dataset": "cifar10", 23 | "num_workers": 2, 24 | "root": "../data/cifar10", 25 | "training": { 26 | "flip_crop": false, 27 | "batch_size": 128 28 | }, 29 | "test": { 30 | "batch_size": 128 31 | }, 32 | "use_half": false 33 | }, 34 | "evaluations": [ 35 | { 36 | "type": "standard", 37 | "params": {} 38 | } 39 | ] 40 | } 41 | -------------------------------------------------------------------------------- /configs/train/cifar10/standard_discrete.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "standard", 3 | "seed": 1, 4 | "training": { 5 | "type": "standard", 6 | "params": {}, 7 | "epochs": 50, 8 | "test_interval": 10, 9 | "checkpoint_interval": 10, 10 | "opt": { 11 | "type": "sgd", 12 | "params": { 13 | "lr": 0.1, 14 | "momentum": 0.9, 15 | "weight_decay": 5e-4, 16 | "nesterov": true 17 | } 18 | } 19 | }, 20 | "data": { 21 | "dataset": "cifar10", 22 | "num_workers": 2, 23 | "root": "../data/cifar10", 24 | "training": { 25 | "flip_crop": true, 26 | "batch_size": 128 27 | }, 28 | "test": { 29 | "batch_size": 128 30 | } 31 | }, 32 | "evaluations": [ 33 | { 34 | "type": "standard", 35 | "params": {} 36 | } 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /configs/train/mnist/adversarial.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "adversarial", 3 | "seed": 1, 4 | "training": { 5 | "type": "adversarial", 6 | "params": { 7 | "m": 50, 8 | "epsilon": 0.3, 9 | "alpha_scale": 2.5 10 | }, 11 | "epochs": 10, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "adam", 16 | "params": { 17 | "lr": 1e-3 18 | } 19 | }, 20 | "lr_schedule": "multi_step" 21 | }, 22 | "data": { 23 | "dataset": "mnist", 24 | "num_workers": 1, 25 | "root": "../data/mnist", 26 | "training": { 27 | "flip_crop": false, 28 | "batch_size": 128 29 | }, 30 | "test": { 31 | "batch_size": 128 32 | }, 33 | "use_half": false 34 | }, 35 | "evaluations": [ 36 | { 37 | "type": "standard", 38 | "params": {} 39 | }, 40 | { 41 | "type": "adversarial", 42 | "params": { 43 | "m": 50, 44 | "epsilon": 0.3, 45 | "alpha_scale": 2.5 46 | } 47 | } 48 | ] 49 | } 50 | -------------------------------------------------------------------------------- /configs/train/mnist/hmc_p=1.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=1", 3 | "seed": 1, 4 | "training": { 5 | "type": "hmc", 6 | "params": { 7 | "p": 1, 8 | "m": 25, 9 | "l": 2, 10 | "epsilon": 0.3, 11 | "path_len": 0.6, 12 | "sigma": 0.1 13 | }, 14 | "epochs": 10, 15 | "test_interval": 10, 16 | "checkpoint_interval": 10, 17 | "opt": { 18 | "type": "adam", 19 | "params": { 20 | "lr": 1e-3 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "mnist", 27 | "num_workers": 2, 28 | "root": "../data/mnist", 29 | "use_half": false, 30 | "training": { 31 | "flip_crop": false, 32 | "batch_size": 128 33 | }, 34 | "test": { 35 | "batch_size": 128 36 | } 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "epsilon": 0.03, 47 | "p": 1, 48 | "m": 50 49 | } 50 | }, 51 | { 52 | "type": "hmc", 53 | "params": { 54 | "p": 1, 55 | "m": 25, 56 | "l": 2, 57 | "epsilon": 0.3, 58 | "path_len": 0.6, 59 | "sigma": 0.1 60 | } 61 | } 62 | ] 63 | } 64 | -------------------------------------------------------------------------------- /configs/train/mnist/hmc_p=10.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=10", 3 | "seed": 1, 4 | "training": { 5 | "type": "hmc", 6 | "params": { 7 | "p": 10, 8 | "m": 25, 9 | "l": 2, 10 | "epsilon": 0.3, 11 | "path_len": 0.6, 12 | "sigma": 0.1 13 | }, 14 | "epochs": 10, 15 | "test_interval": 10, 16 | "checkpoint_interval": 10, 17 | "opt": { 18 | "type": "adam", 19 | "params": { 20 | "lr": 1e-3 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "mnist", 27 | "num_workers": 2, 28 | "root": "../data/mnist", 29 | "use_half": false, 30 | "training": { 31 | "flip_crop": false, 32 | "batch_size": 128 33 | }, 34 | "test": { 35 | "batch_size": 128 36 | } 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "epsilon": 0.03, 47 | "p": 10, 48 | "m": 50 49 | } 50 | }, 51 | { 52 | "type": "hmc", 53 | "params": { 54 | "p": 10, 55 | "m": 25, 56 | "l": 2, 57 | "epsilon": 0.3, 58 | "path_len": 0.6, 59 | "sigma": 0.1 60 | } 61 | } 62 | ] 63 | } 64 | -------------------------------------------------------------------------------- /configs/train/mnist/hmc_p=100.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=100", 3 | "seed": 1, 4 | "training": { 5 | "type": "hmc", 6 | "params": { 7 | "p": 100, 8 | "m": 25, 9 | "l": 2, 10 | "epsilon": 0.3, 11 | "path_len": 0.4, 12 | "sigma": 0.1 13 | }, 14 | "epochs": 10, 15 | "test_interval": 10, 16 | "checkpoint_interval": 10, 17 | "opt": { 18 | "type": "adam", 19 | "params": { 20 | "lr": 1e-3 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "mnist", 27 | "num_workers": 2, 28 | "root": "../data/mnist", 29 | "use_half": false, 30 | "training": { 31 | "flip_crop": false, 32 | "batch_size": 128 33 | }, 34 | "test": { 35 | "batch_size": 128 36 | } 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "epsilon": 0.03, 47 | "p": 100, 48 | "m": 50 49 | } 50 | }, 51 | { 52 | "type": "hmc", 53 | "params": { 54 | "p": 100, 55 | "m": 25, 56 | "l": 2, 57 | "epsilon": 0.3, 58 | "path_len": 0.6, 59 | "sigma": 0.1 60 | } 61 | } 62 | ] 63 | } 64 | -------------------------------------------------------------------------------- /configs/train/mnist/hmc_p=1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "hmc_p=1000", 3 | "seed": 1, 4 | "training": { 5 | "type": "hmc", 6 | "params": { 7 | "p": 1000, 8 | "m": 25, 9 | "l": 2, 10 | "epsilon": 0.3, 11 | "path_len": 0.2, 12 | "sigma": 0.1 13 | }, 14 | "epochs": 10, 15 | "test_interval": 10, 16 | "checkpoint_interval": 10, 17 | "opt": { 18 | "type": "adam", 19 | "params": { 20 | "lr": 1e-3 21 | } 22 | }, 23 | "lr_schedule": "multi_step" 24 | }, 25 | "data": { 26 | "dataset": "mnist", 27 | "num_workers": 2, 28 | "root": "../data/mnist", 29 | "use_half": false, 30 | "training": { 31 | "flip_crop": false, 32 | "batch_size": 128 33 | }, 34 | "test": { 35 | "batch_size": 128 36 | } 37 | }, 38 | "evaluations": [ 39 | { 40 | "type": "standard", 41 | "params": {} 42 | }, 43 | { 44 | "type": "random_sampling", 45 | "params": { 46 | "epsilon": 0.03, 47 | "p": 1000, 48 | "m": 50 49 | } 50 | }, 51 | { 52 | "type": "hmc", 53 | "params": { 54 | "p": 1000, 55 | "m": 25, 56 | "l": 2, 57 | "epsilon": 0.3, 58 | "path_len": 0.2, 59 | "sigma": 0.1 60 | } 61 | } 62 | ] 63 | } 64 | -------------------------------------------------------------------------------- /configs/train/mnist/random_sampling_p=1.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 50, 8 | "p": 1, 9 | "epsilon": 0.3 10 | }, 11 | "epochs": 10, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "adam", 16 | "params": { 17 | "lr": 1e-3 18 | } 19 | }, 20 | "lr_schedule": "multi_step" 21 | }, 22 | "data": { 23 | "dataset": "mnist", 24 | "num_workers": 1, 25 | "root": "../data/mnist", 26 | "training": { 27 | "flip_crop": false, 28 | "batch_size": 128 29 | }, 30 | "test": { 31 | "batch_size": 128 32 | }, 33 | "use_half": false 34 | }, 35 | "evaluations": [ 36 | { 37 | "type": "standard", 38 | "params": {} 39 | }, 40 | { 41 | "type": "random_sampling", 42 | "params": { 43 | "p": 1, 44 | "m": 50, 45 | "epsilon": 0.3 46 | } 47 | } 48 | ] 49 | } 50 | -------------------------------------------------------------------------------- /configs/train/mnist/random_sampling_p=10.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=10", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 50, 8 | "p": 10, 9 | "epsilon": 0.3 10 | }, 11 | "epochs": 10, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "adam", 16 | "params": { 17 | "lr": 1e-3 18 | } 19 | }, 20 | "lr_schedule": "multi_step" 21 | }, 22 | "data": { 23 | "dataset": "mnist", 24 | "num_workers": 1, 25 | "root": "../data/mnist", 26 | "training": { 27 | "flip_crop": false, 28 | "batch_size": 128 29 | }, 30 | "test": { 31 | "batch_size": 128 32 | }, 33 | "use_half": false 34 | }, 35 | "evaluations": [ 36 | { 37 | "type": "standard", 38 | "params": {} 39 | }, 40 | { 41 | "type": "random_sampling", 42 | "params": { 43 | "p": 10, 44 | "m": 50, 45 | "epsilon": 0.3 46 | } 47 | } 48 | ] 49 | } 50 | -------------------------------------------------------------------------------- /configs/train/mnist/random_sampling_p=100.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=100", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 50, 8 | "p": 100, 9 | "epsilon": 0.3 10 | }, 11 | "epochs": 10, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "adam", 16 | "params": { 17 | "lr": 1e-3 18 | } 19 | }, 20 | "lr_schedule": "multi_step" 21 | }, 22 | "data": { 23 | "dataset": "mnist", 24 | "num_workers": 1, 25 | "root": "../data/mnist", 26 | "training": { 27 | "flip_crop": false, 28 | "batch_size": 128 29 | }, 30 | "test": { 31 | "batch_size": 128 32 | }, 33 | "use_half": false 34 | }, 35 | "evaluations": [ 36 | { 37 | "type": "standard", 38 | "params": {} 39 | }, 40 | { 41 | "type": "random_sampling", 42 | "params": { 43 | "p": 100, 44 | "m": 50, 45 | "epsilon": 0.3 46 | } 47 | } 48 | ] 49 | } 50 | -------------------------------------------------------------------------------- /configs/train/mnist/random_sampling_p=1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "random_sampling_p=1000", 3 | "seed": 1, 4 | "training": { 5 | "type": "random_sampling", 6 | "params": { 7 | "m": 50, 8 | "p": 1000, 9 | "epsilon": 0.3 10 | }, 11 | "epochs": 10, 12 | "test_interval": 10, 13 | "checkpoint_interval": 10, 14 | "opt": { 15 | "type": "adam", 16 | "params": { 17 | "lr": 1e-3 18 | } 19 | }, 20 | "lr_schedule": "multi_step" 21 | }, 22 | "data": { 23 | "dataset": "mnist", 24 | "num_workers": 1, 25 | "root": "../data/mnist", 26 | "training": { 27 | "flip_crop": false, 28 | "batch_size": 128 29 | }, 30 | "test": { 31 | "batch_size": 128 32 | }, 33 | "use_half": false 34 | }, 35 | "evaluations": [ 36 | { 37 | "type": "standard", 38 | "params": {} 39 | }, 40 | { 41 | "type": "random_sampling", 42 | "params": { 43 | "p": 1000, 44 | "m": 50, 45 | "epsilon": 0.3 46 | } 47 | } 48 | ] 49 | } 50 | -------------------------------------------------------------------------------- /configs/train/mnist/standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "standard", 3 | "seed": 1, 4 | "training": { 5 | "type": "standard", 6 | "params": {}, 7 | "epochs": 10, 8 | "test_interval": 1, 9 | "checkpoint_interval": 1, 10 | "opt": { 11 | "type": "adam", 12 | "params": { 13 | "lr": 1e-3 14 | } 15 | }, 16 | "lr_schedule": "multi_step" 17 | }, 18 | "data": { 19 | "dataset": "mnist", 20 | "num_workers": 1, 21 | "root": "../data/mnist", 22 | "training": { 23 | "flip_crop": false, 24 | "batch_size": 128 25 | }, 26 | "test": { 27 | "batch_size": 128 28 | }, 29 | "use_half": false 30 | }, 31 | "evaluations": [ 32 | { 33 | "type": "standard", 34 | "params": {} 35 | } 36 | ] 37 | } 38 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | 5 | dsets = { 6 | 'mnist': datasets.MNIST, 7 | 'cifar10': datasets.CIFAR10 8 | } 9 | 10 | 11 | def get_train_loader(config): 12 | print(config.data.dataset) 13 | kwargs = {'num_workers': config.data.num_workers, 'pin_memory': True} 14 | train_transforms = [] 15 | if config.data.training.flip_crop: 16 | train_transforms += [transforms.RandomCrop(size=32, padding=4), transforms.RandomHorizontalFlip()] 17 | train_transforms.append(transforms.ToTensor()) 18 | train_transforms = transforms.Compose(train_transforms) 19 | train_loader = torch.utils.data.DataLoader( 20 | dsets[config.data.dataset](config.data.root, train=True, transform=train_transforms, download=True), 21 | batch_size=config.data.training.batch_size, shuffle=True, **kwargs) 22 | return train_loader 23 | 24 | 25 | def get_test_loader(config): 26 | kwargs = {'num_workers': config.data.num_workers, 'pin_memory': True} 27 | test_transforms = [transforms.ToTensor()] 28 | test_transforms = transforms.Compose(test_transforms) 29 | test_loader = torch.utils.data.DataLoader( 30 | dsets[config.data.dataset](config.data.root, train=False, transform=test_transforms), 31 | batch_size=config.data.test.batch_size, shuffle=False, **kwargs) 32 | return test_loader 33 | -------------------------------------------------------------------------------- /eval_discrete.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import math 5 | import itertools 6 | import numpy as np 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torchvision.transforms.functional as TF 12 | import matplotlib.pyplot as plt 13 | from torchvision import datasets, transforms 14 | from torch.utils.data import DataLoader 15 | from utils import * 16 | from models.mnist import mnist_classifier 17 | from models.preactresnet import preactresnet18 18 | from datasets import get_train_loader, get_test_loader 19 | 20 | mu = torch.tensor(cifar10_mean).view(3,1,1).cuda() 21 | std = torch.tensor(cifar10_std).view(3,1,1).cuda() 22 | 23 | def normalize(X): 24 | return (X - mu)/std 25 | 26 | def clamp(X, lower_limit, upper_limit): 27 | return torch.max(torch.min(X, upper_limit), lower_limit) 28 | 29 | def reflect(x, a, b): 30 | if not torch.is_tensor(a): 31 | a = torch.zeros_like(x).fill_(a) 32 | if not torch.is_tensor(b): 33 | b = torch.zeros_like(x).fill_(b) 34 | while len(torch.where((x < a) | (x > b))[0]) > 0: 35 | low = torch.where(x < a) 36 | high = torch.where(x > b) 37 | x[low] = a[low] + (a[low] - x[low]) 38 | x[high] = b[high] - (x[high] - b[high]) 39 | return x 40 | 41 | 42 | def epoch_standard(loader, model, opt=None, eval_mode=None): 43 | """Standard training/evaluation epoch over the dataset""" 44 | total_loss, total_err = 0.,0. 45 | for batch in loader: 46 | if type(batch) is dict: 47 | X, y = batch['input'], batch['target'] 48 | else: 49 | X, y = batch 50 | X,y = X.cuda(), y.cuda() 51 | if X.shape[1] == 1: 52 | yp = model(X) 53 | else: 54 | yp = model(normalize(X)) 55 | loss = nn.CrossEntropyLoss()(yp,y) 56 | if opt: 57 | opt.zero_grad() 58 | loss.backward() 59 | opt.step() 60 | total_err += (yp.max(dim=1)[1] != y).sum().item() 61 | total_loss += loss.item() * X.shape[0] 62 | return total_err / len(loader.dataset), total_loss / len(loader.dataset) 63 | 64 | 65 | def epoch_discrete_random_sampling(loader, model, opt=None, p=1, m=20): 66 | padded = 40 67 | unpadded = 32 68 | total_loss = 0. 69 | for batch in loader: 70 | X, y = batch['input'], batch['target'] 71 | X, y = X.cuda(), y.cuda() 72 | 73 | losses = torch.zeros(m, X.size(0)).cuda() 74 | max_loss = 0 75 | for i in range(m): 76 | scale_min, scale_max = 0.9, 1.1 77 | rotate_min, rotate_max = -10, 10 78 | 79 | flip = np.random.choice([True, False]) 80 | scale = np.random.uniform(scale_min, scale_max) 81 | rotate = np.random.uniform(rotate_min, rotate_max) 82 | crop_range = math.floor(scale * padded) - unpadded 83 | crop_x = np.random.randint(0, crop_range + 1) 84 | crop_y = np.random.randint(0, crop_range + 1) 85 | 86 | sample = X.detach().clone() 87 | if flip: sample = TF.hflip(sample) 88 | sample = TF.resize(sample, size=(int(padded * scale), int(padded * scale))) 89 | sample = TF.rotate(sample, angle=rotate) 90 | sample = TF.crop(sample, top=crop_y, left=crop_x, height=unpadded, width=unpadded) 91 | yp = model(normalize(sample)) 92 | loss = nn.CrossEntropyLoss(reduction='none')(yp, y) 93 | losses[i] = loss.detach() 94 | 95 | loss = losses.transpose(0, 1) 96 | loss = (torch.exp(torch.logsumexp(torch.log(loss + 1e-10) * p, dim=1)/ p) * (1/m)**(1/p)) 97 | loss = loss.mean() 98 | total_loss += loss.item() * y.size(0) 99 | return total_loss / len(loader.dataset) 100 | 101 | 102 | def epoch_mcmc(loader, model, opt=None, p=1, m=20): 103 | """MCMC with discrete data augmentation transformations""" 104 | padded = 40 105 | unpadded = 32 106 | total_loss = 0. 107 | total_adv_loss = 0. 108 | 109 | for batch in loader: 110 | if type(batch) is dict: 111 | X, y = batch['input'], batch['target'] 112 | X, y = X.cuda(), y.cuda() 113 | 114 | losses = torch.zeros(X.size(0), m) 115 | 116 | transforms = torch.zeros(X.shape[0], 5) 117 | scale_min, scale_max = 0.9, 1.1 118 | rotate_min, rotate_max = -10, 10 119 | 120 | flip = np.random.choice([True, False]) 121 | scale = np.random.uniform(scale_min, scale_max) 122 | rotate = np.random.uniform(rotate_min, rotate_max) 123 | crop_range = math.floor(scale * padded) - unpadded 124 | crop_x = np.random.randint(0, crop_range + 1) 125 | crop_y = np.random.randint(0, crop_range + 1) 126 | 127 | previous = X.detach().clone() 128 | if flip: previous = TF.hflip(previous) 129 | previous = TF.resize(previous, size=(int(padded * scale), int(padded * scale))) 130 | previous = TF.rotate(previous, angle=rotate) 131 | previous = TF.crop(previous, top=crop_y, left=crop_x, height=unpadded , width=unpadded) 132 | 133 | transforms[:, 0] = int(flip) 134 | transforms[:, 1] = rotate 135 | transforms[:, 2] = crop_x 136 | transforms[:, 3] = crop_y 137 | transforms[:, 4] = scale 138 | 139 | thetas = np.linspace(0, p, m) 140 | max_loss = 0 141 | idxx = 0 142 | for i in range(len(thetas)): 143 | theta = thetas[i] 144 | for j in range(20): 145 | yp = model(normalize(previous)) 146 | loss = nn.CrossEntropyLoss(reduction='none')(yp,y) 147 | if j == 19: 148 | losses[:, i] = loss.detach().cpu() 149 | log_loss = theta * torch.log(loss + 1e-10) 150 | 151 | # transforms from previous iteration 152 | flip = transforms[:, 0] 153 | rotate = transforms[:, 1] 154 | crop_x = transforms[:, 2] 155 | crop_y = transforms[:, 3] 156 | scale = transforms[:, 4] 157 | 158 | # proposed deltas 159 | flip_delta = np.random.choice([True, False]) 160 | rotate_delta = np.random.normal(0, 5) 161 | crop_x_delta = int(np.random.normal(0, 2)) 162 | crop_y_delta = int(np.random.normal(0, 2)) 163 | scale_delta = np.random.normal(0, 0.5) 164 | 165 | # proposed transforms with reflections 166 | if flip_delta: 167 | flip_proposal = 1-flip 168 | else: 169 | flip_proposal = flip 170 | scale_proposal = reflect((scale + scale_delta).float(), scale_min, scale_max) 171 | rotate_proposal = reflect((rotate + rotate_delta).float(), rotate_min, rotate_max) 172 | crop_range_max = (np.floor(scale_proposal * padded) - unpadded).int() 173 | crop_x_proposal = reflect((crop_x + crop_x_delta).int(), 0, crop_range_max) 174 | crop_y_proposal = reflect((crop_y + crop_y_delta).int(), 0, crop_range_max) 175 | 176 | new_transforms = torch.stack((flip_proposal.int(), rotate_proposal, crop_x_proposal, crop_y_proposal, scale_proposal), dim=1) 177 | proposal = X.detach().clone() 178 | transformed_proposal = torch.zeros(X.shape[0], X.shape[1], unpadded, unpadded) 179 | proposal[flip_proposal.bool()] = TF.hflip(proposal[flip_proposal.bool()]) 180 | for idx in range(X.size(0)): 181 | scaled = TF.resize(proposal[idx], size=(int(padded * scale_proposal[idx]), int(padded * scale_proposal[idx]))) 182 | rotated = TF.rotate(scaled, angle=rotate_proposal[idx].item()) 183 | cropped = TF.crop(rotated, top=crop_y_proposal[idx], left=crop_x_proposal[idx], height=unpadded, width=unpadded) 184 | transformed_proposal[idx] = cropped 185 | 186 | proposal = transformed_proposal.cuda() 187 | yp_next = model(normalize(proposal)) 188 | loss_next = nn.CrossEntropyLoss(reduction='none')(yp_next, y) 189 | log_loss_next = theta * torch.log(loss_next + 1e-10) 190 | 191 | log_ratio = log_loss_next - log_loss 192 | 193 | idx_ = torch.where(log_ratio > np.log(1)) 194 | log_ratio[idx_].fill_(np.log(1)) 195 | u = torch.log(torch.zeros_like(log_ratio).uniform_(0,1)).cuda() 196 | idx_accept = torch.where(u <= log_ratio.cuda())[0] 197 | transforms[idx_accept] = new_transforms[idx_accept] 198 | previous.data[idx_accept] = proposal.data[idx_accept].half() 199 | 200 | adv_loss = losses.detach().clone() 201 | adv_loss = torch.max(adv_loss, dim=1)[0] 202 | adv_loss = adv_loss.mean() 203 | total_adv_loss += adv_loss * y.size(0) 204 | 205 | loss = losses.detach().clone() 206 | loss = torch.exp(torch.log(loss + 1e-10).sum(dim=1) / m) 207 | loss = loss.mean() 208 | total_loss += loss.item() * y.size(0) 209 | return total_loss / len(loader.dataset), total_adv_loss / len(loader.dataset) 210 | 211 | def main(): 212 | parser = argparse.ArgumentParser() 213 | parser.add_argument('--checkpoint', type=str) 214 | args = parser.parse_args() 215 | 216 | d = torch.load(args.checkpoint) 217 | 218 | torch.manual_seed(1) 219 | torch.cuda.manual_seed(1) 220 | 221 | test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False) 222 | test_data = test_dataset.data[:1000] 223 | test_labels = test_dataset.targets[:1000] 224 | test_examples = list(zip(transpose(pad(test_data, 4) / 255.), test_labels)) 225 | batch_size = 100 226 | test_loader = Batches(test_examples, batch_size, shuffle=False, num_workers=2) 227 | 228 | model = preactresnet18().cuda() 229 | model.load_state_dict(d["model"]) 230 | model.eval() 231 | _, test_loss = epoch_standard(test_loader, model) 232 | print('standard loss: ', test_loss) 233 | adv_losses = [] 234 | for p in [1, 10, 100, 1000]: 235 | test_loss, adv_loss = epoch_mcmc(test_loader, model, p=p, m=500) 236 | print(f'mcmc loss p={p}: ', test_loss) 237 | test_loss = epoch_discrete_random_sampling(test_loader, model, p=p, m=500) 238 | print(f'rs loss p={p}: ', test_loss) 239 | adv_losses.append(adv_loss) 240 | print('adversarial loss: ', max(adv_losses)) 241 | 242 | if __name__ == '__main__': 243 | main() 244 | 245 | -------------------------------------------------------------------------------- /models/mnist.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Flatten(nn.Module): 6 | def forward(self, x): 7 | return x.view(x.size(0),-1) 8 | 9 | 10 | class mnist_classifier(nn.Module): 11 | def __init__(self, N=32): 12 | super(mnist_classifier, self).__init__() 13 | self.net = nn.Sequential(*[ 14 | nn.Conv2d(1, N, 4, stride=2, padding=1), 15 | nn.ReLU(), 16 | nn.Conv2d(N, 2 * N, 4, stride=2, padding=1), 17 | nn.ReLU(), 18 | Flatten(), 19 | nn.Linear(7 * 7 * 2 * N, 32 * N), 20 | nn.ReLU(), 21 | nn.Linear(32 * N, 10)]) 22 | 23 | def forward(self, X): 24 | X = torch.clamp(X, min=0, max=1) 25 | return self.net(X) 26 | -------------------------------------------------------------------------------- /models/preactresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | '''Pre-activation version of the BasicBlock.''' 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1): 11 | super(PreActBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | 17 | if stride != 1 or in_planes != self.expansion*planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 20 | ) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.bn1(x)) 24 | shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x 25 | out = self.conv1(out) 26 | out = self.conv2(F.relu(self.bn2(out))) 27 | out += shortcut 28 | return out 29 | 30 | 31 | class PreActResNet(nn.Module): 32 | def __init__(self, block, num_blocks, num_classes=10, dset='cifar10'): 33 | super(PreActResNet, self).__init__() 34 | self.in_planes = 64 35 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 36 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 37 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 38 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 39 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 40 | self.bn = nn.BatchNorm2d(512 * block.expansion) 41 | self.linear = nn.Linear(512 * block.expansion, num_classes) 42 | 43 | 44 | def _make_layer(self, block, planes, num_blocks, stride): 45 | strides = [stride] + [1]*(num_blocks-1) 46 | layers = [] 47 | for stride in strides: 48 | layers.append(block(self.in_planes, planes, stride)) 49 | self.in_planes = planes * block.expansion 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | out = self.conv1(x) 54 | out = self.layer1(out) 55 | out = self.layer2(out) 56 | out = self.layer3(out) 57 | out = self.layer4(out) 58 | out = F.relu(self.bn(out)) 59 | out = F.avg_pool2d(out, 4) 60 | out = out.view(out.size(0), -1) 61 | out = self.linear(out) 62 | return out 63 | 64 | def preactresnet18(): 65 | return PreActResNet(PreActBlock, [2,2,2,2]) 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter==1.0.0 2 | matplotlib==3.3.3 3 | numpy==1.19.2 4 | torch==1.6.0 5 | torchvision==0.7.0 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import math 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.utils.data import DataLoader 12 | from utils import * 13 | from models.preactresnet import preactresnet18 14 | from models.mnist import mnist_classifier 15 | from datasets import get_train_loader, get_test_loader 16 | import torch.nn.functional as F 17 | 18 | 19 | mu = torch.tensor(cifar10_mean).view(3,1,1).cuda() 20 | std = torch.tensor(cifar10_std).view(3,1,1).cuda() 21 | 22 | 23 | def normalize(X): 24 | return (X - mu)/std 25 | 26 | 27 | def pgd_linf(model, X, y, epsilon=0.03, m=20, randomize=True, alpha_scale=2.5, restarts=1): 28 | """ Construct FGSM adversarial examples on the examples X""" 29 | alpha = alpha_scale * epsilon / m 30 | 31 | max_loss = torch.zeros(y.shape[0]).cuda() 32 | max_delta = torch.zeros_like(X).cuda() 33 | for _ in range(restarts): 34 | if randomize: 35 | delta = torch.zeros_like(X) 36 | delta.uniform_(-epsilon, epsilon) 37 | delta.requires_grad = True 38 | else: 39 | delta = torch.zeros_like(X, requires_grad=True) 40 | for t in range(m): 41 | if X.shape[1] == 1: 42 | loss = nn.CrossEntropyLoss()(model(torch.clamp(X + delta, min=0, max=1)), y) 43 | else: 44 | loss = nn.CrossEntropyLoss()(model(normalize(torch.clamp(X + delta, min=0, max=1))), y) 45 | loss.backward() 46 | delta.data = delta + alpha*delta.grad.detach().sign() 47 | delta.data = torch.clamp(delta, -epsilon, epsilon) 48 | delta.grad.zero_() 49 | if X.shape[1] == 1: 50 | all_loss = F.cross_entropy(model(X+delta), y, reduction='none') 51 | else: 52 | all_loss = F.cross_entropy(model(normalize(X+delta)), y, reduction='none') 53 | max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss] 54 | max_loss = torch.max(max_loss, all_loss) 55 | return max_delta.detach() 56 | 57 | 58 | def epoch_standard(loader, model, opt=None, eval_mode=None): 59 | """Standard training/evaluation epoch over the dataset""" 60 | total_loss, total_err = 0.,0. 61 | for batch in loader: 62 | if type(batch) is dict: 63 | X, y = batch['input'], batch['target'] 64 | else: 65 | X, y = batch 66 | X,y = X.cuda(), y.cuda() 67 | if X.shape[1] == 1: 68 | yp = model(X) 69 | else: 70 | yp = model(normalize(X)) 71 | loss = nn.CrossEntropyLoss()(yp,y) 72 | if opt: 73 | opt.zero_grad() 74 | loss.backward() 75 | opt.step() 76 | total_err += (yp.max(dim=1)[1] != y).sum().item() 77 | total_loss += loss.item() * X.shape[0] 78 | return total_err / len(loader.dataset), total_loss / len(loader.dataset) 79 | 80 | 81 | def epoch_adversarial(loader, model, opt=None, eval_mode=False, **kwargs): 82 | """Adversarial training/evaluation epoch over the dataset""" 83 | total_loss, total_err = 0.,0. 84 | for batch in loader: 85 | if type(batch) is dict: 86 | X, y = batch['input'], batch['target'] 87 | else: 88 | X, y = batch 89 | X,y = X.cuda(), y.cuda() 90 | delta = pgd_linf(model, X, y, **kwargs) 91 | if X.shape[1] == 1: 92 | yp = model(X+delta) 93 | else: 94 | yp = model(normalize(X+delta)) 95 | loss = nn.CrossEntropyLoss()(yp,y) 96 | if opt: 97 | opt.zero_grad() 98 | loss.backward() 99 | opt.step() 100 | 101 | total_err += (yp.max(dim=1)[1] != y).sum().item() 102 | total_loss += loss.item() * X.shape[0] 103 | return total_err / len(loader.dataset), total_loss / len(loader.dataset) 104 | 105 | 106 | def epoch_random_sampling(loader, model, opt=None, epsilon=0.03, p=1, m=10, eval_mode=False): 107 | """Adversarial training/evaluation epoch over the dataset""" 108 | total_loss = 0. 109 | for batch in loader: 110 | if type(batch) is dict: 111 | X, y = batch['input'], batch['target'] 112 | else: 113 | X, y = batch 114 | 115 | X,y = X.cuda(), y.cuda() 116 | lower_limit = torch.max(-X, torch.tensor(-epsilon, dtype=X.dtype).view(1, 1, 1).cuda()) 117 | upper_limit = torch.min(1 - X, torch.tensor(epsilon, dtype=X.dtype).view(1, 1, 1).cuda()) 118 | if not eval_mode: 119 | deltas = (lower_limit - upper_limit) * torch.rand(m, *X.shape).cuda() + upper_limit 120 | deltas.requires_grad = True 121 | 122 | X_delta = (X[None] + deltas).transpose(0, 1).contiguous().view(-1, *X.shape[1:]) 123 | y_delta = y[None].expand(m,*y.shape).transpose(0, 1).contiguous().view(-1) 124 | if X.shape[1] == 1: 125 | yp_delta = model(X_delta) 126 | else: 127 | yp_delta = model(normalize(X_delta)) 128 | loss = nn.CrossEntropyLoss(reduction='none')(yp_delta,y_delta) 129 | loss = loss.view(X.size(0), m) 130 | loss = (torch.exp(torch.logsumexp(torch.log(loss + 1e-10) * p - math.log(m), dim=1)/ p)).mean() 131 | if opt: 132 | opt.zero_grad() 133 | loss.backward() 134 | opt.step() 135 | else: 136 | losses = torch.zeros(m, X.size(0)).cuda() 137 | for i in range(m): 138 | delta = (lower_limit - upper_limit) * torch.rand_like(X) + upper_limit 139 | if X.shape[1] == 1: 140 | yp_delta = model(X + delta) 141 | else: 142 | yp_delta = model(normalize(X + delta)) 143 | loss = nn.CrossEntropyLoss(reduction='none')(yp_delta,y) 144 | losses[i] = loss.detach() 145 | loss = losses.transpose(0, 1) 146 | loss = (torch.exp(torch.logsumexp(torch.log(loss + 1e-10) * p - math.log(m), dim=1)/ p)).mean() 147 | total_loss += loss.item() * y.size(0) 148 | return 0.0, total_loss / len(loader.dataset) 149 | 150 | 151 | def epoch_hmc(loader, model, opt=None, epsilon=0.03, p=1, m=10, l=1, path_len=0.05, sigma=0.1, 152 | eval_mode=False, anneal_theta=True): 153 | """Adversarial training/evaluation epoch over the dataset""" 154 | total_loss = 0. 155 | num_accepts = 0 156 | total_n = 0 157 | 158 | alpha = path_len * sigma ** 2 / l 159 | print(alpha) 160 | 161 | for batch in loader: 162 | if type(batch) is dict: 163 | X, y = batch['input'], batch['target'] 164 | else: 165 | X, y = batch 166 | X,y = X.cuda(), y.cuda() 167 | 168 | lower_limit = torch.max(-X, torch.tensor(-epsilon, dtype=X.dtype).view(1, 1, 1).cuda()) 169 | upper_limit = torch.min(1 - X, torch.tensor(epsilon, dtype=X.dtype).view(1, 1, 1).cuda()) 170 | 171 | if eval_mode: 172 | losses = torch.zeros(X.size(0), m) 173 | else: 174 | deltas = torch.zeros(m, *X.shape).cuda() 175 | 176 | delta = (lower_limit - upper_limit) * torch.rand_like(X) + upper_limit 177 | delta.requires_grad = True 178 | 179 | if anneal_theta: 180 | thetas = np.linspace(0, p, m) 181 | else: 182 | thetas = [np.random.uniform(0, p) for i in range(m)] 183 | model.eval() 184 | for i, theta in enumerate(thetas): 185 | mom = torch.randn_like(X).cuda() * sigma 186 | if not eval_mode: 187 | deltas[i] = delta.data 188 | 189 | if X.shape[1] == 1: 190 | yp = model(X + delta) 191 | else: 192 | yp = model(normalize(X + delta)) 193 | loss = nn.CrossEntropyLoss(reduction='none')(yp,y) 194 | if eval_mode: 195 | losses[:, i] = loss.detach().cpu() 196 | log_loss = theta * torch.log(loss + 1e-10) 197 | log_loss.sum().backward() 198 | 199 | h_delta = torch.norm(mom.view(X.size(0), -1), dim=1)**2/sigma**2/2 - log_loss 200 | mom += 0.5 * alpha * delta.grad # half step of momentum 201 | proposal = delta.data 202 | for j in range(l): 203 | proposal = proposal.data + alpha * mom / sigma**2 # full step of position 204 | # reflection 205 | while len(torch.where(proposal < lower_limit)[0]) > 0 or len(torch.where(proposal > upper_limit)[0]) > 0: 206 | idx_ = torch.where(proposal < lower_limit) 207 | if len(idx_[0]) > 0: 208 | proposal.data[idx_] = 2*lower_limit[idx_] - proposal.data[idx_] 209 | mom[idx_] = -mom[idx_] 210 | idx_ = torch.where(proposal > upper_limit) 211 | if len(idx_[0]) > 0: 212 | proposal.data[idx_] = 2*upper_limit[idx_] - proposal.data[idx_] 213 | mom[idx_] = -mom[idx_] 214 | proposal.requires_grad = True 215 | if X.shape[1] == 1: 216 | yp_next = model(X + proposal) 217 | else: 218 | yp_next = model(normalize(X + proposal)) 219 | loss_next = nn.CrossEntropyLoss(reduction='none')(yp_next, y) 220 | log_loss_next = theta * torch.log(loss_next + 1e-10) 221 | log_loss_next.sum().backward() 222 | if j != (l-1): 223 | mom += alpha * proposal.grad # full step of momentum 224 | mom += 0.5 * alpha * proposal.grad 225 | 226 | h_proposal = torch.norm(mom.view(X.size(0), -1), dim=1)**2/sigma**2/2 - log_loss_next 227 | delta_h = h_proposal - h_delta 228 | u = torch.zeros_like(delta_h).uniform_(0,1) 229 | idx_accept = torch.where(u <= torch.exp(-delta_h)) 230 | delta.data[idx_accept] = proposal.data[idx_accept] 231 | 232 | num_accepts += len(idx_accept[0]) 233 | total_n += delta.size(0) 234 | delta.grad.zero_() 235 | 236 | if eval_mode: 237 | loss = losses 238 | else: 239 | model.train() 240 | y_delta = y[None].expand(m,*y.shape).transpose(0, 1).contiguous().view(-1) 241 | X_delta = (X[None] + deltas).transpose(0, 1).contiguous().view(-1, *X.shape[1:]) 242 | if X.shape[1] == 1: 243 | yp_delta = model(torch.clamp(X_delta, min=0, max=1)) 244 | else: 245 | yp_delta = model(normalize(torch.clamp(X_delta, min=0, max=1))) 246 | loss = nn.CrossEntropyLoss(reduction='none')(yp_delta,y_delta) 247 | loss = loss.view(X.size(0), m) 248 | loss = torch.exp(torch.log(loss + 1e-10).sum(dim=1) / m) 249 | loss = loss.mean() 250 | if opt: 251 | opt.zero_grad() 252 | loss.backward() 253 | opt.step() 254 | 255 | total_loss += loss.item() * y.size(0) 256 | 257 | print('percent accepts ', num_accepts / total_n * 100.) 258 | return 0.0, total_loss / len(loader.dataset) 259 | 260 | 261 | epoch_modes = { 262 | "standard": epoch_standard, 263 | "adversarial": epoch_adversarial, 264 | "hmc": epoch_hmc, 265 | "random_sampling": epoch_random_sampling 266 | } 267 | 268 | 269 | def main(): 270 | parser = argparse.ArgumentParser() 271 | parser.add_argument('-c', '--config', type=str, 272 | help='path to config file', required=True) 273 | parser.add_argument('--resume', default=None, help='epoch') 274 | parser.add_argument('--dp', action='store_true') 275 | args = parser.parse_args() 276 | config_dict = get_config(args.config) 277 | 278 | assert os.path.splitext(os.path.basename(args.config))[0] == config_dict['output_dir'] 279 | 280 | torch.manual_seed(config_dict['seed']) 281 | torch.cuda.manual_seed(config_dict['seed']) 282 | 283 | if config_dict['training'] is None: 284 | output_dir = os.path.join('evaluations', config_dict['data']['dataset'], config_dict['output_dir']) 285 | else: 286 | output_dir = os.path.join('experiments', config_dict['data']['dataset'], config_dict['output_dir']) 287 | if not os.path.exists(output_dir): 288 | os.makedirs(output_dir) 289 | checkpoint_dir = os.path.join(output_dir, 'checkpoints') 290 | if not os.path.exists(checkpoint_dir): 291 | os.makedirs(checkpoint_dir) 292 | 293 | with open(os.path.join(output_dir, 'config.json'), 'w') as f: 294 | json.dump(config_dict, f, sort_keys=True, indent=4) 295 | 296 | config = config_to_namedtuple(config_dict) 297 | 298 | logger = get_logger(__name__, output_dir) 299 | 300 | logger.info(f'cuda {torch.cuda.is_available()}') 301 | 302 | if config.training is not None: 303 | if config.training.lr_schedule == 'cyclic': 304 | lr_schedule = lambda t: np.interp([t], [0, config.training.epochs * 2 // 5, config.training.epochs], [0, config.training.opt.params.lr, 0])[0] 305 | elif config.training.lr_schedule == 'multi_step': 306 | def lr_schedule(t): 307 | if t / config.training.epochs < 0.5: 308 | return config.training.opt.params.lr 309 | elif t / config.training.epochs < 0.75: 310 | return config.training.opt.params.lr / 10. 311 | else: 312 | return config.training.opt.params.lr / 100. 313 | 314 | if config.data.dataset == 'mnist': 315 | model = mnist_classifier().cuda() 316 | elif config.data.dataset == 'cifar10': 317 | model = preactresnet18().cuda() 318 | 319 | if config.training is not None: 320 | if config.data.use_half: 321 | transforms = [Crop(32, 32), FlipLR()] 322 | dataset = cifar10(config.data.root) 323 | train_set = list(zip(transpose(pad(dataset['train']['data'], 4)/255.), dataset['train']['labels'])) 324 | train_set_x = Transform(train_set, transforms) 325 | train_loader = Batches(train_set_x, config.data.training.batch_size, shuffle=True, set_random_choices=True, num_workers=config.data.num_workers) 326 | 327 | test_set = list(zip(transpose(dataset['test']['data']/255.), dataset['test']['labels'])) 328 | test_loader = Batches(test_set, config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers) 329 | else: 330 | train_loader = get_train_loader(config) 331 | test_loader = get_test_loader(config) 332 | if config.training.opt.type == 'adam': 333 | opt = torch.optim.Adam(model.parameters(), **config.training.opt.params._asdict()) 334 | elif config.training.opt.type == 'sgd': 335 | opt = torch.optim.SGD(model.parameters(),**config.training.opt.params._asdict()) 336 | 337 | if args.resume is not None: 338 | checkpoint_filename = os.path.join(output_dir, 'checkpoints', f'checkpoint_{args.resume}.pth') 339 | d = torch.load(checkpoint_filename) 340 | logger.info(f"Resume model checkpoint {d['epoch']}...") 341 | model.load_state_dict(d["model"]) 342 | opt.load_state_dict(d["opt"]) 343 | start_epoch = d["epoch"] + 1 344 | else: 345 | start_epoch = 0 346 | 347 | if args.dp: 348 | model = nn.DataParallel(model) 349 | 350 | # Train 351 | logger.info(f"Epoch \t \t Train Loss \t Train Error \t LR") 352 | for epoch in range(start_epoch, config.training.epochs): 353 | lr = lr_schedule(epoch) 354 | 355 | opt.param_groups[0]['lr'] = lr 356 | epoch_mode = epoch_modes[config.training.type] 357 | train_err, train_loss = epoch_modes[config.training.type](train_loader, model, opt, **config.training.params._asdict()) 358 | logger.info(f'{epoch} \t \t \t {train_loss:.6f} \t \t {train_err:.6f} \t {lr:.6f}') 359 | 360 | # evaluate 361 | if (epoch+1) % config.training.test_interval == 0 or epoch + 1 == config.training.epochs: 362 | model.eval() 363 | for evaluation in config.evaluations: 364 | test_err, test_loss = epoch_modes[evaluation.type](test_loader, model, **evaluation.params._asdict(), eval_mode=True) 365 | logger.info(f"{evaluation.type}: \t Test Loss {test_loss:.6f} \t Test Error {test_err:.6f} \t {lr:.6f}") 366 | model.train() 367 | if (epoch+1) % config.training.checkpoint_interval == 0 or epoch + 1 == config.training.epochs: 368 | if args.dp: 369 | save_m = model.module 370 | else: 371 | save_m = model 372 | d = {"epoch": epoch, "model": save_m.state_dict(), "opt": opt.state_dict()} 373 | torch.save(d, os.path.join(output_dir, 'checkpoints', f'checkpoint_{epoch}.pth')) 374 | 375 | else: 376 | if config.data.use_half: 377 | dataset = cifar10(config.data.root) 378 | test_set = list(zip(transpose(dataset['test']['data']/255.), dataset['test']['labels'])) 379 | test_loader = Batches(test_set, config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers) 380 | else: 381 | test_loader = get_test_loader(config) 382 | checkpoint_filename = config.checkpoint_filename 383 | d = torch.load(checkpoint_filename) 384 | logger.info(f"Loading model checkpoint {config.checkpoint_filename}...") 385 | model.load_state_dict(d["model"]) 386 | if args.dp: 387 | model = nn.DataParallel(model) 388 | model.eval() 389 | for evaluation in config.evaluations: 390 | test_err, test_loss = epoch_modes[evaluation.type](test_loader, model, eval_mode=True, **evaluation.params._asdict()) 391 | logger.info(f"{evaluation.type}: \t Test Loss {test_loss:.6f} \t Test Error {test_err:.6f}") 392 | 393 | 394 | if __name__ == "__main__": 395 | main() 396 | -------------------------------------------------------------------------------- /train_discrete.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import math 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.utils.data import DataLoader 12 | import torchvision.transforms.functional as TF 13 | from utils import * 14 | from models.preactresnet import preactresnet18 15 | from datasets import get_train_loader, get_test_loader 16 | 17 | mu = torch.tensor(cifar10_mean).view(3,1,1).cuda() 18 | std = torch.tensor(cifar10_std).view(3,1,1).cuda() 19 | 20 | 21 | def normalize(X): 22 | return (X - mu)/std 23 | 24 | 25 | def epoch_standard(loader, model, opt=None, eval_mode=None): 26 | """Standard training/evaluation epoch over the dataset""" 27 | total_loss, total_err = 0.,0. 28 | for batch in loader: 29 | if type(batch) is dict: 30 | X, y = batch['input'], batch['target'] 31 | else: 32 | X, y = batch 33 | X,y = X.cuda(), y.cuda() 34 | if X.shape[1] == 1: 35 | yp = model(X) 36 | else: 37 | yp = model(normalize(X)) 38 | loss = nn.CrossEntropyLoss()(yp,y) 39 | if opt: 40 | opt.zero_grad() 41 | loss.backward() 42 | opt.step() 43 | total_err += (yp.max(dim=1)[1] != y).sum().item() 44 | total_loss += loss.item() * X.shape[0] 45 | return total_err / len(loader.dataset), total_loss / len(loader.dataset) 46 | 47 | 48 | def epoch_discrete_random_sampling(loader, model, opt=None, p=1, m=20, eval_mode=False): 49 | padded = 40 50 | unpadded = 32 51 | total_loss = 0. 52 | for batch in loader: 53 | X, y = batch['input'], batch['target'] 54 | X, y = X.cuda(), y.cuda() 55 | if eval_mode: 56 | losses = torch.zeros(m, X.size(0)).cuda() 57 | else: 58 | Xs = torch.zeros(m, X.size(0), X.size(1), unpadded, unpadded).cuda() 59 | for i in range(m): 60 | scale_min, scale_max = 0.9, 1.1 61 | rotate_min, rotate_max = -10, 10 62 | 63 | flip = np.random.choice([True, False]) 64 | scale = np.random.uniform(scale_min, scale_max) 65 | rotate = np.random.uniform(rotate_min, rotate_max) 66 | crop_range = math.floor(scale * padded) - unpadded 67 | crop_x = np.random.randint(0, crop_range + 1) 68 | crop_y = np.random.randint(0, crop_range + 1) 69 | 70 | sample = X.detach().clone() 71 | if flip: sample = TF.hflip(sample) 72 | sample = TF.resize(sample, size=(int(padded * scale), int(padded * scale))) 73 | sample = TF.rotate(sample, angle=rotate) 74 | sample = TF.crop(sample, top=crop_y, left=crop_x, height=unpadded, width=unpadded) 75 | if eval_mode: 76 | yp = model(normalize(sample)) 77 | loss = nn.CrossEntropyLoss(reduction='none')(yp, y) 78 | losses[i] = loss.detach() 79 | else: 80 | Xs[i] = sample 81 | 82 | if eval_mode: 83 | loss = losses.transpose(0, 1) 84 | loss = (torch.exp(torch.logsumexp(torch.log(loss + 1e-10) * p, dim=1)/ p) * (1/m)**(1/p)) 85 | loss = loss.mean() 86 | else: 87 | y_samples = y[None].expand(m,*y.shape).transpose(0, 1).contiguous().view(-1) 88 | X_samples = Xs.transpose(0, 1).contiguous().view(-1, X.shape[1], unpadded, unpadded) 89 | yp_samples = model(normalize(X_samples)) 90 | loss = nn.CrossEntropyLoss(reduction='none')(yp_samples,y_samples) 91 | loss = loss.view(X.size(0), m) 92 | loss = (torch.exp(torch.logsumexp(torch.log(loss + 1e-10) * p, dim=1)/ p) * (1/m)**(1/p)).mean() 93 | if opt: 94 | opt.zero_grad() 95 | loss.backward() 96 | opt.step() 97 | total_loss += loss.item() * y.size(0) 98 | return 0, total_loss / len(loader.dataset) 99 | 100 | 101 | epoch_modes = { 102 | "standard": epoch_standard, 103 | "discrete_random_sampling": epoch_discrete_random_sampling 104 | } 105 | 106 | 107 | def main(): 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument('-c', '--config', type=str, 110 | help='path to config file', required=True) 111 | parser.add_argument('--resume', default=None, help='epoch') 112 | parser.add_argument('--dp', action='store_true') 113 | args = parser.parse_args() 114 | config_dict = get_config(args.config) 115 | 116 | assert os.path.splitext(os.path.basename(args.config))[0] == config_dict['output_dir'] 117 | 118 | torch.manual_seed(config_dict['seed']) 119 | torch.cuda.manual_seed(config_dict['seed']) 120 | 121 | if config_dict['training'] is None: 122 | output_dir = os.path.join('evaluations', config_dict['data']['dataset'], config_dict['output_dir']) 123 | else: 124 | output_dir = os.path.join('experiments', config_dict['data']['dataset'], config_dict['output_dir']) 125 | if not os.path.exists(output_dir): 126 | os.makedirs(output_dir) 127 | checkpoint_dir = os.path.join(output_dir, 'checkpoints') 128 | if not os.path.exists(checkpoint_dir): 129 | os.makedirs(checkpoint_dir) 130 | 131 | with open(os.path.join(output_dir, 'config.json'), 'w') as f: 132 | json.dump(config_dict, f, sort_keys=True, indent=4) 133 | 134 | config = config_to_namedtuple(config_dict) 135 | 136 | logger = get_logger(__name__, output_dir) 137 | 138 | logger.info(f'cuda {torch.cuda.is_available()}') 139 | 140 | if config.training is not None: 141 | lr_schedule = lambda t: np.interp([t], [0, config.training.epochs * 2 // 5, config.training.epochs], [0, config.training.opt.params.lr, 0])[0] 142 | 143 | model = preactresnet18().cuda() 144 | 145 | if config.training is not None: 146 | dataset = cifar10(config.data.root) 147 | train_set = list(zip(transpose(pad(dataset['train']['data'], 4)/255.), dataset['train']['labels'])) 148 | if config.training.type == 'discrete_random_sampling': 149 | train_loader = Batches(train_set, config.data.training.batch_size, shuffle=True, num_workers=config.data.num_workers) 150 | test_set = list(zip(transpose(pad(dataset['test']['data'], 4)/255.), dataset['test']['labels'])) 151 | test_loader = Batches(test_set, config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers) 152 | else: 153 | transforms = [Crop(32, 32), FlipLR()] 154 | train_set_x = Transform(train_set, transforms) 155 | train_loader = Batches(train_set_x, config.data.training.batch_size, shuffle=True, set_random_choices=True, num_workers=config.data.num_workers) 156 | 157 | test_set = list(zip(transpose(dataset['test']['data']/255.), dataset['test']['labels'])) 158 | test_loader = Batches(test_set, config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers) 159 | if config.training.opt.type == 'adam': 160 | opt = torch.optim.Adam(model.parameters(), **config.training.opt.params._asdict()) 161 | elif config.training.opt.type == 'sgd': 162 | opt = torch.optim.SGD(model.parameters(),**config.training.opt.params._asdict()) 163 | 164 | if args.resume is not None: 165 | checkpoint_filename = os.path.join(output_dir, 'checkpoints', f'checkpoint_{args.resume}.pth') 166 | d = torch.load(checkpoint_filename) 167 | logger.info(f"Resume model checkpoint {d['epoch']}...") 168 | model.load_state_dict(d["model"]) 169 | opt.load_state_dict(d["opt"]) 170 | start_epoch = d["epoch"] + 1 171 | else: 172 | start_epoch = 0 173 | 174 | if args.dp: 175 | model = nn.DataParallel(model) 176 | 177 | # Train 178 | logger.info(f"Epoch \t \t Train Loss \t Train Error \t LR") 179 | for epoch in range(start_epoch, config.training.epochs): 180 | lr = lr_schedule(epoch) 181 | 182 | opt.param_groups[0]['lr'] = lr 183 | epoch_mode = epoch_modes[config.training.type] 184 | train_err, train_loss = epoch_modes[config.training.type](train_loader, model, opt, **config.training.params._asdict()) 185 | logger.info(f'{epoch} \t \t \t {train_loss:.6f} \t \t {train_err:.6f} \t {lr:.6f}') 186 | 187 | # evaluate 188 | if (epoch+1) % config.training.test_interval == 0 or epoch + 1 == config.training.epochs: 189 | model.eval() 190 | for evaluation in config.evaluations: 191 | test_err, test_loss = epoch_modes[evaluation.type](test_loader, model, **evaluation.params._asdict(), eval_mode=True) 192 | logger.info(f"{evaluation.type}: \t Test Loss {test_loss:.6f} \t Test Error {test_err:.6f} \t {lr:.6f}") 193 | model.train() 194 | if (epoch+1) % config.training.checkpoint_interval == 0 or epoch + 1 == config.training.epochs: 195 | if args.dp: 196 | save_m = model.module 197 | else: 198 | save_m = model 199 | d = {"epoch": epoch, "model": save_m.state_dict(), "opt": opt.state_dict()} 200 | torch.save(d, os.path.join(output_dir, 'checkpoints', f'checkpoint_{epoch}.pth')) 201 | 202 | else: 203 | if config.data.use_half: 204 | dataset = cifar10(config.data.root) 205 | test_set = list(zip(transpose(dataset['test']['data']/255.), dataset['test']['labels'])) 206 | test_loader = Batches(test_set, config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers) 207 | else: 208 | test_loader = get_test_loader(config) 209 | checkpoint_filename = config.checkpoint_filename 210 | d = torch.load(checkpoint_filename) 211 | logger.info(f"Loading model checkpoint {config.checkpoint_filename}...") 212 | model.load_state_dict(d["model"]) 213 | if args.dp: 214 | model = nn.DataParallel(model) 215 | model.eval() 216 | for evaluation in config.evaluations: 217 | test_err, test_loss = epoch_modes[evaluation.type](test_loader, model, eval_mode=True, **evaluation.params._asdict()) 218 | logger.info(f"{evaluation.type}: \t Test Loss {test_loss:.6f} \t Test Error {test_err:.6f}") 219 | 220 | 221 | if __name__ == "__main__": 222 | main() 223 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from collections import namedtuple 8 | 9 | 10 | def get_config(config_path): 11 | with open(config_path) as config_file: 12 | config = json.load(config_file) 13 | return config 14 | 15 | 16 | def config_to_namedtuple(obj): 17 | if isinstance(obj, dict): 18 | for key, value in obj.items(): 19 | obj[key] = config_to_namedtuple(value) 20 | return namedtuple('GenericDict', obj.keys())(**obj) 21 | elif isinstance(obj, list): 22 | return [config_to_namedtuple(item) for item in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def get_logger(name_, output_dir=None): 28 | logger = logging.getLogger(name_) 29 | if output_dir is not None: 30 | handlers=[logging.FileHandler(os.path.join(output_dir,'output.log')), 31 | logging.StreamHandler()] 32 | else: 33 | handlers = [logging.StreamHandler()] 34 | logging.basicConfig( 35 | format='[%(asctime)s] - %(message)s', 36 | datefmt='%Y/%m/%d %H:%M:%S', 37 | level=logging.DEBUG, 38 | handlers=handlers) 39 | return logger 40 | 41 | 42 | ######## mixed precision CIFAR-10 ######### 43 | ######## source: https://github.com/davidcpage/cifar10-fast ######## 44 | 45 | cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255 46 | cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255 47 | 48 | 49 | def pad(x, border=4): 50 | return np.pad(x, [(0, 0), (border, border), (border, border), (0, 0)], mode='reflect') 51 | 52 | 53 | def transpose(x, source='NHWC', target='NCHW'): 54 | return x.transpose([source.index(d) for d in target]) 55 | 56 | 57 | class Batches(): 58 | def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False): 59 | self.dataset = dataset 60 | self.batch_size = batch_size 61 | self.set_random_choices = set_random_choices 62 | self.dataloader = torch.utils.data.DataLoader( 63 | dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last 64 | ) 65 | 66 | def __iter__(self): 67 | if self.set_random_choices: 68 | self.dataset.set_random_choices() 69 | return ({'input': x.cuda().half(), 'target': y.cuda().long()} for (x,y) in self.dataloader) 70 | 71 | def __len__(self): 72 | return len(self.dataloader) 73 | 74 | 75 | class Transform(): 76 | def __init__(self, dataset, transforms): 77 | self.dataset, self.transforms = dataset, transforms 78 | self.choices = None 79 | 80 | def __len__(self): 81 | return len(self.dataset) 82 | 83 | def __getitem__(self, index): 84 | data, labels = self.dataset[index] 85 | for choices, f in zip(self.choices, self.transforms): 86 | args = {k: v[index] for (k,v) in choices.items()} 87 | data = f(data, **args) 88 | return data, labels 89 | 90 | def set_random_choices(self): 91 | self.choices = [] 92 | x_shape = self.dataset[0][0].shape 93 | N = len(self) 94 | for t in self.transforms: 95 | options = t.options(x_shape) 96 | x_shape = t.output_shape(x_shape) if hasattr(t, 'output_shape') else x_shape 97 | self.choices.append({k:np.random.choice(v, size=N) for (k,v) in options.items()}) 98 | 99 | 100 | class Crop(namedtuple('Crop', ('h', 'w'))): 101 | def __call__(self, x, x0, y0): 102 | return x[:,y0:y0+self.h,x0:x0+self.w] 103 | 104 | def options(self, x_shape): 105 | C, H, W = x_shape 106 | return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)} 107 | 108 | def output_shape(self, x_shape): 109 | C, H, W = x_shape 110 | return (C, self.h, self.w) 111 | 112 | 113 | class FlipLR(namedtuple('FlipLR', ())): 114 | def __call__(self, x, choice): 115 | return x[:, :, ::-1].copy() if choice else x 116 | 117 | def options(self, x_shape): 118 | return {'choice': [True, False]} 119 | 120 | 121 | def cifar10(root): 122 | train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True) 123 | test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True) 124 | return { 125 | 'train': {'data': train_set.data, 'labels': train_set.targets}, 126 | 'test': {'data': test_set.data, 'labels': test_set.targets} 127 | } 128 | --------------------------------------------------------------------------------