├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── TMM.png ├── common ├── __init__.py ├── coco_caption │ ├── .gitignore │ ├── README.md │ ├── annotations │ │ └── captions_val2014.json │ ├── cocoEvalCapDemo.ipynb │ ├── get_stanford_models.sh │ ├── license.txt │ ├── pycocoevalcap │ │ ├── __init__.py │ │ ├── bleu │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── bleu.py │ │ │ └── bleu_scorer.py │ │ ├── cider │ │ │ ├── __init__.py │ │ │ ├── cider.py │ │ │ └── cider_scorer.py │ │ ├── eval.py │ │ ├── meteor │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ └── paraphrase-en.gz │ │ │ ├── meteor-1.5.jar │ │ │ └── meteor.py │ │ ├── rouge │ │ │ ├── __init__.py │ │ │ └── rouge.py │ │ ├── spice │ │ │ ├── __init__.py │ │ │ ├── lib │ │ │ │ ├── Meteor-1.5.jar │ │ │ │ ├── SceneGraphParser-1.0.jar │ │ │ │ ├── ejml-0.23.jar │ │ │ │ ├── fst-2.47.jar │ │ │ │ ├── guava-19.0.jar │ │ │ │ ├── hamcrest-core-1.3.jar │ │ │ │ ├── jackson-core-2.5.3.jar │ │ │ │ ├── javassist-3.19.0-GA.jar │ │ │ │ ├── json-simple-1.1.1.jar │ │ │ │ ├── junit-4.12.jar │ │ │ │ ├── lmdbjni-0.4.6.jar │ │ │ │ ├── lmdbjni-linux64-0.4.6.jar │ │ │ │ ├── lmdbjni-osx64-0.4.6.jar │ │ │ │ ├── lmdbjni-win64-0.4.6.jar │ │ │ │ ├── objenesis-2.4.jar │ │ │ │ ├── slf4j-api-1.7.12.jar │ │ │ │ └── slf4j-simple-1.7.21.jar │ │ │ ├── spice-1.0.jar │ │ │ └── spice.py │ │ └── tokenizer │ │ │ ├── __init__.py │ │ │ ├── ptbtokenizer.py │ │ │ └── stanford-corenlp-3.4.1.jar │ ├── pycocotools │ │ ├── __init__.py │ │ └── coco.py │ └── results │ │ └── captions_val2014_fakecap_results.json ├── configuration.py ├── get_repo.py ├── inputs │ ├── __init__.py │ ├── manager_image_caption.py │ └── preprocessing │ │ ├── __init__.py │ │ ├── cifar_preprocessing.py │ │ ├── cifarnet_preprocessing.py │ │ ├── inception_preprocessing_radix.py │ │ ├── preprocessing_factory.py │ │ └── vgg_preprocessing.py ├── natural_sort.py ├── net_params.py ├── nets │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── cyclegan.py │ ├── cyclegan_test.py │ ├── dcgan.py │ ├── dcgan_test.py │ ├── i3d.py │ ├── i3d_test.py │ ├── i3d_utils.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_utils.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── inception_v4.py │ ├── inception_v4_test.py │ ├── lenet.py │ ├── mobilenet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── conv_blocks.py │ │ ├── madds_top1_accuracy.png │ │ ├── mnet_v1_vs_v2_pixel1_latency.png │ │ ├── mobilenet.py │ │ ├── mobilenet_example.ipynb │ │ ├── mobilenet_v2.py │ │ └── mobilenet_v2_test.py │ ├── mobilenet_v1.md │ ├── mobilenet_v1.png │ ├── mobilenet_v1.py │ ├── mobilenet_v1_eval.py │ ├── mobilenet_v1_test.py │ ├── mobilenet_v1_train.py │ ├── nasnet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── nasnet.py │ │ ├── nasnet_test.py │ │ ├── nasnet_utils.py │ │ ├── nasnet_utils_test.py │ │ ├── pnasnet.py │ │ └── pnasnet_test.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── pix2pix.py │ ├── pix2pix_test.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── s3dg.py │ ├── s3dg_test.py │ ├── vgg.py │ └── vgg_test.py ├── ops.py ├── ops_rnn.py ├── scst │ ├── __init__.py │ ├── cider_ruotianluo │ │ ├── README.md │ │ ├── data │ │ │ └── abstract_candsB.json │ │ ├── license.txt │ │ ├── pyciderevalcap │ │ │ ├── __init__.py │ │ │ ├── cider │ │ │ │ ├── __init__.py │ │ │ │ ├── cider.py │ │ │ │ └── cider_scorer.py │ │ │ ├── ciderD │ │ │ │ ├── __init__.py │ │ │ │ ├── ciderD.py │ │ │ │ └── ciderD_scorer.py │ │ │ ├── eval.py │ │ │ └── tokenizer │ │ │ │ ├── __init__.py │ │ │ │ ├── ptbtokenizer.py │ │ │ │ └── stanford-corenlp-3.4.1.jar │ │ └── pydataformat │ │ │ ├── __init__.py │ │ │ ├── jsonify_refs.py │ │ │ └── loadData.py │ ├── prepro_ngrams.py │ └── scorers.py ├── stanford_corenlp.py └── utils.py ├── datasets └── preprocessing │ ├── __init__.py │ ├── coco_prepro.py │ ├── insta_prepro.py │ ├── prepro_base.py │ └── ptb_tokenizer.py └── src ├── example.sh ├── infer.py ├── infer_fn.py ├── model.py ├── model_base.py ├── setup.sh ├── train.py └── train_fn.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # These files are text and should be normalized (Convert crlf => lf) 2 | *.tex text 3 | *.php text 4 | *.css text 5 | *.js text 6 | *.htm text 7 | *.html text 8 | *.xml text 9 | *.txt text 10 | *.ini text 11 | *.inc text 12 | .htaccess text 13 | 14 | # These files are binary and should be left untouched 15 | # (binary is a macro for -text -diff) 16 | *.psd binary 17 | *.png binary 18 | *.jpg binary 19 | *.jpeg binary 20 | *.gif binary 21 | *.ico binary 22 | *.mov binary 23 | *.mp4 binary 24 | *.mp3 binary 25 | *.flv binary 26 | *.fla binary 27 | *.swf binary 28 | *.gz binary 29 | *.zip binary 30 | *.7z binary 31 | *.ttf binary 32 | 33 | # Auto detect text files and perform LF normalization 34 | # http://davidlaing.com/2012/09/19/customise-your-gitattributes-to-become-a-git-ninja/ 35 | * text=auto 36 | 37 | # Documents 38 | *.doc diff=astextplain 39 | *.DOC diff=astextplain 40 | *.docx diff=astextplain 41 | *.DOCX diff=astextplain 42 | *.dot diff=astextplain 43 | *.DOT diff=astextplain 44 | *.pdf diff=astextplain 45 | *.PDF diff=astextplain 46 | *.rtf diff=astextplain 47 | *.RTF diff=astextplain 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Tan Jia Huei 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /TMM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/TMM.png -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 28 18:12:55 2017 4 | 5 | @author: jiahuei 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /common/coco_caption/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /common/coco_caption/README.md: -------------------------------------------------------------------------------- 1 | Microsoft COCO Caption Evaluation 2 | =================== 3 | 4 | Evaluation codes for MS COCO caption generation. 5 | 6 | ## Requirements ## 7 | - java 1.8.0 8 | - python 2.7 9 | 10 | ## Files ## 11 | ./ 12 | - cocoEvalCapDemo.py (demo script) 13 | 14 | ./annotation 15 | - captions_val2014.json (MS COCO 2014 caption validation set) 16 | - Visit MS COCO [download](http://mscoco.org/dataset/#download) page for more details. 17 | 18 | ./results 19 | - captions_val2014_fakecap_results.json (an example of fake results for running demo) 20 | - Visit MS COCO [format](http://mscoco.org/dataset/#format) page for more details. 21 | 22 | ./pycocoevalcap: The folder where all evaluation codes are stored. 23 | - evals.py: The file includes COCOEavlCap class that can be used to evaluate results on COCO. 24 | - tokenizer: Python wrapper of Stanford CoreNLP PTBTokenizer 25 | - bleu: Bleu evalutation codes 26 | - meteor: Meteor evaluation codes 27 | - rouge: Rouge-L evaluation codes 28 | - cider: CIDEr evaluation codes 29 | - spice: SPICE evaluation codes 30 | 31 | ## Setup ## 32 | 33 | - You will first need to download the [Stanford CoreNLP 3.6.0](http://stanfordnlp.github.io/CoreNLP/index.html) code and models for use by SPICE. To do this, run: 34 | ./get_stanford_models.sh 35 | - Note: SPICE will try to create a cache of parsed sentences in ./pycocoevalcap/spice/cache/. This dramatically speeds up repeated evaluations. The cache directory can be moved by setting 'CACHE_DIR' in ./pycocoevalcap/spice. In the same file, caching can be turned off by removing the '-cache' argument to 'spice_cmd'. 36 | 37 | ## References ## 38 | 39 | - [Microsoft COCO Captions: Data Collection and Evaluation Server](http://arxiv.org/abs/1504.00325) 40 | - PTBTokenizer: We use the [Stanford Tokenizer](http://nlp.stanford.edu/software/tokenizer.shtml) which is included in [Stanford CoreNLP 3.4.1](http://nlp.stanford.edu/software/corenlp.shtml). 41 | - BLEU: [BLEU: a Method for Automatic Evaluation of Machine Translation](http://www.aclweb.org/anthology/P02-1040.pdf) 42 | - Meteor: [Project page](http://www.cs.cmu.edu/~alavie/METEOR/) with related publications. We use the latest version (1.5) of the [Code](https://github.com/mjdenkowski/meteor). Changes have been made to the source code to properly aggreate the statistics for the entire corpus. 43 | - Rouge-L: [ROUGE: A Package for Automatic Evaluation of Summaries](http://anthology.aclweb.org/W/W04/W04-1013.pdf) 44 | - CIDEr: [CIDEr: Consensus-based Image Description Evaluation](http://arxiv.org/pdf/1411.5726.pdf) 45 | - SPICE: [SPICE: Semantic Propositional Image Caption Evaluation](https://arxiv.org/abs/1607.08822) 46 | 47 | ## Developers ## 48 | - Xinlei Chen (CMU) 49 | - Hao Fang (University of Washington) 50 | - Tsung-Yi Lin (Cornell) 51 | - Ramakrishna Vedantam (Virgina Tech) 52 | 53 | ## Acknowledgement ## 54 | - David Chiang (University of Norte Dame) 55 | - Michael Denkowski (CMU) 56 | - Alexander Rush (Harvard University) 57 | -------------------------------------------------------------------------------- /common/coco_caption/get_stanford_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This script downloads the Stanford CoreNLP models. 3 | 4 | CORENLP=stanford-corenlp-full-2015-12-09 5 | SPICELIB=pycocoevalcap/spice/lib 6 | JAR=stanford-corenlp-3.6.0 7 | 8 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 9 | cd $DIR 10 | 11 | if [ -f $SPICELIB/$JAR.jar ]; then 12 | echo "Found Stanford CoreNLP." 13 | else 14 | echo "Downloading..." 15 | wget http://nlp.stanford.edu/software/$CORENLP.zip 16 | echo "Unzipping..." 17 | unzip $CORENLP.zip -d $SPICELIB/ 18 | mv $SPICELIB/$CORENLP/$JAR.jar $SPICELIB/ 19 | mv $SPICELIB/$CORENLP/$JAR-models.jar $SPICELIB/ 20 | rm -f $CORENLP.zip 21 | rm -rf $SPICELIB/$CORENLP/ 22 | echo "Done." 23 | fi 24 | -------------------------------------------------------------------------------- /common/coco_caption/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | from tokenizer.ptbtokenizer import PTBTokenizer 3 | from bleu.bleu import Bleu 4 | from meteor.meteor import Meteor 5 | from rouge.rouge import Rouge 6 | from cider.cider import Cider 7 | from spice.spice import Spice 8 | 9 | class COCOEvalCap: 10 | def __init__(self, coco, cocoRes): 11 | self.evalImgs = [] 12 | self.eval = {} 13 | self.imgToEval = {} 14 | self.coco = coco 15 | self.cocoRes = cocoRes 16 | self.params = {'image_id': coco.getImgIds()} 17 | 18 | def evaluate(self): 19 | imgIds = self.params['image_id'] 20 | # imgIds = self.coco.getImgIds() 21 | gts = {} 22 | res = {} 23 | for imgId in imgIds: 24 | gts[imgId] = self.coco.imgToAnns[imgId] 25 | res[imgId] = self.cocoRes.imgToAnns[imgId] 26 | 27 | # ================================================= 28 | # Set up scorers 29 | # ================================================= 30 | print 'tokenization...' 31 | tokenizer = PTBTokenizer() 32 | gts = tokenizer.tokenize(gts) 33 | res = tokenizer.tokenize(res) 34 | 35 | # ================================================= 36 | # Set up scorers 37 | # ================================================= 38 | print 'setting up scorers...' 39 | scorers = [ 40 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 41 | (Meteor(),"METEOR"), 42 | (Rouge(), "ROUGE_L"), 43 | (Cider(), "CIDEr"), 44 | (Spice(), "SPICE") 45 | ] 46 | 47 | # ================================================= 48 | # Compute scores 49 | # ================================================= 50 | for scorer, method in scorers: 51 | print 'computing %s score...'%(scorer.method()) 52 | score, scores = scorer.compute_score(gts, res) 53 | if type(method) == list: 54 | for sc, scs, m in zip(score, scores, method): 55 | self.setEval(sc, m) 56 | self.setImgToEvalImgs(scs, gts.keys(), m) 57 | print "%s: %0.3f"%(m, sc) 58 | else: 59 | self.setEval(score, method) 60 | self.setImgToEvalImgs(scores, gts.keys(), method) 61 | print "%s: %0.3f"%(method, score) 62 | self.setEvalImgs() 63 | 64 | def setEval(self, score, method): 65 | self.eval[method] = score 66 | 67 | def setImgToEvalImgs(self, scores, imgIds, method): 68 | for imgId, score in zip(imgIds, scores): 69 | if not imgId in self.imgToEval: 70 | self.imgToEval[imgId] = {} 71 | self.imgToEval[imgId]["image_id"] = imgId 72 | self.imgToEval[imgId][method] = score 73 | 74 | def setEvalImgs(self): 75 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 19 | '-', '-', '-stdio', '-l', 'en', '-norm'] 20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 21 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 22 | stdin=subprocess.PIPE, \ 23 | stdout=subprocess.PIPE, \ 24 | stderr=subprocess.PIPE) 25 | # Used to guarantee thread safety 26 | self.lock = threading.Lock() 27 | 28 | def compute_score(self, gts, res): 29 | assert(gts.keys() == res.keys()) 30 | imgIds = gts.keys() 31 | scores = [] 32 | 33 | eval_line = 'EVAL' 34 | self.lock.acquire() 35 | for i in imgIds: 36 | assert(len(res[i]) == 1) 37 | stat = self._stat(res[i][0], gts[i]) 38 | eval_line += ' ||| {}'.format(stat) 39 | 40 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 41 | for i in range(0,len(imgIds)): 42 | scores.append(float(self.meteor_p.stdout.readline().strip())) 43 | score = float(self.meteor_p.stdout.readline().strip()) 44 | self.lock.release() 45 | 46 | return score, scores 47 | 48 | def method(self): 49 | return "METEOR" 50 | 51 | def _stat(self, hypothesis_str, reference_list): 52 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 53 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 54 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 55 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 56 | return self.meteor_p.stdout.readline().strip() 57 | 58 | def _score(self, hypothesis_str, reference_list): 59 | self.lock.acquire() 60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 63 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 64 | stats = self.meteor_p.stdout.readline().strip() 65 | eval_line = 'EVAL ||| {}'.format(stats) 66 | # EVAL ||| stats 67 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 68 | score = float(self.meteor_p.stdout.readline().strip()) 69 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 70 | # thanks for Andrej for pointing this out 71 | score = float(self.meteor_p.stdout.readline().strip()) 72 | self.lock.release() 73 | return score 74 | 75 | def __del__(self): 76 | self.lock.acquire() 77 | self.meteor_p.stdin.close() 78 | self.meteor_p.kill() 79 | self.meteor_p.wait() 80 | self.lock.release() 81 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/__init__.py -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/Meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/Meteor-1.5.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/SceneGraphParser-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/SceneGraphParser-1.0.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/ejml-0.23.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/ejml-0.23.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/fst-2.47.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/fst-2.47.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/guava-19.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/guava-19.0.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/hamcrest-core-1.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/hamcrest-core-1.3.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/jackson-core-2.5.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/jackson-core-2.5.3.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/javassist-3.19.0-GA.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/javassist-3.19.0-GA.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/json-simple-1.1.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/json-simple-1.1.1.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/junit-4.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/junit-4.12.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-0.4.6.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-linux64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-linux64-0.4.6.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-osx64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-osx64-0.4.6.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-win64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/lmdbjni-win64-0.4.6.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/objenesis-2.4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/objenesis-2.4.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/slf4j-api-1.7.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/slf4j-api-1.7.12.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/lib/slf4j-simple-1.7.21.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/lib/slf4j-simple-1.7.21.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/spice-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/spice/spice-1.0.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/spice/spice.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import sys 4 | import subprocess 5 | import threading 6 | import json 7 | import numpy as np 8 | import ast 9 | import tempfile 10 | 11 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 12 | SPICE_JAR = 'spice-1.0.jar' 13 | TEMP_DIR = 'tmp' 14 | CACHE_DIR = 'cache' 15 | 16 | class Spice: 17 | """ 18 | Main Class to compute the SPICE metric 19 | """ 20 | 21 | def float_convert(self, obj): 22 | try: 23 | return float(obj) 24 | except: 25 | return np.nan 26 | 27 | def compute_score(self, gts, res): 28 | assert(sorted(gts.keys()) == sorted(res.keys())) 29 | imgIds = sorted(gts.keys()) 30 | 31 | # Prepare temp input file for the SPICE scorer 32 | input_data = [] 33 | for id in imgIds: 34 | hypo = res[id] 35 | ref = gts[id] 36 | 37 | # Sanity check. 38 | assert(type(hypo) is list) 39 | assert(len(hypo) == 1) 40 | assert(type(ref) is list) 41 | assert(len(ref) >= 1) 42 | 43 | input_data.append({ 44 | "image_id" : id, 45 | "test" : hypo[0], 46 | "refs" : ref 47 | }) 48 | 49 | cwd = os.path.dirname(os.path.abspath(__file__)) 50 | temp_dir=os.path.join(cwd, TEMP_DIR) 51 | if not os.path.exists(temp_dir): 52 | os.makedirs(temp_dir) 53 | in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) 54 | json.dump(input_data, in_file, indent=2) 55 | in_file.close() 56 | 57 | # Start job 58 | out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) 59 | out_file.close() 60 | cache_dir=os.path.join(cwd, CACHE_DIR) 61 | if not os.path.exists(cache_dir): 62 | os.makedirs(cache_dir) 63 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, 64 | '-cache', cache_dir, 65 | '-out', out_file.name, 66 | '-subset', 67 | '-silent' 68 | ] 69 | subprocess.check_call(spice_cmd, 70 | cwd=os.path.dirname(os.path.abspath(__file__))) 71 | 72 | # Read and process results 73 | with open(out_file.name) as data_file: 74 | results = json.load(data_file) 75 | os.remove(in_file.name) 76 | os.remove(out_file.name) 77 | 78 | imgId_to_scores = {} 79 | spice_scores = [] 80 | for item in results: 81 | imgId_to_scores[item['image_id']] = item['scores'] 82 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 83 | average_score = np.mean(np.array(spice_scores)) 84 | scores = [] 85 | for image_id in imgIds: 86 | # Convert none to NaN before saving scores over subcategories 87 | score_set = {} 88 | for category,score_tuple in imgId_to_scores[image_id].iteritems(): 89 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()} 90 | scores.append(score_set) 91 | return average_score, scores 92 | 93 | def method(self): 94 | return "SPICE" 95 | 96 | 97 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | # path to the stanford corenlp jar 18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 19 | 20 | # punctuations to be removed from the sentences 21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 23 | 24 | class PTBTokenizer: 25 | """Python wrapper of Stanford PTBTokenizer""" 26 | 27 | def tokenize(self, captions_for_image): 28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 29 | 'edu.stanford.nlp.process.PTBTokenizer', \ 30 | '-preserveLines', '-lowerCase'] 31 | 32 | # ====================================================== 33 | # prepare data for PTB Tokenizer 34 | # ====================================================== 35 | final_tokenized_captions_for_image = {} 36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 37 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 38 | 39 | # ====================================================== 40 | # save sentences to temporary file 41 | # ====================================================== 42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 44 | tmp_file.write(sentences) 45 | tmp_file.close() 46 | 47 | # ====================================================== 48 | # tokenize sentence 49 | # ====================================================== 50 | cmd.append(os.path.basename(tmp_file.name)) 51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 52 | stdout=subprocess.PIPE) 53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 54 | lines = token_lines.split('\n') 55 | # remove temp file 56 | os.remove(tmp_file.name) 57 | 58 | # ====================================================== 59 | # create dictionary for tokenized captions 60 | # ====================================================== 61 | for k, line in zip(image_id, lines): 62 | if not k in final_tokenized_captions_for_image: 63 | final_tokenized_captions_for_image[k] = [] 64 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 65 | if w not in PUNCTUATIONS]) 66 | final_tokenized_captions_for_image[k].append(tokenized_caption) 67 | 68 | return final_tokenized_captions_for_image 69 | -------------------------------------------------------------------------------- /common/coco_caption/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/coco_caption/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /common/coco_caption/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/configuration.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jun 22 14:48:09 2017 4 | 5 | @author: jiahuei 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import cPickle as pickle 14 | from time import localtime, strftime 15 | from natural_sort import natural_keys 16 | 17 | 18 | class Config(object): 19 | """ Configuration object.""" 20 | def __init__(self, **kwargs): 21 | for key, value in sorted(kwargs.iteritems()): 22 | setattr(self, key, value) 23 | 24 | 25 | def save_config_to_file(self): 26 | params = sorted(self.__dict__.keys(), key=natural_keys) 27 | f_dump = ['%s = %s' % (k, self.__dict__[k]) for k in params] 28 | config_name = 'config___%s.txt' % strftime('%Y-%m-%d_%H-%M-%S', localtime()) 29 | with open(os.path.join(self.log_path, config_name), 'w') as f: 30 | f.write('\r\n'.join(f_dump)) 31 | # Save the dictionary instead of the object for maximum flexibility 32 | # Avoid this error: 33 | # https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules 34 | with open(os.path.join(self.log_path, 'config.pkl'), 'wb') as f: 35 | pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL) 36 | 37 | 38 | def overwrite_safety_check(self, overwrite): 39 | """ Exits if log_path exists but `overwrite` is set to `False`.""" 40 | path_exists = os.path.exists(self.log_path) 41 | if path_exists: 42 | if not overwrite: 43 | print('\nINFO: log_path already exists. ' 44 | 'Set `overwrite` to True? Exiting now.') 45 | raise SystemExit 46 | else: print('\nINFO: log_path already exists. ' 47 | 'The directory will be overwritten.') 48 | else: 49 | print('\nINFO: log_path does not exist. ' 50 | 'The directory will be created.') 51 | #os.mkdir(self.log_path) 52 | os.makedirs(self.log_path) 53 | 54 | 55 | def load_config(config_filepath): 56 | with open(config_filepath, 'rb') as f: 57 | c_dict = pickle.load(f) 58 | config = Config(**c_dict) 59 | return config 60 | 61 | 62 | -------------------------------------------------------------------------------- /common/get_repo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jun 11 16:57:10 2019 5 | 6 | @author: jiahuei 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | CURR_DIR = os.path.dirname(os.path.realpath(__file__)) 14 | import utils 15 | pjoin = os.path.join 16 | 17 | 18 | common = CURR_DIR 19 | scst = pjoin(CURR_DIR, 'scst') 20 | 21 | ## 22 | 23 | print('\nINFO: Fetching `tylin/coco-caption` @ commit 3a9afb2 ...') 24 | dest = common 25 | zip_path = utils.maybe_download_from_url( 26 | r'https://github.com/tylin/coco-caption/archive/3a9afb2682141a03e1cdc02b0df6770d2c884f6f.zip', 27 | dest) 28 | utils.extract_zip(zip_path) 29 | os.remove(zip_path) 30 | old_name = pjoin(dest, 'coco-caption-3a9afb2682141a03e1cdc02b0df6770d2c884f6f') 31 | new_name = pjoin(dest, 'coco_caption') 32 | os.rename(old_name, new_name) 33 | 34 | 35 | print('\nINFO: Fetching `ruotianluo/cider` @ commit 77dff32 ...') 36 | dest = scst 37 | zip_path = utils.maybe_download_from_url( 38 | r'https://github.com/ruotianluo/cider/archive/dbb3960165d86202ed3c417b412a000fc8e717f3.zip', 39 | dest) 40 | utils.extract_zip(zip_path) 41 | os.remove(zip_path) 42 | old_name = pjoin(dest, 'cider-dbb3960165d86202ed3c417b412a000fc8e717f3') 43 | new_name = pjoin(dest, 'cider_ruotianluo') 44 | os.rename(old_name, new_name) 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /common/inputs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 28 18:12:55 2017 4 | 5 | @author: jiahuei 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /common/inputs/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/inputs/preprocessing/cifar_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | if image.dtype != tf.float32: 55 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 56 | 57 | if padding > 0: 58 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 59 | # Randomly crop a [height, width] section of the image. 60 | distorted_image = tf.random_crop(image, 61 | [output_height, output_width, 3]) 62 | 63 | # Randomly flip the image horizontally. 64 | distorted_image = tf.image.random_flip_left_right(distorted_image) 65 | 66 | if add_image_summaries: 67 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 68 | ''' 69 | # Because these operations are not commutative, consider randomizing 70 | # the order their operation. 71 | distorted_image = tf.image.random_brightness(distorted_image, 72 | max_delta=63) 73 | distorted_image = tf.image.random_contrast(distorted_image, 74 | lower=0.2, upper=1.8) 75 | ''' 76 | distorted_image = tf.subtract(distorted_image, 0.5) 77 | distorted_image = tf.multiply(distorted_image, 2.0) 78 | return distorted_image 79 | 80 | 81 | def preprocess_for_eval(image, output_height, output_width, 82 | add_image_summaries=True): 83 | """Preprocesses the given image for evaluation. 84 | 85 | Args: 86 | image: A `Tensor` representing an image of arbitrary size. 87 | output_height: The height of the image after preprocessing. 88 | output_width: The width of the image after preprocessing. 89 | add_image_summaries: Enable image summaries. 90 | 91 | Returns: 92 | A preprocessed image. 93 | """ 94 | if add_image_summaries: 95 | tf.summary.image('image', tf.expand_dims(image, 0)) 96 | # Transform the image to floats. 97 | if image.dtype != tf.float32: 98 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 99 | 100 | # Resize and crop if needed. 101 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 102 | output_width, 103 | output_height) 104 | if add_image_summaries: 105 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 106 | 107 | image = tf.subtract(image, 0.5) 108 | image = tf.multiply(image, 2.0) 109 | return image 110 | 111 | 112 | def preprocess_image(image, output_height, output_width, is_training=False, 113 | add_image_summaries=True): 114 | """Preprocesses the given image. 115 | 116 | Args: 117 | image: A `Tensor` representing an image of arbitrary size. 118 | output_height: The height of the image after preprocessing. 119 | output_width: The width of the image after preprocessing. 120 | is_training: `True` if we're preprocessing the image for training and 121 | `False` otherwise. 122 | add_image_summaries: Enable image summaries. 123 | 124 | Returns: 125 | A preprocessed image. 126 | """ 127 | if is_training: 128 | return preprocess_for_train( 129 | image, output_height, output_width, 130 | add_image_summaries=add_image_summaries) 131 | else: 132 | return preprocess_for_eval( 133 | image, output_height, output_width, 134 | add_image_summaries=add_image_summaries) -------------------------------------------------------------------------------- /common/inputs/preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /common/inputs/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from inputs.preprocessing import cifarnet_preprocessing 24 | from inputs.preprocessing import cifar_preprocessing 25 | from inputs.preprocessing import inception_preprocessing_radix as inception_preprocessing 26 | from inputs.preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v1': inception_preprocessing, 55 | 'inception_resnet_v2': inception_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'mobilenet_v2': inception_preprocessing, 58 | 'mobilenet_v2_035': inception_preprocessing, 59 | 'mobilenet_v2_140': inception_preprocessing, 60 | 'nasnet_mobile': inception_preprocessing, 61 | 'nasnet_small': inception_preprocessing, 62 | 'nasnet_large': inception_preprocessing, 63 | 'resnet_v1_50': vgg_preprocessing, 64 | 'resnet_v1_101': vgg_preprocessing, 65 | 'resnet_v1_152': vgg_preprocessing, 66 | 'resnet_v1_200': vgg_preprocessing, 67 | 'resnet_v2_50': vgg_preprocessing, 68 | 'resnet_v2_101': vgg_preprocessing, 69 | 'resnet_v2_152': vgg_preprocessing, 70 | 'resnet_v2_200': vgg_preprocessing, 71 | 'vgg': vgg_preprocessing, 72 | 'vgg_a': vgg_preprocessing, 73 | 'vgg_16': vgg_preprocessing, 74 | 'vgg_19': vgg_preprocessing, 75 | } 76 | 77 | if name not in preprocessing_fn_map: 78 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 79 | 80 | def preprocessing_fn(image, output_height, output_width, **kwargs): 81 | return preprocessing_fn_map[name].preprocess_image( 82 | image, output_height, output_width, is_training=is_training, **kwargs) 83 | 84 | return preprocessing_fn 85 | -------------------------------------------------------------------------------- /common/natural_sort.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Jun 28 15:36:44 2018 5 | 6 | @author: jiahuei 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | 13 | import re 14 | 15 | def atoi(text): 16 | return int(text) if text.isdigit() else text 17 | 18 | def natural_keys(text): 19 | ''' 20 | alist.sort(key=natural_keys) sorts in human order 21 | http://nedbatchelder.com/blog/200712/human_sorting.html 22 | (See Toothy's implementation in the comments) 23 | ''' 24 | return [ atoi(c) for c in re.split('(\d+)', text) ] 25 | 26 | ''' 27 | for d in tqdm(sorted(dirs)): 28 | if os.path.isfile(d): continue 29 | model_files = [] 30 | compact_files = [] 31 | for f in sorted(os.listdir(d)): 32 | if 'model_compact-' in f: 33 | compact_files.append(pjoin(d, f)) 34 | elif 'model-' in f: 35 | model_files.append(pjoin(d, f)) 36 | else: 37 | pass 38 | compact_files.sort(key=natural_keys) 39 | model_files.sort(key=natural_keys) 40 | for f in compact_files[:-3]: 41 | os.remove(f) 42 | for f in model_files[:-3]: 43 | os.remove(f) 44 | ''' 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /common/net_params.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 9 23:21:36 2019 4 | 5 | @author: jiahuei 6 | 7 | Network parameters, preprocessing functions, etc. 8 | 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os 15 | 16 | pjoin = os.path.join 17 | 18 | 19 | all_net_params = dict( 20 | vgg_16 = dict( 21 | name = 'vgg_16', 22 | ckpt_path = 'vgg_16.ckpt', 23 | url = 'http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz', 24 | ), 25 | resnet_v1_50 = dict( 26 | name = 'resnet_v1_50', 27 | ckpt_path = 'resnet_v1_50.ckpt', 28 | url = 'http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz', 29 | ), 30 | resnet_v1_101 = dict( 31 | name = 'resnet_v1_101', 32 | ckpt_path = 'resnet_v1_101.ckpt', 33 | url = 'http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz', 34 | ), 35 | resnet_v1_152 = dict( 36 | name = 'resnet_v1_152', 37 | ckpt_path = 'resnet_v1_152.ckpt', 38 | url = 'http://download.tensorflow.org/models/resnet_v1_152_2016_08_28.tar.gz', 39 | ), 40 | resnet_v2_50 = dict( 41 | name = 'resnet_v2_50', 42 | ckpt_path = 'resnet_v2_50.ckpt', 43 | url = 'http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz', 44 | ), 45 | resnet_v2_101 = dict( 46 | name = 'resnet_v2_101', 47 | ckpt_path = 'resnet_v2_101.ckpt', 48 | url = 'http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz', 49 | ), 50 | resnet_v2_152 = dict( 51 | name = 'resnet_v2_152', 52 | ckpt_path = 'resnet_v2_152.ckpt', 53 | url = 'http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz', 54 | ), 55 | inception_v1 = dict( 56 | name = 'inception_v1', 57 | ckpt_path = 'inception_v1.ckpt', 58 | url = 'http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz', 59 | ), 60 | inception_v2 = dict( 61 | name = 'inception_v2', 62 | ckpt_path = 'inception_v2.ckpt', 63 | url = 'http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz', 64 | ), 65 | inception_v3 = dict( 66 | name = 'inception_v3', 67 | ckpt_path = 'inception_v3.ckpt', 68 | url = 'http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz', 69 | ), 70 | inception_v4 = dict( 71 | name = 'inception_v4', 72 | ckpt_path = 'inception_v4.ckpt', 73 | url = 'http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz', 74 | ), 75 | inception_resnet_v2 = dict( 76 | name = 'inception_resnet_v2', 77 | ckpt_path = 'inception_resnet_v2_2016_08_30.ckpt', 78 | url = 'http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz', 79 | ), 80 | mobilenet_v2 = dict( 81 | name = 'mobilenet_v2', 82 | ckpt_path = 'mobilenet_v2_1.0_224.ckpt', 83 | url = 'https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz', 84 | ), 85 | mobilenet_v2_140 = dict( 86 | name = 'mobilenet_v2_140', 87 | ckpt_path = 'mobilenet_v2_1.4_224.ckpt', 88 | url = 'https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz', 89 | ), 90 | ) 91 | 92 | 93 | def get_net_params(net_name, ckpt_dir_or_file=''): 94 | net_params = all_net_params[net_name] 95 | ckpt_name = net_params['ckpt_path'] 96 | 97 | if ckpt_dir_or_file is None or ckpt_dir_or_file == '': 98 | base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 99 | ckpt_dir_or_file = pjoin(base_dir, 'ckpt', ckpt_name) 100 | else: 101 | if os.path.isdir(ckpt_dir_or_file): 102 | ckpt_dir_or_file = pjoin(ckpt_dir_or_file, ckpt_name) 103 | if os.path.isfile(ckpt_dir_or_file): 104 | assert os.path.basename(ckpt_dir_or_file) == ckpt_name 105 | net_params['ckpt_path'] = ckpt_dir_or_file 106 | return net_params 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /common/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /common/nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 256, 256 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 224, 224 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 70 | expected_names = ['alexnet_v2/conv1', 71 | 'alexnet_v2/pool1', 72 | 'alexnet_v2/conv2', 73 | 'alexnet_v2/pool2', 74 | 'alexnet_v2/conv3', 75 | 'alexnet_v2/conv4', 76 | 'alexnet_v2/conv5', 77 | 'alexnet_v2/pool5', 78 | 'alexnet_v2/fc6', 79 | 'alexnet_v2/fc7', 80 | 'alexnet_v2/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 224, 224 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = alexnet.alexnet_v2(inputs, num_classes) 91 | expected_names = ['alexnet_v2/conv1', 92 | 'alexnet_v2/pool1', 93 | 'alexnet_v2/conv2', 94 | 'alexnet_v2/pool2', 95 | 'alexnet_v2/conv3', 96 | 'alexnet_v2/conv4', 97 | 'alexnet_v2/conv5', 98 | 'alexnet_v2/pool5', 99 | 'alexnet_v2/fc6', 100 | 'alexnet_v2/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('alexnet_v2/fc7')) 104 | self.assertListEqual(net.get_shape().as_list(), 105 | [batch_size, 1, 1, 4096]) 106 | 107 | def testModelVariables(self): 108 | batch_size = 5 109 | height, width = 224, 224 110 | num_classes = 1000 111 | with self.test_session(): 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | alexnet.alexnet_v2(inputs, num_classes) 114 | expected_names = ['alexnet_v2/conv1/weights', 115 | 'alexnet_v2/conv1/biases', 116 | 'alexnet_v2/conv2/weights', 117 | 'alexnet_v2/conv2/biases', 118 | 'alexnet_v2/conv3/weights', 119 | 'alexnet_v2/conv3/biases', 120 | 'alexnet_v2/conv4/weights', 121 | 'alexnet_v2/conv4/biases', 122 | 'alexnet_v2/conv5/weights', 123 | 'alexnet_v2/conv5/biases', 124 | 'alexnet_v2/fc6/weights', 125 | 'alexnet_v2/fc6/biases', 126 | 'alexnet_v2/fc7/weights', 127 | 'alexnet_v2/fc7/biases', 128 | 'alexnet_v2/fc8/weights', 129 | 'alexnet_v2/fc8/biases', 130 | ] 131 | model_variables = [v.op.name for v in slim.get_model_variables()] 132 | self.assertSetEqual(set(model_variables), set(expected_names)) 133 | 134 | def testEvaluation(self): 135 | batch_size = 2 136 | height, width = 224, 224 137 | num_classes = 1000 138 | with self.test_session(): 139 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 140 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 141 | self.assertListEqual(logits.get_shape().as_list(), 142 | [batch_size, num_classes]) 143 | predictions = tf.argmax(logits, 1) 144 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 145 | 146 | def testTrainEvalWithReuse(self): 147 | train_batch_size = 2 148 | eval_batch_size = 1 149 | train_height, train_width = 224, 224 150 | eval_height, eval_width = 300, 400 151 | num_classes = 1000 152 | with self.test_session(): 153 | train_inputs = tf.random_uniform( 154 | (train_batch_size, train_height, train_width, 3)) 155 | logits, _ = alexnet.alexnet_v2(train_inputs) 156 | self.assertListEqual(logits.get_shape().as_list(), 157 | [train_batch_size, num_classes]) 158 | tf.get_variable_scope().reuse_variables() 159 | eval_inputs = tf.random_uniform( 160 | (eval_batch_size, eval_height, eval_width, 3)) 161 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 162 | spatial_squeeze=False) 163 | self.assertListEqual(logits.get_shape().as_list(), 164 | [eval_batch_size, 4, 7, num_classes]) 165 | logits = tf.reduce_mean(logits, [1, 2]) 166 | predictions = tf.argmax(logits, 1) 167 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 168 | 169 | def testForward(self): 170 | batch_size = 1 171 | height, width = 224, 224 172 | with self.test_session() as sess: 173 | inputs = tf.random_uniform((batch_size, height, width, 3)) 174 | logits, _ = alexnet.alexnet_v2(inputs) 175 | sess.run(tf.global_variables_initializer()) 176 | output = sess.run(logits) 177 | self.assertTrue(output.any()) 178 | 179 | if __name__ == '__main__': 180 | tf.test.main() 181 | -------------------------------------------------------------------------------- /common/nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /common/nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /common/nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | 24 | from nets import dcgan 25 | 26 | 27 | class DCGANTest(tf.test.TestCase): 28 | 29 | def test_generator_run(self): 30 | tf.set_random_seed(1234) 31 | noise = tf.random_normal([100, 64]) 32 | image, _ = dcgan.generator(noise) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | image.eval() 36 | 37 | def test_generator_graph(self): 38 | tf.set_random_seed(1234) 39 | # Check graph construction for a number of image size/depths and batch 40 | # sizes. 41 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 42 | tf.reset_default_graph() 43 | final_size = 2 ** i 44 | noise = tf.random_normal([batch_size, 64]) 45 | image, end_points = dcgan.generator( 46 | noise, 47 | depth=32, 48 | final_size=final_size) 49 | 50 | self.assertAllEqual([batch_size, final_size, final_size, 3], 51 | image.shape.as_list()) 52 | 53 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 54 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 55 | 56 | # Check layer depths. 57 | for j in range(1, i): 58 | layer = end_points['deconv%i' % j] 59 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 60 | 61 | def test_generator_invalid_input(self): 62 | wrong_dim_input = tf.zeros([5, 32, 32]) 63 | with self.assertRaises(ValueError): 64 | dcgan.generator(wrong_dim_input) 65 | 66 | correct_input = tf.zeros([3, 2]) 67 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 68 | dcgan.generator(correct_input, final_size=30) 69 | 70 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 71 | dcgan.generator(correct_input, final_size=4) 72 | 73 | def test_discriminator_run(self): 74 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 75 | output, _ = dcgan.discriminator(image) 76 | with self.test_session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | output.eval() 79 | 80 | def test_discriminator_graph(self): 81 | # Check graph construction for a number of image size/depths and batch 82 | # sizes. 83 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 84 | tf.reset_default_graph() 85 | img_w = 2 ** i 86 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 87 | output, end_points = dcgan.discriminator( 88 | image, 89 | depth=32) 90 | 91 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 92 | 93 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 94 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 95 | 96 | # Check layer depths. 97 | for j in range(1, i+1): 98 | layer = end_points['conv%i' % j] 99 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 100 | 101 | def test_discriminator_invalid_input(self): 102 | wrong_dim_img = tf.zeros([5, 32, 32]) 103 | with self.assertRaises(ValueError): 104 | dcgan.discriminator(wrong_dim_img) 105 | 106 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 107 | with self.assertRaises(ValueError): 108 | dcgan.discriminator(spatially_undefined_shape) 109 | 110 | not_square = tf.zeros([5, 32, 16, 3]) 111 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 112 | dcgan.discriminator(not_square) 113 | 114 | not_power_2 = tf.zeros([5, 30, 30, 3]) 115 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 116 | dcgan.discriminator(not_power_2) 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.test.main() 121 | -------------------------------------------------------------------------------- /common/nets/i3d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition for Inflated 3D Inception V1 (I3D). 16 | 17 | The network architecture is proposed by: 18 | Joao Carreira and Andrew Zisserman, 19 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset. 20 | https://arxiv.org/abs/1705.07750 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | from nets import i3d_utils 30 | from nets import s3dg 31 | 32 | slim = tf.contrib.slim 33 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 34 | conv3d_spatiotemporal = i3d_utils.conv3d_spatiotemporal 35 | 36 | 37 | def i3d_arg_scope(weight_decay=1e-7, 38 | batch_norm_decay=0.999, 39 | batch_norm_epsilon=0.001, 40 | use_renorm=False, 41 | separable_conv3d=False): 42 | """Defines default arg_scope for I3D. 43 | 44 | Args: 45 | weight_decay: The weight decay to use for regularizing the model. 46 | batch_norm_decay: Decay for batch norm moving average. 47 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 48 | in batch norm. 49 | use_renorm: Whether to use batch renormalization or not. 50 | separable_conv3d: Whether to use separable 3d Convs. 51 | 52 | Returns: 53 | sc: An arg_scope to use for the models. 54 | """ 55 | batch_norm_params = { 56 | # Decay for the moving averages. 57 | 'decay': batch_norm_decay, 58 | # epsilon to prevent 0s in variance. 59 | 'epsilon': batch_norm_epsilon, 60 | # Turns off fused batch norm. 61 | 'fused': False, 62 | 'renorm': use_renorm, 63 | # collection containing the moving mean and moving variance. 64 | 'variables_collections': { 65 | 'beta': None, 66 | 'gamma': None, 67 | 'moving_mean': ['moving_vars'], 68 | 'moving_variance': ['moving_vars'], 69 | } 70 | } 71 | 72 | with slim.arg_scope( 73 | [slim.conv3d, conv3d_spatiotemporal], 74 | weights_regularizer=slim.l2_regularizer(weight_decay), 75 | activation_fn=tf.nn.relu, 76 | normalizer_fn=slim.batch_norm, 77 | normalizer_params=batch_norm_params): 78 | with slim.arg_scope( 79 | [conv3d_spatiotemporal], separable=separable_conv3d) as sc: 80 | return sc 81 | 82 | 83 | def i3d_base(inputs, final_endpoint='Mixed_5c', 84 | scope='InceptionV1'): 85 | """Defines the I3D base architecture. 86 | 87 | Note that we use the names as defined in Inception V1 to facilitate checkpoint 88 | conversion from an image-trained Inception V1 checkpoint to I3D checkpoint. 89 | 90 | Args: 91 | inputs: A 5-D float tensor of size [batch_size, num_frames, height, width, 92 | channels]. 93 | final_endpoint: Specifies the endpoint to construct the network up to. It 94 | can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 95 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 96 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 97 | 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c'] 98 | scope: Optional variable_scope. 99 | 100 | Returns: 101 | A dictionary from components of the network to the corresponding activation. 102 | 103 | Raises: 104 | ValueError: if final_endpoint is not set to one of the predefined values. 105 | """ 106 | 107 | return s3dg.s3dg_base( 108 | inputs, 109 | first_temporal_kernel_size=7, 110 | temporal_conv_startat='Conv2d_2c_3x3', 111 | gating_startat=None, 112 | final_endpoint=final_endpoint, 113 | min_depth=16, 114 | depth_multiplier=1.0, 115 | data_format='NDHWC', 116 | scope=scope) 117 | 118 | 119 | def i3d(inputs, 120 | num_classes=1000, 121 | dropout_keep_prob=0.8, 122 | is_training=True, 123 | prediction_fn=slim.softmax, 124 | spatial_squeeze=True, 125 | reuse=None, 126 | scope='InceptionV1'): 127 | """Defines the I3D architecture. 128 | 129 | The default image size used to train this network is 224x224. 130 | 131 | Args: 132 | inputs: A 5-D float tensor of size [batch_size, num_frames, height, width, 133 | channels]. 134 | num_classes: number of predicted classes. 135 | dropout_keep_prob: the percentage of activation values that are retained. 136 | is_training: whether is training or not. 137 | prediction_fn: a function to get predictions out of logits. 138 | spatial_squeeze: if True, logits is of shape is [B, C], if false logits is 139 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 140 | reuse: whether or not the network and its variables should be reused. To be 141 | able to reuse 'scope' must be given. 142 | scope: Optional variable_scope. 143 | 144 | Returns: 145 | logits: the pre-softmax activations, a tensor of size 146 | [batch_size, num_classes] 147 | end_points: a dictionary from components of the network to the corresponding 148 | activation. 149 | """ 150 | # Final pooling and prediction 151 | with tf.variable_scope( 152 | scope, 'InceptionV1', [inputs, num_classes], reuse=reuse) as scope: 153 | with slim.arg_scope( 154 | [slim.batch_norm, slim.dropout], is_training=is_training): 155 | net, end_points = i3d_base(inputs, scope=scope) 156 | with tf.variable_scope('Logits'): 157 | kernel_size = i3d_utils.reduced_kernel_size_3d(net, [2, 7, 7]) 158 | net = slim.avg_pool3d( 159 | net, kernel_size, stride=1, scope='AvgPool_0a_7x7') 160 | net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b') 161 | logits = slim.conv3d( 162 | net, 163 | num_classes, [1, 1, 1], 164 | activation_fn=None, 165 | normalizer_fn=None, 166 | scope='Conv2d_0c_1x1') 167 | # Temporal average pooling. 168 | logits = tf.reduce_mean(logits, axis=1) 169 | if spatial_squeeze: 170 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 171 | 172 | end_points['Logits'] = logits 173 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 174 | return logits, end_points 175 | 176 | 177 | i3d.default_image_size = 224 178 | -------------------------------------------------------------------------------- /common/nets/i3d_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for networks.i3d.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import i3d 24 | 25 | 26 | class I3DTest(tf.test.TestCase): 27 | 28 | def testBuildClassificationNetwork(self): 29 | batch_size = 5 30 | num_frames = 64 31 | height, width = 224, 224 32 | num_classes = 1000 33 | 34 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 35 | logits, end_points = i3d.i3d(inputs, num_classes) 36 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | self.assertTrue('Predictions' in end_points) 40 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 41 | [batch_size, num_classes]) 42 | 43 | def testBuildBaseNetwork(self): 44 | batch_size = 5 45 | num_frames = 64 46 | height, width = 224, 224 47 | 48 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 49 | mixed_6c, end_points = i3d.i3d_base(inputs) 50 | self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c')) 51 | self.assertListEqual(mixed_6c.get_shape().as_list(), 52 | [batch_size, 8, 7, 7, 1024]) 53 | expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 54 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 55 | 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 56 | 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 57 | 'Mixed_5b', 'Mixed_5c'] 58 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 59 | 60 | def testBuildOnlyUptoFinalEndpoint(self): 61 | batch_size = 5 62 | num_frames = 64 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 67 | 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 68 | 'Mixed_5c'] 69 | for index, endpoint in enumerate(endpoints): 70 | with tf.Graph().as_default(): 71 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 72 | out_tensor, end_points = i3d.i3d_base( 73 | inputs, final_endpoint=endpoint) 74 | self.assertTrue(out_tensor.op.name.startswith( 75 | 'InceptionV1/' + endpoint)) 76 | self.assertItemsEqual(endpoints[:index+1], end_points) 77 | 78 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 79 | batch_size = 5 80 | num_frames = 64 81 | height, width = 224, 224 82 | 83 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 84 | _, end_points = i3d.i3d_base(inputs, 85 | final_endpoint='Mixed_5c') 86 | endpoints_shapes = {'Conv2d_1a_7x7': [5, 32, 112, 112, 64], 87 | 'MaxPool_2a_3x3': [5, 32, 56, 56, 64], 88 | 'Conv2d_2b_1x1': [5, 32, 56, 56, 64], 89 | 'Conv2d_2c_3x3': [5, 32, 56, 56, 192], 90 | 'MaxPool_3a_3x3': [5, 32, 28, 28, 192], 91 | 'Mixed_3b': [5, 32, 28, 28, 256], 92 | 'Mixed_3c': [5, 32, 28, 28, 480], 93 | 'MaxPool_4a_3x3': [5, 16, 14, 14, 480], 94 | 'Mixed_4b': [5, 16, 14, 14, 512], 95 | 'Mixed_4c': [5, 16, 14, 14, 512], 96 | 'Mixed_4d': [5, 16, 14, 14, 512], 97 | 'Mixed_4e': [5, 16, 14, 14, 528], 98 | 'Mixed_4f': [5, 16, 14, 14, 832], 99 | 'MaxPool_5a_2x2': [5, 8, 7, 7, 832], 100 | 'Mixed_5b': [5, 8, 7, 7, 832], 101 | 'Mixed_5c': [5, 8, 7, 7, 1024]} 102 | 103 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 104 | for endpoint_name, expected_shape in endpoints_shapes.iteritems(): 105 | self.assertTrue(endpoint_name in end_points) 106 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 107 | expected_shape) 108 | 109 | def testHalfSizeImages(self): 110 | batch_size = 5 111 | num_frames = 64 112 | height, width = 112, 112 113 | 114 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 115 | mixed_5c, _ = i3d.i3d_base(inputs) 116 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 117 | self.assertListEqual(mixed_5c.get_shape().as_list(), 118 | [batch_size, 8, 4, 4, 1024]) 119 | 120 | def testTenFrames(self): 121 | batch_size = 5 122 | num_frames = 10 123 | height, width = 224, 224 124 | 125 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 126 | mixed_5c, _ = i3d.i3d_base(inputs) 127 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 128 | self.assertListEqual(mixed_5c.get_shape().as_list(), 129 | [batch_size, 2, 7, 7, 1024]) 130 | 131 | def testEvaluation(self): 132 | batch_size = 2 133 | num_frames = 64 134 | height, width = 224, 224 135 | num_classes = 1000 136 | 137 | eval_inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 138 | logits, _ = i3d.i3d(eval_inputs, num_classes, 139 | is_training=False) 140 | predictions = tf.argmax(logits, 1) 141 | 142 | with self.test_session() as sess: 143 | sess.run(tf.global_variables_initializer()) 144 | output = sess.run(predictions) 145 | self.assertEquals(output.shape, (batch_size,)) 146 | 147 | 148 | if __name__ == '__main__': 149 | tf.test.main() 150 | -------------------------------------------------------------------------------- /common/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /common/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu, 37 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS, 38 | batch_norm_scale=False): 39 | """Defines the default arg scope for inception models. 40 | 41 | Args: 42 | weight_decay: The weight decay to use for regularizing the model. 43 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 44 | batch_norm_decay: Decay for batch norm moving average. 45 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 46 | in batch norm. 47 | activation_fn: Activation function for conv2d. 48 | batch_norm_updates_collections: Collection for the update ops for 49 | batch norm. 50 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 51 | activations in the batch normalization layer. 52 | 53 | Returns: 54 | An `arg_scope` to use for the inception models. 55 | """ 56 | batch_norm_params = { 57 | # Decay for the moving averages. 58 | 'decay': batch_norm_decay, 59 | # epsilon to prevent 0s in variance. 60 | 'epsilon': batch_norm_epsilon, 61 | # collection containing update_ops. 62 | 'updates_collections': batch_norm_updates_collections, 63 | # use fused batch norm if possible. 64 | 'fused': None, 65 | 'scale': batch_norm_scale, 66 | } 67 | if use_batch_norm: 68 | normalizer_fn = slim.batch_norm 69 | normalizer_params = batch_norm_params 70 | else: 71 | normalizer_fn = None 72 | normalizer_params = {} 73 | # Set weight_decay for weights in Conv and FC layers. 74 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 75 | weights_regularizer=slim.l2_regularizer(weight_decay)): 76 | with slim.arg_scope( 77 | [slim.conv2d], 78 | weights_initializer=slim.variance_scaling_initializer(), 79 | activation_fn=activation_fn, 80 | normalizer_fn=normalizer_fn, 81 | normalizer_params=normalizer_params) as sc: 82 | return sc 83 | -------------------------------------------------------------------------------- /common/nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /common/nets/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 2 | This folder contains building code for MobileNetV2, based on 3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 4 | 5 | # Performance 6 | ## Latency 7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using 8 | TF-Lite on the large core of Pixel 1 phone. 9 | 10 | ![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png) 11 | 12 | ## MACs 13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed 14 | to compute an inference on a single image is a common metric to measure the efficiency of the model. 15 | 16 | Below is the graph comparing V2 vs a few selected networks. The size 17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there 18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers. 19 | 20 | ![madds_top1_accuracy](madds_top1_accuracy.png) 21 | 22 | # Pretrained models 23 | ## Imagenet Checkpoints 24 | 25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1 26 | ---------------------------|---------|---------------|---------|----|------------- 27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0 28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0 29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8 30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1 31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2 32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6 33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6 34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8 35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6 36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4 37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9 38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2 39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7 40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1 41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9 42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9 43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4 44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7 45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6 46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5 47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9 48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5 49 | 50 | # Training 51 | The numbers above can be reproduced using slim's `train_image_classifier`. 52 | Below is the set of parameters that achieves 72.0% for full size MobileNetV2, after about 700K when trained on 8 GPU. 53 | If trained on a single GPU the full convergence is after 5.5M steps. Also note that learning rate and 54 | num_epochs_per_decay both need to be adjusted depending on how many GPUs are being 55 | used due to slim's internal averaging. 56 | 57 | ```bash 58 | --model_name="mobilenet_v2" 59 | --learning_rate=0.045 * NUM_GPUS #slim internally averages clones so we compensate 60 | --preprocessing_name="inception_v2" 61 | --label_smoothing=0.1 62 | --moving_average_decay=0.9999 63 | --batch_size= 96 64 | --num_clones = NUM_GPUS # you can use any number here between 1 and 8 depending on your hardware setup. 65 | --learning_rate_decay_factor=0.98 66 | --num_epochs_per_decay = 2.5 / NUM_GPUS # train_image_classifier does per clone epochs 67 | ``` 68 | 69 | # Example 70 | 71 | 72 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb). 73 | 74 | -------------------------------------------------------------------------------- /common/nets/mobilenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/nets/mobilenet/__init__.py -------------------------------------------------------------------------------- /common/nets/mobilenet/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/nets/mobilenet/madds_top1_accuracy.png -------------------------------------------------------------------------------- /common/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /common/nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /common/nets/mobilenet_v1_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Validate mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import mobilenet_v1 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | flags = tf.app.flags 31 | 32 | flags.DEFINE_string('master', '', 'Session master') 33 | flags.DEFINE_integer('batch_size', 250, 'Batch size') 34 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 35 | flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate') 36 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 37 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 38 | flags.DEFINE_bool('quantize', False, 'Quantize training') 39 | flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints') 40 | flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs') 41 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def imagenet_input(is_training): 47 | """Data reader for imagenet. 48 | 49 | Reads in imagenet data and performs pre-processing on the images. 50 | 51 | Args: 52 | is_training: bool specifying if train or validation dataset is needed. 53 | Returns: 54 | A batch of images and labels. 55 | """ 56 | if is_training: 57 | dataset = dataset_factory.get_dataset('imagenet', 'train', 58 | FLAGS.dataset_dir) 59 | else: 60 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 61 | FLAGS.dataset_dir) 62 | 63 | provider = slim.dataset_data_provider.DatasetDataProvider( 64 | dataset, 65 | shuffle=is_training, 66 | common_queue_capacity=2 * FLAGS.batch_size, 67 | common_queue_min=FLAGS.batch_size) 68 | [image, label] = provider.get(['image', 'label']) 69 | 70 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 71 | 'mobilenet_v1', is_training=is_training) 72 | 73 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 74 | 75 | images, labels = tf.train.batch( 76 | tensors=[image, label], 77 | batch_size=FLAGS.batch_size, 78 | num_threads=4, 79 | capacity=5 * FLAGS.batch_size) 80 | return images, labels 81 | 82 | 83 | def metrics(logits, labels): 84 | """Specify the metrics for eval. 85 | 86 | Args: 87 | logits: Logits output from the graph. 88 | labels: Ground truth labels for inputs. 89 | 90 | Returns: 91 | Eval Op for the graph. 92 | """ 93 | labels = tf.squeeze(labels) 94 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 95 | 'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels), 96 | 'Recall_5': tf.metrics.recall_at_k(labels, logits, 5), 97 | }) 98 | for name, value in names_to_values.iteritems(): 99 | slim.summaries.add_scalar_summary( 100 | value, name, prefix='eval', print_summary=True) 101 | return names_to_updates.values() 102 | 103 | 104 | def build_model(): 105 | """Build the mobilenet_v1 model for evaluation. 106 | 107 | Returns: 108 | g: graph with rewrites after insertion of quantization ops and batch norm 109 | folding. 110 | eval_ops: eval ops for inference. 111 | variables_to_restore: List of variables to restore from checkpoint. 112 | """ 113 | g = tf.Graph() 114 | with g.as_default(): 115 | inputs, labels = imagenet_input(is_training=False) 116 | 117 | scope = mobilenet_v1.mobilenet_v1_arg_scope( 118 | is_training=False, weight_decay=0.0) 119 | with slim.arg_scope(scope): 120 | logits, _ = mobilenet_v1.mobilenet_v1( 121 | inputs, 122 | is_training=False, 123 | depth_multiplier=FLAGS.depth_multiplier, 124 | num_classes=FLAGS.num_classes) 125 | 126 | if FLAGS.quantize: 127 | tf.contrib.quantize.create_eval_graph() 128 | 129 | eval_ops = metrics(logits, labels) 130 | 131 | return g, eval_ops 132 | 133 | 134 | def eval_model(): 135 | """Evaluates mobilenet_v1.""" 136 | g, eval_ops = build_model() 137 | with g.as_default(): 138 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) 139 | slim.evaluation.evaluate_once( 140 | FLAGS.master, 141 | FLAGS.checkpoint_dir, 142 | logdir=FLAGS.eval_dir, 143 | num_evals=num_batches, 144 | eval_op=eval_ops) 145 | 146 | 147 | def main(unused_arg): 148 | eval_model() 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run(main) 153 | -------------------------------------------------------------------------------- /common/nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /common/nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /common/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in list(nets_factory.networks_map.keys())[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes=num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | if net not in ['i3d', 's3dg']: 39 | inputs = tf.random_uniform( 40 | (batch_size, image_size, image_size, 3)) 41 | logits, end_points = net_fn(inputs) 42 | self.assertTrue(isinstance(logits, tf.Tensor)) 43 | self.assertTrue(isinstance(end_points, dict)) 44 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 45 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 46 | 47 | def testGetNetworkFnSecondHalf(self): 48 | batch_size = 5 49 | num_classes = 1000 50 | for net in list(nets_factory.networks_map.keys())[10:]: 51 | with tf.Graph().as_default() as g, self.test_session(g): 52 | net_fn = nets_factory.get_network_fn(net, num_classes=num_classes) 53 | # Most networks use 224 as their default_image_size 54 | image_size = getattr(net_fn, 'default_image_size', 224) 55 | if net not in ['i3d', 's3dg']: 56 | inputs = tf.random_uniform( 57 | (batch_size, image_size, image_size, 3)) 58 | logits, end_points = net_fn(inputs) 59 | self.assertTrue(isinstance(logits, tf.Tensor)) 60 | self.assertTrue(isinstance(end_points, dict)) 61 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 62 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 63 | 64 | def testGetNetworkFnVideoModels(self): 65 | batch_size = 5 66 | num_classes = 400 67 | for net in ['i3d', 's3dg']: 68 | with tf.Graph().as_default() as g, self.test_session(g): 69 | net_fn = nets_factory.get_network_fn(net, num_classes=num_classes) 70 | # Most networks use 224 as their default_image_size 71 | image_size = getattr(net_fn, 'default_image_size', 224) // 2 72 | inputs = tf.random_uniform( 73 | (batch_size, 10, image_size, image_size, 3)) 74 | logits, end_points = net_fn(inputs) 75 | self.assertTrue(isinstance(logits, tf.Tensor)) 76 | self.assertTrue(isinstance(end_points, dict)) 77 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 78 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 79 | 80 | if __name__ == '__main__': 81 | tf.test.main() 82 | -------------------------------------------------------------------------------- /common/nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /common/nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 281, 281 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 231, 231 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = overfeat.overfeat(inputs, num_classes) 70 | expected_names = ['overfeat/conv1', 71 | 'overfeat/pool1', 72 | 'overfeat/conv2', 73 | 'overfeat/pool2', 74 | 'overfeat/conv3', 75 | 'overfeat/conv4', 76 | 'overfeat/conv5', 77 | 'overfeat/pool5', 78 | 'overfeat/fc6', 79 | 'overfeat/fc7', 80 | 'overfeat/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 231, 231 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = overfeat.overfeat(inputs, num_classes) 91 | expected_names = ['overfeat/conv1', 92 | 'overfeat/pool1', 93 | 'overfeat/conv2', 94 | 'overfeat/pool2', 95 | 'overfeat/conv3', 96 | 'overfeat/conv4', 97 | 'overfeat/conv5', 98 | 'overfeat/pool5', 99 | 'overfeat/fc6', 100 | 'overfeat/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('overfeat/fc7')) 104 | 105 | def testModelVariables(self): 106 | batch_size = 5 107 | height, width = 231, 231 108 | num_classes = 1000 109 | with self.test_session(): 110 | inputs = tf.random_uniform((batch_size, height, width, 3)) 111 | overfeat.overfeat(inputs, num_classes) 112 | expected_names = ['overfeat/conv1/weights', 113 | 'overfeat/conv1/biases', 114 | 'overfeat/conv2/weights', 115 | 'overfeat/conv2/biases', 116 | 'overfeat/conv3/weights', 117 | 'overfeat/conv3/biases', 118 | 'overfeat/conv4/weights', 119 | 'overfeat/conv4/biases', 120 | 'overfeat/conv5/weights', 121 | 'overfeat/conv5/biases', 122 | 'overfeat/fc6/weights', 123 | 'overfeat/fc6/biases', 124 | 'overfeat/fc7/weights', 125 | 'overfeat/fc7/biases', 126 | 'overfeat/fc8/weights', 127 | 'overfeat/fc8/biases', 128 | ] 129 | model_variables = [v.op.name for v in slim.get_model_variables()] 130 | self.assertSetEqual(set(model_variables), set(expected_names)) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | height, width = 231, 231 135 | num_classes = 1000 136 | with self.test_session(): 137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | predictions = tf.argmax(logits, 1) 142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 143 | 144 | def testTrainEvalWithReuse(self): 145 | train_batch_size = 2 146 | eval_batch_size = 1 147 | train_height, train_width = 231, 231 148 | eval_height, eval_width = 281, 281 149 | num_classes = 1000 150 | with self.test_session(): 151 | train_inputs = tf.random_uniform( 152 | (train_batch_size, train_height, train_width, 3)) 153 | logits, _ = overfeat.overfeat(train_inputs) 154 | self.assertListEqual(logits.get_shape().as_list(), 155 | [train_batch_size, num_classes]) 156 | tf.get_variable_scope().reuse_variables() 157 | eval_inputs = tf.random_uniform( 158 | (eval_batch_size, eval_height, eval_width, 3)) 159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 160 | spatial_squeeze=False) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [eval_batch_size, 2, 2, num_classes]) 163 | logits = tf.reduce_mean(logits, [1, 2]) 164 | predictions = tf.argmax(logits, 1) 165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 166 | 167 | def testForward(self): 168 | batch_size = 1 169 | height, width = 231, 231 170 | with self.test_session() as sess: 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | logits, _ = overfeat.overfeat(inputs) 173 | sess.run(tf.global_variables_initializer()) 174 | output = sess.run(logits) 175 | self.assertTrue(output.any()) 176 | 177 | if __name__ == '__main__': 178 | tf.test.main() 179 | -------------------------------------------------------------------------------- /common/nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def _reduced_default_blocks(self): 28 | """Returns the default blocks, scaled down to make test run faster.""" 29 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 30 | for b in pix2pix._default_generator_blocks()] 31 | 32 | def test_output_size_nn_upsample_conv(self): 33 | batch_size = 2 34 | height, width = 256, 256 35 | num_outputs = 4 36 | 37 | images = tf.ones((batch_size, height, width, 3)) 38 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 39 | logits, _ = pix2pix.pix2pix_generator( 40 | images, num_outputs, blocks=self._reduced_default_blocks(), 41 | upsample_method='nn_upsample_conv') 42 | 43 | with self.test_session() as session: 44 | session.run(tf.global_variables_initializer()) 45 | np_outputs = session.run(logits) 46 | self.assertListEqual([batch_size, height, width, num_outputs], 47 | list(np_outputs.shape)) 48 | 49 | def test_output_size_conv2d_transpose(self): 50 | batch_size = 2 51 | height, width = 256, 256 52 | num_outputs = 4 53 | 54 | images = tf.ones((batch_size, height, width, 3)) 55 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 56 | logits, _ = pix2pix.pix2pix_generator( 57 | images, num_outputs, blocks=self._reduced_default_blocks(), 58 | upsample_method='conv2d_transpose') 59 | 60 | with self.test_session() as session: 61 | session.run(tf.global_variables_initializer()) 62 | np_outputs = session.run(logits) 63 | self.assertListEqual([batch_size, height, width, num_outputs], 64 | list(np_outputs.shape)) 65 | 66 | def test_block_number_dictates_number_of_layers(self): 67 | batch_size = 2 68 | height, width = 256, 256 69 | num_outputs = 4 70 | 71 | images = tf.ones((batch_size, height, width, 3)) 72 | blocks = [ 73 | pix2pix.Block(64, 0.5), 74 | pix2pix.Block(128, 0), 75 | ] 76 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 77 | _, end_points = pix2pix.pix2pix_generator( 78 | images, num_outputs, blocks) 79 | 80 | num_encoder_layers = 0 81 | num_decoder_layers = 0 82 | for end_point in end_points: 83 | if end_point.startswith('encoder'): 84 | num_encoder_layers += 1 85 | elif end_point.startswith('decoder'): 86 | num_decoder_layers += 1 87 | 88 | self.assertEqual(num_encoder_layers, len(blocks)) 89 | self.assertEqual(num_decoder_layers, len(blocks)) 90 | 91 | 92 | class DiscriminatorTest(tf.test.TestCase): 93 | 94 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 95 | return (input_size + pad * 2 - kernel_size) // stride + 1 96 | 97 | def test_four_layers(self): 98 | batch_size = 2 99 | input_size = 256 100 | 101 | output_size = self._layer_output_size(input_size) 102 | output_size = self._layer_output_size(output_size) 103 | output_size = self._layer_output_size(output_size) 104 | output_size = self._layer_output_size(output_size, stride=1) 105 | output_size = self._layer_output_size(output_size, stride=1) 106 | 107 | images = tf.ones((batch_size, input_size, input_size, 3)) 108 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 109 | logits, end_points = pix2pix.pix2pix_discriminator( 110 | images, num_filters=[64, 128, 256, 512]) 111 | self.assertListEqual([batch_size, output_size, output_size, 1], 112 | logits.shape.as_list()) 113 | self.assertListEqual([batch_size, output_size, output_size, 1], 114 | end_points['predictions'].shape.as_list()) 115 | 116 | def test_four_layers_no_padding(self): 117 | batch_size = 2 118 | input_size = 256 119 | 120 | output_size = self._layer_output_size(input_size, pad=0) 121 | output_size = self._layer_output_size(output_size, pad=0) 122 | output_size = self._layer_output_size(output_size, pad=0) 123 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 124 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 125 | 126 | images = tf.ones((batch_size, input_size, input_size, 3)) 127 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 128 | logits, end_points = pix2pix.pix2pix_discriminator( 129 | images, num_filters=[64, 128, 256, 512], padding=0) 130 | self.assertListEqual([batch_size, output_size, output_size, 1], 131 | logits.shape.as_list()) 132 | self.assertListEqual([batch_size, output_size, output_size, 1], 133 | end_points['predictions'].shape.as_list()) 134 | 135 | def test_four_layers_wrog_paddig(self): 136 | batch_size = 2 137 | input_size = 256 138 | 139 | images = tf.ones((batch_size, input_size, input_size, 3)) 140 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 141 | with self.assertRaises(TypeError): 142 | pix2pix.pix2pix_discriminator( 143 | images, num_filters=[64, 128, 256, 512], padding=1.5) 144 | 145 | def test_four_layers_negative_padding(self): 146 | batch_size = 2 147 | input_size = 256 148 | 149 | images = tf.ones((batch_size, input_size, input_size, 3)) 150 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 151 | with self.assertRaises(ValueError): 152 | pix2pix.pix2pix_discriminator( 153 | images, num_filters=[64, 128, 256, 512], padding=-1) 154 | 155 | if __name__ == '__main__': 156 | tf.test.main() 157 | -------------------------------------------------------------------------------- /common/nets/s3dg_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for networks.s3dg.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import s3dg 24 | 25 | 26 | class S3DGTest(tf.test.TestCase): 27 | 28 | def testBuildClassificationNetwork(self): 29 | batch_size = 5 30 | num_frames = 64 31 | height, width = 224, 224 32 | num_classes = 1000 33 | 34 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 35 | logits, end_points = s3dg.s3dg(inputs, num_classes) 36 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | self.assertTrue('Predictions' in end_points) 40 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 41 | [batch_size, num_classes]) 42 | 43 | def testBuildBaseNetwork(self): 44 | batch_size = 5 45 | num_frames = 64 46 | height, width = 224, 224 47 | 48 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 49 | mixed_6c, end_points = s3dg.s3dg_base(inputs) 50 | self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c')) 51 | self.assertListEqual(mixed_6c.get_shape().as_list(), 52 | [batch_size, 8, 7, 7, 1024]) 53 | expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 54 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 55 | 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 56 | 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 57 | 'Mixed_5b', 'Mixed_5c'] 58 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 59 | 60 | def testBuildOnlyUptoFinalEndpointNoGating(self): 61 | batch_size = 5 62 | num_frames = 64 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 67 | 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 68 | 'Mixed_5c'] 69 | for index, endpoint in enumerate(endpoints): 70 | with tf.Graph().as_default(): 71 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 72 | out_tensor, end_points = s3dg.s3dg_base( 73 | inputs, final_endpoint=endpoint, gating_startat=None) 74 | print(endpoint, out_tensor.op.name) 75 | self.assertTrue(out_tensor.op.name.startswith( 76 | 'InceptionV1/' + endpoint)) 77 | self.assertItemsEqual(endpoints[:index+1], end_points) 78 | 79 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 80 | batch_size = 5 81 | num_frames = 64 82 | height, width = 224, 224 83 | 84 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 85 | _, end_points = s3dg.s3dg_base(inputs, 86 | final_endpoint='Mixed_5c') 87 | endpoints_shapes = {'Conv2d_1a_7x7': [5, 32, 112, 112, 64], 88 | 'MaxPool_2a_3x3': [5, 32, 56, 56, 64], 89 | 'Conv2d_2b_1x1': [5, 32, 56, 56, 64], 90 | 'Conv2d_2c_3x3': [5, 32, 56, 56, 192], 91 | 'MaxPool_3a_3x3': [5, 32, 28, 28, 192], 92 | 'Mixed_3b': [5, 32, 28, 28, 256], 93 | 'Mixed_3c': [5, 32, 28, 28, 480], 94 | 'MaxPool_4a_3x3': [5, 16, 14, 14, 480], 95 | 'Mixed_4b': [5, 16, 14, 14, 512], 96 | 'Mixed_4c': [5, 16, 14, 14, 512], 97 | 'Mixed_4d': [5, 16, 14, 14, 512], 98 | 'Mixed_4e': [5, 16, 14, 14, 528], 99 | 'Mixed_4f': [5, 16, 14, 14, 832], 100 | 'MaxPool_5a_2x2': [5, 8, 7, 7, 832], 101 | 'Mixed_5b': [5, 8, 7, 7, 832], 102 | 'Mixed_5c': [5, 8, 7, 7, 1024]} 103 | 104 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 105 | for endpoint_name, expected_shape in endpoints_shapes.iteritems(): 106 | self.assertTrue(endpoint_name in end_points) 107 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 108 | expected_shape) 109 | 110 | def testHalfSizeImages(self): 111 | batch_size = 5 112 | num_frames = 64 113 | height, width = 112, 112 114 | 115 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 116 | mixed_5c, _ = s3dg.s3dg_base(inputs) 117 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 118 | self.assertListEqual(mixed_5c.get_shape().as_list(), 119 | [batch_size, 8, 4, 4, 1024]) 120 | 121 | def testTenFrames(self): 122 | batch_size = 5 123 | num_frames = 10 124 | height, width = 224, 224 125 | 126 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 127 | mixed_5c, _ = s3dg.s3dg_base(inputs) 128 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 129 | self.assertListEqual(mixed_5c.get_shape().as_list(), 130 | [batch_size, 2, 7, 7, 1024]) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | num_frames = 64 135 | height, width = 224, 224 136 | num_classes = 1000 137 | 138 | eval_inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 139 | logits, _ = s3dg.s3dg(eval_inputs, num_classes, 140 | is_training=False) 141 | predictions = tf.argmax(logits, 1) 142 | 143 | with self.test_session() as sess: 144 | sess.run(tf.global_variables_initializer()) 145 | output = sess.run(predictions) 146 | self.assertEquals(output.shape, (batch_size,)) 147 | 148 | 149 | if __name__ == '__main__': 150 | tf.test.main() 151 | -------------------------------------------------------------------------------- /common/scst/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 28 18:12:55 2017 4 | 5 | @author: jiahuei 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/README.md: -------------------------------------------------------------------------------- 1 | CIDEr for Self-Critical Sequence Training (SCST) 2 | =================== 3 | 4 | This module is based on the repo `ruotianluo/cider` (with modifications). 5 | 6 | Code for Consensus-based Image Description Evaluation. Provides CIDEr as well as 7 | CIDEr-D (CIDEr Defended) which is more robust to gaming effects. 8 | 9 | 10 | ## Important Note 11 | 12 | CIDEr by default (with idf parameter set to "corpus" mode) computes IDF values 13 | using the reference sentences provided. 14 | Thus, CIDEr score for a reference dataset with only 1 image will be zero. 15 | When evaluating using one (or few) images, set idf to "coco-val-df" instead, 16 | which uses IDF from the MSCOCO Vaildation Dataset for reliable results. 17 | 18 | To enable the IDF mode "coco-val-df": 19 | 1. Download the [IDF file](https://github.com/ruotianluo/cider/blob/dbb3960165d86202ed3c417b412a000fc8e717f3/data/coco-val.p) 20 | 1. Rename the file to `coco-val-df.p` 21 | 1. Place the file in `./cider_ruotianluo/data` 22 | 23 | 24 | ## Dependencies 25 | - java 1.8.0 26 | - python 2.7 27 | 28 | 29 | ## References & Acknowledgments 30 | To see the code differences, refer to this fork: 31 | - [[jiahuei/cider]](https://github.com/jiahuei/cider) 32 | 33 | Thanks to the developers of: 34 | - [[ruotianluo/cider]](https://github.com/ruotianluo/cider/tree/dbb3960165d86202ed3c417b412a000fc8e717f3) 35 | 36 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # 4 | # Description: Describes the class to compute the CIDEr 5 | # (Consensus-Based Image Description Evaluation) Metric 6 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 7 | # 8 | # Creation Date: Sun Feb 8 14:16:54 2015 9 | # 10 | # Authors: Ramakrishna Vedantam and 11 | # Tsung-Yi Lin 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | from .cider_scorer import CiderScorer 17 | 18 | 19 | class Cider: 20 | """ 21 | Main Class to compute the CIDEr metric 22 | 23 | """ 24 | def __init__(self, n=4, df="corpus"): 25 | """ 26 | Initialize the CIDEr scoring function 27 | : param n (int): n-gram size 28 | : param df (string): specifies where to get the IDF values from 29 | takes values 'corpus', 'coco-train' 30 | : return: None 31 | """ 32 | # set cider to sum over 1 to 4-grams 33 | self._n = n 34 | self._df = df 35 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) 36 | 37 | def compute_score(self, gts, res): 38 | """ 39 | Main function to compute CIDEr score 40 | : param gts (dict) : {image:tokenized reference sentence} 41 | : param res (dict) : {image:tokenized candidate sentence} 42 | : return: cider (float) : computed CIDEr score for the corpus 43 | """ 44 | 45 | # clear all the previous hypos and refs 46 | self.cider_scorer.clear() 47 | 48 | for id in gts: 49 | hypo = res[id] 50 | ref = gts[id] 51 | 52 | # Sanity check. 53 | assert(type(hypo) is list) 54 | assert(len(hypo) == 1) 55 | assert(type(ref) is list) 56 | assert(len(ref) > 0) 57 | self.cider_scorer += (hypo[0], ref) 58 | 59 | (score, scores) = self.cider_scorer.compute_score() 60 | 61 | return score, scores 62 | 63 | def method(self): 64 | return "CIDEr" 65 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/ciderD/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/ciderD/ciderD.py: -------------------------------------------------------------------------------- 1 | # Filename: ciderD.py 2 | # 3 | # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | from .ciderD_scorer import CiderScorer 14 | 15 | 16 | class CiderD: 17 | """ 18 | Main Class to compute the CIDEr metric 19 | 20 | """ 21 | def __init__(self, n=4, sigma=6.0, df="corpus"): 22 | # set cider to sum over 1 to 4-grams 23 | self._n = n 24 | # set the standard deviation parameter for gaussian penalty 25 | self._sigma = sigma 26 | # set which where to compute document frequencies from 27 | self._df = df 28 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) 29 | 30 | def compute_score(self, gts, res): 31 | """ 32 | Main function to compute CIDEr score 33 | :param hypo_for_image (dict) : dictionary with key and value 34 | ref_for_image (dict) : dictionary with key and value 35 | :return: cider (float) : computed CIDEr score for the corpus 36 | """ 37 | 38 | # clear all the previous hypos and refs 39 | tmp_cider_scorer = self.cider_scorer.copy_empty() 40 | tmp_cider_scorer.clear() 41 | 42 | assert(sorted(gts.keys()) == sorted(res.keys())) 43 | for id in gts: 44 | hypo = res[id] 45 | ref = gts[id] 46 | 47 | # Sanity check. 48 | assert(type(hypo) is list) 49 | assert(len(hypo) == 1) 50 | assert(type(ref) is list) 51 | assert(len(ref) > 0) 52 | tmp_cider_scorer += (hypo[0], ref) 53 | 54 | (score, scores) = tmp_cider_scorer.compute_score() 55 | 56 | return score, scores 57 | 58 | def method(self): 59 | return "CIDEr-D" 60 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | __author__ = 'rama' 6 | 7 | from .tokenizer.ptbtokenizer import PTBTokenizer 8 | from .cider.cider import Cider 9 | from .ciderD.ciderD import CiderD 10 | 11 | 12 | class CIDErEvalCap: 13 | def __init__(self, gts, res, df): 14 | print('tokenization...') 15 | tokenizer = PTBTokenizer('gts') 16 | _gts = tokenizer.tokenize(gts) 17 | print('tokenized refs') 18 | tokenizer = PTBTokenizer('res') 19 | _res = tokenizer.tokenize(res) 20 | print('tokenized cands') 21 | 22 | self.gts = _gts 23 | self.res = _res 24 | self.df = df 25 | 26 | def evaluate(self): 27 | # ================================================= 28 | # Set up scorers 29 | # ================================================= 30 | 31 | print('setting up scorers...') 32 | scorers = [ 33 | (Cider(df=self.df), "CIDEr"), (CiderD(df=self.df), "CIDErD") 34 | ] 35 | 36 | # ================================================= 37 | # Compute scores 38 | # ================================================= 39 | metric_scores = {} 40 | for scorer, method in scorers: 41 | print('computing %s score...' % (scorer.method())) 42 | score, scores = scorer.compute_score(self.gts, self.res) 43 | print("Mean %s score: %0.3f" % (method, score)) 44 | metric_scores[method] = list(scores) 45 | return metric_scores 46 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os 15 | import pdb # python debugger 16 | import sys 17 | import subprocess 18 | import re 19 | import tempfile 20 | import itertools 21 | 22 | # path to the stanford corenlp jar 23 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 24 | 25 | # punctuations to be removed from the sentences 26 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 27 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 28 | 29 | class PTBTokenizer: 30 | """Python wrapper of Stanford PTBTokenizer""" 31 | def __init__(self, _source='gts'): 32 | self.source = _source 33 | 34 | def tokenize(self, captions_for_image): 35 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 36 | 'edu.stanford.nlp.process.PTBTokenizer', \ 37 | '-preserveLines', '-lowerCase'] 38 | 39 | # ====================================================== 40 | # prepare data for PTB Tokenizer 41 | # ====================================================== 42 | 43 | if self.source == 'gts': 44 | image_id = [k for k, v in list(captions_for_image.items()) for _ in range(len(v))] 45 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in list(captions_for_image.items()) for c in v]) 46 | final_tokenized_captions_for_image = {} 47 | 48 | elif self.source == 'res': 49 | index = [i for i, v in enumerate(captions_for_image)] 50 | image_id = [v["image_id"] for v in captions_for_image] 51 | sentences = '\n'.join(v["caption"].replace('\n', ' ') for v in captions_for_image ) 52 | final_tokenized_captions_for_index = [] 53 | 54 | # ====================================================== 55 | # save sentences to temporary file 56 | # ====================================================== 57 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 58 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 59 | tmp_file.write(sentences.encode('utf-8')) 60 | tmp_file.close() 61 | 62 | # ====================================================== 63 | # tokenize sentence 64 | # ====================================================== 65 | cmd.append(os.path.basename(tmp_file.name)) 66 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 67 | stdout=subprocess.PIPE) 68 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 69 | lines = token_lines.decode("utf-8").split('\n') 70 | # remove temp file 71 | os.remove(tmp_file.name) 72 | 73 | # ====================================================== 74 | # create dictionary for tokenized captions 75 | # ====================================================== 76 | if self.source == 'gts': 77 | for k, line in zip(image_id, lines): 78 | if not k in final_tokenized_captions_for_image: 79 | final_tokenized_captions_for_image[k] = [] 80 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 81 | if w not in PUNCTUATIONS]) 82 | final_tokenized_captions_for_image[k].append(tokenized_caption) 83 | 84 | return final_tokenized_captions_for_image 85 | 86 | elif self.source == 'res': 87 | for k, img, line in zip(index, image_id, lines): 88 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 89 | if w not in PUNCTUATIONS]) 90 | final_tokenized_captions_for_index.append({'image_id': img, 'caption': [tokenized_caption]}) 91 | 92 | return final_tokenized_captions_for_index 93 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pyciderevalcap/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahuei/COMIC-Compact-Image-Captioning-with-Attention/73165e0aac2816e89732571814f978801958e1ac/common/scst/cider_ruotianluo/pyciderevalcap/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pydataformat/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'rama' 2 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pydataformat/jsonify_refs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to convert mat file with structures into json files 3 | Created on : 5/18/15 3:27 PM by rama 4 | """ 5 | 6 | import scipy.io as io 7 | import os 8 | import re 9 | import json 10 | import string 11 | import pdb 12 | 13 | pathToMat = '/Users/rama/Research/data/pyCider/' 14 | matfile = 'pascal_cands.mat' 15 | jsonfile = 'pascal_cands' 16 | 17 | data = io.loadmat(os.path.join(pathToMat, matfile)) 18 | refs = list(data['cands'][0]) 19 | 20 | A = [] 21 | B = [] 22 | 23 | for image in refs: 24 | for sentences in image[1]: 25 | for i, sent in enumerate(sentences): 26 | sent_struct = {} 27 | imname = str(image[0][0]).split('/')[-1] 28 | sent_struct['image_id'] = imname 29 | string_sent = sent[0].strip().split('\\') 30 | if len(string_sent) == 1: 31 | sent_struct['caption'] = string_sent[0] 32 | else: 33 | sent_struct['caption'] = ' '.join(string_sent[:-1]) 34 | if i == 1: 35 | A.append(sent_struct) 36 | else: 37 | B.append(sent_struct) 38 | 39 | with open(os.path.join(pathToMat, jsonfile + 'A.json'), 'w') as outfile: 40 | json.dump(A, outfile) 41 | 42 | with open(os.path.join(pathToMat, jsonfile + 'B.json'), 'w') as outfile: 43 | json.dump(B, outfile) 44 | -------------------------------------------------------------------------------- /common/scst/cider_ruotianluo/pydataformat/loadData.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load the reference and candidate json files, which are to be evaluated using CIDEr. 3 | 4 | Reference file: list of dict('image_id': image_id, 'caption': caption). 5 | Candidate file: list of dict('image_id': image_id, 'caption': caption). 6 | 7 | """ 8 | import json 9 | import os 10 | from collections import defaultdict 11 | 12 | class LoadData(): 13 | def __init__(self, path): 14 | self.pathToData = path 15 | 16 | def readJson(self, refname, candname): 17 | 18 | path_to_ref_file = os.path.join(self.pathToData, refname) 19 | path_to_cand_file = os.path.join(self.pathToData, candname) 20 | 21 | ref_list = json.loads(open(path_to_ref_file, 'r').read()) 22 | cand_list = json.loads(open(path_to_cand_file, 'r').read()) 23 | 24 | gts = defaultdict(list) 25 | res = [] 26 | 27 | for l in ref_list: 28 | gts[l['image_id']].append({"caption": l['caption']}) 29 | 30 | res = cand_list; 31 | return gts, res 32 | -------------------------------------------------------------------------------- /common/scst/prepro_ngrams.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Apr 1 18:02:06 2019 5 | 6 | @author: jiahuei 7 | 8 | Adapted from `https://github.com/ruotianluo/self-critical.pytorch/blob/master/scripts/prepro_ngrams.py` 9 | 10 | """ 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | 16 | import os, json, argparse, time 17 | from six.moves import cPickle as pickle 18 | from collections import defaultdict 19 | from tqdm import tqdm 20 | CURR_DIR = os.path.dirname(os.path.realpath(__file__)) 21 | pjoin = os.path.join 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """ 26 | Takes a string as input and returns an object that can be given to 27 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 28 | can take string arguments as well. 29 | :param s: string : sentence to be converted into ngrams 30 | :param n: int : number of ngrams for which representation is calculated 31 | :return: term frequency vector for occuring ngrams 32 | """ 33 | words = s.split() 34 | counts = defaultdict(int) 35 | for k in xrange(1,n+1): 36 | for i in xrange(len(words)-k+1): 37 | ngram = tuple(words[i:i+k]) 38 | counts[ngram] += 1 39 | return counts 40 | 41 | 42 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 43 | '''Takes a list of reference sentences for a single segment 44 | and returns an object that encapsulates everything that BLEU 45 | needs to know about them. 46 | :param refs: list of string : reference sentences for some image 47 | :param n: int : number of ngrams for which (ngram) representation is calculated 48 | :return: result (list of dict) 49 | ''' 50 | return [precook(ref, n) for ref in refs] 51 | 52 | 53 | def create_crefs(refs): 54 | crefs = [] 55 | for ref in tqdm(refs, ncols=100, desc='create_crefs'): 56 | # ref is a list of 5 captions 57 | crefs.append(cook_refs(ref)) 58 | return crefs 59 | 60 | 61 | def compute_doc_freq(crefs): 62 | ''' 63 | Compute term frequency for reference data. 64 | This will be used to compute idf (inverse document frequency later) 65 | From `cider_scorer.py` in `coco_caption`. 66 | ''' 67 | document_frequency = defaultdict(float) 68 | for refs in tqdm(crefs, ncols=100, desc='compute_doc_freq'): 69 | # refs, k ref captions of one image 70 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 71 | document_frequency[ngram] += 1 72 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 73 | return document_frequency 74 | 75 | 76 | def get_ngrams(refs_words, wtoi, params): 77 | """ 78 | Calculates the n-grams and lengths 79 | """ 80 | #refs_idxs = [] 81 | #for ref_words in tqdm(refs_words, ncols=100, desc='Token-to-idx'): 82 | # # `ref_words` is a list of captions for an image 83 | # ref_idxs = [] 84 | # for caption in ref_words: 85 | # tokens = caption.split(' ') 86 | # idx = [wtoi.get(t, wtoi['']) for t in tokens] 87 | # idx = ' '.join([str(i) for i in idx]) 88 | # ref_idxs.append(idx) 89 | # refs_idxs.append(ref_idxs) 90 | 91 | print('\nINFO: Computing term frequency: word.') 92 | time.sleep(0.1) 93 | ngram_words = compute_doc_freq(create_crefs(refs_words)) 94 | print('\nINFO: Computing term frequency: indices. (SKIPPED)') 95 | #time.sleep(0.1) 96 | #ngram_idxs = compute_doc_freq(create_crefs(refs_idxs)) 97 | ngram_idxs = None 98 | return ngram_words, ngram_idxs, len(refs_words) 99 | 100 | 101 | if __name__ == '__main__': 102 | 103 | parser = argparse.ArgumentParser() 104 | base_dir = os.path.dirname((os.path.dirname(CURR_DIR))) 105 | 106 | parser.add_argument( 107 | '--dataset_dir', type=str, default='', 108 | help='The dataset directory.') 109 | parser.add_argument( 110 | '--dataset_file_pattern', type=str, 111 | default='mscoco_{}_w5_s20_include_restval', 112 | help='The dataset file pattern, example: `mscoco_{}_w5_s20`.') 113 | parser.add_argument( 114 | '--split', type=str, default='train', 115 | help='The split for generating n-grams.') 116 | args = parser.parse_args() 117 | 118 | dataset = args.dataset_file_pattern.split('_')[0] 119 | if args.dataset_dir == '': 120 | args.dataset_dir = pjoin(base_dir, 'datasets', dataset) 121 | 122 | # Data format: filepath,w0 w1 w2 w3 w4 ... wN 123 | fp = pjoin(args.dataset_dir, 'captions', 124 | args.dataset_file_pattern.format(args.split)) 125 | with open(fp + '.txt', 'r') as f: 126 | data = [l.strip().split(',') for l in f.readlines()] 127 | data_dict = {} 128 | for d in data: 129 | if d[0] not in data_dict: 130 | data_dict[d[0]] = [] 131 | data_dict[d[0]].append(d[1].replace(' ', '')) 132 | 133 | captions_group = [v for v in data_dict.values()] 134 | assert len(data_dict.keys()) == len(captions_group) 135 | 136 | fp = pjoin(args.dataset_dir, 'captions', 137 | args.dataset_file_pattern.format('wtoi')) 138 | with open(fp + '.json', 'r') as f: 139 | wtoi = json.load(f) 140 | 141 | print('\nINFO: Data reading complete.') 142 | #time.sleep(0.2) 143 | ngram_words, ngram_idxs, ref_len = get_ngrams(captions_group, wtoi, args) 144 | 145 | time.sleep(0.2) 146 | print('\nINFO: Saving output files.') 147 | 148 | fp = pjoin(args.dataset_dir, 'captions', args.dataset_file_pattern) 149 | with open(fp.format('scst-words') + '.p', 'w') as f: 150 | pickle.dump({'document_frequency': ngram_words, 'ref_len': ref_len}, 151 | f, pickle.HIGHEST_PROTOCOL) 152 | #with open(fp.format('scst-idxs') + '.p', 'w') as f: 153 | # pickle.dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, 154 | # f, pickle.HIGHEST_PROTOCOL) 155 | 156 | print('\nINFO: Completed.') 157 | 158 | 159 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 9 23:38:59 2019 4 | 5 | @author: jiahuei 6 | 7 | Utility functions. 8 | 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os, math, time 15 | import requests 16 | import tarfile 17 | from zipfile import ZipFile 18 | from tqdm import tqdm 19 | #from PIL import Image 20 | #Image.MAX_IMAGE_PIXELS = None 21 | # By default, PIL limit is around 89 Mpix (~ 9459 ** 2) 22 | 23 | 24 | _EXT = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] 25 | pjoin = os.path.join 26 | 27 | try: 28 | from natsort import natsorted, ns 29 | except ImportError: 30 | natsorted = None 31 | 32 | 33 | def maybe_download_from_url(url, dest_dir, wget=True, file_size=None): 34 | """ 35 | Downloads file from URL, streaming large files. 36 | """ 37 | fname = url.split('/')[-1] 38 | if not os.path.exists(dest_dir): 39 | os.makedirs(dest_dir) 40 | fpath = pjoin(dest_dir, fname) 41 | if os.path.isfile(fpath): 42 | print('INFO: Found file `{}`'.format(fname)) 43 | return fpath 44 | if wget: 45 | import subprocess 46 | subprocess.call(['wget', url], cwd=dest_dir) 47 | else: 48 | import requests 49 | response = requests.get(url, stream=True) 50 | chunk_size = 1024 ** 2 # 1 MB 51 | if response.ok: 52 | print('INFO: Downloading `{}`'.format(fname)) 53 | else: 54 | print('ERROR: Download error. Server response: {}'.format(response)) 55 | return False 56 | time.sleep(0.2) 57 | 58 | # Case-insensitive Dictionary of Response Headers. 59 | # The length of the request body in octets (8-bit bytes). 60 | try: 61 | file_size = int(response.headers['Content-Length']) 62 | except: 63 | pass 64 | if file_size is None: 65 | num_iters = None 66 | else: 67 | num_iters = math.ceil(file_size / chunk_size) 68 | tqdm_kwargs = dict(desc = 'Download progress', 69 | total = num_iters, 70 | unit = 'MB') 71 | with open(fpath, 'wb') as handle: 72 | for chunk in tqdm(response.iter_content(chunk_size), **tqdm_kwargs): 73 | if not chunk: break 74 | handle.write(chunk) 75 | print('INFO: Download complete: `{}`'.format(fname)) 76 | return fpath 77 | 78 | 79 | def maybe_download_from_google_drive(id, fpath, file_size=None): 80 | URL = 'https://docs.google.com/uc?export=download' 81 | chunk_size = 1024 ** 2 # 1 MB 82 | fname = os.path.basename(fpath) 83 | out_path = os.path.split(fpath)[0] 84 | if not os.path.exists(out_path): 85 | os.makedirs(out_path) 86 | if os.path.isfile(fpath): 87 | print('INFO: Found file `{}`'.format(fname)) 88 | return fpath 89 | print('INFO: Downloading `{}`'.format(fname)) 90 | 91 | session = requests.Session() 92 | response = session.get(URL, params = { 'id' : id }, stream = True) 93 | token = get_confirm_token(response) 94 | 95 | if token: 96 | params = { 'id' : id, 'confirm' : token } 97 | response = session.get(URL, params = params, stream = True) 98 | 99 | if file_size is not None: 100 | num_iters = math.ceil(file_size / chunk_size) 101 | else: 102 | num_iters = None 103 | tqdm_kwargs = dict(desc = 'Download progress', 104 | total = num_iters, 105 | unit = 'MB') 106 | with open(fpath, 'wb') as handle: 107 | for chunk in tqdm(response.iter_content(chunk_size), **tqdm_kwargs): 108 | if not chunk: break 109 | handle.write(chunk) 110 | print('INFO: Download complete: `{}`'.format(fname)) 111 | return fpath 112 | 113 | 114 | def get_confirm_token(response): 115 | for key, value in response.cookies.items(): 116 | if key.startswith('download_warning'): 117 | return value 118 | return None 119 | 120 | 121 | def extract_tar_gz(fpath): 122 | """ 123 | Extracts tar.gz file into the containing directory. 124 | """ 125 | tar = tarfile.open(fpath, 'r') 126 | members = tar.getmembers() 127 | opath = os.path.split(fpath)[0] 128 | for m in tqdm(iterable=members, 129 | total=len(members), 130 | desc='Extracting `{}`'.format(os.path.split(fpath)[1])): 131 | tar.extract(member=m, path=opath) 132 | #tar.extractall(path=opath, members=progress(tar)) # members=None to extract all 133 | tar.close() 134 | 135 | 136 | def extract_zip(fpath): 137 | """ 138 | Extracts zip file into the containing directory. 139 | """ 140 | with ZipFile(fpath, 'r') as zip_ref: 141 | for m in tqdm( 142 | iterable=zip_ref.namelist(), 143 | total=len(zip_ref.namelist()), 144 | desc='Extracting `{}`'.format(os.path.split(fpath)[1])): 145 | zip_ref.extract(member=m, path=os.path.split(fpath)[0]) 146 | #zip_ref.extractall(os.path.split(fpath)[0]) 147 | 148 | 149 | def maybe_get_ckpt_file(net_params, remove_tar=True): 150 | """ 151 | Download, extract, remove. 152 | """ 153 | if os.path.isfile(net_params['ckpt_path']): 154 | pass 155 | else: 156 | url = net_params['url'] 157 | tar_gz_path = maybe_download_from_url( 158 | url, os.path.split(net_params['ckpt_path'])[0]) 159 | extract_tar_gz(tar_gz_path) 160 | if remove_tar: os.remove(tar_gz_path) 161 | 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /datasets/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 28 18:12:55 2017 4 | 5 | @author: jiahuei 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /datasets/preprocessing/coco_prepro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Feb 14 18:34:56 2017 5 | 6 | @author: jiahuei 7 | 8 | V8 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os, sys, json, argparse 15 | import prepro_base as prepro 16 | CURR_DIR = os.path.dirname(os.path.realpath(__file__)) 17 | sys.path.append(os.path.join(CURR_DIR, '..', '..', 'common')) 18 | import utils 19 | pjoin = os.path.join 20 | 21 | 22 | #wtoi_file = 'coco_wtoi_w5_s20_include_restval.json' 23 | #itow_file = 'coco_itow_w5_s20_include_restval.json' 24 | 25 | 26 | def create_parser(): 27 | parser = argparse.ArgumentParser( 28 | formatter_class=argparse.RawDescriptionHelpFormatter) 29 | parser.add_argument( 30 | '--dataset_dir', type=str, default='') 31 | parser.add_argument( 32 | '--output_prefix', type=str, default='mscoco') 33 | parser.add_argument( 34 | '--retokenise', type=bool, default=False) 35 | parser.add_argument( 36 | '--include_restval', type=bool, default=True) 37 | parser.add_argument( 38 | '--word_count_thres', type=int, default=5) 39 | parser.add_argument( 40 | '--caption_len_thres', type=int, default=20) 41 | parser.add_argument( 42 | '--pad_value', type=int, default=-1) 43 | parser.add_argument( 44 | '--wtoi_file', type=str, default=None) 45 | parser.add_argument( 46 | '--itow_file', type=str, default=None) 47 | 48 | return parser 49 | 50 | if __name__ == '__main__': 51 | 52 | parser = create_parser() 53 | args = parser.parse_args() 54 | 55 | if args.dataset_dir == '': 56 | dset_dir = pjoin(os.path.dirname(CURR_DIR), 'mscoco') 57 | else: 58 | dset_dir = args.dataset_dir 59 | out_path = pjoin(dset_dir, 'captions') 60 | json_path = pjoin(dset_dir, 'dataset_coco.json') 61 | 62 | ### Get the caption JSON files ### 63 | if os.path.isfile(json_path): 64 | print('INFO: Found file: `dataset_coco.json`') 65 | else: 66 | zip_path = utils.maybe_download_from_url( 67 | r'https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip', 68 | dset_dir) 69 | utils.extract_zip(zip_path) 70 | os.remove(zip_path) 71 | 72 | #if os.path.isfile(pjoin(dset_dir, 'annotations', 'captions_val2014.json')): 73 | # print('INFO: Found file: `captions_val2014.json`') 74 | #else: 75 | # zip_path = utils.maybe_download_from_url( 76 | # r'http://images.cocodataset.org/annotations/annotations_trainval2014.zip', 77 | # dset_dir) 78 | # utils.extract_zip(zip_path) 79 | # os.remove(zip_path) 80 | 81 | ### Read the raw JSON file ### 82 | 83 | with open(json_path, 'r') as f: 84 | dataset_coco = json.load(f) 85 | 86 | ### Tokenise captions ### 87 | 88 | tokenised_coco = prepro.tokenise(dataset_coco, 89 | image_id_key='imgid', 90 | retokenise=args.retokenise) 91 | 92 | ### Build vocabulary ### 93 | 94 | build_vocab = args.wtoi_file is None or args.itow_file is None 95 | if build_vocab: 96 | wtoi, itow = prepro.build_vocab(tokenised_coco, 97 | args.word_count_thres, 98 | args.caption_len_thres, 99 | vocab_size=None, 100 | include_restval=args.include_restval, 101 | pad_value=args.pad_value, 102 | include_GO_EOS_tokens=True) 103 | else: 104 | print('INFO: Reusing provided vocabulary.\n') 105 | with open(os.path.join(out_path, args.wtoi_file), 'r') as f: 106 | wtoi = json.load(f) 107 | with open(os.path.join(out_path, args.itow_file), 'r') as f: 108 | itow = json.load(f) 109 | 110 | ### Convert tokenised words to text files ### 111 | 112 | tokenised_coco = prepro.tokenised_word_to_txt_V1(tokenised_coco, 113 | args.caption_len_thres, 114 | args.include_restval) 115 | 116 | print('\nINFO: Example captions:') 117 | for j in range(5): 118 | print(tokenised_coco['train'][j]) 119 | print('\n') 120 | 121 | ### Output files ### 122 | 123 | if args.output_prefix is not None: 124 | if not os.path.exists(out_path): 125 | os.makedirs(out_path) 126 | suffix = [] 127 | suffix.append('w{:d}_s{:d}'.format( 128 | args.word_count_thres, args.caption_len_thres)) 129 | if args.include_restval: 130 | suffix.append('include_restval') 131 | if args.retokenise: 132 | suffix.append('retokenised') 133 | suffix = '_'.join(suffix) 134 | 135 | for split in tokenised_coco.keys(): 136 | filename = '{}_{}_{}.txt'.format(args.output_prefix, split, suffix) 137 | with open(pjoin(out_path, filename), 'w') as f: 138 | f.write('\r\n'.join(tokenised_coco[split])) 139 | 140 | # Assert no overlaps between sets 141 | train_set = set([s.split(',')[0] for s in tokenised_coco['train']]) 142 | valid_set = set([s.split(',')[0] for s in tokenised_coco['valid']]) 143 | test_set = set([s.split(',')[0] for s in tokenised_coco['test']]) 144 | assert not bool(train_set.intersection(valid_set)) 145 | assert not bool(train_set.intersection(test_set)) 146 | assert not bool(valid_set.intersection(test_set)) 147 | 148 | # Write validation file list 149 | with open(pjoin(out_path, 'filenames_valid.txt'), 'w') as f: 150 | f.write('\r\n'.join(list(valid_set))) 151 | 152 | # Write test file list 153 | with open(pjoin(out_path, 'filenames_test.txt'), 'w') as f: 154 | f.write('\r\n'.join(list(test_set))) 155 | 156 | if build_vocab: 157 | with open('%s/%s_wtoi_%s.json' % 158 | (out_path, args.output_prefix, suffix), 'w') as f: 159 | json.dump(wtoi, f) 160 | with open('%s/%s_itow_%s.json' % 161 | (out_path, args.output_prefix, suffix), 'w') as f: 162 | json.dump(itow, f) 163 | 164 | print('INFO: Saved output text files.\n') 165 | 166 | 167 | ### Get the image files ### 168 | print('INFO: Checking image files.') 169 | img_all = train_set.union(valid_set).union(test_set) 170 | tpath = pjoin(dset_dir, 'train2014') 171 | vpath = pjoin(dset_dir, 'val2014') 172 | ext = exv = [] 173 | if os.path.exists(tpath): 174 | ext = os.listdir(tpath) 175 | ext = [pjoin('train2014', i) for i in ext] 176 | if os.path.exists(vpath): 177 | exv = os.listdir(vpath) 178 | exv = [pjoin('val2014', i) for i in exv] 179 | ex = set(ext + exv) 180 | img_exists = len(ex.intersection(img_all)) == len(img_all) 181 | 182 | if img_exists: 183 | print('INFO: Found exising image files.') 184 | else: 185 | zip_path = utils.maybe_download_from_url( 186 | r'http://images.cocodataset.org/zips/train2014.zip', 187 | dset_dir) 188 | utils.extract_zip(zip_path) 189 | os.remove(zip_path) 190 | zip_path = utils.maybe_download_from_url( 191 | r'http://images.cocodataset.org/zips/val2014.zip', 192 | dset_dir) 193 | utils.extract_zip(zip_path) 194 | os.remove(zip_path) 195 | zip_path = utils.maybe_download_from_url( 196 | r'http://images.cocodataset.org/zips/test2014.zip', 197 | dset_dir) 198 | utils.extract_zip(zip_path) 199 | os.remove(zip_path) 200 | -------------------------------------------------------------------------------- /datasets/preprocessing/ptb_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | #import sys 13 | import subprocess 14 | import tempfile 15 | import re 16 | #import itertools 17 | 18 | # path to the stanford corenlp jar 19 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 20 | 21 | # punctuations to be removed from the sentences 22 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 23 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 24 | 25 | class PTBTokenizer: 26 | """ 27 | Python wrapper of Stanford PTBTokenizer. 28 | 29 | The tokeniser will split " xxx's " into " xxx 's ". 30 | Thus, words like " chef's " (one word) 31 | will become " chef 's " (two words). 32 | """ 33 | 34 | def tokenize(self, sentences_list, remove_non_alphanumerics=False): 35 | pattern = re.compile(r'([^\s\w]|_)+', re.UNICODE) # matches non-alphanumerics 36 | #pattern = re.compile(r'([^\w]|_)+', re.UNICODE) # matches non-alphanumerics and whitespaces 37 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 38 | 'edu.stanford.nlp.process.PTBTokenizer', \ 39 | '-preserveLines', '-lowerCase'] 40 | 41 | # ====================================================== 42 | # prepare data for PTB Tokenizer 43 | # ====================================================== 44 | sentences = '\n'.join([s.replace('\n', ' ') for s in sentences_list]) 45 | 46 | # ====================================================== 47 | # save sentences to temporary file 48 | # ====================================================== 49 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 50 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 51 | tmp_file.write(sentences) 52 | tmp_file.close() 53 | 54 | # ====================================================== 55 | # tokenize sentence 56 | # ====================================================== 57 | cmd.append(os.path.basename(tmp_file.name)) 58 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 59 | stdout=subprocess.PIPE) 60 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 61 | tokenized_caption_w_punc = token_lines.split('\n') 62 | # remove temp file 63 | os.remove(tmp_file.name) 64 | 65 | # ====================================================== 66 | # remove punctuations and optionally non-alphanumerics 67 | # ====================================================== 68 | tokenized_caption = [] 69 | if remove_non_alphanumerics: 70 | for line in tokenized_caption_w_punc: 71 | line = ' '.join([w for w in line.rstrip().split(' ') \ 72 | if w not in PUNCTUATIONS]) 73 | line = re.sub(pattern, '', line) 74 | tokenized_caption.append(line) 75 | else: 76 | for line in tokenized_caption_w_punc: 77 | tokenized_caption.append(' '.join([w for w in line.rstrip().split(' ') \ 78 | if w not in PUNCTUATIONS])) 79 | 80 | return tokenized_caption, tokenized_caption_w_punc 81 | -------------------------------------------------------------------------------- /src/example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | cd ${DIR} 5 | 6 | ### Training 7 | # Default 8 | python train.py 9 | 10 | # Custom MS-COCO directory 11 | python train.py \ 12 | --dataset_dir '/home/jiahuei/Documents/3_Datasets/MSCOCO_captions' 13 | 14 | # Word token, custom MS-COCO directory, GPU 1 15 | python train.py \ 16 | --token_type 'word' \ 17 | --dataset_dir '/home/jiahuei/Documents/3_Datasets/MSCOCO_captions' \ 18 | --gpu '1' 19 | 20 | # InstaPIC 21 | python train.py \ 22 | --dataset_file_pattern 'insta_{}_v25595_s15' \ 23 | --batch_size_eval 50 24 | 25 | # Custom InstaPIC directory 26 | python train.py \ 27 | --dataset_file_pattern 'insta_{}_v25595_s15' \ 28 | --dataset_dir '/home/jiahuei/Documents/3_Datasets/InstaPIC' \ 29 | --batch_size_eval 50 30 | 31 | 32 | ### Inference 33 | # Default dataset and checkpoint directories (MSCOCO, COMIC-256) 34 | python infer.py 35 | 36 | # Custom dataset and checkpoint directories 37 | python infer.py \ 38 | --infer_checkpoints_dir 'mscoco/word_add_softmax_h8_tie_lstm_run_01' \ 39 | --dataset_dir '/home/jiahuei/Documents/3_Datasets/MSCOCO_captions' \ 40 | --gpu '1' 41 | 42 | # InstaPIC 43 | python infer.py \ 44 | --infer_checkpoints_dir 'insta/word_add_softmax_h8_tie_lstm_run_01' \ 45 | --dataset_dir '/home/jiahuei/Documents/3_Datasets/InstaPIC' \ 46 | --annotations_file 'insta_testval_clean.json' 47 | 48 | # Custom InstaPIC directory 49 | python infer.py \ 50 | --infer_checkpoints_dir 'insta/word_add_softmax_h8_tie_lstm_run_01' \ 51 | --dataset_dir '/home/jiahuei/Documents/3_Datasets/InstaPIC' \ 52 | --annotations_file 'insta_testval_clean.json' 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Aug 28 17:53:58 2017 5 | 6 | @author: jiahuei 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os, sys, argparse 13 | CURR_DIR = os.path.dirname(os.path.realpath(__file__)) 14 | BASE_DIR = os.path.dirname(CURR_DIR) 15 | sys.path.append(os.path.join(BASE_DIR, 'common')) 16 | import infer_fn as infer 17 | import configuration as conf 18 | #import ops 19 | from natural_sort import natural_keys as nat_key 20 | pjoin = os.path.join 21 | 22 | 23 | def create_parser(): 24 | parser = argparse.ArgumentParser( 25 | formatter_class=argparse.RawDescriptionHelpFormatter) 26 | 27 | parser.add_argument( 28 | '--infer_set', type=str, default='test', 29 | choices=['test', 'valid', 'coco_test', 'coco_valid'], 30 | help='The split to perform inference on.') 31 | parser.add_argument( 32 | '--infer_checkpoints_dir', type=str, 33 | default=pjoin('mscoco', 'radix_b256_add_LN_softmax_h8_tie_lstm_run_01'), 34 | help='The directory containing the checkpoint files.') 35 | parser.add_argument( 36 | '--infer_checkpoints', type=str, default='all', 37 | help='The checkpoint numbers to be evaluated. Comma-separated.') 38 | parser.add_argument( 39 | '--annotations_file', type=str, default='captions_val2014.json', 40 | help='The annotations / reference file for calculating scores.') 41 | parser.add_argument( 42 | '--dataset_dir', type=str, 43 | default=pjoin(BASE_DIR, 'datasets', 'mscoco'), 44 | help='Dataset directory.') 45 | parser.add_argument( 46 | '--run_inference', type=bool, default=True, 47 | help='Whether to perform inference.') 48 | parser.add_argument( 49 | '--get_metric_score', type=bool, default=True, 50 | help='Whether to perform metric score calculations.') 51 | parser.add_argument( 52 | '--save_attention_maps', type=bool, default=False, 53 | help='Whether to save attention maps to disk as pickle file.') 54 | parser.add_argument( 55 | '--gpu', type=str, default='0', 56 | help='The gpu number.') 57 | parser.add_argument( 58 | '--per_process_gpu_memory_fraction', type=float, default=0.75, 59 | help='The fraction of GPU memory allocated.') 60 | 61 | parser.add_argument( 62 | '--infer_beam_size', type=int, default=3, 63 | help='The beam size.') 64 | parser.add_argument( 65 | '--infer_length_penalty_weight', type=float, default=0.0, 66 | help='The length penalty weight used in beam search.') 67 | parser.add_argument( 68 | '--infer_max_length', type=int, default=30, 69 | help='The maximum caption length allowed during inference.') 70 | parser.add_argument( 71 | '--batch_size_infer', type=int, default=25, 72 | help='The batch size.') 73 | 74 | return parser 75 | 76 | 77 | if __name__ == '__main__': 78 | ckpt_prefix = 'model_compact-' 79 | 80 | parser = create_parser() 81 | args = parser.parse_args() 82 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 83 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 84 | default_exp_dir = pjoin(os.path.dirname(CURR_DIR), 'experiments') 85 | args.infer_checkpoints_dir = pjoin(default_exp_dir, args.infer_checkpoints_dir) 86 | 87 | args.annotations_file = pjoin( 88 | BASE_DIR, 'common', 'coco_caption', 'annotations', args.annotations_file) 89 | 90 | if args.infer_checkpoints == 'all': 91 | files = sorted(os.listdir(args.infer_checkpoints_dir), key=nat_key) 92 | files = [f for f in files if ckpt_prefix in f] 93 | files = [f.replace('.index', '') for f in files if '.index' in f] 94 | files = [f.replace(ckpt_prefix, '') for f in files] 95 | if len(files) > 20: 96 | files = files[-12:] 97 | args.infer_checkpoints = files 98 | else: 99 | args.infer_checkpoints = args.infer_checkpoints.split(',') 100 | if len(args.infer_checkpoints) < 1: 101 | raise ValueError('`infer_checkpoints` must be either `all` or ' 102 | 'a list of comma-separated checkpoint numbers.') 103 | 104 | ### 105 | 106 | c = conf.load_config(pjoin(args.infer_checkpoints_dir, 'config.pkl')) 107 | c.__dict__.update(args.__dict__) 108 | ckpt_dir = c.infer_checkpoints_dir 109 | 110 | save_name = 'beam_{}_lpen_{}'.format( 111 | c.infer_beam_size, c.infer_length_penalty_weight) 112 | if c.infer_set == 'test': 113 | save_name = 'infer_test_' + save_name 114 | elif c.infer_set == 'valid': 115 | save_name = 'infer_valid_' + save_name 116 | elif c.infer_set == 'coco_test': 117 | save_name = 'infer_cocoTest_' + save_name 118 | elif c.infer_set == 'coco_valid': 119 | save_name = 'infer_cocoValid_' + save_name 120 | 121 | c.infer_save_path = pjoin(ckpt_dir, save_name) 122 | 123 | ############################################################################### 124 | 125 | if os.path.exists(c.infer_save_path): 126 | print('\nINFO: `eval_log_path` already exists.') 127 | else: 128 | print('\nINFO: `eval_log_path` will be created.') 129 | os.mkdir(c.infer_save_path) 130 | 131 | # Loop through the checkpoint files 132 | scores_combined = {} 133 | for ckpt_num in c.infer_checkpoints: 134 | curr_ckpt_path = pjoin(ckpt_dir, ckpt_prefix + ckpt_num) 135 | infer.evaluate_model( 136 | config=c, 137 | curr_ckpt_path=curr_ckpt_path, 138 | scores_combined=scores_combined) 139 | print('\n') 140 | 141 | 142 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 23 17:12:32 2017 4 | 5 | @author: jiahuei 6 | 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import sys, os 13 | import tensorflow as tf 14 | from model_base import ModelBase 15 | CURR_DIR = os.path.dirname(os.path.realpath(__file__)) 16 | sys.path.append(os.path.join(CURR_DIR, '..', 'common')) 17 | import ops 18 | _shape = ops.shape 19 | 20 | 21 | class CaptionModel(ModelBase): 22 | 23 | def __init__(self, 24 | config, 25 | mode, 26 | batch_ops=None, 27 | reuse=False, 28 | name=None): 29 | assert mode in ['train', 'eval', 'infer'] 30 | print('INFO: Building graph for: {}'.format(name)) 31 | super(CaptionModel, self).__init__(config) 32 | self.mode = mode 33 | self.batch_ops = batch_ops 34 | self.reuse = reuse 35 | self.name = name 36 | self._batch_size = _shape(self.batch_ops[0])[0] 37 | 38 | # Start to build the model 39 | c = self._config 40 | is_inference = self.mode == 'infer' 41 | vs_kwargs = dict(reuse=tf.AUTO_REUSE, 42 | initializer=self._get_initialiser()) 43 | 44 | if self.is_training(): 45 | self._create_gstep() 46 | if c.legacy: 47 | self._create_lr() 48 | else: 49 | self._create_cosine_lr(c.max_step) 50 | 51 | with tf.variable_scope('Model', **vs_kwargs): 52 | self._process_inputs() 53 | with tf.variable_scope('encoder'): 54 | self._encoder() 55 | with tf.variable_scope('decoder'): 56 | self._decoder_rnn() 57 | 58 | # We place the optimisation graph out of 'Model' scope 59 | self._train_caption_model() 60 | 61 | if is_inference: 62 | attention_maps = self.dec_attn_maps 63 | if attention_maps is None: 64 | self.infer_output = [self.dec_preds, tf.zeros([])] 65 | else: 66 | self.infer_output = [self.dec_preds, attention_maps] 67 | return None 68 | 69 | # Log softmax temperature value 70 | t = tf.get_collection('softmax_temperatures') 71 | if len(t) > 0: tf.summary.scalar('softmax_temperature', t[0]) 72 | self.summary_op = tf.summary.merge_all() 73 | print('INFO: Model `{}` initialisation complete.'.format(mode)) 74 | 75 | 76 | class CaptionModel_SCST(ModelBase): 77 | 78 | def __init__(self, 79 | config, 80 | scst_mode, 81 | reuse=False): 82 | assert scst_mode in ['train', 'sample'] 83 | #assert config.token_type == 'word' 84 | 85 | print('INFO: Building graph for: {}'.format(scst_mode)) 86 | super(CaptionModel_SCST, self).__init__(config) 87 | self.mode = scst_mode if scst_mode == 'train' else 'infer' 88 | c = self._config 89 | batch_size = c.batch_size_train 90 | if self.is_training(): 91 | batch_size *= (c.scst_beam_size +0) 92 | im_size = c.cnn_input_size 93 | self.imgs = tf.placeholder( 94 | dtype=tf.float32, 95 | shape=[batch_size, im_size[0], im_size[1], 3]) 96 | self.captions = tf.placeholder_with_default( 97 | input=tf.zeros(shape=[batch_size, 1], dtype=tf.int32), 98 | shape=[batch_size, None]) 99 | self.rewards = tf.placeholder(dtype=tf.float32, shape=[batch_size]) 100 | self.batch_ops = [self.imgs, self.captions] 101 | self.reuse = reuse 102 | self.name = scst_mode 103 | self._batch_size = _shape(self.batch_ops[0])[0] 104 | 105 | # Start to build the model 106 | vs_kwargs = dict(reuse=tf.AUTO_REUSE, 107 | initializer=self._get_initialiser()) 108 | 109 | if self.is_training(): 110 | self._create_gstep() 111 | self._create_cosine_lr(c.max_step) 112 | 113 | with tf.variable_scope('Model', **vs_kwargs): 114 | self._process_inputs() 115 | with tf.variable_scope('encoder'): 116 | self._encoder() 117 | with tf.variable_scope('decoder'): 118 | if self.is_training(): 119 | self._decoder_rnn_scst() 120 | else: 121 | with tf.name_scope('greedy'): 122 | self._decoder_rnn_scst(1) 123 | self.dec_preds_greedy = self.dec_preds 124 | with tf.name_scope('beam'): 125 | self._decoder_rnn_scst(c.scst_beam_size) 126 | self.dec_preds_beam = self.dec_preds 127 | #with tf.name_scope('sample'): 128 | # self._decoder_rnn_scst(0) 129 | # self.dec_preds_sample = self.dec_preds 130 | 131 | # Generated captions can be obtained by calling self.dec_preds 132 | 133 | # We place the optimisation graph out of 'Model' scope 134 | self.train_scst = self._train_caption_model(scst=True) 135 | 136 | 137 | # Log softmax temperature value 138 | t = tf.get_collection('softmax_temperatures') 139 | if len(t) > 0: tf.summary.scalar('softmax_temperature', t[0]) 140 | self.summary_op = tf.summary.merge_all() 141 | print('INFO: Model `{}` initialisation complete.'.format(scst_mode)) 142 | 143 | 144 | -------------------------------------------------------------------------------- /src/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | 5 | cd ${DIR}/../common/coco_caption 6 | printf "\nSetting up Stanford CoreNLP for SPICE ...\n" 7 | bash get_stanford_models.sh 8 | 9 | 10 | cd ${DIR}/../datasets/preprocessing 11 | printf "\nRunning pre-processing script for MS-COCO ...\n" 12 | python coco_prepro.py --dataset_dir '' 13 | 14 | printf "\nRunning pre-processing script for InstaPIC-1.1M ...\n" 15 | python insta_prepro.py --dataset_dir '' 16 | 17 | 18 | cd ${DIR}/../common/scst 19 | printf "\nRunning pre-processing script for SCST (MS-COCO) ...\n" 20 | python prepro_ngrams.py --dataset_dir '' 21 | 22 | printf "\nRunning pre-processing script for SCST (InstaPIC-1.1M) ...\n" 23 | python prepro_ngrams.py --dataset_dir '' --dataset_file_pattern 'insta_{}_v25595_s15' 24 | 25 | 26 | printf "\nSetup complete.\n" 27 | --------------------------------------------------------------------------------