├── compressed.png ├── complexity ├── dtd.npy ├── isun.npy ├── lsun.npy ├── lsunR.npy ├── mnist.npy ├── stl10.npy ├── svhn.npy ├── cifar10.npy ├── kmnist.npy ├── cifar100.npy ├── place365.npy └── fasionmnist.npy ├── figs ├── 10Results1.pdf ├── 10Results1.png ├── 10Results2.pdf ├── 10Results2.png ├── performance.png └── architecture.png ├── Flops ├── Glow-PyTorch-master │ ├── x.pkl │ ├── __pycache__ │ │ ├── model.cpython-37.pyc │ │ ├── modules.cpython-37.pyc │ │ ├── utils.cpython-37.pyc │ │ ├── datasets.cpython-37.pyc │ │ └── utils_flop.cpython-37.pyc │ ├── images │ │ └── histogram_glow_cifar_svhn.png │ ├── sum.py │ ├── output │ │ └── hparams.json │ ├── calculate_flops.py │ ├── ops.txt │ ├── utils.py │ ├── LICENSE │ ├── datasets.py │ ├── Sample_from_Glow.ipynb │ ├── Do_deep_generative_models_know_what_they_dont_know.ipynb │ ├── README.md │ ├── model.py │ ├── utils_flop.py │ └── modules.py ├── pixel-cnn-pp-master │ ├── __pycache__ │ │ ├── global2.cpython-37.pyc │ │ ├── layers.cpython-37.pyc │ │ ├── model.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ ├── images │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_11.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_14.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_17.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_2.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_20.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_23.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_26.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_29.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_32.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_35.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_38.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_41.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_44.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_47.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_5.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_50.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_53.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_56.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_59.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_62.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_65.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_68.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_71.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_74.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_77.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_8.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_80.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_83.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_86.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_89.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_92.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_95.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_98.png │ │ ├── pcnn_lr:0.00050_nr-resnet3_nr-filters160_19.png │ │ ├── pcnn_lr:0.00050_nr-resnet3_nr-filters160_29.png │ │ ├── pcnn_lr:0.00050_nr-resnet3_nr-filters160_39.png │ │ ├── pcnn_lr:0.00050_nr-resnet3_nr-filters160_49.png │ │ ├── pcnn_lr:0.00050_nr-resnet3_nr-filters160_9.png │ │ ├── pcnn_lr:0.00100_nr-resnet5_nr-filters160_0.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_101.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_104.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_107.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_110.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_113.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_116.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_119.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_122.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_125.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_128.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_131.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_134.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_137.png │ │ ├── pcnn_lr:0.00020_nr-resnet5_nr-filters160_140.png │ │ └── pcnn_lr:0.00020_nr-resnet5_nr-filters160_143.png │ ├── sum.py │ ├── license.md │ ├── readme.md │ ├── ops.txt │ ├── main.py │ ├── model.py │ ├── layers.py │ └── utils.py └── README.md ├── mahalanobis_parameters ├── magnitude.pkl ├── precision.pkl ├── num_classes.pkl └── sample_mean.pkl ├── utils ├── __pycache__ │ ├── MOOD.cpython-37.pyc │ ├── dataloader.cpython-37.pyc │ ├── svhn_loader.cpython-37.pyc │ └── msdnet_function.cpython-37.pyc ├── msdnet_function.py ├── svhn_loader.py ├── dataloader.py └── MOOD.py ├── models ├── __pycache__ │ ├── msdnet.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── msdnet_ge.cpython-37.pyc │ └── msdnet_imta.cpython-37.pyc ├── __init__.py ├── msdnet_imta.py └── msdnet.py ├── msd_dataloader.py ├── README.md ├── main.py └── msd_args.py /compressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/compressed.png -------------------------------------------------------------------------------- /complexity/dtd.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/dtd.npy -------------------------------------------------------------------------------- /complexity/isun.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/isun.npy -------------------------------------------------------------------------------- /complexity/lsun.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/lsun.npy -------------------------------------------------------------------------------- /complexity/lsunR.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/lsunR.npy -------------------------------------------------------------------------------- /complexity/mnist.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/mnist.npy -------------------------------------------------------------------------------- /complexity/stl10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/stl10.npy -------------------------------------------------------------------------------- /complexity/svhn.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/svhn.npy -------------------------------------------------------------------------------- /figs/10Results1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/figs/10Results1.pdf -------------------------------------------------------------------------------- /figs/10Results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/figs/10Results1.png -------------------------------------------------------------------------------- /figs/10Results2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/figs/10Results2.pdf -------------------------------------------------------------------------------- /figs/10Results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/figs/10Results2.png -------------------------------------------------------------------------------- /figs/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/figs/performance.png -------------------------------------------------------------------------------- /complexity/cifar10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/cifar10.npy -------------------------------------------------------------------------------- /complexity/kmnist.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/kmnist.npy -------------------------------------------------------------------------------- /figs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/figs/architecture.png -------------------------------------------------------------------------------- /complexity/cifar100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/cifar100.npy -------------------------------------------------------------------------------- /complexity/place365.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/place365.npy -------------------------------------------------------------------------------- /complexity/fasionmnist.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/complexity/fasionmnist.npy -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/Glow-PyTorch-master/x.pkl -------------------------------------------------------------------------------- /mahalanobis_parameters/magnitude.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/mahalanobis_parameters/magnitude.pkl -------------------------------------------------------------------------------- /mahalanobis_parameters/precision.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/mahalanobis_parameters/precision.pkl -------------------------------------------------------------------------------- /utils/__pycache__/MOOD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/utils/__pycache__/MOOD.cpython-37.pyc -------------------------------------------------------------------------------- /mahalanobis_parameters/num_classes.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/mahalanobis_parameters/num_classes.pkl -------------------------------------------------------------------------------- /mahalanobis_parameters/sample_mean.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/mahalanobis_parameters/sample_mean.pkl -------------------------------------------------------------------------------- /models/__pycache__/msdnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/models/__pycache__/msdnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .msdnet import MSDNet as msdnet 2 | from .msdnet_ge import msdnet_ge 3 | from .msdnet_imta import IMTA_MSDNet 4 | 5 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/msdnet_ge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/models/__pycache__/msdnet_ge.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/msdnet_imta.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/models/__pycache__/msdnet_imta.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/utils/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/svhn_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/utils/__pycache__/svhn_loader.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/msdnet_function.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/utils/__pycache__/msdnet_function.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/Glow-PyTorch-master/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/Glow-PyTorch-master/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/Glow-PyTorch-master/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/__pycache__/global2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/__pycache__/global2.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/Glow-PyTorch-master/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/images/histogram_glow_cifar_svhn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/Glow-PyTorch-master/images/histogram_glow_cifar_svhn.png -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/__pycache__/utils_flop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/Glow-PyTorch-master/__pycache__/utils_flop.cpython-37.pyc -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_11.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_14.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_17.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_2.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_20.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_23.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_26.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_29.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_32.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_35.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_38.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_41.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_44.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_47.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_5.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_50.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_53.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_56.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_56.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_59.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_59.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_62.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_62.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_65.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_65.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_68.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_71.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_71.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_74.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_74.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_77.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_77.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_8.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_80.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_83.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_83.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_86.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_86.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_89.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_89.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_92.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_92.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_95.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_95.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_98.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_98.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_19.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_29.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_39.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_49.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00050_nr-resnet3_nr-filters160_9.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00100_nr-resnet5_nr-filters160_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00100_nr-resnet5_nr-filters160_0.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_101.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_101.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_104.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_104.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_107.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_107.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_110.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_113.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_113.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_116.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_116.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_119.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_119.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_122.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_122.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_125.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_125.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_128.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_131.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_131.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_134.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_134.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_137.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_137.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_140.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_140.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_143.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deeplearning-wisc/MOOD/HEAD/Flops/pixel-cnn-pp-master/images/pcnn_lr:0.00020_nr-resnet5_nr-filters160_143.png -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/sum.py: -------------------------------------------------------------------------------- 1 | 2 | numbers = [] 3 | l=0 4 | with open("ops.txt", 'rt') as handle: 5 | for ln in handle: 6 | l=l+1 7 | numbers.append(int(ln)) 8 | print(l) 9 | print('Flops = ',sum(numbers)) 10 | #27806798720 -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/sum.py: -------------------------------------------------------------------------------- 1 | 2 | numbers = [] 3 | l=0 4 | with open("ops.txt", 'rt') as handle: 5 | for ln in handle: 6 | l=l+1 7 | numbers.append(int(ln)) 8 | print(l) 9 | print('Flops = ',sum(numbers)) 10 | #4093771776 -------------------------------------------------------------------------------- /Flops/README.md: -------------------------------------------------------------------------------- 1 | This file contains codes for calculating Flops 2 | 3 | For Glow: 4 | Original code from https://github.com/y0ast/Glow-PyTorch. We revised it for calculating Flops. 5 | (1) cd to Flops/Glow-PyTorch-master/ 6 | (2) python main.py using the terminal 7 | (3) copy all intergers into ops.txt 8 | you can skip (1)(2)(3) since we have done this for ops.txt 9 | (4) python sum.py using the terminal 10 | 11 | For PixelCNN++: 12 | Original code from https://github.com/pclucas14/pixel-cnn-pp. We revised it for calculating Flops. 13 | (1) cd to Flops/pixel-cnn-pp-master/ 14 | (2) python main.py using the terminal 15 | (3) copy all intergers between two strings "0******" into ops.txt 16 | you can skip (1)(2)(3) since we have done this for ops.txt 17 | (4) python sum.py using the terminal 18 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/output/hparams.json: -------------------------------------------------------------------------------- 1 | { 2 | "K": 32, 3 | "L": 3, 4 | "LU_decomposed": true, 5 | "actnorm_scale": 1.0, 6 | "augment": true, 7 | "batch_size": 64, 8 | "cuda": true, 9 | "dataroot": "./", 10 | "dataset": "cifar10", 11 | "download": true, 12 | "epochs": 250, 13 | "eval_batch_size": 512, 14 | "flow_coupling": "affine", 15 | "flow_permutation": "invconv", 16 | "hidden_channels": 512, 17 | "learn_top": true, 18 | "lr": 0.0005, 19 | "max_grad_clip": 0, 20 | "max_grad_norm": 0, 21 | "n_init_batches": 8, 22 | "n_workers": 6, 23 | "output_dir": "output/", 24 | "saved_model": "", 25 | "saved_optimizer": "", 26 | "seed": 0, 27 | "warmup": 5, 28 | "y_condition": false, 29 | "y_weight": 0.01 30 | } -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/license.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, but NOT sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/calculate_flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import shutil 5 | import random 6 | from itertools import islice 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.utils.data as data 12 | 13 | from ignite.contrib.handlers import ProgressBar 14 | from ignite.engine import Engine, Events 15 | from ignite.handlers import ModelCheckpoint, Timer 16 | from ignite.metrics import RunningAverage, Loss 17 | 18 | from datasets import get_CIFAR10, get_SVHN 19 | from model import Glow 20 | 21 | 22 | if 1: 23 | model = Glow( 24 | image_shape=(32,32,3), 25 | hidden_channels=512, 26 | K=32, 27 | L=3, 28 | actnorm_scale=1.0, 29 | flow_permutation="invconv", 30 | flow_coupling="affine", 31 | LU_decomposed=True, 32 | y_classes=10, 33 | learn_top=True, 34 | y_condition=False, 35 | ) 36 | print(model) 37 | model = model.cuda() 38 | 39 | import pickle 40 | data_input = open('x.pkl','rb') 41 | image = pickle.load(data_input) 42 | data_input.close() 43 | 44 | pre = model(image[0:1], None) -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/ops.txt: -------------------------------------------------------------------------------- 1 | 90050560 2 | 90050560 3 | 90050560 4 | 90050560 5 | 90050560 6 | 90050560 7 | 90050560 8 | 90050560 9 | 90050560 10 | 90050560 11 | 90050560 12 | 90050560 13 | 90050560 14 | 90050560 15 | 90050560 16 | 90050560 17 | 90050560 18 | 90050560 19 | 90050560 20 | 90050560 21 | 90050560 22 | 90050560 23 | 90050560 24 | 90050560 25 | 90050560 26 | 90050560 27 | 90050560 28 | 90050560 29 | 90050560 30 | 90050560 31 | 90050560 32 | 90050560 33 | 28069888 34 | 28069888 35 | 28069888 36 | 28069888 37 | 28069888 38 | 28069888 39 | 28069888 40 | 28069888 41 | 28069888 42 | 28069888 43 | 28069888 44 | 28069888 45 | 28069888 46 | 28069888 47 | 28069888 48 | 28069888 49 | 28069888 50 | 28069888 51 | 28069888 52 | 28069888 53 | 28069888 54 | 28069888 55 | 28069888 56 | 28069888 57 | 28069888 58 | 28069888 59 | 28069888 60 | 28069888 61 | 28069888 62 | 28069888 63 | 28069888 64 | 28069888 65 | 9809920 66 | 9809920 67 | 9809920 68 | 9809920 69 | 9809920 70 | 9809920 71 | 9809920 72 | 9809920 73 | 9809920 74 | 9809920 75 | 9809920 76 | 9809920 77 | 9809920 78 | 9809920 79 | 9809920 80 | 9809920 81 | 9809920 82 | 9809920 83 | 9809920 84 | 9809920 85 | 9809920 86 | 9809920 87 | 9809920 88 | 9809920 89 | 9809920 90 | 9809920 91 | 9809920 92 | 9809920 93 | 9809920 94 | 9809920 95 | 9809920 96 | 9809920 97 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def compute_same_pad(kernel_size, stride): 6 | if isinstance(kernel_size, int): 7 | kernel_size = [kernel_size] 8 | 9 | if isinstance(stride, int): 10 | stride = [stride] 11 | 12 | assert len(stride) == len( 13 | kernel_size 14 | ), "Pass kernel size and stride both as int, or both as equal length iterable" 15 | 16 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] 17 | 18 | 19 | def uniform_binning_correction(x, n_bits=8): 20 | """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). 21 | 22 | Args: 23 | x: 4-D Tensor of shape (NCHW) 24 | n_bits: optional. 25 | Returns: 26 | x: x ~ U(x, x + 1.0 / 256) 27 | objective: Equivalent to -q(x)*log(q(x)). 28 | """ 29 | b, c, h, w = x.size() 30 | n_bins = 2 ** n_bits 31 | chw = c * h * w 32 | x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) 33 | 34 | objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) 35 | return x, objective 36 | 37 | 38 | def split_feature(tensor, type="split"): 39 | """ 40 | type = ["split", "cross"] 41 | """ 42 | C = tensor.size(1) 43 | if type == "split": 44 | return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] 45 | elif type == "cross": 46 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 47 | -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/readme.md: -------------------------------------------------------------------------------- 1 | ## PixelCNN++ 2 | 3 | A Pytorch Implementation of [PixelCNN++.](https://arxiv.org/pdf/1701.05517.pdf) 4 | 5 | Main work taken from the [official implementation](https://github.com/openai/pixel-cnn) 6 | 7 | Pre-trained models are available [here](https://mega.nz/#F!W7IhST7R!PV7Pbet8Q07GxVLGnmQrZg) 8 | 9 | I kept the code structure to facilitate comparison with the official code. 10 | 11 | The code achieves **2.95** BPD on test set, compared to **2.92** BPD on the official tensorflow implementation. 12 |

13 | 14 | 15 | 16 | 17 |

18 | 19 | ### Running the code 20 | ``` 21 | python main.py 22 | ``` 23 | 24 | ### Differences with official implementation 25 | 1. No data dependant weight initialization 26 | 2. No exponential moving average of past models for test set evalutation 27 | 28 | ### Contact 29 | For questions / comments / requests, feel free to send me an email.\ 30 | Happy generative modelling :) 31 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Joost van Amersfoort 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | MIT License 24 | 25 | Copyright (c) 2019 Yuki-Chai 26 | 27 | Permission is hereby granted, free of charge, to any person obtaining a copy 28 | of this software and associated documentation files (the "Software"), to deal 29 | in the Software without restriction, including without limitation the rights 30 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 31 | copies of the Software, and to permit persons to whom the Software is 32 | furnished to do so, subject to the following conditions: 33 | 34 | The above copyright notice and this permission notice shall be included in all 35 | copies or substantial portions of the Software. 36 | 37 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 38 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 39 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 40 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 41 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 42 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 43 | SOFTWARE. 44 | 45 | -------------------------------------------------------------------------------- /models/msdnet_imta.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import math 12 | from .msdnet_ge import msdnet_ge 13 | 14 | __all__ = ['IMTA_MSDNet'] 15 | 16 | class IMTA_MSDNet(nn.Module): 17 | def __init__(self, args): 18 | super(IMTA_MSDNet, self).__init__() 19 | self.nBlocks = args.nBlocks 20 | if args.data == 'ImageNet': 21 | if args.step == 7: 22 | logits_channels = [576, 640, 608, 528, 976] 23 | elif args.step == 6: 24 | logits_channels = [512, 544, 496, 880, 792] 25 | else: 26 | logits_channels = [384, 384, 352, 304, 560] # step=4 27 | else: 28 | logits_channels = [] 29 | for i in range(args.nBlocks): 30 | logits_channels.append(128) # 128 for cifar10/100 31 | 32 | self.net = msdnet_ge(args) 33 | self.classifier = nn.ModuleList() 34 | self.isc_modules = nn.ModuleList() 35 | for i in range(args.nBlocks): 36 | if i == 0: 37 | in_channels = logits_channels[i] 38 | else: 39 | in_channels = logits_channels[i] * 2 40 | self.classifier.append(nn.Linear(in_channels, args.num_classes)) 41 | for i in range(args.nBlocks - 1): 42 | out_channels = logits_channels[i + 1] 43 | self.isc_modules.append(nn.Sequential( 44 | nn.Linear(args.num_classes, out_channels), 45 | nn.BatchNorm1d(out_channels), 46 | nn.ReLU(inplace=True))) 47 | 48 | 49 | def forward(self, x): 50 | pred = [] 51 | real_logits, logits = self.net(x) 52 | 53 | for i in range(self.nBlocks): 54 | if i == 0: 55 | in_logits = logits[i] 56 | else: 57 | in_logits = torch.cat((logits[i], feat), dim=-1) 58 | pd = self.classifier[i](in_logits) 59 | if i < self.nBlocks - 1: 60 | feat = self.isc_modules[i](pd) 61 | pred.append(pd) 62 | 63 | if self.training: 64 | return pred, real_logits[-1] 65 | else: 66 | return pred 67 | 68 | 69 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/datasets.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torchvision import transforms, datasets 7 | 8 | n_bits = 8 9 | 10 | 11 | def preprocess(x): 12 | # Follows: 13 | # https://github.com/tensorflow/tensor2tensor/blob/e48cf23c505565fd63378286d9722a1632f4bef7/tensor2tensor/models/research/glow.py#L78 14 | 15 | x = x * 255 # undo ToTensor scaling to [0,1] 16 | 17 | n_bins = 2 ** n_bits 18 | if n_bits < 8: 19 | x = torch.floor(x / 2 ** (8 - n_bits)) 20 | x = x / n_bins - 0.5 21 | 22 | return x 23 | 24 | 25 | def postprocess(x): 26 | x = torch.clamp(x, -0.5, 0.5) 27 | x += 0.5 28 | x = x * 2 ** n_bits 29 | return torch.clamp(x, 0, 255).byte() 30 | 31 | 32 | def get_CIFAR10(augment, dataroot, download): 33 | image_shape = (32, 32, 3) 34 | num_classes = 10 35 | 36 | if augment: 37 | transformations = [ 38 | transforms.RandomAffine(0, translate=(0.1, 0.1)), 39 | transforms.RandomHorizontalFlip(), 40 | ] 41 | else: 42 | transformations = [] 43 | 44 | transformations.extend([transforms.ToTensor(), preprocess]) 45 | train_transform = transforms.Compose(transformations) 46 | 47 | test_transform = transforms.Compose([transforms.ToTensor(), preprocess]) 48 | 49 | one_hot_encode = lambda target: F.one_hot(torch.tensor(target), num_classes) 50 | 51 | path = Path(dataroot) / "data" / "CIFAR10" 52 | train_dataset = datasets.CIFAR10( 53 | path, 54 | train=True, 55 | transform=train_transform, 56 | target_transform=one_hot_encode, 57 | download=download, 58 | ) 59 | 60 | test_dataset = datasets.CIFAR10( 61 | path, 62 | train=False, 63 | transform=test_transform, 64 | target_transform=one_hot_encode, 65 | download=download, 66 | ) 67 | 68 | return image_shape, num_classes, train_dataset, test_dataset 69 | 70 | 71 | def get_SVHN(augment, dataroot, download): 72 | image_shape = (32, 32, 3) 73 | num_classes = 10 74 | 75 | if augment: 76 | transformations = [transforms.RandomAffine(0, translate=(0.1, 0.1))] 77 | else: 78 | transformations = [] 79 | 80 | transformations.extend([transforms.ToTensor(), preprocess]) 81 | train_transform = transforms.Compose(transformations) 82 | 83 | test_transform = transforms.Compose([transforms.ToTensor(), preprocess]) 84 | 85 | one_hot_encode = lambda target: F.one_hot(torch.tensor(target), num_classes) 86 | 87 | path = Path(dataroot) / "data" / "SVHN" 88 | train_dataset = datasets.SVHN( 89 | path, 90 | split="train", 91 | transform=train_transform, 92 | target_transform=one_hot_encode, 93 | download=download, 94 | ) 95 | 96 | test_dataset = datasets.SVHN( 97 | path, 98 | split="test", 99 | transform=test_transform, 100 | target_transform=one_hot_encode, 101 | download=download, 102 | ) 103 | 104 | return image_shape, num_classes, train_dataset, test_dataset 105 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/Sample_from_Glow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "\n", 11 | "import torch\n", 12 | "from torchvision.utils import make_grid\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "\n", 15 | "from datasets import get_CIFAR10, get_SVHN, postprocess\n", 16 | "from model import Glow\n", 17 | "\n", 18 | "device = torch.device(\"cuda\")\n", 19 | "\n", 20 | "output_folder = 'output/'\n", 21 | "model_name = 'glow_model_250.pth'\n", 22 | "\n", 23 | "with open(output_folder + 'hparams.json') as json_file: \n", 24 | " hparams = json.load(json_file)\n", 25 | " \n", 26 | "image_shape, num_classes, _, test_cifar = get_CIFAR10(hparams['augment'], hparams['dataroot'], hparams['download'])\n", 27 | "image_shape, num_classes, _, test_svhn = get_SVHN(hparams['augment'], hparams['dataroot'], hparams['download'])\n", 28 | "\n", 29 | "\n", 30 | "model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],\n", 31 | " hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,\n", 32 | " hparams['learn_top'], hparams['y_condition'])\n", 33 | "\n", 34 | "model.load_state_dict(torch.load(output_folder + model_name))\n", 35 | "model.set_actnorm_init()\n", 36 | "\n", 37 | "model = model.to(device)\n", 38 | "\n", 39 | "model = model.eval()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def sample(model):\n", 49 | " with torch.no_grad():\n", 50 | " if hparams['y_condition']:\n", 51 | " y = torch.eye(num_classes)\n", 52 | " y = y.repeat(batch_size // num_classes + 1)\n", 53 | " y = y[:32, :].to(device) # number hardcoded in model for now\n", 54 | " else:\n", 55 | " y = None\n", 56 | "\n", 57 | " images = postprocess(model(y_onehot=y, temperature=1, reverse=True))\n", 58 | "\n", 59 | " return images.cpu()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": { 66 | "scrolled": false 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "images = sample(model)\n", 71 | "grid = make_grid(images[:30], nrow=6).permute(1,2,0)\n", 72 | "\n", 73 | "plt.figure(figsize=(10,10))\n", 74 | "plt.imshow(grid)\n", 75 | "plt.axis('off')" 76 | ] 77 | } 78 | ], 79 | "metadata": { 80 | "kernelspec": { 81 | "display_name": "Python 3", 82 | "language": "python", 83 | "name": "python3" 84 | }, 85 | "language_info": { 86 | "codemirror_mode": { 87 | "name": "ipython", 88 | "version": 3 89 | }, 90 | "file_extension": ".py", 91 | "mimetype": "text/x-python", 92 | "name": "python", 93 | "nbconvert_exporter": "python", 94 | "pygments_lexer": "ipython3", 95 | "version": "3.7.1" 96 | } 97 | }, 98 | "nbformat": 4, 99 | "nbformat_minor": 2 100 | } 101 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/Do_deep_generative_models_know_what_they_dont_know.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "\n", 11 | "import torch\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import seaborn as sns\n", 14 | "sns.set()\n", 15 | "\n", 16 | "from datasets import get_CIFAR10, get_SVHN\n", 17 | "from model import Glow\n", 18 | "\n", 19 | "device = torch.device(\"cuda\")\n", 20 | "\n", 21 | "output_folder = 'glow/'\n", 22 | "model_name = 'glow_affine_coupling.pt'\n", 23 | "\n", 24 | "with open(output_folder + 'hparams.json') as json_file: \n", 25 | " hparams = json.load(json_file)\n", 26 | " \n", 27 | "print(hparams)\n", 28 | "\n", 29 | "image_shape, num_classes, _, test_cifar = get_CIFAR10(hparams['augment'], hparams['dataroot'], True)\n", 30 | "image_shape, num_classes, _, test_svhn = get_SVHN(hparams['augment'], hparams['dataroot'], True)\n", 31 | "\n", 32 | "model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],\n", 33 | " hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,\n", 34 | " hparams['learn_top'], hparams['y_condition'])\n", 35 | "\n", 36 | "model.load_state_dict(torch.load(output_folder + model_name))\n", 37 | "model.set_actnorm_init()\n", 38 | "\n", 39 | "model = model.to(device)\n", 40 | "\n", 41 | "model = model.eval()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def compute_nll(dataset, model):\n", 51 | " dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, num_workers=6)\n", 52 | " \n", 53 | " nlls = []\n", 54 | " for x,y in dataloader:\n", 55 | " x = x.to(device)\n", 56 | " \n", 57 | " if hparams['y_condition']:\n", 58 | " y = y.to(device)\n", 59 | " else:\n", 60 | " y = None\n", 61 | " \n", 62 | " with torch.no_grad():\n", 63 | " _, nll, _ = model(x, y_onehot=y)\n", 64 | " nlls.append(nll)\n", 65 | " \n", 66 | " return torch.cat(nlls).cpu()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": { 73 | "scrolled": true 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "cifar_nll = compute_nll(test_cifar, model)\n", 78 | "svhn_nll = compute_nll(test_svhn, model)\n", 79 | "\n", 80 | "print(\"CIFAR NLL\", torch.mean(cifar_nll))\n", 81 | "print(\"SVHN NLL\", torch.mean(svhn_nll))" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": { 88 | "scrolled": false 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "plt.figure(figsize=(20,10))\n", 93 | "plt.title(\"Histogram Glow - trained on CIFAR10\")\n", 94 | "plt.xlabel(\"Negative bits per dimension\")\n", 95 | "plt.hist(-svhn_nll.numpy(), label=\"SVHN\", density=True, bins=30)\n", 96 | "plt.hist(-cifar_nll.numpy(), label=\"CIFAR10\", density=True, bins=50)\n", 97 | "plt.legend()\n", 98 | "plt.show()\n", 99 | "# plt.savefig(\"images/histogram_glow_cifar_svhn.png\", dpi=300)" 100 | ] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "Python 3", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.8.5" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 2 124 | } 125 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/README.md: -------------------------------------------------------------------------------- 1 | # Glow 2 | 3 | This repository implements the [Glow](https://arxiv.org/abs/1807.03039) model using PyTorch on the CIFAR-10 and SVHN dataset. We use the trained Glow to reproduce some of the results of the paper ["Do Deep Generative Models Know What They Don't Know?"](https://arxiv.org/abs/1810.09136): 4 | 5 | ![Histogram Glow - CIFAR10 and SVHN](images/histogram_glow_cifar_svhn.png) 6 | 7 | **To create histogram**: 8 | See [notebook](Do_deep_generative_models_know_what_they_dont_know.ipynb). 9 | Pretrained model (on CIFAR-10): [download](http://www.cs.ox.ac.uk/people/joost.vanamersfoort/glow.zip) (unzip before use). 10 | 11 | Note this pretrained model was created using the `affine` coupling layer, so it does not work well for generative sampling (see qualitative vs quantitative models in the Glow paper). The pretrained model achieves 3.39 bpd, while the original paper gets 3.35. The difference between our pretrained model and the paper is that we use batch size 64 (single GPU) and the paper uses 512 (8 GPU). 12 | 13 | This code uses some layers and groundwork from [glow-pytorch](https://github.com/chaiyujin/glow-pytorch), but is more modular, extendable, faster, easier to read and supports training on CIFAR-10 and SVHN. There are fewer dependencies and a consistent interface for new datasets. Thanks to [Milad](https://github.com/mi-lad) for comments and help with debugging. 14 | 15 | ## Setup and run 16 | 17 | The code has minimal dependencies. You need python 3.6+ and up to date versions of: 18 | 19 | ``` 20 | pytorch (tested on 1.1.0) 21 | torchvision 22 | pytorch-ignite 23 | tqdm 24 | ``` 25 | 26 | To install in a local conda: 27 | 28 | ``` 29 | conda install pytorch torchvision pytorch-ignite tqdm -c pytorch 30 | ``` 31 | 32 | **To train your own model:** 33 | 34 | ``` 35 | python train.py --download 36 | ``` 37 | 38 | Will download the CIFAR10 dataset for you, and start training. The defaults are tested on a `1080Ti`, Glow is a memory hungry model and it might be necessary to tune down the model size for your specific GPU. The output files will be send to `output/`. 39 | 40 | Everything is configurable through command line arguments, see 41 | 42 | ``` 43 | python train.py --help 44 | ``` 45 | 46 | for what is possible. 47 | 48 | ## Evaluate 49 | 50 | There are two notebooks available for evaluation: 51 | 52 | * The [first notebook](Do_deep_generative_models_know_what_they_dont_know.ipynb) reproduces a plot from "Do Deep Generative models know what they don't know?" (see above) and computes the average bpd on the CIFAR-10 and SVHN test sets. 53 | * The [second notebook](Sample_from_Glow.ipynb) allows you to visualise samples from the model (This works best with a model trained using the `additive` coupling layer). 54 | 55 | 56 | ## Extensions 57 | 58 | There are several possible extensions: 59 | 60 | - Multiclass conditional training 61 | - multiGPU 62 | - port over the [tests](https://github.com/chaiyujin/glow-pytorch/blob/master/test_modules.py) 63 | 64 | PRs for any of these would be very welcome. If you find any problem, feel free to make an [issue](https://github.com/y0ast/Glow-PyTorch/issues) too. 65 | 66 | The model is trained using `adamax` instead of `adam` as in the original implementation. Using `adam` leads to a NLL of 3.48 (vs. 3.39 with `adamax`). Note: when using `adam` you need to set `warmup` to 1, otherwise optimisation gets stuck in a poor local minimum. It's unclear why `adamax` is so important and I'm curious to hear any ideas! 67 | 68 | ## References: 69 | 70 | ``` 71 | @inproceedings{kingma2018glow, 72 | title={Glow: Generative flow with invertible 1x1 convolutions}, 73 | author={Kingma, Durk P and Dhariwal, Prafulla}, 74 | booktitle={Advances in Neural Information Processing Systems}, 75 | pages={10215--10224}, 76 | year={2018} 77 | } 78 | 79 | @inproceedings{nalisnick2018do, 80 | title={Do Deep Generative Models Know What They Don't Know? }, 81 | author={Eric Nalisnick and Akihiro Matsukawa and Yee Whye Teh and Dilan Gorur and Balaji Lakshminarayanan}, 82 | booktitle={International Conference on Learning Representations}, 83 | year={2019}, 84 | url={https://openreview.net/forum?id=H1xwNhCcYm}, 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /utils/msdnet_function.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from msd_args import arg_parser, arch_resume_names 4 | args = arg_parser.parse_args() 5 | 6 | args.grFactor = list(map(int, args.grFactor.split('-'))) 7 | args.bnFactor = list(map(int, args.bnFactor.split('-'))) 8 | args.nScales = len(args.grFactor) 9 | 10 | if args.use_valid: 11 | args.splits = ['train', 'val', 'test'] 12 | else: 13 | args.splits = ['train', 'val'] 14 | 15 | if args.data == 'cifar10': 16 | args.num_classes = 10 17 | elif args.data == 'cifar100': 18 | args.num_classes = 100 19 | else: 20 | args.num_classes = 1000 21 | 22 | def validate(val_loader, model, criterion): 23 | print(args.nBlocks) 24 | batch_time = AverageMeter() 25 | losses = AverageMeter() 26 | data_time = AverageMeter() 27 | top1, top5 = [], [] 28 | for i in range(args.nBlocks): 29 | top1.append(AverageMeter()) 30 | top5.append(AverageMeter()) 31 | 32 | # switch to evaluate mode 33 | model.eval() 34 | 35 | end = time.time() 36 | with torch.no_grad(): 37 | for i, (input, target) in enumerate(val_loader): 38 | target = target.cuda(non_blocking=True) 39 | input = input.cuda() 40 | 41 | input_var = torch.autograd.Variable(input) 42 | target_var = torch.autograd.Variable(target) 43 | 44 | data_time.update(time.time() - end) 45 | 46 | # compute output 47 | output, _ = model(input_var) 48 | if not isinstance(output, list): 49 | output = [output] 50 | 51 | loss = 0.0 52 | for j in range(len(output)): 53 | loss += criterion(output[j], target_var) 54 | 55 | # measure error and record loss 56 | losses.update(loss.item(), input.size(0)) 57 | 58 | for j in range(len(output)): 59 | err1, err5 = accuracy(output[j].data, target, topk=(1, 5)) 60 | top1[j].update(err1.item(), input.size(0)) 61 | top5[j].update(err5.item(), input.size(0)) 62 | 63 | # measure elapsed time 64 | batch_time.update(time.time() - end) 65 | end = time.time() 66 | 67 | if i % args.print_freq == 0: 68 | print('Epoch: [{0}/{1}]\t' 69 | 'Time {batch_time.avg:.3f}\t' 70 | 'Data {data_time.avg:.3f}\t' 71 | 'Loss {loss.val:.4f}\t' 72 | 'Err@1 {top1.val:.4f}\t' 73 | 'Err@5 {top5.val:.4f}'.format( 74 | i + 1, len(val_loader), 75 | batch_time=batch_time, data_time=data_time, 76 | loss=losses, top1=top1[-1], top5=top5[-1])) 77 | # break 78 | for j in range(args.nBlocks): 79 | print(' * Err@1 {top1.avg:.3f} Err@5 {top5.avg:.3f}'.format(top1=top1[j], top5=top5[j])) 80 | """ 81 | print('Exit {}\t' 82 | 'Err@1 {:.4f}\t' 83 | 'Err@5 {:.4f}'.format( 84 | j, top1[j].avg, top5[j].avg)) 85 | """ 86 | # print(' * Err@1 {top1.avg:.3f} Err@5 {top5.avg:.3f}'.format(top1=top1[-1], top5=top5[-1])) 87 | return losses.avg, top1[-1].avg, top5[-1].avg 88 | 89 | class AverageMeter(object): 90 | """Computes and stores the average and current value""" 91 | 92 | def __init__(self): 93 | self.reset() 94 | 95 | def reset(self): 96 | self.val = 0 97 | self.avg = 0 98 | self.sum = 0 99 | self.count = 0 100 | 101 | def update(self, val, n=1): 102 | self.val = val 103 | self.sum += val * n 104 | self.count += n 105 | self.avg = self.sum / self.count 106 | 107 | def accuracy(output, target, topk=(1,)): 108 | """Computes the error@k for the specified values of k""" 109 | maxk = max(topk) 110 | batch_size = target.size(0) 111 | 112 | _, pred = output.topk(maxk, 1, True, True) 113 | pred = pred.t() 114 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 115 | 116 | res = [] 117 | for k in topk: 118 | #correct_k = correct[:k].view(-1).float().sum(0) 119 | correct_k = correct[:k].reshape(-1).float().sum(0) 120 | # res.append(100.0 - correct_k.mul_(100.0 / batch_size)) 121 | res.append(correct_k.mul_(100.0 / batch_size)) 122 | return res -------------------------------------------------------------------------------- /msd_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as dset 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | import os 6 | 7 | 8 | def msd_get_dataloaders(args): 9 | train_loader, val_loader, test_loader = None, None, None 10 | if args.data == 'cifar10': 11 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], 12 | std=[0.2471, 0.2435, 0.2616]) 13 | train_set = datasets.CIFAR10(args.data_root, train=True, download=True, 14 | transform=transforms.Compose([ 15 | transforms.RandomCrop(32, padding=4), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor(), 18 | normalize 19 | ])) 20 | val_set = datasets.CIFAR10(args.data_root, train=False, 21 | transform=transforms.Compose([ 22 | transforms.ToTensor(), 23 | normalize 24 | ])) 25 | elif args.data == 'cifar100': 26 | normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], 27 | std=[0.2675, 0.2565, 0.2761]) 28 | train_set = datasets.CIFAR100(args.data_root, train=True, download=True, 29 | transform=transforms.Compose([ 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | normalize 34 | ])) 35 | val_set = datasets.CIFAR100(args.data_root, train=False, download=True, 36 | transform=transforms.Compose([ 37 | transforms.ToTensor(), 38 | normalize 39 | ])) 40 | else: 41 | # ImageNet 42 | traindir = os.path.join(args.data_root, 'train') 43 | valdir = os.path.join(args.data_root, 'val') 44 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 45 | std=[0.229, 0.224, 0.225]) 46 | train_set = datasets.ImageFolder(traindir, transforms.Compose([ 47 | transforms.RandomResizedCrop(224), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | normalize 51 | ])) 52 | val_set = datasets.ImageFolder(valdir, transforms.Compose([ 53 | transforms.Resize(256), 54 | transforms.CenterCrop(224), 55 | transforms.ToTensor(), 56 | normalize 57 | ])) 58 | if args.use_valid: 59 | train_set_index = torch.randperm(len(train_set)) 60 | if os.path.exists(os.path.join(args.save, 'index.pth')): 61 | print('!!!!!! Load train_set_index !!!!!!') 62 | train_set_index = torch.load(os.path.join(args.save, 'index.pth')) 63 | else: 64 | print('!!!!!! Save train_set_index !!!!!!') 65 | torch.save(train_set_index, os.path.join(args.save, 'index.pth')) 66 | if args.data.startswith('cifar'): 67 | num_sample_valid = 5000 68 | else: 69 | num_sample_valid = 50000 70 | # num_sample_valid = len(val_set) 71 | print("------------------------------------") 72 | print("split num_sample_valid: %d" % num_sample_valid) 73 | print("------------------------------------") 74 | 75 | if 'train' in args.splits: 76 | train_loader = torch.utils.data.DataLoader( 77 | train_set, batch_size=args.batch_size, 78 | sampler=torch.utils.data.sampler.SubsetRandomSampler( 79 | train_set_index[:-num_sample_valid]), 80 | num_workers=args.workers, pin_memory=True) 81 | if 'val' in args.splits: 82 | val_loader = torch.utils.data.DataLoader( 83 | train_set, batch_size=args.batch_size, 84 | sampler=torch.utils.data.sampler.SubsetRandomSampler( 85 | train_set_index[-num_sample_valid:]), 86 | num_workers=args.workers, pin_memory=True) 87 | if 'test' in args.splits: 88 | test_loader = torch.utils.data.DataLoader( 89 | val_set, 90 | batch_size=args.batch_size, shuffle=False, 91 | num_workers=args.workers, pin_memory=True) 92 | else: 93 | if 'train' in args.splits: 94 | train_loader = torch.utils.data.DataLoader( 95 | train_set, 96 | batch_size=args.batch_size, shuffle=True, 97 | num_workers=args.workers, pin_memory=True) 98 | if 'val' or 'test' in args.splits: 99 | val_loader = torch.utils.data.DataLoader( 100 | val_set, 101 | batch_size=args.batch_size, shuffle=False, 102 | num_workers=args.workers, pin_memory=True) 103 | test_loader = val_loader 104 | 105 | return train_loader, val_loader, test_loader 106 | -------------------------------------------------------------------------------- /utils/svhn_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | 7 | 8 | class SVHN(data.Dataset): 9 | url = "" 10 | filename = "" 11 | file_md5 = "" 12 | 13 | split_list = { 14 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 15 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 16 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 17 | "selected_test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 18 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 19 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"], 20 | 'train_and_extra': [ 21 | ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 22 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 23 | ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 24 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]]} 25 | 26 | def __init__(self, root, split='train', 27 | transform=None, target_transform=None, download=False): 28 | self.root = root 29 | self.transform = transform 30 | self.target_transform = target_transform 31 | self.split = split # training set or test set or extra set 32 | 33 | if self.split not in self.split_list: 34 | raise ValueError('Wrong split entered! Please use split="train" ' 35 | 'or split="extra" or split="test" ' 36 | 'or split="train_and_extra" ') 37 | 38 | if self.split == "train_and_extra": 39 | self.url = self.split_list[split][0][0] 40 | self.filename = self.split_list[split][0][1] 41 | self.file_md5 = self.split_list[split][0][2] 42 | else: 43 | self.url = self.split_list[split][0] 44 | self.filename = self.split_list[split][1] 45 | self.file_md5 = self.split_list[split][2] 46 | 47 | # import here rather than at top of file because this is 48 | # an optional dependency for torchvision 49 | import scipy.io as sio 50 | 51 | # reading(loading) mat file as array 52 | loaded_mat = sio.loadmat(os.path.join(root, self.filename)) 53 | 54 | if self.split == "test": 55 | self.data = loaded_mat['X'] 56 | self.targets = loaded_mat['y'] 57 | # Note label 10 == 0 so modulo operator required 58 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 59 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 60 | else: 61 | self.data = loaded_mat['X'] 62 | self.targets = loaded_mat['y'] 63 | 64 | if self.split == "train_and_extra": 65 | extra_filename = self.split_list[split][1][1] 66 | loaded_mat = sio.loadmat(os.path.join(root, extra_filename)) 67 | self.data = np.concatenate([self.data, 68 | loaded_mat['X']], axis=3) 69 | self.targets = np.vstack((self.targets, 70 | loaded_mat['y'])) 71 | # Note label 10 == 0 so modulo operator required 72 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 73 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 74 | 75 | def __getitem__(self, index): 76 | if self.split == "test": 77 | img, target = self.data[index], self.targets[index] 78 | else: 79 | img, target = self.data[index], self.targets[index] 80 | 81 | # doing this so that it is consistent with all other datasets 82 | # to return a PIL Image 83 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 84 | 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | 91 | return img, target.astype(np.long) 92 | 93 | def __len__(self): 94 | if self.split == "test": 95 | return len(self.data) 96 | else: 97 | return len(self.data) 98 | 99 | def _check_integrity(self): 100 | root = self.root 101 | if self.split == "train_and_extra": 102 | md5 = self.split_list[self.split][0][2] 103 | fpath = os.path.join(root, self.filename) 104 | train_integrity = check_integrity(fpath, md5) 105 | extra_filename = self.split_list[self.split][1][1] 106 | md5 = self.split_list[self.split][1][2] 107 | fpath = os.path.join(root, extra_filename) 108 | return check_integrity(fpath, md5) and train_integrity 109 | else: 110 | md5 = self.split_list[self.split][2] 111 | fpath = os.path.join(root, self.filename) 112 | return check_integrity(fpath, md5) 113 | 114 | def download(self): 115 | if self.split == "train_and_extra": 116 | md5 = self.split_list[self.split][0][2] 117 | download_url(self.url, self.root, self.filename, md5) 118 | extra_filename = self.split_list[self.split][1][1] 119 | md5 = self.split_list[self.split][1][2] 120 | download_url(self.url, self.root, extra_filename, md5) 121 | else: 122 | md5 = self.split_list[self.split][2] 123 | download_url(self.url, self.root, self.filename, md5) 124 | -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/ops.txt: -------------------------------------------------------------------------------- 1 | 4259840 2 | 2293760 3 | 1638400 4 | 1638400 5 | 314900480 6 | 1638400 7 | 629800960 8 | 655360 9 | 1638400 10 | 210042880 11 | 1638400 12 | 1638400 13 | 1638400 14 | 420085760 15 | 105512960 16 | 1638400 17 | 314900480 18 | 1638400 19 | 629800960 20 | 655360 21 | 1638400 22 | 210042880 23 | 1638400 24 | 1638400 25 | 1638400 26 | 420085760 27 | 105512960 28 | 1638400 29 | 314900480 30 | 1638400 31 | 629800960 32 | 655360 33 | 1638400 34 | 210042880 35 | 1638400 36 | 1638400 37 | 1638400 38 | 420085760 39 | 105512960 40 | 1638400 41 | 314900480 42 | 1638400 43 | 629800960 44 | 655360 45 | 1638400 46 | 210042880 47 | 1638400 48 | 1638400 49 | 1638400 50 | 420085760 51 | 105512960 52 | 1638400 53 | 314900480 54 | 1638400 55 | 629800960 56 | 655360 57 | 1638400 58 | 210042880 59 | 1638400 60 | 1638400 61 | 1638400 62 | 420085760 63 | 105512960 64 | 39403520 65 | 26296320 66 | 409600 67 | 78725120 68 | 409600 69 | 157450240 70 | 163840 71 | 409600 72 | 52510720 73 | 409600 74 | 409600 75 | 409600 76 | 105021440 77 | 26378240 78 | 409600 79 | 78725120 80 | 409600 81 | 157450240 82 | 163840 83 | 409600 84 | 52510720 85 | 409600 86 | 409600 87 | 409600 88 | 105021440 89 | 26378240 90 | 409600 91 | 78725120 92 | 409600 93 | 157450240 94 | 163840 95 | 409600 96 | 52510720 97 | 409600 98 | 409600 99 | 409600 100 | 105021440 101 | 26378240 102 | 409600 103 | 78725120 104 | 409600 105 | 157450240 106 | 163840 107 | 409600 108 | 52510720 109 | 409600 110 | 409600 111 | 409600 112 | 105021440 113 | 26378240 114 | 409600 115 | 78725120 116 | 409600 117 | 157450240 118 | 163840 119 | 409600 120 | 52510720 121 | 409600 122 | 409600 123 | 409600 124 | 105021440 125 | 26378240 126 | 9850880 127 | 6574080 128 | 102400 129 | 19681280 130 | 102400 131 | 39362560 132 | 40960 133 | 102400 134 | 13127680 135 | 102400 136 | 102400 137 | 102400 138 | 26255360 139 | 6594560 140 | 102400 141 | 19681280 142 | 102400 143 | 39362560 144 | 40960 145 | 102400 146 | 13127680 147 | 102400 148 | 102400 149 | 102400 150 | 26255360 151 | 6594560 152 | 102400 153 | 19681280 154 | 102400 155 | 39362560 156 | 40960 157 | 102400 158 | 13127680 159 | 102400 160 | 102400 161 | 102400 162 | 26255360 163 | 6594560 164 | 102400 165 | 19681280 166 | 102400 167 | 39362560 168 | 40960 169 | 102400 170 | 13127680 171 | 102400 172 | 102400 173 | 102400 174 | 26255360 175 | 6594560 176 | 102400 177 | 19681280 178 | 102400 179 | 39362560 180 | 40960 181 | 102400 182 | 13127680 183 | 102400 184 | 102400 185 | 102400 186 | 26255360 187 | 6594560 188 | 102400 189 | 19681280 190 | 102400 191 | 102400 192 | 102400 193 | 39362560 194 | 6594560 195 | 102400 196 | 13127680 197 | 204800 198 | 204800 199 | 102400 200 | 26255360 201 | 26255360 202 | 102400 203 | 19681280 204 | 102400 205 | 102400 206 | 102400 207 | 39362560 208 | 6594560 209 | 102400 210 | 13127680 211 | 204800 212 | 204800 213 | 102400 214 | 26255360 215 | 26255360 216 | 102400 217 | 19681280 218 | 102400 219 | 102400 220 | 102400 221 | 39362560 222 | 6594560 223 | 102400 224 | 13127680 225 | 204800 226 | 204800 227 | 102400 228 | 26255360 229 | 26255360 230 | 102400 231 | 19681280 232 | 102400 233 | 102400 234 | 102400 235 | 39362560 236 | 6594560 237 | 102400 238 | 13127680 239 | 204800 240 | 204800 241 | 102400 242 | 26255360 243 | 26255360 244 | 102400 245 | 19681280 246 | 102400 247 | 102400 248 | 102400 249 | 39362560 250 | 6594560 251 | 102400 252 | 13127680 253 | 204800 254 | 204800 255 | 102400 256 | 26255360 257 | 26255360 258 | 9928320 259 | 6646080 260 | 409600 261 | 78725120 262 | 409600 263 | 409600 264 | 409600 265 | 157450240 266 | 26378240 267 | 409600 268 | 52510720 269 | 819200 270 | 819200 271 | 409600 272 | 105021440 273 | 105021440 274 | 409600 275 | 78725120 276 | 409600 277 | 409600 278 | 409600 279 | 157450240 280 | 26378240 281 | 409600 282 | 52510720 283 | 819200 284 | 819200 285 | 409600 286 | 105021440 287 | 105021440 288 | 409600 289 | 78725120 290 | 409600 291 | 409600 292 | 409600 293 | 157450240 294 | 26378240 295 | 409600 296 | 52510720 297 | 819200 298 | 819200 299 | 409600 300 | 105021440 301 | 105021440 302 | 409600 303 | 78725120 304 | 409600 305 | 409600 306 | 409600 307 | 157450240 308 | 26378240 309 | 409600 310 | 52510720 311 | 819200 312 | 819200 313 | 409600 314 | 105021440 315 | 105021440 316 | 409600 317 | 78725120 318 | 409600 319 | 409600 320 | 409600 321 | 157450240 322 | 26378240 323 | 409600 324 | 52510720 325 | 819200 326 | 819200 327 | 409600 328 | 105021440 329 | 105021440 330 | 409600 331 | 78725120 332 | 409600 333 | 409600 334 | 409600 335 | 157450240 336 | 26378240 337 | 409600 338 | 52510720 339 | 819200 340 | 819200 341 | 409600 342 | 105021440 343 | 105021440 344 | 39680640 345 | 26562880 346 | 1638400 347 | 314900480 348 | 1638400 349 | 1638400 350 | 1638400 351 | 629800960 352 | 105512960 353 | 1638400 354 | 210042880 355 | 3276800 356 | 3276800 357 | 1638400 358 | 420085760 359 | 420085760 360 | 1638400 361 | 314900480 362 | 1638400 363 | 1638400 364 | 1638400 365 | 629800960 366 | 105512960 367 | 1638400 368 | 210042880 369 | 3276800 370 | 3276800 371 | 1638400 372 | 420085760 373 | 420085760 374 | 1638400 375 | 314900480 376 | 1638400 377 | 1638400 378 | 1638400 379 | 629800960 380 | 105512960 381 | 1638400 382 | 210042880 383 | 3276800 384 | 3276800 385 | 1638400 386 | 420085760 387 | 420085760 388 | 1638400 389 | 314900480 390 | 1638400 391 | 1638400 392 | 1638400 393 | 629800960 394 | 105512960 395 | 1638400 396 | 210042880 397 | 3276800 398 | 3276800 399 | 1638400 400 | 420085760 401 | 420085760 402 | 1638400 403 | 314900480 404 | 1638400 405 | 1638400 406 | 1638400 407 | 629800960 408 | 105512960 409 | 1638400 410 | 210042880 411 | 3276800 412 | 3276800 413 | 1638400 414 | 420085760 415 | 420085760 416 | 1638400 417 | 314900480 418 | 1638400 419 | 1638400 420 | 1638400 421 | 629800960 422 | 105512960 423 | 1638400 424 | 210042880 425 | 3276800 426 | 3276800 427 | 1638400 428 | 420085760 429 | 420085760 430 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOOD: Multi-level Out-of-distribution Detection 2 | *** STL will be removed from figures later *** 3 | 4 | This is a [PyTorch](http://pytorch.org) implementation for detecting out-of-distribution examples in neural networks. The method is described in the paper [MOOD: Multi-level Out-of-distribution Detection](http://arxiv.org/abs/2104.14726) by Ziqian Lin*, Sreya Dutta Roy* and Yixuan Li (*Authors contributed equally.). We propose a novel framework, multi-level out-of-distribution detection (MOOD), which exploits intermediate classifier outputs for dynamic and efficient OOD inference, where easy OOD examples can be effectively detected early without propagating to deeper layers. 5 |

6 | 7 |

8 | The method achieves up to 71.05% computational reduction in inference, while maintaining competitive OOD detection performance. 9 |

10 | 11 |

12 |

13 | 14 |

15 | 16 | ## Experimental Results 17 | 18 | We used the deep learning model [MSDNet](https://openaccess.thecvf.com/content_ICCV_2019/papers/Li_Improved_Techniques_for_Training_Adaptive_Deep_Networks_ICCV_2019_paper.pdf) with for our experiment. The PyTorch implementation of [MSDNet](https://github.com/kalviny/IMTA) is provided by [Hao Li](https://github.com/andreasveit/densenet-pytorch). 19 | The experimental results are shown as follows. The definition of each metric can be found in the [paper](https://arxiv.org/???). 20 | ![performance](./figs/performance.png) 21 | 22 | 23 | 24 | ## Pre-trained Models 25 | 26 | We provide two pre-trained neural networks: The two [MSDNet](https://drive.google.com/drive/folders/1SxytZVfrV_FWN3BkYl_LyPZKBVqW8LG0?usp=sharing) networks trained on CIFAR-10 and CIFAR-100 respectively, please put the unzipped files in the folder '/trained_model'. The test accuracies are given by: 27 | 28 | Architecture | CIFAR-10 | CIFAR-100 29 | ------------ | --------- | --------- 30 | MSDNet | 94.09 | 75.43 31 | 32 | ## Dataset 33 | 34 | ### Description 35 | We use CIFAR-10 and CIFAR-100 as in-distribution datasets, which are common benchmarks for OOD detection. For the OOD detection evaluation, we consider a total of 9 datasets with a diverse spectrum of image complexity. In order of increasing complexity, we use MNIST, K-MNIST, fashion-MNIST, LSUN (crop), SVHN, Textures, Places365, iSUN and LSUN (resize). All images are resized to 32×32 before feeding into the network. For each OOD dataset, we evaluate on the entire test split. 36 | 37 | ### Downloading Out-of-Distribtion Datasets 38 | We provide download links of 6 out-of-distributin [datasets](https://drive.google.com/drive/folders/1ypLnPHgnukDO0bJxhDmpSejhxeoJuaPP?usp=sharing), please put the unzipped files in the folder '/data'. 39 | For the other 2 in-distribution datasets and 4 out-of-distribution datasets, the code will automatically download them since they are included in the torchvision.datasets. 40 | 41 | Datasets | Download Through 42 | ------------------------------|----------------------- 43 | Cifar10 | torchvision.datasets 44 | CIfar100 | torchvision.datasets 45 | MNIST | torchvision.datasets 46 | K-MNIST | torchvision.datasets 47 | fashion-MNIST | torchvision.datasets 48 | LSUN (crop) | google drive 49 | SVHN | google drive 50 | Textures | google drive 51 | Places365 | google drive 52 | isun | google drive 53 | lsunR | google drive 54 | 55 | 56 | 57 | ## Running the code 58 | 59 | ### Dependencies 60 | * python 3.7 61 | * CUDA 10.2 62 | * PyTorch with GPU 63 | * Anaconda3 64 | * opencv 3.4.2 65 | * scikit-learn 66 | 67 | 68 | ### Running 69 | 70 | Here is an example code reproducing the results of MOOD method, the MSDNet is trained on CIFAR-10 and out-of-distribution data includes 10 datasets. In the **root** directory, run 71 | 72 | ``` 73 | python main.py -ms energy -ml 5 -ma 1 -mc png 74 | ``` 75 | **Note:** Please choose arguments according to the following. 76 | 77 | #### args 78 | * **args.score**: the arguments of the score function for MOOD method are shown as follows 79 | 80 | Score Functions | args.score 81 | ------------------|-------- 82 | Energy Score | energy 83 | MSP Score | msp 84 | Odin SCore | odin 85 | Mahalanobis Score | mahalanobis 86 | * **args.id**: the arguments of in-ditribution datasets are shown as follows 87 | 88 | Nerual Network Models | args.id 89 | ----------------------|-------- 90 | MSDNet trained on CIFAR-10 | cifar10 91 | MSDNet trained on CIFAR-100| cifar100 92 | * **args.od**: the arguments of out-of-distribution datasets are shown as follows 93 | 94 | Out-of-Distribution Datasets | args.od 95 | ------------------------------|----------------- 96 | MNIST | mnist 97 | K-MNIST | kmnist 98 | fashion-MNIST | fasionmnist 99 | LSUN (crop) | lsun 100 | SVHN | svhn 101 | Textures | dtd 102 | Places365 | place365 103 | isun | isun 104 | lsunR | lsunR 105 | * **args.compressor**: the arguments of the compressor for MOOD method are shown as follows 106 | 107 | IMG Compressor Method | args.compressor 108 | ----------------------|------------------ 109 | PNG | png 110 | * **args.adjusted**: the arguments of whether using adjusted score for MOOD method are shown as follows 111 | 112 | Score Function | args.adjusted 113 | ------------------|------------------ 114 | Energy Score | 1 115 | MSP Score | 0 116 | Odin SCore | 0 117 | Mahalanobis Score | 0 118 | 119 | ### Outputs 120 | Here is an example of output. 121 | 122 | ``` 123 | 124 | ********** auroc result cifar10 with energy ********** 125 | auroc fpr95 126 | OOD dataset exit@last MOOD exit@last MOOD 127 | mnist 0.9903 0.9979 0.0413 0.0036 128 | kmnist 0.9844 0.9986 0.0699 0.0033 129 | fasionmnist 0.9923 0.9991 0.0248 0.0011 130 | lsun 0.9873 0.9923 0.0591 0.0320 131 | svhn 0.9282 0.9649 0.3409 0.1716 132 | dtd 0.8229 0.8329 0.5537 0.5603 133 | place365 0.8609 0.8674 0.4568 0.4687 134 | isun 0.9384 0.9296 0.3179 0.3882 135 | lsunR 0.9412 0.9325 0.2911 0.3616 136 | average 0.9384 0.9461 0.2395 0.2212 137 | ``` 138 | 139 | ### For bibtex citation 140 | ``` 141 | @inproceedings{lin2021mood, 142 | author = {Lin, Ziqian and Roy, Sreya Dutta and Li, Yixuan}, 143 | title = {MOOD: Multi-level Out-of-distribution Detection}, 144 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 145 | year = {2021} 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.optim import lr_scheduler 9 | from torchvision import datasets, transforms, utils 10 | from tensorboardX import SummaryWriter 11 | from utils import * 12 | from model import * 13 | from PIL import Image 14 | import sys 15 | parser = argparse.ArgumentParser() 16 | # data I/O 17 | parser.add_argument('-i', '--data_dir', type=str, 18 | default='data', help='Location for the dataset') 19 | parser.add_argument('-o', '--save_dir', type=str, default='models', 20 | help='Location for parameter checkpoints and samples') 21 | parser.add_argument('-d', '--dataset', type=str, 22 | default='cifar', help='Can be either cifar|mnist') 23 | parser.add_argument('-p', '--print_every', type=int, default=50, 24 | help='how many iterations between print statements') 25 | parser.add_argument('-t', '--save_interval', type=int, default=10, 26 | help='Every how many epochs to write checkpoint/samples?') 27 | parser.add_argument('-r', '--load_params', type=str, default=None, 28 | help='Restore training from previous model checkpoint?') 29 | # model 30 | parser.add_argument('-q', '--nr_resnet', type=int, default=5, 31 | help='Number of residual blocks per stage of the model') 32 | parser.add_argument('-n', '--nr_filters', type=int, default=160, 33 | help='Number of filters to use across the model. Higher = larger model.') 34 | parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10, 35 | help='Number of logistic components in the mixture. Higher = more flexible model') 36 | parser.add_argument('-l', '--lr', type=float, 37 | default=0.0002, help='Base learning rate') 38 | parser.add_argument('-e', '--lr_decay', type=float, default=0.999995, 39 | help='Learning rate decay, applied every step of the optimization') 40 | parser.add_argument('-b', '--batch_size', type=int, default=1, 41 | help='Batch size during training per GPU') 42 | parser.add_argument('-x', '--max_epochs', type=int, 43 | default=5000, help='How many epochs to run in total?') 44 | parser.add_argument('-s', '--seed', type=int, default=1, 45 | help='Random seed to use') 46 | args = parser.parse_args() 47 | 48 | # reproducibility 49 | torch.manual_seed(args.seed) 50 | np.random.seed(args.seed) 51 | 52 | model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(args.lr, args.nr_resnet, args.nr_filters) 53 | #assert not os.path.exists(os.path.join('runs', model_name)), '{} already exists!'.format(model_name) 54 | writer = SummaryWriter(log_dir=os.path.join('runs', model_name)) 55 | 56 | sample_batch_size = 25 57 | obs = (1, 28, 28) if 'mnist' in args.dataset else (3, 32, 32) 58 | input_channels = obs[0] 59 | rescaling = lambda x : (x - .5) * 2. 60 | rescaling_inv = lambda x : .5 * x + .5 61 | kwargs = {'num_workers':1, 'pin_memory':True, 'drop_last':True} 62 | ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling]) 63 | 64 | if 'mnist' in args.dataset : 65 | train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True, 66 | train=True, transform=ds_transforms), batch_size=args.batch_size, 67 | shuffle=True, **kwargs) 68 | 69 | test_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False, 70 | transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs) 71 | 72 | loss_op = lambda real, fake : discretized_mix_logistic_loss_1d(real, fake) 73 | sample_op = lambda x : sample_from_discretized_mix_logistic_1d(x, args.nr_logistic_mix) 74 | 75 | elif 'cifar' in args.dataset : 76 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True, 77 | download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs) 78 | 79 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False, 80 | transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs) 81 | 82 | loss_op = lambda real, fake : discretized_mix_logistic_loss(real, fake) 83 | sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix) 84 | else : 85 | raise Exception('{} dataset not in {mnist, cifar10}'.format(args.dataset)) 86 | 87 | model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters, 88 | input_channels=input_channels, nr_logistic_mix=args.nr_logistic_mix) 89 | model = model.cuda() 90 | 91 | if args.load_params: 92 | load_part_of_model(model, args.load_params) 93 | # model.load_state_dict(torch.load(args.load_params)) 94 | print('model parameters loaded') 95 | 96 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 97 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay) 98 | 99 | def sample(model): 100 | model.train(False) 101 | data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2]) 102 | data = data.cuda() 103 | for i in range(obs[1]): 104 | for j in range(obs[2]): 105 | data_v = Variable(data, volatile=True) 106 | out = model(data_v, sample=True) 107 | out_sample = sample_op(out) 108 | data[:, :, i, j] = out_sample.data[:, :, i, j] 109 | return data 110 | 111 | print('starting training') 112 | writes = 0 113 | for epoch in range(args.max_epochs): 114 | model.train(True) 115 | torch.cuda.synchronize() 116 | train_loss = 0. 117 | time_ = time.time() 118 | model.train() 119 | print(model) 120 | for batch_idx, (input,_) in enumerate(train_loader): 121 | input = input.cuda() 122 | input = Variable(input) 123 | print('input: ',input.shape) 124 | output = model(input) 125 | print('output: ',output.shape) 126 | 127 | loss = loss_op(input, output) 128 | 129 | 130 | sys.exit() 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | train_loss += loss.data 135 | if (batch_idx +1) % args.print_every == 0 : 136 | deno = args.print_every * args.batch_size * np.prod(obs) * np.log(2.) 137 | writer.add_scalar('train/bpd', (train_loss / deno), writes) 138 | print('loss : {:.4f}, time : {:.4f}'.format( 139 | (train_loss / deno), 140 | (time.time() - time_))) 141 | train_loss = 0. 142 | writes += 1 143 | time_ = time.time() 144 | 145 | 146 | # decrease learning rate 147 | scheduler.step() 148 | 149 | torch.cuda.synchronize() 150 | model.eval() 151 | test_loss = 0. 152 | for batch_idx, (input,_) in enumerate(test_loader): 153 | input = input.cuda() 154 | input_var = Variable(input) 155 | output = model(input_var) 156 | loss = loss_op(input_var, output) 157 | test_loss += loss.data[0] 158 | del loss, output 159 | 160 | deno = batch_idx * args.batch_size * np.prod(obs) * np.log(2.) 161 | writer.add_scalar('test/bpd', (test_loss / deno), writes) 162 | print('test loss : %s' % (test_loss / deno)) 163 | 164 | if (epoch + 1) % args.save_interval == 0: 165 | torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch)) 166 | print('sampling...') 167 | sample_t = sample(model) 168 | sample_t = rescaling_inv(sample_t) 169 | utils.save_image(sample_t,'images/{}_{}.png'.format(model_name, epoch), 170 | nrow=5, padding=0) 171 | -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from layers import * 7 | from utils import * 8 | import numpy as np 9 | 10 | class PixelCNNLayer_up(nn.Module): 11 | def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity): 12 | super(PixelCNNLayer_up, self).__init__() 13 | self.nr_resnet = nr_resnet 14 | # stream from pixels above 15 | self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d, 16 | resnet_nonlinearity, skip_connection=0) 17 | for _ in range(nr_resnet)]) 18 | 19 | # stream from pixels above and to thes left 20 | self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d, 21 | resnet_nonlinearity, skip_connection=1) 22 | for _ in range(nr_resnet)]) 23 | 24 | def forward(self, u, ul): 25 | u_list, ul_list = [], [] 26 | 27 | for i in range(self.nr_resnet): 28 | u = self.u_stream[i](u) 29 | ul = self.ul_stream[i](ul, a=u) 30 | u_list += [u] 31 | ul_list += [ul] 32 | 33 | return u_list, ul_list 34 | 35 | 36 | class PixelCNNLayer_down(nn.Module): 37 | def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity): 38 | super(PixelCNNLayer_down, self).__init__() 39 | self.nr_resnet = nr_resnet 40 | # stream from pixels above 41 | self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d, 42 | resnet_nonlinearity, skip_connection=1) 43 | for _ in range(nr_resnet)]) 44 | 45 | # stream from pixels above and to thes left 46 | self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d, 47 | resnet_nonlinearity, skip_connection=2) 48 | for _ in range(nr_resnet)]) 49 | 50 | def forward(self, u, ul, u_list, ul_list): 51 | for i in range(self.nr_resnet): 52 | u = self.u_stream[i](u, a=u_list.pop()) 53 | ul = self.ul_stream[i](ul, a=torch.cat((u, ul_list.pop()), 1)) 54 | 55 | return u, ul 56 | 57 | 58 | class PixelCNN(nn.Module): 59 | def __init__(self, nr_resnet=5, nr_filters=80, nr_logistic_mix=10, 60 | resnet_nonlinearity='concat_elu', input_channels=3): 61 | super(PixelCNN, self).__init__() 62 | if resnet_nonlinearity == 'concat_elu' : 63 | self.resnet_nonlinearity = lambda x : concat_elu(x) 64 | else : 65 | raise Exception('right now only concat elu is supported as resnet nonlinearity.') 66 | 67 | self.nr_filters = nr_filters 68 | self.input_channels = input_channels 69 | self.nr_logistic_mix = nr_logistic_mix 70 | self.right_shift_pad = nn.ZeroPad2d((1, 0, 0, 0)) 71 | self.down_shift_pad = nn.ZeroPad2d((0, 0, 1, 0)) 72 | 73 | down_nr_resnet = [nr_resnet] + [nr_resnet + 1] * 2 74 | self.down_layers = nn.ModuleList([PixelCNNLayer_down(down_nr_resnet[i], nr_filters, 75 | self.resnet_nonlinearity) for i in range(3)]) 76 | 77 | self.up_layers = nn.ModuleList([PixelCNNLayer_up(nr_resnet, nr_filters, 78 | self.resnet_nonlinearity) for _ in range(3)]) 79 | 80 | self.downsize_u_stream = nn.ModuleList([down_shifted_conv2d(nr_filters, nr_filters, 81 | stride=(2,2)) for _ in range(2)]) 82 | 83 | self.downsize_ul_stream = nn.ModuleList([down_right_shifted_conv2d(nr_filters, 84 | nr_filters, stride=(2,2)) for _ in range(2)]) 85 | 86 | self.upsize_u_stream = nn.ModuleList([down_shifted_deconv2d(nr_filters, nr_filters, 87 | stride=(2,2)) for _ in range(2)]) 88 | 89 | self.upsize_ul_stream = nn.ModuleList([down_right_shifted_deconv2d(nr_filters, 90 | nr_filters, stride=(2,2)) for _ in range(2)]) 91 | 92 | self.u_init = down_shifted_conv2d(input_channels + 1, nr_filters, filter_size=(2,3), 93 | shift_output_down=True) 94 | 95 | self.ul_init = nn.ModuleList([down_shifted_conv2d(input_channels + 1, nr_filters, 96 | filter_size=(1,3), shift_output_down=True), 97 | down_right_shifted_conv2d(input_channels + 1, nr_filters, 98 | filter_size=(2,1), shift_output_right=True)]) 99 | 100 | num_mix = 3 if self.input_channels == 1 else 10 101 | self.nin_out = nin(nr_filters, num_mix * nr_logistic_mix) 102 | self.init_padding = None 103 | 104 | 105 | def forward(self, x, sample=False): 106 | # similar as done in the tf repo : 107 | if self.init_padding is None and not sample: 108 | print(1) 109 | xs = [int(y) for y in x.size()] 110 | padding = Variable(torch.ones(xs[0], 1, xs[2], xs[3]), requires_grad=False) 111 | self.init_padding = padding.cuda() if x.is_cuda else padding 112 | 113 | if sample : 114 | print('**********') 115 | xs = [int(y) for y in x.size()] 116 | padding = Variable(torch.ones(xs[0], 1, xs[2], xs[3]), requires_grad=False) 117 | padding = padding.cuda() if x.is_cuda else padding 118 | x = torch.cat((x, padding), 1) 119 | 120 | ### UP PASS ### 121 | print(2) 122 | print(x.shape) 123 | x = x if sample else torch.cat((x, self.init_padding), 1) 124 | print(x.shape) 125 | u_list = [self.u_init(x)] 126 | ul_list = [self.ul_init[0](x) + self.ul_init[1](x)] 127 | for i in range(3): 128 | # resnet block 129 | u_out, ul_out = self.up_layers[i](u_list[-1], ul_list[-1]) 130 | u_list += u_out 131 | ul_list += ul_out 132 | 133 | if i != 2: 134 | # downscale (only twice) 135 | u_list += [self.downsize_u_stream[i](u_list[-1])] 136 | ul_list += [self.downsize_ul_stream[i](ul_list[-1])] 137 | 138 | ### DOWN PASS ### 139 | u = u_list.pop() 140 | ul = ul_list.pop() 141 | 142 | for i in range(3): 143 | # resnet block 144 | u, ul = self.down_layers[i](u, ul, u_list, ul_list) 145 | 146 | # upscale (only twice) 147 | if i != 2 : 148 | u = self.upsize_u_stream[i](u) 149 | ul = self.upsize_ul_stream[i](ul) 150 | 151 | x_out = self.nin_out(F.elu(ul)) 152 | 153 | assert len(u_list) == len(ul_list) == 0, pdb.set_trace() 154 | 155 | return x_out 156 | 157 | 158 | if __name__ == '__main__': 159 | ''' testing loss with tf version ''' 160 | np.random.seed(1) 161 | xx_t = (np.random.rand(15, 32, 32, 100) * 3).astype('float32') 162 | yy_t = np.random.uniform(-1, 1, size=(15, 32, 32, 3)).astype('float32') 163 | x_t = Variable(torch.from_numpy(xx_t)).cuda() 164 | y_t = Variable(torch.from_numpy(yy_t)).cuda() 165 | loss = discretized_mix_logistic_loss(y_t, x_t) 166 | 167 | ''' testing model and deconv dimensions ''' 168 | x = torch.cuda.FloatTensor(32, 3, 32, 32).uniform_(-1., 1.) 169 | xv = Variable(x).cpu() 170 | ds = down_shifted_deconv2d(3, 40, stride=(2,2)) 171 | x_v = Variable(x) 172 | 173 | ''' testing loss compatibility ''' 174 | model = PixelCNN(nr_resnet=3, nr_filters=100, input_channels=x.size(1)) 175 | model = model.cuda() 176 | out = model(x_v) 177 | loss = discretized_mix_logistic_loss(x_v, out) 178 | print('loss : %s' % loss.data[0]) 179 | -------------------------------------------------------------------------------- /Flops/pixel-cnn-pp-master/layers.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import pdb 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from torch.nn.utils import weight_norm as wn 8 | import numpy as np 9 | 10 | class nin(nn.Module): 11 | def __init__(self, dim_in, dim_out): 12 | super(nin, self).__init__() 13 | self.lin_a = wn(nn.Linear(dim_in, dim_out)) 14 | self.dim_out = dim_out 15 | 16 | def forward(self, x): 17 | og_x = x 18 | # assumes pytorch ordering 19 | """ a network in network layer (1x1 CONV) """ 20 | # TODO : try with original ordering 21 | x = x.permute(0, 2, 3, 1) 22 | shp = [int(y) for y in x.size()] 23 | out = self.lin_a(x.contiguous().view(shp[0]*shp[1]*shp[2], shp[3])) 24 | shp[-1] = self.dim_out 25 | out = out.view(shp) 26 | return out.permute(0, 3, 1, 2) 27 | 28 | 29 | class down_shifted_conv2d(nn.Module): 30 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2,3), stride=(1,1), 31 | shift_output_down=False, norm='weight_norm'): 32 | super(down_shifted_conv2d, self).__init__() 33 | 34 | assert norm in [None, 'batch_norm', 'weight_norm'] 35 | self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride) 36 | self.num_filters_in = num_filters_in 37 | self.num_filters_out = num_filters_out 38 | self.filter_size = filter_size 39 | 40 | self.shift_output_down = shift_output_down 41 | self.norm = norm 42 | self.pad = nn.ZeroPad2d((int((filter_size[1] - 1) / 2), # pad left 43 | int((filter_size[1] - 1) / 2), # pad right 44 | filter_size[0] - 1, # pad top 45 | 0) ) # pad down 46 | 47 | if norm == 'weight_norm': 48 | self.conv = wn(self.conv) 49 | elif norm == 'batch_norm': 50 | self.bn = nn.BatchNorm2d(num_filters_out) 51 | 52 | if shift_output_down : 53 | self.down_shift = lambda x : down_shift(x, pad=nn.ZeroPad2d((0, 0, 1, 0))) 54 | 55 | def forward(self, x): 56 | #print('down_shifted_conv2d') 57 | ops = 0 58 | x = self.pad(x) 59 | x = self.conv(x) 60 | ho, wo = x.shape[2], x.shape[3] 61 | ops = ops + ho*wo*self.num_filters_out * (self.filter_size[0]*self.filter_size[1]*self.num_filters_in) 62 | 63 | x = self.bn(x) if self.norm == 'batch_norm' else x 64 | if self.norm == 'batch_norm': 65 | ops = ops + x.numel() 66 | if self.norm == 'weight_norm': 67 | ops = ops + x.numel()*2 68 | print(ops) 69 | return self.down_shift(x) if self.shift_output_down else x 70 | 71 | 72 | class down_shifted_deconv2d(nn.Module): 73 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2,3), stride=(1,1)): 74 | super(down_shifted_deconv2d, self).__init__() 75 | self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size, stride, 76 | output_padding=1)) 77 | self.num_filters_in = num_filters_in 78 | self.num_filters_out = num_filters_out 79 | 80 | self.filter_size = filter_size 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | ops = 0 85 | ho, wo = x.shape[2], x.shape[3] 86 | x = self.deconv(x) 87 | ops = ops + ho*wo*self.num_filters_in * (self.filter_size[0]*self.filter_size[1]*self.num_filters_out) 88 | 89 | ops = ops + x.numel()*2 90 | 91 | xs = [int(y) for y in x.size()] 92 | print(ops) 93 | return x[:, :, :(xs[2] - self.filter_size[0] + 1), 94 | int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))] 95 | 96 | 97 | class down_right_shifted_conv2d(nn.Module): 98 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2,2), stride=(1,1), 99 | shift_output_right=False, norm='weight_norm'): 100 | super(down_right_shifted_conv2d, self).__init__() 101 | 102 | assert norm in [None, 'batch_norm', 'weight_norm'] 103 | self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0] - 1, 0)) 104 | self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride=stride) 105 | self.num_filters_in = num_filters_in 106 | self.num_filters_out = num_filters_out 107 | self.filter_size = filter_size 108 | 109 | 110 | self.shift_output_right = shift_output_right 111 | self.norm = norm 112 | 113 | if norm == 'weight_norm': 114 | self.conv = wn(self.conv) 115 | elif norm == 'batch_norm': 116 | self.bn = nn.BatchNorm2d(num_filters_out) 117 | 118 | if shift_output_right : 119 | self.right_shift = lambda x : right_shift(x, pad=nn.ZeroPad2d((1, 0, 0, 0))) 120 | 121 | def forward(self, x): 122 | ops = 0 123 | x = self.pad(x) 124 | x = self.conv(x) 125 | ho, wo = x.shape[2], x.shape[3] 126 | ops = ops + ho*wo*self.num_filters_out * (self.filter_size[0]*self.filter_size[1]*self.num_filters_in) 127 | 128 | x = self.bn(x) if self.norm == 'batch_norm' else x 129 | if self.norm == 'batch_norm': 130 | ops = ops + x.numel() 131 | if self.norm == 'weight_norm': 132 | ops = ops + x.numel()*2 133 | print(ops) 134 | return self.right_shift(x) if self.shift_output_right else x 135 | 136 | 137 | class down_right_shifted_deconv2d(nn.Module): 138 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2,2), stride=(1,1), 139 | shift_output_right=False): 140 | super(down_right_shifted_deconv2d, self).__init__() 141 | self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size, 142 | stride, output_padding=1)) 143 | self.num_filters_in = num_filters_in 144 | self.num_filters_out = num_filters_out 145 | 146 | self.filter_size = filter_size 147 | self.stride = stride 148 | 149 | def forward(self, x): 150 | ops = 0 151 | ho, wo = x.shape[2], x.shape[3] 152 | x = self.deconv(x) 153 | ops = ops + ho*wo*self.num_filters_in * (self.filter_size[0]*self.filter_size[1]*self.num_filters_out) 154 | 155 | ops = ops + x.numel()*2 156 | 157 | xs = [int(y) for y in x.size()] 158 | x = x[:, :, :(xs[2] - self.filter_size[0] + 1):, :(xs[3] - self.filter_size[1] + 1)] 159 | print(ops) 160 | return x 161 | 162 | 163 | ''' 164 | skip connection parameter : 0 = no skip connection 165 | 1 = skip connection where skip input size === input size 166 | 2 = skip connection where skip input size === 2 * input size 167 | ''' 168 | class gated_resnet(nn.Module): 169 | def __init__(self, num_filters, conv_op, nonlinearity=concat_elu, skip_connection=0): 170 | super(gated_resnet, self).__init__() 171 | self.skip_connection = skip_connection 172 | self.nonlinearity = nonlinearity 173 | self.conv_input = conv_op(2 * num_filters, num_filters) # cuz of concat elu 174 | 175 | if skip_connection != 0 : 176 | self.nin_skip = nin(2 * skip_connection * num_filters, num_filters) 177 | self.nin_i = 2 * skip_connection * num_filters 178 | self.nin_o = num_filters 179 | 180 | self.dropout = nn.Dropout2d(0.5) 181 | self.conv_out = conv_op(2 * num_filters, 2 * num_filters) 182 | 183 | 184 | def forward(self, og_x, a=None): 185 | ops = 0 186 | x = self.conv_input(self.nonlinearity(og_x)) 187 | if a is not None : 188 | x += self.nin_skip(self.nonlinearity(a)) 189 | ops = ops + self.nonlinearity(a).numel()*self.nin_i 190 | x = self.nonlinearity(x) 191 | x = self.dropout(x) 192 | x = self.conv_out(x) 193 | a, b = torch.chunk(x, 2, dim=1) 194 | c3 = a * F.sigmoid(b) 195 | ops = ops + b.numel()*3 196 | ops = ops + c3.numel() 197 | print(ops) 198 | return og_x + c3 199 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from torchvision import datasets 4 | 5 | def get_dataloader(name,normalizer,bs): 6 | if name == 'cifar10': 7 | dataloader = cifar10(normalizer,bs) 8 | elif name == 'cifar100': 9 | dataloader = cifar100(normalizer,bs) 10 | elif name == 'mnist': 11 | dataloader = mnist(normalizer,bs) 12 | elif name == 'kmnist': 13 | dataloader = kmnist(normalizer,bs) 14 | elif name == 'fasionmnist': 15 | dataloader = fasionmnist(normalizer,bs) 16 | elif name == 'svhn': 17 | dataloader = svhn(normalizer,bs) 18 | elif name == 'stl10': 19 | dataloader = stl10(normalizer,bs) 20 | elif name == 'dtd': 21 | dataloader = dtd(normalizer,bs) 22 | elif name == 'place365': 23 | dataloader = place365(normalizer,bs) 24 | elif name == 'lsun': 25 | dataloader = lsun(normalizer,bs) 26 | elif name == 'lsunR': 27 | dataloader = lsunR(normalizer,bs) 28 | elif name == 'isun': 29 | dataloader = isun(normalizer,bs) 30 | elif name == 'celebA': 31 | dataloader = celebA(normalizer,bs) 32 | else: 33 | print('the dataset is not used in this project') 34 | return None 35 | return dataloader 36 | 37 | 38 | def cifar10(normalizer,bs): 39 | transform_cifar10 = transforms.Compose([transforms.ToTensor(), 40 | normalizer 41 | ]) 42 | dataloader = torch.utils.data.DataLoader( 43 | datasets.CIFAR10('data/cifar10', 44 | train=False, 45 | download=True, 46 | transform=transform_cifar10), 47 | batch_size=bs, 48 | shuffle=False, 49 | num_workers=1, 50 | pin_memory=True) 51 | return dataloader 52 | 53 | def celebA(normalizer,bs): 54 | transformer = transforms.Compose([transforms.Resize(32), 55 | transforms.ToTensor(), 56 | normalizer 57 | ]) 58 | dataloader = torch.utils.data.DataLoader( 59 | datasets.CelebA('data/celebA', 60 | split='test', 61 | download=True, 62 | transform=transformer), 63 | batch_size=bs, 64 | shuffle=False, 65 | num_workers=1, 66 | pin_memory=True) 67 | return dataloader 68 | 69 | def cifar100(normalizer,bs): 70 | transform_cifar100 = transforms.Compose([transforms.ToTensor(), 71 | normalizer 72 | ]) 73 | dataloader = torch.utils.data.DataLoader( 74 | datasets.CIFAR100('data/cifar100', 75 | train=False, 76 | download=True, 77 | transform=transform_cifar100), 78 | batch_size=bs, 79 | shuffle=False, 80 | num_workers=1, 81 | pin_memory=True) 82 | return dataloader 83 | def mnist(normalizer,bs): 84 | transformer = transforms.Compose([transforms.Grayscale(num_output_channels=3), 85 | transforms.Pad(padding=2), 86 | transforms.ToTensor(), 87 | normalizer 88 | ]) 89 | dataloader = torch.utils.data.DataLoader( 90 | datasets.MNIST('data/mnist', 91 | train=False, 92 | download=True, 93 | transform=transformer), 94 | batch_size=bs, 95 | shuffle=False, 96 | num_workers=1, 97 | pin_memory=True) 98 | return dataloader 99 | def kmnist(normalizer,bs): 100 | transformer = transforms.Compose([transforms.Grayscale(num_output_channels=3), 101 | transforms.Pad(padding=2), 102 | transforms.ToTensor(), 103 | normalizer 104 | ]) 105 | dataloader = torch.utils.data.DataLoader( 106 | datasets.KMNIST('data/kmnist', 107 | train=False, 108 | download=True, 109 | transform=transformer), 110 | batch_size=bs, 111 | shuffle=False, 112 | num_workers=1, 113 | pin_memory=True) 114 | return dataloader 115 | def fasionmnist(normalizer,bs): 116 | transformer = transforms.Compose([transforms.Grayscale(num_output_channels=3), 117 | transforms.Pad(padding=2), 118 | transforms.ToTensor(), 119 | normalizer 120 | ]) 121 | dataloader = torch.utils.data.DataLoader( 122 | datasets.FashionMNIST('data/fasionmnist', 123 | train=False, 124 | download=True, 125 | transform=transformer), 126 | batch_size=bs, 127 | shuffle=False, 128 | num_workers=1, 129 | pin_memory=True) 130 | return dataloader 131 | ''' 132 | def svhn(normalizer,bs): 133 | transformer = transforms.Compose([transforms.ToTensor(), 134 | normalizer 135 | ]) 136 | dataloader = torch.utils.data.DataLoader( 137 | datasets.SVHN('data/svhn', 138 | split='test', 139 | download=True, 140 | transform=transformer), 141 | batch_size=bs, 142 | shuffle=False, 143 | num_workers=1, 144 | pin_memory=True) 145 | return dataloader 146 | ''' 147 | def stl10(normalizer,bs): 148 | transformer = transforms.Compose([transforms.Resize(32), 149 | transforms.ToTensor(), 150 | normalizer 151 | ]) 152 | dataloader = torch.utils.data.DataLoader( 153 | datasets.STL10('data/STL10', 154 | split='test', 155 | folds=0, 156 | download=(True), 157 | transform=transformer), 158 | batch_size=bs, 159 | shuffle=False, 160 | num_workers=1, 161 | pin_memory=True) 162 | return dataloader 163 | 164 | def svhn(normalizer,bs): 165 | import utils.svhn_loader as svhn 166 | transformer = transforms.Compose([transforms.ToTensor(), 167 | normalizer 168 | ]) 169 | info_svhn_dataset = svhn.SVHN('data/svhn', split='test', 170 | transform=transformer, download=True) 171 | dataloader = torch.utils.data.DataLoader( 172 | info_svhn_dataset, 173 | batch_size=bs, 174 | shuffle=False, 175 | num_workers=1, 176 | pin_memory=True) 177 | return dataloader 178 | 179 | def dtd(normalizer,bs): 180 | import torchvision 181 | transformer = transforms.Compose([transforms.Resize(32), 182 | transforms.CenterCrop(32),#32*40 exist 183 | transforms.ToTensor(), 184 | normalizer 185 | ]) 186 | info_dtd_dataset = torchvision.datasets.ImageFolder(root="data/dtd/images", 187 | transform=transformer) 188 | dataloader = torch.utils.data.DataLoader( 189 | info_dtd_dataset, 190 | batch_size=bs, 191 | shuffle=False, 192 | num_workers=1, 193 | pin_memory=True) 194 | return dataloader 195 | def place365(normalizer,bs): 196 | import torchvision 197 | transformer = transforms.Compose([transforms.Resize(32), 198 | #transforms.CenterCrop(32), 199 | transforms.ToTensor(), 200 | normalizer 201 | ]) 202 | info_place365_dataset = torchvision.datasets.ImageFolder(root="data/places365/test_subset", 203 | transform=transformer) 204 | dataloader = torch.utils.data.DataLoader( 205 | info_place365_dataset, 206 | batch_size=bs, 207 | shuffle=False, 208 | num_workers=1, 209 | pin_memory=True) 210 | return dataloader 211 | def lsun(normalizer,bs): 212 | import torchvision 213 | transformer = transforms.Compose([transforms.Resize(32), 214 | #transforms.CenterCrop(32), 215 | transforms.ToTensor(), 216 | normalizer 217 | ]) 218 | info_lsun_dataset = torchvision.datasets.ImageFolder("data/LSUN", 219 | transform=transformer) 220 | dataloader = torch.utils.data.DataLoader( 221 | info_lsun_dataset, 222 | batch_size=bs, 223 | shuffle=False, 224 | num_workers=1, 225 | pin_memory=True) 226 | return dataloader 227 | def lsunR(normalizer,bs): 228 | import torchvision 229 | transformer = transforms.Compose([transforms.Resize(32), 230 | #transforms.CenterCrop(32), 231 | transforms.ToTensor(), 232 | normalizer 233 | ]) 234 | info_lsunR_dataset = torchvision.datasets.ImageFolder("data/LSUN_resize", 235 | transform=transformer) 236 | dataloader = torch.utils.data.DataLoader( 237 | info_lsunR_dataset, 238 | batch_size=bs, 239 | shuffle=False, 240 | num_workers=1, 241 | pin_memory=True) 242 | return dataloader 243 | def isun(normalizer,bs): 244 | import torchvision 245 | transformer = transforms.Compose([transforms.Resize(32), 246 | #transforms.CenterCrop(32), 247 | transforms.ToTensor(), 248 | normalizer 249 | ]) 250 | info_isun_dataset = torchvision.datasets.ImageFolder("data/iSUN", 251 | transform=transformer) 252 | dataloader = torch.utils.data.DataLoader( 253 | info_isun_dataset, 254 | batch_size=bs, 255 | shuffle=False, 256 | num_workers=1, 257 | pin_memory=True) 258 | return dataloader 259 | 260 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from modules import ( 7 | Conv2d, 8 | Conv2dZeros, 9 | ActNorm2d, 10 | InvertibleConv1x1, 11 | Permute2d, 12 | LinearZeros, 13 | SqueezeLayer, 14 | Split2d, 15 | gaussian_likelihood, 16 | gaussian_sample, 17 | ) 18 | from utils import split_feature, uniform_binning_correction 19 | 20 | 21 | def get_block(in_channels, out_channels, hidden_channels): 22 | block = nn.Sequential( 23 | Conv2d(in_channels, hidden_channels), 24 | # in_channels * hidden_channels 25 | nn.ReLU(inplace=False), 26 | Conv2d(hidden_channels, hidden_channels, kernel_size=(1, 1)), 27 | nn.ReLU(inplace=False), 28 | Conv2dZeros(hidden_channels, out_channels), 29 | ) 30 | return block 31 | 32 | 33 | class FlowStep(nn.Module): 34 | def __init__( 35 | self, 36 | in_channels, 37 | hidden_channels, 38 | actnorm_scale, 39 | flow_permutation, 40 | flow_coupling, 41 | LU_decomposed, 42 | ): 43 | super().__init__() 44 | self.in_channels = in_channels 45 | self.hidden_channels = hidden_channels 46 | 47 | self.flow_coupling = flow_coupling 48 | 49 | self.actnorm = ActNorm2d(in_channels, actnorm_scale) 50 | 51 | # 2. permute 52 | if flow_permutation == "invconv": 53 | self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=LU_decomposed) 54 | self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev) 55 | elif flow_permutation == "shuffle": 56 | self.shuffle = Permute2d(in_channels, shuffle=True) 57 | self.flow_permutation = lambda z, logdet, rev: ( 58 | self.shuffle(z, rev), 59 | logdet, 60 | ) 61 | else: 62 | self.reverse = Permute2d(in_channels, shuffle=False) 63 | self.flow_permutation = lambda z, logdet, rev: ( 64 | self.reverse(z, rev), 65 | logdet, 66 | ) 67 | 68 | # 3. coupling 69 | if flow_coupling == "additive": 70 | self.block = get_block(in_channels // 2, in_channels // 2, hidden_channels) 71 | elif flow_coupling == "affine": 72 | self.block = get_block(in_channels // 2, in_channels, hidden_channels) 73 | 74 | def forward(self, input, logdet=None, reverse=False): 75 | if not reverse: 76 | return self.normal_flow(input, logdet) 77 | else: 78 | return self.reverse_flow(input, logdet) 79 | 80 | def normal_flow(self, input, logdet): 81 | assert input.size(1) % 2 == 0 82 | 83 | ops = 0 84 | # 1. actnorm 85 | z, logdet = self.actnorm(input, logdet=logdet, reverse=False) 86 | #print( '# 1. actnorm' ) 87 | #print(input.shape) 88 | #print(z.shape) 89 | #print(logdet) 90 | ops = ops + input.numel()*2 91 | # 2. permute 92 | z, logdet = self.flow_permutation(z, logdet, False) 93 | #print( '# 2. permute' ) 94 | #print(z.shape) 95 | #print(logdet) 96 | ops = ops + z.shape[2]*z.shape[3] *z.shape[1] *(1*1*z.shape[1] ) 97 | # 3. coupling 98 | z1, z2 = split_feature(z, "split") 99 | if self.flow_coupling == "additive": 100 | z2 = z2 + self.block(z1) 101 | elif self.flow_coupling == "affine": 102 | h = self.block(z1) 103 | #print(z1.shape[1]) 104 | #print(self.hidden_channels,z1.shape[2],z1.shape[3]) 105 | ops = ops + z1.shape[2]*z1.shape[3] *self.hidden_channels *(3*3*self.in_channels//2 ) 106 | ops = ops + z1.shape[2]*z1.shape[3] *self.hidden_channels *3 # actnorm relu 107 | ops = ops + z1.shape[2]*z1.shape[3] *self.hidden_channels *(1*1*self.hidden_channels) 108 | ops = ops + z1.shape[2]*z1.shape[3] *self.hidden_channels *3 # actnorm relu 109 | ops = ops + z1.shape[2]*z1.shape[3] *self.in_channels *(3*3*self.hidden_channels) 110 | #print(ops) 111 | shift, scale = split_feature(h, "cross") 112 | scale = torch.sigmoid(scale + 2.0) 113 | ops = ops + scale.numel()*(1+3) 114 | z2 = z2 + shift 115 | z2 = z2 * scale 116 | ops = ops + z2.numel() 117 | logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet 118 | z = torch.cat((z1, z2), dim=1) 119 | print(ops) 120 | return z, logdet 121 | 122 | def reverse_flow(self, input, logdet): 123 | assert input.size(1) % 2 == 0 124 | 125 | # 1.coupling 126 | z1, z2 = split_feature(input, "split") 127 | if self.flow_coupling == "additive": 128 | z2 = z2 - self.block(z1) 129 | elif self.flow_coupling == "affine": 130 | h = self.block(z1) 131 | shift, scale = split_feature(h, "cross") 132 | scale = torch.sigmoid(scale + 2.0) 133 | z2 = z2 / scale 134 | z2 = z2 - shift 135 | logdet = -torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet 136 | z = torch.cat((z1, z2), dim=1) 137 | 138 | # 2. permute 139 | z, logdet = self.flow_permutation(z, logdet, True) 140 | 141 | # 3. actnorm 142 | z, logdet = self.actnorm(z, logdet=logdet, reverse=True) 143 | 144 | return z, logdet 145 | 146 | 147 | class FlowNet(nn.Module): 148 | def __init__( 149 | self, 150 | image_shape, 151 | hidden_channels, 152 | K, 153 | L, 154 | actnorm_scale, 155 | flow_permutation, 156 | flow_coupling, 157 | LU_decomposed, 158 | ): 159 | super().__init__() 160 | 161 | self.layers = nn.ModuleList() 162 | self.output_shapes = [] 163 | 164 | self.K = K 165 | self.L = L 166 | 167 | H, W, C = image_shape 168 | 169 | for i in range(L): 170 | # 1. Squeeze 171 | C, H, W = C * 4, H // 2, W // 2 172 | self.layers.append(SqueezeLayer(factor=2)) 173 | self.output_shapes.append([-1, C, H, W]) 174 | 175 | # 2. K FlowStep 176 | for _ in range(K): 177 | self.layers.append( 178 | FlowStep( 179 | in_channels=C, 180 | hidden_channels=hidden_channels, 181 | actnorm_scale=actnorm_scale, 182 | flow_permutation=flow_permutation, 183 | flow_coupling=flow_coupling, 184 | LU_decomposed=LU_decomposed, 185 | ) 186 | ) 187 | self.output_shapes.append([-1, C, H, W]) 188 | 189 | # 3. Split2d 190 | if i < L - 1: 191 | self.layers.append(Split2d(num_channels=C)) 192 | self.output_shapes.append([-1, C // 2, H, W]) 193 | C = C // 2 194 | 195 | def forward(self, input, logdet=0.0, reverse=False, temperature=None): 196 | if reverse: 197 | return self.decode(input, temperature) 198 | else: 199 | return self.encode(input, logdet) 200 | 201 | def encode(self, z, logdet=0.0): 202 | for layer, shape in zip(self.layers, self.output_shapes): 203 | z, logdet = layer(z, logdet, reverse=False) 204 | return z, logdet 205 | 206 | def decode(self, z, temperature=None): 207 | for layer in reversed(self.layers): 208 | if isinstance(layer, Split2d): 209 | z, logdet = layer(z, logdet=0, reverse=True, temperature=temperature) 210 | else: 211 | z, logdet = layer(z, logdet=0, reverse=True) 212 | return z 213 | 214 | 215 | class Glow(nn.Module): 216 | def __init__( 217 | self, 218 | image_shape, 219 | hidden_channels, 220 | K, 221 | L, 222 | actnorm_scale, 223 | flow_permutation, 224 | flow_coupling, 225 | LU_decomposed, 226 | y_classes, 227 | learn_top, 228 | y_condition, 229 | ): 230 | super().__init__() 231 | self.flow = FlowNet( 232 | image_shape=image_shape, 233 | hidden_channels=hidden_channels, 234 | K=K, 235 | L=L, 236 | actnorm_scale=actnorm_scale, 237 | flow_permutation=flow_permutation, 238 | flow_coupling=flow_coupling, 239 | LU_decomposed=LU_decomposed, 240 | ) 241 | self.y_classes = y_classes 242 | self.y_condition = y_condition 243 | 244 | self.learn_top = learn_top 245 | 246 | # learned prior 247 | if learn_top: 248 | C = self.flow.output_shapes[-1][1] 249 | self.learn_top_fn = Conv2dZeros(C * 2, C * 2) 250 | 251 | if y_condition: 252 | C = self.flow.output_shapes[-1][1] 253 | self.project_ycond = LinearZeros(y_classes, 2 * C) 254 | self.project_class = LinearZeros(C, y_classes) 255 | 256 | self.register_buffer( 257 | "prior_h", 258 | torch.zeros( 259 | [ 260 | 1, 261 | self.flow.output_shapes[-1][1] * 2, 262 | self.flow.output_shapes[-1][2], 263 | self.flow.output_shapes[-1][3], 264 | ] 265 | ), 266 | ) 267 | 268 | def prior(self, data, y_onehot=None): 269 | if data is not None: 270 | h = self.prior_h.repeat(data.shape[0], 1, 1, 1) 271 | else: 272 | # Hardcoded a batch size of 32 here 273 | h = self.prior_h.repeat(32, 1, 1, 1) 274 | 275 | channels = h.size(1) 276 | 277 | if self.learn_top: 278 | h = self.learn_top_fn(h) 279 | 280 | if self.y_condition: 281 | assert y_onehot is not None 282 | yp = self.project_ycond(y_onehot) 283 | h += yp.view(h.shape[0], channels, 1, 1) 284 | 285 | return split_feature(h, "split") 286 | 287 | def forward(self, x=None, y_onehot=None, z=None, temperature=None, reverse=False): 288 | if reverse: 289 | return self.reverse_flow(z, y_onehot, temperature) 290 | else: 291 | return self.normal_flow(x, y_onehot) 292 | 293 | def normal_flow(self, x, y_onehot): 294 | b, c, h, w = x.shape 295 | 296 | x, logdet = uniform_binning_correction(x) 297 | 298 | z, objective = self.flow(x, logdet=logdet, reverse=False) 299 | 300 | mean, logs = self.prior(x, y_onehot) 301 | objective += gaussian_likelihood(mean, logs, z) 302 | 303 | if self.y_condition: 304 | y_logits = self.project_class(z.mean(2).mean(2)) 305 | else: 306 | y_logits = None 307 | 308 | # Full objective - converted to bits per dimension 309 | bpd = (-objective) / (math.log(2.0) * c * h * w) 310 | 311 | print(z.shape) 312 | print(bpd.shape) 313 | return z, bpd, y_logits 314 | 315 | def reverse_flow(self, z, y_onehot, temperature): 316 | with torch.no_grad(): 317 | if z is None: 318 | mean, logs = self.prior(z, y_onehot) 319 | z = gaussian_sample(mean, logs, temperature) 320 | x = self.flow(z, temperature=temperature, reverse=True) 321 | return x 322 | 323 | def set_actnorm_init(self): 324 | for name, m in self.named_modules(): 325 | if isinstance(m, ActNorm2d): 326 | m.inited = True 327 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/utils_flop.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import os 4 | import shutil 5 | import torch 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | import math 9 | import random 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.autograd 15 | 16 | class KDLoss(nn.Module): 17 | def __init__(self, args): 18 | super(KDLoss, self).__init__() 19 | 20 | self.kld_loss = nn.KLDivLoss().cuda() 21 | self.ce_loss = nn.CrossEntropyLoss().cuda() 22 | self.log_softmax = nn.LogSoftmax(dim=1).cuda() 23 | self.softmax = nn.Softmax(dim=1).cuda() 24 | 25 | self.T = args.T 26 | self.gamma = args.gamma 27 | self.nBlocks = args.nBlocks 28 | 29 | def loss_fn_kd(self, outputs, targets, soft_targets): 30 | loss = self.ce_loss(outputs[-1], targets) 31 | T = self.T 32 | for i in range(self.nBlocks - 1): 33 | _ce = (1. - self.gamma) * self.ce_loss(outputs[i], targets) 34 | _kld = self.kld_loss(self.log_softmax(outputs[i] / T), self.softmax(soft_targets.detach() / T)) * self.gamma * T * T 35 | loss = loss + _ce + _kld 36 | return loss 37 | 38 | 39 | class MyRandomSizedCrop(object): 40 | def __init__(self, size, augmentation=0.08, interpolation=Image.BILINEAR): 41 | self.size = size 42 | self.interpolation = interpolation 43 | self.augmentation = augmentation 44 | 45 | def __call__(self, img): 46 | for _ in range(10): 47 | area = img.size[0] * img.size[1] 48 | target_area = random.uniform(self.augmentation, 1.0) * area 49 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 50 | 51 | w = int(round(math.sqrt(target_area * aspect_ratio))) 52 | h = int(round(math.sqrt(target_area / aspect_ratio))) 53 | 54 | if random.random() < 0.5: 55 | w, h = h, w 56 | 57 | if w <= img.size[0] and h <= img.size[1]: 58 | x1 = random.randint(0, img.size[0] - w) 59 | y1 = random.randint(0, img.size[1] - h) 60 | 61 | img = img.crop((x1, y1, x1 + w, y1 + h)) 62 | assert(img.size == (w, h)) 63 | 64 | return img.resize((self.size, self.size), self.interpolation) 65 | 66 | # Fallback 67 | scale = transforms.Scale(self.size, interpolation=self.interpolation) 68 | crop = transforms.CenterCrop(self.size) 69 | return crop(scale(img)) 70 | 71 | 72 | def create_save_folder(save_path, ignore_patterns=[]): 73 | if not os.path.exists(save_path): 74 | os.makedirs(save_path) 75 | print('create folder: ' + save_path) 76 | 77 | def adjust_learning_rate(optimizer, lr_init, decay_rate, epoch, num_epochs, args): 78 | """Decay Learning rate at 1/2 and 3/4 of the num_epochs""" 79 | lr = lr_init 80 | if args.data == 'imagenet': 81 | if epoch >= 30: 82 | lr *= decay_rate 83 | elif epoch >= 60: 84 | lr *= decay_rate ** 2 85 | else: 86 | if epoch >= num_epochs * 0.75: 87 | lr *= decay_rate**2 88 | elif epoch >= num_epochs * 0.5: 89 | lr *= decay_rate 90 | for param_group in optimizer.param_groups: 91 | param_group['lr'] = lr 92 | return lr 93 | 94 | def save_checkpoint(state, args, is_best, filename, result): 95 | print(args) 96 | result_filename = os.path.join(args.save, 'scores.tsv') 97 | model_dir = os.path.join(args.save, 'save_models') 98 | latest_filename = os.path.join(model_dir, 'latest.txt') 99 | model_filename = os.path.join(model_dir, filename) 100 | best_filename = os.path.join(model_dir, 'model_best.pth.tar') 101 | os.makedirs(args.save, exist_ok=True) 102 | os.makedirs(model_dir, exist_ok=True) 103 | print("=> saving checkpoint '{}'".format(model_filename)) 104 | 105 | torch.save(state, model_filename) 106 | 107 | with open(result_filename, 'w') as f: 108 | print('\n'.join(result), file=f) 109 | 110 | with open(latest_filename, 'w') as fout: 111 | fout.write(model_filename) 112 | if is_best: 113 | shutil.copyfile(model_filename, best_filename) 114 | 115 | print("=> saved checkpoint '{}'".format(model_filename)) 116 | return 117 | 118 | def load_checkpoint(args): 119 | model_dir = os.path.join(args.save, 'save_models') 120 | latest_filename = os.path.join(model_dir, 'latest.txt') 121 | if os.path.exists(latest_filename): 122 | with open(latest_filename, 'r') as fin: 123 | model_filename = fin.readlines()[0] 124 | else: 125 | return None 126 | print("=> loading checkpoint '{}'".format(model_filename)) 127 | state = torch.load(model_filename) 128 | print("=> loaded checkpoint '{}'".format(model_filename)) 129 | return state 130 | 131 | 132 | def get_optimizer(model, args): 133 | if args.optimizer == 'sgd': 134 | return torch.optim.SGD(model.parameters(), args.lr, 135 | momentum=args.momentum, nesterov=args.nesterov, 136 | weight_decay=args.weight_decay) 137 | elif args.optimizer == 'rmsprop': 138 | return torch.optim.RMSprop(model.parameters(), args.lr, 139 | alpha=args.alpha, 140 | weight_decay=args.weight_decay) 141 | elif args.optimizer == 'adam': 142 | return torch.optim.Adam(model.parameters(), args.lr, 143 | beta=(args.beta1, args.beta2), 144 | weight_decay=args.weight_decay) 145 | else: 146 | raise NotImplementedError 147 | 148 | 149 | class AverageMeter(object): 150 | """Computes and stores the average and current value""" 151 | 152 | def __init__(self): 153 | self.reset() 154 | 155 | def reset(self): 156 | self.val = 0 157 | self.avg = 0 158 | self.sum = 0 159 | self.count = 0 160 | 161 | def update(self, val, n=1): 162 | self.val = val 163 | self.sum += val * n 164 | self.count += n 165 | self.avg = self.sum / self.count 166 | 167 | 168 | def error(output, target, topk=(1,)): 169 | """Computes the error@k for the specified values of k""" 170 | maxk = max(topk) 171 | batch_size = target.size(0) 172 | 173 | _, pred = output.topk(maxk, 1, True, True) 174 | pred = pred.t() 175 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 176 | 177 | res = [] 178 | for k in topk: 179 | correct_k = correct[:k].view(-1).float().sum(0) 180 | res.append(100.0 - correct_k.mul_(100.0 / batch_size)) 181 | return res 182 | 183 | 184 | 185 | import torch 186 | from torch.autograd import Variable 187 | from functools import reduce 188 | import operator 189 | 190 | 191 | count_ops = 0 192 | count_params = 0 193 | 194 | 195 | def get_num_gen(gen): 196 | return sum(1 for x in gen) 197 | 198 | 199 | def is_pruned(layer): 200 | try: 201 | layer.mask 202 | return True 203 | except AttributeError: 204 | return False 205 | 206 | 207 | def is_leaf(model): 208 | return get_num_gen(model.children()) == 0 209 | 210 | 211 | def get_layer_info(layer): 212 | layer_str = str(layer) 213 | type_name = layer_str[:layer_str.find('(')].strip() 214 | return type_name 215 | 216 | 217 | def get_layer_param(model): 218 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 219 | 220 | 221 | """ The input batch size should be 1 to call this function """ 222 | def measure_layer(layer, x): 223 | global count_ops, count_params 224 | delta_ops = 0 225 | delta_params = 0 226 | multi_add = 1 227 | type_name = get_layer_info(layer) 228 | print(str(layer), count_ops) 229 | ### ops_conv 230 | if type_name in ['Conv2d']: 231 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 232 | layer.stride[0] + 1) 233 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 234 | layer.stride[1] + 1) 235 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 236 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 237 | delta_params = get_layer_param(layer) 238 | 239 | ### ops_learned_conv 240 | elif type_name in ['LearnedGroupConv']: 241 | measure_layer(layer.relu, x) 242 | measure_layer(layer.norm, x) 243 | conv = layer.conv 244 | out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / 245 | conv.stride[0] + 1) 246 | out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / 247 | conv.stride[1] + 1) 248 | delta_ops = conv.in_channels * conv.out_channels * conv.kernel_size[0] * \ 249 | conv.kernel_size[1] * out_h * out_w / layer.condense_factor * multi_add 250 | delta_params = get_layer_param(conv) / layer.condense_factor 251 | 252 | ### ops_nonlinearity 253 | elif type_name in ['ReLU']: 254 | delta_ops = x.numel() 255 | delta_params = get_layer_param(layer) 256 | 257 | ### ops_pooling 258 | elif type_name in ['AvgPool2d', 'MaxPool2d']: 259 | in_w = x.size()[2] 260 | kernel_ops = layer.kernel_size * layer.kernel_size 261 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 262 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 263 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops 264 | delta_params = get_layer_param(layer) 265 | 266 | elif type_name in ['AdaptiveAvgPool2d']: 267 | delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] 268 | delta_params = get_layer_param(layer) 269 | 270 | ### ops_linear 271 | elif type_name in ['Linear']: 272 | weight_ops = layer.weight.numel() * multi_add 273 | bias_ops = layer.bias.numel() 274 | delta_ops = x.size()[0] * (weight_ops + bias_ops) 275 | delta_params = get_layer_param(layer) 276 | 277 | ### ops_nothing 278 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']: 279 | delta_params = get_layer_param(layer) 280 | 281 | ### unknown layer type 282 | else: 283 | print('unknown layer type: %s' % type_name) 284 | #raise TypeError('unknown layer type: %s' % type_name) 285 | 286 | count_ops += delta_ops 287 | count_params += delta_params 288 | return 289 | 290 | 291 | def measure_model(model, H, W): 292 | global count_ops, count_params 293 | count_ops = 0 294 | count_params = 0 295 | data = Variable(torch.zeros(1, 3, H, W)).cuda() 296 | 297 | def modify_forward(model): 298 | for child in model.children(): 299 | if is_leaf(child): 300 | def new_forward(m): 301 | def lambda_forward(x): 302 | measure_layer(m, x) 303 | return m.old_forward(x) 304 | return lambda_forward 305 | child.old_forward = child.forward 306 | child.forward = new_forward(child) 307 | else: 308 | modify_forward(child) 309 | 310 | def restore_forward(model): 311 | for child in model.children(): 312 | # leaf node 313 | if is_leaf(child) and hasattr(child, 'old_forward'): 314 | child.forward = child.old_forward 315 | child.old_forward = None 316 | else: 317 | restore_forward(child) 318 | 319 | modify_forward(model) 320 | model.forward(data) 321 | restore_forward(model) 322 | 323 | return count_ops, count_params 324 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torchvision.transforms as transforms 8 | from utils.dataloader import get_dataloader 9 | from utils.MOOD import get_ood_score, sample_estimator 10 | from utils.MOOD import auroc, fpr95 11 | #import argparse 12 | ''' 13 | mood_parser = argparse.ArgumentParser() 14 | mood_parser.add_argument('-s', '--score', type=str, 15 | default='energy', 16 | help='basic score for MOOD method, choose from: energy, msp, odin, mahalanobis') 17 | 18 | mood_parser.add_argument('-f', '--file', type=str, 19 | default='trained_model/msdnet_cifar10.pth.tar', 20 | help='model file for MSDNet') 21 | mood_parser.add_argument('-l', '--layer', type=int, 22 | default=5, 23 | help='# of exits for MSDNet') 24 | 25 | mood_parser.add_argument('-i', '--id', type=str, 26 | default='cifar10', 27 | help='in distribution dataset: cifar10 or cifar100') 28 | mood_parser.add_argument('-o', '--od', type=list, 29 | default=['mnist', 30 | 'kmnist', 31 | 'fasionmnist', 32 | 'lsun', 33 | 'svhn', 34 | 'dtd', 35 | 'stl10', 36 | 'place365', 37 | 'isun', 38 | 'lsunR' 39 | ], 40 | help='all 10 OOD datasets used in experiment') 41 | 42 | mood_parser.add_argument('-c', '--compressor', type=str, 43 | default='png', 44 | help='compressor for complexity') 45 | mood_parser.add_argument('-t', '--threshold', type=int, 46 | default=[0, 47 | 1*2700/5, 48 | 2*2700/5, 49 | 3*2700/5, 50 | 4*2700/5, 51 | 9999], 52 | 53 | help='the complex thresholds for different exits in MSDNet') 54 | mood_parser.add_argument('-a', '--adjusted', type=int, 55 | default=1, 56 | help='adjusted energy score: mode 1: minus mean; mode 0: keep as original') 57 | 58 | mood_parser.add_argument('-b', '--bs', type=int, 59 | default=64, 60 | help='batch size') 61 | mood_args = mood_parser.parse_args() 62 | ''' 63 | 64 | if 1:#load and test model 65 | from msd_args import arg_parser 66 | import models 67 | from msd_dataloader import msd_get_dataloaders 68 | mood_args = arg_parser.parse_args() 69 | mood_args.grFactor = list(map(int, mood_args.grFactor.split('-'))) 70 | mood_args.bnFactor = list(map(int, mood_args.bnFactor.split('-'))) 71 | mood_args.nScales = len(mood_args.grFactor) 72 | 73 | if mood_args.use_valid: 74 | mood_args.splits = ['train', 'val', 'test'] 75 | else: 76 | mood_args.splits = ['train', 'val'] 77 | mood_args.data = mood_args.id 78 | if mood_args.data == 'cifar10': 79 | mood_args.num_classes = 10 80 | elif mood_args.data == 'cifar100': 81 | mood_args.num_classes = 100 82 | else: 83 | print('dataset not support!') 84 | 85 | model = getattr(models, mood_args.arch)(mood_args) 86 | model = torch.nn.DataParallel(model).cuda() 87 | 88 | criterion = nn.CrossEntropyLoss().cuda() 89 | 90 | cudnn.benchmark = True 91 | 92 | train_loader, val_loader, test_loader = msd_get_dataloaders(mood_args) 93 | print("*************************************") 94 | print(mood_args.use_valid, len(train_loader), len(val_loader)) 95 | print("*************************************") 96 | 97 | model.load_state_dict(torch.load(mood_args.file)['state_dict']) 98 | print(sum(p.numel() for p in model.parameters() if p.requires_grad)) 99 | 100 | model.eval() 101 | 102 | if 1: 103 | from utils.msdnet_function import validate 104 | val_loss, val_err1, val_err5 = validate(test_loader, model, criterion) 105 | 106 | 107 | 108 | if mood_args.id == 'cifar10': 109 | MEAN=[0.4914, 0.4824, 0.4467] 110 | STD=[0.2471, 0.2435, 0.2616] 111 | NM = [MEAN,STD] 112 | elif mood_args.id == 'cifar100': 113 | MEAN=[0.5071, 0.4867, 0.4408] 114 | STD=[0.2675, 0.2565, 0.2761] 115 | NM = [MEAN,STD] 116 | else: 117 | print('wrong indistribution dataset! use cifar10 or cifar100!') 118 | 119 | normalizer = transforms.Normalize(mean=MEAN, std=STD) 120 | print('calculating ood scores and complexity takes long time') 121 | print('process ',mood_args.id) 122 | 123 | dataloader = get_dataloader(mood_args.id, normalizer, mood_args.bs) 124 | if mood_args.score == 'mahalanobis': 125 | print('processing mahalanobis parameters') 126 | if mood_args.id == 'cifar10': 127 | num_classes = 10 128 | magnitude = 0.012 129 | elif mood_args.id == 'cifar100': 130 | num_classes = 100 131 | magnitude = 0.006 132 | else: 133 | print('did not support this in distribution dataset!') 134 | # get fake feature list 135 | model.eval() 136 | temp_x = torch.rand(2,3,32,32).cuda() 137 | temp_list = model(temp_x)[1] 138 | num_output = len(temp_list) 139 | feature_list = np.empty(num_output) 140 | count = 0 141 | for out in temp_list: 142 | feature_list[count] = out.size(1) 143 | count += 1 144 | sample_mean, precision = sample_estimator(model, num_classes, feature_list, dataloader) 145 | data_output = open('mahalanobis_parameters/sample_mean.pkl','wb') 146 | pickle.dump(sample_mean, data_output) 147 | data_output.close() 148 | data_output = open('mahalanobis_parameters/precision.pkl','wb') 149 | pickle.dump(precision, data_output) 150 | data_output.close() 151 | data_output = open('mahalanobis_parameters/num_classes.pkl','wb') 152 | pickle.dump(num_classes, data_output) 153 | data_output.close() 154 | data_output = open('mahalanobis_parameters/magnitude.pkl','wb') 155 | pickle.dump(magnitude, data_output) 156 | data_output.close() 157 | print('processing mahalanobis parameters finished!') 158 | 159 | i_score, i_adjusted_score, i_complexity = get_ood_score(data_name=mood_args.id, 160 | model=model, 161 | L=mood_args.layer, 162 | dataloader=dataloader, 163 | score_type=mood_args.score, 164 | threshold=mood_args.threshold, 165 | NM=NM, 166 | adjusted_mode=0, 167 | mean=None, 168 | cal_complexity=True 169 | ) 170 | mean=[] 171 | for i in range(mood_args.layer): 172 | mean.append( np.mean(i_score[i]) ) 173 | 174 | i_score, i_adjusted_score, i_complexity = get_ood_score(data_name=mood_args.id, 175 | model=model, 176 | L=mood_args.layer, 177 | dataloader=dataloader, 178 | score_type=mood_args.score, 179 | threshold=mood_args.threshold, 180 | NM=NM, 181 | adjusted_mode=mood_args.adjusted, 182 | mean=mean, 183 | cal_complexity=True 184 | ) 185 | auroc_base = [] 186 | fpr95_base = [] 187 | auroc_mood = [] 188 | fpr95_mood = [] 189 | auroc_for_barplot = [] 190 | complexity_for_arplot = [] 191 | for o_name in mood_args.od: 192 | print('process ',o_name) 193 | dataloader = get_dataloader(o_name, normalizer, mood_args.bs) 194 | o_score, o_adjusted_score, o_complexity = get_ood_score(data_name=o_name, 195 | model=model, 196 | L=mood_args.layer, 197 | dataloader=dataloader, 198 | score_type=mood_args.score, 199 | threshold=mood_args.threshold, 200 | NM=NM, 201 | adjusted_mode=mood_args.adjusted, 202 | mean=mean, 203 | cal_complexity=True 204 | ) 205 | auroc_base.append(auroc(i_score[-1], o_score[-1])) 206 | fpr95_base.append(fpr95(i_score[-1], o_score[-1])) 207 | auroc_mood.append(auroc(i_adjusted_score, o_adjusted_score)) 208 | fpr95_mood.append(fpr95(i_adjusted_score, o_adjusted_score)) 209 | auroc_for_barplot.append([auroc(i_score[i], o_score[i]) for i in range(mood_args.layer)]) 210 | complexity_for_arplot.append(o_complexity) 211 | 212 | print('********** auroc result ',mood_args.id,' with ',mood_args.score,' **********') 213 | print(' auroc fpr95 ') 214 | print('OOD dataset exit@last MOOD exit@last MOOD') 215 | for i in range(len(mood_args.od)): 216 | data_name=mood_args.od[i] 217 | data_name = data_name + ' '*(17-len(data_name)) 218 | print(data_name,"%.4f"%auroc_base[i],' ',"%.4f"%auroc_mood[i],' ',"%.4f"%fpr95_base[i],' ',"%.4f"%fpr95_mood[i]) 219 | data_name = 'average' 220 | data_name = data_name + ' '*(17-len(data_name)) 221 | print(data_name,"%.4f"%np.mean(auroc_base),' ',"%.4f"%np.mean(auroc_mood),' ',"%.4f"%np.mean(fpr95_base),' ',"%.4f"%np.mean(fpr95_mood)) 222 | 223 | 224 | if mood_args.score == 'energy' and mood_args.adjusted == 1 : 225 | flops = np.array([26621540, 51598536, 68873004, 88417936, 105102580]) 226 | auroc_score = np.array(auroc_for_barplot) 227 | S=20 228 | selected_datasets = mood_args.od 229 | selected_score = np.zeros_like(auroc_score) 230 | for k, complexity in enumerate(complexity_for_arplot): 231 | for i in range(mood_args.layer): 232 | index = (mood_args.threshold[i] 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) 87 | 88 | # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) 89 | # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs 90 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue 91 | # if the probability on a sub-pixel is below 1e-5, we use an approximation 92 | # based on the assumption that the log-density is constant in the bin of 93 | # the observed sub-pixel value 94 | 95 | inner_inner_cond = (cdf_delta > 1e-5).float() 96 | inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (log_pdf_mid - np.log(127.5)) 97 | inner_cond = (x > 0.999).float() 98 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 99 | cond = (x < -0.999).float() 100 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 101 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) 102 | 103 | return -torch.sum(log_sum_exp(log_probs)) 104 | 105 | 106 | def discretized_mix_logistic_loss_1d(x, l): 107 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 108 | # Pytorch ordering 109 | x = x.permute(0, 2, 3, 1) 110 | l = l.permute(0, 2, 3, 1) 111 | xs = [int(y) for y in x.size()] 112 | ls = [int(y) for y in l.size()] 113 | 114 | # here and below: unpacking the params of the mixture of logistics 115 | nr_mix = int(ls[-1] / 3) 116 | logit_probs = l[:, :, :, :nr_mix] 117 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # 2 for mean, scale 118 | means = l[:, :, :, :, :nr_mix] 119 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) 120 | # here and below: getting the means and adjusting them based on preceding 121 | # sub-pixels 122 | x = x.contiguous() 123 | x = x.unsqueeze(-1) + Variable(torch.zeros(xs + [nr_mix]).cuda(), requires_grad=False) 124 | 125 | # means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3) 126 | centered_x = x - means 127 | inv_stdv = torch.exp(-log_scales) 128 | plus_in = inv_stdv * (centered_x + 1. / 255.) 129 | cdf_plus = F.sigmoid(plus_in) 130 | min_in = inv_stdv * (centered_x - 1. / 255.) 131 | cdf_min = F.sigmoid(min_in) 132 | # log probability for edge case of 0 (before scaling) 133 | log_cdf_plus = plus_in - F.softplus(plus_in) 134 | # log probability for edge case of 255 (before scaling) 135 | log_one_minus_cdf_min = -F.softplus(min_in) 136 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 137 | mid_in = inv_stdv * centered_x 138 | # log probability in the center of the bin, to be used in extreme cases 139 | # (not actually used in our code) 140 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 141 | 142 | inner_inner_cond = (cdf_delta > 1e-5).float() 143 | inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (log_pdf_mid - np.log(127.5)) 144 | inner_cond = (x > 0.999).float() 145 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 146 | cond = (x < -0.999).float() 147 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 148 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) 149 | 150 | return -torch.sum(log_sum_exp(log_probs)) 151 | 152 | 153 | def to_one_hot(tensor, n, fill_with=1.): 154 | # we perform one hot encore with respect to the last axis 155 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 156 | if tensor.is_cuda : one_hot = one_hot.cuda() 157 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 158 | return Variable(one_hot) 159 | 160 | 161 | def sample_from_discretized_mix_logistic_1d(l, nr_mix): 162 | # Pytorch ordering 163 | l = l.permute(0, 2, 3, 1) 164 | ls = [int(y) for y in l.size()] 165 | xs = ls[:-1] + [1] #[3] 166 | 167 | # unpack parameters 168 | logit_probs = l[:, :, :, :nr_mix] 169 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # for mean, scale 170 | 171 | # sample mixture indicator from softmax 172 | temp = torch.FloatTensor(logit_probs.size()) 173 | if l.is_cuda : temp = temp.cuda() 174 | temp.uniform_(1e-5, 1. - 1e-5) 175 | temp = logit_probs.data - torch.log(- torch.log(temp)) 176 | _, argmax = temp.max(dim=3) 177 | 178 | one_hot = to_one_hot(argmax, nr_mix) 179 | sel = one_hot.view(xs[:-1] + [1, nr_mix]) 180 | # select logistic parameters 181 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 182 | log_scales = torch.clamp(torch.sum( 183 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) 184 | u = torch.FloatTensor(means.size()) 185 | if l.is_cuda : u = u.cuda() 186 | u.uniform_(1e-5, 1. - 1e-5) 187 | u = Variable(u) 188 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 189 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) 190 | out = x0.unsqueeze(1) 191 | return out 192 | 193 | 194 | def sample_from_discretized_mix_logistic(l, nr_mix): 195 | # Pytorch ordering 196 | l = l.permute(0, 2, 3, 1) 197 | ls = [int(y) for y in l.size()] 198 | xs = ls[:-1] + [3] 199 | 200 | # unpack parameters 201 | logit_probs = l[:, :, :, :nr_mix] 202 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) 203 | # sample mixture indicator from softmax 204 | temp = torch.FloatTensor(logit_probs.size()) 205 | if l.is_cuda : temp = temp.cuda() 206 | temp.uniform_(1e-5, 1. - 1e-5) 207 | temp = logit_probs.data - torch.log(- torch.log(temp)) 208 | _, argmax = temp.max(dim=3) 209 | 210 | one_hot = to_one_hot(argmax, nr_mix) 211 | sel = one_hot.view(xs[:-1] + [1, nr_mix]) 212 | # select logistic parameters 213 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 214 | log_scales = torch.clamp(torch.sum( 215 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) 216 | coeffs = torch.sum(F.tanh( 217 | l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, dim=4) 218 | # sample from logistic & clip to interval 219 | # we don't actually round to the nearest 8bit value when sampling 220 | u = torch.FloatTensor(means.size()) 221 | if l.is_cuda : u = u.cuda() 222 | u.uniform_(1e-5, 1. - 1e-5) 223 | u = Variable(u) 224 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 225 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) 226 | x1 = torch.clamp(torch.clamp( 227 | x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-1.), max=1.) 228 | x2 = torch.clamp(torch.clamp( 229 | x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=-1.), max=1.) 230 | 231 | out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3) 232 | # put back in Pytorch ordering 233 | out = out.permute(0, 3, 1, 2) 234 | return out 235 | 236 | 237 | 238 | ''' utilities for shifting the image around, efficient alternative to masking convolutions ''' 239 | def down_shift(x, pad=None): 240 | # Pytorch ordering 241 | xs = [int(y) for y in x.size()] 242 | # when downshifting, the last row is removed 243 | x = x[:, :, :xs[2] - 1, :] 244 | # padding left, padding right, padding top, padding bottom 245 | pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad 246 | return pad(x) 247 | 248 | 249 | def right_shift(x, pad=None): 250 | # Pytorch ordering 251 | xs = [int(y) for y in x.size()] 252 | # when righshifting, the last column is removed 253 | x = x[:, :, :, :xs[3] - 1] 254 | # padding left, padding right, padding top, padding bottom 255 | pad = nn.ZeroPad2d((1, 0, 0, 0)) if pad is None else pad 256 | return pad(x) 257 | 258 | 259 | def load_part_of_model(model, path): 260 | params = torch.load(path) 261 | added = 0 262 | for name, param in params.items(): 263 | if name in model.state_dict().keys(): 264 | try : 265 | model.state_dict()[name].copy_(param) 266 | added += 1 267 | except Exception as e: 268 | print(e) 269 | pass 270 | print('added %s of params:' % (added / float(len(model.state_dict().keys())))) 271 | -------------------------------------------------------------------------------- /Flops/Glow-PyTorch-master/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils import split_feature, compute_same_pad 8 | 9 | 10 | def gaussian_p(mean, logs, x): 11 | """ 12 | lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } 13 | k = 1 (Independent) 14 | Var = logs ** 2 15 | """ 16 | c = math.log(2 * math.pi) 17 | return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c) 18 | 19 | 20 | def gaussian_likelihood(mean, logs, x): 21 | p = gaussian_p(mean, logs, x) 22 | return torch.sum(p, dim=[1, 2, 3]) 23 | 24 | 25 | def gaussian_sample(mean, logs, temperature=1): 26 | # Sample from Gaussian with temperature 27 | z = torch.normal(mean, torch.exp(logs) * temperature) 28 | 29 | return z 30 | 31 | 32 | def squeeze2d(input, factor): 33 | if factor == 1: 34 | return input 35 | 36 | B, C, H, W = input.size() 37 | 38 | assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" 39 | 40 | x = input.view(B, C, H // factor, factor, W // factor, factor) 41 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 42 | x = x.view(B, C * factor * factor, H // factor, W // factor) 43 | 44 | return x 45 | 46 | 47 | def unsqueeze2d(input, factor): 48 | if factor == 1: 49 | return input 50 | 51 | factor2 = factor ** 2 52 | 53 | B, C, H, W = input.size() 54 | 55 | assert C % (factor2) == 0, "C module factor squared is not 0" 56 | 57 | x = input.view(B, C // factor2, factor, factor, H, W) 58 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 59 | x = x.view(B, C // (factor2), H * factor, W * factor) 60 | 61 | return x 62 | 63 | 64 | class _ActNorm(nn.Module): 65 | """ 66 | Activation Normalization 67 | Initialize the bias and scale with a given minibatch, 68 | so that the output per-channel have zero mean and unit variance for that. 69 | 70 | After initialization, `bias` and `logs` will be trained as parameters. 71 | """ 72 | 73 | def __init__(self, num_features, scale=1.0): 74 | super().__init__() 75 | # register mean and scale 76 | size = [1, num_features, 1, 1] 77 | self.bias = nn.Parameter(torch.zeros(*size)) 78 | self.logs = nn.Parameter(torch.zeros(*size)) 79 | self.num_features = num_features 80 | self.scale = scale 81 | self.inited = False 82 | 83 | def initialize_parameters(self, input): 84 | if not self.training: 85 | raise ValueError("In Eval mode, but ActNorm not inited") 86 | 87 | with torch.no_grad(): 88 | bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True) 89 | vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) 90 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) 91 | 92 | self.bias.data.copy_(bias.data) 93 | self.logs.data.copy_(logs.data) 94 | 95 | self.inited = True 96 | 97 | def _center(self, input, reverse=False): 98 | if reverse: 99 | return input - self.bias 100 | else: 101 | return input + self.bias 102 | 103 | def _scale(self, input, logdet=None, reverse=False): 104 | 105 | if reverse: 106 | input = input * torch.exp(-self.logs) 107 | else: 108 | input = input * torch.exp(self.logs) 109 | 110 | if logdet is not None: 111 | """ 112 | logs is log_std of `mean of channels` 113 | so we need to multiply by number of pixels 114 | """ 115 | b, c, h, w = input.shape 116 | 117 | dlogdet = torch.sum(self.logs) * h * w 118 | 119 | if reverse: 120 | dlogdet *= -1 121 | 122 | logdet = logdet + dlogdet 123 | 124 | return input, logdet 125 | 126 | def forward(self, input, logdet=None, reverse=False): 127 | self._check_input_dim(input) 128 | 129 | if not self.inited: 130 | self.initialize_parameters(input) 131 | 132 | if reverse: 133 | input, logdet = self._scale(input, logdet, reverse) 134 | input = self._center(input, reverse) 135 | else: 136 | input = self._center(input, reverse) 137 | input, logdet = self._scale(input, logdet, reverse) 138 | 139 | return input, logdet 140 | 141 | 142 | class ActNorm2d(_ActNorm): 143 | def __init__(self, num_features, scale=1.0): 144 | super().__init__(num_features, scale) 145 | 146 | def _check_input_dim(self, input): 147 | assert len(input.size()) == 4 148 | assert input.size(1) == self.num_features, ( 149 | "[ActNorm]: input should be in shape as `BCHW`," 150 | " channels should be {} rather than {}".format( 151 | self.num_features, input.size() 152 | ) 153 | ) 154 | 155 | 156 | class LinearZeros(nn.Module): 157 | def __init__(self, in_channels, out_channels, logscale_factor=3): 158 | super().__init__() 159 | 160 | self.linear = nn.Linear(in_channels, out_channels) 161 | self.linear.weight.data.zero_() 162 | self.linear.bias.data.zero_() 163 | 164 | self.logscale_factor = logscale_factor 165 | 166 | self.logs = nn.Parameter(torch.zeros(out_channels)) 167 | 168 | def forward(self, input): 169 | output = self.linear(input) 170 | return output * torch.exp(self.logs * self.logscale_factor) 171 | 172 | 173 | class Conv2d(nn.Module): 174 | def __init__( 175 | self, 176 | in_channels, 177 | out_channels, 178 | kernel_size=(3, 3), 179 | stride=(1, 1), 180 | padding="same", 181 | do_actnorm=True, 182 | weight_std=0.05, 183 | ): 184 | super().__init__() 185 | 186 | if padding == "same": 187 | padding = compute_same_pad(kernel_size, stride) 188 | elif padding == "valid": 189 | padding = 0 190 | 191 | self.conv = nn.Conv2d( 192 | in_channels, 193 | out_channels, 194 | kernel_size, 195 | stride, 196 | padding, 197 | bias=(not do_actnorm), 198 | ) 199 | 200 | # init weight with std 201 | self.conv.weight.data.normal_(mean=0.0, std=weight_std) 202 | 203 | if not do_actnorm: 204 | self.conv.bias.data.zero_() 205 | else: 206 | self.actnorm = ActNorm2d(out_channels) 207 | 208 | self.do_actnorm = do_actnorm 209 | 210 | def forward(self, input): 211 | x = self.conv(input) 212 | if self.do_actnorm: 213 | x, _ = self.actnorm(x) 214 | return x 215 | 216 | 217 | class Conv2dZeros(nn.Module): 218 | def __init__( 219 | self, 220 | in_channels, 221 | out_channels, 222 | kernel_size=(3, 3), 223 | stride=(1, 1), 224 | padding="same", 225 | logscale_factor=3, 226 | ): 227 | super().__init__() 228 | 229 | if padding == "same": 230 | padding = compute_same_pad(kernel_size, stride) 231 | elif padding == "valid": 232 | padding = 0 233 | 234 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 235 | 236 | self.conv.weight.data.zero_() 237 | self.conv.bias.data.zero_() 238 | 239 | self.logscale_factor = logscale_factor 240 | self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1)) 241 | 242 | def forward(self, input): 243 | output = self.conv(input) 244 | return output * torch.exp(self.logs * self.logscale_factor) 245 | 246 | 247 | class Permute2d(nn.Module): 248 | def __init__(self, num_channels, shuffle): 249 | super().__init__() 250 | self.num_channels = num_channels 251 | self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long) 252 | self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long) 253 | 254 | for i in range(self.num_channels): 255 | self.indices_inverse[self.indices[i]] = i 256 | 257 | if shuffle: 258 | self.reset_indices() 259 | 260 | def reset_indices(self): 261 | shuffle_idx = torch.randperm(self.indices.shape[0]) 262 | self.indices = self.indices[shuffle_idx] 263 | 264 | for i in range(self.num_channels): 265 | self.indices_inverse[self.indices[i]] = i 266 | 267 | def forward(self, input, reverse=False): 268 | assert len(input.size()) == 4 269 | 270 | if not reverse: 271 | input = input[:, self.indices, :, :] 272 | return input 273 | else: 274 | return input[:, self.indices_inverse, :, :] 275 | 276 | 277 | class Split2d(nn.Module): 278 | def __init__(self, num_channels): 279 | super().__init__() 280 | self.conv = Conv2dZeros(num_channels // 2, num_channels) 281 | 282 | def split2d_prior(self, z): 283 | h = self.conv(z) 284 | return split_feature(h, "cross") 285 | 286 | def forward(self, input, logdet=0.0, reverse=False, temperature=None): 287 | if reverse: 288 | z1 = input 289 | mean, logs = self.split2d_prior(z1) 290 | z2 = gaussian_sample(mean, logs, temperature) 291 | z = torch.cat((z1, z2), dim=1) 292 | return z, logdet 293 | else: 294 | z1, z2 = split_feature(input, "split") 295 | mean, logs = self.split2d_prior(z1) 296 | logdet = gaussian_likelihood(mean, logs, z2) + logdet 297 | return z1, logdet 298 | 299 | 300 | class SqueezeLayer(nn.Module): 301 | def __init__(self, factor): 302 | super().__init__() 303 | self.factor = factor 304 | 305 | def forward(self, input, logdet=None, reverse=False): 306 | if reverse: 307 | output = unsqueeze2d(input, self.factor) 308 | else: 309 | output = squeeze2d(input, self.factor) 310 | 311 | return output, logdet 312 | 313 | 314 | class InvertibleConv1x1(nn.Module): 315 | def __init__(self, num_channels, LU_decomposed): 316 | super().__init__() 317 | w_shape = [num_channels, num_channels] 318 | w_init = torch.qr(torch.randn(*w_shape))[0] 319 | 320 | if not LU_decomposed: 321 | self.weight = nn.Parameter(torch.Tensor(w_init)) 322 | else: 323 | p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) 324 | s = torch.diag(upper) 325 | sign_s = torch.sign(s) 326 | log_s = torch.log(torch.abs(s)) 327 | upper = torch.triu(upper, 1) 328 | l_mask = torch.tril(torch.ones(w_shape), -1) 329 | eye = torch.eye(*w_shape) 330 | 331 | self.register_buffer("p", p) 332 | self.register_buffer("sign_s", sign_s) 333 | self.lower = nn.Parameter(lower) 334 | self.log_s = nn.Parameter(log_s) 335 | self.upper = nn.Parameter(upper) 336 | self.l_mask = l_mask 337 | self.eye = eye 338 | 339 | self.w_shape = w_shape 340 | self.LU_decomposed = LU_decomposed 341 | 342 | def get_weight(self, input, reverse): 343 | b, c, h, w = input.shape 344 | 345 | if not self.LU_decomposed: 346 | dlogdet = torch.slogdet(self.weight)[1] * h * w 347 | if reverse: 348 | weight = torch.inverse(self.weight) 349 | else: 350 | weight = self.weight 351 | else: 352 | self.l_mask = self.l_mask.to(input.device) 353 | self.eye = self.eye.to(input.device) 354 | 355 | lower = self.lower * self.l_mask + self.eye 356 | 357 | u = self.upper * self.l_mask.transpose(0, 1).contiguous() 358 | u += torch.diag(self.sign_s * torch.exp(self.log_s)) 359 | 360 | dlogdet = torch.sum(self.log_s) * h * w 361 | 362 | if reverse: 363 | u_inv = torch.inverse(u) 364 | l_inv = torch.inverse(lower) 365 | p_inv = torch.inverse(self.p) 366 | 367 | weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) 368 | else: 369 | weight = torch.matmul(self.p, torch.matmul(lower, u)) 370 | 371 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 372 | 373 | def forward(self, input, logdet=None, reverse=False): 374 | """ 375 | log-det = log|abs(|W|)| * pixels 376 | """ 377 | weight, dlogdet = self.get_weight(input, reverse) 378 | 379 | if not reverse: 380 | z = F.conv2d(input, weight) 381 | if logdet is not None: 382 | logdet = logdet + dlogdet 383 | return z, logdet 384 | else: 385 | z = F.conv2d(input, weight) 386 | if logdet is not None: 387 | logdet = logdet - dlogdet 388 | return z, logdet 389 | -------------------------------------------------------------------------------- /utils/MOOD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | import cv2 5 | import pickle 6 | import numpy as np 7 | import time 8 | 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | def calculate_complex(batch_data,NM): 14 | MEAN = NM[0] 15 | STD = NM[1] 16 | MEAN= torch.from_numpy(np.array(MEAN)).view([3,1,1]) 17 | STD = torch.from_numpy(np.array(STD)).view([3,1,1]) 18 | 19 | S = batch_data.shape 20 | complexity = np.zeros([S[0],1]) 21 | for i in range(S[0]): 22 | img = batch_data[i] 23 | img = (img*STD+MEAN)*255. 24 | img = img.numpy().transpose([1,2,0]) 25 | img = np.round(img).astype('uint8') 26 | cv2.imwrite('compressed.png', img, [cv2.IMWRITE_PNG_COMPRESSION , 9]) 27 | complexity[i] = os.path.getsize('compressed.png') 28 | return complexity 29 | 30 | def msp_score(pres, TF, L): 31 | for i in range(L): 32 | scores = np.max(F.softmax(pres[i], dim=1).detach().cpu().numpy(), axis=1) 33 | TF[i].append(scores) 34 | return scores 35 | 36 | def energy_score(pres, TF, L, T=1): 37 | for i in range(L): 38 | scores = T*torch.log( torch.sum( torch.exp(pres[i].detach().cpu().type(torch.DoubleTensor) ) / T, dim=1)).numpy() 39 | TF[i].append(scores) 40 | return scores 41 | 42 | def odin_score(inputs, TF, model, L, temper=1000, noiseMagnitude=0.001): 43 | for i in range(L): 44 | criterion = nn.CrossEntropyLoss() 45 | inputs = Variable(inputs, requires_grad = True) 46 | inputs = inputs.cuda() 47 | inputs.retain_grad() 48 | 49 | outputs = model(inputs)[0][i] 50 | 51 | maxIndexTemp = np.argmax(outputs.data.cpu().numpy(), axis=1) 52 | 53 | # Using temperature scaling 54 | outputs = outputs / temper 55 | 56 | labels = Variable(torch.LongTensor(maxIndexTemp).cuda()) 57 | loss = criterion(outputs, labels) 58 | loss.backward() 59 | 60 | # Normalizing the gradient to binary in {0, 1} 61 | gradient = torch.ge(inputs.grad.data, 0) 62 | gradient = (gradient.float() - 0.5) * 2 63 | 64 | # Adding small perturbations to images 65 | tempInputs = torch.add(inputs.data, -noiseMagnitude, gradient) 66 | outputs = model(Variable(tempInputs))[0][i] 67 | outputs = outputs / temper 68 | # Calculating the confidence after adding perturbations 69 | nnOutputs = outputs.data.cpu() 70 | nnOutputs = nnOutputs.numpy() 71 | nnOutputs = nnOutputs - np.max(nnOutputs, axis=1, keepdims=True) 72 | nnOutputs = np.exp(nnOutputs) / np.sum(np.exp(nnOutputs), axis=1, keepdims=True) 73 | scores = np.max(nnOutputs, axis=1) 74 | 75 | TF[i].append(scores) 76 | return scores 77 | 78 | def sample_estimator(model, num_classes, feature_list, data_loader): 79 | """ 80 | compute sample mean and precision (inverse of covariance) 81 | return: sample_class_mean: list of class mean 82 | precision: list of precisions 83 | """ 84 | import sklearn.covariance 85 | 86 | model.eval() 87 | group_lasso = sklearn.covariance.EmpiricalCovariance(assume_centered=False) 88 | correct, total = 0, 0 89 | num_output = len(feature_list) 90 | num_sample_per_class = np.empty(num_classes) 91 | num_sample_per_class.fill(0) 92 | list_features = [] 93 | for i in range(num_output): 94 | temp_list = [] 95 | for j in range(num_classes): 96 | temp_list.append(0) 97 | list_features.append(temp_list) 98 | 99 | for W in range(1): 100 | for data, target in data_loader: 101 | total += data.size(0) 102 | data = Variable(data) 103 | data = data.cuda() 104 | output, out_features = model(data) 105 | 106 | # get hidden features 107 | for i in range(num_output): 108 | out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1) 109 | out_features[i] = torch.mean(out_features[i].data, 2) 110 | 111 | # compute the accuracy 112 | output = output[-1] 113 | pred = output.data.max(1)[1] 114 | equal_flag = pred.eq(target.cuda()).cpu() 115 | correct += equal_flag.sum() 116 | 117 | # construct the sample matrix 118 | for i in range(data.size(0)): 119 | label = target[i] 120 | if num_sample_per_class[label] == 0: 121 | out_count = 0 122 | for out in out_features: 123 | list_features[out_count][label] = out[i].view(1, -1) 124 | out_count += 1 125 | else: 126 | out_count = 0 127 | for out in out_features: 128 | list_features[out_count][label] \ 129 | = torch.cat((list_features[out_count][label], out[i].view(1, -1)), 0) 130 | out_count += 1 131 | num_sample_per_class[label] += 1 132 | 133 | sample_class_mean = [] 134 | out_count = 0 135 | for num_feature in feature_list: 136 | temp_list = torch.Tensor(num_classes, int(num_feature)).cuda() 137 | for j in range(num_classes): 138 | temp_list[j] = torch.mean(list_features[out_count][j], 0) 139 | sample_class_mean.append(temp_list) 140 | out_count += 1 141 | 142 | precision = [] 143 | for k in range(num_output): 144 | X = 0 145 | for i in range(num_classes): 146 | if i == 0: 147 | X = list_features[k][i] - sample_class_mean[k][i] 148 | else: 149 | X = torch.cat((X, list_features[k][i] - sample_class_mean[k][i]), 0) 150 | 151 | # find inverse 152 | group_lasso.fit(X.cpu().numpy()) 153 | temp_precision = group_lasso.precision_ 154 | temp_precision = torch.from_numpy(temp_precision).float().cuda() 155 | precision.append(temp_precision) 156 | 157 | print('\n Training Accuracy:({:.2f}%)\n'.format(100. * correct / total)) 158 | 159 | return sample_class_mean, precision 160 | 161 | def mahalanobis_score(inputs, TF, model, L): 162 | data_input = open('mahalanobis_parameters/sample_mean.pkl','rb') 163 | sample_mean = pickle.load(data_input) 164 | data_input.close() 165 | data_input = open('mahalanobis_parameters/precision.pkl','rb') 166 | precision = pickle.load(data_input) 167 | data_input.close() 168 | data_input = open('mahalanobis_parameters/num_classes.pkl','rb') 169 | num_classes = pickle.load(data_input) 170 | data_input.close() 171 | data_input = open('mahalanobis_parameters/magnitude.pkl','rb') 172 | magnitude = pickle.load(data_input) 173 | data_input.close() 174 | for layer_index in range(L): 175 | data = Variable(inputs, requires_grad = True) 176 | data = data.cuda() 177 | data.retain_grad() 178 | out_features = model(data)[1][layer_index] 179 | 180 | out_features = out_features.view(out_features.size(0), out_features.size(1), -1) 181 | out_features = torch.mean(out_features, 2) 182 | 183 | gaussian_score = 0 184 | for i in range(num_classes): 185 | batch_sample_mean = sample_mean[layer_index][i] 186 | zero_f = out_features.data - batch_sample_mean 187 | term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag() 188 | if i == 0: 189 | gaussian_score = term_gau.view(-1,1) 190 | else: 191 | gaussian_score = torch.cat((gaussian_score, term_gau.view(-1,1)), 1) 192 | 193 | # Input_processing 194 | sample_pred = gaussian_score.max(1)[1] 195 | batch_sample_mean = sample_mean[layer_index].index_select(0, sample_pred) 196 | zero_f = out_features - Variable(batch_sample_mean) 197 | pure_gau = -0.5*torch.mm(torch.mm(zero_f, Variable(precision[layer_index])), zero_f.t()).diag() 198 | loss = torch.mean(-pure_gau) 199 | loss.backward() 200 | 201 | gradient = torch.ge(data.grad.data, 0) 202 | gradient = (gradient.float() - 0.5) * 2 203 | 204 | tempInputs = torch.add(data.data, -magnitude, gradient) 205 | 206 | noise_out_features = model(Variable(tempInputs))[1][layer_index] 207 | noise_out_features = noise_out_features.view(noise_out_features.size(0), noise_out_features.size(1), -1) 208 | noise_out_features = torch.mean(noise_out_features, 2) 209 | noise_gaussian_score = 0 210 | for i in range(num_classes): 211 | batch_sample_mean = sample_mean[layer_index][i] 212 | zero_f = noise_out_features.data - batch_sample_mean 213 | term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag() 214 | if i == 0: 215 | noise_gaussian_score = term_gau.view(-1,1) 216 | else: 217 | noise_gaussian_score = torch.cat((noise_gaussian_score, term_gau.view(-1,1)), 1) 218 | 219 | noise_gaussian_score, _ = torch.max(noise_gaussian_score, dim=1) 220 | 221 | noise_gaussian_score = np.asarray(noise_gaussian_score.cpu().numpy(), dtype=np.float32) 222 | if layer_index == 0: 223 | Mahalanobis_scores = noise_gaussian_score.reshape((noise_gaussian_score.shape[0], -1)) 224 | else: 225 | Mahalanobis_scores = np.concatenate((Mahalanobis_scores, noise_gaussian_score.reshape((noise_gaussian_score.shape[0], -1))), axis=1) 226 | 227 | for i in range(L): 228 | TF[i].append(Mahalanobis_scores[:, i]) 229 | return Mahalanobis_scores 230 | 231 | def cut_transfer(L, threshold, energy, complexity, mean): 232 | cut_score = [] 233 | for i in range(L): 234 | index = (threshold[i]=tpr95: 311 | break 312 | fpr0=fpr1 313 | tpr0=tpr1 314 | fpr95 = ((tpr95-tpr0)*fpr1 + (tpr1-tpr95)*fpr0) / (tpr1-tpr0) 315 | return fpr95 -------------------------------------------------------------------------------- /models/msdnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | class ConvBasic(nn.Module): 6 | def __init__(self, nIn, nOut, kernel=3, stride=1, 7 | padding=1): 8 | super(ConvBasic, self).__init__() 9 | self.net = nn.Sequential( 10 | nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, 11 | padding=padding, bias=False), 12 | nn.BatchNorm2d(nOut), 13 | nn.ReLU(True) 14 | ) 15 | 16 | def forward(self, x): 17 | return self.net(x) 18 | 19 | 20 | class ConvBN(nn.Module): 21 | def __init__(self, nIn, nOut, type: str, bottleneck, 22 | bnWidth): 23 | """ 24 | a basic conv in MSDNet, two type 25 | :param nIn: 26 | :param nOut: 27 | :param type: normal or down 28 | :param bottleneck: use bottlenet or not 29 | :param bnWidth: bottleneck factor 30 | """ 31 | super(ConvBN, self).__init__() 32 | layer = [] 33 | nInner = nIn 34 | if bottleneck is True: 35 | nInner = min(nInner, bnWidth * nOut) 36 | layer.append(nn.Conv2d( 37 | nIn, nInner, kernel_size=1, stride=1, padding=0, bias=False)) 38 | layer.append(nn.BatchNorm2d(nInner)) 39 | layer.append(nn.ReLU(True)) 40 | 41 | if type == 'normal': 42 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3, 43 | stride=1, padding=1, bias=False)) 44 | elif type == 'down': 45 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3, 46 | stride=2, padding=1, bias=False)) 47 | else: 48 | raise ValueError 49 | 50 | layer.append(nn.BatchNorm2d(nOut)) 51 | layer.append(nn.ReLU(True)) 52 | 53 | self.net = nn.Sequential(*layer) 54 | 55 | def forward(self, x): 56 | # print(self.net, x.size()) 57 | # pdb.set_trace() 58 | return self.net(x) 59 | 60 | 61 | class ConvDownNormal(nn.Module): 62 | def __init__(self, nIn1, nIn2, nOut, bottleneck, bnWidth1, bnWidth2): 63 | super(ConvDownNormal, self).__init__() 64 | self.conv_down = ConvBN(nIn1, nOut // 2, 'down', 65 | bottleneck, bnWidth1) 66 | self.conv_normal = ConvBN(nIn2, nOut // 2, 'normal', 67 | bottleneck, bnWidth2) 68 | 69 | def forward(self, x): 70 | # print(self.conv_down, self.conv_normal, '========') 71 | # pdb.set_trace() 72 | res = [x[1], 73 | self.conv_down(x[0]), 74 | self.conv_normal(x[1])] 75 | # print(res[0].size(), res[1].size(), res[2].size()) 76 | return torch.cat(res, dim=1) 77 | 78 | 79 | class ConvNormal(nn.Module): 80 | def __init__(self, nIn, nOut, bottleneck, bnWidth): 81 | super(ConvNormal, self).__init__() 82 | self.conv_normal = ConvBN(nIn, nOut, 'normal', 83 | bottleneck, bnWidth) 84 | 85 | def forward(self, x): 86 | if not isinstance(x, list): 87 | x = [x] 88 | res = [x[0], 89 | self.conv_normal(x[0])] 90 | 91 | return torch.cat(res, dim=1) 92 | 93 | 94 | class MSDNFirstLayer(nn.Module): 95 | def __init__(self, nIn, nOut, args): 96 | super(MSDNFirstLayer, self).__init__() 97 | self.layers = nn.ModuleList() 98 | if args.data.startswith('cifar'): 99 | self.layers.append(ConvBasic(nIn, nOut * args.grFactor[0], 100 | kernel=3, stride=1, padding=1)) 101 | elif args.data == 'ImageNet': 102 | conv = nn.Sequential( 103 | nn.Conv2d(nIn, nOut * args.grFactor[0], 7, 2, 3), 104 | nn.BatchNorm2d(nOut * args.grFactor[0]), 105 | nn.ReLU(inplace=True), 106 | nn.MaxPool2d(3, 2, 1)) 107 | self.layers.append(conv) 108 | 109 | nIn = nOut * args.grFactor[0] 110 | 111 | for i in range(1, args.nScales): 112 | self.layers.append(ConvBasic(nIn, nOut * args.grFactor[i], 113 | kernel=3, stride=2, padding=1)) 114 | nIn = nOut * args.grFactor[i] 115 | 116 | def forward(self, x): 117 | res = [] 118 | for i in range(len(self.layers)): 119 | x = self.layers[i](x) 120 | res.append(x) 121 | 122 | return res 123 | 124 | 125 | class MSDNLayer(nn.Module): 126 | def __init__(self, nIn, nOut, args, inScales=None, outScales=None): 127 | super(MSDNLayer, self).__init__() 128 | self.nIn = nIn 129 | self.nOut = nOut 130 | self.inScales = inScales if inScales is not None else args.nScales 131 | self.outScales = outScales if outScales is not None else args.nScales 132 | 133 | self.nScales = args.nScales 134 | self.discard = self.inScales - self.outScales 135 | 136 | self.offset = self.nScales - self.outScales 137 | self.layers = nn.ModuleList() 138 | 139 | if self.discard > 0: 140 | nIn1 = nIn * args.grFactor[self.offset - 1] 141 | nIn2 = nIn * args.grFactor[self.offset] 142 | _nOut = nOut * args.grFactor[self.offset] 143 | self.layers.append(ConvDownNormal(nIn1, nIn2, _nOut, args.bottleneck, 144 | args.bnFactor[self.offset - 1], 145 | args.bnFactor[self.offset])) 146 | else: 147 | self.layers.append(ConvNormal(nIn * args.grFactor[self.offset], 148 | nOut * args.grFactor[self.offset], 149 | args.bottleneck, 150 | args.bnFactor[self.offset])) 151 | 152 | for i in range(self.offset + 1, self.nScales): 153 | nIn1 = nIn * args.grFactor[i - 1] 154 | nIn2 = nIn * args.grFactor[i] 155 | _nOut = nOut * args.grFactor[i] 156 | self.layers.append(ConvDownNormal(nIn1, nIn2, _nOut, args.bottleneck, 157 | args.bnFactor[i - 1], 158 | args.bnFactor[i])) 159 | 160 | def forward(self, x): 161 | if self.discard > 0: 162 | inp = [] 163 | for i in range(1, self.outScales + 1): 164 | inp.append([x[i - 1], x[i]]) 165 | else: 166 | inp = [[x[0]]] 167 | for i in range(1, self.outScales): 168 | inp.append([x[i - 1], x[i]]) 169 | 170 | res = [] 171 | for i in range(self.outScales): 172 | res.append(self.layers[i](inp[i])) 173 | 174 | return res 175 | 176 | 177 | class ParallelModule(nn.Module): 178 | """ 179 | This module is similar to luatorch's Parallel Table 180 | input: N tensor 181 | network: N module 182 | output: N tensor 183 | """ 184 | def __init__(self, parallel_modules): 185 | super(ParallelModule, self).__init__() 186 | self.m = nn.ModuleList(parallel_modules) 187 | 188 | def forward(self, x): 189 | res = [] 190 | for i in range(len(x)): 191 | res.append(self.m[i](x[i])) 192 | 193 | return res 194 | 195 | 196 | class ClassifierModule(nn.Module): 197 | def __init__(self, m, channel, num_classes): 198 | super(ClassifierModule, self).__init__() 199 | self.m = m 200 | self.linear = nn.Linear(channel, num_classes) 201 | 202 | def forward(self, x): 203 | res = self.m(x[-1]) 204 | res = res.view(res.size(0), -1) 205 | return self.linear(res), res 206 | 207 | class MSDNet(nn.Module): 208 | def __init__(self, args): 209 | super(MSDNet, self).__init__() 210 | self.blocks = nn.ModuleList() 211 | self.classifier = nn.ModuleList() 212 | self.nBlocks = args.nBlocks 213 | self.steps = [args.base] 214 | self.args = args 215 | # todo: how many block? 216 | n_layers_all, n_layer_curr = args.base, 0 217 | for i in range(1, self.nBlocks): 218 | self.steps.append(args.step if args.stepmode == 'even' 219 | else args.step * i + 1) 220 | n_layers_all += self.steps[-1] 221 | 222 | print("building network of steps: ") 223 | print(self.steps, n_layers_all) 224 | 225 | nIn = args.nChannels 226 | for i in range(self.nBlocks): 227 | print(' ********************** Block {} ' 228 | ' **********************'.format(i + 1)) 229 | m, nIn = \ 230 | self._build_block(nIn, args, self.steps[i], 231 | n_layers_all, n_layer_curr) 232 | self.blocks.append(m) 233 | n_layer_curr += self.steps[i] 234 | 235 | if args.data.startswith('cifar100'): 236 | self.classifier.append( 237 | self._build_classifier_cifar(nIn * args.grFactor[-1], 100)) 238 | elif args.data.startswith('cifar10'): 239 | self.classifier.append( 240 | self._build_classifier_cifar(nIn * args.grFactor[-1], 10)) 241 | elif args.data == 'ImageNet': 242 | self.classifier.append( 243 | self._build_classifier_imagenet(nIn * args.grFactor[-1], 1000)) 244 | else: 245 | raise NotImplementedError 246 | 247 | # adding initialization functions 248 | for m in self.blocks: 249 | if hasattr(m, '__iter__'): 250 | for _m in m: 251 | self._init_weights(_m) 252 | else: 253 | self._init_weights(m) 254 | 255 | for m in self.classifier: 256 | if hasattr(m, '__iter__'): 257 | for _m in m: 258 | self._init_weights(_m) 259 | else: 260 | self._init_weights(m) 261 | 262 | def _init_weights(self, m): 263 | if isinstance(m, nn.Conv2d): 264 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 265 | m.weight.data.normal_(0, math.sqrt(2. / n)) 266 | elif isinstance(m, nn.BatchNorm2d): 267 | m.weight.data.fill_(1) 268 | m.bias.data.zero_() 269 | elif isinstance(m, nn.Linear): 270 | m.bias.data.zero_() 271 | 272 | def _build_block(self, nIn, args, step, n_layer_all, n_layer_curr): 273 | 274 | layers = [MSDNFirstLayer(3, nIn, args)] \ 275 | if n_layer_curr == 0 else [] 276 | for i in range(step): 277 | n_layer_curr += 1 278 | inScales = args.nScales 279 | outScales = args.nScales 280 | if args.prune == 'min': 281 | inScales = min(args.nScales, n_layer_all - n_layer_curr + 2) 282 | outScales = min(args.nScales, n_layer_all - n_layer_curr + 1) 283 | elif args.prune == 'max': 284 | interval = math.ceil(1.0 * n_layer_all / args.nScales) 285 | inScales = args.nScales - math.floor(1.0 * (max(0, n_layer_curr - 2)) / interval) 286 | outScales = args.nScales - math.floor(1.0 * (n_layer_curr - 1) / interval) 287 | # print(i, interval, inScales, outScales, n_layer_curr, n_layer_all) 288 | else: 289 | raise ValueError 290 | # print('|\t\tinScales {} outScales {}\t\t\t|'.format(inScales, outScales)) 291 | 292 | layers.append(MSDNLayer(nIn, args.growthRate, args, inScales, outScales)) 293 | print('|\t\tinScales {} outScales {} inChannels {} outChannels {}\t\t|'.format(inScales, outScales, nIn, args.growthRate)) 294 | 295 | nIn += args.growthRate 296 | if args.prune == 'max' and inScales > outScales and \ 297 | args.reduction > 0: 298 | offset = args.nScales - outScales 299 | layers.append( 300 | self._build_transition(nIn, math.floor(1.0 * args.reduction * nIn), 301 | outScales, offset, args)) 302 | _t = nIn 303 | nIn = math.floor(1.0 * args.reduction * nIn) 304 | print('|\t\tTransition layer inserted! (max), inChannels {}, outChannels {}\t|'.format(_t, math.floor(1.0 * args.reduction * _t))) 305 | elif args.prune == 'min' and args.reduction > 0 and \ 306 | ((n_layer_curr == math.floor(1.0 * n_layer_all / 3)) or 307 | n_layer_curr == math.floor(2.0 * n_layer_all / 3)): 308 | offset = args.nScales - outScales 309 | layers.append(self._build_transition(nIn, math.floor(1.0 * args.reduction * nIn), 310 | outScales, offset, args)) 311 | 312 | nIn = math.floor(1.0 * args.reduction * nIn) 313 | print('|\t\tTransition layer inserted! (min)\t|') 314 | print("") 315 | # print('|\t\tinScales {} outScales {} inChannels {} outChannels {}\t\t\t|'.format(inScales, outScales, in_channel, nIn)) 316 | 317 | return nn.Sequential(*layers), nIn 318 | 319 | def _build_transition(self, nIn, nOut, outScales, offset, args): 320 | net = [] 321 | for i in range(outScales): 322 | net.append(ConvBasic(nIn * args.grFactor[offset + i], 323 | nOut * args.grFactor[offset + i], 324 | kernel=1, stride=1, padding=0)) 325 | return ParallelModule(net) 326 | 327 | def _build_classifier_cifar(self, nIn, num_classes): 328 | interChannels1, interChannels2 = 128, 128 329 | conv = nn.Sequential( 330 | ConvBasic(nIn, interChannels1, kernel=3, stride=2, padding=1), 331 | ConvBasic(interChannels1, interChannels2, kernel=3, stride=2, padding=1), 332 | nn.AvgPool2d(2), 333 | ) 334 | return ClassifierModule(conv, interChannels2, num_classes) 335 | 336 | def _build_classifier_imagenet(self, nIn, num_classes): 337 | conv = nn.Sequential( 338 | ConvBasic(nIn, nIn, kernel=3, stride=2, padding=1), 339 | ConvBasic(nIn, nIn, kernel=3, stride=2, padding=1), 340 | nn.AvgPool2d(2) 341 | ) 342 | return ClassifierModule(conv, nIn, num_classes) 343 | 344 | def forward(self, x): 345 | res = [] 346 | feat = [] 347 | for i in range(self.nBlocks): 348 | # print('!!!!! The {}-th block !!!!!'.format(i)) 349 | x = self.blocks[i](x) 350 | pred, t = self.classifier[i](x) 351 | res.append(pred) 352 | feat.append(t) 353 | # res.append(self.classifier[i](x)) 354 | return res, feat 355 | # return res 356 | 357 | class WrappedModel(nn.Module): 358 | def __init__(self, module): 359 | super(WrappedModel, self).__init__() 360 | self.module = module 361 | 362 | def forward(self, x): 363 | return self.module(x) 364 | 365 | def msdnet(args): 366 | 367 | model = MSDNet(args) 368 | if args.pretrained is not None: 369 | print('!!!!!! Load pretrained model !!!!!!') 370 | model = WrappedModel(model) 371 | checkpoint = torch.load(args.pretrained) 372 | model.load_state_dict(checkpoint['state_dict']) 373 | return model 374 | 375 | --------------------------------------------------------------------------------