├── ESL.yaml
├── README.md
├── arguments.py
├── eval.py
├── eval_ensemble.py
├── lib
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── encoders.cpython-36.pyc
│ ├── evaluation.cpython-36.pyc
│ ├── loss.cpython-36.pyc
│ └── vse.cpython-36.pyc
├── _init_paths.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── image_caption.cpython-36.pyc
│ └── image_caption.py
├── encoders.py
├── evaluation.py
├── loss.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── mlp.cpython-36.pyc
│ │ └── resnet.cpython-36.pyc
│ ├── aggr
│ │ ├── __init__.py
│ │ └── gpo.py
│ ├── mlp.py
│ └── resnet.py
├── pytorch-logo-dark.png
└── vse.py
├── motivation.png
├── overview.png
├── runs
└── runX
│ └── log
│ └── performance.log
├── test.py
├── test_stack.py
├── testall.py
├── train.py
├── train_region_coco.sh
└── train_region_f30k.sh
/ESL.yaml:
--------------------------------------------------------------------------------
1 | name: py36
2 | channels:
3 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/free
4 | - pytorch
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
7 | - conda-forge
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
12 | - defaults
13 | dependencies:
14 | - _libgcc_mutex=0.1=conda_forge
15 | - _openmp_mutex=4.5=1_gnu
16 | - blas=1.0=mkl
17 | - bzip2=1.0.8=h7f98852_4
18 | - ca-certificates=2023.7.22=hbcca054_0
19 | - certifi=2016.9.26=py36_0
20 | - cudatoolkit=11.1.1=h6406543_8
21 | - dataclasses=0.8=pyh787bdff_0
22 | - ffmpeg=4.3=hf484d3e_0
23 | - freetype=2.10.4=h0708190_1
24 | - gmp=6.2.1=h58526e2_0
25 | - gnutls=3.6.13=h85f3911_1
26 | - intel-openmp=2021.2.0=h06a4308_610
27 | - jpeg=9b=0
28 | - lame=3.100=h7f98852_1001
29 | - lcms2=2.12=h3be6417_0
30 | - ld_impl_linux-64=2.35.1=hea4e1c9_2
31 | - libblas=3.9.0=9_mkl
32 | - libcblas=3.9.0=9_mkl
33 | - libffi=3.3=h58526e2_2
34 | - libgcc-ng=9.3.0=h2828fa1_19
35 | - libgomp=9.3.0=h2828fa1_19
36 | - libiconv=1.16=h516909a_0
37 | - liblapack=3.9.0=9_mkl
38 | - libpng=1.6.37=h21135ba_2
39 | - libstdcxx-ng=9.3.0=h6de172a_19
40 | - libtiff=4.1.0=h2733197_1
41 | - libuv=1.41.0=h7f98852_0
42 | - lz4-c=1.9.3=h9c3ff4c_0
43 | - mkl=2021.2.0=h06a4308_296
44 | - ncurses=6.2=h58526e2_4
45 | - nettle=3.6=he412f7d_0
46 | - ninja=1.10.2=h4bd325d_0
47 | - numpy=1.19.5=py36h2aa4a07_1
48 | - olefile=0.46=pyh9f0ad1d_1
49 | - openh264=2.1.1=h780b84a_0
50 | - openssl=1.1.1v=h7f8727e_0
51 | - pillow=8.2.0=py36he98fc37_0
52 | - pip=21.1.1=pyhd8ed1ab_0
53 | - python=3.6.13=hffdb5ce_0_cpython
54 | - python_abi=3.6=1_cp36m
55 | - pytorch=1.8.0=py3.6_cuda11.1_cudnn8.0.5_0
56 | - readline=8.1=h46c0cb4_0
57 | - ruamel_yaml=0.15.80=py36h8f6f2f9_1004
58 | - setuptools=49.6.0=py36h5fab9bb_3
59 | - sqlite=3.35.5=h74cdb3f_0
60 | - tk=8.6.10=h21135ba_1
61 | - torchaudio=0.8.0=py36
62 | - torchvision=0.9.0=py36_cu111
63 | - typing_extensions=3.7.4.3=py_0
64 | - wheel=0.36.2=pyhd3deb0d_0
65 | - xz=5.2.5=h516909a_1
66 | - yaml=0.2.5=h516909a_0
67 | - zlib=1.2.11=h516909a_1010
68 | - zstd=1.4.9=ha95c52a_0
69 | - pip:
70 | - blessed==1.20.0
71 | - boto3==1.23.10
72 | - botocore==1.26.10
73 | - cached-property==1.5.2
74 | - charset-normalizer==2.0.12
75 | - click==8.0.2
76 | - cycler==0.11.0
77 | - ftfy==3.0.1
78 | - gpustat==1.1
79 | - h5py==3.1.0
80 | - idna==3.4
81 | - imageio==2.9.0
82 | - importlib-metadata==4.8.1
83 | - jmespath==0.10.0
84 | - joblib==1.1.0
85 | - kiwisolver==1.3.1
86 | - lmdb==1.3.0
87 | - matplotlib==3.3.4
88 | - nltk==3.6.3
89 | - nvidia-ml-py==11.525.112
90 | - opencv-python==4.7.0.72
91 | - pandas==1.1.5
92 | - protobuf==3.17.0
93 | - psutil==5.9.4
94 | - pyparsing==3.0.6
95 | - python-dateutil==2.8.2
96 | - pytz==2021.3
97 | - regex==2021.10.8
98 | - requests==2.27.1
99 | - s3transfer==0.5.2
100 | - sacremoses==0.0.53
101 | - scikit-learn==0.24.2
102 | - scipy==1.5.4
103 | - seaborn==0.11.2
104 | - sentencepiece==0.1.98
105 | - six==1.16.0
106 | - sklearn==0.0
107 | - tensorboard-logger==0.1.0
108 | - threadpoolctl==3.1.0
109 | - tqdm==4.62.3
110 | - transformers==2.1.0
111 | - urllib3==1.26.15
112 | - wcwidth==0.2.6
113 | - zipp==3.6.0
114 | prefix: /mnt/data10t/bakuphome20210617/zhangkun/anaconda3/envs/py36
115 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Enhanced Semantic Similarity Learning Framework for Image-Text Matching
2 |
3 |
[](https://opensource.org/licenses/MIT)
4 |
5 | Official PyTorch implementation of the paper [Enhanced Semantic Similarity Learning Framework for Image-Text Matching](https://www.researchgate.net/publication/373318149_Enhanced_Semantic_Similarity_Learning_Framework_for_Image-Text_Matching).
6 |
7 | Please use the following bib entry to cite this paper if you are using any resources from the repo.
8 |
9 | ```
10 | @article{zhang2023enhanced,
11 | author={Zhang, Kun and Hu, Bo and Zhang, Huatian and Li, Zhe and Mao, Zhendong},
12 | journal={IEEE Transactions on Circuits and Systems for Video Technology},
13 | title={Enhanced Semantic Similarity Learning Framework for Image-Text Matching},
14 | year={2024},
15 | volume={34},
16 | number={4},
17 | pages={2973-2988}
18 | }
19 | ```
20 |
21 |
22 | We referred to the implementations of [GPO](https://github.com/woodfrog/vse_infty/blob/master/README.md) to build up our codebase.
23 |
24 | ## Motivation
25 |

31 | In this paper, different from the single-dimensional correspondence with limited semantic expressive capability, we propose a novel enhanced semantic similarity learning (ESL), which generalizes both measure-units and their correspondences into a dynamic learnable framework to examine the multi-dimensional enhanced correspondence between visual and textual features. Specifically, we first devise the intra-modal multi-dimensional aggregators with iterative enhancing mechanism, which dynamically captures new measure-units integrated by hierarchical multi-dimensions, producing diverse semantic combinatorial expressive capabilities to provide richer and discriminative information for similarity examination. Then, we devise the inter-modal enhanced correspondence learning with sparse contribution degrees, which comprehensively and efficiently determines the cross-modal semantic similarity. Extensive experiments verify its superiority in achieving state-of-the-art performance.
32 |
33 | ### Image-text Matching Results
34 |
35 | The following tables show partial results of image-to-text retrieval on COCO and Flickr30K datasets. In these experiments, we use BERT-base as the text encoder for our methods. This branch provides our code and pre-trained models for **using BERT as the text backbone**. Some results are better than those reported in the paper. However, it should be noted that the ensemble results in the paper may not be obtained by the best two checkpoints provided. It is lost due to not saving in time. You can train the model several times more and then combine any two to find the best ensemble performance. Please check out to [**the ```CLIP-based``` branch**](https://github.com/kkzhang95/ESL/blob/main/README.md) for the code and pre-trained models.
36 |
37 | #### Results of 5-fold evaluation on COCO 1K Test Splitbe
38 |
39 | | |Visual Backbone|Text Backbone|R1|R5|R10|R1|R5|R10|Rsum|Link|
40 | |---|:---:|:---:|---|---|---|---|---|---|---|---|
41 | |ESL-H | BUTD region |BERT-base|**82.5**|**97.4**|**99.0**|**66.2**|**91.9**|**96.7**|**533.5**|[Here](https://drive.google.com/file/d/1NgTLNFGhEt14YgLb3gCkWfBp1gBxvl9w/view?usp=sharing)|
42 | |ESL-A | BUTD region |BERT-base|**82.2**|**96.9**|**98.9**|**66.5**|**92.1**|**96.7**|**533.4**|[Here](https://drive.google.com/file/d/17jaJm2DSJbF5IuUij9s3c2fupcy4CW8T/view?usp=sharing)|
43 |
44 |
45 | #### Results of 5-fold evaluation on COCO 5K Test Split
46 |
47 | | |Visual Backbone|Text Backbone|R1|R5|R10|R1|R5|R10|Rsum|Link|
48 | |---|:---:|:---:|---|---|---|---|---|---|---|---|
49 | |ESL-H | BUTD region |BERT-base|**63.6**|**87.4**|**93.5**|**44.2**|**74.1**|**84.0**|**446.9**|[Here](https://drive.google.com/file/d/1NgTLNFGhEt14YgLb3gCkWfBp1gBxvl9w/view?usp=sharing)|
50 | |ESL-A | BUTD region |BERT-base|**63.0**|**87.6**|**93.3**|**44.5**|**74.4**|**84.1**|**447.0**|[Here](https://drive.google.com/file/d/17jaJm2DSJbF5IuUij9s3c2fupcy4CW8T/view?usp=sharing)|
51 |
52 |
53 | #### Results on Flickr30K Test Split
54 |
55 | | |Visual Backbone|Text Backbone|R1|R5|R10|R1|R5|R10|Rsum|Link|
56 | |---|:---:|:---:|---|---|---|---|---|---|---|---|
57 | |ESL-H | BUTD region |BERT-base|**83.5**|**96.3**|**98.4**|**65.1**|**87.6**|**92.7**|**523.7**|[Here](https://drive.google.com/file/d/17FnwyH8aSOwvUuZco0lQ5eY0TM_4LXxv/view?usp=sharing)|
58 | |ESL-A | BUTD region |BERT-base|**84.3**|**96.3**|**98.0**|**64.1**|**87.4**|**92.2**|**522.4**|[Here](https://drive.google.com/file/d/1ZoPW8azNkBWVq1jaQxHfI_XpINzvmv1n/view?usp=sharing)|
59 |
60 |
61 |
62 |
63 |
64 | ## Preparation
65 |
66 | ### Environment
67 |
68 | We recommended the following dependencies.
69 |
70 | * Python 3.6
71 | * [PyTorch](http://pytorch.org/) 1.8.0
72 | * [NumPy](http://www.numpy.org/) (>1.19.5)
73 | * [TensorBoard](https://github.com/TeamHG-Memex/tensorboard_logger)
74 | * The specific required environment can be found [here](https://github.com/CrossmodalGroup/ESL/blob/main/ESL.yaml) Using **conda env create -f ESL.yaml** to create the corresponding environments.
75 |
76 | ### Data
77 |
78 | You can download the dataset through Baidu Cloud. Download links are [Flickr30K]( https://pan.baidu.com/s/1Fr_bviuWLcrJ9MiiRn_H2Q) and [MSCOCO]( https://pan.baidu.com/s/1vp3gtQhT7GO0PQACBSnOrQ), the extraction code is: USTC.
79 |
80 | ## Training
81 |
82 | ```bash
83 | sh train_region_f30k.sh
84 | ```
85 |
86 | ```bash
87 | sh train_region_coco.sh
88 | ```
89 | For the dimensional selective mask, we design both heuristic and adaptive strategies. You can use the flag in [vse.py](https://github.com/CrossmodalGroup/ESL/blob/main/lib/vse.py) (line 44)
90 | ```bash
91 | heuristic_strategy = False
92 | ```
93 | to control which strategy is selected. True -> heuristic strategy, False -> adaptive strategy.
94 |
95 | ## Evaluation
96 |
97 | Test on Flickr30K
98 | ```bash
99 | python test.py
100 | ```
101 |
102 | To do cross-validation on MSCOCO, pass `fold5=True` with a model trained using
103 | `--data_name coco_precomp`.
104 |
105 | ```bash
106 | python testall.py
107 | ```
108 |
109 | To ensemble model, specify the model_path in test_stack.py, and run
110 | ```bash
111 | python test_stack.py
112 | ```
113 |
114 |
115 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_argument_parser():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--data_path', default='../Flickr30K/', # Flickr30K MS-COCO
7 | help='path to datasets')
8 | parser.add_argument('--data_name', default='f30k_precomp',
9 | help='{coco,f30k}_precomp')
10 | parser.add_argument('--vocab_path', default='./vocab/',
11 | help='Path to saved vocabulary json files.')
12 | parser.add_argument('--margin', default=0.2, type=float,
13 | help='Rank loss margin.')
14 | parser.add_argument('--num_epochs', default=30, type=int,
15 | help='Number of training epochs.')
16 | parser.add_argument('--batch_size', default=128, type=int,
17 | help='Size of a training mini-batch.')
18 | parser.add_argument('--word_dim', default=300, type=int,
19 | help='Dimensionality of the word embedding.')
20 | parser.add_argument('--grad_clip', default=2., type=float,
21 | help='Gradient clipping threshold.')
22 | parser.add_argument('--learning_rate', default=.0005, type=float,
23 | help='Initial learning rate.')
24 | parser.add_argument('--lr_update', default=15, type=int,
25 | help='Number of epochs to update the learning rate.')
26 | parser.add_argument('--optim', default='adam', type=str,
27 | help='the optimizer')
28 | parser.add_argument('--workers', default=10, type=int,
29 | help='Number of data loader workers.')
30 | parser.add_argument('--log_step', default=500, type=int,
31 | help='Number of steps to logger.info and record the log.')
32 | parser.add_argument('--val_step', default=500, type=int,
33 | help='Number of steps to run validation.')
34 | parser.add_argument('--logger_name', default='../log/',
35 | help='Path to save Tensorboard log.')
36 | parser.add_argument('--model_name', default='../checkpoint/',
37 | help='Path to save the model.')
38 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
39 | help='path to latest checkpoint (default: none)')
40 | parser.add_argument('--max_violation', action='store_true',
41 | help='Use max instead of sum in the rank loss.')
42 | parser.add_argument('--img_dim', default=2048, type=int,
43 | help='Dimensionality of the image embedding.')
44 | parser.add_argument('--no_imgnorm', action='store_true',
45 | help='Do not normalize the image embeddings.')
46 | parser.add_argument('--no_txtnorm', action='store_true',
47 | help='Do not normalize the text embeddings.')
48 | parser.add_argument('--precomp_enc_type', default="basic",
49 | help='basic|backbone')
50 | parser.add_argument('--backbone_path', type=str, default='',
51 | help='path to the pre-trained backbone net')
52 | parser.add_argument('--backbone_source', type=str, default='detector',
53 | help='the source of the backbone model, detector|imagenet')
54 | parser.add_argument('--vse_mean_warmup_epochs', type=int, default=1,
55 | help='The number of warmup epochs using mean vse loss')
56 | parser.add_argument('--reset_start_epoch', action='store_true',
57 | help='Whether restart the start epoch when load weights')
58 | parser.add_argument('--backbone_warmup_epochs', type=int, default=5,
59 | help='The number of epochs for warmup')
60 | parser.add_argument('--embedding_warmup_epochs', type=int, default=2,
61 | help='The number of epochs for warming up the embedding layers')
62 | parser.add_argument('--backbone_lr_factor', default=0.01, type=float,
63 | help='The lr factor for fine-tuning the backbone, it will be multiplied to the lr of '
64 | 'the embedding layers')
65 | parser.add_argument('--input_scale_factor', type=float, default=1,
66 | help='The factor for scaling the input image')
67 | parser.add_argument('--kernel_size', default=2, type=int,
68 | help='Dimensionality of the sim embedding.')
69 | parser.add_argument('--embed_size', default=512, type=int,
70 | help='Dimensionality of the joint embedding.')
71 |
72 |
73 | return parser
74 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | from lib import evaluation
5 |
6 | logging.basicConfig()
7 | logger = logging.getLogger()
8 | logger.setLevel(logging.INFO)
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--dataset', default='coco',
14 | help='coco or f30k')
15 | parser.add_argument('--data_path', default='/tmp/data/coco')
16 | parser.add_argument('--save_results', action='store_true')
17 | parser.add_argument('--evaluate_cxc', action='store_true')
18 | opt = parser.parse_args()
19 |
20 | if opt.dataset == 'coco':
21 | weights_bases = [
22 | 'runs/release_weights/coco_butd_region_bert',
23 | 'runs/release_weights/coco_butd_grid_bert',
24 | 'runs/release_weights/coco_wsl_grid_bert',
25 | ]
26 | elif opt.dataset == 'f30k':
27 | weights_bases = [
28 | 'runs/release_weights/f30k_butd_region_bert',
29 | 'runs/release_weights/f30k_butd_grid_bert',
30 | 'runs/release_weights/f30k_wsl_grid_bert',
31 | ]
32 | else:
33 | raise ValueError('Invalid dataset argument {}'.format(opt.dataset))
34 |
35 | for base in weights_bases:
36 | logger.info('Evaluating {}...'.format(base))
37 | model_path = os.path.join(base, 'model_best.pth')
38 | if opt.save_results: # Save the final results for computing ensemble results
39 | save_path = os.path.join(base, 'results_{}.npy'.format(opt.dataset))
40 | else:
41 | save_path = None
42 |
43 | if opt.dataset == 'coco':
44 | if not opt.evaluate_cxc:
45 | # Evaluate COCO 5-fold 1K
46 | evaluation.evalrank(model_path, data_path=opt.data_path, split='testall', fold5=True)
47 | # Evaluate COCO 5K
48 | evaluation.evalrank(model_path, data_path=opt.data_path, split='testall', fold5=False, save_path=save_path)
49 | else:
50 | # Evaluate COCO-trained models on CxC
51 | evaluation.evalrank(model_path, data_path=opt.data_path, split='testall', fold5=True, cxc=True)
52 | elif opt.dataset == 'f30k':
53 | # Evaluate Flickr30K
54 | evaluation.evalrank(model_path, data_path=opt.data_path, split='test', fold5=False, save_path=save_path)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/eval_ensemble.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from lib import evaluation
3 |
4 | logging.basicConfig()
5 | logger = logging.getLogger()
6 | logger.setLevel(logging.INFO)
7 |
8 | # Evaluate model ensemble
9 | paths = ['runs/coco_butd_grid_bert/results_coco.npy',
10 | 'runs/coco_butd_region_bert/results_coco.npy']
11 |
12 | evaluation.eval_ensemble(results_paths=paths, fold5=True)
13 | evaluation.eval_ensemble(results_paths=paths, fold5=False)
14 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__init__.py
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/encoders.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/encoders.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/evaluation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/evaluation.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/vse.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/vse.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 |
5 | def add_path(path):
6 | if path not in sys.path:
7 | sys.path.insert(0, path)
8 |
9 |
10 | root_dir = osp.abspath(osp.dirname(osp.join(__file__, '..')))
11 |
12 | # Add lib to PYTHONPATH
13 | lib_path = osp.join(root_dir, 'lib')
14 | datasets_path = osp.join(root_dir, 'datasets')
15 | add_path(lib_path)
16 | add_path(datasets_path)
17 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/datasets/__init__.py
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/image_caption.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/datasets/__pycache__/image_caption.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/image_caption.py:
--------------------------------------------------------------------------------
1 | """COCO dataset loader"""
2 | import torch
3 | import torch.utils.data as data
4 | import os
5 | import os.path as osp
6 | import numpy as np
7 | from imageio import imread
8 | import random
9 | import json
10 | import cv2
11 |
12 | import logging
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class RawImageDataset(data.Dataset):
18 | """
19 | Load precomputed captions and image features
20 | Possible options: f30k_precomp, coco_precomp
21 | """
22 |
23 | def __init__(self, data_path, data_name, data_split, tokenzier, opt, train):
24 | self.opt = opt
25 | self.train = train
26 | self.data_path = data_path
27 | self.data_name = data_name
28 | self.tokenizer = tokenzier
29 |
30 | loc_cap = osp.join(data_path, 'precomp')
31 | loc_image = osp.join(data_path, 'precomp')
32 | loc_mapping = osp.join(data_path, 'id_mapping.json')
33 | if 'coco' in data_name:
34 | self.image_base = osp.join(data_path, 'images')
35 | else:
36 | self.image_base = osp.join(data_path, 'flickr30k-images')
37 |
38 | with open(loc_mapping, 'r') as f_mapping:
39 | self.id_to_path = json.load(f_mapping)
40 |
41 | # Read Captions
42 | self.captions = []
43 | with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
44 | for line in f:
45 | self.captions.append(line.strip())
46 |
47 | # Get the image ids
48 | with open(osp.join(loc_image, '{}_ids.txt'.format(data_split)), 'r') as f:
49 | image_ids = f.readlines()
50 | self.images = [int(x.strip()) for x in image_ids]
51 |
52 | # Set related parameters according to the pre-trained backbone **
53 | assert 'backbone' in opt.precomp_enc_type
54 |
55 | self.backbone_source = opt.backbone_source
56 | self.base_target_size = 256
57 | self.crop_ratio = 0.875
58 | self.train_scale_rate = 1
59 | if hasattr(opt, 'input_scale_factor') and opt.input_scale_factor != 1:
60 | self.base_target_size = int(self.base_target_size * opt.input_scale_factor)
61 | logger.info('Input images are scaled by factor {}'.format(opt.input_scale_factor))
62 | if 'detector' in self.backbone_source:
63 | self.pixel_means = np.array([[[102.9801, 115.9465, 122.7717]]])
64 | else:
65 | self.imagenet_mean = [0.485, 0.456, 0.406]
66 | self.imagenet_std = [0.229, 0.224, 0.225]
67 |
68 | self.length = len(self.captions)
69 |
70 | # rkiros data has redundancy in images, we divide by 5, 10crop doesn't
71 | num_images = len(self.images)
72 |
73 | if num_images != self.length:
74 | self.im_div = 5
75 | else:
76 | self.im_div = 1
77 | # the development set for coco is large and so validation would be slow
78 | if data_split == 'dev':
79 | self.length = 5000
80 |
81 | def __getitem__(self, index):
82 | img_index = index // self.im_div
83 | caption = self.captions[index]
84 | caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
85 |
86 | # Convert caption (string) to word ids (with Size Augmentation at training time).
87 | target = process_caption(self.tokenizer, caption_tokens, self.train)
88 |
89 | image_id = self.images[img_index]
90 | image_path = os.path.join(self.image_base, self.id_to_path[str(image_id)])
91 | im_in = np.array(imread(image_path))
92 | processed_image = self._process_image(im_in)
93 | image = torch.Tensor(processed_image)
94 | image = image.permute(2, 0, 1)
95 | return image, target, index, img_index
96 |
97 | def __len__(self):
98 | return self.length
99 |
100 | def _process_image(self, im_in):
101 | """
102 | Converts an image into a network input, with pre-processing including re-scaling, padding, etc, and data
103 | augmentation.
104 | """
105 | if len(im_in.shape) == 2:
106 | im_in = im_in[:, :, np.newaxis]
107 | im_in = np.concatenate((im_in, im_in, im_in), axis=2)
108 |
109 | if 'detector' in self.backbone_source:
110 | im_in = im_in[:, :, ::-1]
111 | im = im_in.astype(np.float32, copy=True)
112 |
113 | if self.train:
114 | target_size = self.base_target_size * self.train_scale_rate
115 | else:
116 | target_size = self.base_target_size
117 |
118 | # 2. Random crop when in training mode, elsewise just skip
119 | if self.train:
120 | crop_ratio = np.random.random() * 0.4 + 0.6
121 | crop_size_h = int(im.shape[0] * crop_ratio)
122 | crop_size_w = int(im.shape[1] * crop_ratio)
123 | processed_im = self._crop(im, crop_size_h, crop_size_w, random=True)
124 | else:
125 | processed_im = im
126 |
127 | # 3. Resize to the target resolution
128 | im_shape = processed_im.shape
129 | im_scale_x = float(target_size) / im_shape[1]
130 | im_scale_y = float(target_size) / im_shape[0]
131 | processed_im = cv2.resize(processed_im, None, None, fx=im_scale_x, fy=im_scale_y,
132 | interpolation=cv2.INTER_LINEAR)
133 |
134 | if self.train:
135 | if np.random.random() > 0.5:
136 | processed_im = self._hori_flip(processed_im)
137 |
138 | # Normalization
139 | if 'detector' in self.backbone_source:
140 | processed_im = self._detector_norm(processed_im)
141 | else:
142 | processed_im = self._imagenet_norm(processed_im)
143 |
144 | return processed_im
145 |
146 | def _imagenet_norm(self, im_in):
147 | im_in = im_in.astype(np.float32)
148 | im_in = im_in / 255
149 | for i in range(im_in.shape[-1]):
150 | im_in[:, :, i] = (im_in[:, :, i] - self.imagenet_mean[i]) / self.imagenet_std[i]
151 | return im_in
152 |
153 | def _detector_norm(self, im_in):
154 | im_in = im_in.astype(np.float32)
155 | im_in -= self.pixel_means
156 | return im_in
157 |
158 | @staticmethod
159 | def _crop(im, crop_size_h, crop_size_w, random):
160 | h, w = im.shape[0], im.shape[1]
161 | if random:
162 | if w - crop_size_w == 0:
163 | x_start = 0
164 | else:
165 | x_start = np.random.randint(w - crop_size_w, size=1)[0]
166 | if h - crop_size_h == 0:
167 | y_start = 0
168 | else:
169 | y_start = np.random.randint(h - crop_size_h, size=1)[0]
170 | else:
171 | x_start = (w - crop_size_w) // 2
172 | y_start = (h - crop_size_h) // 2
173 |
174 | cropped_im = im[y_start:y_start + crop_size_h, x_start:x_start + crop_size_w, :]
175 |
176 | return cropped_im
177 |
178 | @staticmethod
179 | def _hori_flip(im):
180 | im = np.fliplr(im).copy()
181 | return im
182 |
183 |
184 | class PrecompRegionDataset(data.Dataset):
185 | """
186 | Load precomputed captions and image features for COCO or Flickr
187 | """
188 |
189 | def __init__(self, data_path, data_name, data_split, tokenizer, opt, train):
190 | self.tokenizer = tokenizer
191 | self.opt = opt
192 | self.train = train
193 | self.data_path = data_path
194 | self.data_name = data_name
195 |
196 | loc_cap = osp.join(data_path, data_name)
197 | loc_image = osp.join(data_path, data_name)
198 |
199 | # Captions
200 | self.captions = []
201 | with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
202 | for line in f:
203 | self.captions.append(line.strip())
204 | # Image features
205 | self.images = np.load(os.path.join(loc_image, '%s_ims.npy' % data_split))
206 |
207 | self.length = len(self.captions)
208 | # rkiros data has redundancy in images, we divide by 5, 10crop doesn't
209 | num_images = len(self.images)
210 |
211 | if num_images != self.length:
212 | self.im_div = 5
213 | else:
214 | self.im_div = 1
215 | # the development set for coco is large and so validation would be slow
216 | if data_split == 'dev':
217 | self.length = 5000
218 |
219 | def __getitem__(self, index):
220 | # handle the image redundancy
221 | img_index = index // self.im_div
222 | caption = self.captions[index]
223 | caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
224 |
225 | # Convert caption (string) to word ids (with Size Augmentation at training time)
226 | target = process_caption(self.tokenizer, caption_tokens, self.train)
227 | image = self.images[img_index]
228 | if self.train: # Size augmentation for region feature
229 | num_features = image.shape[0]
230 | rand_list = np.random.rand(num_features)
231 | image = image[np.where(rand_list > 0.20)]
232 | image = torch.Tensor(image)
233 | return image, target, index, img_index
234 |
235 | def __len__(self):
236 | return self.length
237 |
238 |
239 | def process_caption(tokenizer, tokens, train=True):
240 | output_tokens = []
241 | deleted_idx = []
242 |
243 | for i, token in enumerate(tokens):
244 | sub_tokens = tokenizer.wordpiece_tokenizer.tokenize(token)
245 | prob = random.random()
246 |
247 | if prob < 0.20 and train: # mask/remove the tokens only during training
248 | prob /= 0.20
249 |
250 | # 50% randomly change token to mask token
251 | if prob < 0.5:
252 | for sub_token in sub_tokens:
253 | output_tokens.append("[MASK]")
254 | # 10% randomly change token to random token
255 | elif prob < 0.6:
256 | for sub_token in sub_tokens:
257 | output_tokens.append(random.choice(list(tokenizer.vocab.keys())))
258 | # -> rest 10% randomly keep current token
259 | else:
260 | for sub_token in sub_tokens:
261 | output_tokens.append(sub_token)
262 | deleted_idx.append(len(output_tokens) - 1)
263 | else:
264 | for sub_token in sub_tokens:
265 | # no masking token (will be ignored by loss function later)
266 | output_tokens.append(sub_token)
267 |
268 | if len(deleted_idx) != 0:
269 | output_tokens = [output_tokens[i] for i in range(len(output_tokens)) if i not in deleted_idx]
270 |
271 | output_tokens = ['[CLS]'] + output_tokens + ['[SEP]']
272 | target = tokenizer.convert_tokens_to_ids(output_tokens)
273 | target = torch.Tensor(target)
274 | return target
275 |
276 |
277 | def collate_fn(data):
278 | """Build mini-batch tensors from a list of (image, caption) tuples.
279 | Args:
280 | data: list of (image, caption) tuple.
281 | - image: torch tensor of shape (3, 256, 256).
282 | - caption: torch tensor of shape (?); variable length.
283 |
284 | Returns:
285 | images: torch tensor of shape (batch_size, 3, 256, 256).
286 | targets: torch tensor of shape (batch_size, padded_length).
287 | lengths: list; valid length for each padded caption.
288 | """
289 | images, captions, ids, img_ids = zip(*data)
290 | if len(images[0].shape) == 2: # region feature
291 | # Sort a data list by caption length
292 | # Merge images (convert tuple of 3D tensor to 4D tensor)
293 | # images = torch.stack(images, 0)
294 | img_lengths = [len(image) for image in images]
295 | all_images = torch.zeros(len(images), max(img_lengths), images[0].size(-1))
296 | for i, image in enumerate(images):
297 | end = img_lengths[i]
298 | all_images[i, :end] = image[:end]
299 | img_lengths = torch.Tensor(img_lengths)
300 |
301 | # Merget captions (convert tuple of 1D tensor to 2D tensor)
302 | lengths = [len(cap) for cap in captions]
303 | targets = torch.zeros(len(captions), max(lengths)).long()
304 |
305 | for i, cap in enumerate(captions):
306 | end = lengths[i]
307 | targets[i, :end] = cap[:end]
308 |
309 | return all_images, img_lengths, targets, lengths, ids
310 | else: # raw input image
311 | # Merge images (convert tuple of 3D tensor to 4D tensor)
312 | images = torch.stack(images, 0)
313 |
314 | # Merget captions (convert tuple of 1D tensor to 2D tensor)
315 | lengths = [len(cap) for cap in captions]
316 | targets = torch.zeros(len(captions), max(lengths)).long()
317 | for i, cap in enumerate(captions):
318 | end = lengths[i]
319 | targets[i, :end] = cap[:end]
320 | return images, targets, lengths, ids
321 |
322 |
323 | def get_loader(data_path, data_name, data_split, tokenizer, opt, batch_size=100,
324 | shuffle=True, num_workers=2, train=True):
325 | """Returns torch.utils.data.DataLoader for custom coco dataset."""
326 | if train:
327 | drop_last = True
328 | else:
329 | drop_last = False
330 | if opt.precomp_enc_type == 'basic':
331 | dset = PrecompRegionDataset(data_path, data_name, data_split, tokenizer, opt, train)
332 | data_loader = torch.utils.data.DataLoader(dataset=dset,
333 | batch_size=batch_size,
334 | shuffle=shuffle,
335 | pin_memory=True,
336 | collate_fn=collate_fn,
337 | num_workers=num_workers,
338 | drop_last=drop_last)
339 | else:
340 | dset = RawImageDataset(data_path, data_name, data_split, tokenizer, opt, train)
341 | data_loader = torch.utils.data.DataLoader(dataset=dset,
342 | batch_size=batch_size,
343 | shuffle=shuffle,
344 | num_workers=num_workers,
345 | pin_memory=True,
346 | collate_fn=collate_fn)
347 | return data_loader
348 |
349 |
350 | def get_loaders(data_path, data_name, tokenizer, batch_size, workers, opt):
351 | train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
352 | batch_size, True, workers)
353 | val_loader = get_loader(data_path, data_name, 'test', tokenizer, opt,
354 | batch_size, False, workers, train=False)
355 | return train_loader, val_loader
356 |
357 |
358 | def get_train_loader(data_path, data_name, tokenizer, batch_size, workers, opt, shuffle):
359 | train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
360 | batch_size, shuffle, workers)
361 | return train_loader
362 |
363 |
364 | def get_test_loader(split_name, data_name, tokenizer, batch_size, workers, opt):
365 |
366 | # opt.data_path = '/mnt/data10t/bakuphome20210617/lz/data/I-T/Flickr30K/'
367 | # data_name = 'f30k_precomp'
368 | # split_name = 'test'
369 | test_loader = get_loader(opt.data_path, data_name, split_name, tokenizer, opt,
370 | batch_size, False, workers, train=False)
371 | return test_loader
372 |
--------------------------------------------------------------------------------
/lib/encoders.py:
--------------------------------------------------------------------------------
1 | """VSE modules"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | # from collections import OrderedDict
7 |
8 | from transformers import BertModel
9 |
10 | from lib.modules.resnet import ResnetFeatureExtractor
11 | # from lib.modules.aggr.gpo import GPO
12 | from lib.modules.mlp import MLP
13 |
14 | import logging
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def l1norm(X, dim, eps=1e-8):
20 | """L1-normalize columns of X
21 | """
22 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
23 | X = torch.div(X, norm)
24 | return X
25 |
26 |
27 | def l2norm(X, dim, eps=1e-8):
28 | """L2-normalize columns of X
29 | """
30 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
31 | X = torch.div(X, norm)
32 | return X
33 |
34 |
35 | def maxk_pool1d_var(x, dim, k, lengths):
36 | results = list()
37 | lengths = list(lengths.cpu().numpy())
38 | lengths = [int(x) for x in lengths]
39 | for idx, length in enumerate(lengths):
40 | k = min(k, length)
41 | max_k_i = maxk(x[idx, :length, :], dim - 1, k).mean(dim - 1)
42 | results.append(max_k_i)
43 | results = torch.stack(results, dim=0)
44 | return results
45 |
46 |
47 | def maxk_pool1d(x, dim, k):
48 | max_k = maxk(x, dim, k)
49 | return max_k.mean(dim)
50 |
51 |
52 | def maxk(x, dim, k):
53 | index = x.topk(k, dim=dim)[1]
54 | return x.gather(dim, index)
55 |
56 |
57 | def get_text_encoder(embed_size, no_txtnorm=False):
58 | return EncoderText(embed_size, no_txtnorm=no_txtnorm)
59 |
60 |
61 | def get_image_encoder(data_name, img_dim, embed_size, precomp_enc_type='basic',
62 | backbone_source=None, backbone_path=None, no_imgnorm=False):
63 | """A wrapper to image encoders. Chooses between an different encoders
64 | that uses precomputed image features.
65 | """
66 | if precomp_enc_type == 'basic':
67 | img_enc = EncoderImageAggr(
68 | img_dim, embed_size, precomp_enc_type, no_imgnorm)
69 | elif precomp_enc_type == 'backbone':
70 | backbone_cnn = ResnetFeatureExtractor(backbone_source, backbone_path, fixed_blocks=2)
71 | img_enc = EncoderImageFull(backbone_cnn, img_dim, embed_size, precomp_enc_type, no_imgnorm)
72 | else:
73 | raise ValueError("Unknown precomp_enc_type: {}".format(precomp_enc_type))
74 |
75 | return img_enc
76 |
77 |
78 | class EncoderImageAggr(nn.Module):
79 | def __init__(self, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
80 | super(EncoderImageAggr, self).__init__()
81 | self.embed_size = embed_size
82 | self.no_imgnorm = no_imgnorm
83 | self.fc = nn.Linear(img_dim, embed_size)
84 | self.precomp_enc_type = precomp_enc_type
85 | if precomp_enc_type == 'basic':
86 | self.mlp = MLP(img_dim, embed_size // 2, embed_size, 2)
87 | # self.gpool = GPO(32, 32)
88 | self.init_weights()
89 |
90 | def init_weights(self):
91 | """Xavier initialization for the fully connected layer
92 | """
93 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
94 | self.fc.out_features)
95 | self.fc.weight.data.uniform_(-r, r)
96 | self.fc.bias.data.fill_(0)
97 |
98 | def forward(self, images, image_lengths):
99 | """Extract image feature vectors."""
100 | features = self.fc(images)
101 | if self.precomp_enc_type == 'basic':
102 | # When using pre-extracted region features, add an extra MLP for the embedding transformation
103 | features = self.mlp(images) + features
104 |
105 | # features, pool_weights = self.gpool(features, image_lengths)
106 |
107 | if not self.no_imgnorm:
108 | features = l2norm(features, dim=-1)
109 |
110 | return features
111 |
112 |
113 | class EncoderImageFull(nn.Module):
114 | def __init__(self, backbone_cnn, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False):
115 | super(EncoderImageFull, self).__init__()
116 | self.backbone = backbone_cnn
117 | self.image_encoder = EncoderImageAggr(img_dim, embed_size, precomp_enc_type, no_imgnorm)
118 | self.backbone_freezed = False
119 |
120 | def forward(self, images):
121 | """Extract image feature vectors."""
122 | base_features = self.backbone(images)
123 |
124 | if self.training:
125 | # Size Augmentation during training, randomly drop grids
126 | base_length = base_features.size(1)
127 | features = []
128 | feat_lengths = []
129 | rand_list_1 = np.random.rand(base_features.size(0), base_features.size(1))
130 | rand_list_2 = np.random.rand(base_features.size(0))
131 | for i in range(base_features.size(0)):
132 | if rand_list_2[i] > 0.2:
133 | feat_i = base_features[i][np.where(rand_list_1[i] > 0.20 * rand_list_2[i])]
134 | len_i = len(feat_i)
135 | pads_i = torch.zeros(base_length - len_i, base_features.size(-1)).to(base_features.device)
136 | feat_i = torch.cat([feat_i, pads_i], dim=0)
137 | else:
138 | feat_i = base_features[i]
139 | len_i = base_length
140 | feat_lengths.append(len_i)
141 | features.append(feat_i)
142 | base_features = torch.stack(features, dim=0)
143 | base_features = base_features[:, :max(feat_lengths), :]
144 | feat_lengths = torch.tensor(feat_lengths).to(base_features.device)
145 | else:
146 | feat_lengths = torch.zeros(base_features.size(0)).to(base_features.device)
147 | feat_lengths[:] = base_features.size(1)
148 |
149 | features = self.image_encoder(base_features, feat_lengths)
150 |
151 | return features
152 |
153 | def freeze_backbone(self):
154 | for param in self.backbone.parameters():
155 | param.requires_grad = False
156 | logger.info('Backbone freezed.')
157 |
158 | def unfreeze_backbone(self, fixed_blocks):
159 | for param in self.backbone.parameters(): # open up all params first, then adjust the base parameters
160 | param.requires_grad = True
161 | self.backbone.set_fixed_blocks(fixed_blocks)
162 | self.backbone.unfreeze_base()
163 | logger.info('Backbone unfreezed, fixed blocks {}'.format(self.backbone.get_fixed_blocks()))
164 |
165 |
166 | # Language Model with BERT
167 | class EncoderText(nn.Module):
168 | def __init__(self, embed_size, no_txtnorm=False):
169 | super(EncoderText, self).__init__()
170 | self.embed_size = embed_size
171 | self.no_txtnorm = no_txtnorm
172 |
173 | self.bert = BertModel.from_pretrained('bert-base-uncased')
174 | self.linear = nn.Linear(768, embed_size)
175 | # self.gpool = GPO(32, 32)
176 | # self.dropout = nn.Dropout(0.4)
177 |
178 | def forward(self, x, lengths):
179 | """Handles variable size captions
180 | """
181 | # Embed word ids to vectors
182 | bert_attention_mask = (x != 0).float()
183 | bert_emb = self.bert(x, bert_attention_mask)[0] # B x N x D
184 |
185 | cap_len = lengths
186 | # bert_emb = self.dropout(bert_emb)
187 |
188 | cap_emb = self.linear(bert_emb)
189 |
190 | # pooled_features, pool_weights = self.gpool(cap_emb, cap_len.to(cap_emb.device))
191 |
192 | # normalization in the joint embedding space
193 | if not self.no_txtnorm:
194 | cap_emb = l2norm(cap_emb, dim=-1)
195 |
196 | return cap_emb
197 |
--------------------------------------------------------------------------------
/lib/evaluation.py:
--------------------------------------------------------------------------------
1 | """Evaluation"""
2 | from __future__ import print_function
3 | import logging
4 | import time
5 | import torch
6 | import numpy as np
7 |
8 | import sys
9 |
10 | from collections import OrderedDict
11 | from transformers import BertTokenizer
12 |
13 | from lib.datasets import image_caption
14 | from lib.vse import VSEModel
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | class AverageMeter(object):
20 | """Computes and stores the average and current value"""
21 |
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.val = 0
27 | self.avg = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def update(self, val, n=0):
32 | self.val = val
33 | self.sum += val * n
34 | self.count += n
35 | self.avg = self.sum / (.0001 + self.count)
36 |
37 | def __str__(self):
38 | """String representation for logging
39 | """
40 | # for values that should be recorded exactly e.g. iteration number
41 | if self.count == 0:
42 | return str(self.val)
43 | # for stats
44 | return '%.4f (%.4f)' % (self.val, self.avg)
45 |
46 |
47 | class LogCollector(object):
48 | """A collection of logging objects that can change from train to val"""
49 |
50 | def __init__(self):
51 | # to keep the order of logged variables deterministic
52 | self.meters = OrderedDict()
53 |
54 | def update(self, k, v, n=0):
55 | # create a new meter if previously not recorded
56 | if k not in self.meters:
57 | self.meters[k] = AverageMeter()
58 | self.meters[k].update(v, n)
59 |
60 | def __str__(self):
61 | """Concatenate the meters in one log line
62 | """
63 | s = ''
64 | for i, (k, v) in enumerate(self.meters.items()):
65 | if i > 0:
66 | s += ' '
67 | s += k + ' ' + str(v)
68 | return s
69 |
70 | def tb_log(self, tb_logger, prefix='', step=None):
71 | """Log using tensorboard
72 | """
73 | for k, v in self.meters.items():
74 | tb_logger.log_value(prefix + k, v.val, step=step)
75 |
76 |
77 | def encode_data(model, data_loader, log_step=10, logging=logger.info, backbone=False):
78 | """Encode all images and captions loadable by `data_loader`
79 | """
80 | val_logger = LogCollector()
81 |
82 | # switch to evaluate mode
83 | model.val_start()
84 |
85 | # np array to keep all the embeddings
86 | img_embs = None
87 | cap_embs = None
88 | cap_lens = None
89 |
90 | max_n_word = 0
91 | for i, (images, img_lengths, captions, lengths, _) in enumerate(data_loader):
92 | max_n_word = max(max_n_word, max(lengths))
93 |
94 | for i, data_i in enumerate(data_loader):
95 | # make sure val logger is used
96 | if not backbone:
97 | images, image_lengths, captions, lengths, ids = data_i
98 | else:
99 | images, captions, lengths, ids = data_i
100 | model.logger = val_logger
101 |
102 | # compute the embeddings
103 | if not backbone:
104 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths, image_lengths=image_lengths)
105 | else:
106 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths)
107 |
108 | if img_embs is None:
109 | if img_emb.dim() == 3:
110 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
111 | else:
112 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
113 | # cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1)))
114 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2)))
115 | cap_lens = [0] * len(data_loader.dataset)
116 |
117 | # cache embeddings
118 | img_embs[ids] = img_emb.data.cpu().numpy().copy()
119 | # cap_embs[ids, :] = cap_emb.data.cpu().numpy().copy()
120 | cap_embs[ids, :max(lengths), :] = cap_emb.data.cpu().numpy().copy()
121 |
122 | for j, nid in enumerate(ids):
123 | cap_lens[nid] = cap_len[j]
124 |
125 | del images, captions
126 | return img_embs, cap_embs, cap_lens
127 |
128 |
129 | def eval_ensemble(results_paths, fold5=False):
130 | all_sims = []
131 | all_npts = []
132 | for sim_path in results_paths:
133 | results = np.load(sim_path, allow_pickle=True).tolist()
134 | npts = results['npts']
135 | sims = results['sims']
136 | all_npts.append(npts)
137 | all_sims.append(sims)
138 | all_npts = np.array(all_npts)
139 | all_sims = np.array(all_sims)
140 | assert np.all(all_npts == all_npts[0])
141 | npts = int(all_npts[0])
142 | sims = all_sims.mean(axis=0)
143 |
144 | if not fold5:
145 | r, rt = i2t(npts, sims, return_ranks=True)
146 | ri, rti = t2i(npts, sims, return_ranks=True)
147 | ar = (r[0] + r[1] + r[2]) / 3
148 | ari = (ri[0] + ri[1] + ri[2]) / 3
149 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
150 | logger.info("rsum: %.1f" % rsum)
151 | logger.info("Average i2t Recall: %.1f" % ar)
152 | logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
153 | logger.info("Average t2i Recall: %.1f" % ari)
154 | logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
155 | else:
156 | npts = npts // 5
157 | results = []
158 | all_sims = sims.copy()
159 | for i in range(5):
160 | sims = all_sims[i * npts: (i + 1) * npts, i * npts * 5: (i + 1) * npts * 5]
161 | r, rt0 = i2t(npts, sims, return_ranks=True)
162 | logger.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
163 | ri, rti0 = t2i(npts, sims, return_ranks=True)
164 | logger.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
165 |
166 | if i == 0:
167 | rt, rti = rt0, rti0
168 | ar = (r[0] + r[1] + r[2]) / 3
169 | ari = (ri[0] + ri[1] + ri[2]) / 3
170 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
171 | logger.info("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
172 | results += [list(r) + list(ri) + [ar, ari, rsum]]
173 | logger.info("-----------------------------------")
174 | logger.info("Mean metrics: ")
175 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
176 | logger.info("rsum: %.1f" % (mean_metrics[12]))
177 | logger.info("Average i2t Recall: %.1f" % mean_metrics[10])
178 | logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" %
179 | mean_metrics[:5])
180 | logger.info("Average t2i Recall: %.1f" % mean_metrics[11])
181 | logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" %
182 | mean_metrics[5:10])
183 |
184 |
185 | def evalrank(model_path, data_path=None, split='dev', fold5=False, save_path=None, cxc=False):
186 | """
187 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
188 | cross-validation is done (only for MSCOCO). Otherwise, the full data is
189 | used for evaluation.
190 | """
191 | # load model and options
192 | checkpoint = torch.load(model_path)
193 | opt = checkpoint['opt']
194 | opt.workers = 5
195 |
196 | logger.info(opt)
197 | if not hasattr(opt, 'caption_loss'):
198 | opt.caption_loss = False
199 |
200 | # load vocabulary used by the model
201 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
202 | vocab = tokenizer.vocab
203 | opt.vocab_size = len(vocab)
204 |
205 | opt.backbone_path = '/tmp/data/weights/original_updown_backbone.pth'
206 | if data_path is not None:
207 | opt.data_path = data_path
208 |
209 | # construct model
210 | model = VSEModel(opt)
211 |
212 | model.make_data_parallel()
213 | # load model state
214 | model.load_state_dict(checkpoint['model'])
215 | model.val_start()
216 |
217 | logger.info('Loading dataset')
218 | data_loader = image_caption.get_test_loader(split, opt.data_name, tokenizer,
219 | opt.batch_size, opt.workers, opt)
220 |
221 | logger.info('Computing results...')
222 | with torch.no_grad():
223 | if opt.precomp_enc_type == 'basic':
224 | # img_embs, cap_embs = encode_data(model, data_loader)
225 | img_embs, cap_embs, cap_lens = encode_data(model, data_loader)
226 | else:
227 | # img_embs, cap_embs = encode_data(model, data_loader, backbone=True)
228 | img_embs, cap_embs, cap_lens = encode_data(model, data_loader, backbone=True)
229 |
230 | logger.info('Images: %d, Captions: %d' %
231 | (img_embs.shape[0] / 5, cap_embs.shape[0]))
232 |
233 | if cxc:
234 | eval_cxc(img_embs, cap_embs, data_path)
235 | else:
236 | if not fold5:
237 | # no cross-validation, full evaluation
238 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
239 |
240 | # sims = compute_sim(img_embs, cap_embs)
241 | start = time.time()
242 | sims = shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=1000)
243 | npts = img_embs.shape[0]
244 | end = time.time()
245 | print("calculate similarity time:", end - start)
246 |
247 | np.savetxt('/mnt/data2/zk/ESL_bert/' + 'sim_best.txt', sims, fmt='%.5f')
248 |
249 |
250 | if save_path is not None:
251 | np.save(save_path, {'npts': npts, 'sims': sims})
252 | logger.info('Save the similarity into {}'.format(save_path))
253 |
254 | end = time.time()
255 | logger.info("calculate similarity time: {}".format(end - start))
256 | end = time.time()
257 | print("calculate similarity time:", end - start)
258 |
259 | r, rt = i2t(npts, sims, return_ranks=True)
260 | ri, rti = t2i(npts, sims, return_ranks=True)
261 | ar = (r[0] + r[1] + r[2]) / 3
262 | ari = (ri[0] + ri[1] + ri[2]) / 3
263 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
264 |
265 | print("rsum: %.1f" % rsum)
266 | print("Average i2t Recall: %.1f" % ar)
267 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
268 | print("Average t2i Recall: %.1f" % ari)
269 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
270 | else:
271 | # 5fold cross-validation, only for MSCOCO
272 | results = []
273 | for i in range(5):
274 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
275 |
276 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
277 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
278 | start = time.time()
279 | # sims = compute_sim(img_embs_shard, cap_embs_shard)
280 |
281 | sims = shard_attn_scores(model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=1000)
282 | end = time.time()
283 | # logger.info("calculate similarity time: {}".format(end - start))
284 |
285 | print("calculate similarity time:", end - start)
286 |
287 | np.savetxt('/mnt/data2/zk/ESL_bert/' + str(i) + 'sim_best.txt', sims, fmt='%.5f')
288 |
289 | npts = img_embs_shard.shape[0]
290 | r, rt0 = i2t(npts, sims, return_ranks=True)
291 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
292 | ri, rti0 = t2i(npts, sims, return_ranks=True)
293 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
294 |
295 | if i == 0:
296 | rt, rti = rt0, rti0
297 | ar = (r[0] + r[1] + r[2]) / 3
298 | ari = (ri[0] + ri[1] + ri[2]) / 3
299 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
300 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
301 | results += [list(r) + list(ri) + [ar, ari, rsum]]
302 |
303 | print("-----------------------------------")
304 | print("Mean metrics: ")
305 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
306 | print("rsum: %.1f" % (mean_metrics[12]))
307 | print("Average i2t Recall: %.1f" % mean_metrics[11])
308 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
309 | mean_metrics[:5])
310 | print("Average t2i Recall: %.1f" % mean_metrics[12])
311 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
312 | mean_metrics[5:10])
313 |
314 |
315 | def compute_sim(images, captions):
316 | similarities = np.matmul(images, np.matrix.transpose(captions))
317 | return similarities
318 |
319 |
320 | def shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=100):
321 | n_im_shard = (len(img_embs) - 1) // shard_size + 1
322 | n_cap_shard = (len(cap_embs) - 1) // shard_size + 1
323 |
324 | sims = np.zeros((len(img_embs), len(cap_embs)))
325 |
326 | for i in range(n_im_shard):
327 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(img_embs))
328 | for j in range(n_cap_shard):
329 | sys.stdout.write('\r>> shard_attn_scores batch (%d,%d)' % (i, j))
330 | ca_start, ca_end = shard_size * j, min(shard_size * (j + 1), len(cap_embs))
331 |
332 | with torch.no_grad():
333 | im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda()
334 | ca = torch.from_numpy(cap_embs[ca_start:ca_end]).float().cuda()
335 | l = cap_lens[ca_start:ca_end]
336 | sim, A = model.forward_sim_test(im, ca, l)
337 |
338 | sims[im_start:im_end, ca_start:ca_end] = sim.data.cpu().numpy()
339 | sys.stdout.write('\n')
340 | return sims
341 |
342 |
343 | def i2t(npts, sims, return_ranks=False, mode='coco'):
344 | """
345 | Images->Text (Image Annotation)
346 | Images: (N, n_region, d) matrix of images
347 | Captions: (5N, max_n_word, d) matrix of captions
348 | CapLens: (5N) array of caption lengths
349 | sims: (N, 5N) matrix of similarity im-cap
350 | """
351 | ranks = np.zeros(npts)
352 | top1 = np.zeros(npts)
353 | for index in range(npts):
354 | inds = np.argsort(sims[index])[::-1]
355 | if mode == 'coco':
356 | rank = 1e20
357 | for i in range(5 * index, 5 * index + 5, 1):
358 | tmp = np.where(inds == i)[0][0]
359 | if tmp < rank:
360 | rank = tmp
361 | ranks[index] = rank
362 | top1[index] = inds[0]
363 | else:
364 | rank = np.where(inds == index)[0][0]
365 | ranks[index] = rank
366 | top1[index] = inds[0]
367 |
368 | # Compute metrics
369 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
370 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
371 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
372 | medr = np.floor(np.median(ranks)) + 1
373 | meanr = ranks.mean() + 1
374 |
375 | if return_ranks:
376 | return (r1, r5, r10, medr, meanr), (ranks, top1)
377 | else:
378 | return (r1, r5, r10, medr, meanr)
379 |
380 |
381 | def t2i(npts, sims, return_ranks=False, mode='coco'):
382 | """
383 | Text->Images (Image Search)
384 | Images: (N, n_region, d) matrix of images
385 | Captions: (5N, max_n_word, d) matrix of captions
386 | CapLens: (5N) array of caption lengths
387 | sims: (N, 5N) matrix of similarity im-cap
388 | """
389 | # npts = images.shape[0]
390 |
391 | if mode == 'coco':
392 | ranks = np.zeros(5 * npts)
393 | top1 = np.zeros(5 * npts)
394 | else:
395 | ranks = np.zeros(npts)
396 | top1 = np.zeros(npts)
397 |
398 | # --> (5N(caption), N(image))
399 | sims = sims.T
400 |
401 | for index in range(npts):
402 | if mode == 'coco':
403 | for i in range(5):
404 | inds = np.argsort(sims[5 * index + i])[::-1]
405 | ranks[5 * index + i] = np.where(inds == index)[0][0]
406 | top1[5 * index + i] = inds[0]
407 | else:
408 | inds = np.argsort(sims[index])[::-1]
409 | ranks[index] = np.where(inds == index)[0][0]
410 | top1[index] = inds[0]
411 |
412 | # Compute metrics
413 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
414 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
415 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
416 | medr = np.floor(np.median(ranks)) + 1
417 | meanr = ranks.mean() + 1
418 | if return_ranks:
419 | return (r1, r5, r10, medr, meanr), (ranks, top1)
420 | else:
421 | return (r1, r5, r10, medr, meanr)
422 |
423 |
424 | """
425 | CxC related evaluation.
426 | """
427 |
428 | def eval_cxc(images, captions, data_path):
429 | import os
430 | import json
431 | cxc_annot_base = os.path.join(data_path, 'cxc_annots')
432 | img_id_path = os.path.join(cxc_annot_base, 'testall_ids.txt')
433 | cap_id_path = os.path.join(cxc_annot_base, 'testall_capids.txt')
434 |
435 | images = images[::5, :]
436 |
437 | with open(img_id_path) as f:
438 | img_ids = f.readlines()
439 | with open(cap_id_path) as f:
440 | cap_ids = f.readlines()
441 |
442 | img_ids = [img_id.strip() for i, img_id in enumerate(img_ids) if i % 5 == 0]
443 | cap_ids = [cap_id.strip() for cap_id in cap_ids]
444 |
445 | with open(os.path.join(cxc_annot_base, 'cxc_it.json')) as f_it:
446 | cxc_it = json.load(f_it)
447 | with open(os.path.join(cxc_annot_base, 'cxc_i2i.json')) as f_i2i:
448 | cxc_i2i = json.load(f_i2i)
449 | with open(os.path.join(cxc_annot_base, 'cxc_t2t.json')) as f_t2t:
450 | cxc_t2t = json.load(f_t2t)
451 |
452 | sims = compute_sim(images, captions)
453 | t2i_recalls = cxc_inter(sims.T, img_ids, cap_ids, cxc_it['t2i'])
454 | i2t_recalls = cxc_inter(sims, cap_ids, img_ids, cxc_it['i2t'])
455 | logger.info('T2I R@1: {}, R@5: {}, R@10: {}'.format(*t2i_recalls))
456 | logger.info('I2T R@1: {}, R@5: {}, R@10: {}'.format(*i2t_recalls))
457 |
458 | i2i_recalls = cxc_intra(images, img_ids, cxc_i2i)
459 | t2t_recalls = cxc_intra(captions, cap_ids, cxc_t2t, text=True)
460 | logger.info('I2I R@1: {}, R@5: {}, R@10: {}'.format(*i2i_recalls))
461 | logger.info('T2T R@1: {}, R@5: {}, R@10: {}'.format(*t2t_recalls))
462 |
463 |
464 | def cxc_inter(sims, data_ids, query_ids, annot):
465 | ranks = list()
466 | for idx, query_id in enumerate(query_ids):
467 | if query_id not in annot:
468 | raise ValueError('unexpected query id {}'.format(query_id))
469 | pos_data_ids = annot[query_id]
470 | pos_data_ids = [pos_data_id for pos_data_id in pos_data_ids if str(pos_data_id[0]) in data_ids]
471 | pos_data_indices = [data_ids.index(str(pos_data_id[0])) for pos_data_id in pos_data_ids]
472 | rank = 1e20
473 | inds = np.argsort(sims[idx])[::-1]
474 | for pos_data_idx in pos_data_indices:
475 | tmp = np.where(inds == pos_data_idx)[0][0]
476 | if tmp < rank:
477 | rank = tmp
478 | ranks.append(rank)
479 | ranks = np.array(ranks)
480 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
481 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
482 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
483 | return (r1, r5, r10)
484 |
485 |
486 | def cxc_intra(embs, data_ids, annot, text=False):
487 | pos_thresh = 3.0 if text else 2.5 # threshold for positive pairs according to the CxC paper
488 |
489 | sims = compute_sim(embs, embs)
490 | np.fill_diagonal(sims, 0)
491 |
492 | ranks = list()
493 | for idx, data_id in enumerate(data_ids):
494 | sim_items = annot[data_id]
495 | pos_items = [item for item in sim_items if item[1] >= pos_thresh]
496 | rank = 1e20
497 | inds = np.argsort(sims[idx])[::-1]
498 | if text:
499 | coco_pos = list(range(idx // 5 * 5, (idx // 5 + 1) * 5))
500 | coco_pos.remove(idx)
501 | pos_indices = coco_pos
502 | pos_indices.extend([data_ids.index(str(pos_item[0])) for pos_item in pos_items])
503 | else:
504 | pos_indices = [data_ids.index(str(pos_item[0])) for pos_item in pos_items]
505 | if len(pos_indices) == 0: # skip it since there is positive example in the annotation
506 | continue
507 | for pos_idx in pos_indices:
508 | tmp = np.where(inds == pos_idx)[0][0]
509 | if tmp < rank:
510 | rank = tmp
511 | ranks.append(rank)
512 |
513 | ranks = np.array(ranks)
514 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
515 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
516 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
517 | return (r1, r5, r10)
518 |
--------------------------------------------------------------------------------
/lib/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 |
6 | class ContrastiveLoss(nn.Module):
7 | """
8 | Compute contrastive loss (max-margin based)
9 | """
10 |
11 | def __init__(self, opt, margin=0, max_violation=False):
12 | super(ContrastiveLoss, self).__init__()
13 | self.opt = opt
14 | self.margin = margin
15 | self.max_violation = max_violation
16 |
17 | def max_violation_on(self):
18 | self.max_violation = True
19 | print('Use VSE++ objective.')
20 |
21 | def max_violation_off(self):
22 | self.max_violation = False
23 | print('Use VSE0 objective.')
24 |
25 | def forward(self, scores):
26 | # compute image-sentence score matrix
27 | # scores = get_sim(im, s)
28 | # diagonal = scores.diag().view(im.size(0), 1)
29 |
30 | diagonal = scores.diag().view(scores.size(0), 1)
31 | d1 = diagonal.expand_as(scores)
32 | d2 = diagonal.t().expand_as(scores)
33 |
34 | # compare every diagonal score to scores in its column
35 | # caption retrieval
36 | cost_s = (self.margin + scores - d1).clamp(min=0)
37 | # compare every diagonal score to scores in its row
38 | # image retrieval
39 | cost_im = (self.margin + scores - d2).clamp(min=0)
40 |
41 | # clear diagonals
42 | mask = torch.eye(scores.size(0)) > .5
43 | I = Variable(mask)
44 | if torch.cuda.is_available():
45 | I = I.cuda()
46 | cost_s = cost_s.masked_fill_(I, 0)
47 | cost_im = cost_im.masked_fill_(I, 0)
48 |
49 | # keep the maximum violating negative for each query
50 | if self.max_violation:
51 | cost_s = cost_s.max(1)[0]
52 | cost_im = cost_im.max(0)[0]
53 |
54 | return cost_s.sum() + cost_im.sum()
55 |
56 |
57 | def get_sim(images, captions):
58 | similarities = images.mm(captions.t())
59 | return similarities
60 |
61 |
--------------------------------------------------------------------------------
/lib/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__init__.py
--------------------------------------------------------------------------------
/lib/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/modules/__pycache__/mlp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__pycache__/mlp.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/modules/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/modules/aggr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/aggr/__init__.py
--------------------------------------------------------------------------------
/lib/modules/aggr/gpo.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6 |
7 |
8 | def positional_encoding_1d(d_model, length):
9 | """
10 | :param d_model: dimension of the model
11 | :param length: length of positions
12 | :return: length*d_model position matrix
13 | """
14 | if d_model % 2 != 0:
15 | raise ValueError("Cannot use sin/cos positional encoding with "
16 | "odd dim (got dim={:d})".format(d_model))
17 | pe = torch.zeros(length, d_model)
18 | position = torch.arange(0, length).unsqueeze(1)
19 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
20 | -(math.log(10000.0) / d_model)))
21 | pe[:, 0::2] = torch.sin(position.float() * div_term)
22 | pe[:, 1::2] = torch.cos(position.float() * div_term)
23 |
24 | return pe
25 |
26 |
27 | class GPO(nn.Module):
28 | def __init__(self, d_pe, d_hidden):
29 | super(GPO, self).__init__()
30 | self.d_pe = d_pe
31 | self.d_hidden = d_hidden
32 |
33 | self.pe_database = {}
34 | self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True)
35 | self.linear = nn.Linear(self.d_hidden, 1, bias=False)
36 |
37 | def compute_pool_weights(self, lengths, features):
38 | max_len = int(lengths.max())
39 | pe_max_len = self.get_pe(max_len)
40 | pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device)
41 | mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device)
42 | mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1)
43 | pes = pes.masked_fill(mask == 0, 0)
44 |
45 | self.gru.flatten_parameters()
46 | packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False)
47 | out, _ = self.gru(packed)
48 | padded = pad_packed_sequence(out, batch_first=True)
49 | out_emb, out_len = padded
50 | out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2
51 | scores = self.linear(out_emb)
52 | scores[torch.where(mask == 0)] = -10000
53 |
54 | weights = torch.softmax(scores / 0.1, 1)
55 | return weights, mask
56 |
57 | def forward(self, features, lengths):
58 | """
59 | :param features: features with shape B x K x D
60 | :param lengths: B x 1, specify the length of each data sample.
61 | :return: pooled feature with shape B x D
62 | """
63 | pool_weights, mask = self.compute_pool_weights(lengths, features)
64 |
65 | features = features[:, :int(lengths.max()), :]
66 | sorted_features = features.masked_fill(mask == 0, -10000)
67 | sorted_features = sorted_features.sort(dim=1, descending=True)[0]
68 | sorted_features = sorted_features.masked_fill(mask == 0, 0)
69 |
70 | pooled_features = (sorted_features * pool_weights).sum(1)
71 | return pooled_features, pool_weights
72 |
73 | def get_pe(self, length):
74 | """
75 |
76 | :param length: the length of the sequence
77 | :return: the positional encoding of the given length
78 | """
79 | length = int(length)
80 | if length in self.pe_database:
81 | return self.pe_database[length]
82 | else:
83 | pe = positional_encoding_1d(self.d_pe, length)
84 | self.pe_database[length] = pe
85 | return pe
86 |
--------------------------------------------------------------------------------
/lib/modules/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class TwoLayerMLP(nn.Module):
6 | def __init__(self, num_features, hid_dim, out_dim, return_hidden=False):
7 | super().__init__()
8 | self.return_hidden = return_hidden
9 | self.model = nn.Sequential(
10 | nn.Linear(num_features, hid_dim),
11 | nn.ReLU(),
12 | nn.Linear(hid_dim, out_dim),
13 | )
14 |
15 | for m in self.model:
16 | if isinstance(m, nn.Linear):
17 | nn.init.kaiming_normal_(m.weight)
18 | nn.init.constant_(m.bias, 0)
19 |
20 | def forward(self, x):
21 | if not self.return_hidden:
22 | return self.model(x)
23 | else:
24 | hid_feat = self.model[:2](x)
25 | results = self.model[2:](hid_feat)
26 | return hid_feat, results
27 |
28 |
29 | class MLP(nn.Module):
30 | """ Very simple multi-layer perceptron (also called FFN)"""
31 |
32 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
33 | super().__init__()
34 | self.output_dim = output_dim
35 | self.num_layers = num_layers
36 | h = [hidden_dim] * (num_layers - 1)
37 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
38 | self.bns = nn.ModuleList(nn.BatchNorm1d(k) for k in h + [output_dim])
39 |
40 | def forward(self, x):
41 | B, N, D = x.size()
42 | x = x.reshape(B*N, D)
43 | for i, (bn, layer) in enumerate(zip(self.bns, self.layers)):
44 | x = F.relu(bn(layer(x))) if i < self.num_layers - 1 else layer(x)
45 | x = x.view(B, N, self.output_dim)
46 | return x
47 |
48 |
--------------------------------------------------------------------------------
/lib/modules/resnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 | import torch.utils.model_zoo as model_zoo
7 | import logging
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | __all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
12 |
13 | model_urls = {
14 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
17 | }
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1):
21 | "3x3 convolution with padding"
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23 | padding=1, bias=False)
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None):
30 | super(BasicBlock, self).__init__()
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = nn.BatchNorm2d(planes)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 | self.bn2 = nn.BatchNorm2d(planes)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 |
49 | if self.downsample is not None:
50 | residual = self.downsample(x)
51 |
52 | out += residual
53 | out = self.relu(out)
54 |
55 | return out
56 |
57 |
58 | class Bottleneck(nn.Module):
59 | expansion = 4
60 |
61 | def __init__(self, inplanes, planes, stride=1, downsample=None):
62 | super(Bottleneck, self).__init__()
63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
64 | self.bn1 = nn.BatchNorm2d(planes)
65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
66 | padding=1, bias=False)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
69 | self.bn3 = nn.BatchNorm2d(planes * 4)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | residual = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | residual = self.downsample(x)
90 |
91 | out += residual
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNet(nn.Module):
98 | def __init__(self, block, layers, width_mult, num_classes=1000):
99 | self.inplanes = 64 * width_mult
100 | super(ResNet, self).__init__()
101 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
102 | bias=False)
103 | self.bn1 = nn.BatchNorm2d(self.inplanes)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
106 | self.layer1 = self._make_layer(block, 64 * width_mult, layers[0])
107 | self.layer2 = self._make_layer(block, 128 * width_mult, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256 * width_mult, layers[2], stride=2)
109 | self.layer4 = self._make_layer(block, 512 * width_mult, layers[3], stride=2)
110 |
111 | self.avgpool = nn.AvgPool2d(7)
112 | self.fc = nn.Linear(512 * block.expansion * width_mult, num_classes)
113 |
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
117 | m.weight.data.normal_(0, math.sqrt(2. / n))
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 |
122 | def _make_layer(self, block, planes, blocks, stride=1):
123 | downsample = None
124 | if stride != 1 or self.inplanes != planes * block.expansion:
125 | downsample = nn.Sequential(
126 | nn.Conv2d(self.inplanes, planes * block.expansion,
127 | kernel_size=1, stride=stride, bias=False),
128 | nn.BatchNorm2d(planes * block.expansion),
129 | )
130 |
131 | layers = []
132 | layers.append(block(self.inplanes, planes, stride, downsample))
133 | self.inplanes = planes * block.expansion
134 | for i in range(1, blocks):
135 | layers.append(block(self.inplanes, planes))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def forward(self, x):
140 | x = self.conv1(x)
141 | x = self.bn1(x)
142 | x = self.relu(x)
143 | x = self.maxpool(x)
144 |
145 | x = self.layer1(x)
146 | x = self.layer2(x)
147 | x = self.layer3(x)
148 | x = self.layer4(x)
149 |
150 | x = self.avgpool(x)
151 | x = x.view(x.size(0), -1)
152 | x = self.fc(x)
153 |
154 | return x
155 |
156 |
157 | def resnet50(pretrained=False, width_mult=1):
158 | """Constructs a ResNet-50 model.
159 | Args:
160 | pretrained (bool): If True, returns a model pre-trained on ImageNet
161 | """
162 | model = ResNet(Bottleneck, [3, 4, 6, 3], width_mult)
163 | if pretrained:
164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
165 | return model
166 |
167 |
168 | def resnet101(pretrained=False, width_mult=1):
169 | """Constructs a ResNet-101 model.
170 | Args:
171 | pretrained (bool): If True, returns a model pre-trained on ImageNet
172 | """
173 | model = ResNet(Bottleneck, [3, 4, 23, 3], width_mult)
174 | if pretrained:
175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
176 | return model
177 |
178 |
179 | def resnet152(pretrained=False, width_mult=1):
180 | """Constructs a ResNet-152 model.
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | """
184 | model = ResNet(Bottleneck, [3, 8, 36, 3], width_mult)
185 | if pretrained:
186 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
187 | return model
188 |
189 |
190 | class ResnetFeatureExtractor(nn.Module):
191 | def __init__(self, backbone_source, weights_path, pooling_size=7, fixed_blocks=2):
192 | super(ResnetFeatureExtractor, self).__init__()
193 | self.backbone_source = backbone_source
194 | self.weights_path = weights_path
195 | self.pooling_size = pooling_size
196 | self.fixed_blocks = fixed_blocks
197 |
198 | if 'detector' in self.backbone_source:
199 | self.resnet = resnet101()
200 | elif self.backbone_source == 'imagenet':
201 | self.resnet = resnet101(pretrained=True)
202 | elif self.backbone_source == 'imagenet_res50':
203 | self.resnet = resnet50(pretrained=True)
204 | elif self.backbone_source == 'imagenet_res152':
205 | self.resnet = resnet152(pretrained=True)
206 | elif self.backbone_source == 'imagenet_resnext':
207 | self.resnet = torch.hub.load('pytorch/vision:v0.4.2', 'resnext101_32x8d', pretrained=True)
208 | elif 'wsl' in self.backbone_source:
209 | self.resnet = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
210 | else:
211 | raise ValueError('Unknown backbone source {}'.format(self.backbone_source))
212 |
213 | self._init_modules()
214 |
215 | def _init_modules(self):
216 | # Build resnet.
217 | self.base = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu,
218 | self.resnet.maxpool, self.resnet.layer1, self.resnet.layer2, self.resnet.layer3)
219 | self.top = nn.Sequential(self.resnet.layer4)
220 |
221 | if self.weights_path != '':
222 | if 'detector' in self.backbone_source:
223 | if os.path.exists(self.weights_path):
224 | logger.info(
225 | 'Loading pretrained backbone weights from {} for backbone source {}'.format(self.weights_path,
226 | self.backbone_source))
227 | backbone_ckpt = torch.load(self.weights_path)
228 | self.base.load_state_dict(backbone_ckpt['base'])
229 | self.top.load_state_dict(backbone_ckpt['top'])
230 | else:
231 | raise ValueError('Could not find weights for backbone CNN at {}'.format(self.weights_path))
232 | else:
233 | logger.info('Did not load external checkpoints')
234 | self.unfreeze_base()
235 |
236 | def set_fixed_blocks(self, fixed_blocks):
237 | self.fixed_blocks = fixed_blocks
238 |
239 | def get_fixed_blocks(self):
240 | return self.fixed_blocks
241 |
242 | def unfreeze_base(self):
243 | assert (0 <= self.fixed_blocks < 4)
244 | if self.fixed_blocks == 3:
245 | for p in self.base[6].parameters(): p.requires_grad = False
246 | for p in self.base[5].parameters(): p.requires_grad = False
247 | for p in self.base[4].parameters(): p.requires_grad = False
248 | for p in self.base[0].parameters(): p.requires_grad = False
249 | for p in self.base[1].parameters(): p.requires_grad = False
250 | if self.fixed_blocks == 2:
251 | for p in self.base[6].parameters(): p.requires_grad = True
252 | for p in self.base[5].parameters(): p.requires_grad = False
253 | for p in self.base[4].parameters(): p.requires_grad = False
254 | for p in self.base[0].parameters(): p.requires_grad = False
255 | for p in self.base[1].parameters(): p.requires_grad = False
256 | if self.fixed_blocks == 1:
257 | for p in self.base[6].parameters(): p.requires_grad = True
258 | for p in self.base[5].parameters(): p.requires_grad = True
259 | for p in self.base[4].parameters(): p.requires_grad = False
260 | for p in self.base[0].parameters(): p.requires_grad = False
261 | for p in self.base[1].parameters(): p.requires_grad = False
262 | if self.fixed_blocks == 0:
263 | for p in self.base[6].parameters(): p.requires_grad = True
264 | for p in self.base[5].parameters(): p.requires_grad = True
265 | for p in self.base[4].parameters(): p.requires_grad = True
266 | for p in self.base[0].parameters(): p.requires_grad = True
267 | for p in self.base[1].parameters(): p.requires_grad = True
268 |
269 | logger.info('Resnet backbone now has fixed blocks {}'.format(self.fixed_blocks))
270 |
271 | def freeze_base(self):
272 | for p in self.base.parameters():
273 | p.requires_grad = False
274 |
275 | def train(self, mode=True):
276 | # Override train so that the training mode is set as we want (BN does not update the running stats)
277 | nn.Module.train(self, mode)
278 | if mode:
279 | # fix all bn layers
280 | def set_bn_eval(m):
281 | classname = m.__class__.__name__
282 | if classname.find('BatchNorm') != -1:
283 | m.eval()
284 |
285 | self.base.apply(set_bn_eval)
286 | self.top.apply(set_bn_eval)
287 |
288 | def _head_to_tail(self, pool5):
289 | fc7 = self.top(pool5).mean(3).mean(2)
290 | return fc7
291 |
292 | def forward(self, im_data):
293 | b_s = im_data.size(0)
294 | base_feat = self.base(im_data)
295 | top_feat = self.top(base_feat)
296 | features = top_feat.view(b_s, top_feat.size(1), -1).permute(0, 2, 1)
297 | return features
298 |
299 |
300 | if __name__ == '__main__':
301 | import numpy as np
302 |
303 | def count_params(model):
304 | model_parameters = model.parameters()
305 | params = sum([np.prod(p.size()) for p in model_parameters])
306 | return params
307 |
308 | model = resnet50(pretrained=False, width_mult=1)
309 | num_params = count_params(model)
310 |
311 |
--------------------------------------------------------------------------------
/lib/pytorch-logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/pytorch-logo-dark.png
--------------------------------------------------------------------------------
/lib/vse.py:
--------------------------------------------------------------------------------
1 | """ESL model, building on the top of VSE model"""
2 | import numpy as np
3 | import os
4 | import torch
5 | import torch.nn as nn
6 | import math
7 | import torch.nn.functional as F
8 |
9 | import torch.nn.init
10 | import torch.backends.cudnn as cudnn
11 | from torch.nn.utils import clip_grad_norm_
12 |
13 | from lib.encoders import get_image_encoder, get_text_encoder
14 | from lib.loss import ContrastiveLoss
15 |
16 | import logging
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | def logging_func(log_file, message):
21 | with open(log_file,'a') as f:
22 | f.write(message)
23 | f.close()
24 |
25 | def l2norm(X, dim=-1, eps=1e-8):
26 | """L2-normalize columns of X"""
27 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
28 | X = torch.div(X, norm)
29 | return X
30 |
31 |
32 | def cosine_similarity(x1, x2, dim=1, eps=1e-8):
33 | """Returns cosine similarity between x1 and x2, computed along dim."""
34 | w12 = torch.sum(x1 * x2, dim)
35 | w1 = torch.norm(x1, 2, dim)
36 | w2 = torch.norm(x2, 2, dim)
37 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
38 |
39 |
40 | ########################################################
41 | ### For the dimensional selective mask, we design both heuristic and adaptive strategies.
42 | ### You can use this flag to control which strategy is selected. True -> heuristic strategy, False -> adaptive strategy
43 |
44 | heuristic_strategy = False
45 | ########################################################
46 |
47 |
48 | if heuristic_strategy:
49 |
50 | ### Heuristic Dimensional Selective Mask
51 | class Image_levels(nn.Module):
52 | def __init__(self, opt):
53 | super(Image_levels, self).__init__()
54 | self.sub_space = opt.embed_size
55 | self.kernel_size = int(opt.kernel_size)
56 |
57 | self.kernel_img_1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
58 | self.kernel_img_2 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
59 | self.kernel_img_3 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
60 | self.kernel_img_4 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
61 | self.kernel_img_5 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
62 | self.kernel_img_6 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
63 | self.kernel_img_7 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
64 | self.kernel_img_8 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
65 |
66 |
67 | def get_image_levels(self, img_emb, batch_size, n_region):
68 | img_emb_1 = self.kernel_img_1(img_emb.reshape(-1, self.sub_space).unsqueeze(-2)).sum(1)
69 | img_emb_1 = l2norm(img_emb_1.reshape(batch_size, n_region, -1), -1)
70 |
71 | emb_size = img_emb_1.size(-1)
72 | img_emb_2 = self.kernel_img_2(img_emb_1.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
73 | img_emb_2 = l2norm(img_emb_2.reshape(batch_size, n_region, -1), -1)
74 |
75 | emb_size = img_emb_2.size(-1)
76 | img_emb_3 = self.kernel_img_3(img_emb_2.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
77 | img_emb_3 = l2norm(img_emb_3.reshape(batch_size, n_region, -1), -1)
78 |
79 | emb_size = img_emb_3.size(-1)
80 | img_emb_4 = self.kernel_img_4(img_emb_3.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
81 | img_emb_4 = l2norm(img_emb_4.reshape(batch_size, n_region, -1), -1)
82 |
83 | emb_size = img_emb_4.size(-1)
84 | img_emb_5 = self.kernel_img_5(img_emb_4.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
85 | img_emb_5 = l2norm(img_emb_5.reshape(batch_size, n_region, -1), -1)
86 |
87 | emb_size = img_emb_5.size(-1)
88 | img_emb_6 = self.kernel_img_6(img_emb_5.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
89 | img_emb_6 = l2norm(img_emb_6.reshape(batch_size, n_region, -1), -1)
90 |
91 | emb_size = img_emb_6.size(-1)
92 | img_emb_7 = self.kernel_img_7(img_emb_6.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
93 | img_emb_7 = l2norm(img_emb_7.reshape(batch_size, n_region, -1), -1)
94 |
95 | emb_size = img_emb_7.size(-1)
96 | img_emb_8 = self.kernel_img_8(img_emb_7.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
97 | img_emb_8 = l2norm(img_emb_8.reshape(batch_size, n_region, -1), -1)
98 |
99 |
100 |
101 | return torch.cat([img_emb, img_emb_1, img_emb_2, img_emb_3, img_emb_4, img_emb_5, img_emb_6, img_emb_7, img_emb_8], -1)
102 |
103 | def forward(self, img_emb):
104 |
105 | batch_size, n_region, embed_size = img_emb.size(0), img_emb.size(1), img_emb.size(2)
106 |
107 | return self.get_image_levels(img_emb, batch_size, n_region)
108 |
109 |
110 | class Text_levels(nn.Module):
111 |
112 | def __init__(self, opt):
113 | super(Text_levels, self).__init__()
114 | self.sub_space = opt.embed_size
115 | self.kernel_size = int(opt.kernel_size)
116 |
117 | self.kernel_txt_1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
118 | self.kernel_txt_2 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
119 | self.kernel_txt_3 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
120 | self.kernel_txt_4 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
121 | self.kernel_txt_5 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
122 | self.kernel_txt_6 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
123 | self.kernel_txt_7 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
124 | self.kernel_txt_8 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False)
125 |
126 |
127 | def get_text_levels(self, cap_i, batch_size, n_word):
128 | cap_i_1 = self.kernel_txt_1(cap_i.reshape(-1, self.sub_space).unsqueeze(-2)).sum(1)
129 | cap_i_expand_1 = l2norm(cap_i_1.reshape(batch_size, n_word, -1), -1)
130 |
131 | emb_size = cap_i_expand_1.size(-1)
132 | cap_i_2 = self.kernel_txt_2(cap_i_1.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
133 | cap_i_expand_2 = l2norm(cap_i_2.reshape(batch_size, n_word, -1), -1)
134 |
135 | emb_size = cap_i_expand_2.size(-1)
136 | cap_i_3 = self.kernel_txt_3(cap_i_2.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
137 | cap_i_expand_3 = l2norm(cap_i_3.reshape(batch_size, n_word, -1), -1)
138 |
139 | emb_size = cap_i_expand_3.size(-1)
140 | cap_i_4 = self.kernel_txt_4(cap_i_3.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
141 | cap_i_expand_4 = l2norm(cap_i_4.reshape(batch_size, n_word, -1), -1)
142 |
143 | emb_size = cap_i_expand_4.size(-1)
144 | cap_i_5 = self.kernel_txt_5(cap_i_4.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
145 | cap_i_expand_5 = l2norm(cap_i_5.reshape(batch_size, n_word, -1), -1)
146 |
147 | emb_size = cap_i_expand_5.size(-1)
148 | cap_i_6 = self.kernel_txt_6(cap_i_5.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
149 | cap_i_expand_6 = l2norm(cap_i_6.reshape(batch_size, n_word, -1), -1)
150 |
151 | emb_size = cap_i_expand_6.size(-1)
152 | cap_i_7 = self.kernel_txt_7(cap_i_6.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
153 | cap_i_expand_7 = l2norm(cap_i_7.reshape(batch_size, n_word, -1), -1)
154 |
155 | emb_size = cap_i_expand_7.size(-1)
156 | cap_i_8 = self.kernel_txt_8(cap_i_7.reshape(-1, emb_size).unsqueeze(-2)).sum(1)
157 | cap_i_expand_8 = l2norm(cap_i_8.reshape(batch_size, n_word, -1), -1)
158 |
159 | return torch.cat([cap_i, cap_i_expand_1, cap_i_expand_2, cap_i_expand_3, cap_i_expand_4, cap_i_expand_5, cap_i_expand_6, cap_i_expand_7, cap_i_expand_8], -1)
160 |
161 |
162 | def forward(self, cap_i):
163 |
164 | batch_size, n_word, embed_size = cap_i.size(0), cap_i.size(1), cap_i.size(2)
165 |
166 | return self.get_text_levels(cap_i, batch_size, n_word)
167 |
168 | else:
169 |
170 | #### Adaptive Dimensional Selective Mask
171 | class Image_levels(nn.Module):
172 | def __init__(self, opt):
173 | super(Image_levels, self).__init__()
174 | self.sub_space = opt.embed_size
175 | self.kernel_size = int(opt.kernel_size)
176 | self.out_channels = 2
177 |
178 | self.masks_1 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 0))) # num_embedding, dims_input
179 | self.masks_2 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 1)))
180 | self.masks_3 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 2)))
181 | self.masks_4 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 3)))
182 | self.masks_5 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 4)))
183 | self.masks_6 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 5)))
184 | self.masks_7 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 6)))
185 | self.masks_8 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 7)))
186 |
187 |
188 | def get_image_levels(self, img_emb, batch_size, n_region):
189 |
190 |
191 | sub_space_index = torch.tensor(torch.linspace(0, 1024, steps=1024, dtype=torch.int)).cuda()
192 | Dim_learned_mask_1 = l2norm(self.masks_1(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 1))]), dim=-1)
193 | Dim_learned_mask_2 = l2norm(self.masks_2(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 2))]), dim=-1)
194 | Dim_learned_mask_3 = l2norm(self.masks_3(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 3))]), dim=-1)
195 | Dim_learned_mask_4 = l2norm(self.masks_4(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 4))]), dim=-1)
196 | Dim_learned_mask_5 = l2norm(self.masks_5(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 5))]), dim=-1)
197 | Dim_learned_mask_6 = l2norm(self.masks_6(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 6))]), dim=-1)
198 | Dim_learned_mask_7 = l2norm(self.masks_7(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 7))]), dim=-1)
199 | Dim_learned_mask_8 = l2norm(self.masks_8(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 8))]), dim=-1)
200 |
201 |
202 | if Dim_learned_mask_1.size(1) < self.out_channels:
203 | select_nums = Dim_learned_mask_1.size(1)
204 | else:
205 | select_nums = self.out_channels
206 | Dim_learned_range = Dim_learned_mask_1.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
207 | Dim_learned_mask_1 = (Dim_learned_mask_1 >= Dim_learned_range).float() * Dim_learned_mask_1
208 |
209 |
210 | if Dim_learned_mask_2.size(1) < self.out_channels:
211 | select_nums = Dim_learned_mask_2.size(1)
212 | else:
213 | select_nums = self.out_channels
214 | Dim_learned_range = Dim_learned_mask_2.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
215 | Dim_learned_mask_2 = (Dim_learned_mask_2 >= Dim_learned_range).float() * Dim_learned_mask_2
216 |
217 |
218 | if Dim_learned_mask_3.size(1) < self.out_channels:
219 | select_nums = Dim_learned_mask_3.size(1)
220 | else:
221 | select_nums = self.out_channels
222 | Dim_learned_range = Dim_learned_mask_3.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
223 | Dim_learned_mask_3 = (Dim_learned_mask_3 >= Dim_learned_range).float() * Dim_learned_mask_3
224 |
225 |
226 | if Dim_learned_mask_4.size(1) < self.out_channels:
227 | select_nums = Dim_learned_mask_4.size(1)
228 | else:
229 | select_nums = self.out_channels
230 | Dim_learned_range = Dim_learned_mask_4.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
231 | Dim_learned_mask_4 = (Dim_learned_mask_4 >= Dim_learned_range).float() * Dim_learned_mask_4
232 |
233 |
234 | if Dim_learned_mask_5.size(1) < self.out_channels:
235 | select_nums = Dim_learned_mask_5.size(1)
236 | else:
237 | select_nums = self.out_channels
238 | Dim_learned_range = Dim_learned_mask_5.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
239 | Dim_learned_mask_5 = (Dim_learned_mask_5 >= Dim_learned_range).float() * Dim_learned_mask_5
240 |
241 |
242 | if Dim_learned_mask_6.size(1) < self.out_channels:
243 | select_nums = Dim_learned_mask_6.size(1)
244 | else:
245 | select_nums = self.out_channels
246 | Dim_learned_range = Dim_learned_mask_6.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
247 | Dim_learned_mask_6 = (Dim_learned_mask_6 >= Dim_learned_range).float() * Dim_learned_mask_6
248 |
249 |
250 | if Dim_learned_mask_7.size(1) < self.out_channels:
251 | select_nums = Dim_learned_mask_7.size(1)
252 | else:
253 | select_nums = self.out_channels
254 | Dim_learned_range = Dim_learned_mask_7.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
255 | Dim_learned_mask_7 = (Dim_learned_mask_7 >= Dim_learned_range).float() * Dim_learned_mask_7
256 |
257 |
258 | if Dim_learned_mask_8.size(1) < self.out_channels:
259 | select_nums = Dim_learned_mask_8.size(1)
260 | else:
261 | select_nums = self.out_channels
262 | Dim_learned_range = Dim_learned_mask_8.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
263 | Dim_learned_mask_8 = (Dim_learned_mask_8 >= Dim_learned_range).float() * Dim_learned_mask_8
264 |
265 |
266 | img_emb_1 = img_emb.reshape(-1, self.sub_space) @ Dim_learned_mask_1.t()
267 | img_emb_1 = l2norm(img_emb_1.reshape(batch_size, n_region, -1), -1)
268 |
269 | emb_size = img_emb_1.size(-1)
270 | img_emb_2 = img_emb_1.reshape(-1, emb_size) @ Dim_learned_mask_2.t()
271 | img_emb_2 = l2norm(img_emb_2.reshape(batch_size, n_region, -1), -1)
272 |
273 | emb_size = img_emb_2.size(-1)
274 | img_emb_3 = img_emb_2.reshape(-1, emb_size) @ Dim_learned_mask_3.t()
275 | img_emb_3 = l2norm(img_emb_3.reshape(batch_size, n_region, -1), -1)
276 |
277 | emb_size = img_emb_3.size(-1)
278 | img_emb_4 = img_emb_3.reshape(-1, emb_size) @ Dim_learned_mask_4.t()
279 | img_emb_4 = l2norm(img_emb_4.reshape(batch_size, n_region, -1), -1)
280 |
281 | emb_size = img_emb_4.size(-1)
282 | img_emb_5 = img_emb_4.reshape(-1, emb_size) @ Dim_learned_mask_5.t()
283 | img_emb_5 = l2norm(img_emb_5.reshape(batch_size, n_region, -1), -1)
284 |
285 | emb_size = img_emb_5.size(-1)
286 | img_emb_6 = img_emb_5.reshape(-1, emb_size) @ Dim_learned_mask_6.t()
287 | img_emb_6 = l2norm(img_emb_6.reshape(batch_size, n_region, -1), -1)
288 |
289 | emb_size = img_emb_6.size(-1)
290 | img_emb_7 = img_emb_6.reshape(-1, emb_size) @ Dim_learned_mask_7.t()
291 | img_emb_7 = l2norm(img_emb_7.reshape(batch_size, n_region, -1), -1)
292 |
293 | emb_size = img_emb_7.size(-1)
294 | img_emb_8 = img_emb_7.reshape(-1, emb_size) @ Dim_learned_mask_8.t()
295 | img_emb_8 = l2norm(img_emb_8.reshape(batch_size, n_region, -1), -1)
296 |
297 |
298 | return torch.cat([img_emb, img_emb_1, img_emb_2, img_emb_3, img_emb_4, img_emb_5, img_emb_6, img_emb_7, img_emb_8], -1)
299 |
300 | def forward(self, img_emb):
301 |
302 | batch_size, n_region, embed_size = img_emb.size(0), img_emb.size(1), img_emb.size(2)
303 |
304 | return self.get_image_levels(img_emb, batch_size, n_region)
305 |
306 |
307 | class Text_levels(nn.Module):
308 |
309 | def __init__(self, opt):
310 | super(Text_levels, self).__init__()
311 | self.sub_space = opt.embed_size
312 | self.kernel_size = int(opt.kernel_size)
313 | self.out_channels = 2
314 |
315 | self.masks_1 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 0))) # num_embedding, dims_input
316 | self.masks_2 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 1)))
317 | self.masks_3 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 2)))
318 | self.masks_4 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 3)))
319 | self.masks_5 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 4)))
320 | self.masks_6 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 5)))
321 | self.masks_7 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 6)))
322 | self.masks_8 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 7)))
323 |
324 | def get_text_levels(self, cap_i, batch_size, n_word):
325 |
326 | sub_space_index = torch.tensor(torch.linspace(0, 1024, steps=1024, dtype=torch.int)).cuda()
327 | Dim_learned_mask_1 = l2norm(self.masks_1(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 1))]), dim=-1)
328 | Dim_learned_mask_2 = l2norm(self.masks_2(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 2))]), dim=-1)
329 | Dim_learned_mask_3 = l2norm(self.masks_3(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 3))]), dim=-1)
330 | Dim_learned_mask_4 = l2norm(self.masks_4(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 4))]), dim=-1)
331 | Dim_learned_mask_5 = l2norm(self.masks_5(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 5))]), dim=-1)
332 | Dim_learned_mask_6 = l2norm(self.masks_6(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 6))]), dim=-1)
333 | Dim_learned_mask_7 = l2norm(self.masks_7(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 7))]), dim=-1)
334 | Dim_learned_mask_8 = l2norm(self.masks_8(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 8))]), dim=-1)
335 |
336 |
337 | if Dim_learned_mask_1.size(1) < self.out_channels:
338 | select_nums = Dim_learned_mask_1.size(1)
339 | else:
340 | select_nums = self.out_channels
341 | Dim_learned_range = Dim_learned_mask_1.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
342 | Dim_learned_mask_1 = (Dim_learned_mask_1 >= Dim_learned_range).float() * Dim_learned_mask_1
343 |
344 |
345 | if Dim_learned_mask_2.size(1) < self.out_channels:
346 | select_nums = Dim_learned_mask_2.size(1)
347 | else:
348 | select_nums = self.out_channels
349 | Dim_learned_range = Dim_learned_mask_2.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
350 | Dim_learned_mask_2 = (Dim_learned_mask_2 >= Dim_learned_range).float() * Dim_learned_mask_2
351 |
352 |
353 | if Dim_learned_mask_3.size(1) < self.out_channels:
354 | select_nums = Dim_learned_mask_3.size(1)
355 | else:
356 | select_nums = self.out_channels
357 | Dim_learned_range = Dim_learned_mask_3.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
358 | Dim_learned_mask_3 = (Dim_learned_mask_3 >= Dim_learned_range).float() * Dim_learned_mask_3
359 |
360 |
361 | if Dim_learned_mask_4.size(1) < self.out_channels:
362 | select_nums = Dim_learned_mask_4.size(1)
363 | else:
364 | select_nums = self.out_channels
365 | Dim_learned_range = Dim_learned_mask_4.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
366 | Dim_learned_mask_4 = (Dim_learned_mask_4 >= Dim_learned_range).float() * Dim_learned_mask_4
367 |
368 |
369 | if Dim_learned_mask_5.size(1) < self.out_channels:
370 | select_nums = Dim_learned_mask_5.size(1)
371 | else:
372 | select_nums = self.out_channels
373 | Dim_learned_range = Dim_learned_mask_5.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
374 | Dim_learned_mask_5 = (Dim_learned_mask_5 >= Dim_learned_range).float() * Dim_learned_mask_5
375 |
376 |
377 | if Dim_learned_mask_6.size(1) < self.out_channels:
378 | select_nums = Dim_learned_mask_6.size(1)
379 | else:
380 | select_nums = self.out_channels
381 | Dim_learned_range = Dim_learned_mask_6.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
382 | Dim_learned_mask_6 = (Dim_learned_mask_6 >= Dim_learned_range).float() * Dim_learned_mask_6
383 |
384 |
385 | if Dim_learned_mask_7.size(1) < self.out_channels:
386 | select_nums = Dim_learned_mask_7.size(1)
387 | else:
388 | select_nums = self.out_channels
389 | Dim_learned_range = Dim_learned_mask_7.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
390 | Dim_learned_mask_7 = (Dim_learned_mask_7 >= Dim_learned_range).float() * Dim_learned_mask_7
391 |
392 |
393 | if Dim_learned_mask_8.size(1) < self.out_channels:
394 | select_nums = Dim_learned_mask_8.size(1)
395 | else:
396 | select_nums = self.out_channels
397 | Dim_learned_range = Dim_learned_mask_8.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1)
398 | Dim_learned_mask_8 = (Dim_learned_mask_8 >= Dim_learned_range).float() * Dim_learned_mask_8
399 |
400 | cap_i_1 = cap_i.reshape(-1, self.sub_space) @ Dim_learned_mask_1.t()
401 | cap_i_expand_1 = l2norm(cap_i_1.reshape(batch_size, n_word, -1), -1)
402 |
403 | emb_size = cap_i_expand_1.size(-1)
404 | cap_i_2 = cap_i_1.reshape(-1, emb_size) @ Dim_learned_mask_2.t()
405 | cap_i_expand_2 = l2norm(cap_i_2.reshape(batch_size, n_word, -1), -1)
406 |
407 | emb_size = cap_i_expand_2.size(-1)
408 | cap_i_3 = cap_i_2.reshape(-1, emb_size) @ Dim_learned_mask_3.t()
409 | cap_i_expand_3 = l2norm(cap_i_3.reshape(batch_size, n_word, -1), -1)
410 |
411 | emb_size = cap_i_expand_3.size(-1)
412 | cap_i_4 = cap_i_3.reshape(-1, emb_size) @ Dim_learned_mask_4.t()
413 | cap_i_expand_4 = l2norm(cap_i_4.reshape(batch_size, n_word, -1), -1)
414 |
415 | emb_size = cap_i_expand_4.size(-1)
416 | cap_i_5 = cap_i_4.reshape(-1, emb_size) @ Dim_learned_mask_5.t()
417 | cap_i_expand_5 = l2norm(cap_i_5.reshape(batch_size, n_word, -1), -1)
418 |
419 | emb_size = cap_i_expand_5.size(-1)
420 | cap_i_6 = cap_i_5.reshape(-1, emb_size) @ Dim_learned_mask_6.t()
421 | cap_i_expand_6 = l2norm(cap_i_6.reshape(batch_size, n_word, -1), -1)
422 |
423 | emb_size = cap_i_expand_6.size(-1)
424 | cap_i_7 = cap_i_6.reshape(-1, emb_size) @ Dim_learned_mask_7.t()
425 | cap_i_expand_7 = l2norm(cap_i_7.reshape(batch_size, n_word, -1), -1)
426 |
427 | emb_size = cap_i_expand_7.size(-1)
428 | cap_i_8 = cap_i_7.reshape(-1, emb_size) @ Dim_learned_mask_8.t()
429 | cap_i_expand_8 = l2norm(cap_i_8.reshape(batch_size, n_word, -1), -1)
430 |
431 |
432 | return torch.cat([cap_i, cap_i_expand_1, cap_i_expand_2, cap_i_expand_3, cap_i_expand_4, cap_i_expand_5, cap_i_expand_6, cap_i_expand_7, cap_i_expand_8], -1)
433 |
434 | def forward(self, cap_i):
435 |
436 | batch_size, n_word, embed_size = cap_i.size(0), cap_i.size(1), cap_i.size(2)
437 |
438 | return self.get_text_levels(cap_i, batch_size, n_word)
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 | class Image_Text_Encoders(nn.Module):
451 | def __init__(self, opt):
452 |
453 | super(Image_Text_Encoders, self).__init__()
454 | self.text_levels = Text_levels(opt)
455 | self.image_levels = Image_levels(opt)
456 |
457 | def forward(self, images, captions, return_type):
458 |
459 | if return_type == 'image':
460 | img_embs = self.image_levels(images)
461 | return img_embs
462 | else:
463 | cap_embs = self.text_levels(captions)
464 | return cap_embs
465 |
466 | class Image_Text_Processing(nn.Module):
467 |
468 | def __init__(self, opt):
469 | super(Image_Text_Processing, self).__init__()
470 | self.encoders_1 = Image_Text_Encoders(opt)
471 |
472 |
473 | def forward(self, images, captions):
474 |
475 | image_processed = self.encoders_1(images, captions, 'image')
476 | text_processed = self.encoders_1(images, captions, 'text')
477 |
478 | return image_processed, text_processed
479 |
480 |
481 |
482 |
483 | class sims_claculator(nn.Module):
484 | def __init__(self, opt):
485 | super(sims_claculator, self).__init__()
486 | self.sub_space = opt.embed_size
487 | self.kernel_size = int(opt.kernel_size)
488 |
489 | self.opt = opt
490 | self.sim_eval = nn.Linear(9, 1, bias=False)
491 | self.temp_scale = nn.Linear(1, 1, bias=False)
492 | self.temp_scale_1 = nn.Linear(1, 1, bias=False)
493 | self.temp_scale_2 = nn.Linear(1, 1, bias=False)
494 |
495 | self.masks_0 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 0)), int(opt.embed_size/math.pow(self.kernel_size, 0)))
496 | self.masks_1 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 1)))
497 | self.masks_2 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 2)))
498 | self.masks_3 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 3)))
499 | self.masks_4 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 4)))
500 | self.masks_5 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 5)))
501 | self.masks_6 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 6)))
502 | self.masks_7 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 7)))
503 | self.masks_8 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 8)))
504 |
505 | self.lynorm_0 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 0)), int(opt.embed_size/math.pow(self.kernel_size, 0))], eps=1e-08, elementwise_affine=True)
506 | self.lynorm_1 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 1))], eps=1e-08, elementwise_affine=True)
507 | self.lynorm_2 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 2))], eps=1e-08, elementwise_affine=True)
508 | self.lynorm_3 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 3))], eps=1e-08, elementwise_affine=True)
509 | self.lynorm_4 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 4))], eps=1e-08, elementwise_affine=True)
510 | self.lynorm_5 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 5))], eps=1e-08, elementwise_affine=True)
511 | self.lynorm_6 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 6))], eps=1e-08, elementwise_affine=True)
512 | self.lynorm_7 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 7))], eps=1e-08, elementwise_affine=True)
513 | self.lynorm_8 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 8))], eps=1e-08, elementwise_affine=True)
514 |
515 | self.list_length = [int(opt.embed_size/math.pow(self.kernel_size, 0)), int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 2)),
516 | int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 5)),
517 | int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 8))]
518 |
519 |
520 | self.init_weights()
521 |
522 | def init_weights(self):
523 | self.temp_scale.weight.data.fill_(np.log(1 / 0.07))
524 | self.sim_eval.weight.data.fill_(0.1)
525 | self.temp_scale_1.weight.data.fill_(0)
526 | self.temp_scale_2.weight.data.fill_(3)
527 |
528 | def get_weighted_features(self, attn, smooth):
529 | # --> (batch, sourceL, queryL)
530 | attnT = torch.transpose(attn, 1, 2).contiguous()
531 | attn = nn.LeakyReLU(0.1)(attnT)
532 | attn = l2norm(attn, 2)
533 | # --> (batch, queryL, sourceL)
534 | attn = torch.transpose(attn, 1, 2).contiguous()
535 | # --> (batch, queryL, sourceL
536 | attn = F.softmax(attn*smooth, dim=2)
537 |
538 | return attn
539 |
540 |
541 | def get_sims_levels(self, X_0, X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8,
542 | Y_0, Y_1, Y_2, Y_3, Y_4, Y_5, Y_6, Y_7, Y_8,
543 | D_0, D_1, D_2, D_3, D_4, D_5, D_6, D_7, D_8):
544 | attn_0 = (X_0 @ D_0 @ Y_0.transpose(1, 2))
545 | attn_1 = (X_1 @ D_1 @ Y_1.transpose(1, 2))
546 | attn_2 = (X_2 @ D_2 @ Y_2.transpose(1, 2))
547 | attn_3 = (X_3 @ D_3 @ Y_3.transpose(1, 2))
548 | attn_4 = (X_4 @ D_4 @ Y_4.transpose(1, 2))
549 | attn_5 = (X_5 @ D_5 @ Y_5.transpose(1, 2))
550 | attn_6 = (X_6 @ D_6 @ Y_6.transpose(1, 2))
551 | attn_7 = (X_7 @ D_7 @ Y_7.transpose(1, 2))
552 | attn_8 = (X_8 @ D_8 @ Y_8.transpose(1, 2))
553 |
554 | attn = attn_0 + attn_1 + attn_2 + attn_3 + attn_4 + attn_5 + attn_6 + attn_7 + attn_8
555 |
556 | return attn
557 |
558 | def forward(self, img_emb, cap_emb, cap_lens):
559 |
560 | n_caption = cap_emb.size(0)
561 | sim_all = []
562 | batch_size, n_region, embed_size = img_emb.size(0), img_emb.size(1), img_emb.size(2)
563 | sub_space_index = torch.tensor(torch.linspace(0, self.sub_space, steps=self.sub_space, dtype=torch.int)).cuda()
564 | smooth=torch.exp(self.temp_scale.weight)
565 |
566 | sigma_ = self.temp_scale_1.weight
567 | lambda_ = torch.exp(self.temp_scale_2.weight)
568 | threshold = (torch.abs(self.sim_eval.weight).max() - torch.abs(self.sim_eval.weight).min()) * sigma_ + torch.abs(self.sim_eval.weight).min()
569 | if not heuristic_strategy:
570 | lambda_ = 0
571 |
572 | weight_0 = (torch.exp((torch.abs(self.sim_eval.weight[0, 0]) - threshold) * lambda_) * self.sim_eval.weight[0, 0])
573 | weight_1 = (torch.exp((torch.abs(self.sim_eval.weight[0, 1]) - threshold) * lambda_) * self.sim_eval.weight[0, 1])
574 | weight_2 = (torch.exp((torch.abs(self.sim_eval.weight[0, 2]) - threshold) * lambda_) * self.sim_eval.weight[0, 2])
575 | weight_3 = (torch.exp((torch.abs(self.sim_eval.weight[0, 3]) - threshold) * lambda_) * self.sim_eval.weight[0, 3])
576 | weight_4 = (torch.exp((torch.abs(self.sim_eval.weight[0, 4]) - threshold) * lambda_) * self.sim_eval.weight[0, 4])
577 | weight_5 = (torch.exp((torch.abs(self.sim_eval.weight[0, 5]) - threshold) * lambda_) * self.sim_eval.weight[0, 5])
578 | weight_6 = (torch.exp((torch.abs(self.sim_eval.weight[0, 6]) - threshold) * lambda_) * self.sim_eval.weight[0, 6])
579 | weight_7 = (torch.exp((torch.abs(self.sim_eval.weight[0, 7]) - threshold) * lambda_) * self.sim_eval.weight[0, 7])
580 | weight_8 = (torch.exp((torch.abs(self.sim_eval.weight[0, 8]) - threshold) * lambda_) * self.sim_eval.weight[0, 8])
581 |
582 |
583 |
584 | Dim_learned_mask_0 = self.lynorm_0(self.masks_0(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 0))])) * weight_0
585 | Dim_learned_mask_1 = self.lynorm_1(self.masks_1(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 1))])) * weight_1
586 | Dim_learned_mask_2 = self.lynorm_2(self.masks_2(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 2))])) * weight_2
587 | Dim_learned_mask_3 = self.lynorm_3(self.masks_3(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 3))])) * weight_3
588 | Dim_learned_mask_4 = self.lynorm_4(self.masks_4(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 4))])) * weight_4
589 | Dim_learned_mask_5 = self.lynorm_5(self.masks_5(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 5))])) * weight_5
590 | Dim_learned_mask_6 = self.lynorm_6(self.masks_6(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 6))])) * weight_6
591 | Dim_learned_mask_7 = self.lynorm_7(self.masks_7(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 7))])) * weight_7
592 | Dim_learned_mask_8 = self.lynorm_8(self.masks_8(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 8))])) * weight_8
593 |
594 |
595 |
596 | for i in range(n_caption):
597 | # get the i-th sentence
598 | n_word = cap_lens[i]
599 |
600 | ## ------------------------------------------------------------------------------------------------------------------------
601 | # attention
602 |
603 | attn = self.get_sims_levels(
604 | cap_emb[i, :n_word, :sum(self.list_length[:1])].unsqueeze(0).repeat(batch_size, 1, 1),
605 | cap_emb[i, :n_word, sum(self.list_length[:1]):sum(self.list_length[:2])].unsqueeze(0).repeat(batch_size, 1, 1),
606 | cap_emb[i, :n_word, sum(self.list_length[:2]):sum(self.list_length[:3])].unsqueeze(0).repeat(batch_size, 1, 1),
607 | cap_emb[i, :n_word, sum(self.list_length[:3]):sum(self.list_length[:4])].unsqueeze(0).repeat(batch_size, 1, 1),
608 | cap_emb[i, :n_word, sum(self.list_length[:4]):sum(self.list_length[:5])].unsqueeze(0).repeat(batch_size, 1, 1),
609 | cap_emb[i, :n_word, sum(self.list_length[:5]):sum(self.list_length[:6])].unsqueeze(0).repeat(batch_size, 1, 1),
610 | cap_emb[i, :n_word, sum(self.list_length[:6]):sum(self.list_length[:7])].unsqueeze(0).repeat(batch_size, 1, 1),
611 | cap_emb[i, :n_word, sum(self.list_length[:7]):sum(self.list_length[:8])].unsqueeze(0).repeat(batch_size, 1, 1),
612 | cap_emb[i, :n_word, sum(self.list_length[:8]):sum(self.list_length[:9])].unsqueeze(0).repeat(batch_size, 1, 1),
613 | img_emb[:, :, :sum(self.list_length[:1])],
614 | img_emb[:, :, sum(self.list_length[:1]):sum(self.list_length[:2])],
615 | img_emb[:, :, sum(self.list_length[:2]):sum(self.list_length[:3])],
616 | img_emb[:, :, sum(self.list_length[:3]):sum(self.list_length[:4])],
617 | img_emb[:, :, sum(self.list_length[:4]):sum(self.list_length[:5])],
618 | img_emb[:, :, sum(self.list_length[:5]):sum(self.list_length[:6])],
619 | img_emb[:, :, sum(self.list_length[:6]):sum(self.list_length[:7])],
620 | img_emb[:, :, sum(self.list_length[:7]):sum(self.list_length[:8])],
621 | img_emb[:, :, sum(self.list_length[:8]):sum(self.list_length[:9])],
622 | Dim_learned_mask_0, Dim_learned_mask_1, Dim_learned_mask_2, Dim_learned_mask_3, Dim_learned_mask_4, Dim_learned_mask_5, Dim_learned_mask_6, Dim_learned_mask_7, Dim_learned_mask_8)
623 |
624 | ##################################################################################################
625 | # --> (batch, sourceL, queryL)
626 | attnT = torch.transpose(attn, 1, 2).contiguous()
627 | attn_t2i_weight = self.get_weighted_features(torch.tanh(attn), smooth)
628 | sims_t2i = attn.mul(attn_t2i_weight).sum(-1).mean(dim=1, keepdim=True)
629 | ##################################################################################################
630 | attn_i2t_weight = self.get_weighted_features(torch.tanh(attnT), smooth)
631 | sims_i2t = attnT.mul(attn_i2t_weight).sum(-1).mean(dim=1, keepdim=True)
632 | ##################################################################################################
633 |
634 | sims = sims_t2i + sims_i2t
635 | sim_all.append(sims)
636 |
637 | sim_all = torch.cat(sim_all, 1)
638 | return sim_all
639 |
640 |
641 |
642 |
643 | class Sims_Measuring(nn.Module):
644 | def __init__(self, opt):
645 | super(Sims_Measuring, self).__init__()
646 |
647 | self.calculator_1 = sims_claculator(opt)
648 |
649 | def forward(self, img_embs, cap_embs, lengths):
650 |
651 | sims = self.calculator_1(img_embs, cap_embs, lengths)
652 |
653 | return sims
654 |
655 |
656 |
657 | class Sim_vec(nn.Module):
658 |
659 | def __init__(self, embed_size, opt):
660 | super(Sim_vec, self).__init__()
661 | self.plus_encoder = Image_Text_Processing(opt)
662 | self.sims = Sims_Measuring(opt)
663 |
664 | def forward(self, img_emb, cap_emb, cap_lens, is_Train):
665 |
666 | region_features, word_features = self.plus_encoder(img_emb, cap_emb)
667 | sims = self.sims(region_features, word_features, cap_lens)
668 |
669 | return sims, sims
670 |
671 |
672 |
673 |
674 |
675 | class VSEModel(object):
676 |
677 |
678 | def __init__(self, opt):
679 | # Build Models
680 | self.grad_clip = opt.grad_clip
681 | self.img_enc = get_image_encoder(opt.data_name, opt.img_dim, opt.embed_size,
682 | precomp_enc_type=opt.precomp_enc_type,
683 | backbone_source=opt.backbone_source,
684 | backbone_path=opt.backbone_path,
685 | no_imgnorm=opt.no_imgnorm)
686 | self.txt_enc = get_text_encoder(opt.embed_size, no_txtnorm=opt.no_txtnorm)
687 |
688 | self.sim_vec = Sim_vec(opt.embed_size, opt)
689 |
690 | if torch.cuda.is_available():
691 | self.img_enc.cuda()
692 | self.txt_enc.cuda()
693 | self.sim_vec.cuda()
694 | cudnn.benchmark = True
695 |
696 | # Loss and Optimizer
697 | self.criterion = ContrastiveLoss(opt=opt,
698 | margin=opt.margin,
699 | max_violation=opt.max_violation)
700 |
701 | params = list(self.txt_enc.parameters())
702 | params += list(self.img_enc.parameters())
703 | params += list(self.sim_vec.parameters())
704 |
705 | self.params = params
706 | self.opt = opt
707 |
708 | # Set up the lr for different parts of the VSE model
709 | decay_factor = 1e-4
710 | if opt.precomp_enc_type == 'basic':
711 | if self.opt.optim == 'adam':
712 | all_text_params = list(self.txt_enc.parameters())
713 | bert_params = list(self.txt_enc.bert.parameters())
714 | bert_params_ptr = [p.data_ptr() for p in bert_params]
715 | text_params_no_bert = list()
716 | for p in all_text_params:
717 | if p.data_ptr() not in bert_params_ptr:
718 | text_params_no_bert.append(p)
719 | self.optimizer = torch.optim.AdamW([
720 | {'params': text_params_no_bert, 'lr': opt.learning_rate},
721 | {'params': bert_params, 'lr': opt.learning_rate * 0.1},
722 | {'params': self.img_enc.parameters(), 'lr': opt.learning_rate},
723 | {'params': self.sim_vec.parameters(), 'lr': opt.learning_rate},
724 | ],
725 | lr=opt.learning_rate, weight_decay=decay_factor)
726 | elif self.opt.optim == 'sgd':
727 | self.optimizer = torch.optim.SGD(self.params, lr=opt.learning_rate, momentum=0.9)
728 | else:
729 | raise ValueError('Invalid optim option {}'.format(self.opt.optim))
730 | else:
731 | if self.opt.optim == 'adam':
732 | all_text_params = list(self.txt_enc.parameters())
733 | bert_params = list(self.txt_enc.bert.parameters())
734 | bert_params_ptr = [p.data_ptr() for p in bert_params]
735 | text_params_no_bert = list()
736 | for p in all_text_params:
737 | if p.data_ptr() not in bert_params_ptr:
738 | text_params_no_bert.append(p)
739 | self.optimizer = torch.optim.AdamW([
740 | {'params': text_params_no_bert, 'lr': opt.learning_rate},
741 | {'params': bert_params, 'lr': opt.learning_rate * 0.1},
742 | {'params': self.img_enc.backbone.top.parameters(),
743 | 'lr': opt.learning_rate * opt.backbone_lr_factor, },
744 | {'params': self.img_enc.backbone.base.parameters(),
745 | 'lr': opt.learning_rate * opt.backbone_lr_factor, },
746 | {'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate},
747 | ], lr=opt.learning_rate, weight_decay=decay_factor)
748 | elif self.opt.optim == 'sgd':
749 | self.optimizer = torch.optim.SGD([
750 | {'params': self.txt_enc.parameters(), 'lr': opt.learning_rate},
751 | {'params': self.img_enc.backbone.parameters(), 'lr': opt.learning_rate * opt.backbone_lr_factor,
752 | 'weight_decay': decay_factor},
753 | {'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate},
754 | ], lr=opt.learning_rate, momentum=0.9, nesterov=True)
755 | else:
756 | raise ValueError('Invalid optim option {}'.format(self.opt.optim))
757 |
758 | logger.info('Use {} as the optimizer, with init lr {}'.format(self.opt.optim, opt.learning_rate))
759 |
760 | self.Eiters = 0
761 | self.data_parallel = False
762 |
763 | def set_max_violation(self, max_violation):
764 | if max_violation:
765 | self.criterion.max_violation_on()
766 | else:
767 | self.criterion.max_violation_off()
768 |
769 | def state_dict(self):
770 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(),
771 | self.sim_vec.state_dict()]
772 | return state_dict
773 |
774 | def load_state_dict(self, state_dict):
775 | self.img_enc.load_state_dict(state_dict[0], strict=False)
776 | self.txt_enc.load_state_dict(state_dict[1], strict=False)
777 | self.sim_vec.load_state_dict(state_dict[2], strict=False)
778 |
779 | def train_start(self):
780 | """switch to train mode
781 | """
782 | self.img_enc.train()
783 | self.txt_enc.train()
784 | self.sim_vec.train()
785 |
786 | def val_start(self):
787 | """switch to evaluate mode
788 | """
789 | self.img_enc.eval()
790 | self.txt_enc.eval()
791 | self.sim_vec.eval()
792 |
793 | def freeze_backbone(self):
794 | if 'backbone' in self.opt.precomp_enc_type:
795 | if isinstance(self.img_enc, nn.DataParallel):
796 | self.img_enc.module.freeze_backbone()
797 | else:
798 | self.img_enc.freeze_backbone()
799 |
800 | def unfreeze_backbone(self, fixed_blocks):
801 | if 'backbone' in self.opt.precomp_enc_type:
802 | if isinstance(self.img_enc, nn.DataParallel):
803 | self.img_enc.module.unfreeze_backbone(fixed_blocks)
804 | else:
805 | self.img_enc.unfreeze_backbone(fixed_blocks)
806 |
807 | def make_data_parallel(self):
808 | self.img_enc = nn.DataParallel(self.img_enc)
809 | self.txt_enc = nn.DataParallel(self.txt_enc)
810 | self.sim_vec = nn.DataParallel(self.sim_vec)
811 | self.data_parallel = True
812 | logger.info('Image encoder is data paralleled now.')
813 |
814 | @property
815 | def is_data_parallel(self):
816 | return self.data_parallel
817 |
818 | def forward_emb(self, images, captions, lengths, image_lengths=None):
819 | """Compute the image and caption embeddings
820 | """
821 | # Set mini-batch dataset
822 | if self.opt.precomp_enc_type == 'basic':
823 | if torch.cuda.is_available():
824 | images = images.cuda()
825 | captions = captions.cuda()
826 | image_lengths = image_lengths.cuda()
827 | img_emb = self.img_enc(images, image_lengths)
828 | else:
829 | if torch.cuda.is_available():
830 | images = images.cuda()
831 | captions = captions.cuda()
832 | img_emb = self.img_enc(images)
833 |
834 | # lengths = torch.Tensor(lengths).cuda()
835 | cap_emb = self.txt_enc(captions, lengths)
836 | return img_emb, cap_emb, lengths
837 |
838 | def forward_sim(self, img_emb, cap_emb, cap_lens):
839 | is_Train = True
840 | sim_all, L1 = self.sim_vec(img_emb, cap_emb, cap_lens, is_Train)
841 |
842 | return sim_all, L1
843 |
844 | def forward_sim_test(self, img_emb, cap_emb, cap_lens):
845 | is_Train = False
846 | sim_all, L1 = self.sim_vec(img_emb, cap_emb, cap_lens, is_Train)
847 |
848 | return sim_all, L1
849 |
850 |
851 | def forward_loss(self, sims):
852 | """Compute the loss given pairs of image and caption embeddings
853 | """
854 | loss = self.criterion(sims)
855 | self.logger.update('Le', loss.data.item(), sims.size(0))
856 | return loss
857 |
858 | def train_emb(self, images, captions, lengths, image_lengths=None, warmup_alpha=None):
859 | """One training step given images and captions.
860 | """
861 | self.Eiters += 1
862 | self.logger.update('Eit', self.Eiters)
863 | self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
864 |
865 | # compute the embeddings
866 | img_emb, cap_emb, cap_lens = self.forward_emb(images, captions, lengths, image_lengths=image_lengths)
867 | sims, L1= self.forward_sim(img_emb, cap_emb, cap_lens)
868 |
869 | # measure accuracy and record loss
870 | self.optimizer.zero_grad()
871 | loss = self.forward_loss(sims)
872 |
873 | if warmup_alpha is not None:
874 | loss = loss * warmup_alpha
875 |
876 |
877 | message = "%f\n" %(loss)
878 | log_file2 = os.path.join(self.opt.logger_name, "loss.txt")
879 | logging_func(log_file2, message)
880 |
881 |
882 | # compute gradient and update
883 | loss.backward()
884 | if self.grad_clip > 0:
885 | clip_grad_norm_(self.params, self.grad_clip)
886 | self.optimizer.step()
--------------------------------------------------------------------------------
/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/motivation.png
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/overview.png
--------------------------------------------------------------------------------
/runs/runX/log/performance.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/runs/runX/log/performance.log
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from lib import evaluation
3 |
4 | os.environ["CUDA_VISIBLE_DEVICES"] = "7"
5 |
6 |
7 | ## for heuristic_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as True
8 | # RUN_PATH = "../checkpoint_heuristic_flickr30k_bert.tar"
9 |
10 |
11 | ## for adaptive_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as False
12 | RUN_PATH = "../checkpoint_adaptive_flickr30k_bert.tar"
13 |
14 |
15 | DATA_PATH = "../Flickr30K/"
16 | evaluation.evalrank(RUN_PATH, data_path=DATA_PATH, split="test")
17 |
--------------------------------------------------------------------------------
/test_stack.py:
--------------------------------------------------------------------------------
1 | # from vocab import Vocabulary
2 | # import evaluation
3 | import numpy as np
4 | import os
5 |
6 |
7 | def i2t(im_len, sims, npts=None, return_ranks=False):
8 | """
9 | Images->Text (Image Annotation)
10 | Images: (N, n_region, d) matrix of images
11 | Captions: (5N, max_n_word, d) matrix of captions
12 | CapLens: (5N) array of caption lengths
13 | sims: (N, 5N) matrix of similarity im-cap
14 | """
15 | npts = im_len
16 | ranks = np.zeros(npts)
17 | top1 = np.zeros(npts)
18 | for index in range(npts):
19 | inds = np.argsort(sims[index])[::-1]
20 | # Score
21 | rank = 1e20
22 | for i in range(5 * index, 5 * index + 5, 1):
23 | tmp = np.where(inds == i)[0][0]
24 | if tmp < rank:
25 | rank = tmp
26 | ranks[index] = rank
27 | top1[index] = inds[0]
28 |
29 | # Compute metrics
30 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
31 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
32 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
33 | medr = np.floor(np.median(ranks)) + 1
34 | meanr = ranks.mean() + 1
35 | if return_ranks:
36 | return (r1, r5, r10, medr, meanr), (ranks, top1)
37 | else:
38 | return (r1, r5, r10, medr, meanr)
39 |
40 |
41 | def t2i(im_len, sims, npts=None, return_ranks=False):
42 | """
43 | Text->Images (Image Search)
44 | Images: (N, n_region, d) matrix of images
45 | Captions: (5N, max_n_word, d) matrix of captions
46 | CapLens: (5N) array of caption lengths
47 | sims: (N, 5N) matrix of similarity im-cap
48 | """
49 | npts = im_len
50 | ranks = np.zeros(5 * npts)
51 | top1 = np.zeros(5 * npts)
52 |
53 | # --> (5N(caption), N(image))
54 | sims = sims.T
55 |
56 | for index in range(npts):
57 | for i in range(5):
58 | inds = np.argsort(sims[5 * index + i])[::-1]
59 | ranks[5 * index + i] = np.where(inds == index)[0][0]
60 | top1[5 * index + i] = inds[0]
61 |
62 | # Compute metrics
63 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
64 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
65 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
66 | medr = np.floor(np.median(ranks)) + 1
67 | meanr = ranks.mean() + 1
68 | if return_ranks:
69 | return (r1, r5, r10, medr, meanr), (ranks, top1)
70 | else:
71 | return (r1, r5, r10, medr, meanr)
72 |
73 |
74 |
75 | if __name__ == '__main__':
76 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
77 |
78 | isfold5 = True
79 |
80 | if not isfold5:
81 |
82 | # ## Flickr30K
83 | # Path_of_Model_1 = ''
84 | # Path_of_Model_2 = ''
85 |
86 | ## MS-COCO
87 | Path_of_Model_1 = ''
88 | Path_of_Model_2 = ''
89 |
90 | sims1 = np.loadtxt(Path_of_Model_1)
91 | sims2 = np.loadtxt(Path_of_Model_2)
92 |
93 | sims = (sims1 + sims2)
94 | im_len = len(sims)
95 | print('im length:', im_len)
96 | r, rt = i2t(im_len, sims, return_ranks=True)
97 | ri, rti = t2i(im_len, sims, return_ranks=True)
98 | ar = (r[0] + r[1] + r[2]) / 3
99 | ari = (ri[0] + ri[1] + ri[2]) / 3
100 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
101 | print("rsum: %.1f" % rsum)
102 | print("Average i2t Recall: %.1f" % ar)
103 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
104 | print("Average t2i Recall: %.1f" % ari)
105 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
106 | else:
107 | results = []
108 | for i in range(5):
109 |
110 | Path_of_Model_1 = '/mnt/data2/zk/ESL_bert/checkpoint2/COCO-LEARNABLE/'
111 | Path_of_Model_2 = '/mnt/data2/zk/ESL_bert/checkpoint2/COCO-NON-LEARNABLE/'
112 |
113 | sims1 = np.loadtxt(Path_of_Model_1 + str(i) + 'sim_best.txt')
114 | sims2 = np.loadtxt(Path_of_Model_2 + str(i) + 'sim_best.txt')
115 |
116 | sim_shard = (sims1 + sims2) / 2
117 | im_len = len(sim_shard)
118 | print('im length:', im_len)
119 | r, rt0 = i2t(im_len, sim_shard, return_ranks=True)
120 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
121 | ri, rti0 = t2i(im_len, sim_shard, return_ranks=True)
122 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
123 |
124 | if i == 0:
125 | rt, rti = rt0, rti0
126 | ar = (r[0] + r[1] + r[2]) / 3
127 | ari = (ri[0] + ri[1] + ri[2]) / 3
128 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
129 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
130 | results += [list(r) + list(ri) + [ar, ari, rsum]]
131 |
132 | print("-----------------------------------")
133 | print("Mean metrics: ")
134 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
135 | print("rsum: %.1f" % ( mean_metrics[12]))
136 | print("Average i2t Recall: %.1f" % mean_metrics[11])
137 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
138 | mean_metrics[:5])
139 | print("Average t2i Recall: %.1f" % mean_metrics[12])
140 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
141 | mean_metrics[5:10])
142 |
--------------------------------------------------------------------------------
/testall.py:
--------------------------------------------------------------------------------
1 | from logging import fatal
2 | import os
3 | from pickle import TRUE
4 | from lib import evaluation
5 |
6 | import torch
7 | torch.set_num_threads(4)
8 |
9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
10 |
11 |
12 | ## for heuristic_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as True
13 | # RUN_PATH = "../checkpoint_heuristic_mscoco_bert.tar"
14 |
15 |
16 | ## for adaptive_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as False
17 | RUN_PATH = "../checkpoint_adaptive_mscoco_bert.tar"
18 |
19 | DATA_PATH = "../MS-COCO/"
20 |
21 | ### set fold5 as Flase for 5K TEST, otherwise for AVERAGE 1K TEST
22 | evaluation.evalrank(RUN_PATH, data_path=DATA_PATH, split="testall", fold5=False)
23 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Training script"""
2 | import os
3 | import time
4 | import numpy as np
5 | import torch
6 | from transformers import BertTokenizer
7 |
8 | from lib.datasets import image_caption
9 | from lib.vse import VSEModel
10 | from lib.evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data, shard_attn_scores
11 |
12 | import logging
13 | import tensorboard_logger as tb_logger
14 |
15 | import arguments
16 |
17 |
18 | def logging_func(log_file, message):
19 | with open(log_file, 'a') as f:
20 | f.write(message)
21 | f.close()
22 |
23 |
24 | def main():
25 | # Hyper Parameters
26 | parser = arguments.get_argument_parser()
27 | opt = parser.parse_args()
28 |
29 | if not os.path.exists(opt.model_name):
30 | os.makedirs(opt.model_name)
31 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
32 | tb_logger.configure(opt.logger_name, flush_secs=5)
33 |
34 | logger = logging.getLogger(__name__)
35 | logger.info(opt)
36 |
37 | # Load Tokenizer and Vocabulary
38 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
39 | vocab = tokenizer.vocab
40 | opt.vocab_size = len(vocab)
41 |
42 | train_loader, val_loader = image_caption.get_loaders(
43 | opt.data_path, opt.data_name, tokenizer, opt.batch_size, opt.workers, opt)
44 |
45 | model = VSEModel(opt)
46 |
47 | lr_schedules = [opt.lr_update, ]
48 |
49 | # optionally resume from a checkpoint
50 | start_epoch = 0
51 | if opt.resume:
52 | if os.path.isfile(opt.resume):
53 | logger.info("=> loading checkpoint '{}'".format(opt.resume))
54 | checkpoint = torch.load(opt.resume)
55 | start_epoch = checkpoint['epoch']
56 | best_rsum = checkpoint['best_rsum']
57 | if not model.is_data_parallel:
58 | model.make_data_parallel()
59 | model.load_state_dict(checkpoint['model'])
60 | # Eiters is used to show logs as the continuation of another training
61 | model.Eiters = checkpoint['Eiters']
62 | logger.info("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
63 | .format(opt.resume, start_epoch, best_rsum))
64 | # validate(opt, val_loader, model)
65 | if opt.reset_start_epoch:
66 | start_epoch = 0
67 | else:
68 | logger.info("=> no checkpoint found at '{}'".format(opt.resume))
69 |
70 | if not model.is_data_parallel:
71 | model.make_data_parallel()
72 |
73 | # Train the Model
74 | best_rsum = 0
75 | for epoch in range(start_epoch, opt.num_epochs):
76 | logger.info(opt.logger_name)
77 | logger.info(opt.model_name)
78 |
79 | adjust_learning_rate(opt, model.optimizer, epoch, lr_schedules)
80 |
81 | if epoch >= opt.vse_mean_warmup_epochs:
82 | opt.max_violation = True
83 | model.set_max_violation(opt.max_violation)
84 |
85 | # Set up the all warm-up options
86 | if opt.precomp_enc_type == 'backbone':
87 | if epoch < opt.embedding_warmup_epochs:
88 | model.freeze_backbone()
89 | logger.info('All backbone weights are frozen, only train the embedding layers')
90 | else:
91 | model.unfreeze_backbone(3)
92 |
93 | if epoch < opt.embedding_warmup_epochs:
94 | logger.info('Warm up the embedding layers')
95 | elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs:
96 | model.unfreeze_backbone(3) # only train the last block of resnet backbone
97 | elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs * 2:
98 | model.unfreeze_backbone(2)
99 | elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs * 3:
100 | model.unfreeze_backbone(1)
101 | else:
102 | model.unfreeze_backbone(0)
103 |
104 | # train for one epoch
105 | train(opt, train_loader, model, epoch, val_loader)
106 |
107 | # evaluate on validation set
108 | rsum = validate(opt, val_loader, model, epoch)
109 |
110 | # remember best R@ sum and save checkpoint
111 | is_best = rsum > best_rsum
112 | best_rsum = max(rsum, best_rsum)
113 | if not os.path.exists(opt.model_name):
114 | os.mkdir(opt.model_name)
115 |
116 | save_checkpoint({
117 | 'epoch': epoch + 1,
118 | 'model': model.state_dict(),
119 | 'best_rsum': best_rsum,
120 | 'opt': opt,
121 | 'Eiters': model.Eiters,
122 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/')
123 |
124 |
125 | def train(opt, train_loader, model, epoch, val_loader):
126 | # average meters to record the training statistics
127 | logger = logging.getLogger(__name__)
128 | batch_time = AverageMeter()
129 | data_time = AverageMeter()
130 | train_logger = LogCollector()
131 |
132 | logger.info('image encoder trainable parameters: {}'.format(count_params(model.img_enc)))
133 | logger.info('txt encoder trainable parameters: {}'.format(count_params(model.txt_enc)))
134 |
135 | num_loader_iter = len(train_loader.dataset) // train_loader.batch_size + 1
136 |
137 | end = time.time()
138 | # opt.viz = True
139 | for i, train_data in enumerate(train_loader):
140 | # switch to train mode
141 | model.train_start()
142 |
143 | # measure data loading time
144 | data_time.update(time.time() - end)
145 |
146 | # make sure train logger is used
147 | model.logger = train_logger
148 |
149 | # Update the model
150 | if opt.precomp_enc_type == 'basic':
151 | images, img_lengths, captions, lengths, _ = train_data
152 | model.train_emb(images, captions, lengths, image_lengths=img_lengths)
153 | else:
154 | images, captions, lengths, _ = train_data
155 | if epoch == opt.embedding_warmup_epochs:
156 | warmup_alpha = float(i) / num_loader_iter
157 | model.train_emb(images, captions, lengths, warmup_alpha=warmup_alpha)
158 | else:
159 | model.train_emb(images, captions, lengths)
160 |
161 | # measure elapsed time
162 | batch_time.update(time.time() - end)
163 | end = time.time()
164 |
165 | # logger.info log info
166 | if model.Eiters % opt.log_step == 0:
167 | if opt.precomp_enc_type == 'backbone' and epoch == opt.embedding_warmup_epochs:
168 | logging.info('Current epoch-{}, the first epoch for training backbone, warmup alpha {}'.format(epoch,
169 | warmup_alpha))
170 | logging.info(
171 | 'Epoch: [{0}][{1}/{2}]\t'
172 | '{e_log}\t'
173 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
174 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
175 | .format(
176 | epoch, i, len(train_loader.dataset) // train_loader.batch_size + 1, batch_time=batch_time,
177 | data_time=data_time, e_log=str(model.logger)))
178 |
179 | # Record logs in tensorboard
180 | tb_logger.log_value('epoch', epoch, step=model.Eiters)
181 | tb_logger.log_value('step', i, step=model.Eiters)
182 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters)
183 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters)
184 | model.logger.tb_log(tb_logger, step=model.Eiters)
185 |
186 |
187 | def validate(opt, val_loader, model, epoch):
188 | logger = logging.getLogger(__name__)
189 | model.val_start()
190 | with torch.no_grad():
191 | # compute the encoding for all the validation images and captions
192 | img_embs, cap_embs, cap_lens = encode_data(
193 | model, val_loader, opt.log_step, logging.info, backbone=opt.precomp_enc_type == 'backbone')
194 |
195 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
196 |
197 | start = time.time()
198 | # sims = compute_sim(img_embs, cap_embs)
199 | sims = shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=opt.batch_size)
200 | end = time.time()
201 | logger.info("calculate similarity time: {}".format(end - start))
202 |
203 | # caption retrieval
204 | npts = img_embs.shape[0]
205 | # (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, cap_lens, sims)
206 | (r1, r5, r10, medr, meanr) = i2t(npts, sims)
207 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
208 | (r1, r5, r10, medr, meanr))
209 | # image retrieval
210 | # (r1i, r5i, r10i, medri, meanr) = t2i(img_embs, cap_embs, cap_lens, sims)
211 | (r1i, r5i, r10i, medri, meanr) = t2i(npts, sims)
212 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
213 | (r1i, r5i, r10i, medri, meanr))
214 | # sum of recalls to be used for early stopping
215 | currscore = r1 + r5 + r10 + r1i + r5i + r10i
216 | logger.info('Current rsum is {}'.format(currscore))
217 |
218 |
219 | message = "Epoch: %d: Image to text: (%.1f, %.1f, %.1f) " % (epoch, r1, r5, r10)
220 | message += "Text to image: (%.1f, %.1f, %.1f) " % (r1i, r5i, r10i)
221 | message += "rsum: %.1f\n" % currscore
222 |
223 | log_file = os.path.join(opt.logger_name, "performance.log")
224 | logging_func(log_file, message)
225 |
226 | return currscore
227 |
228 |
229 | def save_checkpoint(state, is_best, filename='checkpoint.pth', prefix=''):
230 | logger = logging.getLogger(__name__)
231 | tries = 15
232 |
233 | # deal with unstable I/O. Usually not necessary.
234 | while tries:
235 | try:
236 | torch.save(state, prefix + filename)
237 | if is_best:
238 | torch.save(state, prefix + 'model_best.pth')
239 | except IOError as e:
240 | error = e
241 | tries -= 1
242 | else:
243 | break
244 | logger.info('model save {} failed, remaining {} trials'.format(filename, tries))
245 | if not tries:
246 | raise error
247 |
248 |
249 | def adjust_learning_rate(opt, optimizer, epoch, lr_schedules):
250 | logger = logging.getLogger(__name__)
251 | """Sets the learning rate to the initial LR
252 | decayed by 10 every opt.lr_update epochs"""
253 | if epoch in lr_schedules:
254 | logger.info('Current epoch num is {}, decrease all lr by 10'.format(epoch, ))
255 | for param_group in optimizer.param_groups:
256 | old_lr = param_group['lr']
257 | new_lr = old_lr * 0.1
258 | param_group['lr'] = new_lr
259 | logger.info('new lr {}'.format(new_lr))
260 |
261 |
262 | def count_params(model):
263 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
264 | params = sum([np.prod(p.size()) for p in model_parameters])
265 | return params
266 |
267 |
268 | if __name__ == '__main__':
269 | os.environ["CUDA_VISIBLE_DEVICES"] = "4"
270 | main()
271 |
--------------------------------------------------------------------------------
/train_region_coco.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python3 ../ESL_MAIN_BERT/train.py \
2 | --data_path ../MS-COCO/ \
3 | --data_name coco_precomp \
4 | --logger_name ../ESL_MAIN_BERT/log \
5 | --model_name ../ESL_MAIN_BERT/checkpoint \
6 | --batch_size 128 \
7 | --num_epochs=20 \
8 | --lr_update=10 \
9 | --learning_rate=.0005 \
10 | --precomp_enc_type basic \
11 | --workers 10 \
12 | --log_step 200 \
13 | --embed_size 512 \
14 | --vse_mean_warmup_epochs 1 \
15 | --kernel_size 2
16 |
--------------------------------------------------------------------------------
/train_region_f30k.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python3 ../ESL_MAIN_BERT/train.py \
2 | --data_path ../Flickr30K/ \
3 | --data_name f30k_precomp \
4 | --logger_name ../ESL_MAIN_BERT/log \
5 | --model_name ../ESL_MAIN_BERT/checkpoint2 \
6 | --batch_size 128 \
7 | --num_epochs=30 \
8 | --lr_update=15 \
9 | --learning_rate=.0005 \
10 | --precomp_enc_type basic \
11 | --workers 10 \
12 | --log_step 200 \
13 | --embed_size 512 \
14 | --vse_mean_warmup_epochs 1 \
15 | --kernel_size 2
16 |
--------------------------------------------------------------------------------