├── ESL.yaml ├── README.md ├── arguments.py ├── eval.py ├── eval_ensemble.py ├── lib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── encoders.cpython-36.pyc │ ├── evaluation.cpython-36.pyc │ ├── loss.cpython-36.pyc │ └── vse.cpython-36.pyc ├── _init_paths.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── image_caption.cpython-36.pyc │ └── image_caption.py ├── encoders.py ├── evaluation.py ├── loss.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── mlp.cpython-36.pyc │ │ └── resnet.cpython-36.pyc │ ├── aggr │ │ ├── __init__.py │ │ └── gpo.py │ ├── mlp.py │ └── resnet.py ├── pytorch-logo-dark.png └── vse.py ├── motivation.png ├── overview.png ├── runs └── runX │ └── log │ └── performance.log ├── test.py ├── test_stack.py ├── testall.py ├── train.py ├── train_region_coco.sh └── train_region_f30k.sh /ESL.yaml: -------------------------------------------------------------------------------- 1 | name: py36 2 | channels: 3 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/free 4 | - pytorch 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 7 | - conda-forge 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 12 | - defaults 13 | dependencies: 14 | - _libgcc_mutex=0.1=conda_forge 15 | - _openmp_mutex=4.5=1_gnu 16 | - blas=1.0=mkl 17 | - bzip2=1.0.8=h7f98852_4 18 | - ca-certificates=2023.7.22=hbcca054_0 19 | - certifi=2016.9.26=py36_0 20 | - cudatoolkit=11.1.1=h6406543_8 21 | - dataclasses=0.8=pyh787bdff_0 22 | - ffmpeg=4.3=hf484d3e_0 23 | - freetype=2.10.4=h0708190_1 24 | - gmp=6.2.1=h58526e2_0 25 | - gnutls=3.6.13=h85f3911_1 26 | - intel-openmp=2021.2.0=h06a4308_610 27 | - jpeg=9b=0 28 | - lame=3.100=h7f98852_1001 29 | - lcms2=2.12=h3be6417_0 30 | - ld_impl_linux-64=2.35.1=hea4e1c9_2 31 | - libblas=3.9.0=9_mkl 32 | - libcblas=3.9.0=9_mkl 33 | - libffi=3.3=h58526e2_2 34 | - libgcc-ng=9.3.0=h2828fa1_19 35 | - libgomp=9.3.0=h2828fa1_19 36 | - libiconv=1.16=h516909a_0 37 | - liblapack=3.9.0=9_mkl 38 | - libpng=1.6.37=h21135ba_2 39 | - libstdcxx-ng=9.3.0=h6de172a_19 40 | - libtiff=4.1.0=h2733197_1 41 | - libuv=1.41.0=h7f98852_0 42 | - lz4-c=1.9.3=h9c3ff4c_0 43 | - mkl=2021.2.0=h06a4308_296 44 | - ncurses=6.2=h58526e2_4 45 | - nettle=3.6=he412f7d_0 46 | - ninja=1.10.2=h4bd325d_0 47 | - numpy=1.19.5=py36h2aa4a07_1 48 | - olefile=0.46=pyh9f0ad1d_1 49 | - openh264=2.1.1=h780b84a_0 50 | - openssl=1.1.1v=h7f8727e_0 51 | - pillow=8.2.0=py36he98fc37_0 52 | - pip=21.1.1=pyhd8ed1ab_0 53 | - python=3.6.13=hffdb5ce_0_cpython 54 | - python_abi=3.6=1_cp36m 55 | - pytorch=1.8.0=py3.6_cuda11.1_cudnn8.0.5_0 56 | - readline=8.1=h46c0cb4_0 57 | - ruamel_yaml=0.15.80=py36h8f6f2f9_1004 58 | - setuptools=49.6.0=py36h5fab9bb_3 59 | - sqlite=3.35.5=h74cdb3f_0 60 | - tk=8.6.10=h21135ba_1 61 | - torchaudio=0.8.0=py36 62 | - torchvision=0.9.0=py36_cu111 63 | - typing_extensions=3.7.4.3=py_0 64 | - wheel=0.36.2=pyhd3deb0d_0 65 | - xz=5.2.5=h516909a_1 66 | - yaml=0.2.5=h516909a_0 67 | - zlib=1.2.11=h516909a_1010 68 | - zstd=1.4.9=ha95c52a_0 69 | - pip: 70 | - blessed==1.20.0 71 | - boto3==1.23.10 72 | - botocore==1.26.10 73 | - cached-property==1.5.2 74 | - charset-normalizer==2.0.12 75 | - click==8.0.2 76 | - cycler==0.11.0 77 | - ftfy==3.0.1 78 | - gpustat==1.1 79 | - h5py==3.1.0 80 | - idna==3.4 81 | - imageio==2.9.0 82 | - importlib-metadata==4.8.1 83 | - jmespath==0.10.0 84 | - joblib==1.1.0 85 | - kiwisolver==1.3.1 86 | - lmdb==1.3.0 87 | - matplotlib==3.3.4 88 | - nltk==3.6.3 89 | - nvidia-ml-py==11.525.112 90 | - opencv-python==4.7.0.72 91 | - pandas==1.1.5 92 | - protobuf==3.17.0 93 | - psutil==5.9.4 94 | - pyparsing==3.0.6 95 | - python-dateutil==2.8.2 96 | - pytz==2021.3 97 | - regex==2021.10.8 98 | - requests==2.27.1 99 | - s3transfer==0.5.2 100 | - sacremoses==0.0.53 101 | - scikit-learn==0.24.2 102 | - scipy==1.5.4 103 | - seaborn==0.11.2 104 | - sentencepiece==0.1.98 105 | - six==1.16.0 106 | - sklearn==0.0 107 | - tensorboard-logger==0.1.0 108 | - threadpoolctl==3.1.0 109 | - tqdm==4.62.3 110 | - transformers==2.1.0 111 | - urllib3==1.26.15 112 | - wcwidth==0.2.6 113 | - zipp==3.6.0 114 | prefix: /mnt/data10t/bakuphome20210617/zhangkun/anaconda3/envs/py36 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Enhanced Semantic Similarity Learning Framework for Image-Text Matching 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | Official PyTorch implementation of the paper [Enhanced Semantic Similarity Learning Framework for Image-Text Matching](https://www.researchgate.net/publication/373318149_Enhanced_Semantic_Similarity_Learning_Framework_for_Image-Text_Matching). 6 | 7 | Please use the following bib entry to cite this paper if you are using any resources from the repo. 8 | 9 | ``` 10 | @article{zhang2023enhanced, 11 | author={Zhang, Kun and Hu, Bo and Zhang, Huatian and Li, Zhe and Mao, Zhendong}, 12 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 13 | title={Enhanced Semantic Similarity Learning Framework for Image-Text Matching}, 14 | year={2024}, 15 | volume={34}, 16 | number={4}, 17 | pages={2973-2988} 18 | } 19 | ``` 20 | 21 | 22 | We referred to the implementations of [GPO](https://github.com/woodfrog/vse_infty/blob/master/README.md) to build up our codebase. 23 | 24 | ## Motivation 25 |
26 | 27 | Squares denote local dimension elements in a feature. Circles denote the measure-unit, i.e., the minimal basic component used to examine semantic similarity. Compared with (a) existing methods typically default to a static mechanism that only examines the single-dimensional cross-modal correspondence, (b) our key idea is to dynamically capture and learn multi-dimensional enhanced correspondence. That is, the number of dimensions constituting the measure-units is changed from existing only one to hierarchical multi-levels, enabling their examining information granularity to be enriched and enhanced to promote a more comprehensive semantic similarity learning. 28 | 29 | ## Introduction 30 | 31 | In this paper, different from the single-dimensional correspondence with limited semantic expressive capability, we propose a novel enhanced semantic similarity learning (ESL), which generalizes both measure-units and their correspondences into a dynamic learnable framework to examine the multi-dimensional enhanced correspondence between visual and textual features. Specifically, we first devise the intra-modal multi-dimensional aggregators with iterative enhancing mechanism, which dynamically captures new measure-units integrated by hierarchical multi-dimensions, producing diverse semantic combinatorial expressive capabilities to provide richer and discriminative information for similarity examination. Then, we devise the inter-modal enhanced correspondence learning with sparse contribution degrees, which comprehensively and efficiently determines the cross-modal semantic similarity. Extensive experiments verify its superiority in achieving state-of-the-art performance. 32 | 33 | ### Image-text Matching Results 34 | 35 | The following tables show partial results of image-to-text retrieval on COCO and Flickr30K datasets. In these experiments, we use BERT-base as the text encoder for our methods. This branch provides our code and pre-trained models for **using BERT as the text backbone**. Some results are better than those reported in the paper. However, it should be noted that the ensemble results in the paper may not be obtained by the best two checkpoints provided. It is lost due to not saving in time. You can train the model several times more and then combine any two to find the best ensemble performance. Please check out to [**the ```CLIP-based``` branch**](https://github.com/kkzhang95/ESL/blob/main/README.md) for the code and pre-trained models. 36 | 37 | #### Results of 5-fold evaluation on COCO 1K Test Splitbe 38 | 39 | | |Visual Backbone|Text Backbone|R1|R5|R10|R1|R5|R10|Rsum|Link| 40 | |---|:---:|:---:|---|---|---|---|---|---|---|---| 41 | |ESL-H | BUTD region |BERT-base|**82.5**|**97.4**|**99.0**|**66.2**|**91.9**|**96.7**|**533.5**|[Here](https://drive.google.com/file/d/1NgTLNFGhEt14YgLb3gCkWfBp1gBxvl9w/view?usp=sharing)| 42 | |ESL-A | BUTD region |BERT-base|**82.2**|**96.9**|**98.9**|**66.5**|**92.1**|**96.7**|**533.4**|[Here](https://drive.google.com/file/d/17jaJm2DSJbF5IuUij9s3c2fupcy4CW8T/view?usp=sharing)| 43 | 44 | 45 | #### Results of 5-fold evaluation on COCO 5K Test Split 46 | 47 | | |Visual Backbone|Text Backbone|R1|R5|R10|R1|R5|R10|Rsum|Link| 48 | |---|:---:|:---:|---|---|---|---|---|---|---|---| 49 | |ESL-H | BUTD region |BERT-base|**63.6**|**87.4**|**93.5**|**44.2**|**74.1**|**84.0**|**446.9**|[Here](https://drive.google.com/file/d/1NgTLNFGhEt14YgLb3gCkWfBp1gBxvl9w/view?usp=sharing)| 50 | |ESL-A | BUTD region |BERT-base|**63.0**|**87.6**|**93.3**|**44.5**|**74.4**|**84.1**|**447.0**|[Here](https://drive.google.com/file/d/17jaJm2DSJbF5IuUij9s3c2fupcy4CW8T/view?usp=sharing)| 51 | 52 | 53 | #### Results on Flickr30K Test Split 54 | 55 | | |Visual Backbone|Text Backbone|R1|R5|R10|R1|R5|R10|Rsum|Link| 56 | |---|:---:|:---:|---|---|---|---|---|---|---|---| 57 | |ESL-H | BUTD region |BERT-base|**83.5**|**96.3**|**98.4**|**65.1**|**87.6**|**92.7**|**523.7**|[Here](https://drive.google.com/file/d/17FnwyH8aSOwvUuZco0lQ5eY0TM_4LXxv/view?usp=sharing)| 58 | |ESL-A | BUTD region |BERT-base|**84.3**|**96.3**|**98.0**|**64.1**|**87.4**|**92.2**|**522.4**|[Here](https://drive.google.com/file/d/1ZoPW8azNkBWVq1jaQxHfI_XpINzvmv1n/view?usp=sharing)| 59 | 60 | 61 | 62 | 63 | 64 | ## Preparation 65 | 66 | ### Environment 67 | 68 | We recommended the following dependencies. 69 | 70 | * Python 3.6 71 | * [PyTorch](http://pytorch.org/) 1.8.0 72 | * [NumPy](http://www.numpy.org/) (>1.19.5) 73 | * [TensorBoard](https://github.com/TeamHG-Memex/tensorboard_logger) 74 | * The specific required environment can be found [here](https://github.com/CrossmodalGroup/ESL/blob/main/ESL.yaml) Using **conda env create -f ESL.yaml** to create the corresponding environments. 75 | 76 | ### Data 77 | 78 | You can download the dataset through Baidu Cloud. Download links are [Flickr30K]( https://pan.baidu.com/s/1Fr_bviuWLcrJ9MiiRn_H2Q) and [MSCOCO]( https://pan.baidu.com/s/1vp3gtQhT7GO0PQACBSnOrQ), the extraction code is: USTC. 79 | 80 | ## Training 81 | 82 | ```bash 83 | sh train_region_f30k.sh 84 | ``` 85 | 86 | ```bash 87 | sh train_region_coco.sh 88 | ``` 89 | For the dimensional selective mask, we design both heuristic and adaptive strategies. You can use the flag in [vse.py](https://github.com/CrossmodalGroup/ESL/blob/main/lib/vse.py) (line 44) 90 | ```bash 91 | heuristic_strategy = False 92 | ``` 93 | to control which strategy is selected. True -> heuristic strategy, False -> adaptive strategy. 94 | 95 | ## Evaluation 96 | 97 | Test on Flickr30K 98 | ```bash 99 | python test.py 100 | ``` 101 | 102 | To do cross-validation on MSCOCO, pass `fold5=True` with a model trained using 103 | `--data_name coco_precomp`. 104 | 105 | ```bash 106 | python testall.py 107 | ``` 108 | 109 | To ensemble model, specify the model_path in test_stack.py, and run 110 | ```bash 111 | python test_stack.py 112 | ``` 113 | 114 | 115 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_argument_parser(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--data_path', default='../Flickr30K/', # Flickr30K MS-COCO 7 | help='path to datasets') 8 | parser.add_argument('--data_name', default='f30k_precomp', 9 | help='{coco,f30k}_precomp') 10 | parser.add_argument('--vocab_path', default='./vocab/', 11 | help='Path to saved vocabulary json files.') 12 | parser.add_argument('--margin', default=0.2, type=float, 13 | help='Rank loss margin.') 14 | parser.add_argument('--num_epochs', default=30, type=int, 15 | help='Number of training epochs.') 16 | parser.add_argument('--batch_size', default=128, type=int, 17 | help='Size of a training mini-batch.') 18 | parser.add_argument('--word_dim', default=300, type=int, 19 | help='Dimensionality of the word embedding.') 20 | parser.add_argument('--grad_clip', default=2., type=float, 21 | help='Gradient clipping threshold.') 22 | parser.add_argument('--learning_rate', default=.0005, type=float, 23 | help='Initial learning rate.') 24 | parser.add_argument('--lr_update', default=15, type=int, 25 | help='Number of epochs to update the learning rate.') 26 | parser.add_argument('--optim', default='adam', type=str, 27 | help='the optimizer') 28 | parser.add_argument('--workers', default=10, type=int, 29 | help='Number of data loader workers.') 30 | parser.add_argument('--log_step', default=500, type=int, 31 | help='Number of steps to logger.info and record the log.') 32 | parser.add_argument('--val_step', default=500, type=int, 33 | help='Number of steps to run validation.') 34 | parser.add_argument('--logger_name', default='../log/', 35 | help='Path to save Tensorboard log.') 36 | parser.add_argument('--model_name', default='../checkpoint/', 37 | help='Path to save the model.') 38 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 39 | help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--max_violation', action='store_true', 41 | help='Use max instead of sum in the rank loss.') 42 | parser.add_argument('--img_dim', default=2048, type=int, 43 | help='Dimensionality of the image embedding.') 44 | parser.add_argument('--no_imgnorm', action='store_true', 45 | help='Do not normalize the image embeddings.') 46 | parser.add_argument('--no_txtnorm', action='store_true', 47 | help='Do not normalize the text embeddings.') 48 | parser.add_argument('--precomp_enc_type', default="basic", 49 | help='basic|backbone') 50 | parser.add_argument('--backbone_path', type=str, default='', 51 | help='path to the pre-trained backbone net') 52 | parser.add_argument('--backbone_source', type=str, default='detector', 53 | help='the source of the backbone model, detector|imagenet') 54 | parser.add_argument('--vse_mean_warmup_epochs', type=int, default=1, 55 | help='The number of warmup epochs using mean vse loss') 56 | parser.add_argument('--reset_start_epoch', action='store_true', 57 | help='Whether restart the start epoch when load weights') 58 | parser.add_argument('--backbone_warmup_epochs', type=int, default=5, 59 | help='The number of epochs for warmup') 60 | parser.add_argument('--embedding_warmup_epochs', type=int, default=2, 61 | help='The number of epochs for warming up the embedding layers') 62 | parser.add_argument('--backbone_lr_factor', default=0.01, type=float, 63 | help='The lr factor for fine-tuning the backbone, it will be multiplied to the lr of ' 64 | 'the embedding layers') 65 | parser.add_argument('--input_scale_factor', type=float, default=1, 66 | help='The factor for scaling the input image') 67 | parser.add_argument('--kernel_size', default=2, type=int, 68 | help='Dimensionality of the sim embedding.') 69 | parser.add_argument('--embed_size', default=512, type=int, 70 | help='Dimensionality of the joint embedding.') 71 | 72 | 73 | return parser 74 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | from lib import evaluation 5 | 6 | logging.basicConfig() 7 | logger = logging.getLogger() 8 | logger.setLevel(logging.INFO) 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dataset', default='coco', 14 | help='coco or f30k') 15 | parser.add_argument('--data_path', default='/tmp/data/coco') 16 | parser.add_argument('--save_results', action='store_true') 17 | parser.add_argument('--evaluate_cxc', action='store_true') 18 | opt = parser.parse_args() 19 | 20 | if opt.dataset == 'coco': 21 | weights_bases = [ 22 | 'runs/release_weights/coco_butd_region_bert', 23 | 'runs/release_weights/coco_butd_grid_bert', 24 | 'runs/release_weights/coco_wsl_grid_bert', 25 | ] 26 | elif opt.dataset == 'f30k': 27 | weights_bases = [ 28 | 'runs/release_weights/f30k_butd_region_bert', 29 | 'runs/release_weights/f30k_butd_grid_bert', 30 | 'runs/release_weights/f30k_wsl_grid_bert', 31 | ] 32 | else: 33 | raise ValueError('Invalid dataset argument {}'.format(opt.dataset)) 34 | 35 | for base in weights_bases: 36 | logger.info('Evaluating {}...'.format(base)) 37 | model_path = os.path.join(base, 'model_best.pth') 38 | if opt.save_results: # Save the final results for computing ensemble results 39 | save_path = os.path.join(base, 'results_{}.npy'.format(opt.dataset)) 40 | else: 41 | save_path = None 42 | 43 | if opt.dataset == 'coco': 44 | if not opt.evaluate_cxc: 45 | # Evaluate COCO 5-fold 1K 46 | evaluation.evalrank(model_path, data_path=opt.data_path, split='testall', fold5=True) 47 | # Evaluate COCO 5K 48 | evaluation.evalrank(model_path, data_path=opt.data_path, split='testall', fold5=False, save_path=save_path) 49 | else: 50 | # Evaluate COCO-trained models on CxC 51 | evaluation.evalrank(model_path, data_path=opt.data_path, split='testall', fold5=True, cxc=True) 52 | elif opt.dataset == 'f30k': 53 | # Evaluate Flickr30K 54 | evaluation.evalrank(model_path, data_path=opt.data_path, split='test', fold5=False, save_path=save_path) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /eval_ensemble.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from lib import evaluation 3 | 4 | logging.basicConfig() 5 | logger = logging.getLogger() 6 | logger.setLevel(logging.INFO) 7 | 8 | # Evaluate model ensemble 9 | paths = ['runs/coco_butd_grid_bert/results_coco.npy', 10 | 'runs/coco_butd_region_bert/results_coco.npy'] 11 | 12 | evaluation.eval_ensemble(results_paths=paths, fold5=True) 13 | evaluation.eval_ensemble(results_paths=paths, fold5=False) 14 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__init__.py -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/encoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/encoders.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/vse.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/__pycache__/vse.cpython-36.pyc -------------------------------------------------------------------------------- /lib/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | 5 | def add_path(path): 6 | if path not in sys.path: 7 | sys.path.insert(0, path) 8 | 9 | 10 | root_dir = osp.abspath(osp.dirname(osp.join(__file__, '..'))) 11 | 12 | # Add lib to PYTHONPATH 13 | lib_path = osp.join(root_dir, 'lib') 14 | datasets_path = osp.join(root_dir, 'datasets') 15 | add_path(lib_path) 16 | add_path(datasets_path) 17 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/datasets/__init__.py -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/image_caption.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/datasets/__pycache__/image_caption.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/image_caption.py: -------------------------------------------------------------------------------- 1 | """COCO dataset loader""" 2 | import torch 3 | import torch.utils.data as data 4 | import os 5 | import os.path as osp 6 | import numpy as np 7 | from imageio import imread 8 | import random 9 | import json 10 | import cv2 11 | 12 | import logging 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class RawImageDataset(data.Dataset): 18 | """ 19 | Load precomputed captions and image features 20 | Possible options: f30k_precomp, coco_precomp 21 | """ 22 | 23 | def __init__(self, data_path, data_name, data_split, tokenzier, opt, train): 24 | self.opt = opt 25 | self.train = train 26 | self.data_path = data_path 27 | self.data_name = data_name 28 | self.tokenizer = tokenzier 29 | 30 | loc_cap = osp.join(data_path, 'precomp') 31 | loc_image = osp.join(data_path, 'precomp') 32 | loc_mapping = osp.join(data_path, 'id_mapping.json') 33 | if 'coco' in data_name: 34 | self.image_base = osp.join(data_path, 'images') 35 | else: 36 | self.image_base = osp.join(data_path, 'flickr30k-images') 37 | 38 | with open(loc_mapping, 'r') as f_mapping: 39 | self.id_to_path = json.load(f_mapping) 40 | 41 | # Read Captions 42 | self.captions = [] 43 | with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f: 44 | for line in f: 45 | self.captions.append(line.strip()) 46 | 47 | # Get the image ids 48 | with open(osp.join(loc_image, '{}_ids.txt'.format(data_split)), 'r') as f: 49 | image_ids = f.readlines() 50 | self.images = [int(x.strip()) for x in image_ids] 51 | 52 | # Set related parameters according to the pre-trained backbone ** 53 | assert 'backbone' in opt.precomp_enc_type 54 | 55 | self.backbone_source = opt.backbone_source 56 | self.base_target_size = 256 57 | self.crop_ratio = 0.875 58 | self.train_scale_rate = 1 59 | if hasattr(opt, 'input_scale_factor') and opt.input_scale_factor != 1: 60 | self.base_target_size = int(self.base_target_size * opt.input_scale_factor) 61 | logger.info('Input images are scaled by factor {}'.format(opt.input_scale_factor)) 62 | if 'detector' in self.backbone_source: 63 | self.pixel_means = np.array([[[102.9801, 115.9465, 122.7717]]]) 64 | else: 65 | self.imagenet_mean = [0.485, 0.456, 0.406] 66 | self.imagenet_std = [0.229, 0.224, 0.225] 67 | 68 | self.length = len(self.captions) 69 | 70 | # rkiros data has redundancy in images, we divide by 5, 10crop doesn't 71 | num_images = len(self.images) 72 | 73 | if num_images != self.length: 74 | self.im_div = 5 75 | else: 76 | self.im_div = 1 77 | # the development set for coco is large and so validation would be slow 78 | if data_split == 'dev': 79 | self.length = 5000 80 | 81 | def __getitem__(self, index): 82 | img_index = index // self.im_div 83 | caption = self.captions[index] 84 | caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption) 85 | 86 | # Convert caption (string) to word ids (with Size Augmentation at training time). 87 | target = process_caption(self.tokenizer, caption_tokens, self.train) 88 | 89 | image_id = self.images[img_index] 90 | image_path = os.path.join(self.image_base, self.id_to_path[str(image_id)]) 91 | im_in = np.array(imread(image_path)) 92 | processed_image = self._process_image(im_in) 93 | image = torch.Tensor(processed_image) 94 | image = image.permute(2, 0, 1) 95 | return image, target, index, img_index 96 | 97 | def __len__(self): 98 | return self.length 99 | 100 | def _process_image(self, im_in): 101 | """ 102 | Converts an image into a network input, with pre-processing including re-scaling, padding, etc, and data 103 | augmentation. 104 | """ 105 | if len(im_in.shape) == 2: 106 | im_in = im_in[:, :, np.newaxis] 107 | im_in = np.concatenate((im_in, im_in, im_in), axis=2) 108 | 109 | if 'detector' in self.backbone_source: 110 | im_in = im_in[:, :, ::-1] 111 | im = im_in.astype(np.float32, copy=True) 112 | 113 | if self.train: 114 | target_size = self.base_target_size * self.train_scale_rate 115 | else: 116 | target_size = self.base_target_size 117 | 118 | # 2. Random crop when in training mode, elsewise just skip 119 | if self.train: 120 | crop_ratio = np.random.random() * 0.4 + 0.6 121 | crop_size_h = int(im.shape[0] * crop_ratio) 122 | crop_size_w = int(im.shape[1] * crop_ratio) 123 | processed_im = self._crop(im, crop_size_h, crop_size_w, random=True) 124 | else: 125 | processed_im = im 126 | 127 | # 3. Resize to the target resolution 128 | im_shape = processed_im.shape 129 | im_scale_x = float(target_size) / im_shape[1] 130 | im_scale_y = float(target_size) / im_shape[0] 131 | processed_im = cv2.resize(processed_im, None, None, fx=im_scale_x, fy=im_scale_y, 132 | interpolation=cv2.INTER_LINEAR) 133 | 134 | if self.train: 135 | if np.random.random() > 0.5: 136 | processed_im = self._hori_flip(processed_im) 137 | 138 | # Normalization 139 | if 'detector' in self.backbone_source: 140 | processed_im = self._detector_norm(processed_im) 141 | else: 142 | processed_im = self._imagenet_norm(processed_im) 143 | 144 | return processed_im 145 | 146 | def _imagenet_norm(self, im_in): 147 | im_in = im_in.astype(np.float32) 148 | im_in = im_in / 255 149 | for i in range(im_in.shape[-1]): 150 | im_in[:, :, i] = (im_in[:, :, i] - self.imagenet_mean[i]) / self.imagenet_std[i] 151 | return im_in 152 | 153 | def _detector_norm(self, im_in): 154 | im_in = im_in.astype(np.float32) 155 | im_in -= self.pixel_means 156 | return im_in 157 | 158 | @staticmethod 159 | def _crop(im, crop_size_h, crop_size_w, random): 160 | h, w = im.shape[0], im.shape[1] 161 | if random: 162 | if w - crop_size_w == 0: 163 | x_start = 0 164 | else: 165 | x_start = np.random.randint(w - crop_size_w, size=1)[0] 166 | if h - crop_size_h == 0: 167 | y_start = 0 168 | else: 169 | y_start = np.random.randint(h - crop_size_h, size=1)[0] 170 | else: 171 | x_start = (w - crop_size_w) // 2 172 | y_start = (h - crop_size_h) // 2 173 | 174 | cropped_im = im[y_start:y_start + crop_size_h, x_start:x_start + crop_size_w, :] 175 | 176 | return cropped_im 177 | 178 | @staticmethod 179 | def _hori_flip(im): 180 | im = np.fliplr(im).copy() 181 | return im 182 | 183 | 184 | class PrecompRegionDataset(data.Dataset): 185 | """ 186 | Load precomputed captions and image features for COCO or Flickr 187 | """ 188 | 189 | def __init__(self, data_path, data_name, data_split, tokenizer, opt, train): 190 | self.tokenizer = tokenizer 191 | self.opt = opt 192 | self.train = train 193 | self.data_path = data_path 194 | self.data_name = data_name 195 | 196 | loc_cap = osp.join(data_path, data_name) 197 | loc_image = osp.join(data_path, data_name) 198 | 199 | # Captions 200 | self.captions = [] 201 | with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f: 202 | for line in f: 203 | self.captions.append(line.strip()) 204 | # Image features 205 | self.images = np.load(os.path.join(loc_image, '%s_ims.npy' % data_split)) 206 | 207 | self.length = len(self.captions) 208 | # rkiros data has redundancy in images, we divide by 5, 10crop doesn't 209 | num_images = len(self.images) 210 | 211 | if num_images != self.length: 212 | self.im_div = 5 213 | else: 214 | self.im_div = 1 215 | # the development set for coco is large and so validation would be slow 216 | if data_split == 'dev': 217 | self.length = 5000 218 | 219 | def __getitem__(self, index): 220 | # handle the image redundancy 221 | img_index = index // self.im_div 222 | caption = self.captions[index] 223 | caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption) 224 | 225 | # Convert caption (string) to word ids (with Size Augmentation at training time) 226 | target = process_caption(self.tokenizer, caption_tokens, self.train) 227 | image = self.images[img_index] 228 | if self.train: # Size augmentation for region feature 229 | num_features = image.shape[0] 230 | rand_list = np.random.rand(num_features) 231 | image = image[np.where(rand_list > 0.20)] 232 | image = torch.Tensor(image) 233 | return image, target, index, img_index 234 | 235 | def __len__(self): 236 | return self.length 237 | 238 | 239 | def process_caption(tokenizer, tokens, train=True): 240 | output_tokens = [] 241 | deleted_idx = [] 242 | 243 | for i, token in enumerate(tokens): 244 | sub_tokens = tokenizer.wordpiece_tokenizer.tokenize(token) 245 | prob = random.random() 246 | 247 | if prob < 0.20 and train: # mask/remove the tokens only during training 248 | prob /= 0.20 249 | 250 | # 50% randomly change token to mask token 251 | if prob < 0.5: 252 | for sub_token in sub_tokens: 253 | output_tokens.append("[MASK]") 254 | # 10% randomly change token to random token 255 | elif prob < 0.6: 256 | for sub_token in sub_tokens: 257 | output_tokens.append(random.choice(list(tokenizer.vocab.keys()))) 258 | # -> rest 10% randomly keep current token 259 | else: 260 | for sub_token in sub_tokens: 261 | output_tokens.append(sub_token) 262 | deleted_idx.append(len(output_tokens) - 1) 263 | else: 264 | for sub_token in sub_tokens: 265 | # no masking token (will be ignored by loss function later) 266 | output_tokens.append(sub_token) 267 | 268 | if len(deleted_idx) != 0: 269 | output_tokens = [output_tokens[i] for i in range(len(output_tokens)) if i not in deleted_idx] 270 | 271 | output_tokens = ['[CLS]'] + output_tokens + ['[SEP]'] 272 | target = tokenizer.convert_tokens_to_ids(output_tokens) 273 | target = torch.Tensor(target) 274 | return target 275 | 276 | 277 | def collate_fn(data): 278 | """Build mini-batch tensors from a list of (image, caption) tuples. 279 | Args: 280 | data: list of (image, caption) tuple. 281 | - image: torch tensor of shape (3, 256, 256). 282 | - caption: torch tensor of shape (?); variable length. 283 | 284 | Returns: 285 | images: torch tensor of shape (batch_size, 3, 256, 256). 286 | targets: torch tensor of shape (batch_size, padded_length). 287 | lengths: list; valid length for each padded caption. 288 | """ 289 | images, captions, ids, img_ids = zip(*data) 290 | if len(images[0].shape) == 2: # region feature 291 | # Sort a data list by caption length 292 | # Merge images (convert tuple of 3D tensor to 4D tensor) 293 | # images = torch.stack(images, 0) 294 | img_lengths = [len(image) for image in images] 295 | all_images = torch.zeros(len(images), max(img_lengths), images[0].size(-1)) 296 | for i, image in enumerate(images): 297 | end = img_lengths[i] 298 | all_images[i, :end] = image[:end] 299 | img_lengths = torch.Tensor(img_lengths) 300 | 301 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 302 | lengths = [len(cap) for cap in captions] 303 | targets = torch.zeros(len(captions), max(lengths)).long() 304 | 305 | for i, cap in enumerate(captions): 306 | end = lengths[i] 307 | targets[i, :end] = cap[:end] 308 | 309 | return all_images, img_lengths, targets, lengths, ids 310 | else: # raw input image 311 | # Merge images (convert tuple of 3D tensor to 4D tensor) 312 | images = torch.stack(images, 0) 313 | 314 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 315 | lengths = [len(cap) for cap in captions] 316 | targets = torch.zeros(len(captions), max(lengths)).long() 317 | for i, cap in enumerate(captions): 318 | end = lengths[i] 319 | targets[i, :end] = cap[:end] 320 | return images, targets, lengths, ids 321 | 322 | 323 | def get_loader(data_path, data_name, data_split, tokenizer, opt, batch_size=100, 324 | shuffle=True, num_workers=2, train=True): 325 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 326 | if train: 327 | drop_last = True 328 | else: 329 | drop_last = False 330 | if opt.precomp_enc_type == 'basic': 331 | dset = PrecompRegionDataset(data_path, data_name, data_split, tokenizer, opt, train) 332 | data_loader = torch.utils.data.DataLoader(dataset=dset, 333 | batch_size=batch_size, 334 | shuffle=shuffle, 335 | pin_memory=True, 336 | collate_fn=collate_fn, 337 | num_workers=num_workers, 338 | drop_last=drop_last) 339 | else: 340 | dset = RawImageDataset(data_path, data_name, data_split, tokenizer, opt, train) 341 | data_loader = torch.utils.data.DataLoader(dataset=dset, 342 | batch_size=batch_size, 343 | shuffle=shuffle, 344 | num_workers=num_workers, 345 | pin_memory=True, 346 | collate_fn=collate_fn) 347 | return data_loader 348 | 349 | 350 | def get_loaders(data_path, data_name, tokenizer, batch_size, workers, opt): 351 | train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt, 352 | batch_size, True, workers) 353 | val_loader = get_loader(data_path, data_name, 'test', tokenizer, opt, 354 | batch_size, False, workers, train=False) 355 | return train_loader, val_loader 356 | 357 | 358 | def get_train_loader(data_path, data_name, tokenizer, batch_size, workers, opt, shuffle): 359 | train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt, 360 | batch_size, shuffle, workers) 361 | return train_loader 362 | 363 | 364 | def get_test_loader(split_name, data_name, tokenizer, batch_size, workers, opt): 365 | 366 | # opt.data_path = '/mnt/data10t/bakuphome20210617/lz/data/I-T/Flickr30K/' 367 | # data_name = 'f30k_precomp' 368 | # split_name = 'test' 369 | test_loader = get_loader(opt.data_path, data_name, split_name, tokenizer, opt, 370 | batch_size, False, workers, train=False) 371 | return test_loader 372 | -------------------------------------------------------------------------------- /lib/encoders.py: -------------------------------------------------------------------------------- 1 | """VSE modules""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | # from collections import OrderedDict 7 | 8 | from transformers import BertModel 9 | 10 | from lib.modules.resnet import ResnetFeatureExtractor 11 | # from lib.modules.aggr.gpo import GPO 12 | from lib.modules.mlp import MLP 13 | 14 | import logging 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def l1norm(X, dim, eps=1e-8): 20 | """L1-normalize columns of X 21 | """ 22 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps 23 | X = torch.div(X, norm) 24 | return X 25 | 26 | 27 | def l2norm(X, dim, eps=1e-8): 28 | """L2-normalize columns of X 29 | """ 30 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 31 | X = torch.div(X, norm) 32 | return X 33 | 34 | 35 | def maxk_pool1d_var(x, dim, k, lengths): 36 | results = list() 37 | lengths = list(lengths.cpu().numpy()) 38 | lengths = [int(x) for x in lengths] 39 | for idx, length in enumerate(lengths): 40 | k = min(k, length) 41 | max_k_i = maxk(x[idx, :length, :], dim - 1, k).mean(dim - 1) 42 | results.append(max_k_i) 43 | results = torch.stack(results, dim=0) 44 | return results 45 | 46 | 47 | def maxk_pool1d(x, dim, k): 48 | max_k = maxk(x, dim, k) 49 | return max_k.mean(dim) 50 | 51 | 52 | def maxk(x, dim, k): 53 | index = x.topk(k, dim=dim)[1] 54 | return x.gather(dim, index) 55 | 56 | 57 | def get_text_encoder(embed_size, no_txtnorm=False): 58 | return EncoderText(embed_size, no_txtnorm=no_txtnorm) 59 | 60 | 61 | def get_image_encoder(data_name, img_dim, embed_size, precomp_enc_type='basic', 62 | backbone_source=None, backbone_path=None, no_imgnorm=False): 63 | """A wrapper to image encoders. Chooses between an different encoders 64 | that uses precomputed image features. 65 | """ 66 | if precomp_enc_type == 'basic': 67 | img_enc = EncoderImageAggr( 68 | img_dim, embed_size, precomp_enc_type, no_imgnorm) 69 | elif precomp_enc_type == 'backbone': 70 | backbone_cnn = ResnetFeatureExtractor(backbone_source, backbone_path, fixed_blocks=2) 71 | img_enc = EncoderImageFull(backbone_cnn, img_dim, embed_size, precomp_enc_type, no_imgnorm) 72 | else: 73 | raise ValueError("Unknown precomp_enc_type: {}".format(precomp_enc_type)) 74 | 75 | return img_enc 76 | 77 | 78 | class EncoderImageAggr(nn.Module): 79 | def __init__(self, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False): 80 | super(EncoderImageAggr, self).__init__() 81 | self.embed_size = embed_size 82 | self.no_imgnorm = no_imgnorm 83 | self.fc = nn.Linear(img_dim, embed_size) 84 | self.precomp_enc_type = precomp_enc_type 85 | if precomp_enc_type == 'basic': 86 | self.mlp = MLP(img_dim, embed_size // 2, embed_size, 2) 87 | # self.gpool = GPO(32, 32) 88 | self.init_weights() 89 | 90 | def init_weights(self): 91 | """Xavier initialization for the fully connected layer 92 | """ 93 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 94 | self.fc.out_features) 95 | self.fc.weight.data.uniform_(-r, r) 96 | self.fc.bias.data.fill_(0) 97 | 98 | def forward(self, images, image_lengths): 99 | """Extract image feature vectors.""" 100 | features = self.fc(images) 101 | if self.precomp_enc_type == 'basic': 102 | # When using pre-extracted region features, add an extra MLP for the embedding transformation 103 | features = self.mlp(images) + features 104 | 105 | # features, pool_weights = self.gpool(features, image_lengths) 106 | 107 | if not self.no_imgnorm: 108 | features = l2norm(features, dim=-1) 109 | 110 | return features 111 | 112 | 113 | class EncoderImageFull(nn.Module): 114 | def __init__(self, backbone_cnn, img_dim, embed_size, precomp_enc_type='basic', no_imgnorm=False): 115 | super(EncoderImageFull, self).__init__() 116 | self.backbone = backbone_cnn 117 | self.image_encoder = EncoderImageAggr(img_dim, embed_size, precomp_enc_type, no_imgnorm) 118 | self.backbone_freezed = False 119 | 120 | def forward(self, images): 121 | """Extract image feature vectors.""" 122 | base_features = self.backbone(images) 123 | 124 | if self.training: 125 | # Size Augmentation during training, randomly drop grids 126 | base_length = base_features.size(1) 127 | features = [] 128 | feat_lengths = [] 129 | rand_list_1 = np.random.rand(base_features.size(0), base_features.size(1)) 130 | rand_list_2 = np.random.rand(base_features.size(0)) 131 | for i in range(base_features.size(0)): 132 | if rand_list_2[i] > 0.2: 133 | feat_i = base_features[i][np.where(rand_list_1[i] > 0.20 * rand_list_2[i])] 134 | len_i = len(feat_i) 135 | pads_i = torch.zeros(base_length - len_i, base_features.size(-1)).to(base_features.device) 136 | feat_i = torch.cat([feat_i, pads_i], dim=0) 137 | else: 138 | feat_i = base_features[i] 139 | len_i = base_length 140 | feat_lengths.append(len_i) 141 | features.append(feat_i) 142 | base_features = torch.stack(features, dim=0) 143 | base_features = base_features[:, :max(feat_lengths), :] 144 | feat_lengths = torch.tensor(feat_lengths).to(base_features.device) 145 | else: 146 | feat_lengths = torch.zeros(base_features.size(0)).to(base_features.device) 147 | feat_lengths[:] = base_features.size(1) 148 | 149 | features = self.image_encoder(base_features, feat_lengths) 150 | 151 | return features 152 | 153 | def freeze_backbone(self): 154 | for param in self.backbone.parameters(): 155 | param.requires_grad = False 156 | logger.info('Backbone freezed.') 157 | 158 | def unfreeze_backbone(self, fixed_blocks): 159 | for param in self.backbone.parameters(): # open up all params first, then adjust the base parameters 160 | param.requires_grad = True 161 | self.backbone.set_fixed_blocks(fixed_blocks) 162 | self.backbone.unfreeze_base() 163 | logger.info('Backbone unfreezed, fixed blocks {}'.format(self.backbone.get_fixed_blocks())) 164 | 165 | 166 | # Language Model with BERT 167 | class EncoderText(nn.Module): 168 | def __init__(self, embed_size, no_txtnorm=False): 169 | super(EncoderText, self).__init__() 170 | self.embed_size = embed_size 171 | self.no_txtnorm = no_txtnorm 172 | 173 | self.bert = BertModel.from_pretrained('bert-base-uncased') 174 | self.linear = nn.Linear(768, embed_size) 175 | # self.gpool = GPO(32, 32) 176 | # self.dropout = nn.Dropout(0.4) 177 | 178 | def forward(self, x, lengths): 179 | """Handles variable size captions 180 | """ 181 | # Embed word ids to vectors 182 | bert_attention_mask = (x != 0).float() 183 | bert_emb = self.bert(x, bert_attention_mask)[0] # B x N x D 184 | 185 | cap_len = lengths 186 | # bert_emb = self.dropout(bert_emb) 187 | 188 | cap_emb = self.linear(bert_emb) 189 | 190 | # pooled_features, pool_weights = self.gpool(cap_emb, cap_len.to(cap_emb.device)) 191 | 192 | # normalization in the joint embedding space 193 | if not self.no_txtnorm: 194 | cap_emb = l2norm(cap_emb, dim=-1) 195 | 196 | return cap_emb 197 | -------------------------------------------------------------------------------- /lib/evaluation.py: -------------------------------------------------------------------------------- 1 | """Evaluation""" 2 | from __future__ import print_function 3 | import logging 4 | import time 5 | import torch 6 | import numpy as np 7 | 8 | import sys 9 | 10 | from collections import OrderedDict 11 | from transformers import BertTokenizer 12 | 13 | from lib.datasets import image_caption 14 | from lib.vse import VSEModel 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=0): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / (.0001 + self.count) 36 | 37 | def __str__(self): 38 | """String representation for logging 39 | """ 40 | # for values that should be recorded exactly e.g. iteration number 41 | if self.count == 0: 42 | return str(self.val) 43 | # for stats 44 | return '%.4f (%.4f)' % (self.val, self.avg) 45 | 46 | 47 | class LogCollector(object): 48 | """A collection of logging objects that can change from train to val""" 49 | 50 | def __init__(self): 51 | # to keep the order of logged variables deterministic 52 | self.meters = OrderedDict() 53 | 54 | def update(self, k, v, n=0): 55 | # create a new meter if previously not recorded 56 | if k not in self.meters: 57 | self.meters[k] = AverageMeter() 58 | self.meters[k].update(v, n) 59 | 60 | def __str__(self): 61 | """Concatenate the meters in one log line 62 | """ 63 | s = '' 64 | for i, (k, v) in enumerate(self.meters.items()): 65 | if i > 0: 66 | s += ' ' 67 | s += k + ' ' + str(v) 68 | return s 69 | 70 | def tb_log(self, tb_logger, prefix='', step=None): 71 | """Log using tensorboard 72 | """ 73 | for k, v in self.meters.items(): 74 | tb_logger.log_value(prefix + k, v.val, step=step) 75 | 76 | 77 | def encode_data(model, data_loader, log_step=10, logging=logger.info, backbone=False): 78 | """Encode all images and captions loadable by `data_loader` 79 | """ 80 | val_logger = LogCollector() 81 | 82 | # switch to evaluate mode 83 | model.val_start() 84 | 85 | # np array to keep all the embeddings 86 | img_embs = None 87 | cap_embs = None 88 | cap_lens = None 89 | 90 | max_n_word = 0 91 | for i, (images, img_lengths, captions, lengths, _) in enumerate(data_loader): 92 | max_n_word = max(max_n_word, max(lengths)) 93 | 94 | for i, data_i in enumerate(data_loader): 95 | # make sure val logger is used 96 | if not backbone: 97 | images, image_lengths, captions, lengths, ids = data_i 98 | else: 99 | images, captions, lengths, ids = data_i 100 | model.logger = val_logger 101 | 102 | # compute the embeddings 103 | if not backbone: 104 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths, image_lengths=image_lengths) 105 | else: 106 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths) 107 | 108 | if img_embs is None: 109 | if img_emb.dim() == 3: 110 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2))) 111 | else: 112 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1))) 113 | # cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1))) 114 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2))) 115 | cap_lens = [0] * len(data_loader.dataset) 116 | 117 | # cache embeddings 118 | img_embs[ids] = img_emb.data.cpu().numpy().copy() 119 | # cap_embs[ids, :] = cap_emb.data.cpu().numpy().copy() 120 | cap_embs[ids, :max(lengths), :] = cap_emb.data.cpu().numpy().copy() 121 | 122 | for j, nid in enumerate(ids): 123 | cap_lens[nid] = cap_len[j] 124 | 125 | del images, captions 126 | return img_embs, cap_embs, cap_lens 127 | 128 | 129 | def eval_ensemble(results_paths, fold5=False): 130 | all_sims = [] 131 | all_npts = [] 132 | for sim_path in results_paths: 133 | results = np.load(sim_path, allow_pickle=True).tolist() 134 | npts = results['npts'] 135 | sims = results['sims'] 136 | all_npts.append(npts) 137 | all_sims.append(sims) 138 | all_npts = np.array(all_npts) 139 | all_sims = np.array(all_sims) 140 | assert np.all(all_npts == all_npts[0]) 141 | npts = int(all_npts[0]) 142 | sims = all_sims.mean(axis=0) 143 | 144 | if not fold5: 145 | r, rt = i2t(npts, sims, return_ranks=True) 146 | ri, rti = t2i(npts, sims, return_ranks=True) 147 | ar = (r[0] + r[1] + r[2]) / 3 148 | ari = (ri[0] + ri[1] + ri[2]) / 3 149 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 150 | logger.info("rsum: %.1f" % rsum) 151 | logger.info("Average i2t Recall: %.1f" % ar) 152 | logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) 153 | logger.info("Average t2i Recall: %.1f" % ari) 154 | logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) 155 | else: 156 | npts = npts // 5 157 | results = [] 158 | all_sims = sims.copy() 159 | for i in range(5): 160 | sims = all_sims[i * npts: (i + 1) * npts, i * npts * 5: (i + 1) * npts * 5] 161 | r, rt0 = i2t(npts, sims, return_ranks=True) 162 | logger.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 163 | ri, rti0 = t2i(npts, sims, return_ranks=True) 164 | logger.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 165 | 166 | if i == 0: 167 | rt, rti = rt0, rti0 168 | ar = (r[0] + r[1] + r[2]) / 3 169 | ari = (ri[0] + ri[1] + ri[2]) / 3 170 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 171 | logger.info("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 172 | results += [list(r) + list(ri) + [ar, ari, rsum]] 173 | logger.info("-----------------------------------") 174 | logger.info("Mean metrics: ") 175 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 176 | logger.info("rsum: %.1f" % (mean_metrics[12])) 177 | logger.info("Average i2t Recall: %.1f" % mean_metrics[10]) 178 | logger.info("Image to text: %.1f %.1f %.1f %.1f %.1f" % 179 | mean_metrics[:5]) 180 | logger.info("Average t2i Recall: %.1f" % mean_metrics[11]) 181 | logger.info("Text to image: %.1f %.1f %.1f %.1f %.1f" % 182 | mean_metrics[5:10]) 183 | 184 | 185 | def evalrank(model_path, data_path=None, split='dev', fold5=False, save_path=None, cxc=False): 186 | """ 187 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold 188 | cross-validation is done (only for MSCOCO). Otherwise, the full data is 189 | used for evaluation. 190 | """ 191 | # load model and options 192 | checkpoint = torch.load(model_path) 193 | opt = checkpoint['opt'] 194 | opt.workers = 5 195 | 196 | logger.info(opt) 197 | if not hasattr(opt, 'caption_loss'): 198 | opt.caption_loss = False 199 | 200 | # load vocabulary used by the model 201 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 202 | vocab = tokenizer.vocab 203 | opt.vocab_size = len(vocab) 204 | 205 | opt.backbone_path = '/tmp/data/weights/original_updown_backbone.pth' 206 | if data_path is not None: 207 | opt.data_path = data_path 208 | 209 | # construct model 210 | model = VSEModel(opt) 211 | 212 | model.make_data_parallel() 213 | # load model state 214 | model.load_state_dict(checkpoint['model']) 215 | model.val_start() 216 | 217 | logger.info('Loading dataset') 218 | data_loader = image_caption.get_test_loader(split, opt.data_name, tokenizer, 219 | opt.batch_size, opt.workers, opt) 220 | 221 | logger.info('Computing results...') 222 | with torch.no_grad(): 223 | if opt.precomp_enc_type == 'basic': 224 | # img_embs, cap_embs = encode_data(model, data_loader) 225 | img_embs, cap_embs, cap_lens = encode_data(model, data_loader) 226 | else: 227 | # img_embs, cap_embs = encode_data(model, data_loader, backbone=True) 228 | img_embs, cap_embs, cap_lens = encode_data(model, data_loader, backbone=True) 229 | 230 | logger.info('Images: %d, Captions: %d' % 231 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 232 | 233 | if cxc: 234 | eval_cxc(img_embs, cap_embs, data_path) 235 | else: 236 | if not fold5: 237 | # no cross-validation, full evaluation 238 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) 239 | 240 | # sims = compute_sim(img_embs, cap_embs) 241 | start = time.time() 242 | sims = shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=1000) 243 | npts = img_embs.shape[0] 244 | end = time.time() 245 | print("calculate similarity time:", end - start) 246 | 247 | np.savetxt('/mnt/data2/zk/ESL_bert/' + 'sim_best.txt', sims, fmt='%.5f') 248 | 249 | 250 | if save_path is not None: 251 | np.save(save_path, {'npts': npts, 'sims': sims}) 252 | logger.info('Save the similarity into {}'.format(save_path)) 253 | 254 | end = time.time() 255 | logger.info("calculate similarity time: {}".format(end - start)) 256 | end = time.time() 257 | print("calculate similarity time:", end - start) 258 | 259 | r, rt = i2t(npts, sims, return_ranks=True) 260 | ri, rti = t2i(npts, sims, return_ranks=True) 261 | ar = (r[0] + r[1] + r[2]) / 3 262 | ari = (ri[0] + ri[1] + ri[2]) / 3 263 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 264 | 265 | print("rsum: %.1f" % rsum) 266 | print("Average i2t Recall: %.1f" % ar) 267 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) 268 | print("Average t2i Recall: %.1f" % ari) 269 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) 270 | else: 271 | # 5fold cross-validation, only for MSCOCO 272 | results = [] 273 | for i in range(5): 274 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 275 | 276 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 277 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 278 | start = time.time() 279 | # sims = compute_sim(img_embs_shard, cap_embs_shard) 280 | 281 | sims = shard_attn_scores(model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=1000) 282 | end = time.time() 283 | # logger.info("calculate similarity time: {}".format(end - start)) 284 | 285 | print("calculate similarity time:", end - start) 286 | 287 | np.savetxt('/mnt/data2/zk/ESL_bert/' + str(i) + 'sim_best.txt', sims, fmt='%.5f') 288 | 289 | npts = img_embs_shard.shape[0] 290 | r, rt0 = i2t(npts, sims, return_ranks=True) 291 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 292 | ri, rti0 = t2i(npts, sims, return_ranks=True) 293 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 294 | 295 | if i == 0: 296 | rt, rti = rt0, rti0 297 | ar = (r[0] + r[1] + r[2]) / 3 298 | ari = (ri[0] + ri[1] + ri[2]) / 3 299 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 300 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 301 | results += [list(r) + list(ri) + [ar, ari, rsum]] 302 | 303 | print("-----------------------------------") 304 | print("Mean metrics: ") 305 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 306 | print("rsum: %.1f" % (mean_metrics[12])) 307 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 308 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 309 | mean_metrics[:5]) 310 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 311 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 312 | mean_metrics[5:10]) 313 | 314 | 315 | def compute_sim(images, captions): 316 | similarities = np.matmul(images, np.matrix.transpose(captions)) 317 | return similarities 318 | 319 | 320 | def shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=100): 321 | n_im_shard = (len(img_embs) - 1) // shard_size + 1 322 | n_cap_shard = (len(cap_embs) - 1) // shard_size + 1 323 | 324 | sims = np.zeros((len(img_embs), len(cap_embs))) 325 | 326 | for i in range(n_im_shard): 327 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(img_embs)) 328 | for j in range(n_cap_shard): 329 | sys.stdout.write('\r>> shard_attn_scores batch (%d,%d)' % (i, j)) 330 | ca_start, ca_end = shard_size * j, min(shard_size * (j + 1), len(cap_embs)) 331 | 332 | with torch.no_grad(): 333 | im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda() 334 | ca = torch.from_numpy(cap_embs[ca_start:ca_end]).float().cuda() 335 | l = cap_lens[ca_start:ca_end] 336 | sim, A = model.forward_sim_test(im, ca, l) 337 | 338 | sims[im_start:im_end, ca_start:ca_end] = sim.data.cpu().numpy() 339 | sys.stdout.write('\n') 340 | return sims 341 | 342 | 343 | def i2t(npts, sims, return_ranks=False, mode='coco'): 344 | """ 345 | Images->Text (Image Annotation) 346 | Images: (N, n_region, d) matrix of images 347 | Captions: (5N, max_n_word, d) matrix of captions 348 | CapLens: (5N) array of caption lengths 349 | sims: (N, 5N) matrix of similarity im-cap 350 | """ 351 | ranks = np.zeros(npts) 352 | top1 = np.zeros(npts) 353 | for index in range(npts): 354 | inds = np.argsort(sims[index])[::-1] 355 | if mode == 'coco': 356 | rank = 1e20 357 | for i in range(5 * index, 5 * index + 5, 1): 358 | tmp = np.where(inds == i)[0][0] 359 | if tmp < rank: 360 | rank = tmp 361 | ranks[index] = rank 362 | top1[index] = inds[0] 363 | else: 364 | rank = np.where(inds == index)[0][0] 365 | ranks[index] = rank 366 | top1[index] = inds[0] 367 | 368 | # Compute metrics 369 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 370 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 371 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 372 | medr = np.floor(np.median(ranks)) + 1 373 | meanr = ranks.mean() + 1 374 | 375 | if return_ranks: 376 | return (r1, r5, r10, medr, meanr), (ranks, top1) 377 | else: 378 | return (r1, r5, r10, medr, meanr) 379 | 380 | 381 | def t2i(npts, sims, return_ranks=False, mode='coco'): 382 | """ 383 | Text->Images (Image Search) 384 | Images: (N, n_region, d) matrix of images 385 | Captions: (5N, max_n_word, d) matrix of captions 386 | CapLens: (5N) array of caption lengths 387 | sims: (N, 5N) matrix of similarity im-cap 388 | """ 389 | # npts = images.shape[0] 390 | 391 | if mode == 'coco': 392 | ranks = np.zeros(5 * npts) 393 | top1 = np.zeros(5 * npts) 394 | else: 395 | ranks = np.zeros(npts) 396 | top1 = np.zeros(npts) 397 | 398 | # --> (5N(caption), N(image)) 399 | sims = sims.T 400 | 401 | for index in range(npts): 402 | if mode == 'coco': 403 | for i in range(5): 404 | inds = np.argsort(sims[5 * index + i])[::-1] 405 | ranks[5 * index + i] = np.where(inds == index)[0][0] 406 | top1[5 * index + i] = inds[0] 407 | else: 408 | inds = np.argsort(sims[index])[::-1] 409 | ranks[index] = np.where(inds == index)[0][0] 410 | top1[index] = inds[0] 411 | 412 | # Compute metrics 413 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 414 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 415 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 416 | medr = np.floor(np.median(ranks)) + 1 417 | meanr = ranks.mean() + 1 418 | if return_ranks: 419 | return (r1, r5, r10, medr, meanr), (ranks, top1) 420 | else: 421 | return (r1, r5, r10, medr, meanr) 422 | 423 | 424 | """ 425 | CxC related evaluation. 426 | """ 427 | 428 | def eval_cxc(images, captions, data_path): 429 | import os 430 | import json 431 | cxc_annot_base = os.path.join(data_path, 'cxc_annots') 432 | img_id_path = os.path.join(cxc_annot_base, 'testall_ids.txt') 433 | cap_id_path = os.path.join(cxc_annot_base, 'testall_capids.txt') 434 | 435 | images = images[::5, :] 436 | 437 | with open(img_id_path) as f: 438 | img_ids = f.readlines() 439 | with open(cap_id_path) as f: 440 | cap_ids = f.readlines() 441 | 442 | img_ids = [img_id.strip() for i, img_id in enumerate(img_ids) if i % 5 == 0] 443 | cap_ids = [cap_id.strip() for cap_id in cap_ids] 444 | 445 | with open(os.path.join(cxc_annot_base, 'cxc_it.json')) as f_it: 446 | cxc_it = json.load(f_it) 447 | with open(os.path.join(cxc_annot_base, 'cxc_i2i.json')) as f_i2i: 448 | cxc_i2i = json.load(f_i2i) 449 | with open(os.path.join(cxc_annot_base, 'cxc_t2t.json')) as f_t2t: 450 | cxc_t2t = json.load(f_t2t) 451 | 452 | sims = compute_sim(images, captions) 453 | t2i_recalls = cxc_inter(sims.T, img_ids, cap_ids, cxc_it['t2i']) 454 | i2t_recalls = cxc_inter(sims, cap_ids, img_ids, cxc_it['i2t']) 455 | logger.info('T2I R@1: {}, R@5: {}, R@10: {}'.format(*t2i_recalls)) 456 | logger.info('I2T R@1: {}, R@5: {}, R@10: {}'.format(*i2t_recalls)) 457 | 458 | i2i_recalls = cxc_intra(images, img_ids, cxc_i2i) 459 | t2t_recalls = cxc_intra(captions, cap_ids, cxc_t2t, text=True) 460 | logger.info('I2I R@1: {}, R@5: {}, R@10: {}'.format(*i2i_recalls)) 461 | logger.info('T2T R@1: {}, R@5: {}, R@10: {}'.format(*t2t_recalls)) 462 | 463 | 464 | def cxc_inter(sims, data_ids, query_ids, annot): 465 | ranks = list() 466 | for idx, query_id in enumerate(query_ids): 467 | if query_id not in annot: 468 | raise ValueError('unexpected query id {}'.format(query_id)) 469 | pos_data_ids = annot[query_id] 470 | pos_data_ids = [pos_data_id for pos_data_id in pos_data_ids if str(pos_data_id[0]) in data_ids] 471 | pos_data_indices = [data_ids.index(str(pos_data_id[0])) for pos_data_id in pos_data_ids] 472 | rank = 1e20 473 | inds = np.argsort(sims[idx])[::-1] 474 | for pos_data_idx in pos_data_indices: 475 | tmp = np.where(inds == pos_data_idx)[0][0] 476 | if tmp < rank: 477 | rank = tmp 478 | ranks.append(rank) 479 | ranks = np.array(ranks) 480 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 481 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 482 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 483 | return (r1, r5, r10) 484 | 485 | 486 | def cxc_intra(embs, data_ids, annot, text=False): 487 | pos_thresh = 3.0 if text else 2.5 # threshold for positive pairs according to the CxC paper 488 | 489 | sims = compute_sim(embs, embs) 490 | np.fill_diagonal(sims, 0) 491 | 492 | ranks = list() 493 | for idx, data_id in enumerate(data_ids): 494 | sim_items = annot[data_id] 495 | pos_items = [item for item in sim_items if item[1] >= pos_thresh] 496 | rank = 1e20 497 | inds = np.argsort(sims[idx])[::-1] 498 | if text: 499 | coco_pos = list(range(idx // 5 * 5, (idx // 5 + 1) * 5)) 500 | coco_pos.remove(idx) 501 | pos_indices = coco_pos 502 | pos_indices.extend([data_ids.index(str(pos_item[0])) for pos_item in pos_items]) 503 | else: 504 | pos_indices = [data_ids.index(str(pos_item[0])) for pos_item in pos_items] 505 | if len(pos_indices) == 0: # skip it since there is positive example in the annotation 506 | continue 507 | for pos_idx in pos_indices: 508 | tmp = np.where(inds == pos_idx)[0][0] 509 | if tmp < rank: 510 | rank = tmp 511 | ranks.append(rank) 512 | 513 | ranks = np.array(ranks) 514 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 515 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 516 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 517 | return (r1, r5, r10) 518 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class ContrastiveLoss(nn.Module): 7 | """ 8 | Compute contrastive loss (max-margin based) 9 | """ 10 | 11 | def __init__(self, opt, margin=0, max_violation=False): 12 | super(ContrastiveLoss, self).__init__() 13 | self.opt = opt 14 | self.margin = margin 15 | self.max_violation = max_violation 16 | 17 | def max_violation_on(self): 18 | self.max_violation = True 19 | print('Use VSE++ objective.') 20 | 21 | def max_violation_off(self): 22 | self.max_violation = False 23 | print('Use VSE0 objective.') 24 | 25 | def forward(self, scores): 26 | # compute image-sentence score matrix 27 | # scores = get_sim(im, s) 28 | # diagonal = scores.diag().view(im.size(0), 1) 29 | 30 | diagonal = scores.diag().view(scores.size(0), 1) 31 | d1 = diagonal.expand_as(scores) 32 | d2 = diagonal.t().expand_as(scores) 33 | 34 | # compare every diagonal score to scores in its column 35 | # caption retrieval 36 | cost_s = (self.margin + scores - d1).clamp(min=0) 37 | # compare every diagonal score to scores in its row 38 | # image retrieval 39 | cost_im = (self.margin + scores - d2).clamp(min=0) 40 | 41 | # clear diagonals 42 | mask = torch.eye(scores.size(0)) > .5 43 | I = Variable(mask) 44 | if torch.cuda.is_available(): 45 | I = I.cuda() 46 | cost_s = cost_s.masked_fill_(I, 0) 47 | cost_im = cost_im.masked_fill_(I, 0) 48 | 49 | # keep the maximum violating negative for each query 50 | if self.max_violation: 51 | cost_s = cost_s.max(1)[0] 52 | cost_im = cost_im.max(0)[0] 53 | 54 | return cost_s.sum() + cost_im.sum() 55 | 56 | 57 | def get_sim(images, captions): 58 | similarities = images.mm(captions.t()) 59 | return similarities 60 | 61 | -------------------------------------------------------------------------------- /lib/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__init__.py -------------------------------------------------------------------------------- /lib/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/modules/__pycache__/mlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__pycache__/mlp.cpython-36.pyc -------------------------------------------------------------------------------- /lib/modules/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /lib/modules/aggr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/modules/aggr/__init__.py -------------------------------------------------------------------------------- /lib/modules/aggr/gpo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 6 | 7 | 8 | def positional_encoding_1d(d_model, length): 9 | """ 10 | :param d_model: dimension of the model 11 | :param length: length of positions 12 | :return: length*d_model position matrix 13 | """ 14 | if d_model % 2 != 0: 15 | raise ValueError("Cannot use sin/cos positional encoding with " 16 | "odd dim (got dim={:d})".format(d_model)) 17 | pe = torch.zeros(length, d_model) 18 | position = torch.arange(0, length).unsqueeze(1) 19 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * 20 | -(math.log(10000.0) / d_model))) 21 | pe[:, 0::2] = torch.sin(position.float() * div_term) 22 | pe[:, 1::2] = torch.cos(position.float() * div_term) 23 | 24 | return pe 25 | 26 | 27 | class GPO(nn.Module): 28 | def __init__(self, d_pe, d_hidden): 29 | super(GPO, self).__init__() 30 | self.d_pe = d_pe 31 | self.d_hidden = d_hidden 32 | 33 | self.pe_database = {} 34 | self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True) 35 | self.linear = nn.Linear(self.d_hidden, 1, bias=False) 36 | 37 | def compute_pool_weights(self, lengths, features): 38 | max_len = int(lengths.max()) 39 | pe_max_len = self.get_pe(max_len) 40 | pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device) 41 | mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device) 42 | mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1) 43 | pes = pes.masked_fill(mask == 0, 0) 44 | 45 | self.gru.flatten_parameters() 46 | packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False) 47 | out, _ = self.gru(packed) 48 | padded = pad_packed_sequence(out, batch_first=True) 49 | out_emb, out_len = padded 50 | out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2 51 | scores = self.linear(out_emb) 52 | scores[torch.where(mask == 0)] = -10000 53 | 54 | weights = torch.softmax(scores / 0.1, 1) 55 | return weights, mask 56 | 57 | def forward(self, features, lengths): 58 | """ 59 | :param features: features with shape B x K x D 60 | :param lengths: B x 1, specify the length of each data sample. 61 | :return: pooled feature with shape B x D 62 | """ 63 | pool_weights, mask = self.compute_pool_weights(lengths, features) 64 | 65 | features = features[:, :int(lengths.max()), :] 66 | sorted_features = features.masked_fill(mask == 0, -10000) 67 | sorted_features = sorted_features.sort(dim=1, descending=True)[0] 68 | sorted_features = sorted_features.masked_fill(mask == 0, 0) 69 | 70 | pooled_features = (sorted_features * pool_weights).sum(1) 71 | return pooled_features, pool_weights 72 | 73 | def get_pe(self, length): 74 | """ 75 | 76 | :param length: the length of the sequence 77 | :return: the positional encoding of the given length 78 | """ 79 | length = int(length) 80 | if length in self.pe_database: 81 | return self.pe_database[length] 82 | else: 83 | pe = positional_encoding_1d(self.d_pe, length) 84 | self.pe_database[length] = pe 85 | return pe 86 | -------------------------------------------------------------------------------- /lib/modules/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class TwoLayerMLP(nn.Module): 6 | def __init__(self, num_features, hid_dim, out_dim, return_hidden=False): 7 | super().__init__() 8 | self.return_hidden = return_hidden 9 | self.model = nn.Sequential( 10 | nn.Linear(num_features, hid_dim), 11 | nn.ReLU(), 12 | nn.Linear(hid_dim, out_dim), 13 | ) 14 | 15 | for m in self.model: 16 | if isinstance(m, nn.Linear): 17 | nn.init.kaiming_normal_(m.weight) 18 | nn.init.constant_(m.bias, 0) 19 | 20 | def forward(self, x): 21 | if not self.return_hidden: 22 | return self.model(x) 23 | else: 24 | hid_feat = self.model[:2](x) 25 | results = self.model[2:](hid_feat) 26 | return hid_feat, results 27 | 28 | 29 | class MLP(nn.Module): 30 | """ Very simple multi-layer perceptron (also called FFN)""" 31 | 32 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 33 | super().__init__() 34 | self.output_dim = output_dim 35 | self.num_layers = num_layers 36 | h = [hidden_dim] * (num_layers - 1) 37 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 38 | self.bns = nn.ModuleList(nn.BatchNorm1d(k) for k in h + [output_dim]) 39 | 40 | def forward(self, x): 41 | B, N, D = x.size() 42 | x = x.reshape(B*N, D) 43 | for i, (bn, layer) in enumerate(zip(self.bns, self.layers)): 44 | x = F.relu(bn(layer(x))) if i < self.num_layers - 1 else layer(x) 45 | x = x.view(B, N, self.output_dim) 46 | return x 47 | 48 | -------------------------------------------------------------------------------- /lib/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | import torch.utils.model_zoo as model_zoo 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | __all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] 12 | 13 | model_urls = { 14 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | def __init__(self, block, layers, width_mult, num_classes=1000): 99 | self.inplanes = 64 * width_mult 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(self.inplanes) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 106 | self.layer1 = self._make_layer(block, 64 * width_mult, layers[0]) 107 | self.layer2 = self._make_layer(block, 128 * width_mult, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256 * width_mult, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512 * width_mult, layers[3], stride=2) 110 | 111 | self.avgpool = nn.AvgPool2d(7) 112 | self.fc = nn.Linear(512 * block.expansion * width_mult, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def resnet50(pretrained=False, width_mult=1): 158 | """Constructs a ResNet-50 model. 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(Bottleneck, [3, 4, 6, 3], width_mult) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 165 | return model 166 | 167 | 168 | def resnet101(pretrained=False, width_mult=1): 169 | """Constructs a ResNet-101 model. 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | model = ResNet(Bottleneck, [3, 4, 23, 3], width_mult) 174 | if pretrained: 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 176 | return model 177 | 178 | 179 | def resnet152(pretrained=False, width_mult=1): 180 | """Constructs a ResNet-152 model. 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(Bottleneck, [3, 8, 36, 3], width_mult) 185 | if pretrained: 186 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 187 | return model 188 | 189 | 190 | class ResnetFeatureExtractor(nn.Module): 191 | def __init__(self, backbone_source, weights_path, pooling_size=7, fixed_blocks=2): 192 | super(ResnetFeatureExtractor, self).__init__() 193 | self.backbone_source = backbone_source 194 | self.weights_path = weights_path 195 | self.pooling_size = pooling_size 196 | self.fixed_blocks = fixed_blocks 197 | 198 | if 'detector' in self.backbone_source: 199 | self.resnet = resnet101() 200 | elif self.backbone_source == 'imagenet': 201 | self.resnet = resnet101(pretrained=True) 202 | elif self.backbone_source == 'imagenet_res50': 203 | self.resnet = resnet50(pretrained=True) 204 | elif self.backbone_source == 'imagenet_res152': 205 | self.resnet = resnet152(pretrained=True) 206 | elif self.backbone_source == 'imagenet_resnext': 207 | self.resnet = torch.hub.load('pytorch/vision:v0.4.2', 'resnext101_32x8d', pretrained=True) 208 | elif 'wsl' in self.backbone_source: 209 | self.resnet = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl') 210 | else: 211 | raise ValueError('Unknown backbone source {}'.format(self.backbone_source)) 212 | 213 | self._init_modules() 214 | 215 | def _init_modules(self): 216 | # Build resnet. 217 | self.base = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, 218 | self.resnet.maxpool, self.resnet.layer1, self.resnet.layer2, self.resnet.layer3) 219 | self.top = nn.Sequential(self.resnet.layer4) 220 | 221 | if self.weights_path != '': 222 | if 'detector' in self.backbone_source: 223 | if os.path.exists(self.weights_path): 224 | logger.info( 225 | 'Loading pretrained backbone weights from {} for backbone source {}'.format(self.weights_path, 226 | self.backbone_source)) 227 | backbone_ckpt = torch.load(self.weights_path) 228 | self.base.load_state_dict(backbone_ckpt['base']) 229 | self.top.load_state_dict(backbone_ckpt['top']) 230 | else: 231 | raise ValueError('Could not find weights for backbone CNN at {}'.format(self.weights_path)) 232 | else: 233 | logger.info('Did not load external checkpoints') 234 | self.unfreeze_base() 235 | 236 | def set_fixed_blocks(self, fixed_blocks): 237 | self.fixed_blocks = fixed_blocks 238 | 239 | def get_fixed_blocks(self): 240 | return self.fixed_blocks 241 | 242 | def unfreeze_base(self): 243 | assert (0 <= self.fixed_blocks < 4) 244 | if self.fixed_blocks == 3: 245 | for p in self.base[6].parameters(): p.requires_grad = False 246 | for p in self.base[5].parameters(): p.requires_grad = False 247 | for p in self.base[4].parameters(): p.requires_grad = False 248 | for p in self.base[0].parameters(): p.requires_grad = False 249 | for p in self.base[1].parameters(): p.requires_grad = False 250 | if self.fixed_blocks == 2: 251 | for p in self.base[6].parameters(): p.requires_grad = True 252 | for p in self.base[5].parameters(): p.requires_grad = False 253 | for p in self.base[4].parameters(): p.requires_grad = False 254 | for p in self.base[0].parameters(): p.requires_grad = False 255 | for p in self.base[1].parameters(): p.requires_grad = False 256 | if self.fixed_blocks == 1: 257 | for p in self.base[6].parameters(): p.requires_grad = True 258 | for p in self.base[5].parameters(): p.requires_grad = True 259 | for p in self.base[4].parameters(): p.requires_grad = False 260 | for p in self.base[0].parameters(): p.requires_grad = False 261 | for p in self.base[1].parameters(): p.requires_grad = False 262 | if self.fixed_blocks == 0: 263 | for p in self.base[6].parameters(): p.requires_grad = True 264 | for p in self.base[5].parameters(): p.requires_grad = True 265 | for p in self.base[4].parameters(): p.requires_grad = True 266 | for p in self.base[0].parameters(): p.requires_grad = True 267 | for p in self.base[1].parameters(): p.requires_grad = True 268 | 269 | logger.info('Resnet backbone now has fixed blocks {}'.format(self.fixed_blocks)) 270 | 271 | def freeze_base(self): 272 | for p in self.base.parameters(): 273 | p.requires_grad = False 274 | 275 | def train(self, mode=True): 276 | # Override train so that the training mode is set as we want (BN does not update the running stats) 277 | nn.Module.train(self, mode) 278 | if mode: 279 | # fix all bn layers 280 | def set_bn_eval(m): 281 | classname = m.__class__.__name__ 282 | if classname.find('BatchNorm') != -1: 283 | m.eval() 284 | 285 | self.base.apply(set_bn_eval) 286 | self.top.apply(set_bn_eval) 287 | 288 | def _head_to_tail(self, pool5): 289 | fc7 = self.top(pool5).mean(3).mean(2) 290 | return fc7 291 | 292 | def forward(self, im_data): 293 | b_s = im_data.size(0) 294 | base_feat = self.base(im_data) 295 | top_feat = self.top(base_feat) 296 | features = top_feat.view(b_s, top_feat.size(1), -1).permute(0, 2, 1) 297 | return features 298 | 299 | 300 | if __name__ == '__main__': 301 | import numpy as np 302 | 303 | def count_params(model): 304 | model_parameters = model.parameters() 305 | params = sum([np.prod(p.size()) for p in model_parameters]) 306 | return params 307 | 308 | model = resnet50(pretrained=False, width_mult=1) 309 | num_params = count_params(model) 310 | 311 | -------------------------------------------------------------------------------- /lib/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/lib/pytorch-logo-dark.png -------------------------------------------------------------------------------- /lib/vse.py: -------------------------------------------------------------------------------- 1 | """ESL model, building on the top of VSE model""" 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import torch.nn.functional as F 8 | 9 | import torch.nn.init 10 | import torch.backends.cudnn as cudnn 11 | from torch.nn.utils import clip_grad_norm_ 12 | 13 | from lib.encoders import get_image_encoder, get_text_encoder 14 | from lib.loss import ContrastiveLoss 15 | 16 | import logging 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | def logging_func(log_file, message): 21 | with open(log_file,'a') as f: 22 | f.write(message) 23 | f.close() 24 | 25 | def l2norm(X, dim=-1, eps=1e-8): 26 | """L2-normalize columns of X""" 27 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 28 | X = torch.div(X, norm) 29 | return X 30 | 31 | 32 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 33 | """Returns cosine similarity between x1 and x2, computed along dim.""" 34 | w12 = torch.sum(x1 * x2, dim) 35 | w1 = torch.norm(x1, 2, dim) 36 | w2 = torch.norm(x2, 2, dim) 37 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 38 | 39 | 40 | ######################################################## 41 | ### For the dimensional selective mask, we design both heuristic and adaptive strategies. 42 | ### You can use this flag to control which strategy is selected. True -> heuristic strategy, False -> adaptive strategy 43 | 44 | heuristic_strategy = False 45 | ######################################################## 46 | 47 | 48 | if heuristic_strategy: 49 | 50 | ### Heuristic Dimensional Selective Mask 51 | class Image_levels(nn.Module): 52 | def __init__(self, opt): 53 | super(Image_levels, self).__init__() 54 | self.sub_space = opt.embed_size 55 | self.kernel_size = int(opt.kernel_size) 56 | 57 | self.kernel_img_1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 58 | self.kernel_img_2 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 59 | self.kernel_img_3 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 60 | self.kernel_img_4 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 61 | self.kernel_img_5 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 62 | self.kernel_img_6 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 63 | self.kernel_img_7 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 64 | self.kernel_img_8 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 65 | 66 | 67 | def get_image_levels(self, img_emb, batch_size, n_region): 68 | img_emb_1 = self.kernel_img_1(img_emb.reshape(-1, self.sub_space).unsqueeze(-2)).sum(1) 69 | img_emb_1 = l2norm(img_emb_1.reshape(batch_size, n_region, -1), -1) 70 | 71 | emb_size = img_emb_1.size(-1) 72 | img_emb_2 = self.kernel_img_2(img_emb_1.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 73 | img_emb_2 = l2norm(img_emb_2.reshape(batch_size, n_region, -1), -1) 74 | 75 | emb_size = img_emb_2.size(-1) 76 | img_emb_3 = self.kernel_img_3(img_emb_2.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 77 | img_emb_3 = l2norm(img_emb_3.reshape(batch_size, n_region, -1), -1) 78 | 79 | emb_size = img_emb_3.size(-1) 80 | img_emb_4 = self.kernel_img_4(img_emb_3.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 81 | img_emb_4 = l2norm(img_emb_4.reshape(batch_size, n_region, -1), -1) 82 | 83 | emb_size = img_emb_4.size(-1) 84 | img_emb_5 = self.kernel_img_5(img_emb_4.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 85 | img_emb_5 = l2norm(img_emb_5.reshape(batch_size, n_region, -1), -1) 86 | 87 | emb_size = img_emb_5.size(-1) 88 | img_emb_6 = self.kernel_img_6(img_emb_5.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 89 | img_emb_6 = l2norm(img_emb_6.reshape(batch_size, n_region, -1), -1) 90 | 91 | emb_size = img_emb_6.size(-1) 92 | img_emb_7 = self.kernel_img_7(img_emb_6.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 93 | img_emb_7 = l2norm(img_emb_7.reshape(batch_size, n_region, -1), -1) 94 | 95 | emb_size = img_emb_7.size(-1) 96 | img_emb_8 = self.kernel_img_8(img_emb_7.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 97 | img_emb_8 = l2norm(img_emb_8.reshape(batch_size, n_region, -1), -1) 98 | 99 | 100 | 101 | return torch.cat([img_emb, img_emb_1, img_emb_2, img_emb_3, img_emb_4, img_emb_5, img_emb_6, img_emb_7, img_emb_8], -1) 102 | 103 | def forward(self, img_emb): 104 | 105 | batch_size, n_region, embed_size = img_emb.size(0), img_emb.size(1), img_emb.size(2) 106 | 107 | return self.get_image_levels(img_emb, batch_size, n_region) 108 | 109 | 110 | class Text_levels(nn.Module): 111 | 112 | def __init__(self, opt): 113 | super(Text_levels, self).__init__() 114 | self.sub_space = opt.embed_size 115 | self.kernel_size = int(opt.kernel_size) 116 | 117 | self.kernel_txt_1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 118 | self.kernel_txt_2 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 119 | self.kernel_txt_3 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 120 | self.kernel_txt_4 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 121 | self.kernel_txt_5 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 122 | self.kernel_txt_6 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 123 | self.kernel_txt_7 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 124 | self.kernel_txt_8 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=self.kernel_size, stride=self.kernel_size, bias=False) 125 | 126 | 127 | def get_text_levels(self, cap_i, batch_size, n_word): 128 | cap_i_1 = self.kernel_txt_1(cap_i.reshape(-1, self.sub_space).unsqueeze(-2)).sum(1) 129 | cap_i_expand_1 = l2norm(cap_i_1.reshape(batch_size, n_word, -1), -1) 130 | 131 | emb_size = cap_i_expand_1.size(-1) 132 | cap_i_2 = self.kernel_txt_2(cap_i_1.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 133 | cap_i_expand_2 = l2norm(cap_i_2.reshape(batch_size, n_word, -1), -1) 134 | 135 | emb_size = cap_i_expand_2.size(-1) 136 | cap_i_3 = self.kernel_txt_3(cap_i_2.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 137 | cap_i_expand_3 = l2norm(cap_i_3.reshape(batch_size, n_word, -1), -1) 138 | 139 | emb_size = cap_i_expand_3.size(-1) 140 | cap_i_4 = self.kernel_txt_4(cap_i_3.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 141 | cap_i_expand_4 = l2norm(cap_i_4.reshape(batch_size, n_word, -1), -1) 142 | 143 | emb_size = cap_i_expand_4.size(-1) 144 | cap_i_5 = self.kernel_txt_5(cap_i_4.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 145 | cap_i_expand_5 = l2norm(cap_i_5.reshape(batch_size, n_word, -1), -1) 146 | 147 | emb_size = cap_i_expand_5.size(-1) 148 | cap_i_6 = self.kernel_txt_6(cap_i_5.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 149 | cap_i_expand_6 = l2norm(cap_i_6.reshape(batch_size, n_word, -1), -1) 150 | 151 | emb_size = cap_i_expand_6.size(-1) 152 | cap_i_7 = self.kernel_txt_7(cap_i_6.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 153 | cap_i_expand_7 = l2norm(cap_i_7.reshape(batch_size, n_word, -1), -1) 154 | 155 | emb_size = cap_i_expand_7.size(-1) 156 | cap_i_8 = self.kernel_txt_8(cap_i_7.reshape(-1, emb_size).unsqueeze(-2)).sum(1) 157 | cap_i_expand_8 = l2norm(cap_i_8.reshape(batch_size, n_word, -1), -1) 158 | 159 | return torch.cat([cap_i, cap_i_expand_1, cap_i_expand_2, cap_i_expand_3, cap_i_expand_4, cap_i_expand_5, cap_i_expand_6, cap_i_expand_7, cap_i_expand_8], -1) 160 | 161 | 162 | def forward(self, cap_i): 163 | 164 | batch_size, n_word, embed_size = cap_i.size(0), cap_i.size(1), cap_i.size(2) 165 | 166 | return self.get_text_levels(cap_i, batch_size, n_word) 167 | 168 | else: 169 | 170 | #### Adaptive Dimensional Selective Mask 171 | class Image_levels(nn.Module): 172 | def __init__(self, opt): 173 | super(Image_levels, self).__init__() 174 | self.sub_space = opt.embed_size 175 | self.kernel_size = int(opt.kernel_size) 176 | self.out_channels = 2 177 | 178 | self.masks_1 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 0))) # num_embedding, dims_input 179 | self.masks_2 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 1))) 180 | self.masks_3 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 2))) 181 | self.masks_4 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 3))) 182 | self.masks_5 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 4))) 183 | self.masks_6 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 5))) 184 | self.masks_7 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 6))) 185 | self.masks_8 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 7))) 186 | 187 | 188 | def get_image_levels(self, img_emb, batch_size, n_region): 189 | 190 | 191 | sub_space_index = torch.tensor(torch.linspace(0, 1024, steps=1024, dtype=torch.int)).cuda() 192 | Dim_learned_mask_1 = l2norm(self.masks_1(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 1))]), dim=-1) 193 | Dim_learned_mask_2 = l2norm(self.masks_2(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 2))]), dim=-1) 194 | Dim_learned_mask_3 = l2norm(self.masks_3(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 3))]), dim=-1) 195 | Dim_learned_mask_4 = l2norm(self.masks_4(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 4))]), dim=-1) 196 | Dim_learned_mask_5 = l2norm(self.masks_5(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 5))]), dim=-1) 197 | Dim_learned_mask_6 = l2norm(self.masks_6(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 6))]), dim=-1) 198 | Dim_learned_mask_7 = l2norm(self.masks_7(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 7))]), dim=-1) 199 | Dim_learned_mask_8 = l2norm(self.masks_8(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 8))]), dim=-1) 200 | 201 | 202 | if Dim_learned_mask_1.size(1) < self.out_channels: 203 | select_nums = Dim_learned_mask_1.size(1) 204 | else: 205 | select_nums = self.out_channels 206 | Dim_learned_range = Dim_learned_mask_1.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 207 | Dim_learned_mask_1 = (Dim_learned_mask_1 >= Dim_learned_range).float() * Dim_learned_mask_1 208 | 209 | 210 | if Dim_learned_mask_2.size(1) < self.out_channels: 211 | select_nums = Dim_learned_mask_2.size(1) 212 | else: 213 | select_nums = self.out_channels 214 | Dim_learned_range = Dim_learned_mask_2.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 215 | Dim_learned_mask_2 = (Dim_learned_mask_2 >= Dim_learned_range).float() * Dim_learned_mask_2 216 | 217 | 218 | if Dim_learned_mask_3.size(1) < self.out_channels: 219 | select_nums = Dim_learned_mask_3.size(1) 220 | else: 221 | select_nums = self.out_channels 222 | Dim_learned_range = Dim_learned_mask_3.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 223 | Dim_learned_mask_3 = (Dim_learned_mask_3 >= Dim_learned_range).float() * Dim_learned_mask_3 224 | 225 | 226 | if Dim_learned_mask_4.size(1) < self.out_channels: 227 | select_nums = Dim_learned_mask_4.size(1) 228 | else: 229 | select_nums = self.out_channels 230 | Dim_learned_range = Dim_learned_mask_4.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 231 | Dim_learned_mask_4 = (Dim_learned_mask_4 >= Dim_learned_range).float() * Dim_learned_mask_4 232 | 233 | 234 | if Dim_learned_mask_5.size(1) < self.out_channels: 235 | select_nums = Dim_learned_mask_5.size(1) 236 | else: 237 | select_nums = self.out_channels 238 | Dim_learned_range = Dim_learned_mask_5.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 239 | Dim_learned_mask_5 = (Dim_learned_mask_5 >= Dim_learned_range).float() * Dim_learned_mask_5 240 | 241 | 242 | if Dim_learned_mask_6.size(1) < self.out_channels: 243 | select_nums = Dim_learned_mask_6.size(1) 244 | else: 245 | select_nums = self.out_channels 246 | Dim_learned_range = Dim_learned_mask_6.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 247 | Dim_learned_mask_6 = (Dim_learned_mask_6 >= Dim_learned_range).float() * Dim_learned_mask_6 248 | 249 | 250 | if Dim_learned_mask_7.size(1) < self.out_channels: 251 | select_nums = Dim_learned_mask_7.size(1) 252 | else: 253 | select_nums = self.out_channels 254 | Dim_learned_range = Dim_learned_mask_7.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 255 | Dim_learned_mask_7 = (Dim_learned_mask_7 >= Dim_learned_range).float() * Dim_learned_mask_7 256 | 257 | 258 | if Dim_learned_mask_8.size(1) < self.out_channels: 259 | select_nums = Dim_learned_mask_8.size(1) 260 | else: 261 | select_nums = self.out_channels 262 | Dim_learned_range = Dim_learned_mask_8.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 263 | Dim_learned_mask_8 = (Dim_learned_mask_8 >= Dim_learned_range).float() * Dim_learned_mask_8 264 | 265 | 266 | img_emb_1 = img_emb.reshape(-1, self.sub_space) @ Dim_learned_mask_1.t() 267 | img_emb_1 = l2norm(img_emb_1.reshape(batch_size, n_region, -1), -1) 268 | 269 | emb_size = img_emb_1.size(-1) 270 | img_emb_2 = img_emb_1.reshape(-1, emb_size) @ Dim_learned_mask_2.t() 271 | img_emb_2 = l2norm(img_emb_2.reshape(batch_size, n_region, -1), -1) 272 | 273 | emb_size = img_emb_2.size(-1) 274 | img_emb_3 = img_emb_2.reshape(-1, emb_size) @ Dim_learned_mask_3.t() 275 | img_emb_3 = l2norm(img_emb_3.reshape(batch_size, n_region, -1), -1) 276 | 277 | emb_size = img_emb_3.size(-1) 278 | img_emb_4 = img_emb_3.reshape(-1, emb_size) @ Dim_learned_mask_4.t() 279 | img_emb_4 = l2norm(img_emb_4.reshape(batch_size, n_region, -1), -1) 280 | 281 | emb_size = img_emb_4.size(-1) 282 | img_emb_5 = img_emb_4.reshape(-1, emb_size) @ Dim_learned_mask_5.t() 283 | img_emb_5 = l2norm(img_emb_5.reshape(batch_size, n_region, -1), -1) 284 | 285 | emb_size = img_emb_5.size(-1) 286 | img_emb_6 = img_emb_5.reshape(-1, emb_size) @ Dim_learned_mask_6.t() 287 | img_emb_6 = l2norm(img_emb_6.reshape(batch_size, n_region, -1), -1) 288 | 289 | emb_size = img_emb_6.size(-1) 290 | img_emb_7 = img_emb_6.reshape(-1, emb_size) @ Dim_learned_mask_7.t() 291 | img_emb_7 = l2norm(img_emb_7.reshape(batch_size, n_region, -1), -1) 292 | 293 | emb_size = img_emb_7.size(-1) 294 | img_emb_8 = img_emb_7.reshape(-1, emb_size) @ Dim_learned_mask_8.t() 295 | img_emb_8 = l2norm(img_emb_8.reshape(batch_size, n_region, -1), -1) 296 | 297 | 298 | return torch.cat([img_emb, img_emb_1, img_emb_2, img_emb_3, img_emb_4, img_emb_5, img_emb_6, img_emb_7, img_emb_8], -1) 299 | 300 | def forward(self, img_emb): 301 | 302 | batch_size, n_region, embed_size = img_emb.size(0), img_emb.size(1), img_emb.size(2) 303 | 304 | return self.get_image_levels(img_emb, batch_size, n_region) 305 | 306 | 307 | class Text_levels(nn.Module): 308 | 309 | def __init__(self, opt): 310 | super(Text_levels, self).__init__() 311 | self.sub_space = opt.embed_size 312 | self.kernel_size = int(opt.kernel_size) 313 | self.out_channels = 2 314 | 315 | self.masks_1 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 0))) # num_embedding, dims_input 316 | self.masks_2 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 1))) 317 | self.masks_3 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 2))) 318 | self.masks_4 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 3))) 319 | self.masks_5 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 4))) 320 | self.masks_6 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 5))) 321 | self.masks_7 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 6))) 322 | self.masks_8 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 7))) 323 | 324 | def get_text_levels(self, cap_i, batch_size, n_word): 325 | 326 | sub_space_index = torch.tensor(torch.linspace(0, 1024, steps=1024, dtype=torch.int)).cuda() 327 | Dim_learned_mask_1 = l2norm(self.masks_1(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 1))]), dim=-1) 328 | Dim_learned_mask_2 = l2norm(self.masks_2(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 2))]), dim=-1) 329 | Dim_learned_mask_3 = l2norm(self.masks_3(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 3))]), dim=-1) 330 | Dim_learned_mask_4 = l2norm(self.masks_4(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 4))]), dim=-1) 331 | Dim_learned_mask_5 = l2norm(self.masks_5(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 5))]), dim=-1) 332 | Dim_learned_mask_6 = l2norm(self.masks_6(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 6))]), dim=-1) 333 | Dim_learned_mask_7 = l2norm(self.masks_7(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 7))]), dim=-1) 334 | Dim_learned_mask_8 = l2norm(self.masks_8(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 8))]), dim=-1) 335 | 336 | 337 | if Dim_learned_mask_1.size(1) < self.out_channels: 338 | select_nums = Dim_learned_mask_1.size(1) 339 | else: 340 | select_nums = self.out_channels 341 | Dim_learned_range = Dim_learned_mask_1.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 342 | Dim_learned_mask_1 = (Dim_learned_mask_1 >= Dim_learned_range).float() * Dim_learned_mask_1 343 | 344 | 345 | if Dim_learned_mask_2.size(1) < self.out_channels: 346 | select_nums = Dim_learned_mask_2.size(1) 347 | else: 348 | select_nums = self.out_channels 349 | Dim_learned_range = Dim_learned_mask_2.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 350 | Dim_learned_mask_2 = (Dim_learned_mask_2 >= Dim_learned_range).float() * Dim_learned_mask_2 351 | 352 | 353 | if Dim_learned_mask_3.size(1) < self.out_channels: 354 | select_nums = Dim_learned_mask_3.size(1) 355 | else: 356 | select_nums = self.out_channels 357 | Dim_learned_range = Dim_learned_mask_3.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 358 | Dim_learned_mask_3 = (Dim_learned_mask_3 >= Dim_learned_range).float() * Dim_learned_mask_3 359 | 360 | 361 | if Dim_learned_mask_4.size(1) < self.out_channels: 362 | select_nums = Dim_learned_mask_4.size(1) 363 | else: 364 | select_nums = self.out_channels 365 | Dim_learned_range = Dim_learned_mask_4.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 366 | Dim_learned_mask_4 = (Dim_learned_mask_4 >= Dim_learned_range).float() * Dim_learned_mask_4 367 | 368 | 369 | if Dim_learned_mask_5.size(1) < self.out_channels: 370 | select_nums = Dim_learned_mask_5.size(1) 371 | else: 372 | select_nums = self.out_channels 373 | Dim_learned_range = Dim_learned_mask_5.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 374 | Dim_learned_mask_5 = (Dim_learned_mask_5 >= Dim_learned_range).float() * Dim_learned_mask_5 375 | 376 | 377 | if Dim_learned_mask_6.size(1) < self.out_channels: 378 | select_nums = Dim_learned_mask_6.size(1) 379 | else: 380 | select_nums = self.out_channels 381 | Dim_learned_range = Dim_learned_mask_6.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 382 | Dim_learned_mask_6 = (Dim_learned_mask_6 >= Dim_learned_range).float() * Dim_learned_mask_6 383 | 384 | 385 | if Dim_learned_mask_7.size(1) < self.out_channels: 386 | select_nums = Dim_learned_mask_7.size(1) 387 | else: 388 | select_nums = self.out_channels 389 | Dim_learned_range = Dim_learned_mask_7.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 390 | Dim_learned_mask_7 = (Dim_learned_mask_7 >= Dim_learned_range).float() * Dim_learned_mask_7 391 | 392 | 393 | if Dim_learned_mask_8.size(1) < self.out_channels: 394 | select_nums = Dim_learned_mask_8.size(1) 395 | else: 396 | select_nums = self.out_channels 397 | Dim_learned_range = Dim_learned_mask_8.sort(1, descending =True)[0][:, select_nums -1].unsqueeze(-1) 398 | Dim_learned_mask_8 = (Dim_learned_mask_8 >= Dim_learned_range).float() * Dim_learned_mask_8 399 | 400 | cap_i_1 = cap_i.reshape(-1, self.sub_space) @ Dim_learned_mask_1.t() 401 | cap_i_expand_1 = l2norm(cap_i_1.reshape(batch_size, n_word, -1), -1) 402 | 403 | emb_size = cap_i_expand_1.size(-1) 404 | cap_i_2 = cap_i_1.reshape(-1, emb_size) @ Dim_learned_mask_2.t() 405 | cap_i_expand_2 = l2norm(cap_i_2.reshape(batch_size, n_word, -1), -1) 406 | 407 | emb_size = cap_i_expand_2.size(-1) 408 | cap_i_3 = cap_i_2.reshape(-1, emb_size) @ Dim_learned_mask_3.t() 409 | cap_i_expand_3 = l2norm(cap_i_3.reshape(batch_size, n_word, -1), -1) 410 | 411 | emb_size = cap_i_expand_3.size(-1) 412 | cap_i_4 = cap_i_3.reshape(-1, emb_size) @ Dim_learned_mask_4.t() 413 | cap_i_expand_4 = l2norm(cap_i_4.reshape(batch_size, n_word, -1), -1) 414 | 415 | emb_size = cap_i_expand_4.size(-1) 416 | cap_i_5 = cap_i_4.reshape(-1, emb_size) @ Dim_learned_mask_5.t() 417 | cap_i_expand_5 = l2norm(cap_i_5.reshape(batch_size, n_word, -1), -1) 418 | 419 | emb_size = cap_i_expand_5.size(-1) 420 | cap_i_6 = cap_i_5.reshape(-1, emb_size) @ Dim_learned_mask_6.t() 421 | cap_i_expand_6 = l2norm(cap_i_6.reshape(batch_size, n_word, -1), -1) 422 | 423 | emb_size = cap_i_expand_6.size(-1) 424 | cap_i_7 = cap_i_6.reshape(-1, emb_size) @ Dim_learned_mask_7.t() 425 | cap_i_expand_7 = l2norm(cap_i_7.reshape(batch_size, n_word, -1), -1) 426 | 427 | emb_size = cap_i_expand_7.size(-1) 428 | cap_i_8 = cap_i_7.reshape(-1, emb_size) @ Dim_learned_mask_8.t() 429 | cap_i_expand_8 = l2norm(cap_i_8.reshape(batch_size, n_word, -1), -1) 430 | 431 | 432 | return torch.cat([cap_i, cap_i_expand_1, cap_i_expand_2, cap_i_expand_3, cap_i_expand_4, cap_i_expand_5, cap_i_expand_6, cap_i_expand_7, cap_i_expand_8], -1) 433 | 434 | def forward(self, cap_i): 435 | 436 | batch_size, n_word, embed_size = cap_i.size(0), cap_i.size(1), cap_i.size(2) 437 | 438 | return self.get_text_levels(cap_i, batch_size, n_word) 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | class Image_Text_Encoders(nn.Module): 451 | def __init__(self, opt): 452 | 453 | super(Image_Text_Encoders, self).__init__() 454 | self.text_levels = Text_levels(opt) 455 | self.image_levels = Image_levels(opt) 456 | 457 | def forward(self, images, captions, return_type): 458 | 459 | if return_type == 'image': 460 | img_embs = self.image_levels(images) 461 | return img_embs 462 | else: 463 | cap_embs = self.text_levels(captions) 464 | return cap_embs 465 | 466 | class Image_Text_Processing(nn.Module): 467 | 468 | def __init__(self, opt): 469 | super(Image_Text_Processing, self).__init__() 470 | self.encoders_1 = Image_Text_Encoders(opt) 471 | 472 | 473 | def forward(self, images, captions): 474 | 475 | image_processed = self.encoders_1(images, captions, 'image') 476 | text_processed = self.encoders_1(images, captions, 'text') 477 | 478 | return image_processed, text_processed 479 | 480 | 481 | 482 | 483 | class sims_claculator(nn.Module): 484 | def __init__(self, opt): 485 | super(sims_claculator, self).__init__() 486 | self.sub_space = opt.embed_size 487 | self.kernel_size = int(opt.kernel_size) 488 | 489 | self.opt = opt 490 | self.sim_eval = nn.Linear(9, 1, bias=False) 491 | self.temp_scale = nn.Linear(1, 1, bias=False) 492 | self.temp_scale_1 = nn.Linear(1, 1, bias=False) 493 | self.temp_scale_2 = nn.Linear(1, 1, bias=False) 494 | 495 | self.masks_0 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 0)), int(opt.embed_size/math.pow(self.kernel_size, 0))) 496 | self.masks_1 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 1))) 497 | self.masks_2 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 2))) 498 | self.masks_3 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 3))) 499 | self.masks_4 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 4))) 500 | self.masks_5 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 5))) 501 | self.masks_6 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 6))) 502 | self.masks_7 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 7))) 503 | self.masks_8 = torch.nn.Embedding(int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 8))) 504 | 505 | self.lynorm_0 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 0)), int(opt.embed_size/math.pow(self.kernel_size, 0))], eps=1e-08, elementwise_affine=True) 506 | self.lynorm_1 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 1))], eps=1e-08, elementwise_affine=True) 507 | self.lynorm_2 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 2)), int(opt.embed_size/math.pow(self.kernel_size, 2))], eps=1e-08, elementwise_affine=True) 508 | self.lynorm_3 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 3))], eps=1e-08, elementwise_affine=True) 509 | self.lynorm_4 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 4))], eps=1e-08, elementwise_affine=True) 510 | self.lynorm_5 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 5)), int(opt.embed_size/math.pow(self.kernel_size, 5))], eps=1e-08, elementwise_affine=True) 511 | self.lynorm_6 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 6))], eps=1e-08, elementwise_affine=True) 512 | self.lynorm_7 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 7))], eps=1e-08, elementwise_affine=True) 513 | self.lynorm_8 = nn.LayerNorm([int(opt.embed_size/math.pow(self.kernel_size, 8)), int(opt.embed_size/math.pow(self.kernel_size, 8))], eps=1e-08, elementwise_affine=True) 514 | 515 | self.list_length = [int(opt.embed_size/math.pow(self.kernel_size, 0)), int(opt.embed_size/math.pow(self.kernel_size, 1)), int(opt.embed_size/math.pow(self.kernel_size, 2)), 516 | int(opt.embed_size/math.pow(self.kernel_size, 3)), int(opt.embed_size/math.pow(self.kernel_size, 4)), int(opt.embed_size/math.pow(self.kernel_size, 5)), 517 | int(opt.embed_size/math.pow(self.kernel_size, 6)), int(opt.embed_size/math.pow(self.kernel_size, 7)), int(opt.embed_size/math.pow(self.kernel_size, 8))] 518 | 519 | 520 | self.init_weights() 521 | 522 | def init_weights(self): 523 | self.temp_scale.weight.data.fill_(np.log(1 / 0.07)) 524 | self.sim_eval.weight.data.fill_(0.1) 525 | self.temp_scale_1.weight.data.fill_(0) 526 | self.temp_scale_2.weight.data.fill_(3) 527 | 528 | def get_weighted_features(self, attn, smooth): 529 | # --> (batch, sourceL, queryL) 530 | attnT = torch.transpose(attn, 1, 2).contiguous() 531 | attn = nn.LeakyReLU(0.1)(attnT) 532 | attn = l2norm(attn, 2) 533 | # --> (batch, queryL, sourceL) 534 | attn = torch.transpose(attn, 1, 2).contiguous() 535 | # --> (batch, queryL, sourceL 536 | attn = F.softmax(attn*smooth, dim=2) 537 | 538 | return attn 539 | 540 | 541 | def get_sims_levels(self, X_0, X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8, 542 | Y_0, Y_1, Y_2, Y_3, Y_4, Y_5, Y_6, Y_7, Y_8, 543 | D_0, D_1, D_2, D_3, D_4, D_5, D_6, D_7, D_8): 544 | attn_0 = (X_0 @ D_0 @ Y_0.transpose(1, 2)) 545 | attn_1 = (X_1 @ D_1 @ Y_1.transpose(1, 2)) 546 | attn_2 = (X_2 @ D_2 @ Y_2.transpose(1, 2)) 547 | attn_3 = (X_3 @ D_3 @ Y_3.transpose(1, 2)) 548 | attn_4 = (X_4 @ D_4 @ Y_4.transpose(1, 2)) 549 | attn_5 = (X_5 @ D_5 @ Y_5.transpose(1, 2)) 550 | attn_6 = (X_6 @ D_6 @ Y_6.transpose(1, 2)) 551 | attn_7 = (X_7 @ D_7 @ Y_7.transpose(1, 2)) 552 | attn_8 = (X_8 @ D_8 @ Y_8.transpose(1, 2)) 553 | 554 | attn = attn_0 + attn_1 + attn_2 + attn_3 + attn_4 + attn_5 + attn_6 + attn_7 + attn_8 555 | 556 | return attn 557 | 558 | def forward(self, img_emb, cap_emb, cap_lens): 559 | 560 | n_caption = cap_emb.size(0) 561 | sim_all = [] 562 | batch_size, n_region, embed_size = img_emb.size(0), img_emb.size(1), img_emb.size(2) 563 | sub_space_index = torch.tensor(torch.linspace(0, self.sub_space, steps=self.sub_space, dtype=torch.int)).cuda() 564 | smooth=torch.exp(self.temp_scale.weight) 565 | 566 | sigma_ = self.temp_scale_1.weight 567 | lambda_ = torch.exp(self.temp_scale_2.weight) 568 | threshold = (torch.abs(self.sim_eval.weight).max() - torch.abs(self.sim_eval.weight).min()) * sigma_ + torch.abs(self.sim_eval.weight).min() 569 | if not heuristic_strategy: 570 | lambda_ = 0 571 | 572 | weight_0 = (torch.exp((torch.abs(self.sim_eval.weight[0, 0]) - threshold) * lambda_) * self.sim_eval.weight[0, 0]) 573 | weight_1 = (torch.exp((torch.abs(self.sim_eval.weight[0, 1]) - threshold) * lambda_) * self.sim_eval.weight[0, 1]) 574 | weight_2 = (torch.exp((torch.abs(self.sim_eval.weight[0, 2]) - threshold) * lambda_) * self.sim_eval.weight[0, 2]) 575 | weight_3 = (torch.exp((torch.abs(self.sim_eval.weight[0, 3]) - threshold) * lambda_) * self.sim_eval.weight[0, 3]) 576 | weight_4 = (torch.exp((torch.abs(self.sim_eval.weight[0, 4]) - threshold) * lambda_) * self.sim_eval.weight[0, 4]) 577 | weight_5 = (torch.exp((torch.abs(self.sim_eval.weight[0, 5]) - threshold) * lambda_) * self.sim_eval.weight[0, 5]) 578 | weight_6 = (torch.exp((torch.abs(self.sim_eval.weight[0, 6]) - threshold) * lambda_) * self.sim_eval.weight[0, 6]) 579 | weight_7 = (torch.exp((torch.abs(self.sim_eval.weight[0, 7]) - threshold) * lambda_) * self.sim_eval.weight[0, 7]) 580 | weight_8 = (torch.exp((torch.abs(self.sim_eval.weight[0, 8]) - threshold) * lambda_) * self.sim_eval.weight[0, 8]) 581 | 582 | 583 | 584 | Dim_learned_mask_0 = self.lynorm_0(self.masks_0(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 0))])) * weight_0 585 | Dim_learned_mask_1 = self.lynorm_1(self.masks_1(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 1))])) * weight_1 586 | Dim_learned_mask_2 = self.lynorm_2(self.masks_2(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 2))])) * weight_2 587 | Dim_learned_mask_3 = self.lynorm_3(self.masks_3(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 3))])) * weight_3 588 | Dim_learned_mask_4 = self.lynorm_4(self.masks_4(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 4))])) * weight_4 589 | Dim_learned_mask_5 = self.lynorm_5(self.masks_5(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 5))])) * weight_5 590 | Dim_learned_mask_6 = self.lynorm_6(self.masks_6(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 6))])) * weight_6 591 | Dim_learned_mask_7 = self.lynorm_7(self.masks_7(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 7))])) * weight_7 592 | Dim_learned_mask_8 = self.lynorm_8(self.masks_8(sub_space_index[:int(self.sub_space/math.pow(self.kernel_size, 8))])) * weight_8 593 | 594 | 595 | 596 | for i in range(n_caption): 597 | # get the i-th sentence 598 | n_word = cap_lens[i] 599 | 600 | ## ------------------------------------------------------------------------------------------------------------------------ 601 | # attention 602 | 603 | attn = self.get_sims_levels( 604 | cap_emb[i, :n_word, :sum(self.list_length[:1])].unsqueeze(0).repeat(batch_size, 1, 1), 605 | cap_emb[i, :n_word, sum(self.list_length[:1]):sum(self.list_length[:2])].unsqueeze(0).repeat(batch_size, 1, 1), 606 | cap_emb[i, :n_word, sum(self.list_length[:2]):sum(self.list_length[:3])].unsqueeze(0).repeat(batch_size, 1, 1), 607 | cap_emb[i, :n_word, sum(self.list_length[:3]):sum(self.list_length[:4])].unsqueeze(0).repeat(batch_size, 1, 1), 608 | cap_emb[i, :n_word, sum(self.list_length[:4]):sum(self.list_length[:5])].unsqueeze(0).repeat(batch_size, 1, 1), 609 | cap_emb[i, :n_word, sum(self.list_length[:5]):sum(self.list_length[:6])].unsqueeze(0).repeat(batch_size, 1, 1), 610 | cap_emb[i, :n_word, sum(self.list_length[:6]):sum(self.list_length[:7])].unsqueeze(0).repeat(batch_size, 1, 1), 611 | cap_emb[i, :n_word, sum(self.list_length[:7]):sum(self.list_length[:8])].unsqueeze(0).repeat(batch_size, 1, 1), 612 | cap_emb[i, :n_word, sum(self.list_length[:8]):sum(self.list_length[:9])].unsqueeze(0).repeat(batch_size, 1, 1), 613 | img_emb[:, :, :sum(self.list_length[:1])], 614 | img_emb[:, :, sum(self.list_length[:1]):sum(self.list_length[:2])], 615 | img_emb[:, :, sum(self.list_length[:2]):sum(self.list_length[:3])], 616 | img_emb[:, :, sum(self.list_length[:3]):sum(self.list_length[:4])], 617 | img_emb[:, :, sum(self.list_length[:4]):sum(self.list_length[:5])], 618 | img_emb[:, :, sum(self.list_length[:5]):sum(self.list_length[:6])], 619 | img_emb[:, :, sum(self.list_length[:6]):sum(self.list_length[:7])], 620 | img_emb[:, :, sum(self.list_length[:7]):sum(self.list_length[:8])], 621 | img_emb[:, :, sum(self.list_length[:8]):sum(self.list_length[:9])], 622 | Dim_learned_mask_0, Dim_learned_mask_1, Dim_learned_mask_2, Dim_learned_mask_3, Dim_learned_mask_4, Dim_learned_mask_5, Dim_learned_mask_6, Dim_learned_mask_7, Dim_learned_mask_8) 623 | 624 | ################################################################################################## 625 | # --> (batch, sourceL, queryL) 626 | attnT = torch.transpose(attn, 1, 2).contiguous() 627 | attn_t2i_weight = self.get_weighted_features(torch.tanh(attn), smooth) 628 | sims_t2i = attn.mul(attn_t2i_weight).sum(-1).mean(dim=1, keepdim=True) 629 | ################################################################################################## 630 | attn_i2t_weight = self.get_weighted_features(torch.tanh(attnT), smooth) 631 | sims_i2t = attnT.mul(attn_i2t_weight).sum(-1).mean(dim=1, keepdim=True) 632 | ################################################################################################## 633 | 634 | sims = sims_t2i + sims_i2t 635 | sim_all.append(sims) 636 | 637 | sim_all = torch.cat(sim_all, 1) 638 | return sim_all 639 | 640 | 641 | 642 | 643 | class Sims_Measuring(nn.Module): 644 | def __init__(self, opt): 645 | super(Sims_Measuring, self).__init__() 646 | 647 | self.calculator_1 = sims_claculator(opt) 648 | 649 | def forward(self, img_embs, cap_embs, lengths): 650 | 651 | sims = self.calculator_1(img_embs, cap_embs, lengths) 652 | 653 | return sims 654 | 655 | 656 | 657 | class Sim_vec(nn.Module): 658 | 659 | def __init__(self, embed_size, opt): 660 | super(Sim_vec, self).__init__() 661 | self.plus_encoder = Image_Text_Processing(opt) 662 | self.sims = Sims_Measuring(opt) 663 | 664 | def forward(self, img_emb, cap_emb, cap_lens, is_Train): 665 | 666 | region_features, word_features = self.plus_encoder(img_emb, cap_emb) 667 | sims = self.sims(region_features, word_features, cap_lens) 668 | 669 | return sims, sims 670 | 671 | 672 | 673 | 674 | 675 | class VSEModel(object): 676 | 677 | 678 | def __init__(self, opt): 679 | # Build Models 680 | self.grad_clip = opt.grad_clip 681 | self.img_enc = get_image_encoder(opt.data_name, opt.img_dim, opt.embed_size, 682 | precomp_enc_type=opt.precomp_enc_type, 683 | backbone_source=opt.backbone_source, 684 | backbone_path=opt.backbone_path, 685 | no_imgnorm=opt.no_imgnorm) 686 | self.txt_enc = get_text_encoder(opt.embed_size, no_txtnorm=opt.no_txtnorm) 687 | 688 | self.sim_vec = Sim_vec(opt.embed_size, opt) 689 | 690 | if torch.cuda.is_available(): 691 | self.img_enc.cuda() 692 | self.txt_enc.cuda() 693 | self.sim_vec.cuda() 694 | cudnn.benchmark = True 695 | 696 | # Loss and Optimizer 697 | self.criterion = ContrastiveLoss(opt=opt, 698 | margin=opt.margin, 699 | max_violation=opt.max_violation) 700 | 701 | params = list(self.txt_enc.parameters()) 702 | params += list(self.img_enc.parameters()) 703 | params += list(self.sim_vec.parameters()) 704 | 705 | self.params = params 706 | self.opt = opt 707 | 708 | # Set up the lr for different parts of the VSE model 709 | decay_factor = 1e-4 710 | if opt.precomp_enc_type == 'basic': 711 | if self.opt.optim == 'adam': 712 | all_text_params = list(self.txt_enc.parameters()) 713 | bert_params = list(self.txt_enc.bert.parameters()) 714 | bert_params_ptr = [p.data_ptr() for p in bert_params] 715 | text_params_no_bert = list() 716 | for p in all_text_params: 717 | if p.data_ptr() not in bert_params_ptr: 718 | text_params_no_bert.append(p) 719 | self.optimizer = torch.optim.AdamW([ 720 | {'params': text_params_no_bert, 'lr': opt.learning_rate}, 721 | {'params': bert_params, 'lr': opt.learning_rate * 0.1}, 722 | {'params': self.img_enc.parameters(), 'lr': opt.learning_rate}, 723 | {'params': self.sim_vec.parameters(), 'lr': opt.learning_rate}, 724 | ], 725 | lr=opt.learning_rate, weight_decay=decay_factor) 726 | elif self.opt.optim == 'sgd': 727 | self.optimizer = torch.optim.SGD(self.params, lr=opt.learning_rate, momentum=0.9) 728 | else: 729 | raise ValueError('Invalid optim option {}'.format(self.opt.optim)) 730 | else: 731 | if self.opt.optim == 'adam': 732 | all_text_params = list(self.txt_enc.parameters()) 733 | bert_params = list(self.txt_enc.bert.parameters()) 734 | bert_params_ptr = [p.data_ptr() for p in bert_params] 735 | text_params_no_bert = list() 736 | for p in all_text_params: 737 | if p.data_ptr() not in bert_params_ptr: 738 | text_params_no_bert.append(p) 739 | self.optimizer = torch.optim.AdamW([ 740 | {'params': text_params_no_bert, 'lr': opt.learning_rate}, 741 | {'params': bert_params, 'lr': opt.learning_rate * 0.1}, 742 | {'params': self.img_enc.backbone.top.parameters(), 743 | 'lr': opt.learning_rate * opt.backbone_lr_factor, }, 744 | {'params': self.img_enc.backbone.base.parameters(), 745 | 'lr': opt.learning_rate * opt.backbone_lr_factor, }, 746 | {'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate}, 747 | ], lr=opt.learning_rate, weight_decay=decay_factor) 748 | elif self.opt.optim == 'sgd': 749 | self.optimizer = torch.optim.SGD([ 750 | {'params': self.txt_enc.parameters(), 'lr': opt.learning_rate}, 751 | {'params': self.img_enc.backbone.parameters(), 'lr': opt.learning_rate * opt.backbone_lr_factor, 752 | 'weight_decay': decay_factor}, 753 | {'params': self.img_enc.image_encoder.parameters(), 'lr': opt.learning_rate}, 754 | ], lr=opt.learning_rate, momentum=0.9, nesterov=True) 755 | else: 756 | raise ValueError('Invalid optim option {}'.format(self.opt.optim)) 757 | 758 | logger.info('Use {} as the optimizer, with init lr {}'.format(self.opt.optim, opt.learning_rate)) 759 | 760 | self.Eiters = 0 761 | self.data_parallel = False 762 | 763 | def set_max_violation(self, max_violation): 764 | if max_violation: 765 | self.criterion.max_violation_on() 766 | else: 767 | self.criterion.max_violation_off() 768 | 769 | def state_dict(self): 770 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), 771 | self.sim_vec.state_dict()] 772 | return state_dict 773 | 774 | def load_state_dict(self, state_dict): 775 | self.img_enc.load_state_dict(state_dict[0], strict=False) 776 | self.txt_enc.load_state_dict(state_dict[1], strict=False) 777 | self.sim_vec.load_state_dict(state_dict[2], strict=False) 778 | 779 | def train_start(self): 780 | """switch to train mode 781 | """ 782 | self.img_enc.train() 783 | self.txt_enc.train() 784 | self.sim_vec.train() 785 | 786 | def val_start(self): 787 | """switch to evaluate mode 788 | """ 789 | self.img_enc.eval() 790 | self.txt_enc.eval() 791 | self.sim_vec.eval() 792 | 793 | def freeze_backbone(self): 794 | if 'backbone' in self.opt.precomp_enc_type: 795 | if isinstance(self.img_enc, nn.DataParallel): 796 | self.img_enc.module.freeze_backbone() 797 | else: 798 | self.img_enc.freeze_backbone() 799 | 800 | def unfreeze_backbone(self, fixed_blocks): 801 | if 'backbone' in self.opt.precomp_enc_type: 802 | if isinstance(self.img_enc, nn.DataParallel): 803 | self.img_enc.module.unfreeze_backbone(fixed_blocks) 804 | else: 805 | self.img_enc.unfreeze_backbone(fixed_blocks) 806 | 807 | def make_data_parallel(self): 808 | self.img_enc = nn.DataParallel(self.img_enc) 809 | self.txt_enc = nn.DataParallel(self.txt_enc) 810 | self.sim_vec = nn.DataParallel(self.sim_vec) 811 | self.data_parallel = True 812 | logger.info('Image encoder is data paralleled now.') 813 | 814 | @property 815 | def is_data_parallel(self): 816 | return self.data_parallel 817 | 818 | def forward_emb(self, images, captions, lengths, image_lengths=None): 819 | """Compute the image and caption embeddings 820 | """ 821 | # Set mini-batch dataset 822 | if self.opt.precomp_enc_type == 'basic': 823 | if torch.cuda.is_available(): 824 | images = images.cuda() 825 | captions = captions.cuda() 826 | image_lengths = image_lengths.cuda() 827 | img_emb = self.img_enc(images, image_lengths) 828 | else: 829 | if torch.cuda.is_available(): 830 | images = images.cuda() 831 | captions = captions.cuda() 832 | img_emb = self.img_enc(images) 833 | 834 | # lengths = torch.Tensor(lengths).cuda() 835 | cap_emb = self.txt_enc(captions, lengths) 836 | return img_emb, cap_emb, lengths 837 | 838 | def forward_sim(self, img_emb, cap_emb, cap_lens): 839 | is_Train = True 840 | sim_all, L1 = self.sim_vec(img_emb, cap_emb, cap_lens, is_Train) 841 | 842 | return sim_all, L1 843 | 844 | def forward_sim_test(self, img_emb, cap_emb, cap_lens): 845 | is_Train = False 846 | sim_all, L1 = self.sim_vec(img_emb, cap_emb, cap_lens, is_Train) 847 | 848 | return sim_all, L1 849 | 850 | 851 | def forward_loss(self, sims): 852 | """Compute the loss given pairs of image and caption embeddings 853 | """ 854 | loss = self.criterion(sims) 855 | self.logger.update('Le', loss.data.item(), sims.size(0)) 856 | return loss 857 | 858 | def train_emb(self, images, captions, lengths, image_lengths=None, warmup_alpha=None): 859 | """One training step given images and captions. 860 | """ 861 | self.Eiters += 1 862 | self.logger.update('Eit', self.Eiters) 863 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 864 | 865 | # compute the embeddings 866 | img_emb, cap_emb, cap_lens = self.forward_emb(images, captions, lengths, image_lengths=image_lengths) 867 | sims, L1= self.forward_sim(img_emb, cap_emb, cap_lens) 868 | 869 | # measure accuracy and record loss 870 | self.optimizer.zero_grad() 871 | loss = self.forward_loss(sims) 872 | 873 | if warmup_alpha is not None: 874 | loss = loss * warmup_alpha 875 | 876 | 877 | message = "%f\n" %(loss) 878 | log_file2 = os.path.join(self.opt.logger_name, "loss.txt") 879 | logging_func(log_file2, message) 880 | 881 | 882 | # compute gradient and update 883 | loss.backward() 884 | if self.grad_clip > 0: 885 | clip_grad_norm_(self.params, self.grad_clip) 886 | self.optimizer.step() -------------------------------------------------------------------------------- /motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/motivation.png -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/overview.png -------------------------------------------------------------------------------- /runs/runX/log/performance.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/ESL/b91e287b238fcf6dd7dc784c04e7bf1ac53404ae/runs/runX/log/performance.log -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from lib import evaluation 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "7" 5 | 6 | 7 | ## for heuristic_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as True 8 | # RUN_PATH = "../checkpoint_heuristic_flickr30k_bert.tar" 9 | 10 | 11 | ## for adaptive_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as False 12 | RUN_PATH = "../checkpoint_adaptive_flickr30k_bert.tar" 13 | 14 | 15 | DATA_PATH = "../Flickr30K/" 16 | evaluation.evalrank(RUN_PATH, data_path=DATA_PATH, split="test") 17 | -------------------------------------------------------------------------------- /test_stack.py: -------------------------------------------------------------------------------- 1 | # from vocab import Vocabulary 2 | # import evaluation 3 | import numpy as np 4 | import os 5 | 6 | 7 | def i2t(im_len, sims, npts=None, return_ranks=False): 8 | """ 9 | Images->Text (Image Annotation) 10 | Images: (N, n_region, d) matrix of images 11 | Captions: (5N, max_n_word, d) matrix of captions 12 | CapLens: (5N) array of caption lengths 13 | sims: (N, 5N) matrix of similarity im-cap 14 | """ 15 | npts = im_len 16 | ranks = np.zeros(npts) 17 | top1 = np.zeros(npts) 18 | for index in range(npts): 19 | inds = np.argsort(sims[index])[::-1] 20 | # Score 21 | rank = 1e20 22 | for i in range(5 * index, 5 * index + 5, 1): 23 | tmp = np.where(inds == i)[0][0] 24 | if tmp < rank: 25 | rank = tmp 26 | ranks[index] = rank 27 | top1[index] = inds[0] 28 | 29 | # Compute metrics 30 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 31 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 32 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 33 | medr = np.floor(np.median(ranks)) + 1 34 | meanr = ranks.mean() + 1 35 | if return_ranks: 36 | return (r1, r5, r10, medr, meanr), (ranks, top1) 37 | else: 38 | return (r1, r5, r10, medr, meanr) 39 | 40 | 41 | def t2i(im_len, sims, npts=None, return_ranks=False): 42 | """ 43 | Text->Images (Image Search) 44 | Images: (N, n_region, d) matrix of images 45 | Captions: (5N, max_n_word, d) matrix of captions 46 | CapLens: (5N) array of caption lengths 47 | sims: (N, 5N) matrix of similarity im-cap 48 | """ 49 | npts = im_len 50 | ranks = np.zeros(5 * npts) 51 | top1 = np.zeros(5 * npts) 52 | 53 | # --> (5N(caption), N(image)) 54 | sims = sims.T 55 | 56 | for index in range(npts): 57 | for i in range(5): 58 | inds = np.argsort(sims[5 * index + i])[::-1] 59 | ranks[5 * index + i] = np.where(inds == index)[0][0] 60 | top1[5 * index + i] = inds[0] 61 | 62 | # Compute metrics 63 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 64 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 65 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 66 | medr = np.floor(np.median(ranks)) + 1 67 | meanr = ranks.mean() + 1 68 | if return_ranks: 69 | return (r1, r5, r10, medr, meanr), (ranks, top1) 70 | else: 71 | return (r1, r5, r10, medr, meanr) 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 77 | 78 | isfold5 = True 79 | 80 | if not isfold5: 81 | 82 | # ## Flickr30K 83 | # Path_of_Model_1 = '' 84 | # Path_of_Model_2 = '' 85 | 86 | ## MS-COCO 87 | Path_of_Model_1 = '' 88 | Path_of_Model_2 = '' 89 | 90 | sims1 = np.loadtxt(Path_of_Model_1) 91 | sims2 = np.loadtxt(Path_of_Model_2) 92 | 93 | sims = (sims1 + sims2) 94 | im_len = len(sims) 95 | print('im length:', im_len) 96 | r, rt = i2t(im_len, sims, return_ranks=True) 97 | ri, rti = t2i(im_len, sims, return_ranks=True) 98 | ar = (r[0] + r[1] + r[2]) / 3 99 | ari = (ri[0] + ri[1] + ri[2]) / 3 100 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 101 | print("rsum: %.1f" % rsum) 102 | print("Average i2t Recall: %.1f" % ar) 103 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) 104 | print("Average t2i Recall: %.1f" % ari) 105 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) 106 | else: 107 | results = [] 108 | for i in range(5): 109 | 110 | Path_of_Model_1 = '/mnt/data2/zk/ESL_bert/checkpoint2/COCO-LEARNABLE/' 111 | Path_of_Model_2 = '/mnt/data2/zk/ESL_bert/checkpoint2/COCO-NON-LEARNABLE/' 112 | 113 | sims1 = np.loadtxt(Path_of_Model_1 + str(i) + 'sim_best.txt') 114 | sims2 = np.loadtxt(Path_of_Model_2 + str(i) + 'sim_best.txt') 115 | 116 | sim_shard = (sims1 + sims2) / 2 117 | im_len = len(sim_shard) 118 | print('im length:', im_len) 119 | r, rt0 = i2t(im_len, sim_shard, return_ranks=True) 120 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 121 | ri, rti0 = t2i(im_len, sim_shard, return_ranks=True) 122 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 123 | 124 | if i == 0: 125 | rt, rti = rt0, rti0 126 | ar = (r[0] + r[1] + r[2]) / 3 127 | ari = (ri[0] + ri[1] + ri[2]) / 3 128 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 129 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 130 | results += [list(r) + list(ri) + [ar, ari, rsum]] 131 | 132 | print("-----------------------------------") 133 | print("Mean metrics: ") 134 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 135 | print("rsum: %.1f" % ( mean_metrics[12])) 136 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 137 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 138 | mean_metrics[:5]) 139 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 140 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 141 | mean_metrics[5:10]) 142 | -------------------------------------------------------------------------------- /testall.py: -------------------------------------------------------------------------------- 1 | from logging import fatal 2 | import os 3 | from pickle import TRUE 4 | from lib import evaluation 5 | 6 | import torch 7 | torch.set_num_threads(4) 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 10 | 11 | 12 | ## for heuristic_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as True 13 | # RUN_PATH = "../checkpoint_heuristic_mscoco_bert.tar" 14 | 15 | 16 | ## for adaptive_strategy, note that you should set the flag varibale 'heuristic_strategy' in line 44 of ves.py as False 17 | RUN_PATH = "../checkpoint_adaptive_mscoco_bert.tar" 18 | 19 | DATA_PATH = "../MS-COCO/" 20 | 21 | ### set fold5 as Flase for 5K TEST, otherwise for AVERAGE 1K TEST 22 | evaluation.evalrank(RUN_PATH, data_path=DATA_PATH, split="testall", fold5=False) 23 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Training script""" 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | from transformers import BertTokenizer 7 | 8 | from lib.datasets import image_caption 9 | from lib.vse import VSEModel 10 | from lib.evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data, shard_attn_scores 11 | 12 | import logging 13 | import tensorboard_logger as tb_logger 14 | 15 | import arguments 16 | 17 | 18 | def logging_func(log_file, message): 19 | with open(log_file, 'a') as f: 20 | f.write(message) 21 | f.close() 22 | 23 | 24 | def main(): 25 | # Hyper Parameters 26 | parser = arguments.get_argument_parser() 27 | opt = parser.parse_args() 28 | 29 | if not os.path.exists(opt.model_name): 30 | os.makedirs(opt.model_name) 31 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 32 | tb_logger.configure(opt.logger_name, flush_secs=5) 33 | 34 | logger = logging.getLogger(__name__) 35 | logger.info(opt) 36 | 37 | # Load Tokenizer and Vocabulary 38 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 39 | vocab = tokenizer.vocab 40 | opt.vocab_size = len(vocab) 41 | 42 | train_loader, val_loader = image_caption.get_loaders( 43 | opt.data_path, opt.data_name, tokenizer, opt.batch_size, opt.workers, opt) 44 | 45 | model = VSEModel(opt) 46 | 47 | lr_schedules = [opt.lr_update, ] 48 | 49 | # optionally resume from a checkpoint 50 | start_epoch = 0 51 | if opt.resume: 52 | if os.path.isfile(opt.resume): 53 | logger.info("=> loading checkpoint '{}'".format(opt.resume)) 54 | checkpoint = torch.load(opt.resume) 55 | start_epoch = checkpoint['epoch'] 56 | best_rsum = checkpoint['best_rsum'] 57 | if not model.is_data_parallel: 58 | model.make_data_parallel() 59 | model.load_state_dict(checkpoint['model']) 60 | # Eiters is used to show logs as the continuation of another training 61 | model.Eiters = checkpoint['Eiters'] 62 | logger.info("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" 63 | .format(opt.resume, start_epoch, best_rsum)) 64 | # validate(opt, val_loader, model) 65 | if opt.reset_start_epoch: 66 | start_epoch = 0 67 | else: 68 | logger.info("=> no checkpoint found at '{}'".format(opt.resume)) 69 | 70 | if not model.is_data_parallel: 71 | model.make_data_parallel() 72 | 73 | # Train the Model 74 | best_rsum = 0 75 | for epoch in range(start_epoch, opt.num_epochs): 76 | logger.info(opt.logger_name) 77 | logger.info(opt.model_name) 78 | 79 | adjust_learning_rate(opt, model.optimizer, epoch, lr_schedules) 80 | 81 | if epoch >= opt.vse_mean_warmup_epochs: 82 | opt.max_violation = True 83 | model.set_max_violation(opt.max_violation) 84 | 85 | # Set up the all warm-up options 86 | if opt.precomp_enc_type == 'backbone': 87 | if epoch < opt.embedding_warmup_epochs: 88 | model.freeze_backbone() 89 | logger.info('All backbone weights are frozen, only train the embedding layers') 90 | else: 91 | model.unfreeze_backbone(3) 92 | 93 | if epoch < opt.embedding_warmup_epochs: 94 | logger.info('Warm up the embedding layers') 95 | elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs: 96 | model.unfreeze_backbone(3) # only train the last block of resnet backbone 97 | elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs * 2: 98 | model.unfreeze_backbone(2) 99 | elif epoch < opt.embedding_warmup_epochs + opt.backbone_warmup_epochs * 3: 100 | model.unfreeze_backbone(1) 101 | else: 102 | model.unfreeze_backbone(0) 103 | 104 | # train for one epoch 105 | train(opt, train_loader, model, epoch, val_loader) 106 | 107 | # evaluate on validation set 108 | rsum = validate(opt, val_loader, model, epoch) 109 | 110 | # remember best R@ sum and save checkpoint 111 | is_best = rsum > best_rsum 112 | best_rsum = max(rsum, best_rsum) 113 | if not os.path.exists(opt.model_name): 114 | os.mkdir(opt.model_name) 115 | 116 | save_checkpoint({ 117 | 'epoch': epoch + 1, 118 | 'model': model.state_dict(), 119 | 'best_rsum': best_rsum, 120 | 'opt': opt, 121 | 'Eiters': model.Eiters, 122 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/') 123 | 124 | 125 | def train(opt, train_loader, model, epoch, val_loader): 126 | # average meters to record the training statistics 127 | logger = logging.getLogger(__name__) 128 | batch_time = AverageMeter() 129 | data_time = AverageMeter() 130 | train_logger = LogCollector() 131 | 132 | logger.info('image encoder trainable parameters: {}'.format(count_params(model.img_enc))) 133 | logger.info('txt encoder trainable parameters: {}'.format(count_params(model.txt_enc))) 134 | 135 | num_loader_iter = len(train_loader.dataset) // train_loader.batch_size + 1 136 | 137 | end = time.time() 138 | # opt.viz = True 139 | for i, train_data in enumerate(train_loader): 140 | # switch to train mode 141 | model.train_start() 142 | 143 | # measure data loading time 144 | data_time.update(time.time() - end) 145 | 146 | # make sure train logger is used 147 | model.logger = train_logger 148 | 149 | # Update the model 150 | if opt.precomp_enc_type == 'basic': 151 | images, img_lengths, captions, lengths, _ = train_data 152 | model.train_emb(images, captions, lengths, image_lengths=img_lengths) 153 | else: 154 | images, captions, lengths, _ = train_data 155 | if epoch == opt.embedding_warmup_epochs: 156 | warmup_alpha = float(i) / num_loader_iter 157 | model.train_emb(images, captions, lengths, warmup_alpha=warmup_alpha) 158 | else: 159 | model.train_emb(images, captions, lengths) 160 | 161 | # measure elapsed time 162 | batch_time.update(time.time() - end) 163 | end = time.time() 164 | 165 | # logger.info log info 166 | if model.Eiters % opt.log_step == 0: 167 | if opt.precomp_enc_type == 'backbone' and epoch == opt.embedding_warmup_epochs: 168 | logging.info('Current epoch-{}, the first epoch for training backbone, warmup alpha {}'.format(epoch, 169 | warmup_alpha)) 170 | logging.info( 171 | 'Epoch: [{0}][{1}/{2}]\t' 172 | '{e_log}\t' 173 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 174 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 175 | .format( 176 | epoch, i, len(train_loader.dataset) // train_loader.batch_size + 1, batch_time=batch_time, 177 | data_time=data_time, e_log=str(model.logger))) 178 | 179 | # Record logs in tensorboard 180 | tb_logger.log_value('epoch', epoch, step=model.Eiters) 181 | tb_logger.log_value('step', i, step=model.Eiters) 182 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters) 183 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters) 184 | model.logger.tb_log(tb_logger, step=model.Eiters) 185 | 186 | 187 | def validate(opt, val_loader, model, epoch): 188 | logger = logging.getLogger(__name__) 189 | model.val_start() 190 | with torch.no_grad(): 191 | # compute the encoding for all the validation images and captions 192 | img_embs, cap_embs, cap_lens = encode_data( 193 | model, val_loader, opt.log_step, logging.info, backbone=opt.precomp_enc_type == 'backbone') 194 | 195 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) 196 | 197 | start = time.time() 198 | # sims = compute_sim(img_embs, cap_embs) 199 | sims = shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=opt.batch_size) 200 | end = time.time() 201 | logger.info("calculate similarity time: {}".format(end - start)) 202 | 203 | # caption retrieval 204 | npts = img_embs.shape[0] 205 | # (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, cap_lens, sims) 206 | (r1, r5, r10, medr, meanr) = i2t(npts, sims) 207 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % 208 | (r1, r5, r10, medr, meanr)) 209 | # image retrieval 210 | # (r1i, r5i, r10i, medri, meanr) = t2i(img_embs, cap_embs, cap_lens, sims) 211 | (r1i, r5i, r10i, medri, meanr) = t2i(npts, sims) 212 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % 213 | (r1i, r5i, r10i, medri, meanr)) 214 | # sum of recalls to be used for early stopping 215 | currscore = r1 + r5 + r10 + r1i + r5i + r10i 216 | logger.info('Current rsum is {}'.format(currscore)) 217 | 218 | 219 | message = "Epoch: %d: Image to text: (%.1f, %.1f, %.1f) " % (epoch, r1, r5, r10) 220 | message += "Text to image: (%.1f, %.1f, %.1f) " % (r1i, r5i, r10i) 221 | message += "rsum: %.1f\n" % currscore 222 | 223 | log_file = os.path.join(opt.logger_name, "performance.log") 224 | logging_func(log_file, message) 225 | 226 | return currscore 227 | 228 | 229 | def save_checkpoint(state, is_best, filename='checkpoint.pth', prefix=''): 230 | logger = logging.getLogger(__name__) 231 | tries = 15 232 | 233 | # deal with unstable I/O. Usually not necessary. 234 | while tries: 235 | try: 236 | torch.save(state, prefix + filename) 237 | if is_best: 238 | torch.save(state, prefix + 'model_best.pth') 239 | except IOError as e: 240 | error = e 241 | tries -= 1 242 | else: 243 | break 244 | logger.info('model save {} failed, remaining {} trials'.format(filename, tries)) 245 | if not tries: 246 | raise error 247 | 248 | 249 | def adjust_learning_rate(opt, optimizer, epoch, lr_schedules): 250 | logger = logging.getLogger(__name__) 251 | """Sets the learning rate to the initial LR 252 | decayed by 10 every opt.lr_update epochs""" 253 | if epoch in lr_schedules: 254 | logger.info('Current epoch num is {}, decrease all lr by 10'.format(epoch, )) 255 | for param_group in optimizer.param_groups: 256 | old_lr = param_group['lr'] 257 | new_lr = old_lr * 0.1 258 | param_group['lr'] = new_lr 259 | logger.info('new lr {}'.format(new_lr)) 260 | 261 | 262 | def count_params(model): 263 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 264 | params = sum([np.prod(p.size()) for p in model_parameters]) 265 | return params 266 | 267 | 268 | if __name__ == '__main__': 269 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 270 | main() 271 | -------------------------------------------------------------------------------- /train_region_coco.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python3 ../ESL_MAIN_BERT/train.py \ 2 | --data_path ../MS-COCO/ \ 3 | --data_name coco_precomp \ 4 | --logger_name ../ESL_MAIN_BERT/log \ 5 | --model_name ../ESL_MAIN_BERT/checkpoint \ 6 | --batch_size 128 \ 7 | --num_epochs=20 \ 8 | --lr_update=10 \ 9 | --learning_rate=.0005 \ 10 | --precomp_enc_type basic \ 11 | --workers 10 \ 12 | --log_step 200 \ 13 | --embed_size 512 \ 14 | --vse_mean_warmup_epochs 1 \ 15 | --kernel_size 2 16 | -------------------------------------------------------------------------------- /train_region_f30k.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 ../ESL_MAIN_BERT/train.py \ 2 | --data_path ../Flickr30K/ \ 3 | --data_name f30k_precomp \ 4 | --logger_name ../ESL_MAIN_BERT/log \ 5 | --model_name ../ESL_MAIN_BERT/checkpoint2 \ 6 | --batch_size 128 \ 7 | --num_epochs=30 \ 8 | --lr_update=15 \ 9 | --learning_rate=.0005 \ 10 | --precomp_enc_type basic \ 11 | --workers 10 \ 12 | --log_step 200 \ 13 | --embed_size 512 \ 14 | --vse_mean_warmup_epochs 1 \ 15 | --kernel_size 2 16 | --------------------------------------------------------------------------------