├── __init__.py ├── eval ├── __init__.py └── vqa │ ├── __init__.py │ ├── plot_tail.py │ └── gqa_eval_from_file.py ├── models ├── __init__.py ├── vqa │ ├── __init__.py │ ├── updn │ │ ├── __init__.py │ │ ├── tda.py │ │ └── net.py │ ├── make_mask.py │ ├── ops │ │ ├── layer_norm.py │ │ └── fc.py │ ├── vqa_adapter.py │ ├── ban │ │ ├── ban.py │ │ └── _ban.py │ └── mcan │ │ ├── net.py │ │ └── mca.py ├── model_factory.py ├── fc_models.py ├── coordconv.py └── variable_width_resnet.py ├── utils ├── __init__.py ├── bias_retrievers.py ├── losses.py ├── running_stats.py ├── format_utils.py ├── ema.py ├── trainer_utils.py ├── metric_visualizer.py ├── metrics.py └── data_utils.py ├── .gitignore ├── datasets ├── __init__.py ├── vqa │ ├── __init__.py │ ├── feat_filter.py │ ├── base_vqa_dataset.py │ ├── gqa_feat_preproc.py │ └── ans_punct.py ├── .idea │ ├── .gitignore │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── $CACHE_FILE$ │ ├── vcs.xml │ ├── deployment.xml │ ├── modules.xml │ ├── misc.xml │ └── datasets.iml ├── dataloader_factory.py └── shape_generator.py ├── common.sh ├── images ├── main_table.jpg ├── scalability.jpg ├── bias_exploitation.jpg └── distribution_variance.jpg ├── trainers ├── trainer_factory.py ├── __init__.py ├── group_upweighting_trainer.py ├── spectral_decoupling_trainer.py ├── group_dro_trainer.py ├── rubi_trainer.py ├── lnl_trainer.py ├── irm_v1_trainer.py └── learning_from_failure_trainer.py ├── scripts ├── gqa-ood │ ├── run_multiple_gqa.sh │ ├── baseline_gqa.sh │ ├── lff_gqa.sh │ ├── spectral_decoupling_gqa.sh │ ├── upweighting_gqa.sh │ ├── rubi_gqa.sh │ ├── preprocess_gqa.sh │ ├── group_dro_gqa.sh │ ├── irmv1_gqa.sh │ ├── lnl_gqa.sh │ └── hyperparam_search_gqa.sh ├── celebA │ ├── rubi_celebA.sh │ ├── group_upweighting_celebA.sh │ ├── group_dro_celebA.sh │ ├── baseline_celebA.sh │ ├── irmv1_celebA.sh │ ├── spectral_decoupling_celebA.sh │ ├── lnl_celebA.sh │ └── learning_from_failure_celebA.sh ├── biased_mnist │ ├── rubi_biased_mnist.sh │ ├── baseline_biased_mnist.sh │ ├── group_dro_biased_mnist.sh │ ├── group_upweighting_biased_mnist.sh │ ├── irmv1_biased_mnist.sh │ ├── learning_from_failure_biased_mnist.sh │ ├── lnl_biased_mnist.sh │ └── spectral_decoupling_biased_mnist.sh └── biased_mnist_dataset_generator │ └── full_generator.sh ├── experiments ├── __init__.py ├── gqa_experiments.py └── celebA_experiments.py ├── LICENSE ├── main.py ├── README.md └── option.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /eval/vqa/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/vqa/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/vqa/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/vqa/updn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /common.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONUNBUFFERED=1 3 | ROOT=/hdd/robik # Change this! 4 | -------------------------------------------------------------------------------- /images/main_table.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/main_table.jpg -------------------------------------------------------------------------------- /images/scalability.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/scalability.jpg -------------------------------------------------------------------------------- /images/bias_exploitation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/bias_exploitation.jpg -------------------------------------------------------------------------------- /images/distribution_variance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/distribution_variance.jpg -------------------------------------------------------------------------------- /trainers/trainer_factory.py: -------------------------------------------------------------------------------- 1 | from trainers import * 2 | 3 | def build_trainer(option): 4 | return eval(option.trainer_name)(option) 5 | -------------------------------------------------------------------------------- /datasets/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /datasets/.idea/$CACHE_FILE$: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /datasets/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /datasets/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /scripts/gqa-ood/run_multiple_gqa.sh: -------------------------------------------------------------------------------- 1 | #./scripts/gqa-ood/baseline_gqa.sh 2 | 3 | #./scripts/gqa-ood/upweighting_gqa.sh 4 | #./scripts/gqa-ood/group_dro_gqa.sh 5 | 6 | ./scripts/gqa-ood/rubi_gqa.sh 7 | ./scripts/gqa-ood/lnl_gqa.sh 8 | #./scripts/gqa-ood/irmv1_gqa.sh 9 | #./scripts/gqa-ood/lff_gqa.sh 10 | -------------------------------------------------------------------------------- /scripts/gqa-ood/baseline_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='BaseTrainer' 7 | python main.py \ 8 | --expt_type gqa_experiments \ 9 | --lr 1e-4 \ 10 | --weight_decay 0 \ 11 | --trainer_name ${TRAINER_NAME} \ 12 | --root_dir ${ROOT} 13 | -------------------------------------------------------------------------------- /datasets/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/celebA/rubi_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='RUBiTrainer' 7 | 8 | python main.py \ 9 | --expt_type celebA_experiments \ 10 | --trainer_name ${TRAINER_NAME} \ 11 | --lr 1e-4 \ 12 | --weight_decay 1e-5 \ 13 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/gqa-ood/lff_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='LffTrainer' 7 | python main.py \ 8 | --expt_type gqa_lff \ 9 | --lr 1e-4 \ 10 | --weight_decay 0 \ 11 | --trainer_name ${TRAINER_NAME} \ 12 | --bias_loss_gamma 0.7 \ 13 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /datasets/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /scripts/biased_mnist/rubi_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='RUBiTrainer' 7 | 8 | python main.py \ 9 | --expt_type biased_mnist_experiments \ 10 | --trainer_name ${TRAINER_NAME} \ 11 | --lr 1e-3 \ 12 | --weight_decay 1e-5 \ 13 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/celebA/group_upweighting_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='GroupUpweightingTrainer' 7 | python main.py \ 8 | --expt_type celebA_experiments \ 9 | --lr 1e-5 \ 10 | --weight_decay 0.1 \ 11 | --trainer_name ${TRAINER_NAME} \ 12 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/celebA/group_dro_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='GroupDROTrainer' 7 | 8 | python main.py \ 9 | --expt_type celebA_experiments \ 10 | --trainer_name ${TRAINER_NAME} \ 11 | --lr 1e-5 \ 12 | --weight_decay 0.1 \ 13 | --group_weight_step_size 0.01 \ 14 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/biased_mnist/baseline_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='BaseTrainer' 7 | lr=1e-3 8 | wd=1e-5 9 | 10 | python main.py \ 11 | --expt_type biased_mnist_experiments \ 12 | --trainer_name ${TRAINER_NAME} \ 13 | --lr ${lr} \ 14 | --weight_decay ${wd} \ 15 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/celebA/baseline_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='BaseTrainer' 7 | lr=1e-3 8 | wd=0 9 | python main.py \ 10 | --expt_type celebA_experiments \ 11 | --trainer_name ${TRAINER_NAME} \ 12 | --lr ${lr} \ 13 | --weight_decay ${wd} \ 14 | --expt_name ${TRAINER_NAME} \ 15 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/celebA/irmv1_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='IRMv1Trainer' 7 | python main.py \ 8 | --lr 1e-4 \ 9 | --weight_decay 0 \ 10 | --expt_type celebA_experiments \ 11 | --trainer_name ${TRAINER_NAME} \ 12 | --grad_penalty_weight 1 \ 13 | --num_envs_per_batch 4 \ 14 | --root_dir ${ROOT} 15 | -------------------------------------------------------------------------------- /scripts/biased_mnist/group_dro_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='GroupDROTrainer' 7 | 8 | python -u main.py \ 9 | --expt_type biased_mnist_experiments \ 10 | --lr 1e-3 \ 11 | --weight_decay 1e-5 \ 12 | --trainer_name ${TRAINER_NAME} \ 13 | --group_weight_step_size 1e-3 \ 14 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/biased_mnist/group_upweighting_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='GroupUpweightingTrainer' 7 | 8 | CUDA_VISIBLE_DEVICES=0 python main.py \ 9 | --expt_type biased_mnist_experiments \ 10 | --trainer_name ${TRAINER_NAME} \ 11 | --lr 1e-3 \ 12 | --weight_decay 1e-5 \ 13 | --root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/celebA/spectral_decoupling_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='SpectralDecouplingTrainer' 7 | 8 | # Lambdas and gammas are specified in celebA_experiments.py 9 | python main.py \ 10 | --lr 1e-4 \ 11 | --weight_decay 1e-5 \ 12 | --expt_type celebA_experiments \ 13 | --trainer_name ${TRAINER_NAME} \ 14 | --root_dir ${ROOT} 15 | -------------------------------------------------------------------------------- /scripts/gqa-ood/spectral_decoupling_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='SpectralDecouplingTrainer' 7 | 8 | python main.py \ 9 | --lr 1e-4 \ 10 | --weight_decay 0 \ 11 | --expt_type gqa_sd_experiments \ 12 | --trainer_name ${TRAINER_NAME} \ 13 | --root_dir ${ROOT} \ 14 | --spectral_decoupling_lambda 1e-3 \ 15 | --spectral_decoupling_gamma 1e-3 -------------------------------------------------------------------------------- /models/vqa/make_mask.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # -------------------------------------------------------- 5 | 6 | import torch 7 | 8 | 9 | # Masking the sequence mask 10 | def make_mask(feature): 11 | return (torch.sum( 12 | torch.abs(feature), 13 | dim=-1 14 | ) == 0).unsqueeze(1).unsqueeze(2) -------------------------------------------------------------------------------- /datasets/.idea/datasets.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /scripts/celebA/lnl_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='LNLTrainer' 7 | python main.py \ 8 | --expt_type celebA_experiments \ 9 | --lr 1e-4 \ 10 | --weight_decay 1e-4 \ 11 | --trainer_name ${TRAINER_NAME} \ 12 | --root_dir ${ROOT} 13 | 14 | #TRAINER_NAME='LNLTrainer' 15 | #python main.py \ 16 | #--expt_type celebA_experiments \ 17 | #--trainer_name ${TRAINER_NAME} \ 18 | #--root_dir ${ROOT} -------------------------------------------------------------------------------- /scripts/biased_mnist/irmv1_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='IRMv1Trainer' 7 | for grad_penalty_weight in 0.01; do 8 | python main.py \ 9 | --lr 1e-3 \ 10 | --weight_decay 1e-5 \ 11 | --expt_type biased_mnist_experiments \ 12 | --trainer_name ${TRAINER_NAME} \ 13 | --grad_penalty_weight ${grad_penalty_weight} \ 14 | --num_envs_per_batch 16 \ 15 | --root_dir ${ROOT} 16 | done -------------------------------------------------------------------------------- /scripts/biased_mnist/learning_from_failure_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='LffTrainer' 7 | 8 | for bias_loss_gamma in 0.5; do 9 | python main.py \ 10 | --expt_type biased_mnist_experiments \ 11 | --lr 1e-3 \ 12 | --weight_decay 1e-5 \ 13 | --trainer_name ${TRAINER_NAME} \ 14 | --optimizer_name Adam \ 15 | --bias_loss_gamma ${bias_loss_gamma} \ 16 | --root_dir ${ROOT} 17 | done -------------------------------------------------------------------------------- /scripts/gqa-ood/upweighting_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | #for key_to_group_by in head_tail answer global_group_name local_group_name; do 7 | for key_to_group_by in head_tail ; do 8 | TRAINER_NAME='GroupUpweightingTrainer' 9 | python main.py \ 10 | --expt_type gqa_experiments \ 11 | --trainer_name ${TRAINER_NAME} \ 12 | --key_to_group_by ${key_to_group_by} \ 13 | --lr 1e-3 \ 14 | --weight_decay 0 \ 15 | --root_dir ${ROOT} 16 | done -------------------------------------------------------------------------------- /scripts/gqa-ood/rubi_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='RUBiTrainer' 7 | #for key_to_group_by in head_tail global_group_name local_group_name; do 8 | for key_to_group_by in head_tail ; do 9 | python main.py \ 10 | --expt_type gqa_experiments \ 11 | --lr 1e-4 \ 12 | --weight_decay 0 \ 13 | --trainer_name ${TRAINER_NAME} \ 14 | --key_to_group_by ${key_to_group_by} \ 15 | --bias_variable_name group_ix \ 16 | --root_dir ${ROOT} 17 | done 18 | -------------------------------------------------------------------------------- /scripts/gqa-ood/preprocess_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | mkdir -p ${ROOT}/GQA/preprocessed/objects 7 | mkdir -p ${ROOT}/GQA/preprocessed/spatial 8 | 9 | python datasets/vqa/gqa_feat_preproc.py \ 10 | --mode object \ 11 | --object_dir ${ROOT}/GQA/objects \ 12 | --out_dir ${ROOT}/GQA/preprocessed/objects 13 | 14 | python datasets/vqa/gqa_feat_preproc.py \ 15 | --mode spatial \ 16 | --spatial_dir ${ROOT}/GQA/spatial \ 17 | --out_dir ${ROOT}/GQA/preprocessed/spatial 18 | -------------------------------------------------------------------------------- /scripts/biased_mnist_dataset_generator/full_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source activate bias_mitigator 3 | #python datasets/biased_mnist_generator.py \ 4 | #--config_file conf/biased_mnist_generator/full.yaml \ 5 | #--p_bias 0.9 \ 6 | #--suffix '_0.9' \ 7 | #--generate_test_set 1 8 | 9 | for p_bias in 0.9 0.93 0.95 0.97 0.99 1.0; do 10 | python -u datasets/biased_mnist_generator.py \ 11 | --config_file conf/biased_mnist_generator/full_v1.yaml \ 12 | --p_bias ${p_bias} \ 13 | --suffix _${p_bias} \ 14 | --generate_test_set 0 15 | done -------------------------------------------------------------------------------- /scripts/gqa-ood/group_dro_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='GroupDROTrainer' 7 | #for key_to_group_by in head_tail answer global_group_name local_group_name; do 8 | for key_to_group_by in head_tail ; do 9 | python main.py \ 10 | --expt_type gqa_experiments \ 11 | --lr 1e-4 \ 12 | --weight_decay 0 \ 13 | --group_weight_step_size 0.01 \ 14 | --trainer_name ${TRAINER_NAME} \ 15 | --key_to_group_by ${key_to_group_by} \ 16 | --root_dir ${ROOT} 17 | done 18 | -------------------------------------------------------------------------------- /scripts/biased_mnist/lnl_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='LNLTrainer' 7 | for entropy_loss_weight in 0.01; do 8 | for grad_reverse_factor in -0.1; do 9 | python main.py \ 10 | --expt_type biased_mnist_experiments \ 11 | --lr 1e-3 \ 12 | --weight_decay 1e-5 \ 13 | --trainer_name ${TRAINER_NAME} \ 14 | --root_dir ${ROOT} \ 15 | --entropy_loss_weight ${entropy_loss_weight} \ 16 | --grad_reverse_factor ${grad_reverse_factor} 17 | done 18 | done -------------------------------------------------------------------------------- /scripts/biased_mnist/spectral_decoupling_biased_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='SpectralDecouplingTrainer' 7 | 8 | for sd_gamma in 1e-3; do 9 | for sd_lambda in 1e-3; do 10 | python main.py \ 11 | --lr 1e-3 \ 12 | --weight_decay 1e-5 \ 13 | --expt_type biased_mnist_experiments \ 14 | --trainer_name ${TRAINER_NAME} \ 15 | --root_dir ${ROOT} \ 16 | --spectral_decoupling_lambda ${sd_lambda} \ 17 | --spectral_decoupling_gamma ${sd_gamma} 18 | done 19 | done -------------------------------------------------------------------------------- /scripts/gqa-ood/irmv1_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='IRMv1Trainer' 7 | 8 | #for key_to_group_by in head_tail answer global_group_name local_group_name; do 9 | for key_to_group_by in head_tail ; do 10 | python main.py \ 11 | --lr 1e-4 \ 12 | --weight_decay 0 \ 13 | --expt_type gqa_experiments \ 14 | --key_to_group_by ${key_to_group_by} \ 15 | --trainer_name ${TRAINER_NAME} \ 16 | --grad_penalty_weight 0.01 \ 17 | --num_envs_per_batch 16 \ 18 | --root_dir ${ROOT} 19 | done -------------------------------------------------------------------------------- /scripts/gqa-ood/lnl_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='LNLTrainer' 7 | #BIAS_VARIABLE_NAME=answer 8 | #BIAS_VARIABLE_NAME=head_tail 9 | 10 | for key_to_group_by in qtype_detailed; do 11 | EXPT_NAME=bias_variable_${key_to_group_by} 12 | python main.py \ 13 | --expt_type gqa_experiments \ 14 | --lr 1e-3 \ 15 | --weight_decay 0 \ 16 | --grad_reverse_factor -0.1 \ 17 | --entropy_loss_weight 0.01 \ 18 | --key_to_group_by ${key_to_group_by} \ 19 | --bias_variable_name group_ix \ 20 | --trainer_name ${TRAINER_NAME} \ 21 | --expt_name ${EXPT_NAME} \ 22 | --root_dir ${ROOT} 23 | done -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | # iterate through the modules in the current package 6 | package_dir = Path(__file__).resolve().parent 7 | for (_, module_name, _) in iter_modules([package_dir]): 8 | 9 | # import the module and iterate through its attributes 10 | module = import_module(f"{__name__}.{module_name}") 11 | for attribute_name in dir(module): 12 | attribute = getattr(module, attribute_name) 13 | 14 | if isclass(attribute): 15 | # Add the class to this package's variables 16 | globals()[attribute_name] = attribute 17 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | # iterate through the modules in the current package 6 | package_dir = Path(__file__).resolve().parent 7 | for (_, module_name, _) in iter_modules([package_dir]): 8 | 9 | # import the module and iterate through its attributes 10 | module = import_module(f"{__name__}.{module_name}") 11 | for attribute_name in dir(module): 12 | attribute = getattr(module, attribute_name) 13 | 14 | if isclass(attribute): 15 | # Add the class to this package's variables 16 | globals()[attribute_name] = attribute 17 | -------------------------------------------------------------------------------- /scripts/celebA/learning_from_failure_celebA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | 6 | TRAINER_NAME='LffTrainer' 7 | 8 | # Note that we used SGD for all other methods, however LFF didn't gel well with SGD despite searching for hyperparameters. 9 | # So, we are using Adam. 10 | # Also, we were unable to replicate the original paper's results with bias_loss_gamma = 0.7. We tuned it, and found that 0.1 worked best for our runs. 11 | 12 | python main.py \ 13 | --expt_type celebA_experiments \ 14 | --lr 1e-4 \ 15 | --weight_decay 0 \ 16 | --trainer_name ${TRAINER_NAME} \ 17 | --optimizer_name Adam \ 18 | --bias_loss_gamma 0.1 \ 19 | --root_dir ${ROOT} 20 | 21 | -------------------------------------------------------------------------------- /models/vqa/ops/layer_norm.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # -------------------------------------------------------- 5 | 6 | import torch.nn as nn 7 | import torch 8 | 9 | class LayerNorm(nn.Module): 10 | def __init__(self, size, eps=1e-6): 11 | super(LayerNorm, self).__init__() 12 | self.eps = eps 13 | 14 | self.a_2 = nn.Parameter(torch.ones(size)) 15 | self.b_2 = nn.Parameter(torch.zeros(size)) 16 | 17 | def forward(self, x): 18 | mean = x.mean(-1, keepdim=True) 19 | std = x.std(-1, keepdim=True) 20 | 21 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 22 | -------------------------------------------------------------------------------- /eval/vqa/plot_tail.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns; 2 | import os 3 | 4 | sns.set() 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import numpy as np 8 | 9 | 10 | def plot_tail_for_one_model(alpha, accuracy, model_name='default'): 11 | data = {'Tail size': alpha, model_name: accuracy} 12 | df = pd.DataFrame(data, dtype=float) 13 | df = pd.melt(df, ['Tail size'], var_name="Models", value_name="Accuracy") 14 | ax = sns.lineplot(x="Tail size", y="Accuracy", hue="Models", style="Models", data=df, markers=False, ci=None) 15 | plt.xscale('log') 16 | plt.ylim(0, 100) 17 | save_dir = 'figures' 18 | if not os.path.exists(save_dir): 19 | os.makedirs(save_dir) 20 | plt.savefig('figures/tail_plot_%s.pdf' % model_name) 21 | plt.close() 22 | -------------------------------------------------------------------------------- /datasets/vqa/feat_filter.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # -------------------------------------------------------- 5 | 6 | 7 | def feat_filter(dataset, frcn_feat, grid_feat, bbox_feat): 8 | feat_dict = {} 9 | 10 | if dataset in ['vqa']: 11 | feat_dict['FRCN_FEAT'] = frcn_feat 12 | feat_dict['BBOX_FEAT'] = bbox_feat 13 | 14 | elif dataset in ['gqa']: 15 | feat_dict['FRCN_FEAT'] = frcn_feat 16 | feat_dict['GRID_FEAT'] = grid_feat 17 | feat_dict['BBOX_FEAT'] = bbox_feat 18 | 19 | elif dataset in ['clevr']: 20 | feat_dict['GRID_FEAT'] = grid_feat 21 | 22 | else: 23 | exit(-1) 24 | 25 | return feat_dict 26 | 27 | 28 | -------------------------------------------------------------------------------- /utils/bias_retrievers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def build_bias_retriever(bias_variable_name): 5 | if bias_variable_name == 'color': 6 | return ColorRetriever() 7 | else: 8 | return VariableRetriever(bias_variable_name) 9 | 10 | 11 | class ColorRetriever: 12 | 13 | def retrieve(self, batch): 14 | x = batch['x'] 15 | return x.view(x.size(0), x.size(1), -1).max(2)[0] 16 | 17 | def __call__(self, batch, main_out): 18 | return self.retrieve(batch) 19 | 20 | 21 | class VariableRetriever: 22 | def __init__(self, var_name): 23 | self.var_name = var_name 24 | 25 | def __call__(self, batch, main_out): 26 | if self.var_name in batch: 27 | ret = batch[self.var_name] 28 | else: 29 | ret = main_out[self.var_name] 30 | if isinstance(ret, list): 31 | ret = torch.FloatTensor(ret).cuda() 32 | return ret 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Robik Shrestha, Kushal Kafle and Christopher Kanan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/vqa/ops/fc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # -------------------------------------------------------- 5 | 6 | import torch.nn as nn 7 | import torch 8 | 9 | 10 | class FC(nn.Module): 11 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): 12 | super(FC, self).__init__() 13 | self.dropout_r = dropout_r 14 | self.use_relu = use_relu 15 | 16 | self.linear = nn.Linear(in_size, out_size) 17 | 18 | if use_relu: 19 | self.relu = nn.ReLU(inplace=True) 20 | 21 | if dropout_r > 0: 22 | self.dropout = nn.Dropout(dropout_r) 23 | 24 | def forward(self, x): 25 | x = self.linear(x) 26 | 27 | if self.use_relu: 28 | x = self.relu(x) 29 | 30 | if self.dropout_r > 0: 31 | x = self.dropout(x) 32 | 33 | return x 34 | 35 | 36 | class MLP(nn.Module): 37 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): 38 | super(MLP, self).__init__() 39 | 40 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) 41 | self.linear = nn.Linear(mid_size, out_size) 42 | 43 | def forward(self, x): 44 | return self.linear(self.fc(x)) 45 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | from models.fc_models import * 2 | from models.cnn_models import * 3 | from models.vqa.updn.net import * 4 | from models.vqa.mcan.net import MCAN 5 | # from models.vqa.ban.ban import BAN, BANNoDropout 6 | 7 | 8 | def build_model(option, 9 | model_name, 10 | in_dims=None, 11 | hid_dims=None, 12 | out_dims=None, 13 | freeze_layers=None, 14 | dropout=None 15 | ): 16 | if 'updn' in model_name.lower() or 'mcan' in model_name.lower() or 'ban' in model_name.lower(): 17 | m = eval(model_name)(option.dataset_info.pretrained_emb, 18 | option.dataset_info.token_size, 19 | option.dataset_info.ans_size) 20 | else: 21 | if in_dims is None and hid_dims is None: 22 | m = eval(model_name)(num_classes=out_dims) 23 | elif hid_dims is None: 24 | m = eval(model_name)(in_dims=in_dims, num_classes=out_dims) 25 | elif dropout is None: 26 | m = eval(model_name)(in_dims=in_dims, hid_dims=hid_dims, num_classes=out_dims) 27 | else: 28 | m = eval(model_name)(in_dims=in_dims, hid_dims=hid_dims, num_classes=out_dims, dropout=dropout) 29 | if freeze_layers is not None: 30 | m.freeze_layers(freeze_layers) 31 | return m 32 | -------------------------------------------------------------------------------- /models/vqa/vqa_adapter.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Adapted from OpenVQA 3 | # Written by Zhenwei Shao https://github.com/ParadoxZW 4 | # -------------------------------------------------------- 5 | 6 | import torch.nn as nn 7 | import torch 8 | from models.vqa.make_mask import make_mask 9 | 10 | 11 | class VQAAdapter(nn.Module): 12 | def __init__(self, 13 | frcn_feat_size=(100, 2048), 14 | bbox_feat_size=(100, 5), 15 | bbox_feat_emb_size=1024, 16 | use_bbox_feats=True, 17 | hidden_size=1024 18 | ): 19 | super(VQAAdapter, self).__init__() 20 | self.frcn_feat_size = frcn_feat_size 21 | self.bbox_feat_size = bbox_feat_size 22 | self.use_bbox_feats = use_bbox_feats 23 | self.hidden_size = hidden_size 24 | in_size = frcn_feat_size 25 | if self.use_bbox_feats: 26 | self.bbox_linear = nn.Linear(5, bbox_feat_emb_size) 27 | in_size = frcn_feat_size[1] + bbox_feat_emb_size 28 | self.frcn_linear = nn.Linear(in_size, hidden_size) 29 | 30 | def forward(self, frcn_feat, bbox_feat): 31 | 32 | img_feat_mask = make_mask(frcn_feat) 33 | 34 | if self.use_bbox_feats: 35 | bbox_feat = self.bbox_linear(bbox_feat) 36 | frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1) 37 | img_feat = self.frcn_linear(frcn_feat) 38 | 39 | return img_feat, img_feat_mask 40 | -------------------------------------------------------------------------------- /scripts/gqa-ood/hyperparam_search_gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source common.sh 3 | set -e 4 | source activate bias_mitigator 5 | PROJECT_NAME='GQA' 6 | 7 | TRAINER_NAME='BaseTrainer' 8 | for lr in 1e-2; do 9 | python main.py \ 10 | --expt_type gqa_experiments \ 11 | --project_name ${PROJECT_NAME} \ 12 | --trainer_name ${TRAINER_NAME} \ 13 | --lr ${lr} \ 14 | --weight_decay 0 \ 15 | --root_dir ${ROOT} 16 | done 17 | 18 | #TRAINER_NAME='RUBiTrainer' 19 | #for lr in 1e-2; do 20 | # for wd in 0 1e-3 0.1; do 21 | # python main.py \ 22 | # --expt_type gqa_experiments \ 23 | # --project_name ${PROJECT_NAME} \ 24 | # --trainer_name ${TRAINER_NAME} \ 25 | # --lr ${lr} \ 26 | # --weight_decay ${wd} \ 27 | # --root_dir ${ROOT} 28 | # done 29 | #done 30 | 31 | 32 | #TRAINER_NAME='GroupDROTrainer' 33 | #for lr in 1e-2; do 34 | # for wd in 0 1e-3 0.1; do 35 | # python main.py \ 36 | # --expt_type gqa_experiments \ 37 | # --project_name ${PROJECT_NAME} \ 38 | # --trainer_name ${TRAINER_NAME} \ 39 | # --lr ${lr} \ 40 | # --weight_decay ${wd} \ 41 | # --root_dir ${ROOT} 42 | # done 43 | #done 44 | 45 | #TRAINER_NAME='LffTrainer' 46 | #for lr in 1e-2; do 47 | # for wd in 0 1e-3 0.1; do 48 | # python main.py \ 49 | # --expt_type gqa_experiments \ 50 | # --project_name ${PROJECT_NAME} \ 51 | # --trainer_name ${TRAINER_NAME} \ 52 | # --lr ${lr} \ 53 | # --weight_decay ${wd} \ 54 | # --root_dir ${ROOT} 55 | # done 56 | #done 57 | 58 | 59 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GCELoss(nn.Module): 7 | def __init__(self, q=0.7, reduction='none'): 8 | super(GCELoss, self).__init__() 9 | self.q = q 10 | self.reduction = reduction 11 | 12 | def __call__(self, input, target): 13 | p = F.softmax(input, dim=1) 14 | Yg = torch.gather(p, 1, torch.unsqueeze(target, 1)) 15 | loss_weight = (Yg.squeeze().detach() ** self.q) * self.q 16 | loss = F.cross_entropy(input, target, reduction=self.reduction) * loss_weight 17 | return loss 18 | 19 | 20 | class ExpLoss(nn.Module): 21 | def __init__(self, reduction='none'): 22 | super(ExpLoss, self).__init__() 23 | self.reduction = reduction 24 | 25 | def __call__(self, input, target): 26 | return torch.exp(torch.gather(1 - F.softmax(input, dim=1), dim=1, index=target.view(-1, 1))) 27 | 28 | 29 | class InverseProbabilityLoss(nn.Module): 30 | def __init__(self, reduction='none'): 31 | super(InverseProbabilityLoss, self).__init__() 32 | self.reduction = reduction 33 | 34 | def __call__(self, input, target): 35 | return 1 / torch.gather(1 - F.softmax(input, dim=1), dim=1, index=target.view(-1, 1)) 36 | # return torch.exp(torch.gather(1 - F.softmax(input, dim=1), dim=1, index=target.view(-1, 1))) 37 | 38 | 39 | if __name__ == "__main__": 40 | gce_loss = GCELoss(q=2) 41 | l = gce_loss(torch.FloatTensor([[0.1, 0.9], [0.1, 0.8]]), torch.LongTensor([0, 1])) 42 | print(l) 43 | -------------------------------------------------------------------------------- /utils/running_stats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # https://stackoverflow.com/a/17637351/1122681 5 | class RunningStats: 6 | 7 | def __init__(self): 8 | self.n = 0 9 | self.old_m = None 10 | self.new_m = None 11 | self.old_s = None 12 | self.new_s = None 13 | self.sum = None 14 | 15 | def clear(self): 16 | self.n = 0 17 | 18 | def push(self, x): 19 | self.n += 1 20 | 21 | if self.n == 1: 22 | self.sum = torch.zeros_like(x) 23 | self.sum += x 24 | self.old_m = self.new_m = x 25 | self.old_s = torch.zeros_like(x) 26 | self.new_m = torch.zeros_like(x) 27 | self.new_s = torch.zeros_like(x) 28 | 29 | else: 30 | self.new_m = self.old_m + (x - self.old_m) / self.n 31 | self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m) 32 | self.sum += x 33 | self.old_m = self.new_m 34 | self.old_s = self.new_s 35 | 36 | def mean(self): 37 | return self.new_m if self.n else torch.zeros_like(self.new_m) 38 | 39 | def variance(self): 40 | return self.new_s / (self.n - 1) if self.n > 1 else 0.0 41 | 42 | def std(self): 43 | return torch.sqrt(self.variance()) 44 | 45 | def std_err(self): 46 | return self.std() / (self.n ** 0.5) 47 | 48 | def get_summary(self): 49 | return { 50 | 'n': self.n, 51 | 'mean': self.mean(), 52 | 'std': self.std(), 53 | 'full_mean': self.sum / self.n 54 | } 55 | 56 | 57 | if __name__ == "__main__": 58 | stats = RunningStats() 59 | for i in range(1, 10): 60 | stats.push(torch.randn((3, 4))) 61 | print(stats.mean()) 62 | print(stats.std()) 63 | -------------------------------------------------------------------------------- /utils/format_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | 6 | 7 | # Define function for string formatting of scientific notation 8 | def sci_notation(num, decimal_digits=1, precision=None, exponent=None): 9 | """ 10 | Returns a string representation of the scientific 11 | notation of the given number formatted for use with 12 | LaTeX or Mathtext, with specified number of significant 13 | decimal digits and precision (number of decimal digits 14 | to show). The exponent to be used can also be specified 15 | explicitly. 16 | """ 17 | if exponent is None: 18 | exponent = int(np.floor(np.log10(abs(num)))) 19 | coeff = round(num / float(10 ** exponent), decimal_digits) 20 | if precision is None: 21 | precision = decimal_digits 22 | 23 | return r"${0:.{2}f}\times 10^{{{1:d}}}$".format(coeff, exponent, precision) 24 | 25 | 26 | # def convert_to_exponent_of_ten(v): 27 | # if v == 0: 28 | # return v 29 | # else: 30 | # try: 31 | # return '$10^{' + str(int(np.log10(v))) + '}$' 32 | # # return '{:.1e}'.format(v) 33 | # except: 34 | # return v 35 | 36 | def format_matplotlib(text_fontsize=10, default_fontsize=14, legend_fontsize=10): 37 | sns.set(style='darkgrid') 38 | # plt.style.use('tableau-colorblind10') 39 | font = { 40 | 'family': 'normal', 41 | 'weight': 'bold' 42 | } 43 | matplotlib.rc('font', **font) 44 | params = {'axes.labelsize': default_fontsize, 'axes.titlesize': default_fontsize, 'font.size': text_fontsize, 45 | 'legend.fontsize': legend_fontsize, 46 | 'xtick.labelsize': default_fontsize, 'ytick.labelsize': default_fontsize} 47 | matplotlib.rcParams.update(params) 48 | 49 | 50 | format_matplotlib() 51 | -------------------------------------------------------------------------------- /datasets/vqa/base_vqa_dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # -------------------------------------------------------- 5 | 6 | import numpy as np 7 | import glob, json, torch, random 8 | import torch.utils.data as Data 9 | import torch.nn as nn 10 | from datasets.vqa.feat_filter import feat_filter 11 | import logging 12 | 13 | 14 | class BaseVQADataset(Data.Dataset): 15 | def __init__(self, load_visual_feats=True): 16 | self.token_to_ix = None 17 | self.pretrained_emb = None 18 | self.ans_to_ix = None 19 | self.ix_to_ans = None 20 | 21 | self.data_size = None 22 | self.token_size = None 23 | self.ans_size = None 24 | self.load_visual_feats = load_visual_feats 25 | 26 | def load_ques_ans(self, idx): 27 | raise NotImplementedError() 28 | 29 | def load_img_feats(self, idx, iid): 30 | raise NotImplementedError() 31 | 32 | def __getitem__(self, idx): 33 | vqa = self.load_ques_ans(idx) 34 | if self.load_visual_feats: 35 | frcn_feat_iter, bbox_feat_iter = self.load_img_feats(idx, vqa['image_id']) 36 | vqa['frcn_feat'] = torch.from_numpy(frcn_feat_iter) 37 | vqa['bbox_feat'] = torch.from_numpy(bbox_feat_iter) 38 | vqa['dataset_ix'] = idx 39 | vqa['question_token_ixs'] = torch.from_numpy(vqa['question_token_ixs']) 40 | vqa['y'] = torch.from_numpy(vqa['ans_iter']) 41 | # Adding 'group_name' suffix 42 | vqa['answer_group_name'] = vqa['answer'] 43 | return vqa 44 | # return { 45 | # 'frcn_feat': torch.from_numpy(frcn_feat_iter), 46 | # 'bbox_feat': torch.from_numpy(bbox_feat_iter), 47 | # 'question_id': q_details['question_id'], 48 | # 'ques_ix_iter': torch.from_numpy(q_details['ques_ix_iter']), 49 | # 'answer': q_details['ans'], 50 | # 'y': torch.from_numpy(q_details['ans_iter']), 51 | # 'dataset_ix': idx, 52 | # 'local_group_name': q_details['local_grp_name'], 53 | # 'group_name': q_details['group_name'], 54 | # 'group_ix': q_details['group_ix'] 55 | # } 56 | 57 | def __len__(self): 58 | return self.data_size 59 | 60 | def shuffle_list(self, list): 61 | random.shuffle(list) 62 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | import torch 4 | from option import get_option 5 | from trainers import trainer_factory 6 | from utils.trainer_utils import save_option, initialize_logger 7 | import logging 8 | from datasets import dataloader_factory 9 | import json 10 | from experiments.celebA_experiments import * 11 | from experiments.biased_mnist_experiments import * 12 | from experiments.gqa_experiments import * 13 | 14 | 15 | def backend_setting(option): 16 | # Initialize the expt_dir where all the results (predictions, checkpoints, logs, metrics) will be saved 17 | if option.expt_dir is None: 18 | option.expt_dir = os.path.join(option.save_dir, option.expt_name) 19 | 20 | if not os.path.exists(option.expt_dir): 21 | os.makedirs(option.expt_dir) 22 | 23 | # Configure the logger 24 | initialize_logger(option.expt_dir) 25 | 26 | # Set the random seeds 27 | if option.random_seed is None: 28 | option.random_seed = random.randint(1, 10000) 29 | random.seed(option.random_seed) 30 | torch.manual_seed(option.random_seed) 31 | torch.cuda.manual_seed_all(option.random_seed) 32 | np.random.seed(option.random_seed) 33 | 34 | if torch.cuda.is_available() and not option.cuda: 35 | logging.warn('GPU is available, but we are not using it!!!') 36 | 37 | if not torch.cuda.is_available() and option.cuda: 38 | option.cuda = False 39 | 40 | # Dataset specific settings 41 | set_if_null(option, 'bias_loss_gamma', 0.7) 42 | set_if_null(option, 'bias_ema_gamma', 0.7) 43 | 44 | 45 | def set_if_null(option, attr_name, val): 46 | if not hasattr(option, attr_name) or getattr(option, attr_name) is None: 47 | setattr(option, attr_name, val) 48 | 49 | def main(): 50 | option = get_option() 51 | if option.project_name is None: 52 | option.project_name = option.dataset_name 53 | if option.expt_type is not None: 54 | eval(option.expt_type)(option, run) 55 | else: 56 | run(option) 57 | 58 | 59 | def run(option): 60 | backend_setting(option) 61 | data_loaders = dataloader_factory.build_dataloaders(option) 62 | if 'gqa' in option.dataset_name.lower(): 63 | option.bias_variable_dims = option.num_groups 64 | option.num_bias_classes = option.num_groups 65 | 66 | save_option(option) 67 | logging.getLogger().info(json.dumps(option.__dict__, indent=4, sort_keys=True, 68 | default=lambda o: f"<>")) 69 | 70 | trainer = trainer_factory.build_trainer(option) 71 | trainer.train(data_loaders['Train'], data_loaders['Test'], data_loaders['Unbalanced Train']) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /models/vqa/ban/ban.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Zhenwei Shao https://github.com/ParadoxZW 4 | # -------------------------------------------------------- 5 | 6 | from models.vqa.make_mask import make_mask 7 | from models.vqa.ops.fc import FC, MLP 8 | from models.vqa.ops.layer_norm import LayerNorm 9 | from models.vqa.ban._ban import _BAN 10 | from models.vqa.vqa_adapter import VQAAdapter 11 | 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.utils.weight_norm import weight_norm 15 | import torch 16 | 17 | 18 | # ------------------------- 19 | # ---- Main BAN Model ---- 20 | # ------------------------- 21 | 22 | class BAN(nn.Module): 23 | def __init__(self, pretrained_emb, token_size, answer_size, 24 | word_embed_size=300, 25 | img_feat_size=1024, 26 | hidden_size=1024, 27 | k_times=3, 28 | dropout_r=0.2, 29 | classifier_dropout_r=0.5, 30 | glimpse=8, 31 | flat_out_size=2048, 32 | use_glove=True): 33 | super(BAN, self).__init__() 34 | 35 | self.embedding = nn.Embedding( 36 | num_embeddings=token_size, 37 | embedding_dim=word_embed_size 38 | ) 39 | 40 | # Loading the GloVe embedding weights 41 | if use_glove: 42 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 43 | 44 | self.rnn = nn.GRU( 45 | input_size=word_embed_size, 46 | hidden_size=hidden_size, 47 | num_layers=1, 48 | batch_first=True 49 | ) 50 | 51 | self.adapter = VQAAdapter(hidden_size=hidden_size) 52 | 53 | self.backbone = _BAN(img_feat_size, 54 | hidden_size, 55 | k_times, 56 | dropout_r, 57 | classifier_dropout_r, 58 | glimpse) 59 | 60 | # Classification layers 61 | layers = [ 62 | weight_norm(nn.Linear(hidden_size, flat_out_size), dim=None), 63 | nn.ReLU(), 64 | nn.Dropout(classifier_dropout_r, inplace=True), 65 | weight_norm(nn.Linear(flat_out_size, answer_size), dim=None) 66 | ] 67 | self.classifier = nn.Sequential(*layers) 68 | 69 | def forward(self, frcn_feat, bbox_feat, ques_ix): 70 | # Pre-process Language Feature 71 | # lang_feat_mask = make_mask(ques_ix.unsqueeze(2)) 72 | lang_feat = self.embedding(ques_ix) 73 | lang_feat, _ = self.rnn(lang_feat) 74 | img_feat, _ = self.adapter(frcn_feat, bbox_feat) 75 | 76 | # Backbone Framework 77 | lang_feat = self.backbone( 78 | lang_feat, 79 | img_feat 80 | ) 81 | 82 | # Classification layers 83 | proj_feat = self.classifier(lang_feat.sum(1)) 84 | return { 85 | 'question_features': lang_feat, 86 | 'logits': proj_feat 87 | } 88 | 89 | 90 | class BANNoDropout(BAN): 91 | def __init__(self, pretrained_emb, token_size, answer_size): 92 | super().__init__(pretrained_emb, token_size, answer_size, 93 | dropout_r=0, classifier_dropout_r=0) 94 | -------------------------------------------------------------------------------- /datasets/dataloader_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import json 5 | from datasets.biased_mnist_dataset import create_biased_mnist_dataloaders 6 | from datasets.celebA_dataset import create_celebA_dataloaders 7 | from datasets.vqa.gqa_dataset import create_gqa_dataloaders 8 | from utils.data_utils import dict_collate_fn 9 | 10 | def build_balanced_loader(dataloader, balanced_sampling_attributes=['y'], balanced_sampling_gamma=1, replacement=True): 11 | logger = logging.getLogger() 12 | all_group_names = [] 13 | 14 | # Count frequencies for all groups of attributes to balance, 15 | # and assign each sample to a group, so that we can compute its sampling weight later on 16 | group_name_to_count = {} 17 | for batch in dataloader: 18 | batch_group_names = [] 19 | for ix, _ in enumerate(batch['y']): 20 | group_name = "" 21 | for attr in balanced_sampling_attributes: 22 | group_name += f"{attr}_{batch[attr][ix]}_" 23 | batch_group_names.append(group_name) 24 | 25 | for group_name in batch_group_names: 26 | if group_name not in group_name_to_count: 27 | group_name_to_count[group_name] = 0 28 | group_name_to_count[group_name] += 1 29 | all_group_names.append(group_name) 30 | 31 | # Create the balanced loader 32 | weights = [] 33 | for val in all_group_names: 34 | weights.append(1 / group_name_to_count[val] ** balanced_sampling_gamma) 35 | weighted_sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(weights), 36 | replacement=replacement) 37 | balanced_dataloader = DataLoader(dataloader.dataset, batch_size=dataloader.batch_size, sampler=weighted_sampler, 38 | num_workers=dataloader.num_workers, collate_fn=dataloader.collate_fn) 39 | logger.info(f"Created balanced loader for {len(weights)} samples of dataset size {len(dataloader.dataset)}") 40 | logger.info(f"Group counts: {json.dumps(group_name_to_count, indent=4)}") 41 | return balanced_dataloader 42 | 43 | 44 | def build_dataloaders(option): 45 | dataset_name = option.dataset_name.lower() 46 | if dataset_name == 'biased_mnist_v1': 47 | loaders = create_biased_mnist_dataloaders(option) # Sets the num_groups 48 | elif dataset_name == 'celeba': 49 | loaders = create_celebA_dataloaders(option) 50 | elif dataset_name == 'gqa': 51 | loaders = create_gqa_dataloaders(option) 52 | loaders['Unbalanced Train'] = loaders['Train'] 53 | if option.balanced_sampling_attributes is not None: 54 | unshuffled_train_loader = DataLoader(loaders['Train'].dataset, batch_size=option.batch_size, shuffle=False, 55 | num_workers=option.num_workers, 56 | collate_fn=dict_collate_fn()) 57 | loaders['Train'] = build_balanced_loader(unshuffled_train_loader, 58 | option.balanced_sampling_attributes, 59 | balanced_sampling_gamma=option.balanced_sampling_gamma, 60 | replacement=True) 61 | return loaders 62 | -------------------------------------------------------------------------------- /eval/vqa/gqa_eval_from_file.py: -------------------------------------------------------------------------------- 1 | from eval.vqa.gqa_eval import GQAEval 2 | from eval.vqa.plot_tail import plot_tail_for_one_model 3 | import argparse 4 | import numpy as np 5 | import os.path 6 | import glob 7 | import json 8 | from option import ROOT 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_root', default=f'{ROOT}/GQA') 12 | parser.add_argument('--eval_tail_size', action='store_true') 13 | parser.add_argument('--ood_test', action='store_true') 14 | parser.add_argument('--predictions', type=str) 15 | args = parser.parse_args() 16 | 17 | 18 | def loadFile(name): 19 | # load standard json file 20 | if os.path.isfile(name): 21 | with open(name) as file: 22 | data = json.load(file) 23 | # load file chunks if too big 24 | elif os.path.isdir(name.split(".")[0]): 25 | data = {} 26 | chunks = glob.glob('{dir}/{dir}_*.{ext}'.format(dir=name.split(".")[0], ext=name.split(".")[1])) 27 | for chunk in chunks: 28 | with open(chunk) as file: 29 | data.update(json.load(file)) 30 | else: 31 | raise Exception("Can't find {}".format(name)) 32 | return data 33 | 34 | 35 | if args.eval_tail_size: 36 | result_eval_file = args.predictions 37 | 38 | # Retrieve scores 39 | alpha_list = [9.0, 7.0, 5.0, 3.6, 2.8, 2.2, 1.8, 1.4, 1.0, 0.8, 0.4, 0.3, 0.2, 0.1, 0.0, -0.1, -0.2, -0.3, -0.4, 40 | -0.5, -0.6, -0.7] 41 | acc_list = [] 42 | for alpha in alpha_list: 43 | ques_file_path = f'{args.data_root}/val_bal_tail_%.1f.json' % alpha 44 | predictions = loadFile(result_eval_file) 45 | # predictions = {p["questionId"]: p["prediction"] for p in predictions} 46 | 47 | gqa_eval = GQAEval(predictions, ques_file_path, choices_path=None, EVAL_CONSISTENCY=False) 48 | acc = gqa_eval.get_acc_result()['accuracy'] 49 | acc_list.append(acc) 50 | 51 | print("Alpha:", alpha_list) 52 | print("Accuracy:", acc_list) 53 | # Plot: save to "tail_plot_[model_name].pdf" 54 | # plot_tail(alpha=list(map(lambda x: x + 1, alpha_list)), accuracy=acc_list, 55 | # model_name='default') # We plot 1+alpha vs. accuracy 56 | elif args.ood_test: 57 | result_eval_file = args.predictions 58 | file_list = {'Tail': 'ood_testdev_tail.json', 'Head': 'ood_testdev_head.json', 'All': 'ood_testdev_all.json'} 59 | result = {} 60 | for setup, ques_file_path in file_list.items(): 61 | predictions = loadFile(result_eval_file) 62 | # predictions = {p["questionId"]: p["prediction"] for p in predictions} 63 | 64 | gqa_eval = GQAEval(predictions, f'{args.data_root}/' + ques_file_path, choices_path=None, 65 | EVAL_CONSISTENCY=False) 66 | result[setup] = gqa_eval.get_acc_result()['accuracy'] 67 | 68 | result_string, detail_result_string = gqa_eval.get_str_result() 69 | print('\n___%s___' % setup) 70 | for result_string_ in result_string: 71 | print(result_string_) 72 | 73 | print('\nRESULTS:\n') 74 | msg = 'Accuracy (tail, head, all): %.2f, %.2f, %.2f' % (result['Tail'], result['Head'], result['All']) 75 | print(msg) 76 | # Sample command: 77 | # python eval/vqa/gqa_eval_from_file.py --predictions "{EXPT_ROOT}/gqa_1.1/SpectralDecouplingTrainer/lambda_0.001_gamma_0.001/ans_preds_Val All_Main_epoch_30.json" --eval_tail_size 78 | -------------------------------------------------------------------------------- /experiments/gqa_experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def set_if_null(option, attr_name, val): 5 | if not hasattr(option, attr_name) or getattr(option, attr_name) is None: 6 | setattr(option, attr_name, val) 7 | 8 | 9 | def gqa_experiments(option, run): 10 | # Method-specific arguments are defined in the bash files inside: scripts/gqa-ood 11 | 12 | # Here, we configure rest of the arguments which likely DO NOT NEED TO BE CHANGED 13 | option.dataset_name = 'GQA' 14 | option.data_dir = option.root_dir + f"/{option.dataset_name}" 15 | option.train_ratio = None # set to sth like 0.1 for debugging 16 | 17 | # Define the bias and target variables 18 | # option.num_bias_classes: # This is set to num_groups in main.py for GQA 19 | # option.num_classes: # This is set by gqa_dataset 20 | # option.num_groups: # This is set by gqa_dataset 21 | 22 | # Optimizer + Model 23 | set_if_null(option, 'model_name', 'UpDn') 24 | set_if_null(option, 'optimizer_name', 'Adam') 25 | set_if_null(option, 'batch_size', 128) 26 | set_if_null(option, 'epochs', 30) 27 | 28 | # We allow training on subset of data using train_ratio argument. If it is "None", then the full set is used. 29 | if option.train_ratio is not None: 30 | dataset_name = f"{option.dataset_name}_ratio_{option.train_ratio}" 31 | else: 32 | dataset_name = option.dataset_name 33 | 34 | # Configure name of the experiments and directory to save the results 35 | if option.save_dir is None: 36 | option.save_dir = os.path.join(option.root_dir, option.project_name, dataset_name, option.trainer_name) 37 | 38 | if option.expt_name is None: 39 | option.expt_name = f"lr_{option.lr}_wd_{option.weight_decay}" 40 | 41 | if option.key_to_group_by is not None: 42 | option.expt_name += f'_expl_bias_{option.key_to_group_by}' 43 | 44 | # Method-specific configurations 45 | if option.trainer_name == 'GroupDROTrainer': 46 | option.balanced_sampling_attributes = ['group_ix'] # Perform balanced sampling 47 | option.group_by = 'group_ix' # Groups for GDRO 48 | option.key_to_group_by = 'group_name' # Used to find readable group names 49 | 50 | if option.trainer_name == 'RUBiTrainer': 51 | set_if_null(option, 'bias_model_hid_dims', 2048) 52 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments 53 | set_if_null(option, 'bias_model_name', 'MLP3') 54 | set_if_null(option, 'bias_variable_type', 'categorical') 55 | 56 | if option.trainer_name == 'LNLTrainer': 57 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments 58 | set_if_null(option, 'bias_predictor_name', 'MLP3') 59 | set_if_null(option, 'bias_predictor_in_layer', 'question_features') 60 | option.bias_predictor_in_dims = 1024 61 | option.bias_predictor_hid_dims = 1024 62 | 63 | if option.trainer_name == 'IRMv1Trainer': 64 | set_if_null(option, 'num_envs_per_batch', 16) 65 | if option.bias_variable_name != 'question_features': 66 | option.bias_variable_type = 'categorical' 67 | 68 | # Test epochs 69 | option.test_every = 15 70 | option.save_every = 30 71 | option.save_model_every = 30 72 | 73 | if option.model_name == 'MCAN': 74 | option.bias_variable_dims = 2048 75 | else: 76 | option.bias_variable_dims = 1024 77 | 78 | run(option) 79 | -------------------------------------------------------------------------------- /models/vqa/updn/tda.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Zhenwei Shao https://github.com/ParadoxZW 4 | # based on the implementation in https://github.com/hengyuan-hu/bottom-up-attention-vqa 5 | # ELU is chosen as the activation function in non-linear layers due to 6 | # the experiment results that indicate ELU is better than ReLU in BUTD model. 7 | # -------------------------------------------------------- 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.utils.weight_norm import weight_norm 12 | import torch 13 | import math 14 | 15 | 16 | # ------------------------------ 17 | # ----- Weight Normal MLP ------ 18 | # ------------------------------ 19 | 20 | class MLP(nn.Module): 21 | """ 22 | class for non-linear fully connect network 23 | """ 24 | 25 | def __init__(self, dims, act='ELU', dropout_r=0.0): 26 | super(MLP, self).__init__() 27 | 28 | layers = [] 29 | for i in range(len(dims) - 1): 30 | in_dim = dims[i] 31 | out_dim = dims[i + 1] 32 | if dropout_r > 0: 33 | layers.append(nn.Dropout(dropout_r)) 34 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 35 | if act != '': 36 | layers.append(getattr(nn, act)()) 37 | 38 | self.mlp = nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | return self.mlp(x) 42 | 43 | 44 | # ------------------------------ 45 | # ---Top Down Attention Map ---- 46 | # ------------------------------ 47 | 48 | 49 | class AttnMap(nn.Module): 50 | ''' 51 | implementation of top down attention 52 | ''' 53 | 54 | def __init__(self, img_feat_size, hidden_size, dropout_r): 55 | super(AttnMap, self).__init__() 56 | self.linear_q = weight_norm( 57 | nn.Linear(hidden_size, hidden_size), dim=None) 58 | self.linear_v = weight_norm( 59 | nn.Linear(img_feat_size, img_feat_size), dim=None) 60 | self.nonlinear = MLP( 61 | [img_feat_size + hidden_size, hidden_size], dropout_r=dropout_r) 62 | self.linear = weight_norm(nn.Linear(hidden_size, 1), dim=None) 63 | 64 | def forward(self, q, v): 65 | v = self.linear_v(v) 66 | q = self.linear_q(q) 67 | logits = self.logits(q, v) 68 | w = nn.functional.softmax(logits, 1) 69 | return w 70 | 71 | def logits(self, q, v): 72 | num_objs = v.size(1) 73 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 74 | vq = torch.cat((v, q), 2) 75 | joint_repr = self.nonlinear(vq) 76 | logits = self.linear(joint_repr) 77 | return logits 78 | 79 | 80 | # ------------------------------ 81 | # ---- Attended Joint Map ------ 82 | # ------------------------------ 83 | 84 | 85 | class TDA(nn.Module): 86 | def __init__(self, img_feat_size, hidden_size, dropout_r): 87 | super(TDA, self).__init__() 88 | 89 | self.v_att = AttnMap(img_feat_size, hidden_size, dropout_r) 90 | self.q_net = MLP([hidden_size, hidden_size]) 91 | self.v_net = MLP([img_feat_size, hidden_size]) 92 | 93 | def forward(self, q, v): 94 | att = self.v_att(q, v) 95 | atted_v = (att * v).sum(1) 96 | q_repr = self.q_net(q) 97 | v_repr = self.v_net(atted_v) 98 | joint_repr = q_repr * v_repr 99 | return joint_repr 100 | -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ClasswiseEMA: 6 | 7 | def __init__(self, dataset_size, alpha=0.7): 8 | self.dataset_size = dataset_size 9 | self.alpha = alpha 10 | self.labels = torch.ones(dataset_size).long().cuda() * -1000 11 | self.parameter = torch.zeros(dataset_size).cuda() 12 | self.updated = torch.zeros(dataset_size).cuda() 13 | 14 | def update(self, data, dataset_ix, label): 15 | # if dataset_ix.__class__ == torch.Tensor: 16 | # dataset_ix = dataset_ix.cuda() 17 | param = self.alpha * self.parameter[dataset_ix] + ( 18 | 1 - self.alpha * self.updated[dataset_ix]) * data 19 | self.updated[dataset_ix] = 1 20 | self.labels[dataset_ix] = label 21 | self.parameter[dataset_ix] = param.detach() 22 | return param 23 | 24 | def max_loss(self, label): 25 | label_index = torch.where(self.labels == label)[0] 26 | if len(label_index) == 0: 27 | return 1 28 | else: 29 | return self.parameter[label_index].max() 30 | 31 | 32 | class EMA: 33 | def __init__(self, dataset_size, alpha=0.9): 34 | self.dataset_size = dataset_size 35 | self.alpha = alpha 36 | self.parameter = torch.zeros(dataset_size).cuda() 37 | self.updated = torch.zeros(dataset_size).cuda() 38 | 39 | def update(self, data, dataset_ix): 40 | param = self.alpha * self.parameter[dataset_ix] + ( 41 | 1 - self.alpha * self.updated[dataset_ix]) * data 42 | self.updated[dataset_ix] = 1 43 | self.parameter[dataset_ix] = param.detach() 44 | return param 45 | 46 | 47 | class WeightsEMA: 48 | """Exponential moving average of model parameters. 49 | https://anmoljoshi.com/Pytorch-Dicussions/ 50 | Args: 51 | model (torch.nn.Module): Model with parameters whose EMA will be kept. 52 | decay (float): Decay rate for exponential moving average. 53 | """ 54 | 55 | def __init__(self, model, decay=0.999): 56 | self.decay = decay 57 | self.shadow = {} 58 | self.original = {} 59 | 60 | # Register model parameters 61 | for name, param in model.named_parameters(): 62 | if param.requires_grad: 63 | self.shadow[name] = param.data.clone() 64 | 65 | def __call__(self, model, num_updates): 66 | decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates)) 67 | for name, param in model.named_parameters(): 68 | if param.requires_grad: 69 | assert name in self.shadow 70 | new_average = \ 71 | (1.0 - decay) * param.data + decay * self.shadow[name] 72 | self.shadow[name] = new_average.clone() 73 | 74 | def assign(self, model): 75 | """Assign exponential moving average of parameter values to the 76 | respective parameters. 77 | Args: 78 | model (torch.nn.Module): Model to assign parameter values. 79 | """ 80 | for name, param in model.named_parameters(): 81 | if param.requires_grad: 82 | assert name in self.shadow 83 | self.original[name] = param.data.clone() 84 | param.data = self.shadow[name] 85 | 86 | def resume(self, model): 87 | """Restore original parameters to a model. That is, put back 88 | the values that were in each parameter at the last call to `assign`. 89 | Args: 90 | model (torch.nn.Module): Model to assign parameter values. 91 | """ 92 | for name, param in model.named_parameters(): 93 | if param.requires_grad: 94 | assert name in self.shadow 95 | param.data = self.original[name] 96 | -------------------------------------------------------------------------------- /models/vqa/updn/net.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Zhenwei Shao https://github.com/ParadoxZW 4 | # -------------------------------------------------------- 5 | 6 | from models.vqa.updn.tda import TDA 7 | from models.vqa.vqa_adapter import VQAAdapter 8 | 9 | import torch.nn as nn 10 | from torch.nn.utils.weight_norm import weight_norm 11 | import torch 12 | 13 | 14 | # ------------------------- 15 | # ---- Main BUTD Model ---- 16 | # ------------------------- 17 | 18 | class UpDn(nn.Module): 19 | def __init__(self, 20 | pretrained_emb, 21 | token_size, 22 | answer_size, 23 | img_feat_size=1024, 24 | word_embed_size=300, 25 | hidden_size=1024, 26 | use_glove=True, 27 | flat_out_size=2048, 28 | dropout_r=0.2, 29 | classifier_dropout_r=0.5 30 | ): 31 | super(UpDn, self).__init__() 32 | 33 | self.embedding = nn.Embedding( 34 | num_embeddings=token_size, 35 | embedding_dim=word_embed_size 36 | ) 37 | 38 | # Loading the GloVe embedding weights 39 | if use_glove: 40 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 41 | 42 | self.rnn = nn.LSTM( 43 | input_size=word_embed_size, 44 | hidden_size=hidden_size, 45 | num_layers=1, 46 | batch_first=True 47 | ) 48 | 49 | self.adapter = VQAAdapter(hidden_size=hidden_size) 50 | 51 | self.backbone = TDA(img_feat_size=hidden_size, 52 | hidden_size=hidden_size, 53 | dropout_r=dropout_r 54 | ) 55 | 56 | # Classification layers 57 | self.fc1 = weight_norm(nn.Linear(hidden_size, flat_out_size), dim=None) 58 | self.relu = nn.ReLU() 59 | self.dropout = nn.Dropout(classifier_dropout_r, inplace=True) 60 | self.classifier = weight_norm(nn.Linear(flat_out_size, answer_size), dim=None) 61 | # layers = [ 62 | # weight_norm(nn.Linear(hidden_size, 63 | # flat_out_size), dim=None), 64 | # nn.ReLU(), 65 | # nn.Dropout(classifier_dropout_r, inplace=True), 66 | # weight_norm(nn.Linear(flat_out_size, answer_size), dim=None) 67 | # ] 68 | # self.classifier = nn.Sequential(*layers) 69 | 70 | def forward(self, frcn_feat, bbox_feat, question_token_ixs): 71 | # Pre-process Language Feature 72 | # lang_feat_mask = make_mask(ques_ix.unsqueeze(2)) 73 | lang_feat = self.embedding(question_token_ixs) 74 | lang_feat, _ = self.rnn(lang_feat) 75 | 76 | img_feat, _ = self.adapter(frcn_feat, bbox_feat) 77 | 78 | # Backbone Framework 79 | joint_feat = self.backbone( 80 | lang_feat[:, -1], 81 | img_feat 82 | ) 83 | 84 | # Classification layers 85 | # proj_feat = self.classifier(joint_feat) 86 | fc1 = self.fc1(joint_feat) 87 | clf_in = self.dropout(self.relu(fc1)) 88 | logits = self.classifier(clf_in) 89 | 90 | return { 91 | 'logits': logits, 92 | 'before_logits': fc1, 93 | 'question_features': lang_feat[:, -1], 94 | 'visual_features': img_feat, 95 | 'joint_features': joint_feat 96 | } 97 | 98 | 99 | class UpDnNoDropout(UpDn): 100 | def __init__(self, 101 | pretrained_emb, token_size, answer_size): 102 | super().__init__(pretrained_emb, token_size, answer_size, dropout_r=0, classifier_dropout_r=0) 103 | -------------------------------------------------------------------------------- /trainers/group_upweighting_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | from trainers.base_trainer import BaseTrainer 5 | from utils.losses import * 6 | from utils.metric_visualizer import LossVisualizer 7 | 8 | 9 | class GroupUpweightingTrainer(BaseTrainer): 10 | """ 11 | Simple upweighting technique which multiplies the loss by inverse group frequency. 12 | This has been found to work well when models are sufficiently underparameterized (e.g., low learning rate, high weight decay, fewer model parameters etc) 13 | Paper that investigated underparameterization with upweighting method: https://arxiv.org/abs/2005.04345 14 | """ 15 | 16 | def __init__(self, option): 17 | super(GroupUpweightingTrainer, self).__init__(option) 18 | self.loss_visualizer = LossVisualizer(self.option.expt_dir) 19 | 20 | def get_keys_to_save(self): 21 | return super().get_keys_to_save() + ['group_ix_to_cnt', 'group_ix_to_weight', 'group_name_to_weight'] 22 | 23 | def init_group_weights(self, loader): 24 | logging.getLogger().info("Initializing the group weights...") 25 | self.group_ix_to_cnt = {} 26 | self.group_ix_to_weight, self.group_name_to_weight = {}, {} 27 | self.group_ix_to_name = {} 28 | total_samples = 0 29 | for batch in loader: 30 | # If 'explicit_group_ix' is returned in the batch, then uses that to group the data 31 | # Else, uses 'group_ix' key to group the data 32 | if 'explicit_group_ix' in batch: 33 | grp_key = 'explicit_group_ix' 34 | grp_name_key = 'explicit_group_name' 35 | else: 36 | grp_key = 'group_ix' 37 | grp_name_key = 'group_name' 38 | for grp_ix, grp_name in zip(batch[grp_key], batch[grp_name_key]): 39 | grp_ix = int(grp_ix) 40 | if grp_ix not in self.group_ix_to_cnt: 41 | self.group_ix_to_cnt[grp_ix] = 0 42 | self.group_ix_to_cnt[grp_ix] += 1 43 | self.group_ix_to_name[grp_ix] = grp_name 44 | total_samples += 1 45 | for group_ix in self.group_ix_to_cnt: 46 | self.group_ix_to_weight[group_ix] = total_samples / self.group_ix_to_cnt[group_ix] 47 | self.group_name_to_weight[self.group_ix_to_name[group_ix]] \ 48 | = total_samples / self.group_ix_to_cnt[group_ix] 49 | 50 | def train(self, train_loader, test_loaders=None, unbalanced_train_loader=None): 51 | self.init_group_weights(train_loader) 52 | super().train(train_loader, test_loaders, unbalanced_train_loader) 53 | 54 | def _train_epoch(self, epoch, data_loader): 55 | self._mode_setting(is_train=True) 56 | 57 | for i, batch in enumerate(data_loader): 58 | batch = self.prepare_batch(batch) 59 | out = self.forward_model(self.model, batch) 60 | logits = out['logits'] 61 | batch_losses = self.loss(out['logits'], torch.squeeze(batch['y'])) 62 | group_key = 'explicit_group_ix' if 'explicit_group_ix' in batch else 'group_ix' 63 | 64 | weights = torch.FloatTensor([self.group_ix_to_weight[group_ix] for group_ix in batch[group_key]]).cuda() 65 | weighted_batch_losses = weights * batch_losses # Multiply per-sample losses by weights for the corresponding groups 66 | 67 | self.optim.zero_grad() 68 | weighted_batch_losses.mean().backward() 69 | self.optim.step() 70 | self.loss_visualizer.update('Train', 'Loss', batch_losses.mean().detach().item()) 71 | self.loss_visualizer.update('Train', 'Weighted Loss', weighted_batch_losses.mean().detach().item()) 72 | self.update_generalization_metrics('Train', batch, weighted_batch_losses) 73 | self._after_train_epoch(epoch) 74 | -------------------------------------------------------------------------------- /trainers/spectral_decoupling_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import math 4 | import os 5 | 6 | import torch 7 | from torch import nn 8 | from torch import optim 9 | 10 | from models.model_factory import build_model 11 | from utils import trainer_utils 12 | from utils.metric_visualizer import AccuracyVisualizer, LossVisualizer 13 | from utils.metrics import Accuracy, GroupWiseAccuracy 14 | from torch.optim import * 15 | import json 16 | from copy import deepcopy 17 | from utils.trainer_utils import create_optimizer 18 | from torch.nn import * 19 | from utils.losses import * 20 | from trainers.base_trainer import BaseTrainer 21 | 22 | 23 | class SpectralDecouplingTrainer(BaseTrainer): 24 | """ 25 | Implementation for: 26 | Pezeshki, Mohammad, et al. "Gradient Starvation: A Learning Proclivity in Neural Networks." arXiv preprint arXiv:2011.09468 (2020). 27 | 28 | The paper shows that decay and shift in network's logits can help decouple learning of features, which may enable learning of signal too. 29 | 30 | Changes from the original implementation: 31 | The original implementation uses this loss: = torch.log(1.0 + torch.exp(-yhat[:, 0] * (2.0 * y - 1.0))) 32 | However, we just use cross-entropy since the above formulation doesn't make sense when there are more than 2 classes. 33 | """ 34 | 35 | def __init__(self, option): 36 | super(SpectralDecouplingTrainer, self).__init__(option) 37 | self.loss_visualizer = LossVisualizer(self.option.expt_dir) 38 | if self.option.spectral_decoupling_lambdas is None: 39 | assert self.option.spectral_decoupling_lambda is not None, 'lambda not specified' 40 | self.option.spectral_decoupling_lambdas = torch.ones( 41 | self.option.num_classes) * self.option.spectral_decoupling_lambda 42 | if self.option.spectral_decoupling_gammas is None: 43 | assert self.option.spectral_decoupling_gamma is not None, 'gammas not specified' 44 | self.option.spectral_decoupling_gammas = torch.ones( 45 | self.option.num_classes) * self.option.spectral_decoupling_gamma 46 | 47 | def _train_epoch(self, epoch, data_loader): 48 | self._mode_setting(is_train=True) 49 | 50 | for i, batch in enumerate(data_loader): 51 | # Forward pass 52 | batch = self.prepare_batch(batch) 53 | out = self.forward_model(self.model, batch) 54 | logits = out['logits'] 55 | 56 | # Compute the prediction loss 57 | # The original paper uses this formulation: 58 | # per_sample_losses = torch.log(1.0 + torch.exp(-yhat[:, 0] * (2.0 * y - 1.0))) 59 | # However, we just use cross-entropy since the above formulation doesn't make sense when there are more than 2 classes. 60 | pred_losses = self.loss(out['logits'], torch.squeeze(batch['y'])) 61 | 62 | per_class_lambdas = torch.FloatTensor( 63 | [self.option.spectral_decoupling_lambdas[y] for y in batch['y']]).cuda() 64 | per_class_gammas = torch.FloatTensor( 65 | [self.option.spectral_decoupling_gammas[y] for y in batch['y']]).cuda() 66 | 67 | # The loss is based on equation 28 of the paper 68 | # softmax = torch.softmax(logits, dim=1) 69 | # gt_softmax = softmax.gather(1, batch['y'].view(-1, 1)).squeeze() 70 | gt_logits = logits.gather(1, batch['y'].view(-1, 1)).squeeze() 71 | 72 | logit_l2_losses = 0.5 * per_class_lambdas * (gt_logits - per_class_gammas) ** 2 73 | total_losses = pred_losses + logit_l2_losses 74 | 75 | self.optim.zero_grad() 76 | total_losses.mean().backward() 77 | self.optim.step() 78 | 79 | self.loss_visualizer.update('Train', 'Total Loss', total_losses.mean().detach().item()) 80 | self.loss_visualizer.update('Train', 'Pred Loss', pred_losses.mean().detach().item()) 81 | self.loss_visualizer.update('Train', 'Logit L2 Loss', logit_l2_losses.mean().detach().item()) 82 | 83 | self._after_train_epoch(epoch) 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## An Investigation of Critical Issues in Bias Mitigation Techniques 2 | 3 | Our paper examines if the state-of-the-art bias mitigation methods are able to perform well on more realistic settings: with multiple sources of biases, hidden biases and without access to test distributions. This repository has implementations/re-implementations for seven popular techniques. 4 | 5 | 6 | ### Setup 7 | 8 | #### Install Dependencies 9 | 10 | `conda create -n bias_mitigator python=3.7` 11 | 12 | `source activate bias_mitigator` 13 | 14 | `conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch` 15 | 16 | `conda install tqdm opencv pandas` 17 | 18 | #### Configure Path 19 | 20 | - Edit the `ROOT` variable in `common.sh`. This directory will contain the datasets and the experimental results. 21 | 22 | #### Datasets 23 | - For each dataset, we test on train, val and test splits. Each dataset file contains a function to create a dataloader for all of these splits. 24 | 25 | ##### Biased MNIST v1/v2 26 | Since the publication, we created BiasedMNISTv2, a more challenging version of the dataset. Version 2 includes increased image sizes, spuriously correlated digit scales, distracting letters instead of simplistic geometric shapes, and updated background textures. 27 | 28 | We encourage the community to use the [BiasedMNIST v2](https://github.com/erobic/occam-nets-v1). 29 | 30 | 31 | You can download the BiasedMNISTv1 (WACV 2021) [from here](https://drive.google.com/file/d/1RlvskdRjdAj6sqpYeD48sR2uJnxAYmv5/view?usp=sharing). 32 | 33 | Both BiasedMNISTv1 and v2 are released under Creative Commons Attribution 4.0 International (CC BY 4.0) license. 34 | 35 | 36 | ##### CelebA 37 | - Download the dataset [from here](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8) and extract the data to `${ROOT}` 38 | - We adapted the data loader from `https://github.com/kohpangwei/group_DRO` 39 | 40 | ##### GQA-OOD 41 | - Download GQA (object features, spatial features and questions) [from here](https://cs.stanford.edu/people/dorarad/gqa/download.html) 42 | - Build the GQA-OOD by following [these instructions](https://github.com/gqa-ood/GQA-OOD/tree/master/code) 43 | 44 | - Download embeddings 45 | `python -m spacy download en_vectors_web_lg` 46 | 47 | - Preprocess visual/spatial features 48 | `./scripts/gqa-ood/preprocess_gqa.sh` 49 | 50 | #### Run the methods 51 | 52 | We have provided a separate bash file for running each method on each dataset in the `scripts` directory. Here is a sample script: 53 | 54 | ```bash 55 | source activate bias_mitigator 56 | 57 | TRAINER_NAME='BaseTrainer' 58 | lr=1e-3 59 | wd=0 60 | python main.py \ 61 | --expt_type celebA_experiments \ 62 | --trainer_name ${TRAINER_NAME} \ 63 | --lr ${lr} \ 64 | --weight_decay ${wd} \ 65 | --expt_name ${TRAINER_NAME} \ 66 | --root_dir ${ROOT} 67 | ``` 68 | 69 | ### Contribute! 70 | 71 | - If you want to add more methods, simply follow one of the implementations inside `trainers` directory. 72 | 73 | 74 | ### Highlights from the paper: 75 | 76 | 1. Overall, methods fail when datasets contain multiple sources of bias, even if they excel on smaller settings with one or two sources of bias (e.g., CelebA). 77 | ![](images/main_table.jpg) 78 | 79 | 2. Methods can exploit both implicit (hidden) and explicit biases. 80 | ![](images/bias_exploitation.jpg) 81 | 82 | 3. Methods cannot handle multiple sources of bias even when they are explicitly labeled. 83 | ![](images/scalability.jpg) 84 | 85 | 4. Most methods show high sensitivity to the tuning distribution especially for minority groups 86 | ![](images/distribution_variance.jpg) 87 | 88 | 89 | ### Citation 90 | ``` 91 | @article{shrestha2021investigation, 92 | title={An investigation of critical issues in bias mitigation techniques}, 93 | author={Shrestha, Robik and Kafle, Kushal and Kanan, Christopher}, 94 | journal={Workshop on Applications of Computer Vision}, 95 | year={2021} 96 | } 97 | ``` 98 | 99 | This work was supported in part by the DARPA/SRI Lifelong Learning Machines program[HR0011-18-C-0051], AFOSR grant [FA9550-18-1-0121], and NSF award #1909696. 100 | -------------------------------------------------------------------------------- /experiments/celebA_experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def set_if_null(option, attr_name, val): 5 | if not hasattr(option, attr_name) or getattr(option, attr_name) is None: 6 | setattr(option, attr_name, val) 7 | 8 | 9 | def celebA_experiments(option, run): 10 | # Method-specific arguments are mostly defined in the bash files inside: scripts/celebA 11 | 12 | # Here, we configure rest of the arguments which likely DO NOT NEED TO BE CHANGED 13 | option.dataset_name = 'CelebA' 14 | option.data_dir = option.root_dir + f"/{option.dataset_name}" 15 | 16 | option.train_ratio = None # None, set to sth like 0.1 for debugging 17 | 18 | # Define the bias and target variables 19 | option.target_name = 'Blond_Hair' 20 | option.bias_variables = ['Male'] 21 | option.bias_variable_name = 'Male' 22 | option.num_bias_classes = 2 23 | option.num_classes = 2 24 | option.num_groups = 4 25 | 26 | # Optimizer + Model 27 | set_if_null(option, 'optimizer_name', 'SGD') 28 | set_if_null(option, 'batch_size', 128) 29 | set_if_null(option, 'epochs', 50) 30 | set_if_null(option, 'model_name', 'ResNet18') 31 | 32 | # We allow training on subset of data using train_ratio argument. If it is "None", then the full set is used. 33 | if option.train_ratio is not None: 34 | dataset_name = f"{option.dataset_name}_ratio_{option.train_ratio}" 35 | else: 36 | dataset_name = option.dataset_name 37 | 38 | # Configure name of the experiments and directory to save the results 39 | if option.save_dir is None: 40 | option.save_dir = os.path.join(option.root_dir, option.project_name, dataset_name, 'predict_' + option.target_name, 41 | option.trainer_name) 42 | if option.expt_name is None: 43 | option.expt_name = f"lr_{option.lr}_wd_{option.weight_decay}" 44 | 45 | # Method-specific configurations 46 | 47 | if option.trainer_name == 'GroupDROTrainer': 48 | option.balanced_sampling_attributes = ['group_ix'] # Perform balanced sampling 49 | option.group_by = 'group_ix' # Groups for GDRO 50 | option.key_to_group_by = 'group_name' # Used to find readable group names 51 | 52 | if option.trainer_name == 'RUBiTrainer': 53 | set_if_null(option, 'bias_model_hid_dims', 512) 54 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments 55 | set_if_null(option, 'bias_model_name', 'MLP1') 56 | set_if_null(option, 'bias_variable_type', 'categorical') 57 | option.bias_variable_dims = 2 58 | 59 | if option.trainer_name == 'LNLTrainer': 60 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments 61 | feature_dims = get_feature_dims(option.model_name, option.num_classes) 62 | set_if_null(option, 'bias_predictor_name', 'MLP1') 63 | set_if_null(option, 'bias_predictor_in_layer', 'model.pooled2') 64 | option.bias_predictor_in_dims = feature_dims[option.bias_predictor_in_layer] 65 | option.bias_predictor_hid_dims = feature_dims[option.bias_predictor_in_layer] 66 | 67 | if option.trainer_name == 'IRMv1Trainer': 68 | set_if_null(option, 'num_envs_per_batch', 4) 69 | 70 | if option.trainer_name == 'SpectralDecouplingTrainer': 71 | option.spectral_decoupling_lambdas = [10.0, 10.0] # For CelebA, we have per-class lambdas and gammas for SD. 72 | option.spectral_decoupling_gammas = [0.44, 2.5] 73 | 74 | # Test epochs 75 | option.test_epochs = [e for e in range(40, 51)] # Due to instability, we average accuracies over the last 10 epochs 76 | option.test_every = 10 # We further test every 10 epochs 77 | option.save_every = 50 78 | 79 | run(option) 80 | 81 | 82 | def get_feature_dims(model_name, num_classes): 83 | if 'ResNet18' in model_name: 84 | return { 85 | 'model.conv1': 64, 86 | 'model.pooled_conv1': 64, 87 | 'model.layer1.1.conv2': 64, 88 | 'model.pooled1': 64, 89 | 'model.layer2.1.conv2': 128, 90 | 'model.layer2_flattened': 128 * 28 * 28, 91 | 'model.pooled2': 128, 92 | 'model.layer3.1.conv2': 256, 93 | 'model.pooled3': 256, 94 | 'model.layer4.1.conv2': 512, 95 | 'model.pooled4': 512, 96 | 'model.fc': 10, 97 | 'logits': num_classes 98 | } 99 | -------------------------------------------------------------------------------- /datasets/vqa/gqa_feat_preproc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # GQA spatial features & object features .h5 files to .npz files transform script 4 | # Written by Pengbing Gao https://github.com/nbgao 5 | # -------------------------------------------------------- 6 | 7 | ''' 8 | Command line example: 9 | (1) Process spatial features 10 | python gqa_feat_preproc.py --mode=spatial --spatial_dir=./spatialFeatures --out_dir=./feats/gqa-grid 11 | 12 | (2) Process object features 13 | python gqa_feat_preproc.py --mode=object --object_dir=./objectFeatures --out_dir=./feats/gqa-frcn 14 | ''' 15 | 16 | import h5py, glob, json, cv2, argparse 17 | import numpy as np 18 | 19 | 20 | # spatial features 21 | def process_spatial_features(feat_path, out_path): 22 | info_file = feat_path + '/gqa_spatial_info.json' 23 | try: 24 | info = json.load(open(info_file, 'r')) 25 | except: 26 | print('Failed to open info file:', info_file) 27 | return 28 | print('Total grid features', len(info)) 29 | 30 | print('Making the
to dict...') 31 | h5idx_to_imgid = {} 32 | for img_id in info: 33 | h5idx_to_imgid[str(info[img_id]['file']) + '_' + str(info[img_id]['idx'])] = img_id 34 | 35 | for ix in range(16): 36 | feat_file = feat_path + '/gqa_spatial_' + str(ix) + '.h5' 37 | print('Processing', feat_file) 38 | try: 39 | feat_dict = h5py.File(feat_file, 'r') 40 | except: 41 | print('Failed to open feat file:', feat_file) 42 | return 43 | 44 | features = feat_dict['features'] 45 | 46 | for iy in range(features.shape[0]): 47 | img_id = h5idx_to_imgid[str(ix) + '_' + str(iy)] 48 | feature = features[iy] 49 | # save to .npz file ['x'] 50 | np.savez( 51 | out_path + '/' + img_id + '.npz', 52 | x=feature.reshape(2048, 49).transpose(1, 0), # (49, 2048) 53 | ) 54 | 55 | print('Process spatial features successfully!') 56 | 57 | 58 | # object features 59 | def process_object_features(feat_path, out_path): 60 | info_file = feat_path + '/gqa_objects_info.json' 61 | try: 62 | info = json.load(open(info_file, 'r')) 63 | except: 64 | print('Failed to open info file:', info_file) 65 | return 66 | print('Total frcn features', len(info)) 67 | 68 | print('Making the
to dict...') 69 | h5idx_to_imgid = {} 70 | for img_id in info: 71 | h5idx_to_imgid[str(info[img_id]['file']) + '_' + str(info[img_id]['idx'])] = img_id 72 | 73 | for ix in range(16): 74 | feat_file = feat_path + '/gqa_objects_' + str(ix) + '.h5' 75 | print('Processing', feat_file) 76 | 77 | try: 78 | feat_dict = h5py.File(feat_file, 'r') 79 | except: 80 | print('Failed to open feat file:', feat_file) 81 | return 82 | 83 | bboxes = feat_dict['bboxes'] 84 | features = feat_dict['features'] 85 | 86 | for iy in range(features.shape[0]): 87 | img_id = h5idx_to_imgid[str(ix) + '_' + str(iy)] 88 | img_info = info[img_id] 89 | objects_num = img_info['objectsNum'] 90 | # save to .npz file ['x', 'bbox', 'width', 'height'] 91 | np.savez( 92 | out_path + '/' + img_id + '.npz', 93 | x=features[iy, :objects_num], 94 | bbox=bboxes[iy, :objects_num], 95 | width=img_info['width'], 96 | height=img_info['height'], 97 | ) 98 | 99 | print('Process object features successfully!') 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = argparse.ArgumentParser(description='gqa_h52npz') 104 | parser.add_argument('--mode', '-mode', choices=['object', 'spatial', 'frcn', 'grid'], help='mode', type=str) 105 | parser.add_argument('--object_dir', '-object_dir', help='object features dir', type=str) 106 | parser.add_argument('--spatial_dir', '-spatial_dir', help='spatial features dir', type=str) 107 | parser.add_argument('--out_dir', '-out_dir', help='output dir', type=str) 108 | 109 | args = parser.parse_args() 110 | 111 | mode = args.mode 112 | object_path = args.object_dir 113 | spatial_path = args.spatial_dir 114 | out_path = args.out_dir 115 | 116 | print('mode:', mode) 117 | print('object_path:', object_path) 118 | print('spatial_path:', spatial_path) 119 | print('out_path:', out_path) 120 | 121 | # process spatial features 122 | if mode in ['spatial', 'grid']: 123 | process_spatial_features(spatial_path, out_path) 124 | 125 | # process object features 126 | if mode in ['object', 'frcn']: 127 | process_object_features(object_path, out_path) 128 | -------------------------------------------------------------------------------- /trainers/group_dro_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | from trainers.base_trainer import BaseTrainer 5 | from utils.losses import * 6 | from utils.metric_visualizer import LossVisualizer 7 | 8 | 9 | class GroupDROTrainer(BaseTrainer): 10 | """ 11 | Implementation for: 12 | Sagawa, Shiori, et al. "Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization." (ICLR 2020). 13 | GroupDRO groups the data using explicit bias variables and class labels and optimizes for the worst-case group. 14 | Paper: https://arxiv.org/pdf/1911.08731.pdf 15 | Original codebase: https://github.com/kohpangwei/group_DRO 16 | """ 17 | def __init__(self, option): 18 | super(GroupDROTrainer, self).__init__(option) 19 | self.group_weights = torch.ones(self.option.num_groups).cuda() / self.option.num_groups 20 | self.loss_visualizer = LossVisualizer(self.option.expt_dir) 21 | self.group_names = None 22 | 23 | def get_keys_to_save(self): 24 | return super().get_keys_to_save() + ['group_weights'] 25 | 26 | def compute_group_loss_and_counts(self, losses, group_idxs): 27 | """ 28 | Computes groupwise loss and count 29 | :param losses: 30 | :param group_idxs: 31 | :return: 32 | """ 33 | group_map = (torch.LongTensor(group_idxs).cuda() == torch.arange(self.option.num_groups).unsqueeze( 34 | 1).long().cuda()).float() 35 | group_count = group_map.sum(1) 36 | group_denom = group_count + (group_count == 0).float() # avoid nans 37 | group_loss = (group_map @ losses.view(-1)) / group_denom 38 | return group_loss, group_count 39 | 40 | def update_group_weights_and_compute_robust_loss(self, group_loss): 41 | """ 42 | Use exponent of weighted loss to update the group weights. Uses those weights to compute overall robust loss 43 | :param group_loss: 44 | :return: 45 | """ 46 | self.group_weights = self.group_weights * torch.exp(self.option.group_weight_step_size * group_loss.data) 47 | self.group_weights = self.group_weights / (self.group_weights.sum()) 48 | robust_loss = group_loss @ self.group_weights.detach() 49 | return robust_loss, self.group_weights 50 | 51 | def _train_epoch(self, epoch, data_loader): 52 | self._mode_setting(is_train=True) 53 | for i, batch in enumerate(data_loader): 54 | batch = self.prepare_batch(batch) 55 | out = self.forward_model(self.model, batch) 56 | logits = out['logits'] 57 | batch_losses = self.loss(out['logits'], torch.squeeze(batch['y'])) 58 | 59 | # Compute the GDRO loss 60 | group_ixs = batch[self.option.group_by] 61 | group_loss, _ = self.compute_group_loss_and_counts(batch_losses, group_ixs) 62 | robust_loss, _ = self.update_group_weights_and_compute_robust_loss(group_loss) 63 | 64 | self.optim.zero_grad() 65 | robust_loss.backward(retain_graph=True) 66 | self.optim.step() 67 | 68 | self.loss_visualizer.update('Train', 'Group Loss', group_loss.mean().detach().item()) 69 | self.loss_visualizer.update('Train', 'Robust Loss', robust_loss.detach().item()) 70 | self.update_generalization_metrics('Train', batch, batch_losses) 71 | 72 | group_name_to_weights = {} 73 | for gix in self.dro_group_ix_to_name: 74 | group_name = self.dro_group_ix_to_name[gix] 75 | group_name_to_weights[group_name] = float(self.group_weights[gix]) 76 | 77 | self.loss_visualizer.update_multiple('Train Group Weights', group_name_to_weights) 78 | if self.option.enable_groupwise_metrics: 79 | self.update_groupwise_values('Train', 'Group Loss', group_loss, batch) 80 | self.update_groupwise_values('Train', 'Robust Loss', group_loss, batch) 81 | self._after_train_epoch(epoch) 82 | 83 | def train(self, train_loader, test_loaders=None, unbalanced_train_loader=None): 84 | logging.getLogger().info("Beginning the training process...") 85 | self.compute_max_dataset_ixs(train_loader, test_loaders) 86 | self._initialization() 87 | 88 | # Gather group names 89 | self.dro_group_ix_to_name = {} 90 | for batch in train_loader: 91 | for group_ix, group_name in zip(batch[self.option.group_by], batch[self.option.key_to_group_by]): 92 | self.dro_group_ix_to_name[group_ix] = group_name 93 | 94 | self._mode_setting(is_train=True) 95 | start_epoch = 1 96 | for epoch in range(start_epoch, self.option.epochs + 1): 97 | self._train_epoch(epoch, train_loader) 98 | self._after_one_epoch(epoch, test_loaders) 99 | self.after_all_epochs() 100 | -------------------------------------------------------------------------------- /datasets/vqa/ans_punct.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # based on VQA Evaluation Code 5 | # -------------------------------------------------------- 6 | 7 | import re 8 | 9 | contractions = { 10 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 11 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 12 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 13 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 14 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 15 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 16 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 17 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 18 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 19 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 20 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 21 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 22 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 23 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 24 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 25 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 26 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 27 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 28 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 29 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 30 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 31 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 32 | "someonell": "someone'll", "someones": "someone's", "somethingd": 33 | "something'd", "somethingd've": "something'd've", "something'dve": 34 | "something'd've", "somethingll": "something'll", "thats": 35 | "that's", "thered": "there'd", "thered've": "there'd've", 36 | "there'dve": "there'd've", "therere": "there're", "theres": 37 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 38 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 39 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 40 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 41 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 42 | "what's", "whatve": "what've", "whens": "when's", "whered": 43 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 44 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 45 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 46 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 47 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 48 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 49 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 50 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 51 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 52 | "you'll", "youre": "you're", "youve": "you've" 53 | } 54 | 55 | manual_map = { 'none': '0', 56 | 'zero': '0', 57 | 'one': '1', 58 | 'two': '2', 59 | 'three': '3', 60 | 'four': '4', 61 | 'five': '5', 62 | 'six': '6', 63 | 'seven': '7', 64 | 'eight': '8', 65 | 'nine': '9', 66 | 'ten': '10'} 67 | articles = ['a', 'an', 'the'] 68 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 69 | comma_strip = re.compile("(\d)(\,)(\d)") 70 | punct = [';', r"/", '[', ']', '"', '{', '}', 71 | '(', ')', '=', '+', '\\', '_', '-', 72 | '>', '<', '@', '`', ',', '?', '!'] 73 | 74 | def process_punctuation(inText): 75 | outText = inText 76 | for p in punct: 77 | if (p + ' ' in inText or ' ' + p in inText) \ 78 | or (re.search(comma_strip, inText) != None): 79 | outText = outText.replace(p, '') 80 | else: 81 | outText = outText.replace(p, ' ') 82 | outText = period_strip.sub("", outText, re.UNICODE) 83 | return outText 84 | 85 | 86 | def process_digit_article(inText): 87 | outText = [] 88 | tempText = inText.lower().split() 89 | for word in tempText: 90 | word = manual_map.setdefault(word, word) 91 | if word not in articles: 92 | outText.append(word) 93 | else: 94 | pass 95 | for wordId, word in enumerate(outText): 96 | if word in contractions: 97 | outText[wordId] = contractions[word] 98 | outText = ' '.join(outText) 99 | return outText 100 | 101 | 102 | def prep_ans(answer): 103 | answer = process_digit_article(process_punctuation(answer)) 104 | answer = answer.replace(',', '') 105 | return answer 106 | -------------------------------------------------------------------------------- /models/vqa/mcan/net.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # -------------------------------------------------------- 5 | 6 | from models.vqa.make_mask import make_mask 7 | from models.vqa.ops.fc import FC, MLP 8 | from models.vqa.ops.layer_norm import LayerNorm 9 | from models.vqa.mcan.mca import MCA_ED 10 | # from models.vqa.mcan.adapter import Adapter 11 | from models.vqa.vqa_adapter import VQAAdapter 12 | # from openvqa.models.mcan.mca import MCA_ED 13 | # from openvqa.models.mcan.adapter import Adapter 14 | 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch 18 | 19 | 20 | # ------------------------------ 21 | # ---- Flatten the sequence ---- 22 | # ------------------------------ 23 | 24 | class AttFlat(nn.Module): 25 | def __init__(self, hidden_size=1024, flat_mlp_size=512, flat_glimpses=1, dropout_r=0.1, flat_out_size=2048): 26 | super(AttFlat, self).__init__() 27 | self.flat_glimpses = flat_glimpses 28 | 29 | self.mlp = MLP( 30 | in_size=hidden_size, 31 | mid_size=flat_mlp_size, 32 | out_size=flat_glimpses, 33 | dropout_r=dropout_r, 34 | use_relu=True 35 | ) 36 | 37 | self.linear_merge = nn.Linear( 38 | hidden_size * flat_glimpses, 39 | flat_out_size 40 | ) 41 | 42 | def forward(self, x, x_mask): 43 | att = self.mlp(x) 44 | att = att.masked_fill( 45 | x_mask.squeeze(1).squeeze(1).unsqueeze(2), 46 | -1e9 47 | ) 48 | att = F.softmax(att, dim=1) 49 | 50 | att_list = [] 51 | for i in range(self.flat_glimpses): 52 | att_list.append( 53 | torch.sum(att[:, :, i: i + 1] * x, dim=1) 54 | ) 55 | 56 | x_atted = torch.cat(att_list, dim=1) 57 | x_atted = self.linear_merge(x_atted) 58 | 59 | return x_atted 60 | 61 | 62 | # ------------------------- 63 | # ---- Main MCAN Model ---- 64 | # ------------------------- 65 | 66 | class MCAN(nn.Module): 67 | def __init__(self, pretrained_emb, 68 | token_size, 69 | answer_size, 70 | word_embed_size=300, 71 | use_glove=True, 72 | hidden_size=1024, 73 | flat_out_size=2048, 74 | flat_mlp_size=512, 75 | flat_glimpses=1, 76 | ff_size=4096, 77 | dropout_r=0.1, 78 | multi_head=8): 79 | super(MCAN, self).__init__() 80 | 81 | self.embedding = nn.Embedding( 82 | num_embeddings=token_size, 83 | embedding_dim=word_embed_size 84 | ) 85 | 86 | # Loading the GloVe embedding weights 87 | if use_glove: 88 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 89 | 90 | self.lstm = nn.LSTM( 91 | input_size=word_embed_size, 92 | hidden_size=hidden_size, 93 | num_layers=1, 94 | batch_first=True 95 | ) 96 | 97 | self.adapter = VQAAdapter(hidden_size=hidden_size) 98 | 99 | self.backbone = MCA_ED(hidden_size, ff_size, dropout_r, multi_head) 100 | 101 | # Flatten to vector 102 | self.attflat_img = AttFlat(hidden_size, 103 | flat_mlp_size=flat_mlp_size, 104 | flat_glimpses=flat_glimpses, 105 | dropout_r=dropout_r, 106 | flat_out_size=flat_out_size 107 | ) 108 | self.attflat_lang = AttFlat(hidden_size, 109 | flat_mlp_size=flat_mlp_size, 110 | flat_glimpses=flat_glimpses, 111 | dropout_r=dropout_r, 112 | flat_out_size=flat_out_size) 113 | 114 | # Classification layers 115 | self.proj_norm = LayerNorm(flat_out_size) 116 | self.proj = nn.Linear(flat_out_size, answer_size) 117 | 118 | def forward(self, frcn_feat, bbox_feat, ques_ix): 119 | # Pre-process Language Feature 120 | lang_feat_mask = make_mask(ques_ix.unsqueeze(2)) 121 | lang_feat = self.embedding(ques_ix) 122 | lang_feat, _ = self.lstm(lang_feat) 123 | 124 | img_feat, img_feat_mask = self.adapter(frcn_feat, bbox_feat) 125 | 126 | # Backbone Framework 127 | lang_feat, img_feat = self.backbone( 128 | lang_feat, 129 | img_feat, 130 | lang_feat_mask, 131 | img_feat_mask 132 | ) 133 | 134 | # Flatten to vector 135 | lang_feat = self.attflat_lang( 136 | lang_feat, 137 | lang_feat_mask 138 | ) 139 | 140 | img_feat = self.attflat_img( 141 | img_feat, 142 | img_feat_mask 143 | ) 144 | 145 | # Classification layers 146 | proj_feat = lang_feat + img_feat 147 | proj_feat = self.proj_norm(proj_feat) 148 | proj_feat = self.proj(proj_feat) 149 | 150 | return { 151 | 'question_features': lang_feat, 152 | 'logits': proj_feat 153 | } 154 | 155 | # return proj_feat 156 | -------------------------------------------------------------------------------- /datasets/shape_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | class ShapeGenerator(): 7 | def __init__(self, dim=32, padding=8): 8 | self.dim = dim 9 | self.padded_dim = dim - padding * 2 10 | self.padding = padding 11 | self.shapes = ['circle', 'right_triangle', 'obtuse_triangle', 'square', 'parallelogram', 'kite', 12 | 'pentagon', 'semi_circle', 'plus', 'arrow'] 13 | 14 | def init_image(self): 15 | return np.zeros((self.padded_dim, self.padded_dim, 3), np.uint8) 16 | 17 | def generate_polygon(self, vertices, image=None, color=(255, 255, 255)): 18 | if image is None: 19 | image = self.init_image() 20 | 21 | pts = vertices.reshape((-1, 1, 2)) 22 | cv2.polylines(image, [pts], isClosed=True, color=color, thickness=1) 23 | # fill it 24 | cv2.fillPoly(image, [pts], color=color) 25 | return image 26 | 27 | def generate_circle(self): 28 | center = (int(self.padded_dim / 2), int(self.padded_dim / 2)) 29 | radius = int(self.padded_dim / 2) 30 | image = self.init_image() 31 | cv2.circle(image, center, radius, (255, 255, 255), thickness=-1) 32 | return image 33 | 34 | def generate_semi_circle(self): 35 | circle = self.generate_circle() 36 | circle[:, int(self.padded_dim / 2):] = 0 37 | return circle 38 | 39 | def generate_right_triangle(self): 40 | vertices = np.array( 41 | [[0, 0], 42 | [0, self.padded_dim], 43 | [self.padded_dim, self.padded_dim]], np.int32) 44 | return self.generate_polygon(vertices) 45 | 46 | def generate_obtuse_triangle(self): 47 | vertices = np.array([[0, 0], 48 | [self.padded_dim / 3, self.padded_dim], 49 | [self.padded_dim, self.padded_dim]], np.int32) 50 | return self.generate_polygon(vertices) 51 | 52 | def generate_square(self): 53 | vertices = np.array([[0, 0], 54 | [0, self.padded_dim], 55 | [self.padded_dim, self.padded_dim], 56 | [self.padded_dim, 0]], np.int32) 57 | return self.generate_polygon(vertices) 58 | 59 | def generate_parallelogram(self): 60 | vertices = np.array([[self.padded_dim / 3, 0], 61 | [0, self.padded_dim], 62 | [self.padded_dim * 2 / 3, self.padded_dim], 63 | [self.padded_dim, 0]], np.int32) 64 | return self.generate_polygon(vertices) 65 | 66 | def generate_kite(self): 67 | vertices = np.array([[self.padded_dim / 2, 0], 68 | [0, self.padded_dim / 3], 69 | [self.padded_dim / 2, self.padded_dim], 70 | [self.padded_dim, self.padded_dim / 3]], np.int32) 71 | return self.generate_polygon(vertices) 72 | 73 | def generate_pentagon(self): 74 | vertices = np.array([[self.padded_dim / 2, 0], 75 | [0, self.padded_dim / 3], 76 | [self.padded_dim / 3, self.padded_dim], 77 | [self.padded_dim * 2 / 3, self.padded_dim], 78 | [self.padded_dim, self.padded_dim / 3]], np.int32) 79 | return self.generate_polygon(vertices) 80 | 81 | def generate_hexagon(self): 82 | vertices = np.array([[self.padded_dim / 4, 0], 83 | [0, self.padded_dim / 2], 84 | [self.padded_dim / 4, self.padded_dim], 85 | [self.padded_dim * 3 / 4, self.padded_dim], 86 | [self.padded_dim, self.padded_dim / 2], 87 | [self.padded_dim * 3 / 4, 0]], np.int32) 88 | return self.generate_polygon(vertices) 89 | 90 | def generate_plus(self): 91 | image = self.init_image() 92 | vert_rect = np.array([[self.padded_dim / 3, 0], 93 | [self.padded_dim / 3, self.padded_dim], 94 | [self.padded_dim * 2 / 3, self.padded_dim], 95 | [self.padded_dim * 2 / 3, 0]], np.int32) 96 | hor_rect = np.array([[0, self.padded_dim / 3], 97 | [0, self.padded_dim * 2 / 3], 98 | [self.padded_dim, self.padded_dim * 2 / 3], 99 | [self.padded_dim, self.padded_dim / 3]], np.int32) 100 | image = self.generate_polygon(vert_rect, image=image) 101 | return self.generate_polygon(hor_rect, image=image) 102 | 103 | def generate_arrow(self): 104 | vertices = np.array([[0, self.padded_dim / 4], 105 | [0, self.padded_dim * 3 / 4], 106 | [self.padded_dim * 2 / 3, self.padded_dim * 3 / 4], 107 | [self.padded_dim, self.padded_dim / 2], 108 | [self.padded_dim * 2 / 3, self.padded_dim / 4]], np.int32) 109 | return self.generate_polygon(vertices) 110 | 111 | def generate(self, shape): 112 | method = getattr(self, 'generate_' + shape) 113 | shape = method() 114 | img = np.zeros((self.dim, self.dim, 3)) 115 | img[self.padding:self.dim - self.padding, self.padding:self.dim - self.padding] = shape 116 | return img 117 | 118 | 119 | if __name__ == "__main__": 120 | generator = ShapeGenerator() 121 | save_dir = '/tmp/shapes' 122 | if not os.path.exists(save_dir): 123 | os.makedirs(save_dir) 124 | 125 | for s in generator.shapes: 126 | img = generator.generate(s) 127 | cv2.imwrite(save_dir + f'/{s}.jpg', img) 128 | print(f"Saved {s}") 129 | -------------------------------------------------------------------------------- /trainers/rubi_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | import torch 5 | 6 | from models import model_factory 7 | from trainers.base_trainer import BaseTrainer 8 | from utils.bias_retrievers import build_bias_retriever 9 | from utils.metric_visualizer import LossVisualizer 10 | from utils.trainer_utils import create_optimizer 11 | 12 | 13 | class RUBiTrainer(BaseTrainer): 14 | """ 15 | Adaptation of 16 | Cadene, Remi, et al. "Rubi: Reducing unimodal biases in visual question answering." (NeurIPS 2019). 17 | 18 | While the original implementation dealt with language biases in VQA by using question features as bias features, 19 | in this work, we use explicit labels (e.g., gender label in CelebA) as the bias input. 20 | 21 | Specifically, RUBi contains a main branch (taking in the full input e.g., image for CelebA or image and question for VQA) 22 | and a bias-only branch (e.g., taking in gender labels for CelebA). 23 | The sigmoid of logits from the bias-only branch are used to modulate logits from the main branch during training. 24 | 25 | Paper: https://arxiv.org/abs/2005.04345 26 | Original codebase: https://github.com/cdancette/rubi.bootstrap.pytorch 27 | """ 28 | 29 | def __init__(self, option): 30 | super(RUBiTrainer, self).__init__(option) 31 | self.loss_visualizer = LossVisualizer(self.option.expt_dir) 32 | 33 | def _build_model(self): 34 | super()._build_model() 35 | if self.option.bias_model_name is None: 36 | self.option.bias_model_name = self.option.model_name 37 | self.bias_model = model_factory.build_model(self.option, 38 | self.option.bias_model_name, 39 | in_dims=self.option.bias_variable_dims, 40 | hid_dims=self.option.bias_model_hid_dims, 41 | out_dims=self.option.num_classes) 42 | logging.getLogger().info("Bias Model") 43 | logging.getLogger().info(self.bias_model) 44 | self.bias_retriever = build_bias_retriever(self.option.bias_variable_name) 45 | 46 | if self.option.cuda: 47 | self.model.cuda() 48 | self.bias_model.cuda() 49 | self.loss.cuda() 50 | 51 | def _build_optimizer(self): 52 | self.optim = create_optimizer(self.option.optimizer_name, 53 | named_params=list(self.model.named_parameters()) + list( 54 | self.bias_model.named_parameters()), 55 | lr=self.option.lr, 56 | weight_decay=self.option.weight_decay, 57 | momentum=self.option.momentum, 58 | freeze_layers=self.option.freeze_layers) 59 | 60 | def _mode_setting(self, is_train=True): 61 | self.model.train(is_train) 62 | self.bias_model.train(is_train) 63 | 64 | def _train_epoch(self, epoch, data_loader): 65 | self._mode_setting(is_train=True) 66 | for i, batch in enumerate(data_loader): 67 | batch = self.prepare_batch(batch) 68 | self.optim.zero_grad() 69 | 70 | main_out = self.forward_model(self.model, batch) 71 | logits = main_out['logits'] 72 | bias = self.bias_retriever(batch, main_out) 73 | 74 | if self.option.bias_variable_type == 'categorical': 75 | # If it is a categorical bias variable (e.g., gender), then convert to one hot vectors 76 | _bias = torch.zeros((len(bias), self.option.bias_variable_dims)) 77 | for ix, bias_ix in enumerate(bias): 78 | _bias[ix, int(bias_ix)] = 1 79 | bias = _bias 80 | 81 | bias = bias.cuda() 82 | 83 | # Feed into the bias model 84 | bias_out = self.bias_model(bias.detach()) 85 | bias_logits = bias_out['logits'] 86 | bias_losses = self.loss(bias_logits, batch['y'].squeeze()) 87 | bias_loss = bias_losses.mean() 88 | 89 | # Modulate the main model's outputs with bias-only model's outputs 90 | main_losses = self.loss(logits, batch['y'].squeeze()) 91 | loss = main_losses.mean() 92 | sigmoid_weight = torch.sigmoid(bias_logits) 93 | rubi_logits = logits * sigmoid_weight 94 | rubi_losses = self.loss(rubi_logits, batch['y'].squeeze()) 95 | loss_ratio = rubi_losses / (rubi_losses + main_losses) 96 | 97 | # Optimize main model and bias-only model 98 | fused_losses = rubi_losses + bias_losses 99 | fused_loss = fused_losses.mean() 100 | fused_loss.backward(retain_graph=True) 101 | 102 | if self.option.grad_clip is not None: 103 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip) 104 | torch.nn.utils.clip_grad_norm_(self.bias_model.parameters(), self.option.grad_clip) 105 | self.optim.step() 106 | self.update_generalization_metrics('Train Main', batch, main_losses) 107 | self.update_generalization_metrics('Train Bias', batch, bias_losses) 108 | rubi_out = {} 109 | rubi_out['logits'] = rubi_logits 110 | self.update_generalization_metrics('Train RUBi', batch, rubi_losses) 111 | self.update_generalization_metrics('Train Fused', batch, fused_losses) 112 | if self.option.enable_groupwise_metrics: 113 | self.update_groupwise_values('RUBi Loss/(RUBi+Main)', 'Loss Ratio', loss_ratio, batch) 114 | 115 | self._after_train_epoch(epoch, 'Train Fused') 116 | 117 | def get_keys_to_save(self): 118 | return super().get_keys_to_save() + ['bias_model'] 119 | -------------------------------------------------------------------------------- /utils/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import time 5 | import logging 6 | import torch 7 | from torch import optim 8 | from torch.optim import * 9 | 10 | 11 | def save_option(option): 12 | if not os.path.exists(option.save_dir + '/' + option.expt_name): 13 | os.makedirs(option.save_dir + '/' + option.expt_name) 14 | option_path = os.path.join(option.save_dir, option.expt_name, "options.json") 15 | 16 | with open(option_path, 'w') as fp: 17 | json.dump(option.__dict__, fp, indent=4, sort_keys=True, 18 | default=lambda o: f"<>") 19 | logging.getLogger().info(json.dumps(option.__dict__, indent=4, sort_keys=True, 20 | default=lambda o: f"<>")) 21 | 22 | 23 | def initialize_logger(expt_dir): 24 | if not os.path.exists(expt_dir): 25 | os.makedirs(expt_dir) 26 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 27 | logger = logging.getLogger() 28 | logger.handlers = [] 29 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 30 | 31 | streaming_handler = logging.StreamHandler() 32 | streaming_handler.setFormatter(formatter) 33 | logger.addHandler(streaming_handler) 34 | 35 | null_handler = logging.NullHandler() 36 | null_handler.setLevel(logging.DEBUG) 37 | logging.getLogger("tornado.access").addHandler(null_handler) 38 | logging.getLogger("tornado.access").propagate = False 39 | 40 | file_handler = logging.FileHandler(os.path.join(expt_dir, 'log.txt')) 41 | file_handler.setFormatter(formatter) 42 | logger.addHandler(file_handler) 43 | return logger 44 | 45 | 46 | class Timer(object): 47 | def __init__(self, logger, epochs, last_step=0): 48 | self.logger = logger 49 | self.epochs = epochs 50 | self.step = last_step 51 | 52 | curr_time = time.time() 53 | self.start = curr_time 54 | self.last = curr_time 55 | 56 | def __call__(self): 57 | curr_time = time.time() 58 | self.step += 1 59 | 60 | duration = curr_time - self.last 61 | remaining = (self.epochs - self.step) * (curr_time - self.start) / self.step / 3600 62 | msg = 'TIMER, duration(s)|remaining(h), %f, %f' % (duration, remaining) 63 | 64 | self.last = curr_time 65 | 66 | 67 | def get_dir(file_path): 68 | return '/'.join(file_path.split('/')[:-1]) 69 | 70 | 71 | class GradMult(torch.autograd.Function): 72 | 73 | @staticmethod 74 | def forward(ctx, x, const): 75 | ctx.const = const 76 | return x.view_as(x) 77 | 78 | @staticmethod 79 | def backward(ctx, grad_output): 80 | return grad_output * ctx.const, None 81 | 82 | 83 | def grad_mult(x, const): 84 | return GradMult.apply(x, const) 85 | 86 | 87 | class GradReverse(torch.autograd.Function): 88 | @staticmethod 89 | def forward(ctx, x): 90 | return x.view_as(x) 91 | 92 | @staticmethod 93 | def backward(ctx, grad_output): 94 | return grad_output.neg() # * 0.1 95 | 96 | 97 | def grad_reverse(x): 98 | return GradReverse.apply(x) 99 | 100 | 101 | def create_optimizer(optimizer_name, named_params, lr, weight_decay=0, momentum=0.9, freeze_layers=None, 102 | custom_lr_config=None): 103 | """ 104 | Builds the optimizer of given name, adding only the provided named_params 105 | Supports freezing layers too 106 | 107 | For the experiments, we did not use freeze_layers or custom_lr_config. So, this code can be simplified a lot now. 108 | :return: 109 | """ 110 | if weight_decay is None: 111 | weight_decay = 0 112 | 113 | def should_be_added(layer_name): 114 | ret = True 115 | if freeze_layers is None: 116 | return ret 117 | for freeze_layer in freeze_layers: 118 | if layer_name.startswith(freeze_layer): 119 | ret = False 120 | return ret 121 | 122 | filt_params = [] 123 | for name, param in named_params: 124 | param_dict = None 125 | if should_be_added(name): 126 | if custom_lr_config is not None: 127 | for custom_lr_name in custom_lr_config: 128 | if name.startswith(custom_lr_name): 129 | param_dict = {'params': param, 'lr': custom_lr_config[custom_lr_name]} 130 | if param_dict is None: 131 | param_dict = {'params': param, 'lr': lr} 132 | filt_params.append(param_dict) 133 | # logging.getLogger().info(f"Adding param: {name}") 134 | else: 135 | param.requires_grad = False # for efficiency 136 | logging.getLogger().info(f"Freezing param: {name}") 137 | 138 | if optimizer_name == 'SGD': 139 | optimizer = optim.SGD(filt_params, lr=lr, momentum=momentum, weight_decay=weight_decay) 140 | else: 141 | optimizer = eval(optimizer_name)(filt_params, lr=lr, weight_decay=weight_decay) 142 | return optimizer 143 | 144 | 145 | def clip_grad_norm(parameters, max_norm, norm_type=2): 146 | if isinstance(parameters, torch.Tensor): 147 | parameters = [parameters] 148 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 149 | max_norm = float(max_norm) 150 | norm_type = float(norm_type) 151 | if norm_type == 'inf': 152 | total_norm = max(p.grad.data.abs().max() for p in parameters) 153 | else: 154 | total_norm = 0 155 | for p in parameters: 156 | param_norm = p.grad.data.norm(norm_type) 157 | total_norm += param_norm.item() ** norm_type 158 | total_norm = total_norm ** (1. / norm_type) 159 | clip_coef = max_norm / (total_norm + 1e-6) 160 | if clip_coef < 1: 161 | for p in parameters: 162 | p.grad.data.mul_(clip_coef) 163 | return total_norm 164 | -------------------------------------------------------------------------------- /models/vqa/mcan/mca.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 4 | # -------------------------------------------------------- 5 | 6 | from models.vqa.ops.fc import FC, MLP 7 | from models.vqa.ops.layer_norm import LayerNorm 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch 12 | import math 13 | 14 | 15 | # ------------------------------ 16 | # ---- Multi-Head Attention ---- 17 | # ------------------------------ 18 | 19 | class MHAtt(nn.Module): 20 | def __init__(self, hidden_size, dropout_r, multi_head=8): 21 | super(MHAtt, self).__init__() 22 | self.multi_head = multi_head 23 | self.hidden_size = hidden_size 24 | 25 | self.linear_v = nn.Linear(hidden_size, hidden_size) 26 | self.linear_k = nn.Linear(hidden_size, hidden_size) 27 | self.linear_q = nn.Linear(hidden_size, hidden_size) 28 | self.linear_merge = nn.Linear(hidden_size, hidden_size) 29 | 30 | self.dropout = nn.Dropout(dropout_r) 31 | 32 | def forward(self, v, k, q, mask): 33 | n_batches = q.size(0) 34 | 35 | v = self.linear_v(v).view( 36 | n_batches, 37 | -1, 38 | self.multi_head, 39 | int(self.hidden_size / self.multi_head) 40 | ).transpose(1, 2) 41 | 42 | k = self.linear_k(k).view( 43 | n_batches, 44 | -1, 45 | self.multi_head, 46 | int(self.hidden_size / self.multi_head) 47 | ).transpose(1, 2) 48 | 49 | q = self.linear_q(q).view( 50 | n_batches, 51 | -1, 52 | self.multi_head, 53 | int(self.hidden_size / self.multi_head) 54 | ).transpose(1, 2) 55 | 56 | atted = self.att(v, k, q, mask) 57 | atted = atted.transpose(1, 2).contiguous().view( 58 | n_batches, 59 | -1, 60 | self.hidden_size 61 | ) 62 | 63 | atted = self.linear_merge(atted) 64 | 65 | return atted 66 | 67 | def att(self, value, key, query, mask): 68 | d_k = query.size(-1) 69 | 70 | scores = torch.matmul( 71 | query, key.transpose(-2, -1) 72 | ) / math.sqrt(d_k) 73 | 74 | if mask is not None: 75 | scores = scores.masked_fill(mask, -1e9) 76 | 77 | att_map = F.softmax(scores, dim=-1) 78 | att_map = self.dropout(att_map) 79 | 80 | return torch.matmul(att_map, value) 81 | 82 | 83 | # --------------------------- 84 | # ---- Feed Forward Nets ---- 85 | # --------------------------- 86 | 87 | class FFN(nn.Module): 88 | def __init__(self, hidden_size, ff_size, dropout_r): 89 | super(FFN, self).__init__() 90 | 91 | self.mlp = MLP( 92 | in_size=hidden_size, 93 | mid_size=ff_size, 94 | out_size=hidden_size, 95 | dropout_r=dropout_r, 96 | use_relu=True 97 | ) 98 | 99 | def forward(self, x): 100 | return self.mlp(x) 101 | 102 | 103 | # ------------------------ 104 | # ---- Self Attention ---- 105 | # ------------------------ 106 | 107 | class SA(nn.Module): 108 | def __init__(self, hidden_size, ff_size, dropout_r, multi_head): 109 | super(SA, self).__init__() 110 | 111 | self.mhatt = MHAtt(hidden_size, dropout_r, multi_head) 112 | self.ffn = FFN(hidden_size, ff_size, dropout_r) 113 | 114 | self.dropout1 = nn.Dropout(dropout_r) 115 | self.norm1 = LayerNorm(hidden_size) 116 | 117 | self.dropout2 = nn.Dropout(dropout_r) 118 | self.norm2 = LayerNorm(hidden_size) 119 | 120 | def forward(self, y, y_mask): 121 | y = self.norm1(y + self.dropout1( 122 | self.mhatt(y, y, y, y_mask) 123 | )) 124 | 125 | y = self.norm2(y + self.dropout2( 126 | self.ffn(y) 127 | )) 128 | 129 | return y 130 | 131 | 132 | # ------------------------------- 133 | # ---- Self Guided Attention ---- 134 | # ------------------------------- 135 | 136 | class SGA(nn.Module): 137 | def __init__(self, hidden_size, ff_size, dropout_r, multi_head): 138 | super(SGA, self).__init__() 139 | 140 | self.mhatt1 = MHAtt(hidden_size, dropout_r, multi_head) 141 | self.mhatt2 = MHAtt(hidden_size, dropout_r, multi_head) 142 | self.ffn = FFN(hidden_size, ff_size, dropout_r) 143 | 144 | self.dropout1 = nn.Dropout(dropout_r) 145 | self.norm1 = LayerNorm(hidden_size) 146 | 147 | self.dropout2 = nn.Dropout(dropout_r) 148 | self.norm2 = LayerNorm(hidden_size) 149 | 150 | self.dropout3 = nn.Dropout(dropout_r) 151 | self.norm3 = LayerNorm(hidden_size) 152 | 153 | def forward(self, x, y, x_mask, y_mask): 154 | x = self.norm1(x + self.dropout1( 155 | self.mhatt1(v=x, k=x, q=x, mask=x_mask) 156 | )) 157 | 158 | x = self.norm2(x + self.dropout2( 159 | self.mhatt2(v=y, k=y, q=x, mask=y_mask) 160 | )) 161 | 162 | x = self.norm3(x + self.dropout3( 163 | self.ffn(x) 164 | )) 165 | 166 | return x 167 | 168 | 169 | # ------------------------------------------------ 170 | # ---- MAC Layers Cascaded by Encoder-Decoder ---- 171 | # ------------------------------------------------ 172 | 173 | class MCA_ED(nn.Module): 174 | def __init__(self, hidden_size, ff_size, dropout_r, multi_head, layer=6): 175 | super(MCA_ED, self).__init__() 176 | 177 | self.enc_list = nn.ModuleList([SA(hidden_size, ff_size, dropout_r, multi_head) for _ in range(layer)]) 178 | self.dec_list = nn.ModuleList([SGA(hidden_size, ff_size, dropout_r, multi_head) for _ in range(layer)]) 179 | 180 | def forward(self, y, x, y_mask, x_mask): 181 | # Get encoder last hidden vector 182 | for enc in self.enc_list: 183 | y = enc(y, y_mask) 184 | 185 | # Input encoder last hidden vector 186 | # And obtain decoder last hidden vectors 187 | for dec in self.dec_list: 188 | x = dec(x, y, x_mask, y_mask) 189 | 190 | return y, x 191 | -------------------------------------------------------------------------------- /models/vqa/ban/_ban.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Written by Zhenwei Shao https://github.com/ParadoxZW 4 | # Based on the implementation of paper "Bilinear Attention Neworks", NeurIPS 2018 https://github.com/jnhwkim/ban-vqa) 5 | # -------------------------------------------------------- 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn.utils.weight_norm import weight_norm 10 | import torch, math 11 | 12 | 13 | # ------------------------------ 14 | # ----- Weight Normal MLP ------ 15 | # ------------------------------ 16 | 17 | class MLP(nn.Module): 18 | """ 19 | Simple class for non-linear fully connect network 20 | """ 21 | 22 | def __init__(self, dims, act='ReLU', dropout_r=0.0): 23 | super(MLP, self).__init__() 24 | 25 | layers = [] 26 | for i in range(len(dims) - 1): 27 | in_dim = dims[i] 28 | out_dim = dims[i + 1] 29 | if dropout_r > 0: 30 | layers.append(nn.Dropout(dropout_r)) 31 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 32 | if act != '': 33 | layers.append(getattr(nn, act)()) 34 | 35 | self.mlp = nn.Sequential(*layers) 36 | 37 | def forward(self, x): 38 | return self.mlp(x) 39 | 40 | 41 | # ------------------------------ 42 | # ------ Bilinear Connect ------ 43 | # ------------------------------ 44 | 45 | class BC(nn.Module): 46 | """ 47 | Simple class for non-linear bilinear connect network 48 | """ 49 | 50 | def __init__(self, 51 | img_feat_size=2048, 52 | hidden_size=1024, 53 | k_times=3, 54 | dropout_r=0.2, 55 | classifier_dropout_r=0.5, 56 | glimpse=8, 57 | atten=False): 58 | super(BC, self).__init__() 59 | self.k_times = k_times 60 | ba_hidden_size = k_times * hidden_size 61 | self.v_net = MLP([img_feat_size, 62 | ba_hidden_size], dropout_r=dropout_r) 63 | self.q_net = MLP([hidden_size, 64 | ba_hidden_size], dropout_r=dropout_r) 65 | if not atten: 66 | self.p_net = nn.AvgPool1d(k_times, stride=k_times) 67 | else: 68 | self.dropout = nn.Dropout(classifier_dropout_r) # attention 69 | 70 | self.h_mat = nn.Parameter(torch.Tensor( 71 | 1, glimpse, 1, ba_hidden_size).normal_()) 72 | self.h_bias = nn.Parameter( 73 | torch.Tensor(1, glimpse, 1, 1).normal_()) 74 | 75 | def forward(self, v, q): 76 | # low-rank bilinear pooling using einsum 77 | v_ = self.dropout(self.v_net(v)) 78 | q_ = self.q_net(q) 79 | logits = torch.einsum('xhyk,bvk,bqk->bhvq', 80 | (self.h_mat, v_, q_)) + self.h_bias 81 | return logits # b x h_out x v x q 82 | 83 | def forward_with_weights(self, v, q, w): 84 | v_ = self.v_net(v) # b x v x d 85 | q_ = self.q_net(q) # b x q x d 86 | logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_)) 87 | logits = logits.unsqueeze(1) # b x 1 x d 88 | logits = self.p_net(logits).squeeze(1) * self.k_times # sum-pooling 89 | return logits 90 | 91 | 92 | # ------------------------------ 93 | # -------- BiAttention --------- 94 | # ------------------------------ 95 | 96 | 97 | class BiAttention(nn.Module): 98 | def __init__(self, 99 | img_feat_size=2048, 100 | hidden_size=1024, 101 | k_times=3, 102 | dropout_r=0.2, 103 | classifier_dropout_r=0.5, 104 | glimpse=8 105 | ): 106 | super(BiAttention, self).__init__() 107 | # 108 | self.glimpse = glimpse 109 | self.logits = weight_norm( 110 | BC(img_feat_size, hidden_size, k_times, dropout_r, classifier_dropout_r, glimpse, True), name='h_mat', 111 | dim=None) 112 | 113 | def forward(self, v, q, v_mask=True, logit=False, mask_with=-float('inf')): 114 | v_num = v.size(1) 115 | q_num = q.size(1) 116 | logits = self.logits(v, q) # b x g x v x q 117 | 118 | if v_mask: 119 | mask = (0 == v.abs().sum(2)).unsqueeze( 120 | 1).unsqueeze(3).expand(logits.size()) 121 | logits.data.masked_fill_(mask.data, mask_with) 122 | 123 | if not logit: 124 | p = nn.functional.softmax( 125 | logits.view(-1, self.glimpse, v_num * q_num), 2) 126 | return p.view(-1, self.glimpse, v_num, q_num), logits 127 | 128 | return logits 129 | 130 | 131 | # ------------------------------ 132 | # - Bilinear Attention Network - 133 | # ------------------------------ 134 | 135 | class _BAN(nn.Module): 136 | def __init__(self, 137 | img_feat_size, 138 | hidden_size, 139 | k_times, 140 | dropout_r, 141 | classifier_dropout_r, 142 | glimpse): 143 | super(_BAN, self).__init__() 144 | 145 | self.BiAtt = BiAttention(img_feat_size, 146 | hidden_size, 147 | k_times, 148 | dropout_r, 149 | classifier_dropout_r, 150 | glimpse) 151 | b_net = [] 152 | q_prj = [] 153 | c_prj = [] 154 | self.glimpse = glimpse 155 | for i in range(glimpse): 156 | b_net.append(BC(img_feat_size, 157 | hidden_size, 158 | k_times, 159 | dropout_r, 160 | classifier_dropout_r, 161 | glimpse, 162 | False)) 163 | q_prj.append(MLP([hidden_size, hidden_size], '', dropout_r)) 164 | self.b_net = nn.ModuleList(b_net) 165 | self.q_prj = nn.ModuleList(q_prj) 166 | 167 | def forward(self, q, v): 168 | att, logits = self.BiAtt(v, q) # b x g x v x q 169 | 170 | for g in range(self.glimpse): 171 | bi_emb = self.b_net[g].forward_with_weights( 172 | v, q, att[:, g, :, :]) # b x l x h 173 | q = self.q_prj[g](bi_emb.unsqueeze(1)) + q 174 | 175 | return q 176 | -------------------------------------------------------------------------------- /utils/metric_visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import torch 4 | # import neptune 5 | import json 6 | 7 | 8 | class MetricVisualizer(): 9 | def __init__(self, env_name='main', divide_windows_by='split', is_one_time_metric=True): 10 | self.env_name = env_name 11 | self.plots = {} 12 | self.reset() 13 | self.divide_windows_by = divide_windows_by 14 | self.use_markers = False 15 | self.is_one_time_metric = is_one_time_metric 16 | 17 | def update_multiple(self, split, dict): 18 | for name, value in zip(dict.keys(), dict.values()): 19 | self.update(split, name, value) 20 | 21 | def update(self, split, name, value): 22 | if split not in self.metrics: 23 | self.metrics[split] = {} 24 | self.metric_cnts[split] = {} 25 | if name not in self.metrics[split]: 26 | if self.is_one_time_metric: 27 | self.metrics[split][name] = 0 28 | else: 29 | self.metrics[split][name] = [] 30 | self.metric_cnts[split][name] = 0 31 | 32 | if self.is_one_time_metric: 33 | self.metrics[split][name] = value 34 | self.metric_cnts[split][name] = 1 35 | else: 36 | if isinstance(value, list): 37 | for val in value: 38 | self.metrics[split][name].append(val) 39 | self.metric_cnts[split][name] += 1 40 | else: 41 | self.metrics[split][name].append(value) 42 | self.metric_cnts[split][name] += 1 43 | 44 | def reset(self): 45 | self.metrics = {} 46 | self.metric_cnts = {} 47 | 48 | def compute_and_save_std_dev(self, split, name): 49 | if split + " Std Dev" not in self.metrics: 50 | self.metrics[split + " Std Dev"] = {} 51 | self.metrics[split + " Variance"] = {} 52 | 53 | self.metrics[split + " Std Dev"][name] = torch.std(torch.Tensor(self.metrics[split][name])) 54 | self.metrics[split + " Variance"][name] = self.metrics[split + " Std Dev"][name] ** 2 55 | 56 | def log(self, epoch, split, avg=True): 57 | log_str = f"Split: {split}" 58 | if epoch is not None: 59 | log_str += f', Epoch: {epoch}' 60 | name_values = {} 61 | for ix, name in enumerate(self.metrics[split]): 62 | if self.is_one_time_metric: 63 | val = self.metrics[split][name] 64 | else: 65 | val = sum(self.metrics[split][name]) 66 | if avg: 67 | val = val / (self.metric_cnts[split][name] + int(self.metric_cnts[split][name] == 0)) 68 | metric_format = self.get_metric_format() 69 | if 'cnt' in name.lower(): 70 | metric_format = '%d' 71 | 72 | # log_str += (", %s: " + metric_format) % (name, val) 73 | # neptune.log_metric(log_name=split + " " + name, x=epoch, y=val) 74 | name_values[name] = val 75 | if len(name_values) <= 1: 76 | logging.getLogger().info(name_values) 77 | else: 78 | logging.getLogger().info(json.dumps(name_values, sort_keys=True, indent=4)) 79 | # logging.getLogger().info(sorted(list(name_values.keys()))) 80 | # vals = [float(name_values[k]) for k in sorted(list(name_values.keys()))] 81 | # logging.getLogger().info(vals) 82 | 83 | # logging.getLogger().info(log_str) 84 | 85 | def get_metric_format(self): 86 | return "%.4f" 87 | 88 | def accumulate_plot_and_reset(self, epoch, xlabel='Epochs', avg=True): 89 | x = epoch 90 | for split in self.metrics: 91 | for name in self.metrics[split]: 92 | if self.is_one_time_metric: 93 | y = self.metrics[split][name] 94 | else: 95 | try: 96 | y = sum(self.metrics[split][name]) 97 | if avg: 98 | y = y / (self.metric_cnts[split][name] + int(self.metric_cnts[split][name] == 0)) 99 | except: 100 | y = self.metrics[split][name] 101 | self.plot(split, name, x, y, xlabel) 102 | self.reset() 103 | 104 | def plot(self, split, name, x, y, xlabel='Epochs'): 105 | """ 106 | Creates a separate chart for each loss 107 | :return: 108 | """ 109 | if self.divide_windows_by == 'name': 110 | win_name = name 111 | line_name = split 112 | else: 113 | win_name = split 114 | line_name = name 115 | # if win_name not in self.plots: 116 | # self.plots[win_name] = self.viz.line(X=np.array([x, x]), Y=np.array([y, y]), env=self.env_name, 117 | # opts=dict(legend=[line_name], title=win_name, xlabel=xlabel, 118 | # ylabel=name, markers=self.use_markers)) 119 | # 120 | # else: 121 | # self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env_name, win=self.plots[win_name], name=line_name, 122 | # update='append') 123 | 124 | 125 | class LossVisualizer(MetricVisualizer): 126 | def __init__(self, env_name='main', divide_windows_by='split'): 127 | super(LossVisualizer, self).__init__(env_name, divide_windows_by, is_one_time_metric=False) 128 | self.use_markers = False 129 | 130 | def get_metric_format(self): 131 | return "%.4f" 132 | 133 | 134 | class AccuracyVisualizer(MetricVisualizer): 135 | def __init__(self, env_name='main', divide_windows_by='split'): 136 | super(AccuracyVisualizer, self).__init__(env_name, divide_windows_by, is_one_time_metric=True) 137 | self.use_markers = True 138 | 139 | def get_metric_format(self): 140 | return "%.2f%%" 141 | 142 | 143 | class CountVisualizer(MetricVisualizer): 144 | def __init__(self, env_name='main', divide_windows_by='name'): 145 | super(CountVisualizer, self).__init__(env_name, divide_windows_by, is_one_time_metric=True) 146 | self.use_markers = False 147 | 148 | def get_metric_format(self): 149 | return "%d" 150 | 151 | def log(self, epoch, split): 152 | return super().log(epoch, split) 153 | 154 | def accumulate_plot_and_reset(self, epoch, xlabel='Epochs'): 155 | return super().accumulate_plot_and_reset(epoch, xlabel) 156 | -------------------------------------------------------------------------------- /trainers/lnl_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from torch import optim 5 | 6 | import models 7 | import models.cnn_models 8 | from trainers.base_trainer import BaseTrainer 9 | from utils.trainer_utils import grad_reverse, grad_mult, create_optimizer 10 | from torch.optim import * 11 | import logging 12 | from models.model_factory import build_model 13 | 14 | 15 | class LNLTrainer(BaseTrainer): 16 | """ 17 | Implementation for: 18 | Kim, Byungju, et al. "Learning not to learn: Training deep neural networks with biased data." (CVPR 2019). 19 | 20 | Method Description: 21 | Has a main branch taking in bias+signal to predict classes and has branches predicting bias variables. 22 | Bias branches act as adversaries to the main branch, thereby reducing dependencies on biases. 23 | 24 | Reference codebase: https://github.com/feidfoe/learning-not-to-learn 25 | 26 | Known Issue(s)/Confusion(s) with original repo: 27 | 1) The paper does not provide full details to replicate the results: https://github.com/feidfoe/learning-not-to-learn/issues/7 28 | 29 | 2) There are two forward passes, one to compute the main loss+entropy loss and the other to compute bias loss (training bias predictors). 30 | Unsure why it hasn't been done with single forward pass in the original repo: https://github.com/feidfoe/learning-not-to-learn/issues/5 31 | """ 32 | 33 | def __init__(self, option): 34 | super(LNLTrainer, self).__init__(option) 35 | 36 | def _build_model(self): 37 | super()._build_model() 38 | self.bias_predictor = build_model(self.option, 39 | self.option.bias_predictor_name, 40 | in_dims=self.option.bias_predictor_in_dims, 41 | hid_dims=self.option.bias_predictor_hid_dims, 42 | out_dims=self.option.num_bias_classes) 43 | logging.getLogger().info(f"Bias predictor {self.bias_predictor}") 44 | 45 | if self.option.cuda: 46 | self.model.cuda() 47 | self.bias_predictor.cuda() 48 | 49 | def _build_optimizer(self): 50 | super()._build_optimizer() 51 | self.bias_predictor_optim = create_optimizer(self.option.optimizer_name, 52 | named_params=self.bias_predictor.named_parameters(), 53 | lr=self.option.lr, 54 | weight_decay=self.option.weight_decay, 55 | momentum=self.option.momentum, 56 | freeze_layers=self.option.freeze_layers) 57 | 58 | def get_keys_to_save(self): 59 | return super().get_keys_to_save() + ['bias_predictor', 'bias_predictor_optim'] 60 | 61 | def _mode_setting(self, is_train=True): 62 | self.model.train(is_train) 63 | self.bias_predictor.train(is_train) 64 | 65 | def _train_epoch(self, epoch, data_loader): 66 | self._mode_setting(is_train=True) 67 | 68 | for i, batch in enumerate(data_loader): 69 | # Prepare data 70 | batch = self.prepare_batch(batch) 71 | labels = batch['y'] 72 | 73 | # Forward pass 74 | self.optim.zero_grad() 75 | self.bias_predictor_optim.zero_grad() 76 | out = self.forward_model(self.model, batch) 77 | hidden = out[self.option.bias_predictor_in_layer] 78 | logits = out['logits'] 79 | 80 | bias_out = self.bias_predictor(hidden) 81 | bias_logits = bias_out['logits'] 82 | 83 | # main loss 84 | batch_loss = self.loss(logits, torch.squeeze(labels)) 85 | loss_pred = batch_loss.mean() 86 | self.loss_visualizer.update('Train', 'Loss', loss_pred.item()) 87 | 88 | # bias loss 89 | bias_softmax = torch.nn.functional.softmax(bias_logits, dim=1) + 1e-8 90 | bias_entropy_loss = torch.mean( 91 | torch.sum(bias_softmax * torch.log(bias_softmax), 1)) * self.option.entropy_loss_weight 92 | self.loss_visualizer.update('Train', 'Bias Entropy', bias_entropy_loss.item()) 93 | loss = loss_pred + bias_entropy_loss 94 | loss.backward(retain_graph=True) 95 | 96 | if self.option.grad_clip is not None: 97 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip) 98 | torch.nn.utils.clip_grad_norm_(self.bias_predictor.parameters(), self.option.grad_clip) 99 | 100 | self.optim.step() 101 | self.optim.zero_grad() 102 | self.bias_predictor_optim.zero_grad() 103 | 104 | # Train with adversarial loss 105 | out = self.forward_model(self.model, batch) 106 | hidden = out[self.option.bias_predictor_in_layer] 107 | hidden_feat = grad_mult(hidden, self.option.grad_reverse_factor) 108 | bias_out = self.bias_predictor(hidden_feat) 109 | 110 | bias_labels = batch[self.option.bias_variable_name] 111 | if isinstance(bias_labels, list): 112 | bias_labels = torch.LongTensor(bias_labels) 113 | bias_labels = bias_labels.long().cuda() 114 | if len(bias_labels.squeeze().shape) > 1: 115 | bias_labels = torch.argmax(bias_labels, dim=1).squeeze() 116 | bias_loss = self.loss(bias_out['logits'].squeeze(), bias_labels.squeeze()) 117 | self.loss_visualizer.update('Train', 'Main loss', loss_pred.mean().item()) 118 | self.loss_visualizer.update('Train', 'Bias Entropy loss', bias_entropy_loss.mean().item()) 119 | self.loss_visualizer.update('Train', 'Bias loss', bias_loss.mean().item()) 120 | self.update_generalization_metrics('Train', batch, batch_loss) 121 | bias_loss.mean().backward(retain_graph=True) 122 | 123 | if self.option.grad_clip is not None: 124 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip) 125 | torch.nn.utils.clip_grad_norm_(self.bias_predictor.parameters(), self.option.grad_clip) 126 | 127 | self.optim.step() 128 | self.bias_predictor_optim.step() 129 | 130 | self._after_train_epoch(epoch) 131 | 132 | def _after_one_epoch(self, epoch, test_loaders, force_test=False): 133 | super()._after_one_epoch(epoch, test_loaders, force_test) 134 | -------------------------------------------------------------------------------- /trainers/irm_v1_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import random 4 | 5 | import numpy as np 6 | from torch.utils.data import Subset, DataLoader 7 | 8 | from trainers.base_trainer import BaseTrainer 9 | from utils.losses import * 10 | 11 | 12 | class IRMv1Trainer(BaseTrainer): 13 | """ 14 | Implementation for: 15 | Arjovsky, Martin, et al. "Invariant risk minimization." (ICLR 2021). 16 | 17 | This attempts to learn representations that enable the same classifier to be optimal across environments. 18 | For our implementation, we use the explicit groups based on the explicit bias variables including the class label to form the environments. 19 | We uniformly sample from environments within each batch (i.e., our implementation assumes balanced group sampling). 20 | """ 21 | 22 | def __init__(self, option): 23 | super(IRMv1Trainer, self).__init__(option) 24 | self.group_names = None 25 | self.env_to_data_loader = None 26 | self.num_envs_per_batch = option.num_envs_per_batch 27 | assert self.option.batch_size % self.num_envs_per_batch == 0, \ 28 | "Batch size is not exactly divisible by the number of environments within that batch" 29 | 30 | def compute_gradient_penalty(self, logits, y): 31 | """ 32 | Gradient of the risk when all classifier weights are set to 1 (i.e., the IRMv1 regularizer) 33 | Based on unbiased estimation of IRMV1 (Sec 3.2) and Appendix D of https://arxiv.org/pdf/1907.02893.pdf 34 | It assumes that each batch contains per-environment samples from 'num_envs_per_batch' environments in an ordered manner. 35 | 36 | :param logits: 37 | :param y: 38 | :return: 39 | """ 40 | dummy_classifier = torch.tensor(1.).cuda().requires_grad_() 41 | loss = self.loss(logits * dummy_classifier, y.squeeze()) 42 | start_ix = 0 43 | grad_loss = 0 44 | for env in np.arange(0, self.num_envs_per_batch): 45 | end_ix = start_ix + self.option.batch_size // self.num_envs_per_batch 46 | env_loss = loss[start_ix:end_ix] 47 | loss1 = env_loss[:len(env_loss) // 2] 48 | loss2 = env_loss[len(env_loss) // 2:] 49 | grad1 = torch.autograd.grad(loss1.mean(), [dummy_classifier], create_graph=True)[0] 50 | grad2 = torch.autograd.grad(loss2.mean(), [dummy_classifier], create_graph=True)[0] 51 | grad_loss += grad1 * grad2 52 | start_ix = end_ix 53 | 54 | return grad_loss 55 | 56 | def train(self, train_loader, test_loaders=None, unbalanced_train_loader=None): 57 | self.before_train(train_loader, test_loaders) 58 | start_epoch = 1 59 | orig_loader = train_loader 60 | batch_sampler = EnvironmentWiseBatchSampler(self.option.batch_size, orig_loader, self.num_envs_per_batch) 61 | dataset = orig_loader.dataset 62 | if isinstance(dataset, Subset): 63 | dataset = dataset.dataset 64 | 65 | train_loader = DataLoader(dataset, batch_sampler=batch_sampler, 66 | num_workers=orig_loader.num_workers, collate_fn=orig_loader.collate_fn) 67 | 68 | for epoch in range(start_epoch, self.option.epochs + 1): 69 | self._train_epoch(epoch, train_loader) 70 | self._after_one_epoch(epoch, test_loaders) 71 | self.after_all_epochs() 72 | 73 | def _train_epoch(self, epoch, data_loader): 74 | self._mode_setting(is_train=True) 75 | 76 | for batch_ix, batch in enumerate(data_loader): 77 | batch = self.prepare_batch(batch) 78 | out = self.forward_model(self.model, batch) 79 | logits = out['logits'] 80 | 81 | # Unbiased IRMv1 goes through each environment before doing a backward pass 82 | # However this is not scalable e.g., when # of environments are in 100s or 1000s, 83 | # so we randomly sample certain environments in every batch 84 | batch_losses = self.loss(logits, torch.squeeze(batch['y'])) 85 | grad_penalty = self.option.grad_penalty_weight * self.compute_gradient_penalty(logits, batch['y']) 86 | self.loss_visualizer.update(f'Train', 'Main Loss', batch_losses.mean().item()) 87 | self.loss_visualizer.update(f'Train', 'Grad Penalty', 88 | self.option.grad_penalty_weight * grad_penalty.mean().item()) 89 | 90 | self.optim.zero_grad() 91 | loss = batch_losses.mean() + grad_penalty.mean() 92 | loss.backward(retain_graph=True) # Cannot go through each environment before calling backward() 93 | if self.option.grad_clip is not None: 94 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip) 95 | self.optim.step() 96 | 97 | self.optim.zero_grad() 98 | self._after_train_epoch(epoch) 99 | 100 | 101 | class EnvironmentWiseBatchSampler(): 102 | def __init__(self, batch_size, data_loader, num_envs_per_batch): 103 | """ 104 | We first identify the environment for each data item 105 | For each batch, we randomly sample from random environments 106 | """ 107 | self.num_items = 0 108 | self.env_to_dataset_ixs = {} 109 | self.batch_size = batch_size 110 | self.num_envs_per_batch = num_envs_per_batch 111 | 112 | # Do one pass through the train_loader to get indices per env 113 | for batch_ix, batch in enumerate(data_loader): 114 | for dix, gix in zip(batch['dataset_ix'], batch['group_ix']): 115 | if gix not in self.env_to_dataset_ixs: 116 | self.env_to_dataset_ixs[gix] = [] 117 | self.env_to_dataset_ixs[gix].append(dix) 118 | self.num_items += 1 119 | self.env_keys = list(self.env_to_dataset_ixs.keys()) 120 | logging.getLogger().info(f"env keys {self.env_keys}") 121 | # for gix in self.env_to_dataset_ixs: 122 | # logging.getLogger().info(f"Env Key: {gix} Cnt: {len(self.env_to_dataset_ixs[gix])}") 123 | 124 | def __iter__(self): 125 | num_batches_per_epoch = self.__len__() 126 | curr_batch_cnt = 0 127 | 128 | while curr_batch_cnt <= num_batches_per_epoch: 129 | # Randomly select some environments per batch 130 | env_ixs = np.random.choice(self.env_keys, self.num_envs_per_batch) # Randomly sample some environments 131 | 132 | # Randomly select within each of the chosen environments 133 | batch = [] 134 | 135 | for env_ix in env_ixs: 136 | for b in np.arange(self.batch_size // self.num_envs_per_batch): 137 | dix = random.choice(self.env_to_dataset_ixs[env_ix]) 138 | batch.append(dix) 139 | 140 | curr_batch_cnt += 1 141 | yield batch 142 | 143 | def __len__(self): 144 | # The total budget per epoch is self.num_items 145 | return self.num_items // self.batch_size 146 | -------------------------------------------------------------------------------- /models/fc_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class GammaRegressor(nn.Module): 7 | def __init__(self, in_dims, num_classes, gamma_coeff=5): 8 | super(GammaRegressor, self).__init__() 9 | self.fc = nn.Linear(in_dims, num_classes) 10 | self.gamma_coeff = gamma_coeff 11 | 12 | def forward(self, x): 13 | if len(x.shape) == 1: 14 | x = x.unsqueeze(1) 15 | # fc = self.gamma_coeff * torch.sigmoid(self.fc(x)) 16 | # fc = torch.sigmoid(self.fc(x)) 17 | fc = F.relu(self.fc(x)) 18 | return { 19 | 'out': fc 20 | } 21 | 22 | 23 | class MLP1(nn.Module): 24 | def __init__(self, in_dims, num_classes, hid_dims=None): 25 | super(MLP1, self).__init__() 26 | self.fc = nn.Linear(in_dims, num_classes) 27 | 28 | def forward(self, x): 29 | if len(x.shape) == 1: 30 | x = x.unsqueeze(1) 31 | fc = self.fc(x) 32 | return { 33 | 'before_logits': x, 34 | 'logits': fc 35 | } 36 | 37 | 38 | class MLP2(nn.Module): 39 | def __init__(self, in_dims, hid_dims, num_classes): 40 | super(MLP2, self).__init__() 41 | self.fc1 = nn.Linear(in_dims, hid_dims) 42 | self.fc2 = nn.Linear(hid_dims, num_classes) 43 | 44 | def forward(self, x): 45 | if len(x.shape) == 1: 46 | x = x.unsqueeze(1) 47 | x = self.fc1(x) 48 | x = F.relu(x) 49 | logits = self.fc2(x) 50 | return { 51 | 'before_logits': x, 52 | 'logits': logits 53 | } 54 | 55 | def freeze_layers(self, freeze_layers): 56 | pass 57 | 58 | 59 | class MLP3(nn.Module): 60 | def __init__(self, in_dims, hid_dims, num_classes): 61 | super(MLP3, self).__init__() 62 | self.fc1 = nn.Linear(in_dims, hid_dims) 63 | self.fc2 = nn.Linear(hid_dims, hid_dims) 64 | self.fc3 = nn.Linear(hid_dims, num_classes) 65 | 66 | def forward(self, x): 67 | if len(x.shape) == 1: 68 | x = x.unsqueeze(1) 69 | x = self.fc1(x) 70 | x = F.relu(x) 71 | x = self.fc2(x) 72 | x = F.relu(x) 73 | logits = self.fc3(x) 74 | return { 75 | 'before_logits': x, 76 | 'logits': logits 77 | } 78 | 79 | def freeze_layers(self, freeze_layers): 80 | pass 81 | 82 | 83 | class MLP2_200(MLP2): 84 | def __init__(self, in_dims=300, hid_dims=600, num_classes=None): 85 | super().__init__(200, 400, num_classes) 86 | 87 | 88 | class MLP2_300(MLP2): 89 | def __init__(self, in_dims=300, hid_dims=600, num_classes=None): 90 | super().__init__(300, 600, num_classes) 91 | 92 | 93 | class MLP2_2(MLP2): 94 | def __init__(self, in_dims=300, hid_dims=600, num_classes=None): 95 | super().__init__(2, 4, num_classes) 96 | 97 | 98 | class MLP4(nn.Module): 99 | def __init__(self, in_dims, hid_dims, num_classes): 100 | super(MLP4, self).__init__() 101 | self.fc1 = nn.Linear(in_dims, hid_dims) 102 | self.fc2 = nn.Linear(hid_dims, hid_dims) 103 | self.fc3 = nn.Linear(hid_dims, hid_dims) 104 | self.fc4 = nn.Linear(hid_dims, num_classes) 105 | 106 | def forward(self, x): 107 | if len(x.shape) == 1: 108 | x = x.unsqueeze(1) 109 | x = self.fc1(x) 110 | x = F.relu(x) 111 | x = self.fc2(x) 112 | x = F.relu(x) 113 | x = self.fc3(x) 114 | x = F.relu(x) 115 | logits = self.fc4(x) 116 | return { 117 | 'before_logits': x, 118 | 'logits': logits 119 | } 120 | 121 | 122 | class ModelWrapper(nn.Module): 123 | def __init__(self, core_model, classifier, classifier_in_layer): 124 | super().__init__() 125 | self.core_model = core_model 126 | self.classifier = classifier 127 | self.classifier_in_layer = classifier_in_layer 128 | 129 | def forward(self, x): 130 | out = self.core_model(x) 131 | feat_repr = out[self.classifier_in_layer] 132 | return self.classifier(feat_repr) 133 | 134 | 135 | class LFFMnistClassifier(nn.Module): 136 | def __init__(self, num_classes=10, in_dims=None, hid_dims=None): 137 | super(LFFMnistClassifier, self).__init__() 138 | self.fc1 = nn.Linear(3 * 28 * 28, 100) 139 | self.fc2 = nn.Linear(100, 100) 140 | self.fc3 = nn.Linear(100, 100) 141 | 142 | self.feature = nn.Sequential( 143 | nn.Linear(3 * 28 * 28, 100), 144 | nn.ReLU(), 145 | nn.Linear(100, 100), 146 | nn.ReLU(), 147 | nn.Linear(100, 100), 148 | nn.ReLU() 149 | ) 150 | self.classifier = nn.Linear(100, num_classes) 151 | 152 | def forward(self, x, return_feat=False): 153 | x = x.view(x.size(0), -1) 154 | x = self.fc1(x) 155 | x = F.relu(x) 156 | fc1 = x 157 | x = self.fc2(x) 158 | x = F.relu(x) 159 | fc2 = x 160 | x = self.fc3(x) 161 | x = F.relu(x) 162 | fc3 = x 163 | logits = self.classifier(x) 164 | return { 165 | 'fc1': fc1, 166 | 'fc2': fc2, 167 | 'fc3': fc3, 168 | 'before_logits': fc3, 169 | 'logits': logits 170 | } 171 | 172 | 173 | class MoonNet(nn.Module): 174 | def __init__(self, 175 | hidden_dim=500): 176 | super(MoonNet, self).__init__() 177 | 178 | self.fc1 = nn.Linear(2, hidden_dim) 179 | self.fc2 = nn.Linear(hidden_dim, 1) 180 | 181 | def forward(self, x): 182 | return self.fc2(F.relu(self.fc1(x))) 183 | 184 | 185 | class SlabNet(nn.Module): 186 | def __init__(self, num_classes): 187 | super().__init__() 188 | self.fc1 = nn.Linear(50, 75, bias=True) 189 | self.fc2 = nn.Linear(75, 100, bias=True) 190 | self.classifier = nn.Linear(100, num_classes, bias=True) 191 | 192 | def forward(self, x): 193 | fc1 = self.fc1(x) 194 | fc1 = F.relu(fc1) 195 | before_logits = self.fc2(fc1) 196 | x = F.relu(before_logits) 197 | logits = self.classifier(x) 198 | return { 199 | 'fc1': fc1, 200 | 'fc2': before_logits, 201 | 'before_logits': before_logits, 202 | 'logits': logits 203 | } 204 | 205 | def forward_representation_encoder(self, x): 206 | fc1 = self.fc1(x) 207 | fc1 = F.relu(fc1) 208 | fc2 = self.fc2(fc1) 209 | fc2 = F.relu(fc2) 210 | return { 211 | 'fc1': fc1, 212 | 'fc2': fc2 213 | } 214 | 215 | def forward_classifier(self, x): 216 | return self.classifier(x) 217 | 218 | def reset_classifier(self): 219 | self.classifier.reset_parameters() 220 | 221 | def set_representation_encoder_train(self, is_train): 222 | self.fc1.train(is_train) 223 | self.fc2.train(is_train) 224 | 225 | def set_classifier_train(self, is_train): 226 | self.classifier.train(is_train) 227 | 228 | def get_classifier_named_params(self): 229 | return (('classifier.weight', self.classifier.weight), 230 | ('classifier.bias', self.classifier.bias)) 231 | -------------------------------------------------------------------------------- /trainers/learning_from_failure_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from torch import optim 5 | 6 | import models 7 | import models.cnn_models 8 | from trainers.base_trainer import BaseTrainer 9 | from utils.trainer_utils import grad_reverse 10 | from models import model_factory 11 | from utils.bias_retrievers import build_bias_retriever 12 | from utils.trainer_utils import grad_mult, create_optimizer 13 | from torch.optim import * 14 | import logging 15 | from models.model_factory import build_model 16 | from utils.losses import GCELoss 17 | import os 18 | from utils.metrics import Accuracy, GroupWiseAccuracy 19 | from utils.ema import ClasswiseEMA 20 | 21 | 22 | class LffTrainer(BaseTrainer): 23 | """ 24 | Implementation for: 25 | Nam, Junhyun, et al. "Learning from failure: Training debiased classifier from biased classifier." (NeurIPS 2020) 26 | 27 | The method trains a bias-only model using generalized cross entropy loss which helps models focus on easier samples, thereby amplifying the bias 28 | and then uses that to re-weight the samples for the main model, so that easy/biased samples receive lower weights. 29 | """ 30 | 31 | def __init__(self, option): 32 | super(LffTrainer, self).__init__(option) 33 | 34 | def _build_model(self): 35 | super()._build_model() 36 | self.bias_model = build_model(self.option, self.option.model_name, 37 | out_dims=self.option.num_classes, 38 | in_dims=self.option.in_dims, 39 | hid_dims=self.option.hid_dims, 40 | freeze_layers=self.option.freeze_layers) 41 | logging.getLogger().info("Bias model") 42 | logging.getLogger().info(self.bias_model) 43 | self.bias_amplification_loss = GCELoss(q=self.option.bias_loss_gamma) 44 | 45 | if self.option.cuda: 46 | self.model.cuda() 47 | self.bias_model.cuda() 48 | self.loss.cuda() 49 | 50 | def _initialization(self): 51 | super()._initialization() 52 | self.bias_loss_ema_computer = ClasswiseEMA(self.max_dataset_ixs['Train'] + 1, alpha=self.option.bias_ema_gamma) 53 | self.main_loss_ema_computer = ClasswiseEMA(self.max_dataset_ixs['Train'] + 1, alpha=self.option.bias_ema_gamma) 54 | 55 | def _build_optimizer(self): 56 | super()._build_optimizer() 57 | self.bias_optim = create_optimizer(self.option.optimizer_name, 58 | named_params=self.bias_model.named_parameters(), 59 | lr=self.option.lr, 60 | weight_decay=self.option.weight_decay, 61 | momentum=self.option.momentum, 62 | freeze_layers=self.option.freeze_layers) 63 | 64 | def _mode_setting(self, is_train=True): 65 | self.model.train(is_train) 66 | self.bias_model.train(is_train) 67 | 68 | def _train_epoch(self, epoch, data_loader): 69 | self._mode_setting(is_train=True) 70 | for i, batch in enumerate(data_loader): 71 | batch = self.prepare_batch(batch) 72 | 73 | # Pass through the main model 74 | out = self.forward_model(self.model, batch) 75 | logits = out['logits'] 76 | main_loss = self.loss(logits, batch['y'].squeeze()) 77 | 78 | # Pass through the bias model 79 | bias_out = self.forward_model(self.bias_model, batch) 80 | bias_logits = bias_out['logits'] 81 | bias_loss = self.loss(bias_logits, batch['y'].squeeze()) 82 | 83 | # Update the bias and main loss EMAs (for computing weights) 84 | self.bias_loss_ema_computer.update(bias_loss.squeeze(), 85 | batch['dataset_ix'], 86 | batch['y'].squeeze()) 87 | self.main_loss_ema_computer.update(main_loss.squeeze(), 88 | batch['dataset_ix'], 89 | batch['y'].squeeze()) 90 | 91 | # Perform class-wise normalization on bias loss 92 | bias_loss_ema = self.bias_loss_ema_computer.parameter[batch['dataset_ix']].clone() 93 | main_loss_ema = self.main_loss_ema_computer.parameter[batch['dataset_ix']].clone() 94 | 95 | for c in range(self.option.num_classes): 96 | dataset_ixs_for_c = torch.where(batch['y'] == c)[0] 97 | if len(dataset_ixs_for_c) == 0: 98 | continue 99 | max_bias_loss_ema = self.bias_loss_ema_computer.max_loss(c) 100 | bias_loss_ema[dataset_ixs_for_c] /= max_bias_loss_ema 101 | max_main_loss_ema = self.main_loss_ema_computer.max_loss(c) 102 | main_loss_ema[dataset_ixs_for_c] /= max_main_loss_ema 103 | 104 | # Compute sample wise weights for main model 105 | sample_weights = bias_loss_ema.cuda() / (bias_loss_ema.cuda() + main_loss_ema.cuda() + 1e-8) 106 | 107 | # Compute bias amplification loss to update the parameters of bias model 108 | bias_amplication_loss = self.bias_amplification_loss(bias_logits, batch['y'].squeeze()) 109 | 110 | # Weight the main model's loss using sample weights 111 | main_loss = self.loss(logits, batch['y'].squeeze()) * sample_weights 112 | 113 | loss = main_loss.mean() + bias_amplication_loss.mean() 114 | 115 | self.bias_optim.zero_grad() 116 | self.optim.zero_grad() 117 | loss.backward() 118 | # loss.backward(retain_graph=True) 119 | 120 | # if self.option.grad_clip is not None: 121 | # torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip) 122 | # torch.nn.utils.clip_grad_norm_(self.bias_model.parameters(), self.option.grad_clip) 123 | self.bias_optim.step() 124 | self.optim.step() 125 | 126 | split = 'Train' 127 | self.update_generalization_metrics(split + " Main", batch, main_loss) 128 | self.update_generalization_metrics(split + " Bias", batch, bias_loss) 129 | if self.option.enable_groupwise_metrics: 130 | self.update_groupwise_values(split, "Sample Weights", sample_weights, batch) 131 | self.update_groupwise_values(split, "Main Loss EMA", main_loss_ema, batch) 132 | self.update_groupwise_values(split, "Bias Amp Loss", bias_amplication_loss, batch) 133 | self.update_groupwise_values(split, "Bias Loss EMA", bias_loss_ema, batch) 134 | self.loss_visualizer.log(epoch, f'{split} Main') 135 | self.loss_visualizer.accumulate_plot_and_reset(epoch) 136 | 137 | def get_keys_to_save(self): 138 | return super().get_keys_to_save() + ['bias_model', 'bias_optim'] 139 | 140 | def get_current_state(self): 141 | save_state = super().get_current_state() 142 | save_state['bias_loss_ema'] = self.bias_loss_ema_computer.parameter 143 | save_state['main_loss_ema'] = self.main_loss_ema_computer.parameter 144 | return save_state 145 | 146 | def test(self, epoch, data_key, data_loader): 147 | for model, model_key in [[self.model, 'Main'], [self.bias_model, 'Bias']]: 148 | super().test(epoch, data_key, data_loader, model=model, model_key=model_key) -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('-e', '--expt_name', required=False, help='Experiment will be saved under this name.') 7 | parser.add_argument('--expt_type', type=str, 8 | help='The experimental configuration to use e.g., all CelebA runs can use celebA_experiments value for this argument') 9 | parser.add_argument('--dataset_name', required=False, help='Name of the dataset') 10 | parser.add_argument('--root_dir', default=None) 11 | parser.add_argument('--num_classes', type=int, help='Number of classes') 12 | parser.add_argument('--batch_size', default=None, type=int, help='Mini-batch size') 13 | parser.add_argument('--momentum', default=0.9, type=float, help='SGD momentum') 14 | parser.add_argument('--lr', type=float, help='Learning rate') 15 | parser.add_argument('--weight_decay', default=0, type=float, help='Weight decay') 16 | parser.add_argument('--epochs', type=int, help='Num of epochs') 17 | parser.add_argument('--load_checkpoint', default=None, help='Checkpoint to resume from.') 18 | 19 | parser.add_argument('--test_every', default=5, type=int, help='Interval to perform testing') 20 | parser.add_argument('--test_epochs', nargs='+', type=int, 21 | help='List of epochs to perform testing at. Is compatible with test_every argument -- both will be used') 22 | parser.add_argument('--save_predictions_every', default=25, type=int, help='Interval to save predictions and metrics') 23 | parser.add_argument('--save_model_every', default=25, type=int, help='Interval to save the model checkpoint') 24 | parser.add_argument('--data_dir', required=False, 25 | help='Directory where the dataset is stored. We usually assume this convention: {root_dir}/{dataset_name}') 26 | parser.add_argument('--save_dir', required=False, help='Logs, Checkpoints, predictions and metrics will be saved here.') 27 | 28 | parser.add_argument('--random_seed', type=int, help='Random seed', default=1) 29 | parser.add_argument('--num_workers', default=8, type=int, help='Number of workers in data loader') 30 | 31 | parser.add_argument('--grad_clip', type=float, default=None, 32 | help="Grad clip. This wasn't used for any of the methods for the comparison experiments.") 33 | parser.add_argument('--trainer_name', type=str, help="Name of the method e.g., BaseTrainer or GroupDROTrainer") 34 | parser.add_argument('--model_name', type=str, 35 | help="Name of the main model. For two branch models e.g., RUBi, this refers to the name for the main branch.") 36 | parser.add_argument('--bias_model_name', type=str, 37 | help="For two/multi branch setups, this either predicts the bias variables or uses them as input.") 38 | parser.add_argument('--optimizer_name', type=str, default=None, help="e.g., SGD, Adam") 39 | parser.add_argument('--bias_proba', type=float, default=1.1, help='p_bias for BiasedMNIST') 40 | parser.add_argument('--bias_var', type=float, default=0.02) 41 | parser.add_argument('--dummy', action='store_true', 42 | help="A flag used for debugging runs e.g., setting num_workers=0 to make debugging possible and using a smaller dataset size.") 43 | parser.add_argument('--balanced_sampling_attributes', type=str, nargs='+', default=None, 44 | help="List of attributes (as returned in a mini-batch) which should be used for balancing i.e., every unique combination of these attributes will have equal probability of being sampled." 45 | "Useful for GroupDRO") 46 | parser.add_argument('--balanced_sampling_gamma', type=float, default=1.0, 47 | help="Exponentiation for inverse group probability. Higher values would oversample minority patterns a lot.") 48 | parser.add_argument('--freeze_layers', default=None, nargs='+', 49 | help="Can be used to freeze layers i.e., not used for optimization." 50 | "When freezing, you need to disable batch norm and other model-specific settings yourself.") 51 | parser.add_argument('--custom_lr_config', default=None, type=str, help="Unused (deprecated) argument.") 52 | 53 | parser.add_argument('--grad_reverse_factor', type=float, default=-0.1, 54 | help="Reversal parameter for adversarial debiasing e.g., learning not to learn (LNL). Use a negative value.") 55 | parser.add_argument('--loss_type', type=str, default='CrossEntropyLoss') 56 | 57 | # Arguments specific to GroupDROTrainer 58 | parser.add_argument('--num_groups', type=int, help="Number of groups for grouping methods e.g., GroupDRO.") 59 | parser.add_argument('--group_weight_step_size', type=float, default=0.01, 60 | help="Learning rate to update group weights in GroupDRO.") 61 | parser.add_argument('--group_mode', type=str, 62 | help='Grouping mode e.g., unique_bias_value or majority_minority for BiasedMNIST. TODO: remove this.') 63 | parser.add_argument('--bias_predictor_in_layer', type=str, default=None, 64 | help="LNL predicts bias variables from this layer.") 65 | parser.add_argument('--bias_predictor_name', type=str, default=None, help="Bias model name for LNL.") 66 | 67 | parser.add_argument('--bias_variable_name', type=str, default=None, 68 | help="Name of the bias variable used by explicit methods and also used to compute metrics.") 69 | parser.add_argument('--target_name', type=str, default=None, help="Variable name to predict i.e., class variable.") 70 | parser.add_argument('--group_by', type=str, default=None, 71 | help="Dataset is grouped by this variable, usually set to group_ix.") 72 | parser.add_argument('--key_to_group_by', type=str, default=None, help="This provides names for the groups.") 73 | 74 | # Arguments specific to LffTrainer 75 | parser.add_argument('--bias_loss_gamma', type=float, default=0.7, help="Loss gamma for LFF") 76 | parser.add_argument('--bias_ema_gamma', type=float, default=0.7, help="EMA gamma for LFF") 77 | parser.add_argument('--bias_model_hid_dims', type=int, help='Hidden dimensions for the bias model') 78 | 79 | parser.add_argument('--entropy_loss_weight', type=float, default=0, help="Weight for entropy loss weight in LNL.") 80 | 81 | parser.add_argument('--dataset_info', help="Used internally to set dataset specific attributes.") 82 | parser.add_argument('--enable_groupwise_metrics', action='store_true') 83 | parser.add_argument('--project_name', type=str, default='Bias-Mitigators', help="Results will be saved here.") 84 | 85 | # Arguments specific to RunningFocalLossTrainer 86 | parser.add_argument('--in_dims', type=int, default=None) 87 | parser.add_argument('--hid_dims', type=int, default=None) 88 | parser.add_argument('--grad_penalty_weight', type=float, default=1.0) 89 | parser.add_argument('--expt_dir', type=str) 90 | parser.add_argument('--bias_variable_type', type=str) 91 | 92 | parser.add_argument('--spectral_decoupling_lambda', type=float) 93 | parser.add_argument('--spectral_decoupling_lambdas', type=float, nargs='+') 94 | parser.add_argument('--spectral_decoupling_gamma', type=float) 95 | parser.add_argument('--spectral_decoupling_gammas', type=float, nargs='+') 96 | 97 | parser.add_argument('--num_envs_per_batch', type=int, 98 | help="Used by IRMv1. Each mini-batch will contain the specified number of environments.") 99 | 100 | 101 | def get_option(): 102 | option = parser.parse_args() 103 | option.cuda = True 104 | if option.dummy: 105 | option.num_workers = 0 106 | return option 107 | 108 | 109 | # Used when bash files are not used 110 | ROOT = '/hdd/user' 111 | EXPT_ROOT = '/hdd/user/bias_mitigators' 112 | -------------------------------------------------------------------------------- /models/coordconv.py: -------------------------------------------------------------------------------- 1 | # https://raw.githubusercontent.com/walsvid/CoordConv/052d45354bae46f5fdb0f906fac31f3e1f9debe2/coordconv.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.modules.conv as conv 5 | 6 | 7 | class AddCoords(nn.Module): 8 | def __init__(self, rank, with_r=False, use_cuda=True): 9 | super(AddCoords, self).__init__() 10 | self.rank = rank 11 | self.with_r = with_r 12 | self.use_cuda = use_cuda 13 | 14 | def forward(self, input_tensor): 15 | """ 16 | :param input_tensor: shape (N, C_in, H, W) 17 | :return: 18 | """ 19 | if self.rank == 1: 20 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape 21 | xx_range = torch.arange(dim_x, dtype=torch.int32) 22 | xx_channel = xx_range[None, None, :] 23 | 24 | xx_channel = xx_channel.float() / (dim_x - 1) 25 | xx_channel = xx_channel * 2 - 1 26 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) 27 | 28 | if torch.cuda.is_available and self.use_cuda: 29 | input_tensor = input_tensor.cuda() 30 | xx_channel = xx_channel.cuda() 31 | out = torch.cat([input_tensor, xx_channel], dim=1) 32 | 33 | if self.with_r: 34 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) 35 | out = torch.cat([out, rr], dim=1) 36 | 37 | elif self.rank == 2: 38 | batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape 39 | xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) 40 | yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) 41 | 42 | xx_range = torch.arange(dim_y, dtype=torch.int32) 43 | yy_range = torch.arange(dim_x, dtype=torch.int32) 44 | xx_range = xx_range[None, None, :, None] 45 | yy_range = yy_range[None, None, :, None] 46 | 47 | xx_channel = torch.matmul(xx_range, xx_ones) 48 | yy_channel = torch.matmul(yy_range, yy_ones) 49 | 50 | # transpose y 51 | yy_channel = yy_channel.permute(0, 1, 3, 2) 52 | 53 | xx_channel = xx_channel.float() / (dim_y - 1) 54 | yy_channel = yy_channel.float() / (dim_x - 1) 55 | 56 | xx_channel = xx_channel * 2 - 1 57 | yy_channel = yy_channel * 2 - 1 58 | 59 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) 60 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) 61 | 62 | if torch.cuda.is_available and self.use_cuda: 63 | input_tensor = input_tensor.cuda() 64 | xx_channel = xx_channel.cuda() 65 | yy_channel = yy_channel.cuda() 66 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 67 | 68 | if self.with_r: 69 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 70 | out = torch.cat([out, rr], dim=1) 71 | 72 | elif self.rank == 3: 73 | batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = input_tensor.shape 74 | xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) 75 | yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) 76 | zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) 77 | 78 | xy_range = torch.arange(dim_y, dtype=torch.int32) 79 | xy_range = xy_range[None, None, None, :, None] 80 | 81 | yz_range = torch.arange(dim_z, dtype=torch.int32) 82 | yz_range = yz_range[None, None, None, :, None] 83 | 84 | zx_range = torch.arange(dim_x, dtype=torch.int32) 85 | zx_range = zx_range[None, None, None, :, None] 86 | 87 | xy_channel = torch.matmul(xy_range, xx_ones) 88 | xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) 89 | 90 | yz_channel = torch.matmul(yz_range, yy_ones) 91 | yz_channel = yz_channel.permute(0, 1, 3, 4, 2) 92 | yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) 93 | 94 | zx_channel = torch.matmul(zx_range, zz_ones) 95 | zx_channel = zx_channel.permute(0, 1, 4, 2, 3) 96 | zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) 97 | 98 | if torch.cuda.is_available and self.use_cuda: 99 | input_tensor = input_tensor.cuda() 100 | xx_channel = xx_channel.cuda() 101 | yy_channel = yy_channel.cuda() 102 | zz_channel = zz_channel.cuda() 103 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1) 104 | 105 | if self.with_r: 106 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + 107 | torch.pow(yy_channel - 0.5, 2) + 108 | torch.pow(zz_channel - 0.5, 2)) 109 | out = torch.cat([out, rr], dim=1) 110 | else: 111 | raise NotImplementedError 112 | 113 | return out 114 | 115 | 116 | class CoordConv1d(conv.Conv1d): 117 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 118 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 119 | super(CoordConv1d, self).__init__(in_channels, out_channels, kernel_size, 120 | stride, padding, dilation, groups, bias) 121 | self.rank = 1 122 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 123 | self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), out_channels, 124 | kernel_size, stride, padding, dilation, groups, bias) 125 | 126 | def forward(self, input_tensor): 127 | """ 128 | input_tensor_shape: (N, C_in,H,W) 129 | output_tensor_shape: N,C_out,H_out,W_out) 130 | :return: CoordConv2d Result 131 | """ 132 | out = self.addcoords(input_tensor) 133 | out = self.conv(out) 134 | 135 | return out 136 | 137 | 138 | class CoordConv2d(conv.Conv2d): 139 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 140 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 141 | super(CoordConv2d, self).__init__(in_channels, out_channels, kernel_size, 142 | stride, padding, dilation, groups, bias) 143 | self.rank = 2 144 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 145 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels, 146 | kernel_size, stride, padding, dilation, groups, bias) 147 | 148 | def forward(self, input_tensor): 149 | """ 150 | input_tensor_shape: (N, C_in,H,W) 151 | output_tensor_shape: N,C_out,H_out,W_out) 152 | :return: CoordConv2d Result 153 | """ 154 | out = self.addcoords(input_tensor) 155 | out = self.conv(out) 156 | 157 | return out 158 | 159 | 160 | class CoordConv3d(conv.Conv3d): 161 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 162 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 163 | super(CoordConv3d, self).__init__(in_channels, out_channels, kernel_size, 164 | stride, padding, dilation, groups, bias) 165 | self.rank = 3 166 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 167 | self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), out_channels, 168 | kernel_size, stride, padding, dilation, groups, bias) 169 | 170 | def forward(self, input_tensor): 171 | """ 172 | input_tensor_shape: (N, C_in,H,W) 173 | output_tensor_shape: N,C_out,H_out,W_out) 174 | :return: CoordConv2d Result 175 | """ 176 | out = self.addcoords(input_tensor) 177 | out = self.conv(out) 178 | 179 | return out 180 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import torch 4 | 5 | 6 | class Accuracy(): 7 | def __init__(self, num_classes): 8 | self.num_classes = num_classes 9 | self.reset() 10 | 11 | def update(self, pred_ys, gt_ys): 12 | for pred_y, gt_y in zip(pred_ys, gt_ys): 13 | if pred_y == gt_y: 14 | self.correct[pred_y] += 1 15 | self.total[gt_y] += 1 16 | 17 | def get_accuracy(self): 18 | return self.correct.sum() / self.total.sum() 19 | 20 | def get_per_class_accuracy(self): 21 | correct = np.asarray(self.correct) 22 | total = np.asarray(self.total) 23 | for tix, (c, t) in enumerate(zip(correct, total)): 24 | if t == 0: 25 | correct[tix] = 1 26 | total[tix] = 1 27 | return correct / total 28 | 29 | def get_mean_per_class_accuracy(self): 30 | return self.get_per_class_accuracy().mean() 31 | 32 | def reset(self): 33 | self.correct = np.zeros((self.num_classes)) 34 | self.total = np.zeros((self.num_classes)) 35 | 36 | 37 | class GroupWiseAccuracy(): 38 | def __init__(self): 39 | self.reset() 40 | 41 | def update(self, pred_ys, gt_ys, group_names): 42 | for pred_y, gt_y, group_name in zip(pred_ys, gt_ys, group_names): 43 | group_name = str(group_name) 44 | if group_name not in self.group_wise_total: 45 | self.group_wise_total[group_name] = 0 46 | self.group_wise_correct[group_name] = 0 47 | if pred_y == gt_y: 48 | self.group_wise_correct[group_name] += 1 49 | self.group_wise_total[group_name] += 1 50 | 51 | def get_per_group_accuracy(self): 52 | per_group_accuracy = {} 53 | for group_name in self.group_wise_correct: 54 | per_group_accuracy[group_name] = self.group_wise_correct[group_name] / self.group_wise_total[group_name] 55 | return per_group_accuracy 56 | 57 | def get_mean_per_group_accuracy(self): 58 | total_acc, total_num = 0, 0 59 | per_group_accuracy = self.get_per_group_accuracy() 60 | for group_name in per_group_accuracy: 61 | total_acc += per_group_accuracy[group_name] 62 | total_num += 1 63 | return total_acc / total_num 64 | 65 | def reset(self): 66 | self.group_wise_correct = {} 67 | self.group_wise_total = {} 68 | 69 | def log(self, prefix=''): 70 | log_str = prefix 71 | per_group_accuracy = self.get_per_group_accuracy() 72 | group_names = sorted([k for k in per_group_accuracy.keys()]) 73 | accuracies = "" 74 | for group_name in group_names: 75 | log_str += '%s: %.2f%% ' % (group_name, per_group_accuracy[group_name] * 100) 76 | accuracies += '%.2f%%, ' % (per_group_accuracy[group_name] * 100) 77 | log_str += ' MPG: %.2f%%' % (self.get_mean_per_group_accuracy() * 100) 78 | # logging.getLogger().info(log_str) 79 | logging.getLogger().info(f"Group names {group_names}") 80 | logging.getLogger().info(f"Accuracies {accuracies}") 81 | 82 | 83 | class ModuleStatsComputer: 84 | """Stores sensitivities, GT classes and predicted classes""" 85 | 86 | def __init__(self, num_modules, num_classes): 87 | self.reset() 88 | self.num_modules = num_modules 89 | self.num_classes = num_classes 90 | 91 | def reset(self): 92 | self.sensitivities = [] 93 | self.gt_class_ixs, self.pred_class_ixs = [], [] 94 | self.group_names = [] 95 | 96 | def update(self, sensitivities, gt_class_ixs, pred_class_ixs, group_names): 97 | self.sensitivities += sensitivities 98 | self.gt_class_ixs += gt_class_ixs 99 | self.pred_class_ixs += pred_class_ixs 100 | self.group_names += [str(gn) for gn in group_names] 101 | 102 | def log(self): 103 | sensitivities = np.asarray(self.sensitivities) 104 | gt_class_ixs = np.asarray(self.gt_class_ixs) 105 | pred_class_ixs = np.asarray(self.pred_class_ixs) 106 | most_sensitive_ixs = np.argmax(sensitivities, axis=1) 107 | group_names = np.asarray(self.group_names) 108 | 109 | most_sensitive_counts = {ix: len(np.nonzero(most_sensitive_ixs == ix)[0]) 110 | for ix in range(0, self.num_modules)} 111 | most_sens_n_correct = {} 112 | accuracy_when_most_sensitive = {} 113 | group_distribution = {} 114 | overall_metrics = {} 115 | for module_ix in range(self.num_modules): 116 | most_sens_n_correct[module_ix] = len(np.intersect1d(np.nonzero(most_sensitive_ixs == module_ix)[0], 117 | np.nonzero(gt_class_ixs == pred_class_ixs)[0])) 118 | accuracy_when_most_sensitive[module_ix] = 100 * most_sens_n_correct[module_ix] / max(most_sensitive_counts[ 119 | module_ix], 1) 120 | if module_ix not in group_distribution: 121 | group_distribution[module_ix] = {} 122 | for gn in group_names[np.nonzero(most_sensitive_ixs == module_ix)[0]]: 123 | if gn not in group_distribution[module_ix]: 124 | group_distribution[module_ix][gn] = 0 125 | group_distribution[module_ix][gn] += 1 126 | 127 | for module_ix in range(self.num_modules): 128 | if int(most_sensitive_counts[module_ix]) > 0: 129 | overall_metrics[module_ix] = { 130 | 'most_sensitive_count': int(most_sensitive_counts[module_ix]), 131 | 'accuracy_when_most_sensitive': '%.2f%%' % (float(accuracy_when_most_sensitive[module_ix])), 132 | 'group_distribution': group_distribution[module_ix] 133 | } 134 | 135 | 136 | class GradientTracker(): 137 | def __init__(self, num_samples, num_epochs): 138 | # l2_norm = torch.norm(gradients.detach(), p=2, dim=1) 139 | # abs = torch.sum(torch.abs(gradients.detach()), dim=1) 140 | 141 | self.l2_norms = torch.zeros((num_samples, num_epochs)) 142 | self.abs_vals = torch.zeros((num_samples, num_epochs)) 143 | self.groups = np.asarray(['NoneNoneNoneNoneNoneNone'] * (num_samples)) 144 | self.unq_groups = {} 145 | 146 | def update(self, epoch, dataset_ixs, gradients, groups): 147 | self.epoch = epoch 148 | # self.l2_norms[dataset_ixs, epoch - 1] = torch.norm(gradients.detach().flatten(1), p=2, dim=1).cpu() 149 | # self.abs_vals[dataset_ixs, epoch - 1] = torch.sum(torch.abs(gradients.detach().flatten(1)), dim=1).cpu() 150 | if epoch == 1: 151 | self.groups[dataset_ixs] = groups 152 | for g in groups: 153 | if g not in self.unq_groups: 154 | self.unq_groups[g] = g 155 | 156 | def get_groupwise_values(self): 157 | # return mean, variance, normalized variance, mean of running variance? 158 | # How does it evolve over time 159 | values = {} 160 | for g in self.unq_groups: 161 | if g not in values: 162 | values[g] = {} 163 | grp_ixs = np.nonzero(self.groups == g)[0] 164 | l2_norms = self.l2_norms[grp_ixs, self.epoch - 1] 165 | values[g]['mean_of_l2_norms'] = torch.mean(l2_norms) 166 | values[g]['variance_of_l2_norms'] = torch.std(l2_norms) ** 2 167 | 168 | squared_l2_norms = l2_norms ** 2 169 | values[g]['mean_of_squared_l2_norms'] = torch.mean(squared_l2_norms) 170 | values[g]['variance_of_squared_l2_norms'] = torch.std(squared_l2_norms) ** 2 171 | 172 | abs = self.abs_vals[grp_ixs, self.epoch - 1] 173 | values[g]['mean_of_abs'] = torch.mean(abs) 174 | values[g]['variance_of_abs'] = torch.std(abs) ** 2 175 | 176 | return values 177 | 178 | 179 | class PredictionChangeTracker(): 180 | def __init__(self, num_samples, num_epochs): 181 | self.preds = torch.zeros((num_samples, num_epochs)) 182 | self.num_pred_changes = torch.zeros((num_samples)) 183 | self.groups = np.asarray(['NoneNoneNoneNoneNoneNone'] * (num_samples)) 184 | self.unq_groups = {} 185 | 186 | def update(self, dataset_ixs, epoch, logits, labels, groups): 187 | self.epoch = epoch 188 | self.preds[dataset_ixs, epoch - 1] = torch.argmax(logits.cpu(), dim=1).float() 189 | if epoch > 1: 190 | pred_change_mask = torch.where(self.preds[dataset_ixs, epoch - 2] != self.preds[dataset_ixs, epoch - 1], 191 | torch.ones_like(dataset_ixs), 192 | torch.zeros_like(dataset_ixs)) 193 | self.groups[dataset_ixs] = groups 194 | for g in groups: 195 | if g not in self.unq_groups: 196 | self.unq_groups[g] = g 197 | self.num_pred_changes[dataset_ixs] += pred_change_mask 198 | 199 | def get_values(self): 200 | group_change_pct = {} 201 | max_num_changes = {} 202 | mean_num_changes = {} 203 | for g in self.unq_groups: 204 | grp_ixs = np.nonzero(self.groups == g)[0] 205 | num_changes = self.num_pred_changes[grp_ixs].sum() 206 | grp_len = len(grp_ixs) 207 | group_change_pct[g] = (num_changes / (grp_len * (self.epoch - 1))) * 100 208 | max_num_changes[g] = self.num_pred_changes[grp_ixs].max().item() 209 | 210 | return { 211 | 'group_change_percent': group_change_pct, 212 | 'max_num_changes': max_num_changes, 213 | } 214 | -------------------------------------------------------------------------------- /models/variable_width_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | 9 | 10 | # __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | # 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 12 | # 'wide_resnet50_2', 'wide_resnet101_2'] 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | __constants__ = ['downsample'] 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 31 | base_width=64, dilation=1, norm_layer=None): 32 | super(BasicBlock, self).__init__() 33 | if norm_layer is None: 34 | norm_layer = nn.BatchNorm2d 35 | if groups != 1 or base_width != 64: 36 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 37 | if dilation > 1: 38 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | __constants__ = ['downsample'] 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 72 | base_width=64, dilation=1, norm_layer=None): 73 | super(Bottleneck, self).__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | width = int(planes * (base_width / 64.)) * groups 77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 78 | self.conv1 = conv1x1(inplanes, width) 79 | self.bn1 = norm_layer(width) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.conv3 = conv1x1(width, planes * self.expansion) 83 | self.bn3 = norm_layer(planes * self.expansion) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class VariableWidthResNet(nn.Module): 112 | 113 | def __init__(self, block, layers, width, num_classes=1000, zero_init_residual=False, 114 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 115 | norm_layer=None): 116 | super(VariableWidthResNet, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | self._norm_layer = norm_layer 120 | 121 | self.inplanes = width 122 | self.dilation = 1 123 | if replace_stride_with_dilation is None: 124 | # each element in the tuple indicates if we should replace 125 | # the 2x2 stride with a dilated convolution instead 126 | replace_stride_with_dilation = [False, False, False] 127 | if len(replace_stride_with_dilation) != 3: 128 | raise ValueError("replace_stride_with_dilation should be None " 129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 130 | self.groups = groups 131 | self.base_width = width_per_group 132 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 133 | bias=False) 134 | self.bn1 = norm_layer(self.inplanes) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, width, layers[0]) 138 | self.layer2 = self._make_layer(block, width * 2, layers[1], stride=2, 139 | dilate=replace_stride_with_dilation[0]) 140 | self.layer3 = self._make_layer(block, width * 4, layers[2], stride=2, 141 | dilate=replace_stride_with_dilation[1]) 142 | self.layer4 = self._make_layer(block, width * 8, layers[3], stride=2, 143 | dilate=replace_stride_with_dilation[2]) 144 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 145 | self.fc = nn.Linear(8 * width * block.expansion, num_classes) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 150 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | # Zero-initialize the last BN in each residual branch, 155 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 156 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 157 | if zero_init_residual: 158 | for m in self.modules(): 159 | if isinstance(m, Bottleneck): 160 | nn.init.constant_(m.bn3.weight, 0) 161 | elif isinstance(m, BasicBlock): 162 | nn.init.constant_(m.bn2.weight, 0) 163 | 164 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 165 | norm_layer = self._norm_layer 166 | downsample = None 167 | previous_dilation = self.dilation 168 | if dilate: 169 | self.dilation *= stride 170 | stride = 1 171 | if stride != 1 or self.inplanes != planes * block.expansion: 172 | downsample = nn.Sequential( 173 | conv1x1(self.inplanes, planes * block.expansion, stride), 174 | norm_layer(planes * block.expansion), 175 | ) 176 | 177 | layers = [] 178 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 179 | self.base_width, previous_dilation, norm_layer)) 180 | self.inplanes = planes * block.expansion 181 | for _ in range(1, blocks): 182 | layers.append(block(self.inplanes, planes, groups=self.groups, 183 | base_width=self.base_width, dilation=self.dilation, 184 | norm_layer=norm_layer)) 185 | 186 | return nn.Sequential(*layers) 187 | 188 | def _forward_impl(self, x): 189 | # See note [TorchScript super()] 190 | x = self.conv1(x) 191 | x = self.bn1(x) 192 | x = self.relu(x) 193 | x = self.maxpool(x) 194 | 195 | x = self.layer1(x) 196 | x = self.layer2(x) 197 | x = self.layer3(x) 198 | x = self.layer4(x) 199 | 200 | x = self.avgpool(x) 201 | x = torch.flatten(x, 1) 202 | x = self.fc(x) 203 | 204 | return x 205 | 206 | def forward(self, x): 207 | return self._forward_impl(x) 208 | 209 | 210 | def _vwresnet(arch, block, layers, width, pretrained, progress, **kwargs): 211 | assert not pretrained, "No pretrained model for variable width ResNets" 212 | model = VariableWidthResNet(block, layers, width, **kwargs) 213 | return model 214 | 215 | 216 | def resnet10vw(width, pretrained=False, progress=True, **kwargs): 217 | r"""ResNet-18 model from 218 | `"Deep Residual Learning for Image Recognition" `_ 219 | Args: 220 | pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | progress (bool): If True, displays a progress bar of the download to stderr 222 | """ 223 | return _vwresnet('resnet10', BasicBlock, [1, 1, 1, 1], width, pretrained, progress, 224 | **kwargs) 225 | 226 | 227 | def resnet18vw(width, pretrained=False, progress=True, **kwargs): 228 | r"""ResNet-18 model from 229 | `"Deep Residual Learning for Image Recognition" `_ 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | progress (bool): If True, displays a progress bar of the download to stderr 233 | """ 234 | return _vwresnet('resnet18', BasicBlock, [2, 2, 2, 2], width, pretrained, progress, 235 | **kwargs) 236 | 237 | 238 | def resnet34vw(width, pretrained=False, progress=True, **kwargs): 239 | r"""ResNet-34 model from 240 | `"Deep Residual Learning for Image Recognition" `_ 241 | Args: 242 | pretrained (bool): If True, returns a model pre-trained on ImageNet 243 | progress (bool): If True, displays a progress bar of the download to stderr 244 | """ 245 | return _vwresnet('resnet34', BasicBlock, [3, 4, 6, 3], width, pretrained, progress, 246 | **kwargs) 247 | 248 | 249 | def resnet50vw(width, pretrained=False, progress=True, **kwargs): 250 | r"""ResNet-50 model from 251 | `"Deep Residual Learning for Image Recognition" `_ 252 | Args: 253 | pretrained (bool): If True, returns a model pre-trained on ImageNet 254 | progress (bool): If True, displays a progress bar of the download to stderr 255 | """ 256 | return _vwresnet('resnet50', Bottleneck, [3, 4, 6, 3], width, pretrained, progress, 257 | **kwargs) 258 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.utils.data import Sampler 5 | 6 | 7 | # https://raw.githubusercontent.com/Cadene/bootstrap.pytorch/master/bootstrap/datasets/transforms.py 8 | 9 | class Compose(object): 10 | """Composes several collate together. 11 | 12 | Args: 13 | transforms (list of ``Collate`` objects): list of transforms to compose. 14 | """ 15 | 16 | def __init__(self, transforms): 17 | self.transforms = transforms 18 | 19 | def __call__(self, batch): 20 | for transform in self.transforms: 21 | batch = transform(batch) 22 | return batch 23 | 24 | 25 | class ListDictsToDictLists(object): 26 | 27 | def __init__(self): 28 | pass 29 | 30 | def __call__(self, batch): 31 | batch = self.ld_to_dl(batch) 32 | return batch 33 | 34 | def ld_to_dl(self, batch): 35 | if isinstance(batch[0], collections.Mapping): 36 | return {key: self.ld_to_dl([d[key] for d in batch]) for key in batch[0]} 37 | else: 38 | return batch 39 | 40 | 41 | class PadTensors(object): 42 | 43 | def __init__(self, value=0, use_keys=[], avoid_keys=[]): 44 | self.value = value 45 | self.use_keys = use_keys 46 | if len(self.use_keys) > 0: 47 | self.avoid_keys = [] 48 | else: 49 | self.avoid_keys = avoid_keys 50 | 51 | def __call__(self, batch): 52 | batch = self.pad_tensors(batch) 53 | return batch 54 | 55 | def pad_tensors(self, batch): 56 | if isinstance(batch, collections.Mapping): 57 | out = {} 58 | for key, value in batch.items(): 59 | if (key in self.use_keys) or \ 60 | (len(self.use_keys) == 0 and key not in self.avoid_keys): 61 | out[key] = self.pad_tensors(value) 62 | else: 63 | out[key] = value 64 | return out 65 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): 66 | max_size = [max([item.size(i) for item in batch]) for i in range(batch[0].dim())] 67 | max_size = torch.Size(max_size) 68 | n_batch = [] 69 | for item in batch: 70 | if item.size() != max_size: 71 | n_item = item.new(max_size).fill_(self.value) 72 | # TODO: Improve this 73 | if item.dim() == 1: 74 | n_item[:item.size(0)] = item 75 | elif item.dim() == 2: 76 | n_item[:item.size(0), :item.size(1)] = item 77 | elif item.dim() == 3: 78 | n_item[:item.size(0), :item.size(1), :item.size(2)] = item 79 | else: 80 | raise ValueError 81 | n_batch.append(n_item) 82 | else: 83 | n_batch.append(item) 84 | return n_batch 85 | else: 86 | return batch 87 | 88 | 89 | class StackTensors(object): 90 | 91 | def __init__(self, use_shared_memory=False, avoid_keys=[]): 92 | self.use_shared_memory = use_shared_memory 93 | self.avoid_keys = avoid_keys 94 | 95 | def __call__(self, batch): 96 | batch = self.stack_tensors(batch) 97 | return batch 98 | 99 | # key argument is useful for debuging 100 | def stack_tensors(self, batch, key=None): 101 | if isinstance(batch, collections.Mapping): 102 | out = {} 103 | for key, value in batch.items(): 104 | if key not in self.avoid_keys: 105 | out[key] = self.stack_tensors(value, key=key) 106 | else: 107 | out[key] = value 108 | return out 109 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): 110 | out = None 111 | if self.use_shared_memory: 112 | # If we're in a background process, concatenate directly into a 113 | # shared memory tensor to avoid an extra copy 114 | numel = sum([x.numel() for x in batch]) 115 | storage = batch[0].storage()._new_shared(numel) 116 | out = batch[0].new(storage) 117 | return torch.stack(batch, 0, out=out) 118 | else: 119 | return batch 120 | 121 | 122 | class CatTensors(object): 123 | 124 | def __init__(self, use_shared_memory=False, use_keys=[], avoid_keys=[]): 125 | self.use_shared_memory = use_shared_memory 126 | self.use_keys = use_keys 127 | if len(self.use_keys) > 0: 128 | self.avoid_keys = [] 129 | else: 130 | self.avoid_keys = avoid_keys 131 | 132 | def __call__(self, batch): 133 | batch = self.cat_tensors(batch) 134 | return batch 135 | 136 | def cat_tensors(self, batch): 137 | if isinstance(batch, collections.Mapping): 138 | out = {} 139 | for key, value in batch.items(): 140 | if (key in self.use_keys) or \ 141 | (len(self.use_keys) == 0 and key not in self.avoid_keys): 142 | out[key] = self.cat_tensors(value) 143 | if ('batch_id' not in out) and torch.is_tensor(value[0]): 144 | out['batch_id'] = torch.cat([i * torch.ones(x.size(0)) \ 145 | for i, x in enumerate(value)], 0) 146 | else: 147 | out[key] = value 148 | return out 149 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): 150 | out = None 151 | if self.use_shared_memory: 152 | # If we're in a background process, concatenate directly into a 153 | # shared memory tensor to avoid an extra copy 154 | numel = sum([x.numel() for x in batch]) 155 | storage = batch[0].storage()._new_shared(numel) 156 | out = batch[0].new(storage) 157 | return torch.cat(batch, 0, out=out) 158 | else: 159 | return batch 160 | 161 | 162 | class ToCuda(object): 163 | 164 | def __init__(self, *args, **kwargs): 165 | pass 166 | 167 | def __call__(self, batch): 168 | batch = self.to_cuda(batch) 169 | return batch 170 | 171 | def to_cuda(self, batch): 172 | if isinstance(batch, collections.Mapping): 173 | return {key: self.to_cuda(value) for key, value in batch.items()} 174 | elif torch.is_tensor(batch): 175 | # TODO: verify async usage 176 | return batch.cuda(non_blocking=True) 177 | elif type(batch).__name__ == 'Variable': 178 | # TODO: Really hacky 179 | return Variable(batch.data.cuda(non_blocking=True)) 180 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): 181 | return [self.to_cuda(value) for value in batch] 182 | else: 183 | return batch 184 | 185 | 186 | class ToCpu(object): 187 | 188 | def __init__(self, *args, **kwargs): 189 | pass 190 | 191 | def __call__(self, batch): 192 | batch = self.to_cpu(batch) 193 | return batch 194 | 195 | def to_cpu(self, batch): 196 | if isinstance(batch, collections.Mapping): 197 | return {key: self.to_cpu(value) for key, value in batch.items()} 198 | elif torch.is_tensor(batch): 199 | return batch.cpu() 200 | elif type(batch).__name__ == 'Variable': 201 | # TODO: Really hacky 202 | return Variable(batch.data.cpu()) 203 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): 204 | return [self.to_cpu(value) for value in batch] 205 | else: 206 | return batch 207 | 208 | 209 | class ToVariable(object): 210 | 211 | def __init__(self, volatile=False): 212 | self.volatile = volatile 213 | 214 | def __call__(self, batch): 215 | batch = self.to_variable(batch) 216 | return batch 217 | 218 | def to_variable(self, batch): 219 | if torch.is_tensor(batch): 220 | if self.volatile: 221 | return Variable(batch, volatile=True) 222 | else: 223 | return Variable(batch) 224 | elif isinstance(batch, collections.Mapping): 225 | return {key: self.to_variable(value) for key, value in batch.items()} 226 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): 227 | return [self.to_variable(value) for value in batch] 228 | else: 229 | return batch 230 | 231 | 232 | class ToDetach(object): 233 | 234 | def __init__(self): 235 | pass 236 | 237 | def __call__(self, batch): 238 | batch = self.to_detach(batch) 239 | return batch 240 | 241 | def to_detach(self, batch): 242 | if torch.is_tensor(batch): 243 | return batch.detach_() 244 | elif isinstance(batch, collections.Mapping): 245 | return {key: self.to_detach(value) for key, value in batch.items()} 246 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): 247 | return [self.to_detach(value) for value in batch] 248 | else: 249 | return batch 250 | 251 | 252 | class SortByKey(object): 253 | 254 | def __init__(self, key='lengths', reverse=True): 255 | self.key = key 256 | self.reverse = True 257 | self.i = 0 258 | 259 | def __call__(self, batch): 260 | self.set_sort_keys(batch[self.key]) # must be a list 261 | batch = self.sort_by_key(batch) 262 | return batch 263 | 264 | def set_sort_keys(self, sort_keys): 265 | self.i = 0 266 | self.sort_keys = sort_keys 267 | 268 | # ugly hack to be able to sort without lambda function 269 | def get_key(self, _): 270 | key = self.sort_keys[self.i] 271 | self.i += 1 272 | if self.i >= len(self.sort_keys): 273 | self.i = 0 274 | return key 275 | 276 | def sort_by_key(self, batch): 277 | if isinstance(batch, collections.Mapping): 278 | return {key: self.sort_by_key(value) for key, value in batch.items()} 279 | elif type(batch) is list: # isinstance(batch, collections.Sequence): 280 | return sorted(batch, key=self.get_key, reverse=self.reverse) 281 | else: 282 | return batch 283 | 284 | 285 | def dict_collate_fn(): 286 | return Compose([ 287 | ListDictsToDictLists(), 288 | StackTensors() 289 | ]) 290 | 291 | 292 | class IndexSampler(Sampler): 293 | def __init__(self, data_source): 294 | self.data_source = data_source 295 | 296 | def __iter__(self): 297 | return iter(self.data_source) 298 | 299 | def __len__(self): 300 | return len(self.data_source) 301 | --------------------------------------------------------------------------------