├── classification ├── __init__.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── wrn.cpython-38.pyc │ │ ├── wrn.cpython-39.pyc │ │ ├── route.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── densenet.cpython-38.pyc │ │ └── resnet.cpython-38.pyc │ ├── wrn.py │ └── resnet.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── lsun_loader.cpython-38.pyc │ │ ├── lsun_loader.cpython-39.pyc │ │ ├── svhn_loader.cpython-38.pyc │ │ ├── svhn_loader.cpython-39.pyc │ │ ├── ImbalanceCIFAR.cpython-38.pyc │ │ ├── ImbalanceCIFAR.cpython-39.pyc │ │ ├── display_results.cpython-38.pyc │ │ ├── display_results.cpython-39.pyc │ │ ├── tinyimages_300k.cpython-38.pyc │ │ ├── tinyimages_300k.cpython-39.pyc │ │ ├── score_calculation.cpython-38.pyc │ │ ├── score_calculation.cpython-39.pyc │ │ ├── validation_dataset.cpython-38.pyc │ │ ├── validation_dataset.cpython-39.pyc │ │ ├── SCOODBenchmarkDataset.cpython-38.pyc │ │ ├── SCOODBenchmarkDataset.cpython-39.pyc │ │ └── tinyimages_80mn_loader.cpython-38.pyc │ ├── SCOODBenchmarkDataset.py │ ├── tinyimages_300k.py │ ├── validation_dataset.py │ ├── cifar_resnet.py │ ├── tiny_resnet.py │ ├── lsun_loader.py │ ├── calibration_tools.py │ ├── svhn_loader.py │ ├── display_results.py │ ├── score_calculation.py │ └── ImbalanceCIFAR.py ├── inf_run_res.sh ├── inf_run_im_res.sh ├── inf_run_im_wide.sh └── LICENSE ├── segmentation ├── __init__.py ├── code │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── config.cpython-38.pyc │ │ │ ├── config.cpython-39.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── config_19inf.cpython-38.pyc │ │ │ ├── config_19inf.cpython-39.pyc │ │ │ ├── config_load.cpython-38.pyc │ │ │ ├── config_paper.cpython-38.pyc │ │ │ ├── config_19class.cpython-36.pyc │ │ │ ├── config_19class.cpython-38.pyc │ │ │ ├── config_19class.cpython-39.pyc │ │ │ ├── config_main_OE.cpython-38.pyc │ │ │ ├── config_main_add.cpython-38.pyc │ │ │ ├── config_19inf_init.cpython-38.pyc │ │ │ ├── config_main_energy.cpython-38.pyc │ │ │ ├── config_pebal_influence.cpython-38.pyc │ │ │ ├── config_main_add_energy_reg.cpython-38.pyc │ │ │ ├── config_main_add_energy_reg.cpython-39.pyc │ │ │ ├── config_pebal_influence_load.cpython-38.pyc │ │ │ ├── config_main_add_energy_reg_pure.cpython-38.pyc │ │ │ ├── config_main_energy_pure_weighted.cpython-38.pyc │ │ │ ├── config_main_energy_pure_influence.cpython-38.pyc │ │ │ ├── config_main_energy_pure_influence.cpython-39.pyc │ │ │ ├── config_pebal_influence_load_newmix.cpython-38.pyc │ │ │ ├── config_pebal_influence_load_newmix.cpython-39.pyc │ │ │ ├── config_main_energy_pure_weighted_norm.cpython-38.pyc │ │ │ ├── config_main_energy_pure_influence_load.cpython-38.pyc │ │ │ ├── config_main_energy_pure_influence_newmix.cpython-38.pyc │ │ │ ├── config_main_energy_pure_weighted_gamma2.cpython-38.pyc │ │ │ ├── config_main_energy_pure_weighted_outonly.cpython-38.pyc │ │ │ ├── config_main_energy_pure_influence_newmix_load.cpython-38.pyc │ │ │ ├── config_main_energy_pure_influence_newmix_load.cpython-39.pyc │ │ │ ├── config_main_energy_pure_weighted_gamma2_infocal.cpython-38.pyc │ │ │ ├── config_main_energy_pure_weighted_gamma2_infocal.cpython-39.pyc │ │ │ └── config_main_energy_pure_weighted_gamma2_infocal_both.cpython-38.pyc │ │ └── config.py │ ├── dataset │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── data_loader.cpython-38.pyc │ │ │ └── base_dataset.cpython-38.pyc │ │ └── base_dataset.py │ ├── engine │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── engine.cpython-38.pyc │ │ │ ├── trainer.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── evaluator.cpython-38.pyc │ │ │ └── lr_policy.cpython-38.pyc │ │ ├── lr_policy.py │ │ ├── trainer.py │ │ ├── evaluator.py │ │ └── engine.py │ ├── model │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── mynn.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── network.cpython-38.pyc │ │ │ ├── wide_network.cpython-38.pyc │ │ │ └── wide_resnet_base.cpython-38.pyc │ │ ├── network.py │ │ ├── mynn.py │ │ ├── wide_network.py │ │ └── resnet.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── logger.cpython-38.pyc │ │ │ ├── metric.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── img_utils.cpython-38.pyc │ │ │ ├── pyt_utils.cpython-38.pyc │ │ │ └── wandb_upload.cpython-38.pyc │ │ ├── logger.py │ │ ├── metric.py │ │ └── img_utils.py │ ├── __pycache__ │ │ ├── losses.cpython-38.pyc │ │ ├── valid.cpython-38.pyc │ │ └── __init__.cpython-38.pyc │ ├── valid.py │ ├── test.py │ ├── main.py │ └── losses.py └── preparation │ ├── __init__.py │ └── prepare_coco_segmentation.py ├── LICENSE ├── balanced.yml └── README.md /classification/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classification/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/code/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/code/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/code/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/code/engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/code/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation/preparation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classification/models/__pycache__/wrn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/models/__pycache__/wrn.cpython-38.pyc -------------------------------------------------------------------------------- /classification/models/__pycache__/wrn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/models/__pycache__/wrn.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/__pycache__/valid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/__pycache__/valid.cpython-38.pyc -------------------------------------------------------------------------------- /classification/models/__pycache__/route.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/models/__pycache__/route.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /classification/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /classification/models/__pycache__/densenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/models/__pycache__/densenet.cpython-38.pyc -------------------------------------------------------------------------------- /classification/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/model/__pycache__/mynn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/model/__pycache__/mynn.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/utils/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/utils/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/lsun_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/lsun_loader.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/lsun_loader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/lsun_loader.cpython-39.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/svhn_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/svhn_loader.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/svhn_loader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/svhn_loader.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/engine/__pycache__/engine.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/engine/__pycache__/engine.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/engine/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/engine/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/model/__pycache__/network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/model/__pycache__/network.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/ImbalanceCIFAR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/ImbalanceCIFAR.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/ImbalanceCIFAR.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/ImbalanceCIFAR.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/engine/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/engine/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/engine/__pycache__/evaluator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/engine/__pycache__/evaluator.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/engine/__pycache__/lr_policy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/engine/__pycache__/lr_policy.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/utils/__pycache__/img_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/utils/__pycache__/img_utils.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/utils/__pycache__/pyt_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/utils/__pycache__/pyt_utils.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/display_results.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/display_results.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/display_results.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/display_results.cpython-39.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/tinyimages_300k.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/tinyimages_300k.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/tinyimages_300k.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/tinyimages_300k.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_19inf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_19inf.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_19inf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_19inf.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_load.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_load.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_paper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_paper.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/dataset/__pycache__/data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/dataset/__pycache__/data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/model/__pycache__/wide_network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/model/__pycache__/wide_network.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/utils/__pycache__/wandb_upload.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/utils/__pycache__/wandb_upload.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/score_calculation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/score_calculation.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/score_calculation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/score_calculation.cpython-39.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/validation_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/validation_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/validation_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/validation_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_19class.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_19class.cpython-36.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_19class.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_19class.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_19class.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_19class.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_OE.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_OE.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_add.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_add.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/dataset/__pycache__/base_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/dataset/__pycache__/base_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/model/__pycache__/wide_resnet_base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/model/__pycache__/wide_resnet_base.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/SCOODBenchmarkDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/SCOODBenchmarkDataset.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/SCOODBenchmarkDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/SCOODBenchmarkDataset.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_19inf_init.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_19inf_init.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/tinyimages_80mn_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/classification/utils/__pycache__/tinyimages_80mn_loader.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_pebal_influence.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_pebal_influence.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_add_energy_reg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_add_energy_reg.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_add_energy_reg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_add_energy_reg.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_pebal_influence_load.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_pebal_influence_load.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_add_energy_reg_pure.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_add_energy_reg_pure.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_weighted.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_weighted.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_influence.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_influence.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_influence.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_influence.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_pebal_influence_load_newmix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_pebal_influence_load_newmix.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_pebal_influence_load_newmix.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_pebal_influence_load_newmix.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_weighted_norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_weighted_norm.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_influence_load.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_influence_load.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_influence_newmix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_influence_newmix.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_weighted_outonly.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_weighted_outonly.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_influence_newmix_load.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_influence_newmix_load.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_influence_newmix_load.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_influence_newmix_load.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2_infocal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2_infocal.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2_infocal.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2_infocal.cpython-39.pyc -------------------------------------------------------------------------------- /segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2_infocal_both.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjunChhoi/Balanced_Energy/HEAD/segmentation/code/config/__pycache__/config_main_energy_pure_weighted_gamma2_infocal_both.cpython-38.pyc -------------------------------------------------------------------------------- /segmentation/code/model/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from model.wide_network import DeepWV3Plus 4 | 5 | 6 | class Network(nn.Module): 7 | def __init__(self, num_classes, wide=False): 8 | super(Network, self).__init__() 9 | # if wide: 10 | self.branch1 = DeepWV3Plus(num_classes) 11 | 12 | def forward(self, data, output_anomaly=False): 13 | return self.branch1(data, output_anomaly=output_anomaly) 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 hyunjunChhoi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /classification/utils/SCOODBenchmarkDataset.py: -------------------------------------------------------------------------------- 1 | 2 | ## adopted from https://github.com/amazon-science/long-tailed-ood-detection ## 3 | import os, ast 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | from torch.utils.data import DataLoader, Subset 8 | from torchvision import transforms 9 | 10 | class SCOODDataset(torch.utils.data.Dataset): 11 | 12 | def __init__(self, root, id_name, ood_name, transform): 13 | 14 | super(SCOODDataset, self).__init__() 15 | 16 | assert id_name in ['cifar10', 'cifar100'] 17 | 18 | imglist_path = os.path.join(root, 'data/imglist/benchmark_%s' % id_name, 'test_%s.txt' % ood_name) 19 | 20 | with open(imglist_path) as fp: 21 | self.imglist = fp.readlines() 22 | 23 | self.transform = transform 24 | self.root = root 25 | 26 | print("SCOODDataset (id %s, ood %s) Contain %d images" % (id_name, ood_name, len(self.imglist))) 27 | 28 | def __len__(self): 29 | return len(self.imglist) 30 | 31 | def __getitem__(self, index): 32 | # parse the string in imglist file: 33 | line = self.imglist[index].strip("\n") 34 | tokens = line.split(" ", 1) 35 | image_name, extra_str = tokens[0], tokens[1] 36 | extras = ast.literal_eval(extra_str) 37 | sc_label = extras['sc_label'] # the ood label is here. -1 means ood. 38 | 39 | # read image according to image name: 40 | img_path = os.path.join(self.root, 'data', 'images', image_name) 41 | with open(img_path, 'rb') as f: 42 | img = Image.open(f).convert('RGB') 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | 46 | return img, sc_label 47 | -------------------------------------------------------------------------------- /classification/utils/tinyimages_300k.py: -------------------------------------------------------------------------------- 1 | 2 | ## adopted from https://github.com/amazon-science/long-tailed-ood-detection ## 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | from torch.utils.data import DataLoader, Subset 8 | from torchvision import transforms 9 | 10 | class TinyImages(torch.utils.data.Dataset): 11 | 12 | def __init__(self, root, transform): 13 | 14 | super(TinyImages, self).__init__() 15 | 16 | self.data = np.load(os.path.join(root, 'tinyimages80m', '300K_random_images.npy')) 17 | self.transform = transform 18 | 19 | print("TinyImages Contain {} images".format(len(self.data))) 20 | 21 | def __getitem__(self, index): 22 | img = self.data[index] 23 | img = Image.fromarray(img) 24 | if self.transform is not None: 25 | img = self.transform(img) 26 | 27 | return img, -1 # -1 is the class 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | 33 | def tinyimages300k_dataloaders(num_samples=300000, train_batch_size=64, num_workers=8, data_root_path='/home/datasets'): 34 | 35 | num_samples = int(num_samples) 36 | 37 | data_dir = os.path.join(data_root_path) 38 | 39 | train_transform = transforms.Compose([ 40 | transforms.RandomCrop(32, padding=4), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | ]) 44 | 45 | train_set = Subset(TinyImages(data_dir, train=True, transform=train_transform, download=True), list(range(num_samples))) 46 | 47 | train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, num_workers=num_workers, 48 | drop_last=True, pin_memory=True) 49 | 50 | return train_loader 51 | 52 | -------------------------------------------------------------------------------- /segmentation/code/engine/lr_policy.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy 3 | from abc import ABCMeta, abstractmethod 4 | 5 | 6 | class BaseLR(): 7 | __metaclass__ = ABCMeta 8 | 9 | @abstractmethod 10 | def get_lr(self, cur_iter): pass 11 | 12 | 13 | class PolyLR(BaseLR): 14 | def __init__(self, start_lr, lr_power, total_iters): 15 | self.start_lr = start_lr 16 | self.lr_power = lr_power 17 | self.total_iters = total_iters + 0.0 18 | 19 | def get_lr(self, cur_iter): 20 | return self.start_lr * ( 21 | (1 - float(cur_iter) / self.total_iters) ** self.lr_power) 22 | 23 | 24 | class WarmUpPolyLR(BaseLR): 25 | def __init__(self, start_lr, lr_power, total_iters, warmup_steps, end_lr=1e-8): 26 | self.start_lr = start_lr 27 | self.lr_power = lr_power 28 | self.total_iters = total_iters + 0.0 29 | self.warmup_steps = warmup_steps 30 | self.end_lr = end_lr 31 | 32 | def get_lr(self, cur_iter): 33 | if cur_iter < self.warmup_steps: 34 | return self.start_lr * (cur_iter / self.warmup_steps) 35 | else: 36 | curr_lr = self.start_lr * ((1 - float(cur_iter) / self.total_iters) ** self.lr_power) 37 | return numpy.real(numpy.clip(curr_lr, a_min=self.end_lr, a_max=self.start_lr)) 38 | 39 | 40 | class MultiStageLR(BaseLR): 41 | def __init__(self, lr_stages): 42 | assert type(lr_stages) in [list, tuple] and len(lr_stages[0]) == 2, \ 43 | 'lr_stages must be list or tuple, with [iters, lr] format' 44 | self._lr_stagess = lr_stages 45 | 46 | def get_lr(self, epoch): 47 | for it_lr in self._lr_stagess: 48 | if epoch < it_lr[0]: 49 | return it_lr[1] 50 | 51 | 52 | class LinearIncreaseLR(BaseLR): 53 | def __init__(self, start_lr, end_lr, warm_iters): 54 | self._start_lr = start_lr 55 | self._end_lr = end_lr 56 | self._warm_iters = warm_iters 57 | self._delta_lr = (end_lr - start_lr) / warm_iters 58 | 59 | def get_lr(self, cur_epoch): 60 | return self._start_lr + cur_epoch * self._delta_lr 61 | -------------------------------------------------------------------------------- /segmentation/code/utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | 5 | 6 | class Logger: 7 | """ 8 | Training process logger 9 | 10 | Note: 11 | Used by BaseTrainer to save training history. 12 | """ 13 | 14 | def __init__(self): 15 | self.entries = {} 16 | 17 | def add_entry(self, entry): 18 | self.entries[len(self.entries) + 1] = entry 19 | 20 | def __str__(self): 21 | return json.dumps(self.entries, sort_keys=True, indent=4) 22 | 23 | 24 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) 25 | 26 | # The background is set with 40 plus the number of the color, and the foreground with 30 27 | # These are the sequences need to get colored ouput 28 | RESET_SEQ = "\033[0m" 29 | COLOR_SEQ = "\033[1;%dm" 30 | BOLD_SEQ = "\033[1m" 31 | 32 | 33 | def formatter_message(message, use_color=True): 34 | if use_color: 35 | message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ) 36 | else: 37 | message = message.replace("$RESET", "").replace("$BOLD", "") 38 | return message 39 | 40 | 41 | COLORS = { 42 | 'WARNING': BLUE, 43 | 'INFO': WHITE, 44 | 'DEBUG': GREEN, 45 | 'CRITICAL': YELLOW, 46 | 'ERROR': RED 47 | } 48 | 49 | 50 | class ColoredFormatter(logging.Formatter): 51 | def __init__(self, msg, use_color=True): 52 | logging.Formatter.__init__(self, msg) 53 | self.use_color = use_color 54 | 55 | def format(self, record): 56 | levelname = record.levelname 57 | if self.use_color and levelname in COLORS: 58 | levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ 59 | record.levelname = levelname_color 60 | return logging.Formatter.format(self, record) 61 | 62 | 63 | # FORMAT = "[$BOLD%(name)-20s$RESET][%(levelname)-18s] %(message)s ($BOLD%(filename)s$RESET:%(lineno)d)" 64 | 65 | class ColoredLogger(logging.Logger): 66 | FORMAT = "[$BOLD%(name)s$RESET][%(levelname)s] %(message)-s " 67 | COLOR_FORMAT = formatter_message(FORMAT, True) 68 | 69 | def __init__(self, name): 70 | logging.Logger.__init__(self, name, logging.INFO) 71 | color_formatter = ColoredFormatter(self.COLOR_FORMAT) 72 | console = logging.StreamHandler() 73 | console.setFormatter(color_formatter) 74 | self.addHandler(console) 75 | return 76 | 77 | 78 | logging.setLoggerClass(ColoredLogger) 79 | -------------------------------------------------------------------------------- /segmentation/code/utils/metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import numpy as np 4 | 5 | np.seterr(divide='ignore', invalid='ignore') 6 | 7 | 8 | # voc cityscapes metric 9 | def hist_info(n_cl, pred, gt): 10 | assert (pred.shape == gt.shape) 11 | k = (gt >= 0) & (gt < n_cl) 12 | labeled = np.sum(k) 13 | correct = np.sum((pred[k] == gt[k])) 14 | 15 | return np.bincount(n_cl * gt[k].astype(int) + pred[k].astype(int), 16 | minlength=n_cl ** 2).reshape(n_cl, 17 | n_cl), labeled, correct 18 | 19 | 20 | def compute_score(hist, correct, labeled): 21 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 22 | mean_IU = np.nanmean(iu) 23 | mean_IU_no_back = np.nanmean(iu[1:]) 24 | freq = hist.sum(1) / hist.sum() 25 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 26 | mean_pixel_acc = correct / labeled 27 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 28 | 29 | 30 | # ade metric 31 | def meanIoU(area_intersection, area_union): 32 | iou = 1.0 * np.sum(area_intersection, axis=1) / np.sum(area_union, axis=1) 33 | meaniou = np.nanmean(iou) 34 | meaniou_no_back = np.nanmean(iou[1:]) 35 | 36 | return iou, meaniou, meaniou_no_back 37 | 38 | 39 | def intersectionAndUnion(imPred, imLab, numClass): 40 | # Remove classes from unlabeled pixels in gt image. 41 | # We should not penalize detections in unlabeled portions of the image. 42 | imPred = imPred * (imLab >= 0) 43 | 44 | # Compute area intersection: 45 | intersection = imPred * (imPred == imLab) 46 | (area_intersection, _) = np.histogram(intersection, bins=numClass, 47 | range=(1, numClass)) 48 | 49 | # Compute area union: 50 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 51 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 52 | area_union = area_pred + area_lab - area_intersection 53 | 54 | return area_intersection, area_union 55 | 56 | 57 | def mean_pixel_accuracy(pixel_correct, pixel_labeled): 58 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / ( 59 | np.spacing(1) + np.sum(pixel_labeled)) 60 | 61 | return mean_pixel_accuracy 62 | 63 | 64 | def pixelAccuracy(imPred, imLab): 65 | # Remove classes from unlabeled pixels in gt image. 66 | # We should not penalize detections in unlabeled portions of the image. 67 | pixel_labeled = np.sum(imLab >= 0) 68 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) 69 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 70 | 71 | return pixel_accuracy, pixel_correct, pixel_labeled 72 | -------------------------------------------------------------------------------- /segmentation/code/config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy 3 | from easydict import EasyDict 4 | 5 | C = EasyDict() 6 | config = C 7 | cfg = C 8 | 9 | C.seed = 666 10 | 11 | """Root Directory Config""" 12 | C.repo_name = 'segmentation' 13 | C.root_dir = os.path.realpath(".") 14 | 15 | """Data Dir and Weight Dir""" 16 | C.city_root_path = '/datastorage/dataset/city_scape' # path/to/your/city_scape 17 | C.coco_root_path = '/datastorage/dataset/coco' # path/to/your/coco 18 | C.fishy_root_path = '/datastorage/dataset/fishyscapes' # path/to/your/fishy 19 | C.road_root_path = '/datastorage/dataset/road_anomaly' # path/to/your/road 20 | 21 | 22 | C.pebal_weight_path = os.path.join(C.root_dir, 'ckpts', 'pebal_balanced', '6707_best_real.pth') 23 | C.pretrained_weight_path = os.path.join(C.root_dir, 'ckpts', 'pretrained_ckpts', 'cityscapes_best.pth') 24 | 25 | """Network Config""" 26 | C.fix_bias = True 27 | C.bn_eps = 1e-5 28 | C.bn_momentum = 0.1 29 | 30 | """Image Config""" 31 | C.num_classes = 19 + 1 # NOTE: 1 more channel for gambler loss 32 | C.image_mean = numpy.array([0.485, 0.456, 0.406]) # 0.485, 0.456, 0.406 33 | C.image_std = numpy.array([0.229, 0.224, 0.225]) 34 | 35 | C.image_height = 900 36 | C.image_width = 900 37 | 38 | C.num_train_imgs = 2975 39 | C.num_eval_imgs = 500 40 | 41 | """Train Config""" 42 | C.lr = 1e-5 43 | C.batch_size = 8 44 | C.lr_power = 0.9 45 | C.momentum = 0.9 46 | C.weight_decay = 0 47 | 48 | C.nepochs = 20 49 | C.niters_per_epoch = C.num_train_imgs // C.batch_size 50 | C.num_workers = 8 51 | C.train_scale_array = [0.5, 0.75, 1, 1.5, 1.75, 2.0] 52 | C.void_number = 5 53 | C.warm_up_epoch = 0 54 | 55 | """Eval Config""" 56 | C.eval_epoch = 1 57 | C.eval_stride_rate = 2 / 3 58 | C.eval_scale_array = [1, ] # 0.5, 0.75, 1, 1.5, 1.75 59 | C.eval_flip = False 60 | C.eval_base_size = 800 61 | C.eval_crop_size = 800 62 | 63 | """Display Config""" 64 | C.record_info_iter = 20 65 | C.display_iter = 50 66 | 67 | """Wandb Config""" 68 | # Specify you wandb environment KEY; and paste here 69 | C.wandb_key = "" 70 | 71 | # Your project [work_space] name 72 | C.proj_name = "OoD_Segmentation" 73 | 74 | # Your current experiment name 75 | C.gamma1=3.00 76 | C.gamma2=3.00 77 | C.alpha=5 78 | 79 | C.experiment_name = "pebal_balanced_alpha_"+str(C.alpha)+"_gamma_"+str(C.gamma1)+str(C.gamma2) 80 | 81 | # half pretrained_ckpts-loader upload images; loss upload every iteration 82 | C.upload_image_step = [0, int((C.num_train_imgs / C.batch_size) / 2)] 83 | 84 | # False for debug; True for visualize 85 | C.wandb_online = True 86 | 87 | """Save Config""" 88 | C.saved_dir = os.path.join(C.root_dir, 'ckpts', C.experiment_name) 89 | 90 | if not os.path.exists(C.saved_dir): 91 | os.mkdir(C.saved_dir) 92 | -------------------------------------------------------------------------------- /classification/utils/validation_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class PartialDataset(torch.utils.data.Dataset): 6 | def __init__(self, parent_ds, offset, length): 7 | self.parent_ds = parent_ds 8 | self.offset = offset 9 | self.length = length 10 | assert len(parent_ds) >= offset + length, Exception("Parent Dataset not long enough") 11 | super(PartialDataset, self).__init__() 12 | 13 | def __len__(self): 14 | return self.length 15 | 16 | def __getitem__(self, i): 17 | return self.parent_ds[i + self.offset] 18 | 19 | 20 | def validation_split(dataset, val_share=0.1): 21 | """ 22 | Split a (training and vaidation combined) dataset into training and validation. 23 | Note that to be statistically sound, the items in the dataset should be statistically 24 | independent (e.g. not sorted by class, not several instances of the same dataset that 25 | could end up in either set). 26 | 27 | inputs: 28 | dataset: ("training") dataset to split into training and validation 29 | val_share: fraction of validation data (should be 0= min_size and w >= min_size: 50 | ann_Ids = tools.getAnnIds(imgIds=img['id'], iscrowd=None) 51 | annotations = tools.loadAnns(ann_Ids) 52 | 53 | # Generate binary segmentation mask 54 | mask = np.ones((h, w), dtype="uint8") * id_in 55 | for j in range(len(annotations)): 56 | mask = np.maximum(tools.annToMask(annotations[j])*id_out, mask) 57 | 58 | # Save segmentation mask 59 | Image.fromarray(mask).save(os.path.join(save_dir, "{:012d}.png".format(img_Id))) 60 | num_masks += 1 61 | print("\rImages Processed: {}/{}".format(i + 1, len(img_Ids)), end=' ') 62 | sys.stdout.flush() 63 | 64 | # Print summary 65 | print("\nNumber of created segmentation masks with height and width of at least %d pixels:" % min_size, num_masks) 66 | end = time.time() 67 | hours, rem = divmod(end - start, 3600) 68 | minutes, seconds = divmod(rem, 60) 69 | print("FINISHED {:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)) 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /segmentation/code/valid.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import torch.optim 5 | from tqdm import tqdm 6 | 7 | from utils.metric import hist_info 8 | from utils.pyt_utils import eval_ood_measure 9 | 10 | warnings.filterwarnings('ignore', '.*imshow.*', ) 11 | 12 | 13 | def valid_anomaly(model, test_set, data_name=None, epoch=None, my_wandb=None, logger=None, 14 | upload_img_num=4): 15 | curr_info = {} 16 | model.eval() 17 | 18 | logger.info("validating {} dataset ...".format(data_name)) 19 | tbar = tqdm(range(len(test_set)), ncols=137, leave=True) 20 | 21 | anomaly_score_list = [] 22 | ood_gts_list = [] 23 | focus_area = [] 24 | 25 | with torch.no_grad(): 26 | for idx in tbar: 27 | img, label = test_set[idx] 28 | anomaly_score = model.module(img, output_anomaly=True) 29 | anomaly_score = anomaly_score.cpu().numpy() 30 | ood_gts_list.append(np.expand_dims(label.detach().cpu().numpy(), 0)) 31 | anomaly_score_list.append(np.expand_dims(anomaly_score, 0)) 32 | if len(focus_area) < upload_img_num: 33 | anomaly_score[(label != test_set.train_id_out) & (label != test_set.train_id_in)] = 0 34 | focus_area.append(anomaly_score) 35 | 36 | # evaluation 37 | ood_gts = np.array(ood_gts_list) 38 | anomaly_scores = np.array(anomaly_score_list) 39 | 40 | roc_auc, prc_auc, fpr = eval_ood_measure(anomaly_scores, ood_gts, test_set.train_id_in, test_set.train_id_out) 41 | 42 | curr_info['{}_auroc'.format(data_name)] = roc_auc 43 | curr_info['{}_fpr95'.format(data_name)] = fpr 44 | curr_info['{}_auprc'.format(data_name)] = prc_auc 45 | logger.critical(f'AUROC score for {data_name}: {roc_auc}') 46 | logger.critical(f'AUPRC score for {data_name}: {prc_auc}') 47 | logger.critical(f'FPR@TPR95 for {data_name}: {fpr}') 48 | 49 | if my_wandb is not None: 50 | my_wandb.upload_wandb_info(current_step=epoch, info_dict=curr_info) 51 | my_wandb.upload_ood_image(current_step=epoch, energy_map=focus_area, reserve_map=None, 52 | img_number=upload_img_num, data_name=data_name) 53 | 54 | del curr_info 55 | return roc_auc, prc_auc, fpr 56 | 57 | 58 | def valid_epoch(model, engine, test_set, my_wandb, evaluator=None, logger=None, transform=None): 59 | model.eval() 60 | logger.info("validating cityscapes dataset ...") 61 | 62 | curr_info = {} 63 | all_results = [] 64 | tbar = tqdm(range(0, len(test_set)), ncols=137, leave=True) 65 | 66 | with torch.no_grad(): 67 | for idx in tbar: 68 | img, label = test_set[idx] 69 | img, label = img.permute(1, 2, 0).numpy(), label.numpy() 70 | pred = evaluator(img, model) 71 | hist_tmp, labeled_tmp, correct_tmp = hist_info(19, pred, label) 72 | results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp} 73 | all_results.append(results_dict) 74 | 75 | if engine.local_rank <= 0: 76 | tbar.set_description(" labeled: {}, correct: {}".format(str(labeled_tmp), str(correct_tmp))) 77 | 78 | m_iou, m_acc = evaluator.compute_metric(all_results) 79 | curr_info['m_iou'] = m_iou 80 | curr_info['m_acc'] = m_acc 81 | 82 | logger.critical("current mIoU is {}, mAcc is {}".format(curr_info['m_iou'], curr_info['m_acc'])) 83 | 84 | if my_wandb is not None: 85 | my_wandb.upload_wandb_info(info_dict=curr_info, current_step=0) 86 | 87 | return 88 | -------------------------------------------------------------------------------- /segmentation/code/engine/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | class Trainer_balanced: 6 | """ 7 | loss_1 -> gambler loss; loss_2 -> energy loss 8 | lr_scheduler -> cosine; 9 | """ 10 | 11 | def __init__(self, engine, loss1, loss2, tensorboard, lr_scheduler=None, ckpt_dir=None, gamma1=0, gamma2=0, alpha=0): 12 | self.engine = engine 13 | self.loss1 = loss1 14 | self.loss2 = loss2 15 | self.lr_scheduler = lr_scheduler 16 | self.saved_dir = ckpt_dir 17 | self.tensorboard = tensorboard 18 | self.gamma1=gamma1 19 | self.gamma2=gamma2 20 | self.alpha=alpha 21 | 22 | def train(self, model, epoch, train_sampler, train_loader, optimizer): 23 | model.train() 24 | 25 | self.freeze_model_parameters(model) 26 | 27 | if self.engine.distributed: 28 | train_sampler.set_epoch(epoch) 29 | 30 | loader_len = len(train_loader) 31 | tbar = tqdm(range(loader_len), ncols=137, leave=True) if self.engine.local_rank <= 0 else range(loader_len) 32 | train_loader = iter(train_loader) 33 | 34 | for batch_idx in tbar: 35 | minibatch = next(train_loader) 36 | optimizer.zero_grad() 37 | curr_idx = epoch * loader_len + batch_idx 38 | 39 | self.engine.update_iteration(epoch, curr_idx) 40 | 41 | imgs = minibatch['data'].cuda(non_blocking=True) 42 | target = minibatch['label'].cuda(non_blocking=True) 43 | is_ood = minibatch['is_ood'] 44 | 45 | logits = model(imgs) 46 | in_logits, in_target = logits[~is_ood], target[~is_ood] 47 | out_logits, out_target = logits[is_ood], target[is_ood] 48 | 49 | e_loss, _ = self.loss2(logits=logits, targets=target, gamma1=self.gamma1, gamma2=self.gamma2, alpha=self.alpha) 50 | 51 | loss = self.loss1(pred=in_logits, targets=in_target, wrong_sample=False) 52 | 53 | if torch.any(is_ood): 54 | loss += self.loss1(pred=out_logits, targets=out_target, wrong_sample=True) 55 | 56 | loss += 0.1 * e_loss 57 | 58 | loss.backward() 59 | optimizer.step() 60 | 61 | # update learning rate 62 | current_lr = self.lr_scheduler.get_lr(cur_iter=curr_idx) 63 | for _, opt_group in enumerate(optimizer.param_groups): 64 | opt_group['lr'] = current_lr 65 | 66 | curr_info = {} 67 | if self.engine.local_rank <= 0: 68 | curr_info['gambler_loss'] = loss 69 | curr_info['energy_loss'] = e_loss * .1 70 | self.tensorboard.upload_wandb_info(current_step=curr_idx, info_dict=curr_info) 71 | 72 | tbar.set_description("epoch ({}) | " 73 | "gambler_loss: {:.3f} " 74 | "energy_loss: {:.3f} ".format(epoch, curr_info['gambler_loss'], 75 | curr_info['energy_loss'])) 76 | 77 | if self.engine.local_rank <= 0: 78 | self.engine.save_and_link_checkpoint(snapshot_dir=self.saved_dir, name='epoch_{}.pth'.format(epoch)) 79 | 80 | return 81 | 82 | @staticmethod 83 | def freeze_model_parameters(curr_model): 84 | for name, param in curr_model.named_parameters(): 85 | if 'module.branch1.final' not in name: 86 | param.requires_grad = False 87 | else: 88 | param.requires_grad = True 89 | 90 | -------------------------------------------------------------------------------- /classification/inf_run_res.sh: -------------------------------------------------------------------------------- 1 | methods=(pretrained oe_tune) 2 | data_models=(cifar10_res cifar100_res) 3 | gpu=2 4 | 5 | if [ "$1" = "MSP" ]; then 6 | for dm in ${data_models[$2]}; do 7 | for method in ${methods[0]}; do 8 | # MSP with in-distribution samples as pos 9 | echo "-----------"${dm}_${method}" MSP score-----------------" 10 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 11 | done 12 | done 13 | echo "||||||||done with "${dm}_${method}" above |||||||||||||||||||" 14 | elif [ "$1" = "energy" ]; then 15 | for dm in ${data_models[$2]}; do 16 | for method in ${methods[0]}; do 17 | echo "-----------"${dm}_${method}" energy score-----------------" 18 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score energy 19 | done 20 | done 21 | echo "||||||||done with "${dm}_${method}" energy score above |||||||||||||||||||" 22 | elif [ "$1" = "M" ]; then 23 | for dm in ${data_models[$2]}; do 24 | for method in ${methods[0]}; do 25 | for noise in 0.0 0.01 0.005 0.002 0.0014 0.001 0.0005; do 26 | echo "-----------"${dm}_${method}_M_noise_${noise}"-----------------" 27 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score M --noise $noise -v 28 | done 29 | done 30 | done 31 | echo "||||||||done with "${dm}_${method}_M" noise above|||||||||||||||||||" 32 | elif [ "$1" = "Odin" ]; then 33 | for T in 1000 100 10 1; do 34 | for noise in 0 0.0004 0.0008 0.0014 0.002 0.0024 0.0028 0.0032 0.0038 0.0048; do 35 | echo "-------T="${T}_$2" noise="$noise"--------" 36 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name $2 --score Odin --num_to_avg 10 --T $T --noise $noise -v #--test_bs 50 37 | done 38 | echo "||||Odin temperature|||||||||||||||||||||||||||||||||||||||||||" 39 | done 40 | elif [ "$1" = "oe_tune" ] || [ "$1" = "energy_ft" ]; then # fine-tuning 41 | score=OE 42 | if [ "$1" = "energy_ft" ]; then # fine-tuning 43 | score=energy 44 | fi 45 | for dm in ${data_models[$2]}; do 46 | array=(${dm//_/ }) 47 | data=${array[0]} 48 | model=${array[1]} 49 | for seed in 1; do 50 | echo "---Training with dataset: "$data"---model used:"$model"---seed: "$seed"---score used:"$score"---------" 51 | if [ "$2" = "0" ]; then 52 | m_out=-5 53 | m_in=-23 54 | lamb=10 55 | gamma1=0.05 56 | gamma2=0.05 57 | elif [ "$2" = "1" ]; then 58 | m_out=-5 59 | m_in=-27 60 | lamb=100 61 | gamma1=0.005 62 | gamma2=0.005 63 | fi 64 | echo "---------------"$m_in"------"$m_out"--------------------" 65 | CUDA_VISIBLE_DEVICES=$gpu python inf_train_res.py $data --model $model --score $score --seed $seed --m_in $m_in --m_out $m_out --gamma1 $gamma1 --gamma2 $gamma2 --lamb $lamb --trial $3 66 | CUDA_VISIBLE_DEVICES=$gpu python inf_test_SC_res.py --method_name ${dm}_s${seed}_$1 --num_to_avg 1 --score $score --trial $3 --gamma1 $gamma1 --gamma2 $gamma2 67 | done 68 | done 69 | echo "||||||||done with training above "$1"|||||||||||||||||||" 70 | elif [ "$1" = "T" ]; then 71 | for dm in ${data_models[@]}; do 72 | for method in ${methods[0]}; do 73 | for T in 1 2 5 10 20 50 100 200 500 1000; do 74 | echo "-----------"${dm}_${method}_T_${T}"-----------------" 75 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score energy --T $T 76 | done 77 | done 78 | echo "||||||||done with "${dm}_${method}_T" tempearture above|||||||||||||||||||" 79 | done 80 | fi 81 | 82 | -------------------------------------------------------------------------------- /classification/inf_run_im_res.sh: -------------------------------------------------------------------------------- 1 | methods=(pretrained oe_tune) 2 | data_models=(cifar10_res0.01im cifar100_res0.01im) 3 | gpu=0 4 | 5 | if [ "$1" = "MSP" ]; then 6 | for dm in ${data_models[$2]}; do 7 | for method in ${methods[0]}; do 8 | # MSP with in-distribution samples as pos 9 | echo "-----------"${dm}_${method}" MSP score-----------------" 10 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 11 | done 12 | done 13 | echo "||||||||done with "${dm}_${method}" above |||||||||||||||||||" 14 | elif [ "$1" = "energy" ]; then 15 | for dm in ${data_models[$2]}; do 16 | for method in ${methods[0]}; do 17 | echo "-----------"${dm}_${method}" energy score-----------------" 18 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score energy 19 | done 20 | done 21 | echo "||||||||done with "${dm}_${method}" energy score above |||||||||||||||||||" 22 | elif [ "$1" = "M" ]; then 23 | for dm in ${data_models[$2]}; do 24 | for method in ${methods[0]}; do 25 | for noise in 0.0 0.01 0.005 0.002 0.0014 0.001 0.0005; do 26 | echo "-----------"${dm}_${method}_M_noise_${noise}"-----------------" 27 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score M --noise $noise -v 28 | done 29 | done 30 | done 31 | echo "||||||||done with "${dm}_${method}_M" noise above|||||||||||||||||||" 32 | elif [ "$1" = "Odin" ]; then 33 | for T in 1000 100 10 1; do 34 | for noise in 0 0.0004 0.0008 0.0014 0.002 0.0024 0.0028 0.0032 0.0038 0.0048; do 35 | echo "-------T="${T}_$2" noise="$noise"--------" 36 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name $2 --score Odin --num_to_avg 10 --T $T --noise $noise -v #--test_bs 50 37 | done 38 | echo "||||Odin temperature|||||||||||||||||||||||||||||||||||||||||||" 39 | done 40 | elif [ "$1" = "oe_tune" ] || [ "$1" = "energy_ft" ]; then # fine-tuning 41 | score=OE 42 | if [ "$1" = "energy_ft" ]; then # fine-tuning 43 | score=energy 44 | fi 45 | for dm in ${data_models[$2]}; do 46 | array=(${dm//_/ }) 47 | data=${array[0]} 48 | model=${array[1]} 49 | for seed in 1; do 50 | echo "---Training with dataset: "$data"---model used:"$model"---seed: "$seed"---score used:"$score"---------" 51 | if [ "$2" = "0" ]; then 52 | m_out=-5 53 | m_in=-23 54 | lamb=10 55 | gamma1=0.75 56 | gamma2=0.75 57 | elif [ "$2" = "1" ]; then 58 | m_out=-5 59 | m_in=-27 60 | lamb=100 61 | gamma1=0.75 62 | gamma2=0.75 63 | fi 64 | echo "---------------"$m_in"------"$m_out"--------------------" 65 | CUDA_VISIBLE_DEVICES=$gpu python inf_train_im_res.py $data --model $model --score $score --seed $seed --m_in $m_in --m_out $m_out --gamma1 $gamma1 --gamma2 $gamma2 --lamb $lamb --trial $3 66 | CUDA_VISIBLE_DEVICES=$gpu python inf_test_SC_res.py --method_name ${dm}_s${seed}_$1 --num_to_avg 1 --score $score --trial $3 --gamma1 $gamma1 --gamma2 $gamma2 67 | done 68 | done 69 | echo "||||||||done with training above "$1"|||||||||||||||||||" 70 | elif [ "$1" = "T" ]; then 71 | for dm in ${data_models[@]}; do 72 | for method in ${methods[0]}; do 73 | for T in 1 2 5 10 20 50 100 200 500 1000; do 74 | echo "-----------"${dm}_${method}_T_${T}"-----------------" 75 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score energy --T $T 76 | done 77 | done 78 | echo "||||||||done with "${dm}_${method}_T" tempearture above|||||||||||||||||||" 79 | done 80 | fi 81 | 82 | -------------------------------------------------------------------------------- /classification/inf_run_im_wide.sh: -------------------------------------------------------------------------------- 1 | methods=(pretrained oe_tune) 2 | data_models=(cifar10_wrn0.01im cifar100_wrn0.01im) 3 | gpu=1 4 | 5 | if [ "$1" = "MSP" ]; then 6 | for dm in ${data_models[$2]}; do 7 | for method in ${methods[0]}; do 8 | # MSP with in-distribution samples as pos 9 | echo "-----------"${dm}_${method}" MSP score-----------------" 10 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 11 | done 12 | done 13 | echo "||||||||done with "${dm}_${method}" above |||||||||||||||||||" 14 | elif [ "$1" = "energy" ]; then 15 | for dm in ${data_models[$2]}; do 16 | for method in ${methods[0]}; do 17 | echo "-----------"${dm}_${method}" energy score-----------------" 18 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score energy 19 | done 20 | done 21 | echo "||||||||done with "${dm}_${method}" energy score above |||||||||||||||||||" 22 | elif [ "$1" = "M" ]; then 23 | for dm in ${data_models[$2]}; do 24 | for method in ${methods[0]}; do 25 | for noise in 0.0 0.01 0.005 0.002 0.0014 0.001 0.0005; do 26 | echo "-----------"${dm}_${method}_M_noise_${noise}"-----------------" 27 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score M --noise $noise -v 28 | done 29 | done 30 | done 31 | echo "||||||||done with "${dm}_${method}_M" noise above|||||||||||||||||||" 32 | elif [ "$1" = "Odin" ]; then 33 | for T in 1000 100 10 1; do 34 | for noise in 0 0.0004 0.0008 0.0014 0.002 0.0024 0.0028 0.0032 0.0038 0.0048; do 35 | echo "-------T="${T}_$2" noise="$noise"--------" 36 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name $2 --score Odin --num_to_avg 10 --T $T --noise $noise -v #--test_bs 50 37 | done 38 | echo "||||Odin temperature|||||||||||||||||||||||||||||||||||||||||||" 39 | done 40 | elif [ "$1" = "oe_tune" ] || [ "$1" = "energy_ft" ]; then # fine-tuning 41 | score=OE 42 | if [ "$1" = "energy_ft" ]; then # fine-tuning 43 | score=energy 44 | fi 45 | for dm in ${data_models[$2]}; do 46 | array=(${dm//_/ }) 47 | data=${array[0]} 48 | model=${array[1]} 49 | for seed in 1; do 50 | echo "---Training with dataset: "$data"---model used:"$model"---seed: "$seed"---score used:"$score"---------" 51 | if [ "$2" = "0" ]; then 52 | m_out=-5 53 | m_in=-23 54 | lamb=10 55 | gamma1=0.50 56 | gamma2=0.50 57 | elif [ "$2" = "1" ]; then 58 | m_out=-5 59 | m_in=-27 60 | lamb=100 61 | gamma1=0.50 62 | gamma2=0.50 63 | fi 64 | echo "---------------"$m_in"------"$m_out"--------------------" 65 | CUDA_VISIBLE_DEVICES=$gpu python inf_train_im_wide.py $data --model $model --score $score --seed $seed --m_in $m_in --m_out $m_out --gamma1 $gamma1 --gamma2 $gamma2 --lamb $lamb --trial $3 66 | CUDA_VISIBLE_DEVICES=$gpu python inf_test_SC_wide.py --method_name ${dm}_s${seed}_$1 --num_to_avg 1 --score $score --trial $3 --gamma1 $gamma1 --gamma2 $gamma2 67 | done 68 | done 69 | echo "||||||||done with training above "$1"|||||||||||||||||||" 70 | elif [ "$1" = "T" ]; then 71 | for dm in ${data_models[@]}; do 72 | for method in ${methods[0]}; do 73 | for T in 1 2 5 10 20 50 100 200 500 1000; do 74 | echo "-----------"${dm}_${method}_T_${T}"-----------------" 75 | CUDA_VISIBLE_DEVICES=$gpu python test.py --method_name ${dm}_${method} --num_to_avg 10 --score energy --T $T 76 | done 77 | done 78 | echo "||||||||done with "${dm}_${method}_T" tempearture above|||||||||||||||||||" 79 | done 80 | fi 81 | 82 | -------------------------------------------------------------------------------- /classification/utils/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | if self.equalInOut: 29 | out = self.relu2(self.bn2(self.conv1(out))) 30 | else: 31 | out = self.relu2(self.bn2(self.conv1(x))) 32 | if self.droprate > 0: 33 | out = F.dropout(out, p=self.droprate, training=self.training) 34 | out = self.conv2(out) 35 | if not self.equalInOut: 36 | return torch.add(self.convShortcut(x), out) 37 | else: 38 | return torch.add(x, out) 39 | 40 | 41 | class NetworkBlock(nn.Module): 42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 43 | super(NetworkBlock, self).__init__() 44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 45 | 46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | return self.layer(x) 54 | 55 | 56 | class WideResNet(nn.Module): 57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 58 | super(WideResNet, self).__init__() 59 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 60 | assert ((depth - 4) % 6 == 0) 61 | n = (depth - 4) // 6 62 | block = BasicBlock 63 | # 1st conv before any network block 64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 65 | padding=1, bias=False) 66 | # 1st block 67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 68 | # 2nd block 69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 70 | # 3rd block 71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 72 | # global average pooling and classifier 73 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.fc = nn.Linear(nChannels[3], num_classes) 76 | self.nChannels = nChannels[3] 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | m.bias.data.zero_() 87 | 88 | def forward(self, x): 89 | out = self.conv1(x) 90 | out = self.block1(out) 91 | out = self.block2(out) 92 | out = self.block3(out) 93 | out = self.relu(self.bn1(out)) 94 | out = F.avg_pool2d(out, 8) 95 | out = out.view(-1, self.nChannels) 96 | return self.fc(out) 97 | -------------------------------------------------------------------------------- /classification/utils/tiny_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | if self.equalInOut: 29 | out = self.relu2(self.bn2(self.conv1(out))) 30 | else: 31 | out = self.relu2(self.bn2(self.conv1(x))) 32 | if self.droprate > 0: 33 | out = F.dropout(out, p=self.droprate, training=self.training) 34 | out = self.conv2(out) 35 | if not self.equalInOut: 36 | return torch.add(self.convShortcut(x), out) 37 | else: 38 | return torch.add(x, out) 39 | 40 | 41 | class NetworkBlock(nn.Module): 42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 43 | super(NetworkBlock, self).__init__() 44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 45 | 46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | return self.layer(x) 54 | 55 | 56 | class WideResNet(nn.Module): 57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 58 | super(WideResNet, self).__init__() 59 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 60 | assert ((depth - 4) % 6 == 0) 61 | n = (depth - 4) // 6 62 | block = BasicBlock 63 | # 1st conv before any network block 64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 65 | padding=1, bias=False) 66 | # 1st block 67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 68 | # 2nd block 69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 70 | # 3rd block 71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 72 | # global average pooling and classifier 73 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.fc = nn.Linear(nChannels[3], num_classes) 76 | self.nChannels = nChannels[3] 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | m.bias.data.zero_() 87 | 88 | def forward(self, x): 89 | out = self.conv1(x) 90 | out = self.block1(out) 91 | out = self.block2(out) 92 | out = self.block3(out) 93 | out = self.relu(self.bn1(out)) 94 | out = F.avg_pool2d(out, 16) 95 | out = out.view(-1, self.nChannels) 96 | return self.fc(out) 97 | -------------------------------------------------------------------------------- /segmentation/code/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.optim 3 | from valid import * 4 | from utils.logger import * 5 | from engine.engine import Engine 6 | from config.config import config 7 | from model.network import Network 8 | from collections import OrderedDict 9 | from engine.evaluator import SlidingEval 10 | from dataset.data_loader import Fishyscapes, Cityscapes, RoadAnomaly 11 | from utils.img_utils import Compose, Normalize, ToTensor 12 | 13 | # warnings.filterwarnings('ignore', '.*imshow.*', ) 14 | 15 | 16 | def get_anomaly_detector(num_classes, criterion=None): 17 | """ 18 | Get Network Architecture based on arguments provided 19 | """ 20 | ckpt_name = 'best_ad_ckpt.pth' 21 | model = Network(num_classes, criterion=criterion, norm_layer=torch.nn.BatchNorm2d, wide=True) 22 | 23 | tmp = torch.load(ckpt_name) 24 | print('################ retore ckpt from {} #############################'.format(ckpt_name)) 25 | state_dict = tmp 26 | if 'model' in state_dict.keys(): 27 | state_dict = state_dict['model'] 28 | new_state_dict = OrderedDict() 29 | for k, v in state_dict.items(): 30 | name = k 31 | new_state_dict[name] = v 32 | state_dict = new_state_dict 33 | model.load_state_dict(state_dict, strict=True) 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | model.to(device) 36 | return model 37 | 38 | 39 | def main(gpu, ngpus_per_node, config, args): 40 | args.local_rank = gpu 41 | logger = logging.getLogger("pebal") 42 | logger.propagate = False 43 | 44 | engine = Engine(custom_arg=args, logger=logger, 45 | continue_state_object=config.pretrained_weight_path) 46 | 47 | transform = Compose([ToTensor(), Normalize(config.image_mean, config.image_std)]) 48 | 49 | cityscapes = Cityscapes(root=config.city_root_path, split="val", transform=transform) 50 | evaluator = SlidingEval(config, device=0 if engine.local_rank < 0 else engine.local_rank) 51 | fishyscapes_ls = Fishyscapes(split='fs_lost_and_found', root=config.fishy_root_path, transform=transform) 52 | fishyscapes_static = Fishyscapes(split='fs_static', root=config.fishy_root_path, transform=transform) 53 | Road_anomaly = RoadAnomaly(root=config.road_root_path, transform=transform) 54 | 55 | # we only support 1 gpu for testing 56 | model = Network(config.num_classes, wide=True) 57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | model = torch.nn.DataParallel(model, device_ids=engine.devices) 59 | model.to(device) 60 | engine.load_pebal_ckpt(config.pebal_weight_path, model=model) 61 | 62 | model.eval() 63 | #valid_epoch(model=model, engine=engine, test_set=cityscapes, my_wandb=None, 64 | # evaluator=evaluator, logger=logger) 65 | 66 | #valid_anomaly(model=model, epoch=0, test_set=fishyscapes_ls, data_name='Fishyscapes_ls', 67 | # my_wandb=None, logger=logger) 68 | 69 | #valid_anomaly(model=model, epoch=0, test_set=fishyscapes_static, 70 | # data_name='Fishyscapes_static', my_wandb=None, logger=logger) 71 | 72 | valid_anomaly(model=model, epoch=0, test_set=Road_anomaly, 73 | data_name='Road_anomaly', my_wandb=None, logger=logger) 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser(description='Anomaly Segmentation') 78 | parser.add_argument('--gpus', default=1, 79 | type=int, 80 | help="gpus in use") 81 | parser.add_argument("--ddp", action="store_true", 82 | help="distributed data parallel training or not;" 83 | "MUST SPECIFIED") 84 | parser.add_argument('-l', '--local_rank', default=-1, 85 | type=int, 86 | help="distributed or not") 87 | parser.add_argument('-n', '--nodes', default=1, 88 | type=int, 89 | help="distributed or not") 90 | 91 | args = parser.parse_args() 92 | 93 | args.world_size = args.nodes * args.gpus 94 | if args.gpus <= 1: 95 | main(-1, 1, config=config, args=args) 96 | else: 97 | torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, config, args)) 98 | -------------------------------------------------------------------------------- /classification/utils/lsun_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import six 6 | import string 7 | import sys 8 | 9 | if sys.version_info[0] == 2: 10 | import cPickle as pickle 11 | else: 12 | import pickle 13 | 14 | 15 | class LSUNClass(data.Dataset): 16 | def __init__(self, db_path, transform=None, target_transform=None): 17 | import lmdb 18 | self.db_path = db_path 19 | self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, 20 | readahead=False, meminit=False) 21 | with self.env.begin(write=False) as txn: 22 | self.length = txn.stat()['entries'] 23 | cache_file = '_cache_' + db_path.replace('/', '_') 24 | if os.path.isfile(cache_file): 25 | self.keys = pickle.load(open(cache_file, "rb")) 26 | else: 27 | with self.env.begin(write=False) as txn: 28 | self.keys = [key for key, _ in txn.cursor()] 29 | pickle.dump(self.keys, open(cache_file, "wb")) 30 | self.transform = transform 31 | self.target_transform = target_transform 32 | 33 | def __getitem__(self, index): 34 | img, target = None, None 35 | env = self.env 36 | with env.begin(write=False) as txn: 37 | imgbuf = txn.get(self.keys[index]) 38 | 39 | buf = six.BytesIO() 40 | buf.write(imgbuf) 41 | buf.seek(0) 42 | img = Image.open(buf).convert('RGB') 43 | 44 | if self.transform is not None: 45 | img = self.transform(img) 46 | 47 | if self.target_transform is not None: 48 | target = self.target_transform(target) 49 | 50 | return img, target 51 | 52 | def __len__(self): 53 | return self.length 54 | 55 | def __repr__(self): 56 | return self.__class__.__name__ + ' (' + self.db_path + ')' 57 | 58 | 59 | class LSUN(data.Dataset): 60 | """ 61 | `LSUN `_ dataset. 62 | 63 | Args: 64 | db_path (string): Root directory for the database files. 65 | classes (string or list): One of {'train', 'val', 'test'} or a list of 66 | categories to load. e,g. ['bedroom_train', 'church_train']. 67 | transform (callable, optional): A function/transform that takes in an PIL image 68 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 69 | target_transform (callable, optional): A function/transform that takes in the 70 | target and transforms it. 71 | """ 72 | 73 | def __init__(self, db_path, classes='train', 74 | transform=None, target_transform=None): 75 | categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', 76 | 'conference_room', 'dining_room', 'kitchen', 77 | 'living_room', 'restaurant', 'tower'] 78 | dset_opts = ['train', 'val', 'test'] 79 | self.db_path = db_path 80 | if type(classes) == str and classes in dset_opts: 81 | if classes == 'test': 82 | classes = [classes] 83 | else: 84 | classes = [c + '_' + classes for c in categories] 85 | self.classes = classes 86 | 87 | # for each class, create an LSUNClassDataset 88 | self.dbs = [] 89 | for c in self.classes: 90 | self.dbs.append(LSUNClass( 91 | db_path=db_path + '/' + c + '_lmdb', 92 | transform=transform)) 93 | 94 | self.indices = [] 95 | count = 0 96 | for db in self.dbs: 97 | count += len(db) 98 | self.indices.append(count) 99 | 100 | self.length = count 101 | self.target_transform = target_transform 102 | 103 | def __getitem__(self, index): 104 | """ 105 | Args: 106 | index (int): Index 107 | 108 | Returns: 109 | tuple: Tuple (image, target) where target is the index of the target category. 110 | """ 111 | target = 0 112 | sub = 0 113 | for ind in self.indices: 114 | if index < ind: 115 | break 116 | target += 1 117 | sub = ind 118 | 119 | db = self.dbs[target] 120 | index = index - sub 121 | 122 | if self.target_transform is not None: 123 | target = self.target_transform(target) 124 | 125 | img, _ = db[index] 126 | return img, target 127 | 128 | def __len__(self): 129 | return self.length 130 | 131 | def __repr__(self): 132 | return self.__class__.__name__ + ' (' + self.db_path + ')' 133 | -------------------------------------------------------------------------------- /balanced.yml: -------------------------------------------------------------------------------- 1 | name: balanced 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=1_llvm 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7f98852_4 10 | - ca-certificates=2020.12.5=ha878542_0 11 | - certifi=2020.12.5=py38h578d9bd_1 12 | - freetype=2.10.4=h0708190_1 13 | - gmp=6.2.1=h58526e2_0 14 | - gnutls=3.6.13=h85f3911_1 15 | - jpeg=9b=h024ee3a_2 16 | - lame=3.100=h7f98852_1001 17 | - lcms2=2.11=h396b838_0 18 | - ld_impl_linux-64=2.33.1=h53a641e_7 19 | - libffi=3.3=he6710b0_2 20 | - libgcc-ng=9.3.0=h2828fa1_18 21 | - libiconv=1.16=h516909a_0 22 | - libpng=1.6.37=h21135ba_2 23 | - libstdcxx-ng=9.3.0=h6de172a_18 24 | - libtiff=4.1.0=h2733197_1 25 | - libuv=1.41.0=h7f98852_0 26 | - llvm-openmp=11.1.0=h4bd325d_0 27 | - lz4-c=1.9.3=h9c3ff4c_0 28 | - mkl=2020.4=h726a3e6_304 29 | - mkl-service=2.3.0=py38h1e0a361_2 30 | - mkl_fft=1.3.0=py38h5c078b8_1 31 | - mkl_random=1.2.0=py38hc5bc63f_1 32 | - ncurses=6.2=he6710b0_1 33 | - nettle=3.6=he412f7d_0 34 | - numpy=1.19.2=py38h54aff64_0 35 | - numpy-base=1.19.2=py38hfa32c7d_0 36 | - olefile=0.46=pyh9f0ad1d_1 37 | - openh264=2.1.1=h780b84a_0 38 | - openssl=1.1.1k=h7f98852_0 39 | - pillow=8.1.2=py38he98fc37_0 40 | - pip=21.0.1=py38h06a4308_0 41 | - python=3.8.8=hdb3f193_4 42 | - python_abi=3.8=1_cp38 43 | - setuptools=52.0.0=py38h06a4308_0 44 | - six=1.15.0=pyh9f0ad1d_0 45 | - sqlite=3.35.2=hdfb4753_0 46 | - tk=8.6.10=hbc83047_0 47 | - typing_extensions=3.7.4.3=py_0 48 | - wheel=0.36.2=pyhd3eb1b0_0 49 | - xz=5.2.5=h7b6447c_0 50 | - zlib=1.2.11=h7b6447c_3 51 | - zstd=1.4.9=ha95c52a_0 52 | - pip: 53 | - absl-py==0.12.0 54 | - backcall==0.2.0 55 | - cachetools==4.2.1 56 | - cffi==1.14.5 57 | - chardet==4.0.0 58 | - click==8.0.1 59 | - colorama==0.4.4 60 | - configparser==5.0.2 61 | - cycler==0.10.0 62 | - cython==0.29.24 63 | - decorator==4.4.2 64 | - docker-pycreds==0.4.0 65 | - easydict==1.9 66 | - faiss-gpu==1.7.0 67 | - ffmpeg==1.4 68 | - gcloud==0.18.3 69 | - gitdb==4.0.7 70 | - gitpython==3.1.18 71 | - google-api-core==1.26.2 72 | - google-auth==1.28.0 73 | - google-auth-oauthlib==0.4.4 74 | - google-cloud==0.34.0 75 | - google-cloud-core==1.6.0 76 | - google-cloud-speech==2.2.0 77 | - google-cloud-storage==1.37.0 78 | - google-crc32c==1.1.2 79 | - google-resumable-media==1.2.0 80 | - googleapis-common-protos==1.53.0 81 | - grpcio==1.36.1 82 | - httplib2==0.19.0 83 | - idna==2.10 84 | - imageio==2.9.0 85 | - ipython==7.22.0 86 | - ipython-genutils==0.2.0 87 | - jedi==0.18.0 88 | - joblib==1.0.1 89 | - jsonpatch==1.32 90 | - jsonpointer==2.1 91 | - kiwisolver==1.3.1 92 | - kornia==0.5.1 93 | - libcst==0.3.17 94 | - markdown==3.3.4 95 | - matplotlib==3.4.0 96 | - mypy-extensions==0.4.3 97 | - networkx==2.5 98 | - ninja==1.10.0.post2 99 | - oauth2client==4.1.3 100 | - oauthlib==3.1.0 101 | - opencv-python==4.5.2.54 102 | - packaging==20.9 103 | - panda==0.3.1 104 | - pandas==1.2.3 105 | - parso==0.8.1 106 | - pathtools==0.1.2 107 | - pexpect==4.8.0 108 | - pickleshare==0.7.5 109 | - promise==2.3 110 | - prompt-toolkit==3.0.18 111 | - proto-plus==1.18.1 112 | - protobuf==3.15.6 113 | - psutil==5.8.0 114 | - ptflops==0.6.6 115 | - ptyprocess==0.7.0 116 | - pyasn1==0.4.8 117 | - pyasn1-modules==0.2.8 118 | - pycocotools==2.0.2 119 | - pycparser==2.20 120 | - pygments==2.8.1 121 | - pyparsing==2.4.7 122 | - python-dateutil==2.8.1 123 | - pytorch-msssim==0.2.1 124 | - pytz==2021.1 125 | - pywavelets==1.1.1 126 | - pyyaml==5.4.1 127 | - pyzmq==22.1.0 128 | - readline==6.2.4.1 129 | - requests==2.25.1 130 | - requests-oauthlib==1.3.0 131 | - rsa==4.7.2 132 | - scikit-image==0.18.1 133 | - scikit-learn==0.24.1 134 | - scipy==1.6.2 135 | - seaborn==0.11.1 136 | - sentry-sdk==1.3.1 137 | - shortuuid==1.0.1 138 | - sklearn==0.0 139 | - smmap==4.0.0 140 | - subprocess32==3.5.4 141 | - tabulate==0.8.9 142 | - termcolor==1.1.0 143 | - threadpoolctl==2.1.0 144 | - tifffile==2021.3.17 145 | - timm==0.3.2 146 | - torch-tb-profiler==0.1.0 147 | - torchfile==0.1.0 148 | - tornado==6.1 149 | - tqdm==4.59.0 150 | - traitlets==5.0.5 151 | - typing-inspect==0.6.0 152 | - urllib3==1.26.4 153 | - visdom==0.1.8.9 154 | - wcwidth==0.2.5 155 | - websocket-client==1.1.0 156 | - werkzeug==1.0.1 157 | - wandb>=0.12.0 158 | prefix: /anaconda/envs/balanced 159 | -------------------------------------------------------------------------------- /classification/utils/calibration_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def calib_err(confidence, correct, p='2', beta=100): 5 | # beta is target bin size 6 | idxs = np.argsort(confidence) 7 | confidence = confidence[idxs] 8 | correct = correct[idxs] 9 | bins = [[i * beta, (i + 1) * beta] for i in range(len(confidence) // beta)] 10 | bins[-1] = [bins[-1][0], len(confidence)] 11 | 12 | cerr = 0 13 | total_examples = len(confidence) 14 | for i in range(len(bins) - 1): 15 | bin_confidence = confidence[bins[i][0]:bins[i][1]] 16 | bin_correct = correct[bins[i][0]:bins[i][1]] 17 | num_examples_in_bin = len(bin_confidence) 18 | 19 | if num_examples_in_bin > 0: 20 | difference = np.abs(np.nanmean(bin_confidence) - np.nanmean(bin_correct)) 21 | 22 | if p == '2': 23 | cerr += num_examples_in_bin / total_examples * np.square(difference) 24 | elif p == '1': 25 | cerr += num_examples_in_bin / total_examples * difference 26 | elif p == 'infty' or p == 'infinity' or p == 'max': 27 | cerr = np.maximum(cerr, difference) 28 | else: 29 | assert False, "p must be '1', '2', or 'infty'" 30 | 31 | if p == '2': 32 | cerr = np.sqrt(cerr) 33 | 34 | return cerr 35 | 36 | 37 | def soft_f1(confidence, correct): 38 | wrong = 1 - correct 39 | 40 | # # the incorrectly classified samples are our interest 41 | # # so they make the positive class 42 | # tp_soft = np.sum((1 - confidence) * wrong) 43 | # fp_soft = np.sum((1 - confidence) * correct) 44 | # fn_soft = np.sum(confidence * wrong) 45 | 46 | # return 2 * tp_soft / (2 * tp_soft + fn_soft + fp_soft) 47 | return 2 * ((1 - confidence) * wrong).sum()/(1 - confidence + wrong).sum() 48 | 49 | 50 | def tune_temp(logits, labels, binary_search=True, lower=0.2, upper=5.0, eps=0.0001): 51 | logits = np.array(logits) 52 | 53 | if binary_search: 54 | import torch 55 | import torch.nn.functional as F 56 | 57 | logits = torch.FloatTensor(logits) 58 | labels = torch.LongTensor(labels) 59 | t_guess = torch.FloatTensor([0.5*(lower + upper)]).requires_grad_() 60 | 61 | while upper - lower > eps: 62 | if torch.autograd.grad(F.cross_entropy(logits / t_guess, labels), t_guess)[0] > 0: 63 | upper = 0.5 * (lower + upper) 64 | else: 65 | lower = 0.5 * (lower + upper) 66 | t_guess = t_guess * 0 + 0.5 * (lower + upper) 67 | 68 | t = min([lower, 0.5 * (lower + upper), upper], key=lambda x: float(F.cross_entropy(logits / x, labels))) 69 | else: 70 | import cvxpy as cx 71 | 72 | set_size = np.array(logits).shape[0] 73 | 74 | t = cx.Variable() 75 | 76 | expr = sum((cx.Minimize(cx.log_sum_exp(logits[i, :] * t) - logits[i, labels[i]] * t) 77 | for i in range(set_size))) 78 | p = cx.Problem(expr, [lower <= t, t <= upper]) 79 | 80 | p.solve() # p.solve(solver=cx.SCS) 81 | t = 1 / t.value 82 | 83 | return t 84 | 85 | 86 | def get_measures(confidence, correct): 87 | rms = calib_err(confidence, correct, p='2') 88 | mad = calib_err(confidence, correct, p='1') 89 | sf1 = soft_f1(confidence, correct) 90 | 91 | return rms, mad, sf1 92 | 93 | 94 | def print_measures(rms, mad, sf1, method_name='Baseline'): 95 | print('\t\t\t\t\t\t\t' + method_name) 96 | print('RMS Calib Error (%): \t\t{:.2f}'.format(100 * rms)) 97 | print('MAD Calib Error (%): \t\t{:.2f}'.format(100 * mad)) 98 | print('Soft F1 Score (%): \t\t{:.2f}'.format(100 * sf1)) 99 | 100 | 101 | def print_measures_with_std(rmss, mads, sf1s, method_name='Baseline'): 102 | print('\t\t\t\t\t\t\t' + method_name) 103 | print('RMS Calib Error (%): \t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(rmss), 100 * np.std(rmss))) 104 | print('MAD Calib Error (%): \t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(mads), 100 * np.std(mads))) 105 | print('Soft F1 Score (%): \t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(sf1s), 100 * np.std(sf1s))) 106 | 107 | 108 | def show_calibration_results(confidence, correct, method_name='Baseline'): 109 | 110 | print('\t\t\t\t' + method_name) 111 | print('RMS Calib Error (%): \t\t{:.2f}'.format( 112 | 100 * calib_err(confidence, correct, p='2'))) 113 | 114 | print('MAD Calib Error (%): \t\t{:.2f}'.format( 115 | 100 * calib_err(confidence, correct, p='1'))) 116 | 117 | # print('Max Calib Error (%): \t\t{:.2f}'.format( 118 | # 100 * calib_err(confidence, correct, p='infty'))) 119 | 120 | print('Soft F1-Score (%): \t\t{:.2f}'.format( 121 | 100 * soft_f1(confidence, correct))) 122 | 123 | # add error detection measures? 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Balanced Energy Regularization Loss for Out-of-distribution Detection 2 | 3 | This repo contains the official implementation for the CVPR2023 paper: 4 | 5 | [Balanced Energy Regularization Loss for Out-of-distribution Detection](https://arxiv.org/abs/2306.10485) 6 | 7 | by Hyunjun Choi, Hawook Jeong, and Jin Young Choi. 8 | 9 | [Arxiv](https://arxiv.org/abs/2306.10485) 10 | 11 | thumbnail_intro 12 | 13 | 14 | Our code heavily relies on the implementation of [Energy-based Out-of-distribution Detection](https://github.com/wetliu/energy_ood) 15 | and [PEBAL](https://github.com/tianyu0207/PEBAL) 16 | 17 | ## Prerequisite 18 | 19 | ### Prepare Dataset 20 | 21 | #### Segmentation 22 | We follow the installation process of [PEBAL](https://github.com/tianyu0207/PEBAL/blob/main/docs/installation.md) 23 | 24 | All data tree has to be inserted in path: Balanced_Energy/segmentation/code/dataset 25 | 26 | #### Classification 27 | 28 | We use cifar10, cifar100 as training data 29 | 30 | We use auxiliary data as 300K random images following [Outlier Exposure](https://github.com/hendrycks/outlier-exposure) 31 | 32 | We test on the SC-OOD benchmark ,this should be inserted in data tree 33 | which can be downloaded from [SC-OOD UDG](https://github.com/Jingkang50/ICCV21_SCOOD) 34 | 35 | ```shell 36 | classification/data 37 | ├── cifar10 38 | ├── cifar100 39 | ├── data 40 | │ ├── images 41 | │ └── imglist 42 | └── tinyimages80m 43 | └── 300K_random_images.npy 44 | 45 | ``` 46 | 47 | 48 | ### Install dependencies 49 | 50 | The project is based on the pytorch 1.8.1 with python 3.8. 51 | 52 | 1) create conda env 53 | ```shell 54 | $ conda env create -f balanced.yml 55 | ``` 56 | 2) install the torch 1.8.1 57 | ```shell 58 | $ conda activate balanced 59 | # IF cuda version < 11.0 60 | $ pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 61 | # IF cuda version >= 11.0 (e.g., 30x or above) 62 | $ pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 63 | ``` 64 | 65 | ### Prepare checkpoint 66 | 67 | All checkpoint is in the zip files following google drive link: 68 | 69 | [checkpoint](https://drive.google.com/file/d/1V9STZyI4uQ1x_eckkdryfYj36QZzCBnE/view?usp=share_link) 70 | 71 | 72 | #### Segmentation 73 | 74 | put the checkpoint in the path : Balanced_Energy/segmentation/ckpts/pebal_balanced/ 75 | 76 | additionally, to finetune from the nvidia cityscapes model, 77 | 78 | we follow [Meta-OoD](https://github.com/robin-chan/meta-ood) and use the deeplabv3+ checkpoint 79 | in [here](https://github.com/NVIDIA/semantic-segmentation/tree/sdcnet). you'll need to put it in "ckpts/pretrained_ckpts" directory, and 80 | **please note that downloading the checkpoint before running the code is necessary** 81 | 82 | 83 | #### Classification 84 | 85 | put the checkpoint in the path : Balanced_Energy/classification/snapshots/pretrained 86 | 87 | ## Segmentation code run 88 | 89 | #### Data path Setting in Config File 90 | 91 | Open 92 | Balanced_Energy/segmentation/code/config/config.py 93 | 94 | Set the root path for datasets 95 | 96 | 97 | #### Use package run the training : main.py 98 | 99 | 1) python code/main.py in the Balanced_Energy/segmentation/ 100 | 101 | #### Use package run the evaluation : test.py 102 | 103 | 2) python code/test.py in the Balanced_Energy/segmentation/ 104 | 105 | 106 | ## Classification code run 107 | 108 | 109 | in the Balanced_Energy/classification/ 110 | 111 | run ResNet18 balanced_energy_fine_tune training and testing for cifar10 with trial index 3 112 | ```train 113 | bash inf_run_res.sh energy_ft 0 3 114 | ``` 115 | 116 | run ResNet18 balanced_energy_fine_tune training and testing for cifar100 with trial index 3 117 | ```train 118 | bash inf_run_res.sh energy_ft 1 3 119 | ``` 120 | 121 | run ResNet18 balanced_energy_fine_tune training and testing for cifar10-LT with trial index 3 122 | ```train 123 | bash inf_run_im_res.sh energy_ft 0 3 124 | ``` 125 | 126 | run ResNet18 balanced_energy_fine_tune training and testing for cifar100-LT with trial index 3 127 | ```train 128 | bash inf_run_im_res.sh energy_ft 1 3 129 | ``` 130 | 131 | the setting of hyperparameter alpha and gamma can be controlled in the bash script 132 | 133 | ## Citation 134 | 135 | If you find this project useful, please consider the citation: 136 | 137 | ``` 138 | @inproceedings{choi2023balanced, 139 | title={Balanced Energy Regularization Loss for Out-of-Distribution Detection}, 140 | author={Choi, Hyunjun and Jeong, Hawook and Choi, Jin Young}, 141 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 142 | pages={15691--15700}, 143 | year={2023} 144 | } 145 | ``` 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /classification/models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | if self.equalInOut: 29 | out = self.relu2(self.bn2(self.conv1(out))) 30 | else: 31 | out = self.relu2(self.bn2(self.conv1(x))) 32 | if self.droprate > 0: 33 | out = F.dropout(out, p=self.droprate, training=self.training) 34 | out = self.conv2(out) 35 | if not self.equalInOut: 36 | return torch.add(self.convShortcut(x), out) 37 | else: 38 | return torch.add(x, out) 39 | 40 | 41 | class NetworkBlock(nn.Module): 42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 43 | super(NetworkBlock, self).__init__() 44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 45 | 46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | return self.layer(x) 54 | 55 | 56 | class WideResNet(nn.Module): 57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 58 | super(WideResNet, self).__init__() 59 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 60 | assert ((depth - 4) % 6 == 0) 61 | n = (depth - 4) // 6 62 | block = BasicBlock 63 | # 1st conv before any network block 64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 65 | padding=1, bias=False) 66 | # 1st block 67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 68 | # 2nd block 69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 70 | # 3rd block 71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 72 | # global average pooling and classifier 73 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.fc = nn.Linear(nChannels[3], num_classes) 76 | self.nChannels = nChannels[3] 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | m.bias.data.zero_() 87 | 88 | def forward(self, x): 89 | out = self.conv1(x) 90 | out = self.block1(out) 91 | out = self.block2(out) 92 | out = self.block3(out) 93 | out = self.relu(self.bn1(out)) 94 | out = F.avg_pool2d(out, 8) 95 | out = out.view(-1, self.nChannels) 96 | return self.fc(out) 97 | 98 | def intermediate_forward(self, x, layer_index): 99 | out = self.conv1(x) 100 | out = self.block1(out) 101 | out = self.block2(out) 102 | out = self.block3(out) 103 | out = self.relu(self.bn1(out)) 104 | return out 105 | 106 | def feature_list(self, x): 107 | out_list = [] 108 | out = self.conv1(x) 109 | out = self.block1(out) 110 | out = self.block2(out) 111 | out = self.block3(out) 112 | out = self.relu(self.bn1(out)) 113 | out_list.append(out) 114 | out = F.avg_pool2d(out, 8) 115 | out = out.view(-1, self.nChannels) 116 | return self.fc(out), out_list 117 | 118 | -------------------------------------------------------------------------------- /classification/models/resnet.py: -------------------------------------------------------------------------------- 1 | ''' PyTorch implementation of ResNet taken from 2 | https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 3 | which is originally licensed under MIT. 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class BasicBlock(nn.Module): 11 | 12 | def __init__(self, in_planes, mid_planes, out_planes, norm, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn1 = norm(mid_planes) 16 | self.conv2 = nn.Conv2d(mid_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 17 | self.bn2 = norm(out_planes) 18 | 19 | self.shortcut = nn.Sequential() 20 | if stride != 1 or in_planes != out_planes: 21 | self.shortcut = nn.Sequential( 22 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False), 23 | norm(out_planes) 24 | ) 25 | 26 | def forward(self, x): 27 | out = self.bn1(self.conv1(x)) 28 | out = F.relu(out) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | out = F.relu(out) 32 | # print(out.size()) 33 | return out 34 | 35 | 36 | class ResNet(nn.Module): 37 | def __init__(self, block, num_blocks, num_classes=10, pooling='avgpool', norm=nn.BatchNorm2d, return_features=False): 38 | super(ResNet, self).__init__() 39 | if pooling == 'avgpool': 40 | self.pooling = nn.AvgPool2d(4) 41 | elif pooling == 'maxpool': 42 | self.pooling = nn.MaxPool2d(4) 43 | else: 44 | raise Exception('Unsupported pooling: %s' % pooling) 45 | self.in_planes = 64 46 | self.return_features = return_features 47 | 48 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn1 = norm(64) 50 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, norm=norm) 51 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, norm=norm) 52 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, norm=norm) 53 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, norm=norm) 54 | self.linear = nn.Linear(512, num_classes) 55 | 56 | # self.aux_linear = nn.Linear(512, num_classes) 57 | 58 | self.projection = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 128)) 59 | 60 | def _make_layer(self, block, planes, num_blocks, norm, stride): 61 | strides = [stride] + [1]*(num_blocks-1) 62 | layers = [] 63 | for stride in strides: 64 | layers.append(block(self.in_planes, planes, planes, norm, stride)) 65 | self.in_planes = planes 66 | return nn.Sequential(*layers) 67 | 68 | def forward_features(self, x): 69 | c1 = F.relu(self.bn1(self.conv1(x))) # (3,32,32) 70 | h1 = self.layer1(c1) # (64,32,32) 71 | h2 = self.layer2(h1) # (128,16,16) 72 | h3 = self.layer3(h2) # (256,8,8) 73 | h4 = self.layer4(h3) # (512,4,4) 74 | p4 = self.pooling(h4) # (512,1,1) 75 | p4 = p4.view(p4.size(0), -1) # (512) 76 | return p4 77 | 78 | def forward_classifier(self, p4): 79 | logits = self.linear(p4) # (10) 80 | return logits 81 | 82 | # def forward_aux_classifier(self, p4): 83 | # logits = self.aux_linear(p4) # (10) 84 | # return logits 85 | 86 | def forward(self, x): 87 | p4 = self.forward_features(x) 88 | logits = self.forward_classifier(p4) 89 | 90 | if self.return_features: 91 | return logits, p4 92 | else: 93 | return logits 94 | 95 | # def forward_aux(self, x): 96 | # p4 = self.forward_features(x) 97 | # logits = self.forward_aux_classifier(p4) 98 | 99 | # if self.return_features: 100 | # return logits, p4 101 | # else: 102 | # return logits 103 | 104 | def forward_projection(self, p4): 105 | projected_f = self.projection(p4) # (10) 106 | projected_f = F.normalize(projected_f, dim=1) 107 | return projected_f 108 | 109 | 110 | def ResNet18(num_classes=10, pooling='avgpool', norm=nn.BatchNorm2d, return_features=False): 111 | ''' 112 | GFLOPS: 0.5579, model size: 11.1740MB 113 | ''' 114 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, pooling=pooling, norm=norm, return_features=return_features) 115 | 116 | def ResNet34(num_classes=10, pooling='avgpool', norm=nn.BatchNorm2d, return_features=False): 117 | ''' 118 | GFLOPS: 1.1635, model size: 21.2859MB 119 | ''' 120 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, pooling=pooling, norm=norm, return_features=return_features) 121 | 122 | 123 | 124 | if __name__ == '__main__': 125 | from thop import profile 126 | net = ResNet18(num_classes=10, return_features=True) 127 | x = torch.randn(1,3,32,32) 128 | flops, params = profile(net, inputs=(x, )) 129 | y, features = net(x) 130 | print(y.size()) 131 | print('GFLOPS: %.4f, model size: %.4fMB' % (flops/1e9, params/1e6)) 132 | 133 | 134 | -------------------------------------------------------------------------------- /segmentation/code/model/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def Norm2d(in_channels): 9 | """ 10 | Custom Norm Function to allow flexible switching 11 | """ 12 | layer = torch.nn.BatchNorm2d 13 | normalization_layer = layer(in_channels) 14 | return normalization_layer 15 | 16 | 17 | def initialize_weights(*models): 18 | """ 19 | Initialize Model Weights 20 | """ 21 | for model in models: 22 | for module in model.modules(): 23 | if isinstance(module, (nn.Conv2d, nn.Linear)): 24 | nn.init.kaiming_normal_(module.weight) 25 | if module.bias is not None: 26 | module.bias.data.zero_() 27 | elif isinstance(module, nn.BatchNorm2d): 28 | module.weight.data.fill_(1) 29 | module.bias.data.zero_() 30 | 31 | 32 | def Upsample(x, size): 33 | """ 34 | Wrapper Around the Upsample Call 35 | """ 36 | return nn.functional.interpolate(x, size=size, mode='bilinear', 37 | align_corners=True) 38 | 39 | 40 | def freeze_weights(*models): 41 | for model in models: 42 | for k in model.parameters(): 43 | k.requires_grad = False 44 | 45 | 46 | def unfreeze_weights(*models): 47 | for model in models: 48 | for k in model.parameters(): 49 | k.requires_grad = True 50 | 51 | 52 | def initialize_embedding(*models): 53 | """ 54 | Initialize Model Weights 55 | """ 56 | for model in models: 57 | for module in model.modules(): 58 | if isinstance(module, nn.Embedding): 59 | module.weight.data.zero_() # original 60 | 61 | 62 | def forgiving_state_restore(net, loaded_dict): 63 | """ 64 | Handle partial loading when some tensors don't match up in size. 65 | Because we want to use models that were trained off a different 66 | number of classes. 67 | """ 68 | net_state_dict = net.state_dict() 69 | new_loaded_dict = {} 70 | for k in net_state_dict: 71 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 72 | new_loaded_dict[k] = loaded_dict[k] 73 | else: 74 | print("Skipped loading parameter", k) 75 | # logging.info("Skipped loading parameter %s", k) 76 | net_state_dict.update(new_loaded_dict) 77 | net.load_state_dict(net_state_dict) 78 | return net 79 | 80 | 81 | def Zero_Masking(input_tensor, mask_org): 82 | output = input_tensor.clone() 83 | output.mul_(mask_org) 84 | return output 85 | 86 | 87 | def RandomPosZero_Masking(input_tensor, p=0.5): 88 | output = input_tensor.clone() 89 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 90 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), 91 | input_tensor.size(3)) 92 | noise_b.bernoulli_(1 - p) 93 | noise_b = noise_b.expand_as(input_tensor) 94 | output.mul_(noise_b) 95 | return output 96 | 97 | 98 | def RandomVal_Masking(input_tensor, mask_org): 99 | output = input_tensor.clone() 100 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), 101 | input_tensor.size(3)) 102 | mask = (mask_org == 0).type(input_tensor.type()) 103 | mask = mask.expand_as(input_tensor) 104 | mask = torch.mul(mask, noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item())) 105 | mask_org = mask_org.expand_as(input_tensor) 106 | output.mul_(mask_org) 107 | output.add_(mask) 108 | return output 109 | 110 | 111 | def RandomPosVal_Masking(input_tensor, p=0.5): 112 | output = input_tensor.clone() 113 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 114 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), 115 | input_tensor.size(3)) 116 | mask = noise_b.bernoulli_(1 - p) 117 | mask = (mask == 0).type(input_tensor.type()) 118 | mask = mask.expand_as(input_tensor) 119 | mask = torch.mul(mask, noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item())) 120 | noise_b = noise_b.expand_as(input_tensor) 121 | output.mul_(noise_b) 122 | output.add_(mask) 123 | return output 124 | 125 | 126 | def masking(input_tensor, p=0.5): 127 | output = input_tensor.clone() 128 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 129 | noise_u = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 130 | mask = noise_b.bernoulli_(1 - p) 131 | mask = (mask == 0).type(input_tensor.type()) 132 | mask.mul_(noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item())) 133 | # mask.mul_(noise_u.uniform_(5, 10)) 134 | noise_b = noise_b.expand_as(input_tensor) 135 | mask = mask.expand_as(input_tensor) 136 | output.mul_(noise_b) 137 | output.add_(mask) 138 | return output 139 | -------------------------------------------------------------------------------- /segmentation/code/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | 9 | 10 | class BaseDataset(data.Dataset): 11 | def __init__(self, setting, split_name, preprocess=None, 12 | file_length=None): 13 | super(BaseDataset, self).__init__() 14 | self._split_name = split_name 15 | self._img_path = setting['img_root'] 16 | self._gt_path = setting['gt_root'] 17 | self._train_source = setting['train_source'] 18 | self._eval_source = setting['eval_source'] 19 | self._file_names = self._get_file_names(split_name) 20 | self._file_length = file_length 21 | self.preprocess = preprocess 22 | 23 | def __len__(self): 24 | if self._file_length is not None: 25 | return self._file_length 26 | return len(self._file_names) 27 | 28 | def __getitem__(self, index): 29 | if self._file_length is not None: 30 | names = self._construct_new_file_names(self._file_length)[index] 31 | else: 32 | names = self._file_names[index] 33 | img_path = os.path.join(self._img_path, names[0]) 34 | gt_path = os.path.join(self._gt_path, names[1]) 35 | item_name = names[1].split("/")[-1].split(".")[0] 36 | 37 | img, gt = self._fetch_data(img_path, gt_path) 38 | 39 | img = img[:, :, ::-1] 40 | if self.preprocess is not None: 41 | img, gt, extra_dict = self.preprocess(img, gt) 42 | 43 | if self._split_name == 'train': 44 | img = torch.from_numpy(np.ascontiguousarray(img)).float() 45 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long() 46 | if self.preprocess is not None and extra_dict is not None: 47 | for k, v in extra_dict.items(): 48 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v)) 49 | if 'label' in k: 50 | extra_dict[k] = extra_dict[k].long() 51 | if 'img' in k: 52 | extra_dict[k] = extra_dict[k].float() 53 | 54 | output_dict = dict(data=img, label=gt, fn=str(item_name), 55 | n=len(self._file_names)) 56 | if self.preprocess is not None and extra_dict is not None: 57 | output_dict.update(**extra_dict) 58 | 59 | return output_dict 60 | 61 | def _fetch_data(self, img_path, gt_path, dtype=None): 62 | img = self._open_image(img_path) 63 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype) 64 | 65 | return img, gt 66 | 67 | def _get_file_names(self, split_name, train_extra=False): 68 | assert split_name in ['train', 'val'] 69 | source = self._train_source 70 | if split_name == "val": 71 | source = self._eval_source 72 | 73 | file_names = [] 74 | with open(source) as f: 75 | files = f.readlines() 76 | 77 | for item in files: 78 | img_name, gt_name = self._process_item_names(item) 79 | file_names.append([img_name, gt_name]) 80 | 81 | if train_extra: 82 | file_names2 = [] 83 | source2 = self._train_source.replace('train', 'train_extra') 84 | with open(source2) as f: 85 | files2 = f.readlines() 86 | 87 | for item in files2: 88 | img_name, gt_name = self._process_item_names(item) 89 | file_names2.append([img_name, gt_name]) 90 | 91 | return file_names, file_names2 92 | 93 | return file_names 94 | 95 | def _construct_new_file_names(self, length): 96 | assert isinstance(length, int) 97 | files_len = len(self._file_names) 98 | 99 | 100 | if length < files_len: 101 | return self._file_names[:length] 102 | 103 | new_file_names = self._file_names * (length // files_len) 104 | 105 | rand_indices = torch.randperm(files_len).tolist() 106 | new_indices = rand_indices[:length % files_len] 107 | 108 | new_file_names += [self._file_names[i] for i in new_indices] 109 | 110 | return new_file_names 111 | 112 | @staticmethod 113 | def _process_item_names(item): 114 | item = item.strip() 115 | item = item.split('\t') 116 | img_name = item[0] 117 | 118 | if len(item) == 1: 119 | gt_name = None 120 | else: 121 | gt_name = item[1] 122 | 123 | return img_name, gt_name 124 | 125 | def get_length(self): 126 | return self.__len__() 127 | 128 | @staticmethod 129 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None): 130 | # cv2: B G R 131 | # h w c 132 | img = np.array(cv2.imread(filepath, mode), dtype=dtype) 133 | return img 134 | 135 | @classmethod 136 | def get_class_colors(*args): 137 | raise NotImplementedError 138 | 139 | @classmethod 140 | def get_class_names(*args): 141 | raise NotImplementedError 142 | 143 | 144 | if __name__ == "__main__": 145 | data_setting = {'img_root': '', 146 | 'gt_root': '', 147 | 'train_source': '', 148 | 'eval_source': ''} 149 | bd = BaseDataset(data_setting, 'train', None) 150 | print(bd.get_class_names()) 151 | -------------------------------------------------------------------------------- /classification/utils/svhn_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | 7 | 8 | class SVHN(data.Dataset): 9 | url = "" 10 | filename = "" 11 | file_md5 = "" 12 | 13 | split_list = { 14 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 15 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 16 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 17 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 18 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 19 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"], 20 | 'train_and_extra': [ 21 | ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 22 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 23 | ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 24 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]]} 25 | 26 | def __init__(self, root, split='train', 27 | transform=None, target_transform=None, download=False): 28 | self.root = root 29 | self.transform = transform 30 | self.target_transform = target_transform 31 | self.split = split # training set or test set or extra set 32 | 33 | if self.split not in self.split_list: 34 | raise ValueError('Wrong split entered! Please use split="train" ' 35 | 'or split="extra" or split="test" ' 36 | 'or split="train_and_extra" ') 37 | 38 | if self.split == "train_and_extra": 39 | self.url = self.split_list[split][0][0] 40 | self.filename = self.split_list[split][0][1] 41 | self.file_md5 = self.split_list[split][0][2] 42 | else: 43 | self.url = self.split_list[split][0] 44 | self.filename = self.split_list[split][1] 45 | self.file_md5 = self.split_list[split][2] 46 | 47 | # import here rather than at top of file because this is 48 | # an optional dependency for torchvision 49 | import scipy.io as sio 50 | 51 | # reading(loading) mat file as array 52 | loaded_mat = sio.loadmat(os.path.join(root, self.filename)) 53 | 54 | if self.split == "test": 55 | self.data = loaded_mat['X'] 56 | self.targets = loaded_mat['y'] 57 | # Note label 10 == 0 so modulo operator required 58 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 59 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 60 | else: 61 | self.data = loaded_mat['X'] 62 | self.targets = loaded_mat['y'] 63 | 64 | if self.split == "train_and_extra": 65 | extra_filename = self.split_list[split][1][1] 66 | loaded_mat = sio.loadmat(os.path.join(root, extra_filename)) 67 | self.data = np.concatenate([self.data, 68 | loaded_mat['X']], axis=3) 69 | self.targets = np.vstack((self.targets, 70 | loaded_mat['y'])) 71 | # Note label 10 == 0 so modulo operator required 72 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 73 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 74 | 75 | def __getitem__(self, index): 76 | if self.split == "test": 77 | img, target = self.data[index], self.targets[index] 78 | else: 79 | img, target = self.data[index], self.targets[index] 80 | 81 | # doing this so that it is consistent with all other datasets 82 | # to return a PIL Image 83 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 84 | 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | 91 | return img, target 92 | 93 | def __len__(self): 94 | if self.split == "test": 95 | return len(self.data) 96 | else: 97 | return len(self.data) 98 | 99 | def _check_integrity(self): 100 | root = self.root 101 | if self.split == "train_and_extra": 102 | md5 = self.split_list[self.split][0][2] 103 | fpath = os.path.join(root, self.filename) 104 | train_integrity = check_integrity(fpath, md5) 105 | extra_filename = self.split_list[self.split][1][1] 106 | md5 = self.split_list[self.split][1][2] 107 | fpath = os.path.join(root, extra_filename) 108 | return check_integrity(fpath, md5) and train_integrity 109 | else: 110 | md5 = self.split_list[self.split][2] 111 | fpath = os.path.join(root, self.filename) 112 | return check_integrity(fpath, md5) 113 | 114 | def download(self): 115 | if self.split == "train_and_extra": 116 | md5 = self.split_list[self.split][0][2] 117 | download_url(self.url, self.root, self.filename, md5) 118 | extra_filename = self.split_list[self.split][1][1] 119 | md5 = self.split_list[self.split][1][2] 120 | download_url(self.url, self.root, extra_filename, md5) 121 | else: 122 | md5 = self.split_list[self.split][2] 123 | download_url(self.url, self.root, self.filename, md5) 124 | -------------------------------------------------------------------------------- /segmentation/code/model/wide_network.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | from model.mynn import * 4 | from model.wide_resnet_base import WiderResNetA2 5 | 6 | Norm2d = torch.nn.BatchNorm2d 7 | 8 | 9 | class _AtrousSpatialPyramidPoolingModule(nn.Module): 10 | """ 11 | operations performed: 12 | 1x1 x depth 13 | 3x3 x depth dilation 6 14 | 3x3 x depth dilation 12 15 | 3x3 x depth dilation 18 16 | image pooling 17 | concatenate all together 18 | Final 1x1 conv 19 | """ 20 | 21 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): 22 | super(_AtrousSpatialPyramidPoolingModule, self).__init__() 23 | 24 | # Check if we are using distributed BN and use the nn from encoding.nn 25 | # library rather than using standard pytorch.nn 26 | 27 | if output_stride == 8: 28 | rates = [2 * r for r in rates] 29 | elif output_stride == 16: 30 | pass 31 | else: 32 | raise 'output stride of {} not supported'.format(output_stride) 33 | 34 | self.features = [] 35 | # 1x1 36 | self.features.append( 37 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 38 | Norm2d(reduction_dim), nn.ReLU(inplace=True))) 39 | # other rates 40 | for r in rates: 41 | self.features.append(nn.Sequential( 42 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, 43 | dilation=r, padding=r, bias=False), 44 | Norm2d(reduction_dim), 45 | nn.ReLU(inplace=True) 46 | )) 47 | self.features = torch.nn.ModuleList(self.features) 48 | 49 | # img level features 50 | self.img_pooling = nn.AdaptiveAvgPool2d(1) 51 | self.img_conv = nn.Sequential( 52 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 53 | Norm2d(reduction_dim), nn.ReLU(inplace=True)) 54 | 55 | def forward(self, x): 56 | x_size = x.size() 57 | 58 | img_features = self.img_pooling(x) 59 | img_features = self.img_conv(img_features) 60 | img_features = Upsample(img_features, x_size[2:]) 61 | out = img_features 62 | 63 | for f in self.features: 64 | y = f(x) 65 | out = torch.cat((out, y), 1) 66 | return out 67 | 68 | 69 | class DeepWV3Plus(torch.nn.Module): 70 | """ 71 | Wide_resnet version of DeepLabV3 72 | mod1 73 | pool2 74 | mod2 str2 75 | pool3 76 | mod3-7 77 | structure: [3, 3, 6, 3, 1, 1] 78 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 79 | (1024, 2048, 4096)] 80 | """ 81 | 82 | def __init__(self, num_classes, trunk='WideResnet38'): 83 | 84 | super(DeepWV3Plus, self).__init__() 85 | # logging.debug("Trunk: %s", trunk) 86 | wide_resnet = WiderResNetA2(structure=[3, 3, 6, 3, 1, 1], classes=1000, dilation=True) 87 | 88 | self.gaussian_smoothing = transforms.GaussianBlur(7, sigma=1) 89 | 90 | self.mod1 = wide_resnet.mod1 91 | self.mod2 = wide_resnet.mod2 92 | self.mod3 = wide_resnet.mod3 93 | self.mod4 = wide_resnet.mod4 94 | self.mod5 = wide_resnet.mod5 95 | self.mod6 = wide_resnet.mod6 96 | self.mod7 = wide_resnet.mod7 97 | self.pool2 = wide_resnet.pool2 98 | self.pool3 = wide_resnet.pool3 99 | del wide_resnet 100 | self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, 101 | output_stride=8) 102 | 103 | self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) 104 | self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) 105 | 106 | self.final = nn.Sequential( 107 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 108 | Norm2d(256), 109 | nn.ReLU(inplace=True), 110 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 111 | Norm2d(256), 112 | nn.ReLU(inplace=True), 113 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 114 | 115 | initialize_weights(self.final) 116 | 117 | def compute_anomaly_score(self, score): 118 | score = score.squeeze()[:19] 119 | anomaly_score = -(1. * torch.logsumexp(score[:19, :, :] / 1., dim=0)) 120 | 121 | # regurlar gaussian smoothing 122 | anomaly_score = anomaly_score.unsqueeze(0) 123 | 124 | # transfer to "anomaly_score.cpu()", in case you meet the bug: 125 | # CUDA error: CUBLAS_STATUS_ALLOC_FAILED 126 | anomaly_score = self.gaussian_smoothing(anomaly_score) 127 | anomaly_score = anomaly_score.squeeze(0) 128 | 129 | return anomaly_score 130 | 131 | def forward(self, inp, output_anomaly=False): 132 | if len(inp.shape) == 3: 133 | inp = inp.unsqueeze(0).cuda() 134 | 135 | assert len(inp.shape) == 4 ## (B, C, W, H) 136 | 137 | x_size = inp.size() 138 | x = self.mod1(inp) 139 | m2 = self.mod2(self.pool2(x)) 140 | x = self.mod3(self.pool3(m2)) 141 | x = self.mod4(x) 142 | x = self.mod5(x) 143 | x = self.mod6(x) 144 | x = self.mod7(x) 145 | x = self.aspp(x) 146 | dec0_up = self.bot_aspp(x) 147 | dec0_fine = self.bot_fine(m2) 148 | dec0_up = Upsample(dec0_up, m2.size()[2:]) 149 | dec0 = [dec0_fine, dec0_up] 150 | dec0 = torch.cat(dec0, 1) 151 | 152 | dec1 = self.final(dec0) 153 | out = Upsample(dec1, x_size[2:]) 154 | if output_anomaly is True: 155 | anomaly_score = self.compute_anomaly_score(out) 156 | return anomaly_score 157 | else: 158 | return out 159 | -------------------------------------------------------------------------------- /segmentation/code/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.distributed as dist 3 | import torch.optim 4 | from config.config import config 5 | from dataset.data_loader import Fishyscapes, Cityscapes 6 | from dataset.data_loader import get_mix_loader 7 | from engine.engine import Engine 8 | from engine.evaluator import SlidingEval 9 | from engine.lr_policy import WarmUpPolyLR 10 | from engine.trainer import Trainer_balanced 11 | from losses import * 12 | from model.network import Network 13 | from utils.img_utils import * 14 | from utils.wandb_upload import * 15 | from valid import * 16 | 17 | from utils.logger import * 18 | 19 | warnings.filterwarnings('ignore', '.*imshow.*', ) 20 | 21 | 22 | def declare_settings(config_file, logger, engine): 23 | logger.critical("distributed data parallel training: {}".format(str("on" if engine.distributed is True 24 | else "off"))) 25 | 26 | logger.critical("gpus: {}, with batch_size[local]: {}".format(engine.world_size, config.batch_size)) 27 | 28 | logger.critical("network architecture: {}, with ResNet {} backbone".format("deeplabv3+", 29 | config_file['pretrained_weight_path'] 30 | .split('/')[-1].split('_')[0])) 31 | logger.critical("learning rate: other {}, and head is same [world]".format(config_file['lr'])) 32 | 33 | logger.info("image: {}x{} based on 1024x2048".format(config_file['image_height'], 34 | config_file['image_width'])) 35 | 36 | logger.info("current batch: {} [world]".format(int(config_file['batch_size']) * engine.world_size)) 37 | 38 | 39 | def main(gpu, ngpus_per_node, config, args): 40 | args.local_rank = gpu 41 | logger = logging.getLogger("pebal") 42 | logger.propagate = False 43 | engine = Engine(custom_arg=args, logger=logger, 44 | continue_state_object=config.pretrained_weight_path) 45 | 46 | if engine.local_rank <= 0: 47 | declare_settings(config_file=config, logger=logger, engine=engine) 48 | visual_tool = Tensorboard(config=config) 49 | else: 50 | visual_tool = None 51 | 52 | seed = config.seed 53 | 54 | if engine.distributed: 55 | seed = seed + engine.local_rank 56 | 57 | torch.manual_seed(seed) 58 | if torch.cuda.is_available(): 59 | torch.cuda.manual_seed(seed) 60 | 61 | model = Network(config.num_classes, wide=True) 62 | gambler_loss = Gambler(reward=[4.5], pretrain=-1, device=engine.local_rank if engine.local_rank >= 0 else 0) 63 | 64 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 65 | testing_transform = Compose([ToTensor(), Normalize(config.image_mean, config.image_std)]) 66 | fishyscapes_ls = Fishyscapes(split='fs_lost_and_found', root=config.fishy_root_path, transform=testing_transform) 67 | fishyscapes_static = Fishyscapes(split='fs_static', root=config.fishy_root_path, transform=testing_transform) 68 | cityscapes = Cityscapes(root=config.city_root_path, split="val", transform=testing_transform) 69 | 70 | # config lr policy 71 | base_lr = config.lr 72 | total_iteration = config.nepochs * config.niters_per_epoch 73 | lr_policy = WarmUpPolyLR(base_lr, config.lr_power, total_iteration, config.niters_per_epoch * config.warm_up_epoch) 74 | trainer = Trainer_balanced(engine=engine, loss1=gambler_loss, loss2=balanced_energy_loss_with_smooth_sparsity, lr_scheduler=lr_policy, 75 | ckpt_dir=config.saved_dir, tensorboard=visual_tool, gamma1=config.gamma1, gamma2=config.gamma2, alpha=config.alpha) 76 | 77 | evaluator = SlidingEval(config, device=0 if engine.local_rank < 0 else engine.local_rank) 78 | 79 | if engine.distributed: 80 | torch.cuda.set_device(engine.local_rank) 81 | model.cuda(engine.local_rank) 82 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 83 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[engine.local_rank], 84 | find_unused_parameters=True) 85 | else: 86 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 87 | model = torch.nn.DataParallel(model, device_ids=engine.devices) 88 | model.to(device) 89 | 90 | # starting with the pre-trained weight from https://github.com/NVIDIA/semantic-segmentation/tree/sdcnet 91 | if engine.continue_state_object: 92 | engine.register_state(dataloader=None, model=model, optimizer=optimizer) 93 | engine.restore_checkpoint(extra_channel=True) 94 | # engine.load_pebal_ckpt(config.pebal_weight_path, model=model) 95 | 96 | logger.info('training begin...') 97 | 98 | for curr_epoch in range(engine.state.epoch, config.nepochs): 99 | 100 | train_loader, train_sampler, void_ind = get_mix_loader(engine=engine, augment=True, 101 | cs_root=config.city_root_path, 102 | coco_root=config.coco_root_path) 103 | 104 | engine.register_state(dataloader=train_loader, model=model, optimizer=optimizer) 105 | 106 | trainer.train(model=model, epoch=curr_epoch, train_sampler=train_sampler, train_loader=train_loader, 107 | optimizer=optimizer) 108 | 109 | if curr_epoch % config.eval_epoch == 0: 110 | if engine.local_rank <= 0: 111 | 112 | # valid_epoch(model=model, engine=engine, test_set=cityscapes, my_wandb=visual_tool, 113 | # evaluator=evaluator, logger=logger) 114 | 115 | valid_anomaly(model=model, epoch=curr_epoch, test_set=fishyscapes_ls, data_name='Fishyscapes_ls', 116 | my_wandb=visual_tool, logger=logger) 117 | 118 | valid_anomaly(model=model, epoch=curr_epoch, test_set=fishyscapes_static, 119 | data_name='Fishyscapes_static', my_wandb=visual_tool, logger=logger) 120 | 121 | if engine.distributed: 122 | dist.barrier() 123 | 124 | 125 | if __name__ == '__main__': 126 | parser = argparse.ArgumentParser(description='Anomaly Segmentation') 127 | parser.add_argument('--gpus', default=1, 128 | type=int, 129 | help="gpus in use") 130 | parser.add_argument('-l', '--local_rank', default=-1, 131 | type=int, 132 | help="distributed or not") 133 | parser.add_argument('-n', '--nodes', default=1, 134 | type=int, 135 | help="distributed or not") 136 | args = parser.parse_args() 137 | 138 | torch.backends.cudnn.benchmark = True 139 | 140 | args.world_size = args.nodes * args.gpus 141 | 142 | # we enforce the flag of ddp if gpus >= 2; 143 | args.ddp = True if args.world_size > 1 else False 144 | if args.gpus <= 1: 145 | main(-1, 1, config=config, args=args) 146 | else: 147 | torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, config, args)) 148 | -------------------------------------------------------------------------------- /segmentation/code/utils/img_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numbers 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as trans 9 | 10 | 11 | class Compose(object): 12 | """Wraps together multiple image augmentations. 13 | Should also be used with only one augmentation, as it ensures, that input 14 | images are of type 'PIL.Image' and handles the augmentation process. 15 | Args: 16 | augmentations: List of augmentations to be applied. 17 | """ 18 | 19 | def __init__(self, augmentations): 20 | """Initializes the composer with the given augmentations.""" 21 | self.augmentations = augmentations 22 | 23 | def __call__(self, img, mask, *inputs): 24 | """Returns images that are augmented with the given augmentations.""" 25 | # img, mask = Image.fromarray(img, mode='RGB'), Image.fromarray(mask, mode='L') 26 | assert img.size == mask.size 27 | for a in self.augmentations: 28 | img, mask, inputs = a(img, mask, *inputs) 29 | return (img, mask, *inputs) 30 | 31 | 32 | class ToTensor(object): 33 | def __call__(self, image, mask, *inputs, **kwargs): 34 | t = trans.ToTensor() 35 | return (t(image), torch.tensor(np.array(mask, dtype=np.uint8), dtype=torch.long), 36 | tuple(torch.tensor(np.array(i, dtype=np.uint8), dtype=torch.long) for i in inputs)) 37 | 38 | def __repr__(self, *inputs, **kwargs): 39 | return self.__class__.__name__ + '()' 40 | 41 | 42 | class Normalize(object): 43 | def __init__(self, mean, std, *inputs, **kwargs): 44 | self.mean = mean 45 | self.std = std 46 | self.t = trans.Normalize(mean=self.mean, std=self.std) 47 | 48 | def __call__(self, tensor, mask, *inputs, **kwargs): 49 | return self.t(tensor), mask, tuple(i for i in inputs) 50 | 51 | 52 | def get_2dshape(shape, *, zero=True): 53 | if not isinstance(shape, collections.Iterable): 54 | shape = int(shape) 55 | shape = (shape, shape) 56 | else: 57 | h, w = map(int, shape) 58 | shape = (h, w) 59 | if zero: 60 | minv = 0 61 | else: 62 | minv = 1 63 | 64 | assert min(shape) >= minv, 'invalid shape: {}'.format(shape) 65 | return shape 66 | 67 | 68 | def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value): 69 | h, w = img.shape[:2] 70 | start_crop_h, start_crop_w = crop_pos 71 | assert ((start_crop_h < h) and (start_crop_h >= 0)) 72 | assert ((start_crop_w < w) and (start_crop_w >= 0)) 73 | 74 | crop_size = get_2dshape(crop_size) 75 | crop_h, crop_w = crop_size 76 | 77 | img_crop = img[start_crop_h:start_crop_h + crop_h, 78 | start_crop_w:start_crop_w + crop_w, ...] 79 | 80 | img_, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT, 81 | pad_label_value) 82 | 83 | return img_, margin 84 | 85 | 86 | def generate_random_crop_pos(ori_size, crop_size): 87 | ori_size = get_2dshape(ori_size) 88 | h, w = ori_size 89 | 90 | crop_size = get_2dshape(crop_size) 91 | crop_h, crop_w = crop_size 92 | 93 | pos_h, pos_w = 0, 0 94 | 95 | if h > crop_h: 96 | pos_h = random.randint(0, h - crop_h + 1) 97 | 98 | if w > crop_w: 99 | pos_w = random.randint(0, w - crop_w + 1) 100 | 101 | return pos_h, pos_w 102 | 103 | 104 | def pad_image_to_shape(img, shape, border_mode, value): 105 | margin = np.zeros(4, np.uint32) 106 | shape = get_2dshape(shape) 107 | pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 108 | pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0 109 | 110 | margin[0] = pad_height // 2 111 | margin[1] = pad_height // 2 + pad_height % 2 112 | margin[2] = pad_width // 2 113 | margin[3] = pad_width // 2 + pad_width % 2 114 | 115 | img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3], 116 | border_mode, value=value) 117 | 118 | return img, margin 119 | 120 | 121 | def pad_image_size_to_multiples_of(img, multiple, pad_value): 122 | h, w = img.shape[:2] 123 | d = multiple 124 | 125 | def canonicalize(s): 126 | v = s // d 127 | return (v + (v * d != s)) * d 128 | 129 | th, tw = map(canonicalize, (h, w)) 130 | 131 | return pad_image_to_shape(img, (th, tw), cv2.BORDER_CONSTANT, pad_value) 132 | 133 | 134 | def resize_ensure_shortest_edge(img, edge_length, 135 | interpolation_mode=cv2.INTER_LINEAR): 136 | assert isinstance(edge_length, int) and edge_length > 0, edge_length 137 | h, w = img.shape[:2] 138 | if h < w: 139 | ratio = float(edge_length) / h 140 | th, tw = edge_length, max(1, int(ratio * w)) 141 | else: 142 | ratio = float(edge_length) / w 143 | th, tw = max(1, int(ratio * h)), edge_length 144 | img = cv2.resize(img, (tw, th), interpolation_mode) 145 | 146 | return img 147 | 148 | 149 | def random_scale(img, gt, scales): 150 | scale = random.choice(scales) 151 | sh = int(img.shape[0] * scale) 152 | sw = int(img.shape[1] * scale) 153 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR) 154 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST) 155 | 156 | return img, gt, scale 157 | 158 | 159 | def random_scale_with_length(img, gt, length): 160 | size = random.choice(length) 161 | sh = size 162 | sw = size 163 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR) 164 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST) 165 | 166 | return img, gt, size 167 | 168 | 169 | def random_mirror(img, gt): 170 | if random.random() >= 0.5: 171 | img = cv2.flip(img, 1) 172 | gt = cv2.flip(gt, 1) 173 | 174 | return img, gt, 175 | 176 | 177 | def random_rotation(img, gt): 178 | angle = random.random() * 20 - 10 179 | h, w = img.shape[:2] 180 | rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 181 | img = cv2.warpAffine(img, rotation_matrix, (w, h), flags=cv2.INTER_LINEAR) 182 | gt = cv2.warpAffine(gt, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST) 183 | 184 | return img, gt 185 | 186 | 187 | def random_gaussian_blur(img): 188 | gauss_size = random.choice([1, 3, 5, 7]) 189 | if gauss_size > 1: 190 | # do the gaussian blur 191 | img = cv2.GaussianBlur(img, (gauss_size, gauss_size), 0) 192 | 193 | return img 194 | 195 | 196 | def center_crop(img, shape): 197 | h, w = shape[0], shape[1] 198 | y = (img.shape[0] - h) // 2 199 | x = (img.shape[1] - w) // 2 200 | return img[y:y + h, x:x + w] 201 | 202 | 203 | def random_crop(img, gt, size): 204 | if isinstance(size, numbers.Number): 205 | size = (int(size), int(size)) 206 | else: 207 | size = size 208 | 209 | h, w = img.shape[:2] 210 | crop_h, crop_w = size[0], size[1] 211 | 212 | if h > crop_h: 213 | x = random.randint(0, h - crop_h + 1) 214 | img = img[x:x + crop_h, :, :] 215 | gt = gt[x:x + crop_h, :] 216 | 217 | if w > crop_w: 218 | x = random.randint(0, w - crop_w + 1) 219 | img = img[:, x:x + crop_w, :] 220 | gt = gt[:, x:x + crop_w] 221 | 222 | return img, gt 223 | 224 | 225 | def normalize(img, mean, std): 226 | # pytorch pretrained model need the input range: 0-1 227 | img = img.astype(np.float32) / 255.0 228 | img = img - mean 229 | img = img / std 230 | 231 | return img 232 | -------------------------------------------------------------------------------- /segmentation/code/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import time 3 | 4 | import cv2 5 | import numpy 6 | import torch 7 | 8 | from utils.metric import compute_score 9 | 10 | 11 | class SlidingEval(torch.nn.Module): 12 | def __init__(self, config, device): 13 | super(SlidingEval, self).__init__() 14 | self.config = config 15 | self.device = device 16 | 17 | # slide the window to evaluate the image 18 | def forward(self, img, model, device=None): 19 | ori_rows, ori_cols, c = img.shape 20 | # fix to be 19, for inlier testing. 21 | num_class = 19 22 | processed_pred = numpy.zeros((ori_rows, ori_cols, num_class)) 23 | 24 | # it is single scale 25 | multi_scales = self.config.eval_scale_array 26 | for s in multi_scales: 27 | img_scale = cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_LINEAR) 28 | new_rows, new_cols, _ = img_scale.shape 29 | processed_pred += self.scale_process(img_scale, (ori_rows, ori_cols), 30 | self.config.eval_crop_size, self.config.eval_stride_rate, 31 | model, device) 32 | 33 | pred = processed_pred.argmax(2) 34 | return pred 35 | 36 | def process_image(self, img, crop_size=None): 37 | p_img = img 38 | 39 | if img.shape[2] < 3: 40 | im_b = p_img 41 | im_g = p_img 42 | im_r = p_img 43 | p_img = numpy.concatenate((im_b, im_g, im_r), axis=2) 44 | 45 | # p_img = self.normalize(p_img, self.config.image_mean, self.config.image_std) 46 | 47 | if crop_size is not None: 48 | p_img, margin = self.pad_image_to_shape(p_img, crop_size, 49 | cv2.BORDER_CONSTANT, value=0) 50 | p_img = p_img.transpose(2, 0, 1) 51 | 52 | return p_img, margin 53 | 54 | p_img = p_img.transpose(2, 0, 1) 55 | 56 | return p_img 57 | 58 | def val_func_process(self, input_data, val_func, device=None): 59 | input_data = numpy.ascontiguousarray(input_data[None, :, :, :], dtype=numpy.float32) 60 | input_data = torch.FloatTensor(input_data).cuda(device) 61 | 62 | with torch.cuda.device(input_data.get_device()): 63 | val_func.eval() 64 | val_func.to(input_data.get_device()) 65 | with torch.no_grad(): 66 | # modify for 19 classes 67 | start_time = time.time() 68 | score = val_func.module(input_data, output_anomaly=False) 69 | end_time = time.time() 70 | time_len = end_time - start_time 71 | # remove last reservation channel for OoD 72 | score = score.squeeze()[:19] 73 | 74 | # if self.config.eval_flip: 75 | # input_data = input_data.flip(-1) 76 | # score_flip = val_func(input_data) 77 | # score_flip = score_flip[0] 78 | # score += score_flip.flip(-1) 79 | 80 | return score, time_len 81 | 82 | def scale_process(self, img, ori_shape, crop_size, stride_rate, model, device=None): 83 | new_rows, new_cols, c = img.shape 84 | long_size = new_cols if new_cols > new_rows else new_rows 85 | 86 | if isinstance(crop_size, int): 87 | crop_size = (crop_size, crop_size) 88 | 89 | # remove last reservation channel for OoD 90 | class_num = 19 91 | if long_size <= min(crop_size[0], crop_size[1]): 92 | input_data, margin = self.process_image(img, crop_size) # pad image 93 | score, _ = self.val_func_process(input_data, model, device) 94 | score = score[:class_num, margin[0]:(score.shape[1] - margin[1]), 95 | margin[2]:(score.shape[2] - margin[3])] 96 | else: 97 | stride_0 = int(numpy.ceil(crop_size[0] * stride_rate)) 98 | stride_1 = int(numpy.ceil(crop_size[1] * stride_rate)) 99 | img_pad, margin = self.pad_image_to_shape(img, crop_size, cv2.BORDER_CONSTANT, value=0) 100 | pad_rows = img_pad.shape[0] 101 | pad_cols = img_pad.shape[1] 102 | r_grid = int(numpy.ceil((pad_rows - crop_size[0]) / stride_0)) + 1 103 | c_grid = int(numpy.ceil((pad_cols - crop_size[1]) / stride_1)) + 1 104 | data_scale = torch.zeros(class_num, pad_rows, pad_cols).cuda( 105 | device) 106 | count_scale = torch.zeros(class_num, pad_rows, pad_cols).cuda( 107 | device) 108 | 109 | for grid_yidx in range(r_grid): 110 | for grid_xidx in range(c_grid): 111 | s_x = grid_xidx * stride_1 112 | s_y = grid_yidx * stride_0 113 | e_x = min(s_x + crop_size[1], pad_cols) 114 | e_y = min(s_y + crop_size[0], pad_rows) 115 | s_x = e_x - crop_size[1] 116 | s_y = e_y - crop_size[0] 117 | img_sub = img_pad[s_y:e_y, s_x: e_x, :] 118 | count_scale[:, s_y: e_y, s_x: e_x] += 1 119 | 120 | input_data, tmargin = self.process_image(img_sub, crop_size) 121 | temp_score, _ = self.val_func_process(input_data, model, device) 122 | temp_score = temp_score[:class_num, 123 | tmargin[0]:(temp_score.shape[1] - tmargin[1]), 124 | tmargin[2]:(temp_score.shape[2] - tmargin[3])] 125 | 126 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score 127 | # score = data_scale / count_scale 128 | score = data_scale 129 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 130 | margin[2]:(score.shape[2] - margin[3])] 131 | 132 | score = score.permute(1, 2, 0) 133 | data_output = cv2.resize(score.cpu().numpy(), (ori_shape[1], ori_shape[0]), 134 | interpolation=cv2.INTER_LINEAR) 135 | 136 | return data_output 137 | 138 | def compute_metric(self, results): 139 | hist = numpy.zeros((19, 19)) 140 | correct = 0 141 | labeled = 0 142 | count = 0 143 | for d in results: 144 | hist += d['hist'] 145 | correct += d['correct'] 146 | labeled += d['labeled'] 147 | count += 1 148 | 149 | iu, mean_IU, _, mean_pixel_acc = compute_score(hist, correct, labeled) 150 | return mean_IU, mean_pixel_acc 151 | 152 | def pad_image_to_shape(self, img, shape, border_mode, value): 153 | margin = numpy.zeros(4, numpy.uint32) 154 | shape = self.get_2dshape(shape) 155 | pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 156 | pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0 157 | 158 | margin[0] = pad_height // 2 159 | margin[1] = pad_height // 2 + pad_height % 2 160 | margin[2] = pad_width // 2 161 | margin[3] = pad_width // 2 + pad_width % 2 162 | 163 | img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3], 164 | border_mode, value=value) 165 | 166 | return img, margin 167 | 168 | def get_2dshape(self, shape, *, zero=True): 169 | if not isinstance(shape, collections.Iterable): 170 | shape = int(shape) 171 | shape = (shape, shape) 172 | else: 173 | h, w = map(int, shape) 174 | shape = (h, w) 175 | if zero: 176 | minv = 0 177 | else: 178 | minv = 1 179 | 180 | assert min(shape) >= minv, 'invalid shape: {}'.format(shape) 181 | return shape 182 | 183 | @staticmethod 184 | def normalize(img, mean, std): 185 | img = img.astype(numpy.float32) / 255.0 186 | img = img - mean 187 | img = img / std 188 | return img 189 | -------------------------------------------------------------------------------- /segmentation/code/model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from utils.pyt_utils import load_model 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, norm_layer=None, 19 | bn_eps=1e-5, bn_momentum=0.1, downsample=None, inplace=True): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn1 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum) 23 | self.relu = nn.ReLU(inplace=inplace) 24 | self.relu_inplace = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum) 27 | self.downsample = downsample 28 | self.stride = stride 29 | self.inplace = inplace 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | if self.inplace: 45 | out += residual 46 | else: 47 | out = out + residual 48 | 49 | out = self.relu_inplace(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, 58 | norm_layer=None, bn_eps=1e-5, bn_momentum=0.1, 59 | downsample=None, inplace=True): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 67 | bias=False) 68 | self.bn3 = norm_layer(planes * self.expansion, eps=bn_eps, 69 | momentum=bn_momentum) 70 | self.relu = nn.ReLU(inplace=inplace) 71 | self.relu_inplace = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | self.inplace = inplace 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | if self.inplace: 94 | out += residual 95 | else: 96 | out = out + residual 97 | out = self.relu_inplace(out) 98 | 99 | return out 100 | 101 | 102 | class ResNet(nn.Module): 103 | 104 | def __init__(self, block, layers, norm_layer=nn.BatchNorm2d, bn_eps=1e-5, 105 | bn_momentum=0.1, deep_stem=False, stem_width=32, inplace=True): 106 | self.inplanes = stem_width * 2 if deep_stem else 64 107 | super(ResNet, self).__init__() 108 | if deep_stem: 109 | self.conv1 = nn.Sequential( 110 | nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, 111 | bias=False), 112 | norm_layer(stem_width, eps=bn_eps, momentum=bn_momentum), 113 | nn.ReLU(inplace=inplace), 114 | nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, 115 | padding=1, 116 | bias=False), 117 | norm_layer(stem_width, eps=bn_eps, momentum=bn_momentum), 118 | nn.ReLU(inplace=inplace), 119 | nn.Conv2d(stem_width, stem_width * 2, kernel_size=3, stride=1, 120 | padding=1, 121 | bias=False), 122 | ) 123 | else: 124 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 125 | bias=False) 126 | 127 | self.bn1 = norm_layer(stem_width * 2 if deep_stem else 64, eps=bn_eps, 128 | momentum=bn_momentum) 129 | self.relu = nn.ReLU(inplace=inplace) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, norm_layer, 64, layers[0], 132 | inplace, 133 | bn_eps=bn_eps, bn_momentum=bn_momentum) 134 | self.layer2 = self._make_layer(block, norm_layer, 128, layers[1], 135 | inplace, stride=2, 136 | bn_eps=bn_eps, bn_momentum=bn_momentum) 137 | self.layer3 = self._make_layer(block, norm_layer, 256, layers[2], 138 | inplace, stride=2, 139 | bn_eps=bn_eps, bn_momentum=bn_momentum) 140 | self.layer4 = self._make_layer(block, norm_layer, 512, layers[3], 141 | inplace, stride=2, 142 | bn_eps=bn_eps, bn_momentum=bn_momentum) 143 | 144 | def _make_layer(self, block, norm_layer, planes, blocks, inplace=True, 145 | stride=1, bn_eps=1e-5, bn_momentum=0.1): 146 | downsample = None 147 | if stride != 1 or self.inplanes != planes * block.expansion: 148 | downsample = nn.Sequential( 149 | nn.Conv2d(self.inplanes, planes * block.expansion, 150 | kernel_size=1, stride=stride, bias=False), 151 | norm_layer(planes * block.expansion, eps=bn_eps, 152 | momentum=bn_momentum), 153 | ) 154 | 155 | layers = [] 156 | layers.append(block(self.inplanes, planes, stride, norm_layer, bn_eps, 157 | bn_momentum, downsample, inplace)) 158 | self.inplanes = planes * block.expansion 159 | for i in range(1, blocks): 160 | layers.append(block(self.inplanes, planes, 161 | norm_layer=norm_layer, bn_eps=bn_eps, 162 | bn_momentum=bn_momentum, inplace=inplace)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | x = self.conv1(x) 168 | x = self.bn1(x) 169 | x = self.relu(x) 170 | x = self.maxpool(x) 171 | 172 | blocks = [] 173 | x = self.layer1(x); 174 | blocks.append(x) 175 | x = self.layer2(x); 176 | blocks.append(x) 177 | x = self.layer3(x); 178 | blocks.append(x) 179 | x = self.layer4(x); 180 | blocks.append(x) 181 | 182 | return blocks 183 | 184 | 185 | def resnet18(pretrained_model=None, **kwargs): 186 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 187 | 188 | if pretrained_model is not None: 189 | model = load_model(model, pretrained_model) 190 | return model 191 | 192 | 193 | def resnet34(pretrained_model=None, **kwargs): 194 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 195 | 196 | if pretrained_model is not None: 197 | model = load_model(model, pretrained_model) 198 | return model 199 | 200 | 201 | def resnet50(pretrained_model=None, **kwargs): 202 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 203 | 204 | if pretrained_model is not None: 205 | model = load_model(model, pretrained_model) 206 | return model 207 | 208 | 209 | def resnet101(pretrained_model=None, **kwargs): 210 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 211 | 212 | if pretrained_model is not None: 213 | model = load_model(model, pretrained_model) 214 | return model 215 | 216 | 217 | def resnet152(pretrained_model=None, **kwargs): 218 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 219 | 220 | if pretrained_model is not None: 221 | model = load_model(model, pretrained_model) 222 | return model 223 | -------------------------------------------------------------------------------- /segmentation/code/engine/engine.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import shutil 5 | import time 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from utils.pyt_utils import load_model, link_file, ensure_dir 11 | from utils.pyt_utils import on_load_checkpoint 12 | 13 | 14 | class State(object): 15 | def __init__(self): 16 | self.epoch = 0 17 | self.iteration = 0 18 | self.dataloader = None 19 | self.model = None 20 | self.optimizer = None 21 | 22 | def register(self, **kwargs): 23 | for k, v in kwargs.items(): 24 | setattr(self, k, v) 25 | 26 | 27 | class Engine(object): 28 | def __init__(self, custom_arg, logger, continue_state_object): 29 | assert continue_state_object is not None, "our proj. only works upon the pretrained weight" 30 | self.logger = logger 31 | self.state = State() 32 | self.devices = None 33 | self.distributed = False 34 | self.parser = argparse.ArgumentParser() 35 | self.inject_default_parser() 36 | self.args = custom_arg 37 | self.continue_state_object = continue_state_object 38 | 39 | if 'WORLD_SIZE' in os.environ: 40 | self.distributed = int(os.environ['WORLD_SIZE']) >= 1 41 | else: 42 | self.distributed = self.args.ddp 43 | self.local_rank = 0 if self.args.local_rank < 0 else self.args.local_rank 44 | self.gpus = self.args.gpus 45 | self.world_size = self.args.world_size 46 | 47 | if self.distributed: 48 | os.environ['MASTER_ADDR'] = '127.0.0.4' 49 | os.environ['MASTER_PORT'] = '9904' 50 | dist.init_process_group(backend="nccl", 51 | init_method='env://', 52 | rank=self.local_rank, 53 | world_size=self.world_size) 54 | else: 55 | self.world_size = 1 56 | 57 | def inject_default_parser(self): 58 | p = self.parser 59 | p.add_argument('-d', '--devices', default='0', 60 | help='set data parallel training') 61 | p.add_argument('-c', '--continue', type=str, 62 | dest="continue_fpath", 63 | help='continue from one certain checkpoint') 64 | p.add_argument('-p', '--port', type=str, 65 | default='16001', 66 | dest="port", 67 | help='port for init_process_group') 68 | p.add_argument('--debug', default=0, type=int, 69 | help='whlocal_rankether to use the debug mode') 70 | p.add_argument('-e', '--epochs', default='last', type=str) 71 | 72 | p.add_argument('-v', '--verbose', default=False, action='store_true') 73 | p.add_argument('--show_image', '-s', default=True, 74 | action='store_true') 75 | p.add_argument('--save_path', default=None) 76 | 77 | def register_state(self, **kwargs): 78 | self.state.register(**kwargs) 79 | 80 | def update_iteration(self, epoch, iteration): 81 | self.state.epoch = epoch 82 | self.state.iteration = iteration 83 | 84 | def save_checkpoint(self, path): 85 | self.logger.info("Saving checkpoint to file {}".format(path)) 86 | t_start = time.time() 87 | 88 | state_dict = {} 89 | 90 | from collections import OrderedDict 91 | new_state_dict = OrderedDict() 92 | for k, v in self.state.model.state_dict().items(): 93 | key = k 94 | if k.split('.')[0] == 'module': 95 | key = k[7:] 96 | new_state_dict[key] = v 97 | state_dict['model'] = new_state_dict 98 | if self.state.optimizer is not None: 99 | state_dict['optimizer'] = self.state.optimizer.state_dict() 100 | state_dict['epoch'] = self.state.epoch 101 | state_dict['iteration'] = self.state.iteration 102 | 103 | t_iobegin = time.time() 104 | torch.save(state_dict, path) 105 | del state_dict 106 | del new_state_dict 107 | t_end = time.time() 108 | self.logger.info( 109 | "Save checkpoint to file {}, " 110 | "Time usage:\n\tprepare snapshot: {}, IO: {}".format( 111 | path, t_iobegin - t_start, t_end - t_iobegin)) 112 | 113 | def link_tb(self, source, target): 114 | ensure_dir(source) 115 | ensure_dir(target) 116 | link_file(source, target) 117 | 118 | def save_and_link_checkpoint(self, snapshot_dir, log_dir=None, log_dir_link=None, m_iou=None, name=None): 119 | ensure_dir(snapshot_dir) 120 | if name is None: 121 | current_epoch_checkpoint = os.path.join(snapshot_dir, 'epoch-{}-iou-{}.pth'.format( 122 | self.state.epoch, m_iou)) 123 | else: 124 | current_epoch_checkpoint = os.path.join(snapshot_dir, '{}.pth'.format( 125 | name)) 126 | 127 | 128 | if os.path.exists(current_epoch_checkpoint): 129 | os.remove(current_epoch_checkpoint) 130 | 131 | self.save_checkpoint(current_epoch_checkpoint) 132 | last_epoch_checkpoint = os.path.join(snapshot_dir, 'epoch-last.pth') 133 | # link_file(current_epoch_checkpoint, last_epoch_checkpoint) 134 | try: 135 | shutil.copy(current_epoch_checkpoint, last_epoch_checkpoint) 136 | except: 137 | pass 138 | 139 | def restore_checkpoint(self, extra_channel=False, eval=False): 140 | t_start = time.time() 141 | continue_state_object = self.continue_state_object 142 | self.logger.critical('restoring ckpt from pretrained file {}.'.format(continue_state_object)) 143 | 144 | if self.distributed: 145 | tmp = torch.load(continue_state_object, 146 | map_location=lambda storage, loc: storage.cuda(self.local_rank)) 147 | else: 148 | tmp = torch.load(continue_state_object) 149 | 150 | t_ioend = time.time() 151 | if eval: 152 | self.state.model = on_load_checkpoint(model=self.state.model, checkpoint=tmp['model']) 153 | else: 154 | self.state.model = load_model(self.state.model, tmp['state_dict'], True, strict=True, 155 | extra_channel=extra_channel) 156 | 157 | self.state.epoch = 0 # tmp['epoch'] + 1 158 | self.state.iteration = 0 # tmp['iteration'] 159 | del tmp 160 | t_end = time.time() 161 | self.logger.info("Load checkpoint from file {}, " 162 | "Time usage:\n\tIO: {}, restore snapshot: {}".format(self.continue_state_object, 163 | t_ioend - t_start, t_end - t_ioend)) 164 | def load_pebal_ckpt_dist(self, ckpt_name, model): 165 | if self.distributed: 166 | tmp = torch.load(ckpt_name, 167 | map_location=lambda storage, loc: storage.cuda(self.local_rank)) 168 | else: 169 | tmp = torch.load(ckpt_name) 170 | self.logger.critical('restoring pebal ckpt from {}'.format(ckpt_name)) 171 | state_dict = tmp 172 | if 'model' in state_dict.keys(): 173 | state_dict = state_dict['model'] 174 | 175 | from collections import OrderedDict 176 | new_state_dict = OrderedDict() 177 | 178 | for k, v in state_dict.items(): 179 | name = k 180 | new_state_dict[name] = v 181 | state_dict = new_state_dict 182 | model.module.load_state_dict(state_dict, strict=True) 183 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 184 | model.to(device) 185 | return model 186 | 187 | def load_pebal_ckpt(self, ckpt_name, model): 188 | tmp = torch.load(ckpt_name) 189 | self.logger.critical('restoring pebal ckpt from {}'.format(ckpt_name)) 190 | state_dict = tmp 191 | if 'model' in state_dict.keys(): 192 | state_dict = state_dict['model'] 193 | 194 | from collections import OrderedDict 195 | new_state_dict = OrderedDict() 196 | 197 | for k, v in state_dict.items(): 198 | name = k 199 | new_state_dict[name] = v 200 | state_dict = new_state_dict 201 | model.module.load_state_dict(state_dict, strict=True) 202 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 203 | model.to(device) 204 | return model 205 | 206 | def __enter__(self): 207 | return self 208 | 209 | def __exit__(self, type, value, tb): 210 | torch.cuda.empty_cache() 211 | if type is not None: 212 | self.logger.warning( 213 | "A exception occurred during Engine initialization, " 214 | "give up running process") 215 | return False 216 | -------------------------------------------------------------------------------- /classification/utils/display_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics as sk 3 | 4 | recall_level_default = 0.95 5 | 6 | 7 | def stable_cumsum(arr, rtol=1e-05, atol=1e-08): 8 | """Use high precision for cumsum and check that final value matches sum 9 | Parameters 10 | ---------- 11 | arr : array-like 12 | To be cumulatively summed as flat 13 | rtol : float 14 | Relative tolerance, see ``np.allclose`` 15 | atol : float 16 | Absolute tolerance, see ``np.allclose`` 17 | """ 18 | out = np.cumsum(arr, dtype=np.float64) 19 | expected = np.sum(arr, dtype=np.float64) 20 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): 21 | raise RuntimeError('cumsum was found to be unstable: ' 22 | 'its last element does not correspond to sum') 23 | return out 24 | 25 | 26 | def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None): 27 | classes = np.unique(y_true) 28 | if (pos_label is None and 29 | not (np.array_equal(classes, [0, 1]) or 30 | np.array_equal(classes, [-1, 1]) or 31 | np.array_equal(classes, [0]) or 32 | np.array_equal(classes, [-1]) or 33 | np.array_equal(classes, [1]))): 34 | raise ValueError("Data is not binary and pos_label is not specified") 35 | elif pos_label is None: 36 | pos_label = 1. 37 | 38 | # make y_true a boolean vector 39 | y_true = (y_true == pos_label) 40 | 41 | # sort scores and corresponding truth values 42 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 43 | y_score = y_score[desc_score_indices] 44 | y_true = y_true[desc_score_indices] 45 | 46 | # y_score typically has many tied values. Here we extract 47 | # the indices associated with the distinct values. We also 48 | # concatenate a value for the end of the curve. 49 | distinct_value_indices = np.where(np.diff(y_score))[0] 50 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 51 | 52 | # accumulate the true positives with decreasing threshold 53 | tps = stable_cumsum(y_true)[threshold_idxs] 54 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 55 | 56 | thresholds = y_score[threshold_idxs] 57 | 58 | recall = tps / tps[-1] 59 | 60 | last_ind = tps.searchsorted(tps[-1]) 61 | sl = slice(last_ind, None, -1) # [last_ind::-1] 62 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] 63 | 64 | cutoff = np.argmin(np.abs(recall - recall_level)) 65 | 66 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff]) 67 | 68 | 69 | def get_measures(_pos, _neg, recall_level=recall_level_default): 70 | pos = np.array(_pos[:]).reshape((-1, 1)) 71 | neg = np.array(_neg[:]).reshape((-1, 1)) 72 | examples = np.squeeze(np.vstack((pos, neg))) 73 | labels = np.zeros(len(examples), dtype=np.int32) 74 | labels[:len(pos)] += 1 75 | 76 | auroc = sk.roc_auc_score(labels, examples) 77 | aupr = sk.average_precision_score(labels, examples) 78 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level) 79 | 80 | return auroc, aupr, fpr 81 | 82 | 83 | def show_performance(pos, neg, method_name='Ours', recall_level=recall_level_default): 84 | ''' 85 | :param pos: 1's class, class to detect, outliers, or wrongly predicted 86 | example scores 87 | :param neg: 0's class scores 88 | ''' 89 | 90 | auroc, aupr, fpr = get_measures(pos[:], neg[:], recall_level) 91 | 92 | print('\t\t\t' + method_name) 93 | print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 94 | print('AUROC:\t\t\t{:.2f}'.format(100 * auroc)) 95 | print('AUPR:\t\t\t{:.2f}'.format(100 * aupr)) 96 | # print('FDR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fdr)) 97 | def show_performance_log(logger, pos, neg, method_name='Ours', recall_level=recall_level_default): 98 | ''' 99 | :param pos: 1's class, class to detect, outliers, or wrongly predicted 100 | example scores 101 | :param neg: 0's class scores 102 | ''' 103 | 104 | auroc, aupr, fpr = get_measures(pos[:], neg[:], recall_level) 105 | 106 | print('\t\t\t' + method_name) 107 | print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 108 | print('AUROC:\t\t\t{:.2f}'.format(100 * auroc)) 109 | print('AUPR:\t\t\t{:.2f}'.format(100 * aupr)) 110 | logger.info('\t\t\t' + method_name) 111 | logger.info('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 112 | logger.info('AUROC:\t\t\t{:.2f}'.format(100 * auroc)) 113 | logger.info('AUPR:\t\t\t{:.2f}'.format(100 * aupr)) 114 | 115 | 116 | # print('FDR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fdr)) 117 | 118 | def print_measures(auroc, aupr, fpr, method_name='Ours', recall_level=recall_level_default): 119 | print('\t\t\t\t' + method_name) 120 | print(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 121 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*fpr, 100*auroc, 100*aupr)) 122 | #print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 123 | #print('AUROC: \t\t\t{:.2f}'.format(100 * auroc)) 124 | #print('AUPR: \t\t\t{:.2f}'.format(100 * aupr)) 125 | 126 | def print_measures_log(logger, auroc, aupr, fpr, method_name='Ours', recall_level=recall_level_default): 127 | print('\t\t\t\t' + method_name) 128 | print(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 129 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*fpr, 100*auroc, 100*aupr)) 130 | logger.info('\t\t\t\t' + method_name) 131 | logger.info(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 132 | logger.info('& {:.2f} & {:.2f} & {:.2f}'.format(100*fpr, 100*auroc, 100*aupr)) 133 | #print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 134 | #print('AUROC: \t\t\t{:.2f}'.format(100 * auroc)) 135 | #print('AUPR: \t\t\t{:.2f}'.format(100 * aupr)) 136 | def print_measures_with_std(aurocs, auprs, fprs, method_name='Ours', recall_level=recall_level_default): 137 | print('\t\t\t\t' + method_name) 138 | print(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 139 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.mean(fprs), 100*np.mean(aurocs), 100*np.mean(auprs))) 140 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.std(fprs), 100*np.std(aurocs), 100*np.std(auprs))) 141 | #print('FPR{:d}:\t\t\t{:.2f}\t+/- {:.2f}'.format(int(100 * recall_level), 100 * np.mean(fprs), 100 * np.std(fprs))) 142 | #print('AUROC: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(aurocs), 100 * np.std(aurocs))) 143 | #print('AUPR: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(auprs), 100 * np.std(auprs))) 144 | 145 | def print_measures_with_std_log(logger, aurocs, auprs, fprs, method_name='Ours', recall_level=recall_level_default): 146 | print('\t\t\t\t' + method_name) 147 | print(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 148 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.mean(fprs), 100*np.mean(aurocs), 100*np.mean(auprs))) 149 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.std(fprs), 100*np.std(aurocs), 100*np.std(auprs))) 150 | logger.info('\t\t\t\t' + method_name) 151 | logger.info(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 152 | logger.info('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.mean(fprs), 100*np.mean(aurocs), 100*np.mean(auprs))) 153 | logger.info('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.std(fprs), 100*np.std(aurocs), 100*np.std(auprs))) 154 | #print('FPR{:d}:\t\t\t{:.2f}\t+/- {:.2f}'.format(int(100 * recall_level), 100 * np.mean(fprs), 100 * np.std(fprs))) 155 | #print('AUROC: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(aurocs), 100 * np.std(aurocs))) 156 | #print('AUPR: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(auprs), 100 * np.std(auprs))) 157 | 158 | def show_performance_comparison(pos_base, neg_base, pos_ours, neg_ours, baseline_name='Baseline', 159 | method_name='Ours', recall_level=recall_level_default): 160 | ''' 161 | :param pos_base: 1's class, class to detect, outliers, or wrongly predicted 162 | example scores from the baseline 163 | :param neg_base: 0's class scores generated by the baseline 164 | ''' 165 | auroc_base, aupr_base, fpr_base = get_measures(pos_base[:], neg_base[:], recall_level) 166 | auroc_ours, aupr_ours, fpr_ours = get_measures(pos_ours[:], neg_ours[:], recall_level) 167 | 168 | print('\t\t\t' + baseline_name + '\t' + method_name) 169 | print('FPR{:d}:\t\t\t{:.2f}\t\t{:.2f}'.format( 170 | int(100 * recall_level), 100 * fpr_base, 100 * fpr_ours)) 171 | print('AUROC:\t\t\t{:.2f}\t\t{:.2f}'.format( 172 | 100 * auroc_base, 100 * auroc_ours)) 173 | print('AUPR:\t\t\t{:.2f}\t\t{:.2f}'.format( 174 | 100 * aupr_base, 100 * aupr_ours)) 175 | # print('FDR{:d}:\t\t\t{:.2f}\t\t{:.2f}'.format( 176 | # int(100 * recall_level), 100 * fdr_base, 100 * fdr_ours)) 177 | -------------------------------------------------------------------------------- /segmentation/code/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | from torch.nn.functional import normalize 5 | 6 | def influence(value, weight, dim=None, keepdim=False): 7 | 8 | influence= torch.sum(value*F.relu(weight), dim=dim, keepdim=keepdim) 9 | influence = influence.squeeze(dim) 10 | 11 | return influence 12 | 13 | def smooth(arr, lamda1): 14 | new_array = arr 15 | arr2 = torch.zeros_like(arr) 16 | arr2[:, :-1, :] = arr[:, 1:, :] 17 | arr2[:, -1, :] = arr[:, -1, :] 18 | 19 | new_array2 = torch.zeros_like(new_array) 20 | new_array2[:, :, :-1] = new_array[:, :, 1:] 21 | new_array2[:, :, -1] = new_array[:, :, -1] 22 | loss = (torch.sum((arr2 - arr) ** 2) + torch.sum((new_array2 - new_array) ** 2)) / 2 23 | return lamda1 * loss 24 | 25 | 26 | def sparsity(arr, lamda2): 27 | loss = torch.mean(torch.norm(arr, dim=0)) 28 | return lamda2 * loss 29 | 30 | 31 | def energy_loss_with_smooth_sparsity(logits, targets): 32 | ood_ind = 254 33 | void_ind = 255 34 | num_class = 19 35 | T = 1. 36 | m_in = -12 37 | m_out = -6 38 | 39 | energy = -(T * torch.logsumexp(logits[:, :num_class, :, :] / T, dim=1)) 40 | Ec_out = energy[targets == ood_ind] 41 | Ec_in = energy[(targets != ood_ind) & (targets != void_ind)] 42 | 43 | loss = torch.tensor(0.).cuda() 44 | if Ec_out.size()[0] == 0: 45 | loss += torch.pow(F.relu(Ec_in - m_in), 2).mean() 46 | else: 47 | loss += 0.5 * (torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean()) 48 | loss += sparsity(Ec_out, 5e-4) 49 | 50 | loss += smooth(energy, 3e-6) 51 | 52 | return loss, energy 53 | 54 | 55 | def balanced_energy_loss_with_smooth_sparsity(logits, targets, gamma1=0, gamma2=0, alpha=0): 56 | OOD_prior=torch.tensor([0.152509041,0.013298,0.106703,0.040058,0.033866,0.005256,0.001065,0.027989,0.080249,0.097971,0.002175,0.153123,0.029222,0.087057,0.119925,0.016023,0.007071,0.017048,0.009393]).cuda() 57 | 58 | OOD_prior_gamma1=OOD_prior**gamma1 59 | OOD_prior_gamma2=OOD_prior**gamma2 60 | 61 | OOD_prior_gamma1 = normalize(OOD_prior_gamma1, p=1.0, dim=0) 62 | OOD_prior_gamma2 = normalize(OOD_prior_gamma2, p=1.0, dim=0) 63 | 64 | OOD_prior_gamma1=OOD_prior_gamma1[None,:,None,None] 65 | OOD_prior_gamma2=OOD_prior_gamma2[None,:,None,None] 66 | 67 | ood_ind = 254 68 | void_ind = 255 69 | num_class = 19 70 | T = 1.0 71 | m_in = -12 72 | m_out = -6 73 | 74 | softmax=torch.nn.functional.softmax(logits[:, :num_class, :, :] ,dim=1) 75 | 76 | influences_for_margin = influence(softmax,OOD_prior_gamma1,dim=1) 77 | influences_for_loss = influence(softmax,OOD_prior_gamma2,dim=1) 78 | 79 | Ec_out_margin=influences_for_margin[targets==ood_ind] 80 | Ec_out_weight=influences_for_loss[targets==ood_ind] 81 | 82 | energy = -(T * torch.logsumexp(logits[:, :num_class, :, :] / T, dim=1)) 83 | Ec_out = energy[targets == ood_ind] 84 | Ec_in = energy[(targets != ood_ind) & (targets != void_ind)] 85 | 86 | Ec_out= Ec_out-(alpha*Ec_out_margin) 87 | 88 | loss = torch.tensor(0.).cuda() 89 | if Ec_out.size()[0] == 0: 90 | loss += torch.pow(F.relu(Ec_in - m_in), 2).mean() 91 | else: 92 | #loss += 0.5 * (torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean()) 93 | loss += 0.5 * (torch.pow(F.relu(Ec_in - m_in), 2).mean() + (torch.pow(F.relu((m_out - Ec_out)), 2) * Ec_out_weight).sum() / Ec_out_weight.sum()) 94 | 95 | loss += sparsity(Ec_out, 5e-4) 96 | 97 | loss += smooth(energy, 3e-6) 98 | 99 | return loss, energy 100 | 101 | 102 | def energy_loss_pure(logits, targets): 103 | ood_ind = 254 104 | void_ind = 255 105 | num_class = 19 106 | T = 1. 107 | m_in = -12 108 | m_out = -6 109 | 110 | energy = -(T * torch.logsumexp(logits[:, :num_class, :, :] / T, dim=1)) 111 | Ec_out = energy[targets == ood_ind] 112 | Ec_in = energy[(targets != ood_ind) & (targets != void_ind)] 113 | 114 | loss = torch.tensor(0.).cuda() 115 | if Ec_out.size()[0] == 0: 116 | loss += torch.pow(F.relu(Ec_in - m_in), 2).mean() 117 | else: 118 | loss += 0.5 * (torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean()) 119 | 120 | return loss, energy 121 | 122 | def balanced_energy_loss_pure(logits, targets, gamma1=0, gamma2=0, alpha=0): 123 | OOD_prior=torch.tensor([0.152509041,0.013298,0.106703,0.040058,0.033866,0.005256,0.001065,0.027989,0.080249,0.097971,0.002175,0.153123,0.029222,0.087057,0.119925,0.016023,0.007071,0.017048,0.009393]).cuda() 124 | 125 | OOD_prior_gamma1=OOD_prior**gamma1 126 | OOD_prior_gamma2=OOD_prior**gamma2 127 | 128 | OOD_prior_gamma1 = normalize(OOD_prior_gamma1, p=1.0, dim=0) 129 | OOD_prior_gamma2 = normalize(OOD_prior_gamma2, p=1.0, dim=0) 130 | 131 | OOD_prior_gamma1=OOD_prior_gamma1[None,:,None,None] 132 | OOD_prior_gamma2=OOD_prior_gamma2[None,:,None,None] 133 | 134 | ood_ind = 254 135 | void_ind = 255 136 | num_class = 19 137 | T = 1.0 138 | m_in = -12 139 | m_out = -6 140 | 141 | softmax=torch.nn.functional.softmax(logits[:, :num_class, :, :] ,dim=1) 142 | 143 | influences_for_margin = influence(softmax,OOD_prior_gamma1,dim=1) 144 | influences_for_loss = influence(softmax,OOD_prior_gamma2,dim=1) 145 | 146 | Ec_out_margin=influences_for_margin[targets==ood_ind] 147 | Ec_out_weight=influences_for_loss[targets==ood_ind] 148 | 149 | energy = -( torch.logsumexp(logits[:, :num_class, :, :] / T, dim=1)) 150 | Ec_out = energy[targets == ood_ind] 151 | Ec_in = energy[(targets != ood_ind) & (targets != void_ind)] 152 | 153 | Ec_out= Ec_out-(alpha*Ec_out_margin) 154 | 155 | loss = torch.tensor(0.).cuda() 156 | 157 | 158 | if Ec_out.size()[0] == 0: 159 | loss += (torch.pow(F.relu(Ec_in - m_in), 2)).mean() 160 | else: 161 | #loss += 0.5 * (torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu((m_out - Ec_out)), 2).mean()) 162 | loss += 0.5 * (torch.pow(F.relu(Ec_in - m_in), 2).mean() + (torch.pow(F.relu((m_out - Ec_out)), 2) * Ec_out_weight).sum() / Ec_out_weight.sum()) 163 | 164 | return loss, energy 165 | 166 | 167 | 168 | class Gambler(torch.nn.Module): 169 | def __init__(self, reward, device, pretrain=-1, ood_reg=.1): 170 | super(Gambler, self).__init__() 171 | self.reward = torch.tensor([reward]).cuda(device) 172 | self.pretrain = pretrain 173 | self.ood_reg = ood_reg 174 | self.device = device 175 | 176 | def forward(self, pred, targets, wrong_sample=False): 177 | 178 | pred_prob = torch.softmax(pred, dim=1) 179 | 180 | assert torch.all(pred_prob > 0), print(pred_prob[pred_prob <= 0]) 181 | assert torch.all(pred_prob <= 1), print(pred_prob[pred_prob > 1]) 182 | true_pred, reservation = pred_prob[:, :-1, :, :], pred_prob[:, -1, :, :] 183 | 184 | # compute the reward via the energy score; 185 | reward = torch.logsumexp(pred[:, :-1, :, :], dim=1).pow(2) 186 | 187 | if reward.nelement() > 0: 188 | gaussian_smoothing = transforms.GaussianBlur(7, sigma=1) 189 | reward = reward.unsqueeze(0) 190 | reward = gaussian_smoothing(reward) 191 | reward = reward.squeeze(0) 192 | else: 193 | reward = self.reward 194 | 195 | if wrong_sample: # if there's ood pixels inside the image 196 | reservation = torch.div(reservation, reward) 197 | mask = targets == 254 198 | # mask out each of the ood output channel 199 | reserve_boosting_energy = torch.add(true_pred, reservation.unsqueeze(1))[mask.unsqueeze(1). 200 | repeat(1, 19, 1, 1)] 201 | 202 | gambler_loss_out = torch.tensor([.0], device=self.device) 203 | if reserve_boosting_energy.nelement() > 0: 204 | reserve_boosting_energy = torch.clamp(reserve_boosting_energy, min=1e-7).log() 205 | gambler_loss_out = self.ood_reg * reserve_boosting_energy 206 | 207 | # gambler loss for in-lier pixels 208 | void_mask = targets == 255 209 | targets[void_mask] = 0 # make void pixel to 0 210 | targets[mask] = 0 # make ood pixel to 0 211 | gambler_loss_in = torch.gather(true_pred, index=targets.unsqueeze(1), dim=1).squeeze() 212 | gambler_loss_in = torch.add(gambler_loss_in, reservation) 213 | 214 | # exclude the ood pixel mask and void pixel mask 215 | gambler_loss_in = gambler_loss_in[(~mask) & (~void_mask)].log() 216 | return -(gambler_loss_in.mean() + gambler_loss_out.mean()) 217 | else: 218 | mask = targets == 255 219 | targets[mask] = 0 220 | reservation = torch.div(reservation, reward) 221 | gambler_loss = torch.gather(true_pred, index=targets.unsqueeze(1), dim=1).squeeze() 222 | gambler_loss = torch.add(gambler_loss, reservation) 223 | gambler_loss = gambler_loss[~mask].log() 224 | # assert not torch.any(torch.isnan(gambler_loss)), "nan check" 225 | return -gambler_loss.mean() 226 | -------------------------------------------------------------------------------- /classification/utils/score_calculation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import torch.optim as optim 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import numpy as np 11 | from scipy import misc 12 | 13 | to_np = lambda x: x.data.cpu().numpy() 14 | concat = lambda x: np.concatenate(x, axis=0) 15 | 16 | def get_ood_scores_odin(loader, net, bs, ood_num_examples, T, noise, in_dist=False): 17 | _score = [] 18 | _right_score = [] 19 | _wrong_score = [] 20 | 21 | net.eval() 22 | for batch_idx, (data, target) in enumerate(loader): 23 | if batch_idx >= ood_num_examples // bs and in_dist is False: 24 | break 25 | data = data.cuda() 26 | data = Variable(data, requires_grad = True) 27 | 28 | output = net(data) 29 | smax = to_np(F.softmax(output, dim=1)) 30 | 31 | odin_score = ODIN(data, output,net, T, noise) 32 | _score.append(-np.max(odin_score, 1)) 33 | 34 | if in_dist: 35 | preds = np.argmax(smax, axis=1) 36 | targets = target.numpy().squeeze() 37 | right_indices = preds == targets 38 | wrong_indices = np.invert(right_indices) 39 | 40 | _right_score.append(-np.max(smax[right_indices], axis=1)) 41 | _wrong_score.append(-np.max(smax[wrong_indices], axis=1)) 42 | 43 | if in_dist: 44 | return concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy() 45 | else: 46 | return concat(_score)[:ood_num_examples].copy() 47 | 48 | 49 | def ODIN(inputs, outputs, model, temper, noiseMagnitude1): 50 | # Calculating the perturbation we need to add, that is, 51 | # the sign of gradient of cross entropy loss w.r.t. input 52 | criterion = nn.CrossEntropyLoss() 53 | 54 | maxIndexTemp = np.argmax(outputs.data.cpu().numpy(), axis=1) 55 | 56 | # Using temperature scaling 57 | outputs = outputs / temper 58 | 59 | labels = Variable(torch.LongTensor(maxIndexTemp).cuda()) 60 | loss = criterion(outputs, labels) 61 | loss.backward() 62 | 63 | # Normalizing the gradient to binary in {0, 1} 64 | gradient = torch.ge(inputs.grad.data, 0) 65 | gradient = (gradient.float() - 0.5) * 2 66 | 67 | gradient[:,0] = (gradient[:,0] )/(63.0/255.0) 68 | gradient[:,1] = (gradient[:,1] )/(62.1/255.0) 69 | gradient[:,2] = (gradient[:,2] )/(66.7/255.0) 70 | #gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (63.0/255.0)) 71 | #gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (62.1/255.0)) 72 | #gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (66.7/255.0)) 73 | 74 | # Adding small perturbations to images 75 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient) 76 | outputs = model(Variable(tempInputs)) 77 | outputs = outputs / temper 78 | # Calculating the confidence after adding perturbations 79 | nnOutputs = outputs.data.cpu() 80 | nnOutputs = nnOutputs.numpy() 81 | nnOutputs = nnOutputs - np.max(nnOutputs, axis=1, keepdims=True) 82 | nnOutputs = np.exp(nnOutputs) / np.sum(np.exp(nnOutputs), axis=1, keepdims=True) 83 | 84 | return nnOutputs 85 | 86 | def get_Mahalanobis_score(model, test_loader, num_classes, sample_mean, precision, layer_index, magnitude, num_batches, in_dist=False): 87 | ''' 88 | Compute the proposed Mahalanobis confidence score on input dataset 89 | return: Mahalanobis score from layer_index 90 | ''' 91 | model.eval() 92 | Mahalanobis = [] 93 | 94 | for batch_idx, (data, target) in enumerate(test_loader): 95 | if batch_idx >= num_batches and in_dist is False: 96 | break 97 | 98 | data, target = data.cuda(), target.cuda() 99 | data, target = Variable(data, requires_grad = True), Variable(target) 100 | 101 | out_features = model.intermediate_forward(data, layer_index) 102 | out_features = out_features.view(out_features.size(0), out_features.size(1), -1) 103 | out_features = torch.mean(out_features, 2) 104 | 105 | # compute Mahalanobis score 106 | gaussian_score = 0 107 | for i in range(num_classes): 108 | batch_sample_mean = sample_mean[layer_index][i] 109 | zero_f = out_features.data - batch_sample_mean 110 | term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag() 111 | if i == 0: 112 | gaussian_score = term_gau.view(-1,1) 113 | else: 114 | gaussian_score = torch.cat((gaussian_score, term_gau.view(-1,1)), 1) 115 | 116 | # Input_processing 117 | sample_pred = gaussian_score.max(1)[1] 118 | batch_sample_mean = sample_mean[layer_index].index_select(0, sample_pred) 119 | zero_f = out_features - Variable(batch_sample_mean) 120 | pure_gau = -0.5*torch.mm(torch.mm(zero_f, Variable(precision[layer_index])), zero_f.t()).diag() 121 | loss = torch.mean(-pure_gau) 122 | loss.backward() 123 | 124 | gradient = torch.ge(data.grad.data, 0) 125 | gradient = (gradient.float() - 0.5) * 2 126 | gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (63.0/255.0)) 127 | gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (62.1/255.0)) 128 | gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (66.7/255.0)) 129 | 130 | tempInputs = torch.add(data.data, -magnitude, gradient) 131 | with torch.no_grad(): 132 | noise_out_features = model.intermediate_forward(tempInputs, layer_index) 133 | noise_out_features = noise_out_features.view(noise_out_features.size(0), noise_out_features.size(1), -1) 134 | noise_out_features = torch.mean(noise_out_features, 2) 135 | noise_gaussian_score = 0 136 | for i in range(num_classes): 137 | batch_sample_mean = sample_mean[layer_index][i] 138 | zero_f = noise_out_features.data - batch_sample_mean 139 | term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag() 140 | if i == 0: 141 | noise_gaussian_score = term_gau.view(-1,1) 142 | else: 143 | noise_gaussian_score = torch.cat((noise_gaussian_score, term_gau.view(-1,1)), 1) 144 | 145 | noise_gaussian_score, _ = torch.max(noise_gaussian_score, dim=1) 146 | Mahalanobis.extend(-noise_gaussian_score.cpu().numpy()) 147 | 148 | return np.asarray(Mahalanobis, dtype=np.float32) 149 | 150 | def sample_estimator(model, num_classes, feature_list, train_loader): 151 | """ 152 | compute sample mean and precision (inverse of covariance) 153 | return: sample_class_mean: list of class mean 154 | precision: list of precisions 155 | """ 156 | import sklearn.covariance 157 | 158 | model.eval() 159 | group_lasso = sklearn.covariance.EmpiricalCovariance(assume_centered=False) 160 | correct, total = 0, 0 161 | num_output = len(feature_list) 162 | num_sample_per_class = np.empty(num_classes) 163 | num_sample_per_class.fill(0) 164 | list_features = [] 165 | for i in range(num_output): 166 | temp_list = [] 167 | for j in range(num_classes): 168 | temp_list.append(0) 169 | list_features.append(temp_list) 170 | 171 | for data, target in train_loader: 172 | total += data.size(0) 173 | data = data.cuda() 174 | data = Variable(data, volatile=True) 175 | output, out_features = model.feature_list(data) 176 | 177 | # get hidden features 178 | for i in range(num_output): 179 | out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1) 180 | out_features[i] = torch.mean(out_features[i].data, 2) 181 | 182 | # compute the accuracy 183 | pred = output.data.max(1)[1] 184 | equal_flag = pred.eq(target.cuda()).cpu() 185 | correct += equal_flag.sum() 186 | 187 | # construct the sample matrix 188 | for i in range(data.size(0)): 189 | label = target[i] 190 | if num_sample_per_class[label] == 0: 191 | out_count = 0 192 | for out in out_features: 193 | list_features[out_count][label] = out[i].view(1, -1) 194 | out_count += 1 195 | else: 196 | out_count = 0 197 | for out in out_features: 198 | list_features[out_count][label] \ 199 | = torch.cat((list_features[out_count][label], out[i].view(1, -1)), 0) 200 | out_count += 1 201 | num_sample_per_class[label] += 1 202 | 203 | sample_class_mean = [] 204 | out_count = 0 205 | for num_feature in feature_list: 206 | temp_list = torch.Tensor(num_classes, int(num_feature)).cuda() 207 | for j in range(num_classes): 208 | temp_list[j] = torch.mean(list_features[out_count][j], 0) 209 | sample_class_mean.append(temp_list) 210 | out_count += 1 211 | 212 | precision = [] 213 | for k in range(num_output): 214 | X = 0 215 | for i in range(num_classes): 216 | if i == 0: 217 | X = list_features[k][i] - sample_class_mean[k][i] 218 | else: 219 | X = torch.cat((X, list_features[k][i] - sample_class_mean[k][i]), 0) 220 | 221 | # find inverse 222 | group_lasso.fit(X.cpu().numpy()) 223 | temp_precision = group_lasso.precision_ 224 | temp_precision = torch.from_numpy(temp_precision).float().cuda() 225 | precision.append(temp_precision) 226 | 227 | print('\n Training Accuracy:({:.2f}%)\n'.format(100. * correct / total)) 228 | 229 | return sample_class_mean, precision 230 | -------------------------------------------------------------------------------- /classification/utils/ImbalanceCIFAR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adopted from https://github.com/Megvii-Nanjing/BBN 3 | """ 4 | 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | from PIL import Image 9 | import random 10 | 11 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 12 | cls_num = 10 13 | 14 | def __init__(self, phase, imbalance_ratio, root = '/home/datasets', imb_type='exp'): 15 | train = True if phase == "train" else False 16 | super(IMBALANCECIFAR10, self).__init__(root, train, transform=None, target_transform=None, download=True) 17 | self.train = train 18 | if self.train: 19 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio) 20 | self.gen_imbalanced_data(img_num_list) 21 | self.transform = transforms.Compose([ 22 | transforms.RandomCrop(32, padding=4), 23 | transforms.RandomHorizontalFlip(), 24 | #transforms.Resize(224), 25 | transforms.ToTensor(), 26 | transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255, 66.7/255)) 27 | ]) 28 | else: 29 | self.transform = transforms.Compose([ 30 | #transforms.Resize(224), 31 | transforms.ToTensor(), 32 | transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255, 66.7/255)) 33 | ]) 34 | 35 | self.labels = self.targets 36 | 37 | print("{} Mode: Contain {} images".format(phase, len(self.data))) 38 | 39 | def _get_class_dict(self): 40 | class_dict = dict() 41 | for i, anno in enumerate(self.get_annotations()): 42 | cat_id = anno["category_id"] 43 | if not cat_id in class_dict: 44 | class_dict[cat_id] = [] 45 | class_dict[cat_id].append(i) 46 | return class_dict 47 | 48 | 49 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 50 | img_max = len(self.data) / cls_num 51 | img_num_per_cls = [] 52 | if imb_type == 'exp': 53 | for cls_idx in range(cls_num): 54 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 55 | img_num_per_cls.append(int(num)) 56 | elif imb_type == 'step': 57 | for cls_idx in range(cls_num // 2): 58 | img_num_per_cls.append(int(img_max)) 59 | for cls_idx in range(cls_num // 2): 60 | img_num_per_cls.append(int(img_max * imb_factor)) 61 | else: 62 | img_num_per_cls.extend([int(imb_factor)] * cls_num) 63 | return img_num_per_cls 64 | 65 | def gen_imbalanced_data(self, img_num_per_cls): 66 | new_data = [] 67 | new_targets = [] 68 | targets_np = np.array(self.targets, dtype=np.int64) 69 | classes = np.unique(targets_np) 70 | 71 | self.num_per_cls_dict = dict() 72 | for the_class, the_img_num in zip(classes, img_num_per_cls): 73 | self.num_per_cls_dict[the_class] = the_img_num 74 | idx = np.where(targets_np == the_class)[0] 75 | #np.random.shuffle(idx) 76 | selec_idx = idx[:the_img_num] 77 | new_data.append(self.data[selec_idx, ...]) 78 | new_targets.extend([the_class, ] * the_img_num) 79 | new_data = np.vstack(new_data) 80 | self.data = new_data 81 | self.targets = new_targets 82 | 83 | def __getitem__(self, index): 84 | img, label = self.data[index], self.labels[index] 85 | 86 | # doing this so that it is consistent with all other datasets 87 | # to return a PIL Image 88 | img = Image.fromarray(img) 89 | 90 | if self.transform is not None: 91 | img = self.transform(img) 92 | 93 | if self.target_transform is not None: 94 | label = self.target_transform(label) 95 | 96 | return img, label, index 97 | 98 | def __len__(self): 99 | return len(self.labels) 100 | 101 | def get_num_classes(self): 102 | return self.cls_num 103 | 104 | def get_annotations(self): 105 | annos = [] 106 | for label in self.labels: 107 | annos.append({'category_id': int(label)}) 108 | return annos 109 | 110 | def get_cls_num_list(self): 111 | cls_num_list = [] 112 | for i in range(self.cls_num): 113 | cls_num_list.append(self.num_per_cls_dict[i]) 114 | return cls_num_list 115 | 116 | class IMBALANCECIFAR100(IMBALANCECIFAR10): 117 | """`CIFAR100 `_ Dataset. 118 | This is a subclass of the `CIFAR10` Dataset. 119 | """ 120 | cls_num = 100 121 | base_folder = 'cifar-100-python' 122 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 123 | filename = "cifar-100-python.tar.gz" 124 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 125 | train_list = [ 126 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 127 | ] 128 | 129 | test_list = [ 130 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 131 | ] 132 | meta = { 133 | 'filename': 'meta', 134 | 'key': 'fine_label_names', 135 | 'md5': '7973b15100ade9c7d40fb424638fde48', 136 | } 137 | 138 | 139 | class IMBALANCECIFAR10_sota(torchvision.datasets.CIFAR10): 140 | cls_num = 10 141 | 142 | def __init__(self, phase, imbalance_ratio, root = '/home/datasets', imb_type='exp'): 143 | train = True if phase == "train" else False 144 | super(IMBALANCECIFAR10_sota, self).__init__(root, train, transform=None, target_transform=None, download=True) 145 | self.train = train 146 | if self.train: 147 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 148 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 149 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio) 150 | self.gen_imbalanced_data(img_num_list) 151 | self.transform = transforms.Compose([ 152 | transforms.RandomCrop(32, padding=4), 153 | transforms.RandomHorizontalFlip(), 154 | transforms.RandomApply([ 155 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 156 | ], p=0.8), 157 | transforms.RandomGrayscale(p=0.2), 158 | transforms.ToTensor(), transforms.Normalize(mean, std) 159 | ]) 160 | else: 161 | self.transform = transforms.Compose([ 162 | #transforms.Resize(224), 163 | transforms.ToTensor() 164 | ,transforms.Normalize(mean, std) 165 | ]) 166 | 167 | self.labels = self.targets 168 | 169 | print("{} Mode: Contain {} images".format(phase, len(self.data))) 170 | 171 | def _get_class_dict(self): 172 | class_dict = dict() 173 | for i, anno in enumerate(self.get_annotations()): 174 | cat_id = anno["category_id"] 175 | if not cat_id in class_dict: 176 | class_dict[cat_id] = [] 177 | class_dict[cat_id].append(i) 178 | return class_dict 179 | 180 | 181 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 182 | img_max = len(self.data) / cls_num 183 | img_num_per_cls = [] 184 | if imb_type == 'exp': 185 | for cls_idx in range(cls_num): 186 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 187 | img_num_per_cls.append(int(num)) 188 | elif imb_type == 'step': 189 | for cls_idx in range(cls_num // 2): 190 | img_num_per_cls.append(int(img_max)) 191 | for cls_idx in range(cls_num // 2): 192 | img_num_per_cls.append(int(img_max * imb_factor)) 193 | else: 194 | img_num_per_cls.extend([int(imb_factor)] * cls_num) 195 | return img_num_per_cls 196 | 197 | def gen_imbalanced_data(self, img_num_per_cls): 198 | new_data = [] 199 | new_targets = [] 200 | targets_np = np.array(self.targets, dtype=np.int64) 201 | classes = np.unique(targets_np) 202 | 203 | self.num_per_cls_dict = dict() 204 | for the_class, the_img_num in zip(classes, img_num_per_cls): 205 | self.num_per_cls_dict[the_class] = the_img_num 206 | idx = np.where(targets_np == the_class)[0] 207 | #np.random.shuffle(idx) 208 | selec_idx = idx[:the_img_num] 209 | new_data.append(self.data[selec_idx, ...]) 210 | new_targets.extend([the_class, ] * the_img_num) 211 | new_data = np.vstack(new_data) 212 | self.data = new_data 213 | self.targets = new_targets 214 | 215 | def __getitem__(self, index): 216 | img, label = self.data[index], self.labels[index] 217 | 218 | # doing this so that it is consistent with all other datasets 219 | # to return a PIL Image 220 | img = Image.fromarray(img) 221 | 222 | if self.transform is not None: 223 | img = self.transform(img) 224 | 225 | if self.target_transform is not None: 226 | label = self.target_transform(label) 227 | 228 | return img, label, index 229 | 230 | def __len__(self): 231 | return len(self.labels) 232 | 233 | def get_num_classes(self): 234 | return self.cls_num 235 | 236 | def get_annotations(self): 237 | annos = [] 238 | for label in self.labels: 239 | annos.append({'category_id': int(label)}) 240 | return annos 241 | 242 | def get_cls_num_list(self): 243 | cls_num_list = [] 244 | for i in range(self.cls_num): 245 | cls_num_list.append(self.num_per_cls_dict[i]) 246 | return cls_num_list 247 | 248 | class IMBALANCECIFAR100_sota(IMBALANCECIFAR10_sota): 249 | """`CIFAR100 `_ Dataset. 250 | This is a subclass of the `CIFAR10` Dataset. 251 | """ 252 | cls_num = 100 253 | base_folder = 'cifar-100-python' 254 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 255 | filename = "cifar-100-python.tar.gz" 256 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 257 | train_list = [ 258 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 259 | ] 260 | 261 | test_list = [ 262 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 263 | ] 264 | meta = { 265 | 'filename': 'meta', 266 | 'key': 'fine_label_names', 267 | 'md5': '7973b15100ade9c7d40fb424638fde48', 268 | } 269 | 270 | -------------------------------------------------------------------------------- /classification/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------