├── __init__.py
├── eval
├── __init__.py
└── vqa
│ ├── __init__.py
│ ├── plot_tail.py
│ └── gqa_eval_from_file.py
├── models
├── __init__.py
├── vqa
│ ├── __init__.py
│ ├── updn
│ │ ├── __init__.py
│ │ ├── tda.py
│ │ └── net.py
│ ├── make_mask.py
│ ├── ops
│ │ ├── layer_norm.py
│ │ └── fc.py
│ ├── vqa_adapter.py
│ ├── ban
│ │ ├── ban.py
│ │ └── _ban.py
│ └── mcan
│ │ ├── net.py
│ │ └── mca.py
├── model_factory.py
├── fc_models.py
├── coordconv.py
└── variable_width_resnet.py
├── utils
├── __init__.py
├── bias_retrievers.py
├── losses.py
├── running_stats.py
├── format_utils.py
├── ema.py
├── trainer_utils.py
├── metric_visualizer.py
├── metrics.py
└── data_utils.py
├── .gitignore
├── datasets
├── __init__.py
├── vqa
│ ├── __init__.py
│ ├── feat_filter.py
│ ├── base_vqa_dataset.py
│ ├── gqa_feat_preproc.py
│ └── ans_punct.py
├── .idea
│ ├── .gitignore
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── $CACHE_FILE$
│ ├── vcs.xml
│ ├── deployment.xml
│ ├── modules.xml
│ ├── misc.xml
│ └── datasets.iml
├── dataloader_factory.py
└── shape_generator.py
├── common.sh
├── images
├── main_table.jpg
├── scalability.jpg
├── bias_exploitation.jpg
└── distribution_variance.jpg
├── trainers
├── trainer_factory.py
├── __init__.py
├── group_upweighting_trainer.py
├── spectral_decoupling_trainer.py
├── group_dro_trainer.py
├── rubi_trainer.py
├── lnl_trainer.py
├── irm_v1_trainer.py
└── learning_from_failure_trainer.py
├── scripts
├── gqa-ood
│ ├── run_multiple_gqa.sh
│ ├── baseline_gqa.sh
│ ├── lff_gqa.sh
│ ├── spectral_decoupling_gqa.sh
│ ├── upweighting_gqa.sh
│ ├── rubi_gqa.sh
│ ├── preprocess_gqa.sh
│ ├── group_dro_gqa.sh
│ ├── irmv1_gqa.sh
│ ├── lnl_gqa.sh
│ └── hyperparam_search_gqa.sh
├── celebA
│ ├── rubi_celebA.sh
│ ├── group_upweighting_celebA.sh
│ ├── group_dro_celebA.sh
│ ├── baseline_celebA.sh
│ ├── irmv1_celebA.sh
│ ├── spectral_decoupling_celebA.sh
│ ├── lnl_celebA.sh
│ └── learning_from_failure_celebA.sh
├── biased_mnist
│ ├── rubi_biased_mnist.sh
│ ├── baseline_biased_mnist.sh
│ ├── group_dro_biased_mnist.sh
│ ├── group_upweighting_biased_mnist.sh
│ ├── irmv1_biased_mnist.sh
│ ├── learning_from_failure_biased_mnist.sh
│ ├── lnl_biased_mnist.sh
│ └── spectral_decoupling_biased_mnist.sh
└── biased_mnist_dataset_generator
│ └── full_generator.sh
├── experiments
├── __init__.py
├── gqa_experiments.py
└── celebA_experiments.py
├── LICENSE
├── main.py
├── README.md
└── option.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/eval/vqa/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/vqa/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/vqa/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/vqa/updn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/common.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONUNBUFFERED=1
3 | ROOT=/hdd/robik # Change this!
4 |
--------------------------------------------------------------------------------
/images/main_table.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/main_table.jpg
--------------------------------------------------------------------------------
/images/scalability.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/scalability.jpg
--------------------------------------------------------------------------------
/images/bias_exploitation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/bias_exploitation.jpg
--------------------------------------------------------------------------------
/images/distribution_variance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erobic/bias-mitigators/HEAD/images/distribution_variance.jpg
--------------------------------------------------------------------------------
/trainers/trainer_factory.py:
--------------------------------------------------------------------------------
1 | from trainers import *
2 |
3 | def build_trainer(option):
4 | return eval(option.trainer_name)(option)
5 |
--------------------------------------------------------------------------------
/datasets/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/datasets/.idea/$CACHE_FILE$:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/datasets/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/datasets/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/scripts/gqa-ood/run_multiple_gqa.sh:
--------------------------------------------------------------------------------
1 | #./scripts/gqa-ood/baseline_gqa.sh
2 |
3 | #./scripts/gqa-ood/upweighting_gqa.sh
4 | #./scripts/gqa-ood/group_dro_gqa.sh
5 |
6 | ./scripts/gqa-ood/rubi_gqa.sh
7 | ./scripts/gqa-ood/lnl_gqa.sh
8 | #./scripts/gqa-ood/irmv1_gqa.sh
9 | #./scripts/gqa-ood/lff_gqa.sh
10 |
--------------------------------------------------------------------------------
/scripts/gqa-ood/baseline_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='BaseTrainer'
7 | python main.py \
8 | --expt_type gqa_experiments \
9 | --lr 1e-4 \
10 | --weight_decay 0 \
11 | --trainer_name ${TRAINER_NAME} \
12 | --root_dir ${ROOT}
13 |
--------------------------------------------------------------------------------
/datasets/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/scripts/celebA/rubi_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='RUBiTrainer'
7 |
8 | python main.py \
9 | --expt_type celebA_experiments \
10 | --trainer_name ${TRAINER_NAME} \
11 | --lr 1e-4 \
12 | --weight_decay 1e-5 \
13 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/gqa-ood/lff_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='LffTrainer'
7 | python main.py \
8 | --expt_type gqa_lff \
9 | --lr 1e-4 \
10 | --weight_decay 0 \
11 | --trainer_name ${TRAINER_NAME} \
12 | --bias_loss_gamma 0.7 \
13 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/datasets/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/scripts/biased_mnist/rubi_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='RUBiTrainer'
7 |
8 | python main.py \
9 | --expt_type biased_mnist_experiments \
10 | --trainer_name ${TRAINER_NAME} \
11 | --lr 1e-3 \
12 | --weight_decay 1e-5 \
13 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/celebA/group_upweighting_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='GroupUpweightingTrainer'
7 | python main.py \
8 | --expt_type celebA_experiments \
9 | --lr 1e-5 \
10 | --weight_decay 0.1 \
11 | --trainer_name ${TRAINER_NAME} \
12 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/celebA/group_dro_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='GroupDROTrainer'
7 |
8 | python main.py \
9 | --expt_type celebA_experiments \
10 | --trainer_name ${TRAINER_NAME} \
11 | --lr 1e-5 \
12 | --weight_decay 0.1 \
13 | --group_weight_step_size 0.01 \
14 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/biased_mnist/baseline_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='BaseTrainer'
7 | lr=1e-3
8 | wd=1e-5
9 |
10 | python main.py \
11 | --expt_type biased_mnist_experiments \
12 | --trainer_name ${TRAINER_NAME} \
13 | --lr ${lr} \
14 | --weight_decay ${wd} \
15 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/celebA/baseline_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='BaseTrainer'
7 | lr=1e-3
8 | wd=0
9 | python main.py \
10 | --expt_type celebA_experiments \
11 | --trainer_name ${TRAINER_NAME} \
12 | --lr ${lr} \
13 | --weight_decay ${wd} \
14 | --expt_name ${TRAINER_NAME} \
15 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/celebA/irmv1_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='IRMv1Trainer'
7 | python main.py \
8 | --lr 1e-4 \
9 | --weight_decay 0 \
10 | --expt_type celebA_experiments \
11 | --trainer_name ${TRAINER_NAME} \
12 | --grad_penalty_weight 1 \
13 | --num_envs_per_batch 4 \
14 | --root_dir ${ROOT}
15 |
--------------------------------------------------------------------------------
/scripts/biased_mnist/group_dro_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='GroupDROTrainer'
7 |
8 | python -u main.py \
9 | --expt_type biased_mnist_experiments \
10 | --lr 1e-3 \
11 | --weight_decay 1e-5 \
12 | --trainer_name ${TRAINER_NAME} \
13 | --group_weight_step_size 1e-3 \
14 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/biased_mnist/group_upweighting_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='GroupUpweightingTrainer'
7 |
8 | CUDA_VISIBLE_DEVICES=0 python main.py \
9 | --expt_type biased_mnist_experiments \
10 | --trainer_name ${TRAINER_NAME} \
11 | --lr 1e-3 \
12 | --weight_decay 1e-5 \
13 | --root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/celebA/spectral_decoupling_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='SpectralDecouplingTrainer'
7 |
8 | # Lambdas and gammas are specified in celebA_experiments.py
9 | python main.py \
10 | --lr 1e-4 \
11 | --weight_decay 1e-5 \
12 | --expt_type celebA_experiments \
13 | --trainer_name ${TRAINER_NAME} \
14 | --root_dir ${ROOT}
15 |
--------------------------------------------------------------------------------
/scripts/gqa-ood/spectral_decoupling_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='SpectralDecouplingTrainer'
7 |
8 | python main.py \
9 | --lr 1e-4 \
10 | --weight_decay 0 \
11 | --expt_type gqa_sd_experiments \
12 | --trainer_name ${TRAINER_NAME} \
13 | --root_dir ${ROOT} \
14 | --spectral_decoupling_lambda 1e-3 \
15 | --spectral_decoupling_gamma 1e-3
--------------------------------------------------------------------------------
/models/vqa/make_mask.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # --------------------------------------------------------
5 |
6 | import torch
7 |
8 |
9 | # Masking the sequence mask
10 | def make_mask(feature):
11 | return (torch.sum(
12 | torch.abs(feature),
13 | dim=-1
14 | ) == 0).unsqueeze(1).unsqueeze(2)
--------------------------------------------------------------------------------
/datasets/.idea/datasets.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/scripts/celebA/lnl_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='LNLTrainer'
7 | python main.py \
8 | --expt_type celebA_experiments \
9 | --lr 1e-4 \
10 | --weight_decay 1e-4 \
11 | --trainer_name ${TRAINER_NAME} \
12 | --root_dir ${ROOT}
13 |
14 | #TRAINER_NAME='LNLTrainer'
15 | #python main.py \
16 | #--expt_type celebA_experiments \
17 | #--trainer_name ${TRAINER_NAME} \
18 | #--root_dir ${ROOT}
--------------------------------------------------------------------------------
/scripts/biased_mnist/irmv1_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='IRMv1Trainer'
7 | for grad_penalty_weight in 0.01; do
8 | python main.py \
9 | --lr 1e-3 \
10 | --weight_decay 1e-5 \
11 | --expt_type biased_mnist_experiments \
12 | --trainer_name ${TRAINER_NAME} \
13 | --grad_penalty_weight ${grad_penalty_weight} \
14 | --num_envs_per_batch 16 \
15 | --root_dir ${ROOT}
16 | done
--------------------------------------------------------------------------------
/scripts/biased_mnist/learning_from_failure_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='LffTrainer'
7 |
8 | for bias_loss_gamma in 0.5; do
9 | python main.py \
10 | --expt_type biased_mnist_experiments \
11 | --lr 1e-3 \
12 | --weight_decay 1e-5 \
13 | --trainer_name ${TRAINER_NAME} \
14 | --optimizer_name Adam \
15 | --bias_loss_gamma ${bias_loss_gamma} \
16 | --root_dir ${ROOT}
17 | done
--------------------------------------------------------------------------------
/scripts/gqa-ood/upweighting_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | #for key_to_group_by in head_tail answer global_group_name local_group_name; do
7 | for key_to_group_by in head_tail ; do
8 | TRAINER_NAME='GroupUpweightingTrainer'
9 | python main.py \
10 | --expt_type gqa_experiments \
11 | --trainer_name ${TRAINER_NAME} \
12 | --key_to_group_by ${key_to_group_by} \
13 | --lr 1e-3 \
14 | --weight_decay 0 \
15 | --root_dir ${ROOT}
16 | done
--------------------------------------------------------------------------------
/scripts/gqa-ood/rubi_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='RUBiTrainer'
7 | #for key_to_group_by in head_tail global_group_name local_group_name; do
8 | for key_to_group_by in head_tail ; do
9 | python main.py \
10 | --expt_type gqa_experiments \
11 | --lr 1e-4 \
12 | --weight_decay 0 \
13 | --trainer_name ${TRAINER_NAME} \
14 | --key_to_group_by ${key_to_group_by} \
15 | --bias_variable_name group_ix \
16 | --root_dir ${ROOT}
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/gqa-ood/preprocess_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | mkdir -p ${ROOT}/GQA/preprocessed/objects
7 | mkdir -p ${ROOT}/GQA/preprocessed/spatial
8 |
9 | python datasets/vqa/gqa_feat_preproc.py \
10 | --mode object \
11 | --object_dir ${ROOT}/GQA/objects \
12 | --out_dir ${ROOT}/GQA/preprocessed/objects
13 |
14 | python datasets/vqa/gqa_feat_preproc.py \
15 | --mode spatial \
16 | --spatial_dir ${ROOT}/GQA/spatial \
17 | --out_dir ${ROOT}/GQA/preprocessed/spatial
18 |
--------------------------------------------------------------------------------
/scripts/biased_mnist_dataset_generator/full_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source activate bias_mitigator
3 | #python datasets/biased_mnist_generator.py \
4 | #--config_file conf/biased_mnist_generator/full.yaml \
5 | #--p_bias 0.9 \
6 | #--suffix '_0.9' \
7 | #--generate_test_set 1
8 |
9 | for p_bias in 0.9 0.93 0.95 0.97 0.99 1.0; do
10 | python -u datasets/biased_mnist_generator.py \
11 | --config_file conf/biased_mnist_generator/full_v1.yaml \
12 | --p_bias ${p_bias} \
13 | --suffix _${p_bias} \
14 | --generate_test_set 0
15 | done
--------------------------------------------------------------------------------
/scripts/gqa-ood/group_dro_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='GroupDROTrainer'
7 | #for key_to_group_by in head_tail answer global_group_name local_group_name; do
8 | for key_to_group_by in head_tail ; do
9 | python main.py \
10 | --expt_type gqa_experiments \
11 | --lr 1e-4 \
12 | --weight_decay 0 \
13 | --group_weight_step_size 0.01 \
14 | --trainer_name ${TRAINER_NAME} \
15 | --key_to_group_by ${key_to_group_by} \
16 | --root_dir ${ROOT}
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/biased_mnist/lnl_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='LNLTrainer'
7 | for entropy_loss_weight in 0.01; do
8 | for grad_reverse_factor in -0.1; do
9 | python main.py \
10 | --expt_type biased_mnist_experiments \
11 | --lr 1e-3 \
12 | --weight_decay 1e-5 \
13 | --trainer_name ${TRAINER_NAME} \
14 | --root_dir ${ROOT} \
15 | --entropy_loss_weight ${entropy_loss_weight} \
16 | --grad_reverse_factor ${grad_reverse_factor}
17 | done
18 | done
--------------------------------------------------------------------------------
/scripts/biased_mnist/spectral_decoupling_biased_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='SpectralDecouplingTrainer'
7 |
8 | for sd_gamma in 1e-3; do
9 | for sd_lambda in 1e-3; do
10 | python main.py \
11 | --lr 1e-3 \
12 | --weight_decay 1e-5 \
13 | --expt_type biased_mnist_experiments \
14 | --trainer_name ${TRAINER_NAME} \
15 | --root_dir ${ROOT} \
16 | --spectral_decoupling_lambda ${sd_lambda} \
17 | --spectral_decoupling_gamma ${sd_gamma}
18 | done
19 | done
--------------------------------------------------------------------------------
/scripts/gqa-ood/irmv1_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='IRMv1Trainer'
7 |
8 | #for key_to_group_by in head_tail answer global_group_name local_group_name; do
9 | for key_to_group_by in head_tail ; do
10 | python main.py \
11 | --lr 1e-4 \
12 | --weight_decay 0 \
13 | --expt_type gqa_experiments \
14 | --key_to_group_by ${key_to_group_by} \
15 | --trainer_name ${TRAINER_NAME} \
16 | --grad_penalty_weight 0.01 \
17 | --num_envs_per_batch 16 \
18 | --root_dir ${ROOT}
19 | done
--------------------------------------------------------------------------------
/scripts/gqa-ood/lnl_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='LNLTrainer'
7 | #BIAS_VARIABLE_NAME=answer
8 | #BIAS_VARIABLE_NAME=head_tail
9 |
10 | for key_to_group_by in qtype_detailed; do
11 | EXPT_NAME=bias_variable_${key_to_group_by}
12 | python main.py \
13 | --expt_type gqa_experiments \
14 | --lr 1e-3 \
15 | --weight_decay 0 \
16 | --grad_reverse_factor -0.1 \
17 | --entropy_loss_weight 0.01 \
18 | --key_to_group_by ${key_to_group_by} \
19 | --bias_variable_name group_ix \
20 | --trainer_name ${TRAINER_NAME} \
21 | --expt_name ${EXPT_NAME} \
22 | --root_dir ${ROOT}
23 | done
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from inspect import isclass
2 | from pkgutil import iter_modules
3 | from pathlib import Path
4 | from importlib import import_module
5 | # iterate through the modules in the current package
6 | package_dir = Path(__file__).resolve().parent
7 | for (_, module_name, _) in iter_modules([package_dir]):
8 |
9 | # import the module and iterate through its attributes
10 | module = import_module(f"{__name__}.{module_name}")
11 | for attribute_name in dir(module):
12 | attribute = getattr(module, attribute_name)
13 |
14 | if isclass(attribute):
15 | # Add the class to this package's variables
16 | globals()[attribute_name] = attribute
17 |
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | from inspect import isclass
2 | from pkgutil import iter_modules
3 | from pathlib import Path
4 | from importlib import import_module
5 | # iterate through the modules in the current package
6 | package_dir = Path(__file__).resolve().parent
7 | for (_, module_name, _) in iter_modules([package_dir]):
8 |
9 | # import the module and iterate through its attributes
10 | module = import_module(f"{__name__}.{module_name}")
11 | for attribute_name in dir(module):
12 | attribute = getattr(module, attribute_name)
13 |
14 | if isclass(attribute):
15 | # Add the class to this package's variables
16 | globals()[attribute_name] = attribute
17 |
--------------------------------------------------------------------------------
/scripts/celebA/learning_from_failure_celebA.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 |
6 | TRAINER_NAME='LffTrainer'
7 |
8 | # Note that we used SGD for all other methods, however LFF didn't gel well with SGD despite searching for hyperparameters.
9 | # So, we are using Adam.
10 | # Also, we were unable to replicate the original paper's results with bias_loss_gamma = 0.7. We tuned it, and found that 0.1 worked best for our runs.
11 |
12 | python main.py \
13 | --expt_type celebA_experiments \
14 | --lr 1e-4 \
15 | --weight_decay 0 \
16 | --trainer_name ${TRAINER_NAME} \
17 | --optimizer_name Adam \
18 | --bias_loss_gamma 0.1 \
19 | --root_dir ${ROOT}
20 |
21 |
--------------------------------------------------------------------------------
/models/vqa/ops/layer_norm.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # --------------------------------------------------------
5 |
6 | import torch.nn as nn
7 | import torch
8 |
9 | class LayerNorm(nn.Module):
10 | def __init__(self, size, eps=1e-6):
11 | super(LayerNorm, self).__init__()
12 | self.eps = eps
13 |
14 | self.a_2 = nn.Parameter(torch.ones(size))
15 | self.b_2 = nn.Parameter(torch.zeros(size))
16 |
17 | def forward(self, x):
18 | mean = x.mean(-1, keepdim=True)
19 | std = x.std(-1, keepdim=True)
20 |
21 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
22 |
--------------------------------------------------------------------------------
/eval/vqa/plot_tail.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns;
2 | import os
3 |
4 | sns.set()
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 | import numpy as np
8 |
9 |
10 | def plot_tail_for_one_model(alpha, accuracy, model_name='default'):
11 | data = {'Tail size': alpha, model_name: accuracy}
12 | df = pd.DataFrame(data, dtype=float)
13 | df = pd.melt(df, ['Tail size'], var_name="Models", value_name="Accuracy")
14 | ax = sns.lineplot(x="Tail size", y="Accuracy", hue="Models", style="Models", data=df, markers=False, ci=None)
15 | plt.xscale('log')
16 | plt.ylim(0, 100)
17 | save_dir = 'figures'
18 | if not os.path.exists(save_dir):
19 | os.makedirs(save_dir)
20 | plt.savefig('figures/tail_plot_%s.pdf' % model_name)
21 | plt.close()
22 |
--------------------------------------------------------------------------------
/datasets/vqa/feat_filter.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # --------------------------------------------------------
5 |
6 |
7 | def feat_filter(dataset, frcn_feat, grid_feat, bbox_feat):
8 | feat_dict = {}
9 |
10 | if dataset in ['vqa']:
11 | feat_dict['FRCN_FEAT'] = frcn_feat
12 | feat_dict['BBOX_FEAT'] = bbox_feat
13 |
14 | elif dataset in ['gqa']:
15 | feat_dict['FRCN_FEAT'] = frcn_feat
16 | feat_dict['GRID_FEAT'] = grid_feat
17 | feat_dict['BBOX_FEAT'] = bbox_feat
18 |
19 | elif dataset in ['clevr']:
20 | feat_dict['GRID_FEAT'] = grid_feat
21 |
22 | else:
23 | exit(-1)
24 |
25 | return feat_dict
26 |
27 |
28 |
--------------------------------------------------------------------------------
/utils/bias_retrievers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def build_bias_retriever(bias_variable_name):
5 | if bias_variable_name == 'color':
6 | return ColorRetriever()
7 | else:
8 | return VariableRetriever(bias_variable_name)
9 |
10 |
11 | class ColorRetriever:
12 |
13 | def retrieve(self, batch):
14 | x = batch['x']
15 | return x.view(x.size(0), x.size(1), -1).max(2)[0]
16 |
17 | def __call__(self, batch, main_out):
18 | return self.retrieve(batch)
19 |
20 |
21 | class VariableRetriever:
22 | def __init__(self, var_name):
23 | self.var_name = var_name
24 |
25 | def __call__(self, batch, main_out):
26 | if self.var_name in batch:
27 | ret = batch[self.var_name]
28 | else:
29 | ret = main_out[self.var_name]
30 | if isinstance(ret, list):
31 | ret = torch.FloatTensor(ret).cuda()
32 | return ret
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Robik Shrestha, Kushal Kafle and Christopher Kanan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/vqa/ops/fc.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # --------------------------------------------------------
5 |
6 | import torch.nn as nn
7 | import torch
8 |
9 |
10 | class FC(nn.Module):
11 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
12 | super(FC, self).__init__()
13 | self.dropout_r = dropout_r
14 | self.use_relu = use_relu
15 |
16 | self.linear = nn.Linear(in_size, out_size)
17 |
18 | if use_relu:
19 | self.relu = nn.ReLU(inplace=True)
20 |
21 | if dropout_r > 0:
22 | self.dropout = nn.Dropout(dropout_r)
23 |
24 | def forward(self, x):
25 | x = self.linear(x)
26 |
27 | if self.use_relu:
28 | x = self.relu(x)
29 |
30 | if self.dropout_r > 0:
31 | x = self.dropout(x)
32 |
33 | return x
34 |
35 |
36 | class MLP(nn.Module):
37 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
38 | super(MLP, self).__init__()
39 |
40 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
41 | self.linear = nn.Linear(mid_size, out_size)
42 |
43 | def forward(self, x):
44 | return self.linear(self.fc(x))
45 |
--------------------------------------------------------------------------------
/models/model_factory.py:
--------------------------------------------------------------------------------
1 | from models.fc_models import *
2 | from models.cnn_models import *
3 | from models.vqa.updn.net import *
4 | from models.vqa.mcan.net import MCAN
5 | # from models.vqa.ban.ban import BAN, BANNoDropout
6 |
7 |
8 | def build_model(option,
9 | model_name,
10 | in_dims=None,
11 | hid_dims=None,
12 | out_dims=None,
13 | freeze_layers=None,
14 | dropout=None
15 | ):
16 | if 'updn' in model_name.lower() or 'mcan' in model_name.lower() or 'ban' in model_name.lower():
17 | m = eval(model_name)(option.dataset_info.pretrained_emb,
18 | option.dataset_info.token_size,
19 | option.dataset_info.ans_size)
20 | else:
21 | if in_dims is None and hid_dims is None:
22 | m = eval(model_name)(num_classes=out_dims)
23 | elif hid_dims is None:
24 | m = eval(model_name)(in_dims=in_dims, num_classes=out_dims)
25 | elif dropout is None:
26 | m = eval(model_name)(in_dims=in_dims, hid_dims=hid_dims, num_classes=out_dims)
27 | else:
28 | m = eval(model_name)(in_dims=in_dims, hid_dims=hid_dims, num_classes=out_dims, dropout=dropout)
29 | if freeze_layers is not None:
30 | m.freeze_layers(freeze_layers)
31 | return m
32 |
--------------------------------------------------------------------------------
/models/vqa/vqa_adapter.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Adapted from OpenVQA
3 | # Written by Zhenwei Shao https://github.com/ParadoxZW
4 | # --------------------------------------------------------
5 |
6 | import torch.nn as nn
7 | import torch
8 | from models.vqa.make_mask import make_mask
9 |
10 |
11 | class VQAAdapter(nn.Module):
12 | def __init__(self,
13 | frcn_feat_size=(100, 2048),
14 | bbox_feat_size=(100, 5),
15 | bbox_feat_emb_size=1024,
16 | use_bbox_feats=True,
17 | hidden_size=1024
18 | ):
19 | super(VQAAdapter, self).__init__()
20 | self.frcn_feat_size = frcn_feat_size
21 | self.bbox_feat_size = bbox_feat_size
22 | self.use_bbox_feats = use_bbox_feats
23 | self.hidden_size = hidden_size
24 | in_size = frcn_feat_size
25 | if self.use_bbox_feats:
26 | self.bbox_linear = nn.Linear(5, bbox_feat_emb_size)
27 | in_size = frcn_feat_size[1] + bbox_feat_emb_size
28 | self.frcn_linear = nn.Linear(in_size, hidden_size)
29 |
30 | def forward(self, frcn_feat, bbox_feat):
31 |
32 | img_feat_mask = make_mask(frcn_feat)
33 |
34 | if self.use_bbox_feats:
35 | bbox_feat = self.bbox_linear(bbox_feat)
36 | frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
37 | img_feat = self.frcn_linear(frcn_feat)
38 |
39 | return img_feat, img_feat_mask
40 |
--------------------------------------------------------------------------------
/scripts/gqa-ood/hyperparam_search_gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source common.sh
3 | set -e
4 | source activate bias_mitigator
5 | PROJECT_NAME='GQA'
6 |
7 | TRAINER_NAME='BaseTrainer'
8 | for lr in 1e-2; do
9 | python main.py \
10 | --expt_type gqa_experiments \
11 | --project_name ${PROJECT_NAME} \
12 | --trainer_name ${TRAINER_NAME} \
13 | --lr ${lr} \
14 | --weight_decay 0 \
15 | --root_dir ${ROOT}
16 | done
17 |
18 | #TRAINER_NAME='RUBiTrainer'
19 | #for lr in 1e-2; do
20 | # for wd in 0 1e-3 0.1; do
21 | # python main.py \
22 | # --expt_type gqa_experiments \
23 | # --project_name ${PROJECT_NAME} \
24 | # --trainer_name ${TRAINER_NAME} \
25 | # --lr ${lr} \
26 | # --weight_decay ${wd} \
27 | # --root_dir ${ROOT}
28 | # done
29 | #done
30 |
31 |
32 | #TRAINER_NAME='GroupDROTrainer'
33 | #for lr in 1e-2; do
34 | # for wd in 0 1e-3 0.1; do
35 | # python main.py \
36 | # --expt_type gqa_experiments \
37 | # --project_name ${PROJECT_NAME} \
38 | # --trainer_name ${TRAINER_NAME} \
39 | # --lr ${lr} \
40 | # --weight_decay ${wd} \
41 | # --root_dir ${ROOT}
42 | # done
43 | #done
44 |
45 | #TRAINER_NAME='LffTrainer'
46 | #for lr in 1e-2; do
47 | # for wd in 0 1e-3 0.1; do
48 | # python main.py \
49 | # --expt_type gqa_experiments \
50 | # --project_name ${PROJECT_NAME} \
51 | # --trainer_name ${TRAINER_NAME} \
52 | # --lr ${lr} \
53 | # --weight_decay ${wd} \
54 | # --root_dir ${ROOT}
55 | # done
56 | #done
57 |
58 |
59 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class GCELoss(nn.Module):
7 | def __init__(self, q=0.7, reduction='none'):
8 | super(GCELoss, self).__init__()
9 | self.q = q
10 | self.reduction = reduction
11 |
12 | def __call__(self, input, target):
13 | p = F.softmax(input, dim=1)
14 | Yg = torch.gather(p, 1, torch.unsqueeze(target, 1))
15 | loss_weight = (Yg.squeeze().detach() ** self.q) * self.q
16 | loss = F.cross_entropy(input, target, reduction=self.reduction) * loss_weight
17 | return loss
18 |
19 |
20 | class ExpLoss(nn.Module):
21 | def __init__(self, reduction='none'):
22 | super(ExpLoss, self).__init__()
23 | self.reduction = reduction
24 |
25 | def __call__(self, input, target):
26 | return torch.exp(torch.gather(1 - F.softmax(input, dim=1), dim=1, index=target.view(-1, 1)))
27 |
28 |
29 | class InverseProbabilityLoss(nn.Module):
30 | def __init__(self, reduction='none'):
31 | super(InverseProbabilityLoss, self).__init__()
32 | self.reduction = reduction
33 |
34 | def __call__(self, input, target):
35 | return 1 / torch.gather(1 - F.softmax(input, dim=1), dim=1, index=target.view(-1, 1))
36 | # return torch.exp(torch.gather(1 - F.softmax(input, dim=1), dim=1, index=target.view(-1, 1)))
37 |
38 |
39 | if __name__ == "__main__":
40 | gce_loss = GCELoss(q=2)
41 | l = gce_loss(torch.FloatTensor([[0.1, 0.9], [0.1, 0.8]]), torch.LongTensor([0, 1]))
42 | print(l)
43 |
--------------------------------------------------------------------------------
/utils/running_stats.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | # https://stackoverflow.com/a/17637351/1122681
5 | class RunningStats:
6 |
7 | def __init__(self):
8 | self.n = 0
9 | self.old_m = None
10 | self.new_m = None
11 | self.old_s = None
12 | self.new_s = None
13 | self.sum = None
14 |
15 | def clear(self):
16 | self.n = 0
17 |
18 | def push(self, x):
19 | self.n += 1
20 |
21 | if self.n == 1:
22 | self.sum = torch.zeros_like(x)
23 | self.sum += x
24 | self.old_m = self.new_m = x
25 | self.old_s = torch.zeros_like(x)
26 | self.new_m = torch.zeros_like(x)
27 | self.new_s = torch.zeros_like(x)
28 |
29 | else:
30 | self.new_m = self.old_m + (x - self.old_m) / self.n
31 | self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m)
32 | self.sum += x
33 | self.old_m = self.new_m
34 | self.old_s = self.new_s
35 |
36 | def mean(self):
37 | return self.new_m if self.n else torch.zeros_like(self.new_m)
38 |
39 | def variance(self):
40 | return self.new_s / (self.n - 1) if self.n > 1 else 0.0
41 |
42 | def std(self):
43 | return torch.sqrt(self.variance())
44 |
45 | def std_err(self):
46 | return self.std() / (self.n ** 0.5)
47 |
48 | def get_summary(self):
49 | return {
50 | 'n': self.n,
51 | 'mean': self.mean(),
52 | 'std': self.std(),
53 | 'full_mean': self.sum / self.n
54 | }
55 |
56 |
57 | if __name__ == "__main__":
58 | stats = RunningStats()
59 | for i in range(1, 10):
60 | stats.push(torch.randn((3, 4)))
61 | print(stats.mean())
62 | print(stats.std())
63 |
--------------------------------------------------------------------------------
/utils/format_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 |
6 |
7 | # Define function for string formatting of scientific notation
8 | def sci_notation(num, decimal_digits=1, precision=None, exponent=None):
9 | """
10 | Returns a string representation of the scientific
11 | notation of the given number formatted for use with
12 | LaTeX or Mathtext, with specified number of significant
13 | decimal digits and precision (number of decimal digits
14 | to show). The exponent to be used can also be specified
15 | explicitly.
16 | """
17 | if exponent is None:
18 | exponent = int(np.floor(np.log10(abs(num))))
19 | coeff = round(num / float(10 ** exponent), decimal_digits)
20 | if precision is None:
21 | precision = decimal_digits
22 |
23 | return r"${0:.{2}f}\times 10^{{{1:d}}}$".format(coeff, exponent, precision)
24 |
25 |
26 | # def convert_to_exponent_of_ten(v):
27 | # if v == 0:
28 | # return v
29 | # else:
30 | # try:
31 | # return '$10^{' + str(int(np.log10(v))) + '}$'
32 | # # return '{:.1e}'.format(v)
33 | # except:
34 | # return v
35 |
36 | def format_matplotlib(text_fontsize=10, default_fontsize=14, legend_fontsize=10):
37 | sns.set(style='darkgrid')
38 | # plt.style.use('tableau-colorblind10')
39 | font = {
40 | 'family': 'normal',
41 | 'weight': 'bold'
42 | }
43 | matplotlib.rc('font', **font)
44 | params = {'axes.labelsize': default_fontsize, 'axes.titlesize': default_fontsize, 'font.size': text_fontsize,
45 | 'legend.fontsize': legend_fontsize,
46 | 'xtick.labelsize': default_fontsize, 'ytick.labelsize': default_fontsize}
47 | matplotlib.rcParams.update(params)
48 |
49 |
50 | format_matplotlib()
51 |
--------------------------------------------------------------------------------
/datasets/vqa/base_vqa_dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # --------------------------------------------------------
5 |
6 | import numpy as np
7 | import glob, json, torch, random
8 | import torch.utils.data as Data
9 | import torch.nn as nn
10 | from datasets.vqa.feat_filter import feat_filter
11 | import logging
12 |
13 |
14 | class BaseVQADataset(Data.Dataset):
15 | def __init__(self, load_visual_feats=True):
16 | self.token_to_ix = None
17 | self.pretrained_emb = None
18 | self.ans_to_ix = None
19 | self.ix_to_ans = None
20 |
21 | self.data_size = None
22 | self.token_size = None
23 | self.ans_size = None
24 | self.load_visual_feats = load_visual_feats
25 |
26 | def load_ques_ans(self, idx):
27 | raise NotImplementedError()
28 |
29 | def load_img_feats(self, idx, iid):
30 | raise NotImplementedError()
31 |
32 | def __getitem__(self, idx):
33 | vqa = self.load_ques_ans(idx)
34 | if self.load_visual_feats:
35 | frcn_feat_iter, bbox_feat_iter = self.load_img_feats(idx, vqa['image_id'])
36 | vqa['frcn_feat'] = torch.from_numpy(frcn_feat_iter)
37 | vqa['bbox_feat'] = torch.from_numpy(bbox_feat_iter)
38 | vqa['dataset_ix'] = idx
39 | vqa['question_token_ixs'] = torch.from_numpy(vqa['question_token_ixs'])
40 | vqa['y'] = torch.from_numpy(vqa['ans_iter'])
41 | # Adding 'group_name' suffix
42 | vqa['answer_group_name'] = vqa['answer']
43 | return vqa
44 | # return {
45 | # 'frcn_feat': torch.from_numpy(frcn_feat_iter),
46 | # 'bbox_feat': torch.from_numpy(bbox_feat_iter),
47 | # 'question_id': q_details['question_id'],
48 | # 'ques_ix_iter': torch.from_numpy(q_details['ques_ix_iter']),
49 | # 'answer': q_details['ans'],
50 | # 'y': torch.from_numpy(q_details['ans_iter']),
51 | # 'dataset_ix': idx,
52 | # 'local_group_name': q_details['local_grp_name'],
53 | # 'group_name': q_details['group_name'],
54 | # 'group_ix': q_details['group_ix']
55 | # }
56 |
57 | def __len__(self):
58 | return self.data_size
59 |
60 | def shuffle_list(self, list):
61 | random.shuffle(list)
62 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import random
3 | import torch
4 | from option import get_option
5 | from trainers import trainer_factory
6 | from utils.trainer_utils import save_option, initialize_logger
7 | import logging
8 | from datasets import dataloader_factory
9 | import json
10 | from experiments.celebA_experiments import *
11 | from experiments.biased_mnist_experiments import *
12 | from experiments.gqa_experiments import *
13 |
14 |
15 | def backend_setting(option):
16 | # Initialize the expt_dir where all the results (predictions, checkpoints, logs, metrics) will be saved
17 | if option.expt_dir is None:
18 | option.expt_dir = os.path.join(option.save_dir, option.expt_name)
19 |
20 | if not os.path.exists(option.expt_dir):
21 | os.makedirs(option.expt_dir)
22 |
23 | # Configure the logger
24 | initialize_logger(option.expt_dir)
25 |
26 | # Set the random seeds
27 | if option.random_seed is None:
28 | option.random_seed = random.randint(1, 10000)
29 | random.seed(option.random_seed)
30 | torch.manual_seed(option.random_seed)
31 | torch.cuda.manual_seed_all(option.random_seed)
32 | np.random.seed(option.random_seed)
33 |
34 | if torch.cuda.is_available() and not option.cuda:
35 | logging.warn('GPU is available, but we are not using it!!!')
36 |
37 | if not torch.cuda.is_available() and option.cuda:
38 | option.cuda = False
39 |
40 | # Dataset specific settings
41 | set_if_null(option, 'bias_loss_gamma', 0.7)
42 | set_if_null(option, 'bias_ema_gamma', 0.7)
43 |
44 |
45 | def set_if_null(option, attr_name, val):
46 | if not hasattr(option, attr_name) or getattr(option, attr_name) is None:
47 | setattr(option, attr_name, val)
48 |
49 | def main():
50 | option = get_option()
51 | if option.project_name is None:
52 | option.project_name = option.dataset_name
53 | if option.expt_type is not None:
54 | eval(option.expt_type)(option, run)
55 | else:
56 | run(option)
57 |
58 |
59 | def run(option):
60 | backend_setting(option)
61 | data_loaders = dataloader_factory.build_dataloaders(option)
62 | if 'gqa' in option.dataset_name.lower():
63 | option.bias_variable_dims = option.num_groups
64 | option.num_bias_classes = option.num_groups
65 |
66 | save_option(option)
67 | logging.getLogger().info(json.dumps(option.__dict__, indent=4, sort_keys=True,
68 | default=lambda o: f"<>"))
69 |
70 | trainer = trainer_factory.build_trainer(option)
71 | trainer.train(data_loaders['Train'], data_loaders['Test'], data_loaders['Unbalanced Train'])
72 |
73 |
74 | if __name__ == '__main__':
75 | main()
76 |
--------------------------------------------------------------------------------
/models/vqa/ban/ban.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Zhenwei Shao https://github.com/ParadoxZW
4 | # --------------------------------------------------------
5 |
6 | from models.vqa.make_mask import make_mask
7 | from models.vqa.ops.fc import FC, MLP
8 | from models.vqa.ops.layer_norm import LayerNorm
9 | from models.vqa.ban._ban import _BAN
10 | from models.vqa.vqa_adapter import VQAAdapter
11 |
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.nn.utils.weight_norm import weight_norm
15 | import torch
16 |
17 |
18 | # -------------------------
19 | # ---- Main BAN Model ----
20 | # -------------------------
21 |
22 | class BAN(nn.Module):
23 | def __init__(self, pretrained_emb, token_size, answer_size,
24 | word_embed_size=300,
25 | img_feat_size=1024,
26 | hidden_size=1024,
27 | k_times=3,
28 | dropout_r=0.2,
29 | classifier_dropout_r=0.5,
30 | glimpse=8,
31 | flat_out_size=2048,
32 | use_glove=True):
33 | super(BAN, self).__init__()
34 |
35 | self.embedding = nn.Embedding(
36 | num_embeddings=token_size,
37 | embedding_dim=word_embed_size
38 | )
39 |
40 | # Loading the GloVe embedding weights
41 | if use_glove:
42 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
43 |
44 | self.rnn = nn.GRU(
45 | input_size=word_embed_size,
46 | hidden_size=hidden_size,
47 | num_layers=1,
48 | batch_first=True
49 | )
50 |
51 | self.adapter = VQAAdapter(hidden_size=hidden_size)
52 |
53 | self.backbone = _BAN(img_feat_size,
54 | hidden_size,
55 | k_times,
56 | dropout_r,
57 | classifier_dropout_r,
58 | glimpse)
59 |
60 | # Classification layers
61 | layers = [
62 | weight_norm(nn.Linear(hidden_size, flat_out_size), dim=None),
63 | nn.ReLU(),
64 | nn.Dropout(classifier_dropout_r, inplace=True),
65 | weight_norm(nn.Linear(flat_out_size, answer_size), dim=None)
66 | ]
67 | self.classifier = nn.Sequential(*layers)
68 |
69 | def forward(self, frcn_feat, bbox_feat, ques_ix):
70 | # Pre-process Language Feature
71 | # lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
72 | lang_feat = self.embedding(ques_ix)
73 | lang_feat, _ = self.rnn(lang_feat)
74 | img_feat, _ = self.adapter(frcn_feat, bbox_feat)
75 |
76 | # Backbone Framework
77 | lang_feat = self.backbone(
78 | lang_feat,
79 | img_feat
80 | )
81 |
82 | # Classification layers
83 | proj_feat = self.classifier(lang_feat.sum(1))
84 | return {
85 | 'question_features': lang_feat,
86 | 'logits': proj_feat
87 | }
88 |
89 |
90 | class BANNoDropout(BAN):
91 | def __init__(self, pretrained_emb, token_size, answer_size):
92 | super().__init__(pretrained_emb, token_size, answer_size,
93 | dropout_r=0, classifier_dropout_r=0)
94 |
--------------------------------------------------------------------------------
/datasets/dataloader_factory.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from torch.utils.data import DataLoader
4 | import json
5 | from datasets.biased_mnist_dataset import create_biased_mnist_dataloaders
6 | from datasets.celebA_dataset import create_celebA_dataloaders
7 | from datasets.vqa.gqa_dataset import create_gqa_dataloaders
8 | from utils.data_utils import dict_collate_fn
9 |
10 | def build_balanced_loader(dataloader, balanced_sampling_attributes=['y'], balanced_sampling_gamma=1, replacement=True):
11 | logger = logging.getLogger()
12 | all_group_names = []
13 |
14 | # Count frequencies for all groups of attributes to balance,
15 | # and assign each sample to a group, so that we can compute its sampling weight later on
16 | group_name_to_count = {}
17 | for batch in dataloader:
18 | batch_group_names = []
19 | for ix, _ in enumerate(batch['y']):
20 | group_name = ""
21 | for attr in balanced_sampling_attributes:
22 | group_name += f"{attr}_{batch[attr][ix]}_"
23 | batch_group_names.append(group_name)
24 |
25 | for group_name in batch_group_names:
26 | if group_name not in group_name_to_count:
27 | group_name_to_count[group_name] = 0
28 | group_name_to_count[group_name] += 1
29 | all_group_names.append(group_name)
30 |
31 | # Create the balanced loader
32 | weights = []
33 | for val in all_group_names:
34 | weights.append(1 / group_name_to_count[val] ** balanced_sampling_gamma)
35 | weighted_sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(weights),
36 | replacement=replacement)
37 | balanced_dataloader = DataLoader(dataloader.dataset, batch_size=dataloader.batch_size, sampler=weighted_sampler,
38 | num_workers=dataloader.num_workers, collate_fn=dataloader.collate_fn)
39 | logger.info(f"Created balanced loader for {len(weights)} samples of dataset size {len(dataloader.dataset)}")
40 | logger.info(f"Group counts: {json.dumps(group_name_to_count, indent=4)}")
41 | return balanced_dataloader
42 |
43 |
44 | def build_dataloaders(option):
45 | dataset_name = option.dataset_name.lower()
46 | if dataset_name == 'biased_mnist_v1':
47 | loaders = create_biased_mnist_dataloaders(option) # Sets the num_groups
48 | elif dataset_name == 'celeba':
49 | loaders = create_celebA_dataloaders(option)
50 | elif dataset_name == 'gqa':
51 | loaders = create_gqa_dataloaders(option)
52 | loaders['Unbalanced Train'] = loaders['Train']
53 | if option.balanced_sampling_attributes is not None:
54 | unshuffled_train_loader = DataLoader(loaders['Train'].dataset, batch_size=option.batch_size, shuffle=False,
55 | num_workers=option.num_workers,
56 | collate_fn=dict_collate_fn())
57 | loaders['Train'] = build_balanced_loader(unshuffled_train_loader,
58 | option.balanced_sampling_attributes,
59 | balanced_sampling_gamma=option.balanced_sampling_gamma,
60 | replacement=True)
61 | return loaders
62 |
--------------------------------------------------------------------------------
/eval/vqa/gqa_eval_from_file.py:
--------------------------------------------------------------------------------
1 | from eval.vqa.gqa_eval import GQAEval
2 | from eval.vqa.plot_tail import plot_tail_for_one_model
3 | import argparse
4 | import numpy as np
5 | import os.path
6 | import glob
7 | import json
8 | from option import ROOT
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--data_root', default=f'{ROOT}/GQA')
12 | parser.add_argument('--eval_tail_size', action='store_true')
13 | parser.add_argument('--ood_test', action='store_true')
14 | parser.add_argument('--predictions', type=str)
15 | args = parser.parse_args()
16 |
17 |
18 | def loadFile(name):
19 | # load standard json file
20 | if os.path.isfile(name):
21 | with open(name) as file:
22 | data = json.load(file)
23 | # load file chunks if too big
24 | elif os.path.isdir(name.split(".")[0]):
25 | data = {}
26 | chunks = glob.glob('{dir}/{dir}_*.{ext}'.format(dir=name.split(".")[0], ext=name.split(".")[1]))
27 | for chunk in chunks:
28 | with open(chunk) as file:
29 | data.update(json.load(file))
30 | else:
31 | raise Exception("Can't find {}".format(name))
32 | return data
33 |
34 |
35 | if args.eval_tail_size:
36 | result_eval_file = args.predictions
37 |
38 | # Retrieve scores
39 | alpha_list = [9.0, 7.0, 5.0, 3.6, 2.8, 2.2, 1.8, 1.4, 1.0, 0.8, 0.4, 0.3, 0.2, 0.1, 0.0, -0.1, -0.2, -0.3, -0.4,
40 | -0.5, -0.6, -0.7]
41 | acc_list = []
42 | for alpha in alpha_list:
43 | ques_file_path = f'{args.data_root}/val_bal_tail_%.1f.json' % alpha
44 | predictions = loadFile(result_eval_file)
45 | # predictions = {p["questionId"]: p["prediction"] for p in predictions}
46 |
47 | gqa_eval = GQAEval(predictions, ques_file_path, choices_path=None, EVAL_CONSISTENCY=False)
48 | acc = gqa_eval.get_acc_result()['accuracy']
49 | acc_list.append(acc)
50 |
51 | print("Alpha:", alpha_list)
52 | print("Accuracy:", acc_list)
53 | # Plot: save to "tail_plot_[model_name].pdf"
54 | # plot_tail(alpha=list(map(lambda x: x + 1, alpha_list)), accuracy=acc_list,
55 | # model_name='default') # We plot 1+alpha vs. accuracy
56 | elif args.ood_test:
57 | result_eval_file = args.predictions
58 | file_list = {'Tail': 'ood_testdev_tail.json', 'Head': 'ood_testdev_head.json', 'All': 'ood_testdev_all.json'}
59 | result = {}
60 | for setup, ques_file_path in file_list.items():
61 | predictions = loadFile(result_eval_file)
62 | # predictions = {p["questionId"]: p["prediction"] for p in predictions}
63 |
64 | gqa_eval = GQAEval(predictions, f'{args.data_root}/' + ques_file_path, choices_path=None,
65 | EVAL_CONSISTENCY=False)
66 | result[setup] = gqa_eval.get_acc_result()['accuracy']
67 |
68 | result_string, detail_result_string = gqa_eval.get_str_result()
69 | print('\n___%s___' % setup)
70 | for result_string_ in result_string:
71 | print(result_string_)
72 |
73 | print('\nRESULTS:\n')
74 | msg = 'Accuracy (tail, head, all): %.2f, %.2f, %.2f' % (result['Tail'], result['Head'], result['All'])
75 | print(msg)
76 | # Sample command:
77 | # python eval/vqa/gqa_eval_from_file.py --predictions "{EXPT_ROOT}/gqa_1.1/SpectralDecouplingTrainer/lambda_0.001_gamma_0.001/ans_preds_Val All_Main_epoch_30.json" --eval_tail_size
78 |
--------------------------------------------------------------------------------
/experiments/gqa_experiments.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def set_if_null(option, attr_name, val):
5 | if not hasattr(option, attr_name) or getattr(option, attr_name) is None:
6 | setattr(option, attr_name, val)
7 |
8 |
9 | def gqa_experiments(option, run):
10 | # Method-specific arguments are defined in the bash files inside: scripts/gqa-ood
11 |
12 | # Here, we configure rest of the arguments which likely DO NOT NEED TO BE CHANGED
13 | option.dataset_name = 'GQA'
14 | option.data_dir = option.root_dir + f"/{option.dataset_name}"
15 | option.train_ratio = None # set to sth like 0.1 for debugging
16 |
17 | # Define the bias and target variables
18 | # option.num_bias_classes: # This is set to num_groups in main.py for GQA
19 | # option.num_classes: # This is set by gqa_dataset
20 | # option.num_groups: # This is set by gqa_dataset
21 |
22 | # Optimizer + Model
23 | set_if_null(option, 'model_name', 'UpDn')
24 | set_if_null(option, 'optimizer_name', 'Adam')
25 | set_if_null(option, 'batch_size', 128)
26 | set_if_null(option, 'epochs', 30)
27 |
28 | # We allow training on subset of data using train_ratio argument. If it is "None", then the full set is used.
29 | if option.train_ratio is not None:
30 | dataset_name = f"{option.dataset_name}_ratio_{option.train_ratio}"
31 | else:
32 | dataset_name = option.dataset_name
33 |
34 | # Configure name of the experiments and directory to save the results
35 | if option.save_dir is None:
36 | option.save_dir = os.path.join(option.root_dir, option.project_name, dataset_name, option.trainer_name)
37 |
38 | if option.expt_name is None:
39 | option.expt_name = f"lr_{option.lr}_wd_{option.weight_decay}"
40 |
41 | if option.key_to_group_by is not None:
42 | option.expt_name += f'_expl_bias_{option.key_to_group_by}'
43 |
44 | # Method-specific configurations
45 | if option.trainer_name == 'GroupDROTrainer':
46 | option.balanced_sampling_attributes = ['group_ix'] # Perform balanced sampling
47 | option.group_by = 'group_ix' # Groups for GDRO
48 | option.key_to_group_by = 'group_name' # Used to find readable group names
49 |
50 | if option.trainer_name == 'RUBiTrainer':
51 | set_if_null(option, 'bias_model_hid_dims', 2048)
52 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments
53 | set_if_null(option, 'bias_model_name', 'MLP3')
54 | set_if_null(option, 'bias_variable_type', 'categorical')
55 |
56 | if option.trainer_name == 'LNLTrainer':
57 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments
58 | set_if_null(option, 'bias_predictor_name', 'MLP3')
59 | set_if_null(option, 'bias_predictor_in_layer', 'question_features')
60 | option.bias_predictor_in_dims = 1024
61 | option.bias_predictor_hid_dims = 1024
62 |
63 | if option.trainer_name == 'IRMv1Trainer':
64 | set_if_null(option, 'num_envs_per_batch', 16)
65 | if option.bias_variable_name != 'question_features':
66 | option.bias_variable_type = 'categorical'
67 |
68 | # Test epochs
69 | option.test_every = 15
70 | option.save_every = 30
71 | option.save_model_every = 30
72 |
73 | if option.model_name == 'MCAN':
74 | option.bias_variable_dims = 2048
75 | else:
76 | option.bias_variable_dims = 1024
77 |
78 | run(option)
79 |
--------------------------------------------------------------------------------
/models/vqa/updn/tda.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Zhenwei Shao https://github.com/ParadoxZW
4 | # based on the implementation in https://github.com/hengyuan-hu/bottom-up-attention-vqa
5 | # ELU is chosen as the activation function in non-linear layers due to
6 | # the experiment results that indicate ELU is better than ReLU in BUTD model.
7 | # --------------------------------------------------------
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.nn.utils.weight_norm import weight_norm
12 | import torch
13 | import math
14 |
15 |
16 | # ------------------------------
17 | # ----- Weight Normal MLP ------
18 | # ------------------------------
19 |
20 | class MLP(nn.Module):
21 | """
22 | class for non-linear fully connect network
23 | """
24 |
25 | def __init__(self, dims, act='ELU', dropout_r=0.0):
26 | super(MLP, self).__init__()
27 |
28 | layers = []
29 | for i in range(len(dims) - 1):
30 | in_dim = dims[i]
31 | out_dim = dims[i + 1]
32 | if dropout_r > 0:
33 | layers.append(nn.Dropout(dropout_r))
34 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
35 | if act != '':
36 | layers.append(getattr(nn, act)())
37 |
38 | self.mlp = nn.Sequential(*layers)
39 |
40 | def forward(self, x):
41 | return self.mlp(x)
42 |
43 |
44 | # ------------------------------
45 | # ---Top Down Attention Map ----
46 | # ------------------------------
47 |
48 |
49 | class AttnMap(nn.Module):
50 | '''
51 | implementation of top down attention
52 | '''
53 |
54 | def __init__(self, img_feat_size, hidden_size, dropout_r):
55 | super(AttnMap, self).__init__()
56 | self.linear_q = weight_norm(
57 | nn.Linear(hidden_size, hidden_size), dim=None)
58 | self.linear_v = weight_norm(
59 | nn.Linear(img_feat_size, img_feat_size), dim=None)
60 | self.nonlinear = MLP(
61 | [img_feat_size + hidden_size, hidden_size], dropout_r=dropout_r)
62 | self.linear = weight_norm(nn.Linear(hidden_size, 1), dim=None)
63 |
64 | def forward(self, q, v):
65 | v = self.linear_v(v)
66 | q = self.linear_q(q)
67 | logits = self.logits(q, v)
68 | w = nn.functional.softmax(logits, 1)
69 | return w
70 |
71 | def logits(self, q, v):
72 | num_objs = v.size(1)
73 | q = q.unsqueeze(1).repeat(1, num_objs, 1)
74 | vq = torch.cat((v, q), 2)
75 | joint_repr = self.nonlinear(vq)
76 | logits = self.linear(joint_repr)
77 | return logits
78 |
79 |
80 | # ------------------------------
81 | # ---- Attended Joint Map ------
82 | # ------------------------------
83 |
84 |
85 | class TDA(nn.Module):
86 | def __init__(self, img_feat_size, hidden_size, dropout_r):
87 | super(TDA, self).__init__()
88 |
89 | self.v_att = AttnMap(img_feat_size, hidden_size, dropout_r)
90 | self.q_net = MLP([hidden_size, hidden_size])
91 | self.v_net = MLP([img_feat_size, hidden_size])
92 |
93 | def forward(self, q, v):
94 | att = self.v_att(q, v)
95 | atted_v = (att * v).sum(1)
96 | q_repr = self.q_net(q)
97 | v_repr = self.v_net(atted_v)
98 | joint_repr = q_repr * v_repr
99 | return joint_repr
100 |
--------------------------------------------------------------------------------
/utils/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class ClasswiseEMA:
6 |
7 | def __init__(self, dataset_size, alpha=0.7):
8 | self.dataset_size = dataset_size
9 | self.alpha = alpha
10 | self.labels = torch.ones(dataset_size).long().cuda() * -1000
11 | self.parameter = torch.zeros(dataset_size).cuda()
12 | self.updated = torch.zeros(dataset_size).cuda()
13 |
14 | def update(self, data, dataset_ix, label):
15 | # if dataset_ix.__class__ == torch.Tensor:
16 | # dataset_ix = dataset_ix.cuda()
17 | param = self.alpha * self.parameter[dataset_ix] + (
18 | 1 - self.alpha * self.updated[dataset_ix]) * data
19 | self.updated[dataset_ix] = 1
20 | self.labels[dataset_ix] = label
21 | self.parameter[dataset_ix] = param.detach()
22 | return param
23 |
24 | def max_loss(self, label):
25 | label_index = torch.where(self.labels == label)[0]
26 | if len(label_index) == 0:
27 | return 1
28 | else:
29 | return self.parameter[label_index].max()
30 |
31 |
32 | class EMA:
33 | def __init__(self, dataset_size, alpha=0.9):
34 | self.dataset_size = dataset_size
35 | self.alpha = alpha
36 | self.parameter = torch.zeros(dataset_size).cuda()
37 | self.updated = torch.zeros(dataset_size).cuda()
38 |
39 | def update(self, data, dataset_ix):
40 | param = self.alpha * self.parameter[dataset_ix] + (
41 | 1 - self.alpha * self.updated[dataset_ix]) * data
42 | self.updated[dataset_ix] = 1
43 | self.parameter[dataset_ix] = param.detach()
44 | return param
45 |
46 |
47 | class WeightsEMA:
48 | """Exponential moving average of model parameters.
49 | https://anmoljoshi.com/Pytorch-Dicussions/
50 | Args:
51 | model (torch.nn.Module): Model with parameters whose EMA will be kept.
52 | decay (float): Decay rate for exponential moving average.
53 | """
54 |
55 | def __init__(self, model, decay=0.999):
56 | self.decay = decay
57 | self.shadow = {}
58 | self.original = {}
59 |
60 | # Register model parameters
61 | for name, param in model.named_parameters():
62 | if param.requires_grad:
63 | self.shadow[name] = param.data.clone()
64 |
65 | def __call__(self, model, num_updates):
66 | decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates))
67 | for name, param in model.named_parameters():
68 | if param.requires_grad:
69 | assert name in self.shadow
70 | new_average = \
71 | (1.0 - decay) * param.data + decay * self.shadow[name]
72 | self.shadow[name] = new_average.clone()
73 |
74 | def assign(self, model):
75 | """Assign exponential moving average of parameter values to the
76 | respective parameters.
77 | Args:
78 | model (torch.nn.Module): Model to assign parameter values.
79 | """
80 | for name, param in model.named_parameters():
81 | if param.requires_grad:
82 | assert name in self.shadow
83 | self.original[name] = param.data.clone()
84 | param.data = self.shadow[name]
85 |
86 | def resume(self, model):
87 | """Restore original parameters to a model. That is, put back
88 | the values that were in each parameter at the last call to `assign`.
89 | Args:
90 | model (torch.nn.Module): Model to assign parameter values.
91 | """
92 | for name, param in model.named_parameters():
93 | if param.requires_grad:
94 | assert name in self.shadow
95 | param.data = self.original[name]
96 |
--------------------------------------------------------------------------------
/models/vqa/updn/net.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Zhenwei Shao https://github.com/ParadoxZW
4 | # --------------------------------------------------------
5 |
6 | from models.vqa.updn.tda import TDA
7 | from models.vqa.vqa_adapter import VQAAdapter
8 |
9 | import torch.nn as nn
10 | from torch.nn.utils.weight_norm import weight_norm
11 | import torch
12 |
13 |
14 | # -------------------------
15 | # ---- Main BUTD Model ----
16 | # -------------------------
17 |
18 | class UpDn(nn.Module):
19 | def __init__(self,
20 | pretrained_emb,
21 | token_size,
22 | answer_size,
23 | img_feat_size=1024,
24 | word_embed_size=300,
25 | hidden_size=1024,
26 | use_glove=True,
27 | flat_out_size=2048,
28 | dropout_r=0.2,
29 | classifier_dropout_r=0.5
30 | ):
31 | super(UpDn, self).__init__()
32 |
33 | self.embedding = nn.Embedding(
34 | num_embeddings=token_size,
35 | embedding_dim=word_embed_size
36 | )
37 |
38 | # Loading the GloVe embedding weights
39 | if use_glove:
40 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
41 |
42 | self.rnn = nn.LSTM(
43 | input_size=word_embed_size,
44 | hidden_size=hidden_size,
45 | num_layers=1,
46 | batch_first=True
47 | )
48 |
49 | self.adapter = VQAAdapter(hidden_size=hidden_size)
50 |
51 | self.backbone = TDA(img_feat_size=hidden_size,
52 | hidden_size=hidden_size,
53 | dropout_r=dropout_r
54 | )
55 |
56 | # Classification layers
57 | self.fc1 = weight_norm(nn.Linear(hidden_size, flat_out_size), dim=None)
58 | self.relu = nn.ReLU()
59 | self.dropout = nn.Dropout(classifier_dropout_r, inplace=True)
60 | self.classifier = weight_norm(nn.Linear(flat_out_size, answer_size), dim=None)
61 | # layers = [
62 | # weight_norm(nn.Linear(hidden_size,
63 | # flat_out_size), dim=None),
64 | # nn.ReLU(),
65 | # nn.Dropout(classifier_dropout_r, inplace=True),
66 | # weight_norm(nn.Linear(flat_out_size, answer_size), dim=None)
67 | # ]
68 | # self.classifier = nn.Sequential(*layers)
69 |
70 | def forward(self, frcn_feat, bbox_feat, question_token_ixs):
71 | # Pre-process Language Feature
72 | # lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
73 | lang_feat = self.embedding(question_token_ixs)
74 | lang_feat, _ = self.rnn(lang_feat)
75 |
76 | img_feat, _ = self.adapter(frcn_feat, bbox_feat)
77 |
78 | # Backbone Framework
79 | joint_feat = self.backbone(
80 | lang_feat[:, -1],
81 | img_feat
82 | )
83 |
84 | # Classification layers
85 | # proj_feat = self.classifier(joint_feat)
86 | fc1 = self.fc1(joint_feat)
87 | clf_in = self.dropout(self.relu(fc1))
88 | logits = self.classifier(clf_in)
89 |
90 | return {
91 | 'logits': logits,
92 | 'before_logits': fc1,
93 | 'question_features': lang_feat[:, -1],
94 | 'visual_features': img_feat,
95 | 'joint_features': joint_feat
96 | }
97 |
98 |
99 | class UpDnNoDropout(UpDn):
100 | def __init__(self,
101 | pretrained_emb, token_size, answer_size):
102 | super().__init__(pretrained_emb, token_size, answer_size, dropout_r=0, classifier_dropout_r=0)
103 |
--------------------------------------------------------------------------------
/trainers/group_upweighting_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | from trainers.base_trainer import BaseTrainer
5 | from utils.losses import *
6 | from utils.metric_visualizer import LossVisualizer
7 |
8 |
9 | class GroupUpweightingTrainer(BaseTrainer):
10 | """
11 | Simple upweighting technique which multiplies the loss by inverse group frequency.
12 | This has been found to work well when models are sufficiently underparameterized (e.g., low learning rate, high weight decay, fewer model parameters etc)
13 | Paper that investigated underparameterization with upweighting method: https://arxiv.org/abs/2005.04345
14 | """
15 |
16 | def __init__(self, option):
17 | super(GroupUpweightingTrainer, self).__init__(option)
18 | self.loss_visualizer = LossVisualizer(self.option.expt_dir)
19 |
20 | def get_keys_to_save(self):
21 | return super().get_keys_to_save() + ['group_ix_to_cnt', 'group_ix_to_weight', 'group_name_to_weight']
22 |
23 | def init_group_weights(self, loader):
24 | logging.getLogger().info("Initializing the group weights...")
25 | self.group_ix_to_cnt = {}
26 | self.group_ix_to_weight, self.group_name_to_weight = {}, {}
27 | self.group_ix_to_name = {}
28 | total_samples = 0
29 | for batch in loader:
30 | # If 'explicit_group_ix' is returned in the batch, then uses that to group the data
31 | # Else, uses 'group_ix' key to group the data
32 | if 'explicit_group_ix' in batch:
33 | grp_key = 'explicit_group_ix'
34 | grp_name_key = 'explicit_group_name'
35 | else:
36 | grp_key = 'group_ix'
37 | grp_name_key = 'group_name'
38 | for grp_ix, grp_name in zip(batch[grp_key], batch[grp_name_key]):
39 | grp_ix = int(grp_ix)
40 | if grp_ix not in self.group_ix_to_cnt:
41 | self.group_ix_to_cnt[grp_ix] = 0
42 | self.group_ix_to_cnt[grp_ix] += 1
43 | self.group_ix_to_name[grp_ix] = grp_name
44 | total_samples += 1
45 | for group_ix in self.group_ix_to_cnt:
46 | self.group_ix_to_weight[group_ix] = total_samples / self.group_ix_to_cnt[group_ix]
47 | self.group_name_to_weight[self.group_ix_to_name[group_ix]] \
48 | = total_samples / self.group_ix_to_cnt[group_ix]
49 |
50 | def train(self, train_loader, test_loaders=None, unbalanced_train_loader=None):
51 | self.init_group_weights(train_loader)
52 | super().train(train_loader, test_loaders, unbalanced_train_loader)
53 |
54 | def _train_epoch(self, epoch, data_loader):
55 | self._mode_setting(is_train=True)
56 |
57 | for i, batch in enumerate(data_loader):
58 | batch = self.prepare_batch(batch)
59 | out = self.forward_model(self.model, batch)
60 | logits = out['logits']
61 | batch_losses = self.loss(out['logits'], torch.squeeze(batch['y']))
62 | group_key = 'explicit_group_ix' if 'explicit_group_ix' in batch else 'group_ix'
63 |
64 | weights = torch.FloatTensor([self.group_ix_to_weight[group_ix] for group_ix in batch[group_key]]).cuda()
65 | weighted_batch_losses = weights * batch_losses # Multiply per-sample losses by weights for the corresponding groups
66 |
67 | self.optim.zero_grad()
68 | weighted_batch_losses.mean().backward()
69 | self.optim.step()
70 | self.loss_visualizer.update('Train', 'Loss', batch_losses.mean().detach().item())
71 | self.loss_visualizer.update('Train', 'Weighted Loss', weighted_batch_losses.mean().detach().item())
72 | self.update_generalization_metrics('Train', batch, weighted_batch_losses)
73 | self._after_train_epoch(epoch)
74 |
--------------------------------------------------------------------------------
/trainers/spectral_decoupling_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 | import math
4 | import os
5 |
6 | import torch
7 | from torch import nn
8 | from torch import optim
9 |
10 | from models.model_factory import build_model
11 | from utils import trainer_utils
12 | from utils.metric_visualizer import AccuracyVisualizer, LossVisualizer
13 | from utils.metrics import Accuracy, GroupWiseAccuracy
14 | from torch.optim import *
15 | import json
16 | from copy import deepcopy
17 | from utils.trainer_utils import create_optimizer
18 | from torch.nn import *
19 | from utils.losses import *
20 | from trainers.base_trainer import BaseTrainer
21 |
22 |
23 | class SpectralDecouplingTrainer(BaseTrainer):
24 | """
25 | Implementation for:
26 | Pezeshki, Mohammad, et al. "Gradient Starvation: A Learning Proclivity in Neural Networks." arXiv preprint arXiv:2011.09468 (2020).
27 |
28 | The paper shows that decay and shift in network's logits can help decouple learning of features, which may enable learning of signal too.
29 |
30 | Changes from the original implementation:
31 | The original implementation uses this loss: = torch.log(1.0 + torch.exp(-yhat[:, 0] * (2.0 * y - 1.0)))
32 | However, we just use cross-entropy since the above formulation doesn't make sense when there are more than 2 classes.
33 | """
34 |
35 | def __init__(self, option):
36 | super(SpectralDecouplingTrainer, self).__init__(option)
37 | self.loss_visualizer = LossVisualizer(self.option.expt_dir)
38 | if self.option.spectral_decoupling_lambdas is None:
39 | assert self.option.spectral_decoupling_lambda is not None, 'lambda not specified'
40 | self.option.spectral_decoupling_lambdas = torch.ones(
41 | self.option.num_classes) * self.option.spectral_decoupling_lambda
42 | if self.option.spectral_decoupling_gammas is None:
43 | assert self.option.spectral_decoupling_gamma is not None, 'gammas not specified'
44 | self.option.spectral_decoupling_gammas = torch.ones(
45 | self.option.num_classes) * self.option.spectral_decoupling_gamma
46 |
47 | def _train_epoch(self, epoch, data_loader):
48 | self._mode_setting(is_train=True)
49 |
50 | for i, batch in enumerate(data_loader):
51 | # Forward pass
52 | batch = self.prepare_batch(batch)
53 | out = self.forward_model(self.model, batch)
54 | logits = out['logits']
55 |
56 | # Compute the prediction loss
57 | # The original paper uses this formulation:
58 | # per_sample_losses = torch.log(1.0 + torch.exp(-yhat[:, 0] * (2.0 * y - 1.0)))
59 | # However, we just use cross-entropy since the above formulation doesn't make sense when there are more than 2 classes.
60 | pred_losses = self.loss(out['logits'], torch.squeeze(batch['y']))
61 |
62 | per_class_lambdas = torch.FloatTensor(
63 | [self.option.spectral_decoupling_lambdas[y] for y in batch['y']]).cuda()
64 | per_class_gammas = torch.FloatTensor(
65 | [self.option.spectral_decoupling_gammas[y] for y in batch['y']]).cuda()
66 |
67 | # The loss is based on equation 28 of the paper
68 | # softmax = torch.softmax(logits, dim=1)
69 | # gt_softmax = softmax.gather(1, batch['y'].view(-1, 1)).squeeze()
70 | gt_logits = logits.gather(1, batch['y'].view(-1, 1)).squeeze()
71 |
72 | logit_l2_losses = 0.5 * per_class_lambdas * (gt_logits - per_class_gammas) ** 2
73 | total_losses = pred_losses + logit_l2_losses
74 |
75 | self.optim.zero_grad()
76 | total_losses.mean().backward()
77 | self.optim.step()
78 |
79 | self.loss_visualizer.update('Train', 'Total Loss', total_losses.mean().detach().item())
80 | self.loss_visualizer.update('Train', 'Pred Loss', pred_losses.mean().detach().item())
81 | self.loss_visualizer.update('Train', 'Logit L2 Loss', logit_l2_losses.mean().detach().item())
82 |
83 | self._after_train_epoch(epoch)
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## An Investigation of Critical Issues in Bias Mitigation Techniques
2 |
3 | Our paper examines if the state-of-the-art bias mitigation methods are able to perform well on more realistic settings: with multiple sources of biases, hidden biases and without access to test distributions. This repository has implementations/re-implementations for seven popular techniques.
4 |
5 |
6 | ### Setup
7 |
8 | #### Install Dependencies
9 |
10 | `conda create -n bias_mitigator python=3.7`
11 |
12 | `source activate bias_mitigator`
13 |
14 | `conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch`
15 |
16 | `conda install tqdm opencv pandas`
17 |
18 | #### Configure Path
19 |
20 | - Edit the `ROOT` variable in `common.sh`. This directory will contain the datasets and the experimental results.
21 |
22 | #### Datasets
23 | - For each dataset, we test on train, val and test splits. Each dataset file contains a function to create a dataloader for all of these splits.
24 |
25 | ##### Biased MNIST v1/v2
26 | Since the publication, we created BiasedMNISTv2, a more challenging version of the dataset. Version 2 includes increased image sizes, spuriously correlated digit scales, distracting letters instead of simplistic geometric shapes, and updated background textures.
27 |
28 | We encourage the community to use the [BiasedMNIST v2](https://github.com/erobic/occam-nets-v1).
29 |
30 |
31 | You can download the BiasedMNISTv1 (WACV 2021) [from here](https://drive.google.com/file/d/1RlvskdRjdAj6sqpYeD48sR2uJnxAYmv5/view?usp=sharing).
32 |
33 | Both BiasedMNISTv1 and v2 are released under Creative Commons Attribution 4.0 International (CC BY 4.0) license.
34 |
35 |
36 | ##### CelebA
37 | - Download the dataset [from here](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8) and extract the data to `${ROOT}`
38 | - We adapted the data loader from `https://github.com/kohpangwei/group_DRO`
39 |
40 | ##### GQA-OOD
41 | - Download GQA (object features, spatial features and questions) [from here](https://cs.stanford.edu/people/dorarad/gqa/download.html)
42 | - Build the GQA-OOD by following [these instructions](https://github.com/gqa-ood/GQA-OOD/tree/master/code)
43 |
44 | - Download embeddings
45 | `python -m spacy download en_vectors_web_lg`
46 |
47 | - Preprocess visual/spatial features
48 | `./scripts/gqa-ood/preprocess_gqa.sh`
49 |
50 | #### Run the methods
51 |
52 | We have provided a separate bash file for running each method on each dataset in the `scripts` directory. Here is a sample script:
53 |
54 | ```bash
55 | source activate bias_mitigator
56 |
57 | TRAINER_NAME='BaseTrainer'
58 | lr=1e-3
59 | wd=0
60 | python main.py \
61 | --expt_type celebA_experiments \
62 | --trainer_name ${TRAINER_NAME} \
63 | --lr ${lr} \
64 | --weight_decay ${wd} \
65 | --expt_name ${TRAINER_NAME} \
66 | --root_dir ${ROOT}
67 | ```
68 |
69 | ### Contribute!
70 |
71 | - If you want to add more methods, simply follow one of the implementations inside `trainers` directory.
72 |
73 |
74 | ### Highlights from the paper:
75 |
76 | 1. Overall, methods fail when datasets contain multiple sources of bias, even if they excel on smaller settings with one or two sources of bias (e.g., CelebA).
77 | 
78 |
79 | 2. Methods can exploit both implicit (hidden) and explicit biases.
80 | 
81 |
82 | 3. Methods cannot handle multiple sources of bias even when they are explicitly labeled.
83 | 
84 |
85 | 4. Most methods show high sensitivity to the tuning distribution especially for minority groups
86 | 
87 |
88 |
89 | ### Citation
90 | ```
91 | @article{shrestha2021investigation,
92 | title={An investigation of critical issues in bias mitigation techniques},
93 | author={Shrestha, Robik and Kafle, Kushal and Kanan, Christopher},
94 | journal={Workshop on Applications of Computer Vision},
95 | year={2021}
96 | }
97 | ```
98 |
99 | This work was supported in part by the DARPA/SRI Lifelong Learning Machines program[HR0011-18-C-0051], AFOSR grant [FA9550-18-1-0121], and NSF award #1909696.
100 |
--------------------------------------------------------------------------------
/experiments/celebA_experiments.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def set_if_null(option, attr_name, val):
5 | if not hasattr(option, attr_name) or getattr(option, attr_name) is None:
6 | setattr(option, attr_name, val)
7 |
8 |
9 | def celebA_experiments(option, run):
10 | # Method-specific arguments are mostly defined in the bash files inside: scripts/celebA
11 |
12 | # Here, we configure rest of the arguments which likely DO NOT NEED TO BE CHANGED
13 | option.dataset_name = 'CelebA'
14 | option.data_dir = option.root_dir + f"/{option.dataset_name}"
15 |
16 | option.train_ratio = None # None, set to sth like 0.1 for debugging
17 |
18 | # Define the bias and target variables
19 | option.target_name = 'Blond_Hair'
20 | option.bias_variables = ['Male']
21 | option.bias_variable_name = 'Male'
22 | option.num_bias_classes = 2
23 | option.num_classes = 2
24 | option.num_groups = 4
25 |
26 | # Optimizer + Model
27 | set_if_null(option, 'optimizer_name', 'SGD')
28 | set_if_null(option, 'batch_size', 128)
29 | set_if_null(option, 'epochs', 50)
30 | set_if_null(option, 'model_name', 'ResNet18')
31 |
32 | # We allow training on subset of data using train_ratio argument. If it is "None", then the full set is used.
33 | if option.train_ratio is not None:
34 | dataset_name = f"{option.dataset_name}_ratio_{option.train_ratio}"
35 | else:
36 | dataset_name = option.dataset_name
37 |
38 | # Configure name of the experiments and directory to save the results
39 | if option.save_dir is None:
40 | option.save_dir = os.path.join(option.root_dir, option.project_name, dataset_name, 'predict_' + option.target_name,
41 | option.trainer_name)
42 | if option.expt_name is None:
43 | option.expt_name = f"lr_{option.lr}_wd_{option.weight_decay}"
44 |
45 | # Method-specific configurations
46 |
47 | if option.trainer_name == 'GroupDROTrainer':
48 | option.balanced_sampling_attributes = ['group_ix'] # Perform balanced sampling
49 | option.group_by = 'group_ix' # Groups for GDRO
50 | option.key_to_group_by = 'group_name' # Used to find readable group names
51 |
52 | if option.trainer_name == 'RUBiTrainer':
53 | set_if_null(option, 'bias_model_hid_dims', 512)
54 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments
55 | set_if_null(option, 'bias_model_name', 'MLP1')
56 | set_if_null(option, 'bias_variable_type', 'categorical')
57 | option.bias_variable_dims = 2
58 |
59 | if option.trainer_name == 'LNLTrainer':
60 | # We experimented with MLP2 and MLP3 too, but MLP1 had worked best in the preliminary experiments
61 | feature_dims = get_feature_dims(option.model_name, option.num_classes)
62 | set_if_null(option, 'bias_predictor_name', 'MLP1')
63 | set_if_null(option, 'bias_predictor_in_layer', 'model.pooled2')
64 | option.bias_predictor_in_dims = feature_dims[option.bias_predictor_in_layer]
65 | option.bias_predictor_hid_dims = feature_dims[option.bias_predictor_in_layer]
66 |
67 | if option.trainer_name == 'IRMv1Trainer':
68 | set_if_null(option, 'num_envs_per_batch', 4)
69 |
70 | if option.trainer_name == 'SpectralDecouplingTrainer':
71 | option.spectral_decoupling_lambdas = [10.0, 10.0] # For CelebA, we have per-class lambdas and gammas for SD.
72 | option.spectral_decoupling_gammas = [0.44, 2.5]
73 |
74 | # Test epochs
75 | option.test_epochs = [e for e in range(40, 51)] # Due to instability, we average accuracies over the last 10 epochs
76 | option.test_every = 10 # We further test every 10 epochs
77 | option.save_every = 50
78 |
79 | run(option)
80 |
81 |
82 | def get_feature_dims(model_name, num_classes):
83 | if 'ResNet18' in model_name:
84 | return {
85 | 'model.conv1': 64,
86 | 'model.pooled_conv1': 64,
87 | 'model.layer1.1.conv2': 64,
88 | 'model.pooled1': 64,
89 | 'model.layer2.1.conv2': 128,
90 | 'model.layer2_flattened': 128 * 28 * 28,
91 | 'model.pooled2': 128,
92 | 'model.layer3.1.conv2': 256,
93 | 'model.pooled3': 256,
94 | 'model.layer4.1.conv2': 512,
95 | 'model.pooled4': 512,
96 | 'model.fc': 10,
97 | 'logits': num_classes
98 | }
99 |
--------------------------------------------------------------------------------
/datasets/vqa/gqa_feat_preproc.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # GQA spatial features & object features .h5 files to .npz files transform script
4 | # Written by Pengbing Gao https://github.com/nbgao
5 | # --------------------------------------------------------
6 |
7 | '''
8 | Command line example:
9 | (1) Process spatial features
10 | python gqa_feat_preproc.py --mode=spatial --spatial_dir=./spatialFeatures --out_dir=./feats/gqa-grid
11 |
12 | (2) Process object features
13 | python gqa_feat_preproc.py --mode=object --object_dir=./objectFeatures --out_dir=./feats/gqa-frcn
14 | '''
15 |
16 | import h5py, glob, json, cv2, argparse
17 | import numpy as np
18 |
19 |
20 | # spatial features
21 | def process_spatial_features(feat_path, out_path):
22 | info_file = feat_path + '/gqa_spatial_info.json'
23 | try:
24 | info = json.load(open(info_file, 'r'))
25 | except:
26 | print('Failed to open info file:', info_file)
27 | return
28 | print('Total grid features', len(info))
29 |
30 | print('Making the to dict...')
31 | h5idx_to_imgid = {}
32 | for img_id in info:
33 | h5idx_to_imgid[str(info[img_id]['file']) + '_' + str(info[img_id]['idx'])] = img_id
34 |
35 | for ix in range(16):
36 | feat_file = feat_path + '/gqa_spatial_' + str(ix) + '.h5'
37 | print('Processing', feat_file)
38 | try:
39 | feat_dict = h5py.File(feat_file, 'r')
40 | except:
41 | print('Failed to open feat file:', feat_file)
42 | return
43 |
44 | features = feat_dict['features']
45 |
46 | for iy in range(features.shape[0]):
47 | img_id = h5idx_to_imgid[str(ix) + '_' + str(iy)]
48 | feature = features[iy]
49 | # save to .npz file ['x']
50 | np.savez(
51 | out_path + '/' + img_id + '.npz',
52 | x=feature.reshape(2048, 49).transpose(1, 0), # (49, 2048)
53 | )
54 |
55 | print('Process spatial features successfully!')
56 |
57 |
58 | # object features
59 | def process_object_features(feat_path, out_path):
60 | info_file = feat_path + '/gqa_objects_info.json'
61 | try:
62 | info = json.load(open(info_file, 'r'))
63 | except:
64 | print('Failed to open info file:', info_file)
65 | return
66 | print('Total frcn features', len(info))
67 |
68 | print('Making the to dict...')
69 | h5idx_to_imgid = {}
70 | for img_id in info:
71 | h5idx_to_imgid[str(info[img_id]['file']) + '_' + str(info[img_id]['idx'])] = img_id
72 |
73 | for ix in range(16):
74 | feat_file = feat_path + '/gqa_objects_' + str(ix) + '.h5'
75 | print('Processing', feat_file)
76 |
77 | try:
78 | feat_dict = h5py.File(feat_file, 'r')
79 | except:
80 | print('Failed to open feat file:', feat_file)
81 | return
82 |
83 | bboxes = feat_dict['bboxes']
84 | features = feat_dict['features']
85 |
86 | for iy in range(features.shape[0]):
87 | img_id = h5idx_to_imgid[str(ix) + '_' + str(iy)]
88 | img_info = info[img_id]
89 | objects_num = img_info['objectsNum']
90 | # save to .npz file ['x', 'bbox', 'width', 'height']
91 | np.savez(
92 | out_path + '/' + img_id + '.npz',
93 | x=features[iy, :objects_num],
94 | bbox=bboxes[iy, :objects_num],
95 | width=img_info['width'],
96 | height=img_info['height'],
97 | )
98 |
99 | print('Process object features successfully!')
100 |
101 |
102 | if __name__ == "__main__":
103 | parser = argparse.ArgumentParser(description='gqa_h52npz')
104 | parser.add_argument('--mode', '-mode', choices=['object', 'spatial', 'frcn', 'grid'], help='mode', type=str)
105 | parser.add_argument('--object_dir', '-object_dir', help='object features dir', type=str)
106 | parser.add_argument('--spatial_dir', '-spatial_dir', help='spatial features dir', type=str)
107 | parser.add_argument('--out_dir', '-out_dir', help='output dir', type=str)
108 |
109 | args = parser.parse_args()
110 |
111 | mode = args.mode
112 | object_path = args.object_dir
113 | spatial_path = args.spatial_dir
114 | out_path = args.out_dir
115 |
116 | print('mode:', mode)
117 | print('object_path:', object_path)
118 | print('spatial_path:', spatial_path)
119 | print('out_path:', out_path)
120 |
121 | # process spatial features
122 | if mode in ['spatial', 'grid']:
123 | process_spatial_features(spatial_path, out_path)
124 |
125 | # process object features
126 | if mode in ['object', 'frcn']:
127 | process_object_features(object_path, out_path)
128 |
--------------------------------------------------------------------------------
/trainers/group_dro_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | from trainers.base_trainer import BaseTrainer
5 | from utils.losses import *
6 | from utils.metric_visualizer import LossVisualizer
7 |
8 |
9 | class GroupDROTrainer(BaseTrainer):
10 | """
11 | Implementation for:
12 | Sagawa, Shiori, et al. "Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization." (ICLR 2020).
13 | GroupDRO groups the data using explicit bias variables and class labels and optimizes for the worst-case group.
14 | Paper: https://arxiv.org/pdf/1911.08731.pdf
15 | Original codebase: https://github.com/kohpangwei/group_DRO
16 | """
17 | def __init__(self, option):
18 | super(GroupDROTrainer, self).__init__(option)
19 | self.group_weights = torch.ones(self.option.num_groups).cuda() / self.option.num_groups
20 | self.loss_visualizer = LossVisualizer(self.option.expt_dir)
21 | self.group_names = None
22 |
23 | def get_keys_to_save(self):
24 | return super().get_keys_to_save() + ['group_weights']
25 |
26 | def compute_group_loss_and_counts(self, losses, group_idxs):
27 | """
28 | Computes groupwise loss and count
29 | :param losses:
30 | :param group_idxs:
31 | :return:
32 | """
33 | group_map = (torch.LongTensor(group_idxs).cuda() == torch.arange(self.option.num_groups).unsqueeze(
34 | 1).long().cuda()).float()
35 | group_count = group_map.sum(1)
36 | group_denom = group_count + (group_count == 0).float() # avoid nans
37 | group_loss = (group_map @ losses.view(-1)) / group_denom
38 | return group_loss, group_count
39 |
40 | def update_group_weights_and_compute_robust_loss(self, group_loss):
41 | """
42 | Use exponent of weighted loss to update the group weights. Uses those weights to compute overall robust loss
43 | :param group_loss:
44 | :return:
45 | """
46 | self.group_weights = self.group_weights * torch.exp(self.option.group_weight_step_size * group_loss.data)
47 | self.group_weights = self.group_weights / (self.group_weights.sum())
48 | robust_loss = group_loss @ self.group_weights.detach()
49 | return robust_loss, self.group_weights
50 |
51 | def _train_epoch(self, epoch, data_loader):
52 | self._mode_setting(is_train=True)
53 | for i, batch in enumerate(data_loader):
54 | batch = self.prepare_batch(batch)
55 | out = self.forward_model(self.model, batch)
56 | logits = out['logits']
57 | batch_losses = self.loss(out['logits'], torch.squeeze(batch['y']))
58 |
59 | # Compute the GDRO loss
60 | group_ixs = batch[self.option.group_by]
61 | group_loss, _ = self.compute_group_loss_and_counts(batch_losses, group_ixs)
62 | robust_loss, _ = self.update_group_weights_and_compute_robust_loss(group_loss)
63 |
64 | self.optim.zero_grad()
65 | robust_loss.backward(retain_graph=True)
66 | self.optim.step()
67 |
68 | self.loss_visualizer.update('Train', 'Group Loss', group_loss.mean().detach().item())
69 | self.loss_visualizer.update('Train', 'Robust Loss', robust_loss.detach().item())
70 | self.update_generalization_metrics('Train', batch, batch_losses)
71 |
72 | group_name_to_weights = {}
73 | for gix in self.dro_group_ix_to_name:
74 | group_name = self.dro_group_ix_to_name[gix]
75 | group_name_to_weights[group_name] = float(self.group_weights[gix])
76 |
77 | self.loss_visualizer.update_multiple('Train Group Weights', group_name_to_weights)
78 | if self.option.enable_groupwise_metrics:
79 | self.update_groupwise_values('Train', 'Group Loss', group_loss, batch)
80 | self.update_groupwise_values('Train', 'Robust Loss', group_loss, batch)
81 | self._after_train_epoch(epoch)
82 |
83 | def train(self, train_loader, test_loaders=None, unbalanced_train_loader=None):
84 | logging.getLogger().info("Beginning the training process...")
85 | self.compute_max_dataset_ixs(train_loader, test_loaders)
86 | self._initialization()
87 |
88 | # Gather group names
89 | self.dro_group_ix_to_name = {}
90 | for batch in train_loader:
91 | for group_ix, group_name in zip(batch[self.option.group_by], batch[self.option.key_to_group_by]):
92 | self.dro_group_ix_to_name[group_ix] = group_name
93 |
94 | self._mode_setting(is_train=True)
95 | start_epoch = 1
96 | for epoch in range(start_epoch, self.option.epochs + 1):
97 | self._train_epoch(epoch, train_loader)
98 | self._after_one_epoch(epoch, test_loaders)
99 | self.after_all_epochs()
100 |
--------------------------------------------------------------------------------
/datasets/vqa/ans_punct.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # based on VQA Evaluation Code
5 | # --------------------------------------------------------
6 |
7 | import re
8 |
9 | contractions = {
10 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
11 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
12 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
13 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
14 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
15 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
16 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
17 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
18 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
19 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
20 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've":
21 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
22 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
23 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
24 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
25 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
26 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
27 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
28 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
29 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
30 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
31 | "someoned've": "someone'd've", "someone'dve": "someone'd've",
32 | "someonell": "someone'll", "someones": "someone's", "somethingd":
33 | "something'd", "somethingd've": "something'd've", "something'dve":
34 | "something'd've", "somethingll": "something'll", "thats":
35 | "that's", "thered": "there'd", "thered've": "there'd've",
36 | "there'dve": "there'd've", "therere": "there're", "theres":
37 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
38 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
39 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
40 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
41 | "weren't", "whatll": "what'll", "whatre": "what're", "whats":
42 | "what's", "whatve": "what've", "whens": "when's", "whered":
43 | "where'd", "wheres": "where's", "whereve": "where've", "whod":
44 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
45 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
46 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
47 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
48 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
49 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
50 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
51 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
52 | "you'll", "youre": "you're", "youve": "you've"
53 | }
54 |
55 | manual_map = { 'none': '0',
56 | 'zero': '0',
57 | 'one': '1',
58 | 'two': '2',
59 | 'three': '3',
60 | 'four': '4',
61 | 'five': '5',
62 | 'six': '6',
63 | 'seven': '7',
64 | 'eight': '8',
65 | 'nine': '9',
66 | 'ten': '10'}
67 | articles = ['a', 'an', 'the']
68 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
69 | comma_strip = re.compile("(\d)(\,)(\d)")
70 | punct = [';', r"/", '[', ']', '"', '{', '}',
71 | '(', ')', '=', '+', '\\', '_', '-',
72 | '>', '<', '@', '`', ',', '?', '!']
73 |
74 | def process_punctuation(inText):
75 | outText = inText
76 | for p in punct:
77 | if (p + ' ' in inText or ' ' + p in inText) \
78 | or (re.search(comma_strip, inText) != None):
79 | outText = outText.replace(p, '')
80 | else:
81 | outText = outText.replace(p, ' ')
82 | outText = period_strip.sub("", outText, re.UNICODE)
83 | return outText
84 |
85 |
86 | def process_digit_article(inText):
87 | outText = []
88 | tempText = inText.lower().split()
89 | for word in tempText:
90 | word = manual_map.setdefault(word, word)
91 | if word not in articles:
92 | outText.append(word)
93 | else:
94 | pass
95 | for wordId, word in enumerate(outText):
96 | if word in contractions:
97 | outText[wordId] = contractions[word]
98 | outText = ' '.join(outText)
99 | return outText
100 |
101 |
102 | def prep_ans(answer):
103 | answer = process_digit_article(process_punctuation(answer))
104 | answer = answer.replace(',', '')
105 | return answer
106 |
--------------------------------------------------------------------------------
/models/vqa/mcan/net.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # --------------------------------------------------------
5 |
6 | from models.vqa.make_mask import make_mask
7 | from models.vqa.ops.fc import FC, MLP
8 | from models.vqa.ops.layer_norm import LayerNorm
9 | from models.vqa.mcan.mca import MCA_ED
10 | # from models.vqa.mcan.adapter import Adapter
11 | from models.vqa.vqa_adapter import VQAAdapter
12 | # from openvqa.models.mcan.mca import MCA_ED
13 | # from openvqa.models.mcan.adapter import Adapter
14 |
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch
18 |
19 |
20 | # ------------------------------
21 | # ---- Flatten the sequence ----
22 | # ------------------------------
23 |
24 | class AttFlat(nn.Module):
25 | def __init__(self, hidden_size=1024, flat_mlp_size=512, flat_glimpses=1, dropout_r=0.1, flat_out_size=2048):
26 | super(AttFlat, self).__init__()
27 | self.flat_glimpses = flat_glimpses
28 |
29 | self.mlp = MLP(
30 | in_size=hidden_size,
31 | mid_size=flat_mlp_size,
32 | out_size=flat_glimpses,
33 | dropout_r=dropout_r,
34 | use_relu=True
35 | )
36 |
37 | self.linear_merge = nn.Linear(
38 | hidden_size * flat_glimpses,
39 | flat_out_size
40 | )
41 |
42 | def forward(self, x, x_mask):
43 | att = self.mlp(x)
44 | att = att.masked_fill(
45 | x_mask.squeeze(1).squeeze(1).unsqueeze(2),
46 | -1e9
47 | )
48 | att = F.softmax(att, dim=1)
49 |
50 | att_list = []
51 | for i in range(self.flat_glimpses):
52 | att_list.append(
53 | torch.sum(att[:, :, i: i + 1] * x, dim=1)
54 | )
55 |
56 | x_atted = torch.cat(att_list, dim=1)
57 | x_atted = self.linear_merge(x_atted)
58 |
59 | return x_atted
60 |
61 |
62 | # -------------------------
63 | # ---- Main MCAN Model ----
64 | # -------------------------
65 |
66 | class MCAN(nn.Module):
67 | def __init__(self, pretrained_emb,
68 | token_size,
69 | answer_size,
70 | word_embed_size=300,
71 | use_glove=True,
72 | hidden_size=1024,
73 | flat_out_size=2048,
74 | flat_mlp_size=512,
75 | flat_glimpses=1,
76 | ff_size=4096,
77 | dropout_r=0.1,
78 | multi_head=8):
79 | super(MCAN, self).__init__()
80 |
81 | self.embedding = nn.Embedding(
82 | num_embeddings=token_size,
83 | embedding_dim=word_embed_size
84 | )
85 |
86 | # Loading the GloVe embedding weights
87 | if use_glove:
88 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
89 |
90 | self.lstm = nn.LSTM(
91 | input_size=word_embed_size,
92 | hidden_size=hidden_size,
93 | num_layers=1,
94 | batch_first=True
95 | )
96 |
97 | self.adapter = VQAAdapter(hidden_size=hidden_size)
98 |
99 | self.backbone = MCA_ED(hidden_size, ff_size, dropout_r, multi_head)
100 |
101 | # Flatten to vector
102 | self.attflat_img = AttFlat(hidden_size,
103 | flat_mlp_size=flat_mlp_size,
104 | flat_glimpses=flat_glimpses,
105 | dropout_r=dropout_r,
106 | flat_out_size=flat_out_size
107 | )
108 | self.attflat_lang = AttFlat(hidden_size,
109 | flat_mlp_size=flat_mlp_size,
110 | flat_glimpses=flat_glimpses,
111 | dropout_r=dropout_r,
112 | flat_out_size=flat_out_size)
113 |
114 | # Classification layers
115 | self.proj_norm = LayerNorm(flat_out_size)
116 | self.proj = nn.Linear(flat_out_size, answer_size)
117 |
118 | def forward(self, frcn_feat, bbox_feat, ques_ix):
119 | # Pre-process Language Feature
120 | lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
121 | lang_feat = self.embedding(ques_ix)
122 | lang_feat, _ = self.lstm(lang_feat)
123 |
124 | img_feat, img_feat_mask = self.adapter(frcn_feat, bbox_feat)
125 |
126 | # Backbone Framework
127 | lang_feat, img_feat = self.backbone(
128 | lang_feat,
129 | img_feat,
130 | lang_feat_mask,
131 | img_feat_mask
132 | )
133 |
134 | # Flatten to vector
135 | lang_feat = self.attflat_lang(
136 | lang_feat,
137 | lang_feat_mask
138 | )
139 |
140 | img_feat = self.attflat_img(
141 | img_feat,
142 | img_feat_mask
143 | )
144 |
145 | # Classification layers
146 | proj_feat = lang_feat + img_feat
147 | proj_feat = self.proj_norm(proj_feat)
148 | proj_feat = self.proj(proj_feat)
149 |
150 | return {
151 | 'question_features': lang_feat,
152 | 'logits': proj_feat
153 | }
154 |
155 | # return proj_feat
156 |
--------------------------------------------------------------------------------
/datasets/shape_generator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | class ShapeGenerator():
7 | def __init__(self, dim=32, padding=8):
8 | self.dim = dim
9 | self.padded_dim = dim - padding * 2
10 | self.padding = padding
11 | self.shapes = ['circle', 'right_triangle', 'obtuse_triangle', 'square', 'parallelogram', 'kite',
12 | 'pentagon', 'semi_circle', 'plus', 'arrow']
13 |
14 | def init_image(self):
15 | return np.zeros((self.padded_dim, self.padded_dim, 3), np.uint8)
16 |
17 | def generate_polygon(self, vertices, image=None, color=(255, 255, 255)):
18 | if image is None:
19 | image = self.init_image()
20 |
21 | pts = vertices.reshape((-1, 1, 2))
22 | cv2.polylines(image, [pts], isClosed=True, color=color, thickness=1)
23 | # fill it
24 | cv2.fillPoly(image, [pts], color=color)
25 | return image
26 |
27 | def generate_circle(self):
28 | center = (int(self.padded_dim / 2), int(self.padded_dim / 2))
29 | radius = int(self.padded_dim / 2)
30 | image = self.init_image()
31 | cv2.circle(image, center, radius, (255, 255, 255), thickness=-1)
32 | return image
33 |
34 | def generate_semi_circle(self):
35 | circle = self.generate_circle()
36 | circle[:, int(self.padded_dim / 2):] = 0
37 | return circle
38 |
39 | def generate_right_triangle(self):
40 | vertices = np.array(
41 | [[0, 0],
42 | [0, self.padded_dim],
43 | [self.padded_dim, self.padded_dim]], np.int32)
44 | return self.generate_polygon(vertices)
45 |
46 | def generate_obtuse_triangle(self):
47 | vertices = np.array([[0, 0],
48 | [self.padded_dim / 3, self.padded_dim],
49 | [self.padded_dim, self.padded_dim]], np.int32)
50 | return self.generate_polygon(vertices)
51 |
52 | def generate_square(self):
53 | vertices = np.array([[0, 0],
54 | [0, self.padded_dim],
55 | [self.padded_dim, self.padded_dim],
56 | [self.padded_dim, 0]], np.int32)
57 | return self.generate_polygon(vertices)
58 |
59 | def generate_parallelogram(self):
60 | vertices = np.array([[self.padded_dim / 3, 0],
61 | [0, self.padded_dim],
62 | [self.padded_dim * 2 / 3, self.padded_dim],
63 | [self.padded_dim, 0]], np.int32)
64 | return self.generate_polygon(vertices)
65 |
66 | def generate_kite(self):
67 | vertices = np.array([[self.padded_dim / 2, 0],
68 | [0, self.padded_dim / 3],
69 | [self.padded_dim / 2, self.padded_dim],
70 | [self.padded_dim, self.padded_dim / 3]], np.int32)
71 | return self.generate_polygon(vertices)
72 |
73 | def generate_pentagon(self):
74 | vertices = np.array([[self.padded_dim / 2, 0],
75 | [0, self.padded_dim / 3],
76 | [self.padded_dim / 3, self.padded_dim],
77 | [self.padded_dim * 2 / 3, self.padded_dim],
78 | [self.padded_dim, self.padded_dim / 3]], np.int32)
79 | return self.generate_polygon(vertices)
80 |
81 | def generate_hexagon(self):
82 | vertices = np.array([[self.padded_dim / 4, 0],
83 | [0, self.padded_dim / 2],
84 | [self.padded_dim / 4, self.padded_dim],
85 | [self.padded_dim * 3 / 4, self.padded_dim],
86 | [self.padded_dim, self.padded_dim / 2],
87 | [self.padded_dim * 3 / 4, 0]], np.int32)
88 | return self.generate_polygon(vertices)
89 |
90 | def generate_plus(self):
91 | image = self.init_image()
92 | vert_rect = np.array([[self.padded_dim / 3, 0],
93 | [self.padded_dim / 3, self.padded_dim],
94 | [self.padded_dim * 2 / 3, self.padded_dim],
95 | [self.padded_dim * 2 / 3, 0]], np.int32)
96 | hor_rect = np.array([[0, self.padded_dim / 3],
97 | [0, self.padded_dim * 2 / 3],
98 | [self.padded_dim, self.padded_dim * 2 / 3],
99 | [self.padded_dim, self.padded_dim / 3]], np.int32)
100 | image = self.generate_polygon(vert_rect, image=image)
101 | return self.generate_polygon(hor_rect, image=image)
102 |
103 | def generate_arrow(self):
104 | vertices = np.array([[0, self.padded_dim / 4],
105 | [0, self.padded_dim * 3 / 4],
106 | [self.padded_dim * 2 / 3, self.padded_dim * 3 / 4],
107 | [self.padded_dim, self.padded_dim / 2],
108 | [self.padded_dim * 2 / 3, self.padded_dim / 4]], np.int32)
109 | return self.generate_polygon(vertices)
110 |
111 | def generate(self, shape):
112 | method = getattr(self, 'generate_' + shape)
113 | shape = method()
114 | img = np.zeros((self.dim, self.dim, 3))
115 | img[self.padding:self.dim - self.padding, self.padding:self.dim - self.padding] = shape
116 | return img
117 |
118 |
119 | if __name__ == "__main__":
120 | generator = ShapeGenerator()
121 | save_dir = '/tmp/shapes'
122 | if not os.path.exists(save_dir):
123 | os.makedirs(save_dir)
124 |
125 | for s in generator.shapes:
126 | img = generator.generate(s)
127 | cv2.imwrite(save_dir + f'/{s}.jpg', img)
128 | print(f"Saved {s}")
129 |
--------------------------------------------------------------------------------
/trainers/rubi_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | import torch
5 |
6 | from models import model_factory
7 | from trainers.base_trainer import BaseTrainer
8 | from utils.bias_retrievers import build_bias_retriever
9 | from utils.metric_visualizer import LossVisualizer
10 | from utils.trainer_utils import create_optimizer
11 |
12 |
13 | class RUBiTrainer(BaseTrainer):
14 | """
15 | Adaptation of
16 | Cadene, Remi, et al. "Rubi: Reducing unimodal biases in visual question answering." (NeurIPS 2019).
17 |
18 | While the original implementation dealt with language biases in VQA by using question features as bias features,
19 | in this work, we use explicit labels (e.g., gender label in CelebA) as the bias input.
20 |
21 | Specifically, RUBi contains a main branch (taking in the full input e.g., image for CelebA or image and question for VQA)
22 | and a bias-only branch (e.g., taking in gender labels for CelebA).
23 | The sigmoid of logits from the bias-only branch are used to modulate logits from the main branch during training.
24 |
25 | Paper: https://arxiv.org/abs/2005.04345
26 | Original codebase: https://github.com/cdancette/rubi.bootstrap.pytorch
27 | """
28 |
29 | def __init__(self, option):
30 | super(RUBiTrainer, self).__init__(option)
31 | self.loss_visualizer = LossVisualizer(self.option.expt_dir)
32 |
33 | def _build_model(self):
34 | super()._build_model()
35 | if self.option.bias_model_name is None:
36 | self.option.bias_model_name = self.option.model_name
37 | self.bias_model = model_factory.build_model(self.option,
38 | self.option.bias_model_name,
39 | in_dims=self.option.bias_variable_dims,
40 | hid_dims=self.option.bias_model_hid_dims,
41 | out_dims=self.option.num_classes)
42 | logging.getLogger().info("Bias Model")
43 | logging.getLogger().info(self.bias_model)
44 | self.bias_retriever = build_bias_retriever(self.option.bias_variable_name)
45 |
46 | if self.option.cuda:
47 | self.model.cuda()
48 | self.bias_model.cuda()
49 | self.loss.cuda()
50 |
51 | def _build_optimizer(self):
52 | self.optim = create_optimizer(self.option.optimizer_name,
53 | named_params=list(self.model.named_parameters()) + list(
54 | self.bias_model.named_parameters()),
55 | lr=self.option.lr,
56 | weight_decay=self.option.weight_decay,
57 | momentum=self.option.momentum,
58 | freeze_layers=self.option.freeze_layers)
59 |
60 | def _mode_setting(self, is_train=True):
61 | self.model.train(is_train)
62 | self.bias_model.train(is_train)
63 |
64 | def _train_epoch(self, epoch, data_loader):
65 | self._mode_setting(is_train=True)
66 | for i, batch in enumerate(data_loader):
67 | batch = self.prepare_batch(batch)
68 | self.optim.zero_grad()
69 |
70 | main_out = self.forward_model(self.model, batch)
71 | logits = main_out['logits']
72 | bias = self.bias_retriever(batch, main_out)
73 |
74 | if self.option.bias_variable_type == 'categorical':
75 | # If it is a categorical bias variable (e.g., gender), then convert to one hot vectors
76 | _bias = torch.zeros((len(bias), self.option.bias_variable_dims))
77 | for ix, bias_ix in enumerate(bias):
78 | _bias[ix, int(bias_ix)] = 1
79 | bias = _bias
80 |
81 | bias = bias.cuda()
82 |
83 | # Feed into the bias model
84 | bias_out = self.bias_model(bias.detach())
85 | bias_logits = bias_out['logits']
86 | bias_losses = self.loss(bias_logits, batch['y'].squeeze())
87 | bias_loss = bias_losses.mean()
88 |
89 | # Modulate the main model's outputs with bias-only model's outputs
90 | main_losses = self.loss(logits, batch['y'].squeeze())
91 | loss = main_losses.mean()
92 | sigmoid_weight = torch.sigmoid(bias_logits)
93 | rubi_logits = logits * sigmoid_weight
94 | rubi_losses = self.loss(rubi_logits, batch['y'].squeeze())
95 | loss_ratio = rubi_losses / (rubi_losses + main_losses)
96 |
97 | # Optimize main model and bias-only model
98 | fused_losses = rubi_losses + bias_losses
99 | fused_loss = fused_losses.mean()
100 | fused_loss.backward(retain_graph=True)
101 |
102 | if self.option.grad_clip is not None:
103 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip)
104 | torch.nn.utils.clip_grad_norm_(self.bias_model.parameters(), self.option.grad_clip)
105 | self.optim.step()
106 | self.update_generalization_metrics('Train Main', batch, main_losses)
107 | self.update_generalization_metrics('Train Bias', batch, bias_losses)
108 | rubi_out = {}
109 | rubi_out['logits'] = rubi_logits
110 | self.update_generalization_metrics('Train RUBi', batch, rubi_losses)
111 | self.update_generalization_metrics('Train Fused', batch, fused_losses)
112 | if self.option.enable_groupwise_metrics:
113 | self.update_groupwise_values('RUBi Loss/(RUBi+Main)', 'Loss Ratio', loss_ratio, batch)
114 |
115 | self._after_train_epoch(epoch, 'Train Fused')
116 |
117 | def get_keys_to_save(self):
118 | return super().get_keys_to_save() + ['bias_model']
119 |
--------------------------------------------------------------------------------
/utils/trainer_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import json
4 | import time
5 | import logging
6 | import torch
7 | from torch import optim
8 | from torch.optim import *
9 |
10 |
11 | def save_option(option):
12 | if not os.path.exists(option.save_dir + '/' + option.expt_name):
13 | os.makedirs(option.save_dir + '/' + option.expt_name)
14 | option_path = os.path.join(option.save_dir, option.expt_name, "options.json")
15 |
16 | with open(option_path, 'w') as fp:
17 | json.dump(option.__dict__, fp, indent=4, sort_keys=True,
18 | default=lambda o: f"<>")
19 | logging.getLogger().info(json.dumps(option.__dict__, indent=4, sort_keys=True,
20 | default=lambda o: f"<>"))
21 |
22 |
23 | def initialize_logger(expt_dir):
24 | if not os.path.exists(expt_dir):
25 | os.makedirs(expt_dir)
26 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
27 | logger = logging.getLogger()
28 | logger.handlers = []
29 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
30 |
31 | streaming_handler = logging.StreamHandler()
32 | streaming_handler.setFormatter(formatter)
33 | logger.addHandler(streaming_handler)
34 |
35 | null_handler = logging.NullHandler()
36 | null_handler.setLevel(logging.DEBUG)
37 | logging.getLogger("tornado.access").addHandler(null_handler)
38 | logging.getLogger("tornado.access").propagate = False
39 |
40 | file_handler = logging.FileHandler(os.path.join(expt_dir, 'log.txt'))
41 | file_handler.setFormatter(formatter)
42 | logger.addHandler(file_handler)
43 | return logger
44 |
45 |
46 | class Timer(object):
47 | def __init__(self, logger, epochs, last_step=0):
48 | self.logger = logger
49 | self.epochs = epochs
50 | self.step = last_step
51 |
52 | curr_time = time.time()
53 | self.start = curr_time
54 | self.last = curr_time
55 |
56 | def __call__(self):
57 | curr_time = time.time()
58 | self.step += 1
59 |
60 | duration = curr_time - self.last
61 | remaining = (self.epochs - self.step) * (curr_time - self.start) / self.step / 3600
62 | msg = 'TIMER, duration(s)|remaining(h), %f, %f' % (duration, remaining)
63 |
64 | self.last = curr_time
65 |
66 |
67 | def get_dir(file_path):
68 | return '/'.join(file_path.split('/')[:-1])
69 |
70 |
71 | class GradMult(torch.autograd.Function):
72 |
73 | @staticmethod
74 | def forward(ctx, x, const):
75 | ctx.const = const
76 | return x.view_as(x)
77 |
78 | @staticmethod
79 | def backward(ctx, grad_output):
80 | return grad_output * ctx.const, None
81 |
82 |
83 | def grad_mult(x, const):
84 | return GradMult.apply(x, const)
85 |
86 |
87 | class GradReverse(torch.autograd.Function):
88 | @staticmethod
89 | def forward(ctx, x):
90 | return x.view_as(x)
91 |
92 | @staticmethod
93 | def backward(ctx, grad_output):
94 | return grad_output.neg() # * 0.1
95 |
96 |
97 | def grad_reverse(x):
98 | return GradReverse.apply(x)
99 |
100 |
101 | def create_optimizer(optimizer_name, named_params, lr, weight_decay=0, momentum=0.9, freeze_layers=None,
102 | custom_lr_config=None):
103 | """
104 | Builds the optimizer of given name, adding only the provided named_params
105 | Supports freezing layers too
106 |
107 | For the experiments, we did not use freeze_layers or custom_lr_config. So, this code can be simplified a lot now.
108 | :return:
109 | """
110 | if weight_decay is None:
111 | weight_decay = 0
112 |
113 | def should_be_added(layer_name):
114 | ret = True
115 | if freeze_layers is None:
116 | return ret
117 | for freeze_layer in freeze_layers:
118 | if layer_name.startswith(freeze_layer):
119 | ret = False
120 | return ret
121 |
122 | filt_params = []
123 | for name, param in named_params:
124 | param_dict = None
125 | if should_be_added(name):
126 | if custom_lr_config is not None:
127 | for custom_lr_name in custom_lr_config:
128 | if name.startswith(custom_lr_name):
129 | param_dict = {'params': param, 'lr': custom_lr_config[custom_lr_name]}
130 | if param_dict is None:
131 | param_dict = {'params': param, 'lr': lr}
132 | filt_params.append(param_dict)
133 | # logging.getLogger().info(f"Adding param: {name}")
134 | else:
135 | param.requires_grad = False # for efficiency
136 | logging.getLogger().info(f"Freezing param: {name}")
137 |
138 | if optimizer_name == 'SGD':
139 | optimizer = optim.SGD(filt_params, lr=lr, momentum=momentum, weight_decay=weight_decay)
140 | else:
141 | optimizer = eval(optimizer_name)(filt_params, lr=lr, weight_decay=weight_decay)
142 | return optimizer
143 |
144 |
145 | def clip_grad_norm(parameters, max_norm, norm_type=2):
146 | if isinstance(parameters, torch.Tensor):
147 | parameters = [parameters]
148 | parameters = list(filter(lambda p: p.grad is not None, parameters))
149 | max_norm = float(max_norm)
150 | norm_type = float(norm_type)
151 | if norm_type == 'inf':
152 | total_norm = max(p.grad.data.abs().max() for p in parameters)
153 | else:
154 | total_norm = 0
155 | for p in parameters:
156 | param_norm = p.grad.data.norm(norm_type)
157 | total_norm += param_norm.item() ** norm_type
158 | total_norm = total_norm ** (1. / norm_type)
159 | clip_coef = max_norm / (total_norm + 1e-6)
160 | if clip_coef < 1:
161 | for p in parameters:
162 | p.grad.data.mul_(clip_coef)
163 | return total_norm
164 |
--------------------------------------------------------------------------------
/models/vqa/mcan/mca.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Yuhao Cui https://github.com/cuiyuhao1996
4 | # --------------------------------------------------------
5 |
6 | from models.vqa.ops.fc import FC, MLP
7 | from models.vqa.ops.layer_norm import LayerNorm
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch
12 | import math
13 |
14 |
15 | # ------------------------------
16 | # ---- Multi-Head Attention ----
17 | # ------------------------------
18 |
19 | class MHAtt(nn.Module):
20 | def __init__(self, hidden_size, dropout_r, multi_head=8):
21 | super(MHAtt, self).__init__()
22 | self.multi_head = multi_head
23 | self.hidden_size = hidden_size
24 |
25 | self.linear_v = nn.Linear(hidden_size, hidden_size)
26 | self.linear_k = nn.Linear(hidden_size, hidden_size)
27 | self.linear_q = nn.Linear(hidden_size, hidden_size)
28 | self.linear_merge = nn.Linear(hidden_size, hidden_size)
29 |
30 | self.dropout = nn.Dropout(dropout_r)
31 |
32 | def forward(self, v, k, q, mask):
33 | n_batches = q.size(0)
34 |
35 | v = self.linear_v(v).view(
36 | n_batches,
37 | -1,
38 | self.multi_head,
39 | int(self.hidden_size / self.multi_head)
40 | ).transpose(1, 2)
41 |
42 | k = self.linear_k(k).view(
43 | n_batches,
44 | -1,
45 | self.multi_head,
46 | int(self.hidden_size / self.multi_head)
47 | ).transpose(1, 2)
48 |
49 | q = self.linear_q(q).view(
50 | n_batches,
51 | -1,
52 | self.multi_head,
53 | int(self.hidden_size / self.multi_head)
54 | ).transpose(1, 2)
55 |
56 | atted = self.att(v, k, q, mask)
57 | atted = atted.transpose(1, 2).contiguous().view(
58 | n_batches,
59 | -1,
60 | self.hidden_size
61 | )
62 |
63 | atted = self.linear_merge(atted)
64 |
65 | return atted
66 |
67 | def att(self, value, key, query, mask):
68 | d_k = query.size(-1)
69 |
70 | scores = torch.matmul(
71 | query, key.transpose(-2, -1)
72 | ) / math.sqrt(d_k)
73 |
74 | if mask is not None:
75 | scores = scores.masked_fill(mask, -1e9)
76 |
77 | att_map = F.softmax(scores, dim=-1)
78 | att_map = self.dropout(att_map)
79 |
80 | return torch.matmul(att_map, value)
81 |
82 |
83 | # ---------------------------
84 | # ---- Feed Forward Nets ----
85 | # ---------------------------
86 |
87 | class FFN(nn.Module):
88 | def __init__(self, hidden_size, ff_size, dropout_r):
89 | super(FFN, self).__init__()
90 |
91 | self.mlp = MLP(
92 | in_size=hidden_size,
93 | mid_size=ff_size,
94 | out_size=hidden_size,
95 | dropout_r=dropout_r,
96 | use_relu=True
97 | )
98 |
99 | def forward(self, x):
100 | return self.mlp(x)
101 |
102 |
103 | # ------------------------
104 | # ---- Self Attention ----
105 | # ------------------------
106 |
107 | class SA(nn.Module):
108 | def __init__(self, hidden_size, ff_size, dropout_r, multi_head):
109 | super(SA, self).__init__()
110 |
111 | self.mhatt = MHAtt(hidden_size, dropout_r, multi_head)
112 | self.ffn = FFN(hidden_size, ff_size, dropout_r)
113 |
114 | self.dropout1 = nn.Dropout(dropout_r)
115 | self.norm1 = LayerNorm(hidden_size)
116 |
117 | self.dropout2 = nn.Dropout(dropout_r)
118 | self.norm2 = LayerNorm(hidden_size)
119 |
120 | def forward(self, y, y_mask):
121 | y = self.norm1(y + self.dropout1(
122 | self.mhatt(y, y, y, y_mask)
123 | ))
124 |
125 | y = self.norm2(y + self.dropout2(
126 | self.ffn(y)
127 | ))
128 |
129 | return y
130 |
131 |
132 | # -------------------------------
133 | # ---- Self Guided Attention ----
134 | # -------------------------------
135 |
136 | class SGA(nn.Module):
137 | def __init__(self, hidden_size, ff_size, dropout_r, multi_head):
138 | super(SGA, self).__init__()
139 |
140 | self.mhatt1 = MHAtt(hidden_size, dropout_r, multi_head)
141 | self.mhatt2 = MHAtt(hidden_size, dropout_r, multi_head)
142 | self.ffn = FFN(hidden_size, ff_size, dropout_r)
143 |
144 | self.dropout1 = nn.Dropout(dropout_r)
145 | self.norm1 = LayerNorm(hidden_size)
146 |
147 | self.dropout2 = nn.Dropout(dropout_r)
148 | self.norm2 = LayerNorm(hidden_size)
149 |
150 | self.dropout3 = nn.Dropout(dropout_r)
151 | self.norm3 = LayerNorm(hidden_size)
152 |
153 | def forward(self, x, y, x_mask, y_mask):
154 | x = self.norm1(x + self.dropout1(
155 | self.mhatt1(v=x, k=x, q=x, mask=x_mask)
156 | ))
157 |
158 | x = self.norm2(x + self.dropout2(
159 | self.mhatt2(v=y, k=y, q=x, mask=y_mask)
160 | ))
161 |
162 | x = self.norm3(x + self.dropout3(
163 | self.ffn(x)
164 | ))
165 |
166 | return x
167 |
168 |
169 | # ------------------------------------------------
170 | # ---- MAC Layers Cascaded by Encoder-Decoder ----
171 | # ------------------------------------------------
172 |
173 | class MCA_ED(nn.Module):
174 | def __init__(self, hidden_size, ff_size, dropout_r, multi_head, layer=6):
175 | super(MCA_ED, self).__init__()
176 |
177 | self.enc_list = nn.ModuleList([SA(hidden_size, ff_size, dropout_r, multi_head) for _ in range(layer)])
178 | self.dec_list = nn.ModuleList([SGA(hidden_size, ff_size, dropout_r, multi_head) for _ in range(layer)])
179 |
180 | def forward(self, y, x, y_mask, x_mask):
181 | # Get encoder last hidden vector
182 | for enc in self.enc_list:
183 | y = enc(y, y_mask)
184 |
185 | # Input encoder last hidden vector
186 | # And obtain decoder last hidden vectors
187 | for dec in self.dec_list:
188 | x = dec(x, y, x_mask, y_mask)
189 |
190 | return y, x
191 |
--------------------------------------------------------------------------------
/models/vqa/ban/_ban.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # OpenVQA
3 | # Written by Zhenwei Shao https://github.com/ParadoxZW
4 | # Based on the implementation of paper "Bilinear Attention Neworks", NeurIPS 2018 https://github.com/jnhwkim/ban-vqa)
5 | # --------------------------------------------------------
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn.utils.weight_norm import weight_norm
10 | import torch, math
11 |
12 |
13 | # ------------------------------
14 | # ----- Weight Normal MLP ------
15 | # ------------------------------
16 |
17 | class MLP(nn.Module):
18 | """
19 | Simple class for non-linear fully connect network
20 | """
21 |
22 | def __init__(self, dims, act='ReLU', dropout_r=0.0):
23 | super(MLP, self).__init__()
24 |
25 | layers = []
26 | for i in range(len(dims) - 1):
27 | in_dim = dims[i]
28 | out_dim = dims[i + 1]
29 | if dropout_r > 0:
30 | layers.append(nn.Dropout(dropout_r))
31 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
32 | if act != '':
33 | layers.append(getattr(nn, act)())
34 |
35 | self.mlp = nn.Sequential(*layers)
36 |
37 | def forward(self, x):
38 | return self.mlp(x)
39 |
40 |
41 | # ------------------------------
42 | # ------ Bilinear Connect ------
43 | # ------------------------------
44 |
45 | class BC(nn.Module):
46 | """
47 | Simple class for non-linear bilinear connect network
48 | """
49 |
50 | def __init__(self,
51 | img_feat_size=2048,
52 | hidden_size=1024,
53 | k_times=3,
54 | dropout_r=0.2,
55 | classifier_dropout_r=0.5,
56 | glimpse=8,
57 | atten=False):
58 | super(BC, self).__init__()
59 | self.k_times = k_times
60 | ba_hidden_size = k_times * hidden_size
61 | self.v_net = MLP([img_feat_size,
62 | ba_hidden_size], dropout_r=dropout_r)
63 | self.q_net = MLP([hidden_size,
64 | ba_hidden_size], dropout_r=dropout_r)
65 | if not atten:
66 | self.p_net = nn.AvgPool1d(k_times, stride=k_times)
67 | else:
68 | self.dropout = nn.Dropout(classifier_dropout_r) # attention
69 |
70 | self.h_mat = nn.Parameter(torch.Tensor(
71 | 1, glimpse, 1, ba_hidden_size).normal_())
72 | self.h_bias = nn.Parameter(
73 | torch.Tensor(1, glimpse, 1, 1).normal_())
74 |
75 | def forward(self, v, q):
76 | # low-rank bilinear pooling using einsum
77 | v_ = self.dropout(self.v_net(v))
78 | q_ = self.q_net(q)
79 | logits = torch.einsum('xhyk,bvk,bqk->bhvq',
80 | (self.h_mat, v_, q_)) + self.h_bias
81 | return logits # b x h_out x v x q
82 |
83 | def forward_with_weights(self, v, q, w):
84 | v_ = self.v_net(v) # b x v x d
85 | q_ = self.q_net(q) # b x q x d
86 | logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))
87 | logits = logits.unsqueeze(1) # b x 1 x d
88 | logits = self.p_net(logits).squeeze(1) * self.k_times # sum-pooling
89 | return logits
90 |
91 |
92 | # ------------------------------
93 | # -------- BiAttention ---------
94 | # ------------------------------
95 |
96 |
97 | class BiAttention(nn.Module):
98 | def __init__(self,
99 | img_feat_size=2048,
100 | hidden_size=1024,
101 | k_times=3,
102 | dropout_r=0.2,
103 | classifier_dropout_r=0.5,
104 | glimpse=8
105 | ):
106 | super(BiAttention, self).__init__()
107 | #
108 | self.glimpse = glimpse
109 | self.logits = weight_norm(
110 | BC(img_feat_size, hidden_size, k_times, dropout_r, classifier_dropout_r, glimpse, True), name='h_mat',
111 | dim=None)
112 |
113 | def forward(self, v, q, v_mask=True, logit=False, mask_with=-float('inf')):
114 | v_num = v.size(1)
115 | q_num = q.size(1)
116 | logits = self.logits(v, q) # b x g x v x q
117 |
118 | if v_mask:
119 | mask = (0 == v.abs().sum(2)).unsqueeze(
120 | 1).unsqueeze(3).expand(logits.size())
121 | logits.data.masked_fill_(mask.data, mask_with)
122 |
123 | if not logit:
124 | p = nn.functional.softmax(
125 | logits.view(-1, self.glimpse, v_num * q_num), 2)
126 | return p.view(-1, self.glimpse, v_num, q_num), logits
127 |
128 | return logits
129 |
130 |
131 | # ------------------------------
132 | # - Bilinear Attention Network -
133 | # ------------------------------
134 |
135 | class _BAN(nn.Module):
136 | def __init__(self,
137 | img_feat_size,
138 | hidden_size,
139 | k_times,
140 | dropout_r,
141 | classifier_dropout_r,
142 | glimpse):
143 | super(_BAN, self).__init__()
144 |
145 | self.BiAtt = BiAttention(img_feat_size,
146 | hidden_size,
147 | k_times,
148 | dropout_r,
149 | classifier_dropout_r,
150 | glimpse)
151 | b_net = []
152 | q_prj = []
153 | c_prj = []
154 | self.glimpse = glimpse
155 | for i in range(glimpse):
156 | b_net.append(BC(img_feat_size,
157 | hidden_size,
158 | k_times,
159 | dropout_r,
160 | classifier_dropout_r,
161 | glimpse,
162 | False))
163 | q_prj.append(MLP([hidden_size, hidden_size], '', dropout_r))
164 | self.b_net = nn.ModuleList(b_net)
165 | self.q_prj = nn.ModuleList(q_prj)
166 |
167 | def forward(self, q, v):
168 | att, logits = self.BiAtt(v, q) # b x g x v x q
169 |
170 | for g in range(self.glimpse):
171 | bi_emb = self.b_net[g].forward_with_weights(
172 | v, q, att[:, g, :, :]) # b x l x h
173 | q = self.q_prj[g](bi_emb.unsqueeze(1)) + q
174 |
175 | return q
176 |
--------------------------------------------------------------------------------
/utils/metric_visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 | import torch
4 | # import neptune
5 | import json
6 |
7 |
8 | class MetricVisualizer():
9 | def __init__(self, env_name='main', divide_windows_by='split', is_one_time_metric=True):
10 | self.env_name = env_name
11 | self.plots = {}
12 | self.reset()
13 | self.divide_windows_by = divide_windows_by
14 | self.use_markers = False
15 | self.is_one_time_metric = is_one_time_metric
16 |
17 | def update_multiple(self, split, dict):
18 | for name, value in zip(dict.keys(), dict.values()):
19 | self.update(split, name, value)
20 |
21 | def update(self, split, name, value):
22 | if split not in self.metrics:
23 | self.metrics[split] = {}
24 | self.metric_cnts[split] = {}
25 | if name not in self.metrics[split]:
26 | if self.is_one_time_metric:
27 | self.metrics[split][name] = 0
28 | else:
29 | self.metrics[split][name] = []
30 | self.metric_cnts[split][name] = 0
31 |
32 | if self.is_one_time_metric:
33 | self.metrics[split][name] = value
34 | self.metric_cnts[split][name] = 1
35 | else:
36 | if isinstance(value, list):
37 | for val in value:
38 | self.metrics[split][name].append(val)
39 | self.metric_cnts[split][name] += 1
40 | else:
41 | self.metrics[split][name].append(value)
42 | self.metric_cnts[split][name] += 1
43 |
44 | def reset(self):
45 | self.metrics = {}
46 | self.metric_cnts = {}
47 |
48 | def compute_and_save_std_dev(self, split, name):
49 | if split + " Std Dev" not in self.metrics:
50 | self.metrics[split + " Std Dev"] = {}
51 | self.metrics[split + " Variance"] = {}
52 |
53 | self.metrics[split + " Std Dev"][name] = torch.std(torch.Tensor(self.metrics[split][name]))
54 | self.metrics[split + " Variance"][name] = self.metrics[split + " Std Dev"][name] ** 2
55 |
56 | def log(self, epoch, split, avg=True):
57 | log_str = f"Split: {split}"
58 | if epoch is not None:
59 | log_str += f', Epoch: {epoch}'
60 | name_values = {}
61 | for ix, name in enumerate(self.metrics[split]):
62 | if self.is_one_time_metric:
63 | val = self.metrics[split][name]
64 | else:
65 | val = sum(self.metrics[split][name])
66 | if avg:
67 | val = val / (self.metric_cnts[split][name] + int(self.metric_cnts[split][name] == 0))
68 | metric_format = self.get_metric_format()
69 | if 'cnt' in name.lower():
70 | metric_format = '%d'
71 |
72 | # log_str += (", %s: " + metric_format) % (name, val)
73 | # neptune.log_metric(log_name=split + " " + name, x=epoch, y=val)
74 | name_values[name] = val
75 | if len(name_values) <= 1:
76 | logging.getLogger().info(name_values)
77 | else:
78 | logging.getLogger().info(json.dumps(name_values, sort_keys=True, indent=4))
79 | # logging.getLogger().info(sorted(list(name_values.keys())))
80 | # vals = [float(name_values[k]) for k in sorted(list(name_values.keys()))]
81 | # logging.getLogger().info(vals)
82 |
83 | # logging.getLogger().info(log_str)
84 |
85 | def get_metric_format(self):
86 | return "%.4f"
87 |
88 | def accumulate_plot_and_reset(self, epoch, xlabel='Epochs', avg=True):
89 | x = epoch
90 | for split in self.metrics:
91 | for name in self.metrics[split]:
92 | if self.is_one_time_metric:
93 | y = self.metrics[split][name]
94 | else:
95 | try:
96 | y = sum(self.metrics[split][name])
97 | if avg:
98 | y = y / (self.metric_cnts[split][name] + int(self.metric_cnts[split][name] == 0))
99 | except:
100 | y = self.metrics[split][name]
101 | self.plot(split, name, x, y, xlabel)
102 | self.reset()
103 |
104 | def plot(self, split, name, x, y, xlabel='Epochs'):
105 | """
106 | Creates a separate chart for each loss
107 | :return:
108 | """
109 | if self.divide_windows_by == 'name':
110 | win_name = name
111 | line_name = split
112 | else:
113 | win_name = split
114 | line_name = name
115 | # if win_name not in self.plots:
116 | # self.plots[win_name] = self.viz.line(X=np.array([x, x]), Y=np.array([y, y]), env=self.env_name,
117 | # opts=dict(legend=[line_name], title=win_name, xlabel=xlabel,
118 | # ylabel=name, markers=self.use_markers))
119 | #
120 | # else:
121 | # self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env_name, win=self.plots[win_name], name=line_name,
122 | # update='append')
123 |
124 |
125 | class LossVisualizer(MetricVisualizer):
126 | def __init__(self, env_name='main', divide_windows_by='split'):
127 | super(LossVisualizer, self).__init__(env_name, divide_windows_by, is_one_time_metric=False)
128 | self.use_markers = False
129 |
130 | def get_metric_format(self):
131 | return "%.4f"
132 |
133 |
134 | class AccuracyVisualizer(MetricVisualizer):
135 | def __init__(self, env_name='main', divide_windows_by='split'):
136 | super(AccuracyVisualizer, self).__init__(env_name, divide_windows_by, is_one_time_metric=True)
137 | self.use_markers = True
138 |
139 | def get_metric_format(self):
140 | return "%.2f%%"
141 |
142 |
143 | class CountVisualizer(MetricVisualizer):
144 | def __init__(self, env_name='main', divide_windows_by='name'):
145 | super(CountVisualizer, self).__init__(env_name, divide_windows_by, is_one_time_metric=True)
146 | self.use_markers = False
147 |
148 | def get_metric_format(self):
149 | return "%d"
150 |
151 | def log(self, epoch, split):
152 | return super().log(epoch, split)
153 |
154 | def accumulate_plot_and_reset(self, epoch, xlabel='Epochs'):
155 | return super().accumulate_plot_and_reset(epoch, xlabel)
156 |
--------------------------------------------------------------------------------
/trainers/lnl_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from torch import optim
5 |
6 | import models
7 | import models.cnn_models
8 | from trainers.base_trainer import BaseTrainer
9 | from utils.trainer_utils import grad_reverse, grad_mult, create_optimizer
10 | from torch.optim import *
11 | import logging
12 | from models.model_factory import build_model
13 |
14 |
15 | class LNLTrainer(BaseTrainer):
16 | """
17 | Implementation for:
18 | Kim, Byungju, et al. "Learning not to learn: Training deep neural networks with biased data." (CVPR 2019).
19 |
20 | Method Description:
21 | Has a main branch taking in bias+signal to predict classes and has branches predicting bias variables.
22 | Bias branches act as adversaries to the main branch, thereby reducing dependencies on biases.
23 |
24 | Reference codebase: https://github.com/feidfoe/learning-not-to-learn
25 |
26 | Known Issue(s)/Confusion(s) with original repo:
27 | 1) The paper does not provide full details to replicate the results: https://github.com/feidfoe/learning-not-to-learn/issues/7
28 |
29 | 2) There are two forward passes, one to compute the main loss+entropy loss and the other to compute bias loss (training bias predictors).
30 | Unsure why it hasn't been done with single forward pass in the original repo: https://github.com/feidfoe/learning-not-to-learn/issues/5
31 | """
32 |
33 | def __init__(self, option):
34 | super(LNLTrainer, self).__init__(option)
35 |
36 | def _build_model(self):
37 | super()._build_model()
38 | self.bias_predictor = build_model(self.option,
39 | self.option.bias_predictor_name,
40 | in_dims=self.option.bias_predictor_in_dims,
41 | hid_dims=self.option.bias_predictor_hid_dims,
42 | out_dims=self.option.num_bias_classes)
43 | logging.getLogger().info(f"Bias predictor {self.bias_predictor}")
44 |
45 | if self.option.cuda:
46 | self.model.cuda()
47 | self.bias_predictor.cuda()
48 |
49 | def _build_optimizer(self):
50 | super()._build_optimizer()
51 | self.bias_predictor_optim = create_optimizer(self.option.optimizer_name,
52 | named_params=self.bias_predictor.named_parameters(),
53 | lr=self.option.lr,
54 | weight_decay=self.option.weight_decay,
55 | momentum=self.option.momentum,
56 | freeze_layers=self.option.freeze_layers)
57 |
58 | def get_keys_to_save(self):
59 | return super().get_keys_to_save() + ['bias_predictor', 'bias_predictor_optim']
60 |
61 | def _mode_setting(self, is_train=True):
62 | self.model.train(is_train)
63 | self.bias_predictor.train(is_train)
64 |
65 | def _train_epoch(self, epoch, data_loader):
66 | self._mode_setting(is_train=True)
67 |
68 | for i, batch in enumerate(data_loader):
69 | # Prepare data
70 | batch = self.prepare_batch(batch)
71 | labels = batch['y']
72 |
73 | # Forward pass
74 | self.optim.zero_grad()
75 | self.bias_predictor_optim.zero_grad()
76 | out = self.forward_model(self.model, batch)
77 | hidden = out[self.option.bias_predictor_in_layer]
78 | logits = out['logits']
79 |
80 | bias_out = self.bias_predictor(hidden)
81 | bias_logits = bias_out['logits']
82 |
83 | # main loss
84 | batch_loss = self.loss(logits, torch.squeeze(labels))
85 | loss_pred = batch_loss.mean()
86 | self.loss_visualizer.update('Train', 'Loss', loss_pred.item())
87 |
88 | # bias loss
89 | bias_softmax = torch.nn.functional.softmax(bias_logits, dim=1) + 1e-8
90 | bias_entropy_loss = torch.mean(
91 | torch.sum(bias_softmax * torch.log(bias_softmax), 1)) * self.option.entropy_loss_weight
92 | self.loss_visualizer.update('Train', 'Bias Entropy', bias_entropy_loss.item())
93 | loss = loss_pred + bias_entropy_loss
94 | loss.backward(retain_graph=True)
95 |
96 | if self.option.grad_clip is not None:
97 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip)
98 | torch.nn.utils.clip_grad_norm_(self.bias_predictor.parameters(), self.option.grad_clip)
99 |
100 | self.optim.step()
101 | self.optim.zero_grad()
102 | self.bias_predictor_optim.zero_grad()
103 |
104 | # Train with adversarial loss
105 | out = self.forward_model(self.model, batch)
106 | hidden = out[self.option.bias_predictor_in_layer]
107 | hidden_feat = grad_mult(hidden, self.option.grad_reverse_factor)
108 | bias_out = self.bias_predictor(hidden_feat)
109 |
110 | bias_labels = batch[self.option.bias_variable_name]
111 | if isinstance(bias_labels, list):
112 | bias_labels = torch.LongTensor(bias_labels)
113 | bias_labels = bias_labels.long().cuda()
114 | if len(bias_labels.squeeze().shape) > 1:
115 | bias_labels = torch.argmax(bias_labels, dim=1).squeeze()
116 | bias_loss = self.loss(bias_out['logits'].squeeze(), bias_labels.squeeze())
117 | self.loss_visualizer.update('Train', 'Main loss', loss_pred.mean().item())
118 | self.loss_visualizer.update('Train', 'Bias Entropy loss', bias_entropy_loss.mean().item())
119 | self.loss_visualizer.update('Train', 'Bias loss', bias_loss.mean().item())
120 | self.update_generalization_metrics('Train', batch, batch_loss)
121 | bias_loss.mean().backward(retain_graph=True)
122 |
123 | if self.option.grad_clip is not None:
124 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip)
125 | torch.nn.utils.clip_grad_norm_(self.bias_predictor.parameters(), self.option.grad_clip)
126 |
127 | self.optim.step()
128 | self.bias_predictor_optim.step()
129 |
130 | self._after_train_epoch(epoch)
131 |
132 | def _after_one_epoch(self, epoch, test_loaders, force_test=False):
133 | super()._after_one_epoch(epoch, test_loaders, force_test)
134 |
--------------------------------------------------------------------------------
/trainers/irm_v1_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 | import random
4 |
5 | import numpy as np
6 | from torch.utils.data import Subset, DataLoader
7 |
8 | from trainers.base_trainer import BaseTrainer
9 | from utils.losses import *
10 |
11 |
12 | class IRMv1Trainer(BaseTrainer):
13 | """
14 | Implementation for:
15 | Arjovsky, Martin, et al. "Invariant risk minimization." (ICLR 2021).
16 |
17 | This attempts to learn representations that enable the same classifier to be optimal across environments.
18 | For our implementation, we use the explicit groups based on the explicit bias variables including the class label to form the environments.
19 | We uniformly sample from environments within each batch (i.e., our implementation assumes balanced group sampling).
20 | """
21 |
22 | def __init__(self, option):
23 | super(IRMv1Trainer, self).__init__(option)
24 | self.group_names = None
25 | self.env_to_data_loader = None
26 | self.num_envs_per_batch = option.num_envs_per_batch
27 | assert self.option.batch_size % self.num_envs_per_batch == 0, \
28 | "Batch size is not exactly divisible by the number of environments within that batch"
29 |
30 | def compute_gradient_penalty(self, logits, y):
31 | """
32 | Gradient of the risk when all classifier weights are set to 1 (i.e., the IRMv1 regularizer)
33 | Based on unbiased estimation of IRMV1 (Sec 3.2) and Appendix D of https://arxiv.org/pdf/1907.02893.pdf
34 | It assumes that each batch contains per-environment samples from 'num_envs_per_batch' environments in an ordered manner.
35 |
36 | :param logits:
37 | :param y:
38 | :return:
39 | """
40 | dummy_classifier = torch.tensor(1.).cuda().requires_grad_()
41 | loss = self.loss(logits * dummy_classifier, y.squeeze())
42 | start_ix = 0
43 | grad_loss = 0
44 | for env in np.arange(0, self.num_envs_per_batch):
45 | end_ix = start_ix + self.option.batch_size // self.num_envs_per_batch
46 | env_loss = loss[start_ix:end_ix]
47 | loss1 = env_loss[:len(env_loss) // 2]
48 | loss2 = env_loss[len(env_loss) // 2:]
49 | grad1 = torch.autograd.grad(loss1.mean(), [dummy_classifier], create_graph=True)[0]
50 | grad2 = torch.autograd.grad(loss2.mean(), [dummy_classifier], create_graph=True)[0]
51 | grad_loss += grad1 * grad2
52 | start_ix = end_ix
53 |
54 | return grad_loss
55 |
56 | def train(self, train_loader, test_loaders=None, unbalanced_train_loader=None):
57 | self.before_train(train_loader, test_loaders)
58 | start_epoch = 1
59 | orig_loader = train_loader
60 | batch_sampler = EnvironmentWiseBatchSampler(self.option.batch_size, orig_loader, self.num_envs_per_batch)
61 | dataset = orig_loader.dataset
62 | if isinstance(dataset, Subset):
63 | dataset = dataset.dataset
64 |
65 | train_loader = DataLoader(dataset, batch_sampler=batch_sampler,
66 | num_workers=orig_loader.num_workers, collate_fn=orig_loader.collate_fn)
67 |
68 | for epoch in range(start_epoch, self.option.epochs + 1):
69 | self._train_epoch(epoch, train_loader)
70 | self._after_one_epoch(epoch, test_loaders)
71 | self.after_all_epochs()
72 |
73 | def _train_epoch(self, epoch, data_loader):
74 | self._mode_setting(is_train=True)
75 |
76 | for batch_ix, batch in enumerate(data_loader):
77 | batch = self.prepare_batch(batch)
78 | out = self.forward_model(self.model, batch)
79 | logits = out['logits']
80 |
81 | # Unbiased IRMv1 goes through each environment before doing a backward pass
82 | # However this is not scalable e.g., when # of environments are in 100s or 1000s,
83 | # so we randomly sample certain environments in every batch
84 | batch_losses = self.loss(logits, torch.squeeze(batch['y']))
85 | grad_penalty = self.option.grad_penalty_weight * self.compute_gradient_penalty(logits, batch['y'])
86 | self.loss_visualizer.update(f'Train', 'Main Loss', batch_losses.mean().item())
87 | self.loss_visualizer.update(f'Train', 'Grad Penalty',
88 | self.option.grad_penalty_weight * grad_penalty.mean().item())
89 |
90 | self.optim.zero_grad()
91 | loss = batch_losses.mean() + grad_penalty.mean()
92 | loss.backward(retain_graph=True) # Cannot go through each environment before calling backward()
93 | if self.option.grad_clip is not None:
94 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip)
95 | self.optim.step()
96 |
97 | self.optim.zero_grad()
98 | self._after_train_epoch(epoch)
99 |
100 |
101 | class EnvironmentWiseBatchSampler():
102 | def __init__(self, batch_size, data_loader, num_envs_per_batch):
103 | """
104 | We first identify the environment for each data item
105 | For each batch, we randomly sample from random environments
106 | """
107 | self.num_items = 0
108 | self.env_to_dataset_ixs = {}
109 | self.batch_size = batch_size
110 | self.num_envs_per_batch = num_envs_per_batch
111 |
112 | # Do one pass through the train_loader to get indices per env
113 | for batch_ix, batch in enumerate(data_loader):
114 | for dix, gix in zip(batch['dataset_ix'], batch['group_ix']):
115 | if gix not in self.env_to_dataset_ixs:
116 | self.env_to_dataset_ixs[gix] = []
117 | self.env_to_dataset_ixs[gix].append(dix)
118 | self.num_items += 1
119 | self.env_keys = list(self.env_to_dataset_ixs.keys())
120 | logging.getLogger().info(f"env keys {self.env_keys}")
121 | # for gix in self.env_to_dataset_ixs:
122 | # logging.getLogger().info(f"Env Key: {gix} Cnt: {len(self.env_to_dataset_ixs[gix])}")
123 |
124 | def __iter__(self):
125 | num_batches_per_epoch = self.__len__()
126 | curr_batch_cnt = 0
127 |
128 | while curr_batch_cnt <= num_batches_per_epoch:
129 | # Randomly select some environments per batch
130 | env_ixs = np.random.choice(self.env_keys, self.num_envs_per_batch) # Randomly sample some environments
131 |
132 | # Randomly select within each of the chosen environments
133 | batch = []
134 |
135 | for env_ix in env_ixs:
136 | for b in np.arange(self.batch_size // self.num_envs_per_batch):
137 | dix = random.choice(self.env_to_dataset_ixs[env_ix])
138 | batch.append(dix)
139 |
140 | curr_batch_cnt += 1
141 | yield batch
142 |
143 | def __len__(self):
144 | # The total budget per epoch is self.num_items
145 | return self.num_items // self.batch_size
146 |
--------------------------------------------------------------------------------
/models/fc_models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | class GammaRegressor(nn.Module):
7 | def __init__(self, in_dims, num_classes, gamma_coeff=5):
8 | super(GammaRegressor, self).__init__()
9 | self.fc = nn.Linear(in_dims, num_classes)
10 | self.gamma_coeff = gamma_coeff
11 |
12 | def forward(self, x):
13 | if len(x.shape) == 1:
14 | x = x.unsqueeze(1)
15 | # fc = self.gamma_coeff * torch.sigmoid(self.fc(x))
16 | # fc = torch.sigmoid(self.fc(x))
17 | fc = F.relu(self.fc(x))
18 | return {
19 | 'out': fc
20 | }
21 |
22 |
23 | class MLP1(nn.Module):
24 | def __init__(self, in_dims, num_classes, hid_dims=None):
25 | super(MLP1, self).__init__()
26 | self.fc = nn.Linear(in_dims, num_classes)
27 |
28 | def forward(self, x):
29 | if len(x.shape) == 1:
30 | x = x.unsqueeze(1)
31 | fc = self.fc(x)
32 | return {
33 | 'before_logits': x,
34 | 'logits': fc
35 | }
36 |
37 |
38 | class MLP2(nn.Module):
39 | def __init__(self, in_dims, hid_dims, num_classes):
40 | super(MLP2, self).__init__()
41 | self.fc1 = nn.Linear(in_dims, hid_dims)
42 | self.fc2 = nn.Linear(hid_dims, num_classes)
43 |
44 | def forward(self, x):
45 | if len(x.shape) == 1:
46 | x = x.unsqueeze(1)
47 | x = self.fc1(x)
48 | x = F.relu(x)
49 | logits = self.fc2(x)
50 | return {
51 | 'before_logits': x,
52 | 'logits': logits
53 | }
54 |
55 | def freeze_layers(self, freeze_layers):
56 | pass
57 |
58 |
59 | class MLP3(nn.Module):
60 | def __init__(self, in_dims, hid_dims, num_classes):
61 | super(MLP3, self).__init__()
62 | self.fc1 = nn.Linear(in_dims, hid_dims)
63 | self.fc2 = nn.Linear(hid_dims, hid_dims)
64 | self.fc3 = nn.Linear(hid_dims, num_classes)
65 |
66 | def forward(self, x):
67 | if len(x.shape) == 1:
68 | x = x.unsqueeze(1)
69 | x = self.fc1(x)
70 | x = F.relu(x)
71 | x = self.fc2(x)
72 | x = F.relu(x)
73 | logits = self.fc3(x)
74 | return {
75 | 'before_logits': x,
76 | 'logits': logits
77 | }
78 |
79 | def freeze_layers(self, freeze_layers):
80 | pass
81 |
82 |
83 | class MLP2_200(MLP2):
84 | def __init__(self, in_dims=300, hid_dims=600, num_classes=None):
85 | super().__init__(200, 400, num_classes)
86 |
87 |
88 | class MLP2_300(MLP2):
89 | def __init__(self, in_dims=300, hid_dims=600, num_classes=None):
90 | super().__init__(300, 600, num_classes)
91 |
92 |
93 | class MLP2_2(MLP2):
94 | def __init__(self, in_dims=300, hid_dims=600, num_classes=None):
95 | super().__init__(2, 4, num_classes)
96 |
97 |
98 | class MLP4(nn.Module):
99 | def __init__(self, in_dims, hid_dims, num_classes):
100 | super(MLP4, self).__init__()
101 | self.fc1 = nn.Linear(in_dims, hid_dims)
102 | self.fc2 = nn.Linear(hid_dims, hid_dims)
103 | self.fc3 = nn.Linear(hid_dims, hid_dims)
104 | self.fc4 = nn.Linear(hid_dims, num_classes)
105 |
106 | def forward(self, x):
107 | if len(x.shape) == 1:
108 | x = x.unsqueeze(1)
109 | x = self.fc1(x)
110 | x = F.relu(x)
111 | x = self.fc2(x)
112 | x = F.relu(x)
113 | x = self.fc3(x)
114 | x = F.relu(x)
115 | logits = self.fc4(x)
116 | return {
117 | 'before_logits': x,
118 | 'logits': logits
119 | }
120 |
121 |
122 | class ModelWrapper(nn.Module):
123 | def __init__(self, core_model, classifier, classifier_in_layer):
124 | super().__init__()
125 | self.core_model = core_model
126 | self.classifier = classifier
127 | self.classifier_in_layer = classifier_in_layer
128 |
129 | def forward(self, x):
130 | out = self.core_model(x)
131 | feat_repr = out[self.classifier_in_layer]
132 | return self.classifier(feat_repr)
133 |
134 |
135 | class LFFMnistClassifier(nn.Module):
136 | def __init__(self, num_classes=10, in_dims=None, hid_dims=None):
137 | super(LFFMnistClassifier, self).__init__()
138 | self.fc1 = nn.Linear(3 * 28 * 28, 100)
139 | self.fc2 = nn.Linear(100, 100)
140 | self.fc3 = nn.Linear(100, 100)
141 |
142 | self.feature = nn.Sequential(
143 | nn.Linear(3 * 28 * 28, 100),
144 | nn.ReLU(),
145 | nn.Linear(100, 100),
146 | nn.ReLU(),
147 | nn.Linear(100, 100),
148 | nn.ReLU()
149 | )
150 | self.classifier = nn.Linear(100, num_classes)
151 |
152 | def forward(self, x, return_feat=False):
153 | x = x.view(x.size(0), -1)
154 | x = self.fc1(x)
155 | x = F.relu(x)
156 | fc1 = x
157 | x = self.fc2(x)
158 | x = F.relu(x)
159 | fc2 = x
160 | x = self.fc3(x)
161 | x = F.relu(x)
162 | fc3 = x
163 | logits = self.classifier(x)
164 | return {
165 | 'fc1': fc1,
166 | 'fc2': fc2,
167 | 'fc3': fc3,
168 | 'before_logits': fc3,
169 | 'logits': logits
170 | }
171 |
172 |
173 | class MoonNet(nn.Module):
174 | def __init__(self,
175 | hidden_dim=500):
176 | super(MoonNet, self).__init__()
177 |
178 | self.fc1 = nn.Linear(2, hidden_dim)
179 | self.fc2 = nn.Linear(hidden_dim, 1)
180 |
181 | def forward(self, x):
182 | return self.fc2(F.relu(self.fc1(x)))
183 |
184 |
185 | class SlabNet(nn.Module):
186 | def __init__(self, num_classes):
187 | super().__init__()
188 | self.fc1 = nn.Linear(50, 75, bias=True)
189 | self.fc2 = nn.Linear(75, 100, bias=True)
190 | self.classifier = nn.Linear(100, num_classes, bias=True)
191 |
192 | def forward(self, x):
193 | fc1 = self.fc1(x)
194 | fc1 = F.relu(fc1)
195 | before_logits = self.fc2(fc1)
196 | x = F.relu(before_logits)
197 | logits = self.classifier(x)
198 | return {
199 | 'fc1': fc1,
200 | 'fc2': before_logits,
201 | 'before_logits': before_logits,
202 | 'logits': logits
203 | }
204 |
205 | def forward_representation_encoder(self, x):
206 | fc1 = self.fc1(x)
207 | fc1 = F.relu(fc1)
208 | fc2 = self.fc2(fc1)
209 | fc2 = F.relu(fc2)
210 | return {
211 | 'fc1': fc1,
212 | 'fc2': fc2
213 | }
214 |
215 | def forward_classifier(self, x):
216 | return self.classifier(x)
217 |
218 | def reset_classifier(self):
219 | self.classifier.reset_parameters()
220 |
221 | def set_representation_encoder_train(self, is_train):
222 | self.fc1.train(is_train)
223 | self.fc2.train(is_train)
224 |
225 | def set_classifier_train(self, is_train):
226 | self.classifier.train(is_train)
227 |
228 | def get_classifier_named_params(self):
229 | return (('classifier.weight', self.classifier.weight),
230 | ('classifier.bias', self.classifier.bias))
231 |
--------------------------------------------------------------------------------
/trainers/learning_from_failure_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from torch import optim
5 |
6 | import models
7 | import models.cnn_models
8 | from trainers.base_trainer import BaseTrainer
9 | from utils.trainer_utils import grad_reverse
10 | from models import model_factory
11 | from utils.bias_retrievers import build_bias_retriever
12 | from utils.trainer_utils import grad_mult, create_optimizer
13 | from torch.optim import *
14 | import logging
15 | from models.model_factory import build_model
16 | from utils.losses import GCELoss
17 | import os
18 | from utils.metrics import Accuracy, GroupWiseAccuracy
19 | from utils.ema import ClasswiseEMA
20 |
21 |
22 | class LffTrainer(BaseTrainer):
23 | """
24 | Implementation for:
25 | Nam, Junhyun, et al. "Learning from failure: Training debiased classifier from biased classifier." (NeurIPS 2020)
26 |
27 | The method trains a bias-only model using generalized cross entropy loss which helps models focus on easier samples, thereby amplifying the bias
28 | and then uses that to re-weight the samples for the main model, so that easy/biased samples receive lower weights.
29 | """
30 |
31 | def __init__(self, option):
32 | super(LffTrainer, self).__init__(option)
33 |
34 | def _build_model(self):
35 | super()._build_model()
36 | self.bias_model = build_model(self.option, self.option.model_name,
37 | out_dims=self.option.num_classes,
38 | in_dims=self.option.in_dims,
39 | hid_dims=self.option.hid_dims,
40 | freeze_layers=self.option.freeze_layers)
41 | logging.getLogger().info("Bias model")
42 | logging.getLogger().info(self.bias_model)
43 | self.bias_amplification_loss = GCELoss(q=self.option.bias_loss_gamma)
44 |
45 | if self.option.cuda:
46 | self.model.cuda()
47 | self.bias_model.cuda()
48 | self.loss.cuda()
49 |
50 | def _initialization(self):
51 | super()._initialization()
52 | self.bias_loss_ema_computer = ClasswiseEMA(self.max_dataset_ixs['Train'] + 1, alpha=self.option.bias_ema_gamma)
53 | self.main_loss_ema_computer = ClasswiseEMA(self.max_dataset_ixs['Train'] + 1, alpha=self.option.bias_ema_gamma)
54 |
55 | def _build_optimizer(self):
56 | super()._build_optimizer()
57 | self.bias_optim = create_optimizer(self.option.optimizer_name,
58 | named_params=self.bias_model.named_parameters(),
59 | lr=self.option.lr,
60 | weight_decay=self.option.weight_decay,
61 | momentum=self.option.momentum,
62 | freeze_layers=self.option.freeze_layers)
63 |
64 | def _mode_setting(self, is_train=True):
65 | self.model.train(is_train)
66 | self.bias_model.train(is_train)
67 |
68 | def _train_epoch(self, epoch, data_loader):
69 | self._mode_setting(is_train=True)
70 | for i, batch in enumerate(data_loader):
71 | batch = self.prepare_batch(batch)
72 |
73 | # Pass through the main model
74 | out = self.forward_model(self.model, batch)
75 | logits = out['logits']
76 | main_loss = self.loss(logits, batch['y'].squeeze())
77 |
78 | # Pass through the bias model
79 | bias_out = self.forward_model(self.bias_model, batch)
80 | bias_logits = bias_out['logits']
81 | bias_loss = self.loss(bias_logits, batch['y'].squeeze())
82 |
83 | # Update the bias and main loss EMAs (for computing weights)
84 | self.bias_loss_ema_computer.update(bias_loss.squeeze(),
85 | batch['dataset_ix'],
86 | batch['y'].squeeze())
87 | self.main_loss_ema_computer.update(main_loss.squeeze(),
88 | batch['dataset_ix'],
89 | batch['y'].squeeze())
90 |
91 | # Perform class-wise normalization on bias loss
92 | bias_loss_ema = self.bias_loss_ema_computer.parameter[batch['dataset_ix']].clone()
93 | main_loss_ema = self.main_loss_ema_computer.parameter[batch['dataset_ix']].clone()
94 |
95 | for c in range(self.option.num_classes):
96 | dataset_ixs_for_c = torch.where(batch['y'] == c)[0]
97 | if len(dataset_ixs_for_c) == 0:
98 | continue
99 | max_bias_loss_ema = self.bias_loss_ema_computer.max_loss(c)
100 | bias_loss_ema[dataset_ixs_for_c] /= max_bias_loss_ema
101 | max_main_loss_ema = self.main_loss_ema_computer.max_loss(c)
102 | main_loss_ema[dataset_ixs_for_c] /= max_main_loss_ema
103 |
104 | # Compute sample wise weights for main model
105 | sample_weights = bias_loss_ema.cuda() / (bias_loss_ema.cuda() + main_loss_ema.cuda() + 1e-8)
106 |
107 | # Compute bias amplification loss to update the parameters of bias model
108 | bias_amplication_loss = self.bias_amplification_loss(bias_logits, batch['y'].squeeze())
109 |
110 | # Weight the main model's loss using sample weights
111 | main_loss = self.loss(logits, batch['y'].squeeze()) * sample_weights
112 |
113 | loss = main_loss.mean() + bias_amplication_loss.mean()
114 |
115 | self.bias_optim.zero_grad()
116 | self.optim.zero_grad()
117 | loss.backward()
118 | # loss.backward(retain_graph=True)
119 |
120 | # if self.option.grad_clip is not None:
121 | # torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.option.grad_clip)
122 | # torch.nn.utils.clip_grad_norm_(self.bias_model.parameters(), self.option.grad_clip)
123 | self.bias_optim.step()
124 | self.optim.step()
125 |
126 | split = 'Train'
127 | self.update_generalization_metrics(split + " Main", batch, main_loss)
128 | self.update_generalization_metrics(split + " Bias", batch, bias_loss)
129 | if self.option.enable_groupwise_metrics:
130 | self.update_groupwise_values(split, "Sample Weights", sample_weights, batch)
131 | self.update_groupwise_values(split, "Main Loss EMA", main_loss_ema, batch)
132 | self.update_groupwise_values(split, "Bias Amp Loss", bias_amplication_loss, batch)
133 | self.update_groupwise_values(split, "Bias Loss EMA", bias_loss_ema, batch)
134 | self.loss_visualizer.log(epoch, f'{split} Main')
135 | self.loss_visualizer.accumulate_plot_and_reset(epoch)
136 |
137 | def get_keys_to_save(self):
138 | return super().get_keys_to_save() + ['bias_model', 'bias_optim']
139 |
140 | def get_current_state(self):
141 | save_state = super().get_current_state()
142 | save_state['bias_loss_ema'] = self.bias_loss_ema_computer.parameter
143 | save_state['main_loss_ema'] = self.main_loss_ema_computer.parameter
144 | return save_state
145 |
146 | def test(self, epoch, data_key, data_loader):
147 | for model, model_key in [[self.model, 'Main'], [self.bias_model, 'Bias']]:
148 | super().test(epoch, data_key, data_loader, model=model, model_key=model_key)
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 |
6 | parser.add_argument('-e', '--expt_name', required=False, help='Experiment will be saved under this name.')
7 | parser.add_argument('--expt_type', type=str,
8 | help='The experimental configuration to use e.g., all CelebA runs can use celebA_experiments value for this argument')
9 | parser.add_argument('--dataset_name', required=False, help='Name of the dataset')
10 | parser.add_argument('--root_dir', default=None)
11 | parser.add_argument('--num_classes', type=int, help='Number of classes')
12 | parser.add_argument('--batch_size', default=None, type=int, help='Mini-batch size')
13 | parser.add_argument('--momentum', default=0.9, type=float, help='SGD momentum')
14 | parser.add_argument('--lr', type=float, help='Learning rate')
15 | parser.add_argument('--weight_decay', default=0, type=float, help='Weight decay')
16 | parser.add_argument('--epochs', type=int, help='Num of epochs')
17 | parser.add_argument('--load_checkpoint', default=None, help='Checkpoint to resume from.')
18 |
19 | parser.add_argument('--test_every', default=5, type=int, help='Interval to perform testing')
20 | parser.add_argument('--test_epochs', nargs='+', type=int,
21 | help='List of epochs to perform testing at. Is compatible with test_every argument -- both will be used')
22 | parser.add_argument('--save_predictions_every', default=25, type=int, help='Interval to save predictions and metrics')
23 | parser.add_argument('--save_model_every', default=25, type=int, help='Interval to save the model checkpoint')
24 | parser.add_argument('--data_dir', required=False,
25 | help='Directory where the dataset is stored. We usually assume this convention: {root_dir}/{dataset_name}')
26 | parser.add_argument('--save_dir', required=False, help='Logs, Checkpoints, predictions and metrics will be saved here.')
27 |
28 | parser.add_argument('--random_seed', type=int, help='Random seed', default=1)
29 | parser.add_argument('--num_workers', default=8, type=int, help='Number of workers in data loader')
30 |
31 | parser.add_argument('--grad_clip', type=float, default=None,
32 | help="Grad clip. This wasn't used for any of the methods for the comparison experiments.")
33 | parser.add_argument('--trainer_name', type=str, help="Name of the method e.g., BaseTrainer or GroupDROTrainer")
34 | parser.add_argument('--model_name', type=str,
35 | help="Name of the main model. For two branch models e.g., RUBi, this refers to the name for the main branch.")
36 | parser.add_argument('--bias_model_name', type=str,
37 | help="For two/multi branch setups, this either predicts the bias variables or uses them as input.")
38 | parser.add_argument('--optimizer_name', type=str, default=None, help="e.g., SGD, Adam")
39 | parser.add_argument('--bias_proba', type=float, default=1.1, help='p_bias for BiasedMNIST')
40 | parser.add_argument('--bias_var', type=float, default=0.02)
41 | parser.add_argument('--dummy', action='store_true',
42 | help="A flag used for debugging runs e.g., setting num_workers=0 to make debugging possible and using a smaller dataset size.")
43 | parser.add_argument('--balanced_sampling_attributes', type=str, nargs='+', default=None,
44 | help="List of attributes (as returned in a mini-batch) which should be used for balancing i.e., every unique combination of these attributes will have equal probability of being sampled."
45 | "Useful for GroupDRO")
46 | parser.add_argument('--balanced_sampling_gamma', type=float, default=1.0,
47 | help="Exponentiation for inverse group probability. Higher values would oversample minority patterns a lot.")
48 | parser.add_argument('--freeze_layers', default=None, nargs='+',
49 | help="Can be used to freeze layers i.e., not used for optimization."
50 | "When freezing, you need to disable batch norm and other model-specific settings yourself.")
51 | parser.add_argument('--custom_lr_config', default=None, type=str, help="Unused (deprecated) argument.")
52 |
53 | parser.add_argument('--grad_reverse_factor', type=float, default=-0.1,
54 | help="Reversal parameter for adversarial debiasing e.g., learning not to learn (LNL). Use a negative value.")
55 | parser.add_argument('--loss_type', type=str, default='CrossEntropyLoss')
56 |
57 | # Arguments specific to GroupDROTrainer
58 | parser.add_argument('--num_groups', type=int, help="Number of groups for grouping methods e.g., GroupDRO.")
59 | parser.add_argument('--group_weight_step_size', type=float, default=0.01,
60 | help="Learning rate to update group weights in GroupDRO.")
61 | parser.add_argument('--group_mode', type=str,
62 | help='Grouping mode e.g., unique_bias_value or majority_minority for BiasedMNIST. TODO: remove this.')
63 | parser.add_argument('--bias_predictor_in_layer', type=str, default=None,
64 | help="LNL predicts bias variables from this layer.")
65 | parser.add_argument('--bias_predictor_name', type=str, default=None, help="Bias model name for LNL.")
66 |
67 | parser.add_argument('--bias_variable_name', type=str, default=None,
68 | help="Name of the bias variable used by explicit methods and also used to compute metrics.")
69 | parser.add_argument('--target_name', type=str, default=None, help="Variable name to predict i.e., class variable.")
70 | parser.add_argument('--group_by', type=str, default=None,
71 | help="Dataset is grouped by this variable, usually set to group_ix.")
72 | parser.add_argument('--key_to_group_by', type=str, default=None, help="This provides names for the groups.")
73 |
74 | # Arguments specific to LffTrainer
75 | parser.add_argument('--bias_loss_gamma', type=float, default=0.7, help="Loss gamma for LFF")
76 | parser.add_argument('--bias_ema_gamma', type=float, default=0.7, help="EMA gamma for LFF")
77 | parser.add_argument('--bias_model_hid_dims', type=int, help='Hidden dimensions for the bias model')
78 |
79 | parser.add_argument('--entropy_loss_weight', type=float, default=0, help="Weight for entropy loss weight in LNL.")
80 |
81 | parser.add_argument('--dataset_info', help="Used internally to set dataset specific attributes.")
82 | parser.add_argument('--enable_groupwise_metrics', action='store_true')
83 | parser.add_argument('--project_name', type=str, default='Bias-Mitigators', help="Results will be saved here.")
84 |
85 | # Arguments specific to RunningFocalLossTrainer
86 | parser.add_argument('--in_dims', type=int, default=None)
87 | parser.add_argument('--hid_dims', type=int, default=None)
88 | parser.add_argument('--grad_penalty_weight', type=float, default=1.0)
89 | parser.add_argument('--expt_dir', type=str)
90 | parser.add_argument('--bias_variable_type', type=str)
91 |
92 | parser.add_argument('--spectral_decoupling_lambda', type=float)
93 | parser.add_argument('--spectral_decoupling_lambdas', type=float, nargs='+')
94 | parser.add_argument('--spectral_decoupling_gamma', type=float)
95 | parser.add_argument('--spectral_decoupling_gammas', type=float, nargs='+')
96 |
97 | parser.add_argument('--num_envs_per_batch', type=int,
98 | help="Used by IRMv1. Each mini-batch will contain the specified number of environments.")
99 |
100 |
101 | def get_option():
102 | option = parser.parse_args()
103 | option.cuda = True
104 | if option.dummy:
105 | option.num_workers = 0
106 | return option
107 |
108 |
109 | # Used when bash files are not used
110 | ROOT = '/hdd/user'
111 | EXPT_ROOT = '/hdd/user/bias_mitigators'
112 |
--------------------------------------------------------------------------------
/models/coordconv.py:
--------------------------------------------------------------------------------
1 | # https://raw.githubusercontent.com/walsvid/CoordConv/052d45354bae46f5fdb0f906fac31f3e1f9debe2/coordconv.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.modules.conv as conv
5 |
6 |
7 | class AddCoords(nn.Module):
8 | def __init__(self, rank, with_r=False, use_cuda=True):
9 | super(AddCoords, self).__init__()
10 | self.rank = rank
11 | self.with_r = with_r
12 | self.use_cuda = use_cuda
13 |
14 | def forward(self, input_tensor):
15 | """
16 | :param input_tensor: shape (N, C_in, H, W)
17 | :return:
18 | """
19 | if self.rank == 1:
20 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape
21 | xx_range = torch.arange(dim_x, dtype=torch.int32)
22 | xx_channel = xx_range[None, None, :]
23 |
24 | xx_channel = xx_channel.float() / (dim_x - 1)
25 | xx_channel = xx_channel * 2 - 1
26 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)
27 |
28 | if torch.cuda.is_available and self.use_cuda:
29 | input_tensor = input_tensor.cuda()
30 | xx_channel = xx_channel.cuda()
31 | out = torch.cat([input_tensor, xx_channel], dim=1)
32 |
33 | if self.with_r:
34 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
35 | out = torch.cat([out, rr], dim=1)
36 |
37 | elif self.rank == 2:
38 | batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape
39 | xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32)
40 | yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32)
41 |
42 | xx_range = torch.arange(dim_y, dtype=torch.int32)
43 | yy_range = torch.arange(dim_x, dtype=torch.int32)
44 | xx_range = xx_range[None, None, :, None]
45 | yy_range = yy_range[None, None, :, None]
46 |
47 | xx_channel = torch.matmul(xx_range, xx_ones)
48 | yy_channel = torch.matmul(yy_range, yy_ones)
49 |
50 | # transpose y
51 | yy_channel = yy_channel.permute(0, 1, 3, 2)
52 |
53 | xx_channel = xx_channel.float() / (dim_y - 1)
54 | yy_channel = yy_channel.float() / (dim_x - 1)
55 |
56 | xx_channel = xx_channel * 2 - 1
57 | yy_channel = yy_channel * 2 - 1
58 |
59 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
60 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
61 |
62 | if torch.cuda.is_available and self.use_cuda:
63 | input_tensor = input_tensor.cuda()
64 | xx_channel = xx_channel.cuda()
65 | yy_channel = yy_channel.cuda()
66 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
67 |
68 | if self.with_r:
69 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
70 | out = torch.cat([out, rr], dim=1)
71 |
72 | elif self.rank == 3:
73 | batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = input_tensor.shape
74 | xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32)
75 | yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32)
76 | zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32)
77 |
78 | xy_range = torch.arange(dim_y, dtype=torch.int32)
79 | xy_range = xy_range[None, None, None, :, None]
80 |
81 | yz_range = torch.arange(dim_z, dtype=torch.int32)
82 | yz_range = yz_range[None, None, None, :, None]
83 |
84 | zx_range = torch.arange(dim_x, dtype=torch.int32)
85 | zx_range = zx_range[None, None, None, :, None]
86 |
87 | xy_channel = torch.matmul(xy_range, xx_ones)
88 | xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)
89 |
90 | yz_channel = torch.matmul(yz_range, yy_ones)
91 | yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
92 | yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)
93 |
94 | zx_channel = torch.matmul(zx_range, zz_ones)
95 | zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
96 | zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)
97 |
98 | if torch.cuda.is_available and self.use_cuda:
99 | input_tensor = input_tensor.cuda()
100 | xx_channel = xx_channel.cuda()
101 | yy_channel = yy_channel.cuda()
102 | zz_channel = zz_channel.cuda()
103 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1)
104 |
105 | if self.with_r:
106 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) +
107 | torch.pow(yy_channel - 0.5, 2) +
108 | torch.pow(zz_channel - 0.5, 2))
109 | out = torch.cat([out, rr], dim=1)
110 | else:
111 | raise NotImplementedError
112 |
113 | return out
114 |
115 |
116 | class CoordConv1d(conv.Conv1d):
117 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
118 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True):
119 | super(CoordConv1d, self).__init__(in_channels, out_channels, kernel_size,
120 | stride, padding, dilation, groups, bias)
121 | self.rank = 1
122 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda)
123 | self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), out_channels,
124 | kernel_size, stride, padding, dilation, groups, bias)
125 |
126 | def forward(self, input_tensor):
127 | """
128 | input_tensor_shape: (N, C_in,H,W)
129 | output_tensor_shape: N,C_out,H_out,W_out)
130 | :return: CoordConv2d Result
131 | """
132 | out = self.addcoords(input_tensor)
133 | out = self.conv(out)
134 |
135 | return out
136 |
137 |
138 | class CoordConv2d(conv.Conv2d):
139 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
140 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True):
141 | super(CoordConv2d, self).__init__(in_channels, out_channels, kernel_size,
142 | stride, padding, dilation, groups, bias)
143 | self.rank = 2
144 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda)
145 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels,
146 | kernel_size, stride, padding, dilation, groups, bias)
147 |
148 | def forward(self, input_tensor):
149 | """
150 | input_tensor_shape: (N, C_in,H,W)
151 | output_tensor_shape: N,C_out,H_out,W_out)
152 | :return: CoordConv2d Result
153 | """
154 | out = self.addcoords(input_tensor)
155 | out = self.conv(out)
156 |
157 | return out
158 |
159 |
160 | class CoordConv3d(conv.Conv3d):
161 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
162 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True):
163 | super(CoordConv3d, self).__init__(in_channels, out_channels, kernel_size,
164 | stride, padding, dilation, groups, bias)
165 | self.rank = 3
166 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda)
167 | self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), out_channels,
168 | kernel_size, stride, padding, dilation, groups, bias)
169 |
170 | def forward(self, input_tensor):
171 | """
172 | input_tensor_shape: (N, C_in,H,W)
173 | output_tensor_shape: N,C_out,H_out,W_out)
174 | :return: CoordConv2d Result
175 | """
176 | out = self.addcoords(input_tensor)
177 | out = self.conv(out)
178 |
179 | return out
180 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 | import torch
4 |
5 |
6 | class Accuracy():
7 | def __init__(self, num_classes):
8 | self.num_classes = num_classes
9 | self.reset()
10 |
11 | def update(self, pred_ys, gt_ys):
12 | for pred_y, gt_y in zip(pred_ys, gt_ys):
13 | if pred_y == gt_y:
14 | self.correct[pred_y] += 1
15 | self.total[gt_y] += 1
16 |
17 | def get_accuracy(self):
18 | return self.correct.sum() / self.total.sum()
19 |
20 | def get_per_class_accuracy(self):
21 | correct = np.asarray(self.correct)
22 | total = np.asarray(self.total)
23 | for tix, (c, t) in enumerate(zip(correct, total)):
24 | if t == 0:
25 | correct[tix] = 1
26 | total[tix] = 1
27 | return correct / total
28 |
29 | def get_mean_per_class_accuracy(self):
30 | return self.get_per_class_accuracy().mean()
31 |
32 | def reset(self):
33 | self.correct = np.zeros((self.num_classes))
34 | self.total = np.zeros((self.num_classes))
35 |
36 |
37 | class GroupWiseAccuracy():
38 | def __init__(self):
39 | self.reset()
40 |
41 | def update(self, pred_ys, gt_ys, group_names):
42 | for pred_y, gt_y, group_name in zip(pred_ys, gt_ys, group_names):
43 | group_name = str(group_name)
44 | if group_name not in self.group_wise_total:
45 | self.group_wise_total[group_name] = 0
46 | self.group_wise_correct[group_name] = 0
47 | if pred_y == gt_y:
48 | self.group_wise_correct[group_name] += 1
49 | self.group_wise_total[group_name] += 1
50 |
51 | def get_per_group_accuracy(self):
52 | per_group_accuracy = {}
53 | for group_name in self.group_wise_correct:
54 | per_group_accuracy[group_name] = self.group_wise_correct[group_name] / self.group_wise_total[group_name]
55 | return per_group_accuracy
56 |
57 | def get_mean_per_group_accuracy(self):
58 | total_acc, total_num = 0, 0
59 | per_group_accuracy = self.get_per_group_accuracy()
60 | for group_name in per_group_accuracy:
61 | total_acc += per_group_accuracy[group_name]
62 | total_num += 1
63 | return total_acc / total_num
64 |
65 | def reset(self):
66 | self.group_wise_correct = {}
67 | self.group_wise_total = {}
68 |
69 | def log(self, prefix=''):
70 | log_str = prefix
71 | per_group_accuracy = self.get_per_group_accuracy()
72 | group_names = sorted([k for k in per_group_accuracy.keys()])
73 | accuracies = ""
74 | for group_name in group_names:
75 | log_str += '%s: %.2f%% ' % (group_name, per_group_accuracy[group_name] * 100)
76 | accuracies += '%.2f%%, ' % (per_group_accuracy[group_name] * 100)
77 | log_str += ' MPG: %.2f%%' % (self.get_mean_per_group_accuracy() * 100)
78 | # logging.getLogger().info(log_str)
79 | logging.getLogger().info(f"Group names {group_names}")
80 | logging.getLogger().info(f"Accuracies {accuracies}")
81 |
82 |
83 | class ModuleStatsComputer:
84 | """Stores sensitivities, GT classes and predicted classes"""
85 |
86 | def __init__(self, num_modules, num_classes):
87 | self.reset()
88 | self.num_modules = num_modules
89 | self.num_classes = num_classes
90 |
91 | def reset(self):
92 | self.sensitivities = []
93 | self.gt_class_ixs, self.pred_class_ixs = [], []
94 | self.group_names = []
95 |
96 | def update(self, sensitivities, gt_class_ixs, pred_class_ixs, group_names):
97 | self.sensitivities += sensitivities
98 | self.gt_class_ixs += gt_class_ixs
99 | self.pred_class_ixs += pred_class_ixs
100 | self.group_names += [str(gn) for gn in group_names]
101 |
102 | def log(self):
103 | sensitivities = np.asarray(self.sensitivities)
104 | gt_class_ixs = np.asarray(self.gt_class_ixs)
105 | pred_class_ixs = np.asarray(self.pred_class_ixs)
106 | most_sensitive_ixs = np.argmax(sensitivities, axis=1)
107 | group_names = np.asarray(self.group_names)
108 |
109 | most_sensitive_counts = {ix: len(np.nonzero(most_sensitive_ixs == ix)[0])
110 | for ix in range(0, self.num_modules)}
111 | most_sens_n_correct = {}
112 | accuracy_when_most_sensitive = {}
113 | group_distribution = {}
114 | overall_metrics = {}
115 | for module_ix in range(self.num_modules):
116 | most_sens_n_correct[module_ix] = len(np.intersect1d(np.nonzero(most_sensitive_ixs == module_ix)[0],
117 | np.nonzero(gt_class_ixs == pred_class_ixs)[0]))
118 | accuracy_when_most_sensitive[module_ix] = 100 * most_sens_n_correct[module_ix] / max(most_sensitive_counts[
119 | module_ix], 1)
120 | if module_ix not in group_distribution:
121 | group_distribution[module_ix] = {}
122 | for gn in group_names[np.nonzero(most_sensitive_ixs == module_ix)[0]]:
123 | if gn not in group_distribution[module_ix]:
124 | group_distribution[module_ix][gn] = 0
125 | group_distribution[module_ix][gn] += 1
126 |
127 | for module_ix in range(self.num_modules):
128 | if int(most_sensitive_counts[module_ix]) > 0:
129 | overall_metrics[module_ix] = {
130 | 'most_sensitive_count': int(most_sensitive_counts[module_ix]),
131 | 'accuracy_when_most_sensitive': '%.2f%%' % (float(accuracy_when_most_sensitive[module_ix])),
132 | 'group_distribution': group_distribution[module_ix]
133 | }
134 |
135 |
136 | class GradientTracker():
137 | def __init__(self, num_samples, num_epochs):
138 | # l2_norm = torch.norm(gradients.detach(), p=2, dim=1)
139 | # abs = torch.sum(torch.abs(gradients.detach()), dim=1)
140 |
141 | self.l2_norms = torch.zeros((num_samples, num_epochs))
142 | self.abs_vals = torch.zeros((num_samples, num_epochs))
143 | self.groups = np.asarray(['NoneNoneNoneNoneNoneNone'] * (num_samples))
144 | self.unq_groups = {}
145 |
146 | def update(self, epoch, dataset_ixs, gradients, groups):
147 | self.epoch = epoch
148 | # self.l2_norms[dataset_ixs, epoch - 1] = torch.norm(gradients.detach().flatten(1), p=2, dim=1).cpu()
149 | # self.abs_vals[dataset_ixs, epoch - 1] = torch.sum(torch.abs(gradients.detach().flatten(1)), dim=1).cpu()
150 | if epoch == 1:
151 | self.groups[dataset_ixs] = groups
152 | for g in groups:
153 | if g not in self.unq_groups:
154 | self.unq_groups[g] = g
155 |
156 | def get_groupwise_values(self):
157 | # return mean, variance, normalized variance, mean of running variance?
158 | # How does it evolve over time
159 | values = {}
160 | for g in self.unq_groups:
161 | if g not in values:
162 | values[g] = {}
163 | grp_ixs = np.nonzero(self.groups == g)[0]
164 | l2_norms = self.l2_norms[grp_ixs, self.epoch - 1]
165 | values[g]['mean_of_l2_norms'] = torch.mean(l2_norms)
166 | values[g]['variance_of_l2_norms'] = torch.std(l2_norms) ** 2
167 |
168 | squared_l2_norms = l2_norms ** 2
169 | values[g]['mean_of_squared_l2_norms'] = torch.mean(squared_l2_norms)
170 | values[g]['variance_of_squared_l2_norms'] = torch.std(squared_l2_norms) ** 2
171 |
172 | abs = self.abs_vals[grp_ixs, self.epoch - 1]
173 | values[g]['mean_of_abs'] = torch.mean(abs)
174 | values[g]['variance_of_abs'] = torch.std(abs) ** 2
175 |
176 | return values
177 |
178 |
179 | class PredictionChangeTracker():
180 | def __init__(self, num_samples, num_epochs):
181 | self.preds = torch.zeros((num_samples, num_epochs))
182 | self.num_pred_changes = torch.zeros((num_samples))
183 | self.groups = np.asarray(['NoneNoneNoneNoneNoneNone'] * (num_samples))
184 | self.unq_groups = {}
185 |
186 | def update(self, dataset_ixs, epoch, logits, labels, groups):
187 | self.epoch = epoch
188 | self.preds[dataset_ixs, epoch - 1] = torch.argmax(logits.cpu(), dim=1).float()
189 | if epoch > 1:
190 | pred_change_mask = torch.where(self.preds[dataset_ixs, epoch - 2] != self.preds[dataset_ixs, epoch - 1],
191 | torch.ones_like(dataset_ixs),
192 | torch.zeros_like(dataset_ixs))
193 | self.groups[dataset_ixs] = groups
194 | for g in groups:
195 | if g not in self.unq_groups:
196 | self.unq_groups[g] = g
197 | self.num_pred_changes[dataset_ixs] += pred_change_mask
198 |
199 | def get_values(self):
200 | group_change_pct = {}
201 | max_num_changes = {}
202 | mean_num_changes = {}
203 | for g in self.unq_groups:
204 | grp_ixs = np.nonzero(self.groups == g)[0]
205 | num_changes = self.num_pred_changes[grp_ixs].sum()
206 | grp_len = len(grp_ixs)
207 | group_change_pct[g] = (num_changes / (grp_len * (self.epoch - 1))) * 100
208 | max_num_changes[g] = self.num_pred_changes[grp_ixs].max().item()
209 |
210 | return {
211 | 'group_change_percent': group_change_pct,
212 | 'max_num_changes': max_num_changes,
213 | }
214 |
--------------------------------------------------------------------------------
/models/variable_width_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | try:
5 | from torch.hub import load_state_dict_from_url
6 | except ImportError:
7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
8 |
9 |
10 | # __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
11 | # 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
12 | # 'wide_resnet50_2', 'wide_resnet101_2']
13 |
14 |
15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
16 | """3x3 convolution with padding"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=dilation, groups=groups, bias=False, dilation=dilation)
19 |
20 |
21 | def conv1x1(in_planes, out_planes, stride=1):
22 | """1x1 convolution"""
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 | __constants__ = ['downsample']
29 |
30 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
31 | base_width=64, dilation=1, norm_layer=None):
32 | super(BasicBlock, self).__init__()
33 | if norm_layer is None:
34 | norm_layer = nn.BatchNorm2d
35 | if groups != 1 or base_width != 64:
36 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
37 | if dilation > 1:
38 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
40 | self.conv1 = conv3x3(inplanes, planes, stride)
41 | self.bn1 = norm_layer(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = conv3x3(planes, planes)
44 | self.bn2 = norm_layer(planes)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | identity = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 |
58 | if self.downsample is not None:
59 | identity = self.downsample(x)
60 |
61 | out += identity
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class Bottleneck(nn.Module):
68 | expansion = 4
69 | __constants__ = ['downsample']
70 |
71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
72 | base_width=64, dilation=1, norm_layer=None):
73 | super(Bottleneck, self).__init__()
74 | if norm_layer is None:
75 | norm_layer = nn.BatchNorm2d
76 | width = int(planes * (base_width / 64.)) * groups
77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
78 | self.conv1 = conv1x1(inplanes, width)
79 | self.bn1 = norm_layer(width)
80 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
81 | self.bn2 = norm_layer(width)
82 | self.conv3 = conv1x1(width, planes * self.expansion)
83 | self.bn3 = norm_layer(planes * self.expansion)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | identity = x
90 |
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv2(out)
96 | out = self.bn2(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv3(out)
100 | out = self.bn3(out)
101 |
102 | if self.downsample is not None:
103 | identity = self.downsample(x)
104 |
105 | out += identity
106 | out = self.relu(out)
107 |
108 | return out
109 |
110 |
111 | class VariableWidthResNet(nn.Module):
112 |
113 | def __init__(self, block, layers, width, num_classes=1000, zero_init_residual=False,
114 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
115 | norm_layer=None):
116 | super(VariableWidthResNet, self).__init__()
117 | if norm_layer is None:
118 | norm_layer = nn.BatchNorm2d
119 | self._norm_layer = norm_layer
120 |
121 | self.inplanes = width
122 | self.dilation = 1
123 | if replace_stride_with_dilation is None:
124 | # each element in the tuple indicates if we should replace
125 | # the 2x2 stride with a dilated convolution instead
126 | replace_stride_with_dilation = [False, False, False]
127 | if len(replace_stride_with_dilation) != 3:
128 | raise ValueError("replace_stride_with_dilation should be None "
129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
130 | self.groups = groups
131 | self.base_width = width_per_group
132 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
133 | bias=False)
134 | self.bn1 = norm_layer(self.inplanes)
135 | self.relu = nn.ReLU(inplace=True)
136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
137 | self.layer1 = self._make_layer(block, width, layers[0])
138 | self.layer2 = self._make_layer(block, width * 2, layers[1], stride=2,
139 | dilate=replace_stride_with_dilation[0])
140 | self.layer3 = self._make_layer(block, width * 4, layers[2], stride=2,
141 | dilate=replace_stride_with_dilation[1])
142 | self.layer4 = self._make_layer(block, width * 8, layers[3], stride=2,
143 | dilate=replace_stride_with_dilation[2])
144 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145 | self.fc = nn.Linear(8 * width * block.expansion, num_classes)
146 |
147 | for m in self.modules():
148 | if isinstance(m, nn.Conv2d):
149 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
150 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
151 | nn.init.constant_(m.weight, 1)
152 | nn.init.constant_(m.bias, 0)
153 |
154 | # Zero-initialize the last BN in each residual branch,
155 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
156 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
157 | if zero_init_residual:
158 | for m in self.modules():
159 | if isinstance(m, Bottleneck):
160 | nn.init.constant_(m.bn3.weight, 0)
161 | elif isinstance(m, BasicBlock):
162 | nn.init.constant_(m.bn2.weight, 0)
163 |
164 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
165 | norm_layer = self._norm_layer
166 | downsample = None
167 | previous_dilation = self.dilation
168 | if dilate:
169 | self.dilation *= stride
170 | stride = 1
171 | if stride != 1 or self.inplanes != planes * block.expansion:
172 | downsample = nn.Sequential(
173 | conv1x1(self.inplanes, planes * block.expansion, stride),
174 | norm_layer(planes * block.expansion),
175 | )
176 |
177 | layers = []
178 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
179 | self.base_width, previous_dilation, norm_layer))
180 | self.inplanes = planes * block.expansion
181 | for _ in range(1, blocks):
182 | layers.append(block(self.inplanes, planes, groups=self.groups,
183 | base_width=self.base_width, dilation=self.dilation,
184 | norm_layer=norm_layer))
185 |
186 | return nn.Sequential(*layers)
187 |
188 | def _forward_impl(self, x):
189 | # See note [TorchScript super()]
190 | x = self.conv1(x)
191 | x = self.bn1(x)
192 | x = self.relu(x)
193 | x = self.maxpool(x)
194 |
195 | x = self.layer1(x)
196 | x = self.layer2(x)
197 | x = self.layer3(x)
198 | x = self.layer4(x)
199 |
200 | x = self.avgpool(x)
201 | x = torch.flatten(x, 1)
202 | x = self.fc(x)
203 |
204 | return x
205 |
206 | def forward(self, x):
207 | return self._forward_impl(x)
208 |
209 |
210 | def _vwresnet(arch, block, layers, width, pretrained, progress, **kwargs):
211 | assert not pretrained, "No pretrained model for variable width ResNets"
212 | model = VariableWidthResNet(block, layers, width, **kwargs)
213 | return model
214 |
215 |
216 | def resnet10vw(width, pretrained=False, progress=True, **kwargs):
217 | r"""ResNet-18 model from
218 | `"Deep Residual Learning for Image Recognition" `_
219 | Args:
220 | pretrained (bool): If True, returns a model pre-trained on ImageNet
221 | progress (bool): If True, displays a progress bar of the download to stderr
222 | """
223 | return _vwresnet('resnet10', BasicBlock, [1, 1, 1, 1], width, pretrained, progress,
224 | **kwargs)
225 |
226 |
227 | def resnet18vw(width, pretrained=False, progress=True, **kwargs):
228 | r"""ResNet-18 model from
229 | `"Deep Residual Learning for Image Recognition" `_
230 | Args:
231 | pretrained (bool): If True, returns a model pre-trained on ImageNet
232 | progress (bool): If True, displays a progress bar of the download to stderr
233 | """
234 | return _vwresnet('resnet18', BasicBlock, [2, 2, 2, 2], width, pretrained, progress,
235 | **kwargs)
236 |
237 |
238 | def resnet34vw(width, pretrained=False, progress=True, **kwargs):
239 | r"""ResNet-34 model from
240 | `"Deep Residual Learning for Image Recognition" `_
241 | Args:
242 | pretrained (bool): If True, returns a model pre-trained on ImageNet
243 | progress (bool): If True, displays a progress bar of the download to stderr
244 | """
245 | return _vwresnet('resnet34', BasicBlock, [3, 4, 6, 3], width, pretrained, progress,
246 | **kwargs)
247 |
248 |
249 | def resnet50vw(width, pretrained=False, progress=True, **kwargs):
250 | r"""ResNet-50 model from
251 | `"Deep Residual Learning for Image Recognition" `_
252 | Args:
253 | pretrained (bool): If True, returns a model pre-trained on ImageNet
254 | progress (bool): If True, displays a progress bar of the download to stderr
255 | """
256 | return _vwresnet('resnet50', Bottleneck, [3, 4, 6, 3], width, pretrained, progress,
257 | **kwargs)
258 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import torch
3 | from torch.autograd import Variable
4 | from torch.utils.data import Sampler
5 |
6 |
7 | # https://raw.githubusercontent.com/Cadene/bootstrap.pytorch/master/bootstrap/datasets/transforms.py
8 |
9 | class Compose(object):
10 | """Composes several collate together.
11 |
12 | Args:
13 | transforms (list of ``Collate`` objects): list of transforms to compose.
14 | """
15 |
16 | def __init__(self, transforms):
17 | self.transforms = transforms
18 |
19 | def __call__(self, batch):
20 | for transform in self.transforms:
21 | batch = transform(batch)
22 | return batch
23 |
24 |
25 | class ListDictsToDictLists(object):
26 |
27 | def __init__(self):
28 | pass
29 |
30 | def __call__(self, batch):
31 | batch = self.ld_to_dl(batch)
32 | return batch
33 |
34 | def ld_to_dl(self, batch):
35 | if isinstance(batch[0], collections.Mapping):
36 | return {key: self.ld_to_dl([d[key] for d in batch]) for key in batch[0]}
37 | else:
38 | return batch
39 |
40 |
41 | class PadTensors(object):
42 |
43 | def __init__(self, value=0, use_keys=[], avoid_keys=[]):
44 | self.value = value
45 | self.use_keys = use_keys
46 | if len(self.use_keys) > 0:
47 | self.avoid_keys = []
48 | else:
49 | self.avoid_keys = avoid_keys
50 |
51 | def __call__(self, batch):
52 | batch = self.pad_tensors(batch)
53 | return batch
54 |
55 | def pad_tensors(self, batch):
56 | if isinstance(batch, collections.Mapping):
57 | out = {}
58 | for key, value in batch.items():
59 | if (key in self.use_keys) or \
60 | (len(self.use_keys) == 0 and key not in self.avoid_keys):
61 | out[key] = self.pad_tensors(value)
62 | else:
63 | out[key] = value
64 | return out
65 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
66 | max_size = [max([item.size(i) for item in batch]) for i in range(batch[0].dim())]
67 | max_size = torch.Size(max_size)
68 | n_batch = []
69 | for item in batch:
70 | if item.size() != max_size:
71 | n_item = item.new(max_size).fill_(self.value)
72 | # TODO: Improve this
73 | if item.dim() == 1:
74 | n_item[:item.size(0)] = item
75 | elif item.dim() == 2:
76 | n_item[:item.size(0), :item.size(1)] = item
77 | elif item.dim() == 3:
78 | n_item[:item.size(0), :item.size(1), :item.size(2)] = item
79 | else:
80 | raise ValueError
81 | n_batch.append(n_item)
82 | else:
83 | n_batch.append(item)
84 | return n_batch
85 | else:
86 | return batch
87 |
88 |
89 | class StackTensors(object):
90 |
91 | def __init__(self, use_shared_memory=False, avoid_keys=[]):
92 | self.use_shared_memory = use_shared_memory
93 | self.avoid_keys = avoid_keys
94 |
95 | def __call__(self, batch):
96 | batch = self.stack_tensors(batch)
97 | return batch
98 |
99 | # key argument is useful for debuging
100 | def stack_tensors(self, batch, key=None):
101 | if isinstance(batch, collections.Mapping):
102 | out = {}
103 | for key, value in batch.items():
104 | if key not in self.avoid_keys:
105 | out[key] = self.stack_tensors(value, key=key)
106 | else:
107 | out[key] = value
108 | return out
109 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
110 | out = None
111 | if self.use_shared_memory:
112 | # If we're in a background process, concatenate directly into a
113 | # shared memory tensor to avoid an extra copy
114 | numel = sum([x.numel() for x in batch])
115 | storage = batch[0].storage()._new_shared(numel)
116 | out = batch[0].new(storage)
117 | return torch.stack(batch, 0, out=out)
118 | else:
119 | return batch
120 |
121 |
122 | class CatTensors(object):
123 |
124 | def __init__(self, use_shared_memory=False, use_keys=[], avoid_keys=[]):
125 | self.use_shared_memory = use_shared_memory
126 | self.use_keys = use_keys
127 | if len(self.use_keys) > 0:
128 | self.avoid_keys = []
129 | else:
130 | self.avoid_keys = avoid_keys
131 |
132 | def __call__(self, batch):
133 | batch = self.cat_tensors(batch)
134 | return batch
135 |
136 | def cat_tensors(self, batch):
137 | if isinstance(batch, collections.Mapping):
138 | out = {}
139 | for key, value in batch.items():
140 | if (key in self.use_keys) or \
141 | (len(self.use_keys) == 0 and key not in self.avoid_keys):
142 | out[key] = self.cat_tensors(value)
143 | if ('batch_id' not in out) and torch.is_tensor(value[0]):
144 | out['batch_id'] = torch.cat([i * torch.ones(x.size(0)) \
145 | for i, x in enumerate(value)], 0)
146 | else:
147 | out[key] = value
148 | return out
149 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
150 | out = None
151 | if self.use_shared_memory:
152 | # If we're in a background process, concatenate directly into a
153 | # shared memory tensor to avoid an extra copy
154 | numel = sum([x.numel() for x in batch])
155 | storage = batch[0].storage()._new_shared(numel)
156 | out = batch[0].new(storage)
157 | return torch.cat(batch, 0, out=out)
158 | else:
159 | return batch
160 |
161 |
162 | class ToCuda(object):
163 |
164 | def __init__(self, *args, **kwargs):
165 | pass
166 |
167 | def __call__(self, batch):
168 | batch = self.to_cuda(batch)
169 | return batch
170 |
171 | def to_cuda(self, batch):
172 | if isinstance(batch, collections.Mapping):
173 | return {key: self.to_cuda(value) for key, value in batch.items()}
174 | elif torch.is_tensor(batch):
175 | # TODO: verify async usage
176 | return batch.cuda(non_blocking=True)
177 | elif type(batch).__name__ == 'Variable':
178 | # TODO: Really hacky
179 | return Variable(batch.data.cuda(non_blocking=True))
180 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
181 | return [self.to_cuda(value) for value in batch]
182 | else:
183 | return batch
184 |
185 |
186 | class ToCpu(object):
187 |
188 | def __init__(self, *args, **kwargs):
189 | pass
190 |
191 | def __call__(self, batch):
192 | batch = self.to_cpu(batch)
193 | return batch
194 |
195 | def to_cpu(self, batch):
196 | if isinstance(batch, collections.Mapping):
197 | return {key: self.to_cpu(value) for key, value in batch.items()}
198 | elif torch.is_tensor(batch):
199 | return batch.cpu()
200 | elif type(batch).__name__ == 'Variable':
201 | # TODO: Really hacky
202 | return Variable(batch.data.cpu())
203 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
204 | return [self.to_cpu(value) for value in batch]
205 | else:
206 | return batch
207 |
208 |
209 | class ToVariable(object):
210 |
211 | def __init__(self, volatile=False):
212 | self.volatile = volatile
213 |
214 | def __call__(self, batch):
215 | batch = self.to_variable(batch)
216 | return batch
217 |
218 | def to_variable(self, batch):
219 | if torch.is_tensor(batch):
220 | if self.volatile:
221 | return Variable(batch, volatile=True)
222 | else:
223 | return Variable(batch)
224 | elif isinstance(batch, collections.Mapping):
225 | return {key: self.to_variable(value) for key, value in batch.items()}
226 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
227 | return [self.to_variable(value) for value in batch]
228 | else:
229 | return batch
230 |
231 |
232 | class ToDetach(object):
233 |
234 | def __init__(self):
235 | pass
236 |
237 | def __call__(self, batch):
238 | batch = self.to_detach(batch)
239 | return batch
240 |
241 | def to_detach(self, batch):
242 | if torch.is_tensor(batch):
243 | return batch.detach_()
244 | elif isinstance(batch, collections.Mapping):
245 | return {key: self.to_detach(value) for key, value in batch.items()}
246 | elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
247 | return [self.to_detach(value) for value in batch]
248 | else:
249 | return batch
250 |
251 |
252 | class SortByKey(object):
253 |
254 | def __init__(self, key='lengths', reverse=True):
255 | self.key = key
256 | self.reverse = True
257 | self.i = 0
258 |
259 | def __call__(self, batch):
260 | self.set_sort_keys(batch[self.key]) # must be a list
261 | batch = self.sort_by_key(batch)
262 | return batch
263 |
264 | def set_sort_keys(self, sort_keys):
265 | self.i = 0
266 | self.sort_keys = sort_keys
267 |
268 | # ugly hack to be able to sort without lambda function
269 | def get_key(self, _):
270 | key = self.sort_keys[self.i]
271 | self.i += 1
272 | if self.i >= len(self.sort_keys):
273 | self.i = 0
274 | return key
275 |
276 | def sort_by_key(self, batch):
277 | if isinstance(batch, collections.Mapping):
278 | return {key: self.sort_by_key(value) for key, value in batch.items()}
279 | elif type(batch) is list: # isinstance(batch, collections.Sequence):
280 | return sorted(batch, key=self.get_key, reverse=self.reverse)
281 | else:
282 | return batch
283 |
284 |
285 | def dict_collate_fn():
286 | return Compose([
287 | ListDictsToDictLists(),
288 | StackTensors()
289 | ])
290 |
291 |
292 | class IndexSampler(Sampler):
293 | def __init__(self, data_source):
294 | self.data_source = data_source
295 |
296 | def __iter__(self):
297 | return iter(self.data_source)
298 |
299 | def __len__(self):
300 | return len(self.data_source)
301 |
--------------------------------------------------------------------------------