├── .editorconfig ├── .gitignore ├── .vocab_cache ├── char.json ├── coco_precomp.json ├── complete_precomp.json ├── f30k_precomp.json ├── glove_840B_coco_precomp.json.pkl ├── glove_840B_f30k_precomp.json.pkl ├── glove_coco_precomp.json.pkl ├── glove_f30k_precomp.json.pkl ├── jap_precomp.json ├── w2v_coco_precomp.json.pkl └── w2v_f30k_precomp.json.pkl ├── .vscode ├── launch.json └── settings.json ├── README.md ├── assets ├── adapt.png ├── adapt_poster.pdf └── adapt_poster.png ├── extract_features.py ├── options ├── adapt │ ├── abstract.yaml │ ├── coco │ │ ├── i2t.yaml │ │ └── t2i.yaml │ ├── f30k │ │ ├── i2t.yaml │ │ └── t2i.yaml │ └── latent_size │ │ ├── adapt_t2i_l1024.yaml │ │ ├── adapt_t2i_l128.yaml │ │ ├── adapt_t2i_l1280.yaml │ │ ├── adapt_t2i_l1536.yaml │ │ ├── adapt_t2i_l256.yaml │ │ ├── adapt_t2i_l512.yaml │ │ ├── adapt_t2i_l64.yaml │ │ └── adapt_t2i_l768.yaml ├── scan-t2i │ ├── coco.yaml │ ├── f30k.yaml │ └── scan.yaml └── vsepp │ ├── coco.yaml │ ├── f30k.yaml │ └── vsepp.yaml ├── params.py ├── requirements.txt ├── results └── print_result.py ├── retrieval ├── __init__.py ├── data │ ├── __init__.py │ ├── adapters.py │ ├── collate_fns.py │ ├── dataiterator.py │ ├── datasets.py │ ├── loaders.py │ ├── preprocessing.py │ └── tokenizer.py ├── model │ ├── __init__.py │ ├── data_parallel.py │ ├── imgenc │ │ ├── __init__.py │ │ ├── common.py │ │ ├── factory.py │ │ ├── fullencoder.py │ │ ├── pooling.py │ │ └── precomp.py │ ├── layers │ │ ├── __init__.py │ │ ├── adapt.py │ │ ├── attention.py │ │ └── convblocks.py │ ├── loss.py │ ├── model.py │ ├── similarity │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── measure.py │ │ └── similarity.py │ └── txtenc │ │ ├── __init__.py │ │ ├── embedding.py │ │ ├── factory.py │ │ ├── pooling.py │ │ └── txtenc.py ├── train │ ├── __init__.py │ ├── evaluation.py │ ├── lr_scheduler.py │ ├── optimizers.py │ ├── test.py │ └── train.py └── utils │ ├── __init__.py │ ├── file_utils.py │ ├── helper.py │ ├── layers.py │ ├── logger.py │ └── options.py ├── run.py ├── test.py ├── test_ens.py └── vocab └── align_vocab.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | 6 | [*] 7 | end_of_line = lf 8 | insert_final_newline = true 9 | charset = utf-8 10 | 11 | [*.py] 12 | trim_trailing_whitespace = true 13 | indent_style = space 14 | indent_size = 4 15 | quote_type = single 16 | max_line_length = 100 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.ipynb_checkpoints 4 | # *.json 5 | *.pth.tar 6 | .DS_Store 7 | #data/ 8 | pycocotools/ 9 | runs/ 10 | temp/ 11 | .vocab_cache/ 12 | logs/ 13 | figures/ 14 | logs/ 15 | logs2/ 16 | logs_aaai/ 17 | logs_aaai2/ 18 | jupyter/correct_predictions/ 19 | jupyter/incorrect_predictions/ 20 | jupyter/t2i_imgs/ 21 | 22 | 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | pip-wheel-metadata/ 47 | share/python-wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .nox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # celery beat schedule file 116 | celerybeat-schedule 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | -------------------------------------------------------------------------------- /.vocab_cache/char.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2idx": { 3 | "": 0, 4 | "": 1, 5 | "": 2, 6 | "": 3, 7 | " ": 4, 8 | "T": 5, 9 | "w": 6, 10 | "o": 7, 11 | "y": 8, 12 | "u": 9, 13 | "n": 10, 14 | "g": 11, 15 | "s": 12, 16 | "i": 13, 17 | "t": 14, 18 | "h": 15, 19 | "a": 16, 20 | "r": 17, 21 | "l": 18, 22 | "k": 19, 23 | "e": 20, 24 | "d": 21, 25 | ".": 22, 26 | ",": 23, 27 | "W": 24, 28 | "m": 25, 29 | "b": 26, 30 | "A": 27, 31 | "f": 28, 32 | "j": 29, 33 | "p": 30, 34 | "S": 31, 35 | "v": 32, 36 | "c": 33, 37 | "q": 34, 38 | "F": 35, 39 | "x": 36, 40 | "'": 37, 41 | "-": 38, 42 | "G": 39, 43 | "B": 40, 44 | "J": 41, 45 | "D": 42, 46 | "Y": 43, 47 | "(": 44, 48 | ")": 45, 49 | "P": 46, 50 | "M": 47, 51 | "C": 48, 52 | "2": 49, 53 | "L": 50, 54 | "z": 51, 55 | "\"": 52, 56 | "5": 53, 57 | "O": 54, 58 | "?": 55, 59 | "4": 56, 60 | "N": 57, 61 | "V": 58, 62 | "I": 59, 63 | "1": 60, 64 | "3": 61, 65 | "K": 62, 66 | "H": 63, 67 | "0": 64, 68 | ";": 65, 69 | "R": 66, 70 | "E": 67, 71 | "X": 68, 72 | "!": 69, 73 | ":": 70, 74 | "&": 71, 75 | "7": 72, 76 | "U": 73, 77 | "6": 74, 78 | "9": 75, 79 | "Q": 76, 80 | "#": 77, 81 | "8": 78, 82 | "Z": 79, 83 | "$": 80, 84 | "`": 81, 85 | "%": 82, 86 | "\u00e4": 83, 87 | "\u00fc": 84, 88 | "\u00df": 85, 89 | "\u00f6": 86, 90 | "\u00dc": 87, 91 | "\u00d6": 88, 92 | "\u00c4": 89, 93 | ">": 90, 94 | "/": 91, 95 | "\u00b4": 92, 96 | "\u00e9": 93, 97 | "_": 94, 98 | "=": 95, 99 | "+": 96, 100 | "[": 97, 101 | "\\": 98, 102 | "@": 99, 103 | "]": 100 104 | }, 105 | "char_level": true, 106 | "max_len": null 107 | } 108 | -------------------------------------------------------------------------------- /.vocab_cache/glove_840B_coco_precomp.json.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/.vocab_cache/glove_840B_coco_precomp.json.pkl -------------------------------------------------------------------------------- /.vocab_cache/glove_840B_f30k_precomp.json.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/.vocab_cache/glove_840B_f30k_precomp.json.pkl -------------------------------------------------------------------------------- /.vocab_cache/glove_coco_precomp.json.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/.vocab_cache/glove_coco_precomp.json.pkl -------------------------------------------------------------------------------- /.vocab_cache/glove_f30k_precomp.json.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/.vocab_cache/glove_f30k_precomp.json.pkl -------------------------------------------------------------------------------- /.vocab_cache/w2v_coco_precomp.json.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/.vocab_cache/w2v_coco_precomp.json.pkl -------------------------------------------------------------------------------- /.vocab_cache/w2v_f30k_precomp.json.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/.vocab_cache/w2v_f30k_precomp.json.pkl -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | 8 | { 9 | "name": "Python: Current File (Integrated Terminal)", 10 | "type": "python", 11 | "request": "launch", 12 | "program": "${file}", 13 | "console": "integratedTerminal", 14 | "args": [ 15 | "--data_path", "/home/jonatas/data/retrieval/", 16 | "--train_data", "f30k_precomp.en", 17 | "--val_data", "f30k_precomp.en", 18 | "--outpath", "temp/", 19 | "--sim", "adaptive", 20 | "--valid_interval", "500", 21 | "--image_encoder", "hierarchical", 22 | "--text_encoder", "attngru", 23 | "--text_pooling", "none", 24 | "--image_pooling", "none", 25 | "--lr", "5e-4", 26 | "--beta", "0.999", 27 | ], 28 | }, 29 | { 30 | "name": "Python: Remote Attach", 31 | "type": "python", 32 | "request": "attach", 33 | "port": 5678, 34 | "host": "localhost", 35 | "pathMappings": [ 36 | { 37 | "localRoot": "${workspaceFolder}", 38 | "remoteRoot": "." 39 | } 40 | ] 41 | }, 42 | { 43 | "name": "Python: Module", 44 | "type": "python", 45 | "request": "launch", 46 | "module": "enter-your-module-name-here", 47 | "console": "integratedTerminal" 48 | }, 49 | { 50 | "name": "Python: Django", 51 | "type": "python", 52 | "request": "launch", 53 | "program": "${workspaceFolder}/manage.py", 54 | "console": "integratedTerminal", 55 | "args": [ 56 | "runserver", 57 | "--noreload", 58 | "--nothreading" 59 | ], 60 | "django": true 61 | }, 62 | { 63 | "name": "Python: Flask", 64 | "type": "python", 65 | "request": "launch", 66 | "module": "flask", 67 | "env": { 68 | "FLASK_APP": "app.py" 69 | }, 70 | "args": [ 71 | "run", 72 | "--no-debugger", 73 | "--no-reload" 74 | ], 75 | "jinja": true 76 | }, 77 | { 78 | "name": "Python: Current File (External Terminal)", 79 | "type": "python", 80 | "request": "launch", 81 | "program": "${file}", 82 | "console": "externalTerminal" 83 | } 84 | ] 85 | } 86 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/jonatas/anaconda2/envs/retrieval/bin/python", 3 | "python.linting.pylintEnabled": true, 4 | "python.linting.enabled": true 5 | } 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Cross-modal Embeddings for Image-Text Alignment (ADAPT) 2 | 3 | This code implements a novel approach for training image-text alignment models, namely ADAPT. 4 | 5 |

6 | 7 |

8 | 9 | 10 | ADAPT is designed to adjust an intermediate representation of instances from a modality _a_ using an embedding vector of an instance from modality _b_. Such an adaptation is designed to filter and enhance important information across internal features, allowing for guided vector representations – which resembles the working of attention modules, though far more computationally efficient. For further information, please read our [AAAI 2020 paper](https://www.researchgate.net/publication/337636199_Adaptive_Cross-modal_Embeddings_for_Image-Text_Alignment). 11 | 12 | 13 | ## Table of Contents 14 | 15 | * [Installation](#installation) 16 | * [Quick start](#quickstart) 17 | * [Training models](#training) 18 | * [Pre-trained models](#pretrained) 19 | * [Citation](#citation) 20 | * [Poster](#poster) 21 | 22 | ## Installation 23 | 24 | 25 | ### 1. Python 3 & Anaconda 26 | 27 | We don't provide support for python 2. We advise you to install python 3 with [Anaconda](https://docs.anaconda.com/anaconda/install/) and then create an environment. 28 | 29 | ### 2. As standalone project 30 | 31 | ``` 32 | conda create --name adapt python=3 33 | conda activate adapt 34 | git clone https://github.com/jwehrmann/retrieval.pytorch 35 | cd retrieval.pytorch 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ### 3. Download datasets 40 | 41 | ``` 42 | wget https://scanproject.blob.core.windows.net/scan-data/data.zip 43 | ``` 44 | 45 | ## Quick start 46 | 47 | 48 | 49 | ### Setup 50 | 51 | * Option 1: 52 | 53 | ``` 54 | conda activate adapt 55 | export DATA_PATH=/path/to/dataset 56 | ``` 57 | 58 | * Option 2: 59 | 60 | You can also create a shell alias (shortcut to reference a command). For example, add this command to your shell profile: 61 | ``` 62 | alias adapt='source activate adapt && export DATA_PATH=/path/to/dataset' 63 | ``` 64 | 65 | And then only run the declared name of the alias to have everything configured: 66 | ``` 67 | $ adapt 68 | ``` 69 | 70 | ## Training Models 71 | 72 | 73 | You can reproduce our main results using the following scripts. 74 | 75 | * Training on Flickr30k: 76 | ``` 77 | python run.py options/adapt/f30k/t2i.yaml 78 | python test.py options/adapt/f30k/t2i.yaml -data_split test 79 | python run.py options/adapt/f30k/i2t.yaml 80 | python test.py options/adapt/f30k/i2t.yaml -data_split test 81 | ``` 82 | 83 | * Training on MS COCO: 84 | ``` 85 | python run.py options/adapt/coco/t2i.yaml 86 | python test.py options/adapt/coco/t2i.yaml -data_split test 87 | python run.py options/adapt/coco/i2t.yaml 88 | python test.py options/adapt/coco/i2t.yaml -data_split test 89 | ``` 90 | 91 | ### Ensembling results 92 | 93 | To ensemble multiple models (ADAPT-Ens) one can use: 94 | 95 | * MS COCO models: 96 | ``` 97 | python test_ens.py options/adapt/coco/t2i.yaml options/adapt/coco/i2t.yaml -data_split test 98 | ``` 99 | 100 | * F30k models: 101 | ``` 102 | python test_ens.py options/adapt/f30k/t2i.yaml options/adapt/f30k/i2t.yaml -data_split test 103 | ``` 104 | 105 | ### Pre-trained models 106 | 107 | 108 | We make available all the main models generated in this research. Each file has the best model of the run (according to validation result), the last checkpoint generated, all tensorboard logs (loss and recall curves), result files, and configuration options used for training. 109 | 110 | #### F30k models: 111 | 112 | | Dataset| Model | Image Annotation R@1 | Image Retrieval R@1 | 113 | |:--: | :--: | :--: | :--: | 114 | | F30k | [ADAPT-t2i](https://wehrmann.s3-us-west-2.amazonaws.com/adapt_models/f30k_adapt_t2i.tar) | 76.4% | 57.8% | 115 | | F30k | [ADAPT-i2t](https://wehrmann.s3-us-west-2.amazonaws.com/adapt_models/f30k_adapt_i2t.tar) | 66.3% | 53.8% | 116 | | F30k | ADAPT-ens | 76.2% | 60.5% | 117 | | COCO | [ADAPT-t2i](https://wehrmann.s3-us-west-2.amazonaws.com/adapt_models/coco_adapt_t2i.tar) | 75.4% | 64.0% | 118 | | COCO | [ADAPT-i2t](https://wehrmann.s3-us-west-2.amazonaws.com/adapt_models/coco_adapt_i2t.tar) | 67.2% | 57.8% | 119 | | COCO | ADAPT-ens | 75.3% | 64.4% | 120 | 121 | ## Citation 122 | 123 | 124 | If you find this research or code useful, please consider citing our paper: 125 | 126 | ``` 127 | @article{wehrmanna2020daptive, 128 | title={Adaptive Cross-modal Embeddings for Image-Text Alignment}, 129 | author={Wehrmann, J{\^o}natas and Kolling, Camila and Barros, Rodrigo C}, 130 | booktitle={The Thirty-Fourth AAAI Conference on Artificial Intelligence (AAAI 2020)}, 131 | year={2020} 132 | } 133 | ``` 134 | 135 | 136 | ## Poster 137 | 138 | 139 |

140 | 141 |

142 | -------------------------------------------------------------------------------- /assets/adapt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/assets/adapt.png -------------------------------------------------------------------------------- /assets/adapt_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/assets/adapt_poster.pdf -------------------------------------------------------------------------------- /assets/adapt_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/assets/adapt_poster.png -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from params import get_extractfeats_params 6 | from retrieval.data.loaders import get_loader 7 | from run import load_model, get_data_path, get_tokenizers 8 | from retrieval.utils.logger import create_logger 9 | from retrieval.model.similarity.measure import l2norm 10 | from retrieval.data.collate_fns import default_padding 11 | from retrieval.utils.file_utils import load_yaml_opts, load_pickle 12 | 13 | 14 | if __name__ == '__main__': 15 | args = get_extractfeats_params() 16 | opt = load_yaml_opts(args.options) 17 | logger = create_logger(level='debug' if opt.engine.debug else 'info') 18 | 19 | logger.info(f'Used args : \n{args}') 20 | logger.info(f'Used options: \n{opt}') 21 | 22 | data_path = get_data_path(opt) 23 | 24 | loader = get_loader( 25 | data_split=args.data_split, 26 | data_path=data_path, 27 | data_info=opt.dataset.train.data, 28 | loader_name=opt.dataset.loader_name, 29 | local_rank=args.local_rank, 30 | text_repr=opt.dataset.text_repr, 31 | vocab_paths=opt.dataset.vocab_paths, 32 | ngpu=torch.cuda.device_count(), 33 | **opt.dataset.val, 34 | ) 35 | 36 | tokenizer = get_tokenizers(loader) 37 | 38 | path = args.captions_path 39 | outpath_file = args.outpath 40 | outpath_folder = os.path.dirname(args.outpath) 41 | file = load_pickle(path) 42 | model = load_model(opt, [tokenizer]) 43 | model.eval() 44 | 45 | with torch.no_grad(): 46 | outfile = {} 47 | for k, v in tqdm(file.items(), total=len(file)): 48 | tv, l = default_padding([tokenizer(x) for x in v]) 49 | batch = {'caption': (tv, l)} 50 | cap = l2norm(model.embed_captions(batch).cpu(), dim=-1) 51 | 52 | torch.save(cap, outpath_folder / f'{k}.pkl') 53 | outfile[k] = cap.cpu() 54 | torch.save(outfile, outpath_file) 55 | -------------------------------------------------------------------------------- /options/adapt/abstract.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | text_repr: word 3 | loader_name: precomp 4 | train: 5 | batch_size: 128 6 | workers: 1 7 | val: 8 | batch_size: 32 9 | limit: 5000 10 | adapt: 11 | data: [] 12 | model: 13 | latent_size: 1024 14 | freeze_modules: [model.txt_enc.embed.glove,] 15 | txt_enc: 16 | name: gru_glove 17 | params: 18 | embed_dim: 300 19 | use_bi_gru: true 20 | add_rand_embed: true 21 | pooling: none 22 | devices: [cuda,] 23 | img_enc: 24 | name: simple 25 | params: 26 | img_dim: 2048 27 | devices: [cuda,] 28 | pooling: none 29 | similarity: 30 | name: adapt_t2i 31 | params: 32 | latent_size: 1024 33 | gamma: 10 34 | train_gamma: False 35 | device: cuda 36 | device: cuda 37 | optimizer: 38 | name: adamax 39 | params: 40 | lr: 0.001 # 7e-4 41 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 42 | lr_decay_epochs: [10000, 20000, 3000] #range 43 | lr_decay_rate: .25 44 | lr_scheduler: 45 | name: null 46 | params: 47 | step_size: 1000 48 | gamma: 1 49 | grad_clip: 2. 50 | criterion: 51 | margin: 0.2 52 | max_violation: False 53 | beta: 0.991 54 | engine: 55 | eval_before_training: False 56 | print_freq: 10 57 | nb_epochs: 30 58 | early_stop: 20 59 | valid_interval: 500 60 | -------------------------------------------------------------------------------- /options/adapt/coco/i2t.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/coco_precomp/adapt_i2t/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | dataset: 6 | vocab_paths: [.vocab_cache/coco_precomp.json] 7 | train: 8 | data: coco_precomp.en 9 | val: 10 | data: [coco_precomp.en] 11 | model: 12 | txt_enc: 13 | params: 14 | glove_path: '.vocab_cache/glove_coco_precomp.json.pkl' 15 | similarity: 16 | name: adapt_i2t 17 | params: 18 | latent_size: 1024 19 | gamma: 5 20 | optimizer: 21 | name: adamax 22 | params: 23 | lr: 0.0007 24 | gradual_warmup_steps: [0.5, 2.0, 16000] #torch.linspace 25 | lr_decay_epochs: [40000, 80000, 8000] #range 26 | lr_decay_rate: .25 27 | lr_scheduler: 28 | name: null 29 | grad_clip: 2. 30 | -------------------------------------------------------------------------------- /options/adapt/coco/t2i.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/coco_precomp/adapt_t2i/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | dataset: 6 | vocab_paths: [.vocab_cache/coco_precomp.json] 7 | train: 8 | data: coco_precomp.en 9 | val: 10 | data: [coco_precomp.en] 11 | model: 12 | txt_enc: 13 | params: 14 | glove_path: '.vocab_cache/glove_coco_precomp.json.pkl' 15 | similarity: 16 | name: adapt_t2i 17 | params: 18 | latent_size: 1024 19 | gamma: 5 20 | optimizer: 21 | name: adamax 22 | params: 23 | lr: 0.0007 24 | gradual_warmup_steps: [0.5, 2.0, 16000] #torch.linspace 25 | lr_decay_epochs: [40000, 80000, 8000] #range 26 | lr_decay_rate: .25 27 | lr_scheduler: 28 | name: null 29 | grad_clip: 2. 30 | -------------------------------------------------------------------------------- /options/adapt/f30k/i2t.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_i2t/ 4 | dataset: 5 | vocab_paths: [.vocab_cache/f30k_precomp.json,] 6 | train: 7 | data: f30k_precomp.en 8 | workers: 1 9 | batch_size: 128 10 | val: 11 | data: [f30k_precomp.en] 12 | workers: 1 13 | batch_size: 32 14 | limit: 5000 15 | model: 16 | txt_enc: 17 | params: 18 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 19 | similarity: 20 | name: adapt_i2t 21 | params: 22 | latent_size: 1024 23 | k: 1 24 | gamma: 10 25 | train_gamma: False 26 | device: cuda 27 | device: cuda 28 | -------------------------------------------------------------------------------- /options/adapt/f30k/t2i.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/ 4 | dataset: 5 | vocab_paths: [.vocab_cache/f30k_precomp.json,] 6 | train: 7 | data: f30k_precomp.en 8 | workers: 1 9 | batch_size: 105 10 | val: 11 | data: [f30k_precomp.en] 12 | workers: 1 13 | batch_size: 32 14 | limit: 5000 15 | model: 16 | txt_enc: 17 | params: 18 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 19 | similarity: 20 | name: adapt_t2i 21 | params: 22 | latent_size: 1024 23 | k: 1 24 | gamma: 10 25 | train_gamma: False 26 | device: cuda 27 | device: cuda 28 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l1024.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_1024/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 1024 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 1024 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l128.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_128/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 128 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 128 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l1280.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_1280/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 1280 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 1280 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l1536.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_1536/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 1536 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 1536 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l256.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_256/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 256 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 256 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l512.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_512/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 512 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 512 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l64.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_64/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 64 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 64 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/adapt/latent_size/adapt_t2i_l768.yaml: -------------------------------------------------------------------------------- 1 | __include__: ../abstract.yaml 2 | exp: 3 | outpath: logs/f30k_precomp/adapt_t2i/latent_768/ 4 | resume: null # last, best_[...], or empty (from scratch) 5 | model: 6 | latent_size: 768 7 | freeze_modules: [model.txt_enc.embed.glove,] 8 | txt_enc: 9 | name: gru_glove 10 | params: 11 | embed_dim: 300 12 | use_bi_gru: true 13 | glove_path: '.vocab_cache/glove_f30k_precomp.json.pkl' 14 | add_rand_embed: true 15 | pooling: none 16 | devices: [cuda,] 17 | img_enc: 18 | name: simple 19 | params: 20 | img_dim: 2048 21 | devices: [cuda,] 22 | pooling: none 23 | similarity: 24 | name: adapt_t2i 25 | params: 26 | latent_size: 768 27 | k: 1 28 | gamma: 10 29 | norm_output: True 30 | device: cuda # FIXME 31 | device: cuda # FIXME 32 | optimizer: 33 | name: adamax 34 | params: 35 | lr: 0.001 # 7e-4 36 | gradual_warmup_steps: [0.5, 2.0, 4000] #torch.linspace 37 | lr_decay_epochs: [10000, 20000, 3000] #range 38 | lr_decay_rate: .25 39 | lr_scheduler: 40 | name: null 41 | params: 42 | step_size: 1000 43 | gamma: 1 44 | grad_clip: 2. 45 | -------------------------------------------------------------------------------- /options/scan-t2i/coco.yaml: -------------------------------------------------------------------------------- 1 | __include__: 'scan.yaml' 2 | exp: 3 | outpath: logs/coco_precomp.en/scan-t2i-nobeta/ 4 | dataset: 5 | vocab_paths: [.vocab_cache/coco_precomp.json,] 6 | train: 7 | data: coco_precomp.en 8 | workers: 0 9 | val: 10 | data: [coco_precomp.en] 11 | workers: 0 12 | optimizer: 13 | params: 14 | lr: .0005 15 | lr_scheduler: 16 | params: 17 | step_size: 40000 18 | gamma: 0.1 19 | -------------------------------------------------------------------------------- /options/scan-t2i/f30k.yaml: -------------------------------------------------------------------------------- 1 | __include__: 'scan.yaml' 2 | exp: 3 | outpath: logs/f30k_precomp.en/scan/ 4 | dataset: 5 | vocab_paths: [.vocab_cache/f30k_precomp.json,] 6 | train: 7 | data: f30k_precomp.en 8 | val: 9 | data: [f30k_precomp.en] 10 | optimizer: 11 | lr_scheduler: 12 | params: 13 | step_size: 15000 14 | gamma: 0.1 15 | -------------------------------------------------------------------------------- /options/scan-t2i/scan.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | resume: null # last, best_[...], or empty (from scratch) 3 | dataset: 4 | vocab_paths: [.vocab_cache/coco_precomp.json,] 5 | text_repr: word 6 | loader_name: precomp 7 | train: 8 | workers: 1 9 | batch_size: 128 10 | val: 11 | workers: 1 12 | batch_size: 64 13 | adapt: 14 | data: [] 15 | model: 16 | latent_size: 1024 17 | freeze_modules: [] 18 | txt_enc: 19 | name: gru 20 | params: 21 | embed_dim: 300 22 | use_bi_gru: true 23 | pooling: none 24 | devices: [cuda,] 25 | img_enc: 26 | name: scan 27 | params: 28 | img_dim: 2048 29 | devices: [cuda,] 30 | pooling: none 31 | similarity: 32 | name: scan_t2i 33 | params: 34 | device: cuda 35 | feature_norm: clipped_l2norm 36 | smooth: 9 37 | agg_function: Mean 38 | device: cuda 39 | criterion: 40 | margin: 0.2 41 | max_violation: False 42 | beta: 0.997 43 | optimizer: 44 | name: adam 45 | import: retrieval.optimizers.factory 46 | params: 47 | lr: 0.0002 48 | lr_scheduler: 49 | name: step 50 | params: 51 | step_size: 15000 52 | gamma: 0.1 53 | grad_clip: 2. 54 | engine: 55 | eval_before_training: False 56 | debug: False 57 | print_freq: 10 58 | nb_epochs: 30 59 | early_stop: 50 60 | valid_interval: 500 61 | misc: # TODO 62 | cuda: True 63 | distributed: False # TODO 64 | seed: 1337 # TODO 65 | -------------------------------------------------------------------------------- /options/vsepp/coco.yaml: -------------------------------------------------------------------------------- 1 | __include__: 'vsepp.yaml' 2 | exp: 3 | outpath: logs/coco_precomp.en/vsepp/ 4 | dataset: 5 | train: 6 | data: coco_precomp.en 7 | workers: 0 8 | val: 9 | data: [coco_precomp.en] 10 | workers: 0 11 | adapt: 12 | data: [] 13 | optimizer: 14 | lr_scheduler: 15 | params: 16 | step_size: 32800 17 | gamma: 0.1 18 | -------------------------------------------------------------------------------- /options/vsepp/f30k.yaml: -------------------------------------------------------------------------------- 1 | __include__: 'vsepp.yaml' 2 | exp: 3 | outpath: logs/f30k_precomp.en/vsepp/ 4 | dataset: 5 | train: 6 | data: f30k_precomp.en 7 | val: 8 | data: [f30k_precomp.en] 9 | adapt: 10 | data: [] 11 | optimizer: 12 | lr_scheduler: 13 | params: 14 | step_size: 15000 15 | gamma: 0.1 16 | -------------------------------------------------------------------------------- /options/vsepp/vsepp.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | resume: null # last, best_[...], or empty (from scratch) 3 | dataset: 4 | vocab_paths: [.vocab_cache/coco_precomp.json,] 5 | text_repr: word 6 | loader_name: precomp 7 | train: 8 | workers: 1 9 | batch_size: 128 10 | val: 11 | workers: 1 12 | batch_size: 64 13 | model: 14 | latent_size: 1024 15 | freeze_modules: [] 16 | txt_enc: 17 | name: gru 18 | params: 19 | embed_dim: 300 20 | use_bi_gru: true 21 | pooling: lens 22 | devices: [cuda,] 23 | img_enc: 24 | name: vsepp_precomp 25 | params: 26 | img_dim: 2048 27 | devices: [cuda,] 28 | pooling: mean 29 | similarity: 30 | name: cosine 31 | params: 32 | device: cuda # FIXME 33 | device: cuda # FIXME 34 | criterion: 35 | margin: 0.2 36 | max_violation: True 37 | beta: 0.991 38 | optimizer: 39 | name: adam 40 | params: 41 | lr: 0.0006 42 | lr_scheduler: 43 | name: step 44 | params: 45 | step_size: 10000 46 | gamma: 0.1 47 | grad_clip: 2. 48 | engine: 49 | eval_before_training: False 50 | debug: False 51 | print_freq: 10 52 | nb_epochs: 30 53 | early_stop: 50 54 | valid_interval: 500 55 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from addict import Dict 3 | 4 | 5 | def get_train_params(): 6 | """Get arguments to train model""" 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('options', type=str, help='YAML path to training options') 9 | parser.add_argument('-local_rank', type=int, default=0) 10 | 11 | args = parser.parse_args() 12 | args = Dict(vars(args)) 13 | return args 14 | 15 | 16 | def get_test_params(ensemble=False): 17 | """Get arguments to test model""" 18 | parser = argparse.ArgumentParser() 19 | if ensemble: 20 | parser.add_argument('options', type=str, nargs='+', help='YAML paths to test options') 21 | else: 22 | parser.add_argument('options', type=str, help='YAML path to test options') 23 | parser.add_argument('-local_rank', type=int, default=0) 24 | parser.add_argument('-device', default='cuda', help='Device option to run test script') 25 | parser.add_argument('-data_split', '-s', default='dev', help='Data split to run test script') 26 | parser.add_argument('-outpath', '-o', default=None, help='Output file') 27 | 28 | args = parser.parse_args() 29 | args = Dict(vars(args)) 30 | return args 31 | 32 | 33 | def get_extractfeats_params(): 34 | """Get arguments to test model""" 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('options', type=str, help='YAML path to test options') 37 | parser.add_argument('-local_rank', type=int, default=0) 38 | parser.add_argument('-device', default='cuda', help='Device option to run test script') 39 | parser.add_argument('-data_split', '-s', default='dev', help='Data split to run test script') 40 | parser.add_argument('-captions_path', default=None, help='Captions file (.pkl)') 41 | parser.add_argument('-outpath', '-o', default=None, help='Output file') 42 | 43 | args = parser.parse_args() 44 | args = Dict(vars(args)) 45 | return args 46 | 47 | 48 | def get_vocab_alignment_params(): 49 | """Get arguments to test model""" 50 | parser = argparse.ArgumentParser() 51 | 52 | parser.add_argument('vocab_path', help='Vocabulary path') 53 | parser.add_argument('-emb_path', help='Path to glove or w2v file') 54 | parser.add_argument('-outpath', help='Output path') 55 | 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def get_vocab_builder_params(): 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('-data_path', help='Data path') 63 | parser.add_argument('-char_level', action='store_true', help='Define character or word level') 64 | parser.add_argument('-data_name', nargs='+', default=['f30k_precomp'], help='Name of dataset') 65 | parser.add_argument('-outpath', help='Output file name') 66 | args = parser.parse_args() 67 | return args 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | torchvision==0.2.2 3 | tensorboardx==1.6 4 | numpy==1.16.2 5 | addict==2.2.0 6 | nltk==3.4.5 7 | pillow==6.2.2 8 | pyaml==19.4.1 9 | argparse==1.4.0 10 | -------------------------------------------------------------------------------- /results/print_result.py: -------------------------------------------------------------------------------- 1 | ''' 2 | find ../logs/ -name *json -print0 | xargs -0 python print_result.py 3 | ''' 4 | import sys 5 | sys.path.append('../') 6 | 7 | from retrieval.utils.file_utils import load_json 8 | from collections import defaultdict 9 | 10 | files = sys.argv[1:] 11 | 12 | metrics = [ 13 | 'i2t_r1', 'i2t_r5', 'i2t_r10', 'i2t_meanr', 'i2t_medr', 14 | 't2i_r1', 't2i_r5', 't2i_r10', 't2i_meanr', 't2i_medr', 15 | ] 16 | 17 | 18 | def load_and_filter_file(file_path): 19 | 20 | result = load_json(file) 21 | print(result) 22 | result_filtered = { 23 | k.split('/')[1]: v 24 | for k, v in result.items() 25 | if k.split('/')[1] in metrics 26 | } 27 | 28 | 29 | res_line = '\t'.join( 30 | [f'{result_filtered[metric]:>4.2f}' for metric in metrics] 31 | ) 32 | _file = '/'.join(file_path.split('/')[-3:]) 33 | print( 34 | f'{_file:60s}\t{res_line}' 35 | ) 36 | 37 | 38 | def load_and_filter_file(file_path): 39 | 40 | result = load_json(file) 41 | result_filtered = defaultdict(dict) 42 | for k, v in result.items(): 43 | try: 44 | data_name, metr = k.split('/') 45 | except: 46 | continue 47 | if metr not in metrics: 48 | continue 49 | result_filtered[data_name].update({metr: v}) 50 | 51 | _file = '/'.join(file_path.split('/')[-3:-1]) 52 | for data_name, vals in result_filtered.items(): 53 | # print(data_name, [vals[m] for m in metrics]) 54 | res_line = '\t'.join( 55 | [f'{vals[metric]:>4.2f}' for metric in metrics] 56 | ) 57 | # print(_file, data_name, res_line) 58 | 59 | print( 60 | f'{_file:55s}\t{data_name:20s}\t{res_line}' 61 | ) 62 | 63 | for file in files: 64 | load_and_filter_file(file) 65 | -------------------------------------------------------------------------------- /retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from . import utils 3 | from . import data 4 | from . import train 5 | -------------------------------------------------------------------------------- /retrieval/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import collate_fns 2 | from . import datasets 3 | from . import loaders 4 | from . import tokenizer 5 | from . import adapters 6 | -------------------------------------------------------------------------------- /retrieval/data/adapters.py: -------------------------------------------------------------------------------- 1 | from ..utils.file_utils import load_json 2 | from collections import defaultdict 3 | from ..utils.logger import get_logger 4 | from pathlib import Path 5 | 6 | 7 | logger = get_logger() 8 | 9 | 10 | class Flickr: 11 | 12 | def __init__(self, data_path, data_split): 13 | 14 | data_split = data_split.replace('dev', 'val') 15 | 16 | self.data_path = Path(data_path) 17 | self.annotation_path = ( 18 | self.data_path / 'dataset_flickr30k.json' 19 | ) 20 | self.data = load_json(self.annotation_path) 21 | self.image_ids, self.img_dict, self.img_captions = self._get_img_ids(data_split) 22 | 23 | for k, v in self.img_captions.items(): 24 | assert len(v) == 5 25 | 26 | logger.info(( 27 | f'[Flickr] Loaded {len(self.img_captions)} images ' 28 | f'and {len(self.img_captions)*5} annotations.' 29 | )) 30 | 31 | def _get_img_ids(self, data_split): 32 | image_ids = [] 33 | img_dict = {} 34 | annotations = defaultdict(list) 35 | for img in self.data['images']: 36 | if img['split'].lower() != data_split.lower(): 37 | continue 38 | img_dict[img['imgid']] = img 39 | image_ids.append(img['imgid']) 40 | annotations[img['imgid']].extend( 41 | [x['raw'] for x in img['sentences']][:5] 42 | ) 43 | return image_ids, img_dict, annotations 44 | 45 | def get_image_id_by_filename(self, filename): 46 | return self.img_dict[filename]['imgid'] 47 | 48 | def get_captions_by_image_id(self, img_id): 49 | return self.img_captions[img_id] 50 | 51 | def get_filename_by_image_id(self, image_id): 52 | return ( 53 | Path('images') / 54 | Path('flickr30k_images') / 55 | self.img_dict[image_id]['filename'] 56 | ) 57 | 58 | def __call__(self, filename): 59 | return self.img_dict[filename] 60 | 61 | def __len__(self, ): 62 | return len(self.img_captions) 63 | 64 | 65 | class Coco: 66 | 67 | def __init__(self, path, data_split): 68 | 69 | data_split = data_split.replace('dev', 'val') 70 | 71 | self.data_path = Path(path) 72 | self.annotation_path = ( 73 | self.data_path / 'dataset_coco.json' 74 | ) 75 | self.data = load_json(self.annotation_path) 76 | 77 | self.image_ids, self.img_dict, self.img_captions = self._get_img_ids(data_split) 78 | 79 | for k, v in self.img_captions.items(): 80 | assert len(v) == 5 81 | 82 | logger.info(( 83 | f'[Coco] Loaded {len(self.img_captions)} images ' 84 | f'and {len(self.img_captions)*5} annotations.' 85 | )) 86 | 87 | def _get_img_ids(self, data_split): 88 | img_dict = {} 89 | image_ids = [] 90 | annotations = defaultdict(list) 91 | for img in self.data['images']: 92 | split = img['split'].lower().replace('restval', 'train') 93 | if split != data_split.lower(): 94 | continue 95 | img_dict[img['imgid']] = img 96 | image_ids.append(img['imgid']) 97 | 98 | annotations[img['imgid']].extend( 99 | [x['raw'] for x in img['sentences']][:5] 100 | ) 101 | return image_ids, img_dict, annotations 102 | 103 | def get_image_id_by_filename(self, filename): 104 | return self.img_dict[filename]['imgid'] 105 | 106 | def get_captions_by_image_id(self, img_id): 107 | return self.img_captions[img_id] 108 | 109 | def get_filename_by_image_id(self, image_id): 110 | return ( 111 | Path('images') / 112 | self.img_dict[image_id]['filename'].split('_')[1] / 113 | self.img_dict[image_id]['filename'] 114 | ) 115 | 116 | def __call__(self, filename): 117 | return self.img_dict[filename] 118 | 119 | def __len__(self, ): 120 | return len(self.img_captions) 121 | -------------------------------------------------------------------------------- /retrieval/data/collate_fns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from addict import Dict 4 | 5 | 6 | def split_array(iterable, splitters=[4,]): 7 | import itertools 8 | return [ 9 | torch.LongTensor(list(g)) 10 | for k, g in itertools.groupby( 11 | iterable, lambda x: x in splitters 12 | ) 13 | if not k 14 | ] 15 | 16 | 17 | def default_padding(captions, device=None): 18 | lengths = [len(cap) for cap in captions] 19 | targets = torch.zeros(len(captions), max(lengths)).long() 20 | 21 | for i, cap in enumerate(captions): 22 | end = lengths[i] 23 | targets[i, :end] = cap[:end] 24 | 25 | if device is None: 26 | return targets, lengths 27 | 28 | return targets.to(device), lengths 29 | 30 | 31 | def stack(x,): 32 | return torch.stack(x, 0) 33 | 34 | 35 | def no_preprocess(x,): 36 | return x 37 | 38 | 39 | def to_numpy(x,): 40 | return np.array(x) 41 | 42 | 43 | _preprocessing_fn = { 44 | 'image': stack, 45 | 'caption': default_padding, 46 | 'index': to_numpy, 47 | 'img_id': to_numpy, 48 | 'attributes': stack, 49 | } 50 | 51 | 52 | def liwe_padding(captions): 53 | splitted_caps = [] 54 | for caption in captions: 55 | sc = split_array(caption) 56 | splitted_caps.append(sc) 57 | sent_lens = np.array([len(x) for x in splitted_caps]) 58 | max_nb_steps = max(sent_lens) 59 | word_maxlen = 26 60 | targets = torch.zeros(len(captions), max_nb_steps, word_maxlen).long() 61 | for i, cap in enumerate(splitted_caps): 62 | end_sentence = sent_lens[i] 63 | for j, word in enumerate(cap): 64 | end_word = word_maxlen if len(word) > word_maxlen else len(word) 65 | targets[i, j, :end_word] = word[:end_word] 66 | 67 | return targets, sent_lens 68 | 69 | 70 | class Collate: 71 | 72 | def __init__(self, text_repr='words'): 73 | if text_repr == 'liwe': 74 | self.padding = liwe_padding 75 | else: 76 | self.padding = default_padding 77 | pass 78 | 79 | def __call__(self, data): 80 | attributes = data[0].keys() 81 | 82 | batch = Dict() 83 | if len(data[0]['caption']) == 2: 84 | words, chars = zip(*[x['caption'] for x in data]) 85 | words = default_padding(words) 86 | char = liwe_padding(chars) 87 | batch['caption'] = (words, char) 88 | else: 89 | batch['caption'] = self.padding([x['caption'][0] for x in data]) 90 | 91 | for att in attributes: 92 | if att == 'caption': 93 | continue 94 | batch[att] = _preprocessing_fn[att]([x[att] for x in data]) 95 | 96 | return batch 97 | 98 | 99 | def collate_lang_word(data): 100 | """Build mini-batch tensors from a list of (image, caption) tuples. 101 | Args: 102 | data: list of (image, caption) tuple. 103 | - image: torch tensor of shape (3, 256, 256). 104 | - caption: torch tensor of shape (?); variable length. 105 | 106 | Returns: 107 | images: torch tensor of shape (batch_size, 3, 256, 256). 108 | targets: torch tensor of shape (batch_size, padded_length). 109 | lengths: list; valid length for each padded caption. 110 | """ 111 | # Sort a data list by caption length 112 | lang_a, lang_b, ids = zip(*data) 113 | 114 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 115 | targ_a, lens_a = default_padding(lang_a) 116 | targ_b, lens_b = default_padding(lang_b) 117 | 118 | return targ_a, lens_a, targ_b, lens_b, ids 119 | 120 | 121 | def collate_lang_liwe(data): 122 | """Build mini-batch tensors from a list of (image, caption) tuples. 123 | Args: 124 | data: list of (image, caption) tuple. 125 | - image: torch tensor of shape (3, 256, 256). 126 | - caption: torch tensor of shape (?); variable length. 127 | Returns: 128 | images: torch tensor of shape (batch_size, 3, 256, 256). 129 | targets: torch tensor of shape (batch_size, padded_length). 130 | lengths: list; valid length for each padded caption. 131 | """ 132 | # Sort a data list by caption length 133 | lang_a, lang_b, ids = zip(*data) 134 | 135 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 136 | targ_a, lens_a = liwe_padding(lang_a) 137 | targ_b, lens_b = liwe_padding(lang_b) 138 | 139 | return targ_a, lens_a, targ_b, lens_b, ids 140 | -------------------------------------------------------------------------------- /retrieval/data/dataiterator.py: -------------------------------------------------------------------------------- 1 | from .loaders import prepare_ml_data 2 | from ..utils.logger import get_logger 3 | 4 | 5 | logger = get_logger() 6 | 7 | class DataIterator: 8 | 9 | def __init__(self, loader, device, non_stop=False): 10 | self.data_iter = iter(loader) 11 | self.loader = loader 12 | self.non_stop = non_stop 13 | self.device = device 14 | 15 | def __str__(self): 16 | return f'{self.loader.dataset.data_name}.{self.loader.dataset.data_split}' 17 | 18 | def next(self): 19 | try: 20 | instance = next(self.data_iter) 21 | 22 | targ_a, lens_a, targ_b, lens_b, ids = prepare_ml_data( 23 | instance, self.device 24 | ) 25 | logger.debug(( 26 | f'DataIter - CrossLang - Images: {targ_a.shape} ' 27 | f'DataIter - CrossLang - Target: {targ_a.shape} ' 28 | f'DataIter - CrossLang - Ids: {ids[:10]}\n' 29 | )) 30 | return targ_a, lens_a, targ_b, lens_b, ids 31 | 32 | except StopIteration: 33 | if self.non_stop: 34 | self.data_iter = iter(self.loader) 35 | return self.next() 36 | else: 37 | raise StopIteration( 38 | 'The data iterator has finished its job.' 39 | ) 40 | -------------------------------------------------------------------------------- /retrieval/data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | from PIL import Image 6 | from addict import Dict 7 | from pathlib import Path 8 | from torch.utils.data import Dataset 9 | from torchvision.datasets.folder import default_loader 10 | 11 | from ..utils.logger import get_logger 12 | from ..utils.file_utils import read_txt, load_pickle 13 | from .preprocessing import get_transform 14 | 15 | logger = get_logger() 16 | 17 | 18 | class Birds(Dataset): 19 | def __init__(self, data_path, data_name, transform=None, 20 | target_transform=None, data_split='train', 21 | tokenizer=None, lang=None): 22 | 23 | self.data_path = data_path 24 | self.data_name = data_name 25 | self.tokenizer = tokenizer 26 | self.target_transform = target_transform 27 | self.data_split = data_split 28 | self.__dataset_path = Path(self.data_path) / self.data_name / data_split 29 | 30 | self.embeddings = load_pickle( 31 | path=os.path.join(self.__dataset_path, 'char-CNN-RNN-embeddings.pickle'), 32 | encoding='latin1' 33 | ) 34 | 35 | self.fnames = load_pickle( 36 | path=os.path.join(self.__dataset_path, 'filenames.pickle'), 37 | encoding='latin1' 38 | ) 39 | 40 | self.images = load_pickle( 41 | path=os.path.join(self.__dataset_path, '304images.pickle'), 42 | encoding='latin1' 43 | ) 44 | 45 | self.class_info = load_pickle( 46 | path=os.path.join(self.__dataset_path, 'class_info.pickle'), 47 | encoding='latin1' 48 | ) 49 | 50 | self.captions = self._get_captions() 51 | self.transform = get_transform(data_split) 52 | 53 | self.captions_per_image = 10 54 | 55 | logger.info(f'[Birds] #Captions {len(self.captions)}') 56 | logger.info(f'[Birds] #Images {len(self.images)}') 57 | self.n = self._set_len_captions(data_split) 58 | 59 | def __getitem__(self, ix): 60 | img = Image.fromarray(self.images[ix//self.captions_per_image]).convert('RGB') 61 | if self.transform is not None: 62 | img = self.transform(img) 63 | 64 | fname = self.fnames[ix//self.captions_per_image] 65 | caption = self.captions[ix] 66 | tokens = self.tokenizer(caption) 67 | return img, tokens, ix, fname 68 | 69 | def __len__(self): 70 | return self.n 71 | 72 | def __repr__(self): 73 | return f'Birds.{self.data_name}.{self.data_split}' 74 | 75 | def __str__(self): 76 | return f'{self.data_name}.{self.data_split}' 77 | 78 | def _get_captions(self): 79 | captions = [] 80 | for fname in self.fnames: 81 | cap_file = Path(self.data_path) / self.data_name / 'text_c10' / f'{fname}.txt' 82 | with open(cap_file, 'r') as f: 83 | cap = f.readlines() 84 | captions.extend(cap) 85 | return captions 86 | 87 | def _set_len_captions(self, data_split): 88 | n = len(self.captions) 89 | if data_split in ['test', 'dev']: 90 | n = 10000 91 | return n 92 | 93 | 94 | class PrecompDataset(Dataset): 95 | """ 96 | Load precomputed captions and image features 97 | Possible options: f30k_precomp, coco_precomp 98 | """ 99 | 100 | def __init__( 101 | self, data_path, data_name, 102 | data_split, tokenizers, lang='en', 103 | ): 104 | logger.debug(f'Precomp dataset\n {[data_path, data_split, tokenizers, lang]}') 105 | self.tokenizers = tokenizers 106 | self.lang = lang 107 | self.data_split = '.'.join([data_split, lang]) 108 | self.data_path = data_path = Path(data_path) 109 | self.data_name = Path(data_name) 110 | self.full_path = self.data_path / self.data_name 111 | # Load Captions 112 | caption_file = self.full_path / f'{data_split}_caps.{lang}.txt' 113 | self.captions = read_txt(caption_file) 114 | logger.debug(f'Read captions. Found: {len(self.captions)}') 115 | 116 | # Load Image features 117 | img_features_file = self.full_path / f'{data_split}_ims.npy' 118 | self.images = np.load(img_features_file) 119 | self.length = len(self.captions) 120 | self.ids = np.loadtxt(data_path/ data_name / f'{data_split}_ids.txt', dtype=int) 121 | 122 | self.captions_per_image = 5 123 | 124 | logger.debug(f'Read feature file. Shape: {len(self.images.shape)}') 125 | 126 | # Each image must have five captions 127 | assert ( 128 | self.images.shape[0] == len(self.captions) 129 | or self.images.shape[0]*5 == len(self.captions) 130 | ) 131 | 132 | if self.images.shape[0] != len(self.captions): 133 | self.im_div = 5 134 | else: 135 | self.im_div = 1 136 | # the development set for coco is large and so validation would be slow 137 | # if data_split == 'dev' and self.length > 5000: 138 | # self.length = 5000 139 | 140 | print('Image div', self.im_div) 141 | 142 | logger.info('Precomputing captions') 143 | # self.precomp_captions = [ 144 | # self.tokenizer(x) 145 | # for x in self.captions 146 | # ] 147 | 148 | # self.maxlen = max([len(x) for x in self.precomp_captions]) 149 | # logger.info(f'Maxlen {self.maxlen}') 150 | 151 | logger.info(( 152 | f'Loaded PrecompDataset {self.data_name}/{self.data_split} with ' 153 | f'images: {self.images.shape} and captions: {self.length}.' 154 | )) 155 | 156 | def get_img_dim(self): 157 | return self.images.shape[-1] 158 | 159 | def __getitem__(self, index): 160 | # handle the image redundancy 161 | img_id = index//self.im_div 162 | image = self.images[img_id] 163 | image = torch.FloatTensor(image) 164 | 165 | # caption = self.precomp_captions[index] 166 | caption = self.captions[index] 167 | 168 | ret_caption = [] 169 | for tokenizer in self.tokenizers: 170 | tokens = tokenizer(caption) 171 | ret_caption.append(tokens) 172 | 173 | batch = Dict( 174 | image=image, 175 | caption=ret_caption, 176 | index=index, 177 | img_id=img_id, 178 | ) 179 | 180 | return batch 181 | 182 | def __len__(self): 183 | return self.length 184 | 185 | def __repr__(self): 186 | return f'PrecompDataset.{self.data_name}.{self.data_split}' 187 | 188 | def __str__(self): 189 | return f'{self.data_name}.{self.data_split}' 190 | 191 | 192 | class DummyDataset(Dataset): 193 | """ 194 | Load precomputed captions and image features 195 | Possible options: f30k_precomp, coco_precomp 196 | """ 197 | 198 | def __init__( 199 | self, data_path, data_name, 200 | data_split, tokenizer, lang='en' 201 | ): 202 | logger.debug(f'Precomp dataset\n {[data_path, data_split, tokenizer, lang]}') 203 | self.tokenizer = tokenizer 204 | 205 | self.captions = np.random.randint(0, 1000, size=(5000, 50)) 206 | logger.debug(f'Read captions. Found: {len(self.captions)}') 207 | 208 | # Load Image features 209 | self.images = np.random.uniform(size=(1000, 36, 2048)) 210 | self.length = 5000 211 | 212 | logger.debug(f'Read feature file. Shape: {len(self.images.shape)}') 213 | 214 | # Each image must have five captions 215 | assert ( 216 | self.images.shape[0] == len(self.captions) 217 | or self.images.shape[0]*5 == len(self.captions) 218 | ) 219 | 220 | if self.images.shape[0] != len(self.captions): 221 | self.im_div = 5 222 | else: 223 | self.im_div = 1 224 | # the development set for coco is large and so validation would be slow 225 | if data_split == 'dev': 226 | self.length = 5000 227 | print('Image div', self.im_div) 228 | 229 | # self.precomp_captions = [ 230 | # self.tokenizer(x) 231 | # for x in self.captions 232 | # ] 233 | 234 | # self.maxlen = max([len(x) for x in self.precomp_captions]) 235 | # logger.info(f'Maxlen {self.maxlen}') 236 | 237 | def get_img_dim(self): 238 | return self.images.shape[-1] 239 | 240 | def __getitem__(self, index): 241 | # handle the image redundancy 242 | img_id = index//self.im_div 243 | image = self.images[img_id] 244 | image = torch.FloatTensor(image) 245 | # caption = self.precomp_captions[index] 246 | caption = torch.LongTensor(self.captions[index]) 247 | 248 | return image, caption, index, img_id 249 | 250 | def __len__(self): 251 | return self.length 252 | 253 | 254 | class CrossLanguageLoader(Dataset): 255 | """ 256 | Load precomputed captions and image features 257 | Possible options: f30k_precomp, coco_precomp 258 | """ 259 | 260 | def __init__( 261 | self, data_path, data_name, data_split, 262 | tokenizers, lang='en-de', 263 | ): 264 | logger.debug(( 265 | 'CrossLanguageLoader dataset\n ' 266 | f'{[data_path, data_split, tokenizers, lang]}' 267 | )) 268 | 269 | self.data_path = Path(data_path) 270 | self.data_name = Path(data_name) 271 | self.full_path = self.data_path / self.data_name 272 | self.data_split = '.'.join([data_split, lang]) 273 | 274 | self.lang = lang 275 | 276 | assert len(tokenizers) == 1 # TODO: implement multi-tokenizer 277 | 278 | self.tokenizer = tokenizers[0] 279 | 280 | lang_base, lang_target = lang.split('-') 281 | base_filename = f'{data_split}_caps.{lang_base}.txt' 282 | target_filename = f'{data_split}_caps.{lang_target}.txt' 283 | 284 | base_file = self.full_path / base_filename 285 | target_file = self.full_path / target_filename 286 | 287 | logger.debug(f'Base: {base_file} - Target: {target_file}') 288 | # Paired files 289 | self.lang_a = read_txt(base_file) 290 | self.lang_b = read_txt(target_file) 291 | 292 | logger.debug(f'Base and target size: {(len(self.lang_a), len(self.lang_b))}') 293 | self.length = len(self.lang_a) 294 | assert len(self.lang_a) == len(self.lang_b) 295 | 296 | logger.info(( 297 | f'Loaded CrossLangDataset {self.data_name}/{self.data_split} with ' 298 | f'captions: {self.length}' 299 | )) 300 | 301 | def __getitem__(self, index): 302 | caption_a = self.lang_a[index] 303 | caption_b = self.lang_b[index] 304 | 305 | target_a = self.tokenizer(caption_a) 306 | target_b = self.tokenizer(caption_b) 307 | 308 | return target_a, target_b, index 309 | 310 | def __len__(self): 311 | return self.length 312 | 313 | def __str__(self): 314 | return f'{self.data_name}.{self.data_split}' 315 | 316 | 317 | class ImageDataset(Dataset): 318 | """ 319 | Load precomputed captions and image features 320 | Possible options: f30k_precomp, coco_precomp 321 | """ 322 | 323 | def __init__( 324 | self, data_path, data_name, 325 | data_split, tokenizer, lang='en', 326 | resize_to=256, crop_size=224, 327 | ): 328 | from .adapters import Flickr, Coco 329 | 330 | logger.debug(f'ImageDataset\n {[data_path, data_split, tokenizer, lang]}') 331 | self.tokenizer = tokenizer 332 | self.lang = lang 333 | self.data_split = data_split 334 | self.split = '.'.join([data_split, lang]) 335 | self.data_path = Path(data_path) 336 | self.data_name = Path(data_name) 337 | self.full_path = self.data_path / self.data_name 338 | 339 | self.data_wrapper = ( 340 | Flickr( 341 | self.full_path, 342 | data_split=data_split, 343 | ) if 'f30k' in data_name 344 | else Coco( 345 | self.full_path, 346 | data_split=data_split, 347 | ) 348 | ) 349 | 350 | self._fetch_captions() 351 | self.length = len(self.ids) 352 | 353 | self.transform = get_transform( 354 | data_split, resize_to=resize_to, crop_size=crop_size 355 | ) 356 | 357 | self.captions_per_image = 5 358 | 359 | if data_split == 'dev' and len(self.length) > 5000: 360 | self.length = 5000 361 | 362 | logger.debug(f'Split size: {len(self.ids)}') 363 | 364 | def _fetch_captions(self,): 365 | self.captions = [] 366 | for image_id in sorted(self.data_wrapper.image_ids): 367 | self.captions.extend( 368 | self.data_wrapper.get_captions_by_image_id(image_id)[:5] 369 | ) 370 | 371 | self.ids = range(len(self.captions)) 372 | logger.debug(f'Loaded {len(self.captions)} captions') 373 | 374 | def load_img(self, image_id): 375 | 376 | filename = self.data_wrapper.get_filename_by_image_id(image_id) 377 | feat_path = self.full_path / filename 378 | try: 379 | image = default_loader(feat_path) 380 | image = self.transform(image) 381 | except OSError: 382 | print('Error to load image: ', feat_path) 383 | image = torch.zeros(3, 224, 224,) 384 | 385 | return image 386 | 387 | def __getitem__(self, index): 388 | # handle the image redundancy 389 | seq_id = self.ids[index] 390 | image_id = self.data_wrapper.image_ids[seq_id//5] 391 | 392 | image = self.load_img(image_id) 393 | 394 | caption = self.captions[index] 395 | cap_tokens = self.tokenizer(caption) 396 | 397 | return image, cap_tokens, index, image_id 398 | 399 | def __len__(self): 400 | return self.length 401 | 402 | def __repr__(self): 403 | return f'ImageDataset.{self.data_name}.{self.split}' 404 | 405 | def __str__(self): 406 | return f'{self.data_name}.{self.split}' 407 | -------------------------------------------------------------------------------- /retrieval/data/loaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from . import datasets 5 | from . import collate_fns 6 | from .tokenizer import Tokenizer 7 | from ..utils.logger import get_logger 8 | 9 | from retrieval.utils.file_utils import load_yaml_opts, parse_loader_name 10 | 11 | 12 | logger = get_logger() 13 | 14 | __loaders__ = { 15 | 'dummy': { 16 | 'class': datasets.DummyDataset, 17 | }, 18 | 'precomp': { 19 | 'class': datasets.PrecompDataset, 20 | }, 21 | 'tensor': { 22 | 'class': datasets.PrecompDataset, 23 | }, 24 | 'lang': { 25 | 'class': datasets.CrossLanguageLoader, 26 | }, 27 | 'image': { 28 | 'class': datasets.ImageDataset, 29 | }, 30 | 'birds': { 31 | 'class': datasets.Birds, 32 | }, 33 | } 34 | 35 | def get_dataset_class(loader_name): 36 | loader = __loaders__[loader_name] 37 | return loader['class'] 38 | 39 | 40 | def prepare_ml_data(instance, device): 41 | targ_a, lens_a, targ_b, lens_b, ids = instance 42 | targ_a = targ_a.to(device).long() 43 | targ_b = targ_b.to(device).long() 44 | return targ_a, lens_a, targ_b, lens_b, ids 45 | 46 | 47 | def get_loader( 48 | loader_name, data_path, data_info, data_split, 49 | batch_size, vocab_paths, text_repr, 50 | workers=4, ngpu=1, local_rank=0, 51 | **kwargs 52 | ): 53 | data_name, lang = parse_loader_name(data_info) 54 | if not lang: 55 | lang = 'en' 56 | logger.debug('Get loader') 57 | dataset_class = get_dataset_class(loader_name) 58 | logger.debug(f'Dataset class is {dataset_class}') 59 | 60 | tokenizers = [] 61 | for vocab_path in vocab_paths: 62 | tokenizers.append(Tokenizer(vocab_path)) 63 | logger.debug(f'Tokenizer built: {tokenizers[-1]}') 64 | 65 | dataset = dataset_class( 66 | data_path=data_path, 67 | data_name=data_name, 68 | data_split=data_split, 69 | tokenizers=tokenizers, 70 | lang=lang, 71 | ) 72 | logger.debug(f'Dataset built: {dataset}') 73 | 74 | sampler = None 75 | shuffle = (data_split == 'train') 76 | if ngpu > 1: 77 | sampler = torch.utils.data.distributed.DistributedSampler( 78 | dataset, 79 | num_replicas=ngpu, 80 | rank=local_rank, 81 | ) 82 | shuffle = False 83 | 84 | collate = collate_fns.Collate(text_repr) 85 | 86 | if loader_name == 'lang' and text_repr == 'liwe': 87 | collate = collate_fns.collate_lang_liwe 88 | if loader_name == 'lang' and text_repr == 'word': 89 | collate = collate_fns.collate_lang_word 90 | 91 | loader = DataLoader( 92 | dataset=dataset, 93 | batch_size=batch_size, 94 | shuffle=shuffle, 95 | pin_memory=True, 96 | collate_fn=collate, 97 | num_workers=workers, 98 | sampler=sampler, 99 | ) 100 | logger.debug(f'Loader built: {loader}') 101 | 102 | return loader 103 | 104 | 105 | def get_loaders(data_path, local_rank, opt): 106 | train_loader = get_loader( 107 | data_split='train', 108 | data_path=data_path, 109 | data_info=opt.dataset.train.data, 110 | loader_name=opt.dataset.loader_name, 111 | local_rank=local_rank, 112 | text_repr=opt.dataset.text_repr, 113 | vocab_paths=opt.dataset.vocab_paths, 114 | ngpu=torch.cuda.device_count(), 115 | **opt.dataset.train 116 | ) 117 | 118 | val_loaders = [] 119 | for val_data in opt.dataset.val.data: 120 | val_loaders.append( 121 | get_loader( 122 | data_split='dev', 123 | data_path=data_path, 124 | data_info=val_data, 125 | loader_name=opt.dataset.loader_name, 126 | local_rank=local_rank, 127 | text_repr=opt.dataset.text_repr, 128 | vocab_paths=opt.dataset.vocab_paths, 129 | ngpu=1, 130 | **opt.dataset.val 131 | ) 132 | ) 133 | assert len(val_loaders) > 0 134 | 135 | adapt_loaders = [] 136 | for adapt_data in opt.dataset.adapt.data: 137 | adapt_loaders.append( 138 | get_loader( 139 | data_split='train', 140 | data_path=data_path, 141 | data_info=adapt_data, 142 | loader_name='lang', 143 | local_rank=local_rank, 144 | text_repr=opt.dataset.text_repr, 145 | vocab_paths=opt.dataset.vocab_paths, 146 | ngpu=1, 147 | **opt.dataset.adapt 148 | ) 149 | ) 150 | logger.info(f'Adapt loaders: {len(adapt_loaders)}') 151 | return train_loader, val_loaders, adapt_loaders 152 | 153 | 154 | # def get_loaders( 155 | # data_path, loader_name, data_name, 156 | # vocab_path, batch_size, 157 | # workers, text_repr, 158 | # splits=['train', 'val', 'test'], 159 | # langs=['en', 'en', 'en'], 160 | # ): 161 | 162 | # loaders = [] 163 | # loader_class = get_dataset_class(loader_name) 164 | # for split, lang in zip(splits, langs): 165 | # logger.debug(f'Getting loader {loader_class}/ {split} / Lang {lang}') 166 | # loader = get_loader( 167 | # loader_name=loader_name, 168 | # data_path=data_path, 169 | # data_name=data_name, 170 | # batch_size=batch_size, 171 | # workers=workers, 172 | # text_repr=text_repr, 173 | # data_split=split, 174 | # lang=lang, 175 | # vocab_path=vocab_path, 176 | # ) 177 | # loaders.append(loader) 178 | # return tuple(loaders) 179 | -------------------------------------------------------------------------------- /retrieval/data/preprocessing.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | 4 | def get_transform( 5 | split, 6 | resize_to=256, 7 | crop_size=224, 8 | ): 9 | 10 | normalizer = transforms.Normalize( 11 | mean=[0.485, 0.456, 0.406], 12 | std=[0.229, 0.224, 0.225], 13 | ) 14 | 15 | if split == 'train': 16 | t_list = [ 17 | # transforms.Resize(resize_to), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.RandomResizedCrop(crop_size), 20 | ] 21 | else: 22 | t_list = [ 23 | transforms.Resize(resize_to), 24 | transforms.CenterCrop(crop_size), 25 | ] 26 | 27 | t_list.extend([transforms.ToTensor(), normalizer]) 28 | transform = transforms.Compose(t_list) 29 | 30 | return transform 31 | 32 | -------------------------------------------------------------------------------- /retrieval/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from collections import Counter 4 | 5 | import torch 6 | from tqdm import tqdm 7 | 8 | import nltk 9 | 10 | from ..utils.logger import get_logger 11 | from ..utils.file_utils import read_txt 12 | 13 | 14 | logger = get_logger() 15 | 16 | class Vocabulary(object): 17 | """Simple vocabulary wrapper.""" 18 | 19 | def __init__(self): 20 | self.word2idx = {} 21 | self.idx2word = {} 22 | self.idx = 0 23 | 24 | def add_word(self, word): 25 | if word not in self.word2idx: 26 | self.word2idx[word] = self.idx 27 | self.idx2word[self.idx] = word 28 | self.idx += 1 29 | 30 | def get_word(self, idx): 31 | if idx in self.idx2word: 32 | return self.idx2word[idx] 33 | else: 34 | return '' 35 | 36 | def __call__(self, word): 37 | if word not in self.word2idx: 38 | return self.word2idx[''] 39 | return self.word2idx[word] 40 | 41 | def __len__(self): 42 | return len(self.word2idx) 43 | 44 | 45 | class Tokenizer(object): 46 | """ 47 | This class converts texts into character or word-level tokens 48 | """ 49 | 50 | def __init__( 51 | self, vocab_path=None, char_level=False, 52 | maxlen=None, download_tokenizer=False 53 | ): 54 | # Create a vocab wrapper and add some special tokens. 55 | self.char_level = char_level 56 | self.maxlen = maxlen 57 | 58 | vocab = Vocabulary() 59 | vocab.add_word('') 60 | vocab.add_word('') 61 | vocab.add_word('') 62 | vocab.add_word('') 63 | 64 | if char_level: 65 | vocab.add_word(' ') # Space is allways #2 66 | 67 | self.vocab = vocab 68 | 69 | if download_tokenizer: 70 | nltk.download('punkt') 71 | 72 | if vocab_path is not None: 73 | self.load(vocab_path) 74 | 75 | logger.info(f'Loaded from {vocab_path}.') 76 | logger.info(f'Created tokenizer with init {len(self.vocab)} tokens.') 77 | 78 | def fit_on_files(self, txt_files): 79 | logger.debug('Fit on files.') 80 | for file in txt_files: 81 | logger.info(f'Updating vocab with {file}') 82 | sentences = read_txt(file) 83 | self.fit(sentences) 84 | 85 | def fit(self, sentences, threshold=4): 86 | logger.debug( 87 | f'Fit word vocab on {len(sentences)} and t={threshold}' 88 | ) 89 | counter = Counter() 90 | 91 | for tokens in tqdm(sentences, total=len(sentences)): 92 | if not self.char_level: 93 | tokens = self.split_sentence(tokens) 94 | counter.update(tokens) 95 | 96 | # Discard if the occurrence of the word is less than threshold 97 | tokens = [ 98 | token for token, cnt in counter.items() 99 | if cnt >= threshold 100 | ] 101 | 102 | # Add words to the vocabulary. 103 | for token in tokens: 104 | self.vocab.add_word(token) 105 | 106 | logger.info(f'Vocab built. Tokens found {len(self.vocab)}') 107 | return self.vocab 108 | 109 | def save(self, outpath): 110 | logger.debug(f'Saving vocab to {outpath}') 111 | 112 | state = { 113 | 'word2idx': self.vocab.word2idx, 114 | 'char_level': self.char_level, 115 | 'max_len': self.maxlen 116 | } 117 | 118 | with open(outpath, "w") as f: 119 | json.dump(state, f) 120 | 121 | logger.info( 122 | f'Vocab stored into {outpath} with {len(self.vocab)} tokens.' 123 | ) 124 | 125 | def load(self, path): 126 | logger.debug(f'Loading vocab from {path}') 127 | with open(path) as f: 128 | state = json.load(f) 129 | 130 | vocab = Vocabulary() 131 | vocab.word2idx = state['word2idx'] 132 | vocab.idx2word = {v: k for k, v in vocab.word2idx.items()} 133 | vocab.idx = max(vocab.idx2word) 134 | self.vocab = vocab 135 | self.char_level = state['char_level'] 136 | self.maxlen = state['max_len'] 137 | logger.info(f'Loaded vocab containing {len(self.vocab)} tokens') 138 | return self 139 | 140 | def split_sentence(self, sentence): 141 | tokens = nltk.tokenize.word_tokenize( 142 | sentence.lower() 143 | ) 144 | return tokens 145 | 146 | def tokens_to_int(self, tokens): 147 | return [self.vocab(token) for token in tokens] 148 | 149 | def tokenize(self, sentence): 150 | tokens = self.split_sentence(sentence) 151 | if self.char_level: 152 | tokens = ' '.join(tokens) 153 | tokens = ( 154 | [self.vocab('')] 155 | + self.tokens_to_int(tokens) 156 | + [self.vocab('')] 157 | ) 158 | return torch.LongTensor(tokens) 159 | 160 | def decode_tokens(self, tokens): 161 | logger.debug(f'Decode tokens {tokens}') 162 | join_char = '' if self.char_level else ' ' 163 | text = join_char.join([ 164 | self.vocab.get_word(token) for token in tokens 165 | ]) 166 | return text 167 | 168 | def __len__(self): 169 | return len(self.vocab) 170 | 171 | def __call__(self, sentence): 172 | return self.tokenize(sentence) 173 | -------------------------------------------------------------------------------- /retrieval/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Retrieval 2 | from . import model 3 | from . import loss 4 | from . import similarity 5 | from . import imgenc 6 | from . import txtenc 7 | from . import layers 8 | from . import data_parallel 9 | -------------------------------------------------------------------------------- /retrieval/model/data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parallel._functions import Gather 4 | 5 | 6 | def gather(outputs, target_device, dim=0): 7 | r""" 8 | Gathers tensors from different GPUs on a specified device 9 | (-1 means the CPU). 10 | """ 11 | def gather_map(outputs): 12 | out = outputs[0] 13 | if torch.is_tensor(out): 14 | return Gather.apply(target_device, dim, *outputs) 15 | if out is None: 16 | return None 17 | if isinstance(out, dict): 18 | if not all((len(out) == len(d) for d in outputs)): 19 | raise ValueError('All dicts must have the same number of keys') 20 | return type(out)(((k, gather_map([d[k] for d in outputs])) 21 | for k in out)) 22 | return type(out)(map(gather_map, zip(*outputs))) 23 | 24 | # Recursive function calls like this create reference cycles. 25 | # Setting the function to None clears the refcycle. 26 | try: 27 | return gather_map(outputs) 28 | finally: 29 | gather_map = None 30 | 31 | 32 | class DataParallel(nn.DataParallel): 33 | 34 | def __getattr__(self, key): 35 | try: 36 | return super(DataParallel, self).__getattr__(key) 37 | except AttributeError: 38 | return self.module.__getattribute__(key) 39 | 40 | def state_dict(self, *args, **kwgs): 41 | return self.module.state_dict(*args, **kwgs) 42 | 43 | def load_state_dict(self, *args, **kwgs): 44 | self.module.load_state_dict(*args, **kwgs) 45 | 46 | def gather(self, outputs, output_device): 47 | return gather(outputs, output_device, dim=self.dim) 48 | 49 | 50 | class DistributedDataParallel(nn.parallel.DistributedDataParallel): 51 | 52 | def __getattr__(self, key): 53 | try: 54 | return super(DistributedDataParallel, self).__getattr__(key) 55 | except AttributeError: 56 | return self.module.__getattribute__(key) 57 | 58 | def state_dict(self, *args, **kwgs): 59 | return self.module.state_dict(*args, **kwgs) 60 | 61 | def load_state_dict(self, *args, **kwgs): 62 | self.module.load_state_dict(*args, **kwgs) 63 | -------------------------------------------------------------------------------- /retrieval/model/imgenc/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import get_image_encoder, get_available_imgenc, get_img_pooling -------------------------------------------------------------------------------- /retrieval/model/imgenc/common.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | def load_state_dict_with_replace(state_dict, own_state): 5 | new_state = OrderedDict() 6 | for name, param in state_dict.items(): 7 | if name in own_state: 8 | new_state[name] = param 9 | return new_state 10 | -------------------------------------------------------------------------------- /retrieval/model/imgenc/factory.py: -------------------------------------------------------------------------------- 1 | from . import precomp 2 | from . import fullencoder 3 | from . import pooling 4 | import torchvision 5 | 6 | 7 | _image_encoders = { 8 | 'simple': { 9 | 'class': precomp.SimplePrecomp, 10 | 'args': {} 11 | }, 12 | 'scan': { 13 | 'class': precomp.SCANImagePrecomp, 14 | 'args': { 15 | 'img_dim': 2048, 16 | }, 17 | }, 18 | 'vsepp_precomp': { 19 | 'class': precomp.VSEImageEncoder, 20 | 'args': { 21 | 'img_dim': 2048, 22 | }, 23 | }, 24 | 'full_image': { 25 | 'class': fullencoder.ImageEncoder, 26 | 'args': { 27 | 'cnn': torchvision.models.resnet152, 28 | 'img_dim': 248, 29 | }, 30 | }, 31 | 'resnet50': { 32 | 'class': fullencoder.FullImageEncoder, 33 | 'args': { 34 | 'cnn': torchvision.models.resnet50, 35 | 'img_dim': 2048, 36 | }, 37 | }, 38 | 'resnet50_ft': { 39 | 'class': fullencoder.FullImageEncoder, 40 | 'args': { 41 | 'cnn': torchvision.models.resnet50, 42 | 'img_dim': 2048, 43 | 'finetune': True, 44 | }, 45 | }, 46 | 'resnet101': { 47 | 'class': fullencoder.FullImageEncoder, 48 | 'args': { 49 | 'cnn': torchvision.models.resnet101, 50 | 'img_dim': 2048, 51 | 'proj_regions': False, 52 | }, 53 | }, 54 | 'resnet101_ft': { 55 | 'class': fullencoder.FullImageEncoder, 56 | 'args': { 57 | 'cnn': torchvision.models.resnet101, 58 | 'img_dim': 2048, 59 | 'finetune': True, 60 | }, 61 | }, 62 | 'resnet152': { 63 | 'class': fullencoder.FullImageEncoder, 64 | 'args': { 65 | 'cnn': torchvision.models.resnet152, 66 | 'img_dim': 2048, 67 | 'proj_regions': False, 68 | }, 69 | }, 70 | 'resnet152_ft': { 71 | 'class': fullencoder.FullImageEncoder, 72 | 'args': { 73 | 'cnn': torchvision.models.resnet152, 74 | 'img_dim': 2048, 75 | 'finetune': True, 76 | }, 77 | }, 78 | 'vsepp_pt': { 79 | 'class': fullencoder.VSEPPEncoder, 80 | 'args': { 81 | 'cnn_type': 'resnet152', 82 | }, 83 | }, 84 | } 85 | 86 | 87 | def get_available_imgenc(): 88 | return _image_encoders.keys() 89 | 90 | 91 | # def get_image_encoder(name, **kwargs): 92 | # model_settings = _image_encoders[name] 93 | # model_class = model_settings['class'] 94 | # model_args = model_settings['args'] 95 | # arg_dict = dict(kwargs) 96 | # arg_dict.update(model_args) 97 | # model = model_class(**arg_dict) 98 | # return model 99 | 100 | 101 | def get_image_encoder(name, **kwargs): 102 | model_class = _image_encoders[name]['class'] 103 | model = model_class(**kwargs) 104 | return model 105 | 106 | 107 | def get_img_pooling(pool_name): 108 | 109 | _pooling = { 110 | 'mean': pooling.mean_pooling, 111 | 'max': pooling.max_pooling, 112 | 'none': lambda x: x, 113 | } 114 | 115 | return _pooling[pool_name] 116 | -------------------------------------------------------------------------------- /retrieval/model/imgenc/fullencoder.py: -------------------------------------------------------------------------------- 1 | '''Neural Network Assembler and Extender''' 2 | import types 3 | from typing import Dict, List 4 | 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torchvision import models 9 | 10 | from ...model.layers import attention, convblocks 11 | from ...utils.layers import default_initializer 12 | from ..similarity.measure import l1norm, l2norm 13 | from .common import load_state_dict_with_replace 14 | 15 | 16 | class BaseFeatures(nn.Module): 17 | 18 | def __init__(self, model): 19 | super(BaseFeatures, self).__init__() 20 | self.conv1 = model.conv1 21 | self.bn1 = model.bn1 22 | self.relu = model.relu 23 | self.maxpool = model.maxpool 24 | self.layer1 = model.layer1 25 | self.layer2 = model.layer2 26 | self.layer3 = model.layer3 27 | self.layer4 = model.layer4 28 | self.num_features = model.fc.in_features 29 | 30 | def forward(self, _input): 31 | x = self.conv1(_input) 32 | x = self.bn1(x) 33 | x = self.relu(x) 34 | x = self.maxpool(x) 35 | 36 | x = self.layer1(x) 37 | x = self.layer2(x) 38 | x = self.layer3(x) 39 | x = self.layer4(x) 40 | return x 41 | 42 | 43 | class HierarchicalFeatures(nn.Module): 44 | 45 | def __init__(self, model): 46 | super().__init__() 47 | self.conv1 = model.conv1 48 | self.bn1 = model.bn1 49 | self.relu = model.relu 50 | self.maxpool = model.maxpool 51 | self.layer1 = model.layer1 52 | self.layer2 = model.layer2 53 | self.layer3 = model.layer3 54 | self.layer4 = model.layer4 55 | self.num_features = model.fc.in_features 56 | 57 | def forward(self, _input): 58 | x = self.conv1(_input) 59 | x = self.bn1(x) 60 | x = self.relu(x) 61 | x = self.maxpool(x) 62 | 63 | a = self.layer1(x) 64 | b = self.layer2(a) 65 | c = self.layer3(b) 66 | d = self.layer4(c) 67 | 68 | return a, b, c, d 69 | 70 | 71 | class Aggregate(nn.Module): 72 | 73 | def __init__(self, ): 74 | super(Aggregate, self).__init__() 75 | 76 | def forward(self, _input): 77 | x = nn.AdaptiveAvgPool2d(1)(_input) 78 | x = x.squeeze(2).squeeze(2) 79 | return x 80 | 81 | 82 | class ImageEncoder(nn.Module): 83 | 84 | def __init__( 85 | self, cnn, img_dim, 86 | latent_size, pretrained=True, 87 | ): 88 | super().__init__() 89 | self.latent_size = latent_size 90 | 91 | # Full text encoder 92 | self.cnn = BaseFeatures(cnn(pretrained)) 93 | 94 | def forward(self, images): 95 | """Extract image feature vectors.""" 96 | # assuming that the precomputed features are already l2-normalized 97 | features = self.cnn(images) 98 | B, D, H, W = features.shape 99 | features = features.view(B, D, H*W) 100 | return features 101 | 102 | def load_state_dict(self, state_dict): 103 | """Copies parameters. overwritting the default one to 104 | accept state_dict from Full model 105 | """ 106 | new_state = load_state_dict_with_replace( 107 | state_dict=state_dict, own_state=self.state_dict() 108 | ) 109 | 110 | super().load_state_dict(new_state) 111 | 112 | # tutorials/09 - Image Captioning 113 | class VSEPPEncoder(nn.Module): 114 | 115 | def __init__(self, latent_size, finetune=False, cnn_type='vgg19', 116 | use_abs=False, no_imgnorm=False): 117 | """Load pretrained VGG19 and replace top fc layer.""" 118 | super(VSEPPEncoder, self).__init__() 119 | self.latent_size = latent_size 120 | self.no_imgnorm = no_imgnorm 121 | self.use_abs = use_abs 122 | 123 | # Load a pre-trained model 124 | self.cnn = self.get_cnn(cnn_type, True) 125 | 126 | if finetune: 127 | if cnn_type.startswith('alexnet') or cnn_type.startswith('vgg'): 128 | model.features = nn.DataParallel(model.features) 129 | model.cuda() 130 | else: 131 | model = nn.DataParallel(model).cuda() 132 | 133 | # For efficient memory usage. 134 | for param in self.cnn.parameters(): 135 | param.requires_grad = finetune 136 | 137 | # Replace the last fully connected layer of CNN with a new one 138 | if cnn_type.startswith('vgg'): 139 | self.fc = nn.Linear( 140 | self.cnn.classifier._modules['6'].in_features, 141 | latent_size 142 | ) 143 | self.cnn.classifier = nn.Sequential( 144 | *list(self.cnn.classifier.children())[:-1] 145 | ) 146 | elif cnn_type.startswith('resnet'): 147 | if hasattr(self.cnn, 'module'): 148 | self.fc = nn.Linear(self.cnn.module.fc.in_features, latent_size) 149 | self.cnn.module.fc = nn.Sequential() 150 | else: 151 | self.fc = nn.Linear(self.cnn.fc.in_features, latent_size) 152 | self.cnn.fc = nn.Sequential() 153 | 154 | self.init_weights() 155 | 156 | def get_cnn(self, arch, pretrained): 157 | """Load a pretrained CNN and parallelize over GPUs 158 | """ 159 | if pretrained: 160 | print("=> using pre-trained model '{}'".format(arch)) 161 | model = models.__dict__[arch](pretrained=True) 162 | else: 163 | print("=> creating model '{}'".format(arch)) 164 | model = models.__dict__[arch]() 165 | 166 | return model 167 | 168 | def load_state_dict(self, state_dict): 169 | """ 170 | Handle the models saved before commit pytorch/vision@989d52a 171 | """ 172 | if 'cnn.classifier.1.weight' in state_dict: 173 | state_dict['cnn.classifier.0.weight'] = state_dict[ 174 | 'cnn.classifier.1.weight'] 175 | del state_dict['cnn.classifier.1.weight'] 176 | state_dict['cnn.classifier.0.bias'] = state_dict[ 177 | 'cnn.classifier.1.bias'] 178 | del state_dict['cnn.classifier.1.bias'] 179 | state_dict['cnn.classifier.3.weight'] = state_dict[ 180 | 'cnn.classifier.4.weight'] 181 | del state_dict['cnn.classifier.4.weight'] 182 | state_dict['cnn.classifier.3.bias'] = state_dict[ 183 | 'cnn.classifier.4.bias'] 184 | del state_dict['cnn.classifier.4.bias'] 185 | 186 | super(EncoderImageFull, self).load_state_dict(state_dict) 187 | 188 | def init_weights(self): 189 | """Xavier initialization for the fully connected layer 190 | """ 191 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 192 | self.fc.out_features) 193 | self.fc.weight.data.uniform_(-r, r) 194 | self.fc.bias.data.fill_(0) 195 | 196 | def forward(self, images): 197 | """Extract image feature vectors.""" 198 | features = self.cnn(images) 199 | # normalization in the image embedding space 200 | features = l2norm(features, dim=-1) 201 | # linear projection to the joint embedding space 202 | features = self.fc(features) 203 | 204 | # normalization in the joint embedding space 205 | if not self.no_imgnorm: 206 | features = l2norm(features, dim=-1) 207 | 208 | # take the absolute value of the embedding (used in order embeddings) 209 | if self.use_abs: 210 | features = torch.abs(features) 211 | 212 | return features 213 | 214 | 215 | class FullImageEncoder(nn.Module): 216 | 217 | def __init__( 218 | self, cnn, img_dim, latent_size, 219 | no_imgnorm=False, pretrained=True, 220 | proj_regions=True, finetune=False 221 | ): 222 | super(FullImageEncoder, self).__init__() 223 | self.latent_size = latent_size 224 | self.proj_regions = proj_regions 225 | self.no_imgnorm = no_imgnorm 226 | 227 | # Full text encoder 228 | self.cnn = BaseFeatures(cnn(pretrained)) 229 | 230 | # # For efficient memory usage. 231 | # for param in self.cnn.parameters(): 232 | # param.requires_grad = finetune 233 | 234 | # Only applies pooling when region_pool is enabled 235 | self.region_pool = nn.AdaptiveAvgPool1d(1) 236 | # if proj_regions: 237 | # self.region_pool = lambda x: x 238 | 239 | self.fc = nn.Linear(img_dim, latent_size) 240 | 241 | # self.apply(default_initializer) 242 | self.init_weights() 243 | 244 | # self.aggregate = Aggregate() 245 | 246 | def init_weights(self): 247 | """Xavier initialization for the fully connected layer 248 | """ 249 | import numpy as np 250 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 251 | self.fc.out_features) 252 | self.fc.weight.data.uniform_(-r, r) 253 | self.fc.bias.data.fill_(0) 254 | 255 | def forward(self, images): 256 | """Extract image feature vectors.""" 257 | # assuming that the precomputed features are already l2-normalized 258 | features = self.cnn(images) 259 | 260 | features = features.view(features.shape[0], features.shape[1], -1) 261 | features = l2norm(features, dim=1) 262 | 263 | if not self.proj_regions: 264 | features = self.region_pool(features) 265 | 266 | features = features.permute(0, 2, 1) 267 | 268 | features = self.fc(features) 269 | 270 | # normalize in the joint embedding space 271 | if not self.no_imgnorm: 272 | features = l2norm(features, dim=-1) 273 | 274 | return features 275 | 276 | def load_state_dict(self, state_dict): 277 | """Copies parameters. overwritting the default one to 278 | accept state_dict from Full model 279 | """ 280 | new_state = load_state_dict_with_replace( 281 | state_dict=state_dict, own_state=self.state_dict() 282 | ) 283 | 284 | super(FullImageEncoder, self).load_state_dict(new_state) 285 | 286 | 287 | class FullHierImageEncoder(nn.Module): 288 | 289 | def __init__( 290 | self, cnn, img_dim, latent_size, 291 | no_imgnorm=False, pretrained=True, 292 | proj_regions=True, 293 | ): 294 | super().__init__() 295 | self.latent_size = latent_size 296 | self.proj_regions = proj_regions 297 | self.no_imgnorm = no_imgnorm 298 | 299 | # Full text encoder 300 | self.cnn = HierarchicalFeatures(cnn(pretrained)) 301 | 302 | # Only applies pooling when region_pool is enabled 303 | self.region_pool = nn.AdaptiveAvgPool1d(1) 304 | # if proj_regions: 305 | # self.region_pool = lambda x: x 306 | 307 | self.fc = nn.Linear(5888, latent_size) 308 | self.max_pool = nn.AdaptiveMaxPool2d(1) 309 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 310 | 311 | self.apply(default_initializer) 312 | 313 | # self.aggregate = Aggregate() 314 | 315 | def forward(self, images): 316 | """Extract image feature vectors.""" 317 | # assuming that the precomputed features are already l2-normalized 318 | a, b, c, d = self.cnn(images) 319 | vectors = [self.max_pool(x) for x in [a, b, c, d]] 320 | d = self.avg_pool(d) 321 | vectors = torch.cat(vectors + [d], dim=1) 322 | vectors = vectors.squeeze(-1)#.squeeze(-1) 323 | vectors = vectors.permute(0, 2, 1) 324 | latent = self.fc(vectors) 325 | 326 | return latent 327 | 328 | def load_state_dict(self, state_dict): 329 | """Copies parameters. overwritting the default one to 330 | accept state_dict from Full model 331 | """ 332 | new_state = load_state_dict_with_replace( 333 | state_dict=state_dict, own_state=self.state_dict() 334 | ) 335 | 336 | super(FullHierImageEncoder, self).load_state_dict(new_state) 337 | -------------------------------------------------------------------------------- /retrieval/model/imgenc/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_pooling(x,): 5 | out = x.mean(1) 6 | return out 7 | 8 | 9 | def max_pooling(texts,): 10 | out = torch.stack( 11 | [t[:l].max(0)[0] for t, l in zip(texts, lengths) 12 | ], dim=0) 13 | return out 14 | 15 | 16 | # def last_hidden_state_pool(texts, lengths): 17 | # I = torch.LongTensor(lengths).view(-1, 1, 1) 18 | # I = I.expand(texts.size(0), 1, texts[0].size(1))-1 19 | 20 | # if torch.cuda.is_available(): 21 | # I = I.cuda() 22 | 23 | # out = torch.gather(texts, 1, I).squeeze(1) 24 | -------------------------------------------------------------------------------- /retrieval/model/imgenc/precomp.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...utils.layers import default_initializer 7 | from ..similarity.measure import l1norm, l2norm 8 | from ..layers import attention, convblocks 9 | 10 | import numpy as np 11 | 12 | 13 | def load_state_dict_with_replace(state_dict, own_state): 14 | new_state = OrderedDict() 15 | for name, param in state_dict.items(): 16 | if name in own_state: 17 | new_state[name] = param 18 | return new_state 19 | 20 | 21 | class SCANImagePrecomp(nn.Module): 22 | 23 | def __init__(self, img_dim, latent_size, no_imgnorm=False, ): 24 | super(SCANImagePrecomp, self).__init__() 25 | self.latent_size = latent_size 26 | self.no_imgnorm = no_imgnorm 27 | self.fc = nn.Linear(img_dim, latent_size) 28 | 29 | self.init_weights() 30 | 31 | def init_weights(self): 32 | """Xavier initialization for the fully connected layer 33 | """ 34 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 35 | self.fc.out_features) 36 | self.fc.weight.data.uniform_(-r, r) 37 | self.fc.bias.data.fill_(0) 38 | 39 | def forward(self, batch): 40 | """Extract image feature vectors.""" 41 | # assuming that the precomputed features are already l2-normalized 42 | images = batch['image'].to(self.device) 43 | features = self.fc(images) 44 | 45 | # normalize in the joint embedding space 46 | if not self.no_imgnorm: 47 | features = l2norm(features, dim=-1) 48 | 49 | return features 50 | 51 | def load_state_dict(self, state_dict): 52 | """Copies parameters. overwritting the default one to 53 | accept state_dict from Full model 54 | """ 55 | new_state = load_state_dict_with_replace( 56 | state_dict=state_dict, own_state=self.state_dict() 57 | ) 58 | 59 | super(SCANImagePrecomp, self).load_state_dict(new_state) 60 | 61 | 62 | class SimplePrecomp(nn.Module): 63 | 64 | def __init__(self, img_dim, latent_size, no_imgnorm=False, ): 65 | super(SimplePrecomp, self).__init__() 66 | self.latent_size = latent_size 67 | self.no_imgnorm = no_imgnorm 68 | self.fc = nn.Linear(img_dim, latent_size) 69 | 70 | self.init_weights() 71 | 72 | def init_weights(self): 73 | """Xavier initialization for the fully connected layer 74 | """ 75 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 76 | self.fc.out_features) 77 | self.fc.weight.data.uniform_(-r, r) 78 | self.fc.bias.data.fill_(0) 79 | 80 | def forward(self, batch): 81 | """Extract image feature vectors.""" 82 | # assuming that the precomputed features are already l2-normalized 83 | images = batch['image'].to(self.device) 84 | features = self.fc(images) 85 | features = nn.LeakyReLU(0.1)(features) 86 | 87 | # normalize in the joint embedding space 88 | if not self.no_imgnorm: 89 | features = l2norm(features, dim=-1) 90 | 91 | return features 92 | 93 | def load_state_dict(self, state_dict): 94 | """Copies parameters. overwritting the default one to 95 | accept state_dict from Full model 96 | """ 97 | new_state = load_state_dict_with_replace( 98 | state_dict=state_dict, own_state=self.state_dict() 99 | ) 100 | 101 | super(SCANImagePrecomp, self).load_state_dict(new_state) 102 | 103 | 104 | class VSEImageEncoder(nn.Module): 105 | 106 | def __init__(self, img_dim, latent_size, no_imgnorm=False, device=None): 107 | super(VSEImageEncoder, self).__init__() 108 | self.device = device 109 | self.latent_size = latent_size 110 | self.no_imgnorm = no_imgnorm 111 | self.fc = nn.Linear(img_dim, latent_size) 112 | self.pool = nn.AdaptiveAvgPool1d(1) 113 | 114 | self.apply(default_initializer) 115 | 116 | def forward(self, batch): 117 | """Extract image feature vectors.""" 118 | # assuming that the precomputed features are already l2-normalized 119 | 120 | images = batch['image'].to(self.device).to(self.device) 121 | 122 | images = self.pool(images.permute(0, 2, 1)) # Global pooling 123 | images = images.permute(0, 2, 1) 124 | features = self.fc(images) 125 | # normalize in the joint embedding space 126 | if not self.no_imgnorm: 127 | features = l2norm(features, dim=-1) 128 | 129 | return features 130 | 131 | def load_state_dict(self, state_dict): 132 | """Copies parameters. overwritting the default one to 133 | accept state_dict from Full model 134 | """ 135 | new_state = load_state_dict_with_replace( 136 | state_dict=state_dict, own_state=self.state_dict() 137 | ) 138 | 139 | super(VSEImageEncoder, self).load_state_dict(new_state) 140 | 141 | -------------------------------------------------------------------------------- /retrieval/model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import attention 2 | from . import convblocks 3 | -------------------------------------------------------------------------------- /retrieval/model/layers/adapt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ADAPT(nn.Module): 6 | 7 | def __init__( 8 | self, value_size, k=None, query_size=None, 9 | nonlinear_proj=False, groups=1, 10 | ): 11 | ''' 12 | value_size (int): size of the features from the value matrix 13 | query_size (int): size of the global query vector 14 | k (int, optional): only used for non-linear projection 15 | nonlinear_proj (bool): whether to project gamma and beta non-linearly 16 | groups (int): number of feature groups (default=1) 17 | ''' 18 | super().__init__() 19 | 20 | self.query_size = query_size 21 | self.groups = groups 22 | 23 | if query_size is None: 24 | query_size = value_size 25 | 26 | if nonlinear_proj: 27 | self.fc_gamma = nn.Sequential( 28 | nn.Linear(query_size, value_size//k), 29 | nn.ReLU(inplace=True), 30 | nn.Linear(value_size//k, value_size), 31 | ) 32 | 33 | self.fc_beta = nn.Sequential( 34 | nn.Linear(query_size, value_size//k), 35 | nn.ReLU(inplace=True), 36 | nn.Linear(value_size//k, value_size), 37 | ) 38 | else: 39 | self.fc_gamma = nn.Sequential( 40 | nn.Linear(query_size, value_size//groups), 41 | ) 42 | 43 | self.fc_beta = nn.Sequential( 44 | nn.Linear(query_size, value_size//groups), 45 | ) 46 | 47 | # self.fc_gamma = nn.Linear(cond_vector_size, in_features) 48 | # self.fc_beta = nn.Linear(cond_vector_size, in_features) 49 | 50 | def forward(self, value, query): 51 | ''' 52 | 53 | Adapt embedding matrix (value) given a query vector. 54 | Dimension order is the same of the convolutional layers. 55 | 56 | Arguments: 57 | feat_matrix {torch.FloatTensor} 58 | -- shape: batch, features, timesteps 59 | cond_vector {torch.FloatTensor} 60 | -- shape: ([1 or batch], features) 61 | 62 | Returns: 63 | torch.FloatTensor 64 | -- shape: batch, features, timesteps 65 | 66 | Special cases: 67 | When query shape is (1, features) it is performed 68 | one-to-many embedding adaptation. A single vector is 69 | used to filter all instances from the value matrix 70 | leveraging the brodacast of the query vector. 71 | This is the default option for retrieval. 72 | 73 | When query shape is (batch, features) it is performed 74 | pairwise embedding adaptation. i.e., adaptation is performed 75 | line by line, and value and query must be aligned. 76 | This could be used for VQA or other tasks that don't require 77 | ranking all instances from a set. 78 | 79 | ''' 80 | 81 | B, D, _ = value.shape 82 | Bv, Dv = query.shape 83 | 84 | value = value.view( 85 | B, D//self.groups, self.groups, -1 86 | ) 87 | 88 | gammas = self.fc_gamma(query).view( 89 | Bv, Dv//self.groups, 1, 1 90 | ) 91 | betas = self.fc_beta(query).view( 92 | Bv, Dv//self.groups, 1, 1 93 | ) 94 | 95 | normalized = value * (gammas + 1) + betas 96 | normalized = normalized.view(B, D, -1) 97 | return normalized 98 | -------------------------------------------------------------------------------- /retrieval/model/layers/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class SelfAttention(nn.Module): 7 | 8 | """ Self attention Layer """ 9 | def __init__(self, in_dim, activation, k=8): 10 | super(SelfAttention, self).__init__() 11 | self.chanel_in = in_dim 12 | self.activation = activation 13 | 14 | self.query_conv = nn.Conv1d( 15 | in_channels=in_dim, 16 | out_channels=in_dim//k, 17 | kernel_size=1, 18 | ) 19 | self.key_conv = nn.Conv1d( 20 | in_channels=in_dim, 21 | out_channels=in_dim//k, 22 | kernel_size=1, 23 | ) 24 | self.value_conv = nn.Conv1d( 25 | in_channels=in_dim, 26 | out_channels=in_dim, 27 | kernel_size=1 28 | ) 29 | self.gamma = nn.Parameter(torch.zeros(1)) 30 | 31 | self.softmax = nn.Softmax(dim=-1) 32 | 33 | def forward(self, x, return_attn=False): 34 | """ 35 | inputs : 36 | x : input feature maps(B X C X T) 37 | returns : 38 | out : self attention value + input feature 39 | attention: B X N X N (N is Width*T) 40 | """ 41 | B, C, T = x.size() 42 | 43 | # B X C X (N) 44 | proj_query = self.query_conv(x).view(B, -1, T).permute(0,2,1) 45 | # B X C x (W*H) 46 | proj_key = self.key_conv(x).view(B, -1, T) 47 | energy = torch.bmm(proj_query, proj_key) 48 | # B X (N) X (N) 49 | attention = self.softmax(energy) 50 | # B X C X N 51 | proj_value = self.value_conv(x).view(B, -1, T) 52 | 53 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 54 | out = out.view(B, C, T) 55 | 56 | out = self.gamma * out + x 57 | 58 | if not return_attn: 59 | return out 60 | 61 | return out, attention 62 | 63 | 64 | 65 | class MultiHeadAttention(nn.Module): 66 | 67 | def __init__( 68 | self, input_dim, h=8, 69 | k=8, r=8, inner_activation=nn.Identity(), 70 | dropout=0.1 71 | ): 72 | super(MultiHeadAttention, self).__init__() 73 | 74 | self.num_heads = h 75 | if inner_activation is None: 76 | inner_activation = nn.Identity() 77 | 78 | self_attentions = [] 79 | for i in range(h): 80 | sa = SelfAttention(input_dim, inner_activation, k=k) 81 | self_attentions.append(sa) 82 | 83 | self.fcs = nn.Sequential( 84 | nn.Linear(input_dim, input_dim//r, 1), 85 | # nn.BatchNorm1d(input_dim//r//h), 86 | nn.ReLU(inplace=True), 87 | nn.Dropout(dropout), 88 | nn.Linear(input_dim//r, input_dim // h, 1), 89 | # nn.BatchNorm1d(input_dim//h), 90 | nn.ReLU(inplace=True), 91 | nn.Dropout(dropout), 92 | ) 93 | self.inner_size = input_dim // h 94 | 95 | self.self_attentions = nn.ModuleList(self_attentions) 96 | 97 | def forward(self, x): 98 | """Extract image feature vectors.""" 99 | residual = x 100 | 101 | outs = [] 102 | for sa in self.self_attentions: 103 | _x = sa(x) 104 | outs.append(_x) 105 | 106 | out = torch.stack(outs, 2) 107 | b, d, heads, regions = out.shape 108 | 109 | out = out.view(out.shape[0], out.shape[1], -1).permute(0, 2, 1) 110 | out = self.fcs(out) 111 | out = out.permute(0, 2, 1).contiguous() 112 | out = out.view(b, self.inner_size * self.num_heads, regions) 113 | 114 | out = out + residual 115 | 116 | return out 117 | 118 | 119 | class ModuleSelfAttention(nn.Module): 120 | 121 | """ Self attention Layer """ 122 | def __init__( 123 | self, module, in_dim, 124 | activation, groups=1, k=8, 125 | **kwargs 126 | ): 127 | super(ModuleSelfAttention, self).__init__() 128 | self.chanel_in = in_dim 129 | self.activation = activation 130 | 131 | self.query_conv = module( 132 | in_channels=in_dim, 133 | out_channels=in_dim//k, 134 | kernel_size=1, 135 | groups=groups, 136 | **kwargs, 137 | ) 138 | self.key_conv = module( 139 | in_channels=in_dim, 140 | out_channels=in_dim//k, 141 | kernel_size=1, 142 | groups=groups, 143 | **kwargs, 144 | ) 145 | self.value_conv = module( 146 | in_channels=in_dim, 147 | out_channels=in_dim, 148 | kernel_size=1, 149 | groups=groups, 150 | **kwargs, 151 | ) 152 | self.gamma = nn.Parameter(torch.zeros(1)) 153 | 154 | self.softmax = nn.Softmax(dim=-1) 155 | 156 | def forward(self, x, base, return_attn=False): 157 | """ 158 | inputs : 159 | x : input feature maps(B X C X T) 160 | returns : 161 | out : self attention value + input feature 162 | attention: B X N X N (N is Width*T) 163 | """ 164 | B, C, T = x.size() 165 | 166 | # B X C X (N) 167 | proj_query = self.query_conv(x, base).view(B, -1, T).permute(0,2,1) 168 | # B X C x (W*H) 169 | proj_key = self.key_conv(x, base).view(B, -1, T) 170 | energy = torch.bmm(proj_query, proj_key) 171 | # B X (N) X (N) 172 | attention = self.softmax(energy) 173 | # B X C X N 174 | proj_value = self.value_conv(x, base).view(B, -1, T) 175 | 176 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 177 | out = out.view(B, C, T) 178 | 179 | out = self.gamma * out + x 180 | 181 | if not return_attn: 182 | return out 183 | 184 | return out, attention 185 | 186 | 187 | class Attention(nn.Module): 188 | 189 | def __init__( 190 | self, input_dim=1024, hidden_units=512, 191 | mlp_glimpses=0, smooth=10.): 192 | super(Attention, self).__init__() 193 | self.mlp_glimpses = mlp_glimpses 194 | 195 | self.fusion = lambda x: torch.cat(x, -1) 196 | self.smooth = smooth 197 | 198 | if self.mlp_glimpses > 0: 199 | self.linear0 = nn.Linear(input_dim, hidden_units) 200 | self.linear1 = nn.Linear(hidden_units, mlp_glimpses) 201 | 202 | def forward(self, q, v): 203 | ''' 204 | q: (batch, dim) 205 | v: (1, regions/timesteps, dim) 206 | ''' 207 | alpha = self.process_attention(q, v) 208 | 209 | if self.mlp_glimpses > 0: 210 | alpha = self.linear0(alpha) 211 | alpha = F.relu(alpha) 212 | alpha = self.linear1(alpha) 213 | 214 | alpha = F.softmax(alpha, dim=1) 215 | 216 | if alpha.size(2) > 1: # nb_glimpses > 1 217 | alphas = torch.unbind(alpha, dim=2) 218 | v_outs = [] 219 | for alpha in alphas: 220 | alpha = alpha.unsqueeze(2).expand_as(v) 221 | v_out = alpha*v 222 | v_out = v_out.sum(1) 223 | v_outs.append(v_out) 224 | v_out = torch.cat(v_outs, dim=1) 225 | else: 226 | # alpha = alpha.expand_as(v) 227 | v_out = alpha*v 228 | v_out = v_out.sum(1) 229 | 230 | return v_out 231 | 232 | def process_attention(self, q, v): 233 | batch_size = q.size(0) 234 | n_regions = v.size(1) 235 | q = q[:,None,:].expand(batch_size, n_regions, q.size(1)) 236 | v = v.expand(batch_size, n_regions, v.size(-1)) 237 | 238 | alpha = torch.cat([q, v], -1) 239 | return alpha 240 | 241 | 242 | 243 | class AttentionL(nn.Module): 244 | 245 | def __init__(self, input_dim, hidden_units, mlp_glimpses=0): 246 | super().__init__() 247 | self.mlp_glimpses = mlp_glimpses 248 | self.fusion = lambda x: torch.cat(x, dim=-1) 249 | if self.mlp_glimpses > 0: 250 | self.linear0 = nn.Linear(input_dim, hidden_units) 251 | self.linear1 = nn.Linear(hidden_units, mlp_glimpses) 252 | 253 | def forward(self, q, v): 254 | alpha = self.process_attention(q, v) 255 | 256 | if self.mlp_glimpses > 0: 257 | alpha = self.linear0(alpha) 258 | alpha = F.relu(alpha) 259 | alpha = self.linear1(alpha) 260 | 261 | alpha = F.softmax(alpha, dim=1) 262 | 263 | # FIXME: wtf is happening here 264 | if alpha.size(2) > 1: # nb_glimpses > 1 265 | alphas = torch.unbind(alpha, dim=2) 266 | v_outs = [] 267 | for alpha in alphas: 268 | alpha = alpha.unsqueeze(2).expand_as(v) 269 | v_out = alpha*v 270 | v_out = v_out.sum(1) 271 | v_outs.append(v_out) 272 | v_out = torch.cat(v_outs, dim=1) 273 | else: 274 | # alpha = alpha.expand_as(v) 275 | v_out = alpha*v 276 | v_out = v_out.sum(1) 277 | return v_out 278 | 279 | def process_attention(self, q, v): 280 | ''' 281 | q: (batch, dim) 282 | v: (regions, dim) 283 | 284 | ''' 285 | batch_size, dimensions = q.shape 286 | _, regions, _ = v.shape 287 | 288 | q = q[:,None,:].expand(batch_size, regions, dimensions) 289 | v = v.expand(batch_size, regions, dimensions) 290 | 291 | alpha = self.fusion([q, v]) 292 | # alpha = alpha.view(batch_size, n_regions, -1) 293 | return alpha 294 | -------------------------------------------------------------------------------- /retrieval/model/layers/convblocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class GELU(nn.Module): 7 | """ 8 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 9 | """ 10 | 11 | def forward(self, x): 12 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 13 | 14 | 15 | class ConvBlock(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | batchnorm=True, 20 | activation=nn.LeakyReLU(0.1, inplace=True), 21 | **kwargs 22 | ): 23 | super(ConvBlock, self).__init__() 24 | 25 | layers = [] 26 | layers.append( 27 | nn.Conv1d( 28 | in_channels=kwargs['in_channels'], 29 | out_channels=kwargs['out_channels'], 30 | kernel_size=kwargs['kernel_size'], 31 | padding=kwargs['padding'], 32 | ) 33 | ) 34 | 35 | if batchnorm: 36 | layers.append(nn.BatchNorm1d(kwargs['out_channels'])) 37 | if activation is not None: 38 | layers.append(activation) 39 | 40 | self.conv = nn.Sequential(*layers) 41 | 42 | def forward(self,x): 43 | return self.conv(x) 44 | 45 | 46 | class ParallelBlock(nn.Module): 47 | 48 | def __init__( 49 | self, in_channels, kernel_sizes, 50 | out_channels, paddings, 51 | ): 52 | super().__init__() 53 | 54 | self.convs = nn.ModuleList([ 55 | ConvBlock( 56 | in_channels=in_channels, 57 | out_channels=o, 58 | kernel_size=k, 59 | padding=p, 60 | ) 61 | for k, p, o 62 | in zip(kernel_sizes, paddings, out_channels) 63 | ]) 64 | 65 | def forward(self, x): 66 | 67 | outs = [ 68 | conv(x) for conv in self.convs 69 | ] 70 | t = min([out.shape[-1] for out in outs]) 71 | outs = torch.cat([out[:,:,:t] for out in outs], dim=1) 72 | return outs 73 | 74 | -------------------------------------------------------------------------------- /retrieval/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def adjust_k(epoch, initial_k, increase_k, max_violation=False): 7 | """ 8 | Update loss hyper-parameter k 9 | linearly from intial_k to 1 according to 10 | the number of epochs 11 | """ 12 | if max_violation: 13 | return 1. 14 | 15 | return min(initial_k + (increase_k * epoch), 1.) 16 | 17 | 18 | def cosine_sim(im, s,): 19 | """ 20 | Cosine similarity between all the 21 | image and sentence pairs 22 | """ 23 | return im.mm(s.t()) 24 | 25 | 26 | def cosine_sim_numpy(im, s): 27 | """ 28 | Cosine similarity between all the 29 | image and sentence pairs 30 | """ 31 | return im.dot(s.T) 32 | 33 | 34 | class ContrastiveLoss(nn.Module): 35 | """ 36 | Compute contrastive loss 37 | """ 38 | 39 | def __init__( 40 | self, margin=0.2, 41 | max_violation=True, 42 | weight=1., beta=0.999, 43 | ): 44 | super().__init__() 45 | self.margin = margin 46 | self.sim = cosine_sim 47 | self.weight = weight 48 | self.max_violation = max_violation 49 | self.beta = beta 50 | 51 | self.iteration = 0 52 | self.k = 0 53 | 54 | def adjust_k(self, ): 55 | """ 56 | Update loss hyper-parameter k 57 | linearly from intial_k to 1 according to 58 | the number of epochs 59 | """ 60 | self.iteration += 1 61 | 62 | if self.max_violation: 63 | self.k = 1 64 | return 1. 65 | 66 | self.k = (1.-self.beta**np.float(self.iteration)) 67 | return self.k 68 | 69 | def forward(self, scores ): 70 | # compute image-sentence score matrix 71 | # scores = self.sim(im, s) 72 | 73 | diagonal = scores.diag().view(scores.size(0), 1) 74 | d1 = diagonal.expand_as(scores) 75 | d2 = diagonal.t().expand_as(scores) 76 | 77 | # compare every diagonal score to scores in its column 78 | # caption retrieval 79 | cost_s = (self.margin + scores - d1).clamp(min=0) 80 | # compare every diagonal score to scores in its row 81 | # image retrieval 82 | cost_im = (self.margin + scores - d2).clamp(min=0) 83 | 84 | # clear diagonals 85 | mask = torch.eye(scores.size(0)) > .5 86 | I = mask#.cuda() 87 | I = I.to(cost_s.device) 88 | cost_s = cost_s.masked_fill_(I, 0) 89 | cost_im = cost_im.masked_fill_(I, 0) 90 | 91 | cost_s_t = cost_s.sum() 92 | cost_im_t = cost_im.sum() 93 | 94 | k = self.adjust_k() 95 | 96 | cost_all_k = (cost_s_t + cost_im_t) * (1. - k) 97 | 98 | # keep the maximum violating negative for each query 99 | cost_s_max = cost_s.max(1)[0] 100 | cost_im_max = cost_im.max(0)[0] 101 | 102 | cost_hard_k = (cost_s_max.sum() + cost_im_max.sum()) * k 103 | 104 | total_loss = cost_all_k + cost_hard_k 105 | 106 | return total_loss * self.weight 107 | 108 | def __repr__(self): 109 | return(( 110 | f'ContrastiveLoss (margin={self.margin}, ' 111 | f'device={self.device}, ' 112 | f'similarity_fn={self.sim}, ' 113 | f'weight={self.weight}, ' 114 | f'max_violation={self.max_violation}, ' 115 | f'beta={self.beta})' 116 | )) 117 | 118 | 119 | 120 | class ContrastiveLossWithSoftmax(nn.Module): 121 | """ 122 | Compute contrastive loss 123 | """ 124 | 125 | def __init__( 126 | self, margin=0.2, 127 | max_violation=True, 128 | weight=1., beta=0.999, smooth=20, 129 | ): 130 | super().__init__() 131 | self.margin = margin 132 | self.sim = cosine_sim 133 | self.weight = weight 134 | self.max_violation = max_violation 135 | self.beta = beta 136 | 137 | self.loss_softmax = SoftmaxLoss(smooth=smooth) 138 | 139 | self.iteration = 0 140 | self.k = 0 141 | 142 | def adjust_k(self, ): 143 | """ 144 | Update loss hyper-parameter k 145 | linearly from intial_k to 1 according to 146 | the number of epochs 147 | """ 148 | self.iteration += 1 149 | 150 | if self.max_violation: 151 | self.k = 1 152 | return 1. 153 | 154 | self.k = (1.-self.beta**np.float(self.iteration)) 155 | return self.k 156 | 157 | def forward(self, scores ): 158 | 159 | lst = self.loss_softmax(scores) 160 | diagonal = scores.diag().view(scores.size(0), 1) 161 | d1 = diagonal.expand_as(scores) 162 | d2 = diagonal.t().expand_as(scores) 163 | 164 | # compare every diagonal score to scores in its column 165 | # caption retrieval 166 | cost_s = (self.margin + scores - d1).clamp(min=0) 167 | # compare every diagonal score to scores in its row 168 | # image retrieval 169 | cost_im = (self.margin + scores - d2).clamp(min=0) 170 | 171 | # clear diagonals 172 | mask = torch.eye(scores.size(0)) > .5 173 | I = mask#.cuda() 174 | I = I.to(cost_s.device) 175 | cost_s = cost_s.masked_fill_(I, 0) 176 | cost_im = cost_im.masked_fill_(I, 0) 177 | 178 | cost_s_t = cost_s.sum() 179 | cost_im_t = cost_im.sum() 180 | 181 | k = self.adjust_k() 182 | 183 | cost_all_k = (cost_s_t + cost_im_t) * (1. - k) 184 | 185 | # keep the maximum violating negative for each query 186 | cost_s_max = cost_s.max(1)[0] 187 | cost_im_max = cost_im.max(0)[0] 188 | 189 | cost_hard_k = (cost_s_max.sum() + cost_im_max.sum()) * k 190 | 191 | total_loss = cost_all_k + cost_hard_k 192 | 193 | return total_loss * self.weight + lst 194 | 195 | def __repr__(self): 196 | return(( 197 | f'ContrastiveLoss (margin={self.margin}, ' 198 | f'device={self.device}, ' 199 | f'similarity_fn={self.sim}, ' 200 | f'weight={self.weight}, ' 201 | f'max_violation={self.max_violation}, ' 202 | f'beta={self.beta})' 203 | )) 204 | 205 | 206 | class SoftmaxLoss(nn.Module): 207 | """ 208 | Compute contrastive loss 209 | """ 210 | 211 | def __init__( 212 | self, smooth=15, **kwargs 213 | ): 214 | super().__init__() 215 | 216 | self.loss_im = nn.CrossEntropyLoss() 217 | self.loss_tx = nn.CrossEntropyLoss() 218 | self.iteration = 0 219 | self.k = 0 220 | self.smooth = smooth 221 | 222 | def forward(self, scores ): 223 | self.iteration += 1 224 | 225 | scores = scores.cuda() 226 | scores = scores * self.smooth 227 | 228 | labels = torch.arange(0, len(scores)).cuda().long() 229 | 230 | l_im = self.loss_im(scores, labels) 231 | l_tx = self.loss_tx(scores.t(), labels) 232 | 233 | l = l_im + l_tx 234 | return l 235 | 236 | 237 | class ContrastiveLoss_(nn.Module): 238 | """ 239 | Compute contrastive loss 240 | """ 241 | 242 | def __init__( 243 | self, device, 244 | margin=0.2, max_violation=True, 245 | weight=1., initial_k=1., increase_k=0, 246 | ): 247 | super(ContrastiveLoss, self).__init__() 248 | self.margin = margin 249 | self.device = device 250 | self.sim = cosine_sim 251 | self.weight = weight 252 | self.initial_k = initial_k 253 | self.increase_k = increase_k 254 | self.max_violation = max_violation 255 | self.epoch = -1 256 | 257 | def update_epoch(self, epoch=None): 258 | if epoch is None: 259 | self.epoch += 1 260 | 261 | def update_k(self,): 262 | self.k = adjust_k( 263 | self.epoch, 264 | initial_k=self.initial_k, 265 | increase_k=self.increase_k, 266 | max_violation=self.max_violation, 267 | ) 268 | return self.k 269 | 270 | def forward(self, scores): 271 | 272 | # compute image-sentence score matrix 273 | diagonal = scores.diag().view(scores.size(0), 1) 274 | d1 = diagonal.expand_as(scores) 275 | d2 = diagonal.t().expand_as(scores) 276 | 277 | # compare every diagonal score to scores in its column 278 | # caption retrieval 279 | cost_s = (self.margin + scores - d1).clamp(min=0) 280 | # compare every diagonal score to scores in its row 281 | # image retrieval 282 | cost_im = (self.margin + scores - d2).clamp(min=0) 283 | 284 | # clear diagonals 285 | mask = torch.eye(scores.size(0)) > .5 286 | I = mask.to(self.device) 287 | 288 | cost_s = cost_s.masked_fill_(I, 0) 289 | cost_im = cost_im.masked_fill_(I, 0) 290 | 291 | cost_s_t = cost_s.sum() 292 | cost_im_t = cost_im.sum() 293 | 294 | k = self.update_k() 295 | 296 | cost_all_k = (cost_s_t + cost_im_t) * (1. - k) 297 | 298 | # keep the maximum violating negative for each query 299 | cost_s_max = cost_s.max(1)[0] 300 | cost_im_max = cost_im.max(0)[0] 301 | 302 | cost_hard_k = (cost_s_max.sum() + cost_im_max.sum()) * k 303 | 304 | total_loss = cost_all_k + cost_hard_k 305 | 306 | return total_loss * self.weight 307 | 308 | 309 | _loss = { 310 | 'contrastive': ContrastiveLoss, 311 | 'softmax': SoftmaxLoss, 312 | 'contrastive_softmax': ContrastiveLossWithSoftmax, 313 | } 314 | 315 | def get_loss(name, params): 316 | return _loss[name](**params) 317 | -------------------------------------------------------------------------------- /retrieval/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import data_parallel 5 | from ..utils.logger import get_logger 6 | from .imgenc import get_image_encoder, get_img_pooling 7 | from .similarity.factory import get_similarity_object 8 | from .similarity.similarity import Similarity 9 | from .txtenc import get_text_encoder, get_txt_pooling 10 | 11 | logger = get_logger() 12 | 13 | 14 | class Retrieval(nn.Module): 15 | 16 | def __init__( 17 | self, txt_enc={}, img_enc={}, similarity={}, 18 | ml_similarity={}, tokenizers=None, latent_size=1024, 19 | **kwargs 20 | ): 21 | super().__init__() 22 | 23 | self.master = True 24 | 25 | self.latent_size = latent_size 26 | self.img_enc = get_image_encoder( 27 | name=img_enc.name, 28 | latent_size=latent_size, 29 | **img_enc.params 30 | ) 31 | 32 | logger.info(( 33 | 'Image encoder created: ' 34 | f'{img_enc.name,}' 35 | )) 36 | 37 | self.txt_enc = get_text_encoder( 38 | name = txt_enc.name, 39 | latent_size=latent_size, 40 | tokenizers=tokenizers, 41 | **txt_enc.params, 42 | ) 43 | 44 | self.tokenizers = tokenizers 45 | 46 | self.txt_pool = get_txt_pooling(txt_enc.pooling) 47 | self.img_pool = get_img_pooling(img_enc.pooling) 48 | 49 | logger.info(( 50 | 'Text encoder created: ' 51 | f'{txt_enc.name}' 52 | )) 53 | 54 | sim_obj = get_similarity_object( 55 | similarity.name, 56 | **similarity.params 57 | ) 58 | 59 | self.similarity = Similarity( 60 | similarity_object=sim_obj, 61 | device=similarity.device, 62 | latent_size=latent_size, 63 | **kwargs 64 | ) 65 | 66 | self.ml_similarity = nn.Identity() 67 | if ml_similarity is not None: 68 | self.ml_similarity = self.similarity 69 | 70 | if ml_similarity != {}: 71 | ml_sim_obj = get_similarity_object( 72 | ml_similarity.name, 73 | **ml_similarity.params 74 | ) 75 | 76 | self.ml_similarity = Similarity( 77 | similarity_object=ml_sim_obj, 78 | device=similarity.device, 79 | latent_size=latent_size, 80 | **kwargs 81 | ) 82 | 83 | self.set_devices_() 84 | 85 | logger.info(f'Using similarity: {similarity.name,}') 86 | 87 | def set_devices_(self, txt_devices=['cuda'], img_devices=['cuda'], loss_device='cuda'): 88 | if len(txt_devices) > 1: 89 | self.txt_enc = data_parallel.DataParallel(self.txt_enc) 90 | self.txt_enc.device = torch.device('cuda') 91 | elif len(txt_devices) == 1: 92 | self.txt_enc.to(txt_devices[0]) 93 | self.txt_enc.device = torch.device(txt_devices[0]) 94 | 95 | if len(img_devices) > 1: 96 | self.img_enc = data_parallel.DataParallel(self.img_device) 97 | self.img_enc.device = torch.device('cuda') 98 | elif len(img_devices) == 1: 99 | self.img_enc.to(img_devices[0]) 100 | self.img_enc.device = torch.device(img_devices[0]) 101 | 102 | self.loss_device = torch.device(loss_device) 103 | 104 | self.similarity = self.similarity.to(self.loss_device) 105 | self.ml_similarity = self.ml_similarity.to(self.loss_device) 106 | 107 | logger.info(( 108 | f'Setting devices: ' 109 | f'img: {self.img_enc.device},' 110 | f'txt: {self.txt_enc.device}, ' 111 | f'loss: {self.loss_device}' 112 | )) 113 | 114 | def embed_caption_features(self, cap_features, lengths): 115 | return self.txt_pool(cap_features, lengths) 116 | 117 | def embed_image_features(self, img_features): 118 | return self.img_pool(img_features) 119 | 120 | def embed_images(self, batch): 121 | img_tensor = self.img_enc(batch) 122 | img_embed = self.embed_image_features(img_tensor) 123 | return img_embed 124 | 125 | def embed_captions(self, batch): 126 | txt_tensor, lengths = self.txt_enc(batch) 127 | txt_embed = self.embed_caption_features(txt_tensor, lengths) 128 | return txt_embed 129 | 130 | def forward_batch(self, batch): 131 | img_embed = self.embed_images(batch) 132 | txt_embed = self.embed_captions(batch) 133 | return img_embed, txt_embed 134 | 135 | # def forward(self, images, captions, lengths): 136 | # img_embed = self.embed_images(images) 137 | # txt_embed = self.embed_captions(captions, lengths) 138 | # return img_embed, txt_embed 139 | 140 | def get_sim_matrix(self, embed_a, embed_b, lens=None): 141 | return self.similarity(embed_a, embed_b, lens) 142 | 143 | def get_ml_sim_matrix(self, embed_a, embed_b, lens=None): 144 | return self.ml_similarity(embed_a, embed_b, lens) 145 | 146 | def get_sim_matrix_shared(self, embed_a, embed_b, lens=None, shared_size=128): 147 | return self.similarity.forward_shared( 148 | embed_a, embed_b, lens, 149 | shared_size=shared_size 150 | ) 151 | -------------------------------------------------------------------------------- /retrieval/model/similarity/__init__.py: -------------------------------------------------------------------------------- 1 | from . import similarity 2 | from . import measure 3 | from . import factory 4 | -------------------------------------------------------------------------------- /retrieval/model/similarity/factory.py: -------------------------------------------------------------------------------- 1 | # TODO: improve this 2 | from . import similarity as sim 3 | from addict import Dict 4 | 5 | 6 | _similarities = { 7 | 'cosine': { 8 | 'class': sim.Cosine, 9 | }, 10 | 'adapt_t2i': { 11 | 'class': sim.AdaptiveEmbeddingT2I, 12 | }, 13 | 'adapt_i2t': { 14 | 'class': sim.AdaptiveEmbeddingI2T, 15 | }, 16 | 'scan_i2t': { 17 | 'class': sim.StackedAttention, 18 | }, 19 | 'scan_t2i': { 20 | 'class': sim.StackedAttention, 21 | }, 22 | 'order': None, 23 | } 24 | 25 | def get_similarity_object(name, **kwargs): 26 | settings = _similarities[name] 27 | return settings['class'](**kwargs) 28 | 29 | 30 | def get_sim_names(): 31 | return _similarities.keys() 32 | -------------------------------------------------------------------------------- /retrieval/model/similarity/measure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def l1norm(X, dim, eps=1e-8): 5 | """L1-normalize columns of X 6 | """ 7 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps 8 | X = torch.div(X, norm) 9 | return X 10 | 11 | 12 | def l2norm(X, dim, eps=1e-8): 13 | """L2-normalize columns of X 14 | """ 15 | norm = torch.pow(X, 2).sum( 16 | dim=dim, keepdim=True 17 | ).sqrt() + eps 18 | X = torch.div(X, norm) 19 | return X 20 | 21 | 22 | def cosine_sim(im, s,): 23 | """ 24 | Cosine similarity between all the 25 | image and sentence pairs 26 | """ 27 | return im.mm(s.t()) 28 | 29 | 30 | def cosine_sim_numpy(im, s): 31 | """ 32 | Cosine similarity between all the 33 | image and sentence pairs 34 | """ 35 | return im.dot(s.T) 36 | 37 | -------------------------------------------------------------------------------- /retrieval/model/similarity/similarity.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as dt 2 | 3 | import numpy as np 4 | import torch 5 | from addict import Dict 6 | from torch import nn 7 | from torch.nn import _VF 8 | from torch.nn import functional as F 9 | from tqdm import tqdm 10 | 11 | from ...utils import helper 12 | from ...utils.logger import get_logger 13 | from .. import txtenc 14 | from ..layers import attention, adapt 15 | from ..txtenc.pooling import mean_pooling 16 | from ..txtenc import pooling 17 | from ..txtenc import factory 18 | from .measure import cosine_sim, l2norm 19 | 20 | logger = get_logger() 21 | 22 | 23 | class Similarity(nn.Module): 24 | 25 | def __init__(self, device, similarity_object, **kwargs): 26 | super().__init__() 27 | self.device = device 28 | self.similarity = similarity_object 29 | # self.similarity = factory.get_similarity_object(similarity_name, device=device, **kwargs) 30 | logger.info(f'Created similarity: {similarity_object}') 31 | self.set_master_() 32 | 33 | def set_master_(self, is_master=True): 34 | self.master = is_master 35 | 36 | def forward(self, img_embed, cap_embed, lens, shared=False): 37 | logger.debug(( 38 | f'Similarity - img_shape: {img_embed.shape} ' 39 | 'cap_shape: {cap_embed.shape}' 40 | )) 41 | 42 | return self.similarity(img_embed, cap_embed, lens) 43 | 44 | def forward_shared(self, img_embed, cap_embed, lens, shared_size=128): 45 | """ 46 | Compute pairwise i2t image-caption distance with locality sharding 47 | """ 48 | 49 | #img_embed = img_embed.to(self.device) 50 | #cap_embed = cap_embed.to(self.device) 51 | 52 | n_im_shard = (len(img_embed)-1)//shared_size + 1 53 | n_cap_shard = (len(cap_embed)-1)//shared_size + 1 54 | 55 | logger.debug('Calculating shared similarities') 56 | 57 | pbar_fn = lambda x: range(x) 58 | if self.master: 59 | pbar_fn = lambda x: tqdm( 60 | range(x), total=x, 61 | desc='Test ', 62 | leave=False, 63 | ) 64 | 65 | d = torch.zeros(len(img_embed), len(cap_embed)).cpu() 66 | for i in pbar_fn(n_im_shard): 67 | im_start = shared_size*i 68 | im_end = min(shared_size*(i+1), len(img_embed)) 69 | for j in range(n_cap_shard): 70 | cap_start = shared_size*j 71 | cap_end = min(shared_size*(j+1), len(cap_embed)) 72 | im = img_embed[im_start:im_end] 73 | s = cap_embed[cap_start:cap_end] 74 | l = lens[cap_start:cap_end] 75 | sim = self.forward(im, s, l) 76 | d[im_start:im_end, cap_start:cap_end] = sim 77 | 78 | logger.debug('Done computing shared similarities.') 79 | return d 80 | 81 | 82 | class Cosine(nn.Module): 83 | 84 | def __init__(self, device, latent_size=1024): 85 | super().__init__() 86 | self.device = device 87 | 88 | def forward(self, img_embed, cap_embed, *args, **kwargs): 89 | img_embed = img_embed.to(self.device) 90 | cap_embed = cap_embed.to(self.device) 91 | 92 | img_embed = l2norm(img_embed, dim=1) 93 | cap_embed = l2norm(cap_embed, dim=1) 94 | 95 | return cosine_sim(img_embed, cap_embed)#.cpu() 96 | 97 | 98 | class Fovea(nn.Module): 99 | 100 | def __init__(self, smooth=10, train_smooth=False): 101 | super().__init__() 102 | 103 | self.smooth = smooth 104 | self.train_smooth = train_smooth 105 | self.softmax = nn.Softmax(dim=-1) 106 | 107 | if train_smooth: 108 | self.smooth = nn.Parameter(torch.zeros(1) + self.smooth) 109 | 110 | def forward(self, x): 111 | ''' 112 | x: [batch_size, features, k] 113 | ''' 114 | mask = self.softmax(x * self.smooth) 115 | output = mask * x 116 | return output 117 | 118 | def __repr__(self): 119 | return ( 120 | f'Fovea(smooth={self.smooth},' 121 | f'train_smooth: {self.train_smooth})' 122 | ) 123 | 124 | 125 | class Normalization(nn.Module): 126 | 127 | def __init__(self, latent_size, norm_method=None): 128 | super().__init__() 129 | if norm_method is None: 130 | self.norm = nn.Identity() 131 | elif norm_method == 'batchnorm': 132 | self.norm = nn.BatchNorm1d(latent_size, affine=False) 133 | elif norm_method == 'instancenorm': 134 | self.norm = nn.InstanceNorm1d(latent_size, affine=False) 135 | 136 | def forward(self, x): 137 | return self.norm(x) 138 | 139 | # FIND 140 | class AdaptiveEmbeddingT2I(nn.Module): 141 | 142 | def __init__( 143 | self, device, latent_size=1024, k=1, 144 | gamma=10, train_gamma=False, clip_embeddings=True, 145 | normalization='batchnorm', use_fovea=True 146 | ): 147 | super().__init__() 148 | 149 | self.device = device 150 | 151 | self.norm = Normalization(latent_size, normalization) 152 | 153 | self.adapt_img = adapt.ADAPT(latent_size, k) 154 | 155 | self.fovea = nn.Identity() 156 | if use_fovea: 157 | self.fovea = Fovea(smooth=gamma, train_smooth=train_gamma) 158 | 159 | def forward(self, img_embed, cap_embed, lens, **kwargs): 160 | ''' 161 | img_embed: (B, 36, latent_size) 162 | cap_embed: (B, T, latent_size) 163 | lens (List[int]): (B) 164 | ''' 165 | # (B, 1024, T) 166 | cap_embed = cap_embed.permute(0, 2, 1) 167 | img_embed = img_embed.permute(0, 2, 1) 168 | 169 | img_embed = self.norm(img_embed) 170 | # cap_embed = self.norm(cap_embed) 171 | 172 | sims = torch.zeros( 173 | img_embed.shape[0], cap_embed.shape[0] 174 | ).to(self.device) 175 | 176 | for i, cap_tensor in enumerate(cap_embed): 177 | # cap_tensor: 1, 1024, T 178 | # img_embed : B, 1024, 36 179 | n_words = lens[i] 180 | 181 | # Global textual representation 182 | # cap_vector: 1, 1024 183 | cap_repr = cap_tensor[:,:n_words].mean(-1).unsqueeze(0) 184 | 185 | img_output = self.adapt_img(img_embed, cap_repr) 186 | img_output = self.fovea(img_output) 187 | # Filtered global representation of the images 188 | img_vector = img_output.mean(-1) 189 | 190 | img_vector = l2norm(img_vector, dim=-1) 191 | cap_vector = l2norm(cap_repr, dim=-1) 192 | 193 | # sim = cosine_sim(img_vector, cap_vector) 194 | sim = cosine_sim(img_vector, cap_vector).squeeze(-1) 195 | 196 | # sim = sim.squeeze(-1) 197 | sims[:,i] = sim 198 | 199 | return sims 200 | 201 | 202 | class AdaptiveEmbeddingI2T(nn.Module): 203 | 204 | def __init__( 205 | self, device, latent_size=1024, k=1, 206 | gamma=1, train_gamma=False, 207 | normalization='batchnorm', use_fovea=True 208 | ): 209 | super().__init__() 210 | 211 | self.device = device 212 | 213 | if normalization: 214 | self.norm = Normalization(latent_size, normalization) 215 | 216 | self.adapt_txt = adapt.ADAPT(latent_size, k) 217 | 218 | if use_fovea: 219 | self.fovea = Fovea(smooth=gamma, train_smooth=train_gamma) 220 | else: 221 | self.fovea = nn.Identity() 222 | 223 | def forward(self, img_embed, cap_embed, lens, **kwargs): 224 | ''' 225 | img_embed: (B, 36, latent_size) 226 | cap_embed: (B, T, latent_size) 227 | ''' 228 | # (B, 1024, T) 229 | # 230 | cap_embed = cap_embed.permute(0, 2, 1)[...,:34] 231 | img_embed = img_embed.permute(0, 2, 1) 232 | 233 | cap_embed = self.norm(cap_embed) 234 | 235 | sims = torch.zeros( 236 | img_embed.shape[0], cap_embed.shape[0] 237 | ).to(self.device) 238 | 239 | # Global image representation 240 | img_embed = img_embed.mean(-1) 241 | 242 | for i, img_tensor in enumerate(img_embed): 243 | # cap_tensor : B, 1024, T 244 | # image_embed: 1, 1024 245 | 246 | img_vector = img_tensor.unsqueeze(0) 247 | txt_output = self.adapt_txt(value=cap_embed, query=img_vector) 248 | txt_output = self.fovea(txt_output) 249 | 250 | txt_vector = txt_output.max(dim=-1)[0] 251 | 252 | txt_vector = l2norm(txt_vector, dim=-1) 253 | img_vector = l2norm(img_vector, dim=-1) 254 | sim = cosine_sim(img_vector, txt_vector) 255 | sim = sim.squeeze(-1) 256 | sims[i,:] = sim 257 | 258 | return sims 259 | 260 | 261 | class LogSumExp(nn.Module): 262 | def __init__(self, lambda_lse): 263 | self.lambda_lse = lambda_lse 264 | 265 | def forward(self, x): 266 | x.mul_(self.lambda_lse).exp_() 267 | x = x.sum(dim=1, keepdim=True) 268 | x = torch.log(x)/self.lambda_lse 269 | return x 270 | 271 | 272 | class ClippedL2Norm(nn.Module): 273 | def __init__(self, ): 274 | super().__init__() 275 | self.leaky = nn.LeakyReLU(0.1) 276 | 277 | def forward(self, x): 278 | return l2norm(self.leaky(x), 2) 279 | 280 | 281 | class StackedAttention(nn.Module): 282 | 283 | def __init__( 284 | self, i2t=True, agg_function='Mean', 285 | feature_norm='softmax', lambda_lse=None, 286 | smooth=4, **kwargs, 287 | ): 288 | super().__init__() 289 | self.i2t = i2t 290 | self.lambda_lse = lambda_lse 291 | self.agg_function = agg_function 292 | self.feature_norm = feature_norm 293 | self.lambda_lse = lambda_lse 294 | self.smooth = smooth 295 | self.kwargs = kwargs 296 | 297 | self.attention = Attention( 298 | smooth=smooth, feature_norm=feature_norm, 299 | ) 300 | 301 | if agg_function == 'LogSumExp': 302 | self.aggregate_function = LogSumExp(lambda_lse) 303 | elif agg_function == 'Max': 304 | self.aggregate_function = lambda x: x.max(dim=1, keepdim=True)[0] 305 | elif agg_function == 'Sum': 306 | self.aggregate_function = lambda x: x.sum(dim=1, keepdim=True) 307 | elif agg_function == 'Mean': 308 | self.aggregate_function = lambda x: x.mean(dim=1, keepdim=True) 309 | else: 310 | raise ValueError("unknown aggfunc: {}".format(agg_function)) 311 | 312 | self.task = 'i2t' if i2t else 't2i' 313 | 314 | def forward(self, images, captions, cap_lens, ): 315 | """ 316 | Images: (n_image, n_regions, d) matrix of images 317 | Captions: (n_caption, max_n_word, d) matrix of captions 318 | CapLens: (n_caption) array of caption lengths 319 | """ 320 | similarities = [] 321 | n_image = images.size(0) 322 | n_caption = captions.size(0) 323 | 324 | for i in range(n_caption): 325 | # Get the i-th text description 326 | n_word = cap_lens[i] 327 | cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous() 328 | # --> (n_image, n_word, d) 329 | cap_i_expand = cap_i.repeat(n_image, 1, 1) 330 | """ 331 | word(query): (n_image, n_word, d) 332 | image(context): (n_image, n_regions, d) 333 | weiContext: (n_image, n_word, d) or (n_image, n_region, d) 334 | attn: (n_image, n_region, n_word) 335 | """ 336 | emb_a = cap_i_expand 337 | emb_b = images 338 | if self.i2t: 339 | emb_a = images 340 | emb_b = cap_i_expand 341 | 342 | weiContext, attn = self.attention(emb_a, emb_b) 343 | emb_a = emb_a.contiguous() 344 | weiContext = weiContext.contiguous() 345 | # (n_image, n_word) 346 | row_sim = cosine_similarity(emb_a, weiContext, dim=2) 347 | row_sim = self.aggregate_function(row_sim) 348 | similarities.append(row_sim) 349 | 350 | # (n_image, n_caption) 351 | similarities = torch.cat(similarities, 1) 352 | 353 | return similarities 354 | 355 | def __repr__(self, ): 356 | return ( 357 | f'StackedAttention(task: {self.task},' 358 | f'i2t: {self.i2t}, ' 359 | f'attention: {self.attention}, ' 360 | f'lambda_lse: {self.lambda_lse}, ' 361 | f'agg_function: {self.agg_function}, ' 362 | f'feature_norm: {self.feature_norm}, ' 363 | f'lambda_lse: {self.lambda_lse}, ' 364 | f'smooth: {self.smooth}, ' 365 | f'kwargs: {self.kwargs})' 366 | ) 367 | 368 | 369 | 370 | def attn_softmax(attn): 371 | batch_size, sourceL, queryL = attn.shape 372 | attn = attn.view(batch_size*sourceL, queryL) 373 | attn = nn.Softmax(dim=-1)(attn) 374 | # --> (batch, sourceL, queryL) 375 | attn = attn.view(batch_size, sourceL, queryL) 376 | return attn 377 | 378 | 379 | class Attention(nn.Module): 380 | 381 | def __init__(self, smooth, feature_norm='softmax'): 382 | super().__init__() 383 | self.smooth = smooth 384 | self.feature_norm = feature_norm 385 | 386 | if feature_norm == "softmax": 387 | self.normalize_attn = attn_softmax 388 | # elif feature_norm == "l2norm": 389 | # attn = lambda x: l2norm(x, 2) 390 | elif feature_norm == "clipped_l2norm": 391 | self.normalize_attn = ClippedL2Norm() 392 | # elif feature_norm == "l1norm": 393 | # attn = l1norm_d(attn, 2) 394 | # elif feature_norm == "clipped_l1norm": 395 | # attn = nn.LeakyReLU(0.1)(attn) 396 | # attn = l1norm_d(attn, 2) 397 | elif feature_norm == "clipped": 398 | self.normalize_attn = lambda x: nn.LeakyReLU(0.1)(x) 399 | elif feature_norm == "no_norm": 400 | self.normalize_attn = lambda x: x 401 | else: 402 | raise ValueError("unknown first norm type:", feature_norm) 403 | 404 | def forward(self, query, context, ): 405 | batch_size_q, queryL = query.size(0), query.size(1) 406 | batch_size, sourceL = context.size(0), context.size(1) 407 | 408 | # Get attention 409 | # --> (batch, d, queryL) 410 | queryT = torch.transpose(query, 1, 2) 411 | 412 | # (batch, sourceL, d)(batch, d, queryL) 413 | # --> (batch, sourceL, queryL) 414 | attn = torch.bmm(context, queryT) 415 | attn = self.normalize_attn(attn) 416 | # --> (batch, queryL, sourceL) 417 | attn = torch.transpose(attn, 1, 2).contiguous() 418 | # --> (batch*queryL, sourceL) 419 | attn = attn.view(batch_size*queryL, sourceL) 420 | attn = nn.Softmax(dim=-1)(attn*self.smooth) 421 | # --> (batch, queryL, sourceL) 422 | attn = attn.view(batch_size, queryL, sourceL) 423 | # --> (batch, sourceL, queryL) 424 | attnT = torch.transpose(attn, 1, 2).contiguous() 425 | 426 | # --> (batch, d, sourceL) 427 | contextT = torch.transpose(context, 1, 2) 428 | # (batch x d x sourceL)(batch x sourceL x queryL) 429 | # --> (batch, d, queryL) 430 | weightedContext = torch.bmm(contextT, attnT) 431 | # --> (batch, queryL, d) 432 | weightedContext = torch.transpose(weightedContext, 1, 2) 433 | 434 | return weightedContext, attnT 435 | 436 | 437 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 438 | """Returns cosine similarity between x1 and x2, computed along dim.""" 439 | w12 = torch.sum(x1 * x2, dim) 440 | w1 = torch.norm(x1, 2, dim) 441 | w2 = torch.norm(x2, 2, dim) 442 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 443 | -------------------------------------------------------------------------------- /retrieval/model/txtenc/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import get_text_encoder, get_available_txtenc, get_txt_pooling 2 | from . import embedding 3 | -------------------------------------------------------------------------------- /retrieval/model/txtenc/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..layers import convblocks 4 | from ..layers import attention 5 | 6 | 7 | class PartialConcat(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | num_embeddings, 12 | embed_dim=300, 13 | liwe_char_dim=24, 14 | liwe_neurons=[128, 256], 15 | liwe_dropout=0.0, 16 | liwe_wnorm=True, 17 | liwe_batch_norm=True, 18 | liwe_activation=nn.ReLU(inplace=True), 19 | max_chars = 26, 20 | **kwargs 21 | ): 22 | super(PartialConcat, self).__init__() 23 | 24 | weight_norm = nn.Identity 25 | if liwe_wnorm: 26 | from torch.nn.utils import weight_norm 27 | 28 | self.embed = nn.Embedding(num_embeddings, liwe_char_dim) 29 | 30 | self.total_embed_size = liwe_char_dim * max_chars 31 | 32 | layers = [] 33 | liwe_neurons = liwe_neurons + [embed_dim] 34 | in_sizes = [liwe_char_dim * max_chars] + liwe_neurons 35 | 36 | batch_norm = nn.BatchNorm1d 37 | if not liwe_batch_norm: 38 | batch_norm = nn.Identity 39 | 40 | for n, i in zip(liwe_neurons, in_sizes): 41 | layer = nn.Sequential(*[ 42 | weight_norm( 43 | nn.Conv1d(i, n, 1) 44 | ), 45 | nn.Dropout(liwe_dropout), 46 | batch_norm(n), 47 | liwe_activation, 48 | ]) 49 | layers.append(layer) 50 | 51 | self.layers = nn.Sequential(*layers) 52 | 53 | def forward_embed(self, x): 54 | 55 | partial_words = x.view(self.B, -1) # (B, W*Ct) 56 | char_embed = self.embed(partial_words) # (B, W*Ct, Cw) 57 | # char_embed = l2norm(char_embed, 2) 58 | char_embed = char_embed.view(self.B, self.W, -1) 59 | # a, b, c = char_embed.shape 60 | # left = self.total_embed_size - c 61 | # char_embed = nn.ReplicationPad1d(left//2)(char_embed) 62 | char_embed = char_embed.permute(0, 2, 1) 63 | word_embed = self.layers(char_embed) 64 | return word_embed 65 | 66 | def forward(self, x): 67 | ''' 68 | x: (batch, nb_words, nb_characters [tokens]) 69 | ''' 70 | self.B, self.W, self.Ct = x.size() 71 | return self.forward_embed(x) 72 | 73 | 74 | class PartialConcat(nn.Module): 75 | 76 | def __init__( 77 | self, 78 | num_embeddings, 79 | embed_dim=300, 80 | liwe_char_dim=24, 81 | liwe_neurons=[128, 256], 82 | liwe_dropout=0.0, 83 | liwe_wnorm=True, 84 | liwe_batch_norm=True, 85 | liwe_activation=nn.ReLU(inplace=True), 86 | max_chars=26, 87 | **kwargs 88 | ): 89 | super(PartialConcat, self).__init__() 90 | 91 | if type(liwe_activation) == str: 92 | liwe_activation = eval(liwe_activation) 93 | 94 | weight_norm = nn.Identity 95 | if liwe_wnorm: 96 | from torch.nn.utils import weight_norm 97 | 98 | self.embed = nn.Embedding(num_embeddings, liwe_char_dim) 99 | 100 | self.attn = attention.SelfAttention( 101 | liwe_char_dim * max_chars, 102 | activation=liwe_activation, 103 | k=4, 104 | ) 105 | 106 | self.total_embed_size = liwe_char_dim * max_chars 107 | 108 | layers = [] 109 | liwe_neurons = liwe_neurons + [embed_dim] 110 | in_sizes = [liwe_char_dim * max_chars] + liwe_neurons 111 | 112 | batch_norm = nn.BatchNorm1d 113 | if not liwe_batch_norm: 114 | batch_norm = nn.Identity 115 | 116 | for n, i in zip(liwe_neurons, in_sizes): 117 | layer = nn.Sequential(*[ 118 | weight_norm( 119 | nn.Conv1d(i, n, 1) 120 | ), 121 | nn.Dropout(liwe_dropout), 122 | batch_norm(n), 123 | liwe_activation, 124 | ]) 125 | layers.append(layer) 126 | 127 | self.layers = nn.Sequential(*layers) 128 | 129 | def forward_embed(self, x): 130 | 131 | partial_words = x.view(self.B, -1) # (B, W*Ct) 132 | char_embed = self.embed(partial_words) # (B, W*Ct, Cw) 133 | # char_embed = l2norm(char_embed, 2) 134 | char_embed = char_embed.view(self.B, self.W, -1) 135 | # a, b, c = char_embed.shape 136 | # left = self.total_embed_size - c 137 | # char_embed = nn.ReplicationPad1d(left//2)(char_embed) 138 | char_embed = char_embed.permute(0, 2, 1) 139 | word_embed = self.layers(char_embed) 140 | return word_embed 141 | 142 | def forward(self, x): 143 | ''' 144 | x: (batch, nb_words, nb_characters [tokens]) 145 | ''' 146 | self.B, self.W, self.Ct = x.size() 147 | return self.forward_embed(x) 148 | 149 | 150 | class PartialGRUs(nn.Module): 151 | 152 | def __init__( 153 | self, 154 | num_embeddings, 155 | embed_dim=300, 156 | liwe_char_dim=24, 157 | **kwargs 158 | ): 159 | super().__init__() 160 | 161 | self.embed = nn.Embedding(num_embeddings, liwe_char_dim) 162 | 163 | self.rnn = nn.GRU(liwe_char_dim, embed_dim, 1, 164 | batch_first=True, bidirectional=False) 165 | 166 | def forward_embed(self, x): 167 | 168 | partial_words = x.view(self.B, -1) # (B, W*Ct) 169 | char_embed = self.embed(partial_words) # (B, W*Ct, Cw) 170 | char_embed = char_embed.view(self.B*self.W, self.Ct, -1) 171 | x, _ = self.rnn(char_embed) 172 | 173 | b, t, d = x.shape 174 | # x = x.view(b, t, 2, d//2).mean(-2) 175 | x = x.max(1)[0] 176 | 177 | return x 178 | 179 | def forward(self, x): 180 | ''' 181 | x: (batch, nb_words, nb_characters [tokens]) 182 | ''' 183 | x = x[:,:,:30] 184 | self.B, self.W, self.Ct = x.size() 185 | embed_word = self.forward_embed(x) 186 | embed_word = embed_word.view(self.B, self.W, -1) 187 | 188 | return embed_word.permute(0, 2, 1) 189 | 190 | 191 | class PartialConcatScale(nn.Module): 192 | 193 | def __init__( 194 | self, 195 | num_embeddings, 196 | embed_dim=300, 197 | liwe_char_dim=24, 198 | liwe_neurons=[128, 256], 199 | liwe_dropout=0.0, 200 | liwe_wnorm=True, 201 | max_chars = 26, 202 | liwe_activation=nn.ReLU(), 203 | liwe_batch_norm=True, 204 | ): 205 | super(PartialConcatScale, self).__init__() 206 | 207 | if type(liwe_activation) == str: 208 | liwe_activation = eval(liwe_activation) 209 | 210 | if not liwe_wnorm: 211 | weight_norm = nn.Identity 212 | else: 213 | from torch.nn.utils import weight_norm 214 | 215 | self.embed = nn.Embedding(num_embeddings, liwe_char_dim) 216 | self.embed_dim = embed_dim 217 | self.max_chars = max_chars 218 | self.total_embed_size = liwe_char_dim * max_chars 219 | 220 | layers = [] 221 | liwe_neurons = liwe_neurons + [embed_dim] 222 | in_sizes = [liwe_char_dim * max_chars] + liwe_neurons 223 | 224 | batch_norm = nn.BatchNorm1d 225 | if not liwe_batch_norm: 226 | batch_norm = nn.Identity 227 | 228 | for n, i in zip(liwe_neurons, in_sizes): 229 | layer = nn.Sequential(*[ 230 | weight_norm( 231 | nn.Conv1d(i, n, 1) 232 | ), 233 | nn.Dropout(liwe_dropout), 234 | batch_norm(n), 235 | liwe_activation, 236 | ]) 237 | layers.append(layer) 238 | 239 | self.scale = nn.Parameter(torch.ones(1)) 240 | 241 | self.layers = nn.Sequential(*layers) 242 | 243 | def forward_embed(self, x): 244 | positive_mask = (x == 0) 245 | words_length = positive_mask.sum(-1) 246 | words_length = words_length.view(self.B, -1, 1, 1) 247 | words_length = words_length.float() 248 | 249 | partial_words = x.view(self.B, -1) # (B, W*Ct) 250 | char_embed = self.embed(partial_words) # (B, W*Ct, Cw) 251 | 252 | mask = (partial_words == 0) 253 | 254 | char_embed[mask] = 0 255 | # char_embed = l2norm(char_embed, 2) 256 | char_embed = char_embed.view(self.B, self.W, -1) 257 | 258 | char_embed_scale = char_embed.view( 259 | self.B, self.W, self.max_chars, -1 260 | ) 261 | 262 | char_embed_scale = char_embed_scale * (torch.sqrt(words_length) * self.scale) 263 | char_embed_scale = char_embed_scale.view(self.B, self.W, -1) 264 | 265 | # a, b, c = char_embed.shape 266 | # left = self.total_embed_size - c 267 | # char_embed = nn.ReplicationPad1d(left//2)(char_embed) 268 | char_embed_scale = char_embed_scale.permute(0, 2, 1) 269 | word_embed_scaled = self.layers(char_embed_scale) 270 | 271 | # char_embed = char_embed.permute(0, 2, 1) 272 | # word_embed_nonscaled = self.layers(char_embed) 273 | 274 | # sample = word_embed_nonscaled[0][:20] 275 | 276 | # print(sample.mean(-1)) 277 | # print(sample.std(-1)) 278 | # print('\n') 279 | # sample = word_embed_scaled[0][:20] 280 | # print(sample.mean(-1)) 281 | # print(sample.std(-1)) 282 | # exit() 283 | return word_embed_scaled 284 | 285 | def forward(self, x): 286 | ''' 287 | x: (batch, nb_words, nb_characters [tokens]) 288 | ''' 289 | self.B, self.W, self.Ct = x.size() 290 | return self.forward_embed(x) 291 | 292 | 293 | class PartialGRUProj(nn.Module): 294 | 295 | def __init__( 296 | self, 297 | num_embeddings, 298 | hidden_size=384, 299 | embed_dim=300, 300 | liwe_char_dim=24, 301 | **kwargs 302 | ): 303 | super().__init__() 304 | 305 | self.embed = nn.Embedding(num_embeddings, liwe_char_dim) 306 | 307 | self.rnn = nn.GRU(liwe_char_dim, hidden_size, 1, 308 | batch_first=True, bidirectional=False) 309 | 310 | self.fc = nn.Linear(hidden_size, embed_dim) 311 | 312 | def forward_embed(self, x): 313 | 314 | partial_words = x.view(self.B, -1) # (B, W*Ct) 315 | char_embed = self.embed(partial_words) # (B, W*Ct, Cw) 316 | char_embed = char_embed.view(self.B*self.W, self.Ct, -1) 317 | x, _ = self.rnn(char_embed) 318 | 319 | b, t, d = x.shape 320 | # x = x.view(b, t, 2, d//2).mean(-2) 321 | x = x.max(1)[0] 322 | 323 | x = self.fc(x) 324 | 325 | return x 326 | 327 | def forward(self, x): 328 | ''' 329 | x: (batch, nb_words, nb_characters [tokens]) 330 | ''' 331 | x = x[:,:,:30] 332 | self.B, self.W, self.Ct = x.size() 333 | embed_word = self.forward_embed(x) 334 | embed_word = embed_word.view(self.B, self.W, -1) 335 | 336 | return embed_word.permute(0, 2, 1) 337 | 338 | 339 | class GloveEmb(nn.Module): 340 | 341 | def __init__( 342 | self, 343 | num_embeddings, 344 | glove_dim, 345 | glove_path, 346 | add_rand_embed=False, 347 | rand_dim=None, 348 | **kwargs 349 | ): 350 | super().__init__() 351 | 352 | self.num_embeddings = num_embeddings 353 | self.add_rand_embed = add_rand_embed 354 | self.glove_dim = glove_dim 355 | self.final_word_emb = glove_dim 356 | 357 | # word embedding 358 | self.glove = nn.Embedding(num_embeddings, glove_dim) 359 | glove = nn.Parameter(torch.load(glove_path)) 360 | self.glove.weight = glove 361 | self.glove.requires_grad = False 362 | 363 | if add_rand_embed: 364 | self.embed = nn.Embedding(num_embeddings, rand_dim) 365 | self.final_word_emb = glove_dim + rand_dim 366 | 367 | def get_word_embed_size(self,): 368 | return self.final_word_emb 369 | 370 | def forward(self, x): 371 | ''' 372 | x: (batch, nb_words, nb_characters [tokens]) 373 | ''' 374 | emb = self.glove(x) 375 | if self.add_rand_embed: 376 | emb2 = self.embed(x) 377 | emb = torch.cat([emb, emb2], dim=2) 378 | 379 | return emb 380 | -------------------------------------------------------------------------------- /retrieval/model/txtenc/factory.py: -------------------------------------------------------------------------------- 1 | from . import txtenc 2 | from . import embedding 3 | from . import pooling 4 | import torch.nn as nn 5 | 6 | import torch 7 | import math 8 | 9 | 10 | __text_encoders__ = { 11 | 'gru': { 12 | 'class': txtenc.RNNEncoder, 13 | 'args': { 14 | 'use_bi_gru': True, 15 | 'rnn_type': nn.GRU, 16 | }, 17 | }, 18 | 'gru_glove': { 19 | 'class': txtenc.GloveRNNEncoder, 20 | 'args': { 21 | }, 22 | }, 23 | # 'scan': { 24 | # 'class': txtenc.EncoderText, 25 | # 'args': { 26 | # 'use_bi_gru': True, 27 | # 'num_layers': 1, 28 | # }, 29 | # }, 30 | } 31 | 32 | 33 | def get_available_txtenc(): 34 | return __text_encoders__.keys() 35 | 36 | 37 | def get_text_encoder(name, tokenizers, **kwargs): 38 | model_class = __text_encoders__[name]['class'] 39 | model = model_class(tokenizers=tokenizers, **kwargs) 40 | return model 41 | 42 | 43 | def get_txt_pooling(pool_name): 44 | 45 | _pooling = { 46 | 'mean': pooling.mean_pooling, 47 | 'max': pooling.max_pooling, 48 | 'lens': pooling.last_hidden_state_pool, 49 | 'none': pooling.none, 50 | } 51 | 52 | return _pooling[pool_name] 53 | -------------------------------------------------------------------------------- /retrieval/model/txtenc/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_pooling(texts, lengths): 5 | out = torch.stack( 6 | [t[:l].mean(0) for t, l in zip(texts, lengths) 7 | ], dim=0) 8 | return out 9 | 10 | 11 | def max_pooling(texts, lengths): 12 | out = torch.stack( 13 | [t[:l].max(0)[0] for t, l in zip(texts, lengths) 14 | ], dim=0) 15 | return out 16 | 17 | 18 | def last_hidden_state_pool(texts, lengths): 19 | out = torch.stack( 20 | [t[l-1] for t, l in zip(texts, lengths) 21 | ], dim=0) 22 | return out 23 | 24 | 25 | def none(x, l): 26 | return x 27 | 28 | 29 | # def last_hidden_state_pool(texts, lengths): 30 | # I = torch.LongTensor(lengths).view(-1, 1, 1) 31 | # I = I.expand(texts.size(0), 1, texts[0].size(1))-1 32 | 33 | # if torch.cuda.is_available(): 34 | # I = I.cuda() 35 | 36 | # out = torch.gather(texts, 1, I).squeeze(1) 37 | -------------------------------------------------------------------------------- /retrieval/model/txtenc/txtenc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..similarity.measure import l2norm 5 | from ...utils.layers import default_initializer 6 | 7 | from torch.nn.utils.rnn import pack_padded_sequence 8 | from torch.nn.utils.rnn import pad_packed_sequence 9 | 10 | from ...model.layers import attention, convblocks 11 | from .embedding import PartialConcat, GloveEmb, PartialConcatScale 12 | from . import pooling 13 | 14 | import numpy as np 15 | 16 | 17 | # RNN Based Language Model 18 | class GloveRNNEncoder(nn.Module): 19 | 20 | def __init__( 21 | self, tokenizers, embed_dim, latent_size, 22 | num_layers=1, use_bi_gru=True, no_txtnorm=False, 23 | rnn_type=nn.GRU, glove_path=None, add_rand_embed=False): 24 | 25 | super().__init__() 26 | self.latent_size = latent_size 27 | self.no_txtnorm = no_txtnorm 28 | 29 | assert len(tokenizers) == 1 30 | 31 | num_embeddings = len(tokenizers[0]) 32 | 33 | self.embed = GloveEmb( 34 | num_embeddings, 35 | glove_dim=embed_dim, 36 | glove_path=glove_path, 37 | add_rand_embed=add_rand_embed, 38 | rand_dim=embed_dim, 39 | ) 40 | 41 | 42 | # caption embedding 43 | self.use_bi_gru = use_bi_gru 44 | self.rnn = rnn_type( 45 | self.embed.final_word_emb, 46 | latent_size, num_layers, 47 | batch_first=True, 48 | bidirectional=use_bi_gru 49 | ) 50 | 51 | if hasattr(self.embed, 'embed'): 52 | self.embed.embed.weight.data.uniform_(-0.1, 0.1) 53 | 54 | def forward(self, batch): 55 | """Handles variable size captions 56 | """ 57 | captions, lengths = batch['caption'] 58 | captions = captions.to(self.device) 59 | # Embed word ids to vectors 60 | emb = self.embed(captions) 61 | 62 | # Forward propagate RNN 63 | # self.rnn.flatten_parameters() 64 | cap_emb, _ = self.rnn(emb) 65 | 66 | if self.use_bi_gru: 67 | b, t, d = cap_emb.shape 68 | cap_emb = cap_emb.view(b, t, 2, d//2).mean(-2) 69 | 70 | # normalization in the joint embedding space 71 | if not self.no_txtnorm: 72 | cap_emb = l2norm(cap_emb, dim=-1) 73 | 74 | return cap_emb, lengths 75 | 76 | 77 | # RNN Based Language Model 78 | class RNNEncoder(nn.Module): 79 | 80 | def __init__( 81 | self, tokenizers, embed_dim, latent_size, 82 | num_layers=1, use_bi_gru=True, no_txtnorm=False, 83 | rnn_type=nn.GRU): 84 | 85 | super(RNNEncoder, self).__init__() 86 | self.latent_size = latent_size 87 | self.no_txtnorm = no_txtnorm 88 | 89 | assert len(tokenizers) == 1 90 | num_embeddings = len(tokenizers[0]) 91 | 92 | # word embedding 93 | self.embed = nn.Embedding(num_embeddings, embed_dim) 94 | 95 | # caption embedding 96 | self.use_bi_gru = use_bi_gru 97 | self.rnn = rnn_type( 98 | embed_dim, latent_size, num_layers, 99 | batch_first=True, 100 | bidirectional=use_bi_gru 101 | ) 102 | 103 | self.apply(default_initializer) 104 | 105 | def forward(self, batch): 106 | """Handles variable size captions 107 | """ 108 | captions, lengths = batch['caption'] 109 | captions = captions.to(self.device) 110 | 111 | # Embed word ids to vectors 112 | x = self.embed(captions) 113 | # Forward propagate RNN 114 | # self.rnn.flatten_parameters() 115 | cap_emb, _ = self.rnn(x) 116 | 117 | if self.use_bi_gru: 118 | b, t, d = cap_emb.shape 119 | cap_emb = cap_emb.view(b, t, 2, d//2).mean(-2) 120 | 121 | # normalization in the joint embedding space 122 | if not self.no_txtnorm: 123 | cap_emb = l2norm(cap_emb, dim=-1) 124 | 125 | return cap_emb, lengths 126 | -------------------------------------------------------------------------------- /retrieval/train/__init__.py: -------------------------------------------------------------------------------- 1 | from . import evaluation 2 | from . import test 3 | from . import train 4 | -------------------------------------------------------------------------------- /retrieval/train/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from timeit import default_timer as dt 5 | 6 | from ..utils import layers 7 | 8 | 9 | @torch.no_grad() 10 | def predict_loader(model, data_loader, device): 11 | img_embs, cap_embs, cap_lens = None, None, None 12 | max_n_word = 77 13 | model.eval() 14 | 15 | pbar_fn = lambda x: x 16 | if model.master: 17 | pbar_fn = lambda x: tqdm( 18 | x, total=len(x), 19 | desc='Pred ', 20 | leave=False, 21 | ) 22 | 23 | for batch in pbar_fn(data_loader): 24 | ids = batch['index'] 25 | if len(batch['caption'][0]) == 2: 26 | (_, _), (_, lengths) = batch['caption'] 27 | else: 28 | cap, lengths = batch['caption'] 29 | img_emb, cap_emb = model.forward_batch(batch) 30 | 31 | if img_embs is None: 32 | if len(img_emb.shape) == 3: 33 | is_tensor = True 34 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2))) 35 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2))) 36 | else: 37 | is_tensor = False 38 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1))) 39 | cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1))) 40 | cap_lens = [0] * len(data_loader.dataset) 41 | # cache embeddings 42 | img_embs[ids] = img_emb.data.cpu().numpy() 43 | if is_tensor: 44 | cap_embs[ids,:max(lengths),:] = cap_emb.data.cpu().numpy() 45 | else: 46 | cap_embs[ids,] = cap_emb.data.cpu().numpy() 47 | 48 | for j, nid in enumerate(ids): 49 | cap_lens[nid] = lengths[j] 50 | 51 | if img_embs.shape[0] == cap_embs.shape[0]: 52 | img_embs = remove_img_feat_redundancy(img_embs, data_loader) 53 | 54 | return img_embs, cap_embs, cap_lens 55 | 56 | def remove_img_feat_redundancy(img_embs, data_loader): 57 | return img_embs[np.arange( 58 | start=0, 59 | stop=img_embs.shape[0], 60 | step=data_loader.dataset.captions_per_image, 61 | ).astype(np.int)] 62 | 63 | @torch.no_grad() 64 | def evaluate( 65 | model, img_emb, txt_emb, lengths, 66 | device, shared_size=128, return_sims=False 67 | ): 68 | model.eval() 69 | _metrics_ = ('r1', 'r5', 'r10', 'medr', 'meanr') 70 | 71 | begin_pred = dt() 72 | 73 | img_emb = torch.FloatTensor(img_emb).to(device) 74 | txt_emb = torch.FloatTensor(txt_emb).to(device) 75 | 76 | end_pred = dt() 77 | sims = model.get_sim_matrix_shared( 78 | embed_a=img_emb, embed_b=txt_emb, 79 | lens=lengths, shared_size=shared_size 80 | ) 81 | sims = layers.tensor_to_numpy(sims) 82 | end_sim = dt() 83 | 84 | i2t_metrics = i2t(sims) 85 | t2i_metrics = t2i(sims) 86 | rsum = np.sum(i2t_metrics[:3]) + np.sum(t2i_metrics[:3]) 87 | 88 | i2t_metrics = {f'i2t_{k}': v for k, v in zip(_metrics_, i2t_metrics)} 89 | t2i_metrics = {f't2i_{k}': v for k, v in zip(_metrics_, t2i_metrics)} 90 | 91 | metrics = { 92 | 'pred_time': end_pred-begin_pred, 93 | 'sim_time': end_sim-end_pred, 94 | } 95 | metrics.update(i2t_metrics) 96 | metrics.update(t2i_metrics) 97 | metrics['rsum'] = rsum 98 | 99 | if return_sims: 100 | return metrics, sims 101 | 102 | return metrics 103 | 104 | 105 | def i2t(sims): 106 | npts, ncaps = sims.shape 107 | captions_per_image = ncaps // npts 108 | 109 | ranks = np.zeros(npts) 110 | top1 = np.zeros(npts) 111 | for index in range(npts): 112 | inds = np.argsort(sims[index])[::-1] 113 | # Score 114 | rank = 1e20 115 | begin = captions_per_image * index 116 | end = captions_per_image * index + captions_per_image 117 | for i in range(begin, end, 1): 118 | tmp = np.where(inds == i)[0][0] 119 | if tmp < rank: 120 | rank = tmp 121 | ranks[index] = rank 122 | top1[index] = inds[0] 123 | 124 | # Compute metrics 125 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 126 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 127 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 128 | medr = np.floor(np.median(ranks)) + 1 129 | meanr = ranks.mean() + 1 130 | 131 | return (r1, r5, r10, medr, meanr) 132 | 133 | 134 | def t2i(sims): 135 | npts, ncaps = sims.shape 136 | captions_per_image = ncaps // npts 137 | 138 | ranks = np.zeros(captions_per_image * npts) 139 | top1 = np.zeros(captions_per_image * npts) 140 | 141 | # --> (5N(caption), N(image)) 142 | sims = sims.T 143 | for index in range(npts): 144 | for i in range(captions_per_image): 145 | inds = np.argsort(sims[captions_per_image * index + i])[::-1] 146 | ranks[captions_per_image * index + i] = np.where(inds == index)[0][0] 147 | top1[captions_per_image * index + i] = inds[0] 148 | 149 | # Compute metrics 150 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 151 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 152 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 153 | medr = np.floor(np.median(ranks)) + 1 154 | meanr = ranks.mean() + 1 155 | 156 | return (r1, r5, r10, medr, meanr) 157 | -------------------------------------------------------------------------------- /retrieval/train/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.optim.lr_scheduler 3 | 4 | 5 | _scheduler = { 6 | 'cosine': torch.optim.lr_scheduler.CosineAnnealingLR, 7 | 'step': torch.optim.lr_scheduler.StepLR, 8 | } 9 | 10 | def get_scheduler(name, optimizer, **kwargs): 11 | return _scheduler[name](optimizer=optimizer, **kwargs) 12 | -------------------------------------------------------------------------------- /retrieval/train/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..utils.logger import get_logger 5 | 6 | logger = get_logger() 7 | 8 | 9 | # Inspired from https://github.com/jnhwkim/ban-vqa/blob/master/train.py 10 | class BanOptimizer(): 11 | 12 | def __init__(self, parameters, name='Adamax', lr=0.0007, 13 | gradual_warmup_steps=[0.5, 2.0, 4], lr_decay_epochs=[10, 20, 2], 14 | lr_decay_rate=.25 15 | ): 16 | 17 | logger.info(f'lr {lr}') 18 | logger.info(f'{gradual_warmup_steps}') 19 | logger.info(f'{lr_decay_epochs}') 20 | 21 | self.optimizer = torch.optim.__dict__[name]( 22 | filter(lambda p: p.requires_grad, parameters), 23 | lr=lr 24 | ) 25 | self.lr_decay_rate = lr_decay_rate 26 | self.lr_decay_epochs = eval("range({},{},{})".format(*lr_decay_epochs)) 27 | 28 | self.gradual_warmup_steps = [ 29 | weight * lr for weight in eval("torch.linspace({},{},{})".format( 30 | gradual_warmup_steps[0], 31 | gradual_warmup_steps[1], 32 | int(gradual_warmup_steps[2]) 33 | )) 34 | ] 35 | self.grad_clip = .25 36 | self.total_norm = 0 37 | self.count_norm = 0 38 | self.iteration = 0 39 | 40 | def set_lr(self): 41 | epoch_id = self.iteration 42 | optim = self.optimizer 43 | old_lr = optim.param_groups[0]['lr'] 44 | if epoch_id < len(self.gradual_warmup_steps): 45 | new_lr = self.gradual_warmup_steps[epoch_id] 46 | optim.param_groups[0]['lr'] = new_lr 47 | elif epoch_id in self.lr_decay_epochs: 48 | new_lr = optim.param_groups[0]['lr'] * self.lr_decay_rate 49 | optim.param_groups[0]['lr'] = new_lr 50 | 51 | def display_norm(self): 52 | logger.info(' norm: {:.5f}'.format(self.total_norm / self.count_norm)) 53 | 54 | def step(self): 55 | self.iteration += 1 56 | self.count_norm += 1 57 | self.optimizer.step() 58 | self.set_lr() 59 | 60 | def zero_grad(self): 61 | self.optimizer.zero_grad() 62 | 63 | def state_dict(self): 64 | state = {} 65 | state['optimizer'] = self.optimizer.state_dict() 66 | return state 67 | 68 | def load_state_dict(self, state): 69 | self.optimizer.load_state_dict(state['optsimizer']) 70 | 71 | def __getattr__(self, key): 72 | return self.optimizer.__getattribute__(key) 73 | 74 | 75 | _optimizers = { 76 | 'adam': torch.optim.Adam, 77 | 'adamax': BanOptimizer, 78 | } 79 | 80 | def get_optimizer(name, parameters, **kwargs): 81 | return _optimizers[name](parameters, **kwargs) 82 | -------------------------------------------------------------------------------- /retrieval/train/test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwehrmann/retrieval.pytorch/80569810cd5d84873afe29ab387275f14ab1d677/retrieval/train/test.py -------------------------------------------------------------------------------- /retrieval/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from addict import Dict 7 | from pathlib import Path 8 | from timeit import default_timer as dt 9 | 10 | from torch.utils.data import dataset 11 | from torch.nn.utils.clip_grad import clip_grad_norm_ 12 | 13 | from . import evaluation 14 | from . import optimizers 15 | from .lr_scheduler import get_scheduler 16 | from ..data.dataiterator import DataIterator 17 | from ..utils import helper, logger, file_utils 18 | 19 | torch.manual_seed(0) 20 | random.seed(0, version=2) 21 | 22 | def freeze(module): 23 | for x in module.parameters(): 24 | x.requires_grad = False 25 | 26 | 27 | class Trainer: 28 | 29 | def __init__( 30 | self, model=None, device=torch.device('cuda'), world_size=1, 31 | args=None, sysoutlog=tqdm.write, master=True, path='runs/' 32 | ): 33 | self.model = model 34 | self.device = device 35 | self.train_logger = logger.LogCollector() 36 | self.val_logger = logger.LogCollector() 37 | self.args = args 38 | self.sysoutlog = sysoutlog 39 | self.optimizer = None 40 | self.metrics = {} 41 | self.master = master 42 | self.world_size = world_size 43 | self.continue_training = True 44 | self.path = path 45 | 46 | def setup_optim(self, optimizer={}, lr=1e-3, lr_scheduler=None, clip_grad=2., 47 | log_histograms=False, log_grad_norm=False, early_stop=50, freeze_modules=[], 48 | **kwargs 49 | ): 50 | # TODO: improve this! 51 | count_params = lambda p: np.sum([ 52 | np.product(tuple(x.shape)) for x in p 53 | ]) 54 | total_params = count_params(self.model.parameters()) 55 | 56 | for fmod in freeze_modules: 57 | print(f'Freezing {fmod}') 58 | freeze(eval(f'self.{fmod}')) 59 | 60 | trainable_params = [ 61 | x for x in self.model.parameters() 62 | if x.requires_grad 63 | ] 64 | 65 | self.optimizer = optimizers.get_optimizer( 66 | optimizer.name, trainable_params, **optimizer.params, 67 | ) 68 | 69 | scheduler = None 70 | if lr_scheduler.name is not None: 71 | scheduler = get_scheduler( 72 | lr_scheduler.name, self.optimizer, **lr_scheduler.params) 73 | 74 | for k in self.optimizer.param_groups: 75 | self.sysoutlog( 76 | f"lr: {k['lr']}, #layers: {len(k['params'])}, #params: {count_params(k['params']):,}") 77 | 78 | self.sysoutlog(f'Total Params: {total_params:,}, ') 79 | self.initial_lr = lr 80 | self.lr_scheduler = scheduler 81 | self.clip_grad = clip_grad 82 | self.log_histograms = log_histograms 83 | self.log_grad_norm = False 84 | self.save_all = False 85 | self.best_val = 0 86 | self.count = early_stop 87 | self.early_stop = early_stop 88 | 89 | def fit(self, train_loader, valid_loaders, lang_loaders=[], 90 | init_iteration=0, nb_epochs=2000, 91 | log_interval=50, valid_interval=500, world_size=1, 92 | ): 93 | 94 | self.tb_writer = helper.set_tensorboard_logger(self.path) 95 | self.path = self.tb_writer.file_writer.get_logdir() 96 | file_utils.save_yaml_opts( 97 | Path(self.path) / 'options.yaml', self.args 98 | ) 99 | self.best_model_path = Path(self.path) / Path('best_model.pkl') 100 | 101 | self.check_optimizer_setup() 102 | pbar = lambda x: range(x) 103 | if self.master: 104 | pbar = lambda x: tqdm(range(x), desc='Epochs') 105 | 106 | for epoch in pbar(nb_epochs): 107 | self.train_epoch( 108 | train_loader=train_loader, 109 | lang_loaders=lang_loaders, 110 | epoch=epoch, 111 | log_interval=log_interval, 112 | valid_loaders=valid_loaders, 113 | valid_interval=valid_interval, 114 | path=self.path, 115 | ) 116 | if not self.continue_training: 117 | break 118 | 119 | def check_optimizer_setup(self): 120 | if self.optimizer is None: 121 | print('You forgot to setup_optim.') 122 | exit() 123 | 124 | def _forward_multimodal_loss(self, batch): 125 | img_emb, cap_emb = self.model.forward_batch(batch) 126 | lens = batch['caption'][1] 127 | sim_matrix = self.model.get_sim_matrix(img_emb, cap_emb, lens) 128 | loss = self.model.mm_criterion(sim_matrix) 129 | return loss 130 | 131 | def _forward_multilanguage_loss(self, captions_a, lens_a, captions_b, lens_b, *args): 132 | cap_a_embed = self.model.embed_captions({'caption': (captions_a, lens_a)}) 133 | cap_b_embed = self.model.embed_captions({'caption': (captions_b, lens_b)}) 134 | 135 | if len(cap_a_embed.shape) == 3: 136 | from ..model.txtenc import pooling 137 | cap_a_embed = pooling.last_hidden_state_pool(cap_a_embed, lens_a) 138 | cap_b_embed = pooling.last_hidden_state_pool(cap_b_embed, lens_b) 139 | 140 | sim_matrix = self.model.get_ml_sim_matrix(cap_a_embed, cap_b_embed, lens_b) 141 | loss = self.model.ml_criterion(sim_matrix) 142 | return loss 143 | 144 | def _get_lang_iters(self, lang_loaders): 145 | lang_iters = [ 146 | DataIterator(loader=loader, device=self.device, non_stop=True) 147 | for loader in lang_loaders] 148 | return lang_iters 149 | 150 | def _get_multilanguage_total_loss(self, lang_iters): 151 | total_lang_loss = 0. 152 | loss_info = {} 153 | for lang_iter in lang_iters: 154 | lang_data = lang_iter.next() 155 | lang_loss = self._forward_multilanguage_loss(*lang_data) 156 | total_lang_loss += lang_loss 157 | loss_info[f'train_loss_{str(lang_iter)}'] = lang_loss 158 | return total_lang_loss, loss_info 159 | 160 | def train_epoch(self, train_loader, lang_loaders, 161 | epoch, valid_loaders=[], log_interval=50, 162 | valid_interval=500, path='' 163 | ): 164 | lang_iters = self._get_lang_iters(lang_loaders) 165 | 166 | pbar = lambda x: x 167 | if self.master: 168 | pbar = lambda x: tqdm( 169 | x, total=len(x), 170 | desc='Steps ', 171 | leave=False, 172 | ) 173 | 174 | for batch in pbar(train_loader): 175 | train_info = self._forward(batch, lang_iters, epoch) 176 | self._update_tb_log_info(train_info) 177 | 178 | if self.model.mm_criterion.iteration % valid_interval == 0: 179 | self.run_evaluation(valid_loaders) 180 | 181 | if self.model.mm_criterion.iteration % log_interval == 0 and self.master: 182 | self._print_log_info(train_info) 183 | 184 | def _forward(self, batch, lang_iters, epoch): 185 | self.model.train() 186 | self.optimizer.zero_grad() 187 | begin_forward = dt() 188 | 189 | multimodal_loss = self._forward_multimodal_loss(batch) 190 | total_lang_loss, loss_info = self._get_multilanguage_total_loss(lang_iters) 191 | total_loss = multimodal_loss + total_lang_loss 192 | total_loss.backward() 193 | 194 | norm = 0. 195 | if self.clip_grad > 0: 196 | norm = clip_grad_norm_(self.model.parameters(), self.clip_grad) 197 | 198 | self.optimizer.step() 199 | if self.lr_scheduler is not None: 200 | self.lr_scheduler.step() 201 | 202 | end_backward = dt() 203 | batch_time = end_backward - begin_forward 204 | return self._update_train_info(batch_time, multimodal_loss, 205 | total_loss, epoch, norm, loss_info) 206 | 207 | def _print_log_info(self, train_info): 208 | helper.print_tensor_dict(train_info, print_fn=self.sysoutlog) 209 | if self.log_histograms: 210 | logger.log_param_histograms(self.model, self.tb_writer, self.model.mm_criterion.iteration) 211 | 212 | def _update_train_info(self, batch_time, multimodal_loss, total_loss, epoch, norm, loss_info): 213 | train_info = Dict({ 214 | 'loss': multimodal_loss, 215 | 'iteration': self.model.mm_criterion.iteration, 216 | 'total_loss': total_loss, 217 | 'k': self.model.mm_criterion.k, 218 | 'batch_time': batch_time, 219 | 'countdown': self.count, 220 | 'epoch': epoch, 221 | 'norm': norm, 222 | }) 223 | train_info.update(loss_info) 224 | for param_group in self.optimizer.param_groups: 225 | if 'name' in param_group: 226 | train_info.update({f"lr_{param_group['name']}": param_group['lr']}) 227 | else: 228 | train_info.update({'lr_base': param_group['lr']}) 229 | return train_info 230 | 231 | def _update_tb_log_info(self, train_info): 232 | if self.master: 233 | logger.tb_log_dict( 234 | tb_writer=self.tb_writer, data_dict=train_info, 235 | iteration=self.model.mm_criterion.iteration, prefix='train' 236 | ) 237 | 238 | def run_evaluation(self, valid_loaders): 239 | metrics, val_metric = self.evaluate_loaders(valid_loaders) 240 | self._update_early_stop_vars(val_metric) 241 | if self.master: 242 | self.save(path=self.path, is_best=(val_metric >= self.best_val), args=self.args, rsum=val_metric) 243 | for metric, values in metrics.items(): 244 | self.tb_writer.add_scalar(metric, values, self.model.mm_criterion.iteration) 245 | self._check_early_stop() 246 | 247 | def _check_early_stop(self): 248 | if self.count == 0 and self.master: 249 | self.sysoutlog('\n\nEarly stop\n\n') 250 | self.continue_training = False 251 | 252 | def _update_early_stop_vars(self, val_metric): 253 | if val_metric < self.best_val: 254 | self.count -= 1 255 | elif not self.save_all: 256 | self.count = self.early_stop 257 | self.best_val = val_metric 258 | 259 | def evaluate_loaders(self, loaders): 260 | loader_metrics = {} 261 | final_sum = 0. 262 | nb_loaders = len(loaders) 263 | 264 | for i, loader in enumerate(loaders): 265 | loader_name = str(loader.dataset) 266 | self.sysoutlog(f'Evaluating {i+1:2d}/{nb_loaders:2d} - {loader_name}') 267 | img_emb, txt_emb, lens = evaluation.predict_loader(self.model, loader, self.device) 268 | 269 | result = evaluation.evaluate( 270 | model=self.model, img_emb=img_emb, 271 | txt_emb=txt_emb, lengths=lens, 272 | device=self.device, shared_size=128) 273 | 274 | for k, v in result.items(): 275 | self.sysoutlog(f'{k:<10s}: {v:>6.1f}') 276 | 277 | result = { 278 | f'{loader_name}/{metric_name}': v 279 | for metric_name, v in result.items() 280 | } 281 | 282 | loader_metrics.update(result) 283 | final_sum += result[f'{loader_name}/rsum'] 284 | return loader_metrics, final_sum/float(nb_loaders) 285 | 286 | def save(self, path=None, is_best=False, args=None, **kwargs): 287 | helper.save_checkpoint( 288 | path, self.model, 289 | optimizer=self.optimizer, 290 | is_best=is_best, 291 | save_all=self.save_all, 292 | iteration=self.model.mm_criterion.iteration, 293 | args=self.args, 294 | **kwargs 295 | ) 296 | 297 | def load(self, path=None): 298 | if path is None: 299 | path = self.best_model_path 300 | states = helper.restore_checkpoint(path, self.model, None) 301 | self.model = states['model'].to(self.device) 302 | 303 | def __repr__(self,): 304 | string = ( 305 | f'{type(self).__name__} ' 306 | f'{type(self.model).__name__} ' 307 | ) 308 | return string 309 | -------------------------------------------------------------------------------- /retrieval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import helper 2 | from . import logger 3 | from . import file_utils 4 | from . import layers -------------------------------------------------------------------------------- /retrieval/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import yaml 4 | import json 5 | import pickle 6 | from yaml import Dumper 7 | from addict import Dict 8 | 9 | 10 | def read_txt(path): 11 | return open(path).read().strip().split('\n') 12 | 13 | 14 | def save_json(path, obj): 15 | with open(path, 'w') as fp: 16 | json.dump(obj, fp) 17 | 18 | 19 | def load_json(path): 20 | with open(path, 'rb') as fp: 21 | return json.load(fp) 22 | 23 | 24 | def save_yaml_opts(path_yaml, opts): 25 | # Warning: copy is not nested 26 | options = copy.copy(opts) 27 | 28 | # https://gist.github.com/oglops/c70fb69eef42d40bed06 29 | def dict_representer(dumper, data): 30 | return dumper.represent_dict(data.items()) 31 | Dumper.add_representer(Dict, dict_representer) 32 | 33 | with open(path_yaml, 'w') as yaml_file: 34 | yaml.dump(options, yaml_file, Dumper=Dumper, default_flow_style=False) 35 | 36 | 37 | def merge_dictionaries(dict1, dict2): 38 | for key in dict2: 39 | if key in dict1 and isinstance(dict1[key], dict) and isinstance(dict2[key], dict): 40 | merge_dictionaries(dict1[key], dict2[key]) 41 | else: 42 | dict1[key] = dict2[key] 43 | 44 | 45 | def load_yaml_opts(path_yaml): 46 | """ Load options dictionary from a yaml file 47 | """ 48 | result = {} 49 | with open(path_yaml, 'r') as yaml_file: 50 | options_yaml = yaml.safe_load(yaml_file) 51 | includes = options_yaml.get('__include__', False) 52 | if includes: 53 | if type(includes) != list: 54 | includes = [includes] 55 | for include in includes: 56 | filename = '{}/{}'.format(os.path.dirname(path_yaml), include) 57 | if os.path.isfile(filename): 58 | parent = load_yaml_opts(filename) 59 | else: 60 | parent = load_yaml_opts(include) 61 | merge_dictionaries(result, parent) 62 | # to be sure the main options overwrite the parent options 63 | merge_dictionaries(result, options_yaml) 64 | result.pop('__include__', None) 65 | result = Dict(result) 66 | return result 67 | 68 | 69 | def parse_loader_name(data_name): 70 | if '.' in data_name: 71 | name, lang = data_name.split('.') 72 | return name, lang 73 | else: 74 | return data_name, None 75 | 76 | 77 | def load_pickle(path, encoding="ASCII"): 78 | with open(path, 'rb') as f: 79 | return pickle.load(f, encoding) 80 | 81 | 82 | def save_pickle(path, obj): 83 | with open(path, 'wb') as f: 84 | pickle.dump(obj, f) 85 | -------------------------------------------------------------------------------- /retrieval/utils/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tensorboardX import SummaryWriter 4 | 5 | from ..utils.logger import get_logger 6 | 7 | logger = get_logger() 8 | 9 | def save_checkpoint( 10 | outpath, model, optimizer=None, 11 | is_best=False, save_all=False, **kwargs 12 | ): 13 | 14 | if hasattr(model, 'module'): 15 | model = model.module 16 | 17 | state_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()} 18 | state_dict.update(**kwargs) 19 | 20 | if not save_all: 21 | epoch = -1 22 | 23 | torch.save( 24 | obj=state_dict, 25 | f=os.path.join(outpath, f'checkpoint_{epoch}.pkl'), 26 | ) 27 | 28 | if is_best: 29 | import shutil 30 | shutil.copy( 31 | os.path.join(outpath, f'checkpoint_{epoch}.pkl'), 32 | os.path.join(outpath, 'best_model.pkl'), 33 | ) 34 | 35 | def load_model(path): 36 | 37 | from .. import model 38 | from addict import Dict 39 | from ..data.tokenizer import Tokenizer 40 | 41 | checkpoint = torch.load( 42 | path, map_location=lambda storage, loc: storage 43 | ) 44 | vocab_paths = checkpoint['args']['dataset']['vocab_paths'] 45 | tokenizers = [Tokenizer(vocab_path=x) for x in vocab_paths] 46 | 47 | model_params = Dict(**checkpoint['args']['model']) 48 | model = model.Retrieval(**model_params, tokenizers=tokenizers) 49 | model.load_state_dict(checkpoint['model']) 50 | 51 | return model 52 | 53 | def restore_checkpoint(path, model=None, optimizer=False): 54 | state_dict = torch.load( 55 | path, map_location=lambda storage, loc: storage 56 | ) 57 | new_state = {} 58 | for k, v in state_dict['model'].items(): 59 | new_state[k.replace('module.', '')] = v 60 | 61 | if model is None: 62 | from .. import model 63 | model_params = state_dict['args']['model_args'] 64 | model = model.Retrieval(**model_params) 65 | 66 | model.load_state_dict(new_state) 67 | state_dict['model'] = model 68 | 69 | if optimizer: 70 | optimizer = state_dict['optimizer'] 71 | state_dict['optimizer'] = None 72 | 73 | return state_dict 74 | 75 | 76 | def get_tb_writer(logger_path): 77 | if logger_path == 'runs/': 78 | tb_writer = SummaryWriter() 79 | logger_path = tb_writer.file_writer.get_logdir() 80 | else: 81 | tb_writer = SummaryWriter(logger_path) 82 | return tb_writer 83 | 84 | 85 | def get_device(gpu_id): 86 | if gpu_id >= 0: 87 | return torch.device('cuda:{}'.format(gpu_id)) 88 | return torch.device('cpu') 89 | 90 | 91 | def reset_pbar(pbar): 92 | from time import time 93 | pbar.n = 0 94 | pbar.last_print_n = 0 95 | pbar.start_t = time() 96 | pbar.last_print_t = time() 97 | pbar.update() 98 | return pbar 99 | 100 | 101 | def print_tensor_dict(tensor_dict, print_fn): 102 | line = [] 103 | for k, v in sorted(tensor_dict.items()): 104 | try: 105 | v = v.item() 106 | except AttributeError: 107 | pass 108 | line.append(f'{k.title()}: {v:10.6f}') 109 | print_fn(', '.join(line)) 110 | 111 | 112 | def set_tensorboard_logger(path): 113 | if path is not None: 114 | if os.path.exists(path): 115 | a = input(f'{path} already exists! Do you want to rewrite it? [y/n] ') 116 | if a.lower() == 'y': 117 | import shutil 118 | shutil.rmtree(path) 119 | tb_writer = get_tb_writer(path) 120 | else: 121 | exit() 122 | else: 123 | tb_writer = get_tb_writer(path) 124 | else: 125 | tb_writer = get_tb_writer() 126 | return tb_writer 127 | -------------------------------------------------------------------------------- /retrieval/utils/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def default_initializer(m): 6 | if type(m) == nn.Linear: 7 | torch.nn.init.xavier_uniform_(m.weight) 8 | m.bias.data.fill_(0.00) 9 | elif type(m) == nn.Conv1d: 10 | torch.nn.init.xavier_uniform_(m.weight) 11 | m.bias.data.fill_(0.00) 12 | elif type(m) == nn.Embedding: 13 | m.weight.data.uniform_(-0.1, 0.1) 14 | 15 | 16 | def tensor_to_numpy(x): 17 | return x.data.cpu().numpy() 18 | -------------------------------------------------------------------------------- /retrieval/utils/logger.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import logging 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=0): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / (.0001 + self.count) 22 | 23 | def __str__(self): 24 | """String representation for logging 25 | """ 26 | # for values that should be recorded exactly e.g. iteration number 27 | if self.count == 0: 28 | return f'{self.val:6g}' 29 | # for stats 30 | return f'{self.val:.3f} ({self.avg:.3f})' 31 | 32 | 33 | class LogCollector(object): 34 | """A collection of logging objects that can change from train to val""" 35 | 36 | def __init__(self): 37 | # to keep the order of logged variables deterministic 38 | self.meters = OrderedDict() 39 | 40 | def update(self, k, v, n=0): 41 | # create a new meter if previously not recorded 42 | if k not in self.meters: 43 | self.meters[k] = AverageMeter() 44 | self.meters[k].update(v, n) 45 | 46 | def __str__(self): 47 | """ 48 | Concatenate the meters in one log line 49 | """ 50 | s = '' 51 | for k, v in self.meters.items(): 52 | s += f'{k.title()} {v}\t' 53 | return s.rstrip() 54 | 55 | def tb_log(self, tb_logger, prefix='data/', step=None): 56 | """ 57 | Log using tensorboard 58 | """ 59 | for k, v in self.meters.items(): 60 | tb_logger.add_scalar(prefix + k, v.val, step) 61 | 62 | def update_dict( 63 | self, val_metrics, 64 | ): 65 | 66 | for metric_name, metric_val in val_metrics.items(): 67 | try: 68 | v = metric_val.item() 69 | except AttributeError: 70 | v = metric_val 71 | 72 | self.update( 73 | k=f'{metric_name}', v=v, n=0 74 | ) 75 | 76 | 77 | def create_logger(level='info'): 78 | 79 | level = eval(f'logging.{level.upper()}') 80 | 81 | logging.basicConfig( 82 | format='%(asctime)s - [%(levelname)-8s] - %(message)s', 83 | level=level 84 | ) 85 | 86 | logger = logging.getLogger(__name__) 87 | return logger 88 | 89 | 90 | def get_logger(): 91 | logger = logging.getLogger(__name__) 92 | return logger 93 | 94 | 95 | def tb_log_dict(tb_writer, data_dict, iteration, prefix=''): 96 | for k, v in data_dict.items(): 97 | tb_writer.add_scalar(f'{prefix}/{k}', v, iteration) 98 | 99 | 100 | def log_param_histograms(model, tb_writer, iteration): 101 | for k, p in model.named_parameters(): 102 | tb_writer.add_histogram( 103 | f'params/{k}', 104 | p.data, 105 | iteration, 106 | ) 107 | 108 | def log_grad_norm(model, tb_writer, iteration, reduce=sum): 109 | 110 | grads = [] 111 | for k, p in model.named_parameters(): 112 | if p.grad is None: 113 | continue 114 | tb_writer.add_scalar( 115 | f'grads/{k}', 116 | p.grad.data.norm(2).item(), 117 | iteration 118 | ) 119 | grads.append(p.grad.data.norm(2).item()) 120 | return reduce(grads) 121 | 122 | 123 | def print_log_param_stats(model, iteration): 124 | 125 | print('Iter s{}'.format(iteration)) 126 | for k, v in model.txt_enc.named_parameters(): 127 | print('{:35s}: {:8.5f}, {:8.5f}, {:8.5f}, {:8.5f}'.format( 128 | k, v.data.cpu().min().numpy(), 129 | v.data.cpu().mean().numpy(), 130 | v.data.cpu().max().numpy(), 131 | v.data.cpu().std().numpy(), 132 | )) 133 | for k, v in model.img_enc.named_parameters(): 134 | print('{:35s}: {:8.5f}, {:8.5f}, {:8.5f}, {:8.5f}'.format( k, v.data.cpu().min().numpy(), 135 | v.data.cpu().mean().numpy(), 136 | v.data.cpu().max().numpy(), 137 | v.data.cpu().std().numpy(), 138 | )) 139 | for k, p in model.txt_enc.named_parameters(): 140 | if p.grad is None: 141 | continue 142 | print('{:35s}: {:8.5f}'.format(k, p.grad.data.norm(2).item(),)) 143 | 144 | for k, p in model.img_enc.named_parameters(): 145 | if p.grad is None: 146 | continue 147 | print('{:35s}: {:8.5f}'.format(k, p.grad.data.norm(2).item(),)) 148 | 149 | print('\n\n') 150 | -------------------------------------------------------------------------------- /retrieval/utils/options.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import yaml 3 | import json 4 | import copy 5 | import inspect 6 | import argparse 7 | import collections 8 | from yaml import Dumper 9 | from collections import OrderedDict 10 | 11 | from retrieval.utils.file_utils import load_yaml_opts 12 | 13 | 14 | class OptionsDict(OrderedDict): 15 | """ Dictionary of options contained in the Options class 16 | """ 17 | 18 | def __init__(self, *args, **kwargs): 19 | self.__locked = False 20 | super(OptionsDict, self).__init__(*args, **kwargs) 21 | 22 | def __getitem__(self, key): 23 | if OrderedDict.__contains__(self, key): 24 | val = OrderedDict.__getitem__(self, key) 25 | elif '.' in key: 26 | keys = key.split('.') 27 | val = self[keys[0]] 28 | for k in keys[1:]: 29 | val = val[k] 30 | else: 31 | return OrderedDict.__getitem__(self, key) 32 | return val 33 | 34 | def __contains__(self, key): 35 | # Cannot use "key in" due to recursion, reusing rules for dotted keys from __getitem__ 36 | try: 37 | self[key] 38 | return True 39 | except KeyError: 40 | return False 41 | 42 | def __setitem__(self, key, val): 43 | if key == '_{}__locked'.format(type(self).__name__): 44 | OrderedDict.__setitem__(self, key, val) 45 | elif hasattr(self, '_{}__locked'.format(type(self).__name__)): 46 | if self.__locked: 47 | raise PermissionError('Options\' dictionnary is locked and cannot be changed.') 48 | if type(val) == dict: 49 | val = OptionsDict(val) 50 | OrderedDict.__setitem__(self, key, val) 51 | elif '.' in key: 52 | keys = key.split('.') 53 | d = self[keys[0]] 54 | for k in keys[1:-1]: 55 | d = d[k] 56 | d[keys[-1]] = val 57 | else: 58 | OrderedDict.__setitem__(self, key, val) 59 | else: 60 | raise PermissionError('Tried to access Options\' dictionnary bypassing the lock feature.') 61 | 62 | def __getattr__(self, key): 63 | if key in self: 64 | return self[key] 65 | else: 66 | return OrderedDict.__getattr__(self, key) 67 | 68 | # def __setattr__(self, key, value): 69 | # self[key] = value 70 | 71 | def __repr__(self): 72 | dictrepr = dict.__repr__(self) 73 | return '{}({})'.format(type(self).__name__, dictrepr) 74 | 75 | def get(self, key, default): 76 | if key in self: 77 | return self[key] 78 | else: 79 | return default 80 | 81 | 82 | def update(self, *args, **kwargs): 83 | for k, v in OrderedDict(*args, **kwargs).items(): 84 | self[k] = v 85 | 86 | def asdict(self): 87 | d = {} 88 | for k, v in self.items(): 89 | if isinstance(v, dict): 90 | d[k] = dict(v) 91 | else: 92 | d[k] = v 93 | return d 94 | 95 | def lock(self): 96 | self.__locked = True 97 | for key in self.keys(): 98 | if type(key) == OptionsDict: 99 | self[key].lock() 100 | 101 | def islocked(self): 102 | return self.__locked 103 | 104 | def unlock(self): 105 | stack_this = inspect.stack()[1] 106 | stack_caller = inspect.stack()[2] 107 | if stack_this.filename != stack_caller.filename or stack_this.function != stack_caller.function: 108 | for i in range(10): 109 | print('WARNING: Options unlocked by {}[{}]: {}.'.format( 110 | stack_caller.filename, 111 | stack_caller.lineno, 112 | stack_caller.function)) 113 | self.__locked = False 114 | for key in self.keys(): 115 | if type(key) == OptionsDict: 116 | self[key].unlock() 117 | 118 | # https://stackoverflow.com/questions/6760685/creating-a-singleton-in-python 119 | class Options(object): 120 | """ Options is a singleton. It parses a yaml file to generate rules to the argument parser. 121 | If a path to a yaml file is not provided, it relies on the `-o/--path_opts` command line argument. 122 | Args: 123 | path_yaml(str): path to the yaml file 124 | arguments_callback(func): function to be called after running argparse, 125 | if values need to be preprocessed 126 | lock(bool): if True, Options will be locked and no changes to values authorized 127 | run_parser(bool): if False, argparse will not be executed, and values from options 128 | file will be used as is 129 | Example usage: 130 | 131 | .. code-block:: python 132 | # parse the yaml file and create options 133 | Options(path_yaml='bootstrap/options/example.yaml', run_parser=False) 134 | 135 | opt = Options() # get the options dictionary from the singleton 136 | print(opt['exp']) # display a dictionary 137 | print(opt['exp.dir']) # display a value 138 | print(opt['exp']['dir']) # display the same value 139 | # the values cannot be changed by command line because run_parser=False 140 | 141 | """ 142 | 143 | # Attributs 144 | 145 | __instance = None # singleton instance of this class 146 | options = None # dictionnary of the singleton 147 | path_yaml = None 148 | 149 | class HelpParser(argparse.ArgumentParser): 150 | def error(self, message): 151 | print('\nError: %s\n' % message) 152 | self.print_help() 153 | sys.exit(2) 154 | 155 | def __new__(self, path_yaml=None, arguments_callback=None, lock=False, run_parser=True): 156 | # Options is a singleton, we will only build if it has not been built before 157 | if not Options.__instance: 158 | Options.__instance = object.__new__(Options) 159 | 160 | if path_yaml: 161 | self.path_yaml = path_yaml 162 | else: 163 | # Parsing only the path_opts argument to find yaml file 164 | optfile_parser = argparse.ArgumentParser(add_help=False) 165 | optfile_parser.add_argument('-o', '--path_opts', type=str, required=True) 166 | self.path_yaml = optfile_parser.parse_known_args()[0].path_opts 167 | 168 | options_yaml = load_yaml_opts(self.path_yaml) 169 | 170 | if run_parser: 171 | fullopt_parser = Options.HelpParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 172 | fullopt_parser.add_argument('-o', '--path_opts', type=str, required=True) 173 | Options.__instance.add_options(fullopt_parser, options_yaml) 174 | 175 | arguments = fullopt_parser.parse_args() 176 | if arguments_callback: 177 | arguments = arguments_callback(Options.__instance, arguments, options_yaml) 178 | 179 | Options.__instance.options = OptionsDict() 180 | for argname in vars(arguments): 181 | nametree = argname.split('.') 182 | value = getattr(arguments, argname) 183 | 184 | position = Options.__instance.options 185 | for piece in nametree[:-1]: 186 | if piece in position and isinstance(position[piece], collections.Mapping): 187 | position = position[piece] 188 | else: 189 | position[piece] = {} 190 | position = position[piece] 191 | position[nametree[-1]] = value 192 | else: 193 | Options.__instance.options = options_yaml 194 | 195 | if lock: 196 | Options.__instance.lock() 197 | return Options.__instance 198 | 199 | 200 | def __getitem__(self, key): 201 | """ 202 | """ 203 | val = self.options[key] 204 | return val 205 | 206 | 207 | def __setitem__(self, key, val): 208 | self.options[key] = val 209 | 210 | 211 | def __getattr__(self, key): 212 | if key in self: 213 | return self[key] 214 | else: 215 | return object.__getattr__(self, key) 216 | 217 | 218 | def __contains__(self, item): 219 | return item in self.options 220 | 221 | 222 | def __str__(self): 223 | return json.dumps(self.options, indent=2) 224 | 225 | 226 | def get(self, key, default): 227 | return self.options.get(key, default) 228 | 229 | 230 | def copy(self): 231 | return self.options.copy() 232 | 233 | 234 | def has_key(self, k): 235 | return k in self.options 236 | 237 | 238 | def keys(self): 239 | return self.options.keys() 240 | 241 | 242 | def values(self): 243 | return self.options.values() 244 | 245 | 246 | def items(self): 247 | return self.options.items() 248 | 249 | 250 | def add_options(self, parser, options, prefix=''): 251 | if prefix: 252 | prefix += '.' 253 | 254 | for key, value in options.items(): 255 | if isinstance(value, dict): 256 | self.add_options(parser, value, '{}{}'.format(prefix, key)) 257 | else: 258 | argname = '--{}{}'.format(prefix, key) 259 | nargs = '*' if isinstance(value, list) else '?' 260 | if value is None: 261 | datatype = str 262 | elif isinstance(value, bool): 263 | datatype = self.str_to_bool 264 | elif isinstance(value, list): 265 | if len(value) == 0: 266 | datatype = str 267 | else: 268 | datatype = type(value[0]) 269 | else: 270 | datatype = type(value) 271 | parser.add_argument(argname, help='Default: %(default)s', default=value, nargs=nargs, type=datatype) 272 | 273 | 274 | def str_to_bool(self, v): 275 | true_strings = ['yes', 'true'] 276 | false_strings = ['no', 'false'] 277 | if isinstance(v, str): 278 | if v.lower() in true_strings: 279 | return True 280 | elif v.lower() in false_strings: 281 | return False 282 | raise argparse.ArgumentTypeError('{} cant be converted to bool ('.format(v)+'|'.join(true_strings+false_strings)+' can be)') 283 | 284 | 285 | def save(self, path_yaml): 286 | """ Write options dictionary to a yaml file 287 | """ 288 | Options.save_yaml_opts(self.options, path_yaml) 289 | 290 | 291 | def lock(self): 292 | Options.__instance.options.lock() 293 | 294 | 295 | def unlock(self): 296 | Options.__instance.options.unlock() 297 | 298 | 299 | # Static methods 300 | 301 | def save_yaml_opts(opts, path_yaml): 302 | # Warning: copy is not nested 303 | options = copy.copy(opts) 304 | if 'path_opts' in options: 305 | del options['path_opts'] 306 | 307 | # https://gist.github.com/oglops/c70fb69eef42d40bed06 308 | def dict_representer(dumper, data): 309 | return dumper.represent_dict(data.items()) 310 | Dumper.add_representer(OptionsDict, dict_representer) 311 | 312 | with open(path_yaml, 'w') as yaml_file: 313 | yaml.dump(options, yaml_file, Dumper=Dumper, default_flow_style=False) 314 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | 5 | import params 6 | from retrieval.train import train 7 | from retrieval.utils import helper 8 | from retrieval.model import loss 9 | from retrieval.model.model import Retrieval 10 | from retrieval.data.loaders import get_loaders 11 | from retrieval.utils.logger import create_logger 12 | from retrieval.utils.helper import load_model 13 | from retrieval.utils.file_utils import load_yaml_opts, parse_loader_name 14 | 15 | 16 | def get_data_path(opt): 17 | if 'DATA_PATH' not in os.environ: 18 | if not opt.dataset.data_path: 19 | raise Exception(''' 20 | DATA_PATH not specified. 21 | Please, run "$ export DATA_PATH=/path/to/dataset" 22 | or add path to yaml file 23 | ''') 24 | return opt.dataset.data_path 25 | else: 26 | return os.environ['DATA_PATH'] 27 | 28 | 29 | def get_tokenizers(train_loader): 30 | tokenizers = train_loader.dataset.tokenizers 31 | if type(tokenizers) != list: 32 | tokenizers = [tokenizers] 33 | return tokenizers 34 | 35 | 36 | def set_criterion(opt, model): 37 | if 'name' in opt.criterion: 38 | logger.info(opt.criterion) 39 | multimodal_criterion = loss.get_loss(**opt.criterion) 40 | multilanguage_criterion = loss.get_loss(**opt.criterion) 41 | else: 42 | multimodal_criterion = loss.ContrastiveLoss(**opt.criterion) 43 | multilanguage_criterion = loss.ContrastiveLoss(**opt.ml_criterion) 44 | set_model_criterion(opt, model, multilanguage_criterion, multimodal_criterion) 45 | # return multimodal_criterion, multilanguage_criterion 46 | 47 | 48 | def set_model_criterion(opt, model, multilanguage_criterion, multimodal_criterion): 49 | model.mm_criterion = multimodal_criterion 50 | model.ml_criterion = None 51 | if len(opt.dataset.adapt.data) > 0: 52 | model.ml_criterion = multilanguage_criterion 53 | 54 | 55 | if __name__ == '__main__': 56 | args = params.get_train_params() 57 | opt = load_yaml_opts(args.options) 58 | logger = create_logger(level='debug' if opt.engine.debug else 'info') 59 | 60 | logger.info(f'Used args : \n{args}') 61 | logger.info(f'Used options: \n{opt}') 62 | 63 | data_path = get_data_path(opt) 64 | 65 | train_loader, val_loaders, adapt_loaders = get_loaders(data_path, args.local_rank, opt) 66 | 67 | tokenizers = get_tokenizers(train_loader) 68 | model = Retrieval(**opt.model, tokenizers=tokenizers) 69 | 70 | # TODO: Implement complete resume of training 71 | if opt.exp.resume: 72 | model = helper.load_model(opt.exp.resume) 73 | # model, optimizer = restore_checkpoint(opt, tokenizers) 74 | print(model) 75 | 76 | print_fn = (lambda x: x) if not model.master else tqdm.write 77 | 78 | set_criterion(opt, model) 79 | trainer = train.Trainer( 80 | model=model, 81 | args=opt, 82 | sysoutlog=print_fn, 83 | path=opt.exp.outpath, 84 | world_size=1 # TODO 85 | ) 86 | 87 | trainer.setup_optim( 88 | lr=opt.optimizer.lr, 89 | lr_scheduler=opt.optimizer.lr_scheduler, 90 | clip_grad=opt.optimizer.grad_clip, 91 | log_grad_norm=False, 92 | log_histograms=False, 93 | optimizer=opt.optimizer, 94 | freeze_modules=opt.model.freeze_modules 95 | ) 96 | 97 | if opt.engine.eval_before_training: 98 | result, rs = trainer.evaluate_loaders( 99 | val_loaders 100 | ) 101 | 102 | trainer.fit( 103 | train_loader=train_loader, 104 | valid_loaders=val_loaders, 105 | lang_loaders=adapt_loaders, 106 | nb_epochs=opt.engine.nb_epochs, 107 | valid_interval=opt.engine.valid_interval, 108 | log_interval=opt.engine.print_freq 109 | ) 110 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | 7 | import params 8 | from run import load_model, get_tokenizers 9 | from retrieval.data.loaders import get_loader 10 | from retrieval.model import model 11 | from retrieval.train.train import Trainer 12 | from retrieval.utils import file_utils, helper 13 | from retrieval.utils.logger import create_logger 14 | from run import load_yaml_opts, parse_loader_name, get_data_path 15 | 16 | 17 | if __name__ == '__main__': 18 | args = params.get_test_params() 19 | opt = load_yaml_opts(args.options) 20 | logger = create_logger(level='debug' if opt.engine.debug else 'info') 21 | 22 | logger.info(f'Used args : \n{args}') 23 | logger.info(f'Used options: \n{opt}') 24 | 25 | data_path = get_data_path(opt) 26 | 27 | loaders = [] 28 | for data_info in opt.dataset.val.data: 29 | _, lang = parse_loader_name(data_info) 30 | loaders.append( 31 | get_loader( 32 | data_split=args.data_split, 33 | data_path=data_path, 34 | data_info=data_info, 35 | loader_name=opt.dataset.loader_name, 36 | local_rank=args.local_rank, 37 | text_repr=opt.dataset.text_repr, 38 | vocab_paths=opt.dataset.vocab_paths, 39 | ngpu=torch.cuda.device_count(), 40 | **opt.dataset.val 41 | ) 42 | ) 43 | 44 | tokenizers = get_tokenizers(loaders[0]) 45 | model = helper.load_model(f'{opt.exp.outpath}/best_model.pkl') 46 | print_fn = (lambda x: x) if not model.master else tqdm.write 47 | 48 | trainer = Trainer( 49 | model=model, 50 | args={'args': args, 'model_args': opt.model}, 51 | sysoutlog=print_fn, 52 | ) 53 | 54 | result, rs = trainer.evaluate_loaders(loaders) 55 | logger.info(result) 56 | 57 | if args.outpath is not None: 58 | outpath = args.outpath 59 | else: 60 | filename = f'{data_info}.{lang}:{args.data_split}:results.json' 61 | outpath = Path(opt.exp.outpath) / filename 62 | 63 | logger.info('Saving into {}'.format(outpath)) 64 | file_utils.save_json(outpath, result) 65 | -------------------------------------------------------------------------------- /test_ens.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | from params import get_test_params 6 | from retrieval.train import evaluation 7 | from retrieval.data.loaders import get_loader 8 | from retrieval.utils.logger import create_logger 9 | from run import load_model, get_data_path, get_tokenizers 10 | from retrieval.utils.file_utils import save_json, load_yaml_opts, parse_loader_name 11 | 12 | 13 | if __name__ == '__main__': 14 | args = get_test_params(ensemble=True) 15 | opt = load_yaml_opts(args.options[0]) 16 | logger = create_logger(level='debug' if opt.engine.debug else 'info') 17 | 18 | logger.info(f'Used args : \n{args}') 19 | logger.info(f'Used options: \n{opt}') 20 | 21 | train_data = opt.dataset.train_data 22 | 23 | data_path = get_data_path(opt) 24 | 25 | data_name, lang = parse_loader_name(opt.dataset.train.data) 26 | 27 | loader = get_loader( 28 | data_split=args.data_split, 29 | data_path=data_path, 30 | data_info=opt.dataset.train.data, 31 | loader_name=opt.dataset.loader_name, 32 | local_rank=args.local_rank, 33 | text_repr=opt.dataset.text_repr, 34 | vocab_paths=opt.dataset.vocab_paths, 35 | ngpu=torch.cuda.device_count(), 36 | **opt.dataset.val, 37 | ) 38 | 39 | device = torch.device(args.device) 40 | # device = torch.device('cuda') 41 | tokenizers = get_tokenizers(loader) 42 | 43 | sims = [] 44 | for options in args.options: 45 | options = load_yaml_opts(options) 46 | _model = load_model(f'{options.exp.outpath}/best_model.pkl') 47 | 48 | img_emb, cap_emb, lens = evaluation.predict_loader(_model, loader, device) 49 | _, sim_matrix = evaluation.evaluate(_model, img_emb, cap_emb, lens, device, return_sims=True) 50 | sims.append(sim_matrix) 51 | 52 | sims = np.array(sims) 53 | sims = sims.mean(0) 54 | 55 | i2t_metrics = evaluation.i2t(sims) 56 | t2i_metrics = evaluation.t2i(sims) 57 | 58 | _metrics_ = ('r1', 'r5', 'r10', 'medr', 'meanr') 59 | metrics = {} 60 | 61 | rsum = np.sum(i2t_metrics[:3]) + np.sum(t2i_metrics[:3]) 62 | 63 | i2t_metrics = { 64 | f'i2t_{k}': float(v) for k, v in zip(_metrics_, i2t_metrics) 65 | } 66 | t2i_metrics = { 67 | f't2i_{k}': float(v) for k, v in zip(_metrics_, t2i_metrics) 68 | } 69 | 70 | metrics.update(i2t_metrics) 71 | metrics.update(t2i_metrics) 72 | metrics['rsum'] = rsum 73 | logger.info(metrics) 74 | 75 | if args.outpath is not None: 76 | outpath = args.outpath 77 | else: 78 | filename = f'{data_name}.{lang}:{args.data_split}:ens_results.json' 79 | outpath = Path(opt.exp.outpath) / filename 80 | 81 | logger.info(f'Saving into: {outpath}') 82 | save_json(outpath, metrics) 83 | -------------------------------------------------------------------------------- /vocab/align_vocab.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | sys.path.append('../') 7 | from retrieval.data.tokenizer import Tokenizer 8 | from retrieval.utils.logger import create_logger 9 | from params import get_vocab_alignment_params 10 | 11 | 12 | def loadEmbModel(embFile, logger): 13 | """Loads W2V or Glove Model""" 14 | logger.info("Loading Embedding Model") 15 | f = open(embFile,'r') 16 | model = {} 17 | v = [] 18 | for line in f: 19 | splitLine = line.split(' ') 20 | word = splitLine[0] 21 | try: 22 | embedding = np.array([float(val) for val in splitLine[1:]]) 23 | except: 24 | logger.info(len(v), line) 25 | model[word] = embedding 26 | v.append(embedding) 27 | mean = np.array(v).mean(0) 28 | logger.info(mean.shape) 29 | model[''] = torch.tensor(mean) 30 | model[''] = torch.zeros(embedding.shape) 31 | model[''] = torch.zeros(embedding.shape) 32 | model[''] = torch.zeros(embedding.shape) 33 | logger.info("Done.",len(model)," words loaded!") 34 | return model 35 | 36 | 37 | def align_vocabs(emb_model, tokenizer): 38 | """Align vocabulary and embedding model""" 39 | hi_emb = emb_model['hi'] 40 | logger.info(hi_emb.shape) 41 | total_unk = 0 42 | nmax = max(tokenizer.vocab.idx2word.keys()) + 1 43 | word_matrix = torch.zeros(nmax, hi_emb.shape[-1]) 44 | logger.info(word_matrix.shape) 45 | 46 | for k, v in tqdm(tokenizer.vocab.idx2word.items(), total=len(tokenizer)): 47 | try: 48 | word_matrix[k] = torch.tensor(emb_model[v]) 49 | except KeyError: 50 | word_matrix[k] = emb_model[''] 51 | total_unk += 1 52 | return word_matrix, total_unk 53 | 54 | 55 | def load_tokenizer(args): 56 | """Load tokenizer""" 57 | tokenizer = Tokenizer() 58 | tokenizer.load(args.vocab_path) 59 | return tokenizer 60 | 61 | if __name__ == '__main__': 62 | args = get_vocab_alignment_params() 63 | logger = create_logger(level='debug') 64 | 65 | tokenizer = load_tokenizer(args) 66 | 67 | emb_model = loadEmbModel(args.emb_path, logger) 68 | 69 | word_matrix, total_unk = align_vocabs(emb_model, tokenizer) 70 | 71 | logger.info(f'Finished. Total UNK: {total_unk}') 72 | torch.save(word_matrix, args.outpath) 73 | logger.info(f'Saved into: {args.outpath}') 74 | --------------------------------------------------------------------------------