├── .DS_Store ├── .idea ├── .gitignore ├── TextKG.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── README.md ├── densevid_eval ├── .DS_Store ├── LICENCE ├── README.md ├── __init__.py ├── anet_data │ ├── anet_entities_test_1.json │ ├── anet_entities_test_1_para.json │ ├── anet_entities_test_2.json │ ├── anet_entities_test_2_para.json │ ├── anet_entities_val_1.json │ ├── anet_entities_val_1_para.json │ ├── anet_entities_val_2.json │ ├── anet_entities_val_2_para.json │ ├── readme.txt │ └── train.json ├── coco-caption │ ├── .gitignore │ ├── README.md │ ├── __init__.py │ ├── annotations │ │ └── captions_val2014.json │ ├── cocoEvalCapDemo.ipynb │ ├── license.txt │ ├── pycocoevalcap │ │ ├── __init__.py │ │ ├── bleu │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── bleu.py │ │ │ └── bleu_scorer.py │ │ ├── cider │ │ │ ├── __init__.py │ │ │ ├── cider.py │ │ │ └── cider_scorer.py │ │ ├── eval.py │ │ ├── meteor │ │ │ ├── __init__.py │ │ │ └── meteor.py │ │ ├── rouge │ │ │ ├── __init__.py │ │ │ └── rouge.py │ │ └── tokenizer │ │ │ ├── __init__.py │ │ │ └── ptbtokenizer.py │ └── pycocotools │ │ ├── __init__.py │ │ └── coco.py ├── evaluate.py ├── evaluateCaptionsDiversity.py ├── evaluateRepetition.py ├── get_caption_stat.py ├── merge_dicts_by_prefix.py ├── para-evaluate.py └── yc2_data │ ├── train_list.txt │ ├── val_list.txt │ ├── val_yc2.json │ ├── yc2_annotations_trainval.json │ ├── yc2_duration_frame.csv │ ├── yc2_train_anet_format.json │ ├── yc2_val_anet_format.json │ └── yc2_val_anet_format_para.json ├── figures └── network.png └── src ├── __init__.py ├── build_vocab.py ├── caption_eval.py ├── rtransformer ├── __init__.py ├── beam_search.py ├── decode_strategy.py ├── masked_transformer.py ├── model.py ├── optimization.py └── recursive_caption_dataset.py ├── train.py ├── translate.py ├── translator.py ├── utils ├── __init__.py ├── checkpoint.py ├── json.py ├── logging.py ├── register.py ├── train_utils.py └── writer.py └── utils_func.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # 基于编辑器的 HTTP 客户端请求 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/TextKG.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 98 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TextKG 2 | 3 | Text with Knowledge Graph Augmented Transformer for Video Captioning 4 | 5 | [[Paper]](https://arxiv.org/abs/2303.12423) 6 | 7 | Official code for **Text with Knowledge Graph Augmented Transformer for Video Captioning**.
8 | 9 | *Xin Gu, Guang Chen, Yufei Wang, Libo Zhang, Tiejian Luo, Longyin Wen* 10 | 11 | Accepted by CVPR2023
12 | 13 | ## Introduction 14 | Existing video captioning methods generally have long tail problems. We present TextKG, a knowledge graph (KG) augmented transformer for video captioning, which integrates external knowledge and exploits multi-modality information in videos to address the challenge of long-tail words. 15 | 16 | ## Approach 17 | ### Knowledge Graphs Construction 18 | - General knowledge graph (G-KG) is designed to include the most key information in general scenarios in which we are interested, such as cooking and activity. It is built from the public available giant knowledge graph ConceptNet by extracting keywords in ConceptNet with the connected edges and neighboring nodes. 19 | - Specific knowledge graph (S-KG) is built to cover key information in specific scenarios. We extract speech transcripts from videos using an automatic speech recognition (ASR) model. We gather phrases such as “adjective and noun”, "noun and noun”, and "adverb and verb" to construct the S-KG. 20 | 21 | ### Two-Stream Transformer 22 | Our approach comprises an external stream that can utilize external knowledge information and an internal stream that can leverage the multimodal information from the video. 23 | 24 | ## Architecture 25 | motivation 26 | 27 |
Figure 1.TextKG Network
28 | 29 | 30 | ## Usage 31 | 32 | Our proposed TextKG is implemented with PyTorch. 33 | 34 | #### Environment 35 | 36 | - Python = 3.7 37 | - PyTorch = 1.4 38 | - pycocoevalcap 39 | 40 | #### 1.Installation 41 | 42 | - Clone this repo: 43 | 44 | ``` 45 | git clone https://github.com/GX77/TextKG.git 46 | cd TextKG 47 | ``` 48 | 49 | #### 2.Download datasets 50 | 51 | - [YouCooKII](https://drive.google.com/file/d/1mj76DwNexFCYovUt8BREeHccQn_z_By9/view?usp=sharing) 52 | 53 | ## Training & Testing 54 | 55 | #### YouCooKII 56 | 57 | ```bash 58 | # Training 59 | python3 train.py --res_root_dir YOUR_DIR --dset_name yc2 60 | 61 | # Test 62 | python3 translate.py --res_dir YOUR_DIR 63 | ``` 64 | We will add other datasets later. 65 | 66 | 67 | ## Citation 68 | 69 | If our research and this repository are helpful to your work, please cite with: 70 | 71 | ``` 72 | @InProceedings{Gu_2023_CVPR, 73 | author = {Gu, Xin and Chen, Guang and Wang, Yufei and Zhang, Libo and Luo, Tiejian and Wen, Longyin}, 74 | title = {Text With Knowledge Graph Augmented Transformer for Video Captioning}, 75 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 76 | month = {June}, 77 | year = {2023}, 78 | pages = {18941-18951} 79 | } 80 | ``` 81 | 82 | 83 | 84 | ## Acknowledge 85 | 86 | Code of the decoding part is based on [MART](https://github.com/jayleicn/recurrent-transformer). 87 | -------------------------------------------------------------------------------- /densevid_eval/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/densevid_eval/.DS_Store -------------------------------------------------------------------------------- /densevid_eval/LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ranjay Krishna 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /densevid_eval/README.md: -------------------------------------------------------------------------------- 1 | # Dense Captioning Events in Video - Evaluation Code 2 | 3 | This is a modified copy from https://github.com/jamespark3922/densevid_eval, 4 | which is again a modified copy of [densevid_eval](https://github.com/ranjaykrishna/densevid_eval). 5 | Instead of using sentence metrics, we evaluate captions at the paragraph level, 6 | as described in [Move Forward and Tell (ECCV18)](https://arxiv.org/abs/1807.10018) 7 | 8 | ## Usage 9 | ``` 10 | python para-evaluate.py -s YOUR_SUBMISSION_FILE.JSON --verbose 11 | ``` 12 | 13 | ## Paper 14 | Visit [the project page](http://cs.stanford.edu/people/ranjaykrishna/densevid) for details on activitynet captions. 15 | 16 | ## Citation 17 | ``` 18 | @inproceedings{krishna2017dense, 19 | title={Dense-Captioning Events in Videos}, 20 | author={Krishna, Ranjay and Hata, Kenji and Ren, Frederic and Fei-Fei, Li and Niebles, Juan Carlos}, 21 | booktitle={ArXiv}, 22 | year={2017} 23 | } 24 | ``` 25 | 26 | ## License 27 | 28 | MIT License copyright Ranjay Krishna 29 | 30 | -------------------------------------------------------------------------------- /densevid_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/densevid_eval/__init__.py -------------------------------------------------------------------------------- /densevid_eval/anet_data/readme.txt: -------------------------------------------------------------------------------- 1 | ANet-Entities val/test splits (re-split from ANet-caption val_1 and val_2 splits): 2 | https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_captions.tar.gz 3 | 4 | ANet-caption original splits: 5 | http://cs.stanford.edu/people/ranjaykrishna/densevid/captions.zip 6 | 7 | Experiment settings: 8 | Training: use GT segments/sentences in `train.json`, 9 | Validation: use GT segments in `anet_entities_val_1.json`, evaluate against references `anet_entities_val_1_para.json` and `anet_entities_val_2_para.json` 10 | Test: use GT segments in `anet_entities_test_1.json`, evaluate against references `anet_entities_test_1_para.json` and `anet_entities_test_2_para.json` 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/README.md: -------------------------------------------------------------------------------- 1 | Microsoft COCO Caption Evaluation 2 | =================== 3 | 4 | Evaluation codes for MS COCO caption generation. 5 | 6 | ## Requirements ## 7 | - java 1.8.0 8 | - python 2.7 9 | 10 | ## Files ## 11 | ./ 12 | - cocoEvalCapDemo.py (demo script) 13 | 14 | ./annotation 15 | - captions_val2014.json (MS COCO 2014 caption validation set) 16 | - Visit MS COCO [download](http://mscoco.org/dataset/#download) page for more details. 17 | 18 | ./results 19 | - captions_val2014_fakecap_results.json (an example of fake results for running demo) 20 | - Visit MS COCO [format](http://mscoco.org/dataset/#format) page for more details. 21 | 22 | ./pycocoevalcap: The folder where all evaluation codes are stored. 23 | - evals.py: The file includes COCOEavlCap class that can be used to evaluate results on COCO. 24 | - tokenizer: Python wrapper of Stanford CoreNLP PTBTokenizer 25 | - bleu: Bleu evalutation codes 26 | - meteor: Meteor evaluation codes 27 | - rouge: Rouge-L evaluation codes 28 | - cider: CIDEr evaluation codes 29 | 30 | ## References ## 31 | 32 | - [Microsoft COCO Captions: Data Collection and Evaluation Server](http://arxiv.org/abs/1504.00325) 33 | - PTBTokenizer: We use the [Stanford Tokenizer](http://nlp.stanford.edu/software/tokenizer.shtml) which is included in [Stanford CoreNLP 3.4.1](http://nlp.stanford.edu/software/corenlp.shtml). 34 | - BLEU: [BLEU: a Method for Automatic Evaluation of Machine Translation](http://www.aclweb.org/anthology/P02-1040.pdf) 35 | - Meteor: [Project page](http://www.cs.cmu.edu/~alavie/METEOR/) with related publications. We use the latest version (1.5) of the [Code](https://github.com/mjdenkowski/meteor). Changes have been made to the source code to properly aggreate the statistics for the entire corpus. 36 | - Rouge-L: [ROUGE: A Package for Automatic Evaluation of Summaries](http://anthology.aclweb.org/W/W04/W04-1013.pdf) 37 | - CIDEr: [CIDEr: Consensus-based Image Description Evaluation] (http://arxiv.org/pdf/1411.5726.pdf) 38 | 39 | ## Developers ## 40 | - Xinlei Chen (CMU) 41 | - Hao Fang (University of Washington) 42 | - Tsung-Yi Lin (Cornell) 43 | - Ramakrishna Vedantam (Virgina Tech) 44 | 45 | ## Acknowledgement ## 46 | - David Chiang (University of Norte Dame) 47 | - Michael Denkowski (CMU) 48 | - Alexander Rush (Harvard University) 49 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/densevid_eval/coco-caption/__init__.py -------------------------------------------------------------------------------- /densevid_eval/coco-caption/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | import sys 11 | sys.path.append('/mnt/bd/gxvolume/Try_asr/densevid_eval/coco-caption/pycocoevalcap/bleu/') 12 | from bleu_scorer import BleuScorer 13 | 14 | 15 | class Bleu: 16 | def __init__(self, n=4): 17 | # default compute Blue score up to 4 18 | self._n = n 19 | self._hypo_for_image = {} 20 | self.ref_for_image = {} 21 | 22 | def compute_score(self, gts, res): 23 | 24 | assert(gts.keys() == res.keys()) 25 | imgIds = gts.keys() 26 | 27 | bleu_scorer = BleuScorer(n=self._n) 28 | for id in imgIds: 29 | hypo = res[id] 30 | ref = gts[id] 31 | 32 | # Sanity check. 33 | assert(type(hypo) is list) 34 | assert(len(hypo) == 1) 35 | assert(type(ref) is list) 36 | assert(len(ref) >= 1) 37 | 38 | bleu_scorer += (hypo[0], ref) 39 | 40 | #score, scores = bleu_scorer.compute_score(option='shortest') 41 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 42 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 43 | 44 | # return (bleu, bleu_info) 45 | return score, scores 46 | 47 | def method(self): 48 | return "Bleu" 49 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.items(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, xxx_todo_changeme, eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | (reflen, refmaxcounts) = xxx_todo_changeme 64 | testlen, counts = precook(test, n, True) 65 | 66 | result = {} 67 | 68 | # Calculate effective reference sentence length. 69 | 70 | if eff == "closest": 71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 72 | else: ## i.e., "average" or "shortest" or None 73 | result["reflen"] = reflen 74 | 75 | result["testlen"] = testlen 76 | 77 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 78 | 79 | result['correct'] = [0]*n 80 | for (ngram, count) in counts.items(): 81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 82 | 83 | return result 84 | 85 | class BleuScorer(object): 86 | """Bleu scorer. 87 | """ 88 | 89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 90 | # special_reflen is used in oracle (proportional effective ref len for a node). 91 | 92 | def copy(self): 93 | ''' copy the refs.''' 94 | new = BleuScorer(n=self.n) 95 | new.ctest = copy.copy(self.ctest) 96 | new.crefs = copy.copy(self.crefs) 97 | new._score = None 98 | return new 99 | 100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 101 | ''' singular instance ''' 102 | 103 | self.n = n 104 | self.crefs = [] 105 | self.ctest = [] 106 | self.cook_append(test, refs) 107 | self.special_reflen = special_reflen 108 | 109 | def cook_append(self, test, refs): 110 | '''called by constructor and __iadd__ to avoid creating new instances.''' 111 | 112 | if refs is not None: 113 | self.crefs.append(cook_refs(refs)) 114 | if test is not None: 115 | cooked_test = cook_test(test, self.crefs[-1]) 116 | self.ctest.append(cooked_test) ## N.B.: -1 117 | else: 118 | self.ctest.append(None) # lens of crefs and ctest have to match 119 | 120 | self._score = None ## need to recompute 121 | 122 | def ratio(self, option=None): 123 | self.compute_score(option=option) 124 | return self._ratio 125 | 126 | def score_ratio(self, option=None): 127 | '''return (bleu, len_ratio) pair''' 128 | return (self.fscore(option=option), self.ratio(option=option)) 129 | 130 | def score_ratio_str(self, option=None): 131 | return "%.4f (%.2f)" % self.score_ratio(option) 132 | 133 | def reflen(self, option=None): 134 | self.compute_score(option=option) 135 | return self._reflen 136 | 137 | def testlen(self, option=None): 138 | self.compute_score(option=option) 139 | return self._testlen 140 | 141 | def retest(self, new_test): 142 | if type(new_test) is str: 143 | new_test = [new_test] 144 | assert len(new_test) == len(self.crefs), new_test 145 | self.ctest = [] 146 | for t, rs in zip(new_test, self.crefs): 147 | self.ctest.append(cook_test(t, rs)) 148 | self._score = None 149 | 150 | return self 151 | 152 | def rescore(self, new_test): 153 | ''' replace test(s) with new test(s), and returns the new score.''' 154 | 155 | return self.retest(new_test).compute_score() 156 | 157 | def size(self): 158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 159 | return len(self.crefs) 160 | 161 | def __iadd__(self, other): 162 | '''add an instance (e.g., from another sentence).''' 163 | 164 | if type(other) is tuple: 165 | ## avoid creating new BleuScorer instances 166 | self.cook_append(other[0], other[1]) 167 | else: 168 | assert self.compatible(other), "incompatible BLEUs." 169 | self.ctest.extend(other.ctest) 170 | self.crefs.extend(other.crefs) 171 | self._score = None ## need to recompute 172 | 173 | return self 174 | 175 | def compatible(self, other): 176 | return isinstance(other, BleuScorer) and self.n == other.n 177 | 178 | def single_reflen(self, option="average"): 179 | return self._single_reflen(self.crefs[0][0], option) 180 | 181 | def _single_reflen(self, reflens, option=None, testlen=None): 182 | 183 | if option == "shortest": 184 | reflen = min(reflens) 185 | elif option == "average": 186 | reflen = float(sum(reflens))/len(reflens) 187 | elif option == "closest": 188 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 189 | else: 190 | assert False, "unsupported reflen option %s" % option 191 | 192 | return reflen 193 | 194 | def recompute_score(self, option=None, verbose=0): 195 | self._score = None 196 | return self.compute_score(option, verbose) 197 | 198 | def compute_score(self, option=None, verbose=0): 199 | n = self.n 200 | small = 1e-9 201 | tiny = 1e-15 ## so that if guess is 0 still return 0 202 | bleu_list = [[] for _ in range(n)] 203 | 204 | if self._score is not None: 205 | return self._score 206 | 207 | if option is None: 208 | option = "average" if len(self.crefs) == 1 else "closest" 209 | 210 | self._testlen = 0 211 | self._reflen = 0 212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 213 | 214 | # for each sentence 215 | for comps in self.ctest: 216 | testlen = comps['testlen'] 217 | self._testlen += testlen 218 | 219 | if self.special_reflen is None: ## need computation 220 | reflen = self._single_reflen(comps['reflen'], option, testlen) 221 | else: 222 | reflen = self.special_reflen 223 | 224 | self._reflen += reflen 225 | 226 | for key in ['guess','correct']: 227 | for k in range(n): 228 | totalcomps[key][k] += comps[key][k] 229 | 230 | # append per image bleu score 231 | bleu = 1. 232 | for k in range(n): 233 | bleu *= (float(comps['correct'][k]) + tiny) \ 234 | /(float(comps['guess'][k]) + small) 235 | bleu_list[k].append(bleu ** (1./(k+1))) 236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 237 | if ratio < 1: 238 | for k in range(n): 239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 240 | 241 | if verbose > 1: 242 | print(comps, reflen) 243 | 244 | totalcomps['reflen'] = self._reflen 245 | totalcomps['testlen'] = self._testlen 246 | 247 | bleus = [] 248 | bleu = 1. 249 | for k in range(n): 250 | bleu *= float(totalcomps['correct'][k] + tiny) \ 251 | / (totalcomps['guess'][k] + small) 252 | bleus.append(bleu ** (1./(k+1))) 253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 254 | if ratio < 1: 255 | for k in range(n): 256 | bleus[k] *= math.exp(1 - 1/ratio) 257 | 258 | if verbose > 0: 259 | print(totalcomps) 260 | print("ratio:", ratio) 261 | 262 | self._score = bleus 263 | return self._score, bleu_list 264 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | import sys 10 | sys.path.append('/mnt/bd/gxvolume/Try_asr/densevid_eval/coco-caption/pycocoevalcap/cider/') 11 | from cider_scorer import CiderScorer 12 | import pdb 13 | 14 | class Cider: 15 | """ 16 | Main Class to compute the CIDEr metric 17 | 18 | """ 19 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 20 | # set cider to sum over 1 to 4-grams 21 | self._n = n 22 | # set the standard deviation parameter for gaussian penalty 23 | self._sigma = sigma 24 | 25 | def compute_score(self, gts, res): 26 | """ 27 | Main function to compute CIDEr score 28 | :param hypo_for_image (dict) : dictionary with key and value 29 | ref_for_image (dict) : dictionary with key and value 30 | :return: cider (float) : computed CIDEr score for the corpus 31 | """ 32 | 33 | assert(gts.keys() == res.keys()) 34 | imgIds = gts.keys() 35 | 36 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 37 | 38 | for id in imgIds: 39 | hypo = res[id] 40 | ref = gts[id] 41 | 42 | # Sanity check. 43 | assert(type(hypo) is list) 44 | assert(len(hypo) == 1) 45 | assert(type(ref) is list) 46 | assert(len(ref) > 0) 47 | 48 | cider_scorer += (hypo[0], ref) 49 | 50 | (score, scores) = cider_scorer.compute_score() 51 | 52 | return(score, scores) 53 | 54 | def method(self): 55 | return "CIDEr" -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in range(1,n+1): 23 | for i in range(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.items(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].items(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | from tokenizer.ptbtokenizer import PTBTokenizer 3 | from bleu.bleu import Bleu 4 | from meteor.meteor import Meteor 5 | from rouge.rouge import Rouge 6 | from cider.cider import Cider 7 | 8 | class COCOEvalCap: 9 | def __init__(self, coco, cocoRes): 10 | self.evalImgs = [] 11 | self.eval = {} 12 | self.imgToEval = {} 13 | self.coco = coco 14 | self.cocoRes = cocoRes 15 | self.params = {'image_id': coco.getImgIds()} 16 | 17 | def evaluate(self): 18 | imgIds = self.params['image_id'] 19 | # imgIds = self.coco.getImgIds() 20 | gts = {} 21 | res = {} 22 | for imgId in imgIds: 23 | gts[imgId] = self.coco.imgToAnns[imgId] 24 | res[imgId] = self.cocoRes.imgToAnns[imgId] 25 | 26 | # ================================================= 27 | # Set up scorers 28 | # ================================================= 29 | print('tokenization...') 30 | tokenizer = PTBTokenizer() 31 | gts = tokenizer.tokenize(gts) 32 | res = tokenizer.tokenize(res) 33 | 34 | # ================================================= 35 | # Set up scorers 36 | # ================================================= 37 | print('setting up scorers...') 38 | scorers = [ 39 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 40 | (Meteor(),"METEOR"), 41 | (Rouge(), "ROUGE_L"), 42 | (Cider(), "CIDEr") 43 | ] 44 | 45 | # ================================================= 46 | # Compute scores 47 | # ================================================= 48 | for scorer, method in scorers: 49 | print('computing %s score...'%(scorer.method())) 50 | score, scores = scorer.compute_score(gts, res) 51 | if type(method) == list: 52 | for sc, scs, m in zip(score, scores, method): 53 | self.setEval(sc, m) 54 | self.setImgToEvalImgs(scs, gts.keys(), m) 55 | print('%s: %0.3f'%(m, sc)) 56 | else: 57 | self.setEval(score, method) 58 | self.setImgToEvalImgs(scores, gts.keys(), method) 59 | print("%s: %0.3f"%(method, score)) 60 | self.setEvalImgs() 61 | 62 | def setEval(self, score, method): 63 | self.eval[method] = score 64 | 65 | def setImgToEvalImgs(self, scores, imgIds, method): 66 | for imgId, score in zip(imgIds, scores): 67 | if not imgId in self.imgToEval: 68 | self.imgToEval[imgId] = {} 69 | self.imgToEval[imgId]["image_id"] = imgId 70 | self.imgToEval[imgId][method] = score 71 | 72 | def setEvalImgs(self): 73 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | from signal import signal, SIGPIPE, SIG_DFL 11 | 12 | # 让 python 忽略 SIGPIPE 信号,并且不抛出异常 13 | signal(SIGPIPE,SIG_DFL) 14 | 15 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 16 | METEOR_JAR = 'meteor-1.5.jar' 17 | # print METEOR_JAR 18 | 19 | class Meteor: 20 | 21 | def __init__(self): 22 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 23 | '-', '-', '-stdio', '-l', 'en', '-norm'] 24 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 25 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 26 | stdin=subprocess.PIPE, \ 27 | stdout=subprocess.PIPE, \ 28 | stderr=subprocess.PIPE) 29 | # Used to guarantee thread safety 30 | self.lock = threading.Lock() 31 | 32 | def compute_score(self, gts, res): 33 | assert(gts.keys() == res.keys()) 34 | imgIds = gts.keys() 35 | scores = [] 36 | 37 | eval_line = 'EVAL' 38 | self.lock.acquire() 39 | for i in imgIds: 40 | assert(len(res[i]) == 1) 41 | stat = self._stat(res[i][0], gts[i]) 42 | eval_line += ' ||| {}'.format(stat) 43 | 44 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 45 | for i in range(0,len(imgIds)): 46 | scores.append(float(self.meteor_p.stdout.readline().strip())) 47 | score = float(self.meteor_p.stdout.readline().strip()) 48 | self.lock.release() 49 | 50 | return score, scores 51 | 52 | def method(self): 53 | return "METEOR" 54 | 55 | def _stat(self, hypothesis_str, reference_list): 56 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 57 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 58 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 59 | w = bytes(score_line,encoding='utf-8') 60 | self.meteor_p.stdin.write(w) 61 | return self.meteor_p.stdout.readline().strip() 62 | 63 | 64 | def _score(self, hypothesis_str, reference_list): 65 | self.lock.acquire() 66 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 67 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 68 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 69 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 70 | stats = self.meteor_p.stdout.readline().strip() 71 | eval_line = 'EVAL ||| {}'.format(stats) 72 | # EVAL ||| stats 73 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 74 | score = float(self.meteor_p.stdout.readline().strip()) 75 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 76 | # thanks for Andrej for pointing this out 77 | score = float(self.meteor_p.stdout.readline().strip()) 78 | self.lock.release() 79 | return score 80 | 81 | def __del__(self): 82 | self.lock.acquire() 83 | self.meteor_p.stdin.close() 84 | self.meteor_p.kill() 85 | self.meteor_p.wait() 86 | self.lock.release() 87 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | # path to the stanford corenlp jar 18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 19 | 20 | # punctuations to be removed from the sentences 21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 23 | 24 | class PTBTokenizer: 25 | """Python wrapper of Stanford PTBTokenizer""" 26 | 27 | def tokenize(self, captions_for_image): 28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 29 | 'edu.stanford.nlp.process.PTBTokenizer', \ 30 | '-preserveLines', '-lowerCase'] 31 | 32 | # ====================================================== 33 | # prepare data for PTB Tokenizer 34 | # ====================================================== 35 | final_tokenized_captions_for_image = {} 36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 37 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 38 | 39 | # ====================================================== 40 | # save sentences to temporary file 41 | # ====================================================== 42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 44 | tmp_file.write(sentences) 45 | tmp_file.close() 46 | 47 | # ====================================================== 48 | # tokenize sentence 49 | # ====================================================== 50 | cmd.append(os.path.basename(tmp_file.name)) 51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 52 | stdout=subprocess.PIPE) 53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 54 | lines = token_lines.split('\n') 55 | # remove temp file 56 | os.remove(tmp_file.name) 57 | 58 | # ====================================================== 59 | # create dictionary for tokenized captions 60 | # ====================================================== 61 | for k, line in zip(image_id, lines): 62 | if not k in final_tokenized_captions_for_image: 63 | final_tokenized_captions_for_image[k] = [] 64 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 65 | if w not in PUNCTUATIONS]) 66 | final_tokenized_captions_for_image[k].append(tokenized_caption) 67 | 68 | return final_tokenized_captions_for_image 69 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /densevid_eval/coco-caption/pycocotools/coco.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | __version__ = '1.0.1' 3 | # Interface for accessing the Microsoft COCO dataset. 4 | 5 | # Microsoft COCO is a large image dataset designed for object detection, 6 | # segmentation, and caption generation. pycocotools is a Python API that 7 | # assists in loading, parsing and visualizing the annotations in COCO. 8 | # Please visit http://mscoco.org/ for more information on COCO, including 9 | # for the data, paper, and tutorials. The exact format of the annotations 10 | # is also described on the COCO website. For example usage of the pycocotools 11 | # please see pycocotools_demo.ipynb. In addition to this API, please download both 12 | # the COCO images and annotations in order to run the demo. 13 | 14 | # An alternative to using the API is to load the annotations directly 15 | # into Python dictionary 16 | # Using the API provides additional utility functions. Note that this API 17 | # supports both *instance* and *caption* annotations. In the case of 18 | # captions not all functions are defined (e.g. categories are undefined). 19 | 20 | # The following API functions are defined: 21 | # COCO - COCO api class that loads COCO annotation file and prepare data structures. 22 | # decodeMask - Decode binary mask M encoded via run-length encoding. 23 | # encodeMask - Encode binary mask M using run-length encoding. 24 | # getAnnIds - Get ann ids that satisfy given filter conditions. 25 | # getCatIds - Get cat ids that satisfy given filter conditions. 26 | # getImgIds - Get img ids that satisfy given filter conditions. 27 | # loadAnns - Load anns with the specified ids. 28 | # loadCats - Load cats with the specified ids. 29 | # loadImgs - Load imgs with the specified ids. 30 | # segToMask - Convert polygon segmentation to binary mask. 31 | # showAnns - Display the specified annotations. 32 | # loadRes - Load result file and create result api object. 33 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 34 | # Help on each functions can be accessed by: "help COCO>function". 35 | 36 | # See also COCO>decodeMask, 37 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds, 38 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats, 39 | # COCO>loadImgs, COCO>segToMask, COCO>showAnns 40 | 41 | # Microsoft COCO Toolbox. Version 1.0 42 | # Data, paper, and tutorials available at: http://mscoco.org/ 43 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014. 44 | # Licensed under the Simplified BSD License [see bsd.txt] 45 | 46 | import json 47 | import datetime 48 | import matplotlib.pyplot as plt 49 | from matplotlib.collections import PatchCollection 50 | from matplotlib.patches import Polygon 51 | import numpy as np 52 | from skimage.draw import polygon 53 | import copy 54 | 55 | class COCO: 56 | def __init__(self, annotation_file=None): 57 | """ 58 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 59 | :param annotation_file (str): location of annotation file 60 | :param image_folder (str): location to the folder that hosts images. 61 | :return: 62 | """ 63 | # load dataset 64 | self.dataset = {} 65 | self.anns = [] 66 | self.imgToAnns = {} 67 | self.catToImgs = {} 68 | self.imgs = [] 69 | self.cats = [] 70 | if not annotation_file == None: 71 | print('loading annotations into memory...') 72 | time_t = datetime.datetime.utcnow() 73 | dataset = json.load(open(annotation_file, 'r')) 74 | print(datetime.datetime.utcnow() - time_t) 75 | self.dataset = dataset 76 | self.createIndex() 77 | 78 | def createIndex(self): 79 | # create index 80 | print('creating index...') 81 | imgToAnns = {ann['image_id']: [] for ann in self.dataset['annotations']} 82 | anns = {ann['id']: [] for ann in self.dataset['annotations']} 83 | for ann in self.dataset['annotations']: 84 | imgToAnns[ann['image_id']] += [ann] 85 | anns[ann['id']] = ann 86 | 87 | imgs = {im['id']: {} for im in self.dataset['images']} 88 | for img in self.dataset['images']: 89 | imgs[img['id']] = img 90 | 91 | cats = [] 92 | catToImgs = [] 93 | if self.dataset['type'] == 'instances': 94 | cats = {cat['id']: [] for cat in self.dataset['categories']} 95 | for cat in self.dataset['categories']: 96 | cats[cat['id']] = cat 97 | catToImgs = {cat['id']: [] for cat in self.dataset['categories']} 98 | for ann in self.dataset['annotations']: 99 | catToImgs[ann['category_id']] += [ann['image_id']] 100 | 101 | print('index created!') 102 | 103 | # create class members 104 | self.anns = anns 105 | self.imgToAnns = imgToAnns 106 | self.catToImgs = catToImgs 107 | self.imgs = imgs 108 | self.cats = cats 109 | 110 | def info(self): 111 | """ 112 | Print information about the annotation file. 113 | :return: 114 | """ 115 | for key, value in self.datset['info'].items(): 116 | print('%s: %s'%(key, value)) 117 | 118 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): 119 | """ 120 | Get ann ids that satisfy given filter conditions. default skips that filter 121 | :param imgIds (int array) : get anns for given imgs 122 | catIds (int array) : get anns for given cats 123 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 124 | iscrowd (boolean) : get anns for given crowd label (False or True) 125 | :return: ids (int array) : integer array of ann ids 126 | """ 127 | imgIds = imgIds if type(imgIds) == list else [imgIds] 128 | catIds = catIds if type(catIds) == list else [catIds] 129 | 130 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 131 | anns = self.dataset['annotations'] 132 | else: 133 | if not len(imgIds) == 0: 134 | anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[]) 135 | else: 136 | anns = self.dataset['annotations'] 137 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] 138 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 139 | if self.dataset['type'] == 'instances': 140 | if not iscrowd == None: 141 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] 142 | else: 143 | ids = [ann['id'] for ann in anns] 144 | else: 145 | ids = [ann['id'] for ann in anns] 146 | return ids 147 | 148 | def getCatIds(self, catNms=[], supNms=[], catIds=[]): 149 | """ 150 | filtering parameters. default skips that filter. 151 | :param catNms (str array) : get cats for given cat names 152 | :param supNms (str array) : get cats for given supercategory names 153 | :param catIds (int array) : get cats for given cat ids 154 | :return: ids (int array) : integer array of cat ids 155 | """ 156 | catNms = catNms if type(catNms) == list else [catNms] 157 | supNms = supNms if type(supNms) == list else [supNms] 158 | catIds = catIds if type(catIds) == list else [catIds] 159 | 160 | if len(catNms) == len(supNms) == len(catIds) == 0: 161 | cats = self.dataset['categories'] 162 | else: 163 | cats = self.dataset['categories'] 164 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] 165 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] 166 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] 167 | ids = [cat['id'] for cat in cats] 168 | return ids 169 | 170 | def getImgIds(self, imgIds=[], catIds=[]): 171 | ''' 172 | Get img ids that satisfy given filter conditions. 173 | :param imgIds (int array) : get imgs for given ids 174 | :param catIds (int array) : get imgs with all given cats 175 | :return: ids (int array) : integer array of img ids 176 | ''' 177 | imgIds = imgIds if type(imgIds) == list else [imgIds] 178 | catIds = catIds if type(catIds) == list else [catIds] 179 | 180 | if len(imgIds) == len(catIds) == 0: 181 | ids = self.imgs.keys() 182 | else: 183 | ids = set(imgIds) 184 | for catId in catIds: 185 | if len(ids) == 0: 186 | ids = set(self.catToImgs[catId]) 187 | else: 188 | ids &= set(self.catToImgs[catId]) 189 | return list(ids) 190 | 191 | def loadAnns(self, ids=[]): 192 | """ 193 | Load anns with the specified ids. 194 | :param ids (int array) : integer ids specifying anns 195 | :return: anns (object array) : loaded ann objects 196 | """ 197 | if type(ids) == list: 198 | return [self.anns[id] for id in ids] 199 | elif type(ids) == int: 200 | return [self.anns[ids]] 201 | 202 | def loadCats(self, ids=[]): 203 | """ 204 | Load cats with the specified ids. 205 | :param ids (int array) : integer ids specifying cats 206 | :return: cats (object array) : loaded cat objects 207 | """ 208 | if type(ids) == list: 209 | return [self.cats[id] for id in ids] 210 | elif type(ids) == int: 211 | return [self.cats[ids]] 212 | 213 | def loadImgs(self, ids=[]): 214 | """ 215 | Load anns with the specified ids. 216 | :param ids (int array) : integer ids specifying img 217 | :return: imgs (object array) : loaded img objects 218 | """ 219 | if type(ids) == list: 220 | return [self.imgs[id] for id in ids] 221 | elif type(ids) == int: 222 | return [self.imgs[ids]] 223 | 224 | def showAnns(self, anns): 225 | """ 226 | Display the specified annotations. 227 | :param anns (array of object): annotations to display 228 | :return: None 229 | """ 230 | if len(anns) == 0: 231 | return 0 232 | if self.dataset['type'] == 'instances': 233 | ax = plt.gca() 234 | polygons = [] 235 | color = [] 236 | for ann in anns: 237 | c = np.random.random((1, 3)).tolist()[0] 238 | if type(ann['segmentation']) == list: 239 | # polygon 240 | for seg in ann['segmentation']: 241 | poly = np.array(seg).reshape((len(seg)/2, 2)) 242 | polygons.append(Polygon(poly, True,alpha=0.4)) 243 | color.append(c) 244 | else: 245 | # mask 246 | mask = COCO.decodeMask(ann['segmentation']) 247 | img = np.ones( (mask.shape[0], mask.shape[1], 3) ) 248 | if ann['iscrowd'] == 1: 249 | color_mask = np.array([2.0,166.0,101.0])/255 250 | if ann['iscrowd'] == 0: 251 | color_mask = np.random.random((1, 3)).tolist()[0] 252 | for i in range(3): 253 | img[:,:,i] = color_mask[i] 254 | ax.imshow(np.dstack( (img, mask*0.5) )) 255 | p = PatchCollection(polygons, facecolors=color, edgecolors=(0,0,0,1), linewidths=3, alpha=0.4) 256 | ax.add_collection(p) 257 | if self.dataset['type'] == 'captions': 258 | for ann in anns: 259 | print(ann['caption']) 260 | 261 | def loadRes(self, resFile): 262 | """ 263 | Load result file and return a result api object. 264 | :param resFile (str) : file name of result file 265 | :return: res (obj) : result api object 266 | """ 267 | res = COCO() 268 | res.dataset['images'] = [img for img in self.dataset['images']] 269 | res.dataset['info'] = copy.deepcopy(self.dataset['info']) 270 | res.dataset['type'] = copy.deepcopy(self.dataset['type']) 271 | res.dataset['licenses'] = copy.deepcopy(self.dataset['licenses']) 272 | 273 | print('Loading and preparing results... ') 274 | time_t = datetime.datetime.utcnow() 275 | anns = json.load(open(resFile)) 276 | assert type(anns) == list, 'results in not an array of objects' 277 | annsImgIds = [ann['image_id'] for ann in anns] 278 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 279 | 'Results do not correspond to current coco set' 280 | if 'caption' in anns[0]: 281 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 282 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 283 | for id, ann in enumerate(anns): 284 | ann['id'] = id 285 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 286 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 287 | for id, ann in enumerate(anns): 288 | bb = ann['bbox'] 289 | x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]] 290 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 291 | ann['area'] = bb[2]*bb[3] 292 | ann['id'] = id 293 | ann['iscrowd'] = 0 294 | elif 'segmentation' in anns[0]: 295 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 296 | for id, ann in enumerate(anns): 297 | ann['area']=sum(ann['segmentation']['counts'][2:-1:2]) 298 | ann['bbox'] = [] 299 | ann['id'] = id 300 | ann['iscrowd'] = 0 301 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())) 302 | 303 | res.dataset['annotations'] = anns 304 | res.createIndex() 305 | return res 306 | 307 | 308 | @staticmethod 309 | def decodeMask(R): 310 | """ 311 | Decode binary mask M encoded via run-length encoding. 312 | :param R (object RLE) : run-length encoding of binary mask 313 | :return: M (bool 2D array) : decoded binary mask 314 | """ 315 | N = len(R['counts']) 316 | M = np.zeros( (R['size'][0]*R['size'][1], )) 317 | n = 0 318 | val = 1 319 | for pos in range(N): 320 | val = not val 321 | for c in range(R['counts'][pos]): 322 | R['counts'][pos] 323 | M[n] = val 324 | n += 1 325 | return M.reshape((R['size']), order='F') 326 | 327 | @staticmethod 328 | def encodeMask(M): 329 | """ 330 | Encode binary mask M using run-length encoding. 331 | :param M (bool 2D array) : binary mask to encode 332 | :return: R (object RLE) : run-length encoding of binary mask 333 | """ 334 | [h, w] = M.shape 335 | M = M.flatten(order='F') 336 | N = len(M) 337 | counts_list = [] 338 | pos = 0 339 | # counts 340 | counts_list.append(1) 341 | diffs = np.logical_xor(M[0:N-1], M[1:N]) 342 | for diff in diffs: 343 | if diff: 344 | pos +=1 345 | counts_list.append(1) 346 | else: 347 | counts_list[pos] += 1 348 | # if array starts from 1. start with 0 counts for 0 349 | if M[0] == 1: 350 | counts_list = [0] + counts_list 351 | return {'size': [h, w], 352 | 'counts': counts_list , 353 | } 354 | 355 | @staticmethod 356 | def segToMask( S, h, w ): 357 | """ 358 | Convert polygon segmentation to binary mask. 359 | :param S (float array) : polygon segmentation mask 360 | :param h (int) : target mask height 361 | :param w (int) : target mask width 362 | :return: M (bool 2D array) : binary mask 363 | """ 364 | M = np.zeros((h,w), dtype=np.bool) 365 | for s in S: 366 | N = len(s) 367 | rr, cc = polygon(np.array(s[1:N:2]), np.array(s[0:N:2])) # (y, x) 368 | M[rr, cc] = 1 369 | return M -------------------------------------------------------------------------------- /densevid_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dense-Captioning Events in Videos Eval 3 | # Copyright (c) 2017 Ranjay Krishna 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ranjay Krishna 6 | # -------------------------------------------------------- 7 | 8 | import argparse 9 | import string 10 | import json 11 | import sys 12 | sys.path.insert(0, './coco-caption') # Hack to allow the import of pycocoeval 13 | 14 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 15 | from pycocoevalcap.bleu.bleu import Bleu 16 | from pycocoevalcap.meteor.meteor import Meteor 17 | from pycocoevalcap.rouge.rouge import Rouge 18 | from pycocoevalcap.cider.cider import Cider 19 | from pycocoevalcap.spice.spice import Spice 20 | from sets import Set 21 | import numpy as np 22 | 23 | def remove_nonascii(text): 24 | return ''.join([i if ord(i) < 128 else ' ' for i in text]) 25 | 26 | class ANETcaptions(object): 27 | PREDICTION_FIELDS = ['results', 'version', 'external_data'] 28 | 29 | def __init__(self, ground_truth_filenames=None, prediction_filename=None, 30 | tious=None, max_proposals=1000, 31 | prediction_fields=PREDICTION_FIELDS, verbose=False): 32 | # Check that the gt and submission files exist and load them 33 | if len(tious) == 0: 34 | raise IOError('Please input a valid tIoU.') 35 | if not ground_truth_filenames: 36 | raise IOError('Please input a valid ground truth file.') 37 | if not prediction_filename: 38 | raise IOError('Please input a valid prediction file.') 39 | 40 | self.verbose = verbose 41 | self.tious = tious 42 | self.max_proposals = max_proposals 43 | self.pred_fields = prediction_fields 44 | self.ground_truths = self.import_ground_truths(ground_truth_filenames) 45 | self.prediction = self.import_prediction(prediction_filename) 46 | self.tokenizer = PTBTokenizer() 47 | 48 | # Set up scorers, if not verbose, we only use the one we're 49 | # testing on: METEOR 50 | if self.verbose: 51 | self.scorers = [ 52 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 53 | (Meteor(),"METEOR"), 54 | (Rouge(), "ROUGE_L"), 55 | (Cider(), "CIDEr"), 56 | (Spice(), "SPICE") 57 | ] 58 | else: 59 | self.scorers = [(Meteor(), "METEOR")] 60 | 61 | def import_prediction(self, prediction_filename): 62 | if self.verbose: 63 | print("| Loading submission...") 64 | submission = json.load(open(prediction_filename)) 65 | if not all([field in submission.keys() for field in self.pred_fields]): 66 | raise IOError('Please input a valid ground truth file.') 67 | # Ensure that every video is limited to the correct maximum number of proposals. 68 | results = {} 69 | len_captions = 0 70 | for vid_id in submission['results']: 71 | results[vid_id] = submission['results'][vid_id][:self.max_proposals] 72 | len_captions+= len(submission['results'][vid_id][:self.max_proposals]) 73 | print('len of results:', len(results)) 74 | print('len of captions:', len_captions) 75 | return results 76 | 77 | def import_ground_truths(self, filenames): 78 | gts = [] 79 | self.n_ref_vids = Set() 80 | for filename in filenames: 81 | gt = json.load(open(filename)) 82 | self.n_ref_vids.update(gt.keys()) 83 | gts.append(gt) 84 | if self.verbose: 85 | print("| Loading GT. #files: %d, #videos: %d" % (len(filenames), len(self.n_ref_vids))) 86 | return gts 87 | 88 | def iou(self, interval_1, interval_2): 89 | start_i, end_i = interval_1[0], interval_1[1] 90 | start, end = interval_2[0], interval_2[1] 91 | intersection = max(0, min(end, end_i) - max(start, start_i)) 92 | union = min(max(end, end_i) - min(start, start_i), end-start + end_i-start_i) 93 | iou = float(intersection) / (union + 1e-8) 94 | return iou 95 | 96 | def check_gt_exists(self, vid_id): 97 | for gt in self.ground_truths: 98 | if vid_id in gt: 99 | return True 100 | return False 101 | 102 | def get_gt_vid_ids(self): 103 | vid_ids = set([]) 104 | for gt in self.ground_truths: 105 | vid_ids |= set(gt.keys()) 106 | return list(vid_ids) 107 | 108 | def evaluate(self): 109 | aggregator = {} 110 | self.scores = {} 111 | for tiou in self.tious: 112 | scores = self.evaluate_tiou(tiou) 113 | for metric, score in scores.items(): 114 | if metric not in self.scores: 115 | self.scores[metric] = [] 116 | self.scores[metric].append(score) 117 | if self.verbose: 118 | self.scores['Recall'] = [] 119 | self.scores['Precision'] = [] 120 | for tiou in self.tious: 121 | precision, recall = self.evaluate_detection(tiou) 122 | self.scores['Recall'].append(recall) 123 | self.scores['Precision'].append(precision) 124 | 125 | def evaluate_detection(self, tiou): 126 | gt_vid_ids = self.get_gt_vid_ids() 127 | # Recall is the percentage of ground truth that is covered by the predictions 128 | # Precision is the percentage of predictions that are valid 129 | recall = [0] * len(gt_vid_ids) 130 | precision = [0] * len(gt_vid_ids) 131 | for vid_i, vid_id in enumerate(gt_vid_ids): 132 | best_recall = 0 133 | best_precision = 0 134 | for gt in self.ground_truths: 135 | if vid_id not in gt: 136 | continue 137 | refs = gt[vid_id] 138 | ref_set_covered = set([]) 139 | pred_set_covered = set([]) 140 | num_gt = 0 141 | num_pred = 0 142 | if vid_id in self.prediction: 143 | for pred_i, pred in enumerate(self.prediction[vid_id]): 144 | pred_timestamp = pred['timestamp'] 145 | for ref_i, ref_timestamp in enumerate(refs['timestamps']): 146 | if self.iou(pred_timestamp, ref_timestamp) > tiou: 147 | ref_set_covered.add(ref_i) 148 | pred_set_covered.add(pred_i) 149 | 150 | new_precision = float(len(pred_set_covered)) / (pred_i + 1) 151 | best_precision = max(best_precision, new_precision) 152 | new_recall = float(len(ref_set_covered)) / len(refs['timestamps']) 153 | best_recall = max(best_recall, new_recall) 154 | recall[vid_i] = best_recall 155 | precision[vid_i] = best_precision 156 | return sum(precision) / len(precision), sum(recall) / len(recall) 157 | 158 | def evaluate_tiou(self, tiou): 159 | # This method averages the tIoU precision from METEOR, Bleu, etc. across videos 160 | res = {} 161 | gts = {} 162 | gt_vid_ids = self.get_gt_vid_ids() 163 | 164 | unique_index = 0 165 | 166 | # video id to unique caption ids mapping 167 | vid2capid = {} 168 | 169 | cur_res = {} 170 | cur_gts = {} 171 | 172 | for vid_id in gt_vid_ids: 173 | 174 | vid2capid[vid_id] = [] 175 | 176 | # If the video does not have a prediction, then Vwe give it no matches 177 | # We set it to empty, and use this as a sanity check later on 178 | if vid_id not in self.prediction: 179 | pass 180 | 181 | # If we do have a prediction, then we find the scores based on all the 182 | # valid tIoU overlaps 183 | else: 184 | # For each prediction, we look at the tIoU with ground truth 185 | for i,pred in enumerate(self.prediction[vid_id]): 186 | has_added = False 187 | for gt in self.ground_truths: 188 | if vid_id not in gt: 189 | print('skipped') 190 | continue 191 | gt_captions = gt[vid_id] 192 | for caption_idx, caption_timestamp in enumerate(gt_captions['timestamps']): 193 | if True or self.iou(pred['timestamp'], caption_timestamp) >= tiou: 194 | gt_caption = gt_captions['sentences'][i] # for now we use gt proposal 195 | cur_res[unique_index] = [{'caption': remove_nonascii(pred['sentence'])}] 196 | cur_gts[unique_index] = [{'caption': remove_nonascii(gt_caption)}] # for now we use gt proposal 197 | #cur_gts[unique_index] = [{'caption': remove_nonascii(gt_captions['sentences'][caption_idx])}] 198 | vid2capid[vid_id].append(unique_index) 199 | unique_index += 1 200 | has_added = True 201 | break # for now we use gt proposal 202 | 203 | # If the predicted caption does not overlap with any ground truth, 204 | # we should compare it with garbage 205 | if not has_added: 206 | cur_res[unique_index] = [{'caption': remove_nonascii(pred['sentence'])}] 207 | cur_gts[unique_index] = [{'caption': 'abc123!@#'}] 208 | vid2capid[vid_id].append(unique_index) 209 | unique_index += 1 210 | 211 | # Each scorer will compute across all videos and take average score 212 | output = {} 213 | 214 | # call tokenizer here for all predictions and gts 215 | tokenize_res = self.tokenizer.tokenize(cur_res) 216 | tokenize_gts = self.tokenizer.tokenize(cur_gts) 217 | 218 | # reshape back 219 | for vid in vid2capid.keys(): 220 | res[vid] = {index:tokenize_res[index] for index in vid2capid[vid]} 221 | gts[vid] = {index:tokenize_gts[index] for index in vid2capid[vid]} 222 | 223 | for scorer, method in self.scorers: 224 | if self.verbose: 225 | print('computing %s score...'%(scorer.method())) 226 | 227 | # For each video, take all the valid pairs (based from tIoU) and compute the score 228 | all_scores = {} 229 | 230 | if method == "SPICE": # don't want to compute spice for 10000 times 231 | print("getting spice score...") 232 | score, scores = scorer.compute_score(tokenize_gts, tokenize_res) 233 | all_scores[0] = score 234 | else: 235 | for i,vid_id in enumerate(gt_vid_ids): 236 | if len(res[vid_id]) == 0 or len(gts[vid_id]) == 0: 237 | if type(method) == list: 238 | score = [0] * len(method) 239 | else: 240 | score = 0 241 | else: 242 | score, scores = scorer.compute_score(gts[vid_id], res[vid_id]) 243 | all_scores[vid_id] = score 244 | 245 | #print all_scores.values() 246 | if type(method) == list: 247 | scores = np.mean(all_scores.values(), axis=0) 248 | for m in range(len(method)): 249 | output[method[m]] = scores[m] 250 | if self.verbose: 251 | print("Calculated tIoU: %1.1f, %s: %0.3f" % (tiou, method[m], output[method[m]])) 252 | else: 253 | output[method] = np.mean(all_scores.values()) 254 | if self.verbose: 255 | print("Calculated tIoU: %1.1f, %s: %0.3f" % (tiou, method, output[method])) 256 | return output 257 | 258 | def main(args): 259 | # Call coco eval 260 | evaluator = ANETcaptions(ground_truth_filenames=args.references, 261 | prediction_filename=args.submission, 262 | tious=args.tious, 263 | max_proposals=args.max_proposals_per_video, 264 | verbose=args.verbose) 265 | evaluator.evaluate() 266 | 267 | # Output the results 268 | if args.verbose: 269 | for i, tiou in enumerate(args.tious): 270 | print('-' * 80) 271 | print("tIoU: " , tiou) 272 | print('-' * 80) 273 | for metric in evaluator.scores: 274 | score = evaluator.scores[metric][i] 275 | print('| %s: %2.4f'%(metric, 100*score)) 276 | 277 | # Print the averages 278 | print('-' * 80) 279 | print("Average across all tIoUs") 280 | print('-' * 80) 281 | output = {} 282 | for metric in evaluator.scores: 283 | score = evaluator.scores[metric] 284 | print('| %s: %2.4f'%(metric, 100 * sum(score) / float(len(score)))) 285 | output[metric] = 100 * sum(score) / float(len(score)) 286 | json.dump(output,open(args.output,'w')) 287 | print(output) 288 | if __name__=='__main__': 289 | parser = argparse.ArgumentParser(description='Evaluate the results stored in a submissions file.') 290 | parser.add_argument('-s', '--submission', type=str, default='sample_submission.json', 291 | help='sample submission file for ActivityNet Captions Challenge.') 292 | parser.add_argument('-r', '--references', type=str, nargs='+', default=['data/val_1.json'], 293 | help='reference files with ground truth captions to compare results against. delimited (,) str') 294 | parser.add_argument('-o', '--output', type=str, default='result.json', 295 | help='output file with final language metrics.') 296 | parser.add_argument('--tious', type=float, nargs='+', default=[0.3], 297 | help='Choose the tIoUs to average over.') 298 | parser.add_argument('-ppv', '--max-proposals-per-video', type=int, default=1000, 299 | help='maximum propoasls per video.') 300 | parser.add_argument('-v', '--verbose', action='store_true', 301 | help='Print intermediate steps.') 302 | args = parser.parse_args() 303 | 304 | main(args) 305 | -------------------------------------------------------------------------------- /densevid_eval/evaluateCaptionsDiversity.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import sys 4 | 5 | def getNgrams(words_pred, unigrams, bigrams, trigrams, fourgrams): 6 | # N=1 7 | for w in words_pred: 8 | if w not in unigrams: 9 | unigrams[w] = 0 10 | unigrams[w] += 1 11 | # N=2 12 | for i, w in enumerate(words_pred): 13 | if i', ' ') 76 | query = query.replace('`', ' ') 77 | query = query.replace('#', ' ') 78 | query = query.replace(u'\u2019', "'") 79 | while query[-1] == ' ': 80 | query = query[0:-1] 81 | while query[0] == ' ': 82 | query = query[1:] 83 | while ' ' in query: 84 | query = query.replace(' ', ' ') 85 | # print(query) 86 | if query not in trainingQueries: 87 | trainingQueries[query] = 0 88 | trainingQueries[query] += 1 89 | 90 | vocab = {} 91 | novel_sentence = [] 92 | uniq_sentence = {} 93 | count_sent = 0 94 | sent_length = [] 95 | 96 | for vid in data_gt['results']: 97 | for i, _ in enumerate(data_gt['results'][vid]): 98 | 99 | try: 100 | pred_sentence = data_predicted['results'][vid][i]['sentence'].lower() 101 | except: 102 | continue 103 | 104 | if pred_sentence[-1] == '.': 105 | pred_sentence = pred_sentence[0:-1] 106 | while pred_sentence[-1] == ' ': 107 | pred_sentence = pred_sentence[0:-1] 108 | pred_sentence = pred_sentence.replace(',', ' ') 109 | while ' ' in pred_sentence: 110 | pred_sentence = pred_sentence.replace(' ', ' ') 111 | 112 | if pred_sentence in trainingQueries: 113 | novel_sentence.append(0) 114 | else: 115 | novel_sentence.append(1) 116 | 117 | if pred_sentence not in uniq_sentence: 118 | uniq_sentence[pred_sentence] = 0 119 | uniq_sentence[pred_sentence] += 1 120 | 121 | words_pred = pred_sentence.split(' ') 122 | for w in words_pred: 123 | if w not in vocab: 124 | vocab[w] = 0 125 | vocab[w] += 1 126 | 127 | sent_length.append(len(words_pred)) 128 | count_sent += 1 129 | 130 | print ('Vocab: %d\t Novel Sent: %.2f\t Uniq Sent: %.2f\t Sent length: %.2f' % 131 | (len(vocab), np.mean(novel_sentence), len(uniq_sentence)/float(count_sent), np.mean(sent_length))) 132 | 133 | def activity_stats(data_predicted, data_gt, data_annos): 134 | print('#### Per activity ####') 135 | 136 | # Per activity 137 | data_annos = data_annos['database'] 138 | activities = {} 139 | vid2act = {} 140 | for vid_id in data_annos: 141 | vid2act[vid_id] = [] 142 | annos = data_annos[vid_id]['annotations'] 143 | for anno in annos: 144 | if anno['label'] not in vid2act[vid_id]: 145 | vid2act[vid_id].append(anno['label']) 146 | if anno['label'] not in activities: 147 | activities[anno['label']] = True 148 | div1 = {} 149 | div2 = {} 150 | div3 = {} 151 | div4 = {} 152 | re = {} 153 | sentences = {} 154 | 155 | for act in activities: 156 | div1[act] = -1 157 | div2[act] = -1 158 | div3[act] = -1 159 | div4[act] = -1 160 | re[act] = -1 161 | sentences[act] = [] 162 | 163 | for vid in data_gt['results']: 164 | 165 | act = vid2act[vid[2:]][0] 166 | 167 | if vid not in data_predicted['results']: 168 | continue 169 | 170 | for i, _ in enumerate(data_gt['results'][vid]): 171 | 172 | try: 173 | pred_sentence = data_predicted['results'][vid][i]['sentence'] 174 | except: 175 | continue 176 | 177 | if pred_sentence[-1] == '.': 178 | pred_sentence = pred_sentence[0:-1] 179 | while pred_sentence[-1] == ' ': 180 | pred_sentence = pred_sentence[0:-1] 181 | pred_sentence = pred_sentence.replace(',', ' ') 182 | while ' ' in pred_sentence: 183 | pred_sentence = pred_sentence.replace(' ', ' ') 184 | 185 | sentences[act].append(pred_sentence) 186 | 187 | for act in activities: 188 | unigrams = {} 189 | bigrams = {} 190 | trigrams = {} 191 | fourgrams = {} 192 | 193 | for pred_sentence in sentences[act]: 194 | words_pred = pred_sentence.split(' ') 195 | unigrams, bigrams, trigrams, fourgrams = getNgrams(words_pred, unigrams, bigrams, trigrams, fourgrams) 196 | 197 | sum_unigrams = sum([unigrams[un] for un in unigrams]) 198 | vid_div1 = float(len(unigrams)) / float(sum_unigrams) 199 | vid_div2 = float(len(bigrams)) / float(sum_unigrams) 200 | vid_div3 = float(len(trigrams)) / float(sum_unigrams) 201 | vid_div4 = float(len(fourgrams)) / float(sum_unigrams) 202 | 203 | vid_re = float(sum([max(fourgrams[f]-1,0) for f in fourgrams])) / float(sum([fourgrams[f] for f in fourgrams])) 204 | 205 | div1[act] = vid_div1 206 | div2[act] = vid_div2 207 | div3[act] = vid_div3 208 | div4[act] = vid_div4 209 | re[act] = vid_re 210 | 211 | mean_div1 = np.mean([div1[act] for act in activities]) 212 | mean_div2 = np.mean([div2[act] for act in activities]) 213 | mean_div3 = np.mean([div3[act] for act in activities]) 214 | mean_div4 = np.mean([div4[act] for act in activities]) 215 | mean_re = np.mean([re[act] for act in activities]) 216 | 217 | print ('Div-1: %.4f\t Div-2: %.4f\t RE: %.4f' % (mean_div1, mean_div2, mean_re)) 218 | 219 | def video_stats(data_predicted, data_gt): 220 | # print('#### Per video ####') 221 | 222 | # Per video 223 | 224 | div1 = [] 225 | div2 = [] 226 | div3 = [] 227 | div4 = [] 228 | re1 = [] 229 | re2 = [] 230 | re3 = [] 231 | re4 = [] 232 | 233 | for vid in data_gt['results']: 234 | 235 | unigrams = {} 236 | bigrams = {} 237 | trigrams = {} 238 | fourgrams = {} 239 | 240 | if vid not in data_predicted['results']: 241 | continue 242 | 243 | for i, _ in enumerate(data_gt['results'][vid]): 244 | 245 | try: 246 | pred_sentence = data_predicted['results'][vid][i]['sentence'] 247 | except: 248 | continue 249 | 250 | if pred_sentence[-1] == '.': 251 | pred_sentence = pred_sentence[0:-1] 252 | while pred_sentence[-1] == ' ': 253 | pred_sentence = pred_sentence[0:-1] 254 | pred_sentence = pred_sentence.replace(',', ' ') 255 | while ' ' in pred_sentence: 256 | pred_sentence = pred_sentence.replace(' ', ' ') 257 | 258 | words_pred = pred_sentence.split(' ') 259 | unigrams, bigrams, trigrams, fourgrams = getNgrams(words_pred, unigrams, bigrams, trigrams, fourgrams) 260 | 261 | sum_unigrams = sum([unigrams[un] for un in unigrams]) 262 | vid_div1 = float(len(unigrams)) / float(sum_unigrams) 263 | vid_div2 = float(len(bigrams)) / float(sum_unigrams) 264 | vid_div3 = float(len(trigrams)) / float(sum_unigrams) 265 | vid_div4 = float(len(fourgrams)) / float(sum_unigrams) 266 | 267 | vid_re1 = float(sum([max(unigrams[f] - 1, 0) for f in unigrams])) / float(sum([unigrams[f] for f in unigrams])) 268 | vid_re2 = float(sum([max(bigrams[f] - 1, 0) for f in bigrams])) / float(sum([bigrams[f] for f in bigrams])) 269 | vid_re3 = float(sum([max(trigrams[f] - 1, 0) for f in trigrams])) / float(sum([trigrams[f] for f in trigrams])) 270 | vid_re4 = float(sum([max(fourgrams[f]-1,0) for f in fourgrams])) / float(sum([fourgrams[f] for f in fourgrams])) 271 | 272 | div1.append(vid_div1) 273 | div2.append(vid_div2) 274 | div3.append(vid_div3) 275 | div4.append(vid_div4) 276 | re1.append(vid_re1) 277 | re2.append(vid_re2) 278 | re3.append(vid_re3) 279 | re4.append(vid_re4) 280 | 281 | #print ('tDiv-1: %.4f\t Div-2: %.4f\t RE-4: %.4f' % (np.mean(div1), np.mean(div2),np.mean(re4))) 282 | 283 | if __name__=='__main__': 284 | submission = sys.argv[1] 285 | evaluateDiversity(submission) 286 | -------------------------------------------------------------------------------- /densevid_eval/evaluateRepetition.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def save_json(data, filepath): 7 | with open(filepath, "w") as f: 8 | json.dump(data, f) 9 | 10 | 11 | def save_json_pretty(data, filepath): 12 | with open(filepath, "w") as f: 13 | f.write(json.dumps(data, indent=4, sort_keys=True)) 14 | 15 | 16 | def get_ngrams(words_pred, unigrams, bigrams, trigrams, fourgrams): 17 | # N=1 18 | for w in words_pred: 19 | if w not in unigrams: 20 | unigrams[w] = 0 21 | unigrams[w] += 1 22 | # N=2 23 | for i, w in enumerate(words_pred): 24 | if i 0 else [''] for k in gt_vid_ids} 124 | para_res = {vid2idx[k]: [' '.join(parse_para(self.prediction[k]))] \ 125 | if k in self.prediction and len(self.prediction[k]) > 0 else [''] for k in gt_vid_ids} 126 | 127 | # Each scorer will compute across all videos and take average score 128 | output = {} 129 | num = len(res) 130 | hard_samples = {} 131 | easy_samples = {} 132 | for scorer, method in self.scorers: 133 | if self.verbose: 134 | print('computing %s score...'%(scorer.method())) 135 | 136 | if method != 'Self_Bleu': 137 | score, scores = scorer.compute_score(gts, res) 138 | else: 139 | score, scores = scorer.compute_score(gts, para_res) 140 | scores = np.asarray(scores) 141 | 142 | if type(method) == list: 143 | for m in range(len(method)): 144 | output[method[m]] = score[m] 145 | if self.verbose: 146 | print("%s: %0.3f" % (method[m], output[method[m]])) 147 | for m, i in enumerate(scores.argmin(1)): 148 | if i not in hard_samples: 149 | hard_samples[i] = [] 150 | hard_samples[i].append(method[m]) 151 | for m, i in enumerate(scores.argmax(1)): 152 | if i not in easy_samples: 153 | easy_samples[i] = [] 154 | easy_samples[i].append(method[m]) 155 | else: 156 | output[method] = score 157 | if self.verbose: 158 | print("%s: %0.3f" % (method, output[method])) 159 | i = scores.argmin() 160 | if i not in hard_samples: 161 | hard_samples[i] = [] 162 | hard_samples[i].append(method) 163 | i = scores.argmax() 164 | if i not in easy_samples: 165 | easy_samples[i] = [] 166 | easy_samples[i].append(method) 167 | # print('# scored video =', num) 168 | 169 | self.hard_samples = {gt_vid_ids[i]: v for i, v in hard_samples.items()} 170 | self.easy_samples = {gt_vid_ids[i]: v for i, v in easy_samples.items()} 171 | return output 172 | 173 | def main(args): 174 | # Call coco eval 175 | evaluator = ANETcaptions(ground_truth_filenames=args.references, 176 | prediction_filename=args.submission, 177 | verbose=args.verbose, 178 | all_scorer=args.all_scorer) 179 | evaluator.evaluate() 180 | output = {} 181 | # Output the results 182 | for metric, score in evaluator.scores.items(): 183 | # print('| %s: %2.4f'%(metric, 100*score)) 184 | output[metric] = score 185 | json.dump(output, open(args.output, 'w')) 186 | #print(output) 187 | 188 | import time 189 | if __name__=='__main__': 190 | parser = argparse.ArgumentParser(description='Evaluate the results stored in a submissions file.') 191 | parser.add_argument('-s', '--submission', type=str, default='sample_submission.json', 192 | help='sample submission file for ActivityNet Captions Challenge.') 193 | parser.add_argument('-r', '--references', type=str, nargs='+', required=True, 194 | help='reference files with ground truth captions to compare results against. delimited (,) str') 195 | parser.add_argument('-o', '--output', type=str, default=None, help='output file with final language metrics.') 196 | parser.add_argument('-v', '--verbose', action='store_true', 197 | help='Print intermediate steps.') 198 | parser.add_argument('--time', '--t', action = 'store_true', 199 | help = 'Count running time.') 200 | parser.add_argument('--all_scorer', '--a', action = 'store_true', 201 | help = 'Use all scorer.') 202 | args = parser.parse_args() 203 | 204 | if args.output is None: 205 | r_path = args.submission 206 | r_path_splits = r_path.split(".") 207 | r_path_splits = r_path_splits[:-1] + ["_metric", r_path_splits[-1]] 208 | args.output = ".".join(r_path_splits) 209 | 210 | if args.time: 211 | start_time = time.time() 212 | main(args) 213 | if args.time: 214 | print('time = %.2f' % (time.time() - start_time)) 215 | -------------------------------------------------------------------------------- /densevid_eval/yc2_data/val_list.txt: -------------------------------------------------------------------------------- 1 | 405/sdB8qBlLS2E 2 | 405/fn9anlEL4FI 3 | 405/-dh_uGahzYo 4 | 405/BktdaTg6_E4 5 | 325/U_yVc8Dl048 6 | 325/RnSl1LVrItI 7 | 325/vVZsj1t9R70 8 | 325/ysRLGUndzgg 9 | 306/HF49t8uVJOE 10 | 306/GXnzgRC3sd4 11 | 306/pxQd53yvSaA 12 | 306/uAzzevo-FME 13 | 306/vU2lND4YQjM 14 | 306/Mzn6Q4gUDBo 15 | 303/JylDlRtH9Tc 16 | 303/9pJToG30LdM 17 | 303/2IkN3hTEZ2Y 18 | 303/ulrh6C5V_VI 19 | 303/2zFAZy0zSbw 20 | 303/6uHoTJSLoL8 21 | 303/EkuM7L31bMQ 22 | 302/PQ97HXmsFR0 23 | 302/vDDeMg2dhEM 24 | 302/WlHWRPyA7_g 25 | 302/c4WaDsqP38k 26 | 302/9guuyTr8EUg 27 | 302/LYj5-CdRIz0 28 | 321/v_dkYNq8G9Y 29 | 321/vXlmXrKC0FE 30 | 321/UHhuaRTF1UY 31 | 221/zzT6RoI4JPU 32 | 221/cCWDR-jUv9U 33 | 221/GmkRlWA2kGI 34 | 221/G-AUY-jWzck 35 | 221/PV93b0xisN8 36 | 221/2heP32bqOV0 37 | 210/LjfTvZ-cmzs 38 | 210/YMYNv3cZ9SE 39 | 210/_GTwKEPmB-U 40 | 210/nHZsE7T7hwI 41 | 210/gTqhgReBDw0 42 | 210/N35UyfIwhVI 43 | 225/NAMZY2LbeFY 44 | 225/KYoelaJY5LA 45 | 225/2HsWZdKKBGg 46 | 225/-Ju39A-G0Dk 47 | 225/pTjoGIvSfE8 48 | 225/lwdypoLpMW4 49 | 107/peld2w63tpM 50 | 107/VLS3ZJt9GMg 51 | 107/zljhtdoqpv0 52 | 107/3jDAyeKeYFA 53 | 107/-k7trpuj3X8 54 | 201/7D4uMKxLDT0 55 | 201/zBexcthy_tA 56 | 201/-ErPSunMfcs 57 | 201/30Q8k57Kbz4 58 | 201/2SxbO4VAgN8 59 | 201/mZwK0TBI1iY 60 | 119/bAC0cZIQVOk 61 | 119/0ShsPjf9shQ 62 | 119/4bEtf7u4YtE 63 | 119/9BNRMHGepS4 64 | 119/wqpqx-Qm7lk 65 | 119/C73qiF138VU 66 | 419/yizxI2Gf_ww 67 | 419/bY4_F8J8HOM 68 | 419/JK0DTF9Edtk 69 | 419/5W3jHo5d7hM 70 | 419/4K9h7ojJYkc 71 | 120/iqcnbNqVc7U 72 | 120/ucaCmhNo78k 73 | 120/QYl_wwBKt18 74 | 120/CWxjNRIKjA0 75 | 120/OrXZqt42OVs 76 | 418/D4mU_NtbneA 77 | 418/LeCwqp8Bic8 78 | 418/DpuofwnCI8A 79 | 418/DrXVuj1Qowo 80 | 204/cF45-iVw--w 81 | 204/EJm2J0WqRcY 82 | 204/nfVXBQwOCMc 83 | 204/OF-Zh5FrxGc 84 | 204/LpBsoQ6TAL0 85 | 319/RUxugNYxFqg 86 | 319/tkuST4Ku37s 87 | 319/FcjEswcaJW4 88 | 319/X4GOx3EW3Rw 89 | 309/Re46osq_NkI 90 | 309/H6acK-N2wMs 91 | 309/abfhnSaZFlA 92 | 309/-AwyG1JcMp8 93 | 309/tYg3lQ5aZv8 94 | 212/9GX8f5EwwE4 95 | 212/XUyqiWN8WFI 96 | 212/HBUz55JRRm8 97 | 212/0uaKitJaqmI 98 | 314/TO_W2RYL2mA 99 | 314/0hb6NShH9hY 100 | 314/p6LSW9kuRCE 101 | 314/2rJ3KKx0oRk 102 | 314/qRSZEN6g8jY 103 | 314/RqgN6iWMkb0 104 | 317/b_uKIQ4dn3A 105 | 317/R-EnNr_oH8A 106 | 317/noS_n5k3oxM 107 | 317/H5NPxWpfYNU 108 | 317/AcWeYhS3cDs 109 | 209/btikV_DUoCM 110 | 209/xPiv3hP5888 111 | 209/5nh2CP22dgY 112 | 228/Odv6ltYAMw4 113 | 228/jT75QMjRkD0 114 | 228/YP4B9gLNOIM 115 | 228/R5IAGR2SeaE 116 | 228/WQlMXudBGT4 117 | 228/IDiovuOcKW8 118 | 122/LfSYF1N5i_Q 119 | 122/viwpmylgps0 120 | 122/lKm5Ji1Fr4U 121 | 122/GgM8IIglBLw 122 | 122/9F5FvWheSrg 123 | 307/YRZ8zZElALQ 124 | 307/eQZEf3NCCo4 125 | 307/Vq5gxXh9zLM 126 | 307/FliMoBfG72Y 127 | 216/Z5bpo2sBsl8 128 | 216/1iv2xhPN3vk 129 | 216/xx698BRyqG4 130 | 216/EsQbw20TQPA 131 | 216/QUt050AXQMw 132 | 216/JPbFE731Y0c 133 | 215/EpNUSTO2BI4 134 | 215/95WMX64RIBc 135 | 215/woTrhsB_bcA 136 | 215/xwQBrf2CAvc 137 | 215/-ju7_ZORsZw 138 | 215/Hh-uza7bwgE 139 | 313/uHv9xRooPMc 140 | 313/mi8NwUqf7nM 141 | 313/CotdlwupDSI 142 | 313/524UzHtbAcY 143 | 108/LWuuCndtJr0 144 | 108/NjAtxfaLwCk 145 | 108/tPLVNKgs8Lk 146 | 108/7ebZWviUfUA 147 | 108/PYjrGqPHGhY 148 | 108/2iWUUcW08ac 149 | 124/f2uDKzq8WM0 150 | 124/-goI2-eJO1w 151 | 124/NTyhMGmuWik 152 | 124/soLZjUyn0CI 153 | 124/1vJp-jaIaeE 154 | 124/Jtusyjv7GiY 155 | 124/RKhfv-spUaI 156 | 213/YX6v3tY7OPg 157 | 213/4Zl5NvXPi-0 158 | 213/dMhoqii0Cq0 159 | 213/4Y8vVGsv4JE 160 | 213/FzhJGCaaYVs 161 | 213/c3JFGGhkArA 162 | 213/i0qYuhtSQHI 163 | 423/AMBH5L6x3dQ 164 | 423/p-NnIyGFZVw 165 | 423/QISvGTL2VDc 166 | 423/SjA7PFoZcNQ 167 | 110/VPFmudvabUg 168 | 110/PTpRTJKAEoI 169 | 110/9GIPE0aeVNI 170 | 110/OWtnI3m-p8g 171 | 110/I1JgU6TK-yc 172 | 110/FSWZXBbEyFw 173 | 403/SOMsxGGSTUk 174 | 403/ACyY0jTrm5c 175 | 403/vp_dOhmfGcs 176 | 403/VmaEuPzlPII 177 | 403/m88rF0rwHo8 178 | 403/gYWqhml_YJQ 179 | 117/hLTNXDKU_Pk 180 | 117/oAE7nqQeMBQ 181 | 117/8XcSP7kKOIo 182 | 117/MCtF5tRCRUk 183 | 117/wk0nfwGyPBI 184 | 425/OUhxy5BANfk 185 | 425/G-spzGkKIHM 186 | 425/57e54HEcrUE 187 | 425/6nVIgasiUtw 188 | 425/W6DgS0s0qcI 189 | 425/yreC9D4yYiM 190 | 106/xhXcJ6bhX2w 191 | 106/EedEYHqLfP8 192 | 106/e8S1vFC8zYk 193 | 106/RY10IUcz3bk 194 | 106/UmJk0WSl9Uc 195 | 106/88YovCsnMxs 196 | 219/R3Jc1fXwSnU 197 | 219/o8HaMr9E8J8 198 | 219/OpURFOTdycE 199 | 219/i9CMFh31Bs0 200 | 219/5Pa79r5Q-ZI 201 | 219/tQ6-_e59Zrk 202 | 218/1Ihxcua2HBc 203 | 218/ljyO7IaGWLY 204 | 218/TMpt-41UTOk 205 | 127/OEfzgobszUA 206 | 127/sj4BJSnjubc 207 | 127/Xz3-xRyBBog 208 | 127/F2qYQZ7Q68s 209 | 127/g6eV_7U5HX8 210 | 127/3WXM2FAueb8 211 | 127/KTQeLdmlzBo 212 | 203/XAHNVoKV1Bc 213 | 203/bmZB3aszZlA 214 | 203/QKjmdrMA2t8 215 | 203/XbTA0SGOdwk 216 | 203/aCvIo-M06xI 217 | 101/SVo2W3ux1pU 218 | 101/uOXlG8Tglc8 219 | 101/10dZTHlkb8w 220 | 101/4eWzsx1vAi8 221 | 101/c9eELn4axpg 222 | 406/BAoQWVV-bh4 223 | 406/8CaadFo3sw0 224 | 406/Pf4UNA-izQo 225 | 406/T_o_T3LEYLY 226 | 406/ysUibvVCpP8 227 | 406/6H8tPeQGhMY 228 | 412/mV3m2svj3XE 229 | 412/eWBSMD3BiHM 230 | 412/sBJJ0Cj0GG4 231 | 412/U_2DFd2ZMfs 232 | 116/eMsfAhVj2e4 233 | 116/cMzyB4m3VHY 234 | 116/5cn9KJfaQXk 235 | 116/kWLYcM3uVVc 236 | 116/9iH8GK1pcEM 237 | 116/MPCU71Hg-i4 238 | 421/ekgZfuxsz_4 239 | 421/4SnAlRlxlFk 240 | 421/c00gy-NVzaw 241 | 421/bxgdUWKOwtQ 242 | 421/s4CktGpWaZE 243 | 422/rf_mGLJPnDk 244 | 422/lBguj96fa5w 245 | 422/3dUm-m3iFaI 246 | 422/N1-rqFfCm9M 247 | 214/mR0inCVvBzY 248 | 214/ZQGfcC62Pys 249 | 214/RHddz6qeJKk 250 | 214/luDzsPatsGw 251 | 214/aCkbw-aI4xU 252 | 103/gZuDMKXWU_E 253 | 103/o42iehActZo 254 | 103/JxCBGlPgr5o 255 | 103/y4y22RQH05c 256 | 103/rwYaDqXFH88 257 | 103/v174YTbr2N8 258 | 103/zLBRrWd4DTo 259 | 207/nz_LHDf0uqE 260 | 207/wHWDBQ9_7FU 261 | 207/4h33GFHLPNg 262 | 207/sjh57ujp52M 263 | 207/J5Tw7KRnSyc 264 | 324/a4RwXrA1hiE 265 | 324/hAzH-GS4cvc 266 | 324/7r6JQycloEs 267 | 324/Pk88LQ7hxbg 268 | 324/4apR0YypAGc 269 | 324/o9kndEZvsnY 270 | 121/p-PFp1c0FKs 271 | 121/-GlSSp5ZOCQ 272 | 121/2-mxsib6pJo 273 | 121/7R5MVNE-ePU 274 | 121/_Vzpj0cXoSM 275 | 121/hs2h7nb5PHQ 276 | 318/yTPJ_u_qxDU 277 | 318/aYjy__xnegM 278 | 318/2Zr72r4OCe8 279 | 318/EnP2j1caRVs 280 | 318/3z_QhNnSFtM 281 | 409/oC5OvA4BK-E 282 | 409/NXnQys_ejeg 283 | 409/DBgap0YANhs 284 | 409/HdVETeyupXE 285 | 409/so-RuJQY1d0 286 | 112/7NptUiW8hJw 287 | 112/jEo9VXYVrxs 288 | 112/5VnaolWGIy4 289 | 112/UkqQAynrM2g 290 | 112/TgttBprZXDY 291 | 112/e1gtgMczUwE 292 | 304/wii9jNiNl9Y 293 | 304/u95xkc4DfAs 294 | 304/cQ8mt5ACO0A 295 | 304/sSO2wO-yaHw 296 | 404/im-aWyUQGrg 297 | 404/wW_kszdGIJw 298 | 404/1rMT2uMF78E 299 | 404/RFE7qdhjgXc 300 | 404/HJHV2nYz1L8 301 | 111/TAXAVvroOgk 302 | 111/xw9aAfqanDo 303 | 111/wQc0xmPurDc 304 | 111/gXINt_KMK3M 305 | 111/83uz_q4_nyk 306 | 111/JWcAs8biQFU 307 | 111/wokMK-w7XiA 308 | 301/Y2HYSmo4KaI 309 | 301/ntiGX3X-spA 310 | 301/4nxbRG6-sfw 311 | 109/5Oq5giRXtag 312 | 109/vLcBGs389k4 313 | 109/A8eDWlCYaq8 314 | 109/RznLeKVI3yo 315 | 109/NYhsc9ikk4I 316 | 109/WqfselLH4MQ 317 | 104/M8SHMUBnm4A 318 | 104/Nbh64ntT3EM 319 | 104/NZtwPf32YN4 320 | 323/8QblSYQpAoM 321 | 323/mUk0FmDrBb8 322 | 323/tKsGWxiWWCg 323 | 308/jnewhlK2USg 324 | 308/WYAFPvlDB_A 325 | 308/zF3TOfktwd4 326 | 308/Ws7JgPJsVjs 327 | 308/hkVfzjA1HA0 328 | 308/wR8Ybxpnbwc 329 | 224/yWEq4_EG1us 330 | 224/02nUKT0A7uE 331 | 224/nuwCjQVlBrg 332 | 224/E9O9-6TQUw0 333 | 305/PNlctwVmbLY 334 | 305/lpWOv7Y3JHM 335 | 305/_mL1gihKDw0 336 | 305/7-FatJyHj_g 337 | 305/x3if1znl5Fg 338 | 416/nVERaEFJWLQ 339 | 416/oJZUxU9szWA 340 | 416/_xIIpW8iMps 341 | 416/fnbXolhuE7k 342 | 416/a5FoLWnEiAI 343 | 416/DVW7nZeeVlk 344 | 114/IDu5czNIM1w 345 | 114/dMbb10O9hGs 346 | 114/84i8Qdnyd0k 347 | 202/iuQjb1-WAzs 348 | 202/2IcWR76i1bo 349 | 202/oR2QDpoatcQ 350 | 202/eHk6NSLGAkc 351 | 202/awQYyYgulLw 352 | 113/YNpVeU1pVZA 353 | 113/S07Fr83GcBI 354 | 113/RllWJUvrxEY 355 | 113/7jO6rYyhuJk 356 | 113/oDsUh1es_lo 357 | 310/gEYyWqs1oL0 358 | 310/7E8Lj_Ktfok 359 | 310/fpPQcbr5VC0 360 | 310/0EuykeOvGg4 361 | 230/3meb_5kcPFg 362 | 230/vWrOd9Ur0po 363 | 230/x6noOknBPDI 364 | 230/3V4MxH2GuIU 365 | 230/6XBocXgvfTs 366 | 230/3aFiXsrKSoQ 367 | 208/bmxWJNbqCk8 368 | 208/We2CzpjPD3k 369 | 208/LQDP3xm8aRk 370 | 208/186EQzPPHW8 371 | 208/T_fPNAK5Ecg 372 | 208/ffyHeyRpYvo 373 | 401/lRwMt_eHjxU 374 | 401/ikmPrpgWQ5M 375 | 401/6seOEuK0ojg 376 | 401/YhevdroG7a4 377 | 401/mNhj7SA7c4g 378 | 401/NRovp9c9e-4 379 | 401/m9gNbLw0Dcg 380 | 226/lC8B_Yx6Qzg 381 | 226/VH0SmCfAov4 382 | 226/xHr8X2Wpmno 383 | 226/Ew5YKc6xmLE 384 | 226/ffhliBglDhY 385 | 226/05ZSU-5UkXw 386 | 311/BRqTCiAc7uk 387 | 311/B1YQYS9BMdk 388 | 311/lkmVVQIsdEE 389 | 311/vq8C5DTfOKc 390 | 311/W2gnFLOi_AQ 391 | 229/XEifm-iXMvs 392 | 229/HdQzPLk_KiA 393 | 229/cMMoRNhHJrI 394 | 229/cDYCtBwin5g 395 | 229/rKtI8FQGhHo 396 | 229/YA6lhxwLrUI 397 | 413/paiJGvLILKE 398 | 413/r9AtdDfDVmo 399 | 413/m3kFrdCHitg 400 | 102/TF1iWaX2-DM 401 | 102/O7ONcb3qhMU 402 | 102/J7gBorrGvDU 403 | 102/pNAwkqm4t3A 404 | 410/29Wkj1LqaK8 405 | 410/oYLrSflCI2g 406 | 410/86Mb6cYFJig 407 | 222/_ilIn1kmNSA 408 | 222/D7K6_0gtpHQ 409 | 222/RubyHelAHBE 410 | 115/F564e476ULM 411 | 115/igF8D7iE46o 412 | 115/DHpQOhQhW3A 413 | 115/qAoqhmjk3iY 414 | 115/kchoaU2HL-o 415 | 115/9ekEjxd-A_Y 416 | 115/xkKuIlYSMMU 417 | 115/yxjnWx6TaQ8 418 | 223/SkawoKeyNoQ 419 | 223/uYBTguvz4tc 420 | 316/lH7pgsnyGrI 421 | 316/kGxmudExRVk 422 | 211/VwtkHIturro 423 | 211/ZjKY9v48fTc 424 | 211/DeiX_otgD1Q 425 | 211/WlkaUxBwURQ 426 | 211/qkluMpILLdQ 427 | 105/PTUxCvCz8Bc 428 | 105/V53XmPeyjIU 429 | 105/jbjg6w5taGU 430 | 105/FrzEHqqi1RY 431 | 105/wlq30WwXwSM 432 | 105/Dao0vasGPMQ 433 | 126/iq7aiv9MPvA 434 | 126/PHpk4ITk-SE 435 | 126/mhEVgpfF-IU 436 | 126/NK2tAXi3cT4 437 | 126/tGaAAI3aAUs 438 | 205/zqTXQ-YqrgQ 439 | 205/RWtVm_5_D2s 440 | 205/zPCtV7YcmkA 441 | 205/efnHOsT7k9s 442 | 205/InDwfZmSikI 443 | 206/LKrI9pGpM78 444 | 206/cdsDsUcLJZM 445 | 206/UIElE5H_iHc 446 | 206/rWdhkAXfEAY 447 | 206/9Y9_OBnJub0 448 | 206/vSRZRp2Ovqc 449 | 227/TfITvKr5M3k 450 | 227/4ZbNtfqKkiI 451 | 227/sv8jRCmi3Ro 452 | 227/eYOn2ZVB4nc 453 | 403/JqjwJIV6pI0 454 | 404/XXUmUPDosYQ 455 | 410/p-gN4cbmunQ 456 | 422/Ky0zf0v2F5A 457 | 401/sGzBQrg1adY 458 | -------------------------------------------------------------------------------- /figures/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/figures/network.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/src/__init__.py -------------------------------------------------------------------------------- /src/build_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import nltk 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | from src.utils import load_json, flat_list_of_lists, save_json 8 | from rtransformer.recursive_caption_dataset import RecursiveCaptionDataset as RCDataset 9 | 10 | 11 | def build_vocab_idx(word_insts, min_word_count): 12 | full_vocab = set(w for sent in word_insts for w in sent) 13 | print("[Info] Original Vocabulary size =", len(full_vocab)) 14 | 15 | word2idx = { 16 | RCDataset.PAD_TOKEN: RCDataset.PAD, 17 | RCDataset.CLS_TOKEN: RCDataset.CLS, 18 | RCDataset.SEP_TOKEN: RCDataset.SEP, 19 | RCDataset.VID_TOKEN: RCDataset.VID, 20 | RCDataset.BOS_TOKEN: RCDataset.BOS, 21 | RCDataset.EOS_TOKEN: RCDataset.EOS, 22 | RCDataset.UNK_TOKEN: RCDataset.UNK, 23 | } 24 | 25 | word_count = {w: 0 for w in full_vocab} 26 | 27 | for sent in word_insts: 28 | for word in sent: 29 | word_count[word] += 1 30 | 31 | ignored_word_count = 0 32 | for word, count in word_count.items(): 33 | if word not in word2idx: 34 | if count > min_word_count: 35 | word2idx[word] = len(word2idx) 36 | else: 37 | ignored_word_count += 1 38 | 39 | print("[Info] Trimmed vocabulary size = {},".format(len(word2idx)), 40 | "each with minimum occurrence = {}".format(min_word_count)) 41 | print("[Info] Ignored word count = {}".format(ignored_word_count)) 42 | return word2idx 43 | 44 | 45 | def load_transform_data(data_path): 46 | data = load_json(data_path) 47 | transformed_data = [] 48 | for v_id, cap in data.items(): 49 | cap["v_id"] = v_id 50 | transformed_data.append(cap) 51 | return transformed_data 52 | 53 | 54 | def load_glove(filename): 55 | """ returns { word (str) : vector_embedding (torch.FloatTensor) } 56 | """ 57 | glove = {} 58 | with open(filename,encoding='utf-8') as f: 59 | for line in f.readlines(): 60 | values = line.strip("\n").split(" ") # space separator 61 | word = values[0] 62 | vector = np.asarray([float(e) for e in values[1:]]) 63 | glove[word] = vector 64 | return glove 65 | 66 | 67 | def extract_glove(word2idx, raw_glove_path, vocab_glove_path, glove_dim=300): 68 | # Make glove embedding. 69 | print("Loading glove embedding at path : {}.\n".format(raw_glove_path)) 70 | glove_full = load_glove(raw_glove_path) 71 | print("Glove Loaded, building word2idx, idx2word mapping.\n") 72 | idx2word = {v: k for k, v in word2idx.items()} 73 | 74 | glove_matrix = np.zeros([len(word2idx), glove_dim]) 75 | glove_keys = glove_full.keys() 76 | for i in tqdm(range(len(idx2word))): 77 | w = idx2word[i] 78 | w_embed = glove_full[w] if w in glove_keys else np.random.randn(glove_dim) * 0.4 79 | glove_matrix[i, :] = w_embed 80 | print("vocab embedding size is :", glove_matrix.shape) 81 | torch.save(glove_matrix, vocab_glove_path) 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--train_path", type=str, default='../densevid_eval/yc2_data/yc2_train_anet_format.json') 87 | parser.add_argument("--dset_name", type=str, default="yc2", choices=["anet", "yc2"]) 88 | parser.add_argument("--cache", type=str, default="./cache") 89 | parser.add_argument("--min_word_count", type=int, default=3) 90 | parser.add_argument("--raw_glove_path", type=str, default='/mnt/Video_feature/path/glove.6B.300d.txt') 91 | 92 | opt = parser.parse_args() 93 | if not os.path.exists(opt.cache): 94 | os.makedirs(opt.cache) 95 | 96 | # load, merge, clean, split data 97 | train_data = load_json(opt.train_path) 98 | all_sentences = flat_list_of_lists([v["sentences"] for k, v in train_data.items()]) 99 | all_sentences = [nltk.tokenize.word_tokenize(sen.lower()) 100 | for sen in all_sentences] 101 | word2idx = build_vocab_idx(all_sentences, opt.min_word_count) 102 | print("[Info] Dumping the processed data to json file", opt.cache) 103 | word2idx_path = os.path.join(opt.cache, "{}_word2idx_3.json".format(opt.dset_name)) 104 | save_json(word2idx, word2idx_path, save_pretty=True) 105 | print("[Info] Finish.") 106 | 107 | vocab_glove_path = os.path.join(opt.cache, "{}_vocab_glove_3.pt".format(opt.dset_name)) 108 | extract_glove(word2idx, opt.raw_glove_path, vocab_glove_path) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /src/caption_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/21 18:44 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : caption_test.py 6 | import os 7 | import logging 8 | from collections import defaultdict 9 | from tqdm import tqdm 10 | 11 | import torch.distributed as dist 12 | import torch.utils.data 13 | 14 | from utils.train_utils import gather_object_multiple_gpu, get_timestamp, CudaPreFetcher 15 | from utils.json import save_json, load_json 16 | from utils.train_utils import Timer 17 | 18 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 19 | from pycocoevalcap.bleu.bleu import Bleu 20 | from pycocoevalcap.meteor.meteor import Meteor 21 | from pycocoevalcap.rouge.rouge import Rouge 22 | from pycocoevalcap.cider.cider import Cider 23 | from pycocoevalcap.spice.spice import Spice 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class Translator(object): 29 | """Load with trained model and handle the beam search""" 30 | 31 | def __init__(self, checkpoint, model=None): 32 | self.max_v_len = checkpoint['opt'].max_v_len 33 | self.max_t_len = checkpoint['opt'].max_t_len 34 | self.PAD = 0 35 | self.BOS = 4 36 | 37 | self.model = model 38 | self.model.eval() 39 | 40 | self.timer = Timer(synchronize=True, history_size=500, precision=3) 41 | 42 | def translate_batch_single_sentence_greedy(self, inputs, model): 43 | inputs_ids = inputs["input_ids"] 44 | input_masks = inputs["input_mask"] 45 | # max_t_len = self.max_t_len 46 | max_t_len = 21 47 | inputs_ids[:, :] = 0. 48 | input_masks[:, :] = 0. 49 | assert torch.sum(input_masks[:, :]) == 0, "Initially, all text tokens should be masked" 50 | bsz = len(inputs_ids) 51 | next_symbols = torch.IntTensor([self.BOS] * bsz) # (N, ) 52 | 53 | self.timer.reset() 54 | for dec_idx in range(max_t_len): 55 | inputs_ids[:, dec_idx] = next_symbols.clone() 56 | input_masks[:, dec_idx] = 1 57 | outputs = model(inputs) 58 | pred_scores = outputs["prediction_scores"] 59 | # pred_scores[:, :, 49406] = -1e10 60 | next_words = pred_scores[:, dec_idx].max(1)[1] # TODO / NOTE changed 61 | next_symbols = next_words.cpu() 62 | if "visual_output" in outputs: 63 | inputs["visual_output"] = outputs["visual_output"] 64 | else: 65 | logger.debug("visual_output is not in the output of model, this may slow down the caption test") 66 | self.timer(stage_name="inference") 67 | logger.debug(f"inference toke {self.timer.get_info()['average']['inference']} ms") 68 | return inputs_ids 69 | 70 | def translate_batch(self, model_inputs): 71 | """while we used *_list as the input names, they could be non-list for single sentence decoding case""" 72 | return self.translate_batch_single_sentence_greedy(model_inputs, self.model) 73 | 74 | 75 | def convert_ids_to_sentence(tokens): 76 | from .clip.clip import _tokenizer 77 | text = _tokenizer.decode(tokens) 78 | text_list = text.split(" ") 79 | new = [] 80 | for i in range(len(text_list)): 81 | if i == 0: 82 | new.append(text_list[i].split(">")[-1]) 83 | elif "<|endoftext|>" in text_list[i]: 84 | break 85 | else: 86 | new.append(text_list[i]) 87 | return " ".join(new) 88 | 89 | 90 | def run_translate(data_loader, translator, epoch, opt): 91 | # submission template 92 | batch_res = {"version": "VERSION 1.0", 93 | "results": defaultdict(list), 94 | "external_data": {"used": "true", "details": "ay"}} 95 | for bid, batch in enumerate(tqdm(data_loader, 96 | dynamic_ncols=True, 97 | disable=dist.is_initialized() and dist.get_rank() != 0)): 98 | if torch.cuda.is_available(): 99 | batch = CudaPreFetcher.cuda(batch) 100 | dec_seq = translator.translate_batch(batch) 101 | 102 | # example_idx indicates which example is in the batch 103 | for example_idx, (cur_gen_sen, cur_meta) in enumerate(zip(dec_seq, batch['metadata'][1])): 104 | cur_data = { 105 | "sentence": convert_ids_to_sentence(cur_gen_sen.tolist()), 106 | "gt_sentence": cur_meta 107 | } 108 | # print(cur_data) 109 | batch_res["results"][batch['metadata'][0][example_idx].split("video")[-1]].append(cur_data) 110 | translator.timer.print() 111 | return batch_res 112 | 113 | 114 | class EvalCap: 115 | def __init__(self, annos, rests, cls_tokenizer=PTBTokenizer, 116 | use_scorers=['Bleu', 'METEOR', 'ROUGE_L', 'CIDEr']): # ,'SPICE']): 117 | self.evalImgs = [] 118 | self.eval = {} 119 | self.imgToEval = {} 120 | self.annos = annos 121 | self.rests = rests 122 | self.Tokenizer = cls_tokenizer 123 | self.use_scorers = use_scorers 124 | 125 | def evaluate(self): 126 | res = {} 127 | for r in self.rests: 128 | res[str(r['image_id'])] = [{'caption': r['caption']}] 129 | 130 | gts = {} 131 | for imgId in self.annos: 132 | gts[str(imgId)] = [{'caption': self.annos[imgId]}] 133 | 134 | # ================================================= 135 | # Set up scorers 136 | # ================================================= 137 | # print('tokenization...') 138 | tokenizer = self.Tokenizer() 139 | gts = tokenizer.tokenize(gts) 140 | res = tokenizer.tokenize(res) 141 | # gts = {k: v for k, v in gts.items() if k in res.keys()} 142 | # ================================================= 143 | # Set up scorers 144 | # ================================================= 145 | # print('setting up scorers...') 146 | use_scorers = self.use_scorers 147 | scorers = [] 148 | if 'Bleu' in use_scorers: 149 | scorers.append((Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"])) 150 | if 'METEOR' in use_scorers: 151 | scorers.append((Meteor(), "METEOR")) 152 | if 'ROUGE_L' in use_scorers: 153 | scorers.append((Rouge(), "ROUGE_L")) 154 | if 'CIDEr' in use_scorers: 155 | scorers.append((Cider(), "CIDEr")) 156 | if 'SPICE' in use_scorers: 157 | scorers.append((Spice(), "SPICE")) 158 | 159 | # ================================================= 160 | # Compute scores 161 | # ================================================= 162 | for scorer, method in scorers: 163 | score, scores = scorer.compute_score(gts, res) 164 | if type(method) == list: 165 | for sc, scs, m in zip(score, scores, method): 166 | self.setEval(sc, m) 167 | self.setImgToEvalImgs(scs, gts.keys(), m) 168 | # print("%s: %0.1f" % (m, sc*100)) 169 | else: 170 | self.setEval(score, method) 171 | self.setImgToEvalImgs(scores, gts.keys(), method) 172 | # print("%s: %0.1f" % (method, score*100)) 173 | self.setEvalImgs() 174 | 175 | def setEval(self, score, method): 176 | self.eval[method] = score 177 | 178 | def setImgToEvalImgs(self, scores, imgIds, method): 179 | for imgId, score in zip(imgIds, scores): 180 | if not imgId in self.imgToEval: 181 | self.imgToEval[imgId] = {} 182 | self.imgToEval[imgId]["image_id"] = imgId 183 | self.imgToEval[imgId][method] = score 184 | 185 | def setEvalImgs(self): 186 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] 187 | 188 | 189 | def evaluate(submission, reference): 190 | tokenizer = PTBTokenizer # for English 191 | annos = reference 192 | data = submission['results'] 193 | rests = [] 194 | for name, value in data.items(): 195 | rests.append({'image_id': str(name), 'caption': value[0]['sentence']}) 196 | eval_cap = EvalCap(annos, rests, tokenizer) 197 | 198 | eval_cap.evaluate() 199 | 200 | all_score = {} 201 | for metric, score in eval_cap.eval.items(): 202 | all_score[metric] = score 203 | return all_score 204 | 205 | if __name__ == "__main__": 206 | ours = load_json("") 207 | gt = load_json("") 208 | evaluate(ours, gt) -------------------------------------------------------------------------------- /src/rtransformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/src/rtransformer/__init__.py -------------------------------------------------------------------------------- /src/rtransformer/beam_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam_search.py 3 | """ 4 | import torch 5 | 6 | from rtransformer.decode_strategy import DecodeStrategy, length_penalty_builder 7 | from rtransformer.recursive_caption_dataset import RecursiveCaptionDataset as RCDataset 8 | 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class BeamSearch(DecodeStrategy): 14 | """Generation beam search. 15 | 16 | Note that the attributes list is not exhaustive. Rather, it highlights 17 | tensors to document their shape. (Since the state variables' "batch" 18 | size decreases as beams finish, we denote this axis with a B rather than 19 | ``batch_size``). 20 | 21 | Args: 22 | beam_size (int): Number of beams to use (see base ``parallel_paths``). 23 | batch_size (int): See base. 24 | pad (int): See base. 25 | bos (int): See base. 26 | eos (int): See base. 27 | n_best (int): Don't stop until at least this many beams have 28 | reached EOS. 29 | mb_device (torch.device or str): See base ``device``. 30 | min_length (int): See base. 31 | max_length (int): See base. 32 | block_ngram_repeat (int): See base. 33 | exclusion_tokens (set[int]): See base. 34 | 35 | Attributes: 36 | top_beam_finished (ByteTensor): Shape ``(B,)``. 37 | _batch_offset (LongTensor): Shape ``(B,)``. 38 | _beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``. 39 | alive_seq (LongTensor): See base. 40 | topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These 41 | are the scores used for the topk operation. 42 | select_indices (LongTensor or NoneType): Shape 43 | ``(B x beam_size,)``. This is just a flat view of the 44 | ``_batch_index``. 45 | topk_scores (FloatTensor): Shape 46 | ``(B, beam_size)``. These are the 47 | scores a sequence will receive if it finishes. 48 | topk_ids (LongTensor): Shape ``(B, beam_size)``. These are the 49 | word indices of the topk predictions. 50 | _batch_index (LongTensor): Shape ``(B, beam_size)``. 51 | _prev_penalty (FloatTensor or NoneType): Shape 52 | ``(B, beam_size)``. Initialized to ``None``. 53 | _coverage (FloatTensor or NoneType): Shape 54 | ``(1, B x beam_size, inp_seq_len)``. 55 | hypotheses (list[list[Tuple[Tensor]]]): Contains a tuple 56 | of score (float), sequence (long), and attention (float or None). 57 | """ 58 | 59 | def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device, 60 | min_length, max_length, block_ngram_repeat, exclusion_tokens, 61 | length_penalty_name=None, length_penalty_alpha=0.): 62 | super(BeamSearch, self).__init__( 63 | pad, bos, eos, batch_size, mb_device, beam_size, min_length, 64 | block_ngram_repeat, exclusion_tokens, max_length) 65 | # beam parameters 66 | self.beam_size = beam_size 67 | self.n_best = n_best 68 | self.batch_size = batch_size 69 | self.length_penalty_name = length_penalty_name 70 | self.length_penalty_func = length_penalty_builder(length_penalty_name) 71 | self.length_penalty_alpha = length_penalty_alpha 72 | 73 | # result caching 74 | self.hypotheses = [[] for _ in range(batch_size)] 75 | 76 | # beam state 77 | self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8) 78 | self.best_scores = torch.full([batch_size], -1e10, dtype=torch.float, 79 | device=mb_device) # (N, ) 80 | 81 | self._batch_offset = torch.arange(batch_size, dtype=torch.long) # (N, ) 82 | self._beam_offset = torch.arange( 83 | 0, batch_size * beam_size, step=beam_size, dtype=torch.long, 84 | device=mb_device) # (N, ) 85 | self.topk_log_probs = torch.tensor( 86 | [0.0] + [float("-inf")] * (beam_size - 1), device=mb_device 87 | ).repeat(batch_size) # (B*N), guess: store the current beam probabilities 88 | self.select_indices = None 89 | 90 | # buffers for the topk scores and 'backpointer' 91 | self.topk_scores = torch.empty((batch_size, beam_size), 92 | dtype=torch.float, device=mb_device) # (N, B) 93 | self.topk_ids = torch.empty((batch_size, beam_size), dtype=torch.long, 94 | device=mb_device) # (N, B) 95 | self._batch_index = torch.empty([batch_size, beam_size], 96 | dtype=torch.long, device=mb_device) # (N, B) 97 | self.done = False 98 | # "global state" of the old beam 99 | self._prev_penalty = None 100 | self._coverage = None 101 | 102 | @property 103 | def current_predictions(self): 104 | return self.alive_seq[:, -1] 105 | 106 | @property 107 | def current_origin(self): 108 | return self.select_indices 109 | 110 | @property 111 | def current_backptr(self): 112 | # for testing 113 | return self.select_indices.view(self.batch_size, self.beam_size)\ 114 | .fmod(self.beam_size) 115 | 116 | def advance(self, log_probs): 117 | """ current step log_probs (N * B, vocab_size), attn (1, N * B, L) 118 | Which attention is this??? Guess: the one with the encoder outputs 119 | """ 120 | vocab_size = log_probs.size(-1) 121 | 122 | # using integer division to get an integer _B without casting 123 | _B = log_probs.shape[0] // self.beam_size # batch_size 124 | 125 | # force the output to be longer than self.min_length, 126 | # by setting prob(EOS) to be a very small number when < min_length 127 | self.ensure_min_length(log_probs) 128 | 129 | # Multiply probs by the beam probability. 130 | # logger.info("after log_probs {} {}".format(log_probs.shape, log_probs)) 131 | log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) 132 | # logger.info("after log_probs {} {}".format(log_probs.shape, log_probs)) 133 | 134 | self.block_ngram_repeats(log_probs) 135 | 136 | # if the sequence ends now, then the penalty is the current 137 | # length + 1, to include the EOS token, length_penalty is a float number 138 | step = len(self) 139 | length_penalty = self.length_penalty_func(step+1, self.length_penalty_alpha) 140 | 141 | # Flatten probs into a list of possibilities. 142 | # pick topk in all the paths 143 | curr_scores = log_probs / length_penalty 144 | curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size) 145 | # self.topk_scores and self.topk_ids => (N, B) 146 | torch.topk(curr_scores, self.beam_size, dim=-1, 147 | out=(self.topk_scores, self.topk_ids)) 148 | 149 | # Recover log probs. 150 | # Length penalty is just a scalar. It doesn't matter if it's applied 151 | # before or after the topk. 152 | torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs) 153 | 154 | # Resolve beam origin and map to batch index flat representation. 155 | torch.div(self.topk_ids, vocab_size, out=self._batch_index) # _batch_index (N * B) 156 | self._batch_index += self._beam_offset[:_B].unsqueeze(1) 157 | self.select_indices = self._batch_index.view(_B * self.beam_size) 158 | 159 | self.topk_ids.fmod_(vocab_size) # resolve true word ids 160 | 161 | # Append last prediction. 162 | self.alive_seq = torch.cat( 163 | [self.alive_seq.index_select(0, self.select_indices), 164 | self.topk_ids.view(_B * self.beam_size, 1)], -1) # (N * B, step_size) 165 | 166 | self.is_finished = self.topk_ids.eq(self.eos) # (N, B) 167 | self.ensure_max_length() 168 | 169 | def update_finished(self): 170 | # Penalize beams that finished. 171 | _B_old = self.topk_log_probs.shape[0] # batch_size might be changing??? as the beams finished 172 | step = self.alive_seq.shape[-1] # 1 greater than the step in advance, as we advanced 1 step 173 | self.topk_log_probs.masked_fill_(self.is_finished, -1e10) 174 | # on real data (newstest2017) with the pretrained transformer, 175 | # it's faster to not move this back to the original device 176 | self.is_finished = self.is_finished.to('cpu') # (N, B) 177 | self.top_beam_finished |= self.is_finished[:, 0].eq(1) # (N, ) initialized as zeros 178 | predictions = self.alive_seq.view(_B_old, self.beam_size, step) 179 | non_finished_batch = [] 180 | for i in range(self.is_finished.size(0)): # (N, ) 181 | b = self._batch_offset[i] # (0, ..., N-1) 182 | finished_hyp = self.is_finished[i].nonzero().view(-1) 183 | # Store finished hypotheses for this batch. 184 | for j in finished_hyp: 185 | self.hypotheses[b].append([self.topk_scores[i, j], 186 | predictions[i, j, 1:]]) 187 | # End condition is the top beam finished and we can return 188 | # n_best hypotheses. 189 | finish_flag = self.top_beam_finished[i] != 0 190 | if finish_flag and len(self.hypotheses[b]) >= self.n_best: 191 | best_hyp = sorted( 192 | self.hypotheses[b], key=lambda x: x[0], reverse=True) # sort by scores 193 | for n, (score, pred) in enumerate(best_hyp): 194 | if n >= self.n_best: 195 | break 196 | self.scores[b].append(score) 197 | self.predictions[b].append(pred) 198 | else: 199 | non_finished_batch.append(i) 200 | non_finished = torch.tensor(non_finished_batch) 201 | # If all sentences are translated, no need to go further. 202 | if len(non_finished) == 0: 203 | self.done = True 204 | return 205 | 206 | _B_new = non_finished.shape[0] 207 | # Remove finished batches for the next step. (Not finished beam!!!) 208 | self.top_beam_finished = self.top_beam_finished.index_select( 209 | 0, non_finished) 210 | self._batch_offset = self._batch_offset.index_select(0, non_finished) 211 | non_finished = non_finished.to(self.topk_ids.device) 212 | self.topk_log_probs = self.topk_log_probs.index_select(0, 213 | non_finished) 214 | self._batch_index = self._batch_index.index_select(0, non_finished) 215 | self.select_indices = self._batch_index.view(_B_new * self.beam_size) 216 | self.alive_seq = predictions.index_select(0, non_finished) \ 217 | .view(-1, self.alive_seq.size(-1)) 218 | self.topk_scores = self.topk_scores.index_select(0, non_finished) 219 | self.topk_ids = self.topk_ids.index_select(0, non_finished) 220 | -------------------------------------------------------------------------------- /src/rtransformer/decode_strategy.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/decode_strategy.py 3 | """ 4 | import torch 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class DecodeStrategy(object): 10 | """Base class for generation strategies. 11 | 12 | Args: 13 | pad (int): Magic integer in output vocab. 14 | bos (int): Magic integer in output vocab. 15 | eos (int): Magic integer in output vocab. 16 | batch_size (int): Current batch size. 17 | device (torch.device or str): Device for memory bank (encoder). 18 | parallel_paths (int): Decoding strategies like beam search 19 | use parallel paths. Each batch is repeated ``parallel_paths`` 20 | times in relevant state tensors. 21 | min_length (int): Shortest acceptable generation, not counting 22 | begin-of-sentence or end-of-sentence. 23 | max_length (int): Longest acceptable sequence, not counting 24 | begin-of-sentence (presumably there has been no EOS 25 | yet if max_length is used as a cutoff). 26 | block_ngram_repeat (int): Block beams where 27 | ``block_ngram_repeat``-grams repeat. 28 | exclusion_tokens (set[int]): If a gram contains any of these 29 | tokens, it may repeat. 30 | 31 | Attributes: 32 | pad (int): See above. 33 | bos (int): See above. 34 | eos (int): See above. 35 | predictions (list[list[LongTensor]]): For each batch, holds a 36 | list of beam prediction sequences. 37 | scores (list[list[FloatTensor]]): For each batch, holds a 38 | list of scores. 39 | attention (list[list[FloatTensor or list[]]]): For each 40 | batch, holds a list of attention sequence tensors 41 | (or empty lists) having shape ``(step, inp_seq_len)`` where 42 | ``inp_seq_len`` is the length of the sample (not the max 43 | length of all inp seqs). 44 | alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``. 45 | This sequence grows in the ``step`` axis on each call to 46 | :func:`advance()`. 47 | is_finished (ByteTensor or NoneType): Shape 48 | ``(B, parallel_paths)``. Initialized to ``None``. 49 | alive_attn (FloatTensor or NoneType): If tensor, shape is 50 | ``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len`` 51 | is the (max) length of the input sequence. 52 | min_length (int): See above. 53 | max_length (int): See above. 54 | block_ngram_repeat (int): See above. 55 | exclusion_tokens (set[int]): See above. 56 | done (bool): See above. 57 | """ 58 | 59 | def __init__(self, pad, bos, eos, batch_size, device, parallel_paths, 60 | min_length, block_ngram_repeat, exclusion_tokens, max_length): 61 | 62 | # magic indices 63 | self.pad = pad 64 | self.bos = bos 65 | self.eos = eos 66 | 67 | # result caching 68 | self.predictions = [[] for _ in range(batch_size)] 69 | self.scores = [[] for _ in range(batch_size)] 70 | self.attention = [[] for _ in range(batch_size)] 71 | 72 | self.alive_seq = torch.full( 73 | [batch_size * parallel_paths, 1], self.bos, 74 | dtype=torch.long, device=device) # (N * B, step_size=1) 75 | self.is_finished = torch.zeros( 76 | [batch_size, parallel_paths], 77 | dtype=torch.uint8, device=device) 78 | self.alive_attn = None 79 | 80 | self.min_length = min_length 81 | self.max_length = max_length 82 | self.block_ngram_repeat = block_ngram_repeat 83 | self.exclusion_tokens = exclusion_tokens 84 | 85 | self.done = False 86 | 87 | def __len__(self): 88 | return self.alive_seq.shape[1] # steps length 89 | 90 | def ensure_min_length(self, log_probs): 91 | if len(self) <= self.min_length: 92 | log_probs[:, self.eos] = -1e20 93 | 94 | def ensure_max_length(self): 95 | # add one to account for BOS. Don't account for EOS because hitting 96 | # this implies it hasn't been found. 97 | if len(self) == self.max_length + 1: 98 | self.is_finished.fill_(1) 99 | 100 | def block_ngram_repeats(self, log_probs): 101 | # log_probs (N * B, vocab_size) 102 | cur_len = len(self) 103 | if self.block_ngram_repeat > 0 and cur_len > 1: 104 | for path_idx in range(self.alive_seq.shape[0]): # N * B 105 | # skip BOS 106 | hyp = self.alive_seq[path_idx, 1:] 107 | ngrams = set() 108 | fail = False 109 | gram = [] 110 | for i in range(cur_len - 1): 111 | # Last n tokens, n = block_ngram_repeat 112 | gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:] 113 | # skip the blocking if any token in gram is excluded 114 | if set(gram) & self.exclusion_tokens: 115 | continue 116 | if tuple(gram) in ngrams: 117 | fail = True 118 | ngrams.add(tuple(gram)) 119 | if fail: 120 | log_probs[path_idx] = -10e20 # all the words in this path 121 | 122 | def advance(self, log_probs): 123 | """DecodeStrategy subclasses should override :func:`advance()`. 124 | 125 | Advance is used to update ``self.alive_seq``, ``self.is_finished``, 126 | and, when appropriate, ``self.alive_attn``. 127 | """ 128 | 129 | raise NotImplementedError() 130 | 131 | def update_finished(self): 132 | """DecodeStrategy subclasses should override :func:`update_finished()`. 133 | 134 | ``update_finished`` is used to update ``self.predictions``, 135 | ``self.scores``, and other "output" attributes. 136 | """ 137 | 138 | raise NotImplementedError() 139 | 140 | 141 | def length_penalty_builder(length_penalty_name="none"): 142 | """implement length penalty""" 143 | def length_wu(cur_len, alpha=0.): 144 | """GNMT length re-ranking score. 145 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 146 | """ 147 | return ((5 + cur_len) / 6.0) ** alpha 148 | 149 | def length_average(cur_len, alpha=0.): 150 | """Returns the current sequence length.""" 151 | return cur_len 152 | 153 | def length_none(cur_len, alpha=0.): 154 | """Returns unmodified scores.""" 155 | return 1.0 156 | 157 | if length_penalty_name == "none": 158 | return length_none 159 | elif length_penalty_name == "wu": 160 | return length_wu 161 | elif length_penalty_name == "avg": 162 | return length_average 163 | else: 164 | raise NotImplementedError 165 | -------------------------------------------------------------------------------- /src/rtransformer/masked_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | References: 8 | https://github.com/salesforce/densecap/blob/master/model/transformer.py 9 | 10 | Modified by Jie Lei 11 | """ 12 | 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | 17 | import math 18 | import numpy as np 19 | from rtransformer.model import LabelSmoothingLoss 20 | 21 | 22 | INF = 1e10 23 | 24 | 25 | def positional_encodings_like(x, t=None): 26 | if t is None: 27 | positions = torch.arange(0, x.size(1)).float() 28 | if x.is_cuda: 29 | positions = positions.cuda(x.get_device()) 30 | else: 31 | positions = t 32 | encodings = torch.zeros(*x.size()[1:]) 33 | if x.is_cuda: 34 | encodings = encodings.cuda(x.get_device()) 35 | 36 | for channel in range(x.size(-1)): 37 | if channel % 2 == 0: 38 | encodings[:, channel] = torch.sin(positions / 10000 ** (channel / x.size(2))) 39 | else: 40 | encodings[:, channel] = torch.cos(positions / 10000 ** ((channel - 1) / x.size(2))) 41 | return encodings 42 | 43 | 44 | class LayerNorm(nn.Module): 45 | def __init__(self, d_model, eps=1e-6): 46 | super(LayerNorm, self).__init__() 47 | self.gamma = nn.Parameter(torch.ones(d_model)) 48 | self.beta = nn.Parameter(torch.zeros(d_model)) 49 | self.eps = eps 50 | 51 | def forward(self, x): 52 | mean = x.mean(-1, keepdim=True) 53 | std = x.std(-1, keepdim=True) 54 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 55 | 56 | 57 | class ResidualBlock(nn.Module): 58 | def __init__(self, layer, d_model, drop_ratio): 59 | super(ResidualBlock, self).__init__() 60 | self.layer = layer 61 | self.dropout = nn.Dropout(drop_ratio) 62 | self.layernorm = LayerNorm(d_model) 63 | 64 | def forward(self, *x): 65 | return self.layernorm(x[0] + self.dropout(self.layer(*x))) 66 | 67 | 68 | class Attention(nn.Module): 69 | 70 | def __init__(self, d_key, drop_ratio, causal): 71 | super(Attention, self).__init__() 72 | self.scale = math.sqrt(d_key) 73 | self.dropout = nn.Dropout(drop_ratio) 74 | self.causal = causal 75 | 76 | def forward(self, query, key, value): 77 | dot_products = torch.bmm(query, key.transpose(1, 2)) 78 | if query.dim() == 3 and (self is None or self.causal): 79 | tri = torch.ones(key.size(1), key.size(1)).triu(1) * INF 80 | if key.is_cuda: 81 | tri = tri.cuda(key.get_device()) 82 | dot_products.data.sub_(tri.unsqueeze(0)) 83 | return torch.bmm(self.dropout(F.softmax(dot_products / self.scale, dim=-1)), value) 84 | 85 | 86 | class MultiHead(nn.Module): 87 | def __init__(self, d_key, d_value, n_heads, drop_ratio, causal=False): 88 | super(MultiHead, self).__init__() 89 | self.attention = Attention(d_key, drop_ratio, causal=causal) 90 | self.wq = nn.Linear(d_key, d_key, bias=False) 91 | self.wk = nn.Linear(d_key, d_key, bias=False) 92 | self.wv = nn.Linear(d_value, d_value, bias=False) 93 | self.wo = nn.Linear(d_value, d_key, bias=False) 94 | self.n_heads = n_heads 95 | 96 | def forward(self, query, key, value): 97 | query, key, value = self.wq(query), self.wk(key), self.wv(value) 98 | query, key, value = ( 99 | x.chunk(self.n_heads, -1) for x in (query, key, value)) 100 | return self.wo(torch.cat([self.attention(q, k, v) 101 | for q, k, v in zip(query, key, value)], -1)) 102 | 103 | 104 | class FeedForward(nn.Module): 105 | def __init__(self, d_model, d_hidden): 106 | super(FeedForward, self).__init__() 107 | self.linear1 = nn.Linear(d_model, d_hidden) 108 | self.linear2 = nn.Linear(d_hidden, d_model) 109 | 110 | def forward(self, x): 111 | return self.linear2(F.relu(self.linear1(x))) 112 | 113 | 114 | class EncoderLayer(nn.Module): 115 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio): 116 | super(EncoderLayer, self).__init__() 117 | self.selfattn = ResidualBlock( 118 | MultiHead(d_model, d_model, n_heads, drop_ratio, causal=False), 119 | d_model, drop_ratio) 120 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden), 121 | d_model, drop_ratio) 122 | 123 | def forward(self, x): 124 | return self.feedforward(self.selfattn(x, x, x)) 125 | 126 | 127 | class DecoderLayer(nn.Module): 128 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio): 129 | super(DecoderLayer, self).__init__() 130 | self.selfattn = ResidualBlock( 131 | MultiHead(d_model, d_model, n_heads, drop_ratio, causal=True), 132 | d_model, drop_ratio) 133 | self.attention = ResidualBlock( 134 | MultiHead(d_model, d_model, n_heads, drop_ratio), 135 | d_model, drop_ratio) 136 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden), 137 | d_model, drop_ratio) 138 | 139 | def forward(self, x, encoding): 140 | """ 141 | Args: 142 | x: (N, Lt, D) 143 | encoding: (N, Lv, D) 144 | """ 145 | x = self.selfattn(x, x, x) # (N, Lt, D) 146 | return self.feedforward(self.attention(x, encoding, encoding)) # (N, Lt, D) 147 | 148 | 149 | class Encoder(nn.Module): 150 | def __init__(self, vfeat_size, d_model, d_hidden, n_layers, n_heads, drop_ratio): 151 | super(Encoder, self).__init__() 152 | self.video_embeddings = nn.Sequential( 153 | LayerNorm(vfeat_size), 154 | nn.Dropout(drop_ratio), 155 | nn.Linear(vfeat_size, d_model) 156 | ) 157 | self.layers = nn.ModuleList( 158 | [EncoderLayer(d_model, d_hidden, n_heads, drop_ratio) 159 | for i in range(n_layers)]) 160 | self.dropout = nn.Dropout(drop_ratio) 161 | 162 | def forward(self, x, mask=None): 163 | """ 164 | 165 | Args: 166 | x: (N, Lv, Dv) 167 | mask: (N, Lv) 168 | 169 | Returns: 170 | 171 | """ 172 | x = self.video_embeddings(x) # (N, Lv, D) 173 | x = x + positional_encodings_like(x) 174 | x = self.dropout(x) 175 | mask.unsqueeze_(-1) 176 | if mask is not None: 177 | x = x*mask 178 | encoding = [] 179 | for layer in self.layers: 180 | x = layer(x) 181 | if mask is not None: 182 | x = x*mask 183 | encoding.append(x) 184 | return encoding 185 | 186 | 187 | class Decoder(nn.Module): 188 | def __init__(self, d_model, d_hidden, vocab_size, n_layers, n_heads, 189 | drop_ratio): 190 | super(Decoder, self).__init__() 191 | self.layers = nn.ModuleList( 192 | [DecoderLayer(d_model, d_hidden, n_heads, drop_ratio) 193 | for i in range(n_layers)]) 194 | self.out = nn.Linear(d_model, vocab_size) 195 | self.dropout = nn.Dropout(drop_ratio) 196 | self.d_model = d_model 197 | self.d_out = vocab_size 198 | 199 | def forward(self, x, encoding): 200 | """ 201 | Args: 202 | x: (N, Lt) 203 | encoding: [(N, Lv, D), ] * num_hidden_layers 204 | 205 | """ 206 | x = F.embedding(x, self.out.weight * math.sqrt(self.d_model)) # (N, Lt, D) 207 | x = x + positional_encodings_like(x) # (N, Lt, D) 208 | x = self.dropout(x) # (N, Lt, D) 209 | for layer, enc in zip(self.layers, encoding): 210 | x = layer(x, enc) # (N, Lt, D) 211 | return x # (N, Lt, D) at last layer 212 | 213 | 214 | class MTransformer(nn.Module): 215 | def __init__(self, config): 216 | super(MTransformer, self).__init__() 217 | self.config = config 218 | vfeat_size = config.video_feature_size 219 | d_model = config.hidden_size # 1024 220 | d_hidden = config.intermediate_size # 2048 221 | n_layers = config.num_hidden_layers # 6 222 | n_heads = config.num_attention_heads # 8 223 | drop_ratio = config.hidden_dropout_prob # 0.1 224 | self.vocab_size = config.vocab_size 225 | self.encoder = Encoder(vfeat_size, d_model, d_hidden, n_layers, 226 | n_heads, drop_ratio) 227 | self.decoder = Decoder(d_model, d_hidden, self.vocab_size, 228 | n_layers, n_heads, drop_ratio) 229 | self.loss_func = LabelSmoothingLoss(config.label_smoothing, config.vocab_size, ignore_index=-1) \ 230 | if "label_smoothing" in config and config.label_smoothing > 0 else nn.CrossEntropyLoss(ignore_index=-1) 231 | 232 | def encode(self, video_features, video_masks): 233 | """ 234 | Args: 235 | video_features: (N, Lv, Dv) 236 | video_masks: (N, Lv) with 1 indicates valid bits 237 | """ 238 | return self.encoder(video_features, video_masks) 239 | 240 | def decode(self, text_input_ids, text_masks, text_input_labels, encoder_outputs, video_masks): 241 | """ 242 | Args: 243 | text_input_ids: (N, Lt) 244 | text_masks: (N, Lt) with 1 indicates valid bits, 245 | text_input_labels: (N, Lt) with `-1` on ignored positions 246 | encoder_outputs: (N, Lv, D) 247 | video_masks: not used, leave here to maintain a common API with untied model 248 | """ 249 | # the triangular mask is generated and applied inside the attention module 250 | h = self.decoder(text_input_ids, encoder_outputs) # (N, Lt, D) 251 | prediction_scores = self.decoder.out(h) # (N, Lt, vocab_size) 252 | caption_loss = self.loss_func(prediction_scores.view(-1, self.config.vocab_size), 253 | text_input_labels.view(-1)) # float 254 | return caption_loss, prediction_scores 255 | 256 | def forward(self, video_features, video_masks, text_input_ids, text_masks, text_input_labels): 257 | """ 258 | Args: 259 | video_features: (N, Lv, Dv) 260 | video_masks: (N, Lv) with 1 indicates valid bits 261 | text_input_ids: (N, Lt) 262 | text_masks: (N, Lt) with 1 indicates valid bits 263 | text_input_labels: (N, Lt) with `-1` on ignored positions (in some sense duplicate with text_masks) 264 | """ 265 | encoder_layer_outputs = self.encode(video_features, video_masks) # [(N, Lv, D), ] * num_hidden_layers 266 | caption_loss, prediction_scores = self.decode( 267 | text_input_ids, text_masks, text_input_labels, encoder_layer_outputs, None) # float, (N, Lt, vocab_size) 268 | return caption_loss, prediction_scores 269 | -------------------------------------------------------------------------------- /src/rtransformer/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class EMA(object): 184 | """ Exponential Moving Average for model parameters. 185 | references: 186 | [1] https://github.com/BangLiu/QANet-PyTorch/blob/master/model/modules/ema.py 187 | [2] https://github.com/hengruo/QANet-pytorch/blob/e2de07cd2c711d525f5ffee35c3764335d4b501d/main.py""" 188 | def __init__(self, decay): 189 | self.decay = decay 190 | self.shadow = {} 191 | self.original = {} 192 | 193 | def register(self, name, val): 194 | self.shadow[name] = val.clone() 195 | 196 | def __call__(self, model, step): 197 | decay = min(self.decay, (1 + step) / (10.0 + step)) 198 | for name, param in model.named_parameters(): 199 | if param.requires_grad: 200 | assert name in self.shadow 201 | new_average = \ 202 | (1.0 - decay) * param.data + decay * self.shadow[name] 203 | self.shadow[name] = new_average.clone() 204 | 205 | def assign(self, model): 206 | for name, param in model.named_parameters(): 207 | if param.requires_grad: 208 | assert name in self.shadow 209 | self.original[name] = param.data.clone() 210 | param.data = self.shadow[name] 211 | 212 | def resume(self, model): 213 | for name, param in model.named_parameters(): 214 | if param.requires_grad: 215 | assert name in self.shadow 216 | param.data = self.original[name] 217 | 218 | 219 | class BertAdam(Optimizer): 220 | """Implements BERT version of Adam algorithm with weight decay fix. 221 | Params: 222 | lr: learning rate 223 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 224 | t_total: total number of training steps for the learning 225 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 226 | schedule: schedule to use for the warmup (see above). 227 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 228 | If `None` or `'none'`, learning rate is always kept constant. 229 | Default : `'warmup_linear'` 230 | b1: Adams b1. Default: 0.9 231 | b2: Adams b2. Default: 0.999 232 | e: Adams epsilon. Default: 1e-6 233 | weight_decay: Weight decay. Default: 0.01 234 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 235 | """ 236 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 237 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 238 | if lr is not required and lr < 0.0: 239 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 240 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 241 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 242 | if not 0.0 <= b1 < 1.0: 243 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 244 | if not 0.0 <= b2 < 1.0: 245 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 246 | if not e >= 0.0: 247 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 248 | # initialize schedule object 249 | if not isinstance(schedule, _LRSchedule): 250 | schedule_type = SCHEDULES[schedule] 251 | schedule = schedule_type(warmup=warmup, t_total=t_total) 252 | else: 253 | if warmup != -1 or t_total != -1: 254 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 255 | "Please specify custom warmup and t_total in _LRSchedule object.") 256 | defaults = dict(lr=lr, schedule=schedule, 257 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 258 | max_grad_norm=max_grad_norm) 259 | super(BertAdam, self).__init__(params, defaults) 260 | 261 | def get_lr(self): 262 | lr = [] 263 | for group in self.param_groups: 264 | for p in group['params']: 265 | state = self.state[p] 266 | if len(state) == 0: 267 | return [0] 268 | lr_scheduled = group['lr'] 269 | lr_scheduled *= group['schedule'].get_lr(state['step']) 270 | lr.append(lr_scheduled) 271 | return lr 272 | 273 | def step(self, closure=None): 274 | """Performs a single optimization step. 275 | 276 | Arguments: 277 | closure (callable, optional): A closure that reevaluates the model 278 | and returns the loss. 279 | """ 280 | loss = None 281 | if closure is not None: 282 | loss = closure() 283 | 284 | for group in self.param_groups: 285 | for p in group['params']: 286 | if p.grad is None: 287 | continue 288 | grad = p.grad.data 289 | if grad.is_sparse: 290 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 291 | 292 | state = self.state[p] 293 | 294 | # State initialization 295 | if len(state) == 0: 296 | state['step'] = 0 297 | # Exponential moving average of gradient values 298 | state['next_m'] = torch.zeros_like(p.data) 299 | # Exponential moving average of squared gradient values 300 | state['next_v'] = torch.zeros_like(p.data) 301 | 302 | next_m, next_v = state['next_m'], state['next_v'] 303 | beta1, beta2 = group['b1'], group['b2'] 304 | 305 | # Add grad clipping 306 | if group['max_grad_norm'] > 0: 307 | clip_grad_norm_(p, group['max_grad_norm']) 308 | 309 | # Decay the first and second moment running average coefficient 310 | # In-place operations to update the averages at the same time 311 | #next_m.mul_(beta1).add_(1 - beta1, grad) 312 | #next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 313 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 314 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 315 | 316 | update = next_m / (next_v.sqrt() + group['e']) 317 | 318 | # Just adding the square of the weights to the loss function is *not* 319 | # the correct way of using L2 regularization/weight decay with Adam, 320 | # since that will interact with the m and v parameters in strange ways. 321 | # 322 | # Instead we want to decay the weights in a manner that doesn't interact 323 | # with the m/v parameters. This is equivalent to adding the square 324 | # of the weights to the loss with plain (non-momentum) SGD. 325 | if group['weight_decay'] > 0.0: 326 | update += group['weight_decay'] * p.data 327 | 328 | lr_scheduled = group['lr'] 329 | lr_scheduled *= group['schedule'].get_lr(state['step']) 330 | 331 | update_with_lr = lr_scheduled * update 332 | p.data.add_(-update_with_lr) 333 | 334 | state['step'] += 1 335 | 336 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 337 | # No bias correction 338 | # bias_correction1 = 1 - beta1 ** state['step'] 339 | # bias_correction2 = 1 - beta2 ** state['step'] 340 | 341 | return loss 342 | -------------------------------------------------------------------------------- /src/translate.py: -------------------------------------------------------------------------------- 1 | """ Translate input text with trained model. """ 2 | 3 | import os 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import argparse 7 | from tqdm import tqdm 8 | import random 9 | import numpy as np 10 | import subprocess 11 | from collections import defaultdict 12 | 13 | from translator import Translator 14 | from rtransformer.recursive_caption_dataset import \ 15 | single_sentence_collate, prepare_batch_inputs 16 | from rtransformer.recursive_caption_dataset import RecursiveCaptionDataset as RCDataset 17 | from utils_func import load_json, merge_dicts, save_json 18 | 19 | 20 | def sort_res(res_dict): 21 | """res_dict: the submission json entry `results`""" 22 | final_res_dict = {} 23 | for k, v in res_dict.items(): 24 | final_res_dict[k] = sorted(v, key=lambda x: float(x["timestamp"][0])) 25 | return final_res_dict 26 | 27 | def run_translate(eval_data_loader, translator, opt): 28 | # submission template 29 | batch_res = {"version": "VERSION 1.0", 30 | "results": defaultdict(list), 31 | "external_data": {"used": "true", "details": "ay"}} 32 | 33 | for raw_batch in tqdm(eval_data_loader, mininterval=2, desc=" - (Translate)"): 34 | meta = raw_batch[2] # list(dict), len == bsz 35 | batched_data = prepare_batch_inputs(raw_batch[0], device=translator.device) 36 | model_inputs = [ 37 | batched_data["input_ids"], 38 | batched_data["input_ids2"], 39 | batched_data["video_feature"], 40 | batched_data['region_feature'], 41 | batched_data["input_mask"], 42 | batched_data["input_mask2"], 43 | batched_data["token_type_ids"], 44 | batched_data["token_type_ids2"], 45 | batched_data['kg_mask'], 46 | batched_data['l_mask'], 47 | ] 48 | 49 | dec_seq = translator.translate_batch( 50 | model_inputs, use_beam=opt.use_beam, recurrent=False, untied=opt.untied or opt.mtrans) 51 | 52 | # example_idx indicates which example is in the batch 53 | for example_idx, (cur_gen_sen, cur_meta) in enumerate(zip(dec_seq, meta)): 54 | cur_data = { 55 | "sentence": eval_data_loader.dataset.convert_ids_to_sentence( 56 | cur_gen_sen.cpu().tolist()), 57 | "timestamp": cur_meta["timestamp"], 58 | "gt_sentence": cur_meta["gt_sentence"] 59 | } 60 | batch_res["results"][cur_meta["name"]].append(cur_data) 61 | 62 | batch_res["results"] = sort_res(batch_res["results"]) 63 | return batch_res 64 | 65 | 66 | def get_data_loader(opt, eval_mode="val"): 67 | eval_dataset = RCDataset( 68 | dset_name=opt.dset_name, 69 | data_dir=opt.data_dir, video_feature_dir=opt.video_feature_dir, 70 | duration_file=opt.v_duration_file, 71 | word2idx_path=opt.word2idx_path, max_t_len=opt.max_t_len, 72 | max_v_len=opt.max_v_len, max_n_sen=opt.max_n_sen + 10, mode=eval_mode, 73 | recurrent=opt.recurrent, untied=opt.untied or opt.mtrans) 74 | 75 | collate_fn = single_sentence_collate 76 | eval_data_loader = DataLoader(eval_dataset, collate_fn=collate_fn, 77 | batch_size=opt.batch_size, shuffle=False, num_workers=8) 78 | return eval_data_loader 79 | 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser(description="translate.py") 83 | 84 | parser.add_argument("--eval_splits", type=str, nargs="+", default=["val", ], 85 | choices=["val", "test"], help="evaluate on val/test set, yc2 only has val") 86 | parser.add_argument("--res_dir", required=True, help="path to dir containing model .pt file") 87 | parser.add_argument("--batch_size", type=int, default=100, help="batch size") 88 | 89 | # beam search configs 90 | parser.add_argument("--use_beam", action="store_true", help="use beam search, otherwise greedy search") 91 | parser.add_argument("--beam_size", type=int, default=2, help="beam size") 92 | parser.add_argument("--n_best", type=int, default=1, help="stop searching when get n_best from beam search") 93 | parser.add_argument("--min_sen_len", type=int, default=5, help="minimum length of the decoded sentences") 94 | parser.add_argument("--max_sen_len", type=int, default=30, help="maximum length of the decoded sentences") 95 | parser.add_argument("--block_ngram_repeat", type=int, default=0, help="block repetition of ngrams during decoding.") 96 | parser.add_argument("--length_penalty_name", default="none", 97 | choices=["none", "wu", "avg"], help="length penalty to use.") 98 | parser.add_argument("--length_penalty_alpha", type=float, default=0., 99 | help="Google NMT length penalty parameter (higher = longer generation)") 100 | parser.add_argument("--eval_tool_dir", type=str, default="/mnt/Pycharm_Remote/recurrent_transformer/densevid_eval") 101 | 102 | parser.add_argument("--no_cuda", action="store_true") 103 | parser.add_argument("--seed", default=2019, type=int) 104 | parser.add_argument("--debug", action="store_true") 105 | 106 | opt = parser.parse_args() 107 | opt.cuda = not opt.no_cuda 108 | 109 | # random seed 110 | random.seed(opt.seed) 111 | np.random.seed(opt.seed) 112 | torch.manual_seed(opt.seed) 113 | 114 | checkpoint = torch.load(os.path.join(opt.res_dir, "model.chkpt")) 115 | 116 | # add some of the train configs 117 | train_opt = checkpoint["opt"] # EDict(load_json(os.path.join(opt.res_dir, "model.cfg.json"))) 118 | for k in train_opt.__dict__: 119 | if k not in opt.__dict__: 120 | setattr(opt, k, getattr(train_opt, k)) 121 | print("train_opt", train_opt) 122 | 123 | decoding_strategy = "beam{}_lp_{}_la_{}".format( 124 | opt.beam_size, opt.length_penalty_name, opt.length_penalty_alpha) if opt.use_beam else "greedy" 125 | save_json(vars(opt), 126 | os.path.join(opt.res_dir, "{}_eval_cfg.json".format(decoding_strategy)), 127 | save_pretty=True) 128 | 129 | if opt.dset_name == "anet": 130 | reference_files_map = { 131 | "val": [os.path.join(opt.data_dir, e) for e in 132 | ["anet_entities_val_1_para.json", "anet_entities_val_2_para.json"]], 133 | "test": [os.path.join(opt.data_dir, e) for e in 134 | ["anet_entities_test_1_para.json", "anet_entities_test_2_para.json"]]} 135 | else: # yc2 136 | reference_files_map = {"val": [os.path.join(opt.data_dir, "yc2_val_anet_format_para.json")]} 137 | for eval_mode in opt.eval_splits: 138 | print("Start evaluating {}".format(eval_mode)) 139 | # add 10 at max_n_sen to make the inference stage use all the segments 140 | eval_data_loader = get_data_loader(opt, eval_mode=eval_mode) 141 | eval_references = reference_files_map[eval_mode] 142 | 143 | # setup model 144 | translator = Translator(opt, checkpoint) 145 | 146 | pred_file = os.path.join(opt.res_dir, "{}_pred_{}.json".format(decoding_strategy, eval_mode)) 147 | pred_file = os.path.abspath(pred_file) 148 | if not os.path.exists(pred_file): 149 | json_res = run_translate(eval_data_loader, translator, opt=opt) 150 | save_json(json_res, pred_file, save_pretty=True) 151 | else: 152 | print("Using existing prediction file at {}".format(pred_file)) 153 | 154 | # COCO language evaluation 155 | lang_file = pred_file.replace(".json", "_lang.json") 156 | eval_command = ["python", "para-evaluate.py", "-s", pred_file, "-o", lang_file, 157 | "-v", "-r"] + eval_references 158 | subprocess.call(eval_command, cwd=opt.eval_tool_dir) 159 | 160 | # basic stats 161 | stat_filepath = pred_file.replace(".json", "_stat.json") 162 | eval_stat_cmd = ["python", "get_caption_stat.py", "-s", pred_file, "-r", eval_references[0], 163 | "-o", stat_filepath, "-v"] 164 | subprocess.call(eval_stat_cmd, cwd=opt.eval_tool_dir) 165 | 166 | # repetition evaluation 167 | rep_filepath = pred_file.replace(".json", "_rep.json") 168 | eval_rep_cmd = ["python", "evaluateRepetition.py", "-s", pred_file, 169 | "-r", eval_references[0], "-o", rep_filepath] 170 | subprocess.call(eval_rep_cmd, cwd=opt.eval_tool_dir) 171 | 172 | metric_filepaths = [lang_file, stat_filepath, rep_filepath] 173 | all_metrics = merge_dicts([load_json(e) for e in metric_filepaths]) 174 | all_metrics_filepath = pred_file.replace(".json", "_all_metrics.json") 175 | save_json(all_metrics, all_metrics_filepath, save_pretty=True) 176 | 177 | print("pred_file {} lang_file {}".format(pred_file, lang_file)) 178 | print("[Info] Finished {}.".format(eval_mode)) 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/12 22:30 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : __init__.py 6 | -------------------------------------------------------------------------------- /src/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/13 00:25 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : checkpoint.py 6 | 7 | 8 | import os 9 | import torch 10 | import logging 11 | from typing import * 12 | from collections import OrderedDict 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def auto_resume(ckpt_folder): 18 | try: 19 | ckpt_files = [ckpt for ckpt in os.listdir(ckpt_folder) if ckpt.endswith(".pth")] 20 | except FileNotFoundError: 21 | ckpt_files = [] 22 | if len(ckpt_files) > 0: 23 | return max([os.path.join(ckpt_folder, file) for file in ckpt_files], key=os.path.getmtime) 24 | else: 25 | return None 26 | 27 | 28 | def save_checkpoint(ckpt_folder, epoch, model, optimizer, scheduler, config, prefix="checkpoint"): 29 | if hasattr(model, 'module'): 30 | model = model.module 31 | stat_dict = { 32 | "epoch": epoch, 33 | "model": model.state_dict(), 34 | "optimizer": optimizer.state_dict(), 35 | "scheduler": scheduler.state_dict() if scheduler is not None else "", 36 | "config": config 37 | } 38 | ckpt_path = os.path.join(ckpt_folder, f"{prefix}_{epoch}.pth") 39 | os.makedirs(ckpt_folder, exist_ok=True) 40 | torch.save(stat_dict, ckpt_path) 41 | return ckpt_path 42 | 43 | 44 | def load_checkpoint(ckpt_file, model: torch.nn.Module, optimizer: Union[torch.optim.Optimizer, None], scheduler: Any, 45 | restart_train=False, rewrite: Tuple[str, str] = None): 46 | if hasattr(model, 'module'): 47 | model = model.module 48 | state_dict = torch.load(ckpt_file, map_location="cpu") 49 | if rewrite is not None: 50 | logger.info("rewrite model checkpoint prefix: %s->%s", *rewrite) 51 | state_dict["model"] = {k.replace(*rewrite) if k.startswith(rewrite[0]) else k: v 52 | for k, v in state_dict["model"].items()} 53 | try: 54 | missing = model.load_state_dict(state_dict["model"], strict=False) 55 | logger.debug(f"checkpoint key missing: {missing}") 56 | except RuntimeError: 57 | print("fail to directly recover from checkpoint, try to match each layers...") 58 | net_dict = model.state_dict() 59 | print("find %s layers", len(state_dict["model"].items())) 60 | missing_keys = [k for k, v in state_dict["model"].items() if k not in net_dict or net_dict[k].shape != v.shape] 61 | print("missing key: %s", missing_keys) 62 | state_dict["model"] = {k: v for k, v in state_dict["model"].items() if 63 | (k in net_dict and net_dict[k].shape == v.shape)} 64 | print("resume %s layers from checkpoint", len(state_dict["model"].items())) 65 | net_dict.update(state_dict["model"]) 66 | model.load_state_dict(OrderedDict(net_dict)) 67 | 68 | if not restart_train: 69 | if optimizer is not None and state_dict["optimizer"]: 70 | optimizer.load_state_dict(state_dict["optimizer"]) 71 | if scheduler is not None and state_dict["scheduler"]: 72 | scheduler.load_state_dict(state_dict["scheduler"]) 73 | epoch = state_dict["epoch"] 74 | else: 75 | logger.info("restart train, optimizer and scheduler will not be resumed") 76 | epoch = 0 77 | 78 | del state_dict 79 | torch.cuda.empty_cache() 80 | return epoch # start epoch 81 | 82 | 83 | def save_model(model_file: str, model: torch.nn.Module): 84 | if hasattr(model, "module"): 85 | model = model.module 86 | torch.save(model.state_dict(), model_file) 87 | 88 | 89 | def load_model(model_file: str, model: torch.nn.Module, strict=True): 90 | if hasattr(model, "module"): 91 | model = model.module 92 | state_dict = torch.load(model_file, map_location="cpu") 93 | 94 | missing_keys: List[str] = [] 95 | unexpected_keys: List[str] = [] 96 | error_msgs: List[str] = [] 97 | 98 | # copy state_dict so _load_from_state_dict can modify it 99 | metadata = getattr(state_dict, '_metadata', None) 100 | state_dict = state_dict.copy() 101 | if metadata is not None: 102 | # mypy isn't aware that "_metadata" exists in state_dict 103 | state_dict._metadata = metadata # type: ignore[attr-defined] 104 | 105 | def load(module, prefix=''): 106 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 107 | module._load_from_state_dict( 108 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 109 | for name, child in module._modules.items(): 110 | if child is not None: 111 | load(child, prefix + name + '.') 112 | 113 | load(model) 114 | del load 115 | 116 | if len(missing_keys) > 0: 117 | logger.info("Weights of {} not initialized from pretrained model: {}" 118 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) 119 | if len(unexpected_keys) > 0: 120 | logger.info("Weights from pretrained model not used in {}: {}" 121 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) 122 | if len(error_msgs) > 0: 123 | logger.info("Weights from pretrained model cause errors in {}: {}" 124 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) 125 | 126 | if len(missing_keys) == 0 and len(unexpected_keys) == 0 and len(error_msgs) == 0: 127 | logger.info("All keys loaded successfully for {}".format(model.__class__.__name__)) 128 | 129 | if strict and len(error_msgs) > 0: 130 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 131 | model.__class__.__name__, "\n\t".join(error_msgs))) 132 | -------------------------------------------------------------------------------- /src/utils/json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/2/18 06:03 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : json.py 6 | 7 | import json 8 | 9 | 10 | def load_json(file_path): 11 | with open(file_path, "r") as f: 12 | return json.load(f) 13 | 14 | 15 | def save_json(data, filename, save_pretty=False, sort_keys=False): 16 | class MyEncoder(json.JSONEncoder): 17 | 18 | def default(self, obj): 19 | if isinstance(obj, bytes): # bytes->str 20 | return str(obj, encoding='utf-8') 21 | return json.JSONEncoder.default(self, obj) 22 | 23 | with open(filename, "w") as f: 24 | if save_pretty: 25 | f.write(json.dumps(data, cls=MyEncoder, indent=4, sort_keys=sort_keys)) 26 | else: 27 | json.dump(data, f) 28 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/12 22:32 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : logging.py 6 | 7 | 8 | import os 9 | import logging 10 | import colorlog 11 | import torch.distributed as dist 12 | 13 | level_dict = { 14 | "critical": logging.CRITICAL, 15 | "error": logging.ERROR, 16 | "warning": logging.WARNING, 17 | "info": logging.INFO, 18 | "debug": logging.DEBUG, 19 | "notset": logging.NOTSET 20 | } 21 | 22 | 23 | # noinspection SpellCheckingInspection 24 | def setup_logging(cfg): 25 | # log file 26 | if len(str(cfg.LOG.LOGGER_FILE).split(".")) == 2: 27 | file_name, extension = str(cfg.LOG.LOGGER_FILE).split(".") 28 | log_file_debug = os.path.join(cfg.LOG.DIR, f"{file_name}_debug.{extension}") 29 | log_file_info = os.path.join(cfg.LOG.DIR, f"{file_name}_info.{extension}") 30 | elif len(str(cfg.LOG.LOGGER_FILE).split(".")) == 1: 31 | file_name = cfg.LOG.LOGGER_FILE 32 | log_file_debug = os.path.join(cfg.LOG.DIR, f"{file_name}_debug") 33 | log_file_info = os.path.join(cfg.LOG.DIR, f"{file_name}_info") 34 | else: 35 | raise ValueError("cfg.LOG.LOGGER_FILE is invalid: %s", cfg.LOG.LOGGER_FILE) 36 | logger = logging.getLogger(__name__.split(".")[0]) 37 | logger.setLevel(logging.DEBUG) 38 | logger.handlers.clear() 39 | formatter = logging.Formatter( 40 | f"[%(asctime)s][%(levelname)s]{f'[Rank {dist.get_rank()}]' if dist.is_initialized() else ''} " 41 | "%(filename)s: %(lineno)3d: %(message)s", 42 | datefmt="%m/%d %H:%M:%S", 43 | ) 44 | color_formatter = colorlog.ColoredFormatter( 45 | f"%(log_color)s%(bold)s%(levelname)-8s%(reset)s" 46 | f"%(log_color)s[%(asctime)s]" 47 | f"{f'[Rank {dist.get_rank()}]' if dist.is_initialized() else ''}" 48 | "[%(filename)s: %(lineno)3d]:%(reset)s " 49 | "%(message)s", 50 | datefmt="%m/%d %H:%M:%S", 51 | ) 52 | # log file 53 | if os.path.dirname(log_file_debug): # dir name is not empty 54 | os.makedirs(os.path.dirname(log_file_debug), exist_ok=True) 55 | # console 56 | handler_console = logging.StreamHandler() 57 | assert cfg.LOG.LOGGER_CONSOLE_LEVEL.lower() in level_dict, \ 58 | f"Log level {cfg.LOG.LOGGER_CONSOLE_LEVEL} is invalid" 59 | handler_console.setLevel(level_dict[cfg.LOG.LOGGER_CONSOLE_LEVEL.lower()]) 60 | handler_console.setFormatter(color_formatter if cfg.LOG.LOGGER_CONSOLE_COLORFUL else formatter) 61 | logger.addHandler(handler_console) 62 | # debug level 63 | handler_debug = logging.FileHandler(log_file_debug, mode="a") 64 | handler_debug.setLevel(logging.DEBUG) 65 | handler_debug.setFormatter(formatter) 66 | logger.addHandler(handler_debug) 67 | # info level 68 | handler_info = logging.FileHandler(log_file_info, mode="a") 69 | handler_info.setLevel(logging.INFO) 70 | handler_info.setFormatter(formatter) 71 | logger.addHandler(handler_info) 72 | 73 | logger.propagate = False 74 | 75 | 76 | def show_registry(): 77 | from mm_video.data.build import DATASET_REGISTRY, COLLATE_FN_REGISTER 78 | from mm_video.modeling.model import MODEL_REGISTRY 79 | from mm_video.modeling.optimizer import OPTIMIZER_REGISTRY 80 | from mm_video.modeling.loss import LOSS_REGISTRY 81 | from mm_video.modeling.meter import METER_REGISTRY 82 | 83 | logger = logging.getLogger(__name__) 84 | logger.debug(DATASET_REGISTRY) 85 | logger.debug(COLLATE_FN_REGISTER) 86 | logger.debug(MODEL_REGISTRY) 87 | logger.debug(OPTIMIZER_REGISTRY) 88 | logger.debug(LOSS_REGISTRY) 89 | logger.debug(METER_REGISTRY) 90 | -------------------------------------------------------------------------------- /src/utils/register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/2/19 23:23 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : register.py 6 | 7 | 8 | 9 | class LooseRegister: 10 | def __init__(self, init_dict=None): 11 | self._dict = init_dict if init_dict is not None else {} 12 | 13 | def register(self, name, target): 14 | self._dict[name] = target 15 | 16 | def __getitem__(self, item: str): 17 | for k in self._dict.keys(): 18 | if item.startswith(k): 19 | return self._dict[k] 20 | raise KeyError(f"Key {item} not found in {list(self._dict.keys())}") 21 | 22 | def __contains__(self, key): 23 | return key in self._dict 24 | 25 | 26 | if __name__ == '__main__': 27 | LOOSE_REGISTER = LooseRegister() 28 | 29 | @LOOSE_REGISTER.register() 30 | def func_a(): 31 | print("a") 32 | 33 | 34 | def show_b(): 35 | print("b") 36 | 37 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/12 22:31 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : train_utils.py 6 | 7 | import datetime 8 | import os 9 | import typing 10 | import torch 11 | import time 12 | import random 13 | import itertools 14 | import numpy as np 15 | import logging 16 | from tabulate import tabulate 17 | from collections import defaultdict 18 | from typing import * 19 | 20 | import torch.distributed as dist 21 | from fvcore.common.config import CfgNode 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class Timer: 27 | def __init__(self, synchronize=False, history_size=1000, precision=3): 28 | self._precision = precision 29 | self._stage_index = 0 30 | self._time_info = {} 31 | self._time_history = defaultdict(list) 32 | self._history_size = history_size 33 | if synchronize: 34 | assert torch.cuda.is_available(), "cuda is not available for synchronize" 35 | self._synchronize = synchronize 36 | self._time = self._get_time() 37 | 38 | def _get_time(self): 39 | return round(time.time() * 1000, self._precision) 40 | 41 | def __call__(self, stage_name=None, reset=True): 42 | if self._synchronize: 43 | torch.cuda.synchronize(torch.cuda.current_device()) 44 | 45 | current_time = self._get_time() 46 | duration = (current_time - self._time) 47 | if reset: 48 | self._time = current_time 49 | 50 | if stage_name is None: 51 | self._time_info[self._stage_index] = duration 52 | else: 53 | self._time_info[stage_name] = duration 54 | self._time_history[stage_name] = self._time_history[stage_name][-self._history_size:] 55 | self._time_history[stage_name].append(duration) 56 | 57 | return duration 58 | 59 | def reset(self): 60 | if self._synchronize: 61 | torch.cuda.synchronize(torch.cuda.current_device()) 62 | self._time = self._get_time() 63 | 64 | def __str__(self): 65 | return str(self.get_info()) 66 | 67 | def get_info(self): 68 | info = { 69 | "current": {k: round(v, self._precision) for k, v in self._time_info.items()}, 70 | "average": {k: round(sum(v) / len(v), self._precision) for k, v in self._time_history.items()} 71 | } 72 | return info 73 | 74 | def print(self): 75 | data = [[k, round(sum(v) / len(v), self._precision)] for k, v in self._time_history.items()] 76 | print(tabulate(data, headers=["Stage", "Time (ms)"], tablefmt="simple")) 77 | 78 | 79 | class CudaPreFetcher: 80 | def __init__(self, data_loader): 81 | self.dl = data_loader 82 | self.loader = iter(data_loader) 83 | self.stream = torch.cuda.Stream() 84 | self.batch = None 85 | 86 | def preload(self): 87 | try: 88 | self.batch = next(self.loader) 89 | except StopIteration: 90 | self.batch = None 91 | return 92 | with torch.cuda.stream(self.stream): 93 | self.batch = self.cuda(self.batch) 94 | 95 | @staticmethod 96 | def cuda(x: typing.Any): 97 | if isinstance(x, list) or isinstance(x, tuple): 98 | return [CudaPreFetcher.cuda(i) for i in x] 99 | elif isinstance(x, dict): 100 | return {k: CudaPreFetcher.cuda(v) for k, v in x.items()} 101 | elif isinstance(x, torch.Tensor): 102 | return x.cuda(non_blocking=True) 103 | else: 104 | return x 105 | 106 | def __next__(self): 107 | torch.cuda.current_stream().wait_stream(self.stream) 108 | batch = self.batch 109 | if batch is None: 110 | raise StopIteration 111 | self.preload() 112 | return batch 113 | 114 | def __iter__(self): 115 | self.preload() 116 | return self 117 | 118 | def __len__(self): 119 | return len(self.dl) 120 | 121 | 122 | def manual_seed(cfg: CfgNode): 123 | if cfg.SYS.DETERMINISTIC: 124 | torch.manual_seed(cfg.SYS.SEED) 125 | random.seed(cfg.SYS.SEED) 126 | np.random.seed(cfg.SYS.SEED) 127 | torch.cuda.manual_seed(cfg.SYS.SEED) 128 | torch.backends.cudnn.deterministic = True 129 | torch.backends.cudnn.benchmark = True 130 | logger.debug("Manual seed is set") 131 | else: 132 | logger.warning("Manual seed is not used") 133 | 134 | 135 | def init_distributed(proc: int, cfg: CfgNode): 136 | if cfg.SYS.MULTIPROCESS: # initialize multiprocess 137 | word_size = cfg.SYS.NUM_GPU * cfg.SYS.NUM_SHARDS 138 | rank = cfg.SYS.NUM_GPU * cfg.SYS.SHARD_ID + proc 139 | dist.init_process_group(backend="nccl", init_method=cfg.SYS.INIT_METHOD, world_size=word_size, rank=rank) 140 | torch.cuda.set_device(cfg.SYS.GPU_DEVICES[proc]) 141 | 142 | 143 | def save_config(cfg: CfgNode): 144 | if not dist.is_initialized() or dist.get_rank() == 0: 145 | config_file = os.path.join(cfg.LOG.DIR, f"config_{get_timestamp()}.yaml") 146 | with open(config_file, "w") as f: 147 | f.write(cfg.dump()) 148 | logger.debug("config is saved to %s", config_file) 149 | 150 | 151 | def get_timestamp(): 152 | return datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') 153 | 154 | 155 | def gather_object_multiple_gpu(list_object: List[Any]): 156 | """ 157 | gather a list of something from multiple GPU 158 | :param list_object: 159 | """ 160 | gathered_objects = [None for _ in range(dist.get_world_size())] 161 | dist.all_gather_object(gathered_objects, list_object) 162 | return list(itertools.chain(*gathered_objects)) 163 | -------------------------------------------------------------------------------- /src/utils/writer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/13 00:41 3 | # @Author : Yaojie Shen 4 | # @Project : MM-Video 5 | # @File : writer.py 6 | 7 | from torch.utils.tensorboard import SummaryWriter 8 | import torch.distributed as dist 9 | 10 | 11 | class DummySummaryWriter: 12 | """ 13 | Issue: https://github.com/pytorch/pytorch/issues/24236 14 | """ 15 | def __init__(*args, **kwargs): 16 | pass 17 | 18 | def __call__(self, *args, **kwargs): 19 | return self 20 | 21 | def __getattr__(self, *args, **kwargs): 22 | return self 23 | 24 | 25 | def get_writer(*args, **kwargs): 26 | if not dist.is_initialized() or dist.get_rank() == 0: 27 | return SummaryWriter(*args, **kwargs) 28 | else: 29 | return DummySummaryWriter() 30 | -------------------------------------------------------------------------------- /src/utils_func.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # !/usr/bin/env python3 4 | # -*- coding: utf-8 -*- 5 | 6 | import json 7 | 8 | 9 | class MyEncoder(json.JSONEncoder): 10 | 11 | def default(self, obj): 12 | """ 13 | 只要检查到了是bytes类型的数据就把它转为str类型 14 | :param obj: 15 | :return: 16 | """ 17 | if isinstance(obj, bytes): 18 | return str(obj, encoding='utf-8') 19 | return json.JSONEncoder.default(self, obj) 20 | 21 | 22 | def save_json(data, filename, save_pretty=False, sort_keys=False): 23 | with open(filename, "w") as f: 24 | if save_pretty: 25 | f.write(json.dumps(data, cls=MyEncoder, indent=4, sort_keys=sort_keys)) 26 | else: 27 | json.dump(data, f) 28 | 29 | 30 | def save_parsed_args_to_json(parsed_args, file_path, pretty=True): 31 | args_dict = vars(parsed_args) 32 | save_json(args_dict, file_path, save_pretty=pretty) 33 | 34 | 35 | def load_json(file_path): 36 | with open(file_path, "r") as f: 37 | return json.load(f) 38 | 39 | 40 | def set_lr(optimizer, decay_factor): 41 | for group in optimizer.param_groups: 42 | group["lr"] = group["lr"] * decay_factor 43 | 44 | 45 | def flat_list_of_lists(l): 46 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" 47 | return [item for sublist in l for item in sublist] 48 | 49 | 50 | def count_parameters(model, verbose=True): 51 | """Count number of parameters in PyTorch model, 52 | References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7. 53 | 54 | from utils.utils import count_parameters 55 | count_parameters(model) 56 | import sys 57 | sys.exit(1) 58 | """ 59 | n_all = sum(p.numel() for p in model.parameters()) 60 | n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 61 | #if verbose: 62 | # print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable)) 63 | return n_all, n_trainable 64 | 65 | 66 | def sum_parameters(model, verbose=True): 67 | """Count number of parameters in PyTorch model, 68 | References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7. 69 | 70 | from utils.utils import count_parameters 71 | count_parameters(model) 72 | import sys 73 | sys.exit(1) 74 | """ 75 | p_sum = sum(p.sum().item() for p in model.parameters()) 76 | if verbose: 77 | print("Parameter sum {}".format(p_sum)) 78 | return p_sum 79 | 80 | 81 | def merge_dicts(list_dicts): 82 | merged_dict = list_dicts[0].copy() 83 | for i in range(1, len(list_dicts)): 84 | merged_dict.update(list_dicts[i]) 85 | return merged_dict 86 | 87 | 88 | def merge_json_files(paths, merged_path): 89 | merged_dict = merge_dicts([load_json(e) for e in paths]) 90 | save_json(merged_dict, merged_path) 91 | 92 | 93 | --------------------------------------------------------------------------------