├── .gitignore ├── README.md ├── configs ├── cifar10 │ ├── resnet32 │ │ ├── adam │ │ │ ├── mb128_lr1e3_bn.json │ │ │ ├── mb128_lr1e3_bn_aug.json │ │ │ ├── mb128_lr1e3_wd1_all_bn.json │ │ │ ├── mb128_lr1e3_wd3e1_all_bn_aug.json │ │ │ ├── mb128_lr3e4.json │ │ │ └── mb128_lr3e4_wd1_all.json │ │ ├── kfac │ │ │ ├── mb128_lr1e2_damp1e3_cov95_bn_aug.json │ │ │ ├── mb128_lr1e2_damp1e3_cov95_kl1e4.json │ │ │ ├── mb128_lr3e2_damp1e3_cov95_bn.json │ │ │ ├── mb128_lr3e2_damp1e3_cov95_wd1e2_all_bn.json │ │ │ ├── mb128_lr3e2_damp1e3_cov95_wd3e2_all_bn_aug.json │ │ │ └── mb128_lr3e2_damp1e3_cov95_wd3e2_all_kl1e4.json │ │ ├── kfacf │ │ │ ├── mb128_lr1e3_damp1e3_cov95_bn.json │ │ │ ├── mb128_lr1e3_damp1e3_cov95_bn_aug.json │ │ │ ├── mb128_lr1e3_damp1e3_cov95_wd1_all_bn.json │ │ │ ├── mb128_lr1e3_damp1e3_cov95_wd1_all_kl1e5.json │ │ │ ├── mb128_lr3e4_damp1e3_cov95_kl1e5.json │ │ │ └── mb128_lr3e4_damp1e3_cov95_wd1_all_bn_aug.json │ │ └── sgd │ │ │ ├── mb128_lr1e2_m9_wd3e3_all_bn_aug.json │ │ │ ├── mb128_lr3e2_m9_bn.json │ │ │ ├── mb128_lr3e2_m9_bn_aug.json │ │ │ ├── mb128_lr3e2_m9_wd3e3_all_bn.json │ │ │ ├── mb128_lr3e3_m9.json │ │ │ └── mb128_lr3e3_m9_wd1e3_all.json │ └── vgg16 │ │ ├── adam │ │ ├── mb128_lr1e3_bn.json │ │ ├── mb128_lr1e3_bn_aug.json │ │ ├── mb128_lr1e3_wd3e1_all_bn.json │ │ ├── mb128_lr3e4.json │ │ ├── mb128_lr3e4_wd1_all_bn_aug.json │ │ └── mb128_lr3e4_wd3e1_all.json │ │ ├── kfac │ │ ├── mb128_lr1e1_damp1e3_cov95_bn.json │ │ ├── mb128_lr3e2_damp1e2_cov95_kl1e4.json │ │ ├── mb128_lr3e2_damp1e2_cov95_wd3e2_all_kl1e3.json │ │ ├── mb128_lr3e2_damp1e3_cov95_bn_aug.json │ │ ├── mb128_lr3e2_damp1e3_cov95_wd1e2_all_bn_aug.json │ │ └── mb128_lr3e2_damp1e3_cov95_wd3e2_all_bn.json │ │ ├── kfacf │ │ ├── mb128_lr1e3_damp1e3_cov95_bn.json │ │ ├── mb128_lr1e3_damp1e3_cov95_bn_aug.json │ │ ├── mb128_lr1e3_damp1e3_cov95_wd1_all_kl1e5.json │ │ ├── mb128_lr1e3_damp1e3_cov95_wd3e1_all_bn_aug.json │ │ ├── mb128_lr3e3_damp1e3_cov95_wd1_all_bn.json │ │ └── mb128_lr3e4_damp1e3_cov95_kl1e5.json │ │ └── sgd │ │ ├── mb128_lr1e1_m9_wd1e3_all_bn.json │ │ ├── mb128_lr1e1_m9_wd1e3_all_bn_aug.json │ │ ├── mb128_lr1e2_m9_bn_aug.json │ │ ├── mb128_lr3e2_m9.json │ │ ├── mb128_lr3e2_m9_bn.json │ │ └── mb128_lr3e2_m9_wd3e4_all.json └── cifar100 │ ├── resnet32 │ ├── adam │ │ ├── mb128_lr3e4_bn_aug.json │ │ └── mb128_lr3e4_wd1_all_bn_aug.json │ ├── kfac │ │ ├── mb128_lr5e2_damp1e3_cov95_bn_aug.json │ │ └── mb128_lr5e2_damp1e3_cov95_wd1e2_all_bn_aug.json │ ├── kfacf │ │ ├── mb128_lr1e3_damp1e3_cov95_bn_aug.json │ │ └── mb128_lr1e3_damp1e3_cov95_wd1_all_bn_aug.json │ └── sgd │ │ ├── mb128_lr3e2_bn_aug.json │ │ └── mb128_lr3e2_wd3e3_all_bn_aug.json │ └── vgg16 │ ├── adam │ ├── mb128_lr3e4_bn_aug.json │ └── mb128_lr3e4_wd1_all_bn_aug.json │ ├── kfac │ ├── mb128_lr1e1_damp1e3_cov95_bn_aug.json │ └── mb128_lr1e1_damp1e3_cov95_wd1e2_all_bn_aug.json │ ├── kfacf │ ├── mb128_lr1e3_damp1e3_cov95_wd1_all_bn_aug.json │ └── mb128_lr3e3_damp1e3_cov95_bn_aug.json │ └── sgd │ ├── mb128_lr3e2_bn_aug.json │ └── mb128_lr3e2_wd3e3_all_bn_aug.json ├── core ├── base_model.py ├── base_train.py ├── model.py └── train.py ├── data_loader.py ├── libs ├── adam │ └── optimizer.py ├── kfac │ ├── cmvp.py │ ├── estimator.py │ ├── fisher_blocks.py │ ├── fisher_factors.py │ ├── layer_collection.py │ ├── loss_functions.py │ ├── optimizer.py │ └── utils.py └── sgd │ └── optimizer.py ├── main.py ├── misc ├── config.py ├── summarizer.py └── utils.py └── network ├── __init__.py ├── mlp.py ├── registry.py ├── resnet.py └── vgg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | tmp 104 | runs 105 | run 106 | 107 | # PyCharm 108 | .idea/ 109 | 110 | # macOS metadata 111 | .DS_Store 112 | 113 | # 114 | *.sh 115 | log/ 116 | summary/ 117 | data/ 118 | experiments/ 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Three Mechanisms of Weight Decay Regularization 2 | This repo contains the official implementations of [Three Mechanisms of Weight Decay Regularization](https://openreview.net/forum?id=B1lz-3Rct7). 3 | 4 | 1. The config file for the experiments are under the directory of `configs/`. 5 | 2. The modified optimization algorithms are in `libs/`. 6 | 7 | # Citation 8 | To cite this work, please use 9 | ``` 10 | @inproceedings{ 11 | zhang2018three, 12 | title={Three Mechanisms of Weight Decay Regularization}, 13 | author={Guodong Zhang and Chaoqi Wang and Bowen Xu and Roger Grosse}, 14 | booktitle={International Conference on Learning Representations}, 15 | year={2019}, 16 | url={https://openreview.net/forum?id=B1lz-3Rct7}, 17 | } 18 | ``` 19 | 20 | # Requirements 21 | This project uses Python 3.5.2. Before running the code, you have to install 22 | * [Tensorflow 1.4+](https://www.tensorflow.org/) 23 | * [PyTorch](http://pytorch.org/) 24 | * [Numpy](http://www.numpy.org/) 25 | * [tqdm](https://pypi.python.org/pypi/tqdm) 26 | 27 | # How to run? 28 | 29 | ``` 30 | # example 31 | $ python main.py --config configs/cifar100/resnet32/kfac/mb128_lr5e2_damp1e3_cov95_bn_aug.json 32 | ``` 33 | 34 | # Credit 35 | This repo uses a modified version of [Tensorflow K-FAC](https://github.com/tensorflow/kfac). 36 | 37 | # Contact 38 | If you have any questions or suggestions about the code or paper, please do not hesitate to contact with Guodong Zhang(`gdzhang.cs@gmail.com` or `gdzhang@cs.toronto.edu`) and Chaoqi Wang(`alecwangcq@gmail.com` or `cqwang@cs.toronto.edu`). 39 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/adam/mb128_lr1e3_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/adam/mb128_lr1e3_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/adam/mb128_lr1e3_wd1_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_wd1_all_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/adam/mb128_lr1e3_wd3e1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_wd3e1_all_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.3, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/adam/mb128_lr3e4.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/adam/mb128_lr3e4_wd1_all.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_wd1_all", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfac/mb128_lr1e2_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e2_damp1e3_cov95_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.01, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfac/mb128_lr1e2_damp1e3_cov95_kl1e4.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e2_damp1e3_cov95_kl1e4", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.01, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-4 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfac/mb128_lr3e2_damp1e3_cov95_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e3_cov95_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfac/mb128_lr3e2_damp1e3_cov95_wd1e2_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e3_cov95_wd1e2_all_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1e-2, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfac/mb128_lr3e2_damp1e3_cov95_wd3e2_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e3_cov95_wd3e2_all_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 3e-2, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfac/mb128_lr3e2_damp1e3_cov95_wd3e2_all_kl1e4.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e3_cov95_wd3e2_all_kl1e4", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 3e-2, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-4 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfacf/mb128_lr1e3_damp1e3_cov95_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfacf/mb128_lr1e3_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfacf/mb128_lr1e3_damp1e3_cov95_wd1_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_wd1_all_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfacf/mb128_lr1e3_damp1e3_cov95_wd1_all_kl1e5.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_wd1_all_kl1e5", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-5 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfacf/mb128_lr3e4_damp1e3_cov95_kl1e5.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_damp1e3_cov95_kl1e5", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-5 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/kfacf/mb128_lr3e4_damp1e3_cov95_wd1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_damp1e3_cov95_wd1_all_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/resnet32/sgd/mb128_lr1e2_m9_wd3e3_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 128, 3 | "decay_every_itr": 23500, 4 | "exp_name": "sgd_mb128_lr1e2_m9_wd3e3_l2_bias_BN_dataaug", 5 | "batch_norm": true, 6 | "model_name": "resnet32", 7 | "learning_rate": 0.01, 8 | "roger_init": false, 9 | "dataset": "cifar10", 10 | "momentum": 0.9, 11 | "use_bias": true, 12 | "use_fisher": false, 13 | "test_batch_size": 1000, 14 | "use_kfac": false, 15 | "num_workers": 2, 16 | "weight_decay": 0.003, 17 | "epoch": 200, 18 | "optimizer": "sgd", 19 | "data_aug": true, 20 | "max_to_keep": 5, 21 | "weight_list": "all", 22 | "weight_decay_type": "l2" 23 | } -------------------------------------------------------------------------------- /configs/cifar10/resnet32/sgd/mb128_lr3e2_m9_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "cifar10", 3 | "use_bias": true, 4 | "use_kfac": false, 5 | "exp_name": "sgd_mb128_lr3e2_m9_wd0_l2_bias_BN", 6 | "test_batch_size": 1000, 7 | "epoch": 200, 8 | "momentum": 0.9, 9 | "optimizer": "sgd", 10 | "num_workers": 2, 11 | "weight_list": "all", 12 | "learning_rate": 0.03, 13 | "batch_size": 128, 14 | "decay_every_itr": 23500, 15 | "data_aug": false, 16 | "roger_init": false, 17 | "weight_decay_type": "l2", 18 | "max_to_keep": 5, 19 | "use_fisher": false, 20 | "weight_decay": 0, 21 | "model_name": "resnet32", 22 | "batch_norm": true 23 | } -------------------------------------------------------------------------------- /configs/cifar10/resnet32/sgd/mb128_lr3e2_m9_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "cifar10", 3 | "use_bias": true, 4 | "use_kfac": false, 5 | "exp_name": "sgd_mb128_lr3e2_m9_wd0_l2_bias_BN_dataaug", 6 | "test_batch_size": 1000, 7 | "epoch": 200, 8 | "momentum": 0.9, 9 | "optimizer": "sgd", 10 | "num_workers": 2, 11 | "weight_list": "all", 12 | "learning_rate": 0.03, 13 | "batch_size": 128, 14 | "decay_every_itr": 23500, 15 | "data_aug": true, 16 | "roger_init": false, 17 | "weight_decay_type": "l2", 18 | "max_to_keep": 5, 19 | "use_fisher": false, 20 | "weight_decay": 0, 21 | "model_name": "resnet32", 22 | "batch_norm": true 23 | } -------------------------------------------------------------------------------- /configs/cifar10/resnet32/sgd/mb128_lr3e2_m9_wd3e3_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet32", 3 | "dataset": "cifar10", 4 | "batch_size": 128, 5 | "epoch": 200, 6 | "use_fisher": false, 7 | "momentum": 0.9, 8 | "learning_rate": 0.03, 9 | "use_bias": true, 10 | "use_kfac": false, 11 | "decay_every_itr": 23500, 12 | "data_aug": false, 13 | "roger_init": false, 14 | "weight_decay_type": "l2", 15 | "batch_norm": true, 16 | "test_batch_size": 1000, 17 | "max_to_keep": 5, 18 | "weight_list": "all", 19 | "weight_decay": 0.003, 20 | "num_workers": 2, 21 | "optimizer": "sgd", 22 | "exp_name": "sgd_mb128_lr3e2_m9_wd3e3_l2_bias_BN" 23 | } -------------------------------------------------------------------------------- /configs/cifar10/resnet32/sgd/mb128_lr3e3_m9.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "cifar10", 3 | "use_bias": true, 4 | "use_kfac": false, 5 | "exp_name": "sgd_mb128_lr3e3_m9_wd0_l2_bias", 6 | "test_batch_size": 1000, 7 | "epoch": 200, 8 | "momentum": 0.9, 9 | "optimizer": "sgd", 10 | "num_workers": 2, 11 | "weight_list": "all", 12 | "learning_rate": 0.003, 13 | "batch_size": 128, 14 | "decay_every_itr": 23500, 15 | "data_aug": false, 16 | "roger_init": false, 17 | "weight_decay_type": "l2", 18 | "max_to_keep": 5, 19 | "use_fisher": false, 20 | "weight_decay": 0, 21 | "model_name": "resnet32", 22 | "batch_norm": false 23 | } -------------------------------------------------------------------------------- /configs/cifar10/resnet32/sgd/mb128_lr3e3_m9_wd1e3_all.json: -------------------------------------------------------------------------------- 1 | { 2 | "roger_init": false, 3 | "momentum": 0.9, 4 | "batch_norm": false, 5 | "weight_decay": 0.001, 6 | "data_aug": false, 7 | "num_workers": 2, 8 | "batch_size": 128, 9 | "weight_decay_type": "l2", 10 | "use_kfac": false, 11 | "use_fisher": false, 12 | "dataset": "cifar10", 13 | "weight_list": "all", 14 | "model_name": "resnet32", 15 | "use_bias": true, 16 | "max_to_keep": 5, 17 | "optimizer": "sgd", 18 | "epoch": 200, 19 | "exp_name": "sgd_mb128_lr3e3_m9_wd1e3_l2_bias", 20 | "decay_every_itr": 23500, 21 | "test_batch_size": 1000, 22 | "learning_rate": 0.003 23 | } -------------------------------------------------------------------------------- /configs/cifar10/vgg16/adam/mb128_lr1e3_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/adam/mb128_lr1e3_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/adam/mb128_lr1e3_wd3e1_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_wd3e1_all_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.3, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/adam/mb128_lr3e4.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/adam/mb128_lr3e4_wd1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_wd1_all_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/adam/mb128_lr3e4_wd3e1_all.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_wd3e1_all", 3 | "dataset": "cifar10", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.3, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfac/mb128_lr1e1_damp1e3_cov95_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e1_damp1e3_cov95_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.1, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfac/mb128_lr3e2_damp1e2_cov95_kl1e4.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e2_cov95_kl1e4", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-2, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-4 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfac/mb128_lr3e2_damp1e2_cov95_wd3e2_all_kl1e3.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e2_cov95_wd3e2_all_kl1e3", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 3e-2, 22 | "weight_list": "all", 23 | "damping": 1e-2, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-3 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfac/mb128_lr3e2_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e3_cov95_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfac/mb128_lr3e2_damp1e3_cov95_wd1e2_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e3_cov95_wd1e2_all_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1e-2, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfac/mb128_lr3e2_damp1e3_cov95_wd3e2_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_damp1e3_cov95_wd3e2_all_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 3e-2, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfacf/mb128_lr1e3_damp1e3_cov95_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfacf/mb128_lr1e3_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfacf/mb128_lr1e3_damp1e3_cov95_wd1_all_kl1e5.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_wd1_all_kl1e5", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-5 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfacf/mb128_lr1e3_damp1e3_cov95_wd3e1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_wd3e1_all_bn_aug", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.3, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfacf/mb128_lr3e3_damp1e3_cov95_wd1_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e3_damp1e3_cov95_wd1_all_bn", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/kfacf/mb128_lr3e4_damp1e3_cov95_kl1e5.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_damp1e3_cov95_kl1e5", 3 | "dataset": "cifar10", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": false, 9 | "batch_norm": false, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95, 25 | "kl_clip": 1e-5 26 | } 27 | -------------------------------------------------------------------------------- /configs/cifar10/vgg16/sgd/mb128_lr1e1_m9_wd1e3_all_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_to_keep": 5, 3 | "exp_name": "sgd_mb128_lr1e1_m9_wd1e3_l2_bias_BNsgd-bn-bias", 4 | "momentum": 0.9, 5 | "learning_rate": 0.1, 6 | "batch_norm": true, 7 | "data_aug": false, 8 | "model_name": "vgg16", 9 | "weight_decay": 0.001, 10 | "dataset": "cifar10", 11 | "test_batch_size": 1000, 12 | "weight_decay_type": "l2", 13 | "epoch": 200, 14 | "use_bias": true, 15 | "use_kfac": false, 16 | "optimizer": "sgd", 17 | "num_workers": 2, 18 | "batch_size": 128, 19 | "_APPROX": "kron", 20 | "roger_init": false, 21 | "decay_every_itr": 23500, 22 | "weight_list": "all" 23 | } -------------------------------------------------------------------------------- /configs/cifar10/vgg16/sgd/mb128_lr1e1_m9_wd1e3_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_bias": true, 3 | "epoch": 200, 4 | "exp_name": "sgd_mb128_lr1e1_m9_wd1e3_l2_bias_BN_dataaugsgd-bn-aug-bias", 5 | "use_kfac": false, 6 | "momentum": 0.9, 7 | "decay_every_itr": 23500, 8 | "dataset": "cifar10", 9 | "max_to_keep": 5, 10 | "num_workers": 2, 11 | "weight_decay_type": "l2", 12 | "batch_norm": true, 13 | "test_batch_size": 1000, 14 | "batch_size": 128, 15 | "model_name": "vgg16", 16 | "data_aug": true, 17 | "weight_decay": 0.001, 18 | "roger_init": false, 19 | "learning_rate": 0.1, 20 | "comment": "sgd-bn-aug-bias", 21 | "optimizer": "sgd", 22 | "weight_list": "all" 23 | } -------------------------------------------------------------------------------- /configs/cifar10/vgg16/sgd/mb128_lr1e2_m9_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "cifar10", 3 | "use_bias": true, 4 | "use_kfac": false, 5 | "exp_name": "sgd_mb128_lr1e2_m9_wd0_l2_bias_BN_dataaugsgd-bn-aug-bias", 6 | "test_batch_size": 1000, 7 | "epoch": 200, 8 | "momentum": 0.9, 9 | "optimizer": "sgd", 10 | "num_workers": 2, 11 | "learning_rate": 0.01, 12 | "batch_size": 128, 13 | "decay_every_itr": 23500, 14 | "gn_weight": 0, 15 | "data_aug": true, 16 | "roger_init": false, 17 | "weight_decay_type": "l2", 18 | "max_to_keep": 5, 19 | "weight_decay": 0, 20 | "model_name": "vgg16", 21 | "batch_norm": true, 22 | "weight_list": "all" 23 | } -------------------------------------------------------------------------------- /configs/cifar10/vgg16/sgd/mb128_lr3e2_m9.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "cifar10", 3 | "use_bias": true, 4 | "cov_ema_decay": 0, 5 | "use_kfac": false, 6 | "exp_name": "sgd_mb128_lr3e2_m9_wd0_l2_biassgd-bias", 7 | "test_batch_size": 1000, 8 | "epoch": 200, 9 | "momentum": 0.9, 10 | "optimizer": "sgd", 11 | "num_workers": 2, 12 | "learning_rate": 0.03, 13 | "batch_size": 128, 14 | "decay_every_itr": 23500, 15 | "data_aug": false, 16 | "roger_init": false, 17 | "weight_decay_type": "l2", 18 | "max_to_keep": 5, 19 | "weight_decay": 0, 20 | "model_name": "vgg16", 21 | "batch_norm": false, 22 | "weight_list": "all" 23 | } -------------------------------------------------------------------------------- /configs/cifar10/vgg16/sgd/mb128_lr3e2_m9_bn.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "cifar10", 3 | "use_bias": true, 4 | "use_kfac": false, 5 | "exp_name": "sgd_mb128_lr3e2_m9_wd0_l2_bias_BNsgd-bn-bias", 6 | "test_batch_size": 1000, 7 | "epoch": 200, 8 | "momentum": 0.9, 9 | "optimizer": "sgd", 10 | "num_workers": 2, 11 | "learning_rate": 0.03, 12 | "batch_size": 128, 13 | "decay_every_itr": 23500, 14 | "gn_weight": 0, 15 | "data_aug": false, 16 | "roger_init": false, 17 | "weight_decay_type": "l2", 18 | "max_to_keep": 5, 19 | "weight_decay": 0, 20 | "model_name": "vgg16", 21 | "batch_norm": true, 22 | "weight_list": "all" 23 | } -------------------------------------------------------------------------------- /configs/cifar10/vgg16/sgd/mb128_lr3e2_m9_wd3e4_all.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight_decay": 0.0003, 3 | "weight_decay_type": "l2", 4 | "roger_init": false, 5 | "decay_every_itr": 23500, 6 | "optimizer": "sgd", 7 | "data_aug": false, 8 | "max_to_keep": 5, 9 | "learning_rate": 0.03, 10 | "test_batch_size": 1000, 11 | "num_workers": 2, 12 | "cov_ema_decay": 0, 13 | "exp_name": "sgd_mb128_lr3e2_m9_wd3e4_l2_biassgd-bias", 14 | "batch_norm": false, 15 | "momentum": 0.9, 16 | "use_kfac": false, 17 | "batch_size": 128, 18 | "use_bias": true, 19 | "epoch": 200, 20 | "model_name": "vgg16", 21 | "dataset": "cifar10", 22 | "weight_list": "all" 23 | } -------------------------------------------------------------------------------- /configs/cifar100/resnet32/adam/mb128_lr3e4_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar100/resnet32/adam/mb128_lr3e4_wd1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_wd1_all_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "adam", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar100/resnet32/kfac/mb128_lr5e2_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr5e2_damp1e3_cov95_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.05, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/resnet32/kfac/mb128_lr5e2_damp1e3_cov95_wd1e2_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr5e2_damp1e3_cov95_wd1e2_all_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.05, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1e-2, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/resnet32/kfacf/mb128_lr1e3_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/resnet32/kfacf/mb128_lr1e3_damp1e3_cov95_wd1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_wd1_all_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "resnet32", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/resnet32/sgd/mb128_lr3e2_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_bn_aug_rerun", 3 | "dataset": "cifar100", 4 | "optimizer": "sgd", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 100000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar100/resnet32/sgd/mb128_lr3e2_wd3e3_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_wd3e3_all_bn_aug_rerun", 3 | "dataset": "cifar100", 4 | "optimizer": "sgd", 5 | "model_name": "resnet32", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 100000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 3e-3, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/adam/mb128_lr3e4_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/adam/mb128_lr3e4_wd1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e4_wd1_all_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "adam", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.0003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/kfac/mb128_lr1e1_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e1_damp1e3_cov95_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.1, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/kfac/mb128_lr1e1_damp1e3_cov95_wd1e2_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e1_damp1e3_cov95_wd1e2_all_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.1, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1e-2, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/kfacf/mb128_lr1e3_damp1e3_cov95_wd1_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr1e3_damp1e3_cov95_wd1_all_bn_aug_rerun", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.001, 19 | "momentum": 0.9, 20 | "weight_decay_type": "wd", 21 | "weight_decay": 1.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/kfacf/mb128_lr3e3_damp1e3_cov95_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e3_damp1e3_cov95_bn_aug_rerun", 3 | "dataset": "cifar100", 4 | "optimizer": "kfac", 5 | "model_name": "vgg16", 6 | "use_kfac": true, 7 | "roger_init": true, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": true, 12 | "epoch": 100, 13 | "decay_every_itr": 16000, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.003, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all", 23 | "damping": 1e-3, 24 | "cov_ema_decay": 0.95 25 | } 26 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/sgd/mb128_lr3e2_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "sgd", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 0.0, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /configs/cifar100/vgg16/sgd/mb128_lr3e2_wd3e3_all_bn_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "mb128_lr3e2_wd3e3_all_bn_aug", 3 | "dataset": "cifar100", 4 | "optimizer": "sgd", 5 | "model_name": "vgg16", 6 | "use_kfac": false, 7 | "roger_init": false, 8 | "data_aug": true, 9 | "batch_norm": true, 10 | "use_bias": true, 11 | "use_fisher": false, 12 | "epoch": 200, 13 | "decay_every_itr": 23460, 14 | "batch_size": 128, 15 | "test_batch_size": 1000, 16 | "max_to_keep": 5, 17 | "num_workers": 2, 18 | "learning_rate": 0.03, 19 | "momentum": 0.9, 20 | "weight_decay_type": "l2", 21 | "weight_decay": 3e-3, 22 | "weight_list": "all" 23 | } 24 | -------------------------------------------------------------------------------- /core/base_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class BaseModel: 5 | def __init__(self, config): 6 | self.config = config 7 | # init the global step 8 | self.init_global_step() 9 | 10 | # save function that saves the checkpoint in the path defined in the config file 11 | def save(self, sess): 12 | print("Saving model...") 13 | self.saver.save(sess, self.config.checkpoint_dir, self.global_step_tensor) 14 | print("Model saved") 15 | 16 | # load latest checkpoint from the experiment path defined in the config file 17 | def load(self, sess): 18 | latest_checkpoint = tf.train.latest_checkpoint(self.config.checkpoint_dir) 19 | if latest_checkpoint: 20 | print("Loading model checkpoint {} ...\n".format(latest_checkpoint)) 21 | self.saver.restore(sess, latest_checkpoint) 22 | print("Model loaded") 23 | 24 | # just initialize a tensorflow variable to use it as global step counter 25 | def init_global_step(self): 26 | # DON'T forget to add the global step tensor to the tensorflow trainer 27 | with tf.variable_scope('global_step'): 28 | self.global_step_tensor = tf.Variable(0, trainable=False, name='global_step') 29 | 30 | def init_saver(self): 31 | # just copy the following line in your child class 32 | # self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep) 33 | raise NotImplementedError 34 | 35 | def build_model(self): 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /core/base_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from misc.summarizer import Summarizer 3 | 4 | 5 | class BaseTrain: 6 | def __init__(self, sess, model, config, logger): 7 | self.model = model 8 | self.logger = logger 9 | if logger is not None: 10 | self.summarizer = Summarizer(sess, config) 11 | self.config = config 12 | self.sess = sess 13 | self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 14 | self.sess.run(self.init) 15 | 16 | def train_epoch(self): 17 | """ 18 | implement the logic of epoch: 19 | -loop over the number of iterations in the config and call the train step 20 | -add any summaries you want using the summary 21 | """ 22 | raise NotImplementedError 23 | 24 | def test_epoch(self): 25 | """ 26 | implement the logic of the train step 27 | - run the tensorflow session 28 | - return any metrics you need to summarize 29 | """ 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from libs.kfac import optimizer as kfac_opt 4 | from libs.kfac import layer_collection as lc 5 | from libs.sgd import optimizer as sgd_opt 6 | from libs.adam import optimizer as adam_opt 7 | from misc.utils import flatten, unflatten 8 | from network.registry import get_model 9 | from core.base_model import BaseModel 10 | 11 | 12 | class Model(BaseModel): 13 | def __init__(self, config, sess): 14 | super().__init__(config) 15 | self.sess = sess 16 | self.build_model() 17 | self.init_optim() 18 | self.init_saver() 19 | 20 | @property 21 | def params_net(self): 22 | return tf.trainable_variables('network') 23 | 24 | @property 25 | def params_all(self): 26 | return tf.global_variables() 27 | 28 | @property 29 | def params_w_flatten(self): 30 | return flatten(self.params_net) 31 | 32 | @property 33 | def params_w_flatten_last(self): 34 | return flatten(self.params_net[-2:]) 35 | 36 | def init_saver(self): 37 | self.saver = tf.train.Saver(var_list=self.params_all, 38 | max_to_keep=self.config.max_to_keep) 39 | 40 | def build_model(self): 41 | self.inputs = tf.placeholder(tf.float32, [None] + self.config.input_dim) 42 | self.targets = tf.placeholder(tf.int64, [None]) 43 | self.is_training = tf.placeholder(tf.bool) 44 | if self.config.use_kfac: 45 | self.layer_collection = lc.LayerCollection() 46 | else: 47 | self.layer_collection = None 48 | self.cov_update_op = None 49 | self.inv_update_op = None 50 | 51 | inputs = self.inputs 52 | with tf.variable_scope("network"): 53 | net = get_model(self.config.model_name) 54 | logits = net(inputs, self.is_training, self.config, self.layer_collection) 55 | 56 | self.acc = tf.reduce_mean(tf.cast(tf.equal( 57 | self.targets, tf.argmax(logits, axis=1)), dtype=tf.float32)) 58 | 59 | self.loss = tf.reduce_mean( 60 | tf.nn.sparse_softmax_cross_entropy_with_logits( 61 | labels=self.targets, logits=logits)) 62 | 63 | self.l2_norm = tf.reduce_sum(tf.square(self.params_w_flatten)) 64 | 65 | def init_optim(self): 66 | if self.config.optimizer == "sgd": 67 | self.optim = sgd_opt.SGDOptimizer( 68 | tf.train.exponential_decay(self.config.learning_rate, 69 | self.global_step_tensor, 70 | self.config.decay_every_itr, 0.1, staircase=True), 71 | momentum=self.config.momentum, 72 | weight_decay=self.config.weight_decay, 73 | weight_decay_type=self.config.weight_decay_type, 74 | weight_list=self.config.weight_list) 75 | 76 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 77 | with tf.control_dependencies(update_ops): 78 | self.train_op = self.optim.minimize(self.loss, global_step=self.global_step_tensor) 79 | elif self.config.optimizer == "adam": 80 | self.optim = adam_opt.ADAMOptimizer( 81 | tf.train.exponential_decay(self.config.learning_rate, 82 | self.global_step_tensor, 83 | self.config.decay_every_itr, 0.1, staircase=True), 84 | weight_decay=self.config.weight_decay, 85 | weight_decay_type=self.config.weight_decay_type, 86 | weight_list=self.config.weight_list) 87 | 88 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 89 | with tf.control_dependencies(update_ops): 90 | self.train_op = self.optim.minimize(self.loss, global_step=self.global_step_tensor) 91 | else: 92 | kl_clip = self.config.get("kl_clip", None) 93 | self.optim = kfac_opt.KFACOptimizer( 94 | tf.train.exponential_decay(self.config.learning_rate, 95 | self.global_step_tensor, 96 | self.config.decay_every_itr, 0.1, staircase=True), 97 | cov_ema_decay=self.config.cov_ema_decay, 98 | damping=self.config.damping, 99 | layer_collection=self.layer_collection, 100 | norm_constraint=kl_clip, 101 | momentum=self.config.momentum, 102 | weight_decay=self.config.weight_decay, 103 | weight_decay_type=self.config.weight_decay_type, 104 | weight_list=self.config.weight_list) 105 | 106 | self.cov_update_op = self.optim.cov_update_op 107 | self.inv_update_op = self.optim.inv_update_op 108 | 109 | self.update_ops = update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 110 | with tf.control_dependencies(update_ops): 111 | self.train_op = self.optim.minimize(self.loss, global_step=self.global_step_tensor) 112 | -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | from core.base_train import BaseTrain 2 | from tqdm import tqdm 3 | from misc.utils import SetFromFlat, GetFlat, unflatten, flatten, numel 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | class Trainer(BaseTrain): 9 | def __init__(self, sess, model, train_loader, test_loader, config, logger): 10 | super(Trainer, self).__init__(sess, model, config, logger) 11 | self.train_loader = train_loader 12 | self.test_loader = test_loader 13 | 14 | self.get_params = GetFlat(self.sess, self.model.params_net) 15 | self.set_params = SetFromFlat(self.sess, self.model.params_net) 16 | self.unflatten = unflatten(self.model.params_net) 17 | self.norm_list = [] 18 | 19 | self.summary_op = tf.summary.merge_all() 20 | 21 | def init_kfac(self): 22 | self.logger.info('Roger Initialization!') 23 | self.model.optim._fisher_est.reset(self.sess) 24 | 25 | for itr, (x, y) in enumerate(self.train_loader): 26 | feed_dict = { 27 | self.model.inputs: x, 28 | # self.model.targets: y, 29 | self.model.is_training: True 30 | } 31 | self.sess.run(self.model.optim.init_cov_op, feed_dict=feed_dict) 32 | self.model.optim._fisher_est.rescale(self.sess, 1. / len(self.train_loader)) 33 | 34 | # inverse 35 | if self.model.inv_update_op is not None: 36 | self.sess.run(self.model.inv_update_op) 37 | 38 | self.logger.info('Done Roger Initialization!') 39 | 40 | def train(self): 41 | if self.config.roger_init: 42 | self.init_kfac() 43 | for cur_epoch in range(self.config.epoch): 44 | self.logger.info('epoch: {}'.format(int(cur_epoch))) 45 | self.train_epoch() 46 | self.test_epoch() 47 | 48 | if cur_epoch % 100 == 0: 49 | self.model.save(self.sess) 50 | 51 | def train_epoch(self): 52 | loss_list = [] 53 | acc_list = [] 54 | 55 | for itr, (x, y) in enumerate(tqdm(self.train_loader)): 56 | feed_dict = { 57 | self.model.inputs: x, 58 | self.model.targets: y, 59 | self.model.is_training: True, 60 | } 61 | self.sess.run(self.model.train_op, feed_dict=feed_dict) 62 | cur_iter = self.model.global_step_tensor.eval(self.sess) 63 | 64 | if cur_iter % self.config.get('TCov', 10) == 0 and self.model.cov_update_op is not None: 65 | self.sess.run(self.model.cov_update_op, feed_dict=feed_dict) 66 | 67 | if cur_iter % self.config.get('TInv', 100) == 0 and self.model.inv_update_op is not None: 68 | self.sess.run(self.model.inv_update_op) 69 | 70 | for itr, (x, y) in enumerate(self.train_loader): 71 | feed_dict = { 72 | self.model.inputs: x, 73 | self.model.targets: y, 74 | self.model.is_training: True 75 | } 76 | 77 | loss, acc = self.sess.run( 78 | [self.model.loss, self.model.acc], 79 | feed_dict=feed_dict) 80 | loss_list.append(loss) 81 | acc_list.append(acc) 82 | 83 | avg_loss = np.mean(loss_list) 84 | avg_acc = np.mean(acc_list) 85 | self.logger.info("[Train] loss: %5.4f | accuracy: %5.4f"%(float(avg_loss), float(avg_acc))) 86 | 87 | l2_norm = self.sess.run(self.model.l2_norm) 88 | self.logger.info("l2_norm: %5.4f"%(float(l2_norm))) 89 | 90 | # summarize 91 | summaries_dict = dict() 92 | summaries_dict['train_loss'] = avg_loss 93 | summaries_dict['train_acc'] = avg_acc 94 | summaries_dict['l2_norm'] = l2_norm 95 | 96 | # summarize 97 | cur_iter = self.model.global_step_tensor.eval(self.sess) 98 | self.summarizer.summarize(cur_iter, summaries_dict=summaries_dict) 99 | 100 | def test_epoch(self): 101 | loss_list = [] 102 | acc_list = [] 103 | for (x, y) in self.test_loader: 104 | feed_dict = { 105 | self.model.inputs: x, 106 | self.model.targets: y, 107 | self.model.is_training: False 108 | } 109 | loss, acc = self.sess.run([self.model.loss, self.model.acc], feed_dict=feed_dict) 110 | loss_list.append(loss) 111 | acc_list.append(acc) 112 | 113 | avg_loss = np.mean(loss_list) 114 | avg_acc = np.mean(acc_list) 115 | self.logger.info("[Test] loss: %5.4f | accuracy: %5.4f"%(float(avg_loss), float(avg_acc))) 116 | 117 | # summarize 118 | summaries_dict = dict() 119 | summaries_dict['test_loss'] = avg_loss 120 | summaries_dict['test_acc'] = avg_acc 121 | 122 | # summarize 123 | cur_iter = self.model.global_step_tensor.eval(self.sess) 124 | self.summarizer.summarize(cur_iter, summaries_dict=summaries_dict) 125 | 126 | 127 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | data_path = '../data' 7 | 8 | 9 | class Flatten(object): 10 | def __call__(self, tensor): 11 | return tensor.view(-1) 12 | 13 | def __repr__(self): 14 | return self.__class__.__name__ 15 | 16 | 17 | class Transpose(object): 18 | def __call__(self, tensor): 19 | return tensor.permute(1, 2, 0) 20 | 21 | def __repr__(self): 22 | return self.__class__.__name__ 23 | 24 | 25 | def load_pytorch(config): 26 | if config.dataset == 'cifar10': 27 | if config.data_aug: 28 | train_transform = transforms.Compose([ 29 | transforms.RandomCrop(32, padding=4), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 33 | Transpose() 34 | ]) 35 | else: 36 | train_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 39 | Transpose() 40 | ]) 41 | test_transform = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | Transpose() 45 | ]) 46 | 47 | trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform) 48 | testset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=test_transform) 49 | elif config.dataset == 'cifar100': 50 | if config.data_aug: 51 | train_transform = transforms.Compose([ 52 | transforms.RandomCrop(32, padding=4), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 56 | Transpose() 57 | ]) 58 | else: 59 | train_transform = transforms.Compose([ 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 62 | Transpose() 63 | ]) 64 | test_transform = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 67 | Transpose() 68 | ]) 69 | trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=train_transform) 70 | testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=test_transform) 71 | elif config.dataset == 'mnist': 72 | transform = transforms.Compose([ 73 | transforms.ToTensor(), 74 | Flatten(), 75 | ]) 76 | trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transform) 77 | testset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transform) 78 | elif config.dataset == 'fmnist': 79 | transform = transforms.Compose([ 80 | transforms.ToTensor(), 81 | Flatten(), 82 | ]) 83 | trainset = torchvision.datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transform) 84 | testset = torchvision.datasets.FashionMNIST(root=data_path, train=False, download=True, transform=transform) 85 | else: 86 | raise ValueError("Unsupported dataset!") 87 | 88 | trainloader = torch.utils.data.DataLoader(trainset, 89 | batch_size=config.batch_size, 90 | shuffle=True, 91 | num_workers=config.num_workers) 92 | testloader = torch.utils.data.DataLoader(testset, 93 | batch_size=config.test_batch_size, 94 | shuffle=False, 95 | num_workers=config.num_workers) 96 | return trainloader, testloader 97 | -------------------------------------------------------------------------------- /libs/adam/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.python.framework import ops 6 | from tensorflow.python.ops import array_ops 7 | from tensorflow.python.ops import gen_array_ops 8 | from tensorflow.python.ops import math_ops 9 | from tensorflow.python.ops import variable_scope 10 | from tensorflow.python.ops import variables as tf_variables 11 | from tensorflow.python.training import gradient_descent 12 | 13 | 14 | class ADAMOptimizer(gradient_descent.GradientDescentOptimizer): 15 | """ 16 | ADAM Optimizer 17 | """ 18 | 19 | def __init__(self, 20 | learning_rate, 21 | beta1=0.9, 22 | beta2=0.999, 23 | var_list=None, 24 | epsilon=1e-8, 25 | weight_decay=0., 26 | weight_decay_type="l2", 27 | weight_list="all", 28 | name="ADAM"): 29 | 30 | variables = var_list 31 | if variables is None: 32 | variables = tf_variables.trainable_variables() 33 | self.variables = variables 34 | 35 | weight_decay_type = weight_decay_type.lower() 36 | legal_weight_decay_types = ["wd", "l2"] 37 | 38 | if weight_decay_type not in legal_weight_decay_types: 39 | raise ValueError("Unsupported weight decay type {}. Must be one of {}." 40 | .format(weight_decay_type, legal_weight_decay_types)) 41 | 42 | self._beta1 = beta1 43 | self._beta2 = beta2 44 | self._epsilon = epsilon 45 | self._weight_decay = weight_decay 46 | self._weight_decay_type = weight_decay_type 47 | self._weight_list = weight_list 48 | self._init() 49 | 50 | super(ADAMOptimizer, self).__init__(learning_rate, name=name) 51 | 52 | def minimize(self, *args, **kwargs): 53 | kwargs["var_list"] = kwargs.get("var_list") or self.variables 54 | if set(kwargs["var_list"]) != set(self.variables): 55 | raise ValueError("var_list doesn't match with set of Fisher-estimating " 56 | "variables.") 57 | return super(ADAMOptimizer, self).minimize(*args, **kwargs) 58 | 59 | def compute_gradients(self, *args, **kwargs): 60 | # args[1] could be our var_list 61 | if len(args) > 1: 62 | var_list = args[1] 63 | else: 64 | kwargs["var_list"] = kwargs.get("var_list") or self.variables 65 | var_list = kwargs["var_list"] 66 | if set(var_list) != set(self.variables): 67 | raise ValueError("var_list doesn't match with set of Fisher-estimating " 68 | "variables.") 69 | return super(ADAMOptimizer, self).compute_gradients(*args, **kwargs) 70 | 71 | def apply_gradients(self, grads_and_vars, *args, **kwargs): 72 | grads_and_vars = list(grads_and_vars) 73 | 74 | if self._weight_decay > 0.0: 75 | if self._weight_decay_type == "l2": 76 | grads_and_vars = self._add_weight_decay(grads_and_vars) 77 | 78 | velocities_and_vars = self._update_velocities(grads_and_vars, self._beta1) 79 | covariances_and_vars = self._update_covariances(grads_and_vars, self._beta2) 80 | 81 | beta1_update_op = self._beta1_power.assign(self._beta1_power * self._beta1) 82 | beta2_update_op = self._beta2_power.assign(self._beta2_power * self._beta2) 83 | with ops.control_dependencies([beta1_update_op, beta2_update_op]): 84 | steps_and_vars = self._compute_update_steps(velocities_and_vars, covariances_and_vars) 85 | if self._weight_decay_type == "wd" and self._weight_decay > 0.0: 86 | steps_and_vars = self._add_weight_decay(steps_and_vars) 87 | update_ops = super(ADAMOptimizer, self).apply_gradients(steps_and_vars, 88 | *args, **kwargs) 89 | return update_ops 90 | 91 | def _compute_update_steps(self, velocities_and_vars, covariances_and_vars): 92 | steps_and_vars = [] 93 | for (velo, var1), (covar, var2) in zip(velocities_and_vars, covariances_and_vars): 94 | if var1 is not var2: 95 | raise ValueError("The variables referenced by the two arguments " 96 | "must match.") 97 | velo = velo / (1 - self._beta1_power) 98 | covar = covar / (1 - self._beta2_power) 99 | step = velo / (math_ops.sqrt(covar) + self._epsilon) 100 | steps_and_vars.append((step, var1)) 101 | 102 | return steps_and_vars 103 | 104 | def _init(self): 105 | first_var = min(self.variables, key=lambda x: x.name) 106 | with ops.colocate_with(first_var): 107 | self._beta1_power = variable_scope.variable(1.0, 108 | name="beta1_power", 109 | trainable=False) 110 | self._beta2_power = variable_scope.variable(1.0, 111 | name="beta2_power", 112 | trainable=False) 113 | 114 | def _add_weight_decay(self, vecs_and_vars): 115 | if self._weight_list == "all": 116 | print("all") 117 | return [(vec + self._weight_decay * gen_array_ops.stop_gradient(var), var) 118 | for vec, var in vecs_and_vars] 119 | elif self._weight_list == "last": 120 | print("last") 121 | grad_list = [] 122 | for vec, var in vecs_and_vars: 123 | if 'fc' not in var.name: 124 | grad_list.append((vec, var)) 125 | else: 126 | grad_list.append( 127 | (vec + self._weight_decay * 128 | gen_array_ops.stop_gradient(var), var)) 129 | return grad_list 130 | else: 131 | print("conv") 132 | grad_list = [] 133 | for vec, var in vecs_and_vars: 134 | if 'fc' in var.name: 135 | grad_list.append((vec, var)) 136 | else: 137 | grad_list.append( 138 | (vec + self._weight_decay * 139 | gen_array_ops.stop_gradient(var), var)) 140 | return grad_list 141 | 142 | def _update_velocities(self, vecs_and_vars, decay): 143 | def _update_velocity(vec, var): 144 | velocity = self._zeros_slot(var, "velocity", self._name) 145 | with ops.colocate_with(velocity): 146 | # Compute the new velocity for this variable. 147 | new_velocity = decay * velocity + (1 - decay) * vec 148 | 149 | # Save the updated velocity. 150 | return (array_ops.identity(velocity.assign(new_velocity)), var) 151 | 152 | # Go through variable and update its associated part of the velocity vector. 153 | return [_update_velocity(vec, var) for vec, var in vecs_and_vars] 154 | 155 | def _update_covariances(self, vecs_and_vars, decay): 156 | def _update_covariance(vec, var): 157 | covariance = self._zeros_slot(var, "covariance", self._name) 158 | with ops.colocate_with(covariance): 159 | # Compute the new velocity for this variable. 160 | new_covariance = decay * covariance + (1 - decay) * vec ** 2 161 | 162 | # Save the updated velocity. 163 | return (array_ops.identity(covariance.assign(new_covariance)), var) 164 | 165 | # Go through variable and update its associated part of the velocity vector. 166 | return [_update_covariance(vec, var) for vec, var in vecs_and_vars] 167 | -------------------------------------------------------------------------------- /libs/kfac/cmvp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from libs.kfac import utils 6 | from tensorflow.python.ops import gradients_impl 7 | from tensorflow.python.ops import math_ops 8 | from tensorflow.python.util import nest 9 | 10 | 11 | class CurvatureMatrixVectorProductComputer(object): 12 | """Class for computing matrix-vector products for Fishers, GGNs and Hessians. 13 | In other words we compute M*v where M is the matrix, v is the vector, and 14 | * refers to standard matrix/vector multiplication (not element-wise 15 | multiplication). 16 | The matrices are defined in terms of some differential quantity of the total 17 | loss function with respect to a provided list of tensors ("wrt_tensors"). 18 | For example, the Fisher associated with a log-prob loss w.r.t. the 19 | parameters. 20 | The 'vecs' argument to each method are lists of tensors that must be the 21 | size as the corresponding ones from "wrt_tensors". They represent 22 | the vector being multiplied. 23 | "factors" of the matrix M are defined as matrices B such that B*B^T = M. 24 | Methods that multiply by the factor B take a 'loss_inner_vecs' argument 25 | instead of 'vecs', which must be a list of tensors with shapes given by the 26 | corresponding XXX_inner_shapes property. 27 | Note that matrix-vector products are not normalized by the batch size, nor 28 | are any damping terms added to the results. These things can be easily 29 | applied externally, if desired. 30 | See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf 31 | and https://arxiv.org/abs/1412.1193 for more information about the 32 | generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector 33 | products. 34 | """ 35 | 36 | def __init__(self, losses, wrt_tensors): 37 | """Create a CurvatureMatrixVectorProductComputer object. 38 | Args: 39 | losses: A list of LossFunction instances whose sum defines the total loss. 40 | wrt_tensors: A list of Tensors to compute the differential quantities 41 | (defining the matrices) with respect to. See class description for more 42 | info. 43 | """ 44 | self._losses = losses 45 | self._inputs_to_losses = list(loss.inputs for loss in losses) 46 | self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses) 47 | self._wrt_tensors = wrt_tensors 48 | 49 | @property 50 | def _total_loss(self): 51 | return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) 52 | 53 | # Jacobian multiplication functions: 54 | def _multiply_jacobian(self, vecs): 55 | """Multiply vecs by the Jacobian of losses.""" 56 | # We stop gradients at wrt_tensors to produce partial derivatives (which is 57 | # what we want for Jacobians). 58 | jacobian_vecs_flat = utils.fwd_gradients( 59 | self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, 60 | stop_gradients=self._wrt_tensors) 61 | return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) 62 | 63 | def _multiply_jacobian_transpose(self, loss_vecs): 64 | """Multiply vecs by the transpose Jacobian of losses.""" 65 | loss_vecs_flat = nest.flatten(loss_vecs) 66 | # We stop gradients at wrt_tensors to produce partial derivatives (which is 67 | # what we want for Jacobians). 68 | return gradients_impl.gradients( 69 | self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, 70 | stop_gradients=self._wrt_tensors) 71 | 72 | # Losses Fisher/Hessian multiplication functions: 73 | def _multiply_loss_fisher(self, loss_vecs): 74 | """Multiply loss_vecs by Fisher of total loss.""" 75 | return tuple( 76 | loss.multiply_fisher(loss_vec) 77 | for loss, loss_vec in zip(self._losses, loss_vecs)) 78 | 79 | def _multiply_loss_fisher_factor(self, loss_inner_vecs): 80 | """Multiply loss_inner_vecs by factor of Fisher of total loss.""" 81 | return tuple( 82 | loss.multiply_fisher_factor(loss_vec) 83 | for loss, loss_vec in zip(self._losses, loss_inner_vecs)) 84 | 85 | def _multiply_loss_fisher_factor_transpose(self, loss_vecs): 86 | """Multiply loss_vecs by transpose factor of Fisher of total loss.""" 87 | return tuple( 88 | loss.multiply_fisher_factor_transpose(loss_vec) 89 | for loss, loss_vec in zip(self._losses, loss_vecs)) 90 | 91 | def _multiply_loss_hessian(self, loss_vecs): 92 | """Multiply loss_vecs by Hessian of total loss.""" 93 | return tuple( 94 | loss.multiply_hessian(loss_vec) 95 | for loss, loss_vec in zip(self._losses, loss_vecs)) 96 | 97 | def _multiply_loss_hessian_factor(self, loss_inner_vecs): 98 | """Multiply loss_inner_vecs by factor of Hessian of total loss.""" 99 | return tuple( 100 | loss.multiply_hessian_factor(loss_vec) 101 | for loss, loss_vec in zip(self._losses, loss_inner_vecs)) 102 | 103 | def _multiply_loss_hessian_factor_transpose(self, loss_vecs): 104 | """Multiply loss_vecs by transpose factor of Hessian of total loss.""" 105 | return tuple( 106 | loss.multiply_hessian_factor_transpose(loss_vec) 107 | for loss, loss_vec in zip(self._losses, loss_vecs)) 108 | 109 | # Matrix-vector product functions: 110 | def multiply_fisher(self, vecs): 111 | """Multiply vecs by Fisher of total loss.""" 112 | jacobian_vecs = self._multiply_jacobian(vecs) 113 | loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs) 114 | return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs) 115 | 116 | def multiply_fisher_factor_transpose(self, vecs): 117 | """Multiply vecs by transpose of factor of Fisher of total loss.""" 118 | jacobian_vecs = self._multiply_jacobian(vecs) 119 | return self._multiply_loss_fisher_factor_transpose(jacobian_vecs) 120 | 121 | def multiply_fisher_factor(self, loss_inner_vecs): 122 | """Multiply loss_inner_vecs by factor of Fisher of total loss.""" 123 | fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose( 124 | loss_inner_vecs) 125 | return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs) 126 | 127 | def multiply_hessian(self, vecs): 128 | """Multiply vecs by Hessian of total loss.""" 129 | return gradients_impl.gradients( 130 | gradients_impl.gradients(self._total_loss, self._wrt_tensors), 131 | self._wrt_tensors, 132 | grad_ys=vecs) 133 | 134 | def multiply_gauss_newton(self, vecs): 135 | jacobian_vecs = self._multiply_jacobian(vecs) 136 | return self._multiply_jacobian_transpose(jacobian_vecs) 137 | 138 | def multiply_generalized_gauss_newton(self, vecs): 139 | """Multiply vecs by generalized Gauss-Newton of total loss.""" 140 | jacobian_vecs = self._multiply_jacobian(vecs) 141 | loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs) 142 | return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs) 143 | 144 | def multiply_generalized_gauss_newton_factor_transpose(self, vecs): 145 | """Multiply vecs by transpose of factor of GGN of total loss.""" 146 | jacobian_vecs = self._multiply_jacobian(vecs) 147 | return self._multiply_loss_hessian_factor_transpose(jacobian_vecs) 148 | 149 | def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs): 150 | """Multiply loss_inner_vecs by factor of GGN of total loss.""" 151 | hessian_factor_transpose_vecs = ( 152 | self._multiply_loss_hessian_factor_transpose(loss_inner_vecs)) 153 | return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs) 154 | 155 | # Shape properties for multiply_XXX_factor methods: 156 | @property 157 | def fisher_factor_inner_shapes(self): 158 | """Shapes required by multiply_fisher_factor.""" 159 | return tuple(loss.fisher_factor_inner_shape for loss in self._losses) 160 | 161 | @property 162 | def generalized_gauss_newton_factor_inner_shapes(self): 163 | """Shapes required by multiply_generalized_gauss_newton_factor.""" 164 | return tuple(loss.hessian_factor_inner_shape for loss in self._losses) 165 | -------------------------------------------------------------------------------- /libs/kfac/estimator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import contextlib 6 | import itertools 7 | 8 | from libs.kfac import utils 9 | from tensorflow.python.framework import ops as tf_ops 10 | from tensorflow.python.ops import control_flow_ops 11 | from tensorflow.python.ops import gradients_impl 12 | from tensorflow.python.util import nest 13 | 14 | 15 | class _DeviceContextGenerator(object): 16 | """Class for generating device contexts in a round-robin fashion.""" 17 | 18 | def __init__(self, devices): 19 | self._cycle = None if devices is None else itertools.cycle(devices) 20 | 21 | @contextlib.contextmanager 22 | def __call__(self): 23 | """Returns a context manager specifying the default device.""" 24 | if self._cycle is None: 25 | yield 26 | else: 27 | with tf_ops.device(next(self._cycle)): 28 | yield 29 | 30 | 31 | class FisherEstimator(object): 32 | def __init__(self, 33 | variables, 34 | cov_ema_decay, 35 | damping, 36 | layer_collection, 37 | estimation_mode="gradients", 38 | update_type="online", 39 | colocate_gradients_with_ops=False, 40 | cov_devices=None, 41 | inv_devices=None): 42 | 43 | self._variables = variables 44 | self._damping = damping 45 | self._estimation_mode = estimation_mode 46 | self._layers = layer_collection 47 | # self._layers.create_subgraph() 48 | # self._layers.check_registration(variables) 49 | self._gradient_fns = { 50 | "gradients": self._get_grads_lists_gradients, 51 | "empirical": self._get_grads_lists_empirical, 52 | } 53 | self._colocate_gradients_with_ops = colocate_gradients_with_ops 54 | self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) 55 | if inv_devices == cov_devices: 56 | self._inv_device_context_generator = self._cov_device_context_generator 57 | else: 58 | self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) 59 | setup = self._setup(cov_ema_decay, update_type) 60 | self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup 61 | 62 | self.init_cov_op = self.init_cov_op() 63 | 64 | @property 65 | def variables(self): 66 | return self._variables 67 | 68 | @property 69 | def damping(self): 70 | return self._damping 71 | 72 | def _apply_transformation(self, vecs_and_vars, transform): 73 | vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) 74 | 75 | trans_vecs = utils.SequenceDict() 76 | 77 | for params, fb in self._layers.fisher_blocks.items(): 78 | trans_vecs[params] = transform(fb, vecs[params]) 79 | 80 | return [(trans_vecs[var], var) for _, var in vecs_and_vars] 81 | 82 | def multiply_inverse(self, vecs_and_vars): 83 | return self._apply_transformation(vecs_and_vars, 84 | lambda fb, vec: fb.multiply_inverse(vec)) 85 | 86 | def multiply(self, vecs_and_vars): 87 | return self._apply_transformation(vecs_and_vars, 88 | lambda fb, vec: fb.multiply(vec)) 89 | 90 | def init_cov_op(self): 91 | cov_updates = [ 92 | factor.make_covariance_update_op(1.0, "accumulate") 93 | for factor in self._layers.get_factors() 94 | ] 95 | return control_flow_ops.group(*cov_updates) 96 | 97 | def rescale_op(self, scale): 98 | rescale_ops = [factor.rescale_covariance_op(scale) for factor in self._layers.get_factors()] 99 | return control_flow_ops.group(*rescale_ops) 100 | 101 | def add_op(self, scale): 102 | add_ops = [factor.add_covariance_op(scale) for factor in self._layers.get_factors()] 103 | return control_flow_ops.group(*add_ops) 104 | 105 | def rescale(self, sess, scale): 106 | rescale_ops = [factor.rescale_covariance_op(scale) for factor in self._layers.get_factors()] 107 | sess.run(control_flow_ops.group(*rescale_ops)) 108 | 109 | def reset(self, sess): 110 | reset_ops = [factor.reset_covariance_op() for factor in self._layers.get_factors()] 111 | sess.run(control_flow_ops.group(*reset_ops)) 112 | 113 | def _setup(self, cov_ema_decay, update_type): 114 | fisher_blocks_list = self._layers.get_blocks() 115 | tensors_to_compute_grads = [ 116 | fb.tensors_to_compute_grads() for fb in fisher_blocks_list 117 | ] 118 | 119 | try: 120 | grads_lists = self._gradient_fns[self._estimation_mode]( 121 | tensors_to_compute_grads) 122 | except KeyError: 123 | raise ValueError("Unrecognized value {} for estimation_mode.".format( 124 | self._estimation_mode)) 125 | 126 | for grads_list, fb in zip(grads_lists, fisher_blocks_list): 127 | with self._cov_device_context_generator(): 128 | fb.instantiate_factors(grads_list, self.damping) 129 | 130 | cov_updates = [ 131 | factor.make_covariance_update_op(cov_ema_decay, update_type) 132 | for factor in self._layers.get_factors() 133 | ] 134 | inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} 135 | 136 | return control_flow_ops.group(*cov_updates), control_flow_ops.group( 137 | *inv_updates.values()), inv_updates 138 | 139 | def _get_all_inverse_update_ops(self): 140 | for factor in self._layers.get_factors(): 141 | with self._inv_device_context_generator(): 142 | for op in factor.make_inverse_update_ops(): 143 | yield op 144 | 145 | def _get_grads_lists_gradients(self, tensors): 146 | grads_flat = gradients_impl.gradients( 147 | self._layers.total_sampled_loss(), 148 | nest.flatten(tensors), 149 | colocate_gradients_with_ops=self._colocate_gradients_with_ops) 150 | grads_all = nest.pack_sequence_as(tensors, grads_flat) 151 | return tuple((grad,) for grad in grads_all) 152 | 153 | def _get_grads_lists_empirical(self, tensors): 154 | grads_flat = gradients_impl.gradients( 155 | self._layers.total_loss(), 156 | nest.flatten(tensors), 157 | colocate_gradients_with_ops=self._colocate_gradients_with_ops) 158 | grads_all = nest.pack_sequence_as(tensors, grads_flat) 159 | return tuple((grad,) for grad in grads_all) 160 | 161 | -------------------------------------------------------------------------------- /libs/kfac/fisher_blocks.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import abc 6 | import six 7 | 8 | from libs.kfac import fisher_factors 9 | from libs.kfac import utils 10 | from tensorflow.python.ops import array_ops 11 | from tensorflow.python.ops import math_ops 12 | 13 | # For blocks corresponding to convolutional layers, or any type of block where 14 | # the parameters can be thought of as being replicated in time or space, 15 | # we want to adjust the scale of the damping by 16 | # damping /= num_replications ** NORMALIZE_DAMPING_POWER 17 | NORMALIZE_DAMPING_POWER = 1.0 18 | 19 | # Methods for adjusting damping for FisherBlocks. See 20 | # _compute_pi_adjusted_damping() for details. 21 | PI_OFF_NAME = "off" 22 | PI_TRACENORM_NAME = "tracenorm" 23 | PI_TYPE = PI_TRACENORM_NAME 24 | 25 | 26 | def set_global_constants(normalize_damping_power=None, pi_type=None): 27 | """Sets various global constants used by the classes in this module.""" 28 | global NORMALIZE_DAMPING_POWER 29 | global PI_TYPE 30 | 31 | if normalize_damping_power is not None: 32 | NORMALIZE_DAMPING_POWER = normalize_damping_power 33 | 34 | if pi_type is not None: 35 | PI_TYPE = pi_type 36 | 37 | 38 | def _compute_pi_tracenorm(left_cov, right_cov): 39 | # Instead of dividing by the dim of the norm, we multiply by the dim of the 40 | # other norm. This works out the same in the ratio. 41 | left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] 42 | right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] 43 | return math_ops.sqrt(left_norm / right_norm) 44 | 45 | 46 | def _compute_pi_adjusted_damping(left_cov, right_cov, damping): 47 | 48 | if PI_TYPE == PI_TRACENORM_NAME: 49 | pi = _compute_pi_tracenorm(left_cov, right_cov) 50 | return (damping * pi, damping / pi) 51 | 52 | elif PI_TYPE == PI_OFF_NAME: 53 | return (damping, damping) 54 | 55 | 56 | @six.add_metaclass(abc.ABCMeta) 57 | class FisherBlock(object): 58 | """Abstract core class for objects modeling approximate Fisher matrix blocks. 59 | Subclasses must implement multiply_inverse(), instantiate_factors(), and 60 | tensors_to_compute_grads() methods. 61 | """ 62 | 63 | def __init__(self, layer_collection): 64 | self._layer_collection = layer_collection 65 | 66 | @abc.abstractmethod 67 | def instantiate_factors(self, grads_list, damping): 68 | """Creates and registers the component factors of this Fisher block. 69 | Args: 70 | grads_list: A list gradients (each a Tensor or tuple of Tensors) with 71 | respect to the tensors returned by tensors_to_compute_grads() that 72 | are to be used to estimate the block. 73 | damping: The damping factor (float or Tensor). 74 | """ 75 | pass 76 | 77 | @abc.abstractmethod 78 | def multiply_inverse(self, vector): 79 | """Multiplies the vector by the (damped) inverse of the block. 80 | Args: 81 | vector: The vector (a Tensor or tuple of Tensors) to be multiplied. 82 | Returns: 83 | The vector left-multiplied by the (damped) inverse of the block. 84 | """ 85 | pass 86 | 87 | @abc.abstractmethod 88 | def multiply(self, vector): 89 | """Multiplies the vector by the (damped) block. 90 | Args: 91 | vector: The vector (a Tensor or tuple of Tensors) to be multiplied. 92 | Returns: 93 | The vector left-multiplied by the (damped) block. 94 | """ 95 | pass 96 | 97 | @abc.abstractmethod 98 | def tensors_to_compute_grads(self): 99 | """Returns the Tensor(s) with respect to which this FisherBlock needs grads. 100 | """ 101 | pass 102 | 103 | @abc.abstractproperty 104 | def num_registered_minibatches(self): 105 | """Number of minibatches registered for this FisherBlock. 106 | Typically equal to the number of towers in a multi-tower setup. 107 | """ 108 | pass 109 | 110 | 111 | class FullFB(FisherBlock): 112 | """FisherBlock using a full matrix estimate (no approximations). 113 | FullFB uses a full matrix estimate (no approximations), and should only ever 114 | be used for very low dimensional parameters. 115 | Note that this uses the naive "square the sum estimator", and so is applicable 116 | to any type of parameter in principle, but has very high variance. 117 | """ 118 | 119 | def __init__(self, layer_collection, params): 120 | """Creates a FullFB block. 121 | Args: 122 | layer_collection: The collection of all layers in the K-FAC approximate 123 | Fisher information matrix to which this FisherBlock belongs. 124 | params: The parameters of this layer (Tensor or tuple of Tensors). 125 | """ 126 | self._batch_sizes = [] 127 | self._params = params 128 | 129 | super(FullFB, self).__init__(layer_collection) 130 | 131 | def instantiate_factors(self, grads_list, damping): 132 | self._damping = damping 133 | self._factor = self._layer_collection.make_or_get_factor( 134 | fisher_factors.FullFactor, (grads_list, self._batch_size)) 135 | self._factor.register_damped_inverse(damping) 136 | 137 | def multiply_inverse(self, vector): 138 | inverse = self._factor.get_damped_inverse(self._damping) 139 | out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) 140 | return utils.column_to_tensors(vector, out_flat) 141 | 142 | def multiply(self, vector): 143 | vector_flat = utils.tensors_to_column(vector) 144 | out_flat = ( 145 | math_ops.matmul(self._factor.get_cov(), vector_flat) + 146 | self._damping * vector_flat) 147 | return utils.column_to_tensors(vector, out_flat) 148 | 149 | def full_fisher_block(self): 150 | """Explicitly constructs the full Fisher block.""" 151 | return self._factor.get_cov() 152 | 153 | def tensors_to_compute_grads(self): 154 | return self._params 155 | 156 | def register_additional_minibatch(self, batch_size): 157 | """Register an additional minibatch. 158 | Args: 159 | batch_size: The batch size, used in the covariance estimator. 160 | """ 161 | self._batch_sizes.append(batch_size) 162 | 163 | @property 164 | def num_registered_minibatches(self): 165 | return len(self._batch_sizes) 166 | 167 | @property 168 | def _batch_size(self): 169 | return math_ops.reduce_sum(self._batch_sizes) 170 | 171 | 172 | class NaiveDiagonalFB(FisherBlock): 173 | """FisherBlock using a diagonal matrix approximation. 174 | This type of approximation is generically applicable but quite primitive. 175 | Note that this uses the naive "square the sum estimator", and so is applicable 176 | to any type of parameter in principle, but has very high variance. 177 | """ 178 | 179 | def __init__(self, layer_collection, params): 180 | """Creates a NaiveDiagonalFB block. 181 | Args: 182 | layer_collection: The collection of all layers in the K-FAC approximate 183 | Fisher information matrix to which this FisherBlock belongs. 184 | params: The parameters of this layer (Tensor or tuple of Tensors). 185 | """ 186 | self._params = params 187 | self._batch_sizes = [] 188 | 189 | super(NaiveDiagonalFB, self).__init__(layer_collection) 190 | 191 | def instantiate_factors(self, grads_list, damping): 192 | self._damping = damping 193 | self._factor = self._layer_collection.make_or_get_factor( 194 | fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) 195 | 196 | def multiply_inverse(self, vector): 197 | vector_flat = utils.tensors_to_column(vector) 198 | out_flat = vector_flat / (self._factor.get_cov() + self._damping) 199 | return utils.column_to_tensors(vector, out_flat) 200 | 201 | def multiply(self, vector): 202 | vector_flat = utils.tensors_to_column(vector) 203 | out_flat = vector_flat * (self._factor.get_cov() + self._damping) 204 | return utils.column_to_tensors(vector, out_flat) 205 | 206 | def full_fisher_block(self): 207 | return array_ops.diag(array_ops.reshape(self._factor.get_cov(), (-1,))) 208 | 209 | def tensors_to_compute_grads(self): 210 | return self._params 211 | 212 | def register_additional_minibatch(self, batch_size): 213 | """Register an additional minibatch. 214 | Args: 215 | batch_size: The batch size, used in the covariance estimator. 216 | """ 217 | self._batch_sizes.append(batch_size) 218 | 219 | @property 220 | def num_registered_minibatches(self): 221 | return len(self._batch_sizes) 222 | 223 | @property 224 | def _batch_size(self): 225 | return math_ops.reduce_sum(self._batch_sizes) 226 | 227 | 228 | class FullyConnectedDiagonalFB(FisherBlock): 229 | """FisherBlock for fully-connected (dense) layers using a diagonal approx. 230 | Estimates the Fisher Information matrix's diagonal entries for a fully 231 | connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of 232 | squares" estimator. 233 | Let 'params' be a vector parameterizing a model and 'i' an arbitrary index 234 | into it. We are interested in Fisher(params)[i, i]. This is, 235 | Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] 236 | = E[ v(x, y, params)[i] ^ 2 ] 237 | Consider fully connected layer in this model with (unshared) weight matrix 238 | 'w'. For an example 'x' that produces layer inputs 'a' and output 239 | preactivations 's', 240 | v(x, y, w) = vec( a (d loss / d s)^T ) 241 | This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding 242 | to the layer's parameters 'w'. 243 | """ 244 | 245 | def __init__(self, layer_collection, has_bias=False): 246 | """Creates a FullyConnectedDiagonalFB block. 247 | Args: 248 | layer_collection: The collection of all layers in the K-FAC approximate 249 | Fisher information matrix to which this FisherBlock belongs. 250 | has_bias: Whether the component Kronecker factors have an additive bias. 251 | (Default: False) 252 | """ 253 | self._inputs = [] 254 | self._outputs = [] 255 | self._has_bias = has_bias 256 | 257 | super(FullyConnectedDiagonalFB, self).__init__(layer_collection) 258 | 259 | def instantiate_factors(self, grads_list, damping): 260 | inputs = _concat_along_batch_dim(self._inputs) 261 | grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) 262 | 263 | self._damping = damping 264 | self._factor = self._layer_collection.make_or_get_factor( 265 | fisher_factors.FullyConnectedDiagonalFactor, 266 | (inputs, grads_list, self._has_bias)) 267 | 268 | def multiply_inverse(self, vector): 269 | """Approximate damped inverse Fisher-vector product. 270 | Args: 271 | vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape 272 | [input_size, output_size] corresponding to layer's weights. If not, a 273 | 2-tuple of the former and a Tensor of shape [output_size] corresponding 274 | to the layer's bias. 275 | Returns: 276 | Tensor of the same shape, corresponding to the inverse Fisher-vector 277 | product. 278 | """ 279 | reshaped_vect = utils.layer_params_to_mat2d(vector) 280 | reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) 281 | return utils.mat2d_to_layer_params(vector, reshaped_out) 282 | 283 | def multiply(self, vector): 284 | """Approximate damped Fisher-vector product. 285 | Args: 286 | vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape 287 | [input_size, output_size] corresponding to layer's weights. If not, a 288 | 2-tuple of the former and a Tensor of shape [output_size] corresponding 289 | to the layer's bias. 290 | Returns: 291 | Tensor of the same shape, corresponding to the Fisher-vector product. 292 | """ 293 | reshaped_vect = utils.layer_params_to_mat2d(vector) 294 | reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) 295 | return utils.mat2d_to_layer_params(vector, reshaped_out) 296 | 297 | def tensors_to_compute_grads(self): 298 | """Tensors to compute derivative of loss with respect to.""" 299 | return self._outputs 300 | 301 | def register_additional_minibatch(self, inputs, outputs): 302 | """Registers an additional minibatch to the FisherBlock. 303 | Args: 304 | inputs: Tensor of shape [batch_size, input_size]. Inputs to the 305 | matrix-multiply. 306 | outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. 307 | """ 308 | self._inputs.append(inputs) 309 | self._outputs.append(outputs) 310 | 311 | @property 312 | def num_registered_minibatches(self): 313 | result = len(self._inputs) 314 | assert result == len(self._outputs) 315 | return result 316 | 317 | 318 | class ConvDiagonalFB(FisherBlock): 319 | """FisherBlock for convolutional layers using a diagonal approx. 320 | Estimates the Fisher Information matrix's diagonal entries for a convolutional 321 | layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" 322 | estimator. 323 | Let 'params' be a vector parameterizing a model and 'i' an arbitrary index 324 | into it. We are interested in Fisher(params)[i, i]. This is, 325 | Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] 326 | = E[ v(x, y, params)[i] ^ 2 ] 327 | Consider a convoluational layer in this model with (unshared) filter matrix 328 | 'w'. For an example image 'x' that produces layer inputs 'a' and output 329 | preactivations 's', 330 | v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) 331 | where 'loc' is a single (x, y) location in an image. 332 | This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding 333 | to the layer's parameters 'w'. 334 | """ 335 | 336 | def __init__(self, layer_collection, params, strides, padding): 337 | """Creates a ConvDiagonalFB block. 338 | Args: 339 | layer_collection: The collection of all layers in the K-FAC approximate 340 | Fisher information matrix to which this FisherBlock belongs. 341 | params: The parameters (Tensor or tuple of Tensors) of this layer. If 342 | kernel alone, a Tensor of shape [kernel_height, kernel_width, 343 | in_channels, out_channels]. If kernel and bias, a tuple of 2 elements 344 | containing the previous and a Tensor of shape [out_channels]. 345 | strides: The stride size in this layer (1-D Tensor of length 4). 346 | padding: The padding in this layer (e.g. "SAME"). 347 | """ 348 | self._inputs = [] 349 | self._outputs = [] 350 | self._strides = tuple(strides) if isinstance(strides, list) else strides 351 | self._padding = padding 352 | self._has_bias = isinstance(params, (tuple, list)) 353 | 354 | fltr = params[0] if self._has_bias else params 355 | self._filter_shape = tuple(fltr.shape.as_list()) 356 | 357 | super(ConvDiagonalFB, self).__init__(layer_collection) 358 | 359 | def instantiate_factors(self, grads_list, damping): 360 | # Concatenate inputs, grads_list into single Tensors. 361 | inputs = _concat_along_batch_dim(self._inputs) 362 | grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) 363 | 364 | # Infer number of locations upon which convolution is applied. 365 | inputs_shape = tuple(inputs.shape.as_list()) 366 | self._num_locations = ( 367 | inputs_shape[1] * inputs_shape[2] // 368 | (self._strides[1] * self._strides[2])) 369 | 370 | if NORMALIZE_DAMPING_POWER: 371 | damping /= self._num_locations**NORMALIZE_DAMPING_POWER 372 | self._damping = self._num_locations**NORMALIZE_DAMPING_POWER * damping 373 | 374 | self._factor = self._layer_collection.make_or_get_factor( 375 | fisher_factors.ConvDiagonalFactor, 376 | (inputs, grads_list, self._filter_shape, self._strides, self._padding, 377 | self._has_bias)) 378 | 379 | def multiply_inverse(self, vector): 380 | reshaped_vect = utils.layer_params_to_mat2d(vector) 381 | reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) 382 | return utils.mat2d_to_layer_params(vector, reshaped_out) 383 | 384 | def multiply(self, vector): 385 | reshaped_vect = utils.layer_params_to_mat2d(vector) 386 | reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) 387 | return utils.mat2d_to_layer_params(vector, reshaped_out) 388 | 389 | def tensors_to_compute_grads(self): 390 | return self._outputs 391 | 392 | def register_additional_minibatch(self, inputs, outputs): 393 | """Registers an additional minibatch to the FisherBlock. 394 | Args: 395 | inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to 396 | the convolution. 397 | outputs: Tensor of shape [batch_size, height, width, output_size]. Layer 398 | preactivations. 399 | """ 400 | self._inputs.append(inputs) 401 | self._outputs.append(outputs) 402 | 403 | @property 404 | def num_registered_minibatches(self): 405 | return len(self._inputs) 406 | 407 | 408 | class KroneckerProductFB(FisherBlock): 409 | """A core class for FisherBlocks with separate input and output factors. 410 | The Fisher block is approximated as a Kronecker product of the input and 411 | output factors. 412 | """ 413 | 414 | def _register_damped_input_and_output_inverses(self, damping): 415 | """Registers damped inverses for both the input and output factors. 416 | Sets the instance members _input_damping and _output_damping. Requires the 417 | instance members _input_factor and _output_factor. 418 | Args: 419 | damping: The core damping factor (float or Tensor) for the damped inverse. 420 | """ 421 | self._input_damping, self._output_damping = _compute_pi_adjusted_damping( 422 | self._input_factor.get_cov(), 423 | self._output_factor.get_cov(), 424 | damping**0.5) 425 | 426 | self._input_factor.register_damped_inverse(self._input_damping) 427 | self._output_factor.register_damped_inverse(self._output_damping) 428 | 429 | @property 430 | def _renorm_coeff(self): 431 | """Kronecker factor multiplier coefficient. 432 | If this FisherBlock is represented as 'FB = c * kron(left, right)', then 433 | this is 'c'. 434 | Returns: 435 | 0-D Tensor. 436 | """ 437 | return 1.0 438 | 439 | def multiply_inverse(self, vector): 440 | left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) 441 | right_factor_inv = self._output_factor.get_damped_inverse( 442 | self._output_damping) 443 | reshaped_vector = utils.layer_params_to_mat2d(vector) 444 | reshaped_out = math_ops.matmul(left_factor_inv, 445 | math_ops.matmul(reshaped_vector, 446 | right_factor_inv)) 447 | if self._renorm_coeff != 1.0: 448 | reshaped_out /= math_ops.cast( 449 | self._renorm_coeff, dtype=reshaped_out.dtype) 450 | return utils.mat2d_to_layer_params(vector, reshaped_out) 451 | 452 | def multiply(self, vector): 453 | left_factor = self._input_factor.get_cov() 454 | right_factor = self._output_factor.get_cov() 455 | reshaped_vector = utils.layer_params_to_mat2d(vector) 456 | reshaped_out = ( 457 | math_ops.matmul(reshaped_vector, right_factor) + 458 | self._output_damping * reshaped_vector) 459 | reshaped_out = ( 460 | math_ops.matmul(left_factor, reshaped_out) + 461 | self._input_damping * reshaped_out) 462 | if self._renorm_coeff != 1.0: 463 | reshaped_out *= math_ops.cast( 464 | self._renorm_coeff, dtype=reshaped_out.dtype) 465 | return utils.mat2d_to_layer_params(vector, reshaped_out) 466 | 467 | def full_fisher_block(self): 468 | """Explicitly constructs the full Fisher block. 469 | Used for testing purposes. (In general, the result may be very large.) 470 | Returns: 471 | The full Fisher block. 472 | """ 473 | left_factor = self._input_factor.get_cov() 474 | right_factor = self._output_factor.get_cov() 475 | return self._renorm_coeff * utils.kronecker_product(left_factor, 476 | right_factor) 477 | 478 | 479 | class FullyConnectedKFACBasicFB(KroneckerProductFB): 480 | """K-FAC FisherBlock for fully-connected (dense) layers. 481 | This uses the Kronecker-factorized approximation from the original 482 | K-FAC paper (https://arxiv.org/abs/1503.05671) 483 | """ 484 | 485 | def __init__(self, layer_collection, has_bias=False): 486 | """Creates a FullyConnectedKFACBasicFB block. 487 | Args: 488 | layer_collection: The collection of all layers in the K-FAC approximate 489 | Fisher information matrix to which this FisherBlock belongs. 490 | has_bias: Whether the component Kronecker factors have an additive bias. 491 | (Default: False) 492 | """ 493 | self._inputs = [] 494 | self._outputs = [] 495 | self._has_bias = has_bias 496 | 497 | super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) 498 | 499 | def instantiate_factors(self, grads_list, damping): 500 | """Instantiate Kronecker Factors for this FisherBlock. 501 | Args: 502 | grads_list: List of list of Tensors. grads_list[i][j] is the 503 | gradient of the loss with respect to 'outputs' from source 'i' and 504 | tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. 505 | damping: 0-D Tensor or float. 'damping' * identity is approximately added 506 | to this FisherBlock's Fisher approximation. 507 | """ 508 | # TODO(b/68033310): Validate which of, 509 | # (1) summing on a single device (as below), or 510 | # (2) on each device in isolation and aggregating 511 | # is faster. 512 | inputs = _concat_along_batch_dim(self._inputs) 513 | grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) 514 | 515 | self._input_factor = self._layer_collection.make_or_get_factor( # 516 | fisher_factors.FullyConnectedKroneckerFactor, # 517 | ((inputs,), self._has_bias)) 518 | self._output_factor = self._layer_collection.make_or_get_factor( # 519 | fisher_factors.FullyConnectedKroneckerFactor, # 520 | (grads_list,)) 521 | self._register_damped_input_and_output_inverses(damping) 522 | 523 | def tensors_to_compute_grads(self): 524 | return self._outputs 525 | 526 | def register_additional_minibatch(self, inputs, outputs): 527 | """Registers an additional minibatch to the FisherBlock. 528 | Args: 529 | inputs: Tensor of shape [batch_size, input_size]. Inputs to the 530 | matrix-multiply. 531 | outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. 532 | """ 533 | self._inputs.append(inputs) 534 | self._outputs.append(outputs) 535 | 536 | @property 537 | def num_registered_minibatches(self): 538 | return len(self._inputs) 539 | 540 | 541 | class ConvKFCBasicFB(KroneckerProductFB): 542 | """FisherBlock for 2D convolutional layers using the basic KFC approx. 543 | Estimates the Fisher Information matrix's blog for a convolutional 544 | layer. 545 | Consider a convoluational layer in this model with (unshared) filter matrix 546 | 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', 547 | this FisherBlock estimates, 548 | F(w) = #locations * kronecker(E[flat(a) flat(a)^T], 549 | E[flat(ds) flat(ds)^T]) 550 | where 551 | ds = (d / ds) log p(y | x, w) 552 | #locations = number of (x, y) locations where 'w' is applied. 553 | where the expectation is taken over all examples and locations and flat() 554 | concatenates an array's leading dimensions. 555 | See equation 23 in https://arxiv.org/abs/1602.01407 for details. 556 | """ 557 | 558 | def __init__(self, layer_collection, params, strides, padding): 559 | """Creates a ConvKFCBasicFB block. 560 | Args: 561 | layer_collection: The collection of all layers in the K-FAC approximate 562 | Fisher information matrix to which this FisherBlock belongs. 563 | params: The parameters (Tensor or tuple of Tensors) of this layer. If 564 | kernel alone, a Tensor of shape [kernel_height, kernel_width, 565 | in_channels, out_channels]. If kernel and bias, a tuple of 2 elements 566 | containing the previous and a Tensor of shape [out_channels]. 567 | strides: The stride size in this layer (1-D Tensor of length 4). 568 | padding: The padding in this layer (1-D of Tensor length 4). 569 | """ 570 | self._inputs = [] 571 | self._outputs = [] 572 | self._strides = tuple(strides) if isinstance(strides, list) else strides 573 | self._padding = padding 574 | self._has_bias = isinstance(params, (tuple, list)) 575 | 576 | fltr = params[0] if self._has_bias else params 577 | self._filter_shape = tuple(fltr.shape.as_list()) 578 | 579 | super(ConvKFCBasicFB, self).__init__(layer_collection) 580 | 581 | def instantiate_factors(self, grads_list, damping): 582 | # TODO(b/68033310): Validate which of, 583 | # (1) summing on a single device (as below), or 584 | # (2) on each device in isolation and aggregating 585 | # is faster. 586 | inputs = _concat_along_batch_dim(self._inputs) 587 | grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) 588 | 589 | # Infer number of locations upon which convolution is applied. 590 | self._num_locations = _num_conv_locations(inputs.shape.as_list(), 591 | self._strides) 592 | 593 | self._input_factor = self._layer_collection.make_or_get_factor( 594 | fisher_factors.ConvInputKroneckerFactor, 595 | (inputs, self._filter_shape, self._strides, self._padding, 596 | self._has_bias)) 597 | self._output_factor = self._layer_collection.make_or_get_factor( 598 | fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) 599 | 600 | if NORMALIZE_DAMPING_POWER: 601 | damping /= self._num_locations**NORMALIZE_DAMPING_POWER 602 | self._damping = damping 603 | 604 | self._register_damped_input_and_output_inverses(damping) 605 | 606 | @property 607 | def _renorm_coeff(self): 608 | return self._num_locations 609 | 610 | def tensors_to_compute_grads(self): 611 | return self._outputs 612 | 613 | def register_additional_minibatch(self, inputs, outputs): 614 | """Registers an additional minibatch to the FisherBlock. 615 | Args: 616 | inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to 617 | the convolution. 618 | outputs: Tensor of shape [batch_size, height, width, output_size]. Layer 619 | preactivations. 620 | """ 621 | self._inputs.append(inputs) 622 | self._outputs.append(outputs) 623 | 624 | @property 625 | def num_registered_minibatches(self): 626 | return len(self._inputs) 627 | 628 | 629 | def _concat_along_batch_dim(tensor_list): 630 | """Concatenate tensors along batch (first) dimension. 631 | Args: 632 | tensor_list: list of Tensors or list of tuples of Tensors. 633 | Returns: 634 | Tensor or tuple of Tensors. 635 | Raises: 636 | ValueError: If 'tensor_list' is empty. 637 | """ 638 | if not tensor_list: 639 | raise ValueError( 640 | "Cannot concatenate Tensors if there are no Tensors to concatenate.") 641 | 642 | if isinstance(tensor_list[0], (tuple, list)): 643 | # [(tensor1a, tensor1b), 644 | # (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b) 645 | return tuple( 646 | array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list)) 647 | else: 648 | # [tensor1, tensor2] --> tensor 649 | return array_ops.concat(tensor_list, axis=0) 650 | 651 | 652 | def _num_conv_locations(input_shape, strides): 653 | """Returns the number of locations a Conv kernel is applied to.""" 654 | return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) 655 | 656 | -------------------------------------------------------------------------------- /libs/kfac/layer_collection.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import OrderedDict 6 | import six 7 | 8 | from libs.kfac import fisher_blocks as fb 9 | from libs.kfac import loss_functions as lf 10 | from libs.kfac.utils import ensure_sequence, LayerParametersDict 11 | from tensorflow.python.framework import ops 12 | from tensorflow.python.ops import math_ops 13 | from tensorflow.python.ops import variable_scope 14 | 15 | # Names for various approximations that can be requested for Fisher blocks. 16 | APPROX_KRONECKER_NAME = "kron" 17 | APPROX_DIAGONAL_NAME = "diagonal" 18 | APPROX_FULL_NAME = "full" 19 | 20 | _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { 21 | APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, 22 | APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, 23 | } 24 | 25 | _CONV2D_APPROX_TO_BLOCK_TYPES = { 26 | APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, 27 | APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, 28 | } 29 | 30 | # Possible value for 'reuse' keyword argument. Sets 'reuse' to 31 | # tf.get_variable_scope().reuse. 32 | VARIABLE_SCOPE = "VARIABLE_SCOPE" 33 | 34 | 35 | class LayerCollection(object): 36 | def __init__(self, 37 | graph=None, 38 | colocate_cov_ops_with_inputs=False, 39 | name="LayerCollection"): 40 | self.fisher_blocks = LayerParametersDict() 41 | self.fisher_factors = OrderedDict() 42 | self._graph = graph or ops.get_default_graph() 43 | self._loss_dict = {} # {str: LossFunction} 44 | self._default_fully_connected_approximation = APPROX_KRONECKER_NAME 45 | self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME 46 | self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs 47 | 48 | with variable_scope.variable_scope(None, default_name=name) as scope: 49 | self._var_scope = scope.name 50 | 51 | @property 52 | def losses(self): 53 | """LossFunctions registered with this LayerCollection.""" 54 | return list(self._loss_dict.values()) 55 | 56 | @property 57 | def registered_variables(self): 58 | """A tuple of all of the variables currently registered.""" 59 | tuple_of_tuples = (ensure_sequence(key) for key, block 60 | in six.iteritems(self.fisher_blocks)) 61 | flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) 62 | return flat_tuple 63 | 64 | @property 65 | def default_fully_connected_approximation(self): 66 | return self._default_fully_connected_approximation 67 | 68 | def set_default_fully_connected_approximation(self, value): 69 | if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: 70 | raise ValueError( 71 | "{} is not a valid approximation for fully connected layers.".format( 72 | value)) 73 | self._default_fully_connected_approximation = value 74 | 75 | @property 76 | def default_conv2d_approximation(self): 77 | return self._default_convolution_2d_approximation 78 | 79 | def set_default_conv2d_approximation(self, value): 80 | if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: 81 | raise ValueError( 82 | "{} is not a valid approximation for 2d convolutional layers.".format( 83 | value)) 84 | self._default_convolution_2d_approximation = value 85 | 86 | def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): 87 | if reuse is VARIABLE_SCOPE: 88 | reuse = variable_scope.get_variable_scope().reuse 89 | 90 | if reuse is True or (reuse is variable_scope.AUTO_REUSE and 91 | layer_key in self.fisher_blocks): 92 | result = self.fisher_blocks[layer_key] 93 | if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck 94 | raise ValueError( 95 | "Attempted to register FisherBlock of type %s when existing " 96 | "FisherBlock has type %s." % (type(fisher_block), type(result))) 97 | return result 98 | if reuse is False and layer_key in self.fisher_blocks: 99 | raise ValueError("FisherBlock for %s is already in LayerCollection." % 100 | (layer_key,)) 101 | 102 | # Insert fisher_block into self.fisher_blocks. 103 | if layer_key in self.fisher_blocks: 104 | raise ValueError("Duplicate registration: {}".format(layer_key)) 105 | # Raise an error if any variable in layer_key has been registered in any 106 | # other blocks. 107 | variable_to_block = { 108 | var: (params, block) 109 | for (params, block) in self.fisher_blocks.items() 110 | for var in ensure_sequence(params) 111 | } 112 | for variable in ensure_sequence(layer_key): 113 | if variable in variable_to_block: 114 | prev_key, prev_block = variable_to_block[variable] 115 | raise ValueError( 116 | "Attempted to register layer_key {} with block {}, but variable {}" 117 | " was already registered in key {} with block {}.".format( 118 | layer_key, fisher_block, variable, prev_key, prev_block)) 119 | self.fisher_blocks[layer_key] = fisher_block 120 | return fisher_block 121 | 122 | def get_blocks(self): 123 | return self.fisher_blocks.values() 124 | 125 | def get_factors(self): 126 | return self.fisher_factors.values() 127 | 128 | def total_loss(self): 129 | return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses)) 130 | 131 | def total_sampled_loss(self): 132 | return math_ops.add_n( 133 | tuple(loss.evaluate_on_sample() for loss in self.losses)) 134 | 135 | def register_fully_connected(self, 136 | params, 137 | inputs, 138 | outputs, 139 | approx=None, 140 | reuse=VARIABLE_SCOPE): 141 | if approx is None: 142 | approx = self.default_fully_connected_approximation 143 | 144 | if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: 145 | raise ValueError("Bad value {} for approx.".format(approx)) 146 | 147 | block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] 148 | has_bias = isinstance(params, (tuple, list)) 149 | 150 | block = self.register_block(params, block_type(self, has_bias), reuse=reuse) 151 | block.register_additional_minibatch(inputs, outputs) 152 | 153 | def register_conv2d(self, 154 | params, 155 | strides, 156 | padding, 157 | inputs, 158 | outputs, 159 | approx=None, 160 | reuse=VARIABLE_SCOPE): 161 | if approx is None: 162 | approx = self.default_conv2d_approximation 163 | 164 | if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: 165 | raise ValueError("Bad value {} for approx.".format(approx)) 166 | 167 | block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] 168 | block = self.register_block( 169 | params, block_type(self, params, strides, padding), reuse=reuse) 170 | block.register_additional_minibatch(inputs, outputs) 171 | 172 | def register_categorical_predictive_distribution(self, 173 | logits, 174 | seed=None, 175 | targets=None, 176 | name=None): 177 | name = name or self._graph.unique_name( 178 | "register_categorical_predictive_distribution") 179 | 180 | if name in self._loss_dict: 181 | raise KeyError( 182 | "Loss function named {} already exists. Set reuse=True to append " 183 | "another minibatch.".format(name)) 184 | loss = lf.CategoricalLogitsNegativeLogProbLoss( 185 | logits, targets=targets, seed=seed) 186 | self._loss_dict[name] = loss 187 | 188 | def register_normal_predictive_distribution(self, 189 | mean, 190 | var=0.5, 191 | seed=None, 192 | targets=None, 193 | name=None): 194 | name = name or self._graph.unique_name( 195 | "register_normal_predictive_distribution") 196 | if name in self._loss_dict: 197 | raise NotImplementedError( 198 | "Adding logits to an existing LossFunction not yet supported.") 199 | loss = lf.NormalMeanNegativeLogProbLoss( 200 | mean, var, targets=targets, seed=seed) 201 | self._loss_dict[name] = loss 202 | 203 | def register_multi_bernoulli_predictive_distribution(self, 204 | logits, 205 | seed=None, 206 | targets=None, 207 | name=None): 208 | 209 | name = name or self._graph.unique_name( 210 | "register_multi_bernoulli_predictive_distribution") 211 | if name in self._loss_dict: 212 | raise NotImplementedError( 213 | "Adding logits to an existing LossFunction not yet supported.") 214 | loss = lf.MultiBernoulliNegativeLogProbLoss( 215 | logits, targets=targets, seed=seed) 216 | self._loss_dict[name] = loss 217 | 218 | def make_or_get_factor(self, cls, args): 219 | try: 220 | hash(args) 221 | except TypeError: 222 | raise TypeError( 223 | ("Unable to use (cls, args) = ({}, {}) as a key in " 224 | "LayerCollection.fisher_factors. The pair cannot be hashed.").format( 225 | cls, args)) 226 | 227 | key = cls, args 228 | if key not in self.fisher_factors: 229 | colo = self._colocate_cov_ops_with_inputs 230 | with variable_scope.variable_scope(self._var_scope): 231 | self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo) 232 | return self.fisher_factors[key] 233 | -------------------------------------------------------------------------------- /libs/kfac/loss_functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import abc 6 | 7 | import six 8 | 9 | from tensorflow.python.framework import tensor_shape 10 | from tensorflow.python.ops import array_ops 11 | from tensorflow.python.ops import math_ops 12 | from tensorflow.python.ops.distributions import bernoulli 13 | from tensorflow.python.ops.distributions import categorical 14 | from tensorflow.python.ops.distributions import normal 15 | 16 | 17 | @six.add_metaclass(abc.ABCMeta) 18 | class LossFunction(object): 19 | """Abstract core class for loss functions. 20 | Note that unlike typical loss functions used in neural networks these are 21 | summed and not averaged across cases in the batch, since this is what the 22 | users of this class (FisherEstimator and MatrixVectorProductComputer) will 23 | be expecting. The implication of this is that you will may want to 24 | normalize things like Fisher-vector products by the batch size when you 25 | use this class. It depends on the use case. 26 | """ 27 | 28 | @abc.abstractproperty 29 | def targets(self): 30 | """The targets being predicted by the model. 31 | Returns: 32 | None or Tensor of appropriate shape for calling self._evaluate() on. 33 | """ 34 | pass 35 | 36 | @abc.abstractproperty 37 | def inputs(self): 38 | """The inputs to the loss function (excluding the targets).""" 39 | pass 40 | 41 | @property 42 | def input_minibatches(self): 43 | """A `list` of inputs to the loss function, separated by minibatch. 44 | Typically there will be one minibatch per tower in a multi-tower setup. 45 | Returns a list consisting of `self.inputs` by default; `LossFunction`s 46 | supporting registering multiple minibatches should override this method. 47 | Returns: 48 | A `list` of `Tensor`s representing 49 | """ 50 | return [self.inputs] 51 | 52 | @property 53 | def num_registered_minibatches(self): 54 | """Number of minibatches registered for this LossFunction. 55 | Typically equal to the number of towers in a multi-tower setup. 56 | Returns: 57 | An `int` representing the number of registered minibatches. 58 | """ 59 | return len(self.input_minibatches) 60 | 61 | def evaluate(self): 62 | """Evaluate the loss function on the targets.""" 63 | if self.targets is not None: 64 | # We treat the targets as "constant". It's only the inputs that get 65 | # "back-propped" through. 66 | return self._evaluate(array_ops.stop_gradient(self.targets)) 67 | else: 68 | raise Exception("Cannot evaluate losses with unspecified targets.") 69 | 70 | @abc.abstractmethod 71 | def _evaluate(self, targets): 72 | """Evaluates the negative log probability of the targets. 73 | Args: 74 | targets: Tensor that distribution can calculate log_prob() of. 75 | Returns: 76 | negative log probability of each target, summed across all targets. 77 | """ 78 | pass 79 | 80 | @abc.abstractmethod 81 | def multiply_hessian(self, vector): 82 | """Right-multiply a vector by the Hessian. 83 | Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 84 | of the loss function with respect to its inputs. 85 | Args: 86 | vector: The vector to multiply. Must be the same shape(s) as the 87 | 'inputs' property. 88 | Returns: 89 | The vector right-multiplied by the Hessian. Will be of the same shape(s) 90 | as the 'inputs' property. 91 | """ 92 | pass 93 | 94 | @abc.abstractmethod 95 | def multiply_hessian_factor(self, vector): 96 | """Right-multiply a vector by a factor B of the Hessian. 97 | Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 98 | of the loss function with respect to its inputs. Typically this will be 99 | block-diagonal across different cases in the batch, since the loss function 100 | is typically summed across cases. 101 | Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, 102 | but will agree with the one used in the other methods of this class. 103 | Args: 104 | vector: The vector to multiply. Must be of the shape given by the 105 | 'hessian_factor_inner_shape' property. 106 | Returns: 107 | The vector right-multiplied by B. Will be of the same shape(s) as the 108 | 'inputs' property. 109 | """ 110 | pass 111 | 112 | @abc.abstractmethod 113 | def multiply_hessian_factor_transpose(self, vector): 114 | """Right-multiply a vector by the transpose of a factor B of the Hessian. 115 | Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 116 | of the loss function with respect to its inputs. Typically this will be 117 | block-diagonal across different cases in the batch, since the loss function 118 | is typically summed across cases. 119 | Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, 120 | but will agree with the one used in the other methods of this class. 121 | Args: 122 | vector: The vector to multiply. Must be the same shape(s) as the 123 | 'inputs' property. 124 | Returns: 125 | The vector right-multiplied by B^T. Will be of the shape given by the 126 | 'hessian_factor_inner_shape' property. 127 | """ 128 | pass 129 | 130 | @abc.abstractmethod 131 | def multiply_hessian_factor_replicated_one_hot(self, index): 132 | """Right-multiply a replicated-one-hot vector by a factor B of the Hessian. 133 | Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 134 | of the loss function with respect to its inputs. Typically this will be 135 | block-diagonal across different cases in the batch, since the loss function 136 | is typically summed across cases. 137 | A 'replicated-one-hot' vector means a tensor which, for each slice along the 138 | batch dimension (assumed to be dimension 0), is 1.0 in the entry 139 | corresponding to the given index and 0 elsewhere. 140 | Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, 141 | but will agree with the one used in the other methods of this class. 142 | Args: 143 | index: A tuple representing in the index of the entry in each slice that 144 | is 1.0. Note that len(index) must be equal to the number of elements 145 | of the 'hessian_factor_inner_shape' tensor minus one. 146 | Returns: 147 | The vector right-multiplied by B^T. Will be of the same shape(s) as the 148 | 'inputs' property. 149 | """ 150 | pass 151 | 152 | @abc.abstractproperty 153 | def hessian_factor_inner_shape(self): 154 | """The shape of the tensor returned by multiply_hessian_factor.""" 155 | pass 156 | 157 | @abc.abstractproperty 158 | def hessian_factor_inner_static_shape(self): 159 | """Static version of hessian_factor_inner_shape.""" 160 | pass 161 | 162 | 163 | @six.add_metaclass(abc.ABCMeta) 164 | class NegativeLogProbLoss(LossFunction): 165 | """Abstract core class for loss functions that are negative log probs.""" 166 | 167 | def __init__(self, seed=None): 168 | self._default_seed = seed 169 | super(NegativeLogProbLoss, self).__init__() 170 | 171 | @property 172 | def inputs(self): 173 | return self.params 174 | 175 | @abc.abstractproperty 176 | def params(self): 177 | """Parameters to the underlying distribution.""" 178 | pass 179 | 180 | @abc.abstractmethod 181 | def multiply_fisher(self, vector): 182 | """Right-multiply a vector by the Fisher. 183 | Args: 184 | vector: The vector to multiply. Must be the same shape(s) as the 185 | 'inputs' property. 186 | Returns: 187 | The vector right-multiplied by the Fisher. Will be of the same shape(s) 188 | as the 'inputs' property. 189 | """ 190 | pass 191 | 192 | @abc.abstractmethod 193 | def multiply_fisher_factor(self, vector): 194 | """Right-multiply a vector by a factor B of the Fisher. 195 | Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- 196 | product of gradients) with respect to the parameters of the underlying 197 | probability distribtion (whose log-prob defines the loss). Typically this 198 | will be block-diagonal across different cases in the batch, since the 199 | distribution is usually (but not always) conditionally iid across different 200 | cases. 201 | Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, 202 | but will agree with the one used in the other methods of this class. 203 | Args: 204 | vector: The vector to multiply. Must be of the shape given by the 205 | 'fisher_factor_inner_shape' property. 206 | Returns: 207 | The vector right-multiplied by B. Will be of the same shape(s) as the 208 | 'inputs' property. 209 | """ 210 | pass 211 | 212 | @abc.abstractmethod 213 | def multiply_fisher_factor_transpose(self, vector): 214 | """Right-multiply a vector by the transpose of a factor B of the Fisher. 215 | Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- 216 | product of gradients) with respect to the parameters of the underlying 217 | probability distribtion (whose log-prob defines the loss). Typically this 218 | will be block-diagonal across different cases in the batch, since the 219 | distribution is usually (but not always) conditionally iid across different 220 | cases. 221 | Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, 222 | but will agree with the one used in the other methods of this class. 223 | Args: 224 | vector: The vector to multiply. Must be the same shape(s) as the 225 | 'inputs' property. 226 | Returns: 227 | The vector right-multiplied by B^T. Will be of the shape given by the 228 | 'fisher_factor_inner_shape' property. 229 | """ 230 | pass 231 | 232 | @abc.abstractmethod 233 | def multiply_fisher_factor_replicated_one_hot(self, index): 234 | """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. 235 | Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- 236 | product of gradients) with respect to the parameters of the underlying 237 | probability distribtion (whose log-prob defines the loss). Typically this 238 | will be block-diagonal across different cases in the batch, since the 239 | distribution is usually (but not always) conditionally iid across different 240 | cases. 241 | A 'replicated-one-hot' vector means a tensor which, for each slice along the 242 | batch dimension (assumed to be dimension 0), is 1.0 in the entry 243 | corresponding to the given index and 0 elsewhere. 244 | Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, 245 | but will agree with the one used in the other methods of this class. 246 | Args: 247 | index: A tuple representing in the index of the entry in each slice that 248 | is 1.0. Note that len(index) must be equal to the number of elements 249 | of the 'fisher_factor_inner_shape' tensor minus one. 250 | Returns: 251 | The vector right-multiplied by B. Will be of the same shape(s) as the 252 | 'inputs' property. 253 | """ 254 | pass 255 | 256 | @abc.abstractproperty 257 | def fisher_factor_inner_shape(self): 258 | """The shape of the tensor returned by multiply_fisher_factor.""" 259 | pass 260 | 261 | @abc.abstractproperty 262 | def fisher_factor_inner_static_shape(self): 263 | """Static version of fisher_factor_inner_shape.""" 264 | pass 265 | 266 | @abc.abstractmethod 267 | def sample(self, seed): 268 | """Sample 'targets' from the underlying distribution.""" 269 | pass 270 | 271 | def evaluate_on_sample(self, seed=None): 272 | """Evaluates the log probability on a random sample. 273 | Args: 274 | seed: int or None. Random seed for this draw from the distribution. 275 | Returns: 276 | Log probability of sampled targets, summed across examples. 277 | """ 278 | if seed is None: 279 | seed = self._default_seed 280 | # We treat the targets as "constant". It's only the inputs that get 281 | # "back-propped" through. 282 | return self._evaluate(array_ops.stop_gradient(self.sample(seed))) 283 | 284 | 285 | # TODO(jamesmartens): should this just inherit from object to avoid "diamond" 286 | # inheritance, or is there a better way? 287 | class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): 288 | """Base class for neg log prob losses whose inputs are 'natural' parameters. 289 | Note that the Hessian and Fisher for natural parameters of exponential- 290 | family models are the same, hence the purpose of this class. 291 | See here: https://arxiv.org/abs/1412.1193 292 | 'Natural parameters' are defined for exponential-family models. See for 293 | example: https://en.wikipedia.org/wiki/Exponential_family 294 | """ 295 | 296 | def multiply_hessian(self, vector): 297 | return self.multiply_fisher(vector) 298 | 299 | def multiply_hessian_factor(self, vector): 300 | return self.multiply_fisher_factor(vector) 301 | 302 | def multiply_hessian_factor_transpose(self, vector): 303 | return self.multiply_fisher_factor_transpose(vector) 304 | 305 | def multiply_hessian_factor_replicated_one_hot(self, index): 306 | return self.multiply_fisher_factor_replicated_one_hot(index) 307 | 308 | @property 309 | def hessian_factor_inner_shape(self): 310 | return self.fisher_factor_inner_shape 311 | 312 | @property 313 | def hessian_factor_inner_static_shape(self): 314 | return self.fisher_factor_inner_shape 315 | 316 | 317 | class DistributionNegativeLogProbLoss(NegativeLogProbLoss): 318 | """Base class for neg log prob losses that use the TF Distribution classes.""" 319 | 320 | def __init__(self, seed=None): 321 | super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) 322 | 323 | @abc.abstractproperty 324 | def dist(self): 325 | """The underlying tf.distributions.Distribution.""" 326 | pass 327 | 328 | def _evaluate(self, targets): 329 | return -math_ops.reduce_sum(self.dist.log_prob(targets)) 330 | 331 | def sample(self, seed): 332 | return self.dist.sample(seed=seed) 333 | 334 | 335 | class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, 336 | NaturalParamsNegativeLogProbLoss): 337 | """Neg log prob loss for a normal distribution parameterized by a mean vector. 338 | Note that the covariance is treated as a constant 'var' times the identity. 339 | Also note that the Fisher for such a normal distribution with respect the mean 340 | parameter is given by: 341 | F = (1/var) * I 342 | See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. 343 | """ 344 | 345 | def __init__(self, mean, var=0.5, targets=None, seed=None): 346 | self._mean = mean 347 | self._var = var 348 | self._targets = targets 349 | super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) 350 | 351 | @property 352 | def targets(self): 353 | return self._targets 354 | 355 | @property 356 | def dist(self): 357 | return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) 358 | 359 | @property 360 | def params(self): 361 | return self._mean 362 | 363 | def multiply_fisher(self, vector): 364 | return (1. / self._var) * vector 365 | 366 | def multiply_fisher_factor(self, vector): 367 | return self._var**-0.5 * vector 368 | 369 | def multiply_fisher_factor_transpose(self, vector): 370 | return self.multiply_fisher_factor(vector) # it's symmetric in this case 371 | 372 | def multiply_fisher_factor_replicated_one_hot(self, index): 373 | assert len(index) == 1, "Length of index was {}".format(len(index)) 374 | ones_slice = array_ops.expand_dims( 375 | array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), 376 | axis=-1) 377 | output_slice = self._var**-0.5 * ones_slice 378 | return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), 379 | index[0]) 380 | 381 | @property 382 | def fisher_factor_inner_shape(self): 383 | return array_ops.shape(self._mean) 384 | 385 | @property 386 | def fisher_factor_inner_static_shape(self): 387 | return self._mean.shape 388 | 389 | 390 | class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): 391 | """Negative log prob loss for a normal distribution with mean and variance. 392 | This class parameterizes a multivariate normal distribution with n independent 393 | dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not 394 | assume the variance is held constant. The Fisher Information for n = 1 395 | is given by, 396 | F = [[1 / variance, 0], 397 | [ 0, 0.5 / variance^2]] 398 | where the parameters of the distribution are concatenated into a single 399 | vector as [mean, variance]. For n > 1, the mean parameter vector is 400 | concatenated with the variance parameter vector. 401 | See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. 402 | """ 403 | 404 | def __init__(self, mean, variance, targets=None, seed=None): 405 | assert len(mean.shape) == 2, "Expect 2D mean tensor." 406 | assert len(variance.shape) == 2, "Expect 2D variance tensor." 407 | self._mean = mean 408 | self._variance = variance 409 | self._scale = math_ops.sqrt(variance) 410 | self._targets = targets 411 | super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) 412 | 413 | @property 414 | def targets(self): 415 | return self._targets 416 | 417 | @property 418 | def dist(self): 419 | return normal.Normal(loc=self._mean, scale=self._scale) 420 | 421 | @property 422 | def params(self): 423 | return self._mean, self._variance 424 | 425 | def _concat(self, mean, variance): 426 | return array_ops.concat([mean, variance], axis=-1) 427 | 428 | def _split(self, params): 429 | return array_ops.split(params, 2, axis=-1) 430 | 431 | @property 432 | def _fisher_mean(self): 433 | return 1. / self._variance 434 | 435 | @property 436 | def _fisher_mean_factor(self): 437 | return 1. / self._scale 438 | 439 | @property 440 | def _fisher_var(self): 441 | return 1. / (2 * math_ops.square(self._variance)) 442 | 443 | @property 444 | def _fisher_var_factor(self): 445 | return 1. / (math_ops.sqrt(2.) * self._variance) 446 | 447 | def multiply_fisher(self, vecs): 448 | mean_vec, var_vec = vecs 449 | return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) 450 | 451 | def multiply_fisher_factor(self, vecs): 452 | mean_vec, var_vec = self._split(vecs) 453 | return (self._fisher_mean_factor * mean_vec, 454 | self._fisher_var_factor * var_vec) 455 | 456 | def multiply_fisher_factor_transpose(self, vecs): 457 | mean_vec, var_vec = vecs 458 | return self._concat(self._fisher_mean_factor * mean_vec, 459 | self._fisher_var_factor * var_vec) 460 | 461 | def multiply_fisher_factor_replicated_one_hot(self, index): 462 | assert len(index) == 1, "Length of index was {}".format(len(index)) 463 | index = index[0] 464 | 465 | if index < int(self._mean.shape[-1]): 466 | # Index corresponds to mean parameter. 467 | mean_slice = self._fisher_mean_factor[:, index] 468 | mean_slice = array_ops.expand_dims(mean_slice, axis=-1) 469 | mean_output = insert_slice_in_zeros(mean_slice, 1, int( 470 | self._mean.shape[1]), index) 471 | var_output = array_ops.zeros_like(mean_output) 472 | else: 473 | index -= int(self._mean.shape[-1]) 474 | # Index corresponds to variance parameter. 475 | var_slice = self._fisher_var_factor[:, index] 476 | var_slice = array_ops.expand_dims(var_slice, axis=-1) 477 | var_output = insert_slice_in_zeros(var_slice, 1, 478 | int(self._variance.shape[1]), index) 479 | mean_output = array_ops.zeros_like(var_output) 480 | 481 | return mean_output, var_output 482 | 483 | @property 484 | def fisher_factor_inner_shape(self): 485 | return array_ops.concat( 486 | [ 487 | array_ops.shape(self._mean)[:-1], 488 | 2 * array_ops.shape(self._mean)[-1:] 489 | ], 490 | axis=0) 491 | 492 | @property 493 | def fisher_factor_inner_static_shape(self): 494 | shape = self._mean.shape.as_list() 495 | return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) 496 | 497 | def multiply_hessian(self, vector): 498 | raise NotImplementedError() 499 | 500 | def multiply_hessian_factor(self, vector): 501 | raise NotImplementedError() 502 | 503 | def multiply_hessian_factor_transpose(self, vector): 504 | raise NotImplementedError() 505 | 506 | def multiply_hessian_factor_replicated_one_hot(self, index): 507 | raise NotImplementedError() 508 | 509 | @property 510 | def hessian_factor_inner_shape(self): 511 | raise NotImplementedError() 512 | 513 | @property 514 | def hessian_factor_inner_static_shape(self): 515 | raise NotImplementedError() 516 | 517 | 518 | class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, 519 | NaturalParamsNegativeLogProbLoss): 520 | """Neg log prob loss for a categorical distribution parameterized by logits. 521 | Note that the Fisher (for a single case) of a categorical distribution, with 522 | respect to the natural parameters (i.e. the logits), is given by: 523 | F = diag(p) - p*p^T 524 | where p = softmax(logits). F can be factorized as F = B * B^T where 525 | B = diag(q) - p*q^T 526 | where q is the entry-wise square root of p. This is easy to verify using the 527 | fact that q^T*q = 1. 528 | """ 529 | 530 | def __init__(self, logits, targets=None, seed=None): 531 | """Instantiates a CategoricalLogitsNegativeLogProbLoss. 532 | Args: 533 | logits: Tensor of shape [batch_size, output_size]. Parameters for 534 | underlying distribution. 535 | targets: None or Tensor of shape [output_size]. Each elements contains an 536 | index in [0, output_size). 537 | seed: int or None. Default random seed when sampling. 538 | """ 539 | self._logits_components = [] 540 | self._targets_components = [] 541 | self.register_additional_minibatch(logits, targets=targets) 542 | super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) 543 | 544 | def register_additional_minibatch(self, logits, targets=None): 545 | """Register an additiona minibatch's worth of parameters. 546 | Args: 547 | logits: Tensor of shape [batch_size, output_size]. Parameters for 548 | underlying distribution. 549 | targets: None or Tensor of shape [batch_size, output_size]. Each row must 550 | be a one-hot vector. 551 | """ 552 | self._logits_components.append(logits) 553 | self._targets_components.append(targets) 554 | 555 | @property 556 | def _logits(self): 557 | return array_ops.concat(self._logits_components, axis=0) 558 | 559 | @property 560 | def input_minibatches(self): 561 | return self._logits_components 562 | 563 | @property 564 | def targets(self): 565 | if all(target is None for target in self._targets_components): 566 | return None 567 | return array_ops.concat(self._targets_components, axis=0) 568 | 569 | @property 570 | def dist(self): 571 | return categorical.Categorical(logits=self._logits) 572 | 573 | @property 574 | def _probs(self): 575 | return self.dist.probs 576 | 577 | @property 578 | def _sqrt_probs(self): 579 | return math_ops.sqrt(self._probs) 580 | 581 | @property 582 | def params(self): 583 | return self._logits 584 | 585 | def multiply_fisher(self, vector): 586 | probs = self._probs 587 | return vector * probs - probs * math_ops.reduce_sum( 588 | vector * probs, axis=1, keep_dims=True) 589 | 590 | def multiply_fisher_factor(self, vector): 591 | probs = self._probs 592 | sqrt_probs = self._sqrt_probs 593 | return sqrt_probs * vector - probs * math_ops.reduce_sum( 594 | sqrt_probs * vector, axis=1, keep_dims=True) 595 | 596 | def multiply_fisher_factor_transpose(self, vector): 597 | probs = self._probs 598 | sqrt_probs = self._sqrt_probs 599 | return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( 600 | probs * vector, axis=1, keep_dims=True) 601 | 602 | def multiply_fisher_factor_replicated_one_hot(self, index): 603 | assert len(index) == 1, "Length of index was {}".format(len(index)) 604 | probs = self._probs 605 | sqrt_probs = self._sqrt_probs 606 | sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1) 607 | padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, 608 | int(sqrt_probs.shape[1]), index[0]) 609 | return padded_slice - probs * sqrt_probs_slice 610 | 611 | @property 612 | def fisher_factor_inner_shape(self): 613 | return array_ops.shape(self._logits) 614 | 615 | @property 616 | def fisher_factor_inner_static_shape(self): 617 | return self._logits.shape 618 | 619 | 620 | class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, 621 | NaturalParamsNegativeLogProbLoss): 622 | """Neg log prob loss for multiple Bernoulli distributions param'd by logits. 623 | Represents N independent Bernoulli distributions where N = len(logits). Its 624 | Fisher Information matrix is given by, 625 | F = diag(p * (1-p)) 626 | p = sigmoid(logits) 627 | As F is diagonal with positive entries, its factor B is, 628 | B = diag(sqrt(p * (1-p))) 629 | """ 630 | 631 | def __init__(self, logits, targets=None, seed=None): 632 | self._logits = logits 633 | self._targets = targets 634 | super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) 635 | 636 | @property 637 | def targets(self): 638 | return self._targets 639 | 640 | @property 641 | def dist(self): 642 | return bernoulli.Bernoulli(logits=self._logits) 643 | 644 | @property 645 | def _probs(self): 646 | return self.dist.probs 647 | 648 | @property 649 | def params(self): 650 | return self._logits 651 | 652 | def multiply_fisher(self, vector): 653 | return self._probs * (1 - self._probs) * vector 654 | 655 | def multiply_fisher_factor(self, vector): 656 | return math_ops.sqrt(self._probs * (1 - self._probs)) * vector 657 | 658 | def multiply_fisher_factor_transpose(self, vector): 659 | return self.multiply_fisher_factor(vector) # it's symmetric in this case 660 | 661 | def multiply_fisher_factor_replicated_one_hot(self, index): 662 | assert len(index) == 1, "Length of index was {}".format(len(index)) 663 | probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) 664 | output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) 665 | return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), 666 | index[0]) 667 | 668 | @property 669 | def fisher_factor_inner_shape(self): 670 | return array_ops.shape(self._logits) 671 | 672 | @property 673 | def fisher_factor_inner_static_shape(self): 674 | return self._logits.shape 675 | 676 | 677 | def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): 678 | """Inserts slice into a larger tensor of zeros. 679 | Forms a new tensor which is the same shape as slice_to_insert, except that 680 | the dimension given by 'dim' is expanded to the size given by 'dim_size'. 681 | 'position' determines the position (index) at which to insert the slice within 682 | that dimension. 683 | Assumes slice_to_insert.shape[dim] = 1. 684 | Args: 685 | slice_to_insert: The slice to insert. 686 | dim: The dimension which to expand with zeros. 687 | dim_size: The new size of the 'dim' dimension. 688 | position: The position of 'slice_to_insert' in the new tensor. 689 | Returns: 690 | The new tensor. 691 | Raises: 692 | ValueError: If the slice's shape at the given dim is not 1. 693 | """ 694 | slice_shape = slice_to_insert.shape 695 | if slice_shape[dim] != 1: 696 | raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " 697 | "was {}".format(dim, slice_to_insert.shape[dim])) 698 | 699 | before = [0] * int(len(slice_shape)) 700 | after = before[:] 701 | before[dim] = position 702 | after[dim] = dim_size - position - 1 703 | 704 | return array_ops.pad(slice_to_insert, list(zip(before, after))) 705 | -------------------------------------------------------------------------------- /libs/kfac/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.python.framework import ops 6 | from tensorflow.python.ops import array_ops 7 | from tensorflow.python.ops import gen_array_ops 8 | from tensorflow.python.ops import math_ops 9 | from tensorflow.python.ops import variable_scope 10 | from tensorflow.python.ops import variables as tf_variables 11 | from tensorflow.python.training import gradient_descent 12 | from libs.kfac import estimator as est 13 | 14 | class KFACOptimizer(gradient_descent.GradientDescentOptimizer): 15 | """ 16 | KFAC Optimizer 17 | """ 18 | 19 | def __init__(self, 20 | learning_rate, 21 | damping, 22 | layer_collection, 23 | cov_ema_decay=None, 24 | var_list=None, 25 | momentum=0., 26 | momentum_type="regular", 27 | weight_decay=0., 28 | weight_decay_type="l2", 29 | weight_list="all", 30 | norm_constraint=None, 31 | name="KFAC", 32 | estimation_mode="gradients", 33 | colocate_gradients_with_ops=False, 34 | cov_devices=None, 35 | inv_devices=None): 36 | 37 | variables = var_list 38 | if variables is None: 39 | variables = tf_variables.trainable_variables() 40 | self.variables = variables 41 | self.damping = damping 42 | 43 | momentum_type = momentum_type.lower() 44 | legal_momentum_types = ["regular", "adam"] 45 | 46 | if momentum_type not in legal_momentum_types: 47 | raise ValueError("Unsupported momentum type {}. Must be one of {}." 48 | .format(momentum_type, legal_momentum_types)) 49 | if momentum_type != "regular" and norm_constraint is not None: 50 | raise ValueError("Update clipping is only supported with momentum" 51 | "type 'regular'.") 52 | 53 | weight_decay_type = weight_decay_type.lower() 54 | legal_weight_decay_types = ["wd", "l2"] 55 | 56 | if weight_decay_type not in legal_weight_decay_types: 57 | raise ValueError("Unsupported weight decay type {}. Must be one of {}." 58 | .format(weight_decay_type, legal_weight_decay_types)) 59 | 60 | self._momentum = momentum 61 | self._momentum_type = momentum_type 62 | self._weight_decay = weight_decay 63 | self._weight_decay_type = weight_decay_type 64 | self._weight_list = weight_list 65 | self._norm_constraint = norm_constraint 66 | 67 | self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] 68 | self._losses = layer_collection.losses 69 | 70 | with variable_scope.variable_scope(name): 71 | self._fisher_est = est.FisherEstimator( 72 | variables, 73 | cov_ema_decay, 74 | damping, 75 | layer_collection, 76 | estimation_mode=estimation_mode, 77 | colocate_gradients_with_ops=colocate_gradients_with_ops, 78 | cov_devices=cov_devices, 79 | inv_devices=inv_devices) 80 | 81 | self.cov_update_op = self._fisher_est.cov_update_op 82 | self.inv_update_op = self._fisher_est.inv_update_op 83 | self.inv_update_dict = self._fisher_est.inv_updates_dict 84 | 85 | self.init_cov_op = self._fisher_est.init_cov_op 86 | 87 | super(KFACOptimizer, self).__init__(learning_rate, name=name) 88 | 89 | def minimize(self, *args, **kwargs): 90 | kwargs["var_list"] = kwargs.get("var_list") or self.variables 91 | if set(kwargs["var_list"]) != set(self.variables): 92 | raise ValueError("var_list doesn't match with set of Fisher-estimating " 93 | "variables.") 94 | return super(KFACOptimizer, self).minimize(*args, **kwargs) 95 | 96 | def compute_gradients(self, *args, **kwargs): 97 | # args[1] could be our var_list 98 | if len(args) > 1: 99 | var_list = args[1] 100 | else: 101 | kwargs["var_list"] = kwargs.get("var_list") or self.variables 102 | var_list = kwargs["var_list"] 103 | if set(var_list) != set(self.variables): 104 | raise ValueError("var_list doesn't match with set of Fisher-estimating " 105 | "variables.") 106 | return super(KFACOptimizer, self).compute_gradients(*args, **kwargs) 107 | 108 | def apply_gradients(self, grads_and_vars, *args, **kwargs): 109 | grads_and_vars = list(grads_and_vars) 110 | 111 | if self._weight_decay_type == "l2" and self._weight_decay > 0.0: 112 | grads_and_vars = self._add_weight_decay(grads_and_vars) 113 | 114 | steps_and_vars = self._compute_update_steps(grads_and_vars) 115 | 116 | if self._weight_decay_type == "wd" and self._weight_decay > 0.0: 117 | steps_and_vars = self._add_weight_decay(steps_and_vars) 118 | 119 | return super(KFACOptimizer, self).apply_gradients(steps_and_vars, 120 | *args, **kwargs) 121 | 122 | def _add_weight_decay(self, vecs_and_vars): 123 | if self._weight_list == "all": 124 | print("all") 125 | return [(vec + self._weight_decay * gen_array_ops.stop_gradient(var), var) 126 | for vec, var in vecs_and_vars] 127 | elif self._weight_list == "last": 128 | print("last") 129 | grad_list = [] 130 | for vec, var in vecs_and_vars: 131 | if 'fc' not in var.name: 132 | grad_list.append((vec, var)) 133 | else: 134 | grad_list.append( 135 | (vec + self._weight_decay * 136 | gen_array_ops.stop_gradient(var), var)) 137 | return grad_list 138 | else: 139 | print("conv") 140 | grad_list = [] 141 | for vec, var in vecs_and_vars: 142 | if 'fc' in var.name: 143 | grad_list.append((vec, var)) 144 | else: 145 | grad_list.append( 146 | (vec + self._weight_decay * 147 | gen_array_ops.stop_gradient(var), var)) 148 | return grad_list 149 | 150 | def _compute_update_steps(self, grads_and_vars): 151 | if self._momentum_type == "regular": 152 | precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) 153 | 154 | # Apply "KL clipping" if asked for. 155 | if self._norm_constraint is not None: 156 | precon_grads_and_vars = self._clip_updates(grads_and_vars, 157 | precon_grads_and_vars) 158 | 159 | # Update the velocity with this and return it as the step. 160 | return self._update_velocities(precon_grads_and_vars, self._momentum) 161 | elif self._momentum_type == "adam": 162 | # Update velocity. 163 | velocities_and_vars = self._update_velocities(grads_and_vars, 164 | self._momentum) 165 | # Return "preconditioned" velocity vector as the step. 166 | precon_grads_and_vars = self._fisher_est.multiply_inverse(velocities_and_vars) 167 | 168 | return precon_grads_and_vars 169 | 170 | def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars): 171 | for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars): 172 | if gvar is not pgvar: 173 | raise ValueError("The variables referenced by the two arguments " 174 | "must match.") 175 | terms = [ 176 | math_ops.reduce_sum(grad * pgrad) 177 | for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars) 178 | ] 179 | return math_ops.reduce_sum(terms) 180 | 181 | def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars): 182 | sq_norm_grad = self._squared_fisher_norm(grads_and_vars, 183 | precon_grads_and_vars) 184 | sq_norm_up = sq_norm_grad * self._learning_rate**2 185 | return math_ops.minimum(1., math_ops.sqrt(self._norm_constraint / sq_norm_up)) 186 | 187 | def _clip_updates(self, grads_and_vars, precon_grads_and_vars): 188 | coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) 189 | return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] 190 | 191 | def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): 192 | def _update_velocity(vec, var): 193 | velocity = self._zeros_slot(var, "velocity", self._name) 194 | with ops.colocate_with(velocity): 195 | # Compute the new velocity for this variable. 196 | new_velocity = decay * velocity + vec_coeff * vec 197 | 198 | # Save the updated velocity. 199 | return (array_ops.identity(velocity.assign(new_velocity)), var) 200 | 201 | # Go through variable and update its associated part of the velocity vector. 202 | return [_update_velocity(vec, var) for vec, var in vecs_and_vars] 203 | -------------------------------------------------------------------------------- /libs/kfac/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | from tensorflow.python.framework import dtypes 8 | from tensorflow.python.framework import ops 9 | from tensorflow.python.ops import array_ops 10 | from tensorflow.python.ops import gradients_impl 11 | from tensorflow.python.ops import linalg_ops 12 | from tensorflow.python.ops import math_ops 13 | from tensorflow.python.ops import gen_math_ops 14 | from tensorflow.python.ops import random_ops 15 | from collections import OrderedDict 16 | 17 | # Method used for inverting matrices. 18 | POSDEF_INV_METHOD = "eig" 19 | POSDEF_EIG_METHOD = "self_adjoint" 20 | 21 | 22 | def set_global_constants(posdef_inv_method=None): 23 | """Sets various global constants used by the classes in this module.""" 24 | global POSDEF_INV_METHOD 25 | 26 | if posdef_inv_method is not None: 27 | POSDEF_INV_METHOD = posdef_inv_method 28 | 29 | 30 | class SequenceDict(object): 31 | """A dict convenience wrapper that allows getting/setting with sequences.""" 32 | 33 | def __init__(self, iterable=None): 34 | self._dict = dict(iterable or []) 35 | 36 | def __getitem__(self, key_or_keys): 37 | if isinstance(key_or_keys, (tuple, list)): 38 | return list(map(self.__getitem__, key_or_keys)) 39 | else: 40 | return self._dict[key_or_keys] 41 | 42 | def __setitem__(self, key_or_keys, val_or_vals): 43 | if isinstance(key_or_keys, (tuple, list)): 44 | for key, value in zip(key_or_keys, val_or_vals): 45 | self[key] = value 46 | else: 47 | self._dict[key_or_keys] = val_or_vals 48 | 49 | def items(self): 50 | return list(self._dict.items()) 51 | 52 | 53 | def tensors_to_column(tensors): 54 | """Converts a tensor or list of tensors to a column vector. 55 | Args: 56 | tensors: A tensor or list of tensors. 57 | Returns: 58 | The tensors reshaped into vectors and stacked on top of each other. 59 | """ 60 | if isinstance(tensors, (tuple, list)): 61 | return array_ops.concat( 62 | tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) 63 | else: 64 | return array_ops.reshape(tensors, [-1, 1]) 65 | 66 | 67 | def column_to_tensors(tensors_template, colvec): 68 | """Converts a column vector back to the shape of the given template. 69 | Args: 70 | tensors_template: A tensor or list of tensors. 71 | colvec: A 2d column vector with the same shape as the value of 72 | tensors_to_column(tensors_template). 73 | Returns: 74 | X, where X is tensor or list of tensors with the properties: 75 | 1) tensors_to_column(X) = colvec 76 | 2) X (or its elements) have the same shape as tensors_template (or its 77 | elements) 78 | """ 79 | if isinstance(tensors_template, (tuple, list)): 80 | offset = 0 81 | tensors = [] 82 | for tensor_template in tensors_template: 83 | sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) 84 | tensor = array_ops.reshape(colvec[offset:(offset + sz)], 85 | tensor_template.shape) 86 | tensors.append(tensor) 87 | offset += sz 88 | 89 | tensors = tuple(tensors) 90 | else: 91 | tensors = array_ops.reshape(colvec, tensors_template.shape) 92 | 93 | return tensors 94 | 95 | 96 | def kronecker_product(mat1, mat2): 97 | """Computes the Kronecker product two matrices.""" 98 | m1, n1 = mat1.get_shape().as_list() 99 | mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) 100 | m2, n2 = mat2.get_shape().as_list() 101 | mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) 102 | return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) 103 | 104 | 105 | def layer_params_to_mat2d(vector): 106 | """Converts a vector shaped like layer parameters to a 2D matrix. 107 | In particular, we reshape the weights/filter component of the vector to be 108 | 2D, flattening all leading (input) dimensions. If there is a bias component, 109 | we concatenate it to the reshaped weights/filter component. 110 | Args: 111 | vector: A Tensor or pair of Tensors shaped like layer parameters. 112 | Returns: 113 | A 2D Tensor with the same coefficients and the same output dimension. 114 | """ 115 | if isinstance(vector, (tuple, list)): 116 | w_part, b_part = vector 117 | w_part_reshaped = array_ops.reshape(w_part, 118 | [-1, w_part.shape.as_list()[-1]]) 119 | return array_ops.concat( 120 | (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) 121 | else: 122 | return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) 123 | 124 | 125 | def mat2d_to_layer_params(vector_template, mat2d): 126 | """Converts a canonical 2D matrix representation back to a vector. 127 | Args: 128 | vector_template: A Tensor or pair of Tensors shaped like layer parameters. 129 | mat2d: A 2D Tensor with the same shape as the value of 130 | layer_params_to_mat2d(vector_template). 131 | Returns: 132 | A Tensor or pair of Tensors with the same coefficients as mat2d and the same 133 | shape as vector_template. 134 | """ 135 | if isinstance(vector_template, (tuple, list)): 136 | w_part, b_part = mat2d[:-1], mat2d[-1] 137 | return array_ops.reshape(w_part, vector_template[0].shape), b_part 138 | else: 139 | return array_ops.reshape(mat2d, vector_template.shape) 140 | 141 | 142 | def posdef_inv(tensor, damping): 143 | """Computes the inverse of tensor + damping * identity.""" 144 | identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) 145 | damping = math_ops.cast(damping, dtype=tensor.dtype) 146 | return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) 147 | 148 | 149 | def posdef_inv_matrix_inverse(tensor, identity, damping): 150 | """Computes inverse(tensor + damping * identity) directly.""" 151 | return linalg_ops.matrix_inverse(tensor + damping * identity) 152 | 153 | 154 | def posdef_inv_cholesky(tensor, identity, damping): 155 | """Computes inverse(tensor + damping * identity) with Cholesky.""" 156 | chol = linalg_ops.cholesky(tensor + damping * identity) 157 | return linalg_ops.cholesky_solve(chol, identity) 158 | 159 | 160 | def posdef_inv_eig(tensor, identity, damping): 161 | """Computes inverse(tensor + damping * identity) with eigendecomposition.""" 162 | eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( 163 | tensor + damping * identity) 164 | # TODO(GD): it's a little hacky 165 | eigenvalues = gen_math_ops.maximum(eigenvalues, damping) 166 | return math_ops.matmul( 167 | eigenvectors / eigenvalues, eigenvectors, transpose_b=True) 168 | 169 | 170 | posdef_inv_functions = { 171 | "matrix_inverse": posdef_inv_matrix_inverse, 172 | "cholesky": posdef_inv_cholesky, 173 | "eig": posdef_inv_eig, 174 | } 175 | 176 | 177 | def posdef_eig(mat): 178 | """Computes the eigendecomposition of a positive semidefinite matrix.""" 179 | return posdef_eig_functions[POSDEF_EIG_METHOD](mat) 180 | 181 | 182 | def posdef_eig_svd(mat): 183 | """Computes the singular values and left singular vectors of a matrix.""" 184 | evals, evecs, _ = linalg_ops.svd(mat) 185 | 186 | return evals, evecs 187 | 188 | 189 | def posdef_eig_self_adjoint(mat): 190 | """Computes eigendecomposition using self_adjoint_eig.""" 191 | evals, evecs = linalg_ops.self_adjoint_eig(mat) 192 | evals = math_ops.abs(evals) # Should be equivalent to svd approach. 193 | 194 | return evals, evecs 195 | 196 | 197 | posdef_eig_functions = { 198 | "self_adjoint": posdef_eig_self_adjoint, 199 | "svd": posdef_eig_svd, 200 | } 201 | 202 | 203 | def generate_random_signs(shape, dtype=dtypes.float32): 204 | """Generate a random tensor with {-1, +1} entries.""" 205 | ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32) 206 | return 2 * math_ops.cast(ints, dtype=dtype) - 1 207 | 208 | 209 | def ensure_sequence(obj): 210 | """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" 211 | if isinstance(obj, (tuple, list)): 212 | return obj 213 | else: 214 | return (obj,) 215 | 216 | 217 | def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): 218 | """Compute forward-mode gradients.""" 219 | us = [array_ops.zeros_like(y) + float("nan") for y in ys] 220 | dydxs = gradients_impl.gradients( 221 | ys, xs, grad_ys=us, stop_gradients=stop_gradients) 222 | # Deal with strange types that gradients_impl.gradients returns but can't 223 | # deal with. 224 | dydxs = [ 225 | ops.convert_to_tensor(dydx) 226 | if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs 227 | ] 228 | dydxs = [ 229 | array_ops.zeros_like(x) if dydx is None else dydx 230 | for x, dydx in zip(xs, dydxs) 231 | ] 232 | 233 | dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) 234 | 235 | return dysdx 236 | 237 | 238 | class LayerParametersDict(OrderedDict): 239 | """An OrderedDict where keys are Tensors or tuples of Tensors. 240 | Ensures that no Tensor is associated with two different keys. 241 | """ 242 | 243 | def __init__(self, *args, **kwargs): 244 | self._tensors = set() 245 | super(LayerParametersDict, self).__init__(*args, **kwargs) 246 | 247 | def __setitem__(self, key, value): 248 | key = self._canonicalize_key(key) 249 | tensors = key if isinstance(key, (tuple, list)) else (key,) 250 | key_collisions = self._tensors.intersection(tensors) 251 | if key_collisions: 252 | raise ValueError("Key(s) already present: {}".format(key_collisions)) 253 | self._tensors.update(tensors) 254 | super(LayerParametersDict, self).__setitem__(key, value) 255 | 256 | def __delitem__(self, key): 257 | key = self._canonicalize_key(key) 258 | self._tensors.remove(key) 259 | super(LayerParametersDict, self).__delitem__(key) 260 | 261 | def __getitem__(self, key): 262 | key = self._canonicalize_key(key) 263 | return super(LayerParametersDict, self).__getitem__(key) 264 | 265 | def __contains__(self, key): 266 | key = self._canonicalize_key(key) 267 | return super(LayerParametersDict, self).__contains__(key) 268 | 269 | def _canonicalize_key(self, key): 270 | if isinstance(key, (list, tuple)): 271 | return tuple(key) 272 | return key 273 | -------------------------------------------------------------------------------- /libs/sgd/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.python.framework import ops 6 | from tensorflow.python.ops import array_ops 7 | from tensorflow.python.ops import math_ops 8 | from tensorflow.python.ops import gen_array_ops 9 | from tensorflow.python.ops import variable_scope 10 | from tensorflow.python.ops import variables as tf_variables 11 | from tensorflow.python.training import gradient_descent 12 | from libs.kfac import estimator as est 13 | 14 | 15 | class SGDOptimizer(gradient_descent.GradientDescentOptimizer): 16 | """ 17 | SGD Optimizer 18 | """ 19 | 20 | def __init__(self, 21 | learning_rate, 22 | var_list=None, 23 | momentum=0., 24 | weight_decay=0., 25 | weight_decay_type="l2", 26 | weight_list="all", 27 | name="SGD"): 28 | 29 | variables = var_list 30 | if variables is None: 31 | variables = tf_variables.trainable_variables() 32 | self.variables = variables 33 | 34 | weight_decay_type = weight_decay_type.lower() 35 | legal_weight_decay_types = ["wd", "l2", "fisher"] 36 | 37 | if weight_decay_type not in legal_weight_decay_types: 38 | raise ValueError("Unsupported weight decay type {}. Must be one of {}." 39 | .format(weight_decay_type, legal_weight_decay_types)) 40 | 41 | self._momentum = momentum 42 | self._weight_decay = weight_decay 43 | self._weight_decay_type = weight_decay_type 44 | self._weight_list = weight_list 45 | 46 | super(SGDOptimizer, self).__init__(learning_rate, name=name) 47 | 48 | def minimize(self, *args, **kwargs): 49 | kwargs["var_list"] = kwargs.get("var_list") or self.variables 50 | if set(kwargs["var_list"]) != set(self.variables): 51 | raise ValueError("var_list doesn't match with set of Fisher-estimating " 52 | "variables.") 53 | return super(SGDOptimizer, self).minimize(*args, **kwargs) 54 | 55 | def compute_gradients(self, *args, **kwargs): 56 | # args[1] could be our var_list 57 | if len(args) > 1: 58 | var_list = args[1] 59 | else: 60 | kwargs["var_list"] = kwargs.get("var_list") or self.variables 61 | var_list = kwargs["var_list"] 62 | if set(var_list) != set(self.variables): 63 | raise ValueError("var_list doesn't match with set of Fisher-estimating " 64 | "variables.") 65 | return super(SGDOptimizer, self).compute_gradients(*args, **kwargs) 66 | 67 | def apply_gradients(self, grads_and_vars, *args, **kwargs): 68 | grads_and_vars = list(grads_and_vars) 69 | 70 | if self._weight_decay > 0.0: 71 | if self._weight_decay_type == "l2" or self._weight_decay_type == "wd": 72 | grads_and_vars = self._add_weight_decay(grads_and_vars) 73 | 74 | steps_and_vars = self._compute_update_steps(grads_and_vars) 75 | return super(SGDOptimizer, self).apply_gradients(steps_and_vars, 76 | *args, **kwargs) 77 | 78 | def _add_weight_decay(self, vecs_and_vars): 79 | if self._weight_list == "all": 80 | print("all") 81 | return [(vec + self._weight_decay * gen_array_ops.stop_gradient(var), var) 82 | for vec, var in vecs_and_vars] 83 | elif self._weight_list == "last": 84 | print("last") 85 | grad_list = [] 86 | for vec, var in vecs_and_vars: 87 | if 'fc' not in var.name: 88 | grad_list.append((vec, var)) 89 | else: 90 | grad_list.append( 91 | (vec + self._weight_decay * 92 | gen_array_ops.stop_gradient(var), var)) 93 | return grad_list 94 | else: 95 | print("conv") 96 | grad_list = [] 97 | for vec, var in vecs_and_vars: 98 | if 'fc' in var.name: 99 | grad_list.append((vec, var)) 100 | else: 101 | grad_list.append( 102 | (vec + self._weight_decay * 103 | gen_array_ops.stop_gradient(var), var)) 104 | return grad_list 105 | 106 | def _compute_update_steps(self, grads_and_vars): 107 | return self._update_velocities(grads_and_vars, self._momentum) 108 | 109 | def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): 110 | def _update_velocity(vec, var): 111 | velocity = self._zeros_slot(var, "velocity", self._name) 112 | with ops.colocate_with(velocity): 113 | # Compute the new velocity for this variable. 114 | new_velocity = decay * velocity + vec_coeff * vec 115 | 116 | # Save the updated velocity. 117 | return (array_ops.identity(velocity.assign(new_velocity)), var) 118 | 119 | # Go through variable and update its associated part of the velocity vector. 120 | return [_update_velocity(vec, var) for vec, var in vecs_and_vars] 121 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import os 8 | 9 | from misc.utils import get_logger, get_args, makedirs 10 | from misc.config import process_config 11 | from core.model import Model 12 | from core.train import Trainer 13 | from data_loader import load_pytorch 14 | 15 | 16 | _INPUT_DIM = { 17 | 'fmnist': [784], 18 | 'mnist': [784], 19 | 'cifar10': [32, 32, 3], 20 | 'cifar100': [32, 32, 3] 21 | } 22 | 23 | _OUTPUT_DIM = { 24 | 'fmnist': 10, 25 | 'mnist': 10, 26 | 'cifar10': 10, 27 | 'cifar100': 100 28 | } 29 | 30 | 31 | def main(): 32 | tf.set_random_seed(1231) 33 | np.random.seed(1231) 34 | 35 | try: 36 | args = get_args() 37 | config = process_config(args.config) 38 | config.input_dim = _INPUT_DIM[config.dataset] 39 | config.output_dim = _OUTPUT_DIM[config.dataset] 40 | except: 41 | print("Add a config file using \'--config file_name.json\'") 42 | exit(1) 43 | 44 | makedirs(config.summary_dir) 45 | makedirs(config.checkpoint_dir) 46 | 47 | # set logger 48 | path = os.path.dirname(os.path.abspath(__file__)) 49 | path_model = os.path.join(path, 'core/model.py') 50 | path_train = os.path.join(path, 'core/train.py') 51 | logger = get_logger('log', logpath=config.summary_dir+'/', 52 | filepath=path_model, package_files=[path_train]) 53 | logger.info(dict(config)) 54 | 55 | # load data 56 | train_loader, test_loader = load_pytorch(config) 57 | 58 | # define computational graph 59 | sess = tf.Session() 60 | 61 | model = Model(config, sess) 62 | trainer = Trainer(sess, model, train_loader, test_loader, config, logger) 63 | 64 | trainer.train() 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /misc/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from easydict import EasyDict as edict 4 | 5 | 6 | def get_config_from_json(json_file): 7 | """ 8 | Get the config from a json file 9 | :param json_file: 10 | :return: config(namespace) or config(dictionary) 11 | """ 12 | # parse the configurations from the config json file provided 13 | with open(json_file, 'r') as config_file: 14 | config_dict = json.load(config_file) 15 | config = edict(config_dict) 16 | 17 | return config, config_dict 18 | 19 | 20 | def process_config(json_file): 21 | config, _ = get_config_from_json(json_file) 22 | paths = json_file.split('/')[1:-1] 23 | summary_dir = ["./experiments"] + paths + [config.exp_name, "summary/"] 24 | ckpt_dir = ["./experiments"] + paths + [config.exp_name, "checkpoint/"] 25 | # print('Summary dir is', summary_dir) 26 | # print('Checkpoint dir is', ckpt_dir) 27 | config.summary_dir = os.path.join(*summary_dir) 28 | config.checkpoint_dir = os.path.join(*ckpt_dir) 29 | return config 30 | -------------------------------------------------------------------------------- /misc/summarizer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | 5 | class Summarizer: 6 | def __init__(self, sess, config): 7 | self.sess = sess 8 | self.config = config 9 | self.summary_placeholders = {} 10 | self.summary_ops = {} 11 | self.summary_writer = tf.summary.FileWriter(self.config.summary_dir, self.sess.graph) 12 | 13 | # it can summarize scalars and images. 14 | def summarize(self, step, scope="", summaries_dict=None): 15 | """ 16 | :param step: the step of the summary 17 | :param scope: variable scope 18 | :param summaries_dict: the dict of the summaries values (tag,value) 19 | :return: 20 | """ 21 | summary_writer = self.summary_writer 22 | with tf.variable_scope(scope): 23 | 24 | if summaries_dict is not None: 25 | summary_list = [] 26 | for tag, value in summaries_dict.items(): 27 | if tag not in self.summary_ops: 28 | if len(value.shape) <= 1: 29 | self.summary_placeholders[tag] = tf.placeholder('float32', value.shape, name=tag) 30 | else: 31 | self.summary_placeholders[tag] = tf.placeholder('float32', [None] + list(value.shape[1:]), name=tag) 32 | if len(value.shape) <= 1: 33 | self.summary_ops[tag] = tf.summary.scalar(tag, self.summary_placeholders[tag]) 34 | else: 35 | self.summary_ops[tag] = tf.summary.image(tag, self.summary_placeholders[tag]) 36 | 37 | summary_list.append(self.sess.run(self.summary_ops[tag], {self.summary_placeholders[tag]: value})) 38 | 39 | for summary in summary_list: 40 | summary_writer.add_summary(summary, step) 41 | summary_writer.flush() 42 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | import time 7 | 8 | import logging 9 | import tensorflow as tf 10 | import numpy as np 11 | import argparse 12 | 13 | 14 | def get_args(): 15 | argparser = argparse.ArgumentParser(description=__doc__) 16 | argparser.add_argument( 17 | '-c', '--config', default='None', help='The Configuration file') 18 | argparser.add_argument( 19 | '-f', '--fig_name', default='tmp', help='The Figure name') 20 | args = argparser.parse_args() 21 | return args 22 | 23 | 24 | def makedirs(filename): 25 | if not os.path.exists(os.path.dirname(filename)): 26 | os.makedirs(os.path.dirname(filename)) 27 | 28 | 29 | def get_logger(name, logpath, filepath, package_files=[], 30 | displaying=True, saving=True): 31 | logger = logging.getLogger(name) 32 | logger.setLevel(logging.INFO) 33 | log_path = logpath + name + time.strftime("-%Y%m%d-%H%M%S") 34 | makedirs(log_path) 35 | if saving: 36 | info_file_handler = logging.FileHandler(log_path) 37 | info_file_handler.setLevel(logging.INFO) 38 | logger.addHandler(info_file_handler) 39 | logger.info(filepath) 40 | with open(filepath, 'r') as f: 41 | logger.info(f.read()) 42 | 43 | for f in package_files: 44 | logger.info(f) 45 | with open(f, 'r') as package_f: 46 | logger.info(package_f.read()) 47 | if displaying: 48 | console_handler = logging.StreamHandler() 49 | console_handler.setLevel(logging.INFO) 50 | logger.addHandler(console_handler) 51 | 52 | return logger 53 | 54 | 55 | def var_shape(x): 56 | out = [k.value for k in x.get_shape()] 57 | assert all(isinstance(a, int) for a in out), \ 58 | "shape function assumes that shape is fully known" 59 | return out 60 | 61 | 62 | def numel(x): 63 | return np.prod(var_shape(x)) 64 | 65 | 66 | class GetFlat(object): 67 | def __init__(self, session, var_list): 68 | self.session = session 69 | self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0) 70 | 71 | def __call__(self): 72 | return self.op.eval(session=self.session) 73 | 74 | 75 | class SetFromFlat(object): 76 | 77 | def __init__(self, session, var_list): 78 | self.session = session 79 | shapes = map(var_shape, var_list) 80 | total_size = sum(np.prod(shape) for shape in shapes) 81 | self.theta = theta = tf.placeholder(tf.float32, [total_size]) 82 | start = 0 83 | assigns = [] 84 | shapes = map(var_shape, var_list) 85 | for (shape, v) in zip(shapes, var_list): 86 | size = np.prod(shape) 87 | assigns.append( 88 | tf.assign(v, tf.reshape(theta[start:start + size], shape))) 89 | start += size 90 | self.op = tf.group(*assigns) 91 | 92 | def __call__(self, theta): 93 | self.session.run(self.op, feed_dict={self.theta: theta}) 94 | 95 | 96 | def flatten(tensors): 97 | if isinstance(tensors, (tuple, list)): 98 | return tf.concat( 99 | tuple(tf.reshape(tensor, [-1]) for tensor in tensors), axis=0) 100 | else: 101 | return tf.reshape(tensors, [-1]) 102 | 103 | 104 | class unflatten(object): 105 | def __init__(self, tensors_template): 106 | self.tensors_template = tensors_template 107 | 108 | def __call__(self, colvec): 109 | if isinstance(self.tensors_template, (tuple, list)): 110 | offset = 0 111 | tensors = [] 112 | for tensor_template in self.tensors_template: 113 | sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) 114 | tensor = tf.reshape(colvec[offset:(offset + sz)], 115 | tensor_template.shape) 116 | tensors.append(tensor) 117 | offset += sz 118 | 119 | tensors = list(tensors) 120 | else: 121 | tensors = tf.reshape(colvec, self.tensors_template.shape) 122 | 123 | return tensors 124 | 125 | 126 | def find_trainable_variables(key): 127 | with tf.variable_scope(key): 128 | return tf.trainable_variables() 129 | 130 | 131 | def conv2d(inputs, kernel_size, out_channels, is_training, name, 132 | activation_fn=tf.nn.relu, padding="SAME", strides=(1, 1), use_bias=False, 133 | batch_norm=False, initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', 134 | distribution='uniform')): 135 | layer = tf.layers.Conv2D( 136 | out_channels, 137 | kernel_size=kernel_size, 138 | strides=strides, 139 | kernel_initializer=initializer, 140 | bias_initializer=tf.constant_initializer(0.0), 141 | padding=padding, 142 | use_bias=use_bias, 143 | name=name) 144 | preactivations = layer(inputs) 145 | if batch_norm: 146 | bn = tf.layers.batch_normalization(preactivations, training=is_training, center=False, scale=False) 147 | activations = activation_fn(bn) 148 | else: 149 | activations = activation_fn(preactivations) 150 | if use_bias: 151 | return preactivations, activations, (layer.kernel, layer.bias) 152 | else: 153 | return preactivations, activations, layer.kernel 154 | 155 | 156 | def dense(inputs, output_size, is_training, name, batch_norm=False, 157 | use_bias=False, activation_fn=tf.nn.relu, 158 | initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', 159 | distribution='uniform')): 160 | layer = tf.layers.Dense( 161 | output_size, 162 | kernel_initializer=initializer, 163 | bias_initializer=tf.constant_initializer(0.0), 164 | use_bias=use_bias, 165 | name=name) 166 | preactivations = layer(inputs) 167 | if batch_norm: 168 | bn = tf.layers.batch_normalization(preactivations, training=is_training, center=False, scale=False) 169 | activations = activation_fn(bn) 170 | else: 171 | activations = activation_fn(preactivations) 172 | if use_bias: 173 | return preactivations, activations, (layer.kernel, layer.bias) 174 | else: 175 | return preactivations, activations, layer.kernel 176 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from network.vgg import * 2 | from network.resnet import * 3 | from network.mlp import * 4 | -------------------------------------------------------------------------------- /network/mlp.py: -------------------------------------------------------------------------------- 1 | from misc.utils import dense 2 | from network.registry import register_model 3 | 4 | _APPROX = "kron" 5 | 6 | 7 | @register_model("mlp") 8 | def mlp(inputs, training, config, layer_collection=None): 9 | for l in range(config.n_layer): 10 | pre, act, param = dense(inputs, output_size=config.hidden_units, is_training=training, 11 | batch_norm=config.batch_norm, use_bias=config.use_bias, name="fc_"+str(l)) 12 | 13 | if layer_collection is not None: 14 | layer_collection.register_fully_connected(param, inputs, pre, approx=_APPROX) 15 | inputs = act 16 | 17 | logits, _, param = dense(inputs, output_size=config.output_dim, is_training=training, 18 | use_bias=config.use_bias, name="fc_"+str(config.n_layer)) 19 | 20 | if layer_collection is not None: 21 | layer_collection.register_fully_connected(param, inputs, logits, approx=_APPROX) 22 | if config.use_fisher: 23 | layer_collection.register_categorical_predictive_distribution(logits, name="logit") 24 | else: 25 | layer_collection.register_normal_predictive_distribution(logits, name="mean") 26 | 27 | return logits, inputs 28 | -------------------------------------------------------------------------------- /network/registry.py: -------------------------------------------------------------------------------- 1 | MODEL_REGISTRY = {} 2 | 3 | 4 | def register_model(model_name): 5 | def decorator(f): 6 | MODEL_REGISTRY[model_name] = f 7 | return f 8 | 9 | return decorator 10 | 11 | 12 | def get_model(model_name): 13 | if model_name in MODEL_REGISTRY: 14 | return MODEL_REGISTRY[model_name] 15 | else: 16 | raise ValueError("Unknown model {:s}".format(model_name)) 17 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from misc.utils import dense, conv2d 4 | from network.registry import register_model 5 | 6 | _OUTPUT = [64, 128, 256] 7 | _APPROX = "kron" 8 | 9 | 10 | def BasicBlock(inputs, training, output_dim, name, stride=1, batch_norm=False, use_bias=False, layer_collection=None): 11 | pre1, act1, param1 = conv2d(inputs, kernel_size=(3, 3), out_channels=output_dim, strides=(stride, stride), 12 | is_training=training, batch_norm=batch_norm, use_bias=use_bias, name=name+"conv1") 13 | 14 | pre2, act2, param2 = conv2d(act1, kernel_size=(3, 3), out_channels=output_dim, activation_fn=tf.identity, 15 | is_training=training, batch_norm=batch_norm, use_bias=use_bias, name=name+"conv2") 16 | 17 | if layer_collection is not None: 18 | layer_collection.register_conv2d(param1, (1, stride, stride, 1), "SAME", inputs, pre1, approx=_APPROX) 19 | layer_collection.register_conv2d(param2, (1, 1, 1, 1), "SAME", act1, pre2, approx=_APPROX) 20 | 21 | if stride != 1: 22 | pre3, act3, param3 = conv2d(inputs, kernel_size=(1, 1), out_channels=output_dim, strides=(stride, stride), 23 | is_training=training, batch_norm=batch_norm, use_bias=use_bias, 24 | name=name+"conv_skip", activation_fn=tf.identity) 25 | if layer_collection is not None: 26 | layer_collection.register_conv2d(param3, (1, stride, stride, 1), "SAME", inputs, pre3, approx=_APPROX) 27 | 28 | return tf.nn.relu(act2 + act3) 29 | 30 | return tf.nn.relu(act2 + inputs) 31 | 32 | 33 | def ResNet(inputs, training, num_blocks, output_dim, use_fisher=False, 34 | batch_norm=False, use_bias=False, layer_collection=None): 35 | pre1, act1, param1 = conv2d(inputs, kernel_size=(3, 3), out_channels=64, use_bias=use_bias, 36 | is_training=training, batch_norm=batch_norm, name="conv1") 37 | if layer_collection is not None: 38 | layer_collection.register_conv2d(param1, (1, 1, 1, 1), "SAME", inputs, pre1, approx=_APPROX) 39 | out = act1 40 | for i, b in enumerate(num_blocks): 41 | for l in range(b): 42 | if i > 0 and l == 0: 43 | stride = 2 44 | else: 45 | stride = 1 46 | out = BasicBlock(out, training, _OUTPUT[i], name="Res_"+str(i+1)+"Blk_"+str(l+1), use_bias=use_bias, 47 | stride=stride, batch_norm=batch_norm, layer_collection=layer_collection) 48 | 49 | # average pooling 50 | assert out.get_shape().as_list()[1:] == [8, 8, 256] 51 | out = tf.reduce_mean(out, [1, 2]) 52 | assert out.get_shape().as_list()[1:] == [256] 53 | 54 | logits, _, param = dense(out, output_size=output_dim, is_training=training, use_bias=use_bias, name="fc") 55 | if layer_collection is not None: 56 | layer_collection.register_fully_connected(param, out, logits, approx=_APPROX) 57 | if use_fisher: 58 | layer_collection.register_categorical_predictive_distribution(logits, name="logit") 59 | else: 60 | layer_collection.register_normal_predictive_distribution(logits, name="mean") 61 | return logits 62 | 63 | 64 | @register_model("resnet20") 65 | def resnet20(inputs, training, config, layer_collection=None): 66 | return ResNet(inputs, training, [3, 3, 3], config.output_dim, config.use_fisher, 67 | config.batch_norm, config.use_bias, layer_collection) 68 | 69 | 70 | @register_model("resnet32") 71 | def resnet32(inputs, training, config, layer_collection=None): 72 | return ResNet(inputs, training, [5, 5, 5], config.output_dim, config.use_fisher, 73 | config.batch_norm, config.use_bias, layer_collection) 74 | 75 | 76 | @register_model("resnet44") 77 | def resnet44(inputs, training, config, layer_collection=None): 78 | return ResNet(inputs, training, [7, 7, 7], config.output_dim, config.use_fisher, 79 | config.batch_norm, config.use_bias, layer_collection) 80 | 81 | 82 | @register_model("resnet56") 83 | def resnet56(inputs, training, config, layer_collection=None): 84 | return ResNet(inputs, training, [9, 9, 9], config.output_dim, config.use_fisher, 85 | config.batch_norm, config.use_bias, layer_collection) 86 | -------------------------------------------------------------------------------- /network/vgg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from misc.utils import dense, conv2d 5 | from network.registry import register_model 6 | 7 | _OUTPUT = [64, 128, 256, 512, 512] 8 | _APPROX = "kron" 9 | 10 | 11 | def VGG(inputs, layer, training, config, layer_collection=None): 12 | for b in range(5): 13 | for l in range(layer[b]): 14 | pre, act, param = conv2d(inputs, kernel_size=(3, 3), 15 | out_channels=_OUTPUT[b], 16 | is_training=training, 17 | batch_norm=config.batch_norm, 18 | use_bias=config.use_bias, 19 | name="conv_"+str(b)+"_"+str(l)) 20 | if layer_collection is not None: 21 | layer_collection.register_conv2d(param, (1, 1, 1, 1), "SAME", inputs, pre, approx=_APPROX) 22 | inputs = act 23 | 24 | inputs = tf.layers.max_pooling2d(inputs, 2, 2, "SAME") 25 | 26 | flat = tf.reshape(inputs, shape=[-1, int(np.prod(inputs.shape[1:]))]) 27 | logits, _, param = dense(flat, output_size=config.output_dim, use_bias=config.use_bias, 28 | is_training=training, name="fc") 29 | 30 | if layer_collection is not None: 31 | layer_collection.register_fully_connected(param, flat, logits, approx=_APPROX) 32 | if config.use_fisher: 33 | layer_collection.register_categorical_predictive_distribution(logits, name="logit") 34 | else: 35 | layer_collection.register_normal_predictive_distribution(logits, name="mean") 36 | 37 | return logits 38 | 39 | 40 | @register_model("vgg11") 41 | def vgg11(inputs, training, config, layer_collection=None): 42 | return VGG(inputs, [1, 1, 2, 2, 2], training, config, layer_collection) 43 | 44 | 45 | @register_model("vgg13") 46 | def vgg13(inputs, training, config, layer_collection=None): 47 | return VGG(inputs, [2, 2, 2, 2, 2], training, config, layer_collection) 48 | 49 | 50 | @register_model("vgg16") 51 | def vgg16(inputs, training, config, layer_collection=None): 52 | return VGG(inputs, [2, 2, 3, 3, 3], training, config, layer_collection) 53 | 54 | 55 | @register_model("vgg19") 56 | def vgg19(inputs, training, config, layer_collection=None): 57 | return VGG(inputs, [2, 2, 4, 4, 4], training, config, layer_collection) 58 | --------------------------------------------------------------------------------