├── ODIR ├── vilt │ ├── __init__.py │ ├── gadgets │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── my_metrics.cpython-37.pyc │ │ │ └── my_metrics.cpython-38.pyc │ │ └── my_metrics.py │ ├── .DS_Store │ ├── modules │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── lora.cpython-37.pyc │ │ │ ├── lora.cpython-38.pyc │ │ │ ├── heads.cpython-37.pyc │ │ │ ├── heads.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dist_utils.cpython-37.pyc │ │ │ ├── dist_utils.cpython-38.pyc │ │ │ ├── objectives.cpython-37.pyc │ │ │ ├── objectives.cpython-38.pyc │ │ │ ├── vilt_utils.cpython-37.pyc │ │ │ ├── vilt_utils.cpython-38.pyc │ │ │ ├── vision_transformer_prompts.cpython-37.pyc │ │ │ ├── vision_transformer_prompts.cpython-38.pyc │ │ │ ├── vilt_missing_aware_prompt_module.cpython-37.pyc │ │ │ └── vilt_missing_aware_prompt_module.cpython-38.pyc │ │ ├── heads.py │ │ └── lora.py │ ├── utils │ │ ├── .DS_Store │ │ ├── write_hatememes.py │ │ ├── write_food101.py │ │ ├── write_mmimdb.py │ │ └── glossary.py │ ├── datasets │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── ODIR_dataset.cpython-37.pyc │ │ │ ├── base_dataset.cpython-37.pyc │ │ │ ├── base_dataset.cpython-38.pyc │ │ │ ├── food101_dataset.cpython-37.pyc │ │ │ ├── food101_dataset.cpython-38.pyc │ │ │ ├── mmimdb_dataset.cpython-37.pyc │ │ │ ├── mmimdb_dataset.cpython-38.pyc │ │ │ ├── chestXray_dataset.cpython-37.pyc │ │ │ ├── chestXray_dataset.cpython-38.pyc │ │ │ ├── hatememes_dataset.cpython-37.pyc │ │ │ └── hatememes_dataset.cpython-38.pyc │ │ ├── __init__.py │ │ ├── hatememes_dataset.py │ │ ├── mmimdb_dataset.py │ │ ├── food101_dataset.py │ │ └── ODIR_dataset.py │ ├── datamodules │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── ODIR_datamodule.cpython-37.pyc │ │ │ ├── datamodule_base.cpython-37.pyc │ │ │ ├── datamodule_base.cpython-38.pyc │ │ │ ├── food101_datamodule.cpython-37.pyc │ │ │ ├── food101_datamodule.cpython-38.pyc │ │ │ ├── mmimdb_datamodule.cpython-37.pyc │ │ │ ├── mmimdb_datamodule.cpython-38.pyc │ │ │ ├── chestXray_datamodule.cpython-37.pyc │ │ │ ├── chestXray_datamodule.cpython-38.pyc │ │ │ ├── hatememes_datamodule.cpython-37.pyc │ │ │ ├── hatememes_datamodule.cpython-38.pyc │ │ │ ├── multitask_datamodule.cpython-37.pyc │ │ │ └── multitask_datamodule.cpython-38.pyc │ │ ├── __init__.py │ │ ├── ODIR_datamodule.py │ │ ├── mmimdb_datamodule.py │ │ ├── food101_datamodule.py │ │ ├── hatememes_datamodule.py │ │ ├── multitask_datamodule.py │ │ └── datamodule_base.py │ ├── transforms │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── utils.cpython-37.pyc │ │ │ ├── utils.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── pixelbert.cpython-37.pyc │ │ │ ├── pixelbert.cpython-38.pyc │ │ │ ├── randaug.cpython-37.pyc │ │ │ └── randaug.cpython-38.pyc │ │ ├── __init__.py │ │ ├── pixelbert.py │ │ ├── utils.py │ │ └── randaug.py │ ├── __pycache__ │ │ ├── config.cpython-37.pyc │ │ ├── config.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-38.pyc │ └── Costom_task_guideline ├── datasets │ └── missing_tables │ │ ├── .DS_Store │ │ ├── ODIR_test_missing_both_01.pt │ │ ├── ODIR_test_missing_both_02.pt │ │ ├── ODIR_test_missing_both_03.pt │ │ ├── ODIR_test_missing_both_04.pt │ │ ├── ODIR_test_missing_both_05.pt │ │ ├── ODIR_test_missing_both_06.pt │ │ ├── ODIR_test_missing_both_07.pt │ │ ├── ODIR_test_missing_both_08.pt │ │ ├── ODIR_test_missing_both_09.pt │ │ ├── ODIR_test_missing_text_07.pt │ │ ├── ODIR_val_missing_both_01.pt │ │ ├── ODIR_val_missing_both_02.pt │ │ ├── ODIR_val_missing_both_03.pt │ │ ├── ODIR_val_missing_both_04.pt │ │ ├── ODIR_val_missing_both_05.pt │ │ ├── ODIR_val_missing_both_06.pt │ │ ├── ODIR_val_missing_both_07.pt │ │ ├── ODIR_val_missing_both_08.pt │ │ ├── ODIR_val_missing_both_09.pt │ │ ├── ODIR_val_missing_image_07.pt │ │ ├── ODIR_val_missing_text_07.pt │ │ ├── ODIR_test_missing_image_07.pt │ │ ├── ODIR_train_missing_both_01.pt │ │ ├── ODIR_train_missing_both_02.pt │ │ ├── ODIR_train_missing_both_03.pt │ │ ├── ODIR_train_missing_both_04.pt │ │ ├── ODIR_train_missing_both_05.pt │ │ ├── ODIR_train_missing_both_06.pt │ │ ├── ODIR_train_missing_both_07.pt │ │ ├── ODIR_train_missing_both_08.pt │ │ ├── ODIR_train_missing_both_09.pt │ │ ├── ODIR_train_missing_image_07.pt │ │ └── ODIR_train_missing_text_07.pt ├── scripts │ ├── ODIR_training.sh │ └── ODIR_testing.sh ├── test.py ├── sort.py ├── run.py └── make_ODIR_arrow.py ├── chestXray ├── vilt │ ├── __init__.py │ ├── gadgets │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── my_metrics.cpython-37.pyc │ │ │ └── my_metrics.cpython-38.pyc │ │ └── my_metrics.py │ ├── .DS_Store │ ├── modules │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── lora.cpython-37.pyc │ │ │ ├── lora.cpython-38.pyc │ │ │ ├── heads.cpython-37.pyc │ │ │ ├── heads.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dist_utils.cpython-37.pyc │ │ │ ├── dist_utils.cpython-38.pyc │ │ │ ├── objectives.cpython-37.pyc │ │ │ ├── objectives.cpython-38.pyc │ │ │ ├── vilt_utils.cpython-37.pyc │ │ │ ├── vilt_utils.cpython-38.pyc │ │ │ ├── vision_transformer_prompts.cpython-37.pyc │ │ │ ├── vision_transformer_prompts.cpython-38.pyc │ │ │ ├── vilt_missing_aware_prompt_module.cpython-37.pyc │ │ │ └── vilt_missing_aware_prompt_module.cpython-38.pyc │ │ ├── heads.py │ │ └── lora.py │ ├── utils │ │ ├── .DS_Store │ │ ├── write_hatememes.py │ │ ├── write_food101.py │ │ ├── write_mmimdb.py │ │ └── glossary.py │ ├── datasets │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── base_dataset.cpython-37.pyc │ │ │ ├── base_dataset.cpython-38.pyc │ │ │ ├── food101_dataset.cpython-37.pyc │ │ │ ├── food101_dataset.cpython-38.pyc │ │ │ ├── mmimdb_dataset.cpython-37.pyc │ │ │ ├── mmimdb_dataset.cpython-38.pyc │ │ │ ├── chestXray_dataset.cpython-37.pyc │ │ │ ├── chestXray_dataset.cpython-38.pyc │ │ │ ├── hatememes_dataset.cpython-37.pyc │ │ │ └── hatememes_dataset.cpython-38.pyc │ │ ├── __init__.py │ │ ├── hatememes_dataset.py │ │ ├── mmimdb_dataset.py │ │ ├── food101_dataset.py │ │ └── chestXray_dataset.py │ ├── datamodules │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── datamodule_base.cpython-37.pyc │ │ │ ├── datamodule_base.cpython-38.pyc │ │ │ ├── food101_datamodule.cpython-37.pyc │ │ │ ├── food101_datamodule.cpython-38.pyc │ │ │ ├── mmimdb_datamodule.cpython-37.pyc │ │ │ ├── mmimdb_datamodule.cpython-38.pyc │ │ │ ├── chestXray_datamodule.cpython-37.pyc │ │ │ ├── chestXray_datamodule.cpython-38.pyc │ │ │ ├── hatememes_datamodule.cpython-37.pyc │ │ │ ├── hatememes_datamodule.cpython-38.pyc │ │ │ ├── multitask_datamodule.cpython-37.pyc │ │ │ └── multitask_datamodule.cpython-38.pyc │ │ ├── __init__.py │ │ ├── mmimdb_datamodule.py │ │ ├── food101_datamodule.py │ │ ├── chestXray_datamodule.py │ │ ├── hatememes_datamodule.py │ │ ├── multitask_datamodule.py │ │ └── datamodule_base.py │ ├── transforms │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── utils.cpython-37.pyc │ │ │ ├── utils.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── pixelbert.cpython-37.pyc │ │ │ ├── pixelbert.cpython-38.pyc │ │ │ ├── randaug.cpython-37.pyc │ │ │ └── randaug.cpython-38.pyc │ │ ├── __init__.py │ │ ├── pixelbert.py │ │ └── utils.py │ ├── __pycache__ │ │ ├── config.cpython-37.pyc │ │ ├── config.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-38.pyc │ └── Costom_task_guideline ├── datasets │ └── missing_tables │ │ ├── .DS_Store │ │ ├── chestXray_test_missing_both_05.pt │ │ ├── chestXray_test_missing_both_07.pt │ │ ├── chestXray_test_missing_test_07.pt │ │ ├── chestXray_test_missing_text_05.pt │ │ ├── chestXray_test_missing_text_07.pt │ │ ├── chestXray_test_missing_text_1.pt │ │ ├── chestXray_train_missing_text_1.pt │ │ ├── chestXray_val_missing_both_07.pt │ │ ├── chestXray_val_missing_image_05.pt │ │ ├── chestXray_val_missing_image_07.pt │ │ ├── chestXray_val_missing_test_07.pt │ │ ├── chestXray_val_missing_text_05.pt │ │ ├── chestXray_val_missing_text_07.pt │ │ ├── chestXray_val_missing_text_1.pt │ │ ├── chestXray_test_missing_image_05.pt │ │ ├── chestXray_test_missing_image_07.pt │ │ ├── chestXray_train_missing_both_07.pt │ │ ├── chestXray_train_missing_image_05.pt │ │ ├── chestXray_train_missing_image_07.pt │ │ ├── chestXray_train_missing_test_07.pt │ │ ├── chestXray_train_missing_text_05.pt │ │ ├── chestXray_train_missing_text_07.pt │ │ ├── chestXray_test_missing_text_05_version1.pt │ │ ├── chestXray_train_missing_text_05_version1.pt │ │ └── chestXray_val_missing_text_05_version1.pt ├── scripts │ ├── chestXray_training.sh │ └── chestXray_testing.sh ├── test.py ├── sort.py ├── run.py └── make_chestXray_arrow.py └── README.md /ODIR/vilt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chestXray/vilt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ODIR/vilt/gadgets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chestXray/vilt/gadgets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ODIR/vilt/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/.DS_Store -------------------------------------------------------------------------------- /chestXray/vilt/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/gadgets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/gadgets/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/utils/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/transforms/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MoRA 2 | Official code for MoRA: LoRA Guided Multi-Modal Disease Diagnosis with Missing Modality 3 | -------------------------------------------------------------------------------- /chestXray/vilt/gadgets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/gadgets/.DS_Store -------------------------------------------------------------------------------- /chestXray/vilt/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/.DS_Store -------------------------------------------------------------------------------- /chestXray/vilt/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/utils/.DS_Store -------------------------------------------------------------------------------- /chestXray/vilt/datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/.DS_Store -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/.DS_Store -------------------------------------------------------------------------------- /chestXray/vilt/transforms/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/.DS_Store -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/.DS_Store -------------------------------------------------------------------------------- /ODIR/vilt/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # from .vilt_module import ViLTransformerSS 2 | from .vilt_missing_aware_prompt_module import ViLTransformerSS 3 | -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/lora.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/lora.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/lora.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/lora.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/heads.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/heads.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/heads.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/heads.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # from .vilt_module import ViLTransformerSS 2 | from .vilt_missing_aware_prompt_module import ViLTransformerSS 3 | -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/gadgets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/gadgets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/gadgets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/gadgets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/lora.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/lora.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/lora.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/lora.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_01.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_01.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_02.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_02.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_03.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_03.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_04.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_04.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_05.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_06.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_06.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_07.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_08.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_08.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_both_09.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_both_09.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_text_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_text_07.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_01.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_01.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_02.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_02.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_03.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_03.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_04.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_04.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_05.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_06.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_06.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_07.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_08.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_08.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_both_09.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_both_09.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_image_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_image_07.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_val_missing_text_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_val_missing_text_07.pt -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/gadgets/__pycache__/my_metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/gadgets/__pycache__/my_metrics.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/gadgets/__pycache__/my_metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/gadgets/__pycache__/my_metrics.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/dist_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/dist_utils.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/dist_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/dist_utils.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/objectives.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/objectives.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/objectives.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/objectives.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/vilt_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/vilt_utils.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/vilt_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/vilt_utils.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/pixelbert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/pixelbert.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/pixelbert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/pixelbert.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/randaug.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/randaug.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__pycache__/randaug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/transforms/__pycache__/randaug.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/heads.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/heads.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/heads.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/heads.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_test_missing_image_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_test_missing_image_07.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_01.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_01.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_02.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_02.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_03.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_03.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_04.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_04.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_05.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_06.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_06.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_07.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_08.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_08.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_both_09.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_both_09.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_image_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_image_07.pt -------------------------------------------------------------------------------- /ODIR/datasets/missing_tables/ODIR_train_missing_text_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/datasets/missing_tables/ODIR_train_missing_text_07.pt -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/ODIR_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/ODIR_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/base_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/base_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/gadgets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/gadgets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/gadgets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/gadgets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/food101_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/food101_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/food101_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/food101_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/mmimdb_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/mmimdb_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/mmimdb_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/mmimdb_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/gadgets/__pycache__/my_metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/gadgets/__pycache__/my_metrics.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/gadgets/__pycache__/my_metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/gadgets/__pycache__/my_metrics.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/dist_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/dist_utils.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/dist_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/dist_utils.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/objectives.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/objectives.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/objectives.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/objectives.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/vilt_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/vilt_utils.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/vilt_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/vilt_utils.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/pixelbert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/pixelbert.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/pixelbert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/pixelbert.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/randaug.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/randaug.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__pycache__/randaug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/transforms/__pycache__/randaug.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/ODIR_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/ODIR_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/datamodule_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/datamodule_base.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/datamodule_base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/datamodule_base.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/chestXray_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/chestXray_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/chestXray_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/chestXray_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/hatememes_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/hatememes_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__pycache__/hatememes_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datasets/__pycache__/hatememes_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/base_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/base_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/food101_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/food101_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/food101_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/food101_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_both_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_both_05.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_both_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_both_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_test_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_test_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_text_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_text_05.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_text_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_text_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_text_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_text_1.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_text_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_text_1.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_both_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_both_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_image_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_image_05.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_image_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_image_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_test_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_test_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_text_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_text_05.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_text_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_text_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_text_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_text_1.pt -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/food101_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/food101_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/food101_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/food101_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/mmimdb_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/mmimdb_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/mmimdb_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/mmimdb_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/multitask_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/multitask_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__pycache__/multitask_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/datamodules/__pycache__/multitask_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_image_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_image_05.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_image_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_image_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_both_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_both_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_image_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_image_05.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_image_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_image_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_test_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_test_07.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_text_05.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_text_05.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_text_07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_text_07.pt -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/datamodule_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/datamodule_base.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/datamodule_base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/datamodule_base.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/chestXray_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/chestXray_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/chestXray_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/chestXray_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/hatememes_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/hatememes_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__pycache__/hatememes_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datasets/__pycache__/hatememes_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/vision_transformer_prompts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/vision_transformer_prompts.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/vision_transformer_prompts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/vision_transformer_prompts.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/food101_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/food101_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/food101_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/food101_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/mmimdb_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/chestXray_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/hatememes_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/multitask_datamodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/multitask_datamodule.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__pycache__/multitask_datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/datamodules/__pycache__/multitask_datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-37.pyc -------------------------------------------------------------------------------- /ODIR/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/ODIR/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_test_missing_text_05_version1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_test_missing_text_05_version1.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_train_missing_text_05_version1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_train_missing_text_05_version1.pt -------------------------------------------------------------------------------- /chestXray/datasets/missing_tables/chestXray_val_missing_text_05_version1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/datasets/missing_tables/chestXray_val_missing_text_05_version1.pt -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/vision_transformer_prompts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/vision_transformer_prompts.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/vision_transformer_prompts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/vision_transformer_prompts.cpython-38.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-37.pyc -------------------------------------------------------------------------------- /chestXray/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyiscs/MoRA/HEAD/chestXray/vilt/modules/__pycache__/vilt_missing_aware_prompt_module.cpython-38.pyc -------------------------------------------------------------------------------- /ODIR/vilt/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmimdb_dataset import MMIMDBDataset 2 | from .hatememes_dataset import HateMemesDataset 3 | from .food101_dataset import FOOD101Dataset 4 | from .ODIR_dataset import ODIRDataset 5 | -------------------------------------------------------------------------------- /chestXray/vilt/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmimdb_dataset import MMIMDBDataset 2 | from .hatememes_dataset import HateMemesDataset 3 | from .food101_dataset import FOOD101Dataset 4 | from .chestXray_dataset import CHESTXRAYDataset 5 | -------------------------------------------------------------------------------- /ODIR/vilt/Costom_task_guideline: -------------------------------------------------------------------------------- 1 | 1. config.py - config and loss_name 2 | 2. module/vilt_*_module 3 | 3. module/objectives 4 | 4. module/vilt_utils 5 | 5. datasets/inits and datasets/{}_datasets 6 | 6. datamodule/inits and datamodule/{}_datamodules 7 | 7. gadgets/my_metric if need new metric -------------------------------------------------------------------------------- /chestXray/vilt/Costom_task_guideline: -------------------------------------------------------------------------------- 1 | 1. config.py - config and loss_name 2 | 2. module/vilt_*_module 3 | 3. module/objectives 4 | 4. module/vilt_utils 5 | 5. datasets/inits and datasets/{}_datasets 6 | 6. datamodule/inits and datamodule/{}_datamodules 7 | 7. gadgets/my_metric if need new metric -------------------------------------------------------------------------------- /ODIR/vilt/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixelbert import ( 2 | pixelbert_transform, 3 | pixelbert_transform_randaug, 4 | ) 5 | 6 | _transforms = { 7 | "pixelbert": pixelbert_transform, 8 | "pixelbert_randaug": pixelbert_transform_randaug, 9 | } 10 | 11 | 12 | def keys_to_transforms(keys: list, size=224): 13 | return [_transforms[key](size=size) for key in keys] 14 | -------------------------------------------------------------------------------- /chestXray/vilt/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixelbert import ( 2 | pixelbert_transform, 3 | pixelbert_transform_randaug, 4 | ) 5 | 6 | _transforms = { 7 | "pixelbert": pixelbert_transform, 8 | "pixelbert_randaug": pixelbert_transform_randaug, 9 | } 10 | 11 | 12 | def keys_to_transforms(keys: list, size=224): 13 | return [_transforms[key](size=size) for key in keys] 14 | -------------------------------------------------------------------------------- /ODIR/scripts/ODIR_training.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python run.py with data_root="/ssd4tb/datasets/ODIR/" \ 3 | num_gpus=1 \ 4 | num_nodes=1 \ 5 | per_gpu_batchsize=4 \ 6 | task_finetune_ODIR \ 7 | load_path="/ssd4tb/datasets/missing_datasets/pre_trian/vilt_200k_mlm_itm.ckpt"\ 8 | exp_name=ODIR_text_missing07_MSP \ 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmimdb_datamodule import MMIMDBDataModule 2 | from .hatememes_datamodule import HateMemesDataModule 3 | from .food101_datamodule import FOOD101DataModule 4 | from .ODIR_datamodule import ODIRDataModule 5 | 6 | _datamodules = { 7 | "mmimdb": MMIMDBDataModule, 8 | "Hatefull_Memes": HateMemesDataModule, 9 | "Food101": FOOD101DataModule, 10 | "ODIR": ODIRDataModule, 11 | } 12 | -------------------------------------------------------------------------------- /chestXray/scripts/chestXray_training.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python run.py with data_root="/ssd4tb/datasets/chestXray/" \ 3 | num_gpus=1 \ 4 | num_nodes=1 \ 5 | per_gpu_batchsize=4 \ 6 | task_finetune_chestXray \ 7 | load_path="/ssd4tb/datasets/missing_datasets/pre_trian/vilt_200k_mlm_itm.ckpt"\ 8 | exp_name=chestXray_both_missing07_MAP_new \ 9 | 10 | 11 | -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmimdb_datamodule import MMIMDBDataModule 2 | from .hatememes_datamodule import HateMemesDataModule 3 | from .food101_datamodule import FOOD101DataModule 4 | from .chestXray_datamodule import CHESTXRAYDataModule 5 | 6 | _datamodules = { 7 | "mmimdb": MMIMDBDataModule, 8 | "Hatefull_Memes": HateMemesDataModule, 9 | "Food101": FOOD101DataModule, 10 | "chestXray": CHESTXRAYDataModule, 11 | } 12 | -------------------------------------------------------------------------------- /chestXray/scripts/chestXray_testing.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | python run.py with data_root="/ssd4tb/datasets/chestXray/" \ 3 | num_gpus=1 \ 4 | num_nodes=1 \ 5 | per_gpu_batchsize=4 \ 6 | task_finetune_chestXray \ 7 | load_path="result/chestXray_text_missing07_MAP_new_seed0_from_vilt_200k_mlm_itm/version_0/checkpoints/epoch=45-step=34321.ckpt"\ 8 | exp_name=chestXray\ 9 | test_only=True 10 | 11 | -------------------------------------------------------------------------------- /ODIR/scripts/ODIR_testing.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python run.py with data_root="/ssd4tb/datasets/ODIR/" \ 3 | num_gpus=1 \ 4 | num_nodes=1 \ 5 | per_gpu_batchsize=4 \ 6 | task_finetune_ODIR \ 7 | load_path="/ssd4tb/Zhiyi/ODIR/result/ODIR_both_missing07_baseline_seed0_from_vilt_200k_mlm_itm/version_0/checkpoints/epoch=18-step=809.ckpt"\ 8 | exp_name=ODIR_both_missing07 \ 9 | test_only=True 10 | 11 | -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/ODIR_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import ODIRDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class ODIRDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return ODIRDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "ODIR" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/mmimdb_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import MMIMDBDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class MMIMDBDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return MMIMDBDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "mmimdb" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/food101_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import FOOD101Dataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class FOOD101DataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return FOOD101Dataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "food101" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/mmimdb_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import MMIMDBDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class MMIMDBDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return MMIMDBDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "mmimdb" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/food101_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import FOOD101Dataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class FOOD101DataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return FOOD101Dataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "food101" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/hatememes_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import HateMemesDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class HateMemesDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return HateMemesDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "Hatefull_Memes" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/chestXray_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import CHESTXRAYDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class CHESTXRAYDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return CHESTXRAYDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "chestXray" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/hatememes_datamodule.py: -------------------------------------------------------------------------------- 1 | from vilt.datasets import HateMemesDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class HateMemesDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return HateMemesDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "Hatefull_Memes" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | -------------------------------------------------------------------------------- /ODIR/test.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import random 3 | import torch 4 | import io 5 | import pyarrow as pa 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | 10 | train_table = pa.ipc.RecordBatchFileReader(pa.memory_map("/ssd4tb/datasets/chestXray/chestXray_train.arrow", "r")).read_all() 11 | #dev_table = pa.ipc.RecordBatchFileReader(pa.memory_map("/ssd4tb/datasets/chestXray/chestXray_val.arrow", "r")).read_all() 12 | #test_table = pa.ipc.RecordBatchFileReader(pa.memory_map("/ssd4tb/datasets/chestXray/chestXray_test.arrow", "r")).read_all() 13 | 14 | 15 | print(len(train_table)) -------------------------------------------------------------------------------- /chestXray/test.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import random 3 | import torch 4 | import io 5 | import pyarrow as pa 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | 10 | train_table = pa.ipc.RecordBatchFileReader(pa.memory_map("/ssd4tb/datasets/chestXray/chestXray_train.arrow", "r")).read_all() 11 | #dev_table = pa.ipc.RecordBatchFileReader(pa.memory_map("/ssd4tb/datasets/chestXray/chestXray_val.arrow", "r")).read_all() 12 | #test_table = pa.ipc.RecordBatchFileReader(pa.memory_map("/ssd4tb/datasets/chestXray/chestXray_test.arrow", "r")).read_all() 13 | 14 | 15 | print(len(train_table)) -------------------------------------------------------------------------------- /ODIR/vilt/transforms/pixelbert.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | inception_normalize, 3 | MinMaxResize, 4 | ) 5 | from torchvision import transforms 6 | from .randaug import RandAugment 7 | 8 | 9 | def pixelbert_transform(size=800): 10 | longer = int((1333 / 800) * size) 11 | return transforms.Compose( 12 | [ 13 | MinMaxResize(shorter=size, longer=longer), 14 | transforms.ToTensor(), 15 | inception_normalize, 16 | ] 17 | ) 18 | 19 | 20 | def pixelbert_transform_randaug(size=800): 21 | longer = int((1333 / 800) * size) 22 | trs = transforms.Compose( 23 | [ 24 | MinMaxResize(shorter=size, longer=longer), 25 | transforms.ToTensor(), 26 | inception_normalize, 27 | ] 28 | ) 29 | trs.transforms.insert(0, RandAugment(2, 9)) 30 | return trs 31 | -------------------------------------------------------------------------------- /chestXray/vilt/transforms/pixelbert.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | inception_normalize, 3 | MinMaxResize, 4 | ) 5 | from torchvision import transforms 6 | from .randaug import RandAugment 7 | 8 | 9 | def pixelbert_transform(size=800): 10 | longer = int((1333 / 800) * size) 11 | return transforms.Compose( 12 | [ 13 | MinMaxResize(shorter=size, longer=longer), 14 | transforms.ToTensor(), 15 | inception_normalize, 16 | ] 17 | ) 18 | 19 | 20 | def pixelbert_transform_randaug(size=800): 21 | longer = int((1333 / 800) * size) 22 | trs = transforms.Compose( 23 | [ 24 | MinMaxResize(shorter=size, longer=longer), 25 | transforms.ToTensor(), 26 | inception_normalize, 27 | ] 28 | ) 29 | trs.transforms.insert(0, RandAugment(2, 9)) 30 | return trs 31 | -------------------------------------------------------------------------------- /ODIR/vilt/utils/write_hatememes.py: -------------------------------------------------------------------------------- 1 | import json, jsonlines 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | def make_arrow(root, dataset_root, single_plot=False): 13 | split_sets = ['train', 'dev', 'test'] 14 | 15 | for split in split_sets: 16 | data_list = [] 17 | with jsonlines.open(os.path.join(root,f'data/{split}.jsonl'), 'r') as rfd: 18 | for data in tqdm(rfd): 19 | image_path = os.path.join(root, 'data', data['img']) 20 | 21 | with open(image_path, "rb") as fp: 22 | binary = fp.read() 23 | 24 | text = [data['text']] 25 | label = data['label'] 26 | text_aug = text_aug_dir['{}.png'.format(data['id'])] 27 | 28 | data = (binary, text, label, split) 29 | data_list.append(data) 30 | 31 | 32 | dataframe = pd.DataFrame( 33 | data_list, 34 | columns=[ 35 | "image", 36 | "text", 37 | "label", 38 | "split", 39 | ], 40 | ) 41 | 42 | table = pa.Table.from_pandas(dataframe) 43 | 44 | os.makedirs(dataset_root, exist_ok=True) 45 | with pa.OSFile(f"{dataset_root}/hatememes_{split}.arrow", "wb") as sink: 46 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 47 | writer.write_table(table) -------------------------------------------------------------------------------- /chestXray/vilt/utils/write_hatememes.py: -------------------------------------------------------------------------------- 1 | import json, jsonlines 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | def make_arrow(root, dataset_root, single_plot=False): 13 | split_sets = ['train', 'dev', 'test'] 14 | 15 | for split in split_sets: 16 | data_list = [] 17 | with jsonlines.open(os.path.join(root,f'data/{split}.jsonl'), 'r') as rfd: 18 | for data in tqdm(rfd): 19 | image_path = os.path.join(root, 'data', data['img']) 20 | 21 | with open(image_path, "rb") as fp: 22 | binary = fp.read() 23 | 24 | text = [data['text']] 25 | label = data['label'] 26 | text_aug = text_aug_dir['{}.png'.format(data['id'])] 27 | 28 | data = (binary, text, label, split) 29 | data_list.append(data) 30 | 31 | 32 | dataframe = pd.DataFrame( 33 | data_list, 34 | columns=[ 35 | "image", 36 | "text", 37 | "label", 38 | "split", 39 | ], 40 | ) 41 | 42 | table = pa.Table.from_pandas(dataframe) 43 | 44 | os.makedirs(dataset_root, exist_ok=True) 45 | with pa.OSFile(f"{dataset_root}/hatememes_{split}.arrow", "wb") as sink: 46 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 47 | writer.write_table(table) -------------------------------------------------------------------------------- /ODIR/vilt/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | -------------------------------------------------------------------------------- /chestXray/vilt/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | -------------------------------------------------------------------------------- /ODIR/vilt/utils/write_food101.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | def make_arrow(root, dataset_root, single_plot=False, missing_type=None): 13 | image_root = os.path.join(root, 'images') 14 | 15 | with open(f"{root}/class_idx.json", "r") as fp: 16 | FOOD_CLASS_DICT = json.load(fp) 17 | 18 | with open(f"{root}/text.json", "r") as fp: 19 | text_dir = json.load(fp) 20 | 21 | with open(f"{root}/split.json", "r") as fp: 22 | split_sets = json.load(fp) 23 | 24 | 25 | for split, samples in split_sets.items(): 26 | split_type = 'train' if split != 'test' else 'test' 27 | data_list = [] 28 | for sample in tqdm(samples): 29 | if sample not in text_dir: 30 | print("ignore no text data: ", sample) 31 | continue 32 | cls = sample[:sample.rindex('_')] 33 | label = FOOD_CLASS_DICT[cls] 34 | image_path = os.path.join(image_root, split_type, cls, sample) 35 | 36 | with open(image_path, "rb") as fp: 37 | binary = fp.read() 38 | 39 | text = [text_dir[sample]] 40 | 41 | 42 | data = (binary, text, label, sample, split) 43 | data_list.append(data) 44 | 45 | dataframe = pd.DataFrame( 46 | data_list, 47 | columns=[ 48 | "image", 49 | "text", 50 | "label", 51 | "image_id", 52 | "split", 53 | ], 54 | ) 55 | 56 | table = pa.Table.from_pandas(dataframe) 57 | 58 | os.makedirs(dataset_root, exist_ok=True) 59 | with pa.OSFile(f"{dataset_root}/food101_{split}.arrow", "wb") as sink: 60 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 61 | writer.write_table(table) -------------------------------------------------------------------------------- /chestXray/vilt/utils/write_food101.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | def make_arrow(root, dataset_root, single_plot=False, missing_type=None): 13 | image_root = os.path.join(root, 'images') 14 | 15 | with open(f"{root}/class_idx.json", "r") as fp: 16 | FOOD_CLASS_DICT = json.load(fp) 17 | 18 | with open(f"{root}/text.json", "r") as fp: 19 | text_dir = json.load(fp) 20 | 21 | with open(f"{root}/split.json", "r") as fp: 22 | split_sets = json.load(fp) 23 | 24 | 25 | for split, samples in split_sets.items(): 26 | split_type = 'train' if split != 'test' else 'test' 27 | data_list = [] 28 | for sample in tqdm(samples): 29 | if sample not in text_dir: 30 | print("ignore no text data: ", sample) 31 | continue 32 | cls = sample[:sample.rindex('_')] 33 | label = FOOD_CLASS_DICT[cls] 34 | image_path = os.path.join(image_root, split_type, cls, sample) 35 | 36 | with open(image_path, "rb") as fp: 37 | binary = fp.read() 38 | 39 | text = [text_dir[sample]] 40 | 41 | 42 | data = (binary, text, label, sample, split) 43 | data_list.append(data) 44 | 45 | dataframe = pd.DataFrame( 46 | data_list, 47 | columns=[ 48 | "image", 49 | "text", 50 | "label", 51 | "image_id", 52 | "split", 53 | ], 54 | ) 55 | 56 | table = pa.Table.from_pandas(dataframe) 57 | 58 | os.makedirs(dataset_root, exist_ok=True) 59 | with pa.OSFile(f"{dataset_root}/food101_{split}.arrow", "wb") as sink: 60 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 61 | writer.write_table(table) -------------------------------------------------------------------------------- /ODIR/sort.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import random 3 | import torch 4 | import io 5 | import pyarrow as pa 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | 11 | def read_folder(folder): 12 | lis = [] 13 | 14 | order = sorted(os.listdir(folder), key=lambda x: int(x.split('.')[0])) 15 | for temp in order: 16 | x = torch.Tensor(np.load(folder + '/' + temp)) 17 | lis.append(x) 18 | 19 | lis = torch.stack(lis,dim=0) 20 | 21 | return lis 22 | 23 | training_image_embedding = read_folder("training_image_logit") 24 | torch.save(training_image_embedding,"training_image_logit.pt") 25 | 26 | test_image_embedding = read_folder("test_image_logit") 27 | torch.save(test_image_embedding ,"test_image_logit.pt") 28 | 29 | 30 | 31 | def kl_divergence(P, Q): 32 | # Ensure tensors have the same shape 33 | assert P.size() == Q.size(), "Tensors must have the same shape" 34 | 35 | P = torch.softmax(P,dim=-1) 36 | Q = torch.softmax(Q,dim=-1) 37 | 38 | # Compute element-wise logarithm 39 | log_P = torch.log(P) 40 | log_Q = torch.log(Q) 41 | 42 | # Compute element-wise difference 43 | log_diff = log_P - log_Q 44 | 45 | # Compute element-wise product 46 | prod = P * log_diff 47 | 48 | # Sum up all the elements 49 | kl_div = torch.sum(prod) 50 | 51 | return kl_div 52 | 53 | 54 | 55 | flag = 0 56 | i = 0 57 | 58 | index_list = [] 59 | 60 | while flag < test_image_embedding.size(0): 61 | order = sorted(os.listdir("test_image_logit"), key=lambda x: int(x.split('.')[0])) 62 | if str(i) == order[flag].split('.')[0]: 63 | temp = test_image_embedding[flag] 64 | # Calculate cosine similarity between temp and each tensor in source 65 | cos_similarities = [torch.nn.functional.cosine_similarity(temp.unsqueeze(0), s, dim=1) for s in training_image_embedding] 66 | 67 | #kl = [kl_divergence(temp, training_image_embedding[s]) for s in range(training_image_embedding.size(0))] 68 | 69 | 70 | # Find the index of the tensor in source with the maximum similarity 71 | max_index = np.argmax([similarity.item() for similarity in cos_similarities]) 72 | 73 | print(max_index) 74 | 75 | index_list.append(max_index) 76 | flag += 1 77 | else: 78 | index_list.append(i) 79 | i += 1 80 | 81 | 82 | 83 | np.save("test_text_index.npy",index_list) 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /chestXray/sort.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import random 3 | import torch 4 | import io 5 | import pyarrow as pa 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | 11 | def read_folder(folder): 12 | lis = [] 13 | 14 | order = sorted(os.listdir(folder), key=lambda x: int(x.split('.')[0])) 15 | for temp in order: 16 | x = torch.Tensor(np.load(folder + '/' + temp)) 17 | lis.append(x) 18 | 19 | lis = torch.stack(lis,dim=0) 20 | 21 | return lis 22 | 23 | training_image_embedding = read_folder("training_image_logit") 24 | torch.save(training_image_embedding,"training_image_logit.pt") 25 | 26 | test_image_embedding = read_folder("test_image_logit") 27 | torch.save(test_image_embedding ,"test_image_logit.pt") 28 | 29 | 30 | 31 | def kl_divergence(P, Q): 32 | # Ensure tensors have the same shape 33 | assert P.size() == Q.size(), "Tensors must have the same shape" 34 | 35 | P = torch.softmax(P,dim=-1) 36 | Q = torch.softmax(Q,dim=-1) 37 | 38 | # Compute element-wise logarithm 39 | log_P = torch.log(P) 40 | log_Q = torch.log(Q) 41 | 42 | # Compute element-wise difference 43 | log_diff = log_P - log_Q 44 | 45 | # Compute element-wise product 46 | prod = P * log_diff 47 | 48 | # Sum up all the elements 49 | kl_div = torch.sum(prod) 50 | 51 | return kl_div 52 | 53 | 54 | 55 | flag = 0 56 | i = 0 57 | 58 | index_list = [] 59 | 60 | while flag < test_image_embedding.size(0): 61 | order = sorted(os.listdir("test_image_logit"), key=lambda x: int(x.split('.')[0])) 62 | if str(i) == order[flag].split('.')[0]: 63 | temp = test_image_embedding[flag] 64 | # Calculate cosine similarity between temp and each tensor in source 65 | cos_similarities = [torch.nn.functional.cosine_similarity(temp.unsqueeze(0), s, dim=1) for s in training_image_embedding] 66 | 67 | #kl = [kl_divergence(temp, training_image_embedding[s]) for s in range(training_image_embedding.size(0))] 68 | 69 | 70 | # Find the index of the tensor in source with the maximum similarity 71 | max_index = np.argmax([similarity.item() for similarity in cos_similarities]) 72 | 73 | print(max_index) 74 | 75 | index_list.append(max_index) 76 | flag += 1 77 | else: 78 | index_list.append(i) 79 | i += 1 80 | 81 | 82 | 83 | np.save("test_text_index.npy",index_list) 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /ODIR/vilt/utils/write_mmimdb.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | def make_arrow(root, dataset_root, single_plot=False, missing_type=None): 13 | GENRE_CLASS = ['Drama', 'Comedy', 'Romance', 'Thriller', 'Crime', 'Action', 'Adventure', 'Horror' 14 | , 'Documentary', 'Mystery', 'Sci-Fi', 'Fantasy', 'Family', 'Biography', 'War', 'History', 'Music', 15 | 'Animation', 'Musical', 'Western', 'Sport', 'Short', 'Film-Noir'] 16 | GENRE_CLASS_DICT = {} 17 | for idx, genre in enumerate(GENRE_CLASS): 18 | GENRE_CLASS_DICT[genre] = idx 19 | 20 | image_root = os.path.join(root, 'images') 21 | label_root = os.path.join(root, 'labels') 22 | 23 | with open(f"{root}/split.json", "r") as fp: 24 | split_sets = json.load(fp) 25 | 26 | 27 | total_genres = [] 28 | for split, samples in split_sets.items(): 29 | data_list = [] 30 | for sample in tqdm(samples): 31 | image_path = os.path.join(image_root, sample+'.jpeg') 32 | label_path = os.path.join(label_root, sample+'.json') 33 | with open(image_path, "rb") as fp: 34 | binary = fp.read() 35 | with open(label_path, "r") as fp: 36 | labels = json.load(fp) 37 | 38 | # There could be more than one plot for a movie, 39 | # if single plot, only the first plots are used 40 | if single_plot: 41 | plots = [labels['plot'][0]] 42 | else: 43 | plots = labels['plot'] 44 | 45 | genres = labels['genres'] 46 | label = [1 if g in genres else 0 for g in GENRE_CLASS_DICT] 47 | data = (binary, plots, label, genres, sample, split) 48 | data_list.append(data) 49 | 50 | dataframe = pd.DataFrame( 51 | data_list, 52 | columns=[ 53 | "image", 54 | "plots", 55 | "label", 56 | "genres", 57 | "image_id", 58 | "split", 59 | ], 60 | ) 61 | 62 | table = pa.Table.from_pandas(dataframe) 63 | 64 | os.makedirs(dataset_root, exist_ok=True) 65 | with pa.OSFile(f"{dataset_root}/mmimdb_{split}.arrow", "wb") as sink: 66 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 67 | writer.write_table(table) -------------------------------------------------------------------------------- /chestXray/vilt/utils/write_mmimdb.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | def make_arrow(root, dataset_root, single_plot=False, missing_type=None): 13 | GENRE_CLASS = ['Drama', 'Comedy', 'Romance', 'Thriller', 'Crime', 'Action', 'Adventure', 'Horror' 14 | , 'Documentary', 'Mystery', 'Sci-Fi', 'Fantasy', 'Family', 'Biography', 'War', 'History', 'Music', 15 | 'Animation', 'Musical', 'Western', 'Sport', 'Short', 'Film-Noir'] 16 | GENRE_CLASS_DICT = {} 17 | for idx, genre in enumerate(GENRE_CLASS): 18 | GENRE_CLASS_DICT[genre] = idx 19 | 20 | image_root = os.path.join(root, 'images') 21 | label_root = os.path.join(root, 'labels') 22 | 23 | with open(f"{root}/split.json", "r") as fp: 24 | split_sets = json.load(fp) 25 | 26 | 27 | total_genres = [] 28 | for split, samples in split_sets.items(): 29 | data_list = [] 30 | for sample in tqdm(samples): 31 | image_path = os.path.join(image_root, sample+'.jpeg') 32 | label_path = os.path.join(label_root, sample+'.json') 33 | with open(image_path, "rb") as fp: 34 | binary = fp.read() 35 | with open(label_path, "r") as fp: 36 | labels = json.load(fp) 37 | 38 | # There could be more than one plot for a movie, 39 | # if single plot, only the first plots are used 40 | if single_plot: 41 | plots = [labels['plot'][0]] 42 | else: 43 | plots = labels['plot'] 44 | 45 | genres = labels['genres'] 46 | label = [1 if g in genres else 0 for g in GENRE_CLASS_DICT] 47 | data = (binary, plots, label, genres, sample, split) 48 | data_list.append(data) 49 | 50 | dataframe = pd.DataFrame( 51 | data_list, 52 | columns=[ 53 | "image", 54 | "plots", 55 | "label", 56 | "genres", 57 | "image_id", 58 | "split", 59 | ], 60 | ) 61 | 62 | table = pa.Table.from_pandas(dataframe) 63 | 64 | os.makedirs(dataset_root, exist_ok=True) 65 | with pa.OSFile(f"{dataset_root}/mmimdb_{split}.arrow", "wb") as sink: 66 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 67 | writer.write_table(table) -------------------------------------------------------------------------------- /ODIR/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from torchmetrics.functional import f1_score 4 | import pytorch_lightning as pl 5 | 6 | from vilt.config import ex 7 | from vilt.modules import ViLTransformerSS 8 | from vilt.datamodules.multitask_datamodule import MTDataModule 9 | 10 | 11 | @ex.automain 12 | def main(_config): 13 | _config = copy.deepcopy(_config) 14 | pl.seed_everything(_config["seed"]) 15 | 16 | dm = MTDataModule(_config, dist=True) 17 | 18 | model = ViLTransformerSS(_config) 19 | exp_name = f'{_config["exp_name"]}' 20 | 21 | os.makedirs(_config["log_dir"], exist_ok=True) 22 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 23 | save_top_k=1, 24 | verbose=True, 25 | monitor="val/the_metric", 26 | mode="max", 27 | save_last=True, 28 | ) 29 | logger = pl.loggers.TensorBoardLogger( 30 | _config["log_dir"], 31 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 32 | ) 33 | 34 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 35 | callbacks = [checkpoint_callback, lr_callback] 36 | # from pytorch_lightning.profiler import SimpleProfiler 37 | # profiler = SimpleProfiler() 38 | 39 | num_gpus = ( 40 | _config["num_gpus"] 41 | if isinstance(_config["num_gpus"], int) 42 | else len(_config["num_gpus"]) 43 | ) 44 | 45 | grad_steps = _config["batch_size"] // ( 46 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 47 | ) 48 | print(_config["batch_size"], _config["per_gpu_batchsize"], num_gpus, _config["num_nodes"]) 49 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else None 50 | 51 | trainer = pl.Trainer( 52 | gpus=1, 53 | num_nodes=1, 54 | precision=_config["precision"], 55 | accelerator="ddp", 56 | benchmark=True, 57 | deterministic=True, 58 | max_epochs=_config["max_epoch"] if max_steps is None else 20, 59 | max_steps=max_steps, 60 | callbacks=callbacks, 61 | logger=logger, 62 | prepare_data_per_node=False, 63 | accumulate_grad_batches=grad_steps, 64 | log_every_n_steps=100, 65 | flush_logs_every_n_steps=100, 66 | resume_from_checkpoint=_config["resume_from"], 67 | weights_summary="top", 68 | fast_dev_run=_config["fast_dev_run"], 69 | val_check_interval=_config["val_check_interval"], 70 | # profiler=profiler, 71 | ) 72 | 73 | if not _config["test_only"]: 74 | trainer.fit(model, datamodule=dm) 75 | else: 76 | trainer.test(model, datamodule=dm) -------------------------------------------------------------------------------- /chestXray/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from torchmetrics.functional import f1_score 4 | import pytorch_lightning as pl 5 | 6 | from vilt.config import ex 7 | from vilt.modules import ViLTransformerSS 8 | from vilt.datamodules.multitask_datamodule import MTDataModule 9 | 10 | 11 | @ex.automain 12 | def main(_config): 13 | _config = copy.deepcopy(_config) 14 | pl.seed_everything(_config["seed"]) 15 | 16 | dm = MTDataModule(_config, dist=True) 17 | 18 | model = ViLTransformerSS(_config) 19 | exp_name = f'{_config["exp_name"]}' 20 | 21 | os.makedirs(_config["log_dir"], exist_ok=True) 22 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 23 | save_top_k=1, 24 | verbose=True, 25 | monitor="val/the_metric", 26 | mode="max", 27 | save_last=True, 28 | ) 29 | logger = pl.loggers.TensorBoardLogger( 30 | _config["log_dir"], 31 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 32 | ) 33 | 34 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 35 | callbacks = [checkpoint_callback, lr_callback] 36 | # from pytorch_lightning.profiler import SimpleProfiler 37 | # profiler = SimpleProfiler() 38 | 39 | num_gpus = ( 40 | _config["num_gpus"] 41 | if isinstance(_config["num_gpus"], int) 42 | else len(_config["num_gpus"]) 43 | ) 44 | 45 | grad_steps = _config["batch_size"] // ( 46 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 47 | ) 48 | print(_config["batch_size"], _config["per_gpu_batchsize"], num_gpus, _config["num_nodes"]) 49 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else None 50 | 51 | trainer = pl.Trainer( 52 | gpus=1, 53 | num_nodes=1, 54 | precision=_config["precision"], 55 | accelerator="ddp", 56 | benchmark=True, 57 | deterministic=True, 58 | max_epochs=_config["max_epoch"] if max_steps is None else 20, 59 | max_steps=max_steps, 60 | callbacks=callbacks, 61 | logger=logger, 62 | prepare_data_per_node=False, 63 | accumulate_grad_batches=grad_steps, 64 | log_every_n_steps=100, 65 | flush_logs_every_n_steps=100, 66 | resume_from_checkpoint=_config["resume_from"], 67 | weights_summary="top", 68 | fast_dev_run=_config["fast_dev_run"], 69 | val_check_interval=_config["val_check_interval"], 70 | # profiler=profiler, 71 | ) 72 | 73 | if not _config["test_only"]: 74 | trainer.fit(model, datamodule=dm) 75 | else: 76 | trainer.test(model, datamodule=dm) -------------------------------------------------------------------------------- /ODIR/vilt/modules/heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers.models.bert.modeling_bert import BertPredictionHeadTransform 6 | 7 | 8 | class Pooler(nn.Module): 9 | def __init__(self, hidden_size): 10 | super().__init__() 11 | self.dense = nn.Linear(hidden_size, hidden_size) 12 | self.activation = nn.Tanh() 13 | 14 | def forward(self, hidden_states): 15 | first_token_tensor = hidden_states[:, 0] 16 | pooled_output = self.dense(first_token_tensor) 17 | pooled_output = self.activation(pooled_output) 18 | return pooled_output 19 | 20 | class Prompt_Pooler(nn.Module): 21 | def __init__(self, prompt_num, prompt_length, hidden_size): 22 | super().__init__() 23 | self.prompt_num = prompt_num 24 | self.prompt_length = prompt_length 25 | self.hidden_size = hidden_size 26 | 27 | self.attn = nn.MultiheadAttention(hidden_size, 16,batch_first=True) 28 | 29 | self.dense = nn.Linear(hidden_size, hidden_size) 30 | self.activation = nn.Tanh() 31 | 32 | def forward(self, prompt,hidden_states): 33 | first_token_tensor = hidden_states[:, 0:1] 34 | prompt_token = prompt.view(prompt.size(0),self.prompt_num*self.prompt_length,self.hidden_size) 35 | 36 | attn_input = torch.cat([first_token_tensor,prompt_token],dim=1) 37 | 38 | attn_output = self.attn(attn_input,attn_input,attn_input) 39 | 40 | pooled_output = self.dense(attn_output[0][:,0] + hidden_states[:, 0]) 41 | pooled_output = self.activation(pooled_output) 42 | return pooled_output 43 | 44 | 45 | class ITMHead(nn.Module): 46 | def __init__(self, hidden_size): 47 | super().__init__() 48 | self.fc = nn.Linear(hidden_size, 2) 49 | 50 | def forward(self, x): 51 | x = self.fc(x) 52 | return x 53 | 54 | 55 | class MLMHead(nn.Module): 56 | def __init__(self, config, weight=None): 57 | super().__init__() 58 | self.transform = BertPredictionHeadTransform(config) 59 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 60 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 61 | if weight is not None: 62 | self.decoder.weight = weight 63 | 64 | def forward(self, x): 65 | x = self.transform(x) 66 | x = self.decoder(x) + self.bias 67 | return x 68 | 69 | 70 | class MPPHead(nn.Module): 71 | def __init__(self, config): 72 | super().__init__() 73 | self.transform = BertPredictionHeadTransform(config) 74 | self.decoder = nn.Linear(config.hidden_size, 256 * 3) 75 | 76 | def forward(self, x): 77 | x = self.transform(x) 78 | x = self.decoder(x) 79 | return x 80 | -------------------------------------------------------------------------------- /chestXray/vilt/modules/heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers.models.bert.modeling_bert import BertPredictionHeadTransform 6 | 7 | 8 | class Pooler(nn.Module): 9 | def __init__(self, hidden_size): 10 | super().__init__() 11 | self.dense = nn.Linear(hidden_size, hidden_size) 12 | self.activation = nn.Tanh() 13 | 14 | def forward(self, hidden_states): 15 | first_token_tensor = hidden_states[:, 0] 16 | pooled_output = self.dense(first_token_tensor) 17 | pooled_output = self.activation(pooled_output) 18 | return pooled_output 19 | 20 | class Prompt_Pooler(nn.Module): 21 | def __init__(self, prompt_num, prompt_length, hidden_size): 22 | super().__init__() 23 | self.prompt_num = prompt_num 24 | self.prompt_length = prompt_length 25 | self.hidden_size = hidden_size 26 | 27 | self.attn = nn.MultiheadAttention(hidden_size, 16,batch_first=True) 28 | 29 | self.dense = nn.Linear(hidden_size, hidden_size) 30 | self.activation = nn.Tanh() 31 | 32 | def forward(self, prompt,hidden_states): 33 | first_token_tensor = hidden_states[:, 0:1] 34 | prompt_token = prompt.view(prompt.size(0),self.prompt_num*self.prompt_length,self.hidden_size) 35 | 36 | attn_input = torch.cat([first_token_tensor,prompt_token],dim=1) 37 | 38 | attn_output = self.attn(attn_input,attn_input,attn_input) 39 | 40 | pooled_output = self.dense(attn_output[0][:,0] + hidden_states[:, 0]) 41 | pooled_output = self.activation(pooled_output) 42 | return pooled_output 43 | 44 | 45 | class ITMHead(nn.Module): 46 | def __init__(self, hidden_size): 47 | super().__init__() 48 | self.fc = nn.Linear(hidden_size, 2) 49 | 50 | def forward(self, x): 51 | x = self.fc(x) 52 | return x 53 | 54 | 55 | class MLMHead(nn.Module): 56 | def __init__(self, config, weight=None): 57 | super().__init__() 58 | self.transform = BertPredictionHeadTransform(config) 59 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 60 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 61 | if weight is not None: 62 | self.decoder.weight = weight 63 | 64 | def forward(self, x): 65 | x = self.transform(x) 66 | x = self.decoder(x) + self.bias 67 | return x 68 | 69 | 70 | class MPPHead(nn.Module): 71 | def __init__(self, config): 72 | super().__init__() 73 | self.transform = BertPredictionHeadTransform(config) 74 | self.decoder = nn.Linear(config.hidden_size, 256 * 3) 75 | 76 | def forward(self, x): 77 | x = self.transform(x) 78 | x = self.decoder(x) 79 | return x 80 | -------------------------------------------------------------------------------- /ODIR/make_ODIR_arrow.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | import json 5 | import pyarrow as pa 6 | import random 7 | import os 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--root', default='./datasets', type=str, help='Root of datasets') 10 | args = parser.parse_args() 11 | 12 | def make_arrow(root): 13 | 14 | 15 | # Replace the path with the actual path to your Excel file 16 | file_path = args.root + "/" + 'ODIR-5K_Training_Annotations(Updated)_V2.xlsx' 17 | 18 | # Reading the Excel file 19 | df_reports = pd.read_excel(file_path, engine='openpyxl') 20 | 21 | # Assuming 'cleaned_df' is your DataFrame after removing NaN values 22 | num_samples = df_reports.shape[0] 23 | print("num_samples:",num_samples) 24 | # Create a random array of floats between 0 and 1 25 | random_floats = np.random.rand(num_samples) 26 | # Assign 'train' for 80%, 'val' for 10%, and 'test' for the remaining 10% 27 | split_values = np.where(random_floats < 0.8, 'train', np.where(random_floats < 0.9, 'val', 'test')) 28 | # Add the 'split' column to the DataFrame 29 | df_reports['split'] = split_values 30 | 31 | split_types = ['train','val','test'] 32 | 33 | 34 | for split_type in split_types: 35 | data_list = [] 36 | df = df_reports[df_reports['split'] == split_type] 37 | for index, sample in df.iterrows(): 38 | left_image_path = root + "/" + sample['Left-Fundus'] 39 | right_image_path = root + "/" + sample['Right-Fundus'] 40 | with open(left_image_path, "rb") as fp: 41 | left_binary = fp.read() 42 | with open(right_image_path, "rb") as fp: 43 | right_binary = fp.read() 44 | 45 | label = [int(sample["N"]),int(sample["D"]),int(sample["G"]),int(sample["C"]),int(sample["A"]),int(sample["H"]),int(sample["M"]),int(sample["O"])] 46 | 47 | 48 | # Convert NaN to empty strings 49 | Left_text = str(sample['Left-Diagnostic Keywords']) if pd.notnull(sample['Left-Diagnostic Keywords']) else '' 50 | Right_text = str(sample['Right-Diagnostic Keywords']) if pd.notnull(sample['Right-Diagnostic Keywords']) else '' 51 | # Now concatenate findings and impression 52 | text = ["Left_text:" + Left_text + "Right_text:" + Right_text] 53 | 54 | split = sample['split'] 55 | 56 | data = (left_binary, right_binary, text, label, split) 57 | data_list.append(data) 58 | 59 | dataframe = pd.DataFrame( 60 | data_list, 61 | columns=[ 62 | "left_image", 63 | "right_image", 64 | "text", 65 | "label", 66 | "split", 67 | ], 68 | ) 69 | 70 | print(dataframe.shape) 71 | 72 | table = pa.Table.from_pandas(dataframe) 73 | 74 | with pa.OSFile(f"{root}/ODIR_{split_type}.arrow", "wb") as sink: 75 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 76 | writer.write_table(table) 77 | 78 | 79 | make_arrow(f'{args.root}') -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import ConcatDataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from . import _datamodules 9 | 10 | 11 | class MTDataModule(LightningDataModule): 12 | def __init__(self, _config, dist=False): 13 | datamodule_keys = _config["datasets"] 14 | assert len(datamodule_keys) > 0 15 | 16 | super().__init__() 17 | 18 | self.dm_keys = datamodule_keys 19 | self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys} 20 | self.dms = [v for k, v in self.dm_dicts.items()] 21 | 22 | self.batch_size = self.dms[0].batch_size 23 | self.vocab_size = self.dms[0].vocab_size 24 | self.num_workers = self.dms[0].num_workers 25 | 26 | self.dist = dist 27 | 28 | def prepare_data(self): 29 | for dm in self.dms: 30 | dm.prepare_data() 31 | 32 | def setup(self, stage): 33 | for dm in self.dms: 34 | dm.setup(stage) 35 | 36 | self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) 37 | self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) 38 | self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) 39 | self.tokenizer = self.dms[0].tokenizer 40 | 41 | self.train_collate = functools.partial( 42 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 43 | ) 44 | self.val_collate = functools.partial( 45 | self.dms[0].val_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 46 | ) 47 | self.test_collate = functools.partial( 48 | self.dms[0].test_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 49 | ) 50 | 51 | if self.dist: 52 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 53 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True) 54 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 55 | else: 56 | self.train_sampler = None 57 | self.val_sampler = None 58 | self.test_sampler = None 59 | 60 | def train_dataloader(self): 61 | loader = DataLoader( 62 | self.train_dataset, 63 | batch_size=self.batch_size, 64 | sampler=self.train_sampler, 65 | num_workers=self.num_workers, 66 | collate_fn=self.train_collate, 67 | ) 68 | return loader 69 | 70 | def val_dataloader(self, batch_size=None): 71 | loader = DataLoader( 72 | self.val_dataset, 73 | batch_size=batch_size if batch_size is not None else self.batch_size, 74 | sampler=self.val_sampler, 75 | num_workers=self.num_workers, 76 | collate_fn=self.val_collate, 77 | ) 78 | return loader 79 | 80 | def test_dataloader(self): 81 | loader = DataLoader( 82 | self.test_dataset, 83 | batch_size=self.batch_size, 84 | sampler=self.test_sampler, 85 | num_workers=self.num_workers, 86 | collate_fn=self.test_collate, 87 | ) 88 | return loader 89 | -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import ConcatDataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from . import _datamodules 9 | 10 | 11 | class MTDataModule(LightningDataModule): 12 | def __init__(self, _config, dist=False): 13 | datamodule_keys = _config["datasets"] 14 | assert len(datamodule_keys) > 0 15 | 16 | super().__init__() 17 | 18 | self.dm_keys = datamodule_keys 19 | self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys} 20 | self.dms = [v for k, v in self.dm_dicts.items()] 21 | 22 | self.batch_size = self.dms[0].batch_size 23 | self.vocab_size = self.dms[0].vocab_size 24 | self.num_workers = self.dms[0].num_workers 25 | 26 | self.dist = dist 27 | 28 | def prepare_data(self): 29 | for dm in self.dms: 30 | dm.prepare_data() 31 | 32 | def setup(self, stage): 33 | for dm in self.dms: 34 | dm.setup(stage) 35 | 36 | self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) 37 | self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) 38 | self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) 39 | self.tokenizer = self.dms[0].tokenizer 40 | 41 | self.train_collate = functools.partial( 42 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 43 | ) 44 | self.val_collate = functools.partial( 45 | self.dms[0].val_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 46 | ) 47 | self.test_collate = functools.partial( 48 | self.dms[0].test_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 49 | ) 50 | 51 | if self.dist: 52 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 53 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True) 54 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 55 | else: 56 | self.train_sampler = None 57 | self.val_sampler = None 58 | self.test_sampler = None 59 | 60 | def train_dataloader(self): 61 | loader = DataLoader( 62 | self.train_dataset, 63 | batch_size=self.batch_size, 64 | sampler=self.train_sampler, 65 | num_workers=self.num_workers, 66 | collate_fn=self.train_collate, 67 | ) 68 | return loader 69 | 70 | def val_dataloader(self, batch_size=None): 71 | loader = DataLoader( 72 | self.val_dataset, 73 | batch_size=batch_size if batch_size is not None else self.batch_size, 74 | sampler=self.val_sampler, 75 | num_workers=self.num_workers, 76 | collate_fn=self.val_collate, 77 | ) 78 | return loader 79 | 80 | def test_dataloader(self): 81 | loader = DataLoader( 82 | self.test_dataset, 83 | batch_size=self.batch_size, 84 | sampler=self.test_sampler, 85 | num_workers=self.num_workers, 86 | collate_fn=self.test_collate, 87 | ) 88 | return loader 89 | -------------------------------------------------------------------------------- /chestXray/make_chestXray_arrow.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | import json 5 | import pyarrow as pa 6 | import random 7 | import os 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--root', default='./datasets', type=str, help='Root of datasets') 10 | args = parser.parse_args() 11 | 12 | def make_arrow(root): 13 | CLASS = ['Lung', 'Opacity', 'Cardiomegaly', 'Calcinosis', 'Pulmonary Atelectasis', 'Calcified Granuloma', 14 | 'Thoracic Vertebrae', 'Cicatrix', 'Spine', 'Markings', 'Pleural Effusion', 'Aorta', 'Diaphragm', 15 | 'Density', 'Atherosclerosis', 'Deformity', 'Airspace Disease', 'Catheters, Indwelling', 'Scoliosis', 'Nodule'] 16 | CLASS_DICT = {} 17 | for idx, genre in enumerate(CLASS): 18 | CLASS_DICT[genre] = idx 19 | 20 | df_reports = pd.read_csv(root + 'indiana_reports.csv') 21 | df_projections = pd.read_csv(root + 'indiana_projections.csv') 22 | frontal_projections = df_projections[df_projections['projection'] == 'Frontal'] 23 | combined_df = pd.merge(frontal_projections, df_reports, on='uid') 24 | cleaned_df = combined_df.dropna(subset=['findings', 'impression'], how='all') 25 | 26 | # Assuming 'cleaned_df' is your DataFrame after removing NaN values 27 | num_samples = cleaned_df.shape[0] 28 | # Create a random array of floats between 0 and 1 29 | random_floats = np.random.rand(num_samples) 30 | # Assign 'train' for 80%, 'val' for 10%, and 'test' for the remaining 10% 31 | split_values = np.where(random_floats < 0.8, 'train', np.where(random_floats < 0.9, 'val', 'test')) 32 | # Add the 'split' column to the DataFrame 33 | cleaned_df['split'] = split_values 34 | 35 | split_types = ['train','val','test'] 36 | 37 | 38 | for split_type in split_types: 39 | data_list = [] 40 | df = cleaned_df[cleaned_df['split'] == split_type] 41 | for index, sample in df.iterrows(): 42 | image_path = root + "images/images_normalized/" + sample['filename'] 43 | with open(image_path, "rb") as fp: 44 | binary = fp.read() 45 | 46 | problems = sample['Problems'].split(";") 47 | if problems[0] == "normal": 48 | label = [0]*20 49 | else: 50 | label = [1 if p in problems else 0 for p in CLASS] 51 | 52 | # Convert NaN to empty strings 53 | findings = str(sample['findings']) if pd.notnull(sample['findings']) else '' 54 | impression = str(sample['impression']) if pd.notnull(sample['impression']) else '' 55 | # Now concatenate findings and impression 56 | text = [findings + impression] 57 | 58 | split = sample['split'] 59 | 60 | data = (binary, text, label, split) 61 | data_list.append(data) 62 | 63 | dataframe = pd.DataFrame( 64 | data_list, 65 | columns=[ 66 | "image", 67 | "text", 68 | "label", 69 | "split", 70 | ], 71 | ) 72 | 73 | print(dataframe.shape) 74 | 75 | table = pa.Table.from_pandas(dataframe) 76 | 77 | with pa.OSFile(f"{root}/chestXray_{split_type}.arrow", "wb") as sink: 78 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 79 | writer.write_table(table) 80 | 81 | 82 | make_arrow(f'{args.root}') -------------------------------------------------------------------------------- /ODIR/vilt/datasets/hatememes_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random, os 4 | 5 | class HateMemesDataset(BaseDataset): 6 | def __init__(self, *args, split="", missing_info={}, **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["hatememes_train"] 12 | elif split == "val": 13 | names = ["hatememes_dev"] 14 | elif split == "test": 15 | names = ["hatememes_test"] 16 | 17 | super().__init__( 18 | *args, 19 | **kwargs, 20 | names=names, 21 | text_column_name="text", 22 | remove_duplicate=False, 23 | ) 24 | 25 | # missing modality control 26 | self.simulate_missing = missing_info['simulate_missing'] 27 | missing_ratio = missing_info['ratio'][split] 28 | mratio = str(missing_ratio).replace('.','') 29 | missing_type = missing_info['type'][split] 30 | both_ratio = missing_info['both_ratio'] 31 | missing_table_root = missing_info['missing_table_root'] 32 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 33 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 34 | 35 | # use image data to formulate missing table 36 | total_num = len(self.table['image']) 37 | 38 | if os.path.exists(missing_table_path): 39 | missing_table = torch.load(missing_table_path) 40 | if len(missing_table) != total_num: 41 | print('missing table mismatched!') 42 | exit() 43 | else: 44 | missing_table = torch.zeros(total_num) 45 | 46 | if missing_ratio > 0: 47 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 48 | 49 | if missing_type == 'text': 50 | missing_table[missing_index] = 1 51 | elif missing_type == 'image': 52 | missing_table[missing_index] = 2 53 | elif missing_type == 'both': 54 | missing_table[missing_index] = 1 55 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 56 | missing_table[missing_index_image] = 2 57 | 58 | torch.save(missing_table, missing_table_path) 59 | 60 | self.missing_table = missing_table 61 | 62 | def __getitem__(self, index): 63 | # index -> pair data index 64 | # image_index -> image index in table 65 | # question_index -> plot index in texts of the given image 66 | image_index, question_index = self.index_mapper[index] 67 | 68 | # For the case of training with modality-complete data 69 | # Simulate missing modality with random assign the missing type of samples 70 | simulate_missing_type = 0 71 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 72 | simulate_missing_type = random.choice([0,1,2]) 73 | 74 | image_tensor = self.get_image(index)["image"] 75 | 76 | # missing image, dummy image is all-one image 77 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 78 | for idx in range(len(image_tensor)): 79 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 80 | 81 | #missing text, dummy text is '' 82 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 83 | text = '' 84 | encoding = self.tokenizer( 85 | text, 86 | padding="max_length", 87 | truncation=True, 88 | max_length=self.max_text_len, 89 | return_special_tokens_mask=True, 90 | ) 91 | text = (text, encoding) 92 | else: 93 | text = self.get_text(index)["text"] 94 | 95 | 96 | labels = self.table["label"][image_index].as_py() 97 | 98 | return { 99 | "image": image_tensor, 100 | "text": text, 101 | "label": labels, 102 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 103 | } 104 | -------------------------------------------------------------------------------- /ODIR/vilt/datasets/mmimdb_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random 4 | import os 5 | 6 | class MMIMDBDataset(BaseDataset): 7 | def __init__(self, *args, split="", missing_info={}, **kwargs): 8 | assert split in ["train", "val", "test"] 9 | self.split = split 10 | 11 | if split == "train": 12 | names = ["mmimdb_train"] 13 | elif split == "val": 14 | names = ["mmimdb_dev"] 15 | elif split == "test": 16 | names = ["mmimdb_test"] 17 | 18 | super().__init__( 19 | *args, 20 | **kwargs, 21 | names=names, 22 | text_column_name="plots", 23 | remove_duplicate=False, 24 | ) 25 | 26 | # missing modality control 27 | self.simulate_missing = missing_info['simulate_missing'] 28 | missing_ratio = missing_info['ratio'][split] 29 | mratio = str(missing_ratio).replace('.','') 30 | missing_type = missing_info['type'][split] 31 | both_ratio = missing_info['both_ratio'] 32 | missing_table_root = missing_info['missing_table_root'] 33 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 34 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 35 | 36 | # use image data to formulate missing table 37 | total_num = len(self.table['image']) 38 | 39 | if os.path.exists(missing_table_path): 40 | missing_table = torch.load(missing_table_path) 41 | if len(missing_table) != total_num: 42 | print('missing table mismatched!') 43 | exit() 44 | else: 45 | missing_table = torch.zeros(total_num) 46 | 47 | if missing_ratio > 0: 48 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 49 | 50 | if missing_type == 'text': 51 | missing_table[missing_index] = 1 52 | elif missing_type == 'image': 53 | missing_table[missing_index] = 2 54 | elif missing_type == 'both': 55 | missing_table[missing_index] = 1 56 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 57 | missing_table[missing_index_image] = 2 58 | 59 | torch.save(missing_table, missing_table_path) 60 | 61 | self.missing_table = missing_table 62 | 63 | def __getitem__(self, index): 64 | # index -> pair data index 65 | # image_index -> image index in table 66 | # question_index -> plot index in texts of the given image 67 | image_index, question_index = self.index_mapper[index] 68 | 69 | # For the case of training with modality-complete data 70 | # Simulate missing modality with random assign the missing type of samples 71 | simulate_missing_type = 0 72 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 73 | simulate_missing_type = random.choice([0,1,2]) 74 | 75 | image_tensor = self.get_image(index)["image"] 76 | 77 | # missing image, dummy image is all-one image 78 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 79 | for idx in range(len(image_tensor)): 80 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 81 | 82 | #missing text, dummy text is '' 83 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 84 | text = '' 85 | encoding = self.tokenizer( 86 | text, 87 | padding="max_length", 88 | truncation=True, 89 | max_length=self.max_text_len, 90 | return_special_tokens_mask=True, 91 | ) 92 | text = (text, encoding) 93 | else: 94 | text = self.get_text(index)["text"] 95 | 96 | 97 | labels = self.table["label"][image_index].as_py() 98 | 99 | return { 100 | "image": image_tensor, 101 | "text": text, 102 | "label": labels, 103 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 104 | } 105 | -------------------------------------------------------------------------------- /chestXray/vilt/datasets/hatememes_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random, os 4 | 5 | class HateMemesDataset(BaseDataset): 6 | def __init__(self, *args, split="", missing_info={}, **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["hatememes_train"] 12 | elif split == "val": 13 | names = ["hatememes_dev"] 14 | elif split == "test": 15 | names = ["hatememes_test"] 16 | 17 | super().__init__( 18 | *args, 19 | **kwargs, 20 | names=names, 21 | text_column_name="text", 22 | remove_duplicate=False, 23 | ) 24 | 25 | # missing modality control 26 | self.simulate_missing = missing_info['simulate_missing'] 27 | missing_ratio = missing_info['ratio'][split] 28 | mratio = str(missing_ratio).replace('.','') 29 | missing_type = missing_info['type'][split] 30 | both_ratio = missing_info['both_ratio'] 31 | missing_table_root = missing_info['missing_table_root'] 32 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 33 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 34 | 35 | # use image data to formulate missing table 36 | total_num = len(self.table['image']) 37 | 38 | if os.path.exists(missing_table_path): 39 | missing_table = torch.load(missing_table_path) 40 | if len(missing_table) != total_num: 41 | print('missing table mismatched!') 42 | exit() 43 | else: 44 | missing_table = torch.zeros(total_num) 45 | 46 | if missing_ratio > 0: 47 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 48 | 49 | if missing_type == 'text': 50 | missing_table[missing_index] = 1 51 | elif missing_type == 'image': 52 | missing_table[missing_index] = 2 53 | elif missing_type == 'both': 54 | missing_table[missing_index] = 1 55 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 56 | missing_table[missing_index_image] = 2 57 | 58 | torch.save(missing_table, missing_table_path) 59 | 60 | self.missing_table = missing_table 61 | 62 | def __getitem__(self, index): 63 | # index -> pair data index 64 | # image_index -> image index in table 65 | # question_index -> plot index in texts of the given image 66 | image_index, question_index = self.index_mapper[index] 67 | 68 | # For the case of training with modality-complete data 69 | # Simulate missing modality with random assign the missing type of samples 70 | simulate_missing_type = 0 71 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 72 | simulate_missing_type = random.choice([0,1,2]) 73 | 74 | image_tensor = self.get_image(index)["image"] 75 | 76 | # missing image, dummy image is all-one image 77 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 78 | for idx in range(len(image_tensor)): 79 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 80 | 81 | #missing text, dummy text is '' 82 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 83 | text = '' 84 | encoding = self.tokenizer( 85 | text, 86 | padding="max_length", 87 | truncation=True, 88 | max_length=self.max_text_len, 89 | return_special_tokens_mask=True, 90 | ) 91 | text = (text, encoding) 92 | else: 93 | text = self.get_text(index)["text"] 94 | 95 | 96 | labels = self.table["label"][image_index].as_py() 97 | 98 | return { 99 | "image": image_tensor, 100 | "text": text, 101 | "label": labels, 102 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 103 | } 104 | -------------------------------------------------------------------------------- /chestXray/vilt/datasets/mmimdb_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random 4 | import os 5 | 6 | class MMIMDBDataset(BaseDataset): 7 | def __init__(self, *args, split="", missing_info={}, **kwargs): 8 | assert split in ["train", "val", "test"] 9 | self.split = split 10 | 11 | if split == "train": 12 | names = ["mmimdb_train"] 13 | elif split == "val": 14 | names = ["mmimdb_dev"] 15 | elif split == "test": 16 | names = ["mmimdb_test"] 17 | 18 | super().__init__( 19 | *args, 20 | **kwargs, 21 | names=names, 22 | text_column_name="plots", 23 | remove_duplicate=False, 24 | ) 25 | 26 | # missing modality control 27 | self.simulate_missing = missing_info['simulate_missing'] 28 | missing_ratio = missing_info['ratio'][split] 29 | mratio = str(missing_ratio).replace('.','') 30 | missing_type = missing_info['type'][split] 31 | both_ratio = missing_info['both_ratio'] 32 | missing_table_root = missing_info['missing_table_root'] 33 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 34 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 35 | 36 | # use image data to formulate missing table 37 | total_num = len(self.table['image']) 38 | 39 | if os.path.exists(missing_table_path): 40 | missing_table = torch.load(missing_table_path) 41 | if len(missing_table) != total_num: 42 | print('missing table mismatched!') 43 | exit() 44 | else: 45 | missing_table = torch.zeros(total_num) 46 | 47 | if missing_ratio > 0: 48 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 49 | 50 | if missing_type == 'text': 51 | missing_table[missing_index] = 1 52 | elif missing_type == 'image': 53 | missing_table[missing_index] = 2 54 | elif missing_type == 'both': 55 | missing_table[missing_index] = 1 56 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 57 | missing_table[missing_index_image] = 2 58 | 59 | torch.save(missing_table, missing_table_path) 60 | 61 | self.missing_table = missing_table 62 | 63 | def __getitem__(self, index): 64 | # index -> pair data index 65 | # image_index -> image index in table 66 | # question_index -> plot index in texts of the given image 67 | image_index, question_index = self.index_mapper[index] 68 | 69 | # For the case of training with modality-complete data 70 | # Simulate missing modality with random assign the missing type of samples 71 | simulate_missing_type = 0 72 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 73 | simulate_missing_type = random.choice([0,1,2]) 74 | 75 | image_tensor = self.get_image(index)["image"] 76 | 77 | # missing image, dummy image is all-one image 78 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 79 | for idx in range(len(image_tensor)): 80 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 81 | 82 | #missing text, dummy text is '' 83 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 84 | text = '' 85 | encoding = self.tokenizer( 86 | text, 87 | padding="max_length", 88 | truncation=True, 89 | max_length=self.max_text_len, 90 | return_special_tokens_mask=True, 91 | ) 92 | text = (text, encoding) 93 | else: 94 | text = self.get_text(index)["text"] 95 | 96 | 97 | labels = self.table["label"][image_index].as_py() 98 | 99 | return { 100 | "image": image_tensor, 101 | "text": text, 102 | "label": labels, 103 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 104 | } 105 | -------------------------------------------------------------------------------- /ODIR/vilt/datasets/food101_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random, os 4 | 5 | class FOOD101Dataset(BaseDataset): 6 | def __init__(self, *args, split="", missing_info={}, **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["food101_train"] 12 | elif split == "val": 13 | names = ["food101_val"] 14 | elif split == "test": 15 | names = ["food101_test"] 16 | 17 | super().__init__( 18 | *args, 19 | **kwargs, 20 | names=names, 21 | text_column_name="text", 22 | remove_duplicate=False, 23 | ) 24 | 25 | # missing modality control 26 | self.simulate_missing = missing_info['simulate_missing'] 27 | missing_ratio = missing_info['ratio'][split] 28 | mratio = str(missing_ratio).replace('.','') 29 | missing_type = missing_info['type'][split] 30 | both_ratio = missing_info['both_ratio'] 31 | missing_table_root = missing_info['missing_table_root'] 32 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 33 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 34 | 35 | # use image data to formulate missing table 36 | total_num = len(self.table['image']) 37 | 38 | if os.path.exists(missing_table_path): 39 | missing_table = torch.load(missing_table_path) 40 | if len(missing_table) != total_num: 41 | print('missing table mismatched!') 42 | exit() 43 | else: 44 | missing_table = torch.zeros(total_num) 45 | 46 | if missing_ratio > 0: 47 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 48 | 49 | if missing_type == 'text': 50 | missing_table[missing_index] = 1 51 | elif missing_type == 'image': 52 | missing_table[missing_index] = 2 53 | elif missing_type == 'both': 54 | missing_table[missing_index] = 1 55 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 56 | missing_table[missing_index_image] = 2 57 | 58 | torch.save(missing_table, missing_table_path) 59 | 60 | self.missing_table = missing_table 61 | 62 | print(self.index_mapper) 63 | 64 | 65 | def __getitem__(self, index): 66 | # index -> pair data index 67 | # image_index -> image index in table 68 | # question_index -> plot index in texts of the given image 69 | image_index, question_index = self.index_mapper[index] 70 | 71 | # For the case of training with modality-complete data 72 | # Simulate missing modality with random assign the missing type of samples 73 | simulate_missing_type = 0 74 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 75 | simulate_missing_type = random.choice([0,1,2]) 76 | 77 | image_tensor = self.get_image(index)["image"] 78 | 79 | # missing image, dummy image is all-one image 80 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 81 | for idx in range(len(image_tensor)): 82 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 83 | 84 | #missing text, dummy text is '' 85 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 86 | text = '' 87 | encoding = self.tokenizer( 88 | text, 89 | padding="max_length", 90 | truncation=True, 91 | max_length=self.max_text_len, 92 | return_special_tokens_mask=True, 93 | ) 94 | text = (text, encoding) 95 | else: 96 | text = self.get_text(index)["text"] 97 | 98 | 99 | labels = self.table["label"][image_index].as_py() 100 | 101 | return { 102 | "image": image_tensor, 103 | "text": text, 104 | "label": labels, 105 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 106 | } 107 | -------------------------------------------------------------------------------- /chestXray/vilt/datasets/food101_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random, os 4 | 5 | class FOOD101Dataset(BaseDataset): 6 | def __init__(self, *args, split="", missing_info={}, **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["food101_train"] 12 | elif split == "val": 13 | names = ["food101_val"] 14 | elif split == "test": 15 | names = ["food101_test"] 16 | 17 | super().__init__( 18 | *args, 19 | **kwargs, 20 | names=names, 21 | text_column_name="text", 22 | remove_duplicate=False, 23 | ) 24 | 25 | # missing modality control 26 | self.simulate_missing = missing_info['simulate_missing'] 27 | missing_ratio = missing_info['ratio'][split] 28 | mratio = str(missing_ratio).replace('.','') 29 | missing_type = missing_info['type'][split] 30 | both_ratio = missing_info['both_ratio'] 31 | missing_table_root = missing_info['missing_table_root'] 32 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 33 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 34 | 35 | # use image data to formulate missing table 36 | total_num = len(self.table['image']) 37 | 38 | if os.path.exists(missing_table_path): 39 | missing_table = torch.load(missing_table_path) 40 | if len(missing_table) != total_num: 41 | print('missing table mismatched!') 42 | exit() 43 | else: 44 | missing_table = torch.zeros(total_num) 45 | 46 | if missing_ratio > 0: 47 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 48 | 49 | if missing_type == 'text': 50 | missing_table[missing_index] = 1 51 | elif missing_type == 'image': 52 | missing_table[missing_index] = 2 53 | elif missing_type == 'both': 54 | missing_table[missing_index] = 1 55 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 56 | missing_table[missing_index_image] = 2 57 | 58 | torch.save(missing_table, missing_table_path) 59 | 60 | self.missing_table = missing_table 61 | 62 | print(self.index_mapper) 63 | 64 | 65 | def __getitem__(self, index): 66 | # index -> pair data index 67 | # image_index -> image index in table 68 | # question_index -> plot index in texts of the given image 69 | image_index, question_index = self.index_mapper[index] 70 | 71 | # For the case of training with modality-complete data 72 | # Simulate missing modality with random assign the missing type of samples 73 | simulate_missing_type = 0 74 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 75 | simulate_missing_type = random.choice([0,1,2]) 76 | 77 | image_tensor = self.get_image(index)["image"] 78 | 79 | # missing image, dummy image is all-one image 80 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 81 | for idx in range(len(image_tensor)): 82 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 83 | 84 | #missing text, dummy text is '' 85 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 86 | text = '' 87 | encoding = self.tokenizer( 88 | text, 89 | padding="max_length", 90 | truncation=True, 91 | max_length=self.max_text_len, 92 | return_special_tokens_mask=True, 93 | ) 94 | text = (text, encoding) 95 | else: 96 | text = self.get_text(index)["text"] 97 | 98 | 99 | labels = self.table["label"][image_index].as_py() 100 | 101 | return { 102 | "image": image_tensor, 103 | "text": text, 104 | "label": labels, 105 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 106 | } 107 | -------------------------------------------------------------------------------- /ODIR/vilt/datasets/ODIR_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random 4 | import os 5 | 6 | class ODIRDataset(BaseDataset): 7 | def __init__(self, *args, split="", missing_info={}, **kwargs): 8 | assert split in ["train", "val", "test"] 9 | self.split = split 10 | 11 | if split == "train": 12 | names = ["ODIR_train"] 13 | elif split == "val": 14 | names = ["ODIR_val"] 15 | elif split == "test": 16 | names = ["ODIR_test"] 17 | 18 | super().__init__( 19 | *args, 20 | **kwargs, 21 | names=names, 22 | text_column_name="text", 23 | remove_duplicate=False, 24 | ) 25 | 26 | # missing modality control 27 | self.simulate_missing = missing_info['simulate_missing'] 28 | missing_ratio = missing_info['ratio'][split] 29 | mratio = str(missing_ratio).replace('.','') 30 | missing_type = missing_info['type'][split] 31 | both_ratio = missing_info['both_ratio'] 32 | missing_table_root = missing_info['missing_table_root'] 33 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 34 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 35 | 36 | # use image data to formulate missing table 37 | total_num = len(self.table['text']) 38 | 39 | 40 | if os.path.exists(missing_table_path): 41 | missing_table = torch.load(missing_table_path) 42 | if len(missing_table) != total_num: 43 | print('missing table mismatched!') 44 | exit() 45 | else: 46 | missing_table = torch.zeros(total_num) 47 | 48 | if missing_ratio > 0: 49 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 50 | 51 | if missing_type == 'text': 52 | missing_table[missing_index] = 1 53 | elif missing_type == 'image': 54 | missing_table[missing_index] = 2 55 | elif missing_type == 'both': 56 | missing_table[missing_index] = 1 57 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 58 | missing_table[missing_index_image] = 2 59 | 60 | torch.save(missing_table, missing_table_path) 61 | 62 | self.missing_table = missing_table 63 | 64 | 65 | def __getitem__(self, index): 66 | # index -> pair data index 67 | # image_index -> image index in table 68 | # question_index -> plot index in texts of the given image 69 | 70 | image_index, question_index = self.index_mapper[index] 71 | 72 | 73 | # For the case of training with modality-complete data 74 | # Simulate missing modality with random assign the missing type of samples 75 | simulate_missing_type = 0 76 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 77 | simulate_missing_type = random.choice([0,1,2]) 78 | 79 | image_tensor = self.get_image(index)["image"] 80 | 81 | 82 | # missing image, dummy image is all-one image 83 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 84 | for idx in range(len(image_tensor)): 85 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 86 | 87 | #missing text, dummy text is '' 88 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 89 | text = '' 90 | encoding = self.tokenizer( 91 | text, 92 | padding="max_length", 93 | truncation=True, 94 | max_length=self.max_text_len, 95 | return_special_tokens_mask=True, 96 | ) 97 | text = (text, encoding) 98 | else: 99 | text = self.get_text(index)["text"] 100 | 101 | 102 | 103 | 104 | labels = self.table["label"][image_index].as_py() 105 | 106 | return { 107 | "image": image_tensor, 108 | "text": text, 109 | "label": labels, 110 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 111 | "index":index, 112 | } 113 | -------------------------------------------------------------------------------- /chestXray/vilt/datasets/chestXray_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import torch 3 | import random 4 | import os 5 | 6 | class CHESTXRAYDataset(BaseDataset): 7 | def __init__(self, *args, split="", missing_info={}, **kwargs): 8 | assert split in ["train", "val", "test"] 9 | self.split = split 10 | 11 | if split == "train": 12 | names = ["chestXray_train"] 13 | elif split == "val": 14 | names = ["chestXray_val"] 15 | elif split == "test": 16 | names = ["chestXray_test"] 17 | 18 | super().__init__( 19 | *args, 20 | **kwargs, 21 | names=names, 22 | text_column_name="text", 23 | remove_duplicate=False, 24 | ) 25 | 26 | # missing modality control 27 | self.simulate_missing = missing_info['simulate_missing'] 28 | missing_ratio = missing_info['ratio'][split] 29 | mratio = str(missing_ratio).replace('.','') 30 | missing_type = missing_info['type'][split] 31 | both_ratio = missing_info['both_ratio'] 32 | missing_table_root = missing_info['missing_table_root'] 33 | missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt' 34 | missing_table_path = os.path.join(missing_table_root, missing_table_name) 35 | 36 | # use image data to formulate missing table 37 | total_num = len(self.table['image']) 38 | 39 | 40 | if os.path.exists(missing_table_path): 41 | missing_table = torch.load(missing_table_path) 42 | if len(missing_table) != total_num: 43 | print('missing table mismatched!') 44 | exit() 45 | else: 46 | missing_table = torch.zeros(total_num) 47 | 48 | if missing_ratio > 0: 49 | missing_index = random.sample(range(total_num), int(total_num*missing_ratio)) 50 | 51 | if missing_type == 'text': 52 | missing_table[missing_index] = 1 53 | elif missing_type == 'image': 54 | missing_table[missing_index] = 2 55 | elif missing_type == 'both': 56 | missing_table[missing_index] = 1 57 | missing_index_image = random.sample(missing_index, int(len(missing_index)*both_ratio)) 58 | missing_table[missing_index_image] = 2 59 | 60 | torch.save(missing_table, missing_table_path) 61 | 62 | self.missing_table = missing_table 63 | 64 | 65 | def __getitem__(self, index): 66 | # index -> pair data index 67 | # image_index -> image index in table 68 | # question_index -> plot index in texts of the given image 69 | 70 | image_index, question_index = self.index_mapper[index] 71 | 72 | 73 | # For the case of training with modality-complete data 74 | # Simulate missing modality with random assign the missing type of samples 75 | simulate_missing_type = 0 76 | if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0: 77 | simulate_missing_type = random.choice([0,1,2]) 78 | 79 | image_tensor = self.get_image(index)["image"] 80 | 81 | 82 | # missing image, dummy image is all-one image 83 | if self.missing_table[image_index] == 2 or simulate_missing_type == 2: 84 | for idx in range(len(image_tensor)): 85 | image_tensor[idx] = torch.ones(image_tensor[idx].size()).float() 86 | 87 | #missing text, dummy text is '' 88 | if self.missing_table[image_index] == 1 or simulate_missing_type == 1: 89 | text = '' 90 | encoding = self.tokenizer( 91 | text, 92 | padding="max_length", 93 | truncation=True, 94 | max_length=self.max_text_len, 95 | return_special_tokens_mask=True, 96 | ) 97 | text = (text, encoding) 98 | else: 99 | text = self.get_text(index)["text"] 100 | 101 | 102 | 103 | 104 | labels = self.table["label"][image_index].as_py() 105 | 106 | return { 107 | "image": image_tensor, 108 | "text": text, 109 | "label": labels, 110 | "missing_type": self.missing_table[image_index].item()+simulate_missing_type, 111 | "index":index, 112 | } 113 | -------------------------------------------------------------------------------- /ODIR/vilt/utils/glossary.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | -------------------------------------------------------------------------------- /chestXray/vilt/utils/glossary.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | -------------------------------------------------------------------------------- /ODIR/vilt/modules/lora.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | class LoRALinearLayer(nn.Module): 8 | r""" 9 | A linear layer that is used with LoRA. 10 | 11 | Parameters: 12 | in_features (`int`): 13 | Number of input features. 14 | out_features (`int`): 15 | Number of output features. 16 | rank (`int`, `optional`, defaults to 4): 17 | The rank of the LoRA layer. 18 | network_alpha (`float`, `optional`, defaults to `None`): 19 | The value of the network alpha used for stable learning and preventing underflow. This value has the same 20 | meaning as the `--network_alpha` option in the kohya-ss trainer script. See 21 | https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 22 | device (`torch.device`, `optional`, defaults to `None`): 23 | The device to use for the layer's weights. 24 | dtype (`torch.dtype`, `optional`, defaults to `None`): 25 | The dtype to use for the layer's weights. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | in_features: int, 31 | out_features: int, 32 | rank: int = 4, 33 | network_alpha: Optional[float] = None, 34 | device: Optional[Union[torch.device, str]] = None, 35 | dtype: Optional[torch.dtype] = None, 36 | ): 37 | super().__init__() 38 | 39 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 40 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 41 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 42 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 43 | self.network_alpha = network_alpha 44 | self.rank = rank 45 | self.out_features = out_features 46 | self.in_features = in_features 47 | 48 | nn.init.normal_(self.down.weight, std=1 / rank) 49 | nn.init.zeros_(self.up.weight) 50 | 51 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 52 | 53 | down_hidden_states = self.down(hidden_states) 54 | up_hidden_states = self.up(down_hidden_states) 55 | 56 | if self.network_alpha is not None: 57 | up_hidden_states *= self.network_alpha / self.rank 58 | 59 | return up_hidden_states 60 | 61 | 62 | 63 | class LoRAConv1dLayer(nn.Module): 64 | def __init__( 65 | self, 66 | in_features: int, 67 | out_features: int, 68 | rank: int = 4, 69 | kernel_size: Union[int, Tuple[int, int]] = 1, 70 | stride: Union[int, Tuple[int, int]] = 1, 71 | padding: Union[int, Tuple[int, int], str] = 0, 72 | network_alpha: Optional[float] = None, 73 | device: Optional[Union[torch.device, str]] = None, 74 | dtype: Optional[torch.dtype] = None, 75 | ): 76 | super().__init__() 77 | 78 | self.down = nn.Conv1d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False,device=device,dtype=dtype) 79 | # according to the official kohya_ss trainer kernel_size are always fixed for the up layer 80 | # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 81 | self.up = nn.Conv1d(rank, out_features, kernel_size=1, stride=1, bias=False,device=device,dtype=dtype) 82 | 83 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 84 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 85 | self.network_alpha = network_alpha 86 | self.rank = rank 87 | 88 | nn.init.normal_(self.down.weight, std=1 / rank) 89 | nn.init.zeros_(self.up.weight) 90 | 91 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 92 | 93 | input = hidden_states.permute(0,2,1) 94 | 95 | down_hidden_states = self.down(input) 96 | up_hidden_states = self.up(down_hidden_states) 97 | 98 | if self.network_alpha is not None: 99 | up_hidden_states *= self.network_alpha / self.rank 100 | 101 | return up_hidden_states.permute(0,2,1) 102 | 103 | 104 | 105 | 106 | 107 | class LoRAMissingLayer(nn.Module): 108 | def __init__( 109 | self, 110 | in_features: int, 111 | out_features: int, 112 | rank: int = 4, 113 | network_alpha: Optional[float] = None, 114 | device: Optional[Union[torch.device, str]] = None, 115 | dtype: Optional[torch.dtype] = None, 116 | ): 117 | super().__init__() 118 | 119 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 120 | self.text_missing_up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 121 | self.image_missing_up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 122 | 123 | 124 | self.network_alpha = network_alpha 125 | self.rank = rank 126 | self.out_features = out_features 127 | self.in_features = in_features 128 | 129 | nn.init.normal_(self.down.weight, std=1 / rank) 130 | nn.init.zeros_(self.text_missing_up.weight) 131 | nn.init.zeros_(self.image_missing_up.weight) 132 | 133 | 134 | def forward(self, hidden_states: torch.Tensor,missing_type) -> torch.Tensor: 135 | 136 | down_hidden_states = self.down(hidden_states) 137 | 138 | # lora_tensor = [] 139 | # for b in range(down_hidden_states.shape[0]): 140 | # temp = down_hidden_states[b:b+1] 141 | # if missing_type[b] == 0: 142 | # temp_lora = self.image_missing_up(temp) + self.text_missing_up(temp) 143 | # elif missing_type[b] == 1: 144 | # temp_lora = self.text_missing_up(temp) 145 | # else: 146 | # temp_lora = self.image_missing_up(temp) 147 | # lora_tensor.append(temp_lora) 148 | # lora_tensor = torch.cat(lora_tensor,dim=0) 149 | # up_hidden_states = lora_tensor 150 | 151 | if missing_type == 0: 152 | up_hidden_states = self.image_missing_up(down_hidden_states) + self.text_missing_up(down_hidden_states) 153 | elif missing_type == 1: 154 | up_hidden_states = self.text_missing_up(down_hidden_states) 155 | #up_hidden_states = self.image_missing_up(down_hidden_states) + self.text_missing_up(down_hidden_states) 156 | else: 157 | up_hidden_states = self.image_missing_up(down_hidden_states) 158 | #up_hidden_states = self.image_missing_up(down_hidden_states) + self.text_missing_up(down_hidden_states) 159 | 160 | 161 | 162 | if self.network_alpha is not None: 163 | up_hidden_states *= self.network_alpha / self.rank 164 | 165 | return up_hidden_states 166 | 167 | def up_weights(self): 168 | return [self.text_missing_up.weight, self.image_missing_up.weight] 169 | -------------------------------------------------------------------------------- /chestXray/vilt/modules/lora.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | class LoRALinearLayer(nn.Module): 8 | r""" 9 | A linear layer that is used with LoRA. 10 | 11 | Parameters: 12 | in_features (`int`): 13 | Number of input features. 14 | out_features (`int`): 15 | Number of output features. 16 | rank (`int`, `optional`, defaults to 4): 17 | The rank of the LoRA layer. 18 | network_alpha (`float`, `optional`, defaults to `None`): 19 | The value of the network alpha used for stable learning and preventing underflow. This value has the same 20 | meaning as the `--network_alpha` option in the kohya-ss trainer script. See 21 | https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 22 | device (`torch.device`, `optional`, defaults to `None`): 23 | The device to use for the layer's weights. 24 | dtype (`torch.dtype`, `optional`, defaults to `None`): 25 | The dtype to use for the layer's weights. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | in_features: int, 31 | out_features: int, 32 | rank: int = 4, 33 | network_alpha: Optional[float] = None, 34 | device: Optional[Union[torch.device, str]] = None, 35 | dtype: Optional[torch.dtype] = None, 36 | ): 37 | super().__init__() 38 | 39 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 40 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 41 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 42 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 43 | self.network_alpha = network_alpha 44 | self.rank = rank 45 | self.out_features = out_features 46 | self.in_features = in_features 47 | 48 | nn.init.normal_(self.down.weight, std=1 / rank) 49 | nn.init.zeros_(self.up.weight) 50 | 51 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 52 | 53 | down_hidden_states = self.down(hidden_states) 54 | up_hidden_states = self.up(down_hidden_states) 55 | 56 | if self.network_alpha is not None: 57 | up_hidden_states *= self.network_alpha / self.rank 58 | 59 | return up_hidden_states 60 | 61 | 62 | 63 | class LoRAConv1dLayer(nn.Module): 64 | def __init__( 65 | self, 66 | in_features: int, 67 | out_features: int, 68 | rank: int = 4, 69 | kernel_size: Union[int, Tuple[int, int]] = 1, 70 | stride: Union[int, Tuple[int, int]] = 1, 71 | padding: Union[int, Tuple[int, int], str] = 0, 72 | network_alpha: Optional[float] = None, 73 | device: Optional[Union[torch.device, str]] = None, 74 | dtype: Optional[torch.dtype] = None, 75 | ): 76 | super().__init__() 77 | 78 | self.down = nn.Conv1d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False,device=device,dtype=dtype) 79 | # according to the official kohya_ss trainer kernel_size are always fixed for the up layer 80 | # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 81 | self.up = nn.Conv1d(rank, out_features, kernel_size=1, stride=1, bias=False,device=device,dtype=dtype) 82 | 83 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 84 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 85 | self.network_alpha = network_alpha 86 | self.rank = rank 87 | 88 | nn.init.normal_(self.down.weight, std=1 / rank) 89 | nn.init.zeros_(self.up.weight) 90 | 91 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 92 | 93 | input = hidden_states.permute(0,2,1) 94 | 95 | down_hidden_states = self.down(input) 96 | up_hidden_states = self.up(down_hidden_states) 97 | 98 | if self.network_alpha is not None: 99 | up_hidden_states *= self.network_alpha / self.rank 100 | 101 | return up_hidden_states.permute(0,2,1) 102 | 103 | 104 | 105 | 106 | 107 | class LoRAMissingLayer(nn.Module): 108 | def __init__( 109 | self, 110 | in_features: int, 111 | out_features: int, 112 | rank: int = 4, 113 | network_alpha: Optional[float] = None, 114 | device: Optional[Union[torch.device, str]] = None, 115 | dtype: Optional[torch.dtype] = None, 116 | ): 117 | super().__init__() 118 | 119 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 120 | self.text_missing_up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 121 | self.image_missing_up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 122 | 123 | 124 | self.network_alpha = network_alpha 125 | self.rank = rank 126 | self.out_features = out_features 127 | self.in_features = in_features 128 | 129 | nn.init.normal_(self.down.weight, std=1 / rank) 130 | nn.init.zeros_(self.text_missing_up.weight) 131 | nn.init.zeros_(self.image_missing_up.weight) 132 | 133 | 134 | def forward(self, hidden_states: torch.Tensor,missing_type) -> torch.Tensor: 135 | 136 | down_hidden_states = self.down(hidden_states) 137 | 138 | # lora_tensor = [] 139 | # for b in range(down_hidden_states.shape[0]): 140 | # temp = down_hidden_states[b:b+1] 141 | # if missing_type[b] == 0: 142 | # temp_lora = self.image_missing_up(temp) + self.text_missing_up(temp) 143 | # elif missing_type[b] == 1: 144 | # temp_lora = self.text_missing_up(temp) 145 | # else: 146 | # temp_lora = self.image_missing_up(temp) 147 | # lora_tensor.append(temp_lora) 148 | # lora_tensor = torch.cat(lora_tensor,dim=0) 149 | # up_hidden_states = lora_tensor 150 | 151 | if missing_type == 0: 152 | up_hidden_states = self.image_missing_up(down_hidden_states) + self.text_missing_up(down_hidden_states) 153 | elif missing_type == 1: 154 | up_hidden_states = self.text_missing_up(down_hidden_states) 155 | #up_hidden_states = self.image_missing_up(down_hidden_states) + self.text_missing_up(down_hidden_states) 156 | else: 157 | up_hidden_states = self.image_missing_up(down_hidden_states) 158 | #up_hidden_states = self.image_missing_up(down_hidden_states) + self.text_missing_up(down_hidden_states) 159 | 160 | 161 | 162 | if self.network_alpha is not None: 163 | up_hidden_states *= self.network_alpha / self.rank 164 | 165 | return up_hidden_states 166 | 167 | def up_weights(self): 168 | return [self.text_missing_up.weight, self.image_missing_up.weight] 169 | -------------------------------------------------------------------------------- /ODIR/vilt/datamodules/datamodule_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from transformers import ( 6 | DataCollatorForLanguageModeling, 7 | DataCollatorForWholeWordMask, 8 | BertTokenizer, 9 | ) 10 | 11 | 12 | def get_pretrained_tokenizer(from_pretrained): 13 | if torch.distributed.is_initialized(): 14 | if torch.distributed.get_rank() == 0: 15 | BertTokenizer.from_pretrained( 16 | from_pretrained, do_lower_case="uncased" in from_pretrained 17 | ) 18 | torch.distributed.barrier() 19 | return BertTokenizer.from_pretrained( 20 | from_pretrained, do_lower_case="uncased" in from_pretrained 21 | ) 22 | 23 | 24 | class BaseDataModule(LightningDataModule): 25 | def __init__(self, _config): 26 | super().__init__() 27 | 28 | self.data_dir = _config["data_root"] 29 | 30 | self.num_workers = _config["num_workers"] 31 | self.batch_size = _config["per_gpu_batchsize"] 32 | self.eval_batch_size = self.batch_size 33 | 34 | self.image_size = _config["image_size"] 35 | self.max_text_len = _config["max_text_len"] 36 | self.draw_false_image = _config["draw_false_image"] 37 | self.draw_false_text = _config["draw_false_text"] 38 | self.image_only = _config["image_only"] 39 | 40 | # construct missing modality info 41 | self.missing_info = { 42 | 'ratio' : _config["missing_ratio"], 43 | 'type' : _config["missing_type"], 44 | 'both_ratio' : _config["both_ratio"], 45 | 'missing_table_root': _config["missing_table_root"], 46 | 'simulate_missing' : _config["simulate_missing"] 47 | } 48 | # for bash execution 49 | if _config["test_ratio"] is not None: 50 | self.missing_info['ratio']['val'] = _config["test_ratio"] 51 | self.missing_info['ratio']['test'] = _config["test_ratio"] 52 | if _config["test_type"] is not None: 53 | self.missing_info['type']['val'] = _config["test_type"] 54 | self.missing_info['type']['test'] = _config["test_type"] 55 | 56 | self.train_transform_keys = ( 57 | ["default_train"] 58 | if len(_config["train_transform_keys"]) == 0 59 | else _config["train_transform_keys"] 60 | ) 61 | 62 | self.val_transform_keys = ( 63 | ["default_val"] 64 | if len(_config["val_transform_keys"]) == 0 65 | else _config["val_transform_keys"] 66 | ) 67 | 68 | tokenizer = _config["tokenizer"] 69 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 70 | self.vocab_size = self.tokenizer.vocab_size 71 | 72 | collator = ( 73 | DataCollatorForWholeWordMask 74 | if _config["whole_word_masking"] 75 | else DataCollatorForLanguageModeling 76 | ) 77 | 78 | self.mlm_collator = collator( 79 | tokenizer=self.tokenizer, mlm=True, mlm_probability=_config["mlm_prob"] 80 | ) 81 | self.setup_flag = False 82 | 83 | @property 84 | def dataset_cls(self): 85 | raise NotImplementedError("return tuple of dataset class") 86 | 87 | @property 88 | def dataset_name(self): 89 | raise NotImplementedError("return name of dataset") 90 | 91 | def set_train_dataset(self): 92 | self.train_dataset = self.dataset_cls( 93 | self.data_dir, 94 | self.train_transform_keys, 95 | split="train", 96 | image_size=self.image_size, 97 | max_text_len=self.max_text_len, 98 | draw_false_image=self.draw_false_image, 99 | draw_false_text=self.draw_false_text, 100 | image_only=self.image_only, 101 | missing_info=self.missing_info, 102 | ) 103 | 104 | def set_val_dataset(self): 105 | self.val_dataset = self.dataset_cls( 106 | self.data_dir, 107 | self.val_transform_keys, 108 | split="val", 109 | image_size=self.image_size, 110 | max_text_len=self.max_text_len, 111 | draw_false_image=self.draw_false_image, 112 | draw_false_text=self.draw_false_text, 113 | image_only=self.image_only, 114 | missing_info=self.missing_info, 115 | ) 116 | 117 | if hasattr(self, "dataset_cls_no_false"): 118 | self.val_dataset_no_false = self.dataset_cls_no_false( 119 | self.data_dir, 120 | self.val_transform_keys, 121 | split="val", 122 | image_size=self.image_size, 123 | max_text_len=self.max_text_len, 124 | draw_false_image=0, 125 | draw_false_text=0, 126 | image_only=self.image_only, 127 | ) 128 | 129 | def make_no_false_val_dset(self, image_only=False): 130 | return self.dataset_cls_no_false( 131 | self.data_dir, 132 | self.val_transform_keys, 133 | split="val", 134 | image_size=self.image_size, 135 | max_text_len=self.max_text_len, 136 | draw_false_image=0, 137 | draw_false_text=0, 138 | image_only=image_only, 139 | ) 140 | 141 | def set_test_dataset(self): 142 | self.test_dataset = self.dataset_cls( 143 | self.data_dir, 144 | self.val_transform_keys, 145 | split="test", 146 | image_size=self.image_size, 147 | max_text_len=self.max_text_len, 148 | draw_false_image=self.draw_false_image, 149 | draw_false_text=self.draw_false_text, 150 | image_only=self.image_only, 151 | missing_info=self.missing_info, 152 | ) 153 | 154 | def setup(self, stage): 155 | if not self.setup_flag: 156 | self.set_train_dataset() 157 | self.set_val_dataset() 158 | self.set_test_dataset() 159 | 160 | self.train_dataset.tokenizer = self.tokenizer 161 | self.val_dataset.tokenizer = self.tokenizer 162 | self.test_dataset.tokenizer = self.tokenizer 163 | 164 | self.setup_flag = True 165 | 166 | def train_dataloader(self): 167 | loader = DataLoader( 168 | self.train_dataset, 169 | batch_size=self.batch_size, 170 | shuffle=True, 171 | num_workers=self.num_workers, 172 | pin_memory=True, 173 | collate_fn=self.train_dataset.collate, 174 | ) 175 | return loader 176 | 177 | def val_dataloader(self): 178 | loader = DataLoader( 179 | self.val_dataset, 180 | batch_size=self.eval_batch_size, 181 | shuffle=False, 182 | num_workers=self.num_workers, 183 | pin_memory=True, 184 | collate_fn=self.val_dataset.collate, 185 | ) 186 | return loader 187 | 188 | def test_dataloader(self): 189 | loader = DataLoader( 190 | self.test_dataset, 191 | batch_size=self.eval_batch_size, 192 | shuffle=False, 193 | num_workers=self.num_workers, 194 | pin_memory=True, 195 | collate_fn=self.test_dataset.collate, 196 | ) 197 | return loader 198 | -------------------------------------------------------------------------------- /chestXray/vilt/datamodules/datamodule_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from transformers import ( 6 | DataCollatorForLanguageModeling, 7 | DataCollatorForWholeWordMask, 8 | BertTokenizer, 9 | ) 10 | 11 | 12 | def get_pretrained_tokenizer(from_pretrained): 13 | if torch.distributed.is_initialized(): 14 | if torch.distributed.get_rank() == 0: 15 | BertTokenizer.from_pretrained( 16 | from_pretrained, do_lower_case="uncased" in from_pretrained 17 | ) 18 | torch.distributed.barrier() 19 | return BertTokenizer.from_pretrained( 20 | from_pretrained, do_lower_case="uncased" in from_pretrained 21 | ) 22 | 23 | 24 | class BaseDataModule(LightningDataModule): 25 | def __init__(self, _config): 26 | super().__init__() 27 | 28 | self.data_dir = _config["data_root"] 29 | 30 | self.num_workers = _config["num_workers"] 31 | self.batch_size = _config["per_gpu_batchsize"] 32 | self.eval_batch_size = self.batch_size 33 | 34 | self.image_size = _config["image_size"] 35 | self.max_text_len = _config["max_text_len"] 36 | self.draw_false_image = _config["draw_false_image"] 37 | self.draw_false_text = _config["draw_false_text"] 38 | self.image_only = _config["image_only"] 39 | 40 | # construct missing modality info 41 | self.missing_info = { 42 | 'ratio' : _config["missing_ratio"], 43 | 'type' : _config["missing_type"], 44 | 'both_ratio' : _config["both_ratio"], 45 | 'missing_table_root': _config["missing_table_root"], 46 | 'simulate_missing' : _config["simulate_missing"] 47 | } 48 | # for bash execution 49 | if _config["test_ratio"] is not None: 50 | self.missing_info['ratio']['val'] = _config["test_ratio"] 51 | self.missing_info['ratio']['test'] = _config["test_ratio"] 52 | if _config["test_type"] is not None: 53 | self.missing_info['type']['val'] = _config["test_type"] 54 | self.missing_info['type']['test'] = _config["test_type"] 55 | 56 | self.train_transform_keys = ( 57 | ["default_train"] 58 | if len(_config["train_transform_keys"]) == 0 59 | else _config["train_transform_keys"] 60 | ) 61 | 62 | self.val_transform_keys = ( 63 | ["default_val"] 64 | if len(_config["val_transform_keys"]) == 0 65 | else _config["val_transform_keys"] 66 | ) 67 | 68 | tokenizer = _config["tokenizer"] 69 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 70 | self.vocab_size = self.tokenizer.vocab_size 71 | 72 | collator = ( 73 | DataCollatorForWholeWordMask 74 | if _config["whole_word_masking"] 75 | else DataCollatorForLanguageModeling 76 | ) 77 | 78 | self.mlm_collator = collator( 79 | tokenizer=self.tokenizer, mlm=True, mlm_probability=_config["mlm_prob"] 80 | ) 81 | self.setup_flag = False 82 | 83 | @property 84 | def dataset_cls(self): 85 | raise NotImplementedError("return tuple of dataset class") 86 | 87 | @property 88 | def dataset_name(self): 89 | raise NotImplementedError("return name of dataset") 90 | 91 | def set_train_dataset(self): 92 | self.train_dataset = self.dataset_cls( 93 | self.data_dir, 94 | self.train_transform_keys, 95 | split="train", 96 | image_size=self.image_size, 97 | max_text_len=self.max_text_len, 98 | draw_false_image=self.draw_false_image, 99 | draw_false_text=self.draw_false_text, 100 | image_only=self.image_only, 101 | missing_info=self.missing_info, 102 | ) 103 | 104 | def set_val_dataset(self): 105 | self.val_dataset = self.dataset_cls( 106 | self.data_dir, 107 | self.val_transform_keys, 108 | split="val", 109 | image_size=self.image_size, 110 | max_text_len=self.max_text_len, 111 | draw_false_image=self.draw_false_image, 112 | draw_false_text=self.draw_false_text, 113 | image_only=self.image_only, 114 | missing_info=self.missing_info, 115 | ) 116 | 117 | if hasattr(self, "dataset_cls_no_false"): 118 | self.val_dataset_no_false = self.dataset_cls_no_false( 119 | self.data_dir, 120 | self.val_transform_keys, 121 | split="val", 122 | image_size=self.image_size, 123 | max_text_len=self.max_text_len, 124 | draw_false_image=0, 125 | draw_false_text=0, 126 | image_only=self.image_only, 127 | ) 128 | 129 | def make_no_false_val_dset(self, image_only=False): 130 | return self.dataset_cls_no_false( 131 | self.data_dir, 132 | self.val_transform_keys, 133 | split="val", 134 | image_size=self.image_size, 135 | max_text_len=self.max_text_len, 136 | draw_false_image=0, 137 | draw_false_text=0, 138 | image_only=image_only, 139 | ) 140 | 141 | def set_test_dataset(self): 142 | self.test_dataset = self.dataset_cls( 143 | self.data_dir, 144 | self.val_transform_keys, 145 | split="test", 146 | image_size=self.image_size, 147 | max_text_len=self.max_text_len, 148 | draw_false_image=self.draw_false_image, 149 | draw_false_text=self.draw_false_text, 150 | image_only=self.image_only, 151 | missing_info=self.missing_info, 152 | ) 153 | 154 | def setup(self, stage): 155 | if not self.setup_flag: 156 | self.set_train_dataset() 157 | self.set_val_dataset() 158 | self.set_test_dataset() 159 | 160 | self.train_dataset.tokenizer = self.tokenizer 161 | self.val_dataset.tokenizer = self.tokenizer 162 | self.test_dataset.tokenizer = self.tokenizer 163 | 164 | self.setup_flag = True 165 | 166 | def train_dataloader(self): 167 | loader = DataLoader( 168 | self.train_dataset, 169 | batch_size=self.batch_size, 170 | shuffle=True, 171 | num_workers=self.num_workers, 172 | pin_memory=True, 173 | collate_fn=self.train_dataset.collate, 174 | ) 175 | return loader 176 | 177 | def val_dataloader(self): 178 | loader = DataLoader( 179 | self.val_dataset, 180 | batch_size=self.eval_batch_size, 181 | shuffle=False, 182 | num_workers=self.num_workers, 183 | pin_memory=True, 184 | collate_fn=self.val_dataset.collate, 185 | ) 186 | return loader 187 | 188 | def test_dataloader(self): 189 | loader = DataLoader( 190 | self.test_dataset, 191 | batch_size=self.eval_batch_size, 192 | shuffle=False, 193 | num_workers=self.num_workers, 194 | pin_memory=True, 195 | collate_fn=self.test_dataset.collate, 196 | ) 197 | return loader 198 | -------------------------------------------------------------------------------- /ODIR/vilt/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.functional import f1_score, auroc 3 | from pytorch_lightning.metrics import Metric 4 | 5 | 6 | class Accuracy(Metric): 7 | def __init__(self, dist_sync_on_step=False): 8 | super().__init__(dist_sync_on_step=dist_sync_on_step) 9 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 11 | 12 | def update(self, logits, target): 13 | logits, target = ( 14 | logits.detach().to(self.correct.device), 15 | target.detach().to(self.correct.device), 16 | ) 17 | if logits.size(-1)>1: 18 | preds = logits.argmax(dim=-1) 19 | else: 20 | preds = (torch.sigmoid(logits)>0.5).long() 21 | 22 | preds = preds[target != -100] 23 | target = target[target != -100] 24 | if target.numel() == 0: 25 | return 1 26 | 27 | assert preds.shape == target.shape 28 | 29 | self.correct += torch.sum(preds == target) 30 | self.total += target.numel() 31 | 32 | def compute(self): 33 | return self.correct / self.total 34 | 35 | class AUROC(Metric): 36 | def __init__(self, dist_sync_on_step=False): 37 | super().__init__(dist_sync_on_step=dist_sync_on_step) 38 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 39 | self.add_state("logits", default=[], dist_reduce_fx="cat") 40 | self.add_state("targets", default=[], dist_reduce_fx="cat") 41 | 42 | def update(self, logits, target): 43 | logits, targets = ( 44 | logits.detach().to(self.correct.device), 45 | target.detach().to(self.correct.device), 46 | ) 47 | 48 | self.logits.append(logits) 49 | self.targets.append(targets) 50 | 51 | 52 | def compute(self): 53 | if type(self.logits) == list: 54 | all_logits = torch.cat(self.logits) 55 | all_targets = torch.cat(self.targets).long() 56 | else: 57 | all_logits = self.logits 58 | all_targets = self.targets.long() 59 | 60 | if all_logits.size(-1)>1: 61 | all_logits = torch.softmax(all_logits, dim=1) 62 | AUROC = auroc(all_logits, all_targets, num_classes=2) 63 | else: 64 | all_logits = torch.sigmoid(all_logits) 65 | AUROC = auroc(all_logits, all_targets) 66 | 67 | return AUROC 68 | 69 | class F1_Score(Metric): 70 | def __init__(self, dist_sync_on_step=False): 71 | super().__init__(dist_sync_on_step=dist_sync_on_step) 72 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 73 | self.add_state("logits", default=[], dist_reduce_fx="cat") 74 | self.add_state("targets", default=[], dist_reduce_fx="cat") 75 | 76 | def update(self, logits, target): 77 | logits, targets = ( 78 | logits.detach().to(self.correct.device), 79 | target.detach().to(self.correct.device), 80 | ) 81 | 82 | self.logits.append(logits) 83 | self.targets.append(targets) 84 | 85 | 86 | def compute(self, use_sigmoid=True): 87 | if type(self.logits) == list: 88 | all_logits = torch.cat(self.logits) 89 | all_targets = torch.cat(self.targets).long() 90 | else: 91 | all_logits = self.logits 92 | all_targets = self.targets.long() 93 | if use_sigmoid: 94 | all_logits = torch.sigmoid(all_logits) 95 | F1_Micro = f1_score(all_logits, all_targets, average='micro') 96 | F1_Macro = f1_score(all_logits, all_targets, average='macro', num_classes=8) 97 | F1_Samples = f1_score(all_logits, all_targets, average='samples') 98 | F1_Weighted = f1_score(all_logits, all_targets, average='weighted', num_classes=8) 99 | return (F1_Micro, F1_Macro, F1_Samples, F1_Weighted) 100 | 101 | class check(Metric): 102 | def __init__(self, dist_sync_on_step=False): 103 | super().__init__(dist_sync_on_step=dist_sync_on_step) 104 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 105 | self.add_state("logits", default=[], dist_reduce_fx="cat") 106 | self.add_state("targets", default=[], dist_reduce_fx="cat") 107 | 108 | def update(self, logits, target): 109 | logits, targets = ( 110 | logits.detach().to(self.correct.device), 111 | target.detach().to(self.correct.device), 112 | ) 113 | 114 | self.logits.append(logits) 115 | self.targets.append(targets) 116 | 117 | 118 | def compute(self, use_sigmoid=True): 119 | if type(self.logits) == list: 120 | all_logits = torch.cat(self.logits).long() 121 | all_targets = torch.cat(self.targets).long() 122 | else: 123 | all_logits = self.logits.long() 124 | all_targets = self.targets.long() 125 | 126 | mislead = all_logits ^ all_targets 127 | accuracy = mislead.sum(dim=0) 128 | return accuracy 129 | 130 | class Scalar(Metric): 131 | def __init__(self, dist_sync_on_step=False): 132 | super().__init__(dist_sync_on_step=dist_sync_on_step) 133 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 134 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 135 | 136 | def update(self, scalar): 137 | if isinstance(scalar, torch.Tensor): 138 | scalar = scalar.detach().to(self.scalar.device) 139 | else: 140 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 141 | self.scalar += scalar 142 | self.total += 1 143 | 144 | def compute(self): 145 | return self.scalar / self.total 146 | 147 | class Scalar2(Metric): 148 | def __init__(self, dist_sync_on_step=False): 149 | super().__init__(dist_sync_on_step=dist_sync_on_step) 150 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 151 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 152 | 153 | def update(self, scalar, num): 154 | if isinstance(scalar, torch.Tensor): 155 | scalar = scalar.detach().to(self.scalar.device) 156 | else: 157 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 158 | 159 | self.scalar += scalar 160 | self.total += num 161 | 162 | def compute(self): 163 | return self.scalar / self.total 164 | 165 | 166 | class VQAScore(Metric): 167 | def __init__(self, dist_sync_on_step=False): 168 | super().__init__(dist_sync_on_step=dist_sync_on_step) 169 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 170 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 171 | 172 | def update(self, logits, target): 173 | logits, target = ( 174 | logits.detach().float().to(self.score.device), 175 | target.detach().float().to(self.score.device), 176 | ) 177 | logits = torch.max(logits, 1)[1] 178 | one_hots = torch.zeros(*target.size()).to(target) 179 | one_hots.scatter_(1, logits.view(-1, 1), 1) 180 | scores = one_hots * target 181 | 182 | self.score += scores.sum() 183 | self.total += len(logits) 184 | 185 | def compute(self): 186 | return self.score / self.total 187 | -------------------------------------------------------------------------------- /chestXray/vilt/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.functional import f1_score, auroc 3 | from pytorch_lightning.metrics import Metric 4 | 5 | 6 | class Accuracy(Metric): 7 | def __init__(self, dist_sync_on_step=False): 8 | super().__init__(dist_sync_on_step=dist_sync_on_step) 9 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 11 | 12 | def update(self, logits, target): 13 | logits, target = ( 14 | logits.detach().to(self.correct.device), 15 | target.detach().to(self.correct.device), 16 | ) 17 | if logits.size(-1)>1: 18 | preds = logits.argmax(dim=-1) 19 | else: 20 | preds = (torch.sigmoid(logits)>0.5).long() 21 | 22 | preds = preds[target != -100] 23 | target = target[target != -100] 24 | if target.numel() == 0: 25 | return 1 26 | 27 | assert preds.shape == target.shape 28 | 29 | self.correct += torch.sum(preds == target) 30 | self.total += target.numel() 31 | 32 | def compute(self): 33 | return self.correct / self.total 34 | 35 | class AUROC(Metric): 36 | def __init__(self, dist_sync_on_step=False): 37 | super().__init__(dist_sync_on_step=dist_sync_on_step) 38 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 39 | self.add_state("logits", default=[], dist_reduce_fx="cat") 40 | self.add_state("targets", default=[], dist_reduce_fx="cat") 41 | 42 | def update(self, logits, target): 43 | logits, targets = ( 44 | logits.detach().to(self.correct.device), 45 | target.detach().to(self.correct.device), 46 | ) 47 | 48 | self.logits.append(logits) 49 | self.targets.append(targets) 50 | 51 | 52 | def compute(self): 53 | if type(self.logits) == list: 54 | all_logits = torch.cat(self.logits) 55 | all_targets = torch.cat(self.targets).long() 56 | else: 57 | all_logits = self.logits 58 | all_targets = self.targets.long() 59 | 60 | if all_logits.size(-1)>1: 61 | all_logits = torch.softmax(all_logits, dim=1) 62 | AUROC = auroc(all_logits, all_targets, num_classes=2) 63 | else: 64 | all_logits = torch.sigmoid(all_logits) 65 | AUROC = auroc(all_logits, all_targets) 66 | 67 | return AUROC 68 | 69 | class F1_Score(Metric): 70 | def __init__(self, dist_sync_on_step=False): 71 | super().__init__(dist_sync_on_step=dist_sync_on_step) 72 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 73 | self.add_state("logits", default=[], dist_reduce_fx="cat") 74 | self.add_state("targets", default=[], dist_reduce_fx="cat") 75 | 76 | def update(self, logits, target): 77 | logits, targets = ( 78 | logits.detach().to(self.correct.device), 79 | target.detach().to(self.correct.device), 80 | ) 81 | 82 | self.logits.append(logits) 83 | self.targets.append(targets) 84 | 85 | 86 | def compute(self, use_sigmoid=True): 87 | if type(self.logits) == list: 88 | all_logits = torch.cat(self.logits) 89 | all_targets = torch.cat(self.targets).long() 90 | else: 91 | all_logits = self.logits 92 | all_targets = self.targets.long() 93 | if use_sigmoid: 94 | all_logits = torch.sigmoid(all_logits) 95 | F1_Micro = f1_score(all_logits, all_targets, average='micro') 96 | F1_Macro = f1_score(all_logits, all_targets, average='macro', num_classes=20) 97 | F1_Samples = f1_score(all_logits, all_targets, average='samples') 98 | F1_Weighted = f1_score(all_logits, all_targets, average='weighted', num_classes=20) 99 | return (F1_Micro, F1_Macro, F1_Samples, F1_Weighted) 100 | 101 | class check(Metric): 102 | def __init__(self, dist_sync_on_step=False): 103 | super().__init__(dist_sync_on_step=dist_sync_on_step) 104 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 105 | self.add_state("logits", default=[], dist_reduce_fx="cat") 106 | self.add_state("targets", default=[], dist_reduce_fx="cat") 107 | 108 | def update(self, logits, target): 109 | logits, targets = ( 110 | logits.detach().to(self.correct.device), 111 | target.detach().to(self.correct.device), 112 | ) 113 | 114 | self.logits.append(logits) 115 | self.targets.append(targets) 116 | 117 | 118 | def compute(self, use_sigmoid=True): 119 | if type(self.logits) == list: 120 | all_logits = torch.cat(self.logits).long() 121 | all_targets = torch.cat(self.targets).long() 122 | else: 123 | all_logits = self.logits.long() 124 | all_targets = self.targets.long() 125 | 126 | mislead = all_logits ^ all_targets 127 | accuracy = mislead.sum(dim=0) 128 | return accuracy 129 | 130 | class Scalar(Metric): 131 | def __init__(self, dist_sync_on_step=False): 132 | super().__init__(dist_sync_on_step=dist_sync_on_step) 133 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 134 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 135 | 136 | def update(self, scalar): 137 | if isinstance(scalar, torch.Tensor): 138 | scalar = scalar.detach().to(self.scalar.device) 139 | else: 140 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 141 | self.scalar += scalar 142 | self.total += 1 143 | 144 | def compute(self): 145 | return self.scalar / self.total 146 | 147 | class Scalar2(Metric): 148 | def __init__(self, dist_sync_on_step=False): 149 | super().__init__(dist_sync_on_step=dist_sync_on_step) 150 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 151 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 152 | 153 | def update(self, scalar, num): 154 | if isinstance(scalar, torch.Tensor): 155 | scalar = scalar.detach().to(self.scalar.device) 156 | else: 157 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 158 | 159 | self.scalar += scalar 160 | self.total += num 161 | 162 | def compute(self): 163 | return self.scalar / self.total 164 | 165 | 166 | class VQAScore(Metric): 167 | def __init__(self, dist_sync_on_step=False): 168 | super().__init__(dist_sync_on_step=dist_sync_on_step) 169 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 170 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 171 | 172 | def update(self, logits, target): 173 | logits, target = ( 174 | logits.detach().float().to(self.score.device), 175 | target.detach().float().to(self.score.device), 176 | ) 177 | logits = torch.max(logits, 1)[1] 178 | one_hots = torch.zeros(*target.size()).to(target) 179 | one_hots.scatter_(1, logits.view(-1, 1), 1) 180 | scores = one_hots * target 181 | 182 | self.score += scores.sum() 183 | self.total += len(logits) 184 | 185 | def compute(self): 186 | return self.score / self.total 187 | -------------------------------------------------------------------------------- /ODIR/vilt/transforms/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.0: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.0)) 136 | y0 = int(max(0, y0 - v / 2.0)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | # (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0.0, 0.3), 197 | (ShearY, 0.0, 0.3), 198 | # (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0.0, 100), 200 | (TranslateYabs, 0.0, 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = ( 220 | self.eigvec.type_as(img) 221 | .clone() 222 | .mul(alpha.view(1, 3).expand(3, 3)) 223 | .mul(self.eigval.view(1, 3).expand(3, 3)) 224 | .sum(1) 225 | .squeeze() 226 | ) 227 | 228 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 229 | 230 | 231 | class CutoutDefault(object): 232 | """ 233 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 234 | """ 235 | 236 | def __init__(self, length): 237 | self.length = length 238 | 239 | def __call__(self, img): 240 | h, w = img.size(1), img.size(2) 241 | mask = np.ones((h, w), np.float32) 242 | y = np.random.randint(h) 243 | x = np.random.randint(w) 244 | 245 | y1 = np.clip(y - self.length // 2, 0, h) 246 | y2 = np.clip(y + self.length // 2, 0, h) 247 | x1 = np.clip(x - self.length // 2, 0, w) 248 | x2 = np.clip(x + self.length // 2, 0, w) 249 | 250 | mask[y1:y2, x1:x2] = 0.0 251 | mask = torch.from_numpy(mask) 252 | mask = mask.expand_as(img) 253 | img *= mask 254 | return img 255 | 256 | 257 | class RandAugment: 258 | def __init__(self, n, m): 259 | self.n = n 260 | self.m = m # [0, 30] 261 | self.augment_list = augment_list() 262 | 263 | def __call__(self, img): 264 | ops = random.choices(self.augment_list, k=self.n) 265 | for op, minval, maxval in ops: 266 | val = (float(self.m) / 30) * float(maxval - minval) + minval 267 | img = op(img, val) 268 | 269 | return img 270 | --------------------------------------------------------------------------------