├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cfgs ├── anet_res101_vg_feat_10x100prop.yml ├── conda_env_gvd.yml └── conda_env_gvd_py3.yml ├── data └── vg_object_vocab.txt ├── demo ├── good_examples.txt └── gvid_teaser.png ├── main.py ├── misc ├── AttModel.py ├── CaptionModelBU.py ├── __init__.py ├── bbox_transform.py ├── dataloader_anet.py ├── model.py ├── transformer.py └── utils.py ├── opts.py ├── prepro └── prepro_dic_anet.py └── tools ├── download_all.sh └── vg_cls_overlap.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.h5 3 | *.json 4 | *.csv 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tools/densevid_eval"] 2 | path = tools/densevid_eval 3 | url = https://github.com/LuoweiZhou/densevid_eval_spice 4 | [submodule "tools/anet_entities"] 5 | path = tools/anet_entities 6 | url = https://github.com/facebookresearch/ActivityNet-Entities.git 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Grounded Video Description 2 | We want to make contributing to this project as easy and transparent as possible. 3 | 4 | ## Our Development Process 5 | Minor changes and improvements will be released on an ongoing basis. 6 | Larger changes (e.g., changesets implementing a new paper) will be released 7 | on a more periodic basis. 8 | 9 | 10 | ## Pull Requests 11 | We actively welcome your pull requests. 12 | 13 | 1. Fork the repo and create your branch from `master`. 14 | 2. If you've added code that should be tested, add tests. 15 | 3. If you've changed APIs, update the documentation. 16 | 4. Ensure the test suite passes. 17 | 5. Make sure your code lints. 18 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 19 | 20 | ## Contributor License Agreement ("CLA") 21 | In order to accept your pull request, we need you to submit a CLA. You only need 22 | to do this once to work on any of Facebook's open source projects. 23 | 24 | Complete your CLA here: 25 | 26 | ## Issues 27 | We use GitHub issues to track public bugs. Please ensure your description is 28 | clear and has sufficient instructions to be able to reproduce the issue. 29 | 30 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 31 | disclosure of security bugs. In those cases, please go through the process 32 | outlined on that page and do not file a public issue. 33 | 34 | ## Coding Style 35 | * 4 spaces for indentation rather than tabs 36 | 37 | ## License 38 | By contributing to Grounded Video Description, you agree that your contributions will 39 | be licensed under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ============================================================================= 24 | 25 | MIT License 26 | 27 | Copyright (c) 2017 Jiasen Lu 28 | 29 | Permission is hereby granted, free of charge, to any person obtaining a copy 30 | of this software and associated documentation files (the "Software"), to deal 31 | in the Software without restriction, including without limitation the rights 32 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | copies of the Software, and to permit persons to whom the Software is 34 | furnished to do so, subject to the following conditions: 35 | 36 | The above copyright notice and this permission notice shall be included in all 37 | copies or substantial portions of the Software. 38 | 39 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | SOFTWARE. 46 | 47 | ============================================================================= 48 | 49 | For the following file(s): 50 | grounded-video-description/misc/transformer.py 51 | 52 | BSD 3-Clause "New" or "Revised" License 53 | 54 | Copyright (c) 2018, Salesforce.com, Inc. 55 | All rights reserved. 56 | 57 | Redistribution and use in source and binary forms, with or without modification, 58 | are permitted provided that the following conditions are met: 59 | 60 | * Redistributions of source code must retain the above copyright notice, 61 | this list of conditions and the following disclaimer. 62 | 63 | * Redistributions in binary form must reproduce the above copyright notice, 64 | this list of conditions and the following disclaimer in the documentation 65 | and/or other materials provided with the distribution. 66 | 67 | * Neither the name of Salesforce.com nor the names of its contributors may be 68 | used to endorse or promote products derived from this software without specific 69 | prior written permission. 70 | 71 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 72 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 73 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 74 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 75 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 76 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 77 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 78 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 79 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 80 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Grounded Video Description 2 | 3 | ### [ActivityNet Entities Object Localization (Grounding) Challenge](http://activity-net.org/challenges/2020/tasks/guest_anet_eol.html) joins the official [ActivityNet Challenge](http://activity-net.org/challenges/2020/challenge.html) as a guest task this year! See [here](https://github.com/facebookresearch/ActivityNet-Entities#activitynet-entities-object-localization-challenge-2020) on how to participate. 4 | 5 | This repo hosts the source code for our paper [Grounded Video Description](https://arxiv.org/pdf/1812.06587.pdf). It supports [ActivityNet-Entities](https://github.com/facebookresearch/ActivityNet-Entities) dataset. We also have code that supports [Flickr30k-Entities](https://github.com/BryanPlummer/flickr30k_entities) dataset, hosted at the [flickr_branch](https://github.com/facebookresearch/grounded-video-description/tree/flickr_branch) branch. 6 | 7 | teaser results 8 | 9 | Note: [42] indicates [Masked Transformer](https://github.com/LuoweiZhou/densecap) 10 | 11 | 12 | ## Quick Start 13 | ### Preparations 14 | Follow the instructions 1 to 3 in the [Requirements](#req) section to install required packages. 15 | 16 | ### Download everything 17 | Simply run the following command to download all the data and pre-trained models (total 216GB): 18 | ``` 19 | bash tools/download_all.sh 20 | ``` 21 | 22 | ### Starter code 23 | Run the following eval code to test if your environment is setup: 24 | ``` 25 | python main.py --batch_size 100 --cuda --num_workers 6 --max_epoch 50 --inference_only \ 26 | --start_from save/anet-sup-0.05-0-0.1-run1 --id anet-sup-0.05-0-0.1-run1 \ 27 | --seq_length 20 --language_eval --eval_obj_grounding --obj_interact 28 | ``` 29 | 30 | (Optional) Single-GPU training code for double-check: 31 | ``` 32 | python main.py --batch_size 20 --cuda --checkpoint_path save/gvd_starter --id gvd_starter --language_eval 33 | ``` 34 | You can now skip to the [Training and Validation](#train) section! 35 | 36 | 37 | ## Requirements (Recommended) 38 | 1) Clone the repo recursively: 39 | ``` 40 | git clone --recursive git@github.com:facebookresearch/grounded-video-description.git 41 | ``` 42 | Make sure all the submodules [densevid_eval](https://github.com/LuoweiZhou/densevid_eval_spice) and [coco-caption](https://github.com/tylin/coco-caption) are included. 43 | 44 | 2) Install CUDA 9.0 and CUDNN v7.1. Later versions should be fine, but might need to get the conda env file updated (e.g., for PyTorch). 45 | 46 | 3) Install [Miniconda](https://conda.io/miniconda.html) (either Miniconda2 or 3, version 4.6+). We recommend using conda [environment](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) to install required packages, including Python 3.7 or 2.7, [PyTorch 1.1.0](https://pytorch.org/get-started/locally/) etc.: 47 | ``` 48 | MINICONDA_ROOT=[to your Miniconda root directory] 49 | conda env create -f cfgs/conda_env_gvd_py3.yml --prefix $MINICONDA_ROOT/envs/gvd_pytorch1.1 50 | conda activate gvd_pytorch1.1 51 | ``` 52 | Note that there have been some [breaking changes](https://github.com/pytorch/pytorch/releases/tag/v1.2.0) since PyTorch 1.2 (e.g., bitwise not on torch.bool/torch.uint8 and masked\_fill\_). This code base could potentially work with PyTorch 1.2+ with corresponding changes made. 53 | 54 | Replace `cfgs/conda_env_gvd_py3.yml` with `cfgs/conda_env_gvd.yml` for Python 2.7. 55 | 56 | 4) (Optional) If you choose to not use `download_all.sh`, be sure to install JAVA and download Stanford CoreNLP for SPICE (see [here](https://github.com/tylin/coco-caption)). Also, download and place the reference [file](https://github.com/jiasenlu/coco-caption/blob/master/annotations/caption_flickr30k.json) under `coco-caption/annotations`. Download [Stanford CoreNLP 3.9.1](https://stanfordnlp.github.io/CoreNLP/history.html) for grounding evaluation and place the uncompressed folder under the `tools` directory. 57 | 58 | 59 | ## Data Preparation 60 | Updates on 04/15/2020: Feature files for the **hidden** test set, used in ANet-Entities Object Localization Challenge 2020, are available to download ([region features](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/fc6_feat_100rois_hidden_test.tar.gz) and [frame-wise features](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/rgb_motion_1d_hidden_test.tar.gz)). Make sure you move the additional *.npy files over to your folder `fc6_feat_100rois` and `rgb_motion_1d`, respectively. The following files have been updated to include the **hidden** test set or video IDs: `anet_detection_vg_fc6_feat_100rois.h5`, `anet_entities_prep.tar.gz`, and `anet_entities_captions.tar.gz`. 61 | 62 | Download the preprocessed annotation files from [here](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_prep.tar.gz), uncompress and place them under `data/anet`. Or you can reproduce them all using the data from ActivityNet-Entities [repo](https://github.com/facebookresearch/ActivityNet-Entities) and the preprocessing script `prepro_dic_anet.py` under `prepro`. Then, download the ground-truth caption annotations (under our val/test splits) from [here](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_captions.tar.gz) and same place under `data/anet`. 63 | 64 | The region features and detections are available for download ([feature](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/fc6_feat_100rois.tar.gz) and [detection](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_detection_vg_fc6_feat_100rois.h5)). The region feature file should be decompressed and placed under your feature directory. We refer to the region feature directory as `feature_root` in the code. The H5 region detection (proposal) file is referred to as `proposal_h5` in the code. To extract feature for customized dataset (or brave folks for ANet-Entities as well), refer to the feature extraction tool [here](https://github.com/LuoweiZhou/detectron-vlp). 65 | 66 | The frame-wise appearance (with suffix `_resnet.npy`) and motion (with suffix `_bn.npy`) feature files are available [here](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/rgb_motion_1d.tar.gz). We refer to this directory as `seg_feature_root`. 67 | 68 | Other auxiliary files, such as the weights from Detectron fc7 layer, are available [here](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/detectron_weights.tar.gz). Uncompress and place under the `data` directory. 69 | 70 | 71 | ## Training and Validation 72 | Modify the config file `cfgs/anet_res101_vg_feat_10x100prop.yml` with the correct dataset and feature paths (or through symlinks). Link `tools/anet_entities` to your ANet-Entities dataset root location. Create new directories `log` and `results` under the root directory to save log and result files. 73 | 74 | The example command on running a 8-GPU data parallel job: 75 | 76 | For supervised models (with self-attention): 77 | ``` 78 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py --path_opt cfgs/anet_res101_vg_feat_10x100prop.yml \ 79 | --batch_size $batch_size --cuda --checkpoint_path save/$ID --id $ID --mGPUs \ 80 | --language_eval --w_att2 $w_att2 --w_grd $w_grd --w_cls $w_cls --obj_interact | tee log/$ID 81 | ``` 82 | 83 | For unsupervised models (without self-attention): 84 | ``` 85 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py --path_opt cfgs/anet_res101_vg_feat_10x100prop.yml \ 86 | --batch_size $batch_size --cuda --checkpoint_path save/$ID --id $ID --mGPUs \ 87 | --language_eval | tee log/$ID 88 | ``` 89 | Arguments: `batch_size=240`, `w_att2=0.05`, `w_grd=0`, `w_cls=0.1`, `ID` indicates the model name. 90 | 91 | (Optional) Remove `--mGPUs` to run in single-GPU mode. 92 | 93 | ### Pre-trained Models 94 | The pre-trained models can be downloaded from [here (1.5GB)](https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/pre-trained-models.tar.gz). Make sure you uncompress the file under the `save` directory (create one under the root directory if not exists). 95 | 96 | 97 | ## Inference and Testing 98 | For supervised models (`ID=anet-sup-0.05-0-0.1-run1`): 99 | 100 | (standard inference: language evaluation and localization evaluation on generated sentences) 101 | 102 | ``` 103 | python main.py --path_opt cfgs/anet_res101_vg_feat_10x100prop.yml --batch_size 100 --cuda \ 104 | --num_workers 6 --max_epoch 50 --inference_only --start_from save/$ID --id $ID \ 105 | --val_split $val_split --densecap_references $dc_references --densecap_verbose --seq_length 20 \ 106 | --language_eval --eval_obj_grounding --obj_interact \ 107 | | tee log/eval-$val_split-$ID-beam$beam_size-standard-inference 108 | ``` 109 | 110 | (GT inference: localization evaluation on GT sentences) 111 | 112 | ``` 113 | python main.py --path_opt cfgs/anet_res101_vg_feat_10x100prop.yml --batch_size 100 --cuda \ 114 | --num_workers 6 --max_epoch 50 --inference_only --start_from save/$ID --id $ID \ 115 | --val_split $val_split --seq_length 40 --eval_obj_grounding_gt --obj_interact \ 116 | --grd_reference $grd_reference | tee log/eval-$val_split-$ID-beam$beam_size-gt-inference 117 | ``` 118 | 119 | For unsupervised models (`ID=anet-unsup-0-0-0-run1`), simply remove the `--obj_interact` option. 120 | 121 | Arguments: `dc_references='./data/anet/anet_entities_val_1.json ./data/anet/anet_entities_val_2.json'`, `grd_reference='tools/anet_entities/data/anet_entities_cleaned_class_thresh50_trainval.json'` `val_split='validation'`. If you want to evaluate on the test splits, set `val_split` to `'testing'` or `'hidden_test'`, `dc_references` (look for `anet_entities_test_1.json` and `anet_entities_test_2.json` and this only supports `'testing'`), and `grd_reference` (the skeleton files `*testing*.json` and `*hidden_test*.json`) accordingly. Then,submit the object localization output files under `results` to the [eval server](https://competitions.codalab.org/competitions/20537). Note that the eval server here is for general purposes. The servers designed for the CVPR'20 challenge is instead [here](https://github.com/facebookresearch/ActivityNet-Entities#evaluation-servers). 122 | 123 | You need at least 9GB of free GPU memory for the evaluation. 124 | 125 | 126 | ## Reference 127 | Please acknowledge the following paper if you use the code: 128 | 129 | ``` 130 | @inproceedings{zhou2019grounded, 131 | title={Grounded Video Description}, 132 | author={Zhou, Luowei and Kalantidis, Yannis and Chen, Xinlei and Corso, Jason J and Rohrbach, Marcus}, 133 | booktitle={CVPR}, 134 | year={2019} 135 | } 136 | ``` 137 | 138 | 139 | ## Acknowledgement 140 | We thank Jiasen Lu for his [Neural Baby Talk](https://github.com/jiasenlu/NeuralBabyTalk) repo. We thank Chih-Yao Ma for his helpful discussions. 141 | 142 | 143 | ## License 144 | This project is licensed under the license found in the LICENSE file in the root directory of this source tree. 145 | 146 | Portions of the source code are based on the [Neural Baby Talk](https://github.com/jiasenlu/NeuralBabyTalk) project. 147 | -------------------------------------------------------------------------------- /cfgs/anet_res101_vg_feat_10x100prop.yml: -------------------------------------------------------------------------------- 1 | # dataset setting 2 | dataset: anet 3 | input_json: 'data/anet/cap_anet_trainval.json' 4 | input_dic: 'data/anet/dic_anet.json' 5 | seg_feature_root: 'data/anet/rgb_motion_1d' 6 | image_path: 'data/anet/frames_10frm' # for visualization only 7 | feature_root: 'data/anet/fc6_feat_100rois' 8 | proposal_h5: 'data/anet/anet_detection_vg_fc6_feat_100rois.h5' 9 | num_prop_per_frm: 100 10 | data_path: 'data' 11 | # language model 12 | att_model: topdown 13 | num_layers: 1 14 | seq_per_img: 1 15 | val_images_use: -1 16 | optim: 'adam' 17 | -------------------------------------------------------------------------------- /cfgs/conda_env_gvd.yml: -------------------------------------------------------------------------------- 1 | name: gvd_pytorch1.1 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ca-certificates=2019.1.23=0 6 | - libedit=3.1.20181209=hc058e9b_0 7 | - libffi=3.2.1=hd88cf55_4 8 | - libgcc-ng=8.2.0=hdf63c60_1 9 | - libstdcxx-ng=8.2.0=hdf63c60_1 10 | - ncurses=6.1=he6710b0_1 11 | - openssl=1.1.1b=h7b6447c_1 12 | - pip=19.1=py27_0 13 | - python=2.7.16=h9bab390_0 14 | - readline=7.0=h7b6447c_5 15 | - setuptools=41.0.1=py27_0 16 | - sqlite=3.28.0=h7b6447c_0 17 | - tk=8.6.8=hbc83047_0 18 | - wheel=0.33.1=py27_0 19 | - zlib=1.2.11=h7b6447c_3 20 | - pip: 21 | - backports-functools-lru-cache==1.5 22 | - certifi==2019.3.9 23 | - chardet==3.0.4 24 | - cycler==0.10.0 25 | - future==0.17.1 26 | - h5py==2.9.0 27 | - idna==2.8 28 | - kiwisolver==1.1.0 29 | - matplotlib==2.2.4 30 | - numpy==1.16.3 31 | - pillow==6.0.0 32 | - psutil==5.6.2 33 | - pyparsing==2.4.0 34 | - python-dateutil==2.8.0 35 | - pytz==2019.1 36 | - pyyaml==5.1 37 | - requests==2.21.0 38 | - six==1.12.0 39 | - stanfordcorenlp==3.9.1.1 40 | - subprocess32==3.5.3 41 | - torch==1.1.0 42 | - torchtext==0.3.1 43 | - torchvision==0.2.2.post3 44 | - tqdm==4.31.1 45 | - urllib3==1.24.3 46 | prefix: /home/luozhou/subsystem/tmp/miniconda3/envs/gvd_pytorch1.1 47 | 48 | -------------------------------------------------------------------------------- /cfgs/conda_env_gvd_py3.yml: -------------------------------------------------------------------------------- 1 | name: gvd_py3_pytorch1.1 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2019.11.27=0 9 | - certifi=2019.11.28=py37_0 10 | - cffi=1.13.2=py37h2e261b9_0 11 | - cudatoolkit=9.0=h13b8566_0 12 | - freetype=2.9.1=h8a8886c_1 13 | - intel-openmp=2019.4=243 14 | - jpeg=9b=h024ee3a_2 15 | - ld_impl_linux-64=2.33.1=h53a641e_7 16 | - libedit=3.1.20181209=hc058e9b_0 17 | - libffi=3.2.1=hd88cf55_4 18 | - libgcc-ng=9.1.0=hdf63c60_0 19 | - libgfortran-ng=7.3.0=hdf63c60_0 20 | - libpng=1.6.37=hbc83047_0 21 | - libstdcxx-ng=9.1.0=hdf63c60_0 22 | - libtiff=4.1.0=h2733197_0 23 | - mkl=2019.4=243 24 | - mkl-service=2.3.0=py37he904b0f_0 25 | - mkl_fft=1.0.15=py37ha843d7b_0 26 | - mkl_random=1.1.0=py37hd6b4f25_0 27 | - ncurses=6.1=he6710b0_1 28 | - ninja=1.9.0=py37hfd86e86_0 29 | - numpy=1.17.4=py37hc1035e2_0 30 | - numpy-base=1.17.4=py37hde5b4d6_0 31 | - olefile=0.46=py37_0 32 | - openssl=1.1.1d=h7b6447c_3 33 | - pip=19.3.1=py37_0 34 | - pycparser=2.19=py37_0 35 | - python=3.7.6=h0371630_1 36 | - pytorch=1.1.0=py3.7_cuda9.0.176_cudnn7.5.1_0 37 | - readline=7.0=h7b6447c_5 38 | - setuptools=44.0.0=py37_0 39 | - six=1.13.0=py37_0 40 | - sqlite=3.30.1=h7b6447c_0 41 | - tk=8.6.8=hbc83047_0 42 | - torchvision=0.3.0=py37_cu9.0.176_1 43 | - wheel=0.33.6=py37_0 44 | - xz=5.2.4=h14c3975_4 45 | - zlib=1.2.11=h7b6447c_3 46 | - zstd=1.3.7=h0b5b093_0 47 | - pip: 48 | - chardet==3.0.4 49 | - cycler==0.10.0 50 | - h5py==2.10.0 51 | - idna==2.8 52 | - kiwisolver==1.1.0 53 | - matplotlib==3.1.2 54 | - pillow==6.1.0 55 | - psutil==5.6.7 56 | - pyparsing==2.4.6 57 | - python-dateutil==2.8.1 58 | - pyyaml==5.3 59 | - requests==2.22.0 60 | - stanfordcorenlp==3.9.1.1 61 | - torchtext==0.4.0 62 | - tqdm==4.41.1 63 | - urllib3==1.25.7 64 | prefix: /z/home/luozhou/miniconda3/envs/gvd_py3_pytorch1.1 65 | 66 | -------------------------------------------------------------------------------- /data/vg_object_vocab.txt: -------------------------------------------------------------------------------- 1 | yolk 2 | goal 3 | bathroom 4 | macaroni 5 | umpire 6 | toothpick 7 | alarm clock 8 | ceiling fan 9 | photos 10 | parrot 11 | tail fin 12 | birthday cake 13 | calculator 14 | catcher 15 | toilet 16 | batter 17 | stop sign,stopsign 18 | cone 19 | microwave,microwave oven 20 | skateboard ramp 21 | tea 22 | dugout 23 | products 24 | halter 25 | kettle 26 | kitchen 27 | refrigerator,fridge 28 | ostrich 29 | bathtub 30 | blinds 31 | court 32 | urinal 33 | knee pads 34 | bed 35 | flamingo 36 | giraffe 37 | helmet 38 | giraffes 39 | tennis court 40 | motorcycle 41 | laptop 42 | tea pot 43 | horse 44 | television,tv 45 | shorts 46 | manhole 47 | dishwasher 48 | jeans 49 | sail 50 | monitor 51 | man 52 | shirt 53 | car 54 | cat 55 | garage door 56 | bus 57 | radiator 58 | tights 59 | sailboat,sail boat 60 | racket,racquet 61 | plate 62 | rock wall 63 | beach 64 | trolley 65 | ocean 66 | headboard,head board 67 | tea kettle 68 | wetsuit 69 | tennis racket,tennis racquet 70 | sink 71 | train 72 | keyboard 73 | sky 74 | match 75 | train station 76 | stereo 77 | bats 78 | tennis player 79 | toilet brush 80 | lighter 81 | pepper shaker 82 | gazebo 83 | hair dryer 84 | elephant 85 | toilet seat 86 | zebra 87 | skateboard,skate board 88 | zebras 89 | floor lamp 90 | french fries 91 | woman 92 | player 93 | tower 94 | bicycle 95 | magazines 96 | christmas tree 97 | umbrella 98 | cow 99 | pants 100 | bike 101 | field 102 | living room 103 | latch 104 | bedroom 105 | grape 106 | castle 107 | table 108 | swan 109 | blender 110 | orange 111 | teddy bear 112 | net 113 | meter 114 | baseball field 115 | runway 116 | screen 117 | ski boot 118 | dog 119 | clock 120 | hair 121 | avocado 122 | highway 123 | skirt 124 | frisbee 125 | parasail 126 | desk 127 | pizza 128 | mouse 129 | sign 130 | shower curtain 131 | polar bear 132 | airplane 133 | jersey 134 | reigns 135 | hot dog,hotdog 136 | surfboard,surf board 137 | couch 138 | glass 139 | snowboard 140 | girl 141 | plane 142 | elephants 143 | oven 144 | dirt bike 145 | tail wing 146 | area rug 147 | bear 148 | washer 149 | date 150 | bow tie 151 | cows 152 | fire extinguisher 153 | bamboo 154 | wallet 155 | tail feathers 156 | truck 157 | beach chair 158 | boat 159 | tablet 160 | ceiling 161 | chandelier 162 | sheep 163 | glasses 164 | ram 165 | kite 166 | salad 167 | pillow 168 | fire hydrant,hydrant 169 | mug 170 | tarmac 171 | computer 172 | swimsuit 173 | tomato 174 | tire 175 | cauliflower 176 | fireplace 177 | snow 178 | building 179 | sandwich 180 | weather vane 181 | bird 182 | jacket 183 | chair 184 | water 185 | cats 186 | soccer ball 187 | horses 188 | drapes 189 | barn 190 | engine 191 | cake 192 | head 193 | head band 194 | skier 195 | town 196 | bath tub 197 | bowl 198 | stove 199 | tongue 200 | coffee table 201 | floor 202 | uniform 203 | ottoman 204 | broccoli 205 | olive 206 | mound 207 | pitcher 208 | food 209 | paintings 210 | traffic light 211 | parking meter 212 | bananas 213 | mountain 214 | cage 215 | hedge 216 | motorcycles 217 | wet suit 218 | radish 219 | teddy bears 220 | monitors 221 | suitcase,suit case 222 | drawers 223 | grass 224 | apple 225 | lamp 226 | goggles 227 | boy 228 | armchair 229 | ramp 230 | burner 231 | lamb 232 | cup 233 | tank top 234 | boats 235 | hat 236 | soup 237 | fence 238 | necklace 239 | visor 240 | coffee 241 | bottle 242 | stool 243 | shoe 244 | surfer 245 | stop 246 | backpack 247 | shin guard 248 | wii remote 249 | wall 250 | pizza slice 251 | home plate 252 | van 253 | packet 254 | earrings 255 | wristband 256 | tracks 257 | mitt 258 | dome 259 | snowboarder 260 | faucet 261 | toiletries 262 | ski boots 263 | room 264 | fork 265 | snow suit 266 | banana slice 267 | bench 268 | tie 269 | burners 270 | stuffed animals 271 | zoo 272 | train platform 273 | cupcake 274 | curtain 275 | ear 276 | tissue box 277 | bread 278 | scissors 279 | vase 280 | herd 281 | smoke 282 | skylight 283 | cub 284 | tail 285 | cutting board 286 | wave 287 | hedges 288 | windshield 289 | apples 290 | mirror 291 | license plate 292 | tree 293 | wheel 294 | ski pole 295 | clock tower 296 | freezer 297 | luggage 298 | skateboarder 299 | mousepad 300 | road 301 | bat 302 | toilet tank 303 | vanity 304 | neck 305 | cliff 306 | tub 307 | sprinkles 308 | dresser 309 | street 310 | wing 311 | suit 312 | veggie 313 | palm trees 314 | urinals 315 | door 316 | propeller 317 | keys 318 | skate park 319 | platform 320 | pot 321 | towel 322 | computer monitor 323 | flip flop 324 | eggs 325 | shed 326 | moped 327 | sand 328 | face 329 | scissor 330 | carts 331 | squash 332 | pillows 333 | family 334 | glove 335 | rug 336 | watch 337 | grafitti 338 | dogs 339 | scoreboard 340 | basket 341 | poster 342 | duck 343 | horns 344 | bears 345 | jeep 346 | painting 347 | lighthouse 348 | remote control 349 | toaster 350 | vegetables 351 | surfboards 352 | ducks 353 | lane 354 | carrots 355 | market 356 | paper towels 357 | island 358 | blueberries 359 | smile 360 | balloons 361 | stroller 362 | napkin 363 | towels 364 | papers 365 | person 366 | train tracks 367 | child 368 | headband 369 | pool 370 | plant 371 | harbor 372 | counter 373 | hand 374 | house 375 | donut,doughnut 376 | knot 377 | soccer player 378 | seagull 379 | bottles 380 | buses 381 | coat 382 | trees 383 | geese 384 | bun 385 | toilet bowl 386 | trunk 387 | station 388 | bikini 389 | goatee 390 | lounge chair 391 | breakfast 392 | nose 393 | moon 394 | river 395 | racer 396 | picture 397 | shaker 398 | sidewalk,side walk 399 | shutters 400 | stove top,stovetop 401 | church 402 | lampshade 403 | map 404 | shop 405 | platter 406 | airport 407 | hoodie 408 | oranges 409 | woods 410 | enclosure 411 | skatepark 412 | vases 413 | city 414 | park 415 | mailbox 416 | balloon 417 | billboard 418 | pasture 419 | portrait 420 | forehead 421 | ship 422 | cookie 423 | seaweed 424 | sofa 425 | slats 426 | tomato slice 427 | tractor 428 | bull 429 | suitcases 430 | graffiti 431 | policeman 432 | remotes 433 | pens 434 | window sill 435 | suspenders 436 | easel 437 | tray 438 | straw 439 | collar 440 | shower 441 | bag 442 | scooter 443 | tails 444 | toilet lid 445 | panda 446 | comforter 447 | outlet 448 | stems 449 | valley 450 | flag 451 | jockey 452 | gravel 453 | mouth 454 | window 455 | bridge 456 | corn 457 | mountains 458 | beer 459 | pitcher's mound 460 | palm tree 461 | crowd 462 | skis 463 | phone 464 | banana bunch 465 | tennis shoe 466 | ground 467 | carpet 468 | eye 469 | urn 470 | beak 471 | giraffe head 472 | steeple 473 | mattress 474 | baseball player 475 | wine 476 | water bottle 477 | kitten 478 | archway 479 | candle 480 | croissant 481 | tennis ball 482 | dress 483 | column 484 | utensils 485 | cell phone 486 | computer mouse 487 | cap 488 | lawn 489 | airplanes 490 | carriage 491 | snout 492 | cabinets 493 | lemons 494 | grill 495 | umbrellas 496 | meat 497 | wagon 498 | ipod 499 | bookshelf 500 | cart 501 | roof 502 | hay 503 | ski pants 504 | seat 505 | mane 506 | bikes 507 | drawer 508 | game 509 | clock face 510 | boys 511 | rider 512 | fire escape 513 | slope 514 | iphone 515 | pumpkin 516 | pan 517 | chopsticks 518 | hill 519 | uniforms 520 | cleat 521 | costume 522 | cabin 523 | police officer 524 | ears 525 | egg 526 | trash can 527 | horn 528 | arrow 529 | toothbrush 530 | carrot 531 | banana 532 | planes 533 | garden 534 | forest 535 | brocolli 536 | aircraft 537 | front window 538 | dashboard 539 | statue 540 | saucer 541 | people 542 | silverware 543 | fruit 544 | drain 545 | jet 546 | speaker 547 | eyes 548 | railway 549 | lid 550 | soap 551 | rocks 552 | office chair 553 | door knob 554 | banana peel 555 | baseball game 556 | asparagus 557 | spoon 558 | cabinet door 559 | pineapple 560 | traffic cone 561 | nightstand,night stand 562 | teapot 563 | taxi 564 | chimney 565 | lake 566 | suit jacket 567 | train engine 568 | ball 569 | wrist band 570 | pickle 571 | fruits 572 | pad 573 | dispenser 574 | bridle 575 | breast 576 | cones 577 | headlight 578 | necktie 579 | skater 580 | toilet paper 581 | skyscraper 582 | telephone 583 | ox 584 | roadway 585 | sock 586 | paddle 587 | dishes 588 | hills 589 | street sign 590 | headlights 591 | benches 592 | fuselage 593 | card 594 | napkins 595 | bush 596 | rice 597 | computer screen 598 | spokes 599 | flowers 600 | bucket 601 | rock 602 | pole 603 | pear 604 | sauce 605 | store 606 | juice 607 | knobs 608 | mustard 609 | ski 610 | stands 611 | cabinet 612 | dirt 613 | goats 614 | wine glass 615 | spectators 616 | crate 617 | pancakes 618 | kids 619 | engines 620 | shade 621 | feeder 622 | cellphone 623 | pepper 624 | blanket 625 | sunglasses 626 | train car 627 | magnet 628 | donuts,doughnuts 629 | sweater 630 | signal 631 | advertisement 632 | log 633 | vent 634 | whiskers 635 | adult 636 | arch 637 | locomotive 638 | tennis match 639 | tent 640 | motorbike 641 | magnets 642 | night 643 | marina 644 | wool 645 | vest 646 | railroad tracks 647 | stuffed bear 648 | moustache 649 | bib 650 | frame 651 | snow pants 652 | tank 653 | undershirt 654 | icons 655 | neck tie 656 | beams 657 | baseball bat 658 | safety cone 659 | paper towel 660 | bedspread 661 | can 662 | container 663 | flower 664 | vehicle 665 | tomatoes 666 | back wheel 667 | soccer field 668 | nostril 669 | suv 670 | buildings 671 | canopy 672 | flame 673 | kid 674 | baseball 675 | throw pillow 676 | belt 677 | rainbow 678 | lemon 679 | oven door 680 | tag 681 | books 682 | monument 683 | men 684 | shadow 685 | bicycles 686 | cars 687 | lamp shade 688 | pine tree 689 | bouquet 690 | toothpaste 691 | potato 692 | sinks 693 | hook 694 | switch 695 | lamp post,lamppost 696 | lapel 697 | desert 698 | knob 699 | chairs 700 | pasta 701 | feathers 702 | hole 703 | meal 704 | station wagon 705 | kites 706 | boots 707 | baby 708 | biker 709 | gate 710 | signal light 711 | headphones 712 | goat 713 | waves 714 | bumper 715 | bud 716 | logo 717 | curtains 718 | american flag 719 | yacht 720 | box 721 | baseball cap 722 | fries 723 | controller 724 | awning 725 | path 726 | front legs 727 | life jacket 728 | purse 729 | outfield 730 | pigeon 731 | toddler 732 | beard 733 | thumb 734 | water tank 735 | board 736 | parade 737 | robe 738 | newspaper 739 | wires 740 | camera 741 | pastries 742 | deck 743 | watermelon 744 | clouds 745 | deer 746 | motorcyclist 747 | kneepad 748 | sneakers 749 | women 750 | onions 751 | eyebrow 752 | gas station 753 | vane 754 | girls 755 | trash 756 | numerals 757 | knife 758 | tags 759 | light 760 | bunch 761 | outfit 762 | groom 763 | infield 764 | frosting 765 | forks 766 | entertainment center 767 | stuffed animal 768 | yard 769 | numeral 770 | ladder 771 | shoes 772 | bracelet 773 | teeth 774 | guy 775 | display case 776 | cushion 777 | post 778 | pathway 779 | tablecloth 780 | skiers 781 | trouser 782 | cloud 783 | hands 784 | produce 785 | beam 786 | ketchup 787 | paw 788 | dish 789 | raft 790 | crosswalk 791 | front wheel 792 | toast 793 | cattle 794 | players 795 | group 796 | coffee pot 797 | track 798 | cowboy hat 799 | petal 800 | eyeglasses 801 | handle 802 | table cloth 803 | jets 804 | shakers 805 | remote 806 | snowsuit 807 | bushes 808 | dessert 809 | leg 810 | eagle 811 | fire truck,firetruck 812 | game controller 813 | smartphone 814 | backsplash 815 | trains 816 | shore 817 | signs 818 | bell 819 | cupboards 820 | sweat band 821 | sack 822 | ankle 823 | coin slot 824 | bagel 825 | masts 826 | police 827 | drawing 828 | biscuit 829 | toy 830 | legs 831 | pavement 832 | outside 833 | wheels 834 | driver 835 | numbers 836 | blazer 837 | pen 838 | cabbage 839 | trucks 840 | key 841 | saddle 842 | pillow case 843 | goose 844 | label 845 | boulder 846 | pajamas 847 | wrist 848 | shelf 849 | cross 850 | coffee cup 851 | foliage 852 | lot 853 | fry 854 | air 855 | officer 856 | pepperoni 857 | cheese 858 | lady 859 | kickstand 860 | counter top 861 | veggies 862 | baseball uniform 863 | book shelf 864 | bags 865 | pickles 866 | stand 867 | netting 868 | lettuce 869 | facial hair 870 | lime 871 | animals 872 | drape 873 | boot 874 | railing 875 | end table 876 | shin guards 877 | steps 878 | trashcan 879 | tusk 880 | head light 881 | walkway 882 | cockpit 883 | tennis net 884 | animal 885 | boardwalk 886 | keypad 887 | bookcase 888 | blueberry 889 | trash bag 890 | ski poles 891 | parking lot 892 | gas tank 893 | beds 894 | fan 895 | base 896 | soap dispenser 897 | banner 898 | life vest 899 | train front 900 | word 901 | cab 902 | liquid 903 | exhaust pipe 904 | sneaker 905 | light fixture 906 | power lines 907 | curb 908 | scene 909 | buttons 910 | roman numerals 911 | muzzle 912 | sticker 913 | bacon 914 | pizzas 915 | paper 916 | feet 917 | stairs 918 | triangle 919 | plants 920 | rope 921 | beans 922 | brim 923 | beverage 924 | letters 925 | soda 926 | menu 927 | finger 928 | dvds 929 | candles 930 | picnic table 931 | wine bottle 932 | pencil 933 | tree trunk 934 | nail 935 | mantle 936 | countertop 937 | view 938 | line 939 | motor bike 940 | audience 941 | traffic sign 942 | arm 943 | pedestrian 944 | stabilizer 945 | dock 946 | doorway 947 | bedding 948 | end 949 | worker 950 | canal 951 | crane 952 | grate 953 | little girl 954 | rims 955 | passenger car 956 | plates 957 | background 958 | peel 959 | brake light 960 | roman numeral 961 | string 962 | tines 963 | turf 964 | armrest 965 | shower head 966 | leash 967 | stones 968 | stoplight 969 | handle bars 970 | front 971 | scarf 972 | band 973 | jean 974 | tennis 975 | pile 976 | doorknob 977 | foot 978 | houses 979 | windows 980 | restaurant 981 | booth 982 | cardboard box 983 | fingers 984 | mountain range 985 | bleachers 986 | rail 987 | pastry 988 | canoe 989 | sun 990 | eye glasses 991 | salt shaker 992 | number 993 | fish 994 | knee pad 995 | fur 996 | she 997 | shower door 998 | rod 999 | branches 1000 | birds 1001 | printer 1002 | sunset 1003 | median 1004 | shutter 1005 | slice 1006 | heater 1007 | prongs 1008 | bathing suit 1009 | skiier 1010 | rack 1011 | book 1012 | blade 1013 | apartment 1014 | manhole cover 1015 | stools 1016 | overhang 1017 | door handle 1018 | couple 1019 | picture frame 1020 | chicken 1021 | planter 1022 | seats 1023 | hour hand 1024 | dvd player 1025 | ski slope 1026 | french fry 1027 | bowls 1028 | top 1029 | landing gear 1030 | coffee maker 1031 | melon 1032 | computers 1033 | light switch 1034 | jar 1035 | tv stand 1036 | overalls 1037 | garage 1038 | tabletop 1039 | writing 1040 | doors 1041 | stadium 1042 | placemat 1043 | air vent 1044 | trick 1045 | sled 1046 | mast 1047 | pond 1048 | steering wheel 1049 | baseball glove 1050 | watermark 1051 | pie 1052 | sandwhich 1053 | cpu 1054 | mushroom 1055 | power pole 1056 | dirt road 1057 | handles 1058 | speakers 1059 | fender 1060 | telephone pole 1061 | strawberry 1062 | mask 1063 | children 1064 | crust 1065 | art 1066 | rim 1067 | branch 1068 | display 1069 | grasses 1070 | photo 1071 | receipt 1072 | instructions 1073 | herbs 1074 | toys 1075 | handlebars 1076 | trailer 1077 | sandal 1078 | skull 1079 | hangar 1080 | pipe 1081 | office 1082 | chest 1083 | lamps 1084 | horizon 1085 | calendar 1086 | foam 1087 | stone 1088 | bars 1089 | button 1090 | poles 1091 | heart 1092 | hose 1093 | jet engine 1094 | potatoes 1095 | rain 1096 | magazine 1097 | chain 1098 | footboard 1099 | tee shirt 1100 | design 1101 | walls 1102 | copyright 1103 | pictures 1104 | pillar 1105 | drink 1106 | barrier 1107 | boxes 1108 | chocolate 1109 | chef 1110 | slot 1111 | sweatpants 1112 | face mask 1113 | icing 1114 | wipers 1115 | circle 1116 | bin 1117 | kitty 1118 | electronics 1119 | wild 1120 | tiles 1121 | steam 1122 | lettering 1123 | bathroom sink 1124 | laptop computer 1125 | cherry 1126 | spire 1127 | conductor 1128 | sheet 1129 | slab 1130 | windshield wipers 1131 | storefront 1132 | hill side 1133 | spatula 1134 | tail light,taillight 1135 | bean 1136 | wire 1137 | intersection 1138 | pier 1139 | snow board 1140 | trunks 1141 | website 1142 | bolt 1143 | kayak 1144 | nuts 1145 | holder 1146 | turbine 1147 | stop light 1148 | olives 1149 | ball cap 1150 | burger 1151 | barrel 1152 | fans 1153 | beanie 1154 | stem 1155 | lines 1156 | traffic signal 1157 | sweatshirt 1158 | handbag 1159 | mulch 1160 | socks 1161 | landscape 1162 | soda can 1163 | shelves 1164 | ski lift 1165 | cord 1166 | vegetable 1167 | apron 1168 | blind 1169 | bracelets 1170 | stickers 1171 | traffic 1172 | strip 1173 | tennis shoes 1174 | swim trunks 1175 | hillside 1176 | sandals 1177 | concrete 1178 | lips 1179 | butter knife 1180 | words 1181 | leaves 1182 | train cars 1183 | spoke 1184 | cereal 1185 | pine trees 1186 | cooler 1187 | bangs 1188 | half 1189 | sheets 1190 | figurine 1191 | park bench 1192 | stack 1193 | second floor 1194 | motor 1195 | hand towel 1196 | wristwatch 1197 | spectator 1198 | tissues 1199 | flip flops 1200 | quilt 1201 | floret 1202 | calf 1203 | back pack 1204 | grapes 1205 | ski tracks 1206 | skin 1207 | bow 1208 | controls 1209 | dinner 1210 | baseball players 1211 | ad 1212 | ribbon 1213 | hotel 1214 | sea 1215 | cover 1216 | tarp 1217 | weather 1218 | notebook 1219 | mustache 1220 | stone wall 1221 | closet 1222 | statues 1223 | bank 1224 | skateboards 1225 | butter 1226 | dress shirt 1227 | knee 1228 | wood 1229 | laptops 1230 | cuff 1231 | hubcap 1232 | wings 1233 | range 1234 | structure 1235 | balls 1236 | tunnel 1237 | globe 1238 | utensil 1239 | dumpster 1240 | cd 1241 | floors 1242 | wrapper 1243 | folder 1244 | pocket 1245 | mother 1246 | ski goggles 1247 | posts 1248 | power line 1249 | wake 1250 | roses 1251 | train track 1252 | reflection 1253 | air conditioner 1254 | referee 1255 | barricade 1256 | baseball mitt 1257 | mouse pad 1258 | garbage can 1259 | buckle 1260 | footprints 1261 | lights 1262 | muffin 1263 | bracket 1264 | plug 1265 | taxi cab 1266 | drinks 1267 | surfers 1268 | arrows 1269 | control panel 1270 | ring 1271 | twigs 1272 | soil 1273 | skies 1274 | clock hand 1275 | caboose 1276 | playground 1277 | mango 1278 | stump 1279 | brick wall 1280 | screw 1281 | minivan 1282 | leaf 1283 | fencing 1284 | ledge 1285 | clothes 1286 | grass field 1287 | plumbing 1288 | blouse 1289 | patch 1290 | scaffolding 1291 | hamburger 1292 | utility pole 1293 | teddy 1294 | rose 1295 | skillet 1296 | cycle 1297 | cable 1298 | gloves 1299 | bark 1300 | decoration 1301 | tables 1302 | palm 1303 | wii 1304 | mountain top 1305 | shrub 1306 | hoof 1307 | celery 1308 | beads 1309 | plaque 1310 | flooring 1311 | surf 1312 | cloth 1313 | passenger 1314 | spot 1315 | plastic 1316 | knives 1317 | case 1318 | railroad 1319 | pony 1320 | muffler 1321 | hot dogs,hotdogs 1322 | stripe 1323 | scale 1324 | block 1325 | recliner 1326 | body 1327 | shades 1328 | tap 1329 | tools 1330 | cupboard 1331 | wallpaper 1332 | sculpture 1333 | surface 1334 | sedan 1335 | distance 1336 | shrubs 1337 | skiis 1338 | lift 1339 | bottom 1340 | cleats 1341 | roll 1342 | clothing 1343 | bed frame 1344 | slacks 1345 | tail lights 1346 | doll 1347 | traffic lights 1348 | symbol 1349 | strings 1350 | fixtures 1351 | short 1352 | paint 1353 | candle holder 1354 | guard rail 1355 | cyclist 1356 | tree branches 1357 | ripples 1358 | gear 1359 | waist 1360 | trash bin 1361 | onion 1362 | home 1363 | side mirror 1364 | brush 1365 | sweatband 1366 | handlebar 1367 | light pole 1368 | street lamp 1369 | pads 1370 | ham 1371 | artwork 1372 | reflector 1373 | figure 1374 | tile 1375 | mountainside 1376 | black 1377 | bricks 1378 | paper plate 1379 | stick 1380 | beef 1381 | patio 1382 | weeds 1383 | back 1384 | sausage 1385 | paws 1386 | farm 1387 | decal 1388 | harness 1389 | monkey 1390 | fence post 1391 | door frame 1392 | stripes 1393 | clocks 1394 | ponytail 1395 | toppings 1396 | strap 1397 | carton 1398 | greens 1399 | chin 1400 | lunch 1401 | name 1402 | earring 1403 | area 1404 | tshirt,t-shirt,t shirt 1405 | cream 1406 | rails 1407 | cushions 1408 | lanyard 1409 | brick 1410 | hallway 1411 | cucumber 1412 | wire fence 1413 | fern 1414 | tangerine 1415 | windowsill 1416 | pipes 1417 | package 1418 | wheelchair 1419 | chips 1420 | driveway 1421 | tattoo 1422 | side window 1423 | stairway 1424 | basin 1425 | machine 1426 | table lamp 1427 | radio 1428 | pony tail 1429 | ocean water 1430 | inside 1431 | cargo 1432 | overpass 1433 | mat 1434 | socket 1435 | flower pot 1436 | tree line 1437 | sign post 1438 | tube 1439 | dial 1440 | splash 1441 | male 1442 | lantern 1443 | lipstick 1444 | lip 1445 | tongs 1446 | ski suit 1447 | trail 1448 | passenger train 1449 | bandana 1450 | antelope 1451 | designs 1452 | tents 1453 | photograph 1454 | catcher's mitt 1455 | electrical outlet 1456 | tires 1457 | boulders 1458 | mannequin 1459 | plain 1460 | layer 1461 | mushrooms 1462 | strawberries 1463 | piece 1464 | oar 1465 | bike rack 1466 | slices 1467 | arms 1468 | fin 1469 | shadows 1470 | hood 1471 | windshield wiper 1472 | letter 1473 | dot 1474 | bus stop 1475 | railings 1476 | pebbles 1477 | mud 1478 | claws 1479 | police car 1480 | crown 1481 | meters 1482 | name tag 1483 | entrance 1484 | staircase 1485 | shrimp 1486 | ladies 1487 | peak 1488 | vines 1489 | computer keyboard 1490 | glass door 1491 | pears 1492 | pant 1493 | wine glasses 1494 | stall 1495 | asphalt 1496 | columns 1497 | sleeve 1498 | pack 1499 | cheek 1500 | baskets 1501 | land 1502 | day 1503 | blocks 1504 | courtyard 1505 | pedal 1506 | panel 1507 | seeds 1508 | balcony 1509 | yellow 1510 | disc 1511 | young man 1512 | eyebrows 1513 | crumbs 1514 | spinach 1515 | emblem 1516 | object 1517 | bar 1518 | cardboard 1519 | tissue 1520 | light post 1521 | ski jacket 1522 | seasoning 1523 | parasol 1524 | terminal 1525 | surfing 1526 | streetlight,street light 1527 | alley 1528 | cords 1529 | image 1530 | jug 1531 | antenna 1532 | puppy 1533 | berries 1534 | diamond 1535 | pans 1536 | fountain 1537 | foreground 1538 | syrup 1539 | bride 1540 | spray 1541 | license 1542 | peppers 1543 | passengers 1544 | cement 1545 | flags 1546 | shack 1547 | trough 1548 | objects 1549 | arches 1550 | streamer 1551 | pots 1552 | border 1553 | baseboard 1554 | beer bottle 1555 | wrist watch 1556 | tile floor 1557 | page 1558 | pin 1559 | items 1560 | baseline 1561 | hanger 1562 | tree branch 1563 | tusks 1564 | donkey 1565 | containers 1566 | condiments 1567 | device 1568 | envelope 1569 | parachute 1570 | mesh 1571 | hut 1572 | butterfly 1573 | salt 1574 | restroom 1575 | twig 1576 | pilot 1577 | ivy 1578 | furniture 1579 | clay 1580 | print 1581 | sandwiches 1582 | lion 1583 | shingles 1584 | pillars 1585 | vehicles 1586 | panes 1587 | shoreline 1588 | stream 1589 | control 1590 | lock 1591 | microphone 1592 | blades 1593 | towel rack 1594 | coaster 1595 | star 1596 | petals 1597 | text 1598 | feather 1599 | spots 1600 | buoy 1601 | -------------------------------------------------------------------------------- /demo/good_examples.txt: -------------------------------------------------------------------------------- 1 | v_9Hxcuf80TK0_segment_00 2 | v_a68fUj833qg_segment_00 3 | v_M1-G6KEhY-M_segment_04 4 | v_PllZQ09sBuI_segment_01 5 | v_tCQiu-qY9XA_segment_01 6 | v_tS2d90ZGmeA_segment_03 7 | v_Gl6EMAgTNKo_segment_00 8 | v_E3UCEbGZmz0_segment_00 9 | -------------------------------------------------------------------------------- /demo/gvid_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/grounded-video-description/2667e216223a63cdc91b616a34f4c046e507a1dc/demo/gvid_teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | import torch.optim as optim 17 | 18 | import numpy as np 19 | import random 20 | import time 21 | import os 22 | import pickle 23 | import torch.backends.cudnn as cudnn 24 | import yaml 25 | import copy 26 | import json 27 | import math 28 | 29 | import opts 30 | from misc import utils, AttModel 31 | from collections import defaultdict 32 | 33 | import torchvision.transforms as transforms 34 | import pdb 35 | 36 | # hack to allow the imports of evaluation repos 37 | _SCRIPTPATH_ = os.path.dirname(os.path.abspath(__file__)) 38 | import sys 39 | sys.path.insert(0, os.path.join(_SCRIPTPATH_, 'tools/densevid_eval')) 40 | sys.path.insert(0, os.path.join(_SCRIPTPATH_, 'tools/densevid_eval/coco-caption')) 41 | sys.path.insert(0, os.path.join(_SCRIPTPATH_, 'tools/anet_entities/scripts')) 42 | 43 | from evaluate import ANETcaptions 44 | from eval_grd_anet_entities import ANetGrdEval 45 | 46 | 47 | # visualization over generated sentences 48 | def vis_infer(seg_show, seg_id, caption, att2_weights, proposals, num_box, gt_bboxs, sim_mat, seg_dim_info): 49 | cap = caption.split() 50 | output = [] 51 | top_k_prop = 1 # plot the top 1 proposal only 52 | proposal = proposals[:num_box[1].item()] 53 | gt_bbox = gt_bboxs[:num_box[2].item()] 54 | 55 | sim_mat_val, sim_mat_ind = torch.max(sim_mat, dim=0) 56 | 57 | for j in range(len(cap)): 58 | 59 | max_att2_weight, top_k_alpha_idx = torch.max(att2_weights[j], dim=0) 60 | 61 | idx = top_k_alpha_idx 62 | target_frm = int(proposal[idx, 4].item()) 63 | seg = copy.deepcopy(seg_show[target_frm, :seg_dim_info[0], :seg_dim_info[1]].numpy()) 64 | seg_text = np.ones((67, int(seg_dim_info[1]), 3))*255 65 | cv2.putText(seg_text, '%s' % (cap[j]), (50, 50), cv2.FONT_HERSHEY_PLAIN, 3.0, (255, 0, 0), thickness=3) 66 | 67 | # draw the proposal box and text 68 | idx = top_k_alpha_idx 69 | bbox = proposal[idx, :4] 70 | bbox = tuple(int(np.round(x)) for x in proposal[idx, :4]) 71 | class_name = opt.itod.get(sim_mat_ind[idx].item(), '__background__') 72 | cv2.rectangle(seg, bbox[0:2], bbox[2:4], 73 | (0, 255, 0), 2) 74 | cv2.putText(seg, '%s: (%.2f)' % (class_name, sim_mat_val[idx]), 75 | (bbox[0], bbox[1] + 25), cv2.FONT_HERSHEY_PLAIN, 2.0, (0, 0, 255), thickness=2) 76 | 77 | output.append(np.concatenate([seg_text, seg], axis=0)) 78 | 79 | output = np.concatenate(output, axis=1) 80 | if not os.path.isdir('./vis'): 81 | os.mkdir('./vis') 82 | if not os.path.isdir('./vis/'+opt.id): 83 | os.mkdir('./vis/'+opt.id) 84 | # print('Visualize segment {} and the generated sentence!'.format(seg_id)) 85 | cv2.imwrite('./vis/'+opt.id+'/'+str(seg_id)+'_generated_sent.jpg', output[:,:,::-1]) 86 | 87 | 88 | # compute localization (attention/grounding) accuracy over GT sentences 89 | def eval_grounding(opt, vis=None): 90 | model.eval() 91 | 92 | data_iter = iter(dataloader_val) 93 | cls_pred_lst = [] 94 | cls_accu_score = defaultdict(list) 95 | att2_output = defaultdict(dict) 96 | grd_output = defaultdict(dict) 97 | vocab_in_split = set() 98 | 99 | for step in range(len(dataloader_val)): 100 | data = data_iter.next() 101 | seg_feat, iseq, gts_seq, num, proposals, bboxs, box_mask, seg_id, region_feat, frm_mask, sample_idx, ppl_mask = data 102 | 103 | proposals = proposals[:,:max(int(max(num[:,1])),1),:] 104 | ppl_mask = ppl_mask[:,:max(int(max(num[:,1])),1)] 105 | assert(max(int(max(num[:,1])),1) == opt.num_sampled_frm*opt.num_prop_per_frm) 106 | bboxs = bboxs[:,:max(int(max(num[:,2])),1),:] 107 | frm_mask = frm_mask[:, :max(int(max(num[:,1])),1), :max(int(max(num[:,2])),1)] 108 | region_feat = region_feat[:,:max(int(max(num[:,1])),1),:] 109 | 110 | segs_feat.resize_(seg_feat.size()).data.copy_(seg_feat) 111 | input_seqs.resize_(iseq.size()).data.copy_(iseq) 112 | gt_seqs.resize_(gts_seq.size()).data.copy_(gts_seq) 113 | input_num.resize_(num.size()).data.copy_(num) 114 | input_ppls.resize_(proposals.size()).data.copy_(proposals) 115 | mask_ppls.resize_(ppl_mask.size()).data.copy_(ppl_mask) 116 | pnt_mask = torch.cat((mask_ppls.new(mask_ppls.size(0), 1).fill_(0), mask_ppls), dim=1) 117 | gt_bboxs.resize_(bboxs.size()).data.copy_(bboxs) # for region cls eval only 118 | mask_frms.resize_(frm_mask.size()).data.copy_(frm_mask) # for region cls eval only 119 | ppls_feat.resize_(region_feat.size()).data.copy_(region_feat) 120 | sample_idx = Variable(sample_idx.type(input_seqs.type())) 121 | 122 | dummy = input_ppls.new(input_ppls.size(0)).byte().fill_(0) 123 | 124 | # cls_pred_hm_lst contains a list of tuples (clss_ind, hit/1 or miss/0) 125 | cls_pred_hm_lst, att2_ind, grd_ind = model(segs_feat, input_seqs, gt_seqs, input_num, 126 | input_ppls, gt_bboxs, dummy, ppls_feat, mask_frms, sample_idx, pnt_mask, 'GRD') 127 | 128 | # save attention/grounding results on GT sentences 129 | obj_mask = (input_seqs[:,0,1:,0] > opt.vocab_size) # Bx20 130 | obj_bbox_att2 = torch.gather(input_ppls.view(-1, opt.num_sampled_frm, opt.num_prop_per_frm, 7) \ 131 | .permute(0, 2, 1, 3).contiguous(), 1, att2_ind.unsqueeze(-1).expand((att2_ind.size(0), \ 132 | att2_ind.size(1), opt.num_sampled_frm, 7))) # Bx20x10x7 133 | obj_bbox_grd = torch.gather(input_ppls.view(-1, opt.num_sampled_frm, opt.num_prop_per_frm, 7) \ 134 | .permute(0, 2, 1, 3).contiguous(), 1, grd_ind.unsqueeze(-1).expand((grd_ind.size(0), \ 135 | grd_ind.size(1), opt.num_sampled_frm, 7))) # Bx20x10x7 136 | 137 | for i in range(obj_mask.size(0)): 138 | vid_id, seg_idx = seg_id[i].split('_segment_') 139 | seg_idx = str(int(seg_idx)) 140 | tmp_result_grd = {'clss':[], 'idx_in_sent':[], 'bbox_for_all_frames':[]} 141 | tmp_result_att2 = {'clss':[], 'idx_in_sent':[], 'bbox_for_all_frames':[]} 142 | for j in range(obj_mask.size(1)): 143 | if obj_mask[i, j]: 144 | cls_name = opt.itod[input_seqs[i,0,j+1,0].item()-opt.vocab_size] 145 | vocab_in_split.update([cls_name]) 146 | tmp_result_att2['clss'].append(cls_name) 147 | tmp_result_att2['idx_in_sent'].append(j) 148 | tmp_result_att2['bbox_for_all_frames'].append(obj_bbox_att2[i, j, :, :4].tolist()) 149 | tmp_result_grd['clss'].append(cls_name) 150 | tmp_result_grd['idx_in_sent'].append(j) 151 | tmp_result_grd['bbox_for_all_frames'].append(obj_bbox_grd[i, j, :, :4].tolist()) 152 | att2_output[vid_id][seg_idx] = tmp_result_att2 153 | grd_output[vid_id][seg_idx] = tmp_result_grd 154 | 155 | cls_pred_lst.append(cls_pred_hm_lst) 156 | 157 | # write results to file 158 | attn_file = 'results/attn-gt-sent-results-'+opt.val_split+'-'+opt.id+'.json' 159 | with open(attn_file, 'w') as f: 160 | json.dump({'results':att2_output, 'eval_mode':'GT', 'external_data':{'used':True, 'details':'Object detector pre-trained on Visual Genome on object detection task.'}}, f) 161 | grd_file = 'results/grd-gt-sent-results-'+opt.val_split+'-'+opt.id+'.json' 162 | with open(grd_file, 'w') as f: 163 | json.dump({'results':grd_output, 'eval_mode':'GT', 'external_data':{'used':True, 'details':'Object detector pre-trained on Visual Genome on object detection task.'}}, f) 164 | 165 | if not opt.test_mode: 166 | cls_pred_lst = torch.cat(cls_pred_lst, dim=0).cpu() 167 | cls_accu_lst = torch.cat((cls_pred_lst[:, 0:1], (cls_pred_lst[:, 0:1] == cls_pred_lst[:, 1:2]).long()), dim=1) 168 | for i in range(cls_accu_lst.size(0)): 169 | cls_accu_score[cls_accu_lst[i,0].long().item()].append(cls_accu_lst[i,1].item()) 170 | print('Total number of object classes in the split: {}. {} have classification results.'.format(len(vocab_in_split), len(cls_accu_score))) 171 | cls_accu = np.sum([sum(hm)*1./len(hm) for i,hm in cls_accu_score.items()])*1./len(vocab_in_split) 172 | 173 | # offline eval 174 | evaluator = ANetGrdEval(reference_file=opt.grd_reference, submission_file=attn_file, 175 | split_file=opt.split_file, val_split=[opt.val_split], 176 | iou_thresh=0.5) 177 | 178 | attn_accu = evaluator.gt_grd_eval() 179 | evaluator.import_sub(grd_file) 180 | grd_accu = evaluator.gt_grd_eval() 181 | 182 | print('\nResults Summary (GT sent):') 183 | print('The averaged attention / grounding box accuracy across all classes is: {:.4f} / {:.4f}'.format(attn_accu, grd_accu)) 184 | print('The averaged classification accuracy across all classes is: {:.4f}\n'.format(cls_accu)) 185 | 186 | return attn_accu, grd_accu, cls_accu 187 | else: 188 | print('*'*62) 189 | print('* [WARNING] Grounding eval unavailable for the test set!\ 190 | *\n* Please submit your result file named *\ 191 | \n* results/grd-gt-sent-*.json to the eval server! *') 192 | print('*'*62) 193 | 194 | return 0, 0, 0 195 | 196 | 197 | def train(epoch, opt, vis=None, vis_window=None): 198 | model.train() 199 | 200 | data_iter = iter(dataloader) 201 | nbatches = len(dataloader) 202 | train_loss = [] 203 | 204 | lm_loss_temp = [] 205 | att2_loss_temp = [] 206 | ground_loss_temp = [] 207 | cls_loss_temp = [] 208 | start = time.time() 209 | 210 | for step in range(len(dataloader)-1): 211 | data = data_iter.next() 212 | seg_feat, iseq, gts_seq, num, proposals, bboxs, box_mask, seg_id, region_feat, frm_mask, sample_idx, ppl_mask = data 213 | proposals = proposals[:,:max(int(max(num[:,1])),1),:] 214 | ppl_mask = ppl_mask[:,:max(int(max(num[:,1])),1)] 215 | bboxs = bboxs[:,:max(int(max(num[:,2])),1),:] 216 | box_mask = box_mask[:,:,:max(int(max(num[:,2])),1),:] 217 | frm_mask = frm_mask[:,:max(int(max(num[:,1])),1),:max(int(max(num[:,2])),1)] 218 | region_feat = region_feat[:,:max(int(max(num[:,1])),1),:] 219 | 220 | segs_feat.resize_(seg_feat.size()).data.copy_(seg_feat) 221 | input_seqs.resize_(iseq.size()).data.copy_(iseq) 222 | gt_seqs.resize_(gts_seq.size()).data.copy_(gts_seq) 223 | input_num.resize_(num.size()).data.copy_(num) 224 | input_ppls.resize_(proposals.size()).data.copy_(proposals) 225 | mask_ppls.resize_(ppl_mask.size()).data.copy_(ppl_mask) 226 | # pad 1 column from a legacy reason 227 | pnt_mask = torch.cat((mask_ppls.new(mask_ppls.size(0), 1).fill_(0), mask_ppls), dim=1) 228 | gt_bboxs.resize_(bboxs.size()).data.copy_(bboxs) 229 | mask_bboxs.resize_(box_mask.size()).data.copy_(box_mask) 230 | mask_frms.resize_(frm_mask.size()).data.copy_(frm_mask) 231 | ppls_feat.resize_(region_feat.size()).data.copy_(region_feat) 232 | sample_idx = Variable(sample_idx.type(input_seqs.type())) 233 | 234 | loss = 0 235 | lm_loss, att2_loss, ground_loss, cls_loss = model(segs_feat, input_seqs, gt_seqs, input_num, 236 | input_ppls, gt_bboxs, mask_bboxs, ppls_feat, mask_frms, sample_idx, pnt_mask, 'MLE') 237 | 238 | w_att2, w_grd, w_cls = opt.w_att2, opt.w_grd, opt.w_cls 239 | att2_loss = w_att2*att2_loss.sum() 240 | ground_loss = w_grd*ground_loss.sum() 241 | cls_loss = w_cls*cls_loss.sum() 242 | 243 | if not opt.disable_caption: 244 | loss += lm_loss.sum() 245 | else: 246 | lm_loss.fill_(0) 247 | 248 | if w_att2: 249 | loss += att2_loss 250 | if w_grd: 251 | loss += ground_loss 252 | if w_cls: 253 | loss += cls_loss 254 | 255 | loss = loss / lm_loss.numel() 256 | train_loss.append(loss.item()) 257 | 258 | lm_loss_temp.append(lm_loss.sum().item() / lm_loss.numel()) 259 | att2_loss_temp.append(att2_loss.sum().item() / lm_loss.numel()) 260 | ground_loss_temp.append(ground_loss.sum().item() / lm_loss.numel()) 261 | cls_loss_temp.append(cls_loss.sum().item() / lm_loss.numel()) 262 | 263 | model.zero_grad() 264 | loss.backward() 265 | nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) 266 | optimizer.step() 267 | 268 | if step % opt.disp_interval == 0 and step != 0: 269 | end = time.time() 270 | 271 | print("step {}/{} (epoch {}), lm_loss = {:.3f}, att2_loss = {:.3f}, ground_loss = {:.3f},cls_los = {:.3f}, lr = {:.5f}, time/batch = {:.3f}" \ 272 | .format(step, len(dataloader), epoch, np.mean(lm_loss_temp), np.mean(att2_loss_temp), \ 273 | np.mean(ground_loss_temp), np.mean(cls_loss_temp), opt.learning_rate, end - start)) 274 | start = time.time() 275 | 276 | if opt.enable_visdom: 277 | if vis_window['iter'] is None: 278 | vis_window['iter'] = vis.line( 279 | X=np.tile(np.arange(epoch*nbatches+step, epoch*nbatches+step+1), 280 | (5,1)).T, 281 | Y=np.column_stack((np.asarray(np.mean(train_loss)), 282 | np.asarray(np.mean(lm_loss_temp)), 283 | np.asarray(np.mean(att2_loss_temp)), 284 | np.asarray(np.mean(ground_loss_temp)), 285 | np.asarray(np.mean(cls_loss_temp)))), 286 | opts=dict(title='Training Loss', 287 | xlabel='Training Iteration', 288 | ylabel='Loss', 289 | legend=['total', 'lm', 'attn', 'grd', 'cls']) 290 | ) 291 | else: 292 | vis.line( 293 | X=np.tile(np.arange(epoch*nbatches+step, epoch*nbatches+step+1), 294 | (5,1)).T, 295 | Y=np.column_stack((np.asarray(np.mean(train_loss)), 296 | np.asarray(np.mean(lm_loss_temp)), 297 | np.asarray(np.mean(att2_loss_temp)), 298 | np.asarray(np.mean(ground_loss_temp)), 299 | np.asarray(np.mean(cls_loss_temp)))), 300 | opts=dict(title='Training Loss', 301 | xlabel='Training Iteration', 302 | ylabel='Loss', 303 | legend=['total', 'lm', 'attn', 'grd', 'cls']), 304 | win=vis_window['iter'], 305 | update='append' 306 | ) 307 | 308 | # Write the training loss summary 309 | if (iteration % opt.losses_log_every == 0): 310 | loss_history[iteration] = loss.item() 311 | lr_history[iteration] = opt.learning_rate 312 | 313 | 314 | def eval(epoch, opt, vis=None, vis_window=None): 315 | model.eval() 316 | 317 | data_iter_val = iter(dataloader_val) 318 | start = time.time() 319 | 320 | num_show = 0 321 | predictions = defaultdict(list) 322 | count = 0 323 | timestamp_file = json.load(open(opt.grd_reference)) 324 | min_value = -1e8 325 | 326 | if opt.eval_obj_grounding: 327 | grd_output = defaultdict(dict) 328 | 329 | lemma_det_dict = {opt.wtol[key]:idx for key,idx in opt.wtod.items() if key in opt.wtol} 330 | print('{} classes have the associated lemma word!'.format(len(lemma_det_dict))) 331 | 332 | if opt.eval_obj_grounding or opt.language_eval: 333 | for step in range(len(dataloader_val)): 334 | data = data_iter_val.next() 335 | if opt.vis_attn: 336 | seg_feat, iseq, gts_seq, num, proposals, bboxs, box_mask, seg_id, seg_show, seg_dim_info, region_feat, frm_mask, sample_idx, ppl_mask = data 337 | else: 338 | seg_feat, iseq, gts_seq, num, proposals, bboxs, box_mask, seg_id, region_feat, frm_mask, sample_idx, ppl_mask = data 339 | 340 | proposals = proposals[:,:max(int(max(num[:,1])),1),:] 341 | ppl_mask = ppl_mask[:,:max(int(max(num[:,1])),1)] 342 | region_feat = region_feat[:,:max(int(max(num[:,1])),1),:] 343 | 344 | segs_feat.resize_(seg_feat.size()).data.copy_(seg_feat) 345 | input_num.resize_(num.size()).data.copy_(num) 346 | input_ppls.resize_(proposals.size()).data.copy_(proposals) 347 | mask_ppls.resize_(ppl_mask.size()).data.copy_(ppl_mask) 348 | pnt_mask = torch.cat((mask_ppls.new(mask_ppls.size(0), 1).fill_(0), mask_ppls), dim=1) # pad 1 column from a legacy reason 349 | ppls_feat.resize_(region_feat.size()).data.copy_(region_feat) 350 | sample_idx = Variable(sample_idx.type(input_num.type())) 351 | 352 | eval_opt = {'sample_max':1, 'beam_size': opt.beam_size, 'inference_mode' : True} 353 | dummy = input_ppls.new(input_ppls.size(0)).byte().fill_(0) 354 | 355 | batch_size = input_ppls.size(0) 356 | 357 | seq, att2_weights, sim_mat = model(segs_feat, dummy, dummy, input_num, \ 358 | input_ppls, dummy, dummy, ppls_feat, dummy, sample_idx, pnt_mask, 'sample', eval_opt) 359 | 360 | # save localization results on generated sentences 361 | if opt.eval_obj_grounding: 362 | assert opt.beam_size == 1, 'only support beam_size is 1' 363 | 364 | att2_ind = torch.max(att2_weights.view(batch_size, att2_weights.size(1), \ 365 | opt.num_sampled_frm, opt.num_prop_per_frm), dim=-1)[1] 366 | obj_bbox_att2 = torch.gather(input_ppls.view(-1, opt.num_sampled_frm, opt.num_prop_per_frm, 7) \ 367 | .permute(0, 2, 1, 3).contiguous(), 1, att2_ind.unsqueeze(-1).expand((batch_size, \ 368 | att2_ind.size(1), opt.num_sampled_frm, input_ppls.size(-1)))) # Bx20x10x7 369 | 370 | for i in range(seq.size(0)): 371 | vid_id, seg_idx = seg_id[i].split('_segment_') 372 | seg_idx = str(int(seg_idx)) 373 | tmp_result = {'clss':[], 'idx_in_sent':[], 'bbox_for_all_frames':[]} 374 | 375 | for j in range(seq.size(1)): 376 | if seq[i,j].item() != 0: 377 | lemma = opt.wtol[opt.itow[str(seq[i,j].item())]] 378 | if lemma in lemma_det_dict: 379 | tmp_result['bbox_for_all_frames'].append(obj_bbox_att2[i, j, :, :4].tolist()) 380 | tmp_result['clss'].append(opt.itod[lemma_det_dict[lemma]]) 381 | tmp_result['idx_in_sent'].append(j) # redundant, for the sake of output format 382 | else: 383 | break 384 | grd_output[vid_id][seg_idx] = tmp_result 385 | 386 | sents = utils.decode_sequence(dataset.itow, dataset.itod, dataset.ltow, dataset.itoc, \ 387 | dataset.wtod, seq.data, opt.vocab_size, opt) 388 | 389 | for k, sent in enumerate(sents): 390 | vid_idx, seg_idx = seg_id[k].split('_segment_') 391 | seg_idx = str(int(seg_idx)) 392 | 393 | predictions[vid_idx].append( 394 | {'sentence':sent, 395 | 'timestamp':[round(timestamp, 2) for timestamp in timestamp_file[ \ 396 | 'annotations'][vid_idx]['segments'][seg_idx]['timestamps']]}) 397 | 398 | if num_show < 20: 399 | print('segment %s: %s' %(seg_id[k], sent)) 400 | num_show += 1 401 | 402 | # visualization 403 | if opt.vis_attn: 404 | assert(opt.beam_size == 1) # only support beam_size=1 405 | att2_weights = F.softmax(att2_weights, dim=2) 406 | # visualize some selected examples 407 | if torch.sum(proposals[k]) != 0: 408 | vis_infer(seg_show[k], seg_id[k], sent, \ 409 | att2_weights[k].cpu().data, proposals[k], num[k].long(), \ 410 | bboxs[k], sim_mat[k].cpu().data, seg_dim_info[k]) 411 | 412 | if count % 2 == 0: 413 | print(count) 414 | count += 1 415 | 416 | lang_stats = defaultdict(float) 417 | if opt.language_eval: 418 | print('Total videos to be evaluated %d' %(len(predictions))) 419 | 420 | submission = 'densecap_results/'+'densecap-'+opt.val_split+'-'+opt.id+'.json' 421 | dense_cap_all = {'version':'VERSION 1.0', 'results':predictions, 422 | 'external_data':{'used':'true', 423 | 'details':'Visual Genome for Faster R-CNN pre-training'}} 424 | with open(submission, 'w') as f: 425 | json.dump(dense_cap_all, f) 426 | 427 | references = opt.densecap_references 428 | verbose = opt.densecap_verbose 429 | tious_lst = [0.3, 0.5, 0.7, 0.9] 430 | evaluator = ANETcaptions(ground_truth_filenames=references, 431 | prediction_filename=submission, 432 | tious=tious_lst, 433 | max_proposals=1000, 434 | verbose=verbose) 435 | evaluator.evaluate() 436 | 437 | for m,v in evaluator.scores.items(): 438 | lang_stats[m] = np.mean(v) 439 | 440 | print('\nResults Summary (lang eval):') 441 | print('Printing language evaluation metrics...') 442 | for m, s in lang_stats.items(): 443 | print('{}: {:.3f}'.format(m, s*100)) 444 | print('\n') 445 | 446 | if opt.eval_obj_grounding: 447 | # write attention results to file 448 | attn_file = 'results/attn-gen-sent-results-'+opt.val_split+'-'+opt.id+'.json' 449 | with open(attn_file, 'w') as f: 450 | json.dump({'results':grd_output, 'eval_mode':'gen', 'external_data':{'used':True, 'details':'Object detector pre-trained on Visual Genome on object detection task.'}}, f) 451 | 452 | if not opt.test_mode: 453 | # offline eval 454 | evaluator = ANetGrdEval(reference_file=opt.grd_reference, submission_file=attn_file, 455 | split_file=opt.split_file, val_split=[opt.val_split], 456 | iou_thresh=0.5) 457 | 458 | print('\nResults Summary (generated sent):') 459 | print('Printing attention accuracy on generated sentences, per class and per sentence, respectively...') 460 | prec_all, recall_all, f1_all, prec_all_per_sent, rec_all_per_sent, f1_all_per_sent = evaluator.grd_eval(mode='all') 461 | prec_loc, recall_loc, f1_loc, prec_loc_per_sent, rec_loc_per_sent, f1_loc_per_sent = evaluator.grd_eval(mode='loc') 462 | else: 463 | print('*'*62) 464 | print('* [WARNING] Grounding eval unavailable for the test set!\ 465 | *\n* Please submit your result file named *\ 466 | \n* results/attn-gen-sent-*.json to the eval server!*') 467 | print('*'*62) 468 | 469 | if opt.att_model == 'topdown' and opt.eval_obj_grounding_gt: 470 | with torch.no_grad(): 471 | box_accu_att, box_accu_grd, cls_accu = eval_grounding(opt) # eval grounding 472 | else: 473 | box_accu_att, box_accu_grd, cls_accu = 0, 0, 0 474 | 475 | if opt.enable_visdom: 476 | assert(opt.language_eval) 477 | if vis_window['score'] is None: 478 | vis_window['score'] = vis.line( 479 | X=np.tile(np.arange(epoch, epoch+1), 480 | (7,1)).T, 481 | Y=np.column_stack((np.asarray(box_accu_att), 482 | np.asarray(box_accu_grd), 483 | np.asarray(cls_accu), 484 | np.asarray(lang_stats['Bleu_4']), 485 | np.asarray(lang_stats['METEOR']), 486 | np.asarray(lang_stats['CIDEr']), 487 | np.asarray(lang_stats['SPICE']))), 488 | opts=dict(title='Validation Score', 489 | xlabel='Validation Epoch', 490 | ylabel='Score', 491 | legend=['BA (alpha)', 'BA (beta)', 'CLS Accu', 'Bleu_4', 'METEOR', 'CIDEr', 'SPICE']) 492 | ) 493 | else: 494 | vis.line( 495 | X=np.tile(np.arange(epoch, epoch+1), 496 | (7,1)).T, 497 | Y=np.column_stack((np.asarray(box_accu_att), 498 | np.asarray(box_accu_grd), 499 | np.asarray(cls_accu), 500 | np.asarray(lang_stats['Bleu_4']), 501 | np.asarray(lang_stats['METEOR']), 502 | np.asarray(lang_stats['CIDEr']), 503 | np.asarray(lang_stats['SPICE']))), 504 | opts=dict(title='Validation Score', 505 | xlabel='Validation Epoch', 506 | ylabel='Score', 507 | legend=['BA (alpha)', 'BA (beta)', 'CLS Accu', 'Bleu_4', 'METEOR', 'CIDEr', 'SPICE']), 508 | win=vis_window['score'], 509 | update='append' 510 | ) 511 | 512 | print('Saving the predictions') 513 | 514 | # Write validation result into summary 515 | val_result_history[iteration] = {'lang_stats': lang_stats, 'predictions': predictions} 516 | 517 | return lang_stats 518 | 519 | 520 | if __name__ == '__main__': 521 | 522 | opt = opts.parse_opt() 523 | if opt.path_opt is not None: 524 | with open(opt.path_opt, 'r') as handle: 525 | options_yaml = yaml.load(handle) 526 | utils.update_values(options_yaml, vars(opt)) 527 | opt.test_mode = (opt.val_split in ['testing', 'hidden_test']) 528 | if opt.enable_BUTD: 529 | assert opt.att_input_mode == 'region', 'region attention only under the BUTD mode' 530 | 531 | # print(opt) 532 | cudnn.benchmark = True 533 | 534 | if opt.enable_visdom: 535 | import visdom 536 | vis = visdom.Visdom(server=opt.visdom_server, env=opt.id) 537 | vis_window={'iter': None, 'score':None} 538 | 539 | torch.manual_seed(opt.seed) 540 | np.random.seed(opt.seed) 541 | random.seed(opt.seed) 542 | if opt.cuda: 543 | torch.cuda.manual_seed_all(opt.seed) 544 | if opt.vis_attn: 545 | import cv2 546 | 547 | if opt.dataset == 'anet': 548 | from misc.dataloader_anet import DataLoader 549 | else: 550 | raise Exception('only support anet!') 551 | 552 | if not os.path.exists(opt.checkpoint_path): 553 | os.makedirs(opt.checkpoint_path) 554 | 555 | # Data Loader 556 | dataset = DataLoader(opt, split=opt.train_split, seq_per_img=opt.seq_per_img) 557 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, 558 | shuffle=True, num_workers=opt.num_workers) 559 | 560 | dataset_val = DataLoader(opt, split=opt.val_split, seq_per_img=opt.seq_per_img) 561 | dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=opt.batch_size, 562 | shuffle=False, num_workers=opt.num_workers) 563 | 564 | segs_feat = torch.FloatTensor(1) 565 | input_seqs = torch.LongTensor(1) 566 | input_ppls = torch.FloatTensor(1) 567 | mask_ppls = torch.ByteTensor(1) 568 | gt_bboxs = torch.FloatTensor(1) 569 | mask_bboxs = torch.ByteTensor(1) 570 | mask_frms = torch.ByteTensor(1) 571 | gt_seqs = torch.LongTensor(1) 572 | input_num = torch.LongTensor(1) 573 | ppls_feat = torch.FloatTensor(1) 574 | 575 | if opt.cuda: 576 | segs_feat = segs_feat.cuda() 577 | input_seqs = input_seqs.cuda() 578 | gt_seqs = gt_seqs.cuda() 579 | input_num = input_num.cuda() 580 | input_ppls = input_ppls.cuda() 581 | mask_ppls = mask_ppls.cuda() 582 | gt_bboxs = gt_bboxs.cuda() 583 | mask_bboxs = mask_bboxs.cuda() 584 | mask_frms = mask_frms.cuda() 585 | ppls_feat = ppls_feat.cuda() 586 | 587 | segs_feat = Variable(segs_feat) 588 | input_seqs = Variable(input_seqs) 589 | gt_seqs = Variable(gt_seqs) 590 | input_num = Variable(input_num) 591 | input_ppls = Variable(input_ppls) 592 | mask_ppls = Variable(mask_ppls) 593 | gt_bboxs = Variable(gt_bboxs) 594 | mask_bboxs = Variable(mask_bboxs) 595 | mask_frms = Variable(mask_frms) 596 | ppls_feat = Variable(ppls_feat) 597 | 598 | # Build the Model 599 | opt.vocab_size = dataset.vocab_size 600 | opt.detect_size = dataset.detect_size 601 | opt.seq_length = opt.seq_length 602 | opt.glove_w = torch.from_numpy(dataset.glove_w).float() 603 | opt.glove_vg_cls = torch.from_numpy(dataset.glove_vg_cls).float() 604 | opt.glove_clss = torch.from_numpy(dataset.glove_clss).float() 605 | 606 | opt.wtoi = dataset.wtoi 607 | opt.itow = dataset.itow 608 | opt.itod = dataset.itod 609 | opt.ltow = dataset.ltow 610 | opt.itoc = dataset.itoc 611 | opt.wtol = dataset.wtol 612 | opt.wtod = dataset.wtod 613 | opt.vg_cls = dataset.vg_cls 614 | 615 | if opt.att_model == 'topdown': 616 | model = AttModel.TopDownModel(opt) 617 | elif opt.att_model == 'transformer': 618 | model = AttModel.TransformerModel(opt) 619 | 620 | infos = {} 621 | histories = {} 622 | if opt.start_from is not None: 623 | if opt.load_best_score == 1: 624 | model_path = os.path.join(opt.start_from, 'model-best.pth') 625 | info_path = os.path.join(opt.start_from, 'infos_'+opt.id+'-best.pkl') 626 | else: 627 | model_path = os.path.join(opt.start_from, 'model.pth') 628 | info_path = os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl') 629 | 630 | # open old infos and check if models are compatible 631 | with open(info_path, 'rb') as f: 632 | infos = pickle.load(f, encoding='latin1') # py2 pickle -> py3 633 | # infos = pickle.load(f) 634 | saved_model_opt = infos['opt'] 635 | 636 | # opt.learning_rate = saved_model_opt.learning_rate 637 | print('Loading the model %s...' %(model_path)) 638 | model.load_state_dict(torch.load(model_path)) 639 | 640 | if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): 641 | with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f: 642 | histories = pickle.load(f, encoding='latin1') # py2 pickle -> py3 643 | # histories = pickle.load(f) 644 | 645 | best_val_score = infos.get('best_val_score', None) 646 | iteration = infos.get('iter', 0) 647 | start_epoch = infos.get('epoch', 0) 648 | 649 | val_result_history = histories.get('val_result_history', {}) 650 | loss_history = histories.get('loss_history', {}) 651 | lr_history = histories.get('lr_history', {}) 652 | ss_prob_history = histories.get('ss_prob_history', {}) 653 | 654 | if opt.mGPUs: 655 | model = nn.DataParallel(model) 656 | 657 | if opt.cuda: 658 | model.cuda() 659 | 660 | params = [] 661 | for key, value in dict(model.named_parameters()).items(): 662 | if value.requires_grad: 663 | if ('ctx2pool_grd' in key) or ('vis_embed' in key): 664 | print('Finetune param: {}'.format(key)) 665 | params += [{'params':[value], 'lr':opt.learning_rate*0.1, # finetune the fc7 layer 666 | 'weight_decay':opt.weight_decay, 'betas':(opt.optim_alpha, opt.optim_beta)}] 667 | else: 668 | params += [{'params':[value], 'lr':opt.learning_rate, 669 | 'weight_decay':opt.weight_decay, 'betas':(opt.optim_alpha, opt.optim_beta)}] 670 | 671 | print("Use %s as optmization method" %(opt.optim)) 672 | if opt.optim == 'sgd': 673 | optimizer = optim.SGD(params, momentum=0.9) 674 | elif opt.optim == 'adam': 675 | optimizer = optim.Adam(params) 676 | elif opt.optim == 'adamax': 677 | optimizer = optim.Adamax(params) 678 | 679 | for epoch in range(start_epoch, opt.max_epochs): 680 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: 681 | if (epoch - opt.learning_rate_decay_start) % opt.learning_rate_decay_every == 0: 682 | # decay the learning rate. 683 | utils.set_lr(optimizer, opt.learning_rate_decay_rate) 684 | opt.learning_rate = opt.learning_rate * opt.learning_rate_decay_rate 685 | 686 | if not opt.inference_only: 687 | if opt.enable_visdom: 688 | train(epoch, opt, vis, vis_window) 689 | else: 690 | train(epoch, opt) 691 | 692 | if epoch % opt.val_every_epoch == 0: 693 | with torch.no_grad(): 694 | if opt.enable_visdom: 695 | lang_stats = eval(epoch, opt, vis, vis_window) 696 | else: 697 | lang_stats = eval(epoch, opt) 698 | 699 | if opt.inference_only: 700 | break 701 | 702 | # Save model if is improving on validation result 703 | current_score = lang_stats['CIDEr'] 704 | 705 | best_flag = False 706 | if best_val_score is None or current_score > best_val_score: 707 | best_val_score = current_score 708 | best_flag = True 709 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') 710 | if opt.mGPUs: 711 | torch.save(model.module.state_dict(), checkpoint_path) 712 | else: 713 | torch.save(model.state_dict(), checkpoint_path) 714 | print("model saved to {}".format(checkpoint_path)) 715 | # optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') 716 | # torch.save(optimizer.state_dict(), optimizer_path) 717 | 718 | # Dump miscalleous informations 719 | infos['iter'] = iteration 720 | infos['epoch'] = epoch 721 | infos['best_val_score'] = best_val_score 722 | infos['opt'] = opt 723 | infos['vocab'] = dataset.itow 724 | 725 | histories['val_result_history'] = val_result_history 726 | histories['loss_history'] = loss_history 727 | histories['lr_history'] = lr_history 728 | histories['ss_prob_history'] = ss_prob_history 729 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: 730 | pickle.dump(infos, f) 731 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: 732 | pickle.dump(histories, f) 733 | 734 | if best_flag: 735 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') 736 | if opt.mGPUs: 737 | torch.save(model.module.state_dict(), checkpoint_path) 738 | else: 739 | torch.save(model.state_dict(), checkpoint_path) 740 | 741 | print("model saved to {} with best cider score {:.3f}".format(checkpoint_path, best_val_score)) 742 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: 743 | pickle.dump(infos, f) 744 | -------------------------------------------------------------------------------- /misc/AttModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | from misc.model import AttModel 17 | from torch.nn.parameter import Parameter 18 | import pdb 19 | import random 20 | 21 | 22 | class Attention(nn.Module): 23 | def __init__(self, opt): 24 | super(Attention, self).__init__() 25 | self.rnn_size = opt.rnn_size 26 | self.att_hid_size = opt.att_hid_size 27 | 28 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 29 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 30 | self.min_value = -1e8 31 | # self.batch_norm = nn.BatchNorm1d(self.rnn_size) 32 | 33 | def forward(self, h, att_feats, p_att_feats): 34 | # The p_att_feats here is already projected 35 | batch_size = h.size(0) 36 | att_size = att_feats.numel() // batch_size // self.rnn_size 37 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 38 | 39 | att_h = self.h2att(h) # batch * att_hid_size 40 | att_h = att_h.unsqueeze(1) # batch * att_size * att_hid_size 41 | dot = att + att_h # batch * att_size * att_hid_size 42 | dot = F.tanh(dot) # batch * att_size * att_hid_size 43 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 44 | # dot = F.dropout(dot, 0.3, training=self.training) 45 | dot = self.alpha_net(dot) # (batch * att_size) * 1 46 | dot = dot.view(-1, att_size) # batch * att_size 47 | 48 | weight = F.softmax(dot, dim=1) # batch * att_size 49 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 50 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 51 | # att_res = self.batch_norm(att_res) 52 | 53 | return att_res 54 | 55 | 56 | class Attention2(nn.Module): 57 | def __init__(self, opt): 58 | super(Attention2, self).__init__() 59 | self.rnn_size = opt.rnn_size 60 | self.att_hid_size = opt.att_hid_size 61 | 62 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 63 | if opt.region_attn_mode in ('add', 'mix', 'mix_mul'): 64 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 65 | elif opt.region_attn_mode == 'cat': 66 | self.alpha_net = nn.Linear(self.att_hid_size*2, 1) 67 | self.min_value = -1e8 68 | self.region_attn_mode = opt.region_attn_mode 69 | # self.batch_norm = nn.BatchNorm1d(self.rnn_size) 70 | 71 | def forward(self, h, att_feats, p_att_feats, att_mask, pnt_mask): 72 | # The p_att_feats here is already projected 73 | batch_size = h.size(0) 74 | att_size = att_feats.numel() // batch_size // self.rnn_size 75 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 76 | 77 | att_h = self.h2att(h) # batch * att_hid_size 78 | 79 | if hasattr(self, 'alpha_net'): 80 | # print('Additive region attention!') 81 | if self.alpha_net.weight.size(1) == self.att_hid_size: 82 | if self.region_attn_mode == 'mix_mul': 83 | dot = att * att_h.unsqueeze(1) # element-wise multiplication attn. 84 | else: 85 | dot = att + att_h.unsqueeze(1) # batch * att_size * att_hid_size 86 | else: 87 | dot = torch.cat((xt.unsqueeze(1), att_feats), 2) 88 | 89 | dot = F.tanh(dot) # batch * att_size * att_hid_size 90 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 91 | # dot = F.dropout(dot, 0.3, training=self.training) 92 | hAflat = self.alpha_net(dot) # (batch * att_size) * 1 93 | else: 94 | # print('Dot-product region attention!') 95 | assert(att.size(2) == att_h.size(1)) 96 | hAflat = torch.matmul(att, att_h.view(batch_size, self.att_hid_size, 1)) 97 | 98 | hAflat = hAflat.view(-1, att_size) # batch * att_size 99 | hAflat.masked_fill_(att_mask, self.min_value) 100 | frm_masked_hAflat = hAflat.clone() 101 | 102 | weight = F.softmax(hAflat, dim=1) # batch * att_size 103 | frm_masked_hAflat.masked_fill_(pnt_mask, self.min_value) 104 | 105 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 106 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 107 | 108 | return att_res, frm_masked_hAflat, att_h 109 | 110 | 111 | class TopDownCore(nn.Module): 112 | def __init__(self, opt, use_maxout=False): 113 | super(TopDownCore, self).__init__() 114 | self.drop_prob_lm = opt.drop_prob_lm 115 | self.min_value = -1e8 116 | self.att_input_mode=opt.att_input_mode 117 | self.rnn_size = opt.rnn_size 118 | self.att_hid_size = opt.att_hid_size 119 | self.detect_size = opt.detect_size 120 | 121 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1 122 | 123 | self.lang_lstm = nn.LSTMCell(opt.rnn_size*2, opt.rnn_size) # h^1_t, \hat v 124 | self.attention = Attention(opt) 125 | self.attention2 = Attention2(opt) 126 | if self.att_input_mode == 'dual_region': 127 | self.attention2_dual = Attention2(opt) 128 | self.dual_pointer = nn.Sequential(nn.Linear(opt.rnn_size, 1), nn.Sigmoid()) 129 | 130 | self.i2h_2 = nn.Linear(opt.rnn_size*2, opt.rnn_size) 131 | self.h2h_2 = nn.Linear(opt.rnn_size, opt.rnn_size) 132 | 133 | 134 | def forward(self, xt, fc_feats, conv_feats, p_conv_feats, pool_feats, p_pool_feats, att_mask, pnt_mask, state, sim_mat_static_update): 135 | # att_mask is for attention , pnt_mask cound be for either attention or grounding 136 | # pnt_mask is frm_mask during training and is att_mask during inference 137 | 138 | att_lstm_input = torch.cat([fc_feats, xt], 1) 139 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) 140 | if self.att_input_mode != 'region': 141 | att = self.attention(h_att, conv_feats, p_conv_feats) 142 | att2, att2_weight, att_h = self.attention2(h_att, pool_feats, p_pool_feats, att_mask[:,1:], pnt_mask[:,1:]) 143 | 144 | max_grd_val = att2.new(pool_feats.size(0), 1).fill_(0) # dummy 145 | grd_val = att2.new(pool_feats.size(0), 1).fill_(0) 146 | 147 | if self.att_input_mode == 'both': 148 | lang_lstm_input = torch.cat([att+att2, h_att], 1) 149 | elif self.att_input_mode == 'featmap': 150 | lang_lstm_input = torch.cat([att, h_att], 1) 151 | elif self.att_input_mode == 'region': 152 | lang_lstm_input = torch.cat([att2, h_att], 1) 153 | elif self.att_input_mode == 'dual_region': 154 | att2_dual, _, _ = self.attention2_dual(h_att, pool_feats, p_pool_feats, att_mask[:,1:], pnt_mask[:,1:]) 155 | dual_p = self.dual_pointer(h_att) 156 | lang_lstm_input = torch.cat([dual_p*att2+(1-dual_p)*att2_dual, h_att], 1) 157 | else: 158 | raise "Unknown attention input mode!" 159 | 160 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 161 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) # later encoded to P_{txt}^t 162 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 163 | 164 | return output, state, att2_weight, att_h, max_grd_val, grd_val 165 | 166 | 167 | class TopDownModel(AttModel): 168 | def __init__(self, opt): 169 | super(TopDownModel, self).__init__(opt) 170 | self.num_layers = 2 171 | self.core = TopDownCore(opt) 172 | 173 | 174 | class TransformerModel(AttModel): 175 | def __init__(self, opt): 176 | super(TransformerModel, self).__init__(opt) 177 | -------------------------------------------------------------------------------- /misc/CaptionModelBU.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import math 17 | import pdb 18 | import random 19 | 20 | class CaptionModel(nn.Module): 21 | def __init__(self): 22 | super(CaptionModel, self).__init__() 23 | 24 | def beam_search(self, state, rnn_output, beam_fc_feats, beam_conv_feats, beam_p_conv_feats, \ 25 | beam_pool_feats, beam_p_pool_feats, beam_sim_mat_static, beam_ppls, beam_pnt_mask, vis_offset, roi_offset, opt): 26 | # args are the miscelleous inputs to the core in addition to embedded word and state 27 | # kwargs only accept opt 28 | 29 | # def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, beam_bn_seq, \ 30 | # beam_bn_seq_logprobs, beam_fg_seq, beam_fg_seq_logprobs, rnn_output, beam_pnt_mask, state): 31 | def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, beam_att2_ind, \ 32 | rnn_output, beam_pnt_mask, state, att2_ind): 33 | #INPUTS: 34 | #logprobsf: probabilities augmented after diversity 35 | #beam_size: obvious 36 | #t : time instant 37 | #beam_seq : tensor contanining the beams 38 | #beam_seq_logprobs: tensor contanining the beam logprobs 39 | #beam_logprobs_sum: tensor contanining joint logprobs 40 | #OUPUTS: 41 | #beam_seq : tensor containing the word indices of the decoded captions 42 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq 43 | #beam_logprobs_sum : joint log-probability of each beam 44 | 45 | ys,ix = torch.sort(logprobsf,1,True) 46 | candidates = [] 47 | cols = min(beam_size, ys.size(1)) 48 | rows = beam_size 49 | if t == 0: 50 | rows = 1 51 | for c in range(cols): # for each column (word, essentially) 52 | for q in range(rows): # for each beam expansion 53 | #compute logprob of expanding beam q with word in (sorted) position c 54 | local_logprob = ys[q,c] 55 | 56 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 57 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_logprob, \ 58 | 'w':att2_ind[q] }) 59 | 60 | candidates = sorted(candidates, key=lambda x: -x['p']) 61 | 62 | new_state = [_.clone() for _ in state] 63 | new_rnn_output = rnn_output.clone() 64 | 65 | #beam_seq_prev, beam_seq_logprobs_prev 66 | if t >= 1: 67 | #we''ll need these as reference when we fork beams around 68 | beam_seq_prev = beam_seq[:t].clone() 69 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() 70 | beam_att2_ind_prev = beam_att2_ind[:t].clone() 71 | 72 | beam_pnt_mask_prev = beam_pnt_mask.clone() 73 | beam_pnt_mask = beam_pnt_mask.clone() 74 | 75 | for vix in range(beam_size): 76 | v = candidates[vix] 77 | #fork beam index q into index vix 78 | if t >= 1: 79 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']] 80 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] 81 | beam_att2_ind[:t, vix] = beam_att2_ind_prev[:, v['q']] 82 | beam_pnt_mask[:, vix] = beam_pnt_mask_prev[:, v['q']] 83 | 84 | #rearrange recurrent states 85 | for state_ix in range(len(new_state)): 86 | # copy over state in previous beam q to new beam at vix 87 | new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step 88 | 89 | new_rnn_output[vix] = rnn_output[v['q']] # dimension one is time step 90 | 91 | #append new end terminal at the end of this beam 92 | beam_seq[t, vix] = v['c'] # c'th word is the continuation 93 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here 94 | if t >= 1: 95 | beam_att2_ind[t, vix] = v['w'] 96 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 97 | 98 | state = new_state 99 | rnn_output = new_rnn_output 100 | 101 | return beam_seq, beam_seq_logprobs, beam_logprobs_sum, beam_att2_ind, \ 102 | rnn_output, state, beam_pnt_mask.t(), candidates 103 | 104 | # start beam search 105 | # opt = kwargs['opt'] 106 | beam_size = opt.get('beam_size', 5) 107 | beam_att_mask = beam_pnt_mask.clone() 108 | rois_num = beam_ppls.size(1) 109 | 110 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 111 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 112 | beam_att2_ind = torch.LongTensor(self.seq_length, beam_size).fill_(-1) 113 | att2_ind = torch.LongTensor(beam_size).fill_(-1) 114 | 115 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 116 | done_beams = [] 117 | beam_pnt_mask_list = [] 118 | beam_pnt_mask_list.append(beam_pnt_mask) 119 | 120 | for t in range(self.seq_length): 121 | """pem a beam merge. that is, 122 | for every previous beam we now many new possibilities to branch out 123 | we need to resort our beams to maintain the loop invariant of keeping 124 | the top beam_size most likely sequences.""" 125 | decoded = F.log_softmax(self.logit(rnn_output), dim=1) 126 | 127 | logprobs = decoded 128 | 129 | logprobsf = logprobs.data.cpu() # lets go to CPU for more efficiency in indexing operations 130 | # suppress UNK tokens in the decoding 131 | # logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 132 | 133 | beam_seq, beam_seq_logprobs, \ 134 | beam_logprobs_sum, beam_att2_ind, \ 135 | rnn_output, state, beam_pnt_mask_new, \ 136 | candidates_divm = beam_step(logprobsf, 137 | beam_size, 138 | t, 139 | beam_seq, 140 | beam_seq_logprobs, 141 | beam_logprobs_sum, 142 | beam_att2_ind, 143 | rnn_output, 144 | beam_pnt_mask_list[-1].t(), 145 | state, att2_ind) 146 | 147 | # encode as vectors 148 | it = beam_seq[t].cuda() 149 | assert(torch.sum(it>=self.vocab_size) == 0) 150 | 151 | roi_idx = it.clone() - self.vocab_size - 1 # starting from 0 152 | roi_mask = roi_idx < 0 153 | 154 | for vix in range(beam_size): 155 | # if time's up... or if end token is reached then copy beams 156 | if beam_seq[t, vix] == 0 or t == self.seq_length - 1: 157 | final_beam = { 158 | 'seq': beam_seq[:, vix].clone(), 159 | 'logps': beam_seq_logprobs[:, vix].clone(), 160 | 'att2' : beam_att2_ind[:, vix], 161 | 'p': beam_logprobs_sum[vix], 162 | } 163 | 164 | done_beams.append(final_beam) 165 | # don't continue beams from finished sequences 166 | beam_logprobs_sum[vix] = -1000 167 | 168 | # updating the mask, and make sure that same object won't happen in the caption 169 | pnt_idx_offset = roi_idx + roi_offset + 1 170 | pnt_idx_offset[roi_mask] = 0 171 | beam_pnt_mask = beam_pnt_mask_new.data.clone() 172 | 173 | beam_pnt_mask.view(-1)[pnt_idx_offset] = 1 174 | beam_pnt_mask.view(-1)[0] = 0 175 | beam_pnt_mask_list.append(Variable(beam_pnt_mask)) 176 | 177 | xt = self.embed(Variable(it)) 178 | 179 | rnn_output, state, att2_weight, att_h, _, _ = self.core(xt, beam_fc_feats, beam_conv_feats, 180 | beam_p_conv_feats, beam_pool_feats, beam_p_pool_feats, beam_att_mask, beam_pnt_mask_list[-1], \ 181 | state, Variable(beam_pool_feats.data.new(beam_size, rois_num).fill_(0)), beam_sim_mat_static, self) 182 | _, att2_ind = torch.max(att2_weight, 1) 183 | 184 | done_beams = sorted(done_beams, key=lambda x: -x['p'])[:beam_size] 185 | return done_beams 186 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/grounded-video-description/2667e216223a63cdc91b616a34f4c046e507a1dc/misc/__init__.py -------------------------------------------------------------------------------- /misc/bbox_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # -------------------------------------------------------- 9 | # Fast R-CNN 10 | # Copyright (c) 2015 Microsoft 11 | # Licensed under The MIT License [see LICENSE for details] 12 | # Written by Ross Girshick 13 | # -------------------------------------------------------- 14 | # -------------------------------------------------------- 15 | # Reorganized and modified by Jianwei Yang and Jiasen Lu 16 | # -------------------------------------------------------- 17 | 18 | import torch 19 | import numpy as np 20 | import pdb 21 | 22 | def bbox_transform(ex_rois, gt_rois): 23 | ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0 24 | ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0 25 | ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths 26 | ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights 27 | 28 | gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0 29 | gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0 30 | gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths 31 | gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights 32 | 33 | targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths 34 | targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights 35 | targets_dw = torch.log(gt_widths / ex_widths) 36 | targets_dh = torch.log(gt_heights / ex_heights) 37 | 38 | targets = torch.stack( 39 | (targets_dx, targets_dy, targets_dw, targets_dh),1) 40 | 41 | return targets 42 | 43 | def bbox_transform_batch(ex_rois, gt_rois): 44 | 45 | if ex_rois.dim() == 2: 46 | ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0 47 | ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0 48 | ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths 49 | ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights 50 | 51 | gt_widths = gt_rois[:, :, 2] - gt_rois[:, :, 0] + 1.0 52 | gt_heights = gt_rois[:, :, 3] - gt_rois[:, :, 1] + 1.0 53 | gt_ctr_x = gt_rois[:, :, 0] + 0.5 * gt_widths 54 | gt_ctr_y = gt_rois[:, :, 1] + 0.5 * gt_heights 55 | 56 | targets_dx = (gt_ctr_x - ex_ctr_x.view(1,-1).expand_as(gt_ctr_x)) / ex_widths 57 | targets_dy = (gt_ctr_y - ex_ctr_y.view(1,-1).expand_as(gt_ctr_y)) / ex_heights 58 | targets_dw = torch.log(gt_widths / ex_widths.view(1,-1).expand_as(gt_widths)) 59 | targets_dh = torch.log(gt_heights / ex_heights.view(1,-1).expand_as(gt_heights)) 60 | 61 | elif ex_rois.dim() == 3: 62 | ex_widths = ex_rois[:, :, 2] - ex_rois[:, :, 0] + 1.0 63 | ex_heights = ex_rois[:,:, 3] - ex_rois[:,:, 1] + 1.0 64 | ex_ctr_x = ex_rois[:, :, 0] + 0.5 * ex_widths 65 | ex_ctr_y = ex_rois[:, :, 1] + 0.5 * ex_heights 66 | 67 | gt_widths = gt_rois[:, :, 2] - gt_rois[:, :, 0] + 1.0 68 | gt_heights = gt_rois[:, :, 3] - gt_rois[:, :, 1] + 1.0 69 | gt_ctr_x = gt_rois[:, :, 0] + 0.5 * gt_widths 70 | gt_ctr_y = gt_rois[:, :, 1] + 0.5 * gt_heights 71 | 72 | targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths 73 | targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights 74 | targets_dw = torch.log(gt_widths / ex_widths) 75 | targets_dh = torch.log(gt_heights / ex_heights) 76 | else: 77 | raise ValueError('ex_roi input dimension is not correct.') 78 | 79 | targets = torch.stack( 80 | (targets_dx, targets_dy, targets_dw, targets_dh),2) 81 | 82 | return targets 83 | 84 | def bbox_transform_inv(boxes, deltas, batch_size): 85 | widths = boxes[:, :, 2] - boxes[:, :, 0] + 1.0 86 | heights = boxes[:, :, 3] - boxes[:, :, 1] + 1.0 87 | ctr_x = boxes[:, :, 0] + 0.5 * widths 88 | ctr_y = boxes[:, :, 1] + 0.5 * heights 89 | 90 | dx = deltas[:, :, 0::4] 91 | dy = deltas[:, :, 1::4] 92 | dw = deltas[:, :, 2::4] 93 | dh = deltas[:, :, 3::4] 94 | 95 | pred_ctr_x = dx * widths.unsqueeze(2) + ctr_x.unsqueeze(2) 96 | pred_ctr_y = dy * heights.unsqueeze(2) + ctr_y.unsqueeze(2) 97 | pred_w = np.exp(dw) * widths.unsqueeze(2) 98 | pred_h = np.exp(dh) * heights.unsqueeze(2) 99 | 100 | pred_boxes = deltas.clone() 101 | # x1 102 | pred_boxes[:, :, 0::4] = pred_ctr_x - 0.5 * pred_w 103 | # y1 104 | pred_boxes[:, :, 1::4] = pred_ctr_y - 0.5 * pred_h 105 | # x2 106 | pred_boxes[:, :, 2::4] = pred_ctr_x + 0.5 * pred_w 107 | # y2 108 | pred_boxes[:, :, 3::4] = pred_ctr_y + 0.5 * pred_h 109 | 110 | return pred_boxes 111 | 112 | def clip_boxes_batch(boxes, im_shape, batch_size): 113 | """ 114 | Clip boxes to image boundaries. 115 | """ 116 | num_rois = boxes.size(1) 117 | 118 | boxes[boxes < 0] = 0 119 | # batch_x = (im_shape[:,0]-1).view(batch_size, 1).expand(batch_size, num_rois) 120 | # batch_y = (im_shape[:,1]-1).view(batch_size, 1).expand(batch_size, num_rois) 121 | 122 | batch_x = im_shape[:, 1] - 1 123 | batch_y = im_shape[:, 0] - 1 124 | 125 | boxes[:,:,0][boxes[:,:,0] > batch_x] = batch_x 126 | boxes[:,:,1][boxes[:,:,1] > batch_y] = batch_y 127 | boxes[:,:,2][boxes[:,:,2] > batch_x] = batch_x 128 | boxes[:,:,3][boxes[:,:,3] > batch_y] = batch_y 129 | 130 | return boxes 131 | 132 | def clip_boxes(boxes, im_shape, batch_size): 133 | 134 | for i in range(batch_size): 135 | boxes[i,:,0::4].clamp_(0, im_shape[i, 1]-1) 136 | boxes[i,:,1::4].clamp_(0, im_shape[i, 0]-1) 137 | boxes[i,:,2::4].clamp_(0, im_shape[i, 1]-1) 138 | boxes[i,:,3::4].clamp_(0, im_shape[i, 0]-1) 139 | 140 | return boxes 141 | 142 | 143 | def bbox_overlaps(anchors, gt_boxes): 144 | """ 145 | anchors: (N, 4) ndarray of float 146 | gt_boxes: (K, 4) ndarray of float 147 | 148 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 149 | """ 150 | N = anchors.size(0) 151 | K = gt_boxes.size(0) 152 | 153 | gt_boxes_area = ((gt_boxes[:,2] - gt_boxes[:,0] + 1) * 154 | (gt_boxes[:,3] - gt_boxes[:,1] + 1)).view(1, K) 155 | 156 | anchors_area = ((anchors[:,2] - anchors[:,0] + 1) * 157 | (anchors[:,3] - anchors[:,1] + 1)).view(N, 1) 158 | 159 | boxes = anchors.view(N, 1, 4).expand(N, K, 4) 160 | query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) 161 | 162 | iw = (torch.min(boxes[:,:,2], query_boxes[:,:,2]) - 163 | torch.max(boxes[:,:,0], query_boxes[:,:,0]) + 1) 164 | iw[iw < 0] = 0 165 | 166 | ih = (torch.min(boxes[:,:,3], query_boxes[:,:,3]) - 167 | torch.max(boxes[:,:,1], query_boxes[:,:,1]) + 1) 168 | ih[ih < 0] = 0 169 | 170 | ua = anchors_area + gt_boxes_area - (iw * ih) 171 | overlaps = iw * ih / ua 172 | 173 | return overlaps 174 | 175 | # modifed on 09/26/2018 for anet 176 | def bbox_overlaps_batch(anchors, gt_boxes, frm_mask=None): 177 | """ 178 | anchors: (N, 4) ndarray of float 179 | gt_boxes: (b, K, 5) ndarray of float 180 | frm_mask: (b, N, K) ndarray of bool 181 | 182 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 183 | """ 184 | batch_size = gt_boxes.size(0) 185 | 186 | 187 | if anchors.dim() == 2: 188 | assert frm_mask == None, 'mask not implemented yet' # hasn't updated the mask yet 189 | N = anchors.size(0) 190 | K = gt_boxes.size(1) 191 | 192 | anchors = anchors.view(1, N, 4).expand(batch_size, N, 4).contiguous() 193 | gt_boxes = gt_boxes[:,:,:4].contiguous() 194 | 195 | 196 | gt_boxes_x = (gt_boxes[:,:,2] - gt_boxes[:,:,0] + 1) 197 | gt_boxes_y = (gt_boxes[:,:,3] - gt_boxes[:,:,1] + 1) 198 | gt_boxes_area = (gt_boxes_x * gt_boxes_y).view(batch_size, 1, K) 199 | 200 | anchors_boxes_x = (anchors[:,:,2] - anchors[:,:,0] + 1) 201 | anchors_boxes_y = (anchors[:,:,3] - anchors[:,:,1] + 1) 202 | anchors_area = (anchors_boxes_x * anchors_boxes_y).view(batch_size, N, 1) 203 | 204 | gt_area_zero = (gt_boxes_x == 1) & (gt_boxes_y == 1) 205 | anchors_area_zero = (anchors_boxes_x == 1) & (anchors_boxes_y == 1) 206 | 207 | boxes = anchors.view(batch_size, N, 1, 4).expand(batch_size, N, K, 4) 208 | query_boxes = gt_boxes.view(batch_size, 1, K, 4).expand(batch_size, N, K, 4) 209 | 210 | iw = (torch.min(boxes[:,:,:,2], query_boxes[:,:,:,2]) - 211 | torch.max(boxes[:,:,:,0], query_boxes[:,:,:,0]) + 1) 212 | iw[iw < 0] = 0 213 | 214 | ih = (torch.min(boxes[:,:,:,3], query_boxes[:,:,:,3]) - 215 | torch.max(boxes[:,:,:,1], query_boxes[:,:,:,1]) + 1) 216 | ih[ih < 0] = 0 217 | ua = anchors_area + gt_boxes_area - (iw * ih) 218 | overlaps = iw * ih / ua 219 | 220 | # mask the overlap here. 221 | overlaps.masked_fill_(gt_area_zero.view(batch_size, 1, K).expand(batch_size, N, K), 0) 222 | overlaps.masked_fill_(anchors_area_zero.view(batch_size, N, 1).expand(batch_size, N, K), -1) 223 | 224 | elif anchors.dim() == 3: 225 | N = anchors.size(1) 226 | K = gt_boxes.size(1) 227 | 228 | if anchors.size(2) == 5: 229 | anchors = anchors[:,:,:5].contiguous() 230 | else: 231 | anchors = anchors[:,:,1:6].contiguous() 232 | 233 | gt_boxes = gt_boxes[:,:,:5].contiguous() 234 | 235 | gt_boxes_x = (gt_boxes[:,:,2] - gt_boxes[:,:,0] + 1) 236 | gt_boxes_y = (gt_boxes[:,:,3] - gt_boxes[:,:,1] + 1) 237 | gt_boxes_area = (gt_boxes_x * gt_boxes_y).view(batch_size, 1, K) 238 | 239 | anchors_boxes_x = (anchors[:,:,2] - anchors[:,:,0] + 1) 240 | anchors_boxes_y = (anchors[:,:,3] - anchors[:,:,1] + 1) 241 | anchors_area = (anchors_boxes_x * anchors_boxes_y).view(batch_size, N, 1) 242 | 243 | gt_area_zero = (gt_boxes_x == 1) & (gt_boxes_y == 1) 244 | anchors_area_zero = (anchors_boxes_x == 1) & (anchors_boxes_y == 1) 245 | 246 | boxes = anchors.view(batch_size, N, 1, 5).expand(batch_size, N, K, 5) 247 | query_boxes = gt_boxes.view(batch_size, 1, K, 5).expand(batch_size, N, K, 5) 248 | 249 | iw = (torch.min(boxes[:,:,:,2], query_boxes[:,:,:,2]) - 250 | torch.max(boxes[:,:,:,0], query_boxes[:,:,:,0]) + 1) 251 | iw[iw < 0] = 0 252 | 253 | ih = (torch.min(boxes[:,:,:,3], query_boxes[:,:,:,3]) - 254 | torch.max(boxes[:,:,:,1], query_boxes[:,:,:,1]) + 1) 255 | ih[ih < 0] = 0 256 | ua = anchors_area + gt_boxes_area - (iw * ih) 257 | 258 | if frm_mask is not None: 259 | # proposal and gt should be on the same frame to overlap 260 | # frm_mask = ~frm_mask # bitwise not (~) does not work with uint8 in pytorch 1.3 261 | frm_mask = 1 - frm_mask 262 | # print('Percentage of proposals that are in the annotated frame: {}'.format(torch.mean(frm_mask.float()))) 263 | 264 | overlaps = iw * ih / ua 265 | overlaps *= frm_mask.type(overlaps.type()) 266 | 267 | # mask the overlap here. 268 | overlaps.masked_fill_(gt_area_zero.view(batch_size, 1, K).expand(batch_size, N, K), 0) 269 | overlaps.masked_fill_(anchors_area_zero.view(batch_size, N, 1).expand(batch_size, N, K), -1) 270 | else: 271 | raise ValueError('anchors input dimension is not correct.') 272 | 273 | return overlaps 274 | 275 | -------------------------------------------------------------------------------- /misc/dataloader_anet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import json 13 | import h5py 14 | import os 15 | import numpy as np 16 | import random 17 | from torchvision.datasets.folder import default_loader 18 | import torch 19 | import torch.utils.data as data 20 | import copy 21 | from PIL import Image 22 | import torchvision.transforms as transforms 23 | import torchtext.vocab as vocab # use this to load glove vector 24 | from collections import defaultdict 25 | 26 | class DataLoader(data.Dataset): 27 | def __init__(self, opt, split='training', seq_per_img=5): 28 | self.opt = opt 29 | self.batch_size = self.opt.batch_size 30 | self.seq_per_img = opt.seq_per_img 31 | self.seq_length = opt.seq_length 32 | self.split = split 33 | self.seq_per_img = seq_per_img 34 | self.att_feat_size = opt.att_feat_size 35 | self.vis_attn = opt.vis_attn 36 | self.feature_root = opt.feature_root 37 | self.seg_feature_root = opt.seg_feature_root 38 | self.num_sampled_frm = opt.num_sampled_frm 39 | self.num_prop_per_frm = opt.num_prop_per_frm 40 | self.exclude_bgd_det = opt.exclude_bgd_det 41 | self.prop_thresh = opt.prop_thresh 42 | self.t_attn_size = opt.t_attn_size 43 | self.test_mode = opt.test_mode 44 | self.max_gt_box = 100 45 | self.max_proposal = self.num_sampled_frm * self.num_prop_per_frm 46 | self.glove = vocab.GloVe(name='6B', dim=300) 47 | 48 | # load the json file which contains additional information about the dataset 49 | print('DataLoader loading json file: ', opt.input_dic) 50 | self.info = json.load(open(self.opt.input_dic)) 51 | self.itow = self.info['ix_to_word'] 52 | self.wtoi = {w:i for i,w in self.itow.items()} 53 | self.wtod = {w:i+1 for w,i in self.info['wtod'].items()} # word to detection 54 | self.dtoi = self.wtod # detection to index 55 | self.itod = {i:w for w,i in self.dtoi.items()} 56 | self.wtol = self.info['wtol'] 57 | self.ltow = {l:w for w,l in self.wtol.items()} 58 | self.vocab_size = len(self.itow) + 1 # since it start from 1 59 | print('vocab size is ', self.vocab_size) 60 | self.itoc = self.itod 61 | 62 | # get the glove vector for the vg detection cls 63 | obj_cls_file = 'data/vg_object_vocab.txt' # From Peter's repo 64 | with open(obj_cls_file) as f: 65 | data = f.readlines() 66 | classes = ['__background__'] 67 | classes.extend([i.strip() for i in data]) 68 | 69 | # for VG classes 70 | self.vg_cls = classes 71 | self.glove_vg_cls = np.zeros((len(classes), 300)) 72 | for i, w in enumerate(classes): 73 | split_word = w.replace(',', ' ').split(' ') 74 | vector = [] 75 | for word in split_word: 76 | if word in self.glove.stoi: 77 | vector.append(self.glove.vectors[self.glove.stoi[word]].numpy()) 78 | else: # use a random vector instead 79 | vector.append(2*np.random.rand(300) - 1) 80 | 81 | avg_vector = np.zeros((300)) 82 | for v in vector: 83 | avg_vector += v 84 | 85 | self.glove_vg_cls[i] = avg_vector/len(vector) 86 | 87 | # open the caption json file 88 | print('DataLoader loading input file: ', opt.input_json) 89 | self.caption_file = json.load(open(self.opt.input_json)) 90 | 91 | # open the caption json file with segment boundaries 92 | print('DataLoader loading grounding file: ', opt.grd_reference) 93 | self.timestamp_file = json.load(open(opt.grd_reference)) 94 | 95 | # open the detection json file. 96 | print('DataLoader loading proposal file: ', opt.proposal_h5) 97 | h5_proposal_file = h5py.File(self.opt.proposal_h5, 'r', driver='core') 98 | self.num_proposals = h5_proposal_file['dets_num'][:] 99 | self.label_proposals = h5_proposal_file['dets_labels'][:] 100 | h5_proposal_file.close() 101 | 102 | # category id to labels. +1 becuase 0 is the background label. 103 | self.glove_clss = np.zeros((len(self.itod)+1, 300)) 104 | self.glove_clss[0] = 2*np.random.rand(300) - 1 # background 105 | for i, word in enumerate(self.itod.values()): 106 | if word in self.glove.stoi: 107 | vector = self.glove.vectors[self.glove.stoi[word]] 108 | else: # use a random vector instead 109 | vector = 2*np.random.rand(300) - 1 110 | self.glove_clss[i+1] = vector 111 | 112 | self.glove_w = np.zeros((len(self.wtoi)+1, 300)) 113 | for i, word in enumerate(self.wtoi.keys()): 114 | vector = np.zeros((300)) 115 | count = 0 116 | for w in word.split(' '): 117 | count += 1 118 | if w in self.glove.stoi: 119 | glove_vector = self.glove.vectors[self.glove.stoi[w]] 120 | vector += glove_vector.numpy() 121 | else: # use a random vector instead 122 | random_vector = 2*np.random.rand(300) - 1 123 | vector += random_vector 124 | self.glove_w[i+1] = vector / count 125 | 126 | self.detect_size = len(self.itod) 127 | 128 | # separate out indexes for each of the provided splits 129 | self.split_ix = [] 130 | self.num_seg_per_vid = defaultdict(list) 131 | for ix in range(len(self.info['videos'])): 132 | seg = self.info['videos'][ix] 133 | seg_id = seg['id'] 134 | vid_id, seg_idx = seg_id.split('_segment_') 135 | self.num_seg_per_vid[vid_id].append(int(seg_idx)) 136 | if seg['split'] == split: 137 | # all the feature files must exist 138 | if os.path.isfile(os.path.join(self.feature_root, seg_id+'.npy')) and \ 139 | os.path.isfile(os.path.join(self.seg_feature_root, vid_id[2:]+'_bn.npy')): 140 | if opt.vis_attn: 141 | if random.random() < 0.001: # randomly sample 0.1% segments to visualize 142 | self.split_ix.append(ix) 143 | else: 144 | self.split_ix.append(ix) 145 | print('assigned %d segments to split %s' %(len(self.split_ix), split)) 146 | 147 | def get_det_word(self, gt_bboxs, caption, bbox_ann): 148 | 149 | # get the present category. 150 | pcats = [] 151 | for i in range(gt_bboxs.shape[0]): 152 | pcats.append(gt_bboxs[i,6]) 153 | # get the orginial form of the caption. 154 | indicator = [] 155 | 156 | indicator.append([(0, 0, 0)]*len(caption)) # category class, binary class, fine-grain class. 157 | for i, bbox in enumerate(bbox_ann): 158 | # if the bbox_idx is not filtered out. 159 | if bbox['bbox_idx'] in pcats: 160 | w_idx = bbox['idx'] 161 | ng = bbox['clss'] 162 | bn = (ng != caption[w_idx]) + 1 163 | fg = bbox['label'] 164 | indicator[0][w_idx] = (self.wtod[bbox['clss']], bn, fg) 165 | 166 | return indicator 167 | 168 | def get_frm_mask(self, proposals, gt_bboxs): 169 | # proposals: num_pps 170 | # gt_bboxs: num_box 171 | num_pps = proposals.shape[0] 172 | num_box = gt_bboxs.shape[0] 173 | return (np.tile(proposals.reshape(-1,1), (1,num_box)) != np.tile(gt_bboxs, (num_pps,1))) 174 | 175 | def __getitem__(self, index): 176 | 177 | ix = self.split_ix[index] 178 | 179 | seg_id = self.info['videos'][ix]['id'] 180 | vid_id_ix, seg_id_ix = seg_id.split('_segment_') 181 | seg_id_ix = str(int(seg_id_ix)) 182 | 183 | # load the proposal file 184 | num_proposal = int(self.num_proposals[ix]) 185 | proposals = copy.deepcopy(self.label_proposals[ix]) 186 | proposals = proposals[:num_proposal,:] 187 | 188 | # no need to resize proposal nor GT box since they are all based on images with 720px in width) 189 | region_feature = np.load(os.path.join(self.feature_root, seg_id+'.npy')) 190 | region_feature = region_feature.reshape(-1, region_feature.shape[2]).copy() 191 | assert(num_proposal == region_feature.shape[0]) 192 | 193 | # proposal mask to filter out low-confidence proposals or backgrounds 194 | pnt_mask = (proposals[:, 6] <= self.prop_thresh) 195 | if self.exclude_bgd_det: 196 | pnt_mask |= (proposals[:, 5] == 0) 197 | 198 | # load the frame-wise segment feature 199 | seg_rgb_feature = np.load(os.path.join(self.seg_feature_root, vid_id_ix[2:]+'_resnet.npy')) 200 | seg_motion_feature = np.load(os.path.join(self.seg_feature_root, vid_id_ix[2:]+'_bn.npy')) 201 | seg_feature_raw = np.concatenate((seg_rgb_feature, seg_motion_feature), axis=1) 202 | 203 | # not accurate, with minor misalignments 204 | timestamps = self.timestamp_file['annotations'][vid_id_ix]['segments'][str(int(seg_id_ix))]['timestamps'] 205 | dur = self.timestamp_file['annotations'][vid_id_ix]['duration'] 206 | num_frm = seg_feature_raw.shape[0] 207 | sample_idx = np.array([np.round(num_frm*timestamps[0]*1./dur), np.round(num_frm*timestamps[1]*1./dur)]) 208 | sample_idx = np.clip(np.round(sample_idx), 0, self.t_attn_size).astype(int) 209 | seg_feature = np.zeros((self.t_attn_size, seg_feature_raw.shape[1])) 210 | seg_feature[:min(self.t_attn_size, num_frm)] = seg_feature_raw[:self.t_attn_size] 211 | 212 | captions = [copy.deepcopy(self.caption_file[vid_id_ix]['segments'][seg_id_ix])] # one per segment 213 | assert len(captions) == 1, 'Only support one caption per segment for now!' 214 | 215 | bbox_ann = [] 216 | bbox_idx = 0 217 | for caption in captions: 218 | for i, clss in enumerate(caption['clss']): 219 | for j, cls in enumerate(clss): # one box might have multiple labels 220 | # we don't care about the boxes outside the length limit. 221 | # after all our goal is referring, not detection 222 | if caption['idx'][i][j] < self.seq_length: 223 | if self.test_mode: 224 | # dummy bbox and frm_idx for the hidden testing split 225 | bbox_ann.append({'bbox':[0, 0, 0, 0], 'label': self.dtoi[cls], 'clss': cls, 226 | 'bbox_idx':bbox_idx, 'idx':caption['idx'][i][j], 'frm_idx':-1}) 227 | else: 228 | bbox_ann.append({'bbox':caption['bbox'][i], 'label': self.dtoi[cls], 'clss': cls, 229 | 'bbox_idx':bbox_idx, 'idx':caption['idx'][i][j], 'frm_idx':caption['frm_idx'][i]}) 230 | 231 | bbox_idx += 1 232 | 233 | # (optional) sort the box based on idx 234 | bbox_ann = sorted(bbox_ann, key=lambda x:x['idx']) 235 | 236 | gt_bboxs = np.zeros((len(bbox_ann), 8)) 237 | for i, bbox in enumerate(bbox_ann): 238 | gt_bboxs[i, :4] = bbox['bbox'] 239 | gt_bboxs[i, 4] = bbox['frm_idx'] 240 | gt_bboxs[i, 5] = bbox['label'] 241 | gt_bboxs[i, 6] = bbox['bbox_idx'] 242 | gt_bboxs[i, 7] = bbox['idx'] 243 | 244 | if not self.test_mode: # skip this in test mode 245 | gt_x = (gt_bboxs[:,2]-gt_bboxs[:,0]+1) 246 | gt_y = (gt_bboxs[:,3]-gt_bboxs[:,1]+1) 247 | gt_area_nonzero = (((gt_x != 1) & (gt_y != 1))) 248 | gt_bboxs = gt_bboxs[gt_area_nonzero] 249 | 250 | # given the bbox_ann, and caption, this function determine which word belongs to the detection. 251 | det_indicator = self.get_det_word(gt_bboxs, captions[0]['caption'], bbox_ann) 252 | # fetch the captions 253 | ncap = len(captions) # number of captions available for this image 254 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' 255 | 256 | # convert caption into sequence label. 257 | cap_seq = np.zeros([ncap, self.seq_length, 5]) 258 | for i, caption in enumerate(captions): 259 | j = 0 260 | while j < len(caption['caption']) and j < self.seq_length: 261 | is_det = False 262 | if det_indicator[i][j][0] != 0: 263 | cap_seq[i,j,0] = det_indicator[i][j][0] + self.vocab_size 264 | cap_seq[i,j,1] = det_indicator[i][j][1] 265 | cap_seq[i,j,2] = det_indicator[i][j][2] 266 | cap_seq[i,j,3] = self.wtoi[caption['caption'][j]] 267 | cap_seq[i,j,4] = self.wtoi[caption['caption'][j]] 268 | else: 269 | cap_seq[i,j,0] = self.wtoi[caption['caption'][j]] 270 | cap_seq[i,j,4] = self.wtoi[caption['caption'][j]] 271 | j += 1 272 | 273 | # get the mask of the ground truth bounding box. The data shape is 274 | # num_caption x num_box x num_seq 275 | box_mask = np.ones((len(captions), gt_bboxs.shape[0], self.seq_length)) 276 | for i in range(gt_bboxs.shape[0]): 277 | box_mask[0, i, int(gt_bboxs[i][7])] = 0 278 | 279 | gt_bboxs = gt_bboxs[:,:6] 280 | 281 | # get the batch version of the seq and box_mask. 282 | if ncap < self.seq_per_img: 283 | seq_batch = np.zeros([self.seq_per_img, self.seq_length, 4]) 284 | mask_batch = np.zeros([self.seq_per_img, gt_bboxs.shape[0], self.seq_length]) 285 | # we need to subsample (with replacement) 286 | for q in range(self.seq_per_img): 287 | ixl = random.randint(0,ncap) 288 | seq_batch[q,:] = cap_seq[ixl,:,:4] 289 | mask_batch[q,:] = box_mask[ixl] 290 | else: 291 | ixl = random.randint(0, ncap - self.seq_per_img) 292 | seq_batch = cap_seq[ixl:ixl+self.seq_per_img,:,:4] 293 | mask_batch = box_mask[ixl:ixl+self.seq_per_img] 294 | 295 | input_seq = np.zeros([self.seq_per_img, self.seq_length+1, 4]) 296 | input_seq[:,1:] = seq_batch 297 | 298 | gt_seq = np.zeros([10, self.seq_length]) 299 | gt_seq[:ncap,:] = cap_seq[:,:,4] 300 | 301 | # load the image for visualization purposes 302 | if self.vis_attn: 303 | seg_show = np.zeros((self.num_sampled_frm, 1280, 720, 3)) 304 | seg_dim_info = torch.LongTensor(2) 305 | for i in range(self.num_sampled_frm): 306 | try: 307 | img = Image.open(os.path.join(self.opt.image_path, seg_id, str(i+1).zfill(2)+'.jpg')).convert('RGB') 308 | width, height = img.size 309 | seg_show[i, :height, :width] = np.array(img) 310 | seg_dim_info[0] = height 311 | seg_dim_info[1] = width 312 | except: 313 | print('cannot load image...') 314 | break 315 | seg_show = torch.from_numpy(seg_show).type(torch.ByteTensor) 316 | 317 | # padding the proposals and gt_bboxs 318 | pad_proposals = np.zeros((self.max_proposal, 7)) 319 | pad_pnt_mask = np.ones((self.max_proposal)) 320 | pad_gt_bboxs = np.zeros((self.max_gt_box, 6)) 321 | pad_box_mask = np.ones((self.seq_per_img, self.max_gt_box, self.seq_length+1)) 322 | pad_region_feature = np.zeros((self.max_proposal, self.att_feat_size)) 323 | pad_frm_mask = np.ones((self.max_proposal, self.max_gt_box)) # mask out proposals outside the target frames 324 | 325 | num_box = min(gt_bboxs.shape[0], self.max_gt_box) 326 | num_pps = min(proposals.shape[0], self.max_proposal) 327 | pad_proposals[:num_pps] = proposals[:num_pps] 328 | pad_pnt_mask[:num_pps] = pnt_mask[:num_pps] 329 | pad_gt_bboxs[:num_box] = gt_bboxs[:num_box] 330 | pad_box_mask[:,:num_box,1:] = mask_batch[:,:num_box,:] 331 | pad_region_feature[:num_pps] = region_feature[:num_pps] 332 | 333 | frm_mask = self.get_frm_mask(pad_proposals[:num_pps, 4], pad_gt_bboxs[:num_box, 4]) 334 | pad_frm_mask[:num_pps, :num_box] = frm_mask 335 | 336 | input_seq = torch.from_numpy(input_seq).long() 337 | gt_seq = torch.from_numpy(gt_seq).long() 338 | pad_proposals = torch.from_numpy(pad_proposals).float() 339 | pad_pnt_mask = torch.from_numpy(pad_pnt_mask).byte() 340 | pad_gt_bboxs = torch.from_numpy(pad_gt_bboxs).float() 341 | pad_box_mask = torch.from_numpy(pad_box_mask).byte() 342 | pad_region_feature = torch.from_numpy(pad_region_feature).float() 343 | pad_proposals.masked_fill_(pad_pnt_mask.view(-1, 1), 0.) 344 | pad_region_feature.masked_fill_(pad_pnt_mask.view(-1, 1), 0.) 345 | pad_frm_mask = torch.from_numpy(pad_frm_mask).byte() 346 | num = torch.FloatTensor([ncap, num_pps, num_box, int(seg_id_ix), 347 | max(self.num_seg_per_vid[vid_id_ix])+1, timestamps[0]*1./dur, 348 | timestamps[1]*1./dur]) # 3 + 4 (seg_id, num_of_seg_in_video, seg_start_time, seg_end_time) 349 | sample_idx = torch.from_numpy(sample_idx).long() 350 | 351 | if self.vis_attn: 352 | return seg_feature, input_seq, gt_seq, num, pad_proposals, pad_gt_bboxs, pad_box_mask, seg_id, seg_show, seg_dim_info, pad_region_feature, pad_frm_mask, sample_idx, pad_pnt_mask 353 | else: 354 | return seg_feature, input_seq, gt_seq, num, pad_proposals, pad_gt_bboxs, pad_box_mask, seg_id, pad_region_feature, pad_frm_mask, sample_idx, pad_pnt_mask 355 | 356 | 357 | def __len__(self): 358 | return len(self.split_ix) 359 | -------------------------------------------------------------------------------- /misc/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | from torch.autograd import Variable 17 | import math 18 | import numpy as np 19 | import random 20 | import pdb 21 | import pickle 22 | 23 | import misc.utils as utils 24 | from misc.CaptionModelBU import CaptionModel 25 | from misc.transformer import Transformer, TransformerDecoder 26 | 27 | 28 | class AttModel(CaptionModel): 29 | def __init__(self, opt): 30 | super(AttModel, self).__init__() 31 | self.vocab_size = opt.vocab_size 32 | self.detect_size = opt.detect_size # number of object classes 33 | self.input_encoding_size = opt.input_encoding_size 34 | self.rnn_size = opt.rnn_size 35 | self.num_layers = opt.num_layers 36 | self.drop_prob_lm = opt.drop_prob_lm 37 | self.seq_length = opt.seq_length 38 | self.seg_info_size = 50 39 | self.fc_feat_size = opt.fc_feat_size+self.seg_info_size 40 | self.att_feat_size = opt.att_feat_size 41 | self.att_hid_size = opt.att_hid_size 42 | self.seq_per_img = opt.seq_per_img 43 | self.itod = opt.itod 44 | self.att_input_mode = opt.att_input_mode 45 | self.transfer_mode = opt.transfer_mode 46 | self.test_mode = opt.test_mode 47 | self.enable_BUTD = opt.enable_BUTD 48 | self.w_grd = opt.w_grd 49 | self.w_cls = opt.w_cls 50 | self.num_sampled_frm = opt.num_sampled_frm 51 | self.num_prop_per_frm = opt.num_prop_per_frm 52 | self.att_model = opt.att_model 53 | self.unk_idx = int(opt.wtoi['UNK']) 54 | 55 | if opt.region_attn_mode == 'add': 56 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 57 | elif opt.region_attn_mode == 'cat': 58 | self.alpha_net = nn.Linear(self.att_hid_size*2, 1) 59 | 60 | self.stride = 32 # downsizing from input image to feature map 61 | 62 | self.t_attn_size = opt.t_attn_size 63 | self.tiny_value = 1e-8 64 | 65 | if self.enable_BUTD: 66 | assert(self.att_input_mode == 'region') 67 | self.pool_feat_size = self.att_feat_size 68 | else: 69 | self.pool_feat_size = self.att_feat_size+300+self.detect_size+1 70 | 71 | self.min_value = -1e8 72 | opt.beta = 1 73 | self.beta = opt.beta 74 | 75 | self.loc_fc = nn.Sequential(nn.Linear(5, 300), 76 | nn.ReLU(), 77 | nn.Dropout(inplace=True)) 78 | 79 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size, 80 | self.input_encoding_size), # det is 1-indexed 81 | nn.ReLU(), 82 | nn.Dropout(self.drop_prob_lm, inplace=True)) 83 | 84 | if self.transfer_mode in ('none', 'cls'): 85 | self.vis_encoding_size = 2048 86 | elif self.transfer_mode == 'both': 87 | self.vis_encoding_size = 2348 88 | elif self.transfer_mode == 'glove': 89 | self.vis_encoding_size = 300 90 | else: 91 | raise NotImplementedError 92 | 93 | self.vis_embed = nn.Sequential(nn.Embedding(self.detect_size+1, 94 | self.vis_encoding_size), # det is 1-indexed 95 | nn.ReLU(), 96 | nn.Dropout(self.drop_prob_lm, inplace=True) 97 | ) 98 | 99 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 100 | nn.ReLU(), 101 | nn.Dropout(self.drop_prob_lm, inplace=True)) 102 | 103 | self.seg_info_embed = nn.Sequential(nn.Linear(4, self.seg_info_size), 104 | nn.ReLU(), 105 | nn.Dropout(self.drop_prob_lm, inplace=True)) 106 | 107 | self.att_embed = nn.ModuleList([nn.Sequential(nn.Linear(2048, self.rnn_size//2), # for rgb feature 108 | nn.ReLU(), 109 | nn.Dropout(self.drop_prob_lm, inplace=True)), 110 | nn.Sequential(nn.Linear(1024, self.rnn_size//2), # for motion feature 111 | nn.ReLU(), 112 | nn.Dropout(self.drop_prob_lm, inplace=True))]) 113 | 114 | self.att_embed_aux = nn.Sequential(nn.BatchNorm1d(self.rnn_size), 115 | nn.ReLU()) 116 | 117 | self.pool_embed = nn.Sequential(nn.Linear(self.pool_feat_size, self.rnn_size), 118 | nn.ReLU(), 119 | nn.Dropout(self.drop_prob_lm, inplace=True)) 120 | 121 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 122 | self.ctx2pool = nn.Linear(self.rnn_size, self.att_hid_size) 123 | 124 | self.logit = nn.Linear(self.rnn_size, self.vocab_size) 125 | 126 | if opt.obj_interact: 127 | n_layers = 2 128 | n_heads = 6 129 | attn_drop = 0.2 130 | self.obj_interact = Transformer(self.rnn_size, 0, 0, 131 | d_hidden=int(self.rnn_size/2), 132 | n_layers=n_layers, 133 | n_heads=n_heads, 134 | drop_ratio=attn_drop, 135 | pe=False) 136 | 137 | if self.att_model == 'transformer': 138 | n_layers = 2 139 | n_heads = 6 140 | attn_drop = 0.2 141 | print('initiailze language decoder transformer...') 142 | self.cap_model = TransformerDecoder(self.rnn_size, 0, self.vocab_size, \ 143 | d_hidden = self.rnn_size//2, n_layers=n_layers, n_heads=n_heads, drop_ratio=attn_drop) 144 | 145 | if opt.t_attn_mode == 'bilstm': # frame-wise feature encoding 146 | n_layers = 2 147 | attn_drop = 0.2 148 | self.context_enc = nn.LSTM(self.rnn_size, self.rnn_size//2, n_layers, dropout=attn_drop, \ 149 | bidirectional=True, batch_first=True) 150 | elif opt.t_attn_mode == 'bigru': 151 | n_layers = 2 152 | attn_drop = 0.2 153 | self.context_enc = nn.GRU(self.rnn_size, self.rnn_size//2, n_layers, dropout=attn_drop, \ 154 | bidirectional=True, batch_first=True) 155 | else: 156 | raise NotImplementedError 157 | 158 | self.ctx2pool_grd = nn.Sequential(nn.Linear(self.att_feat_size, self.vis_encoding_size), # fc7 layer 159 | nn.ReLU(), 160 | nn.Dropout(self.drop_prob_lm, inplace=True) 161 | ) 162 | 163 | self.critLM = utils.LMCriterion(opt) 164 | 165 | # initialize the glove weight for the labels. 166 | # self.det_fc[0].weight.data.copy_(opt.glove_vg_cls) 167 | # for p in self.det_fc[0].parameters(): p.requires_grad=False 168 | 169 | # self.embed[0].weight.data.copy_(torch.cat((opt.glove_w, opt.glove_clss))) 170 | # for p in self.embed[0].parameters(): p.requires_grad=False 171 | 172 | # weights transfer for fc7 layer 173 | with open('data/detectron_weights/fc7_w.pkl', 'rb') as f: 174 | fc7_w = torch.from_numpy(pickle.load(f)) 175 | with open('data/detectron_weights/fc7_b.pkl', 'rb') as f: 176 | fc7_b = torch.from_numpy(pickle.load(f)) 177 | self.ctx2pool_grd[0].weight[:self.att_feat_size].data.copy_(fc7_w) 178 | self.ctx2pool_grd[0].bias[:self.att_feat_size].data.copy_(fc7_b) 179 | 180 | if self.transfer_mode in ('cls', 'both'): 181 | # find nearest neighbour class for transfer 182 | with open('data/detectron_weights/cls_score_w.pkl', 'rb') as f: 183 | cls_score_w = torch.from_numpy(pickle.load(f)) # 1601x2048 184 | with open('data/detectron_weights/cls_score_b.pkl', 'rb') as f: 185 | cls_score_b = torch.from_numpy(pickle.load(f)) # 1601x2048 186 | 187 | assert(len(opt.itod)+1 == opt.glove_clss.size(0)) # index 0 is background 188 | assert(len(opt.vg_cls) == opt.glove_vg_cls.size(0)) # index 0 is background 189 | 190 | sim_matrix = torch.matmul(opt.glove_vg_cls/torch.norm(opt.glove_vg_cls, dim=1).unsqueeze(1), \ 191 | (opt.glove_clss/torch.norm(opt.glove_clss, dim=1).unsqueeze(1)).transpose(1,0)) 192 | 193 | max_sim, matched_cls = torch.max(sim_matrix, dim=0) 194 | self.max_sim = max_sim 195 | self.matched_cls = matched_cls 196 | 197 | vis_classifiers = opt.glove_clss.new(self.detect_size+1, cls_score_w.size(1)).fill_(0) 198 | self.vis_classifiers_bias = nn.Parameter(opt.glove_clss.new(self.detect_size+1).fill_(0)) 199 | vis_classifiers[0] = cls_score_w[0] # background 200 | self.vis_classifiers_bias[0].data.copy_(cls_score_b[0]) 201 | for i in range(1, self.detect_size+1): 202 | vis_classifiers[i] = cls_score_w[matched_cls[i]] 203 | self.vis_classifiers_bias[i].data.copy_(cls_score_b[matched_cls[i]]) 204 | if max_sim[i].item() < 0.9: 205 | print('index: {}, similarity: {:.2}, {}, {}'.format(i, max_sim[i].item(), \ 206 | opt.itod[i], opt.vg_cls[matched_cls[i]])) 207 | 208 | if self.transfer_mode == 'cls': 209 | self.vis_embed[0].weight.data.copy_(vis_classifiers) 210 | else: 211 | self.vis_embed[0].weight.data.copy_(torch.cat((vis_classifiers, opt.glove_clss), dim=1)) 212 | elif self.transfer_mode == 'glove': 213 | self.vis_embed[0].weight.data.copy_(opt.glove_clss) 214 | elif self.transfer_mode == 'none': 215 | print('No knowledge transfer...') 216 | else: 217 | raise NotImplementedError 218 | 219 | # for p in self.ctx2pool_grd.parameters(): p.requires_grad=False 220 | # for p in self.vis_embed[0].parameters(): p.requires_grad=False 221 | 222 | if opt.enable_visdom: 223 | import visdom 224 | self.vis = visdom.Visdom(server=opt.visdom_server, env='vis-'+opt.id) 225 | 226 | 227 | def forward(self, segs_feat, seq, gt_seq, num, ppls, gt_boxes, mask_boxes, ppls_feat, frm_mask, sample_idx, pnt_mask, opt, eval_opt = {}): 228 | if opt == 'MLE': 229 | return self._forward(segs_feat, seq, gt_seq, ppls, gt_boxes, mask_boxes, num, ppls_feat, frm_mask, sample_idx, pnt_mask) 230 | elif opt == 'GRD': 231 | return self._forward(segs_feat, seq, gt_seq, ppls, gt_boxes, mask_boxes, num, ppls_feat, frm_mask, sample_idx, pnt_mask, True) 232 | elif opt == 'sample': 233 | seq, seqLogprobs, att2, sim_mat = self._sample(segs_feat, ppls, num, ppls_feat, sample_idx, pnt_mask, eval_opt) 234 | return Variable(seq), Variable(att2), Variable(sim_mat) 235 | 236 | 237 | def init_hidden(self, bsz): 238 | weight = next(self.parameters()).data 239 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 240 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 241 | 242 | 243 | def _grounder(self, xt, att_feats, mask, bias=None): 244 | # xt - B, seq_cnt, enc_size 245 | # att_feats - B, rois_num, enc_size 246 | # mask - B, rois_num 247 | # 248 | # dot - B, seq_cnt, rois_num 249 | 250 | B, S, _ = xt.size() 251 | _, R, _ = att_feats.size() 252 | 253 | if hasattr(self, 'alpha_net'): 254 | # Additive attention for grounding 255 | if self.alpha_net.weight.size(1) == self.att_hid_size: 256 | dot = xt.unsqueeze(2) + att_feats.unsqueeze(1) 257 | else: 258 | dot = torch.cat((xt.unsqueeze(2).expand(B, S, R, self.att_hid_size), 259 | att_feats.unsqueeze(1).expand(B, S, R, self.att_hid_size)), 3) 260 | dot = F.tanh(dot) 261 | dot = self.alpha_net(dot).squeeze(-1) 262 | else: 263 | # Dot-product attention for grounding 264 | assert(xt.size(-1) == att_feats.size(-1)) 265 | dot = torch.matmul(xt, att_feats.permute(0,2,1).contiguous()) # B, seq_cnt, rois_num 266 | 267 | if bias is not None: 268 | assert(bias.numel() == dot.numel()) 269 | dot += bias 270 | 271 | if mask.dim() == 2: 272 | expanded_mask = mask.unsqueeze(1).expand_as(dot) 273 | elif mask.dim() == 3: # if expanded already 274 | expanded_mask = mask 275 | else: 276 | raise NotImplementedError 277 | 278 | dot.masked_fill_(expanded_mask, self.min_value) 279 | 280 | return dot 281 | 282 | 283 | def _forward(self, segs_feat, input_seq, gt_seq, ppls, gt_boxes, mask_boxes, num, ppls_feat, frm_mask, sample_idx, pnt_mask, eval_obj_ground=False): 284 | 285 | seq = gt_seq[:, :self.seq_per_img, :].clone().view(-1, gt_seq.size(2)) # choose the first seq_per_img 286 | seq = torch.cat((Variable(seq.data.new(seq.size(0), 1).fill_(0)), seq), 1) 287 | input_seq = input_seq.view(-1, input_seq.size(2), input_seq.size(3)) # B*self.seq_per_img, self.seq_length+1, 5 288 | input_seq_update = input_seq.data.clone() 289 | 290 | batch_size = segs_feat.size(0) # B 291 | seq_batch_size = seq.size(0) # B*self.seq_per_img 292 | rois_num = ppls.size(1) # max_num_proposal of the batch 293 | 294 | state = self.init_hidden(seq_batch_size) # self.num_layers, B*self.seq_per_img, self.rnn_size 295 | rnn_output = [] 296 | roi_labels = [] # store which proposal match the gt box 297 | att2_weights = [] 298 | h_att_output = [] 299 | max_grd_output = [] 300 | frm_mask_output = [] 301 | 302 | conv_feats = segs_feat 303 | sample_idx_mask = conv_feats.new(batch_size, conv_feats.size(1), 1).fill_(1).byte() 304 | for i in range(batch_size): 305 | sample_idx_mask[i, sample_idx[i,0]:sample_idx[i,1]] = 0 306 | fc_feats = torch.mean(segs_feat, dim=1) 307 | fc_feats = torch.cat((F.layer_norm(fc_feats, [self.fc_feat_size-self.seg_info_size]), \ 308 | F.layer_norm(self.seg_info_embed(num[:, 3:7].float()), [self.seg_info_size])), dim=-1) 309 | 310 | # pooling the conv_feats 311 | pool_feats = ppls_feat 312 | pool_feats = self.ctx2pool_grd(pool_feats) 313 | g_pool_feats = pool_feats 314 | 315 | # calculate the overlaps between the rois/rois and rois/gt_bbox. 316 | # apply both frame mask and proposal mask 317 | overlaps = utils.bbox_overlaps(ppls.data, gt_boxes.data, \ 318 | (frm_mask | pnt_mask[:, 1:].unsqueeze(-1)).data) 319 | 320 | # visual words embedding 321 | vis_word = Variable(torch.Tensor(range(0, self.detect_size+1)).type(input_seq.type())) 322 | vis_word_embed = self.vis_embed(vis_word) 323 | assert(vis_word_embed.size(0) == self.detect_size+1) 324 | 325 | p_vis_word_embed = vis_word_embed.view(1, self.detect_size+1, self.vis_encoding_size) \ 326 | .expand(batch_size, self.detect_size+1, self.vis_encoding_size).contiguous() 327 | 328 | if hasattr(self, 'vis_classifiers_bias'): 329 | bias = self.vis_classifiers_bias.type(p_vis_word_embed.type()) \ 330 | .view(1,-1,1).expand(p_vis_word_embed.size(0), \ 331 | p_vis_word_embed.size(1), g_pool_feats.size(1)) 332 | else: 333 | bias = None 334 | 335 | # region-class similarity matrix 336 | sim_mat_static = self._grounder(p_vis_word_embed, g_pool_feats, pnt_mask[:,1:], bias) 337 | sim_mat_static_update = sim_mat_static.view(batch_size, 1, self.detect_size+1, rois_num) \ 338 | .expand(batch_size, self.seq_per_img, self.detect_size+1, rois_num).contiguous() \ 339 | .view(seq_batch_size, self.detect_size+1, rois_num) 340 | sim_mat_static = F.softmax(sim_mat_static, dim=1) 341 | 342 | if self.test_mode: 343 | cls_pred = 0 344 | else: 345 | sim_target = utils.sim_mat_target(overlaps, gt_boxes[:,:,5].data) # B, num_box, num_rois 346 | sim_mask = (sim_target > 0) 347 | if not eval_obj_ground: 348 | masked_sim = torch.gather(sim_mat_static, 1, sim_target) 349 | masked_sim = torch.masked_select(masked_sim, sim_mask) 350 | cls_loss = F.binary_cross_entropy(masked_sim, masked_sim.new(masked_sim.size()).fill_(1)) 351 | else: 352 | # region classification accuracy 353 | sim_target_masked = torch.masked_select(sim_target, sim_mask) 354 | sim_mat_masked = torch.masked_select(torch.max(sim_mat_static, dim=1)[1].unsqueeze(1).expand_as(sim_target), sim_mask) 355 | cls_pred = torch.stack((sim_target_masked, sim_mat_masked), dim=1).data 356 | 357 | if not self.enable_BUTD: 358 | loc_input = ppls.data.new(batch_size, rois_num, 5) 359 | loc_input[:,:,:4] = ppls.data[:,:,:4] / 720. 360 | loc_input[:,:,4] = ppls.data[:,:,4]*1./self.num_sampled_frm 361 | loc_feats = self.loc_fc(Variable(loc_input)) # encode the locations 362 | label_feat = sim_mat_static.permute(0,2,1).contiguous() 363 | pool_feats = torch.cat((F.layer_norm(pool_feats, [pool_feats.size(-1)]), \ 364 | F.layer_norm(loc_feats, [loc_feats.size(-1)]), F.layer_norm(label_feat, [label_feat.size(-1)])), 2) 365 | 366 | # replicate the feature to map the seq size. 367 | fc_feats = fc_feats.view(batch_size, 1, self.fc_feat_size)\ 368 | .expand(batch_size, self.seq_per_img, self.fc_feat_size)\ 369 | .contiguous().view(-1, self.fc_feat_size) 370 | pool_feats = pool_feats.view(batch_size, 1, rois_num, self.pool_feat_size)\ 371 | .expand(batch_size, self.seq_per_img, rois_num, self.pool_feat_size)\ 372 | .contiguous().view(-1, rois_num, self.pool_feat_size) 373 | g_pool_feats = g_pool_feats.view(batch_size, 1, rois_num, self.vis_encoding_size) \ 374 | .expand(batch_size, self.seq_per_img, rois_num, self.vis_encoding_size) \ 375 | .contiguous().view(-1, rois_num, self.vis_encoding_size) 376 | pnt_mask = pnt_mask.view(batch_size, 1, rois_num+1).expand(batch_size, self.seq_per_img, rois_num+1)\ 377 | .contiguous().view(-1, rois_num+1) 378 | overlaps = overlaps.view(batch_size, 1, rois_num, overlaps.size(2)) \ 379 | .expand(batch_size, self.seq_per_img, rois_num, overlaps.size(2)) \ 380 | .contiguous().view(-1, rois_num, overlaps.size(2)) 381 | 382 | # embed fc and att feats 383 | fc_feats = self.fc_embed(fc_feats) 384 | pool_feats = self.pool_embed(pool_feats) 385 | 386 | # object region interactions 387 | if hasattr(self, 'obj_interact'): 388 | pool_feats = self.obj_interact(pool_feats) 389 | 390 | # Project the attention feats first to reduce memory and computation comsumptions. 391 | p_pool_feats = self.ctx2pool(pool_feats) # same here 392 | 393 | if self.att_input_mode in ('both', 'featmap'): 394 | conv_feats_splits = torch.split(conv_feats, 2048, 2) 395 | conv_feats = torch.cat([m(c) for (m,c) in zip(self.att_embed, conv_feats_splits)], dim=2) 396 | conv_feats = conv_feats.permute(0,2,1).contiguous() # inconsistency between Torch TempConv and PyTorch Conv1d 397 | conv_feats = self.att_embed_aux(conv_feats) 398 | conv_feats = conv_feats.permute(0,2,1).contiguous() # inconsistency between Torch TempConv and PyTorch Conv1d 399 | conv_feats = self.context_enc(conv_feats)[0] 400 | 401 | conv_feats = conv_feats.masked_fill(sample_idx_mask, 0) 402 | conv_feats = conv_feats.view(batch_size, 1, self.t_attn_size, self.rnn_size)\ 403 | .expand(batch_size, self.seq_per_img, self.t_attn_size, self.rnn_size)\ 404 | .contiguous().view(-1, self.t_attn_size, self.rnn_size) 405 | p_conv_feats = self.ctx2att(conv_feats) # self.rnn_size (1024) -> self.att_hid_size (512) 406 | else: 407 | # dummy 408 | conv_feats = pool_feats.new(1,1).fill_(0) 409 | p_conv_feats = pool_feats.new(1,1).fill_(0) 410 | 411 | if self.att_model == 'transformer': # Masked Transformer does not support box supervision yet 412 | if self.att_input_mode == 'both': 413 | lm_loss = self.cap_model([conv_feats, pool_feats], seq) 414 | elif self.att_input_mode == 'featmap': 415 | lm_loss = self.cap_model([conv_feats, conv_feats], seq) 416 | elif self.att_input_mode == 'region': 417 | lm_loss = self.cap_model([pool_feats, pool_feats], seq) 418 | return lm_loss.unsqueeze(0), lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0), \ 419 | lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0), lm_loss.new(1).fill_(0) 420 | elif self.att_model == 'topdown': 421 | for i in range(self.seq_length): 422 | it = seq[:, i].clone() 423 | 424 | # break if all the sequences end 425 | if i >= 1 and seq[:, i].data.sum() == 0: 426 | break 427 | 428 | xt = self.embed(it) 429 | 430 | if not eval_obj_ground: 431 | roi_label = utils.bbox_target(mask_boxes[:,:,:,i+1], overlaps, input_seq[:,i+1], \ 432 | input_seq_update[:,i+1], self.vocab_size) # roi_label if for the target seq 433 | roi_labels.append(roi_label.view(seq_batch_size, -1)) 434 | 435 | # use frame mask during training 436 | box_mask = mask_boxes[:,0,:,i+1].contiguous().unsqueeze(1).expand(( 437 | batch_size, rois_num, mask_boxes.size(2))) 438 | frm_mask_on_prop = (torch.sum((1 - (box_mask | frm_mask)), dim=2)<=0) 439 | frm_mask_on_prop = torch.cat((frm_mask_on_prop.new(batch_size, 1).fill_(0.), \ 440 | frm_mask_on_prop), dim=1) | pnt_mask 441 | output, state, att2_weight, att_h, max_grd_val, grd_val = self.core(xt, fc_feats, \ 442 | conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask, frm_mask_on_prop, \ 443 | state, sim_mat_static_update) 444 | frm_mask_output.append(frm_mask_on_prop) 445 | else: 446 | output, state, att2_weight, att_h, max_grd_val, grd_val = self.core(xt, fc_feats, \ 447 | conv_feats, p_conv_feats, pool_feats, p_pool_feats, pnt_mask, pnt_mask, \ 448 | state, sim_mat_static_update) 449 | 450 | att2_weights.append(att2_weight) 451 | h_att_output.append(att_h) # the hidden state of attention LSTM 452 | rnn_output.append(output) 453 | max_grd_output.append(max_grd_val) 454 | 455 | seq_cnt = len(rnn_output) 456 | rnn_output = torch.cat([_.unsqueeze(1) for _ in rnn_output], 1) # seq_batch_size, seq_cnt, vocab 457 | h_att_output = torch.cat([_.unsqueeze(1) for _ in h_att_output], 1) 458 | att2_weights = torch.cat([_.unsqueeze(1) for _ in att2_weights], 1) # seq_batch_size, seq_cnt, att_size 459 | max_grd_output = torch.cat([_.unsqueeze(1) for _ in max_grd_output], 1) 460 | if not eval_obj_ground: 461 | frm_mask_output = torch.cat([_.unsqueeze(1) for _ in frm_mask_output], 1) 462 | roi_labels = torch.cat([_.unsqueeze(1) for _ in roi_labels], 1) 463 | 464 | decoded = F.log_softmax(self.beta * self.logit(rnn_output), dim=2) # text word prob 465 | decoded = decoded.view((seq_cnt)*seq_batch_size, -1) 466 | 467 | # object grounding 468 | h_att_all = h_att_output # hidden states from the Attention LSTM 469 | xt_clamp = torch.clamp(input_seq[:, 1:seq_cnt+1, 0].clone()-self.vocab_size, min=0) 470 | xt_all = self.vis_embed(xt_clamp) 471 | 472 | if hasattr(self, 'vis_classifiers_bias'): 473 | bias = self.vis_classifiers_bias[xt_clamp].type(xt_all.type()) \ 474 | .unsqueeze(2).expand(seq_batch_size, seq_cnt, rois_num) 475 | else: 476 | bias = 0 477 | 478 | if not eval_obj_ground: 479 | # att2_weights/ground_weights with both proposal mask and frame mask 480 | ground_weights = self._grounder(xt_all, g_pool_feats, frm_mask_output[:,:,1:], bias+att2_weights) 481 | lm_loss, att2_loss, ground_loss = self.critLM(decoded, att2_weights, ground_weights, \ 482 | seq[:, 1:seq_cnt+1].clone(), roi_labels[:, :seq_cnt, :].clone(), input_seq[:, 1:seq_cnt+1, 0].clone()) 483 | return lm_loss.unsqueeze(0), att2_loss.unsqueeze(0), ground_loss.unsqueeze(0), cls_loss.unsqueeze(0) 484 | else: 485 | # att2_weights/ground_weights with proposal mask only 486 | ground_weights = self._grounder(xt_all, g_pool_feats, pnt_mask[:,1:], bias+att2_weights) 487 | return cls_pred, torch.max(att2_weights.view(seq_batch_size, seq_cnt, self.num_sampled_frm, \ 488 | self.num_prop_per_frm), dim=-1)[1], torch.max(ground_weights.view(seq_batch_size, \ 489 | seq_cnt, self.num_sampled_frm, self.num_prop_per_frm), dim=-1)[1] 490 | 491 | 492 | def _sample(self, segs_feat, ppls, num, ppls_feat, sample_idx, pnt_mask, opt={}): 493 | sample_max = opt.get('sample_max', 1) 494 | beam_size = opt.get('beam_size', 1) 495 | temperature = opt.get('temperature', 1.0) 496 | inference_mode = opt.get('inference_mode', True) 497 | 498 | batch_size = segs_feat.size(0) 499 | rois_num = ppls.size(1) 500 | 501 | if beam_size > 1: 502 | return self._sample_beam(segs_feat, ppls, num, ppls_feat, sample_idx, pnt_mask, opt) 503 | 504 | conv_feats = segs_feat 505 | sample_idx_mask = conv_feats.new(batch_size, conv_feats.size(1), 1).fill_(1).byte() 506 | for i in range(batch_size): 507 | sample_idx_mask[i, sample_idx[i,0]:sample_idx[i,1]] = 0 508 | fc_feats = torch.mean(segs_feat, dim=1) 509 | fc_feats = torch.cat((F.layer_norm(fc_feats, [self.fc_feat_size-self.seg_info_size]), \ 510 | F.layer_norm(self.seg_info_embed(num[:, 3:7].float()), [self.seg_info_size])), dim=-1) 511 | 512 | pool_feats = ppls_feat 513 | pool_feats = self.ctx2pool_grd(pool_feats) 514 | g_pool_feats = pool_feats 515 | 516 | att_mask = pnt_mask.clone() 517 | 518 | # visual words embedding 519 | vis_word = Variable(torch.Tensor(range(0, self.detect_size+1)).type(fc_feats.type())).long() 520 | vis_word_embed = self.vis_embed(vis_word) 521 | assert(vis_word_embed.size(0) == self.detect_size+1) 522 | 523 | p_vis_word_embed = vis_word_embed.view(1, self.detect_size+1, self.vis_encoding_size) \ 524 | .expand(batch_size, self.detect_size+1, self.vis_encoding_size).contiguous() 525 | 526 | if hasattr(self, 'vis_classifiers_bias'): 527 | bias = self.vis_classifiers_bias.type(p_vis_word_embed.type()) \ 528 | .view(1,-1,1).expand(p_vis_word_embed.size(0), \ 529 | p_vis_word_embed.size(1), g_pool_feats.size(1)) 530 | else: 531 | bias = None 532 | 533 | sim_mat_static = self._grounder(p_vis_word_embed, g_pool_feats, pnt_mask[:,1:], bias) 534 | sim_mat_static_update = sim_mat_static 535 | sim_mat_static = F.softmax(sim_mat_static, dim=1) 536 | 537 | if not self.enable_BUTD: 538 | loc_input = ppls.data.new(batch_size, rois_num, 5) 539 | loc_input[:,:,:4] = ppls.data[:,:,:4] / 720. 540 | loc_input[:,:,4] = ppls.data[:,:,4]*1./self.num_sampled_frm 541 | loc_feats = self.loc_fc(Variable(loc_input)) # encode the locations 542 | label_feat = sim_mat_static.permute(0,2,1).contiguous() 543 | pool_feats = torch.cat((F.layer_norm(pool_feats, [pool_feats.size(-1)]), F.layer_norm(loc_feats, \ 544 | [loc_feats.size(-1)]), F.layer_norm(label_feat, [label_feat.size(-1)])), 2) 545 | 546 | # embed fc and att feats 547 | pool_feats = self.pool_embed(pool_feats) 548 | fc_feats = self.fc_embed(fc_feats) 549 | # object region interactions 550 | if hasattr(self, 'obj_interact'): 551 | pool_feats = self.obj_interact(pool_feats) 552 | 553 | # Project the attention feats first to reduce memory and computation comsumptions. 554 | p_pool_feats = self.ctx2pool(pool_feats) 555 | 556 | if self.att_input_mode in ('both', 'featmap'): 557 | conv_feats_splits = torch.split(conv_feats, 2048, 2) 558 | conv_feats = torch.cat([m(c) for (m,c) in zip(self.att_embed, conv_feats_splits)], dim=2) 559 | conv_feats = conv_feats.permute(0,2,1).contiguous() # inconsistency between Torch TempConv and PyTorch Conv1d 560 | conv_feats = self.att_embed_aux(conv_feats) 561 | conv_feats = conv_feats.permute(0,2,1).contiguous() # inconsistency between Torch TempConv and PyTorch Conv1d 562 | conv_feats = self.context_enc(conv_feats)[0] 563 | 564 | conv_feats = conv_feats.masked_fill(sample_idx_mask, 0) 565 | p_conv_feats = self.ctx2att(conv_feats) 566 | else: 567 | conv_feats = pool_feats.new(1,1).fill_(0) 568 | p_conv_feats = pool_feats.new(1,1).fill_(0) 569 | 570 | if self.att_model == 'transformer': 571 | if self.att_input_mode == 'both': 572 | seq = self.cap_model([conv_feats, pool_feats], [], infer=True, seq_length=self.seq_length) 573 | elif self.att_input_mode == 'featmap': 574 | seq = self.cap_model([conv_feats, conv_feats], [], infer=True, seq_length=self.seq_length) 575 | elif self.att_input_mode == 'region': 576 | seq = self.cap_model([pool_feats, pool_feats], [], infer=True, seq_length=self.seq_length) 577 | 578 | return seq, seq.new(batch_size, 1).fill_(0), seq.new(batch_size, 1).fill_(0).long() 579 | elif self.att_model == 'topdown': 580 | state = self.init_hidden(batch_size) 581 | 582 | seq = [] 583 | seqLogprobs = [] 584 | att2_weights = [] 585 | 586 | for t in range(self.seq_length + 1): 587 | if t == 0: # input 588 | it = fc_feats.data.new(batch_size).long().zero_() 589 | elif sample_max: 590 | sampleLogprobs_tmp, it_tmp = torch.topk(logprobs.data, 2, dim=1) 591 | unk_mask = (it_tmp[:,0] != self.unk_idx) # mask on non-unk 592 | sampleLogprobs = unk_mask.float()*sampleLogprobs_tmp[:,0] + (1-unk_mask.float())*sampleLogprobs_tmp[:,1] 593 | it = unk_mask.long()*it_tmp[:,0] + (1-unk_mask.long())*it_tmp[:,1] 594 | it = it.view(-1).long() 595 | else: 596 | if temperature == 1.0: 597 | prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) 598 | else: 599 | # scale logprobs by temperature 600 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)) 601 | it = torch.multinomial(prob_prev, 1) 602 | sampleLogprobs = logprobs.gather(1, Variable(it)) # gather the logprobs at sampled positions 603 | it = it.view(-1).long() # and flatten indices for downstream processing 604 | 605 | xt = self.embed(Variable(it)) 606 | if t >= 1: 607 | seq.append(it) #seq[t] the input of t+2 time step 608 | seqLogprobs.append(sampleLogprobs.view(-1)) 609 | 610 | if t < self.seq_length: 611 | rnn_output, state, att2_weight, att_h, _, _ = self.core(xt, fc_feats, conv_feats, \ 612 | p_conv_feats, pool_feats, p_pool_feats, att_mask, pnt_mask, state, \ 613 | sim_mat_static_update) 614 | 615 | decoded = F.log_softmax(self.beta * self.logit(rnn_output), dim=1) 616 | 617 | logprobs = decoded 618 | att2_weights.append(att2_weight) 619 | 620 | seq = torch.cat([_.unsqueeze(1) for _ in seq], 1) 621 | seqLogprobs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 622 | att2_weights = torch.cat([_.unsqueeze(1) for _ in att2_weights], 1) # batch_size, seq_cnt, att_size 623 | 624 | return seq, seqLogprobs, att2_weights, sim_mat_static 625 | 626 | 627 | def _sample_beam(self, segs_feat, ppls, num, ppls_feat, sample_idx, pnt_mask, opt={}): 628 | 629 | batch_size = ppls.size(0) 630 | rois_num = ppls.size(1) 631 | 632 | beam_size = opt.get('beam_size', 10) 633 | 634 | conv_feats = segs_feat 635 | sample_idx_mask = conv_feats.new(batch_size, conv_feats.size(1), 1).fill_(1).byte() 636 | for i in range(batch_size): 637 | sample_idx_mask[i, sample_idx[i,0]:sample_idx[i,1]] = 0 638 | fc_feats = torch.mean(segs_feat, dim=1) 639 | fc_feats = torch.cat((F.layer_norm(fc_feats, [self.fc_feat_size-self.seg_info_size]), \ 640 | F.layer_norm(self.seg_info_embed(num[:, 3:7].float()), [self.seg_info_size])), dim=-1) 641 | 642 | pool_feats = ppls_feat 643 | pool_feats = self.ctx2pool_grd(pool_feats) 644 | g_pool_feats = pool_feats 645 | 646 | # visual words embedding 647 | vis_word = Variable(torch.Tensor(range(0, self.detect_size+1)).type(fc_feats.type())).long() 648 | vis_word_embed = self.vis_embed(vis_word) 649 | assert(vis_word_embed.size(0) == self.detect_size+1) 650 | 651 | p_vis_word_embed = vis_word_embed.view(1, self.detect_size+1, self.vis_encoding_size) \ 652 | .expand(batch_size, self.detect_size+1, self.vis_encoding_size).contiguous() 653 | 654 | if hasattr(self, 'vis_classifiers_bias'): 655 | bias = self.vis_classifiers_bias.type(p_vis_word_embed.type()) \ 656 | .view(1,-1,1).expand(p_vis_word_embed.size(0), \ 657 | p_vis_word_embed.size(1), g_pool_feats.size(1)) 658 | else: 659 | bias = None 660 | 661 | sim_mat_static = self._grounder(p_vis_word_embed, g_pool_feats, pnt_mask[:,1:], bias) 662 | sim_mat_static_update = sim_mat_static 663 | sim_mat_static = F.softmax(sim_mat_static, dim=1) 664 | 665 | if not self.enable_BUTD: 666 | loc_input = ppls.data.new(batch_size, rois_num, 5) 667 | loc_input[:,:,:4] = ppls.data[:,:,:4] / 720. 668 | loc_input[:,:,4] = ppls.data[:,:,4]*1./self.num_sampled_frm 669 | loc_feats = self.loc_fc(Variable(loc_input)) # encode the locations 670 | 671 | label_feat = sim_mat_static.permute(0,2,1).contiguous() 672 | 673 | pool_feats = torch.cat((F.layer_norm(pool_feats, [pool_feats.size(-1)]), F.layer_norm(loc_feats, [loc_feats.size(-1)]), \ 674 | F.layer_norm(label_feat, [label_feat.size(-1)])), 2) 675 | 676 | # embed fc and att feats 677 | pool_feats = self.pool_embed(pool_feats) 678 | fc_feats = self.fc_embed(fc_feats) 679 | # object region interactions 680 | if hasattr(self, 'obj_interact'): 681 | pool_feats = self.obj_interact(pool_feats) 682 | 683 | # Project the attention feats first to reduce memory and computation comsumptions. 684 | p_pool_feats = self.ctx2pool(pool_feats) 685 | 686 | if self.att_input_mode in ('both', 'featmap'): 687 | conv_feats_splits = torch.split(conv_feats, 2048, 2) 688 | conv_feats = torch.cat([m(c) for (m,c) in zip(self.att_embed, conv_feats_splits)], dim=2) 689 | conv_feats = conv_feats.permute(0,2,1).contiguous() # inconsistency between Torch TempConv and PyTorch Conv1d 690 | conv_feats = self.att_embed_aux(conv_feats) 691 | conv_feats = conv_feats.permute(0,2,1).contiguous() # inconsistency between Torch TempConv and PyTorch Conv1d 692 | conv_feats = self.context_enc(conv_feats)[0] 693 | 694 | conv_feats = conv_feats.masked_fill(sample_idx_mask, 0) 695 | p_conv_feats = self.ctx2att(conv_feats) 696 | else: 697 | conv_feats = pool_feats.new(1,1).fill_(0) 698 | p_conv_feats = pool_feats.new(1,1).fill_(0) 699 | 700 | vis_offset = (torch.arange(0, beam_size)*rois_num).view(beam_size).type_as(ppls.data).long() 701 | roi_offset = (torch.arange(0, beam_size)*(rois_num+1)).view(beam_size).type_as(ppls.data).long() 702 | 703 | seq = ppls.data.new(self.seq_length, batch_size).zero_().long() 704 | seqLogprobs = ppls.data.new(self.seq_length, batch_size).float() 705 | att2 = ppls.data.new(self.seq_length, batch_size).fill_(-1).long() 706 | 707 | self.done_beams = [[] for _ in range(batch_size)] 708 | for k in range(batch_size): 709 | state = self.init_hidden(beam_size) 710 | beam_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) 711 | beam_pool_feats = pool_feats[k:k+1].expand(beam_size, rois_num, self.rnn_size).contiguous() 712 | if self.att_input_mode in ('both', 'featmap'): 713 | beam_conv_feats = conv_feats[k:k+1].expand(beam_size, conv_feats.size(1), self.rnn_size).contiguous() 714 | beam_p_conv_feats = p_conv_feats[k:k+1].expand(beam_size, conv_feats.size(1), self.att_hid_size).contiguous() 715 | else: 716 | beam_conv_feats = beam_pool_feats.new(1,1).fill_(0) 717 | beam_p_conv_feats = beam_pool_feats.new(1,1).fill_(0) 718 | beam_p_pool_feats = p_pool_feats[k:k+1].expand(beam_size, rois_num, self.att_hid_size).contiguous() 719 | 720 | beam_ppls = ppls[k:k+1].expand(beam_size, rois_num, 7).contiguous() 721 | beam_pnt_mask = pnt_mask[k:k+1].expand(beam_size, rois_num+1).contiguous() 722 | 723 | it = fc_feats.data.new(beam_size).long().zero_() 724 | xt = self.embed(Variable(it)) 725 | 726 | beam_sim_mat_static_update = sim_mat_static_update[k:k+1].expand(beam_size, self.detect_size+1, rois_num) 727 | 728 | rnn_output, state, att2_weight, att_h, _, _ = self.core(xt, beam_fc_feats, beam_conv_feats, 729 | beam_p_conv_feats, beam_pool_feats, beam_p_pool_feats, beam_pnt_mask, beam_pnt_mask, 730 | state, beam_sim_mat_static_update) 731 | 732 | assert(att2_weight.size(0) == beam_size) 733 | att2[0, k] = torch.max(att2_weight, 1)[1][0] 734 | 735 | self.done_beams[k] = self.beam_search(state, rnn_output, beam_fc_feats, beam_conv_feats, beam_p_conv_feats, \ 736 | beam_pool_feats, beam_p_pool_feats, beam_sim_mat_static_update, beam_ppls, beam_pnt_mask, vis_offset, roi_offset, opt) 737 | 738 | seq[:, k] = self.done_beams[k][0]['seq'].cuda() # the first beam has highest cumulative score 739 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'].cuda() 740 | att2[1:, k] = self.done_beams[k][0]['att2'][1:].cuda() 741 | 742 | return seq.t(), seqLogprobs.t(), att2.t() 743 | -------------------------------------------------------------------------------- /misc/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # Originally from https://github.com/salesforce/densecap 9 | """ 10 | Copyright (c) 2018, salesforce.com, inc. 11 | All rights reserved. 12 | SPDX-License-Identifier: BSD-3-Clause 13 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 14 | """ 15 | # Last modified by Luowei Zhou on 12/27/2018 16 | 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | from torch.autograd import Variable 21 | 22 | import random 23 | import string 24 | import sys 25 | import math 26 | import uuid 27 | import numpy as np 28 | 29 | INF = 1e10 30 | 31 | def positional_encodings_like(x, t=None): 32 | if t is None: 33 | positions = torch.arange(0, x.size(1)) 34 | if x.is_cuda: 35 | positions = positions.cuda(x.get_device()) 36 | else: 37 | positions = t 38 | encodings = torch.zeros(*x.size()[1:]) 39 | if x.is_cuda: 40 | encodings = encodings.cuda(x.get_device()) 41 | 42 | 43 | for channel in range(x.size(-1)): 44 | if channel % 2 == 0: 45 | encodings[:, channel] = torch.sin( 46 | positions / 10000 ** (channel / x.size(2))) 47 | else: 48 | encodings[:, channel] = torch.cos( 49 | positions / 10000 ** ((channel - 1) / x.size(2))) 50 | return Variable(encodings) 51 | 52 | def mask(targets, out): 53 | mask = (targets != 0) 54 | out_mask = mask.unsqueeze(-1).expand_as(out) 55 | return targets[mask], out[out_mask].view(-1, out.size(-1)) 56 | 57 | # torch.matmul can't do (4, 3, 2) @ (4, 2) -> (4, 3) 58 | # not exactly true but keep it for legacy reason 59 | def matmul(x, y): 60 | if x.dim() == y.dim(): 61 | return torch.matmul(x, y) 62 | if x.dim() == y.dim() - 1: 63 | return torch.matmul(x.unsqueeze(-2), y).squeeze(-2) 64 | return torch.matmul(x, y.unsqueeze(-2)).squeeze(-2) 65 | 66 | class LayerNorm(nn.Module): 67 | 68 | def __init__(self, d_model, eps=1e-6): 69 | super(LayerNorm, self).__init__() 70 | self.gamma = nn.Parameter(torch.ones(d_model)) 71 | self.beta = nn.Parameter(torch.zeros(d_model)) 72 | self.eps = eps 73 | 74 | def forward(self, x): 75 | mean = x.mean(-1, keepdim=True) 76 | std = x.std(-1, keepdim=True) 77 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 78 | 79 | class ResidualBlock(nn.Module): 80 | 81 | def __init__(self, layer, d_model, drop_ratio): 82 | super(ResidualBlock, self).__init__() 83 | self.layer = layer 84 | self.dropout = nn.Dropout(drop_ratio) 85 | self.layernorm = LayerNorm(d_model) 86 | 87 | def forward(self, *x): 88 | return self.layernorm(x[0] + self.dropout(self.layer(*x))) 89 | 90 | class Attention(nn.Module): 91 | 92 | def __init__(self, d_key, drop_ratio, causal): 93 | super(Attention, self).__init__() 94 | self.scale = math.sqrt(d_key) 95 | self.dropout = nn.Dropout(drop_ratio) 96 | self.causal = causal 97 | 98 | def forward(self, query, key, value): 99 | dot_products = matmul(query, key.transpose(1, 2)) 100 | if query.dim() == 3 and (self is None or self.causal): 101 | tri = torch.ones(key.size(1), key.size(1)).triu(1) * INF 102 | if key.is_cuda: 103 | tri = tri.cuda(key.get_device()) 104 | dot_products.data.sub_(tri.unsqueeze(0)) 105 | return matmul(self.dropout(F.softmax(dot_products / self.scale, dim=-1)), value) 106 | 107 | class MultiHead(nn.Module): 108 | 109 | def __init__(self, d_key, d_value, n_heads, drop_ratio, causal=False): 110 | super(MultiHead, self).__init__() 111 | self.attention = Attention(d_key, drop_ratio, causal=causal) 112 | self.wq = nn.Linear(d_key, d_key, bias=False) 113 | self.wk = nn.Linear(d_key, d_key, bias=False) 114 | self.wv = nn.Linear(d_value, d_value, bias=False) 115 | self.wo = nn.Linear(d_value, d_key, bias=False) 116 | self.n_heads = n_heads 117 | 118 | def forward(self, query, key, value): 119 | query, key, value = self.wq(query), self.wk(key), self.wv(value) 120 | query, key, value = ( 121 | x.chunk(self.n_heads, -1) for x in (query, key, value)) 122 | return self.wo(torch.cat([self.attention(q, k, v) 123 | for q, k, v in zip(query, key, value)], -1)) 124 | 125 | class FeedForward(nn.Module): 126 | 127 | def __init__(self, d_model, d_hidden): 128 | super(FeedForward, self).__init__() 129 | self.linear1 = nn.Linear(d_model, d_hidden) 130 | self.linear2 = nn.Linear(d_hidden, d_model) 131 | 132 | def forward(self, x): 133 | return self.linear2(F.relu(self.linear1(x))) 134 | 135 | class EncoderLayer(nn.Module): 136 | 137 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio): 138 | super(EncoderLayer, self).__init__() 139 | self.selfattn = ResidualBlock( 140 | MultiHead(d_model, d_model, n_heads, drop_ratio), 141 | d_model, drop_ratio) 142 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden), 143 | d_model, drop_ratio) 144 | 145 | def forward(self, x): 146 | return self.feedforward(self.selfattn(x, x, x)) 147 | 148 | class DecoderLayer(nn.Module): 149 | 150 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio): 151 | super(DecoderLayer, self).__init__() 152 | self.selfattn = ResidualBlock( 153 | MultiHead(d_model, d_model, n_heads, drop_ratio, causal=True), 154 | d_model, drop_ratio) 155 | self.attention = ResidualBlock( 156 | MultiHead(d_model, d_model, n_heads, drop_ratio), 157 | d_model, drop_ratio) 158 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden), 159 | d_model, drop_ratio) 160 | 161 | def forward(self, x, encoding): 162 | x = self.selfattn(x, x, x) 163 | return self.feedforward(self.attention(x, encoding, encoding)) 164 | 165 | class Encoder(nn.Module): 166 | 167 | def __init__(self, d_model, d_hidden, n_vocab, n_layers, n_heads, 168 | drop_ratio, pe): 169 | super(Encoder, self).__init__() 170 | # self.linear = nn.Linear(d_model*2, d_model) 171 | self.layers = nn.ModuleList( 172 | [EncoderLayer(d_model, d_hidden, n_heads, drop_ratio) 173 | for i in range(n_layers)]) 174 | self.dropout = nn.Dropout(drop_ratio) 175 | self.pe = pe 176 | 177 | def forward(self, x, mask=None): 178 | # x = self.linear(x) 179 | if self.pe: 180 | x = x+positional_encodings_like(x) # spatial configuration is already encoded 181 | # x = self.dropout(x) # dropout is already in the pool_embed layer 182 | if mask is not None: 183 | x = x*mask 184 | encoding = [] 185 | for layer in self.layers: 186 | x = layer(x) 187 | if mask is not None: 188 | x = x*mask 189 | encoding.append(x) 190 | return encoding 191 | 192 | class Decoder(nn.Module): 193 | 194 | def __init__(self, d_model, d_hidden, vocab_size, n_layers, n_heads, 195 | drop_ratio): 196 | super(Decoder, self).__init__() 197 | self.layers = nn.ModuleList( 198 | [DecoderLayer(d_model, d_hidden, n_heads, drop_ratio) 199 | for i in range(n_layers)]) 200 | self.out = nn.Linear(d_model, vocab_size) 201 | self.dropout = nn.Dropout(drop_ratio) 202 | self.d_model = d_model 203 | # self.vocab = vocab 204 | self.d_out = vocab_size 205 | 206 | def forward(self, x, encoding): 207 | x = F.embedding(x, self.out.weight * math.sqrt(self.d_model)) 208 | x = x+positional_encodings_like(x) 209 | x = self.dropout(x) 210 | for layer, enc in zip(self.layers, encoding): 211 | x = layer(x, enc) 212 | return x 213 | 214 | def greedy(self, encoding, T): 215 | B, _, H = encoding[0].size() 216 | # change T to 20, max # of words in a sentence 217 | # T = 40 218 | # T *= 2 219 | prediction = Variable(encoding[0].data.new(B, T).long().fill_( 220 | 0)) 221 | hiddens = [Variable(encoding[0].data.new(B, T, H).zero_()) 222 | for l in range(len(self.layers) + 1)] 223 | embedW = self.out.weight * math.sqrt(self.d_model) 224 | hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) 225 | for t in range(T): 226 | if t == 0: 227 | hiddens[0][:, t] = hiddens[0][:, t] + F.embedding(Variable( 228 | encoding[0].data.new(B).long().fill_( 229 | 0)), embedW) 230 | else: 231 | hiddens[0][:, t] = hiddens[0][:, t] + F.embedding(prediction[:, t - 1], 232 | embedW) 233 | hiddens[0][:, t] = self.dropout(hiddens[0][:, t]) 234 | for l in range(len(self.layers)): 235 | x = hiddens[l][:, :t + 1] 236 | x = self.layers[l].selfattn(hiddens[l][:, t], x, x) 237 | hiddens[l + 1][:, t] = self.layers[l].feedforward( 238 | self.layers[l].attention(x, encoding[l], encoding[l])) 239 | 240 | _, prediction[:, t] = self.out(hiddens[-1][:, t]).max(-1) 241 | return prediction 242 | 243 | 244 | class Transformer(nn.Module): 245 | 246 | def __init__(self, d_model, n_vocab_src, vocab_trg, d_hidden=2048, 247 | n_layers=6, n_heads=8, drop_ratio=0.1, pe=False): 248 | super(Transformer, self).__init__() 249 | self.encoder = Encoder(d_model, d_hidden, n_vocab_src, n_layers, 250 | n_heads, drop_ratio, pe) 251 | 252 | def forward(self, x): 253 | encoding = self.encoder(x) 254 | return encoding[-1] 255 | # return encoding[-1], encoding 256 | # return torch.cat(encoding, 2) 257 | 258 | def all_outputs(self, x): 259 | encoding = self.encoder(x) 260 | return encoding 261 | 262 | class TransformerDecoder(nn.Module): 263 | 264 | def __init__(self, d_model, n_vocab_src, vocab_trg, d_hidden=2048, 265 | n_layers=2, n_heads=6, drop_ratio=0.2): 266 | super(TransformerDecoder, self).__init__() 267 | self.decoder = Decoder(d_model, d_hidden, vocab_trg, n_layers, 268 | n_heads, drop_ratio) 269 | self.n_layers = n_layers 270 | 271 | def forward(self, encoding, s, ss_ratio=1, infer=False, seq_length=20): 272 | if infer: 273 | greedy = self.decoder.greedy(encoding, seq_length) 274 | return greedy 275 | 276 | out = self.decoder(s[:, :-1].contiguous(), encoding) 277 | targets, out = mask(s[:, 1:].contiguous(), out) 278 | logits = self.decoder.out(out) 279 | assert ss_ratio == 1, 'scheduled sampling does not work under pytorch 0.4' # TODO, ss_ratio<1 triggered gradient issues 280 | return F.cross_entropy(logits, targets) 281 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import collections 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | import numpy as np 17 | import pdb 18 | import os 19 | import json 20 | from misc.bbox_transform import bbox_overlaps_batch 21 | import numbers 22 | import random 23 | import math 24 | from PIL import Image, ImageOps, ImageEnhance 25 | try: 26 | import accimage 27 | except ImportError: 28 | accimage = None 29 | import types 30 | import warnings 31 | import torch.nn.functional as F 32 | import sys 33 | import matplotlib.pyplot as plt 34 | import matplotlib.patches as patches 35 | 36 | noc_object = ['bus', 'bottle', 'couch', 'microwave', 'pizza', 'racket', 'suitcase', 'zebra'] 37 | noc_index = [6, 40, 58, 69, 54, 39, 29, 23] 38 | 39 | noc_word_map = {'bus':'car', 'bottle':'cup', 40 | 'couch':'chair', 'microwave':'oven', 41 | 'pizza': 'cake', 'tennis racket': 'baseball bat', 42 | 'suitcase': 'handbag', 'zebra': 'horse'} 43 | 44 | def _is_pil_image(img): 45 | if accimage is not None: 46 | return isinstance(img, (Image.Image, accimage.Image)) 47 | else: 48 | return isinstance(img, Image.Image) 49 | 50 | def update_values(dict_from, dict_to): 51 | for key, value in dict_from.items(): 52 | if isinstance(value, dict): 53 | update_values(dict_from[key], dict_to[key]) 54 | elif value is not None: 55 | dict_to[key] = dict_from[key] 56 | 57 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 58 | """ 59 | def decode_sequence(itow, itod, ltow, itoc, wtod, seq, bn_seq, fg_seq, vocab_size, opt): 60 | N, D = seq.size() 61 | 62 | out = [] 63 | for i in range(N): 64 | txt = '' 65 | for j in range(D): 66 | if j >= 1: 67 | txt = txt + ' ' 68 | ix = seq[i,j] 69 | if ix > vocab_size: 70 | det_word = itod[fg_seq[i,j].item()] 71 | det_class = itoc[wtod[det_word]] 72 | if opt.decode_noc and det_class in noc_object: 73 | det_word = det_class 74 | 75 | if (bn_seq[i,j] == 1) and det_word in ltow: 76 | word = ltow[det_word] 77 | else: 78 | word = det_word 79 | # word = '[ ' + word + ' ]' 80 | else: 81 | if ix == 0: 82 | break 83 | else: 84 | word = itow[str(ix.item())] 85 | txt = txt + word 86 | out.append(txt) 87 | return out 88 | """ 89 | 90 | def decode_sequence(itow, itod, ltow, itoc, wtod, seq, vocab_size, opt): 91 | N, D = seq.size() 92 | 93 | out = [] 94 | for i in range(N): 95 | txt = '' 96 | for j in range(D): 97 | if j >= 1: 98 | txt = txt + ' ' 99 | ix = seq[i,j] 100 | if ix == 0: 101 | break 102 | else: 103 | word = itow[str(ix.item())] 104 | txt = txt + word 105 | out.append(txt) 106 | return out 107 | 108 | 109 | def repackage_hidden(h, batch_size): 110 | """Wraps hidden states in new Variables, to detach them from their history.""" 111 | if type(h) == Variable: 112 | return Variable(h.data.resize_(h.size(0), batch_size, h.size(2)).zero_()) 113 | else: 114 | return tuple(repackage_hidden(v, batch_size) for v in h) 115 | 116 | 117 | class LMCriterion(nn.Module): 118 | def __init__(self, opt): 119 | super(LMCriterion, self).__init__() 120 | self.vocab_size = opt.vocab_size 121 | 122 | def forward(self, txt_input, att2_weights, ground_weights, target, att2_target, input_seq): 123 | 124 | # att2_weights and ground_weights have the same target 125 | assert(torch.sum(target >= self.vocab_size) == 0) 126 | txt_mask = target.data.gt(0) # generate the mask 127 | txt_mask = torch.cat([txt_mask.new(txt_mask.size(0), 1).fill_(1), txt_mask[:, :-1]], 1) 128 | 129 | target = target.view(-1,1) 130 | txt_select = torch.gather(txt_input, 1, target) 131 | if isinstance(txt_input, Variable): 132 | txt_mask = Variable(txt_mask) 133 | txt_out = - torch.masked_select(txt_select, txt_mask.view(-1,1)) 134 | 135 | assert(txt_out.size(0) == torch.sum(txt_mask.data)) 136 | loss = torch.mean(txt_out) 137 | 138 | # attention loss 139 | att2_loss = -torch.mean(torch.masked_select(F.log_softmax(att2_weights, dim=2), att2_target.byte())) 140 | 141 | # grounding loss 142 | ground_loss = -torch.mean(torch.masked_select(F.log_softmax(ground_weights, dim=2), att2_target.byte())) 143 | 144 | # matching loss 145 | vis_mask = (input_seq > self.vocab_size) 146 | vis_mask = vis_mask.unsqueeze(2).expand_as(att2_weights) 147 | 148 | # match_loss = F.kl_div( 149 | # torch.masked_select(F.log_softmax(att2_weights, dim=2), vis_mask), 150 | # torch.masked_select(F.softmax(Variable(ground_weights.data), dim=2), vis_mask)) 151 | 152 | return loss, att2_loss, ground_loss 153 | 154 | 155 | def set_lr(optimizer, decay_factor): 156 | for group in optimizer.param_groups: 157 | group['lr'] = group['lr'] * decay_factor 158 | 159 | 160 | def crop(img, i, j, h, w): 161 | """Crop the given PIL Image. 162 | Args: 163 | img (PIL Image): Image to be cropped. 164 | i: Upper pixel coordinate. 165 | j: Left pixel coordinate. 166 | h: Height of the cropped image. 167 | w: Width of the cropped image. 168 | Returns: 169 | PIL Image: Cropped image. 170 | """ 171 | if not _is_pil_image(img): 172 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 173 | 174 | return img.crop((j, i, j + w, i + h)) 175 | 176 | 177 | def pad(img, padding, fill=0): 178 | """Pad the given PIL Image on all sides with the given "pad" value. 179 | Args: 180 | img (PIL Image): Image to be padded. 181 | padding (int or tuple): Padding on each border. If a single int is provided this 182 | is used to pad all borders. If tuple of length 2 is provided this is the padding 183 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 184 | this is the padding for the left, top, right and bottom borders 185 | respectively. 186 | fill: Pixel fill value. Default is 0. If a tuple of 187 | length 3, it is used to fill R, G, B channels respectively. 188 | Returns: 189 | PIL Image: Padded image. 190 | """ 191 | if not _is_pil_image(img): 192 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 193 | 194 | if not isinstance(padding, (numbers.Number, tuple)): 195 | raise TypeError('Got inappropriate padding arg') 196 | if not isinstance(fill, (numbers.Number, str, tuple)): 197 | raise TypeError('Got inappropriate fill arg') 198 | 199 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 200 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 201 | "{} element tuple".format(len(padding))) 202 | 203 | return ImageOps.expand(img, border=padding, fill=fill) 204 | 205 | 206 | class RandomCropWithBbox(object): 207 | """Crop the given PIL Image at a random location. 208 | Args: 209 | size (sequence or int): Desired output size of the crop. If size is an 210 | int instead of sequence like (h, w), a square crop (size, size) is 211 | made. 212 | padding (int or sequence, optional): Optional padding on each border 213 | of the image. Default is 0, i.e no padding. If a sequence of length 214 | 4 is provided, it is used to pad left, top, right, bottom borders 215 | respectively. 216 | """ 217 | 218 | def __init__(self, size, padding=0): 219 | if isinstance(size, numbers.Number): 220 | self.size = (int(size), int(size)) 221 | else: 222 | self.size = size 223 | self.padding = padding 224 | 225 | @staticmethod 226 | def get_params(img, output_size): 227 | """Get parameters for ``crop`` for a random crop. 228 | Args: 229 | img (PIL Image): Image to be cropped. 230 | output_size (tuple): Expected output size of the crop. 231 | Returns: 232 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 233 | """ 234 | w, h = img.size 235 | th, tw = output_size 236 | if w == tw and h == th: 237 | return 0, 0, h, w 238 | 239 | i = random.randint(0, h - th) 240 | j = random.randint(0, w - tw) 241 | return i, j, th, tw 242 | 243 | def __call__(self, img, proposals, bboxs): 244 | """ 245 | Args: 246 | img (PIL Image): Image to be cropped. 247 | proposals, bboxs: proposals and bboxs to be cropped. 248 | Returns: 249 | PIL Image: Cropped image. 250 | """ 251 | if self.padding > 0: 252 | img = pad(img, self.padding) 253 | 254 | i, j, h, w = self.get_params(img, self.size) 255 | 256 | proposals[:,1] = proposals[:,1] - i 257 | proposals[:,3] = proposals[:,3] - i 258 | proposals[:, 1] = np.clip(proposals[:, 1], 0, h - 1) 259 | proposals[:, 3] = np.clip(proposals[:, 3], 0, h - 1) 260 | 261 | proposals[:,0] = proposals[:,0] - j 262 | proposals[:,2] = proposals[:,2] - j 263 | proposals[:, 0] = np.clip(proposals[:, 0], 0, w - 1) 264 | proposals[:, 2] = np.clip(proposals[:, 2], 0, w - 1) 265 | 266 | bboxs[:,1] = bboxs[:,1] - i 267 | bboxs[:,3] = bboxs[:,3] - i 268 | bboxs[:, 1] = np.clip(bboxs[:, 1], 0, h - 1) 269 | bboxs[:, 3] = np.clip(bboxs[:, 3], 0, h - 1) 270 | 271 | bboxs[:,0] = bboxs[:,0] - j 272 | bboxs[:,2] = bboxs[:,2] - j 273 | bboxs[:, 0] = np.clip(bboxs[:, 0], 0, w - 1) 274 | bboxs[:, 2] = np.clip(bboxs[:, 2], 0, w - 1) 275 | 276 | return crop(img, i, j, h, w), proposals, bboxs 277 | 278 | def resize_bbox(bbox, width, height, rwidth, rheight): 279 | """ 280 | resize the bbox from height width to rheight rwidth 281 | bbox: x,y,width, height. 282 | """ 283 | width_ratio = rwidth / float(width) 284 | height_ratio = rheight / float(height) 285 | 286 | bbox[:,0] = bbox[:,0] * width_ratio 287 | bbox[:,2] = bbox[:,2] * width_ratio 288 | bbox[:,1] = bbox[:,1] * height_ratio 289 | bbox[:,3] = bbox[:,3] * height_ratio 290 | 291 | return bbox 292 | 293 | def bbox_overlaps(rois, gt_box, frm_mask): 294 | 295 | overlaps = bbox_overlaps_batch(rois[:,:,:5], gt_box[:,:,:5], frm_mask) 296 | 297 | return overlaps 298 | 299 | def sim_mat_target(overlaps, pad_gt_bboxs): 300 | # overlaps: B, num_rois, num_box 301 | # pad_gt_bboxs: B, num_box (class labels) 302 | B, num_rois, num_box = overlaps.size() 303 | assert(num_box == pad_gt_bboxs.size(1)) 304 | masked_labels = (overlaps > 0.5).long() * pad_gt_bboxs.view(B, 1, num_box).long() # could try a higher threshold 305 | return masked_labels.permute(0,2,1).contiguous() 306 | 307 | def bbox_target(mask, overlaps, seq, seq_update, vocab_size): 308 | 309 | mask = mask.data.contiguous() 310 | overlaps_copy = overlaps.clone() 311 | 312 | max_rois = overlaps.size(1) 313 | batch_size = overlaps.size(0) 314 | 315 | overlaps_copy.masked_fill_(mask.view(batch_size, 1, -1).expand_as(overlaps_copy), 0) 316 | max_overlaps, gt_assignment = torch.max(overlaps_copy, 2) 317 | 318 | # get the labels. 319 | labels = (max_overlaps > 0.5).float() 320 | no_proposal_idx = (labels.sum(1) > 0) != (seq.data[:,2] > 0) 321 | 322 | # (deprecated) convert vis word to text word if there is not matched proposal 323 | if no_proposal_idx.sum() > 0: 324 | seq_update[:,0][no_proposal_idx] = seq_update[:,3][no_proposal_idx] 325 | seq_update[:,1][no_proposal_idx] = 0 326 | seq_update[:,2][no_proposal_idx] = 0 327 | 328 | return labels 329 | 330 | def _affine_grid_gen(rois, input_size, grid_size): 331 | 332 | rois = rois.detach() 333 | x1 = rois[:, 1::4] / 16.0 334 | y1 = rois[:, 2::4] / 16.0 335 | x2 = rois[:, 3::4] / 16.0 336 | y2 = rois[:, 4::4] / 16.0 337 | 338 | height = input_size[0] 339 | width = input_size[1] 340 | 341 | zero = Variable(rois.data.new(rois.size(0), 1).zero_()) 342 | theta = torch.cat([\ 343 | (x2 - x1) / (width - 1), 344 | zero, 345 | (x1 + x2 - width + 1) / (width - 1), 346 | zero, 347 | (y2 - y1) / (height - 1), 348 | (y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3) 349 | 350 | grid = F.affine_grid(theta, torch.Size((rois.size(0), 1, grid_size, grid_size))) 351 | 352 | return grid 353 | 354 | 355 | def _jitter_boxes(gt_boxes, jitter=0.05): 356 | """ 357 | """ 358 | jittered_boxes = gt_boxes.copy() 359 | ws = jittered_boxes[:, 2] - jittered_boxes[:, 0] + 1.0 360 | hs = jittered_boxes[:, 3] - jittered_boxes[:, 1] + 1.0 361 | width_offset = (np.random.rand(jittered_boxes.shape[0]) - 0.5) * jitter * ws 362 | height_offset = (np.random.rand(jittered_boxes.shape[0]) - 0.5) * jitter * hs 363 | jittered_boxes[:, 0] += width_offset 364 | jittered_boxes[:, 2] += width_offset 365 | jittered_boxes[:, 1] += height_offset 366 | jittered_boxes[:, 3] += height_offset 367 | 368 | return jittered_boxes 369 | 370 | 371 | color_pad = ['red', 'green', 'blue', 'cyan', 'brown', 'orange'] 372 | 373 | def vis_detections(ax, class_name, dets, color_i, rest_flag=0): 374 | """Visual debugging of detections.""" 375 | bbox = tuple(int(np.round(x)) for x in dets[:4]) 376 | score = dets[-1] 377 | 378 | if rest_flag == 0: 379 | ax.add_patch( 380 | patches.Rectangle( 381 | (bbox[0], bbox[1]), 382 | bbox[2]-bbox[0], 383 | bbox[3]-bbox[1], 384 | fill=False, # remove background 385 | lw=3, 386 | color=color_pad[color_i] 387 | ) 388 | ) 389 | 390 | ax.text(bbox[0]+5, bbox[1] + 13, '%s' % (class_name) 391 | , fontsize=9, fontweight='bold', backgroundcolor=color_pad[color_i]) 392 | else: 393 | ax.add_patch( 394 | patches.Rectangle( 395 | (bbox[0], bbox[1]), 396 | bbox[2]-bbox[0], 397 | bbox[3]-bbox[1], 398 | fill=False, # remove background 399 | lw=2, 400 | color='grey' 401 | ) 402 | ) 403 | ax.text(bbox[0]+5, bbox[1] + 13, '%s' % (class_name) 404 | , fontsize=9, fontweight='bold', backgroundcolor='grey') 405 | return ax 406 | 407 | import operator as op 408 | 409 | 410 | import itertools 411 | def cbs_beam_tag(num): 412 | tags = [] 413 | for i in range(num+1): 414 | for tag in itertools.combinations(range(num), i): 415 | tags.append(tag) 416 | return len(tags), tags 417 | 418 | def cmpSet(t1, t2): 419 | return sorted(t1) == sorted(t2) 420 | 421 | def containSet(List1, t1): 422 | # List1: return the index that contain 423 | # t1: tupple we want to match 424 | 425 | if t1 == tuple([]): 426 | return [tag for tag in List1 if len(tag) <= 1] 427 | else: 428 | List = [] 429 | for t in List1: 430 | flag = True 431 | for tag in t1: 432 | if tag not in t: 433 | flag = False 434 | 435 | if flag == True and len(t) <= len(t1)+1: 436 | List.append(t) 437 | return List 438 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | def parse_opt(): 11 | parser = argparse.ArgumentParser() 12 | # Data input settings 13 | parser.add_argument('--path_opt', type=str, default='cfgs/anet_res101_vg_feat_10x100prop.yml', 14 | help='') 15 | parser.add_argument('--dataset', type=str, default='anet', 16 | help='') 17 | parser.add_argument('--input_json', type=str, default='', 18 | help='path to the json file containing additional info and vocab') 19 | parser.add_argument('--input_dic', type=str, default='', 20 | help='path to the json containing the preprocessed dataset') 21 | parser.add_argument('--image_path', type=str, default='', 22 | help='path to the h5file containing the image data') 23 | parser.add_argument('--proposal_h5', type=str, default='', 24 | help='path to the json containing the detection result.') 25 | parser.add_argument('--feature_root', type=str, default='', 26 | help='path to the npy flies containing region features') 27 | parser.add_argument('--seg_feature_root', type=str, default='', 28 | help='path to the npy files containing frame-wise features') 29 | 30 | parser.add_argument('--num_workers', type=int, default=20, 31 | help='number of worker to load data') 32 | parser.add_argument('--cuda', action='store_true', 33 | help='whether use cuda') 34 | parser.add_argument('--mGPUs', action='store_true', 35 | help='whether use multiple GPUs') 36 | 37 | # Model settings 38 | parser.add_argument('--rnn_size', type=int, default=1024, 39 | help='size of the rnn in number of hidden nodes in each layer') 40 | parser.add_argument('--num_layers', type=int, default=1, 41 | help='number of layers in the RNN') 42 | parser.add_argument('--input_encoding_size', type=int, default=512, 43 | help='the encoding size of each token in the vocabulary, and the image.') 44 | parser.add_argument('--att_hid_size', type=int, default=512, 45 | help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') 46 | parser.add_argument('--fc_feat_size', type=int, default=3072, 47 | help='2048 for resnet, 4096 for vgg') 48 | parser.add_argument('--att_feat_size', type=int, default=2048, 49 | help='2048 for resnet, 512 for vgg') 50 | parser.add_argument('--t_attn_size', type=int, default=480, help='number of frames sampled for temopral attention') 51 | parser.add_argument('--num_sampled_frm', type=int, default=10) 52 | parser.add_argument('--num_prop_per_frm', type=int, default=100) 53 | parser.add_argument('--prop_thresh', type=float, default=0.2, 54 | help='threshold to filter out low-confidence proposals') 55 | 56 | parser.add_argument('--att_model', type=str, default='topdown', 57 | help='different attention model, now supporting topdown | transformer(unsupervised)') 58 | parser.add_argument('--att_input_mode', type=str, default='both', 59 | help='use whether featmap|region|dual_region|both in topdown language model') 60 | parser.add_argument('--t_attn_mode', type=str, default='bigru', 61 | help='temporal attention context encoding mode: bilstm | bigru') 62 | parser.add_argument('--transfer_mode', type=str, default='cls', help='knowledge transfer mode, could be cls|glove|both') 63 | parser.add_argument('--region_attn_mode', type=str, default='mix', 64 | help='options: dp|add|cat|mix, dp stands for dot-product, add for additive, cat for concat, mix indicates dp for grd. and add for attn., mix_mul indicates dp for grd. and element-wise multiplication for attn.') 65 | 66 | parser.add_argument('--enable_BUTD', action='store_true', help='if enable, the region feature will not include location embedding nor class encoding') 67 | parser.add_argument('--obj_interact', action='store_true', help='self-attention encoding for region features') 68 | parser.add_argument('--exclude_bgd_det', action='store_true', help='exclude __background__ RoIs') 69 | 70 | parser.add_argument('--w_att2', type=float, default=0) 71 | parser.add_argument('--w_grd', type=float, default=0) 72 | parser.add_argument('--w_cls', type=float, default=0) 73 | parser.add_argument('--disable_caption', action='store_true', help='set to disable caption generation loss') 74 | 75 | # Optimization: General 76 | parser.add_argument('--max_epochs', type=int, default=40, 77 | help='number of epochs') 78 | parser.add_argument('--batch_size', type=int, default=10, 79 | help='minibatch size') 80 | parser.add_argument('--grad_clip', type=float, default=0.1, #5., 81 | help='clip gradients at this value') 82 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, 83 | help='strength of dropout in the Language Model RNN') 84 | parser.add_argument('--seq_per_img', type=int, default=1, 85 | help='number of captions to sample for each image during training') 86 | parser.add_argument('--seq_length', type=int, default=20, help='') 87 | parser.add_argument('--beam_size', type=int, default=1, 88 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 89 | 90 | # Optimization: for the Language Model 91 | parser.add_argument('--optim', type=str, default='adam', 92 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') 93 | parser.add_argument('--learning_rate', type=float, default=5e-4, 94 | help='learning rate') 95 | parser.add_argument('--learning_rate_decay_start', type=int, default=1, 96 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') 97 | parser.add_argument('--learning_rate_decay_every', type=int, default=3, 98 | help='every how many iterations thereafter to drop LR?(in epoch)') 99 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, 100 | help='every how many iterations thereafter to drop LR?(in epoch)') 101 | parser.add_argument('--optim_alpha', type=float, default=0.9, 102 | help='alpha for adam') 103 | parser.add_argument('--optim_beta', type=float, default=0.999, 104 | help='beta used for adam') 105 | parser.add_argument('--optim_epsilon', type=float, default=1e-8, 106 | help='epsilon that goes into denominator for smoothing') 107 | parser.add_argument('--weight_decay', type=float, default=0, 108 | help='weight_decay') 109 | 110 | # set training session 111 | parser.add_argument('--start_from', type=str, default=None, 112 | help="""continue training from saved model at this path. Path must contain files saved by previous training process: 113 | 'infos.pkl' : configuration; 114 | 'checkpoint' : paths to model file(s) (created by tf). 115 | Note: this file contains absolute paths, be careful when moving files around; 116 | 'model.ckpt-*' : file(s) with model definition (created by tf) 117 | """) 118 | parser.add_argument('--id', type=str, default='', 119 | help='an id identifying this run/job. used in cross-val and appended when writing progress files') 120 | 121 | # Evaluation/Checkpointing 122 | parser.add_argument('--train_split', type=str, default='training', 123 | help='') 124 | parser.add_argument('--val_split', type=str, default='validation', 125 | help='') 126 | parser.add_argument('--inference_only', action='store_true', 127 | help='') 128 | parser.add_argument('--densecap_references', type=str, nargs='+', default=['./data/anet/anet_entities_val_1.json', './data/anet/anet_entities_val_2.json'], 129 | help='reference files with ground truth captions to compare results against. delimited (,) str') 130 | parser.add_argument('--densecap_verbose', action='store_true', help='evaluate CIDEr only or all language metrics in densecap') 131 | parser.add_argument('--grd_reference', type=str, default='tools/anet_entities/data/anet_entities_cleaned_class_thresh50_trainval.json') 132 | parser.add_argument('--split_file', type=str, default='tools/anet_entities/data/split_ids_anet_entities.json') 133 | 134 | parser.add_argument('--eval_obj_grounding_gt', action='store_true', 135 | help='whether evaluate object grounding accuracy') 136 | parser.add_argument('--eval_obj_grounding', action='store_true', 137 | help='whether evaluate object grounding accuracy') 138 | parser.add_argument('--vis_attn', action='store_true', help='visualize attention') 139 | parser.add_argument('--enable_visdom', action='store_true') 140 | parser.add_argument('--visdom_server', type=str, default='', help='update it with your server url') 141 | 142 | parser.add_argument('--val_images_use', type=int, default=5000, 143 | help='how many segments to use when periodically evaluating the validation loss? (-1 = all)') 144 | parser.add_argument('--val_every_epoch', type=int, default=2, 145 | help='how many segments to use when periodically evaluating the validation loss? (-1 = all)') 146 | parser.add_argument('--checkpoint_path', type=str, default='save', 147 | help='directory to store checkpointed models') 148 | parser.add_argument('--language_eval', action='store_true', 149 | help='Evaluate language as well (1 = yes, 0 = no)?') 150 | parser.add_argument('--load_best_score', type=int, default=1, 151 | help='Do we load previous best score when resuming training.') 152 | parser.add_argument('--disp_interval', type=int, default=100, 153 | help='how many iteration to display an loss.') 154 | parser.add_argument('--losses_log_every', type=int, default=10, 155 | help='how many iteration for log.') 156 | parser.add_argument('--det_oracle', action='store_true', 157 | help='whether use oracle bounding box.') 158 | parser.add_argument('--frm_oracle', action='store_true', 159 | help='whether use oracle frame.') 160 | parser.add_argument('--seed', type=int, default=123) 161 | args = parser.parse_args() 162 | 163 | return args 164 | -------------------------------------------------------------------------------- /prepro/prepro_dic_anet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # Last modified by Luowei Zhou on 09/25/2018 9 | 10 | import os 11 | import json 12 | import argparse 13 | from random import shuffle, seed 14 | import string 15 | import h5py 16 | import numpy as np 17 | import torch 18 | import torchvision.models as models 19 | from torch.autograd import Variable 20 | import pdb 21 | from stanfordcorenlp import StanfordCoreNLP 22 | from nltk.tokenize import word_tokenize 23 | 24 | nlp = StanfordCoreNLP('tools/stanford-corenlp-full-2018-02-27') 25 | props={'annotators': 'ssplit, tokenize, lemma','pipelineLanguage':'en', 'outputFormat':'json'} 26 | 27 | def build_vocab(vids, split, params): 28 | count_thr = params['word_count_threshold'] 29 | 30 | # count up the number of words 31 | # stats on sentence length distribution 32 | counts = {} 33 | sent_lengths = {} 34 | for vid_id, vid in vids.items(): 35 | if split[vid_id] in ('training', 'validation'): 36 | for seg_id, seg in vid['segments'].items(): 37 | nw = len(seg['tokens']) 38 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 39 | for w in seg['tokens']: 40 | counts[w] = counts.get(w, 0)+1 41 | 42 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 43 | print('top words and their counts:') 44 | print('\n'.join(map(str,cw[:20]))) 45 | 46 | print('The counts of empty token is {}'.format(counts[''])) 47 | counts[''] = 0 48 | # print some stats 49 | total_words = sum(counts.values()) 50 | print('total words:', total_words) 51 | bad_words = [w for w,n in counts.items() if n <= count_thr] 52 | vocab = [w for w,n in counts.items() if n > count_thr] 53 | bad_count = sum(counts[w] for w in bad_words) 54 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 55 | print('number of words in vocab would be %d' % (len(vocab), )) 56 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 57 | 58 | max_len = max(sent_lengths.keys()) 59 | print('max length sentence in raw data: ', max_len) 60 | print('sentence length distribution (count, number of words):') 61 | sum_len = sum(sent_lengths.values()) 62 | for i in range(max_len+1): 63 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 64 | 65 | # lets now produce the final annotations 66 | if bad_count > 0: 67 | # additional special UNK token we will use below to map infrequent words to 68 | print('inserting the special UNK token') 69 | vocab.append('UNK') 70 | 71 | vids_new = {} 72 | for vid_id, vid in vids.items(): 73 | # if split[vid_id] in ('training', 'validation'): 74 | if vid_id in split: 75 | segs_new = {} 76 | for seg_id, seg in vid['segments'].items(): 77 | txt = seg['tokens'] 78 | clss = seg['process_clss'] 79 | bbox = seg['process_bnd_box'] 80 | idx = seg['process_idx'] 81 | frm_idx = seg['frame_ind'] 82 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 83 | segs_new[seg_id] = {'caption':caption, 'clss':clss, 'bbox':bbox, 'idx':idx, 'frm_idx':frm_idx} 84 | vids_new[vid_id] = {} 85 | vids_new[vid_id]['segments'] = segs_new 86 | # vids_new[vid_id]['rwidth'] = vid['rwidth'] 87 | # vids_new[vid_id]['rheight'] = vid['rheight'] 88 | 89 | return vocab, vids_new 90 | 91 | def main(params): 92 | 93 | imgs_split = json.load(open(params['split_file'], 'r')) 94 | split = {} 95 | for s, ids in imgs_split.items(): 96 | for i in ids: 97 | split[i] = s # video names are the ids 98 | 99 | vids_processed = json.load(open(params['input_json'], 'r')) 100 | 101 | # word to detection label 102 | anet_class_all = vids_processed['vocab'] 103 | wtod = {} 104 | for i in range(len(anet_class_all)): 105 | # TODO, assume each object class is a 1-gram, can try n-gram or multiple phrases later 106 | wtod[anet_class_all[i]] = i 107 | 108 | vids_processed = vids_processed['annotations'] 109 | for vid_id, vid in vids_processed.items(): 110 | if vid_id in split: 111 | vid['split'] = split[vid_id] 112 | else: 113 | vid['split'] = 'rest' 114 | print('Video {} can not be found in the dataset!'.format(vid_id)) 115 | seed(123) # make reproducible 116 | 117 | # create the vocab 118 | vocab, vids_new = build_vocab(vids_processed, split, params) 119 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 120 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 121 | 122 | wtol = {} 123 | for w in vocab: 124 | out = json.loads(nlp.annotate(w.encode('utf-8'), properties=props)) 125 | lemma_w = out['sentences'][0]['tokens'][0]['lemma'] 126 | wtol[w] = lemma_w 127 | 128 | # create output json file 129 | out = {} 130 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 131 | out['wtod'] = wtod 132 | out['wtol'] = wtol 133 | out['videos'] = [] 134 | for vid_id, vid in vids_processed.items(): 135 | jvid = {} 136 | jvid['vid_id'] = vid_id 137 | jvid['split'] = vid['split'] 138 | seg_lst = vid['segments'].keys() 139 | seg_lst = [int(s) for s in seg_lst] 140 | seg_lst.sort() 141 | for i in seg_lst: 142 | jvid['seg_id'] = str(i) 143 | jvid['id'] = vid_id+'_segment_'+str(i).zfill(2) # some info 144 | out['videos'].append(jvid.copy()) 145 | print('Total number of segments: {}'.format(len(out['videos']))) 146 | 147 | json.dump(out, open(params['output_dic_json'], 'w')) 148 | print('wrote ', params['output_dic_json']) 149 | 150 | json.dump(vids_new, open(params['output_cap_json'], 'w')) 151 | print('wrote ', params['output_cap_json']) 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser() 155 | 156 | # input json 157 | parser.add_argument('--split_file', default='data/anet/split_ids_anet_entities.json') 158 | parser.add_argument('--input_json', default='data/anet/anet_cleaned_class_thresh50.json') 159 | parser.add_argument('--output_dic_json', default='data/anet/dic_anet.json', help='output json file') 160 | parser.add_argument('--output_cap_json', default='data/anet/cap_anet.json', help='output json file') 161 | 162 | # options 163 | parser.add_argument('--max_length', default=20, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 164 | parser.add_argument('--word_count_threshold', default=3, type=int, help='only words that occur more than this number of times will be put in vocab') 165 | 166 | args = parser.parse_args() 167 | params = vars(args) # convert to ordinary dict 168 | print('parsed input parameters:') 169 | print(json.dumps(params, indent = 2)) 170 | main(params) 171 | nlp.close() 172 | -------------------------------------------------------------------------------- /tools/download_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | # Script to download all the necessary data files and place under the data directory 11 | # Written by Luowei Zhou on 05/01/2019 12 | 13 | 14 | DATA_ROOT='data' 15 | 16 | mkdir -p $DATA_ROOT/anet save results log 17 | 18 | # annotation files 19 | wget -P $DATA_ROOT/anet/ https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_prep.tar.gz 20 | wget -P $DATA_ROOT/anet/ https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_captions.tar.gz 21 | wget -P tools/coco-caption/annotations https://github.com/jiasenlu/coco-caption/raw/master/annotations/caption_flickr30k.json 22 | 23 | # feature files 24 | wget -P $DATA_ROOT/ https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/detectron_weights.tar.gz 25 | wget -P $DATA_ROOT/anet/ https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/rgb_motion_1d.tar.gz 26 | wget -P $DATA_ROOT/anet/ https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_detection_vg_fc6_feat_100rois.h5 27 | wget -P $DATA_ROOT/anet/ https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/fc6_feat_100rois.tar.gz 28 | 29 | # Stanford CoreNLP 3.9.1 30 | wget -P tools/ http://nlp.stanford.edu/software/stanford-corenlp-full-2018-02-27.zip 31 | 32 | # pre-trained models 33 | wget -P save/ https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/pre-trained-models.tar.gz 34 | 35 | # uncompress 36 | cd $DATA_ROOT 37 | for file in *.tar.gz; do tar -zxvf "${file}" && rm "${file}"; done 38 | cd anet 39 | for file in *.tar.gz; do tar -zxvf "${file}" && rm "${file}"; done 40 | cd ../../tools 41 | for file in *.zip; do unzip "${file}" && rm "${file}"; done 42 | cd coco-caption 43 | ./get_stanford_models.sh 44 | cd ../../save 45 | for file in *.tar.gz; do tar -zxvf "${file}" && rm "${file}"; done 46 | mv pre-trained-models/* . && rm -r pre-trained-models 47 | -------------------------------------------------------------------------------- /tools/vg_cls_overlap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # install spacy and download the spacy English model: 9 | # https://github.com/pytorch/text#optional-requirements 10 | 11 | import json 12 | import os 13 | from stanfordcorenlp import StanfordCoreNLP 14 | import torchtext 15 | 16 | nlp = StanfordCoreNLP('./stanford-corenlp-full-2018-02-27') 17 | props={'annotators': 'ssplit, tokenize, lemma','pipelineLanguage':'en', 'outputFormat':'json'} 18 | 19 | # dataset we support: VG, ANet-Entities, Flickr30k-Entities, Open Image v4, Something-Something v2 20 | src_dataset = 'VG' 21 | grd_dataset = 'ss2' 22 | 23 | freq_thresh = 100 24 | 25 | def get_dataset_cls(src_dataset): 26 | classes = [] 27 | if src_dataset == 'VG': 28 | classes = [] 29 | obj_cls_file = 'vg_object_vocab.txt' 30 | with open(obj_cls_file) as f: 31 | data = f.readlines() 32 | classes.extend([i.strip() for i in data]) 33 | elif src_dataset == 'flickr30k': 34 | obj_cls_file = 'data/flickr30k/flickr30k_class_name.txt' 35 | with open(obj_cls_file) as f: 36 | data = f.readlines() 37 | classes.extend([i.strip() for i in data]) 38 | elif src_dataset == 'Open': 39 | cls_file_map = 'open_image_vocab_map.txt' 40 | vocab_dict = {} 41 | with open(cls_file_map) as f: 42 | data = f.readlines() 43 | for i in data: 44 | spt = i.split(',') 45 | idx = spt[0] 46 | cls = ','.join(spt[1:]) 47 | vocab_dict[idx.strip().replace('"', '')] = cls.strip().replace(' ', '-').replace('"', '') 48 | 49 | print(vocab_dict) 50 | obj_cls_file = 'open_image_vocab.txt' 51 | with open(obj_cls_file) as f: 52 | data = f.readlines() 53 | classes.extend([vocab_dict[i.strip()] for i in data if vocab_dict.get(i.strip(), 'dummy') != 'dummy']) # fill in dummy if does not exist 54 | elif src_dataset == 'ss2': 55 | src_file_lst = ['something-something-v2-train.json', 'something-something-v2-validation.json'] 56 | data = [] 57 | for src_file in src_file_lst: 58 | with open(src_file) as f: 59 | data.extend(json.load(f)) 60 | 61 | obj_vocab = {} 62 | for i in data: 63 | for cls in i['placeholders']: 64 | if cls in obj_vocab: 65 | obj_vocab[cls] += 1 66 | else: 67 | obj_vocab[cls] = 1 68 | print('unique object class in ss2: ', len(obj_vocab)) 69 | print('top 100 frequent object class:', sorted(obj_vocab.keys(), key=lambda k:obj_vocab[k], reverse=True)[:100]) 70 | classes = obj_vocab 71 | elif src_dataset == 'anet': 72 | src_file_lst = ['train.json', 'val_1.json'] 73 | class_dict = {} 74 | for src_file in src_file_lst: 75 | with open(src_file) as f: 76 | data = json.load(f) 77 | for k, i in data.items(): 78 | for s in i['sentences']: 79 | out = json.loads(nlp.annotate(s.encode('utf-8'), properties=props)) 80 | if len(out['sentences']) > 0: 81 | for token in out['sentences'][0]['tokens']: 82 | if 'NN' in token['pos']: 83 | lemma_w = token['lemma'] 84 | if lemma_w in class_dict: 85 | class_dict[lemma_w] += 1 86 | else: 87 | class_dict[lemma_w] = 1 88 | 89 | tmp = {} 90 | for k, freq in class_dict.items(): 91 | if freq >= freq_thresh: 92 | tmp[k] = freq 93 | class_dict = tmp 94 | print('number of lemma word in the vocab: ', len(class_dict), src_dataset) 95 | print('unique object class in ss2: ', len(tmp)) 96 | print('top 100 frequent object class:', sorted(tmp.keys(), key=lambda k:tmp[k], reverse=True)[:100]) 97 | return class_dict 98 | 99 | class_dict = {} 100 | if src_dataset != 'ss2': 101 | for i, w in enumerate(classes): 102 | w_s = w.split(',') 103 | for v in w_s: 104 | # if len(v.split(' ')) == 1: 105 | out = json.loads(nlp.annotate(v.encode('utf-8'), properties=props)) 106 | if len(out['sentences']) > 0: 107 | for token in out['sentences'][0]['tokens']: 108 | if 'NN' in token['pos']: 109 | lemma_w = token['lemma'] 110 | class_dict[lemma_w] = i 111 | print('number of lemma word in the vocab: ', len(class_dict), src_dataset) 112 | else: 113 | for w, freq in classes.items(): 114 | w_s = w.split(',') 115 | for v in w_s: 116 | # if len(v.split(' ')) == 1: 117 | out = json.loads(nlp.annotate(v.encode('utf-8'), properties=props)) 118 | if len(out['sentences']) > 0: 119 | for token in out['sentences'][0]['tokens']: 120 | if 'NN' in token['pos']: 121 | lemma_w = token['lemma'] 122 | if lemma_w in class_dict: 123 | class_dict[lemma_w] += freq 124 | else: 125 | class_dict[lemma_w] = freq 126 | 127 | tmp = {} 128 | for k, freq in class_dict.items(): 129 | if freq >= freq_thresh: 130 | tmp[k] = freq 131 | class_dict = tmp 132 | print('number of lemma word in the vocab: ', len(class_dict), src_dataset) 133 | 134 | return class_dict 135 | 136 | 137 | def load_corpus(grd_dataset): 138 | # vocab frequency 139 | sentences = [] 140 | if grd_dataset == 'flickr30k': 141 | imgs_processed = json.load(open('../data/flickr30k/flickr30k_cleaned_class.json', 'r')) 142 | imgs_processed = imgs_processed['annotations'] 143 | 144 | sentences = [] 145 | for img in imgs_processed: 146 | for i in img['captions']: 147 | sentences.append(' '.join(i['tokens'])) 148 | elif grd_dataset == 'ss2': 149 | src_file_lst = ['something-something-v2-train.json', 'something-something-v2-validation.json'] 150 | data = [] 151 | for src_file in src_file_lst: 152 | with open(src_file) as f: 153 | data.extend(json.load(f)) 154 | 155 | for i in data: 156 | sentences.append(i['label']) 157 | elif grd_dataset == 'anet': 158 | src_file_lst = ['train.json', 'val_1.json'] 159 | for src_file in src_file_lst: 160 | with open(src_file) as f: 161 | data = json.load(f) 162 | for k, i in data.items(): 163 | for s in i['sentences']: 164 | sentences.append(s) 165 | else: 166 | raise NotImplementedError 167 | 168 | return sentences 169 | 170 | 171 | def main(): 172 | 173 | class_dict = get_dataset_cls(src_dataset) 174 | g_class_dict = get_dataset_cls(grd_dataset) 175 | text_proc = torchtext.data.Field(sequential=True, tokenize='spacy', 176 | lower=True, batch_first=True) 177 | sentences = load_corpus(grd_dataset) 178 | 179 | print('Total number of sentences: {}'.format(len(sentences))) 180 | sentences_proc = list(map(text_proc.preprocess, sentences)) 181 | text_proc.build_vocab(sentences_proc) 182 | # print(text_proc.vocab.freqs.most_common(20)) 183 | 184 | # check the overlapped vocab 185 | missed_cls = [] 186 | catched_cls = [] 187 | 188 | for k, i in g_class_dict.items(): 189 | if k not in class_dict: 190 | missed_cls.append((k, text_proc.vocab.freqs[k])) 191 | else: 192 | catched_cls.append((k, text_proc.vocab.freqs[k])) 193 | 194 | missed_cls = sorted(missed_cls, key=lambda x:x[1], reverse=True) 195 | catched_cls = sorted(catched_cls, key=lambda x:x[1], reverse=True) 196 | 197 | for i, tup in enumerate(missed_cls): 198 | if i < 20: 199 | print('{}: {}'.format(tup[0], tup[1])) 200 | # for i, tup in enumerate(catched_cls): 201 | # if i < 20: 202 | # print('{}: {}'.format(tup[0], tup[1])) 203 | 204 | print('Number of classes are missing: {}, percentage: {}'.format(len(missed_cls), 205 | len(missed_cls)*1./len(g_class_dict))) 206 | 207 | nlp.close() 208 | 209 | if __name__ == "__main__": 210 | main() 211 | --------------------------------------------------------------------------------