├── .gitignore ├── pyproject.toml ├── main.py ├── src └── maaf │ ├── config │ ├── __init__.py │ ├── config.py │ ├── defaults.py │ ├── compat.py │ └── arguments.py │ ├── __init__.py │ ├── utils │ ├── misc_utils.py │ ├── io_utils.py │ └── bn_utils.py │ ├── actions │ ├── eval_cfq.py │ └── eval_retrieval.py │ ├── datasets │ ├── datasets.py │ ├── fashiongen.py │ ├── mitstates.py │ ├── birdstowords.py │ ├── fashioniq.py │ └── fashion200k.py │ ├── models │ ├── heads.py │ ├── build.py │ ├── loss.py │ ├── image_model.py │ └── transformer.py │ ├── main.py │ └── train.py ├── experiment_scripts ├── extra_datasets │ ├── mitstates │ │ ├── mitstates.sh │ │ ├── mitstates_tirg.sh │ │ ├── mitstates_raf.sh │ │ ├── mitstates_baseline.sh │ │ ├── mitstates_maaf.sh │ │ ├── mitstates_all.sh │ │ ├── mitstates.yaml │ │ ├── mitstates_maaf.yaml │ │ ├── mitstates_tirg.yaml │ │ └── mitstates_raf.yaml │ └── birdstowords │ │ ├── birdstowords.sh │ │ ├── birdstowords_tirg.sh │ │ ├── birdstowords_raf.sh │ │ ├── birdstowords_baseline.sh │ │ ├── birdstowords_maaf.sh │ │ ├── birdstowords_all.sh │ │ ├── birdstowords.yaml │ │ ├── birdstowords_maaf.yaml │ │ ├── birdstowords_tirg.yaml │ │ └── birdstowords_raf.yaml ├── extras │ ├── clipmaaf.sh │ ├── cliptirg.sh │ ├── clipmaaf.yaml │ └── cliptirg.yaml ├── eval_all.sh ├── paper │ ├── no_imfq_pretrain.sh │ ├── no_imfq_pretrain_af.sh │ ├── no_imfq_pretrain_afbig.sh │ ├── imfq_pretrain.sh │ ├── imfq_pretrain_af.sh │ ├── imfq_pretrain.yaml │ └── imfq_pretrain_af.yaml └── eval.sh ├── old_train.py ├── configs ├── random.yaml ├── clip_fiq.yaml ├── clip_fg.yaml ├── clip_imat.yaml ├── adam_imat.yaml ├── clipresmaaf_fiq.yaml ├── clipresmaaf_fg.yaml ├── clipresmaaf_imat.yaml ├── fiq_maaf_roberta.yaml └── adam_crm_imat.yaml ├── Contributing.md ├── setup.py ├── README.md ├── setup.cfg ├── Code-of-Conduct.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | **.pyc 2 | experiments 3 | .DS_Store 4 | .ipynb_checkpoints 5 | .sync-config.cson 6 | .cache 7 | *.log 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # Minimum requirements for the build system to execute. 3 | requires = ["setuptools", "wheel"] # PEP 508 specifications. 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | from maaf import main 5 | 6 | main.main() 7 | -------------------------------------------------------------------------------- /src/maaf/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | from .config import get_config 6 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/mitstates/mitstates.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME mitstates-clip-finetune 4 | -------------------------------------------------------------------------------- /old_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | from maaf import main 5 | 6 | main.main(old_args=True) 7 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_tirg.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/mitstates/mitstates_tirg.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME mitstates-cliptirg 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_raf.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/mitstates/mitstates_raf.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME mitstates-clipraf-finetune 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/birdstowords/birdstowords.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME birdstowords-clip-finetune 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_baseline.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/mitstates/mitstates.yaml \ 2 | --no-train --final_eval_on_test \ 3 | EXP_NAME mitstates_baseline 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_maaf.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/mitstates/mitstates_maaf.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME mitstates-clipaf-finetune 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_tirg.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/birdstowords/birdstowords_tirg.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME birdstowords-cliptirg 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_raf.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/birdstowords/birdstowords_raf.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME birdstowords-clipraf-finetune 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_baseline.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/birdstowords/birdstowords.yaml \ 2 | --no-train --final_eval_on_test \ 3 | EXP_NAME birdstowords_baseline 4 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_maaf.sh: -------------------------------------------------------------------------------- 1 | python main.py --config experiment_scripts/extra_datasets/birdstowords/birdstowords_maaf.yaml \ 2 | --final_eval_on_test \ 3 | EXP_NAME birdstowords-clipaf-finetune 4 | -------------------------------------------------------------------------------- /src/maaf/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | maaf module 3 | """ 4 | import pkg_resources 5 | from typing import List 6 | 7 | 8 | __all__: List[str] = [] 9 | __copyright__: str = "Copyright 2022 Yahoo" 10 | __version__: str = pkg_resources.get_distribution("maaf").version 11 | -------------------------------------------------------------------------------- /src/maaf/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | from tqdm import tqdm as original_tqdm 5 | from functools import partial 6 | 7 | tqdm = partial(original_tqdm, dynamic_ncols=True) 8 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_all.sh: -------------------------------------------------------------------------------- 1 | bash experiment_scripts/extra_datasets/mitstates/mitstates_baseline.sh 2 | bash experiment_scripts/extra_datasets/mitstates/mitstates_maaf.sh 3 | bash experiment_scripts/extra_datasets/mitstates/mitstates_raf.sh 4 | bash experiment_scripts/extra_datasets/mitstates/mitstates_tirg.sh 5 | bash experiment_scripts/extra_datasets/mitstates/mitstates.sh 6 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_all.sh: -------------------------------------------------------------------------------- 1 | bash experiment_scripts/extra_datasets/birdstowords/birdstowords_baseline.sh 2 | bash experiment_scripts/extra_datasets/birdstowords/birdstowords_maaf.sh 3 | bash experiment_scripts/extra_datasets/birdstowords/birdstowords_raf.sh 4 | bash experiment_scripts/extra_datasets/birdstowords/birdstowords_tirg.sh 5 | bash experiment_scripts/extra_datasets/birdstowords/birdstowords.sh 6 | -------------------------------------------------------------------------------- /experiment_scripts/extras/clipmaaf.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=/home/default/ephemeral_drive/experiments/afhp 2 | BASE_CONFIG=experiment_scripts/extras/clipmaaf.yaml 3 | 4 | 5 | TAG="clipmaaf" 6 | EXP_NAME=${TAG}_$(date "+%Y-%m-%d-%H-%M-%S") 7 | exp_dir=${OUTPUT_DIR}/$EXP_NAME 8 | 9 | python main.py --config $BASE_CONFIG --no-timestamp \ 10 | EXP_NAME $EXP_NAME \ 11 | OUTPUT_DIR $OUTPUT_DIR 12 | 13 | bash experiment_scripts/eval.sh $exp_dir 14 | -------------------------------------------------------------------------------- /experiment_scripts/extras/cliptirg.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=/home/default/ephemeral_drive/experiments/tirg 2 | BASE_CONFIG=experiment_scripts/extras/cliptirg.yaml 3 | 4 | TAG="ctirg-fiq" 5 | EXP_NAME=${TAG}_$(date "+%Y-%m-%d-%H-%M-%S") 6 | exp_dir=${OUTPUT_DIR}/$EXP_NAME 7 | 8 | python main.py --config $BASE_CONFIG --no-timestamp \ 9 | EXP_NAME $EXP_NAME \ 10 | OUTPUT_DIR $OUTPUT_DIR \ 11 | 12 | bash experiment_scripts/eval.sh $exp_dir 13 | -------------------------------------------------------------------------------- /experiment_scripts/eval_all.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | for expdir in /home/default/ephemeral_drive/experiments/paper/*; do 6 | [ -e "${expdir}/config.yaml" ] || continue # skip empty or empty glob 7 | [ ! -e "${expdir}/claimed.temp" ] || continue # skip what another process is doing 8 | echo ${expdir} 9 | touch ${expdir}/claimed.temp 10 | bash experiment_scripts/eval.sh $expdir 11 | rm ${expdir}/claimed.temp 12 | done 13 | -------------------------------------------------------------------------------- /experiment_scripts/paper/no_imfq_pretrain.sh: -------------------------------------------------------------------------------- 1 | BASE_CONFIG=experiment_scripts/paper/imfq_pretrain.yaml 2 | OUTPUT_DIR=/home/default/ephemeral_drive/experiments/ 3 | 4 | TAG="clip-fiq" 5 | EXP_NAME=${TAG}_$(date "+%Y-%m-%d-%H%M%S") 6 | exp_dir=${OUTPUT_DIR}/$EXP_NAME 7 | 8 | # fine tune (and eval) on Fashion IQ 9 | python main.py --config $BASE_CONFIG --no-timestamp \ 10 | EXP_NAME ${EXP_NAME} \ 11 | OUTPUT_DIR $OUTPUT_DIR \ 12 | DATASET.NAME fashioniq \ 13 | DATASET.PATH /home/default/ephemeral_drive/Data/FashionIQ/ \ 14 | DATASET.AUGMENTATION.IMAGE_AUGMENTATION True \ 15 | SOLVER.LEARNING_RATE_DECAY_FREQUENCY 980 \ 16 | SOLVER.NUM_ITERS 1960 \ 17 | 18 | bash experiment_scripts/eval.sh ${exp_dir} 19 | -------------------------------------------------------------------------------- /experiment_scripts/paper/no_imfq_pretrain_af.sh: -------------------------------------------------------------------------------- 1 | BASE_CONFIG=experiment_scripts/paper/imfq_pretrain_af.yaml 2 | OUTPUT_DIR=/home/default/ephemeral_drive/experiments/ 3 | 4 | TAG="clipresmaaf-fiq" 5 | EXP_NAME=${TAG}_$(date "+%Y-%m-%d-%H%M%S") 6 | exp_dir=${OUTPUT_DIR}/$EXP_NAME 7 | 8 | # fine tune (and eval) on Fashion IQ 9 | python main.py --config $BASE_CONFIG --no-timestamp \ 10 | EXP_NAME ${EXP_NAME} \ 11 | OUTPUT_DIR $OUTPUT_DIR \ 12 | DATASET.NAME fashioniq \ 13 | DATASET.PATH /home/default/ephemeral_drive/Data/FashionIQ/ \ 14 | DATASET.AUGMENTATION.IMAGE_AUGMENTATION True \ 15 | SOLVER.LEARNING_RATE_DECAY_FREQUENCY 980 \ 16 | SOLVER.NUM_ITERS 1960 \ 17 | 18 | bash experiment_scripts/eval.sh ${exp_dir} 19 | -------------------------------------------------------------------------------- /configs/random.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | NAME: fashioniq 6 | PATH: /home/default/ephemeral_drive/Data/FashionIQ/ 7 | REQUIRE_IMAGES: false 8 | AUGMENTATION: 9 | IMAGE_AUGMENTATION: false 10 | DATA_LOADER: 11 | LOADER_NUM_WORKERS: 0 12 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 13 | EXP_NAME: random 14 | MODEL: 15 | EMBED_DIM: 1024 16 | COMPOSITION: random 17 | DEVICE: cpu 18 | LOSS: batch_based_classification 19 | TEXT_MODEL: 20 | ARCHITECTURE: null 21 | TOKENIZER: null 22 | IMAGE_MODEL: 23 | ARCHITECTURE: null 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | -------------------------------------------------------------------------------- /experiment_scripts/paper/no_imfq_pretrain_afbig.sh: -------------------------------------------------------------------------------- 1 | BASE_CONFIG=experiment_scripts/paper/imfq_pretrain_af.yaml 2 | OUTPUT_DIR=/home/default/ephemeral_drive/experiments/bigres 3 | 4 | TAG="clipbigresmaaf-fiq" 5 | EXP_NAME=${TAG}_$(date "+%Y-%m-%d-%H%M%S") 6 | exp_dir=${OUTPUT_DIR}/$EXP_NAME 7 | 8 | # fine tune (and eval) on Fashion IQ 9 | python main.py --config $BASE_CONFIG --no-timestamp \ 10 | EXP_NAME ${EXP_NAME} \ 11 | OUTPUT_DIR $OUTPUT_DIR \ 12 | DATASET.NAME fashioniq \ 13 | DATASET.PATH /home/default/ephemeral_drive/Data/FashionIQ/ \ 14 | DATASET.AUGMENTATION.IMAGE_AUGMENTATION True \ 15 | SOLVER.LEARNING_RATE_DECAY_FREQUENCY 980 \ 16 | SOLVER.NUM_ITERS 1960 \ 17 | MODEL.MAAF.RESIDUAL.INITIAL_MAAF_WEIGHT 1. 18 | 19 | bash experiment_scripts/eval.sh ${exp_dir} 20 | -------------------------------------------------------------------------------- /Contributing.md: -------------------------------------------------------------------------------- 1 | ## Static Project 2 | 3 | First, thanks for your interest in our project. Please note, this project is inactive. We published this in the interest of making our research transparent and reproducible. 4 | 5 | You may have questions, issues, feature requests, or code contributions. This project is inactive. What can you do next? 6 | 7 | * Research. By looking at the project information on GitHub, you can find out who was active in this project and reach out to those people directly to find out about the project. They may be willing to address your questions, issues, etc. They may be inclined to re-activate the project too. 8 | * Fork and Resurrect. This project is public and under an open source license. You are welcome to fork the project and grow a community around your project fork. 9 | -------------------------------------------------------------------------------- /experiment_scripts/paper/imfq_pretrain.sh: -------------------------------------------------------------------------------- 1 | BASE_CONFIG=experiment_scripts/paper/imfq_pretrain.yaml 2 | OUTPUT_DIR=/home/default/ephemeral_drive/experiments/ 3 | 4 | TAG="clip-imfq-fiq" 5 | EXP_NAME=${TAG}_$(date "+%Y-%m-%d-%H%M%S") 6 | exp_dir=${OUTPUT_DIR}/$EXP_NAME 7 | 8 | python main.py --config $BASE_CONFIG --no-timestamp \ 9 | EXP_NAME $EXP_NAME \ 10 | OUTPUT_DIR $OUTPUT_DIR \ 11 | 12 | bash experiment_scripts/eval.sh $exp_dir 13 | 14 | # fine tune (and eval) on Fashion IQ 15 | python main.py --config $exp_dir/config.yaml --no-timestamp \ 16 | EXP_NAME ${EXP_NAME}-finetune \ 17 | DATASET.NAME fashioniq \ 18 | DATASET.PATH /home/default/ephemeral_drive/Data/FashionIQ/ \ 19 | DATASET.AUGMENTATION.IMAGE_AUGMENTATION True \ 20 | SOLVER.LEARNING_RATE_DECAY_FREQUENCY 980 \ 21 | SOLVER.NUM_ITERS 1960 \ 22 | MODEL.WEIGHTS $exp_dir/latest_checkpoint.pth \ 23 | 24 | bash experiment_scripts/eval.sh ${exp_dir}-finetune 25 | -------------------------------------------------------------------------------- /experiment_scripts/paper/imfq_pretrain_af.sh: -------------------------------------------------------------------------------- 1 | BASE_CONFIG=experiment_scripts/paper/imfq_pretrain_af.yaml 2 | OUTPUT_DIR=/home/default/ephemeral_drive/experiments/ 3 | 4 | TAG="clipresmaaf-imfq-fiq" 5 | EXP_NAME=${TAG}_$(date "+%Y-%m-%d-%H%M%S") 6 | exp_dir=${OUTPUT_DIR}/$EXP_NAME 7 | 8 | python main.py --config $BASE_CONFIG --no-timestamp \ 9 | EXP_NAME $EXP_NAME \ 10 | OUTPUT_DIR $OUTPUT_DIR \ 11 | 12 | bash experiment_scripts/eval.sh $exp_dir 13 | 14 | # fine tune (and eval) on Fashion IQ 15 | python main.py --config $exp_dir/config.yaml --no-timestamp \ 16 | EXP_NAME ${EXP_NAME}-finetune \ 17 | DATASET.NAME fashioniq \ 18 | DATASET.PATH /home/default/ephemeral_drive/Data/FashionIQ/ \ 19 | DATASET.AUGMENTATION.IMAGE_AUGMENTATION True \ 20 | SOLVER.LEARNING_RATE_DECAY_FREQUENCY 980 \ 21 | SOLVER.NUM_ITERS 1960 \ 22 | MODEL.WEIGHTS $exp_dir/latest_checkpoint.pth \ 23 | 24 | bash experiment_scripts/eval.sh ${exp_dir}-finetune 25 | -------------------------------------------------------------------------------- /experiment_scripts/paper/imfq_pretrain.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: imat_fashion 3 | PATH: /home/default/ephemeral_drive/Data/imat2018/ 4 | AUGMENTATION: 5 | IMAGE_AUGMENTATION: null 6 | DATA_LOADER: 7 | LOADER_NUM_WORKERS: 0 8 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 9 | EXP_NAME: defaultexpname 10 | MODEL: 11 | EMBED_DIM: 1024 12 | COMPOSITION: clip 13 | DEVICE: cuda 14 | LOSS: batch_based_classification # consider 15 | TEXT_MODEL: 16 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 17 | TOKENIZER: null 18 | IMAGE_MODEL: 19 | ARCHITECTURE: RN50 20 | OUTPUTS: 21 | - 4 22 | - attnpool 23 | INCLUDES_IMAGE_TRANSFORM: false 24 | SOLVER: 25 | OPTIMIZER: adam 26 | BATCH_SIZE: 128 27 | DROP_WORST_RATE: 0 28 | EVAL_EVERY: 1 29 | LEARNING_RATE: 1.0e-06 30 | LEARNING_RATE_DECAY: 0.1 31 | LEARNING_RATE_DECAY_FREQUENCY: 7254 # consider 980 (yes, an accident if so) 32 | NUM_ITERS: 21762 33 | BATCH_NORM_MODE: freeze_bn 34 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1.0 35 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 1.0 36 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: mitstates 3 | PATH: /home/default/ephemeral_drive/Data/mitstates/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/fashioniq/ 10 | EXP_NAME: clip-mitstates 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: clip 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | INCLUDES_IMAGE_TRANSFORM: false 22 | INITIAL_NORMALIZATION_FACTOR: 4. 23 | SOLVER: 24 | BATCH_SIZE: 128 25 | OPTIMIZER: adam 26 | BATCH_NORM_MODE: freeze_bn 27 | DROP_WORST_RATE: 0 28 | EVAL_EVERY: 1 29 | LEARNING_RATE: 1.0e-05 30 | LEARNING_RATE_DECAY: 0.1 31 | LEARNING_RATE_DECAY_FREQUENCY: 2696 32 | NUM_ITERS: 3370 33 | BATCH_NORM_MODE: freeze_bn 34 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1.0 35 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 1.0 36 | ALWAYS_EVAL_TEST: true 37 | FINAL_EVAL_ON_TEST: true 38 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: birdstowords 3 | PATH: /home/default/ephemeral_drive/Data/birds-to-words/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 10 | EXP_NAME: clip-birdstowords 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: clip 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | INCLUDES_IMAGE_TRANSFORM: false 22 | INITIAL_NORMALIZATION_FACTOR: 4. 23 | SOLVER: 24 | BATCH_SIZE: 128 25 | OPTIMIZER: adam 26 | BATCH_NORM_MODE: freeze_bn 27 | DROP_WORST_RATE: 0 28 | EVAL_EVERY: 1 29 | LEARNING_RATE: 1.0e-05 30 | LEARNING_RATE_DECAY: 0.1 31 | LEARNING_RATE_DECAY_FREQUENCY: 176 32 | NUM_ITERS: 220 33 | BATCH_NORM_MODE: freeze_bn 34 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1.0 35 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 1.0 36 | ALWAYS_EVAL_TEST: true 37 | FINAL_EVAL_ON_TEST: true 38 | -------------------------------------------------------------------------------- /configs/clip_fiq.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | NAME: fashioniq 6 | PATH: /home/default/ephemeral_drive/Data/FashionIQ/ 7 | REQUIRE_IMAGES: false 8 | AUGMENTATION: 9 | IMAGE_AUGMENTATION: true 10 | DATA_LOADER: 11 | LOADER_NUM_WORKERS: 0 12 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/fashioniq/ 13 | EXP_NAME: clip-fiq-lr_6-ds 14 | MODEL: 15 | EMBED_DIM: 1024 16 | COMPOSITION: clip 17 | DEVICE: cuda 18 | LOSS: double_softmax 19 | TEXT_MODEL: 20 | ARCHITECTURE: null 21 | TOKENIZER: null 22 | IMAGE_MODEL: 23 | ARCHITECTURE: null # because included in clip 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | INITIAL_NORMALIZATION_FACTOR: 4. 26 | SOLVER: 27 | BATCH_SIZE: 128 28 | DROP_WORST_RATE: 0 29 | EVAL_EVERY: 1 30 | LEARNING_RATE: 0.00001 31 | LEARNING_RATE_DECAY: 0.1 32 | LEARNING_RATE_DECAY_FREQUENCY: 4000 33 | NUM_ITERS: 6000 34 | BATCH_NORM_MODE: freeze_bn 35 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1.0 36 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 1.0 37 | -------------------------------------------------------------------------------- /configs/clip_fg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | EXP_NAME: fashiongen-clip 5 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 6 | MODEL: 7 | EMBED_DIM: 1024 8 | COMPOSITION: clip 9 | DEVICE: cuda 10 | TEXT_MODEL: 11 | ARCHITECTURE: null 12 | TOKENIZER: null 13 | IMAGE_MODEL: 14 | ARCHITECTURE: null # because included in clip 15 | INCLUDES_IMAGE_TRANSFORM: false 16 | LOSS: double_softmax 17 | INITIAL_NORMALIZATION_FACTOR: 4. 18 | DATASET: 19 | NAME: fashiongen 20 | PATH: /home/default/ephemeral_drive/Data/fashiongen/ 21 | CLASS_WEIGHTS: null 22 | IMAGE_DIR: null 23 | REQUIRE_IMAGES: false 24 | AUGMENTATION: 25 | IMAGE_AUGMENTATION: null 26 | SOLVER: 27 | OPTIMIZER: sgd 28 | BATCH_SIZE: 128 29 | NUM_ITERS: 24420 # 12 epochs 30 | LEARNING_RATE_DECAY_FREQUENCY: 20350 # 10 epochs 31 | EVAL_EVERY: 1 32 | LEARNING_RATE: 0.00001 33 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1. 34 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 1. 35 | BATCH_NORM_MODE: freeze_bn 36 | DATA_LOADER: 37 | LOADER_NUM_WORKERS: 0 # otherwise hdf5 causes problems 38 | -------------------------------------------------------------------------------- /configs/clip_imat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | CLASS_WEIGHTS: null 6 | IMAGE_DIR: null 7 | NAME: imat_fashion 8 | NUM_CLASSES: null 9 | PATH: /home/default/ephemeral_drive/Data/imat2018/ 10 | REQUIRE_IMAGES: false 11 | AUGMENTATION: 12 | IMAGE_AUGMENTATION: null 13 | DATA_LOADER: 14 | LOADER_NUM_WORKERS: 0 15 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/imat/ 16 | EXP_NAME: clip-imat-adam 17 | MODEL: 18 | EMBED_DIM: 1024 19 | COMPOSITION: clip 20 | DEVICE: cuda 21 | LOSS: double_softmax 22 | TEXT_MODEL: 23 | ARCHITECTURE: null 24 | TOKENIZER: null 25 | IMAGE_MODEL: 26 | ARCHITECTURE: null # because included in clip 27 | INCLUDES_IMAGE_TRANSFORM: false 28 | INITIAL_NORMALIZATION_FACTOR: 4. 29 | SOLVER: 30 | OPTIMIZER: sgd 31 | BATCH_SIZE: 128 32 | DROP_WORST_RATE: 0 33 | EVAL_EVERY: 1 34 | LEARNING_RATE: 0.00001 35 | LEARNING_RATE_DECAY: 0.1 36 | LEARNING_RATE_DECAY_FREQUENCY: 72540 # 10 epochs 37 | NUM_ITERS: 87048 # 12 epochs 38 | BATCH_NORM_MODE: freeze_bn 39 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1. 40 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 1. 41 | -------------------------------------------------------------------------------- /experiment_scripts/extras/clipmaaf.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: fashioniq 3 | PATH: /home/default/ephemeral_drive/Data/FashionIQ/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: true 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 10 | EXP_NAME: default 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: clipmaaf 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 1 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | SOLVER: 33 | OPTIMIZER: adam 34 | BATCH_SIZE: 128 35 | DROP_WORST_RATE: 0 36 | EVAL_EVERY: 1 37 | LEARNING_RATE: 1.0e-05 38 | LEARNING_RATE_DECAY: 0.1 39 | LEARNING_RATE_DECAY_FREQUENCY: 980 40 | NUM_ITERS: 1960 41 | BATCH_NORM_MODE: freeze_bn 42 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 43 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 44 | PROJECTION_LR_TIED_TO_PRETRAINED: false 45 | -------------------------------------------------------------------------------- /experiment_scripts/extras/cliptirg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: fashioniq 3 | PATH: /home/default/ephemeral_drive/Data/FashionIQ/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: true 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 10 | EXP_NAME: clipaf-fiq-baseline 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: cliptirg 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 1 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | RESIDUAL: 33 | INITIAL_MAAF_WEIGHT: 0.1 34 | SOLVER: 35 | OPTIMIZER: adam 36 | BATCH_SIZE: 128 37 | DROP_WORST_RATE: 0 38 | EVAL_EVERY: 1 39 | LEARNING_RATE: 1.0e-05 40 | LEARNING_RATE_DECAY: 0.1 41 | LEARNING_RATE_DECAY_FREQUENCY: 700 42 | NUM_ITERS: 1400 43 | BATCH_NORM_MODE: freeze_bn 44 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 45 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 46 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_maaf.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: mitstates 3 | PATH: /home/default/ephemeral_drive/Data/mitstates/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/fashioniq/ 10 | EXP_NAME: clip-mitstates-raf 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: clipmaaf 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 2 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | SOLVER: 33 | BATCH_SIZE: 128 34 | OPTIMIZER: adam 35 | BATCH_NORM_MODE: freeze_bn 36 | DROP_WORST_RATE: 0 37 | EVAL_EVERY: 1 38 | LEARNING_RATE: 1.0e-05 39 | LEARNING_RATE_DECAY: 0.1 40 | LEARNING_RATE_DECAY_FREQUENCY: 2696 41 | NUM_ITERS: 3370 42 | BATCH_NORM_MODE: freeze_bn 43 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 44 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 45 | ALWAYS_EVAL_TEST: true 46 | FINAL_EVAL_ON_TEST: true 47 | PROJECTION_LR_TIED_TO_PRETRAINED: false 48 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_maaf.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: birdstowords 3 | PATH: /home/default/ephemeral_drive/Data/birds-to-words/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 10 | EXP_NAME: clip-birdstowords-raf 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: clipmaaf 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 2 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | SOLVER: 33 | BATCH_SIZE: 128 34 | OPTIMIZER: adam 35 | BATCH_NORM_MODE: freeze_bn 36 | DROP_WORST_RATE: 0 37 | EVAL_EVERY: 1 38 | LEARNING_RATE: 1.0e-05 39 | LEARNING_RATE_DECAY: 0.1 40 | LEARNING_RATE_DECAY_FREQUENCY: 176 41 | NUM_ITERS: 220 42 | BATCH_NORM_MODE: freeze_bn 43 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 44 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 45 | ALWAYS_EVAL_TEST: true 46 | FINAL_EVAL_ON_TEST: true 47 | PROJECTION_LR_TIED_TO_PRETRAINED: false 48 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_tirg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: mitstates 3 | PATH: /home/default/ephemeral_drive/Data/mitstates/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/fashioniq/ 10 | EXP_NAME: cliptirg-mitstates 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: cliptirg 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 1 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | RESIDUAL: 33 | INITIAL_MAAF_WEIGHT: 0.1 34 | SOLVER: 35 | BATCH_SIZE: 128 36 | OPTIMIZER: adam 37 | BATCH_NORM_MODE: freeze_bn 38 | DROP_WORST_RATE: 0 39 | EVAL_EVERY: 1 40 | LEARNING_RATE: 1.0e-05 41 | LEARNING_RATE_DECAY: 0.1 42 | LEARNING_RATE_DECAY_FREQUENCY: 2696 43 | NUM_ITERS: 3370 44 | BATCH_NORM_MODE: freeze_bn 45 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 46 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 47 | ALWAYS_EVAL_TEST: true 48 | FINAL_EVAL_ON_TEST: true 49 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_tirg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: birdstowords 3 | PATH: /home/default/ephemeral_drive/Data/birds-to-words/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 10 | EXP_NAME: cliptirg-birdstowords 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: cliptirg 14 | DEVICE: cuda 15 | LOSS: batch_based_classification 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 1 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | RESIDUAL: 33 | INITIAL_MAAF_WEIGHT: 0.1 34 | SOLVER: 35 | BATCH_SIZE: 128 36 | OPTIMIZER: adam 37 | BATCH_NORM_MODE: freeze_bn 38 | DROP_WORST_RATE: 0 39 | EVAL_EVERY: 1 40 | LEARNING_RATE: 1.0e-05 41 | LEARNING_RATE_DECAY: 0.1 42 | LEARNING_RATE_DECAY_FREQUENCY: 176 43 | NUM_ITERS: 220 44 | BATCH_NORM_MODE: freeze_bn 45 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 46 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 47 | ALWAYS_EVAL_TEST: true 48 | FINAL_EVAL_ON_TEST: true 49 | -------------------------------------------------------------------------------- /experiment_scripts/paper/imfq_pretrain_af.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: imat_fashion 3 | PATH: /home/default/ephemeral_drive/Data/imat2018/ 4 | AUGMENTATION: 5 | IMAGE_AUGMENTATION: null 6 | DATA_LOADER: 7 | LOADER_NUM_WORKERS: 0 8 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 9 | EXP_NAME: defaultexpname 10 | MODEL: 11 | EMBED_DIM: 1024 12 | COMPOSITION: clipresmaaf 13 | DEVICE: cuda 14 | LOSS: batch_based_classification # consider 15 | TEXT_MODEL: 16 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 17 | TOKENIZER: null 18 | IMAGE_MODEL: 19 | ARCHITECTURE: RN50 20 | OUTPUTS: 21 | - 4 22 | - attnpool 23 | INCLUDES_IMAGE_TRANSFORM: false 24 | MAAF: 25 | ATTENTION_HEADS: 8 26 | ATTN_SOFTMAX_REPLACEMENT: null 27 | BLOCK_WIDTH: 256 28 | NUM_BLOCKS: 1 29 | OUTPUT: rwpool 30 | POSITION_ENCODING: null 31 | RESIDUAL: 32 | INITIAL_MAAF_PRESIGMOID: null 33 | INITIAL_MAAF_WEIGHT: 0.0067 34 | LEARN_WEIGHTS: false 35 | SOLVER: 36 | OPTIMIZER: adam 37 | BATCH_SIZE: 128 38 | DROP_WORST_RATE: 0 39 | EVAL_EVERY: 1 40 | LEARNING_RATE: 1.0e-05 41 | LEARNING_RATE_DECAY: 0.1 42 | LEARNING_RATE_DECAY_FREQUENCY: 7254 # consider 980 (yes, an accident if so) 43 | NUM_ITERS: 21762 44 | BATCH_NORM_MODE: freeze_bn 45 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 46 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 47 | PROJECTION_LR_TIED_TO_PRETRAINED: false 48 | -------------------------------------------------------------------------------- /src/maaf/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | import numpy as np 6 | from io import BytesIO 7 | 8 | def ensure_json_serializable(value): 9 | """ 10 | Recursively ensures all values passed in are json serializable 11 | """ 12 | if isinstance(value, np.ndarray): 13 | return value.tolist() 14 | elif isinstance(value, (np.float, np.float32, np.float64)): 15 | return float(value) 16 | elif isinstance(value, (np.uint8, np.int32, np.int64, np.integer)): 17 | return int(value) 18 | elif isinstance(value, dict): 19 | new_dict = {} 20 | for k, v in value.items(): 21 | new_dict[k] = ensure_json_serializable(v) 22 | return new_dict 23 | elif isinstance(value, list): 24 | new_list = [] 25 | for element in value: 26 | new_list.append(ensure_json_serializable(element)) 27 | return new_list 28 | else: 29 | return value 30 | 31 | def pil_image_to_bytes(pim): 32 | """ 33 | Converts PIL image to b64 string. 34 | :params PIL.Image pim: 35 | the PIL image we want to convert 36 | :returns str: 37 | Returns a string of the b64 encoded pixels 38 | """ 39 | buffer = BytesIO() 40 | pim.save(buffer, format="JPEG") 41 | img_str= buffer.getvalue() 42 | #img_str = base64.b64encode(buffer.getvalue()) 43 | return(img_str) 44 | -------------------------------------------------------------------------------- /configs/adam_imat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | CLASS_WEIGHTS: null 8 | IMAGE_DIR: null 9 | NAME: imat_fashion 10 | NUM_CLASSES: null 11 | PATH: /home/default/ephemeral_drive/Data/imat2018/ 12 | REQUIRE_IMAGES: false 13 | SINGLE_CLASS_BATCHES: false 14 | DATA_LOADER: 15 | LOADER_NUM_WORKERS: 0 16 | EXP_NAME: clip-imat 17 | MODEL: 18 | EMBED_DIM: 1024 19 | COMPOSITION: clip 20 | DEVICE: cuda 21 | LOSS: double_softmax 22 | TEXT_MODEL: 23 | ARCHITECTURE: null 24 | TOKENIZER: null 25 | IMAGE_MODEL: 26 | ARCHITECTURE: null # because included in clip 27 | INCLUDES_IMAGE_TRANSFORM: false 28 | INITIAL_NORMALIZATION_FACTOR: 4. 29 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/paper/ 30 | SOLVER: 31 | ALWAYS_EVAL_TEST: false 32 | BATCH_NORM_MODE: freeze_bn 33 | BATCH_SIZE: 128 34 | DROP_WORST_RATE: 0 35 | EVAL_EVERY: 1 36 | FINAL_EVAL_ON_TEST: false 37 | LEARNING_RATE: 1.0e-05 38 | LEARNING_RATE_DECAY: 0.1 39 | LEARNING_RATE_DECAY_FREQUENCY: 7254 40 | LR_DECAY_ONLY_ONCE: false 41 | MOMENTUM: 0.9 42 | NUM_ITERS: 21762 43 | OPTIMIZER: adam 44 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 1. 45 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1. 46 | PROJECTION_LR_TIED_TO_PRETRAINED: true 47 | SAVE_EVERY: 100 48 | SCHEDULE_ITERS: [] 49 | SCHEDULE_RATES: [] 50 | SOFTMAX_MARGIN: 0 51 | WEIGHT_DECAY: 1.0e-06 52 | -------------------------------------------------------------------------------- /configs/clipresmaaf_fiq.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | NAME: fashioniq 6 | PATH: /home/default/ephemeral_drive/Data/FashionIQ/ 7 | REQUIRE_IMAGES: false 8 | DATA_LOADER: 9 | LOADER_NUM_WORKERS: 0 10 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/fashioniq/clipmaaf/ 11 | EXP_NAME: clipmaaf1-fiq 12 | MODEL: 13 | COMPOSITION: clipmaaf 14 | DEVICE: cuda 15 | LOSS: double_softmax 16 | EMBED_DIM: 1024 17 | TEXT_MODEL: 18 | ARCHITECTURE: null 19 | TOKENIZER: null 20 | IMAGE_MODEL: 21 | ARCHITECTURE: null # because included in clip 22 | OUTPUTS: 23 | - 4 24 | - attnpool 25 | INCLUDES_IMAGE_TRANSFORM: false 26 | MAAF: 27 | ATTENTION_HEADS: 8 28 | ATTN_SOFTMAX_REPLACEMENT: null 29 | BLOCK_WIDTH: 256 30 | NUM_BLOCKS: 1 31 | OUTPUT: rwpool 32 | POSITION_ENCODING: null 33 | RESIDUAL: 34 | INITIAL_MAAF_WEIGHT: 0.1 35 | WEIGHTS: null 36 | INITIAL_NORMALIZATION_FACTOR: 4. 37 | SOLVER: 38 | BATCH_SIZE: 128 39 | DROP_WORST_RATE: 0 40 | EVAL_EVERY: 1 41 | LEARNING_RATE: 0.0001 42 | LEARNING_RATE_DECAY: 0.1 43 | LEARNING_RATE_DECAY_FREQUENCY: 4000 44 | LR_DECAY_ONLY_ONCE: false 45 | NUM_ITERS: 6000 46 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 47 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 48 | SAVE_EVERY: 100 49 | SOFTMAX_MARGIN: 0 50 | WEIGHT_DECAY: 1.0e-06 51 | PROJECTION_LR_TIED_TO_PRETRAINED: false 52 | BATCH_NORM_MODE: freeze_bn 53 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/mitstates/mitstates_raf.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: mitstates 3 | PATH: /home/default/ephemeral_drive/Data/mitstates/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/fashioniq/ 10 | EXP_NAME: clip-mitstates-raf 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: clipresmaaf 14 | DEVICE: cuda 15 | LOSS: batch_based_classification # consider 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 2 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | RESIDUAL: 33 | INITIAL_MAAF_PRESIGMOID: null 34 | INITIAL_MAAF_WEIGHT: 0.0067 35 | LEARN_WEIGHTS: false 36 | SOLVER: 37 | BATCH_SIZE: 128 38 | OPTIMIZER: adam 39 | BATCH_NORM_MODE: freeze_bn 40 | DROP_WORST_RATE: 0 41 | EVAL_EVERY: 1 42 | LEARNING_RATE: 1.0e-05 43 | LEARNING_RATE_DECAY: 0.1 44 | LEARNING_RATE_DECAY_FREQUENCY: 2696 45 | NUM_ITERS: 3370 46 | BATCH_NORM_MODE: freeze_bn 47 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 48 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 49 | ALWAYS_EVAL_TEST: true 50 | FINAL_EVAL_ON_TEST: true 51 | PROJECTION_LR_TIED_TO_PRETRAINED: false 52 | -------------------------------------------------------------------------------- /experiment_scripts/extra_datasets/birdstowords/birdstowords_raf.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: birdstowords 3 | PATH: /home/default/ephemeral_drive/Data/birds-to-words/ 4 | REQUIRE_IMAGES: false 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | DATA_LOADER: 8 | LOADER_NUM_WORKERS: 0 9 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 10 | EXP_NAME: clip-birdstowords-raf 11 | MODEL: 12 | EMBED_DIM: 1024 13 | COMPOSITION: clipresmaaf 14 | DEVICE: cuda 15 | LOSS: batch_based_classification # consider 16 | TEXT_MODEL: 17 | ARCHITECTURE: default # use the model paired with IMAGE_MODEL.ARCHITECTURE 18 | TOKENIZER: null 19 | IMAGE_MODEL: 20 | ARCHITECTURE: RN50 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 2 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | RESIDUAL: 33 | INITIAL_MAAF_PRESIGMOID: null 34 | INITIAL_MAAF_WEIGHT: 0.0067 35 | LEARN_WEIGHTS: false 36 | SOLVER: 37 | BATCH_SIZE: 128 38 | OPTIMIZER: adam 39 | BATCH_NORM_MODE: freeze_bn 40 | DROP_WORST_RATE: 0 41 | EVAL_EVERY: 1 42 | LEARNING_RATE: 1.0e-05 43 | LEARNING_RATE_DECAY: 0.1 44 | LEARNING_RATE_DECAY_FREQUENCY: 176 45 | NUM_ITERS: 220 46 | BATCH_NORM_MODE: freeze_bn 47 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 48 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 49 | ALWAYS_EVAL_TEST: true 50 | FINAL_EVAL_ON_TEST: true 51 | PROJECTION_LR_TIED_TO_PRETRAINED: false 52 | -------------------------------------------------------------------------------- /experiment_scripts/eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | expdir=$1 6 | 7 | if [[ $expdir == *"baseline"* ]] && [[ ! $expdir == *"finetune" ]]; then 8 | weights=None 9 | else 10 | weights=$expdir/latest_checkpoint.pth 11 | fi 12 | 13 | 14 | if ! compgen -G "${expdir}/fashioniq*eval.json" > /dev/null || [ $REDO_EVALS ]; then 15 | python main.py --config $expdir/config.yaml \ 16 | --no-train --no-timestamp --no-config-save \ 17 | DATASET.NAME fashioniq \ 18 | DATASET.PATH /home/default/ephemeral_drive/Data/FashionIQ/ \ 19 | MODEL.WEIGHTS $weights 20 | fi 21 | if ! compgen -G "${expdir}/imat_fashion*eval.json" > /dev/null || [ $REDO_EVALS ]; then 22 | python main.py --config $expdir/config.yaml \ 23 | --no-train --no-timestamp --no-config-save \ 24 | DATASET.NAME imat_fashion \ 25 | DATASET.PATH /home/default/ephemeral_drive/Data/imat2018/ \ 26 | MODEL.WEIGHTS $weights \ 27 | DATASET.AUGMENTATION.IMAGE_AUGMENTATION None 28 | fi 29 | if ! compgen -G "${expdir}/fashiongen*eval.json" > /dev/null || [ $REDO_EVALS ]; then 30 | python main.py --config $expdir/config.yaml \ 31 | --no-train --no-timestamp --no-config-save \ 32 | DATASET.NAME fashiongen \ 33 | DATASET.PATH /home/default/ephemeral_drive/Data/fashiongen/ \ 34 | MODEL.WEIGHTS $weights \ 35 | DATASET.AUGMENTATION.IMAGE_AUGMENTATION None 36 | fi 37 | if [[ ! -e $expdir/cfq_results.json ]] || [ $REDO_EVALS ]; then 38 | python src/maaf/actions/eval_cfq.py --config $expdir/config.yaml 39 | fi 40 | -------------------------------------------------------------------------------- /configs/clipresmaaf_fg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | NAME: fashiongen 6 | PATH: /home/default/ephemeral_drive/Data/fashiongen/ 7 | CLASS_WEIGHTS: null 8 | IMAGE_DIR: null 9 | REQUIRE_IMAGES: false 10 | DATA_LOADER: 11 | LOADER_NUM_WORKERS: 0 12 | EXP_NAME: clipresmaaf1-fg 13 | MODEL: 14 | COMPOSITION: clipresmaaf 15 | DEVICE: cuda 16 | DROPOUT_RATE: 0.1 17 | EMBED_DIM: 1024 18 | IMAGE_MODEL: 19 | ARCHITECTURE: null 20 | OUTPUTS: 21 | - 4 22 | - attnpool 23 | INCLUDES_IMAGE_TRANSFORM: false 24 | LOSS: double_softmax 25 | MAAF: 26 | ATTENTION_HEADS: 8 27 | ATTN_SOFTMAX_REPLACEMENT: null 28 | BLOCK_WIDTH: 256 29 | NUM_BLOCKS: 1 30 | OUTPUT: rwpool 31 | POSITION_ENCODING: null 32 | RESIDUAL: 33 | INITIAL_MAAF_WEIGHT: 0.01 34 | TEXT_MODEL: 35 | ARCHITECTURE: null 36 | TOKENIZER: null 37 | WEIGHTS: null 38 | INITIAL_NORMALIZATION_FACTOR: 4. 39 | CLIP: 40 | MISALIGNMENT: null 41 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 42 | SOLVER: 43 | OPTIMIZER: sgd 44 | BATCH_SIZE: 128 45 | DROP_WORST_RATE: 0 46 | EVAL_EVERY: 1 47 | LEARNING_RATE: 0.0001 48 | LEARNING_RATE_DECAY: 0.1 49 | NUM_ITERS: 24420 # 12 epochs 50 | LEARNING_RATE_DECAY_FREQUENCY: 20350 # 10 epochs 51 | LR_DECAY_ONLY_ONCE: false 52 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 53 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 54 | SAVE_EVERY: 100 55 | SOFTMAX_MARGIN: 0 56 | WEIGHT_DECAY: 1.0e-06 57 | PROJECTION_LR_TIED_TO_PRETRAINED: false 58 | BATCH_NORM_MODE: freeze_bn 59 | -------------------------------------------------------------------------------- /configs/clipresmaaf_imat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | CLASS_WEIGHTS: null 6 | IMAGE_DIR: null 7 | NAME: imat_fashion 8 | NUM_CLASSES: null 9 | PATH: /home/default/ephemeral_drive/Data/imat2018/ 10 | REQUIRE_IMAGES: false 11 | DATA_LOADER: 12 | LOADER_NUM_WORKERS: 0 13 | EXP_NAME: clipresmaaf1-imat 14 | MODEL: 15 | COMPOSITION: clipresmaaf 16 | DEVICE: cuda 17 | DROPOUT_RATE: 0.1 18 | EMBED_DIM: 1024 19 | IMAGE_MODEL: 20 | ARCHITECTURE: null 21 | OUTPUTS: 22 | - 4 23 | - attnpool 24 | INCLUDES_IMAGE_TRANSFORM: false 25 | LOSS: double_softmax 26 | MAAF: 27 | ATTENTION_HEADS: 8 28 | ATTN_SOFTMAX_REPLACEMENT: null 29 | BLOCK_WIDTH: 256 30 | NUM_BLOCKS: 1 31 | OUTPUT: rwpool 32 | POSITION_ENCODING: null 33 | RESIDUAL: 34 | INITIAL_MAAF_WEIGHT: 0.01 35 | TEXT_MODEL: 36 | ARCHITECTURE: null 37 | TOKENIZER: null 38 | WEIGHTS: null 39 | INITIAL_NORMALIZATION_FACTOR: 4. 40 | CLIP: 41 | MISALIGNMENT: null 42 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/ 43 | SOLVER: 44 | OPTIMIZER: sgd 45 | BATCH_SIZE: 128 46 | DROP_WORST_RATE: 0 47 | EVAL_EVERY: 1 48 | LEARNING_RATE: 0.0001 49 | LEARNING_RATE_DECAY: 0.1 50 | LEARNING_RATE_DECAY_FREQUENCY: 72540 # 10 epoch 51 | LR_DECAY_ONLY_ONCE: false 52 | NUM_ITERS: 87048 # 12 epochs 53 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 54 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 55 | SAVE_EVERY: 100 56 | SOFTMAX_MARGIN: 0 57 | WEIGHT_DECAY: 1.0e-06 58 | PROJECTION_LR_TIED_TO_PRETRAINED: false 59 | BATCH_NORM_MODE: freeze_bn 60 | -------------------------------------------------------------------------------- /configs/fiq_maaf_roberta.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | # adapted from old script compatibility logic output 6 | 7 | DATASET: 8 | CLASS_WEIGHTS: null 9 | IMAGE_DIR: null 10 | NAME: fashioniq 11 | NUM_CLASSES: null 12 | PATH: /home/default/Data/FashionIQ/ 13 | REQUIRE_IMAGES: false 14 | DATA_LOADER: 15 | LOADER_NUM_WORKERS: 0 16 | EXP_NAME: sequence_concat_attention_34_2_128_1_1_1_0 17 | MODEL: 18 | COMPOSITION: sequence_concat_attention 19 | DEVICE: cuda 20 | DROPOUT_RATE: 0.1 21 | EMBED_DIM: 512 22 | IMAGE_MODEL: 23 | ARCHITECTURE: resnet50 24 | FREEZE_WEIGHTS: false 25 | OUTPUTS: 26 | - 3 27 | - 4 28 | PRETRAINED: true 29 | WEIGHTS: null 30 | LOSS: batch_based_classification 31 | MAAF: 32 | ATTENTION_HEADS: 8 33 | ATTN_SOFTMAX_REPLACEMENT: null 34 | BLOCK_WIDTH: 128 35 | NUM_BLOCKS: 2 36 | OUTPUT: rwpool 37 | POSITION_ENCODING: null 38 | TEXT_MODEL: 39 | ARCHITECTURE: roberta 40 | EMBED_DIM: 512 41 | FREEZE_WEIGHTS: false 42 | MAX_VOCAB: 52000 43 | NUM_LAYERS: 1 44 | TOKENIZER: bpe 45 | TOKENIZER_PATH: roberta-base 46 | VOCAB_DATA: null 47 | VOCAB_MIN_FREQ: 0 48 | WEIGHTS: null 49 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/fashioniq/ 50 | SOLVER: 51 | BATCH_SIZE: 32 52 | DROP_WORST_RATE: 0 53 | EVAL_EVERY: 3 54 | LEARNING_RATE: 0.01 55 | LEARNING_RATE_DECAY: 0.1 56 | LEARNING_RATE_DECAY_FREQUENCY: 50000 57 | LR_DECAY_ONLY_ONCE: false 58 | NUM_ITERS: 150000 59 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 60 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 1.0 61 | SAVE_EVERY: 100 62 | SCHEDULE_ITERS: '' 63 | SCHEDULE_RATES: '' 64 | SOFTMAX_MARGIN: 0 65 | WEIGHT_DECAY: 1.0e-06 66 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Package setup file for python module maaf 4 | """ 5 | # WARNING: Do not modify this file unless you know exactly what you are doing. 6 | # 7 | # Note: If you believe you know exactly what you are doing, you are most likely wrong 8 | # 9 | # This file is an executable python script that will be included in the sdist package, this file 10 | # is used to obtain package metadata and install the package on other computer systems. As a 11 | # result, it must be able to execute without generating errors or execeptions in an environment with 12 | # only the python interpreter and nothing else. 13 | # 14 | # This file should never do any of the following: 15 | # - Import Python modules that are not included with the Python standard library 16 | # - Import any code that is contained in the package 17 | # - Execute external commands (do not use the subprocess module, os.system, os.popen*) 18 | # - Access or use any file that is not included in this package 19 | # - Call the sys.exit() or exit() functions unless there is an error that will prevent the package from installing. 20 | # - Generate any output to stdout (do not use the print function or call or use any code that uses it) 21 | import os 22 | import setuptools 23 | import sys 24 | 25 | 26 | def scripts(): 27 | """ 28 | Get the scripts in the "scripts" directory 29 | 30 | Returns 31 | list 32 | List of filenames 33 | """ 34 | script_list = [] 35 | if os.path.isdir('scripts'): 36 | for item in os.listdir('scripts'): 37 | filename = os.path.join('scripts', item) 38 | if os.path.isfile(filename): 39 | script_list.append(filename) 40 | return script_list 41 | 42 | 43 | if __name__ == '__main__': 44 | # We're being run from the command line so call setup with our arguments 45 | setuptools.setup(scripts=scripts()) 46 | -------------------------------------------------------------------------------- /src/maaf/utils/bn_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | import torch 6 | 7 | 8 | def apply_bn_mode(model, bn_mode): 9 | if bn_mode == "freeze_bn": 10 | freeze_bn(model) 11 | elif bn_mode == "freeze_except_bn": 12 | freeze_except_bn(model) 13 | elif bn_mode == "freeze_bn_averages": 14 | change_bn_mode(model, bn_mode="eval") 15 | elif bn_mode == "freeze_except_bn_averages": 16 | freeze_except_bn(model) 17 | change_bn_mode(model, bn_mode="train") 18 | elif bn_mode == "ordinary": 19 | pass 20 | else: 21 | raise ValueError(f"Invalid batch norm mode {bn_mode}") 22 | 23 | 24 | def freeze_bn(model): 25 | for module in model.modules(): 26 | if isinstance(module, torch.nn.BatchNorm2d): 27 | if hasattr(module, 'weight'): 28 | module.weight.requires_grad_(False) 29 | if hasattr(module, 'bias'): 30 | module.bias.requires_grad_(False) 31 | module.eval() 32 | 33 | 34 | def freeze_except_bn(model): 35 | for module in model.modules(): 36 | if isinstance(module, torch.nn.BatchNorm2d): 37 | if hasattr(module, 'weight'): 38 | module.weight.requires_grad_(True) 39 | if hasattr(module, 'bias'): 40 | module.bias.requires_grad_(True) 41 | module.train() 42 | else: 43 | for param in module.parameters(): 44 | param.requires_grad_(False) 45 | module.eval() 46 | 47 | 48 | def change_bn_mode(model, bn_mode="eval"): 49 | for module in model.modules(): 50 | if isinstance(module, torch.nn.BatchNorm2d): 51 | if bn_mode == "eval": 52 | module.eval() 53 | elif bn_mode == "train": 54 | module.train() 55 | else: 56 | raise ValueError(f"Invalid bn_mode {bn_mode}") 57 | -------------------------------------------------------------------------------- /configs/adam_crm_imat.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | DATASET: 5 | AUGMENTATION: 6 | IMAGE_AUGMENTATION: null 7 | CLASS_WEIGHTS: null 8 | IMAGE_DIR: null 9 | NAME: imat_fashion 10 | NUM_CLASSES: null 11 | PATH: /home/default/ephemeral_drive/Data/imat2018/ 12 | REQUIRE_IMAGES: false 13 | SINGLE_CLASS_BATCHES: false 14 | DATA_LOADER: 15 | LOADER_NUM_WORKERS: 0 16 | EXP_NAME: clipresmaaf-imat-lrm5-pt0_1-adam 17 | MODEL: 18 | CLIP: 19 | MISALIGNMENT: null 20 | PROMPT: '' 21 | COMPOSITION: clipresmaaf 22 | DEVICE: cuda 23 | DROPOUT_RATE: 0.1 24 | EMBED_DIM: 1024 25 | IMAGE_MODEL: 26 | ARCHITECTURE: null 27 | FREEZE_WEIGHTS: false 28 | OUTPUTS: 29 | - 4 30 | - attnpool 31 | PRETRAINED: true 32 | WEIGHTS: null 33 | INCLUDES_IMAGE_TRANSFORM: false 34 | INITIAL_NORMALIZATION_FACTOR: 4.0 35 | LOSS: batch_based_classification # consider double_softmax 36 | MAAF: 37 | ATTENTION_HEADS: 8 38 | ATTN_SOFTMAX_REPLACEMENT: null 39 | BLOCK_WIDTH: 256 40 | NUM_BLOCKS: 1 41 | OUTPUT: rwpool 42 | POSITION_ENCODING: null 43 | RESIDUAL: 44 | INITIAL_MAAF_PRESIGMOID: -5.0 45 | LEARN_WEIGHTS: true 46 | TEXT_MODEL: 47 | ARCHITECTURE: null 48 | EMBED_DIM: 512 49 | FREEZE_WEIGHTS: false 50 | MAX_TOKENS: 128 51 | MAX_VOCAB: 52000 52 | MODEL_PATH: null 53 | NUM_LAYERS: 1 54 | OUTPUT_RELU: false 55 | TOKENIZER: null 56 | TOKENIZER_PATH: null 57 | VOCAB_DATA: null 58 | VOCAB_MIN_FREQ: 0 59 | OUTPUT_DIR: /home/default/ephemeral_drive/experiments/paper/ 60 | SOLVER: 61 | ALWAYS_EVAL_TEST: false 62 | BATCH_NORM_MODE: freeze_bn 63 | BATCH_SIZE: 128 64 | DROP_WORST_RATE: 0 65 | EVAL_EVERY: 1 66 | FINAL_EVAL_ON_TEST: false 67 | LEARNING_RATE: 1.0e-05 68 | LEARNING_RATE_DECAY: 0.1 69 | LEARNING_RATE_DECAY_FREQUENCY: 7254 70 | LR_DECAY_ONLY_ONCE: false 71 | MOMENTUM: 0.9 72 | NUM_ITERS: 21762 73 | OPTIMIZER: adam 74 | PRETRAINED_WEIGHT_LR_FACTOR_IMAGE: 0.1 75 | PRETRAINED_WEIGHT_LR_FACTOR_TEXT: 0.1 76 | PROJECTION_LR_TIED_TO_PRETRAINED: false 77 | SAVE_EVERY: 100 78 | SCHEDULE_ITERS: [] 79 | SCHEDULE_RATES: [] 80 | SOFTMAX_MARGIN: 0 81 | WEIGHT_DECAY: 1.0e-06 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (Residual) Modality Agnostic Attention Fusion for visual search with text feedback 2 | 3 | The methods in this repository were used for the experiments described in two papers from Yahoo's Visual Intelligence Research team. The [more recent paper](https://arxiv.org/abs/2204.11004)'s main results can be reproduced using the scripts in the `experiment_scripts` directory. If you find this code useful please cite 4 | ``` 5 | @article{dodds2022training, 6 | title = {Training and challenging models for text-guided fashion image retrieval}, 7 | author = {Dodds, Eric and Culpepper, Jack and Srivastava, Gaurav}, 8 | journal={arXiv preprint arXiv:2204.11004} 9 | year = {2022}, 10 | doi = {10.48550/ARXIV.2204.11004}, 11 | } 12 | ``` 13 | 14 | We also recommend using the latest version of the code if you wish to build upon our general methods. However if you are interested specifically in reproducing the results in our earlier paper or using datasets discussed there, it will likely be easier to start from commit [49a0df9](https://github.com/yahoo/maaf/commit/49a0df90baf4b9d4a194ed646620375b5b837b15). The [earlier paper](https://arxiv.org/abs/2007.00145) can be cited as: 15 | ``` 16 | @article{dodds2020modality, 17 | title={Modality-Agnostic Attention Fusion for visual search with text feedback}, 18 | author={Dodds, Eric and Culpepper, Jack and Herdade, Simao and Zhang, Yang and Boakye, Kofi}, 19 | journal={arXiv preprint arXiv:2007.00145}, 20 | year={2020} 21 | } 22 | ``` 23 | 24 | This codebase was originally adapted from [TIRG code](https://github.com/google/tirg) written by the authors of [Composing Text and Image for Image Retrieval - An Empirical Odyssey](https://arxiv.org/abs/1812.07119). The core model and training code is based on. Transformer code is adapted from [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html). Further modifications are our own. 25 | We use [YACS](https://github.com/rbgirshick/yacs) for configurations. 26 | 27 | ## Setup 28 | 29 | The code is tested on Python 3.6 with PyTorch 1.5 and should also work on newer versions. Installing with pip should install the requirements. 30 | 31 | ## Datasets 32 | 33 | ### Challenging Fashion Queries (CFQ) 34 | 35 | The Challenging Fashion Queries dataset described in our paper can be found [here](https://webscope.sandbox.yahoo.com/catalog.php?datatype=a&did=92) and used for research purposes. 36 | 37 | We do not own any of other datasets used in our experiments here. Below we link to the datasets where we acquired them. 38 | 39 | ### Fashion IQ 40 | 41 | Download the dataset from [here](https://github.com/XiaoxiaoGuo/fashion-iq). 42 | -------------------------------------------------------------------------------- /src/maaf/config/config.py: -------------------------------------------------------------------------------- 1 | # adapted from fvcore.common.config.py 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | import os 4 | from yacs.config import CfgNode as YacsNode 5 | 6 | BASE_KEY = "_BASE_" 7 | 8 | def get_config(): 9 | from .defaults import _C 10 | return _C.clone() 11 | 12 | class CfgNode(YacsNode): 13 | """ 14 | Extended version of :class:`yacs.config.CfgNode`. 15 | It contains the following extra features: 16 | 1. The :meth:`merge_from_file` method supports the "_BASE_" key, 17 | which allows the new CfgNode to inherit all the attributes from the 18 | base configuration file. 19 | """ 20 | 21 | @classmethod 22 | def load_yaml_with_base(cls, filename): 23 | """ 24 | Just like `yaml.load(open(filename))`, but inherit attributes from its 25 | `_BASE_`. 26 | Args: 27 | filename (str or file-like object): the file name or file of the current config. 28 | Will be used to find the base config file. 29 | Returns: 30 | (dict): the loaded yaml 31 | """ 32 | with open(filename, "r") as f: 33 | cfg = cls.load_cfg(f) 34 | 35 | def merge_a_into_b(a, b): 36 | # merge dict a into dict b. values in a will overwrite b. 37 | for k, v in a.items(): 38 | if isinstance(v, dict) and k in b: 39 | assert isinstance( 40 | b[k], dict 41 | ), "Cannot inherit key '{}' from base!".format(k) 42 | merge_a_into_b(v, b[k]) 43 | else: 44 | b[k] = v 45 | 46 | if BASE_KEY in cfg: 47 | base_cfg_file = cfg[BASE_KEY] 48 | if base_cfg_file.startswith("~"): 49 | base_cfg_file = os.path.expanduser(base_cfg_file) 50 | if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])): 51 | # the path to base cfg is relative to the config file itself. 52 | base_cfg_file = os.path.join(os.path.dirname(filename), base_cfg_file) 53 | base_cfg = cls.load_yaml_with_base(base_cfg_file) 54 | del cfg[BASE_KEY] 55 | 56 | merge_a_into_b(cfg, base_cfg) 57 | return base_cfg 58 | return cfg 59 | 60 | def merge_from_file(self, cfg_filename, allow_unsafe = False): 61 | """ 62 | Merge configs from a given yaml file. 63 | Args: 64 | cfg_filename: the file name of the yaml config. 65 | allow_unsafe: whether to allow loading the config file with 66 | `yaml.unsafe_load`. 67 | """ 68 | loaded_cfg = self.load_yaml_with_base(cfg_filename) 69 | loaded_cfg = type(self)(loaded_cfg) 70 | self.merge_from_other_cfg(loaded_cfg) 71 | -------------------------------------------------------------------------------- /src/maaf/actions/eval_cfq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | from maaf.main import setup 6 | from maaf.models.build import build_model 7 | from maaf.datasets.cfq import CFQSet 8 | from maaf.datasets.datasets import get_default_image_transform 9 | import argparse 10 | import os 11 | import json 12 | import torch 13 | 14 | 15 | def parse_opt(): 16 | parser = argparse.ArgumentParser() 17 | add_arg = parser.add_argument 18 | 19 | add_arg('--config_file', type=str) 20 | add_arg('--weights_path', type=str, default=None) 21 | add_arg('--debug', action="store_true") 22 | add_arg('--non-strict_loading', action="store_false", dest="strict_loading") 23 | add_arg('--data_path', default="/home/default/ephemeral_drive/Data/cfq/") 24 | add_arg('--output_path', type=str, default=None) 25 | 26 | add_arg( 27 | "opts", 28 | help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. ", 29 | default=None, 30 | nargs=argparse.REMAINDER, 31 | ) 32 | 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def evaluate(): 38 | args = parse_opt() 39 | cfg = setup(args) 40 | 41 | model, task = build_model(cfg, None, strict_loading=args.strict_loading) 42 | if args.weights_path is None: 43 | weights_path = os.path.dirname(args.config_file) 44 | weights_path = os.path.join(weights_path, "latest_checkpoint.pth") 45 | else: 46 | weights_path = args.weights_path 47 | if not os.path.exists(weights_path): 48 | print(f"Checkpoint {weights_path} not found, evaluating without") 49 | else: 50 | print(f"Loading from {weights_path}...") 51 | state_dict = torch.load(weights_path, map_location=model.device)["model_state_dict"] 52 | model.load_state_dict(state_dict, strict=args.strict_loading) 53 | 54 | data_path = os.path.join(args.data_path) 55 | image_path = os.path.join(args.data_path, "images") 56 | if hasattr(model, "image_transform"): 57 | transform = model.image_transform 58 | else: 59 | transform = get_default_image_transform( 60 | clip="clip" in cfg.MODEL.COMPOSITION) 61 | datasets = CFQSet(data_path, image_path, transform) 62 | 63 | print("Computing metrics...") 64 | results = datasets.compute_metrics(model, with_dots=False) 65 | primary = datasets.get_primary_metrics(results) 66 | for key, val in results.items(): 67 | for kk, vv in val.items(): 68 | if not isinstance(vv, dict): 69 | results[key][kk] = vv.to_dict() # convert Series for json serialization 70 | 71 | for met, res in primary.items(): 72 | print(met, res) 73 | 74 | if args.output_path is None: 75 | output_path = os.path.dirname(args.config_file) 76 | output_path = os.path.join(output_path, "cfq_results.json") 77 | else: 78 | output_path = args.output_path 79 | 80 | if args.debug: 81 | import IPython 82 | IPython.embed() 83 | 84 | with open(output_path, "w") as fh: 85 | json.dump(results, fh) 86 | 87 | 88 | if __name__ == "__main__": 89 | evaluate() 90 | -------------------------------------------------------------------------------- /src/maaf/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | import torch 5 | import torch.utils.data 6 | from torchvision import transforms as tvt 7 | 8 | 9 | def get_image_normalizer(): 10 | return tvt.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 11 | 12 | 13 | def get_default_image_transform(clip=False): 14 | if clip: 15 | from ..models.clip import get_image_transform_for_clip 16 | return get_image_transform_for_clip() 17 | normalizer = get_image_normalizer() 18 | transform = tvt.Compose([ 19 | tvt.Resize(224), 20 | tvt.CenterCrop(224), 21 | tvt.ToTensor(), 22 | normalizer]) 23 | return transform 24 | 25 | 26 | def get_augmenting_image_transform(clip=False): 27 | if clip: 28 | from ..models.clip import get_augmenting_image_transform_for_clip 29 | return get_augmenting_image_transform_for_clip() 30 | normalizer = get_image_normalizer() 31 | train_transform = tvt.Compose([ 32 | tvt.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.3)), 33 | tvt.RandomHorizontalFlip(), 34 | tvt.ToTensor(), 35 | tvt.Lambda(lambda xx: xx + 0.01 * torch.randn(xx.shape, device="cpu")), 36 | normalizer 37 | ]) 38 | 39 | return train_transform 40 | 41 | 42 | def load_dataset(cfg): 43 | """Loads the input datasets.""" 44 | print('Reading dataset ', cfg.DATASET.NAME) 45 | 46 | if cfg.MODEL.INCLUDES_IMAGE_TRANSFORM: 47 | transform = None 48 | else: 49 | transform = get_default_image_transform(clip="clip" in cfg.MODEL.COMPOSITION) 50 | 51 | if cfg.DATASET.AUGMENTATION.IMAGE_AUGMENTATION is not None: 52 | if cfg.DATASET.NAME != "fashioniq" or cfg.MODEL.INCLUDES_IMAGE_TRANSFORM: 53 | raise NotImplementedError() 54 | train_transform = get_augmenting_image_transform( 55 | clip="clip" in cfg.MODEL.COMPOSITION) 56 | else: 57 | train_transform = transform 58 | 59 | if cfg.DATASET.NAME == 'fashioniq': 60 | from .fashioniq import FashionIQDataset as DatasetClass 61 | 62 | trainset = DatasetClass( 63 | path=cfg.DATASET.PATH, 64 | split='train', 65 | transform=train_transform) 66 | valset = DatasetClass( 67 | path=cfg.DATASET.PATH, 68 | split='val', 69 | transform=transform) 70 | testset = DatasetClass( 71 | path=cfg.DATASET.PATH, 72 | split='test', 73 | transform=transform) 74 | dataset_dict = {"train": trainset, "val": valset, "test": testset} 75 | else: 76 | import importlib 77 | datamod = importlib.import_module(f"maaf.datasets.{cfg.DATASET.NAME}") 78 | DatasetClass = getattr(datamod, datamod.DATASET_CLASS_NAME) 79 | dataset_dict = {split: DatasetClass(path=cfg.DATASET.PATH, split=split, 80 | transform=transform) 81 | for split in ["train", "val", "test"]} 82 | 83 | for name, data in dataset_dict.items(): 84 | if data is not None: 85 | print(name, 'size', len(data)) 86 | 87 | if "test" not in dataset_dict: 88 | dataset_dict["test"] = None 89 | 90 | return dataset_dict 91 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | author = Eric Dodds 3 | author_email = edodds@yahooinc.com 4 | classifiers = 5 | Programming Language :: Python :: 3 6 | description = Python module maaf 7 | long_description = file:README.md 8 | long_description_content_type = text/markdown 9 | name = maaf 10 | version = 2.0.0 11 | 12 | [options] 13 | package_dir = 14 | = src 15 | 16 | namespace_packages = 17 | 18 | packages = find: 19 | 20 | [options.packages.find] 21 | where = src 22 | 23 | 24 | install_requires = 25 | numpy>=1.16.4 26 | Pillow>=6.1.0 27 | tensorboardX 28 | torch>=1.4 29 | torchvision>=0.7.0 30 | tqdm 31 | tokenizers>=0.8.1.rc1 32 | transformers>=3.0.2 33 | yacs 34 | gitpython 35 | 36 | # By default new packages require at minimum the current supported Python release. 37 | python_requires = >="3.6" 38 | 39 | [options.extras_require] 40 | # This config section allows you to define optional dependencies. For the general case, the defaults will 41 | # work fine. So these settings aren't required. However, many of the screwdriver CI Pipeline steps 42 | # will install the appropriate extras for that step. This makes it possible to install packages that install 43 | # or enhance the functionality of the CI Pipeline step. 44 | # Such as packages that implement plugins or themes for the step in question. 45 | 46 | # Additional packages for testing (test step) 47 | test = 48 | 49 | # Additonal packages needed for documentation generation (doc_build/doc_publish steps) 50 | # If you want to use a sphinx theme from a package, list it here. 51 | doc_build = 52 | 53 | # Additional packages needed for mypy type checking 54 | mypy = 55 | 56 | # Additional packages needed for pep8/pycodestyle style checking 57 | pep8 = 58 | 59 | # Additional packages needed for pylint code analysis 60 | pylint = 61 | 62 | [options.entry_points] 63 | # Console script entry points are used to create wrapper scripts that run a specific function, the resulting wrapper 64 | # is installed in the bin directory. 65 | 66 | # They are defined using the following format: 67 | # scriptname = modulename:function 68 | # console_scripts = 69 | # multimodal=maaf.cli:main 70 | 71 | [screwdrivercd.version] 72 | # Package versioning plugin/method to use. 73 | # 74 | # Note: 75 | # Switching between different plugins can be difficult, since new package numbers need to have a higher 76 | # major number than previously published packages. So for example, if the version in the metadata of 77 | # this file is 0.0.0 (default value) and the version_type below is sdv4_SD_BUILD. The version generated 78 | # will be 0.0.build_number. If the version_type is changed to sdv4_date the version generated will be 79 | # year.month.build_number. If the version_type is then changed badk to the sdv4_SD_BUILD the versions 80 | # generated will be again 0.0.build_number which will be a lower version than those generated by the sdv4_date 81 | # plugin, so installations will always treat the older packages generated by the sdv4_date plugin as newer 82 | # until the version in this file is updated with a major number (first part) that is higher than the 2 digit year 83 | # number. 84 | # 85 | # Essentially this means changing to using the sdv4_date plugin to use date based versions can be done easily, but 86 | # switching back to the default requires more changes than simply changing the value back. 87 | 88 | # These versioners require the CI Pipeline to have an update_version build step that runs before 89 | # any packaging steps. 90 | 91 | # Base the autoversion build number on the screwdriver build number 92 | version_type = sdv4_SD_BUILD 93 | 94 | # Base the autoversion build number on the current date and the screwdriver build number 95 | # version_type = sdv4_date 96 | -------------------------------------------------------------------------------- /src/maaf/models/heads.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | from .loss import build_loss 4 | 5 | 6 | class TaskHead(torch.nn.Module, ABC): 7 | 8 | def __init__(self): 9 | torch.nn.Module.__init__(self) 10 | 11 | @abstractmethod 12 | def forward(self, images, texts): 13 | pass 14 | 15 | @abstractmethod 16 | def compute_loss(self, source, target=None, labels=None): 17 | pass 18 | 19 | 20 | class Classification(TaskHead): 21 | 22 | def __init__(self, loss, embed_dim=512, num_classes=3): 23 | super().__init__() 24 | self.classification_head = torch.nn.Linear(embed_dim, num_classes) 25 | self.softmax = torch.nn.Softmax(dim=1) 26 | self.loss = loss 27 | 28 | def forward(self, composed): 29 | return self.classification_head(composed) 30 | 31 | def probabilities(self, composed): 32 | logits = self(composed) 33 | return self.softmax(logits) 34 | 35 | def compute_loss(self, source, target=None, labels=None): 36 | logits = self(source) 37 | preds = torch.argmax(logits, dim=1) 38 | accuracy = (preds == labels).sum().item() / len(labels) 39 | loss_value = self.loss(logits, labels) 40 | metrics = {"loss": loss_value.item(), 41 | "accuracy": accuracy} 42 | return loss_value, metrics 43 | 44 | 45 | class Regression(TaskHead): 46 | 47 | def __init__(self, loss, embed_dim=512): 48 | super().__init__() 49 | self.regression_head = torch.nn.Linear(embed_dim, 1) 50 | self.loss = loss 51 | 52 | def forward(self, composed): 53 | return torch.sigmoid(self.regression_head(composed)) 54 | 55 | def compute_loss(self, source, target=None, labels=None): 56 | output = self(source) 57 | loss_value = self.loss(output, labels.float()) 58 | metrics = {"loss": loss_value.item()} 59 | return loss_value, metrics 60 | 61 | 62 | class NormalizationLayer(torch.nn.Module): 63 | """Class for normalization layer.""" 64 | 65 | def __init__(self, normalize_scale=1.0, learn_scale=True): 66 | super().__init__() 67 | self.norm_s = torch.log(torch.FloatTensor([normalize_scale])) 68 | if learn_scale: 69 | self.norm_s = torch.nn.Parameter(self.norm_s) 70 | self.epsilon = 1e-9 71 | 72 | def forward(self, x): 73 | norm = torch.norm(x, dim=1, keepdim=True).expand_as(x) 74 | factor = torch.exp(self.norm_s) 75 | features = factor * x / (norm + self.epsilon) 76 | return features 77 | 78 | 79 | class Metric(TaskHead): 80 | 81 | def __init__(self, loss, initial_normalization_factor=4.0): 82 | super().__init__() 83 | self.loss = loss 84 | self.normalization_layer = NormalizationLayer( 85 | normalize_scale=initial_normalization_factor, learn_scale=True) 86 | 87 | def forward(self, composed): 88 | return self.normalization_layer(composed) 89 | 90 | def compute_loss(self, source, target, labels=None): 91 | source_emb = self.forward(source) 92 | target_emb = self.forward(target) 93 | 94 | assert source_emb.shape[1] == target_emb.shape[1] 95 | loss_value = self.loss(source_emb, target_emb, labels=labels) 96 | metrics = {"loss": loss_value.item()} 97 | if torch.isnan(loss_value): 98 | import IPython; IPython.embed() 99 | return loss_value, metrics 100 | 101 | 102 | def get_task_head(cfg): 103 | loss_obj, task = build_loss(cfg) 104 | 105 | if task == "metric": 106 | head = Metric(loss_obj, cfg.MODEL.INITIAL_NORMALIZATION_FACTOR) 107 | elif task == "regression": 108 | head = Regression(loss_obj, cfg.MODEL.EMBED_DIM) 109 | else: 110 | head = Classification(loss_obj, embed_dim=cfg.MODEL.EMBED_DIM, 111 | num_classes=cfg.DATASET.NUM_CLASSES) 112 | 113 | return head, task 114 | -------------------------------------------------------------------------------- /src/maaf/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | import os 6 | import torch 7 | import datetime 8 | import json 9 | from .datasets.datasets import load_dataset 10 | from tensorboardX import SummaryWriter 11 | import git # pip install gitpython 12 | from .config.arguments import parse_opt 13 | from .models.build import build_model, get_optimizer 14 | from .config import get_config 15 | from .train import Trainer, MetricTrainer 16 | 17 | # avoids a crash on some systems 18 | torch.set_num_threads(1) 19 | 20 | 21 | def setup(args, modify_exp_name=False): 22 | cfg = get_config() 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | if cfg.MODEL.DEVICE != "cpu" and not torch.cuda.is_available(): 26 | cfg.MODEL.DEVICE = "cpu" 27 | if modify_exp_name: 28 | curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 29 | cfg.EXP_NAME = cfg.EXP_NAME + f"-{curr_time}" 30 | cfg.freeze() 31 | return cfg 32 | 33 | 34 | def main(old_args=False): 35 | if old_args: 36 | from .config.compat import compat_setup 37 | cfg, args = compat_setup() 38 | else: 39 | args = parse_opt() 40 | append_datetime = (args.resume is None) and args.timestamp 41 | cfg = setup(args, modify_exp_name=append_datetime) 42 | 43 | if args.resume is not None: 44 | logger = SummaryWriter(args.resume) 45 | else: 46 | logger = SummaryWriter(logdir=os.path.join(cfg.OUTPUT_DIR, cfg.EXP_NAME)) 47 | print('Log files saved to', logger.file_writer.get_logdir()) 48 | 49 | if args.save_config: 50 | with open(os.path.join(logger.file_writer.get_logdir(), "config.yaml"), 51 | "w") as fh: 52 | fh.write(cfg.dump()) 53 | 54 | # get and save the version of the code being run 55 | repo = git.Repo(search_parent_directories=True) 56 | sha = repo.head.object.hexsha 57 | logger.add_text("git_sha", sha) 58 | 59 | dataset_dict = load_dataset(cfg) 60 | if cfg.MODEL.TEXT_MODEL.TOKENIZER == "simple": 61 | texts = dataset_dict["train"].get_all_texts() 62 | else: 63 | texts = None 64 | model, task = build_model(cfg, texts, strict_loading=args.strict_loading) 65 | optimizer = get_optimizer(cfg, model) 66 | 67 | if args.resume is not None: 68 | print("loading from: %s" % args.resume) 69 | loaded_dict = torch.load( 70 | logger.file_writer.get_logdir() + "/latest_checkpoint.pth") 71 | model.load_state_dict(loaded_dict["model_state_dict"]) 72 | iteration = loaded_dict["it"] 73 | else: 74 | iteration = 0 75 | 76 | if task == "metric": 77 | trainer = MetricTrainer(cfg, logger, dataset_dict, model, optimizer, 78 | iteration) 79 | else: 80 | trainer = Trainer(cfg, logger, dataset_dict, model, optimizer, 81 | iteration) 82 | 83 | if args.debug: 84 | import IPython 85 | IPython.embed() 86 | 87 | if args.train: 88 | iteration = trainer.train() 89 | 90 | if args.eval: 91 | results = trainer.run_eval(eval_on_test=args.final_eval_on_test) 92 | results = {key: val for key, val in results} 93 | 94 | results_file = os.path.join( 95 | logger.file_writer.get_logdir(), 96 | f"{cfg.DATASET.NAME}-{iteration}-eval.json") 97 | with open(results_file, "w") as fh: 98 | json.dump(results, fh) 99 | print(f"Evaluation results saved to {results_file}") 100 | 101 | # if cfg.DATASET.NAME == "fashioniq": 102 | # print('Generating FashionIQ submission...') 103 | # eval_retrieval.predict(cfg, model, dataset_dict["test"], 104 | # filter_categories=True) 105 | # print('done') 106 | 107 | logger.close() 108 | -------------------------------------------------------------------------------- /src/maaf/config/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | from .config import CfgNode 5 | 6 | _C = CfgNode() 7 | 8 | _C.EXP_NAME = "debug" 9 | _C.OUTPUT_DIR = "./output" 10 | 11 | # ---------------------------------------------------------------------------- # 12 | # Model 13 | # ---------------------------------------------------------------------------- # 14 | _C.MODEL = CfgNode() 15 | _C.MODEL.DEVICE = "cuda" 16 | _C.MODEL.COMPOSITION = "maaf" 17 | _C.MODEL.WEIGHTS = None # load from this path 18 | _C.MODEL.DROPOUT_RATE = 0.1 19 | _C.MODEL.EMBED_DIM = 512 20 | _C.MODEL.LOSS = "batch_based_classification" 21 | _C.MODEL.INCLUDES_IMAGE_TRANSFORM = False 22 | _C.MODEL.INITIAL_NORMALIZATION_FACTOR = 4.0 23 | 24 | _C.MODEL.TEXT_MODEL = CfgNode() 25 | _C.MODEL.TEXT_MODEL.ARCHITECTURE = "lstm" 26 | _C.MODEL.TEXT_MODEL.TOKENIZER = "simple" 27 | _C.MODEL.TEXT_MODEL.VOCAB_DATA = None # path to text file to create vocab 28 | _C.MODEL.TEXT_MODEL.VOCAB_MIN_FREQ = 0 29 | _C.MODEL.TEXT_MODEL.MAX_VOCAB = 52000 30 | _C.MODEL.TEXT_MODEL.TOKENIZER_PATH = None 31 | _C.MODEL.TEXT_MODEL.NUM_LAYERS = 1 32 | _C.MODEL.TEXT_MODEL.FREEZE_WEIGHTS = False 33 | _C.MODEL.TEXT_MODEL.EMBED_DIM = 512 # may need to equal MODEL.EMBED_DIM 34 | _C.MODEL.TEXT_MODEL.OUTPUT_RELU = False 35 | _C.MODEL.TEXT_MODEL.MAX_TOKENS = 128 36 | _C.MODEL.TEXT_MODEL.MODEL_PATH = None 37 | 38 | _C.MODEL.IMAGE_MODEL = CfgNode() 39 | _C.MODEL.IMAGE_MODEL.ARCHITECTURE = "resnet50" 40 | _C.MODEL.IMAGE_MODEL.WEIGHTS = None 41 | _C.MODEL.IMAGE_MODEL.PRETRAINED = True 42 | _C.MODEL.IMAGE_MODEL.FREEZE_WEIGHTS = False 43 | _C.MODEL.IMAGE_MODEL.OUTPUTS = [3, 4] 44 | 45 | _C.MODEL.MAAF = CfgNode() 46 | _C.MODEL.MAAF.NUM_BLOCKS = 1 47 | _C.MODEL.MAAF.BLOCK_WIDTH = 256 48 | _C.MODEL.MAAF.ATTENTION_HEADS = 8 49 | _C.MODEL.MAAF.POSITION_ENCODING = None 50 | _C.MODEL.MAAF.OUTPUT = "simple_pool" # rwpool, token 51 | _C.MODEL.MAAF.ATTN_SOFTMAX_REPLACEMENT = None 52 | _C.MODEL.MAAF.RESIDUAL = CfgNode() 53 | _C.MODEL.MAAF.RESIDUAL.LEARN_WEIGHTS = False 54 | _C.MODEL.MAAF.RESIDUAL.INITIAL_MAAF_WEIGHT = 1. 55 | _C.MODEL.MAAF.RESIDUAL.INITIAL_MAAF_PRESIGMOID = None # deprecated 56 | 57 | _C.MODEL.CLIP = CfgNode() 58 | _C.MODEL.CLIP.PROMPT = "" 59 | _C.MODEL.CLIP.MISALIGNMENT = None 60 | 61 | # ---------------------------------------------------------------------------- # 62 | # Dataset 63 | # ---------------------------------------------------------------------------- # 64 | _C.DATASET = CfgNode() 65 | _C.DATASET.NAME = "fashioniq" 66 | _C.DATASET.PATH = '/home/default/Data/fashioniq' 67 | _C.DATASET.IMAGE_DIR = "" 68 | _C.DATASET.REQUIRE_IMAGES = False 69 | _C.DATASET.NUM_CLASSES = 3 70 | _C.DATASET.CLASS_WEIGHTS = [1, 1, 1] 71 | 72 | _C.DATASET.SINGLE_CLASS_BATCHES = False 73 | 74 | _C.DATASET.DPA_ATTRIBUTES = CfgNode() 75 | _C.DATASET.DPA_ATTRIBUTES.DELETE_TERMS = None 76 | _C.DATASET.DPA_ATTRIBUTES.USE_CATEGORY = False 77 | 78 | _C.DATASET.CROSS_MODAL = CfgNode() 79 | _C.DATASET.CROSS_MODAL.SOURCE = "title" 80 | 81 | _C.DATASET.AUGMENTATION = CfgNode() 82 | _C.DATASET.AUGMENTATION.IMAGE_AUGMENTATION = None 83 | 84 | _C.DATA_LOADER = CfgNode() 85 | _C.DATA_LOADER.LOADER_NUM_WORKERS = 4 86 | 87 | # ---------------------------------------------------------------------------- # 88 | # Solver 89 | # ---------------------------------------------------------------------------- # 90 | _C.SOLVER = CfgNode() 91 | _C.SOLVER.OPTIMIZER = "sgd" 92 | _C.SOLVER.BATCH_SIZE = 32 93 | _C.SOLVER.WEIGHT_DECAY = 1e-6 94 | _C.SOLVER.NUM_ITERS = 150000 95 | _C.SOLVER.DROP_WORST_RATE = 0 # 0.2 96 | _C.SOLVER.SOFTMAX_MARGIN = 0 97 | 98 | _C.SOLVER.LEARNING_RATE = 1e-2 99 | _C.SOLVER.LEARNING_RATE_DECAY = 0.1 100 | _C.SOLVER.LEARNING_RATE_DECAY_FREQUENCY = 9999999 101 | _C.SOLVER.LR_DECAY_ONLY_ONCE = False 102 | _C.SOLVER.SCHEDULE_RATES = [] 103 | _C.SOLVER.SCHEDULE_ITERS = [] 104 | _C.SOLVER.PRETRAINED_WEIGHT_LR_FACTOR_TEXT = 1. 105 | _C.SOLVER.PRETRAINED_WEIGHT_LR_FACTOR_IMAGE = 0.1 106 | # By default, projections from image/text model get learning rates including 107 | # the 2 factors above. Setting the parameter below to False changes this. 108 | _C.SOLVER.PROJECTION_LR_TIED_TO_PRETRAINED = True 109 | _C.SOLVER.MOMENTUM = 0.9 110 | 111 | _C.SOLVER.SAVE_EVERY = 100 112 | _C.SOLVER.EVAL_EVERY = 3 113 | _C.SOLVER.FINAL_EVAL_ON_TEST = False 114 | _C.SOLVER.ALWAYS_EVAL_TEST = False 115 | 116 | _C.SOLVER.BATCH_NORM_MODE = "ordinary" 117 | -------------------------------------------------------------------------------- /src/maaf/config/compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | """Backwards compatibility""" 6 | from .config import CfgNode 7 | from .arguments import old_parse_opt 8 | 9 | 10 | MAAF_ALIASES = ["sequence_concat_attention", "seqcat_outtoken", 11 | "concat_attention", "maaf"] 12 | 13 | def compat_setup(): 14 | args = old_parse_opt() 15 | cfg = config_from_args(args).clone() 16 | cfg.freeze() 17 | args.resume = args.load 18 | return cfg, args 19 | 20 | 21 | def config_from_args(args): 22 | _C = CfgNode() 23 | _C.EXP_NAME = args.exp_name 24 | _C.OUTPUT_DIR = args.savedir 25 | 26 | # ----------------------------------------------------------------------- # 27 | # Model 28 | # ----------------------------------------------------------------------- # 29 | _C.MODEL = CfgNode() 30 | _C.MODEL.DEVICE = args.device 31 | _C.MODEL.COMPOSITION = args.model 32 | _C.MODEL.WEIGHTS = args.load 33 | _C.MODEL.DROPOUT_RATE = args.dropout_rate 34 | _C.MODEL.EMBED_DIM = args.embed_dim 35 | _C.MODEL.LOSS = args.loss 36 | 37 | _C.MODEL.TEXT_MODEL = CfgNode() 38 | _C.MODEL.TEXT_MODEL.ARCHITECTURE = \ 39 | None if args.image_only else args.text_model_arch 40 | _C.MODEL.TEXT_MODEL.TOKENIZER = args.text_tokenizer 41 | _C.MODEL.TEXT_MODEL.VOCAB_DATA = args.text_data 42 | _C.MODEL.TEXT_MODEL.VOCAB_MIN_FREQ = args.threshold_rare_words 43 | _C.MODEL.TEXT_MODEL.MAX_VOCAB = args.max_vocab 44 | _C.MODEL.TEXT_MODEL.TOKENIZER_PATH = args.tokenizer_path 45 | _C.MODEL.TEXT_MODEL.NUM_LAYERS = args.text_model_layers 46 | _C.MODEL.TEXT_MODEL.FREEZE_WEIGHTS = args.freeze_text_model 47 | _C.MODEL.TEXT_MODEL.EMBED_DIM = args.embed_dim 48 | _C.MODEL.TEXT_MODEL.OUTPUT_RELU = False 49 | 50 | _C.MODEL.IMAGE_MODEL = CfgNode() 51 | _C.MODEL.IMAGE_MODEL.ARCHITECTURE = \ 52 | None if args.text_only else args.image_model_arch 53 | _C.MODEL.IMAGE_MODEL.WEIGHTS = args.image_model_path 54 | _C.MODEL.IMAGE_MODEL.PRETRAINED = not args.not_pretrained 55 | _C.MODEL.IMAGE_MODEL.FREEZE_WEIGHTS = args.freeze_img_model 56 | _C.MODEL.IMAGE_MODEL.OUTPUTS = \ 57 | [ii for ii in range(5) if str(ii) in args.att_layer_spec] 58 | 59 | _C.MODEL.MAAF = CfgNode() 60 | _C.MODEL.MAAF.NUM_BLOCKS = args.number_attention_blocks 61 | _C.MODEL.MAAF.BLOCK_WIDTH = args.width_per_attention_block 62 | _C.MODEL.MAAF.ATTENTION_HEADS = args.number_attention_heads 63 | _C.MODEL.MAAF.POSITION_ENCODING = args.attn_positional_encoding 64 | maaf_out = "simple_pool" 65 | if args.resolutionwise_pool: 66 | maaf_out = "rwpool" 67 | if args.model == "seqcat_outtoken": 68 | maaf_out = "token" 69 | _C.MODEL.MAAF.OUTPUT = maaf_out 70 | _C.MODEL.MAAF.ATTN_SOFTMAX_REPLACEMENT = args.attn_softmax_replacement 71 | 72 | # ----------------------------------------------------------------------- # 73 | # Dataset 74 | # ----------------------------------------------------------------------- # 75 | _C.DATASET = CfgNode() 76 | _C.DATASET.NAME = args.dataset 77 | _C.DATASET.PATH = args.dataset_path 78 | _C.DATASET.IMAGE_DIR = args.image_dir 79 | _C.DATASET.REQUIRE_IMAGES = args.require_images 80 | _C.DATASET.NUM_CLASSES = args.num_classes 81 | _C.DATASET.CLASS_WEIGHTS = args.class_weights 82 | 83 | _C.DATASET.SINGLE_CLASS_BATCHES = args.dataset == "fashioniq" 84 | 85 | _C.DATASET.DPA_ATTRIBUTES = CfgNode() 86 | _C.DATASET.DPA_ATTRIBUTES.DELETE_TERMS = None 87 | _C.DATASET.DPA_ATTRIBUTES.USE_CATEGORY = False 88 | 89 | _C.DATA_LOADER = CfgNode() 90 | _C.DATA_LOADER.LOADER_NUM_WORKERS = args.loader_num_workers 91 | 92 | # ----------------------------------------------------------------------- # 93 | # Solver 94 | # ----------------------------------------------------------------------- # 95 | _C.SOLVER = CfgNode() 96 | _C.SOLVER.BATCH_SIZE = args.batch_size 97 | _C.SOLVER.WEIGHT_DECAY = args.weight_decay 98 | _C.SOLVER.NUM_ITERS = args.num_iters 99 | _C.SOLVER.DROP_WORST_RATE = \ 100 | args.drop_worst_rate if args.drop_worst_flag else 0 101 | _C.SOLVER.SOFTMAX_MARGIN = args.softmax_margin 102 | 103 | _C.SOLVER.LEARNING_RATE = args.learning_rate 104 | _C.SOLVER.LEARNING_RATE_DECAY = args.learning_rate_decay 105 | _C.SOLVER.LEARNING_RATE_DECAY_FREQUENCY = args.learning_rate_decay_frequency 106 | _C.SOLVER.LR_DECAY_ONLY_ONCE = args.lr_decay_only_once 107 | _C.SOLVER.SCHEDULE_RATES = args.scheduled_lr_rates 108 | _C.SOLVER.SCHEDULE_ITERS = args.scheduled_lr_iters 109 | _C.SOLVER.PRETRAINED_WEIGHT_LR_FACTOR_TEXT = args.pretrained_weight_lr_factor_text 110 | _C.SOLVER.PRETRAINED_WEIGHT_LR_FACTOR_IMAGE = args.pretrained_weight_lr_factor_image 111 | 112 | _C.SOLVER.SAVE_EVERY = args.save_every 113 | _C.SOLVER.EVAL_EVERY = args.eval_every 114 | 115 | return _C 116 | -------------------------------------------------------------------------------- /src/maaf/datasets/fashiongen.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | from PIL import Image 6 | import os 7 | import h5py 8 | from ..utils.misc_utils import tqdm 9 | import torch 10 | import numpy as np 11 | from sklearn.metrics import average_precision_score 12 | 13 | DATASET_CLASS_NAME = "FashionGen" 14 | 15 | 16 | class FashionGen(Dataset): 17 | 18 | def __init__(self, 19 | path="/home/default/ephemeral_drive/Data/fashiongen/", 20 | split="train", 21 | transform=None, 22 | default_image_size=224): 23 | super().__init__() 24 | 25 | self.data_path = path 26 | self.default_image_size = default_image_size 27 | 28 | self.transform = transform 29 | 30 | if split == "val": 31 | split = "validation" 32 | if split == "test": 33 | self.data = {"input_description": []} 34 | return 35 | self.data = h5py.File(os.path.join( 36 | self.data_path, f"fashiongen_256_256_{split}.h5"), mode="r") 37 | self.split = split 38 | 39 | def __len__(self): 40 | return len(self.data["input_description"]) 41 | 42 | def __getitem__(self, idx): 43 | item = {} 44 | item["image"] = Image.fromarray(self.data["input_image"][idx]) 45 | if self.transform is not None: 46 | item["image"] = self.transform(item["image"]) 47 | item["text"] = self.data["input_description"][idx][0].decode("latin-1") 48 | 49 | item["target_text"] = item["text"] 50 | item["target_image"] = None 51 | item["source_text"] = None 52 | item["source_image"] = item["image"] 53 | 54 | return item 55 | 56 | def get_all_texts(self): 57 | return [cap[0].decode("utf-8") for cap in self.data["input_description"]] 58 | 59 | def get_loader(self, 60 | batch_size, 61 | shuffle=False, 62 | drop_last=False, 63 | num_workers=0, 64 | category=None): 65 | return DataLoader( 66 | self, 67 | batch_size=batch_size, 68 | shuffle=shuffle, 69 | num_workers=num_workers, 70 | drop_last=drop_last, 71 | collate_fn=lambda i: i) 72 | 73 | def evaluate(self, model, cfg=None): 74 | model.eval() 75 | 76 | text_embed = {} 77 | image_embed = [] 78 | product_ids = [] 79 | for idx in tqdm(range(len(self))): 80 | img = Image.fromarray(self.data["input_image"][idx]) 81 | if self.transform is not None: 82 | img = self.transform(img) 83 | img = torch.stack([img]).float().to(model.device) 84 | text = self.data["input_description"][idx][0].decode("latin-1") 85 | pid = self.data["input_productID"][idx][0] 86 | img_emb = model(img, [None]).cpu().numpy() 87 | image_embed.append(img_emb / np.linalg.norm(img_emb)) 88 | product_ids.append(pid) 89 | if pid not in text_embed: 90 | text_emb = model([None], [text]).cpu().numpy() 91 | text_embed[pid] = text_emb / np.linalg.norm(text_emb) 92 | 93 | text_emb_array = np.concatenate(list(text_embed.values())) 94 | image_emb_array = np.concatenate(image_embed) 95 | text_pids = list(text_embed.keys()) 96 | 97 | sims = image_emb_array @ text_emb_array.T 98 | 99 | correct = np.zeros(sims.shape, dtype=bool) 100 | for ii, pid in enumerate(product_ids): 101 | for jj, text_pid in enumerate(text_pids): 102 | correct[ii, jj] = pid == text_pid 103 | 104 | img_ap_scores = [] 105 | for ii in range(len(correct)): 106 | img_ap_scores.append(average_precision_score(correct[ii], sims[ii])) 107 | img_map = np.mean(img_ap_scores) 108 | txt_ap_scores = [] 109 | for jj in range(correct.shape[1]): 110 | txt_ap_scores.append(average_precision_score(correct[jj], sims[jj])) 111 | txt_map = np.mean(txt_ap_scores) 112 | maps = [("image_to_text_mAP", img_map), ("text_to_image_mAP", txt_map)] 113 | 114 | sorter_per_img = np.argsort(sims, axis=1)[:, ::-1] 115 | sorter_per_text = np.argsort(sims, axis=0)[::-1] 116 | 117 | correct_sorted_per_img = np.take_along_axis( 118 | correct, sorter_per_img, axis=1) 119 | correct_sorted_per_text = np.take_along_axis( 120 | correct, sorter_per_text, axis=0) 121 | 122 | indic_per_img = np.cumsum(correct_sorted_per_img, axis=1) > 0 123 | indic_per_txt = np.cumsum(correct_sorted_per_text, axis=0) > 0 124 | ks = [1, 5, 10] 125 | recalls_img = [(f"image_to_text_recall{kk}", 126 | np.mean(indic_per_img[:, kk])) for kk in ks] 127 | recalls_txt = [(f"text_to_image_recall{kk}", 128 | np.mean(indic_per_txt[kk])) for kk in ks] 129 | 130 | recall_sum = sum([thing[1] for thing in recalls_img]) + \ 131 | sum([thing[1] for thing in recalls_txt]) 132 | recall_sum = [("recall_sum", recall_sum)] 133 | 134 | return maps + recalls_img + recalls_txt + recall_sum 135 | -------------------------------------------------------------------------------- /src/maaf/models/build.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | """Functions to put everything together and build the model.""" 4 | 5 | import sys 6 | import os 7 | import torch 8 | from . import composition_models 9 | from .heads import get_task_head 10 | from .image_model import build_image_model 11 | from .text_model import build_text_model 12 | from ..config.compat import MAAF_ALIASES 13 | from ..config import get_config 14 | from ..datasets.datasets import load_dataset 15 | 16 | 17 | def load_model(path, modifications=[], strict=True): 18 | config_file = os.path.join(path, "config.yaml") 19 | model, task, cfg = build_from_config_file(config_file, modifications, 20 | strict_loading=strict) 21 | checkpoint = os.path.join(path, "latest_checkpoint.pth") 22 | state_dict = torch.load(checkpoint, map_location=model.device)["model_state_dict"] 23 | model.load_state_dict(state_dict, strict=strict) 24 | return model 25 | 26 | 27 | def build_from_config_file(path, modifications=[], strict_loading=True): 28 | cfg = get_config() 29 | cfg.merge_from_file(path) 30 | cfg.merge_from_list(modifications) 31 | if not torch.cuda.is_available(): 32 | cfg.MODEL.DEVICE = "cpu" 33 | cfg.freeze() 34 | model, task = build_model(cfg, strict_loading=strict_loading) 35 | return model, task, cfg 36 | 37 | 38 | def build_model(cfg, texts=None, strict_loading=True): 39 | print('Building model', cfg.MODEL.COMPOSITION) 40 | 41 | tokenizer_needs_texts = cfg.MODEL.TEXT_MODEL.TOKENIZER == "simple" and \ 42 | cfg.MODEL.TEXT_MODEL.ARCHITECTURE is not None 43 | if texts is None and tokenizer_needs_texts: 44 | texts = load_dataset(cfg)["train"].get_all_texts() 45 | 46 | if "clip" in cfg.MODEL.COMPOSITION: 47 | from .clip import get_clip_class 48 | ModelClass, kwargs = get_clip_class(cfg) 49 | else: 50 | kwargs = {"image_model": build_image_model(cfg), 51 | "text_model": build_text_model(texts, cfg)} 52 | if cfg.MODEL.COMPOSITION == 'imgonly': 53 | ModelClass = composition_models.SimpleModelImageOnly 54 | elif cfg.MODEL.COMPOSITION == 'textonly': 55 | ModelClass = composition_models.SimpleModelTextOnly 56 | elif cfg.MODEL.COMPOSITION in MAAF_ALIASES \ 57 | or cfg.MODEL.COMPOSITION in ["residualMAAF", "resmaaf"]: 58 | if cfg.MODEL.COMPOSITION in ["residualMAAF", "resmaaf"]: 59 | print("Setting up Residual MAAF") 60 | ModelClass = composition_models.ResidualMAAF 61 | else: 62 | print("Setting up MAAF") 63 | ModelClass = composition_models.MAAF 64 | kwargs.update({ 65 | "model_dim": cfg.MODEL.EMBED_DIM, 66 | "num_heads": cfg.MODEL.MAAF.ATTENTION_HEADS, 67 | "ff_width": cfg.MODEL.MAAF.BLOCK_WIDTH, 68 | "dropout": cfg.MODEL.DROPOUT_RATE, 69 | "num_blocks": cfg.MODEL.MAAF.NUM_BLOCKS, 70 | "position_encodings": cfg.MODEL.MAAF.POSITION_ENCODING, 71 | "softmax_replacement": cfg.MODEL.MAAF.ATTN_SOFTMAX_REPLACEMENT, 72 | "output": cfg.MODEL.MAAF.OUTPUT, 73 | }) 74 | elif cfg.MODEL.COMPOSITION == 'concat': 75 | ModelClass = composition_models.Concat 76 | kwargs.update({"embed_dim": cfg.MODEL.EMBED_DIM, 77 | "dropout": cfg.MODEL.DROPOUT_RATE}) 78 | elif cfg.MODEL.COMPOSITION == 'add': 79 | ModelClass = composition_models.Addition 80 | elif "clip" in cfg.MODEL.COMPOSITION: 81 | pass 82 | elif cfg.MODEL.COMPOSITION == "random": 83 | ModelClass = composition_models.RandomComposition 84 | else: 85 | print('Invalid model', cfg.MODEL.COMPOSITION) 86 | sys.exit() 87 | 88 | head, task = get_task_head(cfg) 89 | 90 | model = ModelClass(head, **kwargs) 91 | 92 | device = torch.device(cfg.MODEL.DEVICE) 93 | model.to(device) 94 | 95 | if cfg.MODEL.WEIGHTS is not None: 96 | print("Loading model weights from", cfg.MODEL.WEIGHTS) 97 | loaded_dict = torch.load(cfg.MODEL.WEIGHTS, map_location=model.device) 98 | model.load_state_dict(loaded_dict["model_state_dict"], 99 | strict=strict_loading) 100 | 101 | return model, task 102 | 103 | 104 | def get_optimizer(cfg, model): 105 | # create optimizer 106 | param_dicts = [] 107 | gathered_params = set() 108 | # apply learning rate adjustments for model components 109 | image_fc = [p for p in model.image_model_fc_parameters()] 110 | gathered_params.update(image_fc) 111 | param_dicts.append({ 112 | 'params': image_fc, 113 | 'lr': cfg.SOLVER.LEARNING_RATE 114 | }) 115 | image_params = model.image_model_parameters( 116 | include_scratch=cfg.SOLVER.PROJECTION_LR_TIED_TO_PRETRAINED) 117 | other_img = [p for p in image_params if p not in gathered_params] 118 | gathered_params.update(other_img) 119 | param_dicts.append({ 120 | 'params': other_img, 121 | 'lr': cfg.SOLVER.PRETRAINED_WEIGHT_LR_FACTOR_IMAGE * cfg.SOLVER.LEARNING_RATE 122 | }) 123 | 124 | text_params = model.text_model_parameters( 125 | include_scratch=cfg.SOLVER.PROJECTION_LR_TIED_TO_PRETRAINED) 126 | text_params = [p for p in text_params] 127 | gathered_params.update(text_params) 128 | param_dicts.append({ 129 | 'params': text_params, 130 | 'lr': cfg.SOLVER.PRETRAINED_WEIGHT_LR_FACTOR_TEXT * cfg.SOLVER.LEARNING_RATE 131 | }) 132 | param_dicts.append( 133 | {'params': [p for p in model.parameters() if p not in gathered_params]}) 134 | 135 | if cfg.SOLVER.OPTIMIZER == "adam": 136 | optimizer = torch.optim.Adam( 137 | param_dicts, 138 | lr=cfg.SOLVER.LEARNING_RATE, 139 | eps=1e-4, 140 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 141 | else: 142 | optimizer = torch.optim.SGD( 143 | param_dicts, lr=cfg.SOLVER.LEARNING_RATE, momentum=cfg.SOLVER.MOMENTUM, 144 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 145 | 146 | return optimizer 147 | -------------------------------------------------------------------------------- /src/maaf/config/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | import argparse 6 | 7 | 8 | def get_parser(): 9 | parser = argparse.ArgumentParser() 10 | add_arg = parser.add_argument 11 | 12 | add_arg('--config_file', type=str, default="") 13 | add_arg('--no-train', action="store_false", dest="train") 14 | add_arg('--no-eval', action="store_false", dest="eval", 15 | help="skip the eval following training loop") 16 | add_arg('--no-config-save', action="store_false", dest="save_config") 17 | add_arg('--debug', action="store_true") 18 | add_arg('--resume', type=str, default=None) 19 | add_arg('--final_eval_on_test', action="store_true") 20 | add_arg('--non-strict_loading', action="store_false", dest="strict_loading") 21 | add_arg('--no-timestamp', action="store_false", dest="timestamp") 22 | 23 | add_arg( 24 | "opts", 25 | help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. ", 26 | default=None, 27 | nargs=argparse.REMAINDER, 28 | ) 29 | 30 | return parser 31 | 32 | 33 | def parse_opt(): 34 | args = get_parser().parse_args() 35 | return args 36 | 37 | 38 | def old_parser(): 39 | """Parses the input arguments.""" 40 | parser = argparse.ArgumentParser() 41 | add_arg = parser.add_argument 42 | add_arg('-f', type=str, default='') 43 | add_arg('--exp_name', type=str, default='debug') 44 | add_arg('--comment', type=str, default='') 45 | add_arg('--savedir', type=str, default='') 46 | 47 | add_arg('--device', type=str, default='cuda') 48 | 49 | add_arg('--dataset', type=str, default='dpa_gender') 50 | add_arg('--dataset_path', type=str, 51 | default='') 52 | add_arg('--image_dir', type=str, 53 | default="") 54 | add_arg('--require_images', action='store_true', 55 | help="for datasets where some examples have images and some do not") 56 | add_arg('--text_only', action='store_true') 57 | add_arg('--image_only', action='store_true') 58 | 59 | add_arg('--text_tokenizer', type=str, default="simple") 60 | add_arg('--text_data', type=str, default=None, 61 | help="if building vocab/tokenizer from data, a file with all text") 62 | add_arg('--max_vocab', type=int, default=52000) 63 | add_arg('--tokenizer_path', type=str, default="") 64 | 65 | add_arg('--class_weights', nargs='+', type=float, default=[23, 2.5, 6.4], 66 | help="default is inverse frequency for unisex, female, male in apparel dataset") 67 | 68 | add_arg('--model', type=str, default='maaf') 69 | add_arg('--embed_dim', type=int, default=512) 70 | add_arg('--batch_size', type=int, default=32) 71 | add_arg('--weight_decay', type=float, default=1e-6) 72 | add_arg('--num_iters', type=int, default=150000) 73 | add_arg('--loss', type=str, default='batch_based_classification') 74 | add_arg('--num_classes', type=int, default=3) 75 | 76 | add_arg('--learning_rate', type=float, default=1e-2) 77 | add_arg('--learning_rate_decay', type=float, default=0.1) 78 | add_arg('--learning_rate_decay_frequency', type=int, default=9999999) 79 | add_arg('--lr_decay_only_once', action="store_true") 80 | # more flexible learning rate scheduling. 81 | # both args below must be set or we default to old scheme 82 | add_arg('--scheduled_lr_rates', type=str, default="", 83 | help="Separate rates by commas." + 84 | "The learning_rate argument sets the initial rate; " + 85 | "this param sets rates after each scheduled_lr_iters entry" + 86 | "If empty string, old regular decay schedule is used.") 87 | add_arg('--scheduled_lr_iters', type=str, default="", 88 | help="Separate iteration numbers by commas." + 89 | "If empty string, old regular decay schedule is used.") 90 | add_arg('--pretrained_weight_lr_factor_image', type=float, default=0.1) 91 | add_arg('--pretrained_weight_lr_factor_text', type=float, default=1.) 92 | 93 | add_arg('--loader_num_workers', type=int, default=4) 94 | 95 | add_arg('-t', "--test_only", action="store_true") 96 | add_arg('-l', '--load', type=str, default="") 97 | 98 | add_arg('--dropout_rate', type=float, default=0.1) 99 | 100 | add_arg('--drop_worst_flag', action='store_true', 101 | help='If added the model will ingore the highest --drop_worst_rate losses') 102 | add_arg('--drop_worst_rate', type=float, default=0.2) 103 | 104 | add_arg('--image_model_arch', type=str, default='resnet50') 105 | add_arg('--image_model_path', type=str, default='') 106 | add_arg('--not_pretrained', action='store_true', 107 | help='If added, the image network will be trained WITHOUT ImageNet-pretrained weights.') 108 | add_arg('--freeze_img_model', action='store_true', 109 | help='If added the loaded image model weights will not be finetuned') 110 | 111 | add_arg('--text_model_arch', type=str, default='lstm') 112 | add_arg('--text_model_layers', type=int, default=1) 113 | add_arg('--normalize_text', action='store_true') 114 | add_arg('--threshold_rare_words', type=int, default=0) 115 | add_arg('--freeze_text_model', action='store_true', 116 | help='If added the loaded text model weights will not be finetuned') 117 | 118 | add_arg('--number_attention_blocks', type=int, default=1) 119 | add_arg('--width_per_attention_block', type=int, default=256) 120 | add_arg('--number_attention_heads', type=int, default=8) 121 | add_arg('--attn_positional_encoding', default=None) 122 | add_arg('--resolutionwise_pool', action='store_true') 123 | add_arg('--attn_softmax_replacement', type=str, default="none") 124 | add_arg('--att_layer_spec', type=str, default="3_4") 125 | 126 | add_arg('--softmax_margin', type=float, default=0) 127 | 128 | add_arg('--save_every', type=int, default=100, 129 | help="keep checkpoints this often in epochs") 130 | add_arg('--eval_every', type=int, default=3, 131 | help="run eval on val set this often in epochs") 132 | add_arg('--final_eval_on_test', action="store_true") 133 | 134 | add_arg('--inspect_after', action="store_true") 135 | 136 | return parser 137 | 138 | 139 | def old_parse_opt(): 140 | args, unknown = old_parser().parse_known_args() 141 | if args.load == "": 142 | args.load = None 143 | if args.image_model_path in ["", "none", "None"]: 144 | args.image_model_path = None 145 | if args.image_model_arch in ["", "none", "None"]: 146 | args.image_model_arch = None 147 | args.eval_only = args.test_only 148 | if args.attn_softmax_replacement == "none": 149 | args.attn_softmax_replacement = None 150 | args.debug = args.inspect_after 151 | 152 | for unk in unknown: 153 | print(f"WARNING: unrecognized argument: {unk}") 154 | 155 | return args 156 | -------------------------------------------------------------------------------- /Code-of-Conduct.md: -------------------------------------------------------------------------------- 1 | # Yahoo Open Source Code of Conduct 2 | 3 | ## Summary 4 | This Code of Conduct is our way to encourage good behavior and discourage bad behavior in our open source community. We invite participation from many people to bring different perspectives to support this project. We pledge to do our part to foster a welcoming and professional environment free of harassment. We expect participants to communicate professionally and thoughtfully during their involvement with this project. 5 | 6 | Participants may lose their good standing by engaging in misconduct. For example: insulting, threatening, or conveying unwelcome sexual content. We ask participants who observe conduct issues to report the incident directly to the project's Response Team at opensource-conduct@yahooinc.com. Yahoo will assign a respondent to address the issue. We may remove harassers from this project. 7 | 8 | This code does not replace the terms of service or acceptable use policies of the websites used to support this project. We acknowledge that participants may be subject to additional conduct terms based on their employment which may govern their online expressions. 9 | 10 | ## Details 11 | This Code of Conduct makes our expectations of participants in this community explicit. 12 | * We forbid harassment and abusive speech within this community. 13 | * We request participants to report misconduct to the project’s Response Team. 14 | * We urge participants to refrain from using discussion forums to play out a fight. 15 | 16 | ### Expected Behaviors 17 | We expect participants in this community to conduct themselves professionally. Since our primary mode of communication is text on an online forum (e.g. issues, pull requests, comments, emails, or chats) devoid of vocal tone, gestures, or other context that is often vital to understanding, it is important that participants are attentive to their interaction style. 18 | 19 | * **Assume positive intent.** We ask community members to assume positive intent on the part of other people’s communications. We may disagree on details, but we expect all suggestions to be supportive of the community goals. 20 | * **Respect participants.** We expect participants will occasionally disagree. Even if we reject an idea, we welcome everyone’s participation. Open Source projects are learning experiences. Ask, explore, challenge, and then respectfully assert if you agree or disagree. If your idea is rejected, be more persuasive not bitter. 21 | * **Welcoming to new members.** New members bring new perspectives. Some may raise questions that have been addressed before. Kindly point them to existing discussions. Everyone is new to every project once. 22 | * **Be kind to beginners.** Beginners use open source projects to get experience. They might not be talented coders yet, and projects should not accept poor quality code. But we were all beginners once, and we need to engage kindly. 23 | * **Consider your impact on others.** Your work will be used by others, and you depend on the work of others. We expect community members to be considerate and establish a balance their self-interest with communal interest. 24 | * **Use words carefully.** We may not understand intent when you say something ironic. Poe’s Law suggests that without an emoticon people will misinterpret sarcasm. We ask community members to communicate plainly. 25 | * **Leave with class.** When you wish to resign from participating in this project for any reason, you are free to fork the code and create a competitive project. Open Source explicitly allows this. Your exit should not be dramatic or bitter. 26 | 27 | ### Unacceptable Behaviors 28 | Participants remain in good standing when they do not engage in misconduct or harassment. To elaborate: 29 | * **Don't be a bigot.** Calling out project members by their identity or background in a negative or insulting manner. This includes, but is not limited to, slurs or insinuations related to protected or suspect classes e.g. race, color, citizenship, national origin, political belief, religion, sexual orientation, gender identity and expression, age, size, culture, ethnicity, genetic features, language, profession, national minority status, mental or physical ability. 30 | * **Don't insult.** Insulting remarks about a person’s lifestyle practices. 31 | * **Don't dox.** Revealing private information about other participants without explicit permission. 32 | * **Don't intimidate.** Threats of violence or intimidation of any project member. 33 | * **Don't creep.** Unwanted sexual attention or content unsuited for the subject of this project. 34 | * **Don't disrupt.** Sustained disruptions in a discussion. 35 | * **Let us help.** Refusal to assist the Response Team to resolve an issue in the community. 36 | 37 | We do not list all forms of harassment, nor imply some forms of harassment are not worthy of action. Any participant who *feels* harassed or *observes* harassment, should report the incident. Victim of harassment should not address grievances in the public forum, as this often intensifies the problem. Report it, and let us address it off-line. 38 | 39 | ### Reporting Issues 40 | If you experience or witness misconduct, or have any other concerns about the conduct of members of this project, please report it by contacting our Response Team at opensource-conduct@yahooinc.com who will handle your report with discretion. Your report should include: 41 | * Your preferred contact information. We cannot process anonymous reports. 42 | * Names (real or usernames) of those involved in the incident. 43 | * Your account of what occurred, and if the incident is ongoing. Please provide links to or transcripts of the publicly available records (e.g. a mailing list archive or a public IRC logger), so that we can review it. 44 | * Any additional information that may be helpful to achieve resolution. 45 | 46 | After filing a report, a representative will contact you directly to review the incident and ask additional questions. If a member of the Yahoo Response Team is named in an incident report, that member will be recused from handling your incident. If the complaint originates from a member of the Response Team, it will be addressed by a different member of the Response Team. We will consider reports to be confidential for the purpose of protecting victims of abuse. 47 | 48 | ### Scope 49 | Yahoo will assign a Response Team member with admin rights on the project and legal rights on the project copyright. The Response Team is empowered to restrict some privileges to the project as needed. Since this project is governed by an open source license, any participant may fork the code under the terms of the project license. The Response Team’s goal is to preserve the project if possible, and will restrict or remove participation from those who disrupt the project. 50 | 51 | This code does not replace the terms of service or acceptable use policies that are provided by the websites used to support this community. Nor does this code apply to communications or actions that take place outside of the context of this community. Many participants in this project are also subject to codes of conduct based on their employment. This code is a social-contract that informs participants of our social expectations. It is not a terms of service or legal contract. 52 | 53 | ## License and Acknowledgment. 54 | This text is shared under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0/). This code is based on a study conducted by the [TODO Group](https://todogroup.org/) of many codes used in the open source community. If you have feedback about this code, contact our Response Team at the address listed above. 55 | -------------------------------------------------------------------------------- /src/maaf/datasets/mitstates.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | import random 6 | import PIL 7 | from os import listdir 8 | 9 | DATASET_CLASS_NAME = "MITStates" 10 | 11 | TEST_NOUNS = [ 12 | u'armor', u'bracelet', u'bush', u'camera', u'candy', u'castle', 13 | u'ceramic', u'cheese', u'clock', u'clothes', u'coffee', u'fan', u'fig', 14 | u'fish', u'foam', u'forest', u'fruit', u'furniture', u'garden', u'gate', 15 | u'glass', u'horse', u'island', u'laptop', u'lead', u'lightning', 16 | u'mirror', u'orange', u'paint', u'persimmon', u'plastic', u'plate', 17 | u'potato', u'road', u'rubber', u'sand', u'shell', u'sky', u'smoke', 18 | u'steel', u'stream', u'table', u'tea', u'tomato', u'vacuum', u'wax', 19 | u'wheel', u'window', u'wool' 20 | ] 21 | 22 | 23 | class MITStatesGallery(Dataset): 24 | 25 | def __init__(self, gallery, transform=None): 26 | super().__init__() 27 | self.gallery = gallery 28 | self.transform = transform 29 | 30 | def __getitem__(self, ind): 31 | example = self.gallery[ind] 32 | image = self.get_img(ind) 33 | example["target_image"] = image 34 | example["target_text"] = None 35 | 36 | return example 37 | 38 | def __len__(self): 39 | return len(self.gallery) 40 | 41 | def get_img(self, idx, raw_img=False): 42 | img_path = self.gallery[idx]['file_path'] 43 | with open(img_path, 'rb') as f: 44 | img = PIL.Image.open(f) 45 | img = img.convert('RGB') 46 | if raw_img: 47 | return img 48 | if self.transform: 49 | img = self.transform(img) 50 | return img 51 | 52 | 53 | class MITStates(Dataset): 54 | 55 | def __init__(self, path, split="train", transform=None): 56 | super().__init__() 57 | 58 | self.path = path 59 | self.transform = transform 60 | self.split = split 61 | 62 | self.data = [] 63 | 64 | for f in listdir(path + '/images'): 65 | if ' ' not in f: 66 | continue 67 | adj, noun = f.split() 68 | if adj == 'adj': 69 | continue 70 | if split == 'train' and noun in TEST_NOUNS: 71 | continue 72 | if split == 'test' and noun not in TEST_NOUNS: 73 | continue 74 | 75 | for file_path in listdir(path + '/images/' + f): 76 | assert (file_path.endswith('jpg')) 77 | this_index = len(self.data) 78 | self.data += [{ 79 | 'file_path': path + '/images/' + f + '/' + file_path, 80 | 'captions': [f], 81 | 'adj': adj, 82 | 'noun': noun, 83 | "image_id": this_index 84 | }] 85 | 86 | self.gallery = MITStatesGallery(self.data, self.transform) 87 | self.caption_index_init_() 88 | if split == 'test': 89 | self.generate_test_queries_() 90 | else: 91 | self.test_queries = None 92 | 93 | self.saved_item = None 94 | 95 | def get_all_texts(self): 96 | texts = [] 97 | for img in self.data: 98 | texts += img['captions'] 99 | return texts 100 | 101 | def __getitem__(self, idx): 102 | if self.split == "test": 103 | query = self.test_queries[idx] 104 | return { 105 | 'source_image': self.get_img(query["source_img_id"]), 106 | 'source_text': query["mod"]["str"] 107 | } 108 | else: 109 | idx, target_idx = self.get_random_pair(idx) 110 | mod_str = self.data[target_idx]['adj'] 111 | 112 | return { 113 | 'source_img_id': idx, 114 | 'source_image': self.get_img(idx), 115 | 'source_cap': self.data[idx]['captions'][0], 116 | 'target_img_id': target_idx, 117 | 'target_image': self.get_img(target_idx), 118 | 'target_cap': self.data[target_idx]['captions'][0], 119 | 'source_text': mod_str, 120 | "target_text": None 121 | } 122 | 123 | def get_random_pair(self, idx): 124 | """ 125 | Eric doesn't know why this pairing thing was in the TIRG code 126 | but we're keeping it for consistency. 127 | """ 128 | if self.saved_item is None: 129 | while True: 130 | idx, target_idx1 = self.caption_index_sample_(idx) 131 | idx, target_idx2 = self.caption_index_sample_(idx) 132 | if self.data[target_idx1]['adj'] != self.data[target_idx2]['adj']: 133 | break 134 | idx, target_idx = [idx, target_idx1] 135 | self.saved_item = [idx, target_idx2] 136 | else: 137 | idx, target_idx = self.saved_item 138 | self.saved_item = None 139 | return idx, target_idx 140 | 141 | def caption_index_init_(self): 142 | self.caption2imgids = {} 143 | self.noun2adjs = {} 144 | for i, img in enumerate(self.data): 145 | cap = img['captions'][0] 146 | adj = img['adj'] 147 | noun = img['noun'] 148 | if cap not in self.caption2imgids.keys(): 149 | self.caption2imgids[cap] = [] 150 | if noun not in self.noun2adjs.keys(): 151 | self.noun2adjs[noun] = [] 152 | self.caption2imgids[cap].append(i) 153 | if adj not in self.noun2adjs[noun]: 154 | self.noun2adjs[noun].append(adj) 155 | for noun, adjs in self.noun2adjs.items(): 156 | assert len(adjs) >= 2 157 | 158 | def caption_index_sample_(self, idx): 159 | noun = self.data[idx]['noun'] 160 | # adj = self.data[idx]['adj'] 161 | target_adj = random.choice(self.noun2adjs[noun]) 162 | target_caption = target_adj + ' ' + noun 163 | target_idx = random.choice(self.caption2imgids[target_caption]) 164 | return idx, target_idx 165 | 166 | def generate_test_queries_(self): 167 | self.test_queries = [] 168 | for idx, img in enumerate(self.data): 169 | adj = img['adj'] 170 | noun = img['noun'] 171 | for target_adj in self.noun2adjs[noun]: 172 | if target_adj != adj: 173 | mod_str = target_adj 174 | self.test_queries += [{ 175 | 'source_img_id': idx, 176 | 'source_caption': adj + ' ' + noun, 177 | 'target_caption': target_adj + ' ' + noun, 178 | 'mod': { 179 | 'str': mod_str 180 | } 181 | }] 182 | print(len(self.test_queries), 'test queries') 183 | 184 | def __len__(self): 185 | if self.split == "test": 186 | return len(self.test_queries) 187 | return len(self.data) 188 | 189 | def get_img(self, idx, raw_img=False): 190 | img_path = self.data[idx]['file_path'] 191 | with open(img_path, 'rb') as f: 192 | img = PIL.Image.open(f) 193 | img = img.convert('RGB') 194 | if raw_img: 195 | return img 196 | if self.transform: 197 | img = self.transform(img) 198 | return img 199 | 200 | def get_loader(self, 201 | batch_size, 202 | shuffle=False, 203 | drop_last=False, 204 | num_workers=0, 205 | category=None): 206 | return DataLoader( 207 | self, 208 | batch_size=batch_size, 209 | shuffle=shuffle, 210 | num_workers=num_workers, 211 | drop_last=drop_last, 212 | collate_fn=lambda i: i) 213 | 214 | def get_gallery_loader(self, batch_size, num_workers=0): 215 | return DataLoader( 216 | self.gallery, 217 | batch_size=batch_size, 218 | shuffle=False, 219 | num_workers=num_workers, 220 | drop_last=False, 221 | collate_fn=lambda i: i) 222 | -------------------------------------------------------------------------------- /src/maaf/models/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | 10 | def build_loss(cfg): 11 | loss_name = cfg.MODEL.LOSS 12 | 13 | learning_type = "metric" 14 | if loss_name == "softmax_cross_entropy": 15 | class_weights = torch.Tensor(cfg.DATASET.CLASS_WEIGHTS) 16 | loss_obj = torch.nn.CrossEntropyLoss(weight=class_weights) 17 | learning_type = "classification" 18 | elif loss_name == "multilabel_soft_margin": 19 | class_weights = torch.Tensor(cfg.DATASET.CLASS_WEIGHTS) 20 | loss_obj = torch.nn.MultiLabelSoftMarginLoss(weight=class_weights) 21 | learning_type = "classification" 22 | elif loss_name == "mse": 23 | loss_obj = torch.nn.MSELoss() 24 | learning_type = "regression" 25 | elif loss_name == "soft_triplet": 26 | loss_obj = SoftTripletLoss() 27 | elif loss_name == "batch_based_classification": 28 | loss_obj = BatchSoftmaxLoss() 29 | elif loss_name == "double_softmax": 30 | loss_obj = DoubleBatchSoftmaxLoss() 31 | elif loss_name == "logistic": 32 | loss_obj = Logistic() 33 | elif loss_name == "logistic_cumulative": 34 | loss_obj = LogisticCumulativeLink(cfg.DATASET.NUM_CLASSES) 35 | elif loss_name == "logratio": 36 | loss_obj = LogRatioLoss() 37 | elif loss_name == "softlabel_softmax": 38 | loss_obj = SoftLabelSoftmaxLoss() 39 | else: 40 | print('Invalid loss function', loss_name) 41 | sys.exit() 42 | 43 | return loss_obj, learning_type 44 | 45 | 46 | class MetricLossBase: 47 | 48 | def __call__(self, sources, targets, labels): 49 | raise NotImplementedError 50 | 51 | 52 | class SoftTripletLoss(MetricLossBase): 53 | 54 | def __call__(self, sources, targets, labels=None): 55 | triplets = [] 56 | labels = list(range(sources.shape[0])) + list(range(targets.shape[0])) 57 | for i in range(len(labels)): 58 | triplets_i = [] 59 | for j in range(len(labels)): 60 | if labels[i] == labels[j] and i != j: 61 | for k in range(len(labels)): 62 | if labels[i] != labels[k]: 63 | triplets_i.append([i, j, k]) 64 | np.random.shuffle(triplets_i) 65 | triplets += triplets_i[:3] # WHY? 66 | assert (triplets and len(triplets) < 2000) 67 | return self.soft_triplet_loss(torch.cat([sources, targets]), triplets) 68 | 69 | 70 | class BatchSoftmaxLoss(MetricLossBase): 71 | """ 72 | Implements batch-wise softmax cross-entropy loss. 73 | Source-target pairs are assumed to be positive matches. 74 | """ 75 | 76 | def __init__(self, softmax_margin=0, drop_worst_rate=0): 77 | self.softmax_margin = softmax_margin 78 | self.drop_worst_rate = drop_worst_rate 79 | 80 | def __call__(self, sources, targets, labels=None): 81 | dots = torch.mm(sources, targets.transpose(0, 1)) 82 | if self.softmax_margin > 0: 83 | dots = dots - (torch.Tensor(np.eye(dots.shape[0])).to(sources.device) 84 | * self.softmax_margin) 85 | labels = torch.tensor(range(dots.shape[0])).long().to(dots.device) 86 | losses = nn.functional.cross_entropy(dots, labels, reduction='none') 87 | if self.drop_worst_rate > 0: 88 | losses, idx = torch.topk( 89 | losses, k=int(losses.shape[0] * (1 - self.drop_worst_rate)), 90 | largest=False) 91 | final_loss = losses.mean() 92 | 93 | return final_loss 94 | 95 | 96 | class DoubleBatchSoftmaxLoss(MetricLossBase): 97 | 98 | def __init__(self): 99 | self.loss_first = nn.CrossEntropyLoss() 100 | self.loss_second = nn.CrossEntropyLoss() 101 | 102 | def __call__(self, sources, targets, labels=None): 103 | dots = torch.mm(sources, targets.transpose(0, 1)) 104 | 105 | labels = torch.tensor(range(dots.shape[0])).long().to(dots.device) 106 | losses_a = self.loss_first(dots, labels) 107 | losses_b = self.loss_second(dots.transpose(0, 1), labels) 108 | 109 | final_loss = losses_a + losses_b 110 | return final_loss / 2 111 | 112 | 113 | class SoftLabelSoftmaxLoss(MetricLossBase): 114 | 115 | def __init__(self): 116 | self.celoss = nn.CrossEntropyLoss() 117 | 118 | def __call__(self, sources, targets, labels): 119 | dots = torch.mm(sources, targets.transpose(0, 1)) 120 | iou = labels_from_attributes(labels, dots.shape).to(dots.device) 121 | # NOTE: using CrossEntropyLoss with probabilities requires 122 | # a recent PyTorch version. Installing this naively may cause problems 123 | # e.g. a protobuf conflict...but downgrading protobuf may solve this 124 | return self.celoss(dots, iou) 125 | 126 | 127 | def intersection_over_union(first, other): 128 | first_set = set(first) 129 | other_set = set(other) 130 | union = len(first_set.union(other_set)) 131 | intersection = len(first_set.intersection(other_set)) 132 | return intersection / union 133 | 134 | 135 | def labels_from_attributes(labels, shape): 136 | """Compute IoU labels given attribute lists""" 137 | iou = torch.zeros(shape) 138 | for ii in range(len(labels)): 139 | source_att = labels[ii][0] 140 | for jj in range(len(labels)): 141 | target_att = labels[jj][1] 142 | iou[ii, jj] = intersection_over_union(source_att, target_att) 143 | return iou 144 | 145 | 146 | class LogRatioLoss(MetricLossBase): 147 | """ 148 | Adapted from 149 | Kim et al. 'Deep metric learning beyond binary supervision' CVPR 2019. 150 | See 151 | https://github.com/tjddus9597/Beyond-Binary-Supervision-CVPR19/blob/master/code/LogRatioLoss.py 152 | """ 153 | epsilon = 1e-6 154 | 155 | def __init__(self): 156 | pass 157 | 158 | def __call__(self, sources, targets, labels): 159 | # get all pairwise distances by broadcasting 160 | distances = torch.linalg.vector_norm( 161 | sources[:, None, :] - targets[None, :, :], dim=2) 162 | log_dist = torch.log(distances + self.epsilon) 163 | 164 | iou = labels_from_attributes(labels, log_dist.shape).to(log_dist.device) 165 | log_iou = torch.log(iou + self.epsilon) 166 | 167 | # get a loss term for each triple (a, i, j) for i and j in target 168 | # note that the i=j terms are 0 169 | loss_terms = (log_dist[:, :, None] - log_dist[:, None, :]) - \ 170 | (log_iou[:, :, None] - log_iou[:, None, :]) 171 | loss_terms = loss_terms * loss_terms 172 | 173 | loss = torch.mean(loss_terms) 174 | 175 | return loss 176 | 177 | 178 | class Logistic(MetricLossBase, nn.Module): 179 | 180 | def __init__(self): 181 | nn.Module.__init__(self) 182 | 183 | self.criterion = nn.SoftMarginLoss() 184 | # self.criterion = nn.BCELoss() for this need sigmoid in __call__ 185 | 186 | def __call__(self, sources, targets, labels=None): 187 | dots = torch.sum(sources * targets, dim=1) 188 | labels = torch.Tensor(labels) 189 | return self.criterion(dots, labels) 190 | 191 | 192 | class LogisticCumulativeLink(MetricLossBase, nn.Module): 193 | """Adapted from https://github.com/EthanRosenthal/spacecutter/... 194 | blob/master/spacecutter/losses.py""" 195 | def __init__(self, num_classes): 196 | MetricLossBase.__init__(self) 197 | nn.Module.__init__(self) 198 | 199 | num_thresh = num_classes - 1 200 | self.thresholds = torch.arange(num_thresh).float() - num_thresh / 2 201 | self.thresholds = nn.Parameter(self.thresholds) 202 | 203 | 204 | def __call__(self, sources, targets, labels=None): 205 | dots = torch.sum(sources * targets, dim=1).unsqueeze(-1) 206 | sigmoids = torch.sigmoid(self.thresholds - dots) 207 | link_mat = sigmoids[:, 1:] - sigmoids[:, :-1] 208 | link_mat = torch.cat(( 209 | sigmoids[:, [0]], 210 | link_mat, 211 | (1 - sigmoids[:, [-1]]) 212 | ), 213 | dim=1 214 | ) # batch, num_classes 215 | 216 | labels = torch.Tensor(labels).long() 217 | 218 | likelihoods = link_mat[labels] 219 | eps = 1e-15 220 | likelihoods = torch.clamp(likelihoods, eps, 1 - eps) 221 | neg_log_likelihood = -torch.log(likelihoods) 222 | 223 | loss = torch.mean(neg_log_likelihood) 224 | return loss 225 | -------------------------------------------------------------------------------- /src/maaf/datasets/birdstowords.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | import random 5 | import os 6 | from torch.utils.data import Dataset, DataLoader 7 | import PIL 8 | 9 | 10 | DATASET_CLASS_NAME = "BirdsToWords" 11 | 12 | 13 | class BirdsToWordsGallery(Dataset): 14 | 15 | def __init__(self, gallery, transform=None): 16 | super().__init__() 17 | self.gallery = gallery 18 | self.transform = transform 19 | 20 | def __getitem__(self, ind): 21 | example = self.gallery[ind] 22 | image = self.get_img(ind) 23 | example["target_image"] = image 24 | example["target_text"] = None 25 | 26 | return example 27 | 28 | def __len__(self): 29 | return len(self.gallery) 30 | 31 | def get_img(self, idx, raw_img=False): 32 | """Retrieve image by global index.""" 33 | img_path = self.gallery[idx]['file_path'] 34 | try: 35 | with open(img_path, 'rb') as f: 36 | img = PIL.Image.open(f) 37 | img = img.convert('RGB') 38 | except EnvironmentError as ee: 39 | print("WARNING: EnvironmentError, defaulting to image 0", ee) 40 | img = self.get_img(0, raw_img=True) 41 | if raw_img: 42 | return img 43 | if self.transform: 44 | img = self.transform(img) 45 | return img 46 | 47 | 48 | class BirdsToWords(Dataset): 49 | 50 | def __init__(self, path, split='train', transform=None, 51 | batch_size=None, normalize=False): 52 | 53 | super().__init__() 54 | self.normalize = normalize 55 | self.batch_size = batch_size 56 | 57 | self.split = split 58 | self.transform = transform 59 | self.img_path = path + '/images/' 60 | 61 | failures = set() 62 | 63 | tsv = os.path.join(path, "birds-to-words-v1.0.tsv") 64 | raw_data = [] 65 | 66 | with open(tsv, "r") as fh: 67 | for line in fh: 68 | entry = line.strip().split("\t") 69 | if entry[-3] == self.split: 70 | raw_data.append(entry) 71 | 72 | ####### 73 | # get image data 74 | image_fpaths = set() 75 | for entry in raw_data: 76 | first_image_fpath = entry[1].split("/")[-1] 77 | second_image_fpath = entry[5].split("/")[-1] 78 | image_fpaths.add(first_image_fpath) 79 | image_fpaths.add(second_image_fpath) 80 | 81 | image_fpaths = sorted(list(image_fpaths)) 82 | img_fpath_to_id = {} 83 | all_images = [] 84 | for fpath in image_fpaths: 85 | full_path = self.img_path + fpath 86 | if os.path.exists(full_path): 87 | image_id = len(all_images) 88 | entry = [{ 89 | 'photo_number': fpath.split("?")[-1], 90 | 'file_path': full_path, 91 | 'captions': [image_id], # not really a caption! 92 | 'image_id': image_id, 93 | }] 94 | all_images += entry 95 | img_fpath_to_id[fpath] = image_id 96 | else: 97 | failures.add(fpath) 98 | 99 | print(len(failures), " files not found in ", split) 100 | assert len(all_images) > 0, "no data found" 101 | 102 | 103 | ####### 104 | # get pairs and descriptions 105 | queries = {} 106 | for entry in raw_data: 107 | first_image_fpath = entry[1].split("/")[-1] 108 | second_image_fpath = entry[5].split("/")[-1] 109 | 110 | if first_image_fpath in failures or second_image_fpath in failures: 111 | continue 112 | 113 | query_dict_key = first_image_fpath + "_" + second_image_fpath 114 | descrip = entry[-1] 115 | assert isinstance(descrip, str) 116 | if query_dict_key in queries: 117 | query = queries[query_dict_key] 118 | query["captions"] += [descrip] 119 | else: 120 | query = {} 121 | query["source_id"] = img_fpath_to_id[first_image_fpath] 122 | query["target_id"] = img_fpath_to_id[second_image_fpath] 123 | query["captions"] = [descrip] 124 | 125 | queries[query_dict_key] = query 126 | 127 | # during training also use triplet with images swapped 128 | if self.split == "train": 129 | word_by_word = entry[-1].split(" ") 130 | for ii in range(len(word_by_word)): 131 | if word_by_word[ii] == "animal1": 132 | word_by_word[ii] = "animal2" 133 | elif word_by_word[ii] == "animal2": 134 | word_by_word[ii] = "animal1" 135 | flipped_descrip = " ".join(word_by_word) 136 | assert len(flipped_descrip) == len(entry[-1]) 137 | 138 | flipped_key = second_image_fpath + "_" + first_image_fpath 139 | if flipped_key in queries: 140 | flipped_query = queries[flipped_key] 141 | flipped_query["captions"] += [flipped_descrip] 142 | else: 143 | flipped_query = {} 144 | flipped_query["source_id"] = query["target_id"] 145 | flipped_query["target_id"] = query["source_id"] 146 | flipped_query["captions"] = [flipped_descrip] 147 | 148 | queries[flipped_key] = flipped_query 149 | 150 | query_keys = sorted(list(queries.keys())) 151 | queries = [queries[key] for key in query_keys] 152 | 153 | self.data = all_images 154 | self.queries = queries 155 | 156 | if split in ["val", "test"]: 157 | self.test_queries = [] 158 | for query in queries: 159 | self.test_queries += [{ 160 | 'source_img_id': query['source_id'], 161 | 'target_img_id': query['target_id'], 162 | 'target_caption': query['target_id'], 163 | 'mod': {'str': query['captions'][ii]} 164 | } for ii in range(len(query['captions']))] 165 | 166 | self.gallery = BirdsToWordsGallery(self.data, transform=self.transform) 167 | 168 | def get_all_texts(self): 169 | texts = [] 170 | for query in self.queries: 171 | texts += query['captions'] 172 | return texts 173 | 174 | def __len__(self): 175 | return len(self.data) 176 | 177 | def __getitem__(self, idx): 178 | if self.split in ["val", "test"]: 179 | query = self.test_queries[idx] 180 | return { 181 | 'source_image': self.get_img(query["source_img_id"]), 182 | 'source_text': query["mod"]["str"] 183 | } 184 | return self.generate_random_query_target() 185 | 186 | def generate_random_query_target(self): 187 | query = random.choice(self.queries) 188 | mod_str = random.choice(query['captions']) 189 | 190 | return { 191 | 'source_img_id': query['source_id'], 192 | 'source_image': self.get_img(query['source_id']), 193 | 'target_img_id': query['target_id'], 194 | 'target_caption': query['target_id'], 195 | 'target_image': self.get_img(query['target_id']), 196 | 'source_text': {'str': mod_str}, 197 | "target_text": None 198 | } 199 | 200 | def get_img(self, idx, raw_img=False): 201 | """Retrieve image by global index.""" 202 | img_path = self.data[idx]['file_path'] 203 | try: 204 | with open(img_path, 'rb') as f: 205 | img = PIL.Image.open(f) 206 | img = img.convert('RGB') 207 | except EnvironmentError as ee: 208 | print("WARNING: EnvironmentError, defaulting to image 0", ee) 209 | img = self.get_img(0, raw_img=True) 210 | if raw_img: 211 | return img 212 | if self.transform: 213 | img = self.transform(img) 214 | return img 215 | 216 | def get_loader(self, 217 | batch_size, 218 | shuffle=False, 219 | drop_last=False, 220 | num_workers=0, 221 | category=None): 222 | return DataLoader( 223 | self, 224 | batch_size=batch_size, 225 | shuffle=shuffle, 226 | num_workers=num_workers, 227 | drop_last=drop_last, 228 | collate_fn=lambda i: i) 229 | 230 | def get_gallery_loader(self, batch_size, num_workers=0): 231 | return DataLoader( 232 | self.gallery, 233 | batch_size=batch_size, 234 | shuffle=False, 235 | num_workers=num_workers, 236 | drop_last=False, 237 | collate_fn=lambda i: i) 238 | -------------------------------------------------------------------------------- /src/maaf/models/image_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | import torch 5 | import torchvision 6 | from torch.nn import functional as tfunc 7 | 8 | 9 | # Loading weights Auxiliary functions 10 | def remove_prefix(prefix, a_string): 11 | l_prefix = len(prefix) 12 | if a_string[:l_prefix] == prefix: 13 | final_string = a_string[l_prefix:] 14 | else: 15 | final_string = a_string 16 | return(final_string) 17 | 18 | 19 | def load_pretrained_weights(model, weights_path, freeze, prefix_to_remove): 20 | model_dict = model.state_dict() 21 | saved_state_dict = torch.load(weights_path)['state_dict'] 22 | print("Loading image model weights from: %s" % weights_path) 23 | # 1. filter out unnecessary keys 24 | pretrained_dict = {remove_prefix(prefix_to_remove, k): v 25 | for k, v in saved_state_dict.items() if remove_prefix(prefix_to_remove, k) in model_dict} 26 | # 2. overwrite entries in the existing state dict 27 | for tensor_name in pretrained_dict.keys(): 28 | print('Loading %s' % tensor_name) 29 | model_dict.update(pretrained_dict) 30 | # 3. load the new state dict 31 | model.load_state_dict(model_dict) 32 | 33 | if freeze: 34 | for name, param in model.named_parameters(): 35 | if name in pretrained_dict: 36 | print('freezing parameter: %s' % name) 37 | param.requires_grad = False 38 | model.eval() 39 | 40 | 41 | def repeating_eye(in_channels, out_channels): 42 | repetitions = in_channels // out_channels 43 | eye = torch.eye(out_channels) 44 | return eye.repeat(1, repetitions) 45 | 46 | 47 | class ConvProjection(torch.nn.Module): 48 | 49 | def __init__(self, in_channels, out_channels, kernel_size=1, 50 | dtype=None, initialization=None, **kwargs): 51 | super().__init__() 52 | self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, 53 | bias=False, **kwargs) 54 | self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) 55 | 56 | if initialization == "identity": 57 | assert kernel_size == 1 58 | eyes = repeating_eye(in_channels, out_channels) 59 | self.conv.weight.data = eyes.reshape(out_channels, in_channels, 1, 1) 60 | elif initialization is not None: 61 | raise ValueError(f"Unsupported initialization {initialization}") 62 | 63 | if dtype is not None: 64 | self.conv = self.conv.to(dtype) 65 | 66 | def forward(self, x): 67 | x = self.conv(x) 68 | x = self.bn(x) 69 | x = tfunc.relu(x, inplace=True) 70 | return x.view(x.shape[0], x.shape[1], -1).transpose(-2, -1) 71 | 72 | 73 | class GlobalAvgPool2d(torch.nn.Module): 74 | 75 | def forward(self, x): 76 | return tfunc.adaptive_avg_pool2d(x, (1, 1)) 77 | 78 | 79 | class ResNet(torch.nn.Module): 80 | """ 81 | ResNet, possibly returning intermediate layer outputs, possibly projected 82 | forward returns a dict 83 | """ 84 | 85 | def __init__(self, architecture="resnet50", out_features=["fc"], 86 | out_channels=None, pretrained=True, dict_out=False): 87 | """ 88 | Args: 89 | architecture (str): resnet50 or resnet18 90 | out_features (list[str | int]): layers whose outputs to return, 91 | from among "stem", 1, 2, 3, 4, "fc" 92 | out_channels (None | int): if None, outputs returned directly from 93 | resnet. If an int, learnable convs project outputs to out_channels 94 | pretrained (bool): if True, load torchvision's ImageNet-trained weights 95 | Note this requires internet or precaching these weights 96 | """ 97 | super().__init__() 98 | if architecture == 'resnet50': 99 | top_channel = 2048 100 | print("Using ResNet50") 101 | self.model = torchvision.models.resnet50(pretrained=pretrained) 102 | elif architecture == 'resnet18': 103 | top_channel = 512 104 | print("Using ResNet18") 105 | self.model = torchvision.models.resnet18(pretrained=pretrained) 106 | else: 107 | raise ValueError("Invalid image_model_arch {}".format( 108 | architecture)) 109 | 110 | self.out_features = out_features 111 | 112 | # drop any layers that aren't being used 113 | labels_to_inds = {"stem": 0, 1: 1, 2: 2, 3: 3, 4: 4, "fc": 5} 114 | # inds_to_labels = {val: key for key, val in labels_to_inds.items()} 115 | inds = [labels_to_inds[label] for label in self.out_features] 116 | last_ind = max(inds) 117 | self.layers = [] 118 | for ii in range(1, 5): 119 | if ii > last_ind: 120 | delattr(self.model, f"layer{ii}") 121 | else: 122 | self.layers.append((ii, getattr(self.model, f"layer{ii}"))) 123 | 124 | if out_channels is not None: 125 | self.projections = torch.nn.ModuleDict() 126 | for ii, layer in self.layers: 127 | if ii in self.out_features: 128 | in_channel = top_channel // (2**(4 - ii)) 129 | self.projections[str(ii)] = \ 130 | ConvProjection(in_channel, out_channels, kernel_size=1) 131 | 132 | if "fc" in self.out_features: 133 | self.model.avgpool = GlobalAvgPool2d() 134 | self.avgpool = self.model.avgpool 135 | 136 | self.model.fc = torch.nn.Sequential( 137 | torch.nn.Linear(self.model.fc.weight.shape[1], 138 | out_channels)) 139 | if len(self.out_features) == 1: 140 | self.projections = None 141 | else: 142 | self.projections = None 143 | 144 | if "fc" in out_features: 145 | self.fc = self.model.fc 146 | else: 147 | del self.model.fc 148 | 149 | def pretrained_parameters(self): 150 | if self.projections is not None: 151 | scratch = set([param for param in self.projections.parameters()]) 152 | all_param = set([param for param in self.parameters()]) 153 | return all_param.difference(scratch) 154 | else: 155 | return self.parameters() 156 | 157 | def forward(self, imgs): 158 | """Returns: {layer_name: output}""" 159 | out = {} 160 | xx = imgs 161 | 162 | xx = self.model.conv1(xx) 163 | xx = self.model.bn1(xx) 164 | xx = self.model.relu(xx) 165 | xx = self.model.maxpool(xx) 166 | if "stem" in self.out_features: 167 | out["stem"] = xx 168 | 169 | for ind, layer in self.layers: 170 | xx = layer(xx) 171 | if ind in self.out_features: 172 | out[ind] = xx 173 | 174 | if "fc" in self.out_features: 175 | xx = self.avgpool(xx) 176 | out["fc"] = self.fc(xx.view(xx.size(0), -1)) 177 | 178 | if self.projections is not None: 179 | proj = [self.projections[str(ii)](out[ii]) 180 | for ii, layer in self.layers if ii in out] 181 | if "fc" in out: 182 | proj += [out["fc"].unsqueeze(1)] 183 | out["projections"] = torch.cat(proj, dim=1) 184 | 185 | return out 186 | 187 | def get_num_tokens(self): 188 | num = 0 189 | if 2 in self.out_features: 190 | num += 28**2 191 | if 3 in self.out_features: 192 | num += 14**2 193 | if 4 in self.out_features: 194 | num += 7**2 195 | if "fc" in self.out_features: 196 | num += 1 197 | return num 198 | 199 | def resolutionwise_pool(self, xx): 200 | """Pool over space at each resolution, then average results.""" 201 | resolutions = [] 202 | start = 0 203 | if 2 in self.out_features: 204 | x2 = xx[:, :28**2] 205 | resolutions.append(x2) 206 | start = 28**2 207 | if 3 in self.out_features: 208 | x3 = xx[:, start:start+14**2] 209 | resolutions.append(x3) 210 | start += 14**2 211 | if 4 in self.out_features: 212 | x4 = xx[:, start:start+7**2] 213 | resolutions.append(x4) 214 | start += 7**2 215 | if "fc" in self.out_features: 216 | xfc = xx[:, start:] 217 | resolutions.append(xfc) 218 | 219 | resmeans = [] 220 | for res in resolutions: 221 | resmeans.append(torch.mean(res, 1)) 222 | 223 | return torch.mean(torch.stack(resmeans), 0) 224 | 225 | 226 | def build_image_model(cfg): 227 | architecture = cfg.MODEL.IMAGE_MODEL.ARCHITECTURE 228 | if architecture is None: 229 | return None 230 | 231 | out_features = cfg.MODEL.IMAGE_MODEL.OUTPUTS 232 | # if cfg.MODEL.COMPOSITION in MAAF_ALIASES: 233 | # out_channels = cfg.MODEL.EMBED_DIM 234 | # else: 235 | # out_channels = None 236 | out_channels = cfg.MODEL.EMBED_DIM 237 | pretrained = cfg.MODEL.IMAGE_MODEL.PRETRAINED and \ 238 | cfg.MODEL.IMAGE_MODEL.WEIGHTS is None and \ 239 | cfg.MODEL.WEIGHTS is None 240 | img_model = ResNet(architecture, out_features, out_channels=out_channels, 241 | pretrained=pretrained) 242 | 243 | if cfg.MODEL.IMAGE_MODEL.WEIGHTS is not None: 244 | # saved_state_dict = torch.load(opt.image_model_path)['state_dict'] 245 | # self.model.load_state_dict(saved_state_dict) 246 | load_pretrained_weights( 247 | model=img_model, weights_path=cfg.MODEL.IMAGE_MODEL.WEIGHTS, 248 | freeze=cfg.MODEL.IMAGE_MODEL.FREEZE_WEIGHTS, 249 | prefix_to_remove='img_model.') 250 | 251 | if cfg.MODEL.IMAGE_MODEL.FREEZE_WEIGHTS: 252 | print("Freezing Image model weights") 253 | for param in img_model.parameters(): 254 | param.requires_grad = False 255 | 256 | return img_model 257 | -------------------------------------------------------------------------------- /src/maaf/datasets/fashioniq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | import random 5 | import os 6 | import json 7 | from torch.utils.data import Dataset, DataLoader, Subset 8 | import PIL 9 | import numpy as np 10 | 11 | CATEGORIES = ["dress", "shirt", "toptee"] 12 | 13 | 14 | class FashionIQGallery(Dataset): 15 | 16 | def __init__(self, gallery, img_by_cat=None, transform=None): 17 | super().__init__() 18 | self.gallery = gallery 19 | self.transform = transform 20 | if img_by_cat is not None: 21 | self.gallery_by_cat = \ 22 | {cat: FashionIQGallery(img_by_cat[cat], transform=transform) 23 | for cat in img_by_cat} 24 | 25 | def __getitem__(self, ind): 26 | example = self.gallery[ind] 27 | image = self.get_img(ind) 28 | example["target_image"] = image 29 | example["target_text"] = None 30 | 31 | return example 32 | 33 | def get_img(self, idx, raw_img=False): 34 | """Retrieve image by global index.""" 35 | img_path = self.gallery[idx]['file_path'] 36 | try: 37 | with open(img_path, 'rb') as f: 38 | img = PIL.Image.open(f) 39 | img = img.convert('RGB') 40 | except EnvironmentError as ee: 41 | print("WARNING: EnvironmentError, defaulting to image 0", ee) 42 | img = self.get_img(0, raw_img=True) 43 | if raw_img: 44 | return img 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | return img 48 | 49 | def __len__(self): 50 | return len(self.gallery) 51 | 52 | 53 | class FashionIQTrainLoader: 54 | """ 55 | Each batch is drawn from one of the loaders at random. 56 | Iteration stops when a loader is queried and raises StopIteration, 57 | even if the other loaders have more data left. 58 | """ 59 | 60 | def __init__(self, loaders, random_seed=5222020): 61 | self.loaders = {key: iter(loaders[key]) for key in loaders} 62 | self.rng = np.random.RandomState(random_seed) 63 | 64 | def __iter__(self): 65 | return self 66 | 67 | def __next__(self): 68 | category = self.rng.choice(list(self.loaders.keys())) 69 | return next(self.loaders[category]) 70 | 71 | 72 | def __len__(self): 73 | return sum([len(ld) for ld in self.loaders.values()]) 74 | 75 | 76 | class FashionIQDataset(Dataset): 77 | 78 | def __init__(self, path, split='train', transform=None, normalize=False): 79 | super().__init__() 80 | self.categories = CATEGORIES 81 | self.normalize = normalize 82 | 83 | self.split = split 84 | self.transform = transform 85 | self.img_path = path + '/' 86 | 87 | failures = [] 88 | 89 | data = { 90 | 'image_splits': {}, 91 | 'captions': {} 92 | } 93 | 94 | def wanted_captions(filename): 95 | """Select normalized/original caption files.""" 96 | if normalize and "cap." in filename: 97 | return "normcap." in filename 98 | else: 99 | return "normcap." not in filename 100 | for data_type in data: 101 | for datafile in os.listdir(path + '/' + data_type): 102 | if split in datafile and wanted_captions(datafile): 103 | data[data_type][datafile] = \ 104 | json.load(open(path + '/' + data_type + '/' + datafile)) 105 | 106 | split_labels = sorted(list(data["image_splits"].keys())) 107 | 108 | global_imgs = [] 109 | img_by_cat = {cat: [] for cat in CATEGORIES} 110 | self.asin2id = {} 111 | for splabel in split_labels: 112 | for asin in data['image_splits'][splabel]: 113 | # if asin in failures: 114 | # continue 115 | category = splabel.split(".")[1] 116 | file_path = path + '/img/' + category + '/' + asin 117 | if os.path.exists(file_path) or split == "test": 118 | global_id = len(global_imgs) 119 | category_id = len(img_by_cat[category]) 120 | entry = [{ 121 | 'asin': asin, 122 | 'file_path': file_path, 123 | 'captions': [global_id], 124 | "image_id": global_id, 125 | "category": {category: category_id} 126 | }] 127 | if asin in self.asin2id: 128 | # handle duplicates 129 | oldglobal = self.asin2id[asin] 130 | subentry = global_imgs[oldglobal] 131 | assert category not in subentry["category"], \ 132 | "{} duplicated in {}".format(asin, category) 133 | 134 | # update entry to include additional category and id 135 | subentry["category"][category] = category_id 136 | img_by_cat[category] += [subentry] 137 | else: 138 | # just add the entry 139 | global_imgs += entry 140 | img_by_cat[category] += entry 141 | self.asin2id[asin] = global_id 142 | else: 143 | failures.append(asin) 144 | 145 | print(len(failures), " files not found in ", split) 146 | assert len(global_imgs) > 0, "no data found" 147 | 148 | queries = [] 149 | captions = sorted(list(data["captions"].keys())) 150 | for cap in captions: 151 | for query in data['captions'][cap]: 152 | if split != "test" and (query['candidate'] in failures 153 | or query.get('target') in failures): 154 | continue 155 | query['source_id'] = self.asin2id[query['candidate']] 156 | query["category"] = cap.split(".")[1] 157 | if split != "test": 158 | query['target_id'] = self.asin2id[query['target']] 159 | tarcat = global_imgs[query['target_id']]["category"] 160 | if query["category"] not in tarcat: 161 | print("WARNING: a {} found with a target in {}".format( 162 | query["category"], tarcat 163 | )) 164 | soucat = global_imgs[query['source_id']]["category"] 165 | assert query["category"] in soucat 166 | 167 | queries += [query] 168 | 169 | self.img_by_cat = img_by_cat 170 | self.data = queries 171 | self.gallery = FashionIQGallery(global_imgs, img_by_cat, 172 | transform=transform) 173 | self.data_by_category = \ 174 | {cat: Subset(self, [ii for ii in range(len(self)) 175 | if self.data[ii]["category"] == cat]) 176 | for cat in self.categories} 177 | 178 | 179 | self.id2asin = {val: key for key, val in self.asin2id.items()} 180 | 181 | def get_all_texts(self): 182 | texts = [' inadditiontothat '] 183 | for query in self.data: 184 | texts += query['captions'] 185 | return texts 186 | 187 | def __len__(self): 188 | return len(self.data) 189 | 190 | def get_loader(self, 191 | batch_size, 192 | shuffle=False, 193 | drop_last=False, 194 | num_workers=0, 195 | category=None): 196 | if category == "batchwise": 197 | loaders = { 198 | cat: self.get_loader( 199 | batch_size, shuffle=shuffle, 200 | drop_last=drop_last, num_workers=num_workers, 201 | category=cat) 202 | for cat in self.categories} 203 | return FashionIQTrainLoader(loaders) 204 | elif category is None: 205 | ds = self 206 | else: 207 | ds = self.data_by_category[category] 208 | return DataLoader( 209 | ds, 210 | batch_size=batch_size, 211 | shuffle=shuffle, 212 | num_workers=num_workers, 213 | drop_last=drop_last, 214 | collate_fn=lambda i: i) 215 | 216 | def get_gallery_loader(self, batch_size, num_workers=0, category=None): 217 | if category is None: 218 | gallery = self.gallery 219 | else: 220 | gallery = self.gallery.gallery_by_cat[category] 221 | return DataLoader( 222 | gallery, 223 | batch_size=batch_size, 224 | shuffle=False, 225 | num_workers=num_workers, 226 | drop_last=False, 227 | collate_fn=lambda i: i) 228 | 229 | def __getitem__(self, idx): 230 | example = self.data[idx] 231 | 232 | if self.split == "train": 233 | mod_str = random.choice([ 234 | example['captions'][0] + ' inadditiontothat ' + example['captions'][1], 235 | example['captions'][1] + ' inadditiontothat ' + example['captions'][0], 236 | ]) 237 | else: 238 | mod_str = example['captions'][0] + ' inadditiontothat ' + example['captions'][1] 239 | 240 | if len(mod_str) < 2: 241 | # can happen during training if a caption is tiny 242 | mod_str = example['captions'][0] + ' inadditiontothat ' + example['captions'][1] 243 | 244 | item = {key: val for key, val in example.items()} 245 | 246 | item["source_image"] = self.gallery.get_img(example['source_id']) 247 | item["source_text"] = mod_str 248 | 249 | if self.split != "test": 250 | item["target_image"] = self.gallery.get_img(example['target_id']) 251 | item["target_text"] = None 252 | 253 | item["judgment"] = 1 254 | 255 | return item 256 | 257 | def get_test_queries(self, category=None): 258 | if category is not None: 259 | return [que for que in self.data if que["category"] == category] 260 | return self.data 261 | 262 | def parse_judgment(self, judgment, loss=None): 263 | return judgment 264 | -------------------------------------------------------------------------------- /src/maaf/datasets/fashion200k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | import numpy as np 6 | import random 7 | import PIL 8 | 9 | DATASET_CLASS_NAME = "Fashion200k" 10 | 11 | 12 | class Fashion200kGallery(Dataset): 13 | 14 | def __init__(self, gallery, transform=None, img_path=""): 15 | super().__init__() 16 | self.gallery = gallery 17 | self.transform = transform 18 | self.img_path = img_path 19 | 20 | def __getitem__(self, ind): 21 | example = self.gallery[ind] 22 | image = self.get_img(ind) 23 | example["target_image"] = image 24 | example["target_text"] = None 25 | 26 | return example 27 | 28 | def __len__(self): 29 | return len(self.gallery) 30 | 31 | def get_img(self, idx, raw_img=False): 32 | img_path = self.img_path + self.gallery[idx]['file_path'] 33 | with open(img_path, 'rb') as f: 34 | img = PIL.Image.open(f) 35 | img = img.convert('RGB') 36 | if raw_img: 37 | return img 38 | if self.transform: 39 | img = self.transform(img) 40 | return img 41 | 42 | 43 | class Fashion200k(Dataset): 44 | 45 | def __init__(self, path, split='train', transform=None): 46 | raise NotImplementedError("This dataset may not be implemented correctly") 47 | super().__init__() 48 | 49 | if split != "train": 50 | split = "test" 51 | self.split = split 52 | self.transform = transform 53 | self.img_path = path + '/' 54 | 55 | # get label files for the split 56 | label_path = path + '/labels/' 57 | from os import listdir 58 | from os.path import isfile 59 | from os.path import join 60 | label_files = [ 61 | f for f in listdir(label_path) if isfile(join(label_path, f)) 62 | ] 63 | label_files = [f for f in label_files if split in f] 64 | 65 | # read image info from label files 66 | self.data = [] 67 | 68 | def caption_post_process(s): 69 | return s.strip().replace( 70 | '.', 71 | 'dotmark').replace('?', 'questionmark').replace( 72 | '&', 'andmark').replace('*', 'starmark') 73 | 74 | for filename in label_files: 75 | print('read ' + filename) 76 | with open(label_path + '/' + filename) as f: 77 | lines = f.readlines() 78 | for line in lines: 79 | line = line.split(' ') 80 | this_index = len(self.data) 81 | img = { 82 | 'file_path': line[0], 83 | 'detection_score': line[1], 84 | 'captions': [caption_post_process(line[2])], 85 | 'split': split, 86 | 'modifiable': False, 87 | "target_id": this_index 88 | } 89 | self.data += [img] 90 | print('Fashion200k:', len(self.data), 'images') 91 | 92 | # generate query for training or testing 93 | if split == 'train': 94 | self.caption_index_init_() 95 | else: 96 | self.generate_test_queries_() 97 | 98 | self.gallery = Fashion200kGallery(self.data, self.transform, self.img_path) 99 | 100 | def get_gallery_loader(self, batch_size, num_workers=0): 101 | return DataLoader( 102 | self.gallery, 103 | batch_size=batch_size, 104 | shuffle=False, 105 | num_workers=num_workers, 106 | drop_last=False, 107 | collate_fn=lambda i: i) 108 | 109 | def get_different_word(self, source_caption, target_caption): 110 | source_words = source_caption.split() 111 | target_words = target_caption.split() 112 | for source_word in source_words: 113 | if source_word not in target_words: 114 | break 115 | for target_word in target_words: 116 | if target_word not in source_words: 117 | break 118 | mod_str = 'replace ' + source_word + ' with ' + target_word 119 | return source_word, target_word, mod_str 120 | 121 | def generate_test_queries_(self): 122 | file2imgid = {} 123 | for i, img in enumerate(self.data): 124 | file2imgid[img['file_path']] = i 125 | with open(self.img_path + '/test_queries.txt') as f: 126 | lines = f.readlines() 127 | self.test_queries = [] 128 | for line in lines: 129 | source_file, target_file = line.split() 130 | idx = file2imgid[source_file] 131 | target_idx = file2imgid[target_file] 132 | source_caption = self.data[idx]['captions'][0] 133 | target_caption = self.data[target_idx]['captions'][0] 134 | source_word, target_word, mod_str = self.get_different_word( 135 | source_caption, target_caption) 136 | self.test_queries += [{ 137 | 'source_img_id': idx, 138 | 'source_caption': source_caption, 139 | 'target_caption': target_caption, 140 | 'target_id': target_idx, 141 | 'mod': {'str': mod_str} 142 | }] 143 | 144 | def caption_index_init_(self): 145 | """ index caption to generate training query-target example on the fly later""" 146 | 147 | # index caption 2 caption_id and caption 2 target_ids 148 | caption2id = {} 149 | id2caption = {} 150 | caption2imgids = {} 151 | for i, img in enumerate(self.data): 152 | for c in img['captions']: 153 | if c not in caption2id: 154 | id2caption[len(caption2id)] = c 155 | caption2id[c] = len(caption2id) 156 | caption2imgids[c] = [] 157 | caption2imgids[c].append(i) 158 | self.caption2imgids = caption2imgids 159 | print(len(caption2imgids), 'unique cations') 160 | 161 | # parent captions are 1-word shorter than their children 162 | parent2children_captions = {} 163 | for c in caption2id.keys(): 164 | for w in c.split(): 165 | p = c.replace(w, '') 166 | p = p.replace(' ', ' ').strip() 167 | if p not in parent2children_captions: 168 | parent2children_captions[p] = [] 169 | if c not in parent2children_captions[p]: 170 | parent2children_captions[p].append(c) 171 | self.parent2children_captions = parent2children_captions 172 | 173 | # identify parent captions for each image 174 | for img in self.data: 175 | img['modifiable'] = False 176 | img['parent_captions'] = [] 177 | for p in parent2children_captions: 178 | if len(parent2children_captions[p]) >= 2: 179 | for c in parent2children_captions[p]: 180 | for imgid in caption2imgids[c]: 181 | self.data[imgid]['modifiable'] = True 182 | self.data[imgid]['parent_captions'] += [p] 183 | num_modifiable_imgs = 0 184 | for img in self.data: 185 | if img['modifiable']: 186 | num_modifiable_imgs += 1 187 | print('Modifiable images', num_modifiable_imgs) 188 | 189 | def caption_index_sample_(self, idx): 190 | while not self.data[idx]['modifiable']: 191 | idx = np.random.randint(0, len(self.data)) 192 | 193 | # find random target image (same parent) 194 | img = self.data[idx] 195 | while True: 196 | p = random.choice(img['parent_captions']) 197 | c = random.choice(self.parent2children_captions[p]) 198 | if c not in img['captions']: 199 | break 200 | target_idx = random.choice(self.caption2imgids[c]) 201 | 202 | # find the word difference between query and target (not in parent caption) 203 | source_caption = self.data[idx]['captions'][0] 204 | target_caption = self.data[target_idx]['captions'][0] 205 | source_word, target_word, mod_str = self.get_different_word( 206 | source_caption, target_caption) 207 | return idx, target_idx, source_word, target_word, mod_str 208 | 209 | def get_all_texts(self): 210 | texts = [] 211 | for img in self.data: 212 | for c in img['captions']: 213 | texts.append(c) 214 | return texts 215 | 216 | def __len__(self): 217 | return len(self.data) 218 | 219 | def __getitem__(self, idx): 220 | if self.split == "train": 221 | idx, target_idx, source_word, target_word, mod_str = \ 222 | self.caption_index_sample_(idx) 223 | else: 224 | query = self.test_queries[idx] 225 | idx = query["source_img_id"] 226 | target_idx = query["target_id"] 227 | mod_str = query["mod"]["str"] 228 | 229 | out = {} 230 | out['source_id'] = idx 231 | # out['source_caption'] = self.data[idx]['captions'][0] 232 | out['target_id'] = target_idx 233 | # out['target_caption'] = self.data[target_idx]['captions'][0] 234 | 235 | out["source_image"] = self.get_img(idx) 236 | out["source_text"] = mod_str 237 | out["target_image"] = self.get_img(target_idx) 238 | out["target_text"] = None 239 | 240 | return out 241 | 242 | def get_img(self, idx, raw_img=False): 243 | img_path = self.img_path + self.data[idx]['file_path'] 244 | with open(img_path, 'rb') as f: 245 | img = PIL.Image.open(f) 246 | img = img.convert('RGB') 247 | if raw_img: 248 | return img 249 | if self.transform: 250 | img = self.transform(img) 251 | return img 252 | 253 | def get_loader(self, 254 | batch_size, 255 | shuffle=False, 256 | drop_last=False, 257 | num_workers=0, 258 | category=None): 259 | return DataLoader( 260 | self, 261 | batch_size=batch_size, 262 | shuffle=shuffle, 263 | num_workers=num_workers, 264 | drop_last=drop_last, 265 | collate_fn=lambda i: i) 266 | 267 | def parse_judgment(self, judgment, loss=None): 268 | return judgment 269 | -------------------------------------------------------------------------------- /src/maaf/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | import time 5 | import numpy as np 6 | from .actions import eval_retrieval 7 | import torch 8 | import torch.utils.data 9 | from .utils.misc_utils import tqdm # with dynamic_ncols=True 10 | from collections import defaultdict 11 | from .utils.bn_utils import apply_bn_mode 12 | 13 | 14 | class Trainer: 15 | 16 | def __init__(self, cfg, logger, dataset_dict, model, optimizer, initial_it): 17 | self.cfg = cfg 18 | self.device = torch.device(cfg.MODEL.DEVICE) 19 | self.logger = logger 20 | self.optimizer = optimizer 21 | self.step = initial_it 22 | self.epoch = initial_it // len(dataset_dict["train"]) 23 | self.scheduled_lr = False 24 | if len(cfg.SOLVER.SCHEDULE_RATES) > 0: 25 | self.scheduled_rates = \ 26 | [float(rate) for rate in cfg.SOLVER.SCHEDULE_RATES] 27 | if cfg.SOLVER.SCHEDULE_ITERS != "": 28 | self.schedule_iters = \ 29 | [int(itr) for itr in cfg.SOLVER.SCHEDULE_ITERS] 30 | self.scheduled_lr = True 31 | self.current_lr = cfg.SOLVER.LEARNING_RATE 32 | 33 | self.losses_tracking = defaultdict(list) 34 | 35 | self.dataset_dict = dataset_dict 36 | self.trainset = dataset_dict["train"] 37 | self.model = model 38 | 39 | def parse_batch(self, batch): 40 | images = [dd['image'] for dd in batch] 41 | if hasattr(self.model, "image_transform"): 42 | images = [self.model.image_transform(im) for im in images] 43 | if images[0] is not None: 44 | images = torch.stack(images).float().to(self.device) 45 | texts = [dd["text"] for dd in batch] 46 | labels = [dd["label"] for dd in batch] 47 | labels = torch.Tensor(labels).long().to(self.device) 48 | return (images, texts, labels) 49 | 50 | def train_step(self, batch): 51 | self.model.train() 52 | apply_bn_mode(self.model, self.cfg.SOLVER.BATCH_NORM_MODE) 53 | 54 | parsed_batch = self.parse_batch(batch) 55 | 56 | if self.step == 0 and parsed_batch[0][0] is not None: 57 | for ii, im in enumerate(parsed_batch[0]): 58 | self.logger.add_image(f"image_inputs_{ii}", im, self.step) 59 | 60 | losses = [] 61 | loss_value, metrics = self.model.compute_loss(*parsed_batch) 62 | 63 | loss_name = self.cfg.MODEL.LOSS 64 | loss_weight = 1.0 65 | losses += [(loss_name, loss_weight, loss_value)] 66 | total_loss = sum([ 67 | l_weight * l_value 68 | for _, l_weight, l_value in losses 69 | ]) 70 | assert not torch.isnan(total_loss) 71 | losses += [('total training loss', None, total_loss)] 72 | 73 | for key, val in metrics.items(): 74 | losses += [("train_" + key, None, val)] 75 | 76 | # track losses 77 | for l_name, l_weight, l_value in losses: 78 | self.losses_tracking[loss_name].append(loss_value.item()) 79 | for key, val in metrics.items(): 80 | self.losses_tracking[key].append(val) 81 | 82 | self.optimizer.zero_grad() 83 | total_loss.backward() 84 | self.optimizer.step() 85 | 86 | return metrics 87 | 88 | def train(self): 89 | self.model.to(self.device) 90 | 91 | tic = time.time() 92 | while self.step < self.cfg.SOLVER.NUM_ITERS: 93 | cat = "batchwise" if self.cfg.DATASET.SINGLE_CLASS_BATCHES else None 94 | trainloader = self.trainset.get_loader( 95 | batch_size=self.cfg.SOLVER.BATCH_SIZE, 96 | shuffle=True, 97 | drop_last=True, 98 | num_workers=self.cfg.DATA_LOADER.LOADER_NUM_WORKERS, 99 | category=cat) 100 | 101 | # show/log stats 102 | print("It {} epoch {} Elapse time {:.4f}".format( 103 | self.step, self.epoch, time.time() - tic 104 | )) 105 | tic = time.time() 106 | for loss_name in self.losses_tracking: 107 | the_loss = self.losses_tracking[loss_name] 108 | avg_loss = np.mean(the_loss[-len(trainloader):]) 109 | print(' ', loss_name, round(avg_loss, 4)) 110 | self.logger.add_scalar(loss_name, avg_loss, self.step) 111 | self.logger.add_scalar( 112 | 'learning_rate', self.optimizer.param_groups[0]['lr'], 113 | self.step) 114 | 115 | # test 116 | evalstep = self.epoch % self.cfg.SOLVER.EVAL_EVERY == 1 or \ 117 | self.cfg.SOLVER.EVAL_EVERY == 1 118 | if evalstep and self.epoch > 0: 119 | self.run_eval(eval_on_test=self.cfg.SOLVER.ALWAYS_EVAL_TEST) 120 | 121 | # save checkpoint 122 | torch.save({ 123 | 'it': self.step, 124 | 'model_state_dict': self.model.state_dict(), 125 | }, 126 | self.logger.file_writer.get_logdir() + '/latest_checkpoint.pth') 127 | 128 | if self.epoch % self.cfg.SOLVER.SAVE_EVERY == 0 and self.epoch > 0: 129 | torch.save({ 130 | 'it': self.step, 131 | 'model_state_dict': self.model.state_dict()}, 132 | self.logger.file_writer.get_logdir() 133 | + '/ckpt_epoch{}.pth'.format(self.epoch)) 134 | 135 | for batch in tqdm(trainloader, desc='Training for epoch ' + str(self.epoch)): 136 | self.train_step(batch) 137 | self.step += 1 138 | self.update_learning_rate() 139 | self.epoch += 1 140 | 141 | torch.save({ 142 | 'it': self.step, 143 | 'model_state_dict': self.model.state_dict(), 144 | }, 145 | self.logger.file_writer.get_logdir() + '/latest_checkpoint.pth') 146 | return self.step 147 | 148 | def simple_test(self, testset, name="val"): 149 | self.model.eval() 150 | loader = testset.get_loader( 151 | batch_size=self.cfg.SOLVER.BATCH_SIZE, 152 | shuffle=False, 153 | drop_last=False, 154 | num_workers=self.cfg.DATA_LOADER.LOADER_NUM_WORKERS) 155 | metrics = defaultdict(list) 156 | for batch in tqdm(loader): 157 | with torch.no_grad(): 158 | loss_value, met_dict = self.model.compute_loss(*self.parse_batch(batch)) 159 | for key, val in met_dict.items(): 160 | metrics[key].append(val * len(batch)) 161 | metrics = {key: sum(val) / len(testset) for key, val in metrics.items()} 162 | output = [(f"{name}_{key}", val) for key, val in metrics.items()] 163 | return output 164 | 165 | def run_eval(self, eval_on_test=False): 166 | self.model.eval() 167 | # trainset = self.dataset_dict["train"] 168 | if eval_on_test: 169 | testset = self.dataset_dict["test"] 170 | else: 171 | testset = self.dataset_dict.get("val", self.dataset_dict["test"]) 172 | 173 | tests = [] 174 | 175 | tests = self.simple_test(testset, "val") 176 | 177 | try: 178 | special_subset, name = testset.special_subset() 179 | special_results = self.simple_test(special_subset, name=name) 180 | tests += [(metric_name, metric_value) 181 | for metric_name, metric_value in special_results] 182 | except AttributeError: 183 | pass 184 | 185 | for metric_name, metric_value in tests: 186 | self.logger.add_scalar(metric_name, metric_value, self.step) 187 | print(f' {metric_name}: {metric_value:.4f}') 188 | 189 | return tests 190 | 191 | def update_learning_rate(self): 192 | if self.scheduled_lr: 193 | if len(self.schedule_iters) > 0: 194 | if self.step > self.schedule_iters[0]: 195 | lr_factor = self.scheduled_rates[0] / self.current_lr 196 | for g in self.optimizer.param_groups: 197 | g['lr'] *= lr_factor 198 | self.current_lr = self.scheduled_rates[0] 199 | 200 | del self.schedule_iters[0] 201 | del self.scheduled_rates[0] 202 | else: 203 | # decay learing rate by old method 204 | decay = False 205 | if self.step >= self.cfg.SOLVER.LEARNING_RATE_DECAY_FREQUENCY: 206 | if self.step == self.cfg.SOLVER.LEARNING_RATE_DECAY_FREQUENCY: 207 | decay = True 208 | elif self.step % self.cfg.SOLVER.LEARNING_RATE_DECAY_FREQUENCY == 0: 209 | decay = not self.cfg.SOLVER.LR_DECAY_ONLY_ONCE 210 | if decay: 211 | for g in self.optimizer.param_groups: 212 | g['lr'] *= self.cfg.SOLVER.LEARNING_RATE_DECAY 213 | 214 | 215 | class MetricTrainer(Trainer): 216 | 217 | def parse_batch(self, batch): 218 | source_img = [dd['source_image'] for dd in batch] 219 | target_img = [dd['target_image'] for dd in batch] 220 | if source_img[0] is not None: 221 | if hasattr(self.model, "image_transform"): 222 | source_img = [self.model.image_transform(im) for im in source_img] 223 | source_img = torch.stack(source_img).to(self.model.device).float() 224 | if target_img[0] is not None: 225 | if hasattr(self.model, "image_transform"): 226 | target_img = [self.model.image_transform(im) for im in target_img] 227 | target_img = torch.stack(target_img).to(self.model.device).float() 228 | 229 | source_text = [dd["source_text"] for dd in batch] 230 | target_text = [dd["target_text"] for dd in batch] 231 | 232 | if "judgment" in batch[0]: 233 | judgments = [self.trainset.parse_judgment( 234 | dd["judgment"], loss=self.cfg.MODEL.LOSS) for dd in batch] 235 | else: 236 | judgments = [None for dd in batch] 237 | 238 | return (source_img, source_text, target_img, target_text, judgments) 239 | 240 | def run_eval(self, eval_on_test=False): 241 | self.model.eval() 242 | if eval_on_test: 243 | testset = self.dataset_dict["test"] 244 | else: 245 | testset = self.dataset_dict.get("val", self.dataset_dict["test"]) 246 | 247 | try: 248 | with torch.no_grad(): 249 | test_results = testset.evaluate(self.model, self.cfg) 250 | except AttributeError: 251 | with torch.no_grad(): 252 | test_results = self.metric_eval(testset, eval_on_test=eval_on_test) 253 | 254 | for metric_name, metric_value in test_results: 255 | self.logger.add_scalar(metric_name, metric_value, self.step) 256 | print(' ', metric_name, round(metric_value, 4)) 257 | 258 | return test_results 259 | 260 | 261 | def metric_eval(self, testset, eval_on_test=False): 262 | if self.cfg.DATASET.NAME in ["fashioniq"]: 263 | categ = self.cfg.DATASET.NAME == "fashioniq" 264 | test_results = eval_retrieval.test( 265 | self.cfg, self.model, testset, filter_categories=categ) 266 | else: 267 | print(f"No special validation for {self.cfg.DATASET.NAME};" 268 | "computing average validation loss") 269 | test_results = self.simple_test(testset, "val") 270 | 271 | return test_results 272 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2019 Google LLC 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | https://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /src/maaf/actions/eval_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | """Evaluates the retrieval model.""" 6 | import numpy as np 7 | import torch 8 | from ..utils.misc_utils import tqdm # with dynamic_ncols=True 9 | import os 10 | import json 11 | 12 | 13 | def test(cfg, model, testset, filter_categories=False): 14 | if filter_categories: 15 | out = [] 16 | for category in testset.categories: 17 | print("Evaluating on", category) 18 | cat_out = _test(cfg, model, testset, category) 19 | out += [[category + name, val] for name, val in cat_out] 20 | else: 21 | out = _test(cfg, model, testset) 22 | 23 | if cfg.DATASET.NAME == "fashioniq": 24 | scores = [metric for name, metric in out if 25 | "top100" not in name and ("top10" in name or "top50" in name)] 26 | out += [["fiq_score", np.mean(scores)]] 27 | return out 28 | 29 | 30 | def _test(cfg, model, testset, category=None): 31 | """Tests a model over the given testset.""" 32 | model.eval() 33 | 34 | all_queries = compute_query_features(cfg, model, testset, 35 | category=category) 36 | if hasattr(testset, "test_queries"): 37 | test_queries = testset.test_queries 38 | all_targets = [que["target_caption"] for que in testset.test_queries] 39 | else: 40 | test_queries = testset.data 41 | if category is None: 42 | all_targets = [tq['target_id'] for tq in test_queries] 43 | else: 44 | all_targets = [testset.gallery[tq['target_id']]["category"][category] 45 | for tq in testset.data_by_category[category]] 46 | 47 | # compute all gallery features (within category if applicable) 48 | gallery, all_labels = \ 49 | compute_db_features(cfg, model, testset, category=category) 50 | if hasattr(testset, "test_queries"): 51 | all_labels = [dd["captions"][0] for dd in testset.data] 52 | 53 | nn_result, sorted_sims = nn_and_sims(testset, all_queries, gallery, 54 | all_labels, category=category) 55 | 56 | # compute recalls 57 | out = [] 58 | for k in [1, 5, 10, 50, 100]: 59 | recall = 0.0 60 | for i, nns in enumerate(nn_result): 61 | if all_targets[i] in nns[:k]: 62 | recall += 1 63 | recall /= len(nn_result) 64 | out += [('_test_recall_top' + str(k), recall)] 65 | 66 | if cfg.DATASET.NAME == "fashioniq": 67 | dirname = os.path.join(cfg.OUTPUT_DIR, cfg.EXP_NAME, "val") 68 | if not os.path.exists(dirname): 69 | os.makedirs(dirname) 70 | fname = os.path.join(dirname, "{}.predict.json".format(category)) 71 | write_fashioniq(nn_result, testset, sorted_sims, fname, category) 72 | 73 | return out 74 | 75 | 76 | def nn_and_sims(testset, all_queries, gallery, 77 | all_labels, category=None): 78 | # match test queries to target images, get nearest neighbors 79 | sims = all_queries.dot(gallery.T) 80 | data = testset.data if category is None else testset.data_by_category[category] 81 | for ii, entry in enumerate(data): 82 | if "source_id" in entry: 83 | source_id = entry["source_id"] 84 | if category is not None: 85 | # get index within category 86 | source_id = testset.gallery.gallery[source_id]["category"][category] 87 | sims[ii, source_id] = -10e10 # remove query image 88 | 89 | nn_result = [np.argsort(-sims[i, :]) for i in range(sims.shape[0])] 90 | nn_result = [[all_labels[nn] for nn in nns] for nns in nn_result] 91 | 92 | sorted_sims = [np.sort(sims[ii, :])[::-1] for ii in range(sims.shape[0])] 93 | 94 | return nn_result, sorted_sims 95 | 96 | def compute_db_features(cfg, model, testset, category=None): 97 | """Compute all gallery features.""" 98 | all_feat = [] 99 | if category is None: 100 | loader = testset.get_gallery_loader( 101 | batch_size=cfg.SOLVER.BATCH_SIZE, 102 | num_workers=cfg.DATA_LOADER.LOADER_NUM_WORKERS) 103 | else: 104 | loader = testset.get_gallery_loader( 105 | batch_size=cfg.SOLVER.BATCH_SIZE, 106 | num_workers=cfg.DATA_LOADER.LOADER_NUM_WORKERS, 107 | category=category) 108 | 109 | for batch in tqdm(loader): 110 | images = [dd["target_image"] for dd in batch] 111 | if len(images) > 0: 112 | if hasattr(model, "image_transform"): 113 | images = [model.image_transform(im) for im in images] 114 | images = torch.stack(images).float().to(model.device) 115 | texts = [dd["target_text"] for dd in batch] 116 | 117 | emb = model(images, texts).data.cpu().numpy() 118 | all_feat += [emb] 119 | 120 | all_feat = np.concatenate(all_feat) 121 | all_labels = list(range(len(testset.gallery))) 122 | return all_feat, all_labels 123 | 124 | 125 | def compute_query_features(cfg, model, testset, category=None): 126 | if category is None: 127 | loader = testset.get_loader( 128 | batch_size=cfg.SOLVER.BATCH_SIZE, 129 | shuffle=False, 130 | drop_last=False, 131 | num_workers=cfg.DATA_LOADER.LOADER_NUM_WORKERS) 132 | else: 133 | loader = testset.get_loader( 134 | batch_size=cfg.SOLVER.BATCH_SIZE, 135 | shuffle=False, 136 | drop_last=False, 137 | num_workers=cfg.DATA_LOADER.LOADER_NUM_WORKERS, 138 | category=category) 139 | 140 | all_queries = [] 141 | 142 | # compute query/source features 143 | for batch in tqdm(loader): 144 | source_img = [dd['source_image'] for dd in batch] 145 | if len(source_img) > 0: 146 | if hasattr(model, "image_transform"): 147 | source_img = [model.image_transform(im) for im in source_img] 148 | source_img = torch.stack(source_img).float().to(model.device) 149 | source_text = [dd["source_text"] for dd in batch] 150 | 151 | query_emb = model(source_img, source_text).data.cpu().numpy() 152 | all_queries += [query_emb] 153 | 154 | return np.concatenate(all_queries) 155 | 156 | 157 | def predict(cfg, model, testset, filter_categories=False): 158 | if filter_categories: 159 | for category in testset.categories: 160 | print("Evaluating on ", category) 161 | _predict(cfg, model, testset, category) 162 | else: 163 | _predict(cfg, model, testset) 164 | 165 | 166 | def _predict(cfg, model, testset, category=None): 167 | model.eval() 168 | 169 | all_queries = compute_query_features(cfg, model, testset) 170 | 171 | gallery, all_labels = \ 172 | compute_db_features(cfg, model, testset, category=category) 173 | 174 | nn_result, sorted_sims = nn_and_sims(testset, all_queries, gallery, 175 | all_labels, category=category) 176 | 177 | 178 | if cfg.DATASET.NAME == "fashioniq": 179 | dirname = os.path.join(cfg.OUTPUT_DIR, cfg.EXP_NAME, "test") 180 | if not os.path.exists(dirname): 181 | os.makedirs(dirname) 182 | fname = os.path.join(dirname, "{}.predict.json".format(category)) 183 | write_fashioniq(nn_result, testset, sorted_sims, fname, category) 184 | 185 | 186 | def write_fashioniq(results, testset, scores, fname, category, 187 | num_to_keep=100): 188 | try: 189 | output = [] 190 | for que, res, sc in zip(testset.data_by_category[category], results, scores): 191 | que_asin = testset.gallery.gallery[que["source_id"]]["asin"] 192 | res_asin = [testset.gallery.gallery_by_cat[category].gallery[res[ii]]["asin"] 193 | for ii in range(num_to_keep)] 194 | entry = {"candidate": str(que_asin), 195 | "ranking": [str(ra) for ra in res_asin], 196 | "scores": sc.tolist()[:num_to_keep]} 197 | output.append(entry) 198 | 199 | with open(fname, "w") as fh: 200 | json.dump(output, fh) 201 | except BaseException: 202 | print("Error in write_fashioniq") 203 | import IPython; IPython.embed() 204 | else: 205 | print("wrote to", fname) 206 | 207 | 208 | def ndcg(relevances, all_bad=1): 209 | """When all judgments are 0, return all_bad. Otherwise, return NDCG.""" 210 | dcg = relevances / np.log2(np.arange(len(relevances)) + 2) 211 | dcg = np.sum(dcg) 212 | sorted_rel = np.sort(relevances)[::-1] 213 | max_dcg = sorted_rel / np.log2(np.arange(len(sorted_rel)) + 2) 214 | max_dcg = np.sum(max_dcg) 215 | if max_dcg == 0: 216 | return all_bad 217 | else: 218 | return dcg / max_dcg 219 | 220 | 221 | def test_ndcg(cfg, model, dataset_dict): 222 | testset = dataset_dict["test"] 223 | all_query_emb = compute_query_features(cfg, model, testset) 224 | gallery, all_labels = compute_db_features(cfg, model, testset) 225 | 226 | ndcg_values = [] 227 | for query_data in [testset.head_query_data, testset.random_query_data]: 228 | aggregate_ndcg = [] 229 | for query, entries in query_data.items(): 230 | gallery_indices = [ent["target_id"] for ent in entries] 231 | gallery_embs = [gallery[idx] for idx in gallery_indices] 232 | query_emb = all_query_emb[testset.query_to_index[query]] 233 | dots = np.dot(gallery_embs, query_emb) 234 | sorter = np.argsort(dots)[::-1] 235 | judgments = [entries[ii]["judgment"] for ii in sorter] 236 | relevances = [0 if jj == "Bad" else 1 for jj in judgments] 237 | this_ndcg = ndcg(np.array(relevances)) 238 | aggregate_ndcg.append(this_ndcg) 239 | 240 | ndcg_values.append(np.mean(aggregate_ndcg)) 241 | 242 | out = [("head_queries_ndcg", ndcg_values[0]), 243 | ("random_queries_ndcg", ndcg_values[1])] 244 | 245 | return out 246 | 247 | def test_paired(testset, cfg, model): 248 | """ 249 | Retrieval evaluation suitable when data is entirely source-target pairs. 250 | Compute recall@k for several k, and relative rank of pairs. 251 | """ 252 | model.eval() 253 | 254 | # compute query/source and target features 255 | all_queries = [] 256 | all_targets = [] 257 | loader = testset.get_loader(cfg.SOLVER.BATCH_SIZE) 258 | for batch in tqdm(loader): 259 | source_img = [dd['source_image'] for dd in batch] 260 | if source_img[0] is not None: 261 | if hasattr(model, "image_transform"): 262 | source_img = [model.image_transform(im) for im in source_img] 263 | source_img = torch.stack(source_img).to(model.device).float() 264 | source_text = [dd["source_text"] for dd in batch] 265 | 266 | query_emb = model(source_img, source_text).data.cpu().numpy() 267 | all_queries += [query_emb] 268 | 269 | target_img = [dd['target_image'] for dd in batch] 270 | if target_img[0] is not None: 271 | if hasattr(model, "image_transform"): 272 | target_img = [model.image_transform(im) for im in target_img] 273 | target_img = torch.stack(target_img).to(model.device).float() 274 | target_text = [dd["target_text"] for dd in batch] 275 | 276 | target_emb = model(target_img, target_text).data.cpu().numpy() 277 | all_targets += [target_emb] 278 | 279 | all_queries = np.concatenate(all_queries) 280 | all_targets = np.concatenate(all_targets) 281 | 282 | # compute similarities and nearest neighbors 283 | sims = all_queries.dot(all_targets.T) 284 | nn_result = [np.argsort(-sims[i, :]) for i in range(sims.shape[0])] 285 | 286 | # compute recall@k 287 | out = [] 288 | for k in [1, 5, 10, 50, 100]: 289 | recall = 0.0 290 | for i, nns in enumerate(nn_result): 291 | if i in nns[:k]: 292 | recall += 1 293 | recall /= len(nn_result) 294 | out += [('_recall_top' + str(k), recall)] 295 | 296 | # compute ranks of matches 297 | ranks = [np.where(nn_result[ii] == ii)[0][0] for ii in range(len(nn_result))] 298 | relative = np.array(ranks) / len(nn_result) 299 | out += [('_mean_rel_rank', relative.mean())] 300 | out += [('_median_rel_rank', np.median(relative))] 301 | 302 | return out 303 | -------------------------------------------------------------------------------- /src/maaf/models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Yahoo, Licensed under the terms of the Apache License, Version 2.0. 2 | # See LICENSE file in project root for terms. 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import copy 10 | import math 11 | import numpy as np 12 | 13 | 14 | def clones(module, N): 15 | "Produce N identical layers." 16 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | "Construct a layernorm module (See citation for details)." 21 | def __init__(self, features, eps=1e-6): 22 | super(LayerNorm, self).__init__() 23 | self.a_2 = nn.Parameter(torch.ones(features)) 24 | self.b_2 = nn.Parameter(torch.zeros(features)) 25 | self.eps = eps 26 | 27 | def forward(self, x): 28 | mean = x.mean(-1, keepdim=True) 29 | std = x.std(-1, keepdim=True) 30 | # import IPython; IPython.embed() 31 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 32 | 33 | 34 | def subsequent_mask(size): 35 | "Mask out subsequent positions." 36 | attn_shape = (1, size, size) 37 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 38 | return torch.from_numpy(subsequent_mask) == 0 39 | 40 | 41 | def attention(query, key, value, mask=None, dropout=None, 42 | softmax_replacement=None): 43 | "Compute 'Scaled Dot Product Attention'" 44 | d_k = query.size(-1) 45 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 46 | / math.sqrt(d_k) 47 | 48 | if softmax_replacement is not None: 49 | scores = scores.masked_fill(mask == 0, 0.) 50 | p_attn = softmax_replacement(scores) 51 | else: 52 | if mask is not None: 53 | scores = scores.masked_fill(mask == 0, -1e9) 54 | p_attn = F.softmax(scores, dim = -1) 55 | if dropout is not None: 56 | p_attn = dropout(p_attn) 57 | return torch.matmul(p_attn, value), p_attn 58 | 59 | class MultiHeadedAttention(nn.Module): 60 | def __init__(self, h, d_model, dropout=0.1, softmax_replacement=None): 61 | "Take in model size and number of heads." 62 | super(MultiHeadedAttention, self).__init__() 63 | assert d_model % h == 0 64 | # We assume d_v always equals d_k 65 | self.d_k = d_model // h 66 | self.h = h 67 | self.linears = clones(nn.Linear(d_model, d_model), 4) 68 | self.attn = None 69 | self.dropout = nn.Dropout(p=dropout) 70 | self.softmax_replacement = softmax_replacement 71 | 72 | def forward(self, query, key, value, mask=None): 73 | "Implements Figure 2" 74 | if mask is not None: 75 | # Same mask applied to all h heads. 76 | mask = mask.unsqueeze(1) 77 | # Same mask applied to all 49 visual tokens 78 | mask = mask.unsqueeze(1) 79 | nbatches = query.size(0) 80 | # 1) Do all the linear projections in batch from d_model => h x d_k 81 | query, key, value = \ 82 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 83 | for l, x in zip(self.linears, (query, key, value))] 84 | 85 | # 2) Apply attention on all the projected vectors in batch. 86 | x, self.attn = attention(query, key, value, mask=mask, 87 | dropout=self.dropout, 88 | softmax_replacement=self.softmax_replacement) 89 | 90 | # 3) "Concat" using a view and apply a final linear. 91 | x = x.transpose(1, 2).contiguous() \ 92 | .view(nbatches, -1, self.h * self.d_k) 93 | return self.linears[-1](x) 94 | 95 | class PositionwiseFeedForward(nn.Module): 96 | "Implements FFN equation." 97 | def __init__(self, d_model, d_ff, dropout=0.1): 98 | super(PositionwiseFeedForward, self).__init__() 99 | self.w_1 = nn.Linear(d_model, d_ff) 100 | self.w_2 = nn.Linear(d_ff, d_model) 101 | self.dropout = nn.Dropout(dropout) 102 | 103 | def forward(self, x): 104 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 105 | 106 | 107 | class Embeddings(nn.Module): 108 | def __init__(self, d_model, vocab): 109 | super(Embeddings, self).__init__() 110 | self.lut = nn.Embedding(vocab, d_model) 111 | self.d_model = d_model 112 | 113 | def forward(self, x): 114 | return self.lut(x) * math.sqrt(self.d_model) 115 | 116 | class PositionalEncoding(nn.Module): 117 | "Implement the PE function." 118 | def __init__(self, d_model, dropout, max_len=5000): 119 | super(PositionalEncoding, self).__init__() 120 | self.dropout = nn.Dropout(p=dropout) 121 | 122 | # Compute the positional encodings once in log space. 123 | pe = torch.zeros(max_len, d_model) 124 | position = torch.arange(0, max_len).unsqueeze(1).float() 125 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 126 | -(math.log(10000.0) / d_model)) 127 | pe[:, 0::2] = torch.sin(position * div_term) 128 | pe[:, 1::2] = torch.cos(position * div_term) 129 | pe = pe.unsqueeze(0) 130 | self.register_buffer('pe', pe) 131 | 132 | def forward(self, x): 133 | x = x + self.pe[:, :x.size(1)] 134 | return self.dropout(x) 135 | 136 | 137 | class PositionalDecoder(nn.Module): 138 | """ 139 | Add positioning infomation and decode, without an encoder. 140 | """ 141 | def __init__(self, decoder, pos_embed): 142 | super(PositionalDecoder, self).__init__() 143 | self.decoder = decoder 144 | self.pos_embed = pos_embed 145 | 146 | def forward(self, src, src_mask, tgt, tgt_mask): 147 | "Take in and process masked src and target sequences." 148 | return self.decode(src, src_mask, tgt, tgt_mask) 149 | 150 | def decode(self, src, src_mask, tgt, tgt_mask): 151 | return self.decoder(self.pos_embed(src), tgt, tgt_mask, src_mask) 152 | 153 | 154 | class PositionalEncoder(nn.Module): 155 | """ 156 | Add positioning infomation and encode, without an decoder. 157 | """ 158 | def __init__(self, encoder, pos_embed): 159 | super().__init__() 160 | self.encoder = encoder 161 | self.pos_embed = pos_embed 162 | 163 | def forward(self, xx, mask): 164 | "Take in and process masked sequence." 165 | return self.encoder(self.pos_embed(xx), mask) 166 | 167 | 168 | class EncoderDecoder(nn.Module): 169 | """ 170 | A standard Encoder-Decoder architecture. 171 | """ 172 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 173 | super(EncoderDecoder, self).__init__() 174 | self.encoder = encoder 175 | self.decoder = decoder 176 | self.src_embed = src_embed 177 | self.tgt_embed = tgt_embed 178 | self.generator = generator 179 | 180 | def forward(self, src, tgt, src_mask, tgt_mask): 181 | "Take in and process masked src and target sequences." 182 | return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask) 183 | 184 | def encode(self, src, src_mask): 185 | return self.encoder(self.src_embed(src), src_mask) 186 | 187 | def decode(self, memory, src_mask, tgt, tgt_mask): 188 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 189 | 190 | 191 | class Encoder(nn.Module): 192 | "Core encoder is a stack of N layers" 193 | def __init__(self, layer, N): 194 | super(Encoder, self).__init__() 195 | self.layers = clones(layer, N) 196 | self.norm = LayerNorm(layer.size) 197 | 198 | def forward(self, x, mask): 199 | "Pass the input (and mask) through each layer in turn." 200 | for layer in self.layers: 201 | x = layer(x, mask) 202 | return self.norm(x) 203 | 204 | 205 | class SublayerConnection(nn.Module): 206 | """ 207 | A residual connection followed by a layer norm. 208 | Note for code simplicity the norm is first as opposed to last. 209 | """ 210 | def __init__(self, size, dropout): 211 | super(SublayerConnection, self).__init__() 212 | self.norm = LayerNorm(size) 213 | self.dropout = nn.Dropout(dropout) 214 | 215 | def forward(self, x, sublayer): 216 | "Apply residual connection to any sublayer with the same size." 217 | return x + self.dropout(sublayer(self.norm(x))) 218 | 219 | 220 | class EncoderLayer(nn.Module): 221 | "Encoder is made up of self-attn and feed forward (defined below)" 222 | def __init__(self, size, self_attn, feed_forward, dropout): 223 | super(EncoderLayer, self).__init__() 224 | self.self_attn = self_attn 225 | self.feed_forward = feed_forward 226 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 227 | self.size = size 228 | 229 | def forward(self, x, mask): 230 | "Follow Figure 1 (left) for connections." 231 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 232 | return self.sublayer[1](x, self.feed_forward) 233 | 234 | 235 | class Decoder(nn.Module): 236 | """Generic N layer decoder with masking.""" 237 | def __init__(self, layer, N): 238 | super(Decoder, self).__init__() 239 | self.layers = clones(layer, N) 240 | self.norm = LayerNorm(layer.size) 241 | 242 | def forward(self, x, memory, src_mask, tgt_mask): 243 | for layer in self.layers: 244 | x = layer(x, memory, src_mask, tgt_mask) 245 | return self.norm(x) 246 | 247 | 248 | class DecoderLayer(nn.Module): 249 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 250 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 251 | super(DecoderLayer, self).__init__() 252 | self.size = size 253 | self.self_attn = self_attn 254 | self.src_attn = src_attn 255 | self.feed_forward = feed_forward 256 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 257 | 258 | def forward(self, x, memory, src_mask, tgt_mask): 259 | "Follow Figure 1 (right) for connections." 260 | m = memory 261 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 262 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 263 | return self.sublayer[2](x, self.feed_forward) 264 | 265 | 266 | class SymmetricDecoder(nn.Module): 267 | """Like Decoder but allows 'memory' to be modified and passed around.""" 268 | def __init__(self, layer, N): 269 | super().__init__() 270 | self.layers = clones(layer, N) 271 | self.norm_x = LayerNorm(layer.size) 272 | self.norm_m = LayerNorm(layer.size) 273 | 274 | def forward(self, x, memory, src_mask, tgt_mask): 275 | for layer in self.layers: 276 | x, memory = layer(x, memory, src_mask, tgt_mask) 277 | return self.norm_x(x), self.norm_m(memory) 278 | 279 | 280 | class FlexibleDecoderLayer(nn.Module): 281 | """Decoder is made of self-attn, src-attn, and feed-forward 282 | mode format: 283 | xxx : self attn on x 284 | mmm : self attn on m 285 | xmm : x is query, m is keys and values 286 | xff: feed_forward on x 287 | xmm.mxx: cross-attn 'in parallel' with x, m inputs the same 288 | (i.e., mxx does not use output of xmm) 289 | 290 | Separate by _ for a sequence of operations. 291 | 292 | 293 | DecoderLayer is equivalent to xxx_xmm_xff, 294 | except that this layer returns (x, m) rather than just x 295 | """ 296 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout, 297 | mode="xxx_xmm_xff"): 298 | super().__init__() 299 | self.mode = mode 300 | self.size = size 301 | self.self_attn_x = self_attn 302 | self.src_attn_x = src_attn 303 | self.feed_forward_x = feed_forward 304 | self.self_attn_m = copy.deepcopy(self_attn) 305 | self.src_attn_m = copy.deepcopy(src_attn) 306 | self.feed_forward_m = copy.deepcopy(feed_forward) 307 | self.sublayer = clones(SublayerConnection(size, dropout), 308 | len(self.mode.split("_"))) 309 | 310 | def get_sublayer(self, x, m, src_mask, tgt_mask, index, spec): 311 | sublayer = self.sublayer[index] 312 | if spec == "xxx": 313 | x = sublayer(x, lambda y: self.self_attn_x(y, y, y, tgt_mask)) 314 | elif spec == "xmm": 315 | x = sublayer(x, lambda y: self.src_attn_x(y, m, m, src_mask)) 316 | elif spec == "mmm": 317 | m = sublayer(m, lambda y: self.self_attn_m(y, y, y, src_mask)) 318 | elif spec == "mxx": 319 | m = sublayer(m, lambda y: self.src_attn_m(y, x, x, tgt_mask)) 320 | elif spec == "xff": 321 | x = sublayer(x, self.feed_forward_x) 322 | elif spec == "mff": 323 | m = sublayer(m, self.feed_forward_m) 324 | elif spec == "xmm.mxx": 325 | x_temp = sublayer(x, lambda y: self.src_attn_x(y, m, m, src_mask)) 326 | m = sublayer(m, lambda y: self.src_attn_m(y, x, x, tgt_mask)) 327 | x = x_temp 328 | else: 329 | raise ValueError("Invalid attn_2stream_mode") 330 | return x, m 331 | 332 | def forward(self, x, memory, src_mask, tgt_mask): 333 | m = memory 334 | for ii, spec in enumerate(self.mode.split("_")): 335 | x, m = self.get_sublayer(x, m, src_mask, tgt_mask, ii, spec) 336 | return x, m 337 | --------------------------------------------------------------------------------